From 95862303a6527f4bf111e6f3f783fd66b7b426f3 Mon Sep 17 00:00:00 2001 From: Hasanul Islam Date: Mon, 5 Jul 2021 12:23:18 +0600 Subject: [PATCH] Fix mypy at prims_algo_2 (#4527) --- graphs/minimum_spanning_tree_prims2.py | 56 +++++++++++++------------- 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/graphs/minimum_spanning_tree_prims2.py b/graphs/minimum_spanning_tree_prims2.py index 10ed736c9..c3444c36f 100644 --- a/graphs/minimum_spanning_tree_prims2.py +++ b/graphs/minimum_spanning_tree_prims2.py @@ -8,7 +8,9 @@ connection from the tree to another vertex. """ from sys import maxsize -from typing import Dict, Optional, Tuple, Union +from typing import Generic, Optional, TypeVar + +T = TypeVar("T") def get_parent_position(position: int) -> int: @@ -43,7 +45,7 @@ def get_child_right_position(position: int) -> int: return (2 * position) + 2 -class MinPriorityQueue: +class MinPriorityQueue(Generic[T]): """ Minimum Priority Queue Class @@ -80,9 +82,9 @@ class MinPriorityQueue: """ def __init__(self) -> None: - self.heap = [] - self.position_map = {} - self.elements = 0 + self.heap: list[tuple[T, int]] = [] + self.position_map: dict[T, int] = {} + self.elements: int = 0 def __len__(self) -> int: return self.elements @@ -94,14 +96,14 @@ class MinPriorityQueue: # Check if the priority queue is empty return self.elements == 0 - def push(self, elem: Union[int, str], weight: int) -> None: + def push(self, elem: T, weight: int) -> None: # Add an element with given priority to the queue self.heap.append((elem, weight)) self.position_map[elem] = self.elements self.elements += 1 self._bubble_up(elem) - def extract_min(self) -> Union[int, str]: + def extract_min(self) -> T: # Remove and return the element with lowest weight (highest priority) if self.elements > 1: self._swap_nodes(0, self.elements - 1) @@ -113,7 +115,7 @@ class MinPriorityQueue: self._bubble_down(bubble_down_elem) return elem - def update_key(self, elem: Union[int, str], weight: int) -> None: + def update_key(self, elem: T, weight: int) -> None: # Update the weight of the given key position = self.position_map[elem] self.heap[position] = (elem, weight) @@ -127,7 +129,7 @@ class MinPriorityQueue: else: self._bubble_down(elem) - def _bubble_up(self, elem: Union[int, str]) -> None: + def _bubble_up(self, elem: T) -> None: # Place a node at the proper position (upward movement) [to be used internally # only] curr_pos = self.position_map[elem] @@ -141,7 +143,7 @@ class MinPriorityQueue: return self._bubble_up(elem) return - def _bubble_down(self, elem: Union[int, str]) -> None: + def _bubble_down(self, elem: T) -> None: # Place a node at the proper position (downward movement) [to be used # internally only] curr_pos = self.position_map[elem] @@ -182,7 +184,7 @@ class MinPriorityQueue: self.position_map[node2_elem] = node1_pos -class GraphUndirectedWeighted: +class GraphUndirectedWeighted(Generic[T]): """ Graph Undirected Weighted Class @@ -192,8 +194,8 @@ class GraphUndirectedWeighted: """ def __init__(self) -> None: - self.connections = {} - self.nodes = 0 + self.connections: dict[T, dict[T, int]] = {} + self.nodes: int = 0 def __repr__(self) -> str: return str(self.connections) @@ -201,15 +203,13 @@ class GraphUndirectedWeighted: def __len__(self) -> int: return self.nodes - def add_node(self, node: Union[int, str]) -> None: + def add_node(self, node: T) -> None: # Add a node in the graph if it is not in the graph if node not in self.connections: self.connections[node] = {} self.nodes += 1 - def add_edge( - self, node1: Union[int, str], node2: Union[int, str], weight: int - ) -> None: + def add_edge(self, node1: T, node2: T, weight: int) -> None: # Add an edge between 2 nodes in the graph self.add_node(node1) self.add_node(node2) @@ -218,8 +218,8 @@ class GraphUndirectedWeighted: def prims_algo( - graph: GraphUndirectedWeighted, -) -> Tuple[Dict[str, int], Dict[str, Optional[str]]]: + graph: GraphUndirectedWeighted[T], +) -> tuple[dict[T, int], dict[T, Optional[T]]]: """ >>> graph = GraphUndirectedWeighted() @@ -239,10 +239,13 @@ def prims_algo( 13 """ # prim's algorithm for minimum spanning tree - dist = {node: maxsize for node in graph.connections} - parent = {node: None for node in graph.connections} - priority_queue = MinPriorityQueue() - [priority_queue.push(node, weight) for node, weight in dist.items()] + dist: dict[T, int] = {node: maxsize for node in graph.connections} + parent: dict[T, Optional[T]] = {node: None for node in graph.connections} + + priority_queue: MinPriorityQueue[T] = MinPriorityQueue() + for node, weight in dist.items(): + priority_queue.push(node, weight) + if priority_queue.is_empty(): return dist, parent @@ -254,6 +257,7 @@ def prims_algo( dist[neighbour] = dist[node] + graph.connections[node][neighbour] priority_queue.update_key(neighbour, dist[neighbour]) parent[neighbour] = node + # running prim's algorithm while not priority_queue.is_empty(): node = priority_queue.extract_min() @@ -263,9 +267,3 @@ def prims_algo( priority_queue.update_key(neighbour, dist[neighbour]) parent[neighbour] = node return dist, parent - - -if __name__ == "__main__": - from doctest import testmod - - testmod()