fenwick_tree_multiset

ソースコード

from titan_pylib.data_structures.set.fenwick_tree_multiset import FenwickTreeMultiset

view on github

展開済みコード

  1# from titan_pylib.data_structures.set.fenwick_tree_multiset import FenwickTreeMultiset
  2# from titan_pylib.data_structures.set.fenwick_tree_set import FenwickTreeSet
  3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  4from typing import Protocol
  5
  6
  7class SupportsLessThan(Protocol):
  8
  9    def __lt__(self, other) -> bool: ...
 10# from titan_pylib.data_structures.fenwick_tree.fenwick_tree import FenwickTree
 11from typing import Union, Iterable, Optional
 12
 13
 14class FenwickTree:
 15    """FenwickTreeです。"""
 16
 17    def __init__(self, n_or_a: Union[Iterable[int], int]):
 18        """構築します。
 19        :math:`O(n)` です。
 20
 21        Args:
 22          n_or_a (Union[Iterable[int], int]): `n_or_a` が `int` のとき、初期値 `0` 、長さ `n` で構築します。
 23                                              `n_or_a` が `Iterable` のとき、初期値 `a` で構築します。
 24        """
 25        if isinstance(n_or_a, int):
 26            self._size = n_or_a
 27            self._tree = [0] * (self._size + 1)
 28        else:
 29            a = n_or_a if isinstance(n_or_a, list) else list(n_or_a)
 30            _size = len(a)
 31            _tree = [0] + a
 32            for i in range(1, _size):
 33                if i + (i & -i) <= _size:
 34                    _tree[i + (i & -i)] += _tree[i]
 35            self._size = _size
 36            self._tree = _tree
 37        self._s = 1 << (self._size - 1).bit_length()
 38
 39    def pref(self, r: int) -> int:
 40        """区間 ``[0, r)`` の総和を返します。
 41        :math:`O(\\log{n})` です。
 42        """
 43        assert (
 44            0 <= r <= self._size
 45        ), f"IndexError: {self.__class__.__name__}.pref({r}), n={self._size}"
 46        ret, _tree = 0, self._tree
 47        while r > 0:
 48            ret += _tree[r]
 49            r &= r - 1
 50        return ret
 51
 52    def suff(self, l: int) -> int:
 53        """区間 ``[l, n)`` の総和を返します。
 54        :math:`O(\\log{n})` です。
 55        """
 56        assert (
 57            0 <= l < self._size
 58        ), f"IndexError: {self.__class__.__name__}.suff({l}), n={self._size}"
 59        return self.pref(self._size) - self.pref(l)
 60
 61    def sum(self, l: int, r: int) -> int:
 62        """区間 ``[l, r)`` の総和を返します。
 63        :math:`O(\\log{n})` です。
 64        """
 65        assert (
 66            0 <= l <= r <= self._size
 67        ), f"IndexError: {self.__class__.__name__}.sum({l}, {r}), n={self._size}"
 68        _tree = self._tree
 69        res = 0
 70        while r > l:
 71            res += _tree[r]
 72            r &= r - 1
 73        while l > r:
 74            res -= _tree[l]
 75            l &= l - 1
 76        return res
 77
 78    prod = sum
 79
 80    def __getitem__(self, k: int) -> int:
 81        """位置 ``k`` の要素を返します。
 82        :math:`O(\\log{n})` です。
 83        """
 84        assert (
 85            -self._size <= k < self._size
 86        ), f"IndexError: {self.__class__.__name__}[{k}], n={self._size}"
 87        if k < 0:
 88            k += self._size
 89        return self.sum(k, k + 1)
 90
 91    def add(self, k: int, x: int) -> None:
 92        """``k`` 番目の値に ``x`` を加えます。
 93        :math:`O(\\log{n})` です。
 94        """
 95        assert (
 96            0 <= k < self._size
 97        ), f"IndexError: {self.__class__.__name__}.add({k}, {x}), n={self._size}"
 98        k += 1
 99        _tree = self._tree
