euler_tour_tree

ソースコード

from titan_pylib.data_structures.dynamic_connectivity.euler_tour_tree import EulerTourTree

view on github

展開済みコード

  1# from titan_pylib.data_structures.dynamic_connectivity.euler_tour_tree import EulerTourTree
  2from typing import Generator, Generic, TypeVar, Callable, Iterable, Optional, Union
  3from types import GeneratorType
  4
  5T = TypeVar("T")
  6F = TypeVar("F")
  7
  8
  9class EulerTourTree(Generic[T, F]):
 10
 11    class _Node:
 12
 13        def __init__(self, key: T, lazy: F):
 14            self.key: T = key
 15            self.data: T = key
 16            self.lazy: F = lazy
 17            self.par: Optional[EulerTourTree._Node] = None
 18            self.left: Optional[EulerTourTree._Node] = None
 19            self.right: Optional[EulerTourTree._Node] = None
 20
 21        def __str__(self):
 22            if self.left is None and self.right is None:
 23                return f"(key,par):{self.key,self.data,self.lazy,(self.par.key if self.par else None)}\n"
 24            return f"(key,par):{self.key,self.data,self.lazy,(self.par.key if self.par else None)},\n left:{self.left},\n right:{self.right}\n"
 25
 26        __repr__ = __str__
 27
 28    def __init__(
 29        self,
 30        n_or_a: Union[int, Iterable[T]],
 31        op: Callable[[T, T], T],
 32        mapping: Callable[[F, T], T],
 33        composition: Callable[[F, F], F],
 34        e: T,
 35        id: F,
 36    ) -> None:
 37        self.op = op
 38        self.mapping = mapping
 39        self.composition = composition
 40        self.e = e
 41        self.id = id
 42        a = [e for _ in range(n_or_a)] if isinstance(n_or_a, int) else list(n_or_a)
 43        self.n: int = len(a)
 44        self.ptr_vertex: list[EulerTourTree._Node] = [
 45            EulerTourTree._Node(elem, id) for i, elem in enumerate(a)
 46        ]
 47        self.ptr_edge: dict[tuple[int, int], EulerTourTree._Node] = {}
 48        self._group_numbers: int = self.n
 49
 50    @staticmethod
 51    def antirec(func, stack=[]):
 52        # 参考: https://github.com/cheran-senthil/PyRival/blob/master/pyrival/misc/bootstrap.py
 53        def wrappedfunc(*args, **kwargs):
 54            if stack:
 55                return func(*args, **kwargs)
 56            to = func(*args, **kwargs)
 57            while True:
 58                if isinstance(to, GeneratorType):
 59                    stack.append(to)
 60                    to = next(to)
 61                else:
 62                    stack.pop()
 63                    if not stack:
 64                        break
 65                    to = stack[-1].send(to)
 66            return to
 67
 68        return wrappedfunc
 69
 70    def build(self, G: list[list[int]]) -> None:
 71        """隣接リスト ``G`` をもとにして、辺を張ります。
 72        :math:`O(n)` です。
 73
 74        Args:
 75          G (list[list[int]]): 隣接リストです。
 76
 77        Note:
 78          ``build`` メソッドを使用する場合は他のメソッドより前に使用しなければなりません。
 79        """
 80        n, ptr_vertex, ptr_edge, e, id = (
 81            self.n,
 82            self.ptr_vertex,
 83            self.ptr_edge,
 84            self.e,
 85            self.id,
 86        )
 87        seen = [0] * n
 88        _Node = EulerTourTree._Node
 89
 90        @EulerTourTree.antirec
 91        def dfs(v: int, p: int = -1) -> Generator:
 92            a.append(v * n + v)
 93            for x in G[v]:
 94                if x == p:
 95                    continue
 96                a.append(v * n + x)
 97                yield dfs(x, v)
 98                a.append(x * n + v)
 99            yield
