1from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
2 SegmentTreeInterface,
3)
4from titan_pylib.my_class.supports_add import SupportsAdd
5from typing import Generic, Iterable, TypeVar, Union
6
7T = TypeVar("T", bound=SupportsAdd)
8
9
[docs]
10class SegmentTreeRSQ(SegmentTreeInterface, Generic[T]):
11 """RSQ セグ木です。"""
12
13 def __init__(self, _n_or_a: Union[int, Iterable[T]], e: T = 0) -> None:
14 self._e = e
15 if isinstance(_n_or_a, int):
16 self._n = _n_or_a
17 self._log = (self._n - 1).bit_length()
18 self._size = 1 << self._log
19 self._data = [self._e] * (self._size << 1)
20 else:
21 _n_or_a = list(_n_or_a)
22 self._n = len(_n_or_a)
23 self._log = (self._n - 1).bit_length()
24 self._size = 1 << self._log
25 _data = [self._e] * (self._size << 1)
26 _data[self._size : self._size + self._n] = _n_or_a
27 for i in range(self._size - 1, 0, -1):
28 _data[i] = _data[i << 1] + _data[i << 1 | 1]
29 self._data = _data
30
[docs]
31 def set(self, k: int, v: T) -> None:
32 assert (
33 -self._n <= k < self._n
34 ), f"IndexError: {self.__class__.__name__}.set({k}: int, {v}: T), n={self._n}"
35 if k < 0:
36 k += self._n
37 k += self._size
38 self._data[k] = v
39 for _ in range(self._log):
40 k >>= 1
41 self._data[k] = self._data[k << 1] + self._data[k << 1 | 1]
42
[docs]
43 def get(self, k: int) -> T:
44 assert (
45 -self._n <= k < self._n
46 ), f"IndexError: {self.__class__.__name__}.get({k}: int), n={self._n}"
47 if k < 0:
48 k += self._n
49 return self._data[k + self._size]
50
[docs]
51 def prod(self, l: int, r: int):
52 assert (
53 0 <= l <= r <= self._n
54 ), f"IndexError: {self.__class__.__name__}.prod({l}: int, {r}: int)"
55 l += self._size
56 r += self._size
57 res = self._e
58 while l < r:
59 if l & 1:
60 res += self._data[l]
61 l += 1
62 if r & 1:
63 res += self._data[r ^ 1]
64 l >>= 1
65 r >>= 1
66 return res
67
[docs]
68 def all_prod(self):
69 return self._data[1]
70
[docs]
71 def max_right(self, l: int, f=lambda lr: lr):
72 assert (
73 0 <= l <= self._n
74 ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range"
75 assert f(
76 self._e
77 ), f"{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true."
78 if l == self._n:
79 return self._n
80 l += self._size
81 s = self._e
82 while True:
83 while l & 1 == 0:
84 l >>= 1
85 if not f(s + self._data[l]):
86 while l < self._size:
87 l <<= 1
88 if f(s + self._data[l]):
89 s += self._data[l]
90 l += 1
91 return l - self._size
92 s += self._data[l]
93 l += 1
94 if l & -l == l:
95 break
96 return self._n
97
[docs]
98 def min_left(self, r: int, f=lambda lr: lr):
99 assert (
100 0 <= r <= self._n
101 ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range"
102 assert f(
103 self._e
104 ), f"{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true."
105 if r == 0:
106 return 0
107 r += self._size
108 s = self._e
109 while True:
110 r -= 1
111 while r > 1 and r & 1:
112 r >>= 1
113 if not f(self._data[r] + s):
114 while r < self._size:
115 r = r << 1 | 1
116 if f(self._data[r] + s):
117 s += self._data[r]
118 r -= 1
119 return r + 1 - self._size
120 s += self._data[r]
121 if r & -r == r:
122 break
123 return 0
124
[docs]
125 def tolist(self) -> list[T]:
126 return [self.get(i) for i in range(self._n)]
127
[docs]
128 def show(self) -> None:
129 print(
130 f"<{self.__class__.__name__}> [\n"
131 + "\n".join(
132 [
133 " "
134 + " ".join(
135 map(str, [self._data[(1 << i) + j] for j in range(1 << i)])
136 )
137 for i in range(self._log + 1)
138 ]
139 )
140 + "\n]"
141 )
142
143 def __getitem__(self, k: int) -> T:
144 assert (
145 -self._n <= k < self._n
146 ), f"IndexError: {self.__class__.__name__}.__getitem__({k}: int), n={self._n}"
147 return self.get(k)
148
149 def __setitem__(self, k: int, v: T):
150 assert (
151 -self._n <= k < self._n
152 ), f"IndexError: {self.__class__.__name__}.__setitem__{k}: int, {v}: T), n={self._n}"
153 self.set(k, v)
154
155 def __str__(self):
156 return "[" + ", ".join(map(str, (self.get(i) for i in range(self._n)))) + "]"
157
158 def __repr__(self):
159 return f"{self.__class__.__name__}({self})"