fully_retroactive_union_find¶
ソースコード¶
from titan_pylib.data_structures.union_find.fully_retroactive_union_find import FullyRetroactiveUnionFind
展開済みコード¶
1# from titan_pylib.data_structures.union_find.fully_retroactive_union_find import FullyRetroactiveUnionFind
2# from titan_pylib.data_structures.dynamic_connectivity.link_cut_tree import LinkCutTree
3from array import array
4
5
6class LinkCutTree:
7 """LinkCutTree です。"""
8
9 # - link / cut / merge / split
10 # - root / same
11 # - lca / path_length / path_kth_elm
12 # など
13
14 def __init__(self, n: int) -> None:
15 self.n = n
16 self.arr: array[int] = array("I", [self.n, self.n, self.n, 0] * (self.n + 1))
17 # node.left : arr[node<<2|0]
18 # node.right : arr[node<<2|1]
19 # node.par : arr[node<<2|2]
20 # node.rev : arr[node<<2|3]
21 self.size: array[int] = array("I", [1] * (self.n + 1))
22 self.size[-1] = 0
23 self.group_cnt = self.n
24
25 def _is_root(self, node: int) -> bool:
26 return (self.arr[node << 2 | 2] == self.n) or not (
27 self.arr[self.arr[node << 2 | 2] << 2] == node
28 or self.arr[self.arr[node << 2 | 2] << 2 | 1] == node
29 )
30
31 def _propagate(self, node: int) -> None:
32 if node == self.n:
33 return
34 arr = self.arr
35 if arr[node << 2 | 3]:
36 arr[node << 2 | 3] = 0
37 ln, rn = arr[node << 2], arr[node << 2 | 1]
38 arr[node << 2] = rn
39 arr[node << 2 | 1] = ln
40 arr[ln << 2 | 3] ^= 1
41 arr[rn << 2 | 3] ^= 1
42
43 def _update(self, node: int) -> None:
44 if node == self.n:
45 return
46 ln, rn = self.arr[node << 2], self.arr[node << 2 | 1]
47 self._propagate(ln)
48 self._propagate(rn)
49 self.size[node] = 1 + self.size[ln] + self.size[rn]
50
51 def _update_triple(self, x: int, y: int, z: int) -> None:
52 self._propagate(self.arr[x << 2])
53 self._propagate(self.arr[x << 2 | 1])
54 self._propagate(self.arr[y << 2])
55 self._propagate(self.arr[y << 2 | 1])
56 self.size[z] = self.size[x]
57 self.size[x] = 1 + self.size[self.arr[x << 2]] + self.size[self.arr[x << 2 | 1]]
58 self.size[y] = 1 + self.size[self.arr[y << 2]] + self.size[self.arr[y << 2 | 1]]
59
60 def _update_double(self, x: int, y: int) -> None:
61 self._propagate(self.arr[x << 2])
62 self._propagate(self.arr[x << 2 | 1])
63 self.size[y] = self.size[x]
64 self.size[x] = 1 + self.size[self.arr[x << 2]] + self.size[self.arr[x << 2 | 1]]
65
66 def _splay(self, node: int) -> None:
67 # splayを抜けた後、nodeは遅延伝播済みにするようにする
68 # (splay後のnodeのleft,rightにアクセスしやすいと非常にラクなはず)
69 if node == self.n:
70 return
71 _propagate, _is_root, _update_triple = (
72 self._propagate,
73 self._is_root,
74 self._update_triple,
75 )
76 _propagate(node)
77 if _is_root(node):
78 return
79 arr = self.arr
80 pnode = arr[node << 2 | 2]
81 while not _is_root(pnode):
82 gnode = arr[pnode << 2 | 2]
83 _propagate(gnode)
84 _propagate(pnode)
85 _propagate(node)
86 f = arr[pnode << 2] == node
87 g = (arr[gnode << 2 | f] == pnode) ^ (arr[pnode << 2 | f] == node)
88 nnode = (node if g else pnode) << 2 | f ^ g
89 arr[pnode << 2 | f ^ 1] = arr[node << 2 | f]
90 arr[gnode << 2 | f ^ g ^ 1] = arr[nnode]
91 arr[node << 2 | f] = pnode
92 arr[nnode] = gnode
93 arr[node << 2 | 2] = arr[gnode << 2 | 2]
94 arr[gnode << 2 | 2] = nnode >> 2
95 arr[arr[pnode << 2 | f ^ 1] << 2 | 2] = pnode
96 arr[arr[gnode << 2 | f ^ g ^ 1] << 2 | 2] = gnode
97 arr[pnode << 2 | 2] = node
98 _update_triple(gnode, pnode, node)
99 pnode = arr[node << 2 | 2]
100 if arr[pnode << 2] == gnode:
101 arr[pnode << 2] = node
102 elif arr[pnode << 2 | 1] == gnode:
103 arr[pnode << 2 | 1] = node
104 else:
105 return
106 _propagate(pnode)
107 _propagate(node)
108 f = arr[pnode << 2] == node
109 arr[pnode << 2 | f ^ 1] = arr[node << 2 | f]
110 arr[node << 2 | f] = pnode
111 arr[arr[pnode << 2 | f ^ 1] << 2 | 2] = pnode
112 arr[node << 2 | 2] = arr[pnode << 2 | 2]
113 arr[pnode << 2 | 2] = node
114 self._update_double(pnode, node)
115
116 def expose(self, v: int) -> int:
117 """``v`` が属する木において、その木を管理しているsplay木の根から ``v`` までのパスを作ります。
118 償却 :math:`O(\\log{n})` です。
119 """
120 arr, n, _splay, _update = self.arr, self.n, self._splay, self._update
121 pre = v
122 while arr[v << 2 | 2] != n:
123 _splay(v)
124 arr[v << 2 | 1] = n
125 _update(v)
126 if arr[v << 2 | 2] == n:
127 break
128 pre = arr[v << 2 | 2]
129 _splay(pre)
130 arr[pre << 2 | 1] = v
131 _update(pre)
132 arr[v << 2 | 1] = n
133 _update(v)
134 return pre
135
136 def lca(self, u: int, v: int, root: int) -> int:
137 """``root`` を根としたときの、 ``u``, ``v`` の LCA を返します。
138 償却 :math:`O(\\log{n})` です。
139 """
140 self.evert(root)
141 self.expose(u)
142 return self.expose(v)
143
144 def link(self, c: int, p: int) -> None:
145 """辺 ``(c -> p)`` を追加します。
146 償却 :math:`O(\\log{n})` です。
147
148 制約:
149 ``c`` は元の木の根でなければならないです。
150 """
151 assert not self.same(c, p)
152 self.expose(c)
153 self.expose(p)
154 self.arr[c << 2 | 2] = p
155 self.arr[p << 2 | 1] = c
156 self._update(p)
157 self.group_cnt -= 1
158
159 def cut(self, c: int) -> None:
160 """辺 ``{c -> cの親}`` を削除します。
161 償却 :math:`O(\\log{n})` です。
162
163 制約:
164 ``c`` は元の木の根であってはいけないです。
165 """
166 arr = self.arr
167 self.expose(c)
168 assert arr[c << 2] != self.n
169 arr[arr[c << 2] << 2 | 2] = self.n
170 arr[c << 2] = self.n
171 self._update(c)
172 self.group_cnt += 1
173
174 def group_count(self) -> int:
175 """連結成分数を返します。
176 :math:`O(1)` です。
177 """
178 return self.group_cnt
179
180 def root(self, v: int) -> int:
181 """``v`` が属する木の根を返します。
182 償却 :math:`O(\\log{n})` です。
183 """
184 self.expose(v)
185 arr, n = self.arr, self.n
186 while arr[v << 2] != n:
187 v = arr[v << 2]
188 self._propagate(v)
189 self._splay(v)
190 return v
191
192 def same(self, u: int, v: int) -> bool:
193 """連結判定です。
194 償却 :math:`O(\\log{n})` です。
195
196 Returns:
197 bool: ``u``, ``v`` が同じ連結成分であれば ``True`` を、そうでなければ ``False`` を返します。
198 """
199 return self.root(u) == self.root(v)
200
201 def evert(self, v: int) -> None:
202 """``v`` を根にします。
203 償却 :math:`O(\\log{n})` です。
204 """
205 self.expose(v)
206 self.arr[v << 2 | 3] ^= 1
207 self._propagate(v)
208
209 def merge(self, u: int, v: int) -> bool:
210 """``u``, ``v`` が同じ連結成分なら ``False`` を返します。
211 そうでなければ辺 ``{u -> v}`` を追加して ``True`` を返します。
212 償却 :math:`O(\\log{n})` です。
213 """
214 if self.same(u, v):
215 return False
216 self.evert(u)
217 self.expose(v)
218 self.arr[u << 2 | 2] = v
219 self.arr[v << 2 | 1] = u
220 self._update(v)
221 self.group_cnt -= 1
222 return True
223
224 def split(self, u: int, v: int) -> bool:
225 """辺 ``{u -> v}`` があれば削除し ``True`` を返します。
226 そうでなければ何もせず ``False`` を返します。
227 償却 :math:`O(\\log{n})` です。
228 """
229 self.evert(u)
230 self.cut(v)
231 return True
232
233 def path_length(self, u: int, v: int) -> int:
234 """``u`` から ``v`` へのパスに含まれる頂点の数を返します。
235 存在しないときは ``-1`` を返します。
236 償却 :math:`O(\\log{n})` です。
237 """
238 if not self.same(u, v):
239 return -1
240 self.evert(u)
241 self.expose(v)
242 return self.size[v]
243
244 def path_kth_elm(self, s: int, t: int, k: int) -> int:
245 """``u`` から ``v`` へ ``k`` 個進んだ頂点を返します。
246 存在しないときは ``-1`` を返します。
247 償却 :math:`O(\\log{n})` です。
248 """
249 self.evert(s)
250 self.expose(t)
251 if self.size[t] <= k:
252 return -1
253 size, arr = self.size, self.arr
254 while True:
255 self._propagate(t)
256 s = size[arr[t << 2]]
257 if s == k:
258 self._splay(t)
259 return t
260 t = arr[t << 2 | (s < k)]
261 if s < k:
262 k -= s + 1
263
264 def __str__(self):
265 return f"{self.__class__.__name__}"
266
267 __repr__ = __str__
268
269
270class FullyRetroactiveUnionFind:
271
272 def __init__(self, n: int, m: int) -> None:
273 """頂点数 ``n`` 、クエリ列の長さ ``m`` の ``FullyRetroactiveUnionFind`` を作ります。
274
275 ここで、クエリは `unite` のみです。
276
277 :math:`O(n+m)` です。
278
279 Args:
280 n (int): 頂点数です。
281 m (int): クエリ列の長さです。
282 """
283 m += 1
284 self.n: int = n
285 self.edge: list[tuple[int, int, int]] = [()] * m
286 self.node_pool: set[int] = set(range(n, n + m))
287 self.lct: LinkCutTree[int, None] = LinkCutTree(
288 n + m,
289 op=lambda s, t: s if s > t else t,
290 mapping=lambda f, s: -1,
291 composition=lambda f, g: None,
292 e=-1,
293 id=None,
294 )
295
296 def unite(self, u: int, v: int, t: int) -> None:
297 """時刻 ``t`` のクエリを ``unite(u, v)`` にします。
298
299 償却 :math:`O(\\log{(n+m)})` です。
300
301 Args:
302 u (int): 集合の要素です。
303 v (int): 集合の要素です。
304 t (int): 時刻です。
305
306 Note:
307 ``disconnect`` を使用する場合、 ``u``, ``v`` が連結されていてはいけません。
308 """
309 node = self.node_pool.pop()
310 self.edge[t] = (u, v, node)
311 self.lct[node] = t
312 self.lct.merge(u, node)
313 self.lct.merge(node, v)
314
315 def disconnect(self, t: int) -> None:
316 """時刻 ``t`` の連結クエリをなくして、そのクエリの2頂点を非連結にします。
317
318 償却 :math:`O(\\log{(n+m)})` です。
319
320 Args:
321 t (int): 時刻です。
322
323 Note:
324 時刻 ``t`` のクエリは連結クエリでないといけません。
325 """
326 assert self.edge[t] is not None
327 u, v, node = self.edge[t]
328 self.node_pool.add(node)
329 self.edge[t] = None
330 self.lct.split(u, node)
331 self.lct.split(node, v)
332
333 def same(self, u: int, v: int, t: int) -> bool:
334 """時刻 ``t`` で ``u``, ``v`` の連結判定をします。
335
336 償却 :math:`O(\\log{(n+m)})` です。
337
338 Args:
339 u (int): 集合の要素です。
340 v (int): 集合の要素です。
341 t (int): 時刻です。
342
343 Returns:
344 bool:
345 """
346 if not self.lct.same(u, v):
347 return False
348 return self.lct.path_prod(u, v) <= t
仕様¶
- class FullyRetroactiveUnionFind(n: int, m: int)[source]¶
Bases:
object
- disconnect(t: int) None [source]¶
時刻
t
の連結クエリをなくして、そのクエリの2頂点を非連結にします。償却 \(O(\log{(n+m)})\) です。
- Parameters:
t (int) – 時刻です。
Note
時刻
t
のクエリは連結クエリでないといけません。