import jax from jax import random import jax.numpy as jnp from train import ModelConfig, TrainConfig import optax from functools import partial def init_linear_layer( key: jax.Array, in_features: int, out_features: int, use_bias: bool = True ): """ Initializes the weights and biases in a linear layer. """ key_w, key_b = random.split(key) limit = jnp.sqrt(6/in_features) W = random.uniform(key_w, (in_features, out_features), minval=-limit, maxval=limit) params = {'W': W} if use_bias: b = random.uniform(key_b, (out_features,), minval=-limit, maxval=limit) params['b'] = b return params def init_fn(key: jax.Array, config: ModelConfig): """ Initializes all model parameters. Returns a pytree """ key_embed, key_translate, key_attn_proj, key_head = random.split(key, 4) params = { "agent_embeddings" : { "weight" : random.normal(key_embed, shape=(config.num_agents, config.embedding_dim)) }, "translate": init_linear_layer(key_translate, config.input_dim, config.embedding_dim), "attn_proj": init_linear_layer(key_attn_proj, config.embedding_dim, 2 * config.embedding_dim, use_bias=False), "head": init_linear_layer(key_head, config.embedding_dim, config.output_dim) } return params def forward(params: dict, input_timesteps: jax.Array, config: ModelConfig): """ Model's forward function. Takes in the parameters and inptu timesteps, returns predictions """ batch_size, num_agents, _ = input_timesteps.shape agent_embed = params["agent_embeddings"]["weight"] agent_embed = jnp.broadcast_to(agent_embed, (batch_size, num_agents, config.embedding_dim)) attn_proj_out = agent_embed @ params["attn_proj"]['W'] k, q = jnp.split(attn_proj_out, 2, axis=-1) v = input_timesteps @ params["translate"]['W'] + params["translate"]['b'] att_scores = (q @ k.transpose(0, 2, 1) )/ jnp.sqrt(num_agents) att_weights = jax.nn.softmax(att_scores, axis=-1) weighted_average = att_weights @ v prediction = weighted_average @ params["head"]['W'] + params["head"]['b'] return prediction def get_attention_fn(params: dict, config: ModelConfig): """ Calculates and returns the learned attention matrix between agents. This is a pure function for analysis. """ embeddings = params['agent_embeddings']['weight'] # Project embeddings to get keys (k) and queries (q) for the global graph attn_proj_out = embeddings @ params['attn_proj']['W'] k, q = jnp.split(attn_proj_out, 2, axis=-1) # Note: Using sqrt(embedding_dim) as in the original get_attention method attn_scores = (q @ k.T) / jnp.sqrt(q.shape[-1]) return jnp.asarray(attn_scores) # Return as NumPy array for logging def train_model(config: ModelConfig, inputs: jax.Array, targets: jax.Array, true_graph: jax.Array, train_config: TrainConfig, ): key = random.PRNGKey(0) key, init_key = random.split(key) params = init_fn(init_key, config) optimizer = optax.adamw(train_config.learning_rate) opt_state = optimizer.init(params) def loss_fn(p, x_batch, y_batch, config): predictions = forward(p, x_batch, config) loss = jnp.mean(jnp.abs(predictions - y_batch)) return loss @partial(jax.jit, static_argnames=['config']) def update_step(params, opt_state, x_batch, y_batch, config): # FIX 1: Pass all necessary arguments to the loss function here. loss_val, grads = jax.value_and_grad(loss_fn)(params, x_batch, y_batch, config) updates, new_opt_state = optimizer.update(grads, opt_state, params) new_params = optax.apply_updates(params, updates) return new_params, new_opt_state, loss_val loss_history = {f"epoch_{i}": [] for i in range(train_config.epochs)} # FIX 2: Initialize with empty lists `[]` instead of `None`. graphs = {f"epoch_{i}": [] for i in range(train_config.epochs)} num_batches = len(inputs) for epoch in range(train_config.epochs): running_loss = 0.0 for batch_num in range(num_batches): x, y = inputs[batch_num], targets[batch_num] # FIX 3: Pass the `config` object, as it's a static argument for JIT. params, opt_state, loss_val = update_step(params, opt_state, x, y, config) running_loss += loss_val epoch_loss = running_loss / num_batches loss_history[f"epoch_{epoch}"].append(epoch_loss) if train_config.verbose and (epoch + 1) % 10 == 0: print(f"Epoch {epoch+1:3d} | Loss: {epoch_loss:.6f}") if train_config.log and epoch % train_config.log_epoch_interval == 0: attn = get_attention_fn(params, config) graphs[f"epoch_{epoch}"].append(attn) all_logs = { "loss_history": loss_history, "graphs": graphs, "true_graph": true_graph, } return params, all_logs