Source code for titan_pylib.data_structures.dynamic_connectivity.euler_tour_tree

  1from typing import Generator, Generic, TypeVar, Callable, Iterable, Optional, Union
  2from types import GeneratorType
  3
  4T = TypeVar("T")
  5F = TypeVar("F")
  6
  7
[docs] 8class EulerTourTree(Generic[T, F]): 9 10 class _Node: 11 12 def __init__(self, key: T, lazy: F): 13 self.key: T = key 14 self.data: T = key 15 self.lazy: F = lazy 16 self.par: Optional[EulerTourTree._Node] = None 17 self.left: Optional[EulerTourTree._Node] = None 18 self.right: Optional[EulerTourTree._Node] = None 19 20 def __str__(self): 21 if self.left is None and self.right is None: 22 return f"(key,par):{self.key,self.data,self.lazy,(self.par.key if self.par else None)}\n" 23 return f"(key,par):{self.key,self.data,self.lazy,(self.par.key if self.par else None)},\n left:{self.left},\n right:{self.right}\n" 24 25 __repr__ = __str__ 26 27 def __init__( 28 self, 29 n_or_a: Union[int, Iterable[T]], 30 op: Callable[[T, T], T], 31 mapping: Callable[[F, T], T], 32 composition: Callable[[F, F], F], 33 e: T, 34 id: F, 35 ) -> None: 36 self.op = op 37 self.mapping = mapping 38 self.composition = composition 39 self.e = e 40 self.id = id 41 a = [e for _ in range(n_or_a)] if isinstance(n_or_a, int) else list(n_or_a) 42 self.n: int = len(a) 43 self.ptr_vertex: list[EulerTourTree._Node] = [ 44 EulerTourTree._Node(elem, id) for i, elem in enumerate(a) 45 ] 46 self.ptr_edge: dict[tuple[int, int], EulerTourTree._Node] = {} 47 self._group_numbers: int = self.n 48
[docs] 49 @staticmethod 50 def antirec(func, stack=[]): 51 # 参考: https://github.com/cheran-senthil/PyRival/blob/master/pyrival/misc/bootstrap.py 52 def wrappedfunc(*args, **kwargs): 53 if stack: 54 return func(*args, **kwargs) 55 to = func(*args, **kwargs) 56 while True: 57 if isinstance(to, GeneratorType): 58 stack.append(to) 59 to = next(to) 60 else: 61 stack.pop() 62 if not stack: 63 break 64 to = stack[-1].send(to) 65 return to 66 67 return wrappedfunc
68
[docs] 69 def build(self, G: list[list[int]]) -> None: 70 """隣接リスト ``G`` をもとにして、辺を張ります。 71 :math:`O(n)` です。 72 73 Args: 74 G (list[list[int]]): 隣接リストです。 75 76 Note: 77 ``build`` メソッドを使用する場合は他のメソッドより前に使用しなければなりません。 78 """ 79 n, ptr_vertex, ptr_edge, e, id = ( 80 self.n, 81 self.ptr_vertex, 82 self.ptr_edge, 83 self.e, 84 self.id, 85 ) 86 seen = [0] * n 87 _Node = EulerTourTree._Node 88 89 @EulerTourTree.antirec 90 def dfs(v: int, p: int = -1) -> Generator: 91 a.append(v * n + v) 92 for x in G[v]: 93 if x == p: 94 continue 95 a.append(v * n + x) 96 yield dfs(x, v) 97 a.append(x * n + v) 98 yield 99 100 @EulerTourTree.antirec 101 def rec(l: int, r: int) -> Generator: 102 mid = (l + r) >> 1 103 u, v = divmod(a[mid], n) 104 node = ptr_vertex[u] if u == v else _Node(e, id) 105 if u == v: 106 seen[u] = 1 107 else: 108 ptr_edge[u * n + v] = node 109 if l != mid: 110 node.left = yield rec(l, mid) 111 node.left.par = node 112 if mid + 1 != r: 113 node.right = yield rec(mid + 1, r) 114 node.right.par = node 115 self._update(node) 116 yield node 117 118 for root in range(self.n): 119 if seen[root]: 120 continue 121 a: list[int] = [] 122 dfs(root) 123 rec(0, len(a))
124 125 def _popleft(self, v: _Node) -> Optional[_Node]: 126 v = self._left_splay(v) 127 if v.right: 128 v.right.par = None 129 return v.right 130 131 def _pop(self, v: _Node) -> Optional[_Node]: 132 v = self._right_splay(v) 133 if v.left: 134 v.left.par = None 135 return v.left 136 137 def _split_left(self, v: _Node) -> tuple[_Node, Optional[_Node]]: 138 # x, yに分割する。ただし、xはvを含む 139 self._splay(v) 140 x, y = v, v.right 141 if y: 142 y.par = None 143 x.right = None 144 self._update(x) 145 return x, y 146 147 def _split_right(self, v: _Node) -> tuple[Optional[_Node], _Node]: 148 # x, yに分割する。ただし、yはvを含む 149 self._splay(v) 150 x, y = v.left, v 151 if x: 152 x.par = None 153 y.left = None 154 self._update(y) 155 return x, y 156 157 def _merge(self, u: Optional[_Node], v: Optional[_Node]) -> None: 158 if u is None or v is None: 159 return 160 u = self._right_splay(u) 161 self._splay(v) 162 u.right = v 163 v.par = u 164 self._update(u) 165 166 def _splay(self, node: _Node) -> None: 167 self._propagate(node) 168 while node.par is not None and node.par.par is not None: 169 pnode = node.par 170 gnode = pnode.par 171 self._propagate(gnode) 172 self._propagate(pnode) 173 self._propagate(node) 174 node.par = gnode.par 175 if (gnode.left is pnode) == (pnode.left is node): 176 if pnode.left is node: 177 tmp1 = node.right 178 pnode.left = tmp1 179 node.right = pnode 180 pnode.par = node 181 tmp2 = pnode.right 182 gnode.left = tmp2 183 pnode.right = gnode 184 gnode.par = pnode 185 else: 186 tmp1 = node.left 187 pnode.right = tmp1 188 node.left = pnode 189 pnode.par = node 190 tmp2 = pnode.left 191 gnode.right = tmp2 192 pnode.left = gnode 193 gnode.par = pnode 194 if tmp1: 195 tmp1.par = pnode 196 if tmp2: 197 tmp2.par = gnode 198 else: 199 if pnode.left is node: 200 tmp1 = node.right 201 pnode.left = tmp1 202 node.right = pnode 203 tmp2 = node.left 204 gnode.right = tmp2 205 node.left = gnode 206 pnode.par = node 207 gnode.par = node 208 else: 209 tmp1 = node.left 210 pnode.right = tmp1 211 node.left = pnode 212 tmp2 = node.right 213 gnode.left = tmp2 214 node.right = gnode 215 pnode.par = node 216 gnode.par = node 217 if tmp1: 218 tmp1.par = pnode 219 if tmp2: 220 tmp2.par = gnode 221 self._update(gnode) 222 self._update(pnode) 223 self._update(node) 224 if node.par is None: 225 return 226 if node.par.left is gnode: 227 node.par.left = node 228 else: 229 node.par.right = node 230 if node.par is None: 231 return 232 pnode = node.par 233 self._propagate(pnode) 234 self._propagate(node) 235 if pnode.left is node: 236 pnode.left = node.right 237 if pnode.left: 238 pnode.left.par = pnode 239 node.right = pnode 240 else: 241 pnode.right = node.left 242 if pnode.right: 243 pnode.right.par = pnode 244 node.left = pnode 245 node.par = None 246 pnode.par = node 247 self._update(pnode) 248 self._update(node) 249 250 def _left_splay(self, node: _Node) -> _Node: 251 self._splay(node) 252 while node.left is not None: 253 node = node.left 254 self._splay(node) 255 return node 256 257 def _right_splay(self, node: _Node) -> _Node: 258 self._splay(node) 259 while node.right is not None: 260 node = node.right 261 self._splay(node) 262 return node 263 264 def _propagate(self, node: Optional[_Node]) -> None: 265 if node is None or node.lazy == self.id: 266 return 267 if node.left: 268 node.left.key = self.mapping(node.lazy, node.left.key) 269 node.left.data = self.mapping(node.lazy, node.left.data) 270 node.left.lazy = self.composition(node.lazy, node.left.lazy) 271 if node.right: 272 node.right.key = self.mapping(node.lazy, node.right.key) 273 node.right.data = self.mapping(node.lazy, node.right.data) 274 node.right.lazy = self.composition(node.lazy, node.right.lazy) 275 node.lazy = self.id 276 277 def _update(self, node: _Node) -> None: 278 self._propagate(node.left) 279 self._propagate(node.right) 280 node.data = node.key 281 if node.left: 282 node.data = self.op(node.left.data, node.data) 283 if node.right: 284 node.data = self.op(node.data, node.right.data) 285 312
[docs] 313 def cut(self, u: int, v: int) -> None: 314 """辺 ``{u, v}`` を削除します。 315 :math:`O(\\log{n})` です。 316 317 Note: 318 辺 ``{u, v}`` が存在してなければいけません。 319 """ 320 # erace edge{u, v} 321 self.reroot(v) 322 self.reroot(u) 323 assert ( 324 u * self.n + v in self.ptr_edge 325 ), f"EulerTourTree.cut(), {(u, v)} not in ptr_edge" 326 assert ( 327 v * self.n + u in self.ptr_edge 328 ), f"EulerTourTree.cut(), {(v, u)} not in ptr_edge" 329 uv_node = self.ptr_edge.pop(u * self.n + v) 330 vu_node = self.ptr_edge.pop(v * self.n + u) 331 a, _ = self._split_left(uv_node) 332 _, c = self._split_right(vu_node) 333 a = self._pop(a) 334 c = self._popleft(c) 335 self._merge(a, c) 336 self._group_numbers += 1
337
[docs] 338 def leader(self, v: int) -> _Node: 339 """頂点 ``v`` を含む木の代表元を返します。 340 :math:`O(\\log{n})` です。 341 342 Note: 343 ``reroot`` すると変わるので注意です。 344 """ 345 # vを含む木の代表元 346 # rerootすると変わるので注意 347 return self._left_splay(self.ptr_vertex[v])
348
[docs] 349 def reroot(self, v: int) -> None: 350 """頂点 ``v`` を含む木の根を ``v`` にします。 351 352 :math:`O(\\log{n})` です。 353 """ 354 node = self.ptr_vertex[v] 355 x, y = self._split_right(node) 356 self._merge(y, x) 357 self._splay(node)
358
[docs] 359 def same(self, u: int, v: int) -> bool: 360 """ 361 頂点 ``u`` と ``v`` が同じ連結成分にいれば ``True`` を、 362 そうでなければ ``False`` を返します。 363 364 :math:`O(\\log{n})` です。 365 """ 366 u_node = self.ptr_vertex[u] 367 v_node = self.ptr_vertex[v] 368 self._splay(u_node) 369 self._splay(v_node) 370 return u_node.par is not None or u_node is v_node
371 372 def _show(self) -> None: 373 # for debug 374 print("+++++++++++++++++++++++++++") 375 for i, v in enumerate(self.ptr_vertex): 376 print((i, i), v, end="\n\n") 377 for k, v in self.ptr_edge.items(): 378 print(k, v, end="\n\n") 379 print("+++++++++++++++++++++++++++") 380
[docs] 381 def subtree_apply(self, v: int, p: int, f: F) -> None: 382 """頂点 ``v`` を根としたときの部分木に ``f`` を作用します。 383 384 ``v`` の親は ``p`` です。 385 ``v`` の親が存在しないときは ``p=-1`` として下さい。 386 387 :math:`O(\\log{n})` です。 388 389 Args: 390 v (int): 根です。 391 p (int): ``v`` の親です。 392 f (F): 作用素です。 393 """ 394 if p == -1: 395 v_node = self.ptr_vertex[v] 396 self._splay(v_node) 397 v_node.key = self.mapping(f, v_node.key) 398 v_node.data = self.mapping(f, v_node.data) 399 v_node.lazy = self.composition(f, v_node.lazy) 400 return 401 self.reroot(v) 402 self.reroot(p) 403 assert ( 404 p * self.n + v in self.ptr_edge 405 ), f"EulerTourTree.subtree_apply(), {(p, v)} not in ptr_edge" 406 assert ( 407 v * self.n + p in self.ptr_edge 408 ), f"EulerTourTree.subtree_apply(), {(v, p)} not in ptr_edge" 409 v_node = self.ptr_vertex[v] 410 a, b = self._split_right(self.ptr_edge[p * self.n + v]) 411 b, d = self._split_left(self.ptr_edge[v * self.n + p]) 412 self._splay(v_node) 413 v_node.key = self.mapping(f, v_node.key) 414 v_node.data = self.mapping(f, v_node.data) 415 v_node.lazy = self.composition(f, v_node.lazy) 416 self._propagate(v_node) 417 self._merge(a, b) 418 self._merge(b, d)
419
[docs] 420 def subtree_sum(self, v: int, p: int) -> T: 421 """頂点 ``v`` を根としたときの部分木の総和を返します。 422 423 ``v`` の親は ``p`` です。 424 ``v`` の親が存在しないときは ``p=-1`` として下さい。 425 426 :math:`O(\\log{n})` です。 427 428 Args: 429 v (int): 根です。 430 p (int): ``v`` の親です。 431 """ 432 if p == -1: 433 v_node = self.ptr_vertex[v] 434 self._splay(v_node) 435 return v_node.data 436 self.reroot(v) 437 self.reroot(p) 438 assert ( 439 p * self.n + v in self.ptr_edge 440 ), f"EulerTourTree.subtree_sum(), {(p, v)} not in ptr_edge" 441 assert ( 442 v * self.n + p in self.ptr_edge 443 ), f"EulerTourTree.subtree_sum(), {(v, p)} not in ptr_edge" 444 v_node = self.ptr_vertex[v] 445 a, b = self._split_right(self.ptr_edge[p * self.n + v]) 446 b, d = self._split_left(self.ptr_edge[v * self.n + p]) 447 self._splay(v_node) 448 res = v_node.data 449 self._merge(a, b) 450 self._merge(b, d) 451 return res
452
[docs] 453 def group_count(self) -> int: 454 """連結成分の個数を返します。 455 :math:`O(1)` です。 456 """ 457 return self._group_numbers
458
[docs] 459 def get_vertex(self, v: int) -> T: 460 """頂点 ``v`` の ``key`` を返します。 461 :math:`O(\\log{n})` です。 462 """ 463 node = self.ptr_vertex[v] 464 self._splay(node) 465 return node.key
466
[docs] 467 def set_vertex(self, v: int, val: T) -> None: 468 """頂点 ``v`` の ``key`` を ``val`` に更新します。 469 :math:`O(\\log{n})` です。 470 """ 471 node = self.ptr_vertex[v] 472 self._splay(node) 473 node.key = val 474 self._update(node)
475 476 def __getitem__(self, v: int) -> T: 477 return self.get_vertex(v) 478 479 def __setitem__(self, v: int, val: T) -> None: 480 return self.set_vertex(v, val)