avl tree set

ソースコード

  1#include <iostream>
  2#include <algorithm>
  3#include <vector>
  4#include <cassert>
  5#include <stack>
  6#include "titan_cpplib/data_structures/bbst_node.cpp"
  7using namespace std;
  8
  9// AVLTreeSet
 10namespace titan23 {
 11
 12    template<typename T>
 13    class AVLTreeSet {
 14      public:
 15        class AVLTreeSetNode {
 16          public:
 17            using AVLTreeSetNodePtr = AVLTreeSetNode*;
 18            T key;
 19            int size, height;
 20            AVLTreeSetNodePtr par, left, right;
 21
 22            AVLTreeSetNode() : size(0), height(0), par(nullptr), left(nullptr), right(nullptr) {}
 23            AVLTreeSetNode(const T &key) : key(key), size(1), height(1), par(nullptr), left(nullptr), right(nullptr) {}
 24
 25            int balance() const {
 26                int hl = left ? left->height : 0;
 27                int hr = right ? right->height : 0;
 28                return hl - hr;
 29            }
 30
 31            void update() {
 32                size = 1 + (left ? left->size : 0) + (right ? right->size : 0);
 33                height = 1 + max((left ? left->height : 0), (right ? right->height : 0));
 34            }
 35        };
 36
 37      private:
 38        T missing;
 39        using AVLTreeSetNodePtr = AVLTreeSetNode*;
 40        AVLTreeSetNodePtr root;
 41
 42        void _remove_balance(AVLTreeSetNodePtr node) {
 43            while (node) {
 44                AVLTreeSetNodePtr new_node = nullptr;
 45                node->update();
 46                if (node->balance() == 2) {
 47                    new_node = node->left->balance() == -1 ? BBSTNode<AVLTreeSetNodePtr>::rotate_LR(node) : BBSTNode<AVLTreeSetNodePtr>::rotate_right(node);
 48                } else if (node->balance() == -2) {
 49                    new_node = node->right->balance() == 1 ? BBSTNode<AVLTreeSetNodePtr>::rotate_RL(node) : BBSTNode<AVLTreeSetNodePtr>::rotate_left(node);
 50                } else if (node->balance() != 0) {
 51                    node = node->par;
 52                    break;
 53                }
 54                if (!new_node) {
 55                    node = node->par;
 56                    continue;
 57                }
 58                if (!new_node->par) {
 59                    this->root = new_node;
 60                    return;
 61                }
 62                node = new_node->par;
 63                if (new_node->key < node->key) {
 64                    node->left = new_node;
 65                } else {
 66                    node->right = new_node;
 67                }
 68                if (new_node->balance() != 0) break;
 69            }
 70            while (node) {
 71                node->update();
 72                node = node->par;
 73            }
 74        }
 75
 76        void _add_balance(AVLTreeSetNodePtr node) {
 77            AVLTreeSetNodePtr new_node = nullptr;
 78            while (node) {
 79                node->update();
 80                if (node->balance() == 0) {
 81                    node = node->par;
 82                    break;
 83                }
 84                if (node->balance() == 2) {
 85                    new_node = node->left->balance() == -1 ? BBSTNode<AVLTreeSetNodePtr>::rotate_LR(node) : BBSTNode<AVLTreeSetNodePtr>::rotate_right(node);
 86                    break;
 87                } else if (node->balance() == -2) {
 88                    new_node = node->right->balance() == 1 ? BBSTNode<AVLTreeSetNodePtr>::rotate_RL(node) : BBSTNode<AVLTreeSetNodePtr>::rotate_left(node);
 89                    break;
 90                }
 91                node = node->par;
 92            }
 93            if (new_node) {
 94                node = new_node->par;
 95                if (node) {
 96                    if (new_node->key < node->key) {
 97                        node->left = new_node;
 98                    } else {
 99                        node->right = new_node;
100                    }
101                } else {
102                    this->root = new_node;
103                }
104            }
105            while (node) {
106                node->update();
107                node = node->par;
108            }
109        }
110
111      public:
112        AVLTreeSet() : missing(-1), root(nullptr) {}
113        AVLTreeSet(T missing) : missing(-1), root(nullptr) {}
114        AVLTreeSet(vector<T> &a, T missing) : missing(-1) {
115            this->root = build(a);
116        }
117
118        AVLTreeSetNodePtr build(vector<T> a) {
119            auto _build = [&] (auto &&_build, int l, int r) -> AVLTreeSetNodePtr {
120                int mid = (l + r) / 2;
121                AVLTreeSetNodePtr node = new AVLTreeSetNode(a[mid]);
122                if (l != mid) {
123                    node->left = _build(_build, l, mid);
124                    node->left->par = node;
125                }
126                if (mid+1 != r) {
127                    node->right = _build(_build, mid+1, r);
128                    node->right->par = node;
129                }
130                node->update();
131                return node;
132            };
133
134            if (a.empty()) return nullptr;
135            int n = a.size();
136            bool is_sorted = true;
137            for (int i = 0; i < n-1; ++i) {
138                if (!(a[i] < a[i+1])) {
139                    is_sorted = false;
140                    break;
141                }
142            }
143            if (!is_sorted) {
144                sort(a.begin(), a.end());
145                a.erase(unique(a.begin(), a.end()), a.end());
146            }
147            return _build(_build, 0, a.size());
148        }
149
150        bool add(const T &key) {
151            if (!root) {
152                root = new AVLTreeSetNode(key);
153                return true;
154            }
155            AVLTreeSetNodePtr pnode = nullptr;
156            AVLTreeSetNodePtr node = root;
157            while (node) {
158                if (key == node->key) return false;
159                pnode = node;
160                node = key < node->key ? node->left : node->right;
161            }
162            if (key < pnode->key) {
163                pnode->left = new AVLTreeSetNode(key);
164                pnode->left->par = pnode;
165            } else {
166                pnode->right = new AVLTreeSetNode(key);
167                pnode->right->par = pnode;
168            }
169            _add_balance(pnode);
170            return true;
171        }
172
173        void remove_iter(AVLTreeSetNodePtr node) {
174            AVLTreeSetNodePtr pnode = node->par;
175            if (node->left && node->right) {
176                pnode = node;
177                AVLTreeSetNodePtr mnode = node->left;
178                while (mnode->right) {
179                    pnode = mnode;
180                    mnode = mnode->right;
181                }
182                node->key = mnode->key;
183                node = mnode;
184            }
185            AVLTreeSetNodePtr cnode = (!node->left) ? node->right : node->left;
186            if (cnode) {
187                cnode->par = pnode;
188            }
189            if (pnode) {
190                if (node->key <= pnode->key) {
191                pnode->left = cnode;
192                } else {
193                pnode->right = cnode;
194                }
195                _remove_balance(pnode);
196            } else {
197                root = cnode;
198            }
199        }
200
201        bool discard(const T &key) {
202            AVLTreeSetNodePtr node = find_key(key);
203            if (!node) return false;
204            remove_iter(node);
205            return true;
206        }
207
208        void remove(const T &key) {
209            AVLTreeSetNodePtr node = find_key(key);
210            assert(node);
211            remove_iter(node);
212        }
213
214        T le(const T &key) const {
215            T res = missing;
216            AVLTreeSetNodePtr node = root;
217            while (node) {
218                if (key == node->key) {
219                    res = node->key;
220                    break;
221                }
222                if (key < node->key) {
223                    node = node->left;
224                } else {
225                    res = node->key;
226                    node = node->right;
227                }
228            }
229            return res;
230        }
231
232        T lt(const T &key) const {
233            T res = missing;
234            AVLTreeSetNodePtr node = root;
235            while (node) {
236                if (key <= node->key) {
237                    node = node->left;
238                } else {
239                    res = node->key;
240                    node = node->right;
241                }
242            }
243            return res;
244        }
245
246        T ge(const T &key) const {
247            T res = missing;
248            AVLTreeSetNodePtr node = root;
249            while (node) {
250                if (key == node->key) {
251                    res = node->key;
252                    break;
253                }
254                if (key < node->key) {
255                    res = node->key;
256                    node = node->left;
257                } else {
258                    node = node->right;
259                }
260            }
261            return res;
262        }
263
264        T gt(const T &key) const {
265            T res = missing;
266            AVLTreeSetNodePtr node = root;
267            while (node) {
268                if (key < node->key) {
269                    res = node->key;
270                    node = node->left;
271                } else {
272                    node = node->right;
273                }
274            }
275            return res;
276        }
277
278        int index(const T &key) const {
279            int k = 0;
280            AVLTreeSetNodePtr node = root;
281            while (node) {
282                if (key == node->key) {
283                    k += node->left ? node->left->size : 0;
284                    break;
285                }
286                if (key < node->key) {
287                    node = node->left;
288                } else {
289                    k += node->left ? (node->left->size + 1) : 1;
290                    node = node->right;
291                }
292            }
293            return k;
294        }
295
296        int index_right(const T &key) const {
297            int k = 0;
298            AVLTreeSetNodePtr node = root;
299            while (node) {
300                if (key == node->key) {
301                    k += node->left ? (node->left->size + 1) : 1;
302                    break;
303                }
304                if (key < node->key) {
305                    node = node->left;
306                } else {
307                    k += node->left ? (node->left->size + 1) : 1;
308                    node = node->right;
309                }
310            }
311            return k;
312        }
313
314        AVLTreeSetNodePtr find_key(const T &key) const {
315            AVLTreeSetNodePtr node = root;
316            while (node) {
317                if (key == node->key) return node;
318                node = key < node->key ? node->left : node->right;
319            }
320            return nullptr;
321        }
322
323        AVLTreeSetNodePtr find_kth(int k) const {
324            AVLTreeSetNodePtr node = root;
325            while (true) {
326                int t = node->left ? node->left->size : 0;
327                if (t == k) return node;
328                if (t < k) {
329                    k -= t + 1;
330                    node = node->right;
331                } else {
332                    node = node->left;
333                }
334            }
335        }
336
337        T pop(int k=-1) {
338            AVLTreeSetNodePtr node = find_kth(k);
339            T key = node->key;
340            remove_iter(node);
341            return key;
342        }
343
344        vector<T> tovector() const {
345            vector<T> a;
346            a.reserve(len());
347            stack<AVLTreeSetNodePtr> st;
348            AVLTreeSetNodePtr node = root;
349            while ((!st.empty()) || node) {
350                if (node) {
351                    st.emplace(node);
352                    node = node->left;
353                } else {
354                    node = st.top();
355                    st.pop();
356                    a.emplace_back(node->key);
357                    node = node->right;
358                }
359            }
360            return a;
361        }
362
363        bool contains(T key) const {
364            return find_key(key) != nullptr;
365        }
366
367        T get(int k) const {
368            return find_kth(k)->key;
369        }
370
371        int len() const {
372            return root ? root->size : 0;
373        }
374
375        void check() const {
376            if (!root) {
377                cout << "height=0" << endl;;
378                cout << "check ok empty." << endl;;
379                return;
380            }
381            cout << "height=" << root->height << endl;
382
383            auto dfs = [&] (auto &&dfs, AVLTreeSetNodePtr node) -> void {
384                int h = 0;
385                int b = 0;
386                int s = 1;
387                if (node->left) {
388                    assert(node->left->par == node);
389                    assert(node->key > node->left->key);
390                    dfs(dfs, node->left);
391                    h = max(h, node->left->height);
392                    b += node->left->height;
393                    s += node->left->size;
394                }
395                if (node->right) {
396                    assert(node->right->par == node);
397                    assert(node->key < node->right->key);
398                    dfs(dfs, node->right);
399                    h = max(h, node->right->height);
400                    b -= node->right->height;
401                    s += node->right->size;
402                }
403                assert(node->height == h+1);
404                assert(-1 <= b && b <= 1);
405                assert(node->size == s);
406            };
407            dfs(dfs, root);
408            cout << "check ok." << endl;
409        }
410
411        friend ostream& operator<<(ostream& os, const titan23::AVLTreeSet<T>& s) {
412            vector<T> a = s.tovector();
413            int n = a.size();
414            os << "{";
415            for (int i = 0; i < n - 1; ++i) {
416                os << a[i] << ", ";
417            }
418            if (n > 0) os << a.back();
419            os << "}";
420            return os;
421        }
422    };
423}

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/avl_tree_set.cpp