123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481# 1 "src/base/algodiff/owl_algodiff_ops_builder.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*)moduleMake(Core:Owl_algodiff_core_sig.Sig)=structopenCoreletcmp_tagaibi=ifai>bithen1elseifai<bithen-1else0moduletypeSiso=sigvallabel:stringvalff_f:A.elt->tvalff_arr:A.arr->tvaldf:t->t->t->tvaldr:t->t->tref->tendletbuild_siso=(* single input single output operation *)letop_siso~ff~fd~df~ra=matchawith|DF(ap,at,ai)->letcp=fdapinDF(cp,dfcpapat,ai)|DR(ap,_,_,_,ai,_)->letcp=fdapinDR(cp,ref(zerocp),ra,ref0,ai,ref0)|ap->ffapinfun(moduleS:Siso)->letrecfa=letopenSinletff=function|Fa->S.ff_fa|Arra->S.ff_arra|_->error_unioplabelainletfda=fainletra=letadjointcpcat=(S.dr(primala)cpca,a)::tinletregistert=a::tinletlabel=S.label,[a]inadjoint,register,labelinop_siso~ff~fd~df:S.df~rainfmoduletypeSipo=sigvallabel:stringvalff_f:A.elt->t*tvalff_arr:A.arr->t*tvaldf:t->t->t->tvaldr:t->t->tref*tref->tref*tref->tendletbuild_sipo=(* single input pair outputs operation *)letop_sipo~ff~fd~df~ra=matchawith|DF(ap,at,ai)->letcp1,cp2=fdapinDF(cp1,dfcp1apat,ai),DF(cp2,dfcp2apat,ai)|DR(ap,_,_,_,ai,_)->letcp1,cp2=fdapinletca1_ref=ref(zerocp1)inletca2_ref=ref(zerocp2)inletcp1_ref=refcp1inletcp2_ref=refcp2inlettracker=ref0in(* tracker: int reference In reverse_reset, i keeps track of the number of times
cp1 and cp2 has been called such that in reverse_push, we do not update the
adjoint of ap before we've fully updated both ca1 and ca2 *)(DR(cp1,ca1_ref,r(a,(cp1_ref,cp2_ref),(ca1_ref,ca2_ref)),ref0,ai,tracker),DR(cp2,ca2_ref,r(a,(cp1_ref,cp2_ref),(ca1_ref,ca2_ref)),ref0,ai,tracker))|ap->ffapinfun(moduleS:Sipo)->letrecfa=letopenSinletff=function|Fa->S.ff_fa|Arra->S.ff_arra|_->error_unioplabelainletfd=finletr(a,cp_ref,ca_ref)=letadjointcp_cat=(S.dr(primala)cpcp_refca_ref,a)::tinletregistert=a::tinletlabel=S.label,[a]inadjoint,register,labelinop_sipo~ff~fd~df~rainfmoduletypeSito=sigvallabel:stringvalff_f:A.elt->t*t*tvalff_arr:A.arr->t*t*tvaldf:t->t->t->tvaldr:t->t->tref*tref*tref->tref*tref*tref->tendletbuild_sito=(* single input three outputs operation *)letop_sito~ff~fd~df~ra=matchawith|DF(ap,at,ai)->letcp1,cp2,cp3=fdapinDF(cp1,dfcp1apat,ai),DF(cp2,dfcp2apat,ai),DF(cp3,dfcp3apat,ai)|DR(ap,_,_,_,ai,_)->letcp1,cp2,cp3=fdapinletca1_ref=ref(zerocp1)inletca2_ref=ref(zerocp2)inletca3_ref=ref(zerocp3)inletcp1_ref=refcp1inletcp2_ref=refcp2inletcp3_ref=refcp3inlettracker=ref0in(DR(cp1,ca1_ref,r(a,(cp1_ref,cp2_ref,cp3_ref),(ca1_ref,ca2_ref,ca3_ref)),ref0,ai,tracker),DR(cp2,ca2_ref,r(a,(cp1_ref,cp2_ref,cp3_ref),(ca1_ref,ca2_ref,ca3_ref)),ref0,ai,tracker),DR(cp3,ca3_ref,r(a,(cp1_ref,cp2_ref,cp3_ref),(ca1_ref,ca2_ref,ca3_ref)),ref0,ai,tracker))|ap->ffapinfun(moduleS:Sito)->letrecfa=letopenSinletff=function|Fa->S.ff_fa|Arra->S.ff_arra|_->error_unioplabelainletfd=finletr(a,cp_ref,ca_ref)=letadjointcp_cat=(S.dr(primala)cpcp_refca_ref,a)::tinletregistert=a::tinletlabel=S.label,[a]inadjoint,register,labelinop_sito~ff~fd~df~rainfmoduletypeSiao=sigvallabel:stringvalff_f:A.elt->tarrayvalff_arr:A.arr->tarrayvaldf:tarray->t->t->tarrayvaldr:t->t->trefarray->trefarray->tendletbuild_siao=(* single input array outputs operation *)letop_siao~ff~fd~df~ra=matchawith|DF(ap,at,ai)->letcp_arr=fdapinletct_arr=dfcp_arrapatinArray.map2(funcpct->DF(cp,ct,ai))cp_arrct_arr|DR(ap,_,_,_,ai,_)->letcp_arr=fdapinletcp_arr_ref=Array.map(funcp->refcp)cp_arrinlettracker=ref0inletca_ref_arr=Array.map(funcp->ref(zerocp))cp_arrinArray.map2(funcpca_ref->DR(cp,ca_ref,r(a,cp_arr_ref,ca_ref_arr),ref0,ai,tracker))cp_arrca_ref_arr|ap->ffapinfun(moduleS:Siao)->letrecfa=letopenSinletff=function|Fa->S.ff_fa|Arra->S.ff_arra|_->error_unioplabelainletfd=finletr(a,cp_arr_ref,ca_arr_ref)=letadjointcp_ca_reft=(S.dr(primala)cpcp_arr_refca_arr_ref,a)::tinletregistert=a::tinletlabel=S.label,[a]inadjoint,register,labelinop_siao~ff~fd~df~rainfmoduletypePiso=sigvallabel:stringvalff_aa:A.elt->A.elt->tvalff_ab:A.elt->A.arr->tvalff_ba:A.arr->A.elt->tvalff_bb:A.arr->A.arr->tvaldf_da:t->t->t->t->tvaldf_db:t->t->t->t->tvaldf_dab:t->t->t->t->t->tvaldr_ab:t->t->t->tref->t*tvaldr_a:t->t->t->tref->tvaldr_b:t->t->t->tref->tendletbuild_piso=(* pair input single output operation *)letop_piso~ff~fd~df_da~df_db~df_dab~r_d_d~r_d_c~r_c_dab=matcha,bwith|F_ap,DF(bp,bt,bi)->letcp=fdabpinDF(cp,df_dbcpabpbt,bi)|DF(ap,at,ai),F_bp->letcp=fdapbinDF(cp,df_dacpapatb,ai)|Arr_ap,DF(bp,bt,bi)->letcp=fdabpinDF(cp,df_dbcpabpbt,bi)|DF(ap,at,ai),Arr_bp->letcp=fdapbinDF(cp,df_dacpapatb,ai)|F_ap,DR(bp,_,_,_,bi,_)->letcp=fdabpinDR(cp,ref(zerocp),r_c_dab,ref0,bi,ref0)|DR(ap,_,_,_,ai,_),F_bp->letcp=fdapbinDR(cp,ref(zerocp),r_d_cab,ref0,ai,ref0)|Arr_ap,DR(bp,_,_,_,bi,_)->letcp=fdabpinDR(cp,ref(zerocp),r_c_dab,ref0,bi,ref0)|DR(ap,_,_,_,ai,_),Arr_bp->letcp=fdapbinDR(cp,ref(zerocp),r_d_cab,ref0,ai,ref0)|DF(ap,at,ai),DR(bp,_,_,_,bi,_)->(matchcmp_tagaibiwith|1->letcp=fdapbinDF(cp,df_dacpapatb,ai)|-1->letcp=fdabpinDR(cp,ref(zerocp),r_c_dab,ref0,bi,ref0)|_->failwith"error: forward and reverse clash at the same level")|DR(ap,_,_,_,ai,_),DF(bp,bt,bi)->(matchcmp_tagaibiwith|-1->letcp=fdabpinDF(cp,df_dbcpabpbt,bi)|1->letcp=fdapbinDR(cp,ref(zerocp),r_d_cab,ref0,ai,ref0)|_->failwith"error: forward and reverse clash at the same level")|DF(ap,at,ai),DF(bp,bt,bi)->(matchcmp_tagaibiwith|0->letcp=fdapbpinDF(cp,df_dabcpapatbpbt,ai)|1->letcp=fdapbinDF(cp,df_dacpapatb,ai)|_->letcp=fdabpinDF(cp,df_dbcpabpbt,bi))|DR(ap,_,_,_,ai,_),DR(bp,_,_,_,bi,_)->(matchcmp_tagaibiwith|0->letcp=fdapbpinDR(cp,ref(zerocp),r_d_dab,ref0,ai,ref0)|1->letcp=fdapbinDR(cp,ref(zerocp),r_d_cab,ref0,ai,ref0)|_->letcp=fdabpinDR(cp,ref(zerocp),r_c_dab,ref0,bi,ref0))|a,b->ffabinfun(moduleS:Piso)->letrecfab=letffab=matcha,bwith|Fa,Fb->S.ff_aaab|Fa,Arrb->S.ff_abab|Arra,Fb->S.ff_baab|Arra,Arrb->S.ff_bbab|_->error_binopS.labelabinletfd=finletr_d_dab=letadjointcpca_reft=letabar,bbar=S.dr_ab(primala)(primalb)cpca_refin(abar,a)::(bbar,b)::tinletregistert=a::b::tinletlabel=S.label^"_d_d",[a;b]inadjoint,register,labelinletr_d_cab=letadjointcpca_reft=(S.dr_a(primala)bcpca_ref,a)::tinletregistert=a::tinletlabel=S.label^"_d_c",[a;b]inadjoint,register,labelinletr_c_dab=letadjointcpca_reft=(S.dr_ba(primalb)cpca_ref,b)::tinletregistert=b::tinletlabel=S.label^"_c_d",[a;b]inadjoint,register,labelinop_piso~ff~fd~df_da:S.df_da~df_db:S.df_db~df_dab:S.df_dab~r_d_d~r_d_c~r_c_dabinfmoduletypeAiso=sigvallabel:stringvalff:tarray->tvaldf:intlist->t->tarray->tarray->tvaldr:intlist->tarray->t->tref->tlistendletbuild_aiso=letbuild_info=Array.fold_left(fun(i,t,m,idxs)x->matchm,xwith|_,F_|_,Arr_->succi,t,m,idxs|`normal,DR(_,_,_,_,t',_)->succi,t',`reverse,[i]|`forward,DR(_,_,_,_,t',_)->ift'>tthensucci,t',`reverse,[i]elseift'=tthenfailwith"error: forward and reverse clash on the same level"elsesucci,t,`forward,idxs|`reverse,DR(_,_,_,_,t',_)->ift'>tthensucci,t',`reverse,[i]elseift'=tthensucci,t',`reverse,i::idxselsesucci,t,m,idxs|`normal,DF(_,_,t')->succi,t',`forward,[i]|`forward,DF(_,_,t')->ift'>tthensucci,t',`forward,[i]elseift'=tthensucci,t',`forward,i::idxselsesucci,t,`forward,idxs|`reverse,DF(_,_,t')->ift'>tthensucci,t',`forward,[i]elseift'=tthenfailwith"error: forward and reverse clash on the same level"elsesucci,t,`reverse,idxs)(0,-50000,`normal,[])infun(moduleS:Aiso)->letrecfa=let_,max_t,mode,idxs=build_infoainletidxs=idxs|>List.revinmatchmodewith|`normal->S.ffa|`forward->letap=Array.map(funx->matchxwith|DF(p,_,t')->ifmax_t=t'thenpelseift'>max_tthenfailwith"no tags should be higher than max_t"elsex|x->x)ainletcp=fapinletat=letat=a|>Array.mapzeroinList.iter(funk->at.(k)<-tangenta.(k))idxs;S.dfidxscpapatinDF(cp,at,max_t)|`reverse->letap=Array.map(funx->matchxwith|DR(p,_,_,_,t',_)->ifmax_t=t'thenpelseift'>max_tthenfailwith"no tags should be higher than max_t"elsex|x->x)ainletcp=fapinletadjointcpcat=(* use primal of inputs to calculate adjoint *)letar=S.dridxsapcpca|>Array.of_listinList.appendList.(mapi(funik->ar.(i),a.(k))idxs)tinletregistert=List.fold_left(funti->a.(i)::t)tidxsinletlabel=S.label,List.(map(funi->a.(i))idxs)inDR(cp,ref(zerocp),(adjoint,register,label),ref0,max_t,ref0)infend