dynamic wavelet tree¶
ソースコード¶
1#include <iostream>
2#include <vector>
3#include "titan_cpplib/data_structures/avl_tree_bit_vector.cpp"
4#include "titan_cpplib/others/print.cpp"
5using namespace std;
6
7// DynamicWaveletTree
8namespace titan23 {
9
10 /**
11 * @brief 動的ウェーブレット木
12 *
13 * @tparam T 値の型
14 */
15 template<typename T>
16 class DynamicWaveletTree {
17 private:
18 struct Node;
19 Node* root;
20 T _sigma;
21 int _log;
22 int _size;
23
24 struct Node {
25 Node* left;
26 Node* right;
27 Node* par;
28 AVLTreeBitVector v;
29 Node() : left(nullptr), right(nullptr), par(nullptr) {}
30 Node(const vector<uint8_t> &a) : left(nullptr), right(nullptr), par(nullptr) {
31 v = AVLTreeBitVector(a);
32 }
33 };
34
35 int bit_length(const int n) const {
36 return n > 0 ? 32 - __builtin_clz(n) : 0;
37 }
38
39 void _build(const vector<T> &a) {
40 vector<int> buff0(a.size()), buff1;
41 auto build = [&] (auto &&build,
42 int bit,
43 bool flag01,
44 int s0, int g0,
45 int s1, int g1
46 ) -> Node* {
47 int s = flag01 ? s1 : s0;
48 int g = flag01 ? g1 : g0;
49 if (s == g || bit < 0) return nullptr;
50 vector<int> &vec = flag01 ? buff1 : buff0;
51 vector<uint8_t> v(g-s, 0);
52 int start_0 = buff0.size(), start_1 = buff1.size();
53 for (int i = s; i < g; ++i) {
54 if (a[vec[i]] >> bit & 1) {
55 v[i-s] = 1;
56 buff1.emplace_back(vec[i]);
57 } else {
58 buff0.emplace_back(vec[i]);
59 }
60 }
61 int end_0 = buff0.size(), end_1 = buff1.size();
62 Node* node = new Node(v);
63 node->left = build(build, bit-1, 0, start_0, end_0, start_1, end_1);
64 if (node->left) node->left->par = node;
65 node->right = build(build, bit-1, 1, start_0, end_0, start_1, end_1);
66 if (node->right) node->right->par = node;
67 return node;
68 };
69 for (int i = 0; i < a.size(); ++i) {
70 buff0[i] = i;
71 }
72 this->root = build(build, _log-1, 0, 0, a.size(), 0, 0);
73 if (this->root == nullptr) {
74 this->root = new Node();
75 }
76 }
77
78 public:
79 //! 各要素が `[0, sigma)` の `DynamicWaveletTree` を作成する / `O(1)`
80 DynamicWaveletTree(const T sigma)
81 : _sigma(sigma), _log(bit_length(sigma)), _size(0) {
82 root = new Node();
83 }
84
85 //! 各要素が `[0, sigma)` の `DynamicWaveletTree` を作成する / `O(nlog(σ))`
86 DynamicWaveletTree(const T sigma, vector<T> &a)
87 : _sigma(sigma), _log(bit_length(sigma)), _size(a.size()) {
88 _build(a);
89 }
90
91 //! 位置 `k` に `x` を挿入する / `O(log(n)log(σ))`
92 void insert(int k, T x) {
93 assert(0 <= k && k <= len());
94 assert(0 <= x && x < _sigma);
95 Node* node = root;
96 for (int bit = _log-1; bit >= 0; --bit) {
97 if ((x >> bit) & 1) {
98 k = node->v._insert_and_rank1(k, 1);
99 if (!node->right) {
100 node->right = new Node();
101 node->right->par = node;
102 }
103 node = node->right;
104 } else {
105 k -= node->v._insert_and_rank1(k, 0);
106 if (!node->left) {
107 node->left = new Node();
108 node->left->par = node;
109 }
110 node = node->left;
111 }
112 }
113 _size++;
114 }
115
116 //! 位置 `k` の値を削除して返す / `O(log(n)log(σ))`
117 T pop(int k) {
118 assert(0 <= k && k < len());
119 Node* node = root;
120 T ans = 0;
121 for (int bit = _log-1; node && bit >= 0; --bit) {
122 int sb = node->v._access_pop_and_rank1(k);
123 if (sb & 1) {
124 ans |= (T)1 << bit;
125 k = sb >> 1;
126 node = node->right;
127 } else {
128 k -= sb >> 1;
129 node = node->left;
130 }
131 }
132 _size--;
133 return ans;
134 }
135
136 //! 位置 `k` の値を `x` に更新する / `O(log(n)log(σ))`
137 void set(int k, T x) {
138 assert(0 <= k && k < len());
139 assert(0 <= x && x < _sigma);
140 pop(k);
141 insert(k, x);
142 }
143
144 //! 区間 `[0, r)` の `x` の個数を返す / `O(log(n)log(σ))`
145 int rank(int r, T x) const {
146 assert(0 <= r && r <= len());
147 Node* node = root;
148 int l = 0;
149 for (int bit = _log-1; node && bit >= 0; --bit) {
150 if ((x >> bit) & 1) {
151 l = node->v.rank1(l);
152 r = node->v.rank1(r);
153 node = node->right;
154 } else {
155 l = node->v.rank0(l);
156 r = node->v.rank0(r);
157 node = node->left;
158 }
159 }
160 return r - l;
161 }
162
163 //! 区間 `[l, r)` の `x` の個数を返す / `O(log(n)log(σ))`
164 int range_count(int l, int r, T x) const {
165 assert(0 <= l && l <= r && r <= len());
166 return rank(r, x) - rank(l, x);
167 }
168
169 //! `k` 番目の要素を返す / `O(log(n)log(σ))`
170 T access(int k) const {
171 assert(0 <= k && k < len());
172 Node* node = root;
173 T s = 0;
174 for (int bit = _log-1; bit >= 0; --bit) {
175 auto [b, r] = node->v._access_ans_rank1(k);
176 if (b) {
177 s |= (T)1 << bit;
178 k = r;
179 node = node->right;
180 } else {
181 k -= r;
182 node = node->left;
183 }
184 }
185 return s;
186 }
187
188 //! 区間 `[l, r)` で昇順 `k` 番目の値を返す / `O(log(n)log(σ))`
189 T kth_smallest(int l, int r, int k) const {
190 assert(0 <= l && l <= r && r <= len());
191 assert(0 <= k && k < r-l);
192 Node* node = root;
193 T s = 0;
194 for (int bit = _log-1; node && bit >= 0; --bit) {
195 int l0 = node->v.rank0(l);
196 int r0 = node->v.rank0(r);
197 int cnt = r0 - l0;
198 if (cnt <= k) {
199 s |= (T)1 << bit;
200 k -= cnt;
201 l -= l0;
202 r -= r0;
203 node = node->right;
204 } else {
205 l = l0;
206 r = r0;
207 node = node->left;
208 }
209 }
210 return s;
211 }
212
213 //! 区間 `[l, r)` で降順 `k` 番目の値を返す / `O(log(n)log(σ))`
214 T kth_largest(int l, int r, int k) const {
215 return kth_smallest(l, r, r-l-k-1);
216 }
217
218 //! 区間 `[l, r)` で `x` 未満の要素の個数を返す / `O(log(n)log(σ))`
219 int range_freq(int l, int r, const T &x) const {
220 Node* node = root;
221 int ans = 0;
222 for (int bit = _log-1; node && bit >= 0; --bit) {
223 int l0 = node->v.rank0(l);
224 int r0 = node->v.rank0(r);
225 if ((x >> bit) & 1) {
226 ans += r0 - l0;
227 l -= l0;
228 r -= r0;
229 node = node->right;
230 } else {
231 l = l0;
232 r = r0;
233 node = node->left;
234 }
235 }
236 return ans;
237 }
238
239 //! 区間 `[l, r)` で `x` 以上 `y` 未満の要素の個数を返す / `O(log(n)log(σ))`
240 int range_freq(int l, int r, int x, int y) const {
241 return range_freq(l, r, y) - range_freq(l, r, x);
242 }
243
244 //! `k` 番目の `x` の位置を返す / `O(log(n)log(σ))`
245 int select(int k, T x) const {
246 Node* node = root;
247 for (int bit = _log-1; bit > 0; --bit) {
248 if ((x >> bit) & 1) {
249 node = node->right;
250 } else {
251 node = node->left;
252 }
253 }
254 for (int bit = 0; bit < _log; ++bit) {
255 if ((x >> bit) & 1) {
256 k = node->v.select1(k);
257 } else {
258 k = node->v.select0(k);
259 }
260 node = node->par;
261 }
262 return k;
263 }
264
265 //! `k` 番目の `x` の位置を返して削除する / `O(log(n)log(σ))`
266 int select_remove(int k, T x) {
267 Node* node = root;
268 for (int bit = _log-1; bit > 0; --bit) {
269 if ((x >> bit) & 1) {
270 node = node->right;
271 } else {
272 node = node->left;
273 }
274 }
275 for (int bit = 0; bit < _log; ++bit) {
276 if ((x >> bit) & 1) {
277 k = node->v.select1(k);
278 } else {
279 k = node->v.select0(k);
280 }
281 node->v.pop(k);
282 node = node->par;
283 }
284 _size--;
285 return k;
286 }
287
288 //! 区間[l, r)で、x未満のうち最大の要素を返す
289 T prev_value(int l, int r, T x) const {
290 int k = range_freq(l, r, x)-1;
291 if (k < 0) return -1;
292 return kth_smallest(l, r, k);
293 }
294
295 //! 区間[l, r)で、x以上のうち最小の要素を返す
296 T next_value(int l, int r, T x) const {
297 int k = range_freq(l, r, x);
298 if (k >= r-l) return -1;
299 return kth_smallest(l, r, k);
300 }
301
302 //! 要素数を返す / `O(1)`
303 int len() const {
304 return _size;
305 }
306
307 //! `vector` にして返す / `O(nlog(σ))`
308 //! (n 回 access するよりも高速)
309 vector<T> tovector() const {
310 vector<T> a(len(), 0);
311 vector<int> buff0(a.size()), buff1;
312 auto dfs = [&] (auto &&dfs,
313 Node* node,
314 int bit,
315 bool flag01,
316 int s0, int g0,
317 int s1, int g1
318 ) -> void {
319 int s = flag01 ? s1 : s0;
320 int g = flag01 ? g1 : g0;
321 if (s == g || bit < 0) return;
322 vector<int> &vec = flag01 ? buff1 : buff0;
323 const vector<uint8_t> &v = node->v.tovector();
324 int start_0 = buff0.size(), start_1 = buff1.size();
325 for (int i = s; i < g; ++i) {
326 if (v[i-s]) {
327 a[vec[i]] |= (T)1 << bit;
328 buff1.emplace_back(vec[i]);
329 } else {
330 buff0.emplace_back(vec[i]);
331 }
332 }
333 int end_0 = buff0.size(), end_1 = buff1.size();
334 dfs(dfs, node->left, bit-1, 0, start_0, end_0, start_1, end_1);
335 dfs(dfs, node->right, bit-1, 1, start_0, end_0, start_1, end_1);
336 };
337 for (int i = 0; i < a.size(); ++i) {
338 buff0[i] = i;
339 }
340 dfs(dfs, this->root, _log-1, 0, 0, a.size(), 0, 0);
341 return a;
342 }
343
344 //! 表示する / `O(nlog(σ))`
345 void print() const {
346 vector<T> a = tovector();
347 int n = (int)a.size();
348 cout << "[";
349 for (int i = 0; i < n-1; ++i) {
350 cout << a[i] << ", ";
351 }
352 if (n > 0) {
353 cout << a.back();
354 }
355 cout << "]";
356 cout << endl;
357 }
358
359 friend ostream& operator<<(ostream& os, const titan23::DynamicWaveletTree<T> &dwm) {
360 vector<T> a = dwm.tovector();
361 os << a;
362 return os;
363 }
364 };
365} // namespace titan23
仕様¶
Warning
doxygenfile: Cannot find file “titan_cpplib/data_structures/dynamic_wavelet_tree.cpp