src/base/dense/owl_base_dense_ndarray_generic.ml"(*
* OWL - OCaml Scientific Computing
* Copyright (c) 2016-2022 Liang Wang <liang@ocaml.xyz>
*)[@@@warning"-32"]openBigarrayopenOwl_typestype('a,'b)t=('a,'b,c_layout)Genarray.ttype('a,'b)kind=('a,'b)Bigarray.kindmoduleScalar=Owl_base_maths(* Prepend an array with ones to the given length *)let_prepend_dimsdimsdesired_len=letdims_len=Array.lengthdimsinifdims_len>=desired_lenthendimselseArray.append(Array.make(desired_len-dims_len)1)dimslet_get_broadcasted_dimsdims_adims_b=letlen_c=Stdlib.max(Array.lengthdims_a)(Array.lengthdims_b)inletext_dims_a=_prepend_dimsdims_alen_cinletext_dims_b=_prepend_dimsdims_blen_cinletdims_c=Array.makelen_c0infori=0tolen_c-1doletval_a=ext_dims_a.(i)inletval_b=ext_dims_b.(i)inifval_a=val_bthendims_c.(i)<-val_aelseifval_a!=1&&val_b!=1thenraise(Invalid_argument"The arrays cannot be broadcast into the same shape")elsedims_c.(i)<-Stdlib.maxval_aval_bdone;ext_dims_a,ext_dims_b,dims_c(* Increment the index array, with respect to the dimensions array *)let_next_indexinddims=letnum_dims=Array.lengthindinletp=ref(num_dims-1)inletok=reffalseinwhile!p>=0&¬!okdoifind.(!p)+1<dims.(!p)then(ind.(!p)<-ind.(!p)+1;ok:=true)else(ind.(!p)<-0;p:=!p-1)done;!oklet_get_broadcasted_indexinddims=letnum_dims=Array.lengthdimsinletcalc_funi=letmax_ind=dims.(i)inletind_val=ind.(i)inifind_val<max_indthenind_valelseifmax_ind=1then0elseraise(Invalid_argument"not broadcasted correctly")inArray.initnum_dimscalc_funlet_apply_permarrperm=Array.init(Array.lengtharr)(funi->arr.(perm.(i)))let_draw_int_samplesreplacementrangecount=if(notreplacement)&&count>rangethenraise(Invalid_argument"cannot draw that many samples from the given range, without replacement")else(letpop_cnt=refrangeinletpop=Array.init!pop_cnt(funi->i)inletrand_gen=Random.State.make_self_init()inletdraw_fun_=letindex=Random.State.intrand_gen!pop_cntinletsample=pop.(index)inifreplacementthensampleelse(pop_cnt:=!pop_cnt-1;pop.(index)<-pop.(!pop_cnt);(* eliminate sample by swapping with last element *)sample)inArray.initcountdraw_fun)let_enumerate_slice_defdim?stepstartstop=letstart=ifstart<0thendim+startelsestartinletstop=ifstop<0thendim+stopelsestopinletstep=matchstepwith|Somex->x|None->ifstart<=stopthen1else-1inassert((start<=stop&&step>0)||(start>stop&&step<0));letstep_abs=Stdlib.absstepinletlen=(Stdlib.abs(stop-start)+step_abs)/step_absinArray.initlen(funi->start+(i*step))(* Rewrite the indices s.t. for each dimension they are a list of explicit indices *)let_expand_slice_indicesindex_listdims=letrank=Array.lengthdimsinletsdef_len=List.lengthindex_listin(* the number of dimensions this slice specifies *)let_expand_slice_indexiind=matchindwith|[]->Array.initdims.(i)(funi->i)|[start]->_enumerate_slice_defdims.(i)startstart|[start;stop]->_enumerate_slice_defdims.(i)startstop|[start;stop;step]->_enumerate_slice_defdims.(i)~stepstartstop|x->Array.of_listxinArray.append(Array.of_list(List.mapi_expand_slice_indexindex_list))(* for the axis where the index was specified *)(Array.init(rank-sdef_len)(* the rest of the axis is just all of them *)(funp->Array.initdims.(p+sdef_len)(funi->i)))letresetx=let_kind=Genarray.kindxinGenarray.fillx(Owl_const.zero_kind)letemptykinddims=Genarray.createkindc_layoutdimsletcreatekinddimsvalue=letx=emptykinddimsinGenarray.fillxvalue;xletcreate_~outa=Genarray.filloutaletzeroskinddims=createkinddims(Owl_const.zerokind)letzeros_~out=resetoutletoneskinddims=createkinddims(Owl_const.onekind)letones_~out=Genarray.(fillout(Owl_const.one(kindout)))letshapex=Genarray.dimsxletnth_dimxi=Genarray.nth_dimxiletnum_dimsx=Array.length(shapex)letnumelx=Owl_utils.numelxletkindx=Genarray.kindxletgetxindex=Genarray.getxindexletsetxindexvalue=Genarray.setxindexvalueleteyekindn=letm=zeroskind[|n;n|]infori=0ton-1dosetm[|i;i|](Owl_const.onekind)done;m(*TODO: optimise, test *)letget_sliceindex_listvarr=letdims=shapevarrinletrank=Array.lengthdimsinletindex_array=_expand_slice_indicesindex_listdimsinletslice_dims=Array.map(funa->Array.lengtha)index_arrayinletslice_varr=empty(kindvarr)slice_dimsinletslice_ind=Array.makerank0inletoriginal_ind=Array.makerank0inletshould_stop=reffalseinwhilenot!should_stopdofori=0torank-1dooriginal_ind.(i)<-index_array.(i).(slice_ind.(i))done;Genarray.setslice_varrslice_ind(Genarray.getvarroriginal_ind);ifnot(_next_indexslice_indslice_dims)thenshould_stop:=truedone;slice_varr(*TODO: optimise, test *)letset_sliceindex_listvarrslice_varr=letdims=shapevarrinletrank=Array.lengthdimsinletindex_array=_expand_slice_indicesindex_listdimsinletslice_dims=Array.map(funa->Array.lengtha)index_arrayinletslice_varr=reshapeslice_varrslice_dimsinletslice_ind=Array.makerank0inletoriginal_ind=Array.makerank0inletshould_stop=reffalseinwhilenot!should_stopdofori=0torank-1dooriginal_ind.(i)<-index_array.(i).(slice_ind.(i))done;Genarray.setvarroriginal_ind(Genarray.getslice_varrslice_ind);ifnot(_next_indexslice_indslice_dims)thenshould_stop:=truedone(*TODO: optimise, test *)letget_fancy_indices_varr=raise(Owl_exception.NOT_IMPLEMENTED"base: get_fancy")(*TODO: optimise, test *)letset_fancy_indices_target_input=raise(Owl_exception.NOT_IMPLEMENTED"base: set_fancy")(* The result shares the underlying buffer with original, not a copy *)letreshapexd=letminus_one=Owl_utils.Array.countd(-1)inassert(minus_one<=1);ifminus_one=0thenreshapexdelse(letn=numelxinletm=Array.fold_right(*)d(-1)inlete=Array.map(funa->ifa=-1thenn/melsea)dinreshapexe)(* Return the array as a contiguous block, without copying *)letflattenx=reshapex[|numelx|]letfillxa=Genarray.fillxaletcopyx=lety=empty(kindx)(shapex)inGenarray.blitxy;yletcopy_~outx=letsrc=flattenxinletdst=flattenoutinGenarray.blitsrcdstletreshape_~outx=ifnot(x==out)thencopy_~outxletreversex=letn=numelxinlety=empty(kindx)(shapex)inlety_flat=reshapey[|n|]inletx_flat=reshapex[|n|]infori=0ton-1dosety_flat[|i|](getx_flat[|n-1-i|])done;yletreverse_~outx=letn=numelxinlety_flat=reshapeout[|n|]inletx_flat=reshapex[|n|]infori=0ton-1dosety_flat[|i|](getx_flat[|n-1-i|])doneletmap_fx=lety=flattenx|>array1_of_genarrayinletlength=numelxinfori=0tolength-1doArray1.unsafe_setyi(f(Array1.unsafe_getyi))doneletmapi_fx=lety=flattenx|>array1_of_genarrayinletlength=numelxinfori=0tolength-1doArray1.unsafe_setyi(fi(Array1.unsafe_getyi))doneletinitkinddimsf=letvarr=emptykinddimsinletvarr_flat=flattenvarr|>array1_of_genarrayinletn=numelvarrinfori=0ton-1doArray1.unsafe_setvarr_flati(fi)done;varrletinit_ndkdf=letx=emptykdinlety=array1_of_genarray(flattenx)inletn=numelxinlets=Owl_utils.calc_stridedinletj=Array.copysinfori=0ton-1doOwl_utils.index_1d_ndijs;Array1.unsafe_setyi(fj)done;x(* Map a NDarray from elements x -> f(x), by copying the array *)letmapfx=lety=copyxinmap_fy;yletmapifx=lety=copyxinlety'=flatteny|>array1_of_genarrayinfori=0toArray1.dimy'-1doleta=Array1.unsafe_gety'iinArray1.unsafe_sety'i(fia)done;yletstridesx=x|>shape|>Owl_utils.calc_strideletslice_sizex=x|>shape|>Owl_utils.calc_slice(* TODO: performance can be optimised by removing embedded loops *)(* generic fold function *)letfoldi?axisfax=letx'=flattenx|>array1_of_genarrayinmatchaxiswith|Someaxis->letm,n,o,s=Owl_utils.reduce_paramsaxisxinletstart_x=ref0inletstart_y=ref0inletincy=ref0inletk=ref0inlety=create(kindx)sainlety'=flatteny|>array1_of_genarrayinfor_i=0tom-1doforj=0ton-1doletb=Array1.unsafe_gety'(!start_y+!incy)inletc=Array1.unsafe_getx'(!start_x+j)inArray1.unsafe_sety'(!start_y+!incy)(f!kbc);if!incy+1=othenincy:=0elseincy:=!incy+1;k:=!k+1done;start_x:=!start_x+n;start_y:=!start_y+odone;y|None->letb=refainfori=0tonumelx-1doletc=Array1.unsafe_getx'iinb:=fi!bcdone;create(kindx)[|1|]!bletfold?axisfax=foldi?axis(fun_bc->fbc)ax(* generic scan function *)letscani?axisfx=letd=num_dimsxinleta=matchaxiswith|Somea->a|None->d-1inassert(0<=a&&a<d);let_stride=stridesxinlet_slicez=slice_sizexinletm=numelx/_slicez.(a)inletn=_slicez.(a)-_stride.(a)inletincx=_slicez.(a)inletincy=_slicez.(a)inletstart_x=ref0inletstart_y=ref_stride.(a)inletk=ref0inlety=copyxinlety'=flatteny|>array1_of_genarrayinfor_i=0tom-1doforj=0ton-1doletb=Array1.unsafe_gety'(!start_x+j)inletc=Array1.unsafe_gety'(!start_y+j)inArray1.unsafe_sety'(!start_y+j)(f!kbc);k:=!k+1done;start_x:=!start_x+incx;start_y:=!start_y+incydone;yletscan?axisfx=scani?axis(fun_ab->fab)xletiterifx=letx'=flattenx|>array1_of_genarrayinfori=0toArray1.dimx'-1doleta=Array1.unsafe_getx'iinfiadoneletiterfx=letx'=flattenx|>array1_of_genarrayinfori=0toArray1.dimx'-1doleta=Array1.unsafe_getx'iinfadoneletfilterifx=lets=Owl_utils.Stack.make()initeri(funiy->iffiy=truethenOwl_utils.Stack.pushsi)x;Owl_utils.Stack.to_arraysletfilterfx=filteri(fun_y->fy)xletsequential_?a?step~out=letk=kindoutinleta=matchawith|Somea->a|None->Owl_const.zerokinletstep=matchstepwith|Somestep->step|None->Owl_const.onekinlet_add=Owl_base_dense_common._add_eltkinlet_mul=Owl_base_dense_common._mul_eltkinlet_flt=Owl_base_dense_common._float_typ_eltkinmapi_(funi_->_adda(_mul(_flt(float_of_inti))step))out[@@warning"-unerasable-optional-argument"]letsequentialk?a?stepdimension=letx=emptykdimensioninsequential_?a?step~out:x;xletof_arraykindarrdims=letvarr=emptykinddimsinletflat_varr=flattenvarr|>array1_of_genarrayinletn=numelvarrinfori=0ton-1doArray1.unsafe_setflat_varriarr.(i)done;varrletuniformk?a?bdims=leta=matchawith|Somea->a|None->Owl_const.zerokinletb=matchbwith|Someb->b|None->Owl_const.onekinletuniform_fun=Owl_base_dense_common._uniform_eltkabinletx=emptykdimsinmap_uniform_funx;xletuniform_?a?b~out=letk=kindoutinleta=matchawith|Somea->a|None->Owl_const.zerokinletb=matchbwith|Someb->b|None->Owl_const.onekinletuniform_fun=Owl_base_dense_common._uniform_eltkabinmap_uniform_funout[@@warning"-unerasable-optional-argument"]letbernoullik?(p=0.5)dims=letbernoulli_fun_=leta=Owl_base_stats.bernoulli_rvs~pinOwl_base_dense_common._float_typ_eltkainletx=emptykdimsinmap_bernoulli_funx;xletbernoulli_?(p=0.5)~out=letk=kindoutinletbernoulli_fun_=leta=Owl_base_stats.bernoulli_rvs~pinOwl_base_dense_common._float_typ_eltkainmap_bernoulli_funout[@@warning"-unerasable-optional-argument"]letgaussiank?mu?sigmadims=letmu=matchmuwith|Somea->a|None->Owl_const.zerokinletsigma=matchsigmawith|Somea->a|None->Owl_const.onekinletgaussian_fun=Owl_base_dense_common._gaussian_eltkmusigmainletx=emptykdimsinmap_gaussian_funx;xletgaussian_?mu?sigma~out=letk=kindoutinletmu=matchmuwith|Somea->a|None->Owl_const.zerokinletsigma=matchsigmawith|Somea->a|None->Owl_const.onekinletgaussian_fun=Owl_base_dense_common._gaussian_eltkmusigmainmap_gaussian_funout[@@warning"-unerasable-optional-argument"]letprint?max_row?max_col?header?fmtx=letdims=shapexinletrank=Array.lengthdimsinletn=dims.(rank-1)inletmax_row=matchmax_rowwith|Somea->Somea|None->Some(numelx/n)inletmax_col=matchmax_colwith|Somea->Somea|None->SomeninOwl_pretty.print_dsnda?max_row?max_col?header?elt_to_str_fun:fmtx(* TODO: optimise *)lettilevarrreps=(* First ensure len(reps) = num_dims(varr) *)letdims=shapevarrinletresult_rank=Stdlib.max(Array.lengthdims)(Array.lengthreps)inletdims=_prepend_dimsdimsresult_rankinletreps=_prepend_dimsrepsresult_rankinletvarr=reshapevarrdimsin(* now len(reps) = num_dims(varr) *)letresult_dims=Array.map2(funab->a*b)dimsrepsinletresult_varr=empty(kindvarr)result_dimsinletresult_ind=Array.makeresult_rank0inletoriginal_ind=Array.makeresult_rank0inletshould_stop=reffalseinwhilenot!should_stopdofori=0toresult_rank-1dooriginal_ind.(i)<-Stdlib.(mod)result_ind.(i)dims.(i)done;Genarray.setresult_varrresult_ind(Genarray.getvarroriginal_ind);ifnot(_next_indexresult_indresult_dims)thenshould_stop:=truedone;result_varr(* TODO: optimise *)letsplit?(axis=0)partsvarr=letdims=shapevarrinletrank=Array.lengthdimsinletpos=ref0inletaxis_indices=Array.map(fund->pos:=!pos+d;[!pos-d;!pos-1])partsinletslices_defs=Array.map(funind->Array.to_list(Array.initrank(funi->ifi=axisthenindelse[])))axis_indicesinArray.map(fundef->get_slicedefvarr)slices_defsletsqueeze?(axis=[||])x=leta=matchArray.lengthaxiswith|0->Array.init(num_dimsx)(funi->i)|_->axisinlets=Owl_utils.Array.filteri(funiv->not(v==1&&Array.memia))(shapex)inreshapexsletexpand?(hi=false)xd=letd0=d-num_dimsxinmatchd0>0with|true->ifhi=truethenOwl_utils.Array.pad`Right1d0(shapex)|>reshapexelseOwl_utils.Array.pad`Left1d0(shapex)|>reshapex|false->x(* TODO : ensure this is desired behaviour *)(* Similar to draw rows for matrices *)letdraw?(axis=0)varrcount=letdims=shapevarrinletrank=Array.lengthdimsinletindices=_draw_int_samplesfalsedims.(axis)countin(get_slice(List.initrank(funi->ifi=axisthenArray.to_listindiceselse[]))varr,indices)let_expand_padding_indexds=letls=Array.lengthsinletld=Array.lengthdinletd=Owl_utils.Array.pad`Right[|0;0|](ls-ld)dinArray.map(function|[||]->[|0;0|]|[|x|]->[|x;x|]|x->x)dletrec_copy_to_paddingp1lsl0l1i0i1d0d1s0s1x0x1=ifd0<d1thenfori=0tos0.(d0)-1doi0.(d0)<-i;i1.(d0)<-i+p1.(d0).(0);_copy_to_paddingp1lsl0l1i0i1(d0+1)d1s0s1x0x1;i0.(d0)<-0;i1.(d0)<-p1.(d0).(0)doneelse(letj0=Owl_utils.index_nd_1di0l0inletj1=Owl_utils.index_nd_1di1l1inletsubx=Genarray.sub_leftx0j0ls.(d0)inletsuby=Genarray.sub_leftx1j1ls.(d0)inGenarray.blitsubxsuby)let_highest_padding_dimensionp=letl=Array.lengthp-1inletd=reflin(tryfori=ldownto0dod:=i;ifp.(i)<>[|0;0|]thenfailwith"pad:highest_padding_dimension"donewith|_exn->());!dletpad?vdx=letk=kindxinletv=matchvwith|Somev->v|None->Owl_const.zerokinlets0=shapexinletx'=flattenxinletp1=_expand_padding_index(Owl_utils.llss2aarrd)s0inlets1=Array.map2(funmn->m+n.(0)+n.(1))s0p1inlets'=Owl_utils_array.fold_right(*)s11inlety'=createk[|s'|]vinletls=Owl_utils.calc_slices0inletl0=Owl_utils.calc_strides0inletl1=Owl_utils.calc_strides1inleti0=Array.make(num_dimsx)0inleti1=Array.map(funa->a.(0))p1inletd0=0inletd1=_highest_padding_dimensionp1in_copy_to_paddingp1lsl0l1i0i1d0d1s0s1x'y';reshapey's1(* TODO: optimise? *)letconcatenate?(axis=0)varrs=letvarrs_num=Array.lengthvarrsin(* dimensions of all NDarrays *)letall_dims=Array.mapshapevarrsin(* the dimensions before the axis *)letprefix_dims=Array.suball_dims.(0)0axisin(* the sum of the dimensions of each NDarray along given axis *)letsum_axis_dims=Array.fold_left(funxa->x+a.(axis))0all_dimsin(* the dimensions after the axis *)letsuffix_dims=Array.suball_dims.(0)(axis+1)(Array.lengthall_dims.(0)-axis-1)inletresult_dims=Array.concat[prefix_dims;[|sum_axis_dims|];suffix_dims]inletresult_varr=empty(kindvarrs.(0))result_dimsinletprefix_dims_product=Array.fold_left(*)1prefix_dimsinletsuffix_dims_product=Array.fold_left(*)1suffix_dimsinletreshaper_fun(* Reshape the variable as [prefix_dims_product, rest] *)varr=letold_shape=shapevarrinletnew_shape=[|prefix_dims_product;old_shape.(axis)*suffix_dims_product|]inreshapevarrnew_shapeinletreshaped_result=reshaper_funresult_varrinletreshaped_varrs=Array.mapreshaper_funvarrsinfori=0toprefix_dims_product-1doletstart_index=ref0inletresult_slice=Genarray.slice_leftreshaped_result[|i|]inforj=0tovarrs_num-1doletsrc_slice=Genarray.slice_leftreshaped_varrs.(j)[|i|]inletblock_len=all_dims.(j).(axis)*suffix_dims_productinletresult_sub=Genarray.sub_leftresult_slice!start_indexblock_leninGenarray.blitsrc_sliceresult_sub;start_index:=!start_index+block_lendonedone;result_varrletstack?(axis=0)xs=letshp=shapexs.(0)inletndim=Array.lengthshp+1inletaxis=Owl_utils.adjust_indexaxisndiminletnew_shp=Array.initndim(funi->ifi<axisthenshp.(i)elseifi=axisthen1elseshp.(i-1))inlety=Array.map(funx->letshp'=shapexinifshp'<>shpthenfailwith"stack: ndarrays in [xs] must all have the same shape";reshapexnew_shp)xsinconcatenate~axisy(* TODO: is there a more efficient way to do copy? *)letrepeatxreps=(* check the validity of reps *)ifArray.exists((>)1)repsthenfailwith"repeat: repetition must be >= 1";letx_dims=num_dimsxinassert(Array.lengthreps=x_dims);ifArray.for_all((=)1)reps=truethencopyxelse(let_kind=kindxinletx'=flattenxinletx_shape=shapexinlety_shape=Array.map2(*)x_shaperepsinletnum=Owl_utils_array.fold_right(*)y_shape1inlety'=empty_kind[|num|]inifx_dims=1then(letofsy=ref0infori=0tonumelx-1doletelemx=getx'[|i|]infor_j=0toreps.(0)-1dosety'[|!ofsy|]elemx;ofsy:=!ofsy+1donedone)else(lethighest_dim=x_dims-1inletslice_x=Owl_utils.calc_slicex_shapeinletstride_y=Owl_utils.calc_stridey_shapeinlethd=ref(highest_dim+1)inwhile!hd>1&&reps.(!hd-1)=1dohd:=!hd-1done;lethd=if!hd=highest_dim+1thenhighest_dimelse!hdin(* Copy the HD dimension from x to y *)letblock_num=Array.makehd0infori=0tohd-1doblock_num.(i)<-slice_x.(i)/slice_x.(hd)done;letcounter=Array.makehd0inletofsx=ref0inletofsy=ref0inletblock_sz=reps.(hd)infor_i=0toblock_num.(0)-1doletofsy_sub=ref!ofsyinifblock_sz=1then(letsubx=Genarray.sub_leftx'!ofsxslice_x.(hd)inletsuby=Genarray.sub_lefty'!ofsy_subslice_x.(hd)inGenarray.blitsubxsuby)elseforj=0toslice_x.(hd)-1doletelemx=getx'[|!ofsx+j|]infork=0toblock_sz-1dosety'[|!ofsy_sub+k|]elemxdone;ofsy_sub:=!ofsy_sub+block_szdone;ofsx:=!ofsx+slice_x.(hd);ofsy:=!ofsy+(stride_y.(hd-1)*reps.(hd-1));forj=hd-1downto1doletc=counter.(j)inifc+1=block_num.(j)thenofsy:=!ofsy+(stride_y.(j-1)*(reps.(j-1)-1));counter.(j)<-(ifc+1=block_num.(j)then0elsec+1)donedone;(* Copy the lower dimensions within y *)ford=hd-1downto0doletblock_num=Array.make(d+1)0infori=0toddoblock_num.(i)<-slice_x.(i)/slice_x.(d+1)done;letofsy=ref0inletblock_sz=stride_y.(d)inletcounter=Array.makehd0infor_i=0toblock_num.(0)-1doletofsy_sub=ref(!ofsy+block_sz)infor_j=1toreps.(d)-1doletsubx=Genarray.sub_lefty'!ofsyblock_szinletsuby=Genarray.sub_lefty'!ofsy_subblock_szinGenarray.blitsubxsuby;ofsy_sub:=!ofsy_sub+block_szdone;ofsy:=!ofsy+(stride_y.(d)*reps.(d));forj=d-1downto0doletc=counter.(j)inifc+1=block_num.(j+1)thenofsy:=!ofsy+(stride_y.(j)*(reps.(j)-1));counter.(j)<-(ifc+1=block_num.(j+1)then0elsec+1)donedonedone);reshapey'y_shape)(* mathematical functions *)letabsx=let_kind=kindxinlet_func=Owl_base_dense_common._abs_elt_kindinmap_funcxletabs_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._abs_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletconjx=let_kind=kindxinlet_func=Owl_base_dense_common._conj_elt_kindinmap_funcxletconj_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._conj_elt_kindinletout=matchoutwith|Someo->o|None->xinmap_funcoutletnegx=let_kind=kindxinlet_func=Owl_base_dense_common._neg_elt_kindinmap_funcxletneg_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._neg_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletrecix=let_kind=kindxinlet_func=Owl_base_dense_common._inv_elt_kindinmap_funcxletreci_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._inv_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletfloorx=let_kind=kindxinlet_func=Owl_base_dense_common._floor_elt_kindinmap_funcxletfloor_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._floor_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletceilx=let_kind=kindxinlet_func=Owl_base_dense_common._ceil_elt_kindinmap_funcxletceil_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._ceil_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletroundx=let_kind=kindxinlet_func=Owl_base_dense_common._round_elt_kindinmap_funcxletround_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._round_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutlettruncx=let_kind=kindxinlet_func=Owl_base_dense_common._trunc_elt_kindinmap_funcxlettrunc_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._trunc_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletfixx=let_kind=kindxinlet_func=Owl_base_dense_common._fix_elt_kindinmap_funcxletfix_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._fix_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutleterf_x=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.erf")leterf_?_out_x=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.erf_")leterfc_x=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.erfc")leterfc_?_out_x=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.erfc_")letsqrx=let_kind=kindxinlet_func=Owl_base_dense_common._sqr_elt_kindinmap_funcxletsqr_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._sqr_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletsqrtx=let_kind=kindxinlet_func=Owl_base_dense_common._sqrt_elt_kindinmap_funcxletsqrt_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._sqrt_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletcbrtx=let_kind=kindxinletb=Owl_base_dense_common._float_typ_elt_kind(1./.3.)inlet_funca=Owl_base_dense_common._pow_elt_kindabinmap_funcxletcbrt_?outx=let_kind=kindxinletb=Owl_base_dense_common._float_typ_elt_kind(1./.3.)inlet_funca=Owl_base_dense_common._pow_elt_kindabinletout=matchoutwith|Someo->o|None->xinmap__funcoutletlogx=let_kind=kindxinlet_func=Owl_base_dense_common._log_elt_kindinmap_funcxletlog_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._log_elt_kindinletout=matchoutwith|Someo->o|None->xinmap_Scalar.logoutletlog2x=let_kind=kindxinlet_func=Owl_base_dense_common._log2_elt_kindinmap_funcxletlog2_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._log2_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletlog10x=let_kind=kindxinlet_func=Owl_base_dense_common._log10_elt_kindinmap_funcxletlog10_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._log10_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletlog1px=let_kind=kindxinlet_func=Owl_base_dense_common._log1p_elt_kindinmap_funcxletlog1p_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._log1p_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletexpx=let_kind=kindxinlet_func=Owl_base_dense_common._exp_elt_kindinmap_funcxletexp_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._exp_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletexp2x=let_kind=kindxinlet_func=Owl_base_dense_common._exp2_elt_kindinmap_funcxletexp2_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._exp2_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletexp10x=let_kind=kindxinlet_func=Owl_base_dense_common._exp10_elt_kindinmap_funcxletexp10_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._exp10_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletexpm1x=let_kind=kindxinlet_func=Owl_base_dense_common._expm1_elt_kindinmap_funcxletexpm1_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._expm1_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletsinx=let_kind=kindxinlet_func=Owl_base_dense_common._sin_elt_kindinmap_funcxletsin_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._sin_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletcosx=let_kind=kindxinlet_func=Owl_base_dense_common._cos_elt_kindinmap_funcxletcos_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._cos_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutlettanx=let_kind=kindxinlet_func=Owl_base_dense_common._tan_elt_kindinmap_funcxlettan_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._tan_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletsinhx=let_kind=kindxinlet_func=Owl_base_dense_common._sinh_elt_kindinmap_funcxletsinh_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._sinh_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletcoshx=let_kind=kindxinlet_func=Owl_base_dense_common._cosh_elt_kindinmap_funcxletcosh_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._cosh_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutlettanhx=let_kind=kindxinlet_func=Owl_base_dense_common._tanh_elt_kindinmap_funcxlettanh_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._tanh_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletasinx=let_kind=kindxinlet_func=Owl_base_dense_common._asin_elt_kindinmap_funcxletasin_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._asin_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletacosx=let_kind=kindxinlet_func=Owl_base_dense_common._acos_elt_kindinmap_funcxletacos_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._acos_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletatanx=let_kind=kindxinlet_func=Owl_base_dense_common._atan_elt_kindinmap_funcxletatan_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._atan_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletasinhx=let_kind=kindxinlet_func=Owl_base_dense_common._asinh_elt_kindinmap_funcxletasinh_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._asinh_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletacoshx=let_kind=kindxinlet_func=Owl_base_dense_common._acosh_elt_kindinmap_funcxletacosh_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._acosh_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletatanhx=let_kind=kindxinlet_func=Owl_base_dense_common._atanh_elt_kindinmap_funcxletatanh_?outx=let_kind=kindxinlet_func=Owl_base_dense_common._atanh_elt_kindinletout=matchoutwith|Someo->o|None->xinmap__funcoutletsum_slices?(axis=0)varr=letdims=shapevarrinletrank=Array.lengthdimsin(* reshape into 2d matrix *)letnum_rows=Array.fold_left(*)1(Array.subdims0(axis+1))inletnum_cols=numelvarr/num_rowsinletvarr_mat=reshapevarr[|num_rows;num_cols|]inletresult_vec=empty(kindvarr)[|num_cols|]inletresult_varr=reshaperesult_vec(Array.subdims(axis+1)(rank-axis-1))inletrow_sum=ref0.inforj=0tonum_cols-1dorow_sum:=0.;fori=0tonum_rows-1dorow_sum:=!row_sum+.Genarray.getvarr_mat[|i;j|]done;Genarray.setresult_vec[|j|]!row_sumdone;result_varr(* -1. for negative numbers, 0 or (-0) for 0,
1 for positive numbers, nan for nan*)letsignumx=mapScalar.signumxletsignum_?outx=letout=matchoutwith|Someo->o|None->xinmap_Scalar.signumout(* Apply 1 / (1 + exp (-x)) for each element x *)letsigmoidx=mapScalar.sigmoidxletsigmoid_?outx=letout=matchoutwith|Someo->o|None->xinmap_Scalar.sigmoidoutletrelux=mapScalar.reluxletrelu_?outx=letout=matchoutwith|Someo->o|None->xinmap_Scalar.reluoutletdawsnx=mapScalar.dawsnxletsoftsignx=mapScalar.softsignxletsoftsign_?outx=letout=matchoutwith|Someo->o|None->xinmap_Scalar.softsignoutletsoftplusx=mapScalar.softplusxletsoftplus_?outx=letout=matchoutwith|Someo->o|None->xinmap_Scalar.softplusoutlet_fold_leftfavarr=letaref=refainletvarr_linear=flattenvarr|>array1_of_genarrayinletlength=numelvarrinfori=0tolength-1doaref:=f!aref(Array1.unsafe_getvarr_lineari)done;!aref(* Min of all elements in the NDarray *)letmin'x=let_kind=kindxinlet_max_val=Owl_base_dense_common._max_val_elt_kindin_fold_left(Owl_base_dense_common._min_elt_kind)_max_valx(* Max of all elements in the NDarray *)letmax'x=let_kind=kindxinlet_min_val=Owl_base_dense_common._min_val_elt_kindin_fold_left(Owl_base_dense_common._max_elt_kind)_min_valx(* Sum of all elements *)letsum'x=let_kind=kindxin_fold_left(Owl_base_dense_common._add_elt_kind)(Owl_const.zero_kind)x(* log sum of exp all elements *)letlog_sum_exp'_=raise(Owl_exception.NOT_IMPLEMENTED"base ndarray: log_sum_exp'")(* log sum of exp all elements *)letlog_sum_exp?axis:_?(keep_dims=true)_=ignorekeep_dims;raise(Owl_exception.NOT_IMPLEMENTED"base ndarray: log_sum_exp")(* Folding along a specified axis, aka reduction. The
f: function of type 'a -> 'a -> 'a.
m: number of slices.
n: x's slice size.
o: x's strides, also y's slice size.
x: source; y: shape of destination. Note that o <= n.
*)let_fold_along?outfmnoxysnelem=letx=flattenxinlety=matchoutwith|Someo->o|>flatten|None->create(kindx)ysnelem|>flatteninletidx=ref0inletidy=ref0inletincy=ref0infor_i=0tom-1doforj=0ton-1doletaddon=Genarray.getx[|!idx+j|]inletorig=Genarray.gety[|!idy+!incy|]inGenarray.sety[|!idy+!incy|](forigaddon);incy:=if!incy+1=othen0else!incy+1done;idx:=!idx+n;idy:=!idy+odone;reshapeyysletsum?axis?(keep_dims=true)x=let_kind=kindxinletzero=Owl_const.zero_kindinmatchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinlet_op=Owl_base_dense_common._add_elt_kindinletx=_fold_along_opmnoxszeroinifkeep_dimsthenxelsesqueeze~axis:[|a|]x|None->create(kindx)(Array.make11)(sum'x)letsum_~out~axisx=let_kind=kindxinletzero=Owl_const.zero_kindinGenarray.filloutzero;matchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinlet_op=Owl_base_dense_common._add_elt_kindin_fold_along_op~outmnoxszero|>ignore|None->lety=flattenoutinsety[|0|](sum'x)letsum_reduce?axisx=let_kind=kindxinlet_dims=num_dimsxinletzero=Owl_const.zero_kindinmatchaxiswith|Somea->letx_shape=shapexinletdims'=Owl_utils.squeeze_continuous_dimsx_shapeainifArray.lengthdims'=1thencreate(kindx)(Array.make_dims1)(sum'x)else(lety=ref(reshapexdims')inletflag=ref(Array.mem0a)infori=0toArray.lengthdims'-1doif!flag=truethen(letm,n,o,s=Owl_utils.reduce_paramsi!yiny:=_fold_along(Owl_base_dense_common._add_elt_kind)mno!yszero);flag:=not!flagdone;lety_shape=Array.copyx_shapeinArray.iter(funj->y_shape.(j)<-1)a;reshape!yy_shape)|None->create(kindx)(Array.make_dims1)(sum'x)letmin?axis?(keep_dims=true)x=let_kind=kindxinletmax_val=Owl_base_dense_common._max_val_elt_kindinmatchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinletx=_fold_along(Owl_base_dense_common._min_elt_kind)mnoxsmax_valinifkeep_dimsthenxelsesqueeze~axis:[|a|]x|None->min'x|>create_kind[|1|]letmin_~out~axisx=let_kind=kindxinletmax_val=Owl_base_dense_common._max_val_elt_kindinGenarray.filloutmax_val;matchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinlet_op=Owl_base_dense_common._min_elt_kindin_fold_along~out_opmnoxsmax_val|>ignore|None->lety=flattenoutinsety[|0|](min'x)letmax?axis?(keep_dims=true)x=let_kind=kindxinletmin_val=Owl_base_dense_common._min_val_elt_kindinmatchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxinletx=_fold_along(Owl_base_dense_common._max_elt_kind)mnoxsmin_valinifkeep_dimsthenxelsesqueeze~axis:[|a|]x|None->max'x|>create_kind[|1|]letmax_~out~axisx=let_kind=kindxinletmin_val=Owl_base_dense_common._min_val_elt_kindinGenarray.filloutmin_val;matchaxiswith|Somea->letm,n,o,s=Owl_utils.reduce_paramsaxin_fold_along~out(Owl_base_dense_common._max_elt_kind)mnoxsmin_val|>ignore|None->lety=flattenoutinsety[|0|](max'x)letl1norm'varr=letl1norm_funaggregateelem=aggregate+.Scalar.abselemin_fold_leftl1norm_fun0.varrletl2norm_sqr'varr=letl2norm_sqr_funaggregateelem=aggregate+.(elem*.elem)in_fold_leftl2norm_sqr_fun0.varrletl2norm'varr=letl2norm_sqr_val=l2norm_sqr'varrinScalar.sqrtl2norm_sqr_vallet_broadcasted_op?outvarr_avarr_bop_fun=letdims_a,dims_b,dims_c=_get_broadcasted_dims(shapevarr_a)(shapevarr_b)inlet_kind=kindvarr_ainletvarr_a=reshapevarr_adims_ainletvarr_b=reshapevarr_bdims_binletvarr_c=matchoutwith|Someout->out|None->empty_kinddims_cinletind=Array.make(Array.lengthdims_c)0inletshould_stop=reffalseinwhilenot!should_stopdoletind_a=_get_broadcasted_indexinddims_ainletind_b=_get_broadcasted_indexinddims_binGenarray.setvarr_cind(op_fun(Genarray.getvarr_aind_a)(Genarray.getvarr_bind_b));ifnot(_next_indexinddims_c)thenshould_stop:=truedone;varr_cletaddxy=let_op=Owl_base_dense_common._add_elt(kindx)in_broadcasted_opxy_opletadd_?outxy=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._add_elt(kindx)inletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletsubxy=let_op=Owl_base_dense_common._sub_elt(kindx)in_broadcasted_opxy_opletsub_?outxy=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._sub_elt(kindx)inletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletmulxy=let_op=Owl_base_dense_common._mul_elt(kindx)in_broadcasted_opxy_opletmul_?outxy=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._mul_elt(kindx)inletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletdivxy=let_op=Owl_base_dense_common._div_elt(kindx)in_broadcasted_opxy_opletdiv_?outxy=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._div_elt(kindx)inletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletatan2xy=_broadcasted_opxyScalar.atan2letatan2_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_opxyScalar.atan2|>ignorelethypotxy=_broadcasted_opxyScalar.hypotlethypot_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_opxyScalar.hypot|>ignoreletpowxy=let_kind=kindxinlet_op=Owl_base_dense_common._pow_elt_kindin_broadcasted_opxy_opletpow_?outxy=let_kind=kindxinlet_op=Owl_base_dense_common._pow_elt_kindinletout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletfmodxy=_broadcasted_opxyScalar.fmodletfmod_?outxy=letout=matchoutwith|Someo->o|None->xinletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_opxyScalar.fmod|>ignoreletmin2xy=let_op=Owl_base_dense_common._min_elt(kindx)in_broadcasted_opxy_opletmin2_?outxy=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._min_elt(kindx)inletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletmax2xy=let_op=Owl_base_dense_common._max_elt(kindx)in_broadcasted_opxy_opletmax2_?outxy=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._max_elt(kindx)inletsx=shapexinletsy=shapeyinletso=Owl_utils_infer_shape.broadcast1sxsyinletst=shapeoutinletexn=Owl_exception.DIFFERENT_SHAPE(so,st)inOwl_exception.check(so=st)exn;_broadcasted_op~outxy_op|>ignoreletadd_scalarxa=let_op=Owl_base_dense_common._add_elt(kindx)inmap(funy->_opya)xletadd_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._add_elt(kindx)inmap_(funy->_opya)outletsub_scalarxa=let_op=Owl_base_dense_common._sub_elt(kindx)inmap(funy->_opya)xletsub_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._sub_elt(kindx)inmap_(funy->_opya)outletmul_scalarxa=let_op=Owl_base_dense_common._mul_elt(kindx)inmap(funy->_opya)xletmul_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._mul_elt(kindx)inmap_(funy->_opya)outletdiv_scalarxa=let_op=Owl_base_dense_common._div_elt(kindx)inmap(funy->_opya)xletdiv_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._div_elt(kindx)inmap_(funy->_opya)outletpow_scalarxa=let_op=Owl_base_dense_common._pow_elt(kindx)inmap(funy->_opya)xletpow_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._pow_elt(kindx)inmap_(funy->_opya)outletatan2_scalarxa=let_op=Scalar.atan2inmap(funy->_opya)xletatan2_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinlet_op=Scalar.atan2inmap_(funy->_opya)outletfmod_scalarxa=let_op=Scalar.fmodinmap(funy->_opya)xletfmod_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinlet_op=Scalar.fmodinmap_(funy->_opya)out(* TODO *)letfma_x_y_z=failwith"Owl_base_dense_ndarray_generic:fma: not implemented"letscalar_addax=let_op=Owl_base_dense_common._add_elt(kindx)inmap(funy->_opay)xletscalar_add_?outax=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._add_elt(kindx)inmap_(funy->_opay)outletscalar_subax=let_op=Owl_base_dense_common._sub_elt(kindx)inmap(funy->_opay)xletscalar_sub_?outax=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._sub_elt(kindx)inmap_(funy->_opay)outletscalar_mulax=let_op=Owl_base_dense_common._mul_elt(kindx)inmap(funy->_opay)xletscalar_mul_?outax=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._mul_elt(kindx)inmap_(funy->_opay)outletscalar_divax=let_op=Owl_base_dense_common._div_elt(kindx)inmap(funy->_opay)xletscalar_div_?outax=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._div_elt(kindx)inmap_(funy->_opay)outletscalar_powax=let_op=Owl_base_dense_common._pow_elt(kindx)inmap(funy->_opay)xletscalar_pow_?outax=letout=matchoutwith|Someo->o|None->xinlet_op=Owl_base_dense_common._pow_elt(kindx)inmap_(funy->_opay)outletscalar_atan2ax=let_op=Scalar.atan2inmap(funy->_opay)xletscalar_atan2_?outax=letout=matchoutwith|Someo->o|None->xinlet_op=Scalar.atan2inmap_(funy->_opay)outletscalar_fmodax=let_op=Scalar.fmodinmap(funy->_opay)xletscalar_fmod_?outax=letout=matchoutwith|Someo->o|None->xinlet_op=Scalar.fmodinmap_(funy->_opay)outletclip_by_value?(amin=Stdlib.min_float)?(amax=Stdlib.max_float)x=let_opy=Stdlib.minamax(Stdlib.maxaminy)inmap_opxletclip_by_l2normclip_normx=letl2norm_val=l2norm'xinifl2norm_val>clip_normthenmul_scalarx(clip_norm/.l2norm_val)elsexletsoftmax?(axis=-1)x=letx=copyxinletaxis=Owl_utils.adjust_indexaxis(num_dimsx)insub_~out:xx(max~axisx);exp_~out:xx;leta=sum~axisxindiv_~out:xxa;xletsoftmax_?out?(axis=-1)x=letout=matchoutwith|Someo->o|None->xinletaxis=Owl_utils.adjust_indexaxis(num_dimsx)insub_~outx(max~axisx);exp_~outx;leta=sum~axisxindiv_~outxa(* Comparison functions *)(** Return true if for all elements comp_fun (xa, xb) == true, false otherwise.
Returns false as soon as it finds a counterexample. (NOT broadcasted) *)let_compare_util_shortcircuitvarr_avarr_bcomp_fun=letn=numelvarr_ainletm=numelvarr_binifn!=mthenfalseelse(letvarr_a=flattenvarr_a|>array1_of_genarrayinletvarr_b=flattenvarr_b|>array1_of_genarrayinletall_ok=reftrueinleti=ref0inwhile!all_ok&&!i<ndoletx=Array1.unsafe_getvarr_a!iinlety=Array1.unsafe_getvarr_b!iinifnot(comp_funxy)thenall_ok:=false;i:=!i+1done;!all_ok)letapprox_equal?epsvarr_avarr_b=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32inletapprox_equal_funxy=Scalar.abs(Scalar.subxy)<epsin_compare_util_shortcircuitvarr_avarr_bapprox_equal_funletequalxy=_compare_util_shortcircuitxyStdlib.(=)letnot_equalxy=_compare_util_shortcircuitxyStdlib.(<>)letlessxy=_compare_util_shortcircuitxyStdlib.(<)letgreaterxy=_compare_util_shortcircuitxyStdlib.(>)letless_equalxy=_compare_util_shortcircuitxyStdlib.(<=)letgreater_equalxy=_compare_util_shortcircuitxyStdlib.(>=)(** Return true if for all elements of a comp_fun (xa, bb) == true, false otherwise.
Returns false as soon as it finds a counterexample. (NOT broadcasted) *)let_compare_util_shortcircuit_scalarvarr_abcomp_fun=letn=numelvarr_ainletvarr_a=flattenvarr_a|>array1_of_genarrayinletall_ok=reftrueinleti=ref0inwhile!all_ok&&!i<ndoletx=Array1.unsafe_getvarr_a!iinifnot(comp_funxb)thenall_ok:=false;i:=!i+1done;!all_okletapprox_equal_scalar?epsvarr_ab=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32inletapprox_equal_scalar_funxy=Scalar.abs(Scalar.subxy)<epsin_compare_util_shortcircuit_scalarvarr_abapprox_equal_scalar_funletequal_scalarxa=_compare_util_shortcircuit_scalarxaStdlib.(=)letnot_equal_scalarxa=_compare_util_shortcircuit_scalarxaStdlib.(<>)letless_scalarxa=_compare_util_shortcircuit_scalarxaStdlib.(<)letgreater_scalarxa=_compare_util_shortcircuit_scalarxaStdlib.(>)letless_equal_scalarvarr_ab=_compare_util_shortcircuit_scalarvarr_abStdlib.(<=)letgreater_equal_scalarxa=_compare_util_shortcircuit_scalarxaStdlib.(>=)(* Broadcasted operation, return an array with values of 1
if (one_fun elem_from_a elem_from_b) == true, 0 otherwise *)let_make_elt_compare_funkindcmp_fun=letc0=Owl_const.zerokindinletc1=Owl_const.onekindinlet_funcab=ifcmp_funabthenc1elsec0in_funcletelt_equalxy=let_func=_make_elt_compare_fun(kindx)Stdlib.(=)in_broadcasted_opxy_funcletelt_equal_?outxy=letout=matchoutwith|Someo->o|None->xinlet_func=_make_elt_compare_fun(kindx)Stdlib.(=)in_broadcasted_op~outxy_funcletapprox_elt_equal?epsxy=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32inletapprox_equal_funxy=Scalar.abs(Scalar.subxy)<epsinlet_func=_make_elt_compare_fun(kindx)approx_equal_funin_broadcasted_opxy_funcletelt_not_equalxy=let_func=_make_elt_compare_fun(kindx)Stdlib.(<>)in_broadcasted_opxy_funcletelt_not_equal_?outxy=letout=matchoutwith|Someo->o|None->xinlet_func=_make_elt_compare_fun(kindx)Stdlib.(<>)in_broadcasted_op~outxy_funcletelt_lessxy=let_func=_make_elt_compare_fun(kindx)Stdlib.(<)in_broadcasted_opxy_funcletelt_less_?outxy=letout=matchoutwith|Someo->o|None->xinlet_func=_make_elt_compare_fun(kindx)Stdlib.(<)in_broadcasted_op~outxy_funcletelt_greaterxy=let_func=_make_elt_compare_fun(kindx)Stdlib.(>)in_broadcasted_opxy_funcletelt_greater_?outxy=letout=matchoutwith|Someo->o|None->xinlet_func=_make_elt_compare_fun(kindx)Stdlib.(>)in_broadcasted_op~outxy_funcletelt_less_equalxy=let_func=_make_elt_compare_fun(kindx)Stdlib.(<=)in_broadcasted_opxy_funcletelt_less_equal_?outxy=letout=matchoutwith|Someo->o|None->xinlet_func=_make_elt_compare_fun(kindx)Stdlib.(<=)in_broadcasted_op~outxy_funcletelt_greater_equalxy=let_func=_make_elt_compare_fun(kindx)Stdlib.(>=)in_broadcasted_opxy_funcletelt_greater_equal_?outxy=letout=matchoutwith|Someo->o|None->xinlet_func=_make_elt_compare_fun(kindx)Stdlib.(>=)in_broadcasted_op~outxy_func(* Util function, return an array with values of 1
if (one_fun elem_from_a b) == true, 0 otherwise *)let_make_elt_compare_scalarxcmp_fun=let_kind=kindxinletc0=Owl_const.zero_kindinletc1=Owl_const.one_kindinlet_funca=ifcmp_funathenc1elsec0in_funcletelt_equal_scalarxa=letcmp_funy=y=ainlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_equal_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinletcmp_funy=y=ainlet_func=_make_elt_compare_scalarxcmp_funinmap__funcoutletapprox_elt_equal_scalar?epsxa=leteps=matchepswith|Someeps->eps|None->Owl_utils.epsFloat32inletcmp_funy=Scalar.abs(Scalar.subya)<epsinlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_not_equal_scalarxa=letcmp_funy=y<>ainlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_not_equal_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinletcmp_funy=y<>ainlet_func=_make_elt_compare_scalarxcmp_funinmap__funcoutletelt_less_scalarxa=letcmp_funy=y<ainlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_less_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinletcmp_funy=y<ainlet_func=_make_elt_compare_scalarxcmp_funinmap__funcoutletelt_greater_scalarxa=letcmp_funy=y>ainlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_greater_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinletcmp_funy=y>ainlet_func=_make_elt_compare_scalarxcmp_funinmap__funcoutletelt_less_equal_scalarxa=letcmp_funy=y<=ainlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_less_equal_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinletcmp_funy=y<=ainlet_func=_make_elt_compare_scalarxcmp_funinmap__funcoutletelt_greater_equal_scalarxa=letcmp_funy=y>=ainlet_func=_make_elt_compare_scalarxcmp_funinmap_funcxletelt_greater_equal_scalar_?outxa=letout=matchoutwith|Someo->o|None->xinletcmp_funy=y>=ainlet_func=_make_elt_compare_scalarxcmp_funinmap_funcoutletexistsfx=letn=numelxinletx=flattenx|>array1_of_genarrayinletfound=reffalseinleti=ref0inwhile!i<n&¬!founddoleta=Array1.unsafe_getx!iiniffathenfound:=true;i:=!i+1done;!foundletnot_existsfvarr=not(existsfvarr)letfor_allfvarr=letnot_fx=not(fx)innot_existsnot_fvarrletis_zerovarr=letk=kindvarrinletc0=Owl_const.zerokinletnon_zero_funx=x<>c0innot_existsnon_zero_funvarrletis_positivevarr=letk=kindvarrinletc0=Owl_const.zerokinletnon_positive_funx=x<=c0innot_existsnon_positive_funvarrletis_negativevarr=letk=kindvarrinletc0=Owl_const.zerokinletnon_negative_funx=x>=c0innot_existsnon_negative_funvarrletis_nonpositivevarr=letk=kindvarrinletc0=Owl_const.zerokinletpositive_funx=x>c0innot_existspositive_funvarrletis_nonnegativevarr=letk=kindvarrinletc0=Owl_const.zerokinletnegative_funx=x<c0innot_existsnegative_funvarrletis_normalx=let_kind=kindxinletis_normal_fun=Owl_base_dense_common._is_normal_elt_kindinfor_allis_normal_funxletnot_nanx=let_kind=kindxinletis_nan_fun=Owl_base_dense_common._is_nan_elt_kindinnot_existsis_nan_funxletnot_infx=let_kind=kindxinletis_inf_fun=Owl_base_dense_common._is_inf_elt_kindinnot_existsis_inf_funx(* Neural network related functions *)(*TODO: optimise *)(* conv2d: 4d input and 4d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row; input_channel; output_channel]
stride: [column_stride; row_stride]
output: [batch; output_column; output_row; output_channel]
*)letconv2d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"conv2d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp3=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinlet_kind=kindinputinletoutput=empty_kind[|batches;output_cols;output_rows;out_channel|]inletpad_top,pad_left,_,_=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinletsum=ref0.inforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofork=0toout_channel-1dosum:=0.;fordi=0tokernel_cols-1dofordj=0tokernel_rows-1doforq=0toin_channel-1doletin_col=(i*col_stride)+di-pad_leftinletin_row=(j*row_stride)+dj-pad_topinletin_val=if0<=in_col&&in_col<input_cols&&0<=in_row&&in_row<input_rowsthengetinput[|b;in_col;in_row;q|]else0.insum:=!sum+.(in_val*.getkernel[|di;dj;q;k|])done(*q*)done(*dj*)done;(*di*)setoutput[|b;i;j;k|]!sumdone(*k*)done(*j*)done(*i*)done;(*b*)output(* conv1d: 3d input and 3d kernel, refer to tensorlfow doc
input : [batch; input_column; input_channel]
kernel: [kernel_column; input_channel; output_channel]
stride: [column_stride]
output: [batch; output_column; output_channel]
*)letconv1d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"conv1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"conv1d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=conv2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;out_channel|]inoutput(* TODO: optimise *)(* conv3d: 5d input and 5d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth; input_channel; output_channel]
stride: [column_stride; row_stride; depth_stride]
output: [batch; output_column; output_row; output_dpts; output_channel]
*)letconv3d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"conv3d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp3=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"conv3d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinlet_kind=kindinputinletoutput=empty_kind[|batches;output_cols;output_rows;output_dpts;out_channel|]inletpad_top,pad_left,pad_shallow,_,_,_=Owl_utils_infer_shape.calc_conv3d_paddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsrow_stridecol_stridedpt_strideinletsum=ref0.inforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofordpt=0tooutput_dpts-1dofork=0toout_channel-1dosum:=0.;fordi=0tokernel_cols-1dofordj=0tokernel_rows-1doford_dpt=0tokernel_dpts-1doforq=0toin_channel-1doletin_col=(i*col_stride)+di-pad_leftinletin_row=(j*row_stride)+dj-pad_topinletin_dpt=(dpt*dpt_stride)+d_dpt-pad_shallowinletin_val=if0<=in_col&&in_col<input_cols&&0<=in_row&&in_row<input_rows&&0<=in_dpt&&in_dpt<input_dptsthengetinput[|b;in_col;in_row;in_dpt;q|]else0.insum:=!sum+.(in_val*.getkernel[|di;dj;d_dpt;q;k|])done(*q*)done(*d_dpt*)done(*dj*)done;(*di*)setoutput[|b;i;j;dpt;k|]!sumdone(*k*)done(*dpt*)done(*j*)done(*i*)done;(*b*)output(* General function for avg_pool2d and max_pool2d *)let_pool2d?(padding=SAME)inputkernelstrideinit_pool_funadd_val_pool_funend_pool_fun=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"_pool2d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinlet_kind=kindinputinletoutput=empty_kind[|batches;output_cols;output_rows;in_channel|]inletpad_top,pad_left,_,_=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofork=0toin_channel-1doinit_pool_fun();fordi=0tokernel_cols-1dofordj=0tokernel_rows-1doletin_col=(i*col_stride)+di-pad_leftinletin_row=(j*row_stride)+dj-pad_topinif0<=in_col&&in_col<input_cols&&0<=in_row&&in_row<input_rowsthenadd_val_pool_fun(getinput[|b;in_col;in_row;k|])done(*dj*)done;(*di*)setoutput[|b;i;j;k|](end_pool_fun())done(*k*)done(*j*)done(*i*)done;(*b*)outputlet_pool3d?(padding=SAME)inputkernelstrideinit_pool_funadd_val_pool_funend_pool_fun=letp0=num_dimsinput=5inletp1=Array.lengthkernel=3inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"_pool3d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinlet_kind=kindinputinletoutput=empty_kind[|batches;output_cols;output_rows;output_dpts;in_channel|]inletpad_top,pad_left,pad_shallow,_,_,_=Owl_utils_infer_shape.calc_conv3d_paddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsrow_stridecol_stridedpt_strideinforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofordpt=0tooutput_dpts-1dofork=0toin_channel-1doinit_pool_fun();fordi=0tokernel_cols-1dofordj=0tokernel_rows-1doford_dpt=0tokernel_dpts-1doletin_col=(i*col_stride)+di-pad_leftinletin_row=(j*row_stride)+dj-pad_topinletin_dpt=(dpt*dpt_stride)+d_dpt-pad_shallowinif0<=in_col&&in_col<input_cols&&0<=in_row&&in_row<input_rows&&0<=in_dpt&&in_dpt<input_dptsthenadd_val_pool_fun(getinput[|b;in_col;in_row;in_dpt;k|])done(*d_dpt*)done(*dj*)done;(*di*)setoutput[|b;i;j;dpt;k|](end_pool_fun())done(*k*)done(*dpt*)done(*j*)done(*i*)done;(*b*)output(* max_pool2d: 4d input and 2d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row]
stride: [column_stride; row_stride]
output: [batch; output_column; output_row; input_channel]
*)letmax_pool2d?(padding=SAME)inputkernelstride=letmax_pool=ref0.inletinit_pool_fun()=max_pool:=Stdlib.min_floatinletadd_val_pool_funv=max_pool:=Stdlib.max!max_poolvinletend_pool_fun()=!max_poolin_pool2d~paddinginputkernelstrideinit_pool_funadd_val_pool_funend_pool_fun(* max_pool1d: 3d input and 1d kernel, refer to tensorlfow doc
input : [batch; input_column; input_channel]
kernel: [kernel_column]
stride: [column_stride]
output: [batch; output_column; input_channel]
*)letmax_pool1d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel=[|1;kernel_cols|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=max_pool2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;in_channel|]inoutput(* max_pool3d: 5d input and 3d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth]
stride: [column_stride; row_stride; depth_stride]
output: [batch; output_column; output_row; output_dpts; input_channel]
*)letmax_pool3d?(padding=SAME)inputkernelstride=letmax_pool=ref0.inletinit_pool_fun()=max_pool:=Stdlib.min_floatinletadd_val_pool_funv=max_pool:=Stdlib.max!max_poolvinletend_pool_fun()=!max_poolin_pool3d~paddinginputkernelstrideinit_pool_funadd_val_pool_funend_pool_fun(* similar to max_pool2d *)letavg_pool2d?(padding=SAME)inputkernelstride=letsum_pool=ref0.inletcnt=ref0.inletinit_pool_fun()=sum_pool:=0.;cnt:=0.inletadd_val_pool_funv=sum_pool:=!sum_pool+.v;cnt:=!cnt+.1.inletend_pool_fun()=!sum_pool/.!cntin_pool2d~paddinginputkernelstrideinit_pool_funadd_val_pool_funend_pool_fun(* similar to max_pool1d *)letavg_pool1d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel=[|1;kernel_cols|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=avg_pool2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;in_channel|]inoutput(* similar to max_pool3d *)letavg_pool3d?(padding=SAME)inputkernelstride=letsum_pool=ref0.inletcnt=ref0.inletinit_pool_fun()=sum_pool:=0.;cnt:=0.inletadd_val_pool_funv=sum_pool:=!sum_pool+.v;cnt:=!cnt+.1.inletend_pool_fun()=!sum_pool/.!cntin_pool3d~paddinginputkernelstrideinit_pool_funadd_val_pool_funend_pool_fun(*TODO: optimise *)(* gradient of conv2d w.r.t the input *)letconv2d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"conv2d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"conv2d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletinput'=empty(kindinput)(shapeinput)inletpad_top,pad_left,_,_=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinforb=0tobatches-1doforin_i=0toinput_cols-1doforin_j=0toinput_rows-1doforq=0toin_channel-1doletsum=ref0.infordi=0tokernel_cols-1dofordj=0tokernel_rows-1doifStdlib.(mod)(in_i+pad_left-di)col_stride=0&&Stdlib.(mod)(in_j+pad_top-dj)row_stride=0then(letout_col=(in_i+pad_left-di)/col_strideinletout_row=(in_j+pad_top-dj)/row_strideinif0<=out_col&&out_col<output_cols&&0<=out_row&&out_row<output_rowsthenfork=0toout_channel-1doletout_grad=getoutput'[|b;out_col;out_row;k|]inletkernel_val=getkernel[|di;dj;q;k|]insum:=!sum+.(out_grad*.kernel_val)done(*k*))done(*dj*)done;(*di*)setinput'[|b;in_i;in_j;q|]!sumdone(*q*)done(*in_j*)done(*in_i*)done;(*b*)input'(* gradient of conv2d w.r.t the kernel *)letconv2d_backward_kernelinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"conv2d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"conv2d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletkernel'=empty(kindkernel)(shapekernel)inletpad_top,pad_left,_,_=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinfordi=0tokernel_cols-1dofordj=0tokernel_rows-1doforq=0toin_channel-1dofork=0toout_channel-1doletsum=ref0.inforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1doletin_col=(i*col_stride)+di-pad_leftinletin_row=(j*row_stride)+dj-pad_topinif0<=in_col&&in_col<input_cols&&0<=in_row&&in_row<input_rowsthen(letout_grad=getoutput'[|b;i;j;k|]inletinput_val=getinput[|b;in_col;in_row;q|]insum:=!sum+.(out_grad*.input_val))done(*j*)done(*i*)done;(*b*)setkernel'[|di;dj;q;k|]!sumdone(*k*)done(*q*)done(*dj*)done;(*di*)kernel'lettranspose?axisvarr=letdims=shapevarrinletrank=Array.lengthdimsinletaxis_perm=matchaxiswith|Someperm->perm|None->Array.initrank(funi->rank-i-1)inletnew_dims=_apply_permdimsaxis_perminletnew_varr=empty(kindvarr)new_dimsinletind=Array.makerank0inletshould_stop=reffalseinwhilenot!should_stopdoGenarray.setnew_varr(_apply_permindaxis_perm)(Genarray.getvarrind);ifnot(_next_indexinddims)thenshould_stop:=truedone;new_varr(* transpose_conv2d: 4d input and 4d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row; input_channel; output_channel]
stride: [column_stride; row_stride]
output: [batch; output_column; output_row; output_channel]
*)lettranspose_conv2d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv2d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp3=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv2d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_cols,output_rows=Owl_utils_infer_shape.calc_transpose_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletoutput'=empty(kindinput)[|batches;output_cols;output_rows;out_channel|]inletkernel=transpose~axis:[|0;1;3;2|]kernelinconv2d_backward_inputoutput'kernelstrideinput(* gradient of transpose_conv2d w.r.t the input *)lettranspose_conv2d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=4inletp1=num_dimskernel=4inletp2=num_dimsoutput'=4inletp3=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 4)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 4)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv2d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletout_channel=kernel_shp.(3)inletp4=in_channel=kernel_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 4th dimension of input shape should be equal to the 3rd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv2d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 4th dimension of kernel shape should be equal to the 4th dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv2d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletpadding=SAMEinletoutput_cols_same,output_rows_same=Owl_utils_infer_shape.calc_transpose_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_strideinletp=ifoutput_cols_same=output_cols&&output_rows_same=output_rowsthenSAMEelseVALIDinletkernel=transpose~axis:[|0;1;3;2|]kernelinconv2d~padding:poutput'kernelstride(* gradient of transpose_conv2d w.r.t the kernel *)lettranspose_conv2d_backward_kernelinputkernelstrideoutput'=conv2d_backward_kerneloutput'kernelstrideinput(* transpose_conv1d: 3d input and 3d kernel, refer to tensorlfow doc
input : [batch; input_column; input_channel]
kernel: [kernel_column; input_channel; output_channel]
stride: [column_stride]
output: [batch; output_column; output_channel]
*)lettranspose_conv1d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv1d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=transpose_conv2d~paddinginputkernelstrideinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;out_channel|]inoutput(* gradient of conv1d w.r.t the input *)letconv1d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"conv1d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"conv1d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"conv1d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletinput'=conv2d_backward_inputinputkernelstrideoutput'inreshapeinput'input_shp(* gradient of conv1d w.r.t the kernel *)letconv1d_backward_kernelinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"conv1d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"conv1d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"conv1d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletkernel'=conv2d_backward_kernelinputkernelstrideoutput'inreshapekernel'kernel_shp(* gradient of transpose_conv1d w.r.t the input *)lettranspose_conv1d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv1d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv1d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv1d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletinput'=transpose_conv2d_backward_inputinputkernelstrideoutput'inreshapeinput'input_shp(* gradient of conv1d w.r.t the kernel *)lettranspose_conv1d_backward_kernelinputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv1d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv1d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv1d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletkernel'=transpose_conv2d_backward_kernelinputkernelstrideoutput'inreshapekernel'kernel_shp(*TODO: optimise *)(* gradient of conv3d w.r.t the input *)letconv3d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"conv3d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"conv3d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"conv3d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletinput'=empty(kindinput)(shapeinput)inletpad_top,pad_left,pad_shallow,_,_,_=Owl_utils_infer_shape.calc_conv3d_paddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsrow_stridecol_stridedpt_strideinforb=0tobatches-1doforin_i=0toinput_cols-1doforin_j=0toinput_rows-1doforin_dpt=0toinput_dpts-1doforq=0toin_channel-1doletsum=ref0.infordi=0tokernel_cols-1dofordj=0tokernel_rows-1doford_dpt=0tokernel_dpts-1doifStdlib.(mod)(in_i+pad_left-di)col_stride=0&&Stdlib.(mod)(in_j+pad_top-dj)row_stride=0&&Stdlib.(mod)(in_dpt+pad_shallow-d_dpt)dpt_stride=0then(letout_col=(in_i+pad_left-di)/col_strideinletout_row=(in_j+pad_top-dj)/row_strideinletout_dpt=(in_dpt+pad_shallow-d_dpt)/dpt_strideinif0<=out_col&&out_col<output_cols&&0<=out_row&&out_row<output_rows&&0<=out_dpt&&out_dpt<output_dptsthenfork=0toout_channel-1doletout_grad=getoutput'[|b;out_col;out_row;out_dpt;k|]inletkernel_val=getkernel[|di;dj;d_dpt;q;k|]insum:=!sum+.(out_grad*.kernel_val)done(*k*))done(*d_dpt*)done(*dj*)done;(*di*)setinput'[|b;in_i;in_j;in_dpt;q|]!sumdone(*q*)done(*in_dpt*)done(*in_j*)done(*in_i*)done;(*b*)input'(* gradient of conv3d w.r.t the kernel *)letconv3d_backward_kernelinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"conv3d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"conv2d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"conv2d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletkernel'=empty(kindkernel)(shapekernel)inletpad_top,pad_left,pad_shallow,_,_,_=Owl_utils_infer_shape.calc_conv3d_paddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsrow_stridecol_stridedpt_strideinfordi=0tokernel_cols-1dofordj=0tokernel_rows-1doford_dpt=0tokernel_dpts-1doforq=0toin_channel-1dofork=0toout_channel-1doletsum=ref0.inforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofordpt=0tooutput_dpts-1doletin_col=(i*col_stride)+di-pad_leftinletin_row=(j*row_stride)+dj-pad_topinletin_dpt=(dpt*dpt_stride)+d_dpt-pad_shallowinif0<=in_col&&in_col<input_cols&&0<=in_row&&in_row<input_rows&&0<=in_dpt&&in_dpt<input_dptsthen(letout_grad=getoutput'[|b;i;j;dpt;k|]inletinput_val=getinput[|b;in_col;in_row;in_dpt;q|]insum:=!sum+.(out_grad*.input_val))done(*dpt*)done(*j*)done(*i*)done;(*b*)setkernel'[|di;dj;d_dpt;q;k|]!sumdone(*k*)done(*q*)done(*d_dpt*)done(*dj*)done;(*di*)kernel'(* transpose_conv3d: 5d input and 5d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth; input_channel; output_channel]
stride: [column_stride; row_stride; depth_stride]
output: [batch; output_column; output_row; output_dpts; output_channel]
*)lettranspose_conv3d?(padding=SAME)inputkernelstride=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"transpose_conv3d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp3=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv3d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_cols,output_rows,output_dpts=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletoutput=empty(kindinput)[|batches;output_cols;output_rows;output_dpts;out_channel|]inletkernel=transpose~axis:[|0;1;2;4;3|]kernelinconv3d_backward_inputoutputkernelstrideinput(* gradient of transpose_conv3d w.r.t the input *)lettranspose_conv3d_backward_inputinputkernelstrideoutput'=letp0=num_dimsinput=5inletp1=num_dimskernel=5inletp2=num_dimsoutput'=5inletp3=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 5)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 5)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets4=Printf.sprintf"transpose_conv3d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletout_channel=kernel_shp.(4)inletp4=in_channel=kernel_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 5th dimension of input shape should be equal to the 4th dimension of kernel \
shape"inlets5=Printf.sprintf"transpose_conv3d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=out_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 5th dimension of kernel shape should be equal to the 5th dimension of \
output' shape"inlets8=Printf.sprintf"transpose_conv3d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletpadding=SAMEinletoutput_cols_same,output_rows_same,output_dpts_same=Owl_utils_infer_shape.calc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_strideinletp=ifoutput_cols_same=output_cols&&output_rows_same=output_rows&&output_dpts_same=output_dptsthenSAMEelseVALIDinletkernel=transpose~axis:[|0;1;2;4;3|]kernelinconv3d~padding:poutput'kernelstride(* gradient of transpose_conv3d w.r.t the kernel *)lettranspose_conv3d_backward_kernelinputkernelstrideoutput'=conv3d_backward_kerneloutput'kernelstrideinput(* TODO: definitely optimise *)(* General function for avg_pool2d and max_pool2d *)let_pool2d_backward_paddinginputkernelstrideoutput'init_pool_funadd_val_pool_funend_pool_funcompute_grad_fun=letp0=num_dimsinput=4inletp1=Array.lengthkernel=2inletp2=Array.lengthstride=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 2)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 2)"(Array.lengthstride)inlets3=Printf.sprintf"_pool2d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletp5=batches=output_shp.(0)inletp6=in_channel=output_shp.(3)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Printf.sprintf"input shape is [%s]"s0inlets3=Printf.sprintf"output' shape is [%s]"s1inlets4=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets5=Printf.sprintf"the 4th dimension of input shape should be equal to the 4th dimension of \
output' shape"inlets6=Printf.sprintf"_pool2d_backward: %s; %s; %s; %s."s2s3s4s5inOwl_exception.INVALID_ARGUMENTs6inOwl_exception.verify(p5&&p6)error;letpad_top,pad_left,_,_=Owl_utils_infer_shape.calc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_strideinletinput'=zeros(kindinput)(shapeinput)inforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofork=0toin_channel-1doinit_pool_fun();fordi=0tokernel_cols-1dofordj=0tokernel_rows-1doletin_col=(i*col_stride)+di-pad_leftinletin_row=(j*row_stride)+dj-pad_topinif0<=in_col&&in_col<input_cols&&0<=in_row&&in_row<input_rowsthenadd_val_pool_fun(getinput[|b;in_col;in_row;k|])done(*dj*)done;(*di*)letoutput_val=end_pool_fun()inletoutput_grad=getoutput'[|b;i;j;k|]infordi=0tokernel_cols-1dofordj=0tokernel_rows-1doletin_col=(i*col_stride)+di-pad_leftinletin_row=(j*row_stride)+dj-pad_topinif0<=in_col&&in_col<input_cols&&0<=in_row&&in_row<input_rowsthen(letinput_val=getinput[|b;in_col;in_row;k|]inletinput_grad=getinput'[|b;in_col;in_row;k|]insetinput'[|b;in_col;in_row;k|](compute_grad_funinput_valinput_gradoutput_valoutput_grad))done(*dj*)done(*di*)done(*k*)done(*j*)done(*i*)done;(*b*)input'(* calculate the gradient of max_pool2d *)letmax_pool2d_backwardpaddinginputkernelstrideoutput'=letmax_pool=ref0.inletinit_pool_fun()=max_pool:=Stdlib.min_floatinletadd_val_pool_funv=max_pool:=Stdlib.max!max_poolvinletend_pool_fun()=!max_poolinletcompute_grad_funinput_valinput_gradoutput_valoutput_grad=ifScalar.abs(input_val-.output_val)<1e-8(*TODO: change comparison here *)theninput_grad+.output_gradelseinput_gradin_pool2d_backwardpaddinginputkernelstrideoutput'init_pool_funadd_val_pool_funend_pool_funcompute_grad_fun(* calculate the gradient of avg_pool2d *)letavg_pool2d_backwardpaddinginputkernelstrideoutput'=letsum_pool=ref0.inletcnt=ref0.inletinit_pool_fun()=sum_pool:=0.;cnt:=0.inletadd_val_pool_funv=sum_pool:=!sum_pool+.v;cnt:=!cnt+.1.inletend_pool_fun()=!sum_pool/.!cntinletcompute_grad_fun_input_valinput_grad_output_valoutput_grad=input_grad+.(output_grad/.!cnt)in_pool2d_backwardpaddinginputkernelstrideoutput'init_pool_funadd_val_pool_funend_pool_funcompute_grad_fun(* TODO: definitely optimise *)(* General function for avg_pool3d and max_pool3d *)let_pool3d_backward_paddinginputkernelstrideoutput'init_pool_funadd_val_pool_funend_pool_funcompute_grad_fun=letp0=num_dimsinput=5inletp1=Array.lengthkernel=3inletp2=Array.lengthstride=3inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 5)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 3)"(Array.lengthstride)inlets3=Printf.sprintf"_pool3d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletinput_dpts=input_shp.(3)inletin_channel=input_shp.(4)inletkernel_cols=kernel.(0)inletkernel_rows=kernel.(1)inletkernel_dpts=kernel.(2)inletcol_stride=stride.(0)inletrow_stride=stride.(1)inletdpt_stride=stride.(2)inletoutput_shp=shapeoutput'inletoutput_cols=output_shp.(1)inletoutput_rows=output_shp.(2)inletoutput_dpts=output_shp.(3)inletp5=batches=output_shp.(0)inletp6=in_channel=output_shp.(4)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Printf.sprintf"input shape is [%s]"s0inlets3=Printf.sprintf"output' shape is [%s]"s1inlets4=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets5=Printf.sprintf"the 5th dimension of input shape should be equal to the 5th dimension of \
output' shape"inlets6=Printf.sprintf"_pool3d_backward: %s; %s; %s; %s."s2s3s4s5inOwl_exception.INVALID_ARGUMENTs6inOwl_exception.verify(p5&&p6)error;letpad_top,pad_left,pad_shallow,_,_,_=Owl_utils_infer_shape.calc_conv3d_paddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsoutput_colsoutput_rowsoutput_dptsrow_stridecol_stridedpt_strideinletinput'=zeros(kindinput)(shapeinput)inforb=0tobatches-1dofori=0tooutput_cols-1doforj=0tooutput_rows-1dofordpt=0tooutput_dpts-1dofork=0toin_channel-1doinit_pool_fun();fordi=0tokernel_cols-1dofordj=0tokernel_rows-1dofordk=0tokernel_dpts-1doletin_col=(i*col_stride)+di-pad_leftinletin_row=(j*row_stride)+dj-pad_topinletin_dpt=(dpt*dpt_stride)+dk-pad_shallowinif0<=in_col&&in_col<input_cols&&0<=in_row&&in_row<input_rows&&0<=in_dpt&&in_dpt<input_dptsthenadd_val_pool_fun(getinput[|b;in_col;in_row;in_dpt;k|])done(*dk*)done(*dj*)done;(*di*)letoutput_val=end_pool_fun()inletoutput_grad=getoutput'[|b;i;j;dpt;k|]infordi=0tokernel_cols-1dofordj=0tokernel_rows-1dofordk=0tokernel_dpts-1doletin_col=(i*col_stride)+di-pad_leftinletin_row=(j*row_stride)+dj-pad_topinletin_dpt=(dpt*dpt_stride)+dk-pad_shallowinif0<=in_col&&in_col<input_cols&&0<=in_row&&in_row<input_rows&&0<=in_dpt&&in_dpt<input_dptsthen(letinput_val=getinput[|b;in_col;in_row;in_dpt;k|]inletinput_grad=getinput'[|b;in_col;in_row;in_dpt;k|]insetinput'[|b;in_col;in_row;in_dpt;k|](compute_grad_funinput_valinput_gradoutput_valoutput_grad))done(*dk*)done(*dj*)done(*di*)done(*k*)done(*dpt*)done(*j*)done(*i*)done;(*b*)input'(* calculate the gradient of max_pool3d *)letmax_pool3d_backwardpaddinginputkernelstrideoutput'=letmax_pool=ref0.inletinit_pool_fun()=max_pool:=Stdlib.min_floatinletadd_val_pool_funv=max_pool:=Stdlib.max!max_poolvinletend_pool_fun()=!max_poolinletcompute_grad_funinput_valinput_gradoutput_valoutput_grad=ifScalar.abs(input_val-.output_val)<1e-8(*TODO: change comparison here *)theninput_grad+.output_gradelseinput_gradin_pool3d_backwardpaddinginputkernelstrideoutput'init_pool_funadd_val_pool_funend_pool_funcompute_grad_fun(* calculate the gradient of avg_pool3d *)letavg_pool3d_backwardpaddinginputkernelstrideoutput'=letsum_pool=ref0.inletcnt=ref0.inletinit_pool_fun()=sum_pool:=0.;cnt:=0.inletadd_val_pool_funv=sum_pool:=!sum_pool+.v;cnt:=!cnt+.1.inletend_pool_fun()=!sum_pool/.!cntinletcompute_grad_fun_input_valinput_grad_output_valoutput_grad=input_grad+.(output_grad/.!cnt)in_pool3d_backwardpaddinginputkernelstrideoutput'init_pool_funadd_val_pool_funend_pool_funcompute_grad_fun(* calculate the gradient of max_pool1d *)letmax_pool1d_backwardpaddinginputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"max_pool1d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=1inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel_rows=1inletkernel=[|kernel_rows;kernel_cols|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletoutput_rows=1inletout_channel=output'_shp.(2)inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletinput'=max_pool2d_backwardpaddinginputkernelstrideoutput'inreshapeinput'input_shp(* calculate the gradient of avg_pool1d *)letavg_pool1d_backwardpaddinginputkernelstrideoutput'=letp0=num_dimsinput=3inletp1=Array.lengthkernel=1inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 1)"(Array.lengthkernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"avg_pool1d_backward: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=1inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_cols=kernel.(0)inletkernel_rows=1inletkernel=[|kernel_rows;kernel_cols|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletoutput_rows=1inletout_channel=output'_shp.(2)inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletinput'=avg_pool2d_backwardpaddinginputkernelstrideoutput'inreshapeinput'input_shp(* create a dilated 2d kernel *)letupsample_kernel2dkernelrate=ifrate=[|1;1|]thenkernelelse(letkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletin_channel=kernel_shp.(2)inletout_channel=kernel_shp.(3)inletcol_rate=rate.(0)inletrow_rate=rate.(1)inletcol_up=kernel_cols+((kernel_cols-1)*(col_rate-1))inletrow_up=kernel_rows+((kernel_rows-1)*(row_rate-1))inletnew_kernel=zeros(kindkernel)[|col_up;row_up;in_channel;out_channel|]inforc=0tokernel_cols-1doforr=0tokernel_rows-1dofori=0toin_channel-1doforo=0toout_channel-1doletv=getkernel[|c;r;i;o|]insetnew_kernel[|c*col_rate;r*row_rate;i;o|]vdonedonedonedone;new_kernel)(* change a dilated 2d kernel back to normal *)letdownsample_kernel2dkernelrate=ifrate=[|1;1|]thenkernelelse(letkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletin_channel=kernel_shp.(2)inletout_channel=kernel_shp.(3)inletcol_rate=rate.(0)inletrow_rate=rate.(1)inletcol_down=(kernel_cols+(col_rate-1))/col_rateinletrow_down=(kernel_rows+(row_rate-1))/row_rateinletnew_kernel=zeros(kindkernel)[|col_down;row_down;in_channel;out_channel|]inforc=0tocol_down-1doforr=0torow_down-1dofori=0toin_channel-1doforo=0toout_channel-1doletv=getkernel[|c*col_rate;r*row_rate;i;o|]insetnew_kernel[|c;r;i;o|]vdonedonedonedone;new_kernel)(* dilated_conv2d: 4d input and 4d kernel, refer to tensorlfow doc
input : [batch; input_column; input_row; input_channel]
kernel: [kernel_column; kernel_row; input_channel; output_channel]
stride: [column_stride; row_stride]
rate : [col_dilation_rate; row_dilation_rate]
output: [batch; output_column; output_row; output_channel]
*)letdilated_conv2d?(padding=SAME)inputkernelstriderate=letp0=Array.lengthrate=2inleterror()=lets0=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets1=Printf.sprintf"dilated_conv2d: %s."s0inOwl_exception.INVALID_ARGUMENTs1inOwl_exception.verifyp0error;letkernel=upsample_kernel2dkernelrateinconv2d~paddinginputkernelstride(* gradient of dilated_conv2d w.r.t the input *)letdilated_conv2d_backward_inputinputkernelstriderateoutput'=letp0=Array.lengthrate=2inleterror()=lets0=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets1=Printf.sprintf"dilated_conv2d_backward_input: %s."s0inOwl_exception.INVALID_ARGUMENTs1inOwl_exception.verifyp0error;letkernel=upsample_kernel2dkernelrateinconv2d_backward_inputinputkernelstrideoutput'(* gradient of dilated_conv2d w.r.t the kernel *)letdilated_conv2d_backward_kernelinputkernelstriderateoutput'=letp0=Array.lengthrate=2inleterror()=lets0=Printf.sprintf"rate dimension = %i (should be 2)"(Array.lengthrate)inlets1=Printf.sprintf"dilated_conv2d_backward_kernel: %s."s0inOwl_exception.INVALID_ARGUMENTs1inOwl_exception.verifyp0error;letkernel=upsample_kernel2dkernelrateinletkernel'=conv2d_backward_kernelinputkernelstrideoutput'indownsample_kernel2dkernel'rate(* create a dilated 3d kernel *)letupsample_kernel3dkernelrate=ifrate=[|1;1;1|]thenkernelelse(letkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletin_channel=kernel_shp.(3)inletout_channel=kernel_shp.(4)inletcol_rate=rate.(0)inletrow_rate=rate.(1)inletdpt_rate=rate.(2)inletcol_up=kernel_cols+((kernel_cols-1)*(col_rate-1))inletrow_up=kernel_rows+((kernel_rows-1)*(row_rate-1))inletdpt_up=kernel_dpts+((kernel_dpts-1)*(dpt_rate-1))inletnew_kernel=zeros(kindkernel)[|col_up;row_up;dpt_up;in_channel;out_channel|]inforc=0tokernel_cols-1doforr=0tokernel_rows-1doford=0tokernel_dpts-1dofori=0toin_channel-1doforo=0toout_channel-1doletv=getkernel[|c;r;d;i;o|]insetnew_kernel[|c*col_rate;r*row_rate;d*dpt_rate;i;o|]vdonedonedonedonedone;new_kernel)(* change a dilated 3d kernel back to normal *)letdownsample_kernel3dkernelrate=ifrate=[|1;1;1|]thenkernelelse(letkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletkernel_rows=kernel_shp.(1)inletkernel_dpts=kernel_shp.(2)inletin_channel=kernel_shp.(3)inletout_channel=kernel_shp.(4)inletcol_rate=rate.(0)inletrow_rate=rate.(1)inletdpt_rate=rate.(2)inletcol_down=(kernel_cols+(col_rate-1))/col_rateinletrow_down=(kernel_rows+(row_rate-1))/row_rateinletdpt_down=(kernel_dpts+(dpt_rate-1))/dpt_rateinletnew_kernel=zeros(kindkernel)[|col_down;row_down;dpt_down;in_channel;out_channel|]inforc=0tocol_down-1doforr=0torow_down-1doford=0todpt_down-1dofori=0toin_channel-1doforo=0toout_channel-1doletv=getkernel[|c*col_rate;r*row_rate;d*dpt_rate;i;o|]insetnew_kernel[|c;r;d;i;o|]vdonedonedonedonedone;new_kernel)(* dilated_conv3d: 5d input and 5d kernel, refer to tensorflow doc
input : [batch; input_column; input_row; input_depth; input_channel]
kernel: [kernel_column; kernel_row; kernel_depth; input_channel; output_channel]
stride: [column_stride; row_stride; depth_stride]
rate : [col_dilation_rate; row_dilation_rate; depth_dilation_rate]
output: [batch; output_column; output_row; output_dpts; output_channel]
*)letdilated_conv3d?(padding=SAME)inputkernelstriderate=letp0=Array.lengthrate=3inleterror()=lets0=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets1=Printf.sprintf"dilated_conv3d: %s."s0inOwl_exception.INVALID_ARGUMENTs1inOwl_exception.verifyp0error;letkernel=upsample_kernel3dkernelrateinconv3d~paddinginputkernelstride(* gradient of dilated_conv3d w.r.t the input *)letdilated_conv3d_backward_inputinputkernelstriderateoutput'=letp0=Array.lengthrate=3inleterror()=lets0=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets1=Printf.sprintf"dilated_conv3d_backward_input: %s."s0inOwl_exception.INVALID_ARGUMENTs1inOwl_exception.verifyp0error;letkernel=upsample_kernel3dkernelrateinconv3d_backward_inputinputkernelstrideoutput'(* gradient of dilated_conv3d w.r.t the kernel *)letdilated_conv3d_backward_kernelinputkernelstriderateoutput'=letp0=Array.lengthrate=3inleterror()=lets0=Printf.sprintf"rate dimension = %i (should be 3)"(Array.lengthrate)inlets1=Printf.sprintf"dilated_conv3d_backward_kernel: %s."s0inOwl_exception.INVALID_ARGUMENTs1inOwl_exception.verifyp0error;letkernel=upsample_kernel3dkernelrateinletkernel'=conv3d_backward_kernelinputkernelstrideoutput'indownsample_kernel3dkernel'rate(* dilated_conv1d: 3d input and 3d kernel, refer to tensorlfow doc
input : [batch; input_column; input_channel]
kernel: [kernel_column; input_channel; output_channel]
stride: [column_rate]
output: [batch; output_column; output_channel]
*)letdilated_conv1d?(padding=SAME)inputkernelstriderate=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets3=Printf.sprintf"dilated_conv1d: %s; %s; %s."s0s1s2inOwl_exception.INVALID_ARGUMENTs3inOwl_exception.verify(p0&&p1&&p2)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput=reshapeinput[|batches;1;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp3=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3rd dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv1d: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp3error;letkernel=reshapekernel[|1;kernel_cols;in_channel;out_channel|]inletcol_stride=stride.(0)inletstride=[|1;col_stride|]inletoutput=dilated_conv2d~paddinginputkernelstriderateinletoutput_shp=shapeoutputinletoutput_cols=output_shp.(2)inletoutput=reshapeoutput[|batches;output_cols;out_channel|]inoutput(* gradient of dilated_conv1d w.r.t the input *)letdilated_conv1d_backward_inputinputkernelstriderateoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"dilated_conv1d_backward_input: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv1d_backward_input: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv1d_backward_input: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletinput'=dilated_conv2d_backward_inputinputkernelstriderateoutput'inreshapeinput'input_shp(* gradient of dilated_conv1d w.r.t the kernel *)letdilated_conv1d_backward_kernelinputkernelstriderateoutput'=letp0=num_dimsinput=3inletp1=num_dimskernel=3inletp2=num_dimsoutput'=3inletp3=Array.lengthstride=1inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 3)"(num_dimsinput)inlets1=Printf.sprintf"kernel dimension = %i (should be 3)"(num_dimskernel)inlets2=Printf.sprintf"output' dimension = %i (should be 3)"(num_dimsoutput')inlets3=Printf.sprintf"stride dimension = %i (should be 1)"(Array.lengthstride)inlets4=Printf.sprintf"dilated_conv1d_backward_kernel: %s; %s; %s; %s."s0s1s2s3inOwl_exception.INVALID_ARGUMENTs4inOwl_exception.verify(p0&&p1&&p2&&p3)error;letinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletin_channel=input_shp.(2)inletinput_rows=1inletinput=reshapeinput[|batches;input_rows;input_cols;in_channel|]inletkernel_shp=shapekernelinletkernel_cols=kernel_shp.(0)inletout_channel=kernel_shp.(2)inletp4=in_channel=kernel_shp.(1)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Printf.sprintf"input shape is [%s]"s0inlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"kernel shape is [%s]"s2inlets4=Printf.sprintf"the 3th dimension of input shape should be equal to the 2nd dimension of kernel \
shape"inlets5=Printf.sprintf"dilated_conv1d_backward_kernel: %s, %s, %s."s1s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verifyp4error;letkernel_rows=1inletkernel=reshapekernel[|kernel_rows;kernel_cols;in_channel;out_channel|]inletoutput'_shp=shapeoutput'inletoutput_cols=output'_shp.(1)inletp5=batches=output'_shp.(0)inletp6=out_channel=output'_shp.(2)inleterror()=lets0=Owl_utils_array.to_stringstring_of_intinput_shpinlets1=Owl_utils_array.to_stringstring_of_intoutput'_shpinlets2=Owl_utils_array.to_stringstring_of_intkernel_shpinlets3=Printf.sprintf"input shape is [%s]"s0inlets4=Printf.sprintf"output' shape is [%s]"s1inlets5=Printf.sprintf"kernel shape is [%s]"s2inlets6=Printf.sprintf"the 1st dimension of input shape should be equal to the 1st dimension of \
output' shape"inlets7=Printf.sprintf"the 3rd dimension of kernel shape should be equal to the 3rd dimension of \
output' shape"inlets8=Printf.sprintf"dilated_conv1d_backward_kernel: %s; %s; %s; %s; %s."s3s4s5s6s7inOwl_exception.INVALID_ARGUMENTs8inOwl_exception.verify(p5&&p6)error;letoutput_rows=1inletoutput'=reshapeoutput'[|batches;output_rows;output_cols;out_channel|]inletcol_stride=stride.(0)inletrow_stride=1inletstride=[|row_stride;col_stride|]inletkernel'=dilated_conv2d_backward_kernelinputkernelstriderateoutput'inreshapekernel'kernel_shpletupsampling2dinputsize=letp0=num_dimsinput=4inletp1=Array.lengthsize=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"size dimension = %i (should be 2)"(Array.lengthsize)inlets2=Printf.sprintf"upsampling2d: %s; %s."s0s1inOwl_exception.INVALID_ARGUMENTs2inOwl_exception.verify(p0&&p1)error;repeatinput[|1;size.(0);size.(1);1|]letupsampling2d_backwardinputsizeoutput=letp0=num_dimsinput=4inletp1=Array.lengthsize=2inleterror()=lets0=Printf.sprintf"input dimension = %i (should be 4)"(num_dimsinput)inlets1=Printf.sprintf"size dimension = %i (should be 2)"(Array.lengthsize)inlets2=Printf.sprintf"upsampling2d_backward: %s; %s."s0s1inOwl_exception.INVALID_ARGUMENTs2inOwl_exception.verify(p0&&p1)error;let_kind=kindinputinletinput_shp=shapeinputinletbatches=input_shp.(0)inletinput_cols=input_shp.(1)inletinput_rows=input_shp.(2)inletin_channel=input_shp.(3)inletcol_scale=size.(0)inletrow_scale=size.(1)inletoutput_shp=shapeoutputinletoutput_cols=input_cols*col_scaleinletoutput_rows=input_rows*row_scaleinletp2=output_cols=output_shp.(1)inletp3=output_rows=output_shp.(2)inleterror()=lets1=Owl_utils_array.to_stringstring_of_intoutput_shpinlets2=Printf.sprintf"output shape is [%s]"s1inlets3=Printf.sprintf"scaled output cols is %i, should be equal to the 2nd dimension of output shape"output_colsinlets4=Printf.sprintf"scaled output rows is %i, should be equal to the 3rd dimension of output shape"output_rowsinlets5=Printf.sprintf"upsampling2d_backward: %s; %s; %s."s2s3s4inOwl_exception.INVALID_ARGUMENTs5inOwl_exception.verify(p2&&p3)error;letinput'=zeros_kindinput_shpinforb=0tobatches-1doforc=0tooutput_cols-1doletin_c=c/col_scaleinletin_c=Stdlib.minin_c(input_cols-1)inforr=0tooutput_rows-1doletin_r=r/row_scaleinletin_r=Stdlib.minin_r(input_rows-1)infori=0toin_channel-1doletin_val=getinput'[|b;in_c;in_r;i|]inletout_val=getoutput[|b;c;r;i|]insetinput'[|b;in_c;in_r;i|](in_val+.out_val)donedonedonedone;input'(* matrix functions *)let_remove_unit_dimsdims=letremoved_ones_list=List.filter(funx->x>1)(Array.to_listdims)inletnot_empty_list=matchremoved_ones_listwith|[]->[1]|_->removed_ones_listinArray.of_listnot_empty_listlet_check_is_matrixdims=ifArray.lengthdims!=2thenraise(Invalid_argument"The given NDarray is not a matrix!")else()letrow_numvarr=letdims=shapevarrin_check_is_matrixdims;dims.(0)letcol_numvarr=letdims=shapevarrin_check_is_matrixdims;dims.(1)(* NOTE: this is a view into the original array *)letrowvarrind=letdims=shapevarrin_check_is_matrixdims;Genarray.slice_leftvarr[|ind|]letrowsvarrindices=letdims=shapevarrinlet_=_check_is_matrixdimsinletnew_rownum=Array.lengthindicesinletnew_colnum=dims.(1)inletnew_varr=empty(kindvarr)[|new_rownum;new_colnum|]infori=0tonew_rownum-1doGenarray.blit(Genarray.slice_leftvarr[|indices.(i)|])(* indices[i] row of the original *)(Genarray.slice_leftnew_varr[|i|])(* i-th row of the new matrix *)done;new_varrletcopy_row_tovecvarrind=letdims=shapevarrinlet_=_check_is_matrixdimsinGenarray.blitvec(Genarray.slice_leftvarr[|ind|])letcopy_col_tovecvarrind=letdims=shapevarrinlet_=_check_is_matrixdimsinletvec_dims=_remove_unit_dims(shapevec)inletvec_len=ifArray.lengthvec_dims=1thenvec_dims.(0)elseraise(Invalid_argument"Vector is not a column vector")inletnum_rows=dims.(0)inletvec_linear=flattenvec|>array1_of_genarrayinifnum_rows!=vec_lenthenraise(Invalid_argument"Column vector does not have the same length as the number of rows in the matrix")elsefori=0tonum_rows-1doGenarray.setvarr[|i;ind|](Array1.unsafe_getvec_lineari)doneletdotvarr_avarr_b=letdims_a,dims_b=shapevarr_a,shapevarr_binlet_,_=_check_is_matrixdims_a,_check_is_matrixdims_binletm=dims_a.(0)inletcdim=dims_a.(1)inletn=dims_b.(1)inifdims_b.(0)!=cdimthenraise(Invalid_argument"Matrices cannot be multiplied")else(letvarr_c=empty(kindvarr_a)[|m;n|]inletsum=ref0.infori=0tom-1doforj=0ton-1dosum:=0.;fork=0tocdim-1dosum:=!sum+.(Genarray.getvarr_a[|i;k|]*.Genarray.getvarr_b[|k;j|])done;Genarray.setvarr_c[|i;j|]!sumdonedone;varr_c)lettracevarr=letdims=shapevarrinlet_=_check_is_matrixdimsinletn=dims.(0)inifdims.(1)!=nthenraise(Invalid_argument"Argument is not a square matrix")else(letsum=ref0.infori=0ton-1dosum:=!sum+.Genarray.getvarr[|i;i|]done;!sum)(* NOTE: each row is actually a view in the original matrix, no copying involved *)letto_rowsvarr=letdims=shapevarrinlet_=_check_is_matrixdimsinletm=dims.(0)inArray.initm(funi->Genarray.slice_leftvarr[|i|])letto_cols_harr=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.to_cols")letof_rowsrows=letm=Array.lengthrowsinletrow_dim=shaperows.(0)inletdims=Array.append[|m|]row_diminletvarr=empty(kindrows.(0))dimsinfori=0tom-1doGenarray.blitrows.(i)(Genarray.slice_leftvarr[|i|])done;varrletof_cols_cols=raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.of_cols")letof_arrayskindarrays=letm=Array.lengtharraysinletn=Array.lengtharrays.(0)inletvarr=emptykind[|m;n|]infori=0tom-1doforj=0ton-1doGenarray.setvarr[|i;j|](Array.unsafe_getarrays.(i)j)donedone;varrletdraw_rows?(replacement=true)varrcount=letdims=shapevarrinletindices=_draw_int_samplesreplacement(Array.lengthdims)countinletextracted=rowsvarrindicesinextracted,indicesletdraw_rows2?(replacement=true)varr_avarr_bcount=letextracted_a,indices=draw_rows~replacementvarr_acountinletextracted_b=rowsvarr_bindicesinextracted_a,extracted_b,indicesletdiag?(k=0)_x=k|>ignore;raise(Owl_exception.NOT_IMPLEMENTED"owl_base_dense_ndarray_generic.diag")(* TODO: here k is not used, but neither is it in nonbase dense array? - investigate *)letload_kf=Owl_io.marshal_from_filefletmax_rowsvarr=letdims=shapevarrinlet_=_check_is_matrixdimsinletr,c=dims.(0),dims.(1)inletresult=Array.maker(0.,0,0)infori=0tor-1doletbest=refStdlib.min_floatinletbest_pos=ref~-1inforj=0toc-1doletx=getvarr[|i;j|]inifx>!bestthen(best:=x;best_pos:=j)done;result.(i)<-!best,i,!best_posdone;resultletone_hot_depth_x=failwith"Owl_base_dense_ndarray_generic:one_hot: not implemented"(* Helper functions *)letfloat_to_eltx=xletelt_to_floatx=x(* ends here *)