Source code for titan_pylib.data_structures.set.fenwick_tree_multiset
1from titan_pylib.data_structures.set.fenwick_tree_set import FenwickTreeSet
2from typing import Iterable, TypeVar, Generic, Union
3
4T = TypeVar("T")
5
6
[docs]
7class FenwickTreeMultiset(FenwickTreeSet, Generic[T]):
8
9 def __init__(
10 self, used: Union[int, Iterable[T]], a: Iterable[T] = [], compress: bool = True
11 ) -> None:
12 """
13 Args:
14 used (Union[int, Iterable[T]]): 使用する要素の集合
15 a (Iterable[T], optional): 初期集合
16 compress (bool, optional): 座圧するかどうか( ``True`` : する)
17 """
18 super().__init__(used, a, compress=compress, _multi=True)
19
[docs]
20 def add(self, key: T, num: int = 1) -> None:
21 if num <= 0:
22 return
23 i = self._to_zaatsu[key]
24 self._len += num
25 self._cnt[i] += num
26 self._fw.add(i, num)
27
[docs]
28 def remove(self, key: T, num: int = 1) -> None:
29 if not self.discard(key, num):
30 raise KeyError(key)
31
[docs]
32 def discard(self, key: T, num: int = 1) -> bool:
33 i = self._to_zaatsu[key]
34 num = min(num, self._cnt[i])
35 if num <= 0:
36 return False
37 self._len -= num
38 self._cnt[i] -= num
39 self._fw.add(i, -num)
40 return True
41
[docs]
42 def discard_all(self, key: T) -> bool:
43 return self.discard(key, num=self.count(key))
44
[docs]
45 def count(self, key: T) -> int:
46 return self._cnt[self._to_zaatsu[key]]
47
[docs]
48 def pop(self, k: int = -1) -> T:
49 assert (
50 -self._len <= k < self._len
51 ), f"IndexError: {self.__class__.__name__}.pop({k}), len={self._len}"
52 x = self[k]
53 self.discard(x)
54 return x
55
[docs]
56 def pop_min(self) -> T:
57 assert (
58 self._len > 0
59 ), f"IndexError: pop_min() from empty {self.__class__.__name__}."
60 return self.pop(0)
61
[docs]
62 def pop_max(self) -> T:
63 assert (
64 self._len > 0
65 ), f"IndexError: pop_max() from empty {self.__class__.__name__}."
66 return self.pop(-1)
67
[docs]
68 def items(self) -> Iterable[tuple[T, int]]:
69 _iter = 0
70 while _iter < self._len:
71 res = self._to_origin[self._bisect_right(_iter)]
72 cnt = self.count(res)
73 _iter += cnt
74 yield res, cnt
75
[docs]
76 def show(self) -> None:
77 print("{" + ", ".join(f"{i[0]}: {i[1]}" for i in self.items()) + "}")