splay_tree_dict

ソースコード

from titan_pylib.data_structures.splay_tree.splay_tree_dict import SplayTreeDict

view on github

展開済みコード

  1# from titan_pylib.data_structures.splay_tree.splay_tree_dict import SplayTreeDict
  2# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  3from typing import Protocol
  4
  5
  6class SupportsLessThan(Protocol):
  7
  8    def __lt__(self, other) -> bool: ...
  9from array import array
 10from typing import Generic, Iterator, TypeVar, Any
 11
 12T = TypeVar("T", bound=SupportsLessThan)
 13
 14
 15class SplayTreeDict(Generic[T]):
 16
 17    def __init__(self, e: T, default: Any = 0, reserve: int = 1) -> None:
 18        # e: keyとして使わない値
 19        # default: valのdefault値
 20        if reserve < 1:
 21            reserve = 1
 22        self._keys: list[T] = [e] * reserve
 23        self._vals: list[Any] = [0] * reserve
 24        self._child = array("I", bytes(8 * reserve))
 25        self._end: int = 1
 26        self._root: int = 0
 27        self._len: int = 0
 28        self._default: Any = default
 29        self._e: T = e
 30
 31    def reserve(self, n: int) -> None:
 32        assert n >= 0, "ValueError"
 33        self._keys += [self._e] * n
 34        self._vals += [0] * n
 35        self._child += array("I", bytes(8 * n))
 36
 37    def _make_node(self, key: T, val: Any) -> int:
 38        if self._end >= len(self._keys):
 39            self._keys.append(key)
 40            self._vals.append(val)
 41            self._child.append(0)
 42            self._child.append(0)
 43        else:
 44            self._keys[self._end] = key
 45            self._vals[self._end] = val
 46        self._end += 1
 47        return self._end - 1
 48
 49    def _set_search_splay(self, key: T) -> None:
 50        node = self._root
 51        keys, child = self._keys, self._child
 52        if (not node) or keys[node] == key:
 53            return
 54        left, right = 0, 0
 55        while keys[node] != key:
 56            d = key > keys[node]
 57            if not child[node << 1 | d]:
 58                break
 59            if (d and key > keys[child[node << 1 | 1]]) or (
 60                d ^ 1 and key < keys[child[node << 1]]
 61            ):
 62                new = child[node << 1 | d]
 63                child[node << 1 | d] = child[new << 1 | (d ^ 1)]
 64                child[new << 1 | (d ^ 1)] = node
 65                node = new
 66                if not child[node << 1 | d]:
 67                    break
 68            if d:
 69                child[left << 1 | 1] = node
 70                left = node
 71            else:
 72                child[right << 1] = node
 73                right = node
 74            node = child[node << 1 | d]
 75        child[right << 1] = child[node << 1 | 1]
 76        child[left << 1 | 1] = child[node << 1]
 77        child[node << 1] = child[1]
 78        child[node << 1 | 1] = child[0]
 79        self._root = node
 80
 81    def _get_min_splay(self, node: int) -> int:
 82        child = self._child
 83        if (not node) or (not child[node << 1]):
 84            return node
 85        right = 0
 86        while child[node << 1]:
 87            new = child[node << 1]
 88            child[node << 1] = child[new << 1 | 1]
 89            child[new << 1 | 1] = node
 90            if not child[new << 1]:
 91                break
 92            child[right << 1] = new
 93            right = new
 94            node = child[new << 1]
 95        child[right << 1] = child[node << 1 | 1]
 96        child[1] = child[node << 1]
 97        child[node << 1] = child[1]
 98        child[node << 1 | 1] = child[0]
 99        return node
100
101    def __setitem__(self, key: T, val: Any):
102        if not self._root:
103            self._root = self._make_node(key, val)
104            self._len += 1
105            return
106        self._set_search_splay(key)
107        if self._keys[self._root] == key:
108            self._vals[self._root] = val
109            return
110        node = self._make_node(key, val)
111        d = self._keys[self._root] < key
112        self._child[node << 1 | (d ^ 1)] = self._root
113        self._child[node << 1 | d] = self._child[self._root << 1 | d]
114        self._child[self._root << 1 | d] = 0
115        self._root = node
116        self._len += 1
117
118    def __delitem__(self, key: T) -> None:
119        if self._root == 0:
120            return
121        self._set_search_splay(key)
122        if self._keys[self._root] != key:
123            return
124        if self._child[self._root << 1] == 0:
125            self._root = self._child[self._root << 1 | 1]
126        elif self._child[self._root << 1 | 1] == 0:
127            self._root = self._child[self._root << 1]
128        else:
129            node = self._get_min_splay(self._child[self._root << 1 | 1])
130            self._child[node << 1] = self._child[self._root << 1]
131            self._root = node
132        self._len -= 1
133
134    def tolist(self) -> list[tuple[T, Any]]:
135        node = self._root
136        child, keys, vals = self._child, self._keys, self._vals
137        stack, res = [], []
138        while stack or node:
139            if node:
140                stack.append(node)
141                node = child[node << 1]
142            else:
143                node = stack.pop()
144                res.append((keys[node], vals[node]))
145                node = child[node << 1 | 1]
146        return res
147
148    def keys(self) -> Iterator[T]:
149        node = self._root
150        child, keys = self._child, self._keys
151        stack = []
152        while stack or node:
153            if node:
154                stack.append(node)
155                node = child[node << 1]
156            else:
157                node = stack.pop()
158                yield keys[node]
159                node = child[node << 1 | 1]
160
161    def vals(self) -> Iterator[Any]:
162        node = self._root
163        child, vals = self._child, self._vals
164        stack = []
165        while stack or node:
166            if node:
167                stack.append(node)
168                node = child[node << 1]
169            else:
170                node = stack.pop()
171                yield vals[node]
172                node = child[node << 1 | 1]
173
174    def items(self) -> Iterator[tuple[T, Any]]:
175        node = self._root
176        child, keys, vals = self._child, self._keys, self._vals
177        stack = []
178        while stack or node:
179            if node:
180                stack.append(node)
181                node = child[node << 1]
182            else:
183                node = stack.pop()
184                yield (keys[node], vals[node])
185                node = child[node << 1 | 1]
186
187    def __getitem__(self, key: T) -> Any:
188        self._set_search_splay(key)
189        if self._root == 0 or self._keys[self._root] != key:
190            return self._default
191        return self._vals[self._root]
192
193    def __contains__(self, key: T):
194        self._set_search_splay(key)
195        return self._keys[self._root] == key
196
197    def __len__(self):
198        return self._len
199
200    def __bool__(self):
201        return self._root > 0
202
203    def __str__(self):
204        return "{" + ", ".join(map(str, self.tolist())) + "}"
205
206    def __repr__(self):
207        return f"SplayTreeDict({self})"

仕様

class SplayTreeDict(e: T, default: Any = 0, reserve: int = 1)[source]

Bases: Generic[T]

items() Iterator[tuple[T, Any]][source]
keys() Iterator[T][source]
reserve(n: int) None[source]
tolist() list[tuple[T, Any]][source]
vals() Iterator[Any][source]