12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096# 1 "src/base/optimise/owl_algodiff_generic.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2018 Liang Wang <liang.wang@cl.cam.ac.uk>
*)openOwl_types(* Functor of making AD module of different precisions *)moduleMake(A:Owl_types_ndarray_algodiff.Sig)=structmoduleA=A(* type definitions *)typet=|FofA.elt|ArrofA.arr|DFoft*t*int(* primal, tangent, tag *)|DRoft*tref*trace_op*intref*int(* primal, adjoint, op, fanout, tag *)andtrace_op=|Noop|Add_D_Doft*t|Add_D_Coft*t|Add_C_Doft*t|Sub_D_Doft*t|Sub_D_Coft*t|Sub_C_Doft*t|Mul_D_Doft*t|Mul_D_Coft*t|Mul_C_Doft*t|Div_D_Doft*t|Div_D_Coft*t|Div_C_Doft*t|Pow_D_Doft*t|Pow_D_Coft*t|Pow_C_Doft*t|Atan2_D_Doft*t|Atan2_D_Coft*t|Atan2_C_Doft*t|Neg_Doft|Abs_Doft|Signum_Doft|Floor_Doft|Ceil_Doft|Round_Doft|Sqr_Doft|Sqrt_Doft|Log_Doft|Log2_Doft|Log10_Doft|Exp_Doft|Sin_Doft|Cos_Doft|Tan_Doft|Sinh_Doft|Cosh_Doft|Tanh_Doft|Asin_Doft|Acos_Doft|Atan_Doft|Asinh_Doft|Acosh_Doft|Atanh_Doft|Get_Itemoft*int*int|SetI_D_Doft*int*int*t|SetI_D_Coft*int*int*t|SetI_C_Doft*int*int*t|AddI_D_Doft*int*int*t|AddI_D_Coft*int*int*t|AddI_C_Doft*int*int*t|Get_Slice_Doft*intlistlist|Set_Slice_D_Doft*t*intlistlist|Set_Slice_D_Coft*t*intlistlist|Set_Slice_C_Doft*t*intlistlist|Sum_Doft|Sum__Doft*int|Sum___Doft*intarray|Dot_D_Doft*t|Dot_D_Coft*t|Dot_C_Doft*t|Trans_Doft|L1Norm_Doft|L2Norm_Doft|L2NormS_Doft|Sigmoid_Doft|Relu_Doft|Inv_Doft|Add_Row_D_Doft*t*int|Add_Row_D_Coft*t*int|Add_Row_C_Doft*t*int|Get_Row_Doft*int|Of_Rows_Doftarray|Concat_D_Doft*t*int|Concat_D_Coft*t*int|Concat_C_Doft*t*int|Conv1D_D_Doft*t*intarray|Conv1D_D_Coft*t*intarray|Conv1D_C_Doft*t*intarray|Conv2D_D_Doft*t*intarray|Conv2D_D_Coft*t*intarray|Conv2D_C_Doft*t*intarray|Conv3D_D_Doft*t*intarray|Conv3D_D_Coft*t*intarray|Conv3D_C_Doft*t*intarray|Di_Conv1D_D_Doft*t*intarray*intarray|Di_Conv1D_D_Coft*t*intarray*intarray|Di_Conv1D_C_Doft*t*intarray*intarray|Di_Conv2D_D_Doft*t*intarray*intarray|Di_Conv2D_D_Coft*t*intarray*intarray|Di_Conv2D_C_Doft*t*intarray*intarray|Di_Conv3D_D_Doft*t*intarray*intarray|Di_Conv3D_D_Coft*t*intarray*intarray|Di_Conv3D_C_Doft*t*intarray*intarray|Tr_Conv1D_D_Doft*t*intarray|Tr_Conv1D_D_Coft*t*intarray|Tr_Conv1D_C_Doft*t*intarray|Tr_Conv2D_D_Doft*t*intarray|Tr_Conv2D_D_Coft*t*intarray|Tr_Conv2D_C_Doft*t*intarray|Tr_Conv3D_D_Doft*t*intarray|Tr_Conv3D_D_Coft*t*intarray|Tr_Conv3D_C_Doft*t*intarray|Reshape_Doft|Maxpool1D_Doft*padding*intarray*intarray|Maxpool2D_Doft*padding*intarray*intarray|Maxpool3D_Doft*padding*intarray*intarray|Avgpool1D_Doft*padding*intarray*intarray|Avgpool2D_Doft*padding*intarray*intarray|Avgpool3D_Doft*padding*intarray*intarray(* generate global tags *)let_global_tag=ref0lettag()=_global_tag:=!_global_tag+1;!_global_tag(* hepler functions of the core AD component *)letcmp_tagaibi=ifai>bithen1elseifai<bithen-1else0letreset_zero=function|F_->FA.(float_to_elt0.)|Arrap->A.resetap;Arrap|_->failwith"error: reset_zero"letprimal=function|DF(ap,_,_)->ap|DR(ap,_,_,_,_)->ap|ap->apletrecprimal'=function|DF(ap,_,_)->primal'ap|DR(ap,_,_,_,_)->primal'ap|ap->apletreczero=function|F_->FA.(float_to_elt0.)|Arrap->ArrA.(zeros(shapeap))|DF(ap,_,_)->ap|>primal'|>zero|DR(ap,_,_,_,_)->ap|>primal'|>zerolettangent=function|DF(_,at,_)->at|DR_->failwith"error: no tangent for DR"|ap->zeroapletadjref=function|DF_->failwith"error: no adjref for DF"|DR(_,at,_,_,_)->at|ap->ref(zeroap)letadjval=function|DF_->failwith"error: no adjval for DF"|DR(_,at,_,_,_)->!at|ap->zeroapletshapex=match(primal'x)with|Arrap->A.shapeap|_->failwith"error: AD.shape"letrow_numx=(shapex).(0)letcol_numx=(shapex).(1)letnumelx=matchprimal'xwith|Arrx->A.numelx|_->failwith"error: AD.numel"letclip_by_value~amin~amaxx=matchprimal'xwith|Arrx->ArrA.(clip_by_value~amin~amaxx)|_->failwith"error: AD.clip_by_value"letclip_by_l2normax=matchprimal'xwith|Arrx->ArrA.(clip_by_l2normax)|_->failwith"error: AD.clip_by_l2norm"letcopy_primal'x=match(primal'x)with|Arrap->ArrA.(copyap)|_->failwith"error: AD.copy"lettilexreps=matchprimal'xwith|Arrx->ArrA.(tilexreps)|_->failwith"error: AD.tile"letrepeatxreps=matchprimal'xwith|Arrx->ArrA.(repeatxreps)|_->failwith"error: AD.repeat"(* packing and unpacking functions *)letpack_eltx=Fxletunpack_eltx=match(primalx)with|Fx->x|_->failwith"error: AD.unpack_elt"letpack_fltx=FA.(float_to_eltx)let_fx=FA.(float_to_eltx)(* shorcut for type conversion *)letunpack_fltx=match(primalx)with|Fx->A.elt_to_floatx|_->failwith"error: AD.unpack_flt"letpack_arrx=Arrxletunpack_arrx=match(primalx)with|Arrx->x|_->failwith"error: AD.unpack_arr"(* functions to report errors, help in debugging *)letdeep_infox=matchprimal'xwith|Fa->Printf.sprintf"F(%g)"A.(elt_to_floata)|Arra->Printf.sprintf"Arr(%s)"(A.shapea|>Owl_utils_array.to_stringstring_of_int)|_->"you should not have reached here!"lettype_infox=matchxwith|F_a->Printf.sprintf"[%s]"(deep_infox)|DF(ap,_at,ai)->Printf.sprintf"[DF tag:%i ap:%s]"ai(deep_infoap)|DR(ap,_at,_ao,_af,ai)->Printf.sprintf"[DR tag:%i ap:%s]"ai(deep_infoap)|_->Printf.sprintf"[%s]"(deep_infox)leterror_binopopab=lets0="#0:"^(type_infoa)inlets1="#1:"^(type_infob)infailwith(op^" : "^s0^", "^s1)leterror_uniopopa=lets=type_infoainfailwith(op^" : "^s)(* overload operators *)moduleMaths=structletrecnoop_=()andop_d_dafffddfr=matchawith|DF(ap,at,ai)->letcp=fdapinDF(cp,(dfcpapat),ai)|DR(ap,_,_,_,ai)->letcp=fdapinDR(cp,ref(zerocp),ra,ref0,ai)|ap->ffapandop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_d=matcha,bwith|F_ap,DF(bp,bt,bi)->letcp=fdabpinDF(cp,(df_dbcpbpbt),bi)|DF(ap,at,ai),F_bp->letcp=fdapbinDF(cp,(df_dacpapat),ai)|Arr_ap,DF(bp,bt,bi)->letcp=fdabpinDF(cp,(df_dbcpbpbt),bi)|DF(ap,at,ai),Arr_bp->letcp=fdapbinDF(cp,(df_dacpapat),ai)|F_ap,DR(bp,_,_,_,bi)->letcp=fdabpinDR(cp,ref(zerocp),r_c_dab,ref0,bi)|DR(ap,_,_,_,ai),F_bp->letcp=fdapbinDR(cp,ref(zerocp),r_d_cab,ref0,ai)|Arr_ap,DR(bp,_,_,_,bi)->letcp=fdabpinDR(cp,ref(zerocp),r_c_dab,ref0,bi)|DR(ap,_,_,_,ai),Arr_bp->letcp=fdapbinDR(cp,ref(zerocp),r_d_cab,ref0,ai)|DF(ap,at,ai),DR(bp,_,_,_,bi)->(matchcmp_tagaibiwith|1->letcp=fdapbinDF(cp,df_dacpapat,ai)|-1->letcp=fdabpinDR(cp,ref(zerocp),r_c_dab,ref0,bi)|_->failwith"error: forward and reverse clash at the same level")|DR(ap,_,_,_,ai),DF(bp,bt,bi)->(matchcmp_tagaibiwith|-1->letcp=fdabpinDF(cp,df_dbcpbpbt,bi)|1->letcp=fdapbinDR(cp,ref(zerocp),r_d_cab,ref0,ai)|_->failwith"error: forward and reverse clash at the same level")|DF(ap,at,ai),DF(bp,bt,bi)->(matchcmp_tagaibiwith|0->letcp=fdapbpinDF(cp,(df_dabcpapatbpbt),ai)|1->letcp=fdapbinDF(cp,(df_dacpapat),ai)|_->letcp=fdabpinDF(cp,(df_dbcpbpbt),bi))|DR(ap,_,_,_,ai),DR(bp,_,_,_,bi)->(matchcmp_tagaibiwith|0->letcp=fdapbpinDR(cp,ref(zerocp),r_d_dab,ref0,ai)|1->letcp=fdapbinDR(cp,ref(zerocp),r_d_cab,ref0,ai)|_->letcp=fdabpinDR(cp,ref(zerocp),r_c_dab,ref0,bi))|a,b->ffaband(+)ab=addabandaddab=letffab=matcha,bwith|Fa,Fb->FA.Scalar.(addab)|Fa,Arrb->ArrA.(scalar_addab)|Arra,Fb->ArrA.(add_scalarab)|Arra,Arrb->ArrA.(addab)|_->error_binop"( + )"abinletfdab=a+binletdf_da_cp_apat=atinletdf_db_cp_bpbt=btinletdf_dab_cp_apat_bpbt=at+btinletr_d_dab=Add_D_D(a,b)inletr_d_cab=Add_D_C(a,b)inletr_c_dab=Add_C_D(a,b)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dand(-)ab=subabandsubab=letffab=matcha,bwith|Fa,Fb->FA.Scalar.(subab)|Fa,Arrb->ArrA.(scalar_subab)|Arra,Fb->ArrA.(sub_scalarab)|Arra,Arrb->ArrA.(subab)|_->error_binop"( - )"abinletfdab=a-binletdf_da_cp_apat=atinletdf_db_cp_bpbt=negbtinletdf_dab_cp_apat_bpbt=at-btinletr_d_dab=Sub_D_D(a,b)inletr_d_cab=Sub_D_C(a,b)inletr_c_dab=Sub_C_D(a,b)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dand(*)ab=mulabandmulab=letffab=matcha,bwith|Fa,Fb->FA.Scalar.(mulab)|Fa,Arrb->ArrA.(scalar_mulab)|Arra,Fb->ArrA.(mul_scalarab)|Arra,Arrb->ArrA.(mulab)|_->error_binop"( * )"abinletfdab=a*binletdf_da_cp_apat=at*binletdf_db_cp_bpbt=a*btinletdf_dab_cpapatbpbt=(ap*bt)+(at*bp)inletr_d_dab=Mul_D_D(a,b)inletr_d_cab=Mul_D_C(a,b)inletr_c_dab=Mul_C_D(a,b)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dand(/)ab=divabanddivab=letffab=matcha,bwith|Fa,Fb->FA.Scalar.(divab)|Fa,Arrb->ArrA.(scalar_divab)|Arra,Fb->ArrA.(div_scalarab)|Arra,Arrb->ArrA.(divab)|_->error_binop"( / )"abinletfdab=a/binletdf_da_cp_apat=at/binletdf_dbcpbpbt=(negbt)*cp/bpinletdf_dabcp_apatbpbt=(at-bt*cp)/bpinletr_d_dab=Div_D_D(a,b)inletr_d_cab=Div_D_C(a,b)inletr_c_dab=Div_C_D(a,b)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dand(**)ab=powabandpowab=letffab=matcha,bwith|Fa,Fb->FA.Scalar.(powab)|Fa,Arrb->ArrA.(scalar_powab)|Arra,Fb->ArrA.(pow_scalarab)|Arra,Arrb->ArrA.(powab)|_->error_binop"( ** )"abinletfdab=a**binletdf_da_cpapat=at*(ap**(b-(pack_flt1.)))*binletdf_dbcp_bpbt=bt*cp*(loga)inletdf_dab_cpapatbpbt=(ap**(bp-(pack_flt1.)))*((at*bp)+(ap*bt*logap))inletr_d_dab=Pow_D_D(a,b)inletr_d_cab=Pow_D_C(a,b)inletr_c_dab=Pow_C_D(a,b)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dandatan2ab=letffab=matcha,bwith|Fa,Fb->FA.Scalar.(atan2ab)|Fa,Arrb->ArrA.(scalar_atan2ab)|Arra,Fb->ArrA.(atan2_scalarab)|Arra,Arrb->ArrA.(atan2ab)|_->error_binop"atan2"abinletfdab=atan2abinletdf_da_cpapat=at*b/((sqrap)+(sqrb))inletdf_db_cpbpbt=(negbt)*a/((sqra)+(sqrbp))inletdf_dab_cpapatbpbt=((at*bp)-(bt*ap))/((sqrap)+(sqrbp))inletr_d_dab=Atan2_D_D(a,b)inletr_d_cab=Atan2_D_C(a,b)inletr_c_dab=Atan2_C_D(a,b)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dandmin2ab=((a+b)-abs(a-b))/(pack_flt2.)andmax2ab=((a+b)+abs(b-a))/(pack_flt2.)andnega=letff=function|Fa->FA.Scalar.(nega)|Arra->ArrA.(nega)|_->error_uniop"neg"ainletfda=negainletdf_cp_apat=(pack_flt0.)-atinletra=Neg_Dainop_d_dafffddfrandabsa=letff=function|Fa->FA.Scalar.(absa)|Arra->ArrA.(absa)|_->error_uniop"abs"ainletfda=absainletdf_cpapat=at*(signumap)inletra=Abs_Dainop_d_dafffddfrandsignuma=letff=function|Fa->FA.Scalar.(signuma)|Arra->ArrA.(signuma)|_->error_uniop"signum"ainletfda=signumainletdf_cpap_at=zeroapinletra=Signum_Dainop_d_dafffddfrandfloora=letff=function|Fa->FA.Scalar.(floora)|Arra->ArrA.(floora)|_->error_uniop"floor"ainletfda=floorainletdf_cpap_at=zeroapinletra=Floor_Dainop_d_dafffddfrandceila=letff=function|Fa->FA.Scalar.(ceila)|Arra->ArrA.(ceila)|_->error_uniop"ceil"ainletfda=ceilainletdf_cpap_at=zeroapinletra=Ceil_Dainop_d_dafffddfrandrounda=letff=function|Fa->FA.Scalar.(rounda)|Arra->ArrA.(rounda)|_->error_uniop"round"ainletfda=roundainletdf_cpap_at=zeroapinletra=Round_Dainop_d_dafffddfrandsqra=letff=function|Fa->FA.Scalar.(sqra)|Arra->ArrA.(sqra)|_->error_uniop"sqr"ainletfda=sqrainletdf_cpapat=(pack_flt2.)*at*apinletra=Sqr_Dainop_d_dafffddfrandsqrta=letff=function|Fa->FA.Scalar.(sqrta)|Arra->ArrA.(sqrta)|_->error_uniop"sqrt"ainletfda=sqrtainletdfcp_apat=at/((pack_flt2.)*cp)inletra=Sqrt_Dainop_d_dafffddfrandloga=letff=function|Fa->FA.Scalar.(loga)|Arra->ArrA.(loga)|_->error_uniop"log"ainletfda=logainletdf_cpapat=at/apinletra=Log_Dainop_d_dafffddfrandlog2a=letff=function|Fa->FA.Scalar.(log2a)|Arra->ArrA.(log2a)|_->error_uniop"log2"ainletfda=log2ainletdf_cpapat=at/(ap*(pack_fltOwl_const.log2e))inletra=Log2_Dainop_d_dafffddfrandlog10a=letff=function|Fa->FA.Scalar.(log10a)|Arra->ArrA.(log10a)|_->error_uniop"log10"ainletfda=log10ainletdf_cpapat=at/(ap*(pack_fltOwl_const.log10e))inletra=Log10_Dainop_d_dafffddfrandexpa=letff=function|Fa->FA.Scalar.(expa)|Arra->ArrA.(expa)|_->error_uniop"exp"ainletfda=expainletdfcp_apat=at*cpinletra=Exp_Dainop_d_dafffddfrandsina=letff=function|Fa->FA.Scalar.(sina)|Arra->ArrA.(sina)|_->error_uniop"sin"ainletfda=sinainletdf_cpapat=at*cosapinletra=Sin_Dainop_d_dafffddfrandcosa=letff=function|Fa->FA.Scalar.(cosa)|Arra->ArrA.(cosa)|_->error_uniop"cos"ainletfda=cosainletdf_cpapat=neg(at*sinap)inletra=Cos_Dainop_d_dafffddfrandtana=letff=function|Fa->FA.Scalar.(tana)|Arra->ArrA.(tana)|_->error_uniop"tan"ainletfda=tanainletdf_cpapat=at/(sqr(cosap))inletra=Tan_Dainop_d_dafffddfrandsinha=letff=function|Fa->FA.Scalar.(sinha)|Arra->ArrA.(sinha)|_->error_uniop"sinh"ainletfda=sinhainletdf_cpapat=at*(coshap)inletra=Sinh_Dainop_d_dafffddfrandcosha=letff=function|Fa->FA.Scalar.(cosha)|Arra->ArrA.(cosha)|_->error_uniop"cosh"ainletfda=coshainletdf_cpapat=at*(sinhap)inletra=Cosh_Dainop_d_dafffddfrandtanha=letff=function|Fa->FA.Scalar.(tanha)|Arra->ArrA.(tanha)|_->error_uniop"tanh"ainletfda=tanhainletdf_cpapat=at/(sqr(coshap))inletra=Tanh_Dainop_d_dafffddfrandasina=letff=function|Fa->FA.Scalar.(asina)|Arra->ArrA.(asina)|_->error_uniop"asin"ainletfda=asinainletdf_cpapat=at/sqrt((pack_flt1.)-sqrap)inletra=Asin_Dainop_d_dafffddfrandacosa=letff=function|Fa->FA.Scalar.(acosa)|Arra->ArrA.(acosa)|_->error_uniop"acos"ainletfda=acosainletdf_cpapat=(negat)/sqrt((pack_flt1.)-sqrap)inletra=Acos_Dainop_d_dafffddfrandatana=letff=function|Fa->FA.Scalar.(atana)|Arra->ArrA.(atana)|_->error_uniop"atan"ainletfda=atanainletdf_cpapat=at/((pack_flt1.)+sqrap)inletra=Atan_Dainop_d_dafffddfrandasinha=letff=function|Fa->FA.Scalar.(asinha)|Arra->ArrA.(asinha)|_->error_uniop"asinh"ainletfda=asinhainletdf_cpapat=at/sqrt((sqrap)+(pack_flt1.))inletra=Asinh_Dainop_d_dafffddfrandacosha=letff=function|Fa->FA.Scalar.(acosha)|Arra->ArrA.(acosha)|_->error_uniop"acosh"ainletfda=acoshainletdf_cpapat=at/sqrt((sqrap)-(pack_flt1.))inletra=Acosh_Dainop_d_dafffddfrandatanha=letff=function|Fa->FA.Scalar.(atanha)|Arra->ArrA.(atanha)|_->error_uniop"atanh"ainletfda=atanhainletdf_cpapat=at/((pack_flt1.)-sqrap)inletra=Atanh_Dainop_d_dafffddfrandget_itemaij=matchawith|Arrap->F(A.getap[|i;j|])|DF(ap,at,ai)->DF(get_itemapij,get_itematij,ai)|DR(ap,_,_,_,ai)->DR(get_itemapij,ref(pack_flt0.),Get_Item(a,i,j),ref0,ai)|_->error_uniop"get_item"aandset_itemaijb=letffab=matcha,bwith|Arra,Fb->letaa=A.copyainA.setaa[|i;j|]b;Arraa|_->error_uniop"set_item"ainletfdab=set_itemaijbinletdf_da_cp_apat=set_itematij(pack_flt0.)inletdf_db_cp_bpbt=add_item(zeroa)ijbtinletdf_dab_cp_apat_bpbt=set_itematijbtinletr_d_dab=SetI_D_D(a,i,j,b)inletr_d_cab=SetI_D_C(a,i,j,b)inletr_c_dab=SetI_C_D(a,i,j,b)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dandadd_itemaijb=letffab=matcha,bwith|Arra,Fb->letaa=A.copyainA.setaa[|i;j|]A.Scalar.(add(A.getaa[|i;j|])b);Arraa|_->error_binop"add_item"abinletfdab=add_itemaijbinletdf_da_cp_apat=atinletdf_db_cp_bpbt=add_item(zeroa)ijbtinletdf_dab_cp_apat_bpbt=add_itematijbtinletr_d_dab=AddI_D_D(a,i,j,b)inletr_d_cab=AddI_D_C(a,i,j,b)inletr_c_dab=AddI_C_D(a,i,j,b)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dandget_sliceia=letff=function|Arra->ArrA.(get_sliceia)|_->error_uniop"slice"ainletfda=get_sliceiainletdf_cp_apat=get_sliceiatinletra=Get_Slice_D(a,i)inop_d_dafffddfrandset_sliceiab=letffab=matcha,bwith|Arra,Arrb->leta=A.copyainA.(set_sliceiab);Arra|_->error_binop"set_slice"abinletfdab=set_sliceiabinletdf_da_cp_apat=set_sliceiat(zerob)inletdf_db_cp_bpbt=set_slicei(zeroa)btinletdf_dab_cp_apat_bpbt=set_sliceiatbtinletr_d_dab=Set_Slice_D_D(a,b,i)inletr_d_cab=Set_Slice_D_C(a,b,i)inletr_c_dab=Set_Slice_C_D(a,b,i)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dandsum'a=letff=function|Fa->Fa|Arra->FA.(sum'a)|_->error_uniop"sum"ainletfda=sum'ainletdf_cp_apat=sum'atinletra=Sum_Dainop_d_dafffddfrandsum?(axis=(-1))a=letff=function|Fa->Fa|Arra->ArrA.(sum~axisa)|_->error_uniop"sum"ainletfda=sum~axisainletdf_cp_apat=sum~axisatinletra=Sum__D(a,axis)inop_d_dafffddfrandsum_reduce?(axis=[|0|])a=letff=function|Fa->Fa|Arrx->ArrA.(sum_reduce~axisx)|_->error_uniop"sum_reduce"ainletfda=sum_reduce~axisainletdf_cp_apat=sum_reduce~axisatinletra=Sum___D(a,axis)inop_d_dafffddfrandmeana=(sum'a)/F(numela|>float_of_int|>A.float_to_elt)and(*@)ab=dotabanddotab=letffab=matcha,bwith|Arra,Arrb->ArrA.(dotab)|_->error_binop"( *@ )"abinletfdab=a*@binletdf_da_cp_apat=at*@binletdf_db_cp_bpbt=a*@btinletdf_dab_cpapatbpbt=(ap*@bt)+(at*@bp)inletr_d_dab=Dot_D_D(a,b)inletr_d_cab=Dot_D_C(a,b)inletr_c_dab=Dot_C_D(a,b)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dandtransposea=letff=function|Arra->ArrA.(transposea)|_->error_uniop"transpose"ainletfda=transposeainletdf_cp_apat=transposeatinletra=Trans_Dainop_d_dafffddfrandl1norm'a=letff=function|Arra->FA.(l1norm'a)|_->error_uniop"l1norm'"ainletfda=l1norm'ainletdf_cpapat=at*(signumap)inletra=L1Norm_Dainop_d_dafffddfrandl2norm'a=letff=function|Arra->FA.(l2norm'a)|_->error_uniop"l2norm'"ainletfda=l2norm'ainletdfcpapat=(ap*at)/cpinletra=L2Norm_Dainop_d_dafffddfrandl2norm_sqr'a=letff=function|Fa->FA.Scalar.(sqra)|Arra->FA.(l2norm_sqr'a)|_->error_uniop"l2norm_sqr'"ainletfda=l2norm_sqr'ainletdf_cpapat=(pack_flt2.)*(ap*at)inletra=L2NormS_Dainop_d_dafffddfrandsigmoida=letff=function|Fa->FA.Scalar.(sigmoida)|Arra->ArrA.(sigmoida)|_->error_uniop"sigmoid"ainletfda=sigmoidainletdfcp_apat=at*cp*((pack_flt1.)-cp)inletra=Sigmoid_Dainop_d_dafffddfrandrelua=letff=function|Fa->FA.Scalar.(relua)|Arra->ArrA.(relua)|_->error_uniop"relu"ainletfda=reluainletdf_cpapat=at*((pack_flt1.)+(signumap))/(pack_flt2.)inletra=Relu_Dainop_d_dafffddfrandinva=letff=function|Arra->ArrA.(inva)|_->error_uniop"inv"ainletfda=invainletdfcp_apat=(negcp)*at*cpinletra=Inv_Dainop_d_dafffddfrandsoftplusx=log((pack_flt1.)+expx)andsoftsignx=x/((pack_flt1.)+absx)andsoftmax?(axis=(-1))x=letc=ArrA.(max~axis(unpack_arrx))inlety=exp(x-c)inleta=sum~axisyiny/aandcross_entropyxy=x*logy|>sum'|>negandadd_rowabi=letffab=matcha,bwith|Arra,Arrb->A.(copy_row_to(add(rowai)b)ai;Arra)|_->error_binop"add_row"abinletfdab=add_rowabiinletdf_da_cp_apat=atinletdf_db_cp_bpbt=add_row(zeroa)btiinletdf_dab_cp_apat_bpbt=add_rowatbtiinletr_d_dab=Add_Row_D_D(a,b,i)inletr_d_cab=Add_Row_D_C(a,b,i)inletr_c_dab=Add_Row_C_D(a,b,i)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dandget_rowai=letff=function|Arra->ArrA.(rowai|>copy)|_->error_uniop"get_row"ainletfda=get_rowaiinletdf_cp_apat=get_rowatiinletra=Get_Row_D(a,i)inop_d_dafffddfrandto_rowsa=Array.init(row_numa)(funi->get_rowai)andof_rowsa=(* TODO: this can be further optimised by incorporating t array type as t *)matcha.(0)with|Arr_->Array.mapunpack_arra|>A.of_rows|>pack_arr|DF(_,_,ai)->letap=a|>Array.map(funx->x|>primal|>unpack_arr)|>A.of_rows|>pack_arrinletat=a|>Array.map(funx->x|>adjval|>unpack_arr)|>A.of_rows|>pack_arrinDF(ap,at,ai)|DR(_,_,_,_,ai)->letap=a|>Array.map(funx->x|>primal)inletcp=ap|>Array.map(funx->x|>unpack_arr)|>A.of_rows|>pack_arrinDR(cp,ref(zerocp),Of_Rows_Da,ref0,ai)|_->error_uniop"of_rows a.(0)"a.(0)(* NOTE: these fucntions are for neural network. There are many restrictions
at the moment. E.g. they do not support higher-order derivatives, and some
do not support forward mode, so use them when you know what you are doing.
*)(* a:input; b:kernel; s:stride *)andconv1d?paddingabs=letffab=matcha,bwith|Arra,Arrb->ArrA.(conv1d?paddingabs)|_->error_binop"conv1d"abinletfdab=conv1d?paddingabsin(* FIXME: df_da, df_db, df_dab are not correct ... do not use *)letdf_da_cp_ap_at=failwith"conv1d:df_da"inletdf_db_cp_bp_bt=failwith"conv1d:df_db"inletdf_dab_cp_ap_at_bp_bt=failwith"conv1d:df_dab"inletr_d_dab=Conv1D_D_D(a,b,s)inletr_d_cab=Conv1D_D_C(a,b,s)inletr_c_dab=Conv1D_C_D(a,b,s)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_d(* a:input; b:kernel; s:stride; o:output' *)andconv1d_backward_inputabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.conv1d_backward_inputabso|>pack_arr(* a:input; b:kernel; s:stride; o:output' *)andconv1d_backward_kernelabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.conv1d_backward_kernelabso|>pack_arr(* a:input; b:kernel; s:stride *)andconv2d?paddingabs=letffab=matcha,bwith|Arra,Arrb->ArrA.(conv2d?paddingabs)|_->error_binop"conv2d"abinletfdab=conv2d?paddingabsin(* FIXME: df_da, df_db, df_dab are not correct ... do not use *)letdf_da_cp_apat=atinletdf_db_cp_bpbt=btinletdf_dab_cp_apat_bpbt=at+btinletr_d_dab=Conv2D_D_D(a,b,s)inletr_d_cab=Conv2D_D_C(a,b,s)inletr_c_dab=Conv2D_C_D(a,b,s)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_d(* a:input; b:kernel; s:stride; o:output' *)andconv2d_backward_inputabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.conv2d_backward_inputabso|>pack_arr(* a:input; b:kernel; s:stride; o:output' *)andconv2d_backward_kernelabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.conv2d_backward_kernelabso|>pack_arr(* a:input; b:kernel; s:stride *)andconv3d?paddingabs=letffab=matcha,bwith|Arra,Arrb->ArrA.(conv3d?paddingabs)|_->error_binop"conv3d"abinletfdab=conv3d?paddingabsin(* FIXME: df_da, df_db, df_dab are not correct ... do not use *)letdf_da_cp_apat=atinletdf_db_cp_bpbt=btinletdf_dab_cp_apat_bpbt=at+btinletr_d_dab=Conv3D_D_D(a,b,s)inletr_d_cab=Conv3D_D_C(a,b,s)inletr_c_dab=Conv3D_C_D(a,b,s)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_d(* a:input; b:kernel; s:stride; o:output' *)andconv3d_backward_inputabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.conv3d_backward_inputabso|>pack_arr(* a:input; b:kernel; s:stride; o:output' *)andconv3d_backward_kernelabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.conv3d_backward_kernelabso|>pack_arr(* a:input; b:kernel; s:stride; r:rate *)anddilated_conv1d?paddingabsr=letffab=matcha,bwith|Arra,Arrb->ArrA.(dilated_conv1d?paddingabsr)|_->error_binop"dilated_conv1d"abinletfdab=dilated_conv1d?paddingabsrin(* FIXME: df_da, df_db, df_dab are not correct ... do not use *)letdf_da_cp_apat=atinletdf_db_cp_bpbt=btinletdf_dab_cp_apat_bpbt=at+btinletr_d_dab=Di_Conv1D_D_D(a,b,s,r)inletr_d_cab=Di_Conv1D_D_C(a,b,s,r)inletr_c_dab=Di_Conv1D_C_D(a,b,s,r)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_d(* a:input; b:kernel; o:output'; s:stride; r:rate *)anddilated_conv1d_backward_inputabsro=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.dilated_conv1d_backward_inputabsro|>pack_arr(* a:input; b:kernel; o:output'; s:stride; r:rate *)anddilated_conv1d_backward_kernelabsro=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.dilated_conv1d_backward_kernelabsro|>pack_arr(* a:input; b:kernel; s:stride; r:rate *)anddilated_conv2d?paddingabsr=letffab=matcha,bwith|Arra,Arrb->ArrA.(dilated_conv2d?paddingabsr)|_->error_binop"dilated_conv2d"abinletfdab=dilated_conv2d?paddingabsrin(* FIXME: df_da, df_db, df_dab are not correct ... do not use *)letdf_da_cp_apat=atinletdf_db_cp_bpbt=btinletdf_dab_cp_apat_bpbt=at+btinletr_d_dab=Di_Conv2D_D_D(a,b,s,r)inletr_d_cab=Di_Conv2D_D_C(a,b,s,r)inletr_c_dab=Di_Conv2D_C_D(a,b,s,r)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_d(* a:input; b:kernel; o:output'; s:stride; r:rate *)anddilated_conv2d_backward_inputabsro=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.dilated_conv2d_backward_inputabsro|>pack_arr(* a:input; b:kernel; o:output'; s:stride; r:rate *)anddilated_conv2d_backward_kernelabsro=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.dilated_conv2d_backward_kernelabsro|>pack_arr(* a:input; b:kernel; s:stride; r:rate *)anddilated_conv3d?paddingabsr=letffab=matcha,bwith|Arra,Arrb->ArrA.(dilated_conv3d?paddingabsr)|_->error_binop"dilated_conv3d"abinletfdab=dilated_conv3d?paddingabsrin(* FIXME: df_da, df_db, df_dab are not correct ... do not use *)letdf_da_cp_apat=atinletdf_db_cp_bpbt=btinletdf_dab_cp_apat_bpbt=at+btinletr_d_dab=Di_Conv3D_D_D(a,b,s,r)inletr_d_cab=Di_Conv3D_D_C(a,b,s,r)inletr_c_dab=Di_Conv3D_C_D(a,b,s,r)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_d(* a:input; b:kernel; o:output'; s:stride; r:rate *)anddilated_conv3d_backward_inputabsro=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.dilated_conv3d_backward_inputabsro|>pack_arr(* a:input; b:kernel; o:output'; s:stride; r:rate *)anddilated_conv3d_backward_kernelabsro=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.dilated_conv3d_backward_kernelabsro|>pack_arr(* a:input; b:kernel; s:stride *)andtranspose_conv1d?paddingabs=letffab=matcha,bwith|Arra,Arrb->ArrA.(transpose_conv1d?paddingabs)|_->error_binop"transpose_conv1d"abinletfdab=transpose_conv1d?paddingabsin(* FIXME: df_da, df_db, df_dab are not correct ... do not use *)letdf_da_cp_apat=atinletdf_db_cp_bpbt=btinletdf_dab_cp_apat_bpbt=at+btinletr_d_dab=Tr_Conv1D_D_D(a,b,s)inletr_d_cab=Tr_Conv1D_D_C(a,b,s)inletr_c_dab=Tr_Conv1D_C_D(a,b,s)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_d(* a:input; b:kernel; s:stride; o:output' *)andtranspose_conv1d_backward_inputabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.transpose_conv1d_backward_inputabso|>pack_arr(* a:input; b:kernel; s:stride; o:output' *)andtranspose_conv1d_backward_kernelabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.transpose_conv1d_backward_kernelabso|>pack_arr(* a:input; b:kernel; s:stride *)andtranspose_conv2d?paddingabs=letffab=matcha,bwith|Arra,Arrb->ArrA.(transpose_conv2d?paddingabs)|_->error_binop"transpose_conv2d"abinletfdab=transpose_conv2d?paddingabsin(* FIXME: df_da, df_db, df_dab are not correct ... do not use *)letdf_da_cp_apat=atinletdf_db_cp_bpbt=btinletdf_dab_cp_apat_bpbt=at+btinletr_d_dab=Tr_Conv2D_D_D(a,b,s)inletr_d_cab=Tr_Conv2D_D_C(a,b,s)inletr_c_dab=Tr_Conv2D_C_D(a,b,s)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_d(* a:input; b:kernel; s:stride; o:output' *)andtranspose_conv2d_backward_inputabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.transpose_conv2d_backward_inputabso|>pack_arr(* a:input; b:kernel; s:stride; o:output' *)andtranspose_conv2d_backward_kernelabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.transpose_conv2d_backward_kernelabso|>pack_arr(* a:input; b:kernel; s:stride *)andtranspose_conv3d?paddingabs=letffab=matcha,bwith|Arra,Arrb->ArrA.(transpose_conv3d?paddingabs)|_->error_binop"transpose_conv3d"abinletfdab=transpose_conv3d?paddingabsin(* FIXME: df_da, df_db, df_dab are not correct ... do not use *)letdf_da_cp_apat=atinletdf_db_cp_bpbt=btinletdf_dab_cp_apat_bpbt=at+btinletr_d_dab=Tr_Conv3D_D_D(a,b,s)inletr_d_cab=Tr_Conv3D_D_C(a,b,s)inletr_c_dab=Tr_Conv3D_C_D(a,b,s)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_d(* a:input; b:kernel; s:stride; o:output' *)andtranspose_conv3d_backward_inputabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.transpose_conv3d_backward_inputabso|>pack_arr(* a:input; b:kernel; s:stride; o:output' *)andtranspose_conv3d_backward_kernelabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.transpose_conv3d_backward_kernelabso|>pack_arrandreshapeas=letff=function|Arra->ArrA.(reshapeas)|_->error_uniop"reshape"ainletfda=reshapeasinletdf_cp_apat=reshapeatsinletra=Reshape_Dainop_d_dafffddfrandflattena=reshapea[|1;numela|](* a:input; b:kernel; s:stride *)andmax_pool1dpaddingabs=letff=function|Arra->ArrA.(max_pool1d~paddingabs)|_->error_uniop"max_pool1d"ainletfda=max_pool1dpaddingabsinletdf_cp_ap_at=failwith"max_pool1d:df"inletra=Maxpool1D_D(a,padding,b,s)inop_d_dafffddfr(* a:input; p:padding type; b:kernel; s:stride; o:output' *)andmax_pool1d_backwardpabso=leta=unpack_arrainleto=unpack_arroinA.max_pool1d_backwardpabso|>pack_arr(* a:input; b:kernel; s:stride *)andmax_pool2dpaddingabs=letff=function|Arra->ArrA.(max_pool2d~paddingabs)|_->error_uniop"max_pool2d"ainletfda=max_pool2dpaddingabsinletdf_cp_ap_at=failwith"max_pool2d:df"inletra=Maxpool2D_D(a,padding,b,s)inop_d_dafffddfr(* a:input; p:padding type; b:kernel; s:stride; o:output' *)andmax_pool2d_backwardpabso=leta=unpack_arrainleto=unpack_arroinA.max_pool2d_backwardpabso|>pack_arr(* a:input; b:kernel; s:stride *)andmax_pool3dpaddingabs=letff=function|Arra->ArrA.(max_pool3d~paddingabs)|_->error_uniop"max_pool3d"ainletfda=max_pool3dpaddingabsinletdf_cp_ap_at=failwith"max_pool3d:df"inletra=Maxpool3D_D(a,padding,b,s)inop_d_dafffddfr(* a:input; p:padding type; b:kernel; s:stride; o:output' *)andmax_pool3d_backwardpabso=leta=unpack_arrainleto=unpack_arroinA.max_pool3d_backwardpabso|>pack_arr(* a:input; b:kernel; s:stride *)andavg_pool1dpaddingabs=letff=function|Arra->ArrA.(avg_pool1d~paddingabs)|_->error_uniop"avg_pool1d"ainletfda=avg_pool1dpaddingabsinletdf_cp_ap_at=failwith"avg_pool1d:df"inletra=Avgpool1D_D(a,padding,b,s)inop_d_dafffddfr(* a:input; p:padding type; b:kernel; s:stride; o:output' *)andavg_pool1d_backwardpabso=leta=unpack_arrainleto=unpack_arroinA.avg_pool1d_backwardpabso|>pack_arr(* a:input; b:kernel; s:stride *)andavg_pool2dpaddingabs=letff=function|Arra->ArrA.(avg_pool2d~paddingabs)|_->error_uniop"avg_pool2d"ainletfda=avg_pool2dpaddingabsinletdf_cp_ap_at=failwith"avg_pool2d:df"inletra=Avgpool2D_D(a,padding,b,s)inop_d_dafffddfr(* a:input; p:padding type; b:kernel; s:stride; o:output' *)andavg_pool2d_backwardpabso=leta=unpack_arrainleto=unpack_arroinA.avg_pool2d_backwardpabso|>pack_arr(* a:input; b:kernel; s:stride *)andavg_pool3dpaddingabs=letff=function|Arra->ArrA.(avg_pool3d~paddingabs)|_->error_uniop"avg_pool3d"ainletfda=avg_pool3dpaddingabsinletdf_cp_ap_at=failwith"avg_pool3d:df"inletra=Avgpool3D_D(a,padding,b,s)inop_d_dafffddfr(* a:input; p:padding type; b:kernel; s:stride; o:output' *)andavg_pool3d_backwardpabso=leta=unpack_arrainleto=unpack_arroinA.avg_pool3d_backwardpabso|>pack_arranddropout?(rate=0.5)a=letp=A.float_to_elt(1.-.rate)inletb=match(primal'a)with|Arra->Arr(A.bernoulli~p(A.shapea))|_->error_uniop"dropout"aina*bandconcataxisab=letffab=matcha,bwith|Arra,Arrb->ArrA.(concatenate~axis[|a;b|])|_->error_binop"concat"abinletfdab=concataxisabinletdf_da_cp_apat=concataxisat(zerob)inletdf_db_cp_bpbt=concataxis(zeroa)btinletdf_dab_cp_apat_bpbt=concataxisatbtinletr_d_dab=Concat_D_D(a,b,axis)inletr_d_cab=Concat_D_C(a,b,axis)inletr_c_dab=Concat_C_D(a,b,axis)inop_d_d_dabfffddf_dadf_dbdf_dabr_d_dr_d_cr_c_dandsplitaxispartsa=letffa=matchawith|Arra->A.(split~axispartsa)|>Array.map(funx->Arrx)|_->error_uniop"split"ainffa(* TODO: trace and diag functions ... *)end(* core of the reverse mode *)letreverse_resetx=letrecresetxs=matchxswith|[]->()|x::t->(matchxwith|DR(_ap,aa,ao,af,_ai)->(aa:=reset_zero!aa;af:=!af+1;if!af=1then(matchaowith|Noop->resett|Add_D_D(a,b)->reset(a::b::t)|Add_D_C(a,_)->reset(a::t)|Add_C_D(_,b)->reset(b::t)|Sub_D_D(a,b)->reset(a::b::t)|Sub_D_C(a,_)->reset(a::t)|Sub_C_D(_,b)->reset(b::t)|Mul_D_D(a,b)->reset(a::b::t)|Mul_D_C(a,_)->reset(a::t)|Mul_C_D(_,b)->reset(b::t)|Div_D_D(a,b)->reset(a::b::t)|Div_D_C(a,_)->reset(a::t)|Div_C_D(_,b)->reset(b::t)|Pow_D_D(a,b)->reset(a::b::t)|Pow_D_C(a,_)->reset(a::t)|Pow_C_D(_,b)->reset(b::t)|Atan2_D_D(a,b)->reset(a::b::t)|Atan2_D_C(a,_)->reset(a::t)|Atan2_C_D(_,b)->reset(b::t)|Neg_Da->reset(a::t)|Abs_Da->reset(a::t)|Signum_Da->reset(a::t)|Floor_Da->reset(a::t)|Ceil_Da->reset(a::t)|Round_Da->reset(a::t)|Sqr_Da->reset(a::t)|Sqrt_Da->reset(a::t)|Log_Da->reset(a::t)|Log2_Da->reset(a::t)|Log10_Da->reset(a::t)|Exp_Da->reset(a::t)|Sin_Da->reset(a::t)|Cos_Da->reset(a::t)|Tan_Da->reset(a::t)|Sinh_Da->reset(a::t)|Cosh_Da->reset(a::t)|Tanh_Da->reset(a::t)|Asin_Da->reset(a::t)|Acos_Da->reset(a::t)|Atan_Da->reset(a::t)|Asinh_Da->reset(a::t)|Acosh_Da->reset(a::t)|Atanh_Da->reset(a::t)|Get_Item(a,_,_)->reset(a::t)|SetI_D_D(a,_,_,b)->reset(a::b::t)|SetI_D_C(a,_,_,_)->reset(a::t)|SetI_C_D(_,_,_,b)->reset(b::t)|AddI_D_D(a,_,_,b)->reset(a::b::t)|AddI_D_C(a,_,_,_)->reset(a::t)|AddI_C_D(_,_,_,b)->reset(b::t)|Get_Slice_D(a,_)->reset(a::t)|Set_Slice_D_D(a,b,_)->reset(a::b::t)|Set_Slice_D_C(a,_,_)->reset(a::t)|Set_Slice_C_D(_,b,_)->reset(b::t)|Sum_Da->reset(a::t)|Sum__D(a,_)->reset(a::t)|Sum___D(a,_)->reset(a::t)|Dot_D_D(a,b)->reset(a::b::t)|Dot_D_C(a,_)->reset(a::t)|Dot_C_D(_,b)->reset(b::t)|Trans_Da->reset(a::t)|L1Norm_Da->reset(a::t)|L2Norm_Da->reset(a::t)|L2NormS_Da->reset(a::t)|Sigmoid_Da->reset(a::t)|Relu_Da->reset(a::t)|Inv_Da->reset(a::t)|Add_Row_D_D(a,b,_)->reset(a::b::t)|Add_Row_D_C(a,_,_)->reset(a::t)|Add_Row_C_D(_,b,_)->reset(b::t)|Get_Row_D(a,_)->reset(a::t)|Of_Rows_Da->reset(List.append(Array.to_lista)t)|Conv1D_D_D(a,b,_)->reset(a::b::t)|Conv1D_D_C(a,_,_)->reset(a::t)|Conv1D_C_D(_,b,_)->reset(b::t)|Conv2D_D_D(a,b,_)->reset(a::b::t)|Conv2D_D_C(a,_,_)->reset(a::t)|Conv2D_C_D(_,b,_)->reset(b::t)|Conv3D_D_D(a,b,_)->reset(a::b::t)|Conv3D_D_C(a,_,_)->reset(a::t)|Conv3D_C_D(_,b,_)->reset(b::t)|Di_Conv1D_D_D(a,b,_,_)->reset(a::b::t)|Di_Conv1D_D_C(a,_,_,_)->reset(a::t)|Di_Conv1D_C_D(_,b,_,_)->reset(b::t)|Di_Conv2D_D_D(a,b,_,_)->reset(a::b::t)|Di_Conv2D_D_C(a,_,_,_)->reset(a::t)|Di_Conv2D_C_D(_,b,_,_)->reset(b::t)|Di_Conv3D_D_D(a,b,_,_)->reset(a::b::t)|Di_Conv3D_D_C(a,_,_,_)->reset(a::t)|Di_Conv3D_C_D(_,b,_,_)->reset(b::t)|Tr_Conv1D_D_D(a,b,_)->reset(a::b::t)|Tr_Conv1D_D_C(a,_,_)->reset(a::t)|Tr_Conv1D_C_D(_,b,_)->reset(b::t)|Tr_Conv2D_D_D(a,b,_)->reset(a::b::t)|Tr_Conv2D_D_C(a,_,_)->reset(a::t)|Tr_Conv2D_C_D(_,b,_)->reset(b::t)|Tr_Conv3D_D_D(a,b,_)->reset(a::b::t)|Tr_Conv3D_D_C(a,_,_)->reset(a::t)|Tr_Conv3D_C_D(_,b,_)->reset(b::t)|Reshape_Da->reset(a::t)|Maxpool1D_D(a,_,_,_)->reset(a::t)|Maxpool2D_D(a,_,_,_)->reset(a::t)|Maxpool3D_D(a,_,_,_)->reset(a::t)|Avgpool1D_D(a,_,_,_)->reset(a::t)|Avgpool2D_D(a,_,_,_)->reset(a::t)|Avgpool3D_D(a,_,_,_)->reset(a::t)|Concat_D_D(a,b,_)->reset(a::b::t)|Concat_D_C(a,_,_)->reset(a::t)|Concat_C_D(_,b,_)->reset(b::t))elseresett)|_->resett)inreset[x](* check adjoint a and its update v, ensure rank a >= rank v. This function
fixes the inconsistent shapes between a and v by performing the inverse
operation of the previous broadcasting function. Note that padding is on
the left due to the expand function called in broadcasting. *)let_shrinkav=matcha,vwith|F_,Arrv->F(A.sum'v)|Arra,Arrv->(letshp_a=A.shapeainletshp_v=A.shapevinifshp_a<>shp_vthen(letshp_a,shp_v=Owl_utils_array.align`Left1shp_ashp_vinletaxis=Owl_utils_array.filter2_i(<>)shp_ashp_vinArr(A.sum_reduce~axisv))elseArrv)|_a,v->vletreverse_pushvx=letopenMathsinletrecpushxs=matchxswith|[]->()|(v,x)::t->(matchxwith|DR(ap,aa,ao,af,_ai)->(letv=_shrink!aavinaa:=Maths.(!aa+v);af:=Pervasives.(!af-1);if!af=0then(matchaowith|Noop->pusht|Add_D_D(a,b)->push((!aa,a)::(!aa,b)::t)|Add_D_C(a,_)->push((!aa,a)::t)|Add_C_D(_,b)->push((!aa,b)::t)|Sub_D_D(a,b)->push((!aa,a)::(neg!aa,b)::t)|Sub_D_C(a,_)->push((!aa,a)::t)|Sub_C_D(_,b)->push((neg!aa,b)::t)|Mul_D_D(a,b)->push(((!aa*primalb),a)::((!aa*primala),b)::t)|Mul_D_C(a,b)->push(((!aa*b),a)::t)|Mul_C_D(a,b)->push(((!aa*a),b)::t)|Div_D_D(a,b)->push(((!aa/(primalb)),a)::((!aa*((neg(primala))/((primalb)*(primalb)))),b)::t)|Div_D_C(a,b)->push(((!aa/b),a)::t)|Div_C_D(a,b)->push(((!aa*((nega)/((primalb)*(primalb)))),b)::t)|Pow_D_D(a,b)->push(((!aa*((primala)**((primalb)-(pack_flt1.)))*(primalb)),a)::((!aa*((primala)**(primalb))*log(primala)),b)::t)|Pow_D_C(a,b)->push(((!aa*((primala)**(b-(pack_flt1.)))*b),a)::t)|Pow_C_D(a,b)->push(((!aa*(a**(primalb))*loga),b)::t)|Atan2_D_D(a,b)->letd=(sqr(primala))+(sqr(primalb))inpush(((!aa*(primalb)/d),a)::((!aa*(neg(primala))/d),b)::t)|Atan2_D_C(a,b)->push(((!aa*b/((sqr(primala))+(sqrb))),a)::t)|Atan2_C_D(a,b)->push(((!aa*(nega)/((sqra)+(sqr(primalb)))),b)::t)|Neg_Da->push((neg!aa,a)::t)|Abs_Da->push(((!aa*signum(primala)),a)::t)|Signum_Da->push((zeroa,a)::t)|Floor_Da->push((zeroa,a)::t)|Ceil_Da->push((zeroa,a)::t)|Round_Da->push((zeroa,a)::t)|Sqr_Da->push(((!aa*(primala)*(pack_flt2.)),a)::t)|Sqrt_Da->push(((!aa/((pack_flt2.)*ap)),a)::t)|Log_Da->push(((!aa/(primala)),a)::t)|Log2_Da->push(((!aa/((primala)*(pack_fltOwl_const.log2e))),a)::t)|Log10_Da->push(((!aa/((primala)*(pack_fltOwl_const.log10e))),a)::t)|Exp_Da->push(((!aa*ap),a)::t)|Sin_Da->push(((!aa*cos(primala)),a)::t)|Cos_Da->push(((!aa*(neg(sin(primala)))),a)::t)|Tan_Da->push(((!aa/(sqr(cos(primala)))),a)::t)|Sinh_Da->push(((!aa*(cosh(primala))),a)::t)|Cosh_Da->push(((!aa*(sinh(primala))),a)::t)|Tanh_Da->push(((!aa/(sqr(cosh(primala)))),a)::t)|Asin_Da->push(((!aa/sqrt((pack_flt1.)-sqr(primala))),a)::t)|Acos_Da->push((((neg!aa)/sqrt((pack_flt1.)-sqr(primala))),a)::t)|Atan_Da->push(((!aa/((pack_flt1.)+sqr(primala))),a)::t)|Asinh_Da->push(((!aa/sqrt((sqr(primala))+(pack_flt1.))),a)::t)|Acosh_Da->push(((!aa/sqrt((sqr(primala))-(pack_flt1.))),a)::t)|Atanh_Da->push(((!aa/((pack_flt1.)-sqr(primala))),a)::t)|Get_Item(a,i,j)->push((set_item(zeroa)ij(sum'!aa),a)::t)|SetI_D_D(a,i,j,b)->push((set_item!aaij(pack_flt0.),a)::(get_item!aaij,b)::t)|SetI_D_C(a,i,j,_)->push((set_item!aaij(pack_flt0.),a)::t)|SetI_C_D(_,i,j,b)->push((get_item!aaij,b)::t)|AddI_D_D(a,i,j,b)->push((!aa,a)::(get_item!aaij,b)::t)|AddI_D_C(a,_,_,_)->push((!aa,a)::t)|AddI_C_D(_,i,j,b)->push((get_item!aaij,b)::t)|Get_Slice_D(a,i)->push((set_slicei(zeroa)!aa,a)::t)|Set_Slice_D_D(a,b,i)->push((set_slicei!aa(zerob),a)::(get_slicei!aa,b)::t)|Set_Slice_D_C(a,b,i)->push((set_slicei!aa(zerob),a)::t)|Set_Slice_C_D(_a,b,i)->push((get_slicei!aa,b)::t)|Sum_Da->push((!aa,a)::t)|Sum__D(a,i)->lets=shapeainletreps=Array.(make(lengths)1)inreps.(i)<-s.(i);push((repeat!aareps,a)::t)|Sum___D(a,i)->lets=shapeainletreps=Array.(make(lengths)1)inArray.iter(funj->reps.(j)<-s.(j))i;push((repeat!aareps,a)::t)|Dot_D_D(a,b)->push(((dot!aa(transpose(primalb))),a)::((dot(transpose(primala))!aa),b)::t)|Dot_D_C(a,b)->push(((dot!aa(transposeb)),a)::t)|Dot_C_D(a,b)->push(((dot(transposea)!aa),b)::t)|Trans_Da->push(((transpose!aa),a)::t)|L1Norm_Da->push(((!aa*(signum(primala))),a)::t)|L2Norm_Da->push(((!aa/ap*(primala)),a)::t)|L2NormS_Da->push(((!aa*(pack_flt2.)*(primala)),a)::t)|Sigmoid_Da->push(((!aa*ap*((pack_flt1.)-ap)),a)::t)|Relu_Da->push(((!aa*((signum(primala)+(pack_flt1.))/(pack_flt2.))),a)::t)|Inv_Da->letdpt=transposeapinpush((((negdpt)*!aa*dpt),a)::t)|Add_Row_D_D(a,b,i)->push((!aa,a)::(get_row!aai,b)::t)|Add_Row_D_C(a,_b,_i)->push((!aa,a)::t)|Add_Row_C_D(_a,b,i)->push((get_row!aai,b)::t)|Get_Row_D(a,i)->(adjrefa):=add_row(adjvala)!aai;push((zeroa,a)::t)|Of_Rows_Da->push(t|>List.append(a|>Array.to_list|>List.mapi(funiv->(get_row!aai,v))))|Conv1D_D_D(a,b,s)->push((conv1d_backward_inputabs!aa,a)::(conv1d_backward_kernelabs!aa,b)::t)|Conv1D_D_C(a,b,s)->push((conv1d_backward_inputabs!aa,a)::t)|Conv1D_C_D(a,b,s)->push((conv1d_backward_kernelabs!aa,b)::t)|Conv2D_D_D(a,b,s)->push((conv2d_backward_inputabs!aa,a)::(conv2d_backward_kernelabs!aa,b)::t)|Conv2D_D_C(a,b,s)->push((conv2d_backward_inputabs!aa,a)::t)|Conv2D_C_D(a,b,s)->push((conv2d_backward_kernelabs!aa,b)::t)|Conv3D_D_D(a,b,s)->push((conv3d_backward_inputabs!aa,a)::(conv3d_backward_kernelabs!aa,b)::t)|Conv3D_D_C(a,b,s)->push((conv3d_backward_inputabs!aa,a)::t)|Conv3D_C_D(a,b,s)->push((conv3d_backward_kernelabs!aa,b)::t)|Di_Conv1D_D_D(a,b,s,r)->push((dilated_conv1d_backward_inputabsr!aa,a)::(dilated_conv1d_backward_kernelabsr!aa,b)::t)|Di_Conv1D_D_C(a,b,s,r)->push((dilated_conv1d_backward_inputabsr!aa,a)::t)|Di_Conv1D_C_D(a,b,s,r)->push((dilated_conv1d_backward_kernelabsr!aa,b)::t)|Di_Conv2D_D_D(a,b,s,r)->push((dilated_conv2d_backward_inputabsr!aa,a)::(dilated_conv2d_backward_kernelabsr!aa,b)::t)|Di_Conv2D_D_C(a,b,s,r)->push((dilated_conv2d_backward_inputabsr!aa,a)::t)|Di_Conv2D_C_D(a,b,s,r)->push((dilated_conv2d_backward_kernelabsr!aa,b)::t)|Di_Conv3D_D_D(a,b,s,r)->push((dilated_conv3d_backward_inputabsr!aa,a)::(dilated_conv3d_backward_kernelabsr!aa,b)::t)|Di_Conv3D_D_C(a,b,s,r)->push((dilated_conv3d_backward_inputabsr!aa,a)::t)|Di_Conv3D_C_D(a,b,s,r)->push((dilated_conv3d_backward_kernelabsr!aa,b)::t)|Tr_Conv1D_D_D(a,b,s)->push((transpose_conv1d_backward_inputabs!aa,a)::(transpose_conv1d_backward_kernelabs!aa,b)::t)|Tr_Conv1D_D_C(a,b,s)->push((transpose_conv1d_backward_inputabs!aa,a)::t)|Tr_Conv1D_C_D(a,b,s)->push((transpose_conv1d_backward_kernelabs!aa,b)::t)|Tr_Conv2D_D_D(a,b,s)->push((transpose_conv2d_backward_inputabs!aa,a)::(transpose_conv2d_backward_kernelabs!aa,b)::t)|Tr_Conv2D_D_C(a,b,s)->push((transpose_conv2d_backward_inputabs!aa,a)::t)|Tr_Conv2D_C_D(a,b,s)->push((transpose_conv2d_backward_kernelabs!aa,b)::t)|Tr_Conv3D_D_D(a,b,s)->push((transpose_conv3d_backward_inputabs!aa,a)::(transpose_conv3d_backward_kernelabs!aa,b)::t)|Tr_Conv3D_D_C(a,b,s)->push((transpose_conv3d_backward_inputabs!aa,a)::t)|Tr_Conv3D_C_D(a,b,s)->push((transpose_conv3d_backward_kernelabs!aa,b)::t)|Reshape_Da->push((reshape!aa(shape(primala)),a)::t)|Maxpool1D_D(a,p,d,s)->push((max_pool1d_backwardp(primala)ds!aa,a)::t)|Maxpool2D_D(a,p,d,s)->push((max_pool2d_backwardp(primala)ds!aa,a)::t)|Maxpool3D_D(a,p,d,s)->push((max_pool3d_backwardp(primala)ds!aa,a)::t)|Avgpool1D_D(a,p,d,s)->push((avg_pool1d_backwardp(primala)ds!aa,a)::t)|Avgpool2D_D(a,p,d,s)->push((avg_pool2d_backwardp(primala)ds!aa,a)::t)|Avgpool3D_D(a,p,d,s)->push((avg_pool3d_backwardp(primala)ds!aa,a)::t)|Concat_D_D(a,b,i)->lets=spliti[|(shapea).(i);(shapeb).(i)|]!aainpush((s.(0),a)::(s.(1),b)::t)|Concat_D_C(a,b,i)->lets=spliti[|(shapea).(i);(shapeb).(i)|]!aainpush((s.(0),a)::t)|Concat_C_D(a,b,i)->lets=spliti[|(shapea).(i);(shapeb).(i)|]!aainpush((s.(1),b)::t))elsepusht)|_->pusht)inpush[(v,x)]letreverse_propvx=reverse_resetx;reverse_pushvx(* convenient wrappers *)letmake_forwardpti=DF(p,t,i)letmake_reversepi=DR(p,ref(zerop),Noop,ref0,i)(* derivative of f (scalar -> scalr) at x, forward ad *)letdiff'fx=letx=make_forwardx(pack_flt1.)(tag())inlety=fxinprimaly,tangenty(* derivative of f (scalar -> scalar) at x, forward ad *)letdifffx=diff'fx|>snd(* gradient of f (vector -> scalar) at x, reverse ad *)letgrad'fx=letx=make_reversex(tag())inlety=fxinreverse_resety;reverse_push(pack_flt1.)y;primaly,x|>adjval(* gradient of f (vector -> scalar) at x, reverse ad *)letgradfx=grad'fx|>snd(* jacobian vector product of f (vector -> vector) at x along v, forward ad *)letjacobianv'fxv=letx=make_forwardxv(tag())inlety=fxinprimaly,tangenty(* jacobian vector product of f (vector -> vector) at x along v, forward ad *)letjacobianvfxv=jacobianv'fxv|>snd(* transposed jacobian vector product of f (vector -> vector) at x along v, backward ad *)letjacobianTv'fxv=letx=make_reversex(tag())inlety=fxinreverse_resety;reverse_pushvy;primaly,x|>adjval|>primal(* transposed jacobian vector product of f (vector -> vector) at x along v, backward ad *)letjacobianTvfxv=jacobianTv'fxv|>snd(* jacobian of f (vector -> vector) at x, both x and y are row vectors, also return the original value *)letjacobian'fx=lety=fx|>primalinletm=col_numyinletn=col_numxinletz=A.empty[|m;n|]in(matchm>nwith|true->(Array.initn(funi->letv=A.zeros[|1;n|]inA.(setv[|0;i|](float_to_elt1.));jacobianvfx(Arrv))|>Array.iteri(funiv->matchvwith|Arrv->A.copy_col_to(A.transposev)zi|_->failwith"error: jacobian");)|false->(Array.initm(funi->letv=A.zeros[|1;m|]inA.(setv[|0;i|](float_to_elt1.));jacobianTvfx(Arrv))|>Array.iteri(funiv->matchvwith|Arrv->A.copy_row_tovzi|_->failwith"error: jacobian");););(y,Arrz)(* jacobian of f *)letjacobianfx=jacobian'fx|>snd(* gradient, hessian of f (vector -> scalar) at [x] *)letgradhessianfx=jacobian'(gradf)x(* original value, gradient, and hessian *)letgradhessian'fx=letg,h=gradhessianfxinfx,g,h(* hessian of f *)lethessianfx=(f|>grad|>jacobian)x(* original value and hessian of f *)lethessian'fx=fx,hessianfx(* original value, gradient-vector product, hessian-vector product *)letgradhessianv'fxv=letgv,hv=grad'(funy->jacobianvfyv)xinfx,gv,hv(* gradient-vector product and hessian-vector product *)letgradhessianvfxv=let_,gv,hv=gradhessianv'fxvingv,hv(* original value and hessian-vector product *)lethessianv'fxv=letfv,_,hv=gradhessianv'fxvinfv,hv(* hessian-vector *)lethessianvfxv=let_,_,hv=gradhessianv'fxvinhv(* laplacian of f *)letlaplacianfx=F(hessianfx|>unpack_arr|>A.trace)letlaplacian'fx=fx,laplacianfx(* Wrapper for the Mat module *)moduleMat=structletemptymn=A.empty[|m;n|]|>pack_arrletzerosmn=A.zeros[|m;n|]|>pack_arrletonesmn=A.ones[|m;n|]|>pack_arrletuniform?a?bmn=A.uniform?a?b[|m;n|]|>pack_arrletgaussian?mu?sigmamn=A.gaussian?mu?sigma[|m;n|]|>pack_arrletresetx=x|>unpack_arr|>A.resetletreshapemnx=Maths.reshapex[|m;n|]letshapex=lets=A.shape(unpack_arrx)ins.(0),s.(1)letrow_numx=(unpack_arrx|>A.shape).(0)letcol_numx=(unpack_arrx|>A.shape).(1)letnumelx=numelxletrowxi=Maths.get_rowxiletgetxij=Maths.get_itemxijletsetxija=Maths.set_itemxija(* unary math operators *)letmeanx=Maths.meanx(* binary math operators *)letaddxy=Maths.addxyletsubxy=Maths.subxyletmulxy=Maths.mulxyletdivxy=Maths.divxyletdotxy=Maths.dotxyletmap_by_rowfx=x|>Maths.to_rows|>Array.mapf|>Maths.of_rowsletprintx=A.print(unpack_arrx)letof_arraysx=A.of_arraysx|>pack_arrend(* Wrapper for the Arr module *)moduleArr=structletemptyd=A.emptyd|>pack_arrletzerosd=A.zerosd|>pack_arrletonesd=A.onesd|>pack_arrletuniform?a?bd=A.uniform?a?bd|>pack_arrletgaussian?mu?sigmad=A.gaussian?mu?sigmad|>pack_arrletresetx=x|>unpack_arr|>A.resetletreshapexs=Maths.reshapexsletshapex=A.shape(unpack_arrx)letnumelx=numelx(* binary math operators *)letaddxy=Maths.addxyletsubxy=Maths.subxyletmulxy=Maths.mulxyletdivxy=Maths.divxyletdotxy=Maths.dotxyend(* _traverse_trace and its related functions are used to convert the
computation graph generated in backward mode into human-readable format.
You can make your own convert function to generate needed format.
*)let_traverse_tracex=(* init variables for tracking nodes and indices *)letnodes=Hashtbl.create512inletindex=ref0in(* local function to traverse the nodes *)letrecpushtlist=matchtlistwith|[]->()|hd::tl->ifHashtbl.memnodeshd=falsethen(letop,prev=matchhdwith|DR(_ap,_aa,ao,_af,_ai)->(matchaowith|Noop->"Noop",[]|Add_D_D(a,b)->"Add_D_D",[a;b]|Add_D_C(a,b)->"Add_D_C",[a;b]|Add_C_D(a,b)->"Add_C_D",[a;b]|Sub_D_D(a,b)->"Sub_D_D",[a;b]|Sub_D_C(a,b)->"Sub_D_C",[a;b]|Sub_C_D(a,b)->"Sub_C_D",[a;b]|Mul_D_D(a,b)->"Mul_D_D",[a;b]|Mul_D_C(a,b)->"Mul_D_C",[a;b]|Mul_C_D(a,b)->"Mul_C_D",[a;b]|Div_D_D(a,b)->"Div_D_D",[a;b]|Div_D_C(a,b)->"Div_D_C",[a;b]|Div_C_D(a,b)->"Div_C_D",[a;b]|Pow_D_D(a,b)->"Pow_D_D",[a;b]|Pow_D_C(a,b)->"Pow_D_C",[a;b]|Pow_C_D(a,b)->"Pow_C_D",[a;b]|Atan2_D_D(a,b)->"Atan2_D_D",[a;b]|Atan2_D_C(a,b)->"Atan2_D_C",[a;b]|Atan2_C_D(a,b)->"Atan2_C_D",[a;b]|Neg_Da->"Neg_D",[a]|Abs_Da->"Abs_D",[a]|Signum_Da->"Signum_D",[a]|Floor_Da->"Floor_D",[a]|Ceil_Da->"Ceil_D",[a]|Round_Da->"Round_D",[a]|Sqr_Da->"Sqr_D",[a]|Sqrt_Da->"Sqrt_D",[a]|Log_Da->"Log_D",[a]|Log2_Da->"Log2_D",[a]|Log10_Da->"Log10_D",[a]|Exp_Da->"Exp_D",[a]|Sin_Da->"Sin_D",[a]|Cos_Da->"Cos_D",[a]|Tan_Da->"Tan_D",[a]|Sinh_Da->"Sinh_D",[a]|Cosh_Da->"Cosh_D",[a]|Tanh_Da->"Tanh_D",[a]|Asin_Da->"Asin_D",[a]|Acos_Da->"Acos_D",[a]|Atan_Da->"Atan_D",[a]|Asinh_Da->"Asinh_D",[a]|Acosh_Da->"Acosh_D",[a]|Atanh_Da->"Atanh_D",[a]|Get_Item(a,_i,_j)->"Get_Item",[a]|SetI_D_D(a,_i,_j,b)->"SetI_D_D",[a;b]|SetI_D_C(a,_i,_j,b)->"SetI_D_C",[a;b]|SetI_C_D(a,_i,_j,b)->"SetI_C_D",[a;b]|AddI_D_D(a,_i,_j,b)->"AddI_D_D",[a;b]|AddI_D_C(a,_i,_j,b)->"AddI_D_C",[a;b]|AddI_C_D(a,_i,_j,b)->"AddI_C_D",[a;b]|Get_Slice_D(a,_i)->"Get_Slice_D",[a]|Set_Slice_D_D(a,b,_i)->"Set_Slice_D_D",[a;b]|Set_Slice_D_C(a,b,_i)->"Set_Slice_D_C",[a;b]|Set_Slice_C_D(a,b,_i)->"Set_Slice_C_D",[a;b]|Sum_Da->"Sum_D",[a]|Sum__D(a,_i)->"Sum__D",[a]|Sum___D(a,_i)->"Sum___D",[a]|Dot_D_D(a,b)->"Dot_D_D",[a;b]|Dot_D_C(a,b)->"Dot_D_C",[a;b]|Dot_C_D(a,b)->"Dot_C_D",[a;b]|Trans_Da->"Trans_D",[a]|L1Norm_Da->"L1Norm_D",[a]|L2Norm_Da->"L2Norm_D",[a]|L2NormS_Da->"L2NormS_D",[a]|Sigmoid_Da->"Sigmoid_D",[a]|Relu_Da->"Relu_D",[a]|Inv_Da->"Inv_D",[a]|Add_Row_D_D(a,b,_i)->"Add_Row_D_D",[a;b]|Add_Row_D_C(a,b,_i)->"Add_Row_D_C",[a;b]|Add_Row_C_D(a,b,_i)->"Add_Row_C_D",[a;b]|Get_Row_D(a,_i)->"Get_Row_D",[a]|Of_Rows_Da->"Of_Rows_D",(Array.to_lista)|Conv1D_D_D(a,b,_s)->"Conv1D_D_D",[a;b]|Conv1D_D_C(a,b,_s)->"Conv1D_D_C",[a;b]|Conv1D_C_D(a,b,_s)->"Conv1D_C_D",[a;b]|Conv2D_D_D(a,b,_s)->"Conv2D_D_D",[a;b]|Conv2D_D_C(a,b,_s)->"Conv2D_D_C",[a;b]|Conv2D_C_D(a,b,_s)->"Conv2D_C_D",[a;b]|Conv3D_D_D(a,b,_s)->"Conv3D_D_D",[a;b]|Conv3D_D_C(a,b,_s)->"Conv3D_D_C",[a;b]|Conv3D_C_D(a,b,_s)->"Conv3D_C_D",[a;b]|Di_Conv1D_D_D(a,b,_s,_r)->"Di_Conv1D_D_D",[a;b]|Di_Conv1D_D_C(a,b,_s,_r)->"Di_Conv1D_D_C",[a;b]|Di_Conv1D_C_D(a,b,_s,_r)->"Di_Conv1D_C_D",[a;b]|Di_Conv2D_D_D(a,b,_s,_r)->"Di_Conv2D_D_D",[a;b]|Di_Conv2D_D_C(a,b,_s,_r)->"Di_Conv2D_D_C",[a;b]|Di_Conv2D_C_D(a,b,_s,_r)->"Di_Conv2D_C_D",[a;b]|Di_Conv3D_D_D(a,b,_s,_r)->"Di_Conv3D_D_D",[a;b]|Di_Conv3D_D_C(a,b,_s,_r)->"Di_Conv3D_D_C",[a;b]|Di_Conv3D_C_D(a,b,_s,_r)->"Di_Conv3D_C_D",[a;b]|Tr_Conv1D_D_D(a,b,_s)->"Tr_Conv1D_D_D",[a;b]|Tr_Conv1D_D_C(a,b,_s)->"Tr_Conv1D_D_C",[a;b]|Tr_Conv1D_C_D(a,b,_s)->"Tr_Conv1D_C_D",[a;b]|Tr_Conv2D_D_D(a,b,_s)->"Tr_Conv2D_D_D",[a;b]|Tr_Conv2D_D_C(a,b,_s)->"Tr_Conv2D_D_C",[a;b]|Tr_Conv2D_C_D(a,b,_s)->"Tr_Conv2D_C_D",[a;b]|Tr_Conv3D_D_D(a,b,_s)->"Tr_Conv3D_D_D",[a;b]|Tr_Conv3D_D_C(a,b,_s)->"Tr_Conv3D_D_C",[a;b]|Tr_Conv3D_C_D(a,b,_s)->"Tr_Conv3D_C_D",[a;b]|Reshape_Da->"Reshape_D",[a]|Maxpool1D_D(a,_p,_d,_s)->"Maxpool1D_D",[a]|Maxpool2D_D(a,_p,_d,_s)->"Maxpool2D_D",[a]|Maxpool3D_D(a,_p,_d,_s)->"Maxpool3D_D",[a]|Avgpool1D_D(a,_p,_d,_s)->"Avgpool1D_D",[a]|Avgpool2D_D(a,_p,_d,_s)->"Avgpool2D_D",[a]|Avgpool3D_D(a,_p,_d,_s)->"Avgpool3D_D",[a]|Concat_D_D(a,b,_i)->"Concat_D_D",[a;b]|Concat_D_C(a,b,_i)->"Concat_D_C",[a;b]|Concat_C_D(a,b,_i)->"Concat_C_D",[a;b])|F_a->Printf.sprintf"Const",[]|Arr_a->Printf.sprintf"Const",[]|DF(_,_,_)->Printf.sprintf"DF",[]in(* check if the node has been visited before *)Hashtbl.addnodeshd(!index,op,prev);index:=!index+1;push(prev@tl);)elsepushtlin(* iterate the graph then return the hash table *)pushx;nodes(* convert graph to terminal output *)let_convert_terminal_outputnodes=Hashtbl.fold(funv(v_id,v_op,v_prev)s0->letv_ts=type_infovins0^List.fold_left(funs1u->letu_id,u_op,_=Hashtbl.findnodesuinletu_ts=type_infouins1^Printf.sprintf"{ i:%i o:%s t:%s } -> { i:%i o:%s t:%s }\n"u_idu_opu_tsv_idv_opv_ts)""v_prev)nodes""(* convert graph to dot file output *)let_convert_dot_outputnodes=letnetwork=Hashtbl.fold(fun_v(v_id,_v_op,v_prev)s0->s0^List.fold_left(funs1u->letu_id,_u_op,_=Hashtbl.findnodesuins1^Printf.sprintf"\t%i -> %i;\n"u_idv_id)""v_prev)nodes""inletattrs=Hashtbl.fold(funv(v_id,v_op,_v_prev)s0->ifv_op="Const"thens0^Printf.sprintf"%i [ label=\"#%i | { %s | %s }\" fillcolor=gray, style=filled ];\n"v_idv_idv_op(deep_infov)elses0^Printf.sprintf"%i [ label=\"#%i | { %s | %s }\" ];\n"v_idv_idv_op(deep_infov))nodes""innetwork^attrsletto_tracenodes=_traverse_tracenodes|>_convert_terminal_outputletto_dotnodes=_traverse_tracenodes|>_convert_dot_output|>Printf.sprintf"digraph CG {\nnode [shape=record];\n%s}"letpp_numformatterx=Format.fprintfformatter"%s"(type_infox)end(* ends here *)