replicated mecc
This commit is contained in:
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
.venv
|
||||
.env
|
||||
__pycaches__
|
||||
datasets/
|
||||
temp.*
|
||||
test.*
|
||||
|
15
.vscode/launch.json
vendored
Normal file
15
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python Debugger: Current File",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal"
|
||||
}
|
||||
]
|
||||
}
|
BIN
__pycache__/generate_data_consensus.cpython-312.pyc
Normal file
BIN
__pycache__/generate_data_consensus.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
__pycache__/model.cpython-312.pyc
Normal file
BIN
__pycache__/model.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/train.cpython-312.pyc
Normal file
BIN
__pycache__/train.cpython-312.pyc
Normal file
Binary file not shown.
206
generate_data_consensus.py
Normal file
206
generate_data_consensus.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import networkx as nx
|
||||
import networkx.algorithms.community as nx_comm
|
||||
from tqdm import tqdm
|
||||
import sims
|
||||
|
||||
|
||||
|
||||
def generate_connected_graph(rng: np.random.Generator, num_agents: int, graph_type: str) -> tuple[nx.Graph, str, np.ndarray]:
|
||||
"""
|
||||
Generates a random, undirected, unweighted, connected graph.
|
||||
It randomly selects a NetworkX algorithm and retries until the graph is connected.
|
||||
|
||||
Returns:
|
||||
A tuple containing the networkx Graph object, the algorithm name, and the adjacency matrix.
|
||||
"""
|
||||
|
||||
G = None
|
||||
|
||||
|
||||
if graph_type == "erdos_renyi":
|
||||
|
||||
# p = 2.0 * np.log(num_agents) / num_agents
|
||||
p = 0.5
|
||||
G = nx.erdos_renyi_graph(num_agents, p, seed=rng)
|
||||
|
||||
elif graph_type == "watts_strogatz":
|
||||
k = 2
|
||||
p = 0.1
|
||||
G = nx.watts_strogatz_graph(num_agents, k, p, seed=rng)
|
||||
|
||||
elif graph_type == "barabasi_albert":
|
||||
m = 2
|
||||
if m >= num_agents: m = max(1, num_agents - 1)
|
||||
G = nx.barabasi_albert_graph(num_agents, m, seed=rng)
|
||||
|
||||
elif graph_type == "powerlaw_cluster":
|
||||
m = 3 # Number of random edges to add for each new node
|
||||
p = 0.1 # Probability of adding a triangle after adding a random edge
|
||||
if m >= num_agents: m = max(1, num_agents - 1)
|
||||
G = nx.powerlaw_cluster_graph(num_agents, m, p, seed=rng)
|
||||
|
||||
# Add self-loops, as they are often assumed in consensus algorithms
|
||||
G.add_edges_from([(i, i) for i in range(num_agents)])
|
||||
adj_matrix = nx.to_numpy_array(G, dtype=np.float32)
|
||||
|
||||
return G, graph_type, adj_matrix
|
||||
|
||||
def calculate_graph_metrics(G: nx.Graph) -> dict:
|
||||
"""
|
||||
Calculates and returns a dictionary of key graph metrics.
|
||||
|
||||
This function computes basic properties, connectivity, clustering,
|
||||
community structure, and spectral properties of the graph.
|
||||
Computationally expensive metrics are skipped for larger graphs to ensure performance.
|
||||
"""
|
||||
metrics = {}
|
||||
num_nodes = G.number_of_nodes()
|
||||
num_edges = G.number_of_edges()
|
||||
|
||||
# --- Basic Properties ---
|
||||
metrics["number_of_nodes"] = num_nodes
|
||||
metrics["number_of_edges"] = num_edges
|
||||
metrics["average_degree"] = (2 * num_edges / num_nodes) if num_nodes > 0 else 0
|
||||
metrics["edge_density"] = nx.density(G)
|
||||
|
||||
# --- Connectivity ---
|
||||
metrics["is_connected"] = nx.is_connected(G)
|
||||
metrics["number_connected_components"] = nx.number_connected_components(G)
|
||||
|
||||
# --- Clustering ---
|
||||
metrics["average_clustering_coefficient"] = nx.average_clustering(G)
|
||||
metrics["clustering_coefficient"] = nx.clustering(G) # Per-node clustering
|
||||
|
||||
# --- Distance-Based Metrics (for connected graphs) ---
|
||||
# These are computationally intensive and only run on smaller, connected graphs.
|
||||
if metrics["is_connected"] and num_nodes < 250:
|
||||
metrics["average_shortest_path_length"] = nx.average_shortest_path_length(G)
|
||||
metrics["diameter"] = nx.diameter(G)
|
||||
metrics["eccentricity"] = nx.eccentricity(G)
|
||||
else:
|
||||
# Set to None if graph is disconnected or too large
|
||||
metrics["average_shortest_path_length"] = None
|
||||
metrics["diameter"] = None
|
||||
metrics["eccentricity"] = None
|
||||
|
||||
# --- Spectral & Community Metrics (Potentially Slow) ---
|
||||
# These are also limited to smaller graphs.
|
||||
if 1 < num_nodes < 500:
|
||||
# Eigenvalues of the Laplacian matrix
|
||||
try:
|
||||
laplacian_eigenvalues = sorted(nx.laplacian_spectrum(G))
|
||||
metrics["laplacian_eigenvalues"] = laplacian_eigenvalues
|
||||
# The second-smallest eigenvalue of the Laplacian matrix
|
||||
metrics["algebraic_connectivity"] = laplacian_eigenvalues[1]
|
||||
except Exception:
|
||||
metrics["laplacian_eigenvalues"] = None
|
||||
metrics["algebraic_connectivity"] = None
|
||||
|
||||
# Modularity using the Louvain community detection algorithm
|
||||
if num_edges > 0:
|
||||
try:
|
||||
communities = nx_comm.louvain_communities(G, seed=123)
|
||||
metrics["modularity"] = nx_comm.modularity(G, communities)
|
||||
except Exception:
|
||||
metrics["modularity"] = None # Algorithm may fail on some graphs
|
||||
else:
|
||||
metrics["modularity"] = None
|
||||
else:
|
||||
metrics["laplacian_eigenvalues"] = None
|
||||
metrics["algebraic_connectivity"] = None
|
||||
metrics["modularity"] = None
|
||||
|
||||
return metrics
|
||||
|
||||
def generate_multiple_initial_states(key: jax.Array, num_sims: int, num_agents: int, max_range: float) -> jax.Array:
|
||||
"""Generate a batch of unique random initial states for the agents."""
|
||||
return jax.random.uniform(
|
||||
key,
|
||||
shape=(num_sims, num_agents),
|
||||
minval=-max_range,
|
||||
maxval=max_range
|
||||
)
|
||||
|
||||
def main():
|
||||
"""Main script to generate and save the consensus simulation dataset."""
|
||||
# --- Configuration ---
|
||||
GRAPH_GEN_ALGOS = ["erdos_renyi", "barabasi_albert", "powerlaw_cluster", "watts_strogatz"]
|
||||
AGENT_COUNTS = range(5, 51, 5)# 5, 10, ..., 50 agents
|
||||
GRAPHS_PER_AGENT_COUNT = 100
|
||||
GRAPHS_PER_GRAPH_ALGO = GRAPHS_PER_AGENT_COUNT // len(GRAPH_GEN_ALGOS)
|
||||
SIMS_PER_GRAPH = 100
|
||||
OUTPUT_DIR = "datasets/consensus_dataset"
|
||||
|
||||
# --- Setup ---
|
||||
seed = int(time.time())
|
||||
main_jax_key = jax.random.PRNGKey(seed)
|
||||
numpy_rng = np.random.default_rng(seed=seed + 1)
|
||||
|
||||
print(f"🚀 Starting data generation...")
|
||||
print(f"Saving dataset to directory: '{OUTPUT_DIR}/'")
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# --- Main Generation Loop ---
|
||||
for n_agents in tqdm(AGENT_COUNTS, desc="Overall Progress"):
|
||||
agent_folder = os.path.join(OUTPUT_DIR, f"agents_{n_agents}")
|
||||
os.makedirs(agent_folder, exist_ok=True)
|
||||
|
||||
for graph_idx in tqdm(range(GRAPHS_PER_AGENT_COUNT), desc=f"Graphs for N={n_agents}", leave=False):
|
||||
graph_algo_idx = graph_idx // GRAPHS_PER_GRAPH_ALGO
|
||||
|
||||
# 1. Configure the simulation using the imported class
|
||||
config = sims.consensus.ConsensusConfig()
|
||||
|
||||
config.num_agents=n_agents
|
||||
config.num_sims=SIMS_PER_GRAPH
|
||||
|
||||
# 2. Generate graph and its metrics
|
||||
G, graph_type, adj_matrix_np = generate_connected_graph(numpy_rng, config.num_agents, GRAPH_GEN_ALGOS[graph_algo_idx])
|
||||
adj_matrix_jax = jnp.array(adj_matrix_np)
|
||||
graph_metrics = calculate_graph_metrics(G)
|
||||
graph_metrics["graph_type"] = graph_type
|
||||
|
||||
# 3. Generate initial states for the 100 simulations
|
||||
main_jax_key, states_key = jax.random.split(main_jax_key)
|
||||
initial_states = generate_multiple_initial_states(
|
||||
states_key, config.num_sims, config.num_agents, config.max_range
|
||||
)
|
||||
|
||||
# 4. Run all simulations using the imported JAX function
|
||||
trajectories = sims.consensus.run_consensus_sim(adj_matrix_jax, initial_states, config)
|
||||
|
||||
|
||||
|
||||
# 5. Package all data into a dictionary for logging
|
||||
log_data = {
|
||||
"simulation_config": {
|
||||
"num_agents": config.num_agents,
|
||||
"num_sims": config.num_sims,
|
||||
"num_time_steps": config.num_time_steps,
|
||||
"step_size": config.step_size,
|
||||
"max_range": config.max_range,
|
||||
"directed": config.directed,
|
||||
"weighted": config.weighted,
|
||||
|
||||
},
|
||||
"graph_metrics": graph_metrics,
|
||||
"adjacency_matrix": adj_matrix_np.tolist(),
|
||||
"initial_states": initial_states.tolist(),
|
||||
"trajectories": trajectories.tolist()
|
||||
}
|
||||
|
||||
# 6. Save the complete log to a JSON file
|
||||
file_path = os.path.join(agent_folder, f"graph_{graph_idx:02d}.json")
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(log_data, f, indent=2)
|
||||
|
||||
print(f"\n✅ Data generation complete! Check the '{OUTPUT_DIR}' folder.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
162
generate_data_kuramoto.py
Normal file
162
generate_data_kuramoto.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import networkx as nx
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from dataclasses import asdict
|
||||
import sims
|
||||
# --- Import necessary components from your Kuramoto simulation file ---
|
||||
from sims.kuramoto import (
|
||||
KuramotoConfig,
|
||||
run_kuramoto_simulation,
|
||||
phase_coherence,
|
||||
mean_frequency,
|
||||
)
|
||||
|
||||
from generate_data_consensus import generate_connected_graph, calculate_graph_metrics
|
||||
|
||||
# ==============================================================================
|
||||
# SECTION 1: BATCHED OPERATIONS FOR EFFICIENT DATA GENERATION
|
||||
# ==============================================================================
|
||||
|
||||
@partial(jax.jit, static_argnames=("config",))
|
||||
def run_batched_simulations(
|
||||
thetas0_batch: jax.Array, # (num_sims, N)
|
||||
omegas_batch: jax.Array, # (num_sims, N)
|
||||
adj_mat: jax.Array, # (N, N)
|
||||
config: KuramotoConfig
|
||||
) -> tuple[jax.Array, jax.Array, jax.Array]:
|
||||
"""
|
||||
Runs many Kuramoto simulations in parallel for the same graph but different
|
||||
initial conditions, and performs analysis.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- All trajectories (num_sims, T, N)
|
||||
- Final coherence R(T) for each sim (num_sims,)
|
||||
- Mean frequencies for each sim (num_sims, N)
|
||||
"""
|
||||
# Create a batched version of the simulation runner using vmap.
|
||||
# This maps the function over the first axis of thetas0_batch and omegas_batch.
|
||||
vmapped_runner = jax.vmap(
|
||||
run_kuramoto_simulation,
|
||||
in_axes=(0, 0, None, None), # Map over thetas0, omegas; adj_mat and config are fixed
|
||||
out_axes=0
|
||||
)
|
||||
all_trajectories = vmapped_runner(thetas0_batch, omegas_batch, adj_mat, config)
|
||||
|
||||
# Analyze the results in a batched manner
|
||||
# phase_coherence naturally works on the time axis, so we vmap over the sim axis.
|
||||
vmapped_coherence = jax.vmap(phase_coherence)(all_trajectories) # -> (num_sims, T)
|
||||
final_coherence = vmapped_coherence[:, -1] # -> (num_sims,)
|
||||
|
||||
# vmap the mean_frequency calculation over trajectories and omegas
|
||||
vmapped_mean_freq = jax.vmap(mean_frequency, in_axes=(0, 0, None, None))
|
||||
all_mean_freqs = vmapped_mean_freq(all_trajectories, omegas_batch, adj_mat, config)
|
||||
|
||||
return all_trajectories, final_coherence, all_mean_freqs
|
||||
|
||||
def generate_batched_initial_states(key: jax.Array, num_sims: int, config: KuramotoConfig):
|
||||
"""Generates a batch of initial states."""
|
||||
keys = jax.random.split(key, num_sims)
|
||||
|
||||
# We will just write a new generator that is vmap-friendly
|
||||
@jax.vmap
|
||||
def generate_single_state(k):
|
||||
key_theta, key_omega = jax.random.split(k)
|
||||
thetas0 = jax.random.uniform(key_theta, (config.num_agents,), minval=0, maxval=2 * jnp.pi)
|
||||
omegas = jax.random.normal(key_omega, (config.num_agents,)) # Using std=1, mean=0
|
||||
return thetas0, omegas
|
||||
|
||||
thetas0_batch, omegas_batch = generate_single_state(keys)
|
||||
return thetas0_batch, omegas_batch
|
||||
|
||||
|
||||
def main():
|
||||
"""Main script to generate and save the Kuramoto simulation dataset."""
|
||||
# --- Configuration ---
|
||||
GRAPH_GEN_ALGOS = ["erdos_renyi", "barabasi_albert", "powerlaw_cluster", "watts_strogatz"]
|
||||
AGENT_COUNTS = range(5, 51, 5) # 5, 10, ..., 50 agents
|
||||
# AGENT_COUNTS = [50]
|
||||
GRAPHS_PER_AGENT_COUNT = 100
|
||||
GRAPHS_PER_GRAPH_ALGO = GRAPHS_PER_AGENT_COUNT // len(GRAPH_GEN_ALGOS)
|
||||
SIMS_PER_GRAPH = 100
|
||||
OUTPUT_DIR = "datasets/kuramoto_dataset"
|
||||
|
||||
# --- Setup ---
|
||||
seed = int(time.time())
|
||||
main_key = jax.random.PRNGKey(seed)
|
||||
numpy_rng = np.random.default_rng(seed=seed + 1)
|
||||
|
||||
print(f"🚀 Starting Kuramoto data generation...")
|
||||
print(f"Saving dataset to: '{OUTPUT_DIR}/'")
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# --- Main Loop ---
|
||||
for n_agents in tqdm(AGENT_COUNTS, desc="Agent Counts"):
|
||||
agent_folder = os.path.join(OUTPUT_DIR, f"agents_{n_agents}")
|
||||
os.makedirs(agent_folder, exist_ok=True)
|
||||
|
||||
for graph_idx in tqdm(range(GRAPHS_PER_AGENT_COUNT), desc=f"Graphs for N={n_agents}", leave=False):
|
||||
|
||||
graph_algo_idx = graph_idx // GRAPHS_PER_GRAPH_ALGO
|
||||
# 1. Setup config, keys, and graph
|
||||
config = KuramotoConfig()
|
||||
config.num_agents=n_agents
|
||||
config.coupling=1.5
|
||||
config.T=15.0
|
||||
main_key, graph_key, state_key = jax.random.split(main_key, 3)
|
||||
|
||||
G, graph_type, adj_matrix_np = generate_connected_graph(numpy_rng, config.num_agents, GRAPH_GEN_ALGOS[graph_algo_idx])
|
||||
adj_matrix_jax = jnp.array(adj_matrix_np)
|
||||
graph_metrics = calculate_graph_metrics(G)
|
||||
graph_metrics["graph_type"] = graph_type
|
||||
|
||||
# 2. Generate a batch of initial conditions
|
||||
thetas0_batch, omegas_batch = generate_batched_initial_states(state_key, SIMS_PER_GRAPH, config)
|
||||
|
||||
# 3. Run all simulations and analyses in a single, efficient call
|
||||
trajectories, final_R, mean_freqs = run_batched_simulations(
|
||||
thetas0_batch, omegas_batch, adj_matrix_jax, config
|
||||
)
|
||||
trajectories.block_until_ready() # Ensure computation is finished
|
||||
|
||||
# 4. Package all data for logging
|
||||
log_data = {
|
||||
"config": {
|
||||
"num_agents" : config.num_agents,
|
||||
"coupling" : config.coupling,
|
||||
"dt" : config.dt,
|
||||
"T" : config.T,
|
||||
"normalize_by_degree" : config.normalize_by_degree,
|
||||
"directed" : config.directed,
|
||||
"weighted" : config.weighted,
|
||||
},
|
||||
|
||||
"graph_metrics": graph_metrics,
|
||||
"adjacency_matrix": adj_matrix_np.tolist(),
|
||||
"initial_conditions": {
|
||||
"thetas0": thetas0_batch.tolist(),
|
||||
"omegas": omegas_batch.tolist()
|
||||
},
|
||||
"results": {
|
||||
"final_coherence": final_R.tolist(),
|
||||
"mean_frequencies": mean_freqs.tolist()
|
||||
},
|
||||
# Optionally save full trajectories. Warning: can create large files.
|
||||
"trajectories": trajectories.tolist()
|
||||
}
|
||||
|
||||
# 6. Save the complete log to a JSON file
|
||||
file_path = os.path.join(agent_folder, f"graph_{graph_idx:02d}.json")
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(log_data, f, indent=2)
|
||||
|
||||
print(f"\n✅ Data generation complete! Check the '{OUTPUT_DIR}' folder.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
136
model.py
Normal file
136
model.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import jax
|
||||
from jax import random
|
||||
import jax.numpy as jnp
|
||||
from train import ModelConfig, TrainConfig
|
||||
import optax
|
||||
from functools import partial
|
||||
|
||||
def init_linear_layer(
|
||||
key: jax.Array,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
use_bias: bool = True
|
||||
):
|
||||
|
||||
"""
|
||||
Initializes the weights and biases in a linear layer.
|
||||
"""
|
||||
key_w, key_b = random.split(key)
|
||||
limit = jnp.sqrt(6/in_features)
|
||||
W = random.uniform(key_w, (in_features, out_features), minval=-limit, maxval=limit)
|
||||
params = {'W': W}
|
||||
if use_bias:
|
||||
b = random.uniform(key_b, (out_features,), minval=-limit, maxval=limit)
|
||||
params['b'] = b
|
||||
return params
|
||||
|
||||
def init_fn(key: jax.Array, config: ModelConfig):
|
||||
"""
|
||||
Initializes all model parameters. Returns a pytree
|
||||
"""
|
||||
|
||||
key_embed, key_translate, key_attn_proj, key_head = random.split(key, 4)
|
||||
|
||||
params = {
|
||||
"agent_embeddings" : {
|
||||
"weight" : random.normal(key_embed, shape=(config.num_agents, config.embedding_dim))
|
||||
},
|
||||
"translate": init_linear_layer(key_translate, config.input_dim, config.embedding_dim),
|
||||
"attn_proj": init_linear_layer(key_attn_proj, config.embedding_dim, 2 * config.embedding_dim, use_bias=False),
|
||||
"head": init_linear_layer(key_head, config.embedding_dim, config.output_dim)
|
||||
}
|
||||
|
||||
return params
|
||||
|
||||
def forward(params: dict, input_timesteps: jax.Array, config: ModelConfig):
|
||||
"""
|
||||
Model's forward function. Takes in the parameters and inptu timesteps, returns predictions
|
||||
"""
|
||||
batch_size, num_agents, _ = input_timesteps.shape
|
||||
|
||||
agent_embed = params["agent_embeddings"]["weight"]
|
||||
agent_embed = jnp.broadcast_to(agent_embed, (batch_size, num_agents, config.embedding_dim))
|
||||
|
||||
attn_proj_out = agent_embed @ params["attn_proj"]['W']
|
||||
k, q = jnp.split(attn_proj_out, 2, axis=-1)
|
||||
v = input_timesteps @ params["translate"]['W'] + params["translate"]['b']
|
||||
att_scores = (q @ k.transpose(0, 2, 1) )/ jnp.sqrt(num_agents)
|
||||
att_weights = jax.nn.softmax(att_scores, axis=-1)
|
||||
weighted_average = att_weights @ v
|
||||
prediction = weighted_average @ params["head"]['W'] + params["head"]['b']
|
||||
|
||||
return prediction
|
||||
|
||||
def get_attention_fn(params: dict, config: ModelConfig):
|
||||
"""
|
||||
Calculates and returns the learned attention matrix between agents.
|
||||
This is a pure function for analysis.
|
||||
"""
|
||||
embeddings = params['agent_embeddings']['weight']
|
||||
|
||||
# Project embeddings to get keys (k) and queries (q) for the global graph
|
||||
attn_proj_out = embeddings @ params['attn_proj']['W']
|
||||
k, q = jnp.split(attn_proj_out, 2, axis=-1)
|
||||
|
||||
# Note: Using sqrt(embedding_dim) as in the original get_attention method
|
||||
attn_scores = (q @ k.T) / jnp.sqrt(q.shape[-1])
|
||||
|
||||
return jnp.asarray(attn_scores) # Return as NumPy array for logging
|
||||
|
||||
def train_model(config: ModelConfig, inputs: jax.Array, targets: jax.Array,
|
||||
true_graph: jax.Array,
|
||||
train_config: TrainConfig,
|
||||
):
|
||||
|
||||
key = random.PRNGKey(0)
|
||||
key, init_key = random.split(key)
|
||||
params = init_fn(init_key, config)
|
||||
|
||||
optimizer = optax.adamw(train_config.learning_rate)
|
||||
opt_state = optimizer.init(params)
|
||||
|
||||
def loss_fn(p, x_batch, y_batch, config):
|
||||
predictions = forward(p, x_batch, config)
|
||||
loss = jnp.mean(jnp.abs(predictions - y_batch))
|
||||
return loss
|
||||
|
||||
@partial(jax.jit, static_argnames=['config'])
|
||||
def update_step(params, opt_state, x_batch, y_batch, config):
|
||||
# FIX 1: Pass all necessary arguments to the loss function here.
|
||||
loss_val, grads = jax.value_and_grad(loss_fn)(params, x_batch, y_batch, config)
|
||||
|
||||
updates, new_opt_state = optimizer.update(grads, opt_state, params)
|
||||
new_params = optax.apply_updates(params, updates)
|
||||
|
||||
return new_params, new_opt_state, loss_val
|
||||
|
||||
loss_history = {f"epoch_{i}": [] for i in range(train_config.epochs)}
|
||||
# FIX 2: Initialize with empty lists `[]` instead of `None`.
|
||||
graphs = {f"epoch_{i}": [] for i in range(train_config.epochs)}
|
||||
num_batches = len(inputs)
|
||||
|
||||
for epoch in range(train_config.epochs):
|
||||
running_loss = 0.0
|
||||
for batch_num in range(num_batches):
|
||||
x, y = inputs[batch_num], targets[batch_num]
|
||||
# FIX 3: Pass the `config` object, as it's a static argument for JIT.
|
||||
params, opt_state, loss_val = update_step(params, opt_state, x, y, config)
|
||||
running_loss += loss_val
|
||||
|
||||
epoch_loss = running_loss / num_batches
|
||||
loss_history[f"epoch_{epoch}"].append(epoch_loss)
|
||||
|
||||
if train_config.verbose and (epoch + 1) % 10 == 0:
|
||||
print(f"Epoch {epoch+1:3d} | Loss: {epoch_loss:.6f}")
|
||||
|
||||
if train_config.log and epoch % train_config.log_epoch_interval == 0:
|
||||
attn = get_attention_fn(params, config)
|
||||
graphs[f"epoch_{epoch}"].append(attn)
|
||||
|
||||
all_logs = {
|
||||
"loss_history": loss_history,
|
||||
"graphs": graphs,
|
||||
"true_graph": true_graph,
|
||||
}
|
||||
|
||||
return params, all_logs
|
@@ -10,6 +10,9 @@ dependencies = [
|
||||
"jax[cuda12]>=0.7.0",
|
||||
"jupyter>=1.1.1",
|
||||
"matplotlib>=3.10.3",
|
||||
"networkx>=3.5",
|
||||
"optax>=0.2.5",
|
||||
"scikit-learn>=1.7.1",
|
||||
"seaborn>=0.13.2",
|
||||
"tqdm>=4.67.1",
|
||||
]
|
||||
]
|
||||
|
2
sims/__init__.py
Normal file
2
sims/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .consensus import *
|
||||
from .kuramoto import *
|
BIN
sims/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
sims/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
sims/__pycache__/consensus.cpython-312.pyc
Normal file
BIN
sims/__pycache__/consensus.cpython-312.pyc
Normal file
Binary file not shown.
BIN
sims/__pycache__/kuramoto.cpython-312.pyc
Normal file
BIN
sims/__pycache__/kuramoto.cpython-312.pyc
Normal file
Binary file not shown.
@@ -10,7 +10,7 @@ class ConsensusConfig:
|
||||
Config class for Consensus dynamics sims
|
||||
"""
|
||||
num_sims: int = 500 # Number of consensus sims
|
||||
num_agents: int = 5# Number of agents in the consensus simulation
|
||||
num_agents: int = 5 # Number of agents in the consensus simulation
|
||||
max_range: float = 1 # Max range of values each agent can take
|
||||
step_size: float = 0.1 # Target range for length of simulation
|
||||
directed: bool = False # Consensus graph directed?
|
||||
@@ -61,6 +61,10 @@ def generate_random_adjacency_matrix(key: jax.Array, config: ConsensusConfig):
|
||||
config: ConsensusConfig
|
||||
Config for Consensus dyanmics
|
||||
|
||||
Returns
|
||||
---------------------
|
||||
adj_matrices: jax.Array (num_agents, num_agents)
|
||||
Random matrix
|
||||
"""
|
||||
rand_matrix = jax.random.uniform(key, shape=(config.num_agents, config.num_agents))
|
||||
# idxs = jnp.arange(config.num_agents)
|
||||
@@ -90,7 +94,7 @@ def generate_random_agent_states(key: jax.Array, config: ConsensusConfig):
|
||||
|
||||
Returns
|
||||
---------------------
|
||||
rand_states: jax.Array (num_agents, 1)
|
||||
rand_states: jax.Array (num_sims, num_agents)
|
||||
|
||||
"""
|
||||
rand_states = jax.random.uniform(key, shape=(config.num_sims, config.num_agents), minval=-config.max_range, maxval=config.max_range)
|
||||
@@ -113,7 +117,7 @@ def run_consensus_sim(adj_mat: jax.Array, initial_agent_state: jax.Array, config
|
||||
config: ConsensusConfig
|
||||
Config for Consensus dynamics
|
||||
"""
|
||||
|
||||
# batched consensus step (meant for many initial states)
|
||||
batched_consensus_step = jax.vmap(consensus_step, in_axes=(None, 0, None), out_axes=0)
|
||||
|
||||
def step(x_prev, _):
|
@@ -5,26 +5,18 @@ from functools import partial
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# -------------------- Configuration --------------------
|
||||
@dataclass(frozen=True)
|
||||
class KuramotoConfig:
|
||||
"""Configuration for the Kuramoto model simulation."""
|
||||
num_agents: int = 10 # N: Number of oscillators
|
||||
coupling: float = 1.0 # K: Coupling strength
|
||||
dt: float = 0.01 # Δt: Integration time step
|
||||
T: float = 10.0 # Total simulation time
|
||||
|
||||
# Adjacency matrix properties
|
||||
normalize_by_degree: bool = True
|
||||
time_steps: int = int(T/dt)
|
||||
normalize_by_degree: bool = False
|
||||
directed: bool = False
|
||||
weighted: bool = False
|
||||
|
||||
@property
|
||||
def num_time_steps(self) -> int:
|
||||
"""Total number of simulation steps."""
|
||||
return int(self.T / self.dt)
|
||||
|
||||
# -------------------- Core Dynamics --------------------
|
||||
@partial(jax.jit, static_argnames=("config",))
|
||||
def kuramoto_derivative(theta: jax.Array, # (N,) phase angles
|
||||
omega: jax.Array, # (N,) natural frequencies
|
||||
@@ -88,7 +80,7 @@ def run_kuramoto_simulation(
|
||||
scan_fn,
|
||||
thetas0,
|
||||
None,
|
||||
length=config.num_time_steps
|
||||
length=config.time_steps
|
||||
)
|
||||
return trajectory
|
||||
|
192
test.ipynb
192
test.ipynb
File diff suppressed because one or more lines are too long
19
train.py
Normal file
19
train.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
class ModelConfig:
|
||||
num_agents: int = 10
|
||||
embedding_dim: int = 64
|
||||
input_dim: int = 1
|
||||
output_dim: int = 1
|
||||
simulation_type:str = "consensus"
|
||||
|
||||
class TrainConfig:
|
||||
epochs: float = 100
|
||||
learning_rate: float = 1e-3
|
||||
verbose: bool = True
|
||||
log: bool = True
|
||||
log_epoch_interval: int = 10
|
||||
|
||||
|
||||
|
||||
|
259
train_and_eval.py
Normal file
259
train_and_eval.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import os
|
||||
import json
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from enum import Enum
|
||||
from tqdm import tqdm
|
||||
from sklearn.metrics import f1_score
|
||||
import uuid
|
||||
import pickle
|
||||
import sys
|
||||
|
||||
|
||||
# Import from your existing model file
|
||||
from model import ModelConfig, TrainConfig, train_model, get_attention_fn
|
||||
|
||||
# Define an Enum for the data source for clarity and type safety
|
||||
|
||||
# Overwrite the original TrainConfig to include our new parameters
|
||||
|
||||
class TrainConfig:
|
||||
"""Configuration for the training process."""
|
||||
learning_rate: float = 1e-3
|
||||
epochs: int = 100
|
||||
batch_size: int = 4096
|
||||
verbose: bool = False # Set to True to see epoch loss during training
|
||||
log: bool = True
|
||||
log_epoch_interval: int = 10
|
||||
|
||||
# --- New parameters for this script ---
|
||||
data_directory:str = "datasets/" + sys.argv[1] + "_dataset"
|
||||
# Threshold for converting attention scores to a binary graph
|
||||
f1_threshold: float = -0.4
|
||||
|
||||
|
||||
def prepare_data_for_model(trajectories: np.ndarray, batch_size: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Converts simulation trajectories into input-output pairs for the model.
|
||||
Input: state at time t. Target: state at time t+1.
|
||||
|
||||
Args:
|
||||
trajectories: A numpy array of shape (num_sims, num_timesteps, num_agents).
|
||||
batch_size: The desired batch size for training.
|
||||
|
||||
Returns:
|
||||
A tuple of (batched_inputs, batched_targets).
|
||||
"""
|
||||
# For each simulation, create (input, target) pairs
|
||||
all_inputs = []
|
||||
all_targets = []
|
||||
|
||||
num_sims, num_timesteps, num_agents = trajectories.shape
|
||||
|
||||
# trajectories = np.reshape(trajectories, shape=(num_sims * num_timesteps, num_agents))
|
||||
|
||||
for i_sim in range(num_sims):
|
||||
for j_tstep in range(num_timesteps-1):
|
||||
input = trajectories[i_sim, j_tstep, :]
|
||||
target = trajectories[i_sim, j_tstep + 1, :]
|
||||
all_inputs.append(input)
|
||||
all_targets.append(target)
|
||||
|
||||
all_indices = np.arange(len(all_inputs))
|
||||
np.random.shuffle(all_indices)
|
||||
all_inputs = np.array(all_inputs)
|
||||
all_targets = np.array(all_targets)
|
||||
|
||||
all_inputs = all_inputs[all_indices]
|
||||
all_targets = all_targets[all_indices]
|
||||
|
||||
# for sim_idx in range(num_sims):
|
||||
# # Input is state from t=0 to t=T-2
|
||||
# inputs = trajectories[sim_idx, :-1, :]
|
||||
# # Target is state from t=1 to t=T-1
|
||||
# targets = trajectories[sim_idx, 1:, :]
|
||||
# all_inputs.append(inputs)
|
||||
# all_targets.append(targets)
|
||||
|
||||
# Concatenate all pairs from all simulations
|
||||
# Shape -> (num_sims * (num_timesteps - 1), num_agents)
|
||||
# full_dataset_inputs = np.concatenate(all_inputs, axis=0)
|
||||
# full_dataset_targets = np.concatenate(all_targets, axis=0)
|
||||
|
||||
# Reshape to have a feature dimension
|
||||
# Shape -> (total_samples, num_agents, 1)
|
||||
full_dataset_inputs = np.expand_dims(all_inputs, axis=-1)
|
||||
full_dataset_targets = np.expand_dims(all_targets, axis=-1)
|
||||
|
||||
# Create batches
|
||||
num_samples = full_dataset_inputs.shape[0]
|
||||
num_batches = num_samples // batch_size
|
||||
|
||||
# Truncate to full batches
|
||||
truncated_inputs = full_dataset_inputs[:num_batches * batch_size]
|
||||
truncated_targets = full_dataset_targets[:num_batches * batch_size]
|
||||
|
||||
# Reshape into batches
|
||||
# Shape -> (num_batches, batch_size, num_agents, 1)
|
||||
batched_inputs = truncated_inputs.reshape(num_batches, batch_size, num_agents, 1)
|
||||
batched_targets = truncated_targets.reshape(num_batches, batch_size, num_agents, 1)
|
||||
|
||||
return batched_inputs, batched_targets
|
||||
|
||||
def calculate_f1_score(
|
||||
params: dict,
|
||||
model_config: ModelConfig,
|
||||
true_graph: np.ndarray,
|
||||
threshold: float
|
||||
) -> float:
|
||||
"""
|
||||
Extracts the learned attention graph, thresholds it, and computes the F1 score.
|
||||
"""
|
||||
# Get the learned attention matrix (N, N)
|
||||
learned_graph_scores = np.array(get_attention_fn(params, model_config))
|
||||
|
||||
# Normalize scores to [0, 1] for consistent thresholding (optional but good practice)
|
||||
# This uses min-max scaling on the flattened array
|
||||
# flat_scores = learned_graph_scores.flatten()
|
||||
# min_s, max_s = flat_scores.min(), flat_scores.max()
|
||||
# if max_s > min_s:
|
||||
# learned_graph_scores = (learned_graph_scores - min_s) / (max_s - min_s)
|
||||
|
||||
# Threshold to get a binary predicted graph
|
||||
predicted_graph = (learned_graph_scores > threshold).astype(int)
|
||||
|
||||
# The diagonal is not part of the prediction task
|
||||
# np.fill_diagonal(predicted_graph, 0)
|
||||
# np.fill_diagonal(true_graph, 0)
|
||||
|
||||
# Flatten both graphs to treat this as a binary classification problem
|
||||
true_flat = true_graph.flatten()
|
||||
pred_flat = predicted_graph.flatten()
|
||||
|
||||
return f1_score(true_flat, pred_flat)
|
||||
|
||||
def main():
|
||||
"""Main script to run the training and evaluation pipeline."""
|
||||
|
||||
train_config = TrainConfig()
|
||||
|
||||
# Check if the data directory exists
|
||||
if not os.path.isdir(train_config.data_directory):
|
||||
print(f"Error: Data directory '{train_config.data_directory}' not found.")
|
||||
print(f"Please run the data generation script for '{train_config.data_directory}' first.")
|
||||
return
|
||||
|
||||
print(f"🚀 Starting training pipeline for '{train_config.data_directory}' data.")
|
||||
|
||||
# Get sorted list of agent directories
|
||||
agent_dirs = sorted(
|
||||
[d for d in os.listdir(train_config.data_directory) if d.startswith("agents_")],
|
||||
key=lambda x: int(x.split('_')[1])
|
||||
)
|
||||
|
||||
for agent_dir_name in agent_dirs:
|
||||
agent_dir_path = os.path.join(train_config.data_directory, agent_dir_name)
|
||||
|
||||
all_results_for_agent = []
|
||||
|
||||
graph_files = sorted([f for f in os.listdir(agent_dir_path) if f.endswith(".json")])
|
||||
|
||||
print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...")
|
||||
|
||||
for graph_file_name in tqdm(graph_files, desc=f"Training on {agent_dir_name}"):
|
||||
file_path = os.path.join(agent_dir_path, graph_file_name)
|
||||
|
||||
with open(file_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 1. Load and Prepare Data
|
||||
trajectories = np.array(data['trajectories'])
|
||||
s, l, n = trajectories.shape
|
||||
# trajectories = trajectories.T
|
||||
# np.random.shuffle(trajectories)
|
||||
# trajectories = np.random.shuffle(trajectories)
|
||||
true_graph = np.array(data['adjacency_matrix'])
|
||||
inputs, targets = prepare_data_for_model(trajectories, train_config.batch_size)
|
||||
|
||||
# 2. Configure Model
|
||||
num_agents = trajectories.shape[-1]
|
||||
model_config = ModelConfig(
|
||||
)
|
||||
|
||||
model_config.num_agents=num_agents
|
||||
model_config.input_dim=1 # Each agent has a single state value at time t
|
||||
model_config.output_dim=1
|
||||
model_config.embedding_dim=32
|
||||
|
||||
# 3. Train the Model
|
||||
# This relies on the modified train_model that returns final params
|
||||
final_params, train_logs = train_model(
|
||||
config=model_config,
|
||||
inputs=inputs,
|
||||
targets=targets,
|
||||
true_graph=true_graph,
|
||||
train_config=train_config
|
||||
)
|
||||
|
||||
# 4. Evaluate
|
||||
f1 = calculate_f1_score(
|
||||
final_params,
|
||||
model_config,
|
||||
true_graph,
|
||||
train_config.f1_threshold
|
||||
)
|
||||
|
||||
loss_history_serializable = {
|
||||
epoch: [loss.item() for loss in losses]
|
||||
for epoch, losses in train_logs['loss_history'].items()
|
||||
}
|
||||
|
||||
random_id = str(uuid.uuid4())
|
||||
|
||||
# 5. Log Results
|
||||
result_log = {
|
||||
# "model_name": random_id,
|
||||
"source_file": graph_file_name,
|
||||
"graph_metrics": data['graph_metrics'],
|
||||
"f1_score": f1,
|
||||
"training_loss_history": loss_history_serializable,
|
||||
"config": {
|
||||
# Manually create the dictionary for the model config
|
||||
"model": {
|
||||
"num_agents": model_config.num_agents,
|
||||
"input_dim": model_config.input_dim,
|
||||
"output_dim": model_config.output_dim,
|
||||
"embedding_dim": model_config.embedding_dim
|
||||
},
|
||||
# This is correct because TrainConfig is a dataclass
|
||||
"training": vars(train_config)
|
||||
}
|
||||
}
|
||||
|
||||
result_final_params = final_params
|
||||
|
||||
|
||||
all_results_for_agent.append(result_log)
|
||||
|
||||
# 6. Save aggregated results for this agent count
|
||||
results_dir = os.path.join(agent_dir_path, "results")
|
||||
os.makedirs(results_dir, exist_ok=True)
|
||||
|
||||
output_file = os.path.join(results_dir, "summary_results.json")
|
||||
with open(output_file, 'w') as f:
|
||||
json.dump(all_results_for_agent, f, indent=2)
|
||||
|
||||
|
||||
model_path = os.path.join(results_dir, "model_params")
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
with open(os.path.join(model_path,"model_params" + ".pkl"), "wb") as f:
|
||||
pickle.dump(final_params, f)
|
||||
|
||||
print(f"✅ Results for {agent_dir_name} saved to {output_file}")
|
||||
|
||||
print("\n🎉 Pipeline finished successfully!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
217
train_and_eval_w_noise.py
Normal file
217
train_and_eval_w_noise.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import os
|
||||
import json
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from enum import Enum
|
||||
from tqdm import tqdm
|
||||
from sklearn.metrics import f1_score
|
||||
|
||||
# Import from your existing model file
|
||||
from model import ModelConfig, train_model, get_attention_fn
|
||||
|
||||
# --- MODIFICATION: Enums for clarity and safety ---
|
||||
class DataSource(Enum):
|
||||
CONSENSUS = "consensus"
|
||||
KURAMOTO = "kuramoto"
|
||||
|
||||
class NoiseType(Enum):
|
||||
NONE = "none"
|
||||
NORMAL = "normal"
|
||||
UNIFORM = "uniform"
|
||||
|
||||
# --- MODIFICATION: Updated TrainConfig ---
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
"""Configuration for the training process."""
|
||||
learning_rate: float = 1e-3
|
||||
epochs: int = 50
|
||||
batch_size: int = 64
|
||||
verbose: bool = False
|
||||
log: bool = True
|
||||
log_epoch_interval: int = 10
|
||||
|
||||
# Data and Noise parameters
|
||||
data_source: DataSource = DataSource.CONSENSUS
|
||||
noise_type: NoiseType = NoiseType.NORMAL
|
||||
noise_level: float = 0.1 # Stddev for Normal, half-width for Uniform
|
||||
|
||||
# Evaluation parameter
|
||||
f1_threshold: float = 0.5
|
||||
|
||||
@property
|
||||
def source_data_directory(self) -> str:
|
||||
"""The directory where the source dataset is located."""
|
||||
return f"{self.data_source.value}_dataset"
|
||||
|
||||
@property
|
||||
def results_directory_name(self) -> str:
|
||||
"""Generates a unique output directory name for separate logging."""
|
||||
if self.noise_type == NoiseType.NONE or self.noise_level == 0:
|
||||
return f"results_noiseless"
|
||||
return f"results_noise_{self.noise_type.value}_{self.noise_level}"
|
||||
|
||||
# --- MODIFICATION: Updated data prep function to add noise ---
|
||||
def prepare_data_for_model(
|
||||
trajectories: np.ndarray,
|
||||
key: jax.Array,
|
||||
train_config: TrainConfig
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Converts trajectories to input-output pairs and adds noise to the inputs.
|
||||
"""
|
||||
all_inputs, all_targets = [], []
|
||||
num_sims, num_timesteps, num_agents = trajectories.shape
|
||||
|
||||
for sim_idx in range(num_sims):
|
||||
all_inputs.append(trajectories[sim_idx, :-1, :])
|
||||
all_targets.append(trajectories[sim_idx, 1:, :])
|
||||
|
||||
full_dataset_inputs = np.concatenate(all_inputs, axis=0)
|
||||
full_dataset_targets = np.concatenate(all_targets, axis=0)
|
||||
|
||||
# --- NOISE INJECTION BLOCK ---
|
||||
if train_config.noise_type != NoiseType.NONE and train_config.noise_level > 0:
|
||||
noise_shape = full_dataset_inputs.shape
|
||||
if train_config.noise_type == NoiseType.NORMAL:
|
||||
noise = jax.random.normal(key, noise_shape) * train_config.noise_level
|
||||
elif train_config.noise_type == NoiseType.UNIFORM:
|
||||
noise = jax.random.uniform(
|
||||
key, noise_shape,
|
||||
minval=-train_config.noise_level,
|
||||
maxval=train_config.noise_level
|
||||
)
|
||||
full_dataset_inputs += np.array(noise) # Add noise to inputs
|
||||
# --- END NOISE BLOCK ---
|
||||
|
||||
full_dataset_inputs = np.expand_dims(full_dataset_inputs, axis=-1)
|
||||
full_dataset_targets = np.expand_dims(full_dataset_targets, axis=-1)
|
||||
|
||||
num_samples = full_dataset_inputs.shape[0]
|
||||
num_batches = num_samples // train_config.batch_size
|
||||
|
||||
truncated_inputs = full_dataset_inputs[:num_batches * train_config.batch_size]
|
||||
truncated_targets = full_dataset_targets[:num_batches * train_config.batch_size]
|
||||
|
||||
batched_inputs = truncated_inputs.reshape(num_batches, train_config.batch_size, num_agents, 1)
|
||||
batched_targets = truncated_targets.reshape(num_batches, train_config.batch_size, num_agents, 1)
|
||||
|
||||
return batched_inputs, batched_targets
|
||||
|
||||
def calculate_f1_score(
|
||||
params: dict,
|
||||
model_config: ModelConfig,
|
||||
true_graph: np.ndarray,
|
||||
threshold: float
|
||||
) -> float:
|
||||
"""Extracts the learned graph and computes the F1 score."""
|
||||
learned_scores = np.array(get_attention_fn(params, model_config))
|
||||
flat_scores = learned_scores.flatten()
|
||||
min_s, max_s = flat_scores.min(), flat_scores.max()
|
||||
if max_s > min_s:
|
||||
learned_scores = (learned_scores - min_s) / (max_s - min_s)
|
||||
predicted_graph = (learned_scores > threshold).astype(int)
|
||||
|
||||
np.fill_diagonal(predicted_graph, 0)
|
||||
np.fill_diagonal(true_graph, 0)
|
||||
|
||||
return f1_score(true_graph.flatten(), predicted_graph.flatten())
|
||||
|
||||
def main():
|
||||
"""Main script to run the training and evaluation pipeline."""
|
||||
|
||||
# Configure your training run here
|
||||
train_config = TrainConfig(
|
||||
noise_type=NoiseType.NORMAL,
|
||||
noise_level=0.1
|
||||
)
|
||||
|
||||
if not os.path.isdir(train_config.source_data_directory):
|
||||
print(f"Error: Source data '{train_config.source_data_directory}' not found.")
|
||||
return
|
||||
|
||||
print(f"🚀 Starting training pipeline for '{train_config.data_source.value}' data.")
|
||||
print(f"Noise Configuration: type={train_config.noise_type.value}, level={train_config.noise_level}")
|
||||
|
||||
# --- MODIFICATION: Main JAX key for noise generation ---
|
||||
main_key = jax.random.PRNGKey(42)
|
||||
|
||||
agent_dirs = sorted(
|
||||
[d for d in os.listdir(train_config.source_data_directory) if d.startswith("agents_")],
|
||||
key=lambda x: int(x.split('_')[1])
|
||||
)
|
||||
|
||||
for agent_dir_name in agent_dirs:
|
||||
agent_dir_path = os.path.join(train_config.source_data_directory, agent_dir_name)
|
||||
all_results_for_agent = []
|
||||
graph_files = sorted([f for f in os.listdir(agent_dir_path) if f.endswith(".json")])
|
||||
|
||||
print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...")
|
||||
|
||||
for graph_file_name in tqdm(graph_files, desc=f"Training on {agent_dir_name}"):
|
||||
file_path = os.path.join(agent_dir_path, graph_file_name)
|
||||
|
||||
with open(file_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
main_key, data_key = jax.random.split(main_key)
|
||||
trajectories = np.array(data['trajectories'])
|
||||
true_graph = np.array(data['adjacency_matrix'])
|
||||
|
||||
# --- MODIFICATION: Pass key and config to data prep ---
|
||||
inputs, targets = prepare_data_for_model(trajectories, data_key, train_config)
|
||||
|
||||
num_agents = int(trajectories.shape[-1])
|
||||
model_config = ModelConfig()
|
||||
model_config.num_agents = num_agents
|
||||
model_config.input_dim = 1
|
||||
model_config.output_dim = 1
|
||||
model_config.embedding_dim = 32
|
||||
|
||||
final_params, train_logs = train_model(
|
||||
config=model_config,
|
||||
inputs=inputs,
|
||||
targets=targets,
|
||||
true_graph=true_graph,
|
||||
train_config=train_config
|
||||
)
|
||||
|
||||
f1 = calculate_f1_score(final_params, model_config, true_graph, train_config.f1_threshold)
|
||||
|
||||
loss_history_serializable = {
|
||||
epoch: [loss.item() for loss in losses]
|
||||
for epoch, losses in train_logs['loss_history'].items()
|
||||
}
|
||||
|
||||
result_log = {
|
||||
"source_file": graph_file_name,
|
||||
"graph_metrics": data['graph_metrics'],
|
||||
"f1_score": f1,
|
||||
"training_loss_history": loss_history_serializable,
|
||||
"config": {
|
||||
"model": {
|
||||
"num_agents": model_config.num_agents,
|
||||
"input_dim": model_config.input_dim,
|
||||
"output_dim": model_config.output_dim,
|
||||
"embedding_dim": model_config.embedding_dim
|
||||
},
|
||||
"training": asdict(train_config)
|
||||
}
|
||||
}
|
||||
all_results_for_agent.append(result_log)
|
||||
|
||||
# --- MODIFICATION: Save to a separate results directory ---
|
||||
results_dir = os.path.join(agent_dir_path, train_config.results_directory_name)
|
||||
os.makedirs(results_dir, exist_ok=True)
|
||||
|
||||
output_file = os.path.join(results_dir, "summary_results.json")
|
||||
with open(output_file, 'w') as f:
|
||||
json.dump(all_results_for_agent, f, indent=2)
|
||||
|
||||
print(f"✅ Results for {agent_dir_name} saved to {output_file}")
|
||||
|
||||
print("\n🎉 Pipeline finished successfully!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
114
uv.lock
generated
114
uv.lock
generated
@@ -6,6 +6,15 @@ resolution-markers = [
|
||||
"python_full_version < '3.13'",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
version = "2.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/10/2a/c93173ffa1b39c1d0395b7e842bbdc62e556ca9d8d3b5572926f3e4ca752/absl_py-2.3.1.tar.gz", hash = "sha256:a97820526f7fbfd2ec1bce83f3f25e3a14840dac0d8e02a0b71cd75db3f77fc9", size = 116588, upload-time = "2025-07-03T09:31:44.05Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8f/aa/ba0014cc4659328dc818a28827be78e6d97312ab0cb98105a770924dc11e/absl_py-2.3.1-py3-none-any.whl", hash = "sha256:eeecf07f0c2a93ace0772c92e596ace6d3d3996c042b2128459aaae2a76de11d", size = 135811, upload-time = "2025-07-03T09:31:42.253Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anyio"
|
||||
version = "4.9.0"
|
||||
@@ -218,6 +227,24 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/20/94/c5790835a017658cbfabd07f3bfb549140c3ac458cfc196323996b10095a/charset_normalizer-3.4.2-py3-none-any.whl", hash = "sha256:7f56930ab0abd1c45cd15be65cc741c28b1c9a34876ce8c17a2fa107810c0af0", size = 52626, upload-time = "2025-05-02T08:34:40.053Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "chex"
|
||||
version = "0.1.90"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "absl-py" },
|
||||
{ name = "jax" },
|
||||
{ name = "jaxlib" },
|
||||
{ name = "numpy" },
|
||||
{ name = "setuptools" },
|
||||
{ name = "toolz" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/77/70/53c7d404ce9e2a94009aea7f77ef6e392f6740e071c62683a506647c520f/chex-0.1.90.tar.gz", hash = "sha256:d3c375aeb6154b08f1cccd2bee4ed83659ee2198a6acf1160d2fe2e4a6c87b5c", size = 92363, upload-time = "2025-07-23T19:50:47.945Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/6f/3d/46bb04776c465cea2dd8aa2d4b61ab610b707f798f47838ef7e6105b025c/chex-0.1.90-py3-none-any.whl", hash = "sha256:fce3de82588f72d4796e545e574a433aa29229cbdcf792555e41bead24b704ae", size = 101047, upload-time = "2025-07-23T19:50:46.603Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colorama"
|
||||
version = "0.4.6"
|
||||
@@ -386,6 +413,9 @@ dependencies = [
|
||||
{ name = "jax", extra = ["cuda12"] },
|
||||
{ name = "jupyter" },
|
||||
{ name = "matplotlib" },
|
||||
{ name = "networkx" },
|
||||
{ name = "optax" },
|
||||
{ name = "scikit-learn" },
|
||||
{ name = "seaborn" },
|
||||
{ name = "tqdm" },
|
||||
]
|
||||
@@ -397,6 +427,9 @@ requires-dist = [
|
||||
{ name = "jax", extras = ["cuda12"], specifier = ">=0.7.0" },
|
||||
{ name = "jupyter", specifier = ">=1.1.1" },
|
||||
{ name = "matplotlib", specifier = ">=3.10.3" },
|
||||
{ name = "networkx", specifier = ">=3.5" },
|
||||
{ name = "optax", specifier = ">=0.2.5" },
|
||||
{ name = "scikit-learn", specifier = ">=1.7.1" },
|
||||
{ name = "seaborn", specifier = ">=0.13.2" },
|
||||
{ name = "tqdm", specifier = ">=4.67.1" },
|
||||
]
|
||||
@@ -641,6 +674,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "joblib"
|
||||
version = "1.5.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/dc/fe/0f5a938c54105553436dbff7a61dc4fed4b1b2c98852f8833beaf4d5968f/joblib-1.5.1.tar.gz", hash = "sha256:f4f86e351f39fe3d0d32a9f2c3d8af1ee4cec285aafcb27003dda5205576b444", size = 330475, upload-time = "2025-05-23T12:04:37.097Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/4f/1195bbac8e0c2acc5f740661631d8d750dc38d4a32b23ee5df3cde6f4e0d/joblib-1.5.1-py3-none-any.whl", hash = "sha256:4719a31f054c7d766948dcd83e9613686b27114f190f717cec7eaa2084f8a74a", size = 307746, upload-time = "2025-05-23T12:04:35.124Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "json5"
|
||||
version = "0.12.0"
|
||||
@@ -1141,6 +1183,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195, upload-time = "2024-01-21T14:25:17.223Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "networkx"
|
||||
version = "3.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6c/4f/ccdb8ad3a38e583f214547fd2f7ff1fc160c43a75af88e6aec213404b96a/networkx-3.5.tar.gz", hash = "sha256:d4c6f9cf81f52d69230866796b82afbccdec3db7ae4fbd1b65ea750feed50037", size = 2471065, upload-time = "2025-05-29T11:35:07.804Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl", hash = "sha256:0030d386a9a06dee3565298b4a734b68589749a544acbb6c412dc9e2489ec6ec", size = 2034406, upload-time = "2025-05-29T11:35:04.961Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "notebook"
|
||||
version = "7.4.4"
|
||||
@@ -1373,6 +1424,22 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932, upload-time = "2024-09-26T14:33:23.039Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "optax"
|
||||
version = "0.2.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "absl-py" },
|
||||
{ name = "chex" },
|
||||
{ name = "jax" },
|
||||
{ name = "jaxlib" },
|
||||
{ name = "numpy" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c0/75/1e011953c48be502d4d84fa8458e91be7c6f983002511669bddd7b1a065f/optax-0.2.5.tar.gz", hash = "sha256:b2e38c7aea376186deae758ba7a258e6ef760c6f6131e9e11bc561c65386d594", size = 258548, upload-time = "2025-06-10T17:00:47.544Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b9/33/f86091c706db1a5459f501830241afff2ecab3532725c188ea57be6e54de/optax-0.2.5-py3-none-any.whl", hash = "sha256:966deae936207f268ac8f564d8ed228d645ac1aaddefbbf194096d2299b24ba8", size = 354324, upload-time = "2025-06-10T17:00:46.062Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "overrides"
|
||||
version = "7.7.0"
|
||||
@@ -1862,6 +1929,35 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/75/04/5302cea1aa26d886d34cadbf2dc77d90d7737e576c0065f357b96dc7a1a6/rpds_py-0.26.0-cp314-cp314t-win_amd64.whl", hash = "sha256:f14440b9573a6f76b4ee4770c13f0b5921f71dde3b6fcb8dabbefd13b7fe05d7", size = 232821, upload-time = "2025-07-01T15:55:55.167Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scikit-learn"
|
||||
version = "1.7.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "joblib" },
|
||||
{ name = "numpy" },
|
||||
{ name = "scipy" },
|
||||
{ name = "threadpoolctl" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/41/84/5f4af978fff619706b8961accac84780a6d298d82a8873446f72edb4ead0/scikit_learn-1.7.1.tar.gz", hash = "sha256:24b3f1e976a4665aa74ee0fcaac2b8fccc6ae77c8e07ab25da3ba6d3292b9802", size = 7190445, upload-time = "2025-07-18T08:01:54.5Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/16/57f176585b35ed865f51b04117947fe20f130f78940c6477b6d66279c9c2/scikit_learn-1.7.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3cee419b49b5bbae8796ecd690f97aa412ef1674410c23fc3257c6b8b85b8087", size = 9260431, upload-time = "2025-07-18T08:01:22.77Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/67/4e/899317092f5efcab0e9bc929e3391341cec8fb0e816c4789686770024580/scikit_learn-1.7.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2fd8b8d35817b0d9ebf0b576f7d5ffbbabdb55536b0655a8aaae629d7ffd2e1f", size = 8637191, upload-time = "2025-07-18T08:01:24.731Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f3/1b/998312db6d361ded1dd56b457ada371a8d8d77ca2195a7d18fd8a1736f21/scikit_learn-1.7.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:588410fa19a96a69763202f1d6b7b91d5d7a5d73be36e189bc6396bfb355bd87", size = 9486346, upload-time = "2025-07-18T08:01:26.713Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ad/09/a2aa0b4e644e5c4ede7006748f24e72863ba2ae71897fecfd832afea01b4/scikit_learn-1.7.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e3142f0abe1ad1d1c31a2ae987621e41f6b578144a911ff4ac94781a583adad7", size = 9290988, upload-time = "2025-07-18T08:01:28.938Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/15/fa/c61a787e35f05f17fc10523f567677ec4eeee5f95aa4798dbbbcd9625617/scikit_learn-1.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:3ddd9092c1bd469acab337d87930067c87eac6bd544f8d5027430983f1e1ae88", size = 8735568, upload-time = "2025-07-18T08:01:30.936Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/52/f8/e0533303f318a0f37b88300d21f79b6ac067188d4824f1047a37214ab718/scikit_learn-1.7.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b7839687fa46d02e01035ad775982f2470be2668e13ddd151f0f55a5bf123bae", size = 9213143, upload-time = "2025-07-18T08:01:32.942Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/71/f3/f1df377d1bdfc3e3e2adc9c119c238b182293e6740df4cbeac6de2cc3e23/scikit_learn-1.7.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:a10f276639195a96c86aa572ee0698ad64ee939a7b042060b98bd1930c261d10", size = 8591977, upload-time = "2025-07-18T08:01:34.967Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/99/72/c86a4cd867816350fe8dee13f30222340b9cd6b96173955819a5561810c5/scikit_learn-1.7.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:13679981fdaebc10cc4c13c43344416a86fcbc61449cb3e6517e1df9d12c8309", size = 9436142, upload-time = "2025-07-18T08:01:37.397Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e8/66/277967b29bd297538dc7a6ecfb1a7dce751beabd0d7f7a2233be7a4f7832/scikit_learn-1.7.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4f1262883c6a63f067a980a8cdd2d2e7f2513dddcef6a9eaada6416a7a7cbe43", size = 9282996, upload-time = "2025-07-18T08:01:39.721Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e2/47/9291cfa1db1dae9880420d1e07dbc7e8dd4a7cdbc42eaba22512e6bde958/scikit_learn-1.7.1-cp313-cp313-win_amd64.whl", hash = "sha256:ca6d31fb10e04d50bfd2b50d66744729dbb512d4efd0223b864e2fdbfc4cee11", size = 8707418, upload-time = "2025-07-18T08:01:42.124Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/61/95/45726819beccdaa34d3362ea9b2ff9f2b5d3b8bf721bd632675870308ceb/scikit_learn-1.7.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:781674d096303cfe3d351ae6963ff7c958db61cde3421cd490e3a5a58f2a94ae", size = 9561466, upload-time = "2025-07-18T08:01:44.195Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ee/1c/6f4b3344805de783d20a51eb24d4c9ad4b11a7f75c1801e6ec6d777361fd/scikit_learn-1.7.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:10679f7f125fe7ecd5fad37dd1aa2daae7e3ad8df7f3eefa08901b8254b3e12c", size = 9040467, upload-time = "2025-07-18T08:01:46.671Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6f/80/abe18fe471af9f1d181904203d62697998b27d9b62124cd281d740ded2f9/scikit_learn-1.7.1-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1f812729e38c8cb37f760dce71a9b83ccfb04f59b3dca7c6079dcdc60544fa9e", size = 9532052, upload-time = "2025-07-18T08:01:48.676Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/14/82/b21aa1e0c4cee7e74864d3a5a721ab8fcae5ca55033cb6263dca297ed35b/scikit_learn-1.7.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:88e1a20131cf741b84b89567e1717f27a2ced228e0f29103426102bc2e3b8ef7", size = 9361575, upload-time = "2025-07-18T08:01:50.639Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/20/f4777fcd5627dc6695fa6b92179d0edb7a3ac1b91bcd9a1c7f64fa7ade23/scikit_learn-1.7.1-cp313-cp313t-win_amd64.whl", hash = "sha256:b1bd1d919210b6a10b7554b717c9000b5485aa95a1d0f177ae0d7ee8ec750da5", size = 9277310, upload-time = "2025-07-18T08:01:52.547Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scipy"
|
||||
version = "1.16.0"
|
||||
@@ -1987,6 +2083,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/9e/2064975477fdc887e47ad42157e214526dcad8f317a948dee17e1659a62f/terminado-0.18.1-py3-none-any.whl", hash = "sha256:a4468e1b37bb318f8a86514f65814e1afc977cf29b3992a4500d9dd305dcceb0", size = 14154, upload-time = "2024-03-12T14:34:36.569Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "threadpoolctl"
|
||||
version = "3.6.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b7/4d/08c89e34946fce2aec4fbb45c9016efd5f4d7f24af8e5d93296e935631d8/threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e", size = 21274, upload-time = "2025-03-13T13:49:23.031Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinycss2"
|
||||
version = "1.4.0"
|
||||
@@ -1999,6 +2104,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e6/34/ebdc18bae6aa14fbee1a08b63c015c72b64868ff7dae68808ab500c492e2/tinycss2-1.4.0-py3-none-any.whl", hash = "sha256:3a49cf47b7675da0b15d0c6e1df8df4ebd96e9394bb905a5775adb0d884c5289", size = 26610, upload-time = "2024-10-24T14:58:28.029Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toolz"
|
||||
version = "1.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8a/0b/d80dfa675bf592f636d1ea0b835eab4ec8df6e9415d8cfd766df54456123/toolz-1.0.0.tar.gz", hash = "sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02", size = 66790, upload-time = "2024-10-04T16:17:04.001Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/03/98/eb27cc78ad3af8e302c9d8ff4977f5026676e130d28dd7578132a457170c/toolz-1.0.0-py3-none-any.whl", hash = "sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236", size = 56383, upload-time = "2024-10-04T16:17:01.533Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tornado"
|
||||
version = "6.5.1"
|
||||
|
Reference in New Issue
Block a user