stochsimsolve
stochastix.stochsimsolve #
stochsimsolve(
key: ndarray,
network: ReactionNetwork,
x0: ndarray | dict[str, float] | Any,
t0: float = 0.0,
T: float = 3600.0,
max_steps: int = 100000,
solver: AbstractStochasticSolver = DifferentiableDirect(
is_exact_solver=True,
is_pathwise_differentiable=True,
logits_scale=1.0,
),
controller: AbstractController | None = None,
save_trajectory: bool = True,
save_propensities: bool = True,
checkpoint: bool = False,
) -> SimulationResults
Run a stochastic simulation of a reaction network.
This function performs stochastic simulation of reaction networks. The simulation continues until either the specified time limit is reached, no more reactions can occur, or the maximum number of updates is exceeded.
Parameters:
-
key(ndarray) –JAX random number generator key.
-
network(ReactionNetwork) –The reaction network object containing reactions and their properties.
-
x0(ndarray | dict[str, float] | Any) –Initial state vector (counts). Can be an array, a dictionary mapping species names to their initial counts, a named tuple, or an Equinox module.
-
t0(float, default:0.0) –Initial time.
-
T(float, default:3600.0) –Final simulation time in seconds.
-
max_steps(int, default:100000) –Maximum number of simulation steps (needed for jit compilation).
-
solver(AbstractStochasticSolver, default:DifferentiableDirect( is_exact_solver=True, is_pathwise_differentiable=True, logits_scale=1.0 )) –The stochastic solver to use.
-
controller(AbstractController | None, default:None) –External state control during simulation.
-
save_trajectory(bool, default:True) –If True, store full trajectory. If False, store only initial and final states (shape (2, ...)). Defaults to True.
-
save_propensities(bool, default:True) –If True, store propensities at each step. If False, set propensities to None. Defaults to True.
-
checkpoint(bool, default:False) –If True, use
jax.checkpoint(gradient checkpointing) on the simulation step. This reduces memory usage during backpropagation by recomputing intermediate states, at the cost of increased computation time. This is particularly useful for: 1. Very long simulations (largemax_steps), where storing the gradient tape for every step would cause OOM. 2. Simulations with complex kinetics (e.g., neural network propensities), where the computational graph per step is large. Whensave_trajectory=True, checkpointing is especially efficient as recomputation starts from the stored trajectory states. Defaults to False.
Returns:
-
SimulationResults–SimulationResults. An object with the following attributes:
x: State trajectory over time. Shape (n_timepoints, n_species) ifsave_trajectory=True, or (2, n_species) ifsave_trajectory=False.t: Time points corresponding to state changes. Shape (n_timepoints,) ifsave_trajectory=True, or (2,) ifsave_trajectory=False.propensities: Reaction propensities at each time step, or None ifsave_propensities=False.reactions: Index of reactions that occurred at each step.time_overflow: True if simulation stopped due to reaching max_steps before time T.
Performance considerations.
This function is automatically JIT-compiled with equinox.filter_jit.
The save_trajectory and save_propensities parameters are static (compile-time
constants) and changing them will trigger re-compilation.
Both modes support backward differentiation through jax.lax.scan. For maximum
performance without differentiation (early stopping), use faststochsimsolve.
Basic usage
See Running Simulations for more details.
network = stochastix.ReactionNetwork(...)
x0 = jax.numpy.array(...)
ssa_results = stochastix.stochsimsolve(key, network, x0, T=1000.0)
Memory-efficient simulation
# Only save initial and final states, no propensities
results = stochastix.stochsimsolve(
key, network, x0, T=1000.0,
save_trajectory=False, save_propensities=False
)