weighted_rooted_tree

ソースコード

from titan_pylib.graph.weighted_rooted_tree import WeightedRootedTree

view on github

展開済みコード

  1# from titan_pylib.graph.weighted_rooted_tree import WeightedRootedTree
  2class WeightedRootedTree:
  3
  4    def __init__(
  5        self,
  6        _G: list[list[tuple[int, int]]],
  7        _root: int,
  8        cp: bool = False,
  9        lca: bool = False,
 10    ):
 11        self._n = len(_G)
 12        self._G = _G
 13        self._root = _root
 14        self._height = -1
 15        self._toposo = []
 16        self._dist = []
 17        self._descendant_num = []
 18        self._child = []
 19        self._child_num = []
 20        self._parents = []
 21        self._diameter = (-1, -1, -1)
 22        self._bipartite_graph = []
 23        self._cp = cp
 24        self._lca = lca
 25        self._rank = []
 26        K = 1
 27        while 1 << K < self._n:
 28            K += 1
 29        self._K = K
 30        self._doubling = [[-1] * self._n for _ in range(self._K)]
 31        self._calc_dist_toposo()
 32        if cp:
 33            self._calc_child_parents()
 34        if lca:
 35            self._calc_doubling()
 36
 37    def __str__(self):
 38        self._calc_child_parents()
 39        ret = ["<WeightedRootedTree> ["]
 40        ret.extend(
 41            [
 42                f"  dist:{str(d).zfill(2)} - v:{str(i).zfill(2)} - p:{str(self._parents[i]).zfill(2)} - child:{sorted(self._child[i])}"
 43                for i, d in sorted(enumerate(self._dist), key=lambda x: x[1])
 44            ]
 45        )
 46        ret.append("]")
 47        return "\n".join(ret)
 48
 49    def _calc_dist_toposo(self) -> None:
 50        """Calc dist and toposo. / O(N)"""
 51        # initメソッドで直接実行
 52        _G, _root = self._G, self._root
 53        _dist = [-1] * self._n
 54        _rank = [-1] * self._n
 55        _dist[_root] = 0
 56        _rank[_root] = 0
 57        _toposo = []
 58        _toposo.append(_root)
 59        todo = [_root]
 60        while todo:
 61            v = todo.pop()
 62            d = _dist[v]
 63            r = _rank[v]
 64            for x, c in _G[v]:
 65                if _dist[x] != -1:
 66                    continue
 67                _dist[x] = d + c
 68                _rank[x] = r + 1
 69                todo.append(x)
 70                _toposo.append(x)
 71        self._dist = _dist
 72        self._rank = _rank
 73        self._toposo = _toposo
 74
 75    def _calc_child_parents(self) -> None:
 76        """Calc child and parents. / O(N)"""
 77        if self._child and self._child_num and self._parents:
 78            return
 79        _G, _rank = self._G, self._rank
 80        _child_num = [0] * self._n
 81        _child = [[] for _ in range(self._n)]
 82        _parents = [-1] * self._n
 83        for v in self._toposo[::-1]:
 84            for x, _ in _G[v]:
 85                if _rank[x] < _rank[v]:
 86                    _parents[v] = x
 87                    continue
 88                _child[v].append(x)
 89                _child_num[v] += 1
 90        self._child_num = _child_num
 91        self._child = _child
 92        self._parents = _parents
 93
 94    def get_dists(self) -> list[int]:
 95        """Return dist from root. / O(N)"""
 96        return self._dist
 97
 98    def get_toposo(self) -> list[int]:
 99        """Return toposo. / O(N)"""
