persistent segment tree

ソースコード

  1#include <vector>
  2#include <stack>
  3#include <memory>
  4using namespace std;
  5
  6// PersistentSegmentTree
  7namespace titan23 {
  8
  9    template <class T,
 10              T (*op)(T, T),
 11              T (*e)()>
 12    class PersistentSegmentTree {
 13      private:
 14        struct Node;
 15
 16        using NodePtr = shared_ptr<Node>;
 17        // using NodePtr = Node*;
 18
 19        NodePtr root;
 20
 21        struct Node {
 22            T key, data;
 23            int size;
 24            NodePtr left, right;
 25
 26            Node() : size(0), left(nullptr), right(nullptr) {}
 27            Node(T key) : key(key), data(key), size(1), left(nullptr), right(nullptr) {}
 28
 29            NodePtr copy() {
 30                NodePtr node = make_shared<Node>(this->key);
 31                // NodePtr node = new Node(this->key);
 32                node->data = this->data;
 33                node->size = this->size;
 34                node->left = this->left;
 35                node->right = this->right;
 36                return node;
 37            }
 38
 39            void update() {
 40                this->size = 1;
 41                this->data = this->key;
 42                if (this->left) {
 43                    this->size += this->left->size;
 44                    this->data = op(this->left->data, this->data);
 45                }
 46                if (this->right) {
 47                    this->size += this->right->size;
 48                    this->data = op(this->data, this->right->data);
 49                }
 50            }
 51        };
 52
 53        void _build(const vector<T> &a) {
 54            auto build = [&] (auto &&build, int l, int r) -> NodePtr {
 55                int mid = (l + r) >> 1;
 56                NodePtr node = make_shared<Node>(a[mid]);
 57                // NodePtr node = new Node(a[mid]);
 58                if (l != mid) node->left = build(build, l, mid);
 59                if (mid+1 != r) node->right = build(build, mid+1, r);
 60                node->update();
 61                return node;
 62            };
 63
 64            if (a.empty()) {
 65                this->root = nullptr;
 66                return;
 67            }
 68            this->root = build(build, 0, (int)a.size());
 69        }
 70
 71        PersistentSegmentTree(NodePtr root) : root(root) {}
 72
 73        PersistentSegmentTree<T, op, e> _new(NodePtr root) const {
 74            return PersistentSegmentTree<T, op, e>(root);;
 75        }
 76
 77      public:
 78        PersistentSegmentTree() : root(nullptr) {}
 79        PersistentSegmentTree(const vector<T> a) {
 80            _build(a);
 81        }
 82
 83        T prod(int l, int r) const {
 84            assert(0 <= l && l <= r && r <= len());
 85
 86            auto dfs = [&] (auto &&dfs, NodePtr node, int left, int right) -> T {
 87                if (right <= l || r <= left) return e();
 88                if (l <= left && right < r) return node->data;
 89                int lsize = node->left ? node->left->size : 0;
 90                T res = e();
 91                if (node->left) {
 92                    res = dfs(dfs, node->left, left, left+lsize);
 93                }
 94                if (l <= left + lsize && left + lsize < r) {
 95                    res = op(res, node->key);
 96                }
 97                if (node->right) {
 98                    res = op(res, dfs(dfs, node->right, left+lsize+1, right));
 99                }
100                return res;
101            };
102
103            return dfs(dfs, this->root, 0, len());
104        }
105
106        PersistentSegmentTree<T, op, e> set(int k, T v) const {
107            assert(this->root);
108            NodePtr node = this->root->copy();
109            NodePtr nroot = node;
110            NodePtr pnode = nullptr;
111            int d = 0;
112            stack<NodePtr> path;
113            path.emplace(node);
114            while (true) {
115                int t = (node->left) ? node->left->size : 0;
116                if (t == k) {
117                    node = node->copy();
118                    node->key = v;
119                    path.emplace(node);
120                    if (pnode) {
121                        if (d) {
122                            pnode->left = node;
123                        } else {
124                            pnode->right = node;
125                        }
126                    } else {
127                        nroot = node;
128                    }
129                    while (!path.empty()) {
130                        node = path.top();
131                        path.pop();
132                        node->update();
133                    }
134                    return _new(nroot);
135                }
136
137                pnode = node;
138                if (t < k) {
139                    k -= t + 1;
140                    d = 0;
141                    node = node->right->copy();
142                    pnode->right = node;
143                } else {
144                    d = 1;
145                    node = node->left->copy();
146                    pnode->left = node;
147                }
148                path.emplace(node);
149            }
150        }
151
152        T get(int k) const {
153            assert(0 <= k && k < len());
154            assert(this->root);
155            NodePtr node = this->root;
156            while (true) {
157                int t = node->left ? node->left->size : 0;
158                if (t == k) {
159                    return node->key;
160                }
161                if (t < k) {
162                    k -= t + 1;
163                    node = node->right;
164                } else {
165                    node = node->left;
166                }
167            }
168        }
169
170        PersistentSegmentTree<T, op, e> copy() {
171            return _new(this->root ? this->root->copy() : nullptr);
172        }
173
174        vector<T> tolist() const {
175            vector<T> a;
176            a.reserve(len());
177            NodePtr node = root;
178            stack<NodePtr> s;
179            while (!s.empty() || node) {
180                if (node) {
181                    s.emplace(node);
182                    node = node->left;
183                } else {
184                    node = s.top();
185                    s.pop();
186                    a.emplace_back(node->key);
187                    node = node->right;
188                }
189            }
190            return a;
191        }
192
193        int len() const {
194            return this->root ? this->root->size : 0;
195        }
196
197        void print() const {
198            vector<T> a = tolist();
199            cout << "[";
200            for (int i = 0; i < (int)a.size(); ++i) {
201                cout << a[i];
202                if (i != (int)a.size()-1) {
203                    cout << ", ";
204                }
205            }
206            cout << "]" << endl;
207        }
208    };
209}  // namespace titan23

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/persistent_segment_tree.cpp