Source code for titan_pylib.data_structures.splay_tree.splay_tree_dict

  1from titan_pylib.my_class.supports_less_than import SupportsLessThan
  2from array import array
  3from typing import Generic, Iterator, TypeVar, Any
  4
  5T = TypeVar("T", bound=SupportsLessThan)
  6
  7
[docs] 8class SplayTreeDict(Generic[T]): 9 10 def __init__(self, e: T, default: Any = 0, reserve: int = 1) -> None: 11 # e: keyとして使わない値 12 # default: valのdefault値 13 if reserve < 1: 14 reserve = 1 15 self._keys: list[T] = [e] * reserve 16 self._vals: list[Any] = [0] * reserve 17 self._child = array("I", bytes(8 * reserve)) 18 self._end: int = 1 19 self._root: int = 0 20 self._len: int = 0 21 self._default: Any = default 22 self._e: T = e 23
[docs] 24 def reserve(self, n: int) -> None: 25 assert n >= 0, "ValueError" 26 self._keys += [self._e] * n 27 self._vals += [0] * n 28 self._child += array("I", bytes(8 * n))
29 30 def _make_node(self, key: T, val: Any) -> int: 31 if self._end >= len(self._keys): 32 self._keys.append(key) 33 self._vals.append(val) 34 self._child.append(0) 35 self._child.append(0) 36 else: 37 self._keys[self._end] = key 38 self._vals[self._end] = val 39 self._end += 1 40 return self._end - 1 41 42 def _set_search_splay(self, key: T) -> None: 43 node = self._root 44 keys, child = self._keys, self._child 45 if (not node) or keys[node] == key: 46 return 47 left, right = 0, 0 48 while keys[node] != key: 49 d = key > keys[node] 50 if not child[node << 1 | d]: 51 break 52 if (d and key > keys[child[node << 1 | 1]]) or ( 53 d ^ 1 and key < keys[child[node << 1]] 54 ): 55 new = child[node << 1 | d] 56 child[node << 1 | d] = child[new << 1 | (d ^ 1)] 57 child[new << 1 | (d ^ 1)] = node 58 node = new 59 if not child[node << 1 | d]: 60 break 61 if d: 62 child[left << 1 | 1] = node 63 left = node 64 else: 65 child[right << 1] = node 66 right = node 67 node = child[node << 1 | d] 68 child[right << 1] = child[node << 1 | 1] 69 child[left << 1 | 1] = child[node << 1] 70 child[node << 1] = child[1] 71 child[node << 1 | 1] = child[0] 72 self._root = node 73 74 def _get_min_splay(self, node: int) -> int: 75 child = self._child 76 if (not node) or (not child[node << 1]): 77 return node 78 right = 0 79 while child[node << 1]: 80 new = child[node << 1] 81 child[node << 1] = child[new << 1 | 1] 82 child[new << 1 | 1] = node 83 if not child[new << 1]: 84 break 85 child[right << 1] = new 86 right = new 87 node = child[new << 1] 88 child[right << 1] = child[node << 1 | 1] 89 child[1] = child[node << 1] 90 child[node << 1] = child[1] 91 child[node << 1 | 1] = child[0] 92 return node 93 94 def __setitem__(self, key: T, val: Any): 95 if not self._root: 96 self._root = self._make_node(key, val) 97 self._len += 1 98 return 99 self._set_search_splay(key) 100 if self._keys[self._root] == key: 101 self._vals[self._root] = val 102 return 103 node = self._make_node(key, val) 104 d = self._keys[self._root] < key 105 self._child[node << 1 | (d ^ 1)] = self._root 106 self._child[node << 1 | d] = self._child[self._root << 1 | d] 107 self._child[self._root << 1 | d] = 0 108 self._root = node 109 self._len += 1 110 111 def __delitem__(self, key: T) -> None: 112 if self._root == 0: 113 return 114 self._set_search_splay(key) 115 if self._keys[self._root] != key: 116 return 117 if self._child[self._root << 1] == 0: 118 self._root = self._child[self._root << 1 | 1] 119 elif self._child[self._root << 1 | 1] == 0: 120 self._root = self._child[self._root << 1] 121 else: 122 node = self._get_min_splay(self._child[self._root << 1 | 1]) 123 self._child[node << 1] = self._child[self._root << 1] 124 self._root = node 125 self._len -= 1 126
[docs] 127 def tolist(self) -> list[tuple[T, Any]]: 128 node = self._root 129 child, keys, vals = self._child, self._keys, self._vals 130 stack, res = [], [] 131 while stack or node: 132 if node: 133 stack.append(node) 134 node = child[node << 1] 135 else: 136 node = stack.pop() 137 res.append((keys[node], vals[node])) 138 node = child[node << 1 | 1] 139 return res
140
[docs] 141 def keys(self) -> Iterator[T]: 142 node = self._root 143 child, keys = self._child, self._keys 144 stack = [] 145 while stack or node: 146 if node: 147 stack.append(node) 148 node = child[node << 1] 149 else: 150 node = stack.pop() 151 yield keys[node] 152 node = child[node << 1 | 1]
153
[docs] 154 def vals(self) -> Iterator[Any]: 155 node = self._root 156 child, vals = self._child, self._vals 157 stack = [] 158 while stack or node: 159 if node: 160 stack.append(node) 161 node = child[node << 1] 162 else: 163 node = stack.pop() 164 yield vals[node] 165 node = child[node << 1 | 1]
166
[docs] 167 def items(self) -> Iterator[tuple[T, Any]]: 168 node = self._root 169 child, keys, vals = self._child, self._keys, self._vals 170 stack = [] 171 while stack or node: 172 if node: 173 stack.append(node) 174 node = child[node << 1] 175 else: 176 node = stack.pop() 177 yield (keys[node], vals[node]) 178 node = child[node << 1 | 1]
179 180 def __getitem__(self, key: T) -> Any: 181 self._set_search_splay(key) 182 if self._root == 0 or self._keys[self._root] != key: 183 return self._default 184 return self._vals[self._root] 185 186 def __contains__(self, key: T): 187 self._set_search_splay(key) 188 return self._keys[self._root] == key 189 190 def __len__(self): 191 return self._len 192 193 def __bool__(self): 194 return self._root > 0 195 196 def __str__(self): 197 return "{" + ", ".join(map(str, self.tolist())) + "}" 198 199 def __repr__(self): 200 return f"SplayTreeDict({self})"