replicated mecc

This commit is contained in:
2025-07-31 01:12:53 -04:00
parent 1a0425d549
commit 5a0c479c1e
21 changed files with 1304 additions and 54 deletions

7
.gitignore vendored Normal file
View File

@@ -0,0 +1,7 @@
.venv
.env
__pycaches__
datasets/
temp.*
test.*

15
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,15 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: Current File",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal"
}
]
}

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

206
generate_data_consensus.py Normal file
View File

@@ -0,0 +1,206 @@
import os
import json
import time
import jax
import jax.numpy as jnp
import numpy as np
import networkx as nx
import networkx.algorithms.community as nx_comm
from tqdm import tqdm
import sims
def generate_connected_graph(rng: np.random.Generator, num_agents: int, graph_type: str) -> tuple[nx.Graph, str, np.ndarray]:
"""
Generates a random, undirected, unweighted, connected graph.
It randomly selects a NetworkX algorithm and retries until the graph is connected.
Returns:
A tuple containing the networkx Graph object, the algorithm name, and the adjacency matrix.
"""
G = None
if graph_type == "erdos_renyi":
# p = 2.0 * np.log(num_agents) / num_agents
p = 0.5
G = nx.erdos_renyi_graph(num_agents, p, seed=rng)
elif graph_type == "watts_strogatz":
k = 2
p = 0.1
G = nx.watts_strogatz_graph(num_agents, k, p, seed=rng)
elif graph_type == "barabasi_albert":
m = 2
if m >= num_agents: m = max(1, num_agents - 1)
G = nx.barabasi_albert_graph(num_agents, m, seed=rng)
elif graph_type == "powerlaw_cluster":
m = 3 # Number of random edges to add for each new node
p = 0.1 # Probability of adding a triangle after adding a random edge
if m >= num_agents: m = max(1, num_agents - 1)
G = nx.powerlaw_cluster_graph(num_agents, m, p, seed=rng)
# Add self-loops, as they are often assumed in consensus algorithms
G.add_edges_from([(i, i) for i in range(num_agents)])
adj_matrix = nx.to_numpy_array(G, dtype=np.float32)
return G, graph_type, adj_matrix
def calculate_graph_metrics(G: nx.Graph) -> dict:
"""
Calculates and returns a dictionary of key graph metrics.
This function computes basic properties, connectivity, clustering,
community structure, and spectral properties of the graph.
Computationally expensive metrics are skipped for larger graphs to ensure performance.
"""
metrics = {}
num_nodes = G.number_of_nodes()
num_edges = G.number_of_edges()
# --- Basic Properties ---
metrics["number_of_nodes"] = num_nodes
metrics["number_of_edges"] = num_edges
metrics["average_degree"] = (2 * num_edges / num_nodes) if num_nodes > 0 else 0
metrics["edge_density"] = nx.density(G)
# --- Connectivity ---
metrics["is_connected"] = nx.is_connected(G)
metrics["number_connected_components"] = nx.number_connected_components(G)
# --- Clustering ---
metrics["average_clustering_coefficient"] = nx.average_clustering(G)
metrics["clustering_coefficient"] = nx.clustering(G) # Per-node clustering
# --- Distance-Based Metrics (for connected graphs) ---
# These are computationally intensive and only run on smaller, connected graphs.
if metrics["is_connected"] and num_nodes < 250:
metrics["average_shortest_path_length"] = nx.average_shortest_path_length(G)
metrics["diameter"] = nx.diameter(G)
metrics["eccentricity"] = nx.eccentricity(G)
else:
# Set to None if graph is disconnected or too large
metrics["average_shortest_path_length"] = None
metrics["diameter"] = None
metrics["eccentricity"] = None
# --- Spectral & Community Metrics (Potentially Slow) ---
# These are also limited to smaller graphs.
if 1 < num_nodes < 500:
# Eigenvalues of the Laplacian matrix
try:
laplacian_eigenvalues = sorted(nx.laplacian_spectrum(G))
metrics["laplacian_eigenvalues"] = laplacian_eigenvalues
# The second-smallest eigenvalue of the Laplacian matrix
metrics["algebraic_connectivity"] = laplacian_eigenvalues[1]
except Exception:
metrics["laplacian_eigenvalues"] = None
metrics["algebraic_connectivity"] = None
# Modularity using the Louvain community detection algorithm
if num_edges > 0:
try:
communities = nx_comm.louvain_communities(G, seed=123)
metrics["modularity"] = nx_comm.modularity(G, communities)
except Exception:
metrics["modularity"] = None # Algorithm may fail on some graphs
else:
metrics["modularity"] = None
else:
metrics["laplacian_eigenvalues"] = None
metrics["algebraic_connectivity"] = None
metrics["modularity"] = None
return metrics
def generate_multiple_initial_states(key: jax.Array, num_sims: int, num_agents: int, max_range: float) -> jax.Array:
"""Generate a batch of unique random initial states for the agents."""
return jax.random.uniform(
key,
shape=(num_sims, num_agents),
minval=-max_range,
maxval=max_range
)
def main():
"""Main script to generate and save the consensus simulation dataset."""
# --- Configuration ---
GRAPH_GEN_ALGOS = ["erdos_renyi", "barabasi_albert", "powerlaw_cluster", "watts_strogatz"]
AGENT_COUNTS = range(5, 51, 5)# 5, 10, ..., 50 agents
GRAPHS_PER_AGENT_COUNT = 100
GRAPHS_PER_GRAPH_ALGO = GRAPHS_PER_AGENT_COUNT // len(GRAPH_GEN_ALGOS)
SIMS_PER_GRAPH = 100
OUTPUT_DIR = "datasets/consensus_dataset"
# --- Setup ---
seed = int(time.time())
main_jax_key = jax.random.PRNGKey(seed)
numpy_rng = np.random.default_rng(seed=seed + 1)
print(f"🚀 Starting data generation...")
print(f"Saving dataset to directory: '{OUTPUT_DIR}/'")
os.makedirs(OUTPUT_DIR, exist_ok=True)
# --- Main Generation Loop ---
for n_agents in tqdm(AGENT_COUNTS, desc="Overall Progress"):
agent_folder = os.path.join(OUTPUT_DIR, f"agents_{n_agents}")
os.makedirs(agent_folder, exist_ok=True)
for graph_idx in tqdm(range(GRAPHS_PER_AGENT_COUNT), desc=f"Graphs for N={n_agents}", leave=False):
graph_algo_idx = graph_idx // GRAPHS_PER_GRAPH_ALGO
# 1. Configure the simulation using the imported class
config = sims.consensus.ConsensusConfig()
config.num_agents=n_agents
config.num_sims=SIMS_PER_GRAPH
# 2. Generate graph and its metrics
G, graph_type, adj_matrix_np = generate_connected_graph(numpy_rng, config.num_agents, GRAPH_GEN_ALGOS[graph_algo_idx])
adj_matrix_jax = jnp.array(adj_matrix_np)
graph_metrics = calculate_graph_metrics(G)
graph_metrics["graph_type"] = graph_type
# 3. Generate initial states for the 100 simulations
main_jax_key, states_key = jax.random.split(main_jax_key)
initial_states = generate_multiple_initial_states(
states_key, config.num_sims, config.num_agents, config.max_range
)
# 4. Run all simulations using the imported JAX function
trajectories = sims.consensus.run_consensus_sim(adj_matrix_jax, initial_states, config)
# 5. Package all data into a dictionary for logging
log_data = {
"simulation_config": {
"num_agents": config.num_agents,
"num_sims": config.num_sims,
"num_time_steps": config.num_time_steps,
"step_size": config.step_size,
"max_range": config.max_range,
"directed": config.directed,
"weighted": config.weighted,
},
"graph_metrics": graph_metrics,
"adjacency_matrix": adj_matrix_np.tolist(),
"initial_states": initial_states.tolist(),
"trajectories": trajectories.tolist()
}
# 6. Save the complete log to a JSON file
file_path = os.path.join(agent_folder, f"graph_{graph_idx:02d}.json")
with open(file_path, 'w') as f:
json.dump(log_data, f, indent=2)
print(f"\n✅ Data generation complete! Check the '{OUTPUT_DIR}' folder.")
if __name__ == "__main__":
main()

