1from titan_pylib.my_class.supports_less_than import SupportsLessThan
2from titan_pylib.data_structures.bst_base.bst_multiset_array_base import (
3 BSTMultisetArrayBase,
4)
5from typing import Generic, Iterable, TypeVar, Optional
6from array import array
7
8T = TypeVar("T", bound=SupportsLessThan)
9
10
[docs]
11class AVLTreeMultiset2(Generic[T]):
12 """
13 多重集合としての AVL 木です。
14 配列を用いてノードを表現しています。
15 size を持たないので軽めです。
16 """
17
18 def __init__(self, a: Iterable[T] = []):
19 self.root = 0
20 self._len = 0
21 self.key = [0]
22 self.val = [0]
23 self.left = array("I", bytes(4))
24 self.right = array("I", bytes(4))
25 self.balance = array("b", bytes(1))
26 self.end = 1
27 if not isinstance(a, list):
28 a = list(a)
29 if a:
30 self._build(a)
31
32 def _make_node(self, key: T, val: int) -> int:
33 end = self.end
34 if end >= len(self.key):
35 self.key.append(key)
36 self.val.append(val)
37 self.left.append(0)
38 self.right.append(0)
39 self.balance.append(0)
40 else:
41 self.key[end] = key
42 self.val[end] = val
43 self.end += 1
44 return end
45
[docs]
46 def reserve(self, n: int) -> None:
47 a = [0] * n
48 self.key += a
49 self.val += a
50 a = array("I", bytes(4 * n))
51 self.left += a
52 self.right += a
53 self.balance += array("b", bytes(n))
54
55 def _build(self, a: list[T]) -> None:
56 left, right, balance = self.left, self.right, self.balance
57
58 def sort(l: int, r: int) -> tuple[int, int]:
59 mid = (l + r) >> 1
60 node = mid
61 hl, hr = 0, 0
62 if l != mid:
63 left[node], hl = sort(l, mid)
64 if mid + 1 != r:
65 right[node], hr = sort(mid + 1, r)
66 balance[node] = hl - hr
67 return node, max(hl, hr) + 1
68
69 self._len = len(a)
70 if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
71 a = sorted(a)
72 x, y = BSTMultisetArrayBase[AVLTreeMultiset2, T]._rle(a)
73 n = len(x)
74 end = self.end
75 self.end += n
76 self.reserve(n)
77 self.key[end : end + n] = x
78 self.val[end : end + n] = y
79 self.root = sort(end, n + end)[0]
80
81 def _rotate_L(self, node: int) -> int:
82 left, right, balance = self.left, self.right, self.balance
83 u = left[node]
84 left[node] = right[u]
85 right[u] = node
86 if balance[u] == 1:
87 balance[u] = 0
88 balance[node] = 0
89 else:
90 balance[u] = -1
91 balance[node] = 1
92 return u
93
94 def _rotate_R(self, node: int) -> int:
95 left, right, balance = self.left, self.right, self.balance
96 u = right[node]
97 right[node] = left[u]
98 left[u] = node
99 if balance[u] == -1:
100 balance[u] = 0
101 balance[node] = 0
102 else:
103 balance[u] = 1
104 balance[node] = -1
105 return u
106
107 def _update_balance(self, node: int) -> None:
108 left, right, balance = self.left, self.right, self.balance
109 if balance[node] == 1:
110 balance[right[node]] = -1
111 balance[left[node]] = 0
112 elif balance[node] == -1:
113 balance[right[node]] = 0
114 balance[left[node]] = 1
115 else:
116 balance[right[node]] = 0
117 balance[left[node]] = 0
118 balance[node] = 0
119
120 def _rotate_LR(self, node: int) -> int:
121 left, right = self.left, self.right
122 B = left[node]
123 E = right[B]
124 right[B] = left[E]
125 left[E] = B
126 left[node] = right[E]
127 right[E] = node
128 self._update_balance(E)
129 return E
130
131 def _rotate_RL(self, node: int) -> int:
132 left, right = self.left, self.right
133 C = right[node]
134 D = left[C]
135 left[C] = right[D]
136 right[D] = C
137 right[node] = left[D]
138 left[D] = node
139 self._update_balance(D)
140 return D
141
142 def _discard(self, node: int, path: list[int], di: int) -> bool:
143 left, right, keys, vals, balance = (
144 self.left,
145 self.right,
146 self.key,
147 self.val,
148 self.balance,
149 )
150 if left[node] and right[node]:
151 path.append(node)
152 di <<= 1
153 di |= 1
154 lmax = left[node]
155 while right[lmax]:
156 path.append(lmax)
157 di <<= 1
158 lmax = right[lmax]
159 lmax_val = vals[lmax]
160 keys[node] = keys[lmax]
161 vals[node] = lmax_val
162 node = lmax
163 cnode = right[node] if left[node] == 0 else left[node]
164 if path:
165 if di & 1:
166 left[path[-1]] = cnode
167 else:
168 right[path[-1]] = cnode
169 else:
170 self.root = cnode
171 return True
172 while path:
173 new_node = 0
174 pnode = path.pop()
175 balance[pnode] -= 1 if di & 1 else -1
176 di >>= 1
177 if balance[pnode] == 2:
178 new_node = (
179 self._rotate_LR(pnode)
180 if balance[left[pnode]] < 0
181 else self._rotate_L(pnode)
182 )
183 elif balance[pnode] == -2:
184 new_node = (
185 self._rotate_RL(pnode)
186 if balance[right[pnode]] > 0
187 else self._rotate_R(pnode)
188 )
189 elif balance[pnode] != 0:
190 break
191 if new_node:
192 if not path:
193 self.root = new_node
194 return
195 if di & 1:
196 left[path[-1]] = new_node
197 else:
198 right[path[-1]] = new_node
199 if balance[new_node] != 0:
200 break
201 return True
202
[docs]
203 def discard(self, key: T, val: int = 1) -> bool:
204 keys, vals, left, right = self.key, self.val, self.left, self.right
205 path = []
206 di = 0
207 node = self.root
208 while node:
209 if key == keys[node]:
210 break
211 path.append(node)
212 di <<= 1
213 if key < keys[node]:
214 di |= 1
215 node = left[node]
216 else:
217 node = right[node]
218 else:
219 return False
220 self._len -= min(val, vals[node])
221 if val > vals[node]:
222 val = vals[node] - 1
223 vals[node] -= val
224 if vals[node] == 1:
225 self._discard(node, path, di)
226 else:
227 vals[node] -= val
228 return True
229
[docs]
230 def discard_all(self, key: T) -> None:
231 self.discard(key, self.count(key))
232
[docs]
233 def remove(self, key: T, val: int = 1) -> None:
234 if self.discard(key, val):
235 return
236 raise KeyError(key)
237
[docs]
238 def add(self, key: T, val: int = 1) -> None:
239 self._len += val
240 if self.root == 0:
241 self.root = self._make_node(key, val)
242 return
243 left, right, keys, balance = self.left, self.right, self.key, self.balance
244 node = self.root
245 di = 0
246 path = []
247 while node:
248 if key == keys[node]:
249 self.val[node] += val
250 return
251 path.append(node)
252 di <<= 1
253 if key < keys[node]:
254 di |= 1
255 node = left[node]
256 else:
257 node = right[node]
258 if di & 1:
259 left[path[-1]] = self._make_node(key, val)
260 else:
261 right[path[-1]] = self._make_node(key, val)
262 new_node = 0
263 while path:
264 node = path.pop()
265 balance[node] += 1 if di & 1 else -1
266 di >>= 1
267 if balance[node] == 0:
268 break
269 if balance[node] == 2:
270 new_node = (
271 self._rotate_LR(node)
272 if balance[left[node]] < 0
273 else self._rotate_L(node)
274 )
275 break
276 elif balance[node] == -2:
277 new_node = (
278 self._rotate_RL(node)
279 if balance[right[node]] > 0
280 else self._rotate_R(node)
281 )
282 break
283 if new_node:
284 if path:
285 if di & 1:
286 left[path[-1]] = new_node
287 else:
288 right[path[-1]] = new_node
289 else:
290 self.root = new_node
291
[docs]
292 def count(self, key: T) -> int:
293 return BSTMultisetArrayBase[AVLTreeMultiset2, T].count(self, key)
294
[docs]
295 def le(self, key: T) -> Optional[T]:
296 return BSTMultisetArrayBase[AVLTreeMultiset2, T].le(self, key)
297
[docs]
298 def lt(self, key: T) -> Optional[T]:
299 return BSTMultisetArrayBase[AVLTreeMultiset2, T].lt(self, key)
300
[docs]
301 def ge(self, key: T) -> Optional[T]:
302 return BSTMultisetArrayBase[AVLTreeMultiset2, T].ge(self, key)
303
[docs]
304 def gt(self, key: T) -> Optional[T]:
305 return BSTMultisetArrayBase[AVLTreeMultiset2, T].gt(self, key)
306
[docs]
307 def get_min(self) -> Optional[T]:
308 if self.root == 0:
309 return
310 left = self.left
311 node = self.root
312 while left[node]:
313 node = left[node]
314 return self.key[node]
315
[docs]
316 def get_max(self) -> Optional[T]:
317 if self.root == 0:
318 return
319 right = self.right
320 node = self.root
321 while right[node]:
322 node = right[node]
323 return self.key[node]
324
[docs]
325 def pop_min(self) -> T:
326 left, vals, keys = self.left, self.val, self.key
327 self._len -= 1
328 node = self.root
329 path = []
330 while left[node]:
331 path.append(node)
332 node = left[node]
333 x = keys[node]
334 if vals[node] == 1:
335 self._discard(node, path, (1 << len(path)) - 1)
336 else:
337 vals[node] -= 1
338 return x
339
[docs]
340 def pop_max(self) -> T:
341 right, vals, keys = self.right, self.val, self.key
342 self._len -= 1
343 node = self.root
344 path = []
345 while right[node]:
346 path.append(node)
347 node = right[node]
348 x = keys[node]
349 if vals[node] == 1:
350 self._discard(node, path, 0)
351 else:
352 vals[node] -= 1
353 return x
354
[docs]
355 def clear(self) -> None:
356 self.root = 0
357
[docs]
358 def tolist(self) -> list[T]:
359 return BSTMultisetArrayBase[AVLTreeMultiset2, T].tolist(self)
360
[docs]
361 def tolist_items(self) -> list[tuple[T, int]]:
362 left, right, keys, vals = self.left, self.right, self.key, self.val
363 node = self.root
364 stack: list[int] = []
365 a: list[tuple[T, int]] = []
366 while stack or node:
367 if node:
368 stack.append(node)
369 node = left[node]
370 else:
371 node = stack.pop()
372 a.append((keys[node], vals[node]))
373 node = right[node]
374 return a
375
376 def __contains__(self, key: T):
377 return BSTMultisetArrayBase[AVLTreeMultiset2, T].contains(self, key)
378
379 def __len__(self):
380 return self._len
381
382 def __bool__(self):
383 return self.root != 0
384
385 def __str__(self):
386 return "{" + ", ".join(map(str, self.tolist())) + "}"
387
388 def __repr__(self):
389 return f"{self.__class__.__name__}({self.tolist()})"