Kaun.OptimSourceOptimizers and learning-rate schedules.
An algorithm combines a learning-rate schedule and an update rule. state stores optimizer-specific accumulators and the step count.
The type for optimizer states.
The type for optimization algorithms.
init algo params is the initial state of algo for params.
step algo state params grads is (updates, state') where updates are additive parameter deltas.
The step count is incremented before the learning-rate schedule is evaluated. Use apply_updates to apply updates to params.
apply_updates params updates is params + updates element-wise.
update algo st params grads is let u, st' = step algo st params grads in (apply_updates params u, st').
Convenience for the common case where you want updated parameters directly rather than additive deltas.
sgd ~lr ?momentum ?nesterov () is stochastic gradient descent.
momentum defaults to 0.. nesterov defaults to false. Nesterov mode is ignored when momentum = 0..
Raises Invalid_argument if momentum is not in 0.0 <= momentum < 1.0.
adam ~lr ?b1 ?b2 ?eps () is Adam with bias correction.
b1 defaults to 0.9. b2 defaults to 0.999. eps defaults to 1e-8.
Raises Invalid_argument if b1 or b2 is not in 0.0 <= b < 1.0, or if eps <= 0.0.
val adamw :
lr:Schedule.t ->
?b1:float ->
?b2:float ->
?eps:float ->
?weight_decay:float ->
unit ->
algorithmadamw ~lr ?b1 ?b2 ?eps ?weight_decay () is AdamW.
b1 defaults to 0.9. b2 defaults to 0.999. eps defaults to 1e-8. weight_decay defaults to 0.01.
Weight decay is decoupled from the Adam moment estimates.
Raises Invalid_argument if b1 or b2 is not in 0.0 <= b < 1.0, if eps <= 0.0, or if weight_decay < 0.0.
val rmsprop :
lr:Schedule.t ->
?decay:float ->
?eps:float ->
?momentum:float ->
unit ->
algorithmrmsprop ~lr ?decay ?eps ?momentum () is RMSprop.
decay defaults to 0.9. eps defaults to 1e-8. momentum defaults to 0. (no momentum).
Raises Invalid_argument if decay or momentum is not in 0.0 <= x < 1.0, or if eps <= 0.0.
adagrad ~lr ?eps () is Adagrad.
eps defaults to 1e-8.
Raises Invalid_argument if eps <= 0.0.
state_to_trees st is (count, trees) where count is the optimizer step count and trees are the internal state as parameter trees.
SGD without momentum returns an empty list. Adam returns [mu; nu].
state_of_trees algo ~count trees reconstructs optimizer state from an algorithm, step count, and serialized trees.
Raises Invalid_argument if the number of trees does not match the algorithm's expectation.
clip_by_global_norm max_norm grads rescales grads so their global L2 norm does not exceed max_norm. Returns grads unchanged if the norm is already within bounds.
Raises Invalid_argument if a leaf tensor is not floating point.