avl tree bit vector¶
ソースコード¶
1#include <iostream>
2#include <vector>
3#include <stack>
4#include <cassert>
5#include <tuple>
6#include <nmmintrin.h>
7#include <stdint.h>
8using namespace std;
9
10// AVLTreeBitVector
11namespace titan23 {
12
13class AVLTreeBitVector {
14 private:
15 using Node = int;
16 // using uint64 = unsigned long long;
17 // static constexpr const char _W = 63;
18 using uint128 = __uint128_t;
19 static constexpr const char _W = 127;
20 Node _root, _end;
21 vector<uint128> _key;
22 vector<Node> _left, _right;
23 vector<int> _size, _total;
24 vector<char> _bit_len, _balance;
25
26 void _build(const vector<uint8_t> &a) {
27 auto rec = [&] (auto &&rec, Node l, Node r) -> pair<Node, char> {
28 Node mid = (l + r) >> 1;
29 char hl = 0, hr = 0;
30 if (l != mid) {
31 tie(_left[mid], hl) = rec(rec, l, mid);
32 _size[mid] += _size[_left[mid]];
33 _total[mid] += _total[_left[mid]];
34 }
35 if (mid + 1 != r) {
36 tie(_right[mid], hr) = rec(rec, mid+1, r);
37 _size[mid] += _size[_right[mid]];
38 _total[mid] += _total[_right[mid]];
39 }
40 _balance[mid] = hl - hr;
41 return {mid, (max(hl, hr)+1)};
42 };
43
44 const int n = a.size();
45 reserve(n);
46 Node pre_end = _end;
47 int indx = _end;
48 for (int i = 0; i < n; i += _W) {
49 int j = 0;
50 int pop = 0;
51 uint128 v = 0;
52 while (j < _W && i + j < n) {
53 v <<= 1;
54 if (a[i+j]) {
55 v |= a[i+j];
56 ++pop;
57 }
58 j++;
59 }
60 _key[indx] = v;
61 _bit_len[indx] = j;
62 _size[indx] = j;
63 _total[indx] = pop;
64 ++indx;
65 }
66 this->_end = indx;
67 this->_root = rec(rec, pre_end, _end).first;
68 }
69
70 int _popcount(const uint128 n) const {
71 return __builtin_popcountll(n >> 64) + __builtin_popcountll(n);
72 // return __builtin_popcountll(n);
73 }
74
75 Node _rotate_L(Node node) {
76 Node u = _left[node];
77 _size[u] = _size[node];
78 _total[u] = _total[node];
79 _size[node] -= _size[_left[u]] + _bit_len[u];
80 _total[node] -= _total[_left[u]] + _popcount(_key[u]);
81 _left[node] = _right[u];
82 _right[u] = node;
83 if (_balance[u] == 1) {
84 _balance[u] = 0;
85 _balance[node] = 0;
86 } else {
87 _balance[u] = -1;
88 _balance[node] = 1;
89 }
90 return u;
91 }
92
93 Node _rotate_R(Node node) {
94 Node u = _right[node];
95 _size[u] = _size[node];
96 _total[u] = _total[node];
97 _size[node] -= _size[_right[u]] + _bit_len[u];
98 _total[node] -= _total[_right[u]] + _popcount(_key[u]);
99 _right[node] = _left[u];
100 _left[u] = node;
101 if (_balance[u] == -1) {
102 _balance[u] = 0;
103 _balance[node] = 0;
104 } else {
105 _balance[u] = 1;
106 _balance[node] = -1;
107 }
108 return u;
109 }
110
111 void _update_balance(Node node) {
112 if (_balance[node] == 1) {
113 _balance[_right[node]] = -1;
114 _balance[_left[node]] = 0;
115 } else if (_balance[node] == -1) {
116 _balance[_right[node]] = 0;
117 _balance[_left[node]] = 1;
118 } else {
119 _balance[_right[node]] = 0;
120 _balance[_left[node]] = 0;
121 }
122 _balance[node] = 0;
123 }
124
125 Node _rotate_LR(Node node) {
126 Node B = _left[node];
127 Node E = _right[B];
128 _size[E] = _size[node];
129 _size[node] -= _size[B] - _size[_right[E]];
130 _size[B] -= _size[_right[E]] + _bit_len[E];
131 _total[E] = _total[node];
132 _total[node] -= _total[B] - _total[_right[E]];
133 _total[B] -= _total[_right[E]] + _popcount(_key[E]);
134 _right[B] = _left[E];
135 _left[E] = B;
136 _left[node] = _right[E];
137 _right[E] = node;
138 _update_balance(E);
139 return E;
140 }
141
142 Node _rotate_RL(Node node) {
143 Node C = _right[node];
144 Node D = _left[C];
145 _size[D] = _size[node];
146 _size[node] -= _size[C] - _size[_left[D]];
147 _size[C] -= _size[_left[D]] + _bit_len[D];
148 _total[D] = _total[node];
149 _total[node] -= _total[C] - _total[_left[D]];
150 _total[C] -= _total[_left[D]] + _popcount(_key[D]);
151 _left[C] = _right[D];
152 _right[D] = C;
153 _right[node] = _left[D];
154 _left[D] = node;
155 _update_balance(D);
156 return D;
157 }
158
159 int _pref(int r) const {
160 Node node = _root;
161 int s = 0;
162 while (r > 0) {
163 int t = _size[_left[node]] + _bit_len[node];
164 if (t - _bit_len[node] < r && r <= t) {
165 r -= _size[_left[node]];
166 s += _total[_left[node]] + _popcount(_key[node] >> (_bit_len[node] - r));
167 break;
168 }
169 if (t > r) {
170 node = _left[node];
171 } else {
172 s += _total[_left[node]] + _popcount(_key[node]);
173 node = _right[node];
174 r -= t;
175 }
176 }
177 return s;
178 }
179
180 Node _make_node(const bool new_key, const char new_bit_len) {
181 if (_end >= _key.size()) {
182 _key.emplace_back(new_key);
183 _bit_len.emplace_back(new_bit_len);
184 _size.emplace_back(new_bit_len);
185 _total.emplace_back(new_key);
186 _left.emplace_back(0);
187 _right.emplace_back(0);
188 _balance.emplace_back(0);
189 } else {
190 _key[_end] = new_key;
191 _bit_len[_end] = new_bit_len;
192 _size[_end] = new_bit_len;
193 _total[_end] = new_key;
194 }
195 return _end++;
196 }
197
198 uint128 _bit_insert(uint128 v, char bl, bool key) const {
199 return ((((v >> bl) << 1) | key) << bl) | (v & (((uint128)1<<bl)-1));
200 }
201
202 uint128 _bit_pop(uint128 v, char bl) const {
203 return ((v >> bl) << ((bl-1))) | (v & (((uint128)1<<(bl-1))-1));
204 }
205
206 void _pop_under(stack<Node> &path, int d, Node node, int res) {
207 int fd = 0, lmax_total = 0;
208 char lmax_bit_len = 0;
209 if (_left[node] && _right[node]) {
210 path.emplace(node);
211 d = (d << 1) | 1;
212 Node lmax = _left[node];
213 while (_right[lmax]) {
214 path.emplace(lmax);
215 d <<= 1;
216 fd = (fd << 1) | 1;
217 lmax = _right[lmax];
218 }
219 lmax_total = _popcount(_key[lmax]);
220 lmax_bit_len = _bit_len[lmax];
221 _key[node] = _key[lmax];
222 _bit_len[node] = lmax_bit_len;
223 node = lmax;
224 }
225 Node cnode = _left[node] == 0 ? _right[node] : _left[node];
226 if (!path.empty()) {
227 ((d & 1) ? _left[path.top()] : _right[path.top()]) = cnode;
228 } else {
229 _root = cnode;
230 return;
231 }
232 while (!path.empty()) {
233 Node new_node = 0;
234 node = path.top(); path.pop();
235 _balance[node] -= (d & 1) ? 1 : -1;
236 _size[node] -= (fd & 1) ? lmax_bit_len : 1;
237 _total[node] -= (fd & 1) ? lmax_total : res;
238 d >>= 1;
239 fd >>= 1;
240 if (_balance[node] == 2) {
241 new_node = _balance[_left[node]] < 0 ? _rotate_LR(node) : _rotate_L(node);
242 } else if (_balance[node] == -2) {
243 new_node = _balance[_right[node]] > 0 ? _rotate_RL(node) : _rotate_R(node);
244 } else if (_balance[node] != 0) {
245 break;
246 }
247 if (new_node) {
248 if (path.empty()) {
249 _root = new_node;
250 return;
251 }
252 ((d & 1) ? _left[path.top()] : _right[path.top()]) = new_node;
253 if (_balance[new_node] != 0) break;
254 }
255 }
256 while (!path.empty()) {
257 node = path.top(); path.pop();
258 _size[node] -= (fd & 1) ? lmax_bit_len : 1;
259 _total[node] -= (fd & 1) ? lmax_total : res;
260 fd >>= 1;
261 }
262 }
263
264 void _debug_acc() {
265 auto rec = [&] (auto &&rec, Node node) -> int {
266 int acc = _popcount(_key[node]);
267 if (_left[node]) acc += rec(rec, _left[node]);
268 if (_right[node]) acc += rec(rec, _right[node]);
269 if (acc != _total[node]) {
270 assert(false);
271 }
272 return acc;
273 };
274 rec(rec, _root);
275 cout << "debug_acc ok." << endl;
276 }
277
278 public:
279 AVLTreeBitVector()
280 : _root(0), _end(1),
281 _key(1, 0),
282 _left(1, 0), _right(1, 0),
283 _size(1, 0), _total(1, 0),
284 _bit_len(1, 0), _balance(1, 0) {
285 }
286
287 AVLTreeBitVector(const vector<uint8_t> &a)
288 : _root(0), _end(1),
289 _key(1, 0),
290 _left(1, 0), _right(1, 0),
291 _size(1, 0), _total(1, 0),
292 _bit_len(1, 0), _balance(1, 0) {
293 if (!a.empty()) _build(a);
294 }
295
296 void reserve(int n) {
297 n = n / _W + 1;
298 _key.insert(_key.end(), n, (uint128)0);
299 _left.insert(_left.end(), n, 0);
300 _right.insert(_right.end(), n, 0);
301 _size.insert(_size.end(), n, 0);
302 _total.insert(_total.end(), n, 0);
303 _bit_len.insert(_bit_len.end(), n, (char)0);
304 _balance.insert(_balance.end(), n, (char)0);
305 }
306
307 void insert(int k, bool key) {
308 if (!_root) {
309 Node new_node = _make_node(key, 1);
310 _root = new_node;
311 return;
312 }
313 Node node = _root;
314 int d = 0;
315 stack<Node> path;
316 while (node) {
317 int t = _size[_left[node]] + _bit_len[node];
318 if (t - _bit_len[node] <= k && k <= t) break;
319 d <<= 1;
320 _size[node]++;
321 _total[node] += key;
322 path.emplace(node);
323 node = (t > k) ? _left[node] : _right[node];
324 if (t > k) d |= 1;
325 else k -= t;
326 }
327 k -= _size[_left[node]];
328 if (_bit_len[node] < _W) {
329 uint128 v = _key[node];
330 char bl = _bit_len[node] - k;
331 _key[node] = _bit_insert(v, bl, key);
332 _bit_len[node]++;
333 _size[node]++;
334 _total[node] += key;
335 return;
336 }
337 path.emplace(node);
338 _size[node]++;
339 _total[node] += key;
340 uint128 v = _key[node];
341 char bl = _W - k;
342 v = _bit_insert(v, bl, key);
343 uint128 left_key = v >> _W;
344 char left_key_popcount = left_key & 1;
345 _key[node] = v & (((uint128)1 << _W) - 1);
346 node = _left[node];
347 d = (d << 1) | 1;
348 if (!node) {
349 if (_bit_len[path.top()] < _W) {
350 _bit_len[path.top()]++;
351 _key[path.top()] = (_key[path.top()] << 1) | left_key;
352 return;
353 } else {
354 Node new_node = _make_node(left_key, 1);
355 _left[path.top()] = new_node;
356 }
357 } else {
358 path.emplace(node);
359 _size[node]++;
360 _total[node] += left_key_popcount;
361 d <<= 1;
362 while (_right[node]) {
363 node = _right[node];
364 path.emplace(node);
365 _size[node]++;
366 _total[node] += left_key_popcount;
367 d <<= 1;
368 }
369 if (_bit_len[node] < _W) {
370 _bit_len[node]++;
371 _key[node] = (_key[node] << 1) | left_key;
372 return;
373 } else {
374 Node new_node = _make_node(left_key, 1);
375 _right[node] = new_node;
376 }
377 }
378 Node new_node = 0;
379 while (!path.empty()) {
380 node = path.top(); path.pop();
381 _balance[node] += (d & 1) ? 1 : -1;
382 d >>= 1;
383 if (_balance[node] == 0) break;
384 if (_balance[node] == 2) {
385 new_node = _balance[_left[node]] == -1 ? _rotate_LR(node) : _rotate_L(node);
386 break;
387 } else if (_balance[node] == -2) {
388 new_node = _balance[_right[node]] == 1 ? _rotate_RL(node) : _rotate_R(node);
389 break;
390 }
391 }
392 if (new_node) {
393 if (!path.empty()) {
394 if (d & 1) {
395 _left[path.top()] = new_node;
396 } else {
397 _right[path.top()] = new_node;
398 }
399 } else {
400 _root = new_node;
401 }
402 }
403 }
404
405 bool pop(int k) {
406 Node node = _root;
407 int d = 0;
408 stack<Node> path;
409 while (node) {
410 int t = _size[_left[node]] + _bit_len[node];
411 if (t - _bit_len[node] <= k && k < t) break;
412 path.emplace(node);
413 node = t > k ? _left[node] : _right[node];
414 d <<= 1;
415 if (t > k) d |= 1;
416 else k -= t;
417 }
418 k -= _size[_left[node]];
419 uint128 v = _key[node];
420 bool res = (v >> (_bit_len[node] - k - 1)) & 1;
421 if (_bit_len[node] == 1) {
422 _pop_under(path, d, node, res);
423 return res;
424 }
425 _key[node] = _bit_pop(v, _bit_len[node]-k);
426 --_bit_len[node];
427 --_size[node];
428 _total[node] -= res;
429 while (!path.empty()) {
430 node = path.top(); path.pop();
431 --_size[node];
432 _total[node] -= res;
433 }
434 return res;
435 }
436
437 void set(int k, bool v) {
438 Node node = _root;
439 stack<Node> path;
440 while (true) {
441 int t = _size[_left[node]] + _bit_len[node];
442 path.emplace(node);
443 if (t - _bit_len[node] <= k && k < t) {
444 k -= _size[_left[node]];
445 if (v) {
446 _key[node] |= (uint128)1 << k;
447 } else {
448 _key[node] &= ~((uint128)1 << k);
449 }
450 break;
451 }
452 if (t > k) {
453 node = _left[node];
454 } else {
455 node = _right[node];
456 k -= t;
457 }
458 }
459 while (!path.empty()) {
460 node = path.top(); path.pop();
461 _total[node] = _popcount(_key[node]) + _total[_left[node]] + _total[_right[node]];
462 }
463 }
464
465 vector<uint8_t> tovector() const {
466 vector<uint8_t> a(len());
467 if (!_root) return a;
468 int indx = 0;
469 stack<Node> st;
470 Node node = _root;
471 while ((!st.empty()) || node) {
472 if (node) {
473 st.emplace(node);
474 node = _left[node];
475 } else {
476 node = st.top(); st.pop();
477 uint128 key = _key[node];
478 for (int i = _bit_len[node]-1; i >= 0; --i) {
479 a[indx++] = key >> i & 1;
480 }
481 node = _right[node];
482 }
483 }
484 return a;
485 // auto rec = [&] (auto &&rec, Node node) -> void {
486 // if (_left[node]) rec(rec, _left[node]);
487 // uint128 key = _key[node];
488 // for (int i = _bit_len[node]-1; i >= 0; --i) {
489 // a[indx++] = key >> i & 1;
490 // }
491 // if (_right[node]) rec(rec, _right[node]);
492 // };
493 // rec(rec, _root);
494 // return a;
495 }
496
497 bool access(int k) const {
498 Node node = _root;
499 while (true) {
500 int t = _size[_left[node]] + _bit_len[node];
501 if (t - _bit_len[node] <= k && k < t) {
502 k -= _size[_left[node]];
503 return (_key[node] >> (_bit_len[node] - k - 1)) & 1;
504 }
505 if (t > k) {
506 node = _left[node];
507 } else {
508 node = _right[node];
509 k -= t;
510 }
511 }
512 }
513
514 int rank0(int r) const {
515 return r - _pref(r);
516 }
517
518 int rank1(int r) const {
519 return _pref(r);
520 }
521
522 int rank(int r, bool v) const {
523 return v ? rank1(r) : rank0(r);
524 }
525
526 int select0(int k) const {
527 Node node = _root;
528 int s = 0;
529 while (true) {
530 int t = _size[_left[node]] - _total[_left[node]];
531 if (k < t) {
532 node = _left[node];
533 } else if (k >= t + _bit_len[node] - _popcount(_key[node])) {
534 s += _size[_left[node]] + _bit_len[node];
535 k -= t + _bit_len[node] - _popcount(_key[node]);
536 node = _right[node];
537 } else {
538 k -= t;
539 char l = 0, r = _bit_len[node];
540 while (r - l > 1) {
541 char m = (l + r) >> 1;
542 if (m - _popcount(_key[node]>>(_bit_len[node]-m)) > k) r = m;
543 else l = m;
544 }
545 s += _size[_left[node]] + l;
546 break;
547 }
548 }
549 return s;
550 }
551
552 int select1(int k) const {
553 Node node = _root;
554 int s = 0;
555 while (true) {
556 if (k < _total[_left[node]]) {
557 node = _left[node];
558 } else if (k >= _total[_left[node]] + _popcount(_key[node])) {
559 s += _size[_left[node]] + _bit_len[node];
560 k -= _total[_left[node]] + _popcount(_key[node]);
561 node = _right[node];
562 } else {
563 k -= _total[_left[node]];
564 char l = 0, r = _bit_len[node];
565 while (r - l > 1) {
566 char m = (l + r) >> 1;
567 if (_popcount(_key[node]>>(_bit_len[node]-m)) > k) r = m;
568 else l = m;
569 }
570 s += _size[_left[node]] + l;
571 break;
572 }
573 }
574 return s;
575 }
576
577 int select(int k, bool v) const {
578 return v ? select1(k) : select0(k);
579 }
580
581 int _insert_and_rank1(int k, bool key) {
582 if (_root == 0) {
583 Node new_node = _make_node(key, 1);
584 _root = new_node;
585 return 0;
586 }
587 Node node = _root;
588 int s = 0;
589 stack<Node> path;
590 int d = 0;
591 while (node) {
592 int t = _size[_left[node]] + _bit_len[node];
593 if (t - _bit_len[node] <= k && k <= t) break;
594 if (t <= k) {
595 s += _total[_left[node]] + _popcount(_key[node]);
596 }
597 d <<= 1;
598 _size[node]++;
599 _total[node] += key;
600 path.emplace(node);
601 node = t > k ? _left[node] : _right[node];
602 if (t > k) d |= 1;
603 else k -= t;
604 }
605 k -= _size[_left[node]];
606 s += _total[_left[node]] + _popcount(_key[node] >> (_bit_len[node] - k));
607 if (_bit_len[node] < _W) {
608 uint128 v = _key[node];
609 char bl = _bit_len[node] - k;
610 _key[node] = _bit_insert(v, bl, key);
611 _bit_len[node]++;
612 _size[node]++;
613 _total[node] += key;
614 return s;
615 }
616 path.emplace(node);
617 _size[node]++;
618 _total[node] += key;
619 uint128 v = _key[node];
620 char bl = _W - k;
621 v = _bit_insert(v, bl, key);
622 uint128 left_key = v >> _W;
623 char left_key_popcount = left_key & 1;
624 _key[node] = v & (((uint128)1 << _W) - 1);
625 node = _left[node];
626 d = d << 1 | 1;
627 if (!node) {
628 if (_bit_len[path.top()] < _W) {
629 _bit_len[path.top()]++;
630 _key[path.top()] = (_key[path.top()] << 1) | left_key;
631 return s;
632 } else {
633 Node new_node = _make_node(left_key, 1);
634 _left[path.top()] = new_node;
635 }
636 } else {
637 path.emplace(node);
638 _size[node]++;
639 _total[node] += left_key_popcount;
640 d <<= 1;
641 while (_right[node]) {
642 node = _right[node];
643 path.emplace(node);
644 _size[node]++;
645 _total[node] += left_key_popcount;
646 d <<= 1;
647 }
648 if (_bit_len[node] < _W) {
649 _bit_len[node]++;
650 _key[node] = (_key[node] << 1) | left_key;
651 return s;
652 } else {
653 Node new_node = _make_node(left_key, 1);
654 _right[node] = new_node;
655 }
656 }
657 Node new_node = 0;
658 while (!path.empty()) {
659 node = path.top(); path.pop();
660 _balance[node] += (d & 1) ? 1 : -1;
661 d >>= 1;
662 if (_balance[node] == 0) break;
663 if (_balance[node] == 2) {
664 new_node = _balance[_left[node]] == -1 ? _rotate_LR(node) : _rotate_L(node);
665 break;
666 } else if (_balance[node] == -2) {
667 new_node = _balance[_right[node]] == 1 ? _rotate_RL(node) : _rotate_R(node);
668 break;
669 }
670 }
671 if (new_node) {
672 if (!path.empty()) {
673 ((d & 1) ? _left[path.top()] : _right[path.top()]) = new_node;
674 } else {
675 _root = new_node;
676 }
677 }
678 return s;
679 }
680
681 int _access_pop_and_rank1(int k) {
682 int s = 0, d = 0;
683 Node node = _root;
684 stack<Node> path;
685 while (node) {
686 int t = _size[_left[node]] + _bit_len[node];
687 if (t - _bit_len[node] <= k && k < t) break;
688 if (t <= k) {
689 s += _total[_left[node]] + _popcount(_key[node]);
690 }
691 path.emplace(node);
692 node = t > k ? _left[node] : _right[node];
693 d <<= 1;
694 if (t > k) d |= 1;
695 else k -= t;
696 }
697 k -= _size[_left[node]];
698 s += _total[_left[node]] + _popcount(_key[node] >> (_bit_len[node] - k));
699 uint128 v = _key[node];
700 bool res = v >> (_bit_len[node] - k - 1) & 1;
701 if (_bit_len[node] == 1) {
702 _pop_under(path, d, node, res);
703 return (s << 1) | res;
704 }
705 _key[node] = _bit_pop(v, _bit_len[node]-k);
706 --_bit_len[node];
707 --_size[node];
708 _total[node] -= res;
709 while (!path.empty()) {
710 node = path.top(); path.pop();
711 --_size[node];
712 _total[node] -= res;
713 }
714 return (s << 1) | res;
715 }
716
717 pair<bool, int> _access_ans_rank1(int k) const {
718 Node node = _root;
719 int s = 0;
720 bool res;
721 while (true) {
722 int t = _size[_left[node]] + _bit_len[node];
723 if (t - _bit_len[node] <= k && k < t) {
724 k -= _size[_left[node]];
725 s += _total[_left[node]] + _popcount(_key[node] >> (_bit_len[node] - k));
726 res = (_key[node] >> (_bit_len[node] - k - 1)) & 1;
727 break;
728 }
729 if (t > k) {
730 node = _left[node];
731 } else {
732 s += _total[_left[node]] + _popcount(_key[node]);
733 node = _right[node];
734 k -= t;
735 }
736 }
737 return make_pair(res, s);
738 }
739
740 void print() const {
741 vector<uint8_t> a = tovector();
742 int n = (int)a.size();
743 cout << "[";
744 for (int i = 0; i < n-1; ++i) {
745 cout << a[i] << ", ";
746 }
747 if (n > 0) {
748 cout << a.back();
749 }
750 cout << "]";
751 cout << endl;
752 }
753
754 bool empty() const {
755 return len() == 0;
756 }
757
758 int len() const {
759 return _size[_root];
760 }
761};
762} // namespace titan23
仕様¶
Warning
doxygenfile: Cannot find file “titan_cpplib/data_structures/avl_tree_bit_vector.cpp