Source code for titan_pylib.data_structures.dynamic_connectivity.offline_dynamic_connectivity

  1from typing import Callable
  2from collections import defaultdict
  3
  4
[docs] 5class OfflineDynamicConnectivity: 6 """OfflineDynamicConnectivity 7 8 参考: 9 [ちょっと変わったセグメント木の使い方(ei1333の日記)](https://ei1333.hateblo.jp/entry/2017/12/14/000000) 10 11 Note: 12 内部では辺を ``dict`` で管理しています。メモリに注意です。 13 """ 14
[docs] 15 class UndoableUnionFind: 16 """内部で管理される `UndoableUnionFind` です。""" 17 18 def __init__(self, n: int): 19 self._n: int = n 20 self._parents: list[int] = [-1] * n 21 self._all_sum: list[int] = [0] * n 22 self._one_sum: list[int] = [0] * n 23 self._history: list[tuple[int, int, int]] = [] 24 self._group_count: int = n 25 26 def _undo(self) -> None: 27 assert self._history, "UndoableUnionFind._undo() with non history" 28 y, py, all_sum_y = self._history.pop() 29 if y == -1: 30 return 31 x, px, all_sum_x = self._history.pop() 32 self._group_count += 1 33 self._parents[y] = py 34 self._parents[x] = px 35 s = (self._all_sum[x] - all_sum_y - all_sum_x) // (-py - px) * (-py) 36 self._all_sum[y] += s 37 self._all_sum[x] -= all_sum_y + s 38 self._one_sum[x] -= self._one_sum[y] 39
[docs] 40 def root(self, x: int) -> int: 41 """要素 ``x`` を含む集合の代表元を返します。 42 :math:`O(\\log{n})` です。 43 """ 44 while self._parents[x] >= 0: 45 x = self._parents[x] 46 return x
47
[docs] 48 def unite(self, x: int, y: int) -> bool: 49 """要素 ``x`` を含む集合と要素 ``y`` を含む集合を併合します。 50 :math:`O(\\log{n})` です。 51 52 Returns: 53 bool: もともと同じ集合であれば ``False``、そうでなければ ``True`` を返します。 54 """ 55 x = self.root(x) 56 y = self.root(y) 57 if x == y: 58 self._history.append((-1, -1, -1)) 59 return False 60 if self._parents[x] > self._parents[y]: 61 x, y = y, x 62 self._group_count -= 1 63 self._history.append((x, self._parents[x], self._all_sum[x])) 64 self._history.append((y, self._parents[y], self._all_sum[y])) 65 self._all_sum[x] += self._all_sum[y] 66 self._one_sum[x] += self._one_sum[y] 67 self._parents[x] += self._parents[y] 68 self._parents[y] = x 69 return True
70
[docs] 71 def size(self, x: int) -> int: 72 """要素 ``x`` を含む集合の要素数を返します。 73 :math:`O(\\log{n})` です。 74 """ 75 return -self._parents[self.root(x)]
76
[docs] 77 def same(self, x: int, y: int) -> bool: 78 """ 79 要素 ``x`` と ``y`` が同じ集合に属するなら ``True`` を、 80 そうでないなら ``False`` を返します。 81 :math:`O(\\log{n})` です。 82 """ 83 return self.root(x) == self.root(y)
84
[docs] 85 def add_point(self, x: int, v: int) -> None: 86 """頂点 ``x`` に値 ``v`` を加算します。 87 :math:`O(\\log{n})` です。 88 """ 89 while x >= 0: 90 self._one_sum[x] += v 91 x = self._parents[x]
92
[docs] 93 def add_group(self, x: int, v: int) -> None: 94 """頂点 ``x`` を含む連結成分の要素それぞれに ``v`` を加算します。 95 :math:`O(\\log{n})` です。 96 """ 97 x = self.root(x) 98 self._all_sum[x] += v * self.size(x)
99
[docs] 100 def group_count(self) -> int: 101 """集合の総数を返します。 102 :math:`O(1)` です。 103 """ 104 return self._group_count
105
[docs] 106 def group_sum(self, x: int) -> int: 107 """``x`` を要素に含む集合の総和を求めます。 108 :math:`O(n\\log{n})` です。 109 """ 110 x = self.root(x) 111 return self._one_sum[x] + self._all_sum[x]
112
[docs] 113 def all_group_members(self) -> defaultdict: 114 """``key`` に代表元、 ``value`` に ``key`` を代表元とする集合のリストをもつ ``defaultdict`` を返します。 115 :math:`O(n\\log{n})` です。 116 """ 117 group_members = defaultdict(list) 118 for member in range(self._n): 119 group_members[self.root(member)].append(member) 120 return group_members
121 122 def __str__(self): 123 return ( 124 "<offline-dc.uf> [\n" 125 + "\n".join(f" {k}: {v}" for k, v in self.all_group_members().items()) 126 + "\n]" 127 )
128 129 def __init__(self, n: int) -> None: 130 """初期状態を頂点数 ``n`` の無向グラフとします。 131 :math:`O(n)` です。 132 133 Args: 134 n (int): 頂点数です。 135 """ 136 self._n = n 137 self._query_count = 0 138 self._bit = n.bit_length() + 1 139 self._msk = (1 << self._bit) - 1 140 self._start = defaultdict(lambda: [0, 0]) 141 self._edge_data = [] 142 self.uf = OfflineDynamicConnectivity.UndoableUnionFind(n) 143
[docs] 144 def add_edge(self, u: int, v: int) -> None: 145 """辺 ``{u, v}`` を追加します。 146 :math:`O(1)` です。 147 """ 148 assert 0 <= u < self._n and 0 <= v < self._n 149 if u > v: 150 u, v = v, u 151 edge = u << self._bit | v 152 if self._start[edge][0] == 0: 153 self._start[edge][1] = self._query_count 154 self._start[edge][0] += 1
155
[docs] 156 def delete_edge(self, u: int, v: int) -> None: 157 """辺 ``{u, v}`` を削除します。 158 :math:`O(1)` です。 159 """ 160 assert 0 <= u < self._n and 0 <= v < self._n 161 if u > v: 162 u, v = v, u 163 edge = u << self._bit | v 164 if self._start[edge][0] == 1: 165 self._edge_data.append((self._start[edge][1], self._query_count, edge)) 166 self._start[edge][0] -= 1
167
[docs] 168 def next_query(self) -> None: 169 """クエリカウントを 1 進めます。 170 :math:`O(1)` です。 171 """ 172 self._query_count += 1
173
[docs] 174 def run(self, out: Callable[[int], None]) -> None: 175 """実行します。 176 :math:`O(q \\log{q} \\log{n})` です。 177 178 Args: 179 out (Callable[[int], None]): クエリ番号 ``k`` を引数にとります。 180 """ 181 # O(qlogqlogn) 182 uf, bit, msk, q = self.uf, self._bit, self._msk, self._query_count 183 log = (q - 1).bit_length() 184 size = 1 << log 185 size2 = size * 2 186 data = [[] for _ in range(size << 1)] 187 188 def add(l, r, edge): 189 l += size 190 r += size 191 while l < r: 192 if l & 1: 193 data[l].append(edge) 194 l += 1 195 if r & 1: 196 data[r ^ 1].append(edge) 197 l >>= 1 198 r >>= 1 199 200 for edge, p in self._start.items(): 201 if p[0] != 0: 202 add(p[1], self._query_count, edge) 203 for l, r, edge in self._edge_data: 204 add(l, r, edge) 205 206 todo = [1] 207 while todo: 208 v = todo.pop() 209 if v >= 0: 210 for uv in data[v]: 211 uf.unite(uv >> bit, uv & msk) 212 todo.append(~v) 213 if v << 1 | 1 < size2: 214 todo.append(v << 1 | 1) 215 todo.append(v << 1) 216 elif v - size < q: 217 out(v - size) 218 else: 219 for _ in data[~v]: 220 uf._undo()
221 222 def __repr__(self): 223 return f"OfflineDynamicConnectivity({self._n})"