rooted_tree

ソースコード

from titan_pylib.graph.rooted_tree import RootedTree

view on github

展開済みコード

  1# from titan_pylib.graph.rooted_tree import RootedTree
  2class RootedTree:
  3
  4    def __init__(
  5        self, _G: list[list[int]], _root: int, cp: bool = False, lca: bool = False
  6    ):
  7        self._n: int = len(_G)
  8        self._G: list[list[int]] = _G
  9        self._root: int = _root
 10        self._height: int = -1
 11        self._toposo: list[int] = []
 12        self._dist: list[int] = []
 13        self._descendant_num: list[int] = []
 14        self._child: list[list[int]] = []
 15        self._child_num: list[int] = []
 16        self._parents: list[int] = []
 17        self._diameter: tuple[int, int, int] = (-1, -1, -1)
 18        self._bipartite_graph = []
 19        self._cp = cp
 20        self._lca = lca
 21        K = self._n.bit_length()
 22        self._K = K
 23        self._doubling = [[-1] * self._n for _ in range(self._K)]
 24        self._calc_dist_toposo()
 25        if cp:
 26            self._calc_child_parents()
 27        if lca:
 28            self._calc_doubling()
 29
 30    def __str__(self):
 31        self._calc_child_parents()
 32        ret = ["<RootedTree> ["]
 33        ret.extend(
 34            [
 35                f"  dist:{str(d).zfill(2)} - v:{str(i).zfill(2)} - p:{str(self._parents[i]).zfill(2)} - child:{sorted(self._child[i])}"
 36                for i, d in sorted(enumerate(self._dist), key=lambda x: x[1])
 37            ]
 38        )
 39        ret.append("]")
 40        return "\n".join(ret)
 41
 42    def _calc_dist_toposo(self) -> None:
 43        """Calc dist and toposo. / O(N)"""
 44        # initメソッドで直接実行
 45        _G, _root = self._G, self._root
 46        _dist = [-1] * self._n
 47        _dist[_root] = 0
 48        _toposo = []
 49        _toposo.append(_root)
 50        todo = [_root]
 51        while todo:
 52            v = todo.pop()
 53            d = _dist[v]
 54            for x in _G[v]:
 55                if _dist[x] != -1:
 56                    continue
 57                _dist[x] = d + 1
 58                todo.append(x)
 59                _toposo.append(x)
 60        self._dist = _dist
 61        self._toposo = _toposo
 62
 63    def _calc_child_parents(self) -> None:
 64        """Calc child and parents. / O(N)"""
 65        if self._child and self._child_num and self._parents:
 66            return
 67        _G, _dist = self._G, self._dist
 68        _child_num = [0] * self._n
 69        _child = [[] for _ in range(self._n)]
 70        _parents = [-1] * self._n
 71        for v in self._toposo[::-1]:
 72            for x in _G[v]:
 73                if _dist[x] < _dist[v]:
 74                    _parents[v] = x
 75                    continue
 76                _child[v].append(x)
 77                _child_num[v] += 1
 78        self._child_num = _child_num
 79        self._child = _child
 80        self._parents = _parents
 81
 82    def get_dists(self) -> list[int]:
 83        """Return dist from root. / O(N)"""
 84        return self._dist
 85
 86    def get_toposo(self) -> list[int]:
 87        """Return toposo. / O(N)"""
 88        return self._toposo
 89
 90    def get_height(self) -> int:
 91        """Return height. / O(N)"""
 92        if self._height > -1:
 93            return self._height
 94        self._height = max(self._dist)
 95        return self._height
 96
 97    def get_descendant_num(self) -> list[int]:
 98        """Return descendant_num. / O(N)"""
 99        if self._descendant_num:
