avl tree bit vector

ソースコード

  1#include <iostream>
  2#include <vector>
  3#include <stack>
  4#include <cassert>
  5#include <tuple>
  6#include <nmmintrin.h>
  7#include <stdint.h>
  8using namespace std;
  9
 10// AVLTreeBitVector
 11namespace titan23 {
 12
 13class AVLTreeBitVector {
 14  private:
 15    using Node = int;
 16    // using uint64 = unsigned long long;
 17    // static constexpr const char _W = 63;
 18    using uint128 = __uint128_t;
 19    static constexpr const char _W = 127;
 20    Node _root, _end;
 21    vector<uint128> _key;
 22    vector<Node> _left, _right;
 23    vector<int> _size, _total;
 24    vector<char> _bit_len, _balance;
 25
 26    void _build(const vector<uint8_t> &a) {
 27        auto rec = [&] (auto &&rec, Node l, Node r) -> pair<Node, char> {
 28            Node mid = (l + r) >> 1;
 29            char hl = 0, hr = 0;
 30            if (l != mid) {
 31                tie(_left[mid], hl) = rec(rec, l, mid);
 32                _size[mid] += _size[_left[mid]];
 33                _total[mid] += _total[_left[mid]];
 34            }
 35            if (mid + 1 != r) {
 36                tie(_right[mid], hr) = rec(rec, mid+1, r);
 37                _size[mid] += _size[_right[mid]];
 38                _total[mid] += _total[_right[mid]];
 39            }
 40            _balance[mid] = hl - hr;
 41            return {mid, (max(hl, hr)+1)};
 42        };
 43
 44        const int n = a.size();
 45        reserve(n);
 46        Node pre_end = _end;
 47        int indx = _end;
 48        for (int i = 0; i < n; i += _W) {
 49            int j = 0;
 50            int pop = 0;
 51            uint128 v = 0;
 52            while (j < _W && i + j < n) {
 53                v <<= 1;
 54                if (a[i+j]) {
 55                    v |= a[i+j];
 56                    ++pop;
 57                }
 58                j++;
 59            }
 60            _key[indx] = v;
 61            _bit_len[indx] = j;
 62            _size[indx] = j;
 63            _total[indx] = pop;
 64            ++indx;
 65        }
 66        this->_end = indx;
 67        this->_root = rec(rec, pre_end, _end).first;
 68    }
 69
 70    int _popcount(const uint128 n) const {
 71        return __builtin_popcountll(n >> 64) + __builtin_popcountll(n);
 72        // return __builtin_popcountll(n);
 73    }
 74
 75    Node _rotate_L(Node node) {
 76        Node u = _left[node];
 77        _size[u] = _size[node];
 78        _total[u] = _total[node];
 79        _size[node] -= _size[_left[u]] + _bit_len[u];
 80        _total[node] -= _total[_left[u]] + _popcount(_key[u]);
 81        _left[node] = _right[u];
 82        _right[u] = node;
 83        if (_balance[u] == 1) {
 84            _balance[u] = 0;
 85            _balance[node] = 0;
 86        } else {
 87            _balance[u] = -1;
 88            _balance[node] = 1;
 89        }
 90        return u;
 91    }
 92
 93    Node _rotate_R(Node node) {
 94        Node u = _right[node];
 95        _size[u] = _size[node];
 96        _total[u] = _total[node];
 97        _size[node] -= _size[_right[u]] + _bit_len[u];
 98        _total[node] -= _total[_right[u]] + _popcount(_key[u]);
 99        _right[node] = _left[u];
100        _left[u] = node;
101        if (_balance[u] == -1) {
102            _balance[u] = 0;
103            _balance[node] = 0;
104        } else {
105            _balance[u] = 1;
106            _balance[node] = -1;
107        }
108        return u;
109    }
110
111    void _update_balance(Node node) {
112        if (_balance[node] == 1) {
113            _balance[_right[node]] = -1;
114            _balance[_left[node]] = 0;
115        } else if (_balance[node] == -1) {
116            _balance[_right[node]] = 0;
117            _balance[_left[node]] = 1;
118        } else {
119            _balance[_right[node]] = 0;
120            _balance[_left[node]] = 0;
121        }
122        _balance[node] = 0;
123    }
124
125    Node _rotate_LR(Node node) {
126        Node B = _left[node];
127        Node E = _right[B];
128        _size[E] = _size[node];
129        _size[node] -= _size[B] - _size[_right[E]];
130        _size[B] -= _size[_right[E]] + _bit_len[E];
131        _total[E] = _total[node];
132        _total[node] -= _total[B] - _total[_right[E]];
133        _total[B] -= _total[_right[E]] + _popcount(_key[E]);
134        _right[B] = _left[E];
135        _left[E] = B;
136        _left[node] = _right[E];
137        _right[E] = node;
138        _update_balance(E);
139        return E;
140    }
141
142    Node _rotate_RL(Node node) {
143        Node C = _right[node];
144        Node D = _left[C];
145        _size[D] = _size[node];
146        _size[node] -= _size[C] - _size[_left[D]];
147        _size[C] -= _size[_left[D]] + _bit_len[D];
148        _total[D] = _total[node];
149        _total[node] -= _total[C] - _total[_left[D]];
150        _total[C] -= _total[_left[D]] + _popcount(_key[D]);
151        _left[C] = _right[D];
152        _right[D] = C;
153        _right[node] = _left[D];
154        _left[D] = node;
155        _update_balance(D);
156        return D;
157    }
158
159    int _pref(int r) const {
160        Node node = _root;
161        int s = 0;
162        while (r > 0) {
163            int t = _size[_left[node]] + _bit_len[node];
164            if (t - _bit_len[node] < r && r <= t) {
165                r -= _size[_left[node]];
166                s += _total[_left[node]] + _popcount(_key[node] >> (_bit_len[node] - r));
167                break;
168            }
169            if (t > r) {
170                node = _left[node];
171            } else {
172                s += _total[_left[node]] + _popcount(_key[node]);
173                node = _right[node];
174                r -= t;
175            }
176        }
177        return s;
178    }
179
180    Node _make_node(const bool new_key, const char new_bit_len) {
181        if (_end >= _key.size()) {
182            _key.emplace_back(new_key);
183            _bit_len.emplace_back(new_bit_len);
184            _size.emplace_back(new_bit_len);
185            _total.emplace_back(new_key);
186            _left.emplace_back(0);
187            _right.emplace_back(0);
188            _balance.emplace_back(0);
189        } else {
190            _key[_end] = new_key;
191            _bit_len[_end] = new_bit_len;
192            _size[_end] = new_bit_len;
193            _total[_end] = new_key;
194        }
195        return _end++;
196    }
197
198    uint128 _bit_insert(uint128 v, char bl, bool key) const {
199        return ((((v >> bl) << 1) | key) << bl) | (v & (((uint128)1<<bl)-1));
200    }
201
202    uint128 _bit_pop(uint128 v, char bl) const {
203        return ((v >> bl) << ((bl-1))) | (v & (((uint128)1<<(bl-1))-1));
204    }
205
206    void _pop_under(stack<Node> &path, int d, Node node, int res) {
207        int fd = 0, lmax_total = 0;
208        char lmax_bit_len = 0;
209        if (_left[node] && _right[node]) {
210            path.emplace(node);
211            d = (d << 1) | 1;
212            Node lmax = _left[node];
213            while (_right[lmax]) {
214                path.emplace(lmax);
215                d <<= 1;
216                fd = (fd << 1) | 1;
217                lmax = _right[lmax];
218            }
219            lmax_total = _popcount(_key[lmax]);
220            lmax_bit_len = _bit_len[lmax];
221            _key[node] = _key[lmax];
222            _bit_len[node] = lmax_bit_len;
223            node = lmax;
224        }
225        Node cnode = _left[node] == 0 ? _right[node] : _left[node];
226        if (!path.empty()) {
227            ((d & 1) ? _left[path.top()] : _right[path.top()]) = cnode;
228        } else {
229            _root = cnode;
230            return;
231        }
232        while (!path.empty()) {
233            Node new_node = 0;
234            node = path.top(); path.pop();
235            _balance[node] -= (d & 1) ? 1 : -1;
236            _size[node] -= (fd & 1) ? lmax_bit_len : 1;
237            _total[node] -= (fd & 1) ? lmax_total : res;
238            d >>= 1;
239            fd >>= 1;
240            if (_balance[node] == 2) {
241                new_node = _balance[_left[node]] < 0 ? _rotate_LR(node) : _rotate_L(node);
242            } else if (_balance[node] == -2) {
243                new_node = _balance[_right[node]] > 0 ? _rotate_RL(node) : _rotate_R(node);
244            } else if (_balance[node] != 0) {
245                break;
246            }
247            if (new_node) {
248                if (path.empty()) {
249                    _root = new_node;
250                    return;
251                }
252                ((d & 1) ? _left[path.top()] : _right[path.top()]) = new_node;
253                if (_balance[new_node] != 0) break;
254            }
255        }
256        while (!path.empty()) {
257            node = path.top(); path.pop();
258            _size[node] -= (fd & 1) ? lmax_bit_len : 1;
259            _total[node] -= (fd & 1) ? lmax_total : res;
260            fd >>= 1;
261        }
262    }
263
264    void _debug_acc() {
265        auto rec = [&] (auto &&rec, Node node) -> int {
266            int acc = _popcount(_key[node]);
267            if (_left[node]) acc += rec(rec, _left[node]);
268            if (_right[node]) acc += rec(rec, _right[node]);
269            if (acc != _total[node]) {
270                assert(false);
271            }
272            return acc;
273        };
274        rec(rec, _root);
275        cout << "debug_acc ok." << endl;
276    }
277
278    public:
279    AVLTreeBitVector()
280        : _root(0), _end(1),
281          _key(1, 0),
282          _left(1, 0), _right(1, 0),
283          _size(1, 0), _total(1, 0),
284          _bit_len(1, 0), _balance(1, 0) {
285    }
286
287    AVLTreeBitVector(const vector<uint8_t> &a)
288        : _root(0), _end(1),
289          _key(1, 0),
290          _left(1, 0), _right(1, 0),
291          _size(1, 0), _total(1, 0),
292          _bit_len(1, 0), _balance(1, 0) {
293        if (!a.empty()) _build(a);
294    }
295
296    void reserve(int n) {
297        n = n / _W + 1;
298        _key.insert(_key.end(), n, (uint128)0);
299        _left.insert(_left.end(), n, 0);
300        _right.insert(_right.end(), n, 0);
301        _size.insert(_size.end(), n, 0);
302        _total.insert(_total.end(), n, 0);
303        _bit_len.insert(_bit_len.end(), n, (char)0);
304        _balance.insert(_balance.end(), n, (char)0);
305    }
306
307    void insert(int k, bool key) {
308        if (!_root) {
309            Node new_node = _make_node(key, 1);
310            _root = new_node;
311            return;
312        }
313        Node node = _root;
314        int d = 0;
315        stack<Node> path;
316        while (node) {
317            int t = _size[_left[node]] + _bit_len[node];
318            if (t - _bit_len[node] <= k && k <= t) break;
319            d <<= 1;
320            _size[node]++;
321            _total[node] += key;
322            path.emplace(node);
323            node = (t > k) ? _left[node] : _right[node];
324            if (t > k) d |= 1;
325            else k -= t;
326        }
327        k -= _size[_left[node]];
328        if (_bit_len[node] < _W) {
329            uint128 v = _key[node];
330            char bl = _bit_len[node] - k;
331            _key[node] = _bit_insert(v, bl, key);
332            _bit_len[node]++;
333            _size[node]++;
334            _total[node] += key;
335            return;
336        }
337        path.emplace(node);
338        _size[node]++;
339        _total[node] += key;
340        uint128 v = _key[node];
341        char bl = _W - k;
342        v = _bit_insert(v, bl, key);
343        uint128 left_key = v >> _W;
344        char left_key_popcount = left_key & 1;
345        _key[node] = v & (((uint128)1 << _W) - 1);
346        node = _left[node];
347        d = (d << 1) | 1;
348        if (!node) {
349            if (_bit_len[path.top()] < _W) {
350                _bit_len[path.top()]++;
351                _key[path.top()] = (_key[path.top()] << 1) | left_key;
352                return;
353            } else {
354                Node new_node = _make_node(left_key, 1);
355                _left[path.top()] = new_node;
356            }
357        } else {
358            path.emplace(node);
359            _size[node]++;
360            _total[node] += left_key_popcount;
361            d <<= 1;
362            while (_right[node]) {
363                node = _right[node];
364                path.emplace(node);
365                _size[node]++;
366                _total[node] += left_key_popcount;
367                d <<= 1;
368            }
369            if (_bit_len[node] < _W) {
370                _bit_len[node]++;
371                _key[node] = (_key[node] << 1) | left_key;
372                return;
373            } else {
374                Node new_node = _make_node(left_key, 1);
375                _right[node] = new_node;
376            }
377        }
378        Node new_node = 0;
379        while (!path.empty()) {
380            node = path.top(); path.pop();
381            _balance[node] += (d & 1) ? 1 : -1;
382            d >>= 1;
383            if (_balance[node] == 0) break;
384            if (_balance[node] == 2) {
385                new_node = _balance[_left[node]] == -1 ? _rotate_LR(node) : _rotate_L(node);
386                break;
387            } else if (_balance[node] == -2) {
388                new_node = _balance[_right[node]] == 1 ? _rotate_RL(node) : _rotate_R(node);
389                break;
390            }
391        }
392        if (new_node) {
393            if (!path.empty()) {
394                if (d & 1) {
395                    _left[path.top()] = new_node;
396                } else {
397                    _right[path.top()] = new_node;
398                }
399            } else {
400                _root = new_node;
401            }
402        }
403    }
404
405    bool pop(int k) {
406        Node node = _root;
407        int d = 0;
408        stack<Node> path;
409        while (node) {
410            int t = _size[_left[node]] + _bit_len[node];
411            if (t - _bit_len[node] <= k && k < t) break;
412            path.emplace(node);
413            node = t > k ? _left[node] : _right[node];
414            d <<= 1;
415            if (t > k) d |= 1;
416            else k -= t;
417        }
418        k -= _size[_left[node]];
419        uint128 v = _key[node];
420        bool res = (v >> (_bit_len[node] - k - 1)) & 1;
421        if (_bit_len[node] == 1) {
422            _pop_under(path, d, node, res);
423            return res;
424        }
425        _key[node] = _bit_pop(v, _bit_len[node]-k);
426        --_bit_len[node];
427        --_size[node];
428        _total[node] -= res;
429        while (!path.empty()) {
430            node = path.top(); path.pop();
431            --_size[node];
432            _total[node] -= res;
433        }
434        return res;
435    }
436
437    void set(int k, bool v) {
438        Node node = _root;
439        stack<Node> path;
440        while (true) {
441            int t = _size[_left[node]] + _bit_len[node];
442            path.emplace(node);
443            if (t - _bit_len[node] <= k && k < t) {
444                k -= _size[_left[node]];
445                if (v) {
446                    _key[node] |= (uint128)1 << k;
447                } else {
448                    _key[node] &= ~((uint128)1 << k);
449                }
450                break;
451            }
452            if (t > k) {
453                node = _left[node];
454            } else {
455                node = _right[node];
456                k -= t;
457            }
458        }
459        while (!path.empty()) {
460            node = path.top(); path.pop();
461            _total[node] = _popcount(_key[node]) + _total[_left[node]] + _total[_right[node]];
462        }
463    }
464
465    vector<uint8_t> tovector() const {
466        vector<uint8_t> a(len());
467        if (!_root) return a;
468        int indx = 0;
469        stack<Node> st;
470        Node node = _root;
471        while ((!st.empty()) || node) {
472            if (node) {
473                st.emplace(node);
474                node = _left[node];
475            } else {
476                node = st.top(); st.pop();
477                uint128 key = _key[node];
478                for (int i = _bit_len[node]-1; i >= 0; --i) {
479                    a[indx++] = key >> i & 1;
480                }
481                node = _right[node];
482            }
483        }
484        return a;
485        // auto rec = [&] (auto &&rec, Node node) -> void {
486        //     if (_left[node]) rec(rec, _left[node]);
487        //     uint128 key = _key[node];
488        //     for (int i = _bit_len[node]-1; i >= 0; --i) {
489        //         a[indx++] = key >> i & 1;
490        //     }
491        //     if (_right[node]) rec(rec, _right[node]);
492        // };
493        // rec(rec, _root);
494        // return a;
495    }
496
497    bool access(int k) const {
498        Node node = _root;
499        while (true) {
500            int t = _size[_left[node]] + _bit_len[node];
501            if (t - _bit_len[node] <= k && k < t) {
502                k -= _size[_left[node]];
503                return (_key[node] >> (_bit_len[node] - k - 1)) & 1;
504            }
505            if (t > k) {
506                node = _left[node];
507            } else {
508                node = _right[node];
509                k -= t;
510            }
511        }
512    }
513
514    int rank0(int r) const {
515        return r - _pref(r);
516    }
517
518    int rank1(int r) const {
519        return _pref(r);
520    }
521
522    int rank(int r, bool v) const {
523        return v ? rank1(r) : rank0(r);
524    }
525
526    int select0(int k) const {
527        Node node = _root;
528        int s = 0;
529        while (true) {
530            int t = _size[_left[node]] - _total[_left[node]];
531            if (k < t) {
532                node = _left[node];
533            } else if (k >= t + _bit_len[node] - _popcount(_key[node])) {
534                s += _size[_left[node]] + _bit_len[node];
535                k -= t + _bit_len[node] - _popcount(_key[node]);
536                node = _right[node];
537            } else {
538                k -= t;
539                char l = 0, r = _bit_len[node];
540                while (r - l > 1) {
541                    char m = (l + r) >> 1;
542                    if (m - _popcount(_key[node]>>(_bit_len[node]-m)) > k) r = m;
543                    else l = m;
544                }
545                s += _size[_left[node]] + l;
546                break;
547            }
548        }
549        return s;
550    }
551
552    int select1(int k) const {
553        Node node = _root;
554        int s = 0;
555        while (true) {
556            if (k < _total[_left[node]]) {
557                node = _left[node];
558            } else if (k >= _total[_left[node]] + _popcount(_key[node])) {
559                s += _size[_left[node]] + _bit_len[node];
560                k -= _total[_left[node]] + _popcount(_key[node]);
561                node = _right[node];
562            } else {
563                k -= _total[_left[node]];
564                char l = 0, r = _bit_len[node];
565                while (r - l > 1) {
566                    char m = (l + r) >> 1;
567                    if (_popcount(_key[node]>>(_bit_len[node]-m)) > k) r = m;
568                    else l = m;
569                }
570                s += _size[_left[node]] + l;
571                break;
572            }
573        }
574        return s;
575    }
576
577    int select(int k, bool v) const {
578        return v ? select1(k) : select0(k);
579    }
580
581    int _insert_and_rank1(int k, bool key) {
582        if (_root == 0) {
583            Node new_node = _make_node(key, 1);
584            _root = new_node;
585            return 0;
586        }
587        Node node = _root;
588        int s = 0;
589        stack<Node> path;
590        int d = 0;
591        while (node) {
592            int t = _size[_left[node]] + _bit_len[node];
593            if (t - _bit_len[node] <= k && k <= t) break;
594            if (t <= k) {
595                s += _total[_left[node]] + _popcount(_key[node]);
596            }
597            d <<= 1;
598            _size[node]++;
599            _total[node] += key;
600            path.emplace(node);
601            node = t > k ? _left[node] : _right[node];
602            if (t > k) d |= 1;
603            else k -= t;
604        }
605        k -= _size[_left[node]];
606        s += _total[_left[node]] + _popcount(_key[node] >> (_bit_len[node] - k));
607        if (_bit_len[node] < _W) {
608            uint128 v = _key[node];
609            char bl = _bit_len[node] - k;
610            _key[node] = _bit_insert(v, bl, key);
611            _bit_len[node]++;
612            _size[node]++;
613            _total[node] += key;
614            return s;
615        }
616        path.emplace(node);
617        _size[node]++;
618        _total[node] += key;
619        uint128 v = _key[node];
620        char bl = _W - k;
621        v = _bit_insert(v, bl, key);
622        uint128 left_key = v >> _W;
623        char left_key_popcount = left_key & 1;
624        _key[node] = v & (((uint128)1 << _W) - 1);
625        node = _left[node];
626        d = d << 1 | 1;
627        if (!node) {
628            if (_bit_len[path.top()] < _W) {
629                _bit_len[path.top()]++;
630                _key[path.top()] = (_key[path.top()] << 1) | left_key;
631                return s;
632            } else {
633                Node new_node = _make_node(left_key, 1);
634                _left[path.top()] = new_node;
635            }
636        } else {
637            path.emplace(node);
638            _size[node]++;
639            _total[node] += left_key_popcount;
640            d <<= 1;
641            while (_right[node]) {
642                node = _right[node];
643                path.emplace(node);
644                _size[node]++;
645                _total[node] += left_key_popcount;
646                d <<= 1;
647            }
648            if (_bit_len[node] < _W) {
649                _bit_len[node]++;
650                _key[node] = (_key[node] << 1) | left_key;
651                return s;
652            } else {
653                Node new_node = _make_node(left_key, 1);
654                _right[node] = new_node;
655            }
656        }
657        Node new_node = 0;
658        while (!path.empty()) {
659            node = path.top(); path.pop();
660            _balance[node] += (d & 1) ? 1 : -1;
661            d >>= 1;
662            if (_balance[node] == 0) break;
663            if (_balance[node] == 2) {
664                new_node = _balance[_left[node]] == -1 ? _rotate_LR(node) : _rotate_L(node);
665                break;
666            } else if (_balance[node] == -2) {
667                new_node = _balance[_right[node]] == 1 ? _rotate_RL(node) : _rotate_R(node);
668                break;
669            }
670        }
671        if (new_node) {
672            if (!path.empty()) {
673                ((d & 1) ? _left[path.top()] : _right[path.top()]) = new_node;
674            } else {
675                _root = new_node;
676            }
677        }
678        return s;
679    }
680
681    int _access_pop_and_rank1(int k) {
682        int s = 0, d = 0;
683        Node node = _root;
684        stack<Node> path;
685        while (node) {
686            int t = _size[_left[node]] + _bit_len[node];
687            if (t - _bit_len[node] <= k && k < t) break;
688            if (t <= k) {
689                s += _total[_left[node]] + _popcount(_key[node]);
690            }
691            path.emplace(node);
692            node = t > k ? _left[node] : _right[node];
693            d <<= 1;
694            if (t > k) d |= 1;
695            else k -= t;
696        }
697        k -= _size[_left[node]];
698        s += _total[_left[node]] + _popcount(_key[node] >> (_bit_len[node] - k));
699        uint128 v = _key[node];
700        bool res = v >> (_bit_len[node] - k - 1) & 1;
701        if (_bit_len[node] == 1) {
702            _pop_under(path, d, node, res);
703            return (s << 1) | res;
704        }
705        _key[node] = _bit_pop(v, _bit_len[node]-k);
706        --_bit_len[node];
707        --_size[node];
708        _total[node] -= res;
709        while (!path.empty()) {
710            node = path.top(); path.pop();
711            --_size[node];
712            _total[node] -= res;
713        }
714        return (s << 1) | res;
715    }
716
717    pair<bool, int> _access_ans_rank1(int k) const {
718        Node node = _root;
719        int s = 0;
720        bool res;
721        while (true) {
722            int t = _size[_left[node]] + _bit_len[node];
723            if (t - _bit_len[node] <= k && k < t) {
724                k -= _size[_left[node]];
725                s += _total[_left[node]] + _popcount(_key[node] >> (_bit_len[node] - k));
726                res = (_key[node] >> (_bit_len[node] - k - 1)) & 1;
727                break;
728            }
729            if (t > k) {
730                node = _left[node];
731            } else {
732                s += _total[_left[node]] + _popcount(_key[node]);
733                node = _right[node];
734                k -= t;
735            }
736        }
737        return make_pair(res, s);
738    }
739
740    void print() const {
741        vector<uint8_t> a = tovector();
742        int n = (int)a.size();
743        cout << "[";
744        for (int i = 0; i < n-1; ++i) {
745            cout << a[i] << ", ";
746        }
747        if (n > 0) {
748            cout << a.back();
749        }
750        cout << "]";
751        cout << endl;
752    }
753
754    bool empty() const {
755        return len() == 0;
756    }
757
758    int len() const {
759        return _size[_root];
760    }
761};
762} // namespace titan23

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/avl_tree_bit_vector.cpp