segment_tree_RSQ¶
ソースコード¶
from titan_pylib.data_structures.segment_tree.segment_tree_RSQ import SegmentTreeRSQ
展開済みコード¶
1# from titan_pylib.data_structures.segment_tree.segment_tree_RSQ import SegmentTreeRSQ
2# from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
3# SegmentTreeInterface,
4# )
5from abc import ABC, abstractmethod
6from typing import TypeVar, Generic, Union, Iterable, Callable
7
8T = TypeVar("T")
9
10
11class SegmentTreeInterface(ABC, Generic[T]):
12
13 @abstractmethod
14 def __init__(self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T):
15 raise NotImplementedError
16
17 @abstractmethod
18 def set(self, k: int, v: T) -> None:
19 raise NotImplementedError
20
21 @abstractmethod
22 def get(self, k: int) -> T:
23 raise NotImplementedError
24
25 @abstractmethod
26 def prod(self, l: int, r: int) -> T:
27 raise NotImplementedError
28
29 @abstractmethod
30 def all_prod(self) -> T:
31 raise NotImplementedError
32
33 @abstractmethod
34 def max_right(self, l: int, f: Callable[[T], bool]) -> int:
35 raise NotImplementedError
36
37 @abstractmethod
38 def min_left(self, r: int, f: Callable[[T], bool]) -> int:
39 raise NotImplementedError
40
41 @abstractmethod
42 def tolist(self) -> list[T]:
43 raise NotImplementedError
44
45 @abstractmethod
46 def __getitem__(self, k: int) -> T:
47 raise NotImplementedError
48
49 @abstractmethod
50 def __setitem__(self, k: int, v: T) -> None:
51 raise NotImplementedError
52
53 @abstractmethod
54 def __str__(self):
55 raise NotImplementedError
56
57 @abstractmethod
58 def __repr__(self):
59 raise NotImplementedError
60# from titan_pylib.my_class.supports_add import SupportsAdd
61from typing import Protocol
62
63
64class SupportsAdd(Protocol):
65
66 def __add__(self, other): ...
67 def __iadd__(self, other): ...
68 def __radd__(self, other): ...
69from typing import Generic, Iterable, TypeVar, Union
70
71T = TypeVar("T", bound=SupportsAdd)
72
73
74class SegmentTreeRSQ(SegmentTreeInterface, Generic[T]):
75 """RSQ セグ木です。"""
76
77 def __init__(self, _n_or_a: Union[int, Iterable[T]], e: T = 0) -> None:
78 self._e = e
79 if isinstance(_n_or_a, int):
80 self._n = _n_or_a
81 self._log = (self._n - 1).bit_length()
82 self._size = 1 << self._log
83 self._data = [self._e] * (self._size << 1)
84 else:
85 _n_or_a = list(_n_or_a)
86 self._n = len(_n_or_a)
87 self._log = (self._n - 1).bit_length()
88 self._size = 1 << self._log
89 _data = [self._e] * (self._size << 1)
90 _data[self._size : self._size + self._n] = _n_or_a
91 for i in range(self._size - 1, 0, -1):
92 _data[i] = _data[i << 1] + _data[i << 1 | 1]
93 self._data = _data
94
95 def set(self, k: int, v: T) -> None:
96 assert (
97 -self._n <= k < self._n
98 ), f"IndexError: {self.__class__.__name__}.set({k}: int, {v}: T), n={self._n}"
99 if k < 0:
100 k += self._n
101 k += self._size
102 self._data[k] = v
103 for _ in range(self._log):
104 k >>= 1
105 self._data[k] = self._data[k << 1] + self._data[k << 1 | 1]
106
107 def get(self, k: int) -> T:
108 assert (
109 -self._n <= k < self._n
110 ), f"IndexError: {self.__class__.__name__}.get({k}: int), n={self._n}"
111 if k < 0:
112 k += self._n
113 return self._data[k + self._size]
114
115 def prod(self, l: int, r: int):
116 assert (
117 0 <= l <= r <= self._n
118 ), f"IndexError: {self.__class__.__name__}.prod({l}: int, {r}: int)"
119 l += self._size
120 r += self._size
121 res = self._e
122 while l < r:
123 if l & 1:
124 res += self._data[l]
125 l += 1
126 if r & 1:
127 res += self._data[r ^ 1]
128 l >>= 1
129 r >>= 1
130 return res
131
132 def all_prod(self):
133 return self._data[1]
134
135 def max_right(self, l: int, f=lambda lr: lr):
136 assert (
137 0 <= l <= self._n
138 ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range"
139 assert f(
140 self._e
141 ), f"{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true."
142 if l == self._n:
143 return self._n
144 l += self._size
145 s = self._e
146 while True:
147 while l & 1 == 0:
148 l >>= 1
149 if not f(s + self._data[l]):
150 while l < self._size:
151 l <<= 1
152 if f(s + self._data[l]):
153 s += self._data[l]
154 l += 1
155 return l - self._size
156 s += self._data[l]
157 l += 1
158 if l & -l == l:
159 break
160 return self._n
161
162 def min_left(self, r: int, f=lambda lr: lr):
163 assert (
164 0 <= r <= self._n
165 ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range"
166 assert f(
167 self._e
168 ), f"{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true."
169 if r == 0:
170 return 0
171 r += self._size
172 s = self._e
173 while True:
174 r -= 1
175 while r > 1 and r & 1:
176 r >>= 1
177 if not f(self._data[r] + s):
178 while r < self._size:
179 r = r << 1 | 1
180 if f(self._data[r] + s):
181 s += self._data[r]
182 r -= 1
183 return r + 1 - self._size
184 s += self._data[r]
185 if r & -r == r:
186 break
187 return 0
188
189 def tolist(self) -> list[T]:
190 return [self.get(i) for i in range(self._n)]
191
192 def show(self) -> None:
193 print(
194 f"<{self.__class__.__name__}> [\n"
195 + "\n".join(
196 [
197 " "
198 + " ".join(
199 map(str, [self._data[(1 << i) + j] for j in range(1 << i)])
200 )
201 for i in range(self._log + 1)
202 ]
203 )
204 + "\n]"
205 )
206
207 def __getitem__(self, k: int) -> T:
208 assert (
209 -self._n <= k < self._n
210 ), f"IndexError: {self.__class__.__name__}.__getitem__({k}: int), n={self._n}"
211 return self.get(k)
212
213 def __setitem__(self, k: int, v: T):
214 assert (
215 -self._n <= k < self._n
216 ), f"IndexError: {self.__class__.__name__}.__setitem__{k}: int, {v}: T), n={self._n}"
217 self.set(k, v)
218
219 def __str__(self):
220 return "[" + ", ".join(map(str, (self.get(i) for i in range(self._n)))) + "]"
221
222 def __repr__(self):
223 return f"{self.__class__.__name__}({self})"
仕様¶
- class SegmentTreeRSQ(_n_or_a: int | Iterable[T], e: T = 0)[source]¶
Bases:
SegmentTreeInterface
,Generic
[T
]RSQ セグ木です。