[mypy] Add/fix type annotations for binary trees in data structures (#4085)

* fix mypy: data_structures:binary_tree

* mypy --strict for binary_trees in data_structures

* fix pre-commit

Co-authored-by: LiHao <leo_how@163.com>
This commit is contained in:
Hao LI 2021-02-05 00:59:38 +08:00 committed by GitHub
parent 97b6ca2b19
commit 2595cf059d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 84 additions and 57 deletions

View File

@ -8,21 +8,22 @@ To run an example:
python binary_search_tree_recursive.py
"""
import unittest
from typing import Iterator, Optional
class Node:
def __init__(self, label: int, parent):
def __init__(self, label: int, parent: Optional["Node"]) -> None:
self.label = label
self.parent = parent
self.left = None
self.right = None
self.left: Optional[Node] = None
self.right: Optional[Node] = None
class BinarySearchTree:
def __init__(self):
self.root = None
def __init__(self) -> None:
self.root: Optional[Node] = None
def empty(self):
def empty(self) -> None:
"""
Empties the tree
@ -46,7 +47,7 @@ class BinarySearchTree:
"""
return self.root is None
def put(self, label: int):
def put(self, label: int) -> None:
"""
Put a new node in the tree
@ -65,7 +66,9 @@ class BinarySearchTree:
"""
self.root = self._put(self.root, label)
def _put(self, node: Node, label: int, parent: Node = None) -> Node:
def _put(
self, node: Optional[Node], label: int, parent: Optional[Node] = None
) -> Node:
if node is None:
node = Node(label, parent)
else:
@ -95,7 +98,7 @@ class BinarySearchTree:
"""
return self._search(self.root, label)
def _search(self, node: Node, label: int) -> Node:
def _search(self, node: Optional[Node], label: int) -> Node:
if node is None:
raise Exception(f"Node with label {label} does not exist")
else:
@ -106,7 +109,7 @@ class BinarySearchTree:
return node
def remove(self, label: int):
def remove(self, label: int) -> None:
"""
Removes a node in the tree
@ -122,13 +125,7 @@ class BinarySearchTree:
Exception: Node with label 3 does not exist
"""
node = self.search(label)
if not node.right and not node.left:
self._reassign_nodes(node, None)
elif not node.right and node.left:
self._reassign_nodes(node, node.left)
elif node.right and not node.left:
self._reassign_nodes(node, node.right)
else:
if node.right and node.left:
lowest_node = self._get_lowest_node(node.right)
lowest_node.left = node.left
lowest_node.right = node.right
@ -136,8 +133,14 @@ class BinarySearchTree:
if node.right:
node.right.parent = lowest_node
self._reassign_nodes(node, lowest_node)
elif not node.right and node.left:
self._reassign_nodes(node, node.left)
elif node.right and not node.left:
self._reassign_nodes(node, node.right)
else:
self._reassign_nodes(node, None)
def _reassign_nodes(self, node: Node, new_children: Node):
def _reassign_nodes(self, node: Node, new_children: Optional[Node]) -> None:
if new_children:
new_children.parent = node.parent
@ -192,7 +195,7 @@ class BinarySearchTree:
>>> t.get_max_label()
10
"""
if self.is_empty():
if self.root is None:
raise Exception("Binary search tree is empty")
node = self.root
@ -216,7 +219,7 @@ class BinarySearchTree:
>>> t.get_min_label()
8
"""
if self.is_empty():
if self.root is None:
raise Exception("Binary search tree is empty")
node = self.root
@ -225,7 +228,7 @@ class BinarySearchTree:
return node.label
def inorder_traversal(self) -> list:
def inorder_traversal(self) -> Iterator[Node]:
"""
Return the inorder traversal of the tree
@ -241,13 +244,13 @@ class BinarySearchTree:
"""
return self._inorder_traversal(self.root)
def _inorder_traversal(self, node: Node) -> list:
def _inorder_traversal(self, node: Optional[Node]) -> Iterator[Node]:
if node is not None:
yield from self._inorder_traversal(node.left)
yield node
yield from self._inorder_traversal(node.right)
def preorder_traversal(self) -> list:
def preorder_traversal(self) -> Iterator[Node]:
"""
Return the preorder traversal of the tree
@ -263,7 +266,7 @@ class BinarySearchTree:
"""
return self._preorder_traversal(self.root)
def _preorder_traversal(self, node: Node) -> list:
def _preorder_traversal(self, node: Optional[Node]) -> Iterator[Node]:
if node is not None:
yield node
yield from self._preorder_traversal(node.left)
@ -272,7 +275,7 @@ class BinarySearchTree:
class BinarySearchTreeTest(unittest.TestCase):
@staticmethod
def _get_binary_search_tree():
def _get_binary_search_tree() -> BinarySearchTree:
r"""
8
/ \
@ -298,7 +301,7 @@ class BinarySearchTreeTest(unittest.TestCase):
return t
def test_put(self):
def test_put(self) -> None:
t = BinarySearchTree()
assert t.is_empty()
@ -306,6 +309,7 @@ class BinarySearchTreeTest(unittest.TestCase):
r"""
8
"""
assert t.root is not None
assert t.root.parent is None
assert t.root.label == 8
@ -315,6 +319,7 @@ class BinarySearchTreeTest(unittest.TestCase):
\
10
"""
assert t.root.right is not None
assert t.root.right.parent == t.root
assert t.root.right.label == 10
@ -324,6 +329,7 @@ class BinarySearchTreeTest(unittest.TestCase):
/ \
3 10
"""
assert t.root.left is not None
assert t.root.left.parent == t.root
assert t.root.left.label == 3
@ -335,6 +341,7 @@ class BinarySearchTreeTest(unittest.TestCase):
\
6
"""
assert t.root.left.right is not None
assert t.root.left.right.parent == t.root.left
assert t.root.left.right.label == 6
@ -346,13 +353,14 @@ class BinarySearchTreeTest(unittest.TestCase):
/ \
1 6
"""
assert t.root.left.left is not None
assert t.root.left.left.parent == t.root.left
assert t.root.left.left.label == 1
with self.assertRaises(Exception):
t.put(1)
def test_search(self):
def test_search(self) -> None:
t = self._get_binary_search_tree()
node = t.search(6)
@ -364,7 +372,7 @@ class BinarySearchTreeTest(unittest.TestCase):
with self.assertRaises(Exception):
t.search(2)
def test_remove(self):
def test_remove(self) -> None:
t = self._get_binary_search_tree()
t.remove(13)
@ -379,6 +387,9 @@ class BinarySearchTreeTest(unittest.TestCase):
\
5
"""
assert t.root is not None
assert t.root.right is not None
assert t.root.right.right is not None
assert t.root.right.right.right is None
assert t.root.right.right.left is None
@ -394,6 +405,9 @@ class BinarySearchTreeTest(unittest.TestCase):
\
5
"""
assert t.root.left is not None
assert t.root.left.right is not None
assert t.root.left.right.left is not None
assert t.root.left.right.right is None
assert t.root.left.right.left.label == 4
@ -407,6 +421,8 @@ class BinarySearchTreeTest(unittest.TestCase):
\
5
"""
assert t.root.left.left is not None
assert t.root.left.right.right is not None
assert t.root.left.left.label == 1
assert t.root.left.right.label == 4
assert t.root.left.right.right.label == 5
@ -422,6 +438,7 @@ class BinarySearchTreeTest(unittest.TestCase):
/ \ \
1 5 14
"""
assert t.root is not None
assert t.root.left.label == 4
assert t.root.left.right.label == 5
assert t.root.left.left.label == 1
@ -437,13 +454,15 @@ class BinarySearchTreeTest(unittest.TestCase):
/ \
1 14
"""
assert t.root.left is not None
assert t.root.left.left is not None
assert t.root.left.label == 5
assert t.root.left.right is None
assert t.root.left.left.label == 1
assert t.root.left.parent == t.root
assert t.root.left.left.parent == t.root.left
def test_remove_2(self):
def test_remove_2(self) -> None:
t = self._get_binary_search_tree()
t.remove(3)
@ -456,6 +475,12 @@ class BinarySearchTreeTest(unittest.TestCase):
/ \ /
5 7 13
"""
assert t.root is not None
assert t.root.left is not None
assert t.root.left.left is not None
assert t.root.left.right is not None
assert t.root.left.right.left is not None
assert t.root.left.right.right is not None
assert t.root.left.label == 4
assert t.root.left.right.label == 6
assert t.root.left.left.label == 1
@ -466,25 +491,25 @@ class BinarySearchTreeTest(unittest.TestCase):
assert t.root.left.left.parent == t.root.left
assert t.root.left.right.left.parent == t.root.left.right
def test_empty(self):
def test_empty(self) -> None:
t = self._get_binary_search_tree()
t.empty()
assert t.root is None
def test_is_empty(self):
def test_is_empty(self) -> None:
t = self._get_binary_search_tree()
assert not t.is_empty()
t.empty()
assert t.is_empty()
def test_exists(self):
def test_exists(self) -> None:
t = self._get_binary_search_tree()
assert t.exists(6)
assert not t.exists(-1)
def test_get_max_label(self):
def test_get_max_label(self) -> None:
t = self._get_binary_search_tree()
assert t.get_max_label() == 14
@ -493,7 +518,7 @@ class BinarySearchTreeTest(unittest.TestCase):
with self.assertRaises(Exception):
t.get_max_label()
def test_get_min_label(self):
def test_get_min_label(self) -> None:
t = self._get_binary_search_tree()
assert t.get_min_label() == 1
@ -502,20 +527,20 @@ class BinarySearchTreeTest(unittest.TestCase):
with self.assertRaises(Exception):
t.get_min_label()
def test_inorder_traversal(self):
def test_inorder_traversal(self) -> None:
t = self._get_binary_search_tree()
inorder_traversal_nodes = [i.label for i in t.inorder_traversal()]
assert inorder_traversal_nodes == [1, 3, 4, 5, 6, 7, 8, 10, 13, 14]
def test_preorder_traversal(self):
def test_preorder_traversal(self) -> None:
t = self._get_binary_search_tree()
preorder_traversal_nodes = [i.label for i in t.preorder_traversal()]
assert preorder_traversal_nodes == [8, 3, 1, 6, 4, 5, 7, 10, 14, 13]
def binary_search_tree_example():
def binary_search_tree_example() -> None:
r"""
Example
8

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import math
from typing import List, Union
class SegmentTree:
@ -37,7 +38,7 @@ class SegmentTree:
return idx * 2 + 1
def build(
self, idx: int, left_element: int, right_element: int, A: list[int]
self, idx: int, left_element: int, right_element: int, A: List[int]
) -> None:
if left_element == right_element:
self.segment_tree[idx] = A[left_element - 1]
@ -88,7 +89,7 @@ class SegmentTree:
# query with O(lg n)
def query(
self, idx: int, left_element: int, right_element: int, a: int, b: int
) -> int:
) -> Union[int, float]:
"""
query(1, 1, size, a, b) for query max of [a,b]
>>> A = [1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]
@ -118,8 +119,8 @@ class SegmentTree:
q2 = self.query(self.right(idx), mid + 1, right_element, a, b)
return max(q1, q2)
def __str__(self) -> None:
return [self.query(1, 1, self.size, i, i) for i in range(1, self.size + 1)]
def __str__(self) -> str:
return str([self.query(1, 1, self.size, i, i) for i in range(1, self.size + 1)])
if __name__ == "__main__":

View File

@ -3,6 +3,7 @@
from __future__ import annotations
from random import random
from typing import Optional, Tuple
class Node:
@ -11,13 +12,13 @@ class Node:
Treap is a binary tree by value and heap by priority
"""
def __init__(self, value: int = None):
def __init__(self, value: Optional[int] = None):
self.value = value
self.prior = random()
self.left = None
self.right = None
self.left: Optional[Node] = None
self.right: Optional[Node] = None
def __repr__(self):
def __repr__(self) -> str:
from pprint import pformat
if self.left is None and self.right is None:
@ -27,14 +28,14 @@ class Node:
{f"{self.value}: {self.prior:.5}": (self.left, self.right)}, indent=1
)
def __str__(self):
def __str__(self) -> str:
value = str(self.value) + " "
left = str(self.left or "")
right = str(self.right or "")
return value + left + right
def split(root: Node, value: int) -> tuple[Node, Node]:
def split(root: Optional[Node], value: int) -> Tuple[Optional[Node], Optional[Node]]:
"""
We split current tree into 2 trees with value:
@ -42,9 +43,9 @@ def split(root: Node, value: int) -> tuple[Node, Node]:
Right tree contains all values greater or equal, than split value
"""
if root is None: # None tree is split into 2 Nones
return (None, None)
return None, None
elif root.value is None:
return (None, None)
return None, None
else:
if value < root.value:
"""
@ -54,16 +55,16 @@ def split(root: Node, value: int) -> tuple[Node, Node]:
Right tree's left son: right part of that split
"""
left, root.left = split(root.left, value)
return (left, root)
return left, root
else:
"""
Just symmetric to previous case
"""
root.right, right = split(root.right, value)
return (root, right)
return root, right
def merge(left: Node, right: Node) -> Node:
def merge(left: Optional[Node], right: Optional[Node]) -> Optional[Node]:
"""
We merge 2 trees into one.
Note: all left tree's values must be less than all right tree's
@ -85,7 +86,7 @@ def merge(left: Node, right: Node) -> Node:
return right
def insert(root: Node, value: int) -> Node:
def insert(root: Optional[Node], value: int) -> Optional[Node]:
"""
Insert element
@ -98,7 +99,7 @@ def insert(root: Node, value: int) -> Node:
return merge(merge(left, node), right)
def erase(root: Node, value: int) -> Node:
def erase(root: Optional[Node], value: int) -> Optional[Node]:
"""
Erase element
@ -111,7 +112,7 @@ def erase(root: Node, value: int) -> Node:
return merge(left, right)
def inorder(root: Node):
def inorder(root: Optional[Node]) -> None:
"""
Just recursive print of a tree
"""
@ -123,7 +124,7 @@ def inorder(root: Node):
inorder(root.right)
def interactTreap(root, args):
def interactTreap(root: Optional[Node], args: str) -> Optional[Node]:
"""
Commands:
+ value to add value into treap
@ -160,7 +161,7 @@ def interactTreap(root, args):
return root
def main():
def main() -> None:
"""After each command, program prints treap"""
root = None
print(