1# from titan_pylib.data_structures.set.fenwick_tree_set import FenwickTreeSet
2# from titan_pylib.my_class.supports_less_than import SupportsLessThan
3from typing import Protocol
4
5
6class SupportsLessThan(Protocol):
7
8 def __lt__(self, other) -> bool: ...
9# from titan_pylib.data_structures.fenwick_tree.fenwick_tree import FenwickTree
10from typing import Union, Iterable, Optional
11
12
13class FenwickTree:
14 """FenwickTreeです。"""
15
16 def __init__(self, n_or_a: Union[Iterable[int], int]):
17 """構築します。
18 :math:`O(n)` です。
19
20 Args:
21 n_or_a (Union[Iterable[int], int]): `n_or_a` が `int` のとき、初期値 `0` 、長さ `n` で構築します。
22 `n_or_a` が `Iterable` のとき、初期値 `a` で構築します。
23 """
24 if isinstance(n_or_a, int):
25 self._size = n_or_a
26 self._tree = [0] * (self._size + 1)
27 else:
28 a = n_or_a if isinstance(n_or_a, list) else list(n_or_a)
29 _size = len(a)
30 _tree = [0] + a
31 for i in range(1, _size):
32 if i + (i & -i) <= _size:
33 _tree[i + (i & -i)] += _tree[i]
34 self._size = _size
35 self._tree = _tree
36 self._s = 1 << (self._size - 1).bit_length()
37
38 def pref(self, r: int) -> int:
39 """区間 ``[0, r)`` の総和を返します。
40 :math:`O(\\log{n})` です。
41 """
42 assert (
43 0 <= r <= self._size
44 ), f"IndexError: {self.__class__.__name__}.pref({r}), n={self._size}"
45 ret, _tree = 0, self._tree
46 while r > 0:
47 ret += _tree[r]
48 r &= r - 1
49 return ret
50
51 def suff(self, l: int) -> int:
52 """区間 ``[l, n)`` の総和を返します。
53 :math:`O(\\log{n})` です。
54 """
55 assert (
56 0 <= l < self._size
57 ), f"IndexError: {self.__class__.__name__}.suff({l}), n={self._size}"
58 return self.pref(self._size) - self.pref(l)
59
60 def sum(self, l: int, r: int) -> int:
61 """区間 ``[l, r)`` の総和を返します。
62 :math:`O(\\log{n})` です。
63 """
64 assert (
65 0 <= l <= r <= self._size
66 ), f"IndexError: {self.__class__.__name__}.sum({l}, {r}), n={self._size}"
67 _tree = self._tree
68 res = 0
69 while r > l:
70 res += _tree[r]
71 r &= r - 1
72 while l > r:
73 res -= _tree[l]
74 l &= l - 1
75 return res
76
77 prod = sum
78
79 def __getitem__(self, k: int) -> int:
80 """位置 ``k`` の要素を返します。
81 :math:`O(\\log{n})` です。
82 """
83 assert (
84 -self._size <= k < self._size
85 ), f"IndexError: {self.__class__.__name__}[{k}], n={self._size}"
86 if k < 0:
87 k += self._size
88 return self.sum(k, k + 1)
89
90 def add(self, k: int, x: int) -> None:
91 """``k`` 番目の値に ``x`` を加えます。
92 :math:`O(\\log{n})` です。
93 """
94 assert (
95 0 <= k < self._size
96 ), f"IndexError: {self.__class__.__name__}.add({k}, {x}), n={self._size}"
97 k += 1
98 _tree = self._tree
99 while k <= self._size:
100 _tree[k] += x
101 k += k & -k
102
103 def __setitem__(self, k: int, x: int):
104 """``k`` 番目の値を ``x`` に更新します。
105 :math:`O(\\log{n})` です。
106 """
107 assert (
108 -self._size <= k < self._size
109 ), f"IndexError: {self.__class__.__name__}[{k}] = {x}, n={self._size}"
110 if k < 0:
111 k += self._size
112 pre = self[k]
113 self.add(k, x - pre)
114
115 def bisect_left(self, w: int) -> Optional[int]:
116 i, s, _size, _tree = 0, self._s, self._size, self._tree
117 while s:
118 if i + s <= _size and _tree[i + s] < w:
119 w -= _tree[i + s]
120 i += s
121 s >>= 1
122 return i if w else None
123
124 def bisect_right(self, w: int) -> int:
125 i, s, _size, _tree = 0, self._s, self._size, self._tree
126 while s:
127 if i + s <= _size and _tree[i + s] <= w:
128 w -= _tree[i + s]
129 i += s
130 s >>= 1
131 return i
132
133 def _pop(self, k: int) -> int:
134 assert k >= 0
135 i, acc, s, _size, _tree = 0, 0, self._s, self._size, self._tree
136 while s:
137 if i + s <= _size:
138 if acc + _tree[i + s] <= k:
139 acc += _tree[i + s]
140 i += s
141 else:
142 _tree[i + s] -= 1
143 s >>= 1
144 return i
145
146 def tolist(self) -> list[int]:
147 """リストにして返します。
148 :math:`O(n)` です。
149 """
150 sub = [self.pref(i) for i in range(self._size + 1)]
151 return [sub[i + 1] - sub[i] for i in range(self._size)]
152
153 @staticmethod
154 def get_inversion_num(a: list[int], compress: bool = False) -> int:
155 inv = 0
156 if compress:
157 a_ = sorted(set(a))
158 z = {e: i for i, e in enumerate(a_)}
159 fw = FenwickTree(len(a_) + 1)
160 for i, e in enumerate(a):
161 inv += i - fw.pref(z[e] + 1)
162 fw.add(z[e], 1)
163 else:
164 fw = FenwickTree(len(a) + 1)
165 for i, e in enumerate(a):
166 inv += i - fw.pref(e + 1)
167 fw.add(e, 1)
168 return inv
169
170 def __str__(self):
171 return str(self.tolist())
172
173 def __repr__(self):
174 return f"{self.__class__.__name__}({self})"
175from typing import Iterable, TypeVar, Generic, Union, Optional
176
177T = TypeVar("T", bound=SupportsLessThan)
178
179
180class FenwickTreeSet(Generic[T]):
181
182 def __init__(
183 self,
184 _used: Union[int, Iterable[T]],
185 _a: Iterable[T] = [],
186 compress=True,
187 _multi=False,
188 ) -> None:
189 self._len = 0
190 if isinstance(_used, int):
191 self._to_origin = list(range(_used))
192 elif isinstance(_used, set):
193 self._to_origin = sorted(_used)
194 else:
195 self._to_origin = sorted(set(_used))
196 self._to_zaatsu: dict[T, int] = (
197 {key: i for i, key in enumerate(self._to_origin)}
198 if compress
199 else self._to_origin
200 )
201 self._size = len(self._to_origin)
202 self._cnt = [0] * self._size
203 _a = list(_a)
204 if _a:
205 a_ = [0] * self._size
206 if _multi:
207 self._len = len(_a)
208 for v in _a:
209 i = self._to_zaatsu[v]
210 a_[i] += 1
211 self._cnt[i] += 1
212 else:
213 for v in _a:
214 i = self._to_zaatsu[v]
215 if self._cnt[i] == 0:
216 self._len += 1
217 a_[i] = 1
218 self._cnt[i] = 1
219 self._fw = FenwickTree(a_)
220 else:
221 self._fw = FenwickTree(self._size)
222
223 def add(self, key: T) -> bool:
224 i = self._to_zaatsu[key]
225 if self._cnt[i]:
226 return False
227 self._len += 1
228 self._cnt[i] = 1
229 self._fw.add(i, 1)
230 return True
231
232 def remove(self, key: T) -> None:
233 if not self.discard(key):
234 raise KeyError(key)
235
236 def discard(self, key: T) -> bool:
237 i = self._to_zaatsu[key]
238 if self._cnt[i]:
239 self._len -= 1
240 self._cnt[i] = 0
241 self._fw.add(i, -1)
242 return True
243 return False
244
245 def le(self, key: T) -> Optional[T]:
246 i = self._to_zaatsu[key]
247 if self._cnt[i]:
248 return key
249 pref = self._fw.pref(i) - 1
250 return None if pref < 0 else self._to_origin[self._fw.bisect_right(pref)]
251
252 def lt(self, key: T) -> Optional[T]:
253 pref = self._fw.pref(self._to_zaatsu[key]) - 1
254 return None if pref < 0 else self._to_origin[self._fw.bisect_right(pref)]
255
256 def ge(self, key: T) -> Optional[T]:
257 i = self._to_zaatsu[key]
258 if self._cnt[i]:
259 return key
260 pref = self._fw.pref(i + 1)
261 return (
262 None if pref >= self._len else self._to_origin[self._fw.bisect_right(pref)]
263 )
264
265 def gt(self, key: T) -> Optional[T]:
266 pref = self._fw.pref(self._to_zaatsu[key] + 1)
267 return (
268 None if pref >= self._len else self._to_origin[self._fw.bisect_right(pref)]
269 )
270
271 def index(self, key: T) -> int:
272 return self._fw.pref(self._to_zaatsu[key])
273
274 def index_right(self, key: T) -> int:
275 return self._fw.pref(self._to_zaatsu[key] + 1)
276
277 def pop(self, k: int = -1) -> T:
278 assert (
279 -self._len <= k < self._len
280 ), f"IndexError: FenwickTreeSet.pop({k}), Index out of range."
281 if k < 0:
282 k += self._len
283 self._len -= 1
284 x = self._fw._pop(k)
285 self._cnt[x] = 0
286 return self._to_origin[x]
287
288 def pop_min(self) -> T:
289 assert (
290 self._len > 0
291 ), f"IndexError: pop_min() from empty {self.__class__.__name__}."
292 return self.pop(0)
293
294 def pop_max(self) -> T:
295 assert (
296 self._len > 0
297 ), f"IndexError: pop_max() from empty {self.__class__.__name__}."
298 return self.pop(-1)
299
300 def get_min(self) -> Optional[T]:
301 if not self:
302 return
303 return self[0]
304
305 def get_max(self) -> Optional[T]:
306 if not self:
307 return
308 return self[-1]
309
310 def __getitem__(self, k):
311 assert (
312 -self._len <= k < self._len
313 ), f"IndexError: FenwickTreeSet[{k}], Index out of range."
314 if k < 0:
315 k += self._len
316 return self._to_origin[self._fw.bisect_right(k)]
317
318 def __iter__(self):
319 self._iter = 0
320 return self
321
322 def __next__(self):
323 if self._iter == self._len:
324 raise StopIteration
325 res = self._to_origin[self._fw.bisect_right(self._iter)]
326 self._iter += 1
327 return res
328
329 def __reversed__(self):
330 _to_origin = self._to_origin
331 for i in range(self._len):
332 yield _to_origin[self._fw.bisect_right(self._len - i - 1)]
333
334 def __len__(self):
335 return self._len
336
337 def __contains__(self, key: T):
338 return self._cnt[self._to_zaatsu[key]] > 0
339
340 def __bool__(self):
341 return self._len > 0
342
343 def __str__(self):
344 return "{" + ", ".join(map(str, self)) + "}"
345
346 def __repr__(self):
347 return f"{self.__class__.__name__}({self})"