162 lines
6.5 KiB
Python
162 lines
6.5 KiB
Python
import os
|
|
import json
|
|
import time
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
import networkx as nx
|
|
from tqdm import tqdm
|
|
from functools import partial
|
|
from dataclasses import asdict
|
|
import sims
|
|
# --- Import necessary components from your Kuramoto simulation file ---
|
|
from sims.kuramoto import (
|
|
KuramotoConfig,
|
|
run_kuramoto_simulation,
|
|
phase_coherence,
|
|
mean_frequency,
|
|
)
|
|
|
|
from generate_data_consensus import generate_connected_graph, calculate_graph_metrics
|
|
|
|
# ==============================================================================
|
|
# SECTION 1: BATCHED OPERATIONS FOR EFFICIENT DATA GENERATION
|
|
# ==============================================================================
|
|
|
|
@partial(jax.jit, static_argnames=("config",))
|
|
def run_batched_simulations(
|
|
thetas0_batch: jax.Array, # (num_sims, N)
|
|
omegas_batch: jax.Array, # (num_sims, N)
|
|
adj_mat: jax.Array, # (N, N)
|
|
config: KuramotoConfig
|
|
) -> tuple[jax.Array, jax.Array, jax.Array]:
|
|
"""
|
|
Runs many Kuramoto simulations in parallel for the same graph but different
|
|
initial conditions, and performs analysis.
|
|
|
|
Returns:
|
|
A tuple containing:
|
|
- All trajectories (num_sims, T, N)
|
|
- Final coherence R(T) for each sim (num_sims,)
|
|
- Mean frequencies for each sim (num_sims, N)
|
|
"""
|
|
# Create a batched version of the simulation runner using vmap.
|
|
# This maps the function over the first axis of thetas0_batch and omegas_batch.
|
|
vmapped_runner = jax.vmap(
|
|
run_kuramoto_simulation,
|
|
in_axes=(0, 0, None, None), # Map over thetas0, omegas; adj_mat and config are fixed
|
|
out_axes=0
|
|
)
|
|
all_trajectories = vmapped_runner(thetas0_batch, omegas_batch, adj_mat, config)
|
|
|
|
# Analyze the results in a batched manner
|
|
# phase_coherence naturally works on the time axis, so we vmap over the sim axis.
|
|
vmapped_coherence = jax.vmap(phase_coherence)(all_trajectories) # -> (num_sims, T)
|
|
final_coherence = vmapped_coherence[:, -1] # -> (num_sims,)
|
|
|
|
# vmap the mean_frequency calculation over trajectories and omegas
|
|
vmapped_mean_freq = jax.vmap(mean_frequency, in_axes=(0, 0, None, None))
|
|
all_mean_freqs = vmapped_mean_freq(all_trajectories, omegas_batch, adj_mat, config)
|
|
|
|
return all_trajectories, final_coherence, all_mean_freqs
|
|
|
|
def generate_batched_initial_states(key: jax.Array, num_sims: int, config: KuramotoConfig):
|
|
"""Generates a batch of initial states."""
|
|
keys = jax.random.split(key, num_sims)
|
|
|
|
# We will just write a new generator that is vmap-friendly
|
|
@jax.vmap
|
|
def generate_single_state(k):
|
|
key_theta, key_omega = jax.random.split(k)
|
|
thetas0 = jax.random.uniform(key_theta, (config.num_agents,), minval=0, maxval=2 * jnp.pi)
|
|
omegas = jax.random.normal(key_omega, (config.num_agents,)) # Using std=1, mean=0
|
|
return thetas0, omegas
|
|
|
|
thetas0_batch, omegas_batch = generate_single_state(keys)
|
|
return thetas0_batch, omegas_batch
|
|
|
|
|
|
def main():
|
|
"""Main script to generate and save the Kuramoto simulation dataset."""
|
|
# --- Configuration ---
|
|
GRAPH_GEN_ALGOS = ["erdos_renyi", "barabasi_albert", "powerlaw_cluster", "watts_strogatz"]
|
|
AGENT_COUNTS = range(5, 51, 5) # 5, 10, ..., 50 agents
|
|
# AGENT_COUNTS = [50]
|
|
GRAPHS_PER_AGENT_COUNT = 100
|
|
GRAPHS_PER_GRAPH_ALGO = GRAPHS_PER_AGENT_COUNT // len(GRAPH_GEN_ALGOS)
|
|
SIMS_PER_GRAPH = 100
|
|
OUTPUT_DIR = "datasets/kuramoto_dataset"
|
|
|
|
# --- Setup ---
|
|
seed = int(time.time())
|
|
main_key = jax.random.PRNGKey(seed)
|
|
numpy_rng = np.random.default_rng(seed=seed + 1)
|
|
|
|
print(f"🚀 Starting Kuramoto data generation...")
|
|
print(f"Saving dataset to: '{OUTPUT_DIR}/'")
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
# --- Main Loop ---
|
|
for n_agents in tqdm(AGENT_COUNTS, desc="Agent Counts"):
|
|
agent_folder = os.path.join(OUTPUT_DIR, f"agents_{n_agents}")
|
|
os.makedirs(agent_folder, exist_ok=True)
|
|
|
|
for graph_idx in tqdm(range(GRAPHS_PER_AGENT_COUNT), desc=f"Graphs for N={n_agents}", leave=False):
|
|
|
|
graph_algo_idx = graph_idx // GRAPHS_PER_GRAPH_ALGO
|
|
# 1. Setup config, keys, and graph
|
|
config = KuramotoConfig()
|
|
config.num_agents=n_agents
|
|
config.coupling=1.5
|
|
config.T=15.0
|
|
main_key, graph_key, state_key = jax.random.split(main_key, 3)
|
|
|
|
G, graph_type, adj_matrix_np = generate_connected_graph(numpy_rng, config.num_agents, GRAPH_GEN_ALGOS[graph_algo_idx])
|
|
adj_matrix_jax = jnp.array(adj_matrix_np)
|
|
graph_metrics = calculate_graph_metrics(G)
|
|
graph_metrics["graph_type"] = graph_type
|
|
|
|
# 2. Generate a batch of initial conditions
|
|
thetas0_batch, omegas_batch = generate_batched_initial_states(state_key, SIMS_PER_GRAPH, config)
|
|
|
|
# 3. Run all simulations and analyses in a single, efficient call
|
|
trajectories, final_R, mean_freqs = run_batched_simulations(
|
|
thetas0_batch, omegas_batch, adj_matrix_jax, config
|
|
)
|
|
trajectories.block_until_ready() # Ensure computation is finished
|
|
|
|
# 4. Package all data for logging
|
|
log_data = {
|
|
"config": {
|
|
"num_agents" : config.num_agents,
|
|
"coupling" : config.coupling,
|
|
"dt" : config.dt,
|
|
"T" : config.T,
|
|
"normalize_by_degree" : config.normalize_by_degree,
|
|
"directed" : config.directed,
|
|
"weighted" : config.weighted,
|
|
},
|
|
|
|
"graph_metrics": graph_metrics,
|
|
"adjacency_matrix": adj_matrix_np.tolist(),
|
|
"initial_conditions": {
|
|
"thetas0": thetas0_batch.tolist(),
|
|
"omegas": omegas_batch.tolist()
|
|
},
|
|
"results": {
|
|
"final_coherence": final_R.tolist(),
|
|
"mean_frequencies": mean_freqs.tolist()
|
|
},
|
|
# Optionally save full trajectories. Warning: can create large files.
|
|
"trajectories": trajectories.tolist()
|
|
}
|
|
|
|
# 6. Save the complete log to a JSON file
|
|
file_path = os.path.join(agent_folder, f"graph_{graph_idx:02d}.json")
|
|
with open(file_path, 'w') as f:
|
|
json.dump(log_data, f, indent=2)
|
|
|
|
print(f"\n✅ Data generation complete! Check the '{OUTPUT_DIR}' folder.")
|
|
|
|
if __name__ == "__main__":
|
|
main() |