1import random
2from typing import Iterable, Iterator
3
4
[docs]
5class HashSet:
6
7 def __init__(self, a: Iterable[int] = [], not_seen: int = -1, deleted: int = -2):
8 self._keys: list[int] = [not_seen]
9 self._empty: int = not_seen
10 self._deleted: int = deleted
11 self._len: int = 0
12 self._dellen: int = 0
13 self._query_count: int = 0
14 self._being_rebuild: bool = False
15 self._xor: int = random.getrandbits(1)
16 for e in a:
17 self.add(e)
18
[docs]
19 def reserve(self, n: int) -> None:
20 self._keys += [self._empty] * (3 * n + 4)
21 self._xor = random.getrandbits(len(self._keys).bit_length())
22
23 def _inner_rebuild(self, old_keys: list[int]) -> None:
24 _empty, _deleted = self._empty, self._deleted
25 self._len = 0
26 self._dellen = 0
27 self._being_rebuild = True
28 self._xor = random.getrandbits(len(self._keys).bit_length())
29 for k in old_keys:
30 if k != _empty and k != _deleted:
31 self.add(k)
32 self._query_count = 0
33 self._being_rebuild = False
34
35 def _rebuild(self) -> None:
36 old_keys, _empty = self._keys, self._empty
37 self._keys = [_empty for _ in old_keys]
38 self._inner_rebuild(old_keys)
39
40 def _rebuild_extend(self) -> None:
41 old_keys, _empty = self._keys, self._empty
42 self._keys = [_empty for _ in range(3 * len(old_keys) + 4)]
43 self._inner_rebuild(old_keys)
44
45 def _rebuid_shrink(self) -> None:
46 old_keys, _empty = self._keys, self._empty
47 self._keys = [_empty for _ in range(len(old_keys) // 3 + 4)]
48 self._inner_rebuild(old_keys)
49
50 def _query_check(self) -> None:
51 if self._being_rebuild:
52 return
53 self._query_count += 1
54 if self._len > 1000 and self._query_count * 3 > self._len:
55 self._rebuild()
56
57 def _hash(self, key: int) -> int:
58 return (key ^ self._xor) % len(self._keys)
59
[docs]
60 def add(self, key: int) -> bool:
61 assert (
62 key != self._empty
63 ), f"ValueError: HashSet.add({key}), {key} cannot be equal to {self._empty}"
64 assert (
65 key != self._deleted
66 ), f"ValueError: HashSet.add({key}), {key} cannot be equal to {self._deleted}"
67 l, _keys, _empty, _deleted = (
68 len(self._keys),
69 self._keys,
70 self._empty,
71 self._deleted,
72 )
73 self._query_check()
74 H = self._hash(key)
75 for h in range(H, H + l):
76 if h >= l:
77 h -= l
78 if _keys[h] == _empty or _keys[h] == _deleted:
79 _keys[h] = key
80 self._len += 1
81 if 3 * self._len > len(self._keys):
82 self._rebuild_extend()
83 return True
84 elif _keys[h] == key:
85 return False
86 assert False
87
[docs]
88 def discard(self, key: int) -> bool:
89 assert (
90 key != self._empty
91 ), f"ValueError: HashSet.discard({key}), {key} cannot be equal to {self._empty}"
92 assert (
93 key != self._deleted
94 ), f"ValueError: HashSet.discard({key}), {key} cannot be equal to {self._deleted}"
95 l, _keys, _empty = len(self._keys), self._keys, self._empty
96 self._query_check()
97 H = self._hash(key)
98 for h in range(H, H + l):
99 if h >= l:
100 h -= l
101 if _keys[h] == _empty:
102 return False
103 elif _keys[h] == key:
104 _keys[h] = self._deleted
105 self._dellen += 1
106 self._len -= 1
107 if 3 * 3 * self._len < len(self._keys):
108 self._rebuid_shrink()
109 if self._len > 1000 and self._dellen * 20 > self._len:
110 self._rebuild()
111 return True
112 assert False
113
114 def __contains__(self, key: int):
115 assert (
116 key != self._empty
117 ), f"ValueError: {key} in HashSet, {key} cannot be equal to {self._empty}"
118 assert (
119 key != self._deleted
120 ), f"ValueError: {key} in HashSet, {key} cannot be equal to {self._deleted}"
121 l, _keys, _empty = len(self._keys), self._keys, self._empty
122 self._query_check()
123 H = self._hash(key)
124 for h in range(H, H + l):
125 if h >= l:
126 h -= l
127 if _keys[h] == _empty:
128 return False
129 elif _keys[h] == key:
130 return True
131
132 def __iter__(self) -> Iterator[int]:
133 _empty, _deleted = self._empty, self._deleted
134 cnt = len(self)
135 for k in self._keys:
136 if k != _empty and k != _deleted:
137 cnt -= 1
138 yield k
139 if cnt == 0:
140 return
141
142 def __str__(self):
143 return "{" + ", ".join(map(str, self)) + "}"
144
145 def __len__(self):
146 return self._len