replicated mecc
This commit is contained in:
136
model.py
Normal file
136
model.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import jax
|
||||
from jax import random
|
||||
import jax.numpy as jnp
|
||||
from train import ModelConfig, TrainConfig
|
||||
import optax
|
||||
from functools import partial
|
||||
|
||||
def init_linear_layer(
|
||||
key: jax.Array,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
use_bias: bool = True
|
||||
):
|
||||
|
||||
"""
|
||||
Initializes the weights and biases in a linear layer.
|
||||
"""
|
||||
key_w, key_b = random.split(key)
|
||||
limit = jnp.sqrt(6/in_features)
|
||||
W = random.uniform(key_w, (in_features, out_features), minval=-limit, maxval=limit)
|
||||
params = {'W': W}
|
||||
if use_bias:
|
||||
b = random.uniform(key_b, (out_features,), minval=-limit, maxval=limit)
|
||||
params['b'] = b
|
||||
return params
|
||||
|
||||
def init_fn(key: jax.Array, config: ModelConfig):
|
||||
"""
|
||||
Initializes all model parameters. Returns a pytree
|
||||
"""
|
||||
|
||||
key_embed, key_translate, key_attn_proj, key_head = random.split(key, 4)
|
||||
|
||||
params = {
|
||||
"agent_embeddings" : {
|
||||
"weight" : random.normal(key_embed, shape=(config.num_agents, config.embedding_dim))
|
||||
},
|
||||
"translate": init_linear_layer(key_translate, config.input_dim, config.embedding_dim),
|
||||
"attn_proj": init_linear_layer(key_attn_proj, config.embedding_dim, 2 * config.embedding_dim, use_bias=False),
|
||||
"head": init_linear_layer(key_head, config.embedding_dim, config.output_dim)
|
||||
}
|
||||
|
||||
return params
|
||||
|
||||
def forward(params: dict, input_timesteps: jax.Array, config: ModelConfig):
|
||||
"""
|
||||
Model's forward function. Takes in the parameters and inptu timesteps, returns predictions
|
||||
"""
|
||||
batch_size, num_agents, _ = input_timesteps.shape
|
||||
|
||||
agent_embed = params["agent_embeddings"]["weight"]
|
||||
agent_embed = jnp.broadcast_to(agent_embed, (batch_size, num_agents, config.embedding_dim))
|
||||
|
||||
attn_proj_out = agent_embed @ params["attn_proj"]['W']
|
||||
k, q = jnp.split(attn_proj_out, 2, axis=-1)
|
||||
v = input_timesteps @ params["translate"]['W'] + params["translate"]['b']
|
||||
att_scores = (q @ k.transpose(0, 2, 1) )/ jnp.sqrt(num_agents)
|
||||
att_weights = jax.nn.softmax(att_scores, axis=-1)
|
||||
weighted_average = att_weights @ v
|
||||
prediction = weighted_average @ params["head"]['W'] + params["head"]['b']
|
||||
|
||||
return prediction
|
||||
|
||||
def get_attention_fn(params: dict, config: ModelConfig):
|
||||
"""
|
||||
Calculates and returns the learned attention matrix between agents.
|
||||
This is a pure function for analysis.
|
||||
"""
|
||||
embeddings = params['agent_embeddings']['weight']
|
||||
|
||||
# Project embeddings to get keys (k) and queries (q) for the global graph
|
||||
attn_proj_out = embeddings @ params['attn_proj']['W']
|
||||
k, q = jnp.split(attn_proj_out, 2, axis=-1)
|
||||
|
||||
# Note: Using sqrt(embedding_dim) as in the original get_attention method
|
||||
attn_scores = (q @ k.T) / jnp.sqrt(q.shape[-1])
|
||||
|
||||
return jnp.asarray(attn_scores) # Return as NumPy array for logging
|
||||
|
||||
def train_model(config: ModelConfig, inputs: jax.Array, targets: jax.Array,
|
||||
true_graph: jax.Array,
|
||||
train_config: TrainConfig,
|
||||
):
|
||||
|
||||
key = random.PRNGKey(0)
|
||||
key, init_key = random.split(key)
|
||||
params = init_fn(init_key, config)
|
||||
|
||||
optimizer = optax.adamw(train_config.learning_rate)
|
||||
opt_state = optimizer.init(params)
|
||||
|
||||
def loss_fn(p, x_batch, y_batch, config):
|
||||
predictions = forward(p, x_batch, config)
|
||||
loss = jnp.mean(jnp.abs(predictions - y_batch))
|
||||
return loss
|
||||
|
||||
@partial(jax.jit, static_argnames=['config'])
|
||||
def update_step(params, opt_state, x_batch, y_batch, config):
|
||||
# FIX 1: Pass all necessary arguments to the loss function here.
|
||||
loss_val, grads = jax.value_and_grad(loss_fn)(params, x_batch, y_batch, config)
|
||||
|
||||
updates, new_opt_state = optimizer.update(grads, opt_state, params)
|
||||
new_params = optax.apply_updates(params, updates)
|
||||
|
||||
return new_params, new_opt_state, loss_val
|
||||
|
||||
loss_history = {f"epoch_{i}": [] for i in range(train_config.epochs)}
|
||||
# FIX 2: Initialize with empty lists `[]` instead of `None`.
|
||||
graphs = {f"epoch_{i}": [] for i in range(train_config.epochs)}
|
||||
num_batches = len(inputs)
|
||||
|
||||
for epoch in range(train_config.epochs):
|
||||
running_loss = 0.0
|
||||
for batch_num in range(num_batches):
|
||||
x, y = inputs[batch_num], targets[batch_num]
|
||||
# FIX 3: Pass the `config` object, as it's a static argument for JIT.
|
||||
params, opt_state, loss_val = update_step(params, opt_state, x, y, config)
|
||||
running_loss += loss_val
|
||||
|
||||
epoch_loss = running_loss / num_batches
|
||||
loss_history[f"epoch_{epoch}"].append(epoch_loss)
|
||||
|
||||
if train_config.verbose and (epoch + 1) % 10 == 0:
|
||||
print(f"Epoch {epoch+1:3d} | Loss: {epoch_loss:.6f}")
|
||||
|
||||
if train_config.log and epoch % train_config.log_epoch_interval == 0:
|
||||
attn = get_attention_fn(params, config)
|
||||
graphs[f"epoch_{epoch}"].append(attn)
|
||||
|
||||
all_logs = {
|
||||
"loss_history": loss_history,
|
||||
"graphs": graphs,
|
||||
"true_graph": true_graph,
|
||||
}
|
||||
|
||||
return params, all_logs
|
Reference in New Issue
Block a user