lazy_link_cut_tree¶
ソースコード¶
from titan_pylib.data_structures.dynamic_connectivity.lazy_link_cut_tree import LazyLinkCutTree
展開済みコード¶
1# from titan_pylib.data_structures.dynamic_connectivity.lazy_link_cut_tree import LazyLinkCutTree
2from array import array
3from typing import Generic, TypeVar, Callable, Iterable, Union
4# from titan_pylib.data_structures.dynamic_connectivity.link_cut_tree import LinkCutTree
5from array import array
6
7
8class LinkCutTree:
9 """LinkCutTree です。"""
10
11 # - link / cut / merge / split
12 # - root / same
13 # - lca / path_length / path_kth_elm
14 # など
15
16 def __init__(self, n: int) -> None:
17 self.n = n
18 self.arr: array[int] = array("I", [self.n, self.n, self.n, 0] * (self.n + 1))
19 # node.left : arr[node<<2|0]
20 # node.right : arr[node<<2|1]
21 # node.par : arr[node<<2|2]
22 # node.rev : arr[node<<2|3]
23 self.size: array[int] = array("I", [1] * (self.n + 1))
24 self.size[-1] = 0
25 self.group_cnt = self.n
26
27 def _is_root(self, node: int) -> bool:
28 return (self.arr[node << 2 | 2] == self.n) or not (
29 self.arr[self.arr[node << 2 | 2] << 2] == node
30 or self.arr[self.arr[node << 2 | 2] << 2 | 1] == node
31 )
32
33 def _propagate(self, node: int) -> None:
34 if node == self.n:
35 return
36 arr = self.arr
37 if arr[node << 2 | 3]:
38 arr[node << 2 | 3] = 0
39 ln, rn = arr[node << 2], arr[node << 2 | 1]
40 arr[node << 2] = rn
41 arr[node << 2 | 1] = ln
42 arr[ln << 2 | 3] ^= 1
43 arr[rn << 2 | 3] ^= 1
44
45 def _update(self, node: int) -> None:
46 if node == self.n:
47 return
48 ln, rn = self.arr[node << 2], self.arr[node << 2 | 1]
49 self._propagate(ln)
50 self._propagate(rn)
51 self.size[node] = 1 + self.size[ln] + self.size[rn]
52
53 def _update_triple(self, x: int, y: int, z: int) -> None:
54 self._propagate(self.arr[x << 2])
55 self._propagate(self.arr[x << 2 | 1])
56 self._propagate(self.arr[y << 2])
57 self._propagate(self.arr[y << 2 | 1])
58 self.size[z] = self.size[x]
59 self.size[x] = 1 + self.size[self.arr[x << 2]] + self.size[self.arr[x << 2 | 1]]
60 self.size[y] = 1 + self.size[self.arr[y << 2]] + self.size[self.arr[y << 2 | 1]]
61
62 def _update_double(self, x: int, y: int) -> None:
63 self._propagate(self.arr[x << 2])
64 self._propagate(self.arr[x << 2 | 1])
65 self.size[y] = self.size[x]
66 self.size[x] = 1 + self.size[self.arr[x << 2]] + self.size[self.arr[x << 2 | 1]]
67
68 def _splay(self, node: int) -> None:
69 # splayを抜けた後、nodeは遅延伝播済みにするようにする
70 # (splay後のnodeのleft,rightにアクセスしやすいと非常にラクなはず)
71 if node == self.n:
72 return
73 _propagate, _is_root, _update_triple = (
74 self._propagate,
75 self._is_root,
76 self._update_triple,
77 )
78 _propagate(node)
79 if _is_root(node):
80 return
81 arr = self.arr
82 pnode = arr[node << 2 | 2]
83 while not _is_root(pnode):
84 gnode = arr[pnode << 2 | 2]
85 _propagate(gnode)
86 _propagate(pnode)
87 _propagate(node)
88 f = arr[pnode << 2] == node
89 g = (arr[gnode << 2 | f] == pnode) ^ (arr[pnode << 2 | f] == node)
90 nnode = (node if g else pnode) << 2 | f ^ g
91 arr[pnode << 2 | f ^ 1] = arr[node << 2 | f]
92 arr[gnode << 2 | f ^ g ^ 1] = arr[nnode]
93 arr[node << 2 | f] = pnode
94 arr[nnode] = gnode
95 arr[node << 2 | 2] = arr[gnode << 2 | 2]
96 arr[gnode << 2 | 2] = nnode >> 2
97 arr[arr[pnode << 2 | f ^ 1] << 2 | 2] = pnode
98 arr[arr[gnode << 2 | f ^ g ^ 1] << 2 | 2] = gnode
99 arr[pnode << 2 | 2] = node
100 _update_triple(gnode, pnode, node)
101 pnode = arr[node << 2 | 2]
102 if arr[pnode << 2] == gnode:
103 arr[pnode << 2] = node
104 elif arr[pnode << 2 | 1] == gnode:
105 arr[pnode << 2 | 1] = node
106 else:
107 return
108 _propagate(pnode)
109 _propagate(node)
110 f = arr[pnode << 2] == node
111 arr[pnode << 2 | f ^ 1] = arr[node << 2 | f]
112 arr[node << 2 | f] = pnode
113 arr[arr[pnode << 2 | f ^ 1] << 2 | 2] = pnode
114 arr[node << 2 | 2] = arr[pnode << 2 | 2]
115 arr[pnode << 2 | 2] = node
116 self._update_double(pnode, node)
117
118 def expose(self, v: int) -> int:
119 """``v`` が属する木において、その木を管理しているsplay木の根から ``v`` までのパスを作ります。
120 償却 :math:`O(\\log{n})` です。
121 """
122 arr, n, _splay, _update = self.arr, self.n, self._splay, self._update
123 pre = v
124 while arr[v << 2 | 2] != n:
125 _splay(v)
126 arr[v << 2 | 1] = n
127 _update(v)
128 if arr[v << 2 | 2] == n:
129 break
130 pre = arr[v << 2 | 2]
131 _splay(pre)
132 arr[pre << 2 | 1] = v
133 _update(pre)
134 arr[v << 2 | 1] = n
135 _update(v)
136 return pre
137
138 def lca(self, u: int, v: int, root: int) -> int:
139 """``root`` を根としたときの、 ``u``, ``v`` の LCA を返します。
140 償却 :math:`O(\\log{n})` です。
141 """
142 self.evert(root)
143 self.expose(u)
144 return self.expose(v)
145
146 def link(self, c: int, p: int) -> None:
147 """辺 ``(c -> p)`` を追加します。
148 償却 :math:`O(\\log{n})` です。
149
150 制約:
151 ``c`` は元の木の根でなければならないです。
152 """
153 assert not self.same(c, p)
154 self.expose(c)
155 self.expose(p)
156 self.arr[c << 2 | 2] = p
157 self.arr[p << 2 | 1] = c
158 self._update(p)
159 self.group_cnt -= 1
160
161 def cut(self, c: int) -> None:
162 """辺 ``{c -> cの親}`` を削除します。
163 償却 :math:`O(\\log{n})` です。
164
165 制約:
166 ``c`` は元の木の根であってはいけないです。
167 """
168 arr = self.arr
169 self.expose(c)
170 assert arr[c << 2] != self.n
171 arr[arr[c << 2] << 2 | 2] = self.n
172 arr[c << 2] = self.n
173 self._update(c)
174 self.group_cnt += 1
175
176 def group_count(self) -> int:
177 """連結成分数を返します。
178 :math:`O(1)` です。
179 """
180 return self.group_cnt
181
182 def root(self, v: int) -> int:
183 """``v`` が属する木の根を返します。
184 償却 :math:`O(\\log{n})` です。
185 """
186 self.expose(v)
187 arr, n = self.arr, self.n
188 while arr[v << 2] != n:
189 v = arr[v << 2]
190 self._propagate(v)
191 self._splay(v)
192 return v
193
194 def same(self, u: int, v: int) -> bool:
195 """連結判定です。
196 償却 :math:`O(\\log{n})` です。
197
198 Returns:
199 bool: ``u``, ``v`` が同じ連結成分であれば ``True`` を、そうでなければ ``False`` を返します。
200 """
201 return self.root(u) == self.root(v)
202
203 def evert(self, v: int) -> None:
204 """``v`` を根にします。
205 償却 :math:`O(\\log{n})` です。
206 """
207 self.expose(v)
208 self.arr[v << 2 | 3] ^= 1
209 self._propagate(v)
210
211 def merge(self, u: int, v: int) -> bool:
212 """``u``, ``v`` が同じ連結成分なら ``False`` を返します。
213 そうでなければ辺 ``{u -> v}`` を追加して ``True`` を返します。
214 償却 :math:`O(\\log{n})` です。
215 """
216 if self.same(u, v):
217 return False
218 self.evert(u)
219 self.expose(v)
220 self.arr[u << 2 | 2] = v
221 self.arr[v << 2 | 1] = u
222 self._update(v)
223 self.group_cnt -= 1
224 return True
225
226 def split(self, u: int, v: int) -> bool:
227 """辺 ``{u -> v}`` があれば削除し ``True`` を返します。
228 そうでなければ何もせず ``False`` を返します。
229 償却 :math:`O(\\log{n})` です。
230 """
231 self.evert(u)
232 self.cut(v)
233 return True
234
235 def path_length(self, u: int, v: int) -> int:
236 """``u`` から ``v`` へのパスに含まれる頂点の数を返します。
237 存在しないときは ``-1`` を返します。
238 償却 :math:`O(\\log{n})` です。
239 """
240 if not self.same(u, v):
241 return -1
242 self.evert(u)
243 self.expose(v)
244 return self.size[v]
245
246 def path_kth_elm(self, s: int, t: int, k: int) -> int:
247 """``u`` から ``v`` へ ``k`` 個進んだ頂点を返します。
248 存在しないときは ``-1`` を返します。
249 償却 :math:`O(\\log{n})` です。
250 """
251 self.evert(s)
252 self.expose(t)
253 if self.size[t] <= k:
254 return -1
255 size, arr = self.size, self.arr
256 while True:
257 self._propagate(t)
258 s = size[arr[t << 2]]
259 if s == k:
260 self._splay(t)
261 return t
262 t = arr[t << 2 | (s < k)]
263 if s < k:
264 k -= s + 1
265
266 def __str__(self):
267 return f"{self.__class__.__name__}"
268
269 __repr__ = __str__
270
271T = TypeVar("T")
272F = TypeVar("F")
273
274
275class LazyLinkCutTree(LinkCutTree, Generic[T, F]):
276 """LazyLinkCutTree です。"""
277
278 # パスクエリ全部載せ
279 # - link / cut / merge / split
280 # - prod / apply / getitem / setitem
281 # - root / same
282 # - lca / path_length / path_kth_elm
283 # など
284
285 # opがいらないならupdateを即returnするように変更したり、
286 # 可換opならupdateを短縮したりなど
287
288 def __init__(
289 self,
290 n_or_a: Union[int, Iterable[T]],
291 op: Callable[[T, T], T],
292 mapping: Callable[[F, T], T],
293 composition: Callable[[F, F], F],
294 e: T,
295 id: F,
296 ) -> None:
297 """
298 各引数は遅延セグ木のアレです。よしなに。
299
300 Args:
301 op (Callable[[T, T], T]): 非可換でも構いません。
302 """
303 self.op = op
304 self.mapping = mapping
305 self.composition = composition
306 self.e = e
307 self.id = id
308 self.key: list[T] = [e] * (n_or_a) if isinstance(n_or_a, int) else list(n_or_a)
309 self.n = len(self.key)
310 self.key.append(e)
311 self.data: list[T] = [x for x in self.key for _ in range(2)]
312 self.lazy: list[F] = [id] * (self.n + 1)
313 self.arr: array[int] = array("I", [self.n, self.n, self.n, 0] * (self.n + 1))
314 # node.left : arr[node<<2|0]
315 # node.right : arr[node<<2|1]
316 # node.par : arr[node<<2|2]
317 # node.rev : arr[node<<2|3]
318 self.size: array[int] = array("I", [1] * (self.n + 1))
319 self.size[-1] = 0
320 self.group_cnt = self.n
321
322 def _propagate_lazy(self, node: int, f: F) -> None:
323 if node == self.n:
324 return
325 self.key[node] = self.mapping(f, self.key[node])
326 self.data[node << 1] = self.mapping(f, self.data[node << 1])
327 self.data[node << 1 | 1] = self.mapping(f, self.data[node << 1 | 1])
328 self.lazy[node] = (
329 f if self.lazy[node] == self.id else self.composition(f, self.lazy[node])
330 )
331
332 def _propagate(self, node: int) -> None:
333 if node == self.n:
334 return
335 arr = self.arr
336 if arr[node << 2 | 3]:
337 self.data[node << 1], self.data[node << 1 | 1] = (
338 self.data[node << 1 | 1],
339 self.data[node << 1],
340 )
341 arr[node << 2 | 3] = 0
342 ln, rn = arr[node << 2], arr[node << 2 | 1]
343 arr[node << 2] = rn
344 arr[node << 2 | 1] = ln
345 arr[ln << 2 | 3] ^= 1
346 arr[rn << 2 | 3] ^= 1
347 if self.lazy[node] == self.id:
348 return
349 self._propagate_lazy(self.arr[node << 2], self.lazy[node])
350 self._propagate_lazy(self.arr[node << 2 | 1], self.lazy[node])
351 self.lazy[node] = self.id
352
353 def _update(self, node: int) -> None:
354 if node == self.n:
355 return
356 ln, rn = self.arr[node << 2], self.arr[node << 2 | 1]
357 self._propagate(ln)
358 self._propagate(rn)
359 self.data[node << 1] = self.op(
360 self.op(self.data[ln << 1], self.key[node]), self.data[rn << 1]
361 )
362 self.data[node << 1 | 1] = self.op(
363 self.op(self.data[rn << 1 | 1], self.key[node]), self.data[ln << 1 | 1]
364 )
365 self.size[node] = 1 + self.size[ln] + self.size[rn]
366
367 def _update_triple(self, x: int, y: int, z: int) -> None:
368 data, key, arr, size = self.data, self.key, self.arr, self.size
369 lx, rx = arr[x << 2], arr[x << 2 | 1]
370 ly, ry = arr[y << 2], arr[y << 2 | 1]
371 self._propagate(lx)
372 self._propagate(rx)
373 self._propagate(ly)
374 self._propagate(ry)
375 data[z << 1] = data[x << 1]
376 data[x << 1] = self.op(self.op(data[lx << 1], key[x]), data[rx << 1])
377 data[y << 1] = self.op(self.op(data[ly << 1], key[y]), data[ry << 1])
378 data[z << 1 | 1] = data[x << 1 | 1]
379 data[x << 1 | 1] = self.op(
380 self.op(data[rx << 1 | 1], key[x]), data[lx << 1 | 1]
381 )
382 data[y << 1 | 1] = self.op(
383 self.op(data[ry << 1 | 1], key[y]), data[ly << 1 | 1]
384 )
385 size[z] = size[x]
386 size[x] = 1 + size[lx] + size[rx]
387 size[y] = 1 + size[ly] + size[ry]
388
389 def _update_double(self, x: int, y: int) -> None:
390 data, key, arr, size = self.data, self.key, self.arr, self.size
391 lx, rx = arr[x << 2], arr[x << 2 | 1]
392 self._propagate(lx)
393 self._propagate(rx)
394 data[y << 1] = data[x << 1]
395 data[x << 1] = self.op(self.op(data[lx << 1], key[x]), data[rx << 1])
396 data[y << 1 | 1] = data[x << 1 | 1]
397 data[x << 1 | 1] = self.op(
398 self.op(data[rx << 1 | 1], key[x]), data[lx << 1 | 1]
399 )
400 size[y] = size[x]
401 size[x] = 1 + size[lx] + size[rx]
402
403 def path_prod(self, u: int, v: int) -> T:
404 """``u`` から ``v`` へのパスの総積を返します。
405 償却 :math:`O(\\log{n})` です。
406 """
407 self.evert(u)
408 self.expose(v)
409 return self.data[v << 1]
410
411 def path_apply(self, u: int, v: int, f: F) -> None:
412 """``u`` から ``v`` へのパスに ``f`` を作用させます。
413 償却 :math:`O(\\log{n})` です。
414 """
415 self.evert(u)
416 self.expose(v)
417 self._propagate_lazy(v, f)
418
419 def __setitem__(self, k: int, v: T):
420 """頂点 ``k`` の値を ``v`` に更新します。
421 償却 :math:`O(\\log{n})` です。
422 """
423 self._splay(k)
424 self.key[k] = v
425 self._update(k)
426
427 def __getitem__(self, k: int) -> T:
428 """頂点 ``k`` の値を返します。
429 償却 :math:`O(\\log{n})` です。
430 """
431 self._splay(k)
432 return self.key[k]
433
434 def __str__(self):
435 return str([self[i] for i in range(self.n)])
436
437 __repr__ = __str__
仕様¶
- class LazyLinkCutTree(n_or_a: int | Iterable[T], op: Callable[[T, T], T], mapping: Callable[[F, T], T], composition: Callable[[F, F], F], e: T, id: F)[source]¶
Bases:
LinkCutTree
,Generic
[T
,F
]LazyLinkCutTree です。