offline_dynamic_connectivity

ソースコード

from titan_pylib.data_structures.dynamic_connectivity.offline_dynamic_connectivity import OfflineDynamicConnectivity

view on github

展開済みコード

  1# from titan_pylib.data_structures.dynamic_connectivity.offline_dynamic_connectivity import OfflineDynamicConnectivity
  2from typing import Callable
  3from collections import defaultdict
  4
  5
  6class OfflineDynamicConnectivity:
  7    """OfflineDynamicConnectivity
  8
  9    参考:
 10      [ちょっと変わったセグメント木の使い方(ei1333の日記)](https://ei1333.hateblo.jp/entry/2017/12/14/000000)
 11
 12    Note:
 13      内部では辺を ``dict`` で管理しています。メモリに注意です。
 14    """
 15
 16    class UndoableUnionFind:
 17        """内部で管理される `UndoableUnionFind` です。"""
 18
 19        def __init__(self, n: int):
 20            self._n: int = n
 21            self._parents: list[int] = [-1] * n
 22            self._all_sum: list[int] = [0] * n
 23            self._one_sum: list[int] = [0] * n
 24            self._history: list[tuple[int, int, int]] = []
 25            self._group_count: int = n
 26
 27        def _undo(self) -> None:
 28            assert self._history, "UndoableUnionFind._undo() with non history"
 29            y, py, all_sum_y = self._history.pop()
 30            if y == -1:
 31                return
 32            x, px, all_sum_x = self._history.pop()
 33            self._group_count += 1
 34            self._parents[y] = py
 35            self._parents[x] = px
 36            s = (self._all_sum[x] - all_sum_y - all_sum_x) // (-py - px) * (-py)
 37            self._all_sum[y] += s
 38            self._all_sum[x] -= all_sum_y + s
 39            self._one_sum[x] -= self._one_sum[y]
 40
 41        def root(self, x: int) -> int:
 42            """要素 ``x`` を含む集合の代表元を返します。
 43            :math:`O(\\log{n})` です。
 44            """
 45            while self._parents[x] >= 0:
 46                x = self._parents[x]
 47            return x
 48
 49        def unite(self, x: int, y: int) -> bool:
 50            """要素 ``x`` を含む集合と要素 ``y`` を含む集合を併合します。
 51            :math:`O(\\log{n})` です。
 52
 53            Returns:
 54              bool: もともと同じ集合であれば ``False``、そうでなければ ``True`` を返します。
 55            """
 56            x = self.root(x)
 57            y = self.root(y)
 58            if x == y:
 59                self._history.append((-1, -1, -1))
 60                return False
 61            if self._parents[x] > self._parents[y]:
 62                x, y = y, x
 63            self._group_count -= 1
 64            self._history.append((x, self._parents[x], self._all_sum[x]))
 65            self._history.append((y, self._parents[y], self._all_sum[y]))
 66            self._all_sum[x] += self._all_sum[y]
 67            self._one_sum[x] += self._one_sum[y]
 68            self._parents[x] += self._parents[y]
 69            self._parents[y] = x
 70            return True
 71
 72        def size(self, x: int) -> int:
 73            """要素 ``x`` を含む集合の要素数を返します。
 74            :math:`O(\\log{n})` です。
 75            """
 76            return -self._parents[self.root(x)]
 77
 78        def same(self, x: int, y: int) -> bool:
 79            """
 80            要素 ``x`` と ``y`` が同じ集合に属するなら ``True`` を、
 81            そうでないなら ``False`` を返します。
 82            :math:`O(\\log{n})` です。
 83            """
 84            return self.root(x) == self.root(y)
 85
 86        def add_point(self, x: int, v: int) -> None:
 87            """頂点 ``x`` に値 ``v`` を加算します。
 88            :math:`O(\\log{n})` です。
 89            """
 90            while x >= 0:
 91                self._one_sum[x] += v
 92                x = self._parents[x]
 93
 94        def add_group(self, x: int, v: int) -> None:
 95            """頂点 ``x`` を含む連結成分の要素それぞれに ``v`` を加算します。
 96            :math:`O(\\log{n})` です。
 97            """
 98            x = self.root(x)
 99            self._all_sum[x] += v * self.size(x)
