[mypy] Fix type annotations in data_structures/binary_tree/red_black_tree.py (#5739)

* [mypy] Fix type annotations in red_black_tree.py

* Remove blank lines

* Update red_black_tree.py
This commit is contained in:
Dylan Buchi 2021-11-04 12:38:43 -03:00 committed by GitHub
parent e835e96856
commit 7a605766fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -51,6 +51,8 @@ class RedBlackTree:
"""
parent = self.parent
right = self.right
if right is None:
return self
self.right = right.left
if self.right:
self.right.parent = self
@ -69,6 +71,8 @@ class RedBlackTree:
returns the new root to this subtree.
Performing one rotation can be done in O(1).
"""
if self.left is None:
return self
parent = self.parent
left = self.left
self.left = left.right
@ -123,23 +127,30 @@ class RedBlackTree:
if color(uncle) == 0:
if self.is_left() and self.parent.is_right():
self.parent.rotate_right()
self.right._insert_repair()
if self.right:
self.right._insert_repair()
elif self.is_right() and self.parent.is_left():
self.parent.rotate_left()
self.left._insert_repair()
if self.left:
self.left._insert_repair()
elif self.is_left():
self.grandparent.rotate_right()
self.parent.color = 0
self.parent.right.color = 1
if self.grandparent:
self.grandparent.rotate_right()
self.parent.color = 0
if self.parent.right:
self.parent.right.color = 1
else:
self.grandparent.rotate_left()
self.parent.color = 0
self.parent.left.color = 1
if self.grandparent:
self.grandparent.rotate_left()
self.parent.color = 0
if self.parent.left:
self.parent.left.color = 1
else:
self.parent.color = 0
uncle.color = 0
self.grandparent.color = 1
self.grandparent._insert_repair()
if uncle and self.grandparent:
uncle.color = 0
self.grandparent.color = 1
self.grandparent._insert_repair()
def remove(self, label: int) -> RedBlackTree:
"""Remove label from this tree."""
@ -149,8 +160,9 @@ class RedBlackTree:
# so we replace this node with the greatest one less than
# it and remove that.
value = self.left.get_max()
self.label = value
self.left.remove(value)
if value is not None:
self.label = value
self.left.remove(value)
else:
# This node has at most one non-None child, so we don't
# need to replace
@ -160,10 +172,11 @@ class RedBlackTree:
# The only way this happens to a node with one child
# is if both children are None leaves.
# We can just remove this node and call it a day.
if self.is_left():
self.parent.left = None
else:
self.parent.right = None
if self.parent:
if self.is_left():
self.parent.left = None
else:
self.parent.right = None
else:
# The node is black
if child is None:
@ -188,7 +201,7 @@ class RedBlackTree:
self.left.parent = self
if self.right:
self.right.parent = self
elif self.label > label:
elif self.label is not None and self.label > label:
if self.left:
self.left.remove(label)
else:
@ -198,6 +211,13 @@ class RedBlackTree:
def _remove_repair(self) -> None:
"""Repair the coloring of the tree that may have been messed up."""
if (
self.parent is None
or self.sibling is None
or self.parent.sibling is None
or self.grandparent is None
):
return
if color(self.sibling) == 1:
self.sibling.color = 0
self.parent.color = 1
@ -231,7 +251,8 @@ class RedBlackTree:
):
self.sibling.rotate_right()
self.sibling.color = 0
self.sibling.right.color = 1
if self.sibling.right:
self.sibling.right.color = 1
if (
self.is_right()
and color(self.sibling) == 0
@ -240,7 +261,8 @@ class RedBlackTree:
):
self.sibling.rotate_left()
self.sibling.color = 0
self.sibling.left.color = 1
if self.sibling.left:
self.sibling.left.color = 1
if (
self.is_left()
and color(self.sibling) == 0
@ -275,21 +297,17 @@ class RedBlackTree:
"""
# I assume property 1 to hold because there is nothing that can
# make the color be anything other than 0 or 1.
# Property 2
if self.color:
# The root was red
print("Property 2")
return False
# Property 3 does not need to be checked, because None is assumed
# to be black and is all the leaves.
# Property 4
if not self.check_coloring():
print("Property 4")
return False
# Property 5
if self.black_height() is None:
print("Property 5")
@ -297,7 +315,7 @@ class RedBlackTree:
# All properties were met
return True
def check_coloring(self) -> None:
def check_coloring(self) -> bool:
"""A helper function to recursively check Property 4 of a
Red-Black Tree. See check_color_properties for more info.
"""
@ -310,12 +328,12 @@ class RedBlackTree:
return False
return True
def black_height(self) -> int:
def black_height(self) -> int | None:
"""Returns the number of black nodes from this node to the
leaves of the tree, or None if there isn't one such value (the
tree is color incorrectly).
"""
if self is None:
if self is None or self.left is None or self.right is None:
# If we're already at a leaf, there is no path
return 1
left = RedBlackTree.black_height(self.left)
@ -332,21 +350,21 @@ class RedBlackTree:
# Here are functions which are general to all binary search trees
def __contains__(self, label) -> bool:
def __contains__(self, label: int) -> bool:
"""Search through the tree for label, returning True iff it is
found somewhere in the tree.
Guaranteed to run in O(log(n)) time.
"""
return self.search(label) is not None
def search(self, label: int) -> RedBlackTree:
def search(self, label: int) -> RedBlackTree | None:
"""Search through the tree for label, returning its node if
it's found, and None otherwise.
This method is guaranteed to run in O(log(n)) time.
"""
if self.label == label:
return self
elif label > self.label:
elif self.label is not None and label > self.label:
if self.right is None:
return None
else:
@ -357,12 +375,12 @@ class RedBlackTree:
else:
return self.left.search(label)
def floor(self, label: int) -> int:
def floor(self, label: int) -> int | None:
"""Returns the largest element in this tree which is at most label.
This method is guaranteed to run in O(log(n)) time."""
if self.label == label:
return self.label
elif self.label > label:
elif self.label is not None and self.label > label:
if self.left:
return self.left.floor(label)
else:
@ -374,13 +392,13 @@ class RedBlackTree:
return attempt
return self.label
def ceil(self, label: int) -> int:
def ceil(self, label: int) -> int | None:
"""Returns the smallest element in this tree which is at least label.
This method is guaranteed to run in O(log(n)) time.
"""
if self.label == label:
return self.label
elif self.label < label:
elif self.label is not None and self.label < label:
if self.right:
return self.right.ceil(label)
else:
@ -392,7 +410,7 @@ class RedBlackTree:
return attempt
return self.label
def get_max(self) -> int:
def get_max(self) -> int | None:
"""Returns the largest element in this tree.
This method is guaranteed to run in O(log(n)) time.
"""
@ -402,7 +420,7 @@ class RedBlackTree:
else:
return self.label
def get_min(self) -> int:
def get_min(self) -> int | None:
"""Returns the smallest element in this tree.
This method is guaranteed to run in O(log(n)) time.
"""
@ -413,7 +431,7 @@ class RedBlackTree:
return self.label
@property
def grandparent(self) -> RedBlackTree:
def grandparent(self) -> RedBlackTree | None:
"""Get the current node's grandparent, or None if it doesn't exist."""
if self.parent is None:
return None
@ -421,7 +439,7 @@ class RedBlackTree:
return self.parent.parent
@property
def sibling(self) -> RedBlackTree:
def sibling(self) -> RedBlackTree | None:
"""Get the current node's sibling, or None if it doesn't exist."""
if self.parent is None:
return None
@ -432,11 +450,15 @@ class RedBlackTree:
def is_left(self) -> bool:
"""Returns true iff this node is the left child of its parent."""
return self.parent and self.parent.left is self
if self.parent is None:
return False
return self.parent.left is self.parent.left is self
def is_right(self) -> bool:
"""Returns true iff this node is the right child of its parent."""
return self.parent and self.parent.right is self
if self.parent is None:
return False
return self.parent.right is self
def __bool__(self) -> bool:
return True
@ -452,21 +474,21 @@ class RedBlackTree:
ln += len(self.right)
return ln
def preorder_traverse(self) -> Iterator[int]:
def preorder_traverse(self) -> Iterator[int | None]:
yield self.label
if self.left:
yield from self.left.preorder_traverse()
if self.right:
yield from self.right.preorder_traverse()
def inorder_traverse(self) -> Iterator[int]:
def inorder_traverse(self) -> Iterator[int | None]:
if self.left:
yield from self.left.inorder_traverse()
yield self.label
if self.right:
yield from self.right.inorder_traverse()
def postorder_traverse(self) -> Iterator[int]:
def postorder_traverse(self) -> Iterator[int | None]:
if self.left:
yield from self.left.postorder_traverse()
if self.right:
@ -488,15 +510,17 @@ class RedBlackTree:
indent=1,
)
def __eq__(self, other) -> bool:
def __eq__(self, other: object) -> bool:
"""Test if two trees are equal."""
if not isinstance(other, RedBlackTree):
return NotImplemented
if self.label == other.label:
return self.left == other.left and self.right == other.right
else:
return False
def color(node) -> int:
def color(node: RedBlackTree | None) -> int:
"""Returns the color of a node, allowing for None leaves."""
if node is None:
return 0
@ -699,19 +723,12 @@ def main() -> None:
>>> pytests()
"""
print_results("Rotating right and left", test_rotations())
print_results("Inserting", test_insert())
print_results("Searching", test_insert_and_search())
print_results("Deleting", test_insert_delete())
print_results("Floor and ceil", test_floor_ceil())
print_results("Tree traversal", test_tree_traversal())
print_results("Tree traversal", test_tree_chaining())
print("Testing tree balancing...")
print("This should only be a few seconds.")
test_insertion_speed()