diff --git a/bot/src/strategy/burn_dqn_agent.rs b/bot/src/strategy/burn_dqn_agent.rs index 1f1c01a..3830fd1 100644 --- a/bot/src/strategy/burn_dqn_agent.rs +++ b/bot/src/strategy/burn_dqn_agent.rs @@ -148,7 +148,7 @@ impl BurnDqnAgent { } // Exploitation : choisir la meilleure action selon le Q-network - let state_tensor = Tensor::::from_floats(state, &self.device) + let state_tensor = Tensor::::from_floats(state, &self.device) .reshape([1, self.config.state_size]); let q_values = self.q_network.forward(state_tensor); @@ -191,7 +191,7 @@ impl BurnDqnAgent { // Préparer les tenseurs d'état let states: Vec = batch.iter().flat_map(|exp| exp.state.clone()).collect(); - let state_tensor = Tensor::::from_floats(states.as_slice(), &self.device) + let state_tensor = Tensor::::from_floats(states.as_slice(), &self.device) .reshape([self.config.batch_size, self.config.state_size]); // Calculer les Q-values actuelles diff --git a/justfile b/justfile index b4e2c4b..bb1d86e 100644 --- a/justfile +++ b/justfile @@ -19,5 +19,6 @@ pythonlib: pip install --no-deps --force-reinstall --prefix .devenv/state/venv target/wheels/*.whl trainbot: #python ./store/python/trainModel.py - # cargo run --bin=train_dqn - cargo run --bin=train_burn_rl + # cargo run --bin=train_dqn # ok + # cargo run --bin=train_burn_rl # doesn't save model + cargo run --bin=train_dqn_full