[docs]
1class RootedTree:
2
3 def __init__(
4 self, _G: list[list[int]], _root: int, cp: bool = False, lca: bool = False
5 ):
6 self._n: int = len(_G)
7 self._G: list[list[int]] = _G
8 self._root: int = _root
9 self._height: int = -1
10 self._toposo: list[int] = []
11 self._dist: list[int] = []
12 self._descendant_num: list[int] = []
13 self._child: list[list[int]] = []
14 self._child_num: list[int] = []
15 self._parents: list[int] = []
16 self._diameter: tuple[int, int, int] = (-1, -1, -1)
17 self._bipartite_graph = []
18 self._cp = cp
19 self._lca = lca
20 K = self._n.bit_length()
21 self._K = K
22 self._doubling = [[-1] * self._n for _ in range(self._K)]
23 self._calc_dist_toposo()
24 if cp:
25 self._calc_child_parents()
26 if lca:
27 self._calc_doubling()
28
29 def __str__(self):
30 self._calc_child_parents()
31 ret = ["<RootedTree> ["]
32 ret.extend(
33 [
34 f" dist:{str(d).zfill(2)} - v:{str(i).zfill(2)} - p:{str(self._parents[i]).zfill(2)} - child:{sorted(self._child[i])}"
35 for i, d in sorted(enumerate(self._dist), key=lambda x: x[1])
36 ]
37 )
38 ret.append("]")
39 return "\n".join(ret)
40
41 def _calc_dist_toposo(self) -> None:
42 """Calc dist and toposo. / O(N)"""
43 # initメソッドで直接実行
44 _G, _root = self._G, self._root
45 _dist = [-1] * self._n
46 _dist[_root] = 0
47 _toposo = []
48 _toposo.append(_root)
49 todo = [_root]
50 while todo:
51 v = todo.pop()
52 d = _dist[v]
53 for x in _G[v]:
54 if _dist[x] != -1:
55 continue
56 _dist[x] = d + 1
57 todo.append(x)
58 _toposo.append(x)
59 self._dist = _dist
60 self._toposo = _toposo
61
62 def _calc_child_parents(self) -> None:
63 """Calc child and parents. / O(N)"""
64 if self._child and self._child_num and self._parents:
65 return
66 _G, _dist = self._G, self._dist
67 _child_num = [0] * self._n
68 _child = [[] for _ in range(self._n)]
69 _parents = [-1] * self._n
70 for v in self._toposo[::-1]:
71 for x in _G[v]:
72 if _dist[x] < _dist[v]:
73 _parents[v] = x
74 continue
75 _child[v].append(x)
76 _child_num[v] += 1
77 self._child_num = _child_num
78 self._child = _child
79 self._parents = _parents
80
[docs]
81 def get_dists(self) -> list[int]:
82 """Return dist from root. / O(N)"""
83 return self._dist
84
[docs]
85 def get_toposo(self) -> list[int]:
86 """Return toposo. / O(N)"""
87 return self._toposo
88
[docs]
89 def get_height(self) -> int:
90 """Return height. / O(N)"""
91 if self._height > -1:
92 return self._height
93 self._height = max(self._dist)
94 return self._height
95
[docs]
96 def get_descendant_num(self) -> list[int]:
97 """Return descendant_num. / O(N)"""
98 if self._descendant_num:
99 return self._descendant_num
100 _G, _dist = self._G, self._dist
101 _descendant_num = [1] * self._n
102 for v in self._toposo[::-1]:
103 for x in _G[v]:
104 if _dist[x] < _dist[v]:
105 continue
106 _descendant_num[v] += _descendant_num[x]
107 for i in range(self._n):
108 _descendant_num[i] -= 1
109 self._descendant_num = _descendant_num
110 return self._descendant_num
111
[docs]
112 def get_child(self) -> list[list[int]]:
113 """Return child / O(N)"""
114 if self._child:
115 return self._child
116 self._calc_child_parents()
117 return self._child
118
[docs]
119 def get_child_num(self) -> list[int]:
120 """Return child_num. / O(N)"""
121 if self._child_num:
122 return self._child_num
123 self._calc_child_parents()
124 return self._child_num
125
[docs]
126 def get_parents(self) -> list[int]:
127 """Return parents. / O(N)"""
128 if self._parents:
129 return self._parents
130 self._calc_child_parents()
131 return self._parents
132
[docs]
133 def get_diameter(self) -> tuple[int, int, int]:
134 """Return diameter of tree. (diameter, start, stop) / O(N)"""
135 if self._diameter[0] > -1:
136 return self._diameter
137 s = self._dist.index(self.get_height())
138 todo = [s]
139 ndist = [-1] * self._n
140 ndist[s] = 0
141 while todo:
142 v = todo.pop()
143 d = ndist[v]
144 for x in self._G[v]:
145 if ndist[x] != -1:
146 continue
147 ndist[x] = d + 1
148 todo.append(x)
149 diameter = max(ndist)
150 t = ndist.index(diameter)
151 self._diameter = (diameter, s, t)
152 return self._diameter
153
[docs]
154 def get_bipartite_graph(self) -> list[int]:
155 """Return [1 if root else 0]. / O(N)"""
156 if self._bipartite_graph:
157 return self._bipartite_graph
158 self._bipartite_graph = [-1] * self._n
159 self._bipartite_graph[self._root] = 1
160 todo = [self._root]
161 while todo:
162 v = todo.pop()
163 nc = 0 if self._bipartite_graph[v] else 1
164 for x in self._G[v]:
165 if self._bipartite_graph[x] != -1:
166 continue
167 self._bipartite_graph[x] = nc
168 todo.append(x)
169 return self._bipartite_graph
170
171 def _calc_doubling(self) -> None:
172 "Calc doubling if self._lca. / O(NlogN)"
173 if not self._parents:
174 self._calc_child_parents()
175 _doubling = self._doubling
176 for i in range(self._n):
177 _doubling[0][i] = self._parents[i]
178 for k in range(self._K - 1):
179 for v in range(self._n):
180 if _doubling[k][v] < 0:
181 _doubling[k + 1][v] = -1
182 else:
183 _doubling[k + 1][v] = _doubling[k][_doubling[k][v]]
184
[docs]
185 def get_lca(self, u: int, v: int) -> int:
186 """Return LCA of (u, v). / O(logN)"""
187 assert (
188 self._lca
189 ), f"Error: {self.__class__.__name__}.get_lca({u}, {v}), `lca` must be True"
190 _doubling, _dist = self._doubling, self._dist
191 if _dist[u] < _dist[v]:
192 u, v = v, u
193 _r = _dist[u] - _dist[v]
194 for k in range(self._K):
195 if _r >> k & 1:
196 u = _doubling[k][u]
197 if u == v:
198 return u
199 for k in range(self._K - 1, -1, -1):
200 if _doubling[k][u] != _doubling[k][v]:
201 u = _doubling[k][u]
202 v = _doubling[k][v]
203 return _doubling[0][u]
204
[docs]
205 def get_dist(self, u: int, v: int, vertex: bool = False) -> int:
206 """Return dist(u -> v). / O(logN)"""
207 return (
208 self._dist[u] + self._dist[v] - 2 * self._dist[self.get_lca(u, v)] + vertex
209 )
210
[docs]
211 def is_on_path(self, u: int, v: int, a: int) -> bool:
212 """Return True if (a is on path(u - v)) else False. / O(logN)"""
213 return self.get_dist(u, a) + self.get_dist(a, v) == self.get_dist(u, v)
214
[docs]
215 def get_path(self, u: int, v: int) -> list[int]:
216 """Return path (u -> v). / O(logN + |path|)"""
217 assert self._lca, f"{self.__class__.__name__}.get_path(), `lca` must be True"
218 if u == v:
219 return [u]
220 self.get_parents()
221
222 def get_path_lca(u: int, v: int) -> list[int]:
223 path = []
224 while u != v:
225 u = self._parents[u]
226 if u == v:
227 break
228 path.append(u)
229 return path
230
231 lca = self.get_lca(u, v)
232 path = [u]
233 path.extend(get_path_lca(u, lca))
234 if u != lca and v != lca:
235 path.append(lca)
236 path.extend(get_path_lca(v, lca)[::-1])
237 path.append(v)
238 return path
239
[docs]
240 def dfs_in_out(self) -> tuple[list[int], list[int]]:
241 curtime = -1
242 todo = [~self._root, self._root]
243 nodein = [-1] * self._n
244 nodeout = [-1] * self._n
245 if not self._parents:
246 self._calc_child_parents()
247 _G, _parents = self._G, self._parents
248 while todo:
249 curtime += 1
250 v = todo.pop()
251 if v >= 0:
252 nodein[v] = curtime
253 for x in _G[v]:
254 if _parents[v] != x:
255 todo.append(~x)
256 todo.append(x)
257 else:
258 nodeout[~v] = curtime
259 return nodein, nodeout