avl_tree_bit_vector

ソースコード

from titan_pylib.data_structures.bit_vector.avl_tree_bit_vector import AVLTreeBitVector

view on github

展開済みコード

  1# from titan_pylib.data_structures.bit_vector.avl_tree_bit_vector import AVLTreeBitVector
  2# from titan_pylib.data_structures.bit_vector.bit_vector_interface import (
  3#     BitVectorInterface,
  4# )
  5from abc import ABC, abstractmethod
  6
  7
  8class BitVectorInterface(ABC):
  9
 10    @abstractmethod
 11    def access(self, k: int) -> int:
 12        raise NotImplementedError
 13
 14    @abstractmethod
 15    def __getitem__(self, k: int) -> int:
 16        raise NotImplementedError
 17
 18    @abstractmethod
 19    def rank0(self, r: int) -> int:
 20        raise NotImplementedError
 21
 22    @abstractmethod
 23    def rank1(self, r: int) -> int:
 24        raise NotImplementedError
 25
 26    @abstractmethod
 27    def rank(self, r: int, v: int) -> int:
 28        raise NotImplementedError
 29
 30    @abstractmethod
 31    def select0(self, k: int) -> int:
 32        raise NotImplementedError
 33
 34    @abstractmethod
 35    def select1(self, k: int) -> int:
 36        raise NotImplementedError
 37
 38    @abstractmethod
 39    def select(self, k: int, v: int) -> int:
 40        raise NotImplementedError
 41
 42    @abstractmethod
 43    def __len__(self) -> int:
 44        raise NotImplementedError
 45
 46    @abstractmethod
 47    def __str__(self) -> str:
 48        raise NotImplementedError
 49
 50    @abstractmethod
 51    def __repr__(self) -> str:
 52        raise NotImplementedError
 53from array import array
 54from typing import Iterable, Final, Sequence
 55
 56titan_pylib_AVLTreeBitVector_W: Final[int] = 31
 57
 58
 59class AVLTreeBitVector(BitVectorInterface):
 60    """AVL木で書かれたビットベクトルです。簡潔でもなんでもありません。
 61
 62    bit列を管理するわけですが、各節点は 1~32 bit を持つようにしています。
 63    これにより、最大 32 倍高速化が行えます。(16~32bitとするといいんだろうけど)
 64    """
 65
 66    @staticmethod
 67    def _popcount(x: int) -> int:
 68        x = x - ((x >> 1) & 0x55555555)
 69        x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
 70        x = x + (x >> 4) & 0x0F0F0F0F
 71        x += x >> 8
 72        x += x >> 16
 73        return x & 0x0000007F
 74
 75    def __init__(self, a: Iterable[int] = []):
 76        """
 77        :math:`O(n)` です。
 78
 79        Args:
 80          a (Iterable[int], optional): 構築元の配列です。
 81        """
 82        self.root = 0
 83        self.bit_len = array("B", bytes(1))
 84        self.key = array("I", bytes(4))
 85        self.size = array("I", bytes(4))
 86        self.total = array("I", bytes(4))
 87        self.left = array("I", bytes(4))
 88        self.right = array("I", bytes(4))
 89        self.balance = array("b", bytes(1))
 90        self.end = 1
 91        if a:
 92            self._build(a)
 93
 94    def reserve(self, n: int) -> None:
 95        """``n`` 要素分のメモリを確保します。
 96        :math:`O(n)` です。
 97        """
 98        n = n // titan_pylib_AVLTreeBitVector_W + 1
 99        a = array("I", bytes(4 * n))
