persistent segment tree¶
ソースコード¶
1#include <vector>
2#include <stack>
3#include <memory>
4using namespace std;
5
6// PersistentSegmentTree
7namespace titan23 {
8
9 template <class T,
10 T (*op)(T, T),
11 T (*e)()>
12 class PersistentSegmentTree {
13 private:
14 struct Node;
15
16 using NodePtr = shared_ptr<Node>;
17 // using NodePtr = Node*;
18
19 NodePtr root;
20
21 struct Node {
22 T key, data;
23 int size;
24 NodePtr left, right;
25
26 Node() : size(0), left(nullptr), right(nullptr) {}
27 Node(T key) : key(key), data(key), size(1), left(nullptr), right(nullptr) {}
28
29 NodePtr copy() {
30 NodePtr node = make_shared<Node>(this->key);
31 // NodePtr node = new Node(this->key);
32 node->data = this->data;
33 node->size = this->size;
34 node->left = this->left;
35 node->right = this->right;
36 return node;
37 }
38
39 void update() {
40 this->size = 1;
41 this->data = this->key;
42 if (this->left) {
43 this->size += this->left->size;
44 this->data = op(this->left->data, this->data);
45 }
46 if (this->right) {
47 this->size += this->right->size;
48 this->data = op(this->data, this->right->data);
49 }
50 }
51 };
52
53 void _build(const vector<T> &a) {
54 auto build = [&] (auto &&build, int l, int r) -> NodePtr {
55 int mid = (l + r) >> 1;
56 NodePtr node = make_shared<Node>(a[mid]);
57 // NodePtr node = new Node(a[mid]);
58 if (l != mid) node->left = build(build, l, mid);
59 if (mid+1 != r) node->right = build(build, mid+1, r);
60 node->update();
61 return node;
62 };
63
64 if (a.empty()) {
65 this->root = nullptr;
66 return;
67 }
68 this->root = build(build, 0, (int)a.size());
69 }
70
71 PersistentSegmentTree(NodePtr root) : root(root) {}
72
73 PersistentSegmentTree<T, op, e> _new(NodePtr root) const {
74 return PersistentSegmentTree<T, op, e>(root);;
75 }
76
77 public:
78 PersistentSegmentTree() : root(nullptr) {}
79 PersistentSegmentTree(const vector<T> a) {
80 _build(a);
81 }
82
83 T prod(int l, int r) const {
84 assert(0 <= l && l <= r && r <= len());
85
86 auto dfs = [&] (auto &&dfs, NodePtr node, int left, int right) -> T {
87 if (right <= l || r <= left) return e();
88 if (l <= left && right < r) return node->data;
89 int lsize = node->left ? node->left->size : 0;
90 T res = e();
91 if (node->left) {
92 res = dfs(dfs, node->left, left, left+lsize);
93 }
94 if (l <= left + lsize && left + lsize < r) {
95 res = op(res, node->key);
96 }
97 if (node->right) {
98 res = op(res, dfs(dfs, node->right, left+lsize+1, right));
99 }
100 return res;
101 };
102
103 return dfs(dfs, this->root, 0, len());
104 }
105
106 PersistentSegmentTree<T, op, e> set(int k, T v) const {
107 assert(this->root);
108 NodePtr node = this->root->copy();
109 NodePtr nroot = node;
110 NodePtr pnode = nullptr;
111 int d = 0;
112 stack<NodePtr> path;
113 path.emplace(node);
114 while (true) {
115 int t = (node->left) ? node->left->size : 0;
116 if (t == k) {
117 node = node->copy();
118 node->key = v;
119 path.emplace(node);
120 if (pnode) {
121 if (d) {
122 pnode->left = node;
123 } else {
124 pnode->right = node;
125 }
126 } else {
127 nroot = node;
128 }
129 while (!path.empty()) {
130 node = path.top();
131 path.pop();
132 node->update();
133 }
134 return _new(nroot);
135 }
136
137 pnode = node;
138 if (t < k) {
139 k -= t + 1;
140 d = 0;
141 node = node->right->copy();
142 pnode->right = node;
143 } else {
144 d = 1;
145 node = node->left->copy();
146 pnode->left = node;
147 }
148 path.emplace(node);
149 }
150 }
151
152 T get(int k) const {
153 assert(0 <= k && k < len());
154 assert(this->root);
155 NodePtr node = this->root;
156 while (true) {
157 int t = node->left ? node->left->size : 0;
158 if (t == k) {
159 return node->key;
160 }
161 if (t < k) {
162 k -= t + 1;
163 node = node->right;
164 } else {
165 node = node->left;
166 }
167 }
168 }
169
170 PersistentSegmentTree<T, op, e> copy() {
171 return _new(this->root ? this->root->copy() : nullptr);
172 }
173
174 vector<T> tolist() const {
175 vector<T> a;
176 a.reserve(len());
177 NodePtr node = root;
178 stack<NodePtr> s;
179 while (!s.empty() || node) {
180 if (node) {
181 s.emplace(node);
182 node = node->left;
183 } else {
184 node = s.top();
185 s.pop();
186 a.emplace_back(node->key);
187 node = node->right;
188 }
189 }
190 return a;
191 }
192
193 int len() const {
194 return this->root ? this->root->size : 0;
195 }
196
197 void print() const {
198 vector<T> a = tolist();
199 cout << "[";
200 for (int i = 0; i < (int)a.size(); ++i) {
201 cout << a[i];
202 if (i != (int)a.size()-1) {
203 cout << ", ";
204 }
205 }
206 cout << "]" << endl;
207 }
208 };
209} // namespace titan23
仕様¶
Warning
doxygenfile: Cannot find file “titan_cpplib/data_structures/persistent_segment_tree.cpp