1# from titan_pylib.data_structures.bst_base.bst_multiset_array_base import BSTMultisetArrayBase
2from typing import TypeVar, Generic, Optional
3
4T = TypeVar("T")
5BST = TypeVar("BST")
6# protcolで、key,val,left,right を規定
7
8
9class BSTMultisetArrayBase(Generic[BST, T]):
10
11 @staticmethod
12 def _rle(a: list[T]) -> tuple[list[T], list[int]]:
13 keys, vals = [a[0]], [1]
14 for i, elm in enumerate(a):
15 if i == 0:
16 continue
17 if elm == keys[-1]:
18 vals[-1] += 1
19 continue
20 keys.append(elm)
21 vals.append(1)
22 return keys, vals
23
24 @staticmethod
25 def count(bst: BST, key: T) -> int:
26 keys, left, right = bst.key, bst.left, bst.right
27 node = bst.root
28 while node:
29 if keys[node] == key:
30 return bst.val[node]
31 node = left[node] if key < keys[node] else right[node]
32 return 0
33
34 @staticmethod
35 def le(bst: BST, key: T) -> Optional[T]:
36 keys, left, right = bst.key, bst.left, bst.right
37 res = None
38 node = bst.root
39 while node:
40 if key == keys[node]:
41 res = key
42 break
43 if key < keys[node]:
44 node = left[node]
45 else:
46 res = keys[node]
47 node = right[node]
48 return res
49
50 @staticmethod
51 def lt(bst: BST, key: T) -> Optional[T]:
52 keys, left, right = bst.key, bst.left, bst.right
53 res = None
54 node = bst.root
55 while node:
56 if key <= keys[node]:
57 node = left[node]
58 else:
59 res = keys[node]
60 node = right[node]
61 return res
62
63 @staticmethod
64 def ge(bst: BST, key: T) -> Optional[T]:
65 keys, left, right = bst.key, bst.left, bst.right
66 res = None
67 node = bst.root
68 while node:
69 if key == keys[node]:
70 res = key
71 break
72 if key < keys[node]:
73 res = keys[node]
74 node = left[node]
75 else:
76 node = right[node]
77 return res
78
79 @staticmethod
80 def gt(bst: BST, key: T) -> Optional[T]:
81 keys, left, right = bst.key, bst.left, bst.right
82 res = None
83 node = bst.root
84 while node:
85 if key < keys[node]:
86 res = keys[node]
87 node = left[node]
88 else:
89 node = right[node]
90 return res
91
92 @staticmethod
93 def index(bst: BST, key: T) -> int:
94 keys, left, right, vals, valsize = (
95 bst.key,
96 bst.left,
97 bst.right,
98 bst.val,
99 bst.valsize,
100 )
101 k = 0
102 node = bst.root
103 while node:
104 if key == keys[node]:
105 if left[node]:
106 k += valsize[left[node]]
107 break
108 if key < keys[node]:
109 node = left[node]
110 else:
111 k += valsize[left[node]] + vals[node]
112 node = right[node]
113 return k
114
115 @staticmethod
116 def index_right(bst: BST, key: T) -> int:
117 keys, left, right, vals, valsize = (
118 bst.key,
119 bst.left,
120 bst.right,
121 bst.val,
122 bst.valsize,
123 )
124 k = 0
125 node = bst.root
126 while node:
127 if key == keys[node]:
128 k += valsize[left[node]] + vals[node]
129 break
130 if key < keys[node]:
131 node = left[node]
132 else:
133 k += valsize[left[node]] + vals[node]
134 node = right[node]
135 return k
136
137 @staticmethod
138 def _kth_elm(bst: BST, k: int) -> tuple[T, int]:
139 left, right, vals, valsize = bst.left, bst.right, bst.val, bst.valsize
140 if k < 0:
141 k += len(bst)
142 node = bst.root
143 while True:
144 t = vals[node] + valsize[left[node]]
145 if t - vals[node] <= k < t:
146 return bst.key[node], vals[node]
147 if t > k:
148 node = left[node]
149 else:
150 node = right[node]
151 k -= t
152
153 @staticmethod
154 def contains(bst: BST, key: T) -> bool:
155 left, right, keys = bst.left, bst.right, bst.key
156 node = bst.root
157 while node:
158 if keys[node] == key:
159 return True
160 node = left[node] if key < keys[node] else right[node]
161 return False
162
163 @staticmethod
164 def tolist(bst: BST) -> list[T]:
165 left, right, keys, vals = bst.left, bst.right, bst.key, bst.val
166 node = bst.root
167 stack, a = [], []
168 while stack or node:
169 if node:
170 stack.append(node)
171 node = left[node]
172 else:
173 node = stack.pop()
174 key = keys[node]
175 for _ in range(vals[node]):
176 a.append(key)
177 node = right[node]
178 return a