Source code for titan_pylib.graph.rooted_tree

[docs] 1class RootedTree: 2 3 def __init__( 4 self, _G: list[list[int]], _root: int, cp: bool = False, lca: bool = False 5 ): 6 self._n: int = len(_G) 7 self._G: list[list[int]] = _G 8 self._root: int = _root 9 self._height: int = -1 10 self._toposo: list[int] = [] 11 self._dist: list[int] = [] 12 self._descendant_num: list[int] = [] 13 self._child: list[list[int]] = [] 14 self._child_num: list[int] = [] 15 self._parents: list[int] = [] 16 self._diameter: tuple[int, int, int] = (-1, -1, -1) 17 self._bipartite_graph = [] 18 self._cp = cp 19 self._lca = lca 20 K = self._n.bit_length() 21 self._K = K 22 self._doubling = [[-1] * self._n for _ in range(self._K)] 23 self._calc_dist_toposo() 24 if cp: 25 self._calc_child_parents() 26 if lca: 27 self._calc_doubling() 28 29 def __str__(self): 30 self._calc_child_parents() 31 ret = ["<RootedTree> ["] 32 ret.extend( 33 [ 34 f" dist:{str(d).zfill(2)} - v:{str(i).zfill(2)} - p:{str(self._parents[i]).zfill(2)} - child:{sorted(self._child[i])}" 35 for i, d in sorted(enumerate(self._dist), key=lambda x: x[1]) 36 ] 37 ) 38 ret.append("]") 39 return "\n".join(ret) 40 41 def _calc_dist_toposo(self) -> None: 42 """Calc dist and toposo. / O(N)""" 43 # initメソッドで直接実行 44 _G, _root = self._G, self._root 45 _dist = [-1] * self._n 46 _dist[_root] = 0 47 _toposo = [] 48 _toposo.append(_root) 49 todo = [_root] 50 while todo: 51 v = todo.pop() 52 d = _dist[v] 53 for x in _G[v]: 54 if _dist[x] != -1: 55 continue 56 _dist[x] = d + 1 57 todo.append(x) 58 _toposo.append(x) 59 self._dist = _dist 60 self._toposo = _toposo 61 62 def _calc_child_parents(self) -> None: 63 """Calc child and parents. / O(N)""" 64 if self._child and self._child_num and self._parents: 65 return 66 _G, _dist = self._G, self._dist 67 _child_num = [0] * self._n 68 _child = [[] for _ in range(self._n)] 69 _parents = [-1] * self._n 70 for v in self._toposo[::-1]: 71 for x in _G[v]: 72 if _dist[x] < _dist[v]: 73 _parents[v] = x 74 continue 75 _child[v].append(x) 76 _child_num[v] += 1 77 self._child_num = _child_num 78 self._child = _child 79 self._parents = _parents 80
[docs] 81 def get_dists(self) -> list[int]: 82 """Return dist from root. / O(N)""" 83 return self._dist
84
[docs] 85 def get_toposo(self) -> list[int]: 86 """Return toposo. / O(N)""" 87 return self._toposo
88
[docs] 89 def get_height(self) -> int: 90 """Return height. / O(N)""" 91 if self._height > -1: 92 return self._height 93 self._height = max(self._dist) 94 return self._height
95
[docs] 96 def get_descendant_num(self) -> list[int]: 97 """Return descendant_num. / O(N)""" 98 if self._descendant_num: 99 return self._descendant_num 100 _G, _dist = self._G, self._dist 101 _descendant_num = [1] * self._n 102 for v in self._toposo[::-1]: 103 for x in _G[v]: 104 if _dist[x] < _dist[v]: 105 continue 106 _descendant_num[v] += _descendant_num[x] 107 for i in range(self._n): 108 _descendant_num[i] -= 1 109 self._descendant_num = _descendant_num 110 return self._descendant_num
111
[docs] 112 def get_child(self) -> list[list[int]]: 113 """Return child / O(N)""" 114 if self._child: 115 return self._child 116 self._calc_child_parents() 117 return self._child
118
[docs] 119 def get_child_num(self) -> list[int]: 120 """Return child_num. / O(N)""" 121 if self._child_num: 122 return self._child_num 123 self._calc_child_parents() 124 return self._child_num
125
[docs] 126 def get_parents(self) -> list[int]: 127 """Return parents. / O(N)""" 128 if self._parents: 129 return self._parents 130 self._calc_child_parents() 131 return self._parents
132
[docs] 133 def get_diameter(self) -> tuple[int, int, int]: 134 """Return diameter of tree. (diameter, start, stop) / O(N)""" 135 if self._diameter[0] > -1: 136 return self._diameter 137 s = self._dist.index(self.get_height()) 138 todo = [s] 139 ndist = [-1] * self._n 140 ndist[s] = 0 141 while todo: 142 v = todo.pop() 143 d = ndist[v] 144 for x in self._G[v]: 145 if ndist[x] != -1: 146 continue 147 ndist[x] = d + 1 148 todo.append(x) 149 diameter = max(ndist) 150 t = ndist.index(diameter) 151 self._diameter = (diameter, s, t) 152 return self._diameter
153
[docs] 154 def get_bipartite_graph(self) -> list[int]: 155 """Return [1 if root else 0]. / O(N)""" 156 if self._bipartite_graph: 157 return self._bipartite_graph 158 self._bipartite_graph = [-1] * self._n 159 self._bipartite_graph[self._root] = 1 160 todo = [self._root] 161 while todo: 162 v = todo.pop() 163 nc = 0 if self._bipartite_graph[v] else 1 164 for x in self._G[v]: 165 if self._bipartite_graph[x] != -1: 166 continue 167 self._bipartite_graph[x] = nc 168 todo.append(x) 169 return self._bipartite_graph
170 171 def _calc_doubling(self) -> None: 172 "Calc doubling if self._lca. / O(NlogN)" 173 if not self._parents: 174 self._calc_child_parents() 175 _doubling = self._doubling 176 for i in range(self._n): 177 _doubling[0][i] = self._parents[i] 178 for k in range(self._K - 1): 179 for v in range(self._n): 180 if _doubling[k][v] < 0: 181 _doubling[k + 1][v] = -1 182 else: 183 _doubling[k + 1][v] = _doubling[k][_doubling[k][v]] 184
[docs] 185 def get_lca(self, u: int, v: int) -> int: 186 """Return LCA of (u, v). / O(logN)""" 187 assert ( 188 self._lca 189 ), f"Error: {self.__class__.__name__}.get_lca({u}, {v}), `lca` must be True" 190 _doubling, _dist = self._doubling, self._dist 191 if _dist[u] < _dist[v]: 192 u, v = v, u 193 _r = _dist[u] - _dist[v] 194 for k in range(self._K): 195 if _r >> k & 1: 196 u = _doubling[k][u] 197 if u == v: 198 return u 199 for k in range(self._K - 1, -1, -1): 200 if _doubling[k][u] != _doubling[k][v]: 201 u = _doubling[k][u] 202 v = _doubling[k][v] 203 return _doubling[0][u]
204
[docs] 205 def get_dist(self, u: int, v: int, vertex: bool = False) -> int: 206 """Return dist(u -> v). / O(logN)""" 207 return ( 208 self._dist[u] + self._dist[v] - 2 * self._dist[self.get_lca(u, v)] + vertex 209 )
210
[docs] 211 def is_on_path(self, u: int, v: int, a: int) -> bool: 212 """Return True if (a is on path(u - v)) else False. / O(logN)""" 213 return self.get_dist(u, a) + self.get_dist(a, v) == self.get_dist(u, v)
214
[docs] 215 def get_path(self, u: int, v: int) -> list[int]: 216 """Return path (u -> v). / O(logN + |path|)""" 217 assert self._lca, f"{self.__class__.__name__}.get_path(), `lca` must be True" 218 if u == v: 219 return [u] 220 self.get_parents() 221 222 def get_path_lca(u: int, v: int) -> list[int]: 223 path = [] 224 while u != v: 225 u = self._parents[u] 226 if u == v: 227 break 228 path.append(u) 229 return path 230 231 lca = self.get_lca(u, v) 232 path = [u] 233 path.extend(get_path_lca(u, lca)) 234 if u != lca and v != lca: 235 path.append(lca) 236 path.extend(get_path_lca(v, lca)[::-1]) 237 path.append(v) 238 return path
239
[docs] 240 def dfs_in_out(self) -> tuple[list[int], list[int]]: 241 curtime = -1 242 todo = [~self._root, self._root] 243 nodein = [-1] * self._n 244 nodeout = [-1] * self._n 245 if not self._parents: 246 self._calc_child_parents() 247 _G, _parents = self._G, self._parents 248 while todo: 249 curtime += 1 250 v = todo.pop() 251 if v >= 0: 252 nodein[v] = curtime 253 for x in _G[v]: 254 if _parents[v] != x: 255 todo.append(~x) 256 todo.append(x) 257 else: 258 nodeout[~v] = curtime 259 return nodein, nodeout