added kuramoto

This commit is contained in:
2025-07-26 20:07:01 -04:00
parent 8227d11844
commit 1a0425d549
4 changed files with 390 additions and 2 deletions

Binary file not shown.

229
kuramoto.py Normal file
View File

@@ -0,0 +1,229 @@
import jax
import jax.numpy as jnp
from dataclasses import dataclass
from functools import partial
import numpy as np
import matplotlib.pyplot as plt
# -------------------- Configuration --------------------
@dataclass(frozen=True)
class KuramotoConfig:
"""Configuration for the Kuramoto model simulation."""
num_agents: int = 10 # N: Number of oscillators
coupling: float = 1.0 # K: Coupling strength
dt: float = 0.01 # Δt: Integration time step
T: float = 10.0 # Total simulation time
# Adjacency matrix properties
normalize_by_degree: bool = True
directed: bool = False
weighted: bool = False
@property
def num_time_steps(self) -> int:
"""Total number of simulation steps."""
return int(self.T / self.dt)
# -------------------- Core Dynamics --------------------
@partial(jax.jit, static_argnames=("config",))
def kuramoto_derivative(theta: jax.Array, # (N,) phase angles
omega: jax.Array, # (N,) natural frequencies
adj_mat: jax.Array, # (N, N) adjacency matrix
config: KuramotoConfig) -> jax.Array:
"""
Computes the derivative of the phase for each oscillator.
dθ_i/dt = ω_i + (K / deg_in_i) * Σ_j A_ji * sin(θ_j - θ_i)
"""
# Pairwise phase differences: delta[i, j] = θ_j - θ_i
delta = theta[jnp.newaxis, :] - theta[:, jnp.newaxis]
# Weighted sinusodial coupling, summing over incoming edges (hence adj_mat.T)
# coupling_effects[i, j] = A_ji * sin(θ_j - θ_i)
coupling_effects = adj_mat.T * jnp.sin(delta)
# Sum contributions from all other oscillators for each oscillator
coupling_sum = jnp.sum(coupling_effects, axis=1)
if config.normalize_by_degree:
# Normalize by the in-degree of each node
# In-degree for node i is the sum of column i in adj_mat
in_degree = jnp.sum(adj_mat, axis=0)
# Add a small epsilon to avoid division by zero for isolated nodes
coupling_sum = coupling_sum / (in_degree + 1e-8)
return omega + config.coupling * coupling_sum
@partial(jax.jit, static_argnames=("config",))
def kuramoto_step(theta: jax.Array, # (N,)
omega: jax.Array, # (N,)
adj_mat: jax.Array, # (N, N)
config: KuramotoConfig) -> jax.Array:
"""Performs a single Euler integration step of the Kuramoto model."""
theta_dot = kuramoto_derivative(theta, omega, adj_mat, config)
theta_next = theta + config.dt * theta_dot
# Wrap phases to the interval [-π, π) for numerical stability
return (theta_next + jnp.pi) % (2 * jnp.pi) - jnp.pi
# -------------------- Simulation Runner --------------------
@partial(jax.jit, static_argnames=("config",))
def run_kuramoto_simulation(
thetas0: jax.Array, # (N,) initial phases
omegas: jax.Array, # (N,) natural frequencies
adj_mat: jax.Array, # (N, N) adjacency matrix
config: KuramotoConfig
) -> jax.Array:
"""
Runs a full Kuramoto simulation for a given initial state.
Returns:
trajectory: (T, N) array of phase angles over time.
"""
def scan_fn(theta, _):
theta_next = kuramoto_step(theta, omegas, adj_mat, config)
return theta_next, theta_next
# jax.lax.scan is a functional loop, efficient for sequential operations
_, trajectory = jax.lax.scan(
scan_fn,
thetas0,
None,
length=config.num_time_steps
)
return trajectory
# -------------------- Analysis Functions --------------------
@jax.jit
def phase_coherence(thetas: jax.Array) -> jax.Array:
"""
Computes the global order parameter R, a measure of phase coherence.
R = |(1/N) * Σ_j exp(i * θ_j)|
Args:
thetas: An array of phases, e.g., (T, N) for a trajectory.
Returns:
The order parameter R. If input is a trajectory, returns R at each time step.
"""
complex_phases = jnp.exp(1j * thetas)
# Mean over the agent axis (-1)
return jnp.abs(jnp.mean(complex_phases, axis=-1))
@partial(jax.jit, static_argnames=("config",))
def mean_frequency(trajectory: jax.Array, # (T, N)
omegas: jax.Array, # (N,)
adj_mat: jax.Array, # (N, N)
config: KuramotoConfig) -> jax.Array:
"""
Computes the mean frequency of each oscillator over the simulation.
Returns:
mean_freqs: (N,) array of mean frequencies.
"""
# To find the mean frequency, we calculate the derivative at each point
# in the trajectory and then average over time.
# We can use vmap to apply the derivative function over the time axis.
vmapped_derivative = jax.vmap(
kuramoto_derivative,
in_axes=(0, None, None, None) # Map over theta (axis 0), other args are fixed
)
all_derivatives = vmapped_derivative(trajectory, omegas, adj_mat, config)
return jnp.mean(all_derivatives, axis=0)
# -------------------- Initialization Helpers --------------------
def generate_random_adjacency_matrix(key: jax.Array, config: KuramotoConfig) -> jax.Array:
"""Generates a single random adjacency matrix (N, N)."""
N = config.num_agents
shape = (N, N)
if config.weighted:
matrix = jax.random.uniform(key, shape)
else:
# Binary matrix based on a 50/50 chance
matrix = (jax.random.uniform(key, shape) > 0.5).astype(jnp.float32)
if not config.directed:
# Symmetrize the matrix for an undirected graph
matrix = jnp.tril(matrix) + jnp.triu(matrix.T, 1)
# An oscillator is always connected to itself to avoid division by zero
# if it has no other connections.
matrix = jnp.fill_diagonal(matrix, 1, inplace=False)
return matrix
def generate_initial_state(key: jax.Array, config: KuramotoConfig,
omega_mean=0.0, omega_std=1.0):
"""Generates initial phases and natural frequencies."""
key_theta, key_omega = jax.random.split(key)
N = config.num_agents
# Initial phases uniformly distributed in [0, 2π)
thetas0 = jax.random.uniform(key_theta, (N,), minval=0, maxval=2 * jnp.pi)
# Natural frequencies from a normal distribution
omegas = omega_mean + omega_std * jax.random.normal(key_omega, (N,))
return thetas0, omegas
# -------------------- Plotting --------------------
def plot_kuramoto_results(trajectory: np.ndarray, R_t: np.ndarray, config: KuramotoConfig):
"""Plots phase trajectories and the global order parameter."""
T, N = trajectory.shape
time = np.linspace(0, config.T, config.num_time_steps)
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
# Plot 1: Phase trajectories (sin(theta) for visualization)
for agent_idx in range(N):
ax1.plot(time, np.sin(trajectory[:, agent_idx]), lw=1.5, label=f"Agent {agent_idx+1}")
ax1.set_title("Kuramoto Oscillator Phase Trajectories")
ax1.set_ylabel(r"$\sin(\theta_i)$")
ax1.grid(True, linestyle='--', alpha=0.6)
if N <= 10:
ax1.legend(loc='upper right', fontsize='small')
# Plot 2: Global order parameter R
ax2.plot(time, R_t, color='k', lw=2)
ax2.set_title("Global Order Parameter (Phase Coherence)")
ax2.set_xlabel("Time (s)")
ax2.set_ylabel("R(t)")
ax2.set_ylim([0, 1.05])
ax2.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.show()
# -------------------- Main Execution --------------------
if __name__ == '__main__':
# 1. Setup configuration and random key
config = KuramotoConfig(num_agents=20, coupling=0.8, T=20)
key = jax.random.PRNGKey(42)
key, adj_key, state_key = jax.random.split(key, 3)
# 2. Generate system components
adj_matrix = generate_random_adjacency_matrix(adj_key, config)
thetas0, omegas = generate_initial_state(state_key, config)
# 3. Run the simulation
print(f"Running Kuramoto simulation for {config.num_time_steps} steps...")
trajectory = run_kuramoto_simulation(thetas0, omegas, adj_matrix, config)
# Block until the computation is done to measure time accurately if needed
trajectory.block_until_ready()
print("Simulation complete.")
# 4. Analyze the results
R_over_time = phase_coherence(trajectory)
avg_frequencies = mean_frequency(trajectory, omegas, adj_matrix, config)
print("\n--- Analysis Results ---")
print(f"Initial Coherence R(0): {R_over_time[0]:.4f}")
print(f"Final Coherence R(T): {R_over_time[-1]:.4f}")
print("\nNatural Frequencies (ω):")
print(np.asarray(omegas))
print("\nMean Frequencies over Simulation:")
print(np.asarray(avg_frequencies))
# 5. Plot the results
plot_kuramoto_results(np.asarray(trajectory), np.asarray(R_over_time), config)

File diff suppressed because one or more lines are too long