replicated mecc
This commit is contained in:
259
train_and_eval.py
Normal file
259
train_and_eval.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import os
|
||||
import json
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from enum import Enum
|
||||
from tqdm import tqdm
|
||||
from sklearn.metrics import f1_score
|
||||
import uuid
|
||||
import pickle
|
||||
import sys
|
||||
|
||||
|
||||
# Import from your existing model file
|
||||
from model import ModelConfig, TrainConfig, train_model, get_attention_fn
|
||||
|
||||
# Define an Enum for the data source for clarity and type safety
|
||||
|
||||
# Overwrite the original TrainConfig to include our new parameters
|
||||
|
||||
class TrainConfig:
|
||||
"""Configuration for the training process."""
|
||||
learning_rate: float = 1e-3
|
||||
epochs: int = 100
|
||||
batch_size: int = 4096
|
||||
verbose: bool = False # Set to True to see epoch loss during training
|
||||
log: bool = True
|
||||
log_epoch_interval: int = 10
|
||||
|
||||
# --- New parameters for this script ---
|
||||
data_directory:str = "datasets/" + sys.argv[1] + "_dataset"
|
||||
# Threshold for converting attention scores to a binary graph
|
||||
f1_threshold: float = -0.4
|
||||
|
||||
|
||||
def prepare_data_for_model(trajectories: np.ndarray, batch_size: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Converts simulation trajectories into input-output pairs for the model.
|
||||
Input: state at time t. Target: state at time t+1.
|
||||
|
||||
Args:
|
||||
trajectories: A numpy array of shape (num_sims, num_timesteps, num_agents).
|
||||
batch_size: The desired batch size for training.
|
||||
|
||||
Returns:
|
||||
A tuple of (batched_inputs, batched_targets).
|
||||
"""
|
||||
# For each simulation, create (input, target) pairs
|
||||
all_inputs = []
|
||||
all_targets = []
|
||||
|
||||
num_sims, num_timesteps, num_agents = trajectories.shape
|
||||
|
||||
# trajectories = np.reshape(trajectories, shape=(num_sims * num_timesteps, num_agents))
|
||||
|
||||
for i_sim in range(num_sims):
|
||||
for j_tstep in range(num_timesteps-1):
|
||||
input = trajectories[i_sim, j_tstep, :]
|
||||
target = trajectories[i_sim, j_tstep + 1, :]
|
||||
all_inputs.append(input)
|
||||
all_targets.append(target)
|
||||
|
||||
all_indices = np.arange(len(all_inputs))
|
||||
np.random.shuffle(all_indices)
|
||||
all_inputs = np.array(all_inputs)
|
||||
all_targets = np.array(all_targets)
|
||||
|
||||
all_inputs = all_inputs[all_indices]
|
||||
all_targets = all_targets[all_indices]
|
||||
|
||||
# for sim_idx in range(num_sims):
|
||||
# # Input is state from t=0 to t=T-2
|
||||
# inputs = trajectories[sim_idx, :-1, :]
|
||||
# # Target is state from t=1 to t=T-1
|
||||
# targets = trajectories[sim_idx, 1:, :]
|
||||
# all_inputs.append(inputs)
|
||||
# all_targets.append(targets)
|
||||
|
||||
# Concatenate all pairs from all simulations
|
||||
# Shape -> (num_sims * (num_timesteps - 1), num_agents)
|
||||
# full_dataset_inputs = np.concatenate(all_inputs, axis=0)
|
||||
# full_dataset_targets = np.concatenate(all_targets, axis=0)
|
||||
|
||||
# Reshape to have a feature dimension
|
||||
# Shape -> (total_samples, num_agents, 1)
|
||||
full_dataset_inputs = np.expand_dims(all_inputs, axis=-1)
|
||||
full_dataset_targets = np.expand_dims(all_targets, axis=-1)
|
||||
|
||||
# Create batches
|
||||
num_samples = full_dataset_inputs.shape[0]
|
||||
num_batches = num_samples // batch_size
|
||||
|
||||
# Truncate to full batches
|
||||
truncated_inputs = full_dataset_inputs[:num_batches * batch_size]
|
||||
truncated_targets = full_dataset_targets[:num_batches * batch_size]
|
||||
|
||||
# Reshape into batches
|
||||
# Shape -> (num_batches, batch_size, num_agents, 1)
|
||||
batched_inputs = truncated_inputs.reshape(num_batches, batch_size, num_agents, 1)
|
||||
batched_targets = truncated_targets.reshape(num_batches, batch_size, num_agents, 1)
|
||||
|
||||
return batched_inputs, batched_targets
|
||||
|
||||
def calculate_f1_score(
|
||||
params: dict,
|
||||
model_config: ModelConfig,
|
||||
true_graph: np.ndarray,
|
||||
threshold: float
|
||||
) -> float:
|
||||
"""
|
||||
Extracts the learned attention graph, thresholds it, and computes the F1 score.
|
||||
"""
|
||||
# Get the learned attention matrix (N, N)
|
||||
learned_graph_scores = np.array(get_attention_fn(params, model_config))
|
||||
|
||||
# Normalize scores to [0, 1] for consistent thresholding (optional but good practice)
|
||||
# This uses min-max scaling on the flattened array
|
||||
# flat_scores = learned_graph_scores.flatten()
|
||||
# min_s, max_s = flat_scores.min(), flat_scores.max()
|
||||
# if max_s > min_s:
|
||||
# learned_graph_scores = (learned_graph_scores - min_s) / (max_s - min_s)
|
||||
|
||||
# Threshold to get a binary predicted graph
|
||||
predicted_graph = (learned_graph_scores > threshold).astype(int)
|
||||
|
||||
# The diagonal is not part of the prediction task
|
||||
# np.fill_diagonal(predicted_graph, 0)
|
||||
# np.fill_diagonal(true_graph, 0)
|
||||
|
||||
# Flatten both graphs to treat this as a binary classification problem
|
||||
true_flat = true_graph.flatten()
|
||||
pred_flat = predicted_graph.flatten()
|
||||
|
||||
return f1_score(true_flat, pred_flat)
|
||||
|
||||
def main():
|
||||
"""Main script to run the training and evaluation pipeline."""
|
||||
|
||||
train_config = TrainConfig()
|
||||
|
||||
# Check if the data directory exists
|
||||
if not os.path.isdir(train_config.data_directory):
|
||||
print(f"Error: Data directory '{train_config.data_directory}' not found.")
|
||||
print(f"Please run the data generation script for '{train_config.data_directory}' first.")
|
||||
return
|
||||
|
||||
print(f"🚀 Starting training pipeline for '{train_config.data_directory}' data.")
|
||||
|
||||
# Get sorted list of agent directories
|
||||
agent_dirs = sorted(
|
||||
[d for d in os.listdir(train_config.data_directory) if d.startswith("agents_")],
|
||||
key=lambda x: int(x.split('_')[1])
|
||||
)
|
||||
|
||||
for agent_dir_name in agent_dirs:
|
||||
agent_dir_path = os.path.join(train_config.data_directory, agent_dir_name)
|
||||
|
||||
all_results_for_agent = []
|
||||
|
||||
graph_files = sorted([f for f in os.listdir(agent_dir_path) if f.endswith(".json")])
|
||||
|
||||
print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...")
|
||||
|
||||
for graph_file_name in tqdm(graph_files, desc=f"Training on {agent_dir_name}"):
|
||||
file_path = os.path.join(agent_dir_path, graph_file_name)
|
||||
|
||||
with open(file_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 1. Load and Prepare Data
|
||||
trajectories = np.array(data['trajectories'])
|
||||
s, l, n = trajectories.shape
|
||||
# trajectories = trajectories.T
|
||||
# np.random.shuffle(trajectories)
|
||||
# trajectories = np.random.shuffle(trajectories)
|
||||
true_graph = np.array(data['adjacency_matrix'])
|
||||
inputs, targets = prepare_data_for_model(trajectories, train_config.batch_size)
|
||||
|
||||
# 2. Configure Model
|
||||
num_agents = trajectories.shape[-1]
|
||||
model_config = ModelConfig(
|
||||
)
|
||||
|
||||
model_config.num_agents=num_agents
|
||||
model_config.input_dim=1 # Each agent has a single state value at time t
|
||||
model_config.output_dim=1
|
||||
model_config.embedding_dim=32
|
||||
|
||||
# 3. Train the Model
|
||||
# This relies on the modified train_model that returns final params
|
||||
final_params, train_logs = train_model(
|
||||
config=model_config,
|
||||
inputs=inputs,
|
||||
targets=targets,
|
||||
true_graph=true_graph,
|
||||
train_config=train_config
|
||||
)
|
||||
|
||||
# 4. Evaluate
|
||||
f1 = calculate_f1_score(
|
||||
final_params,
|
||||
model_config,
|
||||
true_graph,
|
||||
train_config.f1_threshold
|
||||
)
|
||||
|
||||
loss_history_serializable = {
|
||||
epoch: [loss.item() for loss in losses]
|
||||
for epoch, losses in train_logs['loss_history'].items()
|
||||
}
|
||||
|
||||
random_id = str(uuid.uuid4())
|
||||
|
||||
# 5. Log Results
|
||||
result_log = {
|
||||
# "model_name": random_id,
|
||||
"source_file": graph_file_name,
|
||||
"graph_metrics": data['graph_metrics'],
|
||||
"f1_score": f1,
|
||||
"training_loss_history": loss_history_serializable,
|
||||
"config": {
|
||||
# Manually create the dictionary for the model config
|
||||
"model": {
|
||||
"num_agents": model_config.num_agents,
|
||||
"input_dim": model_config.input_dim,
|
||||
"output_dim": model_config.output_dim,
|
||||
"embedding_dim": model_config.embedding_dim
|
||||
},
|
||||
# This is correct because TrainConfig is a dataclass
|
||||
"training": vars(train_config)
|
||||
}
|
||||
}
|
||||
|
||||
result_final_params = final_params
|
||||
|
||||
|
||||
all_results_for_agent.append(result_log)
|
||||
|
||||
# 6. Save aggregated results for this agent count
|
||||
results_dir = os.path.join(agent_dir_path, "results")
|
||||
os.makedirs(results_dir, exist_ok=True)
|
||||
|
||||
output_file = os.path.join(results_dir, "summary_results.json")
|
||||
with open(output_file, 'w') as f:
|
||||
json.dump(all_results_for_agent, f, indent=2)
|
||||
|
||||
|
||||
model_path = os.path.join(results_dir, "model_params")
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
with open(os.path.join(model_path,"model_params" + ".pkl"), "wb") as f:
|
||||
pickle.dump(final_params, f)
|
||||
|
||||
print(f"✅ Results for {agent_dir_name} saved to {output_file}")
|
||||
|
||||
print("\n🎉 Pipeline finished successfully!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user