multi_hash_string

ソースコード

from titan_pylib.string.multi_hash_string import MultiHashStringBase
from titan_pylib.string.multi_hash_string import MultiHashString

view on github

展開済みコード

  1# from titan_pylib.string.multi_hash_string import MultiHashString
  2# from titan_pylib.string.hash_string import HashStringBase, HashString
  3# ref: https://qiita.com/keymoon/items/11fac5627672a6d6a9f6
  4# from titan_pylib.data_structures.segment_tree.segment_tree import SegmentTree
  5# from titan_pylib.data_structures.segment_tree.segment_tree_interface import (
  6#     SegmentTreeInterface,
  7# )
  8from abc import ABC, abstractmethod
  9from typing import TypeVar, Generic, Union, Iterable, Callable
 10
 11T = TypeVar("T")
 12
 13
 14class SegmentTreeInterface(ABC, Generic[T]):
 15
 16    @abstractmethod
 17    def __init__(self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T):
 18        raise NotImplementedError
 19
 20    @abstractmethod
 21    def set(self, k: int, v: T) -> None:
 22        raise NotImplementedError
 23
 24    @abstractmethod
 25    def get(self, k: int) -> T:
 26        raise NotImplementedError
 27
 28    @abstractmethod
 29    def prod(self, l: int, r: int) -> T:
 30        raise NotImplementedError
 31
 32    @abstractmethod
 33    def all_prod(self) -> T:
 34        raise NotImplementedError
 35
 36    @abstractmethod
 37    def max_right(self, l: int, f: Callable[[T], bool]) -> int:
 38        raise NotImplementedError
 39
 40    @abstractmethod
 41    def min_left(self, r: int, f: Callable[[T], bool]) -> int:
 42        raise NotImplementedError
 43
 44    @abstractmethod
 45    def tolist(self) -> list[T]:
 46        raise NotImplementedError
 47
 48    @abstractmethod
 49    def __getitem__(self, k: int) -> T:
 50        raise NotImplementedError
 51
 52    @abstractmethod
 53    def __setitem__(self, k: int, v: T) -> None:
 54        raise NotImplementedError
 55
 56    @abstractmethod
 57    def __str__(self):
 58        raise NotImplementedError
 59
 60    @abstractmethod
 61    def __repr__(self):
 62        raise NotImplementedError
 63from typing import Generic, Iterable, TypeVar, Callable, Union
 64
 65T = TypeVar("T")
 66
 67
 68class SegmentTree(SegmentTreeInterface, Generic[T]):
 69    """セグ木です。非再帰です。"""
 70
 71    def __init__(
 72        self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T
 73    ) -> None:
 74        """``SegmentTree`` を構築します。
 75        :math:`O(n)` です。
 76
 77        Args:
 78            n_or_a (Union[int, Iterable[T]]): ``n: int`` のとき、 ``e`` を初期値として長さ ``n`` の ``SegmentTree`` を構築します。
 79                                              ``a: Iterable[T]`` のとき、 ``a`` から ``SegmentTree`` を構築します。
 80            op (Callable[[T, T], T]): 2項演算の関数です。
 81            e (T): 単位元です。
 82        """
 83        self._op = op
 84        self._e = e
 85        if isinstance(n_or_a, int):
 86            self._n = n_or_a
 87            self._log = (self._n - 1).bit_length()
 88            self._size = 1 << self._log
 89            self._data = [e] * (self._size << 1)
 90        else:
 91            n_or_a = list(n_or_a)
 92            self._n = len(n_or_a)
 93            self._log = (self._n - 1).bit_length()
 94            self._size = 1 << self._log
 95            _data = [e] * (self._size << 1)
 96            _data[self._size : self._size + self._n] = n_or_a
 97            for i in range(self._size - 1, 0, -1):
 98                _data[i] = op(_data[i << 1], _data[i << 1 | 1])
 99            self._data = _data
