fully_retroactive_union_find

ソースコード

from titan_pylib.data_structures.union_find.fully_retroactive_union_find import FullyRetroactiveUnionFind

view on github

展開済みコード

  1# from titan_pylib.data_structures.union_find.fully_retroactive_union_find import FullyRetroactiveUnionFind
  2# from titan_pylib.data_structures.dynamic_connectivity.link_cut_tree import LinkCutTree
  3from array import array
  4
  5
  6class LinkCutTree:
  7    """LinkCutTree です。"""
  8
  9    # - link / cut / merge / split
 10    # - root / same
 11    # - lca / path_length / path_kth_elm
 12    # など
 13
 14    def __init__(self, n: int) -> None:
 15        self.n = n
 16        self.arr: array[int] = array("I", [self.n, self.n, self.n, 0] * (self.n + 1))
 17        # node.left  : arr[node<<2|0]
 18        # node.right : arr[node<<2|1]
 19        # node.par   : arr[node<<2|2]
 20        # node.rev   : arr[node<<2|3]
 21        self.size: array[int] = array("I", [1] * (self.n + 1))
 22        self.size[-1] = 0
 23        self.group_cnt = self.n
 24
 25    def _is_root(self, node: int) -> bool:
 26        return (self.arr[node << 2 | 2] == self.n) or not (
 27            self.arr[self.arr[node << 2 | 2] << 2] == node
 28            or self.arr[self.arr[node << 2 | 2] << 2 | 1] == node
 29        )
 30
 31    def _propagate(self, node: int) -> None:
 32        if node == self.n:
 33            return
 34        arr = self.arr
 35        if arr[node << 2 | 3]:
 36            arr[node << 2 | 3] = 0
 37            ln, rn = arr[node << 2], arr[node << 2 | 1]
 38            arr[node << 2] = rn
 39            arr[node << 2 | 1] = ln
 40            arr[ln << 2 | 3] ^= 1
 41            arr[rn << 2 | 3] ^= 1
 42
 43    def _update(self, node: int) -> None:
 44        if node == self.n:
 45            return
 46        ln, rn = self.arr[node << 2], self.arr[node << 2 | 1]
 47        self._propagate(ln)
 48        self._propagate(rn)
 49        self.size[node] = 1 + self.size[ln] + self.size[rn]
 50
 51    def _update_triple(self, x: int, y: int, z: int) -> None:
 52        self._propagate(self.arr[x << 2])
 53        self._propagate(self.arr[x << 2 | 1])
 54        self._propagate(self.arr[y << 2])
 55        self._propagate(self.arr[y << 2 | 1])
 56        self.size[z] = self.size[x]
 57        self.size[x] = 1 + self.size[self.arr[x << 2]] + self.size[self.arr[x << 2 | 1]]
 58        self.size[y] = 1 + self.size[self.arr[y << 2]] + self.size[self.arr[y << 2 | 1]]
 59
 60    def _update_double(self, x: int, y: int) -> None:
 61        self._propagate(self.arr[x << 2])
 62        self._propagate(self.arr[x << 2 | 1])
 63        self.size[y] = self.size[x]
 64        self.size[x] = 1 + self.size[self.arr[x << 2]] + self.size[self.arr[x << 2 | 1]]
 65
 66    def _splay(self, node: int) -> None:
 67        # splayを抜けた後、nodeは遅延伝播済みにするようにする
 68        # (splay後のnodeのleft,rightにアクセスしやすいと非常にラクなはず)
 69        if node == self.n:
 70            return
 71        _propagate, _is_root, _update_triple = (
 72            self._propagate,
 73            self._is_root,
 74            self._update_triple,
 75        )
 76        _propagate(node)
 77        if _is_root(node):
 78            return
 79        arr = self.arr
 80        pnode = arr[node << 2 | 2]
 81        while not _is_root(pnode):
 82            gnode = arr[pnode << 2 | 2]
 83            _propagate(gnode)
 84            _propagate(pnode)
 85            _propagate(node)
 86            f = arr[pnode << 2] == node
 87            g = (arr[gnode << 2 | f] == pnode) ^ (arr[pnode << 2 | f] == node)
 88            nnode = (node if g else pnode) << 2 | f ^ g
 89            arr[pnode << 2 | f ^ 1] = arr[node << 2 | f]
 90            arr[gnode << 2 | f ^ g ^ 1] = arr[nnode]
 91            arr[node << 2 | f] = pnode
 92            arr[nnode] = gnode
 93            arr[node << 2 | 2] = arr[gnode << 2 | 2]
 94            arr[gnode << 2 | 2] = nnode >> 2
 95            arr[arr[pnode << 2 | f ^ 1] << 2 | 2] = pnode
 96            arr[arr[gnode << 2 | f ^ g ^ 1] << 2 | 2] = gnode
 97            arr[pnode << 2 | 2] = node
 98            _update_triple(gnode, pnode, node)
 99            pnode = arr[node << 2 | 2]
