diff --git a/.vscode/settings.json b/.vscode/settings.json index ef16fa1aa..b36f128d4 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,14 @@ { "githubPullRequests.ignoredPullRequestBranches": [ "master" - ] + ], + "python.testing.unittestArgs": [ + "-v", + "-s", + ".", + "-p", + "*test.py" + ], + "python.testing.pytestEnabled": false, + "python.testing.unittestEnabled": true } diff --git a/data_structures/binary_tree/avl_tree.py b/data_structures/binary_tree/avl_tree.py index 9fca72374..e54046d5d 100644 --- a/data_structures/binary_tree/avl_tree.py +++ b/data_structures/binary_tree/avl_tree.py @@ -150,6 +150,8 @@ def rl_rotation(node: MyNode) -> MyNode: def insert_node(node: MyNode | None, data: Any) -> MyNode | None: if node is None: return MyNode(data) + if data == node.get_data(): + return node if data < node.get_data(): node.set_left(insert_node(node.get_left(), data)) if ( @@ -195,46 +197,67 @@ def get_left_most(root: MyNode) -> Any: return root.get_data() -def del_node(root: MyNode, data: Any) -> MyNode | None: - left_child = root.get_left() - right_child = root.get_right() - if root.get_data() == data: - if left_child is not None and right_child is not None: - temp_data = get_left_most(right_child) - root.set_data(temp_data) - root.set_right(del_node(right_child, temp_data)) - elif left_child is not None: - root = left_child - elif right_child is not None: - root = right_child - else: - return None - elif root.get_data() > data: - if left_child is None: - print("No such data") - return root - else: - root.set_left(del_node(left_child, data)) - # root.get_data() < data - elif right_child is None: - return root - else: - root.set_right(del_node(right_child, data)) +def get_balance(node: MyNode | None) -> int: + if node is None: + return 0 + return get_height(node.get_left()) - get_height(node.get_right()) - if get_height(right_child) - get_height(left_child) == 2: + +def get_min_value_node(node: MyNode) -> MyNode: + # Returns the node with the minimum value in the tree that is leftmost node + # Function get_left_most is not used here because it returns the value of the node + while True: + left_child = node.get_left() + if left_child is None: + break + node = left_child + return node + + +def del_node(root: MyNode | None, data: Any) -> MyNode | None: + if root is None: + print(f"{data} not found in the tree") + return None + + if root.get_data() > data: + left_child = del_node(root.get_left(), data) + root.set_left(left_child) + elif root.get_data() < data: + right_child = del_node(root.get_right(), data) + root.set_right(right_child) + else: + if root.get_left() is None: + return root.get_right() + elif root.get_right() is None: + return root.get_left() + right_child = root.get_right() assert right_child is not None - if get_height(right_child.get_right()) > get_height(right_child.get_left()): - root = left_rotation(root) - else: - root = rl_rotation(root) - elif get_height(right_child) - get_height(left_child) == -2: + temp = get_min_value_node(right_child) + root.set_data(temp.get_data()) + root.set_right(del_node(root.get_right(), temp.get_data())) + + root.set_height( + 1 + my_max(get_height(root.get_left()), get_height(root.get_right())) + ) + + balance = get_balance(root) + + if balance > 1: + left_child = root.get_left() assert left_child is not None - if get_height(left_child.get_left()) > get_height(left_child.get_right()): - root = right_rotation(root) - else: - root = lr_rotation(root) - height = my_max(get_height(root.get_right()), get_height(root.get_left())) + 1 - root.set_height(height) + if get_balance(left_child) >= 0: + return right_rotation(root) + root.set_left(left_rotation(left_child)) + return right_rotation(root) + + if balance < -1: + right_child = root.get_right() + assert right_child is not None + if get_balance(right_child) <= 0: + return left_rotation(root) + root.set_right(right_rotation(right_child)) + return left_rotation(root) + return root @@ -264,7 +287,7 @@ class AVLtree: ************************************* >>> t.get_height() 2 - >>> t.del_node(3) + >>> t.delete(3) delete:3 >>> print(str(t).replace(" \\n","\\n").replace(" \\n","\\n")) 4 @@ -282,7 +305,7 @@ class AVLtree: print("insert:" + str(data)) self.root = insert_node(self.root, data) - def del_node(self, data: Any) -> None: + def delete(self, data: Any) -> None: print("delete:" + str(data)) if self.root is None: print("Tree is empty!") @@ -341,5 +364,5 @@ if __name__ == "__main__": print(str(t)) random.shuffle(lst) for i in lst: - t.del_node(i) + t.delete(i) print(str(t))