Files
graph_recognition_w_attn/train_and_eval.py

290 lines
10 KiB
Python
Raw Normal View History

2025-07-31 01:12:53 -04:00
import os
import json
import jax
import jax.numpy as jnp
import numpy as np
from dataclasses import dataclass, field, asdict
from enum import Enum
from tqdm import tqdm
from sklearn.metrics import f1_score
import uuid
import pickle
import sys
2025-08-04 12:44:35 -04:00
from config_ import TrainConfig, ModelConfig, NoiseType
from model import train_model, get_attention_fn
2025-07-31 01:12:53 -04:00
2025-08-04 12:44:35 -04:00
def prepare_data_for_model(key: jax.Array, trajectories: np.ndarray, train_config: TrainConfig, batch_size: int) -> tuple[np.ndarray, np.ndarray]:
2025-07-31 01:12:53 -04:00
"""
Converts simulation trajectories into input-output pairs for the model.
Input: state at time t. Target: state at time t+1.
Args:
trajectories: A numpy array of shape (num_sims, num_timesteps, num_agents).
batch_size: The desired batch size for training.
Returns:
A tuple of (batched_inputs, batched_targets).
"""
# For each simulation, create (input, target) pairs
all_inputs = []
all_targets = []
num_sims, num_timesteps, num_agents = trajectories.shape
# trajectories = np.reshape(trajectories, shape=(num_sims * num_timesteps, num_agents))
for i_sim in range(num_sims):
for j_tstep in range(num_timesteps-1):
input = trajectories[i_sim, j_tstep, :]
target = trajectories[i_sim, j_tstep + 1, :]
all_inputs.append(input)
all_targets.append(target)
all_indices = np.arange(len(all_inputs))
np.random.shuffle(all_indices)
all_inputs = np.array(all_inputs)
all_targets = np.array(all_targets)
all_inputs = all_inputs[all_indices]
all_targets = all_targets[all_indices]
2025-08-04 12:44:35 -04:00
if train_config.noise_type != NoiseType.NONE and train_config.noise_level > 0:
2025-09-01 14:46:34 -04:00
noise_shape = all_inputs.shape
2025-08-04 12:44:35 -04:00
if train_config.noise_type == NoiseType.NORMAL:
noise = jax.random.normal(key, noise_shape) * train_config.noise_level
elif train_config.noise_type == NoiseType.UNIFORM:
noise = jax.random.uniform(
key, noise_shape,
minval=-train_config.noise_level,
maxval=train_config.noise_level
)
2025-09-01 14:46:34 -04:00
all_inputs += np.array(noise) # Add noise to inputs
2025-08-04 12:44:35 -04:00
2025-07-31 01:12:53 -04:00
2025-09-01 14:46:34 -04:00
all_inputs = np.expand_dims(all_inputs, axis=-1)
2025-07-31 01:12:53 -04:00
full_dataset_targets = np.expand_dims(all_targets, axis=-1)
# Create batches
2025-09-01 14:46:34 -04:00
num_samples = all_inputs.shape[0]
2025-07-31 01:12:53 -04:00
num_batches = num_samples // batch_size
# Truncate to full batches
2025-09-01 14:46:34 -04:00
truncated_inputs = all_inputs[:num_batches * batch_size]
2025-07-31 01:12:53 -04:00
truncated_targets = full_dataset_targets[:num_batches * batch_size]
# Reshape into batches
# Shape -> (num_batches, batch_size, num_agents, 1)
batched_inputs = truncated_inputs.reshape(num_batches, batch_size, num_agents, 1)
batched_targets = truncated_targets.reshape(num_batches, batch_size, num_agents, 1)
return batched_inputs, batched_targets
2025-09-01 14:46:34 -04:00
def f1_score_np(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""
Compute the F1 score between two numpy arrays.
Parameters
----------
y_true : np.ndarray
Ground truth (correct) labels.
y_pred : np.ndarray
Predicted labels.
Returns
-------
float
The F1 score.
"""
# Ensure binary arrays (0 or 1)
y_true = np.asarray(y_true).astype(int)
y_pred = np.asarray(y_pred).astype(int)
# Compute True Positives, False Positives, and False Negatives
tp = np.sum((y_true == 1) & (y_pred == 1))
fp = np.sum((y_true == 0) & (y_pred == 1))
fn = np.sum((y_true == 1) & (y_pred == 0))
# Precision and Recall
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
# F1 Score
if precision + recall == 0:
return 0.0
return 2 * (precision * recall) / (precision + recall)
2025-07-31 01:12:53 -04:00
def calculate_f1_score(
params: dict,
model_config: ModelConfig,
true_graph: np.ndarray,
threshold: float
) -> float:
"""
Extracts the learned attention graph, thresholds it, and computes the F1 score.
"""
# Get the learned attention matrix (N, N)
learned_graph_scores = np.array(get_attention_fn(params, model_config))
# Normalize scores to [0, 1] for consistent thresholding (optional but good practice)
# This uses min-max scaling on the flattened array
# flat_scores = learned_graph_scores.flatten()
# min_s, max_s = flat_scores.min(), flat_scores.max()
# if max_s > min_s:
# learned_graph_scores = (learned_graph_scores - min_s) / (max_s - min_s)
# Threshold to get a binary predicted graph
predicted_graph = (learned_graph_scores > threshold).astype(int)
# The diagonal is not part of the prediction task
# np.fill_diagonal(predicted_graph, 0)
# np.fill_diagonal(true_graph, 0)
# Flatten both graphs to treat this as a binary classification problem
true_flat = true_graph.flatten()
pred_flat = predicted_graph.flatten()
2025-09-01 14:46:34 -04:00
return f1_score_np(true_flat, pred_flat)
2025-07-31 01:12:53 -04:00
def main():
"""Main script to run the training and evaluation pipeline."""
train_config = TrainConfig()
2025-08-04 12:44:35 -04:00
train_config.data_directory = "datasets/" + sys.argv[1] + "_dataset"
2025-07-31 01:12:53 -04:00
# Check if the data directory exists
if not os.path.isdir(train_config.data_directory):
print(f"Error: Data directory '{train_config.data_directory}' not found.")
print(f"Please run the data generation script for '{train_config.data_directory}' first.")
return
2025-09-01 14:46:34 -04:00
print(f"Starting training pipeline for '{train_config.data_directory}' data.")
2025-07-31 01:12:53 -04:00
# Get sorted list of agent directories
agent_dirs = sorted(
[d for d in os.listdir(train_config.data_directory) if d.startswith("agents_")],
key=lambda x: int(x.split('_')[1])
)
2025-08-04 12:44:35 -04:00
starter_key = jax.random.PRNGKey(49)
2025-07-31 01:12:53 -04:00
for agent_dir_name in agent_dirs:
agent_dir_path = os.path.join(train_config.data_directory, agent_dir_name)
all_results_for_agent = []
graph_files = sorted([f for f in os.listdir(agent_dir_path) if f.endswith(".json")])
2025-08-04 12:44:35 -04:00
results_dir = os.path.join(agent_dir_path, "results")
os.makedirs(results_dir, exist_ok=True)
subdir = str(train_config.noise_type)
2025-09-01 14:46:34 -04:00
subsubdir = str(train_config.noise_level)
sub_results_dir = os.path.join(results_dir, subdir, subsubdir)
2025-08-04 12:44:35 -04:00
os.makedirs(sub_results_dir, exist_ok=True)
2025-07-31 01:12:53 -04:00
print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...")
for graph_file_name in tqdm(graph_files, desc=f"Training on {agent_dir_name}"):
file_path = os.path.join(agent_dir_path, graph_file_name)
with open(file_path, 'r') as f:
data = json.load(f)
# 1. Load and Prepare Data
trajectories = np.array(data['trajectories'])
s, l, n = trajectories.shape
# trajectories = trajectories.T
# np.random.shuffle(trajectories)
# trajectories = np.random.shuffle(trajectories)
true_graph = np.array(data['adjacency_matrix'])
2025-08-04 12:44:35 -04:00
starter_key, data_key = jax.random.split(starter_key)
inputs, targets = prepare_data_for_model(data_key, trajectories, train_config, train_config.batch_size)
2025-07-31 01:12:53 -04:00
# 2. Configure Model
num_agents = trajectories.shape[-1]
model_config = ModelConfig(
)
model_config.num_agents=num_agents
model_config.input_dim=1 # Each agent has a single state value at time t
model_config.output_dim=1
model_config.embedding_dim=32
# 3. Train the Model
# This relies on the modified train_model that returns final params
final_params, train_logs = train_model(
config=model_config,
inputs=inputs,
targets=targets,
true_graph=true_graph,
train_config=train_config
)
# 4. Evaluate
f1 = calculate_f1_score(
final_params,
model_config,
true_graph,
train_config.f1_threshold
)
loss_history_serializable = {
epoch: [loss.item() for loss in losses]
for epoch, losses in train_logs['loss_history'].items()
}
random_id = str(uuid.uuid4())
# 5. Log Results
result_log = {
# "model_name": random_id,
"source_file": graph_file_name,
"graph_metrics": data['graph_metrics'],
"f1_score": f1,
"training_loss_history": loss_history_serializable,
"config": {
# Manually create the dictionary for the model config
"model": {
"num_agents": model_config.num_agents,
"input_dim": model_config.input_dim,
"output_dim": model_config.output_dim,
"embedding_dim": model_config.embedding_dim
},
# This is correct because TrainConfig is a dataclass
2025-08-04 12:44:35 -04:00
"training": {
"epochs" : train_config.epochs,
"learning_rate" : train_config.learning_rate,
"verbose" : train_config.verbose,
"log" : train_config.log,
"log_epoch_interval" : train_config.log_epoch_interval,
"noise_type": str(train_config.noise_type),
"noise_level": train_config.noise_level,
}
},
"raw_attention": np.array(get_attention_fn(final_params, model_config)).tolist()
2025-07-31 01:12:53 -04:00
}
result_final_params = final_params
2025-08-04 12:44:35 -04:00
model_path = os.path.join(sub_results_dir, "model_params")
os.makedirs(model_path, exist_ok=True)
with open(os.path.join(model_path,"model_for_" + graph_file_name + ".pkl"), "wb") as f:
pickle.dump(final_params, f)
2025-07-31 01:12:53 -04:00
all_results_for_agent.append(result_log)
2025-08-04 12:44:35 -04:00
output_file = os.path.join(sub_results_dir, "summary_results.json")
2025-07-31 01:12:53 -04:00
with open(output_file, 'w') as f:
json.dump(all_results_for_agent, f, indent=2)
print(f"✅ Results for {agent_dir_name} saved to {output_file}")
print("\n🎉 Pipeline finished successfully!")
if __name__ == "__main__":
main()