Skip to content

State Manipulation

State conversion#

The following functions are used to convert between generic pytree initial state formats and the flat array format.


stochastix.pytree_to_state #

pytree_to_state(
    tree: Any, species: tuple[str, ...]
) -> jaxlib._jax.Array

Converts a PyTree or other initial state formats to a flat JAX array.

This function processes an initial state tree which can be a dictionary, an object with attributes (like a named tuple or an Equinox module), or an array-like object, and converts it into a flat JAX array of species counts. The order of species in the output array is determined by species.

Parameters:

  • tree (Any) –

    The initial state. Can be a PyTree (dictionary, custom object) with leaves named after species, or an array-like object.

  • species (tuple[str, ...]) –

    The species names, in the order they should appear in the output array.

  • dtype

    The data type for the output array.

Returns:

  • Array

    A 1D JAX array representing the state vector, ordered according to species.


stochastix.state_to_pytree #

state_to_pytree(
    template: Any,
    species: tuple[str, ...],
    x_trajectory: Array,
) -> typing.Any

Converts a flat state trajectory array back into a PyTree.

This function reconstructs a PyTree with the same structure as the original initial state template. For each leaf in template that corresponds to a species, it replaces the initial value with the full time-series trajectory from x_trajectory. Leaves that do not correspond to species are preserved.

If template was originally an array-like object, the trajectory is returned as is.

Parameters:

  • template (Any) –

    The original initial state PyTree, used as a template for the output.

  • species (tuple[str, ...]) –

    The species names to substitute in the output.

  • x_trajectory (Array) –

    A 2D array of shape (n_timepoints, n_species) representing the state trajectory.

Returns:

  • Any

    A PyTree of the same structure as template, but with species leaves replaced by their trajectories, or the original x_trajectory if template was array-like.