diff --git a/data_structures/binary_tree/avl_tree.py b/data_structures/binary_tree/avl_tree.py index 9fca72374..96ef19cff 100644 --- a/data_structures/binary_tree/avl_tree.py +++ b/data_structures/binary_tree/avl_tree.py @@ -196,8 +196,13 @@ def get_left_most(root: MyNode) -> Any: def del_node(root: MyNode, data: Any) -> MyNode | None: + if root is None: + print("Nothing to delete") + return None + left_child = root.get_left() right_child = root.get_right() + if root.get_data() == data: if left_child is not None and right_child is not None: temp_data = get_left_most(right_child) @@ -221,20 +226,23 @@ def del_node(root: MyNode, data: Any) -> MyNode | None: else: root.set_right(del_node(right_child, data)) - if get_height(right_child) - get_height(left_child) == 2: + root.set_height(my_max(get_height(root.get_right()), get_height(root.get_left())) + 1) + + balance_factor = get_height(root.get_left()) - get_height(root.get_right()) + + if balance_factor == 2: assert right_child is not None if get_height(right_child.get_right()) > get_height(right_child.get_left()): root = left_rotation(root) else: root = rl_rotation(root) - elif get_height(right_child) - get_height(left_child) == -2: + elif balance_factor == -2: assert left_child is not None if get_height(left_child.get_left()) > get_height(left_child.get_right()): root = right_rotation(root) else: root = lr_rotation(root) - height = my_max(get_height(root.get_right()), get_height(root.get_left())) + 1 - root.set_height(height) + return root