binary trie multiset

ソースコード

  1#include <vector>
  2#include <cassert>
  3using namespace std;
  4
  5// BinaryTrieMultiset
  6namespace titan23 {
  7
  8  template<typename T>
  9  class BinaryTrieMultiset {
 10   private:
 11    vector<unsigned int> left, right, par, size;
 12    unsigned int _end, _root, _bit;
 13    T _lim, _xor_val;
 14
 15    unsigned int _make_node() {
 16      if (_end >= (int)left.size()) {
 17        left.emplace_back(0);
 18        right.emplace_back(0);
 19        par.emplace_back(0);
 20        size.emplace_back(0);
 21      }
 22      return _end++;
 23    }
 24
 25    int _find(T key) const {
 26      key ^= _xor_val;
 27      unsigned int node = _root;
 28      for (int i = _bit-1; i >= 0; --i) {
 29        if ((key >> i) & 1) {
 30          if (!right[node]) return -1;
 31          node = right[node];
 32        } else {
 33          if (!left[node]) return -1;
 34          node = left[node];
 35        }
 36      }
 37      return node;
 38    }
 39
 40    void _remove(unsigned int node) {
 41      unsigned int cnt = size[node];
 42      for (int i = 0; i < _bit; ++i) {
 43        size[node] -= cnt;
 44        if (left[par[node]] == node) {
 45          node = par[node];
 46          left[node] = 0;
 47          if (right[node]) break;
 48        } else {
 49          node = par[node];
 50          right[node] = 0;
 51          if (left[node]) break;
 52        }
 53      }
 54      while (node) {
 55        size[node] -= cnt;
 56        node = par[node];
 57      }
 58    }
 59
 60   public:
 61    BinaryTrieMultiset(const unsigned int bit) {
 62      _end = 2;
 63      _bit = bit;
 64      _root = 1;
 65      _lim = (1ll) << bit;
 66      _xor_val = 0;
 67      left.resize(2);
 68      right.resize(2);
 69      par.resize(2);
 70      size.resize(2);
 71    }
 72
 73    void reserve(const int n) {
 74      left.reserve(n);
 75      right.reserve(n);
 76      par.reserve(n);
 77      size.reserve(n);
 78    }
 79
 80    void add(T key, int cnt = 1) {
 81      assert(0 <= key && key < _lim);
 82      key ^= _xor_val;
 83      unsigned int node = _root;
 84      for (int i = _bit-1; i >= 0; --i) {
 85        size[node] += cnt;
 86        if ((key >> i) & 1) {
 87          if (!right[node]) {
 88            unsigned int new_node = _make_node();
 89            right[node] = new_node;
 90            par[right[node]] = node;
 91          }
 92          node = right[node];
 93        } else {
 94          if (!left[node]) {
 95            unsigned int new_node = _make_node();
 96            left[node] = new_node;
 97            par[left[node]] = node;
 98          }
 99          node = left[node];
100        }
101      }
102      size[node] += cnt;
103    }
104
105    bool contains(T key) const {
106      return _find(key) != -1;
107    }
108
109    bool discard(T key, int cnt = 1) {
110      assert(0 <= key && key < _lim);
111      unsigned int node = _find(key);
112      if (node == -1) {
113        return false;
114      } else if (size[node] <= cnt) {
115        _remove(node);
116      } else {
117        while (node) {
118          size[node] -= cnt;
119          node = par[node];
120        }
121      }
122      return true;
123    }
124
125    bool discard_all(T key) {
126      return discard(key, count(key));
127    }
128
129    void remove(T key, int cnt = 1) {
130      key ^= _xor_val;
131      unsigned int node = _root;
132      for (int i = _bit-1; i >= 0; --i) {
133        if ((key >> i) & 1) {
134          node = right[node];
135        } else {
136          node = left[node];
137        }
138        assert(node);
139      }
140      unsigned int c = size[node];
141      if (c < cnt) {
142        assert(false);
143      } else if (c == cnt) {
144        _remove(node);
145      } else {
146        while (node) {
147          size[node] -= cnt;
148          node = par[node];
149        }
150      }
151    }
152
153    int count(T key) const {
154      unsigned int node = _find(key);
155      return node == -1 ? 0 : size[node];
156    }
157
158    T pop(int k = -1) {
159      if (k < 0) k += len();
160      unsigned int node = _root;
161      T key = _xor_val;
162      T res = 0;
163      for (int i = _bit-1; i >= 0; --i) {
164        res <<= 1;
165        if ((key >> i) & 1) {
166          unsigned int t = size[right[node]];
167          if (t <= k) {
168            k -= t;
169            res |= 1;
170            node = left[node];
171          } else {
172            node = right[node];
173          }
174        } else {
175          unsigned int t = size[left[node]];
176          if (t <= k) {
177            k -= t;
178            res |= 1;
179            node = right[node];
180          } else {
181            node = left[node];
182          }
183        }
184      }
185      if (size[node] == 1) {
186        _remove(node);
187      } else {
188        while (node) {
189          --size[node];
190          node = par[node];
191        }
192      }
193      return res ^ _xor_val;
194    }
195
196    T pop_min() {
197      return pop(0);
198    }
199
200    T pop_max() {
201      return pop(-1);
202    }
203
204    void all_xor(T x) {
205      _xor_val ^= x;
206    }
207
208    T get_min() const {
209      assert(len() > 0);
210      T key = _xor_val;
211      T ans = 0;
212      unsigned int node = _root;
213      for (int i = _bit-1; i >= 0; --i) {
214        ans <<= 1;
215        if ((key >> i) & 1) {
216          if (right[node]) {
217            node = right[node];
218            ans |= 1;
219          } else {
220            node = left[node];
221          }
222        } else {
223          if (left[node]) {
224            node = left[node];
225          } else {
226            node = right[node];
227            ans |= 1;
228          }
229        }
230      }
231      return ans ^ _xor_val;
232    }
233
234    T get_max() const {
235      assert(len() > 0);
236      T key = _xor_val;
237      T ans = 0;
238      unsigned int node = _root;
239      for (int i = _bit-1; i >= 0; --i) {
240        ans <<= 1;
241        if ((key >> i) & 1) {
242          if (left[node]) {
243            node = left[node];
244          } else {
245            node = right[node];
246            ans |= 1;
247          }
248        } else {
249          if (right[node]) {
250            ans |= 1;
251            node = right[node];
252          } else {
253            node = left[node];
254          }
255        }
256      }
257      return ans ^ _xor_val;
258    }
259
260    int index(T key) const {
261      assert(0 <= key && key < _lim);
262      int k = 0;
263      unsigned int node = _root;
264      key ^= _xor_val;
265      for (int i = _bit-1; i >= 0; --i) {
266        if ((key >> i) & 1) {
267          k += size[left[node]];
268          node = right[node];
269        } else {
270          node = left[node];
271        }
272        if (!node) break;
273      }
274      return k;
275    }
276
277    int index_right(T key) const {
278      int k = 0;
279      unsigned int node = _root;
280      key ^= _xor_val;
281      for (int i = _bit-1; i >= 0; --i) {
282        if ((key >> i) & 1) {
283          k += size[left[node]];
284          node = right[node];
285        } else {
286          node = left[node];
287        }
288        if (!node) break;
289      }
290      if (node) k += 1;
291      return k;
292    }
293
294    T get(int k) const {
295      if (k < 0) k += len();
296      unsigned int node = _root;
297      T res = 0;
298      for (int i = _bit-1; i >= 0; --i) {
299        if ((_xor_val >> i) & 1) {
300          unsigned int t = size[right[node]];
301          res <<= 1;
302          if (t <= k) {
303            k -= t;
304            res |= 1;
305            node = left[node];
306          } else {
307            node = right[node];
308          }
309        } else {
310          unsigned int t = size[left[node]];
311          res <<= 1;
312          if (t <= k) {
313            k -= t;
314            res |= 1;
315            node = right[node];
316          } else {
317            node = left[node];
318          }
319        }
320      }
321      return res;
322    }
323
324    T gt(T key) const {
325      int i = index_right(key);
326      return (i >= size[_root]? (-1) : get(i));
327    }
328
329    T lt(T key) const {
330      int i = index(key) - 1;
331      return (i < 0? -1 : get(i));
332    }
333
334    T ge(T key) const {
335      if (key == 0) return (len()? get_min() : -1);
336      int i = index_right(key - 1);
337      return (i >= size[_root]? -1 : get(i));
338    }
339
340    T le(T key) const {
341      int i = index(key + 1) - 1;
342      return (i < 0? -1 : get(i));
343    }
344
345    vector<T> tovector() const {
346      vector<T> a;
347      if (!len()) return a;
348      a.reserve(len());
349      for (int i = 0; i < len(); ++i) {
350        T e = get(i);
351        a.emplace_back(e);
352      }
353      return a;
354    }
355
356    bool empty() const {
357      return size[_root] == 0;
358    }
359
360    void clear() {
361      _root = 1;
362    }
363
364    void print() const {
365      cout << "{";
366      vector<T> a = tovector();
367      for (int i = 0; i < len()-1; ++i) {
368        cout << a[i] << ", ";
369      }
370      if (!a.empty()) {
371        cout << a.back();
372      }
373      cout << "}" << endl;
374    }
375
376    int len() const {
377      return size[_root];
378    }
379  };
380}  // namespace titan23

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/binary_trie_multiset.cpp