training and plotting updates
This commit is contained in:
		
							
								
								
									
										
											BIN
										
									
								
								__pycache__/config_.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								__pycache__/config_.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								__pycache__/train_and_eval.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								__pycache__/train_and_eval.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										33
									
								
								config_.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								config_.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,33 @@
 | 
				
			|||||||
 | 
					import sys
 | 
				
			||||||
 | 
					from enum import Enum
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class NoiseType(Enum):
 | 
				
			||||||
 | 
					    NONE = "none"
 | 
				
			||||||
 | 
					    NORMAL = "normal"
 | 
				
			||||||
 | 
					    UNIFORM = "uniform"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ModelConfig:
 | 
				
			||||||
 | 
					    num_agents: int
 | 
				
			||||||
 | 
					    embedding_dim: int = 16
 | 
				
			||||||
 | 
					    input_dim: int = 1
 | 
				
			||||||
 | 
					    output_dim: int = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TrainConfig:
 | 
				
			||||||
 | 
					    """Configuration for the training process."""
 | 
				
			||||||
 | 
					    learning_rate: float = 1e-3
 | 
				
			||||||
 | 
					    epochs: int = 100
 | 
				
			||||||
 | 
					    batch_size: int = 32
 | 
				
			||||||
 | 
					    verbose: bool = False  # Set to True to see epoch loss during training
 | 
				
			||||||
 | 
					    log: bool = True
 | 
				
			||||||
 | 
					    log_epoch_interval: int = 10
 | 
				
			||||||
 | 
					    noise_type: NoiseType = NoiseType.NONE
 | 
				
			||||||
 | 
					    noise_level: float = 0.01
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # --- New parameters for this script ---
 | 
				
			||||||
 | 
					    data_directory:str 
 | 
				
			||||||
 | 
					    # Threshold for converting attention scores to a binary graph
 | 
				
			||||||
 | 
					    f1_threshold: float = -0.5
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
							
								
								
									
										4
									
								
								model.py
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								model.py
									
									
									
									
									
								
							@@ -1,10 +1,12 @@
 | 
				
			|||||||
import jax
 | 
					import jax
 | 
				
			||||||
from jax import random
 | 
					from jax import random
 | 
				
			||||||
import jax.numpy as jnp
 | 
					import jax.numpy as jnp
 | 
				
			||||||
from train import ModelConfig, TrainConfig
 | 
					from config_ import ModelConfig, TrainConfig
 | 
				
			||||||
import optax
 | 
					import optax
 | 
				
			||||||
