range_set

ソースコード

from titan_pylib.data_structures.set.range_set import RangeSet

view on github

展開済みコード

  1# from titan_pylib.data_structures.set.range_set import RangeSet
  2# from titan_pylib.data_structures.splay_tree.splay_tree_set_top_down import (
  3#     SplayTreeSetTopDown,
  4# )
  5raise NotImplementedError
  6# https://atcoder.jp/contests/abc370/submissions/59284972
  7
  8# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  9from typing import Protocol
 10
 11
 12class SupportsLessThan(Protocol):
 13
 14    def __lt__(self, other) -> bool: ...
 15# from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
 16# from titan_pylib.my_class.supports_less_than import SupportsLessThan
 17from abc import ABC, abstractmethod
 18from typing import Iterable, Optional, Iterator, TypeVar, Generic
 19
 20T = TypeVar("T", bound=SupportsLessThan)
 21
 22
 23class OrderedSetInterface(ABC, Generic[T]):
 24
 25    @abstractmethod
 26    def __init__(self, a: Iterable[T]) -> None:
 27        raise NotImplementedError
 28
 29    @abstractmethod
 30    def add(self, key: T) -> bool:
 31        raise NotImplementedError
 32
 33    @abstractmethod
 34    def discard(self, key: T) -> bool:
 35        raise NotImplementedError
 36
 37    @abstractmethod
 38    def remove(self, key: T) -> None:
 39        raise NotImplementedError
 40
 41    @abstractmethod
 42    def le(self, key: T) -> Optional[T]:
 43        raise NotImplementedError
 44
 45    @abstractmethod
 46    def lt(self, key: T) -> Optional[T]:
 47        raise NotImplementedError
 48
 49    @abstractmethod
 50    def ge(self, key: T) -> Optional[T]:
 51        raise NotImplementedError
 52
 53    @abstractmethod
 54    def gt(self, key: T) -> Optional[T]:
 55        raise NotImplementedError
 56
 57    @abstractmethod
 58    def get_max(self) -> Optional[T]:
 59        raise NotImplementedError
 60
 61    @abstractmethod
 62    def get_min(self) -> Optional[T]:
 63        raise NotImplementedError
 64
 65    @abstractmethod
 66    def pop_max(self) -> T:
 67        raise NotImplementedError
 68
 69    @abstractmethod
 70    def pop_min(self) -> T:
 71        raise NotImplementedError
 72
 73    @abstractmethod
 74    def clear(self) -> None:
 75        raise NotImplementedError
 76
 77    @abstractmethod
 78    def tolist(self) -> list[T]:
 79        raise NotImplementedError
 80
 81    @abstractmethod
 82    def __iter__(self) -> Iterator:
 83        raise NotImplementedError
 84
 85    @abstractmethod
 86    def __next__(self) -> T:
 87        raise NotImplementedError
 88
 89    @abstractmethod
 90    def __contains__(self, key: T) -> bool:
 91        raise NotImplementedError
 92
 93    @abstractmethod
 94    def __len__(self) -> int:
 95        raise NotImplementedError
 96
 97    @abstractmethod
 98    def __bool__(self) -> bool:
 99        raise NotImplementedError
