123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484# 1 "src/base/linalg/owl_base_linalg_generic.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*)type('a,'b)t=('a,'b)Owl_base_dense_ndarray_generic.tmoduleM=Owl_base_dense_ndarray_generic(* Check matrix properties *)letis_triux=letshp=M.shapexinletm,n=shp.(0),shp.(1)inletk=Stdlib.minmninlet_a0=Owl_const.zero(M.kindx)intryfori=0tok-1doforj=0toi-1doassert(M.getx[|i;j|]=_a0)donedone;truewith|_exn->falseletis_trilx=letshp=M.shapexinletm,n=shp.(0),shp.(1)inletk=Stdlib.minmninlet_a0=Owl_const.zero(M.kindx)intryfori=0tok-1doforj=i+1tok-1doassert(M.getx[|i;j|]=_a0)donedone;truewith|_exn->falseletis_symmetricx=letshp=M.shapexinletm,n=shp.(0),shp.(1)inifm<>nthenfalseelse(tryfori=0ton-1doforj=i+1ton-1doleta=M.getx[|j;i|]inletb=M.getx[|i;j|]inassert(a=b)donedone;truewith|_exn->false)letis_hermitianx=letshp=M.shapexinletm,n=shp.(0),shp.(1)inifm<>nthenfalseelse(tryfori=0ton-1doforj=iton-1doleta=M.getx[|j;i|]inletb=Complex.conj(M.getx[|i;j|])inassert(a=b)donedone;truewith|_exn->false)letis_diagx=is_triux&&is_trilxlet_check_is_matrixdims=ifArray.lengthdims<>2thenraise(Invalid_argument"The given NDarray is not a matrix!")else()(* ======= WARNING: the linalg functions below are experimental. ======= *)(* ========= Corner cases etc. are not sufficiently tested. ============ *)(* Linear equation solution by Gauss-Jordan elimination.
* Input matrix: a[n][n], b[n][m];
* Output: ``ainv``, inversed matrix of a; ``x``, so that ax = b.
* TODO: Extend to multiple types: double, complex; unify with existing owl
* structures e.g. naming.
* Test: https://github.com/scipy/scipy/blob/master/scipy/linalg/tests/test_basic.py#L496 *)letlinsolve_gaussab=letdims_a,dims_b=M.shapea,M.shapebinlet_,_=_check_is_matrixdims_a,_check_is_matrixdims_binleta=M.copyainletb=M.copybinletn=dims_a.(0)inletm=dims_b.(1)inleticol=ref0inletirow=ref0inletdum=ref0.0inletpivinv=ref0.0inletindxc=Array.maken0inletindxr=Array.maken0inletipiv=Array.maken0in(* Main loop over the columns to be reduced. *)fori=0ton-1doletbig=ref0.0in(* Outer loop of the search for at pivot element *)forj=0ton-1doifipiv.(j)<>1thenfork=0ton-1doifipiv.(k)==0then(letv=M.geta[|j;k|]|>abs_floatinifv>=!bigthen(big:=v;irow:=j;icol:=k))donedone;ipiv.(!icol)<-ipiv.(!icol)+1;if!irow<>!icolthen(forl=0ton-1doletu=M.geta[|!irow;l|]inletv=M.geta[|!icol;l|]inM.seta[|!icol;l|]u;M.seta[|!irow;l|]vdone;forl=0tom-1doletu=M.getb[|!irow;l|]inletv=M.getb[|!icol;l|]inM.setb[|!icol;l|]u;M.setb[|!irow;l|]vdone);indxr.(i)<-!irow;indxc.(i)<-!icol;letp=M.geta[|!icol;!icol|]inifp=0.0thenraiseOwl_exception.SINGULAR;pivinv:=1.0/.p;M.seta[|!icol;!icol|]1.0;forl=0ton-1doletprev=M.geta[|!icol;l|]inM.seta[|!icol;l|](prev*.!pivinv)done;forl=0tom-1doletprev=M.getb[|!icol;l|]inM.setb[|!icol;l|](prev*.!pivinv)done;forll=0ton-1doifll<>!icolthen(dum:=M.geta[|ll;!icol|];M.seta[|ll;!icol|]0.0;forl=0ton-1doletp=M.geta[|!icol;l|]inletprev=M.geta[|ll;l|]inM.seta[|ll;l|](prev-.(p*.!dum))done;forl=0tom-1doletp=M.getb[|!icol;l|]inletprev=M.getb[|ll;l|]inM.setb[|ll;l|](prev-.(p*.!dum))done)donedone;forl=n-1downto0doifindxr.(l)<>indxc.(l)thenfork=0ton-1doletu=M.geta[|k;indxr.(l)|]inletv=M.geta[|k;indxc.(l)|]inM.seta[|k;indxc.(l)|]u;M.seta[|k;indxr.(l)|]vdonedone;a,b(* LU decomposition.
* Input matrix: a[n][n]; return L/U in one matrix, and the row permutation vector.
* Test: https://github.com/scipy/scipy/blob/master/scipy/linalg/tests/test_decomp.py
*)let_lu_basea=letk=M.kindainlet_abs=Owl_base_dense_common._abs_eltkinlet_mul=Owl_base_dense_common._mul_eltkinlet_div=Owl_base_dense_common._div_eltkinlet_sub=Owl_base_dense_common._sub_eltkinlet_flt=Owl_base_dense_common._float_typ_eltkinlet_zero=Owl_const.zerokinlet_one=Owl_const.onekinletlu=M.copyainletn=(M.shapea).(0)inletm=(M.shapea).(1)inassert(n=m);letindx=Array.maken0in(* implicit scaling of each row *)letvv=Array.maken_zeroinlettiny=_flt1.0e-40inletbig=ref_zeroinlettemp=ref_zeroin(* flag of row exchange *)letd=ref1.0inletimax=ref0in(* loop over rows to get the implicit scaling information *)fori=0ton-1dobig:=_zero;forj=0ton-1dotemp:=M.getlu[|i;j|]|>_abs;if!temp>!bigthenbig:=!tempdone;if!big=_zerothenraiseOwl_exception.SINGULAR;vv.(i)<-_div_one!bigdone;fork=0ton-1dobig:=_zero;(* choose suitable pivot *)fori=kton-1dotemp:=_mul(M.getlu[|i;k|]|>_abs)vv.(i);if!temp>!bigthen(big:=!temp;imax:=i)done;(* interchange rows *)ifk<>!imaxthen(forj=0ton-1dotemp:=M.getlu[|!imax;j|];lettmp=M.getlu[|k;j|]inM.setlu[|!imax;j|]tmp;M.setlu[|k;j|]!tempdone;d:=!d*.-1.;vv.(!imax)<-vv.(k));indx.(k)<-!imax;ifM.getlu[|k;k|]=_zerothenM.setlu[|k;k|]tiny;fori=k+1ton-1dolettmp0=M.getlu[|i;k|]inlettmp1=M.getlu[|k;k|]intemp:=_divtmp0tmp1;M.setlu[|i;k|]!temp;forj=k+1ton-1doletprev=M.getlu[|i;j|]inM.setlu[|i;j|](_subprev(_mul!temp(M.getlu[|k;j|])))donedonedone;lu,indx,!d(* LU decomposition, return L, U, and permutation vector *)letlua=letk=M.kindainlet_zero=Owl_const.zerokinletlu,indx,_=_lu_baseainletn=(M.shapelu).(0)inletm=(M.shapelu).(1)inassert(n=m&&n>=2);letl=M.eyekninforr=1ton-1doforc=0tor-1doletv=M.getlu[|r;c|]inM.setl[|r;c|]v;M.setlu[|r;c|]_zerodonedone;l,lu,indxlet_lu_solve_vecab=let_k=M.kindainlet_mul=Owl_base_dense_common._mul_elt_kinlet_div=Owl_base_dense_common._div_elt_kinlet_sub=Owl_base_dense_common._sub_elt_kinlet_zero=Owl_const.zero_kinassert(Array.length(M.shapeb)=1);letn=(M.shapea).(0)inif(M.shapeb).(0)<>nthenfailwith"LUdcmp::solve bad sizes";letii=ref0inletsum=ref_zeroinletx=M.copybinletlu,indx,_=_lu_baseainfori=0ton-1doletip=indx.(i)insum:=M.getx[|ip|];M.setx[|ip|](M.getx[|i|]);if!ii<>0thenforj=!ii-1toi-1dosum:=_sub!sum(_mul(M.getlu[|i;j|])(M.getx[|j|]))doneelseif!sum<>_zerothenii:=!ii+1;M.setx[|i|]!sumdone;fori=n-1downto0dosum:=M.getx[|i|];forj=i+1ton-1dosum:=_sub!sum(_mul(M.getlu[|i;j|])(M.getx[|j|]))done;M.setx[|i|](_div!sum(M.getlu[|i;i|]))done;x(* Linear equation solution by LU decomposition.
* Input matrix: a[n][n], b[n][m];
* Output: ``x``, so that ax = b. *)letlinsolve_luab=letdims_a,dims_b=M.shapea,M.shapebinlet_,_=_check_is_matrixdims_a,_check_is_matrixdims_binassert(dims_a.(0)=dims_a.(1));letm=dims_b.(1)inletb=M.copybinforj=0tom-1doletvec=M.get_slice[[];[j]]b|>M.flatteninletx=_lu_solve_vecavecinM.set_slice[[];[j]]bxdone;b(* Determinant of matrix a *)letdeta=letk=M.kindainlet_mul=Owl_base_dense_common._mul_eltkinlet_flt=Owl_base_dense_common._float_typ_eltkinletdims_a=M.shapeain_check_is_matrixdims_a|>ignore;assert(dims_a.(0)=dims_a.(1));letn=dims_a.(0)inletlu,_,sign=_lu_baseainletbig=ref(_fltsign)infori=0ton-1dobig:=_mul!big(M.getlu[|i;i|])done;!big(* Solver for tridiagonal matrix
* Input: a[n], b[n], c[n], which together consist the tridiagonal matrix A, and the right side vector r[n]. Return: x[n].
*)lettridiag_solve_vecabcr=letn=Array.lengthainletn1=Array.lengthbinletn2=Array.lengthcinassert(n=n1&&n=n2);ifb.(0)=0.thenraise(Invalid_argument"tridiag_solve_vec: 0 at the beginning of diagonal vector");letbet=refb.(0)inletgam=Array.maken0.inletx=Array.maken0.inx.(0)<-r.(0)/.!bet;forj=1ton-1dogam.(j)<-c.(j-1)/.!bet;bet:=b.(j)-.(a.(j)*.gam.(j));if!bet=0.thenraise(Invalid_argument"tridiag_solve_vec: algorithm fails");x.(j)<-(r.(j)-.(a.(j)*.x.(j-1)))/.!betdone;forj=n-2downto0dox.(j)<-x.(j)-.(gam.(j+1)*.x.(j+1))done;x(* TODO: optimise and test *)(* Implementing the following algorithm:
http://www.irma-international.org/viewtitle/41011/ *)letinvvarr=let_k=M.kindvarrinlet_add=Owl_base_dense_common._add_elt_kinlet_mul=Owl_base_dense_common._mul_elt_kinlet_div=Owl_base_dense_common._div_elt_kinlet_neg=Owl_base_dense_common._neg_elt_kinlet_zero=Owl_const.zero_kinlet_one=Owl_const.one_kinletdims=M.shapevarrinlet_=_check_is_matrixdimsinletn=Array.unsafe_getdims0inifArray.unsafe_getdims1!=nthenfailwith"no inverse - the matrix is not square"else(letpivot_row=Array.maken_zeroinletresult_varr=M.copyvarrinforp=0ton-1doletpivot_elem=M.getresult_varr[|p;p|]inifM.getresult_varr[|p;p|]=_zerothenfailwith"the matrix does not have an inverse";(* update elements of the pivot row, save old vals *)forj=0ton-1dopivot_row.(j)<-M.getresult_varr[|p;j|];ifj!=pthenM.setresult_varr[|p;j|](_divpivot_row.(j)pivot_elem)done;(* update elements of the pivot col *)fori=0ton-1doifi!=pthenM.setresult_varr[|i;p|](_div(M.getresult_varr[|i;p|])(_negpivot_elem))done;(* update the rest of the matrix *)fori=0ton-1doletpivot_col_elem=M.getresult_varr[|i;p|]inforj=0ton-1doifi!=p&&j!=pthen(letpivot_row_elem=pivot_row.(j)in(* use old value *)letold_val=M.getresult_varr[|i;j|]inletnew_val=_addold_val(_mulpivot_row_elempivot_col_elem)inM.setresult_varr[|i;j|]new_val)donedone;(* update the pivot element *)M.setresult_varr[|p;p|](_div_onepivot_elem)done;result_varr)letlogdet_x=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.logdet")letqr?(thin=true)?(pivot=false)_x=ignorethin;ignorepivot;raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.qr")letlq?(thin=true)_x=ignorethin;raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.lq")letchol?(upper=true)_x=upper|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.chol")letsvd?(thin=true)_x=thin|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.svd")letsylvester_a_b_c=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.sylvester")letlyapunov_a_q=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.lyapunov")letdiscrete_lyapunov?(solver=`default)_a_q=solver|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.discrete_lyapunov")letlinsolve?(trans=false)?(typ=`n)_a_b=trans|>ignore;typ|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.linsolve")letcare?(diag_r=false)_a_b_q_r=diag_r|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.care")(* ends here *)