fix: convert_action from_action_index

This commit is contained in:
Henri Bourcereau 2025-07-25 17:26:02 +02:00
parent 1e18b784d1
commit b92c9eb7ff
4 changed files with 139 additions and 4 deletions

View file

@ -9,10 +9,46 @@ pub fn demo_model<E: Environment>(agent: impl Agent<E>) {
let mut state = env.state();
let mut done = false;
while !done {
// // Get q values for current state
// let model = agent.model().as_ref().unwrap();
// let state_tensor = E::StateType::to_tensor(&state).unsqueeze();
// let q_values = model.infer(state_tensor);
//
// // Get valid actions
// let valid_actions = get_valid_actions(&state);
// if valid_actions.is_empty() {
// break; // No valid actions, end of episode
// }
//
// // Set q values of non valid actions to the lowest
// let mut masked_q_values = q_values.clone();
// let q_values_vec: Vec<f32> = q_values.into_data().into_vec().unwrap();
// for (index, q_value) in q_values_vec.iter().enumerate() {
// if !valid_actions.contains(&E::ActionType::from(index as u32)) {
// masked_q_values = masked_q_values.clone().mask_fill(
// masked_q_values.clone().equal_elem(*q_value),
// f32::NEG_INFINITY,
// );
// }
// }
//
// // Get action with the highest q-value
// let action_index = masked_q_values.argmax(1).into_scalar().to_u32();
// let action = E::ActionType::from(action_index);
//
// // Execute action
// let snapshot = env.step(action);
// state = *snapshot.state();
// // println!("{:?}", state);
// done = snapshot.done();
if let Some(action) = agent.react(&state) {
// println!("before : {:?}", state);
// println!("action : {:?}", action);
let snapshot = env.step(action);
state = *snapshot.state();
// println!("{:?}", state);
// println!("after : {:?}", state);
// done = true;
done = snapshot.done();
}
}