lazy wb tree

ソースコード

  1#include <iostream>
  2#include <vector>
  3#include <stack>
  4#include <cassert>
  5using namespace std;
  6
  7namespace titan23 {
  8
  9    template <class T,
 10                class F,
 11                T (*op)(T, T),
 12                T (*mapping)(F, T),
 13                F (*composition)(F, F),
 14                T (*e)(),
 15                F (*id)()>
 16    class LazyWBTree {
 17      private:
 18        class Node;
 19        using NodePtr = Node*;
 20        using MyLazyWBTree = LazyWBTree<T, F, op, mapping, composition, e, id>;
 21        static constexpr int DELTA = 3;
 22        static constexpr int GAMMA = 2;
 23
 24        class Node {
 25          public:
 26            NodePtr left;
 27            NodePtr right;
 28            T key, data;
 29            F lazy;
 30            bool rev;
 31            int size;
 32
 33            Node(T key, F lazy) : left(nullptr), right(nullptr), key(key), data(key), lazy(lazy), rev(false), size(1) {}
 34
 35            int weight_right() const {
 36                return right ? right->size + 1 : 1;
 37            }
 38
 39            int weight_left() const {
 40                return left ? left->size + 1 : 1;
 41            }
 42
 43            void balance_check() const {
 44                if (!weight_left()*DELTA >= weight_right()) {
 45                    cerr << weight_left() << ", " << weight_right() << endl;
 46                    cerr << "not weight_left()*DELTA >= weight_right()." << endl;
 47                    assert(false);
 48                }
 49                if (!weight_right() * DELTA >= weight_left()) {
 50                    cerr << weight_left() << ", " << weight_right() << endl;
 51                    cerr << "not weight_right() * DELTA >= weight_left()." << endl;
 52                    assert(false);
 53                }
 54            }
 55
 56            void update() {
 57                size = 1;
 58                data = key;
 59                if (left) {
 60                    size += left->size;
 61                    data = op(left->data, key);
 62                }
 63                if (right) {
 64                    size += right->size;
 65                    data = op(data, right->data);
 66                }
 67            }
 68
 69            void propagate() {
 70                if (rev) {
 71                    swap(left, right);
 72                    if (left) left->rev ^= 1;
 73                    if (right) right->rev ^= 1;
 74                    rev = 0;
 75                }
 76                if (lazy != id()) {
 77                    if (left) {
 78                        left->key = mapping(lazy, left->key);
 79                        left->data = mapping(lazy, left->data);
 80                        left->lazy = composition(lazy, left->lazy);
 81                    }
 82                    if (right) {
 83                        right->key = mapping(lazy, right->key);
 84                        right->data = mapping(lazy, right->data);
 85                        right->lazy = composition(lazy, right->lazy);
 86                    }
 87                    lazy = id();
 88                }
 89            }
 90
 91            void print() const {
 92                cout << "this : key=" << key << ", size=" << size << endl;
 93                if (left)  cout << "to-left" << endl;
 94                if (right) cout << "to-right" << endl;
 95                cout << endl;
 96                if (left)  left->print();
 97                if (right) right->print();
 98            }
 99        };
100
101        void _build(vector<T> const &a) {
102            auto build = [&] (auto &&build, int l, int r) -> NodePtr {
103                int mid = (l + r) >> 1;
104                NodePtr node = new Node(a[mid], id());
105                if (l != mid) node->left = build(build, l, mid);
106                if (mid+1 != r) node->right = build(build, mid+1, r);
107                node->update();
108                return node;
109            };
110            if (a.empty()) {
111                this->root = nullptr;
112                return;
113            }
114            this->root = build(build, 0, (int)a.size());
115        }
116
117        NodePtr _rotate_right(NodePtr node) {
118            NodePtr u = node->left;
119            node->left = u->right;
120            u->right = node;
121            node->update();
122            u->update();
123            return u;
124        }
125
126        NodePtr _rotate_left(NodePtr node) {
127            NodePtr u = node->right;
128            node->right = u->left;
129            u->left = node;
130            node->update();
131            u->update();
132            return u;
133        }
134
135        NodePtr _balance_left(NodePtr node) {
136            NodePtr u = node->right;
137            u->propagate();
138            if (node->right->weight_left() >= node->right->weight_right() * GAMMA) {
139                u->left->propagate();
140                node->right = _rotate_right(u);
141            }
142            return _rotate_left(node);
143        }
144
145        NodePtr _balance_right(NodePtr node) {
146            NodePtr u = node->left;
147            u->propagate();
148            if (node->left->weight_right() >= node->left->weight_left() * GAMMA) {
149                u->right->propagate();
150                node->left = _rotate_left(u);
151            }
152            return _rotate_right(node);
153        }
154
155        int weight(NodePtr node) const {
156            return node ? node->size + 1 : 1;
157        }
158
159        NodePtr _merge_with_root(NodePtr l, NodePtr root, NodePtr r) {
160            if (weight(l) * DELTA < weight(r)) {
161                r->propagate();
162                r->left = _merge_with_root(l, root, r->left);
163                r->update();
164                if (weight(r->right) * DELTA < weight(r->left)) {
165                    return _balance_right(r);
166                }
167                return r;
168            } else if (weight(r) * DELTA < weight(l)) {
169                l->propagate();
170                l->right = _merge_with_root(l->right, root, r);
171                l->update();
172                if (weight(l->left) * DELTA < weight(l->right)) {
173                    return _balance_left(l);
174                }
175                return l;
176            }
177            root->left = l;
178            root->right = r;
179            root->update();
180            return root;
181        }
182
183        pair<NodePtr, NodePtr> _pop_right(NodePtr node) {
184            return _split_node(node, node->size-1);
185        }
186
187        NodePtr _merge_node(NodePtr l, NodePtr r) {
188            if ((!l) && (!r)) {return nullptr;}
189            if (!l) {return r;}
190            if (!r) {return l;}
191            auto [l_, root_] = _pop_right(l);
192            return _merge_with_root(l_, root_, r);
193        }
194
195        pair<NodePtr, NodePtr> _split_node(NodePtr node, int k) {
196            if (!node) return {nullptr, nullptr};
197            node->propagate();
198            int tmp = node->left ? k-node->left->size : k;
199            NodePtr nl = node->left;
200            NodePtr nr = node->right;
201            if (tmp == 0) {
202                return {nl, _merge_with_root(nullptr, node, nr)};
203            } else if (tmp < 0) {
204                auto [l, r] = _split_node(nl, k);
205                return {l, _merge_with_root(r, node, nr)};
206            } else {
207                auto [l, r] = _split_node(nr, tmp-1);
208                return {_merge_with_root(node->left, node, l), r};
209            }
210        }
211
212        LazyWBTree(NodePtr &root): root(root) {}
213
214        MyLazyWBTree _new(NodePtr root) {
215            return MyLazyWBTree(root);
216        }
217
218    public:
219        NodePtr root;
220
221        LazyWBTree() : root(nullptr) {}
222
223        LazyWBTree(vector<T> const &a) { _build(a); }
224
225        void merge(MyLazyWBTree &other) {
226            this->root = _merge_node(this->root, other.root);
227        }
228
229        pair<MyLazyWBTree, MyLazyWBTree> split(const int k) {
230            auto [l, r] = _split_node(this->root, k);
231            return {_new(l), _new(r)};
232        }
233
234        void all_apply(const F f) {
235            if (!this->root) return;
236            this->root->key = mapping(f, this->root->key);
237            this->root->data = mapping(f, this->root->data);
238            this->root->lazy = composition(f, this->root->lazy);
239        }
240
241        void apply(const int l, const int r, const F f) {
242            assert(0 <= l && l <= r && r <= len());
243            if (l == r) return;
244            auto dfs = [&] (auto &&dfs, NodePtr node, int left, int right) -> void {
245                if (right <= l || r <= left) return;
246                node->propagate();
247                if (l <= left && right < r) {
248                    node->key = mapping(f, node->key);
249                    node->data = mapping(f, node->data);
250                    node->lazy = composition(f, node->lazy);
251                    return;
252                }
253                int lsize = node->left ? node->left->size : 0;
254                if (node->left) dfs(dfs, node->left, left, left+lsize);
255                if (l <= left+lsize && left+lsize < r) node->key = mapping(f, node->key);;
256                if (node->right) dfs(dfs, node->right, left+lsize+1, right);
257                node->update();
258            };
259            dfs(dfs, root, 0, len());
260        }
261
262        T all_prod() const {
263            return this->root ? this->root->data : e();
264        }
265
266        T prod(const int l, const int r) {
267            assert(0 <= l && l <= r && r <= len());
268            if (l == r) return e();
269            auto dfs = [&] (auto &&dfs, NodePtr node, int left, int right) -> T {
270                if (right <= l || r <= left) return e();
271                node->propagate();
272                if (l <= left && right < r) return node->data;
273                int lsize = node->left ? node->left->size : 0;
274                T res = e();
275                if (node->left) res = dfs(dfs, node->left, left, left+lsize);
276                if (l <= left+lsize && left+lsize < r) res = op(res, node->key);
277                if (node->right) res = op(res, dfs(dfs, node->right, left+lsize+1, right));
278                return res;
279            };
280            return dfs(dfs, root, 0, len());
281        }
282
283        void insert(int k, const T key) {
284            assert(0 <= k && k <= len());
285            auto [s, t] = _split_node(root, k);
286            NodePtr new_node = new Node(key, id());
287            this->root = _merge_with_root(s, new_node, t);
288        }
289
290        T pop(int k) {
291            assert(0 <= k && k < len());
292            stack<NodePtr> path;
293            stack<short> path_d;
294            NodePtr node = this->root;
295            T res;
296            int d = 0;
297            while (1) {
298                node->propagate();
299                int t = node->left? node->left->size: 0;
300                if (t == k) {
301                    res = node->key;
302                    break;
303                }
304                int d = (t < k)? 0: 1;
305                path.emplace(node);
306                path_d.emplace(d);
307                if (d) {
308                    node = node->left;
309                } else {
310                    k -= t + 1;
311                    node = node->right;
312                }
313            }
314            if (node->left && node->right) {
315                node->propagate();
316                NodePtr lmax = node->left;
317                path.emplace(node);
318                path_d.emplace(1);
319                lmax->propagate();
320                while (lmax->right) {
321                    path.emplace(lmax);
322                    path_d.emplace(0);
323                    lmax = lmax->right;
324                    lmax->propagate();
325                }
326                node->key = lmax->key;
327                node = lmax;
328            }
329            NodePtr cnode = (node->left) ? node->left : node->right;
330            if (!path.empty()) {
331                if (path_d.top()) path.top()->left = cnode;
332                else path.top()->right = cnode;
333            } else {
334                this->root = cnode;
335                return res;
336            }
337            while (!path.empty()) {
338                NodePtr new_node = nullptr;
339                node = path.top();
340                node->update();
341                path.pop();
342                path_d.pop();
343                if (weight(node->left) * DELTA < weight(node->right)) {
344                    new_node = _balance_left(node);
345                } else if (weight(node->right) * DELTA < weight(node->left)) {
346                    new_node = _balance_right(node);
347                }
348                if (new_node) {
349                    if (path.empty()) {
350                        this->root = new_node;
351                        break;
352                    }
353                    if (path_d.top()) {
354                        path.top()->left = new_node;
355                    } else {
356                        path.top()->right = new_node;
357                    }
358                }
359            }
360            return res;
361        }
362
363        void reverse(const int l, const int r) {
364            assert(0 <= l && l <= r && r <= len());
365            if (l == r) return;
366            auto [s_, t] = _split_node(root, r);
367            auto [u, s] = _split_node(s_, l);
368            s->rev ^= 1;
369            this->root = _merge_node(_merge_node(u, s), t);
370        }
371
372        vector<T> tovector() {
373            NodePtr node = root;
374            stack<NodePtr> s;
375            vector<T> a;
376            a.reserve(len());
377            while ((!s.empty()) || node) {
378                if (node) {
379                    node->propagate();
380                    s.emplace(node);
381                    node = node->left;
382                } else {
383                    node = s.top();
384                    s.pop();
385                    a.emplace_back(node->key);
386                    node = node->right;
387                }
388            }
389            return a;
390        }
391
392        void set(int k, T v) {
393            assert(0 <= k && k < len());
394            NodePtr node = root;
395            NodePtr root = node;
396            NodePtr pnode = nullptr;
397            int d = 0;
398            stack<NodePtr> path = {node};
399            while (1) {
400                node->propagate();
401                int t = node->left? node->left->size: 0;
402                if (t == k) {
403                    node->key = v;
404                    path.emplace(node);
405                    if (pnode) {
406                        if (d) pnode->left = node;
407                        else pnode->right = node;
408                    }
409                    while (!path.empty()) {
410                        path.top()->update();
411                        path.pop();
412                    }
413                    return;
414                }
415                pnode = node;
416                d = (t < k) ? 0 : 1;
417                if (d) {
418                    pnode->left = node = node->left;
419                } else {
420                    k -= t + 1;
421                    pnode->right = node = node->right;
422                }
423                path.emplace(node);
424            }
425        }
426
427        T get(int k) {
428            assert(0 <= k && k < len());
429            NodePtr node = root;
430            while (1) {
431                node->propagate();
432                int t = node->left ? node->left->size : 0;
433                if (t == k) {
434                    return node->key;
435                }
436                if (t < k) {
437                    k -= t + 1;
438                    node = node->right;
439                } else {
440                    node = node->left;
441                }
442            }
443        }
444
445        void print() {
446            vector<T> a = tovector();
447            cout << "[";
448            for (int i = 0; i < (int)a.size()-1; ++i) {
449                cout << a[i] << ", ";
450            }
451            if (!a.empty()) cout << a.back();
452            cout << "]" << endl;
453        }
454
455        friend ostream& operator<<(ostream& os, LazyWBTree<T, F, op, mapping, composition, e, id> &tree) {
456            vector<T> a = tree.tovector();
457            os << "[";
458            for (int i = 0; i < (int)a.size()-1; ++i) {
459                os << a[i] << ", ";
460            }
461            if (!a.empty()) os << a.back();
462            os << "]";
463            return os;
464        }
465
466        int len() const {
467            return root ? root->size : 0;
468        }
469
470        void check() const {
471            auto rec = [&] (auto &&rec, NodePtr node) -> pair<int, int> {
472                int ls = 0, rs = 0;
473                int height = 0;
474                int h;
475                if (node->left) {
476                    pair<int, int> res = rec(rec, node->left);
477                    ls = res.first;
478                    h = res.second;
479                    height = max(height, h);
480                }
481                if (node->right) {
482                    pair<int, int> res = rec(rec, node->right);
483                    rs = res.first;
484                    h = res.second;
485                    height = max(height, h);
486                }
487                node->balance_check();
488                int s = ls + rs + 1;
489                assert(s == node->size);
490                return {s, height+1};
491            };
492            if (root == nullptr) return;
493            auto [_, h] = rec(rec, root);
494            cerr << PRINT_GREEN << "OK : height=" << h << PRINT_NONE << endl;
495        }
496    };
497} // namespace titan23

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/lazy_wb_tree.cpp