persistent lazy wb tree

ソースコード

  1#include <iostream>
  2#include <vector>
  3#include <cmath>
  4#include <cassert>
  5#include <stack>
  6#include <memory>
  7#include "titan_cpplib/others/print.cpp"
  8using namespace std;
  9
 10namespace titan23 {
 11
 12template <class T,
 13            class F,
 14            T (*op)(T, T),
 15            T (*mapping)(F, T),
 16            F (*composition)(F, F),
 17            T (*e)(),
 18            F (*id)()>
 19class PersistentLazyWBTree {
 20  private:
 21    class Node;
 22    using NodePtr = shared_ptr<Node>;
 23    // using NodePtr = Node*;
 24    using MyPersistentLazyWBTree = PersistentLazyWBTree<T, F, op, mapping, composition, e, id>;
 25    static constexpr int DELTA = 3;
 26    static constexpr int GAMMA = 2;
 27    NodePtr root;
 28
 29    class Node {
 30      public:
 31        T key, data;
 32        NodePtr left;
 33        NodePtr right;
 34        F lazy;
 35        int rev, size;
 36
 37        Node(T key, F lazy) : key(key), data(key), left(nullptr), right(nullptr), lazy(lazy), rev(0), size(1) {}
 38
 39        NodePtr copy() const {
 40            NodePtr node = make_shared<Node>(key, lazy);
 41            // NodePtr node = new Node(key, lazy);
 42            node->data = data;
 43            node->left = left;
 44            node->right = right;
 45            node->rev = rev;
 46            node->size = size;
 47            return node;
 48        }
 49
 50        int weight_right() const {
 51            return right ? right->size + 1 : 1;
 52        }
 53
 54        int weight_left() const {
 55            return left ? left->size + 1 : 1;
 56        }
 57
 58        void propagate() {
 59            if (rev) {
 60                NodePtr l = left  ? left->copy()  : nullptr;
 61                NodePtr r = right ? right->copy() : nullptr;
 62                left = r;
 63                right = l;
 64                if (left) left->rev ^= 1;
 65                if (right) right->rev ^= 1;
 66                rev = 0;
 67            }
 68            if (lazy != id()) {
 69                if (left) {
 70                    left = left->copy();
 71                    left->key = mapping(lazy, left->key);
 72                    left->data = mapping(lazy, left->data);
 73                    left->lazy = composition(lazy, left->lazy);
 74                }
 75                if (right) {
 76                    right = right->copy();
 77                    right->key = mapping(lazy, right->key);
 78                    right->data = mapping(lazy, right->data);
 79                    right->lazy = composition(lazy, right->lazy);
 80                }
 81                lazy = id();
 82            }
 83        }
 84
 85        void update() {
 86            size = 1;
 87            data = key;
 88            if (left) {
 89                size += left->size;
 90                data = op(left->data, key);
 91            }
 92            if (right) {
 93                size += right->size;
 94                data = op(data, right->data);
 95            }
 96        }
 97
 98        void balance_check() const {
 99            if (!weight_left()*DELTA >= weight_right()) {
100                cerr << weight_left() << ", " << weight_right() << endl;
101                cerr << "not weight_left()*DELTA >= weight_right()." << endl;
102                assert(false);
103            }
104            if (!weight_right() * DELTA >= weight_left()) {
105                cerr << weight_left() << ", " << weight_right() << endl;
106                cerr << "not weight_right() * DELTA >= weight_left()." << endl;
107                assert(false);
108            }
109        }
110
111        void print() const {
112            cout << "this : key=" << key << ", size=" << size << endl;
113            if (left)  cout << "to-left" << endl;
114            if (right) cout << "to-right" << endl;
115            cout << endl;
116            if (left)  left->print();
117            if (right) right->print();
118        }
119    };
120
121    void _build(vector<T> const &a) {
122        auto build = [&] (auto &&build, int l, int r) -> NodePtr {
123            int mid = (l + r) >> 1;
124            NodePtr node = make_shared<Node>(a[mid], id());
125            // NodePtr node = new Node(a[mid], id());
126            if (l != mid) node->left = build(build, l, mid);
127            if (mid+1 != r) node->right = build(build, mid+1, r);
128            node->update();
129            return node;
130        };
131        root = build(build, 0, (int)a.size());
132    }
133
134    NodePtr _rotate_right(NodePtr &node) {
135        NodePtr u = node->left->copy();
136        node->left = u->right;
137        u->right = node;
138        node->update();
139        u->update();
140        return u;
141    }
142
143    NodePtr _rotate_left(NodePtr &node) {
144        NodePtr u = node->right->copy();
145        node->right = u->left;
146        u->left = node;
147        node->update();
148        u->update();
149        return u;
150    }
151
152    NodePtr _balance_left(NodePtr &node) {
153        node->right->propagate();
154        node->right = node->right->copy();
155        NodePtr u = node->right;
156        if (node->right->weight_left() >= node->right->weight_right() * GAMMA) {
157            u->left->propagate();
158            node->right = _rotate_right(u);
159        }
160        u = _rotate_left(node);
161        return u;
162    }
163
164    NodePtr _balance_right(NodePtr &node) {
165        node->left->propagate();
166        node->left = node->left->copy();
167        NodePtr u = node->left;
168        if (node->left->weight_right() >= node->left->weight_left() * GAMMA) {
169            u->right->propagate();
170            node->left = _rotate_left(u);
171        }
172        u = _rotate_right(node);
173        return u;
174    }
175
176    int weight(NodePtr node) const {
177        return node ? node->size + 1 : 1;
178    }
179
180    NodePtr _merge_with_root(NodePtr l, NodePtr root, NodePtr r) {
181        if (weight(r) * DELTA < weight(l)) {
182            l->propagate();
183            l = l->copy();
184            l->right = _merge_with_root(l->right, root, r);
185            l->update();
186            if (weight(l->left) * DELTA < weight(l->right)) {
187                return _balance_left(l);
188            }
189            return l;
190        } else if (weight(l) * DELTA < weight(r)) {
191            r->propagate();
192            r = r->copy();
193            r->left = _merge_with_root(l, root, r->left);
194            r->update();
195            if (weight(r->right) * DELTA < weight(r->left)) {
196                return _balance_right(r);
197            }
198            return r;
199        }
200        root = root->copy();
201        root->left = l;
202        root->right = r;
203        root->update();
204        return root;
205    }
206
207    pair<NodePtr, NodePtr> _pop_right(NodePtr &node) {
208        return _split_node(node, node->size-1);
209    }
210
211    NodePtr _merge_node(NodePtr l, NodePtr r) {
212        if ((!l) && (!r)) { return nullptr; }
213        if (!l) { return r->copy(); }
214        if (!r) { return l->copy(); }
215        l = l->copy();
216        r = r->copy();
217        auto [l_, root_] = _pop_right(l);
218        return _merge_with_root(l_, root_, r);
219    }
220
221    pair<NodePtr, NodePtr> _split_node(NodePtr &node, int k) {
222        if (!node) {return {nullptr, nullptr};}
223        node->propagate();
224        int tmp = node->left ? k-node->left->size : k;
225        if (tmp == 0) {
226            return {node->left, _merge_with_root(nullptr, node, node->right)};
227        } else if (tmp < 0) {
228            auto [l, r] = _split_node(node->left, k);
229            return {l, _merge_with_root(r, node, node->right)};
230        } else {
231            auto [l, r] = _split_node(node->right, tmp-1);
232            return {_merge_with_root(node->left, node, l), r};
233        }
234    }
235
236    MyPersistentLazyWBTree _new(NodePtr root) {
237        return MyPersistentLazyWBTree(root);
238    }
239
240    PersistentLazyWBTree(NodePtr root) : root(root) {}
241
242  public:
243    PersistentLazyWBTree() : root(nullptr) {}
244
245    PersistentLazyWBTree(vector<T> &a) { _build(a); }
246
247    MyPersistentLazyWBTree merge(MyPersistentLazyWBTree other) {
248        NodePtr root = _merge_node(this->root, other.root);
249        return _new(root);
250    }
251
252    pair<MyPersistentLazyWBTree, MyPersistentLazyWBTree> split(int k) {
253        auto [l, r] = _split_node(this->root, k);
254        return {_new(l), _new(r)};
255    }
256
257    MyPersistentLazyWBTree apply(int l, int r, F f) {
258        if (l >= r) return _new(this->root ? root->copy() : nullptr);
259        auto dfs = [&] (auto &&dfs, NodePtr node, int left, int right) -> NodePtr {
260            if (right <= l || r <= left) return node;
261            node->propagate();
262            NodePtr nnode = node->copy();
263            if (l <= left && right < r) {
264                nnode->key = mapping(f, nnode->key);
265                nnode->data = mapping(f, nnode->data);
266                nnode->lazy = composition(f, nnode->lazy);
267                return nnode;
268            }
269            int lsize = nnode->left ? nnode->left->size : 0;
270            if (nnode->left) nnode->left = dfs(dfs, nnode->left, left, left+lsize);
271            if (l <= left+lsize && left+lsize < r) nnode->key = mapping(f, nnode->key);;
272            if (nnode->right) nnode->right = dfs(dfs, nnode->right, left+lsize+1, right);
273            nnode->update();
274            return nnode;
275        };
276        return _new(dfs(dfs, root, 0, len()));
277    }
278
279    T prod(int l, int r) {
280        assert(0 <= l && l <= r && r <= len());
281        if (l == r) return e();
282        auto dfs = [&] (auto &&dfs, NodePtr node, int left, int right) -> T {
283            if (right <= l || r <= left) return e();
284            node->propagate();
285            if (l <= left && right < r) return node->data;
286            int lsize = node->left ? node->left->size : 0;
287            T res = e();
288            if (node->left) res = dfs(dfs, node->left, left, left+lsize);
289            if (l <= left+lsize && left+lsize < r) res = op(res, node->key);
290            if (node->right) res = op(res, dfs(dfs, node->right, left+lsize+1, right));
291            return res;
292        };
293        return dfs(dfs, root, 0, len());
294    }
295
296    MyPersistentLazyWBTree insert(int k, T key) {
297        assert(0 <= k && k <= len());
298        auto [s, t] = _split_node(root, k);
299        NodePtr new_node = make_shared<Node>(key, id());
300        return _new(_merge_with_root(s, new_node, t));
301    }
302
303    pair<MyPersistentLazyWBTree, T> pop(int k) {
304        assert(0 <= k && k < len());
305        auto [s_, t] = _split_node(this->root, k+1);
306        auto [s, tmp] = _pop_right(s_);
307        NodePtr root = _merge_node(s, t);
308        return {_new(root), tmp->key};
309    }
310
311    MyPersistentLazyWBTree reverse(int l, int r) {
312        assert(0 <= l && l <= r && r <= len());
313        if (l >= r) return _new(root ? root->copy() : nullptr);
314        auto [s_, t] = _split_node(root, r);
315        auto [u, s] = _split_node(s_, l);
316        s->rev ^= 1;
317        NodePtr root = _merge_node(_merge_node(u, s), t);
318        return _new(root);
319    }
320
321    vector<T> tovector() {
322        NodePtr node = root;
323        stack<NodePtr> s;
324        vector<T> a;
325        a.reserve(len());
326        while (!s.empty() || node) {
327            if (node) {
328                node->propagate();
329                s.emplace(node);
330                node = node->left;
331            } else {
332                node = s.top(); s.pop();
333                a.emplace_back(node->key);
334                node = node->right;
335            }
336        }
337        return a;
338    }
339
340    MyPersistentLazyWBTree copy() const {
341        return _new(root ? root->copy() : nullptr);
342    }
343
344    MyPersistentLazyWBTree set(int k, T v) {
345        assert(0 <= k && k < len());
346        NodePtr node = root->copy();
347        NodePtr root = node;
348        NodePtr pnode = nullptr;
349        int d = 0;
350        stack<NodePtr> path = {node};
351        while (1) {
352            node->propagate();
353            int t = node->left ? node->left->size : 0;
354            if (t == k) {
355                node = node->copy();
356                node->key = v;
357                path.emplace(node);
358                if (d) pnode->left = node;
359                else pnode->right = node;
360                while (!path.empty()) {
361                    path.top()->update();
362                    path.pop();
363                }
364                return _new(root);
365            }
366            pnode = node;
367            if (t < k) {
368                k -= t + 1;
369                node = node->right->copy();
370                d = 0;
371            } else {
372                d = 1;
373                node = node->left->copy();
374            }
375            path.emplace_back(node);
376            if (d) pnode->left = node;
377            else pnode->right = node;
378        }
379    }
380
381    T get(int k) {
382        assert(0 <= k && k < len());
383        NodePtr node = root;
384        while (1) {
385            node->propagate();
386            int t = node->left ? node->left->size : 0;
387            if (t == k) {
388                return node->key;
389            }
390            if (t < k) {
391                k -= t + 1;
392                node = node->right;
393            } else {
394                node = node->left;
395            }
396        }
397    }
398
399    void print() {
400        vector<T> a = tovector();
401        cout << "[";
402        for (int i = 0; i < (int)a.size()-1; ++i) {
403            cout << a[i] << ", ";
404        }
405        if (!a.empty()) cout << a.back();
406        cout << "]" << endl;
407    }
408
409    int len() const {
410        return root ? root->size : 0;
411    }
412
413    void check() const {
414        auto rec = [&] (auto &&rec, NodePtr node) -> pair<int, int> {
415            int ls = 0, rs = 0;
416            int height = 0;
417            int h;
418            if (node->left) {
419                pair<int, int> res = rec(rec, node->left);
420                ls = res.first;
421                h = res.second;
422                height = max(height, h);
423            }
424            if (node->right) {
425                pair<int, int> res = rec(rec, node->right);
426                rs = res.first;
427                h = res.second;
428                height = max(height, h);
429            }
430            node->balance_check();
431            int s = ls + rs + 1;
432            assert(s == node->size);
433            return {s, height+1};
434        };
435        if (root == nullptr) return;
436        auto [_, h] = rec(rec, root);
437        cerr << PRINT_GREEN << "OK : height=" << h << PRINT_NONE << endl;
438    }
439
440    friend ostream& operator<<(ostream& os, PersistentLazyWBTree<T, F, op, mapping, composition, e, id> &tree) {
441        vector<T> a = tree.tovector();
442        os << "[";
443        for (int i = 0; i < (int)a.size()-1; ++i) {
444            os << a[i] << ", ";
445        }
446        if (!a.empty()) os << a.back();
447        os << "]";
448        return os;
449    }
450};
451}  // namespace titan23

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/persistent_lazy_wb_tree.cpp