1from typing import TypeVar, Generic, Optional
2
3T = TypeVar("T")
4BST = TypeVar("BST")
5# protcolで、key,val,left,right を規定
6
7
[docs]
8class BSTMultisetArrayBase(Generic[BST, T]):
9
10 @staticmethod
11 def _rle(a: list[T]) -> tuple[list[T], list[int]]:
12 keys, vals = [a[0]], [1]
13 for i, elm in enumerate(a):
14 if i == 0:
15 continue
16 if elm == keys[-1]:
17 vals[-1] += 1
18 continue
19 keys.append(elm)
20 vals.append(1)
21 return keys, vals
22
[docs]
23 @staticmethod
24 def count(bst: BST, key: T) -> int:
25 keys, left, right = bst.key, bst.left, bst.right
26 node = bst.root
27 while node:
28 if keys[node] == key:
29 return bst.val[node]
30 node = left[node] if key < keys[node] else right[node]
31 return 0
32
[docs]
33 @staticmethod
34 def le(bst: BST, key: T) -> Optional[T]:
35 keys, left, right = bst.key, bst.left, bst.right
36 res = None
37 node = bst.root
38 while node:
39 if key == keys[node]:
40 res = key
41 break
42 if key < keys[node]:
43 node = left[node]
44 else:
45 res = keys[node]
46 node = right[node]
47 return res
48
[docs]
49 @staticmethod
50 def lt(bst: BST, key: T) -> Optional[T]:
51 keys, left, right = bst.key, bst.left, bst.right
52 res = None
53 node = bst.root
54 while node:
55 if key <= keys[node]:
56 node = left[node]
57 else:
58 res = keys[node]
59 node = right[node]
60 return res
61
[docs]
62 @staticmethod
63 def ge(bst: BST, key: T) -> Optional[T]:
64 keys, left, right = bst.key, bst.left, bst.right
65 res = None
66 node = bst.root
67 while node:
68 if key == keys[node]:
69 res = key
70 break
71 if key < keys[node]:
72 res = keys[node]
73 node = left[node]
74 else:
75 node = right[node]
76 return res
77
[docs]
78 @staticmethod
79 def gt(bst: BST, key: T) -> Optional[T]:
80 keys, left, right = bst.key, bst.left, bst.right
81 res = None
82 node = bst.root
83 while node:
84 if key < keys[node]:
85 res = keys[node]
86 node = left[node]
87 else:
88 node = right[node]
89 return res
90
[docs]
91 @staticmethod
92 def index(bst: BST, key: T) -> int:
93 keys, left, right, vals, valsize = (
94 bst.key,
95 bst.left,
96 bst.right,
97 bst.val,
98 bst.valsize,
99 )
100 k = 0
101 node = bst.root
102 while node:
103 if key == keys[node]:
104 if left[node]:
105 k += valsize[left[node]]
106 break
107 if key < keys[node]:
108 node = left[node]
109 else:
110 k += valsize[left[node]] + vals[node]
111 node = right[node]
112 return k
113
[docs]
114 @staticmethod
115 def index_right(bst: BST, key: T) -> int:
116 keys, left, right, vals, valsize = (
117 bst.key,
118 bst.left,
119 bst.right,
120 bst.val,
121 bst.valsize,
122 )
123 k = 0
124 node = bst.root
125 while node:
126 if key == keys[node]:
127 k += valsize[left[node]] + vals[node]
128 break
129 if key < keys[node]:
130 node = left[node]
131 else:
132 k += valsize[left[node]] + vals[node]
133 node = right[node]
134 return k
135
136 @staticmethod
137 def _kth_elm(bst: BST, k: int) -> tuple[T, int]:
138 left, right, vals, valsize = bst.left, bst.right, bst.val, bst.valsize
139 if k < 0:
140 k += len(bst)
141 node = bst.root
142 while True:
143 t = vals[node] + valsize[left[node]]
144 if t - vals[node] <= k < t:
145 return bst.key[node], vals[node]
146 if t > k:
147 node = left[node]
148 else:
149 node = right[node]
150 k -= t
151
[docs]
152 @staticmethod
153 def contains(bst: BST, key: T) -> bool:
154 left, right, keys = bst.left, bst.right, bst.key
155 node = bst.root
156 while node:
157 if keys[node] == key:
158 return True
159 node = left[node] if key < keys[node] else right[node]
160 return False
161
[docs]
162 @staticmethod
163 def tolist(bst: BST) -> list[T]:
164 left, right, keys, vals = bst.left, bst.right, bst.key, bst.val
165 node = bst.root
166 stack, a = [], []
167 while stack or node:
168 if node:
169 stack.append(node)
170 node = left[node]
171 else:
172 node = stack.pop()
173 key = keys[node]
174 for _ in range(vals[node]):
175 a.append(key)
176 node = right[node]
177 return a