1# from titan_pylib.data_structures.wbt.wbt_list import WBTList
2# from titan_pylib.data_structures.wbt._wbt_list_node import _WBTListNode
3# from titan_pylib.data_structures.wbt._wbt_node_base import _WBTNodeBase
4from typing import Generic, TypeVar, Optional, Final
5
6T = TypeVar("T")
7
8
9class _WBTNodeBase(Generic[T]):
10 """WBTノードのベースクラス
11 size, par, left, rightをもつ
12 """
13
14 __slots__ = "_size", "_par", "_left", "_right"
15 DELTA: Final[int] = 3
16 GAMMA: Final[int] = 2
17
18 def __init__(self) -> None:
19 self._size: int = 1
20 self._par: Optional[_WBTNodeBase[T]] = None
21 self._left: Optional[_WBTNodeBase[T]] = None
22 self._right: Optional[_WBTNodeBase[T]] = None
23
24 def _rebalance(self) -> "_WBTNodeBase[T]":
25 """根までを再構築する
26
27 Returns:
28 _WBTNodeBase[T]: 根ノード
29 """
30 node = self
31 while True:
32 node._update()
33 wl, wr = node._weight_left(), node._weight_right()
34 if wl * _WBTNodeBase.DELTA < wr:
35 if (
36 node._right._weight_left()
37 >= node._right._weight_right() * _WBTNodeBase.GAMMA
38 ):
39 node._right = node._right._rotate_right()
40 node = node._rotate_left()
41 elif wr * _WBTNodeBase.DELTA < wl:
42 if (
43 node._left._weight_right()
44 >= node._left._weight_left() * _WBTNodeBase.GAMMA
45 ):
46 node._left = node._left._rotate_left()
47 node = node._rotate_right()
48 if not node._par:
49 return node
50 node = node._par
51
52 def _copy_from(self, other: "_WBTNodeBase[T]") -> None:
53 self._size = other._size
54 if other._left:
55 other._left._par = self
56 if other._right:
57 other._right._par = self
58 if other._par:
59 if other._par._left is other:
60 other._par._left = self
61 else:
62 other._par._right = self
63 self._par = other._par
64 self._left = other._left
65 self._right = other._right
66
67 def _weight_left(self) -> int:
68 return self._left._size + 1 if self._left else 1
69
70 def _weight_right(self) -> int:
71 return self._right._size + 1 if self._right else 1
72
73 def _update(self) -> None:
74 self._size = (
75 1
76 + (self._left._size if self._left else 0)
77 + (self._right._size if self._right else 0)
78 )
79
80 def _rotate_right(self) -> "_WBTNodeBase[T]":
81 u = self._left
82 u._size = self._size
83 self._size -= u._left._size + 1 if u._left else 1
84 u._par = self._par
85 self._left = u._right
86 if u._right:
87 u._right._par = self
88 u._right = self
89 self._par = u
90 if u._par:
91 if u._par._left is self:
92 u._par._left = u
93 else:
94 u._par._right = u
95 return u
96
97 def _rotate_left(self) -> "_WBTNodeBase[T]":
98 u = self._right
99 u._size = self._size
100 self._size -= u._right._size + 1 if u._right else 1
101 u._par = self._par
102 self._right = u._left
103 if u._left:
104 u._left._par = self
105 u._left = self
106 self._par = u
107 if u._par:
108 if u._par._left is self:
109 u._par._left = u
110 else:
111 u._par._right = u
112 return u
113
114 def _balance_check(self) -> None:
115 if not self._weight_left() * _WBTNodeBase.DELTA >= self._weight_right():
116 print(self._weight_left(), self._weight_right(), flush=True)
117 print(self)
118 assert False, f"self._weight_left() * DELTA >= self._weight_right()"
119 if not self._weight_right() * _WBTNodeBase.DELTA >= self._weight_left():
120 print(self._weight_left(), self._weight_right(), flush=True)
121 print(self)
122 assert False, f"self._weight_right() * DELTA >= self._weight_left()"
123
124 def _min(self) -> "_WBTNodeBase[T]":
125 node = self
126 while node._left:
127 node = node._left
128 return node
129
130 def _max(self) -> "_WBTNodeBase[T]":
131 node = self
132 while node._right:
133 node = node._right
134 return node
135
136 def _next(self) -> Optional["_WBTNodeBase[T]"]:
137 if self._right:
138 return self._right._min()
139 now, pre = self, None
140 while now and now._right is pre:
141 now, pre = now._par, now
142 return now
143
144 def _prev(self) -> Optional["_WBTNodeBase[T]"]:
145 if self._left:
146 return self._left._max()
147 now, pre = self, None
148 while now and now._left is pre:
149 now, pre = now._par, now
150 return now
151
152 def __add__(self, other: int) -> Optional["_WBTNodeBase[T]"]:
153 node = self
154 for _ in range(other):
155 node = node._next()
156 return node
157
158 def __sub__(self, other: int) -> Optional["_WBTNodeBase[T]"]:
159 node = self
160 for _ in range(other):
161 node = node._prev()
162 return node
163
164 __iadd__ = __add__
165 __isub__ = __sub__
166
167 def __str__(self) -> str:
168 # if self._left is None and self._right is None:
169 # return f"key:{self._key, self._size}\n"
170 # return f"key:{self._key, self._size},\n _left:{self._left},\n _right:{self._right}\n"
171 return str(self._key)
172
173 __repr__ = __str__
174from typing import Generic, TypeVar, Optional, TYPE_CHECKING
175
176if TYPE_CHECKING:
177 from titan_pylib.data_structures.wbt.wbt_list import WBTList
178
179T = TypeVar("T")
180
181
182class _WBTListNode(_WBTNodeBase, Generic[T]):
183
184 __slots__ = (
185 "_left",
186 "_right",
187 "_par",
188 "_tree",
189 "_key",
190 "_rev",
191 )
192
193 def __init__(self, tree: "WBTList[T]", key: T) -> None:
194 super().__init__()
195 self._tree: WBTList[T] = tree
196 self._key: T = key
197 self._rev: int = 0
198 self._left: "_WBTListNode[T]"
199 self._right: "_WBTListNode[T]"
200 self._par: "_WBTListNode[T]"
201
202 def __str__(self) -> str:
203 if self._left is None and self._right is None:
204 return f"key:{self._key, self._size}\n"
205 return f"key:{self._key, self._size},\n _left:{self._left},\n _right:{self._right}\n"
206
207 def _check(self):
208 def dfs(node: "_WBTListNode"):
209 s = 1
210 if node._left:
211 assert node._left._par is node
212 s += node._left._size
213 dfs(node._left)
214 if node._right:
215 assert node._right._par is node
216 s += node._right._size
217 dfs(node._right)
218 assert s == node._size
219
220 dfs(self)
221 # print("check ok.")
222
223 def propagate_above(self) -> None:
224 """これの上について、revをすべて伝播する"""
225 stack: list["_WBTListNode[T]"] = []
226 node = self
227 while node:
228 stack.append(node)
229 node = node._par
230 while stack:
231 node = stack.pop()
232 node._propagate()
233
234 def update_above(self) -> None:
235 """これの上について、updateする
236
237 Note:
238 これの上はすべて revを伝播済み
239 """
240 node = self
241 while node:
242 node._update()
243 node = node._par
244
245 def _update(self) -> None:
246 self._size = 1
247 if self._left:
248 self._size += self._left._size
249 if self._right:
250 self._size += self._right._size
251
252 def _apply_rev(self) -> None:
253 self._rev ^= 1
254
255 def _propagate(self) -> None:
256 if self._rev:
257 self._left, self._right = self._right, self._left
258 if self._left:
259 self._left._apply_rev()
260 if self._right:
261 self._right._apply_rev()
262 self._rev = 0
263
264 def _rotate_right(self) -> "_WBTListNode[T]":
265 u = self._left
266 u._propagate()
267 u._par = self._par
268 self._left = u._right
269 if u._right:
270 u._right._par = self
271 u._right = self
272 self._par = u
273 if u._par:
274 if u._par._left is self:
275 u._par._left = u
276 else:
277 u._par._right = u
278 self._update()
279 u._update()
280 return u
281
282 def _rotate_left(self) -> "_WBTListNode[T]":
283 u = self._right
284 u._propagate()
285 u._par = self._par
286 self._right = u._left
287 if u._left:
288 u._left._par = self
289 u._left = self
290 self._par = u
291 if u._par:
292 if u._par._left is self:
293 u._par._left = u
294 else:
295 u._par._right = u
296 self._update()
297 u._update()
298 return u
299
300 def _balance_left(self) -> "_WBTListNode[T]":
301 self._right._propagate()
302 if self._right._weight_left() >= self._right._weight_right() * self.GAMMA:
303 self._right = self._right._rotate_right()
304 return self._rotate_left()
305
306 def _balance_right(self) -> "_WBTListNode[T]":
307 self._left._propagate()
308 if self._left._weight_right() >= self._left._weight_left() * self.GAMMA:
309 self._left = self._left._rotate_left()
310 return self._rotate_right()
311
312 def _min(self) -> "_WBTListNode[T]":
313 self.propagate_above()
314 assert self._rev == 0
315 node = self
316 while node._left:
317 node = node._left
318 node._propagate()
319 return node
320
321 def _max(self) -> "_WBTListNode[T]":
322 self.propagate_above()
323 assert self._rev == 0
324 node = self
325 while node._right:
326 node = node._right
327 node._propagate()
328 return node
329
330 def _next(self) -> Optional["_WBTListNode[T]"]:
331 self.propagate_above()
332 if self._right:
333 return self._right._min()
334 now, pre = self, None
335 while now and now._right is pre:
336 now, pre = now._par, now
337 return now
338
339 def _prev(self) -> Optional["_WBTListNode[T]"]:
340 self.propagate_above()
341 if self._left:
342 return self._left._max()
343 now, pre = self, None
344 while now and now._left is pre:
345 now, pre = now._par, now
346 return now
347from typing import Generic, TypeVar, Optional, Iterable, Callable
348
349T = TypeVar("T")
350
351
352class WBTList(Generic[T]):
353 # insert / pop / pop_max
354
355 def __init__(
356 self,
357 a: Iterable[T] = [],
358 ) -> None:
359 self._root = None
360 self.__build(a)
361
362 def __build(self, a: Iterable[T]) -> None:
363 def build(l: int, r: int, pnode: Optional[_WBTListNode] = None) -> _WBTListNode:
364 if l == r:
365 return None
366 mid = (l + r) // 2
367 node = _WBTListNode(self, a[mid])
368 node._left = build(l, mid, node)
369 node._right = build(mid + 1, r, node)
370 node._par = pnode
371 node._update()
372 return node
373
374 if not isinstance(a, list):
375 a = list(a)
376 if not a:
377 return
378 self._root = build(0, len(a))
379
380 @classmethod
381 def _weight(self, node: Optional[_WBTListNode]) -> int:
382 return node._size + 1 if node else 1
383
384 def _merge_with_root(
385 self,
386 l: Optional[_WBTListNode],
387 root: _WBTListNode,
388 r: Optional[_WBTListNode],
389 ) -> _WBTListNode:
390 if self._weight(l) * _WBTListNode.DELTA < self._weight(r):
391 r._propagate()
392 r._left = self._merge_with_root(l, root, r._left)
393 r._left._par = r
394 r._par = None
395 r._update()
396 if self._weight(r._right) * _WBTListNode.DELTA < self._weight(r._left):
397 return r._balance_right()
398 return r
399 elif self._weight(r) * _WBTListNode.DELTA < self._weight(l):
400 l._propagate()
401 l._right = self._merge_with_root(l._right, root, r)
402 l._right._par = l
403 l._par = None
404 l._update()
405 if self._weight(l._left) * _WBTListNode.DELTA < self._weight(l._right):
406 return l._balance_left()
407 return l
408 else:
409 root._left = l
410 root._right = r
411 if l:
412 l._par = root
413 if r:
414 r._par = root
415 root._update()
416 return root
417
418 def _split_node(
419 self, node: _WBTListNode, k: int
420 ) -> tuple[Optional[_WBTListNode], Optional[_WBTListNode]]:
421 if not node:
422 return None, None
423 node._propagate()
424 par = node._par
425 u = k if node._left is None else k - node._left._size
426 s, t = None, None
427 if u == 0:
428 s = node._left
429 t = self._merge_with_root(None, node, node._right)
430 elif u < 0:
431 s, t = self._split_node(node._left, k)
432 t = self._merge_with_root(t, node, node._right)
433 else:
434 s, t = self._split_node(node._right, u - 1)
435 s = self._merge_with_root(node._left, node, s)
436 if s:
437 s._par = par
438 if t:
439 t._par = par
440 return s, t
441
442 def find_order(self, k: int) -> "WBTList[T]":
443 if k < 0:
444 k += len(self)
445 node = self._root
446 while True:
447 node._propagate()
448 t = node._left._size if node._left else 0
449 if t == k:
450 return node
451 if t < k:
452 k -= t + 1
453 node = node._right
454 else:
455 node = node._left
456
457 def split(self, k: int) -> tuple["WBTList", "WBTList"]:
458 lnode, rnode = self._split_node(self._root, k)
459 l, r = WBTList(), WBTList()
460 l._root = lnode
461 r._root = rnode
462 return l, r
463
464 def _pop_max(self, node: _WBTListNode) -> tuple[_WBTListNode, _WBTListNode]:
465 l, tmp = self._split_node(node, node._size - 1)
466 return l, tmp
467
468 def _merge_node(self, l: _WBTListNode, r: _WBTListNode) -> _WBTListNode:
469 if l is None:
470 return r
471 if r is None:
472 return l
473 l, tmp = self._pop_max(l)
474 return self._merge_with_root(l, tmp, r)
475
476 def extend(self, other: "WBTList[T]") -> None:
477 self._root = self._merge_node(self._root, other._root)
478
479 def insert(self, k: int, key) -> None:
480 s, t = self._split_node(self._root, k)
481 self._root = self._merge_with_root(s, _WBTListNode(self, key), t)
482
483 def pop(self, k: int):
484 s, t = self._split_node(self._root, k + 1)
485 s, tmp = self._pop_max(s)
486 self._root = self._merge_node(s, t)
487 return tmp._key
488
489 def _check(self, verbose: bool = False) -> None:
490 """作業用デバック関数
491 size,key,balanceをチェックして、正しければ高さを表示する
492 """
493 if self._root is None:
494 if verbose:
495 print("ok. 0 (empty)")
496 return
497
498 # _size, height
499 def dfs(node: _WBTListNode) -> tuple[int, int]:
500 h = 0
501 s = 1
502 if node._left:
503 assert node._left._par is node
504 ls, lh = dfs(node._left)
505 s += ls
506 h = max(h, lh)
507 if node._right:
508 assert node._right._par is node
509 rs, rh = dfs(node._right)
510 s += rs
511 h = max(h, rh)
512 assert node._size == s
513 node._balance_check()
514 return s, h + 1
515
516 assert self._root._par is None
517 _, h = dfs(self._root)
518 if verbose:
519 print(f"ok. {h}")
520
521 def reverse(self, l, r):
522 s, t = self._split_node(self._root, r)
523 r, s = self._split_node(s, l)
524 s._apply_rev()
525 self._root = self._merge_node(self._merge_node(r, s), t)
526
527 def __len__(self):
528 return self._root._size if self._root else 0
529
530 def __iter__(self):
531 node = self._root
532 stack: list[_WBTListNode] = []
533 while stack or node:
534 if node:
535 node._propagate()
536 stack.append(node)
537 node = node._left
538 else:
539 node = stack.pop()
540 yield node._key
541 node = node._right
542
543 def __str__(self):
544 return str(list(self))