1from typing import Generator, Generic, TypeVar, Callable, Iterable, Optional, Union
2from types import GeneratorType
3
4T = TypeVar("T")
5F = TypeVar("F")
6
7
[docs]
8class EulerTourTree(Generic[T, F]):
9
10 class _Node:
11
12 def __init__(self, key: T, lazy: F):
13 self.key: T = key
14 self.data: T = key
15 self.lazy: F = lazy
16 self.par: Optional[EulerTourTree._Node] = None
17 self.left: Optional[EulerTourTree._Node] = None
18 self.right: Optional[EulerTourTree._Node] = None
19
20 def __str__(self):
21 if self.left is None and self.right is None:
22 return f"(key,par):{self.key,self.data,self.lazy,(self.par.key if self.par else None)}\n"
23 return f"(key,par):{self.key,self.data,self.lazy,(self.par.key if self.par else None)},\n left:{self.left},\n right:{self.right}\n"
24
25 __repr__ = __str__
26
27 def __init__(
28 self,
29 n_or_a: Union[int, Iterable[T]],
30 op: Callable[[T, T], T],
31 mapping: Callable[[F, T], T],
32 composition: Callable[[F, F], F],
33 e: T,
34 id: F,
35 ) -> None:
36 self.op = op
37 self.mapping = mapping
38 self.composition = composition
39 self.e = e
40 self.id = id
41 a = [e for _ in range(n_or_a)] if isinstance(n_or_a, int) else list(n_or_a)
42 self.n: int = len(a)
43 self.ptr_vertex: list[EulerTourTree._Node] = [
44 EulerTourTree._Node(elem, id) for i, elem in enumerate(a)
45 ]
46 self.ptr_edge: dict[tuple[int, int], EulerTourTree._Node] = {}
47 self._group_numbers: int = self.n
48
[docs]
49 @staticmethod
50 def antirec(func, stack=[]):
51 # 参考: https://github.com/cheran-senthil/PyRival/blob/master/pyrival/misc/bootstrap.py
52 def wrappedfunc(*args, **kwargs):
53 if stack:
54 return func(*args, **kwargs)
55 to = func(*args, **kwargs)
56 while True:
57 if isinstance(to, GeneratorType):
58 stack.append(to)
59 to = next(to)
60 else:
61 stack.pop()
62 if not stack:
63 break
64 to = stack[-1].send(to)
65 return to
66
67 return wrappedfunc
68
[docs]
69 def build(self, G: list[list[int]]) -> None:
70 """隣接リスト ``G`` をもとにして、辺を張ります。
71 :math:`O(n)` です。
72
73 Args:
74 G (list[list[int]]): 隣接リストです。
75
76 Note:
77 ``build`` メソッドを使用する場合は他のメソッドより前に使用しなければなりません。
78 """
79 n, ptr_vertex, ptr_edge, e, id = (
80 self.n,
81 self.ptr_vertex,
82 self.ptr_edge,
83 self.e,
84 self.id,
85 )
86 seen = [0] * n
87 _Node = EulerTourTree._Node
88
89 @EulerTourTree.antirec
90 def dfs(v: int, p: int = -1) -> Generator:
91 a.append(v * n + v)
92 for x in G[v]:
93 if x == p:
94 continue
95 a.append(v * n + x)
96 yield dfs(x, v)
97 a.append(x * n + v)
98 yield
99
100 @EulerTourTree.antirec
101 def rec(l: int, r: int) -> Generator:
102 mid = (l + r) >> 1
103 u, v = divmod(a[mid], n)
104 node = ptr_vertex[u] if u == v else _Node(e, id)
105 if u == v:
106 seen[u] = 1
107 else:
108 ptr_edge[u * n + v] = node
109 if l != mid:
110 node.left = yield rec(l, mid)
111 node.left.par = node
112 if mid + 1 != r:
113 node.right = yield rec(mid + 1, r)
114 node.right.par = node
115 self._update(node)
116 yield node
117
118 for root in range(self.n):
119 if seen[root]:
120 continue
121 a: list[int] = []
122 dfs(root)
123 rec(0, len(a))
124
125 def _popleft(self, v: _Node) -> Optional[_Node]:
126 v = self._left_splay(v)
127 if v.right:
128 v.right.par = None
129 return v.right
130
131 def _pop(self, v: _Node) -> Optional[_Node]:
132 v = self._right_splay(v)
133 if v.left:
134 v.left.par = None
135 return v.left
136
137 def _split_left(self, v: _Node) -> tuple[_Node, Optional[_Node]]:
138 # x, yに分割する。ただし、xはvを含む
139 self._splay(v)
140 x, y = v, v.right
141 if y:
142 y.par = None
143 x.right = None
144 self._update(x)
145 return x, y
146
147 def _split_right(self, v: _Node) -> tuple[Optional[_Node], _Node]:
148 # x, yに分割する。ただし、yはvを含む
149 self._splay(v)
150 x, y = v.left, v
151 if x:
152 x.par = None
153 y.left = None
154 self._update(y)
155 return x, y
156
157 def _merge(self, u: Optional[_Node], v: Optional[_Node]) -> None:
158 if u is None or v is None:
159 return
160 u = self._right_splay(u)
161 self._splay(v)
162 u.right = v
163 v.par = u
164 self._update(u)
165
166 def _splay(self, node: _Node) -> None:
167 self._propagate(node)
168 while node.par is not None and node.par.par is not None:
169 pnode = node.par
170 gnode = pnode.par
171 self._propagate(gnode)
172 self._propagate(pnode)
173 self._propagate(node)
174 node.par = gnode.par
175 if (gnode.left is pnode) == (pnode.left is node):
176 if pnode.left is node:
177 tmp1 = node.right
178 pnode.left = tmp1
179 node.right = pnode
180 pnode.par = node
181 tmp2 = pnode.right
182 gnode.left = tmp2
183 pnode.right = gnode
184 gnode.par = pnode
185 else:
186 tmp1 = node.left
187 pnode.right = tmp1
188 node.left = pnode
189 pnode.par = node
190 tmp2 = pnode.left
191 gnode.right = tmp2
192 pnode.left = gnode
193 gnode.par = pnode
194 if tmp1:
195 tmp1.par = pnode
196 if tmp2:
197 tmp2.par = gnode
198 else:
199 if pnode.left is node:
200 tmp1 = node.right
201 pnode.left = tmp1
202 node.right = pnode
203 tmp2 = node.left
204 gnode.right = tmp2
205 node.left = gnode
206 pnode.par = node
207 gnode.par = node
208 else:
209 tmp1 = node.left
210 pnode.right = tmp1
211 node.left = pnode
212 tmp2 = node.right
213 gnode.left = tmp2
214 node.right = gnode
215 pnode.par = node
216 gnode.par = node
217 if tmp1:
218 tmp1.par = pnode
219 if tmp2:
220 tmp2.par = gnode
221 self._update(gnode)
222 self._update(pnode)
223 self._update(node)
224 if node.par is None:
225 return
226 if node.par.left is gnode:
227 node.par.left = node
228 else:
229 node.par.right = node
230 if node.par is None:
231 return
232 pnode = node.par
233 self._propagate(pnode)
234 self._propagate(node)
235 if pnode.left is node:
236 pnode.left = node.right
237 if pnode.left:
238 pnode.left.par = pnode
239 node.right = pnode
240 else:
241 pnode.right = node.left
242 if pnode.right:
243 pnode.right.par = pnode
244 node.left = pnode
245 node.par = None
246 pnode.par = node
247 self._update(pnode)
248 self._update(node)
249
250 def _left_splay(self, node: _Node) -> _Node:
251 self._splay(node)
252 while node.left is not None:
253 node = node.left
254 self._splay(node)
255 return node
256
257 def _right_splay(self, node: _Node) -> _Node:
258 self._splay(node)
259 while node.right is not None:
260 node = node.right
261 self._splay(node)
262 return node
263
264 def _propagate(self, node: Optional[_Node]) -> None:
265 if node is None or node.lazy == self.id:
266 return
267 if node.left:
268 node.left.key = self.mapping(node.lazy, node.left.key)
269 node.left.data = self.mapping(node.lazy, node.left.data)
270 node.left.lazy = self.composition(node.lazy, node.left.lazy)
271 if node.right:
272 node.right.key = self.mapping(node.lazy, node.right.key)
273 node.right.data = self.mapping(node.lazy, node.right.data)
274 node.right.lazy = self.composition(node.lazy, node.right.lazy)
275 node.lazy = self.id
276
277 def _update(self, node: _Node) -> None:
278 self._propagate(node.left)
279 self._propagate(node.right)
280 node.data = node.key
281 if node.left:
282 node.data = self.op(node.left.data, node.data)
283 if node.right:
284 node.data = self.op(node.data, node.right.data)
285
[docs]
286 def link(self, u: int, v: int) -> None:
287 """辺 ``{u, v}`` を追加します。
288 :math:`O(\\log{n})` です。
289
290 Note:
291 ``u`` と ``v`` が同じ連結成分であってはいけません。
292 """
293 # add edge{u, v}
294 self.reroot(u)
295 self.reroot(v)
296 assert (
297 u * self.n + v not in self.ptr_edge
298 ), f"EulerTourTree.link(), {(u, v)} in ptr_edge"
299 assert (
300 v * self.n + u not in self.ptr_edge
301 ), f"EulerTourTree.link(), {(v, u)} in ptr_edge"
302 uv_node = EulerTourTree._Node(self.e, self.id)
303 vu_node = EulerTourTree._Node(self.e, self.id)
304 self.ptr_edge[u * self.n + v] = uv_node
305 self.ptr_edge[v * self.n + u] = vu_node
306 u_node = self.ptr_vertex[u]
307 v_node = self.ptr_vertex[v]
308 self._merge(u_node, uv_node)
309 self._merge(uv_node, v_node)
310 self._merge(v_node, vu_node)
311 self._group_numbers -= 1
312
[docs]
313 def cut(self, u: int, v: int) -> None:
314 """辺 ``{u, v}`` を削除します。
315 :math:`O(\\log{n})` です。
316
317 Note:
318 辺 ``{u, v}`` が存在してなければいけません。
319 """
320 # erace edge{u, v}
321 self.reroot(v)
322 self.reroot(u)
323 assert (
324 u * self.n + v in self.ptr_edge
325 ), f"EulerTourTree.cut(), {(u, v)} not in ptr_edge"
326 assert (
327 v * self.n + u in self.ptr_edge
328 ), f"EulerTourTree.cut(), {(v, u)} not in ptr_edge"
329 uv_node = self.ptr_edge.pop(u * self.n + v)
330 vu_node = self.ptr_edge.pop(v * self.n + u)
331 a, _ = self._split_left(uv_node)
332 _, c = self._split_right(vu_node)
333 a = self._pop(a)
334 c = self._popleft(c)
335 self._merge(a, c)
336 self._group_numbers += 1
337
[docs]
338 def leader(self, v: int) -> _Node:
339 """頂点 ``v`` を含む木の代表元を返します。
340 :math:`O(\\log{n})` です。
341
342 Note:
343 ``reroot`` すると変わるので注意です。
344 """
345 # vを含む木の代表元
346 # rerootすると変わるので注意
347 return self._left_splay(self.ptr_vertex[v])
348
[docs]
349 def reroot(self, v: int) -> None:
350 """頂点 ``v`` を含む木の根を ``v`` にします。
351
352 :math:`O(\\log{n})` です。
353 """
354 node = self.ptr_vertex[v]
355 x, y = self._split_right(node)
356 self._merge(y, x)
357 self._splay(node)
358
[docs]
359 def same(self, u: int, v: int) -> bool:
360 """
361 頂点 ``u`` と ``v`` が同じ連結成分にいれば ``True`` を、
362 そうでなければ ``False`` を返します。
363
364 :math:`O(\\log{n})` です。
365 """
366 u_node = self.ptr_vertex[u]
367 v_node = self.ptr_vertex[v]
368 self._splay(u_node)
369 self._splay(v_node)
370 return u_node.par is not None or u_node is v_node
371
372 def _show(self) -> None:
373 # for debug
374 print("+++++++++++++++++++++++++++")
375 for i, v in enumerate(self.ptr_vertex):
376 print((i, i), v, end="\n\n")
377 for k, v in self.ptr_edge.items():
378 print(k, v, end="\n\n")
379 print("+++++++++++++++++++++++++++")
380
[docs]
381 def subtree_apply(self, v: int, p: int, f: F) -> None:
382 """頂点 ``v`` を根としたときの部分木に ``f`` を作用します。
383
384 ``v`` の親は ``p`` です。
385 ``v`` の親が存在しないときは ``p=-1`` として下さい。
386
387 :math:`O(\\log{n})` です。
388
389 Args:
390 v (int): 根です。
391 p (int): ``v`` の親です。
392 f (F): 作用素です。
393 """
394 if p == -1:
395 v_node = self.ptr_vertex[v]
396 self._splay(v_node)
397 v_node.key = self.mapping(f, v_node.key)
398 v_node.data = self.mapping(f, v_node.data)
399 v_node.lazy = self.composition(f, v_node.lazy)
400 return
401 self.reroot(v)
402 self.reroot(p)
403 assert (
404 p * self.n + v in self.ptr_edge
405 ), f"EulerTourTree.subtree_apply(), {(p, v)} not in ptr_edge"
406 assert (
407 v * self.n + p in self.ptr_edge
408 ), f"EulerTourTree.subtree_apply(), {(v, p)} not in ptr_edge"
409 v_node = self.ptr_vertex[v]
410 a, b = self._split_right(self.ptr_edge[p * self.n + v])
411 b, d = self._split_left(self.ptr_edge[v * self.n + p])
412 self._splay(v_node)
413 v_node.key = self.mapping(f, v_node.key)
414 v_node.data = self.mapping(f, v_node.data)
415 v_node.lazy = self.composition(f, v_node.lazy)
416 self._propagate(v_node)
417 self._merge(a, b)
418 self._merge(b, d)
419
[docs]
420 def subtree_sum(self, v: int, p: int) -> T:
421 """頂点 ``v`` を根としたときの部分木の総和を返します。
422
423 ``v`` の親は ``p`` です。
424 ``v`` の親が存在しないときは ``p=-1`` として下さい。
425
426 :math:`O(\\log{n})` です。
427
428 Args:
429 v (int): 根です。
430 p (int): ``v`` の親です。
431 """
432 if p == -1:
433 v_node = self.ptr_vertex[v]
434 self._splay(v_node)
435 return v_node.data
436 self.reroot(v)
437 self.reroot(p)
438 assert (
439 p * self.n + v in self.ptr_edge
440 ), f"EulerTourTree.subtree_sum(), {(p, v)} not in ptr_edge"
441 assert (
442 v * self.n + p in self.ptr_edge
443 ), f"EulerTourTree.subtree_sum(), {(v, p)} not in ptr_edge"
444 v_node = self.ptr_vertex[v]
445 a, b = self._split_right(self.ptr_edge[p * self.n + v])
446 b, d = self._split_left(self.ptr_edge[v * self.n + p])
447 self._splay(v_node)
448 res = v_node.data
449 self._merge(a, b)
450 self._merge(b, d)
451 return res
452
[docs]
453 def group_count(self) -> int:
454 """連結成分の個数を返します。
455 :math:`O(1)` です。
456 """
457 return self._group_numbers
458
[docs]
459 def get_vertex(self, v: int) -> T:
460 """頂点 ``v`` の ``key`` を返します。
461 :math:`O(\\log{n})` です。
462 """
463 node = self.ptr_vertex[v]
464 self._splay(node)
465 return node.key
466
[docs]
467 def set_vertex(self, v: int, val: T) -> None:
468 """頂点 ``v`` の ``key`` を ``val`` に更新します。
469 :math:`O(\\log{n})` です。
470 """
471 node = self.ptr_vertex[v]
472 self._splay(node)
473 node.key = val
474 self._update(node)
475
476 def __getitem__(self, v: int) -> T:
477 return self.get_vertex(v)
478
479 def __setitem__(self, v: int, val: T) -> None:
480 return self.set_vertex(v, val)