hash dict

ソースコード

  1#include <vector>
  2#include <random>
  3#include <iostream>
  4#include <cassert>
  5
  6using namespace std;
  7
  8// HashDict
  9namespace titan23 {
 10
 11    template<typename V>
 12    class HashDict {
 13      private:
 14        using u64 = unsigned long long;
 15        static constexpr const u64 K = 0x517cc1b727220a95;
 16        static constexpr const int M = 2;
 17        vector<u64> exist;
 18        vector<u64> keys;
 19        vector<V> vals;
 20        int msk, xor_;
 21        int size;
 22
 23        int hash(const u64 &key) const {
 24            return (((((key>>32)&msk) ^ (key&msk) ^ xor_)) * (HashDict::K & msk)) & msk;
 25        }
 26
 27        int bit_length(const int x) const {
 28            if (x == 0) return 0;
 29            return 32 - __builtin_clz(x);
 30        }
 31
 32        void rebuild() {
 33            vector<u64> old_exist = exist;
 34            vector<u64> old_keys = keys;
 35            vector<V> old_vals = vals;
 36            exist.resize(HashDict::M*old_exist.size()+1);
 37            fill(exist.begin(), exist.end(), 0);
 38            keys.resize(HashDict::M*old_keys.size());
 39            vals.resize(HashDict::M*old_vals.size());
 40            size = 0;
 41            msk = (1<<bit_length(keys.size()-1))-1;
 42            random_device rd;
 43            mt19937 gen(rd());
 44            uniform_int_distribution<int> dis(0, msk);
 45            xor_ = dis(gen);
 46            for (int i = 0; i < (int)old_keys.size(); ++i) {
 47                if (old_exist[i>>6]>>(i&63)&1) {
 48                    set(old_keys[i], old_vals[i]);
 49                }
 50            }
 51        }
 52
 53      public:
 54        HashDict() : exist(1, 0), keys(1), vals(1), msk(0), xor_(0), size(0) {}
 55
 56        HashDict(const int n) {
 57            int s = 1<<bit_length(n);
 58            s *= HashDict::M;
 59            assert(s > 0);
 60            exist.resize((s>>6)+1, 0);
 61            keys.resize(s);
 62            vals.resize(s);
 63            msk = (1<<bit_length(keys.size()-1))-1;
 64            random_device rd;
 65            mt19937 gen(rd());
 66            uniform_int_distribution<int> dis(0, msk);
 67            xor_ = dis(gen);
 68            size = 0;
 69        }
 70
 71        pair<int, bool> get_pos(const u64 &key) const {
 72            int h = hash(key);
 73            while (true) {
 74                if (!(exist[h>>6]>>(h&63)&1)) return {h, false};
 75                if (keys[h] == key) return {h, true};
 76                h = (h + 1) & msk;
 77            }
 78        }
 79
 80        V get(const u64 key) const {
 81            const auto [pos, exist_res] = get_pos(key);
 82            if (!exist_res) return V();
 83            else return vals[pos];
 84        }
 85
 86        V get(const u64 key, const V missing) const {
 87            const auto [pos, exist_res] = get_pos(key);
 88            if (!exist_res) return missing;
 89            else return vals[pos];
 90        }
 91
 92        bool contains(const u64 key) const {
 93            return get_pos(key).second;
 94        }
 95
 96        pair<int, bool> pos(const u64 key) const {
 97            return get_pos(key);
 98        }
 99
100        V operator[] (const u64 key) const {
101            return get(key);
102        }
103
104        V inner_get(const pair<int, bool> &dat) {
105            const auto [pos, is_exist] = dat;
106            if (!is_exist) return V();
107            return vals[pos];
108        }
109
110        void inner_set(const pair<int, bool> &dat, const u64 key, const V val) {
111            const auto [pos, is_exist] = dat;
112            vals[pos] = val;
113            if (!is_exist) {
114                exist[pos>>6] |= 1ull<<(pos&63);
115                keys[pos] = key;
116                ++size;
117                if (HashDict::M*size > keys.size()) {
118                    rebuild();
119                }
120            }
121        }
122
123        void set(const u64 key, const V val) {
124            const auto [pos, is_exist] = get_pos(key);
125            vals[pos] = val;
126            if (!is_exist) {
127                exist[pos>>6] |= 1ull<<(pos&63);
128                keys[pos] = key;
129                ++size;
130                if (HashDict::M*size > keys.size()) {
131                    rebuild();
132                }
133            }
134        }
135
136        //! keyがすでにあればtrue, なければ挿入してfalse / `O(1)`
137        bool contains_set(const u64 key, const V val) {
138            const auto [pos, is_exist] = get_pos(key);
139            if (val < vals[pos]) {
140                vals[pos] = val;
141            } else {
142                return false;
143            }
144            if (!is_exist) {
145                exist[pos>>6] |= 1ull<<(pos&63);
146                keys[pos] = key;
147                ++size;
148                if (HashDict::M*size > keys.size()) {
149                    rebuild();
150                }
151                return false;
152            }
153            return true;
154        }
155
156        //! keyがすでにあればtrue, なければ挿入してfalse / `O(1)`
157        bool contains_insert(const u64 key) {
158            const auto [pos, is_exist] = get_pos(key);
159            if (!is_exist) {
160                exist[pos>>6] |= 1ull<<(pos&63);
161                keys[pos] = key;
162                ++size;
163                if (HashDict::M*size > keys.size()) {
164                    rebuild();
165                }
166                return false;
167            }
168            return true;
169        }
170
171        vector<V> values() const {
172            vector<V> res;
173            res.reserve(len());
174            for (int i = 0; i < (int)keys.size(); ++i) {
175                if (exist[i>>6]>>(i&63)&1) {
176                    res.emplace_back(vals[i]);
177                }
178            }
179            return res;
180        }
181
182        vector<pair<u64, V>> items() const {
183            vector<pair<u64, V>> res;
184            res.reserve(len());
185            for (int i = 0; i < (int)keys.size(); ++i) {
186                if (exist[i>>6]>>(i&63)&1) {
187                    res.emplace_back(keys[i], vals[i]);
188                }
189            }
190            return res;
191        }
192
193        //! 全ての要素を削除する / `O(n/w)`
194        void clear() {
195            this->size = 0;
196            fill(exist.begin(), exist.end(), 0);
197        }
198
199        int len() const {
200            return size;
201        }
202    };
203} // namespaced titan23

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/hash_dict.cpp