[mypy] Add type hints and docstrings to heap.py (#3013)

* Add type hints and docstrings to heap.py

- Add type hints
- Add docstrings
- Add explanatory comments 
- Improve code readability
- Change to use f-string

* Fix import sorting

* fixup! Format Python code with psf/black push

* Fix static type error

* Fix failing test

* Fix type hints

* Add return annotation

Co-authored-by: github-actions <${GITHUB_ACTOR}@users.noreply.github.com>
Co-authored-by: Dhruv Manilawala <dhruvmanila@gmail.com>
This commit is contained in:
Mark Huang 2020-12-26 11:12:37 +08:00 committed by GitHub
parent 8f47d9f807
commit 207ac957ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,101 +1,138 @@
#!/usr/bin/python3 from typing import Iterable, List, Optional
class Heap: class Heap:
""" """A Max Heap Implementation
>>> unsorted = [103, 9, 1, 7, 11, 15, 25, 201, 209, 107, 5] >>> unsorted = [103, 9, 1, 7, 11, 15, 25, 201, 209, 107, 5]
>>> h = Heap() >>> h = Heap()
>>> h.build_heap(unsorted) >>> h.build_max_heap(unsorted)
>>> h.display() >>> print(h)
[209, 201, 25, 103, 107, 15, 1, 9, 7, 11, 5] [209, 201, 25, 103, 107, 15, 1, 9, 7, 11, 5]
>>> >>>
>>> h.get_max() >>> h.extract_max()
209 209
>>> h.display() >>> print(h)
[201, 107, 25, 103, 11, 15, 1, 9, 7, 5] [201, 107, 25, 103, 11, 15, 1, 9, 7, 5]
>>> >>>
>>> h.insert(100) >>> h.insert(100)
>>> h.display() >>> print(h)
[201, 107, 25, 103, 100, 15, 1, 9, 7, 5, 11] [201, 107, 25, 103, 100, 15, 1, 9, 7, 5, 11]
>>> >>>
>>> h.heap_sort() >>> h.heap_sort()
>>> h.display() >>> print(h)
[1, 5, 7, 9, 11, 15, 25, 100, 103, 107, 201] [1, 5, 7, 9, 11, 15, 25, 100, 103, 107, 201]
>>>
""" """
def __init__(self): def __init__(self) -> None:
self.h = [] self.h: List[float] = []
self.curr_size = 0 self.heap_size: int = 0
def get_left_child_index(self, i): def __repr__(self) -> str:
left_child_index = 2 * i + 1 return str(self.h)
if left_child_index < self.curr_size:
def parent_index(self, child_idx: int) -> Optional[int]:
""" return the parent index of given child """
if child_idx > 0:
return (child_idx - 1) // 2
return None
def left_child_idx(self, parent_idx: int) -> Optional[int]:
"""
return the left child index if the left child exists.
if not, return None.
"""
left_child_index = 2 * parent_idx + 1
if left_child_index < self.heap_size:
return left_child_index return left_child_index
return None return None
def get_right_child(self, i): def right_child_idx(self, parent_idx: int) -> Optional[int]:
right_child_index = 2 * i + 2 """
if right_child_index < self.curr_size: return the right child index if the right child exists.
if not, return None.
"""
right_child_index = 2 * parent_idx + 2
if right_child_index < self.heap_size:
return right_child_index return right_child_index
return None return None
def max_heapify(self, index): def max_heapify(self, index: int) -> None:
if index < self.curr_size: """
largest = index correct a single violation of the heap property in a subtree's root.
lc = self.get_left_child_index(index) """
rc = self.get_right_child(index) if index < self.heap_size:
if lc is not None and self.h[lc] > self.h[largest]: violation: int = index
largest = lc left_child = self.left_child_idx(index)
if rc is not None and self.h[rc] > self.h[largest]: right_child = self.right_child_idx(index)
largest = rc # check which child is larger than its parent
if largest != index: if left_child is not None and self.h[left_child] > self.h[violation]:
self.h[largest], self.h[index] = self.h[index], self.h[largest] violation = left_child
self.max_heapify(largest) if right_child is not None and self.h[right_child] > self.h[violation]:
violation = right_child
# if violation indeed exists
if violation != index:
# swap to fix the violation
self.h[violation], self.h[index] = self.h[index], self.h[violation]
# fix the subsequent violation recursively if any
self.max_heapify(violation)
def build_heap(self, collection): def build_max_heap(self, collection: Iterable[float]) -> None:
self.curr_size = len(collection) """ build max heap from an unsorted array"""
self.h = list(collection) self.h = list(collection)
if self.curr_size <= 1: self.heap_size = len(self.h)
return if self.heap_size > 1:
for i in range(self.curr_size // 2 - 1, -1, -1): # max_heapify from right to left but exclude leaves (last level)
self.max_heapify(i) for i in range(self.heap_size // 2 - 1, -1, -1):
self.max_heapify(i)
def get_max(self): def max(self) -> float:
if self.curr_size >= 2: """ return the max in the heap """
if self.heap_size >= 1:
return self.h[0]
else:
raise Exception("Empty heap")
def extract_max(self) -> float:
""" get and remove max from heap """
if self.heap_size >= 2:
me = self.h[0] me = self.h[0]
self.h[0] = self.h.pop(-1) self.h[0] = self.h.pop(-1)
self.curr_size -= 1 self.heap_size -= 1
self.max_heapify(0) self.max_heapify(0)
return me return me
elif self.curr_size == 1: elif self.heap_size == 1:
self.curr_size -= 1 self.heap_size -= 1
return self.h.pop(-1) return self.h.pop(-1)
return None else:
raise Exception("Empty heap")
def heap_sort(self): def insert(self, value: float) -> None:
size = self.curr_size """ insert a new value into the max heap """
self.h.append(value)
idx = (self.heap_size - 1) // 2
self.heap_size += 1
while idx >= 0:
self.max_heapify(idx)
idx = (idx - 1) // 2
def heap_sort(self) -> None:
size = self.heap_size
for j in range(size - 1, 0, -1): for j in range(size - 1, 0, -1):
self.h[0], self.h[j] = self.h[j], self.h[0] self.h[0], self.h[j] = self.h[j], self.h[0]
self.curr_size -= 1 self.heap_size -= 1
self.max_heapify(0) self.max_heapify(0)
self.curr_size = size self.heap_size = size
def insert(self, data):
self.h.append(data)
curr = (self.curr_size - 1) // 2
self.curr_size += 1
while curr >= 0:
self.max_heapify(curr)
curr = (curr - 1) // 2
def display(self):
print(self.h)
def main(): if __name__ == "__main__":
import doctest
# run doc test
doctest.testmod()
# demo
for unsorted in [ for unsorted in [
[],
[0], [0],
[2], [2],
[3, 5], [3, 5],
@ -110,26 +147,17 @@ def main():
[103, 9, 1, 7, 11, 15, 25, 201, 209, 107, 5], [103, 9, 1, 7, 11, 15, 25, 201, 209, 107, 5],
[-45, -2, -5], [-45, -2, -5],
]: ]:
print("source unsorted list: %s" % unsorted) print(f"unsorted array: {unsorted}")
h = Heap() heap = Heap()
h.build_heap(unsorted) heap.build_max_heap(unsorted)
print("after build heap: ", end=" ") print(f"after build heap: {heap}")
h.display()
print("max value: %s" % h.get_max()) print(f"max value: {heap.extract_max()}")
print("delete max value: ", end=" ") print(f"after max value removed: {heap}")
h.display()
h.insert(100) heap.insert(100)
print("after insert new value 100: ", end=" ") print(f"after new value 100 inserted: {heap}")
h.display()
h.heap_sort() heap.heap_sort()
print("heap sort: ", end=" ") print(f"heap-sorted array: {heap}\n")
h.display()
print()
if __name__ == "__main__":
main()