1from titan_pylib.data_structures.bit_vector.bit_vector import BitVector
2from titan_pylib.data_structures.cumulative_sum.cumulative_sum import CumulativeSum
3from array import array
4from typing import Sequence, Iterable
5from bisect import bisect_left
6
7
[docs]
8class CumulativeSumWaveletMatrix:
9
10 def __init__(self, sigma: int, pos: Iterable[tuple[int, int, int]] = []) -> None:
11 """
12 Args:
13 sigma (int): yの最大値
14 pos (list[tuple[int, int, int]], optional): list[(x, y, w)]
15 """
16 self.sigma: int = sigma
17 self.log: int = (sigma - 1).bit_length()
18 self.mid: array[int] = array("I", bytes(4 * self.log))
19 self.xy: list[tuple[int, int]] = self._sort_unique([(x, y) for x, y, _ in pos])
20 self.y: list[int] = self._sort_unique([y for _, y in self.xy])
21 self.size: int = len(self.xy)
22 self.v: list[BitVector] = [BitVector(self.size) for _ in range(self.log)]
23 self._build([bisect_left(self.y, y) for _, y in self.xy])
24 ws = [[0] * self.size for _ in range(self.log)]
25 for x, y, w in pos:
26 k = bisect_left(self.xy, (x, y))
27 i_y = bisect_left(self.y, y)
28 for bit in range(self.log - 1, -1, -1):
29 if i_y >> bit & 1:
30 k = self.v[bit].rank1(k) + self.mid[bit]
31 else:
32 k = self.v[bit].rank0(k)
33 ws[bit][k] += w
34 self.bit: list[CumulativeSum] = [CumulativeSum(a) for a in ws]
35
36 def _build(self, a: Sequence[int]) -> None:
37 # 列 a から wm を構築する
38 for bit in range(self.log - 1, -1, -1):
39 # bit目の0/1に応じてvを構築 + aを安定ソート
40 v = self.v[bit]
41 zero, one = [], []
42 for i, e in enumerate(a):
43 if e >> bit & 1:
44 v.set(i)
45 one.append(e)
46 else:
47 zero.append(e)
48 v.build()
49 self.mid[bit] = len(zero) # 境界をmid[bit]に保持
50 a = zero + one
51
52 def _sort_unique(self, a: list) -> list:
53 if not a:
54 return a
55 a.sort()
56 b = [a[0]]
57 for e in a:
58 if b[-1] == e:
59 continue
60 b.append(e)
61 return b
62
63 def _sum(self, l: int, r: int, x: int) -> int:
64 ans = 0
65 for bit in range(self.log - 1, -1, -1):
66 l0, r0 = self.v[bit].rank0(l), self.v[bit].rank0(r)
67 if x >> bit & 1:
68 l += self.mid[bit] - l0
69 r += self.mid[bit] - r0
70 ans += self.bit[bit].sum(l0, r0)
71 else:
72 l, r = l0, r0
73 return ans
74
[docs]
75 def sum(self, w1: int, w2: int, h1: int, h2: int) -> int:
76 """sum([w1, w2) x [h1, h2))"""
77 assert 0 <= w1 <= w2
78 assert 0 <= h1 <= h2
79 l = bisect_left(self.xy, (w1, 0))
80 r = bisect_left(self.xy, (w2, 0))
81 return self._sum(l, r, bisect_left(self.y, h2)) - self._sum(
82 l, r, bisect_left(self.y, h1)
83 )