training and plotting updates
This commit is contained in:
BIN
__pycache__/config_.cpython-312.pyc
Normal file
BIN
__pycache__/config_.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
__pycache__/train_and_eval.cpython-312.pyc
Normal file
BIN
__pycache__/train_and_eval.cpython-312.pyc
Normal file
Binary file not shown.
33
config_.py
Normal file
33
config_.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import sys
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class NoiseType(Enum):
|
||||||
|
NONE = "none"
|
||||||
|
NORMAL = "normal"
|
||||||
|
UNIFORM = "uniform"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig:
|
||||||
|
num_agents: int
|
||||||
|
embedding_dim: int = 16
|
||||||
|
input_dim: int = 1
|
||||||
|
output_dim: int = 1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TrainConfig:
|
||||||
|
"""Configuration for the training process."""
|
||||||
|
learning_rate: float = 1e-3
|
||||||
|
epochs: int = 100
|
||||||
|
batch_size: int = 32
|
||||||
|
verbose: bool = False # Set to True to see epoch loss during training
|
||||||
|
log: bool = True
|
||||||
|
log_epoch_interval: int = 10
|
||||||
|
noise_type: NoiseType = NoiseType.NONE
|
||||||
|
noise_level: float = 0.01
|
||||||
|
|
||||||
|
# --- New parameters for this script ---
|
||||||
|
data_directory:str
|
||||||
|
# Threshold for converting attention scores to a binary graph
|
||||||
|
f1_threshold: float = -0.5
|
||||||
|
|
4
model.py
4
model.py
@@ -1,10 +1,12 @@
|
|||||||
import jax
|
import jax
|
||||||
from jax import random
|
from jax import random
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from train import ModelConfig, TrainConfig
|
from config_ import ModelConfig, TrainConfig
|
||||||
import optax
|
import optax
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def init_linear_layer(
|
def init_linear_layer(
|
||||||
key: jax.Array,
|
key: jax.Array,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
|
74
plot_results.py
Normal file
74
plot_results.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
from config_ import ModelConfig
|
||||||
|
from train_and_eval import calculate_f1_score
|
||||||
|
from sklearn.metrics import f1_score
|
||||||
|
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
data_dir = "datasets/consensus_dataset"
|
||||||
|
else:
|
||||||
|
data_dir = "datasets/" + sys.argv[1]
|
||||||
|
|
||||||
|
datapoints = {}
|
||||||
|
THRESHOLD = 0.2
|
||||||
|
|
||||||
|
for folder in tqdm(os.listdir(data_dir)):
|
||||||
|
num_agents = int(folder.split("_")[1]) # Extract num agents
|
||||||
|
|
||||||
|
folder_path = os.path.join(data_dir, folder)
|
||||||
|
|
||||||
|
# Load model config from summary json
|
||||||
|
with open(os.path.join(folder_path, "results/NoiseType.NONE", "summary_results.json"), "r") as f:
|
||||||
|
summary_results = json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
for i, graph in enumerate(os.listdir(folder_path)):
|
||||||
|
|
||||||
|
# train_summary_results
|
||||||
|
summ_results = summary_results[i-1]
|
||||||
|
|
||||||
|
if graph == "results": # ignore the result folder
|
||||||
|
continue
|
||||||
|
|
||||||
|
graph_path = os.path.join(folder_path, graph)
|
||||||
|
|
||||||
|
# Load run data
|
||||||
|
with open(os.path.join(folder_path, graph), "r") as f:
|
||||||
|
run_data = json.load(f)
|
||||||
|
|
||||||
|
true_graph = np.array(run_data["adjacency_matrix"])
|
||||||
|
|
||||||
|
learned_graph = np.array(summ_results["raw_attention"])
|
||||||
|
|
||||||
|
predicted_graph = (learned_graph > THRESHOLD).astype(int)
|
||||||
|
|
||||||
|
true_flat = true_graph.flatten()
|
||||||
|
pred_flat = predicted_graph.flatten()
|
||||||
|
|
||||||
|
calc_f1_score = f1_score(true_flat, pred_flat)
|
||||||
|
|
||||||
|
|
||||||
|
datapoints[num_agents] = datapoints.get(num_agents, [])
|
||||||
|
datapoints[num_agents].append(calc_f1_score)
|
||||||
|
|
||||||
|
|
||||||
|
for key in datapoints.keys():
|
||||||
|
datapoints[key] = sum(datapoints[key])/len(datapoints[key])
|
||||||
|
|
||||||
|
|
||||||
|
x = []
|
||||||
|
y = []
|
||||||
|
|
||||||
|
for item in datapoints.items():
|
||||||
|
x.append(item[0])
|
||||||
|
y.append(item[1])
|
||||||
|
|
||||||
|
plt.plot(x, y)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
306
test.ipynb
306
test.ipynb
File diff suppressed because one or more lines are too long
19
train.py
19
train.py
@@ -1,19 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
class ModelConfig:
|
|
||||||
num_agents: int = 10
|
|
||||||
embedding_dim: int = 64
|
|
||||||
input_dim: int = 1
|
|
||||||
output_dim: int = 1
|
|
||||||
simulation_type:str = "consensus"
|
|
||||||
|
|
||||||
class TrainConfig:
|
|
||||||
epochs: float = 100
|
|
||||||
learning_rate: float = 1e-3
|
|
||||||
verbose: bool = True
|
|
||||||
log: bool = True
|
|
||||||
log_epoch_interval: int = 10
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@@ -11,30 +11,11 @@ import uuid
|
|||||||
import pickle
|
import pickle
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from config_ import TrainConfig, ModelConfig, NoiseType
|
||||||
|
from model import train_model, get_attention_fn
|
||||||
|
|
||||||
# Import from your existing model file
|
|
||||||
from model import ModelConfig, TrainConfig, train_model, get_attention_fn
|
|
||||||
|
|
||||||
# Define an Enum for the data source for clarity and type safety
|
def prepare_data_for_model(key: jax.Array, trajectories: np.ndarray, train_config: TrainConfig, batch_size: int) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
|
||||||
# Overwrite the original TrainConfig to include our new parameters
|
|
||||||
|
|
||||||
class TrainConfig:
|
|
||||||
"""Configuration for the training process."""
|
|
||||||
learning_rate: float = 1e-3
|
|
||||||
epochs: int = 100
|
|
||||||
batch_size: int = 4096
|
|
||||||
verbose: bool = False # Set to True to see epoch loss during training
|
|
||||||
log: bool = True
|
|
||||||
log_epoch_interval: int = 10
|
|
||||||
|
|
||||||
# --- New parameters for this script ---
|
|
||||||
data_directory:str = "datasets/" + sys.argv[1] + "_dataset"
|
|
||||||
# Threshold for converting attention scores to a binary graph
|
|
||||||
f1_threshold: float = -0.4
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_data_for_model(trajectories: np.ndarray, batch_size: int) -> tuple[np.ndarray, np.ndarray]:
|
|
||||||
"""
|
"""
|
||||||
Converts simulation trajectories into input-output pairs for the model.
|
Converts simulation trajectories into input-output pairs for the model.
|
||||||
Input: state at time t. Target: state at time t+1.
|
Input: state at time t. Target: state at time t+1.
|
||||||
@@ -69,21 +50,19 @@ def prepare_data_for_model(trajectories: np.ndarray, batch_size: int) -> tuple[n
|
|||||||
all_inputs = all_inputs[all_indices]
|
all_inputs = all_inputs[all_indices]
|
||||||
all_targets = all_targets[all_indices]
|
all_targets = all_targets[all_indices]
|
||||||
|
|
||||||
# for sim_idx in range(num_sims):
|
if train_config.noise_type != NoiseType.NONE and train_config.noise_level > 0:
|
||||||
# # Input is state from t=0 to t=T-2
|
noise_shape = full_dataset_inputs.shape
|
||||||
# inputs = trajectories[sim_idx, :-1, :]
|
if train_config.noise_type == NoiseType.NORMAL:
|
||||||
# # Target is state from t=1 to t=T-1
|
noise = jax.random.normal(key, noise_shape) * train_config.noise_level
|
||||||
# targets = trajectories[sim_idx, 1:, :]
|
elif train_config.noise_type == NoiseType.UNIFORM:
|
||||||
# all_inputs.append(inputs)
|
noise = jax.random.uniform(
|
||||||
# all_targets.append(targets)
|
key, noise_shape,
|
||||||
|
minval=-train_config.noise_level,
|
||||||
# Concatenate all pairs from all simulations
|
maxval=train_config.noise_level
|
||||||
# Shape -> (num_sims * (num_timesteps - 1), num_agents)
|
)
|
||||||
# full_dataset_inputs = np.concatenate(all_inputs, axis=0)
|
full_dataset_inputs += np.array(noise) # Add noise to inputs
|
||||||
# full_dataset_targets = np.concatenate(all_targets, axis=0)
|
|
||||||
|
|
||||||
# Reshape to have a feature dimension
|
|
||||||
# Shape -> (total_samples, num_agents, 1)
|
|
||||||
full_dataset_inputs = np.expand_dims(all_inputs, axis=-1)
|
full_dataset_inputs = np.expand_dims(all_inputs, axis=-1)
|
||||||
full_dataset_targets = np.expand_dims(all_targets, axis=-1)
|
full_dataset_targets = np.expand_dims(all_targets, axis=-1)
|
||||||
|
|
||||||
@@ -138,6 +117,7 @@ def main():
|
|||||||
"""Main script to run the training and evaluation pipeline."""
|
"""Main script to run the training and evaluation pipeline."""
|
||||||
|
|
||||||
train_config = TrainConfig()
|
train_config = TrainConfig()
|
||||||
|
train_config.data_directory = "datasets/" + sys.argv[1] + "_dataset"
|
||||||
|
|
||||||
# Check if the data directory exists
|
# Check if the data directory exists
|
||||||
if not os.path.isdir(train_config.data_directory):
|
if not os.path.isdir(train_config.data_directory):
|
||||||
@@ -153,6 +133,8 @@ def main():
|
|||||||
key=lambda x: int(x.split('_')[1])
|
key=lambda x: int(x.split('_')[1])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
starter_key = jax.random.PRNGKey(49)
|
||||||
|
|
||||||
for agent_dir_name in agent_dirs:
|
for agent_dir_name in agent_dirs:
|
||||||
agent_dir_path = os.path.join(train_config.data_directory, agent_dir_name)
|
agent_dir_path = os.path.join(train_config.data_directory, agent_dir_name)
|
||||||
|
|
||||||
@@ -160,6 +142,13 @@ def main():
|
|||||||
|
|
||||||
graph_files = sorted([f for f in os.listdir(agent_dir_path) if f.endswith(".json")])
|
graph_files = sorted([f for f in os.listdir(agent_dir_path) if f.endswith(".json")])
|
||||||
|
|
||||||
|
results_dir = os.path.join(agent_dir_path, "results")
|
||||||
|
os.makedirs(results_dir, exist_ok=True)
|
||||||
|
|
||||||
|
subdir = str(train_config.noise_type)
|
||||||
|
sub_results_dir = os.path.join(results_dir, subdir)
|
||||||
|
os.makedirs(sub_results_dir, exist_ok=True)
|
||||||
|
|
||||||
print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...")
|
print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...")
|
||||||
|
|
||||||
for graph_file_name in tqdm(graph_files, desc=f"Training on {agent_dir_name}"):
|
for graph_file_name in tqdm(graph_files, desc=f"Training on {agent_dir_name}"):
|
||||||
@@ -175,7 +164,8 @@ def main():
|
|||||||
# np.random.shuffle(trajectories)
|
# np.random.shuffle(trajectories)
|
||||||
# trajectories = np.random.shuffle(trajectories)
|
# trajectories = np.random.shuffle(trajectories)
|
||||||
true_graph = np.array(data['adjacency_matrix'])
|
true_graph = np.array(data['adjacency_matrix'])
|
||||||
inputs, targets = prepare_data_for_model(trajectories, train_config.batch_size)
|
starter_key, data_key = jax.random.split(starter_key)
|
||||||
|
inputs, targets = prepare_data_for_model(data_key, trajectories, train_config, train_config.batch_size)
|
||||||
|
|
||||||
# 2. Configure Model
|
# 2. Configure Model
|
||||||
num_agents = trajectories.shape[-1]
|
num_agents = trajectories.shape[-1]
|
||||||
@@ -228,28 +218,34 @@ def main():
|
|||||||
"embedding_dim": model_config.embedding_dim
|
"embedding_dim": model_config.embedding_dim
|
||||||
},
|
},
|
||||||
# This is correct because TrainConfig is a dataclass
|
# This is correct because TrainConfig is a dataclass
|
||||||
"training": vars(train_config)
|
"training": {
|
||||||
}
|
"epochs" : train_config.epochs,
|
||||||
|
"learning_rate" : train_config.learning_rate,
|
||||||
|
"verbose" : train_config.verbose,
|
||||||
|
"log" : train_config.log,
|
||||||
|
"log_epoch_interval" : train_config.log_epoch_interval,
|
||||||
|
"noise_type": str(train_config.noise_type),
|
||||||
|
"noise_level": train_config.noise_level,
|
||||||
|
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"raw_attention": np.array(get_attention_fn(final_params, model_config)).tolist()
|
||||||
}
|
}
|
||||||
|
|
||||||
result_final_params = final_params
|
result_final_params = final_params
|
||||||
|
|
||||||
|
model_path = os.path.join(sub_results_dir, "model_params")
|
||||||
|
os.makedirs(model_path, exist_ok=True)
|
||||||
|
with open(os.path.join(model_path,"model_for_" + graph_file_name + ".pkl"), "wb") as f:
|
||||||
|
pickle.dump(final_params, f)
|
||||||
|
|
||||||
all_results_for_agent.append(result_log)
|
all_results_for_agent.append(result_log)
|
||||||
|
|
||||||
# 6. Save aggregated results for this agent count
|
output_file = os.path.join(sub_results_dir, "summary_results.json")
|
||||||
results_dir = os.path.join(agent_dir_path, "results")
|
|
||||||
os.makedirs(results_dir, exist_ok=True)
|
|
||||||
|
|
||||||
output_file = os.path.join(results_dir, "summary_results.json")
|
|
||||||
with open(output_file, 'w') as f:
|
with open(output_file, 'w') as f:
|
||||||
json.dump(all_results_for_agent, f, indent=2)
|
json.dump(all_results_for_agent, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
model_path = os.path.join(results_dir, "model_params")
|
|
||||||
os.makedirs(model_path, exist_ok=True)
|
|
||||||
with open(os.path.join(model_path,"model_params" + ".pkl"), "wb") as f:
|
|
||||||
pickle.dump(final_params, f)
|
|
||||||
|
|
||||||
print(f"✅ Results for {agent_dir_name} saved to {output_file}")
|
print(f"✅ Results for {agent_dir_name} saved to {output_file}")
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user