import os import json import time import jax import jax.numpy as jnp import numpy as np import networkx as nx import networkx.algorithms.community as nx_comm from tqdm import tqdm import sims def generate_connected_graph(rng: np.random.Generator, num_agents: int, graph_type: str) -> tuple[nx.Graph, str, np.ndarray]: """ Generates a random, undirected, unweighted, connected graph. It randomly selects a NetworkX algorithm and retries until the graph is connected. Returns: A tuple containing the networkx Graph object, the algorithm name, and the adjacency matrix. """ G = None if graph_type == "erdos_renyi": # p = 2.0 * np.log(num_agents) / num_agents p = 0.5 G = nx.erdos_renyi_graph(num_agents, p, seed=rng) elif graph_type == "watts_strogatz": k = 2 p = 0.1 G = nx.watts_strogatz_graph(num_agents, k, p, seed=rng) elif graph_type == "barabasi_albert": m = 2 if m >= num_agents: m = max(1, num_agents - 1) G = nx.barabasi_albert_graph(num_agents, m, seed=rng) elif graph_type == "powerlaw_cluster": m = 3 # Number of random edges to add for each new node p = 0.1 # Probability of adding a triangle after adding a random edge if m >= num_agents: m = max(1, num_agents - 1) G = nx.powerlaw_cluster_graph(num_agents, m, p, seed=rng) # Add self-loops, as they are often assumed in consensus algorithms G.add_edges_from([(i, i) for i in range(num_agents)]) adj_matrix = nx.to_numpy_array(G, dtype=np.float32) return G, graph_type, adj_matrix def calculate_graph_metrics(G: nx.Graph) -> dict: """ Calculates and returns a dictionary of key graph metrics. This function computes basic properties, connectivity, clustering, community structure, and spectral properties of the graph. Computationally expensive metrics are skipped for larger graphs to ensure performance. """ metrics = {} num_nodes = G.number_of_nodes() num_edges = G.number_of_edges() # --- Basic Properties --- metrics["number_of_nodes"] = num_nodes metrics["number_of_edges"] = num_edges metrics["average_degree"] = (2 * num_edges / num_nodes) if num_nodes > 0 else 0 metrics["edge_density"] = nx.density(G) # --- Connectivity --- metrics["is_connected"] = nx.is_connected(G) metrics["number_connected_components"] = nx.number_connected_components(G) # --- Clustering --- metrics["average_clustering_coefficient"] = nx.average_clustering(G) metrics["clustering_coefficient"] = nx.clustering(G) # Per-node clustering # --- Distance-Based Metrics (for connected graphs) --- # These are computationally intensive and only run on smaller, connected graphs. if metrics["is_connected"] and num_nodes < 250: metrics["average_shortest_path_length"] = nx.average_shortest_path_length(G) metrics["diameter"] = nx.diameter(G) metrics["eccentricity"] = nx.eccentricity(G) else: # Set to None if graph is disconnected or too large metrics["average_shortest_path_length"] = None metrics["diameter"] = None metrics["eccentricity"] = None # --- Spectral & Community Metrics (Potentially Slow) --- # These are also limited to smaller graphs. if 1 < num_nodes < 500: # Eigenvalues of the Laplacian matrix try: laplacian_eigenvalues = sorted(nx.laplacian_spectrum(G)) metrics["laplacian_eigenvalues"] = laplacian_eigenvalues # The second-smallest eigenvalue of the Laplacian matrix metrics["algebraic_connectivity"] = laplacian_eigenvalues[1] except Exception: metrics["laplacian_eigenvalues"] = None metrics["algebraic_connectivity"] = None # Modularity using the Louvain community detection algorithm if num_edges > 0: try: communities = nx_comm.louvain_communities(G, seed=123) metrics["modularity"] = nx_comm.modularity(G, communities) except Exception: metrics["modularity"] = None # Algorithm may fail on some graphs else: metrics["modularity"] = None else: metrics["laplacian_eigenvalues"] = None metrics["algebraic_connectivity"] = None metrics["modularity"] = None return metrics def generate_multiple_initial_states(key: jax.Array, num_sims: int, num_agents: int, max_range: float) -> jax.Array: """Generate a batch of unique random initial states for the agents.""" return jax.random.uniform( key, shape=(num_sims, num_agents), minval=-max_range, maxval=max_range ) def main(): """Main script to generate and save the consensus simulation dataset.""" # --- Configuration --- GRAPH_GEN_ALGOS = ["erdos_renyi", "barabasi_albert", "powerlaw_cluster", "watts_strogatz"] AGENT_COUNTS = range(5, 51, 5)# 5, 10, ..., 50 agents GRAPHS_PER_AGENT_COUNT = 100 GRAPHS_PER_GRAPH_ALGO = GRAPHS_PER_AGENT_COUNT // len(GRAPH_GEN_ALGOS) SIMS_PER_GRAPH = 100 OUTPUT_DIR = "datasets/consensus_dataset" # --- Setup --- seed = int(time.time()) main_jax_key = jax.random.PRNGKey(seed) numpy_rng = np.random.default_rng(seed=seed + 1) print(f"šŸš€ Starting data generation...") print(f"Saving dataset to directory: '{OUTPUT_DIR}/'") os.makedirs(OUTPUT_DIR, exist_ok=True) # --- Main Generation Loop --- for n_agents in tqdm(AGENT_COUNTS, desc="Overall Progress"): agent_folder = os.path.join(OUTPUT_DIR, f"agents_{n_agents}") os.makedirs(agent_folder, exist_ok=True) for graph_idx in tqdm(range(GRAPHS_PER_AGENT_COUNT), desc=f"Graphs for N={n_agents}", leave=False): graph_algo_idx = graph_idx // GRAPHS_PER_GRAPH_ALGO # 1. Configure the simulation using the imported class config = sims.consensus.ConsensusConfig() config.num_agents=n_agents config.num_sims=SIMS_PER_GRAPH # 2. Generate graph and its metrics G, graph_type, adj_matrix_np = generate_connected_graph(numpy_rng, config.num_agents, GRAPH_GEN_ALGOS[graph_algo_idx]) adj_matrix_jax = jnp.array(adj_matrix_np) graph_metrics = calculate_graph_metrics(G) graph_metrics["graph_type"] = graph_type # 3. Generate initial states for the 100 simulations main_jax_key, states_key = jax.random.split(main_jax_key) initial_states = generate_multiple_initial_states( states_key, config.num_sims, config.num_agents, config.max_range ) # 4. Run all simulations using the imported JAX function trajectories = sims.consensus.run_consensus_sim(adj_matrix_jax, initial_states, config) # 5. Package all data into a dictionary for logging log_data = { "simulation_config": { "num_agents": config.num_agents, "num_sims": config.num_sims, "num_time_steps": config.num_time_steps, "step_size": config.step_size, "max_range": config.max_range, "directed": config.directed, "weighted": config.weighted, }, "graph_metrics": graph_metrics, "adjacency_matrix": adj_matrix_np.tolist(), "initial_states": initial_states.tolist(), "trajectories": trajectories.tolist() } # 6. Save the complete log to a JSON file file_path = os.path.join(agent_folder, f"graph_{graph_idx:02d}.json") with open(file_path, 'w') as f: json.dump(log_data, f, indent=2) print(f"\nāœ… Data generation complete! Check the '{OUTPUT_DIR}' folder.") if __name__ == "__main__": main()