Basic Neural Networks#
stochastix.utils.nn.Linear #
Linear(
in_features: Union[int, Literal['scalar']],
out_features: Union[int, Literal['scalar']],
use_bias: bool = True,
weight_init: Callable | None = None,
bias_init: Callable | None = None,
dtype=None,
*,
key: Array,
)
Linear layer with custom initializers.
Attributes:
-
weights–Weight matrix for the linear transformation.
-
bias–Bias vector for the linear transformation.
-
in_features–Number of input features or 'scalar'.
-
out_features–Number of output features or 'scalar'.
-
use_bias–Boolean flag indicating whether bias is used.
Parameters:
-
in_features(Union[int, Literal['scalar']]) –The number of input features. Can be an integer or "scalar".
-
out_features(Union[int, Literal['scalar']]) –The number of output features. Can be an integer or "scalar".
-
use_bias(bool, default:True) –Whether to include a bias term.
-
weight_init(Callable | None, default:None) –A function to initialize the weights.
-
bias_init(Callable | None, default:None) –A function to initialize the bias.
-
dtype–The data type of the weights and bias.
-
key(Array) –A JAX random key for initialization.
stochastix.utils.nn.MultiLayerPerceptron #
MultiLayerPerceptron(
in_size: Union[int, Literal['scalar']],
out_size: Union[int, Literal['scalar']],
hidden_sizes: tuple[int, ...] = (),
activation: Callable = relu,
final_activation: Callable = softplus,
use_bias: bool = True,
use_final_bias: bool = True,
weight_init: Callable | None = None,
bias_init: Callable | None = None,
dtype=None,
*,
key: Array,
)
Multi-Layer Perceptron (MLP) for feed-forward neural networks.
This is a simple MLP with a configurable number of hidden layers and neurons, custom activation functions, and initializers.
Attributes:
-
layers–Tuple of Linear layers.
-
activation–The activation function for the hidden layers.
-
final_activation–The activation function for the output layer.
-
in_size–The input size of the MLP.
-
out_size–The output size of the MLP.
-
hidden_sizes–Tuple of integers specifying the size of each hidden layer.
-
use_bias–Whether to use a bias in the hidden layers.
-
use_final_bias–Whether to use a bias in the final layer.
Parameters:
-
in_size(Union[int, Literal['scalar']]) –The input size. The input to the module should be a vector of shape (in_features,). It also supports the string "scalar" as a special value, in which case the input to the module should be of shape ().
-
out_size(Union[int, Literal['scalar']]) –The output size. The output from the module will be a vector of shape (out_features,). It also supports the string "scalar" as a special value, in which case the output from the module will have shape ().
-
hidden_sizes(tuple[int, ...], default:()) –The size of each hidden layer.
-
activation(Callable, default:relu) –The activation function after each hidden layer.
-
final_activation(Callable, default:softplus) –The activation function after the output layer.
-
use_bias(bool, default:True) –Whether to add on a bias to internal layers.
-
use_final_bias(bool, default:True) –Whether to add on a bias to the final layer.
-
weight_init(Callable | None, default:None) –The initializer for the weights.
-
bias_init(Callable | None, default:None) –The initializer for the biases.
-
dtype–The dtype to use for all the weights and biases in this MLP. Defaults to either
jax.numpy.float32orjax.numpy.float64depending on whether JAX is in 64-bit mode. -
key(Array) –A
jax.random.PRNGKeyused to provide randomness for parameter initialisation. (Keyword only argument.)