wbt_list

ソースコード

from titan_pylib.data_structures.wbt.wbt_list import WBTList

view on github

展開済みコード

  1# from titan_pylib.data_structures.wbt.wbt_list import WBTList
  2# from titan_pylib.data_structures.wbt._wbt_list_node import _WBTListNode
  3# from titan_pylib.data_structures.wbt._wbt_node_base import _WBTNodeBase
  4from typing import Generic, TypeVar, Optional, Final
  5
  6T = TypeVar("T")
  7
  8
  9class _WBTNodeBase(Generic[T]):
 10    """WBTノードのベースクラス
 11    size, par, left, rightをもつ
 12    """
 13
 14    __slots__ = "_size", "_par", "_left", "_right"
 15    DELTA: Final[int] = 3
 16    GAMMA: Final[int] = 2
 17
 18    def __init__(self) -> None:
 19        self._size: int = 1
 20        self._par: Optional[_WBTNodeBase[T]] = None
 21        self._left: Optional[_WBTNodeBase[T]] = None
 22        self._right: Optional[_WBTNodeBase[T]] = None
 23
 24    def _rebalance(self) -> "_WBTNodeBase[T]":
 25        """根までを再構築する
 26
 27        Returns:
 28            _WBTNodeBase[T]: 根ノード
 29        """
 30        node = self
 31        while True:
 32            node._update()
 33            wl, wr = node._weight_left(), node._weight_right()
 34            if wl * _WBTNodeBase.DELTA < wr:
 35                if (
 36                    node._right._weight_left()
 37                    >= node._right._weight_right() * _WBTNodeBase.GAMMA
 38                ):
 39                    node._right = node._right._rotate_right()
 40                node = node._rotate_left()
 41            elif wr * _WBTNodeBase.DELTA < wl:
 42                if (
 43                    node._left._weight_right()
 44                    >= node._left._weight_left() * _WBTNodeBase.GAMMA
 45                ):
 46                    node._left = node._left._rotate_left()
 47                node = node._rotate_right()
 48            if not node._par:
 49                return node
 50            node = node._par
 51
 52    def _copy_from(self, other: "_WBTNodeBase[T]") -> None:
 53        self._size = other._size
 54        if other._left:
 55            other._left._par = self
 56        if other._right:
 57            other._right._par = self
 58        if other._par:
 59            if other._par._left is other:
 60                other._par._left = self
 61            else:
 62                other._par._right = self
 63        self._par = other._par
 64        self._left = other._left
 65        self._right = other._right
 66
 67    def _weight_left(self) -> int:
 68        return self._left._size + 1 if self._left else 1
 69
 70    def _weight_right(self) -> int:
 71        return self._right._size + 1 if self._right else 1
 72
 73    def _update(self) -> None:
 74        self._size = (
 75            1
 76            + (self._left._size if self._left else 0)
 77            + (self._right._size if self._right else 0)
 78        )
 79
 80    def _rotate_right(self) -> "_WBTNodeBase[T]":
 81        u = self._left
 82        u._size = self._size
 83        self._size -= u._left._size + 1 if u._left else 1
 84        u._par = self._par
 85        self._left = u._right
 86        if u._right:
 87            u._right._par = self
 88        u._right = self
 89        self._par = u
 90        if u._par:
 91            if u._par._left is self:
 92                u._par._left = u
 93            else:
 94                u._par._right = u
 95        return u
 96
 97    def _rotate_left(self) -> "_WBTNodeBase[T]":
 98        u = self._right
 99        u._size = self._size
