[mypy] Fix type annotations in wavelet_tree.py (#5641)

* [mypy] Fix type annotations for wavelet_tree.py

* fix a typo
This commit is contained in:
Dylan Buchi 2021-10-28 17:53:02 -03:00 committed by GitHub
parent 61e1dd27b0
commit 0590d736fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -31,7 +31,7 @@ class Node:
return f"min_value: {self.minn}, max_value: {self.maxx}"
def build_tree(arr: list[int]) -> Node:
def build_tree(arr: list[int]) -> Node | None:
"""
Builds the tree for arr and returns the root
of the constructed tree
@ -51,7 +51,10 @@ def build_tree(arr: list[int]) -> Node:
then recursively build trees for left_arr and right_arr
"""
pivot = (root.minn + root.maxx) // 2
left_arr, right_arr = [], []
left_arr: list[int] = []
right_arr: list[int] = []
for index, num in enumerate(arr):
if num <= pivot:
left_arr.append(num)
@ -63,7 +66,7 @@ def build_tree(arr: list[int]) -> Node:
return root
def rank_till_index(node: Node, num: int, index: int) -> int:
def rank_till_index(node: Node | None, num: int, index: int) -> int:
"""
Returns the number of occurrences of num in interval [0, index] in the list
@ -79,7 +82,7 @@ def rank_till_index(node: Node, num: int, index: int) -> int:
>>> rank_till_index(root, 0, 9)
1
"""
if index < 0:
if index < 0 or node is None:
return 0
# Leaf node cases
if node.minn == node.maxx:
@ -93,7 +96,7 @@ def rank_till_index(node: Node, num: int, index: int) -> int:
return rank_till_index(node.right, num, index - node.map_left[index])
def rank(node: Node, num: int, start: int, end: int) -> int:
def rank(node: Node | None, num: int, start: int, end: int) -> int:
"""
Returns the number of occurrences of num in interval [start, end] in the list
@ -114,7 +117,7 @@ def rank(node: Node, num: int, start: int, end: int) -> int:
return rank_till_end - rank_before_start
def quantile(node: Node, index: int, start: int, end: int) -> int:
def quantile(node: Node | None, index: int, start: int, end: int) -> int:
"""
Returns the index'th smallest element in interval [start, end] in the list
index is 0-indexed
@ -129,7 +132,7 @@ def quantile(node: Node, index: int, start: int, end: int) -> int:
>>> quantile(root, 4, 2, 5)
-1
"""
if index > (end - start) or start > end:
if index > (end - start) or start > end or node is None:
return -1
# Leaf node case
if node.minn == node.maxx:
@ -155,10 +158,10 @@ def quantile(node: Node, index: int, start: int, end: int) -> int:
def range_counting(
node: Node, start: int, end: int, start_num: int, end_num: int
node: Node | None, start: int, end: int, start_num: int, end_num: int
) -> int:
"""
Returns the number of elememts in range [start_num, end_num]
Returns the number of elements in range [start_num, end_num]
in interval [start, end] in the list
>>> root = build_tree(test_array)
@ -175,6 +178,7 @@ def range_counting(
"""
if (
start > end
or node is None
or start_num > end_num
or node.minn > end_num
or node.maxx < start_num