wordsize_tree_multiset

ソースコード

from titan_pylib.data_structures.set.wordsize_tree_multiset import WordsizeTreeMultiset

view on github

展開済みコード

  1# from titan_pylib.data_structures.set.wordsize_tree_multiset import WordsizeTreeMultiset
  2# from titan_pylib.data_structures.set.wordsize_tree_set import WordsizeTreeSet
  3from array import array
  4from typing import Iterable, Optional
  5
  6
  7class WordsizeTreeSet:
  8    """``[0, u)`` の整数集合を管理する32分木です。
  9    空間 :math:`O(u)` であることに注意してください。
 10    """
 11
 12    def __init__(self, u: int, a: Iterable[int] = []) -> None:
 13        """:math:`O(u)` です。"""
 14        assert u >= 0
 15        u += 1  # 念のため
 16        self.u = u
 17        data = []
 18        len_ = 0
 19        if a:
 20            u >>= 5
 21            A = array("I", bytes(4 * (u + 1)))
 22            for a_ in a:
 23                assert (
 24                    0 <= a_ < self.u
 25                ), f"ValueError: {self.__class__.__name__}.__init__, {a_}, u={u}"
 26                if A[a_ >> 5] >> (a_ & 31) & 1 == 0:
 27                    len_ += 1
 28                    A[a_ >> 5] |= 1 << (a_ & 31)
 29            data.append(A)
 30            while u:
 31                a = array("I", bytes(4 * ((u >> 5) + 1)))
 32                for i in range(u + 1):
 33                    if A[i]:
 34                        a[i >> 5] |= 1 << (i & 31)
 35                data.append(a)
 36                A = a
 37                u >>= 5
 38        else:
 39            while u:
 40                u >>= 5
 41                data.append(array("I", bytes(4 * (u + 1))))
 42        self.data: list[array[int]] = data
 43        self.len: int = len_
 44        self.len_data: int = len(data)
 45
 46    def add(self, v: int) -> bool:
 47        """整数 ``v`` を個追加します。
 48        :math:`O(\\log{u})` です。
 49        """
 50        assert (
 51            0 <= v < self.u
 52        ), f"ValueError: {self.__class__.__name__}.add({v}), u={self.u}"
 53        if self.data[0][v >> 5] >> (v & 31) & 1:
 54            return False
 55        self.len += 1
 56        for a in self.data:
 57            a[v >> 5] |= 1 << (v & 31)
 58            v >>= 5
 59        return True
 60
 61    def discard(self, v: int) -> bool:
 62        """整数 ``v`` を削除します。
 63        :math:`O(\\log{u})` です。
 64        """
 65        assert (
 66            0 <= v < self.u
 67        ), f"ValueError: {self.__class__.__name__}.discard({v}), u={self.u}"
 68        if self.data[0][v >> 5] >> (v & 31) & 1 == 0:
 69            return False
 70        self.len -= 1
 71        for a in self.data:
 72            a[v >> 5] &= ~(1 << (v & 31))
 73            v >>= 5
 74            if a[v]:
 75                break
 76        return True
 77
 78    def remove(self, v: int) -> None:
 79        """整数 ``v`` を削除します。
 80        :math:`O(\\log{u})` です。
 81
 82        Note: ``v`` が存在しないとき、例外を投げます。
 83        """
 84        assert (
 85            0 <= v < self.u
 86        ), f"ValueError: {self.__class__.__name__}.remove({v}), u={self.u}"
 87        assert self.discard(v), f"ValueError: {v} not in self."
 88
 89    def ge(self, v: int) -> Optional[int]:
 90        """``v`` 以上で最小の要素を返します。存在しないとき、 ``None``を返します。
 91        :math:`O(\\log{u})` です。
 92        """
 93        assert (
 94            0 <= v < self.u
 95        ), f"ValueError: {self.__class__.__name__}.ge({v}), u={self.u}"
 96        data = self.data
 97        d = 0
 98        while True:
 99            if d >= self.len_data or v >> 5 >= len(data[d]):
