import os import json import jax import jax.numpy as jnp import numpy as np from dataclasses import dataclass, field, asdict from enum import Enum from tqdm import tqdm from sklearn.metrics import f1_score import uuid import pickle import sys from config_ import TrainConfig, ModelConfig, NoiseType from model import train_model, get_attention_fn def prepare_data_for_model(key: jax.Array, trajectories: np.ndarray, train_config: TrainConfig, batch_size: int) -> tuple[np.ndarray, np.ndarray]: """ Converts simulation trajectories into input-output pairs for the model. Input: state at time t. Target: state at time t+1. Args: trajectories: A numpy array of shape (num_sims, num_timesteps, num_agents). batch_size: The desired batch size for training. Returns: A tuple of (batched_inputs, batched_targets). """ # For each simulation, create (input, target) pairs all_inputs = [] all_targets = [] num_sims, num_timesteps, num_agents = trajectories.shape # trajectories = np.reshape(trajectories, shape=(num_sims * num_timesteps, num_agents)) for i_sim in range(num_sims): for j_tstep in range(num_timesteps-1): input = trajectories[i_sim, j_tstep, :] target = trajectories[i_sim, j_tstep + 1, :] all_inputs.append(input) all_targets.append(target) all_indices = np.arange(len(all_inputs)) np.random.shuffle(all_indices) all_inputs = np.array(all_inputs) all_targets = np.array(all_targets) all_inputs = all_inputs[all_indices] all_targets = all_targets[all_indices] if train_config.noise_type != NoiseType.NONE and train_config.noise_level > 0: noise_shape = all_inputs.shape if train_config.noise_type == NoiseType.NORMAL: noise = jax.random.normal(key, noise_shape) * train_config.noise_level elif train_config.noise_type == NoiseType.UNIFORM: noise = jax.random.uniform( key, noise_shape, minval=-train_config.noise_level, maxval=train_config.noise_level ) all_inputs += np.array(noise) # Add noise to inputs all_inputs = np.expand_dims(all_inputs, axis=-1) full_dataset_targets = np.expand_dims(all_targets, axis=-1) # Create batches num_samples = all_inputs.shape[0] num_batches = num_samples // batch_size # Truncate to full batches truncated_inputs = all_inputs[:num_batches * batch_size] truncated_targets = full_dataset_targets[:num_batches * batch_size] # Reshape into batches # Shape -> (num_batches, batch_size, num_agents, 1) batched_inputs = truncated_inputs.reshape(num_batches, batch_size, num_agents, 1) batched_targets = truncated_targets.reshape(num_batches, batch_size, num_agents, 1) return batched_inputs, batched_targets def f1_score_np(y_true: np.ndarray, y_pred: np.ndarray) -> float: """ Compute the F1 score between two numpy arrays. Parameters ---------- y_true : np.ndarray Ground truth (correct) labels. y_pred : np.ndarray Predicted labels. Returns ------- float The F1 score. """ # Ensure binary arrays (0 or 1) y_true = np.asarray(y_true).astype(int) y_pred = np.asarray(y_pred).astype(int) # Compute True Positives, False Positives, and False Negatives tp = np.sum((y_true == 1) & (y_pred == 1)) fp = np.sum((y_true == 0) & (y_pred == 1)) fn = np.sum((y_true == 1) & (y_pred == 0)) # Precision and Recall precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 # F1 Score if precision + recall == 0: return 0.0 return 2 * (precision * recall) / (precision + recall) def calculate_f1_score( params: dict, model_config: ModelConfig, true_graph: np.ndarray, threshold: float ) -> float: """ Extracts the learned attention graph, thresholds it, and computes the F1 score. """ # Get the learned attention matrix (N, N) learned_graph_scores = np.array(get_attention_fn(params, model_config)) # Normalize scores to [0, 1] for consistent thresholding (optional but good practice) # This uses min-max scaling on the flattened array # flat_scores = learned_graph_scores.flatten() # min_s, max_s = flat_scores.min(), flat_scores.max() # if max_s > min_s: # learned_graph_scores = (learned_graph_scores - min_s) / (max_s - min_s) # Threshold to get a binary predicted graph predicted_graph = (learned_graph_scores > threshold).astype(int) # The diagonal is not part of the prediction task # np.fill_diagonal(predicted_graph, 0) # np.fill_diagonal(true_graph, 0) # Flatten both graphs to treat this as a binary classification problem true_flat = true_graph.flatten() pred_flat = predicted_graph.flatten() return f1_score_np(true_flat, pred_flat) def main(): """Main script to run the training and evaluation pipeline.""" train_config = TrainConfig() train_config.data_directory = "datasets/" + sys.argv[1] + "_dataset" # Check if the data directory exists if not os.path.isdir(train_config.data_directory): print(f"Error: Data directory '{train_config.data_directory}' not found.") print(f"Please run the data generation script for '{train_config.data_directory}' first.") return print(f"Starting training pipeline for '{train_config.data_directory}' data.") # Get sorted list of agent directories agent_dirs = sorted( [d for d in os.listdir(train_config.data_directory) if d.startswith("agents_")], key=lambda x: int(x.split('_')[1]) ) starter_key = jax.random.PRNGKey(49) for agent_dir_name in agent_dirs: agent_dir_path = os.path.join(train_config.data_directory, agent_dir_name) all_results_for_agent = [] graph_files = sorted([f for f in os.listdir(agent_dir_path) if f.endswith(".json")]) results_dir = os.path.join(agent_dir_path, "results") os.makedirs(results_dir, exist_ok=True) subdir = str(train_config.noise_type) subsubdir = str(train_config.noise_level) sub_results_dir = os.path.join(results_dir, subdir, subsubdir) os.makedirs(sub_results_dir, exist_ok=True) print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...") for graph_file_name in tqdm(graph_files, desc=f"Training on {agent_dir_name}"): file_path = os.path.join(agent_dir_path, graph_file_name) with open(file_path, 'r') as f: data = json.load(f) # 1. Load and Prepare Data trajectories = np.array(data['trajectories']) s, l, n = trajectories.shape # trajectories = trajectories.T # np.random.shuffle(trajectories) # trajectories = np.random.shuffle(trajectories) true_graph = np.array(data['adjacency_matrix']) starter_key, data_key = jax.random.split(starter_key) inputs, targets = prepare_data_for_model(data_key, trajectories, train_config, train_config.batch_size) # 2. Configure Model num_agents = trajectories.shape[-1] model_config = ModelConfig( ) model_config.num_agents=num_agents model_config.input_dim=1 # Each agent has a single state value at time t model_config.output_dim=1 model_config.embedding_dim=32 # 3. Train the Model # This relies on the modified train_model that returns final params final_params, train_logs = train_model( config=model_config, inputs=inputs, targets=targets, true_graph=true_graph, train_config=train_config ) # 4. Evaluate f1 = calculate_f1_score( final_params, model_config, true_graph, train_config.f1_threshold ) loss_history_serializable = { epoch: [loss.item() for loss in losses] for epoch, losses in train_logs['loss_history'].items() } random_id = str(uuid.uuid4()) # 5. Log Results result_log = { # "model_name": random_id, "source_file": graph_file_name, "graph_metrics": data['graph_metrics'], "f1_score": f1, "training_loss_history": loss_history_serializable, "config": { # Manually create the dictionary for the model config "model": { "num_agents": model_config.num_agents, "input_dim": model_config.input_dim, "output_dim": model_config.output_dim, "embedding_dim": model_config.embedding_dim }, # This is correct because TrainConfig is a dataclass "training": { "epochs" : train_config.epochs, "learning_rate" : train_config.learning_rate, "verbose" : train_config.verbose, "log" : train_config.log, "log_epoch_interval" : train_config.log_epoch_interval, "noise_type": str(train_config.noise_type), "noise_level": train_config.noise_level, } }, "raw_attention": np.array(get_attention_fn(final_params, model_config)).tolist() } result_final_params = final_params model_path = os.path.join(sub_results_dir, "model_params") os.makedirs(model_path, exist_ok=True) with open(os.path.join(model_path,"model_for_" + graph_file_name + ".pkl"), "wb") as f: pickle.dump(final_params, f) all_results_for_agent.append(result_log) output_file = os.path.join(sub_results_dir, "summary_results.json") with open(output_file, 'w') as f: json.dump(all_results_for_agent, f, indent=2) print(f"āœ… Results for {agent_dir_name} saved to {output_file}") print("\nšŸŽ‰ Pipeline finished successfully!") if __name__ == "__main__": main()