types: Update binary search tree typehints (#7197)

* types: Update binary search tree typehints

* refactor: Don't return `self` in `:meth:insert`

* test: Fix failing doctests

* Apply suggestions from code review

Co-authored-by: Dhruv Manilawala <dhruvmanila@gmail.com>
This commit is contained in:
Caeden 2022-10-15 23:51:23 +01:00 committed by GitHub
parent 553624fcd4
commit c94e215c8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,15 +2,18 @@
A binary search Tree A binary search Tree
""" """
from collections.abc import Iterable
from typing import Any
class Node: class Node:
def __init__(self, value, parent): def __init__(self, value: int | None = None):
self.value = value self.value = value
self.parent = parent # Added in order to delete a node easier self.parent: Node | None = None # Added in order to delete a node easier
self.left = None self.left: Node | None = None
self.right = None self.right: Node | None = None
def __repr__(self): def __repr__(self) -> str:
from pprint import pformat from pprint import pformat
if self.left is None and self.right is None: if self.left is None and self.right is None:
@ -19,16 +22,16 @@ class Node:
class BinarySearchTree: class BinarySearchTree:
def __init__(self, root=None): def __init__(self, root: Node | None = None):
self.root = root self.root = root
def __str__(self): def __str__(self) -> str:
""" """
Return a string of all the Nodes using in order traversal Return a string of all the Nodes using in order traversal
""" """
return str(self.root) return str(self.root)
def __reassign_nodes(self, node, new_children): def __reassign_nodes(self, node: Node, new_children: Node | None) -> None:
if new_children is not None: # reset its kids if new_children is not None: # reset its kids
new_children.parent = node.parent new_children.parent = node.parent
if node.parent is not None: # reset its parent if node.parent is not None: # reset its parent
@ -37,23 +40,27 @@ class BinarySearchTree:
else: else:
node.parent.left = new_children node.parent.left = new_children
else: else:
self.root = new_children self.root = None
def is_right(self, node): def is_right(self, node: Node) -> bool:
return node == node.parent.right if node.parent and node.parent.right:
return node == node.parent.right
return False
def empty(self): def empty(self) -> bool:
return self.root is None return self.root is None
def __insert(self, value): def __insert(self, value) -> None:
""" """
Insert a new node in Binary Search Tree with value label Insert a new node in Binary Search Tree with value label
""" """
new_node = Node(value, None) # create a new Node new_node = Node(value) # create a new Node
if self.empty(): # if Tree is empty if self.empty(): # if Tree is empty
self.root = new_node # set its root self.root = new_node # set its root
else: # Tree is not empty else: # Tree is not empty
parent_node = self.root # from root parent_node = self.root # from root
if parent_node is None:
return None
while True: # While we don't get to a leaf while True: # While we don't get to a leaf
if value < parent_node.value: # We go left if value < parent_node.value: # We go left
if parent_node.left is None: if parent_node.left is None:
@ -69,12 +76,11 @@ class BinarySearchTree:
parent_node = parent_node.right parent_node = parent_node.right
new_node.parent = parent_node new_node.parent = parent_node
def insert(self, *values): def insert(self, *values) -> None:
for value in values: for value in values:
self.__insert(value) self.__insert(value)
return self
def search(self, value): def search(self, value) -> Node | None:
if self.empty(): if self.empty():
raise IndexError("Warning: Tree is empty! please use another.") raise IndexError("Warning: Tree is empty! please use another.")
else: else:
@ -84,30 +90,35 @@ class BinarySearchTree:
node = node.left if value < node.value else node.right node = node.left if value < node.value else node.right
return node return node
def get_max(self, node=None): def get_max(self, node: Node | None = None) -> Node | None:
""" """
We go deep on the right branch We go deep on the right branch
""" """
if node is None: if node is None:
if self.root is None:
return None
node = self.root node = self.root
if not self.empty(): if not self.empty():
while node.right is not None: while node.right is not None:
node = node.right node = node.right
return node return node
def get_min(self, node=None): def get_min(self, node: Node | None = None) -> Node | None:
""" """
We go deep on the left branch We go deep on the left branch
""" """
if node is None: if node is None:
node = self.root node = self.root
if self.root is None:
return None
if not self.empty(): if not self.empty():
node = self.root node = self.root
while node.left is not None: while node.left is not None:
node = node.left node = node.left
return node return node
def remove(self, value): def remove(self, value: int) -> None:
node = self.search(value) # Look for the node with that label node = self.search(value) # Look for the node with that label
if node is not None: if node is not None:
if node.left is None and node.right is None: # If it has no children if node.left is None and node.right is None: # If it has no children
@ -120,18 +131,18 @@ class BinarySearchTree:
tmp_node = self.get_max( tmp_node = self.get_max(
node.left node.left
) # Gets the max value of the left branch ) # Gets the max value of the left branch
self.remove(tmp_node.value) self.remove(tmp_node.value) # type: ignore
node.value = ( node.value = (
tmp_node.value tmp_node.value # type: ignore
) # Assigns the value to the node to delete and keep tree structure ) # Assigns the value to the node to delete and keep tree structure
def preorder_traverse(self, node): def preorder_traverse(self, node: Node | None) -> Iterable:
if node is not None: if node is not None:
yield node # Preorder Traversal yield node # Preorder Traversal
yield from self.preorder_traverse(node.left) yield from self.preorder_traverse(node.left)
yield from self.preorder_traverse(node.right) yield from self.preorder_traverse(node.right)
def traversal_tree(self, traversal_function=None): def traversal_tree(self, traversal_function=None) -> Any:
""" """
This function traversal the tree. This function traversal the tree.
You can pass a function to traversal the tree as needed by client code You can pass a function to traversal the tree as needed by client code
@ -141,7 +152,7 @@ class BinarySearchTree:
else: else:
return traversal_function(self.root) return traversal_function(self.root)
def inorder(self, arr: list, node: Node): def inorder(self, arr: list, node: Node | None) -> None:
"""Perform an inorder traversal and append values of the nodes to """Perform an inorder traversal and append values of the nodes to
a list named arr""" a list named arr"""
if node: if node:
@ -151,12 +162,12 @@ class BinarySearchTree:
def find_kth_smallest(self, k: int, node: Node) -> int: def find_kth_smallest(self, k: int, node: Node) -> int:
"""Return the kth smallest element in a binary search tree""" """Return the kth smallest element in a binary search tree"""
arr: list = [] arr: list[int] = []
self.inorder(arr, node) # append all values to list using inorder traversal self.inorder(arr, node) # append all values to list using inorder traversal
return arr[k - 1] return arr[k - 1]
def postorder(curr_node): def postorder(curr_node: Node | None) -> list[Node]:
""" """
postOrder (left, right, self) postOrder (left, right, self)
""" """
@ -166,7 +177,7 @@ def postorder(curr_node):
return node_list return node_list
def binary_search_tree(): def binary_search_tree() -> None:
r""" r"""
Example Example
8 8
@ -177,7 +188,8 @@ def binary_search_tree():
/ \ / / \ /
4 7 13 4 7 13
>>> t = BinarySearchTree().insert(8, 3, 6, 1, 10, 14, 13, 4, 7) >>> t = BinarySearchTree()
>>> t.insert(8, 3, 6, 1, 10, 14, 13, 4, 7)
>>> print(" ".join(repr(i.value) for i in t.traversal_tree())) >>> print(" ".join(repr(i.value) for i in t.traversal_tree()))
8 3 1 6 4 7 10 14 13 8 3 1 6 4 7 10 14 13
>>> print(" ".join(repr(i.value) for i in t.traversal_tree(postorder))) >>> print(" ".join(repr(i.value) for i in t.traversal_tree(postorder)))
@ -206,8 +218,8 @@ def binary_search_tree():
print("The value -1 doesn't exist") print("The value -1 doesn't exist")
if not t.empty(): if not t.empty():
print("Max Value: ", t.get_max().value) print("Max Value: ", t.get_max().value) # type: ignore
print("Min Value: ", t.get_min().value) print("Min Value: ", t.get_min().value) # type: ignore
for i in testlist: for i in testlist:
t.remove(i) t.remove(i)
@ -217,5 +229,4 @@ def binary_search_tree():
if __name__ == "__main__": if __name__ == "__main__":
import doctest import doctest
doctest.testmod() doctest.testmod(verbose=True)
# binary_search_tree()