dynamic wavelet tree

ソースコード

  1#include <iostream>
  2#include <vector>
  3#include "titan_cpplib/data_structures/avl_tree_bit_vector.cpp"
  4#include "titan_cpplib/others/print.cpp"
  5using namespace std;
  6
  7// DynamicWaveletTree
  8namespace titan23 {
  9
 10    /**
 11     * @brief 動的ウェーブレット木
 12     * 
 13     * @tparam T 値の型
 14     */
 15    template<typename T>
 16    class DynamicWaveletTree {
 17      private:
 18        struct Node;
 19        Node* root;
 20        T _sigma;
 21        int _log;
 22        int _size;
 23
 24        struct Node {
 25            Node* left;
 26            Node* right;
 27            Node* par;
 28            AVLTreeBitVector v;
 29            Node() : left(nullptr), right(nullptr), par(nullptr) {}
 30            Node(const vector<uint8_t> &a) : left(nullptr), right(nullptr), par(nullptr) {
 31                v = AVLTreeBitVector(a);
 32            }
 33        };
 34
 35        int bit_length(const int n) const {
 36            return n > 0 ? 32 - __builtin_clz(n) : 0;
 37        }
 38
 39        void _build(const vector<T> &a) {
 40            vector<int> buff0(a.size()), buff1;
 41            auto build = [&] (auto &&build,
 42                              int bit,
 43                              bool flag01,
 44                              int s0, int g0,
 45                              int s1, int g1
 46                              ) -> Node* {
 47                int s = flag01 ? s1 : s0;
 48                int g = flag01 ? g1 : g0;
 49                if (s == g || bit < 0) return nullptr;
 50                vector<int> &vec = flag01 ? buff1 : buff0;
 51                vector<uint8_t> v(g-s, 0);
 52                int start_0 = buff0.size(), start_1 = buff1.size();
 53                for (int i = s; i < g; ++i) {
 54                    if (a[vec[i]] >> bit & 1) {
 55                        v[i-s] = 1;
 56                        buff1.emplace_back(vec[i]);
 57                    } else {
 58                        buff0.emplace_back(vec[i]);
 59                    }
 60                }
 61                int end_0 = buff0.size(), end_1 = buff1.size();
 62                Node* node = new Node(v);
 63                node->left  = build(build, bit-1, 0, start_0, end_0, start_1, end_1);
 64                if (node->left) node->left->par = node;
 65                node->right = build(build, bit-1, 1, start_0, end_0, start_1, end_1);
 66                if (node->right) node->right->par = node;
 67                return node;
 68            };
 69            for (int i = 0; i < a.size(); ++i) {
 70                buff0[i] = i;
 71            }
 72            this->root = build(build, _log-1, 0, 0, a.size(), 0, 0);
 73            if (this->root == nullptr) {
 74                this->root = new Node();
 75            }
 76        }
 77
 78      public:
 79        //! 各要素が `[0, sigma)` の `DynamicWaveletTree` を作成する / `O(1)`
 80        DynamicWaveletTree(const T sigma)
 81                : _sigma(sigma), _log(bit_length(sigma)), _size(0) {
 82            root = new Node();
 83        }
 84
 85        //! 各要素が `[0, sigma)` の `DynamicWaveletTree` を作成する / `O(nlog(σ))`
 86        DynamicWaveletTree(const T sigma, vector<T> &a)
 87                : _sigma(sigma), _log(bit_length(sigma)), _size(a.size()) {
 88            _build(a);
 89        }
 90
 91        //! 位置 `k` に `x` を挿入する / `O(log(n)log(σ))`
 92        void insert(int k, T x) {
 93            assert(0 <= k && k <= len());
 94            assert(0 <= x && x < _sigma);
 95            Node* node = root;
 96            for (int bit = _log-1; bit >= 0; --bit) {
 97                if ((x >> bit) & 1) {
 98                    k = node->v._insert_and_rank1(k, 1);
 99                    if (!node->right) {
100                        node->right = new Node();
101                        node->right->par = node;
102                    }
103                    node = node->right;
104                } else {
105                    k -= node->v._insert_and_rank1(k, 0);
106                    if (!node->left) {
107                        node->left = new Node();
108                        node->left->par = node;
109                    }
110                    node = node->left;
111                }
112            }
113            _size++;
114        }
115
116        //! 位置 `k` の値を削除して返す / `O(log(n)log(σ))`
117        T pop(int k) {
118            assert(0 <= k && k < len());
119            Node* node = root;
120            T ans = 0;
121            for (int bit = _log-1; node && bit >= 0; --bit) {
122                int sb = node->v._access_pop_and_rank1(k);
123                if (sb & 1) {
124                    ans |= (T)1 << bit;
125                    k = sb >> 1;
126                    node = node->right;
127                } else {
128                    k -= sb >> 1;
129                    node = node->left;
130                }
131            }
132            _size--;
133            return ans;
134        }
135
136        //! 位置 `k` の値を `x` に更新する / `O(log(n)log(σ))`
137        void set(int k, T x) {
138            assert(0 <= k && k < len());
139            assert(0 <= x && x < _sigma);
140            pop(k);
141            insert(k, x);
142        }
143
144        //! 区間 `[0, r)` の `x` の個数を返す / `O(log(n)log(σ))`
145        int rank(int r, T x) const {
146            assert(0 <= r && r <= len());
147            Node* node = root;
148            int l = 0;
149            for (int bit = _log-1; node && bit >= 0; --bit) {
150                if ((x >> bit) & 1) {
151                    l = node->v.rank1(l);
152                    r = node->v.rank1(r);
153                    node = node->right;
154                } else {
155                    l = node->v.rank0(l);
156                    r = node->v.rank0(r);
157                    node = node->left;
158                }
159            }
160            return r - l;
161        }
162
163        //! 区間 `[l, r)` の `x` の個数を返す / `O(log(n)log(σ))`
164        int range_count(int l, int r, T x) const {
165            assert(0 <= l && l <= r && r <= len());
166            return rank(r, x) - rank(l, x);
167        }
168
169        //! `k` 番目の要素を返す / `O(log(n)log(σ))`
170        T access(int k) const {
171            assert(0 <= k && k < len());
172            Node* node = root;
173            T s = 0;
174            for (int bit = _log-1; bit >= 0; --bit) {
175                auto [b, r] = node->v._access_ans_rank1(k);
176                if (b) {
177                    s |= (T)1 << bit;
178                    k = r;
179                    node = node->right;
180                } else {
181                    k -= r;
182                    node = node->left;
183                }
184            }
185            return s;
186        }
187
188        //! 区間 `[l, r)` で昇順 `k` 番目の値を返す / `O(log(n)log(σ))`
189        T kth_smallest(int l, int r, int k) const {
190            assert(0 <= l && l <= r && r <= len());
191            assert(0 <= k && k < r-l);
192            Node* node = root;
193            T s = 0;
194            for (int bit = _log-1; node && bit >= 0; --bit) {
195                int l0 = node->v.rank0(l);
196                int r0 = node->v.rank0(r);
197                int cnt = r0 - l0;
198                if (cnt <= k) {
199                    s |= (T)1 << bit;
200                    k -= cnt;
201                    l -= l0;
202                    r -= r0;
203                    node = node->right;
204                } else {
205                    l = l0;
206                    r = r0;
207                    node = node->left;
208                }
209            }
210            return s;
211        }
212
213        //! 区間 `[l, r)` で降順 `k` 番目の値を返す / `O(log(n)log(σ))`
214        T kth_largest(int l, int r, int k) const {
215            return kth_smallest(l, r, r-l-k-1);
216        }
217
218        //! 区間 `[l, r)` で `x` 未満の要素の個数を返す / `O(log(n)log(σ))`
219        int range_freq(int l, int r, const T &x) const {
220            Node* node = root;
221            int ans = 0;
222            for (int bit = _log-1; node && bit >= 0; --bit) {
223                int l0 = node->v.rank0(l);
224                int r0 = node->v.rank0(r);
225                if ((x >> bit) & 1) {
226                    ans += r0 - l0;
227                    l -= l0;
228                    r -= r0;
229                    node = node->right;
230                } else {
231                    l = l0;
232                    r = r0;
233                    node = node->left;
234                }
235            }
236            return ans;
237        }
238
239        //! 区間 `[l, r)` で `x` 以上 `y` 未満の要素の個数を返す / `O(log(n)log(σ))`
240        int range_freq(int l, int r, int x, int y) const {
241            return range_freq(l, r, y) - range_freq(l, r, x);
242        }
243
244        //! `k` 番目の `x` の位置を返す / `O(log(n)log(σ))`
245        int select(int k, T x) const {
246            Node* node = root;
247            for (int bit = _log-1; bit > 0; --bit) {
248                if ((x >> bit) & 1) {
249                    node = node->right;
250                } else {
251                    node = node->left;
252                }
253            }
254            for (int bit = 0; bit < _log; ++bit) {
255                if ((x >> bit) & 1) {
256                    k = node->v.select1(k);
257                } else {
258                    k = node->v.select0(k);
259                }
260                node = node->par;
261            }
262            return k;
263        }
264
265        //! `k` 番目の `x` の位置を返して削除する / `O(log(n)log(σ))`
266        int select_remove(int k, T x) {
267            Node* node = root;
268            for (int bit = _log-1; bit > 0; --bit) {
269                if ((x >> bit) & 1) {
270                    node = node->right;
271                } else {
272                    node = node->left;
273                }
274            }
275            for (int bit = 0; bit < _log; ++bit) {
276                if ((x >> bit) & 1) {
277                    k = node->v.select1(k);
278                } else {
279                    k = node->v.select0(k);
280                }
281                node->v.pop(k);
282                node = node->par;
283            }
284            _size--;
285            return k;
286        }
287
288        //! 区間[l, r)で、x未満のうち最大の要素を返す
289        T prev_value(int l, int r, T x) const {
290            int k = range_freq(l, r, x)-1;
291            if (k < 0) return -1;
292            return kth_smallest(l, r, k);
293        }
294
295        //! 区間[l, r)で、x以上のうち最小の要素を返す
296        T next_value(int l, int r, T x) const {
297            int k = range_freq(l, r, x);
298            if (k >= r-l) return -1;
299            return kth_smallest(l, r, k);
300        }
301
302        //! 要素数を返す / `O(1)`
303        int len() const {
304            return _size;
305        }
306
307        //! `vector` にして返す / `O(nlog(σ))`
308        //! (n 回 access するよりも高速)
309        vector<T> tovector() const {
310            vector<T> a(len(), 0);
311            vector<int> buff0(a.size()), buff1;
312            auto dfs = [&] (auto &&dfs,
313                            Node* node,
314                            int bit,
315                            bool flag01,
316                            int s0, int g0,
317                            int s1, int g1
318                            ) -> void {
319                int s = flag01 ? s1 : s0;
320                int g = flag01 ? g1 : g0;
321                if (s == g || bit < 0) return;
322                vector<int> &vec = flag01 ? buff1 : buff0;
323                const vector<uint8_t> &v = node->v.tovector();
324                int start_0 = buff0.size(), start_1 = buff1.size();
325                for (int i = s; i < g; ++i) {
326                    if (v[i-s]) {
327                        a[vec[i]] |= (T)1 << bit;
328                        buff1.emplace_back(vec[i]);
329                    } else {
330                        buff0.emplace_back(vec[i]);
331                    }
332                }
333                int end_0 = buff0.size(), end_1 = buff1.size();
334                dfs(dfs, node->left,  bit-1, 0, start_0, end_0, start_1, end_1);
335                dfs(dfs, node->right, bit-1, 1, start_0, end_0, start_1, end_1);
336            };
337            for (int i = 0; i < a.size(); ++i) {
338                buff0[i] = i;
339            }
340            dfs(dfs, this->root, _log-1, 0, 0, a.size(), 0, 0);
341            return a;
342        }
343
344        //! 表示する / `O(nlog(σ))`
345        void print() const {
346            vector<T> a = tovector();
347            int n = (int)a.size();
348            cout << "[";
349            for (int i = 0; i < n-1; ++i) {
350                cout << a[i] << ", ";
351            }
352            if (n > 0) {
353                cout << a.back();
354            }
355            cout << "]";
356            cout << endl;
357        }
358
359        friend ostream& operator<<(ostream& os, const titan23::DynamicWaveletTree<T> &dwm) {
360            vector<T> a = dwm.tovector();
361            os << a;
362            return os;
363        }
364    };
365} // namespace titan23

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/dynamic_wavelet_tree.cpp