Reduce the complexity of graphs/minimum_spanning_tree_prims.py (#7952)

* Lower the --max-complexity threshold in the file .flake8

* Add test

* Reduce the complexity of graphs/minimum_spanning_tree_prims.py

* Remove backslashes

* Remove # noqa: E741

* Fix the flake8 E741 issues

* Refactor

* Fix
This commit is contained in:
Maxim Smolskiy 2022-11-03 00:16:44 +03:00 committed by GitHub
parent db5215f60e
commit a02de964d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 76 additions and 53 deletions

View File

@ -1,7 +1,7 @@
[flake8] [flake8]
max-line-length = 88 max-line-length = 88
# max-complexity should be 10 # max-complexity should be 10
max-complexity = 21 max-complexity = 20
extend-ignore = extend-ignore =
# Formatting style for `black` # Formatting style for `black`
E203 # Whitespace before ':' E203 # Whitespace before ':'

View File

@ -2,40 +2,45 @@ import sys
from collections import defaultdict from collections import defaultdict
def prisms_algorithm(l): # noqa: E741 class Heap:
def __init__(self):
self.node_position = []
node_position = [] def get_position(self, vertex):
return self.node_position[vertex]
def get_position(vertex): def set_position(self, vertex, pos):
return node_position[vertex] self.node_position[vertex] = pos
def set_position(vertex, pos): def top_to_bottom(self, heap, start, size, positions):
node_position[vertex] = pos
def top_to_bottom(heap, start, size, positions):
if start > size // 2 - 1: if start > size // 2 - 1:
return return
else: else:
if 2 * start + 2 >= size: if 2 * start + 2 >= size:
m = 2 * start + 1 smallest_child = 2 * start + 1
else: else:
if heap[2 * start + 1] < heap[2 * start + 2]: if heap[2 * start + 1] < heap[2 * start + 2]:
m = 2 * start + 1 smallest_child = 2 * start + 1
else: else:
m = 2 * start + 2 smallest_child = 2 * start + 2
if heap[m] < heap[start]: if heap[smallest_child] < heap[start]:
temp, temp1 = heap[m], positions[m] temp, temp1 = heap[smallest_child], positions[smallest_child]
heap[m], positions[m] = heap[start], positions[start] heap[smallest_child], positions[smallest_child] = (
heap[start],
positions[start],
)
heap[start], positions[start] = temp, temp1 heap[start], positions[start] = temp, temp1
temp = get_position(positions[m]) temp = self.get_position(positions[smallest_child])
set_position(positions[m], get_position(positions[start])) self.set_position(
set_position(positions[start], temp) positions[smallest_child], self.get_position(positions[start])
)
self.set_position(positions[start], temp)
top_to_bottom(heap, m, size, positions) self.top_to_bottom(heap, smallest_child, size, positions)
# Update function if value of any node in min-heap decreases # Update function if value of any node in min-heap decreases
def bottom_to_top(val, index, heap, position): def bottom_to_top(self, val, index, heap, position):
temp = position[index] temp = position[index]
while index != 0: while index != 0:
@ -47,70 +52,88 @@ def prisms_algorithm(l): # noqa: E741
if val < heap[parent]: if val < heap[parent]:
heap[index] = heap[parent] heap[index] = heap[parent]
position[index] = position[parent] position[index] = position[parent]
set_position(position[parent], index) self.set_position(position[parent], index)
else: else:
heap[index] = val heap[index] = val
position[index] = temp position[index] = temp
set_position(temp, index) self.set_position(temp, index)
break break
index = parent index = parent
else: else:
heap[0] = val heap[0] = val
position[0] = temp position[0] = temp
set_position(temp, 0) self.set_position(temp, 0)
def heapify(heap, positions): def heapify(self, heap, positions):
start = len(heap) // 2 - 1 start = len(heap) // 2 - 1
for i in range(start, -1, -1): for i in range(start, -1, -1):
top_to_bottom(heap, i, len(heap), positions) self.top_to_bottom(heap, i, len(heap), positions)
def delete_minimum(heap, positions): def delete_minimum(self, heap, positions):
temp = positions[0] temp = positions[0]
heap[0] = sys.maxsize heap[0] = sys.maxsize
top_to_bottom(heap, 0, len(heap), positions) self.top_to_bottom(heap, 0, len(heap), positions)
return temp return temp
visited = [0 for i in range(len(l))]
nbr_tv = [-1 for i in range(len(l))] # Neighboring Tree Vertex of selected vertex def prisms_algorithm(adjacency_list):
"""
>>> adjacency_list = {0: [[1, 1], [3, 3]],
... 1: [[0, 1], [2, 6], [3, 5], [4, 1]],
... 2: [[1, 6], [4, 5], [5, 2]],
... 3: [[0, 3], [1, 5], [4, 1]],
... 4: [[1, 1], [2, 5], [3, 1], [5, 4]],
... 5: [[2, 2], [4, 4]]}
>>> prisms_algorithm(adjacency_list)
[(0, 1), (1, 4), (4, 3), (4, 5), (5, 2)]
"""
heap = Heap()
visited = [0] * len(adjacency_list)
nbr_tv = [-1] * len(adjacency_list) # Neighboring Tree Vertex of selected vertex
# Minimum Distance of explored vertex with neighboring vertex of partial tree # Minimum Distance of explored vertex with neighboring vertex of partial tree
# formed in graph # formed in graph
distance_tv = [] # Heap of Distance of vertices from their neighboring vertex distance_tv = [] # Heap of Distance of vertices from their neighboring vertex
positions = [] positions = []
for x in range(len(l)): for vertex in range(len(adjacency_list)):
p = sys.maxsize distance_tv.append(sys.maxsize)
distance_tv.append(p) positions.append(vertex)
positions.append(x) heap.node_position.append(vertex)
node_position.append(x)
tree_edges = [] tree_edges = []
visited[0] = 1 visited[0] = 1
distance_tv[0] = sys.maxsize distance_tv[0] = sys.maxsize
for x in l[0]: for neighbor, distance in adjacency_list[0]:
nbr_tv[x[0]] = 0 nbr_tv[neighbor] = 0
distance_tv[x[0]] = x[1] distance_tv[neighbor] = distance
heapify(distance_tv, positions) heap.heapify(distance_tv, positions)
for _ in range(1, len(l)): for _ in range(1, len(adjacency_list)):
vertex = delete_minimum(distance_tv, positions) vertex = heap.delete_minimum(distance_tv, positions)
if visited[vertex] == 0: if visited[vertex] == 0:
tree_edges.append((nbr_tv[vertex], vertex)) tree_edges.append((nbr_tv[vertex], vertex))
visited[vertex] = 1 visited[vertex] = 1
for v in l[vertex]: for neighbor, distance in adjacency_list[vertex]:
if visited[v[0]] == 0 and v[1] < distance_tv[get_position(v[0])]: if (
distance_tv[get_position(v[0])] = v[1] visited[neighbor] == 0
bottom_to_top(v[1], get_position(v[0]), distance_tv, positions) and distance < distance_tv[heap.get_position(neighbor)]
nbr_tv[v[0]] = vertex ):
distance_tv[heap.get_position(neighbor)] = distance
heap.bottom_to_top(
distance, heap.get_position(neighbor), distance_tv, positions
)
nbr_tv[neighbor] = vertex
return tree_edges return tree_edges
if __name__ == "__main__": # pragma: no cover if __name__ == "__main__": # pragma: no cover
# < --------- Prims Algorithm --------- > # < --------- Prims Algorithm --------- >
n = int(input("Enter number of vertices: ").strip()) edges_number = int(input("Enter number of edges: ").strip())
e = int(input("Enter number of edges: ").strip()) adjacency_list = defaultdict(list)
adjlist = defaultdict(list) for _ in range(edges_number):
for x in range(e): edge = [int(x) for x in input().strip().split()]
l = [int(x) for x in input().strip().split()] # noqa: E741 adjacency_list[edge[0]].append([edge[1], edge[2]])
adjlist[l[0]].append([l[1], l[2]]) adjacency_list[edge[1]].append([edge[0], edge[2]])
adjlist[l[1]].append([l[0], l[2]]) print(prisms_algorithm(adjacency_list))
print(prisms_algorithm(adjlist))