diff --git a/data_structures/binary_tree/red_black_tree.py b/data_structures/binary_tree/red_black_tree.py index e27757f20..35517f307 100644 --- a/data_structures/binary_tree/red_black_tree.py +++ b/data_structures/binary_tree/red_black_tree.py @@ -51,6 +51,8 @@ class RedBlackTree: """ parent = self.parent right = self.right + if right is None: + return self self.right = right.left if self.right: self.right.parent = self @@ -69,6 +71,8 @@ class RedBlackTree: returns the new root to this subtree. Performing one rotation can be done in O(1). """ + if self.left is None: + return self parent = self.parent left = self.left self.left = left.right @@ -123,23 +127,30 @@ class RedBlackTree: if color(uncle) == 0: if self.is_left() and self.parent.is_right(): self.parent.rotate_right() - self.right._insert_repair() + if self.right: + self.right._insert_repair() elif self.is_right() and self.parent.is_left(): self.parent.rotate_left() - self.left._insert_repair() + if self.left: + self.left._insert_repair() elif self.is_left(): - self.grandparent.rotate_right() - self.parent.color = 0 - self.parent.right.color = 1 + if self.grandparent: + self.grandparent.rotate_right() + self.parent.color = 0 + if self.parent.right: + self.parent.right.color = 1 else: - self.grandparent.rotate_left() - self.parent.color = 0 - self.parent.left.color = 1 + if self.grandparent: + self.grandparent.rotate_left() + self.parent.color = 0 + if self.parent.left: + self.parent.left.color = 1 else: self.parent.color = 0 - uncle.color = 0 - self.grandparent.color = 1 - self.grandparent._insert_repair() + if uncle and self.grandparent: + uncle.color = 0 + self.grandparent.color = 1 + self.grandparent._insert_repair() def remove(self, label: int) -> RedBlackTree: """Remove label from this tree.""" @@ -149,8 +160,9 @@ class RedBlackTree: # so we replace this node with the greatest one less than # it and remove that. value = self.left.get_max() - self.label = value - self.left.remove(value) + if value is not None: + self.label = value + self.left.remove(value) else: # This node has at most one non-None child, so we don't # need to replace @@ -160,10 +172,11 @@ class RedBlackTree: # The only way this happens to a node with one child # is if both children are None leaves. # We can just remove this node and call it a day. - if self.is_left(): - self.parent.left = None - else: - self.parent.right = None + if self.parent: + if self.is_left(): + self.parent.left = None + else: + self.parent.right = None else: # The node is black if child is None: @@ -188,7 +201,7 @@ class RedBlackTree: self.left.parent = self if self.right: self.right.parent = self - elif self.label > label: + elif self.label is not None and self.label > label: if self.left: self.left.remove(label) else: @@ -198,6 +211,13 @@ class RedBlackTree: def _remove_repair(self) -> None: """Repair the coloring of the tree that may have been messed up.""" + if ( + self.parent is None + or self.sibling is None + or self.parent.sibling is None + or self.grandparent is None + ): + return if color(self.sibling) == 1: self.sibling.color = 0 self.parent.color = 1 @@ -231,7 +251,8 @@ class RedBlackTree: ): self.sibling.rotate_right() self.sibling.color = 0 - self.sibling.right.color = 1 + if self.sibling.right: + self.sibling.right.color = 1 if ( self.is_right() and color(self.sibling) == 0 @@ -240,7 +261,8 @@ class RedBlackTree: ): self.sibling.rotate_left() self.sibling.color = 0 - self.sibling.left.color = 1 + if self.sibling.left: + self.sibling.left.color = 1 if ( self.is_left() and color(self.sibling) == 0 @@ -275,21 +297,17 @@ class RedBlackTree: """ # I assume property 1 to hold because there is nothing that can # make the color be anything other than 0 or 1. - # Property 2 if self.color: # The root was red print("Property 2") return False - # Property 3 does not need to be checked, because None is assumed # to be black and is all the leaves. - # Property 4 if not self.check_coloring(): print("Property 4") return False - # Property 5 if self.black_height() is None: print("Property 5") @@ -297,7 +315,7 @@ class RedBlackTree: # All properties were met return True - def check_coloring(self) -> None: + def check_coloring(self) -> bool: """A helper function to recursively check Property 4 of a Red-Black Tree. See check_color_properties for more info. """ @@ -310,12 +328,12 @@ class RedBlackTree: return False return True - def black_height(self) -> int: + def black_height(self) -> int | None: """Returns the number of black nodes from this node to the leaves of the tree, or None if there isn't one such value (the tree is color incorrectly). """ - if self is None: + if self is None or self.left is None or self.right is None: # If we're already at a leaf, there is no path return 1 left = RedBlackTree.black_height(self.left) @@ -332,21 +350,21 @@ class RedBlackTree: # Here are functions which are general to all binary search trees - def __contains__(self, label) -> bool: + def __contains__(self, label: int) -> bool: """Search through the tree for label, returning True iff it is found somewhere in the tree. Guaranteed to run in O(log(n)) time. """ return self.search(label) is not None - def search(self, label: int) -> RedBlackTree: + def search(self, label: int) -> RedBlackTree | None: """Search through the tree for label, returning its node if it's found, and None otherwise. This method is guaranteed to run in O(log(n)) time. """ if self.label == label: return self - elif label > self.label: + elif self.label is not None and label > self.label: if self.right is None: return None else: @@ -357,12 +375,12 @@ class RedBlackTree: else: return self.left.search(label) - def floor(self, label: int) -> int: + def floor(self, label: int) -> int | None: """Returns the largest element in this tree which is at most label. This method is guaranteed to run in O(log(n)) time.""" if self.label == label: return self.label - elif self.label > label: + elif self.label is not None and self.label > label: if self.left: return self.left.floor(label) else: @@ -374,13 +392,13 @@ class RedBlackTree: return attempt return self.label - def ceil(self, label: int) -> int: + def ceil(self, label: int) -> int | None: """Returns the smallest element in this tree which is at least label. This method is guaranteed to run in O(log(n)) time. """ if self.label == label: return self.label - elif self.label < label: + elif self.label is not None and self.label < label: if self.right: return self.right.ceil(label) else: @@ -392,7 +410,7 @@ class RedBlackTree: return attempt return self.label - def get_max(self) -> int: + def get_max(self) -> int | None: """Returns the largest element in this tree. This method is guaranteed to run in O(log(n)) time. """ @@ -402,7 +420,7 @@ class RedBlackTree: else: return self.label - def get_min(self) -> int: + def get_min(self) -> int | None: """Returns the smallest element in this tree. This method is guaranteed to run in O(log(n)) time. """ @@ -413,7 +431,7 @@ class RedBlackTree: return self.label @property - def grandparent(self) -> RedBlackTree: + def grandparent(self) -> RedBlackTree | None: """Get the current node's grandparent, or None if it doesn't exist.""" if self.parent is None: return None @@ -421,7 +439,7 @@ class RedBlackTree: return self.parent.parent @property - def sibling(self) -> RedBlackTree: + def sibling(self) -> RedBlackTree | None: """Get the current node's sibling, or None if it doesn't exist.""" if self.parent is None: return None @@ -432,11 +450,15 @@ class RedBlackTree: def is_left(self) -> bool: """Returns true iff this node is the left child of its parent.""" - return self.parent and self.parent.left is self + if self.parent is None: + return False + return self.parent.left is self.parent.left is self def is_right(self) -> bool: """Returns true iff this node is the right child of its parent.""" - return self.parent and self.parent.right is self + if self.parent is None: + return False + return self.parent.right is self def __bool__(self) -> bool: return True @@ -452,21 +474,21 @@ class RedBlackTree: ln += len(self.right) return ln - def preorder_traverse(self) -> Iterator[int]: + def preorder_traverse(self) -> Iterator[int | None]: yield self.label if self.left: yield from self.left.preorder_traverse() if self.right: yield from self.right.preorder_traverse() - def inorder_traverse(self) -> Iterator[int]: + def inorder_traverse(self) -> Iterator[int | None]: if self.left: yield from self.left.inorder_traverse() yield self.label if self.right: yield from self.right.inorder_traverse() - def postorder_traverse(self) -> Iterator[int]: + def postorder_traverse(self) -> Iterator[int | None]: if self.left: yield from self.left.postorder_traverse() if self.right: @@ -488,15 +510,17 @@ class RedBlackTree: indent=1, ) - def __eq__(self, other) -> bool: + def __eq__(self, other: object) -> bool: """Test if two trees are equal.""" + if not isinstance(other, RedBlackTree): + return NotImplemented if self.label == other.label: return self.left == other.left and self.right == other.right else: return False -def color(node) -> int: +def color(node: RedBlackTree | None) -> int: """Returns the color of a node, allowing for None leaves.""" if node is None: return 0 @@ -699,19 +723,12 @@ def main() -> None: >>> pytests() """ print_results("Rotating right and left", test_rotations()) - print_results("Inserting", test_insert()) - print_results("Searching", test_insert_and_search()) - print_results("Deleting", test_insert_delete()) - print_results("Floor and ceil", test_floor_ceil()) - print_results("Tree traversal", test_tree_traversal()) - print_results("Tree traversal", test_tree_chaining()) - print("Testing tree balancing...") print("This should only be a few seconds.") test_insertion_speed()