avl_tree_set3¶
ソースコード¶
from titan_pylib.data_structures.avl_tree.avl_tree_set3 import AVLTreeSet3
展開済みコード¶
1# from titan_pylib.data_structures.avl_tree.avl_tree_set3 import AVLTreeSet3
2# from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
4from typing import Protocol
5
6
7class SupportsLessThan(Protocol):
8
9 def __lt__(self, other) -> bool: ...
10from abc import ABC, abstractmethod
11from typing import Iterable, Optional, Iterator, TypeVar, Generic
12
13T = TypeVar("T", bound=SupportsLessThan)
14
15
16class OrderedSetInterface(ABC, Generic[T]):
17
18 @abstractmethod
19 def __init__(self, a: Iterable[T]) -> None:
20 raise NotImplementedError
21
22 @abstractmethod
23 def add(self, key: T) -> bool:
24 raise NotImplementedError
25
26 @abstractmethod
27 def discard(self, key: T) -> bool:
28 raise NotImplementedError
29
30 @abstractmethod
31 def remove(self, key: T) -> None:
32 raise NotImplementedError
33
34 @abstractmethod
35 def le(self, key: T) -> Optional[T]:
36 raise NotImplementedError
37
38 @abstractmethod
39 def lt(self, key: T) -> Optional[T]:
40 raise NotImplementedError
41
42 @abstractmethod
43 def ge(self, key: T) -> Optional[T]:
44 raise NotImplementedError
45
46 @abstractmethod
47 def gt(self, key: T) -> Optional[T]:
48 raise NotImplementedError
49
50 @abstractmethod
51 def get_max(self) -> Optional[T]:
52 raise NotImplementedError
53
54 @abstractmethod
55 def get_min(self) -> Optional[T]:
56 raise NotImplementedError
57
58 @abstractmethod
59 def pop_max(self) -> T:
60 raise NotImplementedError
61
62 @abstractmethod
63 def pop_min(self) -> T:
64 raise NotImplementedError
65
66 @abstractmethod
67 def clear(self) -> None:
68 raise NotImplementedError
69
70 @abstractmethod
71 def tolist(self) -> list[T]:
72 raise NotImplementedError
73
74 @abstractmethod
75 def __iter__(self) -> Iterator:
76 raise NotImplementedError
77
78 @abstractmethod
79 def __next__(self) -> T:
80 raise NotImplementedError
81
82 @abstractmethod
83 def __contains__(self, key: T) -> bool:
84 raise NotImplementedError
85
86 @abstractmethod
87 def __len__(self) -> int:
88 raise NotImplementedError
89
90 @abstractmethod
91 def __bool__(self) -> bool:
92 raise NotImplementedError
93
94 @abstractmethod
95 def __str__(self) -> str:
96 raise NotImplementedError
97
98 @abstractmethod
99 def __repr__(self) -> str:
100 raise NotImplementedError
101# from titan_pylib.my_class.supports_less_than import SupportsLessThan
102from typing import Generic, Iterable, TypeVar, Optional, Sequence
103
104T = TypeVar("T", bound=SupportsLessThan)
105
106
107class AVLTreeSet3(OrderedSetInterface, Generic[T]):
108 """
109 集合としての AVL木 です。
110 size を持ちます。
111 ``class Node()`` を用いています。
112 """
113
114 class Node:
115
116 def __init__(self, key: T):
117 self.key: T = key
118 self.size: int = 1
119 self.left: Optional["AVLTreeSet3.Node"] = None
120 self.right: Optional["AVLTreeSet3.Node"] = None
121 self.balance: int = 0
122
123 def __str__(self):
124 if self.left is None and self.right is None:
125 return f"key:{self.key, self.size}\n"
126 return (
127 f"key:{self.key, self.size},\n left:{self.left},\n right:{self.right}\n"
128 )
129
130 def __init__(self, a: Iterable[T] = []) -> None:
131 self.node = None
132 if not isinstance(a, Sequence):
133 a = list(a)
134 if a:
135 self._build(a)
136
137 def _build(self, a: Sequence[T]) -> None:
138 Node = AVLTreeSet3.Node
139
140 def rec(l: int, r: int) -> tuple[AVLTreeSet3.Node, int]:
141 mid = (l + r) >> 1
142 node = Node(a[mid])
143 hl, hr = 0, 0
144 if l != mid:
145 node.left, hl = rec(l, mid)
146 node.size += node.left.size
147 if mid + 1 != r:
148 node.right, hr = rec(mid + 1, r)
149 node.size += node.right.size
150 node.balance = hl - hr
151 return node, max(hl, hr) + 1
152
153 if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
154 a = sorted(set(a))
155 self.node = rec(0, len(a))[0]
156
157 def _rotate_L(self, node: Node) -> Node:
158 u = node.left
159 u.size = node.size
160 node.size -= 1 if u.left is None else u.left.size + 1
161 node.left = u.right
162 u.right = node
163 if u.balance == 1:
164 u.balance = 0
165 node.balance = 0
166 else:
167 u.balance = -1
168 node.balance = 1
169 return u
170
171 def _rotate_R(self, node: Node) -> Node:
172 u = node.right
173 u.size = node.size
174 node.size -= 1 if u.right is None else u.right.size + 1
175 node.right = u.left
176 u.left = node
177 if u.balance == -1:
178 u.balance = 0
179 node.balance = 0
180 else:
181 u.balance = 1
182 node.balance = -1
183 return u
184
185 def _update_balance(self, node: Node) -> None:
186 if node.balance == 1:
187 node.right.balance = -1
188 node.left.balance = 0
189 elif node.balance == -1:
190 node.right.balance = 0
191 node.left.balance = 1
192 else:
193 node.right.balance = 0
194 node.left.balance = 0
195 node.balance = 0
196
197 def _rotate_LR(self, node: Node) -> Node:
198 B = node.left
199 E = B.right
200 E.size = node.size
201 if E.right is None:
202 node.size -= B.size
203 B.size -= 1
204 else:
205 node.size -= B.size - E.right.size
206 B.size -= E.right.size + 1
207 B.right = E.left
208 E.left = B
209 node.left = E.right
210 E.right = node
211 self._update_balance(E)
212 return E
213
214 def _rotate_RL(self, node: Node) -> Node:
215 C = node.right
216 D = C.left
217 D.size = node.size
218 if D.left is None:
219 node.size -= C.size
220 C.size -= 1
221 else:
222 node.size -= C.size - D.left.size
223 C.size -= D.left.size + 1
224 C.left = D.right
225 D.right = C
226 node.right = D.left
227 D.left = node
228 self._update_balance(D)
229 return D
230
231 def _kth_elm(self, k: int) -> T:
232 if k < 0:
233 k += self.node.size
234 node = self.node
235 while True:
236 t = 0 if node.left is None else node.left.size
237 if t == k:
238 return node.key
239 elif t < k:
240 k -= t + 1
241 node = node.right
242 else:
243 node = node.left
244
245 def add(self, key: T) -> bool:
246 if self.node is None:
247 self.node = AVLTreeSet3.Node(key)
248 return True
249 pnode = self.node
250 path = []
251 di = 0
252 while pnode is not None:
253 if key == pnode.key:
254 return False
255 elif key < pnode.key:
256 path.append(pnode)
257 di <<= 1
258 di |= 1
259 pnode = pnode.left
260 else:
261 path.append(pnode)
262 di <<= 1
263 pnode = pnode.right
264 if di & 1:
265 path[-1].left = AVLTreeSet3.Node(key)
266 else:
267 path[-1].right = AVLTreeSet3.Node(key)
268 new_node = None
269 while path:
270 pnode = path.pop()
271 pnode.size += 1
272 pnode.balance += 1 if di & 1 else -1
273 di >>= 1
274 if pnode.balance == 0:
275 break
276 if pnode.balance == 2:
277 new_node = (
278 self._rotate_LR(pnode)
279 if pnode.left.balance == -1
280 else self._rotate_L(pnode)
281 )
282 break
283 elif pnode.balance == -2:
284 new_node = (
285 self._rotate_RL(pnode)
286 if pnode.right.balance == 1
287 else self._rotate_R(pnode)
288 )
289 break
290 if new_node is not None:
291 if path:
292 gnode = path.pop()
293 gnode.size += 1
294 if di & 1:
295 gnode.left = new_node
296 else:
297 gnode.right = new_node
298 else:
299 self.node = new_node
300 for p in path:
301 p.size += 1
302 return True
303
304 def discard(self, key: T) -> bool:
305 di = 0
306 path = []
307 node = self.node
308 while node:
309 if key == node.key:
310 break
311 elif key < node.key:
312 path.append(node)
313 di <<= 1
314 di |= 1
315 node = node.left
316 else:
317 path.append(node)
318 di <<= 1
319 node = node.right
320 else:
321 return False
322 if node.left and node.right:
323 path.append(node)
324 di <<= 1
325 di |= 1
326 lmax = node.left
327 while lmax.right:
328 path.append(lmax)
329 di <<= 1
330 lmax = lmax.right
331 node.key = lmax.key
332 node = lmax
333 cnode = node.right if node.left is None else node.left
334 if path:
335 if di & 1:
336 path[-1].left = cnode
337 else:
338 path[-1].right = cnode
339 else:
340 self.node = cnode
341 return True
342 while path:
343 new_node = None
344 pnode = path.pop()
345 pnode.balance -= 1 if di & 1 else -1
346 di >>= 1
347 pnode.size -= 1
348 if pnode.balance == 2:
349 new_node = (
350 self._rotate_LR(pnode)
351 if pnode.left.balance == -1
352 else self._rotate_L(pnode)
353 )
354 elif pnode.balance == -2:
355 new_node = (
356 self._rotate_RL(pnode)
357 if pnode.right.balance == 1
358 else self._rotate_R(pnode)
359 )
360 elif pnode.balance != 0:
361 break
362 if new_node:
363 if not path:
364 self.node = new_node
365 return True
366 if di & 1:
367 path[-1].left = new_node
368 else:
369 path[-1].right = new_node
370 if new_node.balance != 0:
371 break
372 for p in path:
373 p.size -= 1
374 return True
375
376 def remove(self, key: T) -> None:
377 if self.discard(key):
378 return
379 raise KeyError(key)
380
381 def le(self, key: T) -> Optional[T]:
382 res = None
383 node = self.node
384 while node is not None:
385 if key == node.key:
386 res = key
387 break
388 elif key < node.key:
389 node = node.left
390 else:
391 res = node.key
392 node = node.right
393 return res
394
395 def lt(self, key: T) -> Optional[T]:
396 res = None
397 node = self.node
398 while node is not None:
399 if key <= node.key:
400 node = node.left
401 else:
402 res = node.key
403 node = node.right
404 return res
405
406 def ge(self, key: T) -> Optional[T]:
407 res = None
408 node = self.node
409 while node is not None:
410 if key == node.key:
411 res = key
412 break
413 elif key < node.key:
414 res = node.key
415 node = node.left
416 else:
417 node = node.right
418 return res
419
420 def gt(self, key: T) -> Optional[T]:
421 res = None
422 node = self.node
423 while node is not None:
424 if key < node.key:
425 res = node.key
426 node = node.left
427 else:
428 node = node.right
429 return res
430
431 def index(self, key: T) -> int:
432 k = 0
433 node = self.node
434 while node is not None:
435 if key == node.key:
436 k += 0 if node.left is None else node.left.size
437 break
438 elif key < node.key:
439 node = node.left
440 else:
441 k += 1 if node.left is None else node.left.size + 1
442 node = node.right
443 return k
444
445 def index_right(self, key: T) -> int:
446 k = 0
447 node = self.node
448 while node is not None:
449 if key == node.key:
450 k += 1 if node.left is None else node.left.size + 1
451 break
452 elif key < node.key:
453 node = node.left
454 else:
455 k += 1 if node.left is None else node.left.size + 1
456 node = node.right
457 return k
458
459 def pop(self, k: int = -1) -> T:
460 assert (
461 self.node is not None
462 ), f"IndexError: {self.__class__.__name__}.pop({k}), pop({k}) from Empty {self.__class__.__name__}"
463 x = self._kth_elm(k)
464 self.discard(x)
465 return x
466
467 def pop_max(self) -> T:
468 assert (
469 self.node is not None
470 ), f"IndexError: {self.__class__.__name__}.pop_max(), pop_max from Empty {self.__class__.__name__}"
471 return self.pop()
472
473 def pop_min(self) -> T:
474 assert (
475 self.node is not None
476 ), f"IndexError: {self.__class__.__name__}.pop_min(), pop_min from Empty {self.__class__.__name__}"
477 return self.pop(0)
478
479 def get_max(self) -> Optional[T]:
480 if self.node is None:
481 return
482 return self._kth_elm(-1)
483
484 def get_min(self) -> Optional[T]:
485 if self.node is None:
486 return
487 return self._kth_elm(0)
488
489 def clear(self) -> None:
490 self.node = None
491
492 def tolist(self) -> list[T]:
493 a = []
494 if self.node is None:
495 return a
496
497 def rec(node):
498 if node.left is not None:
499 rec(node.left)
500 a.append(node.key)
501 if node.right is not None:
502 rec(node.right)
503
504 rec(self.node)
505 return a
506
507 def __contains__(self, key: T) -> bool:
508 node = self.node
509 while node is not None:
510 if key == node.key:
511 return True
512 elif key < node.key:
513 node = node.left
514 else:
515 node = node.right
516 return False
517
518 def __getitem__(self, k: int) -> T:
519 assert (
520 -len(self) <= k < len(self)
521 ), f"IndexError: {self.__class__.__name__}.__getitem__({k}), len={len(self)}"
522 return self._kth_elm(k)
523
524 def __iter__(self):
525 self.__iter = 0
526 return self
527
528 def __next__(self):
529 if self.__iter == self.__len__():
530 raise StopIteration
531 res = self.__getitem__(self.__iter)
532 self.__iter += 1
533 return res
534
535 def __reversed__(self):
536 for i in range(self.__len__()):
537 yield self.__getitem__(-i - 1)
538
539 def __len__(self):
540 return 0 if self.node is None else self.node.size
541
542 def __bool__(self):
543 return self.node is not None
544
545 def __str__(self):
546 return "{" + ", ".join(map(str, self.tolist())) + "}"
547
548 def __repr__(self):
549 return f"AVLTreeSet3({str(self)})"
仕様¶
- class AVLTreeSet3(a: Iterable[T] = [])[source]¶
Bases:
OrderedSetInterface
,Generic
[T
]集合としての AVL木 です。 size を持ちます。
class Node()
を用いています。