persistent multiset¶
ソースコード¶
1#include <iostream>
2#include <vector>
3#include <cmath>
4#include <cassert>
5#include <stack>
6#include <memory>
7#include "titan_cpplib/others/print.cpp"
8using namespace std;
9
10namespace titan23 {
11
12template <typename T>
13class PersistentMultiset {
14 private:
15 class Node;
16 using NodePtr = shared_ptr<Node>;
17 // using NodePtr = Node*;
18 static constexpr int DELTA = 3;
19 static constexpr int GAMMA = 2;
20 NodePtr root;
21
22 class Node {
23 public:
24 T key;
25 int size, cnt, cnt_subtree;
26 NodePtr left;
27 NodePtr right;
28
29 Node(T key, int cnt = 1) : key(key), size(1), cnt(cnt), cnt_subtree(cnt), left(nullptr), right(nullptr) {}
30
31 NodePtr copy() const {
32 NodePtr node = make_shared<Node>(key);
33 node->size = size;
34 node->cnt = cnt;
35 node->cnt_subtree = cnt_subtree;
36 node->left = left;
37 node->right = right;
38 return node;
39 }
40
41 int weight_right() const {
42 return right ? right->size + 1 : 1;
43 }
44
45 int weight_left() const {
46 return left ? left->size + 1 : 1;
47 }
48
49 void update() {
50 size = 1;
51 cnt_subtree = cnt;
52 if (left) {
53 size += left->size;
54 cnt_subtree += left->cnt_subtree;
55 }
56 if (right) {
57 size += right->size;
58 cnt_subtree += right->cnt_subtree;
59 }
60 }
61
62 void balance_check() const {
63 if (!weight_left()*DELTA >= weight_right()) {
64 cerr << weight_left() << ", " << weight_right() << endl;
65 cerr << "not weight_left()*DELTA >= weight_right()." << endl;
66 assert(false);
67 }
68 if (!weight_right() * DELTA >= weight_left()) {
69 cerr << weight_left() << ", " << weight_right() << endl;
70 cerr << "not weight_right() * DELTA >= weight_left()." << endl;
71 assert(false);
72 }
73 }
74
75 void print() const {
76 vector<T> a;
77 auto dfs = [&] (auto &&dfs, const Node* node) -> void {
78 if (!node) return;
79 if (node->left) dfs(dfs, node->left.get());
80 a.emplace_back(node->key);
81 if (node->right) dfs(dfs, node->right.get());
82 };
83 dfs(dfs, this);
84 cerr << a << endl;
85 }
86
87 void debug() const {
88 cout << "this : key=" << key << ", size=" << size << endl;
89 if (left) cout << "to-left" << endl;
90 if (right) cout << "to-right" << endl;
91 cout << endl;
92 if (left) left->print();
93 if (right) right->print();
94 }
95 };
96
97 void _build(vector<T> a) {
98 auto build = [&] (auto &&build, int l, int r) -> NodePtr {
99 int mid = (l + r) >> 1;
100 NodePtr node = make_shared<Node>(a[mid]);
101 if (l != mid) node->left = build(build, l, mid);
102 if (mid+1 != r) node->right = build(build, mid+1, r);
103 node->update();
104 return node;
105 };
106 sort(a.begin(), a.end());
107 root = build(build, 0, (int)a.size());
108 }
109
110 NodePtr _rotate_right(NodePtr &node) {
111 NodePtr u = node->left->copy();
112 node->left = u->right;
113 u->right = node;
114 node->update();
115 u->update();
116 return u;
117 }
118
119 NodePtr _rotate_left(NodePtr &node) {
120 NodePtr u = node->right->copy();
121 node->right = u->left;
122 u->left = node;
123 node->update();
124 u->update();
125 return u;
126 }
127
128 NodePtr _balance_left(NodePtr &node) {
129 node->right = node->right->copy();
130 NodePtr u = node->right;
131 if (node->right->weight_left() >= node->right->weight_right() * GAMMA) {
132 node->right = _rotate_right(u);
133 }
134 u = _rotate_left(node);
135 return u;
136 }
137
138 NodePtr _balance_right(NodePtr &node) {
139 node->left = node->left->copy();
140 NodePtr u = node->left;
141 if (node->left->weight_right() >= node->left->weight_left() * GAMMA) {
142 node->left = _rotate_left(u);
143 }
144 u = _rotate_right(node);
145 return u;
146 }
147
148 int weight(NodePtr node) const {
149 return node ? node->size + 1 : 1;
150 }
151
152 NodePtr _merge_with_root(NodePtr l, NodePtr root, NodePtr r) {
153 if (weight(r) * DELTA < weight(l)) {
154 l = l->copy();
155 l->right = _merge_with_root(l->right, root, r);
156 l->update();
157 if (weight(l->left) * DELTA < weight(l->right)) {
158 return _balance_left(l);
159 }
160 return l;
161 } else if (weight(l) * DELTA < weight(r)) {
162 r = r->copy();
163 r->left = _merge_with_root(l, root, r->left);
164 r->update();
165 if (weight(r->right) * DELTA < weight(r->left)) {
166 return _balance_right(r);
167 }
168 return r;
169 }
170 root = root->copy();
171 root->left = l;
172 root->right = r;
173 root->update();
174 return root;
175 }
176
177 pair<NodePtr, NodePtr> _pop_right(NodePtr &node) {
178 return _split_node_idx(node, node->size-1);
179 }
180
181 NodePtr _merge_node(NodePtr l, NodePtr r) {
182 if ((!l) && (!r)) {return nullptr;}
183 if (!l) {return r->copy();}
184 if (!r) {return l->copy();}
185 l = l->copy();
186 r = r->copy();
187 auto [l_, root_] = _pop_right(l);
188 return _merge_with_root(l_, root_, r);
189 }
190
191 pair<NodePtr, NodePtr> _split_node_key(NodePtr &node, const T &key) {
192 if (!node) { return {nullptr, nullptr}; }
193 if (node->key == key) {
194 return {_merge_with_root(node->left, node, nullptr), node->right};
195 } else if (node->key > key) {
196 auto [l, r] = _split_node_key(node->left, key);
197 return {l, _merge_with_root(r, node, node->right)};
198 } else {
199 auto [l, r] = _split_node_key(node->right, key);
200 return {_merge_with_root(node->left, node, l), r};
201 }
202 }
203
204 pair<NodePtr, NodePtr> _split_node_idx(NodePtr &node, int k) {
205 if (!node) {return {nullptr, nullptr};}
206 int tmp = node->left ? k-node->left->size : k;
207 if (tmp == 0) {
208 return {node->left, _merge_with_root(nullptr, node, node->right)};
209 } else if (tmp < 0) {
210 auto [l, r] = _split_node_idx(node->left, k);
211 return {l, _merge_with_root(r, node, node->right)};
212 } else {
213 auto [l, r] = _split_node_idx(node->right, tmp-1);
214 return {_merge_with_root(node->left, node, l), r};
215 }
216 }
217
218 PersistentMultiset<T> _new(NodePtr root) const {
219 return PersistentMultiset<T>(root);
220 }
221
222 PersistentMultiset(NodePtr root) : root(root) {}
223
224 public:
225 PersistentMultiset() : root(nullptr) {}
226
227 PersistentMultiset(vector<T> &a) { _build(a); }
228
229 PersistentMultiset<T> merge(PersistentMultiset<T> other) {
230 NodePtr root = _merge_node(this->root, other.root);
231 return _new(root);
232 }
233
234 pair<PersistentMultiset<T>, PersistentMultiset<T>> split(int k) {
235 auto [l, r] = _split_node(this->root, k);
236 return {_new(l), _new(r)};
237 }
238
239 NodePtr find(T key) const {
240 NodePtr node = root;
241 while (node) {
242 s.emplace(node);
243 if (key == node->key) return node;
244 node = key < node->key ? node->left : node->right;
245 }
246 return nullptr;
247 }
248
249 PersistentMultiset<T> add(T key, int cnt = 1) {
250 NodePtr it = find(key);
251 if (it != nullptr) {
252 assert(this->root);
253 NodePtr new_root = this->root->copy();
254 NodePtr node = new_root;
255 while (node) {
256 node->cnt_subtree += cnt;
257 if (key == node->key) {
258 node->cnt += cnt;
259 break;
260 }
261 node = key < node->key ? node->left->copy() : node->right->copy();
262 }
263 return _new(new_root);
264 }
265 auto [s, t] = _split_node_key(root, key);
266 NodePtr new_node = make_shared<Node>(key, cnt);
267 return _new(_merge_with_root(s, new_node, t));
268 }
269
270 PersistentMultiset<T> remove(T key) {
271 NodePtr it = find(key);
272 if (it != nullptr && it->cnt > 1) {
273 assert(this->root);
274 NodePtr new_root = this->root->copy();
275 NodePtr node = new_root;
276 while (node) {
277 node->cnt_subtree -= cnt;
278 if (key == node->key) {
279 node->cnt -= cnt;
280 break;
281 }
282 node = key < node->key ? node->left->copy() : node->right->copy();
283 }
284 return _new(new_root);
285 }
286 auto [s_, t] = _split_node_key(this->root, key);
287 auto [s, tmp] = _pop_right(s_);
288 assert(tmp->key == key);
289 NodePtr root = _merge_node(s, t);
290 return _new(root);
291 }
292
293 bool contains(T key) const {
294 NodePtr node = root;
295 while (node) {
296 if (key == node->key) return true;
297 node = key < node->key ? node->left : node->right;
298 }
299 return false;
300 }
301
302 T get(int k) const {
303 assert(0 <= k && k < len());
304 NodePtr node = root;
305 while (true) {
306 assert(node);
307 int t = node->left ? (node->cnt + node->left->cnt_subtree) : node->cnt;
308 if (t-node->cnt <= k && k < t) return node->key;
309 if (t > k) {
310 node = node->left;
311 } else {
312 k -= t;
313 node = node->right;
314 }
315 }
316 }
317
318 int index(const T &key) const {
319 int k = 0;
320 NodePtr node = root;
321 while (node) {
322 if (key == node->key) {
323 k += node->left ? node->left->cnt_subtree : 0;
324 break;
325 }
326 if (key < node->key) {
327 node = node->left;
328 } else {
329 k += node->left ? (node->left->cnt_subtree + node->cnt) : node->cnt;
330 node = node->right;
331 }
332 }
333 return k;
334 }
335
336 int index_right(const T &key) const {
337 int k = 0;
338 NodePtr node = root;
339 while (node) {
340 if (key == node->key) {
341 k += node->left ? (node->left->cnt_subtree + node->cnt) : node->cnt;
342 break;
343 }
344 if (key < node->key) {
345 node = node->left;
346 } else {
347 k += node->left ? (node->left->cnt_subtree + node->cnt) : node->cnt;
348 node = node->right;
349 }
350 }
351 return k;
352 }
353
354 pair<PersistentMultiset<T>, T> pop(int k) {
355 assert(0 <= k && k < len());
356 auto [s_, t] = _split_node(this->root, k+1);
357 auto [s, tmp] = _pop_right(s_);
358 NodePtr root = _merge_node(s, t);
359 return {_new(root), tmp->key};
360 }
361
362 vector<T> tovector() {
363 NodePtr node = root;
364 stack<NodePtr> s;
365 vector<T> a;
366 a.reserve(len());
367 while (!s.empty() || node) {
368 if (node) {
369 s.emplace(node);
370 node = node->left;
371 } else {
372 node = s.top(); s.pop();
373 a.emplace_back(node->key);
374 node = node->right;
375 }
376 }
377 return a;
378 }
379
380 PersistentMultiset<T> copy() const {
381 return _new(this->root ? this->root->copy() : nullptr);
382 }
383
384 T get(int k) {
385 assert(0 <= k && k < len());
386 NodePtr node = root;
387 while (1) {
388 int t = node->left ? node->left->size : 0;
389 if (t == k) {
390 return node->key;
391 }
392 if (t < k) {
393 k -= t + 1;
394 node = node->right;
395 } else {
396 node = node->left;
397 }
398 }
399 }
400
401 int len() const {
402 return root ? root->size : 0;
403 }
404
405 void check() const {
406 auto rec = [&] (auto &&rec, NodePtr node) -> pair<int, int> {
407 int ls = 0, rs = 0;
408 int height = 0;
409 int h;
410 if (node->left) {
411 pair<int, int> res = rec(rec, node->left);
412 ls = res.first;
413 h = res.second;
414 height = max(height, h);
415 }
416 if (node->right) {
417 pair<int, int> res = rec(rec, node->right);
418 rs = res.first;
419 h = res.second;
420 height = max(height, h);
421 }
422 node->balance_check();
423 int s = ls + rs + 1;
424 assert(s == node->size);
425 return {s, height+1};
426 };
427 if (root == nullptr) return;
428 auto [_, h] = rec(rec, root);
429 cerr << PRINT_GREEN << "OK : height=" << h << PRINT_NONE << endl;
430 }
431
432 friend ostream& operator<<(ostream& os, PersistentMultiset<T> &tree) {
433 vector<T> a = tree.tovector();
434 os << "{";
435 for (int i = 0; i < (int)a.size()-1; ++i) {
436 os << a[i] << ", ";
437 }
438 if (!a.empty()) os << a.back();
439 os << "}";
440 return os;
441 }
442};
443} // namespace titan23
仕様¶
Warning
doxygenfile: Cannot find file “titan_cpplib/data_structures/persistent_multiset.cpp