import os import json import jax import jax.numpy as jnp import numpy as np from dataclasses import dataclass, field, asdict from enum import Enum from tqdm import tqdm from sklearn.metrics import f1_score import uuid import pickle import sys # Import from your existing model file from model import ModelConfig, TrainConfig, train_model, get_attention_fn # Define an Enum for the data source for clarity and type safety # Overwrite the original TrainConfig to include our new parameters class TrainConfig: """Configuration for the training process.""" learning_rate: float = 1e-3 epochs: int = 100 batch_size: int = 4096 verbose: bool = False # Set to True to see epoch loss during training log: bool = True log_epoch_interval: int = 10 # --- New parameters for this script --- data_directory:str = "datasets/" + sys.argv[1] + "_dataset" # Threshold for converting attention scores to a binary graph f1_threshold: float = -0.4 def prepare_data_for_model(trajectories: np.ndarray, batch_size: int) -> tuple[np.ndarray, np.ndarray]: """ Converts simulation trajectories into input-output pairs for the model. Input: state at time t. Target: state at time t+1. Args: trajectories: A numpy array of shape (num_sims, num_timesteps, num_agents). batch_size: The desired batch size for training. Returns: A tuple of (batched_inputs, batched_targets). """ # For each simulation, create (input, target) pairs all_inputs = [] all_targets = [] num_sims, num_timesteps, num_agents = trajectories.shape # trajectories = np.reshape(trajectories, shape=(num_sims * num_timesteps, num_agents)) for i_sim in range(num_sims): for j_tstep in range(num_timesteps-1): input = trajectories[i_sim, j_tstep, :] target = trajectories[i_sim, j_tstep + 1, :] all_inputs.append(input) all_targets.append(target) all_indices = np.arange(len(all_inputs)) np.random.shuffle(all_indices) all_inputs = np.array(all_inputs) all_targets = np.array(all_targets) all_inputs = all_inputs[all_indices] all_targets = all_targets[all_indices] # for sim_idx in range(num_sims): # # Input is state from t=0 to t=T-2 # inputs = trajectories[sim_idx, :-1, :] # # Target is state from t=1 to t=T-1 # targets = trajectories[sim_idx, 1:, :] # all_inputs.append(inputs) # all_targets.append(targets) # Concatenate all pairs from all simulations # Shape -> (num_sims * (num_timesteps - 1), num_agents) # full_dataset_inputs = np.concatenate(all_inputs, axis=0) # full_dataset_targets = np.concatenate(all_targets, axis=0) # Reshape to have a feature dimension # Shape -> (total_samples, num_agents, 1) full_dataset_inputs = np.expand_dims(all_inputs, axis=-1) full_dataset_targets = np.expand_dims(all_targets, axis=-1) # Create batches num_samples = full_dataset_inputs.shape[0] num_batches = num_samples // batch_size # Truncate to full batches truncated_inputs = full_dataset_inputs[:num_batches * batch_size] truncated_targets = full_dataset_targets[:num_batches * batch_size] # Reshape into batches # Shape -> (num_batches, batch_size, num_agents, 1) batched_inputs = truncated_inputs.reshape(num_batches, batch_size, num_agents, 1) batched_targets = truncated_targets.reshape(num_batches, batch_size, num_agents, 1) return batched_inputs, batched_targets def calculate_f1_score( params: dict, model_config: ModelConfig, true_graph: np.ndarray, threshold: float ) -> float: """ Extracts the learned attention graph, thresholds it, and computes the F1 score. """ # Get the learned attention matrix (N, N) learned_graph_scores = np.array(get_attention_fn(params, model_config)) # Normalize scores to [0, 1] for consistent thresholding (optional but good practice) # This uses min-max scaling on the flattened array # flat_scores = learned_graph_scores.flatten() # min_s, max_s = flat_scores.min(), flat_scores.max() # if max_s > min_s: # learned_graph_scores = (learned_graph_scores - min_s) / (max_s - min_s) # Threshold to get a binary predicted graph predicted_graph = (learned_graph_scores > threshold).astype(int) # The diagonal is not part of the prediction task # np.fill_diagonal(predicted_graph, 0) # np.fill_diagonal(true_graph, 0) # Flatten both graphs to treat this as a binary classification problem true_flat = true_graph.flatten() pred_flat = predicted_graph.flatten() return f1_score(true_flat, pred_flat) def main(): """Main script to run the training and evaluation pipeline.""" train_config = TrainConfig() # Check if the data directory exists if not os.path.isdir(train_config.data_directory): print(f"Error: Data directory '{train_config.data_directory}' not found.") print(f"Please run the data generation script for '{train_config.data_directory}' first.") return print(f"šŸš€ Starting training pipeline for '{train_config.data_directory}' data.") # Get sorted list of agent directories agent_dirs = sorted( [d for d in os.listdir(train_config.data_directory) if d.startswith("agents_")], key=lambda x: int(x.split('_')[1]) ) for agent_dir_name in agent_dirs: agent_dir_path = os.path.join(train_config.data_directory, agent_dir_name) all_results_for_agent = [] graph_files = sorted([f for f in os.listdir(agent_dir_path) if f.endswith(".json")]) print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...") for graph_file_name in tqdm(graph_files, desc=f"Training on {agent_dir_name}"): file_path = os.path.join(agent_dir_path, graph_file_name) with open(file_path, 'r') as f: data = json.load(f) # 1. Load and Prepare Data trajectories = np.array(data['trajectories']) s, l, n = trajectories.shape # trajectories = trajectories.T # np.random.shuffle(trajectories) # trajectories = np.random.shuffle(trajectories) true_graph = np.array(data['adjacency_matrix']) inputs, targets = prepare_data_for_model(trajectories, train_config.batch_size) # 2. Configure Model num_agents = trajectories.shape[-1] model_config = ModelConfig( ) model_config.num_agents=num_agents model_config.input_dim=1 # Each agent has a single state value at time t model_config.output_dim=1 model_config.embedding_dim=32 # 3. Train the Model # This relies on the modified train_model that returns final params final_params, train_logs = train_model( config=model_config, inputs=inputs, targets=targets, true_graph=true_graph, train_config=train_config ) # 4. Evaluate f1 = calculate_f1_score( final_params, model_config, true_graph, train_config.f1_threshold ) loss_history_serializable = { epoch: [loss.item() for loss in losses] for epoch, losses in train_logs['loss_history'].items() } random_id = str(uuid.uuid4()) # 5. Log Results result_log = { # "model_name": random_id, "source_file": graph_file_name, "graph_metrics": data['graph_metrics'], "f1_score": f1, "training_loss_history": loss_history_serializable, "config": { # Manually create the dictionary for the model config "model": { "num_agents": model_config.num_agents, "input_dim": model_config.input_dim, "output_dim": model_config.output_dim, "embedding_dim": model_config.embedding_dim }, # This is correct because TrainConfig is a dataclass "training": vars(train_config) } } result_final_params = final_params all_results_for_agent.append(result_log) # 6. Save aggregated results for this agent count results_dir = os.path.join(agent_dir_path, "results") os.makedirs(results_dir, exist_ok=True) output_file = os.path.join(results_dir, "summary_results.json") with open(output_file, 'w') as f: json.dump(all_results_for_agent, f, indent=2) model_path = os.path.join(results_dir, "model_params") os.makedirs(model_path, exist_ok=True) with open(os.path.join(model_path,"model_params" + ".pkl"), "wb") as f: pickle.dump(final_params, f) print(f"āœ… Results for {agent_dir_name} saved to {output_file}") print("\nšŸŽ‰ Pipeline finished successfully!") if __name__ == "__main__": main()