multiset sum qd

ソースコード

  1#include <vector>
  2#include <algorithm>
  3using namespace std;
  4
  5// MultisetSum
  6namespace titan23 {
  7
  8template<typename T>
  9class MultisetSum {
 10    // ref: https://github.com/tatyam-prime/SortedSet/blob/main/SortedMultiset.py
 11
 12  private:
 13    const int BUCKET_RATIO = 16;
 14    const int SPLIT_RATIO = 24;
 15    int n;
 16    T missing;
 17    vector<vector<T>> data;
 18    vector<T> bucket_data;
 19
 20    int bisect_left(const vector<T> &a, const T &key) const {
 21        return lower_bound(a.begin(), a.end(), key) - a.begin();
 22    }
 23
 24    int bisect_right(const vector<T> &a, const T &key) const {
 25        return upper_bound(a.begin(), a.end(), key) - a.begin();
 26    }
 27
 28    pair<int, int> get_pos(const T key) const {
 29        if (data.empty()) return {0, 0};
 30        if (key > data.back().back()) return {data.size()-1, data.back().size()};
 31        int idx = lower_bound(data.begin(), data.end(), key, [&] (const vector<T> &vec, const T &key) -> bool {
 32            return vec.back() < key;
 33        }) - data.begin();
 34        assert(idx < data.size() && key <= data[idx].back());
 35        return {idx, bisect_left(data[idx], key)};
 36    }
 37
 38    void rebuild_split(int i) {
 39        int m = data[i].size();
 40        data.insert(data.begin() + i+1, vector<T>(data[i].begin() + m/2, data[i].end()));
 41        data[i].erase(data[i].begin() + m/2, data[i].end());
 42        T right_sum = 0;
 43        for (const T x : data[i+1]) right_sum += x;
 44        bucket_data[i] -= right_sum;
 45        bucket_data.insert(bucket_data.begin() + i+1, right_sum);
 46    }
 47
 48  public:
 49    MultisetSum() : n(0) {}
 50
 51    MultisetSum(T missing) : n(0), missing(missing) {}
 52
 53    MultisetSum(vector<T> a, T missing) : n(a.size()), missing(missing) {
 54        for (int i = 0; i < n-1; ++i) {
 55            if (a[i] > a[i+1]) {
 56                sort(a.begin(), a.end());
 57                break;
 58            }
 59        }
 60        int bucket_cnt = sqrt(n / BUCKET_RATIO) + 1;
 61        int bucket_size = n / bucket_cnt + 1;
 62        data.resize(bucket_cnt);
 63        bucket_data.resize(bucket_cnt, 0);
 64        for (int k = 0; k < bucket_cnt; ++k) {
 65            int size = min(bucket_size, n-k*bucket_size);
 66            if (size <= 0) {
 67                for (int l = k; l < bucket_cnt; ++l) {
 68                    data.pop_back();
 69                    bucket_data.pop_back();
 70                }
 71                break;
 72            }
 73            data[k] = vector<T>(a.begin()+k*bucket_size, a.begin()+k*bucket_size+size);
 74            for (const T &x : data[k]) bucket_data[k] += x;
 75        }
 76    }
 77
 78    void add(const T &key) {
 79        if (n == 0) {
 80            data.push_back({key});
 81            bucket_data.push_back(key);
 82            n = 1;
 83            return;
 84        }
 85        auto [i, pos] = get_pos(key);
 86        data[i].insert(data[i].begin() + pos, key);
 87        bucket_data[i] += key;
 88        n++;
 89        if (data[i].size() > data.size() * SPLIT_RATIO) {
 90            rebuild_split(i);
 91        }
 92    }
 93
 94    bool discard(const T &key) {
 95        auto [i, pos] = get_pos(key);
 96        if (i >= data.size() || pos >= data[i].size() || data[i][pos] != key) {
 97            return false;
 98        }
 99        data[i].erase(data[i].begin() + pos);
100        bucket_data[i] -= key;
101        n--;
102        if (data[i].empty()) {
103            data.erase(data.begin() + i);
104            bucket_data.erase(bucket_data.begin() + i);
105        }
106        return true;
107    }
108
109    void remove(const T &key) {
110        assert(discard(key));
111    }
112
113    T operator[] (int k) const {
114        for (const vector<T> &d : data) {
115            if (k < d.size()) return d[k];
116            k -= d.size();
117        }
118    }
119
120    T lt(const T &key) const {
121        for (auto it = this->data.rbegin(); it != this->data.rend(); ++it) {
122            const vector<T> &d = *it;
123            if (d[0] < key) {
124                int index = bisect_left(d, key) - 1;
125                if (index >= 0) return d[index];
126            }
127        }
128        return this->missing;
129    }
130
131    T le(const T &key) const {
132        for (auto it = this->data.rbegin(); it != this->data.rend(); ++it) {
133            const vector<T> &d = *it;
134            if (d[0] <= key) {
135                int index = bisect_right(d, key) - 1;
136                if (index >= 0) return d[index];
137            }
138        }
139        return this->missing;
140    }
141
142    T gt(const T &key) const {
143        for (const vector<T> &d : this->data) {
144            if (d.back() > key) {
145                int index = bisect_right(d, key);
146                if (index < d.size()) return d[index];
147            }
148        }
149        return this->missing;
150    }
151
152    T ge(const T &key) const {
153        for (const vector<T> &d : this->data) {
154            if (d.back() >= key) {
155                int index = bisect_left(d, key);
156                if (index < d.size()) return d[index];
157            }
158        }
159        return this->missing;
160    }
161
162    int index(const T &x) const {
163        int ans = 0;
164        for (const vector<T> &d : this->data) {
165            if (d.back() >= x) return ans + bisect_left(d, x);
166            ans += d.size();
167        }
168        return ans;
169    }
170
171    int index_right(const T &x) const {
172        int ans = 0;
173        for (const vector<T> &d : this->data) {
174            if (d.back() > x) return ans + bisect_right(d, x);
175            ans += d.size();
176        }
177        return ans;
178    }
179
180    int count(const T &key) const {
181        return index_right(key) - index(key);
182    }
183
184    bool contains(const T &key) const {
185        auto [i, pos] = get_pos(key);
186        return i < data.size() && pos < data[i].size() && data[i][pos] == key;
187    }
188
189    // [l, r)の総和を返す
190    T sum(int l, int r) const {
191        T sum = 0;
192        int u = 0, v = 0;
193        for (int i = 0; i < data.size(); ++i) {
194            v = u + data[i].size();
195            if (l <= u && v <= r) {
196                sum += bucket_data[i];
197            } else if (v > l && u < r) {
198                int start = max(l, u);
199                int end = min(r, v);
200                for (int j = start; j < end; ++j) {
201                    sum += data[i][j - u];
202                }
203            }
204            u = v;
205            if (u >= r) break;
206        }
207        return sum;
208    }
209
210    int size() const {
211        return n;
212    }
213
214    int len() const {
215        return n;
216    }
217
218    vector<T> tovector() const {
219        vector<T> a;
220        a.reserve(n);
221        for (const vector<T> &d : data) {
222            a.insert(a.end(), d.begin(), d.end());
223        }
224        return a;
225    }
226
227    friend ostream& operator<<(ostream& os, const titan23::MultisetSum<T> &ms) {
228        vector<T> a = ms.tovector();
229        os << "{";
230        int n = ms.len();
231        for (int i = 0; i < n; ++i) {
232            os << a[i];
233            if (i != n-1) os << ", ";
234        }
235        os << "}";
236        return os;
237    }
238};
239} // namespace titan23

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/multiset_sum_qd.cpp