mex_multiset

ソースコード

from titan_pylib.data_structures.set.mex_multiset import MexMultiset

view on github

展開済みコード

  1# from titan_pylib.data_structures.set.mex_multiset import MexMultiset
  2# from titan_pylib.data_structures.segment_tree.segment_tree import SegmentTree
  3# from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
  4#     SegmentTreeInterface,
  5# )
  6from abc import ABC, abstractmethod
  7from typing import TypeVar, Generic, Union, Iterable, Callable
  8
  9T = TypeVar("T")
 10
 11
 12class SegmentTreeInterface(ABC, Generic[T]):
 13
 14    @abstractmethod
 15    def __init__(self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T):
 16        raise NotImplementedError
 17
 18    @abstractmethod
 19    def set(self, k: int, v: T) -> None:
 20        raise NotImplementedError
 21
 22    @abstractmethod
 23    def get(self, k: int) -> T:
 24        raise NotImplementedError
 25
 26    @abstractmethod
 27    def prod(self, l: int, r: int) -> T:
 28        raise NotImplementedError
 29
 30    @abstractmethod
 31    def all_prod(self) -> T:
 32        raise NotImplementedError
 33
 34    @abstractmethod
 35    def max_right(self, l: int, f: Callable[[T], bool]) -> int:
 36        raise NotImplementedError
 37
 38    @abstractmethod
 39    def min_left(self, r: int, f: Callable[[T], bool]) -> int:
 40        raise NotImplementedError
 41
 42    @abstractmethod
 43    def tolist(self) -> list[T]:
 44        raise NotImplementedError
 45
 46    @abstractmethod
 47    def __getitem__(self, k: int) -> T:
 48        raise NotImplementedError
 49
 50    @abstractmethod
 51    def __setitem__(self, k: int, v: T) -> None:
 52        raise NotImplementedError
 53
 54    @abstractmethod
 55    def __str__(self):
 56        raise NotImplementedError
 57
 58    @abstractmethod
 59    def __repr__(self):
 60        raise NotImplementedError
 61from typing import Generic, Iterable, TypeVar, Callable, Union
 62
 63T = TypeVar("T")
 64
 65
 66class SegmentTree(SegmentTreeInterface, Generic[T]):
 67    """セグ木です。非再帰です。"""
 68
 69    def __init__(
 70        self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T
 71    ) -> None:
 72        """``SegmentTree`` を構築します。
 73        :math:`O(n)` です。
 74
 75        Args:
 76            n_or_a (Union[int, Iterable[T]]): ``n: int`` のとき、 ``e`` を初期値として長さ ``n`` の ``SegmentTree`` を構築します。
 77                                              ``a: Iterable[T]`` のとき、 ``a`` から ``SegmentTree`` を構築します。
 78            op (Callable[[T, T], T]): 2項演算の関数です。
 79            e (T): 単位元です。
 80        """
 81        self._op = op
 82        self._e = e
 83        if isinstance(n_or_a, int):
 84            self._n = n_or_a
 85            self._log = (self._n - 1).bit_length()
 86            self._size = 1 << self._log
 87            self._data = [e] * (self._size << 1)
 88        else:
 89            n_or_a = list(n_or_a)
 90            self._n = len(n_or_a)
 91            self._log = (self._n - 1).bit_length()
 92            self._size = 1 << self._log
 93            _data = [e] * (self._size << 1)
 94            _data[self._size : self._size + self._n] = n_or_a
 95            for i in range(self._size - 1, 0, -1):
 96                _data[i] = op(_data[i << 1], _data[i << 1 | 1])
 97            self._data = _data
 98
 99    def set(self, k: int, v: T) -> None:
