wb tree¶
ソースコード¶
1#include <iostream>
2#include <vector>
3#include <cmath>
4#include <cassert>
5#include <memory>
6using namespace std;
7
8namespace titan23 {
9 const double ALPHA = 1 - sqrt(2) / 2;
10 const double BETA = (1 - 2 * ALPHA) / (1 - ALPHA);
11
12 template <class T>
13 class WBTree {
14
15 private:
16 class Node;
17 using NodePtr = Node*;
18 using MyWBTree = WBTree<T>;
19
20 class Node {
21 public:
22 T key;
23 NodePtr left, right;
24 int size;
25
26 Node(T key) : key(key), left(nullptr), right(nullptr), size(1) {}
27
28 double balance() const {
29 return ((left ? left->size : 0) + 1.0) / (size + 1.0);
30 }
31
32 void _update() {
33 size = 1;
34 if (left) {
35 size += left->size;
36 }
37 if (right) {
38 size += right->size;
39 }
40 }
41 };
42
43 void _build(vector<T> const &a) {
44 if (a.empty()) {
45 root = nullptr;
46 return;
47 }
48 auto build = [&] (auto &&build, int l, int r) -> NodePtr {
49 int mid = (l + r) >> 1;
50 NodePtr node = new Node(a[mid]);
51 if (l != mid) node->left = build(build, l, mid);
52 if (mid+1 != r) node->right = build(build, mid+1, r);
53 node->_update();
54 return node;
55 };
56 root = build(build, 0, (int)a.size());
57 }
58
59 NodePtr _rotate_right(NodePtr node) {
60 NodePtr u = node->left;
61 node->left = u->right;
62 u->right = node;
63 node->_update();
64 u->_update();
65 return u;
66 }
67
68 NodePtr _rotate_left(NodePtr node) {
69 NodePtr u = node->right;
70 node->right = u->left;
71 u->left = node;
72 node->_update();
73 u->_update();
74 return u;
75 }
76
77 NodePtr _balance_left(NodePtr node) {
78 NodePtr u = node->right;
79 if (u->balance() >= BETA) {
80 node->right = _rotate_right(u);
81 }
82 return _rotate_left(node);
83 }
84
85 NodePtr _balance_right(NodePtr node) {
86 NodePtr u = node->left;
87 if (u->balance() <= 1.0 - BETA) {
88 node->left = _rotate_left(u);
89 }
90 return _rotate_right(node);
91 }
92
93 NodePtr _merge_with_root(NodePtr l, NodePtr root, NodePtr r) {
94 int ls = l ? l->size : 0;
95 int rs = r ? r->size : 0;
96 double diff = (double)(ls+1.0) / (ls+rs+2.0);
97 if (diff > 1.0-ALPHA) {
98 l->right = _merge_with_root(l->right, root, r);
99 l->_update();
100 if (!(ALPHA <= l->balance() && l->balance() <= 1.0-ALPHA)) {
101 return _balance_left(l);
102 }
103 return l;
104 }
105 if (diff < ALPHA) {
106 r->left = _merge_with_root(l, root, r->left);
107 r->_update();
108 if (!(ALPHA <= r->balance() && r->balance() <= 1.0-ALPHA)) {
109 return _balance_right(r);
110 }
111 return r;
112 }
113 root->left = l;
114 root->right = r;
115 root->_update();
116 return root;
117 }
118
119 pair<NodePtr, NodePtr> _pop_right(NodePtr node) {
120 vector<NodePtr> path;
121 NodePtr mx = node;
122 while (node->right) {
123 path.emplace_back(node);
124 node = node->right;
125 mx = node;
126 }
127 path.emplace_back(node->left? node->left: nullptr);
128 while ((int)path.size() > 1) {
129 node = path.back();
130 path.pop_back();
131 if (!node) {
132 path.back()->right = nullptr;
133 path.back()->_update();
134 continue;
135 }
136 double b = node->balance();
137 if (ALPHA <= b && b <= 1.0-ALPHA) {
138 path.back()->right = node;
139 } else if (b > 1.0-ALPHA) {
140 path.back()->right = _balance_right(node);
141 } else {
142 path.back()->right = _balance_left(node);
143 }
144 path.back()->_update();
145 }
146 if (path[0]) {
147 double b = path[0]->balance();
148 if (b > 1.0-ALPHA) {
149 path[0] = _balance_right(path[0]);
150 } else if (b < ALPHA) {
151 path[0] = _balance_left(path[0]);
152 }
153 }
154 mx->left = nullptr;
155 mx->_update();
156 return {path[0], mx};
157 }
158
159 NodePtr _merge_node(NodePtr l, NodePtr r) {
160 if ((!l) && (!r)) {return nullptr;}
161 if (!l) {return r;}
162 if (!r) {return l;}
163 auto [l_, root_] = _pop_right(l);
164 return _merge_with_root(l_, root_, r);
165 }
166
167 pair<NodePtr, NodePtr> _split_node(NodePtr node, int k) {
168 if (!node) return {nullptr, nullptr};
169 int tmp = node->left? k-node->left->size: k;
170 NodePtr nl = node->left;
171 NodePtr nr = node->right;
172 if (tmp == 0) {
173 return {nl, _merge_with_root(nullptr, node, nr)};
174 } else if (tmp < 0) {
175 auto [l, r] = _split_node(nl, k);
176 return {l, _merge_with_root(r, node, nr)};
177 } else {
178 auto [l, r] = _split_node(nr, tmp-1);
179 return {_merge_with_root(node->left, node, l), r};
180 }
181 }
182
183 WBTree<T> _new(NodePtr root) {
184 WBTree<T> p(root);
185 return p;
186 }
187
188 WBTree(NodePtr &root): root(root) {}
189
190 public:
191 NodePtr root;
192
193 WBTree() : root(nullptr) {}
194
195 WBTree(vector<T> const &a) { _build(a); }
196
197 void build(vector<T> a) {
198 if (a.empty()) {
199 this->root = nullptr;
200 return;
201 }
202 auto _build = [&] (auto &&_build, int l, int r) -> NodePtr {
203 int mid = (l + r) >> 1;
204 NodePtr node = new Node(a[mid]);
205 if (l != mid) node->left = _build(_build, l, mid);
206 if (mid+1 != r) node->right = _build(_build, mid+1, r);
207 node->_update();
208 return node;
209 };
210 this->root = _build(_build, 0, (int)a.size());
211 }
212
213 void merge(MyWBTree &other) {
214 this->root = _merge_node(this->root, other.root);
215 }
216
217 pair<MyWBTree, MyWBTree> split(const int k) {
218 auto [l, r] = _split_node(this->root, k);
219 return {_new(l), _new(r)};
220 }
221
222 void insert(int k, const T key) {
223 assert(0 <= k && k <= len());
224 auto [s, t] = _split_node(root, k);
225 NodePtr new_node = new Node(key);
226 this->root = _merge_with_root(s, new_node, t);
227 }
228
229 void emplace_back(T key) {
230 insert(len(), key);
231 }
232
233 T pop(int k) {
234 if (k < 0) k += len();
235 assert(0 <= k && k < len());
236 vector<NodePtr> path;
237 vector<short> path_d;
238 NodePtr node = this->root;
239 T res;
240 int d = 0;
241 while (1) {
242 int t = node->left? node->left->size: 0;
243 if (t == k) {
244 res = node->key;
245 break;
246 }
247 int d = (t < k)? 0: 1;
248 path.emplace_back(node);
249 path_d.emplace_back(d);
250 if (d) {
251 node = node->left;
252 } else {
253 k -= t + 1;
254 node = node->right;
255 }
256 }
257 if (node->left && node->right) {
258 NodePtr lmax = node->left;
259 path.emplace_back(node);
260 path_d.emplace_back(1);
261 while (lmax->right) {
262 path.emplace_back(lmax);
263 path_d.emplace_back(0);
264 lmax = lmax->right;
265 }
266 node->key = lmax->key;
267 node = lmax;
268 }
269 NodePtr cnode = (node->left)? node->left: node->right;
270 if (!path.empty()) {
271 if (path_d.back()) path.back()->left = cnode;
272 else path.back()->right = cnode;
273 } else {
274 this->root = cnode;
275 return res;
276 }
277 while (!path.empty()) {
278 NodePtr new_node = nullptr;
279 node = path.back();
280 node->_update();
281 path.pop_back();
282 path_d.pop_back();
283 double b = node->balance();
284 if (b < ALPHA) {
285 new_node = _balance_left(node);
286 } else if (b > 1.0-ALPHA) {
287 new_node = _balance_right(node);
288 }
289 if (new_node) {
290 if (path.empty()) {
291 this->root = new_node;
292 break;
293 }
294 if (path_d.back()) {
295 path.back()->left = new_node;
296 } else {
297 path.back()->right = new_node;
298 }
299 }
300 }
301 return res;
302 }
303
304 vector<T> tovector() {
305 NodePtr node = root;
306 vector<NodePtr> stack;
307 vector<T> a;
308 a.reserve(len());
309 while ((!stack.empty()) || node) {
310 if (node) {
311 stack.emplace_back(node);
312 node = node->left;
313 } else {
314 node = stack.back();
315 stack.pop_back();
316 a.emplace_back(node->key);
317 node = node->right;
318 }
319 }
320 return a;
321 }
322
323 void set(int k, T v) {
324 assert(0 <= k && k < len());
325 NodePtr node = root;
326 NodePtr root = node;
327 NodePtr pnode = nullptr;
328 int d = 0;
329 vector<NodePtr> path = {node};
330 while (1) {
331 int t = node->left? node->left->size: 0;
332 if (t == k) {
333 node->key = v;
334 path.emplace_back(node);
335 if (pnode) {
336 if (d) pnode->left = node;
337 else pnode->right = node;
338 }
339 while (!path.empty()) {
340 path.back()->_update();
341 path.pop_back();
342 }
343 return;
344 }
345 pnode = node;
346 d = (t < k)? 0: 1;
347 if (d) {
348 pnode->left = node = node->left;
349 } else {
350 k -= t + 1;
351 pnode->right = node = node->right;
352 }
353 path.emplace_back(node);
354 }
355 }
356
357 T operator[] (int k) const {
358 if (k < 0) k += len();
359 assert(0 <= k && k < len());
360 NodePtr node = root;
361 while (1) {
362 int t = node->left? node->left->size: 0;
363 if (t == k) {
364 return node->key;
365 }
366 if (t < k) {
367 k -= t + 1;
368 node = node->right;
369 } else {
370 node = node->left;
371 }
372 }
373 }
374
375 void clear() {
376 this->root = nullptr;
377 }
378
379 T& operator[] (int k) {
380 if (k < 0) k += len();
381 assert(0 <= k && k < len());
382 NodePtr node = root;
383 while (1) {
384 int t = node->left? node->left->size: 0;
385 if (t == k) {
386 return node->key;
387 }
388 if (t < k) {
389 k -= t + 1;
390 node = node->right;
391 } else {
392 node = node->left;
393 }
394 }
395 }
396
397 void print() {
398 vector<T> a = tovector();
399 cout << "[";
400 for (int i = 0; i < (int)a.size()-1; ++i) {
401 cout << a[i] << ", ";
402 }
403 if (!a.empty()) cout << a.back();
404 cout << "]" << endl;
405 }
406
407 int len() const {
408 return root ? root->size : 0;
409 }
410
411 void isok() {
412 auto rec = [&] (auto &&rec, NodePtr node) -> pair<int, int> {
413 int ls = 0, rs = 0;
414 int height = 0;
415 int h;
416 if (node->left) {
417 pair<int, int> res = rec(rec, node->left);
418 ls = res.first;
419 h = res.second;
420 height = max(height, h);
421 }
422 if (node->right) {
423 pair<int, int> res = rec(rec, node->right);
424 rs = res.first;
425 h = res.second;
426 height = max(height, h);
427 }
428 int s = ls + rs + 1;
429 double b = (double)(ls+1) / (s+1);
430 assert(s == node->size);
431 assert(ALPHA <= b && b <= 1-ALPHA);
432 return {s, height+1};
433 };
434 if (root == nullptr) return;
435 auto [_, h] = rec(rec, root);
436 // printf("height=%d\n", h);
437 }
438 };
439} // namespace titan23
仕様¶
Warning
doxygenfile: Cannot find file “titan_cpplib/data_structures/wb_tree.cpp