1from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
2 SegmentTreeInterface,
3)
4from titan_pylib.my_class.supports_less_than import SupportsLessThan
5from typing import Generic, Iterable, TypeVar, Union
6
7T = TypeVar("T", bound=SupportsLessThan)
8
9
[docs]
10class SegmentTreeRmQ(SegmentTreeInterface, Generic[T]):
11 """RmQ セグ木です。"""
12
13 def __init__(self, _n_or_a: Union[int, Iterable[T]], e: T) -> None:
14 self._e = e
15 if isinstance(_n_or_a, int):
16 self._n = _n_or_a
17 self._log = (self._n - 1).bit_length()
18 self._size = 1 << self._log
19 self._data = [self._e] * (self._size << 1)
20 else:
21 _n_or_a = list(_n_or_a)
22 self._n = len(_n_or_a)
23 self._log = (self._n - 1).bit_length()
24 self._size = 1 << self._log
25 _data = [self._e] * (self._size << 1)
26 _data[self._size : self._size + self._n] = _n_or_a
27 for i in range(self._size - 1, 0, -1):
28 _data[i] = (
29 _data[i << 1]
30 if _data[i << 1] < _data[i << 1 | 1]
31 else _data[i << 1 | 1]
32 )
33 self._data = _data
34
[docs]
35 def set(self, k: int, v: T) -> None:
36 if k < 0:
37 k += self._n
38 assert (
39 0 <= k < self._n
40 ), f"IndexError: {self.__class__.__name__}.set({k}: int, {v}: T), n={self._n}"
41 k += self._size
42 self._data[k] = v
43 for _ in range(self._log):
44 k >>= 1
45 self._data[k] = (
46 self._data[k << 1]
47 if self._data[k << 1] < self._data[k << 1 | 1]
48 else self._data[k << 1 | 1]
49 )
50
[docs]
51 def get(self, k: int) -> T:
52 if k < 0:
53 k += self._n
54 assert (
55 0 <= k < self._n
56 ), f"IndexError: {self.__class__.__name__}.get({k}: int), n={self._n}"
57 return self._data[k + self._size]
58
[docs]
59 def prod(self, l: int, r: int) -> T:
60 assert (
61 0 <= l <= r <= self._n
62 ), f"IndexError: {self.__class__.__name__}.prod({l}: int, {r}: int)"
63 l += self._size
64 r += self._size
65 res = self._e
66 while l < r:
67 if l & 1:
68 if res > self._data[l]:
69 res = self._data[l]
70 l += 1
71 if r & 1:
72 r ^= 1
73 if res > self._data[r]:
74 res = self._data[r]
75 l >>= 1
76 r >>= 1
77 return res
78
[docs]
79 def all_prod(self) -> T:
80 return self._data[1]
81
[docs]
82 def max_right(self, l: int, f=lambda lr: lr):
83 assert (
84 0 <= l <= self._n
85 ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range"
86 assert f(
87 self._e
88 ), f"{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true."
89 if l == self._n:
90 return self._n
91 l += self._size
92 s = self._e
93 while True:
94 while l & 1 == 0:
95 l >>= 1
96 if not f(min(s, self._data[l])):
97 while l < self._size:
98 l <<= 1
99 if f(min(s, self._data[l])):
100 if s > self._data[l]:
101 s = self._data[l]
102 l += 1
103 return l - self._size
104 s = min(s, self._data[l])
105 l += 1
106 if l & -l == l:
107 break
108 return self._n
109
[docs]
110 def min_left(self, r: int, f=lambda lr: lr):
111 assert (
112 0 <= r <= self._n
113 ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range"
114 assert f(
115 self._e
116 ), f"{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true."
117 if r == 0:
118 return 0
119 r += self._size
120 s = self._e
121 while True:
122 r -= 1
123 while r > 1 and r & 1:
124 r >>= 1
125 if not f(min(self._data[r], s)):
126 while r < self._size:
127 r = r << 1 | 1
128 if f(min(self._data[r], s)):
129 if s > self._data[r]:
130 s = self._data[r]
131 r -= 1
132 return r + 1 - self._size
133 s = min(self._data[r], s)
134 if r & -r == r:
135 break
136 return 0
137
[docs]
138 def tolist(self) -> list[T]:
139 return [self.get(i) for i in range(self._n)]
140
[docs]
141 def show(self) -> None:
142 print(
143 f"<{self.__class__.__name__}> [\n"
144 + "\n".join(
145 [
146 " "
147 + " ".join(
148 map(str, [self._data[(1 << i) + j] for j in range(1 << i)])
149 )
150 for i in range(self._log + 1)
151 ]
152 )
153 + "\n]"
154 )
155
156 def __getitem__(self, k: int) -> T:
157 assert (
158 -self._n <= k < self._n
159 ), f"IndexError: {self.__class__.__name__}.__getitem__({k}: int), n={self._n}"
160 return self.get(k)
161
162 def __setitem__(self, k: int, v: T):
163 assert (
164 -self._n <= k < self._n
165 ), f"IndexError: {self.__class__.__name__}.__setitem__{k}: int, {v}: T), n={self._n}"
166 self.set(k, v)
167
168 def __str__(self):
169 return "[" + ", ".join(map(str, (self.get(i) for i in range(self._n)))) + "]"
170
171 def __repr__(self):
172 return f"{self.__class__.__name__}({self})"