segment_tree_RSQ

ソースコード

from titan_pylib.data_structures.segment_tree.segment_tree_RSQ import SegmentTreeRSQ

view on github

展開済みコード

  1# from titan_pylib.data_structures.segment_tree.segment_tree_RSQ import SegmentTreeRSQ
  2# from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
  3#     SegmentTreeInterface,
  4# )
  5from abc import ABC, abstractmethod
  6from typing import TypeVar, Generic, Union, Iterable, Callable
  7
  8T = TypeVar("T")
  9
 10
 11class SegmentTreeInterface(ABC, Generic[T]):
 12
 13    @abstractmethod
 14    def __init__(self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T):
 15        raise NotImplementedError
 16
 17    @abstractmethod
 18    def set(self, k: int, v: T) -> None:
 19        raise NotImplementedError
 20
 21    @abstractmethod
 22    def get(self, k: int) -> T:
 23        raise NotImplementedError
 24
 25    @abstractmethod
 26    def prod(self, l: int, r: int) -> T:
 27        raise NotImplementedError
 28
 29    @abstractmethod
 30    def all_prod(self) -> T:
 31        raise NotImplementedError
 32
 33    @abstractmethod
 34    def max_right(self, l: int, f: Callable[[T], bool]) -> int:
 35        raise NotImplementedError
 36
 37    @abstractmethod
 38    def min_left(self, r: int, f: Callable[[T], bool]) -> int:
 39        raise NotImplementedError
 40
 41    @abstractmethod
 42    def tolist(self) -> list[T]:
 43        raise NotImplementedError
 44
 45    @abstractmethod
 46    def __getitem__(self, k: int) -> T:
 47        raise NotImplementedError
 48
 49    @abstractmethod
 50    def __setitem__(self, k: int, v: T) -> None:
 51        raise NotImplementedError
 52
 53    @abstractmethod
 54    def __str__(self):
 55        raise NotImplementedError
 56
 57    @abstractmethod
 58    def __repr__(self):
 59        raise NotImplementedError
 60# from titan_pylib.my_class.supports_add import SupportsAdd
 61from typing import Protocol
 62
 63
 64class SupportsAdd(Protocol):
 65
 66    def __add__(self, other): ...
 67    def __iadd__(self, other): ...
 68    def __radd__(self, other): ...
 69from typing import Generic, Iterable, TypeVar, Union
 70
 71T = TypeVar("T", bound=SupportsAdd)
 72
 73
 74class SegmentTreeRSQ(SegmentTreeInterface, Generic[T]):
 75    """RSQ セグ木です。"""
 76
 77    def __init__(self, _n_or_a: Union[int, Iterable[T]], e: T = 0) -> None:
 78        self._e = e
 79        if isinstance(_n_or_a, int):
 80            self._n = _n_or_a
 81            self._log = (self._n - 1).bit_length()
 82            self._size = 1 << self._log
 83            self._data = [self._e] * (self._size << 1)
 84        else:
 85            _n_or_a = list(_n_or_a)
 86            self._n = len(_n_or_a)
 87            self._log = (self._n - 1).bit_length()
 88            self._size = 1 << self._log
 89            _data = [self._e] * (self._size << 1)
 90            _data[self._size : self._size + self._n] = _n_or_a
 91            for i in range(self._size - 1, 0, -1):
 92                _data[i] = _data[i << 1] + _data[i << 1 | 1]
 93            self._data = _data
 94
 95    def set(self, k: int, v: T) -> None:
 96        assert (
 97            -self._n <= k < self._n
 98        ), f"IndexError: {self.__class__.__name__}.set({k}: int, {v}: T), n={self._n}"
 99        if k < 0:
100            k += self._n
101        k += self._size
102        self._data[k] = v
103        for _ in range(self._log):
104            k >>= 1
105            self._data[k] = self._data[k << 1] + self._data[k << 1 | 1]
106
107    def get(self, k: int) -> T:
108        assert (
109            -self._n <= k < self._n
110        ), f"IndexError: {self.__class__.__name__}.get({k}: int), n={self._n}"
111        if k < 0:
112            k += self._n
113        return self._data[k + self._size]
114
115    def prod(self, l: int, r: int):
116        assert (
117            0 <= l <= r <= self._n
118        ), f"IndexError: {self.__class__.__name__}.prod({l}: int, {r}: int)"
119        l += self._size
120        r += self._size
121        res = self._e
122        while l < r:
123            if l & 1:
124                res += self._data[l]
125                l += 1
126            if r & 1:
127                res += self._data[r ^ 1]
128            l >>= 1
129            r >>= 1
130        return res
131
132    def all_prod(self):
133        return self._data[1]
134
135    def max_right(self, l: int, f=lambda lr: lr):
136        assert (
137            0 <= l <= self._n
138        ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range"
139        assert f(
140            self._e
141        ), f"{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true."
142        if l == self._n:
143            return self._n
144        l += self._size
145        s = self._e
146        while True:
147            while l & 1 == 0:
148                l >>= 1
149            if not f(s + self._data[l]):
150                while l < self._size:
151                    l <<= 1
152                    if f(s + self._data[l]):
153                        s += self._data[l]
154                        l += 1
155                return l - self._size
156            s += self._data[l]
157            l += 1
158            if l & -l == l:
159                break
160        return self._n
161
162    def min_left(self, r: int, f=lambda lr: lr):
163        assert (
164            0 <= r <= self._n
165        ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range"
166        assert f(
167            self._e
168        ), f"{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true."
169        if r == 0:
170            return 0
171        r += self._size
172        s = self._e
173        while True:
174            r -= 1
175            while r > 1 and r & 1:
176                r >>= 1
177            if not f(self._data[r] + s):
178                while r < self._size:
179                    r = r << 1 | 1
180                    if f(self._data[r] + s):
181                        s += self._data[r]
182                        r -= 1
183                return r + 1 - self._size
184            s += self._data[r]
185            if r & -r == r:
186                break
187        return 0
188
189    def tolist(self) -> list[T]:
190        return [self.get(i) for i in range(self._n)]
191
192    def show(self) -> None:
193        print(
194            f"<{self.__class__.__name__}> [\n"
195            + "\n".join(
196                [
197                    "  "
198                    + " ".join(
199                        map(str, [self._data[(1 << i) + j] for j in range(1 << i)])
200                    )
201                    for i in range(self._log + 1)
202                ]
203            )
204            + "\n]"
205        )
206
207    def __getitem__(self, k: int) -> T:
208        assert (
209            -self._n <= k < self._n
210        ), f"IndexError: {self.__class__.__name__}.__getitem__({k}: int), n={self._n}"
211        return self.get(k)
212
213    def __setitem__(self, k: int, v: T):
214        assert (
215            -self._n <= k < self._n
216        ), f"IndexError: {self.__class__.__name__}.__setitem__{k}: int, {v}: T), n={self._n}"
217        self.set(k, v)
218
219    def __str__(self):
220        return "[" + ", ".join(map(str, (self.get(i) for i in range(self._n)))) + "]"
221
222    def __repr__(self):
223        return f"{self.__class__.__name__}({self})"

仕様

class SegmentTreeRSQ(_n_or_a: int | Iterable[T], e: T = 0)[source]

Bases: SegmentTreeInterface, Generic[T]

RSQ セグ木です。

all_prod()[source]
get(k: int) T[source]
max_right(l: int, f=<function SegmentTreeRSQ.<lambda>>)[source]
min_left(r: int, f=<function SegmentTreeRSQ.<lambda>>)[source]
prod(l: int, r: int)[source]
set(k: int, v: T) None[source]
show() None[source]
tolist() list[T][source]