From 2595cf059d677c39513a9d75f1736bc5b84d6298 Mon Sep 17 00:00:00 2001 From: Hao LI <8520588+Leo-LiHao@users.noreply.github.com> Date: Fri, 5 Feb 2021 00:59:38 +0800 Subject: [PATCH] [mypy] Add/fix type annotations for binary trees in data structures (#4085) * fix mypy: data_structures:binary_tree * mypy --strict for binary_trees in data_structures * fix pre-commit Co-authored-by: LiHao --- .../binary_search_tree_recursive.py | 99 ++++++++++++------- .../binary_tree/lazy_segment_tree.py | 9 +- data_structures/binary_tree/treap.py | 33 ++++--- 3 files changed, 84 insertions(+), 57 deletions(-) diff --git a/data_structures/binary_tree/binary_search_tree_recursive.py b/data_structures/binary_tree/binary_search_tree_recursive.py index f1e46e33c..a05e28a7b 100644 --- a/data_structures/binary_tree/binary_search_tree_recursive.py +++ b/data_structures/binary_tree/binary_search_tree_recursive.py @@ -8,21 +8,22 @@ To run an example: python binary_search_tree_recursive.py """ import unittest +from typing import Iterator, Optional class Node: - def __init__(self, label: int, parent): + def __init__(self, label: int, parent: Optional["Node"]) -> None: self.label = label self.parent = parent - self.left = None - self.right = None + self.left: Optional[Node] = None + self.right: Optional[Node] = None class BinarySearchTree: - def __init__(self): - self.root = None + def __init__(self) -> None: + self.root: Optional[Node] = None - def empty(self): + def empty(self) -> None: """ Empties the tree @@ -46,7 +47,7 @@ class BinarySearchTree: """ return self.root is None - def put(self, label: int): + def put(self, label: int) -> None: """ Put a new node in the tree @@ -65,7 +66,9 @@ class BinarySearchTree: """ self.root = self._put(self.root, label) - def _put(self, node: Node, label: int, parent: Node = None) -> Node: + def _put( + self, node: Optional[Node], label: int, parent: Optional[Node] = None + ) -> Node: if node is None: node = Node(label, parent) else: @@ -95,7 +98,7 @@ class BinarySearchTree: """ return self._search(self.root, label) - def _search(self, node: Node, label: int) -> Node: + def _search(self, node: Optional[Node], label: int) -> Node: if node is None: raise Exception(f"Node with label {label} does not exist") else: @@ -106,7 +109,7 @@ class BinarySearchTree: return node - def remove(self, label: int): + def remove(self, label: int) -> None: """ Removes a node in the tree @@ -122,13 +125,7 @@ class BinarySearchTree: Exception: Node with label 3 does not exist """ node = self.search(label) - if not node.right and not node.left: - self._reassign_nodes(node, None) - elif not node.right and node.left: - self._reassign_nodes(node, node.left) - elif node.right and not node.left: - self._reassign_nodes(node, node.right) - else: + if node.right and node.left: lowest_node = self._get_lowest_node(node.right) lowest_node.left = node.left lowest_node.right = node.right @@ -136,8 +133,14 @@ class BinarySearchTree: if node.right: node.right.parent = lowest_node self._reassign_nodes(node, lowest_node) + elif not node.right and node.left: + self._reassign_nodes(node, node.left) + elif node.right and not node.left: + self._reassign_nodes(node, node.right) + else: + self._reassign_nodes(node, None) - def _reassign_nodes(self, node: Node, new_children: Node): + def _reassign_nodes(self, node: Node, new_children: Optional[Node]) -> None: if new_children: new_children.parent = node.parent @@ -192,7 +195,7 @@ class BinarySearchTree: >>> t.get_max_label() 10 """ - if self.is_empty(): + if self.root is None: raise Exception("Binary search tree is empty") node = self.root @@ -216,7 +219,7 @@ class BinarySearchTree: >>> t.get_min_label() 8 """ - if self.is_empty(): + if self.root is None: raise Exception("Binary search tree is empty") node = self.root @@ -225,7 +228,7 @@ class BinarySearchTree: return node.label - def inorder_traversal(self) -> list: + def inorder_traversal(self) -> Iterator[Node]: """ Return the inorder traversal of the tree @@ -241,13 +244,13 @@ class BinarySearchTree: """ return self._inorder_traversal(self.root) - def _inorder_traversal(self, node: Node) -> list: + def _inorder_traversal(self, node: Optional[Node]) -> Iterator[Node]: if node is not None: yield from self._inorder_traversal(node.left) yield node yield from self._inorder_traversal(node.right) - def preorder_traversal(self) -> list: + def preorder_traversal(self) -> Iterator[Node]: """ Return the preorder traversal of the tree @@ -263,7 +266,7 @@ class BinarySearchTree: """ return self._preorder_traversal(self.root) - def _preorder_traversal(self, node: Node) -> list: + def _preorder_traversal(self, node: Optional[Node]) -> Iterator[Node]: if node is not None: yield node yield from self._preorder_traversal(node.left) @@ -272,7 +275,7 @@ class BinarySearchTree: class BinarySearchTreeTest(unittest.TestCase): @staticmethod - def _get_binary_search_tree(): + def _get_binary_search_tree() -> BinarySearchTree: r""" 8 / \ @@ -298,7 +301,7 @@ class BinarySearchTreeTest(unittest.TestCase): return t - def test_put(self): + def test_put(self) -> None: t = BinarySearchTree() assert t.is_empty() @@ -306,6 +309,7 @@ class BinarySearchTreeTest(unittest.TestCase): r""" 8 """ + assert t.root is not None assert t.root.parent is None assert t.root.label == 8 @@ -315,6 +319,7 @@ class BinarySearchTreeTest(unittest.TestCase): \ 10 """ + assert t.root.right is not None assert t.root.right.parent == t.root assert t.root.right.label == 10 @@ -324,6 +329,7 @@ class BinarySearchTreeTest(unittest.TestCase): / \ 3 10 """ + assert t.root.left is not None assert t.root.left.parent == t.root assert t.root.left.label == 3 @@ -335,6 +341,7 @@ class BinarySearchTreeTest(unittest.TestCase): \ 6 """ + assert t.root.left.right is not None assert t.root.left.right.parent == t.root.left assert t.root.left.right.label == 6 @@ -346,13 +353,14 @@ class BinarySearchTreeTest(unittest.TestCase): / \ 1 6 """ + assert t.root.left.left is not None assert t.root.left.left.parent == t.root.left assert t.root.left.left.label == 1 with self.assertRaises(Exception): t.put(1) - def test_search(self): + def test_search(self) -> None: t = self._get_binary_search_tree() node = t.search(6) @@ -364,7 +372,7 @@ class BinarySearchTreeTest(unittest.TestCase): with self.assertRaises(Exception): t.search(2) - def test_remove(self): + def test_remove(self) -> None: t = self._get_binary_search_tree() t.remove(13) @@ -379,6 +387,9 @@ class BinarySearchTreeTest(unittest.TestCase): \ 5 """ + assert t.root is not None + assert t.root.right is not None + assert t.root.right.right is not None assert t.root.right.right.right is None assert t.root.right.right.left is None @@ -394,6 +405,9 @@ class BinarySearchTreeTest(unittest.TestCase): \ 5 """ + assert t.root.left is not None + assert t.root.left.right is not None + assert t.root.left.right.left is not None assert t.root.left.right.right is None assert t.root.left.right.left.label == 4 @@ -407,6 +421,8 @@ class BinarySearchTreeTest(unittest.TestCase): \ 5 """ + assert t.root.left.left is not None + assert t.root.left.right.right is not None assert t.root.left.left.label == 1 assert t.root.left.right.label == 4 assert t.root.left.right.right.label == 5 @@ -422,6 +438,7 @@ class BinarySearchTreeTest(unittest.TestCase): / \ \ 1 5 14 """ + assert t.root is not None assert t.root.left.label == 4 assert t.root.left.right.label == 5 assert t.root.left.left.label == 1 @@ -437,13 +454,15 @@ class BinarySearchTreeTest(unittest.TestCase): / \ 1 14 """ + assert t.root.left is not None + assert t.root.left.left is not None assert t.root.left.label == 5 assert t.root.left.right is None assert t.root.left.left.label == 1 assert t.root.left.parent == t.root assert t.root.left.left.parent == t.root.left - def test_remove_2(self): + def test_remove_2(self) -> None: t = self._get_binary_search_tree() t.remove(3) @@ -456,6 +475,12 @@ class BinarySearchTreeTest(unittest.TestCase): / \ / 5 7 13 """ + assert t.root is not None + assert t.root.left is not None + assert t.root.left.left is not None + assert t.root.left.right is not None + assert t.root.left.right.left is not None + assert t.root.left.right.right is not None assert t.root.left.label == 4 assert t.root.left.right.label == 6 assert t.root.left.left.label == 1 @@ -466,25 +491,25 @@ class BinarySearchTreeTest(unittest.TestCase): assert t.root.left.left.parent == t.root.left assert t.root.left.right.left.parent == t.root.left.right - def test_empty(self): + def test_empty(self) -> None: t = self._get_binary_search_tree() t.empty() assert t.root is None - def test_is_empty(self): + def test_is_empty(self) -> None: t = self._get_binary_search_tree() assert not t.is_empty() t.empty() assert t.is_empty() - def test_exists(self): + def test_exists(self) -> None: t = self._get_binary_search_tree() assert t.exists(6) assert not t.exists(-1) - def test_get_max_label(self): + def test_get_max_label(self) -> None: t = self._get_binary_search_tree() assert t.get_max_label() == 14 @@ -493,7 +518,7 @@ class BinarySearchTreeTest(unittest.TestCase): with self.assertRaises(Exception): t.get_max_label() - def test_get_min_label(self): + def test_get_min_label(self) -> None: t = self._get_binary_search_tree() assert t.get_min_label() == 1 @@ -502,20 +527,20 @@ class BinarySearchTreeTest(unittest.TestCase): with self.assertRaises(Exception): t.get_min_label() - def test_inorder_traversal(self): + def test_inorder_traversal(self) -> None: t = self._get_binary_search_tree() inorder_traversal_nodes = [i.label for i in t.inorder_traversal()] assert inorder_traversal_nodes == [1, 3, 4, 5, 6, 7, 8, 10, 13, 14] - def test_preorder_traversal(self): + def test_preorder_traversal(self) -> None: t = self._get_binary_search_tree() preorder_traversal_nodes = [i.label for i in t.preorder_traversal()] assert preorder_traversal_nodes == [8, 3, 1, 6, 4, 5, 7, 10, 14, 13] -def binary_search_tree_example(): +def binary_search_tree_example() -> None: r""" Example 8 diff --git a/data_structures/binary_tree/lazy_segment_tree.py b/data_structures/binary_tree/lazy_segment_tree.py index 5bc79e74e..9066db294 100644 --- a/data_structures/binary_tree/lazy_segment_tree.py +++ b/data_structures/binary_tree/lazy_segment_tree.py @@ -1,6 +1,7 @@ from __future__ import annotations import math +from typing import List, Union class SegmentTree: @@ -37,7 +38,7 @@ class SegmentTree: return idx * 2 + 1 def build( - self, idx: int, left_element: int, right_element: int, A: list[int] + self, idx: int, left_element: int, right_element: int, A: List[int] ) -> None: if left_element == right_element: self.segment_tree[idx] = A[left_element - 1] @@ -88,7 +89,7 @@ class SegmentTree: # query with O(lg n) def query( self, idx: int, left_element: int, right_element: int, a: int, b: int - ) -> int: + ) -> Union[int, float]: """ query(1, 1, size, a, b) for query max of [a,b] >>> A = [1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] @@ -118,8 +119,8 @@ class SegmentTree: q2 = self.query(self.right(idx), mid + 1, right_element, a, b) return max(q1, q2) - def __str__(self) -> None: - return [self.query(1, 1, self.size, i, i) for i in range(1, self.size + 1)] + def __str__(self) -> str: + return str([self.query(1, 1, self.size, i, i) for i in range(1, self.size + 1)]) if __name__ == "__main__": diff --git a/data_structures/binary_tree/treap.py b/data_structures/binary_tree/treap.py index 26648f7ab..a09dcc928 100644 --- a/data_structures/binary_tree/treap.py +++ b/data_structures/binary_tree/treap.py @@ -3,6 +3,7 @@ from __future__ import annotations from random import random +from typing import Optional, Tuple class Node: @@ -11,13 +12,13 @@ class Node: Treap is a binary tree by value and heap by priority """ - def __init__(self, value: int = None): + def __init__(self, value: Optional[int] = None): self.value = value self.prior = random() - self.left = None - self.right = None + self.left: Optional[Node] = None + self.right: Optional[Node] = None - def __repr__(self): + def __repr__(self) -> str: from pprint import pformat if self.left is None and self.right is None: @@ -27,14 +28,14 @@ class Node: {f"{self.value}: {self.prior:.5}": (self.left, self.right)}, indent=1 ) - def __str__(self): + def __str__(self) -> str: value = str(self.value) + " " left = str(self.left or "") right = str(self.right or "") return value + left + right -def split(root: Node, value: int) -> tuple[Node, Node]: +def split(root: Optional[Node], value: int) -> Tuple[Optional[Node], Optional[Node]]: """ We split current tree into 2 trees with value: @@ -42,9 +43,9 @@ def split(root: Node, value: int) -> tuple[Node, Node]: Right tree contains all values greater or equal, than split value """ if root is None: # None tree is split into 2 Nones - return (None, None) + return None, None elif root.value is None: - return (None, None) + return None, None else: if value < root.value: """ @@ -54,16 +55,16 @@ def split(root: Node, value: int) -> tuple[Node, Node]: Right tree's left son: right part of that split """ left, root.left = split(root.left, value) - return (left, root) + return left, root else: """ Just symmetric to previous case """ root.right, right = split(root.right, value) - return (root, right) + return root, right -def merge(left: Node, right: Node) -> Node: +def merge(left: Optional[Node], right: Optional[Node]) -> Optional[Node]: """ We merge 2 trees into one. Note: all left tree's values must be less than all right tree's @@ -85,7 +86,7 @@ def merge(left: Node, right: Node) -> Node: return right -def insert(root: Node, value: int) -> Node: +def insert(root: Optional[Node], value: int) -> Optional[Node]: """ Insert element @@ -98,7 +99,7 @@ def insert(root: Node, value: int) -> Node: return merge(merge(left, node), right) -def erase(root: Node, value: int) -> Node: +def erase(root: Optional[Node], value: int) -> Optional[Node]: """ Erase element @@ -111,7 +112,7 @@ def erase(root: Node, value: int) -> Node: return merge(left, right) -def inorder(root: Node): +def inorder(root: Optional[Node]) -> None: """ Just recursive print of a tree """ @@ -123,7 +124,7 @@ def inorder(root: Node): inorder(root.right) -def interactTreap(root, args): +def interactTreap(root: Optional[Node], args: str) -> Optional[Node]: """ Commands: + value to add value into treap @@ -160,7 +161,7 @@ def interactTreap(root, args): return root -def main(): +def main() -> None: """After each command, program prints treap""" root = None print(