persistent set¶
ソースコード¶
1#include <iostream>
2#include <vector>
3#include <cmath>
4#include <cassert>
5#include <stack>
6#include <memory>
7#include "titan_cpplib/others/print.cpp"
8using namespace std;
9
10namespace titan23 {
11
12template <typename T>
13class PersistentSet {
14 private:
15 class Node;
16 using NodePtr = shared_ptr<Node>;
17 // using NodePtr = Node*;
18 static constexpr int DELTA = 3;
19 static constexpr int GAMMA = 2;
20 NodePtr root;
21
22 class Node {
23 public:
24 T key;
25 int size;
26 NodePtr left;
27 NodePtr right;
28
29 Node(T key) : key(key), size(1), left(nullptr), right(nullptr) {}
30
31 NodePtr copy() const {
32 NodePtr node = make_shared<Node>(key);
33 node->size = size;
34 node->left = left;
35 node->right = right;
36 return node;
37 }
38
39 int weight_right() const {
40 return right ? right->size + 1 : 1;
41 }
42
43 int weight_left() const {
44 return left ? left->size + 1 : 1;
45 }
46
47 void update() {
48 size = 1;
49 if (left) size += left->size;
50 if (right) size += right->size;
51 }
52
53 void balance_check() const {
54 if (!weight_left()*DELTA >= weight_right()) {
55 cerr << weight_left() << ", " << weight_right() << endl;
56 cerr << "not weight_left()*DELTA >= weight_right()." << endl;
57 assert(false);
58 }
59 if (!weight_right() * DELTA >= weight_left()) {
60 cerr << weight_left() << ", " << weight_right() << endl;
61 cerr << "not weight_right() * DELTA >= weight_left()." << endl;
62 assert(false);
63 }
64 }
65
66 void print() const {
67 vector<T> a;
68 auto dfs = [&] (auto &&dfs, const Node* node) -> void {
69 if (!node) return;
70 if (node->left) dfs(dfs, node->left.get());
71 a.emplace_back(node->key);
72 if (node->right) dfs(dfs, node->right.get());
73 };
74 dfs(dfs, this);
75 cerr << a << endl;
76 }
77
78 void debug() const {
79 cout << "this : key=" << key << ", size=" << size << endl;
80 if (left) cout << "to-left" << endl;
81 if (right) cout << "to-right" << endl;
82 cout << endl;
83 if (left) left->print();
84 if (right) right->print();
85 }
86 };
87
88 void _build(vector<T> a) {
89 auto build = [&] (auto &&build, int l, int r) -> NodePtr {
90 int mid = (l + r) >> 1;
91 NodePtr node = make_shared<Node>(a[mid]);
92 if (l != mid) node->left = build(build, l, mid);
93 if (mid+1 != r) node->right = build(build, mid+1, r);
94 node->update();
95 return node;
96 };
97 sort(a.begin(), a.end());
98 root = build(build, 0, (int)a.size());
99 }
100
101 NodePtr _rotate_right(NodePtr &node) {
102 NodePtr u = node->left->copy();
103 node->left = u->right;
104 u->right = node;
105 node->update();
106 u->update();
107 return u;
108 }
109
110 NodePtr _rotate_left(NodePtr &node) {
111 NodePtr u = node->right->copy();
112 node->right = u->left;
113 u->left = node;
114 node->update();
115 u->update();
116 return u;
117 }
118
119 NodePtr _balance_left(NodePtr &node) {
120 node->right = node->right->copy();
121 NodePtr u = node->right;
122 if (node->right->weight_left() >= node->right->weight_right() * GAMMA) {
123 node->right = _rotate_right(u);
124 }
125 u = _rotate_left(node);
126 return u;
127 }
128
129 NodePtr _balance_right(NodePtr &node) {
130 node->left = node->left->copy();
131 NodePtr u = node->left;
132 if (node->left->weight_right() >= node->left->weight_left() * GAMMA) {
133 node->left = _rotate_left(u);
134 }
135 u = _rotate_right(node);
136 return u;
137 }
138
139 int weight(NodePtr node) const {
140 return node ? node->size + 1 : 1;
141 }
142
143 NodePtr _merge_with_root(NodePtr l, NodePtr root, NodePtr r) {
144 if (weight(r) * DELTA < weight(l)) {
145 l = l->copy();
146 l->right = _merge_with_root(l->right, root, r);
147 l->update();
148 if (weight(l->left) * DELTA < weight(l->right)) {
149 return _balance_left(l);
150 }
151 return l;
152 } else if (weight(l) * DELTA < weight(r)) {
153 r = r->copy();
154 r->left = _merge_with_root(l, root, r->left);
155 r->update();
156 if (weight(r->right) * DELTA < weight(r->left)) {
157 return _balance_right(r);
158 }
159 return r;
160 }
161 root = root->copy();
162 root->left = l;
163 root->right = r;
164 root->update();
165 return root;
166 }
167
168 pair<NodePtr, NodePtr> _pop_right(NodePtr &node) {
169 return _split_node_idx(node, node->size-1);
170 }
171
172 NodePtr _merge_node(NodePtr l, NodePtr r) {
173 if ((!l) && (!r)) {return nullptr;}
174 if (!l) {return r->copy();}
175 if (!r) {return l->copy();}
176 l = l->copy();
177 r = r->copy();
178 auto [l_, root_] = _pop_right(l);
179 return _merge_with_root(l_, root_, r);
180 }
181
182 pair<NodePtr, NodePtr> _split_node_key(NodePtr &node, const T &key) {
183 if (!node) { return {nullptr, nullptr}; }
184 if (node->key == key) {
185 return {_merge_with_root(node->left, node, nullptr), node->right};
186 } else if (node->key > key) {
187 auto [l, r] = _split_node_key(node->left, key);
188 return {l, _merge_with_root(r, node, node->right)};
189 } else {
190 auto [l, r] = _split_node_key(node->right, key);
191 return {_merge_with_root(node->left, node, l), r};
192 }
193 }
194
195 pair<NodePtr, NodePtr> _split_node_idx(NodePtr &node, int k) {
196 if (!node) {return {nullptr, nullptr};}
197 int tmp = node->left ? k-node->left->size : k;
198 if (tmp == 0) {
199 return {node->left, _merge_with_root(nullptr, node, node->right)};
200 } else if (tmp < 0) {
201 auto [l, r] = _split_node_idx(node->left, k);
202 return {l, _merge_with_root(r, node, node->right)};
203 } else {
204 auto [l, r] = _split_node_idx(node->right, tmp-1);
205 return {_merge_with_root(node->left, node, l), r};
206 }
207 }
208
209 PersistentSet<T> _new(NodePtr root) const {
210 return PersistentSet<T>(root);
211 }
212
213 PersistentSet(NodePtr root) : root(root) {}
214
215 public:
216 PersistentSet() : root(nullptr) {}
217
218 PersistentSet(vector<T> &a) { _build(a); }
219
220 PersistentSet<T> merge(PersistentSet<T> other) {
221 NodePtr root = _merge_node(this->root, other.root);
222 return _new(root);
223 }
224
225 pair<PersistentSet<T>, PersistentSet<T>> split(int k) {
226 auto [l, r] = _split_node(this->root, k);
227 return {_new(l), _new(r)};
228 }
229
230 PersistentSet<T> add(T key) {
231 if (contains(key)) return _new(this->root->copy());
232 auto [s, t] = _split_node_key(root, key);
233 NodePtr new_node = make_shared<Node>(key);
234 return _new(_merge_with_root(s, new_node, t));
235 }
236
237 PersistentSet<T> remove(T key) {
238 if (!contains(key)) return _new(this->root ? this->root->copy() : nullptr);
239 auto [s_, t] = _split_node_key(this->root, key);
240 auto [s, tmp] = _pop_right(s_);
241 assert(tmp->key == key);
242 NodePtr root = _merge_node(s, t);
243 return _new(root);
244 }
245
246 bool contains(T key) const {
247 NodePtr node = root;
248 while (node) {
249 if (key == node->key) return true;
250 node = key < node->key ? node->left : node->right;
251 }
252 return false;
253 }
254
255 T get(int k) const {
256 assert(0 <= k && k < len());
257 NodePtr node = root;
258 while (true) {
259 assert(node);
260 int t = node->left ? (1 + node->left->size) : 1;
261 if (t-1 <= k && k < t) return node->key;
262 if (t > k) {
263 node = node->left;
264 } else {
265 k -= t;
266 node = node->right;
267 }
268 }
269 }
270
271 int index(const T &key) const {
272 int k = 0;
273 NodePtr node = root;
274 while (node) {
275 if (key == node->key) {
276 k += node->left ? node->left->size : 0;
277 break;
278 }
279 if (key < node->key) {
280 node = node->left;
281 } else {
282 k += node->left ? (node->left->size + 1) : 1;
283 node = node->right;
284 }
285 }
286 return k;
287 }
288
289 int index_right(const T &key) const {
290 int k = 0;
291 NodePtr node = root;
292 while (node) {
293 if (key == node->key) {
294 k += node->left ? (node->left->size + 1) : 1;
295 break;
296 }
297 if (key < node->key) {
298 node = node->left;
299 } else {
300 k += node->left ? (node->left->size + 1) : 1;
301 node = node->right;
302 }
303 }
304 return k;
305 }
306
307 pair<PersistentSet<T>, T> pop(int k) {
308 assert(0 <= k && k < len());
309 auto [s_, t] = _split_node(this->root, k+1);
310 auto [s, tmp] = _pop_right(s_);
311 NodePtr root = _merge_node(s, t);
312 return {_new(root), tmp->key};
313 }
314
315 vector<T> tovector() {
316 NodePtr node = root;
317 stack<NodePtr> s;
318 vector<T> a;
319 a.reserve(len());
320 while (!s.empty() || node) {
321 if (node) {
322 s.emplace(node);
323 node = node->left;
324 } else {
325 node = s.top(); s.pop();
326 a.emplace_back(node->key);
327 node = node->right;
328 }
329 }
330 return a;
331 }
332
333 PersistentSet<T> copy() const {
334 return _new(this->root ? this->root->copy() : nullptr);
335 }
336
337 T get(int k) {
338 assert(0 <= k && k < len());
339 NodePtr node = root;
340 while (1) {
341 int t = node->left ? node->left->size : 0;
342 if (t == k) {
343 return node->key;
344 }
345 if (t < k) {
346 k -= t + 1;
347 node = node->right;
348 } else {
349 node = node->left;
350 }
351 }
352 }
353
354 int len() const {
355 return root ? root->size : 0;
356 }
357
358 void check() const {
359 auto rec = [&] (auto &&rec, NodePtr node) -> pair<int, int> {
360 int ls = 0, rs = 0;
361 int height = 0;
362 int h;
363 if (node->left) {
364 pair<int, int> res = rec(rec, node->left);
365 ls = res.first;
366 h = res.second;
367 height = max(height, h);
368 }
369 if (node->right) {
370 pair<int, int> res = rec(rec, node->right);
371 rs = res.first;
372 h = res.second;
373 height = max(height, h);
374 }
375 node->balance_check();
376 int s = ls + rs + 1;
377 assert(s == node->size);
378 return {s, height+1};
379 };
380 if (root == nullptr) return;
381 auto [_, h] = rec(rec, root);
382 cerr << PRINT_GREEN << "OK : height=" << h << PRINT_NONE << endl;
383 }
384
385 friend ostream& operator<<(ostream& os, PersistentSet<T> &tree) {
386 vector<T> a = tree.tovector();
387 os << "{";
388 for (int i = 0; i < (int)a.size()-1; ++i) {
389 os << a[i] << ", ";
390 }
391 if (!a.empty()) os << a.back();
392 os << "}";
393 return os;
394 }
395};
396} // namespace titan23
仕様¶
Warning
doxygenfile: Cannot find file “titan_cpplib/data_structures/persistent_set.cpp