1# from titan_pylib.graph.weighted_rooted_tree import WeightedRootedTree
2class WeightedRootedTree:
3
4 def __init__(
5 self,
6 _G: list[list[tuple[int, int]]],
7 _root: int,
8 cp: bool = False,
9 lca: bool = False,
10 ):
11 self._n = len(_G)
12 self._G = _G
13 self._root = _root
14 self._height = -1
15 self._toposo = []
16 self._dist = []
17 self._descendant_num = []
18 self._child = []
19 self._child_num = []
20 self._parents = []
21 self._diameter = (-1, -1, -1)
22 self._bipartite_graph = []
23 self._cp = cp
24 self._lca = lca
25 self._rank = []
26 K = 1
27 while 1 << K < self._n:
28 K += 1
29 self._K = K
30 self._doubling = [[-1] * self._n for _ in range(self._K)]
31 self._calc_dist_toposo()
32 if cp:
33 self._calc_child_parents()
34 if lca:
35 self._calc_doubling()
36
37 def __str__(self):
38 self._calc_child_parents()
39 ret = ["<WeightedRootedTree> ["]
40 ret.extend(
41 [
42 f" dist:{str(d).zfill(2)} - v:{str(i).zfill(2)} - p:{str(self._parents[i]).zfill(2)} - child:{sorted(self._child[i])}"
43 for i, d in sorted(enumerate(self._dist), key=lambda x: x[1])
44 ]
45 )
46 ret.append("]")
47 return "\n".join(ret)
48
49 def _calc_dist_toposo(self) -> None:
50 """Calc dist and toposo. / O(N)"""
51 # initメソッドで直接実行
52 _G, _root = self._G, self._root
53 _dist = [-1] * self._n
54 _rank = [-1] * self._n
55 _dist[_root] = 0
56 _rank[_root] = 0
57 _toposo = []
58 _toposo.append(_root)
59 todo = [_root]
60 while todo:
61 v = todo.pop()
62 d = _dist[v]
63 r = _rank[v]
64 for x, c in _G[v]:
65 if _dist[x] != -1:
66 continue
67 _dist[x] = d + c
68 _rank[x] = r + 1
69 todo.append(x)
70 _toposo.append(x)
71 self._dist = _dist
72 self._rank = _rank
73 self._toposo = _toposo
74
75 def _calc_child_parents(self) -> None:
76 """Calc child and parents. / O(N)"""
77 if self._child and self._child_num and self._parents:
78 return
79 _G, _rank = self._G, self._rank
80 _child_num = [0] * self._n
81 _child = [[] for _ in range(self._n)]
82 _parents = [-1] * self._n
83 for v in self._toposo[::-1]:
84 for x, _ in _G[v]:
85 if _rank[x] < _rank[v]:
86 _parents[v] = x
87 continue
88 _child[v].append(x)
89 _child_num[v] += 1
90 self._child_num = _child_num
91 self._child = _child
92 self._parents = _parents
93
94 def get_dists(self) -> list[int]:
95 """Return dist from root. / O(N)"""
96 return self._dist
97
98 def get_toposo(self) -> list[int]:
99 """Return toposo. / O(N)"""
100 return self._toposo
101
102 def get_height(self) -> int:
103 """Return height. / O(N)"""
104 if self._height > -1:
105 return self._height
106 self._height = max(self._dist)
107 return self._height
108
109 def get_descendant_num(self) -> list[int]:
110 """Return descendant_num. / O(N)"""
111 if self._descendant_num:
112 return self._descendant_num
113 _G, _dist = self._G, self._dist
114 _descendant_num = [1] * self._n
115 for v in self._toposo[::-1]:
116 for x, _ in _G[v]:
117 if _dist[x] < _dist[v]:
118 continue
119 _descendant_num[v] += _descendant_num[x]
120 for i in range(self._n):
121 _descendant_num[i] -= 1
122 self._descendant_num = _descendant_num
123 return self._descendant_num
124
125 def get_child(self) -> list[list[int]]:
126 """Return child / O(N)"""
127 if self._child:
128 return self._child
129 self._calc_child_parents()
130 return self._child
131
132 def get_child_num(self) -> list[int]:
133 """Return child_num. / O(N)"""
134 if self._child_num:
135 return self._child_num
136 self._calc_child_parents()
137 return self._child_num
138
139 def get_parents(self) -> list[int]:
140 """Return parents. / O(N)"""
141 if self._parents:
142 return self._parents
143 self._calc_child_parents()
144 return self._parents
145
146 def get_diameter(self) -> tuple[int, int, int]:
147 """Return diameter of tree. (diameter, start, stop) / O(N)"""
148 if self._diameter[0] > -1:
149 return self._diameter
150 s = self._dist.index(self.get_height())
151 todo = [s]
152 ndist = [-1] * self._n
153 ndist[s] = 0
154 while todo:
155 v = todo.pop()
156 d = ndist[v]
157 for x, c in self._G[v]:
158 if ndist[x] != -1:
159 continue
160 ndist[x] = d + c
161 todo.append(x)
162 diameter = max(ndist)
163 t = ndist.index(diameter)
164 self._diameter = (diameter, s, t)
165 return self._diameter
166
167 def get_bipartite_graph(self) -> list[int]:
168 """Return [1 if root else 0]. / O(N)"""
169 if self._bipartite_graph:
170 return self._bipartite_graph
171 self._bipartite_graph = [-1] * self._n
172 _bipartite_graph = self._bipartite_graph
173 _bipartite_graph[self._root] = 1
174 todo = [self._root]
175 while todo:
176 v = todo.pop()
177 nc = 0 if _bipartite_graph[v] else 1
178 for x, _ in self._G[v]:
179 if _bipartite_graph[x] != -1:
180 continue
181 _bipartite_graph[x] = nc
182 todo.append(x)
183 return _bipartite_graph
184
185 def _calc_doubling(self) -> None:
186 "Calc doubling if self._lca. / O(NlogN)"
187 if not self._parents:
188 self._calc_child_parents()
189 for i in range(self._n):
190 self._doubling[0][i] = self._parents[i]
191 for k in range(self._K - 1):
192 for v in range(self._n):
193 if self._doubling[k][v] < 0:
194 self._doubling[k + 1][v] = -1
195 else:
196 self._doubling[k + 1][v] = self._doubling[k][self._doubling[k][v]]
197
198 def get_lca(self, u: int, v: int) -> int:
199 """Return LCA of (u, v). / O(logN)"""
200 assert self._lca, f"{self.__class__.__name__}.get_lca(), `lca` must be True"
201 _doubling, _rank = self._doubling, self._rank
202 if _rank[u] < _rank[v]:
203 u, v = v, u
204 _r = _rank[u] - _rank[v]
205 for k in range(self._K):
206 if _r >> k & 1:
207 u = _doubling[k][u]
208 if u == v:
209 return u
210 for k in range(self._K - 1, -1, -1):
211 if _doubling[k][u] != _doubling[k][v]:
212 u = _doubling[k][u]
213 v = _doubling[k][v]
214 return _doubling[0][u]
215
216 def get_dist(self, u: int, v: int) -> int:
217 """Return dist(u - v). / O(logN)"""
218 return self._dist[u] + self._dist[v] - 2 * self._dist[self.get_lca(u, v)]
219
220 def is_on_path(self, u: int, v: int, a: int) -> bool:
221 """Return True if (a is on path(u - v)) else False. / O(logN)"""
222 raise NotImplementedError
223 return self.get_dist(u, a) + self.get_dist(a, v) == self.get_dist(
224 u, v
225 ) # rank??
226
227 def get_path(self, u: int, v: int) -> list[int]:
228 """Return path (u -> v). / O(logN + |path|)"""
229 assert self._lca, f"{self.__class__.__name__}, `lca` must be True"
230 if u == v:
231 return [u]
232 self.get_parents()
233
234 def get_path_lca(u: int, v: int) -> list[int]:
235 path = []
236 while u != v:
237 u = self._parents[u]
238 if u == v:
239 break
240 path.append(u)
241 return path
242
243 lca = self.get_lca(u, v)
244 path = [u]
245 path.extend(get_path_lca(u, lca))
246 if u != lca and v != lca:
247 path.append(lca)
248 path.extend(get_path_lca(v, lca)[::-1])
249 path.append(v)
250 return path
251
252 def dfs_in_out(self) -> tuple[list[int], list[int]]:
253 curtime = -1
254 todo = [~self._root, self._root]
255 intime = [-1] * self._n
256 outtime = [-1] * self._n
257 seen = [False] * self._n
258 seen[self._root] = True
259 while todo:
260 curtime += 1
261 v = todo.pop()
262 if v >= 0:
263 intime[v] = curtime
264 for x, _ in self._G[v]:
265 if not seen[x]:
266 todo.append(~x)
267 todo.append(x)
268 seen[x] = True
269 else:
270 outtime[~v] = curtime
271 return intime, outtime