123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957# 1 "src/base/compute/owl_computation_operator.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*)openOwl_types(* Functor of making Lazy module of different number types *)moduleMake(Symbol:Owl_computation_symbol_sig.Sig)=structmoduleSymbol=SymbolopenSymbolopenSymbol.Shape.TypeopenSymbol.Shape.Type.Device(* mathematical functions *)letnoopx=make_then_connectNoop[|arr_to_nodex|]|>node_to_arrletemptyshape=make_node~shape:[|Someshape|](Emptyshape)|>node_to_arrletzerosshape=make_node~shape:[|Someshape|](Zerosshape)|>node_to_arrletonesshape=make_node~shape:[|Someshape|](Onesshape)|>node_to_arrletcreateshapev=make_then_connect~shape:[|Someshape|](Createshape)[|elt_to_nodev|]|>node_to_arrletsequential?a?stepshape=leta=matchawith|Somea->a|None->const_elt"sequential_a"(A.float_to_elt0.)inletb=matchstepwith|Someb->b|None->const_elt"sequential_step"(A.float_to_elt1.)inmake_then_connect~shape:[|Someshape|](Sequentialshape)[|elt_to_nodea;elt_to_nodeb|]|>node_to_arrletuniform?a?bshape=leta=matchawith|Somea->a|None->const_elt"uniform_a"(A.float_to_elt0.)inletb=matchbwith|Someb->b|None->const_elt"uniform_b"(A.float_to_elt1.)inmake_then_connect~shape:[|Someshape|](Uniformshape)[|elt_to_nodea;elt_to_nodeb|]|>node_to_arrletgaussian?mu?sigmashape=leta=matchmuwith|Somea->a|None->const_elt"sequential_a"(A.float_to_elt0.)inletb=matchsigmawith|Someb->b|None->const_elt"sequential_step"(A.float_to_elt1.)inmake_then_connect~shape:[|Someshape|](Gaussianshape)[|elt_to_nodea;elt_to_nodeb|]|>node_to_arrletbernoulli?pshape=letp=matchpwith|Somep->p|None->const_elt"bernoulli_p"(A.float_to_elt0.5)inmake_then_connect~shape:[|Someshape|](Bernoullishape)[|elt_to_nodep|]|>node_to_arrletinitshapef=make_node~shape:[|Someshape|](Init(shape,f))|>node_to_arrletinit_nd_shape_f=raise(Owl_exception.NOT_IMPLEMENTED"owl_computation_operator.init_nd")letshapex=arr_to_nodex|>node_shapeletnumelx=Array.fold_left(*)1(shapex)letgetxi=make_then_connect(Geti)[|arr_to_nodex|]|>node_to_eltletsetxiv=make_then_connect(Seti)[|arr_to_nodex;elt_to_nodev|]|>ignoreletget_sliceslicex=make_then_connect(GetSliceslice)[|arr_to_nodex|]|>node_to_arrletset_sliceslicexy=make_then_connect(SetSliceslice)[|arr_to_nodex;arr_to_nodey|]|>ignoreletget_fancyindicesx=make_then_connect(GetFancyindices)[|arr_to_nodex|]|>node_to_arrletset_fancyindicesxy=make_then_connect(SetFancyindices)[|arr_to_nodex;arr_to_nodey|]|>ignoreletcopyx=make_then_connectCopy[|arr_to_nodex|]|>node_to_arrletcopy_~out_x=failwith"Owl_computation_operator:copy_: not implemented"[@@warning"-27"]letresetx=make_then_connectReset[|arr_to_nodex|]|>node_to_arr|>ignoreletreshapexshape=letn_old=numelxinletn_new=Array.fold_left(*)1shapeinletexn=Owl_exception.DIFFERENT_SIZE(n_old,n_new)inOwl_exception.check(n_old=n_new)exn;make_then_connect(Reshapeshape)[|arr_to_nodex|]|>node_to_arrletreversex=make_then_connectReverse[|arr_to_nodex|]|>node_to_arrlettilexaxises=make_then_connect(Tileaxises)[|arr_to_nodex|]|>node_to_arrletrepeatxrepeats=make_then_connect(Repeatrepeats)[|arr_to_nodex|]|>node_to_arrletpad?vpaddingx=letv=matchvwith|Somev->v|None->const_elt"pad_v"(A.float_to_elt0.)inmake_then_connect(Pad(v,padding))[|arr_to_nodex|]|>node_to_arrletexpand?(hi=false)_x_d=ignorehi;failwith"expand: not implemented"letsqueeze?(axis=[||])_x=ignoreaxis;failwith"squeeze: not implemented"letconcatenate?(axis=0)xs=make_then_connect(Concatenateaxis)(Array.maparr_to_nodexs)|>node_to_arrletstack?(axis=0)xs=make_then_connect(Stackaxis)(Array.maparr_to_nodexs)|>node_to_arrletconcat~axis=axis|>ignore;failwith"concat: not implemented"letsplit?(axis=0)_parts_x=failwith"split: not implemented"[@@warning"-27"]letdraw?(axis=0)xn=lety=make_then_connect(Draw(axis,n))[|arr_to_nodex|]|>node_to_arriny,[||]letmapfx=make_then_connect(Mapf)[|arr_to_nodex|]|>node_to_arrletfold?(axis=-1)fax=make_then_connect(Fold(axis,f))[|arr_to_nodex;elt_to_nodea|]|>node_to_arrletscan?(axis=-1)fx=make_then_connect(Scan(axis,f))[|arr_to_nodex|]|>node_to_arrletone_hotdepthx=make_then_connect(OneHotdepth)[|arr_to_nodex|]|>node_to_arrletdelayfx=make_then_connect(Delayf)[|arr_to_nodex|]|>node_to_arrletdelay_arrayshapefx=make_then_connect~shape:[|Someshape|](DelayArray(shape,f))(Array.maparr_to_nodex)|>node_to_arrletlazy_print?max_row?max_col?header?fmtx=make_then_connect(LazyPrint(max_row,max_col,header,fmt))[|arr_to_nodex|]|>node_to_arrletprint?max_row?max_col?header?fmtx=()[@@warning"-27"]letabsx=make_then_connectAbs[|arr_to_nodex|]|>node_to_arrletnegx=make_then_connectNeg[|arr_to_nodex|]|>node_to_arrletfloorx=make_then_connectFloor[|arr_to_nodex|]|>node_to_arrletceilx=make_then_connectCeil[|arr_to_nodex|]|>node_to_arrletroundx=make_then_connectRound[|arr_to_nodex|]|>node_to_arrletsqrx=make_then_connectSqr[|arr_to_nodex|]|>node_to_arrletsqrtx=make_then_connectSqrt[|arr_to_nodex|]|>node_to_arrletlogx=make_then_connectLog[|arr_to_nodex|]|>node_to_arrletlog2x=make_then_connectLog2[|arr_to_nodex|]|>node_to_arrletlog10x=make_then_connectLog10[|arr_to_nodex|]|>node_to_arrletexpx=make_then_connectExp[|arr_to_nodex|]|>node_to_arrletsinx=make_then_connectSin[|arr_to_nodex|]|>node_to_arrletcosx=make_then_connectCos[|arr_to_nodex|]|>node_to_arrlettanx=make_then_connectTan[|arr_to_nodex|]|>node_to_arrletsinhx=make_then_connectSinh[|arr_to_nodex|]|>node_to_arrletcoshx=make_then_connectCosh[|arr_to_nodex|]|>node_to_arrlettanhx=make_then_connectTanh[|arr_to_nodex|]|>node_to_arrletasinx=make_then_connectAsin[|arr_to_nodex|]|>node_to_arrletacosx=make_then_connectAcos[|arr_to_nodex|]|>node_to_arrletatanx=make_then_connectAtan[|arr_to_nodex|]|>node_to_arrletasinhx=make_then_connectAsinh[|arr_to_nodex|]|>node_to_arrletacoshx=make_then_connectAcosh[|arr_to_nodex|]|>node_to_arrletatanhx=make_then_connectAtanh[|arr_to_nodex|]|>node_to_arrletmin?(axis=-1)?(keep_dims=true)x=make_then_connect(Min(keep_dims,axis))[|arr_to_nodex|]|>node_to_arrletmax?(axis=-1)?(keep_dims=true)x=ignorekeep_dims;make_then_connect(Max(keep_dims,axis))[|arr_to_nodex|]|>node_to_arrletsum?(axis=-1)?(keep_dims=true)x=ignorekeep_dims;make_then_connect(Sum(keep_dims,axis))[|arr_to_nodex|]|>node_to_arrletsum_reduce?(axis=[|-1|])x=make_then_connect(SumReduceaxis)[|arr_to_nodex|]|>node_to_arrletsignumx=make_then_connectSignum[|arr_to_nodex|]|>node_to_arrletsigmoidx=make_then_connectSigmoid[|arr_to_nodex|]|>node_to_arrletrelux=make_then_connectRelu[|arr_to_nodex|]|>node_to_arrletdawsnx=make_then_connectDawsn[|arr_to_nodex|]|>node_to_arrletmin'x=make_then_connectMin'[|arr_to_nodex|]|>node_to_eltletmax'x=make_then_connectMax'[|arr_to_nodex|]|>node_to_eltletsum'x=make_then_connectSum'[|arr_to_nodex|]|>node_to_eltletlog_sum_exp'x=make_then_connectLogSumExp'[|arr_to_nodex|]|>node_to_eltletlog_sum_exp?(axis=0)?(keep_dims=true)x=make_then_connect(LogSumExp(keep_dims,axis))[|arr_to_nodex|]|>node_to_arrletl1norm'x=make_then_connectL1norm'[|arr_to_nodex|]|>node_to_eltletl2norm'x=make_then_connectL2norm'[|arr_to_nodex|]|>node_to_eltletl2norm_sqr'x=make_then_connectL2NormSqr'[|arr_to_nodex|]|>node_to_eltletclip_by_value?amin?amaxx=leta=matchaminwith|Somea->a|None->const_elt"clip_by_value_amin"(A.float_to_eltneg_infinity)inletb=matchamaxwith|Someb->b|None->const_elt"clip_by_value_amax"(A.float_to_eltinfinity)inmake_then_connectClipByValue[|arr_to_nodex;elt_to_nodea;elt_to_nodeb|]|>node_to_arrletclip_by_l2normax=make_then_connectClipByL2norm[|arr_to_nodex;elt_to_nodea|]|>node_to_arrletpowxy=make_then_connectPow[|arr_to_nodex;arr_to_nodey|]|>node_to_arrletscalar_powax=make_then_connectScalarPow[|elt_to_nodea;arr_to_nodex|]|>node_to_arrletpow_scalarxa=make_then_connectPowScalar[|arr_to_nodex;elt_to_nodea|]|>node_to_arrletatan2xy=make_then_connectAtan2[|arr_to_nodex;arr_to_nodey|]|>node_to_arrletscalar_atan2ax=make_then_connectScalarAtan2[|elt_to_nodea;arr_to_nodex|]|>node_to_arrletatan2_scalarxa=make_then_connectAtan2Scalar[|arr_to_nodex;elt_to_nodea|]|>node_to_arrlethypotxy=make_then_connectHypot[|arr_to_nodex;arr_to_nodey|]|>node_to_arrletmin2xy=make_then_connectMin2[|arr_to_nodex;arr_to_nodey|]|>node_to_arrletmax2xy=make_then_connectMax2[|arr_to_nodex;arr_to_nodey|]|>node_to_arrletaddxy=make_then_connectAdd[|arr_to_nodex;arr_to_nodey|]|>node_to_arrletsubxy=make_then_connectSub[|arr_to_nodex;arr_to_nodey|]|>node_to_arrletmulxy=make_then_connectMul[|arr_to_nodex;arr_to_nodey|]|>node_to_arrletdivxy=make_then_connectDiv[|arr_to_nodex;arr_to_nodey|]|>node_to_arrletadd_scalarxa=make_then_connectAddScalar[|arr_to_nodex;elt_to_nodea|]|>node_to_arrletsub_scalarxa=make_then_connectSubScalar[|arr_to_nodex;elt_to_nodea|]|>node_to_arrletmul_scalarxa=make_then_connectMulScalar[|arr_to_nodex;elt_to_nodea|]|>node_to_arrletdiv_scalarxa=make_then_connectDivScalar[|arr_to_nodex;elt_to_nodea|]|>node_to_arrletscalar_addax=make_then_connectScalarAdd[|elt_to_nodea;arr_to_nodex|]|>node_to_arrletscalar_subax=make_then_connectScalarSub[|elt_to_nodea;arr_to_nodex|]|>node_to_arrletscalar_mulax=make_then_connectScalarMul[|elt_to_nodea;arr_to_nodex|]|>node_to_arrletscalar_divax=make_then_connectScalarDiv[|elt_to_nodea;arr_to_nodex|]|>node_to_arrletfmaxyz=make_then_connectFMA[|arr_to_nodex;arr_to_nodey;arr_to_nodez|]|>node_to_arrletelt_equalxy=make_then_connectEltEqual[|arr_to_nodex;arr_to_nodey|]|>node_to_arrletelt_not_equalxy=make_then_connectEltNotEqual[|arr_to_nodex;arr_to_nodey|]|>node_to_arrletelt_lessxy=make_then_connectEltLess[|arr_to_nodex;arr_to_nodey|]|>node_to_arrletelt_greaterxy=make_then_connectEltGreater[|arr_to_nodex;arr_to_nodey|]|>node_to_arrletelt_less_equalxy=make_then_connectEltLessEqual[|arr_to_nodex;arr_to_nodey|]|>node_to_arrletelt_greater_equalxy=make_then_connectEltGreaterEqual[|arr_to_nodex;arr_to_nodey|]|>node_to_arrletelt_equal_scalarxa=make_then_connectEltEqualScalar[|arr_to_nodex;elt_to_nodea|]|>node_to_arrletelt_not_equal_scalarxa=make_then_connectEltNotEqualScalar[|arr_to_nodex;elt_to_nodea|]|>node_to_arrletelt_less_scalarxa=make_then_connectEltLessScalar[|arr_to_nodex;elt_to_nodea|]|>node_to_arrletelt_greater_scalarxa=make_then_connectEltGreaterScalar[|arr_to_nodex;elt_to_nodea|]|>node_to_arrletelt_less_equal_scalarxa=make_then_connectEltLessEqualScalar[|arr_to_nodex;elt_to_nodea|]|>node_to_arrletelt_greater_equal_scalarxa=make_then_connectEltGreaterEqualScalar[|arr_to_nodex;elt_to_nodea|]|>node_to_arrletconv1d?(padding=SAME)inputkernelstride=make_then_connect(Conv1d(padding,stride))[|arr_to_nodeinput;arr_to_nodekernel|]|>node_to_arrletconv2d?(padding=SAME)inputkernelstride=make_then_connect(Conv2d(padding,stride))[|arr_to_nodeinput;arr_to_nodekernel|]|>node_to_arrletconv3d?(padding=SAME)inputkernelstride=make_then_connect(Conv3d(padding,stride))[|arr_to_nodeinput;arr_to_nodekernel|]|>node_to_arrlettranspose_conv1d?(padding=SAME)inputkernelstride=make_then_connect(TransposeConv1d(padding,stride))[|arr_to_nodeinput;arr_to_nodekernel|]|>node_to_arrlettranspose_conv2d?(padding=SAME)inputkernelstride=make_then_connect(TransposeConv2d(padding,stride))[|arr_to_nodeinput;arr_to_nodekernel|]|>node_to_arrlettranspose_conv3d?(padding=SAME)inputkernelstride=make_then_connect(TransposeConv3d(padding,stride))[|arr_to_nodeinput;arr_to_nodekernel|]|>node_to_arrletdilated_conv1d?(padding=SAME)inputkernelstriderate=make_then_connect(DilatedConv1d(padding,stride,rate))[|arr_to_nodeinput;arr_to_nodekernel|]|>node_to_arrletdilated_conv2d?(padding=SAME)inputkernelstriderate=make_then_connect(DilatedConv2d(padding,stride,rate))[|arr_to_nodeinput;arr_to_nodekernel|]|>node_to_arrletdilated_conv3d?(padding=SAME)inputkernelstriderate=make_then_connect(DilatedConv3d(padding,stride,rate))[|arr_to_nodeinput;arr_to_nodekernel|]|>node_to_arrletmax_pool1d?(padding=SAME)inputkernelstride=make_then_connect(MaxPool1d(padding,kernel,stride))[|arr_to_nodeinput|]|>node_to_arrletmax_pool2d?(padding=SAME)inputkernelstride=make_then_connect(MaxPool2d(padding,kernel,stride))[|arr_to_nodeinput|]|>node_to_arrletmax_pool3d?(padding=SAME)inputkernelstride=make_then_connect(MaxPool3d(padding,kernel,stride))[|arr_to_nodeinput|]|>node_to_arrletavg_pool1d?(padding=SAME)inputkernelstride=make_then_connect(AvgPool1d(padding,kernel,stride))[|arr_to_nodeinput|]|>node_to_arrletavg_pool2d?(padding=SAME)inputkernelstride=make_then_connect(AvgPool2d(padding,kernel,stride))[|arr_to_nodeinput|]|>node_to_arrletavg_pool3d?(padding=SAME)inputkernelstride=make_then_connect(AvgPool3d(padding,kernel,stride))[|arr_to_nodeinput|]|>node_to_arrletupsampling2dinputsize=make_then_connect(UpSampling2dsize)[|arr_to_nodeinput|]|>node_to_arrletconv1d_backward_inputinputkernelstrideoutput'=make_then_connect(Conv1dBackwardInputstride)[|arr_to_nodeinput;arr_to_nodekernel;arr_to_nodeoutput'|]|>node_to_arrletconv1d_backward_kernelinputkernelstrideoutput'=make_then_connect(Conv1dBackwardKernelstride)[|arr_to_nodeinput;arr_to_nodekernel;arr_to_nodeoutput'|]|>node_to_arrletconv2d_backward_inputinputkernelstrideoutput'=make_then_connect(Conv2dBackwardInputstride)[|arr_to_nodeinput;arr_to_nodekernel;arr_to_nodeoutput'|]|>node_to_arrletconv2d_backward_kernelinputkernelstrideoutput'=make_then_connect(Conv2dBackwardKernelstride)[|arr_to_nodeinput;arr_to_nodekernel;arr_to_nodeoutput'|]|>node_to_arrletconv3d_backward_inputinputkernelstrideoutput'=make_then_connect(Conv3dBackwardInputstride)[|arr_to_nodeinput;arr_to_nodekernel;arr_to_nodeoutput'|]|>node_to_arrletconv3d_backward_kernelinputkernelstrideoutput'=make_then_connect(Conv3dBackwardKernelstride)[|arr_to_nodeinput;arr_to_nodekernel;arr_to_nodeoutput'|]|>node_to_arrlettranspose_conv1d_backward_inputinputkernelstrideoutput'=make_then_connect(TransposeConv1dBackwardInputstride)[|arr_to_nodeinput;arr_to_nodekernel;arr_to_nodeoutput'|]|>node_to_arrlettranspose_conv1d_backward_kernelinputkernelstrideoutput'=make_then_connect(TransposeConv1dBackwardKernelstride)[|arr_to_nodeinput;arr_to_nodekernel;arr_to_nodeoutput'|]|>node_to_arrlettranspose_conv2d_backward_inputinputkernelstrideoutput'=make_then_connect(TransposeConv2dBackwardInputstride)[|arr_to_nodeinput;arr_to_nodekernel;arr_to_nodeoutput'|]|>node_to_arrlettranspose_conv2d_backward_kernelinputkernelstrideoutput'=make_then_connect(TransposeConv2dBackwardKernelstride)[|arr_to_nodeinput;arr_to_nodekernel;arr_to_nodeoutput'|]|>node_to_arrlettranspose_conv3d_backward_inputinputkernelstrideoutput'=make_then_connect(TransposeConv3dBackwardInputstride)[|arr_to_nodeinput;arr_to_nodekernel;arr_to_nodeoutput'|]|>node_to_arrlettranspose_conv3d_backward_kernelinputkernelstrideoutput'=make_then_connect(TransposeConv3dBackwardKernelstride)[|arr_to_nodeinput;arr_to_nodekernel;arr_to_nodeoutput'|]|>node_to_arrletdilated_conv1d_backward_inputinputkernelstriderateoutput'=make_then_connect(DilatedConv1dBackwardInput(stride,rate))[|arr_to_nodeinput;arr_to_nodekernel;arr_to_nodeoutput'|]|>node_to_arrletdilated_conv1d_backward_kernelinputkernelstriderateoutput'=make_then_connect(DilatedConv1dBackwardKernel(stride,rate))[|arr_to_nodeinput;arr_to_nodekernel;arr_to_nodeoutput'|]|>node_to_arrletdilated_conv2d_backward_inputinputkernelstriderateoutput'=make_then_connect(DilatedConv2dBackwardInput(stride,rate))[|arr_to_nodeinput;arr_to_nodekernel;arr_to_nodeoutput'|]|>node_to_arrletdilated_conv2d_backward_kernelinputkernelstriderateoutput'=make_then_connect(DilatedConv2dBackwardKernel(stride,rate))[|arr_to_nodeinput;arr_to_nodekernel;arr_to_nodeoutput'|]|>node_to_arrletdilated_conv3d_backward_inputinputkernelstriderateoutput'=make_then_connect(DilatedConv3dBackwardInput(stride,rate))[|arr_to_nodeinput;arr_to_nodekernel;arr_to_nodeoutput'|]|>node_to_arrletdilated_conv3d_backward_kernelinputkernelstriderateoutput'=make_then_connect(DilatedConv3dBackwardKernel(stride,rate))[|arr_to_nodeinput;arr_to_nodekernel;arr_to_nodeoutput'|]|>node_to_arrletmax_pool1d_backwardpaddinginputkernelstrideoutput'=make_then_connect(MaxPool1dBackward(padding,kernel,stride))[|arr_to_nodeinput;arr_to_nodeoutput'|]|>node_to_arrletmax_pool2d_backwardpaddinginputkernelstrideoutput'=make_then_connect(MaxPool2dBackward(padding,kernel,stride))[|arr_to_nodeinput;arr_to_nodeoutput'|]|>node_to_arrletmax_pool3d_backwardpaddinginputkernelstrideoutput'=make_then_connect(MaxPool3dBackward(padding,kernel,stride))[|arr_to_nodeinput;arr_to_nodeoutput'|]|>node_to_arrletavg_pool1d_backwardpaddinginputkernelstrideoutput'=make_then_connect(AvgPool1dBackward(padding,kernel,stride))[|arr_to_nodeinput;arr_to_nodeoutput'|]|>node_to_arrletavg_pool2d_backwardpaddinginputkernelstrideoutput'=make_then_connect(AvgPool2dBackward(padding,kernel,stride))[|arr_to_nodeinput;arr_to_nodeoutput'|]|>node_to_arrletavg_pool3d_backwardpaddinginputkernelstrideoutput'=make_then_connect(AvgPool3dBackward(padding,kernel,stride))[|arr_to_nodeinput;arr_to_nodeoutput'|]|>node_to_arrletupsampling2d_backwardinputsizeoutput'=make_then_connect(UpSampling2dBackwardsize)[|arr_to_nodeinput;arr_to_nodeoutput'|]|>node_to_arrletrow_numx=lets=shapexinletexn=Owl_exception.NOT_MATRIXsinOwl_exception.check(Array.lengths=2)exn;s.(0)letcol_numx=lets=shapexinletexn=Owl_exception.NOT_MATRIXsinOwl_exception.check(Array.lengths=2)exn;s.(1)letrowx_i=make_then_connectRow[|arr_to_nodex|]|>node_to_arrletrowsxi=make_then_connect(Rowsi)[|arr_to_nodex|]|>node_to_arrletcopy_row_tox_y_i=make_then_connectCopyRowTo[|arr_to_nodex|]|>ignoreletcopy_col_tox_y_j=make_then_connectCopyColTo[|arr_to_nodex|]|>ignorelettracex=make_then_connectTrace[|arr_to_nodex|]|>node_to_eltletdiag?k_x=k|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_computation_operator.diag")letdotxy=lettransa=falseinlettransb=falseinletalpha=A.float_to_elt1.|>pack_eltinletbeta=A.float_to_elt0.|>pack_eltinletop=Dot(transa,transb,alpha,beta)inmake_then_connectop[|arr_to_nodex;arr_to_nodey|]|>node_to_arrlettranspose?axisx=letd=Array.length(shapex)inletaxis=matchaxiswith|Somea->a|None->Array.initd(funi->d-i-1)inmake_then_connect(Transposeaxis)[|arr_to_nodex|]|>node_to_arrletto_rowsx=let_=make_then_connectToRows[|arr_to_nodex|]in(* FIXME: wrong shape *)[||]letof_rowsxs=make_then_connectOfRows(Array.maparr_to_nodexs)|>node_to_arrletof_arrayxshape=letparents=Array.mapelt_to_nodexinmake_then_connect(OfArrayshape)parents|>node_to_arrletof_cols_xs=raise(Owl_exception.NOT_IMPLEMENTED"owl_computation_operator.of_cols")letto_cols_xs=raise(Owl_exception.NOT_IMPLEMENTED"owl_computation_operator.to_cols")letof_arraysx=letshape=[|Array.lengthx;Array.lengthx.(0)|]inletparents=List.map(funy->Array.mapelt_to_nodey)(Array.to_listx)|>Array.concatinmake_then_connect(OfArrayshape)parents|>node_to_arrletto_arrays_x=raise(Owl_exception.NOT_IMPLEMENTED"owl_computation_operator.to_arrays")(** Scalar maths *)moduleScalar=structletaddxy=make_then_connectScalar_Add[|elt_to_nodex;elt_to_nodey|]|>node_to_eltletsubxy=make_then_connectScalar_Sub[|elt_to_nodex;elt_to_nodey|]|>node_to_eltletmulxy=make_then_connectScalar_Mul[|elt_to_nodex;elt_to_nodey|]|>node_to_eltletdivxy=make_then_connectScalar_Div[|elt_to_nodex;elt_to_nodey|]|>node_to_eltletpowxy=make_then_connectScalar_Pow[|elt_to_nodex;elt_to_nodey|]|>node_to_eltletatan2xy=make_then_connectScalar_Atan2[|elt_to_nodex;elt_to_nodey|]|>node_to_eltletabsx=make_then_connectScalar_Abs[|elt_to_nodex|]|>node_to_eltletnegx=make_then_connectScalar_Neg[|elt_to_nodex|]|>node_to_eltletsqrx=make_then_connectScalar_Sqr[|elt_to_nodex|]|>node_to_eltletsqrtx=make_then_connectScalar_Sqrt[|elt_to_nodex|]|>node_to_eltletexpx=make_then_connectScalar_Exp[|elt_to_nodex|]|>node_to_eltletlogx=make_then_connectScalar_Log[|elt_to_nodex|]|>node_to_eltletlog2x=make_then_connectScalar_Log2[|elt_to_nodex|]|>node_to_eltletlog10x=make_then_connectScalar_Log10[|elt_to_nodex|]|>node_to_eltletsignumx=make_then_connectScalar_Signum[|elt_to_nodex|]|>node_to_eltletfloorx=make_then_connectScalar_Floor[|elt_to_nodex|]|>node_to_eltletceilx=make_then_connectScalar_Ceil[|elt_to_nodex|]|>node_to_eltletroundx=make_then_connectScalar_Round[|elt_to_nodex|]|>node_to_eltletsinx=make_then_connectScalar_Sin[|elt_to_nodex|]|>node_to_eltletcosx=make_then_connectScalar_Cos[|elt_to_nodex|]|>node_to_eltlettanx=make_then_connectScalar_Tan[|elt_to_nodex|]|>node_to_eltletsinhx=make_then_connectScalar_Sinh[|elt_to_nodex|]|>node_to_eltletcoshx=make_then_connectScalar_Cosh[|elt_to_nodex|]|>node_to_eltlettanhx=make_then_connectScalar_Tanh[|elt_to_nodex|]|>node_to_eltletasinx=make_then_connectScalar_Asin[|elt_to_nodex|]|>node_to_eltletacosx=make_then_connectScalar_Acos[|elt_to_nodex|]|>node_to_eltletatanx=make_then_connectScalar_Atan[|elt_to_nodex|]|>node_to_eltletasinhx=make_then_connectScalar_Asinh[|elt_to_nodex|]|>node_to_eltletacoshx=make_then_connectScalar_Acosh[|elt_to_nodex|]|>node_to_eltletatanhx=make_then_connectScalar_Atanh[|elt_to_nodex|]|>node_to_eltletrelux=make_then_connectScalar_Relu[|elt_to_nodex|]|>node_to_eltletdawsnx=make_then_connectScalar_Dawsn[|elt_to_nodex|]|>node_to_eltletsigmoidx=make_then_connectScalar_Sigmoid[|elt_to_nodex|]|>node_to_eltendmoduleMat=structleteye_n=raise(Owl_exception.NOT_IMPLEMENTED"owl_computation_operator.eye")letdiagm?k_x=k|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_computation_operator.diagm")lettril?k_x=k|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_computation_operator.tril")lettriu?k_x=k|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_computation_operator.triu")endmoduleLinalg=structletinvx=make_then_connectInv[|arr_to_nodex|]|>node_to_arrletlogdet_x=raise(Owl_exception.NOT_IMPLEMENTED"owl_computation_operator.logdet")letchol?(upper=true)_x=upper|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_computation_operator.chol")letsvd?(thin=true)_x=thin|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_computation_operator.svd")letqr_x=raise(Owl_exception.NOT_IMPLEMENTED"owl_computation_operator.qr")letlq_x=raise(Owl_exception.NOT_IMPLEMENTED"owl_computation_operator.lq")letsylvester_a_b_c=raise(Owl_exception.NOT_IMPLEMENTED"owl_computation_operator.sylvester")letlyapunov_a_q=raise(Owl_exception.NOT_IMPLEMENTED"owl_computation_operator.lyapunov")letdiscrete_lyapunov?(solver=`default)_a_q=solver|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_computation_operator.discrete_lyapunov")letlinsolve?trans?(typ=`n)_a_b=trans|>ignore;typ|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_computation_operator.linsolve")letcare?(diag_r=false)_a_b_q_r=diag_r|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_computation_operator.care")letdare?(diag_r=false)_a_b_q_r=diag_r|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_computation_operator.dare")endend(* Make functor ends *)