wavelet matrix¶
ソースコード¶
1#include <vector>
2#include <queue>
3#include "titan_cpplib/data_structures/bit_vector.cpp"
4using namespace std;
5
6// WaveletMatrix
7namespace titan23 {
8
9 template<typename T>
10 class WaveletMatrix {
11
12 private:
13 T sigma;
14 int log;
15 vector<BitVector> v;
16 vector<int> mid;
17 int n;
18
19 int bit_length(const int n) const {
20 return n == 0 ? 0 : 32 - __builtin_clz(n);
21 }
22
23 void build(vector<T> a) {
24 for (int bit = log-1; bit >= 0; --bit) {
25 vector<T> zero, one;
26 v[bit] = BitVector(n);
27 for (int i = 0; i < n; ++i) {
28 if ((a[i] >> bit) & 1) {
29 v[bit].set(i);
30 one.emplace_back(a[i]);
31 } else {
32 zero.emplace_back(a[i]);
33 }
34 }
35 v[bit].build();
36 mid[bit] = zero.size();
37 a = zero;
38 a.insert(a.end(), one.begin(), one.end());
39 assert(a.size() == n);
40 }
41 }
42
43 public:
44 WaveletMatrix() {}
45
46 WaveletMatrix(const T sigma)
47 : sigma(sigma), log(bit_length(sigma-1)), v(log), mid(log), n(0) {}
48
49 WaveletMatrix(const T sigma, const vector<T> &a)
50 : sigma(sigma), log(bit_length(sigma-1)), v(log), mid(log), n(a.size()) {
51 build(a);
52 }
53
54 T access(int k) const {
55 T s = 0;
56 for (int bit = log-1; bit >= 0; --bit) {
57 if (v[bit].access(k)) {
58 s |= (T)1 << bit;
59 k = v[bit].rank1(k) + mid[bit];
60 } else {
61 k = v[bit].rank0(k);
62 }
63 }
64 return s;
65 }
66
67 //! `a[0, r)` に含まれる `x` の個数を返します。
68 int rank(int r, int x) const {
69 int l = 0;
70 for (int bit = log-1; bit >= 0; --bit) {
71 if ((x >> bit) & 1) {
72 l = v[bit].rank1(l) + mid[bit];
73 r = v[bit].rank1(r) + mid[bit];
74 } else {
75 l = v[bit].rank0(l);
76 r = v[bit].rank0(r);
77 }
78 }
79 return r - l;
80 }
81
82 // `k` 番目の `v` のインデックスを返す。
83 int select(int k, int x) const {
84 int s = 0;
85 for (int bit = log-1; bit >= 0; --bit) {
86 if ((x >> bit) & 1) {
87 s = v[bit].rank0(n) + v[bit].rank1(s);
88 } else {
89 s = v[bit].rank0(s);
90 }
91 }
92 s += k;
93 for (int bit = 0; bit < log; ++bit) {
94 if ((x >> bit) & 1) {
95 s = v[bit].select1(s - v[bit].rank0(n));
96 } else {
97 s = v[bit].select0(s);
98 }
99 }
100 return s;
101 }
102
103 // `a[l, r)` の中で k 番目に **小さい** 値を返します。
104 T kth_smallest(int l, int r, int k) const {
105 T s = 0;
106 for (int bit = log-1; bit >= 0; --bit) {
107 const int r0 = v[bit].rank0(r), l0 = v[bit].rank0(l);
108 const int cnt = r0 - l0;
109 if (cnt <= k) {
110 s |= (T)1 << bit;
111 k -= cnt;
112 l = l - l0 + mid[bit];
113 r = r - r0 + mid[bit];
114 } else {
115 l = l0;
116 r = r0;
117 }
118 }
119 return s;
120 }
121
122 T kth_largest(int l, int r, int k) const {
123 return kth_smallest(l, r, r-l-k-1);
124 }
125
126 // `a[l, r)` の中で、要素を出現回数が多い順にその頻度とともに `k` 個返します。
127 vector<pair<int, int>> topk(int l, int r, int k) {
128 // heap[-length, x, l, bit]
129 priority_queue<tuple<int, T, int, int>> hq;
130 hq.emplace(r-l, 0, l, log-1);
131 vector<pair<T, int>> ans;
132 while (!hq.empty()) {
133 auto [length, x, l, bit] = hq.top();
134 hq.pop();
135 if (bit == -1) {
136 ans.emplace_back(x, length);
137 k -= 1;
138 if (k == 0) break;
139 } else {
140 r = l + length;
141 int l0 = v[bit].rank0(l);
142 int r0 = v[bit].rank0(r);
143 if (l0 < r0) hq.emplace(r0-l0, x, l0, bit-1);
144 int l1 = v[bit].rank1(l) + mid[bit];
145 int r1 = v[bit].rank1(r) + mid[bit];
146 if (l1 < r1) hq.emplace(r1-l1, x|((T)1<<bit), l1, bit-1);
147 }
148 }
149 return ans;
150 }
151
152 T sum(int l, int r) const {
153 assert(false);
154 T s = 0;
155 for (const auto &[sum, cnt]: topk(l, r, r-l)) {
156 s += sum * cnt;
157 }
158 return s;
159 }
160
161 // a[l, r) で x 未満の要素の数を返す'''
162 int range_freq(int l, int r, int x) const {
163 int ans = 0;
164 for (int bit = log-1; bit >= 0; --bit) {
165 int l0 = v[bit].rank0(l), r0 = v[bit].rank0(r);
166 if ((x >> bit) & 1) {
167 ans += r0 - l0;
168 l += mid[bit] - l0;
169 r += mid[bit] - r0;
170 } else {
171 l = l0;
172 r = r0;
173 }
174 }
175 return ans;
176 }
177
178 //`a[l, r)` に含まれる、 `x` 以上 `y` 未満である要素の個数を返します。
179 int range_freq(int l, int r, int x, int y) const {
180 return range_freq(l, r, y) - range_freq(l, r, x);
181 }
182
183 //`a[l, r)` で、`x` 以上 `y` 未満であるような要素のうち最大の要素を返します。
184 T prev_value(int l, int r, int x) const {
185 return kth_smallest(l, r, range_freq(l, r, x)-1);
186 }
187
188 T next_value(int l, int r, int x) const {
189 return kth_smallest(l, r, range_freq(l, r, x));
190 }
191
192 //`a[l, r)` に含まれる `x` の個数を返します。
193 int range_count(int l, int r, int x) const {
194 return rank(r, x) - rank(l, x);
195 }
196
197 int len() const {
198 return n;
199 }
200
201 friend ostream& operator<<(ostream& os, const titan23::WaveletMatrix<T>& wm) {
202 int n = wm.len();
203 os << "[";
204 for (int i = 0; i < n; ++i) {
205 os << wm.access(i);
206 if (i != n-1) os << ", ";
207 }
208 os << "]";
209 return os;
210 }
211 };
212} // namespace titan23
仕様¶
Warning
doxygenfile: Cannot find file “titan_cpplib/data_structures/wavelet_matrix.cpp