directed_mst

ソースコード

from titan_pylib.others.directed_mst import SkewHeap

view on github

展開済みコード

  1# from titan_pylib.others.directed_mst import directed_mst
  2# from titan_pylib.data_structures.union_find.union_find import UnionFind
  3from collections import defaultdict
  4
  5
  6class UnionFind:
  7
  8    def __init__(self, n: int) -> None:
  9        """``n`` 個の要素からなる ``UnionFind`` を構築します。
 10        :math:`O(n)` です。
 11        """
 12        self._n: int = n
 13        self._group_numbers: int = n
 14        self._parents: list[int] = [-1] * n
 15
 16    def root(self, x: int) -> int:
 17        """要素 ``x`` を含む集合の代表元を返します。
 18        :math:`O(\\alpha(n))` です。
 19        """
 20        a = x
 21        while self._parents[a] >= 0:
 22            a = self._parents[a]
 23        while self._parents[x] >= 0:
 24            y = x
 25            x = self._parents[x]
 26            self._parents[y] = a
 27        return a
 28
 29    def unite(self, x: int, y: int) -> bool:
 30        """要素 ``x`` を含む集合と要素 ``y`` を含む集合を併合します。
 31        :math:`O(\\alpha(n))` です。
 32
 33        Returns:
 34          bool: もともと同じ集合であれば ``False``、そうでなければ ``True`` を返します。
 35        """
 36        x = self.root(x)
 37        y = self.root(y)
 38        if x == y:
 39            return False
 40        self._group_numbers -= 1
 41        if self._parents[x] > self._parents[y]:
 42            x, y = y, x
 43        self._parents[x] += self._parents[y]
 44        self._parents[y] = x
 45        return True
 46
 47    def unite_right(self, x: int, y: int) -> int:
 48        # x -> y
 49        x = self.root(x)
 50        y = self.root(y)
 51        if x == y:
 52            return x
 53        self._group_numbers -= 1
 54        self._parents[y] += self._parents[x]
 55        self._parents[x] = y
 56        return y
 57
 58    def unite_left(self, x: int, y: int) -> int:
 59        # x <- y
 60        x = self.root(x)
 61        y = self.root(y)
 62        if x == y:
 63            return x
 64        self._group_numbers -= 1
 65        self._parents[x] += self._parents[y]
 66        self._parents[y] = x
 67        return x
 68
 69    def size(self, x: int) -> int:
 70        """要素 ``x`` を含む集合の要素数を返します。
 71        :math:`O(\\alpha(n))` です。
 72        """
 73        return -self._parents[self.root(x)]
 74
 75    def same(self, x: int, y: int) -> bool:
 76        """
 77        要素 ``x`` と ``y`` が同じ集合に属するなら ``True`` を、
 78        そうでないなら ``False`` を返します。
 79        :math:`O(\\alpha(n))` です。
 80        """
 81        return self.root(x) == self.root(y)
 82
 83    def members(self, x: int) -> list[int]:
 84        """要素 ``x`` を含む集合を返します。"""
 85        x = self.root(x)
 86        return [i for i in range(self._n) if self.root(i) == x]
 87
 88    def all_roots(self) -> list[int]:
 89        """全ての集合の代表元からなるリストを返します。
 90        :math:`O(n)` です。
 91
 92        Returns:
 93          list[int]: 昇順であることが保証されます。
 94        """
 95        return [i for i, x in enumerate(self._parents) if x < 0]
 96
 97    def group_count(self) -> int:
 98        """集合の総数を返します。
 99        :math:`O(1)` です。
100        """
101        return self._group_numbers
102
103    def all_group_members(self) -> defaultdict:
104        """
105        key に代表元、 value に key を代表元とする集合のリストをもつ defaultdict を返します。
106        :math:`O(n\\alpha(n))` です。
107        """
108        group_members = defaultdict(list)
109        for member in range(self._n):
110            group_members[self.root(member)].append(member)
111        return group_members
112
113    def clear(self) -> None:
114        """集合の連結状態をなくします(初期状態に戻します)。
115        :math:`O(n)` です。
116        """
117        self._group_numbers = self._n
118        for i in range(self._n):
119            self._parents[i] = -1
120
121    def __str__(self) -> str:
122        """よしなにします。
123        :math:`O(n\\alpha(n))` です。
124        """
125        return (
126            f"<{self.__class__.__name__}> [\n"
127            + "\n".join(f"  {k}: {v}" for k, v in self.all_group_members().items())
128            + "\n]"
129        )
130
131
132class SkewHeap:
133
134    class Node:
135        def __init__(self, val):
136            self.l = None
137            self.r = None
138            self.val = val
139            self.add = 0
140
141        def lazy_propagate(self):
142            if self.l is not None:
143                self.l.add += self.add
144            if self.r is not None:
145                self.r.add += self.add
146            self.val += self.add
147            self.add = 0
148
149    def __init__(self):
150        self.root = None
151
152    def _meld(self, a, b):
153        if a is None:
154            return b
155        if b is None:
156            return a
157        if b.val + b.add < a.val + a.add:
158            a, b = b, a
159        a.lazy_propagate()
160        a.r = self._meld(a.r, b)
161        a.l, a.r = a.r, a.l
162        return a
163
164    @property
165    def min(self):
166        self.root.lazy_propagate()
167        return self.root.val
168
169    def push(self, val):
170        nd = self.Node(val)
171        self.root = self._meld(self.root, nd)
172
173    def pop(self):
174        rt = self.root
175        rt.lazy_propagate()
176        self.root = self._meld(rt.l, rt.r)
177        return rt.val
178
179    def meld(self, other):
180        self.root = self._meld(self.root, other.root)
181
182    def add(self, val):
183        self.root.add += val
184
185    def empty(self):
186        return self.root is None
187
188
189def directed_mst(n, edges, root):
190    OFFSET = len(edges)
191    from_ = [0] * n
192    from_cost = [0] * n
193    from_heap = [SkewHeap() for _ in range(n)]
194
195    uf = UnionFind(n)
196    par_e = [-1] * m
197    stem = [-1] * n
198    used = [0] * n
199    used[root] = 2
200    idxs = []
201
202    for idx, (fr, to, cost) in enumerate(edges):
203        from_heap[to].push(cost * OFFSET + idx)
204
205    res = 0
206    for v in range(n):
207        if used[v] != 0:
208            continue
209        processing = []
210        chi_e = []
211        cycle = 0
212        while used[v] != 2:
213            used[v] = 1
214            processing.append(v)
215            if from_heap[v].empty():
216                return -1, par
217            from_cost[v], idx = divmod(from_heap[v].pop(), OFFSET)
218            from_[v] = uf.root(edges[idx][0])
219            if stem[v] == -1:
220                stem[v] = idx
221            if from_[v] == v:
222                continue
223            res += from_cost[v]
224            idxs.append(idx)
225            while cycle:
226                par_e[chi_e.pop()] = idx
227                cycle -= 1
228            chi_e.append(idx)
229            if used[from_[v]] == 1:
230                p = v
231                while True:
232                    if not from_heap[p].empty():
233                        from_heap[p].add(-from_cost[p] * OFFSET)
234                    if p != v:
235                        uf.merge(v, p)
236                        from_heap[v].meld(from_heap[p])
237                    p = uf.root(from_[p])
238                    cycle += 1
239                    if p == v:
240                        break
241            else:
242                v = from_[v]
243        for v in processing:
244            used[v] = 2
245
246    used_e = [False] * m
247    tree = [-1] * n
248    for idx in reversed(idxs):
249        if used_e[idx]:
250            continue
251        fr, to, cost = edges[idx]
252        tree[to] = fr
253        x = stem[to]
254        while x != idx:
255            used_e[x] = True
256            x = par_e[x]
257    return res, tree
258
259
260n, m, root = map(int, input().split())
261edges = [list(map(int, input().split())) for i in range(m)]
262
263
264res, par = directed_mst(n, edges, root)
265if res == -1:
266    print(res)
267else:
268    print(res)
269    print(*[p if p != -1 else i for i, p in enumerate(par)])

仕様