1from typing import Callable
2from collections import defaultdict
3
4
[docs]
5class OfflineDynamicConnectivity:
6 """OfflineDynamicConnectivity
7
8 参考:
9 [ちょっと変わったセグメント木の使い方(ei1333の日記)](https://ei1333.hateblo.jp/entry/2017/12/14/000000)
10
11 Note:
12 内部では辺を ``dict`` で管理しています。メモリに注意です。
13 """
14
[docs]
15 class UndoableUnionFind:
16 """内部で管理される `UndoableUnionFind` です。"""
17
18 def __init__(self, n: int):
19 self._n: int = n
20 self._parents: list[int] = [-1] * n
21 self._all_sum: list[int] = [0] * n
22 self._one_sum: list[int] = [0] * n
23 self._history: list[tuple[int, int, int]] = []
24 self._group_count: int = n
25
26 def _undo(self) -> None:
27 assert self._history, "UndoableUnionFind._undo() with non history"
28 y, py, all_sum_y = self._history.pop()
29 if y == -1:
30 return
31 x, px, all_sum_x = self._history.pop()
32 self._group_count += 1
33 self._parents[y] = py
34 self._parents[x] = px
35 s = (self._all_sum[x] - all_sum_y - all_sum_x) // (-py - px) * (-py)
36 self._all_sum[y] += s
37 self._all_sum[x] -= all_sum_y + s
38 self._one_sum[x] -= self._one_sum[y]
39
[docs]
40 def root(self, x: int) -> int:
41 """要素 ``x`` を含む集合の代表元を返します。
42 :math:`O(\\log{n})` です。
43 """
44 while self._parents[x] >= 0:
45 x = self._parents[x]
46 return x
47
[docs]
48 def unite(self, x: int, y: int) -> bool:
49 """要素 ``x`` を含む集合と要素 ``y`` を含む集合を併合します。
50 :math:`O(\\log{n})` です。
51
52 Returns:
53 bool: もともと同じ集合であれば ``False``、そうでなければ ``True`` を返します。
54 """
55 x = self.root(x)
56 y = self.root(y)
57 if x == y:
58 self._history.append((-1, -1, -1))
59 return False
60 if self._parents[x] > self._parents[y]:
61 x, y = y, x
62 self._group_count -= 1
63 self._history.append((x, self._parents[x], self._all_sum[x]))
64 self._history.append((y, self._parents[y], self._all_sum[y]))
65 self._all_sum[x] += self._all_sum[y]
66 self._one_sum[x] += self._one_sum[y]
67 self._parents[x] += self._parents[y]
68 self._parents[y] = x
69 return True
70
[docs]
71 def size(self, x: int) -> int:
72 """要素 ``x`` を含む集合の要素数を返します。
73 :math:`O(\\log{n})` です。
74 """
75 return -self._parents[self.root(x)]
76
[docs]
77 def same(self, x: int, y: int) -> bool:
78 """
79 要素 ``x`` と ``y`` が同じ集合に属するなら ``True`` を、
80 そうでないなら ``False`` を返します。
81 :math:`O(\\log{n})` です。
82 """
83 return self.root(x) == self.root(y)
84
[docs]
85 def add_point(self, x: int, v: int) -> None:
86 """頂点 ``x`` に値 ``v`` を加算します。
87 :math:`O(\\log{n})` です。
88 """
89 while x >= 0:
90 self._one_sum[x] += v
91 x = self._parents[x]
92
[docs]
93 def add_group(self, x: int, v: int) -> None:
94 """頂点 ``x`` を含む連結成分の要素それぞれに ``v`` を加算します。
95 :math:`O(\\log{n})` です。
96 """
97 x = self.root(x)
98 self._all_sum[x] += v * self.size(x)
99
[docs]
100 def group_count(self) -> int:
101 """集合の総数を返します。
102 :math:`O(1)` です。
103 """
104 return self._group_count
105
[docs]
106 def group_sum(self, x: int) -> int:
107 """``x`` を要素に含む集合の総和を求めます。
108 :math:`O(n\\log{n})` です。
109 """
110 x = self.root(x)
111 return self._one_sum[x] + self._all_sum[x]
112
[docs]
113 def all_group_members(self) -> defaultdict:
114 """``key`` に代表元、 ``value`` に ``key`` を代表元とする集合のリストをもつ ``defaultdict`` を返します。
115 :math:`O(n\\log{n})` です。
116 """
117 group_members = defaultdict(list)
118 for member in range(self._n):
119 group_members[self.root(member)].append(member)
120 return group_members
121
122 def __str__(self):
123 return (
124 "<offline-dc.uf> [\n"
125 + "\n".join(f" {k}: {v}" for k, v in self.all_group_members().items())
126 + "\n]"
127 )
128
129 def __init__(self, n: int) -> None:
130 """初期状態を頂点数 ``n`` の無向グラフとします。
131 :math:`O(n)` です。
132
133 Args:
134 n (int): 頂点数です。
135 """
136 self._n = n
137 self._query_count = 0
138 self._bit = n.bit_length() + 1
139 self._msk = (1 << self._bit) - 1
140 self._start = defaultdict(lambda: [0, 0])
141 self._edge_data = []
142 self.uf = OfflineDynamicConnectivity.UndoableUnionFind(n)
143
[docs]
144 def add_edge(self, u: int, v: int) -> None:
145 """辺 ``{u, v}`` を追加します。
146 :math:`O(1)` です。
147 """
148 assert 0 <= u < self._n and 0 <= v < self._n
149 if u > v:
150 u, v = v, u
151 edge = u << self._bit | v
152 if self._start[edge][0] == 0:
153 self._start[edge][1] = self._query_count
154 self._start[edge][0] += 1
155
[docs]
156 def delete_edge(self, u: int, v: int) -> None:
157 """辺 ``{u, v}`` を削除します。
158 :math:`O(1)` です。
159 """
160 assert 0 <= u < self._n and 0 <= v < self._n
161 if u > v:
162 u, v = v, u
163 edge = u << self._bit | v
164 if self._start[edge][0] == 1:
165 self._edge_data.append((self._start[edge][1], self._query_count, edge))
166 self._start[edge][0] -= 1
167
[docs]
168 def next_query(self) -> None:
169 """クエリカウントを 1 進めます。
170 :math:`O(1)` です。
171 """
172 self._query_count += 1
173
[docs]
174 def run(self, out: Callable[[int], None]) -> None:
175 """実行します。
176 :math:`O(q \\log{q} \\log{n})` です。
177
178 Args:
179 out (Callable[[int], None]): クエリ番号 ``k`` を引数にとります。
180 """
181 # O(qlogqlogn)
182 uf, bit, msk, q = self.uf, self._bit, self._msk, self._query_count
183 log = (q - 1).bit_length()
184 size = 1 << log
185 size2 = size * 2
186 data = [[] for _ in range(size << 1)]
187
188 def add(l, r, edge):
189 l += size
190 r += size
191 while l < r:
192 if l & 1:
193 data[l].append(edge)
194 l += 1
195 if r & 1:
196 data[r ^ 1].append(edge)
197 l >>= 1
198 r >>= 1
199
200 for edge, p in self._start.items():
201 if p[0] != 0:
202 add(p[1], self._query_count, edge)
203 for l, r, edge in self._edge_data:
204 add(l, r, edge)
205
206 todo = [1]
207 while todo:
208 v = todo.pop()
209 if v >= 0:
210 for uv in data[v]:
211 uf.unite(uv >> bit, uv & msk)
212 todo.append(~v)
213 if v << 1 | 1 < size2:
214 todo.append(v << 1 | 1)
215 todo.append(v << 1)
216 elif v - size < q:
217 out(v - size)
218 else:
219 for _ in data[~v]:
220 uf._undo()
221
222 def __repr__(self):
223 return f"OfflineDynamicConnectivity({self._n})"