wavelet matrix cumulative sum

ソースコード

  1#include <vector>
  2#include <algorithm>
  3#include "titan_cpplib/data_structures/bit_vector.cpp"
  4#include "titan_cpplib/data_structures/cumulative_sum.cpp"
  5using namespace std;
  6
  7// WaveletMatrixCumulativeSum
  8namespace titan23 {
  9
 10    /**
 11     * @brief 
 12     * 
 13     * @tparam T 点の座標を表す型
 14     * @tparam W 重みを表す型
 15     */
 16    template<typename T, typename W>
 17    class WaveletMatrixCumulativeSum {
 18
 19      private:
 20        T sigma;
 21        int log;
 22        vector<tuple<T, T, W>> pos;
 23        vector<BitVector> v;
 24        vector<pair<T, T>> xy;
 25        vector<T> y;
 26        vector<int> mid;
 27        vector<titan23::CumulativeSum<W>> cumsum;
 28        int n;
 29
 30        int bit_length(T n) const {
 31            int b = 0;
 32            while (n) {
 33                n >>= 1;
 34                ++b;
 35            }
 36            return b;
 37        }
 38
 39        void _build(vector<T> a) {
 40            for (int bit = log-1; bit >= 0; --bit) {
 41                vector<T> zero, one;
 42                v[bit] = BitVector(n);
 43                for (int i = 0; i < n; ++i) {
 44                    if ((a[i] >> bit) & 1) {
 45                        v[bit].set(i);
 46                        one.emplace_back(a[i]);
 47                    } else {
 48                        zero.emplace_back(a[i]);
 49                    }
 50                }
 51                v[bit].build();
 52                mid[bit] = zero.size();
 53                a = zero;
 54                a.insert(a.end(), one.begin(), one.end());
 55            }
 56        }
 57
 58        template<typename S>
 59        static void sort_unique(vector<S> &a) {
 60            std::sort(a.begin(), a.end());
 61            a.erase(std::unique(a.begin(), a.end()), a.end());
 62        }
 63
 64        W _sum(int l, int r, int x) const {
 65            W ans = 0;
 66            for (int bit = log-1; bit >= 0; --bit) {
 67                int l0 = v[bit].rank0(l);
 68                int r0 = v[bit].rank0(r);
 69                if ((x>>bit) & 1) {
 70                    l += mid[bit] - l0;
 71                    r += mid[bit] - r0;
 72                    ans += cumsum[bit].sum(l0, r0);
 73                } else {
 74                    l = l0;
 75                    r = r0;
 76                }
 77            }
 78            return ans;
 79        }
 80
 81      public:
 82        WaveletMatrixCumulativeSum() {}
 83        WaveletMatrixCumulativeSum(const T sigma)
 84            : sigma(sigma), log(bit_length(sigma-1)), v(log), mid(log), cumsum(log) {
 85        }
 86
 87        void reserve(const int cap) {
 88            pos.reserve(cap);
 89        }
 90
 91        void set_point(T x, T y, W w) {
 92            pos.emplace_back(x, y, w);
 93        }
 94
 95        void build() {
 96            xy.reserve(pos.size());
 97            for (const auto &[x, y, w]: pos) {
 98                xy.emplace_back(x, y);
 99            }
100            sort_unique(xy);
101
102            this->n = xy.size();
103
104            y.reserve(n);
105            for (const auto &[x, y_]: xy) {
106                y.emplace_back(y_);
107            }
108            sort_unique(y);
109
110            vector<int> a;
111            for (const auto &[x, y_]: xy) {
112                a.emplace_back(lower_bound(y.begin(), y.end(), y_) - y.begin());
113            }
114            _build(a);
115
116            vector<vector<W>> ws(log, vector<W>(n, 0));
117            for (const auto [x, y_, w]: pos) {
118                int k = lower_bound(xy.begin(), xy.end(), make_pair(x, y_)) - xy.begin();
119                int i_y = lower_bound(y.begin(), y.end(), y_) - y.begin();
120                for (int bit = log-1; bit >= 0; --bit) {
121                    if ((i_y >> bit) & 1) {
122                        k = v[bit].rank1(k) + mid[bit];
123                    } else {
124                        k = v[bit].rank0(k);
125                    }
126                    ws[bit][k] += w;
127                }
128            }
129
130            for (int i = 0; i < log; ++i) {
131                cumsum[i] = titan23::CumulativeSum<W>(ws[i], 0);
132            }
133        }
134
135        W sum(int w1, int w2, int h1, int h2) const {
136            int l = lower_bound(xy.begin(), xy.end(), make_pair(w1, 0)) - xy.begin();
137            int r = lower_bound(xy.begin(), xy.end(), make_pair(w2, 0)) - xy.begin();
138            return _sum(l, r, lower_bound(y.begin(), y.end(), h2) - y.begin())
139                    - _sum(l, r, lower_bound(y.begin(), y.end(), h1) - y.begin());
140        }
141    };
142}  // namespace titan23

仕様

Warning

doxygenfile: Cannot find file “titan_cpplib/data_structures/wavelet_matrix_cumulative_sum.cpp