123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858# 1 "src/base/optimise/owl_optimise_generic.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*
* Core optimization algorithms and API ported from Hype
* (http://hypelib.github.io/Hype/), copyright (c) 2015-2018 National
* University of Ireland Maynooth (Barak A. Pearlmutter <barak@pearlmutter.net>,
* 2015-2016 National University of Ireland Maynooth (Atilim Gunes Baydin),
* 2016-2018 University of Oxford (Atilim Gunes Baydin <gunes@robots.ox.ac.uk>)
*)(**
Optimisation engine
This module provides fundamental supports for Owl's regression and neural
network module. The module supports both single and double precision float
numbers.
*)[@@@warning"-45"](* Make functor starts *)moduleMake(Algodiff:Owl_algodiff_generic_sig.Sig)=structmoduleAlgodiff=AlgodiffopenAlgodiffmoduleUtils=structletsample_numx=matchxwith|Arr_->Arr.(shapex).(0)|x->failwith("Owl_optimise.Utils.sample_num:"^type_infox)letdraw_samplesxyn=matchx,ywith|Arrx,Arry->letx,i=A.draw~axis:0xninlety=A.rowsyiinArrx,Arry|x,_->failwith("Owl_optimise.Utils.draw_samples:"^type_infox)letget_chunkxyic=matchx,ywith|Arrx,Arry->letn=A.row_numyinleta=i*cmodninletb=a+c-1inifb<nthen(letx=A.get_slice[[a;b]]xinlety=A.get_slice[[a;b]]yinArrx,Arry)else(letx0=A.get_slice[[a;n-1]]xinlety0=A.get_slice[[a;n-1]]yinletx1=A.get_slice[[0;b-n]]xinlety1=A.get_slice[[0;b-n]]yinletx=A.concatenate~axis:0[|x0;x1|]inlety=A.concatenate~axis:0[|y0;y1|]inArrx,Arry)|x,_->failwith("Owl_optimise.Utils.get_chunk:"^type_infox)endmoduleLearning_Rate=structtypetyp=|Adagradoffloat|Constoffloat|Decayoffloat*float|Exp_decayoffloat*float|RMSpropoffloat*float|Adamoffloat*float*float|Scheduleoffloatarrayletrun=function|Adagrada->fun__c->Maths.(_fa/sqrt(c.(0)+_f1e-32))|Consta->fun___->_fa|Decay(a,k)->funi__->Maths.(_fa/(_f1.+(_fk*_f(float_of_inti))))|Exp_decay(a,k)->funi__->Maths.(_fa*exp(neg(_fk)*_f(float_of_inti)))|RMSprop(a,_)->fun__c->Maths.(_fa/sqrt(c.(0)+_f1e-32))|Adam(a,b1,b2)->funigc->Maths.(_fa*sqrt(_f1.-(_fb2**_f(float_of_inti)))/(_f1.-(_fb1**_f(float_of_inti)))*c.(0)/(sqrtc.(1)+_f1e-8)/(g+_f1e-32))|Schedulea->funi__->_fa.(imodArray.lengtha)letdefault=function|Adagrad_->Adagrad0.01|Const_->Const0.001|Decay_->Decay(0.1,0.1)|Exp_decay_->Exp_decay(1.,0.1)|RMSprop_->RMSprop(0.001,0.9)|Adam_->Adam(0.001,0.9,0.999)|Schedule_->Schedule[|0.001|]letupdate_chtypgc=matchtypwith|Adagrad_->[|Maths.(c.(0)+(g*g));c.(1)|]|RMSprop(_,k)->[|Maths.((_fk*c.(0))+((_f1.-_fk)*g*g));c.(1)|]|Adam(_,b1,b2)->letm=Maths.((_fb1*c.(0))+((_f1.-_fb1)*g))inletv=Maths.((_fb2*c.(1))+((_f1.-_fb2)*g*g))in[|m;v|]|_->cletto_string=function|Adagrada->Printf.sprintf"adagrad %g"a|Consta->Printf.sprintf"constant %g"a|Decay(a,k)->Printf.sprintf"decay (%g, %g)"ak|Exp_decay(a,k)->Printf.sprintf"exp_decay (%g, %g)"ak|RMSprop(a,k)->Printf.sprintf"rmsprop (%g, %g)"ak|Adam(a,b1,b2)->Printf.sprintf"adam (%g, %g, %g)"ab1b2|Schedulea->Printf.sprintf"schedule %i"(Array.lengtha)endmoduleBatch=structtypetyp=|Full|Miniofint|Sampleofint|Stochasticletruntypxyi=matchtypwith|Full->x,y|Minic->Utils.get_chunkxyic|Samplec->Utils.draw_samplesxyc|Stochastic->Utils.draw_samplesxy1letbatchestypx=matchtypwith|Full->1|Minic->Utils.sample_numx/c|Samplec->Utils.sample_numx/c|Stochastic->Utils.sample_numxletto_string=function|Full->Printf.sprintf"%s""full"|Minic->Printf.sprintf"mini of %i"c|Samplec->Printf.sprintf"sample of %i"c|Stochastic->Printf.sprintf"%s""stochastic"endmoduleLoss=structtypetyp=|Hinge|L1norm|L2norm|Quadratic|Cross_entropy|Customof(t->t->t)letruntypyy'=matchtypwith|Hinge->Maths.(sum'(max2(_f0.)(_f1.-(y*y'))))|L1norm->Maths.(l1norm'(y-y'))|L2norm->Maths.(l2norm'(y-y'))|Quadratic->Maths.(l2norm_sqr'(y-y'))|Cross_entropy->Maths.(cross_entropyyy')|Customf->fyy'(* y': prediction *)letto_string=function|Hinge->"Hinge"|L1norm->"l1norm"|L2norm->"l2norm"|Quadratic->"quadratic"|Cross_entropy->"cross_entropy"|Custom_->"customise"endmoduleGradient=structtypetyp=|GD(* classic gradient descent *)|CG(* Hestenes and Stiefel 1952 *)|CD(* Fletcher 1987 *)|NonlinearCG(* Fletcher and Reeves 1964 *)|DaiYuanCG(* Dai and Yuan 1999 *)|NewtonCG(* Newton conjugate gradient *)|Newton(* Exact Newton *)letrun=function|GD->fun____g'->Maths.negg'|CG->fun__gpg'->lety=Maths.(g'-g)inletb=Maths.(sum'(g'*y)/(sum'(p*y)+_f1e-32))inMaths.(negg'+(b*p))|CD->fun__gpg'->letb=Maths.(l2norm_sqr'g'/sum'(negp*g))inMaths.(negg'+(b*p))|NonlinearCG->fun__gpg'->letb=Maths.(l2norm_sqr'g'/l2norm_sqr'g)inMaths.(negg'+(b*p))|DaiYuanCG->fun__gpg'->lety=Maths.(g'-g)inletb=Maths.(l2norm_sqr'g'/sum'(p*y))inMaths.(negg'+(b*p))|NewtonCG->funfw_pg'->(* TODO: NOT FINISHED *)lethv=hessianvfwp|>Maths.transposeinletb=Maths.(hv*@g'/(hv*@p))inMaths.(negg'+(p*@b))|Newton->funfw___->letg',h'=gradhessianfwinMaths.(neg(g'*@invh'))letto_string=function|GD->"gradient descent"|CG->"conjugate gradient"|CD->"conjugate descent"|NonlinearCG->"nonlinear conjugate gradient"|DaiYuanCG->"dai & yuan conjugate gradient"|NewtonCG->"newton conjugate gradient"|Newton->"newton"endmoduleMomentum=structtypetyp=|Standardoffloat|Nesterovoffloat|Noneletrun=function|Standardm->funuu'->Maths.((_fm*u)+u')|Nesterovm->funuu'->Maths.((_fm*_fm*u)+((_fm+_f1.)*u'))|None->fun_u'->u'letdefault=function|Standard_->Standard0.9|Nesterov_->Nesterov0.9|None->Noneletto_string=function|Standardm->Printf.sprintf"standard %g"m|Nesterovm->Printf.sprintf"nesterov %g"m|None->Printf.sprintf"none"endmoduleRegularisation=structtypetyp=|L1normoffloat|L2normoffloat|Elastic_netoffloat*float|Noneletruntypx=matchtypwith|L1norma->Maths.(_fa*l1norm'x)|L2norma->Maths.(_fa*l2norm'x)|Elastic_net(a,b)->Maths.((_fa*l1norm'x)+(_fb*l2norm'x))|None->_f0.letto_string=function|L1norma->Printf.sprintf"l1norm (alpha = %g)"a|L2norma->Printf.sprintf"l2norm (alhpa = %g)"a|Elastic_net(a,b)->Printf.sprintf"elastic net (a = %g, b = %g)"ab|None->"none"endmoduleClipping=structtypetyp=|L2normoffloat|Valueoffloat*float(* min, max *)|Noneletruntypx=matchtypwith|L2normt->clip_by_l2norm(A.float_to_eltt)x|Value(a,b)->clip_by_value~amin:(A.float_to_elta)~amax:(A.float_to_eltb)x|None->xletdefault=function|L2norm_->L2norm1.|Value_->Value(0.,1.)|None->Noneletto_string=function|L2normt->Printf.sprintf"l2norm (threshold = %g)"t|Value(a,b)->Printf.sprintf"value (min = %g, max = %g)"ab|None->"none"endmoduleStopping=structtypetyp=|Constoffloat|Earlyofint*int(* stagnation patience, overfitting patience *)|Noneletruntypx=matchtypwith|Consta->x<a|Early(_,_)->failwith"not implemented"(* TODO *)|None->falseletdefault=function|Const_->Const1e-6|Early_->Early(750,10)|None->Noneletto_string=function|Consta->Printf.sprintf"const (a = %g)"a|Early(s,o)->Printf.sprintf"early (s = %i, o = %i)"so|None->"none"endmoduleCheckpoint=structtypestate={mutablecurrent_batch:int;(* current iteration progress in batch *)mutablebatches_per_epoch:int;(* number of batches in each epoch *)mutableepochs:float;(* total number of epochs to run *)mutablebatches:int;(* total batches = batches_per_epoch * epochs *)mutableloss:tarray;(* history of loss value in each iteration *)mutablestart_at:float;(* time when the optimisation starts *)mutablestop:bool;(* optimisation stops if true, otherwise false *)mutablegs:tarrayarray;(* gradient of the the previous iteration *)mutableps:tarrayarray;(* direction of the the prevoius iteration *)mutableus:tarrayarray;(* direction update of the previous iteration *)mutablech:tarrayarrayarray(* gcache of the prevoius iteration *)}typetyp=|Batchofint(* default checkpoint at every specified batch interval *)|Epochoffloat(* default checkpoint at every specified epoch interval *)|Customof(state->unit)(* customised checkpoint called at every batch *)|None(* no checkpoint at all, or interval is infinity *)letinit_statebatches_per_epochepochs=letbatches=float_of_intbatches_per_epoch*.epochs|>int_of_floatin{current_batch=1;batches_per_epoch;epochs;batches;loss=Array.make(batches+1)(_f0.);start_at=Unix.gettimeofday();stop=false;gs=[|[|_f0.|]|];ps=[|[|_f0.|]|];us=[|[|_f0.|]|];ch=[|[|[|_f0.;_f0.|]|]|]}letdefault_checkpoint_funsave_fun=letfile_name=Printf.sprintf"%s/%s.%i"(Sys.getcwd())"model"(Unix.time()|>int_of_float)inOwl_log.info"checkpoint => %s"file_name;save_funfile_nameletprint_state_infostate=letb_i=state.current_batchinletb_n=state.batchesinlete_n=state.epochsinlete_i=float_of_intb_i/.(float_of_intb_n/.e_n)inletl0=state.loss.(b_i-1)|>unpack_fltinletl1=state.loss.(b_i)|>unpack_fltinletd=l0-.l1inlets=ifd=0.then"-"elseifd<0.then"▲"else"▼"inlett=Unix.gettimeofday()-.state.start_at|>Owl_utils.format_timeinOwl_log.info"T: %s | E: %.1f/%g | B: %i/%i | L: %.6f[%s]"te_ie_nb_ib_nl1sletprint_summarystate=Unix.gettimeofday()-.state.start_at|>Owl_utils.format_time|>Printf.printf"--- Training summary\n Duration: %s\n"|>flush_allletruntypsave_funcurrent_batchcurrent_lossstate=state.loss.(current_batch)<-primal'current_loss;state.stop<-state.current_batch>=state.batches;letinterval=matchtypwith|Batchi->i|Epochi->i*.float_of_intstate.batches_per_epoch|>int_of_float|Custom_->1|None->max_intinifstate.current_batchmodinterval=0&&state.current_batch<state.batchesthen(matchtypwith|Customf->fstate|_->default_checkpoint_funsave_fun)letto_string=function|Batchi->Printf.sprintf"per %i batches"i|Epochi->Printf.sprintf"per %g epochs"i|Custom_->Printf.sprintf"customised f"|None->Printf.sprintf"none"endmoduleParams=structtypetyp={mutableepochs:float;mutablebatch:Batch.typ;mutablegradient:Gradient.typ;mutableloss:Loss.typ;mutablelearning_rate:Learning_Rate.typ;mutableregularisation:Regularisation.typ;mutablemomentum:Momentum.typ;mutableclipping:Clipping.typ;mutablestopping:Stopping.typ;mutablecheckpoint:Checkpoint.typ;mutableverbosity:bool}letdefault()={epochs=1.;batch=Batch.Sample100;gradient=Gradient.GD;loss=Loss.Cross_entropy;learning_rate=Learning_Rate.(default(Const0.));regularisation=Regularisation.None;momentum=Momentum.None;clipping=Clipping.None;stopping=Stopping.None;checkpoint=Checkpoint.None;verbosity=true}letconfig?batch?gradient?loss?learning_rate?regularisation?momentum?clipping?stopping?checkpoint?verbosityepochs=letp=default()in(matchbatchwith|Somex->p.batch<-x|None->());(matchgradientwith|Somex->p.gradient<-x|None->());(matchlosswith|Somex->p.loss<-x|None->());(matchlearning_ratewith|Somex->p.learning_rate<-x|None->());(matchregularisationwith|Somex->p.regularisation<-x|None->());(matchmomentumwith|Somex->p.momentum<-x|None->());(matchclippingwith|Somex->p.clipping<-x|None->());(matchstoppingwith|Somex->p.stopping<-x|None->());(matchcheckpointwith|Somex->p.checkpoint<-x|None->());(matchverbositywith|Somex->p.verbosity<-x|None->());p.epochs<-epochs;pletto_stringp=Printf.sprintf"--- Training config\n"^Printf.sprintf" epochs : %g\n"p.epochs^Printf.sprintf" batch : %s\n"(Batch.to_stringp.batch)^Printf.sprintf" method : %s\n"(Gradient.to_stringp.gradient)^Printf.sprintf" loss : %s\n"(Loss.to_stringp.loss)^Printf.sprintf" learning rate : %s\n"(Learning_Rate.to_stringp.learning_rate)^Printf.sprintf" regularisation : %s\n"(Regularisation.to_stringp.regularisation)^Printf.sprintf" momentum : %s\n"(Momentum.to_stringp.momentum)^Printf.sprintf" clipping : %s\n"(Clipping.to_stringp.clipping)^Printf.sprintf" stopping : %s\n"(Stopping.to_stringp.stopping)^Printf.sprintf" checkpoint : %s\n"(Checkpoint.to_stringp.checkpoint)^Printf.sprintf" verbosity : %s\n"(ifp.verbositythen"true"else"false")^"---"end(* core optimisation functions *)(* This function minimises the weight [w] of passed-in function [f].
[f] is a function [f : w -> x -> y].
[w] is a row vector but [y] can have any shape.
*)letminimise_weight?stateparamsfwxy=letopenParamsinifparams.verbosity=true&&state=Nonethenprint_endline(Params.to_stringparams);(* make alias functions *)letbach_fun=Batch.runparams.batchinletloss_fun=Loss.runparams.lossinletgrad_fun=Gradient.runparams.gradientinletrate_fun=Learning_Rate.runparams.learning_rateinletregl_fun=Regularisation.runparams.regularisationinletmomt_fun=Momentum.runparams.momentuminletupch_fun=Learning_Rate.update_chparams.learning_rateinletclip_fun=Clipping.runparams.clippinginletstop_fun=Stopping.runparams.stoppinginletchkp_fun=Checkpoint.runparams.checkpointin(* make the function to minimise *)letoptz_funxiyiwi=Maths.(loss_funyi(fwixi)+regl_funwi)in(* operations in the ith iteration *)letiterateiw=letxi,yi=bach_funxyiinletoptz=optz_funxiyiinletloss,g=grad'optzwinloss|>primal',g,optzin(* init new or continue previous state of optimisation process *)letstate=matchstatewith|Somestate->state|None->letbatches_per_epoch=Batch.batchesparams.batchxinletstate=Checkpoint.init_statebatches_per_epochparams.epochsin(* first iteration to bootstrap the optimisation *)letloss,_g0,_=iterate0win(* variables used for specific gradient method *)Checkpoint.(state.gs<-[|[|_g0|]|]);Checkpoint.(state.ps<-[|[|Maths.(neg_g0)|]|]);Checkpoint.(state.us<-[|[|_f0.|]|]);Checkpoint.(state.ch<-[|[|[|_f0.;_f0.|]|]|]);Checkpoint.(state.loss.(0)<-primal'loss);statein(* try to iterate all batches *)letw=refwinwhileCheckpoint.(state.stop=false)doletloss',g',optz'=iterateCheckpoint.(state.current_batch)!win(* check if the stopping criterion is met *)Checkpoint.(state.stop<-stop_fun(unpack_fltloss'));(* checkpoint of the optimisation if necessary *)chkp_fun(fun_->())Checkpoint.(state.current_batch)loss'state;(* print out the current state of optimisation *)ifparams.verbosity=truethenCheckpoint.print_state_infostate;(* clip the gradient if necessary *)letg'=clip_fung'in(* calculate gradient descent *)letp'=Checkpoint.(grad_funoptz'!wstate.gs.(0).(0)state.ps.(0).(0)g')in(* update gcache if necessary *)Checkpoint.(state.ch.(0).(0)<-upch_fung'state.ch.(0).(0));(* adjust direction based on learning_rate *)letu'=Checkpoint.(Maths.(p'*rate_funstate.current_batchg'state.ch.(0).(0)))in(* adjust direction based on momentum *)letu'=momt_funCheckpoint.(state.us.(0).(0))u'in(* update the weight *)w:=Maths.(!w+u')|>primal';(* save historical data *)(ifparams.momentum<>Momentum.NonethenCheckpoint.(state.us.(0).(0)<-u'));Checkpoint.(state.gs.(0).(0)<-g');Checkpoint.(state.ps.(0).(0)<-p');Checkpoint.(state.current_batch<-state.current_batch+1);(* force GC to release bigarray memory *)Gc.minor()done;(* print optimisation summary *)ifparams.verbosity=true&&Checkpoint.(state.current_batch>=state.batches)thenCheckpoint.print_summarystate;(* return both loss history and weight *)state,!w(* This function is specifically designed for minimising the weights in a
neural network of graph structure. In Owl's earlier versions, the functions
in the regression module were actually implemented using this function.
*)letminimise_network?stateparamsforwardbackwardupdatesavexy=letopenParamsinifparams.verbosity=true&&state=Nonethenprint_endline(Params.to_stringparams);(* make alias functions *)letbach_fun=Batch.runparams.batchinletloss_fun=Loss.runparams.lossinletgrad_fun=Gradient.runparams.gradientinletrate_fun=Learning_Rate.runparams.learning_rateinletregl_fun=Regularisation.runparams.regularisationinletmomt_fun=Momentum.runparams.momentuminletupch_fun=Learning_Rate.update_chparams.learning_rateinletclip_fun=Clipping.runparams.clippinginletstop_fun=Stopping.runparams.stoppinginletchkp_fun=Checkpoint.runparams.checkpointin(* operations in the ith iteration *)letiteratei=letxt,yt=bach_funxyiinletyt',ws=forwardxtinletloss=loss_funytyt'in(* take the mean of the loss *)letloss=Maths.(loss/_f(Mat.row_numyt|>float_of_int))in(* add regularisation term if necessary *)letreg=matchparams.regularisation<>Regularisation.Nonewith|true->Owl_utils.aarr_fold(funaw->Maths.(a+regl_funw))(_f0.)ws|false->_f0.inletloss=Maths.(loss+reg)inletws,gs'=backwardlossinloss|>primal',ws,gs'in(* init new or continue previous state of optimisation process *)letstate=matchstatewith|Somestate->state|None->letbatches_per_epoch=Batch.batchesparams.batchxinletstate=Checkpoint.init_statebatches_per_epochparams.epochsin(* first iteration to bootstrap the optimisation *)letloss,_ws,_gs=iterate0inupdate_ws;(* variables used for specific gradient method *)Checkpoint.(state.gs<-_gs);Checkpoint.(state.ps<-Owl_utils.aarr_mapMaths.neg_gs);Checkpoint.(state.us<-Owl_utils.aarr_map(fun_->_f0.)_gs);Checkpoint.(state.ch<-Owl_utils.aarr_map(fun_->[|_f0.;_f0.|])_gs);Checkpoint.(state.loss.(0)<-primal'loss);statein(* try to iterate all batches *)whileCheckpoint.(state.stop=false)doletloss',ws,gs'=iterateCheckpoint.(state.current_batch)in(* check if the stopping criterion is met *)Checkpoint.(state.stop<-stop_fun(unpack_fltloss'));(* checkpoint of the optimisation if necessary *)chkp_funsaveCheckpoint.(state.current_batch)loss'state;(* print out the current state of optimisation *)ifparams.verbosity=truethenCheckpoint.print_state_infostate;(* clip the gradient if necessary *)letgs'=Owl_utils.aarr_mapclip_fungs'in(* calculate gradient descent *)letps'=Checkpoint.(Owl_utils.aarr_map4(grad_fun(funa->a))wsstate.gsstate.psgs')in(* update gcache if necessary *)Checkpoint.(state.ch<-Owl_utils.aarr_map2upch_fungs'state.ch);(* adjust direction based on learning_rate *)letus'=Checkpoint.(Owl_utils.aarr_map3(funp'g'c->Maths.(p'*rate_funstate.current_batchg'c))ps'gs'state.ch)in(* adjust direction based on momentum *)letus'=Owl_utils.aarr_map2momt_funCheckpoint.(state.us)us'in(* update the weight *)letws'=Owl_utils.aarr_map2(funwu->Maths.(w+u))wsus'inupdatews';(* save historical data *)(ifparams.momentum<>Momentum.NonethenCheckpoint.(state.us<-us'));Checkpoint.(state.gs<-gs');Checkpoint.(state.ps<-ps');Checkpoint.(state.current_batch<-state.current_batch+1);(* force GC to release bigarray memory *)Gc.minor()done;(* print optimisation summary *)ifparams.verbosity=true&&Checkpoint.(state.current_batch>=state.batches)thenCheckpoint.print_summarystate;(* return the current state *)state(* This function minimises [f : x -> y] wrt [x].
[x] is an ndarray; and [y] is an scalar value.
*)letminimise_fun?stateparamsfx=letopenParamsinifparams.verbosity=true&&state=Nonethenprint_endline(Params.to_stringparams);(* make alias functions *)letgrad_fun=Gradient.runparams.gradientinletrate_fun=Learning_Rate.runparams.learning_rateinletregl_fun=Regularisation.runparams.regularisationinletmomt_fun=Momentum.runparams.momentuminletupch_fun=Learning_Rate.update_chparams.learning_rateinletclip_fun=Clipping.runparams.clippinginletstop_fun=Stopping.runparams.stoppinginletchkp_fun=Checkpoint.runparams.checkpointin(* make the function to minimise *)letoptz_funxi=Maths.(fxi+regl_funxi)in(* operations in the ith iteration *)letiterate_xi=letloss,g=grad'optz_funxiinloss|>primal',g,optz_funin(* init new or continue previous state of optimisation process *)letstate=matchstatewith|Somestate->state|None->letstate=Checkpoint.init_state1params.epochsin(* first iteration to bootstrap the optimisation *)letloss,_g0,_=iterate0xin(* variables used for specific gradient method *)Checkpoint.(state.gs<-[|[|_g0|]|]);Checkpoint.(state.ps<-[|[|Maths.(neg_g0)|]|]);Checkpoint.(state.us<-[|[|_f0.|]|]);Checkpoint.(state.ch<-[|[|[|_f0.;_f0.|]|]|]);Checkpoint.(state.loss.(0)<-primal'loss);statein(* try to iterate all batches *)letx=refxinwhileCheckpoint.(state.stop=false)doletloss',g',optz'=iterateCheckpoint.(state.current_batch)!xin(* check if the stopping criterion is met *)Checkpoint.(state.stop<-stop_fun(unpack_fltloss'));(* checkpoint of the optimisation if necessary *)chkp_fun(fun_->())Checkpoint.(state.current_batch)loss'state;(* print out the current state of optimisation *)ifparams.verbosity=truethenCheckpoint.print_state_infostate;(* clip the gradient if necessary *)letg'=clip_fung'in(* calculate gradient descent *)letp'=Checkpoint.(grad_funoptz'!xstate.gs.(0).(0)state.ps.(0).(0)g')in(* update gcache if necessary *)Checkpoint.(state.ch.(0).(0)<-upch_fung'state.ch.(0).(0));(* adjust direction based on learning_rate *)letu'=Checkpoint.(Maths.(p'*rate_funstate.current_batchg'state.ch.(0).(0)))in(* adjust direction based on momentum *)letu'=momt_funCheckpoint.(state.us.(0).(0))u'in(* update the weight *)x:=Maths.(!x+u')|>primal';(* save historical data *)(ifparams.momentum<>Momentum.NonethenCheckpoint.(state.us.(0).(0)<-u'));Checkpoint.(state.gs.(0).(0)<-g');Checkpoint.(state.ps.(0).(0)<-p');Checkpoint.(state.current_batch<-state.current_batch+1);(* force GC to release bigarray memory *)Gc.minor()done;(* print optimisation summary *)ifparams.verbosity=true&&Checkpoint.(state.current_batch>=state.batches)thenCheckpoint.print_summarystate;(* return both loss history and weight *)state,!x(* This function minimises deeply compiled neural network. *)letminimise_compiled_network?stateparamsevalupdatesavexy=letopenParamsinifparams.verbosity=true&&state=Nonethenprint_endline(Params.to_stringparams);(* make alias functions *)letbach_fun=Batch.runparams.batchinletstop_fun=Stopping.runparams.stoppinginletchkp_fun=Checkpoint.runparams.checkpointin(* operations in the ith iteration *)letiteratei=letxt,yt=bach_funxyiinletloss=evalxtytinlossin(* init new or continue previous state of optimisation process *)letstate=matchstatewith|Somestate->state|None->letbatches_per_epoch=Batch.batchesparams.batchxinletstate=Checkpoint.init_statebatches_per_epochparams.epochsin(* first iteration to bootstrap the optimisation *)letloss=iterate0inupdate();(* variables used for specific gradient method *)Checkpoint.(state.loss.(0)<-primal'loss);statein(* try to iterate all batches *)whileCheckpoint.(state.stop=false)doletloss'=iterateCheckpoint.(state.current_batch)in(* check if the stopping criterion is met *)Checkpoint.(state.stop<-stop_fun(unpack_fltloss'));(* checkpoint of the optimisation if necessary *)chkp_funsaveCheckpoint.(state.current_batch)(loss'|>unpack_flt|>pack_flt)state;(* print out the current state of optimisation *)ifparams.verbosity=truethenCheckpoint.print_state_infostate;update();(* save historical data *)Checkpoint.(state.current_batch<-state.current_batch+1)(* force GC to release bigarray memory *)(* FIXME Gc.minor (); *)done;(* print optimisation summary *)ifparams.verbosity=true&&Checkpoint.(state.current_batch>=state.batches)thenCheckpoint.print_summarystate;(* return the current state *)stateend(* Make functor ends *)