fix: convert_action from_action_index
This commit is contained in:
parent
1e18b784d1
commit
b92c9eb7ff
4 changed files with 139 additions and 4 deletions
|
|
@ -9,10 +9,46 @@ pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
|
|||
let mut state = env.state();
|
||||
let mut done = false;
|
||||
while !done {
|
||||
// // Get q values for current state
|
||||
// let model = agent.model().as_ref().unwrap();
|
||||
// let state_tensor = E::StateType::to_tensor(&state).unsqueeze();
|
||||
// let q_values = model.infer(state_tensor);
|
||||
//
|
||||
// // Get valid actions
|
||||
// let valid_actions = get_valid_actions(&state);
|
||||
// if valid_actions.is_empty() {
|
||||
// break; // No valid actions, end of episode
|
||||
// }
|
||||
//
|
||||
// // Set q values of non valid actions to the lowest
|
||||
// let mut masked_q_values = q_values.clone();
|
||||
// let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
|
||||
// for (index, q_value) in q_values_vec.iter().enumerate() {
|
||||
// if !valid_actions.contains(&E::ActionType::from(index as u32)) {
|
||||
// masked_q_values = masked_q_values.clone().mask_fill(
|
||||
// masked_q_values.clone().equal_elem(*q_value),
|
||||
// f32::NEG_INFINITY,
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // Get action with the highest q-value
|
||||
// let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
|
||||
// let action = E::ActionType::from(action_index);
|
||||
//
|
||||
// // Execute action
|
||||
// let snapshot = env.step(action);
|
||||
// state = *snapshot.state();
|
||||
// // println!("{:?}", state);
|
||||
// done = snapshot.done();
|
||||
|
||||
if let Some(action) = agent.react(&state) {
|
||||
// println!("before : {:?}", state);
|
||||
// println!("action : {:?}", action);
|
||||
let snapshot = env.step(action);
|
||||
state = *snapshot.state();
|
||||
// println!("{:?}", state);
|
||||
// println!("after : {:?}", state);
|
||||
// done = true;
|
||||
done = snapshot.done();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue