Files
graph_recognition_w_attn/generate_data_kuramoto.py

162 lines
6.5 KiB
Python
Raw Permalink Normal View History

2025-07-31 01:12:53 -04:00
import os
import json
import time
import jax
import jax.numpy as jnp
import numpy as np
import networkx as nx
from tqdm import tqdm
from functools import partial
from dataclasses import asdict
import sims
# --- Import necessary components from your Kuramoto simulation file ---
from sims.kuramoto import (
KuramotoConfig,
run_kuramoto_simulation,
phase_coherence,
mean_frequency,
)
from generate_data_consensus import generate_connected_graph, calculate_graph_metrics
# ==============================================================================
# SECTION 1: BATCHED OPERATIONS FOR EFFICIENT DATA GENERATION
# ==============================================================================
@partial(jax.jit, static_argnames=("config",))
def run_batched_simulations(
thetas0_batch: jax.Array, # (num_sims, N)
omegas_batch: jax.Array, # (num_sims, N)
adj_mat: jax.Array, # (N, N)
config: KuramotoConfig
) -> tuple[jax.Array, jax.Array, jax.Array]:
"""
Runs many Kuramoto simulations in parallel for the same graph but different
initial conditions, and performs analysis.
Returns:
A tuple containing:
- All trajectories (num_sims, T, N)
- Final coherence R(T) for each sim (num_sims,)
- Mean frequencies for each sim (num_sims, N)
"""
# Create a batched version of the simulation runner using vmap.
# This maps the function over the first axis of thetas0_batch and omegas_batch.
vmapped_runner = jax.vmap(
run_kuramoto_simulation,
in_axes=(0, 0, None, None), # Map over thetas0, omegas; adj_mat and config are fixed
out_axes=0
)
all_trajectories = vmapped_runner(thetas0_batch, omegas_batch, adj_mat, config)
# Analyze the results in a batched manner
# phase_coherence naturally works on the time axis, so we vmap over the sim axis.
vmapped_coherence = jax.vmap(phase_coherence)(all_trajectories) # -> (num_sims, T)
final_coherence = vmapped_coherence[:, -1] # -> (num_sims,)
# vmap the mean_frequency calculation over trajectories and omegas
vmapped_mean_freq = jax.vmap(mean_frequency, in_axes=(0, 0, None, None))
all_mean_freqs = vmapped_mean_freq(all_trajectories, omegas_batch, adj_mat, config)
return all_trajectories, final_coherence, all_mean_freqs
def generate_batched_initial_states(key: jax.Array, num_sims: int, config: KuramotoConfig):
"""Generates a batch of initial states."""
keys = jax.random.split(key, num_sims)
# We will just write a new generator that is vmap-friendly
@jax.vmap
def generate_single_state(k):
key_theta, key_omega = jax.random.split(k)
thetas0 = jax.random.uniform(key_theta, (config.num_agents,), minval=0, maxval=2 * jnp.pi)
omegas = jax.random.normal(key_omega, (config.num_agents,)) # Using std=1, mean=0
return thetas0, omegas
thetas0_batch, omegas_batch = generate_single_state(keys)
return thetas0_batch, omegas_batch
def main():
"""Main script to generate and save the Kuramoto simulation dataset."""
# --- Configuration ---
GRAPH_GEN_ALGOS = ["erdos_renyi", "barabasi_albert", "powerlaw_cluster", "watts_strogatz"]
AGENT_COUNTS = range(5, 51, 5) # 5, 10, ..., 50 agents
# AGENT_COUNTS = [50]
GRAPHS_PER_AGENT_COUNT = 100
GRAPHS_PER_GRAPH_ALGO = GRAPHS_PER_AGENT_COUNT // len(GRAPH_GEN_ALGOS)
SIMS_PER_GRAPH = 100
OUTPUT_DIR = "datasets/kuramoto_dataset"
# --- Setup ---
seed = int(time.time())
main_key = jax.random.PRNGKey(seed)
numpy_rng = np.random.default_rng(seed=seed + 1)
print(f"🚀 Starting Kuramoto data generation...")
print(f"Saving dataset to: '{OUTPUT_DIR}/'")
os.makedirs(OUTPUT_DIR, exist_ok=True)
# --- Main Loop ---
for n_agents in tqdm(AGENT_COUNTS, desc="Agent Counts"):
agent_folder = os.path.join(OUTPUT_DIR, f"agents_{n_agents}")
os.makedirs(agent_folder, exist_ok=True)
for graph_idx in tqdm(range(GRAPHS_PER_AGENT_COUNT), desc=f"Graphs for N={n_agents}", leave=False):
graph_algo_idx = graph_idx // GRAPHS_PER_GRAPH_ALGO
# 1. Setup config, keys, and graph
config = KuramotoConfig()
config.num_agents=n_agents
config.coupling=1.5
config.T=15.0
main_key, graph_key, state_key = jax.random.split(main_key, 3)
G, graph_type, adj_matrix_np = generate_connected_graph(numpy_rng, config.num_agents, GRAPH_GEN_ALGOS[graph_algo_idx])
adj_matrix_jax = jnp.array(adj_matrix_np)
graph_metrics = calculate_graph_metrics(G)
graph_metrics["graph_type"] = graph_type
# 2. Generate a batch of initial conditions
thetas0_batch, omegas_batch = generate_batched_initial_states(state_key, SIMS_PER_GRAPH, config)
# 3. Run all simulations and analyses in a single, efficient call
trajectories, final_R, mean_freqs = run_batched_simulations(
thetas0_batch, omegas_batch, adj_matrix_jax, config
)
trajectories.block_until_ready() # Ensure computation is finished
# 4. Package all data for logging
log_data = {
"config": {
"num_agents" : config.num_agents,
"coupling" : config.coupling,
"dt" : config.dt,
"T" : config.T,
"normalize_by_degree" : config.normalize_by_degree,
"directed" : config.directed,
"weighted" : config.weighted,
},
"graph_metrics": graph_metrics,
"adjacency_matrix": adj_matrix_np.tolist(),
"initial_conditions": {
"thetas0": thetas0_batch.tolist(),
"omegas": omegas_batch.tolist()
},
"results": {
"final_coherence": final_R.tolist(),
"mean_frequencies": mean_freqs.tolist()
},
# Optionally save full trajectories. Warning: can create large files.
"trajectories": trajectories.tolist()
}
# 6. Save the complete log to a JSON file
file_path = os.path.join(agent_folder, f"graph_{graph_idx:02d}.json")
with open(file_path, 'w') as f:
json.dump(log_data, f, indent=2)
print(f"\n✅ Data generation complete! Check the '{OUTPUT_DIR}' folder.")
if __name__ == "__main__":
main()