Initial commit
This commit is contained in:
		
							
								
								
									
										46
									
								
								consensus.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								consensus.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,46 @@
 | 
				
			|||||||
 | 
					import jax
 | 
				
			||||||
 | 
					import jax.numpy as jnp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ConsensusConfig:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Config class for Consensus dynamics sims
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    num_agents: int # Number of agents in the consensus simulation
 | 
				
			||||||
 | 
					    max_range: float = 100 # Max range of values each agent can take
 | 
				
			||||||
 | 
					    step_size: float = 1 # Target range for length of simulation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def consensus_step(adj_matrix: jax.Array, agent_states: jax.Array, config: ConsensusConfig):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Takes a step given the adjacency matrix and the current agent state using consensus dynamics.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    -----------------------------
 | 
				
			||||||
 | 
					    adj_matrix : jax.Array (num_agents, num_agents)
 | 
				
			||||||
 | 
					        A jax array containing the adjacency matrix for the consensus step.
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    agent_states: jax.Array (num_agents)
 | 
				
			||||||
 | 
					        A jax array containing the current agent state
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    config: ConsensusConfig
 | 
				
			||||||
 | 
					        Config class for Consensus Dynamics
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns
 | 
				
			||||||
 | 
					    ------------------------------
 | 
				
			||||||
 | 
					    updated_agent_state: jax.Array (num_agents)
 | 
				
			||||||
 | 
					        A jax array containing the updated agent state
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    L = jnp.diag(adj_matrix.sum(axis=0)) - adj_matrix
 | 
				
			||||||
 | 
					    return agent_states + config.step_size * L * agent_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def generate_random_adjacency_matrix(key: jax.Array, config: ConsensusConfig):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Generates a random adjacency matrix when given 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """    
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
@@ -4,4 +4,12 @@ version = "0.1.0"
 | 
				
			|||||||
description = "Add your description here"
 | 
					description = "Add your description here"
 | 
				
			||||||
readme = "README.md"
 | 
					readme = "README.md"
 | 
				
			||||||
requires-python = ">=3.12"
 | 
					requires-python = ">=3.12"
 | 
				
			||||||
dependencies = []
 | 
					dependencies = [
 | 
				
			||||||
 | 
					    "ipykernel>=6.30.0",
 | 
				
			||||||
 | 
					    "ipython>=9.4.0",
 | 
				
			||||||
 | 
					    "jax[cuda12]>=0.7.0",
 | 
				
			||||||
 | 
					    "jupyter>=1.1.1",
 | 
				
			||||||
 | 
					    "matplotlib>=3.10.3",
 | 
				
			||||||
 | 
					    "seaborn>=0.13.2",
 | 
				
			||||||
 | 
					    "tqdm>=4.67.1",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
							
								
								
									
										45
									
								
								test.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								test.ipynb
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,45 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					 "cells": [
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": 4,
 | 
				
			||||||
 | 
					   "id": "c883f5e2",
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "import jax"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "id": "ea833558",
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "key = jax.random.PRNGKey(0)"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					 ],
 | 
				
			||||||
 | 
					 "metadata": {
 | 
				
			||||||
 | 
					  "kernelspec": {
 | 
				
			||||||
 | 
					   "display_name": "graph-recognition-w-attn",
 | 
				
			||||||
 | 
					   "language": "python",
 | 
				
			||||||
 | 
					   "name": "python3"
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  "language_info": {
 | 
				
			||||||
 | 
					   "codemirror_mode": {
 | 
				
			||||||
 | 
					    "name": "ipython",
 | 
				
			||||||
 | 
					    "version": 3
 | 
				
			||||||
 | 
					   },
 | 
				
			||||||
 | 
					   "file_extension": ".py",
 | 
				
			||||||
 | 
					   "mimetype": "text/x-python",
 | 
				
			||||||
 | 
					   "name": "python",
 | 
				
			||||||
 | 
					   "nbconvert_exporter": "python",
 | 
				
			||||||
 | 
					   "pygments_lexer": "ipython3",
 | 
				
			||||||
 | 
					   "version": "3.12.3"
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					 },
 | 
				
			||||||
 | 
					 "nbformat": 4,
 | 
				
			||||||
 | 
					 "nbformat_minor": 5
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Reference in New Issue
	
	Block a user