splay_tree_multiset_top_down

ソースコード

from titan_pylib.data_structures.splay_tree.splay_tree_multiset_top_down import SplayTreeMultisetTopDown

view on github

展開済みコード

  1# from titan_pylib.data_structures.splay_tree.splay_tree_multiset_top_down import SplayTreeMultisetTopDown
  2# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  3from typing import Protocol
  4
  5
  6class SupportsLessThan(Protocol):
  7
  8    def __lt__(self, other) -> bool: ...
  9# from titan_pylib.my_class.ordered_multiset_interface import OrderedMultisetInterface
 10# from titan_pylib.my_class.supports_less_than import SupportsLessThan
 11from abc import ABC, abstractmethod
 12from typing import Iterable, Optional, Iterator, TypeVar, Generic
 13
 14T = TypeVar("T", bound=SupportsLessThan)
 15
 16
 17class OrderedMultisetInterface(ABC, Generic[T]):
 18
 19    @abstractmethod
 20    def __init__(self, a: Iterable[T]) -> None:
 21        raise NotImplementedError
 22
 23    @abstractmethod
 24    def add(self, key: T, cnt: int) -> None:
 25        raise NotImplementedError
 26
 27    @abstractmethod
 28    def discard(self, key: T, cnt: int) -> bool:
 29        raise NotImplementedError
 30
 31    @abstractmethod
 32    def discard_all(self, key: T) -> bool:
 33        raise NotImplementedError
 34
 35    @abstractmethod
 36    def count(self, key: T) -> int:
 37        raise NotImplementedError
 38
 39    @abstractmethod
 40    def remove(self, key: T, cnt: int) -> None:
 41        raise NotImplementedError
 42
 43    @abstractmethod
 44    def le(self, key: T) -> Optional[T]:
 45        raise NotImplementedError
 46
 47    @abstractmethod
 48    def lt(self, key: T) -> Optional[T]:
 49        raise NotImplementedError
 50
 51    @abstractmethod
 52    def ge(self, key: T) -> Optional[T]:
 53        raise NotImplementedError
 54
 55    @abstractmethod
 56    def gt(self, key: T) -> Optional[T]:
 57        raise NotImplementedError
 58
 59    @abstractmethod
 60    def get_max(self) -> Optional[T]:
 61        raise NotImplementedError
 62
 63    @abstractmethod
 64    def get_min(self) -> Optional[T]:
 65        raise NotImplementedError
 66
 67    @abstractmethod
 68    def pop_max(self) -> T:
 69        raise NotImplementedError
 70
 71    @abstractmethod
 72    def pop_min(self) -> T:
 73        raise NotImplementedError
 74
 75    @abstractmethod
 76    def clear(self) -> None:
 77        raise NotImplementedError
 78
 79    @abstractmethod
 80    def tolist(self) -> list[T]:
 81        raise NotImplementedError
 82
 83    @abstractmethod
 84    def __iter__(self) -> Iterator:
 85        raise NotImplementedError
 86
 87    @abstractmethod
 88    def __next__(self) -> T:
 89        raise NotImplementedError
 90
 91    @abstractmethod
 92    def __contains__(self, key: T) -> bool:
 93        raise NotImplementedError
 94
 95    @abstractmethod
 96    def __len__(self) -> int:
 97        raise NotImplementedError
 98
 99    @abstractmethod
