Source code for titan_pylib.data_structures.dynamic_connectivity.link_cut_tree

  1from array import array
  2
  3
[docs] 4class LinkCutTree: 5 """LinkCutTree です。""" 6 7 # - link / cut / merge / split 8 # - root / same 9 # - lca / path_length / path_kth_elm 10 # など 11 12 def __init__(self, n: int) -> None: 13 self.n = n 14 self.arr: array[int] = array("I", [self.n, self.n, self.n, 0] * (self.n + 1)) 15 # node.left : arr[node<<2|0] 16 # node.right : arr[node<<2|1] 17 # node.par : arr[node<<2|2] 18 # node.rev : arr[node<<2|3] 19 self.size: array[int] = array("I", [1] * (self.n + 1)) 20 self.size[-1] = 0 21 self.group_cnt = self.n 22 23 def _is_root(self, node: int) -> bool: 24 return (self.arr[node << 2 | 2] == self.n) or not ( 25 self.arr[self.arr[node << 2 | 2] << 2] == node 26 or self.arr[self.arr[node << 2 | 2] << 2 | 1] == node 27 ) 28 29 def _propagate(self, node: int) -> None: 30 if node == self.n: 31 return 32 arr = self.arr 33 if arr[node << 2 | 3]: 34 arr[node << 2 | 3] = 0 35 ln, rn = arr[node << 2], arr[node << 2 | 1] 36 arr[node << 2] = rn 37 arr[node << 2 | 1] = ln 38 arr[ln << 2 | 3] ^= 1 39 arr[rn << 2 | 3] ^= 1 40 41 def _update(self, node: int) -> None: 42 if node == self.n: 43 return 44 ln, rn = self.arr[node << 2], self.arr[node << 2 | 1] 45 self._propagate(ln) 46 self._propagate(rn) 47 self.size[node] = 1 + self.size[ln] + self.size[rn] 48 49 def _update_triple(self, x: int, y: int, z: int) -> None: 50 self._propagate(self.arr[x << 2]) 51 self._propagate(self.arr[x << 2 | 1]) 52 self._propagate(self.arr[y << 2]) 53 self._propagate(self.arr[y << 2 | 1]) 54 self.size[z] = self.size[x] 55 self.size[x] = 1 + self.size[self.arr[x << 2]] + self.size[self.arr[x << 2 | 1]] 56 self.size[y] = 1 + self.size[self.arr[y << 2]] + self.size[self.arr[y << 2 | 1]] 57 58 def _update_double(self, x: int, y: int) -> None: 59 self._propagate(self.arr[x << 2]) 60 self._propagate(self.arr[x << 2 | 1]) 61 self.size[y] = self.size[x] 62 self.size[x] = 1 + self.size[self.arr[x << 2]] + self.size[self.arr[x << 2 | 1]] 63 64 def _splay(self, node: int) -> None: 65 # splayを抜けた後、nodeは遅延伝播済みにするようにする 66 # (splay後のnodeのleft,rightにアクセスしやすいと非常にラクなはず) 67 if node == self.n: 68 return 69 _propagate, _is_root, _update_triple = ( 70 self._propagate, 71 self._is_root, 72 self._update_triple, 73 ) 74 _propagate(node) 75 if _is_root(node): 76 return 77 arr = self.arr 78 pnode = arr[node << 2 | 2] 79 while not _is_root(pnode): 80 gnode = arr[pnode << 2 | 2] 81 _propagate(gnode) 82 _propagate(pnode) 83 _propagate(node) 84 f = arr[pnode << 2] == node 85 g = (arr[gnode << 2 | f] == pnode) ^ (arr[pnode << 2 | f] == node) 86 nnode = (node if g else pnode) << 2 | f ^ g 87 arr[pnode << 2 | f ^ 1] = arr[node << 2 | f] 88 arr[gnode << 2 | f ^ g ^ 1] = arr[nnode] 89 arr[node << 2 | f] = pnode 90 arr[nnode] = gnode 91 arr[node << 2 | 2] = arr[gnode << 2 | 2] 92 arr[gnode << 2 | 2] = nnode >> 2 93 arr[arr[pnode << 2 | f ^ 1] << 2 | 2] = pnode 94 arr[arr[gnode << 2 | f ^ g ^ 1] << 2 | 2] = gnode 95 arr[pnode << 2 | 2] = node 96 _update_triple(gnode, pnode, node) 97 pnode = arr[node << 2 | 2] 98 if arr[pnode << 2] == gnode: 99 arr[pnode << 2] = node 100 elif arr[pnode << 2 | 1] == gnode: 101 arr[pnode << 2 | 1] = node 102 else: 103 return 104 _propagate(pnode) 105 _propagate(node) 106 f = arr[pnode << 2] == node 107 arr[pnode << 2 | f ^ 1] = arr[node << 2 | f] 108 arr[node << 2 | f] = pnode 109 arr[arr[pnode << 2 | f ^ 1] << 2 | 2] = pnode 110 arr[node << 2 | 2] = arr[pnode << 2 | 2] 111 arr[pnode << 2 | 2] = node 112 self._update_double(pnode, node) 113
[docs] 114 def expose(self, v: int) -> int: 115 """``v`` が属する木において、その木を管理しているsplay木の根から ``v`` までのパスを作ります。 116 償却 :math:`O(\\log{n})` です。 117 """ 118 arr, n, _splay, _update = self.arr, self.n, self._splay, self._update 119 pre = v 120 while arr[v << 2 | 2] != n: 121 _splay(v) 122 arr[v << 2 | 1] = n 123 _update(v) 124 if arr[v << 2 | 2] == n: 125 break 126 pre = arr[v << 2 | 2] 127 _splay(pre) 128 arr[pre << 2 | 1] = v 129 _update(pre) 130 arr[v << 2 | 1] = n 131 _update(v) 132 return pre
133
[docs] 134 def lca(self, u: int, v: int, root: int) -> int: 135 """``root`` を根としたときの、 ``u``, ``v`` の LCA を返します。 136 償却 :math:`O(\\log{n})` です。 137 """ 138 self.evert(root) 139 self.expose(u) 140 return self.expose(v)
141 156
[docs] 157 def cut(self, c: int) -> None: 158 """辺 ``{c -> cの親}`` を削除します。 159 償却 :math:`O(\\log{n})` です。 160 161 制約: 162 ``c`` は元の木の根であってはいけないです。 163 """ 164 arr = self.arr 165 self.expose(c) 166 assert arr[c << 2] != self.n 167 arr[arr[c << 2] << 2 | 2] = self.n 168 arr[c << 2] = self.n 169 self._update(c) 170 self.group_cnt += 1
171
[docs] 172 def group_count(self) -> int: 173 """連結成分数を返します。 174 :math:`O(1)` です。 175 """ 176 return self.group_cnt
177
[docs] 178 def root(self, v: int) -> int: 179 """``v`` が属する木の根を返します。 180 償却 :math:`O(\\log{n})` です。 181 """ 182 self.expose(v) 183 arr, n = self.arr, self.n 184 while arr[v << 2] != n: 185 v = arr[v << 2] 186 self._propagate(v) 187 self._splay(v) 188 return v
189
[docs] 190 def same(self, u: int, v: int) -> bool: 191 """連結判定です。 192 償却 :math:`O(\\log{n})` です。 193 194 Returns: 195 bool: ``u``, ``v`` が同じ連結成分であれば ``True`` を、そうでなければ ``False`` を返します。 196 """ 197 return self.root(u) == self.root(v)
198
[docs] 199 def evert(self, v: int) -> None: 200 """``v`` を根にします。 201 償却 :math:`O(\\log{n})` です。 202 """ 203 self.expose(v) 204 self.arr[v << 2 | 3] ^= 1 205 self._propagate(v)
206
[docs] 207 def merge(self, u: int, v: int) -> bool: 208 """``u``, ``v`` が同じ連結成分なら ``False`` を返します。 209 そうでなければ辺 ``{u -> v}`` を追加して ``True`` を返します。 210 償却 :math:`O(\\log{n})` です。 211 """ 212 if self.same(u, v): 213 return False 214 self.evert(u) 215 self.expose(v) 216 self.arr[u << 2 | 2] = v 217 self.arr[v << 2 | 1] = u 218 self._update(v) 219 self.group_cnt -= 1 220 return True
221
[docs] 222 def split(self, u: int, v: int) -> bool: 223 """辺 ``{u -> v}`` があれば削除し ``True`` を返します。 224 そうでなければ何もせず ``False`` を返します。 225 償却 :math:`O(\\log{n})` です。 226 """ 227 self.evert(u) 228 self.cut(v) 229 return True
230
[docs] 231 def path_length(self, u: int, v: int) -> int: 232 """``u`` から ``v`` へのパスに含まれる頂点の数を返します。 233 存在しないときは ``-1`` を返します。 234 償却 :math:`O(\\log{n})` です。 235 """ 236 if not self.same(u, v): 237 return -1 238 self.evert(u) 239 self.expose(v) 240 return self.size[v]
241
[docs] 242 def path_kth_elm(self, s: int, t: int, k: int) -> int: 243 """``u`` から ``v`` へ ``k`` 個進んだ頂点を返します。 244 存在しないときは ``-1`` を返します。 245 償却 :math:`O(\\log{n})` です。 246 """ 247 self.evert(s) 248 self.expose(t) 249 if self.size[t] <= k: 250 return -1 251 size, arr = self.size, self.arr 252 while True: 253 self._propagate(t) 254 s = size[arr[t << 2]] 255 if s == k: 256 self._splay(t) 257 return t 258 t = arr[t << 2 | (s < k)] 259 if s < k: 260 k -= s + 1
261 262 def __str__(self): 263 return f"{self.__class__.__name__}" 264 265 __repr__ = __str__