euler_tour_tree¶
ソースコード¶
from titan_pylib.data_structures.dynamic_connectivity.euler_tour_tree import EulerTourTree
展開済みコード¶
1# from titan_pylib.data_structures.dynamic_connectivity.euler_tour_tree import EulerTourTree
2from typing import Generator, Generic, TypeVar, Callable, Iterable, Optional, Union
3from types import GeneratorType
4
5T = TypeVar("T")
6F = TypeVar("F")
7
8
9class EulerTourTree(Generic[T, F]):
10
11 class _Node:
12
13 def __init__(self, key: T, lazy: F):
14 self.key: T = key
15 self.data: T = key
16 self.lazy: F = lazy
17 self.par: Optional[EulerTourTree._Node] = None
18 self.left: Optional[EulerTourTree._Node] = None
19 self.right: Optional[EulerTourTree._Node] = None
20
21 def __str__(self):
22 if self.left is None and self.right is None:
23 return f"(key,par):{self.key,self.data,self.lazy,(self.par.key if self.par else None)}\n"
24 return f"(key,par):{self.key,self.data,self.lazy,(self.par.key if self.par else None)},\n left:{self.left},\n right:{self.right}\n"
25
26 __repr__ = __str__
27
28 def __init__(
29 self,
30 n_or_a: Union[int, Iterable[T]],
31 op: Callable[[T, T], T],
32 mapping: Callable[[F, T], T],
33 composition: Callable[[F, F], F],
34 e: T,
35 id: F,
36 ) -> None:
37 self.op = op
38 self.mapping = mapping
39 self.composition = composition
40 self.e = e
41 self.id = id
42 a = [e for _ in range(n_or_a)] if isinstance(n_or_a, int) else list(n_or_a)
43 self.n: int = len(a)
44 self.ptr_vertex: list[EulerTourTree._Node] = [
45 EulerTourTree._Node(elem, id) for i, elem in enumerate(a)
46 ]
47 self.ptr_edge: dict[tuple[int, int], EulerTourTree._Node] = {}
48 self._group_numbers: int = self.n
49
50 @staticmethod
51 def antirec(func, stack=[]):
52 # 参考: https://github.com/cheran-senthil/PyRival/blob/master/pyrival/misc/bootstrap.py
53 def wrappedfunc(*args, **kwargs):
54 if stack:
55 return func(*args, **kwargs)
56 to = func(*args, **kwargs)
57 while True:
58 if isinstance(to, GeneratorType):
59 stack.append(to)
60 to = next(to)
61 else:
62 stack.pop()
63 if not stack:
64 break
65 to = stack[-1].send(to)
66 return to
67
68 return wrappedfunc
69
70 def build(self, G: list[list[int]]) -> None:
71 """隣接リスト ``G`` をもとにして、辺を張ります。
72 :math:`O(n)` です。
73
74 Args:
75 G (list[list[int]]): 隣接リストです。
76
77 Note:
78 ``build`` メソッドを使用する場合は他のメソッドより前に使用しなければなりません。
79 """
80 n, ptr_vertex, ptr_edge, e, id = (
81 self.n,
82 self.ptr_vertex,
83 self.ptr_edge,
84 self.e,
85 self.id,
86 )
87 seen = [0] * n
88 _Node = EulerTourTree._Node
89
90 @EulerTourTree.antirec
91 def dfs(v: int, p: int = -1) -> Generator:
92 a.append(v * n + v)
93 for x in G[v]:
94 if x == p:
95 continue
96 a.append(v * n + x)
97 yield dfs(x, v)
98 a.append(x * n + v)
99 yield
100
101 @EulerTourTree.antirec
102 def rec(l: int, r: int) -> Generator:
103 mid = (l + r) >> 1
104 u, v = divmod(a[mid], n)
105 node = ptr_vertex[u] if u == v else _Node(e, id)
106 if u == v:
107 seen[u] = 1
108 else:
109 ptr_edge[u * n + v] = node
110 if l != mid:
111 node.left = yield rec(l, mid)
112 node.left.par = node
113 if mid + 1 != r:
114 node.right = yield rec(mid + 1, r)
115 node.right.par = node
116 self._update(node)
117 yield node
118
119 for root in range(self.n):
120 if seen[root]:
121 continue
122 a: list[int] = []
123 dfs(root)
124 rec(0, len(a))
125
126 def _popleft(self, v: _Node) -> Optional[_Node]:
127 v = self._left_splay(v)
128 if v.right:
129 v.right.par = None
130 return v.right
131
132 def _pop(self, v: _Node) -> Optional[_Node]:
133 v = self._right_splay(v)
134 if v.left:
135 v.left.par = None
136 return v.left
137
138 def _split_left(self, v: _Node) -> tuple[_Node, Optional[_Node]]:
139 # x, yに分割する。ただし、xはvを含む
140 self._splay(v)
141 x, y = v, v.right
142 if y:
143 y.par = None
144 x.right = None
145 self._update(x)
146 return x, y
147
148 def _split_right(self, v: _Node) -> tuple[Optional[_Node], _Node]:
149 # x, yに分割する。ただし、yはvを含む
150 self._splay(v)
151 x, y = v.left, v
152 if x:
153 x.par = None
154 y.left = None
155 self._update(y)
156 return x, y
157
158 def _merge(self, u: Optional[_Node], v: Optional[_Node]) -> None:
159 if u is None or v is None:
160 return
161 u = self._right_splay(u)
162 self._splay(v)
163 u.right = v
164 v.par = u
165 self._update(u)
166
167 def _splay(self, node: _Node) -> None:
168 self._propagate(node)
169 while node.par is not None and node.par.par is not None:
170 pnode = node.par
171 gnode = pnode.par
172 self._propagate(gnode)
173 self._propagate(pnode)
174 self._propagate(node)
175 node.par = gnode.par
176 if (gnode.left is pnode) == (pnode.left is node):
177 if pnode.left is node:
178 tmp1 = node.right
179 pnode.left = tmp1
180 node.right = pnode
181 pnode.par = node
182 tmp2 = pnode.right
183 gnode.left = tmp2
184 pnode.right = gnode
185 gnode.par = pnode
186 else:
187 tmp1 = node.left
188 pnode.right = tmp1
189 node.left = pnode
190 pnode.par = node
191 tmp2 = pnode.left
192 gnode.right = tmp2
193 pnode.left = gnode
194 gnode.par = pnode
195 if tmp1:
196 tmp1.par = pnode
197 if tmp2:
198 tmp2.par = gnode
199 else:
200 if pnode.left is node:
201 tmp1 = node.right
202 pnode.left = tmp1
203 node.right = pnode
204 tmp2 = node.left
205 gnode.right = tmp2
206 node.left = gnode
207 pnode.par = node
208 gnode.par = node
209 else:
210 tmp1 = node.left
211 pnode.right = tmp1
212 node.left = pnode
213 tmp2 = node.right
214 gnode.left = tmp2
215 node.right = gnode
216 pnode.par = node
217 gnode.par = node
218 if tmp1:
219 tmp1.par = pnode
220 if tmp2:
221 tmp2.par = gnode
222 self._update(gnode)
223 self._update(pnode)
224 self._update(node)
225 if node.par is None:
226 return
227 if node.par.left is gnode:
228 node.par.left = node
229 else:
230 node.par.right = node
231 if node.par is None:
232 return
233 pnode = node.par
234 self._propagate(pnode)
235 self._propagate(node)
236 if pnode.left is node:
237 pnode.left = node.right
238 if pnode.left:
239 pnode.left.par = pnode
240 node.right = pnode
241 else:
242 pnode.right = node.left
243 if pnode.right:
244 pnode.right.par = pnode
245 node.left = pnode
246 node.par = None
247 pnode.par = node
248 self._update(pnode)
249 self._update(node)
250
251 def _left_splay(self, node: _Node) -> _Node:
252 self._splay(node)
253 while node.left is not None:
254 node = node.left
255 self._splay(node)
256 return node
257
258 def _right_splay(self, node: _Node) -> _Node:
259 self._splay(node)
260 while node.right is not None:
261 node = node.right
262 self._splay(node)
263 return node
264
265 def _propagate(self, node: Optional[_Node]) -> None:
266 if node is None or node.lazy == self.id:
267 return
268 if node.left:
269 node.left.key = self.mapping(node.lazy, node.left.key)
270 node.left.data = self.mapping(node.lazy, node.left.data)
271 node.left.lazy = self.composition(node.lazy, node.left.lazy)
272 if node.right:
273 node.right.key = self.mapping(node.lazy, node.right.key)
274 node.right.data = self.mapping(node.lazy, node.right.data)
275 node.right.lazy = self.composition(node.lazy, node.right.lazy)
276 node.lazy = self.id
277
278 def _update(self, node: _Node) -> None:
279 self._propagate(node.left)
280 self._propagate(node.right)
281 node.data = node.key
282 if node.left:
283 node.data = self.op(node.left.data, node.data)
284 if node.right:
285 node.data = self.op(node.data, node.right.data)
286
287 def link(self, u: int, v: int) -> None:
288 """辺 ``{u, v}`` を追加します。
289 :math:`O(\\log{n})` です。
290
291 Note:
292 ``u`` と ``v`` が同じ連結成分であってはいけません。
293 """
294 # add edge{u, v}
295 self.reroot(u)
296 self.reroot(v)
297 assert (
298 u * self.n + v not in self.ptr_edge
299 ), f"EulerTourTree.link(), {(u, v)} in ptr_edge"
300 assert (
301 v * self.n + u not in self.ptr_edge
302 ), f"EulerTourTree.link(), {(v, u)} in ptr_edge"
303 uv_node = EulerTourTree._Node(self.e, self.id)
304 vu_node = EulerTourTree._Node(self.e, self.id)
305 self.ptr_edge[u * self.n + v] = uv_node
306 self.ptr_edge[v * self.n + u] = vu_node
307 u_node = self.ptr_vertex[u]
308 v_node = self.ptr_vertex[v]
309 self._merge(u_node, uv_node)
310 self._merge(uv_node, v_node)
311 self._merge(v_node, vu_node)
312 self._group_numbers -= 1
313
314 def cut(self, u: int, v: int) -> None:
315 """辺 ``{u, v}`` を削除します。
316 :math:`O(\\log{n})` です。
317
318 Note:
319 辺 ``{u, v}`` が存在してなければいけません。
320 """
321 # erace edge{u, v}
322 self.reroot(v)
323 self.reroot(u)
324 assert (
325 u * self.n + v in self.ptr_edge
326 ), f"EulerTourTree.cut(), {(u, v)} not in ptr_edge"
327 assert (
328 v * self.n + u in self.ptr_edge
329 ), f"EulerTourTree.cut(), {(v, u)} not in ptr_edge"
330 uv_node = self.ptr_edge.pop(u * self.n + v)
331 vu_node = self.ptr_edge.pop(v * self.n + u)
332 a, _ = self._split_left(uv_node)
333 _, c = self._split_right(vu_node)
334 a = self._pop(a)
335 c = self._popleft(c)
336 self._merge(a, c)
337 self._group_numbers += 1
338
339 def leader(self, v: int) -> _Node:
340 """頂点 ``v`` を含む木の代表元を返します。
341 :math:`O(\\log{n})` です。
342
343 Note:
344 ``reroot`` すると変わるので注意です。
345 """
346 # vを含む木の代表元
347 # rerootすると変わるので注意
348 return self._left_splay(self.ptr_vertex[v])
349
350 def reroot(self, v: int) -> None:
351 """頂点 ``v`` を含む木の根を ``v`` にします。
352
353 :math:`O(\\log{n})` です。
354 """
355 node = self.ptr_vertex[v]
356 x, y = self._split_right(node)
357 self._merge(y, x)
358 self._splay(node)
359
360 def same(self, u: int, v: int) -> bool:
361 """
362 頂点 ``u`` と ``v`` が同じ連結成分にいれば ``True`` を、
363 そうでなければ ``False`` を返します。
364
365 :math:`O(\\log{n})` です。
366 """
367 u_node = self.ptr_vertex[u]
368 v_node = self.ptr_vertex[v]
369 self._splay(u_node)
370 self._splay(v_node)
371 return u_node.par is not None or u_node is v_node
372
373 def _show(self) -> None:
374 # for debug
375 print("+++++++++++++++++++++++++++")
376 for i, v in enumerate(self.ptr_vertex):
377 print((i, i), v, end="\n\n")
378 for k, v in self.ptr_edge.items():
379 print(k, v, end="\n\n")
380 print("+++++++++++++++++++++++++++")
381
382 def subtree_apply(self, v: int, p: int, f: F) -> None:
383 """頂点 ``v`` を根としたときの部分木に ``f`` を作用します。
384
385 ``v`` の親は ``p`` です。
386 ``v`` の親が存在しないときは ``p=-1`` として下さい。
387
388 :math:`O(\\log{n})` です。
389
390 Args:
391 v (int): 根です。
392 p (int): ``v`` の親です。
393 f (F): 作用素です。
394 """
395 if p == -1:
396 v_node = self.ptr_vertex[v]
397 self._splay(v_node)
398 v_node.key = self.mapping(f, v_node.key)
399 v_node.data = self.mapping(f, v_node.data)
400 v_node.lazy = self.composition(f, v_node.lazy)
401 return
402 self.reroot(v)
403 self.reroot(p)
404 assert (
405 p * self.n + v in self.ptr_edge
406 ), f"EulerTourTree.subtree_apply(), {(p, v)} not in ptr_edge"
407 assert (
408 v * self.n + p in self.ptr_edge
409 ), f"EulerTourTree.subtree_apply(), {(v, p)} not in ptr_edge"
410 v_node = self.ptr_vertex[v]
411 a, b = self._split_right(self.ptr_edge[p * self.n + v])
412 b, d = self._split_left(self.ptr_edge[v * self.n + p])
413 self._splay(v_node)
414 v_node.key = self.mapping(f, v_node.key)
415 v_node.data = self.mapping(f, v_node.data)
416 v_node.lazy = self.composition(f, v_node.lazy)
417 self._propagate(v_node)
418 self._merge(a, b)
419 self._merge(b, d)
420
421 def subtree_sum(self, v: int, p: int) -> T:
422 """頂点 ``v`` を根としたときの部分木の総和を返します。
423
424 ``v`` の親は ``p`` です。
425 ``v`` の親が存在しないときは ``p=-1`` として下さい。
426
427 :math:`O(\\log{n})` です。
428
429 Args:
430 v (int): 根です。
431 p (int): ``v`` の親です。
432 """
433 if p == -1:
434 v_node = self.ptr_vertex[v]
435 self._splay(v_node)
436 return v_node.data
437 self.reroot(v)
438 self.reroot(p)
439 assert (
440 p * self.n + v in self.ptr_edge
441 ), f"EulerTourTree.subtree_sum(), {(p, v)} not in ptr_edge"
442 assert (
443 v * self.n + p in self.ptr_edge
444 ), f"EulerTourTree.subtree_sum(), {(v, p)} not in ptr_edge"
445 v_node = self.ptr_vertex[v]
446 a, b = self._split_right(self.ptr_edge[p * self.n + v])
447 b, d = self._split_left(self.ptr_edge[v * self.n + p])
448 self._splay(v_node)
449 res = v_node.data
450 self._merge(a, b)
451 self._merge(b, d)
452 return res
453
454 def group_count(self) -> int:
455 """連結成分の個数を返します。
456 :math:`O(1)` です。
457 """
458 return self._group_numbers
459
460 def get_vertex(self, v: int) -> T:
461 """頂点 ``v`` の ``key`` を返します。
462 :math:`O(\\log{n})` です。
463 """
464 node = self.ptr_vertex[v]
465 self._splay(node)
466 return node.key
467
468 def set_vertex(self, v: int, val: T) -> None:
469 """頂点 ``v`` の ``key`` を ``val`` に更新します。
470 :math:`O(\\log{n})` です。
471 """
472 node = self.ptr_vertex[v]
473 self._splay(node)
474 node.key = val
475 self._update(node)
476
477 def __getitem__(self, v: int) -> T:
478 return self.get_vertex(v)
479
480 def __setitem__(self, v: int, val: T) -> None:
481 return self.set_vertex(v, val)
仕様¶
- class EulerTourTree(n_or_a: int | Iterable[T], op: Callable[[T, T], T], mapping: Callable[[F, T], T], composition: Callable[[F, F], F], e: T, id: F)[source]¶
Bases:
Generic
[T
,F
]- build(G: list[list[int]]) None [source]¶
隣接リスト
G
をもとにして、辺を張ります。 \(O(n)\) です。- Parameters:
G (list[list[int]]) – 隣接リストです。
Note
build
メソッドを使用する場合は他のメソッドより前に使用しなければなりません。
- cut(u: int, v: int) None [source]¶
辺
{u, v}
を削除します。 \(O(\log{n})\) です。Note
辺
{u, v}
が存在してなければいけません。
- link(u: int, v: int) None [source]¶
辺
{u, v}
を追加します。 \(O(\log{n})\) です。Note
u
とv
が同じ連結成分であってはいけません。
- same(u: int, v: int) bool [source]¶
頂点
u
とv
が同じ連結成分にいればTrue
を、 そうでなければFalse
を返します。\(O(\log{n})\) です。