from functools import partial
 | 
					from functools import partial
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def init_linear_layer(
 | 
					def init_linear_layer(
 | 
				
			||||||
        key: jax.Array,
 | 
					        key: jax.Array,
 | 
				
			||||||
        in_features: int, 
 | 
					        in_features: int, 
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										74
									
								
								plot_results.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								plot_results.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,74 @@
 | 
				
			|||||||
 | 
					import os
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					import pickle
 | 
				
			||||||
 | 
					import matplotlib.pyplot as plt
 | 
				
			||||||
 | 
					from tqdm import tqdm
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from config_ import ModelConfig
 | 
				
			||||||
 | 
					from train_and_eval import calculate_f1_score
 | 
				
			||||||
 | 
					from sklearn.metrics import f1_score
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if len(sys.argv) < 2:
 | 
				
			||||||
 | 
					    data_dir = "datasets/consensus_dataset"
 | 
				
			||||||
 | 
					else:
 | 
				
			||||||
 | 
					    data_dir = "datasets/" + sys.argv[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					datapoints = {}
 | 
				
			||||||
 | 
					THRESHOLD = 0.2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					for folder in tqdm(os.listdir(data_dir)):
 | 
				
			||||||
 | 
					    num_agents = int(folder.split("_")[1]) # Extract num agents
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    folder_path = os.path.join(data_dir, folder)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Load model config from summary json
 | 
				
			||||||
 | 
					    with open(os.path.join(folder_path, "results/NoiseType.NONE", "summary_results.json"), "r") as f:
 | 
				
			||||||
 | 
					        summary_results = json.load(f)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    for i, graph in enumerate(os.listdir(folder_path)):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # train_summary_results
 | 
				
			||||||
 | 
					        summ_results = summary_results[i-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if graph == "results": # ignore the result folder
 | 
				
			||||||
 | 
					            continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        graph_path = os.path.join(folder_path, graph)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Load run data
 | 
				
			||||||
 | 
					        with open(os.path.join(folder_path, graph), "r") as f:
 | 
				
			||||||
 | 
					            run_data = json.load(f)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        true_graph = np.array(run_data["adjacency_matrix"])
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        learned_graph = np.array(summ_results["raw_attention"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        predicted_graph = (learned_graph > THRESHOLD).astype(int)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        true_flat = true_graph.flatten()
 | 
				
			||||||
 | 
					        pred_flat = predicted_graph.flatten()
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        calc_f1_score = f1_score(true_flat, pred_flat)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        datapoints[num_agents] = datapoints.get(num_agents, [])
 | 
				
			||||||
 | 
					        datapoints[num_agents].append(calc_f1_score)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					for key in datapoints.keys():
 | 
				
			||||||
 | 
					    datapoints[key] = sum(datapoints[key])/len(datapoints[key])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					x = []
 | 
				
			||||||
 | 
					y = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					for item in datapoints.items():
 | 
				
			||||||
 | 
					    x.append(item[0])
 | 
				
			||||||
 | 
					    y.append(item[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					plt.plot(x, y)
 | 
				
			||||||
 | 
					plt.show()         
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
							
								
								
									
										306
									
								
								test.ipynb
									
									
									
									
									
								
							
							
						
						
									
										306
									
								
								test.ipynb
									
									
									
									
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										19
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								train.py
									
									
									
									
									
								
							@@ -1,19 +0,0 @@
 | 
				
			|||||||
from dataclasses import dataclass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ModelConfig:
 | 
					 | 
				
			||||||
    num_agents: int = 10
 | 
					 | 
				
			||||||
    embedding_dim: int = 64
 | 
					 | 
				
			||||||
    input_dim: int = 1
 | 
					 | 
				
			||||||
    output_dim: int = 1
 | 
					 | 
				
			||||||
    simulation_type:str = "consensus"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class TrainConfig:
 | 
					 | 
				
			||||||
    epochs: float = 100
 | 
					 | 
				
			||||||
    learning_rate: float = 1e-3
 | 
					 | 
				
			||||||
    verbose: bool = True
 | 
					 | 
				
			||||||
    log: bool = True
 | 
					 | 
				
			||||||
    log_epoch_interval: int = 10
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@@ -11,30 +11,11 @@ import uuid
 | 
				
			|||||||
import pickle
 | 
					import pickle
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from config_ import TrainConfig, ModelConfig, NoiseType
 | 
				
			||||||
 | 
					from model import train_model, get_attention_fn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Import from your existing model file
 | 
					 | 
				
			||||||
from model import ModelConfig, TrainConfig, train_model, get_attention_fn
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Define an Enum for the data source for clarity and type safety
 | 
					def prepare_data_for_model(key: jax.Array, trajectories: np.ndarray, train_config: TrainConfig, batch_size: int) -> tuple[np.ndarray, np.ndarray]:
 | 
				
			||||||
 | 
					 | 
				
			||||||
# Overwrite the original TrainConfig to include our new parameters
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class TrainConfig:
 | 
					 | 
				
			||||||
    """Configuration for the training process."""
 | 
					 | 
				
			||||||
    learning_rate: float = 1e-3
 | 
					 | 
				
			||||||
    epochs: int = 100
 | 
					 | 
				
			||||||
    batch_size: int = 4096
 | 
					 | 
				
			||||||
    verbose: bool = False  # Set to True to see epoch loss during training
 | 
					 | 
				
			||||||
    log: bool = True
 | 
					 | 
				
			||||||
    log_epoch_interval: int = 10
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # --- New parameters for this script ---
 | 
					 | 
				
			||||||
    data_directory:str = "datasets/" + sys.argv[1] + "_dataset"
 | 
					 | 
				
			||||||
    # Threshold for converting attention scores to a binary graph
 | 
					 | 
				
			||||||
    f1_threshold: float = -0.4
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def prepare_data_for_model(trajectories: np.ndarray, batch_size: int) -> tuple[np.ndarray, np.ndarray]:
 | 
					 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Converts simulation trajectories into input-output pairs for the model.
 | 
					    Converts simulation trajectories into input-output pairs for the model.
 | 
				
			||||||
    Input: state at time t. Target: state at time t+1.
 | 
					    Input: state at time t. Target: state at time t+1.
 | 
				
			||||||
@@ -69,21 +50,19 @@ def prepare_data_for_model(trajectories: np.ndarray, batch_size: int) -> tuple[n
 | 
				
			|||||||
    all_inputs = all_inputs[all_indices]
 | 
					    all_inputs = all_inputs[all_indices]
 | 
				
			||||||
    all_targets = all_targets[all_indices]
 | 
					    all_targets = all_targets[all_indices]
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
    # for sim_idx in range(num_sims):
 | 
					    if train_config.noise_type != NoiseType.NONE and train_config.noise_level > 0:
 | 
				
			||||||
    #     # Input is state from t=0 to t=T-2
 | 
					        noise_shape = full_dataset_inputs.shape
 | 
				
			||||||
    #     inputs = trajectories[sim_idx, :-1, :]
 | 
					        if train_config.noise_type == NoiseType.NORMAL:
 | 
				
			||||||
    #     # Target is state from t=1 to t=T-1
 | 
					            noise = jax.random.normal(key, noise_shape) * train_config.noise_level
 | 
				
			||||||
    #     targets = trajectories[sim_idx, 1:, :]
 | 
					        elif train_config.noise_type == NoiseType.UNIFORM:
 | 
				
			||||||
    #     all_inputs.append(inputs)
 | 
					            noise = jax.random.uniform(
 | 
				
			||||||
    #     all_targets.append(targets)
 | 
					                key, noise_shape, 
 | 
				
			||||||
        
 | 
					                minval=-train_config.noise_level, 
 | 
				
			||||||
    # Concatenate all pairs from all simulations
 | 
					                maxval=train_config.noise_level
 | 
				
			||||||
    # Shape -> (num_sims * (num_timesteps - 1), num_agents)
 | 
					            )
 | 
				
			||||||
    # full_dataset_inputs = np.concatenate(all_inputs, axis=0)
 | 
					        full_dataset_inputs += np.array(noise) # Add noise to inputs
 | 
				
			||||||
    # full_dataset_targets = np.concatenate(all_targets, axis=0)
 | 
					
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    # Reshape to have a feature dimension
 | 
					 | 
				
			||||||
    # Shape -> (total_samples, num_agents, 1)
 | 
					 | 
				
			||||||
    full_dataset_inputs = np.expand_dims(all_inputs, axis=-1)
 | 
					    full_dataset_inputs = np.expand_dims(all_inputs, axis=-1)
 | 
				
			||||||
    full_dataset_targets = np.expand_dims(all_targets, axis=-1)
 | 
					    full_dataset_targets = np.expand_dims(all_targets, axis=-1)
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
@@ -138,6 +117,7 @@ def main():
 | 
				
			|||||||
    """Main script to run the training and evaluation pipeline."""
 | 
					    """Main script to run the training and evaluation pipeline."""
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    train_config = TrainConfig()
 | 
					    train_config = TrainConfig()
 | 
				
			||||||
 | 
					    train_config.data_directory = "datasets/" + sys.argv[1] + "_dataset"
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    # Check if the data directory exists
 | 
					    # Check if the data directory exists
 | 
				
			||||||
    if not os.path.isdir(train_config.data_directory):
 | 
					    if not os.path.isdir(train_config.data_directory):
 | 
				
			||||||
@@ -153,6 +133,8 @@ def main():
 | 
				
			|||||||
        key=lambda x: int(x.split('_')[1])
 | 
					        key=lambda x: int(x.split('_')[1])
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    starter_key = jax.random.PRNGKey(49)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for agent_dir_name in agent_dirs:
 | 
					    for agent_dir_name in agent_dirs:
 | 
				
			||||||
        agent_dir_path = os.path.join(train_config.data_directory, agent_dir_name)
 | 
					        agent_dir_path = os.path.join(train_config.data_directory, agent_dir_name)
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
@@ -160,6 +142,13 @@ def main():
 | 
				
			|||||||
        
 | 
					        
 | 
				
			||||||
        graph_files = sorted([f for f in os.listdir(agent_dir_path) if f.endswith(".json")])
 | 
					        graph_files = sorted([f for f in os.listdir(agent_dir_path) if f.endswith(".json")])
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 | 
					        results_dir = os.path.join(agent_dir_path, "results")
 | 
				
			||||||
 | 
					        os.makedirs(results_dir, exist_ok=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        subdir = str(train_config.noise_type)
 | 
				
			||||||
 | 
					        sub_results_dir = os.path.join(results_dir, subdir)
 | 
				
			||||||
 | 
					        os.makedirs(sub_results_dir, exist_ok=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...")
 | 
					        print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...")
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        for graph_file_name in tqdm(graph_files, desc=f"Training on {agent_dir_name}"):
 | 
					        for graph_file_name in tqdm(graph_files, desc=f"Training on {agent_dir_name}"):
 | 
				
			||||||
@@ -175,7 +164,8 @@ def main():
 | 
				
			|||||||
            # np.random.shuffle(trajectories)
 | 
					            # np.random.shuffle(trajectories)
 | 
				
			||||||
            # trajectories = np.random.shuffle(trajectories)
 | 
					            # trajectories = np.random.shuffle(trajectories)
 | 
				
			||||||
            true_graph = np.array(data['adjacency_matrix'])
 | 
					            true_graph = np.array(data['adjacency_matrix'])
 | 
				
			||||||
            inputs, targets = prepare_data_for_model(trajectories, train_config.batch_size)
 | 
					            starter_key, data_key = jax.random.split(starter_key)
 | 
				
			||||||
 | 
					            inputs, targets = prepare_data_for_model(data_key, trajectories, train_config, train_config.batch_size)
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
            # 2. Configure Model
 | 
					            # 2. Configure Model
 | 
				
			||||||
            num_agents = trajectories.shape[-1]
 | 
					            num_agents = trajectories.shape[-1]
 | 
				
			||||||
@@ -228,28 +218,34 @@ def main():
 | 
				
			|||||||
                        "embedding_dim": model_config.embedding_dim
 | 
					                        "embedding_dim": model_config.embedding_dim
 | 
				
			||||||
                    },
 | 
					                    },
 | 
				
			||||||
                    # This is correct because TrainConfig is a dataclass
 | 
					                    # This is correct because TrainConfig is a dataclass
 | 
				
			||||||
                    "training": vars(train_config)
 | 
					                    "training": {
 | 
				
			||||||
                }
 | 
					                        "epochs" : train_config.epochs,
 | 
				
			||||||
 | 
					                        "learning_rate" : train_config.learning_rate,
 | 
				
			||||||
 | 
					                        "verbose" : train_config.verbose,
 | 
				
			||||||
 | 
					                        "log" : train_config.log,
 | 
				
			||||||
 | 
					                        "log_epoch_interval" : train_config.log_epoch_interval,
 | 
				
			||||||
 | 
					                        "noise_type": str(train_config.noise_type),
 | 
				
			||||||
 | 
					                        "noise_level": train_config.noise_level,
 | 
				
			||||||
 | 
					                        
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					                "raw_attention": np.array(get_attention_fn(final_params, model_config)).tolist()
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            result_final_params = final_params
 | 
					            result_final_params = final_params
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
 | 
					            model_path = os.path.join(sub_results_dir, "model_params")
 | 
				
			||||||
 | 
					            os.makedirs(model_path, exist_ok=True)
 | 
				
			||||||
 | 
					            with open(os.path.join(model_path,"model_for_" + graph_file_name + ".pkl"), "wb") as f:
 | 
				
			||||||
 | 
					                pickle.dump(final_params, f)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            all_results_for_agent.append(result_log)
 | 
					            all_results_for_agent.append(result_log)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # 6. Save aggregated results for this agent count
 | 
					        output_file = os.path.join(sub_results_dir, "summary_results.json")
 | 
				
			||||||
        results_dir = os.path.join(agent_dir_path, "results")
 | 
					 | 
				
			||||||
        os.makedirs(results_dir, exist_ok=True)
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        output_file = os.path.join(results_dir, "summary_results.json")
 | 
					 | 
				
			||||||
        with open(output_file, 'w') as f:
 | 
					        with open(output_file, 'w') as f:
 | 
				
			||||||
            json.dump(all_results_for_agent, f, indent=2)
 | 
					            json.dump(all_results_for_agent, f, indent=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        model_path = os.path.join(results_dir, "model_params")
 | 
					 | 
				
			||||||
        os.makedirs(model_path, exist_ok=True)
 | 
					 | 
				
			||||||
        with open(os.path.join(model_path,"model_params" + ".pkl"), "wb") as f:
 | 
					 | 
				
			||||||
            pickle.dump(final_params, f)
 | 
					 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
        print(f"✅ Results for {agent_dir_name} saved to {output_file}")
 | 
					        print(f"✅ Results for {agent_dir_name} saved to {output_file}")
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user