union_find¶
ソースコード¶
from titan_pylib.data_structures.union_find.union_find import UnionFind
展開済みコード¶
1# from titan_pylib.data_structures.union_find.union_find import UnionFind
2from collections import defaultdict
3
4
5class UnionFind:
6
7 def __init__(self, n: int) -> None:
8 """``n`` 個の要素からなる ``UnionFind`` を構築します。
9 :math:`O(n)` です。
10 """
11 self._n: int = n
12 self._group_numbers: int = n
13 self._parents: list[int] = [-1] * n
14
15 def root(self, x: int) -> int:
16 """要素 ``x`` を含む集合の代表元を返します。
17 :math:`O(\\alpha(n))` です。
18 """
19 a = x
20 while self._parents[a] >= 0:
21 a = self._parents[a]
22 while self._parents[x] >= 0:
23 y = x
24 x = self._parents[x]
25 self._parents[y] = a
26 return a
27
28 def unite(self, x: int, y: int) -> bool:
29 """要素 ``x`` を含む集合と要素 ``y`` を含む集合を併合します。
30 :math:`O(\\alpha(n))` です。
31
32 Returns:
33 bool: もともと同じ集合であれば ``False``、そうでなければ ``True`` を返します。
34 """
35 x = self.root(x)
36 y = self.root(y)
37 if x == y:
38 return False
39 self._group_numbers -= 1
40 if self._parents[x] > self._parents[y]:
41 x, y = y, x
42 self._parents[x] += self._parents[y]
43 self._parents[y] = x
44 return True
45
46 def unite_right(self, x: int, y: int) -> int:
47 # x -> y
48 x = self.root(x)
49 y = self.root(y)
50 if x == y:
51 return x
52 self._group_numbers -= 1
53 self._parents[y] += self._parents[x]
54 self._parents[x] = y
55 return y
56
57 def unite_left(self, x: int, y: int) -> int:
58 # x <- y
59 x = self.root(x)
60 y = self.root(y)
61 if x == y:
62 return x
63 self._group_numbers -= 1
64 self._parents[x] += self._parents[y]
65 self._parents[y] = x
66 return x
67
68 def size(self, x: int) -> int:
69 """要素 ``x`` を含む集合の要素数を返します。
70 :math:`O(\\alpha(n))` です。
71 """
72 return -self._parents[self.root(x)]
73
74 def same(self, x: int, y: int) -> bool:
75 """
76 要素 ``x`` と ``y`` が同じ集合に属するなら ``True`` を、
77 そうでないなら ``False`` を返します。
78 :math:`O(\\alpha(n))` です。
79 """
80 return self.root(x) == self.root(y)
81
82 def members(self, x: int) -> list[int]:
83 """要素 ``x`` を含む集合を返します。"""
84 x = self.root(x)
85 return [i for i in range(self._n) if self.root(i) == x]
86
87 def all_roots(self) -> list[int]:
88 """全ての集合の代表元からなるリストを返します。
89 :math:`O(n)` です。
90
91 Returns:
92 list[int]: 昇順であることが保証されます。
93 """
94 return [i for i, x in enumerate(self._parents) if x < 0]
95
96 def group_count(self) -> int:
97 """集合の総数を返します。
98 :math:`O(1)` です。
99 """
100 return self._group_numbers
101
102 def all_group_members(self) -> defaultdict:
103 """
104 key に代表元、 value に key を代表元とする集合のリストをもつ defaultdict を返します。
105 :math:`O(n\\alpha(n))` です。
106 """
107 group_members = defaultdict(list)
108 for member in range(self._n):
109 group_members[self.root(member)].append(member)
110 return group_members
111
112 def clear(self) -> None:
113 """集合の連結状態をなくします(初期状態に戻します)。
114 :math:`O(n)` です。
115 """
116 self._group_numbers = self._n
117 for i in range(self._n):
118 self._parents[i] = -1
119
120 def __str__(self) -> str:
121 """よしなにします。
122 :math:`O(n\\alpha(n))` です。
123 """
124 return (
125 f"<{self.__class__.__name__}> [\n"
126 + "\n".join(f" {k}: {v}" for k, v in self.all_group_members().items())
127 + "\n]"
128 )
仕様¶
- class UnionFind(n: int)[source]¶
Bases:
object
- all_group_members() defaultdict [source]¶
key に代表元、 value に key を代表元とする集合のリストをもつ defaultdict を返します。 \(O(n\alpha(n))\) です。
- all_roots() list[int] [source]¶
全ての集合の代表元からなるリストを返します。 \(O(n)\) です。
- Returns:
昇順であることが保証されます。
- Return type:
list[int]
- same(x: int, y: int) bool [source]¶
要素
x
とy
が同じ集合に属するならTrue
を、 そうでないならFalse
を返します。 \(O(\alpha(n))\) です。