From 01601e6382e92b5a3806fb3a44357460dff97ee7 Mon Sep 17 00:00:00 2001 From: luoheng <1301089462@qq.com> Date: Mon, 23 Sep 2019 11:08:20 +0800 Subject: [PATCH] Add disjoint set (#1194) * Add disjoint set * disjoint set: add doctest, make code more Pythonic * disjoint set: replace x.p with x.parent * disjoint set: add test and refercence --- data_structures/disjoint_set/disjoint_set.py | 79 ++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 data_structures/disjoint_set/disjoint_set.py diff --git a/data_structures/disjoint_set/disjoint_set.py b/data_structures/disjoint_set/disjoint_set.py new file mode 100644 index 000000000..a93b89621 --- /dev/null +++ b/data_structures/disjoint_set/disjoint_set.py @@ -0,0 +1,79 @@ +""" + disjoint set + Reference: https://en.wikipedia.org/wiki/Disjoint-set_data_structure +""" + + +class Node: + def __init__(self, data): + self.data = data + + +def make_set(x): + """ + make x as a set. + """ + # rank is the distance from x to its' parent + # root's rank is 0 + x.rank = 0 + x.parent = x + + +def union_set(x, y): + """ + union two sets. + set with bigger rank should be parent, so that the + disjoint set tree will be more flat. + """ + x, y = find_set(x), find_set(y) + if x.rank > y.rank: + y.parent = x + else: + x.parent = y + if x.rank == y.rank: + y.rank += 1 + + +def find_set(x): + """ + return the parent of x + """ + if x != x.parent: + x.parent = find_set(x.parent) + return x.parent + + +def find_python_set(node: Node) -> set: + """ + Return a Python Standard Library set that contains i. + """ + sets = ({0, 1, 2}, {3, 4, 5}) + for s in sets: + if node.data in s: + return s + raise ValueError(f"{node.data} is not in {sets}") + + +def test_disjoint_set(): + """ + >>> test_disjoint_set() + """ + vertex = [Node(i) for i in range(6)] + for v in vertex: + make_set(v) + + union_set(vertex[0], vertex[1]) + union_set(vertex[1], vertex[2]) + union_set(vertex[3], vertex[4]) + union_set(vertex[3], vertex[5]) + + for node0 in vertex: + for node1 in vertex: + if find_python_set(node0).isdisjoint(find_python_set(node1)): + assert find_set(node0) != find_set(node1) + else: + assert find_set(node0) == find_set(node1) + + +if __name__ == "__main__": + test_disjoint_set()