1from titan_pylib.data_structures.wbt._wbt_set_node import _WBTSetNode
2from typing import Generic, TypeVar, Optional, Iterable, Iterator
3
4T = TypeVar("T")
5
6
[docs]
7class WBTSet(Generic[T]):
8 """重み平衡木で実装された順序付き集合"""
9
10 __slots__ = "_root", "_min", "_max"
11
12 def __init__(self, a: Iterable[T] = []) -> None:
13 """イテラブル ``a`` から ``WBTSet`` を構築します。
14
15 Args:
16 a (Iterable[T], optional): 構築元のイテラブルです。
17
18 計算量:
19
20 ソート済みなら :math:`O(n)` 、そうでないなら :math:`O(n \\log{n})`
21 """
22 self._root: Optional[_WBTSetNode[T]] = None
23 self._min: Optional[_WBTSetNode[T]] = None
24 self._max: Optional[_WBTSetNode[T]] = None
25 self.__build(a)
26
27 def __build(self, a: Iterable[T]) -> None:
28 """再帰的に構築する関数"""
29
30 def build(
31 l: int, r: int, pnode: Optional[_WBTSetNode[T]] = None
32 ) -> _WBTSetNode[T]:
33 if l == r:
34 return None
35 mid = (l + r) // 2
36 node = _WBTSetNode(a[mid])
37 node._left = build(l, mid, node)
38 node._right = build(mid + 1, r, node)
39 node._par = pnode
40 node._update()
41 return node
42
43 a = list(a)
44 if not a:
45 return
46 if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
47 a.sort()
48 new_a = [a[0]]
49 for elm in a:
50 if new_a[-1] == elm:
51 continue
52 new_a.append(elm)
53 a = new_a
54 self._root = build(0, len(a))
55 self._max = self._root._max()
56 self._min = self._root._min()
57
[docs]
58 def add(self, key: T) -> bool:
59 """既に ``key`` が存在していれば何もせず ``False`` を返し、
60 存在していれば ``key`` を 1 つ追加して ``True`` を返します。
61
62 Args:
63 key (T): 追加するキーです。
64
65 Returns:
66 bool: ``key`` を追加したら ``True`` 、そうでなければ ``False`` を返します。
67
68 計算量:
69 :math:`O(\\log{n})`
70 """
71 if not self._root:
72 self._root = _WBTSetNode(key)
73 self._max = self._root
74 self._min = self._root
75 return True
76 pnode = None
77 node = self._root
78 while node:
79 if key == node._key:
80 return False
81 pnode = node
82 node = node._left if key < node._key else node._right
83 if key < pnode._key:
84 pnode._left = _WBTSetNode(key)
85 if key < self._min._key:
86 self._min = pnode._left
87 pnode._left._par = pnode
88 else:
89 pnode._right = _WBTSetNode(key)
90 if key > self._max._key:
91 self._max = pnode._right
92 pnode._right._par = pnode
93 self._root = pnode._rebalance()
94 return True
95
[docs]
96 def find_key(self, key: T) -> Optional[_WBTSetNode[T]]:
97 """``key`` が存在すれば ``key`` を指すノードを返します。
98 そうでなければ ``None`` を返します。
99
100 Args:
101 key (T):
102
103 Returns:
104 Optional[_WBTSetNode[T]]:
105
106 計算量:
107 :math:`O(\\log{n})`
108 """
109 node = self._root
110 while node:
111 if key == node._key:
112 return node
113 node = node._left if key < node._key else node._right
114 return None
115
[docs]
116 def find_order(self, k: int) -> _WBTSetNode[T]:
117 """昇順 ``k`` 番目のノードを返します。
118
119 Args:
120 k (int):
121
122 Returns:
123 _WBTSetNode[T]:
124
125 計算量:
126 :math:`O(\\log{n})`
127
128 制約:
129 :math:`-n \\leq k \\le n`
130 """
131 if k < 0:
132 k += len(self)
133 node = self._root
134 while True:
135 t = node._left._size if node._left else 0
136 if t == k:
137 return node
138 if t < k:
139 k -= t + 1
140 node = node._right
141 else:
142 node = node._left
143
[docs]
144 def count(self, key: T) -> int:
145 return 1 if self.find_key(key) is not None else 0
146
[docs]
147 def remove_iter(self, node: _WBTSetNode[T]) -> None:
148 """``node`` を削除します。
149
150 Args:
151 node (_WBTSetNode[T]):
152
153 計算量:
154 :math:`O(\\log{n})`
155 """
156 if node is self._min:
157 self._min = self._min._next()
158 if node is self._max:
159 self._max = self._max._prev()
160 delnode = node
161 pnode, mnode = node._par, None
162 if node._left and node._right:
163 pnode, mnode = node, node._left
164 while mnode._right:
165 pnode, mnode = mnode, mnode._right
166 node = mnode
167 cnode = node._right if not node._left else node._left
168 if cnode:
169 cnode._par = pnode
170 if pnode:
171 if pnode._left is node:
172 pnode._left = cnode
173 else:
174 pnode._right = cnode
175 self._root = pnode._rebalance()
176 else:
177 self._root = cnode
178 if mnode:
179 if self._root is delnode:
180 self._root = mnode
181 mnode._copy_from(delnode)
182 del delnode
183
[docs]
184 def remove(self, key: T) -> None:
185 """``key`` を削除します。
186
187 Args:
188 key (T): 削除する ``key`` です。
189
190 計算量:
191 :math:`O(\\log{n})`
192
193 Note:
194 ``key`` が存在しない場合、 ``AssertionError`` を出します。
195 """
196 node = self.find_key(key)
197 assert node, f"KeyError: {key} is not exist."
198 self.remove_iter(node)
199
[docs]
200 def discard(self, key: T) -> bool:
201 """``key`` が存在すれば削除して ``True`` を返します。
202 存在しなければなにもせず ``False`` を返します。
203
204 Args:
205 key (T): 削除する ``key`` です。
206
207 Returns:
208 bool: ``key`` が存在したかどうか
209
210 計算量:
211 :math:`O(\\log{n})`
212 """
213 node = self.find_key(key)
214 if node is None:
215 return False
216 self.remove_iter(node)
217 return True
218
[docs]
219 def pop(self, k: int = -1) -> T:
220 """``k`` 番目の値を削除して返します。
221 引数指定がない場合は最大の値を削除して返します。
222
223 Args:
224 k (int, optional): 削除するインデックスです。
225
226 Returns:
227 T: ``k`` 番目の値です。
228
229 計算量:
230 :math:`O(\\log{n})`
231 """
232 node = self.find_order(k)
233 key = node._key
234 self.remove_iter(node)
235 return key
236
[docs]
237 def le_iter(self, key: T) -> Optional[_WBTSetNode[T]]:
238 """``key`` 以下で最大のノードを返します。存在しないときは ``None`` を返します。
239
240 計算量:
241 :math:`O(\\log{n})`
242 """
243 res = None
244 node = self._root
245 while node:
246 if key == node._key:
247 res = node
248 break
249 if key < node._key:
250 node = node._left
251 else:
252 res = node
253 node = node._right
254 return res
255
[docs]
256 def lt_iter(self, key: T) -> Optional[_WBTSetNode[T]]:
257 """``key`` より小さい値で最大のノードを返します。存在しないときは ``None`` を返します。
258
259 計算量:
260 :math:`O(\\log{n})`
261 """
262 res = None
263 node = self._root
264 while node:
265 if key <= node._key:
266 node = node._left
267 else:
268 res = node
269 node = node._right
270 return res
271
[docs]
272 def ge_iter(self, key: T) -> Optional[_WBTSetNode[T]]:
273 """``key`` 以上で最小のノードを返します。存在しないときは ``None`` を返します。
274
275 計算量:
276 :math:`O(\\log{n})`
277 """
278 res = None
279 node = self._root
280 while node:
281 if key == node._key:
282 res = node
283 break
284 if key < node._key:
285 res = node
286 node = node._left
287 else:
288 node = node._right
289 return res
290
[docs]
291 def gt_iter(self, key: T) -> Optional[_WBTSetNode[T]]:
292 """``key`` より大きい値で最小のノードを返します。存在しないときは ``None`` を返します。
293
294 計算量:
295 :math:`O(\\log{n})`
296 """
297 res = None
298 node = self._root
299 while node:
300 if key < node._key:
301 res = node
302 node = node._left
303 else:
304 node = node._right
305 return res
306
[docs]
307 def le(self, key: T) -> Optional[T]:
308 """``key`` 以下で最大の要素を返します。存在しないときは ``None`` を返します。
309
310 計算量:
311 :math:`O(\\log{n})`
312 """
313 res = None
314 node = self._root
315 while node:
316 if key == node._key:
317 res = key
318 break
319 if key < node._key:
320 node = node._left
321 else:
322 res = node._key
323 node = node._right
324 return res
325
[docs]
326 def lt(self, key: T) -> Optional[T]:
327 """``key`` より小さい値で最大の要素を返します。存在しないときは ``None`` を返します。
328
329 計算量:
330 :math:`O(\\log{n})`
331 """
332 res = None
333 node = self._root
334 while node:
335 if key <= node._key:
336 node = node._left
337 else:
338 res = node._key
339 node = node._right
340 return res
341
[docs]
342 def ge(self, key: T) -> Optional[T]:
343 """``key`` 以上で最小の要素を返します。存在しないときは ``None`` を返します。
344
345 計算量:
346 :math:`O(\\log{n})`
347 """
348 res = None
349 node = self._root
350 while node:
351 if key == node._key:
352 res = key
353 break
354 if key < node._key:
355 res = node._key
356 node = node._left
357 else:
358 node = node._right
359 return res
360
[docs]
361 def gt(self, key: T) -> Optional[T]:
362 """``key`` より大きい値で最小の要素を返します。存在しないときは ``None`` を返します。
363
364 計算量:
365 :math:`O(\\log{n})`
366 """
367 res = None
368 node = self._root
369 while node:
370 if key < node._key:
371 res = node._key
372 node = node._left
373 else:
374 node = node._right
375 return res
376
[docs]
377 def index(self, key: T) -> int:
378 """``key`` より小さい値を個数を返します。
379
380 Args:
381 key (T):
382
383 Returns:
384 int:
385
386 計算量:
387 :math:`O(\\log{n})`
388 """
389 k = 0
390 node = self._root
391 while node:
392 if key == node._key:
393 k += node._left._size if node._left else 0
394 break
395 if key < node._key:
396 node = node._left
397 else:
398 k += node._left._size + 1 if node._left else 1
399 node = node._right
400 return k
401
[docs]
402 def index_right(self, key: T) -> int:
403 """``key`` 以下の値を個数を返します。
404
405 Args:
406 key (T):
407
408 Returns:
409 int:
410
411 計算量:
412 :math:`O(\\log{n})`
413 """
414 k = 0
415 node = self._root
416 while node:
417 if key == node._key:
418 k += node._left._size + 1 if node._left else 1
419 break
420 if key < node._key:
421 node = node._left
422 else:
423 k += node._left._size + 1 if node._left else 1
424 node = node._right
425 return k
426
[docs]
427 def get_min(self) -> T:
428 """最小の要素を返します。
429
430 Returns:
431 T:
432
433 計算量:
434 :math:`O(1)`
435
436 制約:
437 :math:`0 < n`
438 """
439 assert self._min
440 return self._min._key
441
[docs]
442 def get_max(self) -> T:
443 """最大の要素を返します。
444
445 Returns:
446 T:
447
448 計算量:
449 :math:`O(1)`
450
451 制約:
452 :math:`0 < n`
453 """
454 assert self._max
455 return self._max._key
456
[docs]
457 def pop_min(self) -> T:
458 """最小の要素を削除して返します。
459
460 Returns:
461 T:
462
463 計算量:
464 :math:`O(\\log{n})`
465
466 制約:
467 :math:`0 < n`
468 """
469 assert self._min
470 key = self._min._key
471 self.remove_iter(self._min)
472 return key
473
[docs]
474 def pop_max(self) -> T:
475 """最大の要素を削除して返します。
476
477 Returns:
478 T:
479
480 計算量:
481 :math:`O(\\log{n})`
482
483 制約:
484 :math:`0 < n`
485 """
486 assert self._max
487 key = self._max._key
488 self.remove_iter(self._max)
489 return key
490
491 def _check(self) -> int:
492 """作業用デバック関数
493 size,key,balanceをチェックして、正しければ高さを表示する
494 """
495 if self._root is None:
496 # print("ok. 0 (empty)")
497 return 0
498
499 # _size, height
500 def dfs(node: _WBTSetNode[T]) -> tuple[int, int]:
501 h = 0
502 s = 1
503 if node._left:
504 assert node._key > node._left._key
505 ls, lh = dfs(node._left)
506 s += ls
507 h = max(h, lh)
508 if node._right:
509 assert node._key < node._right._key
510 rs, rh = dfs(node._right)
511 s += rs
512 h = max(h, rh)
513 assert node._size == s
514 node._balance_check()
515 return s, h + 1
516
517 _, h = dfs(self._root)
518 # print(f"ok. {h}")
519 return h
520
[docs]
521 def __contains__(self, key: T) -> bool:
522 """``key`` が存在すれば ``True`` 、そうでなければ ``False`` を返します。
523
524 Args:
525 key (T):
526
527 Returns:
528 bool:
529
530 計算量:
531 :math:`O(\\log{n})`
532 """
533 return self.find_key(key) is not None
534
[docs]
535 def __getitem__(self, k: int) -> T:
536 """昇順 ``k`` 番目の値を返します。
537
538 Args:
539 k (int):
540
541 Returns:
542 T:
543
544 計算量:
545 k = 0 または k = n-1 の場合: :math:`O(1)`
546 そうでない場合: :math:`O(\\log{n})`
547 """
548 assert (
549 -len(self) <= k < len(self)
550 ), f"IndexError: {self.__class__.__name__}[{k}], len={len(self)}"
551 if k < 0:
552 k += len(self)
553 if k == 0:
554 return self.get_min()
555 if k == len(self) - 1:
556 return self.get_max()
557 return self.find_order(k)._key
558
559 # def __delitem__(self, k: int) -> None:
560 # self.remove_iter(self.find_order(k))
561
[docs]
562 def __len__(self) -> int:
563 """要素数を返します。
564
565 Returns:
566 int:
567
568 計算量:
569 :math:`O(1)`
570 """
571 return self._root._size if self._root else 0
572
[docs]
573 def __iter__(self) -> Iterator[T]:
574 """昇順に値を返します。
575
576 Yields:
577 Iterator[T]:
578
579 計算量:
580 全体で :math:`O(n)`
581 """
582 stack: list[_WBTSetNode[T]] = []
583 node = self._root
584 while stack or node:
585 if node:
586 stack.append(node)
587 node = node._left
588 else:
589 node = stack.pop()
590 yield node._key
591 node = node._right
592
[docs]
593 def __reversed__(self) -> Iterator[T]:
594 """降順に値を返します。
595
596 Yields:
597 Iterator[T]:
598
599 計算量:
600 全体で :math:`O(n)`
601 """
602 stack: list[_WBTSetNode[T]] = []
603 node = self._root
604 while stack or node:
605 if node:
606 stack.append(node)
607 node = node._right
608 else:
609 node = stack.pop()
610 yield node._key
611 node = node._left
612
613 def __str__(self) -> str:
614 return "{" + ", ".join(map(str, self)) + "}"
615
616 def __repr__(self) -> str:
617 return f"{self.__class__.__name__}(" + "{" + ", ".join(map(str, self)) + "})"