multiset sum splay

ソースコード

  1#include<bits/stdc++.h>
  2using namespace std;
  3
  4namespace titan23 {
  5
  6template<typename T>
  7class MultisetSum {
  8  private:
  9    struct Node;
 10    using NodePtr = Node*;
 11    stack<NodePtr> unused_node;
 12    NodePtr root;
 13
 14    struct Node {
 15        int size;
 16        NodePtr par, left, right;
 17        T key, sum;
 18
 19        Node() : size(0), par(nullptr), left(nullptr), right(nullptr) {}
 20
 21        Node(T key) : size(1), par(nullptr), left(nullptr), right(nullptr), key(key), sum(key) {}
 22
 23        void init(T &key) {
 24            this->size = 1;
 25            this->par = nullptr;
 26            this->left = nullptr;
 27            this->right = nullptr;
 28            this->key = key;
 29            this->sum = key;
 30        }
 31
 32        void update() {
 33            this->size = 1;
 34            this->sum = this->key;
 35            if (this->left) {
 36                this->size += this->left->size;
 37                this->sum += this->left->sum;
 38            }
 39            if (this->right) {
 40                this->size += this->right->size;
 41                this->sum += this->right->sum;
 42            }
 43        }
 44
 45        void rotate_right() {
 46            NodePtr u = this->left;
 47            assert(u);
 48            this->left = u->right;
 49            u->right = this;
 50            if (this->par) {
 51                if (this->par->left == this) {
 52                    this->par->left = u;
 53                } else {
 54                    assert(this->par->right == this);
 55                    this->par->right = u;
 56                }
 57            }
 58            u->par = this->par;
 59            if (this->left) this->left->par = this;
 60            this->par = u;
 61            this->update();
 62            u->update();
 63        }
 64
 65        void rotate_left() {
 66            NodePtr u = this->right;
 67            assert(u);
 68            this->right = u->left;
 69            u->left = this;
 70            if (this->par) {
 71                if (this->par->left == this) {
 72                    this->par->left = u;
 73                } else {
 74                    assert(this->par->right == this);
 75                    this->par->right = u;
 76                }
 77            }
 78            u->par = this->par;
 79            if (this->right) this->right->par = this;
 80            this->par = u;
 81            this->update();
 82            u->update();
 83        }
 84
 85        void splay() {
 86            while (this->par && this->par->par) {
 87                if (this->par->left == this) {
 88                    if (this->par->par->left == this->par) {
 89                        this->par->par->rotate_right();
 90                        this->par->rotate_right();
 91                    } else {
 92                        this->par->rotate_right();
 93                        this->par->rotate_left();
 94                    }
 95                } else {
 96                    if (this->par->par->right == this->par) {
 97                        this->par->par->rotate_left();
 98                        this->par->rotate_left();
 99                    } else {
100                        this->par->rotate_left();
101                        this->par->rotate_right();
102                    }
103                }
104            }
105            if (this->par) {
106                if (this->par->left == this) {
107                    this->par->rotate_right();
108                } else {
109                    this->par->rotate_left();
110                }
111            }
112            assert(this->par == nullptr);
113        }
114
115        NodePtr left_splay() {
116            NodePtr node = this;
117            while (node->left) node = node->left;
118            node->splay();
119            assert(node->left == nullptr);
120            return node;
121        }
122
123        NodePtr right_splay() {
124            NodePtr node = this;
125            while (node->right) node = node->right;
126            node->splay();
127            assert(node->right == nullptr);
128            return node;
129        }
130    };
131
132    NodePtr find_splay(NodePtr node, const T &key) {
133        NodePtr pnode = nullptr;
134        while (node) {
135            if (node->key == key) {
136                node->splay();
137                return node;
138            }
139            pnode = node;
140            if (key < node->key) {
141                node = node->left;
142            } else {
143                node = node->right;
144            }
145        }
146        if (pnode) {
147            pnode->splay();
148            return pnode;
149        }
150        return node;
151    }
152
153    NodePtr kth_splay(NodePtr node, int k) {
154        while (true) {
155            int t = node->left ? node->left->size : 0;
156            if (t == k) {
157                node->splay();
158                return node;
159            }
160            if (t < k) {
161                k -= t + 1;
162                node = node->right;
163            } else {
164                node = node->left;
165            }
166        }
167    }
168
169    void remove_root() {
170        assert(this->root && this->root->par == nullptr);
171        unused_node.emplace(this->root);
172        NodePtr new_root;
173        if (!this->root->left) {
174            new_root = this->root->right;
175        } else if (!this->root->right) {
176            new_root = this->root->left;
177        } else {
178            new_root = this->root->left;
179            new_root->par = nullptr;
180            new_root = new_root->right_splay();
181            new_root->right = this->root->right;
182            new_root->right->par = new_root;
183            new_root->update();
184        }
185        if (new_root) new_root->par = nullptr;
186        this->root = new_root;
187    }
188
189    MultisetSum(NodePtr root) : root(root) {}
190
191    // leftのsize==k
192    pair<NodePtr, NodePtr> split_node_kth(NodePtr node, int k) {
193        if (node == nullptr || k <= 0) return make_pair(nullptr, node);
194        if (k >= node->size) return make_pair(node, nullptr);
195        node = this->kth_splay(node, k);
196        NodePtr left_root = node->left;
197        if (left_root) {
198            left_root->par = nullptr;
199            node->left = nullptr;
200            node->update();
201        }
202        return make_pair(left_root, node);
203    }
204
205    NodePtr merge_node(NodePtr left, NodePtr right) {
206        if (left == nullptr) return right;
207        if (right == nullptr) return left;
208        left = left->right_splay();
209        left->right = right;
210        right->par = left;
211        left->update();
212        return left;
213    }
214    MultisetSum<T> gen(NodePtr root_node) const {
215        return MultisetSum<T>(root_node);
216    }
217
218  public:
219    MultisetSum() : root(nullptr) {}
220
221    pair<MultisetSum<T>, MultisetSum<T>> split(int k) {
222        auto [left, right] = split_node_kth(this->root, k);
223        return make_pair(gen(left), gen(right));
224    }
225
226    void merge(MultisetSum<T> &other) {
227        this->root = merge_node(this->root, other->root);
228    }
229
230    void print_node(NodePtr node) {
231        stack<NodePtr> st;
232        vector<T> a;
233        while ((!st.empty()) || node) {
234            if (node) {
235                st.emplace(node);
236                node = node->left;
237            } else {
238                node = st.top();
239                st.pop();
240                a.emplace_back(node->key);
241                node = node->right;
242            }
243        }
244        cout << "[";
245        int n = a.size();
246        for (int i = 0; i < n; ++i) {
247            cout << a[i] << ", ";
248        }
249        cout << "]" << endl;
250    }
251
252    //! [l, r)の和
253    T sum(int l, int r) {
254        NodePtr a, b, c;
255        tie(b, c) = split_node_kth(this->root, r);
256        tie(a, b) = split_node_kth(b, l);
257        T res = b ? b->sum : 0;
258        a = merge_node(a, b);
259        a = merge_node(a, c);
260        this->root = a;
261        return res;
262    }
263
264    bool discard(const T &key) {
265        if (this->root == nullptr) return false;
266        this->root = this->find_splay(this->root, key);
267        if (this->root->key == key) {
268            remove_root();
269            return true;
270        }
271        return false;
272    }
273
274    void remove(const T &key) {
275        assert(this->root != nullptr);
276        this->root = this->find_splay(this->root, key);
277        assert(this->root->key == key);
278        remove_root();
279    }
280
281    T pop(int k) {
282        assert(this->root != nullptr);
283        this->root = this->kth_splay(this->root, k);
284        T res = this->root->key;
285        remove_root();
286        return res;
287    }
288
289    void add(T key) {
290        this->root = this->find_splay(this->root, key);
291        NodePtr node;
292        if (unused_node.empty()) {
293            node = new Node(key);
294        } else {
295            node = unused_node.top();
296            unused_node.pop();
297            node->init(key);
298        }
299        if (this->root) {
300            if (this->root->key >= key) {
301                node->left = this->root->left;
302                if (node->left) node->left->par = node;
303                this->root->left = nullptr;
304                node->right = this->root;
305                node->right->par = node;
306            } else {
307                node->right = this->root->right;
308                if (node->right) node->right->par = node;
309                this->root->right = nullptr;
310                node->left = this->root;
311                node->left->par = node;
312            }
313            this->root->update();
314            node->update();
315        }
316        assert(node->par == nullptr);
317        this->root = node;
318    }
319
320    T get(int k) {
321        this->root = this->kth_splay(this->root, k);
322        return this->root->key;
323    }
324
325    int len() const {
326        return this->root ? this->root->size : 0;
327    }
328
329    int get_height() const {
330        auto rec = [&] (auto &&rec, NodePtr node) -> int {
331            if (node == nullptr) return 0;
332            int h = 0;
333            if (node->left) h = max(h, rec(rec, node->left));
334            if (node->right) h = max(h, rec(rec, node->right));
335            return h + 1;
336        };
337        return rec(rec, this->root);
338    }
339
340    vector<T> tovector() const {
341        NodePtr node = this->root;
342        stack<NodePtr> st;
343        vector<T> a;
344        a.reserve(len());
345        while ((!st.empty()) || node) {
346            if (node) {
347                st.emplace(node);
348                node = node->left;
349            } else {
350                node = st.top();
351                st.pop();
352                a.emplace_back(node->key);
353                node = node->right;
354            }
355        }
356        return a;
357    }
358
359    void test_sorted() const {
360        vector<T> a = tovector();
361        int n = a.size();
362        for (int i = 0; i < n-1; ++i) {
363            assert(a[i] <= a[i+1]);
364        }
365    }
366
367    void test() const {
368        auto dfs = [&] (auto dfs, NodePtr node, NodePtr pnode) {
369            if (node == nullptr) return;
370            assert(node->par == pnode);
371            dfs(dfs, node->left, node);
372            dfs(dfs, node->right, node);
373        };
374        dfs(dfs, this->root, nullptr);
375    }
376
377    void print() const {
378        vector<T> a = tovector();
379        int n = a.size();
380        cout << "[";
381        for (int i = 0; i < n-1; ++i) {
382            cout << a[i] << ", ";
383        }
384        if (n-1 >= 0) {
385            cout << a[n-1];
386        }
387        cout << "]" << endl;
388    }
389};
390}

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/multiset_sum_splay.cpp