binary trie multiset¶
ソースコード¶
1#include <vector>
2#include <cassert>
3using namespace std;
4
5// BinaryTrieMultiset
6namespace titan23 {
7
8 template<typename T>
9 class BinaryTrieMultiset {
10 private:
11 vector<unsigned int> left, right, par, size;
12 unsigned int _end, _root, _bit;
13 T _lim, _xor_val;
14
15 unsigned int _make_node() {
16 if (_end >= (int)left.size()) {
17 left.emplace_back(0);
18 right.emplace_back(0);
19 par.emplace_back(0);
20 size.emplace_back(0);
21 }
22 return _end++;
23 }
24
25 int _find(T key) const {
26 key ^= _xor_val;
27 unsigned int node = _root;
28 for (int i = _bit-1; i >= 0; --i) {
29 if ((key >> i) & 1) {
30 if (!right[node]) return -1;
31 node = right[node];
32 } else {
33 if (!left[node]) return -1;
34 node = left[node];
35 }
36 }
37 return node;
38 }
39
40 void _remove(unsigned int node) {
41 unsigned int cnt = size[node];
42 for (int i = 0; i < _bit; ++i) {
43 size[node] -= cnt;
44 if (left[par[node]] == node) {
45 node = par[node];
46 left[node] = 0;
47 if (right[node]) break;
48 } else {
49 node = par[node];
50 right[node] = 0;
51 if (left[node]) break;
52 }
53 }
54 while (node) {
55 size[node] -= cnt;
56 node = par[node];
57 }
58 }
59
60 public:
61 BinaryTrieMultiset(const unsigned int bit) {
62 _end = 2;
63 _bit = bit;
64 _root = 1;
65 _lim = (1ll) << bit;
66 _xor_val = 0;
67 left.resize(2);
68 right.resize(2);
69 par.resize(2);
70 size.resize(2);
71 }
72
73 void reserve(const int n) {
74 left.reserve(n);
75 right.reserve(n);
76 par.reserve(n);
77 size.reserve(n);
78 }
79
80 void add(T key, int cnt = 1) {
81 assert(0 <= key && key < _lim);
82 key ^= _xor_val;
83 unsigned int node = _root;
84 for (int i = _bit-1; i >= 0; --i) {
85 size[node] += cnt;
86 if ((key >> i) & 1) {
87 if (!right[node]) {
88 unsigned int new_node = _make_node();
89 right[node] = new_node;
90 par[right[node]] = node;
91 }
92 node = right[node];
93 } else {
94 if (!left[node]) {
95 unsigned int new_node = _make_node();
96 left[node] = new_node;
97 par[left[node]] = node;
98 }
99 node = left[node];
100 }
101 }
102 size[node] += cnt;
103 }
104
105 bool contains(T key) const {
106 return _find(key) != -1;
107 }
108
109 bool discard(T key, int cnt = 1) {
110 assert(0 <= key && key < _lim);
111 unsigned int node = _find(key);
112 if (node == -1) {
113 return false;
114 } else if (size[node] <= cnt) {
115 _remove(node);
116 } else {
117 while (node) {
118 size[node] -= cnt;
119 node = par[node];
120 }
121 }
122 return true;
123 }
124
125 bool discard_all(T key) {
126 return discard(key, count(key));
127 }
128
129 void remove(T key, int cnt = 1) {
130 key ^= _xor_val;
131 unsigned int node = _root;
132 for (int i = _bit-1; i >= 0; --i) {
133 if ((key >> i) & 1) {
134 node = right[node];
135 } else {
136 node = left[node];
137 }
138 assert(node);
139 }
140 unsigned int c = size[node];
141 if (c < cnt) {
142 assert(false);
143 } else if (c == cnt) {
144 _remove(node);
145 } else {
146 while (node) {
147 size[node] -= cnt;
148 node = par[node];
149 }
150 }
151 }
152
153 int count(T key) const {
154 unsigned int node = _find(key);
155 return node == -1 ? 0 : size[node];
156 }
157
158 T pop(int k = -1) {
159 if (k < 0) k += len();
160 unsigned int node = _root;
161 T key = _xor_val;
162 T res = 0;
163 for (int i = _bit-1; i >= 0; --i) {
164 res <<= 1;
165 if ((key >> i) & 1) {
166 unsigned int t = size[right[node]];
167 if (t <= k) {
168 k -= t;
169 res |= 1;
170 node = left[node];
171 } else {
172 node = right[node];
173 }
174 } else {
175 unsigned int t = size[left[node]];
176 if (t <= k) {
177 k -= t;
178 res |= 1;
179 node = right[node];
180 } else {
181 node = left[node];
182 }
183 }
184 }
185 if (size[node] == 1) {
186 _remove(node);
187 } else {
188 while (node) {
189 --size[node];
190 node = par[node];
191 }
192 }
193 return res ^ _xor_val;
194 }
195
196 T pop_min() {
197 return pop(0);
198 }
199
200 T pop_max() {
201 return pop(-1);
202 }
203
204 void all_xor(T x) {
205 _xor_val ^= x;
206 }
207
208 T get_min() const {
209 assert(len() > 0);
210 T key = _xor_val;
211 T ans = 0;
212 unsigned int node = _root;
213 for (int i = _bit-1; i >= 0; --i) {
214 ans <<= 1;
215 if ((key >> i) & 1) {
216 if (right[node]) {
217 node = right[node];
218 ans |= 1;
219 } else {
220 node = left[node];
221 }
222 } else {
223 if (left[node]) {
224 node = left[node];
225 } else {
226 node = right[node];
227 ans |= 1;
228 }
229 }
230 }
231 return ans ^ _xor_val;
232 }
233
234 T get_max() const {
235 assert(len() > 0);
236 T key = _xor_val;
237 T ans = 0;
238 unsigned int node = _root;
239 for (int i = _bit-1; i >= 0; --i) {
240 ans <<= 1;
241 if ((key >> i) & 1) {
242 if (left[node]) {
243 node = left[node];
244 } else {
245 node = right[node];
246 ans |= 1;
247 }
248 } else {
249 if (right[node]) {
250 ans |= 1;
251 node = right[node];
252 } else {
253 node = left[node];
254 }
255 }
256 }
257 return ans ^ _xor_val;
258 }
259
260 int index(T key) const {
261 assert(0 <= key && key < _lim);
262 int k = 0;
263 unsigned int node = _root;
264 key ^= _xor_val;
265 for (int i = _bit-1; i >= 0; --i) {
266 if ((key >> i) & 1) {
267 k += size[left[node]];
268 node = right[node];
269 } else {
270 node = left[node];
271 }
272 if (!node) break;
273 }
274 return k;
275 }
276
277 int index_right(T key) const {
278 int k = 0;
279 unsigned int node = _root;
280 key ^= _xor_val;
281 for (int i = _bit-1; i >= 0; --i) {
282 if ((key >> i) & 1) {
283 k += size[left[node]];
284 node = right[node];
285 } else {
286 node = left[node];
287 }
288 if (!node) break;
289 }
290 if (node) k += 1;
291 return k;
292 }
293
294 T get(int k) const {
295 if (k < 0) k += len();
296 unsigned int node = _root;
297 T res = 0;
298 for (int i = _bit-1; i >= 0; --i) {
299 if ((_xor_val >> i) & 1) {
300 unsigned int t = size[right[node]];
301 res <<= 1;
302 if (t <= k) {
303 k -= t;
304 res |= 1;
305 node = left[node];
306 } else {
307 node = right[node];
308 }
309 } else {
310 unsigned int t = size[left[node]];
311 res <<= 1;
312 if (t <= k) {
313 k -= t;
314 res |= 1;
315 node = right[node];
316 } else {
317 node = left[node];
318 }
319 }
320 }
321 return res;
322 }
323
324 T gt(T key) const {
325 int i = index_right(key);
326 return (i >= size[_root]? (-1) : get(i));
327 }
328
329 T lt(T key) const {
330 int i = index(key) - 1;
331 return (i < 0? -1 : get(i));
332 }
333
334 T ge(T key) const {
335 if (key == 0) return (len()? get_min() : -1);
336 int i = index_right(key - 1);
337 return (i >= size[_root]? -1 : get(i));
338 }
339
340 T le(T key) const {
341 int i = index(key + 1) - 1;
342 return (i < 0? -1 : get(i));
343 }
344
345 vector<T> tovector() const {
346 vector<T> a;
347 if (!len()) return a;
348 a.reserve(len());
349 for (int i = 0; i < len(); ++i) {
350 T e = get(i);
351 a.emplace_back(e);
352 }
353 return a;
354 }
355
356 bool empty() const {
357 return size[_root] == 0;
358 }
359
360 void clear() {
361 _root = 1;
362 }
363
364 void print() const {
365 cout << "{";
366 vector<T> a = tovector();
367 for (int i = 0; i < len()-1; ++i) {
368 cout << a[i] << ", ";
369 }
370 if (!a.empty()) {
371 cout << a.back();
372 }
373 cout << "}" << endl;
374 }
375
376 int len() const {
377 return size[_root];
378 }
379 };
380} // namespace titan23
仕様¶
Warning
doxygenfile: Cannot find file “titan_cpplib/data_structures/binary_trie_multiset.cpp