100        self.bit_len += array("B", bytes(n))
101        self.key += a
102        self.size += a
103        self.total += a
104        self.left += a
105        self.right += a
106        self.balance += array("b", bytes(n))
107
108    def _build(self, a: Iterable[int]) -> None:
109        key, bit_len, left, right, size, balance, total = (
110            self.key,
111            self.bit_len,
112            self.left,
113            self.right,
114            self.size,
115            self.balance,
116            self.total,
117        )
118        _popcount = AVLTreeBitVector._popcount
119
120        def rec(lr: int) -> int:
121            l, r = lr >> bit, lr & msk
122            mid = (l + r) >> 1
123            hl, hr = 0, 0
124            if l != mid:
125                le = rec(l << bit | mid)
126                left[mid], hl = le >> bit, le & msk
127                size[mid] += size[left[mid]]
128                total[mid] += total[left[mid]]
129            if mid + 1 != r:
130                ri = rec((mid + 1) << bit | r)
131                right[mid], hr = ri >> bit, ri & msk
132                size[mid] += size[right[mid]]
133                total[mid] += total[right[mid]]
134            balance[mid] = hl - hr
135            return mid << bit | (max(hl, hr) + 1)
136
137        if not isinstance(a, Sequence):
138            a = list(a)
139        n = len(a)
140        bit = n.bit_length() + 2
141        msk = (1 << bit) - 1
142        end = self.end
143        self.reserve(n)
144        i = 0
145        indx = end
146        for i in range(0, n, titan_pylib_AVLTreeBitVector_W):
147            j = 0
148            v = 0
149            while j < titan_pylib_AVLTreeBitVector_W and i + j < n:
150                v <<= 1
151                v |= a[i + j]
152                j += 1
153            key[indx] = v
154            bit_len[indx] = j
155            size[indx] = j
156            total[indx] = _popcount(v)
157            indx += 1
158        self.end = indx
159        self.root = rec(end << bit | self.end) >> bit
160
161    def _rotate_L(self, node: int) -> int:
162        left, right, size, balance, total = (
163            self.left,
164            self.right,
165            self.size,
166            self.balance,
167            self.total,
168        )
169        u = left[node]
170        size[u] = size[node]
171        total[u] = total[node]
172        size[node] -= size[left[u]] + self.bit_len[u]
173        total[node] -= total[left[u]] + AVLTreeBitVector._popcount(self.key[u])
174        left[node] = right[u]
175        right[u] = node
176        if balance[u] == 1:
177            balance[u] = 0
178            balance[node] = 0
179        else:
180            balance[u] = -1
181            balance[node] = 1
182        return u
183
184    def _rotate_R(self, node: int) -> int:
185        left, right, size, balance, total = (
186            self.left,
187            self.right,
188            self.size,
189            self.balance,
190            self.total,
191        )
192        u = right[node]
193        size[u] = size[node]
194        total[u] = total[node]
195        size[node] -= size[right[u]] + self.bit_len[u]
196        total[node] -= total[right[u]] + AVLTreeBitVector._popcount(self.key[u])
197        right[node] = left[u]
198        left[u] = node
199        if balance[u] == -1:
200            balance[u] = 0
201            balance[node] = 0
202        else:
203            balance[u] = 1
204            balance[node] = -1
205        return u
206
207    def _update_balance(self, node: int) -> None:
208        balance = self.balance
209        if balance[node] == 1:
210            balance[self.right[node]] = -1
211            balance[self.left[node]] = 0
212        elif balance[node] == -1:
213            balance[self.right[node]] = 0
214            balance[self.left[node]] = 1
215        else:
216            balance[self.right[node]] = 0
217            balance[self.left[node]] = 0
218        balance[node] = 0
219
220    def _rotate_LR(self, node: int) -> int:
221        left, right, size, total = self.left, self.right, self.size, self.total
222        B = left[node]
223        E = right[B]
224        size[E] = size[node]
225        size[node] -= size[B] - size[right[E]]
226        size[B] -= size[right[E]] + self.bit_len[E]
227        total[E] = total[node]
228        total[node] -= total[B] - total[right[E]]
229        total[B] -= total[right[E]] + AVLTreeBitVector._popcount(self.key[E])
230        right[B] = left[E]
231        left[E] = B
232        left[node] = right[E]
233        right[E] = node
234        self._update_balance(E)
235        return E
236
237    def _rotate_RL(self, node: int) -> int:
238        left, right, size, total = self.left, self.right, self.size, self.total
239        C = right[node]
240        D = left[C]
241        size[D] = size[node]
242        size[node] -= size[C] - size[left[D]]
243        size[C] -= size[left[D]] + self.bit_len[D]
244        total[D] = total[node]
245        total[node] -= total[C] - total[left[D]]
246        total[C] -= total[left[D]] + AVLTreeBitVector._popcount(self.key[D])
247        left[C] = right[D]
248        right[D] = C
249        right[node] = left[D]
250        left[D] = node
251        self._update_balance(D)
252        return D
253
254    def _pref(self, r: int) -> int:
255        left, right, bit_len, size, key, total = (
256            self.left,
257            self.right,
258            self.bit_len,
259            self.size,
260            self.key,
261            self.total,
262        )
263        node = self.root
264        s = 0
265        while r > 0:
266            t = size[left[node]] + bit_len[node]
267            if t - bit_len[node] < r <= t:
268                r -= size[left[node]]
269                s += total[left[node]] + AVLTreeBitVector._popcount(
270                    key[node] >> (bit_len[node] - r)
271                )
272                break
273            if t > r:
274                node = left[node]
275            else:
276                s += total[left[node]] + AVLTreeBitVector._popcount(key[node])
277                node = right[node]
278                r -= t
279        return s
280
281    def _make_node(self, key: int, bit_len: int) -> int:
282        end = self.end
283        if end >= len(self.key):
284            self.key.append(key)
285            self.bit_len.append(bit_len)
286            self.size.append(bit_len)
287            self.total.append(AVLTreeBitVector._popcount(key))
288            self.left.append(0)
289            self.right.append(0)
290            self.balance.append(0)
291        else:
292            self.key[end] = key
293            self.bit_len[end] = bit_len
294            self.size[end] = bit_len
295            self.total[end] = AVLTreeBitVector._popcount(key)
296        self.end += 1
297        return end
298
299    def insert(self, k: int, key: int) -> None:
300        """``k`` 番目に ``v`` を挿入します。
301        :math:`O(\\log{n})` です。
302
303        Args:
304          k (int): 挿入位置のインデックスです。
305          key (int): 挿入する値です。 ``0`` または ``1`` である必要があります。
306        """
307        if self.root == 0:
308            self.root = self._make_node(key, 1)
309            return
310        left, right, size, bit_len, balance, keys, total = (
311            self.left,
312            self.right,
313            self.size,
314            self.bit_len,
315            self.balance,
316            self.key,
317            self.total,
318        )
319        node = self.root
320        path = []
321        d = 0
322        while node:
323            t = size[left[node]] + bit_len[node]
324            if t - bit_len[node] <= k <= t:
325                break
326            d <<= 1
327            size[node] += 1
328            total[node] += key
329            path.append(node)
330            node = left[node] if t > k else right[node]
331            if t > k:
332                d |= 1
333            else:
334                k -= t
335        k -= size[left[node]]
336        if bit_len[node] < titan_pylib_AVLTreeBitVector_W:
337            v = keys[node]
338            bl = bit_len[node] - k
339            keys[node] = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
340            bit_len[node] += 1
341            size[node] += 1
342            total[node] += key
343            return
344        path.append(node)
345        size[node] += 1
346        total[node] += key
347        v = keys[node]
348        bl = titan_pylib_AVLTreeBitVector_W - k
349        v = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
350        left_key = v >> titan_pylib_AVLTreeBitVector_W
351        left_key_popcount = left_key & 1
352        keys[node] = v & ((1 << titan_pylib_AVLTreeBitVector_W) - 1)
353        node = left[node]
354        d <<= 1
355        d |= 1
356        if not node:
357            if bit_len[path[-1]] < titan_pylib_AVLTreeBitVector_W:
358                bit_len[path[-1]] += 1
359                keys[path[-1]] = (keys[path[-1]] << 1) | left_key
360                return
361            else:
362                left[path[-1]] = self._make_node(left_key, 1)
363        else:
364            path.append(node)
365            size[node] += 1
366            total[node] += left_key_popcount
367            d <<= 1
368            while right[node]:
369                node = right[node]
370                path.append(node)
371                size[node] += 1
372                total[node] += left_key_popcount
373                d <<= 1
374            if bit_len[node] < titan_pylib_AVLTreeBitVector_W:
375                bit_len[node] += 1
376                keys[node] = (keys[node] << 1) | left_key
377                return
378            else:
379                right[node] = self._make_node(left_key, 1)
380        new_node = 0
381        while path:
382            node = path.pop()
383            balance[node] += 1 if d & 1 else -1
384            d >>= 1
385            if balance[node] == 0:
386                break
387            if balance[node] == 2:
388                new_node = (
389                    self._rotate_LR(node)
390                    if balance[left[node]] == -1
391                    else self._rotate_L(node)
392                )
393                break
394            elif balance[node] == -2:
395                new_node = (
396                    self._rotate_RL(node)
397                    if balance[right[node]] == 1
398                    else self._rotate_R(node)
399                )
400                break
401        if new_node:
402            if path:
403                if d & 1:
404                    left[path[-1]] = new_node
405                else:
406                    right[path[-1]] = new_node
407            else:
408                self.root = new_node
409
410    def _pop_under(self, path: list[int], d: int, node: int, res: int) -> None:
411        left, right, size, bit_len, balance, keys, total = (
412            self.left,
413            self.right,
414            self.size,
415            self.bit_len,
416            self.balance,
417            self.key,
418            self.total,
419        )
420        fd, lmax_total, lmax_bit_len = 0, 0, 0
421        if left[node] and right[node]:
422            path.append(node)
423            d <<= 1
424            d |= 1
425            lmax = left[node]
426            while right[lmax]:
427                path.append(lmax)
428                d <<= 1
429                fd <<= 1
430                fd |= 1
431                lmax = right[lmax]
432            lmax_total = AVLTreeBitVector._popcount(keys[lmax])
433            lmax_bit_len = bit_len[lmax]
434            keys[node] = keys[lmax]
435            bit_len[node] = lmax_bit_len
436            node = lmax
437        cnode = right[node] if left[node] == 0 else left[node]
438        if path:
439            if d & 1:
440                left[path[-1]] = cnode
441            else:
442                right[path[-1]] = cnode
443        else:
444            self.root = cnode
445            return
446        while path:
447            new_node = 0
448            node = path.pop()
449            balance[node] -= 1 if d & 1 else -1
450            size[node] -= lmax_bit_len if fd & 1 else 1
451            total[node] -= lmax_total if fd & 1 else res
452            d >>= 1
453            fd >>= 1
454            if balance[node] == 2:
455                new_node = (
456                    self._rotate_LR(node)
457                    if balance[left[node]] < 0
458                    else self._rotate_L(node)
459                )
460            elif balance[node] == -2:
461                new_node = (
462                    self._rotate_RL(node)
463                    if balance[right[node]] > 0
464                    else self._rotate_R(node)
465                )
466            elif balance[node] != 0:
467                break
468            if new_node:
469                if not path:
470                    self.root = new_node
471                    return
472                if d & 1:
473                    left[path[-1]] = new_node
474                else:
475                    right[path[-1]] = new_node
476                if balance[new_node] != 0:
477                    break
478        while path:
479            node = path.pop()
480            size[node] -= lmax_bit_len if fd & 1 else 1
481            total[node] -= lmax_total if fd & 1 else res
482            fd >>= 1
483
484    def pop(self, k: int) -> int:
485        """``k`` 番目の要素を削除し、その値を返します。
486        :math:`O(\\log{n})` です。
487
488        Args:
489          k (int): 削除位置のインデックスです。
490        """
491        assert 0 <= k < len(self)
492        left, right, size = self.left, self.right, self.size
493        bit_len, keys, total = self.bit_len, self.key, self.total
494        node = self.root
495        d = 0
496        path = []
497        while node:
498            t = size[left[node]] + bit_len[node]
499            if t - bit_len[node] <= k < t:
500                break
501            path.append(node)
502            node = left[node] if t > k else right[node]
503            d <<= 1
504            if t > k:
505                d |= 1
506            else:
507                k -= t
508        k -= size[left[node]]
509        v = keys[node]
510        res = v >> (bit_len[node] - k - 1) & 1
511        if bit_len[node] == 1:
512            self._pop_under(path, d, node, res)
513            return res
514        keys[node] = ((v >> (bit_len[node] - k)) << ((bit_len[node] - k - 1))) | (
515            v & ((1 << (bit_len[node] - k - 1)) - 1)
516        )
517        bit_len[node] -= 1
518        size[node] -= 1
519        total[node] -= res
520        for p in path:
521            size[p] -= 1
522            total[p] -= res
523        return res
524
525    def set(self, k: int, v: int) -> None:
526        """``k`` 番目の値を ``v`` に更新します。
527        :math:`O(\\log{n})` です。
528
529        Args:
530          k (int): 更新位置のインデックスです。
531          key (int): 更新する値です。 ``0`` または ``1`` である必要があります。
532        """
533        self.__setitem__(k, v)
534
535    def tolist(self) -> list[int]:
536        """リストにして返します。
537        :math:`O(n)` です。
538        """
539        left, right, key, bit_len = self.left, self.right, self.key, self.bit_len
540        a = []
541        if not self.root:
542            return a
543
544        def rec(node):
545            if left[node]:
546                rec(left[node])
547            for i in range(bit_len[node] - 1, -1, -1):
548                a.append(key[node] >> i & 1)
549            if right[node]:
550                rec(right[node])
551
552        rec(self.root)
553        return a
554
555    def _debug_acc(self) -> None:
556        """デバッグ用のメソッドです。
557        key,totalをチェックします。
558        """
559        left, right = self.left, self.right
560        key = self.key
561
562        def rec(node):
563            acc = self._popcount(key[node])
564            if left[node]:
565                acc += rec(left[node])
566            if right[node]:
567                acc += rec(right[node])
568            if acc != self.total[node]:
569                # self.debug()
570                assert False, "acc Error"
571            return acc
572
573        rec(self.root)
574        print("debug_acc ok.")
575
576    def access(self, k: int) -> int:
577        """``k`` 番目の値を返します。
578        :math:`O(\\log{n})` です。
579
580        Args:
581          k (int): 取得位置のインデックスです。
582        """
583        return self.__getitem__(k)
584
585    def rank0(self, r: int) -> int:
586        """``a[0, r)`` に含まれる ``0`` の個数を返します。
587        :math:`O(\\log{n})` です。
588        """
589        return r - self._pref(r)
590
591    def rank1(self, r: int) -> int:
592        """``a[0, r)`` に含まれる ``1`` の個数を返します。
593        :math:`O(\\log{n})` です。
594        """
595        return self._pref(r)
596
597    def rank(self, r: int, v: int) -> int:
598        """``a[0, r)`` に含まれる ``v`` の個数を返します。
599        :math:`O(\\log{n})` です。
600        """
601        return self.rank1(r) if v else self.rank0(r)
602
603    def select0(self, k: int) -> int:
604        """``k`` 番目の ``0`` のインデックスを返します。
605        :math:`O(\\log{n}^2)` です。
606        """
607        if k < 0 or self.rank0(len(self)) <= k:
608            return -1
609        l, r = 0, len(self)
610        while r - l > 1:
611            m = (l + r) >> 1
612            if m - self._pref(m) > k:
613                r = m
614            else:
615                l = m
616        return l
617
618    def select1(self, k: int) -> int:
619        """``k`` 番目の ``1`` のインデックスを返します。
620        :math:`O(\\log{n}^2)` です。
621        """
622        if k < 0 or self.rank1(len(self)) <= k:
623            return -1
624        l, r = 0, len(self)
625        while r - l > 1:
626            m = (l + r) >> 1
627            if self._pref(m) > k:
628                r = m
629            else:
630                l = m
631        return l
632
633    def select(self, k: int, v: int) -> int:
634        """``k`` 番目の ``v`` のインデックスを返します。
635        :math:`O(\\log{n}^2)` です。
636        """
637        return self.select1(k) if v else self.select0(k)
638
639    def _insert_and_rank1(self, k: int, key: int) -> int:
640        if self.root == 0:
641            self.root = self._make_node(key, 1)
642            return 0
643        left, right, size, bit_len, balance, keys, total = (
644            self.left,
645            self.right,
646            self.size,
647            self.bit_len,
648            self.balance,
649            self.key,
650            self.total,
651        )
652        node = self.root
653        s = 0
654        path = []
655        d = 0
656        while node:
657            t = size[left[node]] + bit_len[node]
658            if t - bit_len[node] <= k <= t:
659                break
660            if t <= k:
661                s += total[left[node]] + AVLTreeBitVector._popcount(keys[node])
662            d <<= 1
663            size[node] += 1
664            total[node] += key
665            path.append(node)
666            node = left[node] if t > k else right[node]
667            if t > k:
668                d |= 1
669            else:
670                k -= t
671        k -= size[left[node]]
672        s += total[left[node]] + AVLTreeBitVector._popcount(
673            keys[node] >> (bit_len[node] - k)
674        )
675        if bit_len[node] < titan_pylib_AVLTreeBitVector_W:
676            v = keys[node]
677            bl = bit_len[node] - k
678            keys[node] = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
679            bit_len[node] += 1
680            size[node] += 1
681            total[node] += key
682            return s
683        path.append(node)
684        size[node] += 1
685        total[node] += key
686        v = keys[node]
687        bl = titan_pylib_AVLTreeBitVector_W - k
688        v = (((v >> bl) << 1 | key) << bl) | (v & ((1 << bl) - 1))
689        left_key = v >> titan_pylib_AVLTreeBitVector_W
690        left_key_popcount = left_key & 1
691        keys[node] = v & ((1 << titan_pylib_AVLTreeBitVector_W) - 1)
692        node = left[node]
693        d <<= 1
694        d |= 1
695        if not node:
696            if bit_len[path[-1]] < titan_pylib_AVLTreeBitVector_W:
697                bit_len[path[-1]] += 1
698                keys[path[-1]] = (keys[path[-1]] << 1) | left_key
699                return s
700            else:
701                left[path[-1]] = self._make_node(left_key, 1)
702        else:
703            path.append(node)
704            size[node] += 1
705            total[node] += left_key_popcount
706            d <<= 1
707            while right[node]:
708                node = right[node]
709                path.append(node)
710                size[node] += 1
711                total[node] += left_key_popcount
712                d <<= 1
713            if bit_len[node] < titan_pylib_AVLTreeBitVector_W:
714                bit_len[node] += 1
715                keys[node] = (keys[node] << 1) | left_key
716                return s
717            else:
718                right[node] = self._make_node(left_key, 1)
719        new_node = 0
720        while path:
721            node = path.pop()
722            balance[node] += 1 if d & 1 else -1
723            d >>= 1
724            if balance[node] == 0:
725                break
726            if balance[node] == 2:
727                new_node = (
728                    self._rotate_LR(node)
729                    if balance[left[node]] == -1
730                    else self._rotate_L(node)
731                )
732                break
733            elif balance[node] == -2:
734                new_node = (
735                    self._rotate_RL(node)
736                    if balance[right[node]] == 1
737                    else self._rotate_R(node)
738                )
739                break
740        if new_node:
741            if path:
742                if d & 1:
743                    left[path[-1]] = new_node
744                else:
745                    right[path[-1]] = new_node
746            else:
747                self.root = new_node
748        return s
749
750    def _access_pop_and_rank1(self, k: int) -> int:
751        assert 0 <= k < len(self)
752        left, right, size = self.left, self.right, self.size
753        bit_len, keys, total = self.bit_len, self.key, self.total
754        s = 0
755        node = self.root
756        d = 0
757        path = []
758        while node:
759            t = size[left[node]] + bit_len[node]
760            if t - bit_len[node] <= k < t:
761                break
762            if t <= k:
763                s += total[left[node]] + AVLTreeBitVector._popcount(keys[node])
764            path.append(node)
765            node = left[node] if t > k else right[node]
766            d <<= 1
767            if t > k:
768                d |= 1
769            else:
770                k -= t
771        k -= size[left[node]]
772        s += total[left[node]] + AVLTreeBitVector._popcount(
773            keys[node] >> (bit_len[node] - k)
774        )
775        v = keys[node]
776        res = v >> (bit_len[node] - k - 1) & 1
777        if bit_len[node] == 1:
778            self._pop_under(path, d, node, res)
779            return s << 1 | res
780        keys[node] = ((v >> (bit_len[node] - k)) << ((bit_len[node] - k - 1))) | (
781            v & ((1 << (bit_len[node] - k - 1)) - 1)
782        )
783        bit_len[node] -= 1
784        size[node] -= 1
785        total[node] -= res
786        for p in path:
787            size[p] -= 1
788            total[p] -= res
789        return s << 1 | res
790
791    def __getitem__(self, k: int) -> int:
792        """``k`` 番目の要素を返します。
793        :math:`O(\\log{n})` です。
794        """
795        assert 0 <= k < len(self)
796        left, right, bit_len, size, key = (
797            self.left,
798            self.right,
799            self.bit_len,
800            self.size,
801            self.key,
802        )
803        node = self.root
804        while True:
805            t = size[left[node]] + bit_len[node]
806            if t - bit_len[node] <= k < t:
807                k -= size[left[node]]
808                return key[node] >> (bit_len[node] - k - 1) & 1
809            if t > k:
810                node = left[node]
811            else:
812                node = right[node]
813                k -= t
814
815    def __setitem__(self, k: int, v: int) -> None:
816        """``k`` 番目の要素を ``v`` に更新します。
817        :math:`O(\\log{n})` です。
818        """
819        left, right, bit_len, size, key, total = (
820            self.left,
821            self.right,
822            self.bit_len,
823            self.size,
824            self.key,
825            self.total,
826        )
827        assert v == 0 or v == 1, "ValueError"
828        node = self.root
829        path = []
830        while True:
831            t = size[left[node]] + bit_len[node]
832            path.append(node)
833            if t - bit_len[node] <= k < t:
834                k -= size[left[node]]
835                if v:
836                    key[node] |= 1 << k
837                else:
838                    key[node] &= ~(1 << k)
839                break
840            elif t > k:
841                node = left[node]
842            else:
843                node = right[node]
844                k -= t
845        while path:
846            node = path.pop()
847            total[node] = (
848                AVLTreeBitVector._popcount(key[node])
849                + total[left[node]]
850                + total[right[node]]
851            )
852
853    def __str__(self):
854        return str(self.tolist())
855
856    def __len__(self):
857        return self.size[self.root]
858
859    def __repr__(self):
860        return f"{self.__class__.__name__}({self})"