162
generate_data_kuramoto.py Normal file
View File

@@ -0,0 +1,162 @@
import os
import json
import time
import jax
import jax.numpy as jnp
import numpy as np
import networkx as nx
from tqdm import tqdm
from functools import partial
from dataclasses import asdict
import sims
# --- Import necessary components from your Kuramoto simulation file ---
from sims.kuramoto import (
KuramotoConfig,
run_kuramoto_simulation,
phase_coherence,
mean_frequency,
)
from generate_data_consensus import generate_connected_graph, calculate_graph_metrics
# ==============================================================================
# SECTION 1: BATCHED OPERATIONS FOR EFFICIENT DATA GENERATION
# ==============================================================================
@partial(jax.jit, static_argnames=("config",))
def run_batched_simulations(
thetas0_batch: jax.Array, # (num_sims, N)
omegas_batch: jax.Array, # (num_sims, N)
adj_mat: jax.Array, # (N, N)
config: KuramotoConfig
) -> tuple[jax.Array, jax.Array, jax.Array]:
"""
Runs many Kuramoto simulations in parallel for the same graph but different
initial conditions, and performs analysis.
Returns:
A tuple containing:
- All trajectories (num_sims, T, N)
- Final coherence R(T) for each sim (num_sims,)
- Mean frequencies for each sim (num_sims, N)
"""
# Create a batched version of the simulation runner using vmap.
# This maps the function over the first axis of thetas0_batch and omegas_batch.
vmapped_runner = jax.vmap(
run_kuramoto_simulation,
in_axes=(0, 0, None, None), # Map over thetas0, omegas; adj_mat and config are fixed
out_axes=0
)
all_trajectories = vmapped_runner(thetas0_batch, omegas_batch, adj_mat, config)
# Analyze the results in a batched manner
# phase_coherence naturally works on the time axis, so we vmap over the sim axis.
vmapped_coherence = jax.vmap(phase_coherence)(all_trajectories) # -> (num_sims, T)
final_coherence = vmapped_coherence[:, -1] # -> (num_sims,)
# vmap the mean_frequency calculation over trajectories and omegas
vmapped_mean_freq = jax.vmap(mean_frequency, in_axes=(0, 0, None, None))
all_mean_freqs = vmapped_mean_freq(all_trajectories, omegas_batch, adj_mat, config)
return all_trajectories, final_coherence, all_mean_freqs
def generate_batched_initial_states(key: jax.Array, num_sims: int, config: KuramotoConfig):
"""Generates a batch of initial states."""
keys = jax.random.split(key, num_sims)
# We will just write a new generator that is vmap-friendly
@jax.vmap
def generate_single_state(k):
key_theta, key_omega = jax.random.split(k)
thetas0 = jax.random.uniform(key_theta, (config.num_agents,), minval=0, maxval=2 * jnp.pi)
omegas = jax.random.normal(key_omega, (config.num_agents,)) # Using std=1, mean=0
return thetas0, omegas
thetas0_batch, omegas_batch = generate_single_state(keys)
return thetas0_batch, omegas_batch
def main():
"""Main script to generate and save the Kuramoto simulation dataset."""
# --- Configuration ---
GRAPH_GEN_ALGOS = ["erdos_renyi", "barabasi_albert", "powerlaw_cluster", "watts_strogatz"]
AGENT_COUNTS = range(5, 51, 5) # 5, 10, ..., 50 agents
# AGENT_COUNTS = [50]
GRAPHS_PER_AGENT_COUNT = 100
GRAPHS_PER_GRAPH_ALGO = GRAPHS_PER_AGENT_COUNT // len(GRAPH_GEN_ALGOS)
SIMS_PER_GRAPH = 100
OUTPUT_DIR = "datasets/kuramoto_dataset"
# --- Setup ---
seed = int(time.time())
main_key = jax.random.PRNGKey(seed)
numpy_rng = np.random.default_rng(seed=seed + 1)
print(f"🚀 Starting Kuramoto data generation...")
print(f"Saving dataset to: '{OUTPUT_DIR}/'")
os.makedirs(OUTPUT_DIR, exist_ok=True)
# --- Main Loop ---
for n_agents in tqdm(AGENT_COUNTS, desc="Agent Counts"):
agent_folder = os.path.join(OUTPUT_DIR, f"agents_{n_agents}")
os.makedirs(agent_folder, exist_ok=True)
for graph_idx in tqdm(range(GRAPHS_PER_AGENT_COUNT), desc=f"Graphs for N={n_agents}", leave=False):
graph_algo_idx = graph_idx // GRAPHS_PER_GRAPH_ALGO
# 1. Setup config, keys, and graph
config = KuramotoConfig()
config.num_agents=n_agents
config.coupling=1.5
config.T=15.0
main_key, graph_key, state_key = jax.random.split(main_key, 3)
G, graph_type, adj_matrix_np = generate_connected_graph(numpy_rng, config.num_agents, GRAPH_GEN_ALGOS[graph_algo_idx])
adj_matrix_jax = jnp.array(adj_matrix_np)
graph_metrics = calculate_graph_metrics(G)
graph_metrics["graph_type"] = graph_type
# 2. Generate a batch of initial conditions
thetas0_batch, omegas_batch = generate_batched_initial_states(state_key, SIMS_PER_GRAPH, config)
# 3. Run all simulations and analyses in a single, efficient call
trajectories, final_R, mean_freqs = run_batched_simulations(
thetas0_batch, omegas_batch, adj_matrix_jax, config
)
trajectories.block_until_ready() # Ensure computation is finished
# 4. Package all data for logging
log_data = {
"config": {
"num_agents" : config.num_agents,
"coupling" : config.coupling,
"dt" : config.dt,
"T" : config.T,
"normalize_by_degree" : config.normalize_by_degree,
"directed" : config.directed,
"weighted" : config.weighted,
},
"graph_metrics": graph_metrics,
"adjacency_matrix": adj_matrix_np.tolist(),
"initial_conditions": {
"thetas0": thetas0_batch.tolist(),
"omegas": omegas_batch.tolist()
},
"results": {
"final_coherence": final_R.tolist(),
"mean_frequencies": mean_freqs.tolist()
},
# Optionally save full trajectories. Warning: can create large files.
"trajectories": trajectories.tolist()
}
# 6. Save the complete log to a JSON file
file_path = os.path.join(agent_folder, f"graph_{graph_idx:02d}.json")
with open(file_path, 'w') as f:
json.dump(log_data, f, indent=2)
print(f"\n✅ Data generation complete! Check the '{OUTPUT_DIR}' folder.")
if __name__ == "__main__":
main()

