wavelet matrix

ソースコード

  1#include <vector>
  2#include <queue>
  3#include "titan_cpplib/data_structures/bit_vector.cpp"
  4using namespace std;
  5
  6// WaveletMatrix
  7namespace titan23 {
  8
  9    template<typename T>
 10    class WaveletMatrix {
 11
 12      private:
 13        T sigma;
 14        int log;
 15        vector<BitVector> v;
 16        vector<int> mid;
 17        int n;
 18
 19        int bit_length(const int n) const {
 20            return n == 0 ? 0 : 32 - __builtin_clz(n);
 21        }
 22
 23        void build(vector<T> a) {
 24            for (int bit = log-1; bit >= 0; --bit) {
 25                vector<T> zero, one;
 26                v[bit] = BitVector(n);
 27                for (int i = 0; i < n; ++i) {
 28                    if ((a[i] >> bit) & 1) {
 29                        v[bit].set(i);
 30                        one.emplace_back(a[i]);
 31                    } else {
 32                        zero.emplace_back(a[i]);
 33                    }
 34                }
 35                v[bit].build();
 36                mid[bit] = zero.size();
 37                a = zero;
 38                a.insert(a.end(), one.begin(), one.end());
 39                assert(a.size() == n);
 40            }
 41        }
 42
 43      public:
 44        WaveletMatrix() {}
 45
 46        WaveletMatrix(const T sigma)
 47            : sigma(sigma), log(bit_length(sigma-1)), v(log), mid(log), n(0) {}
 48
 49        WaveletMatrix(const T sigma, const vector<T> &a)
 50            : sigma(sigma), log(bit_length(sigma-1)), v(log), mid(log), n(a.size()) {
 51            build(a);
 52        }
 53
 54        T access(int k) const {
 55            T s = 0;
 56            for (int bit = log-1; bit >= 0; --bit) {
 57                if (v[bit].access(k)) {
 58                    s |= (T)1 << bit;
 59                    k = v[bit].rank1(k) + mid[bit];
 60                } else {
 61                    k = v[bit].rank0(k);
 62                }
 63            }
 64            return s;
 65        }
 66
 67        //! `a[0, r)` に含まれる `x` の個数を返します。
 68        int rank(int r, int x) const {
 69            int l = 0;
 70            for (int bit = log-1; bit >= 0; --bit) {
 71                if ((x >> bit) & 1) {
 72                    l = v[bit].rank1(l) + mid[bit];
 73                    r = v[bit].rank1(r) + mid[bit];
 74                } else {
 75                    l = v[bit].rank0(l);
 76                    r = v[bit].rank0(r);
 77                }
 78            }
 79            return r - l;
 80        }
 81
 82        // `k` 番目の `v` のインデックスを返す。
 83        int select(int k, int x) const {
 84            int s = 0;
 85            for (int bit = log-1; bit >= 0; --bit) {
 86                if ((x >> bit) & 1) {
 87                    s = v[bit].rank0(n) + v[bit].rank1(s);
 88                } else {
 89                    s = v[bit].rank0(s);
 90                }
 91            }
 92            s += k;
 93            for (int bit = 0; bit < log; ++bit) {
 94                if ((x >> bit) & 1) {
 95                    s = v[bit].select1(s - v[bit].rank0(n));
 96                } else {
 97                    s = v[bit].select0(s);
 98                }
 99            }
100            return s;
101        }
102
103        // `a[l, r)` の中で k 番目に **小さい** 値を返します。
104        T kth_smallest(int l, int r, int k) const {
105            T s = 0;
106            for (int bit = log-1; bit >= 0; --bit) {
107                const int r0 = v[bit].rank0(r), l0 = v[bit].rank0(l);
108                const int cnt = r0 - l0;
109                if (cnt <= k) {
110                    s |= (T)1 << bit;
111                    k -= cnt;
112                    l = l - l0 + mid[bit];
113                    r = r - r0 + mid[bit];
114                } else {
115                    l = l0;
116                    r = r0;
117                }
118            }
119            return s;
120        }
121
122        T kth_largest(int l, int r, int k) const {
123            return kth_smallest(l, r, r-l-k-1);
124        }
125
126        // `a[l, r)` の中で、要素を出現回数が多い順にその頻度とともに `k` 個返します。
127        vector<pair<int, int>> topk(int l, int r, int k) {
128            // heap[-length, x, l, bit]
129            priority_queue<tuple<int, T, int, int>> hq;
130            hq.emplace(r-l, 0, l, log-1);
131            vector<pair<T, int>> ans;
132            while (!hq.empty()) {
133                auto [length, x, l, bit] = hq.top();
134                hq.pop();
135                if (bit == -1) {
136                    ans.emplace_back(x, length);
137                    k -= 1;
138                    if (k == 0) break;
139                } else {
140                    r = l + length;
141                    int l0 = v[bit].rank0(l);
142                    int r0 = v[bit].rank0(r);
143                    if (l0 < r0) hq.emplace(r0-l0, x, l0, bit-1);
144                    int l1 = v[bit].rank1(l) + mid[bit];
145                    int r1 = v[bit].rank1(r) + mid[bit];
146                    if (l1 < r1) hq.emplace(r1-l1, x|((T)1<<bit), l1, bit-1);
147                }
148            }
149            return ans;
150        }
151
152        T sum(int l, int r) const {
153            assert(false);
154            T s = 0;
155            for (const auto &[sum, cnt]: topk(l, r, r-l)) {
156                s += sum * cnt;
157            }
158            return s;
159        }
160
161        // a[l, r) で x 未満の要素の数を返す'''
162        int range_freq(int l, int r, int x) const {
163            int ans = 0;
164            for (int bit = log-1; bit >= 0; --bit) {
165                int l0 = v[bit].rank0(l), r0 = v[bit].rank0(r);
166                if ((x >> bit) & 1) {
167                    ans += r0 - l0;
168                    l += mid[bit] - l0;
169                    r += mid[bit] - r0;
170                } else {
171                    l = l0;
172                    r = r0;
173                }
174            }
175            return ans;
176        }
177
178        //`a[l, r)` に含まれる、 `x` 以上 `y` 未満である要素の個数を返します。
179        int range_freq(int l, int r, int x, int y) const {
180            return range_freq(l, r, y) - range_freq(l, r, x);
181        }
182
183        //`a[l, r)` で、`x` 以上 `y` 未満であるような要素のうち最大の要素を返します。
184        T prev_value(int l, int r, int x) const {
185            return kth_smallest(l, r, range_freq(l, r, x)-1);
186        }
187
188        T next_value(int l, int r, int x) const {
189            return kth_smallest(l, r, range_freq(l, r, x));
190        }
191
192        //`a[l, r)` に含まれる `x` の個数を返します。
193        int range_count(int l, int r, int x) const {
194            return rank(r, x) - rank(l, x);
195        }
196
197        int len() const {
198            return n;
199        }
200
201        friend ostream& operator<<(ostream& os, const titan23::WaveletMatrix<T>& wm) {
202            int n = wm.len();
203            os << "[";
204            for (int i = 0; i < n; ++i) {
205                os << wm.access(i);
206                if (i != n-1) os << ", ";
207            }
208            os << "]";
209            return os;
210        }
211    };
212}  // namespace titan23

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/wavelet_matrix.cpp