100            return self._descendant_num
101        _G, _dist = self._G, self._dist
102        _descendant_num = [1] * self._n
103        for v in self._toposo[::-1]:
104            for x in _G[v]:
105                if _dist[x] < _dist[v]:
106                    continue
107                _descendant_num[v] += _descendant_num[x]
108        for i in range(self._n):
109            _descendant_num[i] -= 1
110        self._descendant_num = _descendant_num
111        return self._descendant_num
112
113    def get_child(self) -> list[list[int]]:
114        """Return child / O(N)"""
115        if self._child:
116            return self._child
117        self._calc_child_parents()
118        return self._child
119
120    def get_child_num(self) -> list[int]:
121        """Return child_num. / O(N)"""
122        if self._child_num:
123            return self._child_num
124        self._calc_child_parents()
125        return self._child_num
126
127    def get_parents(self) -> list[int]:
128        """Return parents. / O(N)"""
129        if self._parents:
130            return self._parents
131        self._calc_child_parents()
132        return self._parents
133
134    def get_diameter(self) -> tuple[int, int, int]:
135        """Return diameter of tree. (diameter, start, stop) / O(N)"""
136        if self._diameter[0] > -1:
137            return self._diameter
138        s = self._dist.index(self.get_height())
139        todo = [s]
140        ndist = [-1] * self._n
141        ndist[s] = 0
142        while todo:
143            v = todo.pop()
144            d = ndist[v]
145            for x in self._G[v]:
146                if ndist[x] != -1:
147                    continue
148                ndist[x] = d + 1
149                todo.append(x)
150        diameter = max(ndist)
151        t = ndist.index(diameter)
152        self._diameter = (diameter, s, t)
153        return self._diameter
154
155    def get_bipartite_graph(self) -> list[int]:
156        """Return [1 if root else 0]. / O(N)"""
157        if self._bipartite_graph:
158            return self._bipartite_graph
159        self._bipartite_graph = [-1] * self._n
160        self._bipartite_graph[self._root] = 1
161        todo = [self._root]
162        while todo:
163            v = todo.pop()
164            nc = 0 if self._bipartite_graph[v] else 1
165            for x in self._G[v]:
166                if self._bipartite_graph[x] != -1:
167                    continue
168                self._bipartite_graph[x] = nc
169                todo.append(x)
170        return self._bipartite_graph
171
172    def _calc_doubling(self) -> None:
173        "Calc doubling if self._lca. / O(NlogN)"
174        if not self._parents:
175            self._calc_child_parents()
176        _doubling = self._doubling
177        for i in range(self._n):
178            _doubling[0][i] = self._parents[i]
179        for k in range(self._K - 1):
180            for v in range(self._n):
181                if _doubling[k][v] < 0:
182                    _doubling[k + 1][v] = -1
183                else:
184                    _doubling[k + 1][v] = _doubling[k][_doubling[k][v]]
185
186    def get_lca(self, u: int, v: int) -> int:
187        """Return LCA of (u, v). / O(logN)"""
188        assert (
189            self._lca
190        ), f"Error: {self.__class__.__name__}.get_lca({u}, {v}), `lca` must be True"
191        _doubling, _dist = self._doubling, self._dist
192        if _dist[u] < _dist[v]:
193            u, v = v, u
194        _r = _dist[u] - _dist[v]
195        for k in range(self._K):
196            if _r >> k & 1:
197                u = _doubling[k][u]
198        if u == v:
199            return u
200        for k in range(self._K - 1, -1, -1):
201            if _doubling[k][u] != _doubling[k][v]:
202                u = _doubling[k][u]
203                v = _doubling[k][v]
204        return _doubling[0][u]
205
206    def get_dist(self, u: int, v: int, vertex: bool = False) -> int:
207        """Return dist(u -> v). / O(logN)"""
208        return (
209            self._dist[u] + self._dist[v] - 2 * self._dist[self.get_lca(u, v)] + vertex
210        )
211
212    def is_on_path(self, u: int, v: int, a: int) -> bool:
213        """Return True if (a is on path(u - v)) else False. / O(logN)"""
214        return self.get_dist(u, a) + self.get_dist(a, v) == self.get_dist(u, v)
215
216    def get_path(self, u: int, v: int) -> list[int]:
217        """Return path (u -> v). / O(logN + |path|)"""
218        assert self._lca, f"{self.__class__.__name__}.get_path(), `lca` must be True"
219        if u == v:
220            return [u]
221        self.get_parents()
222
223        def get_path_lca(u: int, v: int) -> list[int]:
224            path = []
225            while u != v:
226                u = self._parents[u]
227                if u == v:
228                    break
229                path.append(u)
230            return path
231
232        lca = self.get_lca(u, v)
233        path = [u]
234        path.extend(get_path_lca(u, lca))
235        if u != lca and v != lca:
236            path.append(lca)
237        path.extend(get_path_lca(v, lca)[::-1])
238        path.append(v)
239        return path
240
241    def dfs_in_out(self) -> tuple[list[int], list[int]]:
242        curtime = -1
243        todo = [~self._root, self._root]
244        nodein = [-1] * self._n
245        nodeout = [-1] * self._n
246        if not self._parents:
247            self._calc_child_parents()
248        _G, _parents = self._G, self._parents
249        while todo:
250            curtime += 1
251            v = todo.pop()
252            if v >= 0:
253                nodein[v] = curtime
254                for x in _G[v]:
255                    if _parents[v] != x:
256                        todo.append(~x)
257                        todo.append(x)
258            else:
259                nodeout[~v] = curtime
260        return nodein, nodeout

仕様

class RootedTree(_G: list[list[int]], _root: int, cp: bool = False, lca: bool = False)[source]

Bases: object

dfs_in_out() tuple[list[int], list[int]][source]
get_bipartite_graph() list[int][source]

Return [1 if root else 0]. / O(N)

get_child() list[list[int]][source]

Return child / O(N)

get_child_num() list[int][source]

Return child_num. / O(N)

get_descendant_num() list[int][source]

Return descendant_num. / O(N)

get_diameter() tuple[int, int, int][source]

Return diameter of tree. (diameter, start, stop) / O(N)

get_dist(u: int, v: int, vertex: bool = False) int[source]

Return dist(u -> v). / O(logN)

get_dists() list[int][source]

Return dist from root. / O(N)

get_height() int[source]

Return height. / O(N)

get_lca(u: int, v: int) int[source]

Return LCA of (u, v). / O(logN)

get_parents() list[int][source]

Return parents. / O(N)

get_path(u: int, v: int) list[int][source]

Return path (u -> v). / O(logN + |path|)

get_toposo() list[int][source]

Return toposo. / O(N)

is_on_path(u: int, v: int, a: int) bool[source]

Return True if (a is on path(u - v)) else False. / O(logN)