State Manipulation
State conversion#
The following functions are used to convert between generic pytree initial state formats and the flat array format.
stochastix.pytree_to_state #
pytree_to_state(
tree: Any, species: tuple[str, ...]
) -> jaxlib._jax.Array
Converts a PyTree or other initial state formats to a flat JAX array.
This function processes an initial state tree which can be a dictionary, an
object with attributes (like a named tuple or an Equinox module), or an
array-like object, and converts it into a flat JAX array of species counts.
The order of species in the output array is determined by species.
Parameters:
-
tree(Any) –The initial state. Can be a PyTree (dictionary, custom object) with leaves named after species, or an array-like object.
-
species(tuple[str, ...]) –The species names, in the order they should appear in the output array.
-
dtype–The data type for the output array.
Returns:
-
Array–A 1D JAX array representing the state vector, ordered according to
species.
stochastix.state_to_pytree #
state_to_pytree(
template: Any,
species: tuple[str, ...],
x_trajectory: Array,
) -> typing.Any
Converts a flat state trajectory array back into a PyTree.
This function reconstructs a PyTree with the same structure as the original
initial state template. For each leaf in template that corresponds to a species,
it replaces the initial value with the full time-series trajectory from
x_trajectory. Leaves that do not correspond to species are preserved.
If template was originally an array-like object, the trajectory is returned as is.
Parameters:
-
template(Any) –The original initial state PyTree, used as a template for the output.
-
species(tuple[str, ...]) –The species names to substitute in the output.
-
x_trajectory(Array) –A 2D array of shape
(n_timepoints, n_species)representing the state trajectory.
Returns:
-
Any–A PyTree of the same structure as
template, but with species leaves replaced by their trajectories, or the originalx_trajectoryiftemplatewas array-like.