dynamic_segment_tree

ソースコード

from titan_pylib.data_structures.segment_tree.dynamic_segment_tree import DynamicSegmentTree

view on github

展開済みコード

  1# from titan_pylib.data_structures.segment_tree.dynamic_segment_tree import DynamicSegmentTree
  2# from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
  3#     SegmentTreeInterface,
  4# )
  5from abc import ABC, abstractmethod
  6from typing import TypeVar, Generic, Union, Iterable, Callable
  7
  8T = TypeVar("T")
  9
 10
 11class SegmentTreeInterface(ABC, Generic[T]):
 12
 13    @abstractmethod
 14    def __init__(self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T):
 15        raise NotImplementedError
 16
 17    @abstractmethod
 18    def set(self, k: int, v: T) -> None:
 19        raise NotImplementedError
 20
 21    @abstractmethod
 22    def get(self, k: int) -> T:
 23        raise NotImplementedError
 24
 25    @abstractmethod
 26    def prod(self, l: int, r: int) -> T:
 27        raise NotImplementedError
 28
 29    @abstractmethod
 30    def all_prod(self) -> T:
 31        raise NotImplementedError
 32
 33    @abstractmethod
 34    def max_right(self, l: int, f: Callable[[T], bool]) -> int:
 35        raise NotImplementedError
 36
 37    @abstractmethod
 38    def min_left(self, r: int, f: Callable[[T], bool]) -> int:
 39        raise NotImplementedError
 40
 41    @abstractmethod
 42    def tolist(self) -> list[T]:
 43        raise NotImplementedError
 44
 45    @abstractmethod
 46    def __getitem__(self, k: int) -> T:
 47        raise NotImplementedError
 48
 49    @abstractmethod
 50    def __setitem__(self, k: int, v: T) -> None:
 51        raise NotImplementedError
 52
 53    @abstractmethod
 54    def __str__(self):
 55        raise NotImplementedError
 56
 57    @abstractmethod
 58    def __repr__(self):
 59        raise NotImplementedError
 60from typing import Generic, TypeVar, Callable
 61
 62T = TypeVar("T")
 63
 64
 65class DynamicSegmentTree(SegmentTreeInterface, Generic[T]):
 66    """動的セグ木です。"""
 67
 68    def __init__(self, u: int, op: Callable[[T, T], T], e: T):
 69        self._op = op
 70        self._e = e
 71        self._u = u
 72        self._log = (self._u - 1).bit_length()
 73        self._size = 1 << self._log
 74        self._data: dict[int, T] = {}
 75
 76    def set(self, k: int, v: T) -> None:
 77        assert (
 78            -self._u <= k < self._u
 79        ), f"IndexError: {self.__class__.__name__}.set({k}: int, {v}: T), n={self._u}"
 80        if k < 0:
 81            k += self._u
 82        k += self._size
 83        self._data[k] = v
 84        e = self._e
 85        for _ in range(self._log):
 86            k >>= 1
 87            self._data[k] = self._op(
 88                self._data.get(k << 1, e), self._data.get(k << 1 | 1, e)
 89            )
 90
 91    def get(self, k: int) -> T:
 92        assert (
 93            -self._u <= k < self._u
 94        ), f"IndexError: {self.__class__.__name__}.get({k}: int), n={self._u}"
 95        if k < 0:
 96            k += self._u
 97        return self._data.get(k + self._size, self._e)
 98
 99    def prod(self, l: int, r: int) -> T:
100        assert (
101            0 <= l <= r <= self._u
102        ), f"IndexError: {self.__class__.__name__}.prod({l}: int, {r}: int)"
103        l += self._size
104        r += self._size
105        e = self._e
106        lres = e
107        rres = e
108        while l < r:
109            if l & 1:
110                lres = self._op(lres, self._data.get(l, e))
111                l += 1
112            if r & 1:
113                rres = self._op(self._data.get(r ^ 1, e), rres)
114            l >>= 1
115            r >>= 1
116        return self._op(lres, rres)
117
118    def all_prod(self) -> T:
119        return self._data[1]
120
121    def max_right(self, l: int, f: Callable[[T], bool]) -> int:
122        """Find the largest index R s.t. f([l, R)) == True. / O(logU)"""
123        assert (
124            0 <= l <= self._u
125        ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range"
126        assert f(
127            self._e
128        ), f"{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true."
129        if l == self._u:
130            return self._u
131        l += self._size
132        e = self._e
133        s = e
134        while True:
135            while l & 1 == 0:
136                l >>= 1
137            if not f(self._op(s, self._data.get(l, e))):
138                while l < self._size:
139                    l <<= 1
140                    if f(self._op(s, self._data.get(l, e))):
141                        s = self._op(s, self._data.get(l, e))
142                        l |= 1
143                return l - self._size
144            s = self._op(s, self._data.get(l, e))
145            l += 1
146            if l & -l == l:
147                break
148        return self._u
149
150    def min_left(self, r: int, f: Callable[[T], bool]) -> int:
151        """Find the smallest index L s.t. f([L, r)) == True. / O(logU)"""
152        assert (
153            0 <= r <= self._u
154        ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range"
155        assert f(
156            self._e
157        ), f"{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true."
158        if r == 0:
159            return 0
160        r += self._size
161        e = self._e
162        s = e
163        while True:
164            r -= 1
165            while r > 1 and r & 1:
166                r >>= 1
167            if not f(self._op(self._data.get(r, e), s)):
168                while r < self._size:
169                    r = r << 1 | 1
170                    if f(self._op(self._data.get(r, e), s)):
171                        s = self._op(self._data.get(r, e), s)
172                        r ^= 1
173                return r + 1 - self._size
174            s = self._op(self._data.get(r, e), s)
175            if r & -r == r:
176                break
177        return 0
178
179    def tolist(self) -> list[T]:
180        return [self.get(i) for i in range(self._u)]
181
182    def __getitem__(self, k: int) -> T:
183        assert (
184            -self._u <= k < self._u
185        ), f"IndexError: {self.__class__.__name__}[{k}]: int), n={self._u}"
186        return self.get(k)
187
188    def __setitem__(self, k: int, v: T) -> None:
189        assert (
190            -self._u <= k < self._u
191        ), f"IndexError: {self.__class__.__name__}.__setitem__{k}: int, {v}: T), n={self._u}"
192        self.set(k, v)
193
194    def __str__(self) -> str:
195        return str(self.tolist())
196
197    def __repr__(self) -> str:
198        return f"{self.__class__.__name__}({self})"

仕様

class DynamicSegmentTree(u: int, op: Callable[[T, T], T], e: T)[source]

Bases: SegmentTreeInterface, Generic[T]

動的セグ木です。

all_prod() T[source]
get(k: int) T[source]
max_right(l: int, f: Callable[[T], bool]) int[source]

Find the largest index R s.t. f([l, R)) == True. / O(logU)

min_left(r: int, f: Callable[[T], bool]) int[source]

Find the smallest index L s.t. f([L, r)) == True. / O(logU)

prod(l: int, r: int) T[source]
set(k: int, v: T) None[source]
tolist() list[T][source]