100                return None
101            m = data[d][v >> 5] & ((~0) << (v & 31))
102            if m == 0:
103                d += 1
104                v = (v >> 5) + 1
105            else:
106                v = (v >> 5 << 5) + (m & -m).bit_length() - 1
107                if d == 0:
108                    break
109                v <<= 5
110                d -= 1
111        return v
112
113    def gt(self, v: int) -> Optional[int]:
114        """``v`` より大きい値で最小の要素を返します。存在しないとき、 ``None``を返します。
115        :math:`O(\\log{u})` です。
116        """
117        assert (
118            0 <= v < self.u
119        ), f"ValueError: {self.__class__.__name__}.gt({v}), u={self.u}"
120        if v + 1 == self.u:
121            return
122        return self.ge(v + 1)
123
124    def le(self, v: int) -> Optional[int]:
125        """``v`` 以下で最大の要素を返します。存在しないとき、 ``None``を返します。
126        :math:`O(\\log{u})` です。
127        """
128        assert (
129            0 <= v < self.u
130        ), f"ValueError: {self.__class__.__name__}.le({v}), u={self.u}"
131        data = self.data
132        d = 0
133        while True:
134            if v < 0 or d >= self.len_data:
135                return None
136            m = data[d][v >> 5] & ~((~1) << (v & 31))
137            if m == 0:
138                d += 1
139                v = (v >> 5) - 1
140            else:
141                v = (v >> 5 << 5) + m.bit_length() - 1
142                if d == 0:
143                    break
144                v <<= 5
145                v += 31
146                d -= 1
147        return v
148
149    def lt(self, v: int) -> Optional[int]:
150        """``v`` より小さい値で最大の要素を返します。存在しないとき、 ``None``を返します。
151        :math:`O(\\log{u})` です。
152        """
153        assert (
154            0 <= v < self.u
155        ), f"ValueError: {self.__class__.__name__}.lt({v}), u={self.u}"
156        if v - 1 == 0:
157            return
158        return self.le(v - 1)
159
160    def get_min(self) -> Optional[int]:
161        """`最小値を返します。存在しないとき、 ``None``を返します。
162        :math:`O(\\log{u})` です。
163        """
164        return self.ge(0)
165
166    def get_max(self) -> Optional[int]:
167        """最大値を返します。存在しないとき、 ``None``を返します。
168        :math:`O(\\log{u})` です。
169        """
170        return self.le(self.u - 1)
171
172    def pop_min(self) -> int:
173        """最小値を削除して返します。
174        :math:`O(\\log{u})` です。
175        """
176        v = self.get_min()
177        assert (
178            v is not None
179        ), f"IndexError: pop_min() from empty {self.__class__.__name__}."
180        self.discard(v)
181        return v
182
183    def pop_max(self) -> int:
184        """最大値を削除して返します。
185        :math:`O(\\log{u})` です。
186        """
187        v = self.get_max()
188        assert (
189            v is not None
190        ), f"IndexError: pop_max() from empty {self.__class__.__name__}."
191        self.discard(v)
192        return v
193
194    def clear(self) -> None:
195        """集合を空にします。
196        :math:`O(n\\log{u})` です。
197        """
198        for e in self:
199            self.discard(e)
200        self.len = 0
201
202    def tolist(self) -> list[int]:
203        """リストにして返します。
204        :math:`O(n\\log{u})` です。
205        """
206        return [x for x in self]
207
208    def __bool__(self):
209        return self.len > 0
210
211    def __len__(self):
212        return self.len
213
214    def __contains__(self, v: int):
215        assert (
216            0 <= v < self.u
217        ), f"ValueError: {v} in {self.__class__.__name__}, u={self.u}"
218        return self.data[0][v >> 5] >> (v & 31) & 1 == 1
219
220    def __iter__(self):
221        self._val = self.ge(0)
222        return self
223
224    def __next__(self):
225        if self._val is None:
226            raise StopIteration
227        pre = self._val
228        self._val = self.gt(pre)
229        return pre
230
231    def __str__(self):
232        return "{" + ", ".join(map(str, self)) + "}"
233
234    def __repr__(self):
235        return f"{self.__class__.__name__}({self.u}, {self})"
236from typing import Iterable, Optional, Iterator
237
238
239class WordsizeTreeMultiset:
240    """``[0, u)`` の整数多重集合を管理する32分木です。
241    空間 :math:`O(u)` であることに注意してください。
242    """
243
244    def __init__(self, u: int, a: Iterable[int] = []) -> None:
245        """:math:`O(u)` です。"""
246        u += 1  # 念のため
247        assert u >= 0
248        self.u = u
249        self.len: int = 0
250        self.st: WordsizeTreeSet = WordsizeTreeSet(u, a)
251        cnt = [0] * (u + 1)
252        for a_ in a:
253            self.len += 1
254            cnt[a_] += 1
255        self.cnt: list[int] = cnt
256
257    def add(self, v: int, cnt: int = 1) -> None:
258        """整数 ``v`` を ``cnt`` 個追加します。
259        :math:`O(\\log{u})` です。
260        """
261        assert (
262            0 <= v < self.u
263        ), f"ValueError: {self.__class__.__name__}.add({v}, {cnt}), u={self.u}"
264        self.len += cnt
265        if self.cnt[v]:
266            self.cnt[v] += cnt
267        else:
268            self.cnt[v] = cnt
269            self.st.add(v)
270
271    def discard(self, v: int, cnt: int = 1) -> bool:
272        """整数 ``v`` を ``cnt`` 個削除します。
273        :math:`O(\\log{u})` です。
274        """
275        assert (
276            0 <= v < self.u
277        ), f"ValueError: {self.__class__.__name__}.discard({v}), u={self.u}"
278        if self.cnt[v] == 0:
279            return False
280        c = self.cnt[v]
281        if c > cnt:
282            self.cnt[v] -= cnt
283            self.len -= cnt
284        else:
285            self.len -= c
286            self.cnt[v] = 0
287            self.st.discard(v)
288        return True
289
290    def remove(self, v: int) -> None:
291        """整数 ``v`` を削除します。
292        :math:`O(\\log{u})` です。
293
294        Note: ``v`` が存在しないとき、例外を投げます。
295        """
296        assert (
297            0 <= v < self.u
298        ), f"ValueError: {self.__class__.__name__}.remove({v}), u={self.u}"
299        assert self.discard(v), f"ValueError: {v} not in self."
300
301    def count(self, v: int) -> int:
302        """整数 ``v`` の個数を返します。
303        :math:`O(1)` です。
304        """
305        assert (
306            0 <= v < self.u
307        ), f"ValueError: {self.__class__.__name__}.count({v}), u={self.u}"
308        return self.cnt[v]
309
310    def ge(self, v: int) -> Optional[int]:
311        """``v`` 以上で最小の要素を返します。存在しないとき、 ``None``を返します。
312        :math:`O(\\log{u})` です。
313        """
314        assert (
315            0 <= v < self.u
316        ), f"ValueError: {self.__class__.__name__}.ge({v}), u={self.u}"
317        return self.st.ge(v)
318
319    def gt(self, v: int) -> Optional[int]:
320        """``v`` より大きい値で最小の要素を返します。存在しないとき、 ``None``を返します。
321        :math:`O(\\log{u})` です。
322        """
323        assert (
324            0 <= v < self.u
325        ), f"ValueError: {self.__class__.__name__}.gt({v}), u={self.u}"
326        return self.ge(v + 1)
327
328    def le(self, v: int) -> Optional[int]:
329        """``v`` 以下で最大の要素を返します。存在しないとき、 ``None``を返します。
330        :math:`O(\\log{u})` です。
331        """
332        assert (
333            0 <= v < self.u
334        ), f"ValueError: {self.__class__.__name__}.le({v}), u={self.u}"
335        return self.st.le(v)
336
337    def lt(self, v: int) -> Optional[int]:
338        """``v`` より小さい値で最大の要素を返します。存在しないとき、 ``None``を返します。
339        :math:`O(\\log{u})` です。
340        """
341        assert (
342            0 <= v < self.u
343        ), f"ValueError: {self.__class__.__name__}.lt({v}), u={self.u}"
344        return self.le(v - 1)
345
346    def get_min(self) -> Optional[int]:
347        """`最小値を返します。存在しないとき、 ``None``を返します。
348        :math:`O(\\log{u})` です。
349        """
350        return self.st.ge(0)
351
352    def get_max(self) -> Optional[int]:
353        """最大値を返します。存在しないとき、 ``None``を返します。
354        :math:`O(\\log{u})` です。
355        """
356        return self.st.le(self.st.u - 1)
357
358    def pop_min(self) -> int:
359        """最小値を削除して返します。
360        :math:`O(\\log{u})` です。
361        """
362        assert self, f"IndexError: pop_min() from empty {self.__class__.__name__}."
363        x = self.st.get_min()
364        self.discard(x)
365        return x
366
367    def pop_max(self) -> int:
368        """最大値を削除して返します。
369        :math:`O(\\log{u})` です。
370        """
371        assert self, f"IndexError: pop_max() from empty {self.__class__.__name__}."
372        x = self.st.get_max()
373        self.discard(x)
374        return x
375
376    def keys(self) -> Iterator[int]:
377        """集合に含まれている要素(重複無し)を昇順にイテレートします。
378        :math:`O(n\\log{u})` です。
379        """
380        v = self.st.get_min()
381        while v is not None:
382            yield v
383            v = self.st.gt(v)
384
385    def values(self) -> Iterator[int]:
386        """集合に含まれている要素の個数を、要素の昇順にイテレートします。
387        :math:`O(n\\log{u})` です。
388        """
389        v = self.st.get_min()
390        while v is not None:
391            yield self.cnt[v]
392            v = self.st.gt(v)
393
394    def items(self) -> Iterator[tuple[int, int]]:
395        """集合に含まれている要素とその個数を、要素の昇順にイテレートします。
396        :math:`O(n\\log{u})` です。
397        """
398        v = self.st.get_min()
399        while v is not None:
400            yield (v, self.cnt[v])
401            v = self.st.gt(v)
402
403    def clear(self) -> None:
404        """集合を空にします。
405        :math:`O(n\\log{u})` です。
406        """
407        for e in self:
408            self.cnt[e] = 0
409            self.st.discard(e)
410        self.len = 0
411
412    def tolist(self) -> list[int]:
413        """リストにして返します。
414        :math:`O(n\\log{u})` です。
415        """
416        return [x for x in self]
417
418    def __contains__(self, v: int):
419        """:math:`O(1)` です。"""
420        return self.cnt[v] > 0
421
422    def __bool__(self):
423        return self.len > 0
424
425    def __len__(self):
426        return self.len
427
428    def __iter__(self):
429        self.__val = self.st.get_min()
430        self.__valcnt = 1
431        return self
432
433    def __next__(self):
434        if self.__val is None:
435            raise StopIteration
436        pre = self.__val
437        self.__valcnt += 1
438        if self.__valcnt > self.cnt[self.__val]:
439            self.__valcnt = 1
440            self.__val = self.gt(self.__val)
441        return pre
442
443    def __str__(self):
444        return "{" + ", ".join(map(str, self)) + "}"
445
446    def __repr__(self):
447        return (
448            f"{self.__class__.__name__}({self.u}, [" + ", ".join(map(str, self)) + "])"
449        )

