lca

ソースコード

from titan_pylib.graph.lca import LCA

view on github

展開済みコード

  1# from titan_pylib.graph.lca import LCA
  2# from titan_pylib.data_structures.sparse_table.sparse_table_RmQ import SparseTableRmQ
  3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  4from typing import Protocol
  5
  6
  7class SupportsLessThan(Protocol):
  8
  9    def __lt__(self, other) -> bool: ...
 10from typing import Generic, TypeVar, Iterable
 11
 12T = TypeVar("T", bound=SupportsLessThan)
 13
 14
 15class SparseTableRmQ(Generic[T]):
 16    """
 17    2項演算を :math:`\\min` にしたものです。
 18    """
 19
 20    def __init__(self, a: Iterable[T], e: T):
 21        if not isinstance(a, list):
 22            a = list(a)
 23        self.size = len(a)
 24        log = self.size.bit_length() - 1
 25        data = [a] + [[]] * log
 26        for i in range(log):
 27            pre = data[i]
 28            l = 1 << i
 29            data[i + 1] = [
 30                pre[j] if pre[j] < pre[j + l] else pre[j + l]
 31                for j in range(len(pre) - l)
 32            ]
 33        self.data = data
 34        self.e = e
 35
 36    def prod(self, l: int, r: int) -> T:
 37        assert 0 <= l <= r <= self.size
 38        if l == r:
 39            return self.e
 40        u = (r - l).bit_length() - 1
 41        return (
 42            self.data[u][l]
 43            if self.data[u][l] < self.data[u][r - (1 << u)]
 44            else self.data[u][r - (1 << u)]
 45        )
 46
 47    def __getitem__(self, k: int) -> T:
 48        assert 0 <= k < self.size
 49        return self.data[0][k]
 50
 51    def __len__(self):
 52        return self.size
 53
 54    def __str__(self):
 55        return str(self.data[0])
 56
 57    def __repr__(self):
 58        return f"{self.__class__.__name__}({self.data[0]}, {self.e})"
 59
 60
 61class LCA:
 62    """LCA を定数倍良く求めます。
 63
 64    :math:`< O(NlogN), O(1) >`
 65    https://github.com/cheran-senthil/PyRival/blob/master/pyrival/graphs/lca.py
 66    """
 67
 68    def __init__(self, G: list[list[int]], root: int) -> None:
 69        """根が ``root`` の重み無し隣接リスト ``G`` で表されるグラフに対して LCA を求めます。
 70        時間・空間 :math:`O(n\\log{n})` です。
 71
 72        Args:
 73          G (list[list[int]]): 隣接リストです。
 74          root (int): 根です。
 75        """
 76        _n = len(G)
 77        path = [-1] * _n
 78        nodein = [-1] * _n
 79        par = [-1] * _n
 80        curtime = -1
 81        stack = [root]
 82        while stack:
 83            v = stack.pop()
 84            path[curtime] = par[v]
 85            curtime += 1
 86            nodein[v] = curtime
 87            for x in G[v]:
 88                if nodein[x] != -1:
 89                    continue
 90                par[x] = v
 91                stack.append(x)
 92        self._n = _n
 93        self._path = path
 94        self._nodein = nodein
 95        self._st: SparseTableRmQ[int] = SparseTableRmQ((nodein[v] for v in path), e=_n)
 96
 97    def lca(self, u: int, v: int) -> int:
 98        """頂点 ``u`` と頂点 ``v`` の LCA を返します。
 99        :math:`O(1)` です。
100        """
101        if u == v:
102            return u
103        l, r = self._nodein[u], self._nodein[v]
104        if l > r:
105            l, r = r, l
106        return self._path[self._st.prod(l, r)]
107
108    def lca_mul(self, a: list[int]) -> int:
109        """頂点集合 ``a`` の LCA を返します。"""
110        if all(a[i] == a[i + 1] for i in range(len(a) - 1)):
111            return a[0]
112        l = self._n + 1
113        r = -l
114        for e in a:
115            e = self._nodein[e]
116            if l > e:
117                l = e
118            if r < e:
119                r = e
120        return self._path[self._st.prod(l, r)]

仕様

class LCA(G: list[list[int]], root: int)[source]

Bases: object

LCA を定数倍良く求めます。

\(< O(NlogN), O(1) >\) https://github.com/cheran-senthil/PyRival/blob/master/pyrival/graphs/lca.py

lca(u: int, v: int) int[source]

頂点 u と頂点 v の LCA を返します。 \(O(1)\) です。

lca_mul(a: list[int]) int[source]

頂点集合 a の LCA を返します。