splay_tree_set

ソースコード

from titan_pylib.data_structures.splay_tree.splay_tree_set import SplayTreeSet

view on github

展開済みコード

  1# from titan_pylib.data_structures.splay_tree.splay_tree_set import SplayTreeSet
  2# from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
  3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  4from typing import Protocol
  5
  6
  7class SupportsLessThan(Protocol):
  8
  9    def __lt__(self, other) -> bool: ...
 10from abc import ABC, abstractmethod
 11from typing import Iterable, Optional, Iterator, TypeVar, Generic
 12
 13T = TypeVar("T", bound=SupportsLessThan)
 14
 15
 16class OrderedSetInterface(ABC, Generic[T]):
 17
 18    @abstractmethod
 19    def __init__(self, a: Iterable[T]) -> None:
 20        raise NotImplementedError
 21
 22    @abstractmethod
 23    def add(self, key: T) -> bool:
 24        raise NotImplementedError
 25
 26    @abstractmethod
 27    def discard(self, key: T) -> bool:
 28        raise NotImplementedError
 29
 30    @abstractmethod
 31    def remove(self, key: T) -> None:
 32        raise NotImplementedError
 33
 34    @abstractmethod
 35    def le(self, key: T) -> Optional[T]:
 36        raise NotImplementedError
 37
 38    @abstractmethod
 39    def lt(self, key: T) -> Optional[T]:
 40        raise NotImplementedError
 41
 42    @abstractmethod
 43    def ge(self, key: T) -> Optional[T]:
 44        raise NotImplementedError
 45
 46    @abstractmethod
 47    def gt(self, key: T) -> Optional[T]:
 48        raise NotImplementedError
 49
 50    @abstractmethod
 51    def get_max(self) -> Optional[T]:
 52        raise NotImplementedError
 53
 54    @abstractmethod
 55    def get_min(self) -> Optional[T]:
 56        raise NotImplementedError
 57
 58    @abstractmethod
 59    def pop_max(self) -> T:
 60        raise NotImplementedError
 61
 62    @abstractmethod
 63    def pop_min(self) -> T:
 64        raise NotImplementedError
 65
 66    @abstractmethod
 67    def clear(self) -> None:
 68        raise NotImplementedError
 69
 70    @abstractmethod
 71    def tolist(self) -> list[T]:
 72        raise NotImplementedError
 73
 74    @abstractmethod
 75    def __iter__(self) -> Iterator:
 76        raise NotImplementedError
 77
 78    @abstractmethod
 79    def __next__(self) -> T:
 80        raise NotImplementedError
 81
 82    @abstractmethod
 83    def __contains__(self, key: T) -> bool:
 84        raise NotImplementedError
 85
 86    @abstractmethod
 87    def __len__(self) -> int:
 88        raise NotImplementedError
 89
 90    @abstractmethod
 91    def __bool__(self) -> bool:
 92        raise NotImplementedError
 93
 94    @abstractmethod
 95    def __str__(self) -> str:
 96        raise NotImplementedError
 97
 98    @abstractmethod
 99    def __repr__(self) -> str:
