217 lines
8.0 KiB
Python
217 lines
8.0 KiB
Python
import os
|
|
import json
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
from dataclasses import dataclass, field, asdict
|
|
from enum import Enum
|
|
from tqdm import tqdm
|
|
from sklearn.metrics import f1_score
|
|
|
|
# Import from your existing model file
|
|
from model import ModelConfig, train_model, get_attention_fn
|
|
|
|
# --- MODIFICATION: Enums for clarity and safety ---
|
|
class DataSource(Enum):
|
|
CONSENSUS = "consensus"
|
|
KURAMOTO = "kuramoto"
|
|
|
|
class NoiseType(Enum):
|
|
NONE = "none"
|
|
NORMAL = "normal"
|
|
UNIFORM = "uniform"
|
|
|
|
# --- MODIFICATION: Updated TrainConfig ---
|
|
@dataclass
|
|
class TrainConfig:
|
|
"""Configuration for the training process."""
|
|
learning_rate: float = 1e-3
|
|
epochs: int = 50
|
|
batch_size: int = 64
|
|
verbose: bool = False
|
|
log: bool = True
|
|
log_epoch_interval: int = 10
|
|
|
|
# Data and Noise parameters
|
|
data_source: DataSource = DataSource.CONSENSUS
|
|
noise_type: NoiseType = NoiseType.NORMAL
|
|
noise_level: float = 0.1 # Stddev for Normal, half-width for Uniform
|
|
|
|
# Evaluation parameter
|
|
f1_threshold: float = 0.5
|
|
|
|
@property
|
|
def source_data_directory(self) -> str:
|
|
"""The directory where the source dataset is located."""
|
|
return f"{self.data_source.value}_dataset"
|
|
|
|
@property
|
|
def results_directory_name(self) -> str:
|
|
"""Generates a unique output directory name for separate logging."""
|
|
if self.noise_type == NoiseType.NONE or self.noise_level == 0:
|
|
return f"results_noiseless"
|
|
return f"results_noise_{self.noise_type.value}_{self.noise_level}"
|
|
|
|
# --- MODIFICATION: Updated data prep function to add noise ---
|
|
def prepare_data_for_model(
|
|
trajectories: np.ndarray,
|
|
key: jax.Array,
|
|
train_config: TrainConfig
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
Converts trajectories to input-output pairs and adds noise to the inputs.
|
|
"""
|
|
all_inputs, all_targets = [], []
|
|
num_sims, num_timesteps, num_agents = trajectories.shape
|
|
|
|
for sim_idx in range(num_sims):
|
|
all_inputs.append(trajectories[sim_idx, :-1, :])
|
|
all_targets.append(trajectories[sim_idx, 1:, :])
|
|
|
|
full_dataset_inputs = np.concatenate(all_inputs, axis=0)
|
|
full_dataset_targets = np.concatenate(all_targets, axis=0)
|
|
|
|
# --- NOISE INJECTION BLOCK ---
|
|
if train_config.noise_type != NoiseType.NONE and train_config.noise_level > 0:
|
|
noise_shape = full_dataset_inputs.shape
|
|
if train_config.noise_type == NoiseType.NORMAL:
|
|
noise = jax.random.normal(key, noise_shape) * train_config.noise_level
|
|
elif train_config.noise_type == NoiseType.UNIFORM:
|
|
noise = jax.random.uniform(
|
|
key, noise_shape,
|
|
minval=-train_config.noise_level,
|
|
maxval=train_config.noise_level
|
|
)
|
|
full_dataset_inputs += np.array(noise) # Add noise to inputs
|
|
# --- END NOISE BLOCK ---
|
|
|
|
full_dataset_inputs = np.expand_dims(full_dataset_inputs, axis=-1)
|
|
full_dataset_targets = np.expand_dims(full_dataset_targets, axis=-1)
|
|
|
|
num_samples = full_dataset_inputs.shape[0]
|
|
num_batches = num_samples // train_config.batch_size
|
|
|
|
truncated_inputs = full_dataset_inputs[:num_batches * train_config.batch_size]
|
|
truncated_targets = full_dataset_targets[:num_batches * train_config.batch_size]
|
|
|
|
batched_inputs = truncated_inputs.reshape(num_batches, train_config.batch_size, num_agents, 1)
|
|
batched_targets = truncated_targets.reshape(num_batches, train_config.batch_size, num_agents, 1)
|
|
|
|
return batched_inputs, batched_targets
|
|
|
|
def calculate_f1_score(
|
|
params: dict,
|
|
model_config: ModelConfig,
|
|
true_graph: np.ndarray,
|
|
threshold: float
|
|
) -> float:
|
|
"""Extracts the learned graph and computes the F1 score."""
|
|
learned_scores = np.array(get_attention_fn(params, model_config))
|
|
flat_scores = learned_scores.flatten()
|
|
min_s, max_s = flat_scores.min(), flat_scores.max()
|
|
if max_s > min_s:
|
|
learned_scores = (learned_scores - min_s) / (max_s - min_s)
|
|
predicted_graph = (learned_scores > threshold).astype(int)
|
|
|
|
np.fill_diagonal(predicted_graph, 0)
|
|
np.fill_diagonal(true_graph, 0)
|
|
|
|
return f1_score(true_graph.flatten(), predicted_graph.flatten())
|
|
|
|
def main():
|
|
"""Main script to run the training and evaluation pipeline."""
|
|
|
|
# Configure your training run here
|
|
train_config = TrainConfig(
|
|
noise_type=NoiseType.NORMAL,
|
|
noise_level=0.1
|
|
)
|
|
|
|
if not os.path.isdir(train_config.source_data_directory):
|
|
print(f"Error: Source data '{train_config.source_data_directory}' not found.")
|
|
return
|
|
|
|
print(f"🚀 Starting training pipeline for '{train_config.data_source.value}' data.")
|
|
print(f"Noise Configuration: type={train_config.noise_type.value}, level={train_config.noise_level}")
|
|
|
|
# --- MODIFICATION: Main JAX key for noise generation ---
|
|
main_key = jax.random.PRNGKey(42)
|
|
|
|
agent_dirs = sorted(
|
|
[d for d in os.listdir(train_config.source_data_directory) if d.startswith("agents_")],
|
|
key=lambda x: int(x.split('_')[1])
|
|
)
|
|
|
|
for agent_dir_name in agent_dirs:
|
|
agent_dir_path = os.path.join(train_config.source_data_directory, agent_dir_name)
|
|
all_results_for_agent = []
|
|
graph_files = sorted([f for f in os.listdir(agent_dir_path) if f.endswith(".json")])
|
|
|
|
print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...")
|
|
|
|
for graph_file_name in tqdm(graph_files, desc=f"Training on {agent_dir_name}"):
|
|
file_path = os.path.join(agent_dir_path, graph_file_name)
|
|
|
|
with open(file_path, 'r') as f:
|
|
data = json.load(f)
|
|
|
|
main_key, data_key = jax.random.split(main_key)
|
|
trajectories = np.array(data['trajectories'])
|
|
true_graph = np.array(data['adjacency_matrix'])
|
|
|
|
# --- MODIFICATION: Pass key and config to data prep ---
|
|
inputs, targets = prepare_data_for_model(trajectories, data_key, train_config)
|
|
|
|
num_agents = int(trajectories.shape[-1])
|
|
model_config = ModelConfig()
|
|
model_config.num_agents = num_agents
|
|
model_config.input_dim = 1
|
|
model_config.output_dim = 1
|
|
model_config.embedding_dim = 32
|
|
|
|
final_params, train_logs = train_model(
|
|
config=model_config,
|
|
inputs=inputs,
|
|
targets=targets,
|
|
true_graph=true_graph,
|
|
train_config=train_config
|
|
)
|
|
|
|
f1 = calculate_f1_score(final_params, model_config, true_graph, train_config.f1_threshold)
|
|
|
|
loss_history_serializable = {
|
|
epoch: [loss.item() for loss in losses]
|
|
for epoch, losses in train_logs['loss_history'].items()
|
|
}
|
|
|
|
result_log = {
|
|
"source_file": graph_file_name,
|
|
"graph_metrics": data['graph_metrics'],
|
|
"f1_score": f1,
|
|
"training_loss_history": loss_history_serializable,
|
|
"config": {
|
|
"model": {
|
|
"num_agents": model_config.num_agents,
|
|
"input_dim": model_config.input_dim,
|
|
"output_dim": model_config.output_dim,
|
|
"embedding_dim": model_config.embedding_dim
|
|
},
|
|
"training": asdict(train_config)
|
|
}
|
|
}
|
|
all_results_for_agent.append(result_log)
|
|
|
|
# --- MODIFICATION: Save to a separate results directory ---
|
|
results_dir = os.path.join(agent_dir_path, train_config.results_directory_name)
|
|
os.makedirs(results_dir, exist_ok=True)
|
|
|
|
output_file = os.path.join(results_dir, "summary_results.json")
|
|
with open(output_file, 'w') as f:
|
|
json.dump(all_results_for_agent, f, indent=2)
|
|
|
|
print(f"✅ Results for {agent_dir_name} saved to {output_file}")
|
|
|
|
print("\n🎉 Pipeline finished successfully!")
|
|
|
|
if __name__ == "__main__":
|
|
main() |