1from typing import Any, Iterator
2
3
[docs]
4class HLD:
5
6 def __init__(self, G: list[list[int]], root: int):
7 """``root`` を根とする木 ``G`` を HLD します。
8 :math:`O(n)` です。
9
10 Args:
11 G (list[list[int]]): 木を表す隣接リストです。
12 root (int): 根です。
13 """
14 n = len(G)
15 self.n: int = n
16 self.G: list[list[int]] = G
17 self.size: list[int] = [1] * n
18 self.par: list[int] = [-1] * n
19 self.dep: list[int] = [-1] * n
20 self.nodein: list[int] = [0] * n
21 self.nodeout: list[int] = [0] * n
22 self.head: list[int] = [0] * n
23 self.hld: list[int] = [0] * n
24 self._dfs(root)
25
26 def _dfs(self, root: int) -> None:
27 dep, par, size, G = self.dep, self.par, self.size, self.G
28 dep[root] = 0
29 stack = [~root, root]
30 while stack:
31 v = stack.pop()
32 if v >= 0:
33 dep_nxt = dep[v] + 1
34 for x in G[v]:
35 if dep[x] != -1:
36 continue
37 dep[x] = dep_nxt
38 stack.append(~x)
39 stack.append(x)
40 else:
41 v = ~v
42 G_v, dep_v = G[v], dep[v]
43 for i, x in enumerate(G_v):
44 if dep[x] < dep_v:
45 par[v] = x
46 continue
47 size[v] += size[x]
48 if size[x] > size[G_v[0]]:
49 G_v[0], G_v[i] = G_v[i], G_v[0]
50
51 head, nodein, nodeout, hld = self.head, self.nodein, self.nodeout, self.hld
52 curtime = 0
53 stack = [~root, root]
54 while stack:
55 v = stack.pop()
56 if v >= 0:
57 if par[v] == -1:
58 head[v] = v
59 nodein[v] = curtime
60 hld[curtime] = v
61 curtime += 1
62 if not G[v]:
63 continue
64 G_v0 = G[v][0]
65 for x in reversed(G[v]):
66 if x == par[v]:
67 continue
68 head[x] = head[v] if x == G_v0 else x
69 stack.append(~x)
70 stack.append(x)
71 else:
72 nodeout[~v] = curtime
73
[docs]
74 def build_list(self, a: list[Any]) -> list[Any]:
75 """``hld配列`` を基にインデックスを振りなおします。非破壊的です。
76 :math:`O(n)` です。
77
78 Args:
79 a (list[Any]): 元の配列です。
80
81 Returns:
82 list[Any]: 振りなおし後の配列です。
83 """
84 return [a[e] for e in self.hld]
85
[docs]
86 def for_each_vertex_path(self, u: int, v: int) -> Iterator[tuple[int, int]]:
87 """``u-v`` パスに対応する区間のインデックスを返します。
88 :math:`O(\\log{n})` です。
89 """
90 head, nodein, dep, par = self.head, self.nodein, self.dep, self.par
91 while head[u] != head[v]:
92 if dep[head[u]] < dep[head[v]]:
93 u, v = v, u
94 yield nodein[head[u]], nodein[u] + 1
95 u = par[head[u]]
96 if dep[u] < dep[v]:
97 u, v = v, u
98 yield nodein[v], nodein[u] + 1
99
[docs]
100 def for_each_vertex_subtree(self, v: int) -> Iterator[tuple[int, int]]:
101 """頂点 ``v`` の部分木に対応する区間のインデックスを返します。
102 :math:`O(1)` です。
103 """
104 yield self.nodein[v], self.nodeout[v]
105
[docs]
106 def path_kth_elm(self, s: int, t: int, k: int) -> int:
107 """``s`` から ``t`` に向かって ``k`` 個進んだ頂点のインデックスを返します。
108 存在しないときは ``-1`` を返します。
109 :math:`O(\\log{n})` です。
110 """
111 head, dep, par = self.head, self.dep, self.par
112 lca = self.lca(s, t)
113 d = dep[s] + dep[t] - 2 * dep[lca]
114 if d < k:
115 return -1
116 if dep[s] - dep[lca] < k:
117 s = t
118 k = d - k
119 hs = head[s]
120 while dep[s] - dep[hs] < k:
121 k -= dep[s] - dep[hs] + 1
122 s = par[hs]
123 hs = head[s]
124 return self.hld[self.nodein[s] - k]
125
[docs]
126 def lca(self, u: int, v: int) -> int:
127 """``u``, ``v`` の LCA を返します。
128 :math:`O(\\log{n})` です。
129 """
130 nodein, head, par = self.nodein, self.head, self.par
131 while True:
132 if nodein[u] > nodein[v]:
133 u, v = v, u
134 if head[u] == head[v]:
135 return u
136 v = par[head[v]]
137
[docs]
138 def dist(self, u: int, v: int) -> int:
139 return self.dep[u] + self.dep[v] - 2 * self.dep[self.lca(u, v)]
140
[docs]
141 def is_on_path(self, u: int, v: int, a: int) -> bool:
142 """Return True if (a is on path(u - v)) else False. / O(logN)"""
143 return self.dist(u, a) + self.dist(a, v) == self.dist(u, v)