wavelet_matrix¶
ソースコード¶
from titan_pylib.data_structures.wavelet_matrix.wavelet_matrix import WaveletMatrix
展開済みコード¶
1# from titan_pylib.data_structures.wavelet_matrix.wavelet_matrix import WaveletMatrix
2# from titan_pylib.data_structures.bit_vector.bit_vector import BitVector
3# from titan_pylib.data_structures.bit_vector.bit_vector_interface import (
4# BitVectorInterface,
5# )
6from abc import ABC, abstractmethod
7
8
9class BitVectorInterface(ABC):
10
11 @abstractmethod
12 def access(self, k: int) -> int:
13 raise NotImplementedError
14
15 @abstractmethod
16 def __getitem__(self, k: int) -> int:
17 raise NotImplementedError
18
19 @abstractmethod
20 def rank0(self, r: int) -> int:
21 raise NotImplementedError
22
23 @abstractmethod
24 def rank1(self, r: int) -> int:
25 raise NotImplementedError
26
27 @abstractmethod
28 def rank(self, r: int, v: int) -> int:
29 raise NotImplementedError
30
31 @abstractmethod
32 def select0(self, k: int) -> int:
33 raise NotImplementedError
34
35 @abstractmethod
36 def select1(self, k: int) -> int:
37 raise NotImplementedError
38
39 @abstractmethod
40 def select(self, k: int, v: int) -> int:
41 raise NotImplementedError
42
43 @abstractmethod
44 def __len__(self) -> int:
45 raise NotImplementedError
46
47 @abstractmethod
48 def __str__(self) -> str:
49 raise NotImplementedError
50
51 @abstractmethod
52 def __repr__(self) -> str:
53 raise NotImplementedError
54from array import array
55
56
57class BitVector(BitVectorInterface):
58 """コンパクトな bit vector です。"""
59
60 def __init__(self, n: int):
61 """長さ ``n`` の ``BitVector`` です。
62
63 bit を保持するのに ``array[I]`` を使用します。
64 ``block_size= n / 32`` として、使用bitは ``32*block_size=2n bit`` です。
65
66 累積和を保持するのに同様の ``array[I]`` を使用します。
67 32bitごとの和を保存しています。同様に使用bitは ``2n bit`` です。
68 """
69 assert 0 <= n < 4294967295
70 self.N = n
71 self.block_size = (n + 31) >> 5
72 b = bytes(4 * (self.block_size + 1))
73 self.bit = array("I", b)
74 self.acc = array("I", b)
75
76 @staticmethod
77 def _popcount(x: int) -> int:
78 x = x - ((x >> 1) & 0x55555555)
79 x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
80 x = x + (x >> 4) & 0x0F0F0F0F
81 x += x >> 8
82 x += x >> 16
83 return x & 0x0000007F
84
85 def set(self, k: int) -> None:
86 """``k`` 番目の bit を ``1`` にします。
87 :math:`O(1)` です。
88
89 Args:
90 k (int): インデックスです。
91 """
92 self.bit[k >> 5] |= 1 << (k & 31)
93
94 def build(self) -> None:
95 """構築します。
96 **これ以降 ``set`` メソッドを使用してはいけません。**
97 :math:`O(n)` です。
98 """
99 acc, bit = self.acc, self.bit
100 for i in range(self.block_size):
101 acc[i + 1] = acc[i] + BitVector._popcount(bit[i])
102
103 def access(self, k: int) -> int:
104 """``k`` 番目の bit を返します。
105 :math:`O(1)` です。
106 """
107 return (self.bit[k >> 5] >> (k & 31)) & 1
108
109 def __getitem__(self, k: int) -> int:
110 return (self.bit[k >> 5] >> (k & 31)) & 1
111
112 def rank0(self, r: int) -> int:
113 """``a[0, r)`` に含まれる ``0`` の個数を返します。
114 :math:`O(1)` です。
115 """
116 return r - (
117 self.acc[r >> 5]
118 + BitVector._popcount(self.bit[r >> 5] & ((1 << (r & 31)) - 1))
119 )
120
121 def rank1(self, r: int) -> int:
122 """``a[0, r)`` に含まれる ``1`` の個数を返します。
123 :math:`O(1)` です。
124 """
125 return self.acc[r >> 5] + BitVector._popcount(
126 self.bit[r >> 5] & ((1 << (r & 31)) - 1)
127 )
128
129 def rank(self, r: int, v: int) -> int:
130 """``a[0, r)`` に含まれる ``v`` の個数を返します。
131 :math:`O(1)` です。
132 """
133 return self.rank1(r) if v else self.rank0(r)
134
135 def select0(self, k: int) -> int:
136 """``k`` 番目の ``0`` のインデックスを返します。
137 :math:`O(\\log{n})` です。
138 """
139 if k < 0 or self.rank0(self.N) <= k:
140 return -1
141 l, r = 0, self.block_size + 1
142 while r - l > 1:
143 m = (l + r) >> 1
144 if m * 32 - self.acc[m] > k:
145 r = m
146 else:
147 l = m
148 indx = 32 * l
149 k = k - (l * 32 - self.acc[l]) + self.rank0(indx)
150 l, r = indx, indx + 32
151 while r - l > 1:
152 m = (l + r) >> 1
153 if self.rank0(m) > k:
154 r = m
155 else:
156 l = m
157 return l
158
159 def select1(self, k: int) -> int:
160 """``k`` 番目の ``1`` のインデックスを返します。
161 :math:`O(\\log{n})` です。
162 """
163 if k < 0 or self.rank1(self.N) <= k:
164 return -1
165 l, r = 0, self.block_size + 1
166 while r - l > 1:
167 m = (l + r) >> 1
168 if self.acc[m] > k:
169 r = m
170 else:
171 l = m
172 indx = 32 * l
173 k = k - self.acc[l] + self.rank1(indx)
174 l, r = indx, indx + 32
175 while r - l > 1:
176 m = (l + r) >> 1
177 if self.rank1(m) > k:
178 r = m
179 else:
180 l = m
181 return l
182
183 def select(self, k: int, v: int) -> int:
184 """``k`` 番目の ``v`` のインデックスを返します。
185 :math:`O(\\log{n})` です。
186 """
187 return self.select1(k) if v else self.select0(k)
188
189 def __len__(self):
190 return self.N
191
192 def __str__(self):
193 return str([self.access(i) for i in range(self.N)])
194
195 def __repr__(self):
196 return f"{self.__class__.__name__}({self})"
197from typing import Sequence
198from heapq import heappush, heappop
199from array import array
200
201
202class WaveletMatrix:
203 """``WaveletMatrix`` です。
204 静的であることに注意してください。
205
206 以下の仕様の計算量には嘘があるかもしれません。import 元の ``BitVector`` の計算量も参考にしてください。
207
208 参考:
209 `https://miti-7.hatenablog.com/entry/2018/04/28/152259 <https://miti-7.hatenablog.com/entry/2018/04/28/152259>`_
210 `https://www.slideshare.net/pfi/ss-15916040 <https://www.slideshare.net/pfi/ss-15916040>`_
211 `デwiki <https://scrapbox.io/data-structures/Wavelet_Matrix>`_
212 """
213
214 def __init__(self, sigma: int, a: Sequence[int] = []) -> None:
215 """``[0, sigma)`` の整数列を管理する ``WaveletMatrix`` を構築します。
216 :math:`O(n\\log{\\sigma})` です。
217
218 Args:
219 sigma (int): 扱う整数の上限です。
220 a (Sequence[int], optional): 構築する配列です。
221 """
222 self.sigma: int = sigma
223 self.log: int = (sigma - 1).bit_length()
224 self.mid: array[int] = array("I", bytes(4 * self.log))
225 self.size: int = len(a)
226 self.v: list[BitVector] = [BitVector(self.size) for _ in range(self.log)]
227 self._build(a)
228
229 def _build(self, a: Sequence[int]) -> None:
230 # 列 a から wm を構築する
231 for bit in range(self.log - 1, -1, -1):
232 # bit目の0/1に応じてvを構築 + aを安定ソート
233 v = self.v[bit]
234 zero, one = [], []
235 for i, e in enumerate(a):
236 if e >> bit & 1:
237 v.set(i)
238 one.append(e)
239 else:
240 zero.append(e)
241 v.build()
242 self.mid[bit] = len(zero) # 境界をmid[bit]に保持
243 a = zero + one
244
245 def access(self, k: int) -> int:
246 """``k`` 番目の値を返します。
247 :math:`O(\\log{\\sigma})` です。
248
249 Args:
250 k (int): インデックスです。
251 """
252 assert (
253 -self.size <= k < self.size
254 ), f"IndexError: {self.__class__.__name__}.access({k}), size={self.size}"
255 if k < 0:
256 k += self.size
257 s = 0 # 答え
258 for bit in range(self.log - 1, -1, -1):
259 if self.v[bit].access(k):
260 # k番目が立ってたら、
261 # kまでの1とすべての0が次のk
262 s |= 1 << bit
263 k = self.v[bit].rank1(k) + self.mid[bit]
264 else:
265 # kまでの0が次のk
266 k = self.v[bit].rank0(k)
267 return s
268
269 def __getitem__(self, k: int) -> int:
270 assert (
271 -self.size <= k < self.size
272 ), f"IndexError: {self.__class__.__name__}[{k}], size={self.size}"
273 return self.access(k)
274
275 def rank(self, r: int, x: int) -> int:
276 """``a[0, r)`` に含まれる ``x`` の個数を返します。
277 :math:`O(\\log{\\sigma})` です。
278 """
279 assert (
280 0 <= r <= self.size
281 ), f"IndexError: {self.__class__.__name__}.rank(), r={r}, size={self.size}"
282 assert (
283 0 <= x < 1 << self.log
284 ), f"ValueError: {self.__class__.__name__}.rank(), x={x}, LIM={1<<self.log}"
285 l = 0
286 mid = self.mid
287 for bit in range(self.log - 1, -1, -1):
288 # 位置 r より左に x が何個あるか
289 # x の bit 目で場合分け
290 if x >> bit & 1:
291 # 立ってたら、次のl, rは以下
292 l = self.v[bit].rank1(l) + mid[bit]
293 r = self.v[bit].rank1(r) + mid[bit]
294 else:
295 # そうでなければ次のl, rは以下
296 l = self.v[bit].rank0(l)
297 r = self.v[bit].rank0(r)
298 return r - l
299
300 def select(self, k: int, x: int) -> int:
301 """``k`` 番目の ``v`` のインデックスを返します。
302 :math:`O(\\log{\\sigma})` です。
303 """
304 assert (
305 0 <= k < self.size
306 ), f"IndexError: {self.__class__.__name__}.select({k}, {x}), k={k}, size={self.size}"
307 assert (
308 0 <= x < 1 << self.log
309 ), f"ValueError: {self.__class__.__name__}.select({k}, {x}), x={x}, LIM={1<<self.log}"
310 # x の開始位置 s を探す
311 s = 0
312 for bit in range(self.log - 1, -1, -1):
313 if x >> bit & 1:
314 s = self.v[bit].rank0(self.size) + self.v[bit].rank1(s)
315 else:
316 s = self.v[bit].rank0(s)
317 s += k # s から k 進んだ位置が、元の列で何番目か調べる
318 for bit in range(self.log):
319 if x >> bit & 1:
320 s = self.v[bit].select1(s - self.v[bit].rank0(self.size))
321 else:
322 s = self.v[bit].select0(s)
323 return s
324
325 def kth_smallest(self, l: int, r: int, k: int) -> int:
326 """``a[l, r)`` の中で ``k`` 番目に **小さい** 値を返します。
327 :math:`O(\\log{\\sigma})` です。
328 """
329 assert (
330 0 <= l <= r <= self.size
331 ), f"IndexError: {self.__class__.__name__}.kth_smallest({l}, {r}, {k}), size={self.size}"
332 assert (
333 0 <= k < r - l
334 ), f"IndexError: {self.__class__.__name__}.kth_smallest({l}, {r}, {k}), wrong k"
335 s = 0
336 mid = self.mid
337 for bit in range(self.log - 1, -1, -1):
338 r0, l0 = self.v[bit].rank0(r), self.v[bit].rank0(l)
339 cnt = r0 - l0 # 区間内の 0 の個数
340 if cnt <= k: # 0 が k 以下のとき、 k 番目は 1
341 s |= 1 << bit
342 k -= cnt
343 # この 1 が次の bit 列でどこに行くか
344 l = l - l0 + mid[bit]
345 r = r - r0 + mid[bit]
346 else:
347 # この 0 が次の bit 列でどこに行くか
348 l = l0
349 r = r0
350 return s
351
352 quantile = kth_smallest
353
354 def kth_largest(self, l: int, r: int, k: int) -> int:
355 """``a[l, r)`` の中で ``k`` 番目に **大きい値** を返します。
356 :math:`O(\\log{\\sigma})` です。
357 """
358 assert (
359 0 <= l <= r <= self.size
360 ), f"IndexError: {self.__class__.__name__}.kth_largest({l}, {r}, {k}), size={self.size}"
361 assert (
362 0 <= k < r - l
363 ), f"IndexError: {self.__class__.__name__}.kth_largest({l}, {r}, {k}), wrong k"
364 return self.kth_smallest(l, r, r - l - k - 1)
365
366 def topk(self, l: int, r: int, k: int) -> list[tuple[int, int]]:
367 """``a[l, r)`` の中で、要素を出現回数が多い順にその頻度とともに ``k`` 個返します。
368 :math:`O(\\min(r-l, \\sigam) \\log(\\sigam))` です。
369
370 Note:
371 :math:`\\sigma` が大きい場合、計算量に注意です。
372
373 Returns:
374 list[tuple[int, int]]: ``(要素, 頻度)`` を要素とする配列です。
375 """
376 assert (
377 0 <= l <= r <= self.size
378 ), f"IndexError: {self.__class__.__name__}.topk({l}, {r}, {k}), size={self.size}"
379 assert (
380 0 <= k < r - l
381 ), f"IndexError: {self.__class__.__name__}.topk({l}, {r}, {k}), wrong k"
382 # heap[-length, x, l, bit]
383 hq: list[tuple[int, int, int, int]] = [(-(r - l), 0, l, self.log - 1)]
384 ans = []
385 while hq:
386 length, x, l, bit = heappop(hq)
387 length = -length
388 if bit == -1:
389 ans.append((x, length))
390 k -= 1
391 if k == 0:
392 break
393 else:
394 r = l + length
395 l0 = self.v[bit].rank0(l)
396 r0 = self.v[bit].rank0(r)
397 if l0 < r0:
398 heappush(hq, (-(r0 - l0), x, l0, bit - 1))
399 l1 = self.v[bit].rank1(l) + self.mid[bit]
400 r1 = self.v[bit].rank1(r) + self.mid[bit]
401 if l1 < r1:
402 heappush(hq, (-(r1 - l1), x | (1 << bit), l1, bit - 1))
403 return ans
404
405 def sum(self, l: int, r: int) -> int:
406 """``topk`` メソッドを用いて ``a[l, r)`` の総和を返します。
407 計算量に注意です。
408 """
409 assert False, "Yabai Keisanryo Error"
410 return sum(k * v for k, v in self.topk(l, r, r - l))
411
412 def _range_freq(self, l: int, r: int, x: int) -> int:
413 """a[l, r) で x 未満の要素の数を返す"""
414 ans = 0
415 for bit in range(self.log - 1, -1, -1):
416 l0, r0 = self.v[bit].rank0(l), self.v[bit].rank0(r)
417 if x >> bit & 1:
418 # bit が立ってたら、区間の 0 の個数を答えに加算し、新たな区間は 1 のみ
419 ans += r0 - l0
420 # 1 が次の bit 列でどこに行くか
421 l += self.mid[bit] - l0
422 r += self.mid[bit] - r0
423 else:
424 # 0 が次の bit 列でどこに行くか
425 l, r = l0, r0
426 return ans
427
428 def range_freq(self, l: int, r: int, x: int, y: int) -> int:
429 """``a[l, r)`` に含まれる、 ``x`` 以上 ``y`` 未満である要素の個数を返します。
430 :math:`O(\\log{\\sigma})` です。
431 """
432 assert (
433 0 <= l <= r <= self.size
434 ), f"IndexError: {self.__class__.__name__}.range_freq({l}, {r}, {x}, {y})"
435 assert 0 <= x <= y < self.sigma, f"ValueError"
436 return self._range_freq(l, r, y) - self._range_freq(l, r, x)
437
438 def prev_value(self, l: int, r: int, x: int) -> int:
439 """``a[l, r)`` で、``x`` 以上 ``y`` 未満であるような要素のうち最大の要素を返します。
440 :math:`O(\\log{\\sigma})` です。
441 """
442 assert (
443 0 <= l <= r <= self.size
444 ), f"IndexError: {self.__class__.__name__}.prev_value({l}, {r}, {x})"
445 return self.kth_smallest(l, r, self._range_freq(l, r, x) - 1)
446
447 def next_value(self, l: int, r: int, x: int) -> int:
448 """``a[l, r)`` で、``x`` 以上 ``y`` 未満であるような要素のうち最小の要素を返します。
449 :math:`O(\\log{\\sigma})` です。
450 """
451 assert (
452 0 <= l <= r <= self.size
453 ), f"IndexError: {self.__class__.__name__}.next_value({l}, {r}, {x})"
454 return self.kth_smallest(l, r, self._range_freq(l, r, x))
455
456 def range_count(self, l: int, r: int, x: int) -> int:
457 """``a[l, r)`` に含まれる ``x`` の個数を返します。
458 ``wm.rank(r, x) - wm.rank(l, x)`` と等価です。
459 :math:`O(\\log{\\sigma})` です。
460 """
461 assert (
462 0 <= l <= r <= self.size
463 ), f"IndexError: {self.__class__.__name__}.range_count({l}, {r}, {x})"
464 return self.rank(r, x) - self.rank(l, x)
465
466 def __len__(self) -> int:
467 return self.size
468
469 def __str__(self) -> str:
470 return (
471 f"{self.__class__.__name__}({[self.access(i) for i in range(self.size)]})"
472 )
473
474 __repr__ = __str__
仕様¶
- class WaveletMatrix(sigma: int, a: Sequence[int] = [])[source]¶
Bases:
object
WaveletMatrix
です。 静的であることに注意してください。以下の仕様の計算量には嘘があるかもしれません。import 元の
BitVector
の計算量も参考にしてください。- 参考:
https://miti-7.hatenablog.com/entry/2018/04/28/152259 https://www.slideshare.net/pfi/ss-15916040 デwiki
- kth_largest(l: int, r: int, k: int) int [source]¶
a[l, r)
の中でk
番目に 大きい値 を返します。 \(O(\log{\sigma})\) です。
- kth_smallest(l: int, r: int, k: int) int [source]¶
a[l, r)
の中でk
番目に 小さい 値を返します。 \(O(\log{\sigma})\) です。
- next_value(l: int, r: int, x: int) int [source]¶
a[l, r)
で、x
以上y
未満であるような要素のうち最小の要素を返します。 \(O(\log{\sigma})\) です。
- prev_value(l: int, r: int, x: int) int [source]¶
a[l, r)
で、x
以上y
未満であるような要素のうち最大の要素を返します。 \(O(\log{\sigma})\) です。
- quantile(l: int, r: int, k: int) int ¶
a[l, r)
の中でk
番目に 小さい 値を返します。 \(O(\log{\sigma})\) です。
- range_count(l: int, r: int, x: int) int [source]¶
a[l, r)
に含まれるx
の個数を返します。wm.rank(r, x) - wm.rank(l, x)
と等価です。 \(O(\log{\sigma})\) です。