Source code for titan_pylib.data_structures.union_find.union_find
1from collections import defaultdict
2
3
[docs]
4class UnionFind:
5
6 def __init__(self, n: int) -> None:
7 """``n`` 個の要素からなる ``UnionFind`` を構築します。
8 :math:`O(n)` です。
9 """
10 self._n: int = n
11 self._group_numbers: int = n
12 self._parents: list[int] = [-1] * n
13
[docs]
14 def root(self, x: int) -> int:
15 """要素 ``x`` を含む集合の代表元を返します。
16 :math:`O(\\alpha(n))` です。
17 """
18 a = x
19 while self._parents[a] >= 0:
20 a = self._parents[a]
21 while self._parents[x] >= 0:
22 y = x
23 x = self._parents[x]
24 self._parents[y] = a
25 return a
26
[docs]
27 def unite(self, x: int, y: int) -> bool:
28 """要素 ``x`` を含む集合と要素 ``y`` を含む集合を併合します。
29 :math:`O(\\alpha(n))` です。
30
31 Returns:
32 bool: もともと同じ集合であれば ``False``、そうでなければ ``True`` を返します。
33 """
34 x = self.root(x)
35 y = self.root(y)
36 if x == y:
37 return False
38 self._group_numbers -= 1
39 if self._parents[x] > self._parents[y]:
40 x, y = y, x
41 self._parents[x] += self._parents[y]
42 self._parents[y] = x
43 return True
44
[docs]
45 def unite_right(self, x: int, y: int) -> int:
46 # x -> y
47 x = self.root(x)
48 y = self.root(y)
49 if x == y:
50 return x
51 self._group_numbers -= 1
52 self._parents[y] += self._parents[x]
53 self._parents[x] = y
54 return y
55
[docs]
56 def unite_left(self, x: int, y: int) -> int:
57 # x <- y
58 x = self.root(x)
59 y = self.root(y)
60 if x == y:
61 return x
62 self._group_numbers -= 1
63 self._parents[x] += self._parents[y]
64 self._parents[y] = x
65 return x
66
[docs]
67 def size(self, x: int) -> int:
68 """要素 ``x`` を含む集合の要素数を返します。
69 :math:`O(\\alpha(n))` です。
70 """
71 return -self._parents[self.root(x)]
72
[docs]
73 def same(self, x: int, y: int) -> bool:
74 """
75 要素 ``x`` と ``y`` が同じ集合に属するなら ``True`` を、
76 そうでないなら ``False`` を返します。
77 :math:`O(\\alpha(n))` です。
78 """
79 return self.root(x) == self.root(y)
80
[docs]
81 def members(self, x: int) -> list[int]:
82 """要素 ``x`` を含む集合を返します。"""
83 x = self.root(x)
84 return [i for i in range(self._n) if self.root(i) == x]
85
[docs]
86 def all_roots(self) -> list[int]:
87 """全ての集合の代表元からなるリストを返します。
88 :math:`O(n)` です。
89
90 Returns:
91 list[int]: 昇順であることが保証されます。
92 """
93 return [i for i, x in enumerate(self._parents) if x < 0]
94
[docs]
95 def group_count(self) -> int:
96 """集合の総数を返します。
97 :math:`O(1)` です。
98 """
99 return self._group_numbers
100
[docs]
101 def all_group_members(self) -> defaultdict:
102 """
103 key に代表元、 value に key を代表元とする集合のリストをもつ defaultdict を返します。
104 :math:`O(n\\alpha(n))` です。
105 """
106 group_members = defaultdict(list)
107 for member in range(self._n):
108 group_members[self.root(member)].append(member)
109 return group_members
110
[docs]
111 def clear(self) -> None:
112 """集合の連結状態をなくします(初期状態に戻します)。
113 :math:`O(n)` です。
114 """
115 self._group_numbers = self._n
116 for i in range(self._n):
117 self._parents[i] = -1
118
[docs]
119 def __str__(self) -> str:
120 """よしなにします。
121 :math:`O(n\\alpha(n))` です。
122 """
123 return (
124 f"<{self.__class__.__name__}> [\n"
125 + "\n".join(f" {k}: {v}" for k, v in self.all_group_members().items())
126 + "\n]"
127 )