Source code for titan_pylib.data_structures.set.hash_set

  1import random
  2from typing import Iterable, Iterator
  3
  4
[docs] 5class HashSet: 6 7 def __init__(self, a: Iterable[int] = [], not_seen: int = -1, deleted: int = -2): 8 self._keys: list[int] = [not_seen] 9 self._empty: int = not_seen 10 self._deleted: int = deleted 11 self._len: int = 0 12 self._dellen: int = 0 13 self._query_count: int = 0 14 self._being_rebuild: bool = False 15 self._xor: int = random.getrandbits(1) 16 for e in a: 17 self.add(e) 18
[docs] 19 def reserve(self, n: int) -> None: 20 self._keys += [self._empty] * (3 * n + 4) 21 self._xor = random.getrandbits(len(self._keys).bit_length())
22 23 def _inner_rebuild(self, old_keys: list[int]) -> None: 24 _empty, _deleted = self._empty, self._deleted 25 self._len = 0 26 self._dellen = 0 27 self._being_rebuild = True 28 self._xor = random.getrandbits(len(self._keys).bit_length()) 29 for k in old_keys: 30 if k != _empty and k != _deleted: 31 self.add(k) 32 self._query_count = 0 33 self._being_rebuild = False 34 35 def _rebuild(self) -> None: 36 old_keys, _empty = self._keys, self._empty 37 self._keys = [_empty for _ in old_keys] 38 self._inner_rebuild(old_keys) 39 40 def _rebuild_extend(self) -> None: 41 old_keys, _empty = self._keys, self._empty 42 self._keys = [_empty for _ in range(3 * len(old_keys) + 4)] 43 self._inner_rebuild(old_keys) 44 45 def _rebuid_shrink(self) -> None: 46 old_keys, _empty = self._keys, self._empty 47 self._keys = [_empty for _ in range(len(old_keys) // 3 + 4)] 48 self._inner_rebuild(old_keys) 49 50 def _query_check(self) -> None: 51 if self._being_rebuild: 52 return 53 self._query_count += 1 54 if self._len > 1000 and self._query_count * 3 > self._len: 55 self._rebuild() 56 57 def _hash(self, key: int) -> int: 58 return (key ^ self._xor) % len(self._keys) 59
[docs] 60 def add(self, key: int) -> bool: 61 assert ( 62 key != self._empty 63 ), f"ValueError: HashSet.add({key}), {key} cannot be equal to {self._empty}" 64 assert ( 65 key != self._deleted 66 ), f"ValueError: HashSet.add({key}), {key} cannot be equal to {self._deleted}" 67 l, _keys, _empty, _deleted = ( 68 len(self._keys), 69 self._keys, 70 self._empty, 71 self._deleted, 72 ) 73 self._query_check() 74 H = self._hash(key) 75 for h in range(H, H + l): 76 if h >= l: 77 h -= l 78 if _keys[h] == _empty or _keys[h] == _deleted: 79 _keys[h] = key 80 self._len += 1 81 if 3 * self._len > len(self._keys): 82 self._rebuild_extend() 83 return True 84 elif _keys[h] == key: 85 return False 86 assert False
87
[docs] 88 def discard(self, key: int) -> bool: 89 assert ( 90 key != self._empty 91 ), f"ValueError: HashSet.discard({key}), {key} cannot be equal to {self._empty}" 92 assert ( 93 key != self._deleted 94 ), f"ValueError: HashSet.discard({key}), {key} cannot be equal to {self._deleted}" 95 l, _keys, _empty = len(self._keys), self._keys, self._empty 96 self._query_check() 97 H = self._hash(key) 98 for h in range(H, H + l): 99 if h >= l: 100 h -= l 101 if _keys[h] == _empty: 102 return False 103 elif _keys[h] == key: 104 _keys[h] = self._deleted 105 self._dellen += 1 106 self._len -= 1 107 if 3 * 3 * self._len < len(self._keys): 108 self._rebuid_shrink() 109 if self._len > 1000 and self._dellen * 20 > self._len: 110 self._rebuild() 111 return True 112 assert False
113 114 def __contains__(self, key: int): 115 assert ( 116 key != self._empty 117 ), f"ValueError: {key} in HashSet, {key} cannot be equal to {self._empty}" 118 assert ( 119 key != self._deleted 120 ), f"ValueError: {key} in HashSet, {key} cannot be equal to {self._deleted}" 121 l, _keys, _empty = len(self._keys), self._keys, self._empty 122 self._query_check() 123 H = self._hash(key) 124 for h in range(H, H + l): 125 if h >= l: 126 h -= l 127 if _keys[h] == _empty: 128 return False 129 elif _keys[h] == key: 130 return True 131 132 def __iter__(self) -> Iterator[int]: 133 _empty, _deleted = self._empty, self._deleted 134 cnt = len(self) 135 for k in self._keys: 136 if k != _empty and k != _deleted: 137 cnt -= 1 138 yield k 139 if cnt == 0: 140 return 141 142 def __str__(self): 143 return "{" + ", ".join(map(str, self)) + "}" 144 145 def __len__(self): 146 return self._len