replicated mecc
This commit is contained in:
206
generate_data_consensus.py
Normal file
206
generate_data_consensus.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import networkx as nx
|
||||
import networkx.algorithms.community as nx_comm
|
||||
from tqdm import tqdm
|
||||
import sims
|
||||
|
||||
|
||||
|
||||
def generate_connected_graph(rng: np.random.Generator, num_agents: int, graph_type: str) -> tuple[nx.Graph, str, np.ndarray]:
|
||||
"""
|
||||
Generates a random, undirected, unweighted, connected graph.
|
||||
It randomly selects a NetworkX algorithm and retries until the graph is connected.
|
||||
|
||||
Returns:
|
||||
A tuple containing the networkx Graph object, the algorithm name, and the adjacency matrix.
|
||||
"""
|
||||
|
||||
G = None
|
||||
|
||||
|
||||
if graph_type == "erdos_renyi":
|
||||
|
||||
# p = 2.0 * np.log(num_agents) / num_agents
|
||||
p = 0.5
|
||||
G = nx.erdos_renyi_graph(num_agents, p, seed=rng)
|
||||
|
||||
elif graph_type == "watts_strogatz":
|
||||
k = 2
|
||||
p = 0.1
|
||||
G = nx.watts_strogatz_graph(num_agents, k, p, seed=rng)
|
||||
|
||||
elif graph_type == "barabasi_albert":
|
||||
m = 2
|
||||
if m >= num_agents: m = max(1, num_agents - 1)
|
||||
G = nx.barabasi_albert_graph(num_agents, m, seed=rng)
|
||||
|
||||
elif graph_type == "powerlaw_cluster":
|
||||
m = 3 # Number of random edges to add for each new node
|
||||
p = 0.1 # Probability of adding a triangle after adding a random edge
|
||||
if m >= num_agents: m = max(1, num_agents - 1)
|
||||
G = nx.powerlaw_cluster_graph(num_agents, m, p, seed=rng)
|
||||
|
||||
# Add self-loops, as they are often assumed in consensus algorithms
|
||||
G.add_edges_from([(i, i) for i in range(num_agents)])
|
||||
adj_matrix = nx.to_numpy_array(G, dtype=np.float32)
|
||||
|
||||
return G, graph_type, adj_matrix
|
||||
|
||||
def calculate_graph_metrics(G: nx.Graph) -> dict:
|
||||
"""
|
||||
Calculates and returns a dictionary of key graph metrics.
|
||||
|
||||
This function computes basic properties, connectivity, clustering,
|
||||
community structure, and spectral properties of the graph.
|
||||
Computationally expensive metrics are skipped for larger graphs to ensure performance.
|
||||
"""
|
||||
metrics = {}
|
||||
num_nodes = G.number_of_nodes()
|
||||
num_edges = G.number_of_edges()
|
||||
|
||||
# --- Basic Properties ---
|
||||
metrics["number_of_nodes"] = num_nodes
|
||||
metrics["number_of_edges"] = num_edges
|
||||
metrics["average_degree"] = (2 * num_edges / num_nodes) if num_nodes > 0 else 0
|
||||
metrics["edge_density"] = nx.density(G)
|
||||
|
||||
# --- Connectivity ---
|
||||
metrics["is_connected"] = nx.is_connected(G)
|
||||
metrics["number_connected_components"] = nx.number_connected_components(G)
|
||||
|
||||
# --- Clustering ---
|
||||
metrics["average_clustering_coefficient"] = nx.average_clustering(G)
|
||||
metrics["clustering_coefficient"] = nx.clustering(G) # Per-node clustering
|
||||
|
||||
# --- Distance-Based Metrics (for connected graphs) ---
|
||||
# These are computationally intensive and only run on smaller, connected graphs.
|
||||
if metrics["is_connected"] and num_nodes < 250:
|
||||
metrics["average_shortest_path_length"] = nx.average_shortest_path_length(G)
|
||||
metrics["diameter"] = nx.diameter(G)
|
||||
metrics["eccentricity"] = nx.eccentricity(G)
|
||||
else:
|
||||
# Set to None if graph is disconnected or too large
|
||||
metrics["average_shortest_path_length"] = None
|
||||
metrics["diameter"] = None
|
||||
metrics["eccentricity"] = None
|
||||
|
||||
# --- Spectral & Community Metrics (Potentially Slow) ---
|
||||
# These are also limited to smaller graphs.
|
||||
if 1 < num_nodes < 500:
|
||||
# Eigenvalues of the Laplacian matrix
|
||||
try:
|
||||
laplacian_eigenvalues = sorted(nx.laplacian_spectrum(G))
|
||||
metrics["laplacian_eigenvalues"] = laplacian_eigenvalues
|
||||
# The second-smallest eigenvalue of the Laplacian matrix
|
||||
metrics["algebraic_connectivity"] = laplacian_eigenvalues[1]
|
||||
except Exception:
|
||||
metrics["laplacian_eigenvalues"] = None
|
||||
metrics["algebraic_connectivity"] = None
|
||||
|
||||
# Modularity using the Louvain community detection algorithm
|
||||
if num_edges > 0:
|
||||
try:
|
||||
communities = nx_comm.louvain_communities(G, seed=123)
|
||||
metrics["modularity"] = nx_comm.modularity(G, communities)
|
||||
except Exception:
|
||||
metrics["modularity"] = None # Algorithm may fail on some graphs
|
||||
else:
|
||||
metrics["modularity"] = None
|
||||
else:
|
||||
metrics["laplacian_eigenvalues"] = None
|
||||
metrics["algebraic_connectivity"] = None
|
||||
metrics["modularity"] = None
|
||||
|
||||
return metrics
|
||||
|
||||
def generate_multiple_initial_states(key: jax.Array, num_sims: int, num_agents: int, max_range: float) -> jax.Array:
|
||||
"""Generate a batch of unique random initial states for the agents."""
|
||||
return jax.random.uniform(
|
||||
key,
|
||||
shape=(num_sims, num_agents),
|
||||
minval=-max_range,
|
||||
maxval=max_range
|
||||
)
|
||||
|
||||
def main():
|
||||
"""Main script to generate and save the consensus simulation dataset."""
|
||||
# --- Configuration ---
|
||||
GRAPH_GEN_ALGOS = ["erdos_renyi", "barabasi_albert", "powerlaw_cluster", "watts_strogatz"]
|
||||
AGENT_COUNTS = range(5, 51, 5)# 5, 10, ..., 50 agents
|
||||
GRAPHS_PER_AGENT_COUNT = 100
|
||||
GRAPHS_PER_GRAPH_ALGO = GRAPHS_PER_AGENT_COUNT // len(GRAPH_GEN_ALGOS)
|
||||
SIMS_PER_GRAPH = 100
|
||||
OUTPUT_DIR = "datasets/consensus_dataset"
|
||||
|
||||
# --- Setup ---
|
||||
seed = int(time.time())
|
||||
main_jax_key = jax.random.PRNGKey(seed)
|
||||
numpy_rng = np.random.default_rng(seed=seed + 1)
|
||||
|
||||
print(f"🚀 Starting data generation...")
|
||||
print(f"Saving dataset to directory: '{OUTPUT_DIR}/'")
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# --- Main Generation Loop ---
|
||||
for n_agents in tqdm(AGENT_COUNTS, desc="Overall Progress"):
|
||||
agent_folder = os.path.join(OUTPUT_DIR, f"agents_{n_agents}")
|
||||
os.makedirs(agent_folder, exist_ok=True)
|
||||
|
||||
for graph_idx in tqdm(range(GRAPHS_PER_AGENT_COUNT), desc=f"Graphs for N={n_agents}", leave=False):
|
||||
graph_algo_idx = graph_idx // GRAPHS_PER_GRAPH_ALGO
|
||||
|
||||
# 1. Configure the simulation using the imported class
|
||||
config = sims.consensus.ConsensusConfig()
|
||||
|
||||
config.num_agents=n_agents
|
||||
config.num_sims=SIMS_PER_GRAPH
|
||||
|
||||
# 2. Generate graph and its metrics
|
||||
G, graph_type, adj_matrix_np = generate_connected_graph(numpy_rng, config.num_agents, GRAPH_GEN_ALGOS[graph_algo_idx])
|
||||
adj_matrix_jax = jnp.array(adj_matrix_np)
|
||||
graph_metrics = calculate_graph_metrics(G)
|
||||
graph_metrics["graph_type"] = graph_type
|
||||
|
||||
# 3. Generate initial states for the 100 simulations
|
||||
main_jax_key, states_key = jax.random.split(main_jax_key)
|
||||
initial_states = generate_multiple_initial_states(
|
||||
states_key, config.num_sims, config.num_agents, config.max_range
|
||||
)
|
||||
|
||||
# 4. Run all simulations using the imported JAX function
|
||||
trajectories = sims.consensus.run_consensus_sim(adj_matrix_jax, initial_states, config)
|
||||
|
||||
|
||||
|
||||
# 5. Package all data into a dictionary for logging
|
||||
log_data = {
|
||||
"simulation_config": {
|
||||
"num_agents": config.num_agents,
|
||||
"num_sims": config.num_sims,
|
||||
"num_time_steps": config.num_time_steps,
|
||||
"step_size": config.step_size,
|
||||
"max_range": config.max_range,
|
||||
"directed": config.directed,
|
||||
"weighted": config.weighted,
|
||||
|
||||
},
|
||||
"graph_metrics": graph_metrics,
|
||||
"adjacency_matrix": adj_matrix_np.tolist(),
|
||||
"initial_states": initial_states.tolist(),
|
||||
"trajectories": trajectories.tolist()
|
||||
}
|
||||
|
||||
# 6. Save the complete log to a JSON file
|
||||
file_path = os.path.join(agent_folder, f"graph_{graph_idx:02d}.json")
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(log_data, f, indent=2)
|
||||
|
||||
print(f"\n✅ Data generation complete! Check the '{OUTPUT_DIR}' folder.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user