dynamic wavelet matrix¶
ソースコード¶
1#include <iostream>
2#include <vector>
3#include <tuple>
4#include <queue>
5#include "titan_cpplib/data_structures/avl_tree_bit_vector.cpp"
6#include "titan_cpplib/others/print.cpp"
7using namespace std;
8
9// DynamicWaveletMatrix
10namespace titan23 {
11
12 /**
13 * @brief 動的ウェーブレット行列
14 *
15 * @note `DynamicWavletTree` の方が平均的には高速かもしれません
16 *
17 * @tparam T
18 */
19 template<typename T>
20 class DynamicWaveletMatrix {
21 private:
22 T _sigma;
23 int _log;
24 vector<AVLTreeBitVector> _v;
25 vector<int> _mid;
26 int _size;
27
28 int bit_length(const int n) const {
29 return n == 0 ? 0 : 32 - __builtin_clz(n);
30 }
31
32 int bit_length(const long long n) const {
33 return n == 0 ? 0 : 64 - __builtin_clz(n);
34 }
35
36 int bit_length(const unsigned long long n) const {
37 return n == 0 ? 0 : 64 - __builtin_clz(n);
38 }
39
40 void _build(vector<T> a) {
41 if (a.empty()) return;
42 vector<uint8_t> v(_size);
43 for (int bit = _log-1; bit >= 0; --bit) {
44 vector<T> zero, one;
45 for (int i = 0; i < _size; ++i) {
46 if ((a[i] >> bit) & 1) {
47 v[i] = 1;
48 one.emplace_back(a[i]);
49 } else {
50 v[i] = 0;
51 zero.emplace_back(a[i]);
52 }
53 }
54 _mid[bit] = zero.size();
55 _v[bit] = AVLTreeBitVector(v);
56 a = zero;
57 a.insert(a.end(), one.begin(), one.end());
58 }
59 }
60
61 public:
62 DynamicWaveletMatrix(const T sigma)
63 : _sigma(sigma), _log(bit_length(sigma-1)), _v(_log), _mid(_log), _size(0) {
64 }
65
66 DynamicWaveletMatrix(const T sigma, vector<T> &a)
67 : _sigma(sigma), _log(bit_length(sigma)), _v(_log), _mid(_log), _size(a.size()) {
68 _build(a);
69 }
70
71 void reserve(const int n) {
72 for (int i = 0; i < _log; ++i) {
73 _v[i].reserve(n);
74 }
75 }
76
77 void insert(int k, T x) {
78 for (int bit = _log-1; bit >= 0; --bit) {
79 if ((x >> bit) & 1) {
80 int s = _v[bit]._insert_and_rank1(k, 1);
81 k = s + _mid[bit];
82 } else {
83 int s = _v[bit]._insert_and_rank1(k, 0);
84 k -= s;
85 ++_mid[bit];
86 }
87 }
88 _size++;
89 }
90
91 T pop(int k) {
92 T ans = 0;
93 for (int bit = _log-1; bit >= 0; --bit) {
94 int sb = _v[bit]._access_pop_and_rank1(k);
95 int s = sb >> 1;
96 if (sb & 1) {
97 ans |= (T)1 << bit;
98 k = s + _mid[bit];
99 } else {
100 --_mid[bit];
101 k -= s;
102 }
103 }
104 _size--;
105 return ans;
106 }
107
108 void set(int k, T x) {
109 assert(0 <= k && k < _size);
110 pop(k);
111 insert(k, x);
112 }
113
114 int rank(int r, T x) const {
115 int l = 0;
116 for (int bit = _log-1; bit >= 0; --bit) {
117 if ((x >> bit) & 1) {
118 l = _v[bit].rank1(l) + _mid[bit];
119 r = _v[bit].rank1(r) + _mid[bit];
120 } else {
121 l = _v[bit].rank0(l);
122 r = _v[bit].rank0(r);
123 }
124 }
125 return r - l;
126 }
127
128 T access(int k) const {
129 T s = 0;
130 for (int bit = _log-1; bit >= 0; --bit) {
131 if (_v[bit].access(k)) {
132 s |= (T)1 << bit;
133 k = _v[bit].rank1(k) + _mid[bit];
134 } else {
135 k = _v[bit].rank0(k);
136 }
137 }
138 return s;
139 }
140
141 T kth_smallest(int l, int r, int k) const {
142 T s = 0;
143 for (int bit = _log-1; bit >= 0; --bit) {
144 int l0 = _v[bit].rank0(l);
145 int r0 = _v[bit].rank0(r);
146 int cnt = r0 - l0;
147 if (cnt <= k) {
148 s |= (T)1 << bit;
149 k -= cnt;
150 l = l - l0 + _mid[bit];
151 r = r - r0 + _mid[bit];
152 } else {
153 l = l0;
154 r = r0;
155 }
156 }
157 return s;
158 }
159
160 T kth_largest(int l, int r, int k) const {
161 return kth_smallest(l, r, r-l-k-1);
162 }
163
164 vector<pair<T, int>> topk(int l, int r, int k) {
165 priority_queue<tuple<int, T, int, char>> hq;
166 hq.emplace(r-l, 0, l, _log-1);
167 vector<pair<T, int>> ans;
168 while (!hq.empty()) {
169 auto [length, x, l, bit] = hq.top();
170 hq.pop();
171 if (bit == -1) {
172 ans.emplace_back(x, length);
173 --k;
174 if (k == 0) break;
175 } else {
176 int r = l + length;
177 int l0 = _v[bit].rank0(l);
178 int r0 = _v[bit].rank0(r);
179 if (l0 < r0) hq.emplace(r0-l0, x, l0, bit-1);
180 int l1 = _v[bit].rank1(l) + _mid[bit];
181 int r1 = _v[bit].rank1(r) + _mid[bit];
182 if (l1 < r1) hq.emplace(r1-l1, x|((T)1<<(T)bit), l1, bit-1);
183 }
184 }
185 return ans;
186 }
187
188 int select(int k, T x) const {
189 T s = 0;
190 for (int bit = _log-1; bit >= 0; --bit) {
191 if ((x >> bit) & 1) {
192 s = _v[bit].rank0(_size) + _v[bit].rank1(s);
193 } else {
194 s = _v[bit].rank0(s);
195 }
196 }
197 s += k;
198 for (int bit = 0; bit < _log; ++bit) {
199 if ((x >> bit) & 1) {
200 s = _v[bit].select1(s - _v[bit].rank0(_size));
201 } else {
202 s = _v[bit].select0(s);
203 }
204 }
205 return s;
206 }
207
208 int range_freq(int l, int r, T x) const {
209 int ans = 0;
210 for (int bit = _log-1; bit >= 0; --bit) {
211 int l0 = _v[bit].rank0(l);
212 int r0 = _v[bit].rank0(r);
213 if ((x >> bit) & 1) {
214 ans += r0 - l0;
215 l += _mid[bit] - l0;
216 r += _mid[bit] - r0;
217 } else {
218 l = l0;
219 r = r0;
220 }
221 }
222 return ans;
223 }
224
225 int range_freq(int l, int r, int x, int y) const {
226 return range_freq(l, r, y) - range_freq(l, r, x);
227 }
228
229 //! 区間[l, r)で、x未満のうち最大の要素を返す
230 T prev_value(int l, int r, T x) const {
231 int k = range_freq(l, r, x)-1;
232 if (k < 0) return -1;
233 return kth_smallest(l, r, k);
234 }
235
236 //! 区間[l, r)で、x以上のうち最小の要素を返す
237 T next_value(int l, int r, T x) const {
238 int k = range_freq(l, r, x);
239 if (k >= r-l) return -1;
240 return kth_smallest(l, r, k);
241 }
242
243 int range_count(int l, int r, T x) const {
244 return rank(r, x) - rank(l, x);
245 }
246
247 vector<T> tovector() const {
248 vector<T> a(_size);
249 for (int i = 0; i < _size; ++i) {
250 a[i] = access(i);
251 }
252 return a;
253 }
254
255 void print() const {
256 vector<T> a = tovector();
257 int n = (int)a.size();
258 cout << "[";
259 for (int i = 0; i < n-1; ++i) {
260 cout << a[i] << ", ";
261 }
262 if (n > 0) {
263 cout << a.back();
264 }
265 cout << "]";
266 cout << endl;
267 }
268
269 friend ostream& operator<<(ostream& os, const titan23::DynamicWaveletMatrix<T> &dwm) {
270 vector<T> a = dwm.tovector();
271 os << a;
272 return os;
273 }
274 };
275} // namespace titan23
仕様¶
Warning
doxygenfile: Cannot find file “titan_cpplib/data_structures/dynamic_wavelet_matrix.cpp