Skip to content

Simulation Examples#

import equinox as eqx
import jax.numpy as jnp
import jax.random as rng
import matplotlib.pyplot as plt


import stochastix as stx
from stochastix.utils.visualization import plot_abundance_dynamic

key = rng.PRNGKey(42)

# jax.config.update('jax_enable_x64', True)

plt.rcParams['font.size'] = 18
/Users/francesco/Documents/GitHub/stochastix/stochastix/utils/__init__.py:3: FutureWarning: The 'stochastix.utils.optimization' module is experimental and it may change or be removed without notice in future versions.
  from . import nn, optimization, visualization

Oscillator#

# Define the oscillator reactions using the new API
reactions = [
    stx.Reaction('X0 + X1 -> 2 X1', stx.kinetics.MassAction(0.01)),
    stx.Reaction('X1 + X2 -> 2 X2', stx.kinetics.MassAction(0.01)),
    stx.Reaction('X0 + X2 -> 2 X0', stx.kinetics.MassAction(0.01)),
    stx.Reaction('2 X0 -> X0', stx.kinetics.MassAction(0.0001)),
    stx.Reaction('2 X1 -> X1', stx.kinetics.MassAction(0.0001)),
    stx.Reaction('2 X2 -> X2', stx.kinetics.MassAction(0.0001)),
]

model = stx.ReactionNetwork(reactions)
print(model)
R0:  X0 + X1 -> 2 X1  |  MassAction
R1:  X1 + X2 -> 2 X2  |  MassAction
R2:  X0 + X2 -> 2 X0  |  MassAction
R3:  2 X0 -> X0       |  MassAction
R4:  2 X1 -> X1       |  MassAction
R5:  2 X2 -> X2       |  MassAction
key, subkey = rng.split(key)

x0 = jnp.array([100, 50, 70])
T = 100

sim_results = stx.stochsimsolve(
    subkey,
    model,
    x0,
    T=T,
    max_steps=int(5e5),
)

print('Time overflow:\t', sim_results.time_overflow)
Time overflow:   False
plot_abundance_dynamic(sim_results, time_unit='s');

img

Lotka-Volterra Model (Predator-Prey)#

# Use the built-in Lotka-Volterra generator
lv_model = stx.generators.lotka_volterra_model(
    alpha=jnp.array(0.1),  # prey reproduction rate
    beta=0.001,  # predation rate
    gamma=0.1,  # predator death rate
)
print(lv_model)
R0 (prey_reproduction):  prey -> 2 prey                 |  MassAction
R1 (predation):          predator + prey -> 2 predator  |  MassAction
R2 (predator_death):     predator -> 0                  |  MassAction
key, subkey = rng.split(key)

x0 = dict(predator=100, prey=50)

sim_results = stx.stochsimsolve(
    subkey,
    lv_model,
    x0,
    T=500,
    max_steps=int(5e4),
)

print('Time overflow:\t', sim_results.time_overflow)
Time overflow:   False
plot_abundance_dynamic(sim_results, time_unit='s');

img

lags, autocorr1 = stx.analysis.autocorrelation(
    sim_results, species='prey', n_points=1000
)
lags, autocorr2 = stx.analysis.autocorrelation(
    sim_results, species='predator', n_points=1000
)

# Plot autocorrelation
_, ax = plt.subplots(figsize=(10, 5))
ax.plot(lags, autocorr1)
ax.plot(lags, autocorr2)
ax.set_xlabel('Lag $\\tau$')
ax.set_ylabel('Autocorrelation $R_{X_i,X_i}(\\tau)$')
ax.legend(
    [
        '$R_{X_{\\mathrm{prey}},X_{\\mathrm{prey}}}(\\tau)$',
        '$R_{X_{\\mathrm{predator}},X_{\\mathrm{predator}}}(\\tau)$',
    ]
)
ax.grid(True)
plt.show()

img

lags, cross_corr1 = stx.analysis.cross_correlation(
    sim_results, species1='prey', species2='predator', n_points=1000
)
lags, cross_corr2 = stx.analysis.cross_correlation(
    sim_results, species1='predator', species2='prey', n_points=1000
)

# Plot cross-correlation
_, ax = plt.subplots(figsize=(10, 5))
ax.plot(lags, cross_corr1)
ax.plot(lags, cross_corr2)
ax.set_xlabel('Lag $\\tau$')
ax.set_ylabel('Cross-correlation $R_{X_i,X_j}(\\tau)$')
ax.legend(
    [
        '$R_{X_{\\mathrm{prey}},X_{\\mathrm{predator}}}(\\tau)$',
        '$R_{X_{\\mathrm{predator}},X_{\\mathrm{prey}}}(\\tau)$',
    ]
)
ax.grid(True)
plt.show()

img

Test differentiability:

@eqx.filter_jit
def test_loss(network, x0, key):
    sim_results = stx.stochsimsolve(
        key,
        network,
        x0,
        T=500,
        max_steps=int(5e4),
        solver=stx.solvers.DifferentiableDirect(),
    )

    _, cross_corr = stx.analysis.cross_correlation(
        sim_results,
        species1='prey',
        species2='predator',
        n_points=1000,
    )

    return jnp.mean(cross_corr**2)
key, subkey = rng.split(subkey)

