Skip to content
import jax
import jax.numpy as jnp
import jax.random as rng
import matplotlib.pyplot as plt

import stochastix as stx
from stochastix.utils.visualization import plot_abundance_dynamic

jax.config.update('jax_enable_x64', True)

key = rng.PRNGKey(10)

plt.rcParams['font.size'] = 18
/Users/francesco/Documents/GitHub/stochastix/stochastix/utils/__init__.py:3: FutureWarning: The 'stochastix.utils.optimization' module is experimental and it may change or be removed without notice in future versions.
  from . import nn, optimization, visualization

Chemical Oscillator Model#

# Define the oscillator reactions using the new API
reactions = [
    stx.Reaction('X0 + X1 -> 2 X1', stx.kinetics.MassAction(0.01)),
    stx.Reaction('X1 + X2 -> 2 X2', stx.kinetics.MassAction(0.01)),
    stx.Reaction('X0 + X2 -> 2 X0', stx.kinetics.MassAction(0.01)),
    stx.Reaction('2 X0 -> X0', stx.kinetics.MassAction(0.0001)),
    stx.Reaction('2 X1 -> X1', stx.kinetics.MassAction(0.0001)),
    stx.Reaction('2 X2 -> X2', stx.kinetics.MassAction(0.0001)),
]

oscillator = stx.ReactionNetwork(reactions)
print(oscillator)
R0:  X0 + X1 -> 2 X1  |  MassAction
R1:  X1 + X2 -> 2 X2  |  MassAction
R2:  X0 + X2 -> 2 X0  |  MassAction
R3:  2 X0 -> X0       |  MassAction
R4:  2 X1 -> X1       |  MassAction
R5:  2 X2 -> X2       |  MassAction

Stochastic Simulation#

key, subkey = rng.split(key)

x0 = jnp.array([100, 50, 70])
T = 100

ssa_results = stx.stochsimsolve(
    subkey,
    oscillator,
    x0,
    T=T,
    max_steps=int(5e5),
)

print('Time overflow:\t', ssa_results.time_overflow)
Time overflow:   False
plot_abundance_dynamic(ssa_results);

img

Conversion to ODE Model#

from diffrax import Dopri5, ODETerm, SaveAt, diffeqsolve
ode_term = ODETerm(oscillator.vector_field)
# or
# ode_term = oscillator.diffrax_ode_term()

dt0 = 0.1
saveat = SaveAt(ts=jnp.linspace(0.0, T, 1000))

ode_results = diffeqsolve(
    ode_term, Dopri5(), t0=0.0, t1=T, dt0=dt0, y0=x0, saveat=saveat, max_steps=int(1e6)
)
plt.figure(figsize=(9, 6))

plt.plot(ode_results.ts, ode_results.ys, alpha=0.5)
plt.xlabel('time [s]')
plt.ylabel('concentration')

plt.grid(alpha=0.2)

img

Conversion to SDE Formulation#

from diffrax import (
    ControlTerm,
    Euler,
    MultiTerm,
    ODETerm,
    SaveAt,
    VirtualBrownianTree,
    diffeqsolve,
)
t0, T = 0.0, 100

brownian_motion = VirtualBrownianTree(
    t0, T, tol=1e-1, shape=(oscillator.n_reactions,), key=subkey
)
terms = MultiTerm(
    ODETerm(oscillator.drift_fn), ControlTerm(oscillator.diffusion_fn, brownian_motion)
)

# or just:
# terms = oscillator.diffrax_sde_term(T, key=subkey)

solver = Euler()
saveat = SaveAt(steps=True)

dt0 = 0.1

sde_results = diffeqsolve(
    terms, solver, t0=t0, t1=T, dt0=dt0, y0=x0, saveat=saveat, max_steps=int(1e6)
)
plt.figure(figsize=(9, 6))

plt.plot(sde_results.ts, sde_results.ys, alpha=0.5)
plt.xlabel('time [s]')
plt.ylabel('concentration')

# plt.yscale('log')

plt.grid(alpha=0.2)

img

Lotka-Volterra Model (Predator-Prey)#

# Use the built-in Lotka-Volterra generator
lv_model = stx.generators.lotka_volterra_model(
    alpha=0.1,  # prey reproduction rate
    beta=0.005,  # predation rate
    gamma=0.05,  # predator death rate
)
print(lv_model)
R0 (prey_reproduction):  prey -> 2 prey                 |  MassAction
R1 (predation):          predator + prey -> 2 predator  |  MassAction
R2 (predator_death):     predator -> 0                  |  MassAction

Stochastic Simulation#

key, subkey = rng.split(key)

x0 = jnp.array([30, 30])
T = 500

ssa_results = stx.stochsimsolve(
    subkey,
    lv_model,
    x0,
    T=T,
    max_steps=int(5e4),
)

print('Time overflow:\t', ssa_results.time_overflow)
Time overflow:   False
plot_abundance_dynamic(ssa_results, time_unit='s');

img

Conversion to ODE Model#

from diffrax import Dopri5, ODETerm, SaveAt, diffeqsolve
# ode_term = ODETerm(lv_model.vector_field)
ode_term = lv_model.diffrax_ode_term()

dt0 = 0.1
saveat = SaveAt(ts=jnp.linspace(0.0, T, 1000))

ode_results = diffeqsolve(
    ode_term, Dopri5(), t0=0.0, t1=T, dt0=dt0, y0=x0, saveat=saveat, max_steps=int(1e5)
)
plt.figure(figsize=(9, 6))

plt.plot(ode_results.ts, ode_results.ys, alpha=0.5)
plt.xlabel('time [s]')
plt.ylabel('concentration')

plt.grid(alpha=0.2)

img

Conversion to SDE Formulation#

from diffrax import (
    ControlTerm,
    Euler,
    MultiTerm,
    ODETerm,
    SaveAt,
    VirtualBrownianTree,
    diffeqsolve,
)
t0, T = 0.0, 500

# brownian_motion = VirtualBrownianTree(
#     t0, T, tol=1e-3, shape=(lv_model.n_reactions,), key=subkey
# )
# terms = MultiTerm(
#     ODETerm(lv_model.drift_fn), ControlTerm(lv_model.diffusion_fn, brownian_motion)
# )

terms = lv_model.diffrax_sde_term(T, key=subkey)

solver = Euler()
saveat = SaveAt(
    steps=True,
)

dt0 = 0.1

sde_results = diffeqsolve(
    terms, solver, t0=t0, t1=T, dt0=dt0, y0=x0, saveat=saveat, max_steps=int(1e5)
)
plt.figure(figsize=(9, 6))

plt.plot(sde_results.ts, sde_results.ys, alpha=0.5)
plt.xlabel('time [s]')
plt.ylabel('concentration')

# plt.yscale('log')

plt.grid(alpha=0.2)

img