hld¶
ソースコード¶
from titan_pylib.graph.hld.hld import HLD
展開済みコード¶
1# from titan_pylib.graph.hld.hld import HLD
2from typing import Any, Iterator
3
4
5class HLD:
6
7 def __init__(self, G: list[list[int]], root: int):
8 """``root`` を根とする木 ``G`` を HLD します。
9 :math:`O(n)` です。
10
11 Args:
12 G (list[list[int]]): 木を表す隣接リストです。
13 root (int): 根です。
14 """
15 n = len(G)
16 self.n: int = n
17 self.G: list[list[int]] = G
18 self.size: list[int] = [1] * n
19 self.par: list[int] = [-1] * n
20 self.dep: list[int] = [-1] * n
21 self.nodein: list[int] = [0] * n
22 self.nodeout: list[int] = [0] * n
23 self.head: list[int] = [0] * n
24 self.hld: list[int] = [0] * n
25 self._dfs(root)
26
27 def _dfs(self, root: int) -> None:
28 dep, par, size, G = self.dep, self.par, self.size, self.G
29 dep[root] = 0
30 stack = [~root, root]
31 while stack:
32 v = stack.pop()
33 if v >= 0:
34 dep_nxt = dep[v] + 1
35 for x in G[v]:
36 if dep[x] != -1:
37 continue
38 dep[x] = dep_nxt
39 stack.append(~x)
40 stack.append(x)
41 else:
42 v = ~v
43 G_v, dep_v = G[v], dep[v]
44 for i, x in enumerate(G_v):
45 if dep[x] < dep_v:
46 par[v] = x
47 continue
48 size[v] += size[x]
49 if size[x] > size[G_v[0]]:
50 G_v[0], G_v[i] = G_v[i], G_v[0]
51
52 head, nodein, nodeout, hld = self.head, self.nodein, self.nodeout, self.hld
53 curtime = 0
54 stack = [~root, root]
55 while stack:
56 v = stack.pop()
57 if v >= 0:
58 if par[v] == -1:
59 head[v] = v
60 nodein[v] = curtime
61 hld[curtime] = v
62 curtime += 1
63 if not G[v]:
64 continue
65 G_v0 = G[v][0]
66 for x in reversed(G[v]):
67 if x == par[v]:
68 continue
69 head[x] = head[v] if x == G_v0 else x
70 stack.append(~x)
71 stack.append(x)
72 else:
73 nodeout[~v] = curtime
74
75 def build_list(self, a: list[Any]) -> list[Any]:
76 """``hld配列`` を基にインデックスを振りなおします。非破壊的です。
77 :math:`O(n)` です。
78
79 Args:
80 a (list[Any]): 元の配列です。
81
82 Returns:
83 list[Any]: 振りなおし後の配列です。
84 """
85 return [a[e] for e in self.hld]
86
87 def for_each_vertex_path(self, u: int, v: int) -> Iterator[tuple[int, int]]:
88 """``u-v`` パスに対応する区間のインデックスを返します。
89 :math:`O(\\log{n})` です。
90 """
91 head, nodein, dep, par = self.head, self.nodein, self.dep, self.par
92 while head[u] != head[v]:
93 if dep[head[u]] < dep[head[v]]:
94 u, v = v, u
95 yield nodein[head[u]], nodein[u] + 1
96 u = par[head[u]]
97 if dep[u] < dep[v]:
98 u, v = v, u
99 yield nodein[v], nodein[u] + 1
100
101 def for_each_vertex_subtree(self, v: int) -> Iterator[tuple[int, int]]:
102 """頂点 ``v`` の部分木に対応する区間のインデックスを返します。
103 :math:`O(1)` です。
104 """
105 yield self.nodein[v], self.nodeout[v]
106
107 def path_kth_elm(self, s: int, t: int, k: int) -> int:
108 """``s`` から ``t`` に向かって ``k`` 個進んだ頂点のインデックスを返します。
109 存在しないときは ``-1`` を返します。
110 :math:`O(\\log{n})` です。
111 """
112 head, dep, par = self.head, self.dep, self.par
113 lca = self.lca(s, t)
114 d = dep[s] + dep[t] - 2 * dep[lca]
115 if d < k:
116 return -1
117 if dep[s] - dep[lca] < k:
118 s = t
119 k = d - k
120 hs = head[s]
121 while dep[s] - dep[hs] < k:
122 k -= dep[s] - dep[hs] + 1
123 s = par[hs]
124 hs = head[s]
125 return self.hld[self.nodein[s] - k]
126
127 def lca(self, u: int, v: int) -> int:
128 """``u``, ``v`` の LCA を返します。
129 :math:`O(\\log{n})` です。
130 """
131 nodein, head, par = self.nodein, self.head, self.par
132 while True:
133 if nodein[u] > nodein[v]:
134 u, v = v, u
135 if head[u] == head[v]:
136 return u
137 v = par[head[v]]
138
139 def dist(self, u: int, v: int) -> int:
140 return self.dep[u] + self.dep[v] - 2 * self.dep[self.lca(u, v)]
141
142 def is_on_path(self, u: int, v: int, a: int) -> bool:
143 """Return True if (a is on path(u - v)) else False. / O(logN)"""
144 return self.dist(u, a) + self.dist(a, v) == self.dist(u, v)
仕様¶
- class HLD(G: list[list[int]], root: int)[source]¶
Bases:
object
- build_list(a: list[Any]) list[Any] [source]¶
hld配列
を基にインデックスを振りなおします。非破壊的です。 \(O(n)\) です。- Parameters:
a (list[Any]) – 元の配列です。
- Returns:
振りなおし後の配列です。
- Return type:
list[Any]
- for_each_vertex_path(u: int, v: int) Iterator[tuple[int, int]] [source]¶
u-v
パスに対応する区間のインデックスを返します。 \(O(\log{n})\) です。
- for_each_vertex_subtree(v: int) Iterator[tuple[int, int]] [source]¶
頂点
v
の部分木に対応する区間のインデックスを返します。 \(O(1)\) です。