range_set_range_composite

ソースコード

from titan_pylib.data_structures.segment_tree.range_set_range_composite import RangeSetRangeComposite

view on github

展開済みコード

  1# from titan_pylib.data_structures.segment_tree.range_set_range_composite import RangeSetRangeComposite
  2# from titan_pylib.data_structures.segment_tree.segment_tree import SegmentTree
  3# from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
  4#     SegmentTreeInterface,
  5# )
  6from abc import ABC, abstractmethod
  7from typing import TypeVar, Generic, Union, Iterable, Callable
  8
  9T = TypeVar("T")
 10
 11
 12class SegmentTreeInterface(ABC, Generic[T]):
 13
 14    @abstractmethod
 15    def __init__(self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T):
 16        raise NotImplementedError
 17
 18    @abstractmethod
 19    def set(self, k: int, v: T) -> None:
 20        raise NotImplementedError
 21
 22    @abstractmethod
 23    def get(self, k: int) -> T:
 24        raise NotImplementedError
 25
 26    @abstractmethod
 27    def prod(self, l: int, r: int) -> T:
 28        raise NotImplementedError
 29
 30    @abstractmethod
 31    def all_prod(self) -> T:
 32        raise NotImplementedError
 33
 34    @abstractmethod
 35    def max_right(self, l: int, f: Callable[[T], bool]) -> int:
 36        raise NotImplementedError
 37
 38    @abstractmethod
 39    def min_left(self, r: int, f: Callable[[T], bool]) -> int:
 40        raise NotImplementedError
 41
 42    @abstractmethod
 43    def tolist(self) -> list[T]:
 44        raise NotImplementedError
 45
 46    @abstractmethod
 47    def __getitem__(self, k: int) -> T:
 48        raise NotImplementedError
 49
 50    @abstractmethod
 51    def __setitem__(self, k: int, v: T) -> None:
 52        raise NotImplementedError
 53
 54    @abstractmethod
 55    def __str__(self):
 56        raise NotImplementedError
 57
 58    @abstractmethod
 59    def __repr__(self):
 60        raise NotImplementedError
 61from typing import Generic, Iterable, TypeVar, Callable, Union
 62
 63T = TypeVar("T")
 64
 65
 66class SegmentTree(SegmentTreeInterface, Generic[T]):
 67    """セグ木です。非再帰です。"""
 68
 69    def __init__(
 70        self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T
 71    ) -> None:
 72        """``SegmentTree`` を構築します。
 73        :math:`O(n)` です。
 74
 75        Args:
 76            n_or_a (Union[int, Iterable[T]]): ``n: int`` のとき、 ``e`` を初期値として長さ ``n`` の ``SegmentTree`` を構築します。
 77                                              ``a: Iterable[T]`` のとき、 ``a`` から ``SegmentTree`` を構築します。
 78            op (Callable[[T, T], T]): 2項演算の関数です。
 79            e (T): 単位元です。
 80        """
 81        self._op = op
 82        self._e = e
 83        if isinstance(n_or_a, int):
 84            self._n = n_or_a
 85            self._log = (self._n - 1).bit_length()
 86            self._size = 1 << self._log
 87            self._data = [e] * (self._size << 1)
 88        else:
 89            n_or_a = list(n_or_a)
 90            self._n = len(n_or_a)
 91            self._log = (self._n - 1).bit_length()
 92            self._size = 1 << self._log
 93            _data = [e] * (self._size << 1)
 94            _data[self._size : self._size + self._n] = n_or_a
 95            for i in range(self._size - 1, 0, -1):
 96                _data[i] = op(_data[i << 1], _data[i << 1 | 1])
 97            self._data = _data
 98
 99    def set(self, k: int, v: T) -> None:
100        """一点更新です。
101        :math:`O(\\log{n})` です。
102
103        Args:
104            k (int): 更新するインデックスです。
105            v (T): 更新する値です。
106
107        制約:
108            :math:`-n \\leq n \\leq k < n`
109        """
110        assert (
111            -self._n <= k < self._n
112        ), f"IndexError: {self.__class__.__name__}.set({k}, {v}), n={self._n}"
113        if k < 0:
114            k += self._n
115        k += self._size
116        self._data[k] = v
117        for _ in range(self._log):
118            k >>= 1
119            self._data[k] = self._op(self._data[k << 1], self._data[k << 1 | 1])
120
121    def get(self, k: int) -> T:
122        """一点取得です。
123        :math:`O(1)` です。
124
125        Args:
126            k (int): インデックスです。
127
128        制約:
129            :math:`-n \\leq n \\leq k < n`
130        """
131        assert (
132            -self._n <= k < self._n
133        ), f"IndexError: {self.__class__.__name__}.get({k}), n={self._n}"
134        if k < 0:
135            k += self._n
136        return self._data[k + self._size]
137
138    def prod(self, l: int, r: int) -> T:
139        """区間 ``[l, r)`` の総積を返します。
140        :math:`O(\\log{n})` です。
141
142        Args:
143            l (int): インデックスです。
144            r (int): インデックスです。
145
146        制約:
147            :math:`0 \\leq l \\leq r \\leq n`
148        """
149        assert (
150            0 <= l <= r <= self._n
151        ), f"IndexError: {self.__class__.__name__}.prod({l}, {r})"
152        l += self._size
153        r += self._size
154        lres = self._e
155        rres = self._e
156        while l < r:
157            if l & 1:
158                lres = self._op(lres, self._data[l])
159                l += 1
160            if r & 1:
161                rres = self._op(self._data[r ^ 1], rres)
162            l >>= 1
163            r >>= 1
164        return self._op(lres, rres)
165
166    def all_prod(self) -> T:
167        """区間 ``[0, n)`` の総積を返します。
168        :math:`O(1)` です。
169        """
170        return self._data[1]
171
172    def max_right(self, l: int, f: Callable[[T], bool]) -> int:
173        """Find the largest index R s.t. f([l, R)) == True. / O(\\log{n})"""
174        assert (
175            0 <= l <= self._n
176        ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range"
177        # assert f(self._e), \
178        #     f'{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true.'
179        if l == self._n:
180            return self._n
181        l += self._size
182        s = self._e
183        while True:
184            while l & 1 == 0:
185                l >>= 1
186            if not f(self._op(s, self._data[l])):
187                while l < self._size:
188                    l <<= 1
189                    if f(self._op(s, self._data[l])):
190                        s = self._op(s, self._data[l])
191                        l |= 1
192                return l - self._size
193            s = self._op(s, self._data[l])
194            l += 1
195            if l & -l == l:
196                break
197        return self._n
198
199    def min_left(self, r: int, f: Callable[[T], bool]) -> int:
200        """Find the smallest index L s.t. f([L, r)) == True. / O(\\log{n})"""
201        assert (
202            0 <= r <= self._n
203        ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range"
204        # assert f(self._e), \
205        #     f'{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true.'
206        if r == 0:
207            return 0
208        r += self._size
209        s = self._e
210        while True:
211            r -= 1
212            while r > 1 and r & 1:
213                r >>= 1
214            if not f(self._op(self._data[r], s)):
215                while r < self._size:
216                    r = r << 1 | 1
217                    if f(self._op(self._data[r], s)):
218                        s = self._op(self._data[r], s)
219                        r ^= 1
220                return r + 1 - self._size
221            s = self._op(self._data[r], s)
222            if r & -r == r:
223                break
224        return 0
225
226    def tolist(self) -> list[T]:
227        """リストにして返します。
228        :math:`O(n)` です。
229        """
230        return [self.get(i) for i in range(self._n)]
231
232    def show(self) -> None:
233        """デバッグ用のメソッドです。"""
234        print(
235            f"<{self.__class__.__name__}> [\n"
236            + "\n".join(
237                [
238                    "  "
239                    + " ".join(
240                        map(str, [self._data[(1 << i) + j] for j in range(1 << i)])
241                    )
242                    for i in range(self._log + 1)
243                ]
244            )
245            + "\n]"
246        )
247
248    def __getitem__(self, k: int) -> T:
249        assert (
250            -self._n <= k < self._n
251        ), f"IndexError: {self.__class__.__name__}.__getitem__({k}), n={self._n}"
252        return self.get(k)
253
254    def __setitem__(self, k: int, v: T):
255        assert (
256            -self._n <= k < self._n
257        ), f"IndexError: {self.__class__.__name__}.__setitem__{k}, {v}), n={self._n}"
258        self.set(k, v)
259
260    def __len__(self) -> int:
261        return self._n
262
263    def __str__(self) -> str:
264        return str(self.tolist())
265
266    def __repr__(self) -> str:
267        return f"{self.__class__.__name__}({self})"
268# from titan_pylib.data_structures.set.wordsize_tree_set import WordsizeTreeSet
269from array import array
270from typing import Iterable, Optional
271
272
273class WordsizeTreeSet:
274    """``[0, u)`` の整数集合を管理する32分木です。
275    空間 :math:`O(u)` であることに注意してください。
276    """
277
278    def __init__(self, u: int, a: Iterable[int] = []) -> None:
279        """:math:`O(u)` です。"""
280        assert u >= 0
281        u += 1  # 念のため
282        self.u = u
283        data = []
284        len_ = 0
285        if a:
286            u >>= 5
287            A = array("I", bytes(4 * (u + 1)))
288            for a_ in a:
289                assert (
290                    0 <= a_ < self.u
291                ), f"ValueError: {self.__class__.__name__}.__init__, {a_}, u={u}"
292                if A[a_ >> 5] >> (a_ & 31) & 1 == 0:
293                    len_ += 1
294                    A[a_ >> 5] |= 1 << (a_ & 31)
295            data.append(A)
296            while u:
297                a = array("I", bytes(4 * ((u >> 5) + 1)))
298                for i in range(u + 1):
299                    if A[i]:
300                        a[i >> 5] |= 1 << (i & 31)
301                data.append(a)
302                A = a
303                u >>= 5
304        else:
305            while u:
306                u >>= 5
307                data.append(array("I", bytes(4 * (u + 1))))
308        self.data: list[array[int]] = data
309        self.len: int = len_
310        self.len_data: int = len(data)
311
312    def add(self, v: int) -> bool:
313        """整数 ``v`` を個追加します。
314        :math:`O(\\log{u})` です。
315        """
316        assert (
317            0 <= v < self.u
318        ), f"ValueError: {self.__class__.__name__}.add({v}), u={self.u}"
319        if self.data[0][v >> 5] >> (v & 31) & 1:
320            return False
321        self.len += 1
322        for a in self.data:
323            a[v >> 5] |= 1 << (v & 31)
324            v >>= 5
325        return True
326
327    def discard(self, v: int) -> bool:
328        """整数 ``v`` を削除します。
329        :math:`O(\\log{u})` です。
330        """
331        assert (
332            0 <= v < self.u
333        ), f"ValueError: {self.__class__.__name__}.discard({v}), u={self.u}"
334        if self.data[0][v >> 5] >> (v & 31) & 1 == 0:
335            return False
336        self.len -= 1
337        for a in self.data:
338            a[v >> 5] &= ~(1 << (v & 31))
339            v >>= 5
340            if a[v]:
341                break
342        return True
343
344    def remove(self, v: int) -> None:
345        """整数 ``v`` を削除します。
346        :math:`O(\\log{u})` です。
347
348        Note: ``v`` が存在しないとき、例外を投げます。
349        """
350        assert (
351            0 <= v < self.u
352        ), f"ValueError: {self.__class__.__name__}.remove({v}), u={self.u}"
353        assert self.discard(v), f"ValueError: {v} not in self."
354
355    def ge(self, v: int) -> Optional[int]:
356        """``v`` 以上で最小の要素を返します。存在しないとき、 ``None``を返します。
357        :math:`O(\\log{u})` です。
358        """
359        assert (
360            0 <= v < self.u
361        ), f"ValueError: {self.__class__.__name__}.ge({v}), u={self.u}"
362        data = self.data
363        d = 0
364        while True:
365            if d >= self.len_data or v >> 5 >= len(data[d]):
366                return None
367            m = data[d][v >> 5] & ((~0) << (v & 31))
368            if m == 0:
369                d += 1
370                v = (v >> 5) + 1
371            else:
372                v = (v >> 5 << 5) + (m & -m).bit_length() - 1
373                if d == 0:
374                    break
375                v <<= 5
376                d -= 1
377        return v
378
379    def gt(self, v: int) -> Optional[int]:
380        """``v`` より大きい値で最小の要素を返します。存在しないとき、 ``None``を返します。
381        :math:`O(\\log{u})` です。
382        """
383        assert (
384            0 <= v < self.u
385        ), f"ValueError: {self.__class__.__name__}.gt({v}), u={self.u}"
386        if v + 1 == self.u:
387            return
388        return self.ge(v + 1)
389
390    def le(self, v: int) -> Optional[int]:
391        """``v`` 以下で最大の要素を返します。存在しないとき、 ``None``を返します。
392        :math:`O(\\log{u})` です。
393        """
394        assert (
395            0 <= v < self.u
396        ), f"ValueError: {self.__class__.__name__}.le({v}), u={self.u}"
397        data = self.data
398        d = 0
399        while True:
400            if v < 0 or d >= self.len_data:
401                return None
402            m = data[d][v >> 5] & ~((~1) << (v & 31))
403            if m == 0:
404                d += 1
405                v = (v >> 5) - 1
406            else:
407                v = (v >> 5 << 5) + m.bit_length() - 1
408                if d == 0:
409                    break
410                v <<= 5
411                v += 31
412                d -= 1
413        return v
414
415    def lt(self, v: int) -> Optional[int]:
416        """``v`` より小さい値で最大の要素を返します。存在しないとき、 ``None``を返します。
417        :math:`O(\\log{u})` です。
418        """
419        assert (
420            0 <= v < self.u
421        ), f"ValueError: {self.__class__.__name__}.lt({v}), u={self.u}"
422        if v - 1 == 0:
423            return
424        return self.le(v - 1)
425
426    def get_min(self) -> Optional[int]:
427        """`最小値を返します。存在しないとき、 ``None``を返します。
428        :math:`O(\\log{u})` です。
429        """
430        return self.ge(0)
431
432    def get_max(self) -> Optional[int]:
433        """最大値を返します。存在しないとき、 ``None``を返します。
434        :math:`O(\\log{u})` です。
435        """
436        return self.le(self.u - 1)
437
438    def pop_min(self) -> int:
439        """最小値を削除して返します。
440        :math:`O(\\log{u})` です。
441        """
442        v = self.get_min()
443        assert (
444            v is not None
445        ), f"IndexError: pop_min() from empty {self.__class__.__name__}."
446        self.discard(v)
447        return v
448
449    def pop_max(self) -> int:
450        """最大値を削除して返します。
451        :math:`O(\\log{u})` です。
452        """
453        v = self.get_max()
454        assert (
455            v is not None
456        ), f"IndexError: pop_max() from empty {self.__class__.__name__}."
457        self.discard(v)
458        return v
459
460    def clear(self) -> None:
461        """集合を空にします。
462        :math:`O(n\\log{u})` です。
463        """
464        for e in self:
465            self.discard(e)
466        self.len = 0
467
468    def tolist(self) -> list[int]:
469        """リストにして返します。
470        :math:`O(n\\log{u})` です。
471        """
472        return [x for x in self]
473
474    def __bool__(self):
475        return self.len > 0
476
477    def __len__(self):
478        return self.len
479
480    def __contains__(self, v: int):
481        assert (
482            0 <= v < self.u
483        ), f"ValueError: {v} in {self.__class__.__name__}, u={self.u}"
484        return self.data[0][v >> 5] >> (v & 31) & 1 == 1
485
486    def __iter__(self):
487        self._val = self.ge(0)
488        return self
489
490    def __next__(self):
491        if self._val is None:
492            raise StopIteration
493        pre = self._val
494        self._val = self.gt(pre)
495        return pre
496
497    def __str__(self):
498        return "{" + ", ".join(map(str, self)) + "}"
499
500    def __repr__(self):
501        return f"{self.__class__.__name__}({self.u}, {self})"
502from typing import Union, Callable, TypeVar, Generic, Iterable
503
504T = TypeVar("T")
505
506
507class RangeSetRangeComposite(Generic[T]):
508    """区間更新+区間積です。"""
509
510    def __init__(
511        self,
512        n_or_a: Union[int, Iterable[T]],
513        op: Callable[[T, T], T],
514        pow_: Callable[[T, int], T],
515        e: T,
516    ) -> None:
517        """
518        :math:`O(nlogn)` です。
519
520        Args:
521          n_or_a (Union[int, Iterable[T]]): n or a
522          op (Callable[[T, T], T]): 2項演算です。
523          pow_ (Callable[[T, int], T]): 累乗演算です。
524          e (T): 単位元です。
525        """
526        self.op = op
527        self.pow = pow_
528        self.e = e
529        a = [e] * n_or_a if isinstance(n_or_a, int) else list(n_or_a)
530        a.append(e)
531        self.seg = SegmentTree(a, op, e)
532        self.n = len(self.seg)
533        self.indx = WordsizeTreeSet(self.n + 1, range(self.n + 1))
534        self.val = a
535        self.beki = [1] * self.n
536
537    def prod(self, l: int, r: int) -> T:
538        """区間 ``[l, r)`` の総積を返します。
539        :math:`O(logn)` です。
540        ``op`` を :math:`O(logn)` 回、 ``pow_`` を :math:`O(1)` 回呼び出します。
541        """
542        ll = self.indx.ge(l)
543        rr = self.indx.le(r)
544        ans = self.e
545        if ll != l:
546            l0 = self.indx.le(l)
547            beki = self.beki[l0] - (l - l0) if l0 + self.beki[l0] <= r else r - l
548            ans = self.pow(self.val[l0], beki)
549        if ll < rr:
550            ans = self.op(ans, self.seg.prod(ll, rr))
551        if rr != r and l <= rr:
552            ans = self.op(ans, self.pow(self.val[rr], r - rr))
553        return ans
554
555    def apply(self, l: int, r: int, f: T) -> None:
556        """区間 ``[l, r)`` を ``f`` に更新します。
557        :math:`O(logn)` です。
558        ``op`` を :math:`O(logn)` 回、 ``pow_`` を :math:`O(1)` 回呼び出します。
559        """
560        indx, val, beki, seg = self.indx, self.val, self.beki, self.seg
561
562        l0 = indx.le(l)
563        r0 = indx.le(r)
564        if l != l0:
565            seg[l0] = self.pow(val[l0], l - l0)
566        if r != r0:
567            beki[r] = beki[r0] - (r - r0)
568            indx.add(r)
569            val[r] = val[r0]
570            seg[r] = self.pow(val[r], beki[r])
571        if l != l0:
572            beki[l0] = l - l0
573
574        i = indx.gt(l)
575        while i < r:
576            seg[i] = self.e
577            indx.discard(i)
578            i = indx.gt(i)
579        val[l] = f
580        indx.add(l)
581        beki[l] = r - l
582        seg[l] = self.pow(f, beki[l])

仕様

class RangeSetRangeComposite(n_or_a: int | Iterable[T], op: Callable[[T, T], T], pow_: Callable[[T, int], T], e: T)[source]

Bases: Generic[T]

区間更新+区間積です。

apply(l: int, r: int, f: T) None[source]

区間 [l, r)f に更新します。 \(O(logn)\) です。 op\(O(logn)\) 回、 pow_\(O(1)\) 回呼び出します。

prod(l: int, r: int) T[source]

区間 [l, r) の総積を返します。 \(O(logn)\) です。 op\(O(logn)\) 回、 pow_\(O(1)\) 回呼び出します。