100        return self._toposo
101
102    def get_height(self) -> int:
103        """Return height. / O(N)"""
104        if self._height > -1:
105            return self._height
106        self._height = max(self._dist)
107        return self._height
108
109    def get_descendant_num(self) -> list[int]:
110        """Return descendant_num. / O(N)"""
111        if self._descendant_num:
112            return self._descendant_num
113        _G, _dist = self._G, self._dist
114        _descendant_num = [1] * self._n
115        for v in self._toposo[::-1]:
116            for x, _ in _G[v]:
117                if _dist[x] < _dist[v]:
118                    continue
119                _descendant_num[v] += _descendant_num[x]
120        for i in range(self._n):
121            _descendant_num[i] -= 1
122        self._descendant_num = _descendant_num
123        return self._descendant_num
124
125    def get_child(self) -> list[list[int]]:
126        """Return child / O(N)"""
127        if self._child:
128            return self._child
129        self._calc_child_parents()
130        return self._child
131
132    def get_child_num(self) -> list[int]:
133        """Return child_num. / O(N)"""
134        if self._child_num:
135            return self._child_num
136        self._calc_child_parents()
137        return self._child_num
138
139    def get_parents(self) -> list[int]:
140        """Return parents. / O(N)"""
141        if self._parents:
142            return self._parents
143        self._calc_child_parents()
144        return self._parents
145
146    def get_diameter(self) -> tuple[int, int, int]:
147        """Return diameter of tree. (diameter, start, stop) / O(N)"""
148        if self._diameter[0] > -1:
149            return self._diameter
150        s = self._dist.index(self.get_height())
151        todo = [s]
152        ndist = [-1] * self._n
153        ndist[s] = 0
154        while todo:
155            v = todo.pop()
156            d = ndist[v]
157            for x, c in self._G[v]:
158                if ndist[x] != -1:
159                    continue
160                ndist[x] = d + c
161                todo.append(x)
162        diameter = max(ndist)
163        t = ndist.index(diameter)
164        self._diameter = (diameter, s, t)
165        return self._diameter
166
167    def get_bipartite_graph(self) -> list[int]:
168        """Return [1 if root else 0]. / O(N)"""
169        if self._bipartite_graph:
170            return self._bipartite_graph
171        self._bipartite_graph = [-1] * self._n
172        _bipartite_graph = self._bipartite_graph
173        _bipartite_graph[self._root] = 1
174        todo = [self._root]
175        while todo:
176            v = todo.pop()
177            nc = 0 if _bipartite_graph[v] else 1
178            for x, _ in self._G[v]:
179                if _bipartite_graph[x] != -1:
180                    continue
181                _bipartite_graph[x] = nc
182                todo.append(x)
183        return _bipartite_graph
184
185    def _calc_doubling(self) -> None:
186        "Calc doubling if self._lca. / O(NlogN)"
187        if not self._parents:
188            self._calc_child_parents()
189        for i in range(self._n):
190            self._doubling[0][i] = self._parents[i]
191        for k in range(self._K - 1):
192            for v in range(self._n):
193                if self._doubling[k][v] < 0:
194                    self._doubling[k + 1][v] = -1
195                else:
196                    self._doubling[k + 1][v] = self._doubling[k][self._doubling[k][v]]
197
198    def get_lca(self, u: int, v: int) -> int:
199        """Return LCA of (u, v). / O(logN)"""
200        assert self._lca, f"{self.__class__.__name__}.get_lca(), `lca` must be True"
201        _doubling, _rank = self._doubling, self._rank
202        if _rank[u] < _rank[v]:
203            u, v = v, u
204        _r = _rank[u] - _rank[v]
205        for k in range(self._K):
206            if _r >> k & 1:
207                u = _doubling[k][u]
208        if u == v:
209            return u
210        for k in range(self._K - 1, -1, -1):
211            if _doubling[k][u] != _doubling[k][v]:
212                u = _doubling[k][u]
213                v = _doubling[k][v]
214        return _doubling[0][u]
215
216    def get_dist(self, u: int, v: int) -> int:
217        """Return dist(u - v). / O(logN)"""
218        return self._dist[u] + self._dist[v] - 2 * self._dist[self.get_lca(u, v)]
219
220    def is_on_path(self, u: int, v: int, a: int) -> bool:
221        """Return True if (a is on path(u - v)) else False. / O(logN)"""
222        raise NotImplementedError
223        return self.get_dist(u, a) + self.get_dist(a, v) == self.get_dist(
224            u, v
225        )  # rank??
226
227    def get_path(self, u: int, v: int) -> list[int]:
228        """Return path (u -> v). / O(logN + |path|)"""
229        assert self._lca, f"{self.__class__.__name__}, `lca` must be True"
230        if u == v:
231            return [u]
232        self.get_parents()
233
234        def get_path_lca(u: int, v: int) -> list[int]:
235            path = []
236            while u != v:
237                u = self._parents[u]
238                if u == v:
239                    break
240                path.append(u)
241            return path
242
243        lca = self.get_lca(u, v)
244        path = [u]
245        path.extend(get_path_lca(u, lca))
246        if u != lca and v != lca:
247            path.append(lca)
248        path.extend(get_path_lca(v, lca)[::-1])
249        path.append(v)
250        return path
251
252    def dfs_in_out(self) -> tuple[list[int], list[int]]:
253        curtime = -1
254        todo = [~self._root, self._root]
255        intime = [-1] * self._n
256        outtime = [-1] * self._n
257        seen = [False] * self._n
258        seen[self._root] = True
259        while todo:
260            curtime += 1
261            v = todo.pop()
262            if v >= 0:
263                intime[v] = curtime
264                for x, _ in self._G[v]:
265                    if not seen[x]:
266                        todo.append(~x)
267                        todo.append(x)
268                        seen[x] = True
269            else:
270                outtime[~v] = curtime
271        return intime, outtime

仕様

class WeightedRootedTree(_G: list[list[tuple[int, int]]], _root: int, cp: bool = False, lca: bool = False)[source]

Bases: object

dfs_in_out() tuple[list[int], list[int]][source]
get_bipartite_graph() list[int][source]

Return [1 if root else 0]. / O(N)

get_child() list[list[int]][source]

Return child / O(N)

get_child_num() list[int][source]

Return child_num. / O(N)

get_descendant_num() list[int][source]

Return descendant_num. / O(N)

get_diameter() tuple[int, int, int][source]

Return diameter of tree. (diameter, start, stop) / O(N)

get_dist(u: int, v: int) int[source]

Return dist(u - v). / O(logN)

get_dists() list[int][source]

Return dist from root. / O(N)

get_height() int[source]

Return height. / O(N)

get_lca(u: int, v: int) int[source]

Return LCA of (u, v). / O(logN)

get_parents() list[int][source]

Return parents. / O(N)

get_path(u: int, v: int) list[int][source]

Return path (u -> v). / O(logN + |path|)

get_toposo() list[int][source]

Return toposo. / O(N)

is_on_path(u: int, v: int, a: int) bool[source]

Return True if (a is on path(u - v)) else False. / O(logN)