1from titan_pylib.data_structures.wbt._wbt_multiset_node import _WBTMultisetNode
2from typing import Generic, TypeVar, Optional, Iterable, Iterator
3
4T = TypeVar("T")
5
6
[docs]
7class WBTMultiset(Generic[T]):
8
9 __slots__ = "_root", "_min", "_max"
10
11 def __init__(self, a: Iterable[T] = []) -> None:
12 self._root: Optional[_WBTMultisetNode[T]] = None
13 self._min: Optional[_WBTMultisetNode[T]] = None
14 self._max: Optional[_WBTMultisetNode[T]] = None
15 self.__build(a)
16
17 def __build(self, a: Iterable[T]) -> None:
18 def build(
19 l: int, r: int, pnode: Optional[_WBTMultisetNode[T]] = None
20 ) -> _WBTMultisetNode[T]:
21 if l == r:
22 return None
23 mid = (l + r) // 2
24 node = _WBTMultisetNode(keys[mid], vals[mid])
25 node._left = build(l, mid, node)
26 node._right = build(mid + 1, r, node)
27 node._par = pnode
28 node._update()
29 return node
30
31 a = list(a)
32 if not a:
33 return
34 if not all(a[i] <= a[i + 1] for i in range(len(a) - 1)):
35 a.sort()
36 # RLE
37 keys, vals = [a[0]], [1]
38 for i, elm in enumerate(a):
39 if i == 0:
40 continue
41 if elm == keys[-1]:
42 vals[-1] += 1
43 continue
44 keys.append(elm)
45 vals.append(1)
46 self._root = build(0, len(keys))
47 self._max = self._root._max()
48 self._min = self._root._min()
49
[docs]
50 def add(self, key: T, count: int = 1) -> None:
51 if not self._root:
52 self._root = _WBTMultisetNode(key, count)
53 self._max = self._root
54 self._min = self._root
55 return
56 pnode = None
57 node = self._root
58 while node:
59 node._count_size += count
60 if key == node._key:
61 node._count += count
62 return
63 pnode = node
64 node = node._left if key < node._key else node._right
65 if key < pnode._key:
66 pnode._left = _WBTMultisetNode(key, count)
67 if key < self._min._key:
68 self._min = pnode._left
69 pnode._left._par = pnode
70 else:
71 pnode._right = _WBTMultisetNode(key, count)
72 if key > self._max._key:
73 self._max = pnode._right
74 pnode._right._par = pnode
75 self._root = pnode._rebalance()
76
[docs]
77 def find_key(self, key: T) -> Optional[_WBTMultisetNode[T]]:
78 node = self._root
79 while node:
80 if key == node._key:
81 return node
82 node = node._left if key < node._key else node._right
83 return None
84
[docs]
85 def find_order(self, k: int) -> _WBTMultisetNode[T]:
86 node = self._root
87 while True:
88 t = node._left._count_size + node._count if node._left else node._count
89 if t - node._count <= k < t:
90 return node
91 if t > k:
92 node = node._left
93 else:
94 node = node._right
95 k -= t
96
[docs]
97 def count(self, key: T) -> int:
98 node = self.find_key(key)
99 return node.count if node is not None else 0
100
[docs]
101 def remove_iter(self, node: _WBTMultisetNode[T]) -> None:
102 if node is self._min:
103 self._min = self._min._next()
104 if node is self._max:
105 self._max = self._max._prev()
106 delnode = node
107 pnode, mnode = node._par, None
108 if node._left and node._right:
109 pnode, mnode = node, node._left
110 while mnode._right:
111 pnode, mnode = mnode, mnode._right
112 node._count = mnode._count
113 node = mnode
114 cnode = node._right if not node._left else node._left
115 if cnode:
116 cnode._par = pnode
117 if pnode:
118 if pnode._left is node:
119 pnode._left = cnode
120 else:
121 pnode._right = cnode
122 self._root = pnode._rebalance()
123 else:
124 self._root = cnode
125 if mnode:
126 if self._root is delnode:
127 self._root = mnode
128 mnode._copy_from(delnode)
129 del delnode
130
[docs]
131 def remove(self, key: T, count: int = 1) -> None:
132 node = self.find_key(key)
133 assert node, f"KeyError: {key} is not found."
134 if node._count <= count:
135 self.remove_iter(node)
136 else:
137 node._count -= count
138 while node:
139 node._count_size -= count
140 node = node._par
141
[docs]
142 def discard(self, key: T, count: int = 1) -> bool:
143 node = self.find_key(key)
144 if node is None:
145 return False
146 if node._count <= count:
147 self.remove_iter(node)
148 else:
149 node._count -= count
150 while node:
151 node._count_size -= count
152 node = node._par
153 return True
154
[docs]
155 def pop(self, k: int = -1) -> T:
156 node = self.find_order(k)
157 key = node._key
158 if node._count == 0:
159 self.remove_iter(node)
160 else:
161 node._count -= 1
162 while node:
163 node._count_size -= 1
164 node = node._par
165 return key
166
[docs]
167 def le_iter(self, key: T) -> Optional[_WBTMultisetNode[T]]:
168 res = None
169 node = self._root
170 while node:
171 if key == node._key:
172 res = node
173 break
174 if key < node._key:
175 node = node._left
176 else:
177 res = node
178 node = node._right
179 return res
180
[docs]
181 def lt_iter(self, key: T) -> Optional[_WBTMultisetNode[T]]:
182 res = None
183 node = self._root
184 while node:
185 if key <= node._key:
186 node = node._left
187 else:
188 res = node
189 node = node._right
190 return res
191
[docs]
192 def ge_iter(self, key: T) -> Optional[_WBTMultisetNode[T]]:
193 res = None
194 node = self._root
195 while node:
196 if key == node._key:
197 res = node
198 break
199 if key < node._key:
200 res = node
201 node = node._left
202 else:
203 node = node._right
204 return res
205
[docs]
206 def gt_iter(self, key: T) -> Optional[_WBTMultisetNode[T]]:
207 res = None
208 node = self._root
209 while node:
210 if key < node._key:
211 res = node
212 node = node._left
213 else:
214 node = node._right
215 return res
216
[docs]
217 def le(self, key: T) -> Optional[T]:
218 res = None
219 node = self._root
220 while node:
221 if key == node._key:
222 res = key
223 break
224 if key < node._key:
225 node = node._left
226 else:
227 res = node._key
228 node = node._right
229 return res
230
[docs]
231 def lt(self, key: T) -> Optional[T]:
232 res = None
233 node = self._root
234 while node:
235 if key <= node._key:
236 node = node._left
237 else:
238 res = node._key
239 node = node._right
240 return res
241
[docs]
242 def ge(self, key: T) -> Optional[T]:
243 res = None
244 node = self._root
245 while node:
246 if key == node._key:
247 res = key
248 break
249 if key < node._key:
250 res = node._key
251 node = node._left
252 else:
253 node = node._right
254 return res
255
[docs]
256 def gt(self, key: T) -> Optional[T]:
257 res = None
258 node = self._root
259 while node:
260 if key < node._key:
261 res = node._key
262 node = node._left
263 else:
264 node = node._right
265 return res
266
[docs]
267 def index(self, key: T) -> int:
268 k = 0
269 node = self._root
270 while node:
271 if key == node._key:
272 k += node._left._count_size if node._left else 0
273 break
274 if key < node._key:
275 node = node._left
276 else:
277 k += node._left._count_size + node._count if node._left else node._count
278 node = node._right
279 return k
280
[docs]
281 def index_right(self, key: T) -> int:
282 k = 0
283 node = self._root
284 while node:
285 if key == node._key:
286 k += node._left._count_size + node._count if node._left else node._count
287 break
288 if key < node._key:
289 node = node._left
290 else:
291 k += node._left._count_size + node._count if node._left else node._count
292 node = node._right
293 return k
294
[docs]
295 def tolist(self) -> list[T]:
296 return list(self)
297
[docs]
298 def get_min(self) -> T:
299 assert self._min
300 return self._min._key
301
[docs]
302 def get_max(self) -> T:
303 assert self._max
304 return self._max._key
305
[docs]
306 def pop_min(self) -> T:
307 assert self._min
308 key = self._min._key
309 self._min._count -= 1
310 if self._min._count == 0:
311 self.remove_iter(self._min)
312 return key
313
[docs]
314 def pop_max(self) -> T:
315 assert self._max
316 key = self._max._key
317 self._max._count -= 1
318 if self._max._count == 0:
319 self.remove_iter(self._max)
320 return key
321
[docs]
322 def check(self) -> None:
323 if self._root is None:
324 # print("ok. 0 (empty)")
325 return
326
327 # _size, count_size, height
328 def dfs(node: _WBTMultisetNode[T]) -> tuple[int, int, int]:
329 h = 0
330 s = 1
331 cs = node.count
332 if node._left:
333 assert node._key > node._left._key
334 ls, lcs, lh = dfs(node._left)
335 s += ls
336 cs += lcs
337 h = max(h, lh)
338 if node._right:
339 assert node._key < node._right._key
340 rs, rcs, rh = dfs(node._right)
341 s += rs
342 cs += rcs
343 h = max(h, rh)
344 assert node._size == s
345 assert node._count_size == cs
346 node._balance_check()
347 return s, cs, h + 1
348
349 _, _, h = dfs(self._root)
350 # print(f"ok. {h}")
351
352 def __contains__(self, key: T) -> bool:
353 return self.find_key(key) is not None
354
355 def __getitem__(self, k: int) -> T:
356 assert (
357 -len(self) <= k < len(self)
358 ), f"IndexError: {self.__class__.__name__}[{k}], len={len(self)}"
359 if k < 0:
360 k += len(self)
361 if k == 0:
362 return self.get_min()
363 if k == len(self) - 1:
364 return self.get_max()
365 return self.find_order(k)._key
366
367 def __delitem__(self, k: int) -> None:
368 node = self.find_order(k)
369 node._count -= 1
370 if node._count == 0:
371 self.remove_iter(node)
372
373 def __len__(self) -> int:
374 return self._root._count_size if self._root else 0
375
376 def __iter__(self) -> Iterator[T]:
377 stack: list[_WBTMultisetNode[T]] = []
378 node = self._root
379 while stack or node:
380 if node:
381 stack.append(node)
382 node = node._left
383 else:
384 node = stack.pop()
385 for _ in range(node._count):
386 yield node._key
387 node = node._right
388
389 def __reversed__(self) -> Iterator[T]:
390 stack: list[_WBTMultisetNode[T]] = []
391 node = self._root
392 while stack or node:
393 if node:
394 stack.append(node)
395 node = node._right
396 else:
397 node = stack.pop()
398 for _ in range(node._count):
399 yield node._key
400 node = node._left
401
402 def __str__(self) -> str:
403 return "{" + ", ".join(map(str, self)) + "}"
404
405 def __repr__(self) -> str:
406 return (
407 f"{self.__class__.__name__}("
408 + "["
409 + ", ".join(map(str, self.tolist()))
410 + "])"
411 )