min_max_multiset

ソースコード

from titan_pylib.data_structures.set.min_max_multiset import MinMaxMultiset

view on github

展開済みコード

  1# from titan_pylib.data_structures.set.min_max_multiset import MinMaxMultiset
  2# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  3from typing import Protocol
  4
  5
  6class SupportsLessThan(Protocol):
  7
  8    def __lt__(self, other) -> bool: ...
  9# from titan_pylib.data_structures.heap.double_ended_heap import DoubleEndedHeap
 10# from titan_pylib.my_class.supports_less_than import SupportsLessThan
 11from typing import Generic, Iterable, TypeVar
 12
 13T = TypeVar("T", bound=SupportsLessThan)
 14
 15
 16class DoubleEndedHeap(Generic[T]):
 17    """
 18    - [両端優先度付きキューのInterval-Heap実装](https://natsugiri.hatenablog.com/entry/2016/10/10/035445)
 19    - [Double-ended priority queue(wikipedia)](https://en.wikipedia.org/wiki/Double-ended_priority_queue)
 20    """
 21
 22    def __init__(self, a: Iterable[T] = []) -> None:
 23        """構築します。
 24        :math:`O(n)` です。
 25
 26        Args:
 27          a (Iterable[T], optional): 構築元の配列です。
 28        """
 29        self._data = list(a)
 30        self._heapify()
 31
 32    def _heapify(self) -> None:
 33        n = len(self._data)
 34        for i in range(n - 1, -1, -1):
 35            if i & 1 and self._data[i - 1] < self._data[i]:
 36                self._data[i - 1], self._data[i] = self._data[i], self._data[i - 1]
 37            k = self._down(i)
 38            self._up(k, i)
 39
 40    def add(self, key: T) -> None:
 41        """``key`` を1つ追加します。
 42        :math:`O(\\log{n})` です。
 43        """
 44        self._data.append(key)
 45        self._up(len(self._data) - 1)
 46
 47    def pop_min(self) -> T:
 48        """最小の要素を削除して返します。
 49        :math:`O(\\log{n})` です。
 50        """
 51        if len(self._data) < 3:
 52            res = self._data.pop()
 53        else:
 54            self._data[1], self._data[-1] = self._data[-1], self._data[1]
 55            res = self._data.pop()
 56            k = self._down(1)
 57            self._up(k)
 58        return res
 59
 60    def pop_max(self) -> T:
 61        """最大の要素を削除して返します。
 62        :math:`O(\\log{n})` です。
 63        """
 64        if len(self._data) < 2:
 65            res = self._data.pop()
 66        else:
 67            self._data[0], self._data[-1] = self._data[-1], self._data[0]
 68            res = self._data.pop()
 69            self._up(self._down(0))
 70        return res
 71
 72    def get_min(self) -> T:
 73        """最小の要素を返します。
 74        :math:`O(1)` です。
 75        """
 76        return self._data[0] if len(self._data) < 2 else self._data[1]
 77
 78    def get_max(self) -> T:
 79        """最大の要素を返します。
 80        :math:`O(1)` です。
 81        """
 82        return self._data[0]
 83
 84    def __len__(self) -> int:
 85        return len(self._data)
 86
 87    def __bool__(self) -> bool:
 88        return len(self._data) > 0
 89
 90    def _parent(self, k: int) -> int:
 91        return ((k >> 1) - 1) & ~1
 92
 93    def _down(self, k: int) -> int:
 94        n = len(self._data)
 95        if k & 1:
 96            while k << 1 | 1 < n:
 97                c = 2 * k + 3
 98                if n <= c or self._data[c - 2] < self._data[c]:
 99                    c -= 2
