lazy_segment_tree

ソースコード

from titan_pylib.data_structures.segment_tree.lazy_segment_tree import LazySegmentTree

view on github

展開済みコード

  1# from titan_pylib.data_structures.segment_tree.lazy_segment_tree import LazySegmentTree
  2from typing import Union, Callable, TypeVar, Generic, Iterable
  3
  4T = TypeVar("T")
  5F = TypeVar("F")
  6
  7
  8class LazySegmentTree(Generic[T, F]):
  9    """遅延セグ木です。"""
 10
 11    def __init__(
 12        self,
 13        n_or_a: Union[int, Iterable[T]],
 14        op: Callable[[T, T], T],
 15        mapping: Callable[[F, T], T],
 16        composition: Callable[[F, F], F],
 17        e: T,
 18        id: F,
 19    ) -> None:
 20        self.op: Callable[[T, T], T] = op
 21        self.mapping: Callable[[F, T], T] = mapping
 22        self.composition: Callable[[F, F], F] = composition
 23        self.e: T = e
 24        self.id: F = id
 25        if isinstance(n_or_a, int):
 26            self.n = n_or_a
 27            self.log = (self.n - 1).bit_length()
 28            self.size = 1 << self.log
 29            self.data = [e] * (self.size << 1)
 30        else:
 31            a = list(n_or_a)
 32            self.n = len(a)
 33            self.log = (self.n - 1).bit_length()
 34            self.size = 1 << self.log
 35            data = [e] * (self.size << 1)
 36            data[self.size : self.size + self.n] = a
 37            for i in range(self.size - 1, 0, -1):
 38                data[i] = op(data[i << 1], data[i << 1 | 1])
 39            self.data = data
 40        self.lazy = [id] * self.size
 41
 42    def _update(self, k: int) -> None:
 43        self.data[k] = self.op(self.data[k << 1], self.data[k << 1 | 1])
 44
 45    def _all_apply(self, k: int, f: F) -> None:
 46        self.data[k] = self.mapping(f, self.data[k])
 47        if k >= self.size:
 48            return
 49        self.lazy[k] = self.composition(f, self.lazy[k])
 50
 51    def _propagate(self, k: int) -> None:
 52        if self.lazy[k] == self.id:
 53            return
 54        self._all_apply(k << 1, self.lazy[k])
 55        self._all_apply(k << 1 | 1, self.lazy[k])
 56        self.lazy[k] = self.id
 57
 58    def apply_point(self, k: int, f: F) -> None:
 59        k += self.size
 60        for i in range(self.log, 0, -1):
 61            self._propagate(k >> i)
 62        self.data[k] = self.mapping(f, self.data[k])
 63        for i in range(1, self.log + 1):
 64            self._update(k >> i)
 65
 66    def _upper_propagate(self, l: int, r: int) -> None:
 67        for i in range(self.log, 0, -1):
 68            if l >> i << i != l:
 69                self._propagate(l >> i)
 70            if (r >> i << i != r) and (l >> i != (r - 1) >> i or l >> i << i == l):
 71                self._propagate((r - 1) >> i)
 72
 73    def apply(self, l: int, r: int, f: F) -> None:
 74        assert (
 75            0 <= l <= r <= self.n
 76        ), f"IndexError: {self.__class__.__name__}.apply({l}, {r}, {f}), n={self.n}"
 77        if l == r:
 78            return
 79        if f == self.id:
 80            return
 81        l += self.size
 82        r += self.size
 83        self._upper_propagate(l, r)
 84        l2, r2 = l, r
 85        while l < r:
 86            if l & 1:
 87                self._all_apply(l, f)
 88                l += 1
 89            if r & 1:
 90                self._all_apply(r ^ 1, f)
 91            l >>= 1
 92            r >>= 1
 93        ll, rr = l2, r2 - 1
 94        for i in range(1, self.log + 1):
 95            ll >>= 1
 96            rr >>= 1
 97            if ll << i != l2:
 98                self._update(ll)
 99            if (ll << i == l2 or ll != rr) and (r2 >> i << i != r2):
