""" This is a python3 implementation of binary search tree using recursion To run tests: python -m unittest binary_search_tree_recursive.py To run an example: python binary_search_tree_recursive.py """ import unittest class Node: def __init__(self, label: int, parent): self.label = label self.parent = parent self.left = None self.right = None class BinarySearchTree: def __init__(self): self.root = None def empty(self): """ Empties the tree >>> t = BinarySearchTree() >>> assert t.root is None >>> t.put(8) >>> assert t.root is not None """ self.root = None def is_empty(self) -> bool: """ Checks if the tree is empty >>> t = BinarySearchTree() >>> t.is_empty() True >>> t.put(8) >>> t.is_empty() False """ return self.root is None def put(self, label: int): """ Put a new node in the tree >>> t = BinarySearchTree() >>> t.put(8) >>> assert t.root.parent is None >>> assert t.root.label == 8 >>> t.put(10) >>> assert t.root.right.parent == t.root >>> assert t.root.right.label == 10 >>> t.put(3) >>> assert t.root.left.parent == t.root >>> assert t.root.left.label == 3 """ self.root = self._put(self.root, label) def _put(self, node: Node, label: int, parent: Node = None) -> Node: if node is None: node = Node(label, parent) else: if label < node.label: node.left = self._put(node.left, label, node) elif label > node.label: node.right = self._put(node.right, label, node) else: raise Exception(f"Node with label {label} already exists") return node def search(self, label: int) -> Node: """ Searches a node in the tree >>> t = BinarySearchTree() >>> t.put(8) >>> t.put(10) >>> node = t.search(8) >>> assert node.label == 8 >>> node = t.search(3) Traceback (most recent call last): ... Exception: Node with label 3 does not exist """ return self._search(self.root, label) def _search(self, node: Node, label: int) -> Node: if node is None: raise Exception(f"Node with label {label} does not exist") else: if label < node.label: node = self._search(node.left, label) elif label > node.label: node = self._search(node.right, label) return node def remove(self, label: int): """ Removes a node in the tree >>> t = BinarySearchTree() >>> t.put(8) >>> t.put(10) >>> t.remove(8) >>> assert t.root.label == 10 >>> t.remove(3) Traceback (most recent call last): ... Exception: Node with label 3 does not exist """ node = self.search(label) if not node.right and not node.left: self._reassign_nodes(node, None) elif not node.right and node.left: self._reassign_nodes(node, node.left) elif node.right and not node.left: self._reassign_nodes(node, node.right) else: lowest_node = self._get_lowest_node(node.right) lowest_node.left = node.left lowest_node.right = node.right node.left.parent = lowest_node if node.right: node.right.parent = lowest_node self._reassign_nodes(node, lowest_node) def _reassign_nodes(self, node: Node, new_children: Node): if new_children: new_children.parent = node.parent if node.parent: if node.parent.right == node: node.parent.right = new_children else: node.parent.left = new_children else: self.root = new_children def _get_lowest_node(self, node: Node) -> Node: if node.left: lowest_node = self._get_lowest_node(node.left) else: lowest_node = node self._reassign_nodes(node, node.right) return lowest_node def exists(self, label: int) -> bool: """ Checks if a node exists in the tree >>> t = BinarySearchTree() >>> t.put(8) >>> t.put(10) >>> t.exists(8) True >>> t.exists(3) False """ try: self.search(label) return True except Exception: return False def get_max_label(self) -> int: """ Gets the max label inserted in the tree >>> t = BinarySearchTree() >>> t.get_max_label() Traceback (most recent call last): ... Exception: Binary search tree is empty >>> t.put(8) >>> t.put(10) >>> t.get_max_label() 10 """ if self.is_empty(): raise Exception("Binary search tree is empty") node = self.root while node.right is not None: node = node.right return node.label def get_min_label(self) -> int: """ Gets the min label inserted in the tree >>> t = BinarySearchTree() >>> t.get_min_label() Traceback (most recent call last): ... Exception: Binary search tree is empty >>> t.put(8) >>> t.put(10) >>> t.get_min_label() 8 """ if self.is_empty(): raise Exception("Binary search tree is empty") node = self.root while node.left is not None: node = node.left return node.label def inorder_traversal(self) -> list: """ Return the inorder traversal of the tree >>> t = BinarySearchTree() >>> [i.label for i in t.inorder_traversal()] [] >>> t.put(8) >>> t.put(10) >>> t.put(9) >>> [i.label for i in t.inorder_traversal()] [8, 9, 10] """ return self._inorder_traversal(self.root) def _inorder_traversal(self, node: Node) -> list: if node is not None: yield from self._inorder_traversal(node.left) yield node yield from self._inorder_traversal(node.right) def preorder_traversal(self) -> list: """ Return the preorder traversal of the tree >>> t = BinarySearchTree() >>> [i.label for i in t.preorder_traversal()] [] >>> t.put(8) >>> t.put(10) >>> t.put(9) >>> [i.label for i in t.preorder_traversal()] [8, 10, 9] """ return self._preorder_traversal(self.root) def _preorder_traversal(self, node: Node) -> list: if node is not None: yield node yield from self._preorder_traversal(node.left) yield from self._preorder_traversal(node.right) class BinarySearchTreeTest(unittest.TestCase): @staticmethod def _get_binary_search_tree(): r""" 8 / \ 3 10 / \ \ 1 6 14 / \ / 4 7 13 \ 5 """ t = BinarySearchTree() t.put(8) t.put(3) t.put(6) t.put(1) t.put(10) t.put(14) t.put(13) t.put(4) t.put(7) t.put(5) return t def test_put(self): t = BinarySearchTree() assert t.is_empty() t.put(8) r""" 8 """ assert t.root.parent is None assert t.root.label == 8 t.put(10) r""" 8 \ 10 """ assert t.root.right.parent == t.root assert t.root.right.label == 10 t.put(3) r""" 8 / \ 3 10 """ assert t.root.left.parent == t.root assert t.root.left.label == 3 t.put(6) r""" 8 / \ 3 10 \ 6 """ assert t.root.left.right.parent == t.root.left assert t.root.left.right.label == 6 t.put(1) r""" 8 / \ 3 10 / \ 1 6 """ assert t.root.left.left.parent == t.root.left assert t.root.left.left.label == 1 with self.assertRaises(Exception): t.put(1) def test_search(self): t = self._get_binary_search_tree() node = t.search(6) assert node.label == 6 node = t.search(13) assert node.label == 13 with self.assertRaises(Exception): t.search(2) def test_remove(self): t = self._get_binary_search_tree() t.remove(13) r""" 8 / \ 3 10 / \ \ 1 6 14 / \ 4 7 \ 5 """ assert t.root.right.right.right is None assert t.root.right.right.left is None t.remove(7) r""" 8 / \ 3 10 / \ \ 1 6 14 / 4 \ 5 """ assert t.root.left.right.right is None assert t.root.left.right.left.label == 4 t.remove(6) r""" 8 / \ 3 10 / \ \ 1 4 14 \ 5 """ assert t.root.left.left.label == 1 assert t.root.left.right.label == 4 assert t.root.left.right.right.label == 5 assert t.root.left.right.left is None assert t.root.left.left.parent == t.root.left assert t.root.left.right.parent == t.root.left t.remove(3) r""" 8 / \ 4 10 / \ \ 1 5 14 """ assert t.root.left.label == 4 assert t.root.left.right.label == 5 assert t.root.left.left.label == 1 assert t.root.left.parent == t.root assert t.root.left.left.parent == t.root.left assert t.root.left.right.parent == t.root.left t.remove(4) r""" 8 / \ 5 10 / \ 1 14 """ assert t.root.left.label == 5 assert t.root.left.right is None assert t.root.left.left.label == 1 assert t.root.left.parent == t.root assert t.root.left.left.parent == t.root.left def test_remove_2(self): t = self._get_binary_search_tree() t.remove(3) r""" 8 / \ 4 10 / \ \ 1 6 14 / \ / 5 7 13 """ assert t.root.left.label == 4 assert t.root.left.right.label == 6 assert t.root.left.left.label == 1 assert t.root.left.right.right.label == 7 assert t.root.left.right.left.label == 5 assert t.root.left.parent == t.root assert t.root.left.right.parent == t.root.left assert t.root.left.left.parent == t.root.left assert t.root.left.right.left.parent == t.root.left.right def test_empty(self): t = self._get_binary_search_tree() t.empty() assert t.root is None def test_is_empty(self): t = self._get_binary_search_tree() assert not t.is_empty() t.empty() assert t.is_empty() def test_exists(self): t = self._get_binary_search_tree() assert t.exists(6) assert not t.exists(-1) def test_get_max_label(self): t = self._get_binary_search_tree() assert t.get_max_label() == 14 t.empty() with self.assertRaises(Exception): t.get_max_label() def test_get_min_label(self): t = self._get_binary_search_tree() assert t.get_min_label() == 1 t.empty() with self.assertRaises(Exception): t.get_min_label() def test_inorder_traversal(self): t = self._get_binary_search_tree() inorder_traversal_nodes = [i.label for i in t.inorder_traversal()] assert inorder_traversal_nodes == [1, 3, 4, 5, 6, 7, 8, 10, 13, 14] def test_preorder_traversal(self): t = self._get_binary_search_tree() preorder_traversal_nodes = [i.label for i in t.preorder_traversal()] assert preorder_traversal_nodes == [8, 3, 1, 6, 4, 5, 7, 10, 14, 13] def binary_search_tree_example(): r""" Example 8 / \ 3 10 / \ \ 1 6 14 / \ / 4 7 13 \ 5 Example After Deletion 4 / \ 1 7 \ 5 """ t = BinarySearchTree() t.put(8) t.put(3) t.put(6) t.put(1) t.put(10) t.put(14) t.put(13) t.put(4) t.put(7) t.put(5) print( """ 8 / \\ 3 10 / \\ \\ 1 6 14 / \\ / 4 7 13 \\ 5 """ ) print("Label 6 exists:", t.exists(6)) print("Label 13 exists:", t.exists(13)) print("Label -1 exists:", t.exists(-1)) print("Label 12 exists:", t.exists(12)) # Prints all the elements of the list in inorder traversal inorder_traversal_nodes = [i.label for i in t.inorder_traversal()] print("Inorder traversal:", inorder_traversal_nodes) # Prints all the elements of the list in preorder traversal preorder_traversal_nodes = [i.label for i in t.preorder_traversal()] print("Preorder traversal:", preorder_traversal_nodes) print("Max. label:", t.get_max_label()) print("Min. label:", t.get_min_label()) # Delete elements print("\nDeleting elements 13, 10, 8, 3, 6, 14") print( """ 4 / \\ 1 7 \\ 5 """ ) t.remove(13) t.remove(10) t.remove(8) t.remove(3) t.remove(6) t.remove(14) # Prints all the elements of the list in inorder traversal after delete inorder_traversal_nodes = [i.label for i in t.inorder_traversal()] print("Inorder traversal after delete:", inorder_traversal_nodes) # Prints all the elements of the list in preorder traversal after delete preorder_traversal_nodes = [i.label for i in t.preorder_traversal()] print("Preorder traversal after delete:", preorder_traversal_nodes) print("Max. label:", t.get_max_label()) print("Min. label:", t.get_min_label()) if __name__ == "__main__": binary_search_tree_example()