persistent multiset

ソースコード

  1#include <iostream>
  2#include <vector>
  3#include <cmath>
  4#include <cassert>
  5#include <stack>
  6#include <memory>
  7#include "titan_cpplib/others/print.cpp"
  8using namespace std;
  9
 10namespace titan23 {
 11
 12template <typename T>
 13class PersistentMultiset {
 14  private:
 15    class Node;
 16    using NodePtr = shared_ptr<Node>;
 17    // using NodePtr = Node*;
 18    static constexpr int DELTA = 3;
 19    static constexpr int GAMMA = 2;
 20    NodePtr root;
 21
 22    class Node {
 23      public:
 24        T key;
 25        int size, cnt, cnt_subtree;
 26        NodePtr left;
 27        NodePtr right;
 28
 29        Node(T key, int cnt = 1) : key(key), size(1), cnt(cnt), cnt_subtree(cnt), left(nullptr), right(nullptr) {}
 30
 31        NodePtr copy() const {
 32            NodePtr node = make_shared<Node>(key);
 33            node->size = size;
 34            node->cnt = cnt;
 35            node->cnt_subtree = cnt_subtree;
 36            node->left = left;
 37            node->right = right;
 38            return node;
 39        }
 40
 41        int weight_right() const {
 42            return right ? right->size + 1 : 1;
 43        }
 44
 45        int weight_left() const {
 46            return left ? left->size + 1 : 1;
 47        }
 48
 49        void update() {
 50            size = 1;
 51            cnt_subtree = cnt;
 52            if (left) {
 53                size += left->size;
 54                cnt_subtree += left->cnt_subtree;
 55            }
 56            if (right) {
 57                size += right->size;
 58                cnt_subtree += right->cnt_subtree;
 59            }
 60        }
 61
 62        void balance_check() const {
 63            if (!weight_left()*DELTA >= weight_right()) {
 64                cerr << weight_left() << ", " << weight_right() << endl;
 65                cerr << "not weight_left()*DELTA >= weight_right()." << endl;
 66                assert(false);
 67            }
 68            if (!weight_right() * DELTA >= weight_left()) {
 69                cerr << weight_left() << ", " << weight_right() << endl;
 70                cerr << "not weight_right() * DELTA >= weight_left()." << endl;
 71                assert(false);
 72            }
 73        }
 74
 75        void print() const {
 76            vector<T> a;
 77            auto dfs = [&] (auto &&dfs, const Node* node) -> void {
 78                if (!node) return;
 79                if (node->left)  dfs(dfs, node->left.get());
 80                a.emplace_back(node->key);
 81                if (node->right) dfs(dfs, node->right.get());
 82            };
 83            dfs(dfs, this);
 84            cerr << a << endl;
 85        }
 86
 87        void debug() const {
 88            cout << "this : key=" << key << ", size=" << size << endl;
 89            if (left)  cout << "to-left" << endl;
 90            if (right) cout << "to-right" << endl;
 91            cout << endl;
 92            if (left)  left->print();
 93            if (right) right->print();
 94        }
 95    };
 96
 97    void _build(vector<T> a) {
 98        auto build = [&] (auto &&build, int l, int r) -> NodePtr {
 99            int mid = (l + r) >> 1;
100            NodePtr node = make_shared<Node>(a[mid]);
101            if (l != mid) node->left = build(build, l, mid);
102            if (mid+1 != r) node->right = build(build, mid+1, r);
103            node->update();
104            return node;
105        };
106        sort(a.begin(), a.end());
107        root = build(build, 0, (int)a.size());
108    }
109
110    NodePtr _rotate_right(NodePtr &node) {
111        NodePtr u = node->left->copy();
112        node->left = u->right;
113        u->right = node;
114        node->update();
115        u->update();
116        return u;
117    }
118
119    NodePtr _rotate_left(NodePtr &node) {
120        NodePtr u = node->right->copy();
121        node->right = u->left;
122        u->left = node;
123        node->update();
124        u->update();
125        return u;
126    }
127
128    NodePtr _balance_left(NodePtr &node) {
129        node->right = node->right->copy();
130        NodePtr u = node->right;
131        if (node->right->weight_left() >= node->right->weight_right() * GAMMA) {
132            node->right = _rotate_right(u);
133        }
134        u = _rotate_left(node);
135        return u;
136    }
137
138    NodePtr _balance_right(NodePtr &node) {
139        node->left = node->left->copy();
140        NodePtr u = node->left;
141        if (node->left->weight_right() >= node->left->weight_left() * GAMMA) {
142            node->left = _rotate_left(u);
143        }
144        u = _rotate_right(node);
145        return u;
146    }
147
148    int weight(NodePtr node) const {
149        return node ? node->size + 1 : 1;
150    }
151
152    NodePtr _merge_with_root(NodePtr l, NodePtr root, NodePtr r) {
153        if (weight(r) * DELTA < weight(l)) {
154            l = l->copy();
155            l->right = _merge_with_root(l->right, root, r);
156            l->update();
157            if (weight(l->left) * DELTA < weight(l->right)) {
158                return _balance_left(l);
159            }
160            return l;
161        } else if (weight(l) * DELTA < weight(r)) {
162            r = r->copy();
163            r->left = _merge_with_root(l, root, r->left);
164            r->update();
165            if (weight(r->right) * DELTA < weight(r->left)) {
166                return _balance_right(r);
167            }
168            return r;
169        }
170        root = root->copy();
171        root->left = l;
172        root->right = r;
173        root->update();
174        return root;
175    }
176
177    pair<NodePtr, NodePtr> _pop_right(NodePtr &node) {
178        return _split_node_idx(node, node->size-1);
179    }
180
181    NodePtr _merge_node(NodePtr l, NodePtr r) {
182        if ((!l) && (!r)) {return nullptr;}
183        if (!l) {return r->copy();}
184        if (!r) {return l->copy();}
185        l = l->copy();
186        r = r->copy();
187        auto [l_, root_] = _pop_right(l);
188        return _merge_with_root(l_, root_, r);
189    }
190
191    pair<NodePtr, NodePtr> _split_node_key(NodePtr &node, const T &key) {
192        if (!node) { return {nullptr, nullptr}; }
193        if (node->key == key) {
194            return {_merge_with_root(node->left, node, nullptr), node->right};
195        } else if (node->key > key) {
196            auto [l, r] = _split_node_key(node->left, key);
197            return {l, _merge_with_root(r, node, node->right)};
198        } else {
199            auto [l, r] = _split_node_key(node->right, key);
200            return {_merge_with_root(node->left, node, l), r};
201        }
202    }
203
204    pair<NodePtr, NodePtr> _split_node_idx(NodePtr &node, int k) {
205        if (!node) {return {nullptr, nullptr};}
206        int tmp = node->left ? k-node->left->size : k;
207        if (tmp == 0) {
208            return {node->left, _merge_with_root(nullptr, node, node->right)};
209        } else if (tmp < 0) {
210            auto [l, r] = _split_node_idx(node->left, k);
211            return {l, _merge_with_root(r, node, node->right)};
212        } else {
213            auto [l, r] = _split_node_idx(node->right, tmp-1);
214            return {_merge_with_root(node->left, node, l), r};
215        }
216    }
217
218    PersistentMultiset<T> _new(NodePtr root) const {
219        return PersistentMultiset<T>(root);
220    }
221
222    PersistentMultiset(NodePtr root) : root(root) {}
223
224  public:
225    PersistentMultiset() : root(nullptr) {}
226
227    PersistentMultiset(vector<T> &a) { _build(a); }
228
229    PersistentMultiset<T> merge(PersistentMultiset<T> other) {
230        NodePtr root = _merge_node(this->root, other.root);
231        return _new(root);
232    }
233
234    pair<PersistentMultiset<T>, PersistentMultiset<T>> split(int k) {
235        auto [l, r] = _split_node(this->root, k);
236        return {_new(l), _new(r)};
237    }
238
239    NodePtr find(T key) const {
240        NodePtr node = root;
241        while (node) {
242            s.emplace(node);
243            if (key == node->key) return node;
244            node = key < node->key ? node->left : node->right;
245        }
246        return nullptr;
247    }
248
249    PersistentMultiset<T> add(T key, int cnt = 1) {
250        NodePtr it = find(key);
251        if (it != nullptr) {
252            assert(this->root);
253            NodePtr new_root = this->root->copy();
254            NodePtr node = new_root;
255            while (node) {
256                node->cnt_subtree += cnt;
257                if (key == node->key) {
258                    node->cnt += cnt;
259                    break;
260                }
261                node = key < node->key ? node->left->copy() : node->right->copy();
262            }
263            return _new(new_root);
264        }
265        auto [s, t] = _split_node_key(root, key);
266        NodePtr new_node = make_shared<Node>(key, cnt);
267        return _new(_merge_with_root(s, new_node, t));
268    }
269
270    PersistentMultiset<T> remove(T key) {
271        NodePtr it = find(key);
272        if (it != nullptr && it->cnt > 1) {
273            assert(this->root);
274            NodePtr new_root = this->root->copy();
275            NodePtr node = new_root;
276            while (node) {
277                node->cnt_subtree -= cnt;
278                if (key == node->key) {
279                    node->cnt -= cnt;
280                    break;
281                }
282                node = key < node->key ? node->left->copy() : node->right->copy();
283            }
284            return _new(new_root);
285        }
286        auto [s_, t] = _split_node_key(this->root, key);
287        auto [s, tmp] = _pop_right(s_);
288        assert(tmp->key == key);
289        NodePtr root = _merge_node(s, t);
290        return _new(root);
291    }
292
293    bool contains(T key) const {
294        NodePtr node = root;
295        while (node) {
296            if (key == node->key) return true;
297            node = key < node->key ? node->left : node->right;
298        }
299        return false;
300    }
301
302    T get(int k) const {
303        assert(0 <= k && k < len());
304        NodePtr node = root;
305        while (true) {
306            assert(node);
307            int t = node->left ? (node->cnt + node->left->cnt_subtree) : node->cnt;
308            if (t-node->cnt <= k && k < t) return node->key;
309            if (t > k) {
310                node = node->left;
311            } else {
312                k -= t;
313                node = node->right;
314            }
315        }
316    }
317
318    int index(const T &key) const {
319        int k = 0;
320        NodePtr node = root;
321        while (node) {
322            if (key == node->key) {
323                k += node->left ? node->left->cnt_subtree : 0;
324                break;
325            }
326            if (key < node->key) {
327                node = node->left;
328            } else {
329                k += node->left ? (node->left->cnt_subtree + node->cnt) : node->cnt;
330                node = node->right;
331            }
332        }
333        return k;
334    }
335
336    int index_right(const T &key) const {
337        int k = 0;
338        NodePtr node = root;
339        while (node) {
340            if (key == node->key) {
341                k += node->left ? (node->left->cnt_subtree + node->cnt) : node->cnt;
342                break;
343            }
344            if (key < node->key) {
345                node = node->left;
346            } else {
347                k += node->left ? (node->left->cnt_subtree + node->cnt) : node->cnt;
348                node = node->right;
349            }
350        }
351        return k;
352    }
353
354    pair<PersistentMultiset<T>, T> pop(int k) {
355        assert(0 <= k && k < len());
356        auto [s_, t] = _split_node(this->root, k+1);
357        auto [s, tmp] = _pop_right(s_);
358        NodePtr root = _merge_node(s, t);
359        return {_new(root), tmp->key};
360    }
361
362    vector<T> tovector() {
363        NodePtr node = root;
364        stack<NodePtr> s;
365        vector<T> a;
366        a.reserve(len());
367        while (!s.empty() || node) {
368            if (node) {
369                s.emplace(node);
370                node = node->left;
371            } else {
372                node = s.top(); s.pop();
373                a.emplace_back(node->key);
374                node = node->right;
375            }
376        }
377        return a;
378    }
379
380    PersistentMultiset<T> copy() const {
381        return _new(this->root ? this->root->copy() : nullptr);
382    }
383
384    T get(int k) {
385        assert(0 <= k && k < len());
386        NodePtr node = root;
387        while (1) {
388            int t = node->left ? node->left->size : 0;
389            if (t == k) {
390                return node->key;
391            }
392            if (t < k) {
393                k -= t + 1;
394                node = node->right;
395            } else {
396                node = node->left;
397            }
398        }
399    }
400
401    int len() const {
402        return root ? root->size : 0;
403    }
404
405    void check() const {
406        auto rec = [&] (auto &&rec, NodePtr node) -> pair<int, int> {
407            int ls = 0, rs = 0;
408            int height = 0;
409            int h;
410            if (node->left) {
411                pair<int, int> res = rec(rec, node->left);
412                ls = res.first;
413                h = res.second;
414                height = max(height, h);
415            }
416            if (node->right) {
417                pair<int, int> res = rec(rec, node->right);
418                rs = res.first;
419                h = res.second;
420                height = max(height, h);
421            }
422            node->balance_check();
423            int s = ls + rs + 1;
424            assert(s == node->size);
425            return {s, height+1};
426        };
427        if (root == nullptr) return;
428        auto [_, h] = rec(rec, root);
429        cerr << PRINT_GREEN << "OK : height=" << h << PRINT_NONE << endl;
430    }
431
432    friend ostream& operator<<(ostream& os, PersistentMultiset<T> &tree) {
433        vector<T> a = tree.tovector();
434        os << "{";
435        for (int i = 0; i < (int)a.size()-1; ++i) {
436            os << a[i] << ", ";
437        }
438        if (!a.empty()) os << a.back();
439        os << "}";
440        return os;
441    }
442};
443}  // namespace titan23

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/persistent_multiset.cpp