Source code for titan_pylib.data_structures.wbt.wbt_list

  1from titan_pylib.data_structures.wbt._wbt_list_node import _WBTListNode
  2from typing import Generic, TypeVar, Optional, Iterable, Callable
  3
  4T = TypeVar("T")
  5
  6
[docs] 7class WBTList(Generic[T]): 8 # insert / pop / pop_max 9 10 def __init__( 11 self, 12 a: Iterable[T] = [], 13 ) -> None: 14 self._root = None 15 self.__build(a) 16 17 def __build(self, a: Iterable[T]) -> None: 18 def build(l: int, r: int, pnode: Optional[_WBTListNode] = None) -> _WBTListNode: 19 if l == r: 20 return None 21 mid = (l + r) // 2 22 node = _WBTListNode(self, a[mid]) 23 node._left = build(l, mid, node) 24 node._right = build(mid + 1, r, node) 25 node._par = pnode 26 node._update() 27 return node 28 29 if not isinstance(a, list): 30 a = list(a) 31 if not a: 32 return 33 self._root = build(0, len(a)) 34 35 @classmethod 36 def _weight(self, node: Optional[_WBTListNode]) -> int: 37 return node._size + 1 if node else 1 38 39 def _merge_with_root( 40 self, 41 l: Optional[_WBTListNode], 42 root: _WBTListNode, 43 r: Optional[_WBTListNode], 44 ) -> _WBTListNode: 45 if self._weight(l) * _WBTListNode.DELTA < self._weight(r): 46 r._propagate() 47 r._left = self._merge_with_root(l, root, r._left) 48 r._left._par = r 49 r._par = None 50 r._update() 51 if self._weight(r._right) * _WBTListNode.DELTA < self._weight(r._left): 52 return r._balance_right() 53 return r 54 elif self._weight(r) * _WBTListNode.DELTA < self._weight(l): 55 l._propagate() 56 l._right = self._merge_with_root(l._right, root, r) 57 l._right._par = l 58 l._par = None 59 l._update() 60 if self._weight(l._left) * _WBTListNode.DELTA < self._weight(l._right): 61 return l._balance_left() 62 return l 63 else: 64 root._left = l 65 root._right = r 66 if l: 67 l._par = root 68 if r: 69 r._par = root 70 root._update() 71 return root 72 73 def _split_node( 74 self, node: _WBTListNode, k: int 75 ) -> tuple[Optional[_WBTListNode], Optional[_WBTListNode]]: 76 if not node: 77 return None, None 78 node._propagate() 79 par = node._par 80 u = k if node._left is None else k - node._left._size 81 s, t = None, None 82 if u == 0: 83 s = node._left 84 t = self._merge_with_root(None, node, node._right) 85 elif u < 0: 86 s, t = self._split_node(node._left, k) 87 t = self._merge_with_root(t, node, node._right) 88 else: 89 s, t = self._split_node(node._right, u - 1) 90 s = self._merge_with_root(node._left, node, s) 91 if s: 92 s._par = par 93 if t: 94 t._par = par 95 return s, t 96
[docs] 97 def find_order(self, k: int) -> "WBTList[T]": 98 if k < 0: 99 k += len(self) 100 node = self._root 101 while True: 102 node._propagate() 103 t = node._left._size if node._left else 0 104 if t == k: 105 return node 106 if t < k: 107 k -= t + 1 108 node = node._right 109 else: 110 node = node._left
111
[docs] 112 def split(self, k: int) -> tuple["WBTList", "WBTList"]: 113 lnode, rnode = self._split_node(self._root, k) 114 l, r = WBTList(), WBTList() 115 l._root = lnode 116 r._root = rnode 117 return l, r
118 119 def _pop_max(self, node: _WBTListNode) -> tuple[_WBTListNode, _WBTListNode]: 120 l, tmp = self._split_node(node, node._size - 1) 121 return l, tmp 122 123 def _merge_node(self, l: _WBTListNode, r: _WBTListNode) -> _WBTListNode: 124 if l is None: 125 return r 126 if r is None: 127 return l 128 l, tmp = self._pop_max(l) 129 return self._merge_with_root(l, tmp, r) 130
[docs] 131 def extend(self, other: "WBTList[T]") -> None: 132 self._root = self._merge_node(self._root, other._root)
133
[docs] 134 def insert(self, k: int, key) -> None: 135 s, t = self._split_node(self._root, k) 136 self._root = self._merge_with_root(s, _WBTListNode(self, key), t)
137
[docs] 138 def pop(self, k: int): 139 s, t = self._split_node(self._root, k + 1) 140 s, tmp = self._pop_max(s) 141 self._root = self._merge_node(s, t) 142 return tmp._key
143 144 def _check(self, verbose: bool = False) -> None: 145 """作業用デバック関数 146 size,key,balanceをチェックして、正しければ高さを表示する 147 """ 148 if self._root is None: 149 if verbose: 150 print("ok. 0 (empty)") 151 return 152 153 # _size, height 154 def dfs(node: _WBTListNode) -> tuple[int, int]: 155 h = 0 156 s = 1 157 if node._left: 158 assert node._left._par is node 159 ls, lh = dfs(node._left) 160 s += ls 161 h = max(h, lh) 162 if node._right: 163 assert node._right._par is node 164 rs, rh = dfs(node._right) 165 s += rs 166 h = max(h, rh) 167 assert node._size == s 168 node._balance_check() 169 return s, h + 1 170 171 assert self._root._par is None 172 _, h = dfs(self._root) 173 if verbose: 174 print(f"ok. {h}") 175
[docs] 176 def reverse(self, l, r): 177 s, t = self._split_node(self._root, r) 178 r, s = self._split_node(s, l) 179 s._apply_rev() 180 self._root = self._merge_node(self._merge_node(r, s), t)
181 182 def __len__(self): 183 return self._root._size if self._root else 0 184 185 def __iter__(self): 186 node = self._root 187 stack: list[_WBTListNode] = [] 188 while stack or node: 189 if node: 190 node._propagate() 191 stack.append(node) 192 node = node._left 193 else: 194 node = stack.pop() 195 yield node._key 196 node = node._right 197 198 def __str__(self): 199 return str(list(self))