diff --git a/data_structures/binary_tree/binary_search_tree.py b/data_structures/binary_tree/binary_search_tree.py index 3f214d011..2b43d4692 100644 --- a/data_structures/binary_tree/binary_search_tree.py +++ b/data_structures/binary_tree/binary_search_tree.py @@ -107,12 +107,14 @@ class Node: """ >>> list(Node(0)) [0] - >>> list(Node(0, Node(-1), Node(1), None)) + >>> list(Node(0, Node(-1), Node(1))) [-1, 0, 1] """ - yield from self.left or [] + if self.left: + yield from self.left yield self.value - yield from self.right or [] + if self.right: + yield from self.right def __repr__(self) -> str: from pprint import pformat @@ -143,10 +145,10 @@ class BinarySearchTree: return str(self.root) def __reassign_nodes(self, node: Node, new_children: Node | None) -> None: - if new_children is not None: # reset its kids + if new_children is not None: new_children.parent = node.parent - if node.parent is not None: # reset its parent - if node.is_right: # If it is the right child + if node.parent is not None: + if node.is_right: node.parent.right = new_children else: node.parent.left = new_children @@ -167,37 +169,37 @@ class BinarySearchTree: """ return not self.root - def __insert(self, value) -> None: + def __insert(self, value: int) -> None: """ Insert a new node in Binary Search Tree with value label """ - new_node = Node(value) # create a new Node - if self.empty(): # if Tree is empty - self.root = new_node # set its root - else: # Tree is not empty - parent_node = self.root # from root - if parent_node is None: - return - while True: # While we don't get to a leaf - if value < parent_node.value: # We go left + new_node = Node(value) + if self.empty(): + self.root = new_node + else: + parent_node = self.root + while True: + if value < parent_node.value: if parent_node.left is None: - parent_node.left = new_node # We insert the new node in a leaf + parent_node.left = new_node + new_node.parent = parent_node break else: parent_node = parent_node.left - elif parent_node.right is None: - parent_node.right = new_node - break else: - parent_node = parent_node.right - new_node.parent = parent_node + if parent_node.right is None: + parent_node.right = new_node + new_node.parent = parent_node + break + else: + parent_node = parent_node.right - def insert(self, *values) -> Self: + def insert(self, *values: int) -> Self: for value in values: self.__insert(value) return self - def search(self, value) -> Node | None: + def search(self, value: int) -> Node | None: """ >>> tree = BinarySearchTree().insert(10, 20, 30, 40, 50) >>> tree.search(10) @@ -221,15 +223,12 @@ class BinarySearchTree: ... IndexError: Warning: Tree is empty! please use another. """ - if self.empty(): raise IndexError("Warning: Tree is empty! please use another.") - else: - node = self.root - # use lazy evaluation here to avoid NoneType Attribute error - while node is not None and node.value is not value: - node = node.left if value < node.value else node.right - return node + node = self.root + while node is not None and node.value != value: + node = node.left if value < node.value else node.right + return node def get_max(self, node: Node | None = None) -> Node | None: """ @@ -237,21 +236,19 @@ class BinarySearchTree: >>> BinarySearchTree().insert(10, 20, 30, 40, 50).get_max() 50 - >>> BinarySearchTree().insert(-5, -1, 0.1, -0.3, -4.5).get_max() - {'0.1': (-0.3, None)} + >>> BinarySearchTree().insert(-5, -1, 0, -0.3, -4.5).get_max() + {'0': (-0.3, None)} >>> BinarySearchTree().insert(1, 78.3, 30, 74.0, 1).get_max() {'78.3': ({'30': (1, 74.0)}, None)} >>> BinarySearchTree().insert(1, 783, 30, 740, 1).get_max() {'783': ({'30': (1, 740)}, None)} """ if node is None: - if self.root is None: + if self.empty(): return None node = self.root - - if not self.empty(): - while node.right is not None: - node = node.right + while node.right is not None: + node = node.right return node def get_min(self, node: Node | None = None) -> Node | None: @@ -268,54 +265,47 @@ class BinarySearchTree: {'1': (None, {'783': ({'30': (1, 740)}, None)})} """ if node is None: + if self.empty(): + return None node = self.root - if self.root is None: - return None - if not self.empty(): - node = self.root - while node.left is not None: - node = node.left + while node.left is not None: + node = node.left return node def remove(self, value: int) -> None: - # Look for the node with that label node = self.search(value) if node is None: - msg = f"Value {value} not found" - raise ValueError(msg) + raise ValueError(f"Value {value} not found") - if node.left is None and node.right is None: # If it has no children + if node.left is None and node.right is None: self.__reassign_nodes(node, None) - elif node.left is None: # Has only right children + elif node.left is None: self.__reassign_nodes(node, node.right) - elif node.right is None: # Has only left children + elif node.right is None: self.__reassign_nodes(node, node.left) else: - predecessor = self.get_max( - node.left - ) # Gets the max value of the left branch - self.remove(predecessor.value) # type: ignore[union-attr] - node.value = ( - predecessor.value # type: ignore[union-attr] - ) # Assigns the value to the node to delete and keep tree structure + predecessor = self.get_max(node.left) + if predecessor: + self.remove(predecessor.value) + node.value = predecessor.value - def preorder_traverse(self, node: Node | None) -> Iterable: + def preorder_traverse(self, node: Node | None) -> Iterable[Node]: if node is not None: - yield node # Preorder Traversal + yield node yield from self.preorder_traverse(node.left) yield from self.preorder_traverse(node.right) def traversal_tree(self, traversal_function=None) -> Any: """ - This function traversal the tree. - You can pass a function to traversal the tree as needed by client code + This function traverses the tree. + You can pass a function to traverse the tree as needed by client code """ if traversal_function is None: - return self.preorder_traverse(self.root) + return list(self.preorder_traverse(self.root)) else: return traversal_function(self.root) - def inorder(self, arr: list, node: Node | None) -> None: + def inorder(self, arr: list[int], node: Node | None) -> None: """Perform an inorder traversal and append values of the nodes to a list named arr""" if node: @@ -326,8 +316,10 @@ class BinarySearchTree: def find_kth_smallest(self, k: int, node: Node) -> int: """Return the kth smallest element in a binary search tree""" arr: list[int] = [] - self.inorder(arr, node) # append all values to list using inorder traversal - return arr[k - 1] + self.inorder(arr, node) + if 0 < k <= len(arr): + return arr[k - 1] + raise IndexError("k is out of bounds") def inorder(curr_node: Node | None) -> list[Node]: @@ -346,11 +338,4 @@ def postorder(curr_node: Node | None) -> list[Node]: """ node_list = [] if curr_node is not None: - node_list = postorder(curr_node.left) + postorder(curr_node.right) + [curr_node] - return node_list - - -if __name__ == "__main__": - import doctest - - doctest.testmod(verbose=True) + node_list = postorder(curr_node.left)