1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063# 1 "src/owl/linalg/owl_linalg_generic.ml"(*
* OWL - an OCaml numerical library for scientific computing
* Copyright (c) 2016-2017 Liang Wang <liang.wang@cl.cam.ac.uk>
*)openBigarrayopenOwl_typestype('a,'b)t=('a,'b)Owl_dense_matrix_generic.t(*
We create a local generic matrix module with basic operators. This is only
way to let us use operators to write concise math but avoid circular dependency
at the same time.
*)moduleM=structincludeOwl_dense_matrix_genericincludeOwl_operator.Make_Basic(Owl_dense_matrix_generic)includeOwl_operator.Make_Extend(Owl_dense_matrix_generic)includeOwl_operator.Make_Matrix(Owl_dense_matrix_generic)end(* Helper functions *)letis_squarex=letm,n=M.shapexinm=nletselect_evkeywordev=letk=M.kindevinletm,n=M.shapeevinlets=M.zerosint32mninlet_=matchkeywordwith|`LHP->(let_op=Owl_ndarray._re_eltkinM.iteri_2d(funija->if_opa<0.thenM.setsij1l)ev)|`RHP->(let_op=Owl_ndarray._re_eltkinM.iteri_2d(funija->if_opa>=0.thenM.setsij1l)ev)|`UDI->(let_op=funa->Owl_ndarray.(_abs_eltka|>_re_eltk)inM.iteri_2d(funija->if_opa<1.thenM.setsij1l)ev)|`UDO->(let_op=funa->Owl_ndarray.(_abs_eltka|>_re_eltk)inM.iteri_2d(funija->if_opa>=1.thenM.setsij1l)ev)ins(* LU decomposition *)letlux=letx=M.copyxinletm,n=M.shapexinletminmn=Pervasives.minmninleta,ipiv=Owl_lapacke.getrfxinletl=M.trilainletu=M.resize(M.triua)[|n;n|]inlet_a1=Owl_const.one(M.kindx)infori=0tominmn-1doM.setlii_a1done;l,u,ipivletlufactx=leta,ipiv=Owl_lapacke.getrfxina,ipiv(* basic functions *)letinvx=letx=M.copyxinleta,ipiv=Owl_lapacke.getrfxinOwl_lapacke.getriaipivletdetx=letx=M.copyxinletm,n=M.shapexinassert(m=n);leta,ipiv=Owl_lapacke.getrfxinletd=ref(Owl_const.one(M.kindx))inletc=ref0inlet_mul_op=Owl_ndarray._mul_elt(M.kindx)infori=0tom-1dod:=_mul_op!d(M.getaii);(* NOTE: +1 to adjust to Fortran index *)if(M.getipiv0i)<>Int32.of_int(i+1)thenc:=!c+1done;matchOwl_maths.is_odd!cwith|true->Owl_ndarray._neg_elt(M.kindx)!d|false->!d(* FIXME: need to check ... *)letlogdetx=letx=M.copyxinletm,n=M.shapexinassert(m=n);let_kind=M.kindxinleta,ipiv=Owl_lapacke.getrfxinletd=ref(Owl_const.zero_kind)inletc=ref0inlet_add_op=Owl_ndarray._add_elt_kindinlet_log_op=Owl_ndarray._log_elt_kindinlet_neg_op=Owl_ndarray._neg_elt_kindinfori=0tom-1dod:=_add_op!d(_log_op(M.getaii));(* NOTE: +1 to adjust to Fortran index *)if(M.getipiv0i)<>Int32.of_int(i+1)thenc:=!c+1done;matchOwl_maths.is_odd!cwith|true->Owl_ndarray._neg_elt_kind!d|false->!d(* QR decomposition *)let_get_qr_q:typeab.(a,b)kind->(a,b)t->(a,b)t->(a,b)t=funkatau->matchkwith|Float32->Owl_lapacke.orgqratau|Float64->Owl_lapacke.orgqratau|Complex32->Owl_lapacke.ungqratau|Complex64->Owl_lapacke.ungqratau|_->failwith"owl_linalg:_get_qr_q"letqr?(thin=true)?(pivot=false)x=letx=M.copyxinletm,n=M.shapexinletminmn=Pervasives.minmninleta,tau,jpvt=matchpivotwith|true->Owl_lapacke.geqp3x|false->(letjpvt=M.emptyint3200inleta,tau=Owl_lapacke.geqrfxina,tau,jpvt)inletr=matchthinwith|true->M.resize~head:true(M.triua)[|minmn;n|]|false->M.resize~head:true(M.triua)[|m;n|]inleta=matchthinwith|true->a|false->ifm<=nthenaelse(leta'=M.zeros(M.kindx)m(m-n)inM.concat_horizontalaa')inletq=_get_qr_q(M.kindx)atauinq,r,jpvtletqrfact?(pivot=false)x=leta,tau,jpvt=matchpivotwith|true->Owl_lapacke.geqp3x|false->(letjpvt=M.emptyint3200inleta,tau=Owl_lapacke.geqrfxina,tau,jpvt)ina,tau,jpvtlet_get_lq_q:typeab.(a,b)kind->(a,b)t->(a,b)t->(a,b)t=funkatau->matchkwith|Float32->Owl_lapacke.orglqatau|Float64->Owl_lapacke.orglqatau|Complex32->Owl_lapacke.unglqatau|Complex64->Owl_lapacke.unglqatau|_->failwith"owl_linalg:_get_lq_q"letlq?(thin=true)x=letx=M.copyxinletm,n=M.shapexinletminmn=Pervasives.minmninleta,tau=Owl_lapacke.gelqfxinletl=matchthinwith|true->ifm<nthenM.get_slice[[];[0;minmn-1]](M.trila)elseM.trila|false->M.trilainleta=matchthinwith|true->a|false->ifm>=nthenaelseM.resize~head:truea[|n;n|]inletq=_get_lq_q(M.kindx)atauinl,q(* Sigular Value decomposition *)letsvd?(thin=true)x=letx=M.copyxinletjobz=matchthinwith|true->'S'|false->'A'inletu,s,vt=Owl_lapacke.gesdd~jobz~a:xinu,s,vtletsvdvalsx=letx=M.copyxinlet_,s,_=Owl_lapacke.gesdd~jobz:'N'~a:xinsletgsvdxy=letx=M.copyxinlety=M.copyyinletm,n=M.shapexinletp,_=M.shapeyinletu,v,q,alpha,beta,k,l,r=Owl_lapacke.ggsvd3~jobu:'U'~jobv:'V'~jobq:'Q'~a:x~b:yinletalpha=M.resize~head:truealpha[|1;(k+l)|]inletd1=M.resize~head:true(M.diagmalpha)[|m;k+l|]inletbeta=M.resize~head:truebeta[|1;k+l|]inletbeta=M.resize~head:falsebeta[|1;l|]inletd2=M.resize(M.diagm~kbeta)[|p;k+l|]inu,v,q,d1,d2,rletgsvdvalsxy=letx=M.copyxinlety=M.copyyinlet_,_,_,alpha,beta,k,l,_=Owl_lapacke.ggsvd3~jobu:'N'~jobv:'N'~jobq:'N'~a:x~b:yinletalpha=M.resize~head:truealpha[|1;k+l|]inletbeta=M.resize~head:truebeta[|1;k+l|]inM.(divalphabeta)letrank?tolx=letsv=svdvalsxinletm,n=M.shapexinletmaxmn=Pervasives.maxmnin(* by default using float32 eps *)leteps=Owl_utils.epsFloat32inlettol=matchtolwith|Sometol->tol|None->(float_of_intmaxmn)*.epsinletdtol=tolinletztol=Complex.({re=tol;im=neg_infinity})inlet_count:typeab.(a,b)kind->(a,b)t->int=fun_kindsv->match_kindwith|Float32->M.elt_greater_scalarsvdtol|>M.sum'|>int_of_float|Float64->M.elt_greater_scalarsvdtol|>M.sum'|>int_of_float|Complex32->leta=M.elt_greater_scalarsvztol|>M.sum'inint_of_floatComplex.(a.re)|Complex64->leta=M.elt_greater_scalarsvztol|>M.sum'inint_of_floatComplex.(a.re)|_->failwith"owl_linalg:rank"in_count(M.kindsv)sv(* Cholesky Decomposition *)letchol?(upper=true)x=letx=M.copyxinmatchupperwith|true->Owl_lapacke.potrf'U'x|>M.triu|false->Owl_lapacke.potrf'L'x|>M.tril(* Schur Decomposition *)let_magic_complex:typeabcd.(c,d)kind->(a,b)t->(a,b)t->(c,d)t=funotypreim->letityp=M.kindreinmatchityp,otypwith|Float32,Complex32->M.complexfloat32complex32reim|Float64,Complex64->M.complexfloat64complex64reim|Complex32,Complex32->re|Complex64,Complex64->re|_->failwith"owl_linalg_generic:_magic_complex"letschur:typeabcd.otyp:(c,d)kind->(a,b)t->(a,b)t*(a,b)t*(c,d)t=fun~otypx->assert(is_squarex);letx=M.copyxinlett,z,wr,wi=Owl_lapacke.gees~jobvs:'V'~a:xinletw=_magic_complexotypwrwiint,z,wletschur_tzx=assert(is_squarex);leta=M.copyxinlett,z,_,_=Owl_lapacke.gees~jobvs:'V'~aint,zletordschur:typeabcd.otyp:(c,d)kind->select:(int32,int32_elt)t->(a,b)t->(a,b)t->(a,b)t*(a,b)t*(c,d)t=fun~otyp~selecttq->lett=M.copytinletq=M.copyqinM.iter(funa->assert(a=0l||a=1l))select;letts,zs,wr,wi=Owl_lapacke.trsen~job:'V'~compq:'V'~select~t~qinletws=_magic_complexotypwrwiints,zs,ws(* Generalised Schur Decomposition *)letqz:typeabcd.otyp:(c,d)kind->(a,b)t->(a,b)t->(a,b)t*(a,b)t*(a,b)t*(a,b)t*(c,d)t=fun~otypxy->assert(is_squarex);assert(is_squarey);leta=M.copyxinletb=M.copyyinlets,t,ar,ai,bt,q,z=Owl_lapacke.gges~jobvsl:'V'~jobvsr:'V'~a~binletalpha=_magic_complexotyparaiinletbeta=M.castotypbtinlete=M.(alpha/beta)ins,t,q,z,eletordqz:typeabcd.otyp:(c,d)kind->select:(int32,int32_elt)t->(a,b)t->(a,b)t->(a,b)t->(a,b)t->(a,b)t*(a,b)t*(a,b)t*(a,b)t*(c,d)t=fun~otyp~selectabqz->leta=M.copyainletb=M.copybinletq=M.copyqinletz=M.copyzinleta,b,ar,ai,bt,q,z=Owl_lapacke.tgsen~select~a~b~q~zinletalpha=_magic_complexotyparaiinletbeta=M.castotypbtinlete=M.(alpha/beta)ina,b,q,z,eletqzvals:typeabcd.otyp:(c,d)kind->(a,b)t->(a,b)t->(c,d)t=fun~otypxy->assert(is_squarex);assert(is_squarey);leta=M.copyxinletb=M.copyyinletar,ai,bt,_,_=Owl_lapacke.ggev~jobvl:'N'~jobvr:'N'~a~binletalpha=_magic_complexotyparaiinletbeta=M.castotypbtinM.(alpha/beta)(* TODO: RQ Decomposition *)letrqx=()(* Eigenvalue problem *)leteig:typeabcd.?permute:bool->?scale:bool->otyp:(a,b)kind->(c,d)t->(a,b)t*(a,b)t=fun?(permute=true)?(scale=true)~otypx->letx=M.copyxinletbalanc=matchpermute,scalewith|true,true->'B'|true,false->'P'|false,true->'S'|false,false->'N'inleta,wr,wi,_,vr,_,_,_,_,_,_=Owl_lapacke.geevx~balanc~jobvl:'N'~jobvr:'V'~sense:'N'~a:xin(* TODO: optimise the performance by writing in c *)(* construct eigen vectors from real wr and wi *)let_construct_v:typeab.(float,a)kind->(Complex.t,b)kind->(float,a)t->(float,a)t->(float,a)t->(Complex.t,b)t->unit=funk0k1wrwivrv->let_a0=Owl_const.zero(M.kindwi)inlet_,n=M.shapevinletj=ref0inwhile!j<ndoif(M.getwi0!j)=_a0then(fork=0ton-1doM.setvk!jComplex.({re=M.getvrk!j;im=0.})done)else(fork=0ton-1doM.setvk!jComplex.({re=M.getvrk!j;im=M.getvrk(!j+1)});M.setvk(!j+1)Complex.({re=M.getvrk!j;im=0.-.(M.getvrk(!j+1))});done;j:=!j+1);j:=!j+1done;in(* process eigen vectors *)letm,n=M.shapevrinletv=match(M.kindx)with|Float32->(letv=M.emptycomplex32mnin_construct_vfloat32complex32wrwivrv;Obj.magicv)|Float64->(letv=M.emptycomplex64mnin_construct_vfloat64complex64wrwivrv;Obj.magicv)|Complex32->Obj.magicvr|Complex64->Obj.magicvr|_->failwith"owl_linalg_generic:eig"in(* process eigen values *)letw=match(M.kindx)with|Float32->M.complexfloat32complex32wrwi|>Obj.magic|Float64->M.complexfloat64complex64wrwi|>Obj.magic|Complex32->Obj.magicwr|Complex64->Obj.magicwr|_->failwith"owl_linalg_generic:eigvals"inv,wleteigvals:typeabcd.?permute:bool->?scale:bool->otyp:(a,b)kind->(c,d)t->(a,b)t=fun?(permute=true)?(scale=true)~otypx->letx=M.copyxinletbalanc=matchpermute,scalewith|true,true->'B'|true,false->'P'|false,true->'S'|false,false->'N'inlet_,wr,wi,_,_,_,_,_,_,_,_=Owl_lapacke.geevx~balanc~jobvl:'N'~jobvr:'N'~sense:'N'~a:xinletw=match(M.kindx)with|Float32->M.complexfloat32complex32wrwi|>Obj.magic|Float64->M.complexfloat64complex64wrwi|>Obj.magic|Complex32->Obj.magicwr|Complex64->Obj.magicwr|_->failwith"owl_linalg_generic:eigvals"inw(* Hessenberg form of matrix *)let_get_hess_q:typeab.(a,b)kind->int->int->(a,b)t->(a,b)t->(a,b)t=funkiloihiatau->matchkwith|Float32->Owl_lapacke.orghriloihiatau|Float64->Owl_lapacke.orghriloihiatau|Complex32->Owl_lapacke.unghriloihiatau|Complex64->Owl_lapacke.unghriloihiatau|_->failwith"owl_linalg:_get_hess_q"lethessx=letx=M.copyxinlet_,n=M.shapexinletilo=1inletihi=ninleta,tau=Owl_lapacke.gehrd~ilo~ihi~a:xinleth=M.triu~k:(-1)ainletq=_get_hess_q(M.kindx)iloihiatauinh,q(* Bunch-Kaufman [Bunch1977] factorization *)letbkfact?(upper=true)?(symmetric=true)?(rook=false)x=letx=M.copyxinletuplo=matchupperwith|true->'U'|false->'L'inleta,ipiv,ret=matchrookwith|true->(matchsymmetricwith|true->Owl_lapacke.sytrf_rookuplox|false->Owl_lapacke.hetrf_rookuplox)|false->(matchsymmetricwith|true->Owl_lapacke.sytrfuplox|false->Owl_lapacke.hetrfuplox)ina,ipiv(* Check matrix properties *)letis_triux=Owl_matrix._matrix_is_triu(M.kindx)xletis_trilx=Owl_matrix._matrix_is_tril(M.kindx)xletis_symmetricx=Owl_matrix._matrix_is_symmetric(M.kindx)xletis_hermitianx=Owl_matrix._matrix_is_hermitian(M.kindx)xletis_diagx=Owl_matrix._matrix_is_diag(M.kindx)xletis_posdefx=tryignore(cholx);truewithexn->falselet_minmax_real:typeab.(a,b)kind->(a,b)t->float*float=funkv->match(M.kindv)with|Float32->M.minmax'v|Float64->M.minmax'v|Complex32->M.re_c2sv|>M.minmax'|Complex64->M.re_z2dv|>M.minmax'|_->failwith"owl_linalg_generic:_minmax_real"(* local abs function, bear with obj.magic *)let_abs:typeabc.(a,b)kind->(a,b)t->(float,c)t=funkx->matchkwith|Float32->M.absx|>Obj.magic|Float64->M.absx|>Obj.magic|Complex32->M.abs_c2sx|>Obj.magic|Complex64->M.abs_z2dx|>Obj.magic|_->failwith"owl_linalg_generic:_abs"letnorm?(p=2.)x=letk=M.kindxinifp=1.thenx|>_absk|>M.sum_rows|>M.max'elseifp=-1.thenx|>_absk|>M.sum_rows|>M.min'elseifp=2.thenx|>svdvals|>_minmax_realk|>sndelseifp=-2.thenx|>svdvals|>_minmax_realk|>fstelseifp=infinitythenx|>_absk|>M.sum_cols|>M.max'elseifp=neg_infinitythenx|>_absk|>M.sum_cols|>M.min'elsefailwith"owl_linalg_generic:norm:p=±1|±2|±inf"letvecnorm?(p=2.)x=letk=M.kindxinifp=1.thenM.l1norm'x|>Owl_ndarray._re_eltkelseifp=2.thenM.l2norm'x|>Owl_ndarray._re_eltkelse(letv=M.flattenx|>M.absinifp=infinitythenM.max'v|>Owl_ndarray._re_eltkelseifp=neg_infinitythenM.min'v|>Owl_ndarray._re_eltkelse(M.pow_scalar_v(Owl_ndarray._float_typ_eltkp);leta=M.sum'v|>Owl_ndarray._re_eltkina**(1./.p)))letcond?(p=2.)x=ifp=2.then(letv=svdvalsxinletminv,maxv=_minmax_real(M.kindv)vinifmaxv=0.theninfinityelsemaxv/.minv)elseifp=1.||p=infinitythen(assert(M.row_numx=M.col_numx);letx=M.copyxinleta,ipiv=lufactxinletanorm=norm~pxinlet_norm=ifp=1.then'1'else'I'inletrcond=Owl_lapacke.gecon_normaanormin1./.rcond)elsefailwith"owl_linalg_generic:cond:p=1|2|inf"letrcondx=1./.(cond~p:1.x)(* solve linear system of equations *)letnullx=leteps=Owl_utils.eps(M.kindx)inletm,n=M.shapexinifm=0||n=0thenM.eye(M.kindx)nelse(let_,s,vt=svd~thin:falsexinlets=_abs(M.kinds)sinletmaxsv=M.max'sinletmaxmn=Pervasives.maxmn|>float_of_intinleti=M.elt_greater_scalars(maxmn*.maxsv*.eps)|>M.sum'|>int_of_floatinletvt=M.resize~head:falsevt[|M.row_numvt-i;M.col_numvt|]inM.transposevt)let_get_trans_code:typeab.(a,b)kind->char=function|Float32->'T'|Float64->'T'|Complex32->'C'|Complex64->'C'|_->failwith"owl_linalg_generic:_get_trans_code"(* TODO: add opt parameter to specify the matrix properties so that we can
choose the best solver for better performance.
*)letlinsolve?(trans=false)ab=letma,na=M.shapeainletmb,nb=M.shapebinassert(ma=mb);leta=M.copyainletb=M.copybinlettrans=matchtranswith|true->_get_trans_code(M.kinda)|false->'N'inmatchma=nawith|true->(leta,ipiv=lufactainletx=Owl_lapacke.getrstransaipivbinx)|false->(let_,x,_=Owl_lapacke.gelstransabinx)letlinregxy=letnx=M.numelxinletny=M.numelyinassert(nx=ny);letx=M.reshapex[|nx;1|]inlety=M.reshapey[|ny;1|]inletk=M.kindxinletp=M.get(M.cov~a:x~b:y)01inletq=M.get(M.var~axis:0x)00inletb=Owl_ndarray._div_eltkpqinletc=Owl_ndarray._mul_eltkb(M.mean'x)inleta=Owl_ndarray._sub_eltk(M.mean'y)cina,bletpinv?tolx=letu,s,vt=svdxin(* by default using float32 eps *)leteps=Owl_utils.epsFloat32inletm,n=M.shapexinleta=float_of_int(Pervasives.maxmn)inletb=_minmax_real(M.kindx)s|>sndinlett=matchtolwith|Sometol->tol|None->eps*.a*.binlettol=Owl_ndarray._float_typ_elt(M.kindx)tinlets'=M.(reci_tol~tols|>diagm)inletut=M.ctransposeuinletv=M.ctransposevtinM.(v*@s'*@ut)letsylvesterabc=letra,qa=schur_tzainletrb,qb=schur_tzbinletd=M.((ctransposeqa)*@(c*@qb))inlety,s=Owl_lapacke.trsyl'N''N'1rarbdinletz=M.(qa*@(y*@(ctransposeqb)))inM.mul_scalar_z(Owl_ndarray._float_typ_elt(M.kindc)(1./.s));zletlyapunovac=letr,q=schur_tzainletd=M.((ctransposeq)*@(c*@q))inlettb=_get_trans_code(M.kindc)inlety,s=Owl_lapacke.trsyl'N'tb1rrdinletz=M.(q*@(y*@(ctransposeq)))inM.mul_scalar_z(Owl_ndarray._float_typ_elt(M.kindc)(1./.s));zletcareabqr=letg=M.(b*@(invr)*@(transposeb))inletz=M.(concat_vh[|[|a;negg|];[|negq;neg(transposea)|]|])inlett,u,wr,_=Owl_lapacke.gees~jobvs:'V'~a:zinletselect=M.(zerosint32(row_numwr)(col_numwr))inM.iteri_2d(funijre->ifre<0.thenM.setselectij1l)wr;ignore(Owl_lapacke.trsen~job:'V'~compq:'V'~select~t~q:u);letm,n=M.shapeuinletu0=M.get_slice[[0;m/2-1];[0;n/2-1]]uinletu1=M.get_slice[[m/2;m-1];[0;n/2-1]]uinM.(u1*@(invu0))letdareabqr=letg=M.(b*@(invr)*@(transposeb))inletc=M.transpose(inva)inletz=M.(concat_vh[|[|a+g*@c*@q;(negg)*@c|];[|(negc)*@q;c|]|])inlett,u,wr,wi=Owl_lapacke.gees~jobvs:'V'~a:zinletselect=M.(zerosint32(row_numwr)(col_numwr))inM.iter2i_2d(funijreim->ifComplex.(norm{re;im})<=1.thenM.setselectij1l)wrwi;ignore(Owl_lapacke.trsen~job:'V'~compq:'V'~select~t~q:u);letm,n=M.shapeuinletu0=M.get_slice[[0;m/2-1];[0;n/2-1]]uinletu1=M.get_slice[[m/2;m-1];[0;n/2-1]]uinM.(u1*@(invu0))(* helper functions *)letpeakflops?(n=2000)()=letx=M.onesfloat64nn|>M.flatten|>array1_of_genarrayinletz=M.onesfloat64nn|>M.flatten|>array1_of_genarrayinletlayout=Owl_cblas_basic.CblasRowMajorinlettransa=Owl_cblas_basic.CblasNoTransinlettransb=Owl_cblas_basic.CblasNoTransinlett0=Unix.gettimeofday()inOwl_cblas_basic.gemmlayouttransatransbnnn1.0xnxn0.0zn;lett1=Unix.gettimeofday()inletflops=2.*.(float_of_intn)**3./.(t1-.t0)inflops(* Matrix functions *)letmpowxr=letfrac_part,_=Pervasives.modfriniffrac_part<>0.thenfailwith"mpow: fractional powers not implemented";letm,n=M.shapexinassert(m=n);(* integer matrix powers using floats: *)letrec_mpowaccs=ifs=1.thenaccelseifmod_floats2.=0.(* exponent is even? *)theneven_mpowaccselseM.dotx(even_mpowacc(s-.1.))andeven_mpowaccs=letacc2=_mpowacc(s/.2.)inM.dotacc2acc2in(* r is equal to an integer: *)ifr=0.0thenM.(eye(kindx))nelseifr>0.0then_mpowxrelse_mpow(invx)(-.r)(* DEBUG: initial expm implemented with eig, obsoleted *)letexpm_eig:typeabcd.otyp:(c,d)kind->(a,b)t->(c,d)t=fun~otypx->Owl_exception.(check(is_squarex)NOT_SQUARE);letv,w=eig~otypxinletvi=invvinletu=M.(expw|>diagm)inM.(dot(dotvu)vi)letexpmx=Owl_exception.(check(is_squarex)NOT_SQUARE);(* trivial case *)ifM.shapex=(1,1)thenM.expxelse((* TODO: use gebal to balance to improve accuracy, refer to Julia's impl *)letxe=M.(eye(kindx)(row_numx))inletnorm_x=norm~p:1.xin(* for small norm, use lower order Padé-approximation *)ifnorm_x<=2.097847961257068then(letc=Array.map(Owl_ndarray._float_typ_elt(M.kindx))(ifnorm_x>0.9504178996162932then[|17643225600.;8821612800.;2075673600.;302702400.;30270240.;2162160.;110880.;3960.;90.;1.|]elseifnorm_x>0.2539398330063230then[|17297280.;8648640.;1995840.;277200.;25200.;1512.;56.;1.|]elseifnorm_x>0.01495585217958292then[|30240.;15120.;3360.;420.;30.;1.|]else[|120.;60.;12.;1.|])inletx2=M.dotxxinletp=refM.(copyxe)inletu=M.mul_scalar!pc.(1)inletv=M.mul_scalar!pc.(0)infori=1toArray.(lengthc/2-1)doletj=2*iinletk=j+1inp:=M.dot!px2;M.(add_u(mul_scalar!pc.(k)));M.(add_v(mul_scalar!pc.(j)));done;letu=M.dotxuinleta=M.subvuinletb=M.addvuinOwl_lapacke.gesvab|>ignore;b)(* for larger norm, Padé-13 approximation *)else(lets=Owl_maths.log2(norm_x/.5.4)inlett=ceilsinletx=ifs>0.thenOwl_ndarray._float_typ_elt(M.kindx)(2.**t)|>M.div_scalarxelsexinletc=Array.map(Owl_ndarray._float_typ_elt(M.kindx))[|64764752532480000.;32382376266240000.;7771770303897600.;1187353796428800.;129060195264000.;10559470521600.;670442572800.;33522128640.;1323241920.;40840800.;960960.;16380.;182.;1.|]inletx2=M.dotxxinletx4=M.dotx2x2inletx6=M.dotx2x4inletu=M.(x*@(x6*@(x6*$c.(13)+x4*$c.(11)+x2*$c.(9))+x6*$c.(7)+x4*$c.(5)+x2*$c.(3)+xe*$c.(1)))inletv=M.(x6*@(x6*$c.(12)+x4*$c.(10)+x2*$c.(8))+x6*$c.(6)+x4*$c.(4)+x2*$c.(2)+xe*$c.(0))inleta=M.subvuinletb=M.addvuinOwl_lapacke.gesvab|>ignore;letx=refbinifs>0.then(fori=1toint_of_floattdox:=M.dot!x!xdone;);!x))let_sinm:typeab.(a,b)kind->(a,b)t->(a,b)t=funkx->matchkwith|Float32->(leta=Complex.({re=0.;im=1.})inletx=M.cast_s2cxinM.(expm(a$*x)|>im_c2s))|Float64->(leta=Complex.({re=0.;im=1.})inletx=M.cast_d2zxinM.(expm(a$*x)|>im_z2d))|Complex32->(leta=Complex.({re=0.;im=(-0.5)})inletb=Complex.({re=0.;im=1.})inletc=Complex.({re=0.;im=(-1.)})inM.(a$*(expm(b$*x)-expm(c$*x))))|Complex64->(leta=Complex.({re=0.;im=(-0.5)})inletb=Complex.({re=0.;im=1.})inletc=Complex.({re=0.;im=(-1.)})inM.(a$*(expm(b$*x)-expm(c$*x))))|_->failwith"_sinm: unsupported operation"letsinmx=_sinm(M.kindx)xlet_cosm:typeab.(a,b)kind->(a,b)t->(a,b)t=funkx->matchkwith|Float32->(leta=Complex.({re=0.;im=1.})inletx=M.cast_s2cxinM.(expm(a$*x)|>re_c2s))|Float64->(leta=Complex.({re=0.;im=1.})inletx=M.cast_d2zxinM.(expm(a$*x)|>re_z2d))|Complex32->(leta=Complex.({re=0.5;im=0.})inletb=Complex.({re=0.;im=1.})inletc=Complex.({re=0.;im=(-1.)})inM.(a$*(expm(b$*x)+expm(c$*x))))|Complex64->(leta=Complex.({re=0.5;im=0.})inletb=Complex.({re=0.;im=1.})inletc=Complex.({re=0.;im=(-1.)})inM.(a$*(expm(b$*x)+expm(c$*x))))|_->failwith"_cosm: unsupported operation"letcosmx=_cosm(M.kindx)xlet_sincosm:typeab.(a,b)kind->(a,b)t->(a,b)t*(a,b)t=funkx->matchkwith|Float32->(leta=Complex.({re=0.;im=1.})inletx=M.cast_s2cxinlety=M.(expm(a$*x))inM.(im_c2sy,re_c2sy))|Float64->(leta=Complex.({re=0.;im=1.})inletx=M.cast_d2zxinlety=M.(expm(a$*x))inM.(im_z2dy,re_z2dy))|Complex32->(letb=Complex.({re=0.;im=1.})inletc=Complex.({re=0.;im=(-1.)})inletx=M.(expm(b$*x))inlety=M.(expm(c$*x))inlet_sin=M.(Complex.({re=0.;im=(-0.5)})$*(x-y))inlet_cos=M.(Complex.({re=0.5;im=0.})$*(x+y))in_sin,_cos)|Complex64->(letb=Complex.({re=0.;im=1.})inletc=Complex.({re=0.;im=(-1.)})inletx=M.(expm(b$*x))inlety=M.(expm(c$*x))inlet_sin=M.(Complex.({re=0.;im=(-0.5)})$*(x-y))inlet_cos=M.(Complex.({re=0.5;im=0.})$*(x+y))in_sin,_cos)|_->failwith"_sincosm: unsupported operation"letsincosmx=_sincosm(M.kindx)xlettanmx=lets,c=sincosmxinOwl_lapacke.gesvcs|>ignore;sletsinhmx=leta=Owl_ndarray._float_typ_elt(M.kindx)0.5inM.(a$*((expmx)-(expm(negx))))letcoshmx=leta=Owl_ndarray._float_typ_elt(M.kindx)0.5inM.(a$*((expmx)+(expm(negx))))letsinhcoshmx=leta=Owl_ndarray._float_typ_elt(M.kindx)0.5inletb=expmxinletc=expm(M.negx)inM.(a$*(b-c)),M.(a$*(b+c))lettanhmx=lets,c=sinhcoshmxinOwl_lapacke.gesvcs|>ignore;s(* TODO *)letlogmx=failwith"logm: not implemented"(* TODO *)letsqrtmx=failwith"sqrtm: not implemented"(* ends here *)