100
101    def set(self, k: int, v: T) -> None:
102        """一点更新です。
103        :math:`O(\\log{n})` です。
104
105        Args:
106            k (int): 更新するインデックスです。
107            v (T): 更新する値です。
108
109        制約:
110            :math:`-n \\leq n \\leq k < n`
111        """
112        assert (
113            -self._n <= k < self._n
114        ), f"IndexError: {self.__class__.__name__}.set({k}, {v}), n={self._n}"
115        if k < 0:
116            k += self._n
117        k += self._size
118        self._data[k] = v
119        for _ in range(self._log):
120            k >>= 1
121            self._data[k] = self._op(self._data[k << 1], self._data[k << 1 | 1])
122
123    def get(self, k: int) -> T:
124        """一点取得です。
125        :math:`O(1)` です。
126
127        Args:
128            k (int): インデックスです。
129
130        制約:
131            :math:`-n \\leq n \\leq k < n`
132        """
133        assert (
134            -self._n <= k < self._n
135        ), f"IndexError: {self.__class__.__name__}.get({k}), n={self._n}"
136        if k < 0:
137            k += self._n
138        return self._data[k + self._size]
139
140    def prod(self, l: int, r: int) -> T:
141        """区間 ``[l, r)`` の総積を返します。
142        :math:`O(\\log{n})` です。
143
144        Args:
145            l (int): インデックスです。
146            r (int): インデックスです。
147
148        制約:
149            :math:`0 \\leq l \\leq r \\leq n`
150        """
151        assert (
152            0 <= l <= r <= self._n
153        ), f"IndexError: {self.__class__.__name__}.prod({l}, {r})"
154        l += self._size
155        r += self._size
156        lres = self._e
157        rres = self._e
158        while l < r:
159            if l & 1:
160                lres = self._op(lres, self._data[l])
161                l += 1
162            if r & 1:
163                rres = self._op(self._data[r ^ 1], rres)
164            l >>= 1
165            r >>= 1
166        return self._op(lres, rres)
167
168    def all_prod(self) -> T:
169        """区間 ``[0, n)`` の総積を返します。
170        :math:`O(1)` です。
171        """
172        return self._data[1]
173
174    def max_right(self, l: int, f: Callable[[T], bool]) -> int:
175        """Find the largest index R s.t. f([l, R)) == True. / O(\\log{n})"""
176        assert (
177            0 <= l <= self._n
178        ), f"IndexError: {self.__class__.__name__}.max_right({l}, f) index out of range"
179        # assert f(self._e), \
180        #     f'{self.__class__.__name__}.max_right({l}, f), f({self._e}) must be true.'
181        if l == self._n:
182            return self._n
183        l += self._size
184        s = self._e
185        while True:
186            while l & 1 == 0:
187                l >>= 1
188            if not f(self._op(s, self._data[l])):
189                while l < self._size:
190                    l <<= 1
191                    if f(self._op(s, self._data[l])):
192                        s = self._op(s, self._data[l])
193                        l |= 1
194                return l - self._size
195            s = self._op(s, self._data[l])
196            l += 1
197            if l & -l == l:
198                break
199        return self._n
200
201    def min_left(self, r: int, f: Callable[[T], bool]) -> int:
202        """Find the smallest index L s.t. f([L, r)) == True. / O(\\log{n})"""
203        assert (
204            0 <= r <= self._n
205        ), f"IndexError: {self.__class__.__name__}.min_left({r}, f) index out of range"
206        # assert f(self._e), \
207        #     f'{self.__class__.__name__}.min_left({r}, f), f({self._e}) must be true.'
208        if r == 0:
209            return 0
210        r += self._size
211        s = self._e
212        while True:
213            r -= 1
214            while r > 1 and r & 1:
215                r >>= 1
216            if not f(self._op(self._data[r], s)):
217                while r < self._size:
218                    r = r << 1 | 1
219                    if f(self._op(self._data[r], s)):
220                        s = self._op(self._data[r], s)
221                        r ^= 1
222                return r + 1 - self._size
223            s = self._op(self._data[r], s)
224            if r & -r == r:
225                break
226        return 0
227
228    def tolist(self) -> list[T]:
229        """リストにして返します。
230        :math:`O(n)` です。
231        """
232        return [self.get(i) for i in range(self._n)]
233
234    def show(self) -> None:
235        """デバッグ用のメソッドです。"""
236        print(
237            f"<{self.__class__.__name__}> [\n"
238            + "\n".join(
239                [
240                    "  "
241                    + " ".join(
242                        map(str, [self._data[(1 << i) + j] for j in range(1 << i)])
243                    )
244                    for i in range(self._log + 1)
245                ]
246            )
247            + "\n]"
248        )
249
250    def __getitem__(self, k: int) -> T:
251        assert (
252            -self._n <= k < self._n
253        ), f"IndexError: {self.__class__.__name__}.__getitem__({k}), n={self._n}"
254        return self.get(k)
255
256    def __setitem__(self, k: int, v: T):
257        assert (
258            -self._n <= k < self._n
259        ), f"IndexError: {self.__class__.__name__}.__setitem__{k}, {v}), n={self._n}"
260        self.set(k, v)
261
262    def __len__(self) -> int:
263        return self._n
264
265    def __str__(self) -> str:
266        return str(self.tolist())
267
268    def __repr__(self) -> str:
269        return f"{self.__class__.__name__}({self})"
270from typing import Optional, Final
271import random
272import string
273
274_titan_pylib_HashString_MOD: Final[int] = (1 << 61) - 1
275_titan_pylib_HashString_DIC: Final[dict[str, int]] = {
276    c: i for i, c in enumerate(string.ascii_lowercase, 1)
277}
278_titan_pylib_HashString_MASK30: Final[int] = (1 << 30) - 1
279_titan_pylib_HashString_MASK31: Final[int] = (1 << 31) - 1
280_titan_pylib_HashString_MASK61: Final[int] = _titan_pylib_HashString_MOD
281
282
283class HashStringBase:
284    """HashStringのベースクラスです。"""
285
286    def __init__(self, n: int = 0, base: int = -1, seed: Optional[int] = None) -> None:
287        """
288        :math:`O(n)` です。
289
290        Args:
291            n (int): 文字列の長さの上限です。上限を超えても問題ありません。
292            base (int, optional): Defaults to -1.
293            seed (Optional[int], optional): Defaults to None.
294        """
295        rand = random.Random(seed)
296        base = rand.randint(37, 10**9) if base < 0 else base
297        powb = [1] * (n + 1)
298        invb = [1] * (n + 1)
299        invbpow = pow(base, -1, _titan_pylib_HashString_MOD)
300        for i in range(1, n + 1):
301            powb[i] = HashStringBase.get_mul(powb[i - 1], base)
302            invb[i] = HashStringBase.get_mul(invb[i - 1], invbpow)
303        self.n = n
304        self.base = base
305        self.invpow = invbpow
306        self.powb = powb
307        self.invb = invb
308
309    @staticmethod
310    def get_mul(a: int, b: int) -> int:
311        au = a >> 31
312        ad = a & _titan_pylib_HashString_MASK31
313        bu = b >> 31
314        bd = b & _titan_pylib_HashString_MASK31
315        mid = ad * bu + au * bd
316        midu = mid >> 30
317        midd = mid & _titan_pylib_HashString_MASK30
318        return HashStringBase.get_mod(au * bu * 2 + midu + (midd << 31) + ad * bd)
319
320    @staticmethod
321    def get_mod(x: int) -> int:
322        xu = x >> 61
323        xd = x & _titan_pylib_HashString_MASK61
324        res = xu + xd
325        if res >= _titan_pylib_HashString_MOD:
326            res -= _titan_pylib_HashString_MOD
327        return res
328
329    def extend(self, cap: int) -> None:
330        pre_cap = len(self.powb)
331        powb, invb = self.powb, self.invb
332        powb += [0] * cap
333        invb += [0] * cap
334        invbpow = pow(self.base, -1, _titan_pylib_HashString_MOD)
335        for i in range(pre_cap, pre_cap + cap):
336            powb[i] = HashStringBase.get_mul(powb[i - 1], self.base)
337            invb[i] = HashStringBase.get_mul(invb[i - 1], invbpow)
338
339    def get_cap(self) -> int:
340        return len(self.powb)
341
342    def unite(self, h1: int, h2: int, k: int) -> int:
343        # len(h2) == k
344        # h1 <- h2
345        if k >= self.get_cap():
346            self.extend(k - self.get_cap() + 1)
347        return self.get_mod(self.get_mul(h1, self.powb[k]) + h2)
348
349
350class HashString:
351
352    def __init__(self, hsb: HashStringBase, s: str, update: bool = False) -> None:
353        """ロリハを構築します。
354        :math:`O(n)` です。
355
356        Args:
357            hsb (HashStringBase): ベースクラスです。
358            s (str): ロリハを構築する文字列です。
359            update (bool, optional): ``update=True`` のとき、1点更新が可能になります。
360        """
361        n = len(s)
362        data = [0] * n
363        acc = [0] * (n + 1)
364        if n >= hsb.get_cap():
365            hsb.extend(n - hsb.get_cap() + 1)
366        powb = hsb.powb
367        for i, c in enumerate(s):
368            data[i] = hsb.get_mul(powb[n - i - 1], _titan_pylib_HashString_DIC[c])
369            acc[i + 1] = hsb.get_mod(acc[i] + data[i])
370        self.hsb = hsb
371        self.n = n
372        self.acc = acc
373        self.used_seg = False
374        if update:
375            self.seg = SegmentTree(
376                data, lambda s, t: (s + t) % _titan_pylib_HashString_MOD, 0
377            )
378
379    def get(self, l: int, r: int) -> int:
380        """``s[l, r)`` のハッシュ値を返します。
381        1点更新処理後は :math:`O(\\log{n})` 、そうでなければ :math:`O(1)` です。
382
383        Args:
384            l (int): インデックスです。
385            r (int): インデックスです。
386
387        Returns:
388            int: ハッシュ値です。
389        """
390        assert 0 <= l <= r <= self.n
391        if self.used_seg:
392            return self.hsb.get_mul(self.seg.prod(l, r), self.hsb.invb[self.n - r])
393        return self.hsb.get_mul(
394            self.hsb.get_mod(self.acc[r] - self.acc[l]), self.hsb.invb[self.n - r]
395        )
396
397    def __getitem__(self, k: int) -> int:
398        """``s[k]`` のハッシュ値を返します。
399        1点更新処理後は :math:`O(\\log{n})` 、そうでなければ :math:`O(1)` です。
400
401        Args:
402            k (int): インデックスです。
403
404        Returns:
405            int: ハッシュ値です。
406        """
407        return self.get(k, k + 1)
408
409    def set(self, k: int, c: str) -> None:
410        """`k` 番目の文字を `c` に更新します。
411        :math:`O(\\log{n})` です。また、今後の ``get()`` が :math:`O(\\log{n})` になります。
412
413        Args:
414            k (int): インデックスです。
415            c (str): 更新する文字です。
416        """
417        self.used_seg = True
418        self.seg[k] = self.hsb.get_mul(
419            self.hsb.powb[self.n - k - 1], _titan_pylib_HashString_DIC[c]
420        )
421
422    def __setitem__(self, k: int, c: str) -> None:
423        return self.set(k, c)
424
425    def __len__(self):
426        return self.n
427
428    def get_lcp(self) -> list[int]:
429        """lcp配列を返します。
430        :math:`O(n\\log{n})` です。
431        """
432        a = [0] * self.n
433        memo = [-1] * (self.n + 1)
434        for i in range(self.n):
435            ok, ng = 0, self.n - i + 1
436            while ng - ok > 1:
437                mid = (ok + ng) >> 1
438                if memo[mid] == -1:
439                    memo[mid] = self.get(0, mid)
440                if memo[mid] == self.get(i, i + mid):
441                    ok = mid
442                else:
443                    ng = mid
444            a[i] = ok
445        return a
446from typing import Optional
447import random
448
449
450class MultiHashStringBase:
451
452    def __init__(
453        self,
454        n: int,
455        base_cnt: int = 1,
456        base_list: list[int] = [],
457        seed: Optional[int] = None,
458    ) -> None:
459        if seed is None:
460            seed = random.randint(0, 10**9)
461        assert (
462            base_cnt > 0
463        ), f"ValueError: {self.__class__.__name__} base_cnt must be > 0"
464        base_list = (
465            base_list
466            if len(base_list) == base_cnt
467            else [random.randint(37, 10**9) for _ in range(base_cnt)]
468        )
469        hsb = tuple(HashStringBase(n, base_list[i]) for i in range(base_cnt))
470        self.hsb = hsb
471
472
473class MultiHashString:
474
475    def __init__(self, hsb: MultiHashStringBase, s: str, update: bool = False) -> None:
476        self.hsb = hsb
477        self.hs = tuple(HashString(hsb, s, update=update) for hsb in self.hsb.hsb)
478
479    def get(self, l: int, r: int) -> tuple[int]:
480        return tuple(hs.get(l, r) for hs in self.hs)
481
482    def __getitem__(self, k: int) -> tuple[int]:
483        return self.get(k, k + 1)
484
485    def set(self, k: int, c: str) -> None:
486        for hs in self.hs:
487            hs.set(k, c)
488
489    def __setitem__(self, k: int, c: str) -> None:
490        self.set(k, c)

仕様

class MultiHashString(hsb: MultiHashStringBase, s: str, update: bool = False)[source]

Bases: object

get(l: int, r: int) tuple[int][source]
set(k: int, c: str) None[source]
class MultiHashStringBase(n: int, base_cnt: int = 1, base_list: list[int] = [], seed: int | None = None)[source]

Bases: object