1# from titan_pylib.data_structures.fenwick_tree.fenwick_tree_2D import FenwickTree2D
2class FenwickTree2D:
3 """2次元です。"""
4
5 def __init__(self, h: int, w: int, a: list[list[int]] = []) -> None:
6 """O(HWlogHlogW)"""
7 self._h = h + 1
8 self._w = w + 1
9 self._bit = [0] * (self._h * self._w)
10 if a:
11 assert len(a) == h
12 if h == 0:
13 return
14 assert len(a[0]) == w
15 self._build(a)
16
17 def _build(self, a: list[list[int]]) -> None:
18 for i in range(self._h - 1):
19 for j in range(self._w - 1):
20 if a[i][j] != 0:
21 self.add(i, j, a[i][j])
22
23 def add(self, h: int, w: int, x: int) -> None:
24 """Add x to a[h][w]. / O(logH * logW)"""
25 assert 0 <= h < self._h - 1, f"IndexError"
26 assert 0 <= w < self._w - 1, f"IndexError"
27 h += 1
28 w += 1
29 _h, _w, _bit = self._h, self._w, self._bit
30 while h < _h:
31 j = w
32 while j < _w:
33 _bit[h * _w + j] += x
34 j += j & -j
35 h += h & -h
36
37 def set(self, h: int, w: int, x: int) -> None:
38 assert 0 <= h < self._h - 1, f"IndexError"
39 assert 0 <= w < self._w - 1, f"IndexError"
40 self.add(h, w, x - self.get(h, w))
41
42 def _sum(self, h: int, w: int) -> int:
43 """Return sum([0, h) x [0, w)) of a. / O(logH * logW)"""
44 assert 0 <= h < self._h, f"IndexError"
45 assert 0 <= w < self._w, f"IndexError"
46 ret = 0
47 _w, _bit = self._w, self._bit
48 while h > 0:
49 j = w
50 while j > 0:
51 ret += _bit[h * _w + j]
52 j -= j & -j
53 h -= h & -h
54 return ret
55
56 def sum(self, h1: int, w1: int, h2: int, w2: int) -> int:
57 """Retrun sum([h1, h2) x [w1, w2)) of a. / O(logH * logW)"""
58 assert 0 <= h1 <= h2 < self._h, f"IndexError"
59 assert 0 <= w1 <= w2 < self._w, f"IndexError"
60 return (
61 self._sum(h2, w2)
62 - self._sum(h2, w1)
63 - self._sum(h1, w2)
64 + self._sum(h1, w1)
65 )
66
67 def get(self, h: int, w: int) -> int:
68 assert 0 <= h < self._h - 1, f"IndexError"
69 assert 0 <= w < self._w - 1, f"IndexError"
70 return self.sum(h, w, h + 1, w + 1)
71
72 def __str__(self) -> str:
73 ret = []
74 for i in range(self._h - 1):
75 ret.append(
76 ", ".join(map(str, ((self.get(i, j)) for j in range(self._w - 1))))
77 )
78 return "[\n " + "\n ".join(map(str, ret)) + "\n]"