Source code for titan_pylib.data_structures.bst_base.bst_set_node_base
1from typing import TypeVar, Generic, Optional
2
3T = TypeVar("T")
4Node = TypeVar("Node")
5# protcolで、key,left,right を規定
6
7
[docs]
8class BSTSetNodeBase(Generic[T, Node]):
9
[docs]
10 @staticmethod
11 def sort_unique(a: list[T]) -> list[T]:
12 if not all(a[i] < a[i + 1] for i in range(len(a) - 1)):
13 a = sorted(a)
14 new_a = [a[0]]
15 for elm in a:
16 if new_a[-1] == elm:
17 continue
18 new_a.append(elm)
19 a = new_a
20 return a
21
[docs]
22 @staticmethod
23 def contains(node: Node, key: T) -> bool:
24 while node:
25 if key == node.key:
26 return True
27 node = node.left if key < node.key else node.right
28 return False
29
[docs]
30 @staticmethod
31 def get_min(node: Node) -> Optional[T]:
32 if not node:
33 return None
34 while node.left:
35 node = node.left
36 return node.key
37
[docs]
38 @staticmethod
39 def get_max(node: Node) -> Optional[T]:
40 if not node:
41 return None
42 while node.right:
43 node = node.right
44 return node.key
45
[docs]
46 @staticmethod
47 def le(node: Node, key: T) -> Optional[T]:
48 res = None
49 while node is not None:
50 if key == node.key:
51 res = key
52 break
53 if key < node.key:
54 node = node.left
55 else:
56 res = node.key
57 node = node.right
58 return res
59
[docs]
60 @staticmethod
61 def lt(node: Node, key: T) -> Optional[T]:
62 res = None
63 while node is not None:
64 if key <= node.key:
65 node = node.left
66 else:
67 res = node.key
68 node = node.right
69 return res
70
[docs]
71 @staticmethod
72 def ge(node: Node, key: T) -> Optional[T]:
73 res = None
74 while node is not None:
75 if key == node.key:
76 res = key
77 break
78 if key < node.key:
79 res = node.key
80 node = node.left
81 else:
82 node = node.right
83 return res
84
[docs]
85 @staticmethod
86 def gt(node: Node, key: T) -> Optional[T]:
87 res = None
88 while node is not None:
89 if key < node.key:
90 res = node.key
91 node = node.left
92 else:
93 node = node.right
94 return res
95
[docs]
96 @staticmethod
97 def index(node: Node, key: T) -> int:
98 k = 0
99 while node is not None:
100 if key == node.key:
101 if node.left is not None:
102 k += node.left.size
103 break
104 if key < node.key:
105 node = node.left
106 else:
107 k += 1 if node.left is None else node.left.size + 1
108 node = node.right
109 return k
110
[docs]
111 @staticmethod
112 def index_right(node: Node, key: T) -> int:
113 k = 0
114 while node is not None:
115 if key == node.key:
116 k += 1 if node.left is None else node.left.size + 1
117 break
118 if key < node.key:
119 node = node.left
120 else:
121 k += 1 if node.left is None else node.left.size + 1
122 node = node.right
123 return k
124
[docs]
125 @staticmethod
126 def tolist(node: Node) -> list[T]:
127 stack = []
128 res = []
129 while stack or node:
130 if node:
131 stack.append(node)
132 node = node.left
133 else:
134 node = stack.pop()
135 res.append(node.key)
136 node = node.right
137 return res
138
[docs]
139 @staticmethod
140 def kth_elm(node: Node, k: int, _len: int) -> T:
141 if k < 0:
142 k += _len
143 while True:
144 t = 0 if node.left is None else node.left.size
145 if t == k:
146 return node.key
147 if t > k:
148 node = node.left
149 else:
150 node = node.right
151 k -= t + 1