b_tree_set

ソースコード

from titan_pylib.data_structures.b_tree.b_tree_set import BTreeSet

view on github

展開済みコード

  1# from titan_pylib.data_structures.b_tree.b_tree_set import BTreeSet
  2# from titan_pylib.my_class.ordered_set_interface import OrderedSetInterface
  3# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  4from typing import Protocol
  5
  6
  7class SupportsLessThan(Protocol):
  8
  9    def __lt__(self, other) -> bool: ...
 10from abc import ABC, abstractmethod
 11from typing import Iterable, Optional, Iterator, TypeVar, Generic
 12
 13T = TypeVar("T", bound=SupportsLessThan)
 14
 15
 16class OrderedSetInterface(ABC, Generic[T]):
 17
 18    @abstractmethod
 19    def __init__(self, a: Iterable[T]) -> None:
 20        raise NotImplementedError
 21
 22    @abstractmethod
 23    def add(self, key: T) -> bool:
 24        raise NotImplementedError
 25
 26    @abstractmethod
 27    def discard(self, key: T) -> bool:
 28        raise NotImplementedError
 29
 30    @abstractmethod
 31    def remove(self, key: T) -> None:
 32        raise NotImplementedError
 33
 34    @abstractmethod
 35    def le(self, key: T) -> Optional[T]:
 36        raise NotImplementedError
 37
 38    @abstractmethod
 39    def lt(self, key: T) -> Optional[T]:
 40        raise NotImplementedError
 41
 42    @abstractmethod
 43    def ge(self, key: T) -> Optional[T]:
 44        raise NotImplementedError
 45
 46    @abstractmethod
 47    def gt(self, key: T) -> Optional[T]:
 48        raise NotImplementedError
 49
 50    @abstractmethod
 51    def get_max(self) -> Optional[T]:
 52        raise NotImplementedError
 53
 54    @abstractmethod
 55    def get_min(self) -> Optional[T]:
 56        raise NotImplementedError
 57
 58    @abstractmethod
 59    def pop_max(self) -> T:
 60        raise NotImplementedError
 61
 62    @abstractmethod
 63    def pop_min(self) -> T:
 64        raise NotImplementedError
 65
 66    @abstractmethod
 67    def clear(self) -> None:
 68        raise NotImplementedError
 69
 70    @abstractmethod
 71    def tolist(self) -> list[T]:
 72        raise NotImplementedError
 73
 74    @abstractmethod
 75    def __iter__(self) -> Iterator:
 76        raise NotImplementedError
 77
 78    @abstractmethod
 79    def __next__(self) -> T:
 80        raise NotImplementedError
 81
 82    @abstractmethod
 83    def __contains__(self, key: T) -> bool:
 84        raise NotImplementedError
 85
 86    @abstractmethod
 87    def __len__(self) -> int:
 88        raise NotImplementedError
 89
 90    @abstractmethod
 91    def __bool__(self) -> bool:
 92        raise NotImplementedError
 93
 94    @abstractmethod
 95    def __str__(self) -> str:
 96        raise NotImplementedError
 97
 98    @abstractmethod
 99    def __repr__(self) -> str:
