ile: fitc_gp.ml
OCaml-GPR - Gaussian Processes for OCaml
Copyright (C) 2009- Markus Mottl
email: markus.mottl@gmail.com
WWW: http://www.ocaml.info
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License
along with this library; if not, write to the Free Software Foundation,
Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*)openUtilsopenInterfacesopenCoreopenBigarrayopenLacaml.DmoduletypeSig=functor(Spec:Specs.Eval)->Sigs.EvalwithmoduleSpec=Spec(* Computations shared by FIC and FITC, and standard and variational
version *)moduleMake_common(Spec:Specs.Eval)=structmoduleSpec=SpecopenSpecletjitter=!cholesky_jitter(* Evaluation of inducing points *)moduleInducing=structtypet={kernel:Kernel.t;points:Spec.Inducing.t;km:mat;chol_km:mat;log_det_km:float;}letcheck_n_inducing~n_inducinginputs=letn_inputs=Spec.Inputs.get_n_pointsinputsinifn_inputs<1||n_inducing>n_inputsthenfailwithf"Gpr.Fitc_gp.Make_common.check_n_inducing: \
violating 1 <= n_inducing (%d) <= n_inputs (%d)"n_inducingn_inputs()letcalc_internalkernelpointskm=letchol_km=lacpy~uplo:`UkminMat.add_const_diagjitterchol_km;potrfchol_km;{kernel;points;km;chol_km;log_det_km=log_detchol_km}letcalckernelpoints=calc_internalkernelpoints(Spec.Inducing.calc_upperkernelpoints)letchoosekernelinputsindexes=letchosen_inputs=Spec.Inputs.choose_subsetinputsindexesinSpec.Inputs.create_inducingkernelchosen_inputsletchoose_n_first_inputskernelinputs~n_inducing=check_n_inducing~n_inducinginputs;letindexes=Int_vec.createn_inducinginfori=1ton_inducingdoindexes.{i}<-idone;choosekernelinputsindexesletchoose_n_random_inputs?(rnd_state=Random.State.default)kernelinputs~n_inducing=check_n_inducing~n_inducinginputs;letn_inputs=Spec.Inputs.get_n_pointsinputsinletindexes=Int_vec.createn_inputsinfori=1ton_inputsdoindexes.{i}<-idone;fori=1ton_inducingdoletrnd_index=Random.State.intrnd_state(n_inputs-i+1)+1inlettmp=indexes.{rnd_index}inindexes.{rnd_index}<-indexes.{i};indexes.{i}<-tmp;done;letindexes=Array1.subindexes1n_inducinginchoosekernelinputsindexesletget_kernelinducing=inducing.kernelletget_pointsinducing=inducing.pointsend(* Evaluation of one input point *)moduleInput=structtypet={inducing:Inducing.t;point:Spec.Input.t;k_m:vec}letcalcinducingpoint=let{Inducing.kernel;points=inducing_points}=inducingin{inducing;point;k_m=Spec.Input.evalkernelpointinducing_points}end(* Evaluation of input points *)moduleInputs=structtypet={inducing:Inducing.t;points:Inputs.t;knm:mat}letcalc_internalpointsinducingknm={inducing;points;knm}letcalcpointsinducing=let{Inducing.kernel;points=inducing_points}=inducinginletknm=Inputs.calc_crosskernel~inputs:points~inducing:inducing_pointsin{inducing;points;knm}letget_kernelt=t.inducing.Inducing.kernelletcalc_diaginputs=Inputs.calc_diag(get_kernelinputs)inputs.pointsletcalc_upperinputs=Inputs.calc_upper(get_kernelinputs)inputs.pointsletcreate_default_kernelinputs~n_inducing=Kernel.create(Spec.Inputs.create_default_kernel_paramsinputs~n_inducing)letget_chol_kmt=t.inducing.Inducing.chol_kmletget_log_det_kmt=t.inducing.Inducing.log_det_kmletget_knmt=t.knmletget_pointst=t.pointsend(* Model computations shared by standard and variational version *)moduleCommon_model=structtypet={sigma2:float;inputs:Inputs.t;kn_diag:vec;v_mat:mat;r_vec:vec;is_vec:vec;sqrt_is_vec:vec;q_mat:mat;r_mat:mat;l1:float;}typeco_variance_coeffs=mat*matletcheck_sigma2sigma2=ifsigma2<0.thenfailwith"Model.check_sigma2: sigma2 < 0"letcalc_internalinputssigma2~kn_diag~v_mat~r_vec=check_sigma2sigma2;letn=Mat.dim1v_matinletm=Mat.dim2v_matinletis_vec=Vec.createninletlog_det_s_vec=letreclooplog_det_s_veci=ifi=0thenlog_det_s_vecelselets_vec_i=r_vec.{i}+.sigma2inletis_vec_i=1./.s_vec_iinis_vec.{i}<-is_vec_i;loop(log_det_s_vec+.logs_vec_i)(i-1)inloop0.ninletsqrt_is_vec=Vec.sqrtis_vecinletnm=n+minletn1=n+1inletq_mat=Mat.create(n+m)minforc=1tomdoforr=n1+ctonmdoq_mat.{r,c}<-0.done;done;ignore(lacpy(Inputs.get_knminputs)~b:q_mat);Mat.scal_rows~m:nsqrt_is_vecq_mat;letchol_km=Inputs.get_chol_kminputsinignore(lacpy~uplo:`Uchol_km~br:n1~b:q_mat);lettau=geqrfq_matinletr_mat=lacpy~m~n:m~uplo:`Uq_matinorgqr~tauq_mat;letlog_det_r=letreclooplog_det_rr=ifr=0thenlog_det_r+.log_det_relseletel=letel=r_mat.{r,r}inifel>0.thenelelse(* Cannot happen with LAPACK version 3.2 and greater *)letneg_el=-.elinr_mat.{r,r}<-neg_el;forc=r+1tomdor_mat.{r,c}<--.r_mat.{r,c}done;Mat.scal~m:n~n:1~ac:r~-.1.q_mat;neg_elinloop(log_det_r+.logel)(r-1)inloop0.minletl1=letlog_det_km=Inputs.get_log_det_kminputsin-0.5*.(log_det_r-.log_det_km+.log_det_s_vec+.floatn*.log_2pi)in{inputs;sigma2;kn_diag;v_mat;r_vec;is_vec;sqrt_is_vec;q_mat;r_mat;l1;}letcalc_r_vec~kn_diag~v_mat=Mat.syrk_diag~alpha:~-.1.v_mat~beta:1.~y:(copykn_diag)letcalc_with_kn_diaginputssigma2kn_diag=letv_mat=lacpyinputs.Inputs.knmintrsm~side:`R(Inputs.get_chol_kminputs)v_mat;letr_vec=calc_r_vec~kn_diag~v_matincalc_internalinputssigma2~kn_diag~v_mat~r_vecletcalcinputs~sigma2=calc_with_kn_diaginputssigma2(Inputs.calc_diaginputs)letupdate_sigma2{inputs;kn_diag;v_mat;r_vec}sigma2=check_sigma2sigma2;calc_internalinputssigma2~kn_diag~v_mat~r_vecletcalc_log_evidencemodel=model.l1letget_v_matmodel=model.v_matletget_inducingmodel=model.inputs.Inputs.inducingletget_inducing_pointsmodel=Inducing.get_points(get_inducingmodel)letget_inputsmodel=model.inputsletget_input_pointsmodel=model.inputs.Inputs.pointsletget_kernelmodel=(get_inducingmodel).Inducing.kernelletget_sigma2model=model.sigma2letget_chol_kmmodel=Inputs.get_chol_kmmodel.inputsletget_r_vecmodel=model.r_vecletget_is_vecmodel=model.is_vecletget_sqrt_is_vecmodel=model.sqrt_is_vecletget_kmmodel=(get_inducingmodel).Inducing.kmletget_knmmodel=model.inputs.Inputs.knmletget_kn_diagmodel=model.kn_diagletget_q_matmodel=model.q_matletget_r_matmodel=model.r_matletcalc_co_variance_coeffsmodel=get_chol_kmmodel,model.r_matend(* Model computation (variational version) *)moduleVariational_model=structincludeCommon_modelletfrom_common({r_vec;is_vec;l1}asmodel)={modelwithl1=l1+.-0.5*.dotis_vecr_vec}letcalc_with_kn_diaginputssigma2kn_diag=from_common(calc_with_kn_diaginputssigma2kn_diag)letcalcinputs~sigma2=from_common(calcinputs~sigma2)letupdate_sigma2modelsigma2=from_common(update_sigma2modelsigma2)end(* Trained models *)moduleTrained=structtypet={model:Common_model.t;y:vec;coeffs:vec;l:float;}letcalc_internalmodel~y~coeffs~l2={model;y;coeffs;l=model.Common_model.l1+.l2}letprepare_internalmodel~y=letsqrt_is_vec=model.Common_model.sqrt_is_vecinletn=Vec.dimsqrt_is_vecinletn_y=Vec.dimyinifn_y<>nthenfailwithf"Trained.calc: Vec.dim targets (%d) <> n (%d)"n_yn();lety_=Vec.mulysqrt_is_veciny_,gemv~m:n~trans:`Tmodel.Common_model.q_maty_letcalcmodel~targets:y=lety_,qt_y_=prepare_internalmodel~yinletl2=-0.5*.(Vec.sqr_nrm2y_-.Vec.sqr_nrm2qt_y_)intrsvmodel.Common_model.r_matqt_y_;calc_internalmodel~y~coeffs:qt_y_~l2letcalc_mean_coeffstrained=trained.coeffsletcalc_log_evidencetrained=trained.lletcalc_meanstrained=gemv(Common_model.get_knmtrained.model)trained.coeffsletget_inducingtrained=Common_model.get_inducingtrained.modelletget_targets(trained:t)=trained.yletget_modeltrained=trained.modelendmoduleStats=structtypet={n_samples:int;target_variance:float;sse:float;mse:float;rmse:float;smse:float;msll:float;mad:float;maxad:float;}letcalc_n_samples{Trained.y}=Vec.dimyletcalc_target_variance{Trained.y}=Vec.sqr_nrm2y/.float(Vec.dimy)letcalc_ssetrained=letmeans=Trained.calc_meanstrainedinVec.ssqr_difftrained.Trained.ymeansletcalc_msetrained=calc_ssetrained/.float(calc_n_samplestrained)letcalc_rmsetrained=sqrt(calc_msetrained)letcalc_smsetrained=calc_msetrained/.calc_target_variancetrainedletcalc_prior_ltarget_variance=-0.5*.log(2.*.pi*.target_variance)-.0.5letcalc_mslltrained=letprior_l=calc_prior_l(calc_target_variancetrained)inprior_l-.trained.Trained.l/.float(calc_n_samplestrained)letcalc_mad({Trained.y}astrained)=letn_samples=Vec.dimyinletf_samples=floatn_samplesinletmeans=Trained.calc_meanstrainedinletrecloopmadsumi=ifi=0thenmadsum/.f_sampleselseloop(madsum+.Float.abs(y.{i}-.means.{i}))(i-1)inloop0.n_samplesletcalc_maxad({Trained.y}astrained)=letmeans=Trained.calc_meanstrainedinletrecloopmaxadi=ifi=0thenmaxadelseloop(maxmaxad(Float.abs(y.{i}-.means.{i})))(i-1)inloop0.(Vec.dimy)letcalc({Trained.y;l}astrained)=letn_samples=Vec.dimyinletf_samples=floatn_samplesinlettarget_variance=calc_target_variancetrainedinletmeans=Trained.calc_meanstrainedinletsse=Vec.ssqr_diffymeansinletmse=sse/.f_samplesinletrmse=sqrtmseinletsmse=mse/.target_varianceinletprior_l=calc_prior_ltarget_varianceinletmsll=prior_l-.l/.f_samplesinletmad,maxad=letrecloop~madsum~maxadi=ifi=0thenmadsum/.f_samples,maxadelseletad=Float.abs(y.{i}-.means.{i})inloop~madsum:(madsum+.ad)~maxad:(maxmaxadad)(i-1)inloop~madsum:0.~maxad:0.n_samplesin{n_samples;target_variance;sse;mse;rmse;smse;msll;mad;maxad}endmoduleMean_predictor=structtypet={inducing:Spec.Inducing.t;coeffs:vec}letcalc_trainedtrained={inducing=Inducing.get_points(Trained.get_inducingtrained);coeffs=Trained.calc_mean_coeffstrained;}letcalcinducing~coeffs=ifSpec.Inducing.get_n_pointsinducing<>Vec.dimcoeffsthenfailwith"Mean_predictor.calc: number of inducing points disagrees with \
dimension of coefficients"else{inducing;coeffs}letget_inducingt=t.inducingletget_coeffst=t.coeffsend(* Prediction of mean for one input point *)moduleMean=structtypet={point:Spec.Input.t;value:float}letcalcmean_predictor{Input.inducing=input_inducing;k_m;point}=ifnot(phys_equalmean_predictor.Mean_predictor.inducing(Inducing.get_pointsinput_inducing))thenfailwith"Mean.calc: mean predictor and input disagree about inducing points"else{point;value=dotk_mmean_predictor.Mean_predictor.coeffs}letgetmean=mean.valueend(* Prediction of means for several input points *)moduleMeans=structtypet={points:Spec.Inputs.t;values:vec}letcalcmean_predictor{Inputs.points;inducing;knm}=ifnot(phys_equalmean_predictor.Mean_predictor.inducing(Inducing.get_pointsinducing))thenfailwith"Means.calc: trained and inputs disagree about inducing points"else{points;values=gemvknmmean_predictor.Mean_predictor.coeffs}letgetmeans=means.valuesendmoduleCo_variance_predictor=structtypet={kernel:Spec.Kernel.t;inducing:Spec.Inducing.t;chol_km:mat;r_mat:mat;}letcalc_modelmodel={kernel=Common_model.get_kernelmodel;inducing=Inducing.get_points(Common_model.get_inducingmodel);chol_km=Common_model.get_chol_kmmodel;r_mat=Common_model.get_r_matmodel;}letcalckernelinducing(chol_km,r_mat)={kernel;inducing;chol_km;r_mat}end(* Prediction of variance for one input point *)moduleVariance=structtypet={point:Spec.Input.t;variance:float;sigma2:float}letcalcco_variance_predictor~sigma2{Input.inducing;point;k_m}=ifnot(phys_equalco_variance_predictor.Co_variance_predictor.inducing(Inducing.get_pointsinducing))thenfailwith"Variance.calc: \
co-variance predictor and input disagree about inducing points"elseletvariance=let{Co_variance_predictor.kernel;chol_km;r_mat}=co_variance_predictorinlettmp=copyk_mintrsv~trans:`Tchol_kmtmp;letk=Vec.sqr_nrm2tmpinlettmp=copyk_m~y:tmpintrsv~trans:`Tr_mattmp;letb=Vec.sqr_nrm2tmpinletprior_variance=Spec.Input.eval_onekernelpointinprior_variance-.(k-.b)in{point;variance;sigma2}letget?predictivet=matchpredictivewith|None|Sometrue->t.variance+.t.sigma2|Somefalse->t.varianceend(* Prediction of variance for several input points *)moduleVariances=structtypet={points:Spec.Inputs.t;variances:vec;sigma2:float}letcalc_model_inputsmodel=letvariances=lettmp=lacpy(Common_model.get_knmmodel)intrsm~side:`Rmodel.Common_model.r_mattmp;Mat.syrk_diagtmp~beta:1.~y:(copy(Common_model.get_r_vecmodel))inletsigma2=Common_model.get_sigma2modelin{points=Common_model.get_input_pointsmodel;variances;sigma2}letcalccvp~sigma2inputs=ifnot(phys_equalcvp.Co_variance_predictor.inducing(Inducing.get_pointsinputs.Inputs.inducing))thenfailwith"Variances.calc: \
co-variance predictor and inputs disagree about inducing points"elselet{Inputs.points;knm=ktm}=inputsinletvariances=lety=Inputs.calc_diaginputsinlettmp=lacpyktmintrsm~side:`Rcvp.Co_variance_predictor.chol_kmtmp;lety=Mat.syrk_diag~alpha:~-.1.tmp~beta:1.~yinlettmp=lacpyktm~b:tmpintrsm~side:`Rcvp.Co_variance_predictor.r_mattmp;Mat.syrk_diagtmp~beta:1.~yin{points;variances;sigma2}letget_common?predictive~variances~sigma2=matchpredictivewith|None|Sometrue->letpredictive_variances=Vec.make(Vec.dimvariances)sigma2inaxpyvariancespredictive_variances;predictive_variances|Somefalse->variancesletget?predictive{variances;sigma2}=get_common?predictive~variances~sigma2end(* Computations for predicting covariances shared by FIC and
FITC, and standard and variational version *)moduleCommon_covariances=structtypet={points:Spec.Inputs.t;covariances:mat;sigma2:float}letcheck_inducing~locco_variance_predictorinputs=ifnot(phys_equalco_variance_predictor.Co_variance_predictor.inducing(Inducing.get_pointsinputs.Inputs.inducing))thenfailwithf"%s_covariances.calc: \
co-variance predictor and inputs disagree about inducing points"loc()letget_common?predictive~covariances~sigma2=matchpredictivewith|None|Sometrue->letres=lacpy~uplo:`Ucovariancesinfori=1toMat.dim1resdores.{i,i}<-res.{i,i}+.sigma2done;res|Somefalse->covariancesletget?predictive{covariances;sigma2}=get_common?predictive~covariances~sigma2letget_variances{points;covariances;sigma2}={Variances.points;variances=Mat.copy_diagcovariances;sigma2}end(* Predicting covariances with FITC (standard or variational) *)moduleFITC_covariances=structincludeCommon_covariancesletcalc_model_inputsmodel=letcovariances=Inputs.calc_uppermodel.Common_model.inputsinletv_mat=model.Common_model.v_matinignore(syrk~alpha:~-.1.v_mat~beta:1.~c:covariances);letq_mat=model.Common_model.q_matinletn=Mat.dim1v_matinignore(syrk~nq_mat~beta:1.~c:covariances);letpoints=Common_model.get_input_pointsmodelinletsigma2=Common_model.get_sigma2modelin{points;covariances;sigma2}letcalcco_variance_predictor~sigma2inputs=check_inducing~loc:"FITC"co_variance_predictorinputs;let{Co_variance_predictor.chol_km;r_mat}=co_variance_predictorinletcovariances=Inputs.calc_upperinputsinlet{Inputs.points;knm=ktm}=inputsinletcovariances=lettmp=lacpyktmintrsm~side:`Rchol_kmtmp;ignore(syrk~alpha:~-.1.tmp~c:covariances);lettmp=lacpyktm~b:tmpintrsm~side:`Rr_mattmp;syrktmp~c:covariances;in{points;covariances;sigma2}end(* Predicting covariances with FIC (standard or variational) *)moduleFIC_covariances=structincludeCommon_covariancesletcalc_common~points~sigma2~q_mat~r_vec=letn=Vec.dimr_vecinletcovariances=syrk~nq_matinfori=1tondocovariances.{i,i}<-covariances.{i,i}+.r_vec.{i}done;{points;covariances;sigma2}letcalc_model_inputsmodel=letr_vec=model.Common_model.r_vecinletq_mat=model.Common_model.q_matinletpoints=Common_model.get_input_pointsmodelinletsigma2=Common_model.get_sigma2modelincalc_common~points~sigma2~q_mat~r_vecletcalcco_variance_predictor~sigma2({Inputs.knm=ktm}asinputs)=check_inducing~loc:"FIC"co_variance_predictorinputs;letkt_diag=Inputs.calc_diaginputsinletr_vec=Mat.syrk_diag~alpha:~-.1.ktm~beta:1.~y:kt_diaginletq_mat=lacpyktminletr_mat=co_variance_predictor.Co_variance_predictor.r_matintrsm~side:`Rr_matq_mat;letpoints=Inputs.get_pointsinputsincalc_common~points~sigma2~q_mat~r_vecend(* Computations for sampling the marginal posterior GP distribution
shared by standard and variational version *)moduleCommon_sampler=structtypet={mean:float;stddev:float}letcalc~loc?predictivemeanvariance=ifnot(phys_equalmean.Mean.pointvariance.Variance.point)thenfailwith(loc^".Sampler: mean and variance disagree about input point");letused_variance=matchpredictivewith|None|Sometrue->variance.Variance.variance+.variance.Variance.sigma2|Somefalse->variance.Variance.variancein{mean=mean.Mean.value;stddev=sqrtused_variance}letsample?(rng=default_rng)sampler=letnoise=Gsl.Randist.gaussian_zigguratrng~sigma:sampler.stddevinsampler.mean+.noiseletsamples?(rng=default_rng)sampler~n=Vec.initn(fun_->sample~rngsampler)end(* Computations for sampling the posterior GP distribution shared
by FIC and FITC, and standard and variational version *)moduleCommon_cov_sampler=structtypet={means:vec;cov_chol:mat}letcalc~loc?predictivemeanscovariances=letmoduleCovariances=Common_covariancesinifnot(phys_equalmeans.Means.pointscovariances.Covariances.points)thenfailwith(loc^".Cov_sampler: means and covariances disagree about input points");letcov_chol=lacpy~uplo:`Ucovariances.Covariances.covariancesinbeginmatchpredictivewith|None|Sometrue->letsigma2=covariances.Covariances.sigma2infori=1toMat.dim1cov_choldocov_chol.{i,i}<-cov_chol.{i,i}+.sigma2done|Somefalse->()end;Mat.add_const_diagjittercov_chol;potrfcov_chol;{means=means.Means.values;cov_chol}letsample?(rng=default_rng)samplers=letn=Vec.dimsamplers.meansinletsample=Vec.initn(fun_->Gsl.Randist.gaussian_zigguratrng~sigma:1.)intrmv~trans:`Tsamplers.cov_cholsample;axpysamplers.meanssample;sampleletsamples?(rng=default_rng){means;cov_chol}~n=letn_means=Vec.dimmeansinletsamples=Mat.init_colsn_meansn(fun__->Gsl.Randist.gaussian_zigguratrng~sigma:1.)intrmm~transa:`Tcov_cholsamples;forcol=1tondoforrow=1ton_meansdosamples.{row,col}<-samples.{row,col}+.means.{row}donedone;samplesendendletfitc_loc="FITC"letfic_loc="FIC"letvariational_fitc_loc="Variational_FITC"letvariational_fic_loc="Variational_FIC"moduleMake_FITC(Spec:Specs.Eval)=structincludeMake_common(Spec)moduleModel=Common_modelmoduleCovariances=FITC_covariancesmoduleSampler=structincludeCommon_samplerletcalc=calc~loc:fitc_locendmoduleCov_sampler=structincludeCommon_cov_samplerletcalc=calc~loc:fitc_locendendmoduleMake_FIC(Spec:Specs.Eval)=structincludeMake_common(Spec)moduleModel=Common_modelmoduleCovariances=FIC_covariancesmoduleSampler=structincludeCommon_samplerletcalc=calc~loc:fic_locendmoduleCov_sampler=structincludeCommon_cov_samplerletcalc=calc~loc:fic_locendendmoduleMake_variational_FITC(Spec:Specs.Eval)=structincludeMake_common(Spec)moduleModel=Variational_modelmoduleCovariances=FITC_covariancesmoduleSampler=structincludeCommon_samplerletcalc=calc~loc:variational_fitc_locendmoduleCov_sampler=structincludeCommon_cov_samplerletcalc=calc~loc:variational_fitc_locendendmoduleMake_variational_FIC(Spec:Specs.Eval)=structincludeMake_common(Spec)moduleModel=Variational_modelmoduleCovariances=FIC_covariancesmoduleSampler=structincludeCommon_samplerletcalc=calc~loc:variational_fic_locendmoduleCov_sampler=structincludeCommon_cov_samplerletcalc=calc~loc:variational_fic_locendendmoduleMake(Spec:Specs.Eval)=structmoduletypeSig=Sigs.EvalwithmoduleSpec=SpecmoduleCommon=Make_common(Spec)moduleFITC=structincludeCommonmoduleModel=Common_modelmoduleCovariances=FITC_covariancesmoduleSampler=structincludeCommon_samplerletcalc=calc~loc:fitc_locendmoduleCov_sampler=structincludeCommon_cov_samplerletcalc=calc~loc:fitc_locendendmoduleFIC=structincludeCommonmoduleModel=Common_modelmoduleCovariances=FIC_covariancesmoduleSampler=structincludeCommon_samplerletcalc=calc~loc:fic_locendmoduleCov_sampler=structincludeCommon_cov_samplerletcalc=calc~loc:fic_locendendmoduleVariational_FITC=structincludeCommonmoduleModel=Variational_modelmoduleCovariances=FITC_covariancesmoduleSampler=structincludeCommon_samplerletcalc=calc~loc:variational_fitc_locendmoduleCov_sampler=structincludeCommon_cov_samplerletcalc=calc~loc:variational_fitc_locendendmoduleVariational_FIC=structincludeCommonmoduleModel=Variational_modelmoduleCovariances=FIC_covariancesmoduleSampler=structincludeCommon_samplerletcalc=calc~loc:variational_fic_locendmoduleCov_sampler=structincludeCommon_cov_samplerletcalc=calc~loc:variational_fic_locendendend(* Handling derivatives *)moduletypeDeriv_sig=functor(Spec:Specs.Deriv)->Sigs.DerivwithmoduleEval.Spec=Spec.EvalwithmoduleDeriv.Spec=Spec(* Computations shared by FIC and FITC, and standard and variational
version for derivatives *)moduleMake_common_deriv(Spec:Specs.Deriv)=struct(* Eval modules *)moduleEval_common=Make_common(Spec.Eval)moduleEval_inducing=Eval_common.InducingmoduleEval_inputs=Eval_common.InputsmoduleEval_model=Eval_common.Common_modelmoduleEval_trained=Eval_common.Trained(* Kind of model *)typemodel_kind=Standard|VariationalmoduleDeriv_common=struct(* Derivative modules *)moduleSpec=Spec(* Derivative of inducing points *)moduleInducing=structtypet={eval:Eval_inducing.t;shared_upper:Spec.Inducing.upper}letcalckerneleval_inducing=letkm,shared_upper=Spec.Inducing.calc_shared_upperkerneleval_inducingin{eval=Eval_inducing.calc_internalkerneleval_inducingkm;shared_upper;}letcalc_evalinducing=inducing.evalletget_kernelinducing=Eval_inducing.get_kernelinducing.evalend(* Derivative of inputs *)moduleInputs=structtypet={inducing:Inducing.t;eval:Eval_inputs.t;shared_cross:Spec.Inputs.cross;}letcalcinducingpoints=letkernel=Inducing.get_kernelinducinginletknm,shared_cross=Spec.Inputs.calc_shared_crosskernel~inputs:points~inducing:inducing.Inducing.eval.Eval_inducing.pointsinleteval=Eval_inputs.calc_internalpointsinducing.Inducing.evalknmin{inducing;eval;shared_cross}letcalc_evalt=t.evalletget_kernelinputs=Inducing.get_kernelinputs.inducingend(* Derivative of hyper parameters *)moduleShared=structtypeshared={km:mat;knm:mat;kn_diag:vec;shared_upper:Spec.Inducing.upper;shared_cross:Spec.Inputs.cross;shared_diag:Spec.Inputs.diag;}typedfacts={v_vec:vec;w_mat:mat;x_mat:mat}typehyper_t={shared:shared;dfacts:dfacts}letcalc_us_mateval_model=letu_mat=lacpy(Eval_model.get_v_mateval_model)intrsm~side:`R~transa:`T(Eval_model.get_chol_kmeval_model)u_mat;letn=Mat.dim1u_matinletq_mat=Eval_model.get_q_mateval_modelinlets_mat=lacpy~m:nq_matintrsm~side:`R~transa:`Teval_model.Eval_model.r_mats_mat;Mat.scal_rows(Eval_model.get_sqrt_is_veceval_model)s_mat;u_mat,s_matletupdate_tmptmpv=tmp.x<-tmp.x+.vletcalc_dkn_diag_term~v_vec~kn_diag=function|`Vecdkn_diag->dotv_vecdkn_diag|`Sparse_vec(svec,rows)->check_sparse_vec_sane~real_n:(Vec.dimv_vec)~svec~rows;lettmp={x=0.}infori=1toVec.dimsvecdoupdate_tmptmp(v_vec.{rows.{i}}*.svec.{i})done;tmp.x|`Const0.|`Factor0.->0.|`Constc->c*.Vec.sumv_vec|`Factorc->c*.dotkn_diagv_vecletcalc_dkm_term~w_mat~km=function|`Densedkm->Mat.symm2_tracew_matdkm|`Sparse_rows(smat,rows)->symm2_sparse_trace~mat:w_mat~smat~rows|`Const0.|`Factor0.|`Diag_const0.->0.|`Constc->c*.sum_symm_matw_mat|`Factorc->c*.Mat.symm2_tracew_matkm|`Diag_vecddkm->lettmp={x=0.}infori=1toMat.dim1w_matdoupdate_tmptmp(ddkm.{i}*.w_mat.{i,i})done;tmp.x|`Diag_constc->lettmp={x=0.}infori=1toMat.dim1w_matdoupdate_tmptmp(c*.w_mat.{i,i})done;tmp.xletcalc_dknm_term~x_mat~knm=function|`Densedknm->Mat.gemm_trace~transa:`Tx_matdknm|`Sparse_cols(sdknm,cols)->letreal_n=Mat.dim2x_matincheck_sparse_col_mat_sane~real_n~smat:sdknm~cols;letm=Mat.dim1sdknminlettmp={x=0.}inforc=1toInt_vec.dimcolsdoletreal_c=cols.{c}inforr=1tomdoupdate_tmptmp(x_mat.{r,real_c}*.sdknm.{r,c})donedone;tmp.x|`Const0.|`Factor0.->0.|`Constc->c*.sum_matx_mat|`Factorc->c*.Mat.gemm_trace~transa:`Tx_matknm|`Sparse_rows(sdknm,rows)->letreal_m=Mat.dim1x_matincheck_sparse_row_mat_sane~real_m~smat:sdknm~rows;letn=Mat.dim2sdknminlettmp={x=0.}inforr=1toInt_vec.dimrowsdoletreal_r=rows.{r}inforc=1tondoupdate_tmptmp(x_mat.{real_r,c}*.sdknm.{r,c})donedone;tmp.xletcalc_log_evidence{shared;dfacts={v_vec;w_mat;x_mat}}hyper=letdkn_diag_term=letkn_diag=shared.kn_diaginletdkn_diag=Spec.Inputs.calc_deriv_diagshared.shared_diaghyperincalc_dkn_diag_term~v_vec~kn_diagdkn_diaginletdkm_term=letkm=shared.kminletdkm=Spec.Inducing.calc_deriv_uppershared.shared_upperhyperincalc_dkm_term~w_mat~kmdkminletdknm_term=letknm=shared.knminletdknm=Spec.Inputs.calc_deriv_crossshared.shared_crosshyperincalc_dknm_term~x_mat~knmdknmin-0.5*.(dkn_diag_term-.dkm_term)-.dknm_termend(* Derivative of models *)moduleCommon_model=struct(* Precomputations for all derivatives *)typet={model_kind:model_kind;model_shared:Shared.shared;eval_model:Eval_model.t;inv_km:mat;q_diag:vec;t_mat:mat;}letcalc_internalmodel_kindmodel_sharedeval_modelinv_km=letq_mat=Eval_model.get_q_mateval_modelinletn=Mat.dim1q_mat-Mat.dim2q_matinlett_mat=lacpy~uplo:`Uinv_kminMat.axpy~alpha:~-.1.(icholeval_model.Eval_model.r_mat)t_mat;{model_kind;model_shared;eval_model;inv_km;t_mat;q_diag=Mat.syrk_diag~nq_mat}letcalc_commonmodel_kindinputssigma2=letkernel=Inputs.get_kernelinputsinleteval_inputs=inputs.Inputs.evalinletkn_diag,shared_diag=Spec.Inputs.calc_shared_diagkerneleval_inputs.Eval_inputs.pointsinletcalc_with_kn_diag=matchmodel_kindwith|Standard->Eval_model.calc_with_kn_diag|Variational->Eval_common.Variational_model.calc_with_kn_diaginleteval_model=calc_with_kn_diageval_inputssigma2kn_diaginletkm=Eval_model.get_kmeval_modelinletknm=Eval_model.get_knmeval_modelinletkn_diag=Eval_model.get_kn_diageval_modelinletchol_km=Eval_model.get_chol_kmeval_modelinletinv_km=icholchol_kminletmodel_shared={Shared.km;knm;kn_diag;shared_diag;shared_upper=inputs.Inputs.inducing.Inducing.shared_upper;shared_cross=inputs.Inputs.shared_cross;}incalc_internalmodel_kindmodel_sharedeval_modelinv_kmletcalc_evalmodel=model.eval_modelletcalcinputs~sigma2=calc_commonStandardinputssigma2letupdate_sigma2({model_kind}asmodel)sigma2=letupdate_sigma2=matchmodel_kindwith|Standard->Eval_model.update_sigma2|Variational->Eval_common.Variational_model.update_sigma2inleteval_model=update_sigma2model.eval_modelsigma2incalc_internalmodel_kindmodel.model_sharedeval_modelmodel.inv_kmletcalc_v1_vec{q_diag;eval_model;model_kind}=letis_vec=Eval_model.get_is_veceval_modelinletn=Vec.dimis_vecinletv1_vec=Vec.createninmatchmodel_kindwith|Standard->fori=1tondov1_vec.{i}<-is_vec.{i}*.(1.-.q_diag.{i})done;v1_vec|Variational->letr_vec=Eval_model.get_r_veceval_modelinfori=1tondov1_vec.{i}<-is_vec.{i}*.(2.-.is_vec.{i}*.r_vec.{i}-.q_diag.{i})done;v1_vec(* Derivative of sigma2 *)letcommon_calc_log_evidence_sigma2{eval_model;model_kind}v_vec=letsum_v_vec=Vec.sumv_vecinletsum=matchmodel_kindwith|Standard->sum_v_vec|Variational->sum_v_vec-.Vec.sumeval_model.Eval_model.is_vecin-0.5*.sumletcalc_log_evidence_sigma2model=common_calc_log_evidence_sigma2model(calc_v1_vecmodel)(* Prepare derivative of general hyper-parameters *)letprepare_hyper({eval_model;t_mat;model_shared}asmodel)=letv_vec=calc_v1_vecmodelinletsqrt_v_vec=Vec.sqrtv_vecinletu_mat,x_mat=Shared.calc_us_mateval_modelinMat.scal_rowssqrt_v_vecu_mat;letc=lacpy~uplo:`Ut_matinletw_mat=syrk~trans:`T~alpha:~-.1.u_mat~beta:1.~cinMat.scal_rowssqrt_v_vecu_mat;Mat.axpy~alpha:~-.1.u_matx_mat;letdfacts={Shared.v_vec;w_mat;x_mat}in{Shared.shared=model_shared;dfacts}includeSharedendmoduleCm=Common_modelmoduleVariational_model=structincludeCmletcalcinputs~sigma2=calc_commonVariationalinputssigma2end(* Derivative of trained models *)moduleTrained=structtypet={common_model:Cm.t;eval_trained:Eval_trained.t;w_vec:vec;v_vec:vec;}letcalccommon_model~targets:y=leteval_model=common_model.Cm.eval_modelinlety_,qt_y_=Eval_trained.prepare_internaleval_model~yinletu_vec=copyy_inletn=Vec.dimy_inletq_mat=Eval_model.get_q_mateval_modelinignore(gemv~m:n~alpha:~-.1.q_matqt_y_~beta:1.~y:u_vec);letl2=-0.5*.dotu_vecy_inletcoeffs=qt_y_intrsveval_model.Eval_model.r_matcoeffs;letw_vec=u_vecinletsqrt_is_vec=Eval_model.get_sqrt_is_veceval_modelinfori=1tondow_vec.{i}<-w_vec.{i}*.sqrt_is_vec.{i}done;letv2_vec=Vec.sqrw_vecinletv_vec=Cm.calc_v1_veccommon_modelinaxpy~alpha:~-.1.v2_vecv_vec;{common_model;w_vec;v_vec;eval_trained=Eval_trained.calc_internaleval_model~y~coeffs~l2}letcalc_evaltrained=trained.eval_trained(* Derivative of sigma2 *)letcalc_log_evidence_sigma2{common_model;v_vec}=Cm.common_calc_log_evidence_sigma2common_modelv_vec(* Derivative of general hyper-parameters *)letprepare_hyper{common_model;eval_trained;w_vec;v_vec}=let{Cm.eval_model;t_mat;model_shared=shared}=common_modelinletu_mat,x_mat=Shared.calc_us_mateval_modelinlett_vec=eval_trained.Eval_trained.coeffsinletw_mat=lacpy~uplo:`Ut_matinsyr~alpha:~-.1.t_vecw_mat;letu1_mat=lacpyu_matinMat.scal_rows(Vec.sqrt(Cm.calc_v1_veccommon_model))u1_mat;letw_mat=syrk~trans:`T~alpha:~-.1.u1_mat~beta:1.~c:w_matinletu2_mat=lacpyu_mat~b:u1_matinMat.scal_rowsw_vecu2_mat;ignore(syrk~trans:`Tu2_mat~beta:1.~c:w_mat);Mat.scal_rowsv_vecu_mat;Mat.axpy~alpha:~-.1.u_matx_mat;ger~alpha:~-.1.w_vect_vecx_mat;{Shared.shared;dfacts={Shared.v_vec;w_mat;x_mat}}includeSharedendmoduleTest=structletupdate_hyperkernelinducing_pointspointshyper~eps=letvalue=Spec.Hyper.get_valuekernelinducing_pointspointshyperinletvalue_eps=value+.epsinSpec.Hyper.set_valueskernelinducing_pointspoints[|hyper|](Vec.make1value_eps)letis_bad_deriv~finite_el~deriv~tol=Float.is_nanfinite_el||Float.is_nanderiv||Float.abs(finite_el-.deriv)>tolletcheck_deriv_hyper?(eps=1e-8)?(tol=1e-2)kernel1inducing_points1points1hyper=letkernel2,inducing_points2,points2=update_hyperkernel1inducing_points1points1hyper~epsinleteval_inducing1=Eval_inducing.calckernel1inducing_points1inleteval_cross1=Eval_inputs.calcpoints1eval_inducing1inleteval_inducing2=Eval_inducing.calckernel2inducing_points2inleteval_cross2=Eval_inputs.calcpoints2eval_inducing2inletmake_finite~mat1~mat2=letres=lacpymat2inMat.axpy~alpha:~-.1.mat1res;Mat.scal(1./.eps)res;resinletkm1=eval_inducing1.Eval_inducing.kminletfinite_dkm=make_finite~mat1:km1~mat2:eval_inducing2.Eval_inducing.kminletinducing1=Inducing.calckernel1inducing_points1inletcheck_mat~name~deriv~finite~r~c=letfinite_el=finite.{r,c}inifis_bad_deriv~finite_el~deriv~tolthenfailwithf"Gpr.Fitc_gp.Make_deriv.Test.check_deriv_hyper: \
finite difference (%f) and derivative (%f) differ \
by more than %f on %s.{%d, %d}"finite_elderivtolnamerc()in(* Check dkm *)beginletcheck=check_mat~name:"dkm"~finite:finite_dkminmatchSpec.Inducing.calc_deriv_upperinducing1.Inducing.shared_upperhyperwith|`Densedkm->letm=Mat.dim1dkminforc=1tomdoforr=1tocdocheck~deriv:dkm.{r,c}~r~cdonedone|`Sparse_rows(sdkm,rows)->letm=Int_vec.dimrowsinletn=Mat.dim2sdkminletrows_ix_ref=ref1inforsparse_r=1tomdoletc=rows.{sparse_r}inforr=1tondoletmat_r,mat_c=ifr>cthenc,relser,cinletrows_ix=!rows_ix_refinletderiv=ifrows_ix>m||letrows_el=rows.{rows_ix}inr<rows_el||c<rows_elthensdkm.{sparse_r,r}elsebeginincrrows_ix_ref;sdkm.{rows_ix,c}endincheck~deriv~r:mat_r~c:mat_cdone;rows_ix_ref:=1done|`Constconst->letm=Mat.dim1km1inforc=1tomdoforr=1tocdocheck~deriv:const~r~cdonedone|`Factorconst->letm=Mat.dim1km1inforc=1tomdoforr=1tocdocheck~deriv:(const*.km1.{r,c})~r~cdonedone|`Diag_vecdiag->letm=Mat.dim1km1inforc=1tomdocheck~deriv:diag.{c}~r:c~cdone|`Diag_constconst->letm=Mat.dim1km1inforc=1tomdocheck~deriv:const~r:c~cdoneend;(* Check dknm *)letinputs=Inputs.calcinducing1points1inbeginletknm1=eval_cross1.Eval_inputs.knminletfinite_dknm=make_finite~mat1:knm1~mat2:eval_cross2.Eval_inputs.knminletcheck=check_mat~name:"dknm"~finite:finite_dknminmatchSpec.Inputs.calc_deriv_crossinputs.Inputs.shared_crosshyperwith|`Densedknm->letm=Mat.dim1knm1inforc=1toMat.dim2knm1doforr=1tomdocheck~deriv:dknm.{r,c}~r~cdonedone|`Sparse_cols(sdknm,cols)->letm=Mat.dim1sdknminforc=1toInt_vec.dimcolsdoletreal_c=cols.{c}inforr=1tomdocheck~deriv:sdknm.{r,c}~r~c:real_cdonedone|`Constconst->letm=Mat.dim1knm1inforc=1toMat.dim2knm1doforr=1tomdocheck~deriv:const~r~cdonedone|`Factorconst->letm=Mat.dim1knm1inforc=1toMat.dim2knm1doforr=1tomdocheck~deriv:(const*.knm1.{r,c})~r~cdonedone|`Sparse_rows(sdknm,rows)->letn=Mat.dim2sdknminforr=1toInt_vec.dimrowsdoletreal_r=rows.{r}inforc=1tondocheck~deriv:sdknm.{r,c}~r:real_r~cdonedoneend;(* Check dkn diag *)beginletkn_diag1,shared_diag=Spec.Inputs.calc_shared_diagkernel1points1inletkn_diag2=Spec.Eval.Inputs.calc_diagkernel2points2inletfinite_dkn_diag=letres=copykn_diag2inaxpy~alpha:~-.1.kn_diag1res;scal(1./.eps)res;resinletcheck~deriv~r=letfinite_el=finite_dkn_diag.{r}inifis_bad_deriv~finite_el~deriv~tolthenfailwithf"Gpr.Fitc_gp.Make_deriv.Test.check_deriv_hyper: \
finite difference (%f) and derivative (%f) differ \
by more than %f on dkn_diag.{%d}"finite_elderivtolr()inmatchSpec.Inputs.calc_deriv_diagshared_diaghyperwith|`Vecdkn_diag->forr=1toVec.dimdkn_diagdocheck~deriv:dkn_diag.{r}~rdone|`Sparse_vec(sdkn_diag,cols)->letn=Int_vec.dimcolsinforr=1tondocheck~deriv:sdkn_diag.{r}~r:cols.{r}done|`Constconst->forr=1toVec.dimkn_diag1docheck~deriv:const~rdone|`Factorconst->forr=1toVec.dimkn_diag1docheck~deriv:(const*.kn_diag1.{r})~rdoneendletself_test?(eps=1e-8)?(tol=1e-2)kernel1inducing_points1points1~sigma2~targetshyper=letinducing1=Inducing.calckernel1inducing_points1inletinputs1=Inputs.calcinducing1points1inletderiv_model=Cm.calcinputs1~sigma2inleteval_model1=Cm.calc_evalderiv_modelinletmodel_log_evidence1=Eval_model.calc_log_evidenceeval_model1inletderiv_trained=Trained.calcderiv_model~targetsinleteval_trained1=Trained.calc_evalderiv_trainedinlettrained_log_evidence1=Eval_trained.calc_log_evidenceeval_trained1inletcheck~name~before~after~deriv=letfinite_el=(after-.before)/.epsinifis_bad_deriv~finite_el~deriv~tolthenfailwithf"Gpr.Fitc_gp.Make_deriv.Test.self_test: \
finite difference (%f) and derivative (%f) differ \
by more than %f on %s"finite_elderivtolname()inmatchhyperwith|`Sigma2->leteval_model2=letsigma2=sigma2+.epsinEval_model.calcinputs1.Inputs.eval~sigma2inletmodel_log_evidence2=Eval_model.calc_log_evidenceeval_model2inletmodel_deriv=Cm.calc_log_evidence_sigma2deriv_modelincheck~name:"sigma2(model)"~before:model_log_evidence1~after:model_log_evidence2~deriv:model_deriv;leteval_trained2=Eval_trained.calceval_model2~targetsinlettrained_log_evidence2=Eval_trained.calc_log_evidenceeval_trained2inlettrained_deriv=Trained.calc_log_evidence_sigma2deriv_trainedincheck~name:"sigma2(trained)"~before:trained_log_evidence1~after:trained_log_evidence2~deriv:trained_deriv|`Hyperhyper->letkernel2,inducing_points2,points2=update_hyperkernel1inducing_points1points1hyper~epsinleteval_inducing2=Eval_inducing.calckernel2inducing_points2inleteval_inputs2=Eval_inputs.calcpoints2eval_inducing2inleteval_model2=Eval_model.calceval_inputs2~sigma2inletmodel_log_evidence2=Eval_model.calc_log_evidenceeval_model2inletmodel_hyper_t=Cm.prepare_hyperderiv_modelinletmodel_deriv=Cm.calc_log_evidencemodel_hyper_thyperincheck~name:"hyper(model)"~before:model_log_evidence1~after:model_log_evidence2~deriv:model_deriv;leteval_trained2=Eval_trained.calceval_model2~targetsinlettrained_log_evidence2=Eval_trained.calc_log_evidenceeval_trained2inlettrained_hyper_t=Trained.prepare_hyperderiv_trainedinlettrained_deriv=Trained.calc_log_evidencetrained_hyper_thyperincheck~name:"hyper(trained)"~before:trained_log_evidence1~after:trained_log_evidence2~deriv:trained_derivend(* Hyper parameter optimization by evidence maximization
(type II maximum likelihood) *)moduleOptim=structletget_sigma2targets=function|None->Vec.sqr_nrm2targets/.float(Vec.dimtargets)|Somesigma2whensigma2<0.->failwithf"Optim.get_sigma2: sigma2 < 0: %f"sigma2()|Somesigma2->sigma2letget_kernel_inducing?kernel?n_rand_inducing~inputs=function|None->letn_inducing=letn_inputs=Spec.Eval.Inputs.get_n_pointsinputsinmatchn_rand_inducingwith|None->min(n_inputs/10)1000|Somen_rand_inducing->ifn_rand_inducing<1thenfailwithf"Gpr.Fitc_gp.Optim.get_kernel_inducing: \
n_rand_inducing (%d) < 1"n_rand_inducing()elseifn_rand_inducing>n_inputsthenfailwithf"Gpr.Fitc_gp.Optim.get_kernel_inducing: \
n_rand_inducing (%d) > n_inputs (%d)"n_rand_inducingn_inputs()elsen_rand_inducinginletkernel=matchkernelwith|None->Eval_inputs.create_default_kernel~n_inducinginputs|Somekernel->kernelin(kernel,Eval_inducing.choose_n_random_inputskernel~n_inducinginputs)|Someinducing->matchkernelwith|None->letn_inducing=Spec.Eval.Inducing.get_n_pointsinducingin(Eval_inputs.create_default_kernel~n_inducinginputs,inducing)|Somekernel->kernel,inducingletget_hypers_valskernelinducingpointshypers=lethypers=matchhyperswith|None->Spec.Hyper.get_allkernelinducingpoints|Somehypers->hypersinletn_hypers=Array.lengthhypersinlethyper_vals=Vec.initn_hypers(funi1->Spec.Hyper.get_valuekernelinducingpointshypers.(i1-1))inhypers,hyper_valsmoduleGsl=structexceptionOptim_exceptionofexnletcheck_exceptionseen_exception_refres=ifFloat.classifyres=Float.Class.Nanthenmatch!seen_exception_refwith|None->failwith"Gpr.Optim.Gsl: optimization function returned nan"|Someexc->raise(Optim_exceptionexc)letignore_report~iter:__=()lettrain?(step=1e-1)?(tol=1e-1)?(epsabs=1e-1)?(report_trained_model=ignore_report)?(report_gradient_norm=ignore_report)?kernel?sigma2?inducing?n_rand_inducing?(learn_sigma2=true)?hypers~inputs~targets()=letsigma2=get_sigma2targetssigma2inletkernel,inducing=get_kernel_inducing?kernel?n_rand_inducing~inputsinducinginlethypers,hyper_vals=get_hypers_valskernelinducinginputshypersinletn_hypers=Array.lengthhypersinletn_gsl_hypers,gsl_hypers=iflearn_sigma2thenletn_gsl_hypers=1+n_hypersinletgsl_hypers=Gsl.Vector.createn_gsl_hypersingsl_hypers.{0}<-logsigma2;fori=1ton_hypersdogsl_hypers.{i}<-hyper_vals.{i}done;n_gsl_hypers,gsl_hyperselseletgsl_hypers=Gsl.Vector.createn_hypersinfori=1ton_hypersdogsl_hypers.{i-1}<-hyper_vals.{i}done;n_hypers,gsl_hypersinletmoduleGd=Gsl.Multimin.Derivinletsigma2_ref=refsigma2inletupdate_hypers=iflearn_sigma2then(fun~gsl_hypers->sigma2_ref:=expgsl_hypers.{0};lethyper_vals=Vec.createn_hypersinfori=1ton_hypersdohyper_vals.{i}<-gsl_hypers.{i}done;Spec.Hyper.set_valueskernelinducinginputshypershyper_vals)else(fun~gsl_hypers->lethyper_vals=Vec.createn_hypersinfori=1ton_hypersdohyper_vals.{i}<-gsl_hypers.{i-1}done;Spec.Hyper.set_valueskernelinducinginputshypershyper_vals)inletseen_exception_ref=refNoneinletwrap_seen_exceptionf=tryf()withexc->seen_exception_ref:=Someexc;raiseexcinletbest_model_ref=refNoneinletget_best_model()=match!best_model_refwith|None->assertfalse(* impossible *)|Some(trained,_)->trainedinletiter_count=ref1inletupdate_best_modeltrainedlog_evidence=match!best_model_refwith|Some(_,old_log_evidence)whenold_log_evidence>=log_evidence->()|_->report_trained_model~iter:!iter_counttrained;best_model_ref:=Some(trained,log_evidence)inletmultim_f~x:gsl_hypers=letkernel,inducing,inputs=update_hypers~gsl_hypersinleteval_inducing=Eval_inducing.calckernelinducinginleteval_inputs=Eval_inputs.calcinputseval_inducinginletmodel=Eval_model.calceval_inputs~sigma2:!sigma2_refinlettrained=Eval_trained.calcmodel~targetsinletlog_evidence=Eval_trained.calc_log_evidencetrainedinupdate_best_modeltrainedlog_evidence;-.log_evidenceinletmultim_f~x=wrap_seen_exception(fun()->multim_f~x)inletmultim_dcommon~x:gsl_hypers~g:gradient=letkernel,inducing,inputs=update_hypers~gsl_hypersinletderiv_inducing=Inducing.calckernelinducinginletderiv_inputs=Inputs.calcderiv_inducinginputsinletdmodel=Cm.calc~sigma2:!sigma2_refderiv_inputsinlettrained=Trained.calcdmodel~targetsiniflearn_sigma2thenletdlog_evidence_dsigma2=Trained.calc_log_evidence_sigma2trainedingradient.{0}<--.dlog_evidence_dsigma2*.!sigma2_ref;ifn_hypers=0then()elselethyper_t=Trained.prepare_hypertrainedinfori1=1ton_hypersdogradient.{i1}<--.Trained.calc_log_evidencehyper_thypers.(i1-1)done;elsebeginlethyper_t=Trained.prepare_hypertrainedinfori=0ton_hypers-1dogradient.{i}<--.Trained.calc_log_evidencehyper_thypers.(i)done;end;trainedinletmultim_df~x~g=ignore(multim_dcommon~x~g)inletmultim_df~x~g=wrap_seen_exception(fun()->multim_df~x~g)inletmultim_fdf~x~g=letderiv_trained=multim_dcommon~x~ginlettrained=Trained.calc_evalderiv_trainedinletlog_evidence=Eval_trained.calc_log_evidencetrainedinupdate_best_modeltrainedlog_evidence;-.log_evidenceinletmultim_fdf~x~g=wrap_seen_exception(fun()->multim_fdf~x~g)inletmultim_fun_fdf={Gsl.Fun.multim_f;multim_df;multim_fdf}inletmumin=Gd.makeGd.VECTOR_BFGS2n_gsl_hypersmultim_fun_fdf~x:gsl_hypers~step~tolinletgsl_dhypers=Gsl.Vector.createn_gsl_hypersinletrecloop()=letneg_log_likelihood=Gd.minimum~x:gsl_hypers~g:gsl_dhypersmuminincheck_exceptionseen_exception_refneg_log_likelihood;letgnorm=Gsl.Blas.nrm2gsl_dhypersinbegintryreport_gradient_norm~iter:!iter_countgnormwithexc->raise(Optim_exceptionexc)end;ifgnorm<epsabsthenget_best_model()elsebeginincriter_count;Gd.iteratemumin;loop()endinloop()endletcalc_gradient~learn_sigma2~sigma2~hypers~trained=letn_hypers=Array.lengthhypersinletn_all_hypers=iflearn_sigma2thenn_hypers+1elsen_hypersinletgradient=Vec.createn_all_hypersiniflearn_sigma2thenletdlog_evidence_dsigma2=Trained.calc_log_evidence_sigma2trainedingradient.{1}<-dlog_evidence_dsigma2*.sigma2;ifn_hypers=0then()elselethyper_t=Trained.prepare_hypertrainedinfori=0ton_hypers-1dogradient.{i+2}<-Trained.calc_log_evidencehyper_thypers.(i)doneelsebeginlethyper_t=Trained.prepare_hypertrainedinfori1=1ton_hypersdogradient.{i1}<-Trained.calc_log_evidencehyper_thypers.(i1-1)doneend;gradientletmake_teststepgradient_normget_trained?(epsabs=0.1)?max_iter?(report=ignore)t=letmax_iter=matchmax_iterwith|None->-1|Somemax_iterwhenmax_iter<0->failwith"Optim.SMD.test: max_iter < 0"|Somemax_iter->max_iterinletrecloopn~best_le~best~t=ifn=0||gradient_normt<epsabsthenbestelseletnew_t=steptinletbest_le,best=letnew_trained=get_trainednew_tinletnew_log_evidence=Eval_trained.calc_log_evidencenew_trainedinifnew_log_evidence<=best_lethenbest_le,bestelsebeginreportnew_t;new_log_evidence,new_tendinloop(n-1)~best_le~best~t:new_tinletbest_le=Eval_trained.calc_log_evidence(get_trainedt)inloopmax_iter~best_le~best:t~tmoduleSGD=structtypet={learn_sigma2:bool;hypers:Spec.Hyper.tarray;tau:float;eta:float;step:int;hyper_vals:vec;trained:Trained.t;gradient:vec;gradient_norm:float;}letcreate?(tau=100.)?eta0:(eta=1e-3)?(step=0)?kernel?sigma2?inducing?n_rand_inducing?(learn_sigma2=true)?hypers~inputs~targets()=letloc="Gpr.Fitc_gp.Optim.SGD.create"inletfail_neg0whatv=ifv<=0.thenfailwithf"%s: %s (%f) <= 0"locwhatv()infail_neg0"tau"tau;fail_neg0"eta0"eta;ifstep<0thenfailwithf"%s: step (%d) < 0"locstep();letsigma2=get_sigma2targetssigma2inletkernel,inducing=get_kernel_inducing?kernel?n_rand_inducing~inputsinducinginlethypers,hyper_vals=get_hypers_valskernelinducinginputshypersinlettrained=letderiv_inducing=Inducing.calckernelinducinginletderiv_inputs=Inputs.calcderiv_inducinginputsinletdmodel=Cm.calc~sigma2deriv_inputsinTrained.calcdmodel~targetsinletgradient=calc_gradient~learn_sigma2~sigma2~hypers~trainedinletgradient_norm=nrm2gradientin{learn_sigma2;hypers;tau;eta;step;hyper_vals;trained;gradient;gradient_norm}letstept=let{learn_sigma2;hypers;tau;eta;step;hyper_vals=old_hyper_vals;trained=old_trained;gradient=old_gradient;gradient_norm=_gradient_norm;}=tinletold_sigma2,old_input_points,old_inducing,old_kernel=leteval_trained=Trained.calc_evalold_trainedinleteval_model=Eval_trained.get_modeleval_trainedinletinput_points=Eval_model.get_input_pointseval_modelinletsigma2=Eval_model.get_sigma2eval_modelinletinducing=Eval_model.get_inducing_pointseval_modelinletkernel=Eval_model.get_kerneleval_modelinsigma2,input_points,inducing,kernelinletsigma2,hyper_ix=iflearn_sigma2thenexp(logold_sigma2+.eta*.old_gradient.{1}),2elseold_sigma2,1inlethyper_vals=copyold_hyper_valsinaxpy~alpha:eta~ofsx:hyper_ixold_gradienthyper_vals;lettrained=letkernel,inducing,input_points=Spec.Hyper.set_valuesold_kernelold_inducingold_input_pointshypershyper_valsinletderiv_inducing=Inducing.calckernelinducinginletderiv_inputs=Inputs.calcderiv_inducinginput_pointsinletdmodel=Cm.calc~sigma2deriv_inputsinleteval_trained=Trained.calc_evalold_trainedinlettargets=Eval_trained.get_targetseval_trainedinTrained.calcdmodel~targetsinletgradient=calc_gradient~learn_sigma2~sigma2~hypers~trainedinletgradient_norm=nrm2gradientin{twithhyper_vals;trained;gradient;gradient_norm;eta=tau/.(tau+.floatstep)*.eta;step=step+1;}letgradient_normt=t.gradient_normletget_trainedt=Trained.calc_evalt.trainedletget_etat=t.etaletget_stept=t.steplettest=make_teststepgradient_normget_trainedendmoduleSMD=structtypet={learn_sigma2:bool;hypers:Spec.Hyper.tarray;eps:float;lambda:float;mu:float;eta:vec;nu:vec;hyper_vals:vec;trained:Trained.t;gradient:vec;gradient_norm:float;}letcreate?(eps=1e-8)?lambda?mu?eta0?nu0?kernel?sigma2?inducing?n_rand_inducing?(learn_sigma2=true)?hypers~inputs~targets()=letloc="Gpr.Fitc_gp.Optim.SMD.create"inletlambda=matchlambdawith|None->0.1|Somelambdawhenlambda<0.||lambda>1.->failwithf"%s: violating 0 <= lambda(%f) <= 1"loclambda()|Somelambda->lambdainletmu=matchmuwith|None->1e-3|Somemuwhenmu<0.->failwithf"%s: violating 0 <= mu(%f)"locmu()|Somemu->muinletsigma2=get_sigma2targetssigma2inletkernel,inducing=get_kernel_inducing?kernel?n_rand_inducing~inputsinducinginlethypers,hyper_vals=get_hypers_valskernelinducinginputshypersinletn_all_hypers=letn_hypers=Array.lengthhypersiniflearn_sigma2thenn_hypers+1elsen_hypersinleteta=matcheta0with|None->Vec.maken_all_hypers1e-3|Someeta0->letdim_eta0=Vec.dimeta0inifdim_eta0<>n_all_hypersthenfailwithf"%s: dim(eta0) = %d <> n_all_hypers(%d)"locdim_eta0n_all_hypers()elsebeginfori=1ton_all_hypersdoleteta0_i=eta0.{i}inifeta0_i<=0.thenfailwithf"%s: eta0.{%d} < 0: %f"locieta0_i()done;eta0endinletnu=matchnu0with|None->Vec.maken_all_hypers1e-3|Somenu0->letdim_nu0=Vec.dimnu0inifdim_nu0<>n_all_hypersthenfailwithf"%s: dim(nu0) = %d <> n_all_hypers(%d)"locdim_nu0n_all_hypers()elsenu0inlettrained=letderiv_inducing=Inducing.calckernelinducinginletderiv_inputs=Inputs.calcderiv_inducinginputsinletdmodel=Cm.calc~sigma2deriv_inputsinTrained.calcdmodel~targetsinletgradient=calc_gradient~learn_sigma2~sigma2~hypers~trainedinletgradient_norm=nrm2gradientin{learn_sigma2;hypers;eps;lambda;mu;eta;nu;hyper_vals;trained;gradient;gradient_norm}letstept=let{learn_sigma2;hypers;eps;lambda;mu;eta=old_eta;nu=old_nu;hyper_vals=old_hyper_vals;trained=old_trained;gradient=old_gradient;gradient_norm=_gradient_norm;}=tinleteval_trained=Trained.calc_evalold_trainedinlettargets=Eval_trained.get_targetseval_trainedinleteval_model=Eval_trained.get_modeleval_trainedinletold_input_points=Eval_model.get_input_pointseval_modelinletold_sigma2=Eval_model.get_sigma2eval_modelinletlog_old_sigma2=logold_sigma2inletold_inducing=Eval_model.get_inducing_pointseval_modelinletold_kernel=Eval_model.get_kerneleval_modelinletn_hypers=Array.lengthhypersinletlambda_hessian_nu=(* Just an approximation. Would require algorithmic
differentiation for practical use if exact Hessian-vector
product is required. *)letcalc_gradeps=letsigma2,hyper_ofs=iflearn_sigma2thenexp(log_old_sigma2+.eps*.old_nu.{1}),1elseold_sigma2,0inletkernel,inducing,input_points=lethyper_vals=Vec.createn_hypersinfori1=1ton_hypersdohyper_vals.{i1}<-old_hyper_vals.{i1}+.eps*.old_nu.{i1+hyper_ofs}done;Spec.Hyper.set_valuesold_kernelold_inducingold_input_pointshypershyper_valsinletderiv_inducing=Inducing.calckernelinducinginletderiv_inputs=Inputs.calcderiv_inducinginput_pointsinletdmodel=Cm.calc~sigma2deriv_inputsinlettrained=Trained.calcdmodel~targetsincalc_gradient~learn_sigma2~sigma2~hypers~trainedinletres=Vec.sub(calc_gradeps)(calc_grad(-.eps))inscal(lambda/.(2.*.eps))res;resinletn_all_hypers=Vec.dimold_gradientinleteta=Vec.createn_all_hypersinfori=1ton_all_hypersdoeta.{i}<-old_eta.{i}*.max0.5(1.+.mu*.old_gradient.{i}*.old_nu.{i})done;letsigma2,hyper_ix=iflearn_sigma2thenexp(log_old_sigma2+.eta.{1}*.old_gradient.{1}),2elseold_sigma2,1inlethyper_vals=Vec.addold_hyper_vals(Vec.mul~n:n_hyperseta~ofsy:hyper_ixold_gradient)inletnu=Vec.mulold_eta(Vec.addold_gradientlambda_hessian_nu)inaxpy~alpha:lambdaold_nunu;lettrained=letkernel,inducing,input_points=Spec.Hyper.set_valuesold_kernelold_inducingold_input_pointshypershyper_valsinletderiv_inducing=Inducing.calckernelinducinginletderiv_inputs=Inputs.calcderiv_inducinginput_pointsinletdmodel=Cm.calc~sigma2deriv_inputsinleteval_trained=Trained.calc_evalold_trainedinlettargets=Eval_trained.get_targetseval_trainedinTrained.calcdmodel~targetsinletgradient=calc_gradient~learn_sigma2~sigma2~hypers~trainedinletgradient_norm=nrm2gradientin{twitheta;nu;hyper_vals;trained;gradient;gradient_norm}letgradient_normt=t.gradient_normletget_trainedt=Trained.calc_evalt.trainedletget_etat=t.etaletget_nut=t.nulettest=make_teststepgradient_normget_trainedendend(*
module Online = struct
module SGD = struct
type foo =
| Under_capacity of Trained.t
| Max_capacity of fdsa
type t = {
kernel : Spec.Eval.Kernel.t;
eta : float;
tau : float;
foo : foo;
}
end
module SMD = struct
type t
end
type t =
| Untrained of Spec.Eval.Kernel.t
| SGD of SGD.t
| SMD of SMD.t
let sgd
?(capacity = 10)
?mean_predictor:_ ?(eta0 = 0.1) ?(tau = 0.) kernel =
(* TODO: parameter checks *)
ignore (capacity);
{
SGD.
kernel;
eta = eta0;
tau;
mean_predictor;
}
let smd ?(capacity = 10) ?eta0 ?(mu = 0.1) ?(lam = 0.1) kernel =
(* TODO: parameter checks *)
ignore (capacity, eta0, mu, lam, kernel);
(assert false (* XXX *))
let train_sgd _sgd _input ~target:_ =
(assert false (* XXX *))
let train_smd _smd _input ~target:_ =
(assert false (* XXX *))
let train t input ~target =
match t with
| SGD sgd -> train_sgd sgd input ~target
| SMD smd -> train_smd smd input ~target
let calc_mean_predictor = function
| Untrained kernel ->
let inputs = Eval.Inputs.create [||] in
let inducing = Eval.Inputs.create_inducing kernel inputs in
{
Eval_common.Mean_predictor.
inducing;
coeffs = Vec.empty;
}
| SGD _ ->
(assert false (* XXX *))
| SMD _ ->
(assert false (* XXX *))
let calc_co_variance_predictor _t = (assert false (* XXX *))
end
*)endendmoduleMake_FITC_deriv(Spec:Specs.Deriv)=structincludeMake_common_deriv(Spec)moduleEval=structincludeEval_commonmoduleModel=Common_modelmoduleCovariances=FITC_covariancesmoduleSampler=structincludeCommon_samplerletcalc=calc~loc:fitc_locendmoduleCov_sampler=structincludeCommon_cov_samplerletcalc=calc~loc:fitc_locendendmoduleDeriv=structincludeDeriv_commonmoduleModel=Common_modelendendmoduleMake_FIC_deriv(Spec:Specs.Deriv)=structincludeMake_common_deriv(Spec)moduleEval=structincludeEval_commonmoduleModel=Common_modelmoduleCovariances=FIC_covariancesmoduleSampler=structincludeCommon_samplerletcalc=calc~loc:fic_locendmoduleCov_sampler=structincludeCommon_cov_samplerletcalc=calc~loc:fic_locendendmoduleDeriv=structincludeDeriv_commonmoduleModel=Common_modelendendmoduleMake_variational_FITC_deriv(Spec:Specs.Deriv)=structincludeMake_common_deriv(Spec)moduleEval=structincludeEval_commonmoduleModel=Variational_modelmoduleCovariances=FITC_covariancesmoduleSampler=structincludeCommon_samplerletcalc=calc~loc:fitc_locendmoduleCov_sampler=structincludeCommon_cov_samplerletcalc=calc~loc:fitc_locendendmoduleDeriv=structincludeDeriv_commonmoduleModel=Variational_modelendendmoduleMake_variational_FIC_deriv(Spec:Specs.Deriv)=structincludeMake_common_deriv(Spec)moduleEval=structincludeEval_commonmoduleModel=Variational_modelmoduleCovariances=FIC_covariancesmoduleSampler=structincludeCommon_samplerletcalc=calc~loc:fic_locendmoduleCov_sampler=structincludeCommon_cov_samplerletcalc=calc~loc:fic_locendendmoduleDeriv=structincludeDeriv_commonmoduleModel=Variational_modelendendmoduleMake_deriv(Spec:Specs.Deriv)=structmoduletypeSig=Sigs.DerivwithmoduleEval.Spec=Spec.EvalwithmoduleDeriv.Spec=SpecmoduleCommon_deriv=Make_common_deriv(Spec)moduleFITC=structincludeCommon_derivmoduleEval=structincludeEval_commonmoduleModel=Common_modelmoduleCovariances=FITC_covariancesmoduleSampler=structincludeCommon_samplerletcalc=calc~loc:fitc_locendmoduleCov_sampler=structincludeCommon_cov_samplerletcalc=calc~loc:fitc_locendendmoduleDeriv=structincludeDeriv_commonmoduleModel=Common_modelendendmoduleFIC=structincludeCommon_derivmoduleEval=structincludeEval_commonmoduleModel=Variational_modelmoduleCovariances=FIC_covariancesmoduleSampler=structincludeCommon_samplerletcalc=calc~loc:fitc_locendmoduleCov_sampler=structincludeCommon_cov_samplerletcalc=calc~loc:fitc_locendendmoduleDeriv=structincludeDeriv_commonmoduleModel=Variational_modelendendmoduleVariational_FITC=structincludeCommon_derivmoduleEval=structincludeEval_commonmoduleModel=Variational_modelmoduleCovariances=FITC_covariancesmoduleSampler=structincludeCommon_samplerletcalc=calc~loc:fitc_locendmoduleCov_sampler=structincludeCommon_cov_samplerletcalc=calc~loc:fitc_locendendmoduleDeriv=structincludeDeriv_commonmoduleModel=Variational_modelendendmoduleVariational_FIC=structincludeCommon_derivmoduleEval=structincludeEval_commonmoduleModel=Variational_modelmoduleCovariances=FIC_covariancesmoduleSampler=structincludeCommon_samplerletcalc=calc~loc:fic_locendmoduleCov_sampler=structincludeCommon_cov_samplerletcalc=calc~loc:fic_locendendmoduleDeriv=structincludeDeriv_commonmoduleModel=Variational_modelendendend