136
model.py Normal file
View File

@@ -0,0 +1,136 @@
import jax
from jax import random
import jax.numpy as jnp
from train import ModelConfig, TrainConfig
import optax
from functools import partial
def init_linear_layer(
key: jax.Array,
in_features: int,
out_features: int,
use_bias: bool = True
):
"""
Initializes the weights and biases in a linear layer.
"""
key_w, key_b = random.split(key)
limit = jnp.sqrt(6/in_features)
W = random.uniform(key_w, (in_features, out_features), minval=-limit, maxval=limit)
params = {'W': W}
if use_bias:
b = random.uniform(key_b, (out_features,), minval=-limit, maxval=limit)
params['b'] = b
return params
def init_fn(key: jax.Array, config: ModelConfig):
"""
Initializes all model parameters. Returns a pytree
"""
key_embed, key_translate, key_attn_proj, key_head = random.split(key, 4)
params = {
"agent_embeddings" : {
"weight" : random.normal(key_embed, shape=(config.num_agents, config.embedding_dim))
},
"translate": init_linear_layer(key_translate, config.input_dim, config.embedding_dim),
"attn_proj": init_linear_layer(key_attn_proj, config.embedding_dim, 2 * config.embedding_dim, use_bias=False),
"head": init_linear_layer(key_head, config.embedding_dim, config.output_dim)
}
return params
def forward(params: dict, input_timesteps: jax.Array, config: ModelConfig):
"""
Model's forward function. Takes in the parameters and inptu timesteps, returns predictions
"""
batch_size, num_agents, _ = input_timesteps.shape
agent_embed = params["agent_embeddings"]["weight"]
agent_embed = jnp.broadcast_to(agent_embed, (batch_size, num_agents, config.embedding_dim))
attn_proj_out = agent_embed @ params["attn_proj"]['W']
k, q = jnp.split(attn_proj_out, 2, axis=-1)
v = input_timesteps @ params["translate"]['W'] + params["translate"]['b']
att_scores = (q @ k.transpose(0, 2, 1) )/ jnp.sqrt(num_agents)
att_weights = jax.nn.softmax(att_scores, axis=-1)
weighted_average = att_weights @ v
prediction = weighted_average @ params["head"]['W'] + params["head"]['b']
return prediction
def get_attention_fn(params: dict, config: ModelConfig):
"""
Calculates and returns the learned attention matrix between agents.
This is a pure function for analysis.
"""
embeddings = params['agent_embeddings']['weight']
# Project embeddings to get keys (k) and queries (q) for the global graph
attn_proj_out = embeddings @ params['attn_proj']['W']
k, q = jnp.split(attn_proj_out, 2, axis=-1)
# Note: Using sqrt(embedding_dim) as in the original get_attention method
attn_scores = (q @ k.T) / jnp.sqrt(q.shape[-1])
return jnp.asarray(attn_scores) # Return as NumPy array for logging
def train_model(config: ModelConfig, inputs: jax.Array, targets: jax.Array,
true_graph: jax.Array,
train_config: TrainConfig,
):
key = random.PRNGKey(0)
key, init_key = random.split(key)
params = init_fn(init_key, config)
optimizer = optax.adamw(train_config.learning_rate)
opt_state = optimizer.init(params)
def loss_fn(p, x_batch, y_batch, config):
predictions = forward(p, x_batch, config)
loss = jnp.mean(jnp.abs(predictions - y_batch))
return loss
@partial(jax.jit, static_argnames=['config'])
def update_step(params, opt_state, x_batch, y_batch, config):
# FIX 1: Pass all necessary arguments to the loss function here.
loss_val, grads = jax.value_and_grad(loss_fn)(params, x_batch, y_batch, config)
updates, new_opt_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_state, loss_val
loss_history = {f"epoch_{i}": [] for i in range(train_config.epochs)}
# FIX 2: Initialize with empty lists `[]` instead of `None`.
graphs = {f"epoch_{i}": [] for i in range(train_config.epochs)}
num_batches = len(inputs)
for epoch in range(train_config.epochs):
running_loss = 0.0
for batch_num in range(num_batches):
x, y = inputs[batch_num], targets[batch_num]
# FIX 3: Pass the `config` object, as it's a static argument for JIT.
params, opt_state, loss_val = update_step(params, opt_state, x, y, config)
running_loss += loss_val
epoch_loss = running_loss / num_batches
loss_history[f"epoch_{epoch}"].append(epoch_loss)
if train_config.verbose and (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1:3d} | Loss: {epoch_loss:.6f}")
if train_config.log and epoch % train_config.log_epoch_interval == 0:
attn = get_attention_fn(params, config)
graphs[f"epoch_{epoch}"].append(attn)
all_logs = {
"loss_history": loss_history,
"graphs": graphs,
"true_graph": true_graph,
}
return params, all_logs

View File

@@ -10,6 +10,9 @@ dependencies = [
"jax[cuda12]>=0.7.0",
"jupyter>=1.1.1",
"matplotlib>=3.10.3",
"networkx>=3.5",
"optax>=0.2.5",
"scikit-learn>=1.7.1",
"seaborn>=0.13.2",
"tqdm>=4.67.1",
]
]

