wb tree

ソースコード

  1#include <iostream>
  2#include <vector>
  3#include <cmath>
  4#include <cassert>
  5#include <memory>
  6using namespace std;
  7
  8namespace titan23 {
  9    const double ALPHA = 1 - sqrt(2) / 2;
 10    const double BETA = (1 - 2 * ALPHA) / (1 - ALPHA);
 11
 12    template <class T>
 13    class WBTree {
 14
 15      private:
 16        class Node;
 17        using NodePtr = Node*;
 18        using MyWBTree = WBTree<T>;
 19
 20        class Node {
 21          public:
 22            T key;
 23            NodePtr left, right;
 24            int size;
 25
 26            Node(T key) : key(key), left(nullptr), right(nullptr), size(1) {}
 27
 28            double balance() const {
 29                return ((left ? left->size : 0) + 1.0) / (size + 1.0);
 30            }
 31
 32            void _update() {
 33                size = 1;
 34                if (left) {
 35                    size += left->size;
 36                }
 37                if (right) {
 38                    size += right->size;
 39                }
 40            }
 41        };
 42
 43        void _build(vector<T> const &a) {
 44            if (a.empty()) {
 45                root = nullptr;
 46                return;
 47            }
 48            auto build = [&] (auto &&build, int l, int r) -> NodePtr {
 49                int mid = (l + r) >> 1;
 50                NodePtr node = new Node(a[mid]);
 51                if (l != mid) node->left = build(build, l, mid);
 52                if (mid+1 != r) node->right = build(build, mid+1, r);
 53                node->_update();
 54                return node;
 55            };
 56            root = build(build, 0, (int)a.size());
 57        }
 58
 59        NodePtr _rotate_right(NodePtr node) {
 60            NodePtr u = node->left;
 61            node->left = u->right;
 62            u->right = node;
 63            node->_update();
 64            u->_update();
 65            return u;
 66        }
 67
 68        NodePtr _rotate_left(NodePtr node) {
 69            NodePtr u = node->right;
 70            node->right = u->left;
 71            u->left = node;
 72            node->_update();
 73            u->_update();
 74            return u;
 75        }
 76
 77        NodePtr _balance_left(NodePtr node) {
 78            NodePtr u = node->right;
 79            if (u->balance() >= BETA) {
 80                node->right = _rotate_right(u);
 81            }
 82            return _rotate_left(node);
 83        }
 84
 85        NodePtr _balance_right(NodePtr node) {
 86            NodePtr u = node->left;
 87            if (u->balance() <= 1.0 - BETA) {
 88                node->left = _rotate_left(u);
 89            }
 90            return _rotate_right(node);
 91        }
 92
 93        NodePtr _merge_with_root(NodePtr l, NodePtr root, NodePtr r) {
 94            int ls = l ? l->size : 0;
 95            int rs = r ? r->size : 0;
 96            double diff = (double)(ls+1.0) / (ls+rs+2.0);
 97            if (diff > 1.0-ALPHA) {
 98                l->right = _merge_with_root(l->right, root, r);
 99                l->_update();
100                if (!(ALPHA <= l->balance() && l->balance() <= 1.0-ALPHA)) {
101                    return _balance_left(l);
102                }
103                return l;
104            }
105            if (diff < ALPHA) {
106                r->left = _merge_with_root(l, root, r->left);
107                r->_update();
108                if (!(ALPHA <= r->balance() && r->balance() <= 1.0-ALPHA)) {
109                    return _balance_right(r);
110                }
111                return r;
112            }
113            root->left = l;
114            root->right = r;
115            root->_update();
116            return root;
117        }
118
119        pair<NodePtr, NodePtr> _pop_right(NodePtr node) {
120            vector<NodePtr> path;
121            NodePtr mx = node;
122            while (node->right) {
123                path.emplace_back(node);
124                node = node->right;
125                mx = node;
126            }
127            path.emplace_back(node->left? node->left: nullptr);
128            while ((int)path.size() > 1) {
129                    node = path.back();
130                    path.pop_back();
131                    if (!node) {
132                    path.back()->right = nullptr;
133                    path.back()->_update();
134                    continue;
135                }
136                double b = node->balance();
137                if (ALPHA <= b && b <= 1.0-ALPHA) {
138                    path.back()->right = node;
139                } else if (b > 1.0-ALPHA) {
140                    path.back()->right = _balance_right(node);
141                } else {
142                    path.back()->right = _balance_left(node);
143                }
144                path.back()->_update();
145            }
146            if (path[0]) {
147                double b = path[0]->balance();
148                if (b > 1.0-ALPHA) {
149                    path[0] = _balance_right(path[0]);
150                } else if (b < ALPHA) {
151                    path[0] = _balance_left(path[0]);
152                }
153            }
154            mx->left = nullptr;
155            mx->_update();
156            return {path[0], mx};
157        }
158
159        NodePtr _merge_node(NodePtr l, NodePtr r) {
160            if ((!l) && (!r)) {return nullptr;}
161            if (!l) {return r;}
162            if (!r) {return l;}
163            auto [l_, root_] = _pop_right(l);
164            return _merge_with_root(l_, root_, r);
165        }
166
167        pair<NodePtr, NodePtr> _split_node(NodePtr node, int k) {
168            if (!node) return {nullptr, nullptr};
169            int tmp = node->left? k-node->left->size: k;
170            NodePtr nl = node->left;
171            NodePtr nr = node->right;
172            if (tmp == 0) {
173                return {nl, _merge_with_root(nullptr, node, nr)};
174            } else if (tmp < 0) {
175                auto [l, r] = _split_node(nl, k);
176                return {l, _merge_with_root(r, node, nr)};
177            } else {
178                auto [l, r] = _split_node(nr, tmp-1);
179                return {_merge_with_root(node->left, node, l), r};
180            }
181        }
182
183        WBTree<T> _new(NodePtr root) {
184            WBTree<T> p(root);
185            return p;
186        }
187
188        WBTree(NodePtr &root): root(root) {}
189
190      public:
191        NodePtr root;
192
193        WBTree() : root(nullptr) {}
194
195        WBTree(vector<T> const &a) { _build(a); }
196
197        void build(vector<T> a) {
198            if (a.empty()) {
199                this->root = nullptr;
200                return;
201            }
202            auto _build = [&] (auto &&_build, int l, int r) -> NodePtr {
203                int mid = (l + r) >> 1;
204                NodePtr node = new Node(a[mid]);
205                if (l != mid) node->left = _build(_build, l, mid);
206                if (mid+1 != r) node->right = _build(_build, mid+1, r);
207                node->_update();
208                return node;
209            };
210            this->root = _build(_build, 0, (int)a.size());
211        }
212
213        void merge(MyWBTree &other) {
214            this->root = _merge_node(this->root, other.root);
215        }
216
217        pair<MyWBTree, MyWBTree> split(const int k) {
218            auto [l, r] = _split_node(this->root, k);
219            return {_new(l), _new(r)};
220        }
221
222        void insert(int k, const T key) {
223            assert(0 <= k && k <= len());
224            auto [s, t] = _split_node(root, k);
225            NodePtr new_node = new Node(key);
226            this->root = _merge_with_root(s, new_node, t);
227        }
228
229        void emplace_back(T key) {
230            insert(len(), key);
231        }
232
233        T pop(int k) {
234            if (k < 0) k += len();
235            assert(0 <= k && k < len());
236            vector<NodePtr> path;
237            vector<short> path_d;
238            NodePtr node = this->root;
239            T res;
240            int d = 0;
241            while (1) {
242                int t = node->left? node->left->size: 0;
243                if (t == k) {
244                    res = node->key;
245                    break;
246                }
247                int d = (t < k)? 0: 1;
248                path.emplace_back(node);
249                path_d.emplace_back(d);
250                if (d) {
251                    node = node->left;
252                } else {
253                    k -= t + 1;
254                    node = node->right;
255                }
256            }
257            if (node->left && node->right) {
258                NodePtr lmax = node->left;
259                path.emplace_back(node);
260                path_d.emplace_back(1);
261                while (lmax->right) {
262                    path.emplace_back(lmax);
263                    path_d.emplace_back(0);
264                    lmax = lmax->right;
265                }
266                node->key = lmax->key;
267                node = lmax;
268            }
269            NodePtr cnode = (node->left)? node->left: node->right;
270            if (!path.empty()) {
271                if (path_d.back()) path.back()->left = cnode;
272                else path.back()->right = cnode;
273            } else {
274                this->root = cnode;
275                return res;
276            }
277            while (!path.empty()) {
278                NodePtr new_node = nullptr;
279                node = path.back();
280                node->_update();
281                path.pop_back();
282                path_d.pop_back();
283                double b = node->balance();
284                if (b < ALPHA) {
285                    new_node = _balance_left(node);
286                } else if (b > 1.0-ALPHA) {
287                    new_node = _balance_right(node);
288                }
289                if (new_node) {
290                    if (path.empty()) {
291                        this->root = new_node;
292                        break;
293                    }
294                    if (path_d.back()) {
295                        path.back()->left = new_node;
296                    } else {
297                        path.back()->right = new_node;
298                    }
299                }
300            }
301            return res;
302        }
303
304        vector<T> tovector() {
305            NodePtr node = root;
306            vector<NodePtr> stack;
307            vector<T> a;
308            a.reserve(len());
309            while ((!stack.empty()) || node) {
310                if (node) {
311                stack.emplace_back(node);
312                node = node->left;
313                } else {
314                node = stack.back();
315                stack.pop_back();
316                a.emplace_back(node->key);
317                node = node->right;
318                }
319            }
320            return a;
321        }
322
323        void set(int k, T v) {
324            assert(0 <= k && k < len());
325            NodePtr node = root;
326            NodePtr root = node;
327            NodePtr pnode = nullptr;
328            int d = 0;
329            vector<NodePtr> path = {node};
330            while (1) {
331                int t = node->left? node->left->size: 0;
332                if (t == k) {
333                    node->key = v;
334                    path.emplace_back(node);
335                    if (pnode) {
336                        if (d) pnode->left = node;
337                        else pnode->right = node;
338                    }
339                    while (!path.empty()) {
340                        path.back()->_update();
341                        path.pop_back();
342                    }
343                    return;
344                }
345                pnode = node;
346                d = (t < k)? 0: 1;
347                if (d) {
348                    pnode->left = node = node->left;
349                } else {
350                    k -= t + 1;
351                    pnode->right = node = node->right;
352                }
353                path.emplace_back(node);
354            }
355        }
356
357        T operator[] (int k) const {
358            if (k < 0) k += len();
359            assert(0 <= k && k < len());
360            NodePtr node = root;
361            while (1) {
362                    int t = node->left? node->left->size: 0;
363                if (t == k) {
364                    return node->key;
365                }
366                if (t < k) {
367                    k -= t + 1;
368                    node = node->right;
369                } else {
370                    node = node->left;
371                }
372            }
373        }
374
375        void clear() {
376            this->root = nullptr;
377        }
378
379        T& operator[] (int k) {
380            if (k < 0) k += len();
381            assert(0 <= k && k < len());
382            NodePtr node = root;
383            while (1) {
384                int t = node->left? node->left->size: 0;
385                if (t == k) {
386                    return node->key;
387                }
388                if (t < k) {
389                    k -= t + 1;
390                    node = node->right;
391                } else {
392                    node = node->left;
393                }
394            }
395        }
396
397        void print() {
398            vector<T> a = tovector();
399            cout << "[";
400            for (int i = 0; i < (int)a.size()-1; ++i) {
401                cout << a[i] << ", ";
402            }
403            if (!a.empty()) cout << a.back();
404            cout << "]" << endl;
405        }
406
407        int len() const {
408            return root ? root->size : 0;
409        }
410
411        void isok() {
412            auto rec = [&] (auto &&rec, NodePtr node) -> pair<int, int> {
413                int ls = 0, rs = 0;
414                int height = 0;
415                int h;
416                if (node->left) {
417                pair<int, int> res = rec(rec, node->left);
418                ls = res.first;
419                h = res.second;
420                height = max(height, h);
421                }
422                if (node->right) {
423                pair<int, int> res = rec(rec, node->right);
424                rs = res.first;
425                h = res.second;
426                height = max(height, h);
427                }
428                int s = ls + rs + 1;
429                double b = (double)(ls+1) / (s+1);
430                assert(s == node->size);
431                assert(ALPHA <= b && b <= 1-ALPHA);
432                return {s, height+1};
433            };
434            if (root == nullptr) return;
435            auto [_, h] = rec(rec, root);
436            // printf("height=%d\n", h);
437        }
438    };
439} // namespace titan23

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/wb_tree.cpp