training and plotting updates
This commit is contained in:
		
							
								
								
									
										
											BIN
										
									
								
								__pycache__/config_.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								__pycache__/config_.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								__pycache__/train_and_eval.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								__pycache__/train_and_eval.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										33
									
								
								config_.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								config_.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,33 @@
 | 
			
		||||
import sys
 | 
			
		||||
from enum import Enum
 | 
			
		||||
 | 
			
		||||
class NoiseType(Enum):
 | 
			
		||||
    NONE = "none"
 | 
			
		||||
    NORMAL = "normal"
 | 
			
		||||
    UNIFORM = "uniform"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelConfig:
 | 
			
		||||
    num_agents: int
 | 
			
		||||
    embedding_dim: int = 16
 | 
			
		||||
    input_dim: int = 1
 | 
			
		||||
    output_dim: int = 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TrainConfig:
 | 
			
		||||
    """Configuration for the training process."""
 | 
			
		||||
    learning_rate: float = 1e-3
 | 
			
		||||
    epochs: int = 100
 | 
			
		||||
    batch_size: int = 32
 | 
			
		||||
    verbose: bool = False  # Set to True to see epoch loss during training
 | 
			
		||||
    log: bool = True
 | 
			
		||||
    log_epoch_interval: int = 10
 | 
			
		||||
    noise_type: NoiseType = NoiseType.NONE
 | 
			
		||||
    noise_level: float = 0.01
 | 
			
		||||
    
 | 
			
		||||
    # --- New parameters for this script ---
 | 
			
		||||
    data_directory:str 
 | 
			
		||||
    # Threshold for converting attention scores to a binary graph
 | 
			
		||||
    f1_threshold: float = -0.5
 | 
			
		||||
    
 | 
			
		||||
							
								
								
									
										4
									
								
								model.py
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								model.py
									
									
									
									
									
								
							@@ -1,10 +1,12 @@
 | 
			
		||||
import jax
 | 
			
		||||
from jax import random
 | 
			
		||||
import jax.numpy as jnp
 | 
			
		||||
from train import ModelConfig, TrainConfig
 | 
			
		||||
from config_ import ModelConfig, TrainConfig
 | 
			
		||||
import optax
 | 
			
		||||
