1from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
2from titan_pylib.my_class.supports_less_than import SupportsLessThan
3from collections import deque
4from bisect import bisect_left, bisect_right, insort
5from typing import Deque, Generic, TypeVar, Optional, Iterable
6
7T = TypeVar("T", bound=SupportsLessThan)
8
9
[docs]
10class BTreeSet(OrderedSetInterface, Generic[T]):
11
12 class _Node:
13
14 def __init__(self):
15 self.key: list = []
16 self.child: list["BTreeSet._Node"] = []
17
18 def is_leaf(self) -> bool:
19 return not self.child
20
21 def split(self, i: int) -> "BTreeSet._Node":
22 right = BTreeSet._Node()
23 self.key, right.key = self.key[:i], self.key[i:]
24 self.child, right.child = self.child[: i + 1], self.child[i + 1 :]
25 return right
26
27 def insert_key(self, i: int, key: T) -> None:
28 self.key.insert(i, key)
29
30 def insert_child(self, i: int, node: "BTreeSet._Node") -> None:
31 self.child.insert(i, node)
32
33 def append_key(self, key: T) -> None:
34 self.key.append(key)
35
36 def append_child(self, node: "BTreeSet._Node") -> None:
37 self.child.append(node)
38
39 def pop_key(self, i: int = -1) -> T:
40 return self.key.pop(i)
41
42 def len_key(self) -> int:
43 return len(self.key)
44
45 def insort_key(self, key: T) -> None:
46 insort(self.key, key)
47
48 def pop_child(self, i: int = -1) -> "BTreeSet._Node":
49 return self.child.pop(i)
50
51 def extend_key(self, keys: list[T]) -> None:
52 self.key += keys
53
54 def extend_child(self, children: list["BTreeSet._Node"]) -> None:
55 self.child += children
56
57 def __str__(self):
58 return str(str(self.key))
59
60 __repr__ = __str__
61
62 def __init__(self, a: Iterable[T] = []):
63 self._m: int = 1000
64 self._root: "BTreeSet._Node" = BTreeSet._Node()
65 self._len: int = 0
66 self._build(a)
67
68 def _build(self, a: Iterable[T]):
69 for e in a:
70 self.add(e)
71
72 def _is_over(self, node: "BTreeSet._Node") -> bool:
73 return node.len_key() > self._m
74
[docs]
75 def add(self, key: T) -> bool:
76 node = self._root
77 stack = []
78 while True:
79 i = bisect_left(node.key, key)
80 if i < node.len_key() and node.key[i] == key:
81 return False
82 if i >= len(node.child):
83 break
84 stack.append(node)
85 node = node.child[i]
86 self._len += 1
87 node.insort_key(key)
88 while stack:
89 if not self._is_over(node):
90 break
91 pnode = stack.pop()
92 i = node.len_key() // 2
93 center = node.pop_key(i)
94 right = node.split(i)
95 indx = bisect_left(pnode.key, center)
96 pnode.insert_key(indx, center)
97 pnode.insert_child(indx + 1, right)
98 node = pnode
99 if self._is_over(node):
100 pnode = BTreeSet._Node()
101 i = node.len_key() // 2
102 center = node.pop_key(i)
103 right = node.split(i)
104 pnode.append_key(center)
105 pnode.append_child(node)
106 pnode.append_child(right)
107 self._root = pnode
108 return True
109
110 def __contains__(self, key: T) -> bool:
111 node = self._root
112 while True:
113 i = bisect_left(node.key, key)
114 if i < node.len_key() and node.key[i] == key:
115 return True
116 if node.is_leaf():
117 break
118 node = node.child[i]
119 return False
120
121 def _discard_right(self, node: "BTreeSet._Node") -> T:
122 while not node.is_leaf():
123 if node.child[-1].len_key() == self._m // 2:
124 if node.child[-2].len_key() > self._m // 2:
125 cnode = node.child[-2]
126 node.child[-1].insert_key(0, node.key[-1])
127 node.key[-1] = cnode.pop_key()
128 if cnode.child:
129 node.child[-1].insert_child(0, cnode.pop_child())
130 node = node.child[-1]
131 continue
132 cnode = self._merge(node, node.len_key() - 1)
133 if node is self._root and not node.key:
134 self._root = cnode
135 node = cnode
136 continue
137 node = node.child[-1]
138 return node.pop_key()
139
140 def _discard_left(self, node: "BTreeSet._Node") -> T:
141 while not node.is_leaf():
142 if node.child[0].len_key() == self._m // 2:
143 if node.child[1].len_key() > self._m // 2:
144 cnode = node.child[1]
145 node.child[0].append_key(node.key[0])
146 node.key[0] = cnode.pop_key(0)
147 if cnode.child:
148 node.child[0].append_child(cnode.pop_child(0))
149 node = node.child[0]
150 continue
151 cnode = self._merge(node, 0)
152 if node is self._root and not node.key:
153 self._root = cnode
154 node = cnode
155 continue
156 node = node.child[0]
157 return node.pop_key(0)
158
159 def _merge(self, node: "BTreeSet._Node", i: int) -> "BTreeSet._Node":
160 y = node.child[i]
161 z = node.pop_child(i + 1)
162 y.append_key(node.pop_key(i))
163 y.extend_key(z.key)
164 y.extend_child(z.child)
165 return y
166
167 def _merge_key(self, key: T, node: "BTreeSet._Node", i: int) -> None:
168 if node.child[i].len_key() > self._m // 2:
169 node.key[i] = self._discard_right(node.child[i])
170 return
171 if node.child[i + 1].len_key() > self._m // 2:
172 node.key[i] = self._discard_left(node.child[i + 1])
173 return
174 y = self._merge(node, i)
175 self._discard(key, y)
176 if node is self._root and not node.key:
177 self._root = y
178
179 def _discard(self, key: T, node: Optional["BTreeSet._Node"] = None) -> bool:
180 if node is None:
181 node = self._root
182 if not node.key:
183 return False
184 while True:
185 i = bisect_left(node.key, key)
186 if node.is_leaf():
187 if i < node.len_key() and node.key[i] == key:
188 node.pop_key(i)
189 return True
190 return False
191 if i < node.len_key() and node.key[i] == key:
192 assert i + 1 < len(node.child)
193 self._merge_key(key, node, i)
194 return True
195 if node.child[i].len_key() == self._m // 2:
196 if (
197 i + 1 < len(node.child)
198 and node.child[i + 1].len_key() > self._m // 2
199 ):
200 cnode = node.child[i + 1]
201 node.child[i].append_key(node.key[i])
202 node.key[i] = cnode.pop_key(0)
203 if cnode.child:
204 node.child[i].append_child(cnode.pop_child(0))
205 node = node.child[i]
206 continue
207 if i - 1 >= 0 and node.child[i - 1].len_key() > self._m // 2:
208 cnode = node.child[i - 1]
209 node.child[i].insert_key(0, node.key[i - 1])
210 node.key[i - 1] = cnode.pop_key()
211 if cnode.child:
212 node.child[i].insert_child(0, cnode.pop_child())
213 node = node.child[i]
214 continue
215 if i + 1 >= len(node.child):
216 i -= 1
217 cnode = self._merge(node, i)
218 if node is self._root and not node.key:
219 self._root = cnode
220 node = cnode
221 continue
222 node = node.child[i]
223
[docs]
224 def discard(self, key: T) -> bool:
225 if self._discard(key):
226 self._len -= 1
227 return True
228 return False
229
[docs]
230 def remove(self, key: T) -> None:
231 if self.discard(key):
232 return
233 raise ValueError
234
[docs]
235 def tolist(self) -> list[T]:
236 a = []
237
238 def dfs(node):
239 if not node.child:
240 a.extend(node.key)
241 return
242 dfs(node.child[0])
243 for i in range(node.len_key()):
244 a.append(node.key[i])
245 dfs(node.child[i + 1])
246
247 dfs(self._root)
248 return a
249
[docs]
250 def get_max(self) -> Optional[T]:
251 node = self._root
252 while True:
253 if not node.child:
254 return node.key[-1] if node.key else None
255 node = node.child[-1]
256
[docs]
257 def get_min(self) -> Optional[T]:
258 node = self._root
259 while True:
260 if not node.child:
261 return node.key[0] if node.key else None
262 node = node.child[0]
263
[docs]
264 def debug(self) -> None:
265 dep = [[] for _ in range(10)]
266 dq: Deque[tuple["BTreeSet._Node", int]] = deque([(self._root, 0)])
267 while dq:
268 node, d = dq.popleft()
269 dep[d].append(node.key)
270 if node.child:
271 print(node, "child=", node.child)
272 for e in node.child:
273 if e:
274 dq.append((e, d + 1))
275 for i in range(10):
276 if not dep[i]:
277 break
278 for e in dep[i]:
279 print(e, end=" ")
280 print()
281
[docs]
282 def pop_max(self) -> T:
283 res = self.get_max()
284 assert (
285 res is not None
286 ), f"IndexError: pop_max from empty {self.__class__.__name__}."
287 self.discard(res)
288 return res
289
[docs]
290 def pop_min(self) -> T:
291 res = self.get_min()
292 assert (
293 res is not None
294 ), f"IndexError: pop_min from empty {self.__class__.__name__}."
295 self.discard(res)
296 return res
297
[docs]
298 def ge(self, key: T) -> Optional[T]:
299 res, node = None, self._root
300 while node.key:
301 i = bisect_left(node.key, key)
302 if i < node.len_key() and node.key[i] == key:
303 return node.key[i]
304 if i < node.len_key():
305 res = node.key[i]
306 if not node.child:
307 break
308 node = node.child[i]
309 return res
310
[docs]
311 def gt(self, key: T) -> Optional[T]:
312 res, node = None, self._root
313 while node.key:
314 i = bisect_right(node.key, key)
315 if i < node.len_key():
316 res = node.key[i]
317 if not node.child:
318 break
319 node = node.child[i]
320 return res
321
[docs]
322 def le(self, key: T) -> Optional[T]:
323 res, node = None, self._root
324 while node.key:
325 i = bisect_left(node.key, key)
326 if i < node.len_key() and node.key[i] == key:
327 return node.key[i]
328 if i - 1 >= 0:
329 res = node.key[i - 1]
330 if not node.child:
331 break
332 node = node.child[i]
333 return res
334
[docs]
335 def lt(self, key: T) -> Optional[T]:
336 res, node = None, self._root
337 while node.key:
338 i = bisect_left(node.key, key)
339 if i - 1 >= 0:
340 res = node.key[i - 1]
341 if not node.child:
342 break
343 node = node.child[i]
344 return res
345
[docs]
346 def clear(self) -> None:
347 self._root = BTreeSet._Node()
348
349 def __iter__(self):
350 self._iter_val = self.get_min()
351 return self
352
353 def __next__(self):
354 if self._iter_val is None:
355 raise StopIteration
356 p = self._iter_val
357 self._iter_val = self.gt(self._iter_val)
358 return p
359
360 def __bool__(self):
361 return self._len > 0
362
363 def __len__(self):
364 return self._len
365
366 def __str__(self):
367 return "{" + ", ".join(map(str, self.tolist())) + "}"
368
369 def __repr__(self):
370 return f"{self.__class__.__name__}({self.tolist()})"