1from titan_pylib.data_structures.wbt._wbt_list_node import _WBTListNode
2from typing import Generic, TypeVar, Optional, Iterable, Callable
3
4T = TypeVar("T")
5
6
[docs]
7class WBTList(Generic[T]):
8 # insert / pop / pop_max
9
10 def __init__(
11 self,
12 a: Iterable[T] = [],
13 ) -> None:
14 self._root = None
15 self.__build(a)
16
17 def __build(self, a: Iterable[T]) -> None:
18 def build(l: int, r: int, pnode: Optional[_WBTListNode] = None) -> _WBTListNode:
19 if l == r:
20 return None
21 mid = (l + r) // 2
22 node = _WBTListNode(self, a[mid])
23 node._left = build(l, mid, node)
24 node._right = build(mid + 1, r, node)
25 node._par = pnode
26 node._update()
27 return node
28
29 if not isinstance(a, list):
30 a = list(a)
31 if not a:
32 return
33 self._root = build(0, len(a))
34
35 @classmethod
36 def _weight(self, node: Optional[_WBTListNode]) -> int:
37 return node._size + 1 if node else 1
38
39 def _merge_with_root(
40 self,
41 l: Optional[_WBTListNode],
42 root: _WBTListNode,
43 r: Optional[_WBTListNode],
44 ) -> _WBTListNode:
45 if self._weight(l) * _WBTListNode.DELTA < self._weight(r):
46 r._propagate()
47 r._left = self._merge_with_root(l, root, r._left)
48 r._left._par = r
49 r._par = None
50 r._update()
51 if self._weight(r._right) * _WBTListNode.DELTA < self._weight(r._left):
52 return r._balance_right()
53 return r
54 elif self._weight(r) * _WBTListNode.DELTA < self._weight(l):
55 l._propagate()
56 l._right = self._merge_with_root(l._right, root, r)
57 l._right._par = l
58 l._par = None
59 l._update()
60 if self._weight(l._left) * _WBTListNode.DELTA < self._weight(l._right):
61 return l._balance_left()
62 return l
63 else:
64 root._left = l
65 root._right = r
66 if l:
67 l._par = root
68 if r:
69 r._par = root
70 root._update()
71 return root
72
73 def _split_node(
74 self, node: _WBTListNode, k: int
75 ) -> tuple[Optional[_WBTListNode], Optional[_WBTListNode]]:
76 if not node:
77 return None, None
78 node._propagate()
79 par = node._par
80 u = k if node._left is None else k - node._left._size
81 s, t = None, None
82 if u == 0:
83 s = node._left
84 t = self._merge_with_root(None, node, node._right)
85 elif u < 0:
86 s, t = self._split_node(node._left, k)
87 t = self._merge_with_root(t, node, node._right)
88 else:
89 s, t = self._split_node(node._right, u - 1)
90 s = self._merge_with_root(node._left, node, s)
91 if s:
92 s._par = par
93 if t:
94 t._par = par
95 return s, t
96
[docs]
97 def find_order(self, k: int) -> "WBTList[T]":
98 if k < 0:
99 k += len(self)
100 node = self._root
101 while True:
102 node._propagate()
103 t = node._left._size if node._left else 0
104 if t == k:
105 return node
106 if t < k:
107 k -= t + 1
108 node = node._right
109 else:
110 node = node._left
111
[docs]
112 def split(self, k: int) -> tuple["WBTList", "WBTList"]:
113 lnode, rnode = self._split_node(self._root, k)
114 l, r = WBTList(), WBTList()
115 l._root = lnode
116 r._root = rnode
117 return l, r
118
119 def _pop_max(self, node: _WBTListNode) -> tuple[_WBTListNode, _WBTListNode]:
120 l, tmp = self._split_node(node, node._size - 1)
121 return l, tmp
122
123 def _merge_node(self, l: _WBTListNode, r: _WBTListNode) -> _WBTListNode:
124 if l is None:
125 return r
126 if r is None:
127 return l
128 l, tmp = self._pop_max(l)
129 return self._merge_with_root(l, tmp, r)
130
[docs]
131 def extend(self, other: "WBTList[T]") -> None:
132 self._root = self._merge_node(self._root, other._root)
133
[docs]
134 def insert(self, k: int, key) -> None:
135 s, t = self._split_node(self._root, k)
136 self._root = self._merge_with_root(s, _WBTListNode(self, key), t)
137
[docs]
138 def pop(self, k: int):
139 s, t = self._split_node(self._root, k + 1)
140 s, tmp = self._pop_max(s)
141 self._root = self._merge_node(s, t)
142 return tmp._key
143
144 def _check(self, verbose: bool = False) -> None:
145 """作業用デバック関数
146 size,key,balanceをチェックして、正しければ高さを表示する
147 """
148 if self._root is None:
149 if verbose:
150 print("ok. 0 (empty)")
151 return
152
153 # _size, height
154 def dfs(node: _WBTListNode) -> tuple[int, int]:
155 h = 0
156 s = 1
157 if node._left:
158 assert node._left._par is node
159 ls, lh = dfs(node._left)
160 s += ls
161 h = max(h, lh)
162 if node._right:
163 assert node._right._par is node
164 rs, rh = dfs(node._right)
165 s += rs
166 h = max(h, rh)
167 assert node._size == s
168 node._balance_check()
169 return s, h + 1
170
171 assert self._root._par is None
172 _, h = dfs(self._root)
173 if verbose:
174 print(f"ok. {h}")
175
[docs]
176 def reverse(self, l, r):
177 s, t = self._split_node(self._root, r)
178 r, s = self._split_node(s, l)
179 s._apply_rev()
180 self._root = self._merge_node(self._merge_node(r, s), t)
181
182 def __len__(self):
183 return self._root._size if self._root else 0
184
185 def __iter__(self):
186 node = self._root
187 stack: list[_WBTListNode] = []
188 while stack or node:
189 if node:
190 node._propagate()
191 stack.append(node)
192 node = node._left
193 else:
194 node = stack.pop()
195 yield node._key
196 node = node._right
197
198 def __str__(self):
199 return str(list(self))