1from titan_pylib.my_class.supports_less_than import SupportsLessThan
2from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
3from array import array
4from typing import Generic, Iterable, TypeVar, Optional
5
6T = TypeVar("T", bound=SupportsLessThan)
7
8
[docs]
9class AVLTreeSet2(OrderedSetInterface, Generic[T]):
10 """AVLTreeSet2
11 集合としての AVL 木です。
12 配列を用いてノードを表現しています。
13 size を持たないので軽めです。
14 """
15
16 def __init__(self, a: Iterable[T] = []) -> None:
17 self.root = 0
18 self._len = 0
19 self.key = [0]
20 self.left = array("I", bytes(4))
21 self.right = array("I", bytes(4))
22 self.balance = array("b", bytes(1))
23 self.end = 1
24 if not isinstance(a, list):
25 a = list(a)
26 if a:
27 self._build(a)
28
[docs]
29 def reserve(self, n: int) -> None:
30 self.key += [0] * n
31 a = array("I", bytes(4 * n))
32 self.left += a
33 self.right += a
34 self.balance += array("b", bytes(n))
35
36 def _build(self, a: list[T]) -> None:
37 left, right, balance = self.left, self.right, self.balance
38
39 def sort(l: int, r: int) -> tuple[int, int]:
40 mid = (l + r) >> 1
41 node = mid
42 hl, hr = 0, 0
43 if l != mid:
44 left[node], hl = sort(l, mid)
45 if mid + 1 != r:
46 right[node], hr = sort(mid + 1, r)
47 balance[node] = hl - hr
48 return node, max(hl, hr) + 1
49
50 n = len(a)
51 if n == 0:
52 return
53 if not all(a[i] < a[i + 1] for i in range(n - 1)):
54 b = sorted(a)
55 a = [b[0]]
56 for i in range(1, n):
57 if b[i] != a[-1]:
58 a.append(b[i])
59 n = len(a)
60 self._len = n
61 end = self.end
62 self.end += n
63 self.reserve(n)
64 self.key[end : end + n] = a
65 self.root = sort(end, n + end)[0]
66
67 def _rotate_L(self, node: int) -> int:
68 left, right, balance = self.left, self.right, self.balance
69 u = left[node]
70 left[node] = right[u]
71 right[u] = node
72 if balance[u] == 1:
73 balance[u] = 0
74 balance[node] = 0
75 else:
76 balance[u] = -1
77 balance[node] = 1
78 return u
79
80 def _rotate_R(self, node: int) -> int:
81 left, right, balance = self.left, self.right, self.balance
82 u = right[node]
83 right[node] = left[u]
84 left[u] = node
85 if balance[u] == -1:
86 balance[u] = 0
87 balance[node] = 0
88 else:
89 balance[u] = 1
90 balance[node] = -1
91 return u
92
93 def _update_balance(self, node: int) -> None:
94 balance = self.balance
95 if balance[node] == 1:
96 balance[self.right[node]] = -1
97 balance[self.left[node]] = 0
98 elif balance[node] == -1:
99 balance[self.right[node]] = 0
100 balance[self.left[node]] = 1
101 else:
102 balance[self.right[node]] = 0
103 balance[self.left[node]] = 0
104 balance[node] = 0
105
106 def _rotate_LR(self, node: int) -> int:
107 left, right = self.left, self.right
108 B = left[node]
109 E = right[B]
110 right[B] = left[E]
111 left[E] = B
112 left[node] = right[E]
113 right[E] = node
114 self._update_balance(E)
115 return E
116
117 def _rotate_RL(self, node: int) -> int:
118 left, right = self.left, self.right
119 C = right[node]
120 D = left[C]
121 left[C] = right[D]
122 right[D] = C
123 right[node] = left[D]
124 left[D] = node
125 self._update_balance(D)
126 return D
127
128 def _make_node(self, key: T) -> int:
129 end = self.end
130 if end >= len(self.key):
131 self.key.append(key)
132 self.left.append(0)
133 self.right.append(0)
134 self.balance.append(0)
135 else:
136 self.key[end] = key
137 self.end += 1
138 return end
139
[docs]
140 def add(self, key: T) -> bool:
141 if self.root == 0:
142 self.root = self._make_node(key)
143 self._len = 1
144 return True
145 left, right, balance, keys = self.left, self.right, self.balance, self.key
146 node = self.root
147 path = []
148 di = 0
149 while node:
150 if key == keys[node]:
151 return False
152 di <<= 1
153 path.append(node)
154 if key < keys[node]:
155 di |= 1
156 node = left[node]
157 else:
158 node = right[node]
159 self._len += 1
160 if di & 1:
161 left[path[-1]] = self._make_node(key)
162 else:
163 right[path[-1]] = self._make_node(key)
164 new_node = 0
165 while path:
166 pnode = path.pop()
167 balance[pnode] += 1 if di & 1 else -1
168 di >>= 1
169 if balance[pnode] == 0:
170 break
171 if balance[pnode] == 2:
172 new_node = (
173 self._rotate_LR(pnode)
174 if balance[left[pnode]] == -1
175 else self._rotate_L(pnode)
176 )
177 break
178 elif balance[pnode] == -2:
179 new_node = (
180 self._rotate_RL(pnode)
181 if balance[right[pnode]] == 1
182 else self._rotate_R(pnode)
183 )
184 break
185 if new_node:
186 if path:
187 gnode = path.pop()
188 if di & 1:
189 left[gnode] = new_node
190 else:
191 right[gnode] = new_node
192 else:
193 self.root = new_node
194 return True
195
[docs]
196 def remove(self, key: T) -> bool:
197 if self.discard(key):
198 return True
199 raise KeyError(key)
200
[docs]
201 def discard(self, key: T) -> bool:
202 left, right, balance, keys = self.left, self.right, self.balance, self.key
203 di = 0
204 path = []
205 node = self.root
206 while node:
207 if key == keys[node]:
208 break
209 path.append(node)
210 di <<= 1
211 if key < keys[node]:
212 di |= 1
213 node = left[node]
214 else:
215 node = right[node]
216 else:
217 return False
218 self._len -= 1
219 if left[node] and right[node]:
220 path.append(node)
221 di <<= 1
222 di |= 1
223 lmax = left[node]
224 while right[lmax]:
225 path.append(lmax)
226 di <<= 1
227 lmax = right[lmax]
228 keys[node] = keys[lmax]
229 node = lmax
230 cnode = right[node] if left[node] == 0 else left[node]
231 if path:
232 if di & 1:
233 left[path[-1]] = cnode
234 else:
235 right[path[-1]] = cnode
236 else:
237 self.root = cnode
238 return True
239 while path:
240 new_node = 0
241 pnode = path.pop()
242 balance[pnode] -= 1 if di & 1 else -1
243 di >>= 1
244 if balance[pnode] == 2:
245 new_node = (
246 self._rotate_LR(pnode)
247 if balance[left[pnode]] == -1
248 else self._rotate_L(pnode)
249 )
250 elif balance[pnode] == -2:
251 new_node = (
252 self._rotate_RL(pnode)
253 if balance[right[pnode]] == 1
254 else self._rotate_R(pnode)
255 )
256 elif balance[pnode]:
257 break
258 if new_node:
259 if not path:
260 self.root = new_node
261 return True
262 if di & 1:
263 left[path[-1]] = new_node
264 else:
265 right[path[-1]] = new_node
266 if balance[new_node]:
267 break
268 return True
269
[docs]
270 def le(self, key: T) -> Optional[T]:
271 keys, left, right = self.key, self.left, self.right
272 res = None
273 node = self.root
274 while node:
275 if key == keys[node]:
276 res = key
277 break
278 if key < keys[node]:
279 node = left[node]
280 else:
281 res = keys[node]
282 node = right[node]
283 return res
284
[docs]
285 def lt(self, key: T) -> Optional[T]:
286 keys, left, right = self.key, self.left, self.right
287 res = None
288 node = self.root
289 while node:
290 if key <= keys[node]:
291 node = left[node]
292 else:
293 res = keys[node]
294 node = right[node]
295 return res
296
[docs]
297 def ge(self, key: T) -> Optional[T]:
298 keys, left, right = self.key, self.left, self.right
299 res = None
300 node = self.root
301 while node:
302 if key == keys[node]:
303 res = key
304 break
305 if key < keys[node]:
306 res = keys[node]
307 node = left[node]
308 else:
309 node = right[node]
310 return res
311
[docs]
312 def gt(self, key: T) -> Optional[T]:
313 keys, left, right = self.key, self.left, self.right
314 res = None
315 node = self.root
316 while node:
317 if key < keys[node]:
318 res = keys[node]
319 node = left[node]
320 else:
321 node = right[node]
322 return res
323
[docs]
324 def get_max(self) -> Optional[T]:
325 if not self:
326 return
327 right = self.right
328 node = self.root
329 while right[node]:
330 node = right[node]
331 return self.key[node]
332
[docs]
333 def get_min(self) -> Optional[T]:
334 if not self:
335 return
336 left = self.left
337 node = self.root
338 while left[node]:
339 node = left[node]
340 return self.key[node]
341
[docs]
342 def pop_min(self) -> T:
343 self._len -= 1
344 left, right, balance, keys = self.left, self.right, self.balance, self.key
345 path = []
346 node = self.root
347 while left[node]:
348 path.append(node)
349 node = left[node]
350 res = keys[node]
351 cnode = right[node]
352 if path:
353 left[path[-1]] = cnode
354 else:
355 self.root = cnode
356 return res
357 while path:
358 new_node = 0
359 pnode = path.pop()
360 balance[pnode] -= 1
361 if balance[pnode] == 2:
362 new_node = (
363 self._rotate_LR(pnode)
364 if balance[left[pnode]] == -1
365 else self._rotate_L(pnode)
366 )
367 elif balance[pnode] == -2:
368 new_node = (
369 self._rotate_RL(pnode)
370 if balance[right[pnode]] == 1
371 else self._rotate_R(pnode)
372 )
373 elif balance[pnode]:
374 break
375 if new_node:
376 if not path:
377 self.root = new_node
378 return res
379 left[path[-1]] = new_node
380 if balance[new_node]:
381 break
382 return res
383
[docs]
384 def pop_max(self) -> T:
385 self._len -= 1
386 left, right, balance, keys = self.left, self.right, self.balance, self.key
387 path = []
388 node = self.root
389 while right[node]:
390 path.append(node)
391 node = right[node]
392 res = keys[node]
393 cnode = right[node] if left[node] == 0 else left[node]
394 if path:
395 right[path[-1]] = cnode
396 else:
397 self.root = cnode
398 return res
399 while path:
400 new_node = 0
401 pnode = path.pop()
402 balance[pnode] += 1
403 if balance[pnode] == 2:
404 new_node = (
405 self._rotate_LR(pnode)
406 if balance[left[pnode]] == -1
407 else self._rotate_L(pnode)
408 )
409 elif balance[pnode] == -2:
410 new_node = (
411 self._rotate_RL(pnode)
412 if balance[right[pnode]] == 1
413 else self._rotate_R(pnode)
414 )
415 elif balance[pnode]:
416 break
417 if new_node:
418 if not path:
419 self.root = new_node
420 return res
421 right[path[-1]] = new_node
422 if balance[new_node]:
423 break
424 return res
425
[docs]
426 def clear(self) -> None:
427 self.root = 0
428
[docs]
429 def tolist(self) -> list[T]:
430 left, right, keys = self.left, self.right, self.key
431 node = self.root
432 stack, a = [], []
433 while stack or node:
434 if node:
435 stack.append(node)
436 node = left[node]
437 else:
438 node = stack.pop()
439 a.append(keys[node])
440 node = right[node]
441 return a
442
443 def __contains__(self, key: T) -> bool:
444 keys, left, right = self.key, self.left, self.right
445 node = self.root
446 while node:
447 if key == keys[node]:
448 return True
449 node = left[node] if key < keys[node] else right[node]
450 return False
451
452 def __iter__(self):
453 self.it = self.get_min()
454 return self
455
456 def __next__(self):
457 if self.it is None:
458 raise StopIteration
459 res = self.it
460 self.it = self.gt(res)
461 return res
462
463 def __len__(self):
464 return self._len
465
466 def __bool__(self):
467 return self.root != 0
468
469 def __str__(self):
470 return "{" + ", ".join(map(str, self.tolist())) + "}"
471
472 def __repr__(self):
473 return f"AVLTreeSet2({self})"