euler_tour

ソースコード

from titan_pylib.graph.euler_tour import EulerTour

view on github

展開済みコード

  1# from titan_pylib.graph.euler_tour import EulerTour
  2# from titan_pylib.data_structures.fenwick_tree.fenwick_tree import FenwickTree
  3from typing import Union, Iterable, Optional
  4
  5
  6class FenwickTree:
  7    """FenwickTreeです。"""
  8
  9    def __init__(self, n_or_a: Union[Iterable[int], int]):
 10        """構築します。
 11        :math:`O(n)` です。
 12
 13        Args:
 14          n_or_a (Union[Iterable[int], int]): `n_or_a` が `int` のとき、初期値 `0` 、長さ `n` で構築します。
 15                                              `n_or_a` が `Iterable` のとき、初期値 `a` で構築します。
 16        """
 17        if isinstance(n_or_a, int):
 18            self._size = n_or_a
 19            self._tree = [0] * (self._size + 1)
 20        else:
 21            a = n_or_a if isinstance(n_or_a, list) else list(n_or_a)
 22            _size = len(a)
 23            _tree = [0] + a
 24            for i in range(1, _size):
 25                if i + (i & -i) <= _size:
 26                    _tree[i + (i & -i)] += _tree[i]
 27            self._size = _size
 28            self._tree = _tree
 29        self._s = 1 << (self._size - 1).bit_length()
 30
 31    def pref(self, r: int) -> int:
 32        """区間 ``[0, r)`` の総和を返します。
 33        :math:`O(\\log{n})` です。
 34        """
 35        assert (
 36            0 <= r <= self._size
 37        ), f"IndexError: {self.__class__.__name__}.pref({r}), n={self._size}"
 38        ret, _tree = 0, self._tree
 39        while r > 0:
 40            ret += _tree[r]
 41            r &= r - 1
 42        return ret
 43
 44    def suff(self, l: int) -> int:
 45        """区間 ``[l, n)`` の総和を返します。
 46        :math:`O(\\log{n})` です。
 47        """
 48        assert (
 49            0 <= l < self._size
 50        ), f"IndexError: {self.__class__.__name__}.suff({l}), n={self._size}"
 51        return self.pref(self._size) - self.pref(l)
 52
 53    def sum(self, l: int, r: int) -> int:
 54        """区間 ``[l, r)`` の総和を返します。
 55        :math:`O(\\log{n})` です。
 56        """
 57        assert (
 58            0 <= l <= r <= self._size
 59        ), f"IndexError: {self.__class__.__name__}.sum({l}, {r}), n={self._size}"
 60        _tree = self._tree
 61        res = 0
 62        while r > l:
 63            res += _tree[r]
 64            r &= r - 1
 65        while l > r:
 66            res -= _tree[l]
 67            l &= l - 1
 68        return res
 69
 70    prod = sum
 71
 72    def __getitem__(self, k: int) -> int:
 73        """位置 ``k`` の要素を返します。
 74        :math:`O(\\log{n})` です。
 75        """
 76        assert (
 77            -self._size <= k < self._size
 78        ), f"IndexError: {self.__class__.__name__}[{k}], n={self._size}"
 79        if k < 0:
 80            k += self._size
 81        return self.sum(k, k + 1)
 82
 83    def add(self, k: int, x: int) -> None:
 84        """``k`` 番目の値に ``x`` を加えます。
 85        :math:`O(\\log{n})` です。
 86        """
 87        assert (
 88            0 <= k < self._size
 89        ), f"IndexError: {self.__class__.__name__}.add({k}, {x}), n={self._size}"
 90        k += 1
 91        _tree = self._tree
 92        while k <= self._size:
 93            _tree[k] += x
 94            k += k & -k
 95
 96    def __setitem__(self, k: int, x: int):
 97        """``k`` 番目の値を ``x`` に更新します。
 98        :math:`O(\\log{n})` です。
 99        """
