Files
graph_recognition_w_attn/train_and_eval_w_noise.py
2025-07-31 01:12:53 -04:00

217 lines
8.0 KiB
Python

import os
import json
import jax
import jax.numpy as jnp
import numpy as np
from dataclasses import dataclass, field, asdict
from enum import Enum
from tqdm import tqdm
from sklearn.metrics import f1_score
# Import from your existing model file
from model import ModelConfig, train_model, get_attention_fn
# --- MODIFICATION: Enums for clarity and safety ---
class DataSource(Enum):
CONSENSUS = "consensus"
KURAMOTO = "kuramoto"
class NoiseType(Enum):
NONE = "none"
NORMAL = "normal"
UNIFORM = "uniform"
# --- MODIFICATION: Updated TrainConfig ---
@dataclass
class TrainConfig:
"""Configuration for the training process."""
learning_rate: float = 1e-3
epochs: int = 50
batch_size: int = 64
verbose: bool = False
log: bool = True
log_epoch_interval: int = 10
# Data and Noise parameters
data_source: DataSource = DataSource.CONSENSUS
noise_type: NoiseType = NoiseType.NORMAL
noise_level: float = 0.1 # Stddev for Normal, half-width for Uniform
# Evaluation parameter
f1_threshold: float = 0.5
@property
def source_data_directory(self) -> str:
"""The directory where the source dataset is located."""
return f"{self.data_source.value}_dataset"
@property
def results_directory_name(self) -> str:
"""Generates a unique output directory name for separate logging."""
if self.noise_type == NoiseType.NONE or self.noise_level == 0:
return f"results_noiseless"
return f"results_noise_{self.noise_type.value}_{self.noise_level}"
# --- MODIFICATION: Updated data prep function to add noise ---
def prepare_data_for_model(
trajectories: np.ndarray,
key: jax.Array,
train_config: TrainConfig
) -> tuple[np.ndarray, np.ndarray]:
"""
Converts trajectories to input-output pairs and adds noise to the inputs.
"""
all_inputs, all_targets = [], []
num_sims, num_timesteps, num_agents = trajectories.shape
for sim_idx in range(num_sims):
all_inputs.append(trajectories[sim_idx, :-1, :])
all_targets.append(trajectories[sim_idx, 1:, :])
full_dataset_inputs = np.concatenate(all_inputs, axis=0)
full_dataset_targets = np.concatenate(all_targets, axis=0)
# --- NOISE INJECTION BLOCK ---
if train_config.noise_type != NoiseType.NONE and train_config.noise_level > 0:
noise_shape = full_dataset_inputs.shape
if train_config.noise_type == NoiseType.NORMAL:
noise = jax.random.normal(key, noise_shape) * train_config.noise_level
elif train_config.noise_type == NoiseType.UNIFORM:
noise = jax.random.uniform(
key, noise_shape,
minval=-train_config.noise_level,
maxval=train_config.noise_level
)
full_dataset_inputs += np.array(noise) # Add noise to inputs
# --- END NOISE BLOCK ---
full_dataset_inputs = np.expand_dims(full_dataset_inputs, axis=-1)
full_dataset_targets = np.expand_dims(full_dataset_targets, axis=-1)
num_samples = full_dataset_inputs.shape[0]
num_batches = num_samples // train_config.batch_size
truncated_inputs = full_dataset_inputs[:num_batches * train_config.batch_size]
truncated_targets = full_dataset_targets[:num_batches * train_config.batch_size]
batched_inputs = truncated_inputs.reshape(num_batches, train_config.batch_size, num_agents, 1)
batched_targets = truncated_targets.reshape(num_batches, train_config.batch_size, num_agents, 1)
return batched_inputs, batched_targets
def calculate_f1_score(
params: dict,
model_config: ModelConfig,
true_graph: np.ndarray,
threshold: float
) -> float:
"""Extracts the learned graph and computes the F1 score."""
learned_scores = np.array(get_attention_fn(params, model_config))
flat_scores = learned_scores.flatten()
min_s, max_s = flat_scores.min(), flat_scores.max()
if max_s > min_s:
learned_scores = (learned_scores - min_s) / (max_s - min_s)
predicted_graph = (learned_scores > threshold).astype(int)
np.fill_diagonal(predicted_graph, 0)
np.fill_diagonal(true_graph, 0)
return f1_score(true_graph.flatten(), predicted_graph.flatten())
def main():
"""Main script to run the training and evaluation pipeline."""
# Configure your training run here
train_config = TrainConfig(
noise_type=NoiseType.NORMAL,
noise_level=0.1
)
if not os.path.isdir(train_config.source_data_directory):
print(f"Error: Source data '{train_config.source_data_directory}' not found.")
return
print(f"🚀 Starting training pipeline for '{train_config.data_source.value}' data.")
print(f"Noise Configuration: type={train_config.noise_type.value}, level={train_config.noise_level}")
# --- MODIFICATION: Main JAX key for noise generation ---
main_key = jax.random.PRNGKey(42)
agent_dirs = sorted(
[d for d in os.listdir(train_config.source_data_directory) if d.startswith("agents_")],
key=lambda x: int(x.split('_')[1])
)
for agent_dir_name in agent_dirs:
agent_dir_path = os.path.join(train_config.source_data_directory, agent_dir_name)
all_results_for_agent = []
graph_files = sorted([f for f in os.listdir(agent_dir_path) if f.endswith(".json")])
print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...")
for graph_file_name in tqdm(graph_files, desc=f"Training on {agent_dir_name}"):
file_path = os.path.join(agent_dir_path, graph_file_name)
with open(file_path, 'r') as f:
data = json.load(f)
main_key, data_key = jax.random.split(main_key)
trajectories = np.array(data['trajectories'])
true_graph = np.array(data['adjacency_matrix'])
# --- MODIFICATION: Pass key and config to data prep ---
inputs, targets = prepare_data_for_model(trajectories, data_key, train_config)
num_agents = int(trajectories.shape[-1])
model_config = ModelConfig()
model_config.num_agents = num_agents
model_config.input_dim = 1
model_config.output_dim = 1
model_config.embedding_dim = 32
final_params, train_logs = train_model(
config=model_config,
inputs=inputs,
targets=targets,
true_graph=true_graph,
train_config=train_config
)
f1 = calculate_f1_score(final_params, model_config, true_graph, train_config.f1_threshold)
loss_history_serializable = {
epoch: [loss.item() for loss in losses]
for epoch, losses in train_logs['loss_history'].items()
}
result_log = {
"source_file": graph_file_name,
"graph_metrics": data['graph_metrics'],
"f1_score": f1,
"training_loss_history": loss_history_serializable,
"config": {
"model": {
"num_agents": model_config.num_agents,
"input_dim": model_config.input_dim,
"output_dim": model_config.output_dim,
"embedding_dim": model_config.embedding_dim
},
"training": asdict(train_config)
}
}
all_results_for_agent.append(result_log)
# --- MODIFICATION: Save to a separate results directory ---
results_dir = os.path.join(agent_dir_path, train_config.results_directory_name)
os.makedirs(results_dir, exist_ok=True)
output_file = os.path.join(results_dir, "summary_results.json")
with open(output_file, 'w') as f:
json.dump(all_results_for_agent, f, indent=2)
print(f"✅ Results for {agent_dir_name} saved to {output_file}")
print("\n🎉 Pipeline finished successfully!")
if __name__ == "__main__":
main()