avl tree multiset

ソースコード

  1#include <iostream>
  2#include <algorithm>
  3#include <vector>
  4#include <cassert>
  5#include "titan_cpplib/data_structures/bbst_node.cpp"
  6using namespace std;
  7
  8// AVLTreeMultiset
  9namespace titan23 {
 10
 11    template<typename T>
 12    class AVLTreeMultiset {
 13      public:
 14        class AVLTreeMultisetNode {
 15          public:
 16            using AVLTreeMultisetNodePtr = AVLTreeMultisetNode*;
 17            T key;
 18            int val, valsize, height;
 19            AVLTreeMultisetNodePtr par, left, right;
 20
 21            AVLTreeMultisetNode() {}
 22            AVLTreeMultisetNode(const T &key, const int val) : key(key), val(val), valsize(val), height(1), par(nullptr), left(nullptr), right(nullptr) {}
 23
 24            int balance() const {
 25                int hl = left ? left->height : 0;
 26                int hr = right ? right->height : 0;
 27                return hl - hr;
 28            }
 29
 30            void update() {
 31                valsize = val + (left ? left->valsize : 0) + (right ? right->valsize : 0);
 32                height = 1 + max((left ? left->height : 0), (right ? right->height : 0));
 33            }
 34        };
 35
 36        T missing;
 37        using AVLTreeMultisetNodePtr = AVLTreeMultisetNode*;
 38        AVLTreeMultisetNodePtr root;
 39
 40        AVLTreeMultisetNodePtr build(vector<T> a) {
 41            vector<T> x;
 42            vector<int> y;
 43
 44            auto _build = [&] (auto &&_build, int l, int r) -> AVLTreeMultisetNodePtr {
 45                int mid = (l + r) / 2;
 46                AVLTreeMultisetNodePtr node = new AVLTreeMultisetNode(x[mid], y[mid]);
 47                if (l != mid) {
 48                    node->left = _build(_build, l, mid);
 49                    node->left->par = node;
 50                }
 51                if (mid+1 != r) {
 52                    node->right = _build(_build, mid+1, r);
 53                    node->right->par = node;
 54                }
 55                node->update();
 56            return node;
 57            };
 58
 59            if (a.empty()) return nullptr;
 60            int n = a.size();
 61            bool is_sorted = true;
 62            for (int i = 0; i < n-1; ++i) {
 63                if (!(a[i] <= a[i+1])) {
 64                    is_sorted = false;
 65                    break;
 66                }
 67            }
 68            if (!is_sorted) {
 69                sort(a.begin(), a.end());
 70            }
 71
 72            x = {a[0]};
 73            y = {1};
 74            for (int i = 1; i < n; ++i) {
 75                if (a[i] == x.back()) {
 76                    ++y.back();
 77                    continue;
 78                }
 79                x.emplace_back(a[i]);
 80                y.emplace_back(1);
 81            }
 82            return _build(_build, 0, x.size());
 83        }
 84
 85        void _remove_balance(AVLTreeMultisetNodePtr node) {
 86            while (node) {
 87                AVLTreeMultisetNodePtr new_node = nullptr;
 88                node->update();
 89                if (node->balance() == 2) {
 90                    new_node = node->left->balance() == -1 ? BBSTNode<AVLTreeMultisetNodePtr>::rotate_LR(node) : BBSTNode<AVLTreeMultisetNodePtr>::rotate_right(node);
 91                } else if (node->balance() == -2) {
 92                    new_node = node->right->balance() == 1 ? BBSTNode<AVLTreeMultisetNodePtr>::rotate_RL(node) : BBSTNode<AVLTreeMultisetNodePtr>::rotate_left(node);
 93                } else if (node->balance() != 0) {
 94                    node = node->par;
 95                    break;
 96                }
 97                if (!new_node) {
 98                    node = node->par;
 99                    continue;
100                }
101                if (!new_node->par) {
102                    this->root = new_node;
103                    return;
104                }
105                node = new_node->par;
106                if (new_node->key < node->key) {
107                    node->left = new_node;
108                } else {
109                    node->right = new_node;
110                }
111                if (new_node->balance() != 0) break;
112            }
113            while (node) {
114                node->update();
115                node = node->par;
116            }
117        }
118
119        void _add_balance(AVLTreeMultisetNodePtr node) {
120            AVLTreeMultisetNodePtr new_node = nullptr;
121            while (node) {
122                node->update();
123                if (node->balance() == 0) {
124                    node = node->par;
125                    break;
126                }
127                if (node->balance() == 2) {
128                    new_node = node->left->balance() == -1 ? BBSTNode<AVLTreeMultisetNodePtr>::rotate_LR(node) : BBSTNode<AVLTreeMultisetNodePtr>::rotate_right(node);
129                    break;
130                } else if (node->balance() == -2) {
131                    new_node = node->right->balance() == 1 ? BBSTNode<AVLTreeMultisetNodePtr>::rotate_RL(node) : BBSTNode<AVLTreeMultisetNodePtr>::rotate_left(node);
132                    break;
133                }
134                node = node->par;
135            }
136            if (new_node) {
137                node = new_node->par;
138                if (node) {
139                    if (new_node->key < node->key) {
140                    node->left = new_node;
141                    } else {
142                    node->right = new_node;
143                    }
144                } else {
145                    this->root = new_node;
146                }
147            }
148            while (node) {
149                node->update();
150                node = node->par;
151            }
152        }
153
154        void _update_par(AVLTreeMultisetNodePtr node) {
155            while (node) {
156                node->update();
157                node = node->par;
158            }
159        }
160
161        AVLTreeMultisetNodePtr find_key(const T &key) const {
162            AVLTreeMultisetNodePtr node = root;
163            while (node) {
164                if (key == node->key) return node;
165                node = key < node->key ? node->left : node->right;
166            }
167            return nullptr;
168        }
169
170        AVLTreeMultisetNodePtr find_kth(int k) const {
171            AVLTreeMultisetNodePtr node = root;
172            while (true) {
173                assert(node);
174                int t = node->left ? (node->val + node->left->valsize) : node->val;
175                if (t-node->val <= k && k < t) return node;
176                if (t > k) {
177                    node = node->left;
178                } else {
179                    k -= t;
180                    node = node->right;
181                }
182            }
183        }
184
185      public:
186        AVLTreeMultiset() : root(nullptr) {}
187        AVLTreeMultiset(T missing) : missing(missing), root(nullptr) {}
188        AVLTreeMultiset(vector<T> &a, T missing) : missing(missing) {
189            this->root = build(a);
190        }
191
192        void add(const T &key, int val=1) {
193            if (!root) {
194                root = new AVLTreeMultisetNode(key, val);
195                return;
196            }
197            AVLTreeMultisetNodePtr pnode = nullptr;
198            AVLTreeMultisetNodePtr node = root;
199            while (node) {
200                if (key == node->key) {
201                    node->val += val;
202                    _update_par(node);
203                    return;
204                }
205                pnode = node;
206                node = key < node->key ? node->left : node->right;
207            }
208            if (key < pnode->key) {
209                pnode->left = new AVLTreeMultisetNode(key, val);
210                pnode->left->par = pnode;
211            } else {
212                pnode->right = new AVLTreeMultisetNode(key, val);
213                pnode->right->par = pnode;
214            }
215            _add_balance(pnode);
216        }
217
218        void remove_iter(AVLTreeMultisetNodePtr node) {
219            AVLTreeMultisetNodePtr pnode = node->par;
220            if (node->left && node->right) {
221                pnode = node;
222                AVLTreeMultisetNodePtr mnode = node->left;
223                while (mnode->right) {
224                    pnode = mnode;
225                    mnode = mnode->right;
226                }
227                node->key = mnode->key;
228                node->val = mnode->val;
229                node = mnode;
230            }
231            AVLTreeMultisetNodePtr cnode = (!node->left) ? node->right : node->left;
232            if (cnode) cnode->par = pnode;
233            if (pnode) {
234                if (node->key <= pnode->key) {
235                    pnode->left = cnode;
236                } else {
237                    pnode->right = cnode;
238                }
239                _remove_balance(pnode);
240            } else {
241                root = cnode;
242            }
243        }
244
245        bool discard(const T &key, int val=1) {
246            AVLTreeMultisetNodePtr node = find_key(key);
247            if (!node) return false;
248                node->val -= val;
249                if (node->val <= 0) {
250                    remove_iter(node);
251                } else {
252                    _update_par(node);
253            }
254            return true;
255        }
256
257        void remove(const T &key, int val=1) {
258            AVLTreeMultisetNodePtr node = find_key(key);
259            assert(node);
260            node->val -= val;
261            if (node->val <= 0) {
262                remove_iter(node);
263            } else {
264                _update_par(node);
265            }
266        }
267
268        T le(const T &key) const {
269            T res = missing;
270            AVLTreeMultisetNodePtr node = root;
271            while (node) {
272                if (key == node->key) {
273                    res = node->key;
274                    break;
275                }
276                if (key < node->key) {
277                    node = node->left;
278                } else {
279                    res = node->key;
280                    node = node->right;
281                }
282            }
283            return res;
284        }
285
286        T lt(const T &key) const {
287            T res = missing;
288            AVLTreeMultisetNodePtr node = root;
289            while (node) {
290                if (key <= node->key) {
291                    node = node->left;
292                } else {
293                    res = node->key;
294                    node = node->right;
295                }
296            }
297            return res;
298        }
299
300        T ge(const T &key) const {
301            T res = missing;
302            AVLTreeMultisetNodePtr node = root;
303            while (node) {
304                if (key == node->key) {
305                    res = node->key;
306                    break;
307                }
308                if (key < node->key) {
309                    res = node->key;
310                    node = node->left;
311                } else {
312                    node = node->right;
313                }
314            }
315            return res;
316        }
317
318        T gt(const T &key) const {
319            T res = missing;
320            AVLTreeMultisetNodePtr node = root;
321            while (node) {
322                if (key < node->key) {
323                    res = node->key;
324                    node = node->left;
325                } else {
326                    node = node->right;
327                }
328            }
329            return res;
330        }
331
332        int index(const T &key) const {
333            int k = 0;
334            AVLTreeMultisetNodePtr node = root;
335            while (node) {
336                if (key == node->key) {
337                    k += node->left ? node->left->valsize : 0;
338                    break;
339                }
340                if (key < node->key) {
341                    node = node->left;
342                } else {
343                    k += node->left ? (node->left->valsize + node->val) : node->val;
344                    node = node->right;
345                }
346            }
347            return k;
348        }
349
350        int index_right(const T &key) const {
351            int k = 0;
352            AVLTreeMultisetNodePtr node = root;
353            while (node) {
354                if (key == node->key) {
355                    k += node->left ? (node->left->valsize + node->val) : node->val;
356                    break;
357                }
358                if (key < node->key) {
359                    node = node->left;
360                } else {
361                    k += node->left ? (node->left->valsize + node->val) : node->val;
362                    node = node->right;
363                }
364            }
365            return k;
366        }
367
368        T pop(int k=-1) {
369            AVLTreeMultisetNodePtr node = find_kth(k);
370            T key = node->key;
371            node->val -= 1;
372            if (node->val == 0) {
373                remove_iter(node);
374            } else {
375                _update_par(node);
376            }
377            return key;
378        }
379
380        vector<T> tovector() const {
381            vector<T> a;
382            a.reserve(len());
383            vector<AVLTreeMultisetNodePtr> st;
384            AVLTreeMultisetNodePtr node = root;
385            while ((!st.empty()) || node) {
386                if (node) {
387                    st.emplace_back(node);
388                    node = node->left;
389                } else {
390                    node = st.back();
391                    st.pop_back();
392                    for (int i = 0; i < node->val; ++i) {
393                        a.emplace_back(node->key);
394                    }
395                    node = node->right;
396                }
397            }
398            return a;
399        }
400
401        bool contains(T key) const {
402            return find_key(key) != nullptr;
403        }
404
405        T get(int k) const {
406            return find_kth(k)->key;
407        }
408
409        int len() const {
410            return root ? root->valsize : 0;
411        }
412
413        void print() const {
414            vector<T> a = tovector();
415            int n = a.size();
416            cout << "{";
417            for (int i = 0; i < n-1; ++i) {
418                cout << a[i] << ", ";
419            }
420            if (n > 0) cout << a.back();
421            cout << "}" << endl;
422        }
423
424        void check() const {
425            if (!root) {
426            // cout << "height=0" << endl;
427            // cout << "check ok empty." << endl;
428            return;
429            }
430            // cout << "height=" << root->height << endl;
431
432            auto dfs = [&] (auto &&dfs, AVLTreeMultisetNodePtr node) -> void {
433                int h = 0;
434                int b = 0;
435                int vs = node->val;
436                if (node->left) {
437                    assert(node->left->par == node);
438                    assert(node->key > node->left->key);
439                    dfs(dfs, node->left);
440                    h = max(h, node->left->height);
441                    b += node->left->height;
442                    vs += node->left->valsize;
443                }
444                if (node->right) {
445                    assert(node->right->par == node);
446                    assert(node->key < node->right->key);
447                    dfs(dfs, node->right);
448                    h = max(h, node->right->height);
449                    b -= node->right->height;
450                    vs += node->right->valsize;
451                }
452                assert(node->valsize == vs);
453                assert(node->height == h+1);
454                assert(-1 <= b && b <= 1);
455            };
456            dfs(dfs, root);
457            // cout << "check ok." << endl;
458        }
459
460        friend ostream& operator<<(ostream& os, const titan23::AVLTreeMultiset<T>& s) {
461            vector<T> a = s.tovector();
462            int n = a.size();
463            os << "{";
464            for (int i = 0; i < n - 1; ++i) {
465                os << a[i] << ", ";
466            }
467            if (n > 0) os << a.back();
468            os << "}";
469            return os;
470        }
471    };
472} // namespace titan23

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/avl_tree_multiset.cpp