12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798# 1 "src/base/algodiff/owl_algodiff_ops.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*)(* Below the variable naming convention is based on c = f(a), where f is the operation we
are defining. Therefore we use `cp` to denote the primal of the output, `ca` as the
adjoin of the output, `ap` as the primal of the input, and `at` as the tangent at the
input. *)moduleMake(Core:Owl_algodiff_core_sig.Sig)=structopenCoremoduleBuilder=Owl_algodiff_ops_builder.Make(Core)openBuildermoduleMaths=struct(* squeeze x so that it has shape s *)letrec_squeeze_broadcastxs=letshp_x=shapexinletdim_x=Array.lengthshp_xinletdim=Array.lengthsinifshp_x=sthenxelseifdim_x<dimthenPrintf.sprintf"_squeeze_broadcast: x must have dimension greater than %i, instead has \
dimension %i"dimdim_x|>failwithelseifdim=0thensum'xelse(lets,shp_x=Owl_utils_array.align`Left1sshp_xinletfold=Array.fold_left(fun(k,accu)shp_x->ifs.(k)=shp_xthensucck,accuelseifs.(k)=1thensucck,k::accuelsefailwithPrintf.(sprintf"_squeeze_broadcast: there ought to have been a broadcasting error \
in the forward pass"))inlet_,axis=fold(0,[])shp_xinletidxs=Array.of_listaxisinsum_reduce~axis:idxsx)(* single input single output operations *)and_neg=lazy(build_siso(modulestructletlabel="neg"letff_fa=FA.Scalar.(nega)letff_arra=ArrA.(nega)letdf_cp_apat=pack_flt0.-atletdr_a_cpca=neg!caend:Siso))andnega=Lazy.force_negaand_abs=lazy(build_siso(modulestructletlabel="abs"letff_fa=FA.Scalar.(absa)letff_arra=ArrA.(absa)letdf_cpapat=at*signumapletdra_cpca=!ca*signum(primala)end:Siso))andabsa=Lazy.force_absaand_signum=lazy(build_siso(modulestructletlabel="signum"letff_fa=FA.Scalar.(signuma)letff_arra=ArrA.(signuma)letdf_cpap_at=zeroapletdra_cp_ca=zeroaend:Siso))andsignuma=Lazy.force_signumaand_floor=lazy(build_siso(modulestructletlabel="floor"letff_fa=FA.Scalar.(floora)letff_arra=ArrA.(floora)letdf_cpap_at=zeroapletdra_cp_ca=zeroaend:Siso))andfloora=Lazy.force_flooraand_ceil=lazy(build_siso(modulestructletlabel="ceil"letff_fa=FA.Scalar.(ceila)letff_arra=ArrA.(ceila)letdf_cpap_at=zeroapletdra_cp_ca=zeroaend:Siso))andceila=Lazy.force_ceilaand_round=lazy(build_siso(modulestructletlabel="round"letff_fa=FA.Scalar.(rounda)letff_arra=ArrA.(rounda)letdf_cpap_at=zeroapletdra_cp_ca=zeroaend:Siso))androunda=Lazy.force_roundaand_sqr=lazy(build_siso(modulestructletlabel="sqr"letff_fa=FA.Scalar.(sqra)letff_arra=ArrA.(sqra)letdf_cpapat=pack_flt2.*at*apletdra_cpca=!ca*primala*pack_flt2.end:Siso))andsqra=Lazy.force_sqraand_sqrt=lazy(build_siso(modulestructletlabel="sqrt"letff_fa=FA.Scalar.(sqrta)letff_arra=ArrA.(sqrta)letdfcp_apat=at/(pack_flt2.*cp)letdr_acpca=!ca/(pack_flt2.*cp)end:Siso))andsqrta=Lazy.force_sqrtaand_log=lazy(build_siso(modulestructletlabel="log"letff_fa=FA.Scalar.(loga)letff_arra=ArrA.(loga)letdf_cpapat=at/apletdra_cpca=!ca/primalaend:Siso))andloga=Lazy.force_logaand_log2=lazy(build_siso(modulestructletlabel="log2"letff_fa=FA.Scalar.(log2a)letff_arra=ArrA.(log2a)letdf_cpapat=at/(ap*pack_fltOwl_const.log2e)letdra_cpca=!ca/(primala*pack_fltOwl_const.log2e)end:Siso))andlog2a=Lazy.force_log2aand_log10=lazy(build_siso(modulestructletlabel="log10"letff_fa=FA.Scalar.(log10a)letff_arra=ArrA.(log10a)letdf_cpapat=at/(ap*pack_fltOwl_const.log10e)letdra_cpca=!ca/(primala*pack_fltOwl_const.log10e)end:Siso))andlog10a=Lazy.force_log10aand_exp=lazy(build_siso(modulestructletlabel="exp"letff_fa=FA.Scalar.(expa)letff_arra=ArrA.(expa)letdfcp_apat=at*cpletdr_acpca=!ca*cpend:Siso))andexpa=Lazy.force_expaand_sin=lazy(build_siso(modulestructletlabel="sin"letff_fa=FA.Scalar.(sina)letff_arra=ArrA.(sina)letdf_cpapat=at*cosapletdra_cpca=!ca*cos(primala)end:Siso))andsina=Lazy.force_sinaand_cos=lazy(build_siso(modulestructletlabel="cos"letff_fa=FA.Scalar.(cosa)letff_arra=ArrA.(cosa)letdf_cpapat=neg(at*sinap)letdra_cpca=!ca*neg(sin(primala))end:Siso))andcosa=Lazy.force_cosaand_tan=lazy(build_siso(modulestructletlabel="tan"letff_fa=FA.Scalar.(tana)letff_arra=ArrA.(tana)letdf_cpapat=at/sqr(cosap)letdra_cpca=!ca/sqr(cos(primala))end:Siso))andtana=Lazy.force_tanaand_sinh=lazy(build_siso(modulestructletlabel="sinh"letff_fa=FA.Scalar.(sinha)letff_arra=ArrA.(sinha)letdf_cpapat=at*coshapletdra_cpca=!ca*cosh(primala)end:Siso))andsinha=Lazy.force_sinhaand_cosh=lazy(build_siso(modulestructletlabel="cosh"letff_fa=FA.Scalar.(cosha)letff_arra=ArrA.(cosha)letdf_cpapat=at*sinhapletdra_cpca=!ca*sinh(primala)end:Siso))andcosha=Lazy.force_coshaand_tanh=lazy(build_siso(modulestructletlabel="tanh"letff_fa=FA.Scalar.(tanha)letff_arra=ArrA.(tanha)letdf_cpapat=at/sqr(coshap)letdra_cpca=!ca/sqr(cosh(primala))end:Siso))andtanha=Lazy.force_tanhaand_asin=lazy(build_siso(modulestructletlabel="asin"letff_fa=FA.Scalar.(asina)letff_arra=ArrA.(asina)letdf_cpapat=at/sqrt(pack_flt1.-sqrap)letdra_cpca=!ca/sqrt(pack_flt1.-sqr(primala))end:Siso))andasina=Lazy.force_asinaand_acos=lazy(build_siso(modulestructletlabel="acos"letff_fa=FA.Scalar.(acosa)letff_arra=ArrA.(acosa)letdf_cpapat=negat/sqrt(pack_flt1.-sqrap)letdra_cpca=neg!ca/sqrt(pack_flt1.-sqr(primala))end:Siso))andacosa=Lazy.force_acosaand_atan=lazy(build_siso(modulestructletlabel="atan"letff_fa=FA.Scalar.(atana)letff_arra=ArrA.(atana)letdf_cpapat=at/(pack_flt1.+sqrap)letdra_cpca=!ca/(pack_flt1.+sqr(primala))end:Siso))andatana=Lazy.force_atanaand_asinh=lazy(build_siso(modulestructletlabel="asinh"letff_fa=FA.Scalar.(asinha)letff_arra=ArrA.(asinha)letdf_cpapat=at/sqrt(sqrap+pack_flt1.)letdra_cpca=!ca/sqrt(sqr(primala)+pack_flt1.)end:Siso))andasinha=Lazy.force_asinhaand_acosh=lazy(build_siso(modulestructletlabel="acosh"letff_fa=FA.Scalar.(acosha)letff_arra=ArrA.(acosha)letdf_cpapat=at/sqrt(sqrap-pack_flt1.)letdra_cpca=!ca/sqrt(sqr(primala)-pack_flt1.)end:Siso))andacosha=Lazy.force_acoshaand_atanh=lazy(build_siso(modulestructletlabel="atanh"letff_fa=FA.Scalar.(atanha)letff_arra=ArrA.(atanha)letdf_cpapat=at/(pack_flt1.-sqrap)letdra_cpca=!ca/(pack_flt1.-sqr(primala))end:Siso))andatanha=Lazy.force_atanhaand_get_slice=lazy(funi->build_siso(modulestructletlabel="get_slice"letff_fa=error_unioplabel(pack_elta)letff_arra=ArrA.(get_sliceia)letdf_cp_apat=get_sliceiatletdra_cpca=set_slicei(zeroa)!caend:Siso))andget_slicei=Lazy.force_get_sliceiand_sum'=lazy(build_siso(modulestructletlabel="sum'"letff_fa=Faletff_arra=FA.(sum'a)letdf_cp_apat=sum'atletdr_a_cpca=!caend:Siso))andsum'a=Lazy.force_sum'aand_sum=lazy(fun?axis->build_siso(modulestructletlabel="sum axis"letff_fa=error_unioplabel(pack_elta)letff_arra=ArrA.(sum?axisa)letdf_cp_apat=sum?axisatletdra_cpca=matchaxiswith|Someaxis->lets=shapeainletndim=Array.lengthsinletreps=Array.(makendim1)inletaxis=Owl_utils.adjust_indexaxisndiminreps.(axis)<-s.(axis);repeat!careps|None->!caend:Siso))andsum?axis=Lazy.force_sum?axisand_sum_reduce=lazy(fun~axis->build_siso(modulestructletlabel="sum_reduce"letff_fa=error_unioplabel(pack_elta)letff_arra=ArrA.(sum_reduce~axisa)letdf_cp_apat=sum_reduce~axisatletdra_cpca=lets=shapeainletreps=Array.(make(lengths)1)inArray.iter(funj->reps.(j)<-s.(j))axis;repeat!carepsend:Siso))andsum_reduce?(axis=[|0|])=Lazy.force_sum_reduce~axisandmeana=sum'a/F(numela|>float_of_int|>A.float_to_elt)and_transpose=lazy(build_siso(modulestructletlabel="transpose"letff_fa=error_unioplabel(pack_elta)letff_arra=ArrA.(transposea)letdf_cp_apat=transposeatletdr_a_cpca=transpose!caend:Siso))andtransposea=Lazy.force_transposeaand_l1norm'=lazy(build_siso(modulestructletlabel="l1norm'"letff_fa=error_unioplabel(pack_elta)letff_arra=FA.(l1norm'a)letdf_cpapat=sum'(at*signumap)letdra_cpca=!ca*signum(primala)end:Siso))andl1norm'a=Lazy.force_l1norm'aand_l2norm'=lazy(build_siso(modulestructletlabel="l2norm'"letff_fa=error_unioplabel(pack_elta)letff_arra=FA.(l2norm'a)letdfcpapat=sum'(ap*at)/cpletdracpca=!ca/cp*primalaend:Siso))andl2norm'a=Lazy.force_l2norm'aand_l2norm_sqr'=lazy(build_siso(modulestructletlabel="l2norm_sqr'"letff_fa=error_unioplabel(pack_elta)letff_arra=FA.(l2norm_sqr'a)letdf_cpapat=pack_flt2.*sum'(ap*at)letdra_cpca=!ca*pack_flt2.*primalaend:Siso))andl2norm_sqr'a=Lazy.force_l2norm_sqr'aand_sigmoid=lazy(build_siso(modulestructletlabel="sigmoid"letff_fa=FA.Scalar.(sigmoida)letff_arra=ArrA.(sigmoida)letdfcp_apat=at*cp*(pack_flt1.-cp)letdr_acpca=!ca*cp*(pack_flt1.-cp)end:Siso))andsigmoida=Lazy.force_sigmoidaand_relu=lazy(build_siso(modulestructletlabel="relu"letff_fa=FA.Scalar.(relua)letff_arra=ArrA.(relua)letdf_cpapat=at*(pack_flt1.+signumap)/pack_flt2.letdra_cpca=!ca*((signum(primala)+pack_flt1.)/pack_flt2.)end:Siso))andrelua=Lazy.force_reluaand_dawsn=lazy(build_siso(modulestructletlabel="dawsn"letff_fa=FA.Scalar.(dawsna)letff_arra=ArrA.(dawsna)letdfcpapat=at*(pack_flt1.-(pack_flt2.*ap*cp))letdracpca=!ca*(pack_flt1.-(pack_flt2.*primala*cp))end:Siso))anddawsna=Lazy.force_dawsnaand_diag=lazy(fun~k->build_siso(modulestructletlabel="diag"letff_fa=error_unioplabel(pack_elta)letff_arra=ArrA.(diag~ka|>copy)letdf_cp_apat=diag~katletdra_cpca=letm=col_numainletl=Stdlib.(m-k)inletrecaccuia_=ifi<lthenaccu(succi)(set_itema_iStdlib.(k+i)(get_item!ca0i))elsea_inaccu0(zeroa)end:Siso))anddiag?(k=0)=Lazy.force_diag~kand_diagm=lazy(fun~k->build_siso(modulestructletlabel="diagm"letff_fa=error_unioplabel(pack_elta)letff_arra=ArrA.(Mat.diagm~ka|>copy)letdf_cp_apat=diagm~katletdr_a_cpca=diag~k!caend:Siso))anddiagm?(k=0)=Lazy.force_diagm~kand_trace=lazy(build_siso(modulestructletlabel="trace"letff_fa=error_unioplabel(pack_elta)letff_arra=FA.(tracea)letdf_cp_apat=traceatletdra_cpca=letm=col_numain!ca*diagm(pack_arrA.(ones[|1;m|]))end:Siso))andtracea=Lazy.force_traceaand_triu=lazy(fun~k->build_siso(modulestructletlabel="triu"letff_fa=error_unioplabel(pack_elta)letff_arra=ArrA.(Mat.triu~ka)letdf_cp_apat=triu~katletdr_a_cpca=triu~k!caend:Siso))andtriu?(k=0)=Lazy.force_triu~kand_tril=lazy(fun~k->build_siso(modulestructletlabel="tril"letff_fa=error_unioplabel(pack_elta)letff_arra=ArrA.(Mat.tril~ka)letdf_cp_apat=tril~katletdr_a_cpca=tril~k!caend:Siso))andtril?(k=0)=Lazy.force_tril~kand_inv=lazy(build_siso(modulestructletlabel="inv"letff_fa=error_unioplabel(pack_elta)letff_arra=ArrA.(Linalg.inva)letdfcp_apat=negcp*@at*@cpletdr_acpca=letdpt=transposecpinnegdpt*@!ca*@dptend:Siso))andinva=Lazy.force_invaandsoftplusx=log(pack_flt1.+expx)andsoftsignx=x/(pack_flt1.+absx)andsoftmax?(axis=-1)x=letc=ArrA.(max~axis(unpack_arrx))inlety=exp(x-c)inleta=sum~axisyiny/aand_reshape=lazy(funas->build_siso(modulestructletlabel="reshape"letff_fa=error_unioplabel(pack_elta)letff_arra=ArrA.(reshapeas)letdf_cp_apat=reshapeatsletdra_cpca=reshape!ca(shape(primala))end:Siso)a)andreshapea=Lazy.force_reshapeaandflattena=reshapea[|1;numela|]andget_itemaij=matchawith|Arrap->F(A.getap[|i;j|])|DF(ap,at,ai)->DF(get_itemapij,get_itematij,ai)|DR(ap,_,_,_,ai,_)->letreverse_apcat=(set_item(zeroa)ij(sum'!ca),a)::tinletinputt=a::tinletlabel="Get_Item",[a]inDR(get_itemapij,ref(pack_flt0.),(reverse,input,label),ref0,ai,ref0)|_->error_uniop"get_item"aand_get_row=lazy(funai->build_siso(modulestructletlabel="get_row"letff_fa=error_unioplabel(pack_elta)letff_arra=ArrA.(rowai|>copy)letdf_cp_apat=get_rowatiletdra_cpca=adjrefa:=add_row(adjvala)!cai;zeroaend:Siso)a)andget_rowa=Lazy.force_get_rowa(* pair inputs single output operations *)and(+)ab=addaband_add=lazy(build_piso(modulestructletlabel="add"letff_aaab=FA.Scalar.(addab)letff_abab=ArrA.(scalar_addab)letff_baab=ArrA.(add_scalarab)letff_bbab=ArrA.(addab)letdf_da_cp_apat_bp=atletdf_db_cp_ap_bpbt=btletdf_dab_cp_apat_bpbt=at+btletdr_abab_cpca=_squeeze_broadcast!ca(shapea),_squeeze_broadcast!ca(shapeb)letdr_aa_b_cpca=_squeeze_broadcast!ca(shapea)letdr_b_ab_cpca=_squeeze_broadcast!ca(shapeb)end:Piso))andadda=Lazy.force_addaand(-)ab=subaband_sub=lazy(build_piso(modulestructletlabel="sub"letff_aaab=FA.Scalar.(subab)letff_abab=ArrA.(scalar_subab)letff_baab=ArrA.(sub_scalarab)letff_bbab=ArrA.(subab)letdf_da_cp_apat_bp=atletdf_db_cp_ap_bpbt=btletdf_dab_cp_apat_bpbt=at-btletdr_abab_cpca=_squeeze_broadcast!ca(shapea),neg(_squeeze_broadcast!ca(shapeb))letdr_aa_b_cpca=_squeeze_broadcast!ca(shapea)letdr_b_ab_cpca=neg(_squeeze_broadcast!ca(shapeb))end:Piso))andsuba=Lazy.force_subaand(*)ab=mulaband_mul=lazy(build_piso(modulestructletlabel="mul"letff_aaab=FA.Scalar.(mulab)letff_abab=ArrA.(scalar_mulab)letff_baab=ArrA.(mul_scalarab)letff_bbab=ArrA.(mulab)letdf_da_cp_apatbp=at*bpletdf_db_cpap_bpbt=ap*btletdf_dab_cpapatbpbt=(ap*bt)+(at*bp)letdr_abab_cpca=(_squeeze_broadcast(!ca*primalb)(shapea),_squeeze_broadcast(!ca*primala)(shapeb))letdr_aab_cpca=_squeeze_broadcast(!ca*b)(shapea)letdr_bab_cpca=_squeeze_broadcast(!ca*a)(shapeb)end:Piso))andmula=Lazy.force_mulaand(/)ab=divaband_div=lazy(build_piso(modulestructletlabel="div"letff_aaab=FA.Scalar.(divab)letff_abab=ArrA.(scalar_divab)letff_baab=ArrA.(div_scalarab)letff_bbab=ArrA.(divab)letdf_da_cp_apatbp=at/bpletdf_dbcp_apbpbt=negbt*cp/bpletdf_dabcp_apatbpbt=(at-(bt*cp))/bpletdr_abab_cpca=(_squeeze_broadcast(!ca/primalb)(shapea),_squeeze_broadcast(!ca*(neg(primala)/(primalb*primalb)))(shapeb))letdr_aab_cpca=_squeeze_broadcast(!ca/b)(shapea)letdr_bab_cpca=_squeeze_broadcast(!ca*(nega/(primalb*primalb)))(shapeb)end:Piso))anddiva=Lazy.force_divaand(**)ab=powaband_pow=lazy(build_piso(modulestructletlabel="pow"letff_aaab=FA.Scalar.(powab)letff_abab=ArrA.(scalar_powab)letff_baab=ArrA.(pow_scalarab)letff_bbab=ArrA.(powab)letdf_da_cpapatbp=at*(ap**(bp-pack_flt1.))*bpletdf_dbcpap_bpbt=bt*cp*logapletdf_dabcpapatbpbt=((ap**(bp-pack_flt1.))*(at*bp))+(cp*bt*logap)letdr_ababcpca=(_squeeze_broadcast(!ca*(primala**(primalb-pack_flt1.))*primalb)(shapea),_squeeze_broadcast(!ca*cp*log(primala))(shapeb))letdr_aab_cpca=_squeeze_broadcast(!ca*(primala**(primalb-pack_flt1.))*b)(shapea)letdr_babcpca=_squeeze_broadcast(!ca*cp*log(primala))(shapeb)end:Piso))andpowa=Lazy.force_powaand_atan2=lazy(build_piso(modulestructletlabel="atan2"letff_aaab=FA.Scalar.(atan2ab)letff_abab=ArrA.(scalar_atan2ab)letff_baab=ArrA.(atan2_scalarab)letff_bbab=ArrA.(atan2ab)letdf_da_cpapatbp=at*bp/(sqrap+sqrbp)letdf_db_cpapbpbt=negbt*ap/(sqrap+sqrbp)letdf_dab_cpapatbpbt=((at*bp)-(bt*ap))/(sqrap+sqrbp)letdr_abab_cpca=letd=sqr(primala)+sqr(primalb)in!ca*primalb/d,!ca*neg(primala)/dletdr_aab_cpca=letd=sqr(primala)+sqr(primalb)in!ca*primalb/dletdr_bab_cpca=letd=sqr(primala)+sqr(primalb)in!ca*neg(primala)/dend:Piso))andatan2a=Lazy.force_atan2aandmin2ab=(a+b-abs(a-b))/pack_flt2.andmax2ab=(a+b+abs(b-a))/pack_flt2.and_set_item=lazy(funaijb->build_piso(modulestructletlabel="set_item"letff_aaa_b=error_unioplabel(pack_elta)letff_aba_b=error_unioplabel(pack_elta)letff_baab=letaa=A.copyainA.setaa[|i;j|]b;Arraaletff_bba_b=error_unioplabel(pack_arra)letdf_da_cp_apat_bp=set_itematij(pack_flt0.)letdf_db_cp_ap_bpbt=add_item(zeroa)ijbtletdf_dab_cp_apat_bpbt=set_itematijbtletdr_ab_a_b_cpca=set_item!caij(pack_flt0.),get_item!caijletdr_a_a_b_cpca=set_item!caij(pack_flt0.)letdr_b_a_b_cpca=get_item!caijend:Piso)ab)andset_itema=Lazy.force_set_itemaand_add_item=lazy(funaijb->build_piso(modulestructletlabel="add_item"letff_aaa_b=error_unioplabel(pack_elta)letff_aba_b=error_unioplabel(pack_elta)letff_baab=letaa=A.copyainA.setaa[|i;j|]A.Scalar.(add(A.getaa[|i;j|])b);Arraaletff_bba_b=error_unioplabel(pack_arra)letdf_da_cp_apat_bp=atletdf_db_cpap_bpbt=add_item(zeroap)ijbtletdf_dab_cp_apat_bpbt=add_itematijbtletdr_ab_a_b_cpca=!ca,get_item!caijletdr_a_a_b_cpca=!caletdr_b_a_b_cpca=get_item!caijend:Piso)ab)andadd_itema=Lazy.force_add_itemaand_set_slice=lazy(funi->build_piso(modulestructletlabel="set_slice"letff_aaa_b=error_unioplabel(pack_elta)letff_aba_b=error_unioplabel(pack_elta)letff_ba_ab=error_unioplabel(pack_eltb)letff_bbab=leta=A.copyainA.(set_sliceiab);Arraletdf_da_cp_apatbp=set_sliceiat(zerobp)letdf_db_cpap_bpbt=set_slicei(zeroap)btletdf_dab_cp_apat_bpbt=set_sliceiatbtletdr_ab_ab_cpca=set_slicei!ca(zerob),get_slicei!caletdr_a_ab_cpca=set_slicei!ca(zerob)letdr_b_a_b_cpca=get_slicei!caend:Piso))andset_slicei=Lazy.force_set_sliceiand(*@)ab=dotaband_dot=lazy(build_piso(modulestructletlabel="dot"letff_aaa_b=error_unioplabel(pack_elta)letff_aba_b=error_unioplabel(pack_elta)letff_ba_ab=error_unioplabel(pack_eltb)letff_bbab=ArrA.(dotab)letdf_da_cp_apatbp=at*@bpletdf_db_cpap_bpbt=ap*@btletdf_dab_cpapatbpbt=(ap*@bt)+(at*@bp)letdr_abab_cpca=dot!ca(transpose(primalb)),dot(transpose(primala))!caletdr_a_ab_cpca=dot!ca(transpose(primalb))letdr_ba_b_cpca=dot(transpose(primala))!caend:Piso))anddota=Lazy.force_dotaandcross_entropyxy=x*logy|>sum'|>negand_add_row=lazy(funabi->build_piso(modulestructletlabel="add_row"letff_aaa_b=error_unioplabel(pack_elta)letff_aba_b=error_unioplabel(pack_elta)letff_ba_ab=error_unioplabel(pack_eltb)letff_bbab=A.(copy_row_to(add(rowai)b)ai;Arra)letdf_da_cp_apat_bp=atletdf_db_cpap_bpbt=add_row(zeroap)btiletdf_dab_cp_apat_bpbt=add_rowatbtiletdr_ab_a_b_cpca=!ca,get_row!cailetdr_a_a_b_cpca=!caletdr_b_a_b_cpca=get_row!caiend:Piso)ab)andadd_rowa=Lazy.force_add_rowaand_concataxis=lazy(build_piso(modulestructletlabel="concat"letff_aaa_b=error_unioplabel(pack_elta)letff_aba_b=error_unioplabel(pack_elta)letff_ba_ab=error_unioplabel(pack_eltb)letff_bbab=ArrA.(concatenate~axis[|a;b|])letdf_da_cp_apatbp=concat~axisat(zerobp)letdf_db_cpap_bpbt=concat~axis(zeroap)btletdf_dab_cp_apat_bpbt=concat~axisatbtletdr_aba_b_cpca=letsa=shapeainletl=sa.(axis)inletdim=Array.lengthsain(get_slice(List.initdim(funi->ifi=axisthen[0;predl]else[0;-1]))!ca,get_slice(List.initdim(funi->ifi=axisthen[l;-1]else[0;-1]))!ca)letdr_aa_b_cpca=letsa=shapeainletl=sa.(axis)inletdim=Array.lengthsainget_slice(List.initdim(funi->ifi=axisthen[0;predl]else[0;-1]))!caletdr_ba_b_cpca=letsa=shapeainletl=sa.(axis)inletdim=Array.lengthsainget_slice(List.initdim(funi->ifi=axisthen[l;-1]else[0;-1]))!caend:Piso))andconcat~axis=Lazy.force(_concataxis)andto_rowsa=Array.init(row_numa)(funi->get_rowai)andof_rowsa=(* TODO: this can be further optimised by incorporating t array type as t *)matcha.(0)with|Arr_->Array.mapunpack_arra|>A.of_rows|>pack_arr|DF(_,_,ai)->letap=a|>Array.map(funx->x|>primal|>unpack_arr)|>A.of_rows|>pack_arrinletat=a|>Array.map(funx->x|>tangent|>unpack_arr)|>A.of_rows|>pack_arrinDF(ap,at,ai)|DR(_,_,_,_,ai,_)->letap=a|>Array.map(funx->x|>primal)inletcp=ap|>Array.map(funx->x|>unpack_arr)|>A.of_rows|>pack_arrinletreverse_apcat=t|>List.append(a|>Array.to_list|>List.mapi(funiv->get_row!cai,v))inletinputt=List.append(Array.to_lista)tinletlabel="Of_Rows_D",Array.to_listainDR(cp,ref(zerocp),(reverse,input,label),ref0,ai,ref0)|_->error_uniop"of_rows a.(0)"a.(0)and_of_arrays=lazy(funa->(* mode: `normal , `reverse, `forward *)letmode=ref`normalinletidxs=ref[]inletai_ref=ref0ina(* TODO: the following checks can probably be refactored into ops_builder *)|>Array.iteri(funixs->Array.iteri(funjx->matchx,!modewith|F_,_->()|Arr_,_->error_uniop"of_arrays: array elements should be F not Arr"x|DR(_,_,_,_,ai,_),`normal->mode:=`reverse;ai_ref:=ai;idxs:=[i,j]|DR(_,_,_,_,ai,_),`reverse->ifai>!ai_refthen(idxs:=[i,j];ai_ref:=ai)elseifai=!ai_refthenidxs:=(i,j)::!idxselse()|DR(_,_,_,_,ai,_),`forward->ifai>!ai_refthen(mode:=`reverse;idxs:=[i,j];ai_ref:=ai)elseifai=!ai_refthenfailwith"error: forward and reverse clash on the same level"else()|DF(_,_,ai),`normal->mode:=`forward;ai_ref:=ai;idxs:=[i,j]|DF(_,_,ai),`reverse->ifai>!ai_refthen(mode:=`forward;idxs:=[i,j];ai_ref:=ai)elseifai=!ai_refthenfailwith"error: forward and reverse clash on the same level"else()|DF(_,_,ai),`forward->ifai>!ai_refthen(idxs:=[i,j];ai_ref:=ai)elseifai=!ai_refthenidxs:=(i,j)::!idxselse())xs);match!modewith|`normal->Array.map(Array.mapunpack_elt)a|>A.of_arrays|>pack_arr|`reverse->letcp=Array.map(Array.map(funx->matchxwith|DR(p,_,_,_,ai,_)->ifai=!ai_refthenpelsex|x->x))a|>of_arraysinletidxs=List.rev!idxsinletreverse_cpcat=letca_arrays=to_arrays!caint|>List.append(idxs|>List.map(fun(i,j)->ca_arrays.(i).(j),a.(i).(j)))inletinputt=List.(append(map(fun(i,j)->a.(i).(j))idxs)t)inletlabel="Of_Arrays_D",List.map(fun(i,j)->a.(i).(j))idxsinDR(cp,ref(zerocp),(reverse,input,label),ref0,!ai_ref,ref0)|`forward->letcp=Array.map(Array.map(funx->matchxwith|DF(p,_,ai)->ifai=!ai_refthenpelsex|x->x))a|>of_arraysinletat=letat=Array.map(Array.mapzero)ainList.iter(fun(i,j)->at.(i).(j)<-tangenta.(i).(j))!idxs;at|>of_arraysinDF(cp,at,!ai_ref))andof_arraysa=Lazy.force_of_arraysaandto_arraysa=Array.init(row_numa)(funi->Array.init(col_numa)(funj->get_itemaij))and_split=lazy(fun~axisparts->build_siao(modulestructletlabel="split"letff_fa=error_uniop"label"(pack_elta)letff_arra=A.(split~axispartsa)|>Array.map(funx->Arrx)letdf_cp_ap_at=raise(Owl_exception.NOT_IMPLEMENTED"owl_algodiff_ops.split")letdr_a_cp_cp_ref_arrca_ref_arr=concatenate~axis(Array.map(funca->!ca)ca_ref_arr)end:Siao))andsplit~axisparts=Lazy.force_split~axispartsand_concatenate=lazy(fun~axis->build_aiso(modulestructletlabel="Concatenate_D"letffa=Array.mapunpack_arra|>A.concatenate~axis|>pack_arrletdf___tangents=concatenate~axistangentsletdridxsap_ca=letca=split~axis(Array.map(funx->(shapex).(axis))ap)!cainList.map(funk->ca.(k))idxsend:Aiso))andconcatenate~axis=Lazy.force_concatenate~axisendmoduleLinalg=structopenMaths(* single input single output *)letrecinv=Maths.invand_logdet=lazy(build_siso(modulestructletlabel="logdet"letff_fa=error_unioplabel(pack_elta)letff_arra=FA.(Linalg.logdeta)letdf_cpapat=trace(transpose(invap)*@at)letdra_cpca=!ca*transpose(inv(primala))end:Siso))andlogdeta=Lazy.force_logdetaandcopyltux=trilx+transpose(tril~k:(-1)x)andcopyutlx=triux+transpose(triu~k:1x)and_chol=let_chol_forwardcpatupper=letinv_cp=invcpinlettr_inv_cp=transposeinv_cpinifupperthen(letx=tr_inv_cp*@transposeat*@inv_cpinletm=pack_flt0.5*tril(triux)intransposecp*@(m+triu~k:1x))else(letx=inv_cp*@at*@tr_inv_cpinletm=pack_flt0.5*tril(triux)incp*@(m+tril~k:(-1)x))inlet_chol_backwardocaupper=ifupperthen(letx=linsolve~typ:`uo(copyutl(ca*@transposeo))inletx=linsolve~typ:`uo(transposex)inpack_flt0.5*transposex)else(letx=linsolve~trans:true~typ:`lo(copyltu(transposeo*@ca))inletx=linsolve~trans:true~typ:`lo(transposex)inpack_flt0.5*transposex)inlazy(fun~upper->build_siso(modulestructletlabel="chol"letff_fa=error_uniop"chol"(pack_elta)letff_arra=ArrA.(Linalg.chol~uppera)letdfcp_apat=_chol_forwardcpatupperletdr_acpca=_chol_backwardcp!caupperend:Siso))andchol?(upper=true)=Lazy.force_chol~upper(* single input pair outputs *)and_qr=let_qr_backward(cp1,cp2)(ca1,ca2)=letq=!cp1andr=!cp2andqbar=!ca1andrbar=!ca2inletm=(rbar*@transposer)-(transposeq*@qbar)inlinsolver(transpose(qbar+(q*@copyutlm)))|>transposeinlazy(build_sipo(modulestructletlabel="qr"letff_fa=error_uniop"qr"(pack_elta)letff_arra=letq,r=A.(Linalg.qra)inArrq,Arrrletdf_cp_ap_at=raise(Owl_exception.NOT_IMPLEMENTED"owl_algodiff_ops.qr")letdr_a_cpcp_refca_ref=_qr_backwardcp_refca_refend:Sipo))andqra=Lazy.force_qraand_lq=let_lq_backward(o1,o2)(ca1,ca2)=letl=!o1andq=!o2andlbar=!ca1andqbar=!ca2inletm=(transposel*@lbar)-(qbar*@transposeq)inlinsolve~trans:true~typ:`ll(qbar+(copyltum*@q))inlazy(build_sipo(modulestructletlabel="lq"letff_fa=error_uniop"lq"(pack_elta)letff_arra=letl,q=A.(Linalg.lqa)inArrl,Arrqletdf_cp_ap_at=raise(Owl_exception.NOT_IMPLEMENTED"owl_algodiff_ops.lq")letdr_a_cpoca=_lq_backwardocaend:Sipo))andlqa=Lazy.force_lqa(* single input triple outputs *)and_svd=let_svd_backward(o1,o2,o3)(ca1,ca2,ca3)thin=letu,s,vt=!o1,!o2,!o3andubar,sbar,vbart=!ca1,!ca2,!ca3inletut=transposeuandv=transposevtinletubart=transposeubarandvbar=transposevbartinleteyen=A.(ones[|1;n|])|>pack_arr|>diagminlete_m=eye(row_numu)inlete_n=eye(row_numv)inletk=row_numvtinletf=lets2=sqrsinpack_arrA.(init_nd[|k;k|](funidx->leti=idx.(0)andj=idx.(1)inifi=jthenfloat_to_elt0.else(lets2_i=get_items20i|>unpack_fltinlets2_j=get_items20j|>unpack_fltin1./.(s2_j-.s2_i)|>float_to_elt)))inletinv_s=pack_flt1./sinifthinthen(u*sbar*@vt)+(((u*@(f*((ut*@ubar)-(ubart*@u)))*s)+((e_m-(u*@ut))*@ubar*inv_s))*@vt)+(u*@((transposes*(f*((vt*@vbar)-(vbart*@v)))*@vt)+(transposeinv_s*vbart*@(e_n-(v*@vt)))))elseraise(Owl_exception.NOT_IMPLEMENTED"owl_algodiff_ops.svd")inlazy(fun~thin->build_sito(modulestructletlabel="svd"letff_fa=error_uniop"svd"(pack_elta)letff_arra=letu,s,vt=A.(Linalg.svd~thina)inArru,Arrs,Arrvtletdf_cp_ap_at=raise(Owl_exception.NOT_IMPLEMENTED"owl_algodiff_ops.svd")letdr_a_cpoca=_svd_backwardocathinend:Sito))andsvd?(thin=true)=Lazy.force_svd~thinand_sylvester=lazy(letunpacka=a.(0),a.(1),a.(2)inletsylv_forwardpab_catbtct=letdp_da()=sylvesterab(negat*@p)inletdp_db()=sylvesterab(negp*@bt)inletdp_dc()=sylvesterabctin[|dp_da;dp_db;dp_dc|]inletsylv_backwardab_cppbar=letst=sylvester(transposea)(transposeb)(negpbar)in(* the following calculations are not calculated unless needed *)letabar()=st*@transposepinletbbar()=transposep*@stinletcbar()=negstin[|abar;bbar;cbar|]inbuild_aiso(modulestructletlabel="sylvester"letffa=matchunpackawith|Arra,Arrb,Arrc->A.Linalg.sylvesterabc|>pack_arr|_->error_uniop"sylvester"a.(0)letdfidxspinptangents=leta,b,c=unpackinpinletat,bt,ct=unpacktangentsinletdp=sylv_forwardpabcatbtctinList.map(funk->dp.(k)())idxs|>List.fold_left(+)(pack_flt0.)letdridxsinpppbar_ref=letpbar=!pbar_refinletbars=leta,b,c=unpackinpinsylv_backwardabcppbarinList.map(funk->letbar=bars.(k)inbar())idxsend:Aiso))andsylvesterabc=Lazy.force_sylvester[|a;b;c|](* pair outputs single input *)and_lyapunov=let_lyapunov_backward_aacacp=lets=lyapunov(transposea)(negca)inpack_flt2.*s*@cpinlet_lyapunov_backward_qaca=neg(lyapunov(transposea)(negca))inlet_lyapunov_backward_aqacacp=lets=lyapunov(transposea)(negca)inpack_flt2.*s*@cp,negsinlazy(build_piso(modulestructletlabel="lyapunov"letff_aaa_q=error_unioplabel(pack_elta)letff_aba_q=error_unioplabel(pack_elta)letff_ba_aq=error_unioplabel(pack_eltq)letff_bbaq=ArrA.(Linalg.lyapunovaq)letdf_dacpapat_qp=lyapunovap(neg((at*@cp)+(cp*@transposeat)))letdf_db_cpap_qpqt=lyapunovap(negqt)letdf_dabcpapat_qpqt=lyapunovap(neg((at*@cp)+(cp*@transposeat)))+lyapunovap(negqt)letdr_aba_bcpca=letabar,qbar=_lyapunov_backward_aq(primala)!cacpinabar,qbarletdr_aa_qcpca=_lyapunov_backward_a(primala)!cacpletdr_ba_q_cpca=_lyapunov_backward_q(primala)!caend:Piso))andlyapunova=Lazy.force_lyapunovaand_discrete_lyapunov=let_discrete_lyapunov_backward_aacacp=lets=discrete_lyapunov(transposea)cainpack_flt2.*s*@a*@cpinlet_discrete_lyapunov_backward_qaca=discrete_lyapunov(transposea)cainlet_discrete_lyapunov_backward_aqacacp=lets=discrete_lyapunov(transposea)cainpack_flt2.*s*@a*@cp,sinlazy(fun~solver->build_piso(modulestructletlabel="discrete_lyapunov"letff_aaa_q=error_unioplabel(pack_elta)letff_aba_q=error_unioplabel(pack_elta)letff_ba_aq=error_unioplabel(pack_eltq)letff_bbaq=ArrA.(Linalg.discrete_lyapunov~solveraq)letdf_dacpapat_qp=letg=ap*@cp*@transposeatindiscrete_lyapunovap(g+transposeg)letdf_db_cpap_qpqt=discrete_lyapunovapqtletdf_dabcpapat_qpqt=letg=ap*@cp*@transposeatindiscrete_lyapunovap(g+transposeg)+discrete_lyapunovapqtletdr_aba_bcpca=letabar,qbar=_discrete_lyapunov_backward_aq(primala)!cacpinabar,qbarletdr_aa_qcpca=_discrete_lyapunov_backward_a(primala)!cacpletdr_ba_q_cpca=_discrete_lyapunov_backward_q(primala)!caend:Piso))anddiscrete_lyapunov?(solver=`default)=Lazy.force_discrete_lyapunov~solverand(/@)ab=linsolve~trans:false~typ:`naband_linsolve=let_linsolve_backward_btranstypacbar=linsolve~trans:(nottrans)~typ(primala)cbarinlet_linsolve_backward_atranstypcpbbar=letabar=negbbar*@transposecpinletabar=iftransthentransposeabarelseabarinmatchtypwith|`n->abar|`u->triuabar|`l->trilabarinlazy(fun~trans~typ->build_piso(modulestructletlabel="linsolve"letff_aaa_q=error_unioplabel(pack_elta)letff_aba_q=error_unioplabel(pack_elta)letff_ba_aq=error_unioplabel(pack_eltq)letff_bbaq=ArrA.(Linalg.linsolve~trans~typaq)letdf_dacpapat_qp=linsolve~transap(iftransthenneg(transposeat)*@cpelsenegat*@cp)letdf_db_cpap_bpbt=linsolve~transapbtletdf_dabcpapat_bpbt=linsolve~transap(iftransthenbt-(transposeat*@cp)elsebt-(at*@cp))letdr_aba_bcpca=letbbar=_linsolve_backward_btranstypa!cainletabar=_linsolve_backward_atranstypcpbbarinabar,bbarletdr_aa_bcpca=letbbar=_linsolve_backward_btranstypa!cainletabar=_linsolve_backward_atranstypcpbbarinabarletdr_ba_b_cpca=_linsolve_backward_btranstypa!caend:Piso))andlinsolve?(trans=false)?(typ=`n)=Lazy.force_linsolve~trans~typand_care=lazy(letunpacka=a.(0),a.(1),a.(2),a.(3)inletcare_forward~diag_rpabratbtqtrt=lettr_b=transposebinletr=ifdiag_rthendiagrelserinletinv_r=ifdiag_rthenpack_flt1./relseinvrinletatilde=ifdiag_rthena-(b*inv_r*@tr_b*@p)elsea-(b*@inv_r*@tr_b*@p)inlettr_atilde=transposeatildeinletdp_da()=letpat=p*@atinlyapunovtr_atilde(neg(transposepat)-pat)inletdp_dq()=lyapunovtr_atilde(negqt)inletdp_dr()=letpbinv_r=ifdiag_rthenp*@(b*inv_r)elsep*@b*@inv_rinlyapunovtr_atilde(neg(pbinv_r*@rt*@transposepbinv_r))inletdp_db()=letx=ifdiag_rthenp*@(bt*inv_r)*@tr_b*@pelsep*@bt*@inv_r*@tr_b*@pinlyapunovtr_atilde(x+transposex)in[|dp_da;dp_db;dp_dq;dp_dr|]inletcare_backward~diag_rab_qrppbar=lettr_b=transposebinletinv_r=ifdiag_rthenpack_flt1./diagrelseinvrinletatilde=ifdiag_rthena-(b*inv_r*@tr_b*@p)elsea-(b*@inv_r*@tr_b*@p)inlets=lyapunovatilde(negpbar)in(* the following calculations are not calculated unless needed *)letqbar()=sinletrbar()=letpbinv_r=ifdiag_rthenp*@(b*inv_r)elsep*@b*@inv_rintransposepbinv_r*@s*@pbinv_rinletabar()=pack_flt2.*p*@sinletbbar()=ifdiag_rthenneg(pack_flt2.)*p*@s*@p*@(b*inv_r)elseneg(pack_flt2.)*p*@s*@p*@b*@inv_rin[|abar;bbar;qbar;rbar|]infun~diag_r->build_aiso(modulestructletlabel="care"letffa=matchunpackawith|Arra,Arrb,Arrq,Arrr->A.Linalg.care~diag_rabqr|>pack_arr|_->error_uniop"care"a.(0)letdfidxspinptangents=leta,b,_,r=unpackinpinletat,bt,qt,rt=unpacktangentsinletdp=care_forward~diag_rpabratbtqtrtinList.map(funk->dp.(k)())idxs|>List.fold_left(+)(pack_flt0.)letdridxsinpppbar_ref=letpbar=!pbar_refinletbars=leta,b,q,r=unpackinpincare_backward~diag_rabqrppbarinList.map(funk->bars.(k)())idxsend:Aiso))andcare?(diag_r=false)abqr=Lazy.force_care~diag_r[|a;b;q;r|]end(* neural network module: for specialised neural network operations *)moduleNN=structopenMaths(* NOTE: these fucntions are for neural network. There are many restrictions at the
moment. E.g. they do not support higher-order derivatives, and some do not support
forward mode, so use them when you know what you are doing. *)letdropout?(rate=0.5)a=letp=A.float_to_elt(1.-.rate)inletb=matchprimal'awith|Arra->Arr(A.bernoulli~p(A.shapea))|_->error_uniop"dropout"aina*b(* a:input; b:kernel; s:stride *)let_conv1d=(* a:input; b:kernel; s:stride; o:output' *)letconv1d_backward_inputabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.conv1d_backward_inputabso|>pack_arrin(* a:input; b:kernel; s:stride; o:output' *)letconv1d_backward_kernelabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.conv1d_backward_kernelabso|>pack_arrinlazy(fun~paddingabs->build_piso(modulestructletlabel="conv1d"letff_aaa_b=error_unioplabel(pack_elta)letff_aba_b=error_unioplabel(pack_elta)letff_ba_ab=error_unioplabel(pack_eltb)letff_bbab=ArrA.(conv1d?paddingabs)letdf_da_cp_apat_bp=atletdf_db_cp_ap_bpbt=btletdf_dab_cp_apat_bpbt=at+btletdr_abab_cpca=conv1d_backward_inputabs!ca,conv1d_backward_kernelabs!caletdr_aab_cpca=conv1d_backward_inputabs!caletdr_bab_cpca=conv1d_backward_kernelabs!caend:Piso)ab)letconv1d?padding=Lazy.force_conv1d~padding(* a:input; b:kernel; s:stride *)let_conv2d=(* a:input; b:kernel; s:stride; o:output' *)letconv2d_backward_inputabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.conv2d_backward_inputabso|>pack_arrin(* a:input; b:kernel; s:stride; o:output' *)letconv2d_backward_kernelabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.conv2d_backward_kernelabso|>pack_arrinlazy(fun~paddingabs->build_piso(modulestructletlabel="conv2d"letff_aaa_b=error_unioplabel(pack_elta)letff_aba_b=error_unioplabel(pack_elta)letff_ba_ab=error_unioplabel(pack_eltb)letff_bbab=ArrA.(conv2d?paddingabs)letdf_da_cp_apat_bp=atletdf_db_cp_ap_bpbt=btletdf_dab_cp_apat_bpbt=at+btletdr_abab_cpca=conv2d_backward_inputabs!ca,conv2d_backward_kernelabs!caletdr_aab_cpca=conv2d_backward_inputabs!caletdr_bab_cpca=conv2d_backward_kernelabs!caend:Piso)ab)letconv2d?padding=Lazy.force_conv2d~padding(* a:input; b:kernel; s:stride *)let_conv3d=(* a:input; b:kernel; s:stride; o:output' *)letconv3d_backward_inputabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.conv3d_backward_inputabso|>pack_arr(* a:input; b:kernel; s:stride; o:output' *)andconv3d_backward_kernelabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.conv3d_backward_kernelabso|>pack_arrinlazy(fun~paddingabs->build_piso(modulestructletlabel="conv3d"letff_aaa_b=error_unioplabel(pack_elta)letff_aba_b=error_unioplabel(pack_elta)letff_ba_ab=error_unioplabel(pack_eltb)letff_bbab=ArrA.(conv3d?paddingabs)letdf_da_cp_apat_bp=atletdf_db_cp_ap_bpbt=btletdf_dab_cp_apat_bpbt=at+btletdr_abab_cpca=conv3d_backward_inputabs!ca,conv3d_backward_kernelabs!caletdr_aab_cpca=conv3d_backward_inputabs!caletdr_bab_cpca=conv3d_backward_kernelabs!caend:Piso)ab)letconv3d?padding=Lazy.force_conv3d~padding(* a:input; b:kernel; s:stride; r:rate *)let_dilated_conv1d=(* a:input; b:kernel; o:output'; s:stride; r:rate *)letdilated_conv1d_backward_inputabsro=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.dilated_conv1d_backward_inputabsro|>pack_arr(* a:input; b:kernel; o:output'; s:stride; r:rate *)anddilated_conv1d_backward_kernelabsro=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.dilated_conv1d_backward_kernelabsro|>pack_arrinlazy(fun~paddingabsr->build_piso(modulestructletlabel="dilated_conv1d"letff_aaa_b=error_unioplabel(pack_elta)letff_aba_b=error_unioplabel(pack_elta)letff_ba_ab=error_unioplabel(pack_eltb)letff_bbab=ArrA.(dilated_conv1d?paddingabsr)letdf_da_cp_apat_bp=atletdf_db_cp_ap_bpbt=btletdf_dab_cp_apat_bpbt=at+btletdr_abab_cpca=(dilated_conv1d_backward_inputabsr!ca,dilated_conv1d_backward_kernelabsr!ca)letdr_aab_cpca=dilated_conv1d_backward_inputabsr!caletdr_bab_cpca=dilated_conv1d_backward_kernelabsr!caend:Piso)ab)letdilated_conv1d?padding=Lazy.force_dilated_conv1d~padding(* a:input; b:kernel; s:stride; r:rate *)let_dilated_conv2d=(* a:input; b:kernel; o:output'; s:stride; r:rate *)letdilated_conv2d_backward_inputabsro=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.dilated_conv2d_backward_inputabsro|>pack_arr(* a:input; b:kernel; o:output'; s:stride; r:rate *)anddilated_conv2d_backward_kernelabsro=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.dilated_conv2d_backward_kernelabsro|>pack_arrinlazy(fun~paddingabsr->build_piso(modulestructletlabel="dilated_conv2d"letff_aaa_b=error_unioplabel(pack_elta)letff_aba_b=error_unioplabel(pack_elta)letff_ba_ab=error_unioplabel(pack_eltb)letff_bbab=ArrA.(dilated_conv2d?paddingabsr)letdf_da_cp_apat_bp=atletdf_db_cp_ap_bpbt=btletdf_dab_cp_apat_bpbt=at+btletdr_abab_cpca=(dilated_conv2d_backward_inputabsr!ca,dilated_conv2d_backward_kernelabsr!ca)letdr_aab_cpca=dilated_conv2d_backward_inputabsr!caletdr_bab_cpca=dilated_conv2d_backward_kernelabsr!caend:Piso)ab)letdilated_conv2d?padding=Lazy.force_dilated_conv2d~padding(* a:input; b:kernel; s:stride; r:rate *)let_dilated_conv3d=(* a:input; b:kernel; o:output'; s:stride; r:rate *)letdilated_conv3d_backward_inputabsro=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.dilated_conv3d_backward_inputabsro|>pack_arr(* a:input; b:kernel; o:output'; s:stride; r:rate *)anddilated_conv3d_backward_kernelabsro=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.dilated_conv3d_backward_kernelabsro|>pack_arrinlazy(fun~paddingabsr->build_piso(modulestructletlabel="dilated_conv3d"letff_aaa_b=error_unioplabel(pack_elta)letff_aba_b=error_unioplabel(pack_elta)letff_ba_ab=error_unioplabel(pack_eltb)letff_bbab=ArrA.(dilated_conv3d?paddingabsr)letdf_da_cp_apat_bp=atletdf_db_cp_ap_bpbt=btletdf_dab_cp_apat_bpbt=at+btletdr_abab_cpca=(dilated_conv3d_backward_inputabsr!ca,dilated_conv3d_backward_kernelabsr!ca)letdr_aab_cpca=dilated_conv3d_backward_inputabsr!caletdr_bab_cpca=dilated_conv3d_backward_kernelabsr!caend:Piso)ab)letdilated_conv3d?padding=Lazy.force_dilated_conv3d~padding(* a:input; b:kernel; s:stride *)let_transpose_conv1d=(* a:input; b:kernel; s:stride; o:output' *)lettranspose_conv1d_backward_inputabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.transpose_conv1d_backward_inputabso|>pack_arr(* a:input; b:kernel; s:stride; o:output' *)andtranspose_conv1d_backward_kernelabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.transpose_conv1d_backward_kernelabso|>pack_arrinlazy(fun~paddingabs->build_piso(modulestructletlabel="transpose_conv1d"letff_aaa_b=error_unioplabel(pack_elta)letff_aba_b=error_unioplabel(pack_elta)letff_ba_ab=error_unioplabel(pack_eltb)letff_bbab=ArrA.(transpose_conv1d?paddingabs)letdf_da_cp_apat_bp=atletdf_db_cp_ap_bpbt=btletdf_dab_cp_apat_bpbt=at+btletdr_abab_cpca=(transpose_conv1d_backward_inputabs!ca,transpose_conv1d_backward_kernelabs!ca)letdr_aab_cpca=transpose_conv1d_backward_inputabs!caletdr_bab_cpca=transpose_conv1d_backward_kernelabs!caend:Piso)ab)lettranspose_conv1d?padding=Lazy.force_transpose_conv1d~padding(* a:input; b:kernel; s:stride *)and_transpose_conv2d=(* a:input; b:kernel; s:stride; o:output' *)lettranspose_conv2d_backward_inputabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.transpose_conv2d_backward_inputabso|>pack_arr(* a:input; b:kernel; s:stride; o:output' *)andtranspose_conv2d_backward_kernelabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.transpose_conv2d_backward_kernelabso|>pack_arrinlazy(fun~paddingabs->build_piso(modulestructletlabel="transpose_conv2d"letff_aaa_b=error_unioplabel(pack_elta)letff_aba_b=error_unioplabel(pack_elta)letff_ba_ab=error_unioplabel(pack_eltb)letff_bbab=ArrA.(transpose_conv2d?paddingabs)letdf_da_cp_apat_bp=atletdf_db_cp_ap_bpbt=btletdf_dab_cp_apat_bpbt=at+btletdr_abab_cpca=(transpose_conv2d_backward_inputabs!ca,transpose_conv2d_backward_kernelabs!ca)letdr_aab_cpca=transpose_conv2d_backward_inputabs!caletdr_bab_cpca=transpose_conv2d_backward_kernelabs!caend:Piso)ab)lettranspose_conv2d?padding=Lazy.force_transpose_conv2d~padding(* a:input; b:kernel; s:stride *)let_transpose_conv3d=(* a:input; b:kernel; s:stride; o:output' *)lettranspose_conv3d_backward_inputabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.transpose_conv3d_backward_inputabso|>pack_arr(* a:input; b:kernel; s:stride; o:output' *)andtranspose_conv3d_backward_kernelabso=leta=unpack_arrainletb=unpack_arrbinleto=unpack_arroinA.transpose_conv3d_backward_kernelabso|>pack_arrinlazy(fun~paddingabs->build_piso(modulestructletlabel="transpose_conv3d"letff_aaa_b=error_unioplabel(pack_elta)letff_aba_b=error_unioplabel(pack_elta)letff_ba_ab=error_unioplabel(pack_eltb)letff_bbab=ArrA.(transpose_conv3d?paddingabs)letdf_da_cp_apat_bp=atletdf_db_cp_ap_bpbt=btletdf_dab_cp_apat_bpbt=at+btletdr_abab_cpca=(transpose_conv3d_backward_inputabs!ca,transpose_conv3d_backward_kernelabs!ca)letdr_aab_cpca=transpose_conv3d_backward_inputabs!caletdr_bab_cpca=transpose_conv3d_backward_kernelabs!caend:Piso)ab)lettranspose_conv3d?padding=Lazy.force_transpose_conv3d~padding(* a:input; b:kernel; s:stride *)let_max_pool1d=(* a:input; p:padding type; b:kernel; s:stride; o:output' *)letmax_pool1d_backwardpabso=leta=unpack_arrainleto=unpack_arroinA.max_pool1d_backwardpabso|>pack_arrinlazy(funpaddingabs->build_siso(modulestructletlabel="max_pool1d"letff_fa=error_unioplabel(pack_elta)letff_arra=ArrA.(max_pool1d~paddingabs)letdf_cp_ap_at=failwith"max_pool1d:df"letdra_cpca=max_pool1d_backwardpadding(primala)bs!caend:Siso)a)letmax_pool1dpadding=Lazy.force_max_pool1dpadding(* a:input; b:kernel; s:stride *)let_max_pool2d=(* a:input; p:padding type; b:kernel; s:stride; o:output' *)letmax_pool2d_backwardpabso=leta=unpack_arrainleto=unpack_arroinA.max_pool2d_backwardpabso|>pack_arrinlazy(funpaddingabs->build_siso(modulestructletlabel="max_pool2d"letff_fa=error_unioplabel(pack_elta)letff_arra=ArrA.(max_pool2d~paddingabs)letdf_cp_ap_at=failwith"max_pool2d:df"letdra_cpca=max_pool2d_backwardpadding(primala)bs!caend:Siso)a)letmax_pool2dpadding=Lazy.force_max_pool2dpadding(* a:input; b:kernel; s:stride *)let_max_pool3d=(* a:input; p:padding type; b:kernel; s:stride; o:output' *)letmax_pool3d_backwardpabso=leta=unpack_arrainleto=unpack_arroinA.max_pool3d_backwardpabso|>pack_arrinlazy(funpaddingabs->build_siso(modulestructletlabel="max_pool3d"letff_fa=error_unioplabel(pack_elta)letff_arra=ArrA.(max_pool3d~paddingabs)letdf_cp_ap_at=failwith"max_pool3d:df"letdra_cpca=max_pool3d_backwardpadding(primala)bs!caend:Siso)a)letmax_pool3dpadding=Lazy.force_max_pool3dpadding(* a:input; b:kernel; s:stride *)let_avg_pool1d=(* a:input; p:padding type; b:kernel; s:stride; o:output' *)letavg_pool1d_backwardpabso=leta=unpack_arrainleto=unpack_arroinA.avg_pool1d_backwardpabso|>pack_arrinlazy(funpaddingabs->build_siso(modulestructletlabel="avg_pool1d"letff_fa=error_unioplabel(pack_elta)letff_arra=ArrA.(avg_pool1d~paddingabs)letdf_cp_ap_at=failwith"avg_pool1d:df"letdra_cpca=avg_pool1d_backwardpadding(primala)bs!caend:Siso)a)letavg_pool1dpadding=Lazy.force_avg_pool1dpadding(* a:input; b:kernel; s:stride *)and_avg_pool2d=(* a:input; p:padding type; b:kernel; s:stride; o:output' *)letavg_pool2d_backwardpabso=leta=unpack_arrainleto=unpack_arroinA.avg_pool2d_backwardpabso|>pack_arrinlazy(funpaddingabs->build_siso(modulestructletlabel="avg_pool2d"letff_fa=error_unioplabel(pack_elta)letff_arra=ArrA.(avg_pool2d~paddingabs)letdf_cp_ap_at=failwith"avg_pool2d:df"letdra_cpca=avg_pool2d_backwardpadding(primala)bs!caend:Siso)a)letavg_pool2dpadding=Lazy.force_avg_pool2dpadding(* a:input; b:kernel; s:stride *)let_avg_pool3d=(* a:input; p:padding type; b:kernel; s:stride; o:output' *)letavg_pool3d_backwardpabso=leta=unpack_arrainleto=unpack_arroinA.avg_pool3d_backwardpabso|>pack_arrinlazy(funpaddingabs->build_siso(modulestructletlabel="avg_pool3d"letff_fa=error_unioplabel(pack_elta)letff_arra=ArrA.(avg_pool3d~paddingabs)letdf_cp_ap_at=failwith"avg_pool3d:df"letdra_cpca=avg_pool3d_backwardpadding(primala)bs!caend:Siso)a)letavg_pool3dpadding=Lazy.force_avg_pool3dpadding(* a:input; s:size *)let_upsampling2d=(* a:input; s:size; o:output' *)letupsampling2d_backwardaso=leta=unpack_arrainleto=unpack_arroinA.upsampling2d_backwardaso|>pack_arrinlazy(funas->build_siso(modulestructletlabel="upsampling2d"letff_fa=error_unioplabel(pack_elta)letff_arra=ArrA.(upsampling2das)letdf_cp_ap_at=failwith"upsampling2d:df"letdra_cpca=upsampling2d_backward(primala)s!caend:Siso)a)letupsampling2da=Lazy.force_upsampling2da(* v: padded value; p:padding index; a:input *)let_pad=(* TODO: sources required to confirm this backward op *)(* o:outut'; p: padding index *)letpad_backwardop=(* assume p is full legal index for pad operation *)leto=unpack_arroinletos=A.shapeoinletq=Owl_utils.llss2aarrpinArray.iteri(funix->x.(1)<-Stdlib.(os.(i)-1-x.(1)))q;letq=Owl_utils.aarr2llssqinA.(get_sliceqo)|>pack_arrinlazy(fun~vpa->build_siso(modulestructletlabel="pad"letff_fa=error_unioplabel(pack_elta)letff_arra=ArrA.(pad?vpa)letdf_cp_ap_at=failwith"pad:df"letdr_a_cpca=pad_backward!capend:Siso)a)letpad?v=Lazy.force_pad~vendmoduleMat=structletemptymn=A.empty[|m;n|]|>pack_arrletzerosmn=A.zeros[|m;n|]|>pack_arrleteyen=A.Mat.eyen|>pack_arrletonesmn=A.ones[|m;n|]|>pack_arrletuniform?a?bmn=A.uniform?a?b[|m;n|]|>pack_arrletgaussian?mu?sigmamn=A.gaussian?mu?sigma[|m;n|]|>pack_arrletresetx=x|>unpack_arr|>A.resetletreshapemnx=Maths.reshapex[|m;n|]letshapex=lets=A.shape(unpack_arrx)ins.(0),s.(1)letrow_numx=(unpack_arrx|>A.shape).(0)letcol_numx=(unpack_arrx|>A.shape).(1)letnumelx=numelxletrowxi=Maths.get_rowxiletgetxij=Maths.get_itemxijletsetxija=Maths.set_itemxija(* unary math operators *)letmeanx=Maths.meanx(* binary math operators *)letaddxy=Maths.addxyletsubxy=Maths.subxyletmulxy=Maths.mulxyletdivxy=Maths.divxyletdotxy=Maths.dotxyletmap_by_rowfx=x|>Maths.to_rows|>Array.mapf|>Maths.of_rowsletprintx=A.print(unpack_arrx)letof_arraysx=A.of_arraysx|>pack_arrletinit_2dn_rowsn_colsf=Array.initn_rows(funi->Array.initn_cols(funj->fij))|>Maths.of_arraysendmoduleArr=structletemptyd=A.emptyd|>pack_arrletzerosd=A.zerosd|>pack_arrletonesd=A.onesd|>pack_arrletuniform?a?bd=A.uniform?a?bd|>pack_arrletgaussian?mu?sigmad=A.gaussian?mu?sigmad|>pack_arrletresetx=x|>unpack_arr|>A.resetletreshapexs=Maths.reshapexsletshapex=A.shape(unpack_arrx)letnumelx=numelx(* binary math operators *)letaddxy=Maths.addxyletsubxy=Maths.subxyletmulxy=Maths.mulxyletdivxy=Maths.divxyletdotxy=Maths.dotxyendend