1# from titan_pylib.data_structures.splay_tree.splay_tree_dict import SplayTreeDict
2# from titan_pylib.my_class.supports_less_than import SupportsLessThan
3from typing import Protocol
4
5
6class SupportsLessThan(Protocol):
7
8 def __lt__(self, other) -> bool: ...
9from array import array
10from typing import Generic, Iterator, TypeVar, Any
11
12T = TypeVar("T", bound=SupportsLessThan)
13
14
15class SplayTreeDict(Generic[T]):
16
17 def __init__(self, e: T, default: Any = 0, reserve: int = 1) -> None:
18 # e: keyとして使わない値
19 # default: valのdefault値
20 if reserve < 1:
21 reserve = 1
22 self._keys: list[T] = [e] * reserve
23 self._vals: list[Any] = [0] * reserve
24 self._child = array("I", bytes(8 * reserve))
25 self._end: int = 1
26 self._root: int = 0
27 self._len: int = 0
28 self._default: Any = default
29 self._e: T = e
30
31 def reserve(self, n: int) -> None:
32 assert n >= 0, "ValueError"
33 self._keys += [self._e] * n
34 self._vals += [0] * n
35 self._child += array("I", bytes(8 * n))
36
37 def _make_node(self, key: T, val: Any) -> int:
38 if self._end >= len(self._keys):
39 self._keys.append(key)
40 self._vals.append(val)
41 self._child.append(0)
42 self._child.append(0)
43 else:
44 self._keys[self._end] = key
45 self._vals[self._end] = val
46 self._end += 1
47 return self._end - 1
48
49 def _set_search_splay(self, key: T) -> None:
50 node = self._root
51 keys, child = self._keys, self._child
52 if (not node) or keys[node] == key:
53 return
54 left, right = 0, 0
55 while keys[node] != key:
56 d = key > keys[node]
57 if not child[node << 1 | d]:
58 break
59 if (d and key > keys[child[node << 1 | 1]]) or (
60 d ^ 1 and key < keys[child[node << 1]]
61 ):
62 new = child[node << 1 | d]
63 child[node << 1 | d] = child[new << 1 | (d ^ 1)]
64 child[new << 1 | (d ^ 1)] = node
65 node = new
66 if not child[node << 1 | d]:
67 break
68 if d:
69 child[left << 1 | 1] = node
70 left = node
71 else:
72 child[right << 1] = node
73 right = node
74 node = child[node << 1 | d]
75 child[right << 1] = child[node << 1 | 1]
76 child[left << 1 | 1] = child[node << 1]
77 child[node << 1] = child[1]
78 child[node << 1 | 1] = child[0]
79 self._root = node
80
81 def _get_min_splay(self, node: int) -> int:
82 child = self._child
83 if (not node) or (not child[node << 1]):
84 return node
85 right = 0
86 while child[node << 1]:
87 new = child[node << 1]
88 child[node << 1] = child[new << 1 | 1]
89 child[new << 1 | 1] = node
90 if not child[new << 1]:
91 break
92 child[right << 1] = new
93 right = new
94 node = child[new << 1]
95 child[right << 1] = child[node << 1 | 1]
96 child[1] = child[node << 1]
97 child[node << 1] = child[1]
98 child[node << 1 | 1] = child[0]
99 return node
100
101 def __setitem__(self, key: T, val: Any):
102 if not self._root:
103 self._root = self._make_node(key, val)
104 self._len += 1
105 return
106 self._set_search_splay(key)
107 if self._keys[self._root] == key:
108 self._vals[self._root] = val
109 return
110 node = self._make_node(key, val)
111 d = self._keys[self._root] < key
112 self._child[node << 1 | (d ^ 1)] = self._root
113 self._child[node << 1 | d] = self._child[self._root << 1 | d]
114 self._child[self._root << 1 | d] = 0
115 self._root = node
116 self._len += 1
117
118 def __delitem__(self, key: T) -> None:
119 if self._root == 0:
120 return
121 self._set_search_splay(key)
122 if self._keys[self._root] != key:
123 return
124 if self._child[self._root << 1] == 0:
125 self._root = self._child[self._root << 1 | 1]
126 elif self._child[self._root << 1 | 1] == 0:
127 self._root = self._child[self._root << 1]
128 else:
129 node = self._get_min_splay(self._child[self._root << 1 | 1])
130 self._child[node << 1] = self._child[self._root << 1]
131 self._root = node
132 self._len -= 1
133
134 def tolist(self) -> list[tuple[T, Any]]:
135 node = self._root
136 child, keys, vals = self._child, self._keys, self._vals
137 stack, res = [], []
138 while stack or node:
139 if node:
140 stack.append(node)
141 node = child[node << 1]
142 else:
143 node = stack.pop()
144 res.append((keys[node], vals[node]))
145 node = child[node << 1 | 1]
146 return res
147
148 def keys(self) -> Iterator[T]:
149 node = self._root
150 child, keys = self._child, self._keys
151 stack = []
152 while stack or node:
153 if node:
154 stack.append(node)
155 node = child[node << 1]
156 else:
157 node = stack.pop()
158 yield keys[node]
159 node = child[node << 1 | 1]
160
161 def vals(self) -> Iterator[Any]:
162 node = self._root
163 child, vals = self._child, self._vals
164 stack = []
165 while stack or node:
166 if node:
167 stack.append(node)
168 node = child[node << 1]
169 else:
170 node = stack.pop()
171 yield vals[node]
172 node = child[node << 1 | 1]
173
174 def items(self) -> Iterator[tuple[T, Any]]:
175 node = self._root
176 child, keys, vals = self._child, self._keys, self._vals
177 stack = []
178 while stack or node:
179 if node:
180 stack.append(node)
181 node = child[node << 1]
182 else:
183 node = stack.pop()
184 yield (keys[node], vals[node])
185 node = child[node << 1 | 1]
186
187 def __getitem__(self, key: T) -> Any:
188 self._set_search_splay(key)
189 if self._root == 0 or self._keys[self._root] != key:
190 return self._default
191 return self._vals[self._root]
192
193 def __contains__(self, key: T):
194 self._set_search_splay(key)
195 return self._keys[self._root] == key
196
197 def __len__(self):
198 return self._len
199
200 def __bool__(self):
201 return self._root > 0
202
203 def __str__(self):
204 return "{" + ", ".join(map(str, self.tolist())) + "}"
205
206 def __repr__(self):
207 return f"SplayTreeDict({self})"