Module Kaun.InitializersSource

Flax-compatible weight initializers for Kaun

This module provides weight initialization strategies matching the Flax/JAX neural network library API. All initializers return functions that take RNG seed, shape, device, and dtype parameters.

Sourcetype t = {
  1. f : 'layout 'dev. int -> int array -> (float, 'layout) Rune.dtype -> (float, 'layout) Rune.t;
}

Type for initializer functions

Basic Initializers

Sourceval constant : float -> t

Constant value initializer

Sourceval zeros : unit -> t

Zero initializer

Sourceval ones : unit -> t

Ones initializer

Random Initializers

Sourceval uniform : ?scale:float -> unit -> t

Uniform random initializer in range [0, scale)

Sourceval normal : ?stddev:float -> unit -> t

Normal (Gaussian) random initializer

Sourceval truncated_normal : ?stddev:float -> ?lower:float -> ?upper:float -> unit -> t

Truncated normal initializer

Variance Scaling Initializers

Sourceval variance_scaling : scale:float -> mode:[ `Fan_in | `Fan_out | `Fan_avg ] -> distribution:[ `Normal | `Truncated_normal | `Uniform ] -> in_axis:int -> out_axis:int -> unit -> t

General variance scaling initializer

  • parameter scale

    Scaling factor (positive float)

  • parameter mode

    One of `Fan_in, `Fan_out, `Fan_avg

  • parameter distribution

    One of `Normal, `Truncated_normal, `Uniform

  • parameter in_axis

    Axis of input dimension (default: -2)

  • parameter out_axis

    Axis of output dimension (default: -1)

Xavier/Glorot Initializers

Sourceval glorot_uniform : ?in_axis:int -> ?out_axis:int -> unit -> t

Glorot uniform initializer (aka Xavier uniform)

Uses variance_scaling with scale=1.0, mode=`Fan_avg, distribution=`Uniform

Sourceval glorot_normal : ?in_axis:int -> ?out_axis:int -> unit -> t

Glorot normal initializer (aka Xavier normal)

Uses variance_scaling with scale=1.0, mode=`Fan_avg, distribution=`Truncated_normal

Sourceval xavier_uniform : ?in_axis:int -> ?out_axis:int -> unit -> t

Alias for glorot_uniform

Sourceval xavier_normal : ?in_axis:int -> ?out_axis:int -> unit -> t

Alias for glorot_normal

LeCun Initializers

Sourceval lecun_uniform : ?in_axis:int -> ?out_axis:int -> unit -> t

LeCun uniform initializer

Uses variance_scaling with scale=1.0, mode=`Fan_in, distribution=`Uniform

Sourceval lecun_normal : ?in_axis:int -> ?out_axis:int -> unit -> t

LeCun normal initializer

Uses variance_scaling with scale=1.0, mode=`Fan_in, distribution=`Truncated_normal

He/Kaiming Initializers

Sourceval he_uniform : ?in_axis:int -> ?out_axis:int -> unit -> t

He uniform initializer (aka Kaiming uniform)

Uses variance_scaling with scale=2.0, mode=`Fan_in, distribution=`Uniform Designed for layers with ReLU activation

Sourceval he_normal : ?in_axis:int -> ?out_axis:int -> unit -> t

He normal initializer (aka Kaiming normal)

Uses variance_scaling with scale=2.0, mode=`Fan_in, distribution=`Truncated_normal Designed for layers with ReLU activation

Sourceval kaiming_uniform : ?in_axis:int -> ?out_axis:int -> unit -> t

Alias for he_uniform

Sourceval kaiming_normal : ?in_axis:int -> ?out_axis:int -> unit -> t

Alias for he_normal

Orthogonal Initializers

Sourceval orthogonal : ?scale:float -> ?column_axis:int -> unit -> t

Orthogonal matrix initializer

Returns uniformly distributed orthogonal matrices. If the shape is not square, the matrices will have orthonormal rows or columns depending on which side is smaller.

  • parameter scale

    Scaling factor (default: 1.0)

  • parameter column_axis

    Axis containing columns that should be orthogonal (default: -1)

Sourceval delta_orthogonal : ?scale:float -> ?column_axis:int -> unit -> t

Delta orthogonal initializer for convolutional layers

Initializer for convolutional layers that preserves identity in the spatial dimensions. Requires 3D, 4D, or 5D tensor shape with square spatial dimensions.

  • parameter scale

    Scaling factor (default: 1.0)

  • parameter column_axis

    Axis containing columns that should be orthogonal (default: -1)

Utility Initializers

Sourceval uniform_range : low:float -> high:float -> unit -> t

Uniform initializer with explicit range

Sourceval normal_range : mean:float -> stddev:float -> unit -> t

Normal initializer with explicit mean and stddev