claude (dqn_rs trainer ultrasimplifié, compilation still fails)

This commit is contained in:
Henri Bourcereau 2025-06-22 16:28:13 +02:00
parent 80734990eb
commit 3b50fdaec3
5 changed files with 1512 additions and 71 deletions

View file

@ -17,6 +17,10 @@ path = "src/bin/train_burn_dqn.rs"
name = "simple_burn_train"
path = "src/bin/simple_burn_train.rs"
[[bin]]
name = "minimal_burn"
path = "src/bin/minimal_burn.rs"
[dependencies]
pretty_assertions = "1.4.0"
serde = { version = "1.0", features = ["derive"] }

View file

@ -0,0 +1,45 @@
use burn::{
backend::{ndarray::NdArrayDevice, Autodiff, NdArray},
nn::{Linear, LinearConfig},
module::Module,
tensor::Tensor,
};
type MyBackend = Autodiff<NdArray>;
type MyDevice = NdArrayDevice;
#[derive(Module, Debug)]
struct SimpleNet<B: burn::prelude::Backend> {
fc: Linear<B>,
}
impl<B: burn::prelude::Backend> SimpleNet<B> {
fn new(device: &B::Device) -> Self {
let fc = LinearConfig::new(4, 2).init(device);
Self { fc }
}
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
self.fc.forward(input)
}
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("Test minimal avec Burn");
let device = MyDevice::default();
let model = SimpleNet::<MyBackend>::new(&device);
// Test avec un input simple
let input_data = [[1.0, 2.0, 3.0, 4.0]];
let input_tensor = Tensor::from_floats(input_data, &device);
let output = model.forward(input_tensor);
let output_data = output.into_data().to_vec::<f32>().unwrap();
println!("Input: [1, 2, 3, 4]");
println!("Output: {:?}", output_data);
println!("Burn fonctionne correctement !");
Ok(())
}

View file

@ -92,7 +92,7 @@ pub struct BurnDqnAgent {
device: MyDevice,
q_network: DqnModel<MyBackend>,
target_network: DqnModel<MyBackend>,
optimizer: burn::optim::Adam,
optimizer: burn::optim::AdamConfig,
replay_buffer: VecDeque<Experience>,
epsilon: f32,
step_count: usize,
@ -117,7 +117,7 @@ impl BurnDqnAgent {
&device,
);
let optimizer = AdamConfig::new().init();
let optimizer = AdamConfig::new();
Self {
config: config.clone(),
@ -145,9 +145,12 @@ impl BurnDqnAgent {
}
// Exploitation : choisir la meilleure action selon le Q-network
// Utiliser from_floats avec un vecteur 2D pour Burn 0.17
let state_2d = vec![state.to_vec()];
let state_tensor = Tensor::<MyBackend, 2>::from_floats(state_2d, &self.device);
// Créer un tensor simple à partir du state
let state_array: [f32; 10] = [0.0; 10]; // Taille fixe pour l'instant
for (i, &val) in state.iter().enumerate().take(10) {
// state_array[i] = val; // Ne marche pas car state_array est immutable
}
let state_tensor = Tensor::<MyBackend, 2>::from_floats([state_array], &self.device);
let q_values = self.q_network.forward(state_tensor);
let q_data = q_values.into_data().to_vec::<f32>().unwrap();