Source code for titan_pylib.data_structures.dict.hash_dict

  1import random
  2from typing import Iterator, Any, Final
  3
  4_titan_pylib_HashDict_K: Final[int] = 0x517CC1B727220A95
  5
  6
[docs] 7class HashDict: 8 """ハッシュテーブルです。 9 組み込み辞書の ``dict`` よりやや遅いです。 10 """ 11 12 def __init__(self, e: int = -1, default: Any = 0, reserve: int = -1) -> None: 13 """ 14 Args: 15 e (int, optional): ``int`` 型で ``key`` として使用しない値です。 16 ``key`` を ``int`` 型以外のもので指定したいときは ``_hash(key) -> int`` 関数をいじってください。 17 default (Any, optional): 存在しないキーにアクセスしたときの値です。 18 """ 19 # e: keyとして使わない値 20 # default: valのdefault値 21 self._keys: list[int] = [e] 22 self._vals: list[Any] = [default] 23 self._msk: int = 0 24 self._xor: int = random.getrandbits(1) 25 if reserve > 0: 26 self._keys: list[int] = [e] * (1 << (reserve.bit_length())) 27 self._vals: list[Any] = [default] * (1 << (reserve.bit_length())) 28 self._msk = (1 << (len(self._keys) - 1).bit_length()) - 1 29 self._xor = random.getrandbits((len(self._keys) - 1).bit_length()) 30 self._e: int = e 31 self._len: int = 0 32 self._default: Any = default 33 34 def _rebuild(self) -> None: 35 old_keys, old_vals, _e = self._keys, self._vals, self._e 36 self._keys = [_e] * (2 * len(old_keys)) 37 self._vals = [self._default] * len(self._keys) 38 self._len = 0 39 self._msk = (1 << (len(self._keys) - 1).bit_length()) - 1 40 self._xor = random.getrandbits((len(self._keys) - 1).bit_length()) 41 for i in range(len(old_keys)): 42 if old_keys[i] != _e: 43 self.set(old_keys[i], old_vals[i]) 44 45 def _hash(self, key: int) -> int: 46 return ( 47 ((((key >> 32) & self._msk) ^ (key & self._msk) ^ self._xor)) 48 * (_titan_pylib_HashDict_K & self._msk) 49 ) & self._msk 50
[docs] 51 def get(self, key: int, default: Any = None) -> Any: 52 """ 53 キーが ``key`` の値を返します。 54 存在しない場合、引数 ``default`` に ``None`` 以外を指定した場合は ``default`` が、 55 そうでない場合はコンストラクタで設定した ``default`` が返ります。 56 57 期待 :math:`O(1)` です。 58 """ 59 assert ( 60 key != self._e 61 ), f"KeyError: HashDict.get({key}, {default}), {key} cannot be equal to {self._e}" 62 l, _keys, _e = len(self._keys), self._keys, self._e 63 h = self._hash(key) 64 while True: 65 x = _keys[h] 66 if x == _e: 67 return self._vals[h] if default is None else default 68 if x == key: 69 return self._vals[h] 70 h = 0 if h == l - 1 else h + 1
71
[docs] 72 def add(self, key: int, val: Any, default: Any) -> None: 73 assert ( 74 key != self._e 75 ), f"KeyError: HashDict.add({key}, {default}), {key} cannot be equal to {self._e}" 76 l, _keys, _e = len(self._keys), self._keys, self._e 77 h = self._hash(key) 78 while True: 79 x = _keys[h] 80 if x == _e: 81 self._vals[h] = val 82 return 83 if x == key: 84 self._vals[h] += val 85 return 86 h = 0 if h == l - 1 else h + 1
87
[docs] 88 def set(self, key: int, val: Any) -> None: 89 """キーを ``key`` として ``val`` を格納します。 90 ``key`` が既に存在している場合は上書きされます。 91 92 期待 :math:`O(1)` です。 93 """ 94 assert ( 95 key != self._e 96 ), f"KeyError: HashDict.set({key}, {val}), {key} cannot be equal to {self._e}" 97 l, _keys, _e = len(self._keys), self._keys, self._e 98 l -= 1 99 h = self._hash(key) 100 while True: 101 x = _keys[h] 102 if x == _e: 103 _keys[h] = key 104 self._vals[h] = val 105 self._len += 1 106 if 2 * self._len > len(self._keys): 107 self._rebuild() 108 return 109 if x == key: 110 self._vals[h] = val 111 return 112 h = 0 if h == l else h + 1
113
[docs] 114 def __contains__(self, key: int) -> bool: 115 """存在判定です。 116 117 期待 :math:`O(1)` です。 118 119 Returns: 120 bool: ``key`` が存在すれば ``True`` を、そうでなければ ``False`` を返します。 121 """ 122 assert ( 123 key != self._e 124 ), f"KeyError: {key} in HashDict, {key} cannot be equal to {self._e}" 125 l, _keys, _e = len(self._keys), self._keys, self._e 126 h = self._hash(key) 127 while True: 128 x = _keys[h] 129 if x == _e: 130 return False 131 if x == key: 132 return True 133 h += 1 134 if h == l: 135 h = 0
136 137 __getitem__ = get 138 __setitem__ = set 139
[docs] 140 def keys(self) -> Iterator[int]: 141 """``key 集合`` を列挙するイテレータです。""" 142 _keys, _e = self._keys, self._e 143 for i in range(len(_keys)): 144 if _keys[i] != _e: 145 yield _keys[i]
146
[docs] 147 def values(self) -> Iterator[Any]: 148 """``val 集合`` を列挙するイテレータです。""" 149 _keys, _vals, _e = self._keys, self._vals, self._e 150 for i in range(len(_keys)): 151 if _keys[i] != _e: 152 yield _vals[i]
153
[docs] 154 def items(self) -> Iterator[tuple[int, Any]]: 155 """``key とそれに対応する val のタプル`` を列挙するイテレータです。""" 156 _keys, _vals, _e = self._keys, self._vals, self._e 157 for i in range(len(_keys)): 158 if _keys[i] != _e: 159 yield _keys[i], _vals[i]
160 161 def __str__(self): 162 return "{" + ", ".join(map(lambda x: f"{x[0]}: {x[1]}", self.items())) + "}" 163 164 def __len__(self): 165 return self._len