Files
graph_recognition_w_attn/generate_data_consensus.py

206 lines
7.9 KiB
Python
Raw Permalink Normal View History

2025-07-31 01:12:53 -04:00
import os
import json
import time
import jax
import jax.numpy as jnp
import numpy as np
import networkx as nx
import networkx.algorithms.community as nx_comm
from tqdm import tqdm
import sims
def generate_connected_graph(rng: np.random.Generator, num_agents: int, graph_type: str) -> tuple[nx.Graph, str, np.ndarray]:
"""
Generates a random, undirected, unweighted, connected graph.
It randomly selects a NetworkX algorithm and retries until the graph is connected.
Returns:
A tuple containing the networkx Graph object, the algorithm name, and the adjacency matrix.
"""
G = None
if graph_type == "erdos_renyi":
# p = 2.0 * np.log(num_agents) / num_agents
p = 0.5
G = nx.erdos_renyi_graph(num_agents, p, seed=rng)
elif graph_type == "watts_strogatz":
k = 2
p = 0.1
G = nx.watts_strogatz_graph(num_agents, k, p, seed=rng)
elif graph_type == "barabasi_albert":
m = 2
if m >= num_agents: m = max(1, num_agents - 1)
G = nx.barabasi_albert_graph(num_agents, m, seed=rng)
elif graph_type == "powerlaw_cluster":
m = 3 # Number of random edges to add for each new node
p = 0.1 # Probability of adding a triangle after adding a random edge
if m >= num_agents: m = max(1, num_agents - 1)
G = nx.powerlaw_cluster_graph(num_agents, m, p, seed=rng)
# Add self-loops, as they are often assumed in consensus algorithms
G.add_edges_from([(i, i) for i in range(num_agents)])
adj_matrix = nx.to_numpy_array(G, dtype=np.float32)
return G, graph_type, adj_matrix
def calculate_graph_metrics(G: nx.Graph) -> dict:
"""
Calculates and returns a dictionary of key graph metrics.
This function computes basic properties, connectivity, clustering,
community structure, and spectral properties of the graph.
Computationally expensive metrics are skipped for larger graphs to ensure performance.
"""
metrics = {}
num_nodes = G.number_of_nodes()
num_edges = G.number_of_edges()
# --- Basic Properties ---
metrics["number_of_nodes"] = num_nodes
metrics["number_of_edges"] = num_edges
metrics["average_degree"] = (2 * num_edges / num_nodes) if num_nodes > 0 else 0
metrics["edge_density"] = nx.density(G)
# --- Connectivity ---
metrics["is_connected"] = nx.is_connected(G)
metrics["number_connected_components"] = nx.number_connected_components(G)
# --- Clustering ---
metrics["average_clustering_coefficient"] = nx.average_clustering(G)
metrics["clustering_coefficient"] = nx.clustering(G) # Per-node clustering
# --- Distance-Based Metrics (for connected graphs) ---
# These are computationally intensive and only run on smaller, connected graphs.
if metrics["is_connected"] and num_nodes < 250:
metrics["average_shortest_path_length"] = nx.average_shortest_path_length(G)
metrics["diameter"] = nx.diameter(G)
metrics["eccentricity"] = nx.eccentricity(G)
else:
# Set to None if graph is disconnected or too large
metrics["average_shortest_path_length"] = None
metrics["diameter"] = None
metrics["eccentricity"] = None
# --- Spectral & Community Metrics (Potentially Slow) ---
# These are also limited to smaller graphs.
if 1 < num_nodes < 500:
# Eigenvalues of the Laplacian matrix
try:
laplacian_eigenvalues = sorted(nx.laplacian_spectrum(G))
metrics["laplacian_eigenvalues"] = laplacian_eigenvalues
# The second-smallest eigenvalue of the Laplacian matrix
metrics["algebraic_connectivity"] = laplacian_eigenvalues[1]
except Exception:
metrics["laplacian_eigenvalues"] = None
metrics["algebraic_connectivity"] = None
# Modularity using the Louvain community detection algorithm
if num_edges > 0:
try:
communities = nx_comm.louvain_communities(G, seed=123)
metrics["modularity"] = nx_comm.modularity(G, communities)
except Exception:
metrics["modularity"] = None # Algorithm may fail on some graphs
else:
metrics["modularity"] = None
else:
metrics["laplacian_eigenvalues"] = None
metrics["algebraic_connectivity"] = None
metrics["modularity"] = None
return metrics
def generate_multiple_initial_states(key: jax.Array, num_sims: int, num_agents: int, max_range: float) -> jax.Array:
"""Generate a batch of unique random initial states for the agents."""
return jax.random.uniform(
key,
shape=(num_sims, num_agents),
minval=-max_range,
maxval=max_range
)
def main():
"""Main script to generate and save the consensus simulation dataset."""
# --- Configuration ---
GRAPH_GEN_ALGOS = ["erdos_renyi", "barabasi_albert", "powerlaw_cluster", "watts_strogatz"]
AGENT_COUNTS = range(5, 51, 5)# 5, 10, ..., 50 agents
GRAPHS_PER_AGENT_COUNT = 100
GRAPHS_PER_GRAPH_ALGO = GRAPHS_PER_AGENT_COUNT // len(GRAPH_GEN_ALGOS)
SIMS_PER_GRAPH = 100
OUTPUT_DIR = "datasets/consensus_dataset"
# --- Setup ---
seed = int(time.time())
main_jax_key = jax.random.PRNGKey(seed)
numpy_rng = np.random.default_rng(seed=seed + 1)
print(f"🚀 Starting data generation...")
print(f"Saving dataset to directory: '{OUTPUT_DIR}/'")
os.makedirs(OUTPUT_DIR, exist_ok=True)
# --- Main Generation Loop ---
for n_agents in tqdm(AGENT_COUNTS, desc="Overall Progress"):
agent_folder = os.path.join(OUTPUT_DIR, f"agents_{n_agents}")
os.makedirs(agent_folder, exist_ok=True)
for graph_idx in tqdm(range(GRAPHS_PER_AGENT_COUNT), desc=f"Graphs for N={n_agents}", leave=False):
graph_algo_idx = graph_idx // GRAPHS_PER_GRAPH_ALGO
# 1. Configure the simulation using the imported class
config = sims.consensus.ConsensusConfig()
config.num_agents=n_agents
config.num_sims=SIMS_PER_GRAPH
# 2. Generate graph and its metrics
G, graph_type, adj_matrix_np = generate_connected_graph(numpy_rng, config.num_agents, GRAPH_GEN_ALGOS[graph_algo_idx])
adj_matrix_jax = jnp.array(adj_matrix_np)
graph_metrics = calculate_graph_metrics(G)
graph_metrics["graph_type"] = graph_type
# 3. Generate initial states for the 100 simulations
main_jax_key, states_key = jax.random.split(main_jax_key)
initial_states = generate_multiple_initial_states(
states_key, config.num_sims, config.num_agents, config.max_range
)
# 4. Run all simulations using the imported JAX function
trajectories = sims.consensus.run_consensus_sim(adj_matrix_jax, initial_states, config)
# 5. Package all data into a dictionary for logging
log_data = {
"simulation_config": {
"num_agents": config.num_agents,
"num_sims": config.num_sims,
"num_time_steps": config.num_time_steps,
"step_size": config.step_size,
"max_range": config.max_range,
"directed": config.directed,
"weighted": config.weighted,
},
"graph_metrics": graph_metrics,
"adjacency_matrix": adj_matrix_np.tolist(),
"initial_states": initial_states.tolist(),
"trajectories": trajectories.tolist()
}
# 6. Save the complete log to a JSON file
file_path = os.path.join(agent_folder, f"graph_{graph_idx:02d}.json")
with open(file_path, 'w') as f:
json.dump(log_data, f, indent=2)
print(f"\n✅ Data generation complete! Check the '{OUTPUT_DIR}' folder.")
if __name__ == "__main__":
main()