replicated mecc
This commit is contained in:
2
sims/__init__.py
Normal file
2
sims/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .consensus import *
|
||||
from .kuramoto import *
|
BIN
sims/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
sims/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
sims/__pycache__/consensus.cpython-312.pyc
Normal file
BIN
sims/__pycache__/consensus.cpython-312.pyc
Normal file
Binary file not shown.
BIN
sims/__pycache__/kuramoto.cpython-312.pyc
Normal file
BIN
sims/__pycache__/kuramoto.cpython-312.pyc
Normal file
Binary file not shown.
150
sims/consensus.py
Normal file
150
sims/consensus.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ConsensusConfig:
|
||||
"""
|
||||
Config class for Consensus dynamics sims
|
||||
"""
|
||||
num_sims: int = 500 # Number of consensus sims
|
||||
num_agents: int = 5 # Number of agents in the consensus simulation
|
||||
max_range: float = 1 # Max range of values each agent can take
|
||||
step_size: float = 0.1 # Target range for length of simulation
|
||||
directed: bool = False # Consensus graph directed?
|
||||
weighted: bool = False
|
||||
num_time_steps: int = 100
|
||||
|
||||
|
||||
def consensus_step(adj_matrix: jax.Array, agent_states: jax.Array, config: ConsensusConfig):
|
||||
"""
|
||||
Takes a step given the adjacency matrix and the current agent state using consensus dynamics.
|
||||
|
||||
|
||||
Parameters
|
||||
-----------------------------
|
||||
adj_matrix : jax.Array (num_agents, num_agents)
|
||||
A jax array containing the adjacency matrix for the consensus step.
|
||||
|
||||
|
||||
agent_states: jax.Array (num_agents)
|
||||
A jax array containing the current agent state
|
||||
|
||||
config: ConsensusConfig
|
||||
Config class for Consensus Dynamics
|
||||
|
||||
|
||||
Returns
|
||||
------------------------------
|
||||
updated_agent_state: jax.Array (num_agents)
|
||||
A jax array containing the updated agent state
|
||||
|
||||
"""
|
||||
L = adj_matrix
|
||||
norms = jnp.sum(jnp.abs(L), axis=1, keepdims=True)
|
||||
L = L / (norms)
|
||||
|
||||
return (agent_states + config.step_size * jnp.matmul(L , agent_states))/(1 + config.step_size)
|
||||
|
||||
def generate_random_adjacency_matrix(key: jax.Array, config: ConsensusConfig):
|
||||
"""
|
||||
Generates a random adjacency matrix in accordance with the config.
|
||||
The diagonal of the matrix is ensured to be all ones.
|
||||
|
||||
Parameters
|
||||
--------------------
|
||||
key: jax.Array
|
||||
A key for jax.random operations
|
||||
|
||||
config: ConsensusConfig
|
||||
Config for Consensus dyanmics
|
||||
|
||||
Returns
|
||||
---------------------
|
||||
adj_matrices: jax.Array (num_agents, num_agents)
|
||||
Random matrix
|
||||
"""
|
||||
rand_matrix = jax.random.uniform(key, shape=(config.num_agents, config.num_agents))
|
||||
# idxs = jnp.arange(config.num_agents)
|
||||
# rand_matrix = rand_matrix.at[:, idxs, idxs].set(1)
|
||||
rand_matrix = jnp.fill_diagonal(rand_matrix, 1, inplace=False) # Fill diagonal with ones
|
||||
if not config.weighted:
|
||||
rand_matrix = jnp.where(rand_matrix > 0.5, 1, 0)
|
||||
|
||||
if config.directed:
|
||||
return rand_matrix
|
||||
|
||||
rand_matrix = jnp.tril(rand_matrix) + jnp.triu(rand_matrix.T, 1)
|
||||
return rand_matrix
|
||||
|
||||
|
||||
def generate_random_agent_states(key: jax.Array, config: ConsensusConfig):
|
||||
"""
|
||||
Generate a random initial state for the agents in accordance with the config.
|
||||
|
||||
Parameters
|
||||
---------------------
|
||||
key: jax.Arrray
|
||||
A key for jax.random operations
|
||||
|
||||
config: ConsensusConfig
|
||||
Config for Consensus dynamics
|
||||
|
||||
Returns
|
||||
---------------------
|
||||
rand_states: jax.Array (num_sims, num_agents)
|
||||
|
||||
"""
|
||||
rand_states = jax.random.uniform(key, shape=(config.num_sims, config.num_agents), minval=-config.max_range, maxval=config.max_range)
|
||||
|
||||
return rand_states
|
||||
|
||||
@partial(jax.jit, static_argnames=["config"])
|
||||
def run_consensus_sim(adj_mat: jax.Array, initial_agent_state: jax.Array, config: ConsensusConfig):
|
||||
"""
|
||||
Runs the consensus sim and returns a history of agent states.
|
||||
|
||||
Parameters
|
||||
-------------------
|
||||
adj_mat: jax.Array (num_agents, num_agents)
|
||||
A jax array containing the adjacency matrix for the consensus step.
|
||||
|
||||
initial_agent_state: jax.Array (num_agents)
|
||||
A jax array containing the initial agent state
|
||||
|
||||
config: ConsensusConfig
|
||||
Config for Consensus dynamics
|
||||
"""
|
||||
# batched consensus step (meant for many initial states)
|
||||
batched_consensus_step = jax.vmap(consensus_step, in_axes=(None, 0, None), out_axes=0)
|
||||
|
||||
def step(x_prev, _):
|
||||
x_next = batched_consensus_step(adj_mat, x_prev, config)
|
||||
return x_next, x_next
|
||||
|
||||
_, all_states = jax.lax.scan(step, initial_agent_state, None, config.num_time_steps)
|
||||
|
||||
return all_states.transpose(1, 0, 2)
|
||||
|
||||
|
||||
def plot_consensus(trajectory, config):
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
sns.set_theme()
|
||||
|
||||
states = np.array(trajectory)
|
||||
timesteps, num_agents = states.shape
|
||||
time = np.arange(timesteps)
|
||||
|
||||
plt.figure()
|
||||
for agent_idx in range(num_agents):
|
||||
plt.plot(time ,states[:, agent_idx], label=f"Agent {agent_idx}")
|
||||
|
||||
plt.xlabel("Timestep")
|
||||
plt.ylabel("Agent statue")
|
||||
plt.ylim(-config.max_range, config.max_range)
|
||||
plt.title("Consensus simulation trajectories")
|
||||
plt.legend()
|
||||
plt.show()
|
221
sims/kuramoto.py
Normal file
221
sims/kuramoto.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
class KuramotoConfig:
|
||||
"""Configuration for the Kuramoto model simulation."""
|
||||
num_agents: int = 10 # N: Number of oscillators
|
||||
coupling: float = 1.0 # K: Coupling strength
|
||||
dt: float = 0.01 # Δt: Integration time step
|
||||
T: float = 10.0 # Total simulation time
|
||||
time_steps: int = int(T/dt)
|
||||
normalize_by_degree: bool = False
|
||||
directed: bool = False
|
||||
weighted: bool = False
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnames=("config",))
|
||||
def kuramoto_derivative(theta: jax.Array, # (N,) phase angles
|
||||
omega: jax.Array, # (N,) natural frequencies
|
||||
adj_mat: jax.Array, # (N, N) adjacency matrix
|
||||
config: KuramotoConfig) -> jax.Array:
|
||||
"""
|
||||
Computes the derivative of the phase for each oscillator.
|
||||
dθ_i/dt = ω_i + (K / deg_in_i) * Σ_j A_ji * sin(θ_j - θ_i)
|
||||
"""
|
||||
# Pairwise phase differences: delta[i, j] = θ_j - θ_i
|
||||
delta = theta[jnp.newaxis, :] - theta[:, jnp.newaxis]
|
||||
|
||||
# Weighted sinusodial coupling, summing over incoming edges (hence adj_mat.T)
|
||||
# coupling_effects[i, j] = A_ji * sin(θ_j - θ_i)
|
||||
coupling_effects = adj_mat.T * jnp.sin(delta)
|
||||
|
||||
# Sum contributions from all other oscillators for each oscillator
|
||||
coupling_sum = jnp.sum(coupling_effects, axis=1)
|
||||
|
||||
if config.normalize_by_degree:
|
||||
# Normalize by the in-degree of each node
|
||||
# In-degree for node i is the sum of column i in adj_mat
|
||||
in_degree = jnp.sum(adj_mat, axis=0)
|
||||
# Add a small epsilon to avoid division by zero for isolated nodes
|
||||
coupling_sum = coupling_sum / (in_degree + 1e-8)
|
||||
|
||||
return omega + config.coupling * coupling_sum
|
||||
|
||||
@partial(jax.jit, static_argnames=("config",))
|
||||
def kuramoto_step(theta: jax.Array, # (N,)
|
||||
omega: jax.Array, # (N,)
|
||||
adj_mat: jax.Array, # (N, N)
|
||||
config: KuramotoConfig) -> jax.Array:
|
||||
"""Performs a single Euler integration step of the Kuramoto model."""
|
||||
theta_dot = kuramoto_derivative(theta, omega, adj_mat, config)
|
||||
theta_next = theta + config.dt * theta_dot
|
||||
|
||||
# Wrap phases to the interval [-π, π) for numerical stability
|
||||
return (theta_next + jnp.pi) % (2 * jnp.pi) - jnp.pi
|
||||
|
||||
# -------------------- Simulation Runner --------------------
|
||||
@partial(jax.jit, static_argnames=("config",))
|
||||
def run_kuramoto_simulation(
|
||||
thetas0: jax.Array, # (N,) initial phases
|
||||
omegas: jax.Array, # (N,) natural frequencies
|
||||
adj_mat: jax.Array, # (N, N) adjacency matrix
|
||||
config: KuramotoConfig
|
||||
) -> jax.Array:
|
||||
"""
|
||||
Runs a full Kuramoto simulation for a given initial state.
|
||||
|
||||
Returns:
|
||||
trajectory: (T, N) array of phase angles over time.
|
||||
"""
|
||||
def scan_fn(theta, _):
|
||||
theta_next = kuramoto_step(theta, omegas, adj_mat, config)
|
||||
return theta_next, theta_next
|
||||
|
||||
# jax.lax.scan is a functional loop, efficient for sequential operations
|
||||
_, trajectory = jax.lax.scan(
|
||||
scan_fn,
|
||||
thetas0,
|
||||
None,
|
||||
length=config.time_steps
|
||||
)
|
||||
return trajectory
|
||||
|
||||
# -------------------- Analysis Functions --------------------
|
||||
@jax.jit
|
||||
def phase_coherence(thetas: jax.Array) -> jax.Array:
|
||||
"""
|
||||
Computes the global order parameter R, a measure of phase coherence.
|
||||
R = |(1/N) * Σ_j exp(i * θ_j)|
|
||||
|
||||
Args:
|
||||
thetas: An array of phases, e.g., (T, N) for a trajectory.
|
||||
|
||||
Returns:
|
||||
The order parameter R. If input is a trajectory, returns R at each time step.
|
||||
"""
|
||||
complex_phases = jnp.exp(1j * thetas)
|
||||
# Mean over the agent axis (-1)
|
||||
return jnp.abs(jnp.mean(complex_phases, axis=-1))
|
||||
|
||||
@partial(jax.jit, static_argnames=("config",))
|
||||
def mean_frequency(trajectory: jax.Array, # (T, N)
|
||||
omegas: jax.Array, # (N,)
|
||||
adj_mat: jax.Array, # (N, N)
|
||||
config: KuramotoConfig) -> jax.Array:
|
||||
"""
|
||||
Computes the mean frequency of each oscillator over the simulation.
|
||||
|
||||
Returns:
|
||||
mean_freqs: (N,) array of mean frequencies.
|
||||
"""
|
||||
# To find the mean frequency, we calculate the derivative at each point
|
||||
# in the trajectory and then average over time.
|
||||
# We can use vmap to apply the derivative function over the time axis.
|
||||
vmapped_derivative = jax.vmap(
|
||||
kuramoto_derivative,
|
||||
in_axes=(0, None, None, None) # Map over theta (axis 0), other args are fixed
|
||||
)
|
||||
all_derivatives = vmapped_derivative(trajectory, omegas, adj_mat, config)
|
||||
return jnp.mean(all_derivatives, axis=0)
|
||||
|
||||
# -------------------- Initialization Helpers --------------------
|
||||
def generate_random_adjacency_matrix(key: jax.Array, config: KuramotoConfig) -> jax.Array:
|
||||
"""Generates a single random adjacency matrix (N, N)."""
|
||||
N = config.num_agents
|
||||
shape = (N, N)
|
||||
|
||||
if config.weighted:
|
||||
matrix = jax.random.uniform(key, shape)
|
||||
else:
|
||||
# Binary matrix based on a 50/50 chance
|
||||
matrix = (jax.random.uniform(key, shape) > 0.5).astype(jnp.float32)
|
||||
|
||||
if not config.directed:
|
||||
# Symmetrize the matrix for an undirected graph
|
||||
matrix = jnp.tril(matrix) + jnp.triu(matrix.T, 1)
|
||||
|
||||
# An oscillator is always connected to itself to avoid division by zero
|
||||
# if it has no other connections.
|
||||
matrix = jnp.fill_diagonal(matrix, 1, inplace=False)
|
||||
|
||||
return matrix
|
||||
|
||||
def generate_initial_state(key: jax.Array, config: KuramotoConfig,
|
||||
omega_mean=0.0, omega_std=1.0):
|
||||
"""Generates initial phases and natural frequencies."""
|
||||
key_theta, key_omega = jax.random.split(key)
|
||||
N = config.num_agents
|
||||
|
||||
# Initial phases uniformly distributed in [0, 2π)
|
||||
thetas0 = jax.random.uniform(key_theta, (N,), minval=0, maxval=2 * jnp.pi)
|
||||
|
||||
# Natural frequencies from a normal distribution
|
||||
omegas = omega_mean + omega_std * jax.random.normal(key_omega, (N,))
|
||||
|
||||
return thetas0, omegas
|
||||
|
||||
# -------------------- Plotting --------------------
|
||||
def plot_kuramoto_results(trajectory: np.ndarray, R_t: np.ndarray, config: KuramotoConfig):
|
||||
"""Plots phase trajectories and the global order parameter."""
|
||||
|
||||
T, N = trajectory.shape
|
||||
time = np.linspace(0, config.T, config.num_time_steps)
|
||||
|
||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
|
||||
|
||||
# Plot 1: Phase trajectories (sin(theta) for visualization)
|
||||
for agent_idx in range(N):
|
||||
ax1.plot(time, np.sin(trajectory[:, agent_idx]), lw=1.5, label=f"Agent {agent_idx+1}")
|
||||
ax1.set_title("Kuramoto Oscillator Phase Trajectories")
|
||||
ax1.set_ylabel(r"$\sin(\theta_i)$")
|
||||
ax1.grid(True, linestyle='--', alpha=0.6)
|
||||
if N <= 10:
|
||||
ax1.legend(loc='upper right', fontsize='small')
|
||||
|
||||
# Plot 2: Global order parameter R
|
||||
ax2.plot(time, R_t, color='k', lw=2)
|
||||
ax2.set_title("Global Order Parameter (Phase Coherence)")
|
||||
ax2.set_xlabel("Time (s)")
|
||||
ax2.set_ylabel("R(t)")
|
||||
ax2.set_ylim([0, 1.05])
|
||||
ax2.grid(True, linestyle='--', alpha=0.6)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# -------------------- Main Execution --------------------
|
||||
if __name__ == '__main__':
|
||||
# 1. Setup configuration and random key
|
||||
config = KuramotoConfig(num_agents=20, coupling=0.8, T=20)
|
||||
key = jax.random.PRNGKey(42)
|
||||
key, adj_key, state_key = jax.random.split(key, 3)
|
||||
|
||||
# 2. Generate system components
|
||||
adj_matrix = generate_random_adjacency_matrix(adj_key, config)
|
||||
thetas0, omegas = generate_initial_state(state_key, config)
|
||||
|
||||
# 3. Run the simulation
|
||||
print(f"Running Kuramoto simulation for {config.num_time_steps} steps...")
|
||||
trajectory = run_kuramoto_simulation(thetas0, omegas, adj_matrix, config)
|
||||
# Block until the computation is done to measure time accurately if needed
|
||||
trajectory.block_until_ready()
|
||||
print("Simulation complete.")
|
||||
|
||||
# 4. Analyze the results
|
||||
R_over_time = phase_coherence(trajectory)
|
||||
avg_frequencies = mean_frequency(trajectory, omegas, adj_matrix, config)
|
||||
|
||||
print("\n--- Analysis Results ---")
|
||||
print(f"Initial Coherence R(0): {R_over_time[0]:.4f}")
|
||||
print(f"Final Coherence R(T): {R_over_time[-1]:.4f}")
|
||||
print("\nNatural Frequencies (ω):")
|
||||
print(np.asarray(omegas))
|
||||
print("\nMean Frequencies over Simulation:")
|
||||
print(np.asarray(avg_frequencies))
|
||||
|
||||
# 5. Plot the results
|
||||
plot_kuramoto_results(np.asarray(trajectory), np.asarray(R_over_time), config)
|
Reference in New Issue
Block a user