100        while k <= self._size:
101            _tree[k] += x
102            k += k & -k
103
104    def __setitem__(self, k: int, x: int):
105        """``k`` 番目の値を ``x`` に更新します。
106        :math:`O(\\log{n})` です。
107        """
108        assert (
109            -self._size <= k < self._size
110        ), f"IndexError: {self.__class__.__name__}[{k}] = {x}, n={self._size}"
111        if k < 0:
112            k += self._size
113        pre = self[k]
114        self.add(k, x - pre)
115
116    def bisect_left(self, w: int) -> Optional[int]:
117        i, s, _size, _tree = 0, self._s, self._size, self._tree
118        while s:
119            if i + s <= _size and _tree[i + s] < w:
120                w -= _tree[i + s]
121                i += s
122            s >>= 1
123        return i if w else None
124
125    def bisect_right(self, w: int) -> int:
126        i, s, _size, _tree = 0, self._s, self._size, self._tree
127        while s:
128            if i + s <= _size and _tree[i + s] <= w:
129                w -= _tree[i + s]
130                i += s
131            s >>= 1
132        return i
133
134    def _pop(self, k: int) -> int:
135        assert k >= 0
136        i, acc, s, _size, _tree = 0, 0, self._s, self._size, self._tree
137        while s:
138            if i + s <= _size:
139                if acc + _tree[i + s] <= k:
140                    acc += _tree[i + s]
141                    i += s
142                else:
143                    _tree[i + s] -= 1
144            s >>= 1
145        return i
146
147    def tolist(self) -> list[int]:
148        """リストにして返します。
149        :math:`O(n)` です。
150        """
151        sub = [self.pref(i) for i in range(self._size + 1)]
152        return [sub[i + 1] - sub[i] for i in range(self._size)]
153
154    @staticmethod
155    def get_inversion_num(a: list[int], compress: bool = False) -> int:
156        inv = 0
157        if compress:
158            a_ = sorted(set(a))
159            z = {e: i for i, e in enumerate(a_)}
160            fw = FenwickTree(len(a_) + 1)
161            for i, e in enumerate(a):
162                inv += i - fw.pref(z[e] + 1)
163                fw.add(z[e], 1)
164        else:
165            fw = FenwickTree(len(a) + 1)
166            for i, e in enumerate(a):
167                inv += i - fw.pref(e + 1)
168                fw.add(e, 1)
169        return inv
170
171    def __str__(self):
172        return str(self.tolist())
173
174    def __repr__(self):
175        return f"{self.__class__.__name__}({self})"
176from typing import Iterable, TypeVar, Generic, Union, Optional
177
178T = TypeVar("T", bound=SupportsLessThan)
179
180
181class FenwickTreeSet(Generic[T]):
182
183    def __init__(
184        self,
185        _used: Union[int, Iterable[T]],
186        _a: Iterable[T] = [],
187        compress=True,
188        _multi=False,
189    ) -> None:
190        self._len = 0
191        if isinstance(_used, int):
192            self._to_origin = list(range(_used))
193        elif isinstance(_used, set):
194            self._to_origin = sorted(_used)
195        else:
196            self._to_origin = sorted(set(_used))
197        self._to_zaatsu: dict[T, int] = (
198            {key: i for i, key in enumerate(self._to_origin)}
199            if compress
200            else self._to_origin
201        )
202        self._size = len(self._to_origin)
203        self._cnt = [0] * self._size
204        _a = list(_a)
205        if _a:
206            a_ = [0] * self._size
207            if _multi:
208                self._len = len(_a)
209                for v in _a:
210                    i = self._to_zaatsu[v]
211                    a_[i] += 1
212                    self._cnt[i] += 1
213            else:
214                for v in _a:
215                    i = self._to_zaatsu[v]
216                    if self._cnt[i] == 0:
217                        self._len += 1
218                        a_[i] = 1
219                        self._cnt[i] = 1
220            self._fw = FenwickTree(a_)
221        else:
222            self._fw = FenwickTree(self._size)
223
224    def add(self, key: T) -> bool:
225        i = self._to_zaatsu[key]
226        if self._cnt[i]:
227            return False
228        self._len += 1
229        self._cnt[i] = 1
230        self._fw.add(i, 1)
231        return True
232
233    def remove(self, key: T) -> None:
234        if not self.discard(key):
235            raise KeyError(key)
236
237    def discard(self, key: T) -> bool:
238        i = self._to_zaatsu[key]
239        if self._cnt[i]:
240            self._len -= 1
241            self._cnt[i] = 0
242            self._fw.add(i, -1)
243            return True
244        return False
245
246    def le(self, key: T) -> Optional[T]:
247        i = self._to_zaatsu[key]
248        if self._cnt[i]:
249            return key
250        pref = self._fw.pref(i) - 1
251        return None if pref < 0 else self._to_origin[self._fw.bisect_right(pref)]
252
253    def lt(self, key: T) -> Optional[T]:
254        pref = self._fw.pref(self._to_zaatsu[key]) - 1
255        return None if pref < 0 else self._to_origin[self._fw.bisect_right(pref)]
256
257    def ge(self, key: T) -> Optional[T]:
258        i = self._to_zaatsu[key]
259        if self._cnt[i]:
260            return key
261        pref = self._fw.pref(i + 1)
262        return (
263            None if pref >= self._len else self._to_origin[self._fw.bisect_right(pref)]
264        )
265
266    def gt(self, key: T) -> Optional[T]:
267        pref = self._fw.pref(self._to_zaatsu[key] + 1)
268        return (
269            None if pref >= self._len else self._to_origin[self._fw.bisect_right(pref)]
270        )
271
272    def index(self, key: T) -> int:
273        return self._fw.pref(self._to_zaatsu[key])
274
275    def index_right(self, key: T) -> int:
276        return self._fw.pref(self._to_zaatsu[key] + 1)
277
278    def pop(self, k: int = -1) -> T:
279        assert (
280            -self._len <= k < self._len
281        ), f"IndexError: FenwickTreeSet.pop({k}), Index out of range."
282        if k < 0:
283            k += self._len
284        self._len -= 1
285        x = self._fw._pop(k)
286        self._cnt[x] = 0
287        return self._to_origin[x]
288
289    def pop_min(self) -> T:
290        assert (
291            self._len > 0
292        ), f"IndexError: pop_min() from empty {self.__class__.__name__}."
293        return self.pop(0)
294
295    def pop_max(self) -> T:
296        assert (
297            self._len > 0
298        ), f"IndexError: pop_max() from empty {self.__class__.__name__}."
299        return self.pop(-1)
300
301    def get_min(self) -> Optional[T]:
302        if not self:
303            return
304        return self[0]
305
306    def get_max(self) -> Optional[T]:
307        if not self:
308            return
309        return self[-1]
310
311    def __getitem__(self, k):
312        assert (
313            -self._len <= k < self._len
314        ), f"IndexError: FenwickTreeSet[{k}], Index out of range."
315        if k < 0:
316            k += self._len
317        return self._to_origin[self._fw.bisect_right(k)]
318
319    def __iter__(self):
320        self._iter = 0
321        return self
322
323    def __next__(self):
324        if self._iter == self._len:
325            raise StopIteration
326        res = self._to_origin[self._fw.bisect_right(self._iter)]
327        self._iter += 1
328        return res
329
330    def __reversed__(self):
331        _to_origin = self._to_origin
332        for i in range(self._len):
333            yield _to_origin[self._fw.bisect_right(self._len - i - 1)]
334
335    def __len__(self):
336        return self._len
337
338    def __contains__(self, key: T):
339        return self._cnt[self._to_zaatsu[key]] > 0
340
341    def __bool__(self):
342        return self._len > 0
343
344    def __str__(self):
345        return "{" + ", ".join(map(str, self)) + "}"
346
347    def __repr__(self):
348        return f"{self.__class__.__name__}({self})"
349from typing import Iterable, TypeVar, Generic, Union
350
351T = TypeVar("T")
352
353
354class FenwickTreeMultiset(FenwickTreeSet, Generic[T]):
355
356    def __init__(
357        self, used: Union[int, Iterable[T]], a: Iterable[T] = [], compress: bool = True
358    ) -> None:
359        """
360        Args:
361          used (Union[int, Iterable[T]]): 使用する要素の集合
362          a (Iterable[T], optional): 初期集合
363          compress (bool, optional): 座圧するかどうか( ``True`` : する)
364        """
365        super().__init__(used, a, compress=compress, _multi=True)
366
367    def add(self, key: T, num: int = 1) -> None:
368        if num <= 0:
369            return
370        i = self._to_zaatsu[key]
371        self._len += num
372        self._cnt[i] += num
373        self._fw.add(i, num)
374
375    def remove(self, key: T, num: int = 1) -> None:
376        if not self.discard(key, num):
377            raise KeyError(key)
378
379    def discard(self, key: T, num: int = 1) -> bool:
380        i = self._to_zaatsu[key]
381        num = min(num, self._cnt[i])
382        if num <= 0:
383            return False
384        self._len -= num
385        self._cnt[i] -= num
386        self._fw.add(i, -num)
387        return True
388
389    def discard_all(self, key: T) -> bool:
390        return self.discard(key, num=self.count(key))
391
392    def count(self, key: T) -> int:
393        return self._cnt[self._to_zaatsu[key]]
394
395    def pop(self, k: int = -1) -> T:
396        assert (
397            -self._len <= k < self._len
398        ), f"IndexError: {self.__class__.__name__}.pop({k}), len={self._len}"
399        x = self[k]
400        self.discard(x)
401        return x
402
403    def pop_min(self) -> T:
404        assert (
405            self._len > 0
406        ), f"IndexError: pop_min() from empty {self.__class__.__name__}."
407        return self.pop(0)
408
409    def pop_max(self) -> T:
410        assert (
411            self._len > 0
412        ), f"IndexError: pop_max() from empty {self.__class__.__name__}."
413        return self.pop(-1)
414
415    def items(self) -> Iterable[tuple[T, int]]:
416        _iter = 0
417        while _iter < self._len:
418            res = self._to_origin[self._bisect_right(_iter)]
419            cnt = self.count(res)
420            _iter += cnt
421            yield res, cnt
422
423    def show(self) -> None:
424        print("{" + ", ".join(f"{i[0]}: {i[1]}" for i in self.items()) + "}")

仕様

class FenwickTreeMultiset(used: int | Iterable[T], a: Iterable[T] = [], compress: bool = True)[source]

Bases: FenwickTreeSet, Generic[T]

add(key: T, num: int = 1) None[source]
count(key: T) int[source]
discard(key: T, num: int = 1) bool[source]
discard_all(key: T) bool[source]
items() Iterable[tuple[T, int]][source]
pop(k: int = -1) T[source]
pop_max() T[source]
pop_min() T[source]
remove(key: T, num: int = 1) None[source]
show() None[source]