fix: tensor dimensions

fix execution error
This commit is contained in:
Henri Bourcereau 2025-06-29 11:30:34 +02:00
parent 6a7b1cbebc
commit b98a135749
2 changed files with 5 additions and 4 deletions

View file

@ -148,7 +148,7 @@ impl BurnDqnAgent {
}
// Exploitation : choisir la meilleure action selon le Q-network
let state_tensor = Tensor::<MyBackend, 2>::from_floats(state, &self.device)
let state_tensor = Tensor::<MyBackend, 1>::from_floats(state, &self.device)
.reshape([1, self.config.state_size]);
let q_values = self.q_network.forward(state_tensor);
@ -191,7 +191,7 @@ impl BurnDqnAgent {
// Préparer les tenseurs d'état
let states: Vec<f32> = batch.iter().flat_map(|exp| exp.state.clone()).collect();
let state_tensor = Tensor::<MyBackend, 2>::from_floats(states.as_slice(), &self.device)
let state_tensor = Tensor::<MyBackend, 1>::from_floats(states.as_slice(), &self.device)
.reshape([self.config.batch_size, self.config.state_size]);
// Calculer les Q-values actuelles

View file

@ -19,5 +19,6 @@ pythonlib:
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
trainbot:
#python ./store/python/trainModel.py
# cargo run --bin=train_dqn
cargo run --bin=train_burn_rl
# cargo run --bin=train_dqn # ok
# cargo run --bin=train_burn_rl # doesn't save model
cargo run --bin=train_dqn_full