import os import json import jax import jax.numpy as jnp import numpy as np from dataclasses import dataclass, field, asdict from enum import Enum from tqdm import tqdm from sklearn.metrics import f1_score # Import from your existing model file from model import ModelConfig, train_model, get_attention_fn # --- MODIFICATION: Enums for clarity and safety --- class DataSource(Enum): CONSENSUS = "consensus" KURAMOTO = "kuramoto" class NoiseType(Enum): NONE = "none" NORMAL = "normal" UNIFORM = "uniform" # --- MODIFICATION: Updated TrainConfig --- @dataclass class TrainConfig: """Configuration for the training process.""" learning_rate: float = 1e-3 epochs: int = 50 batch_size: int = 64 verbose: bool = False log: bool = True log_epoch_interval: int = 10 # Data and Noise parameters data_source: DataSource = DataSource.CONSENSUS noise_type: NoiseType = NoiseType.NORMAL noise_level: float = 0.1 # Stddev for Normal, half-width for Uniform # Evaluation parameter f1_threshold: float = 0.5 @property def source_data_directory(self) -> str: """The directory where the source dataset is located.""" return f"{self.data_source.value}_dataset" @property def results_directory_name(self) -> str: """Generates a unique output directory name for separate logging.""" if self.noise_type == NoiseType.NONE or self.noise_level == 0: return f"results_noiseless" return f"results_noise_{self.noise_type.value}_{self.noise_level}" # --- MODIFICATION: Updated data prep function to add noise --- def prepare_data_for_model( trajectories: np.ndarray, key: jax.Array, train_config: TrainConfig ) -> tuple[np.ndarray, np.ndarray]: """ Converts trajectories to input-output pairs and adds noise to the inputs. """ all_inputs, all_targets = [], [] num_sims, num_timesteps, num_agents = trajectories.shape for sim_idx in range(num_sims): all_inputs.append(trajectories[sim_idx, :-1, :]) all_targets.append(trajectories[sim_idx, 1:, :]) full_dataset_inputs = np.concatenate(all_inputs, axis=0) full_dataset_targets = np.concatenate(all_targets, axis=0) # --- NOISE INJECTION BLOCK --- if train_config.noise_type != NoiseType.NONE and train_config.noise_level > 0: noise_shape = full_dataset_inputs.shape if train_config.noise_type == NoiseType.NORMAL: noise = jax.random.normal(key, noise_shape) * train_config.noise_level elif train_config.noise_type == NoiseType.UNIFORM: noise = jax.random.uniform( key, noise_shape, minval=-train_config.noise_level, maxval=train_config.noise_level ) full_dataset_inputs += np.array(noise) # Add noise to inputs # --- END NOISE BLOCK --- full_dataset_inputs = np.expand_dims(full_dataset_inputs, axis=-1) full_dataset_targets = np.expand_dims(full_dataset_targets, axis=-1) num_samples = full_dataset_inputs.shape[0] num_batches = num_samples // train_config.batch_size truncated_inputs = full_dataset_inputs[:num_batches * train_config.batch_size] truncated_targets = full_dataset_targets[:num_batches * train_config.batch_size] batched_inputs = truncated_inputs.reshape(num_batches, train_config.batch_size, num_agents, 1) batched_targets = truncated_targets.reshape(num_batches, train_config.batch_size, num_agents, 1) return batched_inputs, batched_targets def calculate_f1_score( params: dict, model_config: ModelConfig, true_graph: np.ndarray, threshold: float ) -> float: """Extracts the learned graph and computes the F1 score.""" learned_scores = np.array(get_attention_fn(params, model_config)) flat_scores = learned_scores.flatten() min_s, max_s = flat_scores.min(), flat_scores.max() if max_s > min_s: learned_scores = (learned_scores - min_s) / (max_s - min_s) predicted_graph = (learned_scores > threshold).astype(int) np.fill_diagonal(predicted_graph, 0) np.fill_diagonal(true_graph, 0) return f1_score(true_graph.flatten(), predicted_graph.flatten()) def main(): """Main script to run the training and evaluation pipeline.""" # Configure your training run here train_config = TrainConfig( noise_type=NoiseType.NORMAL, noise_level=0.1 ) if not os.path.isdir(train_config.source_data_directory): print(f"Error: Source data '{train_config.source_data_directory}' not found.") return print(f"šŸš€ Starting training pipeline for '{train_config.data_source.value}' data.") print(f"Noise Configuration: type={train_config.noise_type.value}, level={train_config.noise_level}") # --- MODIFICATION: Main JAX key for noise generation --- main_key = jax.random.PRNGKey(42) agent_dirs = sorted( [d for d in os.listdir(train_config.source_data_directory) if d.startswith("agents_")], key=lambda x: int(x.split('_')[1]) ) for agent_dir_name in agent_dirs: agent_dir_path = os.path.join(train_config.source_data_directory, agent_dir_name) all_results_for_agent = [] graph_files = sorted([f for f in os.listdir(agent_dir_path) if f.endswith(".json")]) print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...") for graph_file_name in tqdm(graph_files, desc=f"Training on {agent_dir_name}"): file_path = os.path.join(agent_dir_path, graph_file_name) with open(file_path, 'r') as f: data = json.load(f) main_key, data_key = jax.random.split(main_key) trajectories = np.array(data['trajectories']) true_graph = np.array(data['adjacency_matrix']) # --- MODIFICATION: Pass key and config to data prep --- inputs, targets = prepare_data_for_model(trajectories, data_key, train_config) num_agents = int(trajectories.shape[-1]) model_config = ModelConfig() model_config.num_agents = num_agents model_config.input_dim = 1 model_config.output_dim = 1 model_config.embedding_dim = 32 final_params, train_logs = train_model( config=model_config, inputs=inputs, targets=targets, true_graph=true_graph, train_config=train_config ) f1 = calculate_f1_score(final_params, model_config, true_graph, train_config.f1_threshold) loss_history_serializable = { epoch: [loss.item() for loss in losses] for epoch, losses in train_logs['loss_history'].items() } result_log = { "source_file": graph_file_name, "graph_metrics": data['graph_metrics'], "f1_score": f1, "training_loss_history": loss_history_serializable, "config": { "model": { "num_agents": model_config.num_agents, "input_dim": model_config.input_dim, "output_dim": model_config.output_dim, "embedding_dim": model_config.embedding_dim }, "training": asdict(train_config) } } all_results_for_agent.append(result_log) # --- MODIFICATION: Save to a separate results directory --- results_dir = os.path.join(agent_dir_path, train_config.results_directory_name) os.makedirs(results_dir, exist_ok=True) output_file = os.path.join(results_dir, "summary_results.json") with open(output_file, 'w') as f: json.dump(all_results_for_agent, f, indent=2) print(f"āœ… Results for {agent_dir_name} saved to {output_file}") print("\nšŸŽ‰ Pipeline finished successfully!") if __name__ == "__main__": main()