100    def __bool__(self) -> bool:
101        raise NotImplementedError
102
103    @abstractmethod
104    def __str__(self) -> str:
105        raise NotImplementedError
106
107    @abstractmethod
108    def __repr__(self) -> str:
109        raise NotImplementedError
110from array import array
111from typing import Optional, Generic, Iterable, Sequence, TypeVar
112
113T = TypeVar("T", bound=SupportsLessThan)
114
115
116class SplayTreeMultisetTopDown(OrderedMultisetInterface, Generic[T]):
117
118    def __init__(self, a: Iterable[T] = [], e: T = 0):
119        self.keys: list[T] = [e]
120        self.vals: list[int] = [0]
121        self.child = array("I", bytes(8))
122        self.end: int = 1
123        self.root: int = 0
124        self.len: int = 0
125        self.e: T = e
126        if not isinstance(a, list):
127            a = list(a)
128        if a:
129            self._build(a)
130
131    def _rle(self, a: Sequence[T]) -> tuple[list[T], list[int]]:
132        x = []
133        y = []
134        x.append(a[0])
135        y.append(1)
136        for i, e in enumerate(a):
137            if i == 0:
138                continue
139            if e == x[-1]:
140                y[-1] += 1
141                continue
142            x.append(e)
143            y.append(1)
144        return x, y
145
146    def _build(self, a: Sequence[T]) -> None:
147        def rec(l: int, r: int) -> int:
148            mid = (l + r) >> 1
149            if l != mid:
150                child[mid << 1] = rec(l, mid)
151            if mid + 1 != r:
152                child[mid << 1 | 1] = rec(mid + 1, r)
153            return mid
154
155        if not all(a[i] <= a[i + 1] for i in range(len(a) - 1)):
156            a = sorted(a)
157        x, y = self._rle(a)
158        n = len(x)
159        keys, vals, child = self.keys, self.vals, self.child
160        self.reserve(n - len(keys) + 2)
161        self.end += n
162        keys[1 : n + 1] = x
163        vals[1 : n + 1] = y
164        self.root = rec(1, n + 1)
165        self.len = len(a)
166
167    def _make_node(self, key: T, val: int) -> int:
168        if self.end >= len(self.keys):
169            self.keys.append(key)
170            self.vals.append(val)
171            self.child.append(0)
172            self.child.append(0)
173        else:
174            self.keys[self.end] = key
175            self.vals[self.end] = val
176        self.end += 1
177        return self.end - 1
178
179    def _rotate_left(self, node: int) -> int:
180        child = self.child
181        u = child[node << 1]
182        child[node << 1] = child[u << 1 | 1]
183        child[u << 1 | 1] = node
184        return u
185
186    def _rotate_right(self, node: int) -> int:
187        child = self.child
188        u = child[node << 1 | 1]
189        child[node << 1 | 1] = child[u << 1]
190        child[u << 1] = node
191        return u
192
193    def _set_search_splay(self, key: T) -> None:
194        node = self.root
195        keys, child = self.keys, self.child
196        if (not node) or keys[node] == key:
197            return
198        left, right = 0, 0
199        while keys[node] != key:
200            f = key > keys[node]
201            if not child[node << 1 | f]:
202                break
203            if f:
204                if key > keys[child[node << 1 | 1]]:
205                    node = self._rotate_right(node)
206                    if not child[node << 1 | 1]:
207                        break
208                child[left << 1 | 1] = node
209                left = node
210            else:
211                if key < keys[child[node << 1]]:
212                    node = self._rotate_left(node)
213                    if not child[node << 1]:
214                        break
215                child[right << 1] = node
216                right = node
217            node = child[node << 1 | f]
218        child[right << 1] = child[node << 1 | 1]
219        child[left << 1 | 1] = child[node << 1]
220        child[node << 1] = child[1]
221        child[node << 1 | 1] = child[0]
222        self.root = node
223
224    def _get_min_splay(self, node: int) -> int:
225        child = self.child
226        if (not node) or (not child[node << 1]):
227            return node
228        right = 0
229        while child[node << 1]:
230            node = self._rotate_left(node)
231            if not child[node << 1]:
232                break
233            child[right << 1] = node
234            right = node
235            node = child[node << 1]
236        child[right << 1] = child[node << 1 | 1]
237        child[1] = child[node << 1]
238        child[node << 1] = child[1]
239        child[node << 1 | 1] = child[0]
240        return node
241
242    def _get_max_splay(self, node: int) -> int:
243        child = self.child
244        if (not node) or (not child[node << 1 | 1]):
245            return node
246        left = 0
247        while child[node << 1 | 1]:
248            node = self._rotate_right(node)
249            if not child[node << 1 | 1]:
250                break
251            child[left << 1 | 1] = node
252            left = node
253            node = child[node << 1 | 1]
254        child[0] = child[node << 1 | 1]
255        child[left << 1 | 1] = child[node << 1]
256        child[node << 1] = child[1]
257        child[node << 1 | 1] = child[0]
258        return node
259
260    def reserve(self, n: int) -> None:
261        assert n >= 0, "ValueError"
262        self.keys += [self.e] * n
263        self.vals += [0] * n
264        self.child += array("I", bytes(8 * n))
265
266    def add(self, key: T, val: int = 1) -> bool:
267        self.len += val
268        if not self.root:
269            self.root = self._make_node(key, val)
270            return
271        keys, vals, child = self.keys, self.vals, self.child
272        self._set_search_splay(key)
273        if keys[self.root] == key:
274            vals[self.root] += val
275            return
276        node = self._make_node(key, val)
277        f = key > keys[self.root]
278        child[node << 1 | f] = child[self.root << 1 | f]
279        child[node << 1 | f ^ 1] = self.root
280        child[self.root << 1 | f] = 0
281        self.root = node
282
283    def discard(self, key: T, val: int = 1) -> bool:
284        if not self.root:
285            return False
286        self._set_search_splay(key)
287        keys, vals, child = self.keys, self.vals, self.child
288        if keys[self.root] != key:
289            return False
290        if vals[self.root] > val:
291            vals[self.root] -= val
292            self.len -= val
293            return True
294        self.len -= vals[self.root]
295        if not child[self.root << 1]:
296            self.root = child[self.root << 1 | 1]
297        elif not child[self.root << 1 | 1]:
298            self.root = child[self.root << 1]
299        else:
300            node = self._get_min_splay(child[self.root << 1 | 1])
301            child[node << 1] = child[self.root << 1]
302            self.root = node
303        return True
304
305    def discard_all(self, key: T) -> bool:
306        return self.discard(key, self.count(key))
307
308    def remove(self, key: T, val: int = 1) -> None:
309        c = self.count(key)
310        if c < val:
311            raise KeyError(key)
312        self.discard(key, val)
313
314    def count(self, key: T) -> int:
315        if not self.root:
316            return 0
317        self._set_search_splay(key)
318        return self.vals[self.root]
319
320    def ge(self, key: T) -> Optional[T]:
321        node = self.root
322        if not node:
323            return None
324        keys, child = self.keys, self.child
325        if keys[node] == key:
326            return key
327        ge = None
328        left, right = 0, 0
329        while True:
330            if keys[node] == key:
331                ge = key
332                break
333            if key < keys[node]:
334                ge = keys[node]
335                if not child[node << 1]:
336                    break
337                if key < keys[child[node << 1]]:
338                    node = self._rotate_left(node)
339                    ge = keys[node]
340                    if not child[node << 1]:
341                        break
342                child[right << 1] = node
343                right = node
344                node = child[node << 1]
345            else:
346                if not child[node << 1 | 1]:
347                    break
348                if key > keys[child[node << 1 | 1]]:
349                    node = self._rotate_right(node)
350                    if not child[node << 1 | 1]:
351                        break
352                child[left << 1 | 1] = node
353                left = node
354                node = child[node << 1 | 1]
355        child[right << 1] = child[node << 1 | 1]
356        child[left << 1 | 1] = child[node << 1]
357        child[node << 1] = child[1]
358        child[node << 1 | 1] = child[0]
359        self.root = node
360        return ge
361
362    def gt(self, key: T) -> Optional[T]:
363        node = self.root
364        if not node:
365            return None
366        gt = None
367        keys, child = self.keys, self.child
368        left, right = 0, 0
369        while True:
370            if key < keys[node]:
371                gt = keys[node]
372                if not child[node << 1]:
373                    break
374                if key < keys[child[node << 1]]:
375                    node = self._rotate_left(node)
376                    gt = keys[node]
377                    if not child[node << 1]:
378                        break
379                child[right << 1] = node
380                right = node
381                node = child[node << 1]
382            else:
383                if not child[node << 1 | 1]:
384                    break
385                if key > keys[child[node << 1 | 1]]:
386                    node = self._rotate_right(node)
387                    if not child[node << 1 | 1]:
388                        break
389                child[left << 1 | 1] = node
390                left = node
391                node = child[node << 1 | 1]
392        child[right << 1] = child[node << 1 | 1]
393        child[left << 1 | 1] = child[node << 1]
394        child[node << 1] = child[1]
395        child[node << 1 | 1] = child[0]
396        self.root = node
397        return gt
398
399    def le(self, key: T) -> Optional[T]:
400        node = self.root
401        if not node:
402            return None
403        keys, child = self.keys, self.child
404        if keys[node] == key:
405            return key
406        le = None
407        left, right = 0, 0
408        while True:
409            if keys[node] == key:
410                le = key
411                break
412            if key < keys[node]:
413                if not child[node << 1]:
414                    break
415                if key < keys[child[node << 1]]:
416                    node = self._rotate_left(node)
417                    if not child[node << 1]:
418                        break
419                child[right << 1] = node
420                right = node
421                node = child[node << 1]
422            else:
423                le = keys[node]
424                if not child[node << 1 | 1]:
425                    break
426                if key > keys[child[node << 1 | 1]]:
427                    node = self._rotate_right(node)
428                    le = keys[node]
429                    if not child[node << 1 | 1]:
430                        break
431                child[left << 1 | 1] = node
432                left = node
433                node = child[node << 1 | 1]
434        child[right << 1] = child[node << 1 | 1]
435        child[left << 1 | 1] = child[node << 1]
436        child[node << 1] = child[1]
437        child[node << 1 | 1] = child[0]
438        self.root = node
439        return le
440
441    def lt(self, key: T) -> Optional[T]:
442        node = self.root
443        if not node:
444            return None
445        lt = None
446        keys, child = self.keys, self.child
447        left, right = 0, 0
448        while True:
449            if not keys[node] > key:
450                if not child[node << 1]:
451                    break
452                if key < keys[child[node << 1]]:
453                    node = self._rotate_left(node)
454                    if not child[node << 1]:
455                        break
456                child[right << 1] = node
457                right = node
458                node = child[node << 1]
459            else:
460                lt = keys[node]
461                if not child[node << 1 | 1]:
462                    break
463                if key > keys[child[node << 1 | 1]]:
464                    node = self._rotate_right(node)
465                    lt = keys[node]
466                    if not child[node << 1 | 1]:
467                        break
468                child[left << 1 | 1] = node
469                left = node
470                node = child[node << 1 | 1]
471        child[right << 1] = child[node << 1 | 1]
472        child[left << 1 | 1] = child[node << 1]
473        child[node << 1] = child[1]
474        child[node << 1 | 1] = child[0]
475        self.root = node
476        return lt
477
478    def tolist(self) -> list[T]:
479        node = self.root
480        child, vals, keys = self.child, self.vals, self.keys
481        stack, res = [], []
482        while stack or node:
483            if node:
484                stack.append(node)
485                node = child[node << 1]
486            else:
487                node = stack.pop()
488                for _ in range(vals[node]):
489                    res.append(keys[node])
490                node = child[node << 1 | 1]
491        return res
492
493    def get_max(self) -> T:
494        assert self.root, "IndexError: get_max() from empty SplayTreeMultisetTopDown"
495        self.root = self._get_max_splay(self.root)
496        return self.keys[self.root]
497
498    def get_min(self) -> T:
499        assert self.root, "IndexError: get_min() from empty SplayTreeMultisetTopDown"
500        self.root = self._get_min_splay(self.root)
501        return self.keys[self.root]
502
503    def pop_max(self) -> T:
504        assert self.root, "IndexError: pop_max() from empty SplayTreeMultisetTopDown"
505        node = self._get_max_splay(self.root)
506        self.len -= 1
507        if self.vals[node] > 1:
508            self.vals[node] -= 1
509            self.root = node
510        else:
511            self.root = self.child[node << 1]
512        return self.keys[node]
513
514    def pop_min(self) -> T:
515        assert self.root, "IndexError: pop_min() from empty SplayTreeMultisetTopDown"
516        node = self._get_min_splay(self.root)
517        self.len -= 1
518        if self.vals[node] > 1:
519            self.vals[node] -= 1
520            self.root = node
521        else:
522            self.root = self.child[node << 1 | 1]
523        return self.keys[node]
524
525    def clear(self) -> None:
526        self.root = 0
527        self.len = 0
528
529    def __contains__(self, key: T):
530        self._set_search_splay(key)
531        return self.keys[self.root] == key
532
533    def __iter__(self):
534        raise NotImplementedError
535
536    def __next__(self):
537        raise NotImplementedError
538
539    def __len__(self):
540        return self.len
541
542    def __bool__(self):
543        return self.root != 0
544
545    def __str__(self):
546        return "{" + ", ".join(map(str, self.tolist())) + "}"
547
548    def __repr__(self):
549        return f"SplayTreeMultisetTopDown({self})"

仕様

class SplayTreeMultisetTopDown(a: Iterable[T] = [], e: T = 0)[source]

Bases: OrderedMultisetInterface, Generic[T]

add(key: T, val: int = 1) bool[source]
clear() None[source]
count(key: T) int[source]
discard(key: T, val: int = 1) bool[source]
discard_all(key: T) bool[source]
ge(key: T) T | None[source]
get_max() T[source]
get_min() T[source]
gt(key: T) T | None[source]
le(key: T) T | None[source]
lt(key: T) T | None[source]
pop_max() T[source]
pop_min() T[source]
remove(key: T, val: int = 1) None[source]
reserve(n: int) None[source]
tolist() list[T][source]