100            if arr[pnode << 2] == gnode:
101                arr[pnode << 2] = node
102            elif arr[pnode << 2 | 1] == gnode:
103                arr[pnode << 2 | 1] = node
104            else:
105                return
106        _propagate(pnode)
107        _propagate(node)
108        f = arr[pnode << 2] == node
109        arr[pnode << 2 | f ^ 1] = arr[node << 2 | f]
110        arr[node << 2 | f] = pnode
111        arr[arr[pnode << 2 | f ^ 1] << 2 | 2] = pnode
112        arr[node << 2 | 2] = arr[pnode << 2 | 2]
113        arr[pnode << 2 | 2] = node
114        self._update_double(pnode, node)
115
116    def expose(self, v: int) -> int:
117        """``v`` が属する木において、その木を管理しているsplay木の根から ``v`` までのパスを作ります。
118        償却 :math:`O(\\log{n})` です。
119        """
120        arr, n, _splay, _update = self.arr, self.n, self._splay, self._update
121        pre = v
122        while arr[v << 2 | 2] != n:
123            _splay(v)
124            arr[v << 2 | 1] = n
125            _update(v)
126            if arr[v << 2 | 2] == n:
127                break
128            pre = arr[v << 2 | 2]
129            _splay(pre)
130            arr[pre << 2 | 1] = v
131            _update(pre)
132        arr[v << 2 | 1] = n
133        _update(v)
134        return pre
135
136    def lca(self, u: int, v: int, root: int) -> int:
137        """``root`` を根としたときの、 ``u``, ``v`` の LCA を返します。
138        償却 :math:`O(\\log{n})` です。
139        """
140        self.evert(root)
141        self.expose(u)
142        return self.expose(v)
143
144    def link(self, c: int, p: int) -> None:
145        """辺 ``(c -> p)`` を追加します。
146        償却 :math:`O(\\log{n})` です。
147
148        制約:
149          ``c`` は元の木の根でなければならないです。
150        """
151        assert not self.same(c, p)
152        self.expose(c)
153        self.expose(p)
154        self.arr[c << 2 | 2] = p
155        self.arr[p << 2 | 1] = c
156        self._update(p)
157        self.group_cnt -= 1
158
159    def cut(self, c: int) -> None:
160        """辺 ``{c -> cの親}`` を削除します。
161        償却 :math:`O(\\log{n})` です。
162
163        制約:
164          ``c`` は元の木の根であってはいけないです。
165        """
166        arr = self.arr
167        self.expose(c)
168        assert arr[c << 2] != self.n
169        arr[arr[c << 2] << 2 | 2] = self.n
170        arr[c << 2] = self.n
171        self._update(c)
172        self.group_cnt += 1
173
174    def group_count(self) -> int:
175        """連結成分数を返します。
176        :math:`O(1)` です。
177        """
178        return self.group_cnt
179
180    def root(self, v: int) -> int:
181        """``v`` が属する木の根を返します。
182        償却 :math:`O(\\log{n})` です。
183        """
184        self.expose(v)
185        arr, n = self.arr, self.n
186        while arr[v << 2] != n:
187            v = arr[v << 2]
188            self._propagate(v)
189        self._splay(v)
190        return v
191
192    def same(self, u: int, v: int) -> bool:
193        """連結判定です。
194        償却 :math:`O(\\log{n})` です。
195
196        Returns:
197          bool: ``u``, ``v`` が同じ連結成分であれば ``True`` を、そうでなければ ``False`` を返します。
198        """
199        return self.root(u) == self.root(v)
200
201    def evert(self, v: int) -> None:
202        """``v`` を根にします。
203        償却 :math:`O(\\log{n})` です。
204        """
205        self.expose(v)
206        self.arr[v << 2 | 3] ^= 1
207        self._propagate(v)
208
209    def merge(self, u: int, v: int) -> bool:
210        """``u``, ``v`` が同じ連結成分なら ``False`` を返します。
211        そうでなければ辺 ``{u -> v}`` を追加して ``True`` を返します。
212        償却 :math:`O(\\log{n})` です。
213        """
214        if self.same(u, v):
215            return False
216        self.evert(u)
217        self.expose(v)
218        self.arr[u << 2 | 2] = v
219        self.arr[v << 2 | 1] = u
220        self._update(v)
221        self.group_cnt -= 1
222        return True
223
224    def split(self, u: int, v: int) -> bool:
225        """辺 ``{u -> v}`` があれば削除し ``True`` を返します。
226        そうでなければ何もせず ``False`` を返します。
227        償却 :math:`O(\\log{n})` です。
228        """
229        self.evert(u)
230        self.cut(v)
231        return True
232
233    def path_length(self, u: int, v: int) -> int:
234        """``u`` から ``v`` へのパスに含まれる頂点の数を返します。
235        存在しないときは ``-1`` を返します。
236        償却 :math:`O(\\log{n})` です。
237        """
238        if not self.same(u, v):
239            return -1
240        self.evert(u)
241        self.expose(v)
242        return self.size[v]
243
244    def path_kth_elm(self, s: int, t: int, k: int) -> int:
245        """``u`` から ``v`` へ ``k`` 個進んだ頂点を返します。
246        存在しないときは ``-1`` を返します。
247        償却 :math:`O(\\log{n})` です。
248        """
249        self.evert(s)
250        self.expose(t)
251        if self.size[t] <= k:
252            return -1
253        size, arr = self.size, self.arr
254        while True:
255            self._propagate(t)
256            s = size[arr[t << 2]]
257            if s == k:
258                self._splay(t)
259                return t
260            t = arr[t << 2 | (s < k)]
261            if s < k:
262                k -= s + 1
263
264    def __str__(self):
265        return f"{self.__class__.__name__}"
266
267    __repr__ = __str__
268
269
270class FullyRetroactiveUnionFind:
271
272    def __init__(self, n: int, m: int) -> None:
273        """頂点数 ``n`` 、クエリ列の長さ ``m`` の ``FullyRetroactiveUnionFind`` を作ります。
274
275        ここで、クエリは `unite` のみです。
276
277        :math:`O(n+m)` です。
278
279        Args:
280          n (int): 頂点数です。
281          m (int): クエリ列の長さです。
282        """
283        m += 1
284        self.n: int = n
285        self.edge: list[tuple[int, int, int]] = [()] * m
286        self.node_pool: set[int] = set(range(n, n + m))
287        self.lct: LinkCutTree[int, None] = LinkCutTree(
288            n + m,
289            op=lambda s, t: s if s > t else t,
290            mapping=lambda f, s: -1,
291            composition=lambda f, g: None,
292            e=-1,
293            id=None,
294        )
295
296    def unite(self, u: int, v: int, t: int) -> None:
297        """時刻 ``t`` のクエリを ``unite(u, v)`` にします。
298
299        償却 :math:`O(\\log{(n+m)})` です。
300
301        Args:
302          u (int): 集合の要素です。
303          v (int): 集合の要素です。
304          t (int): 時刻です。
305
306        Note:
307          ``disconnect`` を使用する場合、 ``u``, ``v`` が連結されていてはいけません。
308        """
309        node = self.node_pool.pop()
310        self.edge[t] = (u, v, node)
311        self.lct[node] = t
312        self.lct.merge(u, node)
313        self.lct.merge(node, v)
314
315    def disconnect(self, t: int) -> None:
316        """時刻 ``t`` の連結クエリをなくして、そのクエリの2頂点を非連結にします。
317
318        償却 :math:`O(\\log{(n+m)})` です。
319
320        Args:
321          t (int): 時刻です。
322
323        Note:
324          時刻 ``t`` のクエリは連結クエリでないといけません。
325        """
326        assert self.edge[t] is not None
327        u, v, node = self.edge[t]
328        self.node_pool.add(node)
329        self.edge[t] = None
330        self.lct.split(u, node)
331        self.lct.split(node, v)
332
333    def same(self, u: int, v: int, t: int) -> bool:
334        """時刻 ``t`` で ``u``, ``v`` の連結判定をします。
335
336        償却 :math:`O(\\log{(n+m)})` です。
337
338        Args:
339          u (int): 集合の要素です。
340          v (int): 集合の要素です。
341          t (int): 時刻です。
342
343        Returns:
344          bool:
345        """
346        if not self.lct.same(u, v):
347            return False
348        return self.lct.path_prod(u, v) <= t

仕様

class FullyRetroactiveUnionFind(n: int, m: int)[source]

Bases: object

disconnect(t: int) None[source]

時刻 t の連結クエリをなくして、そのクエリの2頂点を非連結にします。

償却 \(O(\log{(n+m)})\) です。

Parameters:

t (int) – 時刻です。

Note

時刻 t のクエリは連結クエリでないといけません。

same(u: int, v: int, t: int) bool[source]

時刻 tu, v の連結判定をします。

償却 \(O(\log{(n+m)})\) です。

Parameters:
  • u (int) – 集合の要素です。

  • v (int) – 集合の要素です。

  • t (int) – 時刻です。

Return type:

bool

unite(u: int, v: int, t: int) None[source]

時刻 t のクエリを unite(u, v) にします。

償却 \(O(\log{(n+m)})\) です。

Parameters:
  • u (int) – 集合の要素です。

  • v (int) – 集合の要素です。

  • t (int) – 時刻です。

Note

disconnect を使用する場合、 u, v が連結されていてはいけません。