100
101        def group_count(self) -> int:
102            """集合の総数を返します。
103            :math:`O(1)` です。
104            """
105            return self._group_count
106
107        def group_sum(self, x: int) -> int:
108            """``x`` を要素に含む集合の総和を求めます。
109            :math:`O(n\\log{n})` です。
110            """
111            x = self.root(x)
112            return self._one_sum[x] + self._all_sum[x]
113
114        def all_group_members(self) -> defaultdict:
115            """``key`` に代表元、 ``value`` に ``key`` を代表元とする集合のリストをもつ ``defaultdict`` を返します。
116            :math:`O(n\\log{n})` です。
117            """
118            group_members = defaultdict(list)
119            for member in range(self._n):
120                group_members[self.root(member)].append(member)
121            return group_members
122
123        def __str__(self):
124            return (
125                "<offline-dc.uf> [\n"
126                + "\n".join(f"  {k}: {v}" for k, v in self.all_group_members().items())
127                + "\n]"
128            )
129
130    def __init__(self, n: int) -> None:
131        """初期状態を頂点数 ``n`` の無向グラフとします。
132        :math:`O(n)` です。
133
134        Args:
135          n (int): 頂点数です。
136        """
137        self._n = n
138        self._query_count = 0
139        self._bit = n.bit_length() + 1
140        self._msk = (1 << self._bit) - 1
141        self._start = defaultdict(lambda: [0, 0])
142        self._edge_data = []
143        self.uf = OfflineDynamicConnectivity.UndoableUnionFind(n)
144
145    def add_edge(self, u: int, v: int) -> None:
146        """辺 ``{u, v}`` を追加します。
147        :math:`O(1)` です。
148        """
149        assert 0 <= u < self._n and 0 <= v < self._n
150        if u > v:
151            u, v = v, u
152        edge = u << self._bit | v
153        if self._start[edge][0] == 0:
154            self._start[edge][1] = self._query_count
155        self._start[edge][0] += 1
156
157    def delete_edge(self, u: int, v: int) -> None:
158        """辺 ``{u, v}`` を削除します。
159        :math:`O(1)` です。
160        """
161        assert 0 <= u < self._n and 0 <= v < self._n
162        if u > v:
163            u, v = v, u
164        edge = u << self._bit | v
165        if self._start[edge][0] == 1:
166            self._edge_data.append((self._start[edge][1], self._query_count, edge))
167        self._start[edge][0] -= 1
168
169    def next_query(self) -> None:
170        """クエリカウントを 1 進めます。
171        :math:`O(1)` です。
172        """
173        self._query_count += 1
174
175    def run(self, out: Callable[[int], None]) -> None:
176        """実行します。
177        :math:`O(q \\log{q} \\log{n})` です。
178
179        Args:
180          out (Callable[[int], None]): クエリ番号 ``k`` を引数にとります。
181        """
182        # O(qlogqlogn)
183        uf, bit, msk, q = self.uf, self._bit, self._msk, self._query_count
184        log = (q - 1).bit_length()
185        size = 1 << log
186        size2 = size * 2
187        data = [[] for _ in range(size << 1)]
188
189        def add(l, r, edge):
190            l += size
191            r += size
192            while l < r:
193                if l & 1:
194                    data[l].append(edge)
195                    l += 1
196                if r & 1:
197                    data[r ^ 1].append(edge)
198                l >>= 1
199                r >>= 1
200
201        for edge, p in self._start.items():
202            if p[0] != 0:
203                add(p[1], self._query_count, edge)
204        for l, r, edge in self._edge_data:
205            add(l, r, edge)
206
207        todo = [1]
208        while todo:
209            v = todo.pop()
210            if v >= 0:
211                for uv in data[v]:
212                    uf.unite(uv >> bit, uv & msk)
213                todo.append(~v)
214                if v << 1 | 1 < size2:
215                    todo.append(v << 1 | 1)
216                    todo.append(v << 1)
217                elif v - size < q:
218                    out(v - size)
219            else:
220                for _ in data[~v]:
221                    uf._undo()
222
223    def __repr__(self):
224        return f"OfflineDynamicConnectivity({self._n})"

仕様

class OfflineDynamicConnectivity(n: int)[source]

Bases: object

参考:

[ちょっと変わったセグメント木の使い方(ei1333の日記)](https://ei1333.hateblo.jp/entry/2017/12/14/000000)

Note

内部では辺を dict で管理しています。メモリに注意です。

class UndoableUnionFind(n: int)[source]

Bases: object

内部で管理される UndoableUnionFind です。

add_group(x: int, v: int) None[source]

頂点 x を含む連結成分の要素それぞれに v を加算します。 \(O(\log{n})\) です。

add_point(x: int, v: int) None[source]

頂点 x に値 v を加算します。 \(O(\log{n})\) です。

all_group_members() defaultdict[source]

key に代表元、 valuekey を代表元とする集合のリストをもつ defaultdict を返します。 \(O(n\log{n})\) です。

group_count() int[source]

集合の総数を返します。 \(O(1)\) です。

group_sum(x: int) int[source]

x を要素に含む集合の総和を求めます。 \(O(n\log{n})\) です。

root(x: int) int[source]

要素 x を含む集合の代表元を返します。 \(O(\log{n})\) です。

same(x: int, y: int) bool[source]

要素 xy が同じ集合に属するなら True を、 そうでないなら False を返します。 \(O(\log{n})\) です。

size(x: int) int[source]

要素 x を含む集合の要素数を返します。 \(O(\log{n})\) です。

unite(x: int, y: int) bool[source]

要素 x を含む集合と要素 y を含む集合を併合します。 \(O(\log{n})\) です。

Returns:

もともと同じ集合であれば False、そうでなければ True を返します。

Return type:

bool

add_edge(u: int, v: int) None[source]

{u, v} を追加します。 \(O(1)\) です。

delete_edge(u: int, v: int) None[source]

{u, v} を削除します。 \(O(1)\) です。

next_query() None[source]

クエリカウントを 1 進めます。 \(O(1)\) です。

run(out: Callable[[int], None]) None[source]

実行します。 \(O(q \log{q} \log{n})\) です。

Parameters:

out (Callable[[int], None]) – クエリ番号 k を引数にとります。