2025-07-25 00:52:58 -04:00
|
|
|
import jax
|
|
|
|
import jax.numpy as jnp
|
2025-07-25 21:50:20 -04:00
|
|
|
from dataclasses import dataclass
|
|
|
|
from functools import partial
|
|
|
|
import numpy as np
|
2025-07-25 00:52:58 -04:00
|
|
|
|
|
|
|
|
|
|
|
class ConsensusConfig:
|
|
|
|
"""
|
|
|
|
Config class for Consensus dynamics sims
|
|
|
|
"""
|
2025-07-25 21:50:20 -04:00
|
|
|
num_sims: int = 500 # Number of consensus sims
|
2025-07-31 01:12:53 -04:00
|
|
|
num_agents: int = 5 # Number of agents in the consensus simulation
|
2025-07-25 21:50:20 -04:00
|
|
|
max_range: float = 1 # Max range of values each agent can take
|
|
|
|
step_size: float = 0.1 # Target range for length of simulation
|
|
|
|
directed: bool = False # Consensus graph directed?
|
|
|
|
weighted: bool = False
|
|
|
|
num_time_steps: int = 100
|
2025-07-25 00:52:58 -04:00
|
|
|
|
|
|
|
|
|
|
|
def consensus_step(adj_matrix: jax.Array, agent_states: jax.Array, config: ConsensusConfig):
|
|
|
|
"""
|
|
|
|
Takes a step given the adjacency matrix and the current agent state using consensus dynamics.
|
|
|
|
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
-----------------------------
|
|
|
|
adj_matrix : jax.Array (num_agents, num_agents)
|
|
|
|
A jax array containing the adjacency matrix for the consensus step.
|
|
|
|
|
|
|
|
|
|
|
|
agent_states: jax.Array (num_agents)
|
|
|
|
A jax array containing the current agent state
|
|
|
|
|
|
|
|
config: ConsensusConfig
|
|
|
|
Config class for Consensus Dynamics
|
|
|
|
|
|
|
|
|
|
|
|
Returns
|
|
|
|
------------------------------
|
|
|
|
updated_agent_state: jax.Array (num_agents)
|
|
|
|
A jax array containing the updated agent state
|
|
|
|
|
|
|
|
"""
|
2025-07-25 21:50:20 -04:00
|
|
|
L = adj_matrix
|
|
|
|
norms = jnp.sum(jnp.abs(L), axis=1, keepdims=True)
|
|
|
|
L = L / (norms)
|
|
|
|
|
|
|
|
return (agent_states + config.step_size * jnp.matmul(L , agent_states))/(1 + config.step_size)
|
2025-07-25 00:52:58 -04:00
|
|
|
|
|
|
|
def generate_random_adjacency_matrix(key: jax.Array, config: ConsensusConfig):
|
|
|
|
"""
|
2025-07-25 21:50:20 -04:00
|
|
|
Generates a random adjacency matrix in accordance with the config.
|
|
|
|
The diagonal of the matrix is ensured to be all ones.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
--------------------
|
|
|
|
key: jax.Array
|
|
|
|
A key for jax.random operations
|
|
|
|
|
|
|
|
config: ConsensusConfig
|
|
|
|
Config for Consensus dyanmics
|
2025-07-25 00:52:58 -04:00
|
|
|
|
2025-07-31 01:12:53 -04:00
|
|
|
Returns
|
|
|
|
---------------------
|
|
|
|
adj_matrices: jax.Array (num_agents, num_agents)
|
|
|
|
Random matrix
|
2025-07-25 00:52:58 -04:00
|
|
|
"""
|
2025-07-25 21:50:20 -04:00
|
|
|
rand_matrix = jax.random.uniform(key, shape=(config.num_agents, config.num_agents))
|
|
|
|
# idxs = jnp.arange(config.num_agents)
|
|
|
|
# rand_matrix = rand_matrix.at[:, idxs, idxs].set(1)
|
|
|
|
rand_matrix = jnp.fill_diagonal(rand_matrix, 1, inplace=False) # Fill diagonal with ones
|
|
|
|
if not config.weighted:
|
|
|
|
rand_matrix = jnp.where(rand_matrix > 0.5, 1, 0)
|
|
|
|
|
|
|
|
if config.directed:
|
|
|
|
return rand_matrix
|
|
|
|
|
|
|
|
rand_matrix = jnp.tril(rand_matrix) + jnp.triu(rand_matrix.T, 1)
|
|
|
|
return rand_matrix
|
|
|
|
|
|
|
|
|
|
|
|
def generate_random_agent_states(key: jax.Array, config: ConsensusConfig):
|
|
|
|
"""
|
|
|
|
Generate a random initial state for the agents in accordance with the config.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
---------------------
|
|
|
|
key: jax.Arrray
|
|
|
|
A key for jax.random operations
|
|
|
|
|
|
|
|
config: ConsensusConfig
|
|
|
|
Config for Consensus dynamics
|
|
|
|
|
|
|
|
Returns
|
|
|
|
---------------------
|
2025-07-31 01:12:53 -04:00
|
|
|
rand_states: jax.Array (num_sims, num_agents)
|
2025-07-25 21:50:20 -04:00
|
|
|
|
|
|
|
"""
|
|
|
|
rand_states = jax.random.uniform(key, shape=(config.num_sims, config.num_agents), minval=-config.max_range, maxval=config.max_range)
|
|
|
|
|
|
|
|
return rand_states
|
|
|
|
|
|
|
|
@partial(jax.jit, static_argnames=["config"])
|
|
|
|
def run_consensus_sim(adj_mat: jax.Array, initial_agent_state: jax.Array, config: ConsensusConfig):
|
|
|
|
"""
|
|
|
|
Runs the consensus sim and returns a history of agent states.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
-------------------
|
|
|
|
adj_mat: jax.Array (num_agents, num_agents)
|
|
|
|
A jax array containing the adjacency matrix for the consensus step.
|
|
|
|
|
|
|
|
initial_agent_state: jax.Array (num_agents)
|
|
|
|
A jax array containing the initial agent state
|
|
|
|
|
|
|
|
config: ConsensusConfig
|
|
|
|
Config for Consensus dynamics
|
|
|
|
"""
|
2025-07-31 01:12:53 -04:00
|
|
|
# batched consensus step (meant for many initial states)
|
2025-07-25 21:50:20 -04:00
|
|
|
batched_consensus_step = jax.vmap(consensus_step, in_axes=(None, 0, None), out_axes=0)
|
|
|
|
|
|
|
|
def step(x_prev, _):
|
|
|
|
x_next = batched_consensus_step(adj_mat, x_prev, config)
|
|
|
|
return x_next, x_next
|
|
|
|
|
|
|
|
_, all_states = jax.lax.scan(step, initial_agent_state, None, config.num_time_steps)
|
|
|
|
|
|
|
|
return all_states.transpose(1, 0, 2)
|
|
|
|
|
|
|
|
|
|
|
|
def plot_consensus(trajectory, config):
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import seaborn as sns
|
|
|
|
sns.set_theme()
|
|
|
|
|
|
|
|
states = np.array(trajectory)
|
|
|
|
timesteps, num_agents = states.shape
|
|
|
|
time = np.arange(timesteps)
|
|
|
|
|
|
|
|
plt.figure()
|
|
|
|
for agent_idx in range(num_agents):
|
|
|
|
plt.plot(time ,states[:, agent_idx], label=f"Agent {agent_idx}")
|
|
|
|
|
|
|
|
plt.xlabel("Timestep")
|
|
|
|
plt.ylabel("Agent statue")
|
|
|
|
plt.ylim(-config.max_range, config.max_range)
|
|
|
|
plt.title("Consensus simulation trajectories")
|
|
|
|
plt.legend()
|
|
|
|
plt.show()
|