persistent set

ソースコード

  1#include <iostream>
  2#include <vector>
  3#include <cmath>
  4#include <cassert>
  5#include <stack>
  6#include <memory>
  7#include "titan_cpplib/others/print.cpp"
  8using namespace std;
  9
 10namespace titan23 {
 11
 12template <typename T>
 13class PersistentSet {
 14  private:
 15    class Node;
 16    using NodePtr = shared_ptr<Node>;
 17    // using NodePtr = Node*;
 18    static constexpr int DELTA = 3;
 19    static constexpr int GAMMA = 2;
 20    NodePtr root;
 21
 22    class Node {
 23      public:
 24        T key;
 25        int size;
 26        NodePtr left;
 27        NodePtr right;
 28
 29        Node(T key) : key(key), size(1), left(nullptr), right(nullptr) {}
 30
 31        NodePtr copy() const {
 32            NodePtr node = make_shared<Node>(key);
 33            node->size = size;
 34            node->left = left;
 35            node->right = right;
 36            return node;
 37        }
 38
 39        int weight_right() const {
 40            return right ? right->size + 1 : 1;
 41        }
 42
 43        int weight_left() const {
 44            return left ? left->size + 1 : 1;
 45        }
 46
 47        void update() {
 48            size = 1;
 49            if (left) size += left->size;
 50            if (right) size += right->size;
 51        }
 52
 53        void balance_check() const {
 54            if (!weight_left()*DELTA >= weight_right()) {
 55                cerr << weight_left() << ", " << weight_right() << endl;
 56                cerr << "not weight_left()*DELTA >= weight_right()." << endl;
 57                assert(false);
 58            }
 59            if (!weight_right() * DELTA >= weight_left()) {
 60                cerr << weight_left() << ", " << weight_right() << endl;
 61                cerr << "not weight_right() * DELTA >= weight_left()." << endl;
 62                assert(false);
 63            }
 64        }
 65
 66        void print() const {
 67            vector<T> a;
 68            auto dfs = [&] (auto &&dfs, const Node* node) -> void {
 69                if (!node) return;
 70                if (node->left)  dfs(dfs, node->left.get());
 71                a.emplace_back(node->key);
 72                if (node->right) dfs(dfs, node->right.get());
 73            };
 74            dfs(dfs, this);
 75            cerr << a << endl;
 76        }
 77
 78        void debug() const {
 79            cout << "this : key=" << key << ", size=" << size << endl;
 80            if (left)  cout << "to-left" << endl;
 81            if (right) cout << "to-right" << endl;
 82            cout << endl;
 83            if (left)  left->print();
 84            if (right) right->print();
 85        }
 86    };
 87
 88    void _build(vector<T> a) {
 89        auto build = [&] (auto &&build, int l, int r) -> NodePtr {
 90            int mid = (l + r) >> 1;
 91            NodePtr node = make_shared<Node>(a[mid]);
 92            if (l != mid) node->left = build(build, l, mid);
 93            if (mid+1 != r) node->right = build(build, mid+1, r);
 94            node->update();
 95            return node;
 96        };
 97        sort(a.begin(), a.end());
 98        root = build(build, 0, (int)a.size());
 99    }
100
101    NodePtr _rotate_right(NodePtr &node) {
102        NodePtr u = node->left->copy();
103        node->left = u->right;
104        u->right = node;
105        node->update();
106        u->update();
107        return u;
108    }
109
110    NodePtr _rotate_left(NodePtr &node) {
111        NodePtr u = node->right->copy();
112        node->right = u->left;
113        u->left = node;
114        node->update();
115        u->update();
116        return u;
117    }
118
119    NodePtr _balance_left(NodePtr &node) {
120        node->right = node->right->copy();
121        NodePtr u = node->right;
122        if (node->right->weight_left() >= node->right->weight_right() * GAMMA) {
123            node->right = _rotate_right(u);
124        }
125        u = _rotate_left(node);
126        return u;
127    }
128
129    NodePtr _balance_right(NodePtr &node) {
130        node->left = node->left->copy();
131        NodePtr u = node->left;
132        if (node->left->weight_right() >= node->left->weight_left() * GAMMA) {
133            node->left = _rotate_left(u);
134        }
135        u = _rotate_right(node);
136        return u;
137    }
138
139    int weight(NodePtr node) const {
140        return node ? node->size + 1 : 1;
141    }
142
143    NodePtr _merge_with_root(NodePtr l, NodePtr root, NodePtr r) {
144        if (weight(r) * DELTA < weight(l)) {
145            l = l->copy();
146            l->right = _merge_with_root(l->right, root, r);
147            l->update();
148            if (weight(l->left) * DELTA < weight(l->right)) {
149                return _balance_left(l);
150            }
151            return l;
152        } else if (weight(l) * DELTA < weight(r)) {
153            r = r->copy();
154            r->left = _merge_with_root(l, root, r->left);
155            r->update();
156            if (weight(r->right) * DELTA < weight(r->left)) {
157                return _balance_right(r);
158            }
159            return r;
160        }
161        root = root->copy();
162        root->left = l;
163        root->right = r;
164        root->update();
165        return root;
166    }
167
168    pair<NodePtr, NodePtr> _pop_right(NodePtr &node) {
169        return _split_node_idx(node, node->size-1);
170    }
171
172    NodePtr _merge_node(NodePtr l, NodePtr r) {
173        if ((!l) && (!r)) {return nullptr;}
174        if (!l) {return r->copy();}
175        if (!r) {return l->copy();}
176        l = l->copy();
177        r = r->copy();
178        auto [l_, root_] = _pop_right(l);
179        return _merge_with_root(l_, root_, r);
180    }
181
182    pair<NodePtr, NodePtr> _split_node_key(NodePtr &node, const T &key) {
183        if (!node) { return {nullptr, nullptr}; }
184        if (node->key == key) {
185            return {_merge_with_root(node->left, node, nullptr), node->right};
186        } else if (node->key > key) {
187            auto [l, r] = _split_node_key(node->left, key);
188            return {l, _merge_with_root(r, node, node->right)};
189        } else {
190            auto [l, r] = _split_node_key(node->right, key);
191            return {_merge_with_root(node->left, node, l), r};
192        }
193    }
194
195    pair<NodePtr, NodePtr> _split_node_idx(NodePtr &node, int k) {
196        if (!node) {return {nullptr, nullptr};}
197        int tmp = node->left ? k-node->left->size : k;
198        if (tmp == 0) {
199            return {node->left, _merge_with_root(nullptr, node, node->right)};
200        } else if (tmp < 0) {
201            auto [l, r] = _split_node_idx(node->left, k);
202            return {l, _merge_with_root(r, node, node->right)};
203        } else {
204            auto [l, r] = _split_node_idx(node->right, tmp-1);
205            return {_merge_with_root(node->left, node, l), r};
206        }
207    }
208
209    PersistentSet<T> _new(NodePtr root) const {
210        return PersistentSet<T>(root);
211    }
212
213    PersistentSet(NodePtr root) : root(root) {}
214
215  public:
216    PersistentSet() : root(nullptr) {}
217
218    PersistentSet(vector<T> &a) { _build(a); }
219
220    PersistentSet<T> merge(PersistentSet<T> other) {
221        NodePtr root = _merge_node(this->root, other.root);
222        return _new(root);
223    }
224
225    pair<PersistentSet<T>, PersistentSet<T>> split(int k) {
226        auto [l, r] = _split_node(this->root, k);
227        return {_new(l), _new(r)};
228    }
229
230    PersistentSet<T> add(T key) {
231        if (contains(key)) return _new(this->root->copy());
232        auto [s, t] = _split_node_key(root, key);
233        NodePtr new_node = make_shared<Node>(key);
234        return _new(_merge_with_root(s, new_node, t));
235    }
236
237    PersistentSet<T> remove(T key) {
238        if (!contains(key)) return _new(this->root ? this->root->copy() : nullptr);
239        auto [s_, t] = _split_node_key(this->root, key);
240        auto [s, tmp] = _pop_right(s_);
241        assert(tmp->key == key);
242        NodePtr root = _merge_node(s, t);
243        return _new(root);
244    }
245
246    bool contains(T key) const {
247        NodePtr node = root;
248        while (node) {
249            if (key == node->key) return true;
250            node = key < node->key ? node->left : node->right;
251        }
252        return false;
253    }
254
255    T get(int k) const {
256        assert(0 <= k && k < len());
257        NodePtr node = root;
258        while (true) {
259            assert(node);
260            int t = node->left ? (1 + node->left->size) : 1;
261            if (t-1 <= k && k < t) return node->key;
262            if (t > k) {
263                node = node->left;
264            } else {
265                k -= t;
266                node = node->right;
267            }
268        }
269    }
270
271    int index(const T &key) const {
272        int k = 0;
273        NodePtr node = root;
274        while (node) {
275            if (key == node->key) {
276                k += node->left ? node->left->size : 0;
277                break;
278            }
279            if (key < node->key) {
280                node = node->left;
281            } else {
282                k += node->left ? (node->left->size + 1) : 1;
283                node = node->right;
284            }
285        }
286        return k;
287    }
288
289    int index_right(const T &key) const {
290        int k = 0;
291        NodePtr node = root;
292        while (node) {
293            if (key == node->key) {
294                k += node->left ? (node->left->size + 1) : 1;
295                break;
296            }
297            if (key < node->key) {
298                node = node->left;
299            } else {
300                k += node->left ? (node->left->size + 1) : 1;
301                node = node->right;
302            }
303        }
304        return k;
305    }
306
307    pair<PersistentSet<T>, T> pop(int k) {
308        assert(0 <= k && k < len());
309        auto [s_, t] = _split_node(this->root, k+1);
310        auto [s, tmp] = _pop_right(s_);
311        NodePtr root = _merge_node(s, t);
312        return {_new(root), tmp->key};
313    }
314
315    vector<T> tovector() {
316        NodePtr node = root;
317        stack<NodePtr> s;
318        vector<T> a;
319        a.reserve(len());
320        while (!s.empty() || node) {
321            if (node) {
322                s.emplace(node);
323                node = node->left;
324            } else {
325                node = s.top(); s.pop();
326                a.emplace_back(node->key);
327                node = node->right;
328            }
329        }
330        return a;
331    }
332
333    PersistentSet<T> copy() const {
334        return _new(this->root ? this->root->copy() : nullptr);
335    }
336
337    T get(int k) {
338        assert(0 <= k && k < len());
339        NodePtr node = root;
340        while (1) {
341            int t = node->left ? node->left->size : 0;
342            if (t == k) {
343                return node->key;
344            }
345            if (t < k) {
346                k -= t + 1;
347                node = node->right;
348            } else {
349                node = node->left;
350            }
351        }
352    }
353
354    int len() const {
355        return root ? root->size : 0;
356    }
357
358    void check() const {
359        auto rec = [&] (auto &&rec, NodePtr node) -> pair<int, int> {
360            int ls = 0, rs = 0;
361            int height = 0;
362            int h;
363            if (node->left) {
364                pair<int, int> res = rec(rec, node->left);
365                ls = res.first;
366                h = res.second;
367                height = max(height, h);
368            }
369            if (node->right) {
370                pair<int, int> res = rec(rec, node->right);
371                rs = res.first;
372                h = res.second;
373                height = max(height, h);
374            }
375            node->balance_check();
376            int s = ls + rs + 1;
377            assert(s == node->size);
378            return {s, height+1};
379        };
380        if (root == nullptr) return;
381        auto [_, h] = rec(rec, root);
382        cerr << PRINT_GREEN << "OK : height=" << h << PRINT_NONE << endl;
383    }
384
385    friend ostream& operator<<(ostream& os, PersistentSet<T> &tree) {
386        vector<T> a = tree.tovector();
387        os << "{";
388        for (int i = 0; i < (int)a.size()-1; ++i) {
389            os << a[i] << ", ";
390        }
391        if (!a.empty()) os << a.back();
392        os << "}";
393        return os;
394    }
395};
396}  // namespace titan23

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/persistent_set.cpp