Files
2025-07-31 01:12:53 -04:00

136 lines
4.9 KiB
Python

import jax
from jax import random
import jax.numpy as jnp
from train import ModelConfig, TrainConfig
import optax
from functools import partial
def init_linear_layer(
key: jax.Array,
in_features: int,
out_features: int,
use_bias: bool = True
):
"""
Initializes the weights and biases in a linear layer.
"""
key_w, key_b = random.split(key)
limit = jnp.sqrt(6/in_features)
W = random.uniform(key_w, (in_features, out_features), minval=-limit, maxval=limit)
params = {'W': W}
if use_bias:
b = random.uniform(key_b, (out_features,), minval=-limit, maxval=limit)
params['b'] = b
return params
def init_fn(key: jax.Array, config: ModelConfig):
"""
Initializes all model parameters. Returns a pytree
"""
key_embed, key_translate, key_attn_proj, key_head = random.split(key, 4)
params = {
"agent_embeddings" : {
"weight" : random.normal(key_embed, shape=(config.num_agents, config.embedding_dim))
},
"translate": init_linear_layer(key_translate, config.input_dim, config.embedding_dim),
"attn_proj": init_linear_layer(key_attn_proj, config.embedding_dim, 2 * config.embedding_dim, use_bias=False),
"head": init_linear_layer(key_head, config.embedding_dim, config.output_dim)
}
return params
def forward(params: dict, input_timesteps: jax.Array, config: ModelConfig):
"""
Model's forward function. Takes in the parameters and inptu timesteps, returns predictions
"""
batch_size, num_agents, _ = input_timesteps.shape
agent_embed = params["agent_embeddings"]["weight"]
agent_embed = jnp.broadcast_to(agent_embed, (batch_size, num_agents, config.embedding_dim))
attn_proj_out = agent_embed @ params["attn_proj"]['W']
k, q = jnp.split(attn_proj_out, 2, axis=-1)
v = input_timesteps @ params["translate"]['W'] + params["translate"]['b']
att_scores = (q @ k.transpose(0, 2, 1) )/ jnp.sqrt(num_agents)
att_weights = jax.nn.softmax(att_scores, axis=-1)
weighted_average = att_weights @ v
prediction = weighted_average @ params["head"]['W'] + params["head"]['b']
return prediction
def get_attention_fn(params: dict, config: ModelConfig):
"""
Calculates and returns the learned attention matrix between agents.
This is a pure function for analysis.
"""
embeddings = params['agent_embeddings']['weight']
# Project embeddings to get keys (k) and queries (q) for the global graph
attn_proj_out = embeddings @ params['attn_proj']['W']
k, q = jnp.split(attn_proj_out, 2, axis=-1)
# Note: Using sqrt(embedding_dim) as in the original get_attention method
attn_scores = (q @ k.T) / jnp.sqrt(q.shape[-1])
return jnp.asarray(attn_scores) # Return as NumPy array for logging
def train_model(config: ModelConfig, inputs: jax.Array, targets: jax.Array,
true_graph: jax.Array,
train_config: TrainConfig,
):
key = random.PRNGKey(0)
key, init_key = random.split(key)
params = init_fn(init_key, config)
optimizer = optax.adamw(train_config.learning_rate)
opt_state = optimizer.init(params)
def loss_fn(p, x_batch, y_batch, config):
predictions = forward(p, x_batch, config)
loss = jnp.mean(jnp.abs(predictions - y_batch))
return loss
@partial(jax.jit, static_argnames=['config'])
def update_step(params, opt_state, x_batch, y_batch, config):
# FIX 1: Pass all necessary arguments to the loss function here.
loss_val, grads = jax.value_and_grad(loss_fn)(params, x_batch, y_batch, config)
updates, new_opt_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_state, loss_val
loss_history = {f"epoch_{i}": [] for i in range(train_config.epochs)}
# FIX 2: Initialize with empty lists `[]` instead of `None`.
graphs = {f"epoch_{i}": [] for i in range(train_config.epochs)}
num_batches = len(inputs)
for epoch in range(train_config.epochs):
running_loss = 0.0
for batch_num in range(num_batches):
x, y = inputs[batch_num], targets[batch_num]
# FIX 3: Pass the `config` object, as it's a static argument for JIT.
params, opt_state, loss_val = update_step(params, opt_state, x, y, config)
running_loss += loss_val
epoch_loss = running_loss / num_batches
loss_history[f"epoch_{epoch}"].append(epoch_loss)
if train_config.verbose and (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1:3d} | Loss: {epoch_loss:.6f}")
if train_config.log and epoch % train_config.log_epoch_interval == 0:
attn = get_attention_fn(params, config)
graphs[f"epoch_{epoch}"].append(attn)
all_logs = {
"loss_history": loss_history,
"graphs": graphs,
"true_graph": true_graph,
}
return params, all_logs