仕様

class AVLTreeBitVector(a: Iterable[int] = [])[source]

Bases: BitVectorInterface

AVL木で書かれたビットベクトルです。簡潔でもなんでもありません。

bit列を管理するわけですが、各節点は 1~32 bit を持つようにしています。 これにより、最大 32 倍高速化が行えます。(16~32bitとするといいんだろうけど)

__getitem__(k: int) int[source]

k 番目の要素を返します。 \(O(\log{n})\) です。

__setitem__(k: int, v: int) None[source]

k 番目の要素を v に更新します。 \(O(\log{n})\) です。

access(k: int) int[source]

k 番目の値を返します。 \(O(\log{n})\) です。

Parameters:

k (int) – 取得位置のインデックスです。

insert(k: int, key: int) None[source]

k 番目に v を挿入します。 \(O(\log{n})\) です。

Parameters:
  • k (int) – 挿入位置のインデックスです。

  • key (int) – 挿入する値です。 0 または 1 である必要があります。

pop(k: int) int[source]

k 番目の要素を削除し、その値を返します。 \(O(\log{n})\) です。

Parameters:

k (int) – 削除位置のインデックスです。

rank(r: int, v: int) int[source]

a[0, r) に含まれる v の個数を返します。 \(O(\log{n})\) です。

rank0(r: int) int[source]

a[0, r) に含まれる 0 の個数を返します。 \(O(\log{n})\) です。

rank1(r: int) int[source]

a[0, r) に含まれる 1 の個数を返します。 \(O(\log{n})\) です。

reserve(n: int) None[source]

n 要素分のメモリを確保します。 \(O(n)\) です。

select(k: int, v: int) int[source]

k 番目の v のインデックスを返します。 \(O(\log{n}^2)\) です。

select0(k: int) int[source]

k 番目の 0 のインデックスを返します。 \(O(\log{n}^2)\) です。

select1(k: int) int[source]

k 番目の 1 のインデックスを返します。 \(O(\log{n}^2)\) です。

set(k: int, v: int) None[source]

k 番目の値を v に更新します。 \(O(\log{n})\) です。

Parameters:
  • k (int) – 更新位置のインデックスです。

  • key (int) – 更新する値です。 0 または 1 である必要があります。

tolist() list[int][source]

リストにして返します。 \(O(n)\) です。