dynamic wavelet matrix

ソースコード

  1#include <iostream>
  2#include <vector>
  3#include <tuple>
  4#include <queue>
  5#include "titan_cpplib/data_structures/avl_tree_bit_vector.cpp"
  6#include "titan_cpplib/others/print.cpp"
  7using namespace std;
  8
  9// DynamicWaveletMatrix
 10namespace titan23 {
 11
 12    /**
 13     * @brief 動的ウェーブレット行列
 14     *
 15     * @note `DynamicWavletTree` の方が平均的には高速かもしれません
 16     *
 17     * @tparam T
 18     */
 19    template<typename T>
 20    class DynamicWaveletMatrix {
 21      private:
 22        T _sigma;
 23        int _log;
 24        vector<AVLTreeBitVector> _v;
 25        vector<int> _mid;
 26        int _size;
 27
 28        int bit_length(const int n) const {
 29            return n == 0 ? 0 : 32 - __builtin_clz(n);
 30        }
 31
 32        int bit_length(const long long n) const {
 33            return n == 0 ? 0 : 64 - __builtin_clz(n);
 34        }
 35
 36        int bit_length(const unsigned long long n) const {
 37            return n == 0 ? 0 : 64 - __builtin_clz(n);
 38        }
 39
 40        void _build(vector<T> a) {
 41            if (a.empty()) return;
 42            vector<uint8_t> v(_size);
 43            for (int bit = _log-1; bit >= 0; --bit) {
 44                vector<T> zero, one;
 45                for (int i = 0; i < _size; ++i) {
 46                    if ((a[i] >> bit) & 1) {
 47                        v[i] = 1;
 48                        one.emplace_back(a[i]);
 49                    } else {
 50                        v[i] = 0;
 51                        zero.emplace_back(a[i]);
 52                    }
 53                }
 54                _mid[bit] = zero.size();
 55                _v[bit] = AVLTreeBitVector(v);
 56                a = zero;
 57                a.insert(a.end(), one.begin(), one.end());
 58            }
 59        }
 60
 61      public:
 62        DynamicWaveletMatrix(const T sigma)
 63                : _sigma(sigma), _log(bit_length(sigma-1)), _v(_log), _mid(_log), _size(0) {
 64        }
 65
 66        DynamicWaveletMatrix(const T sigma, vector<T> &a)
 67                : _sigma(sigma), _log(bit_length(sigma)), _v(_log), _mid(_log), _size(a.size()) {
 68            _build(a);
 69        }
 70
 71        void reserve(const int n) {
 72            for (int i = 0; i < _log; ++i) {
 73                _v[i].reserve(n);
 74            }
 75        }
 76
 77        void insert(int k, T x) {
 78            for (int bit = _log-1; bit >= 0; --bit) {
 79                if ((x >> bit) & 1) {
 80                    int s = _v[bit]._insert_and_rank1(k, 1);
 81                    k = s + _mid[bit];
 82                } else {
 83                    int s = _v[bit]._insert_and_rank1(k, 0);
 84                    k -= s;
 85                    ++_mid[bit];
 86                }
 87            }
 88            _size++;
 89        }
 90
 91        T pop(int k) {
 92            T ans = 0;
 93            for (int bit = _log-1; bit >= 0; --bit) {
 94                int sb = _v[bit]._access_pop_and_rank1(k);
 95                int s = sb >> 1;
 96                if (sb & 1) {
 97                    ans |= (T)1 << bit;
 98                    k = s + _mid[bit];
 99                } else {
100                    --_mid[bit];
101                    k -= s;
102                }
103            }
104            _size--;
105            return ans;
106        }
107
108        void set(int k, T x) {
109            assert(0 <= k && k < _size);
110            pop(k);
111            insert(k, x);
112        }
113
114        int rank(int r, T x) const {
115            int l = 0;
116            for (int bit = _log-1; bit >= 0; --bit) {
117                if ((x >> bit) & 1) {
118                    l = _v[bit].rank1(l) + _mid[bit];
119                    r = _v[bit].rank1(r) + _mid[bit];
120                } else {
121                    l = _v[bit].rank0(l);
122                    r = _v[bit].rank0(r);
123                }
124            }
125            return r - l;
126        }
127
128        T access(int k) const {
129            T s = 0;
130            for (int bit = _log-1; bit >= 0; --bit) {
131                if (_v[bit].access(k)) {
132                    s |= (T)1 << bit;
133                    k = _v[bit].rank1(k) + _mid[bit];
134                } else {
135                    k = _v[bit].rank0(k);
136                }
137            }
138            return s;
139        }
140
141        T kth_smallest(int l, int r, int k) const {
142            T s = 0;
143            for (int bit = _log-1; bit >= 0; --bit) {
144                int l0 = _v[bit].rank0(l);
145                int r0 = _v[bit].rank0(r);
146                int cnt = r0 - l0;
147                if (cnt <= k) {
148                    s |= (T)1 << bit;
149                    k -= cnt;
150                    l = l - l0 + _mid[bit];
151                    r = r - r0 + _mid[bit];
152                } else {
153                    l = l0;
154                    r = r0;
155                }
156            }
157            return s;
158        }
159
160        T kth_largest(int l, int r, int k) const {
161            return kth_smallest(l, r, r-l-k-1);
162        }
163
164        vector<pair<T, int>> topk(int l, int r, int k) {
165            priority_queue<tuple<int, T, int, char>> hq;
166            hq.emplace(r-l, 0, l, _log-1);
167            vector<pair<T, int>> ans;
168            while (!hq.empty()) {
169                auto [length, x, l, bit] = hq.top();
170                hq.pop();
171                if (bit == -1) {
172                    ans.emplace_back(x, length);
173                    --k;
174                    if (k == 0) break;
175                } else {
176                    int r = l + length;
177                    int l0 = _v[bit].rank0(l);
178                    int r0 = _v[bit].rank0(r);
179                    if (l0 < r0) hq.emplace(r0-l0, x, l0, bit-1);
180                    int l1 = _v[bit].rank1(l) + _mid[bit];
181                    int r1 = _v[bit].rank1(r) + _mid[bit];
182                    if (l1 < r1) hq.emplace(r1-l1, x|((T)1<<(T)bit), l1, bit-1);
183                }
184            }
185            return ans;
186        }
187
188        int select(int k, T x) const {
189            T s = 0;
190            for (int bit = _log-1; bit >= 0; --bit) {
191                if ((x >> bit) & 1) {
192                    s = _v[bit].rank0(_size) + _v[bit].rank1(s);
193                } else {
194                    s = _v[bit].rank0(s);
195                }
196            }
197            s += k;
198            for (int bit = 0; bit < _log; ++bit) {
199                if ((x >> bit) & 1) {
200                    s = _v[bit].select1(s - _v[bit].rank0(_size));
201                } else {
202                    s = _v[bit].select0(s);
203                }
204            }
205            return s;
206        }
207
208        int range_freq(int l, int r, T x) const {
209            int ans = 0;
210            for (int bit = _log-1; bit >= 0; --bit) {
211                int l0 = _v[bit].rank0(l);
212                int r0 = _v[bit].rank0(r);
213                if ((x >> bit) & 1) {
214                    ans += r0 - l0;
215                    l += _mid[bit] - l0;
216                    r += _mid[bit] - r0;
217                } else {
218                    l = l0;
219                    r = r0;
220                }
221            }
222            return ans;
223        }
224
225        int range_freq(int l, int r, int x, int y) const {
226            return range_freq(l, r, y) - range_freq(l, r, x);
227        }
228
229        //! 区間[l, r)で、x未満のうち最大の要素を返す
230        T prev_value(int l, int r, T x) const {
231            int k = range_freq(l, r, x)-1;
232            if (k < 0) return -1;
233            return kth_smallest(l, r, k);
234        }
235
236        //! 区間[l, r)で、x以上のうち最小の要素を返す
237        T next_value(int l, int r, T x) const {
238            int k = range_freq(l, r, x);
239            if (k >= r-l) return -1;
240            return kth_smallest(l, r, k);
241        }
242
243        int range_count(int l, int r, T x) const {
244            return rank(r, x) - rank(l, x);
245        }
246
247        vector<T> tovector() const {
248            vector<T> a(_size);
249            for (int i = 0; i < _size; ++i) {
250                a[i] = access(i);
251            }
252            return a;
253        }
254
255        void print() const {
256            vector<T> a = tovector();
257            int n = (int)a.size();
258            cout << "[";
259            for (int i = 0; i < n-1; ++i) {
260                cout << a[i] << ", ";
261            }
262            if (n > 0) {
263                cout << a.back();
264            }
265            cout << "]";
266            cout << endl;
267        }
268
269        friend ostream& operator<<(ostream& os, const titan23::DynamicWaveletMatrix<T> &dwm) {
270            vector<T> a = dwm.tovector();
271            os << a;
272            return os;
273        }
274    };
275} // namespace titan23

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/dynamic_wavelet_matrix.cpp