Kaun.OptimizerOptax-inspired gradient processing and optimization library
type label_tree = | LabelTensor of int| LabelList of label_tree list| LabelRecord of (string * label_tree) listType for labeling parameters for multi_transform
type 'layout gradient_transformation = {init : 'layout Ptree.t -> 'layout opt_state;update : 'layout opt_state ->
'layout Ptree.t ->
'layout Ptree.t ->
'layout Ptree.t * 'layout opt_state;}Core gradient transformation type
val identity : unit -> 'layout gradient_transformationIdentity transformation - returns gradients unchanged
val scale : float -> 'layout gradient_transformationScale gradients by a constant factor
val scale_by_neg_one : unit -> 'layout gradient_transformationScale gradients by -1 (for gradient descent)
val add_decayed_weights : float -> 'layout gradient_transformationAdd decayed value of parameters to updates (weight decay)
val clip_by_global_norm : float -> 'layout gradient_transformationClip gradients by global norm
val clip : float -> 'layout gradient_transformationClip gradients element-wise to -max_delta, max_delta
val trace :
decay:float ->
?nesterov:bool ->
unit ->
'layout gradient_transformationTrace momentum - maintains exponential moving average of gradients
val scale_by_rms :
?decay:float ->
?eps:float ->
unit ->
'layout gradient_transformationScale by RMS of gradients (used in RMSProp, Adam)
val scale_by_adam :
?b1:float ->
?b2:float ->
?eps:float ->
unit ->
'layout gradient_transformationScale by Adam-style second moment estimate
val scale_by_belief :
?b1:float ->
?b2:float ->
?eps:float ->
unit ->
'layout gradient_transformationScale by belief (used in AdaBelief)
module Schedule : sig ... endval scale_by_schedule : Schedule.t -> 'layout gradient_transformationScale updates by a learning rate schedule
val chain :
'layout gradient_transformation list ->
'layout gradient_transformationChain multiple transformations together
val multi_transform :
transforms:'layout gradient_transformation array ->
labels:('layout Ptree.t -> label_tree) ->
'layout gradient_transformationApply different transformations to different parameters The labels function maps parameters to integer labels. The transforms array maps labels to transformations.
val masked :
mask:('layout Ptree.t -> mask_tree) ->
inner:'layout gradient_transformation ->
'layout gradient_transformationApply transformation only to masked parameters
val sgd :
lr:float ->
?momentum:float ->
?nesterov:bool ->
unit ->
'layout gradient_transformationStochastic Gradient Descent
val adam :
lr:float ->
?b1:float ->
?b2:float ->
?eps:float ->
unit ->
'layout gradient_transformationAdam optimizer
val adamw :
lr:float ->
?b1:float ->
?b2:float ->
?eps:float ->
?weight_decay:float ->
unit ->
'layout gradient_transformationAdamW optimizer (Adam with weight decay)
val rmsprop :
lr:float ->
?decay:float ->
?eps:float ->
?momentum:float ->
unit ->
'layout gradient_transformationRMSProp optimizer
val adagrad : lr:float -> ?eps:float -> unit -> 'layout gradient_transformationAdaGrad optimizer
val adabelief :
lr:float ->
?b1:float ->
?b2:float ->
?eps:float ->
unit ->
'layout gradient_transformationAdaBelief optimizer
val lamb :
lr:float ->
?b1:float ->
?b2:float ->
?eps:float ->
?weight_decay:float ->
unit ->
'layout gradient_transformationLAMB optimizer (Layer-wise Adaptive Moments)
val radam :
lr:float ->
?b1:float ->
?b2:float ->
?eps:float ->
unit ->
'layout gradient_transformationRAdam (Rectified Adam)
val yogi :
lr:float ->
?b1:float ->
?b2:float ->
?eps:float ->
unit ->
'layout gradient_transformationYogi optimizer
Apply updates to parameters: params = params - updates
Apply updates to parameters in place (mutates first argument)
val global_norm : 'layout Ptree.t -> floatCompute global norm of gradients
Set step count (for schedules and bias correction)
val multi_steps :
every:int ->
'layout gradient_transformation ->
'layout gradient_transformationAccumulate gradients over multiple steps before applying
val with_gradient_stats :
?prefix:string ->
'layout gradient_transformation ->
'layout gradient_transformationAdd gradient statistics logging