Skip to content

stochastix basic usage#

This guide will walk you through a simple example of how to use stochastix to define a model, run a stochastic simulation, and visualize the results.

1. Imports#

import jax
import jax.numpy as jnp

import stochastix as stx

key = jax.random.PRNGKey(42)

2. Defining a model#

In stochastix, a model is defined by creating a ReactionNetwork from a list of Reaction objects. Let's create a Lotka–Volterra predator–prey model with species A (prey) and B (predator).

The reactions are:

  • Prey reproduction: prey -> 2 prey (rate alpha)
  • Predation: prey + predator -> 2 predator (rate beta)
  • Predator death: predator -> 0 (rate gamma)

We can define this model by creating a Reaction for each process with MassAction kinetics.

from stochastix.kinetics import MassAction
from stochastix import Reaction, ReactionNetwork

alpha = 1.1
beta = 0.4
gamma = 0.4

reactions = [
    Reaction("A -> 2 A", MassAction(k=alpha), name="prey_reproduction"),
    Reaction("A + B -> 2 B", MassAction(k=beta), name="predation"),
    Reaction("B -> 0", MassAction(k=gamma), name="predator_death"),
]

network = ReactionNetwork(reactions)

The species A and B are automatically discovered from the reaction strings.

3. Running a simulation#

Now that we have a network, we can run a stochastic simulation using the stochsimsolve function. At minimum, provide a JAX random key, the network, an initial state x0, and a total simulation time T. If not specified, the initial time is t0 = 0.0.

key, subkey = jax.random.split(key)

# If passing an array, use lexicographic order of species names: ('A', 'B')
x0 = jnp.array([50.0, 5.0])
T = 100.0  # total simulation time

ssa_results = stx.stochsimsolve(subkey, network, x0, T=T)

Initial state

You can pass a PyTree (e.g., dict) whose leaves are named after species. Extra leaves are ignored. ssa_results.x mirrors the structure of x0. If you pass an array, the order is lexicographic by species name (here ('A', 'B')).

For example, you can use a dict:

x0 = {"A": 50.0, "B": 5.0}
ssa_results = stx.stochsimsolve(subkey, network, x0, T=T)
# ssa_results.x is a dict with the same keys as x0

or a NamedTuple:

from typing import NamedTuple

class State(NamedTuple):
    A: float
    B: float

x0 = State(A=50.0, B=5.0)
ssa_results = stx.stochsimsolve(subkey, network, x0, T=T)

Overflow handling

For JIT compilation the total number of steps must be pre-specified via max_steps. If a simulation stops because it hits max_steps before reaching T, ssa_results.time_overflow will be True. In that case, increase max_steps and rerun.

Once the simulation the total simulation time is reached, the simulation will just do nothing for the remainig number of steps and the results will be padded with the final state. Increasing max_steps too much can impact performance since more memory is allocated for the simulation.

4. Choosing a solver#

stochastix provides multiple solvers. You can pass a solver to stochsimsolve via the solver argument. DirectMethod is used by default.

  • Exact (event-by-event): DirectMethod() (default), FirstReactionMethod()
  • Approximate: TauLeaping(epsilon=0.03)
# Use FirstReactionMethod explicitly
ssa_results = stx.stochsimsolve(subkey, network, x0, T=T, solver=stx.FirstReactionMethod())

# Or an approximate solver
ssa_results = stx.stochsimsolve(subkey, network, x0, T=T, solver=stx.TauLeaping(epsilon=0.03))

5. SimulationResults essentials#

stochsimsolve returns a SimulationResults object with:

  • x: state trajectory over time.
  • t: time points corresponding to state changes.
  • reactions: index of reactions that occurred at each step.
  • propensities: reaction propensities at each time step.
  • time_overflow: true if simulation stopped due to reaching max_steps before time T.
  • species: names of the species in the simulation.
  • clean(): removes padded steps
  • interpolate(t_vec): re-samples states on a provided time grid (useful for plotting or downstream analysis)

Batched results generated by vmapping over stochsimsolve are indexable directly: results[i].

6. Visualizing the results#

We can now visualize the results of the simulation. The plot_abundance_dynamic function is a handy utility for this. It automatically discovers the species in the results and plots them as a function of time.

stx.plot_abundance_dynamic(ssa_results)

This will generate a plot showing the number of molecules of each species as a function of time. Rates are assumed to be provided in 1/s so the default time unit (base_time_unit argument) is seconds. Automatic time unit conversion is carried out if you specify plot_abundance_dynamic(..., time_unit='m'). If the rates are provided in other units, you can specify plot_abundance_dynamic(..., base_time_unit='m'). Available time units are 'ms', 's', 'm', 'h', 'd'.