100
101    @abstractmethod
102    def __str__(self) -> str:
103        raise NotImplementedError
104
105    @abstractmethod
106    def __repr__(self) -> str:
107        raise NotImplementedError
108from array import array
109from typing import Optional, Generic, Iterable, TypeVar
110
111T = TypeVar("T", bound=SupportsLessThan)
112
113
114class SplayTreeSetTopDown(OrderedSetInterface, Generic[T]):
115
116    def __init__(self, a: Iterable[T] = [], e: T = 0):
117        self.keys: list[T] = [e]
118        self.child = array("I", bytes(8))
119        self.end: int = 1
120        self.root: int = 0
121        self.len: int = 0
122        self.e: T = e
123        if not isinstance(a, list):
124            a = list(a)
125        if a:
126            self._build(a)
127
128    def _build(self, a: list[T]) -> None:
129        def rec(l: int, r: int) -> int:
130            mid = (l + r) >> 1
131            if l != mid:
132                child[mid << 1] = rec(l, mid)
133            if mid + 1 != r:
134                child[mid << 1 | 1] = rec(mid + 1, r)
135            return mid
136
137        if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
138            a = sorted(a)
139            b = [a[0]]
140            for e in a:
141                if b[-1] == e:
142                    continue
143                b.append(e)
144            a = b
145        n = len(a)
146        key, child = self.keys, self.child
147        self.reserve(n - len(key) + 2)
148        self.end += n
149        key[1 : n + 1] = a
150        self.root = rec(1, n + 1)
151        self.len = n
152
153    def _make_node(self, key: T) -> int:
154        if self.end >= len(self.keys):
155            self.keys.append(key)
156            self.child.append(0)
157            self.child.append(0)
158        else:
159            self.keys[self.end] = key
160        self.end += 1
161        return self.end - 1
162
163    def _rotate_right(self, node: int) -> int:
164        child = self.child
165        u = child[node << 1]
166        child[node << 1] = child[u << 1 | 1]
167        child[u << 1 | 1] = node
168        return u
169
170    def _rotate_left(self, node: int) -> int:
171        child = self.child
172        u = child[node << 1 | 1]
173        child[node << 1 | 1] = child[u << 1]
174        child[u << 1] = node
175        return u
176
177    def _set_search_splay(self, key: T) -> None:
178        def set_left(node: int) -> int:
179            nonlocal left
180            child[left << 1 | 1] = node
181            left = node
182            return child[node << 1 | 1]
183
184        def set_right(node: int) -> int:
185            nonlocal right
186            child[right << 1] = node
187            right = node
188            return child[node << 1]
189
190        node = self.root
191        keys, child = self.keys, self.child
192        if (not node) or keys[node] == key:
193            return
194        left, right = 0, 0
195        while keys[node] != key:
196            if key > keys[node]:
197                if not child[node << 1 | 1]:
198                    break
199                if key > keys[child[node << 1 | 1]]:
200                    node = self._rotate_left(node)
201                    if not child[node << 1 | 1]:
202                        break
203                    node = set_left(node)
204                elif (
205                    child[child[node << 1 | 1] << 1]
206                    and key < keys[child[child[node << 1 | 1] << 1]]
207                ):
208                    node = set_left(node)
209                    node = set_right(node)
210                else:
211                    node = set_left(node)
212            else:
213                if not child[node << 1]:
214                    break
215                if key < keys[child[node << 1]]:
216                    node = self._rotate_right(node)
217                    if not child[node << 1]:
218                        break
219                    node = set_right(node)
220                elif (
221                    child[child[node << 1] << 1 | 1]
222                    and key > keys[child[child[node << 1] << 1 | 1]]
223                ):
224                    node = set_right(node)
225                    node = set_left(node)
226                else:
227                    node = set_right(node)
228        child[right << 1] = child[node << 1 | 1]
229        child[left << 1 | 1] = child[node << 1]
230        child[node << 1] = child[1]
231        child[node << 1 | 1] = child[0]
232        self.root = node
233
234    def _get_min_splay(self, node: int) -> int:
235        child = self.child
236        if (not node) or (not child[node << 1]):
237            return node
238        right = 0
239        while child[node << 1]:
240            node = self._rotate_right(node)
241            if not child[node << 1]:
242                break
243            child[right << 1] = node
244            right = node
245            node = child[node << 1]
246        child[right << 1] = child[node << 1 | 1]
247        child[1] = child[node << 1]
248        child[node << 1] = child[1]
249        child[node << 1 | 1] = child[0]
250        return node
251
252    def _get_max_splay(self, node: int) -> int:
253        child = self.child
254        if (not node) or (not child[node << 1 | 1]):
255            return node
256        left = 0
257        while child[node << 1 | 1]:
258            node = self._rotate_left(node)
259            if not child[node << 1 | 1]:
260                break
261            child[left << 1 | 1] = node
262            left = node
263            node = child[node << 1 | 1]
264        child[0] = child[node << 1 | 1]
265        child[left << 1 | 1] = child[node << 1]
266        child[node << 1] = child[1]
267        child[node << 1 | 1] = child[0]
268        return node
269
270    def reserve(self, n: int) -> None:
271        assert n >= 0, "ValueError"
272        self.keys += [self.e] * n
273        self.child += array("I", bytes(8 * n))
274
275    def add(self, key: T) -> bool:
276        if not self.root:
277            self.root = self._make_node(key)
278            self.len += 1
279            return True
280        keys, child = self.keys, self.child
281        self._set_search_splay(key)
282        if keys[self.root] == key:
283            return False
284        node = self._make_node(key)
285        f = key > keys[self.root]
286        child[node << 1 | f] = child[self.root << 1 | f]
287        child[node << 1 | f ^ 1] = self.root
288        child[self.root << 1 | f] = 0
289        self.root = node
290        return True
291
292    def discard(self, key: T) -> bool:
293        if not self.root:
294            return False
295        self._set_search_splay(key)
296        keys, child = self.keys, self.child
297        if keys[self.root] != key:
298            return False
299        if not child[self.root << 1]:
300            self.root = child[self.root << 1 | 1]
301        elif not child[self.root << 1 | 1]:
302            self.root = child[self.root << 1]
303        else:
304            node = self._get_min_splay(child[self.root << 1 | 1])
305            child[node << 1] = child[self.root << 1]
306            self.root = node
307        self.len -= 1
308        return True
309
310    def remove(self, key: T) -> None:
311        if self.discard(key):
312            return
313        raise KeyError(key)
314
315    def ge(self, key: T) -> Optional[T]:
316        def set_left(node: int) -> int:
317            nonlocal left
318            child[left << 1 | 1] = node
319            left = node
320            return child[node << 1 | 1]
321
322        def set_right(node: int) -> int:
323            nonlocal right
324            child[right << 1] = node
325            right = node
326            return child[node << 1]
327
328        node = self.root
329        if not node:
330            return None
331        keys, child = self.keys, self.child
332        if keys[node] == key:
333            return key
334        ge = None
335        left, right = 0, 0
336        while True:
337            if keys[node] == key:
338                ge = key
339                break
340            if key < keys[node]:
341                ge = keys[node]
342                if not child[node << 1]:
343                    break
344                if key < keys[child[node << 1]]:
345                    node = self._rotate_right(node)
346                    ge = keys[node]
347                    if not child[node << 1]:
348                        break
349                    node = set_right(node)
350                elif (
351                    child[child[node << 1] << 1 | 1]
352                    and key > keys[child[child[node << 1] << 1 | 1]]
353                ):
354                    node = set_right(node)
355                    node = set_left(node)
356                else:
357                    node = set_right(node)
358            else:
359                if not child[node << 1 | 1]:
360                    break
361                if key > keys[child[node << 1 | 1]]:
362                    node = self._rotate_left(node)
363                    if not child[node << 1 | 1]:
364                        break
365                    node = set_left(node)
366                elif (
367                    child[child[node << 1 | 1] << 1]
368                    and key < keys[child[child[node << 1 | 1] << 1]]
369                ):
370                    node = set_left(node)
371                    node = set_right(node)
372                else:
373                    node = set_left(node)
374        child[right << 1] = child[node << 1 | 1]
375        child[left << 1 | 1] = child[node << 1]
376        child[node << 1] = child[1]
377        child[node << 1 | 1] = child[0]
378        self.root = node
379        return ge
380
381    def gt(self, key: T) -> Optional[T]:
382        def set_left(node: int) -> int:
383            nonlocal left
384            child[left << 1 | 1] = node
385            left = node
386            return child[node << 1 | 1]
387
388        def set_right(node: int) -> int:
389            nonlocal right
390            child[right << 1] = node
391            right = node
392            return child[node << 1]
393
394        node = self.root
395        if not node:
396            return None
397        keys, child = self.keys, self.child
398        gt = None
399        left, right = 0, 0
400        while True:
401            if key < keys[node]:
402                gt = keys[node]
403                if not child[node << 1]:
404                    break
405                if key < keys[child[node << 1]]:
406                    node = self._rotate_right(node)
407                    gt = keys[node]
408                    if not child[node << 1]:
409                        break
410                    node = set_right(node)
411                elif (
412                    child[child[node << 1] << 1 | 1]
413                    and key > keys[child[child[node << 1] << 1 | 1]]
414                ):
415                    node = set_right(node)
416                    node = set_left(node)
417                else:
418                    node = set_right(node)
419            else:
420                if not child[node << 1 | 1]:
421                    break
422                if key > keys[child[node << 1 | 1]]:
423                    node = self._rotate_left(node)
424                    if not child[node << 1 | 1]:
425                        break
426                    node = set_left(node)
427                elif (
428                    child[child[node << 1 | 1] << 1]
429                    and key < keys[child[child[node << 1 | 1] << 1]]
430                ):
431                    node = set_left(node)
432                    node = set_right(node)
433                else:
434                    node = set_left(node)
435        child[right << 1] = child[node << 1 | 1]
436        child[left << 1 | 1] = child[node << 1]
437        child[node << 1] = child[1]
438        child[node << 1 | 1] = child[0]
439        self.root = node
440        return gt
441
442    def le(self, key: T) -> Optional[T]:
443        def set_left(node: int) -> int:
444            nonlocal left
445            child[left << 1 | 1] = node
446            left = node
447            return child[node << 1 | 1]
448
449        def set_right(node: int) -> int:
450            nonlocal right
451            child[right << 1] = node
452            right = node
453            return child[node << 1]
454
455        node = self.root
456        if not node:
457            return None
458        keys, child = self.keys, self.child
459        if keys[node] == key:
460            return key
461        le = None
462        left, right = 0, 0
463        while True:
464            if keys[node] == key:
465                le = key
466                break
467            if key > keys[node]:
468                le = keys[node]
469                if not child[node << 1 | 1]:
470                    break
471                if key > keys[child[node << 1 | 1]]:
472                    node = self._rotate_left(node)
473                    le = keys[node]
474                    if not child[node << 1 | 1]:
475                        break
476                    node = set_left(node)
477                elif (
478                    child[child[node << 1 | 1] << 1]
479                    and key < keys[child[child[node << 1 | 1] << 1]]
480                ):
481                    node = set_left(node)
482                    node = set_right(node)
483                else:
484                    node = set_left(node)
485            else:
486                if not child[node << 1]:
487                    break
488                if key < keys[child[node << 1]]:
489                    node = self._rotate_right(node)
490                    if not child[node << 1]:
491                        break
492                    node = set_right(node)
493                elif (
494                    child[child[node << 1] << 1 | 1]
495                    and key > keys[child[child[node << 1] << 1 | 1]]
496                ):
497                    node = set_right(node)
498                    node = set_left(node)
499                else:
500                    node = set_right(node)
501        child[right << 1] = child[node << 1 | 1]
502        child[left << 1 | 1] = child[node << 1]
503        child[node << 1] = child[1]
504        child[node << 1 | 1] = child[0]
505        self.root = node
506        return le
507
508    def lt(self, key: T) -> Optional[T]:
509        def set_left(node: int) -> int:
510            nonlocal left
511            child[left << 1 | 1] = node
512            left = node
513            return child[node << 1 | 1]
514
515        def set_right(node: int) -> int:
516            nonlocal right
517            child[right << 1] = node
518            right = node
519            return child[node << 1]
520
521        node = self.root
522        if not node:
523            return None
524        keys, child = self.keys, self.child
525        lt = None
526        left, right = 0, 0
527        while True:
528            if key > keys[node]:
529                lt = keys[node]
530                if not child[node << 1 | 1]:
531                    break
532                if key > keys[child[node << 1 | 1]]:
533                    node = self._rotate_left(node)
534                    lt = keys[node]
535                    if not child[node << 1 | 1]:
536                        break
537                    node = set_left(node)
538                elif (
539                    child[child[node << 1 | 1] << 1]
540                    and key < keys[child[child[node << 1 | 1] << 1]]
541                ):
542                    node = set_left(node)
543                    node = set_right(node)
544                else:
545                    node = set_left(node)
546            else:
547                if not child[node << 1]:
548                    break
549                if key < keys[child[node << 1]]:
550                    node = self._rotate_right(node)
551                    if not child[node << 1]:
552                        break
553                    node = set_right(node)
554                elif (
555                    child[child[node << 1] << 1 | 1]
556                    and key > keys[child[child[node << 1] << 1 | 1]]
557                ):
558                    node = set_right(node)
559                    node = set_left(node)
560                else:
561                    node = set_right(node)
562        child[right << 1] = child[node << 1 | 1]
563        child[left << 1 | 1] = child[node << 1]
564        child[node << 1] = child[1]
565        child[node << 1 | 1] = child[0]
566        self.root = node
567        return lt
568
569    def tolist(self) -> list[T]:
570        node = self.root
571        child, keys = self.child, self.keys
572        stack, res = [], []
573        while stack or node:
574            if node:
575                stack.append(node)
576                node = child[node << 1]
577            else:
578                node = stack.pop()
579                res.append(keys[node])
580                node = child[node << 1 | 1]
581        return res
582
583    def get_max(self) -> T:
584        assert self.root, "IndexError: get_max() from empty SplayTreeSetTopDown"
585        self.root = self._get_max_splay(self.root)
586        return self.keys[self.root]
587
588    def get_min(self) -> T:
589        assert self.root, "IndexError: get_min() from empty SplayTreeSetTopDown"
590        self.root = self._get_min_splay(self.root)
591        return self.keys[self.root]
592
593    def pop_max(self) -> T:
594        assert self.root, "IndexError: pop_max() from empty SplayTreeSetTopDown"
595        self.len -= 1
596        node = self._get_max_splay(self.root)
597        self.root = self.child[node << 1]
598        return self.keys[node]
599
600    def pop_min(self) -> T:
601        assert self.root, "IndexError: pop_min() from empty SplayTreeSetTopDown"
602        self.len -= 1
603        node = self._get_min_splay(self.root)
604        self.root = self.child[node << 1 | 1]
605        return self.keys[node]
606
607    def clear(self) -> None:
608        self.root = 0
609
610    def __iter__(self):
611        self.it = self.get_min()
612        return self
613
614    def __next__(self):
615        if self.it is None:
616            raise StopIteration
617        res = self.it
618        self.it = self.gt(res)
619        return res
620
621    def __contains__(self, key: T):
622        self._set_search_splay(key)
623        return self.keys[self.root] == key
624
625    def __len__(self):
626        return self.len
627
628    def __bool__(self):
629        return self.root != 0
630
631    def __str__(self):
632        return "{" + ", ".join(map(str, self.tolist())) + "}"
633
634    def __repr__(self):
635        return f"SplayTreeSetTopDown({self})"
636from typing import Iterable
637
638
639class RangeSet:
640    """
641    1点追加/削除/mex取得機能しかないです、あの
642    """
643
644    def __init__(self, a: Iterable[int] = []):
645        self.data: SplayTreeSetTopDown[tuple[int, int]] = SplayTreeSetTopDown()
646        self.dic: dict[int, int] = {}
647        for a_ in sorted(a):
648            self.add(a_)
649
650    def add(self, x: int) -> None:
651        if x in self.dic:
652            self.dic[x] += 1
653            return
654        self.dic[x] = 1
655        if x < 0:
656            return
657        t = self.data.ge((x - 1, 2))
658        if t is not None and t[0] == x - 1:
659            self.data.discard(t)
660            t = self.data.gt(t)
661        else:
662            self.data.add((x, 1))
663        if t is not None and t[0] == x + 1:
664            self.data.discard(t)
665        else:
666            self.data.add((x, 2))
667
668    def discard(self, x: int) -> None:
669        if x not in self.dic:
670            return
671        if self.dic[x] > 1:
672            self.dic[x] -= 1
673            return
674        del self.dic[x]
675        if x < 0:
676            return
677        t = self.data.lt((x, 2))
678        if t[0] == x:
679            self.data.discard(t)
680        else:
681            self.data.add((x - 1, 2))
682        t = self.data.ge((x, 2))
683        if t[0] == x:
684            self.data.discard(t)
685        else:
686            self.data.add((x + 1, 1))
687
688    def mex(self) -> int:
689        t = self.data.ge((0, -1))
690        if t is None or t[0] != 0:
691            return 0
692        else:
693            t = self.data.gt(t)
694            return t[0] + 1
695
696    def count(self, x: int) -> int:
697        if x in self.dic:
698            return self.dic[x]
699        return 0
700
701    def __contains__(self, x: int):
702        return x in self.dic
703
704    def __str__(self):
705        return (
706            "{"
707            + ", ".join(
708                map(str, sorted(k for k, v in self.dic.items() for _ in range(v)))
709            )
710            + "}"
711        )

仕様