Kaun.InitializersSourceFlax-compatible weight initializers for Kaun
This module provides weight initialization strategies matching the Flax/JAX neural network library API. All initializers return functions that take RNG seed, shape, device, and dtype parameters.
type t = {f : 'layout 'dev. int ->
int array ->
(float, 'layout) Rune.dtype ->
(float, 'layout) Rune.t;}Type for initializer functions
Truncated normal initializer
val variance_scaling :
scale:float ->
mode:[ `Fan_in | `Fan_out | `Fan_avg ] ->
distribution:[ `Normal | `Truncated_normal | `Uniform ] ->
in_axis:int ->
out_axis:int ->
unit ->
tGeneral variance scaling initializer
Glorot uniform initializer (aka Xavier uniform)
Uses variance_scaling with scale=1.0, mode=`Fan_avg, distribution=`Uniform
Glorot normal initializer (aka Xavier normal)
Uses variance_scaling with scale=1.0, mode=`Fan_avg, distribution=`Truncated_normal
LeCun uniform initializer
Uses variance_scaling with scale=1.0, mode=`Fan_in, distribution=`Uniform
LeCun normal initializer
Uses variance_scaling with scale=1.0, mode=`Fan_in, distribution=`Truncated_normal
He uniform initializer (aka Kaiming uniform)
Uses variance_scaling with scale=2.0, mode=`Fan_in, distribution=`Uniform Designed for layers with ReLU activation
He normal initializer (aka Kaiming normal)
Uses variance_scaling with scale=2.0, mode=`Fan_in, distribution=`Truncated_normal Designed for layers with ReLU activation
Orthogonal matrix initializer
Returns uniformly distributed orthogonal matrices. If the shape is not square, the matrices will have orthonormal rows or columns depending on which side is smaller.
Delta orthogonal initializer for convolutional layers
Initializer for convolutional layers that preserves identity in the spatial dimensions. Requires 3D, 4D, or 5D tensor shape with square spatial dimensions.
Uniform initializer with explicit range