eqx.filter_grad(test_loss)(lv_model, x0, key).prey_reproduction.kinetics.k
Array(0.67341566, dtype=float32, weak_type=True)

SIRS Model#

# Use the built-in SIRS generator (with no recovery -> susceptible to get SIR)
sir_model = stx.generators.sirs_model(
    beta=0.005,  # transmission rate
    gamma=0.02,  # recovery rate
    nu=0.01,  # loss of immunity rate (R -> S)
)
print(sir_model)
R0 (infection):         I + S -> 2 I  |  MassAction
R1 (recovery):          I -> R        |  MassAction
R2 (loss_of_immunity):  R -> S        |  MassAction
key, subkey = rng.split(key)

# I, R, S (stored in alphabetical order)
x0 = jnp.array([1, 0, 100])

sim_results = stx.stochsimsolve(
    subkey,
    sir_model,
    x0,
    T=300,
    max_steps=int(5e5),
)

print('Time overflow:\t', sim_results.time_overflow)
Time overflow:   False
sir_model.species
('I', 'R', 'S')
plot_abundance_dynamic(sim_results, time_unit='s', log_x_scale=True);

img

Repressilator#

# Use the built-in repressilator generator
repressilator = stx.generators.repressilator_model(
    alpha=30.0,  # maximum expression rate
    alpha0=0.0,  # leaky expression rate
    K=30.0,  # half-saturation constant
    beta=0.1,  # degradation rate
    n=3.0,  # Hill coefficient (cooperativity)
)
print(repressilator)
R0 (gene_A_expression):      0 -> A  |  HillRepressor
R1 (gene_B_expression):      0 -> B  |  HillRepressor
R2 (gene_C_expression):      0 -> C  |  HillRepressor
R3 (protein_A_degradation):  A -> 0  |  MassAction
R4 (protein_B_degradation):  B -> 0  |  MassAction
R5 (protein_C_degradation):  C -> 0  |  MassAction
key, subkey = rng.split(key)

x0 = jnp.ones(repressilator.n_species) * 10

sim_results = stx.stochsimsolve(subkey, repressilator, x0, T=300.0, max_steps=int(5e5))

print('Time overflow:\t', sim_results.time_overflow)
Time overflow:   False
plot_abundance_dynamic(sim_results, time_unit='s');

img

C2-FFL - Coherent Type 2 Feed-Forward Loop#

from stochastix import Reaction, ReactionNetwork
from stochastix.kinetics import HillActivator, HillRR, MassAction
c2_ffl_reactions = [
    Reaction(
        '0 -> Y',
        HillActivator(regulator='X', v=50.0, K=20.0, n=3.0),
        name='Y_production',
    ),
    Reaction(
        '0 -> Z',
        HillRR('X', 'Y', v=50.0, K1=40.0, K2=5.0, n1=3.0, n2=3.0, logic='and'),
        name='Z_production',
    ),
    Reaction('Y -> 0', MassAction(k=1.0), name='Y_degradation'),
    Reaction('Z -> 0', MassAction(k=1.0), name='Z_degradation'),
]

c2_ffl = ReactionNetwork(c2_ffl_reactions)

print(c2_ffl)
R0 (Y_production):   0 -> Y  |  HillActivator
R1 (Z_production):   0 -> Z  |  HillRR
R2 (Y_degradation):  Y -> 0  |  MassAction
R3 (Z_degradation):  Z -> 0  |  MassAction

Static Inputs#

key, subkey_high, subkey_low = rng.split(key, 3)

# x0_high = jnp.array([50.0, 0.0, 0.0])
x0_high = dict(X=50.0, Y=0.0, Z=0.0)

# x0_low = jnp.array([5.0, 0.0, 0.0])
x0_low = dict(X=5.0, Y=0.0, Z=0.0)

T = 100.0

sim_high = stx.stochsimsolve(subkey_high, c2_ffl, x0_high, T=T, max_steps=int(5e5))
sim_low = stx.stochsimsolve(subkey_low, c2_ffl, x0_low, T=T, max_steps=int(5e5))

_, ax = plt.subplots(figsize=(10, 5))
plot_abundance_dynamic(sim_high, time_unit='s', ax=ax)
ax.set_title('High Input Abundance')

_, ax = plt.subplots(figsize=(10, 5))
plot_abundance_dynamic(sim_low, time_unit='s', ax=ax)
ax.set_title('Low Input Abundance');

img

img

Dynamic Input (Time-controlled species)#

from stochastix.controllers import Timer

controlled_species = 'X'
time_triggers = jnp.array([20.0, 60.0, 80.0])
species_t = jnp.array([0.0, 50.0, 15.0])[:, None]

timed_controller = Timer(controlled_species, time_triggers, species_t)
timed_controller
Timer(
  controlled_species=('X',),
  time_triggers=(20.0, 60.0, 80.0),
  species_at_triggers=((0.0,), (50.0,), (15.0,))
)
key, subkey = rng.split(key)

x0_high = jnp.array([50.0, 0.0, 0.0])

sim_ctrl = stx.stochsimsolve(
    subkey_high,
    c2_ffl,
    x0_high,
    T=100.0,
    controller=timed_controller,
    max_steps=int(5e5),
)
plot_abundance_dynamic(sim_ctrl, time_unit='s');

img