from functools import partial
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init_linear_layer(
 | 
			
		||||
        key: jax.Array,
 | 
			
		||||
        in_features: int, 
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										74
									
								
								plot_results.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								plot_results.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,74 @@
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
import json
 | 
			
		||||
import pickle
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
import numpy as np
 | 
			
		||||
from config_ import ModelConfig
 | 
			
		||||
from train_and_eval import calculate_f1_score
 | 
			
		||||
from sklearn.metrics import f1_score
 | 
			
		||||
 | 
			
		||||
if len(sys.argv) < 2:
 | 
			
		||||
    data_dir = "datasets/consensus_dataset"
 | 
			
		||||
else:
 | 
			
		||||
    data_dir = "datasets/" + sys.argv[1]
 | 
			
		||||
 | 
			
		||||
datapoints = {}
 | 
			
		||||
THRESHOLD = 0.2
 | 
			
		||||
 | 
			
		||||
for folder in tqdm(os.listdir(data_dir)):
 | 
			
		||||
    num_agents = int(folder.split("_")[1]) # Extract num agents
 | 
			
		||||
    
 | 
			
		||||
    folder_path = os.path.join(data_dir, folder)
 | 
			
		||||
 | 
			
		||||
    # Load model config from summary json
 | 
			
		||||
    with open(os.path.join(folder_path, "results/NoiseType.NONE", "summary_results.json"), "r") as f:
 | 
			
		||||
        summary_results = json.load(f)
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    for i, graph in enumerate(os.listdir(folder_path)):
 | 
			
		||||
 | 
			
		||||
        # train_summary_results
 | 
			
		||||
        summ_results = summary_results[i-1]
 | 
			
		||||
 | 
			
		||||
        if graph == "results": # ignore the result folder
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        graph_path = os.path.join(folder_path, graph)
 | 
			
		||||
 | 
			
		||||
        # Load run data
 | 
			
		||||
        with open(os.path.join(folder_path, graph), "r") as f:
 | 
			
		||||
            run_data = json.load(f)
 | 
			
		||||
 | 
			
		||||
        true_graph = np.array(run_data["adjacency_matrix"])
 | 
			
		||||
        
 | 
			
		||||
        learned_graph = np.array(summ_results["raw_attention"])
 | 
			
		||||
 | 
			
		||||
        predicted_graph = (learned_graph > THRESHOLD).astype(int)
 | 
			
		||||
 | 
			
		||||
        true_flat = true_graph.flatten()
 | 
			
		||||
        pred_flat = predicted_graph.flatten()
 | 
			
		||||
        
 | 
			
		||||
        calc_f1_score = f1_score(true_flat, pred_flat)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        datapoints[num_agents] = datapoints.get(num_agents, [])
 | 
			
		||||
        datapoints[num_agents].append(calc_f1_score)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
for key in datapoints.keys():
 | 
			
		||||
    datapoints[key] = sum(datapoints[key])/len(datapoints[key])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
x = []
 | 
			
		||||
y = []
 | 
			
		||||
 | 
			
		||||
for item in datapoints.items():
 | 
			
		||||
    x.append(item[0])
 | 
			
		||||
    y.append(item[1])
 | 
			
		||||
 | 
			
		||||
plt.plot(x, y)
 | 
			
		||||
plt.show()         
 | 
			
		||||
 | 
			
		||||
        
 | 
			
		||||
							
								
								
									
										306
									
								
								test.ipynb
									
									
									
									
									
								
							
							
						
						
									
										306
									
								
								test.ipynb
									
									
									
									
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										19
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								train.py
									
									
									
									
									
								
							@@ -1,19 +0,0 @@
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
 | 
			
		||||
class ModelConfig:
 | 
			
		||||
    num_agents: int = 10
 | 
			
		||||
    embedding_dim: int = 64
 | 
			
		||||
    input_dim: int = 1
 | 
			
		||||
    output_dim: int = 1
 | 
			
		||||
    simulation_type:str = "consensus"
 | 
			
		||||
 | 
			
		||||
class TrainConfig:
 | 
			
		||||
    epochs: float = 100
 | 
			
		||||
    learning_rate: float = 1e-3
 | 
			
		||||
    verbose: bool = True
 | 
			
		||||
    log: bool = True
 | 
			
		||||
    log_epoch_interval: int = 10
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -11,30 +11,11 @@ import uuid
 | 
			
		||||
import pickle
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Import from your existing model file
 | 
			
		||||
from model import ModelConfig, TrainConfig, train_model, get_attention_fn
 | 
			
		||||
 | 
			
		||||
# Define an Enum for the data source for clarity and type safety
 | 
			
		||||
 | 
			
		||||
# Overwrite the original TrainConfig to include our new parameters
 | 
			
		||||
 | 
			
		||||
class TrainConfig:
 | 
			
		||||
    """Configuration for the training process."""
 | 
			
		||||
    learning_rate: float = 1e-3
 | 
			
		||||
    epochs: int = 100
 | 
			
		||||
    batch_size: int = 4096
 | 
			
		||||
    verbose: bool = False  # Set to True to see epoch loss during training
 | 
			
		||||
    log: bool = True
 | 
			
		||||
    log_epoch_interval: int = 10
 | 
			
		||||
    
 | 
			
		||||
    # --- New parameters for this script ---
 | 
			
		||||
    data_directory:str = "datasets/" + sys.argv[1] + "_dataset"
 | 
			
		||||
    # Threshold for converting attention scores to a binary graph
 | 
			
		||||
    f1_threshold: float = -0.4
 | 
			
		||||
from config_ import TrainConfig, ModelConfig, NoiseType
 | 
			
		||||
from model import train_model, get_attention_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def prepare_data_for_model(trajectories: np.ndarray, batch_size: int) -> tuple[np.ndarray, np.ndarray]:
 | 
			
		||||
def prepare_data_for_model(key: jax.Array, trajectories: np.ndarray, train_config: TrainConfig, batch_size: int) -> tuple[np.ndarray, np.ndarray]:
 | 
			
		||||
    """
 | 
			
		||||
    Converts simulation trajectories into input-output pairs for the model.
 | 
			
		||||
    Input: state at time t. Target: state at time t+1.
 | 
			
		||||
@@ -69,21 +50,19 @@ def prepare_data_for_model(trajectories: np.ndarray, batch_size: int) -> tuple[n
 | 
			
		||||
    all_inputs = all_inputs[all_indices]
 | 
			
		||||
    all_targets = all_targets[all_indices]
 | 
			
		||||
            
 | 
			
		||||
    # for sim_idx in range(num_sims):
 | 
			
		||||
    #     # Input is state from t=0 to t=T-2
 | 
			
		||||
    #     inputs = trajectories[sim_idx, :-1, :]
 | 
			
		||||
    #     # Target is state from t=1 to t=T-1
 | 
			
		||||
    #     targets = trajectories[sim_idx, 1:, :]
 | 
			
		||||
    #     all_inputs.append(inputs)
 | 
			
		||||
    #     all_targets.append(targets)
 | 
			
		||||
    if train_config.noise_type != NoiseType.NONE and train_config.noise_level > 0:
 | 
			
		||||
        noise_shape = full_dataset_inputs.shape
 | 
			
		||||
        if train_config.noise_type == NoiseType.NORMAL:
 | 
			
		||||
            noise = jax.random.normal(key, noise_shape) * train_config.noise_level
 | 
			
		||||
        elif train_config.noise_type == NoiseType.UNIFORM:
 | 
			
		||||
            noise = jax.random.uniform(
 | 
			
		||||
                key, noise_shape, 
 | 
			
		||||
                minval=-train_config.noise_level, 
 | 
			
		||||
                maxval=train_config.noise_level
 | 
			
		||||
            )
 | 
			
		||||
        full_dataset_inputs += np.array(noise) # Add noise to inputs
 | 
			
		||||
 | 
			
		||||
    # Concatenate all pairs from all simulations
 | 
			
		||||
    # Shape -> (num_sims * (num_timesteps - 1), num_agents)
 | 
			
		||||
    # full_dataset_inputs = np.concatenate(all_inputs, axis=0)
 | 
			
		||||
    # full_dataset_targets = np.concatenate(all_targets, axis=0)
 | 
			
		||||
    
 | 
			
		||||
    # Reshape to have a feature dimension
 | 
			
		||||
    # Shape -> (total_samples, num_agents, 1)
 | 
			
		||||
    full_dataset_inputs = np.expand_dims(all_inputs, axis=-1)
 | 
			
		||||
    full_dataset_targets = np.expand_dims(all_targets, axis=-1)
 | 
			
		||||
    
 | 
			
		||||
@@ -138,6 +117,7 @@ def main():
 | 
			
		||||
    """Main script to run the training and evaluation pipeline."""
 | 
			
		||||
    
 | 
			
		||||
    train_config = TrainConfig()
 | 
			
		||||
    train_config.data_directory = "datasets/" + sys.argv[1] + "_dataset"
 | 
			
		||||
    
 | 
			
		||||
    # Check if the data directory exists
 | 
			
		||||
    if not os.path.isdir(train_config.data_directory):
 | 
			
		||||
@@ -153,6 +133,8 @@ def main():
 | 
			
		||||
        key=lambda x: int(x.split('_')[1])
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    starter_key = jax.random.PRNGKey(49)
 | 
			
		||||
 | 
			
		||||
    for agent_dir_name in agent_dirs:
 | 
			
		||||
        agent_dir_path = os.path.join(train_config.data_directory, agent_dir_name)
 | 
			
		||||
        
 | 
			
		||||
@@ -160,6 +142,13 @@ def main():
 | 
			
		||||
        
 | 
			
		||||
        graph_files = sorted([f for f in os.listdir(agent_dir_path) if f.endswith(".json")])
 | 
			
		||||
        
 | 
			
		||||
        results_dir = os.path.join(agent_dir_path, "results")
 | 
			
		||||
        os.makedirs(results_dir, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
        subdir = str(train_config.noise_type)
 | 
			
		||||
        sub_results_dir = os.path.join(results_dir, subdir)
 | 
			
		||||
        os.makedirs(sub_results_dir, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
        print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...")
 | 
			
		||||
        
 | 
			
		||||
        for graph_file_name in tqdm(graph_files, desc=f"Training on {agent_dir_name}"):
 | 
			
		||||
@@ -175,7 +164,8 @@ def main():
 | 
			
		||||
            # np.random.shuffle(trajectories)
 | 
			
		||||
            # trajectories = np.random.shuffle(trajectories)
 | 
			
		||||
            true_graph = np.array(data['adjacency_matrix'])
 | 
			
		||||
            inputs, targets = prepare_data_for_model(trajectories, train_config.batch_size)
 | 
			
		||||
            starter_key, data_key = jax.random.split(starter_key)
 | 
			
		||||
            inputs, targets = prepare_data_for_model(data_key, trajectories, train_config, train_config.batch_size)
 | 
			
		||||
            
 | 
			
		||||
            # 2. Configure Model
 | 
			
		||||
            num_agents = trajectories.shape[-1]
 | 
			
		||||
@@ -228,28 +218,34 @@ def main():
 | 
			
		||||
                        "embedding_dim": model_config.embedding_dim
 | 
			
		||||
                    },
 | 
			
		||||
                    # This is correct because TrainConfig is a dataclass
 | 
			
		||||
                    "training": vars(train_config)
 | 
			
		||||
                    "training": {
 | 
			
		||||
                        "epochs" : train_config.epochs,
 | 
			
		||||
                        "learning_rate" : train_config.learning_rate,
 | 
			
		||||
                        "verbose" : train_config.verbose,
 | 
			
		||||
                        "log" : train_config.log,
 | 
			
		||||
                        "log_epoch_interval" : train_config.log_epoch_interval,
 | 
			
		||||
                        "noise_type": str(train_config.noise_type),
 | 
			
		||||
                        "noise_level": train_config.noise_level,
 | 
			
		||||
                        
 | 
			
		||||
                    }
 | 
			
		||||
                },
 | 
			
		||||
                "raw_attention": np.array(get_attention_fn(final_params, model_config)).tolist()
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            result_final_params = final_params
 | 
			
		||||
            
 | 
			
		||||
            model_path = os.path.join(sub_results_dir, "model_params")
 | 
			
		||||
            os.makedirs(model_path, exist_ok=True)
 | 
			
		||||
            with open(os.path.join(model_path,"model_for_" + graph_file_name + ".pkl"), "wb") as f:
 | 
			
		||||
                pickle.dump(final_params, f)
 | 
			
		||||
 | 
			
		||||
            all_results_for_agent.append(result_log)
 | 
			
		||||
 | 
			
		||||
        # 6. Save aggregated results for this agent count
 | 
			
		||||
        results_dir = os.path.join(agent_dir_path, "results")
 | 
			
		||||
        os.makedirs(results_dir, exist_ok=True)
 | 
			
		||||
        
 | 
			
		||||
        output_file = os.path.join(results_dir, "summary_results.json")
 | 
			
		||||
        output_file = os.path.join(sub_results_dir, "summary_results.json")
 | 
			
		||||
        with open(output_file, 'w') as f:
 | 
			
		||||
            json.dump(all_results_for_agent, f, indent=2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        model_path = os.path.join(results_dir, "model_params")
 | 
			
		||||
        os.makedirs(model_path, exist_ok=True)
 | 
			
		||||
        with open(os.path.join(model_path,"model_params" + ".pkl"), "wb") as f:
 | 
			
		||||
            pickle.dump(final_params, f)
 | 
			
		||||
            
 | 
			
		||||
        print(f"✅ Results for {agent_dir_name} saved to {output_file}")
 | 
			
		||||
        
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user