diff --git a/data_structures/UnionFind/__init__.py b/data_structures/UnionFind/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/data_structures/UnionFind/tests_union_find.py b/data_structures/UnionFind/tests_union_find.py new file mode 100644 index 000000000..bdcc01033 --- /dev/null +++ b/data_structures/UnionFind/tests_union_find.py @@ -0,0 +1,77 @@ +from union_find import UnionFind +import unittest + + +class TestUnionFind(unittest.TestCase): + def test_init_with_valid_size(self): + uf = UnionFind(5) + self.assertEqual(uf.size, 5) + + def test_init_with_invalid_size(self): + with self.assertRaises(ValueError): + uf = UnionFind(0) + + with self.assertRaises(ValueError): + uf = UnionFind(-5) + + def test_union_with_valid_values(self): + uf = UnionFind(10) + + for i in range(11): + for j in range(11): + uf.union(i, j) + + def test_union_with_invalid_values(self): + uf = UnionFind(10) + + with self.assertRaises(ValueError): + uf.union(-1, 1) + + with self.assertRaises(ValueError): + uf.union(11, 1) + + def test_same_set_with_valid_values(self): + uf = UnionFind(10) + + for i in range(11): + for j in range(11): + if i == j: + self.assertTrue(uf.same_set(i, j)) + else: + self.assertFalse(uf.same_set(i, j)) + + uf.union(1, 2) + self.assertTrue(uf.same_set(1, 2)) + + uf.union(3, 4) + self.assertTrue(uf.same_set(3, 4)) + + self.assertFalse(uf.same_set(1, 3)) + self.assertFalse(uf.same_set(1, 4)) + self.assertFalse(uf.same_set(2, 3)) + self.assertFalse(uf.same_set(2, 4)) + + uf.union(1, 3) + self.assertTrue(uf.same_set(1, 3)) + self.assertTrue(uf.same_set(1, 4)) + self.assertTrue(uf.same_set(2, 3)) + self.assertTrue(uf.same_set(2, 4)) + + uf.union(4, 10) + self.assertTrue(uf.same_set(1, 10)) + self.assertTrue(uf.same_set(2, 10)) + self.assertTrue(uf.same_set(3, 10)) + self.assertTrue(uf.same_set(4, 10)) + + def test_same_set_with_invalid_values(self): + uf = UnionFind(10) + + with self.assertRaises(ValueError): + uf.same_set(-1, 1) + + with self.assertRaises(ValueError): + uf.same_set(11, 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/data_structures/UnionFind/union_find.py b/data_structures/UnionFind/union_find.py new file mode 100644 index 000000000..40eea67ac --- /dev/null +++ b/data_structures/UnionFind/union_find.py @@ -0,0 +1,87 @@ +class UnionFind(): + """ + https://en.wikipedia.org/wiki/Disjoint-set_data_structure + + The union-find is a disjoint-set data structure + + You can merge two sets and tell if one set belongs to + another one. + + It's used on the Kruskal Algorithm + (https://en.wikipedia.org/wiki/Kruskal%27s_algorithm) + + The elements are in range [0, size] + """ + def __init__(self, size): + if size <= 0: + raise ValueError("size should be greater than 0") + + self.size = size + + # The below plus 1 is because we are using elements + # in range [0, size]. It makes more sense. + + # Every set begins with only itself + self.root = [i for i in range(size+1)] + + # This is used for heuristic union by rank + self.weight = [0 for i in range(size+1)] + + def union(self, u, v): + """ + Union of the sets u and v. + Complexity: log(n). + Amortized complexity: < 5 (it's very fast). + """ + + self._validate_element_range(u, "u") + self._validate_element_range(v, "v") + + if u == v: + return + + # Using union by rank will guarantee the + # log(n) complexity + rootu = self._root(u) + rootv = self._root(v) + weight_u = self.weight[rootu] + weight_v = self.weight[rootv] + if weight_u >= weight_v: + self.root[rootv] = rootu + if weight_u == weight_v: + self.weight[rootu] += 1 + else: + self.root[rootu] = rootv + + def same_set(self, u, v): + """ + Return true if the elements u and v belongs to + the same set + """ + + self._validate_element_range(u, "u") + self._validate_element_range(v, "v") + + return self._root(u) == self._root(v) + + def _root(self, u): + """ + Get the element set root. + This uses the heuristic path compression + See wikipedia article for more details. + """ + + if u != self.root[u]: + self.root[u] = self._root(self.root[u]) + + return self.root[u] + + def _validate_element_range(self, u, element_name): + """ + Raises ValueError if element is not in range + """ + if u < 0 or u > self.size: + msg = ("element {0} with value {1} " + "should be in range [0~{2}]")\ + .format(element_name, u, self.size) + raise ValueError(msg)