Initial commit
This commit is contained in:
46
consensus.py
Normal file
46
consensus.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
|
||||||
|
class ConsensusConfig:
|
||||||
|
"""
|
||||||
|
Config class for Consensus dynamics sims
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_agents: int # Number of agents in the consensus simulation
|
||||||
|
max_range: float = 100 # Max range of values each agent can take
|
||||||
|
step_size: float = 1 # Target range for length of simulation
|
||||||
|
|
||||||
|
def consensus_step(adj_matrix: jax.Array, agent_states: jax.Array, config: ConsensusConfig):
|
||||||
|
"""
|
||||||
|
Takes a step given the adjacency matrix and the current agent state using consensus dynamics.
|
||||||
|
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------------------------
|
||||||
|
adj_matrix : jax.Array (num_agents, num_agents)
|
||||||
|
A jax array containing the adjacency matrix for the consensus step.
|
||||||
|
|
||||||
|
|
||||||
|
agent_states: jax.Array (num_agents)
|
||||||
|
A jax array containing the current agent state
|
||||||
|
|
||||||
|
config: ConsensusConfig
|
||||||
|
Config class for Consensus Dynamics
|
||||||
|
|
||||||
|
|
||||||
|
Returns
|
||||||
|
------------------------------
|
||||||
|
updated_agent_state: jax.Array (num_agents)
|
||||||
|
A jax array containing the updated agent state
|
||||||
|
|
||||||
|
"""
|
||||||
|
L = jnp.diag(adj_matrix.sum(axis=0)) - adj_matrix
|
||||||
|
return agent_states + config.step_size * L * agent_states
|
||||||
|
|
||||||
|
def generate_random_adjacency_matrix(key: jax.Array, config: ConsensusConfig):
|
||||||
|
"""
|
||||||
|
Generates a random adjacency matrix when given
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
@@ -4,4 +4,12 @@ version = "0.1.0"
|
|||||||
description = "Add your description here"
|
description = "Add your description here"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = []
|
dependencies = [
|
||||||
|
"ipykernel>=6.30.0",
|
||||||
|
"ipython>=9.4.0",
|
||||||
|
"jax[cuda12]>=0.7.0",
|
||||||
|
"jupyter>=1.1.1",
|
||||||
|
"matplotlib>=3.10.3",
|
||||||
|
"seaborn>=0.13.2",
|
||||||
|
"tqdm>=4.67.1",
|
||||||
|
]
|
45
test.ipynb
Normal file
45
test.ipynb
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "c883f5e2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import jax"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "ea833558",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"key = jax.random.PRNGKey(0)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "graph-recognition-w-attn",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.12.3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
Reference in New Issue
Block a user