splay_tree_multiset_sum

ソースコード

from titan_pylib.data_structures.splay_tree.splay_tree_multiset_sum import SplayTreeMultisetSum

view on github

展開済みコード

  1# from titan_pylib.data_structures.splay_tree.splay_tree_multiset_sum import SplayTreeMultisetSum
  2import sys
  3from typing import Iterator, Optional, Generic, Iterable, TypeVar
  4
  5T = TypeVar("T")
  6
  7
  8class SplayTreeMultisetSum(Generic[T]):
  9
 10    class Node:
 11
 12        def __init__(self, key, cnt: int) -> None:
 13            self.data = key * cnt
 14            self.key = key
 15            self.size = 1
 16            self.cnt = cnt
 17            self.cntsize = cnt
 18            self.left = None
 19            self.right = None
 20
 21        def __str__(self) -> str:
 22            if self.left is None and self.right is None:
 23                return f"key:{self.key, self.size, self.cnt, self.cntsize}\n"
 24            return f"key:{self.key, self.size, self.cnt, self.cntsize},\n left:{self.left},\n right:{self.right}\n"
 25
 26        __repr__ = __str__
 27
 28    def __init__(self, e: T, a: Iterable[T] = [], _node=None) -> None:
 29        self.node = _node
 30        self.e = e
 31        if a:
 32            self._build(a)
 33
 34    def _build(self, a: Iterable[T]) -> None:
 35        Node = SplayTreeMultisetSum.Node
 36
 37        def sort(l: int, r: int) -> SplayTreeMultisetSum.Node:
 38            mid = (l + r) >> 1
 39            node = Node(key[mid], cnt[mid])
 40            if l != mid:
 41                node.left = sort(l, mid)
 42            if mid + 1 != r:
 43                node.right = sort(mid + 1, r)
 44            self._update(node)
 45            return node
 46
 47        key, cnt = self._rle(sorted(a))
 48        self.node = sort(0, len(key))
 49
 50    def _rle(self, a: list[T]) -> tuple[list[T], list[int]]:
 51        x = []
 52        y = []
 53        x.append(a[0])
 54        y.append(1)
 55        for i, e in enumerate(a):
 56            if i == 0:
 57                continue
 58            if e == x[-1]:
 59                y[-1] += 1
 60                continue
 61            x.append(e)
 62            y.append(1)
 63        return x, y
 64
 65    def _update(self, node: Node) -> None:
 66        if node.left is None:
 67            if node.right is None:
 68                node.size = 1
 69                node.cntsize = node.cnt
 70                node.data = node.key * node.cnt
 71            else:
 72                node.size = 1 + node.right.size
 73                node.cntsize = node.cnt + node.right.cntsize
 74                node.data = node.key * node.cnt + node.right.data
 75        else:
 76            if node.right is None:
 77                node.size = 1 + node.left.size
 78                node.cntsize = node.cnt + node.left.cntsize
 79                node.data = node.key * node.cnt + node.left.data
 80            else:
 81                node.size = 1 + node.left.size + node.right.size
 82                node.cntsize = node.cnt + node.left.cntsize + node.right.cntsize
 83                node.data = node.key * node.cnt + node.left.data + node.right.data
 84
 85    def _splay(self, path: list[Node], d: int) -> Node:
 86        for _ in range(len(path) >> 1):
 87            node = path.pop()
 88            pnode = path.pop()
 89            if d & 1 == d >> 1 & 1:
 90                if d & 1:
 91                    tmp = node.left
 92                    node.left = tmp.right
 93                    tmp.right = node
 94                    pnode.left = node.right
 95                    node.right = pnode
 96                else:
 97                    tmp = node.right
 98                    node.right = tmp.left
 99                    tmp.left = node
