dynamic_wavelet_matrix¶
ソースコード¶
from titan_pylib.data_structures.wavelet_matrix.dynamic_wavelet_matrix import DynamicWaveletMatrix
展開済みコード¶
1# from titan_pylib.data_structures.wavelet_matrix.dynamic_wavelet_matrix import DynamicWaveletMatrix
2# from titan_pylib.data_structures.bit_vector.avl_tree_bit_vector import AVLTreeBitVector
3# from titan_pylib.data_structures.bit_vector.bit_vector_interface import (
4# BitVectorInterface,
5# )
6from abc import ABC, abstractmethod
7
8
9class BitVectorInterface(ABC):
10
11 @abstractmethod
12 def access(self, k: int) -> int:
13 raise NotImplementedError
14
15 @abstractmethod
16 def __getitem__(self, k: int) -> int:
17 raise NotImplementedError
18
19 @abstractmethod
20 def rank0(self, r: int) -> int:
21 raise NotImplementedError
22
23 @abstractmethod
24 def rank1(self, r: int) -> int:
25 raise NotImplementedError
26
27 @abstractmethod
28 def rank(self, r: int, v: int) -> int:
29 raise NotImplementedError
30
31 @abstractmethod
32 def select0(self, k: int) -> int:
33 raise NotImplementedError
34
35 @abstractmethod
36 def select1(self, k: int) -> int:
37 raise NotImplementedError
38
39 @abstractmethod
40 def select(self, k: int, v: int) -> int:
41 raise NotImplementedError
42
43 @abstractmethod
44 def __len__(self) -> int:
45 raise NotImplementedError
46
47 @abstractmethod
48 def __str__(self) -> str:
49 raise NotImplementedError
50
51 @abstractmethod
52 def __repr__(self) -> str:
53 raise NotImplementedError
54from array import array
55from typing import Iterable, Final, Sequence
56
57titan_pylib_AVLTreeBitVector_W: Final[int] = 31
58
59
60class AVLTreeBitVector(BitVectorInterface):
61 """AVL木で書かれたビットベクトルです。簡潔でもなんでもありません。
62
63 bit列を管理するわけですが、各節点は 1~32 bit を持つようにしています。
64 これにより、最大 32 倍高速化が行えます。(16~32bitとするといいんだろうけど)
65 """
66
67 @staticmethod
68 def _popcount(x: int) -> int:
69 x = x - ((x >> 1) & 0x55555555)
70 x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
71 x = x + (x >> 4) & 0x0F0F0F0F
72 x += x >> 8
73 x += x >> 16
74 return x & 0x0000007F
75
76 def __init__(self, a: Iterable[int] = []):
77 """
78 :math:`O(n)` です。
79
80 Args:
81 a (Iterable[int], optional): 構築元の配列です。
82 """
83 self.root = 0
84 self.bit_len = array("B", bytes(1))
85 self.key = array("I", bytes(4))
86 self.size = array("I", bytes(4))
87 self.total = array("I", bytes(4))
88 self.left = array("I", bytes(4))
89 self.right = array("I", bytes(4))
90 self.balance = array("b", bytes(1))
91 self.end = 1
92 if a:
93 self._build(a)
94
95 def reserve(self, n: int) -> None:
96 """``n`` 要素分のメモリを確保します。
97 :math:`O(n)` です。
98 """
99 n = n // titan_pylib_AVLTreeBitVector_W + 1
100 a = array("I", bytes(4 * n))
101 self.bit_len += array("B", bytes(n))
102 self.key += a
103 self.size += a
104 self.total += a
105 self.left += a
106 self.right += a
107 self.balance += array("b", bytes(n))
108
109 def _build(self, a: Iterable[int]) -> None:
110 key, bit_len, left, right, size, balance, total = (
111 self.key,
112 self.bit_len,
113 self.left,
114 self.right,
115 self.size,
116 self.balance,
117 self.total,
118 )
119 _popcount = AVLTreeBitVector._popcount
120
121 def rec(lr: int) -> int:
122 l, r = lr >> bit, lr & msk
123 mid = (l + r) >> 1
124 hl, hr = 0, 0
125 if l != mid:
126 le = rec(l << bit | mid)
127 left[mid], hl = le >> bit, le & msk
128 size[mid] += size[left[mid]]
129 total[mid] += total[left[mid]]
130 if mid + 1 != r:
131 ri = rec((mid + 1) << bit | r)
132 right[mid], hr = ri >> bit, ri & msk
133 size[mid] += size[right[mid]]
134 total[mid] += total[right[mid]]
135 balance[mid] = hl - hr
136 return mid << bit | (max(hl, hr) + 1)
137
138 if not isinstance(a, Sequence):
139 a = list(a)
140 n = len(a)
141 bit = n.bit_length() + 2
142 msk = (1 << bit) - 1
143 end = self.end
144 self.reserve(n)
145 i = 0
146 indx = end
147 for i in range(0, n, titan_pylib_AVLTreeBitVector_W):
148 j = 0
149 v = 0
150 while j < titan_pylib_AVLTreeBitVector_W and i + j < n:
151 v <<= 1
152 v |= a[i + j]
153 j += 1
154 key[indx] = v
155 bit_len[indx] = j
156 size[indx] = j
157 total[indx] = _popcount(v)
158 indx += 1
159 self.end = indx
160 self.root = rec(end << bit | self.end) >> bit
161
162 def _rotate_L(self, node: int) -> int:
163 left, right, size, balance, total = (
164 self.left,
165 self.right,
166 self.size,
167 self.balance,
168 self.total,
169 )
170 u = left[node]
171 size[u] = size[node]
172 total[u] = total[node]
173 size[node] -= size[left[u]] + self.bit_len[u]
174 total[node] -= total[left[u]] + AVLTreeBitVector._popcount(self.key[u])
175 left[node] = right[u]
176 right[u] = node
177 if balance[u] == 1:
178 balance[u] = 0
179 balance[node] = 0
180 else:
181 balance[u] = -1
182 balance[node] = 1
183 return u
184
185 def _rotate_R(self, node: int) -> int:
186 left, right, size, balance, total = (
187 self.left,
188 self.right,
189 self.size,
190 self.balance,
191 self.total,
192 )
193 u = right[node]
194 size[u] = size[node]
195 total[u] = total[node]
196 size[node] -= size[right[u]] + self.bit_len[u]
197 total[node] -= total[right[u]] + AVLTreeBitVector._popcount(self.key[u])
198 right[node] = left[u]
199 left[u] = node
200 if balance[u] == -1:
201 balance[u] = 0
202 balance[node] = 0
203 else:
204 balance[u] = 1
205 balance[node] = -1
206 return u
207
208 def _update_balance(self, node: int) -> None:
209 balance = self.balance
210 if balance[node] == 1:
211 balance[self.right[node]] = -1
212 balance[self.left[node]] = 0
213 elif balance[node] == -1:
214 balance[self.right[node]] = 0
215 balance[self.left[node]] = 1
216 else:
217 balance[self.right[node]] = 0
218 balance[self.left[node]] = 0
219 balance[node] = 0
220
221 def _rotate_LR(self, node: int) -> int:
222 left, right, size, total = self.left, self.right, self.size, self.total
223 B = left[node]
224 E = right[B]
225 size[E] = size[node]
226 size[node] -= size[B] - size[right[E]]
227 size[B] -= size[right[E]] + self.bit_len[E]
228 total[E] = total[node]
229 total[node] -= total[B] - total[right[E]]
230 total[B] -= total[right[E]] + AVLTreeBitVector._popcount(self.key[E])
231 right[B] = left[E]
232 left[E] = B
233 left[node] = right[E]
234 right[E] = node
235 self._update_balance(E)
236 return E
237
238 def _rotate_RL(self, node: int) -> int:
239 left, right, size, total = self.left, self.right, self.size, self.total
240 C = right[node]
241 D = left[C]
242 size[D] = size[node]
243 size[node] -= size[C] - size[left[D]]
244 size[C] -= size[left[D]] + self.bit_len[D]
245 total[D] = total[node]
246 total[node] -= total[C] - total[left[D]]
247 total[C] -= total[left[D]] + AVLTreeBitVector._popcount(self.key[D])
248 left[C] = right[D]
249 right[D] = C
250 right[node] = left[D]
251 left[D] = node
252 self._update_balance(D)
253 return D
254
255 def _pref(self, r: int) -> int:
256 left, right, bit_len, size, key, total = (
257 self.left,
258 self.right,
259 self.bit_len,
260 self.size,
261 self.key,
262 self.total,
263 )
264 node = self.root
265 s = 0
266 while r > 0:
267 t = size[left[node]] + bit_len[node]
268 if t - bit_len[node] < r <= t:
269 r -= size[left[node]]
270 s += total[left[node]] + AVLTreeBitVector._popcount(
271 key[node] >> (bit_len[node] - r)
272 )
273 break
274 if t > r:
275 node = left[node]
276 else:
277 s += total[left[node]] + AVLTreeBitVector._popcount(key[node])
278 node = right[node]
279 r -= t
280 return s
281
282 def _make_node(self, key: int, bit_len: int) -> int:
283 end = self.end
284 if end >= len(self.key):
285 self.key.append(key)
286 self.bit_len.append(bit_len)
287 self.size.append(bit_len)
288 self.total.append(AVLTreeBitVector._popcount(key))
289 self.left.append(0)
290 self.right.append(0)
291 self.balance.append(0)
292 else:
293 self.key[end] = key
294 self.bit_len[end] = bit_len
295 self.size[end] = bit_len
296 self.total[end] = AVLTreeBitVector._popcount(key)
297 self.end += 1
298 return end
299
300 def insert(self, k: int, key: int) -> None:
301 """``k`` 番目に ``v`` を挿入します。
302 :math:`O(\\log{n})` です。
303
304 Args:
305 k (int): 挿入位置のインデックスです。
306 key (int): 挿入する値です。 ``0`` または ``1`` である必要があります。
307 """
308 if self.root == 0:
309 self.root = self._make_node(key, 1)
310 return
311 left, right, size, bit_len, balance, keys, total = (
312 self.left,
313 self.right,
314 self.size,
315 self.bit_len,
316 self.balance,
317 self.key,
318 self.total,
319 )
320 node = self.root
321 path = []
322 d = 0
323 while node:
324 t = size[left[node]] + bit_len[node]
325 if t - bit_len[node] <= k <= t:
326 break
327 d <<= 1
328 size[node] += 1
329 total[node] += key
330 path.append(node)
331 node = left[node] if t > k else right[node]
332 if t > k:
333 d |= 1
334 else:
335 k -= t
336 k -= size[left[node]]
337 if bit_len[node] < titan_pylib_AVLTreeBitVector_W:
338 v = keys[node]
339 bl = bit_len[node] - k
340 keys[node] = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
341 bit_len[node] += 1
342 size[node] += 1
343 total[node] += key
344 return
345 path.append(node)
346 size[node] += 1
347 total[node] += key
348 v = keys[node]
349 bl = titan_pylib_AVLTreeBitVector_W - k
350 v = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
351 left_key = v >> titan_pylib_AVLTreeBitVector_W
352 left_key_popcount = left_key & 1
353 keys[node] = v & ((1 << titan_pylib_AVLTreeBitVector_W) - 1)
354 node = left[node]
355 d <<= 1
356 d |= 1
357 if not node:
358 if bit_len[path[-1]] < titan_pylib_AVLTreeBitVector_W:
359 bit_len[path[-1]] += 1
360 keys[path[-1]] = (keys[path[-1]] << 1) | left_key
361 return
362 else:
363 left[path[-1]] = self._make_node(left_key, 1)
364 else:
365 path.append(node)
366 size[node] += 1
367 total[node] += left_key_popcount
368 d <<= 1
369 while right[node]:
370 node = right[node]
371 path.append(node)
372 size[node] += 1
373 total[node] += left_key_popcount
374 d <<= 1
375 if bit_len[node] < titan_pylib_AVLTreeBitVector_W:
376 bit_len[node] += 1
377 keys[node] = (keys[node] << 1) | left_key
378 return
379 else:
380 right[node] = self._make_node(left_key, 1)
381 new_node = 0
382 while path:
383 node = path.pop()
384 balance[node] += 1 if d & 1 else -1
385 d >>= 1
386 if balance[node] == 0:
387 break
388 if balance[node] == 2:
389 new_node = (
390 self._rotate_LR(node)
391 if balance[left[node]] == -1
392 else self._rotate_L(node)
393 )
394 break
395 elif balance[node] == -2:
396 new_node = (
397 self._rotate_RL(node)
398 if balance[right[node]] == 1
399 else self._rotate_R(node)
400 )
401 break
402 if new_node:
403 if path:
404 if d & 1:
405 left[path[-1]] = new_node
406 else:
407 right[path[-1]] = new_node
408 else:
409 self.root = new_node
410
411 def _pop_under(self, path: list[int], d: int, node: int, res: int) -> None:
412 left, right, size, bit_len, balance, keys, total = (
413 self.left,
414 self.right,
415 self.size,
416 self.bit_len,
417 self.balance,
418 self.key,
419 self.total,
420 )
421 fd, lmax_total, lmax_bit_len = 0, 0, 0
422 if left[node] and right[node]:
423 path.append(node)
424 d <<= 1
425 d |= 1
426 lmax = left[node]
427 while right[lmax]:
428 path.append(lmax)
429 d <<= 1
430 fd <<= 1
431 fd |= 1
432 lmax = right[lmax]
433 lmax_total = AVLTreeBitVector._popcount(keys[lmax])
434 lmax_bit_len = bit_len[lmax]
435 keys[node] = keys[lmax]
436 bit_len[node] = lmax_bit_len
437 node = lmax
438 cnode = right[node] if left[node] == 0 else left[node]
439 if path:
440 if d & 1:
441 left[path[-1]] = cnode
442 else:
443 right[path[-1]] = cnode
444 else:
445 self.root = cnode
446 return
447 while path:
448 new_node = 0
449 node = path.pop()
450 balance[node] -= 1 if d & 1 else -1
451 size[node] -= lmax_bit_len if fd & 1 else 1
452 total[node] -= lmax_total if fd & 1 else res
453 d >>= 1
454 fd >>= 1
455 if balance[node] == 2:
456 new_node = (
457 self._rotate_LR(node)
458 if balance[left[node]] < 0
459 else self._rotate_L(node)
460 )
461 elif balance[node] == -2:
462 new_node = (
463 self._rotate_RL(node)
464 if balance[right[node]] > 0
465 else self._rotate_R(node)
466 )
467 elif balance[node] != 0:
468 break
469 if new_node:
470 if not path:
471 self.root = new_node
472 return
473 if d & 1:
474 left[path[-1]] = new_node
475 else:
476 right[path[-1]] = new_node
477 if balance[new_node] != 0:
478 break
479 while path:
480 node = path.pop()
481 size[node] -= lmax_bit_len if fd & 1 else 1
482 total[node] -= lmax_total if fd & 1 else res
483 fd >>= 1
484
485 def pop(self, k: int) -> int:
486 """``k`` 番目の要素を削除し、その値を返します。
487 :math:`O(\\log{n})` です。
488
489 Args:
490 k (int): 削除位置のインデックスです。
491 """
492 assert 0 <= k < len(self)
493 left, right, size = self.left, self.right, self.size
494 bit_len, keys, total = self.bit_len, self.key, self.total
495 node = self.root
496 d = 0
497 path = []
498 while node:
499 t = size[left[node]] + bit_len[node]
500 if t - bit_len[node] <= k < t:
501 break
502 path.append(node)
503 node = left[node] if t > k else right[node]
504 d <<= 1
505 if t > k:
506 d |= 1
507 else:
508 k -= t
509 k -= size[left[node]]
510 v = keys[node]
511 res = v >> (bit_len[node] - k - 1) & 1
512 if bit_len[node] == 1:
513 self._pop_under(path, d, node, res)
514 return res
515 keys[node] = ((v >> (bit_len[node] - k)) << ((bit_len[node] - k - 1))) | (
516 v & ((1 << (bit_len[node] - k - 1)) - 1)
517 )
518 bit_len[node] -= 1
519 size[node] -= 1
520 total[node] -= res
521 for p in path:
522 size[p] -= 1
523 total[p] -= res
524 return res
525
526 def set(self, k: int, v: int) -> None:
527 """``k`` 番目の値を ``v`` に更新します。
528 :math:`O(\\log{n})` です。
529
530 Args:
531 k (int): 更新位置のインデックスです。
532 key (int): 更新する値です。 ``0`` または ``1`` である必要があります。
533 """
534 self.__setitem__(k, v)
535
536 def tolist(self) -> list[int]:
537 """リストにして返します。
538 :math:`O(n)` です。
539 """
540 left, right, key, bit_len = self.left, self.right, self.key, self.bit_len
541 a = []
542 if not self.root:
543 return a
544
545 def rec(node):
546 if left[node]:
547 rec(left[node])
548 for i in range(bit_len[node] - 1, -1, -1):
549 a.append(key[node] >> i & 1)
550 if right[node]:
551 rec(right[node])
552
553 rec(self.root)
554 return a
555
556 def _debug_acc(self) -> None:
557 """デバッグ用のメソッドです。
558 key,totalをチェックします。
559 """
560 left, right = self.left, self.right
561 key = self.key
562
563 def rec(node):
564 acc = self._popcount(key[node])
565 if left[node]:
566 acc += rec(left[node])
567 if right[node]:
568 acc += rec(right[node])
569 if acc != self.total[node]:
570 # self.debug()
571 assert False, "acc Error"
572 return acc
573
574 rec(self.root)
575 print("debug_acc ok.")
576
577 def access(self, k: int) -> int:
578 """``k`` 番目の値を返します。
579 :math:`O(\\log{n})` です。
580
581 Args:
582 k (int): 取得位置のインデックスです。
583 """
584 return self.__getitem__(k)
585
586 def rank0(self, r: int) -> int:
587 """``a[0, r)`` に含まれる ``0`` の個数を返します。
588 :math:`O(\\log{n})` です。
589 """
590 return r - self._pref(r)
591
592 def rank1(self, r: int) -> int:
593 """``a[0, r)`` に含まれる ``1`` の個数を返します。
594 :math:`O(\\log{n})` です。
595 """
596 return self._pref(r)
597
598 def rank(self, r: int, v: int) -> int:
599 """``a[0, r)`` に含まれる ``v`` の個数を返します。
600 :math:`O(\\log{n})` です。
601 """
602 return self.rank1(r) if v else self.rank0(r)
603
604 def select0(self, k: int) -> int:
605 """``k`` 番目の ``0`` のインデックスを返します。
606 :math:`O(\\log{n}^2)` です。
607 """
608 if k < 0 or self.rank0(len(self)) <= k:
609 return -1
610 l, r = 0, len(self)
611 while r - l > 1:
612 m = (l + r) >> 1
613 if m - self._pref(m) > k:
614 r = m
615 else:
616 l = m
617 return l
618
619 def select1(self, k: int) -> int:
620 """``k`` 番目の ``1`` のインデックスを返します。
621 :math:`O(\\log{n}^2)` です。
622 """
623 if k < 0 or self.rank1(len(self)) <= k:
624 return -1
625 l, r = 0, len(self)
626 while r - l > 1:
627 m = (l + r) >> 1
628 if self._pref(m) > k:
629 r = m
630 else:
631 l = m
632 return l
633
634 def select(self, k: int, v: int) -> int:
635 """``k`` 番目の ``v`` のインデックスを返します。
636 :math:`O(\\log{n}^2)` です。
637 """
638 return self.select1(k) if v else self.select0(k)
639
640 def _insert_and_rank1(self, k: int, key: int) -> int:
641 if self.root == 0:
642 self.root = self._make_node(key, 1)
643 return 0
644 left, right, size, bit_len, balance, keys, total = (
645 self.left,
646 self.right,
647 self.size,
648 self.bit_len,
649 self.balance,
650 self.key,
651 self.total,
652 )
653 node = self.root
654 s = 0
655 path = []
656 d = 0
657 while node:
658 t = size[left[node]] + bit_len[node]
659 if t - bit_len[node] <= k <= t:
660 break
661 if t <= k:
662 s += total[left[node]] + AVLTreeBitVector._popcount(keys[node])
663 d <<= 1
664 size[node] += 1
665 total[node] += key
666 path.append(node)
667 node = left[node] if t > k else right[node]
668 if t > k:
669 d |= 1
670 else:
671 k -= t
672 k -= size[left[node]]
673 s += total[left[node]] + AVLTreeBitVector._popcount(
674 keys[node] >> (bit_len[node] - k)
675 )
676 if bit_len[node] < titan_pylib_AVLTreeBitVector_W:
677 v = keys[node]
678 bl = bit_len[node] - k
679 keys[node] = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
680 bit_len[node] += 1
681 size[node] += 1
682 total[node] += key
683 return s
684 path.append(node)
685 size[node] += 1
686 total[node] += key
687 v = keys[node]
688 bl = titan_pylib_AVLTreeBitVector_W - k
689 v = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
690 left_key = v >> titan_pylib_AVLTreeBitVector_W
691 left_key_popcount = left_key & 1
692 keys[node] = v & ((1 << titan_pylib_AVLTreeBitVector_W) - 1)
693 node = left[node]
694 d <<= 1
695 d |= 1
696 if not node:
697 if bit_len[path[-1]] < titan_pylib_AVLTreeBitVector_W:
698 bit_len[path[-1]] += 1
699 keys[path[-1]] = (keys[path[-1]] << 1) | left_key
700 return s
701 else:
702 left[path[-1]] = self._make_node(left_key, 1)
703 else:
704 path.append(node)
705 size[node] += 1
706 total[node] += left_key_popcount
707 d <<= 1
708 while right[node]:
709 node = right[node]
710 path.append(node)
711 size[node] += 1
712 total[node] += left_key_popcount
713 d <<= 1
714 if bit_len[node] < titan_pylib_AVLTreeBitVector_W:
715 bit_len[node] += 1
716 keys[node] = (keys[node] << 1) | left_key
717 return s
718 else:
719 right[node] = self._make_node(left_key, 1)
720 new_node = 0
721 while path:
722 node = path.pop()
723 balance[node] += 1 if d & 1 else -1
724 d >>= 1
725 if balance[node] == 0:
726 break
727 if balance[node] == 2:
728 new_node = (
729 self._rotate_LR(node)
730 if balance[left[node]] == -1
731 else self._rotate_L(node)
732 )
733 break
734 elif balance[node] == -2:
735 new_node = (
736 self._rotate_RL(node)
737 if balance[right[node]] == 1
738 else self._rotate_R(node)
739 )
740 break
741 if new_node:
742 if path:
743 if d & 1:
744 left[path[-1]] = new_node
745 else:
746 right[path[-1]] = new_node
747 else:
748 self.root = new_node
749 return s
750
751 def _access_pop_and_rank1(self, k: int) -> int:
752 assert 0 <= k < len(self)
753 left, right, size = self.left, self.right, self.size
754 bit_len, keys, total = self.bit_len, self.key, self.total
755 s = 0
756 node = self.root
757 d = 0
758 path = []
759 while node:
760 t = size[left[node]] + bit_len[node]
761 if t - bit_len[node] <= k < t:
762 break
763 if t <= k:
764 s += total[left[node]] + AVLTreeBitVector._popcount(keys[node])
765 path.append(node)
766 node = left[node] if t > k else right[node]
767 d <<= 1
768 if t > k:
769 d |= 1
770 else:
771 k -= t
772 k -= size[left[node]]
773 s += total[left[node]] + AVLTreeBitVector._popcount(
774 keys[node] >> (bit_len[node] - k)
775 )
776 v = keys[node]
777 res = v >> (bit_len[node] - k - 1) & 1
778 if bit_len[node] == 1:
779 self._pop_under(path, d, node, res)
780 return s << 1 | res
781 keys[node] = ((v >> (bit_len[node] - k)) << ((bit_len[node] - k - 1))) | (
782 v & ((1 << (bit_len[node] - k - 1)) - 1)
783 )
784 bit_len[node] -= 1
785 size[node] -= 1
786 total[node] -= res
787 for p in path:
788 size[p] -= 1
789 total[p] -= res
790 return s << 1 | res
791
792 def __getitem__(self, k: int) -> int:
793 """``k`` 番目の要素を返します。
794 :math:`O(\\log{n})` です。
795 """
796 assert 0 <= k < len(self)
797 left, right, bit_len, size, key = (
798 self.left,
799 self.right,
800 self.bit_len,
801 self.size,
802 self.key,
803 )
804 node = self.root
805 while True:
806 t = size[left[node]] + bit_len[node]
807 if t - bit_len[node] <= k < t:
808 k -= size[left[node]]
809 return key[node] >> (bit_len[node] - k - 1) & 1
810 if t > k:
811 node = left[node]
812 else:
813 node = right[node]
814 k -= t
815
816 def __setitem__(self, k: int, v: int) -> None:
817 """``k`` 番目の要素を ``v`` に更新します。
818 :math:`O(\\log{n})` です。
819 """
820 left, right, bit_len, size, key, total = (
821 self.left,
822 self.right,
823 self.bit_len,
824 self.size,
825 self.key,
826 self.total,
827 )
828 assert v == 0 or v == 1, "ValueError"
829 node = self.root
830 path = []
831 while True:
832 t = size[left[node]] + bit_len[node]
833 path.append(node)
834 if t - bit_len[node] <= k < t:
835 k -= size[left[node]]
836 if v:
837 key[node] |= 1 << k
838 else:
839 key[node] &= ~(1 << k)
840 break
841 elif t > k:
842 node = left[node]
843 else:
844 node = right[node]
845 k -= t
846 while path:
847 node = path.pop()
848 total[node] = (
849 AVLTreeBitVector._popcount(key[node])
850 + total[left[node]]
851 + total[right[node]]
852 )
853
854 def __str__(self):
855 return str(self.tolist())
856
857 def __len__(self):
858 return self.size[self.root]
859
860 def __repr__(self):
861 return f"{self.__class__.__name__}({self})"
862# from titan_pylib.data_structures.wavelet_matrix.wavelet_matrix import WaveletMatrix
863# from titan_pylib.data_structures.bit_vector.bit_vector import BitVector
864# from titan_pylib.data_structures.bit_vector.bit_vector_interface import (
865# BitVectorInterface,
866# )
867from array import array
868
869
870class BitVector(BitVectorInterface):
871 """コンパクトな bit vector です。"""
872
873 def __init__(self, n: int):
874 """長さ ``n`` の ``BitVector`` です。
875
876 bit を保持するのに ``array[I]`` を使用します。
877 ``block_size= n / 32`` として、使用bitは ``32*block_size=2n bit`` です。
878
879 累積和を保持するのに同様の ``array[I]`` を使用します。
880 32bitごとの和を保存しています。同様に使用bitは ``2n bit`` です。
881 """
882 assert 0 <= n < 4294967295
883 self.N = n
884 self.block_size = (n + 31) >> 5
885 b = bytes(4 * (self.block_size + 1))
886 self.bit = array("I", b)
887 self.acc = array("I", b)
888
889 @staticmethod
890 def _popcount(x: int) -> int:
891 x = x - ((x >> 1) & 0x55555555)
892 x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
893 x = x + (x >> 4) & 0x0F0F0F0F
894 x += x >> 8
895 x += x >> 16
896 return x & 0x0000007F
897
898 def set(self, k: int) -> None:
899 """``k`` 番目の bit を ``1`` にします。
900 :math:`O(1)` です。
901
902 Args:
903 k (int): インデックスです。
904 """
905 self.bit[k >> 5] |= 1 << (k & 31)
906
907 def build(self) -> None:
908 """構築します。
909 **これ以降 ``set`` メソッドを使用してはいけません。**
910 :math:`O(n)` です。
911 """
912 acc, bit = self.acc, self.bit
913 for i in range(self.block_size):
914 acc[i + 1] = acc[i] + BitVector._popcount(bit[i])
915
916 def access(self, k: int) -> int:
917 """``k`` 番目の bit を返します。
918 :math:`O(1)` です。
919 """
920 return (self.bit[k >> 5] >> (k & 31)) & 1
921
922 def __getitem__(self, k: int) -> int:
923 return (self.bit[k >> 5] >> (k & 31)) & 1
924
925 def rank0(self, r: int) -> int:
926 """``a[0, r)`` に含まれる ``0`` の個数を返します。
927 :math:`O(1)` です。
928 """
929 return r - (
930 self.acc[r >> 5]
931 + BitVector._popcount(self.bit[r >> 5] & ((1 << (r & 31)) - 1))
932 )
933
934 def rank1(self, r: int) -> int:
935 """``a[0, r)`` に含まれる ``1`` の個数を返します。
936 :math:`O(1)` です。
937 """
938 return self.acc[r >> 5] + BitVector._popcount(
939 self.bit[r >> 5] & ((1 << (r & 31)) - 1)
940 )
941
942 def rank(self, r: int, v: int) -> int:
943 """``a[0, r)`` に含まれる ``v`` の個数を返します。
944 :math:`O(1)` です。
945 """
946 return self.rank1(r) if v else self.rank0(r)
947
948 def select0(self, k: int) -> int:
949 """``k`` 番目の ``0`` のインデックスを返します。
950 :math:`O(\\log{n})` です。
951 """
952 if k < 0 or self.rank0(self.N) <= k:
953 return -1
954 l, r = 0, self.block_size + 1
955 while r - l > 1:
956 m = (l + r) >> 1
957 if m * 32 - self.acc[m] > k:
958 r = m
959 else:
960 l = m
961 indx = 32 * l
962 k = k - (l * 32 - self.acc[l]) + self.rank0(indx)
963 l, r = indx, indx + 32
964 while r - l > 1:
965 m = (l + r) >> 1
966 if self.rank0(m) > k:
967 r = m
968 else:
969 l = m
970 return l
971
972 def select1(self, k: int) -> int:
973 """``k`` 番目の ``1`` のインデックスを返します。
974 :math:`O(\\log{n})` です。
975 """
976 if k < 0 or self.rank1(self.N) <= k:
977 return -1
978 l, r = 0, self.block_size + 1
979 while r - l > 1:
980 m = (l + r) >> 1
981 if self.acc[m] > k:
982 r = m
983 else:
984 l = m
985 indx = 32 * l
986 k = k - self.acc[l] + self.rank1(indx)
987 l, r = indx, indx + 32
988 while r - l > 1:
989 m = (l + r) >> 1
990 if self.rank1(m) > k:
991 r = m
992 else:
993 l = m
994 return l
995
996 def select(self, k: int, v: int) -> int:
997 """``k`` 番目の ``v`` のインデックスを返します。
998 :math:`O(\\log{n})` です。
999 """
1000 return self.select1(k) if v else self.select0(k)
1001
1002 def __len__(self):
1003 return self.N
1004
1005 def __str__(self):
1006 return str([self.access(i) for i in range(self.N)])
1007
1008 def __repr__(self):
1009 return f"{self.__class__.__name__}({self})"
1010from typing import Sequence
1011from heapq import heappush, heappop
1012from array import array
1013
1014
1015class WaveletMatrix:
1016 """``WaveletMatrix`` です。
1017 静的であることに注意してください。
1018
1019 以下の仕様の計算量には嘘があるかもしれません。import 元の ``BitVector`` の計算量も参考にしてください。
1020
1021 参考:
1022 `https://miti-7.hatenablog.com/entry/2018/04/28/152259 <https://miti-7.hatenablog.com/entry/2018/04/28/152259>`_
1023 `https://www.slideshare.net/pfi/ss-15916040 <https://www.slideshare.net/pfi/ss-15916040>`_
1024 `デwiki <https://scrapbox.io/data-structures/Wavelet_Matrix>`_
1025 """
1026
1027 def __init__(self, sigma: int, a: Sequence[int] = []) -> None:
1028 """``[0, sigma)`` の整数列を管理する ``WaveletMatrix`` を構築します。
1029 :math:`O(n\\log{\\sigma})` です。
1030
1031 Args:
1032 sigma (int): 扱う整数の上限です。
1033 a (Sequence[int], optional): 構築する配列です。
1034 """
1035 self.sigma: int = sigma
1036 self.log: int = (sigma - 1).bit_length()
1037 self.mid: array[int] = array("I", bytes(4 * self.log))
1038 self.size: int = len(a)
1039 self.v: list[BitVector] = [BitVector(self.size) for _ in range(self.log)]
1040 self._build(a)
1041
1042 def _build(self, a: Sequence[int]) -> None:
1043 # 列 a から wm を構築する
1044 for bit in range(self.log - 1, -1, -1):
1045 # bit目の0/1に応じてvを構築 + aを安定ソート
1046 v = self.v[bit]
1047 zero, one = [], []
1048 for i, e in enumerate(a):
1049 if e >> bit & 1:
1050 v.set(i)
1051 one.append(e)
1052 else:
1053 zero.append(e)
1054 v.build()
1055 self.mid[bit] = len(zero) # 境界をmid[bit]に保持
1056 a = zero + one
1057
1058 def access(self, k: int) -> int:
1059 """``k`` 番目の値を返します。
1060 :math:`O(\\log{\\sigma})` です。
1061
1062 Args:
1063 k (int): インデックスです。
1064 """
1065 assert (
1066 -self.size <= k < self.size
1067 ), f"IndexError: {self.__class__.__name__}.access({k}), size={self.size}"
1068 if k < 0:
1069 k += self.size
1070 s = 0 # 答え
1071 for bit in range(self.log - 1, -1, -1):
1072 if self.v[bit].access(k):
1073 # k番目が立ってたら、
1074 # kまでの1とすべての0が次のk
1075 s |= 1 << bit
1076 k = self.v[bit].rank1(k) + self.mid[bit]
1077 else:
1078 # kまでの0が次のk
1079 k = self.v[bit].rank0(k)
1080 return s
1081
1082 def __getitem__(self, k: int) -> int:
1083 assert (
1084 -self.size <= k < self.size
1085 ), f"IndexError: {self.__class__.__name__}[{k}], size={self.size}"
1086 return self.access(k)
1087
1088 def rank(self, r: int, x: int) -> int:
1089 """``a[0, r)`` に含まれる ``x`` の個数を返します。
1090 :math:`O(\\log{\\sigma})` です。
1091 """
1092 assert (
1093 0 <= r <= self.size
1094 ), f"IndexError: {self.__class__.__name__}.rank(), r={r}, size={self.size}"
1095 assert (
1096 0 <= x < 1 << self.log
1097 ), f"ValueError: {self.__class__.__name__}.rank(), x={x}, LIM={1<<self.log}"
1098 l = 0
1099 mid = self.mid
1100 for bit in range(self.log - 1, -1, -1):
1101 # 位置 r より左に x が何個あるか
1102 # x の bit 目で場合分け
1103 if x >> bit & 1:
1104 # 立ってたら、次のl, rは以下
1105 l = self.v[bit].rank1(l) + mid[bit]
1106 r = self.v[bit].rank1(r) + mid[bit]
1107 else:
1108 # そうでなければ次のl, rは以下
1109 l = self.v[bit].rank0(l)
1110 r = self.v[bit].rank0(r)
1111 return r - l
1112
1113 def select(self, k: int, x: int) -> int:
1114 """``k`` 番目の ``v`` のインデックスを返します。
1115 :math:`O(\\log{\\sigma})` です。
1116 """
1117 assert (
1118 0 <= k < self.size
1119 ), f"IndexError: {self.__class__.__name__}.select({k}, {x}), k={k}, size={self.size}"
1120 assert (
1121 0 <= x < 1 << self.log
1122 ), f"ValueError: {self.__class__.__name__}.select({k}, {x}), x={x}, LIM={1<<self.log}"
1123 # x の開始位置 s を探す
1124 s = 0
1125 for bit in range(self.log - 1, -1, -1):
1126 if x >> bit & 1:
1127 s = self.v[bit].rank0(self.size) + self.v[bit].rank1(s)
1128 else:
1129 s = self.v[bit].rank0(s)
1130 s += k # s から k 進んだ位置が、元の列で何番目か調べる
1131 for bit in range(self.log):
1132 if x >> bit & 1:
1133 s = self.v[bit].select1(s - self.v[bit].rank0(self.size))
1134 else:
1135 s = self.v[bit].select0(s)
1136 return s
1137
1138 def kth_smallest(self, l: int, r: int, k: int) -> int:
1139 """``a[l, r)`` の中で ``k`` 番目に **小さい** 値を返します。
1140 :math:`O(\\log{\\sigma})` です。
1141 """
1142 assert (
1143 0 <= l <= r <= self.size
1144 ), f"IndexError: {self.__class__.__name__}.kth_smallest({l}, {r}, {k}), size={self.size}"
1145 assert (
1146 0 <= k < r - l
1147 ), f"IndexError: {self.__class__.__name__}.kth_smallest({l}, {r}, {k}), wrong k"
1148 s = 0
1149 mid = self.mid
1150 for bit in range(self.log - 1, -1, -1):
1151 r0, l0 = self.v[bit].rank0(r), self.v[bit].rank0(l)
1152 cnt = r0 - l0 # 区間内の 0 の個数
1153 if cnt <= k: # 0 が k 以下のとき、 k 番目は 1
1154 s |= 1 << bit
1155 k -= cnt
1156 # この 1 が次の bit 列でどこに行くか
1157 l = l - l0 + mid[bit]
1158 r = r - r0 + mid[bit]
1159 else:
1160 # この 0 が次の bit 列でどこに行くか
1161 l = l0
1162 r = r0
1163 return s
1164
1165 quantile = kth_smallest
1166
1167 def kth_largest(self, l: int, r: int, k: int) -> int:
1168 """``a[l, r)`` の中で ``k`` 番目に **大きい値** を返します。
1169 :math:`O(\\log{\\sigma})` です。
1170 """
1171 assert (
1172 0 <= l <= r <= self.size
1173 ), f"IndexError: {self.__class__.__name__}.kth_largest({l}, {r}, {k}), size={self.size}"
1174 assert (
1175 0 <= k < r - l
1176 ), f"IndexError: {self.__class__.__name__}.kth_largest({l}, {r}, {k}), wrong k"
1177 return self.kth_smallest(l, r, r - l - k - 1)
1178
1179 def topk(self, l: int, r: int, k: int) -> list[tuple[int, int]]:
1180 """``a[l, r)`` の中で、要素を出現回数が多い順にその頻度とともに ``k`` 個返します。
1181 :math:`O(\\min(r-l, \\sigam) \\log(\\sigam))` です。
1182
1183 Note:
1184 :math:`\\sigma` が大きい場合、計算量に注意です。
1185
1186 Returns:
1187 list[tuple[int, int]]: ``(要素, 頻度)`` を要素とする配列です。
1188 """
1189 assert (
1190 0 <= l <= r <= self.size
1191 ), f"IndexError: {self.__class__.__name__}.topk({l}, {r}, {k}), size={self.size}"
1192 assert (
1193 0 <= k < r - l
1194 ), f"IndexError: {self.__class__.__name__}.topk({l}, {r}, {k}), wrong k"
1195 # heap[-length, x, l, bit]
1196 hq: list[tuple[int, int, int, int]] = [(-(r - l), 0, l, self.log - 1)]
1197 ans = []
1198 while hq:
1199 length, x, l, bit = heappop(hq)
1200 length = -length
1201 if bit == -1:
1202 ans.append((x, length))
1203 k -= 1
1204 if k == 0:
1205 break
1206 else:
1207 r = l + length
1208 l0 = self.v[bit].rank0(l)
1209 r0 = self.v[bit].rank0(r)
1210 if l0 < r0:
1211 heappush(hq, (-(r0 - l0), x, l0, bit - 1))
1212 l1 = self.v[bit].rank1(l) + self.mid[bit]
1213 r1 = self.v[bit].rank1(r) + self.mid[bit]
1214 if l1 < r1:
1215 heappush(hq, (-(r1 - l1), x | (1 << bit), l1, bit - 1))
1216 return ans
1217
1218 def sum(self, l: int, r: int) -> int:
1219 """``topk`` メソッドを用いて ``a[l, r)`` の総和を返します。
1220 計算量に注意です。
1221 """
1222 assert False, "Yabai Keisanryo Error"
1223 return sum(k * v for k, v in self.topk(l, r, r - l))
1224
1225 def _range_freq(self, l: int, r: int, x: int) -> int:
1226 """a[l, r) で x 未満の要素の数を返す"""
1227 ans = 0
1228 for bit in range(self.log - 1, -1, -1):
1229 l0, r0 = self.v[bit].rank0(l), self.v[bit].rank0(r)
1230 if x >> bit & 1:
1231 # bit が立ってたら、区間の 0 の個数を答えに加算し、新たな区間は 1 のみ
1232 ans += r0 - l0
1233 # 1 が次の bit 列でどこに行くか
1234 l += self.mid[bit] - l0
1235 r += self.mid[bit] - r0
1236 else:
1237 # 0 が次の bit 列でどこに行くか
1238 l, r = l0, r0
1239 return ans
1240
1241 def range_freq(self, l: int, r: int, x: int, y: int) -> int:
1242 """``a[l, r)`` に含まれる、 ``x`` 以上 ``y`` 未満である要素の個数を返します。
1243 :math:`O(\\log{\\sigma})` です。
1244 """
1245 assert (
1246 0 <= l <= r <= self.size
1247 ), f"IndexError: {self.__class__.__name__}.range_freq({l}, {r}, {x}, {y})"
1248 assert 0 <= x <= y < self.sigma, f"ValueError"
1249 return self._range_freq(l, r, y) - self._range_freq(l, r, x)
1250
1251 def prev_value(self, l: int, r: int, x: int) -> int:
1252 """``a[l, r)`` で、``x`` 以上 ``y`` 未満であるような要素のうち最大の要素を返します。
1253 :math:`O(\\log{\\sigma})` です。
1254 """
1255 assert (
1256 0 <= l <= r <= self.size
1257 ), f"IndexError: {self.__class__.__name__}.prev_value({l}, {r}, {x})"
1258 return self.kth_smallest(l, r, self._range_freq(l, r, x) - 1)
1259
1260 def next_value(self, l: int, r: int, x: int) -> int:
1261 """``a[l, r)`` で、``x`` 以上 ``y`` 未満であるような要素のうち最小の要素を返します。
1262 :math:`O(\\log{\\sigma})` です。
1263 """
1264 assert (
1265 0 <= l <= r <= self.size
1266 ), f"IndexError: {self.__class__.__name__}.next_value({l}, {r}, {x})"
1267 return self.kth_smallest(l, r, self._range_freq(l, r, x))
1268
1269 def range_count(self, l: int, r: int, x: int) -> int:
1270 """``a[l, r)`` に含まれる ``x`` の個数を返します。
1271 ``wm.rank(r, x) - wm.rank(l, x)`` と等価です。
1272 :math:`O(\\log{\\sigma})` です。
1273 """
1274 assert (
1275 0 <= l <= r <= self.size
1276 ), f"IndexError: {self.__class__.__name__}.range_count({l}, {r}, {x})"
1277 return self.rank(r, x) - self.rank(l, x)
1278
1279 def __len__(self) -> int:
1280 return self.size
1281
1282 def __str__(self) -> str:
1283 return (
1284 f"{self.__class__.__name__}({[self.access(i) for i in range(self.size)]})"
1285 )
1286
1287 __repr__ = __str__
1288from typing import Sequence
1289from array import array
1290
1291
1292class DynamicWaveletMatrix(WaveletMatrix):
1293 """動的ウェーブレット行列です。
1294
1295 (静的)ウェーブレット行列でできる操作に加えて ``insert / pop / set`` 等ができます。
1296 - ``BitVector`` を平衡二分木にしています(``AVLTreeBitVector``)。あらゆる操作に平衡二分木の log がつきます。これヤバくね
1297
1298 :math:`O(n\\log{(\\sigma)})` です。
1299 """
1300
1301 def __init__(self, sigma: int, a: Sequence[int] = []) -> None:
1302 self.sigma: int = sigma
1303 self.log: int = (sigma - 1).bit_length()
1304 self.v: list[AVLTreeBitVector] = [AVLTreeBitVector()] * self.log
1305 self.mid: array[int] = array("I", bytes(4 * self.log))
1306 self.size: int = len(a)
1307 self._build(a)
1308
1309 def _build(self, a: Sequence[int]) -> None:
1310 v = array("B", bytes(self.size))
1311 for bit in range(self.log - 1, -1, -1):
1312 # bit目の0/1に応じてvを構築 + aを安定ソート
1313 zero, one = [], []
1314 for i, e in enumerate(a):
1315 if e >> bit & 1:
1316 v[i] = 1
1317 one.append(e)
1318 else:
1319 v[i] = 0
1320 zero.append(e)
1321 self.mid[bit] = len(zero) # 境界をmid[bit]に保持
1322 self.v[bit] = AVLTreeBitVector(v)
1323 a = zero + one
1324
1325 def reserve(self, n: int) -> None:
1326 """``n`` 要素分のメモリを確保します。
1327 :math:`O(n)` です。
1328 """
1329 assert n >= 0, f"ValueError: {self.__class__.__name__}.reserve({n})"
1330 for v in self.v:
1331 v.reserve(n)
1332
1333 def insert(self, k: int, x: int) -> None:
1334 """位置 ``k`` に ``x`` を挿入します。
1335 :math:`O(\\log{(n)}\\log{(\\sigma)})` です。
1336 """
1337 assert (
1338 0 <= k <= self.size
1339 ), f"IndexError: {self.__class__.__name__}.insert({k}, {x}), n={self.size}"
1340 assert (
1341 0 <= x < 1 << self.log
1342 ), f"ValueError: {self.__class__.__name__}.insert({k}, {x}), LIM={1<<self.log}"
1343 mid = self.mid
1344 for bit in range(self.log - 1, -1, -1):
1345 v = self.v[bit]
1346 # if x >> bit & 1:
1347 # v.insert(k, 1)
1348 # k = v.rank1(k) + mid[bit]
1349 # else:
1350 # v.insert(k, 0)
1351 # mid[bit] += 1
1352 # k = v.rank0(k)
1353 if x >> bit & 1:
1354 s = v._insert_and_rank1(k, 1)
1355 k = s + mid[bit]
1356 else:
1357 s = v._insert_and_rank1(k, 0)
1358 k -= s
1359 mid[bit] += 1
1360 self.size += 1
1361
1362 def pop(self, k: int) -> int:
1363 """位置 ``k`` の要素を削除し、その値を返します。
1364 :math:`O(\\log{(n)}\\log{(\\sigma)})` です。
1365 """
1366 assert (
1367 0 <= k < self.size
1368 ), f"IndexError: {self.__class__.__name__}.pop({k}), n={self.size}"
1369 mid = self.mid
1370 ans = 0
1371 for bit in range(self.log - 1, -1, -1):
1372 v = self.v[bit]
1373 # K = k
1374 # if v.access(k):
1375 # ans |= 1 << bit
1376 # k = v.rank1(k) + mid[bit]
1377 # else:
1378 # mid[bit] -= 1
1379 # k = v.rank0(k)
1380 # v.pop(K)
1381 sb = v._access_pop_and_rank1(k)
1382 s = sb >> 1
1383 if sb & 1:
1384 ans |= 1 << bit
1385 k = s + mid[bit]
1386 else:
1387 mid[bit] -= 1
1388 k -= s
1389 self.size -= 1
1390 return ans
1391
1392 def set(self, k: int, x: int) -> None:
1393 """位置 ``k`` の要素を ``x`` に更新します。
1394 :math:`O(\\log{(n)}\\log{(\\sigma)})` です。
1395 """
1396 assert (
1397 0 <= k < self.size
1398 ), f"IndexError: {self.__class__.__name__}.set({k}, {x}), n={self.size}"
1399 assert (
1400 0 <= x < 1 << self.log
1401 ), f"ValueError: {self.__class__.__name__}.set({k}, {x}), LIM={1<<self.log}"
1402 self.pop(k)
1403 self.insert(k, x)
1404
1405 def __setitem__(self, k: int, x: int):
1406 assert (
1407 0 <= k < self.size
1408 ), f"IndexError: {self.__class__.__name__}[{k}] = {x}, n={self.size}"
1409 assert (
1410 0 <= x < 1 << self.log
1411 ), f"ValueError: {self.__class__.__name__}[{k}] = {x}, LIM={1<<self.log}"
1412 self.set(k, x)
1413
1414 def __str__(self):
1415 return f"{self.__class__.__name__}({[self[i] for i in range(self.size)]})"
仕様¶
- class DynamicWaveletMatrix(sigma: int, a: Sequence[int] = [])[source]¶
Bases:
WaveletMatrix
動的ウェーブレット行列です。
- (静的)ウェーブレット行列でできる操作に加えて
insert / pop / set
等ができます。 BitVector
を平衡二分木にしています(AVLTreeBitVector
)。あらゆる操作に平衡二分木の log がつきます。これヤバくね
\(O(n\log{(\sigma)})\) です。
- (静的)ウェーブレット行列でできる操作に加えて