avl tree multiset¶
ソースコード¶
1#include <iostream>
2#include <algorithm>
3#include <vector>
4#include <cassert>
5#include "titan_cpplib/data_structures/bbst_node.cpp"
6using namespace std;
7
8// AVLTreeMultiset
9namespace titan23 {
10
11 template<typename T>
12 class AVLTreeMultiset {
13 public:
14 class AVLTreeMultisetNode {
15 public:
16 using AVLTreeMultisetNodePtr = AVLTreeMultisetNode*;
17 T key;
18 int val, valsize, height;
19 AVLTreeMultisetNodePtr par, left, right;
20
21 AVLTreeMultisetNode() {}
22 AVLTreeMultisetNode(const T &key, const int val) : key(key), val(val), valsize(val), height(1), par(nullptr), left(nullptr), right(nullptr) {}
23
24 int balance() const {
25 int hl = left ? left->height : 0;
26 int hr = right ? right->height : 0;
27 return hl - hr;
28 }
29
30 void update() {
31 valsize = val + (left ? left->valsize : 0) + (right ? right->valsize : 0);
32 height = 1 + max((left ? left->height : 0), (right ? right->height : 0));
33 }
34 };
35
36 T missing;
37 using AVLTreeMultisetNodePtr = AVLTreeMultisetNode*;
38 AVLTreeMultisetNodePtr root;
39
40 AVLTreeMultisetNodePtr build(vector<T> a) {
41 vector<T> x;
42 vector<int> y;
43
44 auto _build = [&] (auto &&_build, int l, int r) -> AVLTreeMultisetNodePtr {
45 int mid = (l + r) / 2;
46 AVLTreeMultisetNodePtr node = new AVLTreeMultisetNode(x[mid], y[mid]);
47 if (l != mid) {
48 node->left = _build(_build, l, mid);
49 node->left->par = node;
50 }
51 if (mid+1 != r) {
52 node->right = _build(_build, mid+1, r);
53 node->right->par = node;
54 }
55 node->update();
56 return node;
57 };
58
59 if (a.empty()) return nullptr;
60 int n = a.size();
61 bool is_sorted = true;
62 for (int i = 0; i < n-1; ++i) {
63 if (!(a[i] <= a[i+1])) {
64 is_sorted = false;
65 break;
66 }
67 }
68 if (!is_sorted) {
69 sort(a.begin(), a.end());
70 }
71
72 x = {a[0]};
73 y = {1};
74 for (int i = 1; i < n; ++i) {
75 if (a[i] == x.back()) {
76 ++y.back();
77 continue;
78 }
79 x.emplace_back(a[i]);
80 y.emplace_back(1);
81 }
82 return _build(_build, 0, x.size());
83 }
84
85 void _remove_balance(AVLTreeMultisetNodePtr node) {
86 while (node) {
87 AVLTreeMultisetNodePtr new_node = nullptr;
88 node->update();
89 if (node->balance() == 2) {
90 new_node = node->left->balance() == -1 ? BBSTNode<AVLTreeMultisetNodePtr>::rotate_LR(node) : BBSTNode<AVLTreeMultisetNodePtr>::rotate_right(node);
91 } else if (node->balance() == -2) {
92 new_node = node->right->balance() == 1 ? BBSTNode<AVLTreeMultisetNodePtr>::rotate_RL(node) : BBSTNode<AVLTreeMultisetNodePtr>::rotate_left(node);
93 } else if (node->balance() != 0) {
94 node = node->par;
95 break;
96 }
97 if (!new_node) {
98 node = node->par;
99 continue;
100 }
101 if (!new_node->par) {
102 this->root = new_node;
103 return;
104 }
105 node = new_node->par;
106 if (new_node->key < node->key) {
107 node->left = new_node;
108 } else {
109 node->right = new_node;
110 }
111 if (new_node->balance() != 0) break;
112 }
113 while (node) {
114 node->update();
115 node = node->par;
116 }
117 }
118
119 void _add_balance(AVLTreeMultisetNodePtr node) {
120 AVLTreeMultisetNodePtr new_node = nullptr;
121 while (node) {
122 node->update();
123 if (node->balance() == 0) {
124 node = node->par;
125 break;
126 }
127 if (node->balance() == 2) {
128 new_node = node->left->balance() == -1 ? BBSTNode<AVLTreeMultisetNodePtr>::rotate_LR(node) : BBSTNode<AVLTreeMultisetNodePtr>::rotate_right(node);
129 break;
130 } else if (node->balance() == -2) {
131 new_node = node->right->balance() == 1 ? BBSTNode<AVLTreeMultisetNodePtr>::rotate_RL(node) : BBSTNode<AVLTreeMultisetNodePtr>::rotate_left(node);
132 break;
133 }
134 node = node->par;
135 }
136 if (new_node) {
137 node = new_node->par;
138 if (node) {
139 if (new_node->key < node->key) {
140 node->left = new_node;
141 } else {
142 node->right = new_node;
143 }
144 } else {
145 this->root = new_node;
146 }
147 }
148 while (node) {
149 node->update();
150 node = node->par;
151 }
152 }
153
154 void _update_par(AVLTreeMultisetNodePtr node) {
155 while (node) {
156 node->update();
157 node = node->par;
158 }
159 }
160
161 AVLTreeMultisetNodePtr find_key(const T &key) const {
162 AVLTreeMultisetNodePtr node = root;
163 while (node) {
164 if (key == node->key) return node;
165 node = key < node->key ? node->left : node->right;
166 }
167 return nullptr;
168 }
169
170 AVLTreeMultisetNodePtr find_kth(int k) const {
171 AVLTreeMultisetNodePtr node = root;
172 while (true) {
173 assert(node);
174 int t = node->left ? (node->val + node->left->valsize) : node->val;
175 if (t-node->val <= k && k < t) return node;
176 if (t > k) {
177 node = node->left;
178 } else {
179 k -= t;
180 node = node->right;
181 }
182 }
183 }
184
185 public:
186 AVLTreeMultiset() : root(nullptr) {}
187 AVLTreeMultiset(T missing) : missing(missing), root(nullptr) {}
188 AVLTreeMultiset(vector<T> &a, T missing) : missing(missing) {
189 this->root = build(a);
190 }
191
192 void add(const T &key, int val=1) {
193 if (!root) {
194 root = new AVLTreeMultisetNode(key, val);
195 return;
196 }
197 AVLTreeMultisetNodePtr pnode = nullptr;
198 AVLTreeMultisetNodePtr node = root;
199 while (node) {
200 if (key == node->key) {
201 node->val += val;
202 _update_par(node);
203 return;
204 }
205 pnode = node;
206 node = key < node->key ? node->left : node->right;
207 }
208 if (key < pnode->key) {
209 pnode->left = new AVLTreeMultisetNode(key, val);
210 pnode->left->par = pnode;
211 } else {
212 pnode->right = new AVLTreeMultisetNode(key, val);
213 pnode->right->par = pnode;
214 }
215 _add_balance(pnode);
216 }
217
218 void remove_iter(AVLTreeMultisetNodePtr node) {
219 AVLTreeMultisetNodePtr pnode = node->par;
220 if (node->left && node->right) {
221 pnode = node;
222 AVLTreeMultisetNodePtr mnode = node->left;
223 while (mnode->right) {
224 pnode = mnode;
225 mnode = mnode->right;
226 }
227 node->key = mnode->key;
228 node->val = mnode->val;
229 node = mnode;
230 }
231 AVLTreeMultisetNodePtr cnode = (!node->left) ? node->right : node->left;
232 if (cnode) cnode->par = pnode;
233 if (pnode) {
234 if (node->key <= pnode->key) {
235 pnode->left = cnode;
236 } else {
237 pnode->right = cnode;
238 }
239 _remove_balance(pnode);
240 } else {
241 root = cnode;
242 }
243 }
244
245 bool discard(const T &key, int val=1) {
246 AVLTreeMultisetNodePtr node = find_key(key);
247 if (!node) return false;
248 node->val -= val;
249 if (node->val <= 0) {
250 remove_iter(node);
251 } else {
252 _update_par(node);
253 }
254 return true;
255 }
256
257 void remove(const T &key, int val=1) {
258 AVLTreeMultisetNodePtr node = find_key(key);
259 assert(node);
260 node->val -= val;
261 if (node->val <= 0) {
262 remove_iter(node);
263 } else {
264 _update_par(node);
265 }
266 }
267
268 T le(const T &key) const {
269 T res = missing;
270 AVLTreeMultisetNodePtr node = root;
271 while (node) {
272 if (key == node->key) {
273 res = node->key;
274 break;
275 }
276 if (key < node->key) {
277 node = node->left;
278 } else {
279 res = node->key;
280 node = node->right;
281 }
282 }
283 return res;
284 }
285
286 T lt(const T &key) const {
287 T res = missing;
288 AVLTreeMultisetNodePtr node = root;
289 while (node) {
290 if (key <= node->key) {
291 node = node->left;
292 } else {
293 res = node->key;
294 node = node->right;
295 }
296 }
297 return res;
298 }
299
300 T ge(const T &key) const {
301 T res = missing;
302 AVLTreeMultisetNodePtr node = root;
303 while (node) {
304 if (key == node->key) {
305 res = node->key;
306 break;
307 }
308 if (key < node->key) {
309 res = node->key;
310 node = node->left;
311 } else {
312 node = node->right;
313 }
314 }
315 return res;
316 }
317
318 T gt(const T &key) const {
319 T res = missing;
320 AVLTreeMultisetNodePtr node = root;
321 while (node) {
322 if (key < node->key) {
323 res = node->key;
324 node = node->left;
325 } else {
326 node = node->right;
327 }
328 }
329 return res;
330 }
331
332 int index(const T &key) const {
333 int k = 0;
334 AVLTreeMultisetNodePtr node = root;
335 while (node) {
336 if (key == node->key) {
337 k += node->left ? node->left->valsize : 0;
338 break;
339 }
340 if (key < node->key) {
341 node = node->left;
342 } else {
343 k += node->left ? (node->left->valsize + node->val) : node->val;
344 node = node->right;
345 }
346 }
347 return k;
348 }
349
350 int index_right(const T &key) const {
351 int k = 0;
352 AVLTreeMultisetNodePtr node = root;
353 while (node) {
354 if (key == node->key) {
355 k += node->left ? (node->left->valsize + node->val) : node->val;
356 break;
357 }
358 if (key < node->key) {
359 node = node->left;
360 } else {
361 k += node->left ? (node->left->valsize + node->val) : node->val;
362 node = node->right;
363 }
364 }
365 return k;
366 }
367
368 T pop(int k=-1) {
369 AVLTreeMultisetNodePtr node = find_kth(k);
370 T key = node->key;
371 node->val -= 1;
372 if (node->val == 0) {
373 remove_iter(node);
374 } else {
375 _update_par(node);
376 }
377 return key;
378 }
379
380 vector<T> tovector() const {
381 vector<T> a;
382 a.reserve(len());
383 vector<AVLTreeMultisetNodePtr> st;
384 AVLTreeMultisetNodePtr node = root;
385 while ((!st.empty()) || node) {
386 if (node) {
387 st.emplace_back(node);
388 node = node->left;
389 } else {
390 node = st.back();
391 st.pop_back();
392 for (int i = 0; i < node->val; ++i) {
393 a.emplace_back(node->key);
394 }
395 node = node->right;
396 }
397 }
398 return a;
399 }
400
401 bool contains(T key) const {
402 return find_key(key) != nullptr;
403 }
404
405 T get(int k) const {
406 return find_kth(k)->key;
407 }
408
409 int len() const {
410 return root ? root->valsize : 0;
411 }
412
413 void print() const {
414 vector<T> a = tovector();
415 int n = a.size();
416 cout << "{";
417 for (int i = 0; i < n-1; ++i) {
418 cout << a[i] << ", ";
419 }
420 if (n > 0) cout << a.back();
421 cout << "}" << endl;
422 }
423
424 void check() const {
425 if (!root) {
426 // cout << "height=0" << endl;
427 // cout << "check ok empty." << endl;
428 return;
429 }
430 // cout << "height=" << root->height << endl;
431
432 auto dfs = [&] (auto &&dfs, AVLTreeMultisetNodePtr node) -> void {
433 int h = 0;
434 int b = 0;
435 int vs = node->val;
436 if (node->left) {
437 assert(node->left->par == node);
438 assert(node->key > node->left->key);
439 dfs(dfs, node->left);
440 h = max(h, node->left->height);
441 b += node->left->height;
442 vs += node->left->valsize;
443 }
444 if (node->right) {
445 assert(node->right->par == node);
446 assert(node->key < node->right->key);
447 dfs(dfs, node->right);
448 h = max(h, node->right->height);
449 b -= node->right->height;
450 vs += node->right->valsize;
451 }
452 assert(node->valsize == vs);
453 assert(node->height == h+1);
454 assert(-1 <= b && b <= 1);
455 };
456 dfs(dfs, root);
457 // cout << "check ok." << endl;
458 }
459
460 friend ostream& operator<<(ostream& os, const titan23::AVLTreeMultiset<T>& s) {
461 vector<T> a = s.tovector();
462 int n = a.size();
463 os << "{";
464 for (int i = 0; i < n - 1; ++i) {
465 os << a[i] << ", ";
466 }
467 if (n > 0) os << a.back();
468 os << "}";
469 return os;
470 }
471 };
472} // namespace titan23
仕様¶
Warning
doxygenfile: Cannot find file “titan_cpplib/data_structures/avl_tree_multiset.cpp