2
sims/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
from .consensus import *
from .kuramoto import *

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -10,7 +10,7 @@ class ConsensusConfig:
Config class for Consensus dynamics sims
"""
num_sims: int = 500 # Number of consensus sims
num_agents: int = 5# Number of agents in the consensus simulation
num_agents: int = 5 # Number of agents in the consensus simulation
max_range: float = 1 # Max range of values each agent can take
step_size: float = 0.1 # Target range for length of simulation
directed: bool = False # Consensus graph directed?
@@ -61,6 +61,10 @@ def generate_random_adjacency_matrix(key: jax.Array, config: ConsensusConfig):
config: ConsensusConfig
Config for Consensus dyanmics
Returns
---------------------
adj_matrices: jax.Array (num_agents, num_agents)
Random matrix
"""
rand_matrix = jax.random.uniform(key, shape=(config.num_agents, config.num_agents))
# idxs = jnp.arange(config.num_agents)
@@ -90,7 +94,7 @@ def generate_random_agent_states(key: jax.Array, config: ConsensusConfig):
Returns
---------------------
rand_states: jax.Array (num_agents, 1)
rand_states: jax.Array (num_sims, num_agents)
"""
rand_states = jax.random.uniform(key, shape=(config.num_sims, config.num_agents), minval=-config.max_range, maxval=config.max_range)
@@ -113,7 +117,7 @@ def run_consensus_sim(adj_mat: jax.Array, initial_agent_state: jax.Array, config
config: ConsensusConfig
Config for Consensus dynamics
"""
# batched consensus step (meant for many initial states)
batched_consensus_step = jax.vmap(consensus_step, in_axes=(None, 0, None), out_axes=0)
def step(x_prev, _):

View File

