splay_tree_multiset2

ソースコード

from titan_pylib.data_structures.splay_tree.splay_tree_multiset2 import SplayTreeMultiset2

view on github

展開済みコード

  1# from titan_pylib.data_structures.splay_tree.splay_tree_multiset2 import SplayTreeMultiset2
  2import sys
  3from typing import Generic, Iterable, TypeVar, Optional
  4
  5T = TypeVar("T")
  6
  7
  8class SplayTreeMultiset2(Generic[T]):
  9
 10    class Node:
 11
 12        def __init__(self, key: T, val: int):
 13            self.key = key
 14            self.val = val
 15            self.left = None
 16            self.right = None
 17
 18        def __str__(self):
 19            if self.left is None and self.right is None:
 20                return f"key:{self.key, self.val}\n"
 21            return (
 22                f"key:{self.key, self.val},\n left:{self.left},\n right:{self.right}\n"
 23            )
 24
 25    def __init__(self, a: Iterable[T] = []):
 26        self.node = None
 27        self._len = 0
 28        self._len_elm = 0
 29        if not (hasattr(a, "__getitem__") and hasattr(a, "__len__")):
 30            a = list(a)
 31        if a:
 32            self._build(a)
 33
 34    def _build(self, a: Iterable[T]) -> None:
 35        Node = SplayTreeMultiset2.Node
 36
 37        def sort(l: int, r: int) -> SplayTreeMultiset2.Node:
 38            mid = (l + r) >> 1
 39            node = Node(key[mid], val[mid])
 40            if l != mid:
 41                node.left = sort(l, mid)
 42            if mid + 1 != r:
 43                node.right = sort(mid + 1, r)
 44            return node
 45
 46        a = sorted(a)
 47        self._len = len(a)
 48        key, val = self._rle(sorted(a))
 49        self._len_elm = len(key)
 50        self.node = sort(0, len(key))
 51
 52    def _rle(self, a: list[T]) -> tuple[list[T], list[int]]:
 53        x = []
 54        y = []
 55        x.append(a[0])
 56        y.append(1)
 57        for i, e in enumerate(a):
 58            if i == 0:
 59                continue
 60            if e == x[-1]:
 61                y[-1] += 1
 62                continue
 63            x.append(e)
 64            y.append(1)
 65        return x, y
 66
 67    def _splay(self, path: list[Node], di: int) -> Node:
 68        for _ in range(len(path) >> 1):
 69            node = path.pop()
 70            pnode = path.pop()
 71            if di & 1 == di >> 1 & 1:
 72                if di & 1 == 1:
 73                    tmp = node.left
 74                    node.left = tmp.right
 75                    tmp.right = node
 76                    pnode.left = node.right
 77                    node.right = pnode
 78                else:
 79                    tmp = node.right
 80                    node.right = tmp.left
 81                    tmp.left = node
 82                    pnode.right = node.left
 83                    node.left = pnode
 84            else:
 85                if di & 1 == 1:
 86                    tmp = node.left
 87                    node.left = tmp.right
 88                    pnode.right = tmp.left
 89                    tmp.right = node
 90                    tmp.left = pnode
 91                else:
 92                    tmp = node.right
 93                    node.right = tmp.left
 94                    pnode.left = tmp.right
 95                    tmp.left = node
 96                    tmp.right = pnode
 97            if not path:
 98                return tmp
 99            di >>= 2
