1# from titan_pylib.data_structures.fenwick_tree.fenwick_tree_abst import FenwickTreeAbst
2from typing import Union, Iterable, TypeVar, Generic, Callable
3
4T = TypeVar("T")
5
6
7class FenwickTreeAbst(Generic[T]):
8 """和や逆元をこちらで定義できます。"""
9
10 def __init__(
11 self,
12 n_or_a: Union[Iterable[T], T],
13 op: Callable[[T, T], T],
14 inv: Callable[[T], T],
15 e: T,
16 ) -> None:
17 if isinstance(n_or_a, int):
18 self._size = n_or_a
19 self._tree = [e] * (self._size + 1)
20 else:
21 a = n_or_a if isinstance(n_or_a, list) else list(n_or_a)
22 self._size = len(a)
23 self._tree = [e] + a
24 for i in range(1, self._size):
25 if i + (i & -i) <= self._size:
26 self._tree[i + (i & -i)] = op(
27 self._tree[i + (i & -i)], self._tree[i]
28 )
29 self.op = op
30 self.inv = inv
31 self.e = e
32 self._s = 1 << (self._size - 1).bit_length()
33
34 def pref(self, r: int) -> T:
35 assert (
36 0 <= r <= self._size
37 ), f"IndexError: {self.__class__.__name__}.pref({r}), n={self._size}"
38 ret = self.e
39 while r > 0:
40 ret = self.op(ret, self._tree[r])
41 r -= r & -r
42 return ret
43
44 def suff(self, l: int) -> T:
45 assert (
46 0 <= l < self._size
47 ), f"IndexError: {self.__class__.__name__}.suff({l}), n={self._size}"
48 return self.op(self.pref(self._size), self.inv(self.pref(l)))
49
50 def sum(self, l: int, r: int) -> T:
51 assert (
52 0 <= l <= r <= self._size
53 ), f"IndexError: {self.__class__.__name__}.sum({l}, {r}), n={self._size}"
54 _tree = self._tree
55 res = self.e
56 while r > l:
57 res = self.op(res, _tree[r])
58 r &= r - 1
59 while l > r:
60 res += self.inv(_tree[l])
61 l &= l - 1
62 return res
63
64 prod = sum
65
66 def __getitem__(self, k: int) -> T:
67 assert (
68 -self._size <= k < self._size
69 ), f"IndexError: {self.__class__.__name__}.__getitem__({k}), n={self._size}"
70 if k < 0:
71 k += self._size
72 return self.op(self.pref(k + 1), self.inv(self.pref(k)))
73
74 def add(self, k: int, x: T) -> None:
75 assert (
76 0 <= k < self._size
77 ), f"IndexError: {self.__class__.__name__}.add({k}, {x}), n={self._size}"
78 k += 1
79 while k <= self._size:
80 self._tree[k] = self.op(self._tree[k], x)
81 k += k & -k
82
83 def __setitem__(self, k: int, x: T):
84 assert (
85 -self._size <= k < self._size
86 ), f"IndexError: {self.__class__.__name__}.__setitem__({k}, {x}), n={self._size}"
87 if k < 0:
88 k += self._size
89 pre = self.__getitem__(k)
90 self.add(k, self.op(x, self.inv(pre)))
91
92 def tolist(self) -> list[T]:
93 sub = [self.pref(i) for i in range(self._size + 1)]
94 return [self.op(sub[i + 1], self.inv(sub[i])) for i in range(self._size)]
95
96 def __str__(self):
97 return str(self.tolist())
98
99 def __repr__(self):
100 return f"{self.__class__.__name__}({self})"