123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187# 1 "src/base/algodiff/owl_algodiff_core.ml"(*
* OWL - OCaml Scientific Computing
* Copyright (c) 2016-2022 Liang Wang <liang@ocaml.xyz>
*)moduleMake(A:Owl_types_ndarray_algodiff.Sig)=structincludeOwl_algodiff_types.Make(A)moduleA=A(* generate global tags *)let_global_tag=ref0lettag()=_global_tag:=!_global_tag+1;!_global_tag(* helper functions of the core AD component *)letreset_zero=function|F_->FA.(float_to_elt0.)|Arrap->A.resetap;Arrap|_->failwith"error: reset_zero"letprimal=function|DF(ap,_,_)->ap|DR(ap,_,_,_,_,_)->ap|ap->apletrecprimal'=function|DF(ap,_,_)->primal'ap|DR(ap,_,_,_,_,_)->primal'ap|ap->apletreczero=function|F_->FA.(float_to_elt0.)|Arrap->ArrA.(zeros(shapeap))|DF(ap,_,_)->ap|>primal'|>zero|DR(ap,_,_,_,_,_)->ap|>primal'|>zerolettangent=function|DF(_,at,_)->at|DR_->failwith"error: no tangent for DR"|ap->zeroapletadjref=function|DF_->failwith"error: no adjref for DF"|DR(_,at,_,_,_,_)->at|ap->ref(zeroap)letadjval=function|DF_->failwith"error: no adjval for DF"|DR(_,at,_,_,_,_)->!at|ap->zeroapletshapex=matchprimal'xwith|F_->[||]|Arrap->A.shapeap|_->failwith"error: AD.shape"letrecis_floatx=matchxwith|Arr_->false|F_->true|DF_->is_float(primal'x)|DR_->is_float(primal'x)letrecis_arrx=matchxwith|Arr_->false|F_->true|DF_->is_arr(primal'x)|DR_->is_arr(primal'x)letrow_numx=(shapex).(0)letcol_numx=(shapex).(1)letnumelx=matchprimal'xwith|Arrx->A.numelx|_->failwith"error: AD.numel"letclip_by_value~amin~amaxx=matchprimal'xwith|Arrx->ArrA.(clip_by_value~amin~amaxx)|_->failwith"error: AD.clip_by_value"letclip_by_l2normax=matchprimal'xwith|Arrx->ArrA.(clip_by_l2normax)|_->failwith"error: AD.clip_by_l2norm"letcopy_primal'x=matchprimal'xwith|Arrap->ArrA.(copyap)|_->failwith"error: AD.copy"lettilexreps=matchprimal'xwith|Arrx->ArrA.(tilexreps)|_->failwith"error: AD.tile"letrepeatxreps=matchprimal'xwith|Arrx->ArrA.(repeatxreps)|_->failwith"error: AD.repeat"(* packing and unpacking functions *)letpack_eltx=Fxletunpack_eltx=matchprimalxwith|Fx->x|_->failwith"error: AD.unpack_elt"letpack_fltx=FA.(float_to_eltx)let_fx=FA.(float_to_eltx)(* shorcut for type conversion *)letunpack_fltx=matchprimalxwith|Fx->A.elt_to_floatx|_->failwith"error: AD.unpack_flt"letpack_arrx=Arrxletunpack_arrx=matchprimalxwith|Arrx->x|_->failwith"error: AD.unpack_arr"(* functions to report errors, help in debugging *)letdeep_infox=matchprimal'xwith|Fa->Printf.sprintf"F(%g)"A.(elt_to_floata)|Arra->Printf.sprintf"Arr(%s)"(A.shapea|>Owl_utils_array.to_stringstring_of_int)|_->"you should not have reached here!"lettype_infox=matchxwith|F_a->Printf.sprintf"[%s]"(deep_infox)|DF(ap,_at,ai)->Printf.sprintf"[DF tag:%i ap:%s]"ai(deep_infoap)|DR(ap,_at,_ao,_af,ai,_)->Printf.sprintf"[DR tag:%i ap:%s]"ai(deep_infoap)|_->Printf.sprintf"[%s]"(deep_infox)leterror_binopopab=lets0="#0:"^type_infoainlets1="#1:"^type_infobinfailwith(op^" : "^s0^", "^s1)leterror_uniopopa=lets=type_infoainfailwith(op^" : "^s)end