hld

ソースコード

from titan_pylib.graph.hld.hld import HLD

view on github

展開済みコード

  1# from titan_pylib.graph.hld.hld import HLD
  2from typing import Any, Iterator
  3
  4
  5class HLD:
  6
  7    def __init__(self, G: list[list[int]], root: int):
  8        """``root`` を根とする木 ``G`` を HLD します。
  9        :math:`O(n)` です。
 10
 11        Args:
 12          G (list[list[int]]): 木を表す隣接リストです。
 13          root (int): 根です。
 14        """
 15        n = len(G)
 16        self.n: int = n
 17        self.G: list[list[int]] = G
 18        self.size: list[int] = [1] * n
 19        self.par: list[int] = [-1] * n
 20        self.dep: list[int] = [-1] * n
 21        self.nodein: list[int] = [0] * n
 22        self.nodeout: list[int] = [0] * n
 23        self.head: list[int] = [0] * n
 24        self.hld: list[int] = [0] * n
 25        self._dfs(root)
 26
 27    def _dfs(self, root: int) -> None:
 28        dep, par, size, G = self.dep, self.par, self.size, self.G
 29        dep[root] = 0
 30        stack = [~root, root]
 31        while stack:
 32            v = stack.pop()
 33            if v >= 0:
 34                dep_nxt = dep[v] + 1
 35                for x in G[v]:
 36                    if dep[x] != -1:
 37                        continue
 38                    dep[x] = dep_nxt
 39                    stack.append(~x)
 40                    stack.append(x)
 41            else:
 42                v = ~v
 43                G_v, dep_v = G[v], dep[v]
 44                for i, x in enumerate(G_v):
 45                    if dep[x] < dep_v:
 46                        par[v] = x
 47                        continue
 48                    size[v] += size[x]
 49                    if size[x] > size[G_v[0]]:
 50                        G_v[0], G_v[i] = G_v[i], G_v[0]
 51
 52        head, nodein, nodeout, hld = self.head, self.nodein, self.nodeout, self.hld
 53        curtime = 0
 54        stack = [~root, root]
 55        while stack:
 56            v = stack.pop()
 57            if v >= 0:
 58                if par[v] == -1:
 59                    head[v] = v
 60                nodein[v] = curtime
 61                hld[curtime] = v
 62                curtime += 1
 63                if not G[v]:
 64                    continue
 65                G_v0 = G[v][0]
 66                for x in reversed(G[v]):
 67                    if x == par[v]:
 68                        continue
 69                    head[x] = head[v] if x == G_v0 else x
 70                    stack.append(~x)
 71                    stack.append(x)
 72            else:
 73                nodeout[~v] = curtime
 74
 75    def build_list(self, a: list[Any]) -> list[Any]:
 76        """``hld配列`` を基にインデックスを振りなおします。非破壊的です。
 77        :math:`O(n)` です。
 78
 79        Args:
 80            a (list[Any]): 元の配列です。
 81
 82        Returns:
 83            list[Any]: 振りなおし後の配列です。
 84        """
 85        return [a[e] for e in self.hld]
 86
 87    def for_each_vertex_path(self, u: int, v: int) -> Iterator[tuple[int, int]]:
 88        """``u-v`` パスに対応する区間のインデックスを返します。
 89        :math:`O(\\log{n})` です。
 90        """
 91        head, nodein, dep, par = self.head, self.nodein, self.dep, self.par
 92        while head[u] != head[v]:
 93            if dep[head[u]] < dep[head[v]]:
 94                u, v = v, u
 95            yield nodein[head[u]], nodein[u] + 1
 96            u = par[head[u]]
 97        if dep[u] < dep[v]:
 98            u, v = v, u
 99        yield nodein[v], nodein[u] + 1
100
101    def for_each_vertex_subtree(self, v: int) -> Iterator[tuple[int, int]]:
102        """頂点 ``v`` の部分木に対応する区間のインデックスを返します。
103        :math:`O(1)` です。
104        """
105        yield self.nodein[v], self.nodeout[v]
106
107    def path_kth_elm(self, s: int, t: int, k: int) -> int:
108        """``s`` から ``t`` に向かって ``k`` 個進んだ頂点のインデックスを返します。
109        存在しないときは ``-1`` を返します。
110        :math:`O(\\log{n})` です。
111        """
112        head, dep, par = self.head, self.dep, self.par
113        lca = self.lca(s, t)
114        d = dep[s] + dep[t] - 2 * dep[lca]
115        if d < k:
116            return -1
117        if dep[s] - dep[lca] < k:
118            s = t
119            k = d - k
120        hs = head[s]
121        while dep[s] - dep[hs] < k:
122            k -= dep[s] - dep[hs] + 1
123            s = par[hs]
124            hs = head[s]
125        return self.hld[self.nodein[s] - k]
126
127    def lca(self, u: int, v: int) -> int:
128        """``u``, ``v`` の LCA を返します。
129        :math:`O(\\log{n})` です。
130        """
131        nodein, head, par = self.nodein, self.head, self.par
132        while True:
133            if nodein[u] > nodein[v]:
134                u, v = v, u
135            if head[u] == head[v]:
136                return u
137            v = par[head[v]]
138
139    def dist(self, u: int, v: int) -> int:
140        return self.dep[u] + self.dep[v] - 2 * self.dep[self.lca(u, v)]
141
142    def is_on_path(self, u: int, v: int, a: int) -> bool:
143        """Return True if (a is on path(u - v)) else False. / O(logN)"""
144        return self.dist(u, a) + self.dist(a, v) == self.dist(u, v)

仕様

class HLD(G: list[list[int]], root: int)[source]

Bases: object

build_list(a: list[Any]) list[Any][source]

hld配列 を基にインデックスを振りなおします。非破壊的です。 \(O(n)\) です。

Parameters:

a (list[Any]) – 元の配列です。

Returns:

振りなおし後の配列です。

Return type:

list[Any]

dist(u: int, v: int) int[source]
for_each_vertex_path(u: int, v: int) Iterator[tuple[int, int]][source]

u-v パスに対応する区間のインデックスを返します。 \(O(\log{n})\) です。

for_each_vertex_subtree(v: int) Iterator[tuple[int, int]][source]

頂点 v の部分木に対応する区間のインデックスを返します。 \(O(1)\) です。

is_on_path(u: int, v: int, a: int) bool[source]

Return True if (a is on path(u - v)) else False. / O(logN)

lca(u: int, v: int) int[source]

u, v の LCA を返します。 \(O(\log{n})\) です。

path_kth_elm(s: int, t: int, k: int) int[source]

s から t に向かって k 個進んだ頂点のインデックスを返します。 存在しないときは -1 を返します。 \(O(\log{n})\) です。