1# from titan_pylib.data_structures.wavelet_matrix.cumulative_sum_wavelet_matrix import CumulativeSumWaveletMatrix
2# from titan_pylib.data_structures.bit_vector.bit_vector import BitVector
3# from titan_pylib.data_structures.bit_vector.bit_vector_interface import (
4# BitVectorInterface,
5# )
6from abc import ABC, abstractmethod
7
8
9class BitVectorInterface(ABC):
10
11 @abstractmethod
12 def access(self, k: int) -> int:
13 raise NotImplementedError
14
15 @abstractmethod
16 def __getitem__(self, k: int) -> int:
17 raise NotImplementedError
18
19 @abstractmethod
20 def rank0(self, r: int) -> int:
21 raise NotImplementedError
22
23 @abstractmethod
24 def rank1(self, r: int) -> int:
25 raise NotImplementedError
26
27 @abstractmethod
28 def rank(self, r: int, v: int) -> int:
29 raise NotImplementedError
30
31 @abstractmethod
32 def select0(self, k: int) -> int:
33 raise NotImplementedError
34
35 @abstractmethod
36 def select1(self, k: int) -> int:
37 raise NotImplementedError
38
39 @abstractmethod
40 def select(self, k: int, v: int) -> int:
41 raise NotImplementedError
42
43 @abstractmethod
44 def __len__(self) -> int:
45 raise NotImplementedError
46
47 @abstractmethod
48 def __str__(self) -> str:
49 raise NotImplementedError
50
51 @abstractmethod
52 def __repr__(self) -> str:
53 raise NotImplementedError
54from array import array
55
56
57class BitVector(BitVectorInterface):
58 """コンパクトな bit vector です。"""
59
60 def __init__(self, n: int):
61 """長さ ``n`` の ``BitVector`` です。
62
63 bit を保持するのに ``array[I]`` を使用します。
64 ``block_size= n / 32`` として、使用bitは ``32*block_size=2n bit`` です。
65
66 累積和を保持するのに同様の ``array[I]`` を使用します。
67 32bitごとの和を保存しています。同様に使用bitは ``2n bit`` です。
68 """
69 assert 0 <= n < 4294967295
70 self.N = n
71 self.block_size = (n + 31) >> 5
72 b = bytes(4 * (self.block_size + 1))
73 self.bit = array("I", b)
74 self.acc = array("I", b)
75
76 @staticmethod
77 def _popcount(x: int) -> int:
78 x = x - ((x >> 1) & 0x55555555)
79 x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
80 x = x + (x >> 4) & 0x0F0F0F0F
81 x += x >> 8
82 x += x >> 16
83 return x & 0x0000007F
84
85 def set(self, k: int) -> None:
86 """``k`` 番目の bit を ``1`` にします。
87 :math:`O(1)` です。
88
89 Args:
90 k (int): インデックスです。
91 """
92 self.bit[k >> 5] |= 1 << (k & 31)
93
94 def build(self) -> None:
95 """構築します。
96 **これ以降 ``set`` メソッドを使用してはいけません。**
97 :math:`O(n)` です。
98 """
99 acc, bit = self.acc, self.bit
100 for i in range(self.block_size):
101 acc[i + 1] = acc[i] + BitVector._popcount(bit[i])
102
103 def access(self, k: int) -> int:
104 """``k`` 番目の bit を返します。
105 :math:`O(1)` です。
106 """
107 return (self.bit[k >> 5] >> (k & 31)) & 1
108
109 def __getitem__(self, k: int) -> int:
110 return (self.bit[k >> 5] >> (k & 31)) & 1
111
112 def rank0(self, r: int) -> int:
113 """``a[0, r)`` に含まれる ``0`` の個数を返します。
114 :math:`O(1)` です。
115 """
116 return r - (
117 self.acc[r >> 5]
118 + BitVector._popcount(self.bit[r >> 5] & ((1 << (r & 31)) - 1))
119 )
120
121 def rank1(self, r: int) -> int:
122 """``a[0, r)`` に含まれる ``1`` の個数を返します。
123 :math:`O(1)` です。
124 """
125 return self.acc[r >> 5] + BitVector._popcount(
126 self.bit[r >> 5] & ((1 << (r & 31)) - 1)
127 )
128
129 def rank(self, r: int, v: int) -> int:
130 """``a[0, r)`` に含まれる ``v`` の個数を返します。
131 :math:`O(1)` です。
132 """
133 return self.rank1(r) if v else self.rank0(r)
134
135 def select0(self, k: int) -> int:
136 """``k`` 番目の ``0`` のインデックスを返します。
137 :math:`O(\\log{n})` です。
138 """
139 if k < 0 or self.rank0(self.N) <= k:
140 return -1
141 l, r = 0, self.block_size + 1
142 while r - l > 1:
143 m = (l + r) >> 1
144 if m * 32 - self.acc[m] > k:
145 r = m
146 else:
147 l = m
148 indx = 32 * l
149 k = k - (l * 32 - self.acc[l]) + self.rank0(indx)
150 l, r = indx, indx + 32
151 while r - l > 1:
152 m = (l + r) >> 1
153 if self.rank0(m) > k:
154 r = m
155 else:
156 l = m
157 return l
158
159 def select1(self, k: int) -> int:
160 """``k`` 番目の ``1`` のインデックスを返します。
161 :math:`O(\\log{n})` です。
162 """
163 if k < 0 or self.rank1(self.N) <= k:
164 return -1
165 l, r = 0, self.block_size + 1
166 while r - l > 1:
167 m = (l + r) >> 1
168 if self.acc[m] > k:
169 r = m
170 else:
171 l = m
172 indx = 32 * l
173 k = k - self.acc[l] + self.rank1(indx)
174 l, r = indx, indx + 32
175 while r - l > 1:
176 m = (l + r) >> 1
177 if self.rank1(m) > k:
178 r = m
179 else:
180 l = m
181 return l
182
183 def select(self, k: int, v: int) -> int:
184 """``k`` 番目の ``v`` のインデックスを返します。
185 :math:`O(\\log{n})` です。
186 """
187 return self.select1(k) if v else self.select0(k)
188
189 def __len__(self):
190 return self.N
191
192 def __str__(self):
193 return str([self.access(i) for i in range(self.N)])
194
195 def __repr__(self):
196 return f"{self.__class__.__name__}({self})"
197# from titan_pylib.data_structures.cumulative_sum.cumulative_sum import CumulativeSum
198from typing import Iterable
199
200
201class CumulativeSum:
202 """1次元累積和です。"""
203
204 def __init__(self, a: Iterable[int], e: int = 0):
205 """
206 :math:`O(n)` です。
207
208 Args:
209 a (Iterable[int]): ``CumulativeSum`` を構築する配列です。
210 e (int): 単位元です。デフォルトは ``0`` です。
211 """
212 a = list(a)
213 n = len(a)
214 acc = [e] * (n + 1)
215 for i in range(n):
216 acc[i + 1] = acc[i] + a[i]
217 self.n = n
218 self.acc = acc
219 self.a = a
220
221 def pref(self, r: int) -> int:
222 """区間 ``[0, r)`` の演算結果を返します。
223 :math:`O(1)` です。
224
225 Args:
226 r (int): インデックスです。
227 """
228 assert (
229 0 <= r <= self.n
230 ), f"IndexError: {self.__class__.__name__}.pref({r}), n={self.n}"
231 return self.acc[r]
232
233 def all_sum(self) -> int:
234 """区間 `[0, n)` の演算結果を返します。
235 :math:`O(1)` です。
236
237 Args:
238 l (int): インデックスです。
239 r (int): インデックスです。
240 """
241 return self.acc[-1]
242
243 def sum(self, l: int, r: int) -> int:
244 """区間 `[l, r)` の演算結果を返します。
245 :math:`O(1)` です。
246
247 Args:
248 l (int): インデックスです。
249 r (int): インデックスです。
250 """
251 assert (
252 0 <= l <= r <= self.n
253 ), f"IndexError: {self.__class__.__name__}.sum({l}, {r}), n={self.n}"
254 return self.acc[r] - self.acc[l]
255
256 prod = sum
257 all_prod = all_sum
258
259 def __getitem__(self, k: int) -> int:
260 assert (
261 -self.n <= k < self.n
262 ), f"IndexError: {self.__class__.__name__}[{k}], n={self.n}"
263 return self.a[k]
264
265 def __len__(self) -> int:
266 return len(self.a)
267
268 def __str__(self) -> str:
269 return str(self.acc)
270
271 __repr__ = __str__
272from array import array
273from typing import Sequence, Iterable
274from bisect import bisect_left
275
276
277class CumulativeSumWaveletMatrix:
278
279 def __init__(self, sigma: int, pos: Iterable[tuple[int, int, int]] = []) -> None:
280 """
281 Args:
282 sigma (int): yの最大値
283 pos (list[tuple[int, int, int]], optional): list[(x, y, w)]
284 """
285 self.sigma: int = sigma
286 self.log: int = (sigma - 1).bit_length()
287 self.mid: array[int] = array("I", bytes(4 * self.log))
288 self.xy: list[tuple[int, int]] = self._sort_unique([(x, y) for x, y, _ in pos])
289 self.y: list[int] = self._sort_unique([y for _, y in self.xy])
290 self.size: int = len(self.xy)
291 self.v: list[BitVector] = [BitVector(self.size) for _ in range(self.log)]
292 self._build([bisect_left(self.y, y) for _, y in self.xy])
293 ws = [[0] * self.size for _ in range(self.log)]
294 for x, y, w in pos:
295 k = bisect_left(self.xy, (x, y))
296 i_y = bisect_left(self.y, y)
297 for bit in range(self.log - 1, -1, -1):
298 if i_y >> bit & 1:
299 k = self.v[bit].rank1(k) + self.mid[bit]
300 else:
301 k = self.v[bit].rank0(k)
302 ws[bit][k] += w
303 self.bit: list[CumulativeSum] = [CumulativeSum(a) for a in ws]
304
305 def _build(self, a: Sequence[int]) -> None:
306 # 列 a から wm を構築する
307 for bit in range(self.log - 1, -1, -1):
308 # bit目の0/1に応じてvを構築 + aを安定ソート
309 v = self.v[bit]
310 zero, one = [], []
311 for i, e in enumerate(a):
312 if e >> bit & 1:
313 v.set(i)
314 one.append(e)
315 else:
316 zero.append(e)
317 v.build()
318 self.mid[bit] = len(zero) # 境界をmid[bit]に保持
319 a = zero + one
320
321 def _sort_unique(self, a: list) -> list:
322 if not a:
323 return a
324 a.sort()
325 b = [a[0]]
326 for e in a:
327 if b[-1] == e:
328 continue
329 b.append(e)
330 return b
331
332 def _sum(self, l: int, r: int, x: int) -> int:
333 ans = 0
334 for bit in range(self.log - 1, -1, -1):
335 l0, r0 = self.v[bit].rank0(l), self.v[bit].rank0(r)
336 if x >> bit & 1:
337 l += self.mid[bit] - l0
338 r += self.mid[bit] - r0
339 ans += self.bit[bit].sum(l0, r0)
340 else:
341 l, r = l0, r0
342 return ans
343
344 def sum(self, w1: int, w2: int, h1: int, h2: int) -> int:
345 """sum([w1, w2) x [h1, h2))"""
346 assert 0 <= w1 <= w2
347 assert 0 <= h1 <= h2
348 l = bisect_left(self.xy, (w1, 0))
349 r = bisect_left(self.xy, (w2, 0))
350 return self._sum(l, r, bisect_left(self.y, h2)) - self._sum(
351 l, r, bisect_left(self.y, h1)
352 )