123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645# 1 "src/base/stats/owl_base_stats.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2019 Liang Wang <liang.wang@cl.cam.ac.uk>
*)(* PRNG of various distributions *)letstd_uniform_rvs=Owl_base_stats_dist_uniform.std_uniform_rvsletuniform_int_rvs=Owl_base_stats_dist_uniform.uniform_int_rvsletuniform_rvs=Owl_base_stats_dist_uniform.uniform_rvsletbernoulli_rvs=Owl_base_stats_dist_bernoulli.bernoulli_rvsletgaussian_rvs=Owl_base_stats_dist_gaussian.gaussian_rvsletexponential_rvs=Owl_base_stats_dist_exponential.exponential_rvsletcauchy_rvs=Owl_base_stats_dist_cauchy.cauchy_rvsletstd_gamma_rvs=Owl_base_stats_dist_gamma.std_gamma_rvsletgamma_rvs=Owl_base_stats_dist_gamma.gamma_rvsletgumbel1_rvs=Owl_base_stats_dist_gumbel1.gumbel1_rvsletgumbel2_rvs=Owl_base_stats_dist_gumbel2.gumbel2_rvs(* Randomisation function *)letshufflex=lety=Array.copyxinletn=Array.lengthxinfori=n-1downto1dolets=float_of_int(i+1)inletj=int_of_float(std_uniform_rvs()*.s)inOwl_utils_array.swapyijdone;yletchoosexk=letn=Array.lengthxinassert(n>=k);lety=Array.makekx.(0)inleti=ref0inletj=ref0inwhile!i<n&&!j<kdolets=float_of_int(n-!i)inletl=int_of_float(s*.std_uniform_rvs())inifl<(k-!j)then(y.(!j)<-x.(!i);j:=!j+1;);i:=!i+1;done;yletsamplexk=lety=Array.makekx.(0)inletn=Array.lengthxinfori=0tok-1doletj=uniform_int_rvsniny.(i)<-x.(j)done;y(* Basic statistical functions *)letsumx=Array.fold_left(+.)0.xletmeanx=letn=float_of_int(Array.lengthx)insumx/.nlet_get_meanmx=matchmwith|Somea->a|None->meanxletvar?meanx=letm=_get_meanmeanxinlett=ref0.inArray.iter(funa->letd=a-.mint:=!t+.d*.d)x;letl=float_of_int(Array.lengthx)inletn=ifl=1.then1.elsel-.1.in!t/.nletstd?meanx=sqrt(var?meanx)letsem?meanx=lets=std?meanxinletn=float_of_int(Array.lengthx)ins/.(sqrtn)letabsdev?meanx=letm=_get_meanmeanxinlett=ref0.inArray.iter(funa->letd=abs_float(a-.m)int:=!t+.d)x;letn=float_of_int(Array.lengthx)in!t/.nletskew?mean?sdx=letm=_get_meanmeanxinlets=matchsdwith|Somea->a|None->std~mean:mxinlett=ref0.inArray.iter(funa->lets=(a-.m)/.sint:=!t+.s*.s*.s)x;letn=float_of_int(Array.lengthx)in!t/.nletkurtosis?mean?sdx=letm=_get_meanmeanxinlets=matchsdwith|Somea->a|None->std~mean:mxinlett=ref0.inArray.iter(funa->lets=(a-.m)/.sinletu=s*.sint:=!t+.u*.u)x;letn=float_of_int(Array.lengthx)in!t/.nletcentral_momentnx=letm=float_of_intninletu=meanxinletx=Array.map(funx->(x-.u)**m)xinleta=Array.fold_left(+.)0.xina/.(float_of_int(Array.lengthx))letcov?m0?m1x0x1=letn0=Array.lengthx0inletn1=Array.lengthx1inassert(n0=n1);letm0=_get_meanm0x0inletm1=_get_meanm1x1inlett=ref0.inArray.iter2(funa0a1->letd0=a0-.m0inletd1=a1-.m1int:=!t+.d0*.d1)x0x1;letn=float_of_int(Array.lengthx0)in!t/.nletconcordantx0x1=letc=ref0infori=0to(Array.lengthx0)-2doforj=i+1to(Array.lengthx0)-1doif(i<>j)&&(((x0.(i)<x0.(j))&&(x1.(i)<x1.(j)))||((x0.(i)>x0.(j))&&(x1.(i)>x1.(j))))thenc:=!c+1donedone;!cletdiscordantx0x1=letc=ref0infori=0to(Array.lengthx0)-2doforj=i+1to(Array.lengthx0)-1doif(i<>j)&&(((x0.(i)<x0.(j))&&(x1.(i)>x1.(j)))||((x0.(i)>x0.(j))&&(x1.(i)<x1.(j))))thenc:=!c+1donedone;!cletkendall_taux0x1=leta=float_of_int(concordantx0x1)inletb=float_of_int(discordantx0x1)inletn=float_of_int(Array.lengthx0)in2.*.(a-.b)/.(n*.(n-.1.))letsort?(inc=true)x=lety=Array.copyxinletc=ifincthen1else(-1)inArray.sort(funab->ifa<bthen(-c)elseifa>bthencelse0)y;yletargsort?(inc=true)x=letn=Array.lengthxinletdir=ifincthen1else(-1)inletorder=Array.initn(funi->i)inbeginArray.sort(funij->dir*comparex.(i)x.(j))order;orderendlet_resolve_tiesnextd=function|`Average->float_of_intnext-.float_of_intd/.2.|`Min->float_of_int(next-d)|`Max->float_of_intnextletrank?(ties_strategy=`Average)vs=letn=Array.lengthvsinletorder=argsortvsinletranks=Array.maken0.inletd=ref0inbeginfori=0ton-1doifi==n-1||comparevs.(order.(i))vs.(order.(i+1))<>0thenlettie_rank=_resolve_ties(i+1)!dties_strategyinforj=i-!dtoidoranks.(order.(j))<-tie_rankdone;d:=0elseincrd(* Found a duplicate! *)done;end;ranksletminmax_ix=assert(Array.lengthx>0);let_min=refx.(0)inlet_max=refx.(0)inlet_min_idx=ref0inlet_max_idx=ref0inArray.iteri(funia->ifa<!_minthen(_min:=a;_min_idx:=i)elseifa>!_maxthen(_max:=a;_max_idx:=i;))x;!_min_idx,!_max_idxletmin_ix=minmax_ix|>fstletmax_ix=minmax_ix|>sndletminx=Array.fold_leftmininfinityxletmaxx=Array.fold_leftmaxneg_infinityxletminmaxx=let_min=refinfinityinlet_max=refneg_infinityinArray.iter(funa->ifa<!_minthen_min:=a;ifa>!_maxthen_max:=a;)x;!_min,!_maxtypehistogram={bins:floatarray;counts:intarray;weighted_counts:floatarrayoption;normalised_counts:floatarrayoption;density:floatarrayoption}(* fast unsafe array access! *)let(.%())=Array.unsafe_getand(.%()<-)=Array.unsafe_set(*let (.%()) = Array.get and (.%()<-) = Array.set*)letfcmp=(compare:>float->_->_)(* fast direct binning for uniform bins *)letsetup_uniform_binningnx=ifn<1thenfailwith"Need at least one bin!";letbmin,bmax=minmaxxinletdb=(bmax-.bmin)/.float_of_intninletbins=Array.init(n+1)(funi->bmin+.float_of_inti*.db)inletget_biny=leti=int_of_float((y-.bmin)/.db)inifi=nthenn-1elseiinbmin,bmax,db,bins,get_binlethist_uniformnx=let_bmin,_bmax,_db,bins,get_bin=setup_uniform_binningnxinletc=Array.maken0inx|>Array.iter(funy->leti=get_binyinc.%(i)<-c.%(i)+1);bins,clethist_weighted_uniformnwx=ifArray.(lengthx<>lengthw)thenfailwith"Data and weights must have the same length.";let_bmin,_bmax,_db,bins,get_bin=setup_uniform_binningnxinletc=Array.maken0inletwc=Array.maken0.inx|>Array.iteri(funjy->leti=get_binyinc.%(i)<-c.%(i)+1;wc.%(i)<-wc.%(i)+.w.%(j));(bins,c),Somewc(* nonuniform binning with binary search *)letsetup_nonuniform_binningbins=letn=Array.lengthbins-1inifn<1thenfailwith"Need at least two bin boundaries!";letget_biny=leti=Owl_utils_array.bsearch~cmp:fcmpybinsinifi=n&&fcmpybins.(n)=0theni-1elseiinn,get_binlethist_nonuniformbinsx=letn,get_bin=setup_nonuniform_binningbinsinletc=Array.maken0inx|>Array.iter(funy->leti=get_binyin(* drop out-of-bounds without exception.
* compatible with -unsafe compiler option. *)if0<=i&&i<nthenc.(i)<-c.(i)+1);bins,clethist_weighted_nonuniformbinswx=ifArray.(lengthx<>lengthw)thenfailwith"Data and weights must have the same length.";letn,get_bin=setup_nonuniform_binningbinsinletc=Array.maken0inletwc=Array.maken0.inx|>Array.iteri(funjy->leti=get_binyinif0<=i&&i<nthen(c.(i)<-c.(i)+1;wc.(i)<-wc.(i)+.w.(j)));(bins,c),Somewc(* binning of sorted arrays by accumulating bin by bin *)letmake_uniform_bins_from_sortednx=letbmin,bmax=minmaxxinletdb=(bmax-.bmin)/.float_of_intninArray.init(n+1)(funi->bmin+.float_of_inti*.db)(* find start of upward iteration *)letinit_sortedbinsx=(* Number of bins (not boundaries) *)letn=Array.lengthbins-1inifn<1thenfailwith"Need at least two bin boundaries!";letm=Array.lengthxinletbs=Owl_utils_array.bsearch~cmp:fcmpinleti,j=iffcmpbins.(0)x.(0)<1thenbsx.(0)bins,0else0,1+bsbins.(0)xinassert(0<=i&&i<=n);assert(0<=j&&j<m);n,m,refi,refjlethist_sortedbinsx=letn,m,bin_i,x_j=init_sortedbinsxinletc=Array.maken0in(* now iterate up the bins.
* to the second-to-last bin since we look ahead by one. *)while!bin_i<=n-1dowhile!x_j<m&&fcmpx.%(!x_j)bins.%(!bin_i+1)<0doc.%(!bin_i)<-c.%(!bin_i)+1;incrx_jdone;incrbin_idone;(* last bin is right-inclusive... *)while!x_j<m&&fcmpx.%(!x_j)bins.%(n)=0doc.%(n-1)<-c.%(n-1)+1;incrx_jdone;bins,clethist_weighted_sortedbinswx=ifArray.(lengthx<>lengthw)thenfailwith"Data and weights must have the same length.";letn,m,bin_i,x_j=init_sortedbinsxinletc=Array.maken0inletwc=Array.maken0.inwhile!bin_i<=n-1dowhile!x_j<m&&fcmpx.%(!x_j)bins.%(!bin_i+1)<0doc.%(!bin_i)<-c.%(!bin_i)+1;wc.%(!bin_i)<-wc.%(!bin_i)+.w.%(!x_j);incrx_jdone;incrbin_idone;(* last bin is right-inclusive... *)while!x_j<m&&fcmpx.%(!x_j)bins.%(n)=0doc.%(n-1)<-c.%(n-1)+1;wc.%(n-1)<-wc.%(n-1)+.w.%(!x_j);incrx_jdone;(bins,c),Somewclethist_to_string{bins;counts;weighted_counts;normalised_counts;density}=letn_counts=Array.fold_left(+)0countsinletn=Array.lengthbins-1inletw=matchweighted_countswith|None->""|Somewc->lettot=Array.fold_left(+.)0.wcinPrintf.sprintf"; tot. weight: %g"totinletnorm=matchnormalised_countswith|None->""|Some_->"; normalised"inletpdf=matchdensitywith|None->""|Some_->"; density"inPrintf.sprintf"[ N: %i; N_bins: %i%s%s%s ]"n_countsnwnormpdf(* now the external api *)lethistogram(bins:[`Nofint|`Binsoffloatarray])?weightsx=let(bins,counts),weighted_counts=matchbins,weightswith|`Nn,None->hist_uniformnx,None|`Binsb,None->hist_nonuniformbx,None|`Nn,Somew->hist_weighted_uniformnxw|`Binsb,Somew->hist_weighted_nonuniformbwxin{bins;counts;weighted_counts;normalised_counts=None;density=None}lethistogram_sorted(bins:[`Nofint|`Binsoffloatarray])?weightsx=letbins=matchbinswith|`Nn->make_uniform_bins_from_sortednx|`Binsb->binlet(bins,counts),weighted_counts=matchweightswith|None->hist_sortedbinsx,None|Somew->hist_weighted_sortedbinswxin{bins;counts;weighted_counts;normalised_counts=None;density=None}letnormalise({counts;weighted_counts;_}ash)=letnc=matchweighted_countswith|None->lettotal=Array.fold_left(+)0counts|>float_of_intinArray.map(func->float_of_intc/.total)counts|Somewcounts->lettotal=Array.fold_left(+.)0.wcountsinArray.map(funwc->wc/.total)wcountsin{hwithnormalised_counts=Somenc}letnormalise_density({bins;counts;weighted_counts;_}ash)=letds=matchweighted_countswith|None->lettotal=Array.fold_left(+)0counts|>float_of_intinArray.mapi(funic->float_of_intc/.(bins.(i+1)-.bins.(i))/.total)counts|Somewcounts->lettotal=Array.fold_left(+.)0.wcountsinArray.mapi(funiwc->wc/.(bins.(i+1)-.bins.(i))/.total)wcountsin{hwithdensity=Someds}letpp_histformatterhist=Format.open_box0;Format.fprintfformatter"%s"(hist_to_stringhist);Format.close_box()letquantilex=lety=sort~inc:truexinletn=Array.lengthyinfunp->ifp<0.||p>1.thenraise(Invalid_argument"Owl_base_stats.quantile: expected float between 0 and 1")elseletindex=p*.(float_of_int(n-1))inletlhs=int_of_floatindexinletdelta=index-.(float_of_intlhs)inifn=0then0.else(iflhs=n-1theny.(lhs)else(1.-.delta)*.y.(lhs)+.delta*.y.(lhs+1))letpercentilexp=quantilex(p/.100.)letmedianx=percentilex50.letfirst_quartilex=percentilex25.letthird_quartilex=percentilex75.letinterquartilex=third_quartilex-.first_quartilexlettukey_fences?(k=1.5)arr=letfirst_quartile=first_quartilearrinletthird_quartile=third_quartilearrinletoffset=k*.(third_quartile-.first_quartile)infirst_quartile-.offset,third_quartile+.offsetletbuild_kernel=function|`Gaussian->funhpv->letu=(v-.p)/.hin(1./.(Owl_const.(sqrt2*.sqrtpi)))*.exp(-.Owl_base_maths.sqru/.2.)letbuild_pointsn_pointshkernelvs=let(min,max)=minmaxvsinlet(a,b)=matchkernelwith|`Gaussian->(min-.3.*.h,max+.3.*.h)inletpoints=Array.maken_points0.inletstep=(b-.a)/.float_of_intn_pointsinfori=0ton_points-1doArray.unsafe_setpointsi(a+.(float_of_inti)*.step)done;pointsletgaussian_kde?(bandwidth=`Scott)?(n_points=512)vs=ifArray.lengthvs<2theninvalid_arg"estimate_pdf: sample should have multiple elements";letn=float_of_int(Array.lengthvs)inlets=min[|(stdvs);(interquartilevs/.0.34)|]inleth=matchbandwidthwith|`Silverman->0.90*.s*.(n**-0.2)|`Scott->1.06*.s*.(n**-0.2)inletkernel=`Gaussianinletpoints=build_pointsn_pointshkernelvsinletk=build_kernelkernelinletf=1./.(h*.n)inletpdf=Array.maken_points0.infori=0ton_points-1doletp=Array.unsafe_getpointsiinArray.unsafe_setpdfi(f*.Array.fold_left(funaccv->acc+.khpv)0.vs)done;(points,pdf)(* ends here *)