100
101        @EulerTourTree.antirec
102        def rec(l: int, r: int) -> Generator:
103            mid = (l + r) >> 1
104            u, v = divmod(a[mid], n)
105            node = ptr_vertex[u] if u == v else _Node(e, id)
106            if u == v:
107                seen[u] = 1
108            else:
109                ptr_edge[u * n + v] = node
110            if l != mid:
111                node.left = yield rec(l, mid)
112                node.left.par = node
113            if mid + 1 != r:
114                node.right = yield rec(mid + 1, r)
115                node.right.par = node
116            self._update(node)
117            yield node
118
119        for root in range(self.n):
120            if seen[root]:
121                continue
122            a: list[int] = []
123            dfs(root)
124            rec(0, len(a))
125
126    def _popleft(self, v: _Node) -> Optional[_Node]:
127        v = self._left_splay(v)
128        if v.right:
129            v.right.par = None
130        return v.right
131
132    def _pop(self, v: _Node) -> Optional[_Node]:
133        v = self._right_splay(v)
134        if v.left:
135            v.left.par = None
136        return v.left
137
138    def _split_left(self, v: _Node) -> tuple[_Node, Optional[_Node]]:
139        # x, yに分割する。ただし、xはvを含む
140        self._splay(v)
141        x, y = v, v.right
142        if y:
143            y.par = None
144        x.right = None
145        self._update(x)
146        return x, y
147
148    def _split_right(self, v: _Node) -> tuple[Optional[_Node], _Node]:
149        # x, yに分割する。ただし、yはvを含む
150        self._splay(v)
151        x, y = v.left, v
152        if x:
153            x.par = None
154        y.left = None
155        self._update(y)
156        return x, y
157
158    def _merge(self, u: Optional[_Node], v: Optional[_Node]) -> None:
159        if u is None or v is None:
160            return
161        u = self._right_splay(u)
162        self._splay(v)
163        u.right = v
164        v.par = u
165        self._update(u)
166
167    def _splay(self, node: _Node) -> None:
168        self._propagate(node)
169        while node.par is not None and node.par.par is not None:
170            pnode = node.par
171            gnode = pnode.par
172            self._propagate(gnode)
173            self._propagate(pnode)
174            self._propagate(node)
175            node.par = gnode.par
176            if (gnode.left is pnode) == (pnode.left is node):
177                if pnode.left is node:
178                    tmp1 = node.right
179                    pnode.left = tmp1
180                    node.right = pnode
181                    pnode.par = node
182                    tmp2 = pnode.right
183                    gnode.left = tmp2
184                    pnode.right = gnode
185                    gnode.par = pnode
186                else:
187                    tmp1 = node.left
188                    pnode.right = tmp1
189                    node.left = pnode
190                    pnode.par = node
191                    tmp2 = pnode.left
192                    gnode.right = tmp2
193                    pnode.left = gnode
194                    gnode.par = pnode
195                if tmp1:
196                    tmp1.par = pnode
197                if tmp2:
198                    tmp2.par = gnode
199            else:
200                if pnode.left is node:
201                    tmp1 = node.right
202                    pnode.left = tmp1
203                    node.right = pnode
204                    tmp2 = node.left
205                    gnode.right = tmp2
206                    node.left = gnode
207                    pnode.par = node
208                    gnode.par = node
209                else:
210                    tmp1 = node.left
211                    pnode.right = tmp1
212                    node.left = pnode
213                    tmp2 = node.right
214                    gnode.left = tmp2
215                    node.right = gnode
216                    pnode.par = node
217                    gnode.par = node
218                if tmp1:
219                    tmp1.par = pnode
220                if tmp2:
221                    tmp2.par = gnode
222            self._update(gnode)
223            self._update(pnode)
224            self._update(node)
225            if node.par is None:
226                return
227            if node.par.left is gnode:
228                node.par.left = node
229            else:
230                node.par.right = node
231        if node.par is None:
232            return
233        pnode = node.par
234        self._propagate(pnode)
235        self._propagate(node)
236        if pnode.left is node:
237            pnode.left = node.right
238            if pnode.left:
239                pnode.left.par = pnode
240            node.right = pnode
241        else:
242            pnode.right = node.left
243            if pnode.right:
244                pnode.right.par = pnode
245            node.left = pnode
246        node.par = None
247        pnode.par = node
248        self._update(pnode)
249        self._update(node)
250
251    def _left_splay(self, node: _Node) -> _Node:
252        self._splay(node)
253        while node.left is not None:
254            node = node.left
255        self._splay(node)
256        return node
257
258    def _right_splay(self, node: _Node) -> _Node:
259        self._splay(node)
260        while node.right is not None:
261            node = node.right
262        self._splay(node)
263        return node
264
265    def _propagate(self, node: Optional[_Node]) -> None:
266        if node is None or node.lazy == self.id:
267            return
268        if node.left:
269            node.left.key = self.mapping(node.lazy, node.left.key)
270            node.left.data = self.mapping(node.lazy, node.left.data)
271            node.left.lazy = self.composition(node.lazy, node.left.lazy)
272        if node.right:
273            node.right.key = self.mapping(node.lazy, node.right.key)
274            node.right.data = self.mapping(node.lazy, node.right.data)
275            node.right.lazy = self.composition(node.lazy, node.right.lazy)
276        node.lazy = self.id
277
278    def _update(self, node: _Node) -> None:
279        self._propagate(node.left)
280        self._propagate(node.right)
281        node.data = node.key
282        if node.left:
283            node.data = self.op(node.left.data, node.data)
284        if node.right:
285            node.data = self.op(node.data, node.right.data)
286
287    def link(self, u: int, v: int) -> None:
288        """辺 ``{u, v}`` を追加します。
289        :math:`O(\\log{n})` です。
290
291        Note:
292          ``u`` と ``v`` が同じ連結成分であってはいけません。
293        """
294        # add edge{u, v}
295        self.reroot(u)
296        self.reroot(v)
297        assert (
298            u * self.n + v not in self.ptr_edge
299        ), f"EulerTourTree.link(), {(u, v)} in ptr_edge"
300        assert (
301            v * self.n + u not in self.ptr_edge
302        ), f"EulerTourTree.link(), {(v, u)} in ptr_edge"
303        uv_node = EulerTourTree._Node(self.e, self.id)
304        vu_node = EulerTourTree._Node(self.e, self.id)
305        self.ptr_edge[u * self.n + v] = uv_node
306        self.ptr_edge[v * self.n + u] = vu_node
307        u_node = self.ptr_vertex[u]
308        v_node = self.ptr_vertex[v]
309        self._merge(u_node, uv_node)
310        self._merge(uv_node, v_node)
311        self._merge(v_node, vu_node)
312        self._group_numbers -= 1
313
314    def cut(self, u: int, v: int) -> None:
315        """辺 ``{u, v}`` を削除します。
316        :math:`O(\\log{n})` です。
317
318        Note:
319          辺 ``{u, v}`` が存在してなければいけません。
320        """
321        # erace edge{u, v}
322        self.reroot(v)
323        self.reroot(u)
324        assert (
325            u * self.n + v in self.ptr_edge
326        ), f"EulerTourTree.cut(), {(u, v)} not in ptr_edge"
327        assert (
328            v * self.n + u in self.ptr_edge
329        ), f"EulerTourTree.cut(), {(v, u)} not in ptr_edge"
330        uv_node = self.ptr_edge.pop(u * self.n + v)
331        vu_node = self.ptr_edge.pop(v * self.n + u)
332        a, _ = self._split_left(uv_node)
333        _, c = self._split_right(vu_node)
334        a = self._pop(a)
335        c = self._popleft(c)
336        self._merge(a, c)
337        self._group_numbers += 1
338
339    def leader(self, v: int) -> _Node:
340        """頂点 ``v`` を含む木の代表元を返します。
341        :math:`O(\\log{n})` です。
342
343        Note:
344          ``reroot`` すると変わるので注意です。
345        """
346        # vを含む木の代表元
347        # rerootすると変わるので注意
348        return self._left_splay(self.ptr_vertex[v])
349
350    def reroot(self, v: int) -> None:
351        """頂点 ``v`` を含む木の根を ``v`` にします。
352
353        :math:`O(\\log{n})` です。
354        """
355        node = self.ptr_vertex[v]
356        x, y = self._split_right(node)
357        self._merge(y, x)
358        self._splay(node)
359
360    def same(self, u: int, v: int) -> bool:
361        """
362        頂点 ``u`` と ``v`` が同じ連結成分にいれば ``True`` を、
363        そうでなければ ``False`` を返します。
364
365        :math:`O(\\log{n})` です。
366        """
367        u_node = self.ptr_vertex[u]
368        v_node = self.ptr_vertex[v]
369        self._splay(u_node)
370        self._splay(v_node)
371        return u_node.par is not None or u_node is v_node
372
373    def _show(self) -> None:
374        # for debug
375        print("+++++++++++++++++++++++++++")
376        for i, v in enumerate(self.ptr_vertex):
377            print((i, i), v, end="\n\n")
378        for k, v in self.ptr_edge.items():
379            print(k, v, end="\n\n")
380        print("+++++++++++++++++++++++++++")
381
382    def subtree_apply(self, v: int, p: int, f: F) -> None:
383        """頂点 ``v`` を根としたときの部分木に ``f`` を作用します。
384
385        ``v`` の親は ``p`` です。
386        ``v`` の親が存在しないときは ``p=-1`` として下さい。
387
388        :math:`O(\\log{n})` です。
389
390        Args:
391          v (int): 根です。
392          p (int): ``v`` の親です。
393          f (F): 作用素です。
394        """
395        if p == -1:
396            v_node = self.ptr_vertex[v]
397            self._splay(v_node)
398            v_node.key = self.mapping(f, v_node.key)
399            v_node.data = self.mapping(f, v_node.data)
400            v_node.lazy = self.composition(f, v_node.lazy)
401            return
402        self.reroot(v)
403        self.reroot(p)
404        assert (
405            p * self.n + v in self.ptr_edge
406        ), f"EulerTourTree.subtree_apply(), {(p, v)} not in ptr_edge"
407        assert (
408            v * self.n + p in self.ptr_edge
409        ), f"EulerTourTree.subtree_apply(), {(v, p)} not in ptr_edge"
410        v_node = self.ptr_vertex[v]
411        a, b = self._split_right(self.ptr_edge[p * self.n + v])
412        b, d = self._split_left(self.ptr_edge[v * self.n + p])
413        self._splay(v_node)
414        v_node.key = self.mapping(f, v_node.key)
415        v_node.data = self.mapping(f, v_node.data)
416        v_node.lazy = self.composition(f, v_node.lazy)
417        self._propagate(v_node)
418        self._merge(a, b)
419        self._merge(b, d)
420
421    def subtree_sum(self, v: int, p: int) -> T:
422        """頂点 ``v`` を根としたときの部分木の総和を返します。
423
424        ``v`` の親は ``p`` です。
425        ``v`` の親が存在しないときは ``p=-1`` として下さい。
426
427        :math:`O(\\log{n})` です。
428
429        Args:
430          v (int): 根です。
431          p (int): ``v`` の親です。
432        """
433        if p == -1:
434            v_node = self.ptr_vertex[v]
435            self._splay(v_node)
436            return v_node.data
437        self.reroot(v)
438        self.reroot(p)
439        assert (
440            p * self.n + v in self.ptr_edge
441        ), f"EulerTourTree.subtree_sum(), {(p, v)} not in ptr_edge"
442        assert (
443            v * self.n + p in self.ptr_edge
444        ), f"EulerTourTree.subtree_sum(), {(v, p)} not in ptr_edge"
445        v_node = self.ptr_vertex[v]
446        a, b = self._split_right(self.ptr_edge[p * self.n + v])
447        b, d = self._split_left(self.ptr_edge[v * self.n + p])
448        self._splay(v_node)
449        res = v_node.data
450        self._merge(a, b)
451        self._merge(b, d)
452        return res
453
454    def group_count(self) -> int:
455        """連結成分の個数を返します。
456        :math:`O(1)` です。
457        """
458        return self._group_numbers
459
460    def get_vertex(self, v: int) -> T:
461        """頂点 ``v`` の ``key`` を返します。
462        :math:`O(\\log{n})` です。
463        """
464        node = self.ptr_vertex[v]
465        self._splay(node)
466        return node.key
467
468    def set_vertex(self, v: int, val: T) -> None:
469        """頂点 ``v`` の ``key`` を ``val`` に更新します。
470        :math:`O(\\log{n})` です。
471        """
472        node = self.ptr_vertex[v]
473        self._splay(node)
474        node.key = val
475        self._update(node)
476
477    def __getitem__(self, v: int) -> T:
478        return self.get_vertex(v)
479
480    def __setitem__(self, v: int, val: T) -> None:
481        return self.set_vertex(v, val)

