Getting Started
stochastix is a JAX-based library for stochastic simulation of chemical reaction networks.
It provides a simple and flexible API for defining models, simulating them using the Gillespie algorithm (and its variants), and optimizing their parameters with modern gradient-based methods.
Key features#
- JAX-powered: Built on JAX and Equinox. Models are PyTrees, enabling
jit,vmap,grad, and GPU/TPU acceleration. - Flexible model definition: Compose networks from
ReactionandReactionNetworkwith built-in kinetics (MassAction, Hill, Michaelis–Menten) and NN-based kinetics. - Exact, approximate, and differentiable SSA:
DirectMethod,FirstReactionMethod,TauLeaping, plus differentiable variants (DifferentiableDirect,DifferentiableFirstReaction,DGA). - Deterministic and CLE support: ODE (
vector_field,diffrax_ode_term) and CLE/SDE (drift_fn,noise_coupling,diffrax_sde_term) compatible with Diffrax. - Controllers: Time-based interventions (e.g.,
Timer) to manipulate species during simulations. - Likelihood and RL utilities:
ReactionNetwork.log_probfor exact trajectories and helpers for REINFORCE-style training. - Analysis and visualization: Differentiable autocorrelation, cross-correlation, differentiable histograms/MI, and plotting utilities.
Installation#
GPU support
stochastix (as all other JAX-based libraries) relies on JAX for hardware acceleration support. To run on GPU or other accelerators, you need to install the appropriate JAX version. If JAX is not already present, the standard stochastix installation will automatically install the CPU version.
Please refer to the JAX installation guide for the latest guidelines.
- pip
To install the package and core dependencies:
pip install stochastix
or directly from the repository:
pip install git+https://github.com/fmottes/stochastix.git
Note: in order to run the example Jupyter notebooks, you need to install the optional dependencies:
pip install stochastix[notebooks]
- uv
You can add the package to your project dependencies with:
uv add stochastix
For all other uv installation options, see the uv docs.
Quick Example#
A basic simulation of a chain reaction with Gillespie's direct method:
import jax
import jax.numpy as jnp
import stochastix as stx
from stochastix.kinetics import MassAction
# simple reaction chain with mass action rates
network = stx.ReactionNetwork([
stx.Reaction("0 -> X", MassAction(k=0.01)),
stx.Reaction("X -> Y", MassAction(k=0.002))
])
x0 = jnp.array([0,0]) #initial conditions [X,Y]
sim_key = jax.random.PRNGKey(0) #key for jax random number generator
#solve with direct method from t0=0s to t1=100s
sim_results = stx.stochsimsolve(sim_key, network, x0, T=100.0)
Diving deeper#
- Basic Usage: quick examples of basic library functionalities.
- User Guide: more detailed examples and explanations of the library features.
- Example notebooks: Miscellaneous examples in Jupyter notebook format.
API Reference#
The full API reference is available here.
Citation#
If you use this software, please cite the paper:
Gradient-based optimization of exact stochastic kinetic models Francesco Mottes, Qian-Ze Zhu, Michael P. Brenner arXiv:2601.14183 (2026)
@article{mottes2026gradient,
title={Gradient-based optimization of exact stochastic kinetic models},
author={Mottes, Francesco and Zhu, Qian-Ze and Brenner, Michael P.},
journal={arXiv preprint arXiv:2601.14183},
year={2026}
}
You can also use the "Cite this repository" button on GitHub to get the citation in various formats.