dynamic bit vector

ソースコード

  1#include <iostream>
  2#include <vector>
  3#include <cassert>
  4#include <cmath>
  5using namespace std;
  6
  7// DynamicBitVector
  8namespace titan23 {
  9
 10    class DynamicBitVector {
 11      private:
 12        static const int BUCKET_MAX = 1000;
 13        vector<vector<uint8_t>> data;
 14        vector<int> bucket_data;
 15        int _size;
 16        int tot_one;
 17
 18        void build(const vector<uint8_t> &a) {
 19            long long s = len();
 20            int bucket_size = max((int)(s+BUCKET_MAX-1)/BUCKET_MAX, (int)ceil(sqrt(s)));
 21            data.resize(bucket_size);
 22            bucket_data.resize(bucket_size);
 23            for (int i = 0; i < bucket_size; ++i) {
 24                int start = s*i/bucket_size;
 25                int stop = min((int)len(), (int)(s*(i+1)/bucket_size));
 26                vector<uint8_t> d(a.begin()+start, a.begin()+stop);
 27                int sum = 0;
 28                for (const uint8_t &x: d) sum += x;
 29                data[i] = d;
 30                tot_one += sum;
 31                bucket_data[i] = sum;
 32            }
 33        }
 34
 35        pair<int, int> get_bucket(int k) const {
 36            if (k == len()) return {-1, -1};
 37            if (k < len()/2) {
 38                for (int i = 0; i < data.size(); ++i) {
 39                    if (k < data[i].size()) return {i, k};
 40                    k -= data[i].size();
 41                }
 42            } else {
 43                int tot = len();
 44                for (int i = data.size()-1; i >= 0; --i) {
 45                    if (tot-data[i].size() <= k) {
 46                        return {i, k-(tot-data[i].size())};
 47                    }
 48                    tot -= data[i].size();
 49                }
 50            }
 51            assert(false);
 52        }
 53
 54      public:
 55        DynamicBitVector() : _size(0), tot_one(0) {}
 56            DynamicBitVector(const vector<uint8_t> &a) : _size(a.size()), tot_one(0) {
 57            build(a);
 58        }
 59
 60        void insert(int k, bool key) {
 61            assert(0 <= k && k <= len());
 62            if (data.empty()) {
 63                ++_size;
 64                tot_one += key;
 65                bucket_data.emplace_back(key);
 66                data.push_back({key});
 67                return;
 68            }
 69            auto [bucket_pos, bit_pos] = get_bucket(k);
 70            if (bucket_pos == -1) {
 71                bucket_pos = data.size()-1;
 72                bucket_data.back() += key;
 73                data.back().emplace_back(key);
 74            } else {
 75                bucket_data[bucket_pos] += key;
 76                data[bucket_pos].insert(data[bucket_pos].begin() + bit_pos, key);
 77            }
 78            if (data[bucket_pos].size() > BUCKET_MAX) {
 79                vector<uint8_t> right(data[bucket_pos].begin() + BUCKET_MAX/2, data[bucket_pos].end());
 80                data[bucket_pos].erase(data[bucket_pos].begin() + BUCKET_MAX/2, data[bucket_pos].end());
 81                data.emplace(data.begin() + bucket_pos+1, right);
 82                bucket_data.insert(bucket_data.begin() + bucket_pos, 0);
 83                bucket_data[bucket_pos] = 0;
 84                bucket_data[bucket_pos+1] = 0;
 85                for (const uint8_t x: data[bucket_pos]) bucket_data[bucket_pos] += x;
 86                for (const uint8_t x: data[bucket_pos+1]) bucket_data[bucket_pos+1] += x;
 87            }
 88            ++_size;
 89            tot_one += key;
 90        }
 91
 92        bool access(int k) const {
 93            assert(0 <= k && k < len());
 94            auto [bucket_pos, bit_pos] = get_bucket(k);
 95            return data[bucket_pos][bit_pos];
 96        }
 97
 98        bool pop(int k) {
 99            assert(0 <= k && k < len());
100            auto [bucket_pos, bit_pos] = get_bucket(k);
101            bool res = data[bucket_pos][bit_pos];
102            bucket_data[bucket_pos] -= res;
103            data[bucket_pos].erase(data[bucket_pos].begin() + bit_pos);
104            tot_one -= res;
105            --_size;
106            if (data[bucket_pos].empty()) {
107                data.erase(data.begin() + bucket_pos);
108                bucket_data.erase(bucket_data.begin() + bucket_pos);
109            }
110            return res;
111        }
112
113        void set(int k, bool v) {
114            assert(0 <= k && k < len());
115            auto [bucket_pos, bit_pos] = get_bucket(k);
116            data[bucket_pos][bit_pos] = v;
117        }
118
119        int rank0(int r) const {
120            assert(0 <= r && r <= len());
121            return r - rank1(r);
122        }
123
124        int rank1(int r) const {
125            assert(0 <= r && r <= len());
126            int s = 0;
127            for (int i = 0; i < data.size(); ++i) {
128                if (r < data[i].size()) {
129                    const vector<uint8_t> &d = data[i];
130                    for (int j = 0; j < r; ++j) {
131                        if (d[j]) ++s;
132                    }
133                    return s;
134                }
135                s += bucket_data[i];
136                r -= data[i].size();
137            }
138            return s;
139            assert(false);
140        }
141
142        int rank(int r, bool key) const {
143            assert(0 <= r && r <= len());
144            return key ? rank1(r) : rank0(r);
145        }
146
147        int select0(int k) const {
148            int s = 0;
149            for (int i = 0; i < data.size(); ++i) {
150                if (k < data[i].size() - bucket_data[i]) {
151                for (const uint8_t &x: data[i]) {
152                    if (!x) --k;
153                    if (k < 0) return s;
154                    s++;
155                }
156                assert(false);
157                }
158                s += data[i].size();
159                k -= data[i].size() - bucket_data[i];
160            }
161            assert(false);
162        }
163
164        int select1(int k) const {
165            int s = 0;
166            for (int i = 0; i < data.size(); ++i) {
167                if (k < bucket_data[i]) {
168                for (const uint8_t &x: data[i]) {
169                    if (x) --k;
170                    if (k < 0) return s;
171                    s++;
172                }
173                }
174                s += data[i].size();
175                k -= bucket_data[i];
176            }
177            assert(false);
178        }
179
180        int select(int k, bool key) const {
181            return key ? select1(k) : select0(k);
182        }
183
184        int _insert_and_rank1(int k, bool key) {
185            int s = 0;
186            int bucket_pos = -1, bit_pos = -1;
187            if (k < len()/2) {
188                for (int i = 0; i < data.size(); ++i) {
189                    if (k < data[i].size()) {
190                        bucket_pos = i;
191                        bit_pos = k;
192                        const vector<uint8_t> &d = data[i];
193                        for (int j = 0; j < k; ++j) {
194                            s += d[j];
195                        }
196                        break;
197                    }
198                    s += bucket_data[i];
199                    k -= data[i].size();
200                }
201            } else {
202                int tot = len();
203                s = tot_one;
204                for (int i = data.size()-1; i >= 0; --i) {
205                    if (tot-data[i].size() <= k) {
206                        bucket_pos = i;
207                        bit_pos = k-(tot-data[i].size());
208                        const vector<uint8_t> &d = data[i];
209                        for (int j = bit_pos; j < d.size(); ++j) {
210                            s -= d[j];
211                        }
212                        break;
213                    }
214                    tot -= data[i].size();
215                    s -= bucket_data[i];
216                }
217            }
218
219            {
220                ++_size;
221                tot_one += key;
222                if (data.empty()) {
223                    bucket_data.emplace_back(key);
224                    data.push_back({{key}});
225                    return s;
226                }
227                if (bucket_pos == -1) {
228                    bucket_pos = data.size()-1;
229                    bucket_data.back() += key;
230                    data.back().emplace_back(key);
231                } else {
232                    bucket_data[bucket_pos] += key;
233                    data[bucket_pos].insert(data[bucket_pos].begin() + bit_pos, key);
234                }
235                if (data[bucket_pos].size() > BUCKET_MAX) {
236                vector<uint8_t> right(data[bucket_pos].begin() + BUCKET_MAX/2, data[bucket_pos].end());
237                data[bucket_pos].erase(data[bucket_pos].begin() + BUCKET_MAX/2, data[bucket_pos].end());
238                data.emplace(data.begin() + bucket_pos+1, right);
239                bucket_data.insert(bucket_data.begin() + bucket_pos, 0);
240                bucket_data[bucket_pos] = 0;
241                bucket_data[bucket_pos+1] = 0;
242                for (const uint8_t &x: data[bucket_pos]) bucket_data[bucket_pos] += x;
243                for (const uint8_t &x: data[bucket_pos+1]) bucket_data[bucket_pos+1] += x;
244                }
245            }
246            return s;
247        }
248
249        int _access_pop_and_rank1(int k) {
250            int prek = k;
251            int s = 0;
252            int bucket_pos, bit_pos;
253            bool res;
254            for (int i = 0; i < data.size(); ++i) {
255                if (k < data[i].size()) {
256                    bucket_pos = i;
257                    bit_pos = k;
258                    res = data[bucket_pos][bit_pos];
259                    const vector<uint8_t> &d = data[i];
260                    for (int j = 0; j < k; ++j) {
261                        if (d[j]) ++s;
262                    }
263                    break;
264                }
265                s += bucket_data[i];
266                k -= data[i].size();
267            }
268            bucket_data[bucket_pos] -= res;
269            data[bucket_pos].erase(data[bucket_pos].begin() + bit_pos);
270            tot_one -= res;
271            --_size;
272            if (data[bucket_pos].empty()) {
273                data.erase(data.begin() + bucket_pos);
274                bucket_data.erase(bucket_data.begin() + bucket_pos);
275            }
276            return s << 1 | res;
277        }
278
279        pair<bool, int> _access_ans_rank1(int k) const {
280            assert(0 <= k && k < len());
281            int s = 0;
282            for (int i = 0; i < data.size(); ++i) {
283                if (k < data[i].size()) {
284                    const vector<uint8_t> &d = data[i];
285                    for (int j = 0; j < k; ++j) {
286                        s += d[j];
287                    }
288                    return {data[i][k], s};
289                }
290                s += bucket_data[i];
291                k -= data[i].size();
292            }
293            assert(false);
294        }
295
296        vector<uint8_t> tovector() const {
297            vector<uint8_t> a(len());
298            int ptr = 0;
299            for (const vector<uint8_t> &d: data) for (const uint8_t &x: d) {
300                a[ptr++] = x;
301            }
302            return a;
303        }
304
305        void print() const {
306            vector<uint8_t> a = tovector();
307            int n = (int)a.size();
308            assert(n == len());
309            cout << "[";
310            for (int i = 0; i < n-1; ++i) {
311                cout << a[i] << ", ";
312            }
313            if (n > 0) {
314                cout << a.back();
315            }
316            cout << "]";
317            cout << endl;
318        }
319
320        bool empty() const { return _size == 0; }
321        int len() const { return _size; }
322    };
323} // name space titan23

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/dynamic_bit_vector.cpp