avl tree set¶
ソースコード¶
1#include <iostream>
2#include <algorithm>
3#include <vector>
4#include <cassert>
5#include <stack>
6#include "titan_cpplib/data_structures/bbst_node.cpp"
7using namespace std;
8
9// AVLTreeSet
10namespace titan23 {
11
12 template<typename T>
13 class AVLTreeSet {
14 public:
15 class AVLTreeSetNode {
16 public:
17 using AVLTreeSetNodePtr = AVLTreeSetNode*;
18 T key;
19 int size, height;
20 AVLTreeSetNodePtr par, left, right;
21
22 AVLTreeSetNode() : size(0), height(0), par(nullptr), left(nullptr), right(nullptr) {}
23 AVLTreeSetNode(const T &key) : key(key), size(1), height(1), par(nullptr), left(nullptr), right(nullptr) {}
24
25 int balance() const {
26 int hl = left ? left->height : 0;
27 int hr = right ? right->height : 0;
28 return hl - hr;
29 }
30
31 void update() {
32 size = 1 + (left ? left->size : 0) + (right ? right->size : 0);
33 height = 1 + max((left ? left->height : 0), (right ? right->height : 0));
34 }
35 };
36
37 private:
38 T missing;
39 using AVLTreeSetNodePtr = AVLTreeSetNode*;
40 AVLTreeSetNodePtr root;
41
42 void _remove_balance(AVLTreeSetNodePtr node) {
43 while (node) {
44 AVLTreeSetNodePtr new_node = nullptr;
45 node->update();
46 if (node->balance() == 2) {
47 new_node = node->left->balance() == -1 ? BBSTNode<AVLTreeSetNodePtr>::rotate_LR(node) : BBSTNode<AVLTreeSetNodePtr>::rotate_right(node);
48 } else if (node->balance() == -2) {
49 new_node = node->right->balance() == 1 ? BBSTNode<AVLTreeSetNodePtr>::rotate_RL(node) : BBSTNode<AVLTreeSetNodePtr>::rotate_left(node);
50 } else if (node->balance() != 0) {
51 node = node->par;
52 break;
53 }
54 if (!new_node) {
55 node = node->par;
56 continue;
57 }
58 if (!new_node->par) {
59 this->root = new_node;
60 return;
61 }
62 node = new_node->par;
63 if (new_node->key < node->key) {
64 node->left = new_node;
65 } else {
66 node->right = new_node;
67 }
68 if (new_node->balance() != 0) break;
69 }
70 while (node) {
71 node->update();
72 node = node->par;
73 }
74 }
75
76 void _add_balance(AVLTreeSetNodePtr node) {
77 AVLTreeSetNodePtr new_node = nullptr;
78 while (node) {
79 node->update();
80 if (node->balance() == 0) {
81 node = node->par;
82 break;
83 }
84 if (node->balance() == 2) {
85 new_node = node->left->balance() == -1 ? BBSTNode<AVLTreeSetNodePtr>::rotate_LR(node) : BBSTNode<AVLTreeSetNodePtr>::rotate_right(node);
86 break;
87 } else if (node->balance() == -2) {
88 new_node = node->right->balance() == 1 ? BBSTNode<AVLTreeSetNodePtr>::rotate_RL(node) : BBSTNode<AVLTreeSetNodePtr>::rotate_left(node);
89 break;
90 }
91 node = node->par;
92 }
93 if (new_node) {
94 node = new_node->par;
95 if (node) {
96 if (new_node->key < node->key) {
97 node->left = new_node;
98 } else {
99 node->right = new_node;
100 }
101 } else {
102 this->root = new_node;
103 }
104 }
105 while (node) {
106 node->update();
107 node = node->par;
108 }
109 }
110
111 public:
112 AVLTreeSet() : missing(-1), root(nullptr) {}
113 AVLTreeSet(T missing) : missing(-1), root(nullptr) {}
114 AVLTreeSet(vector<T> &a, T missing) : missing(-1) {
115 this->root = build(a);
116 }
117
118 AVLTreeSetNodePtr build(vector<T> a) {
119 auto _build = [&] (auto &&_build, int l, int r) -> AVLTreeSetNodePtr {
120 int mid = (l + r) / 2;
121 AVLTreeSetNodePtr node = new AVLTreeSetNode(a[mid]);
122 if (l != mid) {
123 node->left = _build(_build, l, mid);
124 node->left->par = node;
125 }
126 if (mid+1 != r) {
127 node->right = _build(_build, mid+1, r);
128 node->right->par = node;
129 }
130 node->update();
131 return node;
132 };
133
134 if (a.empty()) return nullptr;
135 int n = a.size();
136 bool is_sorted = true;
137 for (int i = 0; i < n-1; ++i) {
138 if (!(a[i] < a[i+1])) {
139 is_sorted = false;
140 break;
141 }
142 }
143 if (!is_sorted) {
144 sort(a.begin(), a.end());
145 a.erase(unique(a.begin(), a.end()), a.end());
146 }
147 return _build(_build, 0, a.size());
148 }
149
150 bool add(const T &key) {
151 if (!root) {
152 root = new AVLTreeSetNode(key);
153 return true;
154 }
155 AVLTreeSetNodePtr pnode = nullptr;
156 AVLTreeSetNodePtr node = root;
157 while (node) {
158 if (key == node->key) return false;
159 pnode = node;
160 node = key < node->key ? node->left : node->right;
161 }
162 if (key < pnode->key) {
163 pnode->left = new AVLTreeSetNode(key);
164 pnode->left->par = pnode;
165 } else {
166 pnode->right = new AVLTreeSetNode(key);
167 pnode->right->par = pnode;
168 }
169 _add_balance(pnode);
170 return true;
171 }
172
173 void remove_iter(AVLTreeSetNodePtr node) {
174 AVLTreeSetNodePtr pnode = node->par;
175 if (node->left && node->right) {
176 pnode = node;
177 AVLTreeSetNodePtr mnode = node->left;
178 while (mnode->right) {
179 pnode = mnode;
180 mnode = mnode->right;
181 }
182 node->key = mnode->key;
183 node = mnode;
184 }
185 AVLTreeSetNodePtr cnode = (!node->left) ? node->right : node->left;
186 if (cnode) {
187 cnode->par = pnode;
188 }
189 if (pnode) {
190 if (node->key <= pnode->key) {
191 pnode->left = cnode;
192 } else {
193 pnode->right = cnode;
194 }
195 _remove_balance(pnode);
196 } else {
197 root = cnode;
198 }
199 }
200
201 bool discard(const T &key) {
202 AVLTreeSetNodePtr node = find_key(key);
203 if (!node) return false;
204 remove_iter(node);
205 return true;
206 }
207
208 void remove(const T &key) {
209 AVLTreeSetNodePtr node = find_key(key);
210 assert(node);
211 remove_iter(node);
212 }
213
214 T le(const T &key) const {
215 T res = missing;
216 AVLTreeSetNodePtr node = root;
217 while (node) {
218 if (key == node->key) {
219 res = node->key;
220 break;
221 }
222 if (key < node->key) {
223 node = node->left;
224 } else {
225 res = node->key;
226 node = node->right;
227 }
228 }
229 return res;
230 }
231
232 T lt(const T &key) const {
233 T res = missing;
234 AVLTreeSetNodePtr node = root;
235 while (node) {
236 if (key <= node->key) {
237 node = node->left;
238 } else {
239 res = node->key;
240 node = node->right;
241 }
242 }
243 return res;
244 }
245
246 T ge(const T &key) const {
247 T res = missing;
248 AVLTreeSetNodePtr node = root;
249 while (node) {
250 if (key == node->key) {
251 res = node->key;
252 break;
253 }
254 if (key < node->key) {
255 res = node->key;
256 node = node->left;
257 } else {
258 node = node->right;
259 }
260 }
261 return res;
262 }
263
264 T gt(const T &key) const {
265 T res = missing;
266 AVLTreeSetNodePtr node = root;
267 while (node) {
268 if (key < node->key) {
269 res = node->key;
270 node = node->left;
271 } else {
272 node = node->right;
273 }
274 }
275 return res;
276 }
277
278 int index(const T &key) const {
279 int k = 0;
280 AVLTreeSetNodePtr node = root;
281 while (node) {
282 if (key == node->key) {
283 k += node->left ? node->left->size : 0;
284 break;
285 }
286 if (key < node->key) {
287 node = node->left;
288 } else {
289 k += node->left ? (node->left->size + 1) : 1;
290 node = node->right;
291 }
292 }
293 return k;
294 }
295
296 int index_right(const T &key) const {
297 int k = 0;
298 AVLTreeSetNodePtr node = root;
299 while (node) {
300 if (key == node->key) {
301 k += node->left ? (node->left->size + 1) : 1;
302 break;
303 }
304 if (key < node->key) {
305 node = node->left;
306 } else {
307 k += node->left ? (node->left->size + 1) : 1;
308 node = node->right;
309 }
310 }
311 return k;
312 }
313
314 AVLTreeSetNodePtr find_key(const T &key) const {
315 AVLTreeSetNodePtr node = root;
316 while (node) {
317 if (key == node->key) return node;
318 node = key < node->key ? node->left : node->right;
319 }
320 return nullptr;
321 }
322
323 AVLTreeSetNodePtr find_kth(int k) const {
324 AVLTreeSetNodePtr node = root;
325 while (true) {
326 int t = node->left ? node->left->size : 0;
327 if (t == k) return node;
328 if (t < k) {
329 k -= t + 1;
330 node = node->right;
331 } else {
332 node = node->left;
333 }
334 }
335 }
336
337 T pop(int k=-1) {
338 AVLTreeSetNodePtr node = find_kth(k);
339 T key = node->key;
340 remove_iter(node);
341 return key;
342 }
343
344 vector<T> tovector() const {
345 vector<T> a;
346 a.reserve(len());
347 stack<AVLTreeSetNodePtr> st;
348 AVLTreeSetNodePtr node = root;
349 while ((!st.empty()) || node) {
350 if (node) {
351 st.emplace(node);
352 node = node->left;
353 } else {
354 node = st.top();
355 st.pop();
356 a.emplace_back(node->key);
357 node = node->right;
358 }
359 }
360 return a;
361 }
362
363 bool contains(T key) const {
364 return find_key(key) != nullptr;
365 }
366
367 T get(int k) const {
368 return find_kth(k)->key;
369 }
370
371 int len() const {
372 return root ? root->size : 0;
373 }
374
375 void check() const {
376 if (!root) {
377 cout << "height=0" << endl;;
378 cout << "check ok empty." << endl;;
379 return;
380 }
381 cout << "height=" << root->height << endl;
382
383 auto dfs = [&] (auto &&dfs, AVLTreeSetNodePtr node) -> void {
384 int h = 0;
385 int b = 0;
386 int s = 1;
387 if (node->left) {
388 assert(node->left->par == node);
389 assert(node->key > node->left->key);
390 dfs(dfs, node->left);
391 h = max(h, node->left->height);
392 b += node->left->height;
393 s += node->left->size;
394 }
395 if (node->right) {
396 assert(node->right->par == node);
397 assert(node->key < node->right->key);
398 dfs(dfs, node->right);
399 h = max(h, node->right->height);
400 b -= node->right->height;
401 s += node->right->size;
402 }
403 assert(node->height == h+1);
404 assert(-1 <= b && b <= 1);
405 assert(node->size == s);
406 };
407 dfs(dfs, root);
408 cout << "check ok." << endl;
409 }
410
411 friend ostream& operator<<(ostream& os, const titan23::AVLTreeSet<T>& s) {
412 vector<T> a = s.tovector();
413 int n = a.size();
414 os << "{";
415 for (int i = 0; i < n - 1; ++i) {
416 os << a[i] << ", ";
417 }
418 if (n > 0) os << a.back();
419 os << "}";
420 return os;
421 }
422 };
423}
仕様¶
Warning
doxygenfile: Cannot find file “titan_cpplib/data_structures/avl_tree_set.cpp