static_range_mode_query

ソースコード

from titan_pylib.data_structures.static_array_query.static_range_mode_query import StaticRangeModeQuery

view on github

展開済みコード

  1# from titan_pylib.data_structures.static_array_query.static_range_mode_query import StaticRangeModeQuery
  2from typing import Generic, Iterable, TypeVar
  3
  4T = TypeVar("T")
  5
  6
  7class StaticRangeModeQuery(Generic[T]):
  8    """静的な列に対する区間最頻値クエリに答えます。
  9    <構築 :math:`O(n\\sqrt{n})` , 空間 :math:`O(n)` , クエリ :math:`O(\\sqrt{n)})` >
 10
 11    参考: https://noshi91.hatenablog.com/entry/2020/10/26/140105
 12    """
 13
 14    @staticmethod
 15    def _sort_unique(a: list[T]) -> list[T]:
 16        if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
 17            a = sorted(a)
 18            new_a = [a[0]]
 19            for elm in a:
 20                if new_a[-1] == elm:
 21                    continue
 22                new_a.append(elm)
 23            a = new_a
 24        return a
 25
 26    def __init__(self, a: Iterable[T], compress: bool = True) -> None:
 27        """``a`` から ``StaticRangeModeQuery`` を構築します。
 28        :math:`O(n \\sqrt{n})` です。
 29
 30        Args:
 31          a (Iterable[T]):
 32          compress (bool, optional): ``False`` なら座標圧縮しません。
 33        """
 34
 35        a: list[T] = list(a)
 36        self.to_origin: list[T] = []
 37        self.compress: bool = compress
 38        if compress:
 39            self.to_origin = StaticRangeModeQuery._sort_unique(a)
 40            to_zaatsu: dict[T, int] = {x: i for i, x in enumerate(self.to_origin)}
 41            self.a: list[int] = [to_zaatsu[x] for x in a]
 42        else:
 43            assert max(a) < len(self.a), "ValueError"
 44            self.a: list[int] = a
 45
 46        self.n: int = len(self.a)
 47        self.u: int = max(self.a) + 1
 48        self.size: int = int(self.n**0.5) + 1
 49        self.bucket_cnt: int = (self.n + self.size - 1) // self.size
 50        self.data: list[list[int]] = [
 51            self.a[k * self.size : (k + 1) * self.size] for k in range(self.bucket_cnt)
 52        ]
 53
 54        # (freq, val)
 55        self.bucket_data: list[list[tuple[int, int]]] = [
 56            [(0, -1)] * (self.bucket_cnt + 1) for _ in range(self.bucket_cnt + 1)
 57        ]
 58        self._calc_all_blocks()
 59
 60        self.indx: list[list[int]] = [[] for _ in range(self.u)]
 61        self.inv_indx: list[int] = [-1] * self.n
 62        self._calc_index()
 63
 64    def _calc_all_blocks(self) -> None:
 65        """``bucket_data`` を計算します
 66        :math:`O(n \\sqrt{n})` です。
 67
 68        bucket_data[i][j] := data[i:j] の (freq, val)
 69        """
 70        data, bucket_data = self.data, self.bucket_data
 71        freqs = [0] * self.u
 72        for i in range(self.bucket_cnt):
 73            freq, val = -1, -1
 74            for j in range(i + 1, self.bucket_cnt + 1):
 75                for x in data[j - 1]:
 76                    freqs[x] += 1
 77                    if freqs[x] > freq:
 78                        freq, val = freqs[x], x
 79                bucket_data[i][j] = (freq, val)
 80            for j in range(i + 1, self.bucket_cnt + 1):
 81                for x in data[j - 1]:
 82                    freqs[x] = 0
 83
 84    def _calc_index(self):
 85        """``indx``, ``inv_indx`` を計算します
 86        :math:`O(n)` です。
 87
 88        indx[x]: 値 x の、 a におけるインデックス(昇順)
 89        inv_indx[i]: aにおける位置i の、indx[a[i]] でのインデックス
 90        """
 91        indx, inv_indx = self.indx, self.inv_indx
 92        for i, e in enumerate(self.a):
 93            inv_indx[i] = len(indx[e])
 94            indx[e].append(i)
 95
 96    def mode(self, l: int, r: int) -> tuple[T, int]:
 97        """区間 ``[l, r)`` の最頻値とその頻度を返します。
 98
 99        Args:
100          l (int):
101          r (int):
102
103        Returns:
104          tuple[T, int]: (最頻値, 頻度) のタプルです。
105        """
106        assert 0 <= l < r <= self.n
107        L, R = l, r
108        k1 = l // self.size
109        k2 = r // self.size
110        l -= k1 * self.size
111        r -= k2 * self.size
112
113        freq, val = 0, -1
114
115        if k1 == k2:
116            a, indx, inv_indx = self.a, self.indx, self.inv_indx
117            for i in range(L, R):
118                x = a[i]
119                k = inv_indx[i]
120                freq_cand = freq + 1
121                while (
122                    k + freq_cand - 1 < len(indx[x]) and indx[x][k + freq_cand - 1] < R
123                ):
124                    freq, val = freq_cand, x
125                    freq_cand += 1
126
127        else:
128            data, indx, inv_indx = self.data, self.indx, self.inv_indx
129
130            freq, val = self.bucket_data[k1 + 1][k2]
131
132            # left
133            for i in range(l, len(data[k1])):
134                x = data[k1][i]
135                k = inv_indx[k1 * self.size + i]
136                freq_cand = freq + 1
137                while (
138                    k + freq_cand - 1 < len(indx[x]) and indx[x][k + freq_cand - 1] < R
139                ):
140                    freq, val = freq_cand, x
141                    freq_cand += 1
142
143            # right
144            for i in range(r):
145                x = data[k2][i]
146                k = inv_indx[k2 * self.size + i]
147                freq_cand = freq + 1
148                while 0 <= k - (freq_cand - 1) and L <= indx[x][k - (freq_cand - 1)]:
149                    freq, val = freq_cand, x
150                    freq_cand += 1
151
152        val = self.to_origin[val] if self.compress else val
153        return val, freq

仕様

class StaticRangeModeQuery(a: Iterable[T], compress: bool = True)[source]

Bases: Generic[T]

静的な列に対する区間最頻値クエリに答えます。 <構築 \(O(n\sqrt{n})\) , 空間 \(O(n)\) , クエリ \(O(\sqrt{n)})\) >

参考: https://noshi91.hatenablog.com/entry/2020/10/26/140105

mode(l: int, r: int) tuple[T, int][source]

区間 [l, r) の最頻値とその頻度を返します。

Parameters:
  • l (int)

  • r (int)

Returns:

(最頻値, 頻度) のタプルです。

Return type:

tuple[T, int]