[docs]
1class FenwickTree2D:
2 """2次元です。"""
3
4 def __init__(self, h: int, w: int, a: list[list[int]] = []) -> None:
5 """O(HWlogHlogW)"""
6 self._h = h + 1
7 self._w = w + 1
8 self._bit = [0] * (self._h * self._w)
9 if a:
10 assert len(a) == h
11 if h == 0:
12 return
13 assert len(a[0]) == w
14 self._build(a)
15
16 def _build(self, a: list[list[int]]) -> None:
17 for i in range(self._h - 1):
18 for j in range(self._w - 1):
19 if a[i][j] != 0:
20 self.add(i, j, a[i][j])
21
[docs]
22 def add(self, h: int, w: int, x: int) -> None:
23 """Add x to a[h][w]. / O(logH * logW)"""
24 assert 0 <= h < self._h - 1, f"IndexError"
25 assert 0 <= w < self._w - 1, f"IndexError"
26 h += 1
27 w += 1
28 _h, _w, _bit = self._h, self._w, self._bit
29 while h < _h:
30 j = w
31 while j < _w:
32 _bit[h * _w + j] += x
33 j += j & -j
34 h += h & -h
35
[docs]
36 def set(self, h: int, w: int, x: int) -> None:
37 assert 0 <= h < self._h - 1, f"IndexError"
38 assert 0 <= w < self._w - 1, f"IndexError"
39 self.add(h, w, x - self.get(h, w))
40
41 def _sum(self, h: int, w: int) -> int:
42 """Return sum([0, h) x [0, w)) of a. / O(logH * logW)"""
43 assert 0 <= h < self._h, f"IndexError"
44 assert 0 <= w < self._w, f"IndexError"
45 ret = 0
46 _w, _bit = self._w, self._bit
47 while h > 0:
48 j = w
49 while j > 0:
50 ret += _bit[h * _w + j]
51 j -= j & -j
52 h -= h & -h
53 return ret
54
[docs]
55 def sum(self, h1: int, w1: int, h2: int, w2: int) -> int:
56 """Retrun sum([h1, h2) x [w1, w2)) of a. / O(logH * logW)"""
57 assert 0 <= h1 <= h2 < self._h, f"IndexError"
58 assert 0 <= w1 <= w2 < self._w, f"IndexError"
59 return (
60 self._sum(h2, w2)
61 - self._sum(h2, w1)
62 - self._sum(h1, w2)
63 + self._sum(h1, w1)
64 )
65
[docs]
66 def get(self, h: int, w: int) -> int:
67 assert 0 <= h < self._h - 1, f"IndexError"
68 assert 0 <= w < self._w - 1, f"IndexError"
69 return self.sum(h, w, h + 1, w + 1)
70
71 def __str__(self) -> str:
72 ret = []
73 for i in range(self._h - 1):
74 ret.append(
75 ", ".join(map(str, ((self.get(i, j)) for j in range(self._w - 1))))
76 )
77 return "[\n " + "\n ".join(map(str, ret)) + "\n]"