Added Consensus sim
This commit is contained in:
		
							
								
								
									
										
											BIN
										
									
								
								__pycache__/consensus.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								__pycache__/consensus.cpython-312.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										114
									
								
								consensus.py
									
									
									
									
									
								
							
							
						
						
									
										114
									
								
								consensus.py
									
									
									
									
									
								
							@@ -1,15 +1,22 @@
 | 
				
			|||||||
import jax
 | 
					import jax
 | 
				
			||||||
import jax.numpy as jnp
 | 
					import jax.numpy as jnp
 | 
				
			||||||
 | 
					from dataclasses import dataclass
 | 
				
			||||||
 | 
					from functools import partial
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ConsensusConfig:
 | 
					class ConsensusConfig:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Config class for Consensus dynamics sims
 | 
					    Config class for Consensus dynamics sims
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					    num_sims: int = 500 # Number of consensus sims
 | 
				
			||||||
 | 
					    num_agents: int = 5# Number of agents in the consensus simulation
 | 
				
			||||||
 | 
					    max_range: float = 1 # Max range of values each agent can take
 | 
				
			||||||
 | 
					    step_size: float = 0.1 # Target range for length of simulation
 | 
				
			||||||
 | 
					    directed: bool = False # Consensus graph directed?
 | 
				
			||||||
 | 
					    weighted: bool = False
 | 
				
			||||||
 | 
					    num_time_steps: int = 100
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    num_agents: int # Number of agents in the consensus simulation
 | 
					 | 
				
			||||||
    max_range: float = 100 # Max range of values each agent can take
 | 
					 | 
				
			||||||
    step_size: float = 1 # Target range for length of simulation
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
def consensus_step(adj_matrix: jax.Array, agent_states: jax.Array, config: ConsensusConfig):
 | 
					def consensus_step(adj_matrix: jax.Array, agent_states: jax.Array, config: ConsensusConfig):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
@@ -35,12 +42,105 @@ def consensus_step(adj_matrix: jax.Array, agent_states: jax.Array, config: Conse
 | 
				
			|||||||
        A jax array containing the updated agent state
 | 
					        A jax array containing the updated agent state
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    L = jnp.diag(adj_matrix.sum(axis=0)) - adj_matrix
 | 
					    L = adj_matrix
 | 
				
			||||||
    return agent_states + config.step_size * L * agent_states
 | 
					    norms = jnp.sum(jnp.abs(L), axis=1, keepdims=True)
 | 
				
			||||||
 | 
					    L = L / (norms)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return (agent_states + config.step_size * jnp.matmul(L , agent_states))/(1 + config.step_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def generate_random_adjacency_matrix(key: jax.Array, config: ConsensusConfig):
 | 
					def generate_random_adjacency_matrix(key: jax.Array, config: ConsensusConfig):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Generates a random adjacency matrix when given 
 | 
					    Generates a random adjacency matrix in accordance with the config.
 | 
				
			||||||
 | 
					    The diagonal of the matrix is ensured to be all ones.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    --------------------
 | 
				
			||||||
 | 
					        key: jax.Array
 | 
				
			||||||
 | 
					            A key for jax.random operations
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        config: ConsensusConfig
 | 
				
			||||||
 | 
					            Config for Consensus dyanmics
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    """    
 | 
					    """    
 | 
				
			||||||
    
 | 
					    rand_matrix =  jax.random.uniform(key, shape=(config.num_agents, config.num_agents))
 | 
				
			||||||
 | 
					    # idxs = jnp.arange(config.num_agents)
 | 
				
			||||||
 | 
					    # rand_matrix = rand_matrix.at[:, idxs, idxs].set(1)
 | 
				
			||||||
 | 
					    rand_matrix = jnp.fill_diagonal(rand_matrix, 1, inplace=False) # Fill diagonal with ones
 | 
				
			||||||
 | 
					    if not config.weighted:
 | 
				
			||||||
 | 
					        rand_matrix = jnp.where(rand_matrix > 0.5, 1, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if config.directed:
 | 
				
			||||||
 | 
					        return rand_matrix
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rand_matrix = jnp.tril(rand_matrix) + jnp.triu(rand_matrix.T, 1)
 | 
				
			||||||
 | 
					    return rand_matrix
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def generate_random_agent_states(key: jax.Array, config: ConsensusConfig):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Generate a random initial state for the agents in accordance with the config.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    ---------------------
 | 
				
			||||||
 | 
					        key: jax.Arrray
 | 
				
			||||||
 | 
					            A key for jax.random operations
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        config: ConsensusConfig
 | 
				
			||||||
 | 
					            Config for Consensus dynamics
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns
 | 
				
			||||||
 | 
					    ---------------------
 | 
				
			||||||
 | 
					        rand_states: jax.Array (num_agents, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    rand_states = jax.random.uniform(key, shape=(config.num_sims, config.num_agents), minval=-config.max_range, maxval=config.max_range)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return rand_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@partial(jax.jit, static_argnames=["config"])
 | 
				
			||||||
 | 
					def run_consensus_sim(adj_mat: jax.Array, initial_agent_state: jax.Array, config: ConsensusConfig):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Runs the consensus sim and returns a history of agent states.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    -------------------
 | 
				
			||||||
 | 
					        adj_mat: jax.Array (num_agents, num_agents)
 | 
				
			||||||
 | 
					            A jax array containing the adjacency matrix for the consensus step.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        initial_agent_state: jax.Array (num_agents)
 | 
				
			||||||
 | 
					            A jax array containing the initial agent state
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        config: ConsensusConfig
 | 
				
			||||||
 | 
					            Config for Consensus dynamics
 | 
				
			||||||
 | 
					    """ 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    batched_consensus_step = jax.vmap(consensus_step, in_axes=(None, 0, None), out_axes=0)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def step(x_prev, _):
 | 
				
			||||||
 | 
					        x_next = batched_consensus_step(adj_mat, x_prev, config)
 | 
				
			||||||
 | 
					        return x_next, x_next
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _, all_states = jax.lax.scan(step, initial_agent_state, None, config.num_time_steps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return all_states.transpose(1, 0, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def plot_consensus(trajectory, config):
 | 
				
			||||||
 | 
					    import matplotlib.pyplot as plt
 | 
				
			||||||
 | 
					    import seaborn as sns
 | 
				
			||||||
 | 
					    sns.set_theme()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    states = np.array(trajectory)
 | 
				
			||||||
 | 
					    timesteps, num_agents = states.shape
 | 
				
			||||||
 | 
					    time = np.arange(timesteps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    plt.figure()
 | 
				
			||||||
 | 
					    for agent_idx in range(num_agents):
 | 
				
			||||||
 | 
					        plt.plot(time ,states[:, agent_idx], label=f"Agent {agent_idx}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    plt.xlabel("Timestep")
 | 
				
			||||||
 | 
					    plt.ylabel("Agent statue")
 | 
				
			||||||
 | 
					    plt.ylim(-config.max_range, config.max_range)
 | 
				
			||||||
 | 
					    plt.title("Consensus simulation trajectories")
 | 
				
			||||||
 | 
					    plt.legend()
 | 
				
			||||||
 | 
					    plt.show()
 | 
				
			||||||
							
								
								
									
										344
									
								
								test.ipynb
									
									
									
									
									
								
							
							
						
						
									
										344
									
								
								test.ipynb
									
									
									
									
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
		Reference in New Issue
	
	Block a user