100        raise NotImplementedError
101# from titan_pylib.my_class.supports_less_than import SupportsLessThan
102from collections import deque
103from bisect import bisect_left, bisect_right, insort
104from typing import Deque, Generic, TypeVar, Optional, Iterable
105
106T = TypeVar("T", bound=SupportsLessThan)
107
108
109class BTreeSet(OrderedSetInterface, Generic[T]):
110
111    class _Node:
112
113        def __init__(self):
114            self.key: list = []
115            self.child: list["BTreeSet._Node"] = []
116
117        def is_leaf(self) -> bool:
118            return not self.child
119
120        def split(self, i: int) -> "BTreeSet._Node":
121            right = BTreeSet._Node()
122            self.key, right.key = self.key[:i], self.key[i:]
123            self.child, right.child = self.child[: i + 1], self.child[i + 1 :]
124            return right
125
126        def insert_key(self, i: int, key: T) -> None:
127            self.key.insert(i, key)
128
129        def insert_child(self, i: int, node: "BTreeSet._Node") -> None:
130            self.child.insert(i, node)
131
132        def append_key(self, key: T) -> None:
133            self.key.append(key)
134
135        def append_child(self, node: "BTreeSet._Node") -> None:
136            self.child.append(node)
137
138        def pop_key(self, i: int = -1) -> T:
139            return self.key.pop(i)
140
141        def len_key(self) -> int:
142            return len(self.key)
143
144        def insort_key(self, key: T) -> None:
145            insort(self.key, key)
146
147        def pop_child(self, i: int = -1) -> "BTreeSet._Node":
148            return self.child.pop(i)
149
150        def extend_key(self, keys: list[T]) -> None:
151            self.key += keys
152
153        def extend_child(self, children: list["BTreeSet._Node"]) -> None:
154            self.child += children
155
156        def __str__(self):
157            return str(str(self.key))
158
159        __repr__ = __str__
160
161    def __init__(self, a: Iterable[T] = []):
162        self._m: int = 1000
163        self._root: "BTreeSet._Node" = BTreeSet._Node()
164        self._len: int = 0
165        self._build(a)
166
167    def _build(self, a: Iterable[T]):
168        for e in a:
169            self.add(e)
170
171    def _is_over(self, node: "BTreeSet._Node") -> bool:
172        return node.len_key() > self._m
173
174    def add(self, key: T) -> bool:
175        node = self._root
176        stack = []
177        while True:
178            i = bisect_left(node.key, key)
179            if i < node.len_key() and node.key[i] == key:
180                return False
181            if i >= len(node.child):
182                break
183            stack.append(node)
184            node = node.child[i]
185        self._len += 1
186        node.insort_key(key)
187        while stack:
188            if not self._is_over(node):
189                break
190            pnode = stack.pop()
191            i = node.len_key() // 2
192            center = node.pop_key(i)
193            right = node.split(i)
194            indx = bisect_left(pnode.key, center)
195            pnode.insert_key(indx, center)
196            pnode.insert_child(indx + 1, right)
197            node = pnode
198        if self._is_over(node):
199            pnode = BTreeSet._Node()
200            i = node.len_key() // 2
201            center = node.pop_key(i)
202            right = node.split(i)
203            pnode.append_key(center)
204            pnode.append_child(node)
205            pnode.append_child(right)
206            self._root = pnode
207        return True
208
209    def __contains__(self, key: T) -> bool:
210        node = self._root
211        while True:
212            i = bisect_left(node.key, key)
213            if i < node.len_key() and node.key[i] == key:
214                return True
215            if node.is_leaf():
216                break
217            node = node.child[i]
218        return False
219
220    def _discard_right(self, node: "BTreeSet._Node") -> T:
221        while not node.is_leaf():
222            if node.child[-1].len_key() == self._m // 2:
223                if node.child[-2].len_key() > self._m // 2:
224                    cnode = node.child[-2]
225                    node.child[-1].insert_key(0, node.key[-1])
226                    node.key[-1] = cnode.pop_key()
227                    if cnode.child:
228                        node.child[-1].insert_child(0, cnode.pop_child())
229                    node = node.child[-1]
230                    continue
231                cnode = self._merge(node, node.len_key() - 1)
232                if node is self._root and not node.key:
233                    self._root = cnode
234                node = cnode
235                continue
236            node = node.child[-1]
237        return node.pop_key()
238
239    def _discard_left(self, node: "BTreeSet._Node") -> T:
240        while not node.is_leaf():
241            if node.child[0].len_key() == self._m // 2:
242                if node.child[1].len_key() > self._m // 2:
243                    cnode = node.child[1]
244                    node.child[0].append_key(node.key[0])
245                    node.key[0] = cnode.pop_key(0)
246                    if cnode.child:
247                        node.child[0].append_child(cnode.pop_child(0))
248                    node = node.child[0]
249                    continue
250                cnode = self._merge(node, 0)
251                if node is self._root and not node.key:
252                    self._root = cnode
253                node = cnode
254                continue
255            node = node.child[0]
256        return node.pop_key(0)
257
258    def _merge(self, node: "BTreeSet._Node", i: int) -> "BTreeSet._Node":
259        y = node.child[i]
260        z = node.pop_child(i + 1)
261        y.append_key(node.pop_key(i))
262        y.extend_key(z.key)
263        y.extend_child(z.child)
264        return y
265
266    def _merge_key(self, key: T, node: "BTreeSet._Node", i: int) -> None:
267        if node.child[i].len_key() > self._m // 2:
268            node.key[i] = self._discard_right(node.child[i])
269            return
270        if node.child[i + 1].len_key() > self._m // 2:
271            node.key[i] = self._discard_left(node.child[i + 1])
272            return
273        y = self._merge(node, i)
274        self._discard(key, y)
275        if node is self._root and not node.key:
276            self._root = y
277
278    def _discard(self, key: T, node: Optional["BTreeSet._Node"] = None) -> bool:
279        if node is None:
280            node = self._root
281        if not node.key:
282            return False
283        while True:
284            i = bisect_left(node.key, key)
285            if node.is_leaf():
286                if i < node.len_key() and node.key[i] == key:
287                    node.pop_key(i)
288                    return True
289                return False
290            if i < node.len_key() and node.key[i] == key:
291                assert i + 1 < len(node.child)
292                self._merge_key(key, node, i)
293                return True
294            if node.child[i].len_key() == self._m // 2:
295                if (
296                    i + 1 < len(node.child)
297                    and node.child[i + 1].len_key() > self._m // 2
298                ):
299                    cnode = node.child[i + 1]
300                    node.child[i].append_key(node.key[i])
301                    node.key[i] = cnode.pop_key(0)
302                    if cnode.child:
303                        node.child[i].append_child(cnode.pop_child(0))
304                    node = node.child[i]
305                    continue
306                if i - 1 >= 0 and node.child[i - 1].len_key() > self._m // 2:
307                    cnode = node.child[i - 1]
308                    node.child[i].insert_key(0, node.key[i - 1])
309                    node.key[i - 1] = cnode.pop_key()
310                    if cnode.child:
311                        node.child[i].insert_child(0, cnode.pop_child())
312                    node = node.child[i]
313                    continue
314                if i + 1 >= len(node.child):
315                    i -= 1
316                cnode = self._merge(node, i)
317                if node is self._root and not node.key:
318                    self._root = cnode
319                node = cnode
320                continue
321            node = node.child[i]
322
323    def discard(self, key: T) -> bool:
324        if self._discard(key):
325            self._len -= 1
326            return True
327        return False
328
329    def remove(self, key: T) -> None:
330        if self.discard(key):
331            return
332        raise ValueError
333
334    def tolist(self) -> list[T]:
335        a = []
336
337        def dfs(node):
338            if not node.child:
339                a.extend(node.key)
340                return
341            dfs(node.child[0])
342            for i in range(node.len_key()):
343                a.append(node.key[i])
344                dfs(node.child[i + 1])
345
346        dfs(self._root)
347        return a
348
349    def get_max(self) -> Optional[T]:
350        node = self._root
351        while True:
352            if not node.child:
353                return node.key[-1] if node.key else None
354            node = node.child[-1]
355
356    def get_min(self) -> Optional[T]:
357        node = self._root
358        while True:
359            if not node.child:
360                return node.key[0] if node.key else None
361            node = node.child[0]
362
363    def debug(self) -> None:
364        dep = [[] for _ in range(10)]
365        dq: Deque[tuple["BTreeSet._Node", int]] = deque([(self._root, 0)])
366        while dq:
367            node, d = dq.popleft()
368            dep[d].append(node.key)
369            if node.child:
370                print(node, "child=", node.child)
371            for e in node.child:
372                if e:
373                    dq.append((e, d + 1))
374        for i in range(10):
375            if not dep[i]:
376                break
377            for e in dep[i]:
378                print(e, end="  ")
379            print()
380
381    def pop_max(self) -> T:
382        res = self.get_max()
383        assert (
384            res is not None
385        ), f"IndexError: pop_max from empty {self.__class__.__name__}."
386        self.discard(res)
387        return res
388
389    def pop_min(self) -> T:
390        res = self.get_min()
391        assert (
392            res is not None
393        ), f"IndexError: pop_min from empty {self.__class__.__name__}."
394        self.discard(res)
395        return res
396
397    def ge(self, key: T) -> Optional[T]:
398        res, node = None, self._root
399        while node.key:
400            i = bisect_left(node.key, key)
401            if i < node.len_key() and node.key[i] == key:
402                return node.key[i]
403            if i < node.len_key():
404                res = node.key[i]
405            if not node.child:
406                break
407            node = node.child[i]
408        return res
409
410    def gt(self, key: T) -> Optional[T]:
411        res, node = None, self._root
412        while node.key:
413            i = bisect_right(node.key, key)
414            if i < node.len_key():
415                res = node.key[i]
416            if not node.child:
417                break
418            node = node.child[i]
419        return res
420
421    def le(self, key: T) -> Optional[T]:
422        res, node = None, self._root
423        while node.key:
424            i = bisect_left(node.key, key)
425            if i < node.len_key() and node.key[i] == key:
426                return node.key[i]
427            if i - 1 >= 0:
428                res = node.key[i - 1]
429            if not node.child:
430                break
431            node = node.child[i]
432        return res
433
434    def lt(self, key: T) -> Optional[T]:
435        res, node = None, self._root
436        while node.key:
437            i = bisect_left(node.key, key)
438            if i - 1 >= 0:
439                res = node.key[i - 1]
440            if not node.child:
441                break
442            node = node.child[i]
443        return res
444
445    def clear(self) -> None:
446        self._root = BTreeSet._Node()
447
448    def __iter__(self):
449        self._iter_val = self.get_min()
450        return self
451
452    def __next__(self):
453        if self._iter_val is None:
454            raise StopIteration
455        p = self._iter_val
456        self._iter_val = self.gt(self._iter_val)
457        return p
458
459    def __bool__(self):
460        return self._len > 0
461
462    def __len__(self):
463        return self._len
464
465    def __str__(self):
466        return "{" + ", ".join(map(str, self.tolist())) + "}"
467
468    def __repr__(self):
469        return f"{self.__class__.__name__}({self.tolist()})"

仕様

class BTreeSet(a: Iterable[T] = [])[source]

Bases: OrderedSetInterface, Generic[T]

add(key: T) bool[source]
clear() None[source]
debug() None[source]
discard(key: T) bool[source]
ge(key: T) T | None[source]
get_max() T | None[source]
get_min() T | None[source]
gt(key: T) T | None[source]
le(key: T) T | None[source]
lt(key: T) T | None[source]
pop_max() T[source]
pop_min() T[source]
remove(key: T) None[source]
tolist() list[T][source]