Source code for titan_pylib.data_structures.segment_tree.segment_tree_RmQ

  1from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
  2    SegmentTreeInterface,
  3)
  4from titan_pylib.my_class.supports_less_than import SupportsLessThan
  5from typing import Generic, Iterable, TypeVar, Union
  6
  7T = TypeVar("T", bound=SupportsLessThan)
  8
  9
[docs] 10class SegmentTreeRmQ(SegmentTreeInterface, Generic[T]): 11 """RmQ セグ木です。""" 12 13 def __init__(self, _n_or_a: Union[int, Iterable[T]], e: T) -> None: 14 self._e = e 15 if isinstance(_n_or_a, int): 16 self._n = _n_or_a 17 self._log = (self._n - 1).bit_length() 18 self._size = 1 << self._log 19 self._data = [self._e] * (self._size << 1) 20 else: 21 _n_or_a = list(_n_or_a) 22 self._n = len(_n_or_a) 23 self._log = (self._n - 1).bit_length() 24 self._size = 1 << self._log 25 _data = [self._e] * (self._size << 1) 26 _data[self._size : self._size + self._n] = _n_or_a 27 for i in range(self._size - 1, 0, -1): 28 _data[i] = ( 29 _data[i << 1] 30 if _data[i << 1] < _data[i << 1 | 1] 31 else _data[i << 1 | 1] 32 ) 33 self._data = _data 34
[docs] 35 def set(self, k: int, v: T) -> None: 36 if k < 0: 37 k += self._n 38 assert ( 39 0 <= k < self._n 40 ), f"IndexError: {self.__class__.__name__}.set({k}: int, {v}: T), n={self._n}" 41 k += self._size 42 self._data[k] = v 43 for _ in range(self._log): 44 k >>= 1 45 self._data[k] = ( 46 self._data[k << 1] 47 if self._data[k << 1] < self._data[k << 1 | 1] 48 else self._data[k << 1 | 1] 49 )
50
[docs] 51 def get(self, k: int) -> T: 52 if k < 0: 53 k += self._n 54 assert ( 55 0 <= k < self._n 56 ), f"IndexError: {self.__class__.__name__}.get({k}: int), n={self._n}" 57 return self._data[k + self._size]
58
[docs] 59 def prod(self, l: int, r: int) -> T: 60 assert ( 61 0 <= l <= r <= self._n 62 ), f"IndexError: {self.__class__.__name__}.prod({l}: int, {r}: int)" 63 l += self._size 64 r += self._size 65 res = self._e 66 while l < r: 67 if l & 1: 68 if res > self._data[l]: 69 res = self._data[l] 70 l += 1 71 if r & 1: 72 r ^= 1 73 if res > self._data[r]: 74 res = self._data[r] 75 l >>= 1 76 r >>= 1 77 return res
78
[docs] 79 def all_prod(self) -> T: 80 return self._data[1]
81
[docs] 82 def max_right(self, l: int, f=lambda lr: lr): 83 assert ( 84 0 <= l <= self._n 85 ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range" 86 assert f( 87 self._e 88 ), f"{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true." 89 if l == self._n: 90 return self._n 91 l += self._size 92 s = self._e 93 while True: 94 while l & 1 == 0: 95 l >>= 1 96 if not f(min(s, self._data[l])): 97 while l < self._size: 98 l <<= 1 99 if f(min(s, self._data[l])): 100 if s > self._data[l]: 101 s = self._data[l] 102 l += 1 103 return l - self._size 104 s = min(s, self._data[l]) 105 l += 1 106 if l & -l == l: 107 break 108 return self._n
109
[docs] 110 def min_left(self, r: int, f=lambda lr: lr): 111 assert ( 112 0 <= r <= self._n 113 ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range" 114 assert f( 115 self._e 116 ), f"{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true." 117 if r == 0: 118 return 0 119 r += self._size 120 s = self._e 121 while True: 122 r -= 1 123 while r > 1 and r & 1: 124 r >>= 1 125 if not f(min(self._data[r], s)): 126 while r < self._size: 127 r = r << 1 | 1 128 if f(min(self._data[r], s)): 129 if s > self._data[r]: 130 s = self._data[r] 131 r -= 1 132 return r + 1 - self._size 133 s = min(self._data[r], s) 134 if r & -r == r: 135 break 136 return 0
137
[docs] 138 def tolist(self) -> list[T]: 139 return [self.get(i) for i in range(self._n)]
140
[docs] 141 def show(self) -> None: 142 print( 143 f"<{self.__class__.__name__}> [\n" 144 + "\n".join( 145 [ 146 " " 147 + " ".join( 148 map(str, [self._data[(1 << i) + j] for j in range(1 << i)]) 149 ) 150 for i in range(self._log + 1) 151 ] 152 ) 153 + "\n]" 154 )
155 156 def __getitem__(self, k: int) -> T: 157 assert ( 158 -self._n <= k < self._n 159 ), f"IndexError: {self.__class__.__name__}.__getitem__({k}: int), n={self._n}" 160 return self.get(k) 161 162 def __setitem__(self, k: int, v: T): 163 assert ( 164 -self._n <= k < self._n 165 ), f"IndexError: {self.__class__.__name__}.__setitem__{k}: int, {v}: T), n={self._n}" 166 self.set(k, v) 167 168 def __str__(self): 169 return "[" + ", ".join(map(str, (self.get(i) for i in range(self._n)))) + "]" 170 171 def __repr__(self): 172 return f"{self.__class__.__name__}({self})"