Source code for titan_pylib.data_structures.fenwick_tree.fenwick_tree_RAQ

  1from typing import Iterable, Sequence, Union
  2
  3
[docs] 4class FenwickTreeRAQ: 5 """区間加算/区間和クエリができます。。""" 6 7 def __init__(self, n_or_a: Union[Iterable[int], int]): 8 """構築します。 9 :math:`O(n)` です。 10 11 Args: 12 n_or_a (Union[Iterable[int], int]): 構築元のものです。 13 """ 14 if isinstance(n_or_a, int): 15 self.n = n_or_a 16 self.bit0 = [0] * (n_or_a + 2) 17 self.bit1 = [0] * (n_or_a + 2) 18 self.bit_size = self.n + 1 19 else: 20 if not isinstance(n_or_a, Sequence): 21 n_or_a = list(n_or_a) 22 self.n = len(n_or_a) 23 self.bit0 = [0] * (self.n + 2) 24 self.bit1 = [0] * (self.n + 2) 25 self.bit_size = self.n + 1 26 for i, e in enumerate(n_or_a): 27 self.add_range(i, i + 1, e) 28 29 def __add(self, bit: list[int], k: int, x: int) -> None: 30 k += 1 31 while k <= self.bit_size: 32 bit[k] += x 33 k += k & -k 34 35 def __pref(self, bit: list[int], r: int) -> int: 36 ret = 0 37 while r > 0: 38 ret += bit[r] 39 r -= r & -r 40 return ret 41
[docs] 42 def add(self, k: int, x: int) -> None: 43 """``k`` 番目に ``x`` を加算します。 44 :math:`O(\\log{n})` です。 45 46 Args: 47 k (int): 48 x (int): 49 """ 50 assert ( 51 0 <= k < self.n 52 ), f"IndexError: {self.__class__.__name__}.add({k}, {x}), n={self.n}" 53 self.add_range(k, k + 1, x)
54
[docs] 55 def add_range(self, l: int, r: int, x: int) -> None: 56 """区間 ``[l, r)`` に ``x`` を加算します。 57 :math:`O(\\log{n})` です。 58 59 Args: 60 l (int): 61 r (int): 62 x (int): 63 """ 64 assert ( 65 0 <= l <= r <= self.n 66 ), f"IndexError: {self.__class__.__name__}.add_range({l}, {r}, {x}), l={l},r={r},n={self.n}" 67 self.__add(self.bit0, l, -x * l) 68 self.__add(self.bit0, r, x * r) 69 self.__add(self.bit1, l, x) 70 self.__add(self.bit1, r, -x)
71
[docs] 72 def sum(self, l: int, r: int) -> int: 73 """区間 ``[l, r)`` の総和を返します。 74 :math:`O(\\log{n})` です。 75 76 Args: 77 l (int): 78 r (int): 79 80 Returns: 81 int: 82 """ 83 assert ( 84 0 <= l <= r <= self.n 85 ), f"IndexError: {self.__class__.__name__}.sum({l}, {r}), l={l},r={r},n={self.n}" 86 return ( 87 self.__pref(self.bit0, r) 88 + r * self.__pref(self.bit1, r) 89 - self.__pref(self.bit0, l) 90 - l * self.__pref(self.bit1, l) 91 )
92
[docs] 93 def tolist(self) -> list[int]: 94 """``list`` にして返します。 95 96 Returns: 97 list[int]: 98 """ 99 return [self.sum(i, i + 1) for i in range(self.n)]
100
[docs] 101 def __getitem__(self, k: int) -> int: 102 """``k`` 番目の値を返します。 103 ``sum(k, k+1)`` と等価です。 104 :math:`O(\\log{n})` です。 105 106 Args: 107 k (int): 108 109 Returns: 110 int: 111 """ 112 assert ( 113 0 <= k < self.n 114 ), f"IndexError: {self.__class__.__name__}[{k}], n={self.n}" 115 return self.sum(k, k + 1)
116 117 def __str__(self): 118 return str(self.tolist()) 119 120 def __repr__(self): 121 return f"{self.__class__.__name__}({self})"