100        """一点更新です。
101        :math:`O(\\log{n})` です。
102
103        Args:
104            k (int): 更新するインデックスです。
105            v (T): 更新する値です。
106
107        制約:
108            :math:`-n \\leq n \\leq k < n`
109        """
110        assert (
111            -self._n <= k < self._n
112        ), f"IndexError: {self.__class__.__name__}.set({k}, {v}), n={self._n}"
113        if k < 0:
114            k += self._n
115        k += self._size
116        self._data[k] = v
117        for _ in range(self._log):
118            k >>= 1
119            self._data[k] = self._op(self._data[k << 1], self._data[k << 1 | 1])
120
121    def get(self, k: int) -> T:
122        """一点取得です。
123        :math:`O(1)` です。
124
125        Args:
126            k (int): インデックスです。
127
128        制約:
129            :math:`-n \\leq n \\leq k < n`
130        """
131        assert (
132            -self._n <= k < self._n
133        ), f"IndexError: {self.__class__.__name__}.get({k}), n={self._n}"
134        if k < 0:
135            k += self._n
136        return self._data[k + self._size]
137
138    def prod(self, l: int, r: int) -> T:
139        """区間 ``[l, r)`` の総積を返します。
140        :math:`O(\\log{n})` です。
141
142        Args:
143            l (int): インデックスです。
144            r (int): インデックスです。
145
146        制約:
147            :math:`0 \\leq l \\leq r \\leq n`
148        """
149        assert (
150            0 <= l <= r <= self._n
151        ), f"IndexError: {self.__class__.__name__}.prod({l}, {r})"
152        l += self._size
153        r += self._size
154        lres = self._e
155        rres = self._e
156        while l < r:
157            if l & 1:
158                lres = self._op(lres, self._data[l])
159                l += 1
160            if r & 1:
161                rres = self._op(self._data[r ^ 1], rres)
162            l >>= 1
163            r >>= 1
164        return self._op(lres, rres)
165
166    def all_prod(self) -> T:
167        """区間 ``[0, n)`` の総積を返します。
168        :math:`O(1)` です。
169        """
170        return self._data[1]
171
172    def max_right(self, l: int, f: Callable[[T], bool]) -> int:
173        """Find the largest index R s.t. f([l, R)) == True. / O(\\log{n})"""
174        assert (
175            0 <= l <= self._n
176        ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range"
177        # assert f(self._e), \
178        #     f'{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true.'
179        if l == self._n:
180            return self._n
181        l += self._size
182        s = self._e
183        while True:
184            while l & 1 == 0:
185                l >>= 1
186            if not f(self._op(s, self._data[l])):
187                while l < self._size:
188                    l <<= 1
189                    if f(self._op(s, self._data[l])):
190                        s = self._op(s, self._data[l])
191                        l |= 1
192                return l - self._size
193            s = self._op(s, self._data[l])
194            l += 1
195            if l & -l == l:
196                break
197        return self._n
198
199    def min_left(self, r: int, f: Callable[[T], bool]) -> int:
200        """Find the smallest index L s.t. f([L, r)) == True. / O(\\log{n})"""
201        assert (
202            0 <= r <= self._n
203        ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range"
204        # assert f(self._e), \
205        #     f'{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true.'
206        if r == 0:
207            return 0
208        r += self._size
209        s = self._e
210        while True:
211            r -= 1
212            while r > 1 and r & 1:
213                r >>= 1
214            if not f(self._op(self._data[r], s)):
215                while r < self._size:
216                    r = r << 1 | 1
217                    if f(self._op(self._data[r], s)):
218                        s = self._op(self._data[r], s)
219                        r ^= 1
220                return r + 1 - self._size
221            s = self._op(self._data[r], s)
222            if r & -r == r:
223                break
224        return 0
225
226    def tolist(self) -> list[T]:
227        """リストにして返します。
228        :math:`O(n)` です。
229        """
230        return [self.get(i) for i in range(self._n)]
231
232    def show(self) -> None:
233        """デバッグ用のメソッドです。"""
234        print(
235            f"<{self.__class__.__name__}> [\n"
236            + "\n".join(
237                [
238                    "  "
239                    + " ".join(
240                        map(str, [self._data[(1 << i) + j] for j in range(1 << i)])
241                    )
242                    for i in range(self._log + 1)
243                ]
244            )
245            + "\n]"
246        )
247
248    def __getitem__(self, k: int) -> T:
249        assert (
250            -self._n <= k < self._n
251        ), f"IndexError: {self.__class__.__name__}.__getitem__({k}), n={self._n}"
252        return self.get(k)
253
254    def __setitem__(self, k: int, v: T):
255        assert (
256            -self._n <= k < self._n
257        ), f"IndexError: {self.__class__.__name__}.__setitem__{k}, {v}), n={self._n}"
258        self.set(k, v)
259
260    def __len__(self) -> int:
261        return self._n
262
263    def __str__(self) -> str:
264        return str(self.tolist())
265
266    def __repr__(self) -> str:
267        return f"{self.__class__.__name__}({self})"
268from typing import Iterable
269
270
271class MexMultiset:
272    """``MexMultiset`` です。
273
274    各操作は `log` がつきますが、ANDセグ木の ``log`` で割と軽いです。
275    """
276
277    def __init__(self, u: int, a: Iterable[int] = []) -> None:
278        """``[0, u)`` の範囲の mex を計算する ``MexMultiset`` を構築します。
279
280        時間・空間共に :math:`O(u)` です。
281
282        Args:
283          u (int): 値の上限です。
284        """
285        data = [0] * (u + 1)
286        init_data = [1] * (u + 1)
287        for e in a:
288            if e <= u:
289                data[e] += 1
290                init_data[e] = 0
291        self.u: int = u
292        self.data: list[int] = data
293        self.seg: SegmentTree[int] = SegmentTree(init_data, op=lambda s, t: s | t, e=0)
294
295    def add(self, key: int) -> None:
296        """``key`` を追加します。
297
298        :math:`O(\\log{n})` です。
299        """
300        if key > self.u:
301            return
302        if self.data[key] == 0:
303            self.seg[key] = 0
304        self.data[key] += 1
305
306    def remove(self, key: int) -> None:
307        """``key`` を削除します。 ``key`` は存在していなければなりません。
308
309        :math:`O(\\log{n})` です。
310        """
311        if key > self.u:
312            return
313        if self.data[key] == 1:
314            self.seg[key] = 1
315        self.data[key] -= 1
316
317    def mex(self) -> int:
318        """mex を返します。
319
320        :math:`O(\\log{n})` です。
321        """
322        return self.seg.max_right(0, lambda lr: lr == 0)

仕様

class MexMultiset(u: int, a: Iterable[int] = [])[source]

Bases: object

MexMultiset です。

各操作は log がつきますが、ANDセグ木の log で割と軽いです。

add(key: int) None[source]

key を追加します。

\(O(\log{n})\) です。

mex() int[source]

mex を返します。

\(O(\log{n})\) です。

remove(key: int) None[source]

key を削除します。 key は存在していなければなりません。

\(O(\log{n})\) です。