仕様

class WordsizeTreeMultiset(u: int, a: Iterable[int] = [])[source]

Bases: object

[0, u) の整数多重集合を管理する32分木です。 空間 \(O(u)\) であることに注意してください。

__contains__(v: int)[source]

\(O(1)\) です。

add(v: int, cnt: int = 1) None[source]

整数 vcnt 個追加します。 \(O(\log{u})\) です。

clear() None[source]

集合を空にします。 \(O(n\log{u})\) です。

count(v: int) int[source]

整数 v の個数を返します。 \(O(1)\) です。

discard(v: int, cnt: int = 1) bool[source]

整数 vcnt 個削除します。 \(O(\log{u})\) です。

ge(v: int) int | None[source]

v 以上で最小の要素を返します。存在しないとき、 ``None``を返します。 \(O(\log{u})\) です。

get_max() int | None[source]

最大値を返します。存在しないとき、 ``None``を返します。 \(O(\log{u})\) です。

get_min() int | None[source]

最小値を返します。存在しないとき、 ``None``を返します。 :math:`O(log{u}) です。

gt(v: int) int | None[source]

v より大きい値で最小の要素を返します。存在しないとき、 ``None``を返します。 \(O(\log{u})\) です。

items() Iterator[tuple[int, int]][source]

集合に含まれている要素とその個数を、要素の昇順にイテレートします。 \(O(n\log{u})\) です。

keys() Iterator[int][source]

集合に含まれている要素(重複無し)を昇順にイテレートします。 \(O(n\log{u})\) です。

le(v: int) int | None[source]

v 以下で最大の要素を返します。存在しないとき、 ``None``を返します。 \(O(\log{u})\) です。

lt(v: int) int | None[source]

v より小さい値で最大の要素を返します。存在しないとき、 ``None``を返します。 \(O(\log{u})\) です。

pop_max() int[source]

最大値を削除して返します。 \(O(\log{u})\) です。

pop_min() int[source]

最小値を削除して返します。 \(O(\log{u})\) です。

remove(v: int) None[source]

整数 v を削除します。 \(O(\log{u})\) です。

Note: v が存在しないとき、例外を投げます。

tolist() list[int][source]

リストにして返します。 \(O(n\log{u})\) です。

values() Iterator[int][source]

集合に含まれている要素の個数を、要素の昇順にイテレートします。 \(O(n\log{u})\) です。