import os import json import time import jax import jax.numpy as jnp import numpy as np import networkx as nx from tqdm import tqdm from functools import partial from dataclasses import asdict import sims # --- Import necessary components from your Kuramoto simulation file --- from sims.kuramoto import ( KuramotoConfig, run_kuramoto_simulation, phase_coherence, mean_frequency, ) from generate_data_consensus import generate_connected_graph, calculate_graph_metrics # ============================================================================== # SECTION 1: BATCHED OPERATIONS FOR EFFICIENT DATA GENERATION # ============================================================================== @partial(jax.jit, static_argnames=("config",)) def run_batched_simulations( thetas0_batch: jax.Array, # (num_sims, N) omegas_batch: jax.Array, # (num_sims, N) adj_mat: jax.Array, # (N, N) config: KuramotoConfig ) -> tuple[jax.Array, jax.Array, jax.Array]: """ Runs many Kuramoto simulations in parallel for the same graph but different initial conditions, and performs analysis. Returns: A tuple containing: - All trajectories (num_sims, T, N) - Final coherence R(T) for each sim (num_sims,) - Mean frequencies for each sim (num_sims, N) """ # Create a batched version of the simulation runner using vmap. # This maps the function over the first axis of thetas0_batch and omegas_batch. vmapped_runner = jax.vmap( run_kuramoto_simulation, in_axes=(0, 0, None, None), # Map over thetas0, omegas; adj_mat and config are fixed out_axes=0 ) all_trajectories = vmapped_runner(thetas0_batch, omegas_batch, adj_mat, config) # Analyze the results in a batched manner # phase_coherence naturally works on the time axis, so we vmap over the sim axis. vmapped_coherence = jax.vmap(phase_coherence)(all_trajectories) # -> (num_sims, T) final_coherence = vmapped_coherence[:, -1] # -> (num_sims,) # vmap the mean_frequency calculation over trajectories and omegas vmapped_mean_freq = jax.vmap(mean_frequency, in_axes=(0, 0, None, None)) all_mean_freqs = vmapped_mean_freq(all_trajectories, omegas_batch, adj_mat, config) return all_trajectories, final_coherence, all_mean_freqs def generate_batched_initial_states(key: jax.Array, num_sims: int, config: KuramotoConfig): """Generates a batch of initial states.""" keys = jax.random.split(key, num_sims) # We will just write a new generator that is vmap-friendly @jax.vmap def generate_single_state(k): key_theta, key_omega = jax.random.split(k) thetas0 = jax.random.uniform(key_theta, (config.num_agents,), minval=0, maxval=2 * jnp.pi) omegas = jax.random.normal(key_omega, (config.num_agents,)) # Using std=1, mean=0 return thetas0, omegas thetas0_batch, omegas_batch = generate_single_state(keys) return thetas0_batch, omegas_batch def main(): """Main script to generate and save the Kuramoto simulation dataset.""" # --- Configuration --- GRAPH_GEN_ALGOS = ["erdos_renyi", "barabasi_albert", "powerlaw_cluster", "watts_strogatz"] AGENT_COUNTS = range(5, 51, 5) # 5, 10, ..., 50 agents # AGENT_COUNTS = [50] GRAPHS_PER_AGENT_COUNT = 100 GRAPHS_PER_GRAPH_ALGO = GRAPHS_PER_AGENT_COUNT // len(GRAPH_GEN_ALGOS) SIMS_PER_GRAPH = 100 OUTPUT_DIR = "datasets/kuramoto_dataset" # --- Setup --- seed = int(time.time()) main_key = jax.random.PRNGKey(seed) numpy_rng = np.random.default_rng(seed=seed + 1) print(f"šŸš€ Starting Kuramoto data generation...") print(f"Saving dataset to: '{OUTPUT_DIR}/'") os.makedirs(OUTPUT_DIR, exist_ok=True) # --- Main Loop --- for n_agents in tqdm(AGENT_COUNTS, desc="Agent Counts"): agent_folder = os.path.join(OUTPUT_DIR, f"agents_{n_agents}") os.makedirs(agent_folder, exist_ok=True) for graph_idx in tqdm(range(GRAPHS_PER_AGENT_COUNT), desc=f"Graphs for N={n_agents}", leave=False): graph_algo_idx = graph_idx // GRAPHS_PER_GRAPH_ALGO # 1. Setup config, keys, and graph config = KuramotoConfig() config.num_agents=n_agents config.coupling=1.5 config.T=15.0 main_key, graph_key, state_key = jax.random.split(main_key, 3) G, graph_type, adj_matrix_np = generate_connected_graph(numpy_rng, config.num_agents, GRAPH_GEN_ALGOS[graph_algo_idx]) adj_matrix_jax = jnp.array(adj_matrix_np) graph_metrics = calculate_graph_metrics(G) graph_metrics["graph_type"] = graph_type # 2. Generate a batch of initial conditions thetas0_batch, omegas_batch = generate_batched_initial_states(state_key, SIMS_PER_GRAPH, config) # 3. Run all simulations and analyses in a single, efficient call trajectories, final_R, mean_freqs = run_batched_simulations( thetas0_batch, omegas_batch, adj_matrix_jax, config ) trajectories.block_until_ready() # Ensure computation is finished # 4. Package all data for logging log_data = { "config": { "num_agents" : config.num_agents, "coupling" : config.coupling, "dt" : config.dt, "T" : config.T, "normalize_by_degree" : config.normalize_by_degree, "directed" : config.directed, "weighted" : config.weighted, }, "graph_metrics": graph_metrics, "adjacency_matrix": adj_matrix_np.tolist(), "initial_conditions": { "thetas0": thetas0_batch.tolist(), "omegas": omegas_batch.tolist() }, "results": { "final_coherence": final_R.tolist(), "mean_frequencies": mean_freqs.tolist() }, # Optionally save full trajectories. Warning: can create large files. "trajectories": trajectories.tolist() } # 6. Save the complete log to a JSON file file_path = os.path.join(agent_folder, f"graph_{graph_idx:02d}.json") with open(file_path, 'w') as f: json.dump(log_data, f, indent=2) print(f"\nāœ… Data generation complete! Check the '{OUTPUT_DIR}' folder.") if __name__ == "__main__": main()