import jax import jax.numpy as jnp class ConsensusConfig: """ Config class for Consensus dynamics sims """ num_agents: int # Number of agents in the consensus simulation max_range: float = 100 # Max range of values each agent can take step_size: float = 1 # Target range for length of simulation def consensus_step(adj_matrix: jax.Array, agent_states: jax.Array, config: ConsensusConfig): """ Takes a step given the adjacency matrix and the current agent state using consensus dynamics. Parameters ----------------------------- adj_matrix : jax.Array (num_agents, num_agents) A jax array containing the adjacency matrix for the consensus step. agent_states: jax.Array (num_agents) A jax array containing the current agent state config: ConsensusConfig Config class for Consensus Dynamics Returns ------------------------------ updated_agent_state: jax.Array (num_agents) A jax array containing the updated agent state """ L = jnp.diag(adj_matrix.sum(axis=0)) - adj_matrix return agent_states + config.step_size * L * agent_states def generate_random_adjacency_matrix(key: jax.Array, config: ConsensusConfig): """ Generates a random adjacency matrix when given """