123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324# 1 "src/base/optimise/owl_optimise_generic_sig.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*)moduletypeSig=sigmoduleAlgodiff:Owl_algodiff_generic_sig.SigopenAlgodiff(** {6 Utils module} *)moduleUtils:sigvalsample_num:t->int(** Return the total number of samples in passed in ndarray. *)valdraw_samples:t->t->int->t*t(**
``draw_samples x y`` draws samples from both ``x`` (observations) and ``y``
(labels). The samples will be drew along axis 0, so ``x`` and ``y`` must agree
along axis 0.
*)valget_chunk:t->t->int->int->t*t(**
``get_chunk x y i c`` gets a continuous chunk of ``c`` samples from position
``i`` from ``x`` (observations) and ``y`` (labels).
*)end(** {7 Learning_Rate module} *)moduleLearning_Rate:sigtypetyp=|Adagradoffloat|Constoffloat|Decayoffloat*float|Exp_decayoffloat*float|RMSpropoffloat*float|Adamoffloat*float*float|Scheduleoffloatarray(** types of learning rate *)valrun:typ->int->t->tarray->t(** Execute the computations defined in module ``typ``. *)valdefault:typ->typ(** Create module ``typ`` with default values. *)valupdate_ch:typ->t->tarray->tarray(** Update the cache of gradients. *)valto_string:typ->string(** Convert the module ``typ`` to its string representation. *)end(** {6 Batch module} *)moduleBatch:sigtypetyp=|Full|Miniofint|Sampleofint|Stochastic(** Types of batches. *)valrun:typ->t->t->int->t*t(** Execute the computations defined in module ``typ``. *)valbatches:typ->t->int(** Return the total number of batches given a batch ``typ``. *)valto_string:typ->string(** Convert the module ``typ`` to its string representation. *)end(** {6 Loss module} *)moduleLoss:sigtypetyp=|Hinge|L1norm|L2norm|Quadratic|Cross_entropy|Customof(t->t->t)(** Types of loss functions. *)valrun:typ->t->t->t(** Execute the computations defined in module ``typ``. *)valto_string:typ->string(** Convert the module ``typ`` to its string representation. *)end(** {6 Gradient module} *)moduleGradient:sigtypetyp=|GD|CG|CD|NonlinearCG|DaiYuanCG|NewtonCG|Newton(** Types of gradient function. *)valrun:typ->(t->t)->t->t->t->t->t(** Execute the computations defined in module ``typ``. *)valto_string:typ->string(** Convert the module ``typ`` to its string representation. *)end(** {6 Momentum module} *)moduleMomentum:sigtypetyp=|Standardoffloat|Nesterovoffloat|None(** Types of momentum functions. *)valrun:typ->t->t->t(** Execute the computations defined in module ``typ``. *)valdefault:typ->typ(** Create module ``typ`` with default values. *)valto_string:typ->string(** Convert the module ``typ`` to its string representation. *)end(** {6 Regularisation module} *)moduleRegularisation:sigtypetyp=|L1normoffloat|L2normoffloat|Elastic_netoffloat*float|None(** Types of regularisation functions. *)valrun:typ->t->t(** Execute the computations defined in module ``typ``. *)valto_string:typ->string(** Convert the module ``typ`` to its string representation. *)end(** {6 Clipping module} *)moduleClipping:sigtypetyp=|L2normoffloat|Valueoffloat*float|None(** Types of clipping functions. *)valrun:typ->t->t(** Execute the computations defined in module ``typ``. *)valdefault:typ->typ(** Create module ``typ`` with default values. *)valto_string:typ->string(** Convert the module ``typ`` to its string representation. *)end(** {6 Stopping module} *)moduleStopping:sigtypetyp=|Constoffloat|Earlyofint*int|None(** Types of stopping functions. *)valrun:typ->float->bool(** Execute the computations defined in module ``typ``. *)valdefault:typ->typ(** Create module ``typ`` with default values. *)valto_string:typ->string(** Convert the module ``typ`` to its string representation. *)end(** {6 Checkpoint module} *)moduleCheckpoint:sigtypestate={mutablecurrent_batch:int;mutablebatches_per_epoch:int;mutableepochs:float;mutablebatches:int;mutableloss:tarray;mutablestart_at:float;mutablestop:bool;mutablegs:tarrayarray;mutableps:tarrayarray;mutableus:tarrayarray;mutablech:tarrayarrayarray}(** Type definition of checkpoint *)typetyp=|Batchofint|Epochoffloat|Customof(state->unit)|None(** Batch type. *)valinit_state:int->float->state(**
``init_state batches_per_epoch epochs`` initialises a state by specifying the
number of batches per epoch and the number of epochs in total.
*)valdefault_checkpoint_fun:(string->'a)->'a(** This function is used for saving intermediate files during optimisation. *)valprint_state_info:state->unit(** Print out the detail information of current ``state``. *)valprint_summary:state->unit(** Print out the summary of current ``state``. *)valrun:typ->(string->unit)->int->t->state->unit(** Execute the computations defined in module ``typ``. *)valto_string:typ->string(** Convert the module ``typ`` to its string representation. *)end(** {6 Params module} *)moduleParams:sigtypetyp={mutableepochs:float;mutablebatch:Batch.typ;mutablegradient:Gradient.typ;mutableloss:Loss.typ;mutablelearning_rate:Learning_Rate.typ;mutableregularisation:Regularisation.typ;mutablemomentum:Momentum.typ;mutableclipping:Clipping.typ;mutablestopping:Stopping.typ;mutablecheckpoint:Checkpoint.typ;mutableverbosity:bool}(** Type definition of paramater. *)valdefault:unit->typ(** Create module ``typ`` with default values. *)valconfig:?batch:Batch.typ->?gradient:Gradient.typ->?loss:Loss.typ->?learning_rate:Learning_Rate.typ->?regularisation:Regularisation.typ->?momentum:Momentum.typ->?clipping:Clipping.typ->?stopping:Stopping.typ->?checkpoint:Checkpoint.typ->?verbosity:bool->float->typ(** This function creates a parameter object with many configurations. *)valto_string:typ->string(** Convert the module ``typ`` to its string representation. *)end(** {6 Core functions} *)valminimise_weight:?state:Checkpoint.state->Params.typ->(t->t->t)->t->t->t->Checkpoint.state*t(**
This function minimises the weight ``w`` of passed-in function ``f``.
* ``f`` is a function ``f : w -> x -> y``.
* ``w`` is a row vector but ``y`` can have any shape.
*)valminimise_network:?state:Checkpoint.state->Params.typ->(t->t*tarrayarray)->(t->tarrayarray*tarrayarray)->(tarrayarray->unit)->(string->unit)->t->t->Checkpoint.state(**
This function is specifically designed for minimising the weights in a neural
network of graph structure. In Owl's earlier versions, the functions in the
regression module were actually implemented using this function.
*)valminimise_fun:?state:Checkpoint.state->Params.typ->(t->t)->t->Checkpoint.state*t(**
This function minimises ``f : x -> y`` w.r.t ``x``.
``x`` is an ndarray; and ``y`` is an scalar value.
*)valminimise_compiled_network:?state:Checkpoint.state->Params.typ->(t->t->t)->(unit->unit)->(string->unit)->t->t->Checkpoint.state(** TODO *)end