wavelet matrix cumulative sum¶
ソースコード¶
1#include <vector>
2#include <algorithm>
3#include "titan_cpplib/data_structures/bit_vector.cpp"
4#include "titan_cpplib/data_structures/cumulative_sum.cpp"
5using namespace std;
6
7// WaveletMatrixCumulativeSum
8namespace titan23 {
9
10 /**
11 * @brief
12 *
13 * @tparam T 点の座標を表す型
14 * @tparam W 重みを表す型
15 */
16 template<typename T, typename W>
17 class WaveletMatrixCumulativeSum {
18
19 private:
20 T sigma;
21 int log;
22 vector<tuple<T, T, W>> pos;
23 vector<BitVector> v;
24 vector<pair<T, T>> xy;
25 vector<T> y;
26 vector<int> mid;
27 vector<titan23::CumulativeSum<W>> cumsum;
28 int n;
29
30 int bit_length(T n) const {
31 int b = 0;
32 while (n) {
33 n >>= 1;
34 ++b;
35 }
36 return b;
37 }
38
39 void _build(vector<T> a) {
40 for (int bit = log-1; bit >= 0; --bit) {
41 vector<T> zero, one;
42 v[bit] = BitVector(n);
43 for (int i = 0; i < n; ++i) {
44 if ((a[i] >> bit) & 1) {
45 v[bit].set(i);
46 one.emplace_back(a[i]);
47 } else {
48 zero.emplace_back(a[i]);
49 }
50 }
51 v[bit].build();
52 mid[bit] = zero.size();
53 a = zero;
54 a.insert(a.end(), one.begin(), one.end());
55 }
56 }
57
58 template<typename S>
59 static void sort_unique(vector<S> &a) {
60 std::sort(a.begin(), a.end());
61 a.erase(std::unique(a.begin(), a.end()), a.end());
62 }
63
64 W _sum(int l, int r, int x) const {
65 W ans = 0;
66 for (int bit = log-1; bit >= 0; --bit) {
67 int l0 = v[bit].rank0(l);
68 int r0 = v[bit].rank0(r);
69 if ((x>>bit) & 1) {
70 l += mid[bit] - l0;
71 r += mid[bit] - r0;
72 ans += cumsum[bit].sum(l0, r0);
73 } else {
74 l = l0;
75 r = r0;
76 }
77 }
78 return ans;
79 }
80
81 public:
82 WaveletMatrixCumulativeSum() {}
83 WaveletMatrixCumulativeSum(const T sigma)
84 : sigma(sigma), log(bit_length(sigma-1)), v(log), mid(log), cumsum(log) {
85 }
86
87 void reserve(const int cap) {
88 pos.reserve(cap);
89 }
90
91 void set_point(T x, T y, W w) {
92 pos.emplace_back(x, y, w);
93 }
94
95 void build() {
96 xy.reserve(pos.size());
97 for (const auto &[x, y, w]: pos) {
98 xy.emplace_back(x, y);
99 }
100 sort_unique(xy);
101
102 this->n = xy.size();
103
104 y.reserve(n);
105 for (const auto &[x, y_]: xy) {
106 y.emplace_back(y_);
107 }
108 sort_unique(y);
109
110 vector<int> a;
111 for (const auto &[x, y_]: xy) {
112 a.emplace_back(lower_bound(y.begin(), y.end(), y_) - y.begin());
113 }
114 _build(a);
115
116 vector<vector<W>> ws(log, vector<W>(n, 0));
117 for (const auto [x, y_, w]: pos) {
118 int k = lower_bound(xy.begin(), xy.end(), make_pair(x, y_)) - xy.begin();
119 int i_y = lower_bound(y.begin(), y.end(), y_) - y.begin();
120 for (int bit = log-1; bit >= 0; --bit) {
121 if ((i_y >> bit) & 1) {
122 k = v[bit].rank1(k) + mid[bit];
123 } else {
124 k = v[bit].rank0(k);
125 }
126 ws[bit][k] += w;
127 }
128 }
129
130 for (int i = 0; i < log; ++i) {
131 cumsum[i] = titan23::CumulativeSum<W>(ws[i], 0);
132 }
133 }
134
135 W sum(int w1, int w2, int h1, int h2) const {
136 int l = lower_bound(xy.begin(), xy.end(), make_pair(w1, 0)) - xy.begin();
137 int r = lower_bound(xy.begin(), xy.end(), make_pair(w2, 0)) - xy.begin();
138 return _sum(l, r, lower_bound(y.begin(), y.end(), h2) - y.begin())
139 - _sum(l, r, lower_bound(y.begin(), y.end(), h1) - y.begin());
140 }
141 };
142} // namespace titan23
仕様¶
Warning
doxygenfile: Cannot find file “titan_cpplib/data_structures/wavelet_matrix_cumulative_sum.cpp