splay_tree_set_top_down

ソースコード

from titan_pylib.data_structures.splay_tree.splay_tree_set_top_down import SplayTreeSetTopDown

view on github

展開済みコード

  1# from titan_pylib.data_structures.splay_tree.splay_tree_set_top_down import SplayTreeSetTopDown
  2raise NotImplementedError
  3# https://atcoder.jp/contests/abc370/submissions/59284972
  4
  5# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  6from typing import Protocol
  7
  8
  9class SupportsLessThan(Protocol):
 10
 11    def __lt__(self, other) -> bool: ...
 12# from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
 13# from titan_pylib.my_class.supports_less_than import SupportsLessThan
 14from abc import ABC, abstractmethod
 15from typing import Iterable, Optional, Iterator, TypeVar, Generic
 16
 17T = TypeVar("T", bound=SupportsLessThan)
 18
 19
 20class OrderedSetInterface(ABC, Generic[T]):
 21
 22    @abstractmethod
 23    def __init__(self, a: Iterable[T]) -> None:
 24        raise NotImplementedError
 25
 26    @abstractmethod
 27    def add(self, key: T) -> bool:
 28        raise NotImplementedError
 29
 30    @abstractmethod
 31    def discard(self, key: T) -> bool:
 32        raise NotImplementedError
 33
 34    @abstractmethod
 35    def remove(self, key: T) -> None:
 36        raise NotImplementedError
 37
 38    @abstractmethod
 39    def le(self, key: T) -> Optional[T]:
 40        raise NotImplementedError
 41
 42    @abstractmethod
 43    def lt(self, key: T) -> Optional[T]:
 44        raise NotImplementedError
 45
 46    @abstractmethod
 47    def ge(self, key: T) -> Optional[T]:
 48        raise NotImplementedError
 49
 50    @abstractmethod
 51    def gt(self, key: T) -> Optional[T]:
 52        raise NotImplementedError
 53
 54    @abstractmethod
 55    def get_max(self) -> Optional[T]:
 56        raise NotImplementedError
 57
 58    @abstractmethod
 59    def get_min(self) -> Optional[T]:
 60        raise NotImplementedError
 61
 62    @abstractmethod
 63    def pop_max(self) -> T:
 64        raise NotImplementedError
 65
 66    @abstractmethod
 67    def pop_min(self) -> T:
 68        raise NotImplementedError
 69
 70    @abstractmethod
 71    def clear(self) -> None:
 72        raise NotImplementedError
 73
 74    @abstractmethod
 75    def tolist(self) -> list[T]:
 76        raise NotImplementedError
 77
 78    @abstractmethod
 79    def __iter__(self) -> Iterator:
 80        raise NotImplementedError
 81
 82    @abstractmethod
 83    def __next__(self) -> T:
 84        raise NotImplementedError
 85
 86    @abstractmethod
 87    def __contains__(self, key: T) -> bool:
 88        raise NotImplementedError
 89
 90    @abstractmethod
 91    def __len__(self) -> int:
 92        raise NotImplementedError
 93
 94    @abstractmethod
 95    def __bool__(self) -> bool:
 96        raise NotImplementedError
 97
 98    @abstractmethod
 99    def __str__(self) -> str:
100        raise NotImplementedError
101
102    @abstractmethod
103    def __repr__(self) -> str:
104        raise NotImplementedError
105from array import array
106from typing import Optional, Generic, Iterable, TypeVar
107
108T = TypeVar("T", bound=SupportsLessThan)
109
110
111class SplayTreeSetTopDown(OrderedSetInterface, Generic[T]):
112
113    def __init__(self, a: Iterable[T] = [], e: T = 0):
114        self.keys: list[T] = [e]
115        self.child = array("I", bytes(8))
116        self.end: int = 1
117        self.root: int = 0
118        self.len: int = 0
119        self.e: T = e
120        if not isinstance(a, list):
121            a = list(a)
122        if a:
123            self._build(a)
124
125    def _build(self, a: list[T]) -> None:
126        def rec(l: int, r: int) -> int:
127            mid = (l + r) >> 1
128            if l != mid:
129                child[mid << 1] = rec(l, mid)
130            if mid + 1 != r:
131                child[mid << 1 | 1] = rec(mid + 1, r)
132            return mid
133
134        if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
135            a = sorted(a)
136            b = [a[0]]
137            for e in a:
138                if b[-1] == e:
139                    continue
140                b.append(e)
141            a = b
142        n = len(a)
143        key, child = self.keys, self.child
144        self.reserve(n - len(key) + 2)
145        self.end += n
146        key[1 : n + 1] = a
147        self.root = rec(1, n + 1)
148        self.len = n
149
150    def _make_node(self, key: T) -> int:
151        if self.end >= len(self.keys):
152            self.keys.append(key)
153            self.child.append(0)
154            self.child.append(0)
155        else:
156            self.keys[self.end] = key
157        self.end += 1
158        return self.end - 1
159
160    def _rotate_right(self, node: int) -> int:
161        child = self.child
162        u = child[node << 1]
163        child[node << 1] = child[u << 1 | 1]
164        child[u << 1 | 1] = node
165        return u
166
167    def _rotate_left(self, node: int) -> int:
168        child = self.child
169        u = child[node << 1 | 1]
170        child[node << 1 | 1] = child[u << 1]
171        child[u << 1] = node
172        return u
173
174    def _set_search_splay(self, key: T) -> None:
175        def set_left(node: int) -> int:
176            nonlocal left
177            child[left << 1 | 1] = node
178            left = node
179            return child[node << 1 | 1]
180
181        def set_right(node: int) -> int:
182            nonlocal right
183            child[right << 1] = node
184            right = node
185            return child[node << 1]
186
187        node = self.root
188        keys, child = self.keys, self.child
189        if (not node) or keys[node] == key:
190            return
191        left, right = 0, 0
192        while keys[node] != key:
193            if key > keys[node]:
194                if not child[node << 1 | 1]:
195                    break
196                if key > keys[child[node << 1 | 1]]:
197                    node = self._rotate_left(node)
198                    if not child[node << 1 | 1]:
199                        break
200                    node = set_left(node)
201                elif (
202                    child[child[node << 1 | 1] << 1]
203                    and key < keys[child[child[node << 1 | 1] << 1]]
204                ):
205                    node = set_left(node)
206                    node = set_right(node)
207                else:
208                    node = set_left(node)
209            else:
210                if not child[node << 1]:
211                    break
212                if key < keys[child[node << 1]]:
213                    node = self._rotate_right(node)
214                    if not child[node << 1]:
215                        break
216                    node = set_right(node)
217                elif (
218                    child[child[node << 1] << 1 | 1]
219                    and key > keys[child[child[node << 1] << 1 | 1]]
220                ):
221                    node = set_right(node)
222                    node = set_left(node)
223                else:
224                    node = set_right(node)
225        child[right << 1] = child[node << 1 | 1]
226        child[left << 1 | 1] = child[node << 1]
227        child[node << 1] = child[1]
228        child[node << 1 | 1] = child[0]
229        self.root = node
230
231    def _get_min_splay(self, node: int) -> int:
232        child = self.child
233        if (not node) or (not child[node << 1]):
234            return node
235        right = 0
236        while child[node << 1]:
237            node = self._rotate_right(node)
238            if not child[node << 1]:
239                break
240            child[right << 1] = node
241            right = node
242            node = child[node << 1]
243        child[right << 1] = child[node << 1 | 1]
244        child[1] = child[node << 1]
245        child[node << 1] = child[1]
246        child[node << 1 | 1] = child[0]
247        return node
248
249    def _get_max_splay(self, node: int) -> int:
250        child = self.child
251        if (not node) or (not child[node << 1 | 1]):
252            return node
253        left = 0
254        while child[node << 1 | 1]:
255            node = self._rotate_left(node)
256            if not child[node << 1 | 1]:
257                break
258            child[left << 1 | 1] = node
259            left = node
260            node = child[node << 1 | 1]
261        child[0] = child[node << 1 | 1]
262        child[left << 1 | 1] = child[node << 1]
263        child[node << 1] = child[1]
264        child[node << 1 | 1] = child[0]
265        return node
266
267    def reserve(self, n: int) -> None:
268        assert n >= 0, "ValueError"
269        self.keys += [self.e] * n
270        self.child += array("I", bytes(8 * n))
271
272    def add(self, key: T) -> bool:
273        if not self.root:
274            self.root = self._make_node(key)
275            self.len += 1
276            return True
277        keys, child = self.keys, self.child
278        self._set_search_splay(key)
279        if keys[self.root] == key:
280            return False
281        node = self._make_node(key)
282        f = key > keys[self.root]
283        child[node << 1 | f] = child[self.root << 1 | f]
284        child[node << 1 | f ^ 1] = self.root
285        child[self.root << 1 | f] = 0
286        self.root = node
287        return True
288
289    def discard(self, key: T) -> bool:
290        if not self.root:
291            return False
292        self._set_search_splay(key)
293        keys, child = self.keys, self.child
294        if keys[self.root] != key:
295            return False
296        if not child[self.root << 1]:
297            self.root = child[self.root << 1 | 1]
298        elif not child[self.root << 1 | 1]:
299            self.root = child[self.root << 1]
300        else:
301            node = self._get_min_splay(child[self.root << 1 | 1])
302            child[node << 1] = child[self.root << 1]
303            self.root = node
304        self.len -= 1
305        return True
306
307    def remove(self, key: T) -> None:
308        if self.discard(key):
309            return
310        raise KeyError(key)
311
312    def ge(self, key: T) -> Optional[T]:
313        def set_left(node: int) -> int:
314            nonlocal left
315            child[left << 1 | 1] = node
316            left = node
317            return child[node << 1 | 1]
318
319        def set_right(node: int) -> int:
320            nonlocal right
321            child[right << 1] = node
322            right = node
323            return child[node << 1]
324
325        node = self.root
326        if not node:
327            return None
328        keys, child = self.keys, self.child
329        if keys[node] == key:
330            return key
331        ge = None
332        left, right = 0, 0
333        while True:
334            if keys[node] == key:
335                ge = key
336                break
337            if key < keys[node]:
338                ge = keys[node]
339                if not child[node << 1]:
340                    break
341                if key < keys[child[node << 1]]:
342                    node = self._rotate_right(node)
343                    ge = keys[node]
344                    if not child[node << 1]:
345                        break
346                    node = set_right(node)
347                elif (
348                    child[child[node << 1] << 1 | 1]
349                    and key > keys[child[child[node << 1] << 1 | 1]]
350                ):
351                    node = set_right(node)
352                    node = set_left(node)
353                else:
354                    node = set_right(node)
355            else:
356                if not child[node << 1 | 1]:
357                    break
358                if key > keys[child[node << 1 | 1]]:
359                    node = self._rotate_left(node)
360                    if not child[node << 1 | 1]:
361                        break
362                    node = set_left(node)
363                elif (
364                    child[child[node << 1 | 1] << 1]
365                    and key < keys[child[child[node << 1 | 1] << 1]]
366                ):
367                    node = set_left(node)
368                    node = set_right(node)
369                else:
370                    node = set_left(node)
371        child[right << 1] = child[node << 1 | 1]
372        child[left << 1 | 1] = child[node << 1]
373        child[node << 1] = child[1]
374        child[node << 1 | 1] = child[0]
375        self.root = node
376        return ge
377
378    def gt(self, key: T) -> Optional[T]:
379        def set_left(node: int) -> int:
380            nonlocal left
381            child[left << 1 | 1] = node
382            left = node
383            return child[node << 1 | 1]
384
385        def set_right(node: int) -> int:
386            nonlocal right
387            child[right << 1] = node
388            right = node
389            return child[node << 1]
390
391        node = self.root
392        if not node:
393            return None
394        keys, child = self.keys, self.child
395        gt = None
396        left, right = 0, 0
397        while True:
398            if key < keys[node]:
399                gt = keys[node]
400                if not child[node << 1]:
401                    break
402                if key < keys[child[node << 1]]:
403                    node = self._rotate_right(node)
404                    gt = keys[node]
405                    if not child[node << 1]:
406                        break
407                    node = set_right(node)
408                elif (
409                    child[child[node << 1] << 1 | 1]
410                    and key > keys[child[child[node << 1] << 1 | 1]]
411                ):
412                    node = set_right(node)
413                    node = set_left(node)
414                else:
415                    node = set_right(node)
416            else:
417                if not child[node << 1 | 1]:
418                    break
419                if key > keys[child[node << 1 | 1]]:
420                    node = self._rotate_left(node)
421                    if not child[node << 1 | 1]:
422                        break
423                    node = set_left(node)
424                elif (
425                    child[child[node << 1 | 1] << 1]
426                    and key < keys[child[child[node << 1 | 1] << 1]]
427                ):
428                    node = set_left(node)
429                    node = set_right(node)
430                else:
431                    node = set_left(node)
432        child[right << 1] = child[node << 1 | 1]
433        child[left << 1 | 1] = child[node << 1]
434        child[node << 1] = child[1]
435        child[node << 1 | 1] = child[0]
436        self.root = node
437        return gt
438
439    def le(self, key: T) -> Optional[T]:
440        def set_left(node: int) -> int:
441            nonlocal left
442            child[left << 1 | 1] = node
443            left = node
444            return child[node << 1 | 1]
445
446        def set_right(node: int) -> int:
447            nonlocal right
448            child[right << 1] = node
449            right = node
450            return child[node << 1]
451
452        node = self.root
453        if not node:
454            return None
455        keys, child = self.keys, self.child
456        if keys[node] == key:
457            return key
458        le = None
459        left, right = 0, 0
460        while True:
461            if keys[node] == key:
462                le = key
463                break
464            if key > keys[node]:
465                le = keys[node]
466                if not child[node << 1 | 1]:
467                    break
468                if key > keys[child[node << 1 | 1]]:
469                    node = self._rotate_left(node)
470                    le = keys[node]
471                    if not child[node << 1 | 1]:
472                        break
473                    node = set_left(node)
474                elif (
475                    child[child[node << 1 | 1] << 1]
476                    and key < keys[child[child[node << 1 | 1] << 1]]
477                ):
478                    node = set_left(node)
479                    node = set_right(node)
480                else:
481                    node = set_left(node)
482            else:
483                if not child[node << 1]:
484                    break
485                if key < keys[child[node << 1]]:
486                    node = self._rotate_right(node)
487                    if not child[node << 1]:
488                        break
489                    node = set_right(node)
490                elif (
491                    child[child[node << 1] << 1 | 1]
492                    and key > keys[child[child[node << 1] << 1 | 1]]
493                ):
494                    node = set_right(node)
495                    node = set_left(node)
496                else:
497                    node = set_right(node)
498        child[right << 1] = child[node << 1 | 1]
499        child[left << 1 | 1] = child[node << 1]
500        child[node << 1] = child[1]
501        child[node << 1 | 1] = child[0]
502        self.root = node
503        return le
504
505    def lt(self, key: T) -> Optional[T]:
506        def set_left(node: int) -> int:
507            nonlocal left
508            child[left << 1 | 1] = node
509            left = node
510            return child[node << 1 | 1]
511
512        def set_right(node: int) -> int:
513            nonlocal right
514            child[right << 1] = node
515            right = node
516            return child[node << 1]
517
518        node = self.root
519        if not node:
520            return None
521        keys, child = self.keys, self.child
522        lt = None
523        left, right = 0, 0
524        while True:
525            if key > keys[node]:
526                lt = keys[node]
527                if not child[node << 1 | 1]:
528                    break
529                if key > keys[child[node << 1 | 1]]:
530                    node = self._rotate_left(node)
531                    lt = keys[node]
532                    if not child[node << 1 | 1]:
533                        break
534                    node = set_left(node)
535                elif (
536                    child[child[node << 1 | 1] << 1]
537                    and key < keys[child[child[node << 1 | 1] << 1]]
538                ):
539                    node = set_left(node)
540                    node = set_right(node)
541                else:
542                    node = set_left(node)
543            else:
544                if not child[node << 1]:
545                    break
546                if key < keys[child[node << 1]]:
547                    node = self._rotate_right(node)
548                    if not child[node << 1]:
549                        break
550                    node = set_right(node)
551                elif (
552                    child[child[node << 1] << 1 | 1]
553                    and key > keys[child[child[node << 1] << 1 | 1]]
554                ):
555                    node = set_right(node)
556                    node = set_left(node)
557                else:
558                    node = set_right(node)
559        child[right << 1] = child[node << 1 | 1]
560        child[left << 1 | 1] = child[node << 1]
561        child[node << 1] = child[1]
562        child[node << 1 | 1] = child[0]
563        self.root = node
564        return lt
565
566    def tolist(self) -> list[T]:
567        node = self.root
568        child, keys = self.child, self.keys
569        stack, res = [], []
570        while stack or node:
571            if node:
572                stack.append(node)
573                node = child[node << 1]
574            else:
575                node = stack.pop()
576                res.append(keys[node])
577                node = child[node << 1 | 1]
578        return res
579
580    def get_max(self) -> T:
581        assert self.root, "IndexError: get_max() from empty SplayTreeSetTopDown"
582        self.root = self._get_max_splay(self.root)
583        return self.keys[self.root]
584
585    def get_min(self) -> T:
586        assert self.root, "IndexError: get_min() from empty SplayTreeSetTopDown"
587        self.root = self._get_min_splay(self.root)
588        return self.keys[self.root]
589
590    def pop_max(self) -> T:
591        assert self.root, "IndexError: pop_max() from empty SplayTreeSetTopDown"
592        self.len -= 1
593        node = self._get_max_splay(self.root)
594        self.root = self.child[node << 1]
595        return self.keys[node]
596
597    def pop_min(self) -> T:
598        assert self.root, "IndexError: pop_min() from empty SplayTreeSetTopDown"
599        self.len -= 1
600        node = self._get_min_splay(self.root)
601        self.root = self.child[node << 1 | 1]
602        return self.keys[node]
603
604    def clear(self) -> None:
605        self.root = 0
606
607    def __iter__(self):
608        self.it = self.get_min()
609        return self
610
611    def __next__(self):
612        if self.it is None:
613            raise StopIteration
614        res = self.it
615        self.it = self.gt(res)
616        return res
617
618    def __contains__(self, key: T):
619        self._set_search_splay(key)
620        return self.keys[self.root] == key
621
622    def __len__(self):
623        return self.len
624
625    def __bool__(self):
626        return self.root != 0
627
628    def __str__(self):
629        return "{" + ", ".join(map(str, self.tolist())) + "}"
630
631    def __repr__(self):
632        return f"SplayTreeSetTopDown({self})"

仕様