136 lines
4.9 KiB
Python
136 lines
4.9 KiB
Python
import jax
|
|
from jax import random
|
|
import jax.numpy as jnp
|
|
from train import ModelConfig, TrainConfig
|
|
import optax
|
|
from functools import partial
|
|
|
|
def init_linear_layer(
|
|
key: jax.Array,
|
|
in_features: int,
|
|
out_features: int,
|
|
use_bias: bool = True
|
|
):
|
|
|
|
"""
|
|
Initializes the weights and biases in a linear layer.
|
|
"""
|
|
key_w, key_b = random.split(key)
|
|
limit = jnp.sqrt(6/in_features)
|
|
W = random.uniform(key_w, (in_features, out_features), minval=-limit, maxval=limit)
|
|
params = {'W': W}
|
|
if use_bias:
|
|
b = random.uniform(key_b, (out_features,), minval=-limit, maxval=limit)
|
|
params['b'] = b
|
|
return params
|
|
|
|
def init_fn(key: jax.Array, config: ModelConfig):
|
|
"""
|
|
Initializes all model parameters. Returns a pytree
|
|
"""
|
|
|
|
key_embed, key_translate, key_attn_proj, key_head = random.split(key, 4)
|
|
|
|
params = {
|
|
"agent_embeddings" : {
|
|
"weight" : random.normal(key_embed, shape=(config.num_agents, config.embedding_dim))
|
|
},
|
|
"translate": init_linear_layer(key_translate, config.input_dim, config.embedding_dim),
|
|
"attn_proj": init_linear_layer(key_attn_proj, config.embedding_dim, 2 * config.embedding_dim, use_bias=False),
|
|
"head": init_linear_layer(key_head, config.embedding_dim, config.output_dim)
|
|
}
|
|
|
|
return params
|
|
|
|
def forward(params: dict, input_timesteps: jax.Array, config: ModelConfig):
|
|
"""
|
|
Model's forward function. Takes in the parameters and inptu timesteps, returns predictions
|
|
"""
|
|
batch_size, num_agents, _ = input_timesteps.shape
|
|
|
|
agent_embed = params["agent_embeddings"]["weight"]
|
|
agent_embed = jnp.broadcast_to(agent_embed, (batch_size, num_agents, config.embedding_dim))
|
|
|
|
attn_proj_out = agent_embed @ params["attn_proj"]['W']
|
|
k, q = jnp.split(attn_proj_out, 2, axis=-1)
|
|
v = input_timesteps @ params["translate"]['W'] + params["translate"]['b']
|
|
att_scores = (q @ k.transpose(0, 2, 1) )/ jnp.sqrt(num_agents)
|
|
att_weights = jax.nn.softmax(att_scores, axis=-1)
|
|
weighted_average = att_weights @ v
|
|
prediction = weighted_average @ params["head"]['W'] + params["head"]['b']
|
|
|
|
return prediction
|
|
|
|
def get_attention_fn(params: dict, config: ModelConfig):
|
|
"""
|
|
Calculates and returns the learned attention matrix between agents.
|
|
This is a pure function for analysis.
|
|
"""
|
|
embeddings = params['agent_embeddings']['weight']
|
|
|
|
# Project embeddings to get keys (k) and queries (q) for the global graph
|
|
attn_proj_out = embeddings @ params['attn_proj']['W']
|
|
k, q = jnp.split(attn_proj_out, 2, axis=-1)
|
|
|
|
# Note: Using sqrt(embedding_dim) as in the original get_attention method
|
|
attn_scores = (q @ k.T) / jnp.sqrt(q.shape[-1])
|
|
|
|
return jnp.asarray(attn_scores) # Return as NumPy array for logging
|
|
|
|
def train_model(config: ModelConfig, inputs: jax.Array, targets: jax.Array,
|
|
true_graph: jax.Array,
|
|
train_config: TrainConfig,
|
|
):
|
|
|
|
key = random.PRNGKey(0)
|
|
key, init_key = random.split(key)
|
|
params = init_fn(init_key, config)
|
|
|
|
optimizer = optax.adamw(train_config.learning_rate)
|
|
opt_state = optimizer.init(params)
|
|
|
|
def loss_fn(p, x_batch, y_batch, config):
|
|
predictions = forward(p, x_batch, config)
|
|
loss = jnp.mean(jnp.abs(predictions - y_batch))
|
|
return loss
|
|
|
|
@partial(jax.jit, static_argnames=['config'])
|
|
def update_step(params, opt_state, x_batch, y_batch, config):
|
|
# FIX 1: Pass all necessary arguments to the loss function here.
|
|
loss_val, grads = jax.value_and_grad(loss_fn)(params, x_batch, y_batch, config)
|
|
|
|
updates, new_opt_state = optimizer.update(grads, opt_state, params)
|
|
new_params = optax.apply_updates(params, updates)
|
|
|
|
return new_params, new_opt_state, loss_val
|
|
|
|
loss_history = {f"epoch_{i}": [] for i in range(train_config.epochs)}
|
|
# FIX 2: Initialize with empty lists `[]` instead of `None`.
|
|
graphs = {f"epoch_{i}": [] for i in range(train_config.epochs)}
|
|
num_batches = len(inputs)
|
|
|
|
for epoch in range(train_config.epochs):
|
|
running_loss = 0.0
|
|
for batch_num in range(num_batches):
|
|
x, y = inputs[batch_num], targets[batch_num]
|
|
# FIX 3: Pass the `config` object, as it's a static argument for JIT.
|
|
params, opt_state, loss_val = update_step(params, opt_state, x, y, config)
|
|
running_loss += loss_val
|
|
|
|
epoch_loss = running_loss / num_batches
|
|
loss_history[f"epoch_{epoch}"].append(epoch_loss)
|
|
|
|
if train_config.verbose and (epoch + 1) % 10 == 0:
|
|
print(f"Epoch {epoch+1:3d} | Loss: {epoch_loss:.6f}")
|
|
|
|
if train_config.log and epoch % train_config.log_epoch_interval == 0:
|
|
attn = get_attention_fn(params, config)
|
|
graphs[f"epoch_{epoch}"].append(attn)
|
|
|
|
all_logs = {
|
|
"loss_history": loss_history,
|
|
"graphs": graphs,
|
|
"true_graph": true_graph,
|
|
}
|
|
|
|
return params, all_logs |