仕様

class EulerTourTree(n_or_a: int | Iterable[T], op: Callable[[T, T], T], mapping: Callable[[F, T], T], composition: Callable[[F, F], F], e: T, id: F)[source]

Bases: Generic[T, F]

static antirec(func, stack=[])[source]
build(G: list[list[int]]) None[source]

隣接リスト G をもとにして、辺を張ります。 \(O(n)\) です。

Parameters:

G (list[list[int]]) – 隣接リストです。

Note

build メソッドを使用する場合は他のメソッドより前に使用しなければなりません。

cut(u: int, v: int) None[source]

{u, v} を削除します。 \(O(\log{n})\) です。

Note

{u, v} が存在してなければいけません。

get_vertex(v: int) T[source]

頂点 vkey を返します。 \(O(\log{n})\) です。

group_count() int[source]

連結成分の個数を返します。 \(O(1)\) です。

leader(v: int) _Node[source]

頂点 v を含む木の代表元を返します。 \(O(\log{n})\) です。

Note

reroot すると変わるので注意です。

{u, v} を追加します。 \(O(\log{n})\) です。

Note

uv が同じ連結成分であってはいけません。

reroot(v: int) None[source]

頂点 v を含む木の根を v にします。

\(O(\log{n})\) です。

same(u: int, v: int) bool[source]

頂点 uv が同じ連結成分にいれば True を、 そうでなければ False を返します。

\(O(\log{n})\) です。

set_vertex(v: int, val: T) None[source]

頂点 vkeyval に更新します。 \(O(\log{n})\) です。

subtree_apply(v: int, p: int, f: F) None[source]

頂点 v を根としたときの部分木に f を作用します。

v の親は p です。 v の親が存在しないときは p=-1 として下さい。

\(O(\log{n})\) です。

Parameters:
  • v (int) – 根です。

  • p (int) – v の親です。

  • f (F) – 作用素です。

subtree_sum(v: int, p: int) T[source]

頂点 v を根としたときの部分木の総和を返します。

v の親は p です。 v の親が存在しないときは p=-1 として下さい。

\(O(\log{n})\) です。

Parameters:
  • v (int) – 根です。

  • p (int) – v の親です。