1# from titan_pylib.data_structures.wavelet_matrix.fenwick_tree_wavelet_matrix import FenwickTreeWaveletMatrix
2# from titan_pylib.data_structures.bit_vector.bit_vector import BitVector
3# from titan_pylib.data_structures.bit_vector.bit_vector_interface import (
4# BitVectorInterface,
5# )
6from abc import ABC, abstractmethod
7
8
9class BitVectorInterface(ABC):
10
11 @abstractmethod
12 def access(self, k: int) -> int:
13 raise NotImplementedError
14
15 @abstractmethod
16 def __getitem__(self, k: int) -> int:
17 raise NotImplementedError
18
19 @abstractmethod
20 def rank0(self, r: int) -> int:
21 raise NotImplementedError
22
23 @abstractmethod
24 def rank1(self, r: int) -> int:
25 raise NotImplementedError
26
27 @abstractmethod
28 def rank(self, r: int, v: int) -> int:
29 raise NotImplementedError
30
31 @abstractmethod
32 def select0(self, k: int) -> int:
33 raise NotImplementedError
34
35 @abstractmethod
36 def select1(self, k: int) -> int:
37 raise NotImplementedError
38
39 @abstractmethod
40 def select(self, k: int, v: int) -> int:
41 raise NotImplementedError
42
43 @abstractmethod
44 def __len__(self) -> int:
45 raise NotImplementedError
46
47 @abstractmethod
48 def __str__(self) -> str:
49 raise NotImplementedError
50
51 @abstractmethod
52 def __repr__(self) -> str:
53 raise NotImplementedError
54from array import array
55
56
57class BitVector(BitVectorInterface):
58 """コンパクトな bit vector です。"""
59
60 def __init__(self, n: int):
61 """長さ ``n`` の ``BitVector`` です。
62
63 bit を保持するのに ``array[I]`` を使用します。
64 ``block_size= n / 32`` として、使用bitは ``32*block_size=2n bit`` です。
65
66 累積和を保持するのに同様の ``array[I]`` を使用します。
67 32bitごとの和を保存しています。同様に使用bitは ``2n bit`` です。
68 """
69 assert 0 <= n < 4294967295
70 self.N = n
71 self.block_size = (n + 31) >> 5
72 b = bytes(4 * (self.block_size + 1))
73 self.bit = array("I", b)
74 self.acc = array("I", b)
75
76 @staticmethod
77 def _popcount(x: int) -> int:
78 x = x - ((x >> 1) & 0x55555555)
79 x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
80 x = x + (x >> 4) & 0x0F0F0F0F
81 x += x >> 8
82 x += x >> 16
83 return x & 0x0000007F
84
85 def set(self, k: int) -> None:
86 """``k`` 番目の bit を ``1`` にします。
87 :math:`O(1)` です。
88
89 Args:
90 k (int): インデックスです。
91 """
92 self.bit[k >> 5] |= 1 << (k & 31)
93
94 def build(self) -> None:
95 """構築します。
96 **これ以降 ``set`` メソッドを使用してはいけません。**
97 :math:`O(n)` です。
98 """
99 acc, bit = self.acc, self.bit
100 for i in range(self.block_size):
101 acc[i + 1] = acc[i] + BitVector._popcount(bit[i])
102
103 def access(self, k: int) -> int:
104 """``k`` 番目の bit を返します。
105 :math:`O(1)` です。
106 """
107 return (self.bit[k >> 5] >> (k & 31)) & 1
108
109 def __getitem__(self, k: int) -> int:
110 return (self.bit[k >> 5] >> (k & 31)) & 1
111
112 def rank0(self, r: int) -> int:
113 """``a[0, r)`` に含まれる ``0`` の個数を返します。
114 :math:`O(1)` です。
115 """
116 return r - (
117 self.acc[r >> 5]
118 + BitVector._popcount(self.bit[r >> 5] & ((1 << (r & 31)) - 1))
119 )
120
121 def rank1(self, r: int) -> int:
122 """``a[0, r)`` に含まれる ``1`` の個数を返します。
123 :math:`O(1)` です。
124 """
125 return self.acc[r >> 5] + BitVector._popcount(
126 self.bit[r >> 5] & ((1 << (r & 31)) - 1)
127 )
128
129 def rank(self, r: int, v: int) -> int:
130 """``a[0, r)`` に含まれる ``v`` の個数を返します。
131 :math:`O(1)` です。
132 """
133 return self.rank1(r) if v else self.rank0(r)
134
135 def select0(self, k: int) -> int:
136 """``k`` 番目の ``0`` のインデックスを返します。
137 :math:`O(\\log{n})` です。
138 """
139 if k < 0 or self.rank0(self.N) <= k:
140 return -1
141 l, r = 0, self.block_size + 1
142 while r - l > 1:
143 m = (l + r) >> 1
144 if m * 32 - self.acc[m] > k:
145 r = m
146 else:
147 l = m
148 indx = 32 * l
149 k = k - (l * 32 - self.acc[l]) + self.rank0(indx)
150 l, r = indx, indx + 32
151 while r - l > 1:
152 m = (l + r) >> 1
153 if self.rank0(m) > k:
154 r = m
155 else:
156 l = m
157 return l
158
159 def select1(self, k: int) -> int:
160 """``k`` 番目の ``1`` のインデックスを返します。
161 :math:`O(\\log{n})` です。
162 """
163 if k < 0 or self.rank1(self.N) <= k:
164 return -1
165 l, r = 0, self.block_size + 1
166 while r - l > 1:
167 m = (l + r) >> 1
168 if self.acc[m] > k:
169 r = m
170 else:
171 l = m
172 indx = 32 * l
173 k = k - self.acc[l] + self.rank1(indx)
174 l, r = indx, indx + 32
175 while r - l > 1:
176 m = (l + r) >> 1
177 if self.rank1(m) > k:
178 r = m
179 else:
180 l = m
181 return l
182
183 def select(self, k: int, v: int) -> int:
184 """``k`` 番目の ``v`` のインデックスを返します。
185 :math:`O(\\log{n})` です。
186 """
187 return self.select1(k) if v else self.select0(k)
188
189 def __len__(self):
190 return self.N
191
192 def __str__(self):
193 return str([self.access(i) for i in range(self.N)])
194
195 def __repr__(self):
196 return f"{self.__class__.__name__}({self})"
197# from titan_pylib.data_structures.fenwick_tree.fenwick_tree import FenwickTree
198from typing import Union, Iterable, Optional
199
200
201class FenwickTree:
202 """FenwickTreeです。"""
203
204 def __init__(self, n_or_a: Union[Iterable[int], int]):
205 """構築します。
206 :math:`O(n)` です。
207
208 Args:
209 n_or_a (Union[Iterable[int], int]): `n_or_a` が `int` のとき、初期値 `0` 、長さ `n` で構築します。
210 `n_or_a` が `Iterable` のとき、初期値 `a` で構築します。
211 """
212 if isinstance(n_or_a, int):
213 self._size = n_or_a
214 self._tree = [0] * (self._size + 1)
215 else:
216 a = n_or_a if isinstance(n_or_a, list) else list(n_or_a)
217 _size = len(a)
218 _tree = [0] + a
219 for i in range(1, _size):
220 if i + (i & -i) <= _size:
221 _tree[i + (i & -i)] += _tree[i]
222 self._size = _size
223 self._tree = _tree
224 self._s = 1 << (self._size - 1).bit_length()
225
226 def pref(self, r: int) -> int:
227 """区間 ``[0, r)`` の総和を返します。
228 :math:`O(\\log{n})` です。
229 """
230 assert (
231 0 <= r <= self._size
232 ), f"IndexError: {self.__class__.__name__}.pref({r}), n={self._size}"
233 ret, _tree = 0, self._tree
234 while r > 0:
235 ret += _tree[r]
236 r &= r - 1
237 return ret
238
239 def suff(self, l: int) -> int:
240 """区間 ``[l, n)`` の総和を返します。
241 :math:`O(\\log{n})` です。
242 """
243 assert (
244 0 <= l < self._size
245 ), f"IndexError: {self.__class__.__name__}.suff({l}), n={self._size}"
246 return self.pref(self._size) - self.pref(l)
247
248 def sum(self, l: int, r: int) -> int:
249 """区間 ``[l, r)`` の総和を返します。
250 :math:`O(\\log{n})` です。
251 """
252 assert (
253 0 <= l <= r <= self._size
254 ), f"IndexError: {self.__class__.__name__}.sum({l}, {r}), n={self._size}"
255 _tree = self._tree
256 res = 0
257 while r > l:
258 res += _tree[r]
259 r &= r - 1
260 while l > r:
261 res -= _tree[l]
262 l &= l - 1
263 return res
264
265 prod = sum
266
267 def __getitem__(self, k: int) -> int:
268 """位置 ``k`` の要素を返します。
269 :math:`O(\\log{n})` です。
270 """
271 assert (
272 -self._size <= k < self._size
273 ), f"IndexError: {self.__class__.__name__}[{k}], n={self._size}"
274 if k < 0:
275 k += self._size
276 return self.sum(k, k + 1)
277
278 def add(self, k: int, x: int) -> None:
279 """``k`` 番目の値に ``x`` を加えます。
280 :math:`O(\\log{n})` です。
281 """
282 assert (
283 0 <= k < self._size
284 ), f"IndexError: {self.__class__.__name__}.add({k}, {x}), n={self._size}"
285 k += 1
286 _tree = self._tree
287 while k <= self._size:
288 _tree[k] += x
289 k += k & -k
290
291 def __setitem__(self, k: int, x: int):
292 """``k`` 番目の値を ``x`` に更新します。
293 :math:`O(\\log{n})` です。
294 """
295 assert (
296 -self._size <= k < self._size
297 ), f"IndexError: {self.__class__.__name__}[{k}] = {x}, n={self._size}"
298 if k < 0:
299 k += self._size
300 pre = self[k]
301 self.add(k, x - pre)
302
303 def bisect_left(self, w: int) -> Optional[int]:
304 i, s, _size, _tree = 0, self._s, self._size, self._tree
305 while s:
306 if i + s <= _size and _tree[i + s] < w:
307 w -= _tree[i + s]
308 i += s
309 s >>= 1
310 return i if w else None
311
312 def bisect_right(self, w: int) -> int:
313 i, s, _size, _tree = 0, self._s, self._size, self._tree
314 while s:
315 if i + s <= _size and _tree[i + s] <= w:
316 w -= _tree[i + s]
317 i += s
318 s >>= 1
319 return i
320
321 def _pop(self, k: int) -> int:
322 assert k >= 0
323 i, acc, s, _size, _tree = 0, 0, self._s, self._size, self._tree
324 while s:
325 if i + s <= _size:
326 if acc + _tree[i + s] <= k:
327 acc += _tree[i + s]
328 i += s
329 else:
330 _tree[i + s] -= 1
331 s >>= 1
332 return i
333
334 def tolist(self) -> list[int]:
335 """リストにして返します。
336 :math:`O(n)` です。
337 """
338 sub = [self.pref(i) for i in range(self._size + 1)]
339 return [sub[i + 1] - sub[i] for i in range(self._size)]
340
341 @staticmethod
342 def get_inversion_num(a: list[int], compress: bool = False) -> int:
343 inv = 0
344 if compress:
345 a_ = sorted(set(a))
346 z = {e: i for i, e in enumerate(a_)}
347 fw = FenwickTree(len(a_) + 1)
348 for i, e in enumerate(a):
349 inv += i - fw.pref(z[e] + 1)
350 fw.add(z[e], 1)
351 else:
352 fw = FenwickTree(len(a) + 1)
353 for i, e in enumerate(a):
354 inv += i - fw.pref(e + 1)
355 fw.add(e, 1)
356 return inv
357
358 def __str__(self):
359 return str(self.tolist())
360
361 def __repr__(self):
362 return f"{self.__class__.__name__}({self})"
363from array import array
364from typing import Sequence
365from bisect import bisect_left
366
367
368class FenwickTreeWaveletMatrix:
369
370 def __init__(self, sigma: int, pos: list[tuple[int, int, int]] = []):
371 self.sigma: int = sigma
372 self.log: int = (sigma - 1).bit_length()
373 self.mid: array[int] = array("I", bytes(4 * self.log))
374 self.xy: list[tuple[int, int]] = self._sort_unique([(x, y) for x, y, _ in pos])
375 self.y: list[int] = self._sort_unique([y for _, y, _ in pos])
376 self.size: int = len(self.xy)
377 self.v: list[BitVector] = [BitVector(self.size) for _ in range(self.log)]
378 self._build([bisect_left(self.y, y) for _, y in self.xy])
379 ws = [[0] * self.size for _ in range(self.log)]
380 for x, y, w in pos:
381 k = bisect_left(self.xy, (x, y))
382 i_y = bisect_left(self.y, y)
383 for bit in range(self.log - 1, -1, -1):
384 if i_y >> bit & 1:
385 k = self.v[bit].rank1(k) + self.mid[bit]
386 else:
387 k = self.v[bit].rank0(k)
388 ws[bit][k] += w
389 self.bit: list[FenwickTree] = [FenwickTree(a) for a in ws]
390
391 def _build(self, a: Sequence[int]) -> None:
392 # 列 a から wm を構築する
393 for bit in range(self.log - 1, -1, -1):
394 # bit目の0/1に応じてvを構築 + aを安定ソート
395 v = self.v[bit]
396 zero, one = [], []
397 for i, e in enumerate(a):
398 if e >> bit & 1:
399 v.set(i)
400 one.append(e)
401 else:
402 zero.append(e)
403 v.build()
404 self.mid[bit] = len(zero) # 境界をmid[bit]に保持
405 a = zero + one
406
407 def _sort_unique(self, a: list) -> list:
408 if not a:
409 return a
410 a.sort()
411 b = [a[0]]
412 for e in a:
413 if b[-1] == e:
414 continue
415 b.append(e)
416 return b
417
418 def add_point(self, x: int, y: int, w: int) -> None:
419 k = bisect_left(self.xy, (x, y))
420 i_y = bisect_left(self.y, y)
421 for bit in range(self.log - 1, -1, -1):
422 if i_y >> bit & 1:
423 k = self.v[bit].rank1(k) + self.mid[bit]
424 else:
425 k = self.v[bit].rank0(k)
426 self.bit[bit].add(k, w)
427
428 def _sum(self, l: int, r: int, x: int) -> int:
429 ans = 0
430 for bit in range(self.log - 1, -1, -1):
431 l0, r0 = self.v[bit].rank0(l), self.v[bit].rank0(r)
432 if x >> bit & 1:
433 l += self.mid[bit] - l0
434 r += self.mid[bit] - r0
435 ans += self.bit[bit].sum(l0, r0)
436 else:
437 l, r = l0, r0
438 return ans
439
440 def sum(self, w1: int, w2: int, h1: int, h2: int) -> int:
441 # sum([w1, w2) x [h1, h2))
442 l = bisect_left(self.xy, (w1, 0))
443 r = bisect_left(self.xy, (w2, 0))
444 return self._sum(l, r, bisect_left(self.y, h2)) - self._sum(
445 l, r, bisect_left(self.y, h1)
446 )