123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436# 1 "src/owl/ppl/owl_distribution_generic.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*)openBigarrayopenOwl_dense_ndarray_genericopenOwl_distribution_common(* TODO: broadcast_* are most temp solutions atm. *)(* TODO: same as that in Owl_dense_ndarray_generic, need to combine soon. *)letbroadcast_align_shapex0x1=(* align the rank of inputs *)letd0=num_dimsx0inletd1=num_dimsx1inletd3=Stdlib.maxd0d1inlety0=expandx0d3inlety1=expandx1d3in(* check whether the shape is valid *)lets0=shapey0inlets1=shapey1inArray.iter2(funab->assert(not(a<>1&&b<>1&&a<>b)))s0s1;(* calculate the strides *)lett0=Owl_utils.calc_strides0|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inlett1=Owl_utils.calc_strides1|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1in(* return aligned arrays, shapes, strides *)y0,y1,s0,s1,t0,t1(* broadcast for [f : x -> y -> z] *)letbroadcast_op0opx0x1n=(* align the input rank, calculate the output shape and stride *)lety0,y1,s0,s1,t0,t1=broadcast_align_shapex0x1inlets2=Array.(map2Stdlib.maxs0s1|>append[|n|])inlett2=Owl_utils.calc_strides2|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1inlety2=empty(kindx0)s2in(* call the specific map function *)opy0t0y1t1y2t2;y2letbroadcast_op2opx0x1y2=(* align the input rank, calculate the output shape and stride *)lety0,y1,_s0,_s1,t0,t1=broadcast_align_shapex0x1inlety2=copyy2inlets2=shapey2inlett2=Owl_utils.calc_strides2|>Array.mapInt64.of_int|>Array1.of_arrayint64c_layout|>genarray_of_array1in(* call the specific map function *)opy0t0y1t1y2t2;y2(* broadcast for [f : x -> y] *)letbroadcast_op1opxn=lety=empty(kindx)(Array.append[|n|](shapex))inletx,y,_sx,_sy,tx,ty=broadcast_align_shapexyin(* call the specific map function *)opxtxytyyty;yletbroadcast_op3opxy=lety=copyyinletx,y,_sx,_sy,tx,ty=broadcast_align_shapexyin(* call the specific map function *)opxtxytyyty;yletuniform_rvs~a~b~n=broadcast_op0(_owl_uniform_rvs(kinda))abnletuniform_pdf~a~bx=broadcast_op2(_owl_uniform_pdf(kindx))abxletuniform_logpdf~a~bx=broadcast_op2(_owl_uniform_logpdf(kindx))abxletuniform_cdf~a~bx=broadcast_op2(_owl_uniform_cdf(kindx))abxletuniform_logcdf~a~bx=broadcast_op2(_owl_uniform_logcdf(kindx))abxletuniform_ppf~a~bx=broadcast_op2(_owl_uniform_ppf(kindx))abxletuniform_sf~a~bx=broadcast_op2(_owl_uniform_sf(kindx))abxletuniform_logsf~a~bx=broadcast_op2(_owl_uniform_logsf(kindx))abxletuniform_isf~a~bx=broadcast_op2(_owl_uniform_isf(kindx))abxletgaussian_rvs~mu~sigma~n=broadcast_op0(_owl_gaussian_rvs(kindmu))musigmanletgaussian_pdf~mu~sigmax=broadcast_op2(_owl_gaussian_pdf(kindx))musigmaxletgaussian_logpdf~mu~sigmax=broadcast_op2(_owl_gaussian_logpdf(kindx))musigmaxletgaussian_cdf~mu~sigmax=broadcast_op2(_owl_gaussian_cdf(kindx))musigmaxletgaussian_logcdf~mu~sigmax=broadcast_op2(_owl_gaussian_logcdf(kindx))musigmaxletgaussian_ppf~mu~sigmax=broadcast_op2(_owl_gaussian_ppf(kindx))musigmaxletgaussian_sf~mu~sigmax=broadcast_op2(_owl_gaussian_sf(kindx))musigmaxletgaussian_logsf~mu~sigmax=broadcast_op2(_owl_gaussian_logsf(kindx))musigmaxletgaussian_isf~mu~sigmax=broadcast_op2(_owl_gaussian_isf(kindx))musigmaxletexponential_rvs~lambda~n=broadcast_op1(_owl_exponential_rvs(kindlambda))lambdanletexponential_pdf~lambdax=broadcast_op2(_owl_exponential_pdf(kindx))lambdalambdaxletexponential_logpdf~lambdax=broadcast_op2(_owl_exponential_logpdf(kindx))lambdalambdaxletexponential_cdf~lambdax=broadcast_op2(_owl_exponential_cdf(kindx))lambdalambdaxletexponential_logcdf~lambdax=broadcast_op2(_owl_exponential_logcdf(kindx))lambdalambdaxletexponential_ppf~lambdax=broadcast_op2(_owl_exponential_ppf(kindx))lambdalambdaxletexponential_sf~lambdax=broadcast_op2(_owl_exponential_sf(kindx))lambdalambdaxletexponential_logsf~lambdax=broadcast_op2(_owl_exponential_logsf(kindx))lambdalambdaxletexponential_isf~lambdax=broadcast_op2(_owl_exponential_isf(kindx))lambdalambdaxletpoisson_rvs~mu~n=broadcast_op1(_owl_poisson_rvs(kindmu))munletgamma_rvs~shape~scale~n=broadcast_op0(_owl_gamma_rvs(kindshape))shapescalenletgamma_pdf~shape~scalex=broadcast_op2(_owl_gamma_pdf(kindx))shapescalexletgamma_logpdf~shape~scalex=broadcast_op2(_owl_gamma_logpdf(kindx))shapescalexletgamma_cdf~shape~scalex=broadcast_op2(_owl_gamma_cdf(kindx))shapescalexletgamma_logcdf~shape~scalex=broadcast_op2(_owl_gamma_logcdf(kindx))shapescalexletgamma_ppf~shape~scalex=broadcast_op2(_owl_gamma_ppf(kindx))shapescalexletgamma_sf~shape~scalex=broadcast_op2(_owl_gamma_sf(kindx))shapescalexletgamma_logsf~shape~scalex=broadcast_op2(_owl_gamma_logsf(kindx))shapescalexletgamma_isf~shape~scalex=broadcast_op2(_owl_gamma_isf(kindx))shapescalexletbeta_rvs~a~b~n=broadcast_op0(_owl_beta_rvs(kinda))abnletbeta_pdf~a~bx=broadcast_op2(_owl_beta_pdf(kindx))abxletbeta_logpdf~a~bx=broadcast_op2(_owl_beta_logpdf(kindx))abxletbeta_cdf~a~bx=broadcast_op2(_owl_beta_cdf(kindx))abxletbeta_logcdf~a~bx=broadcast_op2(_owl_beta_logcdf(kindx))abxletbeta_ppf~a~bx=broadcast_op2(_owl_beta_ppf(kindx))abxletbeta_sf~a~bx=broadcast_op2(_owl_beta_sf(kindx))abxletbeta_logsf~a~bx=broadcast_op2(_owl_beta_logsf(kindx))abxletbeta_isf~a~bx=broadcast_op2(_owl_beta_isf(kindx))abxletchi2_rvs~df~n=broadcast_op1(_owl_chi2_rvs(kinddf))dfnletchi2_pdf~dfx=broadcast_op2(_owl_chi2_pdf(kindx))dfdfxletchi2_logpdf~dfx=broadcast_op2(_owl_chi2_logpdf(kindx))dfdfxletchi2_cdf~dfx=broadcast_op2(_owl_chi2_cdf(kindx))dfdfxletchi2_logcdf~dfx=broadcast_op2(_owl_chi2_logcdf(kindx))dfdfxletchi2_ppf~dfx=broadcast_op2(_owl_chi2_ppf(kindx))dfdfxletchi2_sf~dfx=broadcast_op2(_owl_chi2_sf(kindx))dfdfxletchi2_logsf~dfx=broadcast_op2(_owl_chi2_logsf(kindx))dfdfxletchi2_isf~dfx=broadcast_op2(_owl_chi2_isf(kindx))dfdfxletf_rvs~dfnum~dfden~n=broadcast_op0(_owl_f_rvs(kinddfnum))dfnumdfdennletf_pdf~dfnum~dfdenx=broadcast_op2(_owl_f_pdf(kindx))dfnumdfdenxletf_logpdf~dfnum~dfdenx=broadcast_op2(_owl_f_logpdf(kindx))dfnumdfdenxletf_cdf~dfnum~dfdenx=broadcast_op2(_owl_f_cdf(kindx))dfnumdfdenxletf_logcdf~dfnum~dfdenx=broadcast_op2(_owl_f_logcdf(kindx))dfnumdfdenxletf_ppf~dfnum~dfdenx=broadcast_op2(_owl_f_ppf(kindx))dfnumdfdenxletf_sf~dfnum~dfdenx=broadcast_op2(_owl_f_sf(kindx))dfnumdfdenxletf_logsf~dfnum~dfdenx=broadcast_op2(_owl_f_logsf(kindx))dfnumdfdenxletf_isf~dfnum~dfdenx=broadcast_op2(_owl_f_isf(kindx))dfnumdfdenxletcauchy_rvs~loc~scale~n=broadcast_op0(_owl_cauchy_rvs(kindloc))locscalenletcauchy_pdf~loc~scalex=broadcast_op2(_owl_cauchy_pdf(kindx))locscalexletcauchy_logpdf~loc~scalex=broadcast_op2(_owl_cauchy_logpdf(kindx))locscalexletcauchy_cdf~loc~scalex=broadcast_op2(_owl_cauchy_cdf(kindx))locscalexletcauchy_logcdf~loc~scalex=broadcast_op2(_owl_cauchy_logcdf(kindx))locscalexletcauchy_ppf~loc~scalex=broadcast_op2(_owl_cauchy_ppf(kindx))locscalexletcauchy_sf~loc~scalex=broadcast_op2(_owl_cauchy_sf(kindx))locscalexletcauchy_logsf~loc~scalex=broadcast_op2(_owl_cauchy_logsf(kindx))locscalexletcauchy_isf~loc~scalex=broadcast_op2(_owl_cauchy_isf(kindx))locscalexletlomax_rvs~shape~scale~n=broadcast_op0(_owl_lomax_rvs(kindshape))shapescalenletlomax_pdf~shape~scalex=broadcast_op2(_owl_lomax_pdf(kindx))shapescalexletlomax_logpdf~shape~scalex=broadcast_op2(_owl_lomax_logpdf(kindx))shapescalexletlomax_cdf~shape~scalex=broadcast_op2(_owl_lomax_cdf(kindx))shapescalexletlomax_logcdf~shape~scalex=broadcast_op2(_owl_lomax_logcdf(kindx))shapescalexletlomax_ppf~shape~scalex=broadcast_op2(_owl_lomax_ppf(kindx))shapescalexletlomax_sf~shape~scalex=broadcast_op2(_owl_lomax_sf(kindx))shapescalexletlomax_logsf~shape~scalex=broadcast_op2(_owl_lomax_logsf(kindx))shapescalexletlomax_isf~shape~scalex=broadcast_op2(_owl_lomax_isf(kindx))shapescalexletweibull_rvs~shape~scale~n=broadcast_op0(_owl_weibull_rvs(kindshape))shapescalenletweibull_pdf~shape~scalex=broadcast_op2(_owl_weibull_pdf(kindx))shapescalexletweibull_logpdf~shape~scalex=broadcast_op2(_owl_weibull_logpdf(kindx))shapescalexletweibull_cdf~shape~scalex=broadcast_op2(_owl_weibull_cdf(kindx))shapescalexletweibull_logcdf~shape~scalex=broadcast_op2(_owl_weibull_logcdf(kindx))shapescalexletweibull_ppf~shape~scalex=broadcast_op2(_owl_weibull_ppf(kindx))shapescalexletweibull_sf~shape~scalex=broadcast_op2(_owl_weibull_sf(kindx))shapescalexletweibull_logsf~shape~scalex=broadcast_op2(_owl_weibull_logsf(kindx))shapescalexletweibull_isf~shape~scalex=broadcast_op2(_owl_weibull_isf(kindx))shapescalexletlaplace_rvs~loc~scale~n=broadcast_op0(_owl_laplace_rvs(kindloc))locscalenletlaplace_pdf~loc~scalex=broadcast_op2(_owl_laplace_pdf(kindx))locscalexletlaplace_logpdf~loc~scalex=broadcast_op2(_owl_laplace_logpdf(kindx))locscalexletlaplace_cdf~loc~scalex=broadcast_op2(_owl_laplace_cdf(kindx))locscalexletlaplace_logcdf~loc~scalex=broadcast_op2(_owl_laplace_logcdf(kindx))locscalexletlaplace_ppf~loc~scalex=broadcast_op2(_owl_laplace_ppf(kindx))locscalexletlaplace_sf~loc~scalex=broadcast_op2(_owl_laplace_sf(kindx))locscalexletlaplace_logsf~loc~scalex=broadcast_op2(_owl_laplace_logsf(kindx))locscalexletlaplace_isf~loc~scalex=broadcast_op2(_owl_laplace_isf(kindx))locscalexletgumbel1_rvs~a~b~n=broadcast_op0(_owl_gumbel1_rvs(kinda))abnletgumbel1_pdf~a~bx=broadcast_op2(_owl_gumbel1_pdf(kindx))abxletgumbel1_logpdf~a~bx=broadcast_op2(_owl_gumbel1_logpdf(kindx))abxletgumbel1_cdf~a~bx=broadcast_op2(_owl_gumbel1_cdf(kindx))abxletgumbel1_logcdf~a~bx=broadcast_op2(_owl_gumbel1_logcdf(kindx))abxletgumbel1_ppf~a~bx=broadcast_op2(_owl_gumbel1_ppf(kindx))abxletgumbel1_sf~a~bx=broadcast_op2(_owl_gumbel1_sf(kindx))abxletgumbel1_logsf~a~bx=broadcast_op2(_owl_gumbel1_logsf(kindx))abxletgumbel1_isf~a~bx=broadcast_op2(_owl_gumbel1_isf(kindx))abxletgumbel2_rvs~a~b~n=broadcast_op0(_owl_gumbel2_rvs(kinda))abnletgumbel2_pdf~a~bx=broadcast_op2(_owl_gumbel2_pdf(kindx))abxletgumbel2_logpdf~a~bx=broadcast_op2(_owl_gumbel2_logpdf(kindx))abxletgumbel2_cdf~a~bx=broadcast_op2(_owl_gumbel2_cdf(kindx))abxletgumbel2_logcdf~a~bx=broadcast_op2(_owl_gumbel2_logcdf(kindx))abxletgumbel2_ppf~a~bx=broadcast_op2(_owl_gumbel2_ppf(kindx))abxletgumbel2_sf~a~bx=broadcast_op2(_owl_gumbel2_sf(kindx))abxletgumbel2_logsf~a~bx=broadcast_op2(_owl_gumbel2_logsf(kindx))abxletgumbel2_isf~a~bx=broadcast_op2(_owl_gumbel2_isf(kindx))abxletlogistic_rvs~loc~scale~n=broadcast_op0(_owl_logistic_rvs(kindloc))locscalenletlogistic_pdf~loc~scalex=broadcast_op2(_owl_logistic_pdf(kindx))locscalexletlogistic_logpdf~loc~scalex=broadcast_op2(_owl_logistic_logpdf(kindx))locscalexletlogistic_cdf~loc~scalex=broadcast_op2(_owl_logistic_cdf(kindx))locscalexletlogistic_logcdf~loc~scalex=broadcast_op2(_owl_logistic_logcdf(kindx))locscalexletlogistic_ppf~loc~scalex=broadcast_op2(_owl_logistic_ppf(kindx))locscalexletlogistic_sf~loc~scalex=broadcast_op2(_owl_logistic_sf(kindx))locscalexletlogistic_logsf~loc~scalex=broadcast_op2(_owl_logistic_logsf(kindx))locscalexletlogistic_isf~loc~scalex=broadcast_op2(_owl_logistic_isf(kindx))locscalexletlognormal_rvs~mu~sigma~n=broadcast_op0(_owl_lognormal_rvs(kindmu))musigmanletlognormal_pdf~mu~sigmax=broadcast_op2(_owl_lognormal_pdf(kindx))musigmaxletlognormal_logpdf~mu~sigmax=broadcast_op2(_owl_lognormal_logpdf(kindx))musigmaxletlognormal_cdf~mu~sigmax=broadcast_op2(_owl_lognormal_cdf(kindx))musigmaxletlognormal_logcdf~mu~sigmax=broadcast_op2(_owl_lognormal_logcdf(kindx))musigmaxletlognormal_ppf~mu~sigmax=broadcast_op2(_owl_lognormal_ppf(kindx))musigmaxletlognormal_sf~mu~sigmax=broadcast_op2(_owl_lognormal_sf(kindx))musigmaxletlognormal_logsf~mu~sigmax=broadcast_op2(_owl_lognormal_logsf(kindx))musigmaxletlognormal_isf~mu~sigmax=broadcast_op2(_owl_lognormal_isf(kindx))musigmaxletrayleigh_rvs~sigma~n=broadcast_op1(_owl_rayleigh_rvs(kindsigma))sigmanletrayleigh_pdf~sigmax=broadcast_op2(_owl_rayleigh_pdf(kindx))sigmasigmaxletrayleigh_logpdf~sigmax=broadcast_op2(_owl_rayleigh_logpdf(kindx))sigmasigmaxletrayleigh_cdf~sigmax=broadcast_op2(_owl_rayleigh_cdf(kindx))sigmasigmaxletrayleigh_logcdf~sigmax=broadcast_op2(_owl_rayleigh_logcdf(kindx))sigmasigmaxletrayleigh_ppf~sigmax=broadcast_op2(_owl_rayleigh_ppf(kindx))sigmasigmaxletrayleigh_sf~sigmax=broadcast_op2(_owl_rayleigh_sf(kindx))sigmasigmaxletrayleigh_logsf~sigmax=broadcast_op2(_owl_rayleigh_logsf(kindx))sigmasigmaxletrayleigh_isf~sigmax=broadcast_op2(_owl_rayleigh_isf(kindx))sigmasigmax(* ends here *)