Python/data_structures/binary_tree/binary_search_tree_recursive.py
Joan Martin Miralles a38e143cf8
Binary search tree using recursion (#1839)
* Binary search tree using recursion

* updating DIRECTORY.md

Co-authored-by: github-actions <${GITHUB_ACTOR}@users.noreply.github.com>
2020-04-07 18:09:05 +02:00

614 lines
15 KiB
Python

"""
This is a python3 implementation of binary search tree using recursion
To run tests:
python -m unittest binary_search_tree_recursive.py
To run an example:
python binary_search_tree_recursive.py
"""
import unittest
class Node:
def __init__(self, label: int, parent):
self.label = label
self.parent = parent
self.left = None
self.right = None
class BinarySearchTree:
def __init__(self):
self.root = None
def empty(self):
"""
Empties the tree
>>> t = BinarySearchTree()
>>> assert t.root is None
>>> t.put(8)
>>> assert t.root is not None
"""
self.root = None
def is_empty(self) -> bool:
"""
Checks if the tree is empty
>>> t = BinarySearchTree()
>>> t.is_empty()
True
>>> t.put(8)
>>> t.is_empty()
False
"""
return self.root is None
def put(self, label: int):
"""
Put a new node in the tree
>>> t = BinarySearchTree()
>>> t.put(8)
>>> assert t.root.parent is None
>>> assert t.root.label == 8
>>> t.put(10)
>>> assert t.root.right.parent == t.root
>>> assert t.root.right.label == 10
>>> t.put(3)
>>> assert t.root.left.parent == t.root
>>> assert t.root.left.label == 3
"""
self.root = self._put(self.root, label)
def _put(self, node: Node, label: int, parent: Node = None) -> Node:
if node is None:
node = Node(label, parent)
else:
if label < node.label:
node.left = self._put(node.left, label, node)
elif label > node.label:
node.right = self._put(node.right, label, node)
else:
raise Exception(f"Node with label {label} already exists")
return node
def search(self, label: int) -> Node:
"""
Searches a node in the tree
>>> t = BinarySearchTree()
>>> t.put(8)
>>> t.put(10)
>>> node = t.search(8)
>>> assert node.label == 8
>>> node = t.search(3)
Traceback (most recent call last):
...
Exception: Node with label 3 does not exist
"""
return self._search(self.root, label)
def _search(self, node: Node, label: int) -> Node:
if node is None:
raise Exception(f"Node with label {label} does not exist")
else:
if label < node.label:
node = self._search(node.left, label)
elif label > node.label:
node = self._search(node.right, label)
return node
def remove(self, label: int):
"""
Removes a node in the tree
>>> t = BinarySearchTree()
>>> t.put(8)
>>> t.put(10)
>>> t.remove(8)
>>> assert t.root.label == 10
>>> t.remove(3)
Traceback (most recent call last):
...
Exception: Node with label 3 does not exist
"""
node = self.search(label)
if not node.right and not node.left:
self._reassign_nodes(node, None)
elif not node.right and node.left:
self._reassign_nodes(node, node.left)
elif node.right and not node.left:
self._reassign_nodes(node, node.right)
else:
lowest_node = self._get_lowest_node(node.right)
lowest_node.left = node.left
lowest_node.right = node.right
node.left.parent = lowest_node
if node.right:
node.right.parent = lowest_node
self._reassign_nodes(node, lowest_node)
def _reassign_nodes(self, node: Node, new_children: Node):
if new_children:
new_children.parent = node.parent
if node.parent:
if node.parent.right == node:
node.parent.right = new_children
else:
node.parent.left = new_children
else:
self.root = new_children
def _get_lowest_node(self, node: Node) -> Node:
if node.left:
lowest_node = self._get_lowest_node(node.left)
else:
lowest_node = node
self._reassign_nodes(node, node.right)
return lowest_node
def exists(self, label: int) -> bool:
"""
Checks if a node exists in the tree
>>> t = BinarySearchTree()
>>> t.put(8)
>>> t.put(10)
>>> t.exists(8)
True
>>> t.exists(3)
False
"""
try:
self.search(label)
return True
except Exception:
return False
def get_max_label(self) -> int:
"""
Gets the max label inserted in the tree
>>> t = BinarySearchTree()
>>> t.get_max_label()
Traceback (most recent call last):
...
Exception: Binary search tree is empty
>>> t.put(8)
>>> t.put(10)
>>> t.get_max_label()
10
"""
if self.is_empty():
raise Exception("Binary search tree is empty")
node = self.root
while node.right is not None:
node = node.right
return node.label
def get_min_label(self) -> int:
"""
Gets the min label inserted in the tree
>>> t = BinarySearchTree()
>>> t.get_min_label()
Traceback (most recent call last):
...
Exception: Binary search tree is empty
>>> t.put(8)
>>> t.put(10)
>>> t.get_min_label()
8
"""
if self.is_empty():
raise Exception("Binary search tree is empty")
node = self.root
while node.left is not None:
node = node.left
return node.label
def inorder_traversal(self) -> list:
"""
Return the inorder traversal of the tree
>>> t = BinarySearchTree()
>>> [i.label for i in t.inorder_traversal()]
[]
>>> t.put(8)
>>> t.put(10)
>>> t.put(9)
>>> [i.label for i in t.inorder_traversal()]
[8, 9, 10]
"""
return self._inorder_traversal(self.root)
def _inorder_traversal(self, node: Node) -> list:
if node is not None:
yield from self._inorder_traversal(node.left)
yield node
yield from self._inorder_traversal(node.right)
def preorder_traversal(self) -> list:
"""
Return the preorder traversal of the tree
>>> t = BinarySearchTree()
>>> [i.label for i in t.preorder_traversal()]
[]
>>> t.put(8)
>>> t.put(10)
>>> t.put(9)
>>> [i.label for i in t.preorder_traversal()]
[8, 10, 9]
"""
return self._preorder_traversal(self.root)
def _preorder_traversal(self, node: Node) -> list:
if node is not None:
yield node
yield from self._preorder_traversal(node.left)
yield from self._preorder_traversal(node.right)
class BinarySearchTreeTest(unittest.TestCase):
@staticmethod
def _get_binary_search_tree():
r"""
8
/ \
3 10
/ \ \
1 6 14
/ \ /
4 7 13
\
5
"""
t = BinarySearchTree()
t.put(8)
t.put(3)
t.put(6)
t.put(1)
t.put(10)
t.put(14)
t.put(13)
t.put(4)
t.put(7)
t.put(5)
return t
def test_put(self):
t = BinarySearchTree()
assert t.is_empty()
t.put(8)
r"""
8
"""
assert t.root.parent is None
assert t.root.label == 8
t.put(10)
r"""
8
\
10
"""
assert t.root.right.parent == t.root
assert t.root.right.label == 10
t.put(3)
r"""
8
/ \
3 10
"""
assert t.root.left.parent == t.root
assert t.root.left.label == 3
t.put(6)
r"""
8
/ \
3 10
\
6
"""
assert t.root.left.right.parent == t.root.left
assert t.root.left.right.label == 6
t.put(1)
r"""
8
/ \
3 10
/ \
1 6
"""
assert t.root.left.left.parent == t.root.left
assert t.root.left.left.label == 1
with self.assertRaises(Exception):
t.put(1)
def test_search(self):
t = self._get_binary_search_tree()
node = t.search(6)
assert node.label == 6
node = t.search(13)
assert node.label == 13
with self.assertRaises(Exception):
t.search(2)
def test_remove(self):
t = self._get_binary_search_tree()
t.remove(13)
r"""
8
/ \
3 10
/ \ \
1 6 14
/ \
4 7
\
5
"""
assert t.root.right.right.right is None
assert t.root.right.right.left is None
t.remove(7)
r"""
8
/ \
3 10
/ \ \
1 6 14
/
4
\
5
"""
assert t.root.left.right.right is None
assert t.root.left.right.left.label == 4
t.remove(6)
r"""
8
/ \
3 10
/ \ \
1 4 14
\
5
"""
assert t.root.left.left.label == 1
assert t.root.left.right.label == 4
assert t.root.left.right.right.label == 5
assert t.root.left.right.left is None
assert t.root.left.left.parent == t.root.left
assert t.root.left.right.parent == t.root.left
t.remove(3)
r"""
8
/ \
4 10
/ \ \
1 5 14
"""
assert t.root.left.label == 4
assert t.root.left.right.label == 5
assert t.root.left.left.label == 1
assert t.root.left.parent == t.root
assert t.root.left.left.parent == t.root.left
assert t.root.left.right.parent == t.root.left
t.remove(4)
r"""
8
/ \
5 10
/ \
1 14
"""
assert t.root.left.label == 5
assert t.root.left.right is None
assert t.root.left.left.label == 1
assert t.root.left.parent == t.root
assert t.root.left.left.parent == t.root.left
def test_remove_2(self):
t = self._get_binary_search_tree()
t.remove(3)
r"""
8
/ \
4 10
/ \ \
1 6 14
/ \ /
5 7 13
"""
assert t.root.left.label == 4
assert t.root.left.right.label == 6
assert t.root.left.left.label == 1
assert t.root.left.right.right.label == 7
assert t.root.left.right.left.label == 5
assert t.root.left.parent == t.root
assert t.root.left.right.parent == t.root.left
assert t.root.left.left.parent == t.root.left
assert t.root.left.right.left.parent == t.root.left.right
def test_empty(self):
t = self._get_binary_search_tree()
t.empty()
assert t.root is None
def test_is_empty(self):
t = self._get_binary_search_tree()
assert not t.is_empty()
t.empty()
assert t.is_empty()
def test_exists(self):
t = self._get_binary_search_tree()
assert t.exists(6)
assert not t.exists(-1)
def test_get_max_label(self):
t = self._get_binary_search_tree()
assert t.get_max_label() == 14
t.empty()
with self.assertRaises(Exception):
t.get_max_label()
def test_get_min_label(self):
t = self._get_binary_search_tree()
assert t.get_min_label() == 1
t.empty()
with self.assertRaises(Exception):
t.get_min_label()
def test_inorder_traversal(self):
t = self._get_binary_search_tree()
inorder_traversal_nodes = [i.label for i in t.inorder_traversal()]
assert inorder_traversal_nodes == [1, 3, 4, 5, 6, 7, 8, 10, 13, 14]
def test_preorder_traversal(self):
t = self._get_binary_search_tree()
preorder_traversal_nodes = [i.label for i in t.preorder_traversal()]
assert preorder_traversal_nodes == [8, 3, 1, 6, 4, 5, 7, 10, 14, 13]
def binary_search_tree_example():
r"""
Example
8
/ \
3 10
/ \ \
1 6 14
/ \ /
4 7 13
\
5
Example After Deletion
4
/ \
1 7
\
5
"""
t = BinarySearchTree()
t.put(8)
t.put(3)
t.put(6)
t.put(1)
t.put(10)
t.put(14)
t.put(13)
t.put(4)
t.put(7)
t.put(5)
print(
"""
8
/ \\
3 10
/ \\ \\
1 6 14
/ \\ /
4 7 13
\\
5
"""
)
print("Label 6 exists:", t.exists(6))
print("Label 13 exists:", t.exists(13))
print("Label -1 exists:", t.exists(-1))
print("Label 12 exists:", t.exists(12))
# Prints all the elements of the list in inorder traversal
inorder_traversal_nodes = [i.label for i in t.inorder_traversal()]
print("Inorder traversal:", inorder_traversal_nodes)
# Prints all the elements of the list in preorder traversal
preorder_traversal_nodes = [i.label for i in t.preorder_traversal()]
print("Preorder traversal:", preorder_traversal_nodes)
print("Max. label:", t.get_max_label())
print("Min. label:", t.get_min_label())
# Delete elements
print("\nDeleting elements 13, 10, 8, 3, 6, 14")
print(
"""
4
/ \\
1 7
\\
5
"""
)
t.remove(13)
t.remove(10)
t.remove(8)
t.remove(3)
t.remove(6)
t.remove(14)
# Prints all the elements of the list in inorder traversal after delete
inorder_traversal_nodes = [i.label for i in t.inorder_traversal()]
print("Inorder traversal after delete:", inorder_traversal_nodes)
# Prints all the elements of the list in preorder traversal after delete
preorder_traversal_nodes = [i.label for i in t.preorder_traversal()]
print("Preorder traversal after delete:", preorder_traversal_nodes)
print("Max. label:", t.get_max_label())
print("Min. label:", t.get_min_label())
if __name__ == "__main__":
binary_search_tree_example()