min_max_set

ソースコード

from titan_pylib.data_structures.set.min_max_set import MinMaxSet

view on github

展開済みコード

  1# from titan_pylib.data_structures.set.min_max_set import MinMaxSet
  2# from titan_pylib.my_class.supports_less_than import SupportsLessThan
  3from typing import Protocol
  4
  5
  6class SupportsLessThan(Protocol):
  7
  8    def __lt__(self, other) -> bool: ...
  9# from titan_pylib.data_structures.heap.double_ended_heap import DoubleEndedHeap
 10# from titan_pylib.my_class.supports_less_than import SupportsLessThan
 11from typing import Generic, Iterable, TypeVar
 12
 13T = TypeVar("T", bound=SupportsLessThan)
 14
 15
 16class DoubleEndedHeap(Generic[T]):
 17    """
 18    - [両端優先度付きキューのInterval-Heap実装](https://natsugiri.hatenablog.com/entry/2016/10/10/035445)
 19    - [Double-ended priority queue(wikipedia)](https://en.wikipedia.org/wiki/Double-ended_priority_queue)
 20    """
 21
 22    def __init__(self, a: Iterable[T] = []) -> None:
 23        """構築します。
 24        :math:`O(n)` です。
 25
 26        Args:
 27          a (Iterable[T], optional): 構築元の配列です。
 28        """
 29        self._data = list(a)
 30        self._heapify()
 31
 32    def _heapify(self) -> None:
 33        n = len(self._data)
 34        for i in range(n - 1, -1, -1):
 35            if i & 1 and self._data[i - 1] < self._data[i]:
 36                self._data[i - 1], self._data[i] = self._data[i], self._data[i - 1]
 37            k = self._down(i)
 38            self._up(k, i)
 39
 40    def add(self, key: T) -> None:
 41        """``key`` を1つ追加します。
 42        :math:`O(\\log{n})` です。
 43        """
 44        self._data.append(key)
 45        self._up(len(self._data) - 1)
 46
 47    def pop_min(self) -> T:
 48        """最小の要素を削除して返します。
 49        :math:`O(\\log{n})` です。
 50        """
 51        if len(self._data) < 3:
 52            res = self._data.pop()
 53        else:
 54            self._data[1], self._data[-1] = self._data[-1], self._data[1]
 55            res = self._data.pop()
 56            k = self._down(1)
 57            self._up(k)
 58        return res
 59
 60    def pop_max(self) -> T:
 61        """最大の要素を削除して返します。
 62        :math:`O(\\log{n})` です。
 63        """
 64        if len(self._data) < 2:
 65            res = self._data.pop()
 66        else:
 67            self._data[0], self._data[-1] = self._data[-1], self._data[0]
 68            res = self._data.pop()
 69            self._up(self._down(0))
 70        return res
 71
 72    def get_min(self) -> T:
 73        """最小の要素を返します。
 74        :math:`O(1)` です。
 75        """
 76        return self._data[0] if len(self._data) < 2 else self._data[1]
 77
 78    def get_max(self) -> T:
 79        """最大の要素を返します。
 80        :math:`O(1)` です。
 81        """
 82        return self._data[0]
 83
 84    def __len__(self) -> int:
 85        return len(self._data)
 86
 87    def __bool__(self) -> bool:
 88        return len(self._data) > 0
 89
 90    def _parent(self, k: int) -> int:
 91        return ((k >> 1) - 1) & ~1
 92
 93    def _down(self, k: int) -> int:
 94        n = len(self._data)
 95        if k & 1:
 96            while k << 1 | 1 < n:
 97                c = 2 * k + 3
 98                if n <= c or self._data[c - 2] < self._data[c]:
 99                    c -= 2
100                if c < n and self._data[c] < self._data[k]:
101                    self._data[k], self._data[c] = self._data[c], self._data[k]
102                    k = c
103                else:
104                    break
105        else:
106            while 2 * k + 2 < n:
107                c = 2 * k + 4
108                if n <= c or self._data[c] < self._data[c - 2]:
109                    c -= 2
110                if c < n and self._data[k] < self._data[c]:
111                    self._data[k], self._data[c] = self._data[c], self._data[k]
112                    k = c
113                else:
114                    break
115        return k
116
117    def _up(self, k: int, root: int = 1) -> int:
118        if (k | 1) < len(self._data) and self._data[k & ~1] < self._data[k | 1]:
119            self._data[k & ~1], self._data[k | 1] = (
120                self._data[k | 1],
121                self._data[k & ~1],
122            )
123            k ^= 1
124        while root < k:
125            p = self._parent(k)
126            if not self._data[p] < self._data[k]:
127                break
128            self._data[p], self._data[k] = self._data[k], self._data[p]
129            k = p
130        while root < k:
131            p = self._parent(k) | 1
132            if not self._data[k] < self._data[p]:
133                break
134            self._data[p], self._data[k] = self._data[k], self._data[p]
135            k = p
136        return k
137
138    def tolist(self) -> list[T]:
139        return sorted(self._data)
140
141    def __str__(self) -> str:
142        return str(self.tolist())
143
144    def __repr__(self) -> str:
145        return f"{self.__class__.__name__}({self})"
146from typing import Generic, Iterable, TypeVar
147
148T = TypeVar("T", bound=SupportsLessThan)
149
150
151class MinMaxSet(Generic[T]):
152
153    def __init__(self, a: Iterable[T] = []):
154        a = set(a)
155        self.data = a
156        self.heap = DoubleEndedHeap(a)
157
158    def add(self, key: T) -> bool:
159        if key not in self.data:
160            self.heap.add(key)
161            self.data.add(key)
162            return True
163        return False
164
165    def discard(self, key: T) -> bool:
166        if key in self.data:
167            self.data.discard(key)
168            return True
169        return False
170
171    def pop_min(self) -> T:
172        while True:
173            v = self.heap.pop_min()
174            if v in self.data:
175                self.data.discard(v)
176                return v
177
178    def pop_max(self) -> T:
179        while True:
180            v = self.heap.pop_max()
181            if v in self.data:
182                self.data.discard(v)
183                return v
184
185    def get_min(self) -> T:
186        while True:
187            v = self.heap.get_min()
188            if v in self.data:
189                return v
190            else:
191                self.heap.pop_min()
192
193    def get_max(self) -> T:
194        while True:
195            v = self.heap.get_max()
196            if v in self.data:
197                return v
198            else:
199                self.heap.pop_max()
200
201    def tolist(self) -> list[T]:
202        return sorted(self.data)
203
204    def __contains__(self, key: T):
205        return key in self.data
206
207    def __getitem__(self, k: int):  # 末尾と先頭のみ
208        if k == -1 or k == len(self.data) - 1:
209            return self.get_max()
210        elif k == 0:
211            return self.get_min()
212        raise IndexError
213
214    def __len__(self):
215        return len(self.data)
216
217    def __str__(self):
218        return "{" + ", ".join(map(str, sorted(self.data))) + "}"
219
220    def __repr__(self):
221        return f"MinMaxSet({self})"

仕様

class MinMaxSet(a: Iterable[T] = [])[source]

Bases: Generic[T]

add(key: T) bool[source]
discard(key: T) bool[source]
get_max() T[source]
get_min() T[source]
pop_max() T[source]
pop_min() T[source]
tolist() list[T][source]