avl_tree_bit_vector¶
ソースコード¶
from titan_pylib.data_structures.bit_vector.avl_tree_bit_vector import AVLTreeBitVector
展開済みコード¶
1# from titan_pylib.data_structures.bit_vector.avl_tree_bit_vector import AVLTreeBitVector
2# from titan_pylib.data_structures.bit_vector.bit_vector_interface import (
3# BitVectorInterface,
4# )
5from abc import ABC, abstractmethod
6
7
8class BitVectorInterface(ABC):
9
10 @abstractmethod
11 def access(self, k: int) -> int:
12 raise NotImplementedError
13
14 @abstractmethod
15 def __getitem__(self, k: int) -> int:
16 raise NotImplementedError
17
18 @abstractmethod
19 def rank0(self, r: int) -> int:
20 raise NotImplementedError
21
22 @abstractmethod
23 def rank1(self, r: int) -> int:
24 raise NotImplementedError
25
26 @abstractmethod
27 def rank(self, r: int, v: int) -> int:
28 raise NotImplementedError
29
30 @abstractmethod
31 def select0(self, k: int) -> int:
32 raise NotImplementedError
33
34 @abstractmethod
35 def select1(self, k: int) -> int:
36 raise NotImplementedError
37
38 @abstractmethod
39 def select(self, k: int, v: int) -> int:
40 raise NotImplementedError
41
42 @abstractmethod
43 def __len__(self) -> int:
44 raise NotImplementedError
45
46 @abstractmethod
47 def __str__(self) -> str:
48 raise NotImplementedError
49
50 @abstractmethod
51 def __repr__(self) -> str:
52 raise NotImplementedError
53from array import array
54from typing import Iterable, Final, Sequence
55
56titan_pylib_AVLTreeBitVector_W: Final[int] = 31
57
58
59class AVLTreeBitVector(BitVectorInterface):
60 """AVL木で書かれたビットベクトルです。簡潔でもなんでもありません。
61
62 bit列を管理するわけですが、各節点は 1~32 bit を持つようにしています。
63 これにより、最大 32 倍高速化が行えます。(16~32bitとするといいんだろうけど)
64 """
65
66 @staticmethod
67 def _popcount(x: int) -> int:
68 x = x - ((x >> 1) & 0x55555555)
69 x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
70 x = x + (x >> 4) & 0x0F0F0F0F
71 x += x >> 8
72 x += x >> 16
73 return x & 0x0000007F
74
75 def __init__(self, a: Iterable[int] = []):
76 """
77 :math:`O(n)` です。
78
79 Args:
80 a (Iterable[int], optional): 構築元の配列です。
81 """
82 self.root = 0
83 self.bit_len = array("B", bytes(1))
84 self.key = array("I", bytes(4))
85 self.size = array("I", bytes(4))
86 self.total = array("I", bytes(4))
87 self.left = array("I", bytes(4))
88 self.right = array("I", bytes(4))
89 self.balance = array("b", bytes(1))
90 self.end = 1
91 if a:
92 self._build(a)
93
94 def reserve(self, n: int) -> None:
95 """``n`` 要素分のメモリを確保します。
96 :math:`O(n)` です。
97 """
98 n = n // titan_pylib_AVLTreeBitVector_W + 1
99 a = array("I", bytes(4 * n))
100 self.bit_len += array("B", bytes(n))
101 self.key += a
102 self.size += a
103 self.total += a
104 self.left += a
105 self.right += a
106 self.balance += array("b", bytes(n))
107
108 def _build(self, a: Iterable[int]) -> None:
109 key, bit_len, left, right, size, balance, total = (
110 self.key,
111 self.bit_len,
112 self.left,
113 self.right,
114 self.size,
115 self.balance,
116 self.total,
117 )
118 _popcount = AVLTreeBitVector._popcount
119
120 def rec(lr: int) -> int:
121 l, r = lr >> bit, lr & msk
122 mid = (l + r) >> 1
123 hl, hr = 0, 0
124 if l != mid:
125 le = rec(l << bit | mid)
126 left[mid], hl = le >> bit, le & msk
127 size[mid] += size[left[mid]]
128 total[mid] += total[left[mid]]
129 if mid + 1 != r:
130 ri = rec((mid + 1) << bit | r)
131 right[mid], hr = ri >> bit, ri & msk
132 size[mid] += size[right[mid]]
133 total[mid] += total[right[mid]]
134 balance[mid] = hl - hr
135 return mid << bit | (max(hl, hr) + 1)
136
137 if not isinstance(a, Sequence):
138 a = list(a)
139 n = len(a)
140 bit = n.bit_length() + 2
141 msk = (1 << bit) - 1
142 end = self.end
143 self.reserve(n)
144 i = 0
145 indx = end
146 for i in range(0, n, titan_pylib_AVLTreeBitVector_W):
147 j = 0
148 v = 0
149 while j < titan_pylib_AVLTreeBitVector_W and i + j < n:
150 v <<= 1
151 v |= a[i + j]
152 j += 1
153 key[indx] = v
154 bit_len[indx] = j
155 size[indx] = j
156 total[indx] = _popcount(v)
157 indx += 1
158 self.end = indx
159 self.root = rec(end << bit | self.end) >> bit
160
161 def _rotate_L(self, node: int) -> int:
162 left, right, size, balance, total = (
163 self.left,
164 self.right,
165 self.size,
166 self.balance,
167 self.total,
168 )
169 u = left[node]
170 size[u] = size[node]
171 total[u] = total[node]
172 size[node] -= size[left[u]] + self.bit_len[u]
173 total[node] -= total[left[u]] + AVLTreeBitVector._popcount(self.key[u])
174 left[node] = right[u]
175 right[u] = node
176 if balance[u] == 1:
177 balance[u] = 0
178 balance[node] = 0
179 else:
180 balance[u] = -1
181 balance[node] = 1
182 return u
183
184 def _rotate_R(self, node: int) -> int:
185 left, right, size, balance, total = (
186 self.left,
187 self.right,
188 self.size,
189 self.balance,
190 self.total,
191 )
192 u = right[node]
193 size[u] = size[node]
194 total[u] = total[node]
195 size[node] -= size[right[u]] + self.bit_len[u]
196 total[node] -= total[right[u]] + AVLTreeBitVector._popcount(self.key[u])
197 right[node] = left[u]
198 left[u] = node
199 if balance[u] == -1:
200 balance[u] = 0
201 balance[node] = 0
202 else:
203 balance[u] = 1
204 balance[node] = -1
205 return u
206
207 def _update_balance(self, node: int) -> None:
208 balance = self.balance
209 if balance[node] == 1:
210 balance[self.right[node]] = -1
211 balance[self.left[node]] = 0
212 elif balance[node] == -1:
213 balance[self.right[node]] = 0
214 balance[self.left[node]] = 1
215 else:
216 balance[self.right[node]] = 0
217 balance[self.left[node]] = 0
218 balance[node] = 0
219
220 def _rotate_LR(self, node: int) -> int:
221 left, right, size, total = self.left, self.right, self.size, self.total
222 B = left[node]
223 E = right[B]
224 size[E] = size[node]
225 size[node] -= size[B] - size[right[E]]
226 size[B] -= size[right[E]] + self.bit_len[E]
227 total[E] = total[node]
228 total[node] -= total[B] - total[right[E]]
229 total[B] -= total[right[E]] + AVLTreeBitVector._popcount(self.key[E])
230 right[B] = left[E]
231 left[E] = B
232 left[node] = right[E]
233 right[E] = node
234 self._update_balance(E)
235 return E
236
237 def _rotate_RL(self, node: int) -> int:
238 left, right, size, total = self.left, self.right, self.size, self.total
239 C = right[node]
240 D = left[C]
241 size[D] = size[node]
242 size[node] -= size[C] - size[left[D]]
243 size[C] -= size[left[D]] + self.bit_len[D]
244 total[D] = total[node]
245 total[node] -= total[C] - total[left[D]]
246 total[C] -= total[left[D]] + AVLTreeBitVector._popcount(self.key[D])
247 left[C] = right[D]
248 right[D] = C
249 right[node] = left[D]
250 left[D] = node
251 self._update_balance(D)
252 return D
253
254 def _pref(self, r: int) -> int:
255 left, right, bit_len, size, key, total = (
256 self.left,
257 self.right,
258 self.bit_len,
259 self.size,
260 self.key,
261 self.total,
262 )
263 node = self.root
264 s = 0
265 while r > 0:
266 t = size[left[node]] + bit_len[node]
267 if t - bit_len[node] < r <= t:
268 r -= size[left[node]]
269 s += total[left[node]] + AVLTreeBitVector._popcount(
270 key[node] >> (bit_len[node] - r)
271 )
272 break
273 if t > r:
274 node = left[node]
275 else:
276 s += total[left[node]] + AVLTreeBitVector._popcount(key[node])
277 node = right[node]
278 r -= t
279 return s
280
281 def _make_node(self, key: int, bit_len: int) -> int:
282 end = self.end
283 if end >= len(self.key):
284 self.key.append(key)
285 self.bit_len.append(bit_len)
286 self.size.append(bit_len)
287 self.total.append(AVLTreeBitVector._popcount(key))
288 self.left.append(0)
289 self.right.append(0)
290 self.balance.append(0)
291 else:
292 self.key[end] = key
293 self.bit_len[end] = bit_len
294 self.size[end] = bit_len
295 self.total[end] = AVLTreeBitVector._popcount(key)
296 self.end += 1
297 return end
298
299 def insert(self, k: int, key: int) -> None:
300 """``k`` 番目に ``v`` を挿入します。
301 :math:`O(\\log{n})` です。
302
303 Args:
304 k (int): 挿入位置のインデックスです。
305 key (int): 挿入する値です。 ``0`` または ``1`` である必要があります。
306 """
307 if self.root == 0:
308 self.root = self._make_node(key, 1)
309 return
310 left, right, size, bit_len, balance, keys, total = (
311 self.left,
312 self.right,
313 self.size,
314 self.bit_len,
315 self.balance,
316 self.key,
317 self.total,
318 )
319 node = self.root
320 path = []
321 d = 0
322 while node:
323 t = size[left[node]] + bit_len[node]
324 if t - bit_len[node] <= k <= t:
325 break
326 d <<= 1
327 size[node] += 1
328 total[node] += key
329 path.append(node)
330 node = left[node] if t > k else right[node]
331 if t > k:
332 d |= 1
333 else:
334 k -= t
335 k -= size[left[node]]
336 if bit_len[node] < titan_pylib_AVLTreeBitVector_W:
337 v = keys[node]
338 bl = bit_len[node] - k
339 keys[node] = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
340 bit_len[node] += 1
341 size[node] += 1
342 total[node] += key
343 return
344 path.append(node)
345 size[node] += 1
346 total[node] += key
347 v = keys[node]
348 bl = titan_pylib_AVLTreeBitVector_W - k
349 v = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
350 left_key = v >> titan_pylib_AVLTreeBitVector_W
351 left_key_popcount = left_key & 1
352 keys[node] = v & ((1 << titan_pylib_AVLTreeBitVector_W) - 1)
353 node = left[node]
354 d <<= 1
355 d |= 1
356 if not node:
357 if bit_len[path[-1]] < titan_pylib_AVLTreeBitVector_W:
358 bit_len[path[-1]] += 1
359 keys[path[-1]] = (keys[path[-1]] << 1) | left_key
360 return
361 else:
362 left[path[-1]] = self._make_node(left_key, 1)
363 else:
364 path.append(node)
365 size[node] += 1
366 total[node] += left_key_popcount
367 d <<= 1
368 while right[node]:
369 node = right[node]
370 path.append(node)
371 size[node] += 1
372 total[node] += left_key_popcount
373 d <<= 1
374 if bit_len[node] < titan_pylib_AVLTreeBitVector_W:
375 bit_len[node] += 1
376 keys[node] = (keys[node] << 1) | left_key
377 return
378 else:
379 right[node] = self._make_node(left_key, 1)
380 new_node = 0
381 while path:
382 node = path.pop()
383 balance[node] += 1 if d & 1 else -1
384 d >>= 1
385 if balance[node] == 0:
386 break
387 if balance[node] == 2:
388 new_node = (
389 self._rotate_LR(node)
390 if balance[left[node]] == -1
391 else self._rotate_L(node)
392 )
393 break
394 elif balance[node] == -2:
395 new_node = (
396 self._rotate_RL(node)
397 if balance[right[node]] == 1
398 else self._rotate_R(node)
399 )
400 break
401 if new_node:
402 if path:
403 if d & 1:
404 left[path[-1]] = new_node
405 else:
406 right[path[-1]] = new_node
407 else:
408 self.root = new_node
409
410 def _pop_under(self, path: list[int], d: int, node: int, res: int) -> None:
411 left, right, size, bit_len, balance, keys, total = (
412 self.left,
413 self.right,
414 self.size,
415 self.bit_len,
416 self.balance,
417 self.key,
418 self.total,
419 )
420 fd, lmax_total, lmax_bit_len = 0, 0, 0
421 if left[node] and right[node]:
422 path.append(node)
423 d <<= 1
424 d |= 1
425 lmax = left[node]
426 while right[lmax]:
427 path.append(lmax)
428 d <<= 1
429 fd <<= 1
430 fd |= 1
431 lmax = right[lmax]
432 lmax_total = AVLTreeBitVector._popcount(keys[lmax])
433 lmax_bit_len = bit_len[lmax]
434 keys[node] = keys[lmax]
435 bit_len[node] = lmax_bit_len
436 node = lmax
437 cnode = right[node] if left[node] == 0 else left[node]
438 if path:
439 if d & 1:
440 left[path[-1]] = cnode
441 else:
442 right[path[-1]] = cnode
443 else:
444 self.root = cnode
445 return
446 while path:
447 new_node = 0
448 node = path.pop()
449 balance[node] -= 1 if d & 1 else -1
450 size[node] -= lmax_bit_len if fd & 1 else 1
451 total[node] -= lmax_total if fd & 1 else res
452 d >>= 1
453 fd >>= 1
454 if balance[node] == 2:
455 new_node = (
456 self._rotate_LR(node)
457 if balance[left[node]] < 0
458 else self._rotate_L(node)
459 )
460 elif balance[node] == -2:
461 new_node = (
462 self._rotate_RL(node)
463 if balance[right[node]] > 0
464 else self._rotate_R(node)
465 )
466 elif balance[node] != 0:
467 break
468 if new_node:
469 if not path:
470 self.root = new_node
471 return
472 if d & 1:
473 left[path[-1]] = new_node
474 else:
475 right[path[-1]] = new_node
476 if balance[new_node] != 0:
477 break
478 while path:
479 node = path.pop()
480 size[node] -= lmax_bit_len if fd & 1 else 1
481 total[node] -= lmax_total if fd & 1 else res
482 fd >>= 1
483
484 def pop(self, k: int) -> int:
485 """``k`` 番目の要素を削除し、その値を返します。
486 :math:`O(\\log{n})` です。
487
488 Args:
489 k (int): 削除位置のインデックスです。
490 """
491 assert 0 <= k < len(self)
492 left, right, size = self.left, self.right, self.size
493 bit_len, keys, total = self.bit_len, self.key, self.total
494 node = self.root
495 d = 0
496 path = []
497 while node:
498 t = size[left[node]] + bit_len[node]
499 if t - bit_len[node] <= k < t:
500 break
501 path.append(node)
502 node = left[node] if t > k else right[node]
503 d <<= 1
504 if t > k:
505 d |= 1
506 else:
507 k -= t
508 k -= size[left[node]]
509 v = keys[node]
510 res = v >> (bit_len[node] - k - 1) & 1
511 if bit_len[node] == 1:
512 self._pop_under(path, d, node, res)
513 return res
514 keys[node] = ((v >> (bit_len[node] - k)) << ((bit_len[node] - k - 1))) | (
515 v & ((1 << (bit_len[node] - k - 1)) - 1)
516 )
517 bit_len[node] -= 1
518 size[node] -= 1
519 total[node] -= res
520 for p in path:
521 size[p] -= 1
522 total[p] -= res
523 return res
524
525 def set(self, k: int, v: int) -> None:
526 """``k`` 番目の値を ``v`` に更新します。
527 :math:`O(\\log{n})` です。
528
529 Args:
530 k (int): 更新位置のインデックスです。
531 key (int): 更新する値です。 ``0`` または ``1`` である必要があります。
532 """
533 self.__setitem__(k, v)
534
535 def tolist(self) -> list[int]:
536 """リストにして返します。
537 :math:`O(n)` です。
538 """
539 left, right, key, bit_len = self.left, self.right, self.key, self.bit_len
540 a = []
541 if not self.root:
542 return a
543
544 def rec(node):
545 if left[node]:
546 rec(left[node])
547 for i in range(bit_len[node] - 1, -1, -1):
548 a.append(key[node] >> i & 1)
549 if right[node]:
550 rec(right[node])
551
552 rec(self.root)
553 return a
554
555 def _debug_acc(self) -> None:
556 """デバッグ用のメソッドです。
557 key,totalをチェックします。
558 """
559 left, right = self.left, self.right
560 key = self.key
561
562 def rec(node):
563 acc = self._popcount(key[node])
564 if left[node]:
565 acc += rec(left[node])
566 if right[node]:
567 acc += rec(right[node])
568 if acc != self.total[node]:
569 # self.debug()
570 assert False, "acc Error"
571 return acc
572
573 rec(self.root)
574 print("debug_acc ok.")
575
576 def access(self, k: int) -> int:
577 """``k`` 番目の値を返します。
578 :math:`O(\\log{n})` です。
579
580 Args:
581 k (int): 取得位置のインデックスです。
582 """
583 return self.__getitem__(k)
584
585 def rank0(self, r: int) -> int:
586 """``a[0, r)`` に含まれる ``0`` の個数を返します。
587 :math:`O(\\log{n})` です。
588 """
589 return r - self._pref(r)
590
591 def rank1(self, r: int) -> int:
592 """``a[0, r)`` に含まれる ``1`` の個数を返します。
593 :math:`O(\\log{n})` です。
594 """
595 return self._pref(r)
596
597 def rank(self, r: int, v: int) -> int:
598 """``a[0, r)`` に含まれる ``v`` の個数を返します。
599 :math:`O(\\log{n})` です。
600 """
601 return self.rank1(r) if v else self.rank0(r)
602
603 def select0(self, k: int) -> int:
604 """``k`` 番目の ``0`` のインデックスを返します。
605 :math:`O(\\log{n}^2)` です。
606 """
607 if k < 0 or self.rank0(len(self)) <= k:
608 return -1
609 l, r = 0, len(self)
610 while r - l > 1:
611 m = (l + r) >> 1
612 if m - self._pref(m) > k:
613 r = m
614 else:
615 l = m
616 return l
617
618 def select1(self, k: int) -> int:
619 """``k`` 番目の ``1`` のインデックスを返します。
620 :math:`O(\\log{n}^2)` です。
621 """
622 if k < 0 or self.rank1(len(self)) <= k:
623 return -1
624 l, r = 0, len(self)
625 while r - l > 1:
626 m = (l + r) >> 1
627 if self._pref(m) > k:
628 r = m
629 else:
630 l = m
631 return l
632
633 def select(self, k: int, v: int) -> int:
634 """``k`` 番目の ``v`` のインデックスを返します。
635 :math:`O(\\log{n}^2)` です。
636 """
637 return self.select1(k) if v else self.select0(k)
638
639 def _insert_and_rank1(self, k: int, key: int) -> int:
640 if self.root == 0:
641 self.root = self._make_node(key, 1)
642 return 0
643 left, right, size, bit_len, balance, keys, total = (
644 self.left,
645 self.right,
646 self.size,
647 self.bit_len,
648 self.balance,
649 self.key,
650 self.total,
651 )
652 node = self.root
653 s = 0
654 path = []
655 d = 0
656 while node:
657 t = size[left[node]] + bit_len[node]
658 if t - bit_len[node] <= k <= t:
659 break
660 if t <= k:
661 s += total[left[node]] + AVLTreeBitVector._popcount(keys[node])
662 d <<= 1
663 size[node] += 1
664 total[node] += key
665 path.append(node)
666 node = left[node] if t > k else right[node]
667 if t > k:
668 d |= 1
669 else:
670 k -= t
671 k -= size[left[node]]
672 s += total[left[node]] + AVLTreeBitVector._popcount(
673 keys[node] >> (bit_len[node] - k)
674 )
675 if bit_len[node] < titan_pylib_AVLTreeBitVector_W:
676 v = keys[node]
677 bl = bit_len[node] - k
678 keys[node] = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
679 bit_len[node] += 1
680 size[node] += 1
681 total[node] += key
682 return s
683 path.append(node)
684 size[node] += 1
685 total[node] += key
686 v = keys[node]
687 bl = titan_pylib_AVLTreeBitVector_W - k
688 v = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
689 left_key = v >> titan_pylib_AVLTreeBitVector_W
690 left_key_popcount = left_key & 1
691 keys[node] = v & ((1 << titan_pylib_AVLTreeBitVector_W) - 1)
692 node = left[node]
693 d <<= 1
694 d |= 1
695 if not node:
696 if bit_len[path[-1]] < titan_pylib_AVLTreeBitVector_W:
697 bit_len[path[-1]] += 1
698 keys[path[-1]] = (keys[path[-1]] << 1) | left_key
699 return s
700 else:
701 left[path[-1]] = self._make_node(left_key, 1)
702 else:
703 path.append(node)
704 size[node] += 1
705 total[node] += left_key_popcount
706 d <<= 1
707 while right[node]:
708 node = right[node]
709 path.append(node)
710 size[node] += 1
711 total[node] += left_key_popcount
712 d <<= 1
713 if bit_len[node] < titan_pylib_AVLTreeBitVector_W:
714 bit_len[node] += 1
715 keys[node] = (keys[node] << 1) | left_key
716 return s
717 else:
718 right[node] = self._make_node(left_key, 1)
719 new_node = 0
720 while path:
721 node = path.pop()
722 balance[node] += 1 if d & 1 else -1
723 d >>= 1
724 if balance[node] == 0:
725 break
726 if balance[node] == 2:
727 new_node = (
728 self._rotate_LR(node)
729 if balance[left[node]] == -1
730 else self._rotate_L(node)
731 )
732 break
733 elif balance[node] == -2:
734 new_node = (
735 self._rotate_RL(node)
736 if balance[right[node]] == 1
737 else self._rotate_R(node)
738 )
739 break
740 if new_node:
741 if path:
742 if d & 1:
743 left[path[-1]] = new_node
744 else:
745 right[path[-1]] = new_node
746 else:
747 self.root = new_node
748 return s
749
750 def _access_pop_and_rank1(self, k: int) -> int:
751 assert 0 <= k < len(self)
752 left, right, size = self.left, self.right, self.size
753 bit_len, keys, total = self.bit_len, self.key, self.total
754 s = 0
755 node = self.root
756 d = 0
757 path = []
758 while node:
759 t = size[left[node]] + bit_len[node]
760 if t - bit_len[node] <= k < t:
761 break
762 if t <= k:
763 s += total[left[node]] + AVLTreeBitVector._popcount(keys[node])
764 path.append(node)
765 node = left[node] if t > k else right[node]
766 d <<= 1
767 if t > k:
768 d |= 1
769 else:
770 k -= t
771 k -= size[left[node]]
772 s += total[left[node]] + AVLTreeBitVector._popcount(
773 keys[node] >> (bit_len[node] - k)
774 )
775 v = keys[node]
776 res = v >> (bit_len[node] - k - 1) & 1
777 if bit_len[node] == 1:
778 self._pop_under(path, d, node, res)
779 return s << 1 | res
780 keys[node] = ((v >> (bit_len[node] - k)) << ((bit_len[node] - k - 1))) | (
781 v & ((1 << (bit_len[node] - k - 1)) - 1)
782 )
783 bit_len[node] -= 1
784 size[node] -= 1
785 total[node] -= res
786 for p in path:
787 size[p] -= 1
788 total[p] -= res
789 return s << 1 | res
790
791 def __getitem__(self, k: int) -> int:
792 """``k`` 番目の要素を返します。
793 :math:`O(\\log{n})` です。
794 """
795 assert 0 <= k < len(self)
796 left, right, bit_len, size, key = (
797 self.left,
798 self.right,
799 self.bit_len,
800 self.size,
801 self.key,
802 )
803 node = self.root
804 while True:
805 t = size[left[node]] + bit_len[node]
806 if t - bit_len[node] <= k < t:
807 k -= size[left[node]]
808 return key[node] >> (bit_len[node] - k - 1) & 1
809 if t > k:
810 node = left[node]
811 else:
812 node = right[node]
813 k -= t
814
815 def __setitem__(self, k: int, v: int) -> None:
816 """``k`` 番目の要素を ``v`` に更新します。
817 :math:`O(\\log{n})` です。
818 """
819 left, right, bit_len, size, key, total = (
820 self.left,
821 self.right,
822 self.bit_len,
823 self.size,
824 self.key,
825 self.total,
826 )
827 assert v == 0 or v == 1, "ValueError"
828 node = self.root
829 path = []
830 while True:
831 t = size[left[node]] + bit_len[node]
832 path.append(node)
833 if t - bit_len[node] <= k < t:
834 k -= size[left[node]]
835 if v:
836 key[node] |= 1 << k
837 else:
838 key[node] &= ~(1 << k)
839 break
840 elif t > k:
841 node = left[node]
842 else:
843 node = right[node]
844 k -= t
845 while path:
846 node = path.pop()
847 total[node] = (
848 AVLTreeBitVector._popcount(key[node])
849 + total[left[node]]
850 + total[right[node]]
851 )
852
853 def __str__(self):
854 return str(self.tolist())
855
856 def __len__(self):
857 return self.size[self.root]
858
859 def __repr__(self):
860 return f"{self.__class__.__name__}({self})"
仕様¶
- class AVLTreeBitVector(a: Iterable[int] = [])[source]¶
Bases:
BitVectorInterface
AVL木で書かれたビットベクトルです。簡潔でもなんでもありません。
bit列を管理するわけですが、各節点は 1~32 bit を持つようにしています。 これにより、最大 32 倍高速化が行えます。(16~32bitとするといいんだろうけど)
- insert(k: int, key: int) None [source]¶
k
番目にv
を挿入します。 \(O(\log{n})\) です。- Parameters:
k (int) – 挿入位置のインデックスです。
key (int) – 挿入する値です。
0
または1
である必要があります。
- pop(k: int) int [source]¶
k
番目の要素を削除し、その値を返します。 \(O(\log{n})\) です。- Parameters:
k (int) – 削除位置のインデックスです。