multi_hash_string¶
ソースコード¶
from titan_pylib.string.multi_hash_string import MultiHashStringBase
from titan_pylib.string.multi_hash_string import MultiHashString
展開済みコード¶
1# from titan_pylib.string.multi_hash_string import MultiHashString
2# from titan_pylib.string.hash_string import HashStringBase, HashString
3# ref: https://qiita.com/keymoon/items/11fac5627672a6d6a9f6
4# from titan_pylib.data_structures.segment_tree.segment_tree import SegmentTree
5# from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
6# SegmentTreeInterface,
7# )
8from abc import ABC, abstractmethod
9from typing import TypeVar, Generic, Union, Iterable, Callable
10
11T = TypeVar("T")
12
13
14class SegmentTreeInterface(ABC, Generic[T]):
15
16 @abstractmethod
17 def __init__(self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T):
18 raise NotImplementedError
19
20 @abstractmethod
21 def set(self, k: int, v: T) -> None:
22 raise NotImplementedError
23
24 @abstractmethod
25 def get(self, k: int) -> T:
26 raise NotImplementedError
27
28 @abstractmethod
29 def prod(self, l: int, r: int) -> T:
30 raise NotImplementedError
31
32 @abstractmethod
33 def all_prod(self) -> T:
34 raise NotImplementedError
35
36 @abstractmethod
37 def max_right(self, l: int, f: Callable[[T], bool]) -> int:
38 raise NotImplementedError
39
40 @abstractmethod
41 def min_left(self, r: int, f: Callable[[T], bool]) -> int:
42 raise NotImplementedError
43
44 @abstractmethod
45 def tolist(self) -> list[T]:
46 raise NotImplementedError
47
48 @abstractmethod
49 def __getitem__(self, k: int) -> T:
50 raise NotImplementedError
51
52 @abstractmethod
53 def __setitem__(self, k: int, v: T) -> None:
54 raise NotImplementedError
55
56 @abstractmethod
57 def __str__(self):
58 raise NotImplementedError
59
60 @abstractmethod
61 def __repr__(self):
62 raise NotImplementedError
63from typing import Generic, Iterable, TypeVar, Callable, Union
64
65T = TypeVar("T")
66
67
68class SegmentTree(SegmentTreeInterface, Generic[T]):
69 """セグ木です。非再帰です。"""
70
71 def __init__(
72 self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T
73 ) -> None:
74 """``SegmentTree`` を構築します。
75 :math:`O(n)` です。
76
77 Args:
78 n_or_a (Union[int, Iterable[T]]): ``n: int`` のとき、 ``e`` を初期値として長さ ``n`` の ``SegmentTree`` を構築します。
79 ``a: Iterable[T]`` のとき、 ``a`` から ``SegmentTree`` を構築します。
80 op (Callable[[T, T], T]): 2項演算の関数です。
81 e (T): 単位元です。
82 """
83 self._op = op
84 self._e = e
85 if isinstance(n_or_a, int):
86 self._n = n_or_a
87 self._log = (self._n - 1).bit_length()
88 self._size = 1 << self._log
89 self._data = [e] * (self._size << 1)
90 else:
91 n_or_a = list(n_or_a)
92 self._n = len(n_or_a)
93 self._log = (self._n - 1).bit_length()
94 self._size = 1 << self._log
95 _data = [e] * (self._size << 1)
96 _data[self._size : self._size + self._n] = n_or_a
97 for i in range(self._size - 1, 0, -1):
98 _data[i] = op(_data[i << 1], _data[i << 1 | 1])
99 self._data = _data
100
101 def set(self, k: int, v: T) -> None:
102 """一点更新です。
103 :math:`O(\\log{n})` です。
104
105 Args:
106 k (int): 更新するインデックスです。
107 v (T): 更新する値です。
108
109 制約:
110 :math:`-n \\leq n \\leq k < n`
111 """
112 assert (
113 -self._n <= k < self._n
114 ), f"IndexError: {self.__class__.__name__}.set({k}, {v}), n={self._n}"
115 if k < 0:
116 k += self._n
117 k += self._size
118 self._data[k] = v
119 for _ in range(self._log):
120 k >>= 1
121 self._data[k] = self._op(self._data[k << 1], self._data[k << 1 | 1])
122
123 def get(self, k: int) -> T:
124 """一点取得です。
125 :math:`O(1)` です。
126
127 Args:
128 k (int): インデックスです。
129
130 制約:
131 :math:`-n \\leq n \\leq k < n`
132 """
133 assert (
134 -self._n <= k < self._n
135 ), f"IndexError: {self.__class__.__name__}.get({k}), n={self._n}"
136 if k < 0:
137 k += self._n
138 return self._data[k + self._size]
139
140 def prod(self, l: int, r: int) -> T:
141 """区間 ``[l, r)`` の総積を返します。
142 :math:`O(\\log{n})` です。
143
144 Args:
145 l (int): インデックスです。
146 r (int): インデックスです。
147
148 制約:
149 :math:`0 \\leq l \\leq r \\leq n`
150 """
151 assert (
152 0 <= l <= r <= self._n
153 ), f"IndexError: {self.__class__.__name__}.prod({l}, {r})"
154 l += self._size
155 r += self._size
156 lres = self._e
157 rres = self._e
158 while l < r:
159 if l & 1:
160 lres = self._op(lres, self._data[l])
161 l += 1
162 if r & 1:
163 rres = self._op(self._data[r ^ 1], rres)
164 l >>= 1
165 r >>= 1
166 return self._op(lres, rres)
167
168 def all_prod(self) -> T:
169 """区間 ``[0, n)`` の総積を返します。
170 :math:`O(1)` です。
171 """
172 return self._data[1]
173
174 def max_right(self, l: int, f: Callable[[T], bool]) -> int:
175 """Find the largest index R s.t. f([l, R)) == True. / O(\\log{n})"""
176 assert (
177 0 <= l <= self._n
178 ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range"
179 # assert f(self._e), \
180 # f'{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true.'
181 if l == self._n:
182 return self._n
183 l += self._size
184 s = self._e
185 while True:
186 while l & 1 == 0:
187 l >>= 1
188 if not f(self._op(s, self._data[l])):
189 while l < self._size:
190 l <<= 1
191 if f(self._op(s, self._data[l])):
192 s = self._op(s, self._data[l])
193 l |= 1
194 return l - self._size
195 s = self._op(s, self._data[l])
196 l += 1
197 if l & -l == l:
198 break
199 return self._n
200
201 def min_left(self, r: int, f: Callable[[T], bool]) -> int:
202 """Find the smallest index L s.t. f([L, r)) == True. / O(\\log{n})"""
203 assert (
204 0 <= r <= self._n
205 ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range"
206 # assert f(self._e), \
207 # f'{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true.'
208 if r == 0:
209 return 0
210 r += self._size
211 s = self._e
212 while True:
213 r -= 1
214 while r > 1 and r & 1:
215 r >>= 1
216 if not f(self._op(self._data[r], s)):
217 while r < self._size:
218 r = r << 1 | 1
219 if f(self._op(self._data[r], s)):
220 s = self._op(self._data[r], s)
221 r ^= 1
222 return r + 1 - self._size
223 s = self._op(self._data[r], s)
224 if r & -r == r:
225 break
226 return 0
227
228 def tolist(self) -> list[T]:
229 """リストにして返します。
230 :math:`O(n)` です。
231 """
232 return [self.get(i) for i in range(self._n)]
233
234 def show(self) -> None:
235 """デバッグ用のメソッドです。"""
236 print(
237 f"<{self.__class__.__name__}> [\n"
238 + "\n".join(
239 [
240 " "
241 + " ".join(
242 map(str, [self._data[(1 << i) + j] for j in range(1 << i)])
243 )
244 for i in range(self._log + 1)
245 ]
246 )
247 + "\n]"
248 )
249
250 def __getitem__(self, k: int) -> T:
251 assert (
252 -self._n <= k < self._n
253 ), f"IndexError: {self.__class__.__name__}.__getitem__({k}), n={self._n}"
254 return self.get(k)
255
256 def __setitem__(self, k: int, v: T):
257 assert (
258 -self._n <= k < self._n
259 ), f"IndexError: {self.__class__.__name__}.__setitem__{k}, {v}), n={self._n}"
260 self.set(k, v)
261
262 def __len__(self) -> int:
263 return self._n
264
265 def __str__(self) -> str:
266 return str(self.tolist())
267
268 def __repr__(self) -> str:
269 return f"{self.__class__.__name__}({self})"
270from typing import Optional, Final
271import random
272import string
273
274_titan_pylib_HashString_MOD: Final[int] = (1 << 61) - 1
275_titan_pylib_HashString_DIC: Final[dict[str, int]] = {
276 c: i for i, c in enumerate(string.ascii_lowercase, 1)
277}
278_titan_pylib_HashString_MASK30: Final[int] = (1 << 30) - 1
279_titan_pylib_HashString_MASK31: Final[int] = (1 << 31) - 1
280_titan_pylib_HashString_MASK61: Final[int] = _titan_pylib_HashString_MOD
281
282
283class HashStringBase:
284 """HashStringのベースクラスです。"""
285
286 def __init__(self, n: int = 0, base: int = -1, seed: Optional[int] = None) -> None:
287 """
288 :math:`O(n)` です。
289
290 Args:
291 n (int): 文字列の長さの上限です。上限を超えても問題ありません。
292 base (int, optional): Defaults to -1.
293 seed (Optional[int], optional): Defaults to None.
294 """
295 rand = random.Random(seed)
296 base = rand.randint(37, 10**9) if base < 0 else base
297 powb = [1] * (n + 1)
298 invb = [1] * (n + 1)
299 invbpow = pow(base, -1, _titan_pylib_HashString_MOD)
300 for i in range(1, n + 1):
301 powb[i] = HashStringBase.get_mul(powb[i - 1], base)
302 invb[i] = HashStringBase.get_mul(invb[i - 1], invbpow)
303 self.n = n
304 self.base = base
305 self.invpow = invbpow
306 self.powb = powb
307 self.invb = invb
308
309 @staticmethod
310 def get_mul(a: int, b: int) -> int:
311 au = a >> 31
312 ad = a & _titan_pylib_HashString_MASK31
313 bu = b >> 31
314 bd = b & _titan_pylib_HashString_MASK31
315 mid = ad * bu + au * bd
316 midu = mid >> 30
317 midd = mid & _titan_pylib_HashString_MASK30
318 return HashStringBase.get_mod(au * bu * 2 + midu + (midd << 31) + ad * bd)
319
320 @staticmethod
321 def get_mod(x: int) -> int:
322 xu = x >> 61
323 xd = x & _titan_pylib_HashString_MASK61
324 res = xu + xd
325 if res >= _titan_pylib_HashString_MOD:
326 res -= _titan_pylib_HashString_MOD
327 return res
328
329 def extend(self, cap: int) -> None:
330 pre_cap = len(self.powb)
331 powb, invb = self.powb, self.invb
332 powb += [0] * cap
333 invb += [0] * cap
334 invbpow = pow(self.base, -1, _titan_pylib_HashString_MOD)
335 for i in range(pre_cap, pre_cap + cap):
336 powb[i] = HashStringBase.get_mul(powb[i - 1], self.base)
337 invb[i] = HashStringBase.get_mul(invb[i - 1], invbpow)
338
339 def get_cap(self) -> int:
340 return len(self.powb)
341
342 def unite(self, h1: int, h2: int, k: int) -> int:
343 # len(h2) == k
344 # h1 <- h2
345 if k >= self.get_cap():
346 self.extend(k - self.get_cap() + 1)
347 return self.get_mod(self.get_mul(h1, self.powb[k]) + h2)
348
349
350class HashString:
351
352 def __init__(self, hsb: HashStringBase, s: str, update: bool = False) -> None:
353 """ロリハを構築します。
354 :math:`O(n)` です。
355
356 Args:
357 hsb (HashStringBase): ベースクラスです。
358 s (str): ロリハを構築する文字列です。
359 update (bool, optional): ``update=True`` のとき、1点更新が可能になります。
360 """
361 n = len(s)
362 data = [0] * n
363 acc = [0] * (n + 1)
364 if n >= hsb.get_cap():
365 hsb.extend(n - hsb.get_cap() + 1)
366 powb = hsb.powb
367 for i, c in enumerate(s):
368 data[i] = hsb.get_mul(powb[n - i - 1], _titan_pylib_HashString_DIC[c])
369 acc[i + 1] = hsb.get_mod(acc[i] + data[i])
370 self.hsb = hsb
371 self.n = n
372 self.acc = acc
373 self.used_seg = False
374 if update:
375 self.seg = SegmentTree(
376 data, lambda s, t: (s + t) % _titan_pylib_HashString_MOD, 0
377 )
378
379 def get(self, l: int, r: int) -> int:
380 """``s[l, r)`` のハッシュ値を返します。
381 1点更新処理後は :math:`O(\\log{n})` 、そうでなければ :math:`O(1)` です。
382
383 Args:
384 l (int): インデックスです。
385 r (int): インデックスです。
386
387 Returns:
388 int: ハッシュ値です。
389 """
390 assert 0 <= l <= r <= self.n
391 if self.used_seg:
392 return self.hsb.get_mul(self.seg.prod(l, r), self.hsb.invb[self.n - r])
393 return self.hsb.get_mul(
394 self.hsb.get_mod(self.acc[r] - self.acc[l]), self.hsb.invb[self.n - r]
395 )
396
397 def __getitem__(self, k: int) -> int:
398 """``s[k]`` のハッシュ値を返します。
399 1点更新処理後は :math:`O(\\log{n})` 、そうでなければ :math:`O(1)` です。
400
401 Args:
402 k (int): インデックスです。
403
404 Returns:
405 int: ハッシュ値です。
406 """
407 return self.get(k, k + 1)
408
409 def set(self, k: int, c: str) -> None:
410 """`k` 番目の文字を `c` に更新します。
411 :math:`O(\\log{n})` です。また、今後の ``get()`` が :math:`O(\\log{n})` になります。
412
413 Args:
414 k (int): インデックスです。
415 c (str): 更新する文字です。
416 """
417 self.used_seg = True
418 self.seg[k] = self.hsb.get_mul(
419 self.hsb.powb[self.n - k - 1], _titan_pylib_HashString_DIC[c]
420 )
421
422 def __setitem__(self, k: int, c: str) -> None:
423 return self.set(k, c)
424
425 def __len__(self):
426 return self.n
427
428 def get_lcp(self) -> list[int]:
429 """lcp配列を返します。
430 :math:`O(n\\log{n})` です。
431 """
432 a = [0] * self.n
433 memo = [-1] * (self.n + 1)
434 for i in range(self.n):
435 ok, ng = 0, self.n - i + 1
436 while ng - ok > 1:
437 mid = (ok + ng) >> 1
438 if memo[mid] == -1:
439 memo[mid] = self.get(0, mid)
440 if memo[mid] == self.get(i, i + mid):
441 ok = mid
442 else:
443 ng = mid
444 a[i] = ok
445 return a
446from typing import Optional
447import random
448
449
450class MultiHashStringBase:
451
452 def __init__(
453 self,
454 n: int,
455 base_cnt: int = 1,
456 base_list: list[int] = [],
457 seed: Optional[int] = None,
458 ) -> None:
459 if seed is None:
460 seed = random.randint(0, 10**9)
461 assert (
462 base_cnt > 0
463 ), f"ValueError: {self.__class__.__name__} base_cnt must be > 0"
464 base_list = (
465 base_list
466 if len(base_list) == base_cnt
467 else [random.randint(37, 10**9) for _ in range(base_cnt)]
468 )
469 hsb = tuple(HashStringBase(n, base_list[i]) for i in range(base_cnt))
470 self.hsb = hsb
471
472
473class MultiHashString:
474
475 def __init__(self, hsb: MultiHashStringBase, s: str, update: bool = False) -> None:
476 self.hsb = hsb
477 self.hs = tuple(HashString(hsb, s, update=update) for hsb in self.hsb.hsb)
478
479 def get(self, l: int, r: int) -> tuple[int]:
480 return tuple(hs.get(l, r) for hs in self.hs)
481
482 def __getitem__(self, k: int) -> tuple[int]:
483 return self.get(k, k + 1)
484
485 def set(self, k: int, c: str) -> None:
486 for hs in self.hs:
487 hs.set(k, c)
488
489 def __setitem__(self, k: int, c: str) -> None:
490 self.set(k, c)
仕様¶
- class MultiHashString(hsb: MultiHashStringBase, s: str, update: bool = False)[source]¶
Bases:
object