replicated mecc
This commit is contained in:
162
generate_data_kuramoto.py
Normal file
162
generate_data_kuramoto.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import networkx as nx
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from dataclasses import asdict
|
||||
import sims
|
||||
# --- Import necessary components from your Kuramoto simulation file ---
|
||||
from sims.kuramoto import (
|
||||
KuramotoConfig,
|
||||
run_kuramoto_simulation,
|
||||
phase_coherence,
|
||||
mean_frequency,
|
||||
)
|
||||
|
||||
from generate_data_consensus import generate_connected_graph, calculate_graph_metrics
|
||||
|
||||
# ==============================================================================
|
||||
# SECTION 1: BATCHED OPERATIONS FOR EFFICIENT DATA GENERATION
|
||||
# ==============================================================================
|
||||
|
||||
@partial(jax.jit, static_argnames=("config",))
|
||||
def run_batched_simulations(
|
||||
thetas0_batch: jax.Array, # (num_sims, N)
|
||||
omegas_batch: jax.Array, # (num_sims, N)
|
||||
adj_mat: jax.Array, # (N, N)
|
||||
config: KuramotoConfig
|
||||
) -> tuple[jax.Array, jax.Array, jax.Array]:
|
||||
"""
|
||||
Runs many Kuramoto simulations in parallel for the same graph but different
|
||||
initial conditions, and performs analysis.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- All trajectories (num_sims, T, N)
|
||||
- Final coherence R(T) for each sim (num_sims,)
|
||||
- Mean frequencies for each sim (num_sims, N)
|
||||
"""
|
||||
# Create a batched version of the simulation runner using vmap.
|
||||
# This maps the function over the first axis of thetas0_batch and omegas_batch.
|
||||
vmapped_runner = jax.vmap(
|
||||
run_kuramoto_simulation,
|
||||
in_axes=(0, 0, None, None), # Map over thetas0, omegas; adj_mat and config are fixed
|
||||
out_axes=0
|
||||
)
|
||||
all_trajectories = vmapped_runner(thetas0_batch, omegas_batch, adj_mat, config)
|
||||
|
||||
# Analyze the results in a batched manner
|
||||
# phase_coherence naturally works on the time axis, so we vmap over the sim axis.
|
||||
vmapped_coherence = jax.vmap(phase_coherence)(all_trajectories) # -> (num_sims, T)
|
||||
final_coherence = vmapped_coherence[:, -1] # -> (num_sims,)
|
||||
|
||||
# vmap the mean_frequency calculation over trajectories and omegas
|
||||
vmapped_mean_freq = jax.vmap(mean_frequency, in_axes=(0, 0, None, None))
|
||||
all_mean_freqs = vmapped_mean_freq(all_trajectories, omegas_batch, adj_mat, config)
|
||||
|
||||
return all_trajectories, final_coherence, all_mean_freqs
|
||||
|
||||
def generate_batched_initial_states(key: jax.Array, num_sims: int, config: KuramotoConfig):
|
||||
"""Generates a batch of initial states."""
|
||||
keys = jax.random.split(key, num_sims)
|
||||
|
||||
# We will just write a new generator that is vmap-friendly
|
||||
@jax.vmap
|
||||
def generate_single_state(k):
|
||||
key_theta, key_omega = jax.random.split(k)
|
||||
thetas0 = jax.random.uniform(key_theta, (config.num_agents,), minval=0, maxval=2 * jnp.pi)
|
||||
omegas = jax.random.normal(key_omega, (config.num_agents,)) # Using std=1, mean=0
|
||||
return thetas0, omegas
|
||||
|
||||
thetas0_batch, omegas_batch = generate_single_state(keys)
|
||||
return thetas0_batch, omegas_batch
|
||||
|
||||
|
||||
def main():
|
||||
"""Main script to generate and save the Kuramoto simulation dataset."""
|
||||
# --- Configuration ---
|
||||
GRAPH_GEN_ALGOS = ["erdos_renyi", "barabasi_albert", "powerlaw_cluster", "watts_strogatz"]
|
||||
AGENT_COUNTS = range(5, 51, 5) # 5, 10, ..., 50 agents
|
||||
# AGENT_COUNTS = [50]
|
||||
GRAPHS_PER_AGENT_COUNT = 100
|
||||
GRAPHS_PER_GRAPH_ALGO = GRAPHS_PER_AGENT_COUNT // len(GRAPH_GEN_ALGOS)
|
||||
SIMS_PER_GRAPH = 100
|
||||
OUTPUT_DIR = "datasets/kuramoto_dataset"
|
||||
|
||||
# --- Setup ---
|
||||
seed = int(time.time())
|
||||
main_key = jax.random.PRNGKey(seed)
|
||||
numpy_rng = np.random.default_rng(seed=seed + 1)
|
||||
|
||||
print(f"🚀 Starting Kuramoto data generation...")
|
||||
print(f"Saving dataset to: '{OUTPUT_DIR}/'")
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# --- Main Loop ---
|
||||
for n_agents in tqdm(AGENT_COUNTS, desc="Agent Counts"):
|
||||
agent_folder = os.path.join(OUTPUT_DIR, f"agents_{n_agents}")
|
||||
os.makedirs(agent_folder, exist_ok=True)
|
||||
|
||||
for graph_idx in tqdm(range(GRAPHS_PER_AGENT_COUNT), desc=f"Graphs for N={n_agents}", leave=False):
|
||||
|
||||
graph_algo_idx = graph_idx // GRAPHS_PER_GRAPH_ALGO
|
||||
# 1. Setup config, keys, and graph
|
||||
config = KuramotoConfig()
|
||||
config.num_agents=n_agents
|
||||
config.coupling=1.5
|
||||
config.T=15.0
|
||||
main_key, graph_key, state_key = jax.random.split(main_key, 3)
|
||||
|
||||
G, graph_type, adj_matrix_np = generate_connected_graph(numpy_rng, config.num_agents, GRAPH_GEN_ALGOS[graph_algo_idx])
|
||||
adj_matrix_jax = jnp.array(adj_matrix_np)
|
||||
graph_metrics = calculate_graph_metrics(G)
|
||||
graph_metrics["graph_type"] = graph_type
|
||||
|
||||
# 2. Generate a batch of initial conditions
|
||||
thetas0_batch, omegas_batch = generate_batched_initial_states(state_key, SIMS_PER_GRAPH, config)
|
||||
|
||||
# 3. Run all simulations and analyses in a single, efficient call
|
||||
trajectories, final_R, mean_freqs = run_batched_simulations(
|
||||
thetas0_batch, omegas_batch, adj_matrix_jax, config
|
||||
)
|
||||
trajectories.block_until_ready() # Ensure computation is finished
|
||||
|
||||
# 4. Package all data for logging
|
||||
log_data = {
|
||||
"config": {
|
||||
"num_agents" : config.num_agents,
|
||||
"coupling" : config.coupling,
|
||||
"dt" : config.dt,
|
||||
"T" : config.T,
|
||||
"normalize_by_degree" : config.normalize_by_degree,
|
||||
"directed" : config.directed,
|
||||
"weighted" : config.weighted,
|
||||
},
|
||||
|
||||
"graph_metrics": graph_metrics,
|
||||
"adjacency_matrix": adj_matrix_np.tolist(),
|
||||
"initial_conditions": {
|
||||
"thetas0": thetas0_batch.tolist(),
|
||||
"omegas": omegas_batch.tolist()
|
||||
},
|
||||
"results": {
|
||||
"final_coherence": final_R.tolist(),
|
||||
"mean_frequencies": mean_freqs.tolist()
|
||||
},
|
||||
# Optionally save full trajectories. Warning: can create large files.
|
||||
"trajectories": trajectories.tolist()
|
||||
}
|
||||
|
||||
# 6. Save the complete log to a JSON file
|
||||
file_path = os.path.join(agent_folder, f"graph_{graph_idx:02d}.json")
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(log_data, f, indent=2)
|
||||
|
||||
print(f"\n✅ Data generation complete! Check the '{OUTPUT_DIR}' folder.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user