Files
graph_recognition_w_attn/train_and_eval.py
2025-07-31 01:12:53 -04:00

259 lines
9.3 KiB
Python

import os
import json
import jax
import jax.numpy as jnp
import numpy as np
from dataclasses import dataclass, field, asdict
from enum import Enum
from tqdm import tqdm
from sklearn.metrics import f1_score
import uuid
import pickle
import sys
# Import from your existing model file
from model import ModelConfig, TrainConfig, train_model, get_attention_fn
# Define an Enum for the data source for clarity and type safety
# Overwrite the original TrainConfig to include our new parameters
class TrainConfig:
"""Configuration for the training process."""
learning_rate: float = 1e-3
epochs: int = 100
batch_size: int = 4096
verbose: bool = False # Set to True to see epoch loss during training
log: bool = True
log_epoch_interval: int = 10
# --- New parameters for this script ---
data_directory:str = "datasets/" + sys.argv[1] + "_dataset"
# Threshold for converting attention scores to a binary graph
f1_threshold: float = -0.4
def prepare_data_for_model(trajectories: np.ndarray, batch_size: int) -> tuple[np.ndarray, np.ndarray]:
"""
Converts simulation trajectories into input-output pairs for the model.
Input: state at time t. Target: state at time t+1.
Args:
trajectories: A numpy array of shape (num_sims, num_timesteps, num_agents).
batch_size: The desired batch size for training.
Returns:
A tuple of (batched_inputs, batched_targets).
"""
# For each simulation, create (input, target) pairs
all_inputs = []
all_targets = []
num_sims, num_timesteps, num_agents = trajectories.shape
# trajectories = np.reshape(trajectories, shape=(num_sims * num_timesteps, num_agents))
for i_sim in range(num_sims):
for j_tstep in range(num_timesteps-1):
input = trajectories[i_sim, j_tstep, :]
target = trajectories[i_sim, j_tstep + 1, :]
all_inputs.append(input)
all_targets.append(target)
all_indices = np.arange(len(all_inputs))
np.random.shuffle(all_indices)
all_inputs = np.array(all_inputs)
all_targets = np.array(all_targets)
all_inputs = all_inputs[all_indices]
all_targets = all_targets[all_indices]
# for sim_idx in range(num_sims):
# # Input is state from t=0 to t=T-2
# inputs = trajectories[sim_idx, :-1, :]
# # Target is state from t=1 to t=T-1
# targets = trajectories[sim_idx, 1:, :]
# all_inputs.append(inputs)
# all_targets.append(targets)
# Concatenate all pairs from all simulations
# Shape -> (num_sims * (num_timesteps - 1), num_agents)
# full_dataset_inputs = np.concatenate(all_inputs, axis=0)
# full_dataset_targets = np.concatenate(all_targets, axis=0)
# Reshape to have a feature dimension
# Shape -> (total_samples, num_agents, 1)
full_dataset_inputs = np.expand_dims(all_inputs, axis=-1)
full_dataset_targets = np.expand_dims(all_targets, axis=-1)
# Create batches
num_samples = full_dataset_inputs.shape[0]
num_batches = num_samples // batch_size
# Truncate to full batches
truncated_inputs = full_dataset_inputs[:num_batches * batch_size]
truncated_targets = full_dataset_targets[:num_batches * batch_size]
# Reshape into batches
# Shape -> (num_batches, batch_size, num_agents, 1)
batched_inputs = truncated_inputs.reshape(num_batches, batch_size, num_agents, 1)
batched_targets = truncated_targets.reshape(num_batches, batch_size, num_agents, 1)
return batched_inputs, batched_targets
def calculate_f1_score(
params: dict,
model_config: ModelConfig,
true_graph: np.ndarray,
threshold: float
) -> float:
"""
Extracts the learned attention graph, thresholds it, and computes the F1 score.
"""
# Get the learned attention matrix (N, N)
learned_graph_scores = np.array(get_attention_fn(params, model_config))
# Normalize scores to [0, 1] for consistent thresholding (optional but good practice)
# This uses min-max scaling on the flattened array
# flat_scores = learned_graph_scores.flatten()
# min_s, max_s = flat_scores.min(), flat_scores.max()
# if max_s > min_s:
# learned_graph_scores = (learned_graph_scores - min_s) / (max_s - min_s)
# Threshold to get a binary predicted graph
predicted_graph = (learned_graph_scores > threshold).astype(int)
# The diagonal is not part of the prediction task
# np.fill_diagonal(predicted_graph, 0)
# np.fill_diagonal(true_graph, 0)
# Flatten both graphs to treat this as a binary classification problem
true_flat = true_graph.flatten()
pred_flat = predicted_graph.flatten()
return f1_score(true_flat, pred_flat)
def main():
"""Main script to run the training and evaluation pipeline."""
train_config = TrainConfig()
# Check if the data directory exists
if not os.path.isdir(train_config.data_directory):
print(f"Error: Data directory '{train_config.data_directory}' not found.")
print(f"Please run the data generation script for '{train_config.data_directory}' first.")
return
print(f"🚀 Starting training pipeline for '{train_config.data_directory}' data.")
# Get sorted list of agent directories
agent_dirs = sorted(
[d for d in os.listdir(train_config.data_directory) if d.startswith("agents_")],
key=lambda x: int(x.split('_')[1])
)
for agent_dir_name in agent_dirs:
agent_dir_path = os.path.join(train_config.data_directory, agent_dir_name)
all_results_for_agent = []
graph_files = sorted([f for f in os.listdir(agent_dir_path) if f.endswith(".json")])
print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...")
for graph_file_name in tqdm(graph_files, desc=f"Training on {agent_dir_name}"):
file_path = os.path.join(agent_dir_path, graph_file_name)
with open(file_path, 'r') as f:
data = json.load(f)
# 1. Load and Prepare Data
trajectories = np.array(data['trajectories'])
s, l, n = trajectories.shape
# trajectories = trajectories.T
# np.random.shuffle(trajectories)
# trajectories = np.random.shuffle(trajectories)
true_graph = np.array(data['adjacency_matrix'])
inputs, targets = prepare_data_for_model(trajectories, train_config.batch_size)
# 2. Configure Model
num_agents = trajectories.shape[-1]
model_config = ModelConfig(
)
model_config.num_agents=num_agents
model_config.input_dim=1 # Each agent has a single state value at time t
model_config.output_dim=1
model_config.embedding_dim=32
# 3. Train the Model
# This relies on the modified train_model that returns final params
final_params, train_logs = train_model(
config=model_config,
inputs=inputs,
targets=targets,
true_graph=true_graph,
train_config=train_config
)
# 4. Evaluate
f1 = calculate_f1_score(
final_params,
model_config,
true_graph,
train_config.f1_threshold
)
loss_history_serializable = {
epoch: [loss.item() for loss in losses]
for epoch, losses in train_logs['loss_history'].items()
}
random_id = str(uuid.uuid4())
# 5. Log Results
result_log = {
# "model_name": random_id,
"source_file": graph_file_name,
"graph_metrics": data['graph_metrics'],
"f1_score": f1,
"training_loss_history": loss_history_serializable,
"config": {
# Manually create the dictionary for the model config
"model": {
"num_agents": model_config.num_agents,
"input_dim": model_config.input_dim,
"output_dim": model_config.output_dim,
"embedding_dim": model_config.embedding_dim
},
# This is correct because TrainConfig is a dataclass
"training": vars(train_config)
}
}
result_final_params = final_params
all_results_for_agent.append(result_log)
# 6. Save aggregated results for this agent count
results_dir = os.path.join(agent_dir_path, "results")
os.makedirs(results_dir, exist_ok=True)
output_file = os.path.join(results_dir, "summary_results.json")
with open(output_file, 'w') as f:
json.dump(all_results_for_agent, f, indent=2)
model_path = os.path.join(results_dir, "model_params")
os.makedirs(model_path, exist_ok=True)
with open(os.path.join(model_path,"model_params" + ".pkl"), "wb") as f:
pickle.dump(final_params, f)
print(f"✅ Results for {agent_dir_name} saved to {output_file}")
print("\n🎉 Pipeline finished successfully!")
if __name__ == "__main__":
main()