diff --git a/DIRECTORY.md b/DIRECTORY.md index 6a3d31709..d3a378c3a 100644 --- a/DIRECTORY.md +++ b/DIRECTORY.md @@ -293,6 +293,8 @@ * [Scc Kosaraju](https://github.com/TheAlgorithms/Python/blob/master/graphs/scc_kosaraju.py) * [Strongly Connected Components](https://github.com/TheAlgorithms/Python/blob/master/graphs/strongly_connected_components.py) * [Tarjans Scc](https://github.com/TheAlgorithms/Python/blob/master/graphs/tarjans_scc.py) + * Tests + * [Test Min Spanning Tree Kruskal](https://github.com/TheAlgorithms/Python/blob/master/graphs/tests/test_min_spanning_tree_kruskal.py) ## Greedy Method * [Greedy Knapsack](https://github.com/TheAlgorithms/Python/blob/master/greedy_method/greedy_knapsack.py) diff --git a/graphs/minimum_spanning_tree_kruskal.py b/graphs/minimum_spanning_tree_kruskal.py index 91b44f650..610baf4b5 100644 --- a/graphs/minimum_spanning_tree_kruskal.py +++ b/graphs/minimum_spanning_tree_kruskal.py @@ -1,13 +1,5 @@ -if __name__ == "__main__": - num_nodes, num_edges = list(map(int, input().strip().split())) - - edges = [] - - for i in range(num_edges): - node1, node2, cost = list(map(int, input().strip().split())) - edges.append((i, node1, node2, cost)) - - edges = sorted(edges, key=lambda edge: edge[3]) +def kruskal(num_nodes, num_edges, edges): + edges = sorted(edges, key=lambda edge: edge[2]) parent = list(range(num_nodes)) @@ -20,13 +12,22 @@ if __name__ == "__main__": minimum_spanning_tree = [] for edge in edges: - parent_a = find_parent(edge[1]) - parent_b = find_parent(edge[2]) + parent_a = find_parent(edge[0]) + parent_b = find_parent(edge[1]) if parent_a != parent_b: - minimum_spanning_tree_cost += edge[3] + minimum_spanning_tree_cost += edge[2] minimum_spanning_tree.append(edge) parent[parent_a] = parent_b - print(minimum_spanning_tree_cost) - for edge in minimum_spanning_tree: - print(edge) + return minimum_spanning_tree + + +if __name__ == "__main__": # pragma: no cover + num_nodes, num_edges = list(map(int, input().strip().split())) + edges = [] + + for _ in range(num_edges): + node1, node2, cost = [int(x) for x in input().strip().split()] + edges.append((node1, node2, cost)) + + kruskal(num_nodes, num_edges, edges) diff --git a/graphs/tests/test_min_spanning_tree_kruskal.py b/graphs/tests/test_min_spanning_tree_kruskal.py new file mode 100644 index 000000000..3a527aef3 --- /dev/null +++ b/graphs/tests/test_min_spanning_tree_kruskal.py @@ -0,0 +1,36 @@ +from graphs.minimum_spanning_tree_kruskal import kruskal + + +def test_kruskal_successful_result(): + num_nodes, num_edges = 9, 14 + edges = [ + [0, 1, 4], + [0, 7, 8], + [1, 2, 8], + [7, 8, 7], + [7, 6, 1], + [2, 8, 2], + [8, 6, 6], + [2, 3, 7], + [2, 5, 4], + [6, 5, 2], + [3, 5, 14], + [3, 4, 9], + [5, 4, 10], + [1, 7, 11], + ] + + result = kruskal(num_nodes, num_edges, edges) + + expected = [ + [7, 6, 1], + [2, 8, 2], + [6, 5, 2], + [0, 1, 4], + [2, 5, 4], + [2, 3, 7], + [0, 7, 8], + [3, 4, 9], + ] + + assert sorted(expected) == sorted(result)