Skip to content

Basic Neural Networks#

stochastix.utils.nn.Linear #

Linear(
    in_features: Union[int, Literal['scalar']],
    out_features: Union[int, Literal['scalar']],
    use_bias: bool = True,
    weight_init: Callable | None = None,
    bias_init: Callable | None = None,
    dtype=None,
    *,
    key: Array,
)

Linear layer with custom initializers.

Attributes:

  • weights

    Weight matrix for the linear transformation.

  • bias

    Bias vector for the linear transformation.

  • in_features

    Number of input features or 'scalar'.

  • out_features

    Number of output features or 'scalar'.

  • use_bias

    Boolean flag indicating whether bias is used.

Parameters:

  • in_features (Union[int, Literal['scalar']]) –

    The number of input features. Can be an integer or "scalar".

  • out_features (Union[int, Literal['scalar']]) –

    The number of output features. Can be an integer or "scalar".

  • use_bias (bool, default: True ) –

    Whether to include a bias term.

  • weight_init (Callable | None, default: None ) –

    A function to initialize the weights.

  • bias_init (Callable | None, default: None ) –

    A function to initialize the bias.

  • dtype

    The data type of the weights and bias.

  • key (Array) –

    A JAX random key for initialization.

stochastix.utils.nn.MultiLayerPerceptron #

MultiLayerPerceptron(
    in_size: Union[int, Literal['scalar']],
    out_size: Union[int, Literal['scalar']],
    hidden_sizes: tuple[int, ...] = (),
    activation: Callable = relu,
    final_activation: Callable = softplus,
    use_bias: bool = True,
    use_final_bias: bool = True,
    weight_init: Callable | None = None,
    bias_init: Callable | None = None,
    dtype=None,
    *,
    key: Array,
)

Multi-Layer Perceptron (MLP) for feed-forward neural networks.

This is a simple MLP with a configurable number of hidden layers and neurons, custom activation functions, and initializers.

Attributes:

  • layers

    Tuple of Linear layers.

  • activation

    The activation function for the hidden layers.

  • final_activation

    The activation function for the output layer.

  • in_size

    The input size of the MLP.

  • out_size

    The output size of the MLP.

  • hidden_sizes

    Tuple of integers specifying the size of each hidden layer.

  • use_bias

    Whether to use a bias in the hidden layers.

  • use_final_bias

    Whether to use a bias in the final layer.

Parameters:

  • in_size (Union[int, Literal['scalar']]) –

    The input size. The input to the module should be a vector of shape (in_features,). It also supports the string "scalar" as a special value, in which case the input to the module should be of shape ().

  • out_size (Union[int, Literal['scalar']]) –

    The output size. The output from the module will be a vector of shape (out_features,). It also supports the string "scalar" as a special value, in which case the output from the module will have shape ().

  • hidden_sizes (tuple[int, ...], default: () ) –

    The size of each hidden layer.

  • activation (Callable, default: relu ) –

    The activation function after each hidden layer.

  • final_activation (Callable, default: softplus ) –

    The activation function after the output layer.

  • use_bias (bool, default: True ) –

    Whether to add on a bias to internal layers.

  • use_final_bias (bool, default: True ) –

    Whether to add on a bias to the final layer.

  • weight_init (Callable | None, default: None ) –

    The initializer for the weights.

  • bias_init (Callable | None, default: None ) –

    The initializer for the biases.

  • dtype

    The dtype to use for all the weights and biases in this MLP. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.

  • key (Array) –

    A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)