Source code for titan_pylib.data_structures.segment_tree.range_set_range_composite

 1from titan_pylib.data_structures.segment_tree.segment_tree import SegmentTree
 2from titan_pylib.data_structures.set.wordsize_tree_set import WordsizeTreeSet
 3from typing import Union, Callable, TypeVar, Generic, Iterable
 4
 5T = TypeVar("T")
 6
 7
[docs] 8class RangeSetRangeComposite(Generic[T]): 9 """区間更新+区間積です。""" 10 11 def __init__( 12 self, 13 n_or_a: Union[int, Iterable[T]], 14 op: Callable[[T, T], T], 15 pow_: Callable[[T, int], T], 16 e: T, 17 ) -> None: 18 """ 19 :math:`O(nlogn)` です。 20 21 Args: 22 n_or_a (Union[int, Iterable[T]]): n or a 23 op (Callable[[T, T], T]): 2項演算です。 24 pow_ (Callable[[T, int], T]): 累乗演算です。 25 e (T): 単位元です。 26 """ 27 self.op = op 28 self.pow = pow_ 29 self.e = e 30 a = [e] * n_or_a if isinstance(n_or_a, int) else list(n_or_a) 31 a.append(e) 32 self.seg = SegmentTree(a, op, e) 33 self.n = len(self.seg) 34 self.indx = WordsizeTreeSet(self.n + 1, range(self.n + 1)) 35 self.val = a 36 self.beki = [1] * self.n 37
[docs] 38 def prod(self, l: int, r: int) -> T: 39 """区間 ``[l, r)`` の総積を返します。 40 :math:`O(logn)` です。 41 ``op`` を :math:`O(logn)` 回、 ``pow_`` を :math:`O(1)` 回呼び出します。 42 """ 43 ll = self.indx.ge(l) 44 rr = self.indx.le(r) 45 ans = self.e 46 if ll != l: 47 l0 = self.indx.le(l) 48 beki = self.beki[l0] - (l - l0) if l0 + self.beki[l0] <= r else r - l 49 ans = self.pow(self.val[l0], beki) 50 if ll < rr: 51 ans = self.op(ans, self.seg.prod(ll, rr)) 52 if rr != r and l <= rr: 53 ans = self.op(ans, self.pow(self.val[rr], r - rr)) 54 return ans
55
[docs] 56 def apply(self, l: int, r: int, f: T) -> None: 57 """区間 ``[l, r)`` を ``f`` に更新します。 58 :math:`O(logn)` です。 59 ``op`` を :math:`O(logn)` 回、 ``pow_`` を :math:`O(1)` 回呼び出します。 60 """ 61 indx, val, beki, seg = self.indx, self.val, self.beki, self.seg 62 63 l0 = indx.le(l) 64 r0 = indx.le(r) 65 if l != l0: 66 seg[l0] = self.pow(val[l0], l - l0) 67 if r != r0: 68 beki[r] = beki[r0] - (r - r0) 69 indx.add(r) 70 val[r] = val[r0] 71 seg[r] = self.pow(val[r], beki[r]) 72 if l != l0: 73 beki[l0] = l - l0 74 75 i = indx.gt(l) 76 while i < r: 77 seg[i] = self.e 78 indx.discard(i) 79 i = indx.gt(i) 80 val[l] = f 81 indx.add(l) 82 beki[l] = r - l 83 seg[l] = self.pow(f, beki[l])