bit_vector

ソースコード

from titan_pylib.data_structures.bit_vector.bit_vector import BitVector

view on github

展開済みコード

  1# from titan_pylib.data_structures.bit_vector.bit_vector import BitVector
  2# from titan_pylib.data_structures.bit_vector.bit_vector_interface import (
  3#     BitVectorInterface,
  4# )
  5from abc import ABC, abstractmethod
  6
  7
  8class BitVectorInterface(ABC):
  9
 10    @abstractmethod
 11    def access(self, k: int) -> int:
 12        raise NotImplementedError
 13
 14    @abstractmethod
 15    def __getitem__(self, k: int) -> int:
 16        raise NotImplementedError
 17
 18    @abstractmethod
 19    def rank0(self, r: int) -> int:
 20        raise NotImplementedError
 21
 22    @abstractmethod
 23    def rank1(self, r: int) -> int:
 24        raise NotImplementedError
 25
 26    @abstractmethod
 27    def rank(self, r: int, v: int) -> int:
 28        raise NotImplementedError
 29
 30    @abstractmethod
 31    def select0(self, k: int) -> int:
 32        raise NotImplementedError
 33
 34    @abstractmethod
 35    def select1(self, k: int) -> int:
 36        raise NotImplementedError
 37
 38    @abstractmethod
 39    def select(self, k: int, v: int) -> int:
 40        raise NotImplementedError
 41
 42    @abstractmethod
 43    def __len__(self) -> int:
 44        raise NotImplementedError
 45
 46    @abstractmethod
 47    def __str__(self) -> str:
 48        raise NotImplementedError
 49
 50    @abstractmethod
 51    def __repr__(self) -> str:
 52        raise NotImplementedError
 53from array import array
 54
 55
 56class BitVector(BitVectorInterface):
 57    """コンパクトな bit vector です。"""
 58
 59    def __init__(self, n: int):
 60        """長さ ``n`` の ``BitVector`` です。
 61
 62        bit を保持するのに ``array[I]`` を使用します。
 63        ``block_size= n / 32`` として、使用bitは ``32*block_size=2n bit`` です。
 64
 65        累積和を保持するのに同様の ``array[I]`` を使用します。
 66        32bitごとの和を保存しています。同様に使用bitは ``2n bit`` です。
 67        """
 68        assert 0 <= n < 4294967295
 69        self.N = n
 70        self.block_size = (n + 31) >> 5
 71        b = bytes(4 * (self.block_size + 1))
 72        self.bit = array("I", b)
 73        self.acc = array("I", b)
 74
 75    @staticmethod
 76    def _popcount(x: int) -> int:
 77        x = x - ((x >> 1) & 0x55555555)
 78        x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
 79        x = x + (x >> 4) & 0x0F0F0F0F
 80        x += x >> 8
 81        x += x >> 16
 82        return x & 0x0000007F
 83
 84    def set(self, k: int) -> None:
 85        """``k`` 番目の bit を ``1`` にします。
 86        :math:`O(1)` です。
 87
 88        Args:
 89          k (int): インデックスです。
 90        """
 91        self.bit[k >> 5] |= 1 << (k & 31)
 92
 93    def build(self) -> None:
 94        """構築します。
 95        **これ以降 ``set`` メソッドを使用してはいけません。**
 96        :math:`O(n)` です。
 97        """
 98        acc, bit = self.acc, self.bit
 99        for i in range(self.block_size):
