binary trie set

ソースコード

  1#include <iostream>
  2#include <vector>
  3using namespace std;
  4
  5namespace titan23 {
  6
  7  template<typename T>
  8  struct BinaryTrieSet {
  9    vector<unsigned int> left, right, par, size;
 10    unsigned int _end, _root, _bit;
 11    T _lim, _xor_val;
 12
 13    BinaryTrieSet(unsigned int bit) {
 14      _end = 2;
 15      _bit = bit;
 16      _root = 1;
 17      _lim = 1ll << bit;
 18      _xor_val = 0;
 19      left.resize(2);
 20      right.resize(2);
 21      par.resize(2);
 22      size.resize(2);
 23    }
 24
 25    void reserve(const int n) {
 26      left.reserve(n);
 27      right.reserve(n);
 28      par.reserve(n);
 29      size.reserve(n);
 30    }
 31
 32    unsigned int _make_node() {
 33      if (_end >= (int)left.size()) {
 34        left.emplace_back(0);
 35        right.emplace_back(0);
 36        par.emplace_back(0);
 37        size.emplace_back(0);
 38      }
 39      return _end++;
 40    }
 41
 42    unsigned int _find(T key) const {
 43      key ^= _xor_val;
 44      unsigned int node = _root;
 45      for (int i = _bit-1; i >= 0; --i) {
 46        if ((key >> i) & 1) {
 47          if (!right[node]) return -1;
 48          node = right[node];
 49        } else {
 50          if (!left[node]) return -1;
 51          node = left[node];
 52        }
 53      }
 54      return node;
 55    }
 56
 57    bool add(T key) {
 58      key ^= _xor_val;
 59      unsigned int node = _root;
 60      for (int i = _bit-1; i >= 0; --i) {
 61        if ((key >> i) & 1) {
 62          if (!right[node]) {
 63            unsigned int new_node = _make_node();
 64            right[node] = new_node;
 65            par[right[node]] = node;
 66          }
 67          node = right[node];
 68        } else {
 69          if (!left[node]) {
 70            unsigned int new_node = _make_node();
 71            left[node] = new_node;
 72            par[left[node]] = node;
 73          }
 74          node = left[node];
 75        }
 76      }
 77      if (size[node]) return false;
 78      size[node] = 1;
 79      for (int i = 0; i < _bit; ++i) {
 80        node = par[node];
 81        size[node] += 1;
 82      }
 83      return true;
 84    }
 85
 86    bool contains(const T key) const {
 87      return _find(key) != -1;
 88    }
 89
 90    void _discard(unsigned int node) {
 91      for (int i = 0; i < _bit; ++i) {
 92        size[node] -= 1;
 93        if (left[par[node]] == node) {
 94          node = par[node];
 95          left[node] = 0;
 96          if (right[node]) break;
 97        } else {
 98          node = par[node];
 99          right[node] = 0;
100          if (left[node]) break;
101        }
102      }
103      while (node) {
104        size[node] -= 1;
105        node = par[node];
106      }
107    }
108
109    bool discard(T key) {
110      unsigned int node = _find(key);
111      if (node == -1) return false;
112      _discard(node);
113      return true;
114    }
115
116    T pop(int k) {
117      if (k < 0) k += len();
118      int node = _root;
119      T res = 0;
120      for (int i = _bit-1; i >= 0; --i) {
121        if ((_xor_val >> i) & 1) swap(left, right);
122        unsigned int t = size[left[node]];
123        res <<= 1;
124        if (t <= k) {
125          k -= t;
126          res |= 1;
127          node = right[node];
128        } else {
129          node = left[node];
130        }
131        if ((_xor_val >> i) & 1) swap(left, right);
132      }
133      _discard(node);
134      return res ^ _xor_val;
135    }
136
137    T pop_min() {
138      return pop(0);
139    }
140
141    T pop_max() {
142      return pop(-1);
143    }
144
145    void all_xor(T x) {
146      _xor_val ^= x;
147    }
148
149    T get_min() const {
150      T key = _xor_val;
151      T ans = 0;
152      unsigned int node = _root;
153      for (int i = _bit-1; i >= 0; --i) {
154        ans <<= 1;
155        if ((key >> i) & 1) {
156          if (right[node]) {
157            node = right[node];
158            ans |= 1;
159          } else {
160            node = left[node];
161          }
162        } else {
163          if (left[node]) {
164            node = left[node];
165          } else {
166            node = right[node];
167            ans |= 1;
168          }
169        }
170      }
171      return ans ^ _xor_val;
172    }
173
174    T get_max() const {
175      T key = _xor_val;
176      T ans = 0;
177      unsigned int node = _root;
178      for (int i = _bit-1; i >= 0; --i) {
179        ans <<= 1;
180        if ((key >> i) & 1) {
181          if (left[node]) {
182            node = left[node];
183          } else {
184            node = right[node];
185            ans |= 1;
186          }
187        } else {
188          if (right[node]) {
189            ans |= 1;
190            node = right[node];
191          } else {
192            node = left[node];
193          }
194        }
195      }
196      return ans ^ _xor_val;
197    }
198
199    int index(T key) const {
200      int k = 0;
201      unsigned int node = _root;
202      key ^= _xor_val;
203      for (int i = _bit-1; i >= 0; --i) {
204        if ((key >> i) & 1) {
205          k += size[left[node]];
206          node = right[node];
207        } else {
208          node = left[node];
209        }
210        if (!node) break;
211      }
212      return k;
213    }
214
215    int index_right(T key) const {
216      int k = 0;
217      unsigned int node = _root;
218      key ^= _xor_val;
219      for (int i = _bit-1; i >= 0; --i) {
220        if ((key >> i) & 1) {
221          k += size[left[node]];
222          node = right[node];
223        } else {
224          node = left[node];
225        }
226        if (!node) break;
227      }
228      if (node) k += 1;
229      return k;
230    }
231
232    T kth_elm(int k) {
233      if (k < 0) k += len();
234      unsigned int node = _root;
235      T res = 0;
236      for (int i = _bit-1; i >= 0; --i) {
237        if ((_xor_val >> i) & 1) swap(left, right);
238        unsigned int t = size[left[node]];
239        res <<= 1;
240        if (t <= k) {
241          k -= t;
242          res |= 1;
243          node = right[node];
244        } else {
245          node = left[node];
246        }
247        if ((_xor_val >> i) & 1) swap(left, right);
248      }
249      return res;
250    }
251
252    T gt(T key) {
253      int i = index_right(key);
254      return (i >= size[_root]? (-1) : kth_elm(i));
255    }
256
257    T lt(T key) {
258      int i = index(key) - 1;
259      return (i < 0? -1 : kth_elm(i));
260    }
261
262    T ge(T key) {
263      if (key == 0) return (len()? get_min() : -1);
264      int i = index_right(key - 1);
265      return (i >= size[_root]? -1 : kth_elm(i));
266    }
267
268    T le(T key) {
269      int i = index(key + 1) - 1;
270      return (i < 0? -1 : kth_elm(i));
271    }
272
273    vector<T> tovector() {
274      vector<T> a;
275      if (!len()) return a;
276      a.reserve(len());
277      for (int i = 0; i < len(); ++i) {
278        T e = kth_elm(i);
279        a.emplace_back(e);
280      }
281      return a;
282    }
283
284    bool empty() const {
285      return size[_root] == 0;
286    }
287
288    void print() {
289      cout << "{";
290      vector<T> a = tovector();
291      for (int i = 0; i < len()-1; ++i) {
292        cout << a[i] << ", ";
293      }
294      if (!a.empty()) {
295        cout << a.back();
296      }
297      cout << "}" << endl;
298    }
299
300    int len() const {
301      return size[_root];
302    }
303  };
304}  // namespace titan23

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/binary_trie_set.cpp