100        self._size -= u._right._size + 1 if u._right else 1
101        u._par = self._par
102        self._right = u._left
103        if u._left:
104            u._left._par = self
105        u._left = self
106        self._par = u
107        if u._par:
108            if u._par._left is self:
109                u._par._left = u
110            else:
111                u._par._right = u
112        return u
113
114    def _balance_check(self) -> None:
115        if not self._weight_left() * _WBTNodeBase.DELTA >= self._weight_right():
116            print(self._weight_left(), self._weight_right(), flush=True)
117            print(self)
118            assert False, f"self._weight_left() * DELTA >= self._weight_right()"
119        if not self._weight_right() * _WBTNodeBase.DELTA >= self._weight_left():
120            print(self._weight_left(), self._weight_right(), flush=True)
121            print(self)
122            assert False, f"self._weight_right() * DELTA >= self._weight_left()"
123
124    def _min(self) -> "_WBTNodeBase[T]":
125        node = self
126        while node._left:
127            node = node._left
128        return node
129
130    def _max(self) -> "_WBTNodeBase[T]":
131        node = self
132        while node._right:
133            node = node._right
134        return node
135
136    def _next(self) -> Optional["_WBTNodeBase[T]"]:
137        if self._right:
138            return self._right._min()
139        now, pre = self, None
140        while now and now._right is pre:
141            now, pre = now._par, now
142        return now
143
144    def _prev(self) -> Optional["_WBTNodeBase[T]"]:
145        if self._left:
146            return self._left._max()
147        now, pre = self, None
148        while now and now._left is pre:
149            now, pre = now._par, now
150        return now
151
152    def __add__(self, other: int) -> Optional["_WBTNodeBase[T]"]:
153        node = self
154        for _ in range(other):
155            node = node._next()
156        return node
157
158    def __sub__(self, other: int) -> Optional["_WBTNodeBase[T]"]:
159        node = self
160        for _ in range(other):
161            node = node._prev()
162        return node
163
164    __iadd__ = __add__
165    __isub__ = __sub__
166
167    def __str__(self) -> str:
168        # if self._left is None and self._right is None:
169        #     return f"key:{self._key, self._size}\n"
170        # return f"key:{self._key, self._size},\n _left:{self._left},\n _right:{self._right}\n"
171        return str(self._key)
172
173    __repr__ = __str__
174from typing import Generic, TypeVar, Optional, TYPE_CHECKING
175
176if TYPE_CHECKING:
177    from titan_pylib.data_structures.wbt.wbt_list import WBTList
178
179T = TypeVar("T")
180
181
182class _WBTListNode(_WBTNodeBase, Generic[T]):
183
184    __slots__ = (
185        "_left",
186        "_right",
187        "_par",
188        "_tree",
189        "_key",
190        "_rev",
191    )
192
193    def __init__(self, tree: "WBTList[T]", key: T) -> None:
194        super().__init__()
195        self._tree: WBTList[T] = tree
196        self._key: T = key
197        self._rev: int = 0
198        self._left: "_WBTListNode[T]"
199        self._right: "_WBTListNode[T]"
200        self._par: "_WBTListNode[T]"
201
202    def __str__(self) -> str:
203        if self._left is None and self._right is None:
204            return f"key:{self._key, self._size}\n"
205        return f"key:{self._key, self._size},\n _left:{self._left},\n _right:{self._right}\n"
206
207    def _check(self):
208        def dfs(node: "_WBTListNode"):
209            s = 1
210            if node._left:
211                assert node._left._par is node
212                s += node._left._size
213                dfs(node._left)
214            if node._right:
215                assert node._right._par is node
216                s += node._right._size
217                dfs(node._right)
218            assert s == node._size
219
220        dfs(self)
221        # print("check ok.")
222
223    def propagate_above(self) -> None:
224        """これの上について、revをすべて伝播する"""
225        stack: list["_WBTListNode[T]"] = []
226        node = self
227        while node:
228            stack.append(node)
229            node = node._par
230        while stack:
231            node = stack.pop()
232            node._propagate()
233
234    def update_above(self) -> None:
235        """これの上について、updateする
236
237        Note:
238            これの上はすべて revを伝播済み
239        """
240        node = self
241        while node:
242            node._update()
243            node = node._par
244
245    def _update(self) -> None:
246        self._size = 1
247        if self._left:
248            self._size += self._left._size
249        if self._right:
250            self._size += self._right._size
251
252    def _apply_rev(self) -> None:
253        self._rev ^= 1
254
255    def _propagate(self) -> None:
256        if self._rev:
257            self._left, self._right = self._right, self._left
258            if self._left:
259                self._left._apply_rev()
260            if self._right:
261                self._right._apply_rev()
262            self._rev = 0
263
264    def _rotate_right(self) -> "_WBTListNode[T]":
265        u = self._left
266        u._propagate()
267        u._par = self._par
268        self._left = u._right
269        if u._right:
270            u._right._par = self
271        u._right = self
272        self._par = u
273        if u._par:
274            if u._par._left is self:
275                u._par._left = u
276            else:
277                u._par._right = u
278        self._update()
279        u._update()
280        return u
281
282    def _rotate_left(self) -> "_WBTListNode[T]":
283        u = self._right
284        u._propagate()
285        u._par = self._par
286        self._right = u._left
287        if u._left:
288            u._left._par = self
289        u._left = self
290        self._par = u
291        if u._par:
292            if u._par._left is self:
293                u._par._left = u
294            else:
295                u._par._right = u
296        self._update()
297        u._update()
298        return u
299
300    def _balance_left(self) -> "_WBTListNode[T]":
301        self._right._propagate()
302        if self._right._weight_left() >= self._right._weight_right() * self.GAMMA:
303            self._right = self._right._rotate_right()
304        return self._rotate_left()
305
306    def _balance_right(self) -> "_WBTListNode[T]":
307        self._left._propagate()
308        if self._left._weight_right() >= self._left._weight_left() * self.GAMMA:
309            self._left = self._left._rotate_left()
310        return self._rotate_right()
311
312    def _min(self) -> "_WBTListNode[T]":
313        self.propagate_above()
314        assert self._rev == 0
315        node = self
316        while node._left:
317            node = node._left
318            node._propagate()
319        return node
320
321    def _max(self) -> "_WBTListNode[T]":
322        self.propagate_above()
323        assert self._rev == 0
324        node = self
325        while node._right:
326            node = node._right
327            node._propagate()
328        return node
329
330    def _next(self) -> Optional["_WBTListNode[T]"]:
331        self.propagate_above()
332        if self._right:
333            return self._right._min()
334        now, pre = self, None
335        while now and now._right is pre:
336            now, pre = now._par, now
337        return now
338
339    def _prev(self) -> Optional["_WBTListNode[T]"]:
340        self.propagate_above()
341        if self._left:
342            return self._left._max()
343        now, pre = self, None
344        while now and now._left is pre:
345            now, pre = now._par, now
346        return now
347from typing import Generic, TypeVar, Optional, Iterable, Callable
348
349T = TypeVar("T")
350
351
352class WBTList(Generic[T]):
353    # insert / pop / pop_max
354
355    def __init__(
356        self,
357        a: Iterable[T] = [],
358    ) -> None:
359        self._root = None
360        self.__build(a)
361
362    def __build(self, a: Iterable[T]) -> None:
363        def build(l: int, r: int, pnode: Optional[_WBTListNode] = None) -> _WBTListNode:
364            if l == r:
365                return None
366            mid = (l + r) // 2
367            node = _WBTListNode(self, a[mid])
368            node._left = build(l, mid, node)
369            node._right = build(mid + 1, r, node)
370            node._par = pnode
371            node._update()
372            return node
373
374        if not isinstance(a, list):
375            a = list(a)
376        if not a:
377            return
378        self._root = build(0, len(a))
379
380    @classmethod
381    def _weight(self, node: Optional[_WBTListNode]) -> int:
382        return node._size + 1 if node else 1
383
384    def _merge_with_root(
385        self,
386        l: Optional[_WBTListNode],
387        root: _WBTListNode,
388        r: Optional[_WBTListNode],
389    ) -> _WBTListNode:
390        if self._weight(l) * _WBTListNode.DELTA < self._weight(r):
391            r._propagate()
392            r._left = self._merge_with_root(l, root, r._left)
393            r._left._par = r
394            r._par = None
395            r._update()
396            if self._weight(r._right) * _WBTListNode.DELTA < self._weight(r._left):
397                return r._balance_right()
398            return r
399        elif self._weight(r) * _WBTListNode.DELTA < self._weight(l):
400            l._propagate()
401            l._right = self._merge_with_root(l._right, root, r)
402            l._right._par = l
403            l._par = None
404            l._update()
405            if self._weight(l._left) * _WBTListNode.DELTA < self._weight(l._right):
406                return l._balance_left()
407            return l
408        else:
409            root._left = l
410            root._right = r
411            if l:
412                l._par = root
413            if r:
414                r._par = root
415            root._update()
416            return root
417
418    def _split_node(
419        self, node: _WBTListNode, k: int
420    ) -> tuple[Optional[_WBTListNode], Optional[_WBTListNode]]:
421        if not node:
422            return None, None
423        node._propagate()
424        par = node._par
425        u = k if node._left is None else k - node._left._size
426        s, t = None, None
427        if u == 0:
428            s = node._left
429            t = self._merge_with_root(None, node, node._right)
430        elif u < 0:
431            s, t = self._split_node(node._left, k)
432            t = self._merge_with_root(t, node, node._right)
433        else:
434            s, t = self._split_node(node._right, u - 1)
435            s = self._merge_with_root(node._left, node, s)
436        if s:
437            s._par = par
438        if t:
439            t._par = par
440        return s, t
441
442    def find_order(self, k: int) -> "WBTList[T]":
443        if k < 0:
444            k += len(self)
445        node = self._root
446        while True:
447            node._propagate()
448            t = node._left._size if node._left else 0
449            if t == k:
450                return node
451            if t < k:
452                k -= t + 1
453                node = node._right
454            else:
455                node = node._left
456
457    def split(self, k: int) -> tuple["WBTList", "WBTList"]:
458        lnode, rnode = self._split_node(self._root, k)
459        l, r = WBTList(), WBTList()
460        l._root = lnode
461        r._root = rnode
462        return l, r
463
464    def _pop_max(self, node: _WBTListNode) -> tuple[_WBTListNode, _WBTListNode]:
465        l, tmp = self._split_node(node, node._size - 1)
466        return l, tmp
467
468    def _merge_node(self, l: _WBTListNode, r: _WBTListNode) -> _WBTListNode:
469        if l is None:
470            return r
471        if r is None:
472            return l
473        l, tmp = self._pop_max(l)
474        return self._merge_with_root(l, tmp, r)
475
476    def extend(self, other: "WBTList[T]") -> None:
477        self._root = self._merge_node(self._root, other._root)
478
479    def insert(self, k: int, key) -> None:
480        s, t = self._split_node(self._root, k)
481        self._root = self._merge_with_root(s, _WBTListNode(self, key), t)
482
483    def pop(self, k: int):
484        s, t = self._split_node(self._root, k + 1)
485        s, tmp = self._pop_max(s)
486        self._root = self._merge_node(s, t)
487        return tmp._key
488
489    def _check(self, verbose: bool = False) -> None:
490        """作業用デバック関数
491        size,key,balanceをチェックして、正しければ高さを表示する
492        """
493        if self._root is None:
494            if verbose:
495                print("ok. 0 (empty)")
496            return
497
498        # _size, height
499        def dfs(node: _WBTListNode) -> tuple[int, int]:
500            h = 0
501            s = 1
502            if node._left:
503                assert node._left._par is node
504                ls, lh = dfs(node._left)
505                s += ls
506                h = max(h, lh)
507            if node._right:
508                assert node._right._par is node
509                rs, rh = dfs(node._right)
510                s += rs
511                h = max(h, rh)
512            assert node._size == s
513            node._balance_check()
514            return s, h + 1
515
516        assert self._root._par is None
517        _, h = dfs(self._root)
518        if verbose:
519            print(f"ok. {h}")
520
521    def reverse(self, l, r):
522        s, t = self._split_node(self._root, r)
523        r, s = self._split_node(s, l)
524        s._apply_rev()
525        self._root = self._merge_node(self._merge_node(r, s), t)
526
527    def __len__(self):
528        return self._root._size if self._root else 0
529
530    def __iter__(self):
531        node = self._root
532        stack: list[_WBTListNode] = []
533        while stack or node:
534            if node:
535                node._propagate()
536                stack.append(node)
537                node = node._left
538            else:
539                node = stack.pop()
540                yield node._key
541                node = node._right
542
543    def __str__(self):
544        return str(list(self))

仕様

class WBTList(a: Iterable[T] = [])[source]

Bases: Generic[T]

extend(other: WBTList[T]) None[source]
find_order(k: int) WBTList[T][source]
insert(k: int, key) None[source]
pop(k: int)[source]
reverse(l, r)[source]
split(k: int) tuple[WBTList, WBTList][source]