100            acc[i + 1] = acc[i] + BitVector._popcount(bit[i])
101
102    def access(self, k: int) -> int:
103        """``k`` 番目の bit を返します。
104        :math:`O(1)` です。
105        """
106        return (self.bit[k >> 5] >> (k & 31)) & 1
107
108    def __getitem__(self, k: int) -> int:
109        return (self.bit[k >> 5] >> (k & 31)) & 1
110
111    def rank0(self, r: int) -> int:
112        """``a[0, r)`` に含まれる ``0`` の個数を返します。
113        :math:`O(1)` です。
114        """
115        return r - (
116            self.acc[r >> 5]
117            + BitVector._popcount(self.bit[r >> 5] & ((1 << (r & 31)) - 1))
118        )
119
120    def rank1(self, r: int) -> int:
121        """``a[0, r)`` に含まれる ``1`` の個数を返します。
122        :math:`O(1)` です。
123        """
124        return self.acc[r >> 5] + BitVector._popcount(
125            self.bit[r >> 5] & ((1 << (r & 31)) - 1)
126        )
127
128    def rank(self, r: int, v: int) -> int:
129        """``a[0, r)`` に含まれる ``v`` の個数を返します。
130        :math:`O(1)` です。
131        """
132        return self.rank1(r) if v else self.rank0(r)
133
134    def select0(self, k: int) -> int:
135        """``k`` 番目の ``0`` のインデックスを返します。
136        :math:`O(\\log{n})` です。
137        """
138        if k < 0 or self.rank0(self.N) <= k:
139            return -1
140        l, r = 0, self.block_size + 1
141        while r - l > 1:
142            m = (l + r) >> 1
143            if m * 32 - self.acc[m] > k:
144                r = m
145            else:
146                l = m
147        indx = 32 * l
148        k = k - (l * 32 - self.acc[l]) + self.rank0(indx)
149        l, r = indx, indx + 32
150        while r - l > 1:
151            m = (l + r) >> 1
152            if self.rank0(m) > k:
153                r = m
154            else:
155                l = m
156        return l
157
158    def select1(self, k: int) -> int:
159        """``k`` 番目の ``1`` のインデックスを返します。
160        :math:`O(\\log{n})` です。
161        """
162        if k < 0 or self.rank1(self.N) <= k:
163            return -1
164        l, r = 0, self.block_size + 1
165        while r - l > 1:
166            m = (l + r) >> 1
167            if self.acc[m] > k:
168                r = m
169            else:
170                l = m
171        indx = 32 * l
172        k = k - self.acc[l] + self.rank1(indx)
173        l, r = indx, indx + 32
174        while r - l > 1:
175            m = (l + r) >> 1
176            if self.rank1(m) > k:
177                r = m
178            else:
179                l = m
180        return l
181
182    def select(self, k: int, v: int) -> int:
183        """``k`` 番目の ``v`` のインデックスを返します。
184        :math:`O(\\log{n})` です。
185        """
186        return self.select1(k) if v else self.select0(k)
187
188    def __len__(self):
189        return self.N
190
191    def __str__(self):
192        return str([self.access(i) for i in range(self.N)])
193
194    def __repr__(self):
195        return f"{self.__class__.__name__}({self})"

仕様

class BitVector(n: int)[source]

Bases: BitVectorInterface

コンパクトな bit vector です。

access(k: int) int[source]

k 番目の bit を返します。 \(O(1)\) です。

build() None[source]

構築します。 これ以降 ``set`` メソッドを使用してはいけません。 \(O(n)\) です。

rank(r: int, v: int) int[source]

a[0, r) に含まれる v の個数を返します。 \(O(1)\) です。

rank0(r: int) int[source]

a[0, r) に含まれる 0 の個数を返します。 \(O(1)\) です。

rank1(r: int) int[source]

a[0, r) に含まれる 1 の個数を返します。 \(O(1)\) です。

select(k: int, v: int) int[source]

k 番目の v のインデックスを返します。 \(O(\log{n})\) です。

select0(k: int) int[source]

k 番目の 0 のインデックスを返します。 \(O(\log{n})\) です。

select1(k: int) int[source]

k 番目の 1 のインデックスを返します。 \(O(\log{n})\) です。

set(k: int) None[source]

k 番目の bit を 1 にします。 \(O(1)\) です。

Parameters:

k (int) – インデックスです。