[docs]
1class WeightedRootedTree:
2
3 def __init__(
4 self,
5 _G: list[list[tuple[int, int]]],
6 _root: int,
7 cp: bool = False,
8 lca: bool = False,
9 ):
10 self._n = len(_G)
11 self._G = _G
12 self._root = _root
13 self._height = -1
14 self._toposo = []
15 self._dist = []
16 self._descendant_num = []
17 self._child = []
18 self._child_num = []
19 self._parents = []
20 self._diameter = (-1, -1, -1)
21 self._bipartite_graph = []
22 self._cp = cp
23 self._lca = lca
24 self._rank = []
25 K = 1
26 while 1 << K < self._n:
27 K += 1
28 self._K = K
29 self._doubling = [[-1] * self._n for _ in range(self._K)]
30 self._calc_dist_toposo()
31 if cp:
32 self._calc_child_parents()
33 if lca:
34 self._calc_doubling()
35
36 def __str__(self):
37 self._calc_child_parents()
38 ret = ["<WeightedRootedTree> ["]
39 ret.extend(
40 [
41 f" dist:{str(d).zfill(2)} - v:{str(i).zfill(2)} - p:{str(self._parents[i]).zfill(2)} - child:{sorted(self._child[i])}"
42 for i, d in sorted(enumerate(self._dist), key=lambda x: x[1])
43 ]
44 )
45 ret.append("]")
46 return "\n".join(ret)
47
48 def _calc_dist_toposo(self) -> None:
49 """Calc dist and toposo. / O(N)"""
50 # initメソッドで直接実行
51 _G, _root = self._G, self._root
52 _dist = [-1] * self._n
53 _rank = [-1] * self._n
54 _dist[_root] = 0
55 _rank[_root] = 0
56 _toposo = []
57 _toposo.append(_root)
58 todo = [_root]
59 while todo:
60 v = todo.pop()
61 d = _dist[v]
62 r = _rank[v]
63 for x, c in _G[v]:
64 if _dist[x] != -1:
65 continue
66 _dist[x] = d + c
67 _rank[x] = r + 1
68 todo.append(x)
69 _toposo.append(x)
70 self._dist = _dist
71 self._rank = _rank
72 self._toposo = _toposo
73
74 def _calc_child_parents(self) -> None:
75 """Calc child and parents. / O(N)"""
76 if self._child and self._child_num and self._parents:
77 return
78 _G, _rank = self._G, self._rank
79 _child_num = [0] * self._n
80 _child = [[] for _ in range(self._n)]
81 _parents = [-1] * self._n
82 for v in self._toposo[::-1]:
83 for x, _ in _G[v]:
84 if _rank[x] < _rank[v]:
85 _parents[v] = x
86 continue
87 _child[v].append(x)
88 _child_num[v] += 1
89 self._child_num = _child_num
90 self._child = _child
91 self._parents = _parents
92
[docs]
93 def get_dists(self) -> list[int]:
94 """Return dist from root. / O(N)"""
95 return self._dist
96
[docs]
97 def get_toposo(self) -> list[int]:
98 """Return toposo. / O(N)"""
99 return self._toposo
100
[docs]
101 def get_height(self) -> int:
102 """Return height. / O(N)"""
103 if self._height > -1:
104 return self._height
105 self._height = max(self._dist)
106 return self._height
107
[docs]
108 def get_descendant_num(self) -> list[int]:
109 """Return descendant_num. / O(N)"""
110 if self._descendant_num:
111 return self._descendant_num
112 _G, _dist = self._G, self._dist
113 _descendant_num = [1] * self._n
114 for v in self._toposo[::-1]:
115 for x, _ in _G[v]:
116 if _dist[x] < _dist[v]:
117 continue
118 _descendant_num[v] += _descendant_num[x]
119 for i in range(self._n):
120 _descendant_num[i] -= 1
121 self._descendant_num = _descendant_num
122 return self._descendant_num
123
[docs]
124 def get_child(self) -> list[list[int]]:
125 """Return child / O(N)"""
126 if self._child:
127 return self._child
128 self._calc_child_parents()
129 return self._child
130
[docs]
131 def get_child_num(self) -> list[int]:
132 """Return child_num. / O(N)"""
133 if self._child_num:
134 return self._child_num
135 self._calc_child_parents()
136 return self._child_num
137
[docs]
138 def get_parents(self) -> list[int]:
139 """Return parents. / O(N)"""
140 if self._parents:
141 return self._parents
142 self._calc_child_parents()
143 return self._parents
144
[docs]
145 def get_diameter(self) -> tuple[int, int, int]:
146 """Return diameter of tree. (diameter, start, stop) / O(N)"""
147 if self._diameter[0] > -1:
148 return self._diameter
149 s = self._dist.index(self.get_height())
150 todo = [s]
151 ndist = [-1] * self._n
152 ndist[s] = 0
153 while todo:
154 v = todo.pop()
155 d = ndist[v]
156 for x, c in self._G[v]:
157 if ndist[x] != -1:
158 continue
159 ndist[x] = d + c
160 todo.append(x)
161 diameter = max(ndist)
162 t = ndist.index(diameter)
163 self._diameter = (diameter, s, t)
164 return self._diameter
165
[docs]
166 def get_bipartite_graph(self) -> list[int]:
167 """Return [1 if root else 0]. / O(N)"""
168 if self._bipartite_graph:
169 return self._bipartite_graph
170 self._bipartite_graph = [-1] * self._n
171 _bipartite_graph = self._bipartite_graph
172 _bipartite_graph[self._root] = 1
173 todo = [self._root]
174 while todo:
175 v = todo.pop()
176 nc = 0 if _bipartite_graph[v] else 1
177 for x, _ in self._G[v]:
178 if _bipartite_graph[x] != -1:
179 continue
180 _bipartite_graph[x] = nc
181 todo.append(x)
182 return _bipartite_graph
183
184 def _calc_doubling(self) -> None:
185 "Calc doubling if self._lca. / O(NlogN)"
186 if not self._parents:
187 self._calc_child_parents()
188 for i in range(self._n):
189 self._doubling[0][i] = self._parents[i]
190 for k in range(self._K - 1):
191 for v in range(self._n):
192 if self._doubling[k][v] < 0:
193 self._doubling[k + 1][v] = -1
194 else:
195 self._doubling[k + 1][v] = self._doubling[k][self._doubling[k][v]]
196
[docs]
197 def get_lca(self, u: int, v: int) -> int:
198 """Return LCA of (u, v). / O(logN)"""
199 assert self._lca, f"{self.__class__.__name__}.get_lca(), `lca` must be True"
200 _doubling, _rank = self._doubling, self._rank
201 if _rank[u] < _rank[v]:
202 u, v = v, u
203 _r = _rank[u] - _rank[v]
204 for k in range(self._K):
205 if _r >> k & 1:
206 u = _doubling[k][u]
207 if u == v:
208 return u
209 for k in range(self._K - 1, -1, -1):
210 if _doubling[k][u] != _doubling[k][v]:
211 u = _doubling[k][u]
212 v = _doubling[k][v]
213 return _doubling[0][u]
214
[docs]
215 def get_dist(self, u: int, v: int) -> int:
216 """Return dist(u - v). / O(logN)"""
217 return self._dist[u] + self._dist[v] - 2 * self._dist[self.get_lca(u, v)]
218
[docs]
219 def is_on_path(self, u: int, v: int, a: int) -> bool:
220 """Return True if (a is on path(u - v)) else False. / O(logN)"""
221 raise NotImplementedError
222 return self.get_dist(u, a) + self.get_dist(a, v) == self.get_dist(
223 u, v
224 ) # rank??
225
[docs]
226 def get_path(self, u: int, v: int) -> list[int]:
227 """Return path (u -> v). / O(logN + |path|)"""
228 assert self._lca, f"{self.__class__.__name__}, `lca` must be True"
229 if u == v:
230 return [u]
231 self.get_parents()
232
233 def get_path_lca(u: int, v: int) -> list[int]:
234 path = []
235 while u != v:
236 u = self._parents[u]
237 if u == v:
238 break
239 path.append(u)
240 return path
241
242 lca = self.get_lca(u, v)
243 path = [u]
244 path.extend(get_path_lca(u, lca))
245 if u != lca and v != lca:
246 path.append(lca)
247 path.extend(get_path_lca(v, lca)[::-1])
248 path.append(v)
249 return path
250
[docs]
251 def dfs_in_out(self) -> tuple[list[int], list[int]]:
252 curtime = -1
253 todo = [~self._root, self._root]
254 intime = [-1] * self._n
255 outtime = [-1] * self._n
256 seen = [False] * self._n
257 seen[self._root] = True
258 while todo:
259 curtime += 1
260 v = todo.pop()
261 if v >= 0:
262 intime[v] = curtime
263 for x, _ in self._G[v]:
264 if not seen[x]:
265 todo.append(~x)
266 todo.append(x)
267 seen[x] = True
268 else:
269 outtime[~v] = curtime
270 return intime, outtime