diff --git a/data_structures/binary_tree/binary_search_tree.py b/data_structures/binary_tree/binary_search_tree.py index 51a651be0..fc60540a1 100644 --- a/data_structures/binary_tree/binary_search_tree.py +++ b/data_structures/binary_tree/binary_search_tree.py @@ -2,15 +2,18 @@ A binary search Tree """ +from collections.abc import Iterable +from typing import Any + class Node: - def __init__(self, value, parent): + def __init__(self, value: int | None = None): self.value = value - self.parent = parent # Added in order to delete a node easier - self.left = None - self.right = None + self.parent: Node | None = None # Added in order to delete a node easier + self.left: Node | None = None + self.right: Node | None = None - def __repr__(self): + def __repr__(self) -> str: from pprint import pformat if self.left is None and self.right is None: @@ -19,16 +22,16 @@ class Node: class BinarySearchTree: - def __init__(self, root=None): + def __init__(self, root: Node | None = None): self.root = root - def __str__(self): + def __str__(self) -> str: """ Return a string of all the Nodes using in order traversal """ return str(self.root) - def __reassign_nodes(self, node, new_children): + def __reassign_nodes(self, node: Node, new_children: Node | None) -> None: if new_children is not None: # reset its kids new_children.parent = node.parent if node.parent is not None: # reset its parent @@ -37,23 +40,27 @@ class BinarySearchTree: else: node.parent.left = new_children else: - self.root = new_children + self.root = None - def is_right(self, node): - return node == node.parent.right + def is_right(self, node: Node) -> bool: + if node.parent and node.parent.right: + return node == node.parent.right + return False - def empty(self): + def empty(self) -> bool: return self.root is None - def __insert(self, value): + def __insert(self, value) -> None: """ Insert a new node in Binary Search Tree with value label """ - new_node = Node(value, None) # create a new Node + new_node = Node(value) # create a new Node if self.empty(): # if Tree is empty self.root = new_node # set its root else: # Tree is not empty parent_node = self.root # from root + if parent_node is None: + return None while True: # While we don't get to a leaf if value < parent_node.value: # We go left if parent_node.left is None: @@ -69,12 +76,11 @@ class BinarySearchTree: parent_node = parent_node.right new_node.parent = parent_node - def insert(self, *values): + def insert(self, *values) -> None: for value in values: self.__insert(value) - return self - def search(self, value): + def search(self, value) -> Node | None: if self.empty(): raise IndexError("Warning: Tree is empty! please use another.") else: @@ -84,30 +90,35 @@ class BinarySearchTree: node = node.left if value < node.value else node.right return node - def get_max(self, node=None): + def get_max(self, node: Node | None = None) -> Node | None: """ We go deep on the right branch """ if node is None: + if self.root is None: + return None node = self.root + if not self.empty(): while node.right is not None: node = node.right return node - def get_min(self, node=None): + def get_min(self, node: Node | None = None) -> Node | None: """ We go deep on the left branch """ if node is None: node = self.root + if self.root is None: + return None if not self.empty(): node = self.root while node.left is not None: node = node.left return node - def remove(self, value): + def remove(self, value: int) -> None: node = self.search(value) # Look for the node with that label if node is not None: if node.left is None and node.right is None: # If it has no children @@ -120,18 +131,18 @@ class BinarySearchTree: tmp_node = self.get_max( node.left ) # Gets the max value of the left branch - self.remove(tmp_node.value) + self.remove(tmp_node.value) # type: ignore node.value = ( - tmp_node.value + tmp_node.value # type: ignore ) # Assigns the value to the node to delete and keep tree structure - def preorder_traverse(self, node): + def preorder_traverse(self, node: Node | None) -> Iterable: if node is not None: yield node # Preorder Traversal yield from self.preorder_traverse(node.left) yield from self.preorder_traverse(node.right) - def traversal_tree(self, traversal_function=None): + def traversal_tree(self, traversal_function=None) -> Any: """ This function traversal the tree. You can pass a function to traversal the tree as needed by client code @@ -141,7 +152,7 @@ class BinarySearchTree: else: return traversal_function(self.root) - def inorder(self, arr: list, node: Node): + def inorder(self, arr: list, node: Node | None) -> None: """Perform an inorder traversal and append values of the nodes to a list named arr""" if node: @@ -151,12 +162,12 @@ class BinarySearchTree: def find_kth_smallest(self, k: int, node: Node) -> int: """Return the kth smallest element in a binary search tree""" - arr: list = [] + arr: list[int] = [] self.inorder(arr, node) # append all values to list using inorder traversal return arr[k - 1] -def postorder(curr_node): +def postorder(curr_node: Node | None) -> list[Node]: """ postOrder (left, right, self) """ @@ -166,7 +177,7 @@ def postorder(curr_node): return node_list -def binary_search_tree(): +def binary_search_tree() -> None: r""" Example 8 @@ -177,7 +188,8 @@ def binary_search_tree(): / \ / 4 7 13 - >>> t = BinarySearchTree().insert(8, 3, 6, 1, 10, 14, 13, 4, 7) + >>> t = BinarySearchTree() + >>> t.insert(8, 3, 6, 1, 10, 14, 13, 4, 7) >>> print(" ".join(repr(i.value) for i in t.traversal_tree())) 8 3 1 6 4 7 10 14 13 >>> print(" ".join(repr(i.value) for i in t.traversal_tree(postorder))) @@ -206,8 +218,8 @@ def binary_search_tree(): print("The value -1 doesn't exist") if not t.empty(): - print("Max Value: ", t.get_max().value) - print("Min Value: ", t.get_min().value) + print("Max Value: ", t.get_max().value) # type: ignore + print("Min Value: ", t.get_min().value) # type: ignore for i in testlist: t.remove(i) @@ -217,5 +229,4 @@ def binary_search_tree(): if __name__ == "__main__": import doctest - doctest.testmod() - # binary_search_tree() + doctest.testmod(verbose=True)