1# from titan_pylib.graph.rooted_tree import RootedTree
2class RootedTree:
3
4 def __init__(
5 self, _G: list[list[int]], _root: int, cp: bool = False, lca: bool = False
6 ):
7 self._n: int = len(_G)
8 self._G: list[list[int]] = _G
9 self._root: int = _root
10 self._height: int = -1
11 self._toposo: list[int] = []
12 self._dist: list[int] = []
13 self._descendant_num: list[int] = []
14 self._child: list[list[int]] = []
15 self._child_num: list[int] = []
16 self._parents: list[int] = []
17 self._diameter: tuple[int, int, int] = (-1, -1, -1)
18 self._bipartite_graph = []
19 self._cp = cp
20 self._lca = lca
21 K = self._n.bit_length()
22 self._K = K
23 self._doubling = [[-1] * self._n for _ in range(self._K)]
24 self._calc_dist_toposo()
25 if cp:
26 self._calc_child_parents()
27 if lca:
28 self._calc_doubling()
29
30 def __str__(self):
31 self._calc_child_parents()
32 ret = ["<RootedTree> ["]
33 ret.extend(
34 [
35 f" dist:{str(d).zfill(2)} - v:{str(i).zfill(2)} - p:{str(self._parents[i]).zfill(2)} - child:{sorted(self._child[i])}"
36 for i, d in sorted(enumerate(self._dist), key=lambda x: x[1])
37 ]
38 )
39 ret.append("]")
40 return "\n".join(ret)
41
42 def _calc_dist_toposo(self) -> None:
43 """Calc dist and toposo. / O(N)"""
44 # initメソッドで直接実行
45 _G, _root = self._G, self._root
46 _dist = [-1] * self._n
47 _dist[_root] = 0
48 _toposo = []
49 _toposo.append(_root)
50 todo = [_root]
51 while todo:
52 v = todo.pop()
53 d = _dist[v]
54 for x in _G[v]:
55 if _dist[x] != -1:
56 continue
57 _dist[x] = d + 1
58 todo.append(x)
59 _toposo.append(x)
60 self._dist = _dist
61 self._toposo = _toposo
62
63 def _calc_child_parents(self) -> None:
64 """Calc child and parents. / O(N)"""
65 if self._child and self._child_num and self._parents:
66 return
67 _G, _dist = self._G, self._dist
68 _child_num = [0] * self._n
69 _child = [[] for _ in range(self._n)]
70 _parents = [-1] * self._n
71 for v in self._toposo[::-1]:
72 for x in _G[v]:
73 if _dist[x] < _dist[v]:
74 _parents[v] = x
75 continue
76 _child[v].append(x)
77 _child_num[v] += 1
78 self._child_num = _child_num
79 self._child = _child
80 self._parents = _parents
81
82 def get_dists(self) -> list[int]:
83 """Return dist from root. / O(N)"""
84 return self._dist
85
86 def get_toposo(self) -> list[int]:
87 """Return toposo. / O(N)"""
88 return self._toposo
89
90 def get_height(self) -> int:
91 """Return height. / O(N)"""
92 if self._height > -1:
93 return self._height
94 self._height = max(self._dist)
95 return self._height
96
97 def get_descendant_num(self) -> list[int]:
98 """Return descendant_num. / O(N)"""
99 if self._descendant_num:
100 return self._descendant_num
101 _G, _dist = self._G, self._dist
102 _descendant_num = [1] * self._n
103 for v in self._toposo[::-1]:
104 for x in _G[v]:
105 if _dist[x] < _dist[v]:
106 continue
107 _descendant_num[v] += _descendant_num[x]
108 for i in range(self._n):
109 _descendant_num[i] -= 1
110 self._descendant_num = _descendant_num
111 return self._descendant_num
112
113 def get_child(self) -> list[list[int]]:
114 """Return child / O(N)"""
115 if self._child:
116 return self._child
117 self._calc_child_parents()
118 return self._child
119
120 def get_child_num(self) -> list[int]:
121 """Return child_num. / O(N)"""
122 if self._child_num:
123 return self._child_num
124 self._calc_child_parents()
125 return self._child_num
126
127 def get_parents(self) -> list[int]:
128 """Return parents. / O(N)"""
129 if self._parents:
130 return self._parents
131 self._calc_child_parents()
132 return self._parents
133
134 def get_diameter(self) -> tuple[int, int, int]:
135 """Return diameter of tree. (diameter, start, stop) / O(N)"""
136 if self._diameter[0] > -1:
137 return self._diameter
138 s = self._dist.index(self.get_height())
139 todo = [s]
140 ndist = [-1] * self._n
141 ndist[s] = 0
142 while todo:
143 v = todo.pop()
144 d = ndist[v]
145 for x in self._G[v]:
146 if ndist[x] != -1:
147 continue
148 ndist[x] = d + 1
149 todo.append(x)
150 diameter = max(ndist)
151 t = ndist.index(diameter)
152 self._diameter = (diameter, s, t)
153 return self._diameter
154
155 def get_bipartite_graph(self) -> list[int]:
156 """Return [1 if root else 0]. / O(N)"""
157 if self._bipartite_graph:
158 return self._bipartite_graph
159 self._bipartite_graph = [-1] * self._n
160 self._bipartite_graph[self._root] = 1
161 todo = [self._root]
162 while todo:
163 v = todo.pop()
164 nc = 0 if self._bipartite_graph[v] else 1
165 for x in self._G[v]:
166 if self._bipartite_graph[x] != -1:
167 continue
168 self._bipartite_graph[x] = nc
169 todo.append(x)
170 return self._bipartite_graph
171
172 def _calc_doubling(self) -> None:
173 "Calc doubling if self._lca. / O(NlogN)"
174 if not self._parents:
175 self._calc_child_parents()
176 _doubling = self._doubling
177 for i in range(self._n):
178 _doubling[0][i] = self._parents[i]
179 for k in range(self._K - 1):
180 for v in range(self._n):
181 if _doubling[k][v] < 0:
182 _doubling[k + 1][v] = -1
183 else:
184 _doubling[k + 1][v] = _doubling[k][_doubling[k][v]]
185
186 def get_lca(self, u: int, v: int) -> int:
187 """Return LCA of (u, v). / O(logN)"""
188 assert (
189 self._lca
190 ), f"Error: {self.__class__.__name__}.get_lca({u}, {v}), `lca` must be True"
191 _doubling, _dist = self._doubling, self._dist
192 if _dist[u] < _dist[v]:
193 u, v = v, u
194 _r = _dist[u] - _dist[v]
195 for k in range(self._K):
196 if _r >> k & 1:
197 u = _doubling[k][u]
198 if u == v:
199 return u
200 for k in range(self._K - 1, -1, -1):
201 if _doubling[k][u] != _doubling[k][v]:
202 u = _doubling[k][u]
203 v = _doubling[k][v]
204 return _doubling[0][u]
205
206 def get_dist(self, u: int, v: int, vertex: bool = False) -> int:
207 """Return dist(u -> v). / O(logN)"""
208 return (
209 self._dist[u] + self._dist[v] - 2 * self._dist[self.get_lca(u, v)] + vertex
210 )
211
212 def is_on_path(self, u: int, v: int, a: int) -> bool:
213 """Return True if (a is on path(u - v)) else False. / O(logN)"""
214 return self.get_dist(u, a) + self.get_dist(a, v) == self.get_dist(u, v)
215
216 def get_path(self, u: int, v: int) -> list[int]:
217 """Return path (u -> v). / O(logN + |path|)"""
218 assert self._lca, f"{self.__class__.__name__}.get_path(), `lca` must be True"
219 if u == v:
220 return [u]
221 self.get_parents()
222
223 def get_path_lca(u: int, v: int) -> list[int]:
224 path = []
225 while u != v:
226 u = self._parents[u]
227 if u == v:
228 break
229 path.append(u)
230 return path
231
232 lca = self.get_lca(u, v)
233 path = [u]
234 path.extend(get_path_lca(u, lca))
235 if u != lca and v != lca:
236 path.append(lca)
237 path.extend(get_path_lca(v, lca)[::-1])
238 path.append(v)
239 return path
240
241 def dfs_in_out(self) -> tuple[list[int], list[int]]:
242 curtime = -1
243 todo = [~self._root, self._root]
244 nodein = [-1] * self._n
245 nodeout = [-1] * self._n
246 if not self._parents:
247 self._calc_child_parents()
248 _G, _parents = self._G, self._parents
249 while todo:
250 curtime += 1
251 v = todo.pop()
252 if v >= 0:
253 nodein[v] = curtime
254 for x in _G[v]:
255 if _parents[v] != x:
256 todo.append(~x)
257 todo.append(x)
258 else:
259 nodeout[~v] = curtime
260 return nodein, nodeout