Skip to content

Optimization#

Experimental Module

This module is experimental and it may change or be removed in future versions. Some functionalities are not well tested and may not be entirely reliable.

Losses#

stochastix.utils.optimization.reinforce_loss #

reinforce_loss(stop_returns_grad: bool = True)

Create a loss function for the REINFORCE algorithm.

Parameters:

  • stop_returns_grad (bool, default: True ) –

    If True, gradients will not be computed for the returns. This is the standard formulation of REINFORCE.

Returns:

  • A loss function that computes the REINFORCE loss.

The returned loss function has the signature: loss(model, ssa_results, returns, baseline=0.0)

Parameters:

  • model

    The model being trained. It must have a log_prob method (e.g., ReactionNetwork) or a .network attribute with a log_prob method (e.g., StochasticModel).

  • ssa_results

    The output from a stochastic simulation.

  • returns

    The returns for each step of the trajectory, typically discounted cumulative rewards.

  • baseline

    A baseline to subtract from the returns to reduce variance. It should not have gradients with respect to the policy parameters.

Returns:

  • The computed REINFORCE loss as a scalar.

Rewards#

stochastix.utils.optimization.neg_final_state_distance #

neg_final_state_distance(species=0, distance='L1')

Compute the negative distance from a target final state.

This is useful for creating a reward function where the goal is to minimize the distance to a target state at the end of the simulation.

Parameters:

  • species

    The index or name of the species to track.

  • distance

    The distance metric, either 'L1' or 'L2'.

Returns:

  • A function that takes SimulationResults and a target state value

  • and returns a reward vector. The reward is non-zero only at the final

  • step of the simulation.

stochastix.utils.optimization.steady_state_distance #

steady_state_distance(from_t=0.0, species=0)

Compute the distance from a target steady state.

Parameters:

  • from_t

    The time from which to start computing the distance.

  • species

    The index or name of the species to track.

Returns:

  • A function that takes SimulationResults and a target steady-state

  • value and returns the absolute distance from the target.

stochastix.utils.optimization.rewards_from_state_metric #

rewards_from_state_metric(metric_fn, metric_type='cost')

Generate a reward function from the differences of a state metric.

This is a general-purpose utility to create reward functions. It takes a metric function that computes a value at each time step and returns the difference between consecutive values as the reward.

Parameters:

  • metric_fn

    A function that takes SimulationResults and returns a vector of metric values, one for each time step.

  • metric_type

    Either 'cost' or 'reward'. If 'cost', the reward is the negative difference of the metric. If 'reward', it is the positive difference.

Returns:

  • A function that takes SimulationResults and computes a reward vector.

Gradients#

stochastix.utils.optimization.grad.gradfd #

gradfd(
    func: Callable[..., ndarray], epsilon: float = 1e-05
) -> Callable[..., Any]

Compute the gradient of a function using central first order finite differences.

NOTE: only supports differentiation with respect to the first argument of func.

Parameters:

  • func (Callable[..., ndarray]) –

    The function to differentiate. Should take a PyTree of arrays as input and return a scalar.

  • epsilon (float, default: 1e-05 ) –

    The step size to use for finite differences.

Returns:

  • Callable[..., Any]

    A function that takes the same arguments as func and returns the gradient of func at that point as a PyTree with the same structure as the input.

stochastix.utils.optimization.grad.gradspsa #

gradspsa(
    func: Callable[..., ndarray],
    epsilon: float = 0.001,
    num_samples: int = 20,
    *,
    split_first_arg_key: bool = True,
) -> Callable[..., Any]

Compute the gradient of a function using SPSA (Simultaneous Perturbation Stochastic Approximation).

NOTE: only supports differentiation with respect to the first argument of func.

Parameters:

  • func (Callable[..., ndarray]) –

    The function to differentiate. Should take a PyTree of arrays as input and return a scalar.

  • epsilon (float, default: 0.001 ) –

    The perturbation size to use for SPSA.

  • num_samples (int, default: 20 ) –

    Number of SPSA samples to average over.

  • split_first_arg_key (bool, default: True ) –

    If True, assumes the first positional argument passed to the returned gradient function is a PRNGKey and splits it into num_samples independent keys, using one per SPSA sample. The same per-sample key is used for both + and - perturbations (common random numbers) to reduce variance. If False, the exact same arguments are used for all SPSA samples.

Returns:

  • Callable[..., Any]

    A function that takes the same arguments as func (plus a PRNGKey) and returns the gradient of func at that point as a PyTree with the same structure as the input.

Utils#

stochastix.utils.optimization.dataloader #

dataloader(
    key: ndarray, arrays: list[ndarray], batch_size: int
)

Create a generator function that yields batches of data.

This function creates an infinite generator that repeatedly yields batches of data from the provided arrays.

Parameters:

  • key (ndarray) –

    A JAX random key for shuffling the data.

  • arrays (list[ndarray]) –

    A list of arrays, each with the same size in the first dimension.

  • batch_size (int) –

    The size of each batch.

Yields:

  • A tuple of batched arrays.

Raises:

  • ValueError

    If the arrays do not all have the same size in the first dimension.

stochastix.utils.optimization.discounted_returns #

discounted_returns(
    rewards: ndarray, GAMMA: float = 0.9
) -> jnp.ndarray

Calculate the discounted returns for a sequence of rewards.

Parameters:

  • rewards (ndarray) –

    A 1D array of rewards.

  • GAMMA (float, default: 0.9 ) –

    The discount factor.

Returns:

  • ndarray

    A 1D array of discounted returns, where each element represents the

  • ndarray

    cumulative discounted return from that point forward.