[docs]
1class DynamicFenwickTree2D:
2 """必要なところだけノードを作ります。2次元です。"""
3
4 def __init__(self, h: int, w: int, a: list[list[int]] = []):
5 """O(HWlogHlogW)"""
6 self._h: int = h + 1
7 self._w: int = w + 1
8 self._bit: dict[int, dict[int, int]] = {}
9 if a:
10 self._build(a)
11
12 def _build(self, a: list[list[int]]) -> None:
13 assert len(a) == self._h - 1 and len(a[0]) == self._w - 1
14 for i in range(self._h - 1):
15 for j in range(self._w - 1):
16 self.add(i, j, a[i][j])
17
[docs]
18 def add(self, h: int, w: int, x) -> None:
19 """Add x to a[h][w]. / O(logH * logW)"""
20 h += 1
21 w += 1
22 _h, _w, _bit = self._h, self._w, self._bit
23 while h < _h:
24 j = w
25 if h not in _bit:
26 _bit[h] = {}
27 bit_h = _bit[h]
28 while j < _w:
29 if j in bit_h:
30 bit_h[j] += x
31 else:
32 bit_h[j] = x
33 j += j & -j
34 h += h & -h
35
[docs]
36 def set(self, h: int, w: int, x) -> None:
37 self.add(h, w, x - self.get(h, w))
38
39 def _sum(self, h: int, w: int) -> int:
40 """Return sum([0, h) x [0, w)) of a. / O(logH * logW)"""
41 ret = 0
42 while h > 0:
43 j = w
44 if h not in self._bit:
45 h -= h & -h
46 continue
47 bit_h = self._bit[h]
48 while j > 0:
49 ret += bit_h.get(j, 0)
50 j -= j & -j
51 h -= h & -h
52 return ret
53
[docs]
54 def sum(self, h1: int, w1: int, h2: int, w2: int) -> int:
55 """Retrun sum([h1, h2) x [w1, w2)) of a. / O(logH * logW)"""
56 assert h1 <= h2 and w1 <= w2
57 # w1, w2 = min(w1, w2), max(w1, w2)
58 # h1, h2 = min(h1, h2), max(h1, h2)
59 return (
60 self._sum(h2, w2)
61 - self._sum(h2, w1)
62 - self._sum(h1, w2)
63 + self._sum(h1, w1)
64 )
65
[docs]
66 def get(self, h: int, w: int) -> int:
67 return self.sum(h, h + 1, w, w + 1)