fix: tensor dimensions
fix execution error
This commit is contained in:
parent
6a7b1cbebc
commit
b98a135749
|
|
@ -148,7 +148,7 @@ impl BurnDqnAgent {
|
|||
}
|
||||
|
||||
// Exploitation : choisir la meilleure action selon le Q-network
|
||||
let state_tensor = Tensor::<MyBackend, 2>::from_floats(state, &self.device)
|
||||
let state_tensor = Tensor::<MyBackend, 1>::from_floats(state, &self.device)
|
||||
.reshape([1, self.config.state_size]);
|
||||
let q_values = self.q_network.forward(state_tensor);
|
||||
|
||||
|
|
@ -191,7 +191,7 @@ impl BurnDqnAgent {
|
|||
|
||||
// Préparer les tenseurs d'état
|
||||
let states: Vec<f32> = batch.iter().flat_map(|exp| exp.state.clone()).collect();
|
||||
let state_tensor = Tensor::<MyBackend, 2>::from_floats(states.as_slice(), &self.device)
|
||||
let state_tensor = Tensor::<MyBackend, 1>::from_floats(states.as_slice(), &self.device)
|
||||
.reshape([self.config.batch_size, self.config.state_size]);
|
||||
|
||||
// Calculer les Q-values actuelles
|
||||
|
|
|
|||
5
justfile
5
justfile
|
|
@ -19,5 +19,6 @@ pythonlib:
|
|||
pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl
|
||||
trainbot:
|
||||
#python ./store/python/trainModel.py
|
||||
# cargo run --bin=train_dqn
|
||||
cargo run --bin=train_burn_rl
|
||||
# cargo run --bin=train_dqn # ok
|
||||
# cargo run --bin=train_burn_rl # doesn't save model
|
||||
cargo run --bin=train_dqn_full
|
||||
|
|
|
|||
Loading…
Reference in a new issue