bit_vector¶
ソースコード¶
from titan_pylib.data_structures.bit_vector.bit_vector import BitVector
展開済みコード¶
1# from titan_pylib.data_structures.bit_vector.bit_vector import BitVector
2# from titan_pylib.data_structures.bit_vector.bit_vector_interface import (
3# BitVectorInterface,
4# )
5from abc import ABC, abstractmethod
6
7
8class BitVectorInterface(ABC):
9
10 @abstractmethod
11 def access(self, k: int) -> int:
12 raise NotImplementedError
13
14 @abstractmethod
15 def __getitem__(self, k: int) -> int:
16 raise NotImplementedError
17
18 @abstractmethod
19 def rank0(self, r: int) -> int:
20 raise NotImplementedError
21
22 @abstractmethod
23 def rank1(self, r: int) -> int:
24 raise NotImplementedError
25
26 @abstractmethod
27 def rank(self, r: int, v: int) -> int:
28 raise NotImplementedError
29
30 @abstractmethod
31 def select0(self, k: int) -> int:
32 raise NotImplementedError
33
34 @abstractmethod
35 def select1(self, k: int) -> int:
36 raise NotImplementedError
37
38 @abstractmethod
39 def select(self, k: int, v: int) -> int:
40 raise NotImplementedError
41
42 @abstractmethod
43 def __len__(self) -> int:
44 raise NotImplementedError
45
46 @abstractmethod
47 def __str__(self) -> str:
48 raise NotImplementedError
49
50 @abstractmethod
51 def __repr__(self) -> str:
52 raise NotImplementedError
53from array import array
54
55
56class BitVector(BitVectorInterface):
57 """コンパクトな bit vector です。"""
58
59 def __init__(self, n: int):
60 """長さ ``n`` の ``BitVector`` です。
61
62 bit を保持するのに ``array[I]`` を使用します。
63 ``block_size= n / 32`` として、使用bitは ``32*block_size=2n bit`` です。
64
65 累積和を保持するのに同様の ``array[I]`` を使用します。
66 32bitごとの和を保存しています。同様に使用bitは ``2n bit`` です。
67 """
68 assert 0 <= n < 4294967295
69 self.N = n
70 self.block_size = (n + 31) >> 5
71 b = bytes(4 * (self.block_size + 1))
72 self.bit = array("I", b)
73 self.acc = array("I", b)
74
75 @staticmethod
76 def _popcount(x: int) -> int:
77 x = x - ((x >> 1) & 0x55555555)
78 x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
79 x = x + (x >> 4) & 0x0F0F0F0F
80 x += x >> 8
81 x += x >> 16
82 return x & 0x0000007F
83
84 def set(self, k: int) -> None:
85 """``k`` 番目の bit を ``1`` にします。
86 :math:`O(1)` です。
87
88 Args:
89 k (int): インデックスです。
90 """
91 self.bit[k >> 5] |= 1 << (k & 31)
92
93 def build(self) -> None:
94 """構築します。
95 **これ以降 ``set`` メソッドを使用してはいけません。**
96 :math:`O(n)` です。
97 """
98 acc, bit = self.acc, self.bit
99 for i in range(self.block_size):
100 acc[i + 1] = acc[i] + BitVector._popcount(bit[i])
101
102 def access(self, k: int) -> int:
103 """``k`` 番目の bit を返します。
104 :math:`O(1)` です。
105 """
106 return (self.bit[k >> 5] >> (k & 31)) & 1
107
108 def __getitem__(self, k: int) -> int:
109 return (self.bit[k >> 5] >> (k & 31)) & 1
110
111 def rank0(self, r: int) -> int:
112 """``a[0, r)`` に含まれる ``0`` の個数を返します。
113 :math:`O(1)` です。
114 """
115 return r - (
116 self.acc[r >> 5]
117 + BitVector._popcount(self.bit[r >> 5] & ((1 << (r & 31)) - 1))
118 )
119
120 def rank1(self, r: int) -> int:
121 """``a[0, r)`` に含まれる ``1`` の個数を返します。
122 :math:`O(1)` です。
123 """
124 return self.acc[r >> 5] + BitVector._popcount(
125 self.bit[r >> 5] & ((1 << (r & 31)) - 1)
126 )
127
128 def rank(self, r: int, v: int) -> int:
129 """``a[0, r)`` に含まれる ``v`` の個数を返します。
130 :math:`O(1)` です。
131 """
132 return self.rank1(r) if v else self.rank0(r)
133
134 def select0(self, k: int) -> int:
135 """``k`` 番目の ``0`` のインデックスを返します。
136 :math:`O(\\log{n})` です。
137 """
138 if k < 0 or self.rank0(self.N) <= k:
139 return -1
140 l, r = 0, self.block_size + 1
141 while r - l > 1:
142 m = (l + r) >> 1
143 if m * 32 - self.acc[m] > k:
144 r = m
145 else:
146 l = m
147 indx = 32 * l
148 k = k - (l * 32 - self.acc[l]) + self.rank0(indx)
149 l, r = indx, indx + 32
150 while r - l > 1:
151 m = (l + r) >> 1
152 if self.rank0(m) > k:
153 r = m
154 else:
155 l = m
156 return l
157
158 def select1(self, k: int) -> int:
159 """``k`` 番目の ``1`` のインデックスを返します。
160 :math:`O(\\log{n})` です。
161 """
162 if k < 0 or self.rank1(self.N) <= k:
163 return -1
164 l, r = 0, self.block_size + 1
165 while r - l > 1:
166 m = (l + r) >> 1
167 if self.acc[m] > k:
168 r = m
169 else:
170 l = m
171 indx = 32 * l
172 k = k - self.acc[l] + self.rank1(indx)
173 l, r = indx, indx + 32
174 while r - l > 1:
175 m = (l + r) >> 1
176 if self.rank1(m) > k:
177 r = m
178 else:
179 l = m
180 return l
181
182 def select(self, k: int, v: int) -> int:
183 """``k`` 番目の ``v`` のインデックスを返します。
184 :math:`O(\\log{n})` です。
185 """
186 return self.select1(k) if v else self.select0(k)
187
188 def __len__(self):
189 return self.N
190
191 def __str__(self):
192 return str([self.access(i) for i in range(self.N)])
193
194 def __repr__(self):
195 return f"{self.__class__.__name__}({self})"
仕様¶
- class BitVector(n: int)[source]¶
Bases:
BitVectorInterface
コンパクトな bit vector です。