1# from titan_pylib.data_structures.set.hash_set import HashSet
2import random
3from typing import Iterable, Iterator
4
5
6class HashSet:
7
8 def __init__(self, a: Iterable[int] = [], not_seen: int = -1, deleted: int = -2):
9 self._keys: list[int] = [not_seen]
10 self._empty: int = not_seen
11 self._deleted: int = deleted
12 self._len: int = 0
13 self._dellen: int = 0
14 self._query_count: int = 0
15 self._being_rebuild: bool = False
16 self._xor: int = random.getrandbits(1)
17 for e in a:
18 self.add(e)
19
20 def reserve(self, n: int) -> None:
21 self._keys += [self._empty] * (3 * n + 4)
22 self._xor = random.getrandbits(len(self._keys).bit_length())
23
24 def _inner_rebuild(self, old_keys: list[int]) -> None:
25 _empty, _deleted = self._empty, self._deleted
26 self._len = 0
27 self._dellen = 0
28 self._being_rebuild = True
29 self._xor = random.getrandbits(len(self._keys).bit_length())
30 for k in old_keys:
31 if k != _empty and k != _deleted:
32 self.add(k)
33 self._query_count = 0
34 self._being_rebuild = False
35
36 def _rebuild(self) -> None:
37 old_keys, _empty = self._keys, self._empty
38 self._keys = [_empty for _ in old_keys]
39 self._inner_rebuild(old_keys)
40
41 def _rebuild_extend(self) -> None:
42 old_keys, _empty = self._keys, self._empty
43 self._keys = [_empty for _ in range(3 * len(old_keys) + 4)]
44 self._inner_rebuild(old_keys)
45
46 def _rebuid_shrink(self) -> None:
47 old_keys, _empty = self._keys, self._empty
48 self._keys = [_empty for _ in range(len(old_keys) // 3 + 4)]
49 self._inner_rebuild(old_keys)
50
51 def _query_check(self) -> None:
52 if self._being_rebuild:
53 return
54 self._query_count += 1
55 if self._len > 1000 and self._query_count * 3 > self._len:
56 self._rebuild()
57
58 def _hash(self, key: int) -> int:
59 return (key ^ self._xor) % len(self._keys)
60
61 def add(self, key: int) -> bool:
62 assert (
63 key != self._empty
64 ), f"ValueError: HashSet.add({key}), {key} cannot be equal to {self._empty}"
65 assert (
66 key != self._deleted
67 ), f"ValueError: HashSet.add({key}), {key} cannot be equal to {self._deleted}"
68 l, _keys, _empty, _deleted = (
69 len(self._keys),
70 self._keys,
71 self._empty,
72 self._deleted,
73 )
74 self._query_check()
75 H = self._hash(key)
76 for h in range(H, H + l):
77 if h >= l:
78 h -= l
79 if _keys[h] == _empty or _keys[h] == _deleted:
80 _keys[h] = key
81 self._len += 1
82 if 3 * self._len > len(self._keys):
83 self._rebuild_extend()
84 return True
85 elif _keys[h] == key:
86 return False
87 assert False
88
89 def discard(self, key: int) -> bool:
90 assert (
91 key != self._empty
92 ), f"ValueError: HashSet.discard({key}), {key} cannot be equal to {self._empty}"
93 assert (
94 key != self._deleted
95 ), f"ValueError: HashSet.discard({key}), {key} cannot be equal to {self._deleted}"
96 l, _keys, _empty = len(self._keys), self._keys, self._empty
97 self._query_check()
98 H = self._hash(key)
99 for h in range(H, H + l):
100 if h >= l:
101 h -= l
102 if _keys[h] == _empty:
103 return False
104 elif _keys[h] == key:
105 _keys[h] = self._deleted
106 self._dellen += 1
107 self._len -= 1
108 if 3 * 3 * self._len < len(self._keys):
109 self._rebuid_shrink()
110 if self._len > 1000 and self._dellen * 20 > self._len:
111 self._rebuild()
112 return True
113 assert False
114
115 def __contains__(self, key: int):
116 assert (
117 key != self._empty
118 ), f"ValueError: {key} in HashSet, {key} cannot be equal to {self._empty}"
119 assert (
120 key != self._deleted
121 ), f"ValueError: {key} in HashSet, {key} cannot be equal to {self._deleted}"
122 l, _keys, _empty = len(self._keys), self._keys, self._empty
123 self._query_check()
124 H = self._hash(key)
125 for h in range(H, H + l):
126 if h >= l:
127 h -= l
128 if _keys[h] == _empty:
129 return False
130 elif _keys[h] == key:
131 return True
132
133 def __iter__(self) -> Iterator[int]:
134 _empty, _deleted = self._empty, self._deleted
135 cnt = len(self)
136 for k in self._keys:
137 if k != _empty and k != _deleted:
138 cnt -= 1
139 yield k
140 if cnt == 0:
141 return
142
143 def __str__(self):
144 return "{" + ", ".join(map(str, self)) + "}"
145
146 def __len__(self):
147 return self._len