123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248# 1 "src/owl/linalg/owl_linalg_generic.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*)[@@@warning"-6"]openBigarraytype('a,'b)t=('a,'b)Owl_dense_matrix_generic.t(*
We create a local generic matrix module with basic operators. This is only
way to let us use operators to write concise math but avoid circular dependency
at the same time.
*)moduleM=structincludeOwl_dense_matrix_genericincludeOwl_operator.Make_Basic(Owl_dense_matrix_generic)includeOwl_operator.Make_Extend(Owl_dense_matrix_generic)includeOwl_operator.Make_Matrix(Owl_dense_matrix_generic)end(* Helper functions *)letis_squarex=letm,n=M.shapexinm=nletselect_evkeywordev=letk=M.kindevinletm,n=M.shapeevinlets=M.zerosint32mninlet_=matchkeywordwith|`LHP->let_op=Owl_base_dense_common._re_eltkinM.iteri_2d(funija->if_opa<0.thenM.setsij1l)ev|`RHP->let_op=Owl_base_dense_common._re_eltkinM.iteri_2d(funija->if_opa>=0.thenM.setsij1l)ev|`UDI->let_opa=Owl_base_dense_common.(_abs_eltka|>_re_eltk)inM.iteri_2d(funija->if_opa<1.thenM.setsij1l)ev|`UDO->let_opa=Owl_base_dense_common.(_abs_eltka|>_re_eltk)inM.iteri_2d(funija->if_opa>=1.thenM.setsij1l)evins(* LU decomposition *)letlux=letx=M.copyxinletm,n=M.shapexinletminmn=Stdlib.minmninleta,ipiv=Owl_lapacke.getrf~a:xinletl=M.trilainletu=M.resize(M.triua)[|n;n|]inlet_a1=Owl_const.one(M.kindx)infori=0tominmn-1doM.setlii_a1done;l,u,ipivletlufactx=leta,ipiv=Owl_lapacke.getrf~a:xina,ipiv(* basic functions *)letinvx=letx=M.copyxinleta,ipiv=Owl_lapacke.getrf~a:xinOwl_lapacke.getri~a~ipivletdetx=letx=M.copyxinletm,n=M.shapexinOwl_exception.(check(m=n)(NOT_SQUARE[|m;n|]));leta,ipiv=Owl_lapacke.getrf~a:xinletd=ref(Owl_const.one(M.kindx))inletc=ref0inlet_mul_op=Owl_base_dense_common._mul_elt(M.kindx)infori=0tom-1dod:=_mul_op!d(M.getaii);(* NOTE: +1 to adjust to Fortran index *)ifM.getipiv0i<>Int32.of_int(i+1)thenc:=!c+1done;matchOwl_maths.is_odd!cwith|true->Owl_base_dense_common._neg_elt(M.kindx)!d|false->!d(* FIXME: need to check ... *)letlogdetx=letx=M.copyxinletm,n=M.shapexinOwl_exception.(check(m=n)(NOT_SQUARE[|m;n|]));let_kind=M.kindxinleta,ipiv=Owl_lapacke.getrf~a:xinletd=ref(Owl_const.zero_kind)inletc=ref0inlet_add_op=Owl_base_dense_common._add_elt_kindinlet_log_op=Owl_base_dense_common._log_elt_kindinlet_abs_op=Owl_base_dense_common._abs_elt_kindinfori=0tom-1dolete=M.getaiiind:=_add_op!d(_log_op(_abs_ope));(* NOTE: +1 to adjust to Fortran index *)letp=M.getipiv0i<>Int32.of_int(i+1)inletq=e<Owl_const.zero_kindin(* implement xor *)if(p&¬q)||((notp)&&q)thenc:=!c+1done;matchOwl_maths.is_odd!cwith|true->failwith"logdet: det is negative"|false->!d(* QR decomposition *)let_get_qr_q:typeab.(a,b)kind->(a,b)t->(a,b)t->(a,b)t=funkatau->matchkwith|Float32->Owl_lapacke.orgqratau|Float64->Owl_lapacke.orgqratau|Complex32->Owl_lapacke.ungqratau|Complex64->Owl_lapacke.ungqratau|_->failwith"owl_linalg:_get_qr_q"letqr?(thin=true)?(pivot=false)x=letx=M.copyxinletm,n=M.shapexinletminmn=Stdlib.minmninleta,tau,jpvt=matchpivotwith|true->Owl_lapacke.geqp3x|false->letjpvt=M.emptyint3200inleta,tau=Owl_lapacke.geqrf~a:xina,tau,jpvtinletr=matchthinwith|true->M.resize~head:true(M.triua)[|minmn;n|]|false->M.resize~head:true(M.triua)[|m;n|]inleta=matchthinwith|true->a|false->ifm<=nthenaelse(leta'=M.zeros(M.kindx)m(m-n)inM.concat_horizontalaa')inletq=_get_qr_q(M.kindx)atauinq,r,jpvtletqrfact?(pivot=false)x=leta,tau,jpvt=matchpivotwith|true->Owl_lapacke.geqp3x|false->letjpvt=M.emptyint3200inleta,tau=Owl_lapacke.geqrfxina,tau,jpvtina,tau,jpvtlet_get_lq_q:typeab.(a,b)kind->(a,b)t->(a,b)t->(a,b)t=funkatau->matchkwith|Float32->Owl_lapacke.orglqatau|Float64->Owl_lapacke.orglqatau|Complex32->Owl_lapacke.unglqatau|Complex64->Owl_lapacke.unglqatau|_->failwith"owl_linalg:_get_lq_q"letlq?(thin=true)x=letx=M.copyxinletm,n=M.shapexinletminmn=Stdlib.minmninleta,tau=Owl_lapacke.gelqfxinletl=matchthinwith|true->ifm<nthenM.get_slice[[];[0;minmn-1]](M.trila)elseM.trila|false->M.trilainleta=matchthinwith|true->a|false->ifm>=nthenaelseM.resize~head:truea[|n;n|]inletq=_get_lq_q(M.kindx)atauinl,q(* Sigular Value decomposition *)letsvd?(thin=true)x=letx=M.copyxinletjobz=matchthinwith|true->'S'|false->'A'inletu,s,vt=Owl_lapacke.gesdd~jobz~a:xinu,s,vtletsvdvalsx=letx=M.copyxinlet_,s,_=Owl_lapacke.gesdd~jobz:'N'~a:xinsletgsvdxy=letx=M.copyxinlety=M.copyyinletm,_n=M.shapexinletp,_=M.shapeyinletu,v,q,alpha,beta,k,l,r=Owl_lapacke.ggsvd3~jobu:'U'~jobv:'V'~jobq:'Q'~a:x~b:yinletalpha=M.resize~head:truealpha[|1;k+l|]inletd1=M.resize~head:true(M.diagmalpha)[|m;k+l|]inletbeta=M.resize~head:truebeta[|1;k+l|]inletbeta=M.resize~head:falsebeta[|1;l|]inletd2=M.resize(M.diagm~kbeta)[|p;k+l|]inu,v,q,d1,d2,rletgsvdvalsxy=letx=M.copyxinlety=M.copyyinlet_,_,_,alpha,beta,k,l,_=Owl_lapacke.ggsvd3~jobu:'N'~jobv:'N'~jobq:'N'~a:x~b:yinletalpha=M.resize~head:truealpha[|1;k+l|]inletbeta=M.resize~head:truebeta[|1;k+l|]inM.(divalphabeta)letrank?tolx=letsv=svdvalsxinletm,n=M.shapexinletmaxmn=Stdlib.maxmnin(* by default using float32 eps *)leteps=Owl_utils.epsFloat32inlettol=matchtolwith|Sometol->tol|None->float_of_intmaxmn*.epsinletdtol=tolinletztol=Complex.{re=tol;im=neg_infinity}inlet_count:typeab.(a,b)kind->(a,b)t->int=fun_kindsv->match_kindwith|Float32->M.elt_greater_scalarsvdtol|>M.sum'|>int_of_float|Float64->M.elt_greater_scalarsvdtol|>M.sum'|>int_of_float|Complex32->leta=M.elt_greater_scalarsvztol|>M.sum'inint_of_floatComplex.(a.re)|Complex64->leta=M.elt_greater_scalarsvztol|>M.sum'inint_of_floatComplex.(a.re)|_->failwith"owl_linalg:rank"in_count(M.kindsv)sv(* Cholesky Decomposition *)letchol?(upper=true)x=letx=M.copyxinmatchupperwith|true->Owl_lapacke.potrf'U'x|>M.triu|false->Owl_lapacke.potrf'L'x|>M.tril(* Schur Decomposition *)let_magic_complex:typeabcd.(c,d)kind->(a,b)t->(a,b)t->(c,d)t=funotypreim->letityp=M.kindreinmatchityp,otypwith|Float32,Complex32->M.complexfloat32complex32reim|Float64,Complex64->M.complexfloat64complex64reim|Complex32,Complex32->re|Complex64,Complex64->re|_->failwith"owl_linalg_generic:_magic_complex"letschur:typeabcd.otyp:(c,d)kind->(a,b)t->(a,b)t*(a,b)t*(c,d)t=fun~otypx->letm,n=M.shapexinOwl_exception.(check(m=n)(NOT_SQUARE[|m;n|]));letx=M.copyxinlett,z,wr,wi=Owl_lapacke.gees~jobvs:'V'~a:xinletw=_magic_complexotypwrwiint,z,wletschur_tzx=letm,n=M.shapexinOwl_exception.(check(m=n)(NOT_SQUARE[|m;n|]));leta=M.copyxinlett,z,_,_=Owl_lapacke.gees~jobvs:'V'~aint,zletordschur:typeabcd.otyp:(c,d)kind->select:(int32,int32_elt)t->(a,b)t->(a,b)t->(a,b)t*(a,b)t*(c,d)t=fun~otyp~selecttq->lett=M.copytinletq=M.copyqinM.iter(funa->assert(a=0l||a=1l))select;letts,zs,wr,wi=Owl_lapacke.trsen~job:'V'~compq:'V'~select~t~qinletws=_magic_complexotypwrwiints,zs,ws(* Generalised Schur Decomposition *)letqz:typeabcd.otyp:(c,d)kind->(a,b)t->(a,b)t->(a,b)t*(a,b)t*(a,b)t*(a,b)t*(c,d)t=fun~otypxy->letm,n=M.shapexinOwl_exception.(check(m=n)(NOT_SQUARE[|m;n|]));letu,v=M.shapeyinOwl_exception.(check(u=v)(NOT_SQUARE[|u;v|]));leta=M.copyxinletb=M.copyyinlets,t,ar,ai,bt,q,z=Owl_lapacke.gges~jobvsl:'V'~jobvsr:'V'~a~binletalpha=_magic_complexotyparaiinletbeta=M.castotypbtinlete=M.(alpha/beta)ins,t,q,z,eletordqz:typeabcd.otyp:(c,d)kind->select:(int32,int32_elt)t->(a,b)t->(a,b)t->(a,b)t->(a,b)t->(a,b)t*(a,b)t*(a,b)t*(a,b)t*(c,d)t=fun~otyp~selectabqz->leta=M.copyainletb=M.copybinletq=M.copyqinletz=M.copyzinleta,b,ar,ai,bt,q,z=Owl_lapacke.tgsen~select~a~b~q~zinletalpha=_magic_complexotyparaiinletbeta=M.castotypbtinlete=M.(alpha/beta)ina,b,q,z,eletqzvals:typeabcd.otyp:(c,d)kind->(a,b)t->(a,b)t->(c,d)t=fun~otypxy->letm,n=M.shapexinOwl_exception.(check(m=n)(NOT_SQUARE[|m;n|]));letu,v=M.shapeyinOwl_exception.(check(u=v)(NOT_SQUARE[|u;v|]));leta=M.copyxinletb=M.copyyinletar,ai,bt,_,_=Owl_lapacke.ggev~jobvl:'N'~jobvr:'N'~a~binletalpha=_magic_complexotyparaiinletbeta=M.castotypbtinM.(alpha/beta)(* TODO: RQ Decomposition *)letrq_x=()[@@warning"-32"](* Eigenvalue problem *)leteig:typeabcd.?permute:bool->?scale:bool->otyp:(a,b)kind->(c,d)t->(a,b)t*(a,b)t=fun?(permute=true)?(scale=true)~otypx->letx=M.copyxinletbalanc=matchpermute,scalewith|true,true->'B'|true,false->'P'|false,true->'S'|false,false->'N'inlet_a,wr,wi,_,vr,_,_,_,_,_,_=Owl_lapacke.geevx~balanc~jobvl:'N'~jobvr:'V'~sense:'N'~a:xin(* TODO: optimise the performance by writing in c *)(* construct eigen vectors from real wr and wi *)let_construct_v:typeab.(float,a)kind->(Complex.t,b)kind->(float,a)t->(float,a)t->(float,a)t->(Complex.t,b)t->unit=funk0k1wrwivrv->let_a0=Owl_const.zero(M.kindwi)inlet_,n=M.shapevinletj=ref0inwhile!j<ndoifM.getwi0!j=_a0thenfork=0ton-1doM.setvk!jComplex.{re=M.getvrk!j;im=0.}doneelse(fork=0ton-1doM.setvk!jComplex.{re=M.getvrk!j;im=M.getvrk(!j+1)};M.setvk(!j+1)Complex.{re=M.getvrk!j;im=0.-.M.getvrk(!j+1)}done;j:=!j+1);j:=!j+1donein(* process eigen vectors *)letm,n=M.shapevrinletv=matchM.kindxwith|Float32->letv=M.emptycomplex32mnin_construct_vfloat32complex32wrwivrv;Obj.magicv|Float64->letv=M.emptycomplex64mnin_construct_vfloat64complex64wrwivrv;Obj.magicv|Complex32->Obj.magicvr|Complex64->Obj.magicvr|_->failwith"owl_linalg_generic:eig"in(* process eigen values *)letw=matchM.kindxwith|Float32->M.complexfloat32complex32wrwi|>Obj.magic|Float64->M.complexfloat64complex64wrwi|>Obj.magic|Complex32->Obj.magicwr|Complex64->Obj.magicwr|_->failwith"owl_linalg_generic:eigvals"inv,w[@@warning"-27"]leteigvals:typeabcd.?permute:bool->?scale:bool->otyp:(a,b)kind->(c,d)t->(a,b)t=fun?(permute=true)?(scale=true)~otypx->letx=M.copyxinletbalanc=matchpermute,scalewith|true,true->'B'|true,false->'P'|false,true->'S'|false,false->'N'inlet_,wr,wi,_,_,_,_,_,_,_,_=Owl_lapacke.geevx~balanc~jobvl:'N'~jobvr:'N'~sense:'N'~a:xinletw=matchM.kindxwith|Float32->M.complexfloat32complex32wrwi|>Obj.magic|Float64->M.complexfloat64complex64wrwi|>Obj.magic|Complex32->Obj.magicwr|Complex64->Obj.magicwr|_->failwith"owl_linalg_generic:eigvals"inw[@@warning"-27"](* Hessenberg form of matrix *)let_get_hess_q:typeab.(a,b)kind->int->int->(a,b)t->(a,b)t->(a,b)t=funkiloihiatau->matchkwith|Float32->Owl_lapacke.orghriloihiatau|Float64->Owl_lapacke.orghriloihiatau|Complex32->Owl_lapacke.unghriloihiatau|Complex64->Owl_lapacke.unghriloihiatau|_->failwith"owl_linalg:_get_hess_q"lethessx=letx=M.copyxinlet_,n=M.shapexinletilo=1inletihi=ninleta,tau=Owl_lapacke.gehrd~ilo~ihi~a:xinleth=M.triu~k:(-1)ainletq=_get_hess_q(M.kindx)iloihiatauinh,q(* Bunch-Kaufman [Bunch1977] factorization *)letbkfact?(upper=true)?(symmetric=true)?(rook=false)x=letx=M.copyxinletuplo=matchupperwith|true->'U'|false->'L'inleta,ipiv,_ret=matchrookwith|true->(matchsymmetricwith|true->Owl_lapacke.sytrf_rookuplox|false->Owl_lapacke.hetrf_rookuplox)|false->(matchsymmetricwith|true->Owl_lapacke.sytrfuplox|false->Owl_lapacke.hetrfuplox)ina,ipiv(* Check matrix properties *)letis_triux=Owl_matrix._matrix_is_triu(M.kindx)xletis_trilx=Owl_matrix._matrix_is_tril(M.kindx)xletis_symmetricx=Owl_matrix._matrix_is_symmetric(M.kindx)xletis_hermitianx=Owl_matrix._matrix_is_hermitian(M.kindx)xletis_diagx=Owl_matrix._matrix_is_diag(M.kindx)xletis_posdefx=tryignore(cholx);truewith|_exn->falselet_minmax_real:typeab.(a,b)kind->(a,b)t->float*float=fun_kv->matchM.kindvwith|Float32->M.minmax'v|Float64->M.minmax'v|Complex32->M.re_c2sv|>M.minmax'|Complex64->M.re_z2dv|>M.minmax'|_->failwith"owl_linalg_generic:_minmax_real"(* local abs function, bear with obj.magic *)let_abs:typeabc.(a,b)kind->(a,b)t->(float,c)t=funkx->matchkwith|Float32->M.absx|>Obj.magic|Float64->M.absx|>Obj.magic|Complex32->M.abs_c2sx|>Obj.magic|Complex64->M.abs_z2dx|>Obj.magic|_->failwith"owl_linalg_generic:_abs"letnorm?(p=2.)x=letk=M.kindxinifp=1.thenx|>_absk|>M.sum_rows|>M.max'elseifp=-1.thenx|>_absk|>M.sum_rows|>M.min'elseifp=2.thenx|>svdvals|>_minmax_realk|>sndelseifp=-2.thenx|>svdvals|>_minmax_realk|>fstelseifp=infinitythenx|>_absk|>M.sum_cols|>M.max'elseifp=neg_infinitythenx|>_absk|>M.sum_cols|>M.min'elsefailwith"owl_linalg_generic:norm:p=±1|±2|±inf"letvecnorm?(p=2.)x=letk=M.kindxinifp=1.thenM.l1norm'x|>Owl_base_dense_common._re_eltkelseifp=2.thenM.l2norm'x|>Owl_base_dense_common._re_eltkelse(letv=M.flattenx|>M.absinifp=infinitythenM.max'v|>Owl_base_dense_common._re_eltkelseifp=neg_infinitythenM.min'v|>Owl_base_dense_common._re_eltkelse(M.pow_scalar_v(Owl_base_dense_common._float_typ_eltkp);leta=M.sum'v|>Owl_base_dense_common._re_eltkina**(1./.p)))letcond?(p=2.)x=ifp=2.then(letv=svdvalsxinletminv,maxv=_minmax_real(M.kindv)vinifmaxv=0.theninfinityelsemaxv/.minv)elseifp=1.||p=infinitythen(letm,n=M.shapexinOwl_exception.(check(m=n)(NOT_SQUARE[|m;n|]));letx=M.copyxinleta,_ipiv=lufactxinletanorm=norm~pxinlet_norm=ifp=1.then'1'else'I'inletrcond=Owl_lapacke.gecon_normaanormin1./.rcond)elsefailwith"owl_linalg_generic:cond:p=1|2|inf"letrcondx=1./.cond~p:1.x(* solve linear system of equations *)letnullx=leteps=Owl_utils.eps(M.kindx)inletm,n=M.shapexinifm=0||n=0thenM.eye(M.kindx)nelse(let_,s,vt=svd~thin:falsexinlets=_abs(M.kinds)sinletmaxsv=M.max'sinletmaxmn=Stdlib.maxmn|>float_of_intinleti=M.elt_greater_scalars(maxmn*.maxsv*.eps)|>M.sum'|>int_of_floatinletvt=M.resize~head:falsevt[|M.row_numvt-i;M.col_numvt|]inM.transposevt)let_get_trans_code:typeab.(a,b)kind->char=function|Float32->'T'|Float64->'T'|Complex32->'C'|Complex64->'C'|_->failwith"owl_linalg_generic:_get_trans_code"lettriangular_solve:typecd.upper:bool->?trans:bool->(c,d)t->(c,d)t->(c,d)t=fun~upper?(trans=false)ab->letb=M.copybinletma,_na=M.shapeainletmb,nb=M.shapebinassert(ma=mb&&ma=_na);let_a=M.flattena|>Bigarray.array1_of_genarrayinlet_b=M.flattenb|>Bigarray.array1_of_genarrayinletk=M.kindainletalpha=Owl_const.onekinlettransa=iftransthen(matchkwith|Float32->Owl_cblas_basic.CblasTrans|Float64->Owl_cblas_basic.CblasTrans|Complex32->Owl_cblas_basic.CblasConjTrans|Complex64->Owl_cblas_basic.CblasConjTrans|_->failwith"owl_linalg:triangular_solve")elseOwl_cblas_basic.CblasNoTransinletlayout=Owl_cblas_basic.CblasRowMajorinletside=Owl_cblas_basic.CblasLeftinletuplo=ifupperthenOwl_cblas_basic.CblasUpperelseOwl_cblas_basic.CblasLowerinletdiag=Owl_cblas_basic.CblasNonUnitinOwl_cblas_basic.trsmlayoutsideuplotransadiagmbnbalpha_ama_bnb;b(* TODO: add opt parameter to specify the matrix properties so that we can
choose the best solver for better performance.
*)letlinsolve?(trans=false)?(typ=`n)ab=letma,na=M.shapeainletmb,_nb=M.shapebinassert(ma=mb);lettrans_=matchtranswith|true->_get_trans_code(M.kinda)|false->'N'inifma=nathen(matchtypwith(* normal *)|`n->leta=M.copyainletb=M.copybinleta,ipiv=lufactainletx=Owl_lapacke.getrstrans_aipivbinx(* upper triangular *)|`u->triangular_solve~trans~upper:trueab(* lower triangular *)|`l->triangular_solve~trans~upper:falseab)else(leta=M.copyainletb=M.copybinlet_,x,_=Owl_lapacke.gelstrans_abinx)letlinregxy=letnx=M.numelxinletny=M.numelyinleterror()=lets=Printf.sprintf"x length is %i, and y length is %i. However, they must be the same."nxnyinOwl_exception.INVALID_ARGUMENTsinOwl_exception.verify(nx=ny)error;letx=M.reshapex[|nx;1|]inlety=M.reshapey[|ny;1|]inletk=M.kindxinletp=M.get(M.cov~a:x~b:y)01inletq=M.get(M.var~axis:0x)00inletb=Owl_base_dense_common._div_eltkpqinletc=Owl_base_dense_common._mul_eltkb(M.mean'x)inleta=Owl_base_dense_common._sub_eltk(M.mean'y)cina,bletpinv?tolx=letu,s,vt=svdxin(* by default using float32 eps *)leteps=Owl_utils.epsFloat32inletm,n=M.shapexinleta=float_of_int(Stdlib.maxmn)inletb=_minmax_real(M.kindx)s|>sndinlett=matchtolwith|Sometol->tol|None->eps*.a*.binlettol=Owl_base_dense_common._float_typ_elt(M.kindx)tinlets'=M.(reci_tol~tols|>diagm)inletut=M.ctransposeuinletv=M.ctransposevtinM.(v*@s'*@ut)letsylvesterabc=letra,qa=schur_tzainletrb,qb=schur_tzbinletd=M.(ctransposeqa*@(c*@qb))inlety,s=Owl_lapacke.trsyl'N''N'1rarbdinletz=M.(qa*@(y*@ctransposeqb))inM.mul_scalar_z(Owl_base_dense_common._float_typ_elt(M.kindc)(1./.s));zletlyapunovac=letr,q=schur_tzainletd=M.(ctransposeq*@(c*@q))inlettb=_get_trans_code(M.kindc)inlety,s=Owl_lapacke.trsyl'N'tb1rrdinletz=M.(q*@(y*@ctransposeq))inM.mul_scalar_z(Owl_base_dense_common._float_typ_elt(M.kindc)(1./.s));zlet_discrete_lyapunov_directaq=letn=M.row_numqinletlhs=M.kronaM.(conja)inletlhs=M.(eye(kinda)(row_numlhs)-lhs)inM.reshape(linsolvelhsM.(reshapeq[|-1;1|]))[|n;n|](* bilinear transform reference
* https://old.control.ee.ethz.ch/info/people/mansour/pdf/168--1993-Schur-Cohn,%20Nour%20Eldin-Markov%20Matrices%20and%20the%20Controllability%20Gramians--.pdf *)let_discrete_lyapunov_bilinearaq=letn=M.row_numainletidentity=M.(eye(kinda)n)inletinv_al=invM.(a-identity)inleta'=M.(inv_al*@(a+identity))inletq'=M.(inv_al*@q*@transposeinv_al)inM.mul_scalar_q'(Owl_base_dense_common._float_typ_elt(M.kinda)2.);lyapunova'M.(negq')letdiscrete_lyapunov?(solver=`default)aq=letsolve=matchsolverwith|`default->ifM.(row_numa)<=10then_discrete_lyapunov_directelse_discrete_lyapunov_bilinear|`bilinear->_discrete_lyapunov_bilinear|`direct->_discrete_lyapunov_directinsolveaqlet_check_are_shape~labelabqr=letn,m=M.shapebinletan,am=M.shapeainletqn,qm=M.shapeqinletrn,rm=M.shaperinletpass=(an,am)=(n,n)&&(qn,qm)=(n,n)&&(rn,rm)=(m,m)inifnotpassthenPrintf.sprintf"%s dims mismatch: a (%i, %i), b (%i, %i), q (%i, %i), r (%i, %i)"labelanamnmqnqmrnrm|>failwithlet_validate_care~diag_rabqr=_check_are_shape~label:"CARE"abqr;(* check that q is Hermitian *)ifnot(is_hermitianq)thenfailwith"CARE: q is not hermitian";(* check that r is Hermitian *)if(notdiag_r)&¬(is_hermitianr)thenfailwith"CARE: r is not hermitian";(* check that r is posdef *)if(diag_r&¬M.(is_positive(diagr)))||((notdiag_r)&¬(is_posdefr))thenfailwith"CARE: r is not posdef"letcare?(diag_r=false)abqr=_validate_care~diag_rabqr;letg=ifdiag_rthen(letr=M.diagrinletinv_r=M.recirinM.(b*inv_r*@transposeb))elseM.(b*@invr*@transposeb)inletz=M.(concat_vh[|[|a;negg|];[|negq;neg(transposea)|]|])inlett,u,wr,_=Owl_lapacke.gees~jobvs:'V'~a:zinletselect=M.(zerosint32(row_numwr)(col_numwr))inM.iteri_2d(funijre->ifre<0.thenM.setselectij1l)wr;ignore(Owl_lapacke.trsen~job:'V'~compq:'V'~select~t~q:u);letm,n=M.shapeuinletu00=M.get_slice[[0;(m/2)-1];[0;(n/2)-1]]uinletu10=M.get_slice[[m/2;m-1];[0;(n/2)-1]]uin(* check solution *)leta,ipiv=lufactM.(copyu00)inifrconda<Owl_utils.(eps(M.kinda))thenfailwith"CARE: failed to find a finite solution";(* check that the solution is symmetric: slightly less stringent condition than is_symmetric *)letu_sym=M.(transposeu00*@u10)inletn_u_sym=normu_syminletu_sym=M.(u_sym-transposeu_sym)inletthres=max(0.1*.n_u_sym)(Owl_utils.epsM.(kindu))inifnormu_sym>thresthenfailwith"CARE: associated symplectic pencil has eigenvalues too close to the imaginary axis";letx=Owl_lapacke.getrs(_get_trans_codeM.(kinda))aipivM.(transposeu10)in(* symmetrise again for numerical stability *)M.(0.5$*x+transposex)let_validate_dare~diag_rabqr=_check_are_shape~label:"DARE"abqr;(* check that q is Hermitian *)ifnot(is_hermitianq)thenfailwith"DARE: q is not hermitian";(* check that r is Hermitian *)if(notdiag_r)&¬(is_hermitianr)thenfailwith"DARE: r is not hermitian"letdare?(diag_r=false)abqr=_validate_dare~diag_rabqr;letg=ifdiag_rthen(letr=M.diagrinletinv_r=M.recirinM.(b*inv_r*@transposeb))elseM.(b*@invr*@transposeb)inletc=tryM.transpose(inva)with|_->failwith"DARE: currently does not support singular A"inletz=M.(concat_vh[|[|a+(g*@c*@q);negg*@c|];[|negc*@q;c|]|])inlett,u,wr,wi=Owl_lapacke.gees~jobvs:'V'~a:zinletselect=M.(zerosint32(row_numwr)(col_numwr))inM.iter2i_2d(funijreim->ifComplex.(norm{re;im})<=1.thenM.setselectij1l)wrwi;ignore(Owl_lapacke.trsen~job:'V'~compq:'V'~select~t~q:u);letm,n=M.shapeuinletu00=M.get_slice[[0;(m/2)-1];[0;(n/2)-1]]uinletu10=M.get_slice[[m/2;m-1];[0;(n/2)-1]]uin(* check solution *)leta,ipiv=lufactM.(copyu00)inifrconda<Owl_utils.(eps(M.kinda))thenfailwith"DARE: failed to find a finite solution";(* check that the solution is symmetric: slightly less stringent condition than is_symmetric *)letu_sym=M.(transposeu00*@u10)inletn_u_sym=normu_syminletu_sym=M.(u_sym-transposeu_sym)inletthres=max(0.1*.n_u_sym)(Owl_utils.epsM.(kindu))inifnormu_sym>thresthenfailwith"DARE: associated symplectic pencil has eigenvalues too close to the unit circle";letx=Owl_lapacke.getrs(_get_trans_codeM.(kinda))aipivM.(transposeu10)in(* symmetrise again for numerical stability *)M.(0.5$*x+transposex)(* helper functions *)letpeakflops?(n=2000)()=letx=M.onesfloat64nn|>M.flatten|>array1_of_genarrayinletz=M.onesfloat64nn|>M.flatten|>array1_of_genarrayinletlayout=Owl_cblas_basic.CblasRowMajorinlettransa=Owl_cblas_basic.CblasNoTransinlettransb=Owl_cblas_basic.CblasNoTransinlett0=Unix.gettimeofday()inOwl_cblas_basic.gemmlayouttransatransbnnn1.0xnxn0.0zn;lett1=Unix.gettimeofday()inletflops=2.*.(float_of_intn**3.)/.(t1-.t0)inflops(* Matrix functions *)letmpowxr=letfrac_part,_=Stdlib.modfriniffrac_part<>0.thenfailwith"mpow: fractional powers not implemented";letm,n=M.shapexinOwl_exception.(check(m=n)(NOT_SQUARE[|m;n|]));(* integer matrix powers using floats: *)letrec_mpowaccs=ifs=1.thenaccelseifmod_floats2.=0.(* exponent is even? *)theneven_mpowaccselseM.dotx(even_mpowacc(s-.1.))andeven_mpowaccs=letacc2=_mpowacc(s/.2.)inM.dotacc2acc2in(* r is equal to an integer: *)ifr=0.0thenM.(eye(kindx))nelseifr>0.0then_mpowxrelse_mpow(invx)(-.r)(* DEBUG: initial expm implemented with eig, obsoleted *)letexpm_eig:typeabcd.otyp:(c,d)kind->(a,b)t->(c,d)t=fun~otypx->letm,n=M.shapexinOwl_exception.(check(m=n)(NOT_SQUARE[|m;n|]));letv,w=eig~otypxinletvi=invvinletu=M.(expw|>diagm)inM.(dot(dotvu)vi)[@@warning"-32"]letexpmx=letm,n=M.shapexinOwl_exception.(check(m=n)(NOT_SQUARE[|m;n|]));(* trivial case *)ifM.shapex=(1,1)thenM.expxelse((* TODO: use gebal to balance to improve accuracy, refer to Julia's impl *)letxe=M.(eye(kindx)(row_numx))inletnorm_x=norm~p:1.xin(* for small norm, use lower order Padé-approximation *)ifnorm_x<=2.097847961257068then(letc=Array.map(Owl_base_dense_common._float_typ_elt(M.kindx))(ifnorm_x>0.9504178996162932then[|17643225600.;8821612800.;2075673600.;302702400.;30270240.;2162160.;110880.;3960.;90.;1.|]elseifnorm_x>0.2539398330063230then[|17297280.;8648640.;1995840.;277200.;25200.;1512.;56.;1.|]elseifnorm_x>0.01495585217958292then[|30240.;15120.;3360.;420.;30.;1.|]else[|120.;60.;12.;1.|])inletx2=M.dotxxinletp=refM.(copyxe)inletu=M.mul_scalar!pc.(1)inletv=M.mul_scalar!pc.(0)infori=1toArray.((lengthc/2)-1)doletj=2*iinletk=j+1inp:=M.dot!px2;M.(add_~out:uu(mul_scalar!pc.(k)));M.(add_~out:vv(mul_scalar!pc.(j)))done;letu=M.dotxuinleta=M.subvuinletb=M.addvuinOwl_lapacke.gesvab|>ignore;b(* for larger norm, Padé-13 approximation *))else(lets=Owl_maths.log2(norm_x/.5.4)inlett=ceilsinletx=ifs>0.thenOwl_base_dense_common._float_typ_elt(M.kindx)(2.**t)|>M.div_scalarxelsexinletc=Array.map(Owl_base_dense_common._float_typ_elt(M.kindx))[|64764752532480000.;32382376266240000.;7771770303897600.;1187353796428800.;129060195264000.;10559470521600.;670442572800.;33522128640.;1323241920.;40840800.;960960.;16380.;182.;1.|]inletx2=M.dotxxinletx4=M.dotx2x2inletx6=M.dotx2x4inletu=M.(x*@((x6*@((x6*$c.(13))+(x4*$c.(11))+(x2*$c.(9))))+(x6*$c.(7))+(x4*$c.(5))+(x2*$c.(3))+(xe*$c.(1))))inletv=M.((x6*@((x6*$c.(12))+(x4*$c.(10))+(x2*$c.(8))))+(x6*$c.(6))+(x4*$c.(4))+(x2*$c.(2))+(xe*$c.(0)))inleta=M.subvuinletb=M.addvuinOwl_lapacke.gesvab|>ignore;letx=refbinifs>0.thenfor_i=1toint_of_floattdox:=M.dot!x!xdone;!x))let_sinm:typeab.(a,b)kind->(a,b)t->(a,b)t=funkx->matchkwith|Float32->leta=Complex.{re=0.;im=1.}inletx=M.cast_s2cxinM.(expm(a$*x)|>im_c2s)|Float64->leta=Complex.{re=0.;im=1.}inletx=M.cast_d2zxinM.(expm(a$*x)|>im_z2d)|Complex32->leta=Complex.{re=0.;im=-0.5}inletb=Complex.{re=0.;im=1.}inletc=Complex.{re=0.;im=-1.}inM.(a$*expm(b$*x)-expm(c$*x))|Complex64->leta=Complex.{re=0.;im=-0.5}inletb=Complex.{re=0.;im=1.}inletc=Complex.{re=0.;im=-1.}inM.(a$*expm(b$*x)-expm(c$*x))|_->failwith"_sinm: unsupported operation"letsinmx=_sinm(M.kindx)xlet_cosm:typeab.(a,b)kind->(a,b)t->(a,b)t=funkx->matchkwith|Float32->leta=Complex.{re=0.;im=1.}inletx=M.cast_s2cxinM.(expm(a$*x)|>re_c2s)|Float64->leta=Complex.{re=0.;im=1.}inletx=M.cast_d2zxinM.(expm(a$*x)|>re_z2d)|Complex32->leta=Complex.{re=0.5;im=0.}inletb=Complex.{re=0.;im=1.}inletc=Complex.{re=0.;im=-1.}inM.(a$*expm(b$*x)+expm(c$*x))|Complex64->leta=Complex.{re=0.5;im=0.}inletb=Complex.{re=0.;im=1.}inletc=Complex.{re=0.;im=-1.}inM.(a$*expm(b$*x)+expm(c$*x))|_->failwith"_cosm: unsupported operation"letcosmx=_cosm(M.kindx)xlet_sincosm:typeab.(a,b)kind->(a,b)t->(a,b)t*(a,b)t=funkx->matchkwith|Float32->leta=Complex.{re=0.;im=1.}inletx=M.cast_s2cxinlety=M.(expm(a$*x))inM.(im_c2sy,re_c2sy)|Float64->leta=Complex.{re=0.;im=1.}inletx=M.cast_d2zxinlety=M.(expm(a$*x))inM.(im_z2dy,re_z2dy)|Complex32->letb=Complex.{re=0.;im=1.}inletc=Complex.{re=0.;im=-1.}inletx=M.(expm(b$*x))inlety=M.(expm(c$*x))inlet_sin=M.(Complex.{re=0.;im=-0.5}$*x-y)inlet_cos=M.(Complex.{re=0.5;im=0.}$*x+y)in_sin,_cos|Complex64->letb=Complex.{re=0.;im=1.}inletc=Complex.{re=0.;im=-1.}inletx=M.(expm(b$*x))inlety=M.(expm(c$*x))inlet_sin=M.(Complex.{re=0.;im=-0.5}$*x-y)inlet_cos=M.(Complex.{re=0.5;im=0.}$*x+y)in_sin,_cos|_->failwith"_sincosm: unsupported operation"letsincosmx=_sincosm(M.kindx)xlettanmx=lets,c=sincosmxinOwl_lapacke.gesvcs|>ignore;sletsinhmx=leta=Owl_base_dense_common._float_typ_elt(M.kindx)0.5inM.(a$*expmx-expm(negx))letcoshmx=leta=Owl_base_dense_common._float_typ_elt(M.kindx)0.5inM.(a$*expmx+expm(negx))letsinhcoshmx=leta=Owl_base_dense_common._float_typ_elt(M.kindx)0.5inletb=expmxinletc=expm(M.negx)inM.(a$*b-c),M.(a$*b+c)lettanhmx=lets,c=sinhcoshmxinOwl_lapacke.gesvcs|>ignore;s(* TODO *)letlogm_x=failwith"logm: not implemented"[@@warning"-32"](* TODO *)letsqrtm_x=failwith"sqrtm: not implemented"[@@warning"-32"](* ends here *)