100            if di & 1 == 1:
101                path[-1].left = tmp
102            else:
103                path[-1].right = tmp
104        gnode = path[0]
105        if di & 1 == 1:
106            node = gnode.left
107            gnode.left = node.right
108            node.right = gnode
109        else:
110            node = gnode.right
111            gnode.right = node.left
112            node.left = gnode
113        return node
114
115    def _set_search_splay(self, key: T) -> None:
116        node = self.node
117        if node is None or node.key == key:
118            return
119        path = []
120        di = 0
121        while True:
122            if node.key == key:
123                break
124            elif key < node.key:
125                if node.left is None:
126                    break
127                path.append(node)
128                di <<= 1
129                di |= 1
130                node = node.left
131            else:
132                if node.right is None:
133                    break
134                path.append(node)
135                di <<= 1
136                node = node.right
137        if path:
138            self.node = self._splay(path, di)
139
140    def _get_min_splay(self, node: Node) -> Node:
141        if node is None or node.left is None:
142            return node
143        path = []
144        while node.left is not None:
145            path.append(node)
146            node = node.left
147        return self._splay(path, (1 << len(path)) - 1)
148
149    def _get_max_splay(self, node: Node) -> Node:
150        if node is None or node.right is None:
151            return node
152        path = []
153        while node.right is not None:
154            path.append(node)
155            node = node.right
156        return self._splay(path, 0)
157
158    def add(self, key: T, val: int = 1) -> None:
159        self._len += val
160        if self.node is None:
161            self._len_elm += 1
162            self.node = SplayTreeMultiset2.Node(key, val)
163            return
164        self._set_search_splay(key)
165        if self.node.key == key:
166            self.node.val += val
167            return
168        self._len_elm += 1
169        node = SplayTreeMultiset2.Node(key, val)
170        if key < self.node.key:
171            node.left = self.node.left
172            node.right = self.node
173            self.node.left = None
174        else:
175            node.left = self.node
176            node.right = self.node.right
177            self.node.right = None
178        self.node = node
179        return
180
181    def discard(self, key: T, val: int = 1) -> bool:
182        if self.node is None:
183            return False
184        self._set_search_splay(key)
185        if self.node.key != key:
186            return False
187        if self.node.val > val:
188            self.node.val -= val
189            self._len -= val
190            return True
191        self._len -= self.node.val
192        self._len_elm -= 1
193        if self.node.left is None:
194            self.node = self.node.right
195        elif self.node.right is None:
196            self.node = self.node.left
197        else:
198            node = self._get_min_splay(self.node.right)
199            node.left = self.node.left
200            self.node = node
201        return True
202
203    def discard_all(self, key: T) -> bool:
204        return self.discar(key, self.count(key))
205
206    def count(self, key: T) -> int:
207        if self.node is None:
208            return 0
209        self._set_search_splay(key)
210        return self.node.val if self.node.key == key else 0
211
212    def le(self, key: T) -> Optional[T]:
213        node = self.node
214        if node is None:
215            return None
216        path = []
217        di = 0
218        res = None
219        while True:
220            if node.key == key:
221                res = key
222                break
223            elif key < node.key:
224                if node.left is None:
225                    break
226                path.append(node)
227                di <<= 1
228                di |= 1
229                node = node.left
230            else:
231                res = node.key
232                if node.right is None:
233                    break
234                path.append(node)
235                di <<= 1
236                node = node.right
237        if path:
238            self.node = self._splay(path, di)
239        return res
240
241    def lt(self, key: T) -> Optional[T]:
242        node = self.node
243        if node is None:
244            return None
245        path = []
246        di = 0
247        res = None
248        while True:
249            if key <= node.key:
250                if node.left is None:
251                    break
252                path.append(node)
253                di <<= 1
254                di |= 1
255                node = node.left
256            else:
257                res = node.key
258                if node.right is None:
259                    break
260                path.append(node)
261                di <<= 1
262                node = node.right
263        if path:
264            self.node = self._splay(path, di)
265        return res
266
267    def ge(self, key: T) -> Optional[T]:
268        node = self.node
269        if node is None:
270            return None
271        path = []
272        di = 0
273        res = None
274        while True:
275            if node.key == key:
276                res = node.key
277                break
278            elif key < node.key:
279                res = node.key
280                if node.left is None:
281                    break
282                path.append(node)
283                di <<= 1
284                di |= 1
285                node = node.left
286            else:
287                if node.right is None:
288                    break
289                path.append(node)
290                di <<= 1
291                node = node.right
292        if path:
293            self.node = self._splay(path, di)
294        return res
295
296    def gt(self, key: T) -> Optional[T]:
297        node = self.node
298        if node is None:
299            return None
300        path = []
301        di = 0
302        res = None
303        while True:
304            if key < node.key:
305                res = node.key
306                if node.left is None:
307                    break
308                path.append(node)
309                di <<= 1
310                di |= 1
311                node = node.left
312            else:
313                if node.right is None:
314                    break
315                path.append(node)
316                di <<= 1
317                node = node.right
318        if path:
319            self.node = self._splay(path, di)
320        return res
321
322    def pop_max(self) -> T:
323        self.node = self._get_max_splay(self.node)
324        res = self.node.key
325        self.discard(res)
326        return res
327
328    def pop_min(self) -> T:
329        self.node = self._get_min_splay(self.node)
330        res = self.node.key
331        self.discard(res)
332        return res
333
334    def get_min(self) -> Optional[T]:
335        if self.node is None:
336            return
337        self.node = self._get_min_splay(self.node)
338        return self.node.key
339
340    def get_max(self) -> Optional[T]:
341        if self.node is None:
342            return
343        self.node = self._get_max_splay(self.node)
344        return self.node.key
345
346    def tolist(self) -> list[T]:
347        a = []
348        if self.node is None:
349            return a
350        if sys.getrecursionlimit() < self.len_elm():
351            sys.setrecursionlimit(self.len_elm() + 1)
352
353        def rec(node):
354            if node.left is not None:
355                rec(node.left)
356            a.extend([node.key] * node.val)
357            if node.right is not None:
358                rec(node.right)
359
360        rec(self.node)
361        return a
362
363    def tolist_items(self) -> list[tuple[T, int]]:
364        a = []
365        if self.node is None:
366            return a
367        if sys.getrecursionlimit() < self._len_elm():
368            sys.setrecursionlimit(self._len_elm() + 1)
369
370        def rec(node):
371            if node.left is not None:
372                rec(node.left)
373            a.append((node.key, node.val))
374            if node.right is not None:
375                rec(node.right)
376
377        rec(self.node)
378        return a
379
380    def len_elm(self) -> int:
381        return self._len_elm
382
383    def clear(self) -> None:
384        self.node = None
385
386    def __getitem__(self, k):  # 先s頭と末尾しか対応していない
387        if k == -1 or k == self._len - 1:
388            return self.get_max()
389        elif k == 0:
390            return self.get_min()
391        raise IndexError
392
393    def __contains__(self, key: T) -> bool:
394        self._set_search_splay(key)
395        return self.node is not None and self.node.key == key
396
397    def __len__(self):
398        return self._len
399
400    def __bool__(self):
401        return self.node is not None
402
403    def __str__(self):
404        return "{" + ", ".join(map(str, self.tolist())) + "}"
405
406    def __repr__(self):
407        return f"SplayTreeMultiset2({self.tolist()})"

仕様

class SplayTreeMultiset2(a: Iterable[T] = [])[source]

Bases: Generic[T]

class Node(key: T, val: int)[source]

Bases: object

add(key: T, val: int = 1) None[source]
clear() None[source]
count(key: T) int[source]
discard(key: T, val: int = 1) bool[source]
discard_all(key: T) bool[source]
ge(key: T) T | None[source]
get_max() T | None[source]
get_min() T | None[source]
gt(key: T) T | None[source]
le(key: T) T | None[source]
len_elm() int[source]
lt(key: T) T | None[source]
pop_max() T[source]
pop_min() T[source]
tolist() list[T][source]
tolist_items() list[tuple[T, int]][source]