2019-09-23 03:08:20 +00:00
|
|
|
"""
|
2024-03-13 06:52:41 +00:00
|
|
|
Disjoint set.
|
|
|
|
Reference: https://en.wikipedia.org/wiki/Disjoint-set_data_structure
|
2019-09-23 03:08:20 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
class Node:
|
2021-10-11 16:34:30 +00:00
|
|
|
def __init__(self, data: int) -> None:
|
2019-09-23 03:08:20 +00:00
|
|
|
self.data = data
|
2021-10-11 16:34:30 +00:00
|
|
|
self.rank: int
|
|
|
|
self.parent: Node
|
2019-09-23 03:08:20 +00:00
|
|
|
|
|
|
|
|
2021-10-11 16:34:30 +00:00
|
|
|
def make_set(x: Node) -> None:
|
2019-09-23 03:08:20 +00:00
|
|
|
"""
|
2021-10-11 16:34:30 +00:00
|
|
|
Make x as a set.
|
2019-09-23 03:08:20 +00:00
|
|
|
"""
|
|
|
|
# rank is the distance from x to its' parent
|
|
|
|
# root's rank is 0
|
|
|
|
x.rank = 0
|
|
|
|
x.parent = x
|
|
|
|
|
|
|
|
|
2021-10-11 16:34:30 +00:00
|
|
|
def union_set(x: Node, y: Node) -> None:
|
2019-09-23 03:08:20 +00:00
|
|
|
"""
|
2021-10-11 16:34:30 +00:00
|
|
|
Union of two sets.
|
2019-09-23 03:08:20 +00:00
|
|
|
set with bigger rank should be parent, so that the
|
|
|
|
disjoint set tree will be more flat.
|
|
|
|
"""
|
|
|
|
x, y = find_set(x), find_set(y)
|
2021-09-24 10:54:38 +00:00
|
|
|
if x == y:
|
|
|
|
return
|
|
|
|
|
|
|
|
elif x.rank > y.rank:
|
2019-09-23 03:08:20 +00:00
|
|
|
y.parent = x
|
|
|
|
else:
|
|
|
|
x.parent = y
|
|
|
|
if x.rank == y.rank:
|
|
|
|
y.rank += 1
|
|
|
|
|
|
|
|
|
2021-10-11 16:34:30 +00:00
|
|
|
def find_set(x: Node) -> Node:
|
2019-09-23 03:08:20 +00:00
|
|
|
"""
|
2021-10-11 16:34:30 +00:00
|
|
|
Return the parent of x
|
2019-09-23 03:08:20 +00:00
|
|
|
"""
|
|
|
|
if x != x.parent:
|
|
|
|
x.parent = find_set(x.parent)
|
|
|
|
return x.parent
|
|
|
|
|
|
|
|
|
|
|
|
def find_python_set(node: Node) -> set:
|
|
|
|
"""
|
|
|
|
Return a Python Standard Library set that contains i.
|
|
|
|
"""
|
|
|
|
sets = ({0, 1, 2}, {3, 4, 5})
|
|
|
|
for s in sets:
|
|
|
|
if node.data in s:
|
|
|
|
return s
|
2023-05-26 07:34:17 +00:00
|
|
|
msg = f"{node.data} is not in {sets}"
|
|
|
|
raise ValueError(msg)
|
2019-09-23 03:08:20 +00:00
|
|
|
|
|
|
|
|
2021-10-11 16:34:30 +00:00
|
|
|
def test_disjoint_set() -> None:
|
2019-09-23 03:08:20 +00:00
|
|
|
"""
|
|
|
|
>>> test_disjoint_set()
|
|
|
|
"""
|
|
|
|
vertex = [Node(i) for i in range(6)]
|
|
|
|
for v in vertex:
|
|
|
|
make_set(v)
|
|
|
|
|
|
|
|
union_set(vertex[0], vertex[1])
|
|
|
|
union_set(vertex[1], vertex[2])
|
|
|
|
union_set(vertex[3], vertex[4])
|
|
|
|
union_set(vertex[3], vertex[5])
|
|
|
|
|
|
|
|
for node0 in vertex:
|
|
|
|
for node1 in vertex:
|
|
|
|
if find_python_set(node0).isdisjoint(find_python_set(node1)):
|
|
|
|
assert find_set(node0) != find_set(node1)
|
|
|
|
else:
|
|
|
|
assert find_set(node0) == find_set(node1)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
test_disjoint_set()
|