100                    pnode.right = node.left
101                    node.left = pnode
102            else:
103                if d & 1:
104                    tmp = node.left
105                    node.left = tmp.right
106                    pnode.right = tmp.left
107                    tmp.right = node
108                    tmp.left = pnode
109                else:
110                    tmp = node.right
111                    node.right = tmp.left
112                    pnode.left = tmp.right
113                    tmp.left = node
114                    tmp.right = pnode
115            self._update(pnode)
116            self._update(node)
117            self._update(tmp)
118            if not path:
119                return tmp
120            d >>= 2
121            if d & 1:
122                path[-1].left = tmp
123            else:
124                path[-1].right = tmp
125        gnode = path[0]
126        if d & 1:
127            node = gnode.left
128            gnode.left = node.right
129            node.right = gnode
130        else:
131            node = gnode.right
132            gnode.right = node.left
133            node.left = gnode
134        self._update(gnode)
135        self._update(node)
136        return node
137
138    def _set_search_splay(self, key: T) -> None:
139        node = self.node
140        if node is None or node.key == key:
141            return
142        path = []
143        d = 0
144        while True:
145            if node.key == key:
146                break
147            elif key < node.key:
148                if node.left is None:
149                    break
150                path.append(node)
151                d <<= 1
152                d |= 1
153                node = node.left
154            else:
155                if node.right is None:
156                    break
157                path.append(node)
158                d <<= 1
159                node = node.right
160        if path:
161            self.node = self._splay(path, d)
162
163    def _set_kth_elm_splay(self, k: int) -> None:
164        if k < 0:
165            k += self.__len__()
166        d = 0
167        node = self.node
168        path = []
169        while True:
170            t = node.cnt if node.left is None else node.cnt + node.left.cntsize
171            if t - node.cnt <= k < t:
172                if path:
173                    self.node = self._splay(path, d)
174                break
175            elif t > k:
176                path.append(node)
177                d <<= 1
178                d |= 1
179                node = node.left
180            else:
181                path.append(node)
182                d <<= 1
183                node = node.right
184                k -= t
185
186    def _set_kth_elm_tree_splay(self, k: int) -> None:
187        if k < 0:
188            k += self.len_elm()
189        assert 0 <= k < self.len_elm()
190        d = 0
191        node = self.node
192        path = []
193        while True:
194            t = 0 if node.left is None else node.left.size
195            if t == k:
196                if path:
197                    self.node = self._splay(path, d)
198                return
199            elif t > k:
200                path.append(node)
201                d <<= 1
202                d |= 1
203                node = node.left
204            else:
205                path.append(node)
206                d <<= 1
207                node = node.right
208                k -= t + 1
209
210    def _get_min_splay(self, node: Node) -> Node:
211        if node is None or node.left is None:
212            return node
213        path = []
214        while node.left is not None:
215            path.append(node)
216            node = node.left
217        return self._splay(path, (1 << len(path)) - 1)
218
219    def _get_max_splay(self, node: Node) -> Node:
220        if node is None or node.right is None:
221            return node
222        path = []
223        while node.right is not None:
224            path.append(node)
225            node = node.right
226        return self._splay(path, 0)
227
228    def add(self, key: T, cnt: int = 1) -> None:
229        if self.node is None:
230            self.node = SplayTreeMultisetSum.Node(key, cnt)
231            return
232        self._set_search_splay(key)
233        if self.node.key == key:
234            self.node.cnt += cnt
235            self._update(self.node)
236            return
237        node = SplayTreeMultisetSum.Node(key, cnt)
238        if key < self.node.key:
239            node.left = self.node.left
240            node.right = self.node
241            self.node.left = None
242            self._update(node.right)
243        else:
244            node.left = self.node
245            node.right = self.node.right
246            self.node.right = None
247            self._update(node.left)
248        self._update(node)
249        self.node = node
250        return
251
252    def discard(self, key: T, cnt: int = 1) -> bool:
253        if self.node is None:
254            return False
255        self._set_search_splay(key)
256        if self.node.key != key:
257            return False
258        if self.node.cnt > cnt:
259            self.node.cnt -= cnt
260            self._update(self.node)
261            return True
262        if self.node.left is None:
263            self.node = self.node.right
264        elif self.node.right is None:
265            self.node = self.node.left
266        else:
267            node = self._get_min_splay(self.node.right)
268            node.left = self.node.left
269            self._update(node)
270            self.node = node
271        return True
272
273    def discard_all(self, key: T) -> bool:
274        return self.discard(key, self.count(key))
275
276    def count(self, key: T) -> int:
277        if self.node is None:
278            return 0
279        self._set_search_splay(key)
280        return self.node.cnt if self.node.key == key else 0
281
282    def le(self, key: T) -> Optional[T]:
283        node = self.node
284        if node is None:
285            return None
286        path = []
287        d = 0
288        res = None
289        while True:
290            if node.key == key:
291                res = key
292                break
293            elif key < node.key:
294                if node.left is None:
295                    break
296                path.append(node)
297                d <<= 1
298                d |= 1
299                node = node.left
300            else:
301                res = node.key
302                if node.right is None:
303                    break
304                path.append(node)
305                d <<= 1
306                node = node.right
307        if path:
308            self.node = self._splay(path, d)
309        return res
310
311    def lt(self, key: T) -> Optional[T]:
312        node = self.node
313        path = []
314        d = 0
315        res = None
316        while node is not None:
317            if key <= node.key:
318                path.append(node)
319                d <<= 1
320                d |= 1
321                node = node.left
322            else:
323                path.append(node)
324                d <<= 1
325                res = node.key
326                node = node.right
327        else:
328            if path:
329                path.pop()
330                d >>= 1
331        if path:
332            self.node = self._splay(path, d)
333        return res
334
335    def ge(self, key: T) -> Optional[T]:
336        node = self.node
337        if node is None:
338            return None
339        path = []
340        d = 0
341        res = None
342        while True:
343            if node.key == key:
344                res = node.key
345                break
346            elif key < node.key:
347                res = node.key
348                if node.left is None:
349                    break
350                path.append(node)
351                d <<= 1
352                d |= 1
353                node = node.left
354            else:
355                if node.right is None:
356                    break
357                path.append(node)
358                d <<= 1
359                node = node.right
360        if path:
361            self.node = self._splay(path, d)
362        return res
363
364    def gt(self, key: T) -> Optional[T]:
365        node = self.node
366        path = []
367        d = 0
368        res = None
369        while node is not None:
370            if key < node.key:
371                path.append(node)
372                d <<= 1
373                d |= 1
374                res = node.key
375                node = node.left
376            else:
377                path.append(node)
378                d <<= 1
379                node = node.right
380        else:
381            if path:
382                path.pop()
383                d >>= 1
384        if path:
385            self.node = self._splay(path, d)
386        return res
387
388    def index(self, key: T) -> int:
389        if self.node is None:
390            return 0
391        self._set_search_splay(key)
392        res = 0 if self.node.left is None else self.node.left.cntsize
393        if self.node.key < key:
394            res += self.node.cnt
395        return res
396
397    def index_right(self, key: T) -> int:
398        if self.node is None:
399            return 0
400        self._set_search_splay(key)
401        res = 0 if self.node.left is None else self.node.left.cntsize
402        if self.node.key <= key:
403            res += self.node.cnt
404        return res
405
406    def index_keys(self, key: T) -> int:
407        if self.node is None:
408            return 0
409        self._set_search_splay(key)
410        res = 0 if self.node.left is None else self.node.left.size
411        if self.node.key < key:
412            res += 1
413        return res
414
415    def index_right_keys(self, key: T) -> int:
416        if self.node is None:
417            return 0
418        self._set_search_splay(key)
419        res = 0 if self.node.left is None else self.node.left.size
420        if self.node.key <= key:
421            res += 1
422        return res
423
424    def pop(self, k: int = -1) -> T:
425        self._set_kth_elm_splay(k)
426        res = self.node.key
427        self.discard(res)
428        return res
429
430    def pop_max(self) -> T:
431        return self.pop()
432
433    def pop_min(self) -> T:
434        return self.pop(0)
435
436    def tolist(self) -> list[T]:
437        a = []
438        if self.node is None:
439            return a
440        if sys.getrecursionlimit() < self.len_elm():
441            sys.setrecursionlimit(self.len_elm() + 1)
442
443        def rec(node):
444            if node.left is not None:
445                rec(node.left)
446            for _ in range(node.cnt):
447                a.append(node.key)
448            if node.right is not None:
449                rec(node.right)
450
451        rec(self.node)
452        return a
453
454    def tolist_items(self) -> list[tuple[T, int]]:
455        a = []
456        if self.node is None:
457            return a
458        if sys.getrecursionlimit() < self.len_elm():
459            sys.setrecursionlimit(self.len_elm() + 1)
460
461        def rec(node):
462            if node.left is not None:
463                rec(node.left)
464            a.append((node.key, node.cnt))
465            if node.right is not None:
466                rec(node.right)
467
468        rec(self.node)
469        return a
470
471    def get_elm(self, k: int) -> T:
472        assert -self.len_elm() <= k < self.len_elm()
473        self._set_kth_elm_tree_splay(k)
474        return self.node.key
475
476    def items(self) -> Iterator[tuple[T, int]]:
477        for i in range(self.len_elm()):
478            self._set_kth_elm_tree_splay(i)
479            yield self.node.key, self.node.cnt
480
481    def keys(self) -> Iterator[T]:
482        for i in range(self.len_elm()):
483            self._set_kth_elm_tree_splay(i)
484            yield self.node.key
485
486    def cntues(self) -> Iterator[int]:
487        for i in range(self.len_elm()):
488            self._set_kth_elm_tree_splay(i)
489            yield self.node.cnt
490
491    def len_elm(self) -> int:
492        return 0 if self.node is None else self.node.size
493
494    def show(self) -> None:
495        print(
496            "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.tolist_items())) + "}"
497        )
498
499    def clear(self) -> None:
500        self.node = None
501
502    def get_sum(self) -> T:
503        if self.node is None:
504            return self.e
505        return self.node.data
506
507    def merge(self, other: "SplayTreeMultisetSum") -> None:
508        if self.node is None:
509            self.node = other.node
510            return
511        if other.node is None:
512            return
513        self.node = self._get_max_splay(self.node)
514        other.node = other._get_min_splay(other.node)
515        assert self.node.key <= other.node.key
516        if self.node.key == other.node.key:
517            self.node.cnt += other.node.cnt
518            self._update(self.node)
519            other.discard(other.node.key, other.node.cnt)
520        self.node.right = other.node
521        self._update(self.node)
522
523    def split(self, k: int) -> tuple["SplayTreeMultisetSum", "SplayTreeMultisetSum"]:
524        if self.node is None:
525            return SplayTreeMultisetSum(self.e), self
526        if k >= len(self):
527            return self, SplayTreeMultisetSum(self.e)
528        self._set_kth_elm_splay(k)
529        if self.node.left is None:
530            left = SplayTreeMultisetSum(self.e)
531        else:
532            left = SplayTreeMultisetSum(self.e, _node=self.node.left)
533        if len(left) != k:
534            assert k - len(left) > 0
535            self.node.cnt -= k - len(left)
536            self._update(self.node)
537            root = SplayTreeMultisetSum.Node(self.node.key, k - len(left))
538            root.left = left.node
539            left = SplayTreeMultisetSum(self.e, _node=root)
540            left._update(left.node)
541        if self.node:
542            self.node.left = None
543            self._update(self.node)
544        return left, self
545
546    def __iter__(self):
547        self.__iter = 0
548        return self
549
550    def __next__(self):
551        if self.__iter == self.__len__():
552            raise StopIteration
553        res = self.__getitem__(self.__iter)
554        self.__iter += 1
555        return res
556
557    def __reversed__(self):
558        for i in range(self.__len__()):
559            yield self.__getitem__(-i - 1)
560
561    def __contains__(self, key: T) -> bool:
562        self._set_search_splay(key)
563        return self.node is not None and self.node.key == key
564
565    def __getitem__(self, k: int) -> T:
566        self._set_kth_elm_splay(k)
567        return self.node.key
568
569    def __len__(self):
570        return 0 if self.node is None else self.node.cntsize
571
572    def __bool__(self):
573        return self.node is not None
574
575    def __str__(self):
576        return "{" + ", ".join(map(str, self.tolist())) + "}"
577
578    def __repr__(self):
579        return "SplayTreeMultisetSum(" + str(self) + ")"

仕様

class SplayTreeMultisetSum(e: T, a: Iterable[T] = [], _node=None)[source]

Bases: Generic[T]

class Node(key, cnt: int)[source]

Bases: object

add(key: T, cnt: int = 1) None[source]
clear() None[source]
cntues() Iterator[int][source]
count(key: T) int[source]
discard(key: T, cnt: int = 1) bool[source]
discard_all(key: T) bool[source]
ge(key: T) T | None[source]
get_elm(k: int) T[source]
get_sum() T[source]
gt(key: T) T | None[source]
index(key: T) int[source]
index_keys(key: T) int[source]
index_right(key: T) int[source]
index_right_keys(key: T) int[source]
items() Iterator[tuple[T, int]][source]
keys() Iterator[T][source]
le(key: T) T | None[source]
len_elm() int[source]
lt(key: T) T | None[source]
merge(other: SplayTreeMultisetSum) None[source]
pop(k: int = -1) T[source]
pop_max() T[source]
pop_min() T[source]
show() None[source]
split(k: int) tuple[SplayTreeMultisetSum, SplayTreeMultisetSum][source]
tolist() list[T][source]
tolist_items() list[tuple[T, int]][source]