1from titan_pylib.data_structures.bit_vector.bit_vector import BitVector
2from titan_pylib.data_structures.fenwick_tree.fenwick_tree import FenwickTree
3from array import array
4from typing import Sequence
5from bisect import bisect_left
6
7
[docs]
8class FenwickTreeWaveletMatrix:
9
10 def __init__(self, sigma: int, pos: list[tuple[int, int, int]] = []):
11 self.sigma: int = sigma
12 self.log: int = (sigma - 1).bit_length()
13 self.mid: array[int] = array("I", bytes(4 * self.log))
14 self.xy: list[tuple[int, int]] = self._sort_unique([(x, y) for x, y, _ in pos])
15 self.y: list[int] = self._sort_unique([y for _, y, _ in pos])
16 self.size: int = len(self.xy)
17 self.v: list[BitVector] = [BitVector(self.size) for _ in range(self.log)]
18 self._build([bisect_left(self.y, y) for _, y in self.xy])
19 ws = [[0] * self.size for _ in range(self.log)]
20 for x, y, w in pos:
21 k = bisect_left(self.xy, (x, y))
22 i_y = bisect_left(self.y, y)
23 for bit in range(self.log - 1, -1, -1):
24 if i_y >> bit & 1:
25 k = self.v[bit].rank1(k) + self.mid[bit]
26 else:
27 k = self.v[bit].rank0(k)
28 ws[bit][k] += w
29 self.bit: list[FenwickTree] = [FenwickTree(a) for a in ws]
30
31 def _build(self, a: Sequence[int]) -> None:
32 # 列 a から wm を構築する
33 for bit in range(self.log - 1, -1, -1):
34 # bit目の0/1に応じてvを構築 + aを安定ソート
35 v = self.v[bit]
36 zero, one = [], []
37 for i, e in enumerate(a):
38 if e >> bit & 1:
39 v.set(i)
40 one.append(e)
41 else:
42 zero.append(e)
43 v.build()
44 self.mid[bit] = len(zero) # 境界をmid[bit]に保持
45 a = zero + one
46
47 def _sort_unique(self, a: list) -> list:
48 if not a:
49 return a
50 a.sort()
51 b = [a[0]]
52 for e in a:
53 if b[-1] == e:
54 continue
55 b.append(e)
56 return b
57
[docs]
58 def add_point(self, x: int, y: int, w: int) -> None:
59 k = bisect_left(self.xy, (x, y))
60 i_y = bisect_left(self.y, y)
61 for bit in range(self.log - 1, -1, -1):
62 if i_y >> bit & 1:
63 k = self.v[bit].rank1(k) + self.mid[bit]
64 else:
65 k = self.v[bit].rank0(k)
66 self.bit[bit].add(k, w)
67
68 def _sum(self, l: int, r: int, x: int) -> int:
69 ans = 0
70 for bit in range(self.log - 1, -1, -1):
71 l0, r0 = self.v[bit].rank0(l), self.v[bit].rank0(r)
72 if x >> bit & 1:
73 l += self.mid[bit] - l0
74 r += self.mid[bit] - r0
75 ans += self.bit[bit].sum(l0, r0)
76 else:
77 l, r = l0, r0
78 return ans
79
[docs]
80 def sum(self, w1: int, w2: int, h1: int, h2: int) -> int:
81 # sum([w1, w2) x [h1, h2))
82 l = bisect_left(self.xy, (w1, 0))
83 r = bisect_left(self.xy, (w2, 0))
84 return self._sum(l, r, bisect_left(self.y, h2)) - self._sum(
85 l, r, bisect_left(self.y, h1)
86 )