Files
graph_recognition_w_attn/consensus.py

46 lines
1.3 KiB
Python
Raw Normal View History

2025-07-25 00:52:58 -04:00
import jax
import jax.numpy as jnp
class ConsensusConfig:
"""
Config class for Consensus dynamics sims
"""
num_agents: int # Number of agents in the consensus simulation
max_range: float = 100 # Max range of values each agent can take
step_size: float = 1 # Target range for length of simulation
def consensus_step(adj_matrix: jax.Array, agent_states: jax.Array, config: ConsensusConfig):
"""
Takes a step given the adjacency matrix and the current agent state using consensus dynamics.
Parameters
-----------------------------
adj_matrix : jax.Array (num_agents, num_agents)
A jax array containing the adjacency matrix for the consensus step.
agent_states: jax.Array (num_agents)
A jax array containing the current agent state
config: ConsensusConfig
Config class for Consensus Dynamics
Returns
------------------------------
updated_agent_state: jax.Array (num_agents)
A jax array containing the updated agent state
"""
L = jnp.diag(adj_matrix.sum(axis=0)) - adj_matrix
return agent_states + config.step_size * L * agent_states
def generate_random_adjacency_matrix(key: jax.Array, config: ConsensusConfig):
"""
Generates a random adjacency matrix when given
"""