beam search euler¶
ソースコード¶
1#include <iostream>
2#include <vector>
3#include <cassert>
4#include <algorithm>
5
6#include "titan_cpplib/algorithm/random.cpp"
7#include "titan_cpplib/ahc/state_pool.cpp"
8#include "titan_cpplib/ahc/timer.cpp"
9#include "titan_cpplib/data_structures/hash_set.cpp"
10#include "titan_cpplib/others/print.cpp"
11
12using namespace std;
13
14#define rep(i, n) for (int i = 0; i < (n); ++i)
15
16//! 木上のビームサーチライブラリ
17namespace beam_search_with_tree {
18
19using ScoreType = long long;
20using HashType = unsigned long long;
21const ScoreType INF = 1e9;
22
23// Action
24struct Action {
25 char d;
26 ScoreType pre_score, nxt_score;
27 HashType pre_hash, nxt_hash;
28
29 Action() {}
30 Action(const char d) : d(d), pre_score(INF), nxt_score(INF), pre_hash(0), nxt_hash(0) {}
31};
32ostream& operator<<(ostream& os, const Action &action) {
33 os << action.d;
34 return os;
35}
36
37class State {
38 private:
39
40 public:
41 ScoreType score;
42 HashType hash;
43
44 State() {}
45
46 // TODO Stateを初期化する
47 void init() {}
48
49 // TODO
50 //! `action` をしたときの評価値とハッシュ値を返す
51 //! ロールバックに必要な情報はすべてactionにメモしておく
52 pair<ScoreType, HashType> try_op(Action &action) const {}
53
54 bool is_done() const {}
55
56 // TODO
57 //! `action` をする
58 void apply_op(const Action &action) {}
59
60 // TODO
61 //! `action` を戻す
62 void rollback(const Action &action) {}
63
64 // TODO
65 //! 現状態から遷移可能な `Action` の `vector` を返す
66 vector<Action> get_actions() const {}
67
68 void print() const {}
69};
70
71struct BeamParam {
72 int MAX_TURN;
73 int BEAM_WIDTH;
74};
75
76class BeamSearchWithTree {
77 // ref: https://eijirou-kyopro.hatenablog.com/entry/2024/02/01/115639
78
79 private:
80 static constexpr const int PRE_ORDER = -1;
81 static constexpr const int POST_ORDER = -2;
82
83 titan23::HashSet seen;
84 using ActionIDType = int;
85 ActionIDType ActionID;
86 vector<Action> result;
87
88 // ビームサーチの過程を表す木
89 // <dir or id, action, action_id>
90 // dir or id := 葉のとき、leaf_id
91 // そうでないとき、行きがけなら-1、帰りがけなら-2
92 vector<tuple<int, Action, ActionIDType>> tree;
93
94 // 次のビーム候補を保持する配列
95 vector<tuple<int, ScoreType, Action, ActionIDType>> next_beam; // <par, score, action, action_id>
96
97 vector<vector<int>> next_beam_data;
98
99 void get_next_beam(State* state, const int turn) {
100 next_beam.clear();
101 next_beam.reserve(tree.size() * 4); // TODO
102 seen.clear();
103
104 if (turn == 0) {
105 vector<Action> actions = state->get_actions();
106 for (Action &action : actions) {
107 auto [score, hash] = state->try_op(action);
108 if (seen.contains_insert(hash)) continue;
109 next_beam.emplace_back(PRE_ORDER, score, action, ActionID);
110 ActionID++;
111 }
112 return;
113 }
114
115 int leaf_id = 0;
116 for (int i = 0; i < tree.size(); ++i) {
117 auto [dir_or_leaf_id, action, _] = tree[i];
118 if (dir_or_leaf_id >= 0) {
119 state->apply_op(action);
120 vector<Action> actions = state->get_actions();
121 std::get<0>(tree[i]) = leaf_id;
122 for (Action &action : actions) {
123 auto [score, hash] = state->try_op(action);
124 if (seen.contains_insert(hash)) continue;
125 next_beam.emplace_back(leaf_id, score, action, ActionID);
126 ActionID++;
127 }
128 ++leaf_id;
129 state->rollback(action);
130 } else if (dir_or_leaf_id == PRE_ORDER) {
131 state->apply_op(action);
132 } else {
133 state->rollback(action);
134 }
135 }
136 }
137
138 //! 不要なNodeを削除し、木を更新する
139 int update_tree(State* state, const int turn) {
140 vector<tuple<int, Action, ActionIDType>> new_tree;
141 new_tree.reserve(tree.size());
142 if (turn == 0) {
143 for (auto &[par, _, new_action, action_id] : next_beam) {
144 assert(par == -1);
145 new_tree.emplace_back(0, new_action, action_id);
146 }
147 swap(tree, new_tree);
148 return 0;
149 }
150
151 int i = 0;
152 int apply_only_turn = 0;
153 while (true) {
154 const auto &[dir_or_leaf_id, action, action_id] = tree[i];
155 // 行きがけかつ帰りがけのaction_idが一致しているなら、一本道なので行くだけ
156 if (dir_or_leaf_id == PRE_ORDER && action_id == std::get<2>(tree.back())) {
157 ++i;
158 result.emplace_back(action);
159 state->apply_op(action);
160 tree.pop_back();
161 apply_only_turn++;
162 } else {
163 break;
164 }
165 }
166
167 for (; i < tree.size(); ++i) {
168 const auto &[dir_or_leaf_id, action, action_id] = tree[i];
169 if (dir_or_leaf_id >= 0) {
170 if (next_beam_data[dir_or_leaf_id].empty()) continue;
171 new_tree.emplace_back(PRE_ORDER, action, action_id);
172 for (const int beam_idx : next_beam_data[dir_or_leaf_id]) {
173 auto &[_, __, new_action, new_action_id] = next_beam[beam_idx];
174 new_tree.emplace_back(dir_or_leaf_id, new_action, new_action_id);
175 }
176 new_tree.emplace_back(POST_ORDER, action, action_id);
177 next_beam_data[dir_or_leaf_id].clear();
178 } else if (dir_or_leaf_id == PRE_ORDER) {
179 new_tree.emplace_back(PRE_ORDER, action, action_id);
180 } else {
181 int pre_dir = std::get<0>(new_tree.back());
182 if (pre_dir == PRE_ORDER) {
183 new_tree.pop_back(); // 一つ前が行きがけなら、削除して追加しない
184 } else {
185 new_tree.emplace_back(POST_ORDER, action, action_id);
186 }
187 }
188 }
189 swap(tree, new_tree);
190 return apply_only_turn;
191 }
192
193 void get_result() {
194 int best_id = -1;
195 ScoreType best_score = 0;
196 for (auto [par, score, _, __] : next_beam) {
197 if (best_id == -1 || score < best_score) {
198 best_score = score;
199 best_id = par;
200 }
201 }
202 assert(best_id != -1);
203 for (const auto &[dir_or_leaf_id, action, _] : tree) {
204 if (dir_or_leaf_id >= 0) {
205 if (best_id == dir_or_leaf_id) {
206 result.emplace_back(action);
207 return;
208 }
209 } else if (dir_or_leaf_id == PRE_ORDER) {
210 result.emplace_back(action);
211 } else {
212 result.pop_back();
213 }
214 }
215 cerr << PRINT_RED << "Error: 解が見つかりませんでした" << PRINT_NONE << endl;
216 assert(false);
217 }
218
219 public:
220
221 /**
222 * @brief ビームサーチをする
223 *
224 * @param param ターン数、ビーム幅を指定するパラメータ構造体
225 * @param verbose ログ出力するかどうか
226 * @return vector<Action>
227 */
228 vector<Action> search(const BeamParam ¶m, const bool verbose = false) {
229 if (verbose) cerr << PRINT_GREEN << "Info: start search()" << PRINT_NONE << endl;
230
231 ActionID = 0;
232 State* state = new State;
233 state->init();
234
235 this->seen = titan23::HashSet(param.BEAM_WIDTH * 4); // TODO
236
237 int now_turn = 0;
238 for (int turn = 0; turn < param.MAX_TURN; ++turn) {
239 if (verbose) cerr << "Info: # turn : " << turn+1 << endl;
240
241 // 次のビーム候補を求める
242 get_next_beam(state, turn-now_turn);
243
244 if (next_beam.empty()) {
245 cerr << PRINT_RED << "Error: \t次の候補が見つかりませんでした" << PRINT_NONE << endl;
246 assert(!next_beam.empty());
247 }
248
249 // ビームを絞る
250 int beam_width = min(param.BEAM_WIDTH, (int)next_beam.size());
251 assert(beam_width <= param.BEAM_WIDTH);
252
253 nth_element(next_beam.begin(), next_beam.begin() + beam_width, next_beam.end(), [&] (const tuple<int, ScoreType, Action, ActionIDType> &left, const tuple<int, ScoreType, Action, ActionIDType> &right) {
254 return std::get<1>(left) < std::get<1>(right);
255 });
256
257 tuple<int, ScoreType, Action, ActionIDType> bests = *min_element(next_beam.begin(), next_beam.begin() + beam_width, [&] (const tuple<int, ScoreType, Action, ActionIDType> &left, const tuple<int, ScoreType, Action, ActionIDType> &right) {
258 return std::get<1>(left) < std::get<1>(right);
259 });
260 if (verbose) cerr << "Info: \tbest_score = " << std::get<1>(bests) << endl;
261 if (std::get<1>(bests) == 0) { // TODO 終了条件
262 cerr << PRINT_GREEN << "Info: find valid solution." << PRINT_NONE << endl;
263 get_result();
264 result.emplace_back(std::get<2>(bests));
265 return result;
266 }
267
268 // 探索木の更新
269 if (turn != 0) {
270 if (next_beam_data.size() < next_beam.size()) {
271 next_beam_data.resize(next_beam.size());
272 }
273 for (int i = 0; i < beam_width; ++i) {
274 auto &[par, _, new_action, new_action_id] = next_beam[i];
275 next_beam_data[par].emplace_back(i);
276 }
277 }
278 int apply_only_turn = update_tree(state, turn);
279 now_turn += apply_only_turn;
280 }
281
282 // 答えを復元する
283 if (verbose) cerr << PRINT_GREEN << "Info: MAX_TURN finished." << PRINT_NONE << endl;
284 get_result();
285 return result;
286 }
287};
288} // namespace beam_search
289
290// int main() {
291// beam_search_with_tree::BeamParam param;
292// param.MAX_TURN = 1000;
293// param.BEAM_WIDTH = 100;
294// beam_search_with_tree::init_zhs();
295// beam_search_with_tree::BeamSearchWithTree bs;
296// vector<beam_search_with_tree::Action> ans = bs.search(param, true);
297// for (const beam_search_with_tree::Action &action : ans) {
298// cout << action;
299// }
300// cout << endl;
301// cerr << "Score = " << ans.size() << endl;
302// return 0;
303// }
仕様¶
Warning
doxygenfile: Cannot find file “titan_cpplib/ahc/beam_search_euler.cpp