100                self._update(rr)
101
102    def all_apply(self, f: F) -> None:
103        self.lazy[1] = self.composition(f, self.lazy[1])
104
105    def prod(self, l: int, r: int) -> T:
106        assert (
107            0 <= l <= r <= self.n
108        ), f"IndexError: {self.__class__.__name__}.prod({l}, {r}), n={self.n}"
109        if l == r:
110            return self.e
111        l += self.size
112        r += self.size
113        self._upper_propagate(l, r)
114        lres = self.e
115        rres = self.e
116        while l < r:
117            if l & 1:
118                lres = self.op(lres, self.data[l])
119                l += 1
120            if r & 1:
121                rres = self.op(self.data[r ^ 1], rres)
122            l >>= 1
123            r >>= 1
124        return self.op(lres, rres)
125
126    def all_prod(self) -> T:
127        return self.data[1]
128
129    def all_propagate(self) -> None:
130        for i in range(self.size):
131            self._propagate(i)
132
133    def tolist(self) -> list[T]:
134        self.all_propagate()
135        return self.data[self.size : self.size + self.n]
136
137    def max_right(self, l, f) -> int:
138        assert 0 <= l <= self.n
139        # assert f(self.e)
140        if l == self.size:
141            return self.n
142        l += self.size
143        for i in range(self.log, 0, -1):
144            self._propagate(l >> i)
145        s = self.e
146        while True:
147            while l & 1 == 0:
148                l >>= 1
149            if not f(self.op(s, self.data[l])):
150                while l < self.size:
151                    self._propagate(l)
152                    l <<= 1
153                    if f(self.op(s, self.data[l])):
154                        s = self.op(s, self.data[l])
155                        l |= 1
156                return l - self.size
157            s = self.op(s, self.data[l])
158            l += 1
159            if l & -l == l:
160                break
161        return self.n
162
163    def min_left(self, r: int, f) -> int:
164        assert 0 <= r <= self.n
165        # assert f(self.e)
166        if r == 0:
167            return 0
168        r += self.size
169        for i in range(self.log, 0, -1):
170            self._propagate((r - 1) >> i)
171        s = self.e
172        while True:
173            r -= 1
174            while r > 1 and r & 1:
175                r >>= 1
176            if not f(self.op(self.data[r], s)):
177                while r < self.size:
178                    self._propagate(r)
179                    r = r << 1 | 1
180                    if f(self.op(self.data[r], s)):
181                        s = self.op(self.data[r], s)
182                        r ^= 1
183                return r + 1 - self.size
184            s = self.op(self.data[r], s)
185            if r & -r == r:
186                break
187        return 0
188
189    def __getitem__(self, k: int) -> T:
190        assert (
191            -self.n <= k < self.n
192        ), f"IndexError: {self.__class__.__name__}[{k}], n={self.n}"
193        if k < 0:
194            k += self.n
195        k += self.size
196        for i in range(self.log, 0, -1):
197            self._propagate(k >> i)
198        return self.data[k]
199
200    def __setitem__(self, k: int, v: T):
201        assert (
202            -self.n <= k < self.n
203        ), f"IndexError: {self.__class__.__name__}[{k}] = {v}, n={self.n}"
204        if k < 0:
205            k += self.n
206        k += self.size
207        for i in range(self.log, 0, -1):
208            self._propagate(k >> i)
209        self.data[k] = v
210        for i in range(1, self.log + 1):
211            self._update(k >> i)
212
213    def __str__(self) -> str:
214        return str(self.tolist())
215
216    def __repr__(self):
217        return f"{self.__class__.__name__}({self})"

仕様

class LazySegmentTree(n_or_a: int | Iterable[T], op: Callable[[T, T], T], mapping: Callable[[F, T], T], composition: Callable[[F, F], F], e: T, id: F)[source]

Bases: Generic[T, F]

遅延セグ木です。

all_apply(f: F) None[source]
all_prod() T[source]
all_propagate() None[source]
apply(l: int, r: int, f: F) None[source]
apply_point(k: int, f: F) None[source]
max_right(l, f) int[source]
min_left(r: int, f) int[source]
prod(l: int, r: int) T[source]
tolist() list[T][source]