Source code for titan_pylib.graph.weighted_rooted_tree

[docs] 1class WeightedRootedTree: 2 3 def __init__( 4 self, 5 _G: list[list[tuple[int, int]]], 6 _root: int, 7 cp: bool = False, 8 lca: bool = False, 9 ): 10 self._n = len(_G) 11 self._G = _G 12 self._root = _root 13 self._height = -1 14 self._toposo = [] 15 self._dist = [] 16 self._descendant_num = [] 17 self._child = [] 18 self._child_num = [] 19 self._parents = [] 20 self._diameter = (-1, -1, -1) 21 self._bipartite_graph = [] 22 self._cp = cp 23 self._lca = lca 24 self._rank = [] 25 K = 1 26 while 1 << K < self._n: 27 K += 1 28 self._K = K 29 self._doubling = [[-1] * self._n for _ in range(self._K)] 30 self._calc_dist_toposo() 31 if cp: 32 self._calc_child_parents() 33 if lca: 34 self._calc_doubling() 35 36 def __str__(self): 37 self._calc_child_parents() 38 ret = ["<WeightedRootedTree> ["] 39 ret.extend( 40 [ 41 f" dist:{str(d).zfill(2)} - v:{str(i).zfill(2)} - p:{str(self._parents[i]).zfill(2)} - child:{sorted(self._child[i])}" 42 for i, d in sorted(enumerate(self._dist), key=lambda x: x[1]) 43 ] 44 ) 45 ret.append("]") 46 return "\n".join(ret) 47 48 def _calc_dist_toposo(self) -> None: 49 """Calc dist and toposo. / O(N)""" 50 # initメソッドで直接実行 51 _G, _root = self._G, self._root 52 _dist = [-1] * self._n 53 _rank = [-1] * self._n 54 _dist[_root] = 0 55 _rank[_root] = 0 56 _toposo = [] 57 _toposo.append(_root) 58 todo = [_root] 59 while todo: 60 v = todo.pop() 61 d = _dist[v] 62 r = _rank[v] 63 for x, c in _G[v]: 64 if _dist[x] != -1: 65 continue 66 _dist[x] = d + c 67 _rank[x] = r + 1 68 todo.append(x) 69 _toposo.append(x) 70 self._dist = _dist 71 self._rank = _rank 72 self._toposo = _toposo 73 74 def _calc_child_parents(self) -> None: 75 """Calc child and parents. / O(N)""" 76 if self._child and self._child_num and self._parents: 77 return 78 _G, _rank = self._G, self._rank 79 _child_num = [0] * self._n 80 _child = [[] for _ in range(self._n)] 81 _parents = [-1] * self._n 82 for v in self._toposo[::-1]: 83 for x, _ in _G[v]: 84 if _rank[x] < _rank[v]: 85 _parents[v] = x 86 continue 87 _child[v].append(x) 88 _child_num[v] += 1 89 self._child_num = _child_num 90 self._child = _child 91 self._parents = _parents 92
[docs] 93 def get_dists(self) -> list[int]: 94 """Return dist from root. / O(N)""" 95 return self._dist
96
[docs] 97 def get_toposo(self) -> list[int]: 98 """Return toposo. / O(N)""" 99 return self._toposo
100
[docs] 101 def get_height(self) -> int: 102 """Return height. / O(N)""" 103 if self._height > -1: 104 return self._height 105 self._height = max(self._dist) 106 return self._height
107
[docs] 108 def get_descendant_num(self) -> list[int]: 109 """Return descendant_num. / O(N)""" 110 if self._descendant_num: 111 return self._descendant_num 112 _G, _dist = self._G, self._dist 113 _descendant_num = [1] * self._n 114 for v in self._toposo[::-1]: 115 for x, _ in _G[v]: 116 if _dist[x] < _dist[v]: 117 continue 118 _descendant_num[v] += _descendant_num[x] 119 for i in range(self._n): 120 _descendant_num[i] -= 1 121 self._descendant_num = _descendant_num 122 return self._descendant_num
123
[docs] 124 def get_child(self) -> list[list[int]]: 125 """Return child / O(N)""" 126 if self._child: 127 return self._child 128 self._calc_child_parents() 129 return self._child
130
[docs] 131 def get_child_num(self) -> list[int]: 132 """Return child_num. / O(N)""" 133 if self._child_num: 134 return self._child_num 135 self._calc_child_parents() 136 return self._child_num
137
[docs] 138 def get_parents(self) -> list[int]: 139 """Return parents. / O(N)""" 140 if self._parents: 141 return self._parents 142 self._calc_child_parents() 143 return self._parents
144
[docs] 145 def get_diameter(self) -> tuple[int, int, int]: 146 """Return diameter of tree. (diameter, start, stop) / O(N)""" 147 if self._diameter[0] > -1: 148 return self._diameter 149 s = self._dist.index(self.get_height()) 150 todo = [s] 151 ndist = [-1] * self._n 152 ndist[s] = 0 153 while todo: 154 v = todo.pop() 155 d = ndist[v] 156 for x, c in self._G[v]: 157 if ndist[x] != -1: 158 continue 159 ndist[x] = d + c 160 todo.append(x) 161 diameter = max(ndist) 162 t = ndist.index(diameter) 163 self._diameter = (diameter, s, t) 164 return self._diameter
165
[docs] 166 def get_bipartite_graph(self) -> list[int]: 167 """Return [1 if root else 0]. / O(N)""" 168 if self._bipartite_graph: 169 return self._bipartite_graph 170 self._bipartite_graph = [-1] * self._n 171 _bipartite_graph = self._bipartite_graph 172 _bipartite_graph[self._root] = 1 173 todo = [self._root] 174 while todo: 175 v = todo.pop() 176 nc = 0 if _bipartite_graph[v] else 1 177 for x, _ in self._G[v]: 178 if _bipartite_graph[x] != -1: 179 continue 180 _bipartite_graph[x] = nc 181 todo.append(x) 182 return _bipartite_graph
183 184 def _calc_doubling(self) -> None: 185 "Calc doubling if self._lca. / O(NlogN)" 186 if not self._parents: 187 self._calc_child_parents() 188 for i in range(self._n): 189 self._doubling[0][i] = self._parents[i] 190 for k in range(self._K - 1): 191 for v in range(self._n): 192 if self._doubling[k][v] < 0: 193 self._doubling[k + 1][v] = -1 194 else: 195 self._doubling[k + 1][v] = self._doubling[k][self._doubling[k][v]] 196
[docs] 197 def get_lca(self, u: int, v: int) -> int: 198 """Return LCA of (u, v). / O(logN)""" 199 assert self._lca, f"{self.__class__.__name__}.get_lca(), `lca` must be True" 200 _doubling, _rank = self._doubling, self._rank 201 if _rank[u] < _rank[v]: 202 u, v = v, u 203 _r = _rank[u] - _rank[v] 204 for k in range(self._K): 205 if _r >> k & 1: 206 u = _doubling[k][u] 207 if u == v: 208 return u 209 for k in range(self._K - 1, -1, -1): 210 if _doubling[k][u] != _doubling[k][v]: 211 u = _doubling[k][u] 212 v = _doubling[k][v] 213 return _doubling[0][u]
214
[docs] 215 def get_dist(self, u: int, v: int) -> int: 216 """Return dist(u - v). / O(logN)""" 217 return self._dist[u] + self._dist[v] - 2 * self._dist[self.get_lca(u, v)]
218
[docs] 219 def is_on_path(self, u: int, v: int, a: int) -> bool: 220 """Return True if (a is on path(u - v)) else False. / O(logN)""" 221 raise NotImplementedError 222 return self.get_dist(u, a) + self.get_dist(a, v) == self.get_dist( 223 u, v 224 ) # rank??
225
[docs] 226 def get_path(self, u: int, v: int) -> list[int]: 227 """Return path (u -> v). / O(logN + |path|)""" 228 assert self._lca, f"{self.__class__.__name__}, `lca` must be True" 229 if u == v: 230 return [u] 231 self.get_parents() 232 233 def get_path_lca(u: int, v: int) -> list[int]: 234 path = [] 235 while u != v: 236 u = self._parents[u] 237 if u == v: 238 break 239 path.append(u) 240 return path 241 242 lca = self.get_lca(u, v) 243 path = [u] 244 path.extend(get_path_lca(u, lca)) 245 if u != lca and v != lca: 246 path.append(lca) 247 path.extend(get_path_lca(v, lca)[::-1]) 248 path.append(v) 249 return path
250
[docs] 251 def dfs_in_out(self) -> tuple[list[int], list[int]]: 252 curtime = -1 253 todo = [~self._root, self._root] 254 intime = [-1] * self._n 255 outtime = [-1] * self._n 256 seen = [False] * self._n 257 seen[self._root] = True 258 while todo: 259 curtime += 1 260 v = todo.pop() 261 if v >= 0: 262 intime[v] = curtime 263 for x, _ in self._G[v]: 264 if not seen[x]: 265 todo.append(~x) 266 todo.append(x) 267 seen[x] = True 268 else: 269 outtime[~v] = curtime 270 return intime, outtime