Fix mypy errors at kruskal_2 (#4528)

This commit is contained in:
Hasanul Islam 2021-07-08 12:46:43 +06:00 committed by GitHub
parent 4412eafaac
commit 256c319ce2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,78 +1,93 @@
from __future__ import annotations from __future__ import annotations
from typing import Generic, TypeVar
class DisjointSetTreeNode: T = TypeVar("T")
class DisjointSetTreeNode(Generic[T]):
# Disjoint Set Node to store the parent and rank # Disjoint Set Node to store the parent and rank
def __init__(self, key: int) -> None: def __init__(self, data: T) -> None:
self.key = key self.data = data
self.parent = self self.parent = self
self.rank = 0 self.rank = 0
class DisjointSetTree: class DisjointSetTree(Generic[T]):
# Disjoint Set DataStructure # Disjoint Set DataStructure
def __init__(self): def __init__(self) -> None:
# map from node name to the node object # map from node name to the node object
self.map = {} self.map: dict[T, DisjointSetTreeNode[T]] = {}
def make_set(self, x: int) -> None: def make_set(self, data: T) -> None:
# create a new set with x as its member # create a new set with x as its member
self.map[x] = DisjointSetTreeNode(x) self.map[data] = DisjointSetTreeNode(data)
def find_set(self, x: int) -> DisjointSetTreeNode: def find_set(self, data: T) -> DisjointSetTreeNode[T]:
# find the set x belongs to (with path-compression) # find the set x belongs to (with path-compression)
elem_ref = self.map[x] elem_ref = self.map[data]
if elem_ref != elem_ref.parent: if elem_ref != elem_ref.parent:
elem_ref.parent = self.find_set(elem_ref.parent.key) elem_ref.parent = self.find_set(elem_ref.parent.data)
return elem_ref.parent return elem_ref.parent
def link(self, x: int, y: int) -> None: def link(
self, node1: DisjointSetTreeNode[T], node2: DisjointSetTreeNode[T]
) -> None:
# helper function for union operation # helper function for union operation
if x.rank > y.rank: if node1.rank > node2.rank:
y.parent = x node2.parent = node1
else: else:
x.parent = y node1.parent = node2
if x.rank == y.rank: if node1.rank == node2.rank:
y.rank += 1 node2.rank += 1
def union(self, x: int, y: int) -> None: def union(self, data1: T, data2: T) -> None:
# merge 2 disjoint sets # merge 2 disjoint sets
self.link(self.find_set(x), self.find_set(y)) self.link(self.find_set(data1), self.find_set(data2))
class GraphUndirectedWeighted: class GraphUndirectedWeighted(Generic[T]):
def __init__(self): def __init__(self) -> None:
# connections: map from the node to the neighbouring nodes (with weights) # connections: map from the node to the neighbouring nodes (with weights)
self.connections = {} self.connections: dict[T, dict[T, int]] = {}
def add_node(self, node: int) -> None: def add_node(self, node: T) -> None:
# add a node ONLY if its not present in the graph # add a node ONLY if its not present in the graph
if node not in self.connections: if node not in self.connections:
self.connections[node] = {} self.connections[node] = {}
def add_edge(self, node1: int, node2: int, weight: int) -> None: def add_edge(self, node1: T, node2: T, weight: int) -> None:
# add an edge with the given weight # add an edge with the given weight
self.add_node(node1) self.add_node(node1)
self.add_node(node2) self.add_node(node2)
self.connections[node1][node2] = weight self.connections[node1][node2] = weight
self.connections[node2][node1] = weight self.connections[node2][node1] = weight
def kruskal(self) -> GraphUndirectedWeighted: def kruskal(self) -> GraphUndirectedWeighted[T]:
# Kruskal's Algorithm to generate a Minimum Spanning Tree (MST) of a graph # Kruskal's Algorithm to generate a Minimum Spanning Tree (MST) of a graph
""" """
Details: https://en.wikipedia.org/wiki/Kruskal%27s_algorithm Details: https://en.wikipedia.org/wiki/Kruskal%27s_algorithm
Example: Example:
>>> g1 = GraphUndirectedWeighted[int]()
>>> graph = GraphUndirectedWeighted() >>> g1.add_edge(1, 2, 1)
>>> graph.add_edge(1, 2, 1) >>> g1.add_edge(2, 3, 2)
>>> graph.add_edge(2, 3, 2) >>> g1.add_edge(3, 4, 1)
>>> graph.add_edge(3, 4, 1) >>> g1.add_edge(3, 5, 100) # Removed in MST
>>> graph.add_edge(3, 5, 100) # Removed in MST >>> g1.add_edge(4, 5, 5)
>>> graph.add_edge(4, 5, 5) >>> assert 5 in g1.connections[3]
>>> assert 5 in graph.connections[3] >>> mst = g1.kruskal()
>>> mst = graph.kruskal()
>>> assert 5 not in mst.connections[3] >>> assert 5 not in mst.connections[3]
>>> g2 = GraphUndirectedWeighted[str]()
>>> g2.add_edge('A', 'B', 1)
>>> g2.add_edge('B', 'C', 2)
>>> g2.add_edge('C', 'D', 1)
>>> g2.add_edge('C', 'E', 100) # Removed in MST
>>> g2.add_edge('D', 'E', 5)
>>> assert 'E' in g2.connections["C"]
>>> mst = g2.kruskal()
>>> assert 'E' not in mst.connections['C']
""" """
# getting the edges in ascending order of weights # getting the edges in ascending order of weights
@ -84,26 +99,23 @@ class GraphUndirectedWeighted:
seen.add((end, start)) seen.add((end, start))
edges.append((start, end, self.connections[start][end])) edges.append((start, end, self.connections[start][end]))
edges.sort(key=lambda x: x[2]) edges.sort(key=lambda x: x[2])
# creating the disjoint set # creating the disjoint set
disjoint_set = DisjointSetTree() disjoint_set = DisjointSetTree[T]()
[disjoint_set.make_set(node) for node in self.connections] for node in self.connections:
disjoint_set.make_set(node)
# MST generation # MST generation
num_edges = 0 num_edges = 0
index = 0 index = 0
graph = GraphUndirectedWeighted() graph = GraphUndirectedWeighted[T]()
while num_edges < len(self.connections) - 1: while num_edges < len(self.connections) - 1:
u, v, w = edges[index] u, v, w = edges[index]
index += 1 index += 1
parentu = disjoint_set.find_set(u) parent_u = disjoint_set.find_set(u)
parentv = disjoint_set.find_set(v) parent_v = disjoint_set.find_set(v)
if parentu != parentv: if parent_u != parent_v:
num_edges += 1 num_edges += 1
graph.add_edge(u, v, w) graph.add_edge(u, v, w)
disjoint_set.union(u, v) disjoint_set.union(u, v)
return graph return graph
if __name__ == "__main__":
import doctest
doctest.testmod()