123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607# 1 "src/base/optimise/owl_algodiff_generic.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2019 Liang Wang <liang.wang@cl.cam.ac.uk>
* Core nested automatic differentiation algorithm and differentiation API
* ported from DiffSharp (http://diffsharp.github.io), copyright (c) 2014-2016
* National University of Ireland Maynooth (Atilim Gunes Baydin), 2014-2018
* National University of Ireland Maynooth (Barak A. Pearlmutter
* <barak@pearlmutter.net>), 2016-2018 University of Oxford (Atilim Gunes
* Baydin <gunes@robots.ox.ac.uk>), 2017-2018 Microsoft Research Cambridge
* (Don Syme <dsyme@microsoft.com>
*)openOwl_types(* Functor of making AD module of different precisions *)moduleMake(A:Owl_types_ndarray_algodiff.Sig)=structmoduleA=A(* type definitions *)typet=|FofA.elt|ArrofA.arr|DFoft*t*int(* primal, tangent, tag *)|DRoft*tref*trace_op*intref*int*intref(* primal, adjoint, op, fanout, tag, tracker *)andtrace_op=|Noop|Add_D_Doft*t|Add_D_Coft*t|Add_C_Doft*t|Sub_D_Doft*t|Sub_D_Coft*t|Sub_C_Doft*t|Mul_D_Doft*t|Mul_D_Coft*t|Mul_C_Doft*t|Div_D_Doft*t|Div_D_Coft*t|Div_C_Doft*t|Pow_D_Doft*t|Pow_D_Coft*t|Pow_C_Doft*t|Atan2_D_Doft*t|Atan2_D_Coft*t|Atan2_C_Doft*t|Neg_Doft|Abs_Doft|Signum_Doft|Floor_Doft|Ceil_Doft|Round_Doft|Sqr_Doft|Sqrt_Doft|Log_Doft|Log2_Doft|Log10_Doft|Exp_Doft|Sin_Doft|Cos_Doft|Tan_Doft|Sinh_Doft|Cosh_Doft|Tanh_Doft|Asin_Doft|Acos_Doft|Atan_Doft|Asinh_Doft|Acosh_Doft|Atanh_Doft|Get_Itemoft*int*int|SetI_D_Doft*int*int*t|SetI_D_Coft*int*int*t|SetI_C_Doft*int*int*t|AddI_D_Doft*int*int*t|AddI_D_Coft*int*int*t|AddI_C_Doft*int*int*t|Get_Slice_Doft*intlistlist|Set_Slice_D_Doft*t*intlistlist|Set_Slice_D_Coft*t*intlistlist|Set_Slice_C_Doft*t*intlistlist|Sum_Doft|Sum__Doft*int|Sum___Doft*intarray|Dot_D_Doft*t|Dot_D_Coft*t|Dot_C_Doft*t|Trans_Doft|L1Norm_Doft|L2Norm_Doft|L2NormS_Doft|Sigmoid_Doft|Relu_Doft|Inv_Doft|Logdet_Doft|Chol_Doft*bool|QR_Dof(t*(tref*tref)*(tref*tref))|Svd_Dof(t*(tref*tref*tref)*(tref*tref*tref)*bool)|Lyapunov_D_Doft*t|Lyapunov_C_Doft*t|Lyapunov_D_Coft*t|Discrete_Lyapunov_D_Doft*t|Discrete_Lyapunov_C_Doft*t|Discrete_Lyapunov_D_Coft*t|Diag_Dofint*t|Diagm_Dofint*t|Tril_Dofint*t|Triu_Dofint*t|Trace_Doft|Add_Row_D_Doft*t*int|Add_Row_D_Coft*t*int|Add_Row_C_Doft*t*int|Get_Row_Doft*int|Of_Rows_Doftarray|Of_Arrays_Doftarrayarray*(int*int)list|Concat_D_Doft*t*int|Concat_D_Coft*t*int|Concat_C_Doft*t*int|Split_Doft*int*trefarray|Concatenate_Doftarray*int*intlist|Conv1D_D_Doft*t*intarray|Conv1D_D_Coft*t*intarray|Conv1D_C_Doft*t*intarray|Conv2D_D_Doft*t*intarray|Conv2D_D_Coft*t*intarray|Conv2D_C_Doft*t*intarray|Conv3D_D_Doft*t*intarray|Conv3D_D_Coft*t*intarray|Conv3D_C_Doft*t*intarray|Di_Conv1D_D_Doft*t*intarray*intarray|Di_Conv1D_D_Coft*t*intarray*intarray|Di_Conv1D_C_Doft*t*intarray*intarray|Di_Conv2D_D_Doft*t*intarray*intarray|Di_Conv2D_D_Coft*t*intarray*intarray|Di_Conv2D_C_Doft*t*intarray*intarray|Di_Conv3D_D_Doft*t*intarray*intarray|Di_Conv3D_D_Coft*t*intarray*intarray|Di_Conv3D_C_Doft*t*intarray*intarray|Tr_Conv1D_D_Doft*t*intarray|Tr_Conv1D_D_Coft*t*intarray|Tr_Conv1D_C_Doft*t*intarray|Tr_Conv2D_D_Doft*t*intarray|Tr_Conv2D_D_Coft*t*intarray|Tr_Conv2D_C_Doft*t*intarray|Tr_Conv3D_D_Doft*t*intarray|Tr_Conv3D_D_Coft*t*intarray|Tr_Conv3D_C_Doft*t*intarray|Reshape_Doft|Maxpool1D_Doft*padding*intarray*intarray|Maxpool2D_Doft*padding*intarray*intarray|Maxpool3D_Doft*padding*intarray*intarray|Avgpool1D_Doft*padding*intarray*intarray|Avgpool2D_Doft*padding*intarray*intarray|Avgpool3D_Doft*padding*intarray*intarray|UpSampling2D_Doft*intarray|PAD_Doft*intlistlist(* generate global tags *)let_global_tag=ref0lettag()=_global_tag:=!_global_tag+1;!_global_tag(* hepler functions of the core AD component *)letcmp_tagaibi=ifai>bithen1elseifai<bithen-1else0letreset_zero=function|F_->FA.(float_to_elt0.)|Arrap->A.resetap;Arrap|_->failwith"error: reset_zero"letprimal=function|DF(ap,_,_)->ap|DR(ap,_,_,_,_,_)->ap|ap->apletrecprimal'=function|DF(ap,_,_)->primal'ap|DR(ap,_,_,_,_,_)->primal'ap|ap->apletreczero=function|F_->FA.(float_to_elt0.)|Arrap->ArrA.(zeros(shapeap))|DF(ap,_,_)->ap|>primal'|>zero|DR(ap,_,_,_,_,_)->ap|>primal'|>zerolettangent=function|DF(_,at,_)->at|DR_->failwith"error: no tangent for DR"|ap->zeroapletadjref=function|DF_->failwith"error: no adjref for DF"|DR(_,at,_,_,_,_)->at|ap->ref(zeroap)letadjval=function|DF_->failwith"error: no adjval for DF"|DR(_,at,_,_,_,_)->!at|ap->zeroapletshapex=match(primal'x)with|Arrap->A.shapeap|_->failwith"error: AD.shape"letrow_numx=(shapex).(0)letcol_numx=(shapex).(1)letnumelx=matchprimal'xwith|Arrx->A.numelx|_->failwith"error: AD.numel"letclip_by_value~amin~amaxx=matchprimal'xwith|Arrx->ArrA.(clip_by_value~amin~amaxx)|_->failwith"error: AD.clip_by_value"letclip_by_l2normax=matchprimal'xwith|Arrx->ArrA.(clip_by_l2normax)|_->failwith"error: AD.clip_by_l2norm"letcopy_primal'x=match(primal'x)with|Arrap->ArrA.(copyap)|_->failwith"error: AD.copy"lettilexreps=matchprimal'xwith|Arrx->ArrA.(tilexreps)|_->failwith"error: AD.tile"letrepeatxreps=matchprimal'xwith|Arrx->ArrA.(repeatxreps)|_->failwith"error: AD.repeat"(* packing and unpacking functions *)letpack_eltx=Fxletunpack_eltx=match(primalx)with|Fx->x|_->failwith"error: AD.unpack_elt"letpack_fltx=FA.(float_to_eltx)let_fx=FA.(float_to_eltx)(* shorcut for type conversion *)letunpack_fltx=match(primalx)with|Fx->A.elt_to_floatx|_->failwith"error: AD.unpack_flt"letpack_arrx=Arrxletunpack_arrx=match(primalx)with|Arrx->x|_->failwith"error: AD.unpack_arr"(* functions to report errors, help in debugging *)letdeep_infox=matchprimal'xwith|Fa->Printf.sprintf"F(%g)"A.(elt_to_floata)|Arra->Printf.sprintf"Arr(%s)"(A.shapea|>Owl_utils_array.to_stringstring_of_int)|_->"you should not have reached here!"lettype_infox=matchxwith|F_a->Printf.sprintf"[%s]"(deep_infox)|DF(ap,_at,ai)->Printf.sprintf"[DF tag:%i ap:%s]"ai(deep_infoap)|DR(ap,_at,_ao,_af,ai,_)->Printf.sprintf"[DR tag:%i ap:%s]"ai(deep_infoap)|_->Printf.sprintf"[%s]"(deep_infox)leterror_binopopab=lets0="#0:"^(type_infoa)inlets1="#1:"^(type_infob)infailwith(op^" : "^s0^", "^s1)leterror_uniopopa=lets=type_infoainfailwith(op^" : "^s)(* overload operators *)moduleMaths=structletrecnoop_=()andop_d_dafffddfr=matchawith|DF(ap,at,ai)->letcp=fdapinDF(cp,(dfcpapat),ai)|DR(ap,_,_,_,ai,_)->letcp=ffapinDR(cp,ref(zerocp),ra,ref0,ai,ref0)|ap->ffapandpair_op_d_dafffddfr=matchawith|DF(ap,at,ai)->letcp1,cp2=fdapinDF(cp1,(dfcp1apat),ai),DF(cp2,(dfcp2apat),ai)|DR(ap,_,_,_,ai,_)->(let(cp1,cp2)=fdapinletaa1=ref(zerocp1)inletaa2=ref(zerocp2)inletcp1_ref=refcp1inletcp2_ref=refcp2inlettracker=ref0in(* tracker: int reference
In reverse_reset, i keeps track of the number of times cp1 and cp2 has been
called such that in reverse_push, we do not update the adjoint of ap before
we've fully updated both aa1 and aa2 *)(DR(cp1,aa1,r(a,(cp1_ref,cp2_ref),(aa1,aa2)),ref0,ai,tracker),DR(cp2,aa2,r(a,(cp1_ref,cp2_ref),(aa1,aa2)),ref0,ai,tracker)))|ap->ffapandtriple_op_d_dafffddfr=matchawith|DF(ap,at,ai)->letcp1,cp2,cp3=fdapinDF(cp1,(dfcp1apat),ai),DF(cp2,(dfcp2apat),ai),DF(cp3,(dfcp3apat),ai)|DR(ap,_,_,_,ai,_)->(let(cp1,cp2,cp3)=fdapinletaa1=ref(zerocp1)inletaa2=ref(zerocp2)inletaa3=ref(zerocp3)inletcp1_ref=refcp1inletcp2_ref=refcp2inletcp3_ref=refcp3inlettracker=ref0in(DR(cp1,aa1,r(a,(cp1_ref,cp2_ref,cp3_ref),(aa1,aa2,aa3)),ref0,ai,tracker),DR(cp2,aa2,r(a,(cp1_ref,cp2_ref,cp3_ref),(aa1,aa2,aa3)),ref0,ai,tracker),DR(cp3,aa3,r(a,(cp1_ref,cp2_ref,cp3_ref),(aa1,aa2,aa3)),ref0,ai,tracker)))|ap->ffapandarray_op_d_dafffddfr=matchawith|DF(ap,at,ai)->letcp_arr=fdapinArray.map(funcp->DF(cp,(dfcpapat),ai))cp_arr|DR(ap,_,_,_,ai,_)->(letcp_arr=fdapinlettracker=ref0inletaa_arr=Array.map(funcp->ref(zerocp))cp_arrinArray.map2(funcpaa->DR(cp,aa,r(a,cp_arr,aa_arr),ref0,ai,tracker))cp_arraa_arr)|ap->ffapandop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_d=matcha,bwith|F_ap,DF(bp,bt,bi)->letcp=fdabpinDF(cp,(df_dbcpbpbt),bi)|DF(ap,at,ai),F_bp->letcp=fdapbinDF(cp,(df_dacpapat),ai)|Arr_ap,DF(bp,bt,bi)->letcp=fdabpinDF(cp,(df_dbcpbpbt),bi)|DF(ap,at,ai),Arr_bp->letcp=fdapbinDF(cp,(df_dacpapat),ai)|F_ap,DR(bp,_,_,_,bi,_)->letcp=fdabpinDR(cp,ref(zerocp),r_c_dab,ref0,bi,ref0)|DR(ap,_,_,_,ai,_),F_bp->letcp=fdapbinDR(cp,ref(zerocp),r_d_cab,ref0,ai,ref0)|Arr_ap,DR(bp,_,_,_,bi,_)->letcp=fdabpinDR(cp,ref(zerocp),r_c_dab,ref0,bi,ref0)|DR(ap,_,_,_,ai,_),Arr_bp->letcp=fdapbinDR(cp,ref(zerocp),r_d_cab,ref0,ai,ref0)|DF(ap,at,ai),DR(bp,_,_,_,bi,_)->(matchcmp_tagaibiwith|1->letcp=fdapbinDF(cp,df_dacpapat,ai)|-1->letcp=fdabpinDR(cp,ref(zerocp),r_c_dab,ref0,bi,ref0)|_->failwith"error: forward and reverse clash at the same level")|DR(ap,_,_,_,ai,_),DF(bp,bt,bi)->(matchcmp_tagaibiwith|-1->letcp=fdabpinDF(cp,df_dbcpbpbt,bi)|1->letcp=fdapbinDR(cp,ref(zerocp),r_d_cab,ref0,ai,ref0)|_->failwith"error: forward and reverse clash at the same level")|DF(ap,at,ai),DF(bp,bt,bi)->(matchcmp_tagaibiwith|0->letcp=fdapbpinDF(cp,(df_dabcpapatbpbt),ai)|1->letcp=fdapbinDF(cp,(df_dacpapat),ai)|_->letcp=fdabpinDF(cp,(df_dbcpbpbt),bi))|DR(ap,_,_,_,ai,_),DR(bp,_,_,_,bi,_)->(matchcmp_tagaibiwith|0->letcp=fdapbpinDR(cp,ref(zerocp),r_d_dab,ref0,ai,ref0)|1->letcp=fdapbinDR(cp,ref(zerocp),r_d_cab,ref0,ai,ref0)|_->letcp=fdabpinDR(cp,ref(zerocp),r_c_dab,ref0,bi,ref0))|a,b->ffaband(+)ab=addabandaddab=letffab=matcha,bwith|Fa,Fb->FA.Scalar.(addab)|Fa,Arrb->ArrA.(scalar_addab)|Arra,Fb->ArrA.(add_scalarab)|Arra,Arrb->ArrA.(addab)|_->error_binop"( + )"abinletfdab=a+binletdf_da_cp_apat=atinletdf_db_cp_bpbt=btinletdf_dab_cp_apat_bpbt=at+btinletr_d_dab=Add_D_D(a,b)inletr_d_cab=Add_D_C(a,b)inletr_c_dab=Add_C_D(a,b)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dand(-)ab=subabandsubab=letffab=matcha,bwith|Fa,Fb->FA.Scalar.(subab)|Fa,Arrb->ArrA.(scalar_subab)|Arra,Fb->ArrA.(sub_scalarab)|Arra,Arrb->ArrA.(subab)|_->error_binop"( - )"abinletfdab=a-binletdf_da_cp_apat=atinletdf_db_cp_bpbt=negbtinletdf_dab_cp_apat_bpbt=at-btinletr_d_dab=Sub_D_D(a,b)inletr_d_cab=Sub_D_C(a,b)inletr_c_dab=Sub_C_D(a,b)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dand(*)ab=mulabandmulab=letffab=matcha,bwith|Fa,Fb->FA.Scalar.(mulab)|Fa,Arrb->ArrA.(scalar_mulab)|Arra,Fb->ArrA.(mul_scalarab)|Arra,Arrb->ArrA.(mulab)|_->error_binop"( * )"abinletfdab=a*binletdf_da_cp_apat=at*binletdf_db_cp_bpbt=a*btinletdf_dab_cpapatbpbt=(ap*bt)+(at*bp)inletr_d_dab=Mul_D_D(a,b)inletr_d_cab=Mul_D_C(a,b)inletr_c_dab=Mul_C_D(a,b)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dand(/)ab=divabanddivab=letffab=matcha,bwith|Fa,Fb->FA.Scalar.(divab)|Fa,Arrb->ArrA.(scalar_divab)|Arra,Fb->ArrA.(div_scalarab)|Arra,Arrb->ArrA.(divab)|_->error_binop"( / )"abinletfdab=a/binletdf_da_cp_apat=at/binletdf_dbcpbpbt=(negbt)*cp/bpinletdf_dabcp_apatbpbt=(at-bt*cp)/bpinletr_d_dab=Div_D_D(a,b)inletr_d_cab=Div_D_C(a,b)inletr_c_dab=Div_C_D(a,b)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dand(**)ab=powabandpowab=letffab=matcha,bwith|Fa,Fb->FA.Scalar.(powab)|Fa,Arrb->ArrA.(scalar_powab)|Arra,Fb->ArrA.(pow_scalarab)|Arra,Arrb->ArrA.(powab)|_->error_binop"( ** )"abinletfdab=a**binletdf_da_cpapat=at*(ap**(b-(pack_flt1.)))*binletdf_dbcp_bpbt=bt*cp*(loga)inletdf_dab_cpapatbpbt=(ap**(bp-(pack_flt1.)))*((at*bp)+(ap*bt*logap))inletr_d_dab=Pow_D_D(a,b)inletr_d_cab=Pow_D_C(a,b)inletr_c_dab=Pow_C_D(a,b)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dandatan2ab=letffab=matcha,bwith|Fa,Fb->FA.Scalar.(atan2ab)|Fa,Arrb->ArrA.(scalar_atan2ab)|Arra,Fb->ArrA.(atan2_scalarab)|Arra,Arrb->ArrA.(atan2ab)|_->error_binop"atan2"abinletfdab=atan2abinletdf_da_cpapat=at*b/((sqrap)+(sqrb))inletdf_db_cpbpbt=(negbt)*a/((sqra)+(sqrbp))inletdf_dab_cpapatbpbt=((at*bp)-(bt*ap))/((sqrap)+(sqrbp))inletr_d_dab=Atan2_D_D(a,b)inletr_d_cab=Atan2_D_C(a,b)inletr_c_dab=Atan2_C_D(a,b)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dandmin2ab=((a+b)-abs(a-b))/(pack_flt2.)andmax2ab=((a+b)+abs(b-a))/(pack_flt2.)andnega=letff=function|Fa->FA.Scalar.(nega)|Arra->ArrA.(nega)|_->error_uniop"neg"ainletfda=negainletdf_cp_apat=(pack_flt0.)-atinletra=Neg_Dainop_d_dafffddfrandabsa=letff=function|Fa->FA.Scalar.(absa)|Arra->ArrA.(absa)|_->error_uniop"abs"ainletfda=absainletdf_cpapat=at*(signumap)inletra=Abs_Dainop_d_dafffddfrandsignuma=letff=function|Fa->FA.Scalar.(signuma)|Arra->ArrA.(signuma)|_->error_uniop"signum"ainletfda=signumainletdf_cpap_at=zeroapinletra=Signum_Dainop_d_dafffddfrandfloora=letff=function|Fa->FA.Scalar.(floora)|Arra->ArrA.(floora)|_->error_uniop"floor"ainletfda=floorainletdf_cpap_at=zeroapinletra=Floor_Dainop_d_dafffddfrandceila=letff=function|Fa->FA.Scalar.(ceila)|Arra->ArrA.(ceila)|_->error_uniop"ceil"ainletfda=ceilainletdf_cpap_at=zeroapinletra=Ceil_Dainop_d_dafffddfrandrounda=letff=function|Fa->FA.Scalar.(rounda)|Arra->ArrA.(rounda)|_->error_uniop"round"ainletfda=roundainletdf_cpap_at=zeroapinletra=Round_Dainop_d_dafffddfrandsqra=letff=function|Fa->FA.Scalar.(sqra)|Arra->ArrA.(sqra)|_->error_uniop"sqr"ainletfda=sqrainletdf_cpapat=(pack_flt2.)*at*apinletra=Sqr_Dainop_d_dafffddfrandsqrta=letff=function|Fa->FA.Scalar.(sqrta)|Arra->ArrA.(sqrta)|_->error_uniop"sqrt"ainletfda=sqrtainletdfcp_apat=at/((pack_flt2.)*cp)inletra=Sqrt_Dainop_d_dafffddfrandloga=letff=function|Fa->FA.Scalar.(loga)|Arra->ArrA.(loga)|_->error_uniop"log"ainletfda=logainletdf_cpapat=at/apinletra=Log_Dainop_d_dafffddfrandlog2a=letff=function|Fa->FA.Scalar.(log2a)|Arra->ArrA.(log2a)|_->error_uniop"log2"ainletfda=log2ainletdf_cpapat=at/(ap*(pack_fltOwl_const.log2e))inletra=Log2_Dainop_d_dafffddfrandlog10a=letff=function|Fa->FA.Scalar.(log10a)|Arra->ArrA.(log10a)|_->error_uniop"log10"ainletfda=log10ainletdf_cpapat=at/(ap*(pack_fltOwl_const.log10e))inletra=Log10_Dainop_d_dafffddfrandexpa=letff=function|Fa->FA.Scalar.(expa)|Arra->ArrA.(expa)|_->error_uniop"exp"ainletfda=expainletdfcp_apat=at*cpinletra=Exp_Dainop_d_dafffddfrandsina=letff=function|Fa->FA.Scalar.(sina)|Arra->ArrA.(sina)|_->error_uniop"sin"ainletfda=sinainletdf_cpapat=at*cosapinletra=Sin_Dainop_d_dafffddfrandcosa=letff=function|Fa->FA.Scalar.(cosa)|Arra->ArrA.(cosa)|_->error_uniop"cos"ainletfda=cosainletdf_cpapat=neg(at*sinap)inletra=Cos_Dainop_d_dafffddfrandtana=letff=function|Fa->FA.Scalar.(tana)|Arra->ArrA.(tana)|_->error_uniop"tan"ainletfda=tanainletdf_cpapat=at/(sqr(cosap))inletra=Tan_Dainop_d_dafffddfrandsinha=letff=function|Fa->FA.Scalar.(sinha)|Arra->ArrA.(sinha)|_->error_uniop"sinh"ainletfda=sinhainletdf_cpapat=at*(coshap)inletra=Sinh_Dainop_d_dafffddfrandcosha=letff=function|Fa->FA.Scalar.(cosha)|Arra->ArrA.(cosha)|_->error_uniop"cosh"ainletfda=coshainletdf_cpapat=at*(sinhap)inletra=Cosh_Dainop_d_dafffddfrandtanha=letff=function|Fa->FA.Scalar.(tanha)|Arra->ArrA.(tanha)|_->error_uniop"tanh"ainletfda=tanhainletdf_cpapat=at/(sqr(coshap))inletra=Tanh_Dainop_d_dafffddfrandasina=letff=function|Fa->FA.Scalar.(asina)|Arra->ArrA.(asina)|_->error_uniop"asin"ainletfda=asinainletdf_cpapat=at/sqrt((pack_flt1.)-sqrap)inletra=Asin_Dainop_d_dafffddfrandacosa=letff=function|Fa->FA.Scalar.(acosa)|Arra->ArrA.(acosa)|_->error_uniop"acos"ainletfda=acosainletdf_cpapat=(negat)/sqrt((pack_flt1.)-sqrap)inletra=Acos_Dainop_d_dafffddfrandatana=letff=function|Fa->FA.Scalar.(atana)|Arra->ArrA.(atana)|_->error_uniop"atan"ainletfda=atanainletdf_cpapat=at/((pack_flt1.)+sqrap)inletra=Atan_Dainop_d_dafffddfrandasinha=letff=function|Fa->FA.Scalar.(asinha)|Arra->ArrA.(asinha)|_->error_uniop"asinh"ainletfda=asinhainletdf_cpapat=at/sqrt((sqrap)+(pack_flt1.))inletra=Asinh_Dainop_d_dafffddfrandacosha=letff=function|Fa->FA.Scalar.(acosha)|Arra->ArrA.(acosha)|_->error_uniop"acosh"ainletfda=acoshainletdf_cpapat=at/sqrt((sqrap)-(pack_flt1.))inletra=Acosh_Dainop_d_dafffddfrandatanha=letff=function|Fa->FA.Scalar.(atanha)|Arra->ArrA.(atanha)|_->error_uniop"atanh"ainletfda=atanhainletdf_cpapat=at/((pack_flt1.)-sqrap)inletra=Atanh_Dainop_d_dafffddfrandget_itemaij=matchawith|Arrap->F(A.getap[|i;j|])|DF(ap,at,ai)->DF(get_itemapij,get_itematij,ai)|DR(ap,_,_,_,ai,_)->DR(get_itemapij,ref(pack_flt0.),Get_Item(a,i,j),ref0,ai,ref0)|_->error_uniop"get_item"aandset_itemaijb=letffab=matcha,bwith|Arra,Fb->letaa=A.copyainA.setaa[|i;j|]b;Arraa|_->error_uniop"set_item"ainletfdab=set_itemaijbinletdf_da_cp_apat=set_itematij(pack_flt0.)inletdf_db_cp_bpbt=add_item(zeroa)ijbtinletdf_dab_cp_apat_bpbt=set_itematijbtinletr_d_dab=SetI_D_D(a,i,j,b)inletr_d_cab=SetI_D_C(a,i,j,b)inletr_c_dab=SetI_C_D(a,i,j,b)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dandadd_itemaijb=letffab=matcha,bwith|Arra,Fb->letaa=A.copyainA.setaa[|i;j|]A.Scalar.(add(A.getaa[|i;j|])b);Arraa|_->error_binop"add_item"abinletfdab=add_itemaijbinletdf_da_cp_apat=atinletdf_db_cp_bpbt=add_item(zeroa)ijbtinletdf_dab_cp_apat_bpbt=add_itematijbtinletr_d_dab=AddI_D_D(a,i,j,b)inletr_d_cab=AddI_D_C(a,i,j,b)inletr_c_dab=AddI_C_D(a,i,j,b)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dandget_sliceia=letff=function|Arra->ArrA.(get_sliceia)|_->error_uniop"slice"ainletfda=get_sliceiainletdf_cp_apat=get_sliceiatinletra=Get_Slice_D(a,i)inop_d_dafffddfrandset_sliceiab=letffab=matcha,bwith|Arra,Arrb->leta=A.copyainA.(set_sliceiab);Arra|_->error_binop"set_slice"abinletfdab=set_sliceiabinletdf_da_cp_apat=set_sliceiat(zerob)inletdf_db_cp_bpbt=set_slicei(zeroa)btinletdf_dab_cp_apat_bpbt=set_sliceiatbtinletr_d_dab=Set_Slice_D_D(a,b,i)inletr_d_cab=Set_Slice_D_C(a,b,i)inletr_c_dab=Set_Slice_C_D(a,b,i)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dandsum'a=letff=function|Fa->Fa|Arra->FA.(sum'a)|_->error_uniop"sum"ainletfda=sum'ainletdf_cp_apat=sum'atinletra=Sum_Dainop_d_dafffddfrandsum?(axis=(-1))a=letff=function|Fa->Fa|Arra->ArrA.(sum~axisa)|_->error_uniop"sum"ainletfda=sum~axisainletdf_cp_apat=sum~axisatinletra=Sum__D(a,axis)inop_d_dafffddfrandsum_reduce?(axis=[|0|])a=letff=function|Fa->Fa|Arrx->ArrA.(sum_reduce~axisx)|_->error_uniop"sum_reduce"ainletfda=sum_reduce~axisainletdf_cp_apat=sum_reduce~axisatinletra=Sum___D(a,axis)inop_d_dafffddfrandmeana=(sum'a)/F(numela|>float_of_int|>A.float_to_elt)and(*@)ab=dotabanddotab=letffab=matcha,bwith|Arra,Arrb->ArrA.(dotab)|_->error_binop"( *@ )"abinletfdab=a*@binletdf_da_cp_apat=at*@binletdf_db_cp_bpbt=a*@btinletdf_dab_cpapatbpbt=(ap*@bt)+(at*@bp)inletr_d_dab=Dot_D_D(a,b)inletr_d_cab=Dot_D_C(a,b)inletr_c_dab=Dot_C_D(a,b)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dandtransposea=letff=function|Arra->ArrA.(transposea)|_->error_uniop"transpose"ainletfda=transposeainletdf_cp_apat=transposeatinletra=Trans_Dainop_d_dafffddfrandl1norm'a=letff=function|Arra->FA.(l1norm'a)|_->error_uniop"l1norm'"ainletfda=l1norm'ainletdf_cpapat=at*(signumap)inletra=L1Norm_Dainop_d_dafffddfrandl2norm'a=letff=function|Arra->FA.(l2norm'a)|_->error_uniop"l2norm'"ainletfda=l2norm'ainletdfcpapat=(ap*at)/cpinletra=L2Norm_Dainop_d_dafffddfrandl2norm_sqr'a=letff=function|Fa->FA.Scalar.(sqra)|Arra->FA.(l2norm_sqr'a)|_->error_uniop"l2norm_sqr'"ainletfda=l2norm_sqr'ainletdf_cpapat=(pack_flt2.)*(ap*at)inletra=L2NormS_Dainop_d_dafffddfrandsigmoida=letff=function|Fa->FA.Scalar.(sigmoida)|Arra->ArrA.(sigmoida)|_->error_uniop"sigmoid"ainletfda=sigmoidainletdfcp_apat=at*cp*((pack_flt1.)-cp)inletra=Sigmoid_Dainop_d_dafffddfrandrelua=letff=function|Fa->FA.Scalar.(relua)|Arra->ArrA.(relua)|_->error_uniop"relu"ainletfda=reluainletdf_cpapat=at*((pack_flt1.)+(signumap))/(pack_flt2.)inletra=Relu_Dainop_d_dafffddfranddiag?(k=0)a=letff=function|Arra->ArrA.(diag~ka|>copy)|_->error_uniop"diag"ainletfda=diag~kainletdf_cp_apat=diag~katinletra=Diag_D(k,a)inop_d_dafffddfranddiagm?(k=0)a=letff=function|Arra->ArrA.(diagm~ka|>copy)|_->error_uniop"diagm"ainletfda=diagm~kainletdf_cp_apat=diagm~katinletra=Diagm_D(k,a)inop_d_dafffddfrandtracea=letff=function|Arra->FA.(tracea)|_->error_uniop"trace"ainletfda=traceainletdf_cp_apat=traceatinletra=Trace_Dainop_d_dafffddfrandtriu?(k=0)a=letff=function|Arra->ArrA.(triu~ka|>copy)|_->error_uniop"triu"ainletfda=triu~kainletdf_cp_apat=triu~katinletra=Triu_D(k,a)inop_d_dafffddfrandtril?(k=0)a=letff=function|Arra->ArrA.(tril~ka|>copy)|_->error_uniop"tril"ainletfda=tril~kainletdf_cp_ap_at=tril~kainletra=Tril_D(k,a)inop_d_dafffddfrandinva=letff=function|Arra->ArrA.(inva)|_->error_uniop"inv"ainletfda=invainletdfcp_apat=(negcp)*at*cpinletra=Inv_Dainop_d_dafffddfrandlogdeta=letff=function|Arra->FA.(logdeta)|_->error_uniop"logdet"ainletfda=logdetainletdf_cpapat=trace((transpose(invap))*@at)inletra=Logdet_Dainop_d_dafffddfrandcopyltux=(trilx)+(transpose(tril~k:(-1)x))andcopyutlx=(triux)+(transpose(triu~k:1x))andchol?(upper=true)a=letff=function|Arra->ArrA.(chol~uppera)|_->error_uniop"chol"ainletfda=chol~upperainletdfcp_apat=_chol_forwardcpatupperinletra=Chol_D(a,upper)inop_d_dafffddfrand_chol_forwardcpatupper=letinv_cp=invcpinlettr_inv_cp=transposeinv_cpinifupperthenletx=tr_inv_cp*@(transposeat)*@inv_cpinletm=(pack_flt0.5)*(tril(triux))in(transposecp)*@(m+(triu~k:(1)x))elseletx=inv_cp*@at*@tr_inv_cpinletm=(pack_flt0.5)*(tril(triux))incp*@(m+(tril~k:(-1)x))and_chol_backwardoaaupper=letinv_o=invoinlettr_inv_o=transposeinv_oinifupperthen(pack_flt0.5)*inv_o*@(copyutl(aa*@(transposeo)))*@tr_inv_oelse(pack_flt0.5)*tr_inv_o*@(copyltu((transposeo)*@aa))*@inv_oandqra=letff=function|Arra->letq,r=A.(qra)in(Arrq,Arrr)|_->error_uniop"qr"ainletfda=qrainletdf_cp_ap_at=raiseOwl_exception.NOT_IMPLEMENTEDinletr(a,(cp1,cp2),(aa1,aa2))=QR_D(a,(cp1,cp2),(aa1,aa2))inpair_op_d_dafffddfrand_qr_backward(o1,o2)(aa1,aa2)=letq=!o1andr=!o2andqbar=!aa1andrbar=!aa2inletqt=transposeqandqbart=transposeqbarinletrt=transposerandrbart=transposerbarin(*let rinvt = r *@ (inv (rt *@ r)) in (* transpose of the left moore-penrose pseudoinverse *)*)letrinvt=transpose(invr)inletmiddle=tril~k:(-1)((r*@rbart)-(rbar*@rt)+(qt*@qbar)-(qbart*@q))in(q*@(rbar+(middle*@rinvt)))+((qbar-(q*@(qt*@qbar)))*@rinvt)andsvd?(thin=true)a=letff=function|Arra->letu,s,vt=A.(svd~thina)in(Arru,Arrs,Arrvt)|_->error_uniop"svd"ainletfda=svd~thinainletdf_cp_ap_at=raiseOwl_exception.NOT_IMPLEMENTEDinletr(a,(cp1,cp2,cp3),(aa1,aa2,aa3))=Svd_D(a,(cp1,cp2,cp3),(aa1,aa2,aa3),thin)intriple_op_d_dafffddfrand_svd_backward(o1,o2,o3)(aa1,aa2,aa3)thin=let(u,s,vt)=(!o1,!o2,!o3)and(ubar,sbar,vbart)=(!aa1,!aa2,!aa3)inletut=transposeuandv=transposevtinletubart=transposeubarandvbar=transposevbartinleteyen=A.(ones[|1;n|])|>pack_arr|>diagminlete_m=eye(row_numu)inlete_n=eye(row_numv)inletk=row_numvtinletf=lets2=sqrsinpack_arrA.(init_nd[|k;k|](funidx->leti=idx.(0)andj=idx.(1)inifi=jthenfloat_to_elt0.elsebeginlets2_i=get_items20i|>unpack_fltinlets2_j=get_items20j|>unpack_fltin(1./.(s2_j-.s2_i))|>float_to_eltend))inletinv_s=(pack_flt1.)/sinifthinthenbegin((u*sbar)*@vt+((u*@(f*(ut*@ubar-ubart*@u))*s)+((e_m-(u*@ut))*@ubar*inv_s))*@vt+u*@(((transposes)*(f*(vt*@vbar-vbart*@v)))*@vt+((transposeinv_s)*vbart*@(e_n-v*@vt))))endelseraiseOwl_exception.NOT_IMPLEMENTEDandlyapunovaq=letffaq=matcha,qwith|Arra,Arrq->ArrA.(lyapunovaq)|_->error_binop"lyapunov"aqinletfdaq=lyapunovaqinletdf_dacpapat=lyapunovap(neg((at*@cp)+(cp*@(transposeat))))inletdf_dq_cp_qpqt=lyapunova(negqt)inletdf_daqcpapat_qpqt=(lyapunovap(neg((at*@cp)+(cp*@(transposeat)))))+(lyapunovap(negqt))inletr_d_daq=Lyapunov_D_D(a,q)inletr_d_caq=Lyapunov_D_C(a,q)inletr_c_daq=Lyapunov_C_D(a,q)inop_d_d_daqfffddf_dadf_dqdf_daqr_d_dr_d_cr_c_dand_lyapunov_backward_aaaaap=lets=lyapunov(transposea)(negaa)in(pack_flt2.)*s*@apand_lyapunov_backward_qaaa=neg(lyapunov(transposea)(negaa))and_lyapunov_backward_aqaaaap=lets=lyapunov(transposea)(negaa)in(pack_flt2.)*s*@ap,negsanddiscrete_lyapunov?(solver=`default)aq=letffaq=matcha,qwith|Arra,Arrq->ArrA.(discrete_lyapunov~solveraq)|_->error_binop"discrete_lyapunov"aqinletfdaq=discrete_lyapunov~solveraqinletdf_dacpapat=discrete_lyapunovap((ap*@cp*@(transposeat))+(at*@cp*@(transposea)))inletdf_dq_cp_qpqt=discrete_lyapunovaqtinletdf_daqcpapat_qpqt=(discrete_lyapunovap((ap*@cp*@(transposeat))+(at*@cp*@(transposea))))+(discrete_lyapunovapqt)inletr_d_daq=Discrete_Lyapunov_D_D(a,q)inletr_d_caq=Discrete_Lyapunov_D_C(a,q)inletr_c_daq=Discrete_Lyapunov_C_D(a,q)inop_d_d_daqfffddf_dadf_dqdf_daqr_d_dr_d_cr_c_dand_discrete_lyapunov_backward_aaaaap=lets=discrete_lyapunov(transposea)aain(pack_flt2.)*s*@a*@apand_discrete_lyapunov_backward_qaaa=discrete_lyapunov(transposea)aaand_discrete_lyapunov_backward_aqaaaap=lets=discrete_lyapunov(transposea)aain(pack_flt2.)*s*@a*@ap,sandsoftplusx=log((pack_flt1.)+expx)andsoftsignx=x/((pack_flt1.)+absx)andsoftmax?(axis=(-1))x=letc=ArrA.(max~axis(unpack_arrx))inlety=exp(x-c)inleta=sum~axisyiny/aandcross_entropyxy=x*logy|>sum'|>negandadd_rowabi=letffab=matcha,bwith|Arra,Arrb->A.(copy_row_to(add(rowai)b)ai;Arra)|_->error_binop"add_row"abinletfdab=add_rowabiinletdf_da_cp_apat=atinletdf_db_cp_bpbt=add_row(zeroa)btiinletdf_dab_cp_apat_bpbt=add_rowatbtiinletr_d_dab=Add_Row_D_D(a,b,i)inletr_d_cab=Add_Row_D_C(a,b,i)inletr_c_dab=Add_Row_C_D(a,b,i)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dandget_rowai=letff=function|Arra->ArrA.(rowai|>copy)|_->error_uniop"get_row"ainletfda=get_rowaiinletdf_cp_apat=get_rowatiinletra=Get_Row_D(a,i)inop_d_dafffddfrandto_rowsa=Array.init(row_numa)(funi->get_rowai)andof_rowsa=(* TODO: this can be further optimised by incorporating t array type as t *)matcha.(0)with|Arr_->Array.mapunpack_arra|>A.of_rows|>pack_arr|DF(_,_,ai)->letap=a|>Array.map(funx->x|>primal|>unpack_arr)|>A.of_rows|>pack_arrinletat=a|>Array.map(funx->x|>tangent|>unpack_arr)|>A.of_rows|>pack_arrinDF(ap,at,ai)|DR(_,_,_,_,ai,_)->letap=a|>Array.map(funx->x|>primal)inletcp=ap|>Array.map(funx->x|>unpack_arr)|>A.of_rows|>pack_arrinDR(cp,ref(zerocp),Of_Rows_Da,ref0,ai,ref0)|_->error_uniop"of_rows a.(0)"a.(0)andof_arraysa=(* mode: 0 constant, 1 reverse, 2 tangent *)letmode=ref0inletidxs=ref[]inletai_ref=ref0inletcp=Array.mapi(funixs->Array.mapi(funjx->matchx,!modewith|F_,_->unpack_eltx|DR(_,_,_,_,ai,_),0->ai_ref:=ai;mode:=1;idxs:=(i,j)::!idxs;unpack_eltx|DR(_,_,_,_,ai,_),1->ai_ref:=ai;idxs:=(i,j)::!idxs;unpack_eltx|DF(_,_,ai),0->ai_ref:=ai;mode:=1;idxs:=(i,j)::!idxs;unpack_eltx|DF(_,_,ai),2->ai_ref:=ai;mode:=2;unpack_eltx|_,_->error_uniop"of_arrays: inconsistent array"x)xs)a|>A.of_arrays|>pack_arrinmatch!modewith|0->cp|1->DR(cp,ref(zerocp),Of_Arrays_D(a,List.rev!idxs),ref0,!ai_ref,ref0)|2->letat=a|>Array.map(Array.map(funx->x|>tangent|>unpack_elt))|>A.of_arrays|>pack_arrinDF(cp,at,!ai_ref)|_->error_uniop"of_arrays"a.(0).(0)andto_arraysa=Array.init(row_numa)(funi->Array.init(col_numa)(funj->get_itemaij))(* NOTE: these fucntions are for neural network. There are many restrictions
at the moment. E.g. they do not support higher-order derivatives, and some
do not support forward mode, so use them when you know what you are doing.
*)(* a:input; b:kernel; s:stride *)andconv1d?paddingabs=letffab=matcha,bwith|Arra,Arrb->ArrA.(conv1d?paddingabs)|_->error_binop"conv1d"abinletfdab=conv1d?paddingabsin(* FIXME: df_da, df_db, df_dab are not correct ... do not use *)letdf_da_cp_ap_at=failwith"conv1d:df_da"inletdf_db_cp_bp_bt=failwith"conv1d:df_db"inletdf_dab_cp_ap_at_bp_bt=failwith"conv1d:df_dab"inletr_d_dab=Conv1D_D_D(a,b,s)inletr_d_cab=Conv1D_D_C(a,b,s)inletr_c_dab=Conv1D_C_D(a,b,s)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_d(* a:input; b:kernel; s:stride; o:output' *)andconv1d_backward_inputabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.conv1d_backward_inputabso|>pack_arr(* a:input; b:kernel; s:stride; o:output' *)andconv1d_backward_kernelabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.conv1d_backward_kernelabso|>pack_arr(* a:input; b:kernel; s:stride *)andconv2d?paddingabs=letffab=matcha,bwith|Arra,Arrb->ArrA.(conv2d?paddingabs)|_->error_binop"conv2d"abinletfdab=conv2d?paddingabsin(* FIXME: df_da, df_db, df_dab are not correct ... do not use *)letdf_da_cp_apat=atinletdf_db_cp_bpbt=btinletdf_dab_cp_apat_bpbt=at+btinletr_d_dab=Conv2D_D_D(a,b,s)inletr_d_cab=Conv2D_D_C(a,b,s)inletr_c_dab=Conv2D_C_D(a,b,s)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_d(* a:input; b:kernel; s:stride; o:output' *)andconv2d_backward_inputabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.conv2d_backward_inputabso|>pack_arr(* a:input; b:kernel; s:stride; o:output' *)andconv2d_backward_kernelabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.conv2d_backward_kernelabso|>pack_arr(* a:input; b:kernel; s:stride *)andconv3d?paddingabs=letffab=matcha,bwith|Arra,Arrb->ArrA.(conv3d?paddingabs)|_->error_binop"conv3d"abinletfdab=conv3d?paddingabsin(* FIXME: df_da, df_db, df_dab are not correct ... do not use *)letdf_da_cp_apat=atinletdf_db_cp_bpbt=btinletdf_dab_cp_apat_bpbt=at+btinletr_d_dab=Conv3D_D_D(a,b,s)inletr_d_cab=Conv3D_D_C(a,b,s)inletr_c_dab=Conv3D_C_D(a,b,s)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_d(* a:input; b:kernel; s:stride; o:output' *)andconv3d_backward_inputabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.conv3d_backward_inputabso|>pack_arr(* a:input; b:kernel; s:stride; o:output' *)andconv3d_backward_kernelabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.conv3d_backward_kernelabso|>pack_arr(* a:input; b:kernel; s:stride; r:rate *)anddilated_conv1d?paddingabsr=letffab=matcha,bwith|Arra,Arrb->ArrA.(dilated_conv1d?paddingabsr)|_->error_binop"dilated_conv1d"abinletfdab=dilated_conv1d?paddingabsrin(* FIXME: df_da, df_db, df_dab are not correct ... do not use *)letdf_da_cp_apat=atinletdf_db_cp_bpbt=btinletdf_dab_cp_apat_bpbt=at+btinletr_d_dab=Di_Conv1D_D_D(a,b,s,r)inletr_d_cab=Di_Conv1D_D_C(a,b,s,r)inletr_c_dab=Di_Conv1D_C_D(a,b,s,r)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_d(* a:input; b:kernel; o:output'; s:stride; r:rate *)anddilated_conv1d_backward_inputabsro=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.dilated_conv1d_backward_inputabsro|>pack_arr(* a:input; b:kernel; o:output'; s:stride; r:rate *)anddilated_conv1d_backward_kernelabsro=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.dilated_conv1d_backward_kernelabsro|>pack_arr(* a:input; b:kernel; s:stride; r:rate *)anddilated_conv2d?paddingabsr=letffab=matcha,bwith|Arra,Arrb->ArrA.(dilated_conv2d?paddingabsr)|_->error_binop"dilated_conv2d"abinletfdab=dilated_conv2d?paddingabsrin(* FIXME: df_da, df_db, df_dab are not correct ... do not use *)letdf_da_cp_apat=atinletdf_db_cp_bpbt=btinletdf_dab_cp_apat_bpbt=at+btinletr_d_dab=Di_Conv2D_D_D(a,b,s,r)inletr_d_cab=Di_Conv2D_D_C(a,b,s,r)inletr_c_dab=Di_Conv2D_C_D(a,b,s,r)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_d(* a:input; b:kernel; o:output'; s:stride; r:rate *)anddilated_conv2d_backward_inputabsro=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.dilated_conv2d_backward_inputabsro|>pack_arr(* a:input; b:kernel; o:output'; s:stride; r:rate *)anddilated_conv2d_backward_kernelabsro=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.dilated_conv2d_backward_kernelabsro|>pack_arr(* a:input; b:kernel; s:stride; r:rate *)anddilated_conv3d?paddingabsr=letffab=matcha,bwith|Arra,Arrb->ArrA.(dilated_conv3d?paddingabsr)|_->error_binop"dilated_conv3d"abinletfdab=dilated_conv3d?paddingabsrin(* FIXME: df_da, df_db, df_dab are not correct ... do not use *)letdf_da_cp_apat=atinletdf_db_cp_bpbt=btinletdf_dab_cp_apat_bpbt=at+btinletr_d_dab=Di_Conv3D_D_D(a,b,s,r)inletr_d_cab=Di_Conv3D_D_C(a,b,s,r)inletr_c_dab=Di_Conv3D_C_D(a,b,s,r)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_d(* a:input; b:kernel; o:output'; s:stride; r:rate *)anddilated_conv3d_backward_inputabsro=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.dilated_conv3d_backward_inputabsro|>pack_arr(* a:input; b:kernel; o:output'; s:stride; r:rate *)anddilated_conv3d_backward_kernelabsro=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.dilated_conv3d_backward_kernelabsro|>pack_arr(* a:input; b:kernel; s:stride *)andtranspose_conv1d?paddingabs=letffab=matcha,bwith|Arra,Arrb->ArrA.(transpose_conv1d?paddingabs)|_->error_binop"transpose_conv1d"abinletfdab=transpose_conv1d?paddingabsin(* FIXME: df_da, df_db, df_dab are not correct ... do not use *)letdf_da_cp_apat=atinletdf_db_cp_bpbt=btinletdf_dab_cp_apat_bpbt=at+btinletr_d_dab=Tr_Conv1D_D_D(a,b,s)inletr_d_cab=Tr_Conv1D_D_C(a,b,s)inletr_c_dab=Tr_Conv1D_C_D(a,b,s)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_d(* a:input; b:kernel; s:stride; o:output' *)andtranspose_conv1d_backward_inputabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.transpose_conv1d_backward_inputabso|>pack_arr(* a:input; b:kernel; s:stride; o:output' *)andtranspose_conv1d_backward_kernelabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.transpose_conv1d_backward_kernelabso|>pack_arr(* a:input; b:kernel; s:stride *)andtranspose_conv2d?paddingabs=letffab=matcha,bwith|Arra,Arrb->ArrA.(transpose_conv2d?paddingabs)|_->error_binop"transpose_conv2d"abinletfdab=transpose_conv2d?paddingabsin(* FIXME: df_da, df_db, df_dab are not correct ... do not use *)letdf_da_cp_apat=atinletdf_db_cp_bpbt=btinletdf_dab_cp_apat_bpbt=at+btinletr_d_dab=Tr_Conv2D_D_D(a,b,s)inletr_d_cab=Tr_Conv2D_D_C(a,b,s)inletr_c_dab=Tr_Conv2D_C_D(a,b,s)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_d(* a:input; b:kernel; s:stride; o:output' *)andtranspose_conv2d_backward_inputabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.transpose_conv2d_backward_inputabso|>pack_arr(* a:input; b:kernel; s:stride; o:output' *)andtranspose_conv2d_backward_kernelabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.transpose_conv2d_backward_kernelabso|>pack_arr(* a:input; b:kernel; s:stride *)andtranspose_conv3d?paddingabs=letffab=matcha,bwith|Arra,Arrb->ArrA.(transpose_conv3d?paddingabs)|_->error_binop"transpose_conv3d"abinletfdab=transpose_conv3d?paddingabsin(* FIXME: df_da, df_db, df_dab are not correct ... do not use *)letdf_da_cp_apat=atinletdf_db_cp_bpbt=btinletdf_dab_cp_apat_bpbt=at+btinletr_d_dab=Tr_Conv3D_D_D(a,b,s)inletr_d_cab=Tr_Conv3D_D_C(a,b,s)inletr_c_dab=Tr_Conv3D_C_D(a,b,s)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_d(* a:input; b:kernel; s:stride; o:output' *)andtranspose_conv3d_backward_inputabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.transpose_conv3d_backward_inputabso|>pack_arr(* a:input; b:kernel; s:stride; o:output' *)andtranspose_conv3d_backward_kernelabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.transpose_conv3d_backward_kernelabso|>pack_arrandreshapeas=letff=function|Arra->ArrA.(reshapeas)|_->error_uniop"reshape"ainletfda=reshapeasinletdf_cp_apat=reshapeatsinletra=Reshape_Dainop_d_dafffddfrandflattena=reshapea[|1;numela|](* a:input; b:kernel; s:stride *)andmax_pool1dpaddingabs=letff=function|Arra->ArrA.(max_pool1d~paddingabs)|_->error_uniop"max_pool1d"ainletfda=max_pool1dpaddingabsinletdf_cp_ap_at=failwith"max_pool1d:df"inletra=Maxpool1D_D(a,padding,b,s)inop_d_dafffddfr(* a:input; p:padding type; b:kernel; s:stride; o:output' *)andmax_pool1d_backwardpabso=leta=unpack_arrainleto=unpack_arroinA.max_pool1d_backwardpabso|>pack_arr(* a:input; b:kernel; s:stride *)andmax_pool2dpaddingabs=letff=function|Arra->ArrA.(max_pool2d~paddingabs)|_->error_uniop"max_pool2d"ainletfda=max_pool2dpaddingabsinletdf_cp_ap_at=failwith"max_pool2d:df"inletra=Maxpool2D_D(a,padding,b,s)inop_d_dafffddfr(* a:input; p:padding type; b:kernel; s:stride; o:output' *)andmax_pool2d_backwardpabso=leta=unpack_arrainleto=unpack_arroinA.max_pool2d_backwardpabso|>pack_arr(* a:input; b:kernel; s:stride *)andmax_pool3dpaddingabs=letff=function|Arra->ArrA.(max_pool3d~paddingabs)|_->error_uniop"max_pool3d"ainletfda=max_pool3dpaddingabsinletdf_cp_ap_at=failwith"max_pool3d:df"inletra=Maxpool3D_D(a,padding,b,s)inop_d_dafffddfr(* a:input; p:padding type; b:kernel; s:stride; o:output' *)andmax_pool3d_backwardpabso=leta=unpack_arrainleto=unpack_arroinA.max_pool3d_backwardpabso|>pack_arr(* a:input; b:kernel; s:stride *)andavg_pool1dpaddingabs=letff=function|Arra->ArrA.(avg_pool1d~paddingabs)|_->error_uniop"avg_pool1d"ainletfda=avg_pool1dpaddingabsinletdf_cp_ap_at=failwith"avg_pool1d:df"inletra=Avgpool1D_D(a,padding,b,s)inop_d_dafffddfr(* a:input; p:padding type; b:kernel; s:stride; o:output' *)andavg_pool1d_backwardpabso=leta=unpack_arrainleto=unpack_arroinA.avg_pool1d_backwardpabso|>pack_arr(* a:input; b:kernel; s:stride *)andavg_pool2dpaddingabs=letff=function|Arra->ArrA.(avg_pool2d~paddingabs)|_->error_uniop"avg_pool2d"ainletfda=avg_pool2dpaddingabsinletdf_cp_ap_at=failwith"avg_pool2d:df"inletra=Avgpool2D_D(a,padding,b,s)inop_d_dafffddfr(* a:input; p:padding type; b:kernel; s:stride; o:output' *)andavg_pool2d_backwardpabso=leta=unpack_arrainleto=unpack_arroinA.avg_pool2d_backwardpabso|>pack_arr(* a:input; b:kernel; s:stride *)andavg_pool3dpaddingabs=letff=function|Arra->ArrA.(avg_pool3d~paddingabs)|_->error_uniop"avg_pool3d"ainletfda=avg_pool3dpaddingabsinletdf_cp_ap_at=failwith"avg_pool3d:df"inletra=Avgpool3D_D(a,padding,b,s)inop_d_dafffddfr(* a:input; p:padding type; b:kernel; s:stride; o:output' *)andavg_pool3d_backwardpabso=leta=unpack_arrainleto=unpack_arroinA.avg_pool3d_backwardpabso|>pack_arr(* a:input; s:size *)andupsampling2das=letff=function|Arra->ArrA.(upsampling2das)|_->error_uniop"upsampling2d"ainletfda=upsampling2dasinletdf_cp_ap_at=failwith"upsampling2d:df"inletra=UpSampling2D_D(a,s)inop_d_dafffddfr(* a:input; s:size; o:output' *)andupsampling2d_backwardaso=leta=unpack_arrainleto=unpack_arroinA.upsampling2d_backwardaso|>pack_arr(* v: padded value; p:padding index; a:input *)andpad?vpa=letff=function|Arra->ArrA.(pad?vpa)|_->error_uniop"pad"ainletfd=padpinletdf_cp_ap_at=failwith"pad:df"inletra=PAD_D(a,p)inop_d_dafffddfr(* TODO: sources required to confirm this backward op *)(* o:outut'; p: padding index *)andpad_backwardop=(* assume p is full legal index for pad operation *)leto=unpack_arroinletos=A.shapeoinletq=Owl_utils.llss2aarrpinArray.iteri(funix->x.(1)<-Pervasives.(os.(i)-1-x.(1));)q;letq=Owl_utils.aarr2llssqinA.(get_sliceqo)|>pack_arranddropout?(rate=0.5)a=letp=A.float_to_elt(1.-.rate)inletb=match(primal'a)with|Arra->Arr(A.bernoulli~p(A.shapea))|_->error_uniop"dropout"aina*bandconcataxisab=letffab=matcha,bwith|Arra,Arrb->ArrA.(concatenate~axis[|a;b|])|_->error_binop"concat"abinletfdab=concataxisabinletdf_da_cp_apat=concataxisat(zerob)inletdf_db_cp_bpbt=concataxis(zeroa)btinletdf_dab_cp_apat_bpbt=concataxisatbtinletr_d_dab=Concat_D_D(a,b,axis)inletr_d_cab=Concat_D_C(a,b,axis)inletr_c_dab=Concat_C_D(a,b,axis)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dandsplit~axispartsa=letffa=matchawith|Arra->A.(split~axispartsa)|>Array.map(funx->Arrx)|_->error_uniop"split"ainletfda=split~axispartsainletdf_cp_ap_at=raiseOwl_exception.NOT_IMPLEMENTEDinletr(a,_cp_arr,aa_arr)=Split_D(a,axis,aa_arr)inarray_op_d_dafffddfrandconcatenate~axisa=(* mode: 0 constant, 1 reverse, 2 tangent *)letmode=ref0inletidxs=ref[]inletai_ref=ref0inletcp=Array.mapi(funix->matchx,!modewith|Arr_,_->unpack_arrx|DR(_,_,_,_,ai,_),0->ai_ref:=ai;idxs:=i::!idxs;mode:=1;unpack_arrx|DR(_,_,_,_,ai,_),1->ai_ref:=ai;idxs:=i::!idxs;unpack_arrx|DF(_,_,ai),0->ai_ref:=ai;unpack_arrx|DF(_,_,ai),2->ai_ref:=ai;unpack_arrx|_->error_uniop"concatenate: inconsistent array"x)a|>A.concatenate~axis|>pack_arrinmatch!modewith|0->cp|1->DR(cp,ref(zerocp),Concatenate_D(a,axis,List.rev!idxs),ref0,!ai_ref,ref0)|2->letat=a|>Array.map(funx->x|>tangent|>unpack_arr)|>A.concatenate~axis|>pack_arrinDF(cp,at,!ai_ref)|_->error_uniop"concatenate"a.(0)andinit_2dn_rowsn_colsf=Array.initn_rows(funi->Array.initn_cols(funj->fij))|>of_arraysend(* core of the reverse mode *)letreverse_resetx=letrecresetxs=matchxswith|[]->()|x::t->(matchxwith|DR(_ap,aa,ao,af,_ai,tracker)->(aa:=reset_zero!aa;af:=!af+1;tracker:=succ!tracker;if(!af=1)&&(!tracker=1)then(matchaowith|Noop->resett|Add_D_D(a,b)->reset(a::b::t)|Add_D_C(a,_)->reset(a::t)|Add_C_D(_,b)->reset(b::t)|Sub_D_D(a,b)->reset(a::b::t)|Sub_D_C(a,_)->reset(a::t)|Sub_C_D(_,b)->reset(b::t)|Mul_D_D(a,b)->reset(a::b::t)|Mul_D_C(a,_)->reset(a::t)|Mul_C_D(_,b)->reset(b::t)|Div_D_D(a,b)->reset(a::b::t)|Div_D_C(a,_)->reset(a::t)|Div_C_D(_,b)->reset(b::t)|Pow_D_D(a,b)->reset(a::b::t)|Pow_D_C(a,_)->reset(a::t)|Pow_C_D(_,b)->reset(b::t)|Atan2_D_D(a,b)->reset(a::b::t)|Atan2_D_C(a,_)->reset(a::t)|Atan2_C_D(_,b)->reset(b::t)|Neg_Da->reset(a::t)|Abs_Da->reset(a::t)|Signum_Da->reset(a::t)|Floor_Da->reset(a::t)|Ceil_Da->reset(a::t)|Round_Da->reset(a::t)|Sqr_Da->reset(a::t)|Sqrt_Da->reset(a::t)|Log_Da->reset(a::t)|Log2_Da->reset(a::t)|Log10_Da->reset(a::t)|Exp_Da->reset(a::t)|Sin_Da->reset(a::t)|Cos_Da->reset(a::t)|Tan_Da->reset(a::t)|Sinh_Da->reset(a::t)|Cosh_Da->reset(a::t)|Tanh_Da->reset(a::t)|Asin_Da->reset(a::t)|Acos_Da->reset(a::t)|Atan_Da->reset(a::t)|Asinh_Da->reset(a::t)|Acosh_Da->reset(a::t)|Atanh_Da->reset(a::t)|Get_Item(a,_,_)->reset(a::t)|SetI_D_D(a,_,_,b)->reset(a::b::t)|SetI_D_C(a,_,_,_)->reset(a::t)|SetI_C_D(_,_,_,b)->reset(b::t)|AddI_D_D(a,_,_,b)->reset(a::b::t)|AddI_D_C(a,_,_,_)->reset(a::t)|AddI_C_D(_,_,_,b)->reset(b::t)|Get_Slice_D(a,_)->reset(a::t)|Set_Slice_D_D(a,b,_)->reset(a::b::t)|Set_Slice_D_C(a,_,_)->reset(a::t)|Set_Slice_C_D(_,b,_)->reset(b::t)|Sum_Da->reset(a::t)|Sum__D(a,_)->reset(a::t)|Sum___D(a,_)->reset(a::t)|Dot_D_D(a,b)->reset(a::b::t)|Dot_D_C(a,_)->reset(a::t)|Dot_C_D(_,b)->reset(b::t)|Trans_Da->reset(a::t)|L1Norm_Da->reset(a::t)|L2Norm_Da->reset(a::t)|L2NormS_Da->reset(a::t)|Sigmoid_Da->reset(a::t)|Relu_Da->reset(a::t)|Inv_Da->reset(a::t)|Logdet_Da->reset(a::t)|Chol_D(a,_)->reset(a::t)|QR_D(a,_,_)->reset(a::t)|Svd_D(a,_,_,_)->reset(a::t)|Lyapunov_D_D(a,q)->reset(a::q::t)|Lyapunov_D_C(a,_)->reset(a::t)|Lyapunov_C_D(_,q)->reset(q::t)|Discrete_Lyapunov_D_D(a,q)->reset(a::q::t)|Discrete_Lyapunov_C_D(_,q)->reset(q::t)|Discrete_Lyapunov_D_C(a,_)->reset(a::t)|Diag_D(_,a)->reset(a::t)|Diagm_D(_,a)->reset(a::t)|Triu_D(_,a)->reset(a::t)|Tril_D(_,a)->reset(a::t)|Trace_Da->reset(a::t)|Add_Row_D_D(a,b,_)->reset(a::b::t)|Add_Row_D_C(a,_,_)->reset(a::t)|Add_Row_C_D(_,b,_)->reset(b::t)|Get_Row_D(a,_)->reset(a::t)|Of_Rows_Da->reset(List.append(Array.to_lista)t)|Of_Arrays_D(a,idxs)->reset(List.(append(map(fun(i,j)->a.(i).(j))idxs)t))|Conv1D_D_D(a,b,_)->reset(a::b::t)|Conv1D_D_C(a,_,_)->reset(a::t)|Conv1D_C_D(_,b,_)->reset(b::t)|Conv2D_D_D(a,b,_)->reset(a::b::t)|Conv2D_D_C(a,_,_)->reset(a::t)|Conv2D_C_D(_,b,_)->reset(b::t)|Conv3D_D_D(a,b,_)->reset(a::b::t)|Conv3D_D_C(a,_,_)->reset(a::t)|Conv3D_C_D(_,b,_)->reset(b::t)|Di_Conv1D_D_D(a,b,_,_)->reset(a::b::t)|Di_Conv1D_D_C(a,_,_,_)->reset(a::t)|Di_Conv1D_C_D(_,b,_,_)->reset(b::t)|Di_Conv2D_D_D(a,b,_,_)->reset(a::b::t)|Di_Conv2D_D_C(a,_,_,_)->reset(a::t)|Di_Conv2D_C_D(_,b,_,_)->reset(b::t)|Di_Conv3D_D_D(a,b,_,_)->reset(a::b::t)|Di_Conv3D_D_C(a,_,_,_)->reset(a::t)|Di_Conv3D_C_D(_,b,_,_)->reset(b::t)|Tr_Conv1D_D_D(a,b,_)->reset(a::b::t)|Tr_Conv1D_D_C(a,_,_)->reset(a::t)|Tr_Conv1D_C_D(_,b,_)->reset(b::t)|Tr_Conv2D_D_D(a,b,_)->reset(a::b::t)|Tr_Conv2D_D_C(a,_,_)->reset(a::t)|Tr_Conv2D_C_D(_,b,_)->reset(b::t)|Tr_Conv3D_D_D(a,b,_)->reset(a::b::t)|Tr_Conv3D_D_C(a,_,_)->reset(a::t)|Tr_Conv3D_C_D(_,b,_)->reset(b::t)|Reshape_Da->reset(a::t)|Maxpool1D_D(a,_,_,_)->reset(a::t)|Maxpool2D_D(a,_,_,_)->reset(a::t)|Maxpool3D_D(a,_,_,_)->reset(a::t)|Avgpool1D_D(a,_,_,_)->reset(a::t)|Avgpool2D_D(a,_,_,_)->reset(a::t)|Avgpool3D_D(a,_,_,_)->reset(a::t)|UpSampling2D_D(a,_)->reset(a::t)|PAD_D(a,_)->reset(a::t)|Concat_D_D(a,b,_)->reset(a::b::t)|Concat_D_C(a,_,_)->reset(a::t)|Concat_C_D(_,b,_)->reset(b::t)|Split_D(a,_,_)->reset(a::t)|Concatenate_D(a,_,idx)->reset(List.appendList.(map(funi->a.(i))idx)t))elseresett)|_->resett)inreset[x](* check adjoint a and its update v, ensure rank a >= rank v. This function
fixes the inconsistent shapes between a and v by performing the inverse
operation of the previous broadcasting function. Note that padding is on
the left due to the expand function called in broadcasting. *)let_shrinkav=matcha,vwith|F_,Arrv->F(A.sum'v)|Arra,Arrv->(letshp_a=A.shapeainletshp_v=A.shapevinifshp_a<>shp_vthen(letshp_a,shp_v=Owl_utils_array.align`Left1shp_ashp_vinletaxis=Owl_utils_array.filter2_i(<>)shp_ashp_vinArr(A.sum_reduce~axisv))elseArrv)|_a,v->vletreverse_pushvx=letopenMathsinletrecpushxs=matchxswith|[]->()|(v,x)::t->(matchxwith|DR(ap,aa,ao,af,_ai,tracker)->(letv=_shrink!aavinaa:=Maths.(!aa+v);af:=Pervasives.(!af-1);if(!af=0)&&(!tracker=1)then(matchaowith|Noop->pusht|Add_D_D(a,b)->push((!aa,a)::(!aa,b)::t)|Add_D_C(a,_)->push((!aa,a)::t)|Add_C_D(_,b)->push((!aa,b)::t)|Sub_D_D(a,b)->push((!aa,a)::(neg!aa,b)::t)|Sub_D_C(a,_)->push((!aa,a)::t)|Sub_C_D(_,b)->push((neg!aa,b)::t)|Mul_D_D(a,b)->push(((!aa*primalb),a)::((!aa*primala),b)::t)|Mul_D_C(a,b)->push(((!aa*b),a)::t)|Mul_C_D(a,b)->push(((!aa*a),b)::t)|Div_D_D(a,b)->push(((!aa/(primalb)),a)::((!aa*((neg(primala))/((primalb)*(primalb)))),b)::t)|Div_D_C(a,b)->push(((!aa/b),a)::t)|Div_C_D(a,b)->push(((!aa*((nega)/((primalb)*(primalb)))),b)::t)|Pow_D_D(a,b)->push(((!aa*((primala)**((primalb)-(pack_flt1.)))*(primalb)),a)::((!aa*((primala)**(primalb))*log(primala)),b)::t)|Pow_D_C(a,b)->push(((!aa*((primala)**(b-(pack_flt1.)))*b),a)::t)|Pow_C_D(a,b)->push(((!aa*(a**(primalb))*loga),b)::t)|Atan2_D_D(a,b)->letd=(sqr(primala))+(sqr(primalb))inpush(((!aa*(primalb)/d),a)::((!aa*(neg(primala))/d),b)::t)|Atan2_D_C(a,b)->push(((!aa*b/((sqr(primala))+(sqrb))),a)::t)|Atan2_C_D(a,b)->push(((!aa*(nega)/((sqra)+(sqr(primalb)))),b)::t)|Neg_Da->push((neg!aa,a)::t)|Abs_Da->push(((!aa*signum(primala)),a)::t)|Signum_Da->push((zeroa,a)::t)|Floor_Da->push((zeroa,a)::t)|Ceil_Da->push((zeroa,a)::t)|Round_Da->push((zeroa,a)::t)|Sqr_Da->push(((!aa*(primala)*(pack_flt2.)),a)::t)|Sqrt_Da->push(((!aa/((pack_flt2.)*ap)),a)::t)|Log_Da->push(((!aa/(primala)),a)::t)|Log2_Da->push(((!aa/((primala)*(pack_fltOwl_const.log2e))),a)::t)|Log10_Da->push(((!aa/((primala)*(pack_fltOwl_const.log10e))),a)::t)|Exp_Da->push(((!aa*ap),a)::t)|Sin_Da->push(((!aa*cos(primala)),a)::t)|Cos_Da->push(((!aa*(neg(sin(primala)))),a)::t)|Tan_Da->push(((!aa/(sqr(cos(primala)))),a)::t)|Sinh_Da->push(((!aa*(cosh(primala))),a)::t)|Cosh_Da->push(((!aa*(sinh(primala))),a)::t)|Tanh_Da->push(((!aa/(sqr(cosh(primala)))),a)::t)|Asin_Da->push(((!aa/sqrt((pack_flt1.)-sqr(primala))),a)::t)|Acos_Da->push((((neg!aa)/sqrt((pack_flt1.)-sqr(primala))),a)::t)|Atan_Da->push(((!aa/((pack_flt1.)+sqr(primala))),a)::t)|Asinh_Da->push(((!aa/sqrt((sqr(primala))+(pack_flt1.))),a)::t)|Acosh_Da->push(((!aa/sqrt((sqr(primala))-(pack_flt1.))),a)::t)|Atanh_Da->push(((!aa/((pack_flt1.)-sqr(primala))),a)::t)|Get_Item(a,i,j)->push((set_item(zeroa)ij(sum'!aa),a)::t)|SetI_D_D(a,i,j,b)->push((set_item!aaij(pack_flt0.),a)::(get_item!aaij,b)::t)|SetI_D_C(a,i,j,_)->push((set_item!aaij(pack_flt0.),a)::t)|SetI_C_D(_,i,j,b)->push((get_item!aaij,b)::t)|AddI_D_D(a,i,j,b)->push((!aa,a)::(get_item!aaij,b)::t)|AddI_D_C(a,_,_,_)->push((!aa,a)::t)|AddI_C_D(_,i,j,b)->push((get_item!aaij,b)::t)|Get_Slice_D(a,i)->push((set_slicei(zeroa)!aa,a)::t)|Set_Slice_D_D(a,b,i)->push((set_slicei!aa(zerob),a)::(get_slicei!aa,b)::t)|Set_Slice_D_C(a,b,i)->push((set_slicei!aa(zerob),a)::t)|Set_Slice_C_D(_a,b,i)->push((get_slicei!aa,b)::t)|Sum_Da->push((!aa,a)::t)|Sum__D(a,i)->lets=shapeainletreps=Array.(make(lengths)1)inreps.(i)<-s.(i);push((repeat!aareps,a)::t)|Sum___D(a,i)->lets=shapeainletreps=Array.(make(lengths)1)inArray.iter(funj->reps.(j)<-s.(j))i;push((repeat!aareps,a)::t)|Dot_D_D(a,b)->push(((dot!aa(transpose(primalb))),a)::((dot(transpose(primala))!aa),b)::t)|Dot_D_C(a,b)->push(((dot!aa(transposeb)),a)::t)|Dot_C_D(a,b)->push(((dot(transposea)!aa),b)::t)|Trans_Da->push(((transpose!aa),a)::t)|L1Norm_Da->push(((!aa*(signum(primala))),a)::t)|L2Norm_Da->push(((!aa/ap*(primala)),a)::t)|L2NormS_Da->push(((!aa*(pack_flt2.)*(primala)),a)::t)|Sigmoid_Da->push(((!aa*ap*((pack_flt1.)-ap)),a)::t)|Relu_Da->push(((!aa*((signum(primala)+(pack_flt1.))/(pack_flt2.))),a)::t)|Inv_Da->letdpt=transposeapinpush((((negdpt)*@!aa*@dpt),a)::t)|Logdet_Da->push((!aa*(transpose(inv(primala))),a)::t)|Chol_D(a,upper)->push((_chol_backwardap!aaupper,a)::t)|QR_D(a,o,aa)->push((_qr_backwardoaa,a)::t)|Svd_D(a,o,aa,thin)->push((_svd_backwardoaathin,a)::t)|Lyapunov_D_D(a,q)->letabar,qbar=_lyapunov_backward_aq(primala)!aaapinpush((abar,a)::(qbar,q)::t)|Lyapunov_D_C(a,_)->push(((_lyapunov_backward_a(primala)!aaap),a)::t)|Lyapunov_C_D(a,q)->push(((_lyapunov_backward_q(primala)!aa),q)::t)|Discrete_Lyapunov_D_D(a,q)->letabar,qbar=_discrete_lyapunov_backward_aq(primala)!aaapinpush((abar,a)::(qbar,q)::t)|Discrete_Lyapunov_D_C(a,_)->push(((_discrete_lyapunov_backward_a(primala)!aaap),a)::t)|Discrete_Lyapunov_C_D(a,q)->push(((_discrete_lyapunov_backward_q(primala)!aa),q)::t)|Diag_D(k,a)->letm=col_numainletl=Pervasives.(m-k)inletrecaccuia_=ifi<lthenaccu(succi)(set_itema_iPervasives.(k+i)(get_item!aa0i))elsea_inletabar=accu0(zeroa)inpush((abar,a)::t)|Diagm_D(k,a)->push((diag~k!aa,a)::t)|Triu_D(k,a)->push((triu~k!aa,a)::t)|Tril_D(k,a)->push((tril~k!aa,a)::t)|Trace_Da->letm=col_numainletabar=!aa*(diagm(pack_arrA.(ones[|1;m|])))inpush((abar,a)::t)(*let m = col_num a in*)(*let rec accu i a_ =*)(*if i < m then accu (succ i) (set_item a_ i i !aa )*)(*else a_ in*)(*let abar = accu 0 (zero a) in*)(*push ((abar,a) :: t)*)|Add_Row_D_D(a,b,i)->push((!aa,a)::(get_row!aai,b)::t)|Add_Row_D_C(a,_b,_i)->push((!aa,a)::t)|Add_Row_C_D(_a,b,i)->push((get_row!aai,b)::t)|Get_Row_D(a,i)->(adjrefa):=add_row(adjvala)!aai;push((zeroa,a)::t)|Of_Rows_Da->push(t|>List.append(a|>Array.to_list|>List.mapi(funiv->(get_row!aai,v))))|Of_Arrays_D(a,idxs)->letaa_arrays=to_arrays!aainpush(t|>List.append(idxs|>List.map(fun(i,j)->(aa_arrays.(i).(j),a.(i).(j)))))|Conv1D_D_D(a,b,s)->push((conv1d_backward_inputabs!aa,a)::(conv1d_backward_kernelabs!aa,b)::t)|Conv1D_D_C(a,b,s)->push((conv1d_backward_inputabs!aa,a)::t)|Conv1D_C_D(a,b,s)->push((conv1d_backward_kernelabs!aa,b)::t)|Conv2D_D_D(a,b,s)->push((conv2d_backward_inputabs!aa,a)::(conv2d_backward_kernelabs!aa,b)::t)|Conv2D_D_C(a,b,s)->push((conv2d_backward_inputabs!aa,a)::t)|Conv2D_C_D(a,b,s)->push((conv2d_backward_kernelabs!aa,b)::t)|Conv3D_D_D(a,b,s)->push((conv3d_backward_inputabs!aa,a)::(conv3d_backward_kernelabs!aa,b)::t)|Conv3D_D_C(a,b,s)->push((conv3d_backward_inputabs!aa,a)::t)|Conv3D_C_D(a,b,s)->push((conv3d_backward_kernelabs!aa,b)::t)|Di_Conv1D_D_D(a,b,s,r)->push((dilated_conv1d_backward_inputabsr!aa,a)::(dilated_conv1d_backward_kernelabsr!aa,b)::t)|Di_Conv1D_D_C(a,b,s,r)->push((dilated_conv1d_backward_inputabsr!aa,a)::t)|Di_Conv1D_C_D(a,b,s,r)->push((dilated_conv1d_backward_kernelabsr!aa,b)::t)|Di_Conv2D_D_D(a,b,s,r)->push((dilated_conv2d_backward_inputabsr!aa,a)::(dilated_conv2d_backward_kernelabsr!aa,b)::t)|Di_Conv2D_D_C(a,b,s,r)->push((dilated_conv2d_backward_inputabsr!aa,a)::t)|Di_Conv2D_C_D(a,b,s,r)->push((dilated_conv2d_backward_kernelabsr!aa,b)::t)|Di_Conv3D_D_D(a,b,s,r)->push((dilated_conv3d_backward_inputabsr!aa,a)::(dilated_conv3d_backward_kernelabsr!aa,b)::t)|Di_Conv3D_D_C(a,b,s,r)->push((dilated_conv3d_backward_inputabsr!aa,a)::t)|Di_Conv3D_C_D(a,b,s,r)->push((dilated_conv3d_backward_kernelabsr!aa,b)::t)|Tr_Conv1D_D_D(a,b,s)->push((transpose_conv1d_backward_inputabs!aa,a)::(transpose_conv1d_backward_kernelabs!aa,b)::t)|Tr_Conv1D_D_C(a,b,s)->push((transpose_conv1d_backward_inputabs!aa,a)::t)|Tr_Conv1D_C_D(a,b,s)->push((transpose_conv1d_backward_kernelabs!aa,b)::t)|Tr_Conv2D_D_D(a,b,s)->push((transpose_conv2d_backward_inputabs!aa,a)::(transpose_conv2d_backward_kernelabs!aa,b)::t)|Tr_Conv2D_D_C(a,b,s)->push((transpose_conv2d_backward_inputabs!aa,a)::t)|Tr_Conv2D_C_D(a,b,s)->push((transpose_conv2d_backward_kernelabs!aa,b)::t)|Tr_Conv3D_D_D(a,b,s)->push((transpose_conv3d_backward_inputabs!aa,a)::(transpose_conv3d_backward_kernelabs!aa,b)::t)|Tr_Conv3D_D_C(a,b,s)->push((transpose_conv3d_backward_inputabs!aa,a)::t)|Tr_Conv3D_C_D(a,b,s)->push((transpose_conv3d_backward_kernelabs!aa,b)::t)|Reshape_Da->push((reshape!aa(shape(primala)),a)::t)|Maxpool1D_D(a,p,d,s)->push((max_pool1d_backwardp(primala)ds!aa,a)::t)|Maxpool2D_D(a,p,d,s)->push((max_pool2d_backwardp(primala)ds!aa,a)::t)|Maxpool3D_D(a,p,d,s)->push((max_pool3d_backwardp(primala)ds!aa,a)::t)|Avgpool1D_D(a,p,d,s)->push((avg_pool1d_backwardp(primala)ds!aa,a)::t)|Avgpool2D_D(a,p,d,s)->push((avg_pool2d_backwardp(primala)ds!aa,a)::t)|Avgpool3D_D(a,p,d,s)->push((avg_pool3d_backwardp(primala)ds!aa,a)::t)|UpSampling2D_D(a,s)->push((upsampling2d_backward(primala)s!aa,a)::t)|PAD_D(a,p)->push((pad_backward!aap,a)::t)|Concat_D_D(a,b,i)->lets=split~axis:i[|(shapea).(i);(shapeb).(i)|]!aainpush((s.(0),a)::(s.(1),b)::t)|Concat_D_C(a,b,i)->lets=split~axis:i[|(shapea).(i);(shapeb).(i)|]!aainpush((s.(0),a)::t)|Concat_C_D(a,b,i)->lets=split~axis:i[|(shapea).(i);(shapeb).(i)|]!aainpush((s.(1),b)::t)|Split_D(a,axis,aa_arr)->push((concatenate~axis(Array.map(funaa->!aa)aa_arr),a)::t)|Concatenate_D(a,axis,idx)->letaa_arr=split~axis(Array.map(funx->(shapex).(axis))a)!aainpush(t|>List.(append(map(funi->aa_arr.(i),a.(i))idx))))elsebegintracker:=pred!tracker;pushtend)|_->pusht)inpush[(v,x)]letreverse_propvx=reverse_resetx;reverse_pushvx(* convenient wrappers *)letmake_forwardpti=DF(p,t,i)letmake_reversepi=DR(p,ref(zerop),Noop,ref0,i,ref0)(* derivative of f (scalar -> scalr) at x, forward ad *)letdiff'fx=letx=make_forwardx(pack_flt1.)(tag())inlety=fxinprimaly,tangenty(* derivative of f (scalar -> scalar) at x, forward ad *)letdifffx=diff'fx|>snd(* gradient of f (vector -> scalar) at x, reverse ad *)letgrad'fx=letx=make_reversex(tag())inlety=fxinreverse_resety;reverse_push(pack_flt1.)y;primaly,x|>adjval(* gradient of f (vector -> scalar) at x, reverse ad *)letgradfx=grad'fx|>snd(* jacobian vector product of f (vector -> vector) at x along v, forward ad *)letjacobianv'fxv=letx=make_forwardxv(tag())inlety=fxinprimaly,tangenty(* jacobian vector product of f (vector -> vector) at x along v, forward ad *)letjacobianvfxv=jacobianv'fxv|>snd(* transposed jacobian vector product of f (vector -> vector) at x along v, backward ad *)letjacobianTv'fxv=letx=make_reversex(tag())inlety=fxinreverse_resety;reverse_pushvy;primaly,x|>adjval|>primal(* transposed jacobian vector product of f (vector -> vector) at x along v, backward ad *)letjacobianTvfxv=jacobianTv'fxv|>snd(* jacobian of f (vector -> vector) at x, both x and y are row vectors, also return the original value *)letjacobian'fx=lety=fx|>primalinletm=col_numyinletn=col_numxinletz=A.empty[|m;n|]in(matchm>nwith|true->(Array.initn(funi->letv=A.zeros[|1;n|]inA.(setv[|0;i|](float_to_elt1.));jacobianvfx(Arrv))|>Array.iteri(funiv->matchvwith|Arrv->A.copy_col_to(A.transposev)zi|_->failwith"error: jacobian");)|false->(Array.initm(funi->letv=A.zeros[|1;m|]inA.(setv[|0;i|](float_to_elt1.));jacobianTvfx(Arrv))|>Array.iteri(funiv->matchvwith|Arrv->A.copy_row_tovzi|_->failwith"error: jacobian");););(y,Arrz)(* jacobian of f *)letjacobianfx=jacobian'fx|>snd(* gradient, hessian of f (vector -> scalar) at [x] *)letgradhessianfx=jacobian'(gradf)x(* original value, gradient, and hessian *)letgradhessian'fx=letg,h=gradhessianfxinfx,g,h(* hessian of f *)lethessianfx=(f|>grad|>jacobian)x(* original value and hessian of f *)lethessian'fx=fx,hessianfx(* original value, gradient-vector product, hessian-vector product *)letgradhessianv'fxv=letgv,hv=grad'(funy->jacobianvfyv)xinfx,gv,hv(* gradient-vector product and hessian-vector product *)letgradhessianvfxv=let_,gv,hv=gradhessianv'fxvingv,hv(* original value and hessian-vector product *)lethessianv'fxv=letfv,_,hv=gradhessianv'fxvinfv,hv(* hessian-vector *)lethessianvfxv=let_,_,hv=gradhessianv'fxvinhv(* laplacian of f *)letlaplacianfx=F(hessianfx|>unpack_arr|>A.trace)letlaplacian'fx=fx,laplacianfx(* Wrapper for the Mat module *)moduleMat=structletemptymn=A.empty[|m;n|]|>pack_arrletzerosmn=A.zeros[|m;n|]|>pack_arrletonesmn=A.ones[|m;n|]|>pack_arrletuniform?a?bmn=A.uniform?a?b[|m;n|]|>pack_arrletgaussian?mu?sigmamn=A.gaussian?mu?sigma[|m;n|]|>pack_arrletresetx=x|>unpack_arr|>A.resetletreshapemnx=Maths.reshapex[|m;n|]letshapex=lets=A.shape(unpack_arrx)ins.(0),s.(1)letrow_numx=(unpack_arrx|>A.shape).(0)letcol_numx=(unpack_arrx|>A.shape).(1)letnumelx=numelxletrowxi=Maths.get_rowxiletgetxij=Maths.get_itemxijletsetxija=Maths.set_itemxija(* unary math operators *)letmeanx=Maths.meanx(* binary math operators *)letaddxy=Maths.addxyletsubxy=Maths.subxyletmulxy=Maths.mulxyletdivxy=Maths.divxyletdotxy=Maths.dotxyletmap_by_rowfx=x|>Maths.to_rows|>Array.mapf|>Maths.of_rowsletprintx=A.print(unpack_arrx)letof_arraysx=A.of_arraysx|>pack_arrend(* Wrapper for the Arr module *)moduleArr=structletemptyd=A.emptyd|>pack_arrletzerosd=A.zerosd|>pack_arrletonesd=A.onesd|>pack_arrletuniform?a?bd=A.uniform?a?bd|>pack_arrletgaussian?mu?sigmad=A.gaussian?mu?sigmad|>pack_arrletresetx=x|>unpack_arr|>A.resetletreshapexs=Maths.reshapexsletshapex=A.shape(unpack_arrx)letnumelx=numelx(* binary math operators *)letaddxy=Maths.addxyletsubxy=Maths.subxyletmulxy=Maths.mulxyletdivxy=Maths.divxyletdotxy=Maths.dotxyend(* _traverse_trace and its related functions are used to convert the
computation graph generated in backward mode into human-readable format.
You can make your own convert function to generate needed format.
*)let_traverse_tracex=(* init variables for tracking nodes and indices *)letnodes=Hashtbl.create512inletindex=ref0in(* local function to traverse the nodes *)letrecpushtlist=matchtlistwith|[]->()|hd::tl->ifHashtbl.memnodeshd=falsethenbeginletop,prev=matchhdwith|DR(_ap,_aa,ao,_af,_ai,_)->beginmatchaowith|Noop->"Noop",[]|Add_D_D(a,b)->"Add_D_D",[a;b]|Add_D_C(a,b)->"Add_D_C",[a;b]|Add_C_D(a,b)->"Add_C_D",[a;b]|Sub_D_D(a,b)->"Sub_D_D",[a;b]|Sub_D_C(a,b)->"Sub_D_C",[a;b]|Sub_C_D(a,b)->"Sub_C_D",[a;b]|Mul_D_D(a,b)->"Mul_D_D",[a;b]|Mul_D_C(a,b)->"Mul_D_C",[a;b]|Mul_C_D(a,b)->"Mul_C_D",[a;b]|Div_D_D(a,b)->"Div_D_D",[a;b]|Div_D_C(a,b)->"Div_D_C",[a;b]|Div_C_D(a,b)->"Div_C_D",[a;b]|Pow_D_D(a,b)->"Pow_D_D",[a;b]|Pow_D_C(a,b)->"Pow_D_C",[a;b]|Pow_C_D(a,b)->"Pow_C_D",[a;b]|Atan2_D_D(a,b)->"Atan2_D_D",[a;b]|Atan2_D_C(a,b)->"Atan2_D_C",[a;b]|Atan2_C_D(a,b)->"Atan2_C_D",[a;b]|Neg_Da->"Neg_D",[a]|Abs_Da->"Abs_D",[a]|Signum_Da->"Signum_D",[a]|Floor_Da->"Floor_D",[a]|Ceil_Da->"Ceil_D",[a]|Round_Da->"Round_D",[a]|Sqr_Da->"Sqr_D",[a]|Sqrt_Da->"Sqrt_D",[a]|Log_Da->"Log_D",[a]|Log2_Da->"Log2_D",[a]|Log10_Da->"Log10_D",[a]|Exp_Da->"Exp_D",[a]|Sin_Da->"Sin_D",[a]|Cos_Da->"Cos_D",[a]|Tan_Da->"Tan_D",[a]|Sinh_Da->"Sinh_D",[a]|Cosh_Da->"Cosh_D",[a]|Tanh_Da->"Tanh_D",[a]|Asin_Da->"Asin_D",[a]|Acos_Da->"Acos_D",[a]|Atan_Da->"Atan_D",[a]|Asinh_Da->"Asinh_D",[a]|Acosh_Da->"Acosh_D",[a]|Atanh_Da->"Atanh_D",[a]|Get_Item(a,_i,_j)->"Get_Item",[a]|SetI_D_D(a,_i,_j,b)->"SetI_D_D",[a;b]|SetI_D_C(a,_i,_j,b)->"SetI_D_C",[a;b]|SetI_C_D(a,_i,_j,b)->"SetI_C_D",[a;b]|AddI_D_D(a,_i,_j,b)->"AddI_D_D",[a;b]|AddI_D_C(a,_i,_j,b)->"AddI_D_C",[a;b]|AddI_C_D(a,_i,_j,b)->"AddI_C_D",[a;b]|Get_Slice_D(a,_i)->"Get_Slice_D",[a]|Set_Slice_D_D(a,b,_i)->"Set_Slice_D_D",[a;b]|Set_Slice_D_C(a,b,_i)->"Set_Slice_D_C",[a;b]|Set_Slice_C_D(a,b,_i)->"Set_Slice_C_D",[a;b]|Sum_Da->"Sum_D",[a]|Sum__D(a,_i)->"Sum__D",[a]|Sum___D(a,_i)->"Sum___D",[a]|Dot_D_D(a,b)->"Dot_D_D",[a;b]|Dot_D_C(a,b)->"Dot_D_C",[a;b]|Dot_C_D(a,b)->"Dot_C_D",[a;b]|Trans_Da->"Trans_D",[a]|L1Norm_Da->"L1Norm_D",[a]|L2Norm_Da->"L2Norm_D",[a]|L2NormS_Da->"L2NormS_D",[a]|Sigmoid_Da->"Sigmoid_D",[a]|Relu_Da->"Relu_D",[a]|Inv_Da->"Inv_D",[a]|Logdet_Da->"Inv_D",[a]|Chol_D(a,_)->"Chol_D",[a]|QR_D(a,_,_)->"QR_D",[a]|Svd_D(a,_,_,_)->"Svd_D",[a]|Lyapunov_D_D(a,q)->"Lyapunov_D_D",[a;q]|Lyapunov_C_D(a,q)->"Lyapunov_C_D",[a;q]|Lyapunov_D_C(a,q)->"Lyapunov_D_C",[a;q]|Discrete_Lyapunov_D_D(a,q)->"Discrete_Lyapunov_D_D",[a;q]|Discrete_Lyapunov_C_D(a,q)->"Discrete_Lyapunov_C_D",[a;q]|Discrete_Lyapunov_D_C(a,q)->"Discrete_Lyapunov_D_C",[a;q]|Diag_D(_,a)->"Diag_D",[a]|Diagm_D(_,a)->"Diagm_D",[a]|Tril_D(_,a)->"Tril_D",[a]|Triu_D(_,a)->"Triu_D",[a]|Trace_Da->"Trace_D",[a]|Add_Row_D_D(a,b,_i)->"Add_Row_D_D",[a;b]|Add_Row_D_C(a,b,_i)->"Add_Row_D_C",[a;b]|Add_Row_C_D(a,b,_i)->"Add_Row_C_D",[a;b]|Get_Row_D(a,_i)->"Get_Row_D",[a]|Of_Rows_Da->"Of_Rows_D",(Array.to_lista)|Of_Arrays_D(a,idxs)->"Of_Arrays_D",List.map(fun(i,j)->a.(i).(j))idxs|Conv1D_D_D(a,b,_s)->"Conv1D_D_D",[a;b]|Conv1D_D_C(a,b,_s)->"Conv1D_D_C",[a;b]|Conv1D_C_D(a,b,_s)->"Conv1D_C_D",[a;b]|Conv2D_D_D(a,b,_s)->"Conv2D_D_D",[a;b]|Conv2D_D_C(a,b,_s)->"Conv2D_D_C",[a;b]|Conv2D_C_D(a,b,_s)->"Conv2D_C_D",[a;b]|Conv3D_D_D(a,b,_s)->"Conv3D_D_D",[a;b]|Conv3D_D_C(a,b,_s)->"Conv3D_D_C",[a;b]|Conv3D_C_D(a,b,_s)->"Conv3D_C_D",[a;b]|Di_Conv1D_D_D(a,b,_s,_r)->"Di_Conv1D_D_D",[a;b]|Di_Conv1D_D_C(a,b,_s,_r)->"Di_Conv1D_D_C",[a;b]|Di_Conv1D_C_D(a,b,_s,_r)->"Di_Conv1D_C_D",[a;b]|Di_Conv2D_D_D(a,b,_s,_r)->"Di_Conv2D_D_D",[a;b]|Di_Conv2D_D_C(a,b,_s,_r)->"Di_Conv2D_D_C",[a;b]|Di_Conv2D_C_D(a,b,_s,_r)->"Di_Conv2D_C_D",[a;b]|Di_Conv3D_D_D(a,b,_s,_r)->"Di_Conv3D_D_D",[a;b]|Di_Conv3D_D_C(a,b,_s,_r)->"Di_Conv3D_D_C",[a;b]|Di_Conv3D_C_D(a,b,_s,_r)->"Di_Conv3D_C_D",[a;b]|Tr_Conv1D_D_D(a,b,_s)->"Tr_Conv1D_D_D",[a;b]|Tr_Conv1D_D_C(a,b,_s)->"Tr_Conv1D_D_C",[a;b]|Tr_Conv1D_C_D(a,b,_s)->"Tr_Conv1D_C_D",[a;b]|Tr_Conv2D_D_D(a,b,_s)->"Tr_Conv2D_D_D",[a;b]|Tr_Conv2D_D_C(a,b,_s)->"Tr_Conv2D_D_C",[a;b]|Tr_Conv2D_C_D(a,b,_s)->"Tr_Conv2D_C_D",[a;b]|Tr_Conv3D_D_D(a,b,_s)->"Tr_Conv3D_D_D",[a;b]|Tr_Conv3D_D_C(a,b,_s)->"Tr_Conv3D_D_C",[a;b]|Tr_Conv3D_C_D(a,b,_s)->"Tr_Conv3D_C_D",[a;b]|Reshape_Da->"Reshape_D",[a]|Maxpool1D_D(a,_p,_d,_s)->"Maxpool1D_D",[a]|Maxpool2D_D(a,_p,_d,_s)->"Maxpool2D_D",[a]|Maxpool3D_D(a,_p,_d,_s)->"Maxpool3D_D",[a]|Avgpool1D_D(a,_p,_d,_s)->"Avgpool1D_D",[a]|Avgpool2D_D(a,_p,_d,_s)->"Avgpool2D_D",[a]|Avgpool3D_D(a,_p,_d,_s)->"Avgpool3D_D",[a]|UpSampling2D_D(a,_s)->"UpSampling2D_D",[a]|PAD_D(a,_p)->"PAD_D",[a]|Concat_D_D(a,b,_i)->"Concat_D_D",[a;b]|Concat_D_C(a,b,_i)->"Concat_D_C",[a;b]|Concat_C_D(a,b,_i)->"Concat_C_D",[a;b]|Split_D(a,_,_)->"Split_D",[a]|Concatenate_D(a,_,idx)->"Concatenate_D",List.(map(funi->a.(i))idx)end|F_a->Printf.sprintf"Const",[]|Arr_a->Printf.sprintf"Const",[]|DF(_,_,_)->Printf.sprintf"DF",[]in(* check if the node has been visited before *)Hashtbl.addnodeshd(!index,op,prev);index:=!index+1;push(prev@tl);endelsepushtlin(* iterate the graph then return the hash table *)pushx;nodes(* convert graph to terminal output *)let_convert_terminal_outputnodes=Hashtbl.fold(funv(v_id,v_op,v_prev)s0->letv_ts=type_infovins0^List.fold_left(funs1u->letu_id,u_op,_=Hashtbl.findnodesuinletu_ts=type_infouins1^Printf.sprintf"{ i:%i o:%s t:%s } -> { i:%i o:%s t:%s }\n"u_idu_opu_tsv_idv_opv_ts)""v_prev)nodes""(* convert graph to dot file output *)let_convert_dot_outputnodes=letnetwork=Hashtbl.fold(fun_v(v_id,_v_op,v_prev)s0->s0^List.fold_left(funs1u->letu_id,_u_op,_=Hashtbl.findnodesuins1^Printf.sprintf"\t%i -> %i;\n"u_idv_id)""v_prev)nodes""inletattrs=Hashtbl.fold(funv(v_id,v_op,_v_prev)s0->ifv_op="Const"thens0^Printf.sprintf"%i [ label=\"#%i | { %s | %s }\" fillcolor=gray, style=filled ];\n"v_idv_idv_op(deep_infov)elses0^Printf.sprintf"%i [ label=\"#%i | { %s | %s }\" ];\n"v_idv_idv_op(deep_infov))nodes""innetwork^attrsletto_tracenodes=_traverse_tracenodes|>_convert_terminal_outputletto_dotnodes=_traverse_tracenodes|>_convert_dot_output|>Printf.sprintf"digraph CG {\nnode [shape=record];\n%s}"letpp_numformatterx=Format.fprintfformatter"%s"(type_infox)(* Finite difference gradient test *)moduleFDGrad_test=structletgenerate_directions(dim1,dim2)=letn_directions=dim1*dim2inArray.initn_directions(funj->Arr(A.init[|dim1;dim2|](funi->ifi=jthenA.(float_to_elt1.)elseA.(float_to_elt0.))))letgenerate_test_samples(dim1,dim2)n_samples=List.initn_samples(fun_->Mat.gaussiandim1dim2),generate_directions(dim1,dim2)letfinite_difference_grad~f?(eps=1E-5)xd=letdx=Maths.((FA.(float_to_elteps))*d)inMaths.(((f(x+dx))-(f(x-dx)))/(F(A.(float_to_elt(2.*.eps)))))letcheck_grad~threshold?(verbose=false)?(eps=1E-5)~f=letcomparers=letn_d=Array.lengthrsinletr_fds=Array.mapsndrsinletrms=(Array.fold_left(funaccr_fd->acc+.(r_fd*.r_fd))0.r_fds)/.(floatn_d)|>sqrtinletmax_err=rs|>Array.map(fun(r_ad,r_fd)->abs_float(r_ad-.r_fd)/.(rms+.1E-9))|>(Array.fold_leftmax(-1.))inmax_err<threshold,max_errinletfx=Maths.(sum'(fx))inletg=gradfinfun~directionssamples->letrec__checkacc=function|[]->acc|hd::tl->letcheck,max_err=Array.map(fund->letr_ad=Maths.(sum'((ghd)*d))|>unpack_fltinletr_fd=(finite_difference_grad~f~epshdd)|>unpack_fltinr_ad,r_fd)directions|>comparein__check((check,max_err)::acc)tlinletn_samples=List.lengthsamplesinletcheck,max_err,n_passed=__check[]samples|>List.fold_left(fun(check_old,max_err_old,acc)(check,max_err)->(check_old&&check,maxmax_err_oldmax_err,(ifcheckthen(succacc)elseacc)))(true,-1.,0)inifverbosethenPrintf.printf"adjoints passed: %i/%i | max_err: %f.\n%!"n_passedn_samplesmax_err;check,n_passedendend(* ends here *)