123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324# 1 "src/base/optimise/owl_optimise_generic_sig.ml"(*
* OWL - OCaml Scientific Computing
* Copyright (c) 2016-2022 Liang Wang <liang@ocaml.xyz>
*)moduletypeSig=sigmoduleAlgodiff:Owl_algodiff_generic_sig.SigopenAlgodiff(** Utils module *)moduleUtils:sigvalsample_num:t->int(** Return the total number of samples in passed in ndarray. *)valdraw_samples:t->t->int->t*t(**
[draw_samples x y] draws samples from both [x] (observations) and [y]
(labels). The samples will be drew along axis 0, so [x] and [y] must agree
along axis 0.
*)valget_chunk:t->t->int->int->t*t(**
[get_chunk x y i c] gets a continuous chunk of [c] samples from position
[i] from [x] (observations) and [y] (labels).
*)end(** Strategies for learning rate update *)moduleLearning_Rate:sig(** Representation of learning rate update strategies. Possible values include:
- [Adam (alpha, beta1, beta2)], see {{: https://arxiv.org/abs/1412.6980 }ref} for parameter meaning
*)typetyp=|Adagradoffloat|Constoffloat|Decayoffloat*float|Exp_decayoffloat*float|RMSpropoffloat*float|Adamoffloat*float*float|Scheduleoffloatarrayvalrun:typ->int->t->tarray->t(** Execute the computations defined in module [typ]. *)valdefault:typ->typ(** Create module [typ] with default values. *)valupdate_ch:typ->t->tarray->tarray(** Update the cache of gradients. *)valto_string:typ->string(** Convert the module [typ] to its string representation. *)end(** Batch module *)moduleBatch:sig(** Types of batches. *)typetyp=|Full|Miniofint|Sampleofint|Stochasticvalrun:typ->t->t->int->t*t(** Execute the computations defined in module [typ]. *)valbatches:typ->t->int(** Return the total number of batches given a batch [typ]. *)valto_string:typ->string(** Convert the module [typ] to its string representation. *)end(** Loss module *)moduleLoss:sig(** Types of loss functions. *)typetyp=|Hinge|L1norm|L2norm|Quadratic|Cross_entropy|Customof(t->t->t)valrun:typ->t->t->t(** Execute the computations defined in module [typ]. *)valto_string:typ->string(** Convert the module [typ] to its string representation. *)end(** Gradient module *)moduleGradient:sig(** Types of gradient function. *)typetyp=|GD|CG|CD|NonlinearCG|DaiYuanCG|NewtonCG|Newtonvalrun:typ->(t->t)->t->t->t->t->t(** Execute the computations defined in module [typ]. *)valto_string:typ->string(** Convert the module [typ] to its string representation. *)end(** Momentum module *)moduleMomentum:sig(** Types of momentum functions. *)typetyp=|Standardoffloat|Nesterovoffloat|Nonevalrun:typ->t->t->t(** Execute the computations defined in module [typ]. *)valdefault:typ->typ(** Create module [typ] with default values. *)valto_string:typ->string(** Convert the module [typ] to its string representation. *)end(** Regularisation module *)moduleRegularisation:sig(** Types of regularisation functions. *)typetyp=|L1normoffloat|L2normoffloat|Elastic_netoffloat*float|Nonevalrun:typ->t->t(** Execute the computations defined in module [typ]. *)valto_string:typ->string(** Convert the module [typ] to its string representation. *)end(** Clipping module *)moduleClipping:sig(** Types of clipping functions. *)typetyp=|L2normoffloat|Valueoffloat*float|Nonevalrun:typ->t->t(** Execute the computations defined in module [typ]. *)valdefault:typ->typ(** Create module [typ] with default values. *)valto_string:typ->string(** Convert the module [typ] to its string representation. *)end(** Stopping module *)moduleStopping:sig(** Types of stopping functions. *)typetyp=|Constoffloat|Earlyofint*int|Nonevalrun:typ->float->bool(** Execute the computations defined in module [typ]. *)valdefault:typ->typ(** Create module [typ] with default values. *)valto_string:typ->string(** Convert the module [typ] to its string representation. *)end(** Checkpoint module *)moduleCheckpoint:sigtypestate={mutablecurrent_batch:int;mutablebatches_per_epoch:int;mutableepochs:float;mutablebatches:int;mutableloss:tarray;mutablestart_at:float;mutablestop:bool;mutablegs:tarrayarray;mutableps:tarrayarray;mutableus:tarrayarray;mutablech:tarrayarrayarray}(** Type definition of checkpoint *)(** Batch type. *)typetyp=|Batchofint|Epochoffloat|Customof(state->unit)|Nonevalinit_state:int->float->state(**
[init_state batches_per_epoch epochs] initialises a state by specifying the
number of batches per epoch and the number of epochs in total.
*)valdefault_checkpoint_fun:(string->'a)->'a(** This function is used for saving intermediate files during optimisation. *)valprint_state_info:state->unit(** Print out the detail information of current [state]. *)valprint_summary:state->unit(** Print out the summary of current [state]. *)valrun:typ->(string->unit)->int->t->state->unit(** Execute the computations defined in module [typ]. *)valto_string:typ->string(** Convert the module [typ] to its string representation. *)end(** Params module *)moduleParams:sigtypetyp={mutableepochs:float;mutablebatch:Batch.typ;mutablegradient:Gradient.typ;mutableloss:Loss.typ;mutablelearning_rate:Learning_Rate.typ;mutableregularisation:Regularisation.typ;mutablemomentum:Momentum.typ;mutableclipping:Clipping.typ;mutablestopping:Stopping.typ;mutablecheckpoint:Checkpoint.typ;mutableverbosity:bool}(** Type definition of parameter. *)valdefault:unit->typ(** Create module [typ] with default values. *)valconfig:?batch:Batch.typ->?gradient:Gradient.typ->?loss:Loss.typ->?learning_rate:Learning_Rate.typ->?regularisation:Regularisation.typ->?momentum:Momentum.typ->?clipping:Clipping.typ->?stopping:Stopping.typ->?checkpoint:Checkpoint.typ->?verbosity:bool->float->typ(** This function creates a parameter object with many configurations. *)valto_string:typ->string(** Convert the module [typ] to its string representation. *)end(** {4 Core functions} *)valminimise_weight:?state:Checkpoint.state->Params.typ->(t->t->t)->t->t->t->Checkpoint.state*t(**
This function minimises the weight [w] of passed-in function [f].
* [f] is a function [f : w -> x -> y].
* [w] is a row vector but [y] can have any shape.
*)valminimise_network:?state:Checkpoint.state->Params.typ->(t->t*tarrayarray)->(t->tarrayarray*tarrayarray)->(tarrayarray->unit)->(string->unit)->t->t->Checkpoint.state(**
This function is specifically designed for minimising the weights in a neural
network of graph structure. In Owl's earlier versions, the functions in the
regression module were actually implemented using this function.
*)valminimise_fun:?state:Checkpoint.state->Params.typ->(t->t)->t->Checkpoint.state*t(**
This function minimises [f : x -> y] w.r.t [x].
[x] is an ndarray; and [y] is an scalar value.
*)valminimise_compiled_network:?state:Checkpoint.state->Params.typ->(t->t->t)->(unit->unit)->(string->unit)->t->t->Checkpoint.state(** This function is minimize the weights in a compiled neural network of graph structure. *)end