import jax import jax.numpy as jnp from dataclasses import dataclass from functools import partial import numpy as np import matplotlib.pyplot as plt # -------------------- Configuration -------------------- @dataclass(frozen=True) class KuramotoConfig: """Configuration for the Kuramoto model simulation.""" num_agents: int = 10 # N: Number of oscillators coupling: float = 1.0 # K: Coupling strength dt: float = 0.01 # Δt: Integration time step T: float = 10.0 # Total simulation time # Adjacency matrix properties normalize_by_degree: bool = True directed: bool = False weighted: bool = False @property def num_time_steps(self) -> int: """Total number of simulation steps.""" return int(self.T / self.dt) # -------------------- Core Dynamics -------------------- @partial(jax.jit, static_argnames=("config",)) def kuramoto_derivative(theta: jax.Array, # (N,) phase angles omega: jax.Array, # (N,) natural frequencies adj_mat: jax.Array, # (N, N) adjacency matrix config: KuramotoConfig) -> jax.Array: """ Computes the derivative of the phase for each oscillator. dθ_i/dt = ω_i + (K / deg_in_i) * Σ_j A_ji * sin(θ_j - θ_i) """ # Pairwise phase differences: delta[i, j] = θ_j - θ_i delta = theta[jnp.newaxis, :] - theta[:, jnp.newaxis] # Weighted sinusodial coupling, summing over incoming edges (hence adj_mat.T) # coupling_effects[i, j] = A_ji * sin(θ_j - θ_i) coupling_effects = adj_mat.T * jnp.sin(delta) # Sum contributions from all other oscillators for each oscillator coupling_sum = jnp.sum(coupling_effects, axis=1) if config.normalize_by_degree: # Normalize by the in-degree of each node # In-degree for node i is the sum of column i in adj_mat in_degree = jnp.sum(adj_mat, axis=0) # Add a small epsilon to avoid division by zero for isolated nodes coupling_sum = coupling_sum / (in_degree + 1e-8) return omega + config.coupling * coupling_sum @partial(jax.jit, static_argnames=("config",)) def kuramoto_step(theta: jax.Array, # (N,) omega: jax.Array, # (N,) adj_mat: jax.Array, # (N, N) config: KuramotoConfig) -> jax.Array: """Performs a single Euler integration step of the Kuramoto model.""" theta_dot = kuramoto_derivative(theta, omega, adj_mat, config) theta_next = theta + config.dt * theta_dot # Wrap phases to the interval [-π, π) for numerical stability return (theta_next + jnp.pi) % (2 * jnp.pi) - jnp.pi # -------------------- Simulation Runner -------------------- @partial(jax.jit, static_argnames=("config",)) def run_kuramoto_simulation( thetas0: jax.Array, # (N,) initial phases omegas: jax.Array, # (N,) natural frequencies adj_mat: jax.Array, # (N, N) adjacency matrix config: KuramotoConfig ) -> jax.Array: """ Runs a full Kuramoto simulation for a given initial state. Returns: trajectory: (T, N) array of phase angles over time. """ def scan_fn(theta, _): theta_next = kuramoto_step(theta, omegas, adj_mat, config) return theta_next, theta_next # jax.lax.scan is a functional loop, efficient for sequential operations _, trajectory = jax.lax.scan( scan_fn, thetas0, None, length=config.num_time_steps ) return trajectory # -------------------- Analysis Functions -------------------- @jax.jit def phase_coherence(thetas: jax.Array) -> jax.Array: """ Computes the global order parameter R, a measure of phase coherence. R = |(1/N) * Σ_j exp(i * θ_j)| Args: thetas: An array of phases, e.g., (T, N) for a trajectory. Returns: The order parameter R. If input is a trajectory, returns R at each time step. """ complex_phases = jnp.exp(1j * thetas) # Mean over the agent axis (-1) return jnp.abs(jnp.mean(complex_phases, axis=-1)) @partial(jax.jit, static_argnames=("config",)) def mean_frequency(trajectory: jax.Array, # (T, N) omegas: jax.Array, # (N,) adj_mat: jax.Array, # (N, N) config: KuramotoConfig) -> jax.Array: """ Computes the mean frequency of each oscillator over the simulation. Returns: mean_freqs: (N,) array of mean frequencies. """ # To find the mean frequency, we calculate the derivative at each point # in the trajectory and then average over time. # We can use vmap to apply the derivative function over the time axis. vmapped_derivative = jax.vmap( kuramoto_derivative, in_axes=(0, None, None, None) # Map over theta (axis 0), other args are fixed ) all_derivatives = vmapped_derivative(trajectory, omegas, adj_mat, config) return jnp.mean(all_derivatives, axis=0) # -------------------- Initialization Helpers -------------------- def generate_random_adjacency_matrix(key: jax.Array, config: KuramotoConfig) -> jax.Array: """Generates a single random adjacency matrix (N, N).""" N = config.num_agents shape = (N, N) if config.weighted: matrix = jax.random.uniform(key, shape) else: # Binary matrix based on a 50/50 chance matrix = (jax.random.uniform(key, shape) > 0.5).astype(jnp.float32) if not config.directed: # Symmetrize the matrix for an undirected graph matrix = jnp.tril(matrix) + jnp.triu(matrix.T, 1) # An oscillator is always connected to itself to avoid division by zero # if it has no other connections. matrix = jnp.fill_diagonal(matrix, 1, inplace=False) return matrix def generate_initial_state(key: jax.Array, config: KuramotoConfig, omega_mean=0.0, omega_std=1.0): """Generates initial phases and natural frequencies.""" key_theta, key_omega = jax.random.split(key) N = config.num_agents # Initial phases uniformly distributed in [0, 2π) thetas0 = jax.random.uniform(key_theta, (N,), minval=0, maxval=2 * jnp.pi) # Natural frequencies from a normal distribution omegas = omega_mean + omega_std * jax.random.normal(key_omega, (N,)) return thetas0, omegas # -------------------- Plotting -------------------- def plot_kuramoto_results(trajectory: np.ndarray, R_t: np.ndarray, config: KuramotoConfig): """Plots phase trajectories and the global order parameter.""" T, N = trajectory.shape time = np.linspace(0, config.T, config.num_time_steps) fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True) # Plot 1: Phase trajectories (sin(theta) for visualization) for agent_idx in range(N): ax1.plot(time, np.sin(trajectory[:, agent_idx]), lw=1.5, label=f"Agent {agent_idx+1}") ax1.set_title("Kuramoto Oscillator Phase Trajectories") ax1.set_ylabel(r"$\sin(\theta_i)$") ax1.grid(True, linestyle='--', alpha=0.6) if N <= 10: ax1.legend(loc='upper right', fontsize='small') # Plot 2: Global order parameter R ax2.plot(time, R_t, color='k', lw=2) ax2.set_title("Global Order Parameter (Phase Coherence)") ax2.set_xlabel("Time (s)") ax2.set_ylabel("R(t)") ax2.set_ylim([0, 1.05]) ax2.grid(True, linestyle='--', alpha=0.6) plt.tight_layout() plt.show() # -------------------- Main Execution -------------------- if __name__ == '__main__': # 1. Setup configuration and random key config = KuramotoConfig(num_agents=20, coupling=0.8, T=20) key = jax.random.PRNGKey(42) key, adj_key, state_key = jax.random.split(key, 3) # 2. Generate system components adj_matrix = generate_random_adjacency_matrix(adj_key, config) thetas0, omegas = generate_initial_state(state_key, config) # 3. Run the simulation print(f"Running Kuramoto simulation for {config.num_time_steps} steps...") trajectory = run_kuramoto_simulation(thetas0, omegas, adj_matrix, config) # Block until the computation is done to measure time accurately if needed trajectory.block_until_ready() print("Simulation complete.") # 4. Analyze the results R_over_time = phase_coherence(trajectory) avg_frequencies = mean_frequency(trajectory, omegas, adj_matrix, config) print("\n--- Analysis Results ---") print(f"Initial Coherence R(0): {R_over_time[0]:.4f}") print(f"Final Coherence R(T): {R_over_time[-1]:.4f}") print("\nNatural Frequencies (ω):") print(np.asarray(omegas)) print("\nMean Frequencies over Simulation:") print(np.asarray(avg_frequencies)) # 5. Plot the results plot_kuramoto_results(np.asarray(trajectory), np.asarray(R_over_time), config)