diff --git a/data_structures/binary_tree/avl_tree.py b/data_structures/binary_tree/avl_tree.py index cb043cf18..71dede2cc 100644 --- a/data_structures/binary_tree/avl_tree.py +++ b/data_structures/binary_tree/avl_tree.py @@ -1,6 +1,11 @@ """ -An auto-balanced binary tree! +Implementation of an auto-balanced binary tree! +For doctests run following command: +python3 -m doctest -v avl_tree.py +For testing run: +python avl_tree.py """ + import math import random @@ -11,7 +16,7 @@ class my_queue: self.head = 0 self.tail = 0 - def isEmpty(self): + def is_empty(self): return self.head == self.tail def push(self, data): @@ -39,39 +44,39 @@ class my_node: self.right = None self.height = 1 - def getdata(self): + def get_data(self): return self.data - def getleft(self): + def get_left(self): return self.left - def getright(self): + def get_right(self): return self.right - def getheight(self): + def get_height(self): return self.height - def setdata(self, data): + def set_data(self, data): self.data = data return - def setleft(self, node): + def set_left(self, node): self.left = node return - def setright(self, node): + def set_right(self, node): self.right = node return - def setheight(self, height): + def set_height(self, height): self.height = height return -def getheight(node): +def get_height(node): if node is None: return 0 - return node.getheight() + return node.get_height() def my_max(a, b): @@ -80,7 +85,7 @@ def my_max(a, b): return b -def leftrotation(node): +def right_rotation(node): r""" A B / \ / \ @@ -89,138 +94,171 @@ def leftrotation(node): Bl Br UB Br C / UB - UB = unbalanced node """ - print("left rotation node:", node.getdata()) - ret = node.getleft() - node.setleft(ret.getright()) - ret.setright(node) - h1 = my_max(getheight(node.getright()), getheight(node.getleft())) + 1 - node.setheight(h1) - h2 = my_max(getheight(ret.getright()), getheight(ret.getleft())) + 1 - ret.setheight(h2) + print("left rotation node:", node.get_data()) + ret = node.get_left() + node.set_left(ret.get_right()) + ret.set_right(node) + h1 = my_max(get_height(node.get_right()), get_height(node.get_left())) + 1 + node.set_height(h1) + h2 = my_max(get_height(ret.get_right()), get_height(ret.get_left())) + 1 + ret.set_height(h2) return ret -def rightrotation(node): +def left_rotation(node): """ - a mirror symmetry rotation of the leftrotation + a mirror symmetry rotation of the left_rotation """ - print("right rotation node:", node.getdata()) - ret = node.getright() - node.setright(ret.getleft()) - ret.setleft(node) - h1 = my_max(getheight(node.getright()), getheight(node.getleft())) + 1 - node.setheight(h1) - h2 = my_max(getheight(ret.getright()), getheight(ret.getleft())) + 1 - ret.setheight(h2) + print("right rotation node:", node.get_data()) + ret = node.get_right() + node.set_right(ret.get_left()) + ret.set_left(node) + h1 = my_max(get_height(node.get_right()), get_height(node.get_left())) + 1 + node.set_height(h1) + h2 = my_max(get_height(ret.get_right()), get_height(ret.get_left())) + 1 + ret.set_height(h2) return ret -def rlrotation(node): +def lr_rotation(node): r""" A A Br / \ / \ / \ - B C RR Br C LR B A + B C LR Br C RR B A / \ --> / \ --> / / \ Bl Br B UB Bl UB C \ / UB Bl - RR = rightrotation LR = leftrotation + RR = right_rotation LR = left_rotation """ - node.setleft(rightrotation(node.getleft())) - return leftrotation(node) + node.set_left(left_rotation(node.get_left())) + return right_rotation(node) -def lrrotation(node): - node.setright(leftrotation(node.getright())) - return rightrotation(node) +def rl_rotation(node): + node.set_right(right_rotation(node.get_right())) + return left_rotation(node) def insert_node(node, data): if node is None: return my_node(data) - if data < node.getdata(): - node.setleft(insert_node(node.getleft(), data)) + if data < node.get_data(): + node.set_left(insert_node(node.get_left(), data)) if ( - getheight(node.getleft()) - getheight(node.getright()) == 2 + get_height(node.get_left()) - get_height(node.get_right()) == 2 ): # an unbalance detected if ( - data < node.getleft().getdata() + data < node.get_left().get_data() ): # new node is the left child of the left child - node = leftrotation(node) + node = right_rotation(node) else: - node = rlrotation(node) # new node is the right child of the left child + node = lr_rotation(node) else: - node.setright(insert_node(node.getright(), data)) - if getheight(node.getright()) - getheight(node.getleft()) == 2: - if data < node.getright().getdata(): - node = lrrotation(node) + node.set_right(insert_node(node.get_right(), data)) + if get_height(node.get_right()) - get_height(node.get_left()) == 2: + if data < node.get_right().get_data(): + node = rl_rotation(node) else: - node = rightrotation(node) - h1 = my_max(getheight(node.getright()), getheight(node.getleft())) + 1 - node.setheight(h1) + node = left_rotation(node) + h1 = my_max(get_height(node.get_right()), get_height(node.get_left())) + 1 + node.set_height(h1) return node -def getRightMost(root): - while root.getright() is not None: - root = root.getright() - return root.getdata() +def get_rightMost(root): + while root.get_right() is not None: + root = root.get_right() + return root.get_data() -def getLeftMost(root): - while root.getleft() is not None: - root = root.getleft() - return root.getdata() +def get_leftMost(root): + while root.get_left() is not None: + root = root.get_left() + return root.get_data() def del_node(root, data): - if root.getdata() == data: - if root.getleft() is not None and root.getright() is not None: - temp_data = getLeftMost(root.getright()) - root.setdata(temp_data) - root.setright(del_node(root.getright(), temp_data)) - elif root.getleft() is not None: - root = root.getleft() + if root.get_data() == data: + if root.get_left() is not None and root.get_right() is not None: + temp_data = get_leftMost(root.get_right()) + root.set_data(temp_data) + root.set_right(del_node(root.get_right(), temp_data)) + elif root.get_left() is not None: + root = root.get_left() else: - root = root.getright() - elif root.getdata() > data: - if root.getleft() is None: + root = root.get_right() + elif root.get_data() > data: + if root.get_left() is None: print("No such data") return root else: - root.setleft(del_node(root.getleft(), data)) - elif root.getdata() < data: - if root.getright() is None: + root.set_left(del_node(root.get_left(), data)) + elif root.get_data() < data: + if root.get_right() is None: return root else: - root.setright(del_node(root.getright(), data)) + root.set_right(del_node(root.get_right(), data)) if root is None: return root - if getheight(root.getright()) - getheight(root.getleft()) == 2: - if getheight(root.getright().getright()) > getheight(root.getright().getleft()): - root = rightrotation(root) + if get_height(root.get_right()) - get_height(root.get_left()) == 2: + if get_height(root.get_right().get_right()) > \ + get_height(root.get_right().get_left()): + root = left_rotation(root) else: - root = lrrotation(root) - elif getheight(root.getright()) - getheight(root.getleft()) == -2: - if getheight(root.getleft().getleft()) > getheight(root.getleft().getright()): - root = leftrotation(root) + root = rl_rotation(root) + elif get_height(root.get_right()) - get_height(root.get_left()) == -2: + if get_height(root.get_left().get_left()) > \ + get_height(root.get_left().get_right()): + root = right_rotation(root) else: - root = rlrotation(root) - height = my_max(getheight(root.getright()), getheight(root.getleft())) + 1 - root.setheight(height) + root = lr_rotation(root) + height = my_max(get_height(root.get_right()), get_height(root.get_left())) + 1 + root.set_height(height) return root class AVLtree: + """ + An AVL tree doctest + Examples: + >>> t = AVLtree() + >>> t.insert(4) + insert:4 + >>> print(str(t).replace(" \\n","\\n")) + 4 + ************************************* + >>> t.insert(2) + insert:2 + >>> print(str(t).replace(" \\n","\\n").replace(" \\n","\\n")) + 4 + 2 * + ************************************* + >>> t.insert(3) + insert:3 + right rotation node: 2 + left rotation node: 4 + >>> print(str(t).replace(" \\n","\\n").replace(" \\n","\\n")) + 3 + 2 4 + ************************************* + >>> t.get_height() + 2 + >>> t.del_node(3) + delete:3 + >>> print(str(t).replace(" \\n","\\n").replace(" \\n","\\n")) + 4 + 2 * + ************************************* + """ def __init__(self): self.root = None - def getheight(self): + def get_height(self): # print("yyy") - return getheight(self.root) + return get_height(self.root) def insert(self, data): print("insert:" + str(data)) @@ -233,56 +271,54 @@ class AVLtree: return self.root = del_node(self.root, data) - def traversale(self): # a level traversale, gives a more intuitive look on the tree + def __str__(self): # a level traversale, gives a more intuitive look on the tree + output = "" q = my_queue() q.push(self.root) - layer = self.getheight() + layer = self.get_height() if layer == 0: - return + return output cnt = 0 - while not q.isEmpty(): + while not q.is_empty(): node = q.pop() space = " " * int(math.pow(2, layer - 1)) - print(space, end="") + output += space if node is None: - print("*", end="") + output += "*" q.push(None) q.push(None) else: - print(node.getdata(), end="") - q.push(node.getleft()) - q.push(node.getright()) - print(space, end="") + output += str(node.get_data()) + q.push(node.get_left()) + q.push(node.get_right()) + output += space cnt = cnt + 1 for i in range(100): if cnt == math.pow(2, i) - 1: layer = layer - 1 if layer == 0: - print() - print("*************************************") - return - print() + output += "\n*************************************" + return output + output += "\n" break - print() - print("*************************************") - return + output += "\n*************************************" + return output - def test(self): - getheight(None) - print("****") - self.getheight() + +def _test(): + import doctest + doctest.testmod() if __name__ == "__main__": + _test() t = AVLtree() - t.traversale() lst = list(range(10)) random.shuffle(lst) for i in lst: t.insert(i) - t.traversale() - + print(str(t)) random.shuffle(lst) for i in lst: t.del_node(i) - t.traversale() + print(str(t))