faststochsimsolve
This function performs stochastic simulation using jax.lax.while_loop, which allows early termination when the simulation finishes before reaching max_steps. This makes it considerably faster than stochsimsolve for simulations that finish early. For the same random key, faststochsimsolve produces the exact same results as stochsimsolve (same final state, same reaction sequence), but comes with some trade-offs:
- It does not support backward differentiation (reverse-mode autodiff), only forward differentiation.
- It does not support saving the full trajectory (only initial and final states are saved).
stochastix.faststochsimsolve #
faststochsimsolve(
key: ndarray,
network: ReactionNetwork,
x0: ndarray | dict[str, float] | Any,
t0: float = 0.0,
T: float = 3600.0,
max_steps: int = 100000,
solver: AbstractStochasticSolver = DifferentiableDirect(
is_exact_solver=True,
is_pathwise_differentiable=True,
logits_scale=1.0,
),
controller: AbstractController | None = None,
save_propensities: bool = True,
) -> SimulationResults
Fast stochastic simulation using while_loop (stops early, forward differentiation only).
This function performs stochastic simulation using jax.lax.while_loop, which allows
early termination when the simulation finishes before reaching max_steps. This makes
it considerably faster than stochsimsolve for simulations that finish early, but it does NOT
support backward differentiation (reverse-mode autodiff). It also does not support saving the trajectory.
Parameters:
-
key(ndarray) –JAX random number generator key.
-
network(ReactionNetwork) –The reaction network object containing reactions and their properties.
-
x0(ndarray | dict[str, float] | Any) –Initial state vector (counts). Can be an array, a dictionary mapping species names to their initial counts, a named tuple, or an Equinox module.
-
t0(float, default:0.0) –Initial time.
-
T(float, default:3600.0) –Final simulation time in seconds.
-
max_steps(int, default:100000) –Maximum number of simulation steps (safety limit).
-
solver(AbstractStochasticSolver, default:DifferentiableDirect( is_exact_solver=True, is_pathwise_differentiable=True, logits_scale=1.0 )) –The stochastic solver to use.
-
controller(AbstractController | None, default:None) –External state control during simulation.
-
save_propensities(bool, default:True) –If True, store initial and final propensities. If False, set propensities to None. Defaults to True.
Returns:
-
SimulationResults–SimulationResults. An object with the following attributes:
x: Initial and final states only, shape (2, n_species).t: Initial and final times, shape (2,).propensities: Initial and final propensities (shape (2, n_reactions)) ifsave_propensities=True, or None otherwise.reactions: None (not tracked in while_loop mode).time_overflow: True if simulation stopped due to reaching max_steps before time T.
Forward differentiation only.
This function is automatically JIT-compiled with equinox.filter_jit.
Supports forward differentiation, DOES NOT support backward differentiation.
Use stochsimsolve if you need backward differentiation.
No trajectory saving.
This function does not support saving the whole trajectory. Use stochsimsolve if you need to save the trajectory.
Example
# Fast simulation without backward differentiation
results = stochastix.faststochsimsolve(key, network, x0, T=1000.0)