bst_multiset_array_base

ソースコード

from titan_pylib.data_structures.bst_base.bst_multiset_array_base import BSTMultisetArrayBase

view on github

展開済みコード

  1# from titan_pylib.data_structures.bst_base.bst_multiset_array_base import BSTMultisetArrayBase
  2from typing import TypeVar, Generic, Optional
  3
  4T = TypeVar("T")
  5BST = TypeVar("BST")
  6# protcolで、key,val,left,right を規定
  7
  8
  9class BSTMultisetArrayBase(Generic[BST, T]):
 10
 11    @staticmethod
 12    def _rle(a: list[T]) -> tuple[list[T], list[int]]:
 13        keys, vals = [a[0]], [1]
 14        for i, elm in enumerate(a):
 15            if i == 0:
 16                continue
 17            if elm == keys[-1]:
 18                vals[-1] += 1
 19                continue
 20            keys.append(elm)
 21            vals.append(1)
 22        return keys, vals
 23
 24    @staticmethod
 25    def count(bst: BST, key: T) -> int:
 26        keys, left, right = bst.key, bst.left, bst.right
 27        node = bst.root
 28        while node:
 29            if keys[node] == key:
 30                return bst.val[node]
 31            node = left[node] if key < keys[node] else right[node]
 32        return 0
 33
 34    @staticmethod
 35    def le(bst: BST, key: T) -> Optional[T]:
 36        keys, left, right = bst.key, bst.left, bst.right
 37        res = None
 38        node = bst.root
 39        while node:
 40            if key == keys[node]:
 41                res = key
 42                break
 43            if key < keys[node]:
 44                node = left[node]
 45            else:
 46                res = keys[node]
 47                node = right[node]
 48        return res
 49
 50    @staticmethod
 51    def lt(bst: BST, key: T) -> Optional[T]:
 52        keys, left, right = bst.key, bst.left, bst.right
 53        res = None
 54        node = bst.root
 55        while node:
 56            if key <= keys[node]:
 57                node = left[node]
 58            else:
 59                res = keys[node]
 60                node = right[node]
 61        return res
 62
 63    @staticmethod
 64    def ge(bst: BST, key: T) -> Optional[T]:
 65        keys, left, right = bst.key, bst.left, bst.right
 66        res = None
 67        node = bst.root
 68        while node:
 69            if key == keys[node]:
 70                res = key
 71                break
 72            if key < keys[node]:
 73                res = keys[node]
 74                node = left[node]
 75            else:
 76                node = right[node]
 77        return res
 78
 79    @staticmethod
 80    def gt(bst: BST, key: T) -> Optional[T]:
 81        keys, left, right = bst.key, bst.left, bst.right
 82        res = None
 83        node = bst.root
 84        while node:
 85            if key < keys[node]:
 86                res = keys[node]
 87                node = left[node]
 88            else:
 89                node = right[node]
 90        return res
 91
 92    @staticmethod
 93    def index(bst: BST, key: T) -> int:
 94        keys, left, right, vals, valsize = (
 95            bst.key,
 96            bst.left,
 97            bst.right,
 98            bst.val,
 99            bst.valsize,
100        )
101        k = 0
102        node = bst.root
103        while node:
104            if key == keys[node]:
105                if left[node]:
106                    k += valsize[left[node]]
107                break
108            if key < keys[node]:
109                node = left[node]
110            else:
111                k += valsize[left[node]] + vals[node]
112                node = right[node]
113        return k
114
115    @staticmethod
116    def index_right(bst: BST, key: T) -> int:
117        keys, left, right, vals, valsize = (
118            bst.key,
119            bst.left,
120            bst.right,
121            bst.val,
122            bst.valsize,
123        )
124        k = 0
125        node = bst.root
126        while node:
127            if key == keys[node]:
128                k += valsize[left[node]] + vals[node]
129                break
130            if key < keys[node]:
131                node = left[node]
132            else:
133                k += valsize[left[node]] + vals[node]
134                node = right[node]
135        return k
136
137    @staticmethod
138    def _kth_elm(bst: BST, k: int) -> tuple[T, int]:
139        left, right, vals, valsize = bst.left, bst.right, bst.val, bst.valsize
140        if k < 0:
141            k += len(bst)
142        node = bst.root
143        while True:
144            t = vals[node] + valsize[left[node]]
145            if t - vals[node] <= k < t:
146                return bst.key[node], vals[node]
147            if t > k:
148                node = left[node]
149            else:
150                node = right[node]
151                k -= t
152
153    @staticmethod
154    def contains(bst: BST, key: T) -> bool:
155        left, right, keys = bst.left, bst.right, bst.key
156        node = bst.root
157        while node:
158            if keys[node] == key:
159                return True
160            node = left[node] if key < keys[node] else right[node]
161        return False
162
163    @staticmethod
164    def tolist(bst: BST) -> list[T]:
165        left, right, keys, vals = bst.left, bst.right, bst.key, bst.val
166        node = bst.root
167        stack, a = [], []
168        while stack or node:
169            if node:
170                stack.append(node)
171                node = left[node]
172            else:
173                node = stack.pop()
174                key = keys[node]
175                for _ in range(vals[node]):
176                    a.append(key)
177                node = right[node]
178        return a

仕様

class BSTMultisetArrayBase[source]

Bases: Generic[BST, T]

static contains(bst: BST, key: T) bool[source]
static count(bst: BST, key: T) int[source]
static ge(bst: BST, key: T) T | None[source]
static gt(bst: BST, key: T) T | None[source]
static index(bst: BST, key: T) int[source]
static index_right(bst: BST, key: T) int[source]
static le(bst: BST, key: T) T | None[source]
static lt(bst: BST, key: T) T | None[source]
static tolist(bst: BST) list[T][source]