123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318# 1 "src/base/maths/owl_base_maths.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*)letaddxy=x+.yletsubxy=x-.yletmulxy=x*.yletdivxy=x/.yletpowxy=x**yletfmodxy=Stdlib.mod_floatxyletatan2xy=Stdlib.atan2xylethypotxy=Stdlib.hypotxyletabsx=Stdlib.abs_floatxletnegx=~-.xletrecix=1./.xletsoftsignx=x/.(1.+.absx)letsoftplusx=log(1.+.expx)letsqrx=x*.xletsqrtx=Stdlib.sqrtxletcbrtx=x**0.33333333333333333333letexpx=Stdlib.expxletexp2x=2.**xletexp10x=10.**xletexpm1x=Stdlib.expm1xletlogx=Stdlib.logxletlog2x=logx/.Owl_const.loge2letlog10x=Stdlib.log10xletlog1px=Stdlib.log1pxletsignumx=ifFP_nan=classify_floatxthennanelseifx>0.then1.elseifx<0.then~-.1.else0.letfloorx=Stdlib.floorxletceilx=Stdlib.ceilxletroundx=floor(x+.0.5)lettruncx=modfx|>sndletfixx=ifx<0.thenceilxelsefloorxletsinx=Stdlib.sinxletcosx=Stdlib.cosxlettanx=Stdlib.tanxletcotx=1./.tanxletsecx=1./.cosxletcscx=1./.sinxletsinhx=Stdlib.sinhxletcoshx=Stdlib.coshxlettanhx=Stdlib.tanhxletasinx=Stdlib.asinxletacosx=Stdlib.acosxletatanx=Stdlib.atanxletacotx=(Owl_const.pi/.2.)-.atanxletasecx=Stdlib.acos(1./.x)letacscx=Stdlib.asin(1./.x)letasinhx=log(x+.sqrt((x*.x)+.1.))letacoshx=log(x+.sqrt((x*.x)-.1.))letatanhx=0.5*.log((1.+.x)/.(1.-.x))letacothx=atanh(1./.x)letasechx=acosh(1./.x)letacschx=asinh(1./.x)letrelux=Stdlib.max0.xletdawsn_x=raise(Owl_exception.NOT_IMPLEMENTED"Owl_base_maths.dawsn")letsigmoidx=1./.(1.+.log~-.x)letxlogyxy=ifx=0.&&classify_floaty<>FP_nanthen0.elsex*.logyletxlog1pyxy=ifx=0.&&classify_floaty<>FP_nanthen0.elsex*.log1pyletlogitx=log(x/.(1.-.x))letexpitx=1./.(1.+.exp(-.x))letlog1mexpx=if-.x>log2.thenlog1p(-.expx)elselog(-.expm1x)letlog1pexpx=ifx<=-37.thenexpxelseifx<=18.thenlog1p(expx)elseifx<=33.3thenx+.exp(-.x)elsexleterfx=leta=0.254829592inletb=-0.284496736inletc=1.421413741inletd=-1.453152027inlete=1.061405429inletp=0.327591100inletsign=signumxinletx=absxinlett=1./.(1.+.(p*.x))inlety=1.-.(((((((((e*.t)+.d)*.t)+.c)*.t)+.b)*.t)+.a)*.t*.exp(-.x*.x))insign*.yleterfcx=1.-.erfxleterfcxx=exp(x*.x)*.erfcx(* Helper functions *)letis_nanx=FP_nan=classify_floatxletis_infx=FP_infinite=classify_floatxletis_normalx=FP_normal=classify_floatxletis_subnormalx=FP_subnormal=classify_floatxletis_oddx=Stdlib.absxmod2=1letis_evenx=xmod2=0letis_pow2x=x<>0&&xland(x-1)=0letsame_signxy=ifx>=0.&&y>=0.thentrueelseifx<=0.&&y<=0.thentrueelsefalseletis_simplexx=letacc=ref0.inletchk=reftruein(tryArray.iter(funa->ifa<0.then(chk:=false;raiseOwl_exception.FOUND);acc:=!acc+.a)xwith|_exn->());letdf=abs_float(1.-.!acc)inifdf>Owl_const.epsthenchk:=false;!chkletis_intx=modfx|>fst|>(=)0.letis_sqrx=float_of_intx|>sqrt|>is_intlet_binary_expabmfid=letr=refidinleta=refainletb=refbinwhile!a>0doif!aland1=1thenr:=f!r!bm;a:=!alsr1;if!a>0thenb:=f!b!bmdone;!rletmulmodabm=leterror()=lets=Printf.sprintf"mulmod requires a >= 0, b >= 0, m >= 1; however the inputs are a = %i, b = %i, \
m = %i"abminOwl_exception.INVALID_ARGUMENTsinletpredicate=a>=0&&b>=0&&m>=1inOwl_exception.verifypredicateerror;leta=amodminletb=bmodminifa=0||b=0then0elseifb<max_int/athena*bmodmelse(let_muladdabm=letc=m-binifa>=cthena-celsem-c+ain_binary_expabm_muladd0)letpowmodabm=leterror()=lets=Printf.sprintf"powmod requires a >= 0, b >= 0, m >= 1; however the inputs are a = %i, b = %i, \
m = %i"abminOwl_exception.INVALID_ARGUMENTsinletpredicate=a>=0&&b>=0&&m>=1inOwl_exception.verifypredicateerror;ifm=1&&b=0then0else_binary_expb(amodm)mmulmod1letfermat_factx=lets="the input of fermat_fact should be an odd number"inOwl_exception.(check(is_oddx=true)(INVALID_ARGUMENTs));letx=float_of_intxinlety=ref(ceil(sqrtx))inletz=ref((!y*.!y)-.x)inwhileis_int(sqrt!z)=falsedoy:=!y+.1.;z:=(!y*.!y)-.xdone;letfac0=!y+.sqrt!z|>int_of_floatinletfac1=!y-.sqrt!z|>int_of_floatinfac0,fac1letis_primex=let_possible_primesdna=letp=ref(powmodadn)inif!p=1||!p=n-1thentrueelse(letr=reffalseinlet_=tryfor_=0tos-1dop:=mulmod!p!pn;if!p=n-1thenraiseOwl_exception.FOUNDdonewith|_->r:=truein!r)inlet_factor_powtwon=leti,r=ref0,refninwhile!rland1=0dor:=!rlsr1;i:=!i+1done;!i,!rinifx<30thenArray.memx[|2;3;5;7;11;13;17;19;23;29|]elseifxmod2=0||xmod3=0thenfalseelse(letbases=(* bases from https://miller-rabin.appspot.com *)ifx<=1073741823then(* 2^30 - 1 *)[|2;7;61|]else[|2;325;9375;28178;450775;9780504;6*299210837|]inlets,d=_factor_powtwo(x-1)inArray.for_all(funb->bmodx=0||_possible_primesdx(bmodx))bases)