100        assert (
101            -self._size <= k < self._size
102        ), f"IndexError: {self.__class__.__name__}[{k}] = {x}, n={self._size}"
103        if k < 0:
104            k += self._size
105        pre = self[k]
106        self.add(k, x - pre)
107
108    def bisect_left(self, w: int) -> Optional[int]:
109        i, s, _size, _tree = 0, self._s, self._size, self._tree
110        while s:
111            if i + s <= _size and _tree[i + s] < w:
112                w -= _tree[i + s]
113                i += s
114            s >>= 1
115        return i if w else None
116
117    def bisect_right(self, w: int) -> int:
118        i, s, _size, _tree = 0, self._s, self._size, self._tree
119        while s:
120            if i + s <= _size and _tree[i + s] <= w:
121                w -= _tree[i + s]
122                i += s
123            s >>= 1
124        return i
125
126    def _pop(self, k: int) -> int:
127        assert k >= 0
128        i, acc, s, _size, _tree = 0, 0, self._s, self._size, self._tree
129        while s:
130            if i + s <= _size:
131                if acc + _tree[i + s] <= k:
132                    acc += _tree[i + s]
133                    i += s
134                else:
135                    _tree[i + s] -= 1
136            s >>= 1
137        return i
138
139    def tolist(self) -> list[int]:
140        """リストにして返します。
141        :math:`O(n)` です。
142        """
143        sub = [self.pref(i) for i in range(self._size + 1)]
144        return [sub[i + 1] - sub[i] for i in range(self._size)]
145
146    @staticmethod
147    def get_inversion_num(a: list[int], compress: bool = False) -> int:
148        inv = 0
149        if compress:
150            a_ = sorted(set(a))
151            z = {e: i for i, e in enumerate(a_)}
152            fw = FenwickTree(len(a_) + 1)
153            for i, e in enumerate(a):
154                inv += i - fw.pref(z[e] + 1)
155                fw.add(z[e], 1)
156        else:
157            fw = FenwickTree(len(a) + 1)
158            for i, e in enumerate(a):
159                inv += i - fw.pref(e + 1)
160                fw.add(e, 1)
161        return inv
162
163    def __str__(self):
164        return str(self.tolist())
165
166    def __repr__(self):
167        return f"{self.__class__.__name__}({self})"
168# from titan_pylib.data_structures.segment_tree.segment_tree_RmQ import SegmentTreeRmQ
169# from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
170#     SegmentTreeInterface,
171# )
172from abc import ABC, abstractmethod
173from typing import TypeVar, Generic, Union, Iterable, Callable
174
175T = TypeVar("T")
176
177
178class SegmentTreeInterface(ABC, Generic[T]):
179
180    @abstractmethod
181    def __init__(self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T):
182        raise NotImplementedError
183
184    @abstractmethod
185    def set(self, k: int, v: T) -> None:
186        raise NotImplementedError
187
188    @abstractmethod
189    def get(self, k: int) -> T:
190        raise NotImplementedError
191
192    @abstractmethod
193    def prod(self, l: int, r: int) -> T:
194        raise NotImplementedError
195
196    @abstractmethod
197    def all_prod(self) -> T:
198        raise NotImplementedError
199
200    @abstractmethod
201    def max_right(self, l: int, f: Callable[[T], bool]) -> int:
202        raise NotImplementedError
203
204    @abstractmethod
205    def min_left(self, r: int, f: Callable[[T], bool]) -> int:
206        raise NotImplementedError
207
208    @abstractmethod
209    def tolist(self) -> list[T]:
210        raise NotImplementedError
211
212    @abstractmethod
213    def __getitem__(self, k: int) -> T:
214        raise NotImplementedError
215
216    @abstractmethod
217    def __setitem__(self, k: int, v: T) -> None:
218        raise NotImplementedError
219
220    @abstractmethod
221    def __str__(self):
222        raise NotImplementedError
223
224    @abstractmethod
225    def __repr__(self):
226        raise NotImplementedError
227# from titan_pylib.my_class.supports_less_than import SupportsLessThan
228from typing import Protocol
229
230
231class SupportsLessThan(Protocol):
232
233    def __lt__(self, other) -> bool: ...
234from typing import Generic, Iterable, TypeVar, Union
235
236T = TypeVar("T", bound=SupportsLessThan)
237
238
239class SegmentTreeRmQ(SegmentTreeInterface, Generic[T]):
240    """RmQ セグ木です。"""
241
242    def __init__(self, _n_or_a: Union[int, Iterable[T]], e: T) -> None:
243        self._e = e
244        if isinstance(_n_or_a, int):
245            self._n = _n_or_a
246            self._log = (self._n - 1).bit_length()
247            self._size = 1 << self._log
248            self._data = [self._e] * (self._size << 1)
249        else:
250            _n_or_a = list(_n_or_a)
251            self._n = len(_n_or_a)
252            self._log = (self._n - 1).bit_length()
253            self._size = 1 << self._log
254            _data = [self._e] * (self._size << 1)
255            _data[self._size : self._size + self._n] = _n_or_a
256            for i in range(self._size - 1, 0, -1):
257                _data[i] = (
258                    _data[i << 1]
259                    if _data[i << 1] < _data[i << 1 | 1]
260                    else _data[i << 1 | 1]
261                )
262            self._data = _data
263
264    def set(self, k: int, v: T) -> None:
265        if k < 0:
266            k += self._n
267        assert (
268            0 <= k < self._n
269        ), f"IndexError: {self.__class__.__name__}.set({k}: int, {v}: T), n={self._n}"
270        k += self._size
271        self._data[k] = v
272        for _ in range(self._log):
273            k >>= 1
274            self._data[k] = (
275                self._data[k << 1]
276                if self._data[k << 1] < self._data[k << 1 | 1]
277                else self._data[k << 1 | 1]
278            )
279
280    def get(self, k: int) -> T:
281        if k < 0:
282            k += self._n
283        assert (
284            0 <= k < self._n
285        ), f"IndexError: {self.__class__.__name__}.get({k}: int), n={self._n}"
286        return self._data[k + self._size]
287
288    def prod(self, l: int, r: int) -> T:
289        assert (
290            0 <= l <= r <= self._n
291        ), f"IndexError: {self.__class__.__name__}.prod({l}: int, {r}: int)"
292        l += self._size
293        r += self._size
294        res = self._e
295        while l < r:
296            if l & 1:
297                if res > self._data[l]:
298                    res = self._data[l]
299                l += 1
300            if r & 1:
301                r ^= 1
302                if res > self._data[r]:
303                    res = self._data[r]
304            l >>= 1
305            r >>= 1
306        return res
307
308    def all_prod(self) -> T:
309        return self._data[1]
310
311    def max_right(self, l: int, f=lambda lr: lr):
312        assert (
313            0 <= l <= self._n
314        ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range"
315        assert f(
316            self._e
317        ), f"{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true."
318        if l == self._n:
319            return self._n
320        l += self._size
321        s = self._e
322        while True:
323            while l & 1 == 0:
324                l >>= 1
325            if not f(min(s, self._data[l])):
326                while l < self._size:
327                    l <<= 1
328                    if f(min(s, self._data[l])):
329                        if s > self._data[l]:
330                            s = self._data[l]
331                        l += 1
332                return l - self._size
333            s = min(s, self._data[l])
334            l += 1
335            if l & -l == l:
336                break
337        return self._n
338
339    def min_left(self, r: int, f=lambda lr: lr):
340        assert (
341            0 <= r <= self._n
342        ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range"
343        assert f(
344            self._e
345        ), f"{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true."
346        if r == 0:
347            return 0
348        r += self._size
349        s = self._e
350        while True:
351            r -= 1
352            while r > 1 and r & 1:
353                r >>= 1
354            if not f(min(self._data[r], s)):
355                while r < self._size:
356                    r = r << 1 | 1
357                    if f(min(self._data[r], s)):
358                        if s > self._data[r]:
359                            s = self._data[r]
360                        r -= 1
361                return r + 1 - self._size
362            s = min(self._data[r], s)
363            if r & -r == r:
364                break
365        return 0
366
367    def tolist(self) -> list[T]:
368        return [self.get(i) for i in range(self._n)]
369
370    def show(self) -> None:
371        print(
372            f"<{self.__class__.__name__}> [\n"
373            + "\n".join(
374                [
375                    "  "
376                    + " ".join(
377                        map(str, [self._data[(1 << i) + j] for j in range(1 << i)])
378                    )
379                    for i in range(self._log + 1)
380                ]
381            )
382            + "\n]"
383        )
384
385    def __getitem__(self, k: int) -> T:
386        assert (
387            -self._n <= k < self._n
388        ), f"IndexError: {self.__class__.__name__}.__getitem__({k}: int), n={self._n}"
389        return self.get(k)
390
391    def __setitem__(self, k: int, v: T):
392        assert (
393            -self._n <= k < self._n
394        ), f"IndexError: {self.__class__.__name__}.__setitem__{k}: int, {v}: T), n={self._n}"
395        self.set(k, v)
396
397    def __str__(self):
398        return "[" + ", ".join(map(str, (self.get(i) for i in range(self._n)))) + "]"
399
400    def __repr__(self):
401        return f"{self.__class__.__name__}({self})"
402
403
404class EulerTour:
405
406    def __init__(
407        self, G: list[list[tuple[int, int]]], root: int, vertexcost: list[int] = []
408    ) -> None:
409        n = len(G)
410        if not vertexcost:
411            vertexcost = [0] * n
412
413        path = [0] * (2 * n)
414        vcost1 = [0] * (2 * n)  # for vertex subtree
415        vcost2 = [0] * (2 * n)  # for vertex path
416        ecost1 = [0] * (2 * n)  # for edge subtree
417        ecost2 = [0] * (2 * n)  # for edge path
418        nodein = [0] * n
419        nodeout = [0] * n
420        depth = [-1] * n
421
422        curtime = -1
423        depth[root] = 0
424        stack: list[tuple[int, int]] = [(~root, 0), (root, 0)]
425        while stack:
426            curtime += 1
427            v, ec = stack.pop()
428            if v >= 0:
429                nodein[v] = curtime
430                path[curtime] = v
431                ecost1[curtime] = ec
432                ecost2[curtime] = ec
433                vcost1[curtime] = vertexcost[v]
434                vcost2[curtime] = vertexcost[v]
435                if len(G[v]) == 1:
436                    nodeout[v] = curtime + 1
437                for x, c in G[v]:
438                    if depth[x] != -1:
439                        continue
440                    depth[x] = depth[v] + 1
441                    stack.append((~v, c))
442                    stack.append((x, c))
443            else:
444                v = ~v
445                path[curtime] = v
446                ecost1[curtime] = 0
447                ecost2[curtime] = -ec
448                vcost1[curtime] = 0
449                vcost2[curtime] = -vertexcost[v]
450                nodeout[v] = curtime
451
452        # ---------------------- #
453
454        self._n = n
455        self._depth = depth
456        self._nodein = nodein
457        self._nodeout = nodeout
458        self._vertexcost = vertexcost
459        self._path = path
460
461        self._vcost_subtree = FenwickTree(vcost1)
462        self._vcost_path = FenwickTree(vcost2)
463        self._ecost_subtree = FenwickTree(ecost1)
464        self._ecost_path = FenwickTree(ecost2)
465
466        bit = len(path).bit_length()
467        self.msk = (1 << bit) - 1
468        a: list[int] = [(depth[v] << bit) + i for i, v in enumerate(path)]
469        self._st: SegmentTreeRmQ[int] = SegmentTreeRmQ(a, e=max(a))
470
471    def lca(self, u: int, v: int) -> int:
472        if u == v:
473            return u
474        l = min(self._nodein[u], self._nodein[v])
475        r = max(self._nodeout[u], self._nodeout[v])
476        ind = self._st.prod(l, r) & self.msk
477        return self._path[ind]
478
479    def lca_mul(self, a: list[int]) -> int:
480        l, r = self._n + 1, -self._n - 1
481        for e in a:
482            l = min(l, self._nodein[e])
483            r = max(r, self._nodeout[e])
484        ind = self._st.prod(l, r) & self.msk
485        return self._path[ind]
486
487    def subtree_vcost(self, v: int) -> int:
488        l = self._nodein[v]
489        r = self._nodeout[v]
490        return self._vcost_subtree.prod(l, r)
491
492    def subtree_ecost(self, v: int) -> int:
493        l = self._nodein[v]
494        r = self._nodeout[v]
495        return self._ecost_subtree.prod(l + 1, r)
496
497    def _path_vcost(self, v: int) -> int:
498        """頂点 v を含む"""
499        return self._vcost_path.pref(self._nodein[v] + 1)
500
501    def _path_ecost(self, v: int) -> int:
502        """根から頂点 v までの辺"""
503        return self._ecost_path.pref(self._nodein[v] + 1)
504
505    def path_vcost(self, u: int, v: int) -> int:
506        a = self.lca(u, v)
507        return (
508            self._path_vcost(u)
509            + self._path_vcost(v)
510            - 2 * self._path_vcost(a)
511            + self._vertexcost[a]
512        )
513
514    def path_ecost(self, u: int, v: int) -> int:
515        return (
516            self._path_ecost(u)
517            + self._path_ecost(v)
518            - 2 * self._path_ecost(self.lca(u, v))
519        )
520
521    def add_vertex(self, v: int, w: int) -> None:
522        """Add w to vertex x. / O(logN)"""
523        l = self._nodein[v]
524        r = self._nodeout[v]
525        self._vcost_subtree.add(l, w)
526        self._vcost_path.add(l, w)
527        self._vcost_path.add(r, -w)
528        self._vertexcost[v] += w
529
530    def set_vertex(self, v: int, w: int) -> None:
531        """Set w to vertex v. / O(logN)"""
532        self.add_vertex(v, w - self._vertexcost[v])
533
534    def add_edge(self, u: int, v: int, w: int) -> None:
535        """Add w to edge([u - v]). / O(logN)"""
536        if self._depth[u] < self._depth[v]:
537            u, v = v, u
538        l = self._nodein[u]
539        r = self._nodeout[u]
540        self._ecost_subtree.add(l, w)
541        self._ecost_subtree.add(r + 1, -w)
542        self._ecost_path.add(l, w)
543        self._ecost_path.add(r + 1, -w)
544
545    def set_edge(self, u: int, v: int, w: int) -> None:
546        """Set w to edge([u - v]). / O(logN)"""
547        self.add_edge(u, v, w - self.path_ecost(u, v))

仕様

class EulerTour(G: list[list[tuple[int, int]]], root: int, vertexcost: list[int] = [])[source]

Bases: object

add_edge(u: int, v: int, w: int) None[source]

Add w to edge([u - v]). / O(logN)

add_vertex(v: int, w: int) None[source]

Add w to vertex x. / O(logN)

lca(u: int, v: int) int[source]
lca_mul(a: list[int]) int[source]
path_ecost(u: int, v: int) int[source]
path_vcost(u: int, v: int) int[source]
set_edge(u: int, v: int, w: int) None[source]

Set w to edge([u - v]). / O(logN)

set_vertex(v: int, w: int) None[source]

Set w to vertex v. / O(logN)

subtree_ecost(v: int) int[source]
subtree_vcost(v: int) int[source]