replicated mecc
This commit is contained in:
		
							
								
								
									
										7
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,7 @@
 | 
			
		||||
.venv
 | 
			
		||||
.env
 | 
			
		||||
__pycaches__
 | 
			
		||||
datasets/
 | 
			
		||||
temp.*
 | 
			
		||||
test.*
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										15
									
								
								.vscode/launch.json
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								.vscode/launch.json
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,15 @@
 | 
			
		||||
{
 | 
			
		||||
    // Use IntelliSense to learn about possible attributes.
 | 
			
		||||
    // Hover to view descriptions of existing attributes.
 | 
			
		||||
    // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
 | 
			
		||||
    "version": "0.2.0",
 | 
			
		||||
    "configurations": [
 | 
			
		||||
        {
 | 
			
		||||
            "name": "Python Debugger: Current File",
 | 
			
		||||
            "type": "debugpy",
 | 
			
		||||
            "request": "launch",
 | 
			
		||||
            "program": "${file}",
 | 
			
		||||
            "console": "integratedTerminal"
 | 
			
		||||
        }
 | 
			
		||||
    ]
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								__pycache__/generate_data_consensus.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								__pycache__/generate_data_consensus.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								__pycache__/model.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								__pycache__/model.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								__pycache__/train.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								__pycache__/train.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										206
									
								
								generate_data_consensus.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										206
									
								
								generate_data_consensus.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,206 @@
 | 
			
		||||
import os
 | 
			
		||||
import json
 | 
			
		||||
import time
 | 
			
		||||
import jax
 | 
			
		||||
import jax.numpy as jnp
 | 
			
		||||
import numpy as np
 | 
			
		||||
import networkx as nx
 | 
			
		||||
import networkx.algorithms.community as nx_comm
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
import sims
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_connected_graph(rng: np.random.Generator, num_agents: int, graph_type: str) -> tuple[nx.Graph, str, np.ndarray]:
 | 
			
		||||
    """
 | 
			
		||||
    Generates a random, undirected, unweighted, connected graph.
 | 
			
		||||
    It randomly selects a NetworkX algorithm and retries until the graph is connected.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        A tuple containing the networkx Graph object, the algorithm name, and the adjacency matrix.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    G = None
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    if graph_type == "erdos_renyi":
 | 
			
		||||
    
 | 
			
		||||
        # p = 2.0 * np.log(num_agents) / num_agents
 | 
			
		||||
        p = 0.5
 | 
			
		||||
        G = nx.erdos_renyi_graph(num_agents, p, seed=rng)
 | 
			
		||||
    
 | 
			
		||||
    elif graph_type == "watts_strogatz":
 | 
			
		||||
        k = 2
 | 
			
		||||
        p = 0.1
 | 
			
		||||
        G = nx.watts_strogatz_graph(num_agents, k, p, seed=rng)
 | 
			
		||||
    
 | 
			
		||||
    elif graph_type == "barabasi_albert":
 | 
			
		||||
        m = 2
 | 
			
		||||
        if m >= num_agents: m = max(1, num_agents - 1)
 | 
			
		||||
        G = nx.barabasi_albert_graph(num_agents, m, seed=rng)
 | 
			
		||||
    
 | 
			
		||||
    elif  graph_type == "powerlaw_cluster":
 | 
			
		||||
        m = 3 # Number of random edges to add for each new node
 | 
			
		||||
        p = 0.1 # Probability of adding a triangle after adding a random edge
 | 
			
		||||
        if m >= num_agents: m = max(1, num_agents - 1)
 | 
			
		||||
        G = nx.powerlaw_cluster_graph(num_agents, m, p, seed=rng)
 | 
			
		||||
 | 
			
		||||
    # Add self-loops, as they are often assumed in consensus algorithms
 | 
			
		||||
    G.add_edges_from([(i, i) for i in range(num_agents)])
 | 
			
		||||
    adj_matrix = nx.to_numpy_array(G, dtype=np.float32)
 | 
			
		||||
 | 
			
		||||
    return G, graph_type, adj_matrix
 | 
			
		||||
 | 
			
		||||
def calculate_graph_metrics(G: nx.Graph) -> dict:
 | 
			
		||||
    """
 | 
			
		||||
    Calculates and returns a dictionary of key graph metrics.
 | 
			
		||||
 | 
			
		||||
    This function computes basic properties, connectivity, clustering,
 | 
			
		||||
    community structure, and spectral properties of the graph.
 | 
			
		||||
    Computationally expensive metrics are skipped for larger graphs to ensure performance.
 | 
			
		||||
    """
 | 
			
		||||
    metrics = {}
 | 
			
		||||
    num_nodes = G.number_of_nodes()
 | 
			
		||||
    num_edges = G.number_of_edges()
 | 
			
		||||
 | 
			
		||||
    # --- Basic Properties ---
 | 
			
		||||
    metrics["number_of_nodes"] = num_nodes
 | 
			
		||||
    metrics["number_of_edges"] = num_edges
 | 
			
		||||
    metrics["average_degree"] = (2 * num_edges / num_nodes) if num_nodes > 0 else 0
 | 
			
		||||
    metrics["edge_density"] = nx.density(G)
 | 
			
		||||
 | 
			
		||||
    # --- Connectivity ---
 | 
			
		||||
    metrics["is_connected"] = nx.is_connected(G)
 | 
			
		||||
    metrics["number_connected_components"] = nx.number_connected_components(G)
 | 
			
		||||
 | 
			
		||||
    # --- Clustering ---
 | 
			
		||||
    metrics["average_clustering_coefficient"] = nx.average_clustering(G)
 | 
			
		||||
    metrics["clustering_coefficient"] = nx.clustering(G) # Per-node clustering
 | 
			
		||||
 | 
			
		||||
    # --- Distance-Based Metrics (for connected graphs) ---
 | 
			
		||||
    # These are computationally intensive and only run on smaller, connected graphs.
 | 
			
		||||
    if metrics["is_connected"] and num_nodes < 250:
 | 
			
		||||
        metrics["average_shortest_path_length"] = nx.average_shortest_path_length(G)
 | 
			
		||||
        metrics["diameter"] = nx.diameter(G)
 | 
			
		||||
        metrics["eccentricity"] = nx.eccentricity(G)
 | 
			
		||||
    else:
 | 
			
		||||
        # Set to None if graph is disconnected or too large
 | 
			
		||||
        metrics["average_shortest_path_length"] = None
 | 
			
		||||
        metrics["diameter"] = None
 | 
			
		||||
        metrics["eccentricity"] = None
 | 
			
		||||
 | 
			
		||||
    # --- Spectral & Community Metrics (Potentially Slow) ---
 | 
			
		||||
    # These are also limited to smaller graphs.
 | 
			
		||||
    if 1 < num_nodes < 500:
 | 
			
		||||
        # Eigenvalues of the Laplacian matrix
 | 
			
		||||
        try:
 | 
			
		||||
            laplacian_eigenvalues = sorted(nx.laplacian_spectrum(G))
 | 
			
		||||
            metrics["laplacian_eigenvalues"] = laplacian_eigenvalues
 | 
			
		||||
            # The second-smallest eigenvalue of the Laplacian matrix
 | 
			
		||||
            metrics["algebraic_connectivity"] = laplacian_eigenvalues[1]
 | 
			
		||||
        except Exception:
 | 
			
		||||
            metrics["laplacian_eigenvalues"] = None
 | 
			
		||||
            metrics["algebraic_connectivity"] = None
 | 
			
		||||
 | 
			
		||||
        # Modularity using the Louvain community detection algorithm
 | 
			
		||||
        if num_edges > 0:
 | 
			
		||||
            try:
 | 
			
		||||
                communities = nx_comm.louvain_communities(G, seed=123)
 | 
			
		||||
                metrics["modularity"] = nx_comm.modularity(G, communities)
 | 
			
		||||
            except Exception:
 | 
			
		||||
                metrics["modularity"] = None # Algorithm may fail on some graphs
 | 
			
		||||
        else:
 | 
			
		||||
            metrics["modularity"] = None
 | 
			
		||||
    else:
 | 
			
		||||
        metrics["laplacian_eigenvalues"] = None
 | 
			
		||||
        metrics["algebraic_connectivity"] = None
 | 
			
		||||
        metrics["modularity"] = None
 | 
			
		||||
        
 | 
			
		||||
    return metrics
 | 
			
		||||
 | 
			
		||||
def generate_multiple_initial_states(key: jax.Array, num_sims: int, num_agents: int, max_range: float) -> jax.Array:
 | 
			
		||||
    """Generate a batch of unique random initial states for the agents."""
 | 
			
		||||
    return jax.random.uniform(
 | 
			
		||||
        key,
 | 
			
		||||
        shape=(num_sims, num_agents),
 | 
			
		||||
        minval=-max_range,
 | 
			
		||||
        maxval=max_range
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    """Main script to generate and save the consensus simulation dataset."""
 | 
			
		||||
    # --- Configuration ---
 | 
			
		||||
    GRAPH_GEN_ALGOS = ["erdos_renyi", "barabasi_albert", "powerlaw_cluster", "watts_strogatz"]
 | 
			
		||||
    AGENT_COUNTS = range(5, 51, 5)# 5, 10, ..., 50 agents
 | 
			
		||||
    GRAPHS_PER_AGENT_COUNT = 100
 | 
			
		||||
    GRAPHS_PER_GRAPH_ALGO = GRAPHS_PER_AGENT_COUNT // len(GRAPH_GEN_ALGOS)
 | 
			
		||||
    SIMS_PER_GRAPH = 100
 | 
			
		||||
    OUTPUT_DIR = "datasets/consensus_dataset"
 | 
			
		||||
    
 | 
			
		||||
    # --- Setup ---
 | 
			
		||||
    seed = int(time.time())
 | 
			
		||||
    main_jax_key = jax.random.PRNGKey(seed)
 | 
			
		||||
    numpy_rng = np.random.default_rng(seed=seed + 1)
 | 
			
		||||
 | 
			
		||||
    print(f"🚀 Starting data generation...")
 | 
			
		||||
    print(f"Saving dataset to directory: '{OUTPUT_DIR}/'")
 | 
			
		||||
    os.makedirs(OUTPUT_DIR, exist_ok=True)
 | 
			
		||||
    
 | 
			
		||||
    # --- Main Generation Loop ---
 | 
			
		||||
    for n_agents in tqdm(AGENT_COUNTS, desc="Overall Progress"):
 | 
			
		||||
        agent_folder = os.path.join(OUTPUT_DIR, f"agents_{n_agents}")
 | 
			
		||||
        os.makedirs(agent_folder, exist_ok=True)
 | 
			
		||||
        
 | 
			
		||||
        for graph_idx in tqdm(range(GRAPHS_PER_AGENT_COUNT), desc=f"Graphs for N={n_agents}", leave=False):
 | 
			
		||||
            graph_algo_idx = graph_idx // GRAPHS_PER_GRAPH_ALGO 
 | 
			
		||||
            
 | 
			
		||||
            # 1. Configure the simulation using the imported class
 | 
			
		||||
            config = sims.consensus.ConsensusConfig()
 | 
			
		||||
 | 
			
		||||
            config.num_agents=n_agents
 | 
			
		||||
            config.num_sims=SIMS_PER_GRAPH
 | 
			
		||||
 | 
			
		||||
            # 2. Generate graph and its metrics
 | 
			
		||||
            G, graph_type, adj_matrix_np = generate_connected_graph(numpy_rng, config.num_agents, GRAPH_GEN_ALGOS[graph_algo_idx])
 | 
			
		||||
            adj_matrix_jax = jnp.array(adj_matrix_np)
 | 
			
		||||
            graph_metrics = calculate_graph_metrics(G)
 | 
			
		||||
            graph_metrics["graph_type"] = graph_type
 | 
			
		||||
 | 
			
		||||
            # 3. Generate initial states for the 100 simulations
 | 
			
		||||
            main_jax_key, states_key = jax.random.split(main_jax_key)
 | 
			
		||||
            initial_states = generate_multiple_initial_states(
 | 
			
		||||
                states_key, config.num_sims, config.num_agents, config.max_range
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # 4. Run all simulations using the imported JAX function
 | 
			
		||||
            trajectories = sims.consensus.run_consensus_sim(adj_matrix_jax, initial_states, config)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            
 | 
			
		||||
            # 5. Package all data into a dictionary for logging
 | 
			
		||||
            log_data = {
 | 
			
		||||
                "simulation_config": {
 | 
			
		||||
                    "num_agents": config.num_agents,
 | 
			
		||||
                    "num_sims": config.num_sims,
 | 
			
		||||
                    "num_time_steps": config.num_time_steps,
 | 
			
		||||
                    "step_size": config.step_size,
 | 
			
		||||
                    "max_range": config.max_range,
 | 
			
		||||
                    "directed": config.directed,
 | 
			
		||||
                    "weighted": config.weighted,
 | 
			
		||||
                
 | 
			
		||||
                },
 | 
			
		||||
                "graph_metrics": graph_metrics,
 | 
			
		||||
                "adjacency_matrix": adj_matrix_np.tolist(),
 | 
			
		||||
                "initial_states": initial_states.tolist(),
 | 
			
		||||
                "trajectories": trajectories.tolist()
 | 
			
		||||
            }
 | 
			
		||||
            
 | 
			
		||||
            # 6. Save the complete log to a JSON file
 | 
			
		||||
            file_path = os.path.join(agent_folder, f"graph_{graph_idx:02d}.json")
 | 
			
		||||
            with open(file_path, 'w') as f:
 | 
			
		||||
                json.dump(log_data, f, indent=2)
 | 
			
		||||
                
 | 
			
		||||
    print(f"\n✅ Data generation complete! Check the '{OUTPUT_DIR}' folder.")
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										162
									
								
								generate_data_kuramoto.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										162
									
								
								generate_data_kuramoto.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,162 @@
 | 
			
		||||
import os
 | 
			
		||||
import json
 | 
			
		||||
import time
 | 
			
		||||
import jax
 | 
			
		||||
import jax.numpy as jnp
 | 
			
		||||
import numpy as np
 | 
			
		||||
import networkx as nx
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from functools import partial
 | 
			
		||||
from dataclasses import asdict
 | 
			
		||||
import sims
 | 
			
		||||
# --- Import necessary components from your Kuramoto simulation file ---
 | 
			
		||||
from sims.kuramoto import (
 | 
			
		||||
    KuramotoConfig,
 | 
			
		||||
    run_kuramoto_simulation,
 | 
			
		||||
    phase_coherence,
 | 
			
		||||
    mean_frequency,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from generate_data_consensus import generate_connected_graph, calculate_graph_metrics
 | 
			
		||||
 | 
			
		||||
# ==============================================================================
 | 
			
		||||
# SECTION 1: BATCHED OPERATIONS FOR EFFICIENT DATA GENERATION
 | 
			
		||||
# ==============================================================================
 | 
			
		||||
 | 
			
		||||
@partial(jax.jit, static_argnames=("config",))
 | 
			
		||||
def run_batched_simulations(
 | 
			
		||||
    thetas0_batch: jax.Array,    # (num_sims, N)
 | 
			
		||||
    omegas_batch: jax.Array,     # (num_sims, N)
 | 
			
		||||
    adj_mat: jax.Array,          # (N, N)
 | 
			
		||||
    config: KuramotoConfig
 | 
			
		||||
) -> tuple[jax.Array, jax.Array, jax.Array]:
 | 
			
		||||
    """
 | 
			
		||||
    Runs many Kuramoto simulations in parallel for the same graph but different
 | 
			
		||||
    initial conditions, and performs analysis.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        A tuple containing:
 | 
			
		||||
        - All trajectories (num_sims, T, N)
 | 
			
		||||
        - Final coherence R(T) for each sim (num_sims,)
 | 
			
		||||
        - Mean frequencies for each sim (num_sims, N)
 | 
			
		||||
    """
 | 
			
		||||
    # Create a batched version of the simulation runner using vmap.
 | 
			
		||||
    # This maps the function over the first axis of thetas0_batch and omegas_batch.
 | 
			
		||||
    vmapped_runner = jax.vmap(
 | 
			
		||||
        run_kuramoto_simulation,
 | 
			
		||||
        in_axes=(0, 0, None, None), # Map over thetas0, omegas; adj_mat and config are fixed
 | 
			
		||||
        out_axes=0
 | 
			
		||||
    )
 | 
			
		||||
    all_trajectories = vmapped_runner(thetas0_batch, omegas_batch, adj_mat, config)
 | 
			
		||||
    
 | 
			
		||||
    # Analyze the results in a batched manner
 | 
			
		||||
    # phase_coherence naturally works on the time axis, so we vmap over the sim axis.
 | 
			
		||||
    vmapped_coherence = jax.vmap(phase_coherence)(all_trajectories) # -> (num_sims, T)
 | 
			
		||||
    final_coherence = vmapped_coherence[:, -1] # -> (num_sims,)
 | 
			
		||||
 | 
			
		||||
    # vmap the mean_frequency calculation over trajectories and omegas
 | 
			
		||||
    vmapped_mean_freq = jax.vmap(mean_frequency, in_axes=(0, 0, None, None))
 | 
			
		||||
    all_mean_freqs = vmapped_mean_freq(all_trajectories, omegas_batch, adj_mat, config)
 | 
			
		||||
 | 
			
		||||
    return all_trajectories, final_coherence, all_mean_freqs
 | 
			
		||||
 | 
			
		||||
def generate_batched_initial_states(key: jax.Array, num_sims: int, config: KuramotoConfig):
 | 
			
		||||
    """Generates a batch of initial states."""
 | 
			
		||||
    keys = jax.random.split(key, num_sims)
 | 
			
		||||
    
 | 
			
		||||
    # We will just write a new generator that is vmap-friendly
 | 
			
		||||
    @jax.vmap
 | 
			
		||||
    def generate_single_state(k):
 | 
			
		||||
        key_theta, key_omega = jax.random.split(k)
 | 
			
		||||
        thetas0 = jax.random.uniform(key_theta, (config.num_agents,), minval=0, maxval=2 * jnp.pi)
 | 
			
		||||
        omegas = jax.random.normal(key_omega, (config.num_agents,)) # Using std=1, mean=0
 | 
			
		||||
        return thetas0, omegas
 | 
			
		||||
        
 | 
			
		||||
    thetas0_batch, omegas_batch = generate_single_state(keys)
 | 
			
		||||
    return thetas0_batch, omegas_batch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    """Main script to generate and save the Kuramoto simulation dataset."""
 | 
			
		||||
    # --- Configuration ---
 | 
			
		||||
    GRAPH_GEN_ALGOS = ["erdos_renyi", "barabasi_albert", "powerlaw_cluster", "watts_strogatz"]
 | 
			
		||||
    AGENT_COUNTS = range(5, 51, 5)  # 5, 10, ..., 50 agents
 | 
			
		||||
    # AGENT_COUNTS = [50]
 | 
			
		||||
    GRAPHS_PER_AGENT_COUNT = 100
 | 
			
		||||
    GRAPHS_PER_GRAPH_ALGO = GRAPHS_PER_AGENT_COUNT // len(GRAPH_GEN_ALGOS)
 | 
			
		||||
    SIMS_PER_GRAPH = 100
 | 
			
		||||
    OUTPUT_DIR = "datasets/kuramoto_dataset"
 | 
			
		||||
    
 | 
			
		||||
    # --- Setup ---
 | 
			
		||||
    seed = int(time.time())
 | 
			
		||||
    main_key = jax.random.PRNGKey(seed)
 | 
			
		||||
    numpy_rng = np.random.default_rng(seed=seed + 1)
 | 
			
		||||
 | 
			
		||||
    print(f"🚀 Starting Kuramoto data generation...")
 | 
			
		||||
    print(f"Saving dataset to: '{OUTPUT_DIR}/'")
 | 
			
		||||
    os.makedirs(OUTPUT_DIR, exist_ok=True)
 | 
			
		||||
    
 | 
			
		||||
    # --- Main Loop ---
 | 
			
		||||
    for n_agents in tqdm(AGENT_COUNTS, desc="Agent Counts"):
 | 
			
		||||
        agent_folder = os.path.join(OUTPUT_DIR, f"agents_{n_agents}")
 | 
			
		||||
        os.makedirs(agent_folder, exist_ok=True)
 | 
			
		||||
        
 | 
			
		||||
        for graph_idx in tqdm(range(GRAPHS_PER_AGENT_COUNT), desc=f"Graphs for N={n_agents}", leave=False):
 | 
			
		||||
            
 | 
			
		||||
            graph_algo_idx = graph_idx // GRAPHS_PER_GRAPH_ALGO 
 | 
			
		||||
            # 1. Setup config, keys, and graph
 | 
			
		||||
            config = KuramotoConfig()
 | 
			
		||||
            config.num_agents=n_agents
 | 
			
		||||
            config.coupling=1.5 
 | 
			
		||||
            config.T=15.0
 | 
			
		||||
            main_key, graph_key, state_key = jax.random.split(main_key, 3)
 | 
			
		||||
            
 | 
			
		||||
            G, graph_type, adj_matrix_np = generate_connected_graph(numpy_rng, config.num_agents, GRAPH_GEN_ALGOS[graph_algo_idx])
 | 
			
		||||
            adj_matrix_jax = jnp.array(adj_matrix_np)
 | 
			
		||||
            graph_metrics = calculate_graph_metrics(G)
 | 
			
		||||
            graph_metrics["graph_type"] = graph_type
 | 
			
		||||
 | 
			
		||||
            # 2. Generate a batch of initial conditions
 | 
			
		||||
            thetas0_batch, omegas_batch = generate_batched_initial_states(state_key, SIMS_PER_GRAPH, config)
 | 
			
		||||
 | 
			
		||||
            # 3. Run all simulations and analyses in a single, efficient call
 | 
			
		||||
            trajectories, final_R, mean_freqs = run_batched_simulations(
 | 
			
		||||
                thetas0_batch, omegas_batch, adj_matrix_jax, config
 | 
			
		||||
            )
 | 
			
		||||
            trajectories.block_until_ready() # Ensure computation is finished
 | 
			
		||||
 | 
			
		||||
            # 4. Package all data for logging
 | 
			
		||||
            log_data = {
 | 
			
		||||
                "config": {
 | 
			
		||||
                    "num_agents" : config.num_agents,
 | 
			
		||||
                    "coupling" : config.coupling,
 | 
			
		||||
                    "dt" : config.dt,
 | 
			
		||||
                    "T" : config.T,
 | 
			
		||||
                    "normalize_by_degree" : config.normalize_by_degree,
 | 
			
		||||
                    "directed" : config.directed,
 | 
			
		||||
                    "weighted" : config.weighted,
 | 
			
		||||
                },
 | 
			
		||||
                
 | 
			
		||||
                "graph_metrics": graph_metrics,
 | 
			
		||||
                "adjacency_matrix": adj_matrix_np.tolist(),
 | 
			
		||||
                "initial_conditions": {
 | 
			
		||||
                    "thetas0": thetas0_batch.tolist(),
 | 
			
		||||
                    "omegas": omegas_batch.tolist()
 | 
			
		||||
                },
 | 
			
		||||
                "results": {
 | 
			
		||||
                    "final_coherence": final_R.tolist(),
 | 
			
		||||
                    "mean_frequencies": mean_freqs.tolist()
 | 
			
		||||
                },
 | 
			
		||||
                # Optionally save full trajectories. Warning: can create large files.
 | 
			
		||||
                "trajectories": trajectories.tolist() 
 | 
			
		||||
            }
 | 
			
		||||
            
 | 
			
		||||
            # 6. Save the complete log to a JSON file
 | 
			
		||||
            file_path = os.path.join(agent_folder, f"graph_{graph_idx:02d}.json")
 | 
			
		||||
            with open(file_path, 'w') as f:
 | 
			
		||||
                json.dump(log_data, f, indent=2)
 | 
			
		||||
                
 | 
			
		||||
    print(f"\n✅ Data generation complete! Check the '{OUTPUT_DIR}' folder.")
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										136
									
								
								model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								model.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,136 @@
 | 
			
		||||
import jax
 | 
			
		||||
from jax import random
 | 
			
		||||
import jax.numpy as jnp
 | 
			
		||||
from train import ModelConfig, TrainConfig
 | 
			
		||||
import optax
 | 
			
		||||
from functools import partial
 | 
			
		||||
 | 
			
		||||
def init_linear_layer(
 | 
			
		||||
        key: jax.Array,
 | 
			
		||||
        in_features: int, 
 | 
			
		||||
        out_features: int,
 | 
			
		||||
        use_bias: bool = True
 | 
			
		||||
        ):
 | 
			
		||||
    
 | 
			
		||||
    """
 | 
			
		||||
    Initializes the weights and biases in a linear layer.
 | 
			
		||||
    """
 | 
			
		||||
    key_w, key_b = random.split(key)
 | 
			
		||||
    limit = jnp.sqrt(6/in_features)
 | 
			
		||||
    W = random.uniform(key_w, (in_features, out_features), minval=-limit, maxval=limit)
 | 
			
		||||
    params = {'W': W}
 | 
			
		||||
    if use_bias:
 | 
			
		||||
        b = random.uniform(key_b, (out_features,), minval=-limit, maxval=limit)
 | 
			
		||||
        params['b'] = b
 | 
			
		||||
    return params
 | 
			
		||||
 | 
			
		||||
def init_fn(key: jax.Array, config: ModelConfig):
 | 
			
		||||
    """
 | 
			
		||||
    Initializes all model parameters. Returns a pytree
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    key_embed, key_translate, key_attn_proj, key_head = random.split(key, 4)
 | 
			
		||||
 | 
			
		||||
    params = {
 | 
			
		||||
        "agent_embeddings" : {
 | 
			
		||||
            "weight" : random.normal(key_embed, shape=(config.num_agents, config.embedding_dim))
 | 
			
		||||
        },
 | 
			
		||||
        "translate": init_linear_layer(key_translate, config.input_dim, config.embedding_dim),
 | 
			
		||||
        "attn_proj": init_linear_layer(key_attn_proj, config.embedding_dim, 2 * config.embedding_dim, use_bias=False),
 | 
			
		||||
        "head": init_linear_layer(key_head, config.embedding_dim, config.output_dim)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return params
 | 
			
		||||
 | 
			
		||||
def forward(params: dict, input_timesteps: jax.Array, config: ModelConfig):
 | 
			
		||||
    """
 | 
			
		||||
    Model's forward function. Takes in the parameters and inptu timesteps, returns predictions
 | 
			
		||||
    """
 | 
			
		||||
    batch_size, num_agents, _ = input_timesteps.shape
 | 
			
		||||
 | 
			
		||||
    agent_embed = params["agent_embeddings"]["weight"]
 | 
			
		||||
    agent_embed = jnp.broadcast_to(agent_embed, (batch_size, num_agents, config.embedding_dim))
 | 
			
		||||
 | 
			
		||||
    attn_proj_out = agent_embed @ params["attn_proj"]['W']
 | 
			
		||||
    k, q = jnp.split(attn_proj_out, 2, axis=-1)
 | 
			
		||||
    v = input_timesteps @ params["translate"]['W'] + params["translate"]['b']
 | 
			
		||||
    att_scores = (q @ k.transpose(0, 2, 1) )/ jnp.sqrt(num_agents)
 | 
			
		||||
    att_weights = jax.nn.softmax(att_scores, axis=-1)
 | 
			
		||||
    weighted_average = att_weights @ v
 | 
			
		||||
    prediction = weighted_average @ params["head"]['W'] + params["head"]['b']
 | 
			
		||||
 | 
			
		||||
    return prediction
 | 
			
		||||
 | 
			
		||||
def get_attention_fn(params: dict, config: ModelConfig):
 | 
			
		||||
    """
 | 
			
		||||
    Calculates and returns the learned attention matrix between agents.
 | 
			
		||||
    This is a pure function for analysis.
 | 
			
		||||
    """
 | 
			
		||||
    embeddings = params['agent_embeddings']['weight']
 | 
			
		||||
    
 | 
			
		||||
    # Project embeddings to get keys (k) and queries (q) for the global graph
 | 
			
		||||
    attn_proj_out = embeddings @ params['attn_proj']['W']
 | 
			
		||||
    k, q = jnp.split(attn_proj_out, 2, axis=-1)
 | 
			
		||||
 | 
			
		||||
    # Note: Using sqrt(embedding_dim) as in the original get_attention method
 | 
			
		||||
    attn_scores = (q @ k.T) / jnp.sqrt(q.shape[-1])
 | 
			
		||||
    
 | 
			
		||||
    return jnp.asarray(attn_scores) # Return as NumPy array for logging
 | 
			
		||||
 | 
			
		||||
def train_model(config: ModelConfig, inputs: jax.Array, targets: jax.Array,
 | 
			
		||||
                true_graph: jax.Array,
 | 
			
		||||
                train_config: TrainConfig,
 | 
			
		||||
                ):
 | 
			
		||||
    
 | 
			
		||||
    key = random.PRNGKey(0)
 | 
			
		||||
    key, init_key = random.split(key)
 | 
			
		||||
    params = init_fn(init_key, config)
 | 
			
		||||
 | 
			
		||||
    optimizer = optax.adamw(train_config.learning_rate)
 | 
			
		||||
    opt_state = optimizer.init(params)
 | 
			
		||||
 | 
			
		||||
    def loss_fn(p, x_batch, y_batch, config):
 | 
			
		||||
        predictions = forward(p, x_batch, config)
 | 
			
		||||
        loss = jnp.mean(jnp.abs(predictions - y_batch))
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
    @partial(jax.jit, static_argnames=['config'])
 | 
			
		||||
    def update_step(params, opt_state, x_batch, y_batch, config):
 | 
			
		||||
        # FIX 1: Pass all necessary arguments to the loss function here.
 | 
			
		||||
        loss_val, grads = jax.value_and_grad(loss_fn)(params, x_batch, y_batch, config)
 | 
			
		||||
 | 
			
		||||
        updates, new_opt_state = optimizer.update(grads, opt_state, params)
 | 
			
		||||
        new_params = optax.apply_updates(params, updates)
 | 
			
		||||
 | 
			
		||||
        return new_params, new_opt_state, loss_val
 | 
			
		||||
 | 
			
		||||
    loss_history = {f"epoch_{i}": [] for i in range(train_config.epochs)}
 | 
			
		||||
    # FIX 2: Initialize with empty lists `[]` instead of `None`.
 | 
			
		||||
    graphs = {f"epoch_{i}": [] for i in range(train_config.epochs)}
 | 
			
		||||
    num_batches = len(inputs)
 | 
			
		||||
 | 
			
		||||
    for epoch in range(train_config.epochs):
 | 
			
		||||
        running_loss = 0.0
 | 
			
		||||
        for batch_num in range(num_batches):
 | 
			
		||||
            x, y = inputs[batch_num], targets[batch_num]
 | 
			
		||||
            # FIX 3: Pass the `config` object, as it's a static argument for JIT.
 | 
			
		||||
            params, opt_state, loss_val = update_step(params, opt_state, x, y, config)
 | 
			
		||||
            running_loss += loss_val
 | 
			
		||||
 | 
			
		||||
        epoch_loss = running_loss / num_batches
 | 
			
		||||
        loss_history[f"epoch_{epoch}"].append(epoch_loss)
 | 
			
		||||
 | 
			
		||||
        if train_config.verbose and (epoch + 1) % 10 == 0:
 | 
			
		||||
            print(f"Epoch {epoch+1:3d} | Loss: {epoch_loss:.6f}")
 | 
			
		||||
 | 
			
		||||
        if train_config.log and epoch % train_config.log_epoch_interval == 0:
 | 
			
		||||
            attn = get_attention_fn(params, config)
 | 
			
		||||
            graphs[f"epoch_{epoch}"].append(attn)
 | 
			
		||||
    
 | 
			
		||||
    all_logs = {
 | 
			
		||||
        "loss_history": loss_history,
 | 
			
		||||
        "graphs": graphs,
 | 
			
		||||
        "true_graph": true_graph,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return params, all_logs
 | 
			
		||||
@@ -10,6 +10,9 @@ dependencies = [
 | 
			
		||||
    "jax[cuda12]>=0.7.0",
 | 
			
		||||
    "jupyter>=1.1.1",
 | 
			
		||||
    "matplotlib>=3.10.3",
 | 
			
		||||
    "networkx>=3.5",
 | 
			
		||||
    "optax>=0.2.5",
 | 
			
		||||
    "scikit-learn>=1.7.1",
 | 
			
		||||
    "seaborn>=0.13.2",
 | 
			
		||||
    "tqdm>=4.67.1",
 | 
			
		||||
]
 | 
			
		||||
							
								
								
									
										2
									
								
								sims/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								sims/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,2 @@
 | 
			
		||||
from .consensus import *
 | 
			
		||||
from .kuramoto import *
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								sims/__pycache__/__init__.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								sims/__pycache__/__init__.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								sims/__pycache__/consensus.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								sims/__pycache__/consensus.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								sims/__pycache__/kuramoto.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								sims/__pycache__/kuramoto.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							@@ -61,6 +61,10 @@ def generate_random_adjacency_matrix(key: jax.Array, config: ConsensusConfig):
 | 
			
		||||
        config: ConsensusConfig
 | 
			
		||||
            Config for Consensus dyanmics
 | 
			
		||||
 | 
			
		||||
    Returns
 | 
			
		||||
    ---------------------
 | 
			
		||||
    adj_matrices: jax.Array (num_agents, num_agents)
 | 
			
		||||
        Random matrix 
 | 
			
		||||
    """    
 | 
			
		||||
    rand_matrix =  jax.random.uniform(key, shape=(config.num_agents, config.num_agents))
 | 
			
		||||
    # idxs = jnp.arange(config.num_agents)
 | 
			
		||||
@@ -90,7 +94,7 @@ def generate_random_agent_states(key: jax.Array, config: ConsensusConfig):
 | 
			
		||||
 | 
			
		||||
    Returns
 | 
			
		||||
    ---------------------
 | 
			
		||||
        rand_states: jax.Array (num_agents, 1)
 | 
			
		||||
        rand_states: jax.Array (num_sims, num_agents)
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    rand_states = jax.random.uniform(key, shape=(config.num_sims, config.num_agents), minval=-config.max_range, maxval=config.max_range)
 | 
			
		||||
@@ -113,7 +117,7 @@ def run_consensus_sim(adj_mat: jax.Array, initial_agent_state: jax.Array, config
 | 
			
		||||
        config: ConsensusConfig
 | 
			
		||||
            Config for Consensus dynamics
 | 
			
		||||
    """ 
 | 
			
		||||
 | 
			
		||||
    # batched consensus step (meant for many initial states)
 | 
			
		||||
    batched_consensus_step = jax.vmap(consensus_step, in_axes=(None, 0, None), out_axes=0)
 | 
			
		||||
    
 | 
			
		||||
    def step(x_prev, _):
 | 
			
		||||
@@ -5,26 +5,18 @@ from functools import partial
 | 
			
		||||
import numpy as np
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
 | 
			
		||||
# -------------------- Configuration --------------------
 | 
			
		||||
@dataclass(frozen=True)
 | 
			
		||||
class KuramotoConfig:
 | 
			
		||||
    """Configuration for the Kuramoto model simulation."""
 | 
			
		||||
    num_agents: int = 10         # N: Number of oscillators
 | 
			
		||||
    coupling: float = 1.0        # K: Coupling strength
 | 
			
		||||
    dt: float = 0.01             # Δt: Integration time step
 | 
			
		||||
    T: float = 10.0              # Total simulation time
 | 
			
		||||
    
 | 
			
		||||
    # Adjacency matrix properties
 | 
			
		||||
    normalize_by_degree: bool = True
 | 
			
		||||
    time_steps: int = int(T/dt)
 | 
			
		||||
    normalize_by_degree: bool = False
 | 
			
		||||
    directed: bool = False
 | 
			
		||||
    weighted: bool = False
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def num_time_steps(self) -> int:
 | 
			
		||||
        """Total number of simulation steps."""
 | 
			
		||||
        return int(self.T / self.dt)
 | 
			
		||||
 | 
			
		||||
# -------------------- Core Dynamics --------------------
 | 
			
		||||
@partial(jax.jit, static_argnames=("config",))
 | 
			
		||||
def kuramoto_derivative(theta: jax.Array,        # (N,) phase angles
 | 
			
		||||
                        omega: jax.Array,        # (N,) natural frequencies
 | 
			
		||||
@@ -88,7 +80,7 @@ def run_kuramoto_simulation(
 | 
			
		||||
        scan_fn,
 | 
			
		||||
        thetas0,
 | 
			
		||||
        None,
 | 
			
		||||
        length=config.num_time_steps
 | 
			
		||||
        length=config.time_steps
 | 
			
		||||
    )
 | 
			
		||||
    return trajectory
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										192
									
								
								test.ipynb
									
									
									
									
									
								
							
							
						
						
									
										192
									
								
								test.ipynb
									
									
									
									
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										19
									
								
								train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								train.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,19 @@
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
 | 
			
		||||
class ModelConfig:
 | 
			
		||||
    num_agents: int = 10
 | 
			
		||||
    embedding_dim: int = 64
 | 
			
		||||
    input_dim: int = 1
 | 
			
		||||
    output_dim: int = 1
 | 
			
		||||
    simulation_type:str = "consensus"
 | 
			
		||||
 | 
			
		||||
class TrainConfig:
 | 
			
		||||
    epochs: float = 100
 | 
			
		||||
    learning_rate: float = 1e-3
 | 
			
		||||
    verbose: bool = True
 | 
			
		||||
    log: bool = True
 | 
			
		||||
    log_epoch_interval: int = 10
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										259
									
								
								train_and_eval.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										259
									
								
								train_and_eval.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,259 @@
 | 
			
		||||
import os
 | 
			
		||||
import json
 | 
			
		||||
import jax
 | 
			
		||||
import jax.numpy as jnp
 | 
			
		||||
import numpy as np
 | 
			
		||||
from dataclasses import dataclass, field, asdict
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from sklearn.metrics import f1_score
 | 
			
		||||
import uuid
 | 
			
		||||
import pickle
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Import from your existing model file
 | 
			
		||||
from model import ModelConfig, TrainConfig, train_model, get_attention_fn
 | 
			
		||||
 | 
			
		||||
# Define an Enum for the data source for clarity and type safety
 | 
			
		||||
 | 
			
		||||
# Overwrite the original TrainConfig to include our new parameters
 | 
			
		||||
 | 
			
		||||
class TrainConfig:
 | 
			
		||||
    """Configuration for the training process."""
 | 
			
		||||
    learning_rate: float = 1e-3
 | 
			
		||||
    epochs: int = 100
 | 
			
		||||
    batch_size: int = 4096
 | 
			
		||||
    verbose: bool = False  # Set to True to see epoch loss during training
 | 
			
		||||
    log: bool = True
 | 
			
		||||
    log_epoch_interval: int = 10
 | 
			
		||||
    
 | 
			
		||||
    # --- New parameters for this script ---
 | 
			
		||||
    data_directory:str = "datasets/" + sys.argv[1] + "_dataset"
 | 
			
		||||
    # Threshold for converting attention scores to a binary graph
 | 
			
		||||
    f1_threshold: float = -0.4
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
def prepare_data_for_model(trajectories: np.ndarray, batch_size: int) -> tuple[np.ndarray, np.ndarray]:
 | 
			
		||||
    """
 | 
			
		||||
    Converts simulation trajectories into input-output pairs for the model.
 | 
			
		||||
    Input: state at time t. Target: state at time t+1.
 | 
			
		||||
    
 | 
			
		||||
    Args:
 | 
			
		||||
        trajectories: A numpy array of shape (num_sims, num_timesteps, num_agents).
 | 
			
		||||
        batch_size: The desired batch size for training.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        A tuple of (batched_inputs, batched_targets).
 | 
			
		||||
    """
 | 
			
		||||
    # For each simulation, create (input, target) pairs
 | 
			
		||||
    all_inputs = []
 | 
			
		||||
    all_targets = []
 | 
			
		||||
    
 | 
			
		||||
    num_sims, num_timesteps, num_agents = trajectories.shape
 | 
			
		||||
    
 | 
			
		||||
    # trajectories = np.reshape(trajectories, shape=(num_sims * num_timesteps, num_agents))
 | 
			
		||||
 | 
			
		||||
    for i_sim in range(num_sims):
 | 
			
		||||
        for j_tstep in range(num_timesteps-1):
 | 
			
		||||
            input = trajectories[i_sim, j_tstep, :]
 | 
			
		||||
            target = trajectories[i_sim, j_tstep + 1, :]
 | 
			
		||||
            all_inputs.append(input)
 | 
			
		||||
            all_targets.append(target)
 | 
			
		||||
 | 
			
		||||
    all_indices = np.arange(len(all_inputs))
 | 
			
		||||
    np.random.shuffle(all_indices)
 | 
			
		||||
    all_inputs = np.array(all_inputs)
 | 
			
		||||
    all_targets = np.array(all_targets)
 | 
			
		||||
 | 
			
		||||
    all_inputs = all_inputs[all_indices]
 | 
			
		||||
    all_targets = all_targets[all_indices]
 | 
			
		||||
            
 | 
			
		||||
    # for sim_idx in range(num_sims):
 | 
			
		||||
    #     # Input is state from t=0 to t=T-2
 | 
			
		||||
    #     inputs = trajectories[sim_idx, :-1, :]
 | 
			
		||||
    #     # Target is state from t=1 to t=T-1
 | 
			
		||||
    #     targets = trajectories[sim_idx, 1:, :]
 | 
			
		||||
    #     all_inputs.append(inputs)
 | 
			
		||||
    #     all_targets.append(targets)
 | 
			
		||||
        
 | 
			
		||||
    # Concatenate all pairs from all simulations
 | 
			
		||||
    # Shape -> (num_sims * (num_timesteps - 1), num_agents)
 | 
			
		||||
    # full_dataset_inputs = np.concatenate(all_inputs, axis=0)
 | 
			
		||||
    # full_dataset_targets = np.concatenate(all_targets, axis=0)
 | 
			
		||||
    
 | 
			
		||||
    # Reshape to have a feature dimension
 | 
			
		||||
    # Shape -> (total_samples, num_agents, 1)
 | 
			
		||||
    full_dataset_inputs = np.expand_dims(all_inputs, axis=-1)
 | 
			
		||||
    full_dataset_targets = np.expand_dims(all_targets, axis=-1)
 | 
			
		||||
    
 | 
			
		||||
    # Create batches
 | 
			
		||||
    num_samples = full_dataset_inputs.shape[0]
 | 
			
		||||
    num_batches = num_samples // batch_size
 | 
			
		||||
    
 | 
			
		||||
    # Truncate to full batches
 | 
			
		||||
    truncated_inputs = full_dataset_inputs[:num_batches * batch_size]
 | 
			
		||||
    truncated_targets = full_dataset_targets[:num_batches * batch_size]
 | 
			
		||||
    
 | 
			
		||||
    # Reshape into batches
 | 
			
		||||
    # Shape -> (num_batches, batch_size, num_agents, 1)
 | 
			
		||||
    batched_inputs = truncated_inputs.reshape(num_batches, batch_size, num_agents, 1)
 | 
			
		||||
    batched_targets = truncated_targets.reshape(num_batches, batch_size, num_agents, 1)
 | 
			
		||||
    
 | 
			
		||||
    return batched_inputs, batched_targets
 | 
			
		||||
 | 
			
		||||
def calculate_f1_score(
 | 
			
		||||
    params: dict, 
 | 
			
		||||
    model_config: ModelConfig, 
 | 
			
		||||
    true_graph: np.ndarray, 
 | 
			
		||||
    threshold: float
 | 
			
		||||
) -> float:
 | 
			
		||||
    """
 | 
			
		||||
    Extracts the learned attention graph, thresholds it, and computes the F1 score.
 | 
			
		||||
    """
 | 
			
		||||
    # Get the learned attention matrix (N, N)
 | 
			
		||||
    learned_graph_scores = np.array(get_attention_fn(params, model_config))
 | 
			
		||||
    
 | 
			
		||||
    # Normalize scores to [0, 1] for consistent thresholding (optional but good practice)
 | 
			
		||||
    # This uses min-max scaling on the flattened array
 | 
			
		||||
    # flat_scores = learned_graph_scores.flatten()
 | 
			
		||||
    # min_s, max_s = flat_scores.min(), flat_scores.max()
 | 
			
		||||
    # if max_s > min_s:
 | 
			
		||||
    #     learned_graph_scores = (learned_graph_scores - min_s) / (max_s - min_s)
 | 
			
		||||
 | 
			
		||||
    # Threshold to get a binary predicted graph
 | 
			
		||||
    predicted_graph = (learned_graph_scores > threshold).astype(int)
 | 
			
		||||
    
 | 
			
		||||
    # The diagonal is not part of the prediction task
 | 
			
		||||
    # np.fill_diagonal(predicted_graph, 0)
 | 
			
		||||
    # np.fill_diagonal(true_graph, 0)
 | 
			
		||||
    
 | 
			
		||||
    # Flatten both graphs to treat this as a binary classification problem
 | 
			
		||||
    true_flat = true_graph.flatten()
 | 
			
		||||
    pred_flat = predicted_graph.flatten()
 | 
			
		||||
    
 | 
			
		||||
    return f1_score(true_flat, pred_flat)
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    """Main script to run the training and evaluation pipeline."""
 | 
			
		||||
    
 | 
			
		||||
    train_config = TrainConfig()
 | 
			
		||||
    
 | 
			
		||||
    # Check if the data directory exists
 | 
			
		||||
    if not os.path.isdir(train_config.data_directory):
 | 
			
		||||
        print(f"Error: Data directory '{train_config.data_directory}' not found.")
 | 
			
		||||
        print(f"Please run the data generation script for '{train_config.data_directory}' first.")
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    print(f"🚀 Starting training pipeline for '{train_config.data_directory}' data.")
 | 
			
		||||
    
 | 
			
		||||
    # Get sorted list of agent directories
 | 
			
		||||
    agent_dirs = sorted(
 | 
			
		||||
        [d for d in os.listdir(train_config.data_directory) if d.startswith("agents_")],
 | 
			
		||||
        key=lambda x: int(x.split('_')[1])
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    for agent_dir_name in agent_dirs:
 | 
			
		||||
        agent_dir_path = os.path.join(train_config.data_directory, agent_dir_name)
 | 
			
		||||
        
 | 
			
		||||
        all_results_for_agent = []
 | 
			
		||||
        
 | 
			
		||||
        graph_files = sorted([f for f in os.listdir(agent_dir_path) if f.endswith(".json")])
 | 
			
		||||
        
 | 
			
		||||
        print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...")
 | 
			
		||||
        
 | 
			
		||||
        for graph_file_name in tqdm(graph_files, desc=f"Training on {agent_dir_name}"):
 | 
			
		||||
            file_path = os.path.join(agent_dir_path, graph_file_name)
 | 
			
		||||
            
 | 
			
		||||
            with open(file_path, 'r') as f:
 | 
			
		||||
                data = json.load(f)
 | 
			
		||||
 | 
			
		||||
            # 1. Load and Prepare Data
 | 
			
		||||
            trajectories = np.array(data['trajectories'])
 | 
			
		||||
            s, l, n = trajectories.shape
 | 
			
		||||
            # trajectories = trajectories.T
 | 
			
		||||
            # np.random.shuffle(trajectories)
 | 
			
		||||
            # trajectories = np.random.shuffle(trajectories)
 | 
			
		||||
            true_graph = np.array(data['adjacency_matrix'])
 | 
			
		||||
            inputs, targets = prepare_data_for_model(trajectories, train_config.batch_size)
 | 
			
		||||
            
 | 
			
		||||
            # 2. Configure Model
 | 
			
		||||
            num_agents = trajectories.shape[-1]
 | 
			
		||||
            model_config = ModelConfig(
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            model_config.num_agents=num_agents
 | 
			
		||||
            model_config.input_dim=1 # Each agent has a single state value at time t
 | 
			
		||||
            model_config.output_dim=1
 | 
			
		||||
            model_config.embedding_dim=32 
 | 
			
		||||
            
 | 
			
		||||
            # 3. Train the Model
 | 
			
		||||
            # This relies on the modified train_model that returns final params
 | 
			
		||||
            final_params, train_logs = train_model(
 | 
			
		||||
                config=model_config,
 | 
			
		||||
                inputs=inputs,
 | 
			
		||||
                targets=targets,
 | 
			
		||||
                true_graph=true_graph,
 | 
			
		||||
                train_config=train_config
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # 4. Evaluate
 | 
			
		||||
            f1 = calculate_f1_score(
 | 
			
		||||
                final_params, 
 | 
			
		||||
                model_config, 
 | 
			
		||||
                true_graph, 
 | 
			
		||||
                train_config.f1_threshold
 | 
			
		||||
            )
 | 
			
		||||
            
 | 
			
		||||
            loss_history_serializable = {
 | 
			
		||||
                epoch: [loss.item() for loss in losses] 
 | 
			
		||||
                for epoch, losses in train_logs['loss_history'].items()
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            random_id = str(uuid.uuid4())
 | 
			
		||||
 | 
			
		||||
            # 5. Log Results
 | 
			
		||||
            result_log = {
 | 
			
		||||
                # "model_name": random_id,
 | 
			
		||||
                "source_file": graph_file_name,
 | 
			
		||||
                "graph_metrics": data['graph_metrics'],
 | 
			
		||||
                "f1_score": f1,
 | 
			
		||||
                "training_loss_history": loss_history_serializable,
 | 
			
		||||
                "config": {
 | 
			
		||||
                    # Manually create the dictionary for the model config
 | 
			
		||||
                    "model": {
 | 
			
		||||
                        "num_agents": model_config.num_agents,
 | 
			
		||||
                        "input_dim": model_config.input_dim,
 | 
			
		||||
                        "output_dim": model_config.output_dim,
 | 
			
		||||
                        "embedding_dim": model_config.embedding_dim
 | 
			
		||||
                    },
 | 
			
		||||
                    # This is correct because TrainConfig is a dataclass
 | 
			
		||||
                    "training": vars(train_config)
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            result_final_params = final_params
 | 
			
		||||
            
 | 
			
		||||
 | 
			
		||||
            all_results_for_agent.append(result_log)
 | 
			
		||||
 | 
			
		||||
        # 6. Save aggregated results for this agent count
 | 
			
		||||
        results_dir = os.path.join(agent_dir_path, "results")
 | 
			
		||||
        os.makedirs(results_dir, exist_ok=True)
 | 
			
		||||
        
 | 
			
		||||
        output_file = os.path.join(results_dir, "summary_results.json")
 | 
			
		||||
        with open(output_file, 'w') as f:
 | 
			
		||||
            json.dump(all_results_for_agent, f, indent=2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        model_path = os.path.join(results_dir, "model_params")
 | 
			
		||||
        os.makedirs(model_path, exist_ok=True)
 | 
			
		||||
        with open(os.path.join(model_path,"model_params" + ".pkl"), "wb") as f:
 | 
			
		||||
            pickle.dump(final_params, f)
 | 
			
		||||
            
 | 
			
		||||
        print(f"✅ Results for {agent_dir_name} saved to {output_file}")
 | 
			
		||||
        
 | 
			
		||||
    print("\n🎉 Pipeline finished successfully!")
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										217
									
								
								train_and_eval_w_noise.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										217
									
								
								train_and_eval_w_noise.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,217 @@
 | 
			
		||||
import os
 | 
			
		||||
import json
 | 
			
		||||
import jax
 | 
			
		||||
import jax.numpy as jnp
 | 
			
		||||
import numpy as np
 | 
			
		||||
from dataclasses import dataclass, field, asdict
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from sklearn.metrics import f1_score
 | 
			
		||||
 | 
			
		||||
# Import from your existing model file
 | 
			
		||||
from model import ModelConfig, train_model, get_attention_fn
 | 
			
		||||
 | 
			
		||||
# --- MODIFICATION: Enums for clarity and safety ---
 | 
			
		||||
class DataSource(Enum):
 | 
			
		||||
    CONSENSUS = "consensus"
 | 
			
		||||
    KURAMOTO = "kuramoto"
 | 
			
		||||
 | 
			
		||||
class NoiseType(Enum):
 | 
			
		||||
    NONE = "none"
 | 
			
		||||
    NORMAL = "normal"
 | 
			
		||||
    UNIFORM = "uniform"
 | 
			
		||||
 | 
			
		||||
# --- MODIFICATION: Updated TrainConfig ---
 | 
			
		||||
@dataclass
 | 
			
		||||
class TrainConfig:
 | 
			
		||||
    """Configuration for the training process."""
 | 
			
		||||
    learning_rate: float = 1e-3
 | 
			
		||||
    epochs: int = 50
 | 
			
		||||
    batch_size: int = 64
 | 
			
		||||
    verbose: bool = False
 | 
			
		||||
    log: bool = True
 | 
			
		||||
    log_epoch_interval: int = 10
 | 
			
		||||
    
 | 
			
		||||
    # Data and Noise parameters
 | 
			
		||||
    data_source: DataSource = DataSource.CONSENSUS
 | 
			
		||||
    noise_type: NoiseType = NoiseType.NORMAL
 | 
			
		||||
    noise_level: float = 0.1  # Stddev for Normal, half-width for Uniform
 | 
			
		||||
 | 
			
		||||
    # Evaluation parameter
 | 
			
		||||
    f1_threshold: float = 0.5
 | 
			
		||||
    
 | 
			
		||||
    @property
 | 
			
		||||
    def source_data_directory(self) -> str:
 | 
			
		||||
        """The directory where the source dataset is located."""
 | 
			
		||||
        return f"{self.data_source.value}_dataset"
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def results_directory_name(self) -> str:
 | 
			
		||||
        """Generates a unique output directory name for separate logging."""
 | 
			
		||||
        if self.noise_type == NoiseType.NONE or self.noise_level == 0:
 | 
			
		||||
            return f"results_noiseless"
 | 
			
		||||
        return f"results_noise_{self.noise_type.value}_{self.noise_level}"
 | 
			
		||||
 | 
			
		||||
# --- MODIFICATION: Updated data prep function to add noise ---
 | 
			
		||||
def prepare_data_for_model(
 | 
			
		||||
    trajectories: np.ndarray, 
 | 
			
		||||
    key: jax.Array, 
 | 
			
		||||
    train_config: TrainConfig
 | 
			
		||||
) -> tuple[np.ndarray, np.ndarray]:
 | 
			
		||||
    """
 | 
			
		||||
    Converts trajectories to input-output pairs and adds noise to the inputs.
 | 
			
		||||
    """
 | 
			
		||||
    all_inputs, all_targets = [], []
 | 
			
		||||
    num_sims, num_timesteps, num_agents = trajectories.shape
 | 
			
		||||
    
 | 
			
		||||
    for sim_idx in range(num_sims):
 | 
			
		||||
        all_inputs.append(trajectories[sim_idx, :-1, :])
 | 
			
		||||
        all_targets.append(trajectories[sim_idx, 1:, :])
 | 
			
		||||
        
 | 
			
		||||
    full_dataset_inputs = np.concatenate(all_inputs, axis=0)
 | 
			
		||||
    full_dataset_targets = np.concatenate(all_targets, axis=0)
 | 
			
		||||
    
 | 
			
		||||
    # --- NOISE INJECTION BLOCK ---
 | 
			
		||||
    if train_config.noise_type != NoiseType.NONE and train_config.noise_level > 0:
 | 
			
		||||
        noise_shape = full_dataset_inputs.shape
 | 
			
		||||
        if train_config.noise_type == NoiseType.NORMAL:
 | 
			
		||||
            noise = jax.random.normal(key, noise_shape) * train_config.noise_level
 | 
			
		||||
        elif train_config.noise_type == NoiseType.UNIFORM:
 | 
			
		||||
            noise = jax.random.uniform(
 | 
			
		||||
                key, noise_shape, 
 | 
			
		||||
                minval=-train_config.noise_level, 
 | 
			
		||||
                maxval=train_config.noise_level
 | 
			
		||||
            )
 | 
			
		||||
        full_dataset_inputs += np.array(noise) # Add noise to inputs
 | 
			
		||||
    # --- END NOISE BLOCK ---
 | 
			
		||||
 | 
			
		||||
    full_dataset_inputs = np.expand_dims(full_dataset_inputs, axis=-1)
 | 
			
		||||
    full_dataset_targets = np.expand_dims(full_dataset_targets, axis=-1)
 | 
			
		||||
    
 | 
			
		||||
    num_samples = full_dataset_inputs.shape[0]
 | 
			
		||||
    num_batches = num_samples // train_config.batch_size
 | 
			
		||||
    
 | 
			
		||||
    truncated_inputs = full_dataset_inputs[:num_batches * train_config.batch_size]
 | 
			
		||||
    truncated_targets = full_dataset_targets[:num_batches * train_config.batch_size]
 | 
			
		||||
    
 | 
			
		||||
    batched_inputs = truncated_inputs.reshape(num_batches, train_config.batch_size, num_agents, 1)
 | 
			
		||||
    batched_targets = truncated_targets.reshape(num_batches, train_config.batch_size, num_agents, 1)
 | 
			
		||||
    
 | 
			
		||||
    return batched_inputs, batched_targets
 | 
			
		||||
 | 
			
		||||
def calculate_f1_score(
 | 
			
		||||
    params: dict, 
 | 
			
		||||
    model_config: ModelConfig, 
 | 
			
		||||
    true_graph: np.ndarray, 
 | 
			
		||||
    threshold: float
 | 
			
		||||
) -> float:
 | 
			
		||||
    """Extracts the learned graph and computes the F1 score."""
 | 
			
		||||
    learned_scores = np.array(get_attention_fn(params, model_config))
 | 
			
		||||
    flat_scores = learned_scores.flatten()
 | 
			
		||||
    min_s, max_s = flat_scores.min(), flat_scores.max()
 | 
			
		||||
    if max_s > min_s:
 | 
			
		||||
        learned_scores = (learned_scores - min_s) / (max_s - min_s)
 | 
			
		||||
    predicted_graph = (learned_scores > threshold).astype(int)
 | 
			
		||||
    
 | 
			
		||||
    np.fill_diagonal(predicted_graph, 0)
 | 
			
		||||
    np.fill_diagonal(true_graph, 0)
 | 
			
		||||
    
 | 
			
		||||
    return f1_score(true_graph.flatten(), predicted_graph.flatten())
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    """Main script to run the training and evaluation pipeline."""
 | 
			
		||||
    
 | 
			
		||||
    # Configure your training run here
 | 
			
		||||
    train_config = TrainConfig(
 | 
			
		||||
        noise_type=NoiseType.NORMAL,
 | 
			
		||||
        noise_level=0.1
 | 
			
		||||
    )
 | 
			
		||||
    
 | 
			
		||||
    if not os.path.isdir(train_config.source_data_directory):
 | 
			
		||||
        print(f"Error: Source data '{train_config.source_data_directory}' not found.")
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    print(f"🚀 Starting training pipeline for '{train_config.data_source.value}' data.")
 | 
			
		||||
    print(f"Noise Configuration: type={train_config.noise_type.value}, level={train_config.noise_level}")
 | 
			
		||||
 | 
			
		||||
    # --- MODIFICATION: Main JAX key for noise generation ---
 | 
			
		||||
    main_key = jax.random.PRNGKey(42)
 | 
			
		||||
 | 
			
		||||
    agent_dirs = sorted(
 | 
			
		||||
        [d for d in os.listdir(train_config.source_data_directory) if d.startswith("agents_")],
 | 
			
		||||
        key=lambda x: int(x.split('_')[1])
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    for agent_dir_name in agent_dirs:
 | 
			
		||||
        agent_dir_path = os.path.join(train_config.source_data_directory, agent_dir_name)
 | 
			
		||||
        all_results_for_agent = []
 | 
			
		||||
        graph_files = sorted([f for f in os.listdir(agent_dir_path) if f.endswith(".json")])
 | 
			
		||||
        
 | 
			
		||||
        print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...")
 | 
			
		||||
        
 | 
			
		||||
        for graph_file_name in tqdm(graph_files, desc=f"Training on {agent_dir_name}"):
 | 
			
		||||
            file_path = os.path.join(agent_dir_path, graph_file_name)
 | 
			
		||||
            
 | 
			
		||||
            with open(file_path, 'r') as f:
 | 
			
		||||
                data = json.load(f)
 | 
			
		||||
 | 
			
		||||
            main_key, data_key = jax.random.split(main_key)
 | 
			
		||||
            trajectories = np.array(data['trajectories'])
 | 
			
		||||
            true_graph = np.array(data['adjacency_matrix'])
 | 
			
		||||
            
 | 
			
		||||
            # --- MODIFICATION: Pass key and config to data prep ---
 | 
			
		||||
            inputs, targets = prepare_data_for_model(trajectories, data_key, train_config)
 | 
			
		||||
            
 | 
			
		||||
            num_agents = int(trajectories.shape[-1])
 | 
			
		||||
            model_config = ModelConfig()
 | 
			
		||||
            model_config.num_agents = num_agents
 | 
			
		||||
            model_config.input_dim = 1
 | 
			
		||||
            model_config.output_dim = 1
 | 
			
		||||
            model_config.embedding_dim = 32
 | 
			
		||||
            
 | 
			
		||||
            final_params, train_logs = train_model(
 | 
			
		||||
                config=model_config,
 | 
			
		||||
                inputs=inputs,
 | 
			
		||||
                targets=targets,
 | 
			
		||||
                true_graph=true_graph,
 | 
			
		||||
                train_config=train_config
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            f1 = calculate_f1_score(final_params, model_config, true_graph, train_config.f1_threshold)
 | 
			
		||||
            
 | 
			
		||||
            loss_history_serializable = {
 | 
			
		||||
                epoch: [loss.item() for loss in losses] 
 | 
			
		||||
                for epoch, losses in train_logs['loss_history'].items()
 | 
			
		||||
            }
 | 
			
		||||
            
 | 
			
		||||
            result_log = {
 | 
			
		||||
                "source_file": graph_file_name,
 | 
			
		||||
                "graph_metrics": data['graph_metrics'],
 | 
			
		||||
                "f1_score": f1,
 | 
			
		||||
                "training_loss_history": loss_history_serializable,
 | 
			
		||||
                "config": {
 | 
			
		||||
                    "model": {
 | 
			
		||||
                        "num_agents": model_config.num_agents,
 | 
			
		||||
                        "input_dim": model_config.input_dim,
 | 
			
		||||
                        "output_dim": model_config.output_dim,
 | 
			
		||||
                        "embedding_dim": model_config.embedding_dim
 | 
			
		||||
                    },
 | 
			
		||||
                    "training": asdict(train_config)
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            all_results_for_agent.append(result_log)
 | 
			
		||||
 | 
			
		||||
        # --- MODIFICATION: Save to a separate results directory ---
 | 
			
		||||
        results_dir = os.path.join(agent_dir_path, train_config.results_directory_name)
 | 
			
		||||
        os.makedirs(results_dir, exist_ok=True)
 | 
			
		||||
        
 | 
			
		||||
        output_file = os.path.join(results_dir, "summary_results.json")
 | 
			
		||||
        with open(output_file, 'w') as f:
 | 
			
		||||
            json.dump(all_results_for_agent, f, indent=2)
 | 
			
		||||
            
 | 
			
		||||
        print(f"✅ Results for {agent_dir_name} saved to {output_file}")
 | 
			
		||||
        
 | 
			
		||||
    print("\n🎉 Pipeline finished successfully!")
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										114
									
								
								uv.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										114
									
								
								uv.lock
									
									
									
										generated
									
									
									
								
							@@ -6,6 +6,15 @@ resolution-markers = [
 | 
			
		||||
    "python_full_version < '3.13'",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "absl-py"
 | 
			
		||||
version = "2.3.1"
 | 
			
		||||
source = { registry = "https://pypi.org/simple" }
 | 
			
		||||
sdist = { url = "https://files.pythonhosted.org/packages/10/2a/c93173ffa1b39c1d0395b7e842bbdc62e556ca9d8d3b5572926f3e4ca752/absl_py-2.3.1.tar.gz", hash = "sha256:a97820526f7fbfd2ec1bce83f3f25e3a14840dac0d8e02a0b71cd75db3f77fc9", size = 116588, upload-time = "2025-07-03T09:31:44.05Z" }
 | 
			
		||||
wheels = [
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/8f/aa/ba0014cc4659328dc818a28827be78e6d97312ab0cb98105a770924dc11e/absl_py-2.3.1-py3-none-any.whl", hash = "sha256:eeecf07f0c2a93ace0772c92e596ace6d3d3996c042b2128459aaae2a76de11d", size = 135811, upload-time = "2025-07-03T09:31:42.253Z" },
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "anyio"
 | 
			
		||||
version = "4.9.0"
 | 
			
		||||
@@ -218,6 +227,24 @@ wheels = [
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/20/94/c5790835a017658cbfabd07f3bfb549140c3ac458cfc196323996b10095a/charset_normalizer-3.4.2-py3-none-any.whl", hash = "sha256:7f56930ab0abd1c45cd15be65cc741c28b1c9a34876ce8c17a2fa107810c0af0", size = 52626, upload-time = "2025-05-02T08:34:40.053Z" },
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "chex"
 | 
			
		||||
version = "0.1.90"
 | 
			
		||||
source = { registry = "https://pypi.org/simple" }
 | 
			
		||||
dependencies = [
 | 
			
		||||
    { name = "absl-py" },
 | 
			
		||||
    { name = "jax" },
 | 
			
		||||
    { name = "jaxlib" },
 | 
			
		||||
    { name = "numpy" },
 | 
			
		||||
    { name = "setuptools" },
 | 
			
		||||
    { name = "toolz" },
 | 
			
		||||
    { name = "typing-extensions" },
 | 
			
		||||
]
 | 
			
		||||
sdist = { url = "https://files.pythonhosted.org/packages/77/70/53c7d404ce9e2a94009aea7f77ef6e392f6740e071c62683a506647c520f/chex-0.1.90.tar.gz", hash = "sha256:d3c375aeb6154b08f1cccd2bee4ed83659ee2198a6acf1160d2fe2e4a6c87b5c", size = 92363, upload-time = "2025-07-23T19:50:47.945Z" }
 | 
			
		||||
wheels = [
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/6f/3d/46bb04776c465cea2dd8aa2d4b61ab610b707f798f47838ef7e6105b025c/chex-0.1.90-py3-none-any.whl", hash = "sha256:fce3de82588f72d4796e545e574a433aa29229cbdcf792555e41bead24b704ae", size = 101047, upload-time = "2025-07-23T19:50:46.603Z" },
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "colorama"
 | 
			
		||||
version = "0.4.6"
 | 
			
		||||
@@ -386,6 +413,9 @@ dependencies = [
 | 
			
		||||
    { name = "jax", extra = ["cuda12"] },
 | 
			
		||||
    { name = "jupyter" },
 | 
			
		||||
    { name = "matplotlib" },
 | 
			
		||||
    { name = "networkx" },
 | 
			
		||||
    { name = "optax" },
 | 
			
		||||
    { name = "scikit-learn" },
 | 
			
		||||
    { name = "seaborn" },
 | 
			
		||||
    { name = "tqdm" },
 | 
			
		||||
]
 | 
			
		||||
@@ -397,6 +427,9 @@ requires-dist = [
 | 
			
		||||
    { name = "jax", extras = ["cuda12"], specifier = ">=0.7.0" },
 | 
			
		||||
    { name = "jupyter", specifier = ">=1.1.1" },
 | 
			
		||||
    { name = "matplotlib", specifier = ">=3.10.3" },
 | 
			
		||||
    { name = "networkx", specifier = ">=3.5" },
 | 
			
		||||
    { name = "optax", specifier = ">=0.2.5" },
 | 
			
		||||
    { name = "scikit-learn", specifier = ">=1.7.1" },
 | 
			
		||||
    { name = "seaborn", specifier = ">=0.13.2" },
 | 
			
		||||
    { name = "tqdm", specifier = ">=4.67.1" },
 | 
			
		||||
]
 | 
			
		||||
@@ -641,6 +674,15 @@ wheels = [
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" },
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "joblib"
 | 
			
		||||
version = "1.5.1"
 | 
			
		||||
source = { registry = "https://pypi.org/simple" }
 | 
			
		||||
sdist = { url = "https://files.pythonhosted.org/packages/dc/fe/0f5a938c54105553436dbff7a61dc4fed4b1b2c98852f8833beaf4d5968f/joblib-1.5.1.tar.gz", hash = "sha256:f4f86e351f39fe3d0d32a9f2c3d8af1ee4cec285aafcb27003dda5205576b444", size = 330475, upload-time = "2025-05-23T12:04:37.097Z" }
 | 
			
		||||
wheels = [
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/7d/4f/1195bbac8e0c2acc5f740661631d8d750dc38d4a32b23ee5df3cde6f4e0d/joblib-1.5.1-py3-none-any.whl", hash = "sha256:4719a31f054c7d766948dcd83e9613686b27114f190f717cec7eaa2084f8a74a", size = 307746, upload-time = "2025-05-23T12:04:35.124Z" },
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "json5"
 | 
			
		||||
version = "0.12.0"
 | 
			
		||||
@@ -1141,6 +1183,15 @@ wheels = [
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195, upload-time = "2024-01-21T14:25:17.223Z" },
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "networkx"
 | 
			
		||||
version = "3.5"
 | 
			
		||||
source = { registry = "https://pypi.org/simple" }
 | 
			
		||||
sdist = { url = "https://files.pythonhosted.org/packages/6c/4f/ccdb8ad3a38e583f214547fd2f7ff1fc160c43a75af88e6aec213404b96a/networkx-3.5.tar.gz", hash = "sha256:d4c6f9cf81f52d69230866796b82afbccdec3db7ae4fbd1b65ea750feed50037", size = 2471065, upload-time = "2025-05-29T11:35:07.804Z" }
 | 
			
		||||
wheels = [
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl", hash = "sha256:0030d386a9a06dee3565298b4a734b68589749a544acbb6c412dc9e2489ec6ec", size = 2034406, upload-time = "2025-05-29T11:35:04.961Z" },
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "notebook"
 | 
			
		||||
version = "7.4.4"
 | 
			
		||||
@@ -1373,6 +1424,22 @@ wheels = [
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932, upload-time = "2024-09-26T14:33:23.039Z" },
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "optax"
 | 
			
		||||
version = "0.2.5"
 | 
			
		||||
source = { registry = "https://pypi.org/simple" }
 | 
			
		||||
dependencies = [
 | 
			
		||||
    { name = "absl-py" },
 | 
			
		||||
    { name = "chex" },
 | 
			
		||||
    { name = "jax" },
 | 
			
		||||
    { name = "jaxlib" },
 | 
			
		||||
    { name = "numpy" },
 | 
			
		||||
]
 | 
			
		||||
sdist = { url = "https://files.pythonhosted.org/packages/c0/75/1e011953c48be502d4d84fa8458e91be7c6f983002511669bddd7b1a065f/optax-0.2.5.tar.gz", hash = "sha256:b2e38c7aea376186deae758ba7a258e6ef760c6f6131e9e11bc561c65386d594", size = 258548, upload-time = "2025-06-10T17:00:47.544Z" }
 | 
			
		||||
wheels = [
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/b9/33/f86091c706db1a5459f501830241afff2ecab3532725c188ea57be6e54de/optax-0.2.5-py3-none-any.whl", hash = "sha256:966deae936207f268ac8f564d8ed228d645ac1aaddefbbf194096d2299b24ba8", size = 354324, upload-time = "2025-06-10T17:00:46.062Z" },
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "overrides"
 | 
			
		||||
version = "7.7.0"
 | 
			
		||||
@@ -1862,6 +1929,35 @@ wheels = [
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/75/04/5302cea1aa26d886d34cadbf2dc77d90d7737e576c0065f357b96dc7a1a6/rpds_py-0.26.0-cp314-cp314t-win_amd64.whl", hash = "sha256:f14440b9573a6f76b4ee4770c13f0b5921f71dde3b6fcb8dabbefd13b7fe05d7", size = 232821, upload-time = "2025-07-01T15:55:55.167Z" },
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "scikit-learn"
 | 
			
		||||
version = "1.7.1"
 | 
			
		||||
source = { registry = "https://pypi.org/simple" }
 | 
			
		||||
dependencies = [
 | 
			
		||||
    { name = "joblib" },
 | 
			
		||||
    { name = "numpy" },
 | 
			
		||||
    { name = "scipy" },
 | 
			
		||||
    { name = "threadpoolctl" },
 | 
			
		||||
]
 | 
			
		||||
sdist = { url = "https://files.pythonhosted.org/packages/41/84/5f4af978fff619706b8961accac84780a6d298d82a8873446f72edb4ead0/scikit_learn-1.7.1.tar.gz", hash = "sha256:24b3f1e976a4665aa74ee0fcaac2b8fccc6ae77c8e07ab25da3ba6d3292b9802", size = 7190445, upload-time = "2025-07-18T08:01:54.5Z" }
 | 
			
		||||
wheels = [
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/cb/16/57f176585b35ed865f51b04117947fe20f130f78940c6477b6d66279c9c2/scikit_learn-1.7.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3cee419b49b5bbae8796ecd690f97aa412ef1674410c23fc3257c6b8b85b8087", size = 9260431, upload-time = "2025-07-18T08:01:22.77Z" },
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/67/4e/899317092f5efcab0e9bc929e3391341cec8fb0e816c4789686770024580/scikit_learn-1.7.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2fd8b8d35817b0d9ebf0b576f7d5ffbbabdb55536b0655a8aaae629d7ffd2e1f", size = 8637191, upload-time = "2025-07-18T08:01:24.731Z" },
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/f3/1b/998312db6d361ded1dd56b457ada371a8d8d77ca2195a7d18fd8a1736f21/scikit_learn-1.7.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:588410fa19a96a69763202f1d6b7b91d5d7a5d73be36e189bc6396bfb355bd87", size = 9486346, upload-time = "2025-07-18T08:01:26.713Z" },
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/ad/09/a2aa0b4e644e5c4ede7006748f24e72863ba2ae71897fecfd832afea01b4/scikit_learn-1.7.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e3142f0abe1ad1d1c31a2ae987621e41f6b578144a911ff4ac94781a583adad7", size = 9290988, upload-time = "2025-07-18T08:01:28.938Z" },
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/15/fa/c61a787e35f05f17fc10523f567677ec4eeee5f95aa4798dbbbcd9625617/scikit_learn-1.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:3ddd9092c1bd469acab337d87930067c87eac6bd544f8d5027430983f1e1ae88", size = 8735568, upload-time = "2025-07-18T08:01:30.936Z" },
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/52/f8/e0533303f318a0f37b88300d21f79b6ac067188d4824f1047a37214ab718/scikit_learn-1.7.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b7839687fa46d02e01035ad775982f2470be2668e13ddd151f0f55a5bf123bae", size = 9213143, upload-time = "2025-07-18T08:01:32.942Z" },
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/71/f3/f1df377d1bdfc3e3e2adc9c119c238b182293e6740df4cbeac6de2cc3e23/scikit_learn-1.7.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:a10f276639195a96c86aa572ee0698ad64ee939a7b042060b98bd1930c261d10", size = 8591977, upload-time = "2025-07-18T08:01:34.967Z" },
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/99/72/c86a4cd867816350fe8dee13f30222340b9cd6b96173955819a5561810c5/scikit_learn-1.7.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:13679981fdaebc10cc4c13c43344416a86fcbc61449cb3e6517e1df9d12c8309", size = 9436142, upload-time = "2025-07-18T08:01:37.397Z" },
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/e8/66/277967b29bd297538dc7a6ecfb1a7dce751beabd0d7f7a2233be7a4f7832/scikit_learn-1.7.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4f1262883c6a63f067a980a8cdd2d2e7f2513dddcef6a9eaada6416a7a7cbe43", size = 9282996, upload-time = "2025-07-18T08:01:39.721Z" },
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/e2/47/9291cfa1db1dae9880420d1e07dbc7e8dd4a7cdbc42eaba22512e6bde958/scikit_learn-1.7.1-cp313-cp313-win_amd64.whl", hash = "sha256:ca6d31fb10e04d50bfd2b50d66744729dbb512d4efd0223b864e2fdbfc4cee11", size = 8707418, upload-time = "2025-07-18T08:01:42.124Z" },
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/61/95/45726819beccdaa34d3362ea9b2ff9f2b5d3b8bf721bd632675870308ceb/scikit_learn-1.7.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:781674d096303cfe3d351ae6963ff7c958db61cde3421cd490e3a5a58f2a94ae", size = 9561466, upload-time = "2025-07-18T08:01:44.195Z" },
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/ee/1c/6f4b3344805de783d20a51eb24d4c9ad4b11a7f75c1801e6ec6d777361fd/scikit_learn-1.7.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:10679f7f125fe7ecd5fad37dd1aa2daae7e3ad8df7f3eefa08901b8254b3e12c", size = 9040467, upload-time = "2025-07-18T08:01:46.671Z" },
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/6f/80/abe18fe471af9f1d181904203d62697998b27d9b62124cd281d740ded2f9/scikit_learn-1.7.1-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1f812729e38c8cb37f760dce71a9b83ccfb04f59b3dca7c6079dcdc60544fa9e", size = 9532052, upload-time = "2025-07-18T08:01:48.676Z" },
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/14/82/b21aa1e0c4cee7e74864d3a5a721ab8fcae5ca55033cb6263dca297ed35b/scikit_learn-1.7.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:88e1a20131cf741b84b89567e1717f27a2ced228e0f29103426102bc2e3b8ef7", size = 9361575, upload-time = "2025-07-18T08:01:50.639Z" },
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/f2/20/f4777fcd5627dc6695fa6b92179d0edb7a3ac1b91bcd9a1c7f64fa7ade23/scikit_learn-1.7.1-cp313-cp313t-win_amd64.whl", hash = "sha256:b1bd1d919210b6a10b7554b717c9000b5485aa95a1d0f177ae0d7ee8ec750da5", size = 9277310, upload-time = "2025-07-18T08:01:52.547Z" },
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "scipy"
 | 
			
		||||
version = "1.16.0"
 | 
			
		||||
@@ -1987,6 +2083,15 @@ wheels = [
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/6a/9e/2064975477fdc887e47ad42157e214526dcad8f317a948dee17e1659a62f/terminado-0.18.1-py3-none-any.whl", hash = "sha256:a4468e1b37bb318f8a86514f65814e1afc977cf29b3992a4500d9dd305dcceb0", size = 14154, upload-time = "2024-03-12T14:34:36.569Z" },
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "threadpoolctl"
 | 
			
		||||
version = "3.6.0"
 | 
			
		||||
source = { registry = "https://pypi.org/simple" }
 | 
			
		||||
sdist = { url = "https://files.pythonhosted.org/packages/b7/4d/08c89e34946fce2aec4fbb45c9016efd5f4d7f24af8e5d93296e935631d8/threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e", size = 21274, upload-time = "2025-03-13T13:49:23.031Z" }
 | 
			
		||||
wheels = [
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" },
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "tinycss2"
 | 
			
		||||
version = "1.4.0"
 | 
			
		||||
@@ -1999,6 +2104,15 @@ wheels = [
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/e6/34/ebdc18bae6aa14fbee1a08b63c015c72b64868ff7dae68808ab500c492e2/tinycss2-1.4.0-py3-none-any.whl", hash = "sha256:3a49cf47b7675da0b15d0c6e1df8df4ebd96e9394bb905a5775adb0d884c5289", size = 26610, upload-time = "2024-10-24T14:58:28.029Z" },
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "toolz"
 | 
			
		||||
version = "1.0.0"
 | 
			
		||||
source = { registry = "https://pypi.org/simple" }
 | 
			
		||||
sdist = { url = "https://files.pythonhosted.org/packages/8a/0b/d80dfa675bf592f636d1ea0b835eab4ec8df6e9415d8cfd766df54456123/toolz-1.0.0.tar.gz", hash = "sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02", size = 66790, upload-time = "2024-10-04T16:17:04.001Z" }
 | 
			
		||||
wheels = [
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/03/98/eb27cc78ad3af8e302c9d8ff4977f5026676e130d28dd7578132a457170c/toolz-1.0.0-py3-none-any.whl", hash = "sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236", size = 56383, upload-time = "2024-10-04T16:17:01.533Z" },
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "tornado"
 | 
			
		||||
version = "6.5.1"
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user