Source code for titan_pylib.data_structures.bst_base.bst_multiset_array_base

  1from typing import TypeVar, Generic, Optional
  2
  3T = TypeVar("T")
  4BST = TypeVar("BST")
  5# protcolで、key,val,left,right を規定
  6
  7
[docs] 8class BSTMultisetArrayBase(Generic[BST, T]): 9 10 @staticmethod 11 def _rle(a: list[T]) -> tuple[list[T], list[int]]: 12 keys, vals = [a[0]], [1] 13 for i, elm in enumerate(a): 14 if i == 0: 15 continue 16 if elm == keys[-1]: 17 vals[-1] += 1 18 continue 19 keys.append(elm) 20 vals.append(1) 21 return keys, vals 22
[docs] 23 @staticmethod 24 def count(bst: BST, key: T) -> int: 25 keys, left, right = bst.key, bst.left, bst.right 26 node = bst.root 27 while node: 28 if keys[node] == key: 29 return bst.val[node] 30 node = left[node] if key < keys[node] else right[node] 31 return 0
32
[docs] 33 @staticmethod 34 def le(bst: BST, key: T) -> Optional[T]: 35 keys, left, right = bst.key, bst.left, bst.right 36 res = None 37 node = bst.root 38 while node: 39 if key == keys[node]: 40 res = key 41 break 42 if key < keys[node]: 43 node = left[node] 44 else: 45 res = keys[node] 46 node = right[node] 47 return res
48
[docs] 49 @staticmethod 50 def lt(bst: BST, key: T) -> Optional[T]: 51 keys, left, right = bst.key, bst.left, bst.right 52 res = None 53 node = bst.root 54 while node: 55 if key <= keys[node]: 56 node = left[node] 57 else: 58 res = keys[node] 59 node = right[node] 60 return res
61
[docs] 62 @staticmethod 63 def ge(bst: BST, key: T) -> Optional[T]: 64 keys, left, right = bst.key, bst.left, bst.right 65 res = None 66 node = bst.root 67 while node: 68 if key == keys[node]: 69 res = key 70 break 71 if key < keys[node]: 72 res = keys[node] 73 node = left[node] 74 else: 75 node = right[node] 76 return res
77
[docs] 78 @staticmethod 79 def gt(bst: BST, key: T) -> Optional[T]: 80 keys, left, right = bst.key, bst.left, bst.right 81 res = None 82 node = bst.root 83 while node: 84 if key < keys[node]: 85 res = keys[node] 86 node = left[node] 87 else: 88 node = right[node] 89 return res
90
[docs] 91 @staticmethod 92 def index(bst: BST, key: T) -> int: 93 keys, left, right, vals, valsize = ( 94 bst.key, 95 bst.left, 96 bst.right, 97 bst.val, 98 bst.valsize, 99 ) 100 k = 0 101 node = bst.root 102 while node: 103 if key == keys[node]: 104 if left[node]: 105 k += valsize[left[node]] 106 break 107 if key < keys[node]: 108 node = left[node] 109 else: 110 k += valsize[left[node]] + vals[node] 111 node = right[node] 112 return k
113
[docs] 114 @staticmethod 115 def index_right(bst: BST, key: T) -> int: 116 keys, left, right, vals, valsize = ( 117 bst.key, 118 bst.left, 119 bst.right, 120 bst.val, 121 bst.valsize, 122 ) 123 k = 0 124 node = bst.root 125 while node: 126 if key == keys[node]: 127 k += valsize[left[node]] + vals[node] 128 break 129 if key < keys[node]: 130 node = left[node] 131 else: 132 k += valsize[left[node]] + vals[node] 133 node = right[node] 134 return k
135 136 @staticmethod 137 def _kth_elm(bst: BST, k: int) -> tuple[T, int]: 138 left, right, vals, valsize = bst.left, bst.right, bst.val, bst.valsize 139 if k < 0: 140 k += len(bst) 141 node = bst.root 142 while True: 143 t = vals[node] + valsize[left[node]] 144 if t - vals[node] <= k < t: 145 return bst.key[node], vals[node] 146 if t > k: 147 node = left[node] 148 else: 149 node = right[node] 150 k -= t 151
[docs] 152 @staticmethod 153 def contains(bst: BST, key: T) -> bool: 154 left, right, keys = bst.left, bst.right, bst.key 155 node = bst.root 156 while node: 157 if keys[node] == key: 158 return True 159 node = left[node] if key < keys[node] else right[node] 160 return False
161
[docs] 162 @staticmethod 163 def tolist(bst: BST) -> list[T]: 164 left, right, keys, vals = bst.left, bst.right, bst.key, bst.val 165 node = bst.root 166 stack, a = [], [] 167 while stack or node: 168 if node: 169 stack.append(node) 170 node = left[node] 171 else: 172 node = stack.pop() 173 key = keys[node] 174 for _ in range(vals[node]): 175 a.append(key) 176 node = right[node] 177 return a