wb tree seg¶
ソースコード¶
1#include <iostream>
2#include <vector>
3#include <stack>
4#include <cassert>
5#include "titan_cpplib/others/print.cpp"
6using namespace std;
7
8namespace titan23 {
9
10 template <class T,
11 T (*op)(T, T),
12 T (*e)()>
13 class WBSegTree {
14 private:
15 class Node;
16 using NodePtr = Node*;
17 using MyLazyWBTree = WBSegTree<T, op, e>;
18 static constexpr int DELTA = 3;
19 static constexpr int GAMMA = 2;
20
21 class Node {
22 public:
23 NodePtr left;
24 NodePtr right;
25 T key, data;
26 int size;
27
28 Node(T key) : left(nullptr), right(nullptr), key(key), data(key), size(1) {}
29
30 int weight_right() const {
31 return right ? right->size + 1 : 1;
32 }
33
34 int weight_left() const {
35 return left ? left->size + 1 : 1;
36 }
37
38 void balance_check() const {
39 if (!weight_left()*DELTA >= weight_right()) {
40 cerr << weight_left() << ", " << weight_right() << endl;
41 cerr << "not weight_left()*DELTA >= weight_right()." << endl;
42 assert(false);
43 }
44 if (!weight_right() * DELTA >= weight_left()) {
45 cerr << weight_left() << ", " << weight_right() << endl;
46 cerr << "not weight_right() * DELTA >= weight_left()." << endl;
47 assert(false);
48 }
49 }
50
51 void update() {
52 size = 1;
53 data = key;
54 if (left) {
55 size += left->size;
56 data = op(left->data, key);
57 }
58 if (right) {
59 size += right->size;
60 data = op(data, right->data);
61 }
62 }
63 };
64
65 void _build(vector<T> const &a) {
66 auto build = [&] (auto &&build, int l, int r) -> NodePtr {
67 int mid = (l + r) >> 1;
68 NodePtr node = new Node(a[mid]);
69 if (l != mid) node->left = build(build, l, mid);
70 if (mid+1 != r) node->right = build(build, mid+1, r);
71 node->update();
72 return node;
73 };
74 if (a.empty()) {
75 this->root = nullptr;
76 return;
77 }
78 this->root = build(build, 0, (int)a.size());
79 }
80
81 NodePtr _rotate_right(NodePtr node) {
82 NodePtr u = node->left;
83 node->left = u->right;
84 u->right = node;
85 node->update();
86 u->update();
87 return u;
88 }
89
90 NodePtr _rotate_left(NodePtr node) {
91 NodePtr u = node->right;
92 node->right = u->left;
93 u->left = node;
94 node->update();
95 u->update();
96 return u;
97 }
98
99 NodePtr _balance_left(NodePtr node) {
100 NodePtr u = node->right;
101 if (node->right->weight_left() >= node->right->weight_right() * GAMMA) {
102 node->right = _rotate_right(u);
103 }
104 return _rotate_left(node);
105 }
106
107 NodePtr _balance_right(NodePtr node) {
108 NodePtr u = node->left;
109 if (node->left->weight_right() >= node->left->weight_left() * GAMMA) {
110 node->left = _rotate_left(u);
111 }
112 return _rotate_right(node);
113 }
114
115 int weight(NodePtr node) const {
116 return node ? node->size + 1 : 1;
117 }
118
119 NodePtr _merge_with_root(NodePtr l, NodePtr root, NodePtr r) {
120 if (weight(l) * DELTA < weight(r)) {
121 r->left = _merge_with_root(l, root, r->left);
122 r->update();
123 if (weight(r->right) * DELTA < weight(r->left)) {
124 return _balance_right(r);
125 }
126 return r;
127 } else if (weight(r) * DELTA < weight(l)) {
128 l->right = _merge_with_root(l->right, root, r);
129 l->update();
130 if (weight(l->left) * DELTA < weight(l->right)) {
131 return _balance_left(l);
132 }
133 return l;
134 }
135 root->left = l;
136 root->right = r;
137 root->update();
138 return root;
139 }
140
141 pair<NodePtr, NodePtr> _pop_right(NodePtr node) {
142 return _split_node(node, node->size-1);
143 }
144
145 NodePtr _merge_node(NodePtr l, NodePtr r) {
146 if ((!l) && (!r)) {return nullptr;}
147 if (!l) {return r;}
148 if (!r) {return l;}
149 auto [l_, root_] = _pop_right(l);
150 return _merge_with_root(l_, root_, r);
151 }
152
153 pair<NodePtr, NodePtr> _split_node(NodePtr node, int k) {
154 if (!node) return {nullptr, nullptr};
155 int tmp = node->left ? k-node->left->size : k;
156 NodePtr nl = node->left;
157 NodePtr nr = node->right;
158 if (tmp == 0) {
159 return {nl, _merge_with_root(nullptr, node, nr)};
160 } else if (tmp < 0) {
161 auto [l, r] = _split_node(nl, k);
162 return {l, _merge_with_root(r, node, nr)};
163 } else {
164 auto [l, r] = _split_node(nr, tmp-1);
165 return {_merge_with_root(node->left, node, l), r};
166 }
167 }
168
169 WBSegTree(NodePtr &root): root(root) {}
170
171 MyLazyWBTree _new(NodePtr root) {
172 return MyLazyWBTree(root);
173 }
174
175 public:
176 NodePtr root;
177
178 WBSegTree() : root(nullptr) {}
179
180 WBSegTree(vector<T> const &a) { _build(a); }
181
182 void merge(MyLazyWBTree &other) {
183 this->root = _merge_node(this->root, other.root);
184 }
185
186 pair<MyLazyWBTree, MyLazyWBTree> split(const int k) {
187 auto [l, r] = _split_node(this->root, k);
188 return {_new(l), _new(r)};
189 }
190
191 T all_prod() const {
192 return this->root ? this->root->data : e();
193 }
194
195 T prod(const int l, const int r) {
196 assert(0 <= l && l <= r && r <= len());
197 if (l == r) return e();
198 auto dfs = [&] (auto &&dfs, NodePtr node, int left, int right) -> T {
199 if (right <= l || r <= left) return e();
200 if (l <= left && right < r) return node->data;
201 int lsize = node->left ? node->left->size : 0;
202 T res = e();
203 if (node->left) res = dfs(dfs, node->left, left, left+lsize);
204 if (l <= left+lsize && left+lsize < r) res = op(res, node->key);
205 if (node->right) res = op(res, dfs(dfs, node->right, left+lsize+1, right));
206 return res;
207 };
208 return dfs(dfs, root, 0, len());
209 }
210
211 void insert(int k, const T key) {
212 assert(0 <= k && k <= len());
213 auto [s, t] = _split_node(root, k);
214 NodePtr new_node = new Node(key);
215 this->root = _merge_with_root(s, new_node, t);
216 }
217
218 T pop(int k) {
219 assert(0 <= k && k < len());
220 stack<NodePtr> path;
221 stack<short> path_d;
222 NodePtr node = this->root;
223 T res;
224 int d = 0;
225 while (1) {
226 int t = node->left? node->left->size: 0;
227 if (t == k) {
228 res = node->key;
229 break;
230 }
231 int d = (t < k)? 0: 1;
232 path.emplace(node);
233 path_d.emplace(d);
234 if (d) {
235 node = node->left;
236 } else {
237 k -= t + 1;
238 node = node->right;
239 }
240 }
241 if (node->left && node->right) {
242 NodePtr lmax = node->left;
243 path.emplace(node);
244 path_d.emplace(1);
245 while (lmax->right) {
246 path.emplace(lmax);
247 path_d.emplace(0);
248 lmax = lmax->right;
249 }
250 node->key = lmax->key;
251 node = lmax;
252 }
253 NodePtr cnode = (node->left) ? node->left : node->right;
254 if (!path.empty()) {
255 if (path_d.top()) path.top()->left = cnode;
256 else path.top()->right = cnode;
257 } else {
258 this->root = cnode;
259 return res;
260 }
261 while (!path.empty()) {
262 NodePtr new_node = nullptr;
263 node = path.top();
264 node->update();
265 path.pop();
266 path_d.pop();
267 if (weight(node->left) * DELTA < weight(node->right)) {
268 new_node = _balance_left(node);
269 } else if (weight(node->right) * DELTA < weight(node->left)) {
270 new_node = _balance_right(node);
271 }
272 if (new_node) {
273 if (path.empty()) {
274 this->root = new_node;
275 break;
276 }
277 if (path_d.top()) {
278 path.top()->left = new_node;
279 } else {
280 path.top()->right = new_node;
281 }
282 }
283 }
284 return res;
285 }
286
287 vector<T> tovector() {
288 NodePtr node = root;
289 stack<NodePtr> s;
290 vector<T> a;
291 a.reserve(len());
292 while ((!s.empty()) || node) {
293 if (node) {
294 s.emplace(node);
295 node = node->left;
296 } else {
297 node = s.top();
298 s.pop();
299 a.emplace_back(node->key);
300 node = node->right;
301 }
302 }
303 return a;
304 }
305
306 void set(int k, T v) {
307 assert(0 <= k && k < len());
308 NodePtr node = root;
309 NodePtr root = node;
310 NodePtr pnode = nullptr;
311 int d = 0;
312 stack<NodePtr> path;
313 path.emplace(node);
314 while (1) {
315 int t = node->left ? node->left->size : 0;
316 if (t == k) {
317 node->key = v;
318 path.emplace(node);
319 if (pnode) {
320 if (d) pnode->left = node;
321 else pnode->right = node;
322 }
323 while (!path.empty()) {
324 path.top()->update();
325 path.pop();
326 }
327 return;
328 }
329 pnode = node;
330 d = (t < k) ? 0 : 1;
331 if (d) {
332 pnode->left = node = node->left;
333 } else {
334 k -= t + 1;
335 pnode->right = node = node->right;
336 }
337 path.emplace(node);
338 }
339 }
340
341 T get(int k) {
342 assert(0 <= k && k < len());
343 NodePtr node = root;
344 while (1) {
345 int t = node->left ? node->left->size : 0;
346 if (t == k) {
347 return node->key;
348 }
349 if (t < k) {
350 k -= t + 1;
351 node = node->right;
352 } else {
353 node = node->left;
354 }
355 }
356 }
357
358 void print() {
359 vector<T> a = tovector();
360 cout << "[";
361 for (int i = 0; i < (int)a.size()-1; ++i) {
362 cout << a[i] << ", ";
363 }
364 if (!a.empty()) cout << a.back();
365 cout << "]" << endl;
366 }
367
368 friend ostream& operator<<(ostream& os, WBSegTree<T, op, e> &tree) {
369 vector<T> a = tree.tovector();
370 os << "[";
371 for (int i = 0; i < (int)a.size()-1; ++i) {
372 os << a[i] << ", ";
373 }
374 if (!a.empty()) os << a.back();
375 os << "]";
376 return os;
377 }
378
379 int len() const {
380 return root ? root->size : 0;
381 }
382
383 void check() const {
384 auto rec = [&] (auto &&rec, NodePtr node) -> pair<int, int> {
385 int ls = 0, rs = 0;
386 int height = 0;
387 int h;
388 if (node->left) {
389 pair<int, int> res = rec(rec, node->left);
390 ls = res.first;
391 h = res.second;
392 height = max(height, h);
393 }
394 if (node->right) {
395 pair<int, int> res = rec(rec, node->right);
396 rs = res.first;
397 h = res.second;
398 height = max(height, h);
399 }
400 node->balance_check();
401 int s = ls + rs + 1;
402 assert(s == node->size);
403 return {s, height+1};
404 };
405 if (root == nullptr) return;
406 auto [_, h] = rec(rec, root);
407 cerr << PRINT_GREEN << "OK : height=" << h << PRINT_NONE << endl;
408 }
409 };
410} // namespace titan23
仕様¶
Warning
doxygenfile: Cannot find file “titan_cpplib/data_structures/wb_tree_seg.cpp