100                if c < n and self._data[c] < self._data[k]:
101                    self._data[k], self._data[c] = self._data[c], self._data[k]
102                    k = c
103                else:
104                    break
105        else:
106            while 2 * k + 2 < n:
107                c = 2 * k + 4
108                if n <= c or self._data[c] < self._data[c - 2]:
109                    c -= 2
110                if c < n and self._data[k] < self._data[c]:
111                    self._data[k], self._data[c] = self._data[c], self._data[k]
112                    k = c
113                else:
114                    break
115        return k
116
117    def _up(self, k: int, root: int = 1) -> int:
118        if (k | 1) < len(self._data) and self._data[k & ~1] < self._data[k | 1]:
119            self._data[k & ~1], self._data[k | 1] = (
120                self._data[k | 1],
121                self._data[k & ~1],
122            )
123            k ^= 1
124        while root < k:
125            p = self._parent(k)
126            if not self._data[p] < self._data[k]:
127                break
128            self._data[p], self._data[k] = self._data[k], self._data[p]
129            k = p
130        while root < k:
131            p = self._parent(k) | 1
132            if not self._data[k] < self._data[p]:
133                break
134            self._data[p], self._data[k] = self._data[k], self._data[p]
135            k = p
136        return k
137
138    def tolist(self) -> list[T]:
139        return sorted(self._data)
140
141    def __str__(self) -> str:
142        return str(self.tolist())
143
144    def __repr__(self) -> str:
145        return f"{self.__class__.__name__}({self})"
146from typing import Generic, Iterable, TypeVar
147
148T = TypeVar("T", bound=SupportsLessThan)
149
150
151class MinMaxMultiset(Generic[T]):
152
153    def __init__(self, a: Iterable[T] = []):
154        a = list(a)
155        data = {}
156        for x in a:
157            if x in data:
158                data[x] += 1
159            else:
160                data[x] = 1
161        self.data = data
162        self.heap = DoubleEndedHeap(a)
163        self.len = len(a)
164
165    def add(self, key: T, val: int = 1) -> None:
166        if val == 0:
167            return
168        self.heap.add(key)
169        if key in self.data:
170            self.data[key] += val
171        else:
172            self.data[key] = val
173        self.len += val
174
175    def discard(self, key: T, val: int = 1) -> bool:
176        if key not in self.data:
177            return False
178        cnt = self.data[key]
179        if val < cnt:
180            self.len -= val
181            self.data[key] -= val
182        else:
183            self.len -= cnt
184            del self.data[key]
185        return True
186
187    def pop_min(self) -> T:
188        while True:
189            v = self.heap.get_min()
190            if v in self.data:
191                if self.data[v] == 1:
192                    self.heap.pop_min()
193                    del self.data[v]
194                else:
195                    self.data[v] -= 1
196                self.len -= 1
197                return v
198            self.heap.pop_min()
199
200    def pop_max(self) -> T:
201        while True:
202            v = self.heap.get_max()
203            if v in self.data:
204                self.len -= 1
205                if self.data[v] == 1:
206                    self.heap.pop_max()
207                    del self.data[v]
208                else:
209                    self.data[v] -= 1
210                return v
211            self.heap.pop_max()
212
213    def get_min(self) -> T:
214        while True:
215            v = self.heap.get_min()
216            if v in self.data:
217                return v
218            else:
219                self.heap.pop_min()
220
221    def get_max(self) -> T:
222        while True:
223            v = self.heap.get_max()
224            if v in self.data:
225                return v
226            else:
227                self.heap.pop_max()
228
229    def count(self, key: T) -> int:
230        return self.data[key]
231
232    def tolist(self) -> list[T]:
233        return sorted(k for k, v in self.data.items() for _ in range(v))
234
235    def len_elm(self) -> int:
236        return len(self.data)
237
238    def __contains__(self, key: T):
239        return key in self.data
240
241    def __len__(self):
242        return self.len
243
244    def __str__(self):
245        return "{" + ", ".join(map(str, self.tolist())) + "}"
246
247    def __repr__(self):
248        return "MinMaxMultiset([" + ", ".join(map(str, self.tolist())) + "])"

仕様

class MinMaxMultiset(a: Iterable[T] = [])[source]

Bases: Generic[T]

add(key: T, val: int = 1) None[source]
count(key: T) int[source]
discard(key: T, val: int = 1) bool[source]
get_max() T[source]
get_min() T[source]
len_elm() int[source]
pop_max() T[source]
pop_min() T[source]
tolist() list[T][source]