Source code for titan_pylib.data_structures.segment_tree.lazy_segment_tree

  1from typing import Union, Callable, TypeVar, Generic, Iterable
  2
  3T = TypeVar("T")
  4F = TypeVar("F")
  5
  6
[docs] 7class LazySegmentTree(Generic[T, F]): 8 """遅延セグ木です。""" 9 10 def __init__( 11 self, 12 n_or_a: Union[int, Iterable[T]], 13 op: Callable[[T, T], T], 14 mapping: Callable[[F, T], T], 15 composition: Callable[[F, F], F], 16 e: T, 17 id: F, 18 ) -> None: 19 self.op: Callable[[T, T], T] = op 20 self.mapping: Callable[[F, T], T] = mapping 21 self.composition: Callable[[F, F], F] = composition 22 self.e: T = e 23 self.id: F = id 24 if isinstance(n_or_a, int): 25 self.n = n_or_a 26 self.log = (self.n - 1).bit_length() 27 self.size = 1 << self.log 28 self.data = [e] * (self.size << 1) 29 else: 30 a = list(n_or_a) 31 self.n = len(a) 32 self.log = (self.n - 1).bit_length() 33 self.size = 1 << self.log 34 data = [e] * (self.size << 1) 35 data[self.size : self.size + self.n] = a 36 for i in range(self.size - 1, 0, -1): 37 data[i] = op(data[i << 1], data[i << 1 | 1]) 38 self.data = data 39 self.lazy = [id] * self.size 40 41 def _update(self, k: int) -> None: 42 self.data[k] = self.op(self.data[k << 1], self.data[k << 1 | 1]) 43 44 def _all_apply(self, k: int, f: F) -> None: 45 self.data[k] = self.mapping(f, self.data[k]) 46 if k >= self.size: 47 return 48 self.lazy[k] = self.composition(f, self.lazy[k]) 49 50 def _propagate(self, k: int) -> None: 51 if self.lazy[k] == self.id: 52 return 53 self._all_apply(k << 1, self.lazy[k]) 54 self._all_apply(k << 1 | 1, self.lazy[k]) 55 self.lazy[k] = self.id 56
[docs] 57 def apply_point(self, k: int, f: F) -> None: 58 k += self.size 59 for i in range(self.log, 0, -1): 60 self._propagate(k >> i) 61 self.data[k] = self.mapping(f, self.data[k]) 62 for i in range(1, self.log + 1): 63 self._update(k >> i)
64 65 def _upper_propagate(self, l: int, r: int) -> None: 66 for i in range(self.log, 0, -1): 67 if l >> i << i != l: 68 self._propagate(l >> i) 69 if (r >> i << i != r) and (l >> i != (r - 1) >> i or l >> i << i == l): 70 self._propagate((r - 1) >> i) 71
[docs] 72 def apply(self, l: int, r: int, f: F) -> None: 73 assert ( 74 0 <= l <= r <= self.n 75 ), f"IndexError: {self.__class__.__name__}.apply({l}, {r}, {f}), n={self.n}" 76 if l == r: 77 return 78 if f == self.id: 79 return 80 l += self.size 81 r += self.size 82 self._upper_propagate(l, r) 83 l2, r2 = l, r 84 while l < r: 85 if l & 1: 86 self._all_apply(l, f) 87 l += 1 88 if r & 1: 89 self._all_apply(r ^ 1, f) 90 l >>= 1 91 r >>= 1 92 ll, rr = l2, r2 - 1 93 for i in range(1, self.log + 1): 94 ll >>= 1 95 rr >>= 1 96 if ll << i != l2: 97 self._update(ll) 98 if (ll << i == l2 or ll != rr) and (r2 >> i << i != r2): 99 self._update(rr)
100
[docs] 101 def all_apply(self, f: F) -> None: 102 self.lazy[1] = self.composition(f, self.lazy[1])
103
[docs] 104 def prod(self, l: int, r: int) -> T: 105 assert ( 106 0 <= l <= r <= self.n 107 ), f"IndexError: {self.__class__.__name__}.prod({l}, {r}), n={self.n}" 108 if l == r: 109 return self.e 110 l += self.size 111 r += self.size 112 self._upper_propagate(l, r) 113 lres = self.e 114 rres = self.e 115 while l < r: 116 if l & 1: 117 lres = self.op(lres, self.data[l]) 118 l += 1 119 if r & 1: 120 rres = self.op(self.data[r ^ 1], rres) 121 l >>= 1 122 r >>= 1 123 return self.op(lres, rres)
124
[docs] 125 def all_prod(self) -> T: 126 return self.data[1]
127
[docs] 128 def all_propagate(self) -> None: 129 for i in range(self.size): 130 self._propagate(i)
131
[docs] 132 def tolist(self) -> list[T]: 133 self.all_propagate() 134 return self.data[self.size : self.size + self.n]
135
[docs] 136 def max_right(self, l, f) -> int: 137 assert 0 <= l <= self.n 138 # assert f(self.e) 139 if l == self.size: 140 return self.n 141 l += self.size 142 for i in range(self.log, 0, -1): 143 self._propagate(l >> i) 144 s = self.e 145 while True: 146 while l & 1 == 0: 147 l >>= 1 148 if not f(self.op(s, self.data[l])): 149 while l < self.size: 150 self._propagate(l) 151 l <<= 1 152 if f(self.op(s, self.data[l])): 153 s = self.op(s, self.data[l]) 154 l |= 1 155 return l - self.size 156 s = self.op(s, self.data[l]) 157 l += 1 158 if l & -l == l: 159 break 160 return self.n
161
[docs] 162 def min_left(self, r: int, f) -> int: 163 assert 0 <= r <= self.n 164 # assert f(self.e) 165 if r == 0: 166 return 0 167 r += self.size 168 for i in range(self.log, 0, -1): 169 self._propagate((r - 1) >> i) 170 s = self.e 171 while True: 172 r -= 1 173 while r > 1 and r & 1: 174 r >>= 1 175 if not f(self.op(self.data[r], s)): 176 while r < self.size: 177 self._propagate(r) 178 r = r << 1 | 1 179 if f(self.op(self.data[r], s)): 180 s = self.op(self.data[r], s) 181 r ^= 1 182 return r + 1 - self.size 183 s = self.op(self.data[r], s) 184 if r & -r == r: 185 break 186 return 0
187 188 def __getitem__(self, k: int) -> T: 189 assert ( 190 -self.n <= k < self.n 191 ), f"IndexError: {self.__class__.__name__}[{k}], n={self.n}" 192 if k < 0: 193 k += self.n 194 k += self.size 195 for i in range(self.log, 0, -1): 196 self._propagate(k >> i) 197 return self.data[k] 198 199 def __setitem__(self, k: int, v: T): 200 assert ( 201 -self.n <= k < self.n 202 ), f"IndexError: {self.__class__.__name__}[{k}] = {v}, n={self.n}" 203 if k < 0: 204 k += self.n 205 k += self.size 206 for i in range(self.log, 0, -1): 207 self._propagate(k >> i) 208 self.data[k] = v 209 for i in range(1, self.log + 1): 210 self._update(k >> i) 211 212 def __str__(self) -> str: 213 return str(self.tolist()) 214 215 def __repr__(self): 216 return f"{self.__class__.__name__}({self})"