1from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
2from titan_pylib.my_class.supports_less_than import SupportsLessThan
3from typing import Generic, Iterable, TypeVar, Optional, Sequence
4
5T = TypeVar("T", bound=SupportsLessThan)
6
7
[docs]
8class AVLTreeSet3(OrderedSetInterface, Generic[T]):
9 """
10 集合としての AVL木 です。
11 size を持ちます。
12 ``class Node()`` を用いています。
13 """
14
[docs]
15 class Node:
16
17 def __init__(self, key: T):
18 self.key: T = key
19 self.size: int = 1
20 self.left: Optional["AVLTreeSet3.Node"] = None
21 self.right: Optional["AVLTreeSet3.Node"] = None
22 self.balance: int = 0
23
24 def __str__(self):
25 if self.left is None and self.right is None:
26 return f"key:{self.key, self.size}\n"
27 return (
28 f"key:{self.key, self.size},\n left:{self.left},\n right:{self.right}\n"
29 )
30
31 def __init__(self, a: Iterable[T] = []) -> None:
32 self.node = None
33 if not isinstance(a, Sequence):
34 a = list(a)
35 if a:
36 self._build(a)
37
38 def _build(self, a: Sequence[T]) -> None:
39 Node = AVLTreeSet3.Node
40
41 def rec(l: int, r: int) -> tuple[AVLTreeSet3.Node, int]:
42 mid = (l + r) >> 1
43 node = Node(a[mid])
44 hl, hr = 0, 0
45 if l != mid:
46 node.left, hl = rec(l, mid)
47 node.size += node.left.size
48 if mid + 1 != r:
49 node.right, hr = rec(mid + 1, r)
50 node.size += node.right.size
51 node.balance = hl - hr
52 return node, max(hl, hr) + 1
53
54 if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
55 a = sorted(set(a))
56 self.node = rec(0, len(a))[0]
57
58 def _rotate_L(self, node: Node) -> Node:
59 u = node.left
60 u.size = node.size
61 node.size -= 1 if u.left is None else u.left.size + 1
62 node.left = u.right
63 u.right = node
64 if u.balance == 1:
65 u.balance = 0
66 node.balance = 0
67 else:
68 u.balance = -1
69 node.balance = 1
70 return u
71
72 def _rotate_R(self, node: Node) -> Node:
73 u = node.right
74 u.size = node.size
75 node.size -= 1 if u.right is None else u.right.size + 1
76 node.right = u.left
77 u.left = node
78 if u.balance == -1:
79 u.balance = 0
80 node.balance = 0
81 else:
82 u.balance = 1
83 node.balance = -1
84 return u
85
86 def _update_balance(self, node: Node) -> None:
87 if node.balance == 1:
88 node.right.balance = -1
89 node.left.balance = 0
90 elif node.balance == -1:
91 node.right.balance = 0
92 node.left.balance = 1
93 else:
94 node.right.balance = 0
95 node.left.balance = 0
96 node.balance = 0
97
98 def _rotate_LR(self, node: Node) -> Node:
99 B = node.left
100 E = B.right
101 E.size = node.size
102 if E.right is None:
103 node.size -= B.size
104 B.size -= 1
105 else:
106 node.size -= B.size - E.right.size
107 B.size -= E.right.size + 1
108 B.right = E.left
109 E.left = B
110 node.left = E.right
111 E.right = node
112 self._update_balance(E)
113 return E
114
115 def _rotate_RL(self, node: Node) -> Node:
116 C = node.right
117 D = C.left
118 D.size = node.size
119 if D.left is None:
120 node.size -= C.size
121 C.size -= 1
122 else:
123 node.size -= C.size - D.left.size
124 C.size -= D.left.size + 1
125 C.left = D.right
126 D.right = C
127 node.right = D.left
128 D.left = node
129 self._update_balance(D)
130 return D
131
132 def _kth_elm(self, k: int) -> T:
133 if k < 0:
134 k += self.node.size
135 node = self.node
136 while True:
137 t = 0 if node.left is None else node.left.size
138 if t == k:
139 return node.key
140 elif t < k:
141 k -= t + 1
142 node = node.right
143 else:
144 node = node.left
145
[docs]
146 def add(self, key: T) -> bool:
147 if self.node is None:
148 self.node = AVLTreeSet3.Node(key)
149 return True
150 pnode = self.node
151 path = []
152 di = 0
153 while pnode is not None:
154 if key == pnode.key:
155 return False
156 elif key < pnode.key:
157 path.append(pnode)
158 di <<= 1
159 di |= 1
160 pnode = pnode.left
161 else:
162 path.append(pnode)
163 di <<= 1
164 pnode = pnode.right
165 if di & 1:
166 path[-1].left = AVLTreeSet3.Node(key)
167 else:
168 path[-1].right = AVLTreeSet3.Node(key)
169 new_node = None
170 while path:
171 pnode = path.pop()
172 pnode.size += 1
173 pnode.balance += 1 if di & 1 else -1
174 di >>= 1
175 if pnode.balance == 0:
176 break
177 if pnode.balance == 2:
178 new_node = (
179 self._rotate_LR(pnode)
180 if pnode.left.balance == -1
181 else self._rotate_L(pnode)
182 )
183 break
184 elif pnode.balance == -2:
185 new_node = (
186 self._rotate_RL(pnode)
187 if pnode.right.balance == 1
188 else self._rotate_R(pnode)
189 )
190 break
191 if new_node is not None:
192 if path:
193 gnode = path.pop()
194 gnode.size += 1
195 if di & 1:
196 gnode.left = new_node
197 else:
198 gnode.right = new_node
199 else:
200 self.node = new_node
201 for p in path:
202 p.size += 1
203 return True
204
[docs]
205 def discard(self, key: T) -> bool:
206 di = 0
207 path = []
208 node = self.node
209 while node:
210 if key == node.key:
211 break
212 elif key < node.key:
213 path.append(node)
214 di <<= 1
215 di |= 1
216 node = node.left
217 else:
218 path.append(node)
219 di <<= 1
220 node = node.right
221 else:
222 return False
223 if node.left and node.right:
224 path.append(node)
225 di <<= 1
226 di |= 1
227 lmax = node.left
228 while lmax.right:
229 path.append(lmax)
230 di <<= 1
231 lmax = lmax.right
232 node.key = lmax.key
233 node = lmax
234 cnode = node.right if node.left is None else node.left
235 if path:
236 if di & 1:
237 path[-1].left = cnode
238 else:
239 path[-1].right = cnode
240 else:
241 self.node = cnode
242 return True
243 while path:
244 new_node = None
245 pnode = path.pop()
246 pnode.balance -= 1 if di & 1 else -1
247 di >>= 1
248 pnode.size -= 1
249 if pnode.balance == 2:
250 new_node = (
251 self._rotate_LR(pnode)
252 if pnode.left.balance == -1
253 else self._rotate_L(pnode)
254 )
255 elif pnode.balance == -2:
256 new_node = (
257 self._rotate_RL(pnode)
258 if pnode.right.balance == 1
259 else self._rotate_R(pnode)
260 )
261 elif pnode.balance != 0:
262 break
263 if new_node:
264 if not path:
265 self.node = new_node
266 return True
267 if di & 1:
268 path[-1].left = new_node
269 else:
270 path[-1].right = new_node
271 if new_node.balance != 0:
272 break
273 for p in path:
274 p.size -= 1
275 return True
276
[docs]
277 def remove(self, key: T) -> None:
278 if self.discard(key):
279 return
280 raise KeyError(key)
281
[docs]
282 def le(self, key: T) -> Optional[T]:
283 res = None
284 node = self.node
285 while node is not None:
286 if key == node.key:
287 res = key
288 break
289 elif key < node.key:
290 node = node.left
291 else:
292 res = node.key
293 node = node.right
294 return res
295
[docs]
296 def lt(self, key: T) -> Optional[T]:
297 res = None
298 node = self.node
299 while node is not None:
300 if key <= node.key:
301 node = node.left
302 else:
303 res = node.key
304 node = node.right
305 return res
306
[docs]
307 def ge(self, key: T) -> Optional[T]:
308 res = None
309 node = self.node
310 while node is not None:
311 if key == node.key:
312 res = key
313 break
314 elif key < node.key:
315 res = node.key
316 node = node.left
317 else:
318 node = node.right
319 return res
320
[docs]
321 def gt(self, key: T) -> Optional[T]:
322 res = None
323 node = self.node
324 while node is not None:
325 if key < node.key:
326 res = node.key
327 node = node.left
328 else:
329 node = node.right
330 return res
331
[docs]
332 def index(self, key: T) -> int:
333 k = 0
334 node = self.node
335 while node is not None:
336 if key == node.key:
337 k += 0 if node.left is None else node.left.size
338 break
339 elif key < node.key:
340 node = node.left
341 else:
342 k += 1 if node.left is None else node.left.size + 1
343 node = node.right
344 return k
345
[docs]
346 def index_right(self, key: T) -> int:
347 k = 0
348 node = self.node
349 while node is not None:
350 if key == node.key:
351 k += 1 if node.left is None else node.left.size + 1
352 break
353 elif key < node.key:
354 node = node.left
355 else:
356 k += 1 if node.left is None else node.left.size + 1
357 node = node.right
358 return k
359
[docs]
360 def pop(self, k: int = -1) -> T:
361 assert (
362 self.node is not None
363 ), f"IndexError: {self.__class__.__name__}.pop({k}), pop({k}) from Empty {self.__class__.__name__}"
364 x = self._kth_elm(k)
365 self.discard(x)
366 return x
367
[docs]
368 def pop_max(self) -> T:
369 assert (
370 self.node is not None
371 ), f"IndexError: {self.__class__.__name__}.pop_max(), pop_max from Empty {self.__class__.__name__}"
372 return self.pop()
373
[docs]
374 def pop_min(self) -> T:
375 assert (
376 self.node is not None
377 ), f"IndexError: {self.__class__.__name__}.pop_min(), pop_min from Empty {self.__class__.__name__}"
378 return self.pop(0)
379
[docs]
380 def get_max(self) -> Optional[T]:
381 if self.node is None:
382 return
383 return self._kth_elm(-1)
384
[docs]
385 def get_min(self) -> Optional[T]:
386 if self.node is None:
387 return
388 return self._kth_elm(0)
389
[docs]
390 def clear(self) -> None:
391 self.node = None
392
[docs]
393 def tolist(self) -> list[T]:
394 a = []
395 if self.node is None:
396 return a
397
398 def rec(node):
399 if node.left is not None:
400 rec(node.left)
401 a.append(node.key)
402 if node.right is not None:
403 rec(node.right)
404
405 rec(self.node)
406 return a
407
408 def __contains__(self, key: T) -> bool:
409 node = self.node
410 while node is not None:
411 if key == node.key:
412 return True
413 elif key < node.key:
414 node = node.left
415 else:
416 node = node.right
417 return False
418
419 def __getitem__(self, k: int) -> T:
420 assert (
421 -len(self) <= k < len(self)
422 ), f"IndexError: {self.__class__.__name__}.__getitem__({k}), len={len(self)}"
423 return self._kth_elm(k)
424
425 def __iter__(self):
426 self.__iter = 0
427 return self
428
429 def __next__(self):
430 if self.__iter == self.__len__():
431 raise StopIteration
432 res = self.__getitem__(self.__iter)
433 self.__iter += 1
434 return res
435
436 def __reversed__(self):
437 for i in range(self.__len__()):
438 yield self.__getitem__(-i - 1)
439
440 def __len__(self):
441 return 0 if self.node is None else self.node.size
442
443 def __bool__(self):
444 return self.node is not None
445
446 def __str__(self):
447 return "{" + ", ".join(map(str, self.tolist())) + "}"
448
449 def __repr__(self):
450 return f"AVLTreeSet3({str(self)})"