import os import json import time import jax import jax.numpy as jnp import numpy as np import networkx as nx import networkx.algorithms.community as nx_comm from tqdm import tqdm import sims from dotenv import load_dotenv load_dotenv() def generate_connected_graph(rng: np.random.Generator, num_agents: int, graph_type: str) -> tuple[nx.Graph, str, np.ndarray]: """ Generates a random, undirected, unweighted, connected graph. It randomly selects a NetworkX algorithm and retries until the graph is connected. Returns: A tuple containing the networkx Graph object, the algorithm name, and the adjacency matrix. """ G = None if graph_type == "erdos_renyi": p = 0.5 G = nx.erdos_renyi_graph(num_agents, p, seed=rng) elif graph_type == "watts_strogatz": k = 4 p = 0.5 G = nx.watts_strogatz_graph(num_agents, k, p, seed=rng) elif graph_type == "barabasi_albert": G = nx.barabasi_albert_graph(num_agents, 2, seed=rng) elif graph_type == "powerlaw_cluster": m = 2 p = 0.5 G = nx.powerlaw_cluster_graph(num_agents, m, p, seed=rng) graph_matrix = nx.adjacency_matrix(G).todense() for i in range(num_agents): graph_matrix[i,i]=1 G = nx.Graph(graph_matrix) return G, graph_type, graph_matrix def calculate_graph_metrics(G: nx.Graph) -> dict: """ Calculates and returns a dictionary of key graph metrics. This function computes basic properties, connectivity, clustering, community structure, and spectral properties of the graph. Computationally expensive metrics are skipped for larger graphs to ensure performance. """ metrics = {} num_nodes = G.number_of_nodes() num_edges = G.number_of_edges() # --- Basic Properties --- metrics["number_of_nodes"] = num_nodes metrics["number_of_edges"] = num_edges metrics["average_degree"] = (2 * num_edges / num_nodes) if num_nodes > 0 else 0 metrics["edge_density"] = nx.density(G) # --- Connectivity --- metrics["is_connected"] = nx.is_connected(G) metrics["number_connected_components"] = nx.number_connected_components(G) # --- Clustering --- metrics["average_clustering_coefficient"] = nx.average_clustering(G) metrics["clustering_coefficient"] = nx.clustering(G) # Per-node clustering # --- Distance-Based Metrics (for connected graphs) --- # These are computationally intensive and only run on smaller, connected graphs. if metrics["is_connected"] and num_nodes < 250: metrics["average_shortest_path_length"] = nx.average_shortest_path_length(G) metrics["diameter"] = nx.diameter(G) metrics["eccentricity"] = nx.eccentricity(G) else: # Set to None if graph is disconnected or too large metrics["average_shortest_path_length"] = None metrics["diameter"] = None metrics["eccentricity"] = None # --- Spectral & Community Metrics (Potentially Slow) --- # These are also limited to smaller graphs. # if 1 < num_nodes < 500: # # Eigenvalues of the Laplacian matrix # try: # laplacian_eigenvalues = sorted(nx.laplacian_spectrum(G)) # metrics["laplacian_eigenvalues"] = laplacian_eigenvalues # # The second-smallest eigenvalue of the Laplacian matrix # metrics["algebraic_connectivity"] = laplacian_eigenvalues[1] # except Exception: # metrics["laplacian_eigenvalues"] = None # metrics["algebraic_connectivity"] = None # # Modularity using the Louvain community detection algorithm # if num_edges > 0: # try: # communities = nx_comm.louvain_communities(G, seed=123) # metrics["modularity"] = nx_comm.modularity(G, communities) # except Exception: # metrics["modularity"] = None # Algorithm may fail on some graphs # else: # metrics["modularity"] = None # else: # metrics["laplacian_eigenvalues"] = None # metrics["algebraic_connectivity"] = None # metrics["modularity"] = None return metrics def generate_multiple_initial_states(key: jax.Array, num_sims: int, num_agents: int, max_range: float) -> jax.Array: """Generate a batch of unique random initial states for the agents.""" return jax.random.uniform( key, shape=(num_sims, num_agents), minval=-max_range, maxval=max_range ) def main(): """Main script to generate and save the consensus simulation dataset.""" # --- Configuration --- GRAPH_GEN_ALGOS = ["erdos_renyi", "barabasi_albert", "powerlaw_cluster", "watts_strogatz"] AGENT_COUNTS = range(5, 51, 5)# 5, 10, ..., 50 agents GRAPHS_PER_AGENT_COUNT = 80 GRAPHS_PER_GRAPH_ALGO = GRAPHS_PER_AGENT_COUNT // len(GRAPH_GEN_ALGOS) SIMS_PER_GRAPH = 400 OUTPUT_DIR = "datasets/consensus_dataset" # --- Setup --- seed = int(time.time()) main_jax_key = jax.random.PRNGKey(seed) numpy_rng = np.random.default_rng(seed=seed + 1) print(f"šŸš€ Starting data generation...") print(f"Saving dataset to directory: '{OUTPUT_DIR}/'") os.makedirs(OUTPUT_DIR, exist_ok=True) # --- Main Generation Loop --- for n_agents in tqdm(AGENT_COUNTS, desc="Overall Progress"): agent_folder = os.path.join(OUTPUT_DIR, f"agents_{n_agents}") os.makedirs(agent_folder, exist_ok=True) for graph_idx in tqdm(range(GRAPHS_PER_AGENT_COUNT), desc=f"Graphs for N={n_agents}", leave=False): graph_algo_idx = graph_idx // GRAPHS_PER_GRAPH_ALGO # 1. Configure the simulation using the imported class config = sims.consensus.ConsensusConfig() config.num_agents=n_agents config.num_sims=SIMS_PER_GRAPH # 2. Generate graph and its metrics G, graph_type, adj_matrix_np = generate_connected_graph(numpy_rng, config.num_agents, GRAPH_GEN_ALGOS[graph_algo_idx]) adj_matrix_jax = jnp.array(adj_matrix_np) graph_metrics = calculate_graph_metrics(G) graph_metrics["graph_type"] = graph_type # 3. Generate initial states for the 100 simulations main_jax_key, states_key = jax.random.split(main_jax_key) initial_states = generate_multiple_initial_states( states_key, config.num_sims, config.num_agents, config.max_range ) # 4. Run all simulations using the imported JAX function trajectories = sims.consensus.run_consensus_sim(adj_matrix_jax, initial_states, config) # 5. Package all data into a dictionary for logging log_data = { "simulation_config": { "num_agents": config.num_agents, "num_sims": config.num_sims, "num_time_steps": config.num_time_steps, "step_size": config.step_size, "max_range": config.max_range, "directed": config.directed, "weighted": config.weighted, }, "graph_metrics": graph_metrics, "adjacency_matrix": adj_matrix_np.tolist(), "initial_states": initial_states.tolist(), "trajectories": trajectories.tolist() } # 6. Save the complete log to a JSON file file_path = os.path.join(agent_folder, f"graph_{graph_idx:02d}.json") with open(file_path, 'w') as f: json.dump(log_data, f, indent=2) print(f"\nāœ… Data generation complete! Check the '{OUTPUT_DIR}' folder.") if __name__ == "__main__": main()