Source code for titan_pylib.graph.hld.hld

  1from typing import Any, Iterator
  2
  3
[docs] 4class HLD: 5 6 def __init__(self, G: list[list[int]], root: int): 7 """``root`` を根とする木 ``G`` を HLD します。 8 :math:`O(n)` です。 9 10 Args: 11 G (list[list[int]]): 木を表す隣接リストです。 12 root (int): 根です。 13 """ 14 n = len(G) 15 self.n: int = n 16 self.G: list[list[int]] = G 17 self.size: list[int] = [1] * n 18 self.par: list[int] = [-1] * n 19 self.dep: list[int] = [-1] * n 20 self.nodein: list[int] = [0] * n 21 self.nodeout: list[int] = [0] * n 22 self.head: list[int] = [0] * n 23 self.hld: list[int] = [0] * n 24 self._dfs(root) 25 26 def _dfs(self, root: int) -> None: 27 dep, par, size, G = self.dep, self.par, self.size, self.G 28 dep[root] = 0 29 stack = [~root, root] 30 while stack: 31 v = stack.pop() 32 if v >= 0: 33 dep_nxt = dep[v] + 1 34 for x in G[v]: 35 if dep[x] != -1: 36 continue 37 dep[x] = dep_nxt 38 stack.append(~x) 39 stack.append(x) 40 else: 41 v = ~v 42 G_v, dep_v = G[v], dep[v] 43 for i, x in enumerate(G_v): 44 if dep[x] < dep_v: 45 par[v] = x 46 continue 47 size[v] += size[x] 48 if size[x] > size[G_v[0]]: 49 G_v[0], G_v[i] = G_v[i], G_v[0] 50 51 head, nodein, nodeout, hld = self.head, self.nodein, self.nodeout, self.hld 52 curtime = 0 53 stack = [~root, root] 54 while stack: 55 v = stack.pop() 56 if v >= 0: 57 if par[v] == -1: 58 head[v] = v 59 nodein[v] = curtime 60 hld[curtime] = v 61 curtime += 1 62 if not G[v]: 63 continue 64 G_v0 = G[v][0] 65 for x in reversed(G[v]): 66 if x == par[v]: 67 continue 68 head[x] = head[v] if x == G_v0 else x 69 stack.append(~x) 70 stack.append(x) 71 else: 72 nodeout[~v] = curtime 73
[docs] 74 def build_list(self, a: list[Any]) -> list[Any]: 75 """``hld配列`` を基にインデックスを振りなおします。非破壊的です。 76 :math:`O(n)` です。 77 78 Args: 79 a (list[Any]): 元の配列です。 80 81 Returns: 82 list[Any]: 振りなおし後の配列です。 83 """ 84 return [a[e] for e in self.hld]
85
[docs] 86 def for_each_vertex_path(self, u: int, v: int) -> Iterator[tuple[int, int]]: 87 """``u-v`` パスに対応する区間のインデックスを返します。 88 :math:`O(\\log{n})` です。 89 """ 90 head, nodein, dep, par = self.head, self.nodein, self.dep, self.par 91 while head[u] != head[v]: 92 if dep[head[u]] < dep[head[v]]: 93 u, v = v, u 94 yield nodein[head[u]], nodein[u] + 1 95 u = par[head[u]] 96 if dep[u] < dep[v]: 97 u, v = v, u 98 yield nodein[v], nodein[u] + 1
99
[docs] 100 def for_each_vertex_subtree(self, v: int) -> Iterator[tuple[int, int]]: 101 """頂点 ``v`` の部分木に対応する区間のインデックスを返します。 102 :math:`O(1)` です。 103 """ 104 yield self.nodein[v], self.nodeout[v]
105
[docs] 106 def path_kth_elm(self, s: int, t: int, k: int) -> int: 107 """``s`` から ``t`` に向かって ``k`` 個進んだ頂点のインデックスを返します。 108 存在しないときは ``-1`` を返します。 109 :math:`O(\\log{n})` です。 110 """ 111 head, dep, par = self.head, self.dep, self.par 112 lca = self.lca(s, t) 113 d = dep[s] + dep[t] - 2 * dep[lca] 114 if d < k: 115 return -1 116 if dep[s] - dep[lca] < k: 117 s = t 118 k = d - k 119 hs = head[s] 120 while dep[s] - dep[hs] < k: 121 k -= dep[s] - dep[hs] + 1 122 s = par[hs] 123 hs = head[s] 124 return self.hld[self.nodein[s] - k]
125
[docs] 126 def lca(self, u: int, v: int) -> int: 127 """``u``, ``v`` の LCA を返します。 128 :math:`O(\\log{n})` です。 129 """ 130 nodein, head, par = self.nodein, self.head, self.par 131 while True: 132 if nodein[u] > nodein[v]: 133 u, v = v, u 134 if head[u] == head[v]: 135 return u 136 v = par[head[v]]
137
[docs] 138 def dist(self, u: int, v: int) -> int: 139 return self.dep[u] + self.dep[v] - 2 * self.dep[self.lca(u, v)]
140
[docs] 141 def is_on_path(self, u: int, v: int, a: int) -> bool: 142 """Return True if (a is on path(u - v)) else False. / O(logN)""" 143 return self.dist(u, a) + self.dist(a, v) == self.dist(u, v)