fenwick_tree_multiset¶
ソースコード¶
from titan_pylib.data_structures.set.fenwick_tree_multiset import FenwickTreeMultiset
展開済みコード¶
1# from titan_pylib.data_structures.set.fenwick_tree_multiset import FenwickTreeMultiset
2# from titan_pylib.data_structures.set.fenwick_tree_set import FenwickTreeSet
3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
4from typing import Protocol
5
6
7class SupportsLessThan(Protocol):
8
9 def __lt__(self, other) -> bool: ...
10# from titan_pylib.data_structures.fenwick_tree.fenwick_tree import FenwickTree
11from typing import Union, Iterable, Optional
12
13
14class FenwickTree:
15 """FenwickTreeです。"""
16
17 def __init__(self, n_or_a: Union[Iterable[int], int]):
18 """構築します。
19 :math:`O(n)` です。
20
21 Args:
22 n_or_a (Union[Iterable[int], int]): `n_or_a` が `int` のとき、初期値 `0` 、長さ `n` で構築します。
23 `n_or_a` が `Iterable` のとき、初期値 `a` で構築します。
24 """
25 if isinstance(n_or_a, int):
26 self._size = n_or_a
27 self._tree = [0] * (self._size + 1)
28 else:
29 a = n_or_a if isinstance(n_or_a, list) else list(n_or_a)
30 _size = len(a)
31 _tree = [0] + a
32 for i in range(1, _size):
33 if i + (i & -i) <= _size:
34 _tree[i + (i & -i)] += _tree[i]
35 self._size = _size
36 self._tree = _tree
37 self._s = 1 << (self._size - 1).bit_length()
38
39 def pref(self, r: int) -> int:
40 """区間 ``[0, r)`` の総和を返します。
41 :math:`O(\\log{n})` です。
42 """
43 assert (
44 0 <= r <= self._size
45 ), f"IndexError: {self.__class__.__name__}.pref({r}), n={self._size}"
46 ret, _tree = 0, self._tree
47 while r > 0:
48 ret += _tree[r]
49 r &= r - 1
50 return ret
51
52 def suff(self, l: int) -> int:
53 """区間 ``[l, n)`` の総和を返します。
54 :math:`O(\\log{n})` です。
55 """
56 assert (
57 0 <= l < self._size
58 ), f"IndexError: {self.__class__.__name__}.suff({l}), n={self._size}"
59 return self.pref(self._size) - self.pref(l)
60
61 def sum(self, l: int, r: int) -> int:
62 """区間 ``[l, r)`` の総和を返します。
63 :math:`O(\\log{n})` です。
64 """
65 assert (
66 0 <= l <= r <= self._size
67 ), f"IndexError: {self.__class__.__name__}.sum({l}, {r}), n={self._size}"
68 _tree = self._tree
69 res = 0
70 while r > l:
71 res += _tree[r]
72 r &= r - 1
73 while l > r:
74 res -= _tree[l]
75 l &= l - 1
76 return res
77
78 prod = sum
79
80 def __getitem__(self, k: int) -> int:
81 """位置 ``k`` の要素を返します。
82 :math:`O(\\log{n})` です。
83 """
84 assert (
85 -self._size <= k < self._size
86 ), f"IndexError: {self.__class__.__name__}[{k}], n={self._size}"
87 if k < 0:
88 k += self._size
89 return self.sum(k, k + 1)
90
91 def add(self, k: int, x: int) -> None:
92 """``k`` 番目の値に ``x`` を加えます。
93 :math:`O(\\log{n})` です。
94 """
95 assert (
96 0 <= k < self._size
97 ), f"IndexError: {self.__class__.__name__}.add({k}, {x}), n={self._size}"
98 k += 1
99 _tree = self._tree
100 while k <= self._size:
101 _tree[k] += x
102 k += k & -k
103
104 def __setitem__(self, k: int, x: int):
105 """``k`` 番目の値を ``x`` に更新します。
106 :math:`O(\\log{n})` です。
107 """
108 assert (
109 -self._size <= k < self._size
110 ), f"IndexError: {self.__class__.__name__}[{k}] = {x}, n={self._size}"
111 if k < 0:
112 k += self._size
113 pre = self[k]
114 self.add(k, x - pre)
115
116 def bisect_left(self, w: int) -> Optional[int]:
117 i, s, _size, _tree = 0, self._s, self._size, self._tree
118 while s:
119 if i + s <= _size and _tree[i + s] < w:
120 w -= _tree[i + s]
121 i += s
122 s >>= 1
123 return i if w else None
124
125 def bisect_right(self, w: int) -> int:
126 i, s, _size, _tree = 0, self._s, self._size, self._tree
127 while s:
128 if i + s <= _size and _tree[i + s] <= w:
129 w -= _tree[i + s]
130 i += s
131 s >>= 1
132 return i
133
134 def _pop(self, k: int) -> int:
135 assert k >= 0
136 i, acc, s, _size, _tree = 0, 0, self._s, self._size, self._tree
137 while s:
138 if i + s <= _size:
139 if acc + _tree[i + s] <= k:
140 acc += _tree[i + s]
141 i += s
142 else:
143 _tree[i + s] -= 1
144 s >>= 1
145 return i
146
147 def tolist(self) -> list[int]:
148 """リストにして返します。
149 :math:`O(n)` です。
150 """
151 sub = [self.pref(i) for i in range(self._size + 1)]
152 return [sub[i + 1] - sub[i] for i in range(self._size)]
153
154 @staticmethod
155 def get_inversion_num(a: list[int], compress: bool = False) -> int:
156 inv = 0
157 if compress:
158 a_ = sorted(set(a))
159 z = {e: i for i, e in enumerate(a_)}
160 fw = FenwickTree(len(a_) + 1)
161 for i, e in enumerate(a):
162 inv += i - fw.pref(z[e] + 1)
163 fw.add(z[e], 1)
164 else:
165 fw = FenwickTree(len(a) + 1)
166 for i, e in enumerate(a):
167 inv += i - fw.pref(e + 1)
168 fw.add(e, 1)
169 return inv
170
171 def __str__(self):
172 return str(self.tolist())
173
174 def __repr__(self):
175 return f"{self.__class__.__name__}({self})"
176from typing import Iterable, TypeVar, Generic, Union, Optional
177
178T = TypeVar("T", bound=SupportsLessThan)
179
180
181class FenwickTreeSet(Generic[T]):
182
183 def __init__(
184 self,
185 _used: Union[int, Iterable[T]],
186 _a: Iterable[T] = [],
187 compress=True,
188 _multi=False,
189 ) -> None:
190 self._len = 0
191 if isinstance(_used, int):
192 self._to_origin = list(range(_used))
193 elif isinstance(_used, set):
194 self._to_origin = sorted(_used)
195 else:
196 self._to_origin = sorted(set(_used))
197 self._to_zaatsu: dict[T, int] = (
198 {key: i for i, key in enumerate(self._to_origin)}
199 if compress
200 else self._to_origin
201 )
202 self._size = len(self._to_origin)
203 self._cnt = [0] * self._size
204 _a = list(_a)
205 if _a:
206 a_ = [0] * self._size
207 if _multi:
208 self._len = len(_a)
209 for v in _a:
210 i = self._to_zaatsu[v]
211 a_[i] += 1
212 self._cnt[i] += 1
213 else:
214 for v in _a:
215 i = self._to_zaatsu[v]
216 if self._cnt[i] == 0:
217 self._len += 1
218 a_[i] = 1
219 self._cnt[i] = 1
220 self._fw = FenwickTree(a_)
221 else:
222 self._fw = FenwickTree(self._size)
223
224 def add(self, key: T) -> bool:
225 i = self._to_zaatsu[key]
226 if self._cnt[i]:
227 return False
228 self._len += 1
229 self._cnt[i] = 1
230 self._fw.add(i, 1)
231 return True
232
233 def remove(self, key: T) -> None:
234 if not self.discard(key):
235 raise KeyError(key)
236
237 def discard(self, key: T) -> bool:
238 i = self._to_zaatsu[key]
239 if self._cnt[i]:
240 self._len -= 1
241 self._cnt[i] = 0
242 self._fw.add(i, -1)
243 return True
244 return False
245
246 def le(self, key: T) -> Optional[T]:
247 i = self._to_zaatsu[key]
248 if self._cnt[i]:
249 return key
250 pref = self._fw.pref(i) - 1
251 return None if pref < 0 else self._to_origin[self._fw.bisect_right(pref)]
252
253 def lt(self, key: T) -> Optional[T]:
254 pref = self._fw.pref(self._to_zaatsu[key]) - 1
255 return None if pref < 0 else self._to_origin[self._fw.bisect_right(pref)]
256
257 def ge(self, key: T) -> Optional[T]:
258 i = self._to_zaatsu[key]
259 if self._cnt[i]:
260 return key
261 pref = self._fw.pref(i + 1)
262 return (
263 None if pref >= self._len else self._to_origin[self._fw.bisect_right(pref)]
264 )
265
266 def gt(self, key: T) -> Optional[T]:
267 pref = self._fw.pref(self._to_zaatsu[key] + 1)
268 return (
269 None if pref >= self._len else self._to_origin[self._fw.bisect_right(pref)]
270 )
271
272 def index(self, key: T) -> int:
273 return self._fw.pref(self._to_zaatsu[key])
274
275 def index_right(self, key: T) -> int:
276 return self._fw.pref(self._to_zaatsu[key] + 1)
277
278 def pop(self, k: int = -1) -> T:
279 assert (
280 -self._len <= k < self._len
281 ), f"IndexError: FenwickTreeSet.pop({k}), Index out of range."
282 if k < 0:
283 k += self._len
284 self._len -= 1
285 x = self._fw._pop(k)
286 self._cnt[x] = 0
287 return self._to_origin[x]
288
289 def pop_min(self) -> T:
290 assert (
291 self._len > 0
292 ), f"IndexError: pop_min() from empty {self.__class__.__name__}."
293 return self.pop(0)
294
295 def pop_max(self) -> T:
296 assert (
297 self._len > 0
298 ), f"IndexError: pop_max() from empty {self.__class__.__name__}."
299 return self.pop(-1)
300
301 def get_min(self) -> Optional[T]:
302 if not self:
303 return
304 return self[0]
305
306 def get_max(self) -> Optional[T]:
307 if not self:
308 return
309 return self[-1]
310
311 def __getitem__(self, k):
312 assert (
313 -self._len <= k < self._len
314 ), f"IndexError: FenwickTreeSet[{k}], Index out of range."
315 if k < 0:
316 k += self._len
317 return self._to_origin[self._fw.bisect_right(k)]
318
319 def __iter__(self):
320 self._iter = 0
321 return self
322
323 def __next__(self):
324 if self._iter == self._len:
325 raise StopIteration
326 res = self._to_origin[self._fw.bisect_right(self._iter)]
327 self._iter += 1
328 return res
329
330 def __reversed__(self):
331 _to_origin = self._to_origin
332 for i in range(self._len):
333 yield _to_origin[self._fw.bisect_right(self._len - i - 1)]
334
335 def __len__(self):
336 return self._len
337
338 def __contains__(self, key: T):
339 return self._cnt[self._to_zaatsu[key]] > 0
340
341 def __bool__(self):
342 return self._len > 0
343
344 def __str__(self):
345 return "{" + ", ".join(map(str, self)) + "}"
346
347 def __repr__(self):
348 return f"{self.__class__.__name__}({self})"
349from typing import Iterable, TypeVar, Generic, Union
350
351T = TypeVar("T")
352
353
354class FenwickTreeMultiset(FenwickTreeSet, Generic[T]):
355
356 def __init__(
357 self, used: Union[int, Iterable[T]], a: Iterable[T] = [], compress: bool = True
358 ) -> None:
359 """
360 Args:
361 used (Union[int, Iterable[T]]): 使用する要素の集合
362 a (Iterable[T], optional): 初期集合
363 compress (bool, optional): 座圧するかどうか( ``True`` : する)
364 """
365 super().__init__(used, a, compress=compress, _multi=True)
366
367 def add(self, key: T, num: int = 1) -> None:
368 if num <= 0:
369 return
370 i = self._to_zaatsu[key]
371 self._len += num
372 self._cnt[i] += num
373 self._fw.add(i, num)
374
375 def remove(self, key: T, num: int = 1) -> None:
376 if not self.discard(key, num):
377 raise KeyError(key)
378
379 def discard(self, key: T, num: int = 1) -> bool:
380 i = self._to_zaatsu[key]
381 num = min(num, self._cnt[i])
382 if num <= 0:
383 return False
384 self._len -= num
385 self._cnt[i] -= num
386 self._fw.add(i, -num)
387 return True
388
389 def discard_all(self, key: T) -> bool:
390 return self.discard(key, num=self.count(key))
391
392 def count(self, key: T) -> int:
393 return self._cnt[self._to_zaatsu[key]]
394
395 def pop(self, k: int = -1) -> T:
396 assert (
397 -self._len <= k < self._len
398 ), f"IndexError: {self.__class__.__name__}.pop({k}), len={self._len}"
399 x = self[k]
400 self.discard(x)
401 return x
402
403 def pop_min(self) -> T:
404 assert (
405 self._len > 0
406 ), f"IndexError: pop_min() from empty {self.__class__.__name__}."
407 return self.pop(0)
408
409 def pop_max(self) -> T:
410 assert (
411 self._len > 0
412 ), f"IndexError: pop_max() from empty {self.__class__.__name__}."
413 return self.pop(-1)
414
415 def items(self) -> Iterable[tuple[T, int]]:
416 _iter = 0
417 while _iter < self._len:
418 res = self._to_origin[self._bisect_right(_iter)]
419 cnt = self.count(res)
420 _iter += cnt
421 yield res, cnt
422
423 def show(self) -> None:
424 print("{" + ", ".join(f"{i[0]}: {i[1]}" for i in self.items()) + "}")
仕様¶
- class FenwickTreeMultiset(used: int | Iterable[T], a: Iterable[T] = [], compress: bool = True)[source]¶
Bases:
FenwickTreeSet
,Generic
[T
]