@@ -5,26 +5,18 @@ from functools import partial
import numpy as np
import matplotlib.pyplot as plt
# -------------------- Configuration --------------------
@dataclass(frozen=True)
class KuramotoConfig:
"""Configuration for the Kuramoto model simulation."""
num_agents: int = 10 # N: Number of oscillators
coupling: float = 1.0 # K: Coupling strength
dt: float = 0.01 # Δt: Integration time step
T: float = 10.0 # Total simulation time
# Adjacency matrix properties
normalize_by_degree: bool = True
time_steps: int = int(T/dt)
normalize_by_degree: bool = False
directed: bool = False
weighted: bool = False
@property
def num_time_steps(self) -> int:
"""Total number of simulation steps."""
return int(self.T / self.dt)
# -------------------- Core Dynamics --------------------
@partial(jax.jit, static_argnames=("config",))
def kuramoto_derivative(theta: jax.Array, # (N,) phase angles
omega: jax.Array, # (N,) natural frequencies
@@ -88,7 +80,7 @@ def run_kuramoto_simulation(
scan_fn,
thetas0,
None,
length=config.num_time_steps
length=config.time_steps
)
return trajectory

File diff suppressed because one or more lines are too long

19
train.py Normal file
View File

@@ -0,0 +1,19 @@
from dataclasses import dataclass
class ModelConfig:
num_agents: int = 10
embedding_dim: int = 64
input_dim: int = 1
output_dim: int = 1
simulation_type:str = "consensus"
class TrainConfig:
epochs: float = 100
learning_rate: float = 1e-3
verbose: bool = True
log: bool = True
log_epoch_interval: int = 10

259
train_and_eval.py Normal file
View File

@@ -0,0 +1,259 @@
import os
import json
import jax
import jax.numpy as jnp
import numpy as np
from dataclasses import dataclass, field, asdict
from enum import Enum
from tqdm import tqdm
from sklearn.metrics import f1_score
import uuid
import pickle
import sys
# Import from your existing model file
from model import ModelConfig, TrainConfig, train_model, get_attention_fn
# Define an Enum for the data source for clarity and type safety
# Overwrite the original TrainConfig to include our new parameters
class TrainConfig:
"""Configuration for the training process."""
learning_rate: float = 1e-3
epochs: int = 100
batch_size: int = 4096
verbose: bool = False # Set to True to see epoch loss during training
log: bool = True
log_epoch_interval: int = 10
# --- New parameters for this script ---
data_directory:str = "datasets/" + sys.argv[1] + "_dataset"
# Threshold for converting attention scores to a binary graph
f1_threshold: float = -0.4
def prepare_data_for_model(trajectories: np.ndarray, batch_size: int) -> tuple[np.ndarray, np.ndarray]:
"""
Converts simulation trajectories into input-output pairs for the model.
Input: state at time t. Target: state at time t+1.
Args:
trajectories: A numpy array of shape (num_sims, num_timesteps, num_agents).
batch_size: The desired batch size for training.
Returns:
A tuple of (batched_inputs, batched_targets).
"""
# For each simulation, create (input, target) pairs
all_inputs = []
all_targets = []
num_sims, num_timesteps, num_agents = trajectories.shape
# trajectories = np.reshape(trajectories, shape=(num_sims * num_timesteps, num_agents))
for i_sim in range(num_sims):
for j_tstep in range(num_timesteps-1):
input = trajectories[i_sim, j_tstep, :]
target = trajectories[i_sim, j_tstep + 1, :]
all_inputs.append(input)
all_targets.append(target)
all_indices = np.arange(len(all_inputs))
np.random.shuffle(all_indices)
all_inputs = np.array(all_inputs)
all_targets = np.array(all_targets)
all_inputs = all_inputs[all_indices]
all_targets = all_targets[all_indices]
# for sim_idx in range(num_sims):
# # Input is state from t=0 to t=T-2
# inputs = trajectories[sim_idx, :-1, :]
# # Target is state from t=1 to t=T-1
# targets = trajectories[sim_idx, 1:, :]
# all_inputs.append(inputs)
# all_targets.append(targets)
# Concatenate all pairs from all simulations
# Shape -> (num_sims * (num_timesteps - 1), num_agents)
# full_dataset_inputs = np.concatenate(all_inputs, axis=0)
# full_dataset_targets = np.concatenate(all_targets, axis=0)
# Reshape to have a feature dimension
# Shape -> (total_samples, num_agents, 1)
full_dataset_inputs = np.expand_dims(all_inputs, axis=-1)
full_dataset_targets = np.expand_dims(all_targets, axis=-1)
# Create batches
num_samples = full_dataset_inputs.shape[0]
num_batches = num_samples // batch_size
# Truncate to full batches
truncated_inputs = full_dataset_inputs[:num_batches * batch_size]
truncated_targets = full_dataset_targets[:num_batches * batch_size]
# Reshape into batches
# Shape -> (num_batches, batch_size, num_agents, 1)
batched_inputs = truncated_inputs.reshape(num_batches, batch_size, num_agents, 1)
batched_targets = truncated_targets.reshape(num_batches, batch_size, num_agents, 1)
return batched_inputs, batched_targets
def calculate_f1_score(
params: dict,
model_config: ModelConfig,
true_graph: np.ndarray,
threshold: float
) -> float:
"""
Extracts the learned attention graph, thresholds it, and computes the F1 score.
"""
# Get the learned attention matrix (N, N)
learned_graph_scores = np.array(get_attention_fn(params, model_config))
# Normalize scores to [0, 1] for consistent thresholding (optional but good practice)
# This uses min-max scaling on the flattened array
# flat_scores = learned_graph_scores.flatten()
# min_s, max_s = flat_scores.min(), flat_scores.max()
# if max_s > min_s:
# learned_graph_scores = (learned_graph_scores - min_s) / (max_s - min_s)
# Threshold to get a binary predicted graph
predicted_graph = (learned_graph_scores > threshold).astype(int)
# The diagonal is not part of the prediction task
# np.fill_diagonal(predicted_graph, 0)
# np.fill_diagonal(true_graph, 0)
# Flatten both graphs to treat this as a binary classification problem
true_flat = true_graph.flatten()
pred_flat = predicted_graph.flatten()
return f1_score(true_flat, pred_flat)
def main():
"""Main script to run the training and evaluation pipeline."""
train_config = TrainConfig()
# Check if the data directory exists
if not os.path.isdir(train_config.data_directory):
print(f"Error: Data directory '{train_config.data_directory}' not found.")
print(f"Please run the data generation script for '{train_config.data_directory}' first.")
return
print(f"🚀 Starting training pipeline for '{train_config.data_directory}' data.")
# Get sorted list of agent directories
agent_dirs = sorted(
[d for d in os.listdir(train_config.data_directory) if d.startswith("agents_")],
key=lambda x: int(x.split('_')[1])
)
for agent_dir_name in agent_dirs:
agent_dir_path = os.path.join(train_config.data_directory, agent_dir_name)
all_results_for_agent = []
graph_files = sorted([f for f in os.listdir(agent_dir_path) if f.endswith(".json")])
print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...")
for graph_file_name in tqdm(graph_files, desc=f"Training on {agent_dir_name}"):
file_path = os.path.join(agent_dir_path, graph_file_name)
with open(file_path, 'r') as f:
data = json.load(f)
# 1. Load and Prepare Data
trajectories = np.array(data['trajectories'])
s, l, n = trajectories.shape
# trajectories = trajectories.T
# np.random.shuffle(trajectories)
# trajectories = np.random.shuffle(trajectories)
true_graph = np.array(data['adjacency_matrix'])
inputs, targets = prepare_data_for_model(trajectories, train_config.batch_size)
# 2. Configure Model
num_agents = trajectories.shape[-1]
model_config = ModelConfig(
)
model_config.num_agents=num_agents
model_config.input_dim=1 # Each agent has a single state value at time t
model_config.output_dim=1
model_config.embedding_dim=32
# 3. Train the Model
# This relies on the modified train_model that returns final params
final_params, train_logs = train_model(
config=model_config,
inputs=inputs,
targets=targets,
true_graph=true_graph,
train_config=train_config
)
# 4. Evaluate
f1 = calculate_f1_score(
final_params,
model_config,
true_graph,
train_config.f1_threshold
)
loss_history_serializable = {
epoch: [loss.item() for loss in losses]
for epoch, losses in train_logs['loss_history'].items()
}
random_id = str(uuid.uuid4())
# 5. Log Results
result_log = {
# "model_name": random_id,
"source_file": graph_file_name,
"graph_metrics": data['graph_metrics'],
"f1_score": f1,
"training_loss_history": loss_history_serializable,
"config": {
# Manually create the dictionary for the model config
"model": {
"num_agents": model_config.num_agents,
"input_dim": model_config.input_dim,
"output_dim": model_config.output_dim,
"embedding_dim": model_config.embedding_dim
},
# This is correct because TrainConfig is a dataclass
"training": vars(train_config)
}
}
result_final_params = final_params
all_results_for_agent.append(result_log)
# 6. Save aggregated results for this agent count
results_dir = os.path.join(agent_dir_path, "results")
os.makedirs(results_dir, exist_ok=True)
output_file = os.path.join(results_dir, "summary_results.json")
with open(output_file, 'w') as f:
json.dump(all_results_for_agent, f, indent=2)
model_path = os.path.join(results_dir, "model_params")
os.makedirs(model_path, exist_ok=True)
with open(os.path.join(model_path,"model_params" + ".pkl"), "wb") as f:
pickle.dump(final_params, f)
print(f"✅ Results for {agent_dir_name} saved to {output_file}")
print("\n🎉 Pipeline finished successfully!")
if __name__ == "__main__":
main()

217
train_and_eval_w_noise.py Normal file
View File

@@ -0,0 +1,217 @@
import os
import json
import jax
import jax.numpy as jnp
import numpy as np
from dataclasses import dataclass, field, asdict
from enum import Enum
from tqdm import tqdm
from sklearn.metrics import f1_score
# Import from your existing model file
from model import ModelConfig, train_model, get_attention_fn
# --- MODIFICATION: Enums for clarity and safety ---
class DataSource(Enum):
CONSENSUS = "consensus"
KURAMOTO = "kuramoto"
class NoiseType(Enum):
NONE = "none"
NORMAL = "normal"
UNIFORM = "uniform"
# --- MODIFICATION: Updated TrainConfig ---
@dataclass
class TrainConfig:
"""Configuration for the training process."""
learning_rate: float = 1e-3
epochs: int = 50
batch_size: int = 64
verbose: bool = False
log: bool = True
log_epoch_interval: int = 10
# Data and Noise parameters
data_source: DataSource = DataSource.CONSENSUS
noise_type: NoiseType = NoiseType.NORMAL
noise_level: float = 0.1 # Stddev for Normal, half-width for Uniform
# Evaluation parameter
f1_threshold: float = 0.5
@property
def source_data_directory(self) -> str:
"""The directory where the source dataset is located."""
return f"{self.data_source.value}_dataset"
@property
def results_directory_name(self) -> str:
"""Generates a unique output directory name for separate logging."""
if self.noise_type == NoiseType.NONE or self.noise_level == 0:
return f"results_noiseless"
return f"results_noise_{self.noise_type.value}_{self.noise_level}"
# --- MODIFICATION: Updated data prep function to add noise ---
def prepare_data_for_model(
trajectories: np.ndarray,
key: jax.Array,
train_config: TrainConfig
) -> tuple[np.ndarray, np.ndarray]:
"""
Converts trajectories to input-output pairs and adds noise to the inputs.
"""
all_inputs, all_targets = [], []
num_sims, num_timesteps, num_agents = trajectories.shape
for sim_idx in range(num_sims):
all_inputs.append(trajectories[sim_idx, :-1, :])
all_targets.append(trajectories[sim_idx, 1:, :])
full_dataset_inputs = np.concatenate(all_inputs, axis=0)
full_dataset_targets = np.concatenate(all_targets, axis=0)
# --- NOISE INJECTION BLOCK ---
if train_config.noise_type != NoiseType.NONE and train_config.noise_level > 0:
noise_shape = full_dataset_inputs.shape
if train_config.noise_type == NoiseType.NORMAL:
noise = jax.random.normal(key, noise_shape) * train_config.noise_level
elif train_config.noise_type == NoiseType.UNIFORM:
noise = jax.random.uniform(
key, noise_shape,
minval=-train_config.noise_level,
maxval=train_config.noise_level
)
full_dataset_inputs += np.array(noise) # Add noise to inputs
# --- END NOISE BLOCK ---
full_dataset_inputs = np.expand_dims(full_dataset_inputs, axis=-1)
full_dataset_targets = np.expand_dims(full_dataset_targets, axis=-1)
num_samples = full_dataset_inputs.shape[0]
num_batches = num_samples // train_config.batch_size
truncated_inputs = full_dataset_inputs[:num_batches * train_config.batch_size]
truncated_targets = full_dataset_targets[:num_batches * train_config.batch_size]
batched_inputs = truncated_inputs.reshape(num_batches, train_config.batch_size, num_agents, 1)
batched_targets = truncated_targets.reshape(num_batches, train_config.batch_size, num_agents, 1)
return batched_inputs, batched_targets
def calculate_f1_score(
params: dict,
model_config: ModelConfig,
true_graph: np.ndarray,
threshold: float
) -> float:
"""Extracts the learned graph and computes the F1 score."""
learned_scores = np.array(get_attention_fn(params, model_config))
flat_scores = learned_scores.flatten()
min_s, max_s = flat_scores.min(), flat_scores.max()
if max_s > min_s:
learned_scores = (learned_scores - min_s) / (max_s - min_s)
predicted_graph = (learned_scores > threshold).astype(int)
np.fill_diagonal(predicted_graph, 0)
np.fill_diagonal(true_graph, 0)
return f1_score(true_graph.flatten(), predicted_graph.flatten())
def main():
"""Main script to run the training and evaluation pipeline."""
# Configure your training run here
train_config = TrainConfig(
noise_type=NoiseType.NORMAL,
noise_level=0.1
)
if not os.path.isdir(train_config.source_data_directory):
print(f"Error: Source data '{train_config.source_data_directory}' not found.")
return
print(f"🚀 Starting training pipeline for '{train_config.data_source.value}' data.")
print(f"Noise Configuration: type={train_config.noise_type.value}, level={train_config.noise_level}")
# --- MODIFICATION: Main JAX key for noise generation ---
main_key = jax.random.PRNGKey(42)
agent_dirs = sorted(
[d for d in os.listdir(train_config.source_data_directory) if d.startswith("agents_")],
key=lambda x: int(x.split('_')[1])
)
for agent_dir_name in agent_dirs:
agent_dir_path = os.path.join(train_config.source_data_directory, agent_dir_name)
all_results_for_agent = []
graph_files = sorted([f for f in os.listdir(agent_dir_path) if f.endswith(".json")])
print(f"\nProcessing {len(graph_files)} graphs for {agent_dir_name}...")
for graph_file_name in tqdm(graph_files, desc=f"Training on {agent_dir_name}"):
file_path = os.path.join(agent_dir_path, graph_file_name)
with open(file_path, 'r') as f:
data = json.load(f)
main_key, data_key = jax.random.split(main_key)
trajectories = np.array(data['trajectories'])
true_graph = np.array(data['adjacency_matrix'])
# --- MODIFICATION: Pass key and config to data prep ---
inputs, targets = prepare_data_for_model(trajectories, data_key, train_config)
num_agents = int(trajectories.shape[-1])
model_config = ModelConfig()
model_config.num_agents = num_agents
model_config.input_dim = 1
model_config.output_dim = 1
model_config.embedding_dim = 32
final_params, train_logs = train_model(
config=model_config,
inputs=inputs,
targets=targets,
true_graph=true_graph,
train_config=train_config
)
f1 = calculate_f1_score(final_params, model_config, true_graph, train_config.f1_threshold)
loss_history_serializable = {
epoch: [loss.item() for loss in losses]
for epoch, losses in train_logs['loss_history'].items()
}
result_log = {
"source_file": graph_file_name,
"graph_metrics": data['graph_metrics'],
"f1_score": f1,
"training_loss_history": loss_history_serializable,
"config": {
"model": {
"num_agents": model_config.num_agents,
"input_dim": model_config.input_dim,
"output_dim": model_config.output_dim,
"embedding_dim": model_config.embedding_dim
},
"training": asdict(train_config)
}
}
all_results_for_agent.append(result_log)
# --- MODIFICATION: Save to a separate results directory ---
results_dir = os.path.join(agent_dir_path, train_config.results_directory_name)
os.makedirs(results_dir, exist_ok=True)
output_file = os.path.join(results_dir, "summary_results.json")
with open(output_file, 'w') as f:
json.dump(all_results_for_agent, f, indent=2)
print(f"✅ Results for {agent_dir_name} saved to {output_file}")
print("\n🎉 Pipeline finished successfully!")
if __name__ == "__main__":
main()

114
uv.lock generated
View File

@@ -6,6 +6,15 @@ resolution-markers = [
"python_full_version < '3.13'",
]
[[package]]
name = "absl-py"
version = "2.3.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/10/2a/c93173ffa1b39c1d0395b7e842bbdc62e556ca9d8d3b5572926f3e4ca752/absl_py-2.3.1.tar.gz", hash = "sha256:a97820526f7fbfd2ec1bce83f3f25e3a14840dac0d8e02a0b71cd75db3f77fc9", size = 116588, upload-time = "2025-07-03T09:31:44.05Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/8f/aa/ba0014cc4659328dc818a28827be78e6d97312ab0cb98105a770924dc11e/absl_py-2.3.1-py3-none-any.whl", hash = "sha256:eeecf07f0c2a93ace0772c92e596ace6d3d3996c042b2128459aaae2a76de11d", size = 135811, upload-time = "2025-07-03T09:31:42.253Z" },
]
[[package]]
name = "anyio"
version = "4.9.0"
@@ -218,6 +227,24 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/20/94/c5790835a017658cbfabd07f3bfb549140c3ac458cfc196323996b10095a/charset_normalizer-3.4.2-py3-none-any.whl", hash = "sha256:7f56930ab0abd1c45cd15be65cc741c28b1c9a34876ce8c17a2fa107810c0af0", size = 52626, upload-time = "2025-05-02T08:34:40.053Z" },
]
[[package]]
name = "chex"
version = "0.1.90"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "absl-py" },
{ name = "jax" },
{ name = "jaxlib" },
{ name = "numpy" },
{ name = "setuptools" },
{ name = "toolz" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/77/70/53c7d404ce9e2a94009aea7f77ef6e392f6740e071c62683a506647c520f/chex-0.1.90.tar.gz", hash = "sha256:d3c375aeb6154b08f1cccd2bee4ed83659ee2198a6acf1160d2fe2e4a6c87b5c", size = 92363, upload-time = "2025-07-23T19:50:47.945Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/6f/3d/46bb04776c465cea2dd8aa2d4b61ab610b707f798f47838ef7e6105b025c/chex-0.1.90-py3-none-any.whl", hash = "sha256:fce3de82588f72d4796e545e574a433aa29229cbdcf792555e41bead24b704ae", size = 101047, upload-time = "2025-07-23T19:50:46.603Z" },
]
[[package]]
name = "colorama"
version = "0.4.6"
@@ -386,6 +413,9 @@ dependencies = [
{ name = "jax", extra = ["cuda12"] },
{ name = "jupyter" },
{ name = "matplotlib" },
{ name = "networkx" },
{ name = "optax" },
{ name = "scikit-learn" },
{ name = "seaborn" },
{ name = "tqdm" },
]
@@ -397,6 +427,9 @@ requires-dist = [
{ name = "jax", extras = ["cuda12"], specifier = ">=0.7.0" },
{ name = "jupyter", specifier = ">=1.1.1" },
{ name = "matplotlib", specifier = ">=3.10.3" },
{ name = "networkx", specifier = ">=3.5" },
{ name = "optax", specifier = ">=0.2.5" },
{ name = "scikit-learn", specifier = ">=1.7.1" },
{ name = "seaborn", specifier = ">=0.13.2" },
{ name = "tqdm", specifier = ">=4.67.1" },
]
@@ -641,6 +674,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" },
]
[[package]]
name = "joblib"
version = "1.5.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/dc/fe/0f5a938c54105553436dbff7a61dc4fed4b1b2c98852f8833beaf4d5968f/joblib-1.5.1.tar.gz", hash = "sha256:f4f86e351f39fe3d0d32a9f2c3d8af1ee4cec285aafcb27003dda5205576b444", size = 330475, upload-time = "2025-05-23T12:04:37.097Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7d/4f/1195bbac8e0c2acc5f740661631d8d750dc38d4a32b23ee5df3cde6f4e0d/joblib-1.5.1-py3-none-any.whl", hash = "sha256:4719a31f054c7d766948dcd83e9613686b27114f190f717cec7eaa2084f8a74a", size = 307746, upload-time = "2025-05-23T12:04:35.124Z" },
]
[[package]]
name = "json5"
version = "0.12.0"
@@ -1141,6 +1183,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195, upload-time = "2024-01-21T14:25:17.223Z" },
]
[[package]]
name = "networkx"
version = "3.5"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/6c/4f/ccdb8ad3a38e583f214547fd2f7ff1fc160c43a75af88e6aec213404b96a/networkx-3.5.tar.gz", hash = "sha256:d4c6f9cf81f52d69230866796b82afbccdec3db7ae4fbd1b65ea750feed50037", size = 2471065, upload-time = "2025-05-29T11:35:07.804Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl", hash = "sha256:0030d386a9a06dee3565298b4a734b68589749a544acbb6c412dc9e2489ec6ec", size = 2034406, upload-time = "2025-05-29T11:35:04.961Z" },
]
[[package]]
name = "notebook"
version = "7.4.4"
@@ -1373,6 +1424,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932, upload-time = "2024-09-26T14:33:23.039Z" },
]
[[package]]
name = "optax"
version = "0.2.5"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "absl-py" },
{ name = "chex" },
{ name = "jax" },
{ name = "jaxlib" },
{ name = "numpy" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c0/75/1e011953c48be502d4d84fa8458e91be7c6f983002511669bddd7b1a065f/optax-0.2.5.tar.gz", hash = "sha256:b2e38c7aea376186deae758ba7a258e6ef760c6f6131e9e11bc561c65386d594", size = 258548, upload-time = "2025-06-10T17:00:47.544Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b9/33/f86091c706db1a5459f501830241afff2ecab3532725c188ea57be6e54de/optax-0.2.5-py3-none-any.whl", hash = "sha256:966deae936207f268ac8f564d8ed228d645ac1aaddefbbf194096d2299b24ba8", size = 354324, upload-time = "2025-06-10T17:00:46.062Z" },
]
[[package]]
name = "overrides"
version = "7.7.0"
@@ -1862,6 +1929,35 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/75/04/5302cea1aa26d886d34cadbf2dc77d90d7737e576c0065f357b96dc7a1a6/rpds_py-0.26.0-cp314-cp314t-win_amd64.whl", hash = "sha256:f14440b9573a6f76b4ee4770c13f0b5921f71dde3b6fcb8dabbefd13b7fe05d7", size = 232821, upload-time = "2025-07-01T15:55:55.167Z" },
]
[[package]]
name = "scikit-learn"
version = "1.7.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "joblib" },
{ name = "numpy" },
{ name = "scipy" },
{ name = "threadpoolctl" },
]
sdist = { url = "https://files.pythonhosted.org/packages/41/84/5f4af978fff619706b8961accac84780a6d298d82a8873446f72edb4ead0/scikit_learn-1.7.1.tar.gz", hash = "sha256:24b3f1e976a4665aa74ee0fcaac2b8fccc6ae77c8e07ab25da3ba6d3292b9802", size = 7190445, upload-time = "2025-07-18T08:01:54.5Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cb/16/57f176585b35ed865f51b04117947fe20f130f78940c6477b6d66279c9c2/scikit_learn-1.7.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3cee419b49b5bbae8796ecd690f97aa412ef1674410c23fc3257c6b8b85b8087", size = 9260431, upload-time = "2025-07-18T08:01:22.77Z" },
{ url = "https://files.pythonhosted.org/packages/67/4e/899317092f5efcab0e9bc929e3391341cec8fb0e816c4789686770024580/scikit_learn-1.7.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2fd8b8d35817b0d9ebf0b576f7d5ffbbabdb55536b0655a8aaae629d7ffd2e1f", size = 8637191, upload-time = "2025-07-18T08:01:24.731Z" },
{ url = "https://files.pythonhosted.org/packages/f3/1b/998312db6d361ded1dd56b457ada371a8d8d77ca2195a7d18fd8a1736f21/scikit_learn-1.7.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:588410fa19a96a69763202f1d6b7b91d5d7a5d73be36e189bc6396bfb355bd87", size = 9486346, upload-time = "2025-07-18T08:01:26.713Z" },
{ url = "https://files.pythonhosted.org/packages/ad/09/a2aa0b4e644e5c4ede7006748f24e72863ba2ae71897fecfd832afea01b4/scikit_learn-1.7.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e3142f0abe1ad1d1c31a2ae987621e41f6b578144a911ff4ac94781a583adad7", size = 9290988, upload-time = "2025-07-18T08:01:28.938Z" },
{ url = "https://files.pythonhosted.org/packages/15/fa/c61a787e35f05f17fc10523f567677ec4eeee5f95aa4798dbbbcd9625617/scikit_learn-1.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:3ddd9092c1bd469acab337d87930067c87eac6bd544f8d5027430983f1e1ae88", size = 8735568, upload-time = "2025-07-18T08:01:30.936Z" },
{ url = "https://files.pythonhosted.org/packages/52/f8/e0533303f318a0f37b88300d21f79b6ac067188d4824f1047a37214ab718/scikit_learn-1.7.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b7839687fa46d02e01035ad775982f2470be2668e13ddd151f0f55a5bf123bae", size = 9213143, upload-time = "2025-07-18T08:01:32.942Z" },
{ url = "https://files.pythonhosted.org/packages/71/f3/f1df377d1bdfc3e3e2adc9c119c238b182293e6740df4cbeac6de2cc3e23/scikit_learn-1.7.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:a10f276639195a96c86aa572ee0698ad64ee939a7b042060b98bd1930c261d10", size = 8591977, upload-time = "2025-07-18T08:01:34.967Z" },
{ url = "https://files.pythonhosted.org/packages/99/72/c86a4cd867816350fe8dee13f30222340b9cd6b96173955819a5561810c5/scikit_learn-1.7.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:13679981fdaebc10cc4c13c43344416a86fcbc61449cb3e6517e1df9d12c8309", size = 9436142, upload-time = "2025-07-18T08:01:37.397Z" },
{ url = "https://files.pythonhosted.org/packages/e8/66/277967b29bd297538dc7a6ecfb1a7dce751beabd0d7f7a2233be7a4f7832/scikit_learn-1.7.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4f1262883c6a63f067a980a8cdd2d2e7f2513dddcef6a9eaada6416a7a7cbe43", size = 9282996, upload-time = "2025-07-18T08:01:39.721Z" },
{ url = "https://files.pythonhosted.org/packages/e2/47/9291cfa1db1dae9880420d1e07dbc7e8dd4a7cdbc42eaba22512e6bde958/scikit_learn-1.7.1-cp313-cp313-win_amd64.whl", hash = "sha256:ca6d31fb10e04d50bfd2b50d66744729dbb512d4efd0223b864e2fdbfc4cee11", size = 8707418, upload-time = "2025-07-18T08:01:42.124Z" },
{ url = "https://files.pythonhosted.org/packages/61/95/45726819beccdaa34d3362ea9b2ff9f2b5d3b8bf721bd632675870308ceb/scikit_learn-1.7.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:781674d096303cfe3d351ae6963ff7c958db61cde3421cd490e3a5a58f2a94ae", size = 9561466, upload-time = "2025-07-18T08:01:44.195Z" },
{ url = "https://files.pythonhosted.org/packages/ee/1c/6f4b3344805de783d20a51eb24d4c9ad4b11a7f75c1801e6ec6d777361fd/scikit_learn-1.7.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:10679f7f125fe7ecd5fad37dd1aa2daae7e3ad8df7f3eefa08901b8254b3e12c", size = 9040467, upload-time = "2025-07-18T08:01:46.671Z" },
{ url = "https://files.pythonhosted.org/packages/6f/80/abe18fe471af9f1d181904203d62697998b27d9b62124cd281d740ded2f9/scikit_learn-1.7.1-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1f812729e38c8cb37f760dce71a9b83ccfb04f59b3dca7c6079dcdc60544fa9e", size = 9532052, upload-time = "2025-07-18T08:01:48.676Z" },
{ url = "https://files.pythonhosted.org/packages/14/82/b21aa1e0c4cee7e74864d3a5a721ab8fcae5ca55033cb6263dca297ed35b/scikit_learn-1.7.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:88e1a20131cf741b84b89567e1717f27a2ced228e0f29103426102bc2e3b8ef7", size = 9361575, upload-time = "2025-07-18T08:01:50.639Z" },
{ url = "https://files.pythonhosted.org/packages/f2/20/f4777fcd5627dc6695fa6b92179d0edb7a3ac1b91bcd9a1c7f64fa7ade23/scikit_learn-1.7.1-cp313-cp313t-win_amd64.whl", hash = "sha256:b1bd1d919210b6a10b7554b717c9000b5485aa95a1d0f177ae0d7ee8ec750da5", size = 9277310, upload-time = "2025-07-18T08:01:52.547Z" },
]
[[package]]
name = "scipy"
version = "1.16.0"
@@ -1987,6 +2083,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/6a/9e/2064975477fdc887e47ad42157e214526dcad8f317a948dee17e1659a62f/terminado-0.18.1-py3-none-any.whl", hash = "sha256:a4468e1b37bb318f8a86514f65814e1afc977cf29b3992a4500d9dd305dcceb0", size = 14154, upload-time = "2024-03-12T14:34:36.569Z" },
]
[[package]]
name = "threadpoolctl"
version = "3.6.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/b7/4d/08c89e34946fce2aec4fbb45c9016efd5f4d7f24af8e5d93296e935631d8/threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e", size = 21274, upload-time = "2025-03-13T13:49:23.031Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" },
]
[[package]]
name = "tinycss2"
version = "1.4.0"
@@ -1999,6 +2104,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e6/34/ebdc18bae6aa14fbee1a08b63c015c72b64868ff7dae68808ab500c492e2/tinycss2-1.4.0-py3-none-any.whl", hash = "sha256:3a49cf47b7675da0b15d0c6e1df8df4ebd96e9394bb905a5775adb0d884c5289", size = 26610, upload-time = "2024-10-24T14:58:28.029Z" },
]
[[package]]
name = "toolz"
version = "1.0.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/8a/0b/d80dfa675bf592f636d1ea0b835eab4ec8df6e9415d8cfd766df54456123/toolz-1.0.0.tar.gz", hash = "sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02", size = 66790, upload-time = "2024-10-04T16:17:04.001Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/03/98/eb27cc78ad3af8e302c9d8ff4977f5026676e130d28dd7578132a457170c/toolz-1.0.0-py3-none-any.whl", hash = "sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236", size = 56383, upload-time = "2024-10-04T16:17:01.533Z" },
]
[[package]]
name = "tornado"
version = "6.5.1"