wb tree seg

ソースコード

  1#include <iostream>
  2#include <vector>
  3#include <stack>
  4#include <cassert>
  5#include "titan_cpplib/others/print.cpp"
  6using namespace std;
  7
  8namespace titan23 {
  9
 10    template <class T,
 11              T (*op)(T, T),
 12              T (*e)()>
 13    class WBSegTree {
 14      private:
 15        class Node;
 16        using NodePtr = Node*;
 17        using MyLazyWBTree = WBSegTree<T, op, e>;
 18        static constexpr int DELTA = 3;
 19        static constexpr int GAMMA = 2;
 20
 21        class Node {
 22          public:
 23            NodePtr left;
 24            NodePtr right;
 25            T key, data;
 26            int size;
 27
 28            Node(T key) : left(nullptr), right(nullptr), key(key), data(key), size(1) {}
 29
 30            int weight_right() const {
 31                return right ? right->size + 1 : 1;
 32            }
 33
 34            int weight_left() const {
 35                return left ? left->size + 1 : 1;
 36            }
 37
 38            void balance_check() const {
 39                if (!weight_left()*DELTA >= weight_right()) {
 40                    cerr << weight_left() << ", " << weight_right() << endl;
 41                    cerr << "not weight_left()*DELTA >= weight_right()." << endl;
 42                    assert(false);
 43                }
 44                if (!weight_right() * DELTA >= weight_left()) {
 45                    cerr << weight_left() << ", " << weight_right() << endl;
 46                    cerr << "not weight_right() * DELTA >= weight_left()." << endl;
 47                    assert(false);
 48                }
 49            }
 50
 51            void update() {
 52                size = 1;
 53                data = key;
 54                if (left) {
 55                    size += left->size;
 56                    data = op(left->data, key);
 57                }
 58                if (right) {
 59                    size += right->size;
 60                    data = op(data, right->data);
 61                }
 62            }
 63        };
 64
 65        void _build(vector<T> const &a) {
 66            auto build = [&] (auto &&build, int l, int r) -> NodePtr {
 67                int mid = (l + r) >> 1;
 68                NodePtr node = new Node(a[mid]);
 69                if (l != mid) node->left = build(build, l, mid);
 70                if (mid+1 != r) node->right = build(build, mid+1, r);
 71                node->update();
 72                return node;
 73            };
 74            if (a.empty()) {
 75                this->root = nullptr;
 76                return;
 77            }
 78            this->root = build(build, 0, (int)a.size());
 79        }
 80
 81        NodePtr _rotate_right(NodePtr node) {
 82            NodePtr u = node->left;
 83            node->left = u->right;
 84            u->right = node;
 85            node->update();
 86            u->update();
 87            return u;
 88        }
 89
 90        NodePtr _rotate_left(NodePtr node) {
 91            NodePtr u = node->right;
 92            node->right = u->left;
 93            u->left = node;
 94            node->update();
 95            u->update();
 96            return u;
 97        }
 98
 99        NodePtr _balance_left(NodePtr node) {
100            NodePtr u = node->right;
101            if (node->right->weight_left() >= node->right->weight_right() * GAMMA) {
102                node->right = _rotate_right(u);
103            }
104            return _rotate_left(node);
105        }
106
107        NodePtr _balance_right(NodePtr node) {
108            NodePtr u = node->left;
109            if (node->left->weight_right() >= node->left->weight_left() * GAMMA) {
110                node->left = _rotate_left(u);
111            }
112            return _rotate_right(node);
113        }
114
115        int weight(NodePtr node) const {
116            return node ? node->size + 1 : 1;
117        }
118
119        NodePtr _merge_with_root(NodePtr l, NodePtr root, NodePtr r) {
120            if (weight(l) * DELTA < weight(r)) {
121                r->left = _merge_with_root(l, root, r->left);
122                r->update();
123                if (weight(r->right) * DELTA < weight(r->left)) {
124                    return _balance_right(r);
125                }
126                return r;
127            } else if (weight(r) * DELTA < weight(l)) {
128                l->right = _merge_with_root(l->right, root, r);
129                l->update();
130                if (weight(l->left) * DELTA < weight(l->right)) {
131                    return _balance_left(l);
132                }
133                return l;
134            }
135            root->left = l;
136            root->right = r;
137            root->update();
138            return root;
139        }
140
141        pair<NodePtr, NodePtr> _pop_right(NodePtr node) {
142            return _split_node(node, node->size-1);
143        }
144
145        NodePtr _merge_node(NodePtr l, NodePtr r) {
146            if ((!l) && (!r)) {return nullptr;}
147            if (!l) {return r;}
148            if (!r) {return l;}
149            auto [l_, root_] = _pop_right(l);
150            return _merge_with_root(l_, root_, r);
151        }
152
153        pair<NodePtr, NodePtr> _split_node(NodePtr node, int k) {
154            if (!node) return {nullptr, nullptr};
155            int tmp = node->left ? k-node->left->size : k;
156            NodePtr nl = node->left;
157            NodePtr nr = node->right;
158            if (tmp == 0) {
159                return {nl, _merge_with_root(nullptr, node, nr)};
160            } else if (tmp < 0) {
161                auto [l, r] = _split_node(nl, k);
162                return {l, _merge_with_root(r, node, nr)};
163            } else {
164                auto [l, r] = _split_node(nr, tmp-1);
165                return {_merge_with_root(node->left, node, l), r};
166            }
167        }
168
169        WBSegTree(NodePtr &root): root(root) {}
170
171        MyLazyWBTree _new(NodePtr root) {
172            return MyLazyWBTree(root);
173        }
174
175    public:
176        NodePtr root;
177
178        WBSegTree() : root(nullptr) {}
179
180        WBSegTree(vector<T> const &a) { _build(a); }
181
182        void merge(MyLazyWBTree &other) {
183            this->root = _merge_node(this->root, other.root);
184        }
185
186        pair<MyLazyWBTree, MyLazyWBTree> split(const int k) {
187            auto [l, r] = _split_node(this->root, k);
188            return {_new(l), _new(r)};
189        }
190
191        T all_prod() const {
192            return this->root ? this->root->data : e();
193        }
194
195        T prod(const int l, const int r) {
196            assert(0 <= l && l <= r && r <= len());
197            if (l == r) return e();
198            auto dfs = [&] (auto &&dfs, NodePtr node, int left, int right) -> T {
199                if (right <= l || r <= left) return e();
200                if (l <= left && right < r) return node->data;
201                int lsize = node->left ? node->left->size : 0;
202                T res = e();
203                if (node->left) res = dfs(dfs, node->left, left, left+lsize);
204                if (l <= left+lsize && left+lsize < r) res = op(res, node->key);
205                if (node->right) res = op(res, dfs(dfs, node->right, left+lsize+1, right));
206                return res;
207            };
208            return dfs(dfs, root, 0, len());
209        }
210
211        void insert(int k, const T key) {
212            assert(0 <= k && k <= len());
213            auto [s, t] = _split_node(root, k);
214            NodePtr new_node = new Node(key);
215            this->root = _merge_with_root(s, new_node, t);
216        }
217
218        T pop(int k) {
219            assert(0 <= k && k < len());
220            stack<NodePtr> path;
221            stack<short> path_d;
222            NodePtr node = this->root;
223            T res;
224            int d = 0;
225            while (1) {
226                int t = node->left? node->left->size: 0;
227                if (t == k) {
228                    res = node->key;
229                    break;
230                }
231                int d = (t < k)? 0: 1;
232                path.emplace(node);
233                path_d.emplace(d);
234                if (d) {
235                    node = node->left;
236                } else {
237                    k -= t + 1;
238                    node = node->right;
239                }
240            }
241            if (node->left && node->right) {
242                NodePtr lmax = node->left;
243                path.emplace(node);
244                path_d.emplace(1);
245                while (lmax->right) {
246                    path.emplace(lmax);
247                    path_d.emplace(0);
248                    lmax = lmax->right;
249                }
250                node->key = lmax->key;
251                node = lmax;
252            }
253            NodePtr cnode = (node->left) ? node->left : node->right;
254            if (!path.empty()) {
255                if (path_d.top()) path.top()->left = cnode;
256                else path.top()->right = cnode;
257            } else {
258                this->root = cnode;
259                return res;
260            }
261            while (!path.empty()) {
262                NodePtr new_node = nullptr;
263                node = path.top();
264                node->update();
265                path.pop();
266                path_d.pop();
267                if (weight(node->left) * DELTA < weight(node->right)) {
268                    new_node = _balance_left(node);
269                } else if (weight(node->right) * DELTA < weight(node->left)) {
270                    new_node = _balance_right(node);
271                }
272                if (new_node) {
273                    if (path.empty()) {
274                        this->root = new_node;
275                        break;
276                    }
277                    if (path_d.top()) {
278                        path.top()->left = new_node;
279                    } else {
280                        path.top()->right = new_node;
281                    }
282                }
283            }
284            return res;
285        }
286
287        vector<T> tovector() {
288            NodePtr node = root;
289            stack<NodePtr> s;
290            vector<T> a;
291            a.reserve(len());
292            while ((!s.empty()) || node) {
293                if (node) {
294                    s.emplace(node);
295                    node = node->left;
296                } else {
297                    node = s.top();
298                    s.pop();
299                    a.emplace_back(node->key);
300                    node = node->right;
301                }
302            }
303            return a;
304        }
305
306        void set(int k, T v) {
307            assert(0 <= k && k < len());
308            NodePtr node = root;
309            NodePtr root = node;
310            NodePtr pnode = nullptr;
311            int d = 0;
312            stack<NodePtr> path;
313            path.emplace(node);
314            while (1) {
315                int t = node->left ? node->left->size : 0;
316                if (t == k) {
317                    node->key = v;
318                    path.emplace(node);
319                    if (pnode) {
320                        if (d) pnode->left = node;
321                        else pnode->right = node;
322                    }
323                    while (!path.empty()) {
324                        path.top()->update();
325                        path.pop();
326                    }
327                    return;
328                }
329                pnode = node;
330                d = (t < k) ? 0 : 1;
331                if (d) {
332                    pnode->left = node = node->left;
333                } else {
334                    k -= t + 1;
335                    pnode->right = node = node->right;
336                }
337                path.emplace(node);
338            }
339        }
340
341        T get(int k) {
342            assert(0 <= k && k < len());
343            NodePtr node = root;
344            while (1) {
345                int t = node->left ? node->left->size : 0;
346                if (t == k) {
347                    return node->key;
348                }
349                if (t < k) {
350                    k -= t + 1;
351                    node = node->right;
352                } else {
353                    node = node->left;
354                }
355            }
356        }
357
358        void print() {
359            vector<T> a = tovector();
360            cout << "[";
361            for (int i = 0; i < (int)a.size()-1; ++i) {
362                cout << a[i] << ", ";
363            }
364            if (!a.empty()) cout << a.back();
365            cout << "]" << endl;
366        }
367
368        friend ostream& operator<<(ostream& os, WBSegTree<T, op, e> &tree) {
369            vector<T> a = tree.tovector();
370            os << "[";
371            for (int i = 0; i < (int)a.size()-1; ++i) {
372                os << a[i] << ", ";
373            }
374            if (!a.empty()) os << a.back();
375            os << "]";
376            return os;
377        }
378
379        int len() const {
380            return root ? root->size : 0;
381        }
382
383        void check() const {
384            auto rec = [&] (auto &&rec, NodePtr node) -> pair<int, int> {
385                int ls = 0, rs = 0;
386                int height = 0;
387                int h;
388                if (node->left) {
389                    pair<int, int> res = rec(rec, node->left);
390                    ls = res.first;
391                    h = res.second;
392                    height = max(height, h);
393                }
394                if (node->right) {
395                    pair<int, int> res = rec(rec, node->right);
396                    rs = res.first;
397                    h = res.second;
398                    height = max(height, h);
399                }
400                node->balance_check();
401                int s = ls + rs + 1;
402                assert(s == node->size);
403                return {s, height+1};
404            };
405            if (root == nullptr) return;
406            auto [_, h] = rec(rec, root);
407            cerr << PRINT_GREEN << "OK : height=" << h << PRINT_NONE << endl;
408        }
409    };
410} // namespace titan23

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/wb_tree_seg.cpp