avl_tree_set2¶
ソースコード¶
from titan_pylib.data_structures.avl_tree.avl_tree_set2 import AVLTreeSet2
展開済みコード¶
1# from titan_pylib.data_structures.avl_tree.avl_tree_set2 import AVLTreeSet2
2# from titan_pylib.my_class.supports_less_than import SupportsLessThan
3from typing import Protocol
4
5
6class SupportsLessThan(Protocol):
7
8 def __lt__(self, other) -> bool: ...
9# from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
10# from titan_pylib.my_class.supports_less_than import SupportsLessThan
11from abc import ABC, abstractmethod
12from typing import Iterable, Optional, Iterator, TypeVar, Generic
13
14T = TypeVar("T", bound=SupportsLessThan)
15
16
17class OrderedSetInterface(ABC, Generic[T]):
18
19 @abstractmethod
20 def __init__(self, a: Iterable[T]) -> None:
21 raise NotImplementedError
22
23 @abstractmethod
24 def add(self, key: T) -> bool:
25 raise NotImplementedError
26
27 @abstractmethod
28 def discard(self, key: T) -> bool:
29 raise NotImplementedError
30
31 @abstractmethod
32 def remove(self, key: T) -> None:
33 raise NotImplementedError
34
35 @abstractmethod
36 def le(self, key: T) -> Optional[T]:
37 raise NotImplementedError
38
39 @abstractmethod
40 def lt(self, key: T) -> Optional[T]:
41 raise NotImplementedError
42
43 @abstractmethod
44 def ge(self, key: T) -> Optional[T]:
45 raise NotImplementedError
46
47 @abstractmethod
48 def gt(self, key: T) -> Optional[T]:
49 raise NotImplementedError
50
51 @abstractmethod
52 def get_max(self) -> Optional[T]:
53 raise NotImplementedError
54
55 @abstractmethod
56 def get_min(self) -> Optional[T]:
57 raise NotImplementedError
58
59 @abstractmethod
60 def pop_max(self) -> T:
61 raise NotImplementedError
62
63 @abstractmethod
64 def pop_min(self) -> T:
65 raise NotImplementedError
66
67 @abstractmethod
68 def clear(self) -> None:
69 raise NotImplementedError
70
71 @abstractmethod
72 def tolist(self) -> list[T]:
73 raise NotImplementedError
74
75 @abstractmethod
76 def __iter__(self) -> Iterator:
77 raise NotImplementedError
78
79 @abstractmethod
80 def __next__(self) -> T:
81 raise NotImplementedError
82
83 @abstractmethod
84 def __contains__(self, key: T) -> bool:
85 raise NotImplementedError
86
87 @abstractmethod
88 def __len__(self) -> int:
89 raise NotImplementedError
90
91 @abstractmethod
92 def __bool__(self) -> bool:
93 raise NotImplementedError
94
95 @abstractmethod
96 def __str__(self) -> str:
97 raise NotImplementedError
98
99 @abstractmethod
100 def __repr__(self) -> str:
101 raise NotImplementedError
102from array import array
103from typing import Generic, Iterable, TypeVar, Optional
104
105T = TypeVar("T", bound=SupportsLessThan)
106
107
108class AVLTreeSet2(OrderedSetInterface, Generic[T]):
109 """AVLTreeSet2
110 集合としての AVL 木です。
111 配列を用いてノードを表現しています。
112 size を持たないので軽めです。
113 """
114
115 def __init__(self, a: Iterable[T] = []) -> None:
116 self.root = 0
117 self._len = 0
118 self.key = [0]
119 self.left = array("I", bytes(4))
120 self.right = array("I", bytes(4))
121 self.balance = array("b", bytes(1))
122 self.end = 1
123 if not isinstance(a, list):
124 a = list(a)
125 if a:
126 self._build(a)
127
128 def reserve(self, n: int) -> None:
129 self.key += [0] * n
130 a = array("I", bytes(4 * n))
131 self.left += a
132 self.right += a
133 self.balance += array("b", bytes(n))
134
135 def _build(self, a: list[T]) -> None:
136 left, right, balance = self.left, self.right, self.balance
137
138 def sort(l: int, r: int) -> tuple[int, int]:
139 mid = (l + r) >> 1
140 node = mid
141 hl, hr = 0, 0
142 if l != mid:
143 left[node], hl = sort(l, mid)
144 if mid + 1 != r:
145 right[node], hr = sort(mid + 1, r)
146 balance[node] = hl - hr
147 return node, max(hl, hr) + 1
148
149 n = len(a)
150 if n == 0:
151 return
152 if not all(a[i] < a[i + 1] for i in range(n - 1)):
153 b = sorted(a)
154 a = [b[0]]
155 for i in range(1, n):
156 if b[i] != a[-1]:
157 a.append(b[i])
158 n = len(a)
159 self._len = n
160 end = self.end
161 self.end += n
162 self.reserve(n)
163 self.key[end : end + n] = a
164 self.root = sort(end, n + end)[0]
165
166 def _rotate_L(self, node: int) -> int:
167 left, right, balance = self.left, self.right, self.balance
168 u = left[node]
169 left[node] = right[u]
170 right[u] = node
171 if balance[u] == 1:
172 balance[u] = 0
173 balance[node] = 0
174 else:
175 balance[u] = -1
176 balance[node] = 1
177 return u
178
179 def _rotate_R(self, node: int) -> int:
180 left, right, balance = self.left, self.right, self.balance
181 u = right[node]
182 right[node] = left[u]
183 left[u] = node
184 if balance[u] == -1:
185 balance[u] = 0
186 balance[node] = 0
187 else:
188 balance[u] = 1
189 balance[node] = -1
190 return u
191
192 def _update_balance(self, node: int) -> None:
193 balance = self.balance
194 if balance[node] == 1:
195 balance[self.right[node]] = -1
196 balance[self.left[node]] = 0
197 elif balance[node] == -1:
198 balance[self.right[node]] = 0
199 balance[self.left[node]] = 1
200 else:
201 balance[self.right[node]] = 0
202 balance[self.left[node]] = 0
203 balance[node] = 0
204
205 def _rotate_LR(self, node: int) -> int:
206 left, right = self.left, self.right
207 B = left[node]
208 E = right[B]
209 right[B] = left[E]
210 left[E] = B
211 left[node] = right[E]
212 right[E] = node
213 self._update_balance(E)
214 return E
215
216 def _rotate_RL(self, node: int) -> int:
217 left, right = self.left, self.right
218 C = right[node]
219 D = left[C]
220 left[C] = right[D]
221 right[D] = C
222 right[node] = left[D]
223 left[D] = node
224 self._update_balance(D)
225 return D
226
227 def _make_node(self, key: T) -> int:
228 end = self.end
229 if end >= len(self.key):
230 self.key.append(key)
231 self.left.append(0)
232 self.right.append(0)
233 self.balance.append(0)
234 else:
235 self.key[end] = key
236 self.end += 1
237 return end
238
239 def add(self, key: T) -> bool:
240 if self.root == 0:
241 self.root = self._make_node(key)
242 self._len = 1
243 return True
244 left, right, balance, keys = self.left, self.right, self.balance, self.key
245 node = self.root
246 path = []
247 di = 0
248 while node:
249 if key == keys[node]:
250 return False
251 di <<= 1
252 path.append(node)
253 if key < keys[node]:
254 di |= 1
255 node = left[node]
256 else:
257 node = right[node]
258 self._len += 1
259 if di & 1:
260 left[path[-1]] = self._make_node(key)
261 else:
262 right[path[-1]] = self._make_node(key)
263 new_node = 0
264 while path:
265 pnode = path.pop()
266 balance[pnode] += 1 if di & 1 else -1
267 di >>= 1
268 if balance[pnode] == 0:
269 break
270 if balance[pnode] == 2:
271 new_node = (
272 self._rotate_LR(pnode)
273 if balance[left[pnode]] == -1
274 else self._rotate_L(pnode)
275 )
276 break
277 elif balance[pnode] == -2:
278 new_node = (
279 self._rotate_RL(pnode)
280 if balance[right[pnode]] == 1
281 else self._rotate_R(pnode)
282 )
283 break
284 if new_node:
285 if path:
286 gnode = path.pop()
287 if di & 1:
288 left[gnode] = new_node
289 else:
290 right[gnode] = new_node
291 else:
292 self.root = new_node
293 return True
294
295 def remove(self, key: T) -> bool:
296 if self.discard(key):
297 return True
298 raise KeyError(key)
299
300 def discard(self, key: T) -> bool:
301 left, right, balance, keys = self.left, self.right, self.balance, self.key
302 di = 0
303 path = []
304 node = self.root
305 while node:
306 if key == keys[node]:
307 break
308 path.append(node)
309 di <<= 1
310 if key < keys[node]:
311 di |= 1
312 node = left[node]
313 else:
314 node = right[node]
315 else:
316 return False
317 self._len -= 1
318 if left[node] and right[node]:
319 path.append(node)
320 di <<= 1
321 di |= 1
322 lmax = left[node]
323 while right[lmax]:
324 path.append(lmax)
325 di <<= 1
326 lmax = right[lmax]
327 keys[node] = keys[lmax]
328 node = lmax
329 cnode = right[node] if left[node] == 0 else left[node]
330 if path:
331 if di & 1:
332 left[path[-1]] = cnode
333 else:
334 right[path[-1]] = cnode
335 else:
336 self.root = cnode
337 return True
338 while path:
339 new_node = 0
340 pnode = path.pop()
341 balance[pnode] -= 1 if di & 1 else -1
342 di >>= 1
343 if balance[pnode] == 2:
344 new_node = (
345 self._rotate_LR(pnode)
346 if balance[left[pnode]] == -1
347 else self._rotate_L(pnode)
348 )
349 elif balance[pnode] == -2:
350 new_node = (
351 self._rotate_RL(pnode)
352 if balance[right[pnode]] == 1
353 else self._rotate_R(pnode)
354 )
355 elif balance[pnode]:
356 break
357 if new_node:
358 if not path:
359 self.root = new_node
360 return True
361 if di & 1:
362 left[path[-1]] = new_node
363 else:
364 right[path[-1]] = new_node
365 if balance[new_node]:
366 break
367 return True
368
369 def le(self, key: T) -> Optional[T]:
370 keys, left, right = self.key, self.left, self.right
371 res = None
372 node = self.root
373 while node:
374 if key == keys[node]:
375 res = key
376 break
377 if key < keys[node]:
378 node = left[node]
379 else:
380 res = keys[node]
381 node = right[node]
382 return res
383
384 def lt(self, key: T) -> Optional[T]:
385 keys, left, right = self.key, self.left, self.right
386 res = None
387 node = self.root
388 while node:
389 if key <= keys[node]:
390 node = left[node]
391 else:
392 res = keys[node]
393 node = right[node]
394 return res
395
396 def ge(self, key: T) -> Optional[T]:
397 keys, left, right = self.key, self.left, self.right
398 res = None
399 node = self.root
400 while node:
401 if key == keys[node]:
402 res = key
403 break
404 if key < keys[node]:
405 res = keys[node]
406 node = left[node]
407 else:
408 node = right[node]
409 return res
410
411 def gt(self, key: T) -> Optional[T]:
412 keys, left, right = self.key, self.left, self.right
413 res = None
414 node = self.root
415 while node:
416 if key < keys[node]:
417 res = keys[node]
418 node = left[node]
419 else:
420 node = right[node]
421 return res
422
423 def get_max(self) -> Optional[T]:
424 if not self:
425 return
426 right = self.right
427 node = self.root
428 while right[node]:
429 node = right[node]
430 return self.key[node]
431
432 def get_min(self) -> Optional[T]:
433 if not self:
434 return
435 left = self.left
436 node = self.root
437 while left[node]:
438 node = left[node]
439 return self.key[node]
440
441 def pop_min(self) -> T:
442 self._len -= 1
443 left, right, balance, keys = self.left, self.right, self.balance, self.key
444 path = []
445 node = self.root
446 while left[node]:
447 path.append(node)
448 node = left[node]
449 res = keys[node]
450 cnode = right[node]
451 if path:
452 left[path[-1]] = cnode
453 else:
454 self.root = cnode
455 return res
456 while path:
457 new_node = 0
458 pnode = path.pop()
459 balance[pnode] -= 1
460 if balance[pnode] == 2:
461 new_node = (
462 self._rotate_LR(pnode)
463 if balance[left[pnode]] == -1
464 else self._rotate_L(pnode)
465 )
466 elif balance[pnode] == -2:
467 new_node = (
468 self._rotate_RL(pnode)
469 if balance[right[pnode]] == 1
470 else self._rotate_R(pnode)
471 )
472 elif balance[pnode]:
473 break
474 if new_node:
475 if not path:
476 self.root = new_node
477 return res
478 left[path[-1]] = new_node
479 if balance[new_node]:
480 break
481 return res
482
483 def pop_max(self) -> T:
484 self._len -= 1
485 left, right, balance, keys = self.left, self.right, self.balance, self.key
486 path = []
487 node = self.root
488 while right[node]:
489 path.append(node)
490 node = right[node]
491 res = keys[node]
492 cnode = right[node] if left[node] == 0 else left[node]
493 if path:
494 right[path[-1]] = cnode
495 else:
496 self.root = cnode
497 return res
498 while path:
499 new_node = 0
500 pnode = path.pop()
501 balance[pnode] += 1
502 if balance[pnode] == 2:
503 new_node = (
504 self._rotate_LR(pnode)
505 if balance[left[pnode]] == -1
506 else self._rotate_L(pnode)
507 )
508 elif balance[pnode] == -2:
509 new_node = (
510 self._rotate_RL(pnode)
511 if balance[right[pnode]] == 1
512 else self._rotate_R(pnode)
513 )
514 elif balance[pnode]:
515 break
516 if new_node:
517 if not path:
518 self.root = new_node
519 return res
520 right[path[-1]] = new_node
521 if balance[new_node]:
522 break
523 return res
524
525 def clear(self) -> None:
526 self.root = 0
527
528 def tolist(self) -> list[T]:
529 left, right, keys = self.left, self.right, self.key
530 node = self.root
531 stack, a = [], []
532 while stack or node:
533 if node:
534 stack.append(node)
535 node = left[node]
536 else:
537 node = stack.pop()
538 a.append(keys[node])
539 node = right[node]
540 return a
541
542 def __contains__(self, key: T) -> bool:
543 keys, left, right = self.key, self.left, self.right
544 node = self.root
545 while node:
546 if key == keys[node]:
547 return True
548 node = left[node] if key < keys[node] else right[node]
549 return False
550
551 def __iter__(self):
552 self.it = self.get_min()
553 return self
554
555 def __next__(self):
556 if self.it is None:
557 raise StopIteration
558 res = self.it
559 self.it = self.gt(res)
560 return res
561
562 def __len__(self):
563 return self._len
564
565 def __bool__(self):
566 return self.root != 0
567
568 def __str__(self):
569 return "{" + ", ".join(map(str, self.tolist())) + "}"
570
571 def __repr__(self):
572 return f"AVLTreeSet2({self})"
仕様¶
- class AVLTreeSet2(a: Iterable[T] = [])[source]¶
Bases:
OrderedSetInterface
,Generic
[T
]集合としての AVL 木です。 配列を用いてノードを表現しています。 size を持たないので軽めです。