Update binary_search_tree.py

This commit is contained in:
Christian Clauss 2023-10-27 01:10:38 +02:00 committed by GitHub
parent fe4aad0ec9
commit 6937b4b258
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,74 +10,65 @@ Example
/ \ / / \ /
4 7 13 4 7 13
>>> t = BinarySearchTree() >>> tree = BinarySearchTree()
>>> t.insert(8, 3, 6, 1, 10, 14, 13, 4, 7) >>> tree.insert(8, 3, 6, 1, 10, 14, 13, 4, 7)
>>> print(" ".join(repr(i.value) for i in t.traversal_tree())) >>> tuple(node.value for node in tree.traversal_tree()) # inorder traversal (sorted)
8 3 1 6 4 7 10 14 13
>>> tuple(i.value for i in t.traversal_tree(inorder))
(1, 3, 4, 6, 7, 8, 10, 13, 14) (1, 3, 4, 6, 7, 8, 10, 13, 14)
>>> tuple(t) >>> tuple(node.value for node in tree.traversal_tree(postorder))
(1, 3, 4, 6, 7, 8, 10, 13, 14) (1, 4, 7, 6, 3, 13, 14, 10, 8)
>>> t.find_kth_smallest(3, t.root)
4
>>> tuple(t)[3-1]
4
>>> print(" ".join(repr(i.value) for i in t.traversal_tree(postorder))) >>> tuple(tree)
1 4 7 6 3 13 14 10 8 (1, 3, 4, 6, 7, 8, 10, 13, 14)
>>> t.remove(20) >>> iter_t = iter(tree)
>>> next(iter_t)
1
>>> next(iter_t)
3
>>> tuple(tree)[3-1] # 3rd smallest element in a zero-indexed tuple
4
>>> sum(tree)
66
>>> tuple(node.value for node in tree.traversal_tree(postorder))
(1, 4, 7, 6, 3, 13, 14, 10, 8)
>>> tree.remove(20)
Traceback (most recent call last): Traceback (most recent call last):
... ...
ValueError: Value 20 not found ValueError: Value 20 not found
>>> BinarySearchTree().search(6)
Traceback (most recent call last):
...
IndexError: Warning: Tree is empty! please use another.
Other example: Other example:
>>> testlist = (8, 3, 6, 1, 10, 14, 13, 4, 7) >>> values = (8, 3, 6, 1, 10, 14, 13, 4, 7)
>>> t = BinarySearchTree() >>> tree = BinarySearchTree()
>>> for i in testlist: >>> for value in values:
... t.insert(i) ... tree.insert(value)
Prints all the elements of the list in order traversal Prints all the elements of the list in order traversal
>>> print(t) >>> print(tree)
{'8': ({'3': (1, {'6': (4, 7)})}, {'10': (None, {'14': (13, None)})})} {'8': ({'3': (1, {'6': (4, 7)})}, {'10': (None, {'14': (13, None)})})}
Test existence Test existence
>>> t.search(6) is not None >>> 6 in tree
True True
>>> 6 in t >>> -1 in tree
True
>>> t.search(-1) is not None
False
>>> -1 in t
False False
>>> t.search(6).is_right >>> tree.search(6).is_right
True True
>>> t.search(1).is_right >>> tree.search(1).is_right
False False
>>> t.get_max().value >>> max(tree)
14 14
>>> max(t) >>> min(tree)
14
>>> t.get_min().value
1 1
>>> min(t) >>> not tree
1
>>> t.empty()
False False
>>> not t >>> for value in values:
False ... tree.remove(value)
>>> for i in testlist: >>> list(tree)
... t.remove(i) []
>>> t.empty() >>> not tree
True
>>> not t
True True
""" """
from __future__ import annotations from __future__ import annotations
@ -144,15 +135,12 @@ class BinarySearchTree:
else: else:
self.root = new_children self.root = new_children
def empty(self) -> bool:
return self.root is None
def __insert(self, value) -> None: def __insert(self, value) -> None:
""" """
Insert a new node in Binary Search Tree with value label Insert a new node in Binary Search Tree with value label
""" """
new_node = Node(value) # create a new Node new_node = Node(value) # create a new Node
if self.empty(): # if Tree is empty if not self: # if Tree is empty
self.root = new_node # set its root self.root = new_node # set its root
else: # Tree is not empty else: # Tree is not empty
parent_node = self.root # from root parent_node = self.root # from root
@ -178,47 +166,32 @@ class BinarySearchTree:
self.__insert(value) self.__insert(value)
def search(self, value) -> Node | None: def search(self, value) -> Node | None:
if self.empty(): if not self:
raise IndexError("Warning: Tree is empty! please use another.") raise IndexError("Warning: Tree is empty! please use another.")
else: node = self.root
node = self.root # use lazy evaluation here to avoid NoneType Attribute error
# use lazy evaluation here to avoid NoneType Attribute error while node and node.value is not value:
while node is not None and node.value is not value: node = node.left if value < node.value else node.right
node = node.left if value < node.value else node.right return node
return node
def get_max(self, node: Node | None = None) -> Node | None: def get_max(self, node: Node | None = None) -> Node | None:
""" """
We go deep on the right branch We go deep on the right branch
""" """
if node is None: if node is None:
if self.root is None: if not self.root:
return None return None
node = self.root node = self.root
if not self.empty(): if self:
while node.right is not None: while node.right is not None:
node = node.right node = node.right
return node return node
def get_min(self, node: Node | None = None) -> Node | None:
"""
We go deep on the left branch
"""
if node is None:
node = self.root
if self.root is None:
return None
if not self.empty():
node = self.root
while node.left is not None:
node = node.left
return node
def remove(self, value: int) -> None: def remove(self, value: int) -> None:
# Look for the node with that label # Look for the node with that label
node = self.search(value) node = self.search(value)
if node is None: if not node:
msg = f"Value {value} not found" msg = f"Value {value} not found"
raise ValueError(msg) raise ValueError(msg)
@ -229,29 +202,18 @@ class BinarySearchTree:
elif node.right is None: # Has only left children elif node.right is None: # Has only left children
self.__reassign_nodes(node, node.left) self.__reassign_nodes(node, node.left)
else: else:
predecessor = self.get_max( # Gets the max value of the left branch
node.left predecessor = self.get_max(node.left)
) # Gets the max value of the left branch
self.remove(predecessor.value) # type: ignore self.remove(predecessor.value) # type: ignore
node.value = ( # Assigns the value to the node to delete and keep tree structure
predecessor.value # type: ignore node.value = predecessor.value # type: ignore
) # Assigns the value to the node to delete and keep tree structure
def preorder_traverse(self, node: Node | None) -> Iterable: @classmethod
if node is not None: def preorder_traverse(cls, node: Node | None) -> Iterable:
if node:
yield node # Preorder Traversal yield node # Preorder Traversal
yield from self.preorder_traverse(node.left) yield from cls.preorder_traverse(node.left)
yield from self.preorder_traverse(node.right) yield from cls.preorder_traverse(node.right)
def traversal_tree(self, traversal_function=None) -> Any:
"""
This function traversal the tree.
You can pass a function to traversal the tree as needed by client code
"""
if traversal_function is None:
return self.preorder_traverse(self.root)
else:
return traversal_function(self.root)
def inorder(self, arr: list, node: Node | None) -> None: def inorder(self, arr: list, node: Node | None) -> None:
"""Perform an inorder traversal and append values of the nodes to """Perform an inorder traversal and append values of the nodes to
@ -261,11 +223,12 @@ class BinarySearchTree:
arr.append(node.value) arr.append(node.value)
self.inorder(arr, node.right) self.inorder(arr, node.right)
def find_kth_smallest(self, k: int, node: Node) -> int: def traversal_tree(self, traversal_function=None) -> Any:
"""Return the kth smallest element in a binary search tree""" """
arr: list[int] = [] This function traversal the tree.
self.inorder(arr, node) # append all values to list using inorder traversal You can pass a function to traversal the tree as needed by client code
return arr[k - 1] """
return (traversal_function or inorder)(self.root)
def inorder(curr_node: Node | None) -> list[Node]: def inorder(curr_node: Node | None) -> list[Node]:
@ -273,17 +236,17 @@ def inorder(curr_node: Node | None) -> list[Node]:
inorder (left, self, right) inorder (left, self, right)
""" """
node_list = [] node_list = []
if curr_node is not None: if curr_node:
node_list = inorder(curr_node.left) + [curr_node] + inorder(curr_node.right) node_list = inorder(curr_node.left) + [curr_node] + inorder(curr_node.right)
return node_list return node_list
def postorder(curr_node: Node | None) -> list[Node]: def postorder(curr_node: Node | None) -> list[Node]:
""" """
postOrder (left, right, self) postorder (left, right, self)
""" """
node_list = [] node_list = []
if curr_node is not None: if curr_node:
node_list = postorder(curr_node.left) + postorder(curr_node.right) + [curr_node] node_list = postorder(curr_node.left) + postorder(curr_node.right) + [curr_node]
return node_list return node_list