import jax import jax.numpy as jnp from dataclasses import dataclass from functools import partial import numpy as np class ConsensusConfig: """ Config class for Consensus dynamics sims """ num_sims: int = 500 # Number of consensus sims num_agents: int = 5 # Number of agents in the consensus simulation max_range: float = 1 # Max range of values each agent can take step_size: float = 0.1 # Target range for length of simulation directed: bool = False # Consensus graph directed? weighted: bool = False num_time_steps: int = 100 def consensus_step(adj_matrix: jax.Array, agent_states: jax.Array, config: ConsensusConfig): """ Takes a step given the adjacency matrix and the current agent state using consensus dynamics. Parameters ----------------------------- adj_matrix : jax.Array (num_agents, num_agents) A jax array containing the adjacency matrix for the consensus step. agent_states: jax.Array (num_agents) A jax array containing the current agent state config: ConsensusConfig Config class for Consensus Dynamics Returns ------------------------------ updated_agent_state: jax.Array (num_agents) A jax array containing the updated agent state """ L = adj_matrix norms = jnp.sum(jnp.abs(L), axis=1, keepdims=True) L = L / (norms) return (agent_states + config.step_size * jnp.matmul(L , agent_states))/(1 + config.step_size) def generate_random_adjacency_matrix(key: jax.Array, config: ConsensusConfig): """ Generates a random adjacency matrix in accordance with the config. The diagonal of the matrix is ensured to be all ones. Parameters -------------------- key: jax.Array A key for jax.random operations config: ConsensusConfig Config for Consensus dyanmics Returns --------------------- adj_matrices: jax.Array (num_agents, num_agents) Random matrix """ rand_matrix = jax.random.uniform(key, shape=(config.num_agents, config.num_agents)) # idxs = jnp.arange(config.num_agents) # rand_matrix = rand_matrix.at[:, idxs, idxs].set(1) rand_matrix = jnp.fill_diagonal(rand_matrix, 1, inplace=False) # Fill diagonal with ones if not config.weighted: rand_matrix = jnp.where(rand_matrix > 0.5, 1, 0) if config.directed: return rand_matrix rand_matrix = jnp.tril(rand_matrix) + jnp.triu(rand_matrix.T, 1) return rand_matrix def generate_random_agent_states(key: jax.Array, config: ConsensusConfig): """ Generate a random initial state for the agents in accordance with the config. Parameters --------------------- key: jax.Arrray A key for jax.random operations config: ConsensusConfig Config for Consensus dynamics Returns --------------------- rand_states: jax.Array (num_sims, num_agents) """ rand_states = jax.random.uniform(key, shape=(config.num_sims, config.num_agents), minval=-config.max_range, maxval=config.max_range) return rand_states @partial(jax.jit, static_argnames=["config"]) def run_consensus_sim(adj_mat: jax.Array, initial_agent_state: jax.Array, config: ConsensusConfig): """ Runs the consensus sim and returns a history of agent states. Parameters ------------------- adj_mat: jax.Array (num_agents, num_agents) A jax array containing the adjacency matrix for the consensus step. initial_agent_state: jax.Array (num_agents) A jax array containing the initial agent state config: ConsensusConfig Config for Consensus dynamics """ # batched consensus step (meant for many initial states) batched_consensus_step = jax.vmap(consensus_step, in_axes=(None, 0, None), out_axes=0) def step(x_prev, _): x_next = batched_consensus_step(adj_mat, x_prev, config) return x_next, x_next _, all_states = jax.lax.scan(step, initial_agent_state, None, config.num_time_steps) return all_states.transpose(1, 0, 2) def plot_consensus(trajectory, config): import matplotlib.pyplot as plt import seaborn as sns sns.set_theme() states = np.array(trajectory) timesteps, num_agents = states.shape time = np.arange(timesteps) plt.figure() for agent_idx in range(num_agents): plt.plot(time ,states[:, agent_idx], label=f"Agent {agent_idx}") plt.xlabel("Timestep") plt.ylabel("Agent statue") plt.ylim(-config.max_range, config.max_range) plt.title("Consensus simulation trajectories") plt.legend() plt.show()