beam search with tree¶
ソースコード¶
1#include <vector>
2#include <algorithm>
3
4#include "titan_cpplib/others/print.cpp"
5#include "titan_cpplib/algorithm/random.cpp"
6#include "titan_cpplib/ahc/state_pool.cpp"
7#include "titan_cpplib/ahc/timer.cpp"
8#include "titan_cpplib/data_structures/hash_set.cpp"
9using namespace std;
10
11//! 木上のビームサーチライブラリ
12namespace beam_search_with_tree {
13
14using ScoreType = long long;
15using HashType = unsigned long long;
16
17struct Action {
18 Action() {}
19
20 friend ostream& operator<<(ostream& os, const Action &action) {
21 return os;
22 }
23};
24
25class State {
26 private:
27 titan23::Random srand;
28 ScoreType score;
29 HashType hash;
30
31 public:
32 // TODO Stateを初期化する
33 void init() {
34 this->score = 0;
35 }
36
37 // TODO
38 //! `action` をしたときの評価値とハッシュ値を返す
39 //! ロールバックに必要な情報はすべてactionにメモしておく
40 pair<ScoreType, HashType> try_op(Action &action) const {
41 }
42
43 // TODO
44 //! `action` をする
45 void apply_op(const Action &action) {
46 }
47
48 // TODO
49 //! `action` を戻す
50 void rollback(const Action &action) {
51 }
52
53 // TODO
54 //! 現状態から遷移可能な `Action` の `vector` を返す
55 vector<Action> get_actions() const {
56 }
57
58 ScoreType get_score() {
59 return this->score;
60 }
61
62 void print() const {
63 }
64};
65
66struct BeamParam {
67 int MAX_TURN;
68 int BEAM_WIDTH;
69};
70
71using TreeNodeID = int;
72using SubStateID = int;
73
74//! try_opした結果をメモしておく構造体
75struct SubState {
76 TreeNodeID par;
77 Action action;
78 ScoreType score;
79
80 SubState() : par(-1) {}
81 SubState(TreeNodeID par, const Action &action, ScoreType score) : par(par), action(action), score(score) {}
82};
83
84//! ビームサーチの過程を表す木のノード
85struct TreeNode {
86 TreeNodeID par;
87 Action pre_action;
88 ScoreType score;
89 vector<TreeNodeID> child;
90
91 TreeNode() : par(-1) {}
92
93 bool is_leaf() const {
94 return child.empty();
95 }
96};
97
98titan23::StatePool<TreeNode> treenode_pool;
99titan23::StatePool<SubState> substate_pool;
100
101
102class BeamSearchWithTree {
103 private:
104 ScoreType best_score;
105 TreeNodeID best_id;
106 titan23::HashSet seen;
107
108 void get_next_beam_recursion(State* state, TreeNodeID node, vector<SubStateID> &next_beam, int depth, const int beam_width) {
109 if (depth == 0) { // 葉
110 vector<Action> actions = state->get_actions();
111 for (Action &action : actions) {
112 auto [score, hash] = state->try_op(action);
113 if (seen.contains_insert(hash)) continue;
114 SubStateID substate = substate_pool.gen();
115 substate_pool.get(substate)->par = node;
116 substate_pool.get(substate)->action = action;
117 substate_pool.get(substate)->score = score;
118 next_beam.emplace_back(substate);
119 }
120 return;
121 }
122 for (const TreeNodeID nxt_node : treenode_pool.get(node)->child) {
123 state->apply_op(treenode_pool.get(nxt_node)->pre_action);
124 get_next_beam_recursion(state, nxt_node, next_beam, depth-1, beam_width);
125 state->rollback(treenode_pool.get(nxt_node)->pre_action);
126 }
127 }
128
129 tuple<int, TreeNodeID, vector<SubStateID>> get_next_beam(State* state, TreeNodeID node, int turn, const int beam_width) {
130 int cnt = 0;
131 while (true) { // 一本道は行くだけ
132 if (treenode_pool.get(node)->child.size() != 1) break;
133 ++cnt;
134 node = treenode_pool.get(node)->child[0];
135 state->apply_op(treenode_pool.get(node)->pre_action);
136 }
137 vector<SubStateID> next_beam;
138 seen.clear();
139 get_next_beam_recursion(state, node, next_beam, turn-cnt, beam_width);
140 return make_tuple(cnt, node, next_beam);
141 }
142
143 //! 不要なNodeを削除し、木を更新する
144 bool update_tree(const TreeNodeID node, const int depth) {
145 if (treenode_pool.get(node)->is_leaf()) return depth == 0;
146 int idx = 0;
147 while (idx < treenode_pool.get(node)->child.size()) {
148 TreeNodeID nxt_node = treenode_pool.get(node)->child[idx];
149 if (!update_tree(nxt_node, depth-1)) {
150 treenode_pool.del(nxt_node);
151 treenode_pool.get(node)->child.erase(treenode_pool.get(node)->child.begin() + idx);
152 continue;
153 }
154 ++idx;
155 }
156 return idx > 0;
157 }
158
159 //! node以上のパスを返す
160 vector<Action> get_path(TreeNodeID node) {
161 vector<Action> result;
162 while (node != -1 && treenode_pool.get(node)->par != -1) {
163 result.emplace_back(treenode_pool.get(node)->pre_action);
164 node = treenode_pool.get(node)->par;
165 }
166 reverse(result.begin(), result.end());
167 return result;
168 }
169
170 //! for debug
171 void print_tree(State* state, const TreeNodeID node, int depth) {
172 }
173
174 //! node以下で、葉かつ最も評価値の良いノードを見るける / 葉はターン数からは判断していないので注意
175 void get_best_node(TreeNodeID node) {
176 if (treenode_pool.get(node)->is_leaf()) {
177 if (best_id == -1 || treenode_pool.get(node)->score < best_score) {
178 best_score = treenode_pool.get(node)->score;
179 best_id = node;
180 }
181 return;
182 }
183 for (TreeNodeID nxt_node : treenode_pool.get(node)->child) {
184 get_best_node(nxt_node);
185 }
186 }
187
188 vector<Action> get_result(TreeNodeID root) {
189 best_id = -1; // 更新
190 get_best_node(root);
191 TreeNodeID node = best_id;
192 vector<Action> result = get_path(node);
193 cerr << treenode_pool.get_size() << endl;
194 return result;
195 }
196
197 public:
198 /**
199 * @brief ビームサーチをする
200 *
201 * @param param ターン数、ビーム幅を指定するパラメータ構造体
202 * @param verbose 途中結果のスコアを標準エラー出力するかどうか
203 * @return vector<Action>
204 */
205 vector<Action> search(const BeamParam ¶m, const bool verbose = false) {
206 TreeNodeID root = treenode_pool.gen();
207 treenode_pool.get(root)->child.clear();
208 treenode_pool.get(root)->par = -1;
209
210 State* state = new State;
211 state->init();
212
213 this->seen = titan23::HashSet(param.BEAM_WIDTH * 4); // TODO
214
215 int now_turn = 0;
216
217 for (int turn = 0; turn < param.MAX_TURN; ++turn) {
218 if (verbose) cerr << "# turn : " << turn << endl;
219
220 // 次のビーム候補を求める
221 auto [apply_only_turn, next_root, next_beam] = get_next_beam(state, root, turn-now_turn, param.BEAM_WIDTH);
222 root = next_root;
223 now_turn += apply_only_turn;
224 assert(!next_beam.empty());
225 if (verbose) {
226 cerr << "min_score=" << substate_pool.get((*min_element(next_beam.begin(), next_beam.end(), [] (const SubStateID &left, const SubStateID &right) {
227 return substate_pool.get(left)->score < substate_pool.get(right)->score;
228 })))->score << endl;
229 }
230
231 // ビームを絞る
232 int beam_width = min(param.BEAM_WIDTH, (int)next_beam.size());
233 nth_element(next_beam.begin(), next_beam.begin() + beam_width, next_beam.end(), [&] (const SubStateID &left, const SubStateID &right) {
234 return substate_pool.get(left)->score < substate_pool.get(right)->score;
235 });
236
237 // 探索木の更新
238 for (int i = 0; i < beam_width; ++i) {
239 SubStateID s = next_beam[i];
240 TreeNodeID new_node = treenode_pool.gen();
241 treenode_pool.get(substate_pool.get(s)->par)->child.emplace_back(new_node);
242 treenode_pool.get(new_node)->par = substate_pool.get(s)->par;
243 treenode_pool.get(new_node)->pre_action = substate_pool.get(s)->action;
244 treenode_pool.get(new_node)->score = substate_pool.get(s)->score;
245 }
246 substate_pool.clear();
247 update_tree(root, turn-now_turn+1);
248 }
249
250 // 答えを復元する
251 vector<Action> result = get_result(root);
252 return result;
253 }
254};
255} // namespace beam_search
256
257// int main() {
258// beam_search_with_tree::BeamParam param;
259// param.MAX_TURN = 2500;
260// param.BEAM_WIDTH = 1000;
261// beam_search_with_tree::BeamSearchWithTree bs;
262// vector<beam_search_with_tree::Action> ans = bs.search(param, true);
263// }
仕様¶
Warning
doxygenfile: Cannot find file “titan_cpplib/ahc/beam_search_with_tree.cpp