123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120# 1 "src/base/algodiff/owl_algodiff_check.ml"moduleMake(Algodiff:Owl_algodiff_generic_sig.Sig)=structopenAlgodiffletgenerate_directions(dim1,dim2)=letn_directions=dim1*dim2inArray.initn_directions(funj->Arr(A.init[|dim1;dim2|](funi->ifi=jthenA.(float_to_elt1.)elseA.(float_to_elt0.))))letgenerate_test_samples(dim1,dim2)n_samples=(Array.initn_samples(fun_->Mat.gaussiandim1dim2),generate_directions(dim1,dim2))moduleReverse=structletfinite_difference_grad=lettwo=FA.(float_to_elt2.)infun~order~f?(eps=1E-5)xd->leteps=FA.(float_to_elteps)inletdx=Maths.(eps*d)inletdf1=Maths.(f(x+dx)-f(x-dx))inmatchorderwith|`fourth->leteight=FA.(float_to_elt8.)inlettwelve=FA.(float_to_elt12.)inletdf2=lettwodx=Maths.(two*dx)inMaths.(f(x+twodx)-f(x-twodx))inMaths.(((eight*df1)-df2)/(twelve*eps))|`second->Maths.(df1/(two*eps))letcheck~threshold~order?(verbose=false)?(eps=1E-5)~f=letcomparers=letn_d=Array.lengthrsinletr_fds=Array.mapsndrsinletrms=Array.fold_left(funaccr_fd->acc+.(r_fd*.r_fd))0.r_fds/.floatn_d|>sqrtinletmax_err=rs|>Array.map(fun(r_ad,r_fd)->abs_float(r_ad-.r_fd)/.(rms+.1E-9))|>Array.fold_leftmax(-1.)inmax_err<threshold,max_errinletfx=Maths.(sum'(fx))inletg=gradfinfun~directionssamples->letn_samples=Array.lengthsamplesinletcheck,max_err,n_passed=samples|>Array.map(funx->letcheck,max_err=Array.map(fund->letr_ad=Maths.(sum'(gx*d))|>unpack_fltinletr_fd=finite_difference_grad~order~f~epsxd|>unpack_fltinr_ad,r_fd)directions|>compareincheck,max_err)|>Array.fold_left(fun(check_old,max_err_old,acc)(check,max_err)->(check_old&&check,maxmax_err_oldmax_err,ifcheckthensuccaccelseacc))(true,-1.,0)inifverbosethenPrintf.printf"adjoints passed: %i/%i | max_err: %f.\n%!"n_passedn_samplesmax_err;check,n_passedendmoduleForward=structletcheck_tangent_dimensions~fx=(* tangent at x should have the same dimension as f x *)matchprimal'(fx),primal'(jacobianvfxx)with|F_,F_->()|Arra,Arrb->ifA.shapea<>A.shapebthenfailwith"tangent dimension mismatch"else()|_->failwith"tangent dimension mismatch"letcheck~threshold~f~directionssamples=check_tangent_dimensions~fsamples.(0);letfx=Maths.(sum'(fx))inletreverse_g=gradfinletdim1,dim2=Mat.shapedirections.(0)inletforward_gx=Arr(A.init[|dim1;dim2|](funi->letv=directions.(i)injacobianvfxv|>unpack_elt))inArray.fold_left(fun(b,n)x->letreverse_gx=reverse_gxinletforward_gx=forward_gxinlete=Maths.(l2norm_sqr'(reverse_gx-forward_gx))|>unpack_fltinletb'=e<thresholdinletn=ifb'thenn+1elseninb&&b',n)(true,0)samplesendend