From 0590d736fa61833c8f8591f7aa3bbea88b8274f9 Mon Sep 17 00:00:00 2001 From: Dylan Buchi Date: Thu, 28 Oct 2021 17:53:02 -0300 Subject: [PATCH] [mypy] Fix type annotations in `wavelet_tree.py` (#5641) * [mypy] Fix type annotations for wavelet_tree.py * fix a typo --- data_structures/binary_tree/wavelet_tree.py | 22 ++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/data_structures/binary_tree/wavelet_tree.py b/data_structures/binary_tree/wavelet_tree.py index 173a88ab7..8d7145189 100644 --- a/data_structures/binary_tree/wavelet_tree.py +++ b/data_structures/binary_tree/wavelet_tree.py @@ -31,7 +31,7 @@ class Node: return f"min_value: {self.minn}, max_value: {self.maxx}" -def build_tree(arr: list[int]) -> Node: +def build_tree(arr: list[int]) -> Node | None: """ Builds the tree for arr and returns the root of the constructed tree @@ -51,7 +51,10 @@ def build_tree(arr: list[int]) -> Node: then recursively build trees for left_arr and right_arr """ pivot = (root.minn + root.maxx) // 2 - left_arr, right_arr = [], [] + + left_arr: list[int] = [] + right_arr: list[int] = [] + for index, num in enumerate(arr): if num <= pivot: left_arr.append(num) @@ -63,7 +66,7 @@ def build_tree(arr: list[int]) -> Node: return root -def rank_till_index(node: Node, num: int, index: int) -> int: +def rank_till_index(node: Node | None, num: int, index: int) -> int: """ Returns the number of occurrences of num in interval [0, index] in the list @@ -79,7 +82,7 @@ def rank_till_index(node: Node, num: int, index: int) -> int: >>> rank_till_index(root, 0, 9) 1 """ - if index < 0: + if index < 0 or node is None: return 0 # Leaf node cases if node.minn == node.maxx: @@ -93,7 +96,7 @@ def rank_till_index(node: Node, num: int, index: int) -> int: return rank_till_index(node.right, num, index - node.map_left[index]) -def rank(node: Node, num: int, start: int, end: int) -> int: +def rank(node: Node | None, num: int, start: int, end: int) -> int: """ Returns the number of occurrences of num in interval [start, end] in the list @@ -114,7 +117,7 @@ def rank(node: Node, num: int, start: int, end: int) -> int: return rank_till_end - rank_before_start -def quantile(node: Node, index: int, start: int, end: int) -> int: +def quantile(node: Node | None, index: int, start: int, end: int) -> int: """ Returns the index'th smallest element in interval [start, end] in the list index is 0-indexed @@ -129,7 +132,7 @@ def quantile(node: Node, index: int, start: int, end: int) -> int: >>> quantile(root, 4, 2, 5) -1 """ - if index > (end - start) or start > end: + if index > (end - start) or start > end or node is None: return -1 # Leaf node case if node.minn == node.maxx: @@ -155,10 +158,10 @@ def quantile(node: Node, index: int, start: int, end: int) -> int: def range_counting( - node: Node, start: int, end: int, start_num: int, end_num: int + node: Node | None, start: int, end: int, start_num: int, end_num: int ) -> int: """ - Returns the number of elememts in range [start_num, end_num] + Returns the number of elements in range [start_num, end_num] in interval [start, end] in the list >>> root = build_tree(test_array) @@ -175,6 +178,7 @@ def range_counting( """ if ( start > end + or node is None or start_num > end_num or node.minn > end_num or node.maxx < start_num