multiset sum splay¶
ソースコード¶
1#include<bits/stdc++.h>
2using namespace std;
3
4namespace titan23 {
5
6template<typename T>
7class MultisetSum {
8 private:
9 struct Node;
10 using NodePtr = Node*;
11 stack<NodePtr> unused_node;
12 NodePtr root;
13
14 struct Node {
15 int size;
16 NodePtr par, left, right;
17 T key, sum;
18
19 Node() : size(0), par(nullptr), left(nullptr), right(nullptr) {}
20
21 Node(T key) : size(1), par(nullptr), left(nullptr), right(nullptr), key(key), sum(key) {}
22
23 void init(T &key) {
24 this->size = 1;
25 this->par = nullptr;
26 this->left = nullptr;
27 this->right = nullptr;
28 this->key = key;
29 this->sum = key;
30 }
31
32 void update() {
33 this->size = 1;
34 this->sum = this->key;
35 if (this->left) {
36 this->size += this->left->size;
37 this->sum += this->left->sum;
38 }
39 if (this->right) {
40 this->size += this->right->size;
41 this->sum += this->right->sum;
42 }
43 }
44
45 void rotate_right() {
46 NodePtr u = this->left;
47 assert(u);
48 this->left = u->right;
49 u->right = this;
50 if (this->par) {
51 if (this->par->left == this) {
52 this->par->left = u;
53 } else {
54 assert(this->par->right == this);
55 this->par->right = u;
56 }
57 }
58 u->par = this->par;
59 if (this->left) this->left->par = this;
60 this->par = u;
61 this->update();
62 u->update();
63 }
64
65 void rotate_left() {
66 NodePtr u = this->right;
67 assert(u);
68 this->right = u->left;
69 u->left = this;
70 if (this->par) {
71 if (this->par->left == this) {
72 this->par->left = u;
73 } else {
74 assert(this->par->right == this);
75 this->par->right = u;
76 }
77 }
78 u->par = this->par;
79 if (this->right) this->right->par = this;
80 this->par = u;
81 this->update();
82 u->update();
83 }
84
85 void splay() {
86 while (this->par && this->par->par) {
87 if (this->par->left == this) {
88 if (this->par->par->left == this->par) {
89 this->par->par->rotate_right();
90 this->par->rotate_right();
91 } else {
92 this->par->rotate_right();
93 this->par->rotate_left();
94 }
95 } else {
96 if (this->par->par->right == this->par) {
97 this->par->par->rotate_left();
98 this->par->rotate_left();
99 } else {
100 this->par->rotate_left();
101 this->par->rotate_right();
102 }
103 }
104 }
105 if (this->par) {
106 if (this->par->left == this) {
107 this->par->rotate_right();
108 } else {
109 this->par->rotate_left();
110 }
111 }
112 assert(this->par == nullptr);
113 }
114
115 NodePtr left_splay() {
116 NodePtr node = this;
117 while (node->left) node = node->left;
118 node->splay();
119 assert(node->left == nullptr);
120 return node;
121 }
122
123 NodePtr right_splay() {
124 NodePtr node = this;
125 while (node->right) node = node->right;
126 node->splay();
127 assert(node->right == nullptr);
128 return node;
129 }
130 };
131
132 NodePtr find_splay(NodePtr node, const T &key) {
133 NodePtr pnode = nullptr;
134 while (node) {
135 if (node->key == key) {
136 node->splay();
137 return node;
138 }
139 pnode = node;
140 if (key < node->key) {
141 node = node->left;
142 } else {
143 node = node->right;
144 }
145 }
146 if (pnode) {
147 pnode->splay();
148 return pnode;
149 }
150 return node;
151 }
152
153 NodePtr kth_splay(NodePtr node, int k) {
154 while (true) {
155 int t = node->left ? node->left->size : 0;
156 if (t == k) {
157 node->splay();
158 return node;
159 }
160 if (t < k) {
161 k -= t + 1;
162 node = node->right;
163 } else {
164 node = node->left;
165 }
166 }
167 }
168
169 void remove_root() {
170 assert(this->root && this->root->par == nullptr);
171 unused_node.emplace(this->root);
172 NodePtr new_root;
173 if (!this->root->left) {
174 new_root = this->root->right;
175 } else if (!this->root->right) {
176 new_root = this->root->left;
177 } else {
178 new_root = this->root->left;
179 new_root->par = nullptr;
180 new_root = new_root->right_splay();
181 new_root->right = this->root->right;
182 new_root->right->par = new_root;
183 new_root->update();
184 }
185 if (new_root) new_root->par = nullptr;
186 this->root = new_root;
187 }
188
189 MultisetSum(NodePtr root) : root(root) {}
190
191 // leftのsize==k
192 pair<NodePtr, NodePtr> split_node_kth(NodePtr node, int k) {
193 if (node == nullptr || k <= 0) return make_pair(nullptr, node);
194 if (k >= node->size) return make_pair(node, nullptr);
195 node = this->kth_splay(node, k);
196 NodePtr left_root = node->left;
197 if (left_root) {
198 left_root->par = nullptr;
199 node->left = nullptr;
200 node->update();
201 }
202 return make_pair(left_root, node);
203 }
204
205 NodePtr merge_node(NodePtr left, NodePtr right) {
206 if (left == nullptr) return right;
207 if (right == nullptr) return left;
208 left = left->right_splay();
209 left->right = right;
210 right->par = left;
211 left->update();
212 return left;
213 }
214 MultisetSum<T> gen(NodePtr root_node) const {
215 return MultisetSum<T>(root_node);
216 }
217
218 public:
219 MultisetSum() : root(nullptr) {}
220
221 pair<MultisetSum<T>, MultisetSum<T>> split(int k) {
222 auto [left, right] = split_node_kth(this->root, k);
223 return make_pair(gen(left), gen(right));
224 }
225
226 void merge(MultisetSum<T> &other) {
227 this->root = merge_node(this->root, other->root);
228 }
229
230 void print_node(NodePtr node) {
231 stack<NodePtr> st;
232 vector<T> a;
233 while ((!st.empty()) || node) {
234 if (node) {
235 st.emplace(node);
236 node = node->left;
237 } else {
238 node = st.top();
239 st.pop();
240 a.emplace_back(node->key);
241 node = node->right;
242 }
243 }
244 cout << "[";
245 int n = a.size();
246 for (int i = 0; i < n; ++i) {
247 cout << a[i] << ", ";
248 }
249 cout << "]" << endl;
250 }
251
252 //! [l, r)の和
253 T sum(int l, int r) {
254 NodePtr a, b, c;
255 tie(b, c) = split_node_kth(this->root, r);
256 tie(a, b) = split_node_kth(b, l);
257 T res = b ? b->sum : 0;
258 a = merge_node(a, b);
259 a = merge_node(a, c);
260 this->root = a;
261 return res;
262 }
263
264 bool discard(const T &key) {
265 if (this->root == nullptr) return false;
266 this->root = this->find_splay(this->root, key);
267 if (this->root->key == key) {
268 remove_root();
269 return true;
270 }
271 return false;
272 }
273
274 void remove(const T &key) {
275 assert(this->root != nullptr);
276 this->root = this->find_splay(this->root, key);
277 assert(this->root->key == key);
278 remove_root();
279 }
280
281 T pop(int k) {
282 assert(this->root != nullptr);
283 this->root = this->kth_splay(this->root, k);
284 T res = this->root->key;
285 remove_root();
286 return res;
287 }
288
289 void add(T key) {
290 this->root = this->find_splay(this->root, key);
291 NodePtr node;
292 if (unused_node.empty()) {
293 node = new Node(key);
294 } else {
295 node = unused_node.top();
296 unused_node.pop();
297 node->init(key);
298 }
299 if (this->root) {
300 if (this->root->key >= key) {
301 node->left = this->root->left;
302 if (node->left) node->left->par = node;
303 this->root->left = nullptr;
304 node->right = this->root;
305 node->right->par = node;
306 } else {
307 node->right = this->root->right;
308 if (node->right) node->right->par = node;
309 this->root->right = nullptr;
310 node->left = this->root;
311 node->left->par = node;
312 }
313 this->root->update();
314 node->update();
315 }
316 assert(node->par == nullptr);
317 this->root = node;
318 }
319
320 T get(int k) {
321 this->root = this->kth_splay(this->root, k);
322 return this->root->key;
323 }
324
325 int len() const {
326 return this->root ? this->root->size : 0;
327 }
328
329 int get_height() const {
330 auto rec = [&] (auto &&rec, NodePtr node) -> int {
331 if (node == nullptr) return 0;
332 int h = 0;
333 if (node->left) h = max(h, rec(rec, node->left));
334 if (node->right) h = max(h, rec(rec, node->right));
335 return h + 1;
336 };
337 return rec(rec, this->root);
338 }
339
340 vector<T> tovector() const {
341 NodePtr node = this->root;
342 stack<NodePtr> st;
343 vector<T> a;
344 a.reserve(len());
345 while ((!st.empty()) || node) {
346 if (node) {
347 st.emplace(node);
348 node = node->left;
349 } else {
350 node = st.top();
351 st.pop();
352 a.emplace_back(node->key);
353 node = node->right;
354 }
355 }
356 return a;
357 }
358
359 void test_sorted() const {
360 vector<T> a = tovector();
361 int n = a.size();
362 for (int i = 0; i < n-1; ++i) {
363 assert(a[i] <= a[i+1]);
364 }
365 }
366
367 void test() const {
368 auto dfs = [&] (auto dfs, NodePtr node, NodePtr pnode) {
369 if (node == nullptr) return;
370 assert(node->par == pnode);
371 dfs(dfs, node->left, node);
372 dfs(dfs, node->right, node);
373 };
374 dfs(dfs, this->root, nullptr);
375 }
376
377 void print() const {
378 vector<T> a = tovector();
379 int n = a.size();
380 cout << "[";
381 for (int i = 0; i < n-1; ++i) {
382 cout << a[i] << ", ";
383 }
384 if (n-1 >= 0) {
385 cout << a[n-1];
386 }
387 cout << "]" << endl;
388 }
389};
390}
仕様¶
Warning
doxygenfile: Cannot find file “titan_cpplib/data_structures/multiset_sum_splay.cpp