Source code for titan_pylib.data_structures.fenwick_tree.fenwick_tree_2D

[docs] 1class FenwickTree2D: 2 """2次元です。""" 3 4 def __init__(self, h: int, w: int, a: list[list[int]] = []) -> None: 5 """O(HWlogHlogW)""" 6 self._h = h + 1 7 self._w = w + 1 8 self._bit = [0] * (self._h * self._w) 9 if a: 10 assert len(a) == h 11 if h == 0: 12 return 13 assert len(a[0]) == w 14 self._build(a) 15 16 def _build(self, a: list[list[int]]) -> None: 17 for i in range(self._h - 1): 18 for j in range(self._w - 1): 19 if a[i][j] != 0: 20 self.add(i, j, a[i][j]) 21
[docs] 22 def add(self, h: int, w: int, x: int) -> None: 23 """Add x to a[h][w]. / O(logH * logW)""" 24 assert 0 <= h < self._h - 1, f"IndexError" 25 assert 0 <= w < self._w - 1, f"IndexError" 26 h += 1 27 w += 1 28 _h, _w, _bit = self._h, self._w, self._bit 29 while h < _h: 30 j = w 31 while j < _w: 32 _bit[h * _w + j] += x 33 j += j & -j 34 h += h & -h
35
[docs] 36 def set(self, h: int, w: int, x: int) -> None: 37 assert 0 <= h < self._h - 1, f"IndexError" 38 assert 0 <= w < self._w - 1, f"IndexError" 39 self.add(h, w, x - self.get(h, w))
40 41 def _sum(self, h: int, w: int) -> int: 42 """Return sum([0, h) x [0, w)) of a. / O(logH * logW)""" 43 assert 0 <= h < self._h, f"IndexError" 44 assert 0 <= w < self._w, f"IndexError" 45 ret = 0 46 _w, _bit = self._w, self._bit 47 while h > 0: 48 j = w 49 while j > 0: 50 ret += _bit[h * _w + j] 51 j -= j & -j 52 h -= h & -h 53 return ret 54
[docs] 55 def sum(self, h1: int, w1: int, h2: int, w2: int) -> int: 56 """Retrun sum([h1, h2) x [w1, w2)) of a. / O(logH * logW)""" 57 assert 0 <= h1 <= h2 < self._h, f"IndexError" 58 assert 0 <= w1 <= w2 < self._w, f"IndexError" 59 return ( 60 self._sum(h2, w2) 61 - self._sum(h2, w1) 62 - self._sum(h1, w2) 63 + self._sum(h1, w1) 64 )
65
[docs] 66 def get(self, h: int, w: int) -> int: 67 assert 0 <= h < self._h - 1, f"IndexError" 68 assert 0 <= w < self._w - 1, f"IndexError" 69 return self.sum(h, w, h + 1, w + 1)
70 71 def __str__(self) -> str: 72 ret = [] 73 for i in range(self._h - 1): 74 ret.append( 75 ", ".join(map(str, ((self.get(i, j)) for j in range(self._w - 1)))) 76 ) 77 return "[\n " + "\n ".join(map(str, ret)) + "\n]"