[mypy] Fix type annotations in data_structures/binary_tree/red_black_tree.py (#5739)

* [mypy] Fix type annotations in red_black_tree.py

* Remove blank lines

* Update red_black_tree.py
This commit is contained in:
Dylan Buchi 2021-11-04 12:38:43 -03:00 committed by GitHub
parent e835e96856
commit 7a605766fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -51,6 +51,8 @@ class RedBlackTree:
""" """
parent = self.parent parent = self.parent
right = self.right right = self.right
if right is None:
return self
self.right = right.left self.right = right.left
if self.right: if self.right:
self.right.parent = self self.right.parent = self
@ -69,6 +71,8 @@ class RedBlackTree:
returns the new root to this subtree. returns the new root to this subtree.
Performing one rotation can be done in O(1). Performing one rotation can be done in O(1).
""" """
if self.left is None:
return self
parent = self.parent parent = self.parent
left = self.left left = self.left
self.left = left.right self.left = left.right
@ -123,23 +127,30 @@ class RedBlackTree:
if color(uncle) == 0: if color(uncle) == 0:
if self.is_left() and self.parent.is_right(): if self.is_left() and self.parent.is_right():
self.parent.rotate_right() self.parent.rotate_right()
self.right._insert_repair() if self.right:
self.right._insert_repair()
elif self.is_right() and self.parent.is_left(): elif self.is_right() and self.parent.is_left():
self.parent.rotate_left() self.parent.rotate_left()
self.left._insert_repair() if self.left:
self.left._insert_repair()
elif self.is_left(): elif self.is_left():
self.grandparent.rotate_right() if self.grandparent:
self.parent.color = 0 self.grandparent.rotate_right()
self.parent.right.color = 1 self.parent.color = 0
if self.parent.right:
self.parent.right.color = 1
else: else:
self.grandparent.rotate_left() if self.grandparent:
self.parent.color = 0 self.grandparent.rotate_left()
self.parent.left.color = 1 self.parent.color = 0
if self.parent.left:
self.parent.left.color = 1
else: else:
self.parent.color = 0 self.parent.color = 0
uncle.color = 0 if uncle and self.grandparent:
self.grandparent.color = 1 uncle.color = 0
self.grandparent._insert_repair() self.grandparent.color = 1
self.grandparent._insert_repair()
def remove(self, label: int) -> RedBlackTree: def remove(self, label: int) -> RedBlackTree:
"""Remove label from this tree.""" """Remove label from this tree."""
@ -149,8 +160,9 @@ class RedBlackTree:
# so we replace this node with the greatest one less than # so we replace this node with the greatest one less than
# it and remove that. # it and remove that.
value = self.left.get_max() value = self.left.get_max()
self.label = value if value is not None:
self.left.remove(value) self.label = value
self.left.remove(value)
else: else:
# This node has at most one non-None child, so we don't # This node has at most one non-None child, so we don't
# need to replace # need to replace
@ -160,10 +172,11 @@ class RedBlackTree:
# The only way this happens to a node with one child # The only way this happens to a node with one child
# is if both children are None leaves. # is if both children are None leaves.
# We can just remove this node and call it a day. # We can just remove this node and call it a day.
if self.is_left(): if self.parent:
self.parent.left = None if self.is_left():
else: self.parent.left = None
self.parent.right = None else:
self.parent.right = None
else: else:
# The node is black # The node is black
if child is None: if child is None:
@ -188,7 +201,7 @@ class RedBlackTree:
self.left.parent = self self.left.parent = self
if self.right: if self.right:
self.right.parent = self self.right.parent = self
elif self.label > label: elif self.label is not None and self.label > label:
if self.left: if self.left:
self.left.remove(label) self.left.remove(label)
else: else:
@ -198,6 +211,13 @@ class RedBlackTree:
def _remove_repair(self) -> None: def _remove_repair(self) -> None:
"""Repair the coloring of the tree that may have been messed up.""" """Repair the coloring of the tree that may have been messed up."""
if (
self.parent is None
or self.sibling is None
or self.parent.sibling is None
or self.grandparent is None
):
return
if color(self.sibling) == 1: if color(self.sibling) == 1:
self.sibling.color = 0 self.sibling.color = 0
self.parent.color = 1 self.parent.color = 1
@ -231,7 +251,8 @@ class RedBlackTree:
): ):
self.sibling.rotate_right() self.sibling.rotate_right()
self.sibling.color = 0 self.sibling.color = 0
self.sibling.right.color = 1 if self.sibling.right:
self.sibling.right.color = 1
if ( if (
self.is_right() self.is_right()
and color(self.sibling) == 0 and color(self.sibling) == 0
@ -240,7 +261,8 @@ class RedBlackTree:
): ):
self.sibling.rotate_left() self.sibling.rotate_left()
self.sibling.color = 0 self.sibling.color = 0
self.sibling.left.color = 1 if self.sibling.left:
self.sibling.left.color = 1
if ( if (
self.is_left() self.is_left()
and color(self.sibling) == 0 and color(self.sibling) == 0
@ -275,21 +297,17 @@ class RedBlackTree:
""" """
# I assume property 1 to hold because there is nothing that can # I assume property 1 to hold because there is nothing that can
# make the color be anything other than 0 or 1. # make the color be anything other than 0 or 1.
# Property 2 # Property 2
if self.color: if self.color:
# The root was red # The root was red
print("Property 2") print("Property 2")
return False return False
# Property 3 does not need to be checked, because None is assumed # Property 3 does not need to be checked, because None is assumed
# to be black and is all the leaves. # to be black and is all the leaves.
# Property 4 # Property 4
if not self.check_coloring(): if not self.check_coloring():
print("Property 4") print("Property 4")
return False return False
# Property 5 # Property 5
if self.black_height() is None: if self.black_height() is None:
print("Property 5") print("Property 5")
@ -297,7 +315,7 @@ class RedBlackTree:
# All properties were met # All properties were met
return True return True
def check_coloring(self) -> None: def check_coloring(self) -> bool:
"""A helper function to recursively check Property 4 of a """A helper function to recursively check Property 4 of a
Red-Black Tree. See check_color_properties for more info. Red-Black Tree. See check_color_properties for more info.
""" """
@ -310,12 +328,12 @@ class RedBlackTree:
return False return False
return True return True
def black_height(self) -> int: def black_height(self) -> int | None:
"""Returns the number of black nodes from this node to the """Returns the number of black nodes from this node to the
leaves of the tree, or None if there isn't one such value (the leaves of the tree, or None if there isn't one such value (the
tree is color incorrectly). tree is color incorrectly).
""" """
if self is None: if self is None or self.left is None or self.right is None:
# If we're already at a leaf, there is no path # If we're already at a leaf, there is no path
return 1 return 1
left = RedBlackTree.black_height(self.left) left = RedBlackTree.black_height(self.left)
@ -332,21 +350,21 @@ class RedBlackTree:
# Here are functions which are general to all binary search trees # Here are functions which are general to all binary search trees
def __contains__(self, label) -> bool: def __contains__(self, label: int) -> bool:
"""Search through the tree for label, returning True iff it is """Search through the tree for label, returning True iff it is
found somewhere in the tree. found somewhere in the tree.
Guaranteed to run in O(log(n)) time. Guaranteed to run in O(log(n)) time.
""" """
return self.search(label) is not None return self.search(label) is not None
def search(self, label: int) -> RedBlackTree: def search(self, label: int) -> RedBlackTree | None:
"""Search through the tree for label, returning its node if """Search through the tree for label, returning its node if
it's found, and None otherwise. it's found, and None otherwise.
This method is guaranteed to run in O(log(n)) time. This method is guaranteed to run in O(log(n)) time.
""" """
if self.label == label: if self.label == label:
return self return self
elif label > self.label: elif self.label is not None and label > self.label:
if self.right is None: if self.right is None:
return None return None
else: else:
@ -357,12 +375,12 @@ class RedBlackTree:
else: else:
return self.left.search(label) return self.left.search(label)
def floor(self, label: int) -> int: def floor(self, label: int) -> int | None:
"""Returns the largest element in this tree which is at most label. """Returns the largest element in this tree which is at most label.
This method is guaranteed to run in O(log(n)) time.""" This method is guaranteed to run in O(log(n)) time."""
if self.label == label: if self.label == label:
return self.label return self.label
elif self.label > label: elif self.label is not None and self.label > label:
if self.left: if self.left:
return self.left.floor(label) return self.left.floor(label)
else: else:
@ -374,13 +392,13 @@ class RedBlackTree:
return attempt return attempt
return self.label return self.label
def ceil(self, label: int) -> int: def ceil(self, label: int) -> int | None:
"""Returns the smallest element in this tree which is at least label. """Returns the smallest element in this tree which is at least label.
This method is guaranteed to run in O(log(n)) time. This method is guaranteed to run in O(log(n)) time.
""" """
if self.label == label: if self.label == label:
return self.label return self.label
elif self.label < label: elif self.label is not None and self.label < label:
if self.right: if self.right:
return self.right.ceil(label) return self.right.ceil(label)
else: else:
@ -392,7 +410,7 @@ class RedBlackTree:
return attempt return attempt
return self.label return self.label
def get_max(self) -> int: def get_max(self) -> int | None:
"""Returns the largest element in this tree. """Returns the largest element in this tree.
This method is guaranteed to run in O(log(n)) time. This method is guaranteed to run in O(log(n)) time.
""" """
@ -402,7 +420,7 @@ class RedBlackTree:
else: else:
return self.label return self.label
def get_min(self) -> int: def get_min(self) -> int | None:
"""Returns the smallest element in this tree. """Returns the smallest element in this tree.
This method is guaranteed to run in O(log(n)) time. This method is guaranteed to run in O(log(n)) time.
""" """
@ -413,7 +431,7 @@ class RedBlackTree:
return self.label return self.label
@property @property
def grandparent(self) -> RedBlackTree: def grandparent(self) -> RedBlackTree | None:
"""Get the current node's grandparent, or None if it doesn't exist.""" """Get the current node's grandparent, or None if it doesn't exist."""
if self.parent is None: if self.parent is None:
return None return None
@ -421,7 +439,7 @@ class RedBlackTree:
return self.parent.parent return self.parent.parent
@property @property
def sibling(self) -> RedBlackTree: def sibling(self) -> RedBlackTree | None:
"""Get the current node's sibling, or None if it doesn't exist.""" """Get the current node's sibling, or None if it doesn't exist."""
if self.parent is None: if self.parent is None:
return None return None
@ -432,11 +450,15 @@ class RedBlackTree:
def is_left(self) -> bool: def is_left(self) -> bool:
"""Returns true iff this node is the left child of its parent.""" """Returns true iff this node is the left child of its parent."""
return self.parent and self.parent.left is self if self.parent is None:
return False
return self.parent.left is self.parent.left is self
def is_right(self) -> bool: def is_right(self) -> bool:
"""Returns true iff this node is the right child of its parent.""" """Returns true iff this node is the right child of its parent."""
return self.parent and self.parent.right is self if self.parent is None:
return False
return self.parent.right is self
def __bool__(self) -> bool: def __bool__(self) -> bool:
return True return True
@ -452,21 +474,21 @@ class RedBlackTree:
ln += len(self.right) ln += len(self.right)
return ln return ln
def preorder_traverse(self) -> Iterator[int]: def preorder_traverse(self) -> Iterator[int | None]:
yield self.label yield self.label
if self.left: if self.left:
yield from self.left.preorder_traverse() yield from self.left.preorder_traverse()
if self.right: if self.right:
yield from self.right.preorder_traverse() yield from self.right.preorder_traverse()
def inorder_traverse(self) -> Iterator[int]: def inorder_traverse(self) -> Iterator[int | None]:
if self.left: if self.left:
yield from self.left.inorder_traverse() yield from self.left.inorder_traverse()
yield self.label yield self.label
if self.right: if self.right:
yield from self.right.inorder_traverse() yield from self.right.inorder_traverse()
def postorder_traverse(self) -> Iterator[int]: def postorder_traverse(self) -> Iterator[int | None]:
if self.left: if self.left:
yield from self.left.postorder_traverse() yield from self.left.postorder_traverse()
if self.right: if self.right:
@ -488,15 +510,17 @@ class RedBlackTree:
indent=1, indent=1,
) )
def __eq__(self, other) -> bool: def __eq__(self, other: object) -> bool:
"""Test if two trees are equal.""" """Test if two trees are equal."""
if not isinstance(other, RedBlackTree):
return NotImplemented
if self.label == other.label: if self.label == other.label:
return self.left == other.left and self.right == other.right return self.left == other.left and self.right == other.right
else: else:
return False return False
def color(node) -> int: def color(node: RedBlackTree | None) -> int:
"""Returns the color of a node, allowing for None leaves.""" """Returns the color of a node, allowing for None leaves."""
if node is None: if node is None:
return 0 return 0
@ -699,19 +723,12 @@ def main() -> None:
>>> pytests() >>> pytests()
""" """
print_results("Rotating right and left", test_rotations()) print_results("Rotating right and left", test_rotations())
print_results("Inserting", test_insert()) print_results("Inserting", test_insert())
print_results("Searching", test_insert_and_search()) print_results("Searching", test_insert_and_search())
print_results("Deleting", test_insert_delete()) print_results("Deleting", test_insert_delete())
print_results("Floor and ceil", test_floor_ceil()) print_results("Floor and ceil", test_floor_ceil())
print_results("Tree traversal", test_tree_traversal()) print_results("Tree traversal", test_tree_traversal())
print_results("Tree traversal", test_tree_chaining()) print_results("Tree traversal", test_tree_chaining())
print("Testing tree balancing...") print("Testing tree balancing...")
print("This should only be a few seconds.") print("This should only be a few seconds.")
test_insertion_speed() test_insertion_speed()