mirror of
https://github.com/TheAlgorithms/Python.git
synced 2024-11-24 05:21:09 +00:00
Merge pull request #186 from gilbertoalexsantos/union-find
Created Union-Find algorithm
This commit is contained in:
commit
8f3abf5ce6
0
data_structures/UnionFind/__init__.py
Normal file
0
data_structures/UnionFind/__init__.py
Normal file
77
data_structures/UnionFind/tests_union_find.py
Normal file
77
data_structures/UnionFind/tests_union_find.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
from union_find import UnionFind
|
||||
import unittest
|
||||
|
||||
|
||||
class TestUnionFind(unittest.TestCase):
|
||||
def test_init_with_valid_size(self):
|
||||
uf = UnionFind(5)
|
||||
self.assertEqual(uf.size, 5)
|
||||
|
||||
def test_init_with_invalid_size(self):
|
||||
with self.assertRaises(ValueError):
|
||||
uf = UnionFind(0)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
uf = UnionFind(-5)
|
||||
|
||||
def test_union_with_valid_values(self):
|
||||
uf = UnionFind(10)
|
||||
|
||||
for i in range(11):
|
||||
for j in range(11):
|
||||
uf.union(i, j)
|
||||
|
||||
def test_union_with_invalid_values(self):
|
||||
uf = UnionFind(10)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
uf.union(-1, 1)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
uf.union(11, 1)
|
||||
|
||||
def test_same_set_with_valid_values(self):
|
||||
uf = UnionFind(10)
|
||||
|
||||
for i in range(11):
|
||||
for j in range(11):
|
||||
if i == j:
|
||||
self.assertTrue(uf.same_set(i, j))
|
||||
else:
|
||||
self.assertFalse(uf.same_set(i, j))
|
||||
|
||||
uf.union(1, 2)
|
||||
self.assertTrue(uf.same_set(1, 2))
|
||||
|
||||
uf.union(3, 4)
|
||||
self.assertTrue(uf.same_set(3, 4))
|
||||
|
||||
self.assertFalse(uf.same_set(1, 3))
|
||||
self.assertFalse(uf.same_set(1, 4))
|
||||
self.assertFalse(uf.same_set(2, 3))
|
||||
self.assertFalse(uf.same_set(2, 4))
|
||||
|
||||
uf.union(1, 3)
|
||||
self.assertTrue(uf.same_set(1, 3))
|
||||
self.assertTrue(uf.same_set(1, 4))
|
||||
self.assertTrue(uf.same_set(2, 3))
|
||||
self.assertTrue(uf.same_set(2, 4))
|
||||
|
||||
uf.union(4, 10)
|
||||
self.assertTrue(uf.same_set(1, 10))
|
||||
self.assertTrue(uf.same_set(2, 10))
|
||||
self.assertTrue(uf.same_set(3, 10))
|
||||
self.assertTrue(uf.same_set(4, 10))
|
||||
|
||||
def test_same_set_with_invalid_values(self):
|
||||
uf = UnionFind(10)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
uf.same_set(-1, 1)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
uf.same_set(11, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
87
data_structures/UnionFind/union_find.py
Normal file
87
data_structures/UnionFind/union_find.py
Normal file
|
@ -0,0 +1,87 @@
|
|||
class UnionFind():
|
||||
"""
|
||||
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
|
||||
|
||||
The union-find is a disjoint-set data structure
|
||||
|
||||
You can merge two sets and tell if one set belongs to
|
||||
another one.
|
||||
|
||||
It's used on the Kruskal Algorithm
|
||||
(https://en.wikipedia.org/wiki/Kruskal%27s_algorithm)
|
||||
|
||||
The elements are in range [0, size]
|
||||
"""
|
||||
def __init__(self, size):
|
||||
if size <= 0:
|
||||
raise ValueError("size should be greater than 0")
|
||||
|
||||
self.size = size
|
||||
|
||||
# The below plus 1 is because we are using elements
|
||||
# in range [0, size]. It makes more sense.
|
||||
|
||||
# Every set begins with only itself
|
||||
self.root = [i for i in range(size+1)]
|
||||
|
||||
# This is used for heuristic union by rank
|
||||
self.weight = [0 for i in range(size+1)]
|
||||
|
||||
def union(self, u, v):
|
||||
"""
|
||||
Union of the sets u and v.
|
||||
Complexity: log(n).
|
||||
Amortized complexity: < 5 (it's very fast).
|
||||
"""
|
||||
|
||||
self._validate_element_range(u, "u")
|
||||
self._validate_element_range(v, "v")
|
||||
|
||||
if u == v:
|
||||
return
|
||||
|
||||
# Using union by rank will guarantee the
|
||||
# log(n) complexity
|
||||
rootu = self._root(u)
|
||||
rootv = self._root(v)
|
||||
weight_u = self.weight[rootu]
|
||||
weight_v = self.weight[rootv]
|
||||
if weight_u >= weight_v:
|
||||
self.root[rootv] = rootu
|
||||
if weight_u == weight_v:
|
||||
self.weight[rootu] += 1
|
||||
else:
|
||||
self.root[rootu] = rootv
|
||||
|
||||
def same_set(self, u, v):
|
||||
"""
|
||||
Return true if the elements u and v belongs to
|
||||
the same set
|
||||
"""
|
||||
|
||||
self._validate_element_range(u, "u")
|
||||
self._validate_element_range(v, "v")
|
||||
|
||||
return self._root(u) == self._root(v)
|
||||
|
||||
def _root(self, u):
|
||||
"""
|
||||
Get the element set root.
|
||||
This uses the heuristic path compression
|
||||
See wikipedia article for more details.
|
||||
"""
|
||||
|
||||
if u != self.root[u]:
|
||||
self.root[u] = self._root(self.root[u])
|
||||
|
||||
return self.root[u]
|
||||
|
||||
def _validate_element_range(self, u, element_name):
|
||||
"""
|
||||
Raises ValueError if element is not in range
|
||||
"""
|
||||
if u < 0 or u > self.size:
|
||||
msg = ("element {0} with value {1} "
|
||||
"should be in range [0~{2}]")\
|
||||
.format(element_name, u, self.size)
|
||||
raise ValueError(msg)
|
Loading…
Reference in New Issue
Block a user