Source code for titan_pylib.data_structures.bit_vector.bit_vector
1from titan_pylib.data_structures.bit_vector.bit_vector_interface import (
2 BitVectorInterface,
3)
4from array import array
5
6
[docs]
7class BitVector(BitVectorInterface):
8 """コンパクトな bit vector です。"""
9
10 def __init__(self, n: int):
11 """長さ ``n`` の ``BitVector`` です。
12
13 bit を保持するのに ``array[I]`` を使用します。
14 ``block_size= n / 32`` として、使用bitは ``32*block_size=2n bit`` です。
15
16 累積和を保持するのに同様の ``array[I]`` を使用します。
17 32bitごとの和を保存しています。同様に使用bitは ``2n bit`` です。
18 """
19 assert 0 <= n < 4294967295
20 self.N = n
21 self.block_size = (n + 31) >> 5
22 b = bytes(4 * (self.block_size + 1))
23 self.bit = array("I", b)
24 self.acc = array("I", b)
25
26 @staticmethod
27 def _popcount(x: int) -> int:
28 x = x - ((x >> 1) & 0x55555555)
29 x = (x & 0x33333333) + ((x >> 2) & 0x33333333)
30 x = x + (x >> 4) & 0x0F0F0F0F
31 x += x >> 8
32 x += x >> 16
33 return x & 0x0000007F
34
[docs]
35 def set(self, k: int) -> None:
36 """``k`` 番目の bit を ``1`` にします。
37 :math:`O(1)` です。
38
39 Args:
40 k (int): インデックスです。
41 """
42 self.bit[k >> 5] |= 1 << (k & 31)
43
[docs]
44 def build(self) -> None:
45 """構築します。
46 **これ以降 ``set`` メソッドを使用してはいけません。**
47 :math:`O(n)` です。
48 """
49 acc, bit = self.acc, self.bit
50 for i in range(self.block_size):
51 acc[i + 1] = acc[i] + BitVector._popcount(bit[i])
52
[docs]
53 def access(self, k: int) -> int:
54 """``k`` 番目の bit を返します。
55 :math:`O(1)` です。
56 """
57 return (self.bit[k >> 5] >> (k & 31)) & 1
58
59 def __getitem__(self, k: int) -> int:
60 return (self.bit[k >> 5] >> (k & 31)) & 1
61
[docs]
62 def rank0(self, r: int) -> int:
63 """``a[0, r)`` に含まれる ``0`` の個数を返します。
64 :math:`O(1)` です。
65 """
66 return r - (
67 self.acc[r >> 5]
68 + BitVector._popcount(self.bit[r >> 5] & ((1 << (r & 31)) - 1))
69 )
70
[docs]
71 def rank1(self, r: int) -> int:
72 """``a[0, r)`` に含まれる ``1`` の個数を返します。
73 :math:`O(1)` です。
74 """
75 return self.acc[r >> 5] + BitVector._popcount(
76 self.bit[r >> 5] & ((1 << (r & 31)) - 1)
77 )
78
[docs]
79 def rank(self, r: int, v: int) -> int:
80 """``a[0, r)`` に含まれる ``v`` の個数を返します。
81 :math:`O(1)` です。
82 """
83 return self.rank1(r) if v else self.rank0(r)
84
[docs]
85 def select0(self, k: int) -> int:
86 """``k`` 番目の ``0`` のインデックスを返します。
87 :math:`O(\\log{n})` です。
88 """
89 if k < 0 or self.rank0(self.N) <= k:
90 return -1
91 l, r = 0, self.block_size + 1
92 while r - l > 1:
93 m = (l + r) >> 1
94 if m * 32 - self.acc[m] > k:
95 r = m
96 else:
97 l = m
98 indx = 32 * l
99 k = k - (l * 32 - self.acc[l]) + self.rank0(indx)
100 l, r = indx, indx + 32
101 while r - l > 1:
102 m = (l + r) >> 1
103 if self.rank0(m) > k:
104 r = m
105 else:
106 l = m
107 return l
108
[docs]
109 def select1(self, k: int) -> int:
110 """``k`` 番目の ``1`` のインデックスを返します。
111 :math:`O(\\log{n})` です。
112 """
113 if k < 0 or self.rank1(self.N) <= k:
114 return -1
115 l, r = 0, self.block_size + 1
116 while r - l > 1:
117 m = (l + r) >> 1
118 if self.acc[m] > k:
119 r = m
120 else:
121 l = m
122 indx = 32 * l
123 k = k - self.acc[l] + self.rank1(indx)
124 l, r = indx, indx + 32
125 while r - l > 1:
126 m = (l + r) >> 1
127 if self.rank1(m) > k:
128 r = m
129 else:
130 l = m
131 return l
132
[docs]
133 def select(self, k: int, v: int) -> int:
134 """``k`` 番目の ``v`` のインデックスを返します。
135 :math:`O(\\log{n})` です。
136 """
137 return self.select1(k) if v else self.select0(k)
138
139 def __len__(self):
140 return self.N
141
142 def __str__(self):
143 return str([self.access(i) for i in range(self.N)])
144
145 def __repr__(self):
146 return f"{self.__class__.__name__}({self})"