100        raise NotImplementedError
101# from titan_pylib.my_class.supports_less_than import SupportsLessThan
102from array import array
103from typing import Generic, Iterable, TypeVar, Optional
104
105T = TypeVar("T", bound=SupportsLessThan)
106
107
108class SplayTreeSet(OrderedSetInterface, Generic[T]):
109
110    def __init__(self, a: Iterable[T] = [], e: T = 0):
111        self.keys: list[T] = [e]
112        self.size = array("I", bytes(4))
113        self.child = array("I", bytes(8))
114        self.end = 1
115        self.node = 0
116        self.e = e
117        if not isinstance(a, list):
118            a = list(a)
119        if a:
120            self._build(a)
121
122    def _build(self, a: list[T]) -> None:
123        def sort(l: int, r: int) -> int:
124            mid = (l + r) >> 1
125            if l != mid:
126                child[mid << 1] = sort(l, mid)
127            if mid + 1 != r:
128                child[mid << 1 | 1] = sort(mid + 1, r)
129            size[mid] = 1 + size[child[mid << 1]] + size[child[mid << 1 | 1]]
130            return mid
131
132        if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
133            a = sorted(set(a))
134        n = len(a)
135        key, child, size = self.keys, self.child, self.size
136        self.reserve(n - len(key) + 2)
137        self.end += n
138        key[1 : n + 1] = a
139        self.node = sort(1, n + 1)
140
141    def _make_node(self, key: T) -> int:
142        end = self.end
143        if end >= len(self.keys):
144            self.keys.append(key)
145            self.size.append(1)
146            self.child.append(0)
147            self.child.append(0)
148        else:
149            self.keys[end] = key
150        self.end += 1
151        return end
152
153    def _update(self, node: int) -> None:
154        self.size[node] = (
155            1 + self.size[self.child[node << 1]] + self.size[self.child[node << 1 | 1]]
156        )
157
158    def _update_triple(self, x: int, y: int, z: int) -> None:
159        size, child = self.size, self.child
160        size[z] = size[x]
161        size[x] = 1 + size[child[x << 1]] + size[child[x << 1 | 1]]
162        size[y] = 1 + size[child[y << 1]] + size[child[y << 1 | 1]]
163
164    def _splay(self, path: list[int], d: int) -> int:
165        child = self.child
166        g = d & 1
167        while len(path) > 1:
168            node = path.pop()
169            pnode = path.pop()
170            f = d >> 1 & 1
171            tmp = child[node << 1 | g ^ 1]
172            child[node << 1 | g ^ 1] = child[tmp << 1 | g]
173            child[tmp << 1 | f ^ (g != f)] = node
174            child[pnode << 1 | f ^ 1] = child[(node if g == f else tmp) << 1 | f]
175            child[(node if g == f else tmp) << 1 | f] = pnode
176            self._update_triple(pnode, node, tmp)
177            if not path:
178                return tmp
179            d >>= 2
180            g = d & 1
181            child[path[-1] << 1 | g ^ 1] = tmp
182        pnode = path[0]
183        node = child[pnode << 1 | g ^ 1]
184        child[pnode << 1 | g ^ 1] = child[node << 1 | g]
185        child[node << 1 | g] = pnode
186        self._update(pnode)
187        self._update(node)
188        return node
189
190    def _set_search_splay(self, key: T) -> None:
191        node = self.node
192        keys, child = self.keys, self.child
193        if (not node) or keys[node] == key:
194            return
195        path = []
196        d = 0
197        while keys[node] != key:
198            nxt = child[node << 1 | (key > keys[node])]
199            if not nxt:
200                break
201            path.append(node)
202            d, node = d << 1 | (key < keys[node]), nxt
203        if path:
204            self.node = self._splay(path, d)
205
206    def _set_kth_elm_splay(self, k: int) -> None:
207        size, child = self.size, self.child
208        node = self.node
209        if k < 0:
210            k += size[node]
211        d = 0
212        path = []
213        while True:
214            t = size[child[node << 1]]
215            if t == k:
216                if path:
217                    self.node = self._splay(path, d)
218                return
219            path.append(node)
220            d, node = d << 1 | (t > k), child[node << 1 | (t < k)]
221            if t < k:
222                k -= t + 1
223
224    def _get_min_splay(self, node: int) -> int:
225        if not node:
226            return 0
227        child = self.child
228        if not child[node << 1]:
229            return node
230        path = []
231        while child[node << 1]:
232            path.append(node)
233            node = child[node << 1]
234        return self._splay(path, (1 << len(path)) - 1)
235
236    def _get_max_splay(self, node: int) -> int:
237        if not node:
238            return 0
239        child = self.child
240        if not child[node << 1 | 1]:
241            return node
242        path = []
243        while child[node << 1 | 1]:
244            path.append(node)
245            node = child[node << 1 | 1]
246        return self._splay(path, 0)
247
248    def reserve(self, n: int) -> None:
249        assert n >= 0, f"ValueError: SplayTreeSet.reserve({n})"
250        self.keys += [self.e] * n
251        self.size += array("I", [1] * n)
252        self.child += array("I", bytes(8 * n))
253
254    def add(self, key: T) -> bool:
255        keys, child = self.keys, self.child
256        self._set_search_splay(key)
257        if keys[self.node] == key:
258            return False
259        node = self._make_node(key)
260        if not self.node:
261            self.node = node
262            return True
263        f = key > keys[self.node]
264        child[node << 1 | f ^ 1] = self.node
265        child[node << 1 | f] = child[self.node << 1 | f]
266        child[self.node << 1 | f] = 0
267        self._update(child[node << 1 | f ^ 1])
268        self._update(node)
269        self.node = node
270        return True
271
272    def discard(self, key: T) -> bool:
273        if not self.node:
274            return False
275        self._set_search_splay(key)
276        if self.keys[self.node] != key:
277            return False
278        child = self.child
279        if not child[self.node << 1]:
280            self.node = child[self.node << 1 | 1]
281        elif not child[self.node << 1 | 1]:
282            self.node = child[self.node << 1]
283        else:
284            node = self._get_min_splay(child[self.node << 1 | 1])
285            child[node << 1] = child[self.node << 1]
286            self._update(child[node << 1])
287            self.node = node
288        return True
289
290    def remove(self, key: T) -> None:
291        if self.discard(key):
292            return
293        raise KeyError(key)
294
295    def le(self, key: T) -> Optional[T]:
296        node = self.node
297        keys, child = self.keys, self.child
298        path = []
299        d = 0
300        res = None
301        while node:
302            if keys[node] == key:
303                res = key
304                break
305            path.append(node)
306            d = d << 1 | (key < keys[node])
307            if key > keys[node]:
308                res = keys[node]
309            node = child[node << 1 | (key > keys[node])]
310        else:
311            if path:
312                path.pop()
313                d >>= 1
314        if path:
315            self.node = self._splay(path, d)
316        return res
317
318    def lt(self, key: T) -> Optional[T]:
319        node = self.node
320        keys, child = self.keys, self.child
321        path = []
322        d = 0
323        res = None
324        while node:
325            path.append(node)
326            d = d << 1 | (key <= keys[node])
327            if key > keys[node]:
328                res = keys[node]
329            node = child[node << 1 | (key > keys[node])]
330        else:
331            if path:
332                path.pop()
333                d >>= 1
334        if path:
335            self.node = self._splay(path, d)
336        return res
337
338    def ge(self, key: T) -> Optional[T]:
339        node = self.node
340        keys, child = self.keys, self.child
341        path = []
342        d = 0
343        res = None
344        while node:
345            if keys[node] == key:
346                res = keys[node]
347                break
348            path.append(node)
349            d = d << 1 | (key < keys[node])
350            if key < keys[node]:
351                res = keys[node]
352            node = child[node << 1 | (key >= keys[node])]
353        else:
354            if path:
355                path.pop()
356                d >>= 1
357        if path:
358            self.node = self._splay(path, d)
359        return res
360
361    def gt(self, key: T) -> Optional[T]:
362        node = self.node
363        keys, child = self.keys, self.child
364        path = []
365        d = 0
366        res = None
367        while node:
368            path.append(node)
369            d = d << 1 | (key < keys[node])
370            if key < keys[node]:
371                res = keys[node]
372            node = child[node << 1 | (key >= keys[node])]
373        else:
374            if path:
375                path.pop()
376                d >>= 1
377        if path:
378            self.node = self._splay(path, d)
379        return res
380
381    def index(self, key: T) -> int:
382        if not self.node:
383            return 0
384        self._set_search_splay(key)
385        res = self.size[self.child[self.node << 1]]
386        if self.keys[self.node] < key:
387            res += 1
388        return res
389
390    def index_right(self, key: T) -> int:
391        if not self.node:
392            return 0
393        self._set_search_splay(key)
394        res = self.size[self.child[self.node << 1]]
395        if self.keys[self.node] <= key:
396            res += 1
397        return res
398
399    def get_min(self) -> T:
400        return self.__getitem__(0)
401
402    def get_max(self) -> T:
403        return self.__getitem__(-1)
404
405    def pop(self, k: int = -1) -> T:
406        assert self.node, f"IndexError: SplayTreeSet.pop({k})"
407        keys, child = self.keys, self.child
408        if k == -1:
409            node = self._get_max_splay(self.node)
410            self.node = child[node << 1]
411            return keys[node]
412        self._set_kth_elm_splay(k)
413        res = keys[self.node]
414        if not [self.node << 1]:
415            self.node = child[self.node << 1 | 1]
416        elif not child[self.node << 1 | 1]:
417            self.node = child[self.node << 1]
418        else:
419            node = self._get_min_splay(child[self.node << 1 | 1])
420            child[node << 1] = child[self.node << 1]
421            self._update(node)
422            self.node = node
423        return res
424
425    def pop_max(self) -> T:
426        return self.pop()
427
428    def pop_min(self) -> T:
429        assert self.node, "IndexError: SplayTreeSet.pop_min()"
430        node = self._get_min_splay(self.node)
431        self.node = self.child[node << 1 | 1]
432        return self.keys[node]
433
434    def clear(self) -> None:
435        self.node = 0
436
437    def tolist(self) -> list[T]:
438        node = self.node
439        child, keys = self.child, self.keys
440        stack, res = [], []
441        while stack or node:
442            if node:
443                stack.append(node)
444                node = child[node << 1]
445            else:
446                node = stack.pop()
447                res.append(keys[node])
448                node = child[node << 1 | 1]
449        return res
450
451    def __iter__(self):
452        self.__iter = 0
453        return self
454
455    def __next__(self):
456        if self.__iter == self.__len__():
457            raise StopIteration
458        self.__iter += 1
459        return self.__getitem__(self.__iter - 1)
460
461    def __reversed__(self):
462        for i in range(self.__len__()):
463            yield self.__getitem__(-i - 1)
464
465    def __contains__(self, key: T):
466        self._set_search_splay(key)
467        return self.keys[self.node] == key
468
469    def __getitem__(self, k: int) -> T:
470        self._set_kth_elm_splay(k)
471        return self.keys[self.node]
472
473    def __len__(self):
474        return self.size[self.node]
475
476    def __bool__(self):
477        return self.node != 0
478
479    def __str__(self):
480        return "{" + ", ".join(map(str, self.tolist())) + "}"
481
482    def __repr__(self):
483        return f"SplayTreeSet({self})"

仕様

class SplayTreeSet(a: Iterable[T] = [], e: T = 0)[source]

Bases: OrderedSetInterface, Generic[T]

add(key: T) bool[source]
clear() None[source]
discard(key: T) bool[source]
ge(key: T) T | None[source]
get_max() T[source]
get_min() T[source]
gt(key: T) T | None[source]
index(key: T) int[source]
index_right(key: T) int[source]
le(key: T) T | None[source]
lt(key: T) T | None[source]
pop(k: int = -1) T[source]
pop_max() T[source]
pop_min() T[source]
remove(key: T) None[source]
reserve(n: int) None[source]
tolist() list[T][source]