1from typing import Iterable, Sequence, Union
2
3
[docs]
4class FenwickTreeRAQ:
5 """区間加算/区間和クエリができます。。"""
6
7 def __init__(self, n_or_a: Union[Iterable[int], int]):
8 """構築します。
9 :math:`O(n)` です。
10
11 Args:
12 n_or_a (Union[Iterable[int], int]): 構築元のものです。
13 """
14 if isinstance(n_or_a, int):
15 self.n = n_or_a
16 self.bit0 = [0] * (n_or_a + 2)
17 self.bit1 = [0] * (n_or_a + 2)
18 self.bit_size = self.n + 1
19 else:
20 if not isinstance(n_or_a, Sequence):
21 n_or_a = list(n_or_a)
22 self.n = len(n_or_a)
23 self.bit0 = [0] * (self.n + 2)
24 self.bit1 = [0] * (self.n + 2)
25 self.bit_size = self.n + 1
26 for i, e in enumerate(n_or_a):
27 self.add_range(i, i + 1, e)
28
29 def __add(self, bit: list[int], k: int, x: int) -> None:
30 k += 1
31 while k <= self.bit_size:
32 bit[k] += x
33 k += k & -k
34
35 def __pref(self, bit: list[int], r: int) -> int:
36 ret = 0
37 while r > 0:
38 ret += bit[r]
39 r -= r & -r
40 return ret
41
[docs]
42 def add(self, k: int, x: int) -> None:
43 """``k`` 番目に ``x`` を加算します。
44 :math:`O(\\log{n})` です。
45
46 Args:
47 k (int):
48 x (int):
49 """
50 assert (
51 0 <= k < self.n
52 ), f"IndexError: {self.__class__.__name__}.add({k}, {x}), n={self.n}"
53 self.add_range(k, k + 1, x)
54
[docs]
55 def add_range(self, l: int, r: int, x: int) -> None:
56 """区間 ``[l, r)`` に ``x`` を加算します。
57 :math:`O(\\log{n})` です。
58
59 Args:
60 l (int):
61 r (int):
62 x (int):
63 """
64 assert (
65 0 <= l <= r <= self.n
66 ), f"IndexError: {self.__class__.__name__}.add_range({l}, {r}, {x}), l={l},r={r},n={self.n}"
67 self.__add(self.bit0, l, -x * l)
68 self.__add(self.bit0, r, x * r)
69 self.__add(self.bit1, l, x)
70 self.__add(self.bit1, r, -x)
71
[docs]
72 def sum(self, l: int, r: int) -> int:
73 """区間 ``[l, r)`` の総和を返します。
74 :math:`O(\\log{n})` です。
75
76 Args:
77 l (int):
78 r (int):
79
80 Returns:
81 int:
82 """
83 assert (
84 0 <= l <= r <= self.n
85 ), f"IndexError: {self.__class__.__name__}.sum({l}, {r}), l={l},r={r},n={self.n}"
86 return (
87 self.__pref(self.bit0, r)
88 + r * self.__pref(self.bit1, r)
89 - self.__pref(self.bit0, l)
90 - l * self.__pref(self.bit1, l)
91 )
92
[docs]
93 def tolist(self) -> list[int]:
94 """``list`` にして返します。
95
96 Returns:
97 list[int]:
98 """
99 return [self.sum(i, i + 1) for i in range(self.n)]
100
[docs]
101 def __getitem__(self, k: int) -> int:
102 """``k`` 番目の値を返します。
103 ``sum(k, k+1)`` と等価です。
104 :math:`O(\\log{n})` です。
105
106 Args:
107 k (int):
108
109 Returns:
110 int:
111 """
112 assert (
113 0 <= k < self.n
114 ), f"IndexError: {self.__class__.__name__}[{k}], n={self.n}"
115 return self.sum(k, k + 1)
116
117 def __str__(self):
118 return str(self.tolist())
119
120 def __repr__(self):
121 return f"{self.__class__.__name__}({self})"