splay_tree_multiset

ソースコード

from titan_pylib.data_structures.splay_tree.splay_tree_multiset import SplayTreeMultiset

view on github

展開済みコード

  1# from titan_pylib.data_structures.splay_tree.splay_tree_multiset import SplayTreeMultiset
  2# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  3from typing import Protocol
  4
  5
  6class SupportsLessThan(Protocol):
  7
  8    def __lt__(self, other) -> bool: ...
  9import sys
 10from typing import Iterator, Optional, Generic, Iterable, TypeVar
 11
 12T = TypeVar("T", bound=SupportsLessThan)
 13
 14
 15class SplayTreeMultiset(Generic[T]):
 16
 17    class Node:
 18
 19        def __init__(self, key: T, val: int):
 20            self.key: T = key
 21            self.size: int = 1
 22            self.val: int = val
 23            self.valsize: int = val
 24            self.left: Optional["SplayTreeMultiset.Node"] = None
 25            self.right: Optional["SplayTreeMultiset.Node"] = None
 26
 27        def __str__(self):
 28            if self.left is None and self.right is None:
 29                return f"key:{self.key, self.size, self.val, self.valsize}\n"
 30            return f"key:{self.key, self.size, self.val, self.valsize},\n left:{self.left},\n right:{self.right}\n"
 31
 32    def __init__(self, a: Iterable[T] = []) -> None:
 33        self.root: Optional["SplayTreeMultiset.Node"] = None
 34        if a:
 35            self._build(a)
 36
 37    def _build(self, a: Iterable[T]) -> None:
 38        Node = SplayTreeMultiset.Node
 39
 40        def sort(l: int, r: int) -> SplayTreeMultiset.Node:
 41            mid = (l + r) >> 1
 42            node = Node(key[mid], val[mid])
 43            if l != mid:
 44                node.left = sort(l, mid)
 45            if mid + 1 != r:
 46                node.right = sort(mid + 1, r)
 47            self._update(node)
 48            return node
 49
 50        key, val = self._rle(sorted(a))
 51        if len(key) == 0:
 52            return
 53        self.root = sort(0, len(key))
 54
 55    def _rle(self, a: list[T]) -> tuple[list[T], list[int]]:
 56        x = []
 57        y = []
 58        x.append(a[0])
 59        y.append(1)
 60        for i, e in enumerate(a):
 61            if i == 0:
 62                continue
 63            if e == x[-1]:
 64                y[-1] += 1
 65                continue
 66            x.append(e)
 67            y.append(1)
 68        return x, y
 69
 70    def _update(self, node: Node) -> None:
 71        if node.left is None:
 72            if node.right is None:
 73                node.size = 1
 74                node.valsize = node.val
 75            else:
 76                node.size = 1 + node.right.size
 77                node.valsize = node.val + node.right.valsize
 78        else:
 79            if node.right is None:
 80                node.size = 1 + node.left.size
 81                node.valsize = node.val + node.left.valsize
 82            else:
 83                node.size = 1 + node.left.size + node.right.size
 84                node.valsize = node.val + node.left.valsize + node.right.valsize
 85
 86    def _splay(self, path: list[Node], d: int) -> Node:
 87        for _ in range(len(path) >> 1):
 88            node = path.pop()
 89            pnode = path.pop()
 90            if d & 1 == d >> 1 & 1:
 91                if d & 1:
 92                    tmp = node.left
 93                    node.left = tmp.right
 94                    tmp.right = node
 95                    pnode.left = node.right
 96                    node.right = pnode
 97                else:
 98                    tmp = node.right
 99                    node.right = tmp.left
100                    tmp.left = node
101                    pnode.right = node.left
102                    node.left = pnode
103            else:
104                if d & 1:
105                    tmp = node.left
106                    node.left = tmp.right
107                    pnode.right = tmp.left
108                    tmp.right = node
109                    tmp.left = pnode
110                else:
111                    tmp = node.right
112                    node.right = tmp.left
113                    pnode.left = tmp.right
114                    tmp.left = node
115                    tmp.right = pnode
116            self._update(pnode)
117            self._update(node)
118            self._update(tmp)
119            if not path:
120                return tmp
121            d >>= 2
122            if d & 1:
123                path[-1].left = tmp
124            else:
125                path[-1].right = tmp
126        gnode = path[0]
127        if d & 1:
128            node = gnode.left
129            gnode.left = node.right
130            node.right = gnode
131        else:
132            node = gnode.right
133            gnode.right = node.left
134            node.left = gnode
135        self._update(gnode)
136        self._update(node)
137        return node
138
139    def _set_search_splay(self, key: T) -> None:
140        node = self.root
141        if node is None or node.key == key:
142            return
143        path = []
144        d = 0
145        while True:
146            if node.key == key:
147                break
148            if key < node.key:
149                if node.left is None:
150                    break
151                path.append(node)
152                d <<= 1
153                d |= 1
154                node = node.left
155            else:
156                if node.right is None:
157                    break
158                path.append(node)
159                d <<= 1
160                node = node.right
161        if path:
162            self.root = self._splay(path, d)
163
164    def _set_kth_elm_splay(self, k: int) -> None:
165        if k < 0:
166            k += self.__len__()
167        d = 0
168        node = self.root
169        path = []
170        while True:
171            t = node.val if node.left is None else node.val + node.left.valsize
172            if t - node.val <= k < t:
173                if path:
174                    self.root = self._splay(path, d)
175                break
176            elif t > k:
177                path.append(node)
178                d <<= 1
179                d |= 1
180                node = node.left
181            else:
182                path.append(node)
183                d <<= 1
184                node = node.right
185                k -= t
186
187    def _set_kth_elm_tree_splay(self, k: int) -> None:
188        if k < 0:
189            k += self.len_elm()
190        assert 0 <= k < self.len_elm()
191        d = 0
192        node = self.root
193        path = []
194        while True:
195            t = 0 if node.left is None else node.left.size
196            if t == k:
197                if path:
198                    self.root = self._splay(path, d)
199                return
200            elif t > k:
201                path.append(node)
202                d <<= 1
203                d |= 1
204                node = node.left
205            else:
206                path.append(node)
207                d <<= 1
208                node = node.right
209                k -= t + 1
210
211    def _get_min_splay(self, node: Node) -> Node:
212        if node is None or node.left is None:
213            return node
214        path = []
215        while node.left is not None:
216            path.append(node)
217            node = node.left
218        return self._splay(path, (1 << len(path)) - 1)
219
220    def _get_max_splay(self, node: Node) -> Node:
221        if node is None or node.right is None:
222            return node
223        path = []
224        while node.right is not None:
225            path.append(node)
226            node = node.right
227        return self._splay(path, 0)
228
229    def add(self, key: T, val: int = 1) -> None:
230        if self.root is None:
231            self.root = SplayTreeMultiset.Node(key, val)
232            return
233        self._set_search_splay(key)
234        if self.root.key == key:
235            self.root.val += val
236            self._update(self.root)
237            return
238        node = SplayTreeMultiset.Node(key, val)
239        if key < self.root.key:
240            node.left = self.root.left
241            node.right = self.root
242            self.root.left = None
243            self._update(node.right)
244        else:
245            node.left = self.root
246            node.right = self.root.right
247            self.root.right = None
248            self._update(node.left)
249        self._update(node)
250        self.root = node
251        return
252
253    def discard(self, key: T, val: int = 1) -> bool:
254        if self.root is None:
255            return False
256        self._set_search_splay(key)
257        if self.root.key != key:
258            return False
259        if self.root.val > val:
260            self.root.val -= val
261            self._update(self.root)
262            return True
263        if self.root.left is None:
264            self.root = self.root.right
265        elif self.root.right is None:
266            self.root = self.root.left
267        else:
268            node = self._get_min_splay(self.root.right)
269            node.left = self.root.left
270            self._update(node)
271            self.root = node
272        return True
273
274    def discard_all(self, key: T) -> bool:
275        return self.discard(key, self.count(key))
276
277    def count(self, key: T) -> int:
278        if self.root is None:
279            return 0
280        self._set_search_splay(key)
281        return self.root.val if self.root.key == key else 0
282
283    def le(self, key: T) -> Optional[T]:
284        node = self.root
285        if node is None:
286            return None
287        path = []
288        d = 0
289        res = None
290        while True:
291            if node.key == key:
292                res = key
293                break
294            elif key < node.key:
295                if node.left is None:
296                    break
297                path.append(node)
298                d <<= 1
299                d |= 1
300                node = node.left
301            else:
302                res = node.key
303                if node.right is None:
304                    break
305                path.append(node)
306                d <<= 1
307                node = node.right
308        if path:
309            self.root = self._splay(path, d)
310        return res
311
312    def lt(self, key: T) -> Optional[T]:
313        node = self.root
314        path = []
315        d = 0
316        res = None
317        while node is not None:
318            if key <= node.key:
319                path.append(node)
320                d <<= 1
321                d |= 1
322                node = node.left
323            else:
324                path.append(node)
325                d <<= 1
326                res = node.key
327                node = node.right
328        else:
329            if path:
330                path.pop()
331                d >>= 1
332        if path:
333            self.root = self._splay(path, d)
334        return res
335
336    def ge(self, key: T) -> Optional[T]:
337        node = self.root
338        if node is None:
339            return None
340        path = []
341        d = 0
342        res = None
343        while True:
344            if node.key == key:
345                res = node.key
346                break
347            elif key < node.key:
348                res = node.key
349                if node.left is None:
350                    break
351                path.append(node)
352                d <<= 1
353                d |= 1
354                node = node.left
355            else:
356                if node.right is None:
357                    break
358                path.append(node)
359                d <<= 1
360                node = node.right
361        if path:
362            self.root = self._splay(path, d)
363        return res
364
365    def gt(self, key: T) -> Optional[T]:
366        node = self.root
367        path = []
368        d = 0
369        res = None
370        while node is not None:
371            if key < node.key:
372                path.append(node)
373                d <<= 1
374                d |= 1
375                res = node.key
376                node = node.left
377            else:
378                path.append(node)
379                d <<= 1
380                node = node.right
381        else:
382            if path:
383                path.pop()
384                d >>= 1
385        if path:
386            self.root = self._splay(path, d)
387        return res
388
389    def index(self, key: T) -> int:
390        if self.root is None:
391            return 0
392        self._set_search_splay(key)
393        res = 0 if self.root.left is None else self.root.left.valsize
394        if self.root.key < key:
395            res += self.root.val
396        return res
397
398    def index_right(self, key: T) -> int:
399        if self.root is None:
400            return 0
401        self._set_search_splay(key)
402        res = 0 if self.root.left is None else self.root.left.valsize
403        if self.root.key <= key:
404            res += self.root.val
405        return res
406
407    def index_keys(self, key: T) -> int:
408        if self.root is None:
409            return 0
410        self._set_search_splay(key)
411        res = 0 if self.root.left is None else self.root.left.size
412        if self.root.key < key:
413            res += 1
414        return res
415
416    def index_right_keys(self, key: T) -> int:
417        if self.root is None:
418            return 0
419        self._set_search_splay(key)
420        res = 0 if self.root.left is None else self.root.left.size
421        if self.root.key <= key:
422            res += 1
423        return res
424
425    def pop(self, k: int = -1) -> T:
426        self._set_kth_elm_splay(k)
427        res = self.root.key
428        self.discard(res)
429        return res
430
431    def pop_max(self) -> T:
432        return self.pop()
433
434    def pop_min(self) -> T:
435        return self.pop(0)
436
437    def tolist(self) -> list[T]:
438        a = []
439        if self.root is None:
440            return a
441        if sys.getrecursionlimit() < self.len_elm():
442            sys.setrecursionlimit(self.len_elm() + 1)
443
444        def rec(node):
445            if node.left is not None:
446                rec(node.left)
447            for _ in range(node.val):
448                a.append(node.key)
449            if node.right is not None:
450                rec(node.right)
451
452        rec(self.root)
453        return a
454
455    def tolist_items(self) -> list[tuple[T, int]]:
456        a = []
457        if self.root is None:
458            return a
459        if sys.getrecursionlimit() < self.len_elm():
460            sys.setrecursionlimit(self.len_elm() + 1)
461
462        def rec(node):
463            if node.left is not None:
464                rec(node.left)
465            a.append((node.key, node.val))
466            if node.right is not None:
467                rec(node.right)
468
469        rec(self.root)
470        return a
471
472    def get_elm(self, k: int) -> T:
473        assert -self.len_elm() <= k < self.len_elm()
474        self._set_kth_elm_tree_splay(k)
475        return self.root.key
476
477    def items(self) -> Iterator[tuple[T, int]]:
478        for i in range(self.len_elm()):
479            self._set_kth_elm_tree_splay(i)
480            yield self.root.key, self.root.val
481
482    def keys(self) -> Iterator[T]:
483        for i in range(self.len_elm()):
484            self._set_kth_elm_tree_splay(i)
485            yield self.root.key
486
487    def values(self) -> Iterator[int]:
488        for i in range(self.len_elm()):
489            self._set_kth_elm_tree_splay(i)
490            yield self.root.val
491
492    def len_elm(self) -> int:
493        return 0 if self.root is None else self.root.size
494
495    def show(self) -> None:
496        print(
497            "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}"
498        )
499
500    def clear(self) -> None:
501        self.root = None
502
503    def __iter__(self):
504        self.__iter = 0
505        return self
506
507    def __next__(self):
508        if self.__iter == self.__len__():
509            raise StopIteration
510        res = self.__getitem__(self.__iter)
511        self.__iter += 1
512        return res
513
514    def __reversed__(self):
515        for i in range(self.__len__()):
516            yield self.__getitem__(-i - 1)
517
518    def __contains__(self, key: T) -> bool:
519        self._set_search_splay(key)
520        return self.root is not None and self.root.key == key
521
522    def __getitem__(self, k: int) -> T:
523        self._set_kth_elm_splay(k)
524        return self.root.key
525
526    def __len__(self):
527        return 0 if self.root is None else self.root.valsize
528
529    def __bool__(self):
530        return self.root is not None
531
532    def __str__(self):
533        return "{" + ", ".join(map(str, self.tolist())) + "}"
534
535    def __repr__(self):
536        return f"SplayTreeMultiset({self.tolist()})"

仕様

class SplayTreeMultiset(a: Iterable[T] = [])[source]

Bases: Generic[T]

class Node(key: T, val: int)[source]

Bases: object

add(key: T, val: int = 1) None[source]
clear() None[source]
count(key: T) int[source]
discard(key: T, val: int = 1) bool[source]
discard_all(key: T) bool[source]
ge(key: T) T | None[source]
get_elm(k: int) T[source]
gt(key: T) T | None[source]
index(key: T) int[source]
index_keys(key: T) int[source]
index_right(key: T) int[source]
index_right_keys(key: T) int[source]
items() Iterator[tuple[T, int]][source]
keys() Iterator[T][source]
le(key: T) T | None[source]
len_elm() int[source]
lt(key: T) T | None[source]
pop(k: int = -1) T[source]
pop_max() T[source]
pop_min() T[source]
show() None[source]
tolist() list[T][source]
tolist_items() list[tuple[T, int]][source]
values() Iterator[int][source]