segment_tree

ソースコード

from titan_pylib.data_structures.segment_tree.segment_tree import SegmentTree

view on github

展開済みコード

  1# from titan_pylib.data_structures.segment_tree.segment_tree import SegmentTree
  2# from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
  3#     SegmentTreeInterface,
  4# )
  5from abc import ABC, abstractmethod
  6from typing import TypeVar, Generic, Union, Iterable, Callable
  7
  8T = TypeVar("T")
  9
 10
 11class SegmentTreeInterface(ABC, Generic[T]):
 12
 13    @abstractmethod
 14    def __init__(self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T):
 15        raise NotImplementedError
 16
 17    @abstractmethod
 18    def set(self, k: int, v: T) -> None:
 19        raise NotImplementedError
 20
 21    @abstractmethod
 22    def get(self, k: int) -> T:
 23        raise NotImplementedError
 24
 25    @abstractmethod
 26    def prod(self, l: int, r: int) -> T:
 27        raise NotImplementedError
 28
 29    @abstractmethod
 30    def all_prod(self) -> T:
 31        raise NotImplementedError
 32
 33    @abstractmethod
 34    def max_right(self, l: int, f: Callable[[T], bool]) -> int:
 35        raise NotImplementedError
 36
 37    @abstractmethod
 38    def min_left(self, r: int, f: Callable[[T], bool]) -> int:
 39        raise NotImplementedError
 40
 41    @abstractmethod
 42    def tolist(self) -> list[T]:
 43        raise NotImplementedError
 44
 45    @abstractmethod
 46    def __getitem__(self, k: int) -> T:
 47        raise NotImplementedError
 48
 49    @abstractmethod
 50    def __setitem__(self, k: int, v: T) -> None:
 51        raise NotImplementedError
 52
 53    @abstractmethod
 54    def __str__(self):
 55        raise NotImplementedError
 56
 57    @abstractmethod
 58    def __repr__(self):
 59        raise NotImplementedError
 60from typing import Generic, Iterable, TypeVar, Callable, Union
 61
 62T = TypeVar("T")
 63
 64
 65class SegmentTree(SegmentTreeInterface, Generic[T]):
 66    """セグ木です。非再帰です。"""
 67
 68    def __init__(
 69        self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T
 70    ) -> None:
 71        """``SegmentTree`` を構築します。
 72        :math:`O(n)` です。
 73
 74        Args:
 75            n_or_a (Union[int, Iterable[T]]): ``n: int`` のとき、 ``e`` を初期値として長さ ``n`` の ``SegmentTree`` を構築します。
 76                                              ``a: Iterable[T]`` のとき、 ``a`` から ``SegmentTree`` を構築します。
 77            op (Callable[[T, T], T]): 2項演算の関数です。
 78            e (T): 単位元です。
 79        """
 80        self._op = op
 81        self._e = e
 82        if isinstance(n_or_a, int):
 83            self._n = n_or_a
 84            self._log = (self._n - 1).bit_length()
 85            self._size = 1 << self._log
 86            self._data = [e] * (self._size << 1)
 87        else:
 88            n_or_a = list(n_or_a)
 89            self._n = len(n_or_a)
 90            self._log = (self._n - 1).bit_length()
 91            self._size = 1 << self._log
 92            _data = [e] * (self._size << 1)
 93            _data[self._size : self._size + self._n] = n_or_a
 94            for i in range(self._size - 1, 0, -1):
 95                _data[i] = op(_data[i << 1], _data[i << 1 | 1])
 96            self._data = _data
 97
 98    def set(self, k: int, v: T) -> None:
 99        """一点更新です。
100        :math:`O(\\log{n})` です。
101
102        Args:
103            k (int): 更新するインデックスです。
104            v (T): 更新する値です。
105
106        制約:
107            :math:`-n \\leq n \\leq k < n`
108        """
109        assert (
110            -self._n <= k < self._n
111        ), f"IndexError: {self.__class__.__name__}.set({k}, {v}), n={self._n}"
112        if k < 0:
113            k += self._n
114        k += self._size
115        self._data[k] = v
116        for _ in range(self._log):
117            k >>= 1
118            self._data[k] = self._op(self._data[k << 1], self._data[k << 1 | 1])
119
120    def get(self, k: int) -> T:
121        """一点取得です。
122        :math:`O(1)` です。
123
124        Args:
125            k (int): インデックスです。
126
127        制約:
128            :math:`-n \\leq n \\leq k < n`
129        """
130        assert (
131            -self._n <= k < self._n
132        ), f"IndexError: {self.__class__.__name__}.get({k}), n={self._n}"
133        if k < 0:
134            k += self._n
135        return self._data[k + self._size]
136
137    def prod(self, l: int, r: int) -> T:
138        """区間 ``[l, r)`` の総積を返します。
139        :math:`O(\\log{n})` です。
140
141        Args:
142            l (int): インデックスです。
143            r (int): インデックスです。
144
145        制約:
146            :math:`0 \\leq l \\leq r \\leq n`
147        """
148        assert (
149            0 <= l <= r <= self._n
150        ), f"IndexError: {self.__class__.__name__}.prod({l}, {r})"
151        l += self._size
152        r += self._size
153        lres = self._e
154        rres = self._e
155        while l < r:
156            if l & 1:
157                lres = self._op(lres, self._data[l])
158                l += 1
159            if r & 1:
160                rres = self._op(self._data[r ^ 1], rres)
161            l >>= 1
162            r >>= 1
163        return self._op(lres, rres)
164
165    def all_prod(self) -> T:
166        """区間 ``[0, n)`` の総積を返します。
167        :math:`O(1)` です。
168        """
169        return self._data[1]
170
171    def max_right(self, l: int, f: Callable[[T], bool]) -> int:
172        """Find the largest index R s.t. f([l, R)) == True. / O(\\log{n})"""
173        assert (
174            0 <= l <= self._n
175        ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range"
176        # assert f(self._e), \
177        #     f'{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true.'
178        if l == self._n:
179            return self._n
180        l += self._size
181        s = self._e
182        while True:
183            while l & 1 == 0:
184                l >>= 1
185            if not f(self._op(s, self._data[l])):
186                while l < self._size:
187                    l <<= 1
188                    if f(self._op(s, self._data[l])):
189                        s = self._op(s, self._data[l])
190                        l |= 1
191                return l - self._size
192            s = self._op(s, self._data[l])
193            l += 1
194            if l & -l == l:
195                break
196        return self._n
197
198    def min_left(self, r: int, f: Callable[[T], bool]) -> int:
199        """Find the smallest index L s.t. f([L, r)) == True. / O(\\log{n})"""
200        assert (
201            0 <= r <= self._n
202        ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range"
203        # assert f(self._e), \
204        #     f'{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true.'
205        if r == 0:
206            return 0
207        r += self._size
208        s = self._e
209        while True:
210            r -= 1
211            while r > 1 and r & 1:
212                r >>= 1
213            if not f(self._op(self._data[r], s)):
214                while r < self._size:
215                    r = r << 1 | 1
216                    if f(self._op(self._data[r], s)):
217                        s = self._op(self._data[r], s)
218                        r ^= 1
219                return r + 1 - self._size
220            s = self._op(self._data[r], s)
221            if r & -r == r:
222                break
223        return 0
224
225    def tolist(self) -> list[T]:
226        """リストにして返します。
227        :math:`O(n)` です。
228        """
229        return [self.get(i) for i in range(self._n)]
230
231    def show(self) -> None:
232        """デバッグ用のメソッドです。"""
233        print(
234            f"<{self.__class__.__name__}> [\n"
235            + "\n".join(
236                [
237                    "  "
238                    + " ".join(
239                        map(str, [self._data[(1 << i) + j] for j in range(1 << i)])
240                    )
241                    for i in range(self._log + 1)
242                ]
243            )
244            + "\n]"
245        )
246
247    def __getitem__(self, k: int) -> T:
248        assert (
249            -self._n <= k < self._n
250        ), f"IndexError: {self.__class__.__name__}.__getitem__({k}), n={self._n}"
251        return self.get(k)
252
253    def __setitem__(self, k: int, v: T):
254        assert (
255            -self._n <= k < self._n
256        ), f"IndexError: {self.__class__.__name__}.__setitem__{k}, {v}), n={self._n}"
257        self.set(k, v)
258
259    def __len__(self) -> int:
260        return self._n
261
262    def __str__(self) -> str:
263        return str(self.tolist())
264
265    def __repr__(self) -> str:
266        return f"{self.__class__.__name__}({self})"

仕様

class SegmentTree(n_or_a: int | Iterable[T], op: Callable[[T, T], T], e: T)[source]

Bases: SegmentTreeInterface, Generic[T]

セグ木です。非再帰です。

all_prod() T[source]

区間 [0, n) の総積を返します。 O(1) です。

get(k: int) T[source]

一点取得です。 O(1) です。

Parameters:

k (int) – インデックスです。

制約:

nnk<n

max_right(l: int, f: Callable[[T], bool]) int[source]

Find the largest index R s.t. f([l, R)) == True. / O(log{n})

min_left(r: int, f: Callable[[T], bool]) int[source]

Find the smallest index L s.t. f([L, r)) == True. / O(log{n})

prod(l: int, r: int) T[source]

区間 [l, r) の総積を返します。 O(logn) です。

Parameters:
  • l (int) – インデックスです。

  • r (int) – インデックスです。

制約:

0lrn

set(k: int, v: T) None[source]

一点更新です。 O(logn) です。

Parameters:
  • k (int) – 更新するインデックスです。

  • v (T) – 更新する値です。

制約:

nnk<n

show() None[source]

デバッグ用のメソッドです。

tolist() list[T][source]

リストにして返します。 O(n) です。