1from array import array
2
3
[docs]
4class LinkCutTree:
5 """LinkCutTree です。"""
6
7 # - link / cut / merge / split
8 # - root / same
9 # - lca / path_length / path_kth_elm
10 # など
11
12 def __init__(self, n: int) -> None:
13 self.n = n
14 self.arr: array[int] = array("I", [self.n, self.n, self.n, 0] * (self.n + 1))
15 # node.left : arr[node<<2|0]
16 # node.right : arr[node<<2|1]
17 # node.par : arr[node<<2|2]
18 # node.rev : arr[node<<2|3]
19 self.size: array[int] = array("I", [1] * (self.n + 1))
20 self.size[-1] = 0
21 self.group_cnt = self.n
22
23 def _is_root(self, node: int) -> bool:
24 return (self.arr[node << 2 | 2] == self.n) or not (
25 self.arr[self.arr[node << 2 | 2] << 2] == node
26 or self.arr[self.arr[node << 2 | 2] << 2 | 1] == node
27 )
28
29 def _propagate(self, node: int) -> None:
30 if node == self.n:
31 return
32 arr = self.arr
33 if arr[node << 2 | 3]:
34 arr[node << 2 | 3] = 0
35 ln, rn = arr[node << 2], arr[node << 2 | 1]
36 arr[node << 2] = rn
37 arr[node << 2 | 1] = ln
38 arr[ln << 2 | 3] ^= 1
39 arr[rn << 2 | 3] ^= 1
40
41 def _update(self, node: int) -> None:
42 if node == self.n:
43 return
44 ln, rn = self.arr[node << 2], self.arr[node << 2 | 1]
45 self._propagate(ln)
46 self._propagate(rn)
47 self.size[node] = 1 + self.size[ln] + self.size[rn]
48
49 def _update_triple(self, x: int, y: int, z: int) -> None:
50 self._propagate(self.arr[x << 2])
51 self._propagate(self.arr[x << 2 | 1])
52 self._propagate(self.arr[y << 2])
53 self._propagate(self.arr[y << 2 | 1])
54 self.size[z] = self.size[x]
55 self.size[x] = 1 + self.size[self.arr[x << 2]] + self.size[self.arr[x << 2 | 1]]
56 self.size[y] = 1 + self.size[self.arr[y << 2]] + self.size[self.arr[y << 2 | 1]]
57
58 def _update_double(self, x: int, y: int) -> None:
59 self._propagate(self.arr[x << 2])
60 self._propagate(self.arr[x << 2 | 1])
61 self.size[y] = self.size[x]
62 self.size[x] = 1 + self.size[self.arr[x << 2]] + self.size[self.arr[x << 2 | 1]]
63
64 def _splay(self, node: int) -> None:
65 # splayを抜けた後、nodeは遅延伝播済みにするようにする
66 # (splay後のnodeのleft,rightにアクセスしやすいと非常にラクなはず)
67 if node == self.n:
68 return
69 _propagate, _is_root, _update_triple = (
70 self._propagate,
71 self._is_root,
72 self._update_triple,
73 )
74 _propagate(node)
75 if _is_root(node):
76 return
77 arr = self.arr
78 pnode = arr[node << 2 | 2]
79 while not _is_root(pnode):
80 gnode = arr[pnode << 2 | 2]
81 _propagate(gnode)
82 _propagate(pnode)
83 _propagate(node)
84 f = arr[pnode << 2] == node
85 g = (arr[gnode << 2 | f] == pnode) ^ (arr[pnode << 2 | f] == node)
86 nnode = (node if g else pnode) << 2 | f ^ g
87 arr[pnode << 2 | f ^ 1] = arr[node << 2 | f]
88 arr[gnode << 2 | f ^ g ^ 1] = arr[nnode]
89 arr[node << 2 | f] = pnode
90 arr[nnode] = gnode
91 arr[node << 2 | 2] = arr[gnode << 2 | 2]
92 arr[gnode << 2 | 2] = nnode >> 2
93 arr[arr[pnode << 2 | f ^ 1] << 2 | 2] = pnode
94 arr[arr[gnode << 2 | f ^ g ^ 1] << 2 | 2] = gnode
95 arr[pnode << 2 | 2] = node
96 _update_triple(gnode, pnode, node)
97 pnode = arr[node << 2 | 2]
98 if arr[pnode << 2] == gnode:
99 arr[pnode << 2] = node
100 elif arr[pnode << 2 | 1] == gnode:
101 arr[pnode << 2 | 1] = node
102 else:
103 return
104 _propagate(pnode)
105 _propagate(node)
106 f = arr[pnode << 2] == node
107 arr[pnode << 2 | f ^ 1] = arr[node << 2 | f]
108 arr[node << 2 | f] = pnode
109 arr[arr[pnode << 2 | f ^ 1] << 2 | 2] = pnode
110 arr[node << 2 | 2] = arr[pnode << 2 | 2]
111 arr[pnode << 2 | 2] = node
112 self._update_double(pnode, node)
113
[docs]
114 def expose(self, v: int) -> int:
115 """``v`` が属する木において、その木を管理しているsplay木の根から ``v`` までのパスを作ります。
116 償却 :math:`O(\\log{n})` です。
117 """
118 arr, n, _splay, _update = self.arr, self.n, self._splay, self._update
119 pre = v
120 while arr[v << 2 | 2] != n:
121 _splay(v)
122 arr[v << 2 | 1] = n
123 _update(v)
124 if arr[v << 2 | 2] == n:
125 break
126 pre = arr[v << 2 | 2]
127 _splay(pre)
128 arr[pre << 2 | 1] = v
129 _update(pre)
130 arr[v << 2 | 1] = n
131 _update(v)
132 return pre
133
[docs]
134 def lca(self, u: int, v: int, root: int) -> int:
135 """``root`` を根としたときの、 ``u``, ``v`` の LCA を返します。
136 償却 :math:`O(\\log{n})` です。
137 """
138 self.evert(root)
139 self.expose(u)
140 return self.expose(v)
141
[docs]
142 def link(self, c: int, p: int) -> None:
143 """辺 ``(c -> p)`` を追加します。
144 償却 :math:`O(\\log{n})` です。
145
146 制約:
147 ``c`` は元の木の根でなければならないです。
148 """
149 assert not self.same(c, p)
150 self.expose(c)
151 self.expose(p)
152 self.arr[c << 2 | 2] = p
153 self.arr[p << 2 | 1] = c
154 self._update(p)
155 self.group_cnt -= 1
156
[docs]
157 def cut(self, c: int) -> None:
158 """辺 ``{c -> cの親}`` を削除します。
159 償却 :math:`O(\\log{n})` です。
160
161 制約:
162 ``c`` は元の木の根であってはいけないです。
163 """
164 arr = self.arr
165 self.expose(c)
166 assert arr[c << 2] != self.n
167 arr[arr[c << 2] << 2 | 2] = self.n
168 arr[c << 2] = self.n
169 self._update(c)
170 self.group_cnt += 1
171
[docs]
172 def group_count(self) -> int:
173 """連結成分数を返します。
174 :math:`O(1)` です。
175 """
176 return self.group_cnt
177
[docs]
178 def root(self, v: int) -> int:
179 """``v`` が属する木の根を返します。
180 償却 :math:`O(\\log{n})` です。
181 """
182 self.expose(v)
183 arr, n = self.arr, self.n
184 while arr[v << 2] != n:
185 v = arr[v << 2]
186 self._propagate(v)
187 self._splay(v)
188 return v
189
[docs]
190 def same(self, u: int, v: int) -> bool:
191 """連結判定です。
192 償却 :math:`O(\\log{n})` です。
193
194 Returns:
195 bool: ``u``, ``v`` が同じ連結成分であれば ``True`` を、そうでなければ ``False`` を返します。
196 """
197 return self.root(u) == self.root(v)
198
[docs]
199 def evert(self, v: int) -> None:
200 """``v`` を根にします。
201 償却 :math:`O(\\log{n})` です。
202 """
203 self.expose(v)
204 self.arr[v << 2 | 3] ^= 1
205 self._propagate(v)
206
[docs]
207 def merge(self, u: int, v: int) -> bool:
208 """``u``, ``v`` が同じ連結成分なら ``False`` を返します。
209 そうでなければ辺 ``{u -> v}`` を追加して ``True`` を返します。
210 償却 :math:`O(\\log{n})` です。
211 """
212 if self.same(u, v):
213 return False
214 self.evert(u)
215 self.expose(v)
216 self.arr[u << 2 | 2] = v
217 self.arr[v << 2 | 1] = u
218 self._update(v)
219 self.group_cnt -= 1
220 return True
221
[docs]
222 def split(self, u: int, v: int) -> bool:
223 """辺 ``{u -> v}`` があれば削除し ``True`` を返します。
224 そうでなければ何もせず ``False`` を返します。
225 償却 :math:`O(\\log{n})` です。
226 """
227 self.evert(u)
228 self.cut(v)
229 return True
230
[docs]
231 def path_length(self, u: int, v: int) -> int:
232 """``u`` から ``v`` へのパスに含まれる頂点の数を返します。
233 存在しないときは ``-1`` を返します。
234 償却 :math:`O(\\log{n})` です。
235 """
236 if not self.same(u, v):
237 return -1
238 self.evert(u)
239 self.expose(v)
240 return self.size[v]
241
[docs]
242 def path_kth_elm(self, s: int, t: int, k: int) -> int:
243 """``u`` から ``v`` へ ``k`` 個進んだ頂点を返します。
244 存在しないときは ``-1`` を返します。
245 償却 :math:`O(\\log{n})` です。
246 """
247 self.evert(s)
248 self.expose(t)
249 if self.size[t] <= k:
250 return -1
251 size, arr = self.size, self.arr
252 while True:
253 self._propagate(t)
254 s = size[arr[t << 2]]
255 if s == k:
256 self._splay(t)
257 return t
258 t = arr[t << 2 | (s < k)]
259 if s < k:
260 k -= s + 1
261
262 def __str__(self):
263 return f"{self.__class__.__name__}"
264
265 __repr__ = __str__