1from titan_pylib.my_class.supports_less_than import SupportsLessThan
2from array import array
3from typing import Generic, Iterator, TypeVar, Any
4
5T = TypeVar("T", bound=SupportsLessThan)
6
7
[docs]
8class SplayTreeDict(Generic[T]):
9
10 def __init__(self, e: T, default: Any = 0, reserve: int = 1) -> None:
11 # e: keyとして使わない値
12 # default: valのdefault値
13 if reserve < 1:
14 reserve = 1
15 self._keys: list[T] = [e] * reserve
16 self._vals: list[Any] = [0] * reserve
17 self._child = array("I", bytes(8 * reserve))
18 self._end: int = 1
19 self._root: int = 0
20 self._len: int = 0
21 self._default: Any = default
22 self._e: T = e
23
[docs]
24 def reserve(self, n: int) -> None:
25 assert n >= 0, "ValueError"
26 self._keys += [self._e] * n
27 self._vals += [0] * n
28 self._child += array("I", bytes(8 * n))
29
30 def _make_node(self, key: T, val: Any) -> int:
31 if self._end >= len(self._keys):
32 self._keys.append(key)
33 self._vals.append(val)
34 self._child.append(0)
35 self._child.append(0)
36 else:
37 self._keys[self._end] = key
38 self._vals[self._end] = val
39 self._end += 1
40 return self._end - 1
41
42 def _set_search_splay(self, key: T) -> None:
43 node = self._root
44 keys, child = self._keys, self._child
45 if (not node) or keys[node] == key:
46 return
47 left, right = 0, 0
48 while keys[node] != key:
49 d = key > keys[node]
50 if not child[node << 1 | d]:
51 break
52 if (d and key > keys[child[node << 1 | 1]]) or (
53 d ^ 1 and key < keys[child[node << 1]]
54 ):
55 new = child[node << 1 | d]
56 child[node << 1 | d] = child[new << 1 | (d ^ 1)]
57 child[new << 1 | (d ^ 1)] = node
58 node = new
59 if not child[node << 1 | d]:
60 break
61 if d:
62 child[left << 1 | 1] = node
63 left = node
64 else:
65 child[right << 1] = node
66 right = node
67 node = child[node << 1 | d]
68 child[right << 1] = child[node << 1 | 1]
69 child[left << 1 | 1] = child[node << 1]
70 child[node << 1] = child[1]
71 child[node << 1 | 1] = child[0]
72 self._root = node
73
74 def _get_min_splay(self, node: int) -> int:
75 child = self._child
76 if (not node) or (not child[node << 1]):
77 return node
78 right = 0
79 while child[node << 1]:
80 new = child[node << 1]
81 child[node << 1] = child[new << 1 | 1]
82 child[new << 1 | 1] = node
83 if not child[new << 1]:
84 break
85 child[right << 1] = new
86 right = new
87 node = child[new << 1]
88 child[right << 1] = child[node << 1 | 1]
89 child[1] = child[node << 1]
90 child[node << 1] = child[1]
91 child[node << 1 | 1] = child[0]
92 return node
93
94 def __setitem__(self, key: T, val: Any):
95 if not self._root:
96 self._root = self._make_node(key, val)
97 self._len += 1
98 return
99 self._set_search_splay(key)
100 if self._keys[self._root] == key:
101 self._vals[self._root] = val
102 return
103 node = self._make_node(key, val)
104 d = self._keys[self._root] < key
105 self._child[node << 1 | (d ^ 1)] = self._root
106 self._child[node << 1 | d] = self._child[self._root << 1 | d]
107 self._child[self._root << 1 | d] = 0
108 self._root = node
109 self._len += 1
110
111 def __delitem__(self, key: T) -> None:
112 if self._root == 0:
113 return
114 self._set_search_splay(key)
115 if self._keys[self._root] != key:
116 return
117 if self._child[self._root << 1] == 0:
118 self._root = self._child[self._root << 1 | 1]
119 elif self._child[self._root << 1 | 1] == 0:
120 self._root = self._child[self._root << 1]
121 else:
122 node = self._get_min_splay(self._child[self._root << 1 | 1])
123 self._child[node << 1] = self._child[self._root << 1]
124 self._root = node
125 self._len -= 1
126
[docs]
127 def tolist(self) -> list[tuple[T, Any]]:
128 node = self._root
129 child, keys, vals = self._child, self._keys, self._vals
130 stack, res = [], []
131 while stack or node:
132 if node:
133 stack.append(node)
134 node = child[node << 1]
135 else:
136 node = stack.pop()
137 res.append((keys[node], vals[node]))
138 node = child[node << 1 | 1]
139 return res
140
[docs]
141 def keys(self) -> Iterator[T]:
142 node = self._root
143 child, keys = self._child, self._keys
144 stack = []
145 while stack or node:
146 if node:
147 stack.append(node)
148 node = child[node << 1]
149 else:
150 node = stack.pop()
151 yield keys[node]
152 node = child[node << 1 | 1]
153
[docs]
154 def vals(self) -> Iterator[Any]:
155 node = self._root
156 child, vals = self._child, self._vals
157 stack = []
158 while stack or node:
159 if node:
160 stack.append(node)
161 node = child[node << 1]
162 else:
163 node = stack.pop()
164 yield vals[node]
165 node = child[node << 1 | 1]
166
[docs]
167 def items(self) -> Iterator[tuple[T, Any]]:
168 node = self._root
169 child, keys, vals = self._child, self._keys, self._vals
170 stack = []
171 while stack or node:
172 if node:
173 stack.append(node)
174 node = child[node << 1]
175 else:
176 node = stack.pop()
177 yield (keys[node], vals[node])
178 node = child[node << 1 | 1]
179
180 def __getitem__(self, key: T) -> Any:
181 self._set_search_splay(key)
182 if self._root == 0 or self._keys[self._root] != key:
183 return self._default
184 return self._vals[self._root]
185
186 def __contains__(self, key: T):
187 self._set_search_splay(key)
188 return self._keys[self._root] == key
189
190 def __len__(self):
191 return self._len
192
193 def __bool__(self):
194 return self._root > 0
195
196 def __str__(self):
197 return "{" + ", ".join(map(str, self.tolist())) + "}"
198
199 def __repr__(self):
200 return f"SplayTreeDict({self})"