training and plotting updates
This commit is contained in:
BIN
__pycache__/config_.cpython-312.pyc
Normal file
BIN
__pycache__/config_.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
__pycache__/train_and_eval.cpython-312.pyc
Normal file
BIN
__pycache__/train_and_eval.cpython-312.pyc
Normal file
Binary file not shown.
33
config_.py
Normal file
33
config_.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import sys
|
||||
from enum import Enum
|
||||
|
||||
class NoiseType(Enum):
|
||||
NONE = "none"
|
||||
NORMAL = "normal"
|
||||
UNIFORM = "uniform"
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
num_agents: int
|
||||
embedding_dim: int = 16
|
||||
input_dim: int = 1
|
||||
output_dim: int = 1
|
||||
|
||||
|
||||
|
||||
class TrainConfig:
|
||||
"""Configuration for the training process."""
|
||||
learning_rate: float = 1e-3
|
||||
epochs: int = 100
|
||||
batch_size: int = 32
|
||||
verbose: bool = False # Set to True to see epoch loss during training
|
||||
log: bool = True
|
||||
log_epoch_interval: int = 10
|
||||
noise_type: NoiseType = NoiseType.NONE
|
||||
noise_level: float = 0.01
|
||||
|
||||
# --- New parameters for this script ---
|
||||
data_directory:str
|
||||
# Threshold for converting attention scores to a binary graph
|
||||
f1_threshold: float = -0.5
|
||||
|
4
model.py
4
model.py
@@ -1,10 +1,12 @@
|
||||
import jax
|
||||
from jax import random
|
||||
import jax.numpy as jnp
|
||||
from train import ModelConfig, TrainConfig
|
||||
from config_ import ModelConfig, TrainConfig
|
||||
import optax
|
||||
from functools import partial
|
||||
|
||||
|
||||
|
||||
def init_linear_layer(
|
||||
key: jax.Array,
|
||||
in_features: int,
|
||||
|
74
plot_results.py
Normal file
74
plot_results.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import pickle
|
||||
import matplotlib.pyplot as plt
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from config_ import ModelConfig
|
||||
from train_and_eval import calculate_f1_score
|
||||
from sklearn.metrics import f1_score
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
data_dir = "datasets/consensus_dataset"
|
||||
else:
|
||||
data_dir = "datasets/" + sys.argv[1]
|
||||
|
||||
datapoints = {}
|
||||
THRESHOLD = 0.2
|
||||
|
||||
for folder in tqdm(os.listdir(data_dir)):
|
||||
num_agents = int(folder.split("_")[1]) # Extract num agents
|
||||
|
||||
folder_path = os.path.join(data_dir, folder)
|
||||
|
||||
# Load model config from summary json
|
||||
with open(os.path.join(folder_path, "results/NoiseType.NONE", "summary_results.json"), "r") as f:
|
||||
summary_results = json.load(f)
|
||||
|
||||
|
||||
for i, graph in enumerate(os.listdir(folder_path)):
|
||||
|
||||
# train_summary_results
|
||||
summ_results = summary_results[i-1]
|
||||
|
||||
if graph == "results": # ignore the result folder
|
||||
continue
|
||||
|
||||
graph_path = os.path.join(folder_path, graph)
|
||||
|
||||
# Load run data
|
||||
with open(os.path.join(folder_path, graph), "r") as f:
|
||||
run_data = json.load(f)
|
||||
|
||||
true_graph = np.array(run_data["adjacency_matrix"])
|
||||
|
||||
learned_graph = np.array(summ_results["raw_attention"])
|
||||
|
||||
predicted_graph = (learned_graph > THRESHOLD).astype(int)
|
||||
|
||||
true_flat = true_graph.flatten()
|
||||
pred_flat = predicted_graph.flatten()
|
||||
|
||||
calc_f1_score = f1_score(true_flat, pred_flat)
|
||||
|
||||
|
||||
datapoints[num_agents] = datapoints.get(num_agents, [])
|
||||
datapoints[num_agents].append(calc_f1_score)
|
||||
|
||||
|
||||
for key in datapoints.keys():
|
||||
datapoints[key] = sum(datapoints[key])/len(datapoints[key])
|
||||
|
||||
|
||||
x = []
|
||||
y = []
|
||||
|
||||
for item in datapoints.items():
|
||||
x.append(item[0])
|
||||
y.append(item[1])
|
||||
|
||||
plt.plot(x, y)
|
||||
plt.show()
|
||||
|
||||
|
306
test.ipynb
306
test.ipynb
File diff suppressed because one or more lines are too long
19
train.py
19
train.py
@@ -1,19 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
class ModelConfig:
|
||||
num_agents: int = 10
|
||||
embedding_dim: int = 64
|
||||
input_dim: int = 1
|
||||
output_dim: int = 1
|
||||
simulation_type:str = "consensus"
|
||||
|
||||
class TrainConfig:
|
||||
epochs: float = 100
|
||||
learning_rate: float = 1e-3
|
||||
verbose: bool = True
|
||||
log: bool = True
|
||||
log_epoch_interval: int = 10
|
||||
|
||||
|
||||
|
||||
|
@@ -11,30 +11,11 @@ import uuid
|
||||
import pickle
|
||||
import sys
|
||||
|
||||
|
||||
# Import from your existing model file
|
||||
from model import ModelConfig, TrainConfig, train_model, get_attention_fn
|
||||
|
||||
# Define an Enum for the data source for clarity and type safety
|
||||
|
||||
# Overwrite the original TrainConfig to include our new parameters
|
||||
|
||||
class TrainConfig:
|
||||
"""Configuration for the training process."""
|
||||
learning_rate: float = 1e-3
|
||||
epochs: int = 100
|
||||
batch_size: int = 4096
|
||||
verbose: bool = False # Set to True to see epoch loss during training
|
||||
log: bool = True
|
||||
log_epoch_interval: int = 10
|
||||
|
||||
# --- New parameters for this script ---
|
||||
data_directory:str = "datasets/" + sys.argv[1] + "_dataset"
|
||||
# Threshold for converting attention scores to a binary graph
|
||||
f1_threshold: float = -0.4
|
||||
from config_ import TrainConfig, ModelConfig, NoiseType
|
||||
from model import train_model, get_attention_fn
|
||||
|
||||
|
||||
def prepare_data_for_model(trajectories: np.ndarray, batch_size: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
def prepare_data_for_model(key: jax.Array, trajectories: np.ndarray, train_config: TrainConfig, batch_size: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Converts simulation trajectories into input-output pairs for the model.
|
||||
Input: state at time t. Target: state at time t+1.
|
||||
@@ -69,21 +50,19 @@ def prepare_data_for_model(trajectories: np.ndarray, batch_size: int) -> tuple[n
|
||||
all_inputs = all_inputs[all_indices]
|
||||
all_targets = all_targets[all_indices]
|
||||
|
||||
# for sim_idx in range(num_sims):
|
||||
# # Input is state from t=0 to t=T-2
|
||||
# inputs = trajectories[sim_idx, :-1, :]
|
||||
# # Target is state from t=1 to t=T-1
|
||||
# targets = trajectories[sim_idx, 1:, :]
|
||||
# all_inputs.append(inputs)
|
||||
# all_targets.append(targets)
|
||||
if train_config.noise_type != NoiseType.NONE and train_config.noise_level > 0:
|
||||
noise_shape = full_dataset_inputs.shape
|
||||
if train_config.noise_type == NoiseType.NORMAL:
|
||||
noise = jax.random.normal(key, noise_shape) * train_config.noise_level
|
||||
elif train_config.noise_type == NoiseType.UNIFORM:
|
||||
noise = jax.random.uniform(
|
||||
key, noise_shape,
|
||||
minval=-train_config.noise_level,
|
||||
maxval=train_config.noise_level
|
||||
)
|
||||
full_dataset_inputs += np.array(noise) # Add noise to inputs
|
||||
|
||||
# Concatenate all pairs from all simulations
|
||||
# Shape -> (num_sims * (num_timesteps - 1), num_agents)
|
||||
# full_dataset_inputs = np.concatenate(all_inputs, axis=0)
|
||||
# full_dataset_targets = np.concatenate(all_targets, axis=0)
|
||||
|
||||
# Reshape to have a feature dimension
|
||||
# Shape -> (total_samples, num_agents, 1)
|
||||
full_dataset_inputs = np.expand_dims(all_inputs, axis=-1)
|
||||
full_dataset_targets = np.expand_dims(all_targets, axis=-1)
|
||||
|
||||
@@ -138,6 +117,7 @@ def main():
|
||||
"""Main script to run the training and evaluation pipeline."""
|
||||
|
||||
train_config = TrainConfig()
|
||||
train_config.data_directory = "datasets/" + sys.argv[1] + "_dataset"
|
||||
|
||||
# Check if the data directory exists
|
||||
if not os.path.isdir(train_config.data_directory):
|
||||
@@ -153,6 +133,8 @@ def main():
|
||||
key=lambda x: int(x.split('_')[1])
|
||||
)
|
||||
|
||||
starter_key = jax.random.PRNGKey(49)
|
||||
|
||||
for agent_dir_name in agent_dirs:
|
||||
agent_dir_path = os.path.join(train_config.data_directory, agent_dir_name)
|
||||
|
||||
@@ -160,6 +142,13 @@ def main():
|
||||
|
||||
graph_files = sorted([f for f in os.listdir(agent_dir_path) if f.endswith(".json")])
|
||||
|
||||
results_dir = os.path.join(agent_dir_path, "results")
|
||||
os.makedirs(results_dir, exist_ok=True)
|
||||
|
||||
subdir = str(train_config.noise_type)
|
||||
sub_results_dir = os.path.join(results_dir, subdir)
|
||||
os.makedirs(sub_results_dir, exist_ok=True)
|
||||
|
||||
print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...")
|
||||
|
||||
for graph_file_name in tqdm(graph_files, desc=f"Training on {agent_dir_name}"):
|
||||
@@ -175,7 +164,8 @@ def main():
|
||||
# np.random.shuffle(trajectories)
|
||||
# trajectories = np.random.shuffle(trajectories)
|
||||
true_graph = np.array(data['adjacency_matrix'])
|
||||
inputs, targets = prepare_data_for_model(trajectories, train_config.batch_size)
|
||||
starter_key, data_key = jax.random.split(starter_key)
|
||||
inputs, targets = prepare_data_for_model(data_key, trajectories, train_config, train_config.batch_size)
|
||||
|
||||
# 2. Configure Model
|
||||
num_agents = trajectories.shape[-1]
|
||||
@@ -228,28 +218,34 @@ def main():
|
||||
"embedding_dim": model_config.embedding_dim
|
||||
},
|
||||
# This is correct because TrainConfig is a dataclass
|
||||
"training": vars(train_config)
|
||||
}
|
||||
"training": {
|
||||
"epochs" : train_config.epochs,
|
||||
"learning_rate" : train_config.learning_rate,
|
||||
"verbose" : train_config.verbose,
|
||||
"log" : train_config.log,
|
||||
"log_epoch_interval" : train_config.log_epoch_interval,
|
||||
"noise_type": str(train_config.noise_type),
|
||||
"noise_level": train_config.noise_level,
|
||||
|
||||
}
|
||||
},
|
||||
"raw_attention": np.array(get_attention_fn(final_params, model_config)).tolist()
|
||||
}
|
||||
|
||||
result_final_params = final_params
|
||||
|
||||
model_path = os.path.join(sub_results_dir, "model_params")
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
with open(os.path.join(model_path,"model_for_" + graph_file_name + ".pkl"), "wb") as f:
|
||||
pickle.dump(final_params, f)
|
||||
|
||||
all_results_for_agent.append(result_log)
|
||||
|
||||
# 6. Save aggregated results for this agent count
|
||||
results_dir = os.path.join(agent_dir_path, "results")
|
||||
os.makedirs(results_dir, exist_ok=True)
|
||||
|
||||
output_file = os.path.join(results_dir, "summary_results.json")
|
||||
output_file = os.path.join(sub_results_dir, "summary_results.json")
|
||||
with open(output_file, 'w') as f:
|
||||
json.dump(all_results_for_agent, f, indent=2)
|
||||
|
||||
|
||||
model_path = os.path.join(results_dir, "model_params")
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
with open(os.path.join(model_path,"model_params" + ".pkl"), "wb") as f:
|
||||
pickle.dump(final_params, f)
|
||||
|
||||
print(f"✅ Results for {agent_dir_name} saved to {output_file}")
|
||||
|
||||
|
Reference in New Issue
Block a user