123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789# 1 "src/owl/dense/owl_dense_ndarray_a.ml"(*
* OWL - an OCaml numerical library for scientific computing
* Copyright (c) 2016-2018 Liang Wang <liang.wang@cl.cam.ac.uk>
*)(* This is the module to support N-dimensional array of any types. In order to
differentiate from others supporting numerical types, I made this module
mostly self-contained. That means, I don't intend to make the code in this
module reused by other modules, and some functions are copied from other
files then locally adapted to provide the needed functionality.
*)openOwl_typesopenOwl_ndarraytype'aarr={mutableshape:intarray;mutablestride:intarray;mutabledata:'aarray;}let_calc_numel_from_shapes=Array.fold_left(funab->a*b)1sletmake_arrshapestridedata={shape;stride;data;}letcreateda=letn=_calc_numel_from_shapedinmake_arrd(Owl_utils.calc_strided)(Array.makena)letinitdf=letn=_calc_numel_from_shapedinletdata=Array.initn(funi->fi)inmake_arrd(Owl_utils.calc_strided)dataletinit_nddf=letn=_calc_numel_from_shapedinletj=Array.copydinlets=Owl_utils.calc_stridedinletdata=Array.initn(funi->Owl_utils.index_1d_ndijs;fj;)inmake_arrd(Owl_utils.calc_strided)dataletsequential?(a=0.)?(step=1.)d=letn=_calc_numel_from_shapedinleta=ref(a-.step)inletdata=Array.initn(fun_->a:=!a+.step;!a)inmake_arrd(Owl_utils.calc_strided)dataletzerosd=created0.letonesd=created1.letnum_dimsx=Array.lengthx.shapeletshapex=Array.copyx.shapeletnth_dimxi=x.shape.(i)letnumelx=_calc_numel_from_shapex.shapeletgetxi=x.data.(Owl_utils.index_nd_1dix.stride)letsetxia=x.data.(Owl_utils.index_nd_1dix.stride)<-aletget_indexxaxis=letd=num_dimsxinassert(Array.lengthaxis=d);letn=Array.lengthaxis.(0)inletindices=Array.make_matrixdn0inArray.iteri(funia->Array.iteri(funjb->indices.(i).(j)<-b)a)axis;Array.map(funi->getxi)indicesletset_indexxaxisa=letd=num_dimsxinassert(Array.lengthaxis=d);letn=Array.lengthaxis.(0)inletindices=Array.make_matrixdn0inArray.iteri(funia->Array.iteri(funjb->indices.(i).(j)<-b)a)axis;Array.iteri(funij->setxja.(i))indicesletslice_left=Noneletcopy_tosrcdst=assert(src.shape=dst.shape);Array.blitsrc.data0dst.data0(numelsrc)letfillxa=Array.fillx.data0(numelx)aletreshapexd=letm=_calc_numel_from_shapex.shapeinletn=_calc_numel_from_shapedinassert(m=n);make_arrdx.stridex.dataletflattenx=make_arr[|Array.lengthx.data|][|1|]x.dataletcopyx={shape=Array.copyx.shape;stride=Array.copyx.stride;data=Array.copyx.data;}letsame_shapexy=x.shape=y.shapeletsub_leftxi=leti_len=Array.lengthiinlets_len=x.stride.(i_len-1)inletpad_len=num_dimsx-i_leninassert(pad_len>0);leti=Owl_utils.Array.pad`Righti0pad_leninletstart_pos=Owl_utils.index_nd_1dix.strideinletdata_y=Array.subx.datastart_poss_leninletshape_y=Array.subx.shapei_lenpad_leninletstride_y=Array.subx.stridei_lenpad_leninmake_arrshape_ystride_ydata_yletsqueeze?(axis=[||])x=leta=matchArray.lengthaxiswith|0->Array.init(num_dimsx)(funi->i)|_->axisinlets=Owl_utils.Array.filteri(funiv->not(v==1&&Array.memia))(shapex)inreshapexsletexpand?(hi=false)xd=letd0=d-(num_dimsx)inmatchd0>0with|true->(ifhi=truethenOwl_utils.Array.pad`Right(shapex)1d0|>reshapexelseOwl_utils.Array.pad`Left(shapex)1d0|>reshapex)|false->xletreversex=lety=copyxinOwl_utils.array_reversey.data;y(* iteration functions *)letiterfx=fori=0to(numelx)-1dofx.data.(i)|>ignoredoneletiterifx=fori=0to(numelx)-1dofix.data.(i)|>ignoredoneletmapfx=make_arrx.shapex.stride(Array.init(numelx)(funi->fx.data.(i)))letmapifx=make_arrx.shapex.stride(Array.init(numelx)(funi->fix.data.(i)))letiter2fxy=assert(x.shape=y.shape);fori=0to(numelx)-1dofx.data.(i)y.data.(i)|>ignoredoneletiter2ifxy=assert(x.shape=y.shape);fori=0to(numelx)-1dofix.data.(i)y.data.(i)|>ignoredoneletmap2fxy=assert(x.shape=y.shape);make_arrx.shapex.stride(Array.init(numelx)(funi->fx.data.(i)y.data.(i)))letmap2ifxy=assert(x.shape=y.shape);make_arrx.shapex.stride(Array.init(numelx)(funi->fix.data.(i)y.data.(i)))letfilterifx=lets=Owl_utils.Stack.make()initeri(funiy->matchfiywith|true->Owl_utils.Stack.pushsi|false->())x;Owl_utils.Stack.to_arraysletfilterfx=filteri(fun_y->fy)xletfoldfax=Array.fold_leftfax.dataletfoldifax=leta=refainfori=0tonumelx-1doa:=fi!ax.data.(i)done;!aletexistsfx=Array.existsfx.dataletnot_existsfx=not(existsfx)letfor_allfx=Array.for_allfx.data(* some comparison functions *)letis_equal?(cmp=Pervasives.compare)xy=assert(x.shape=y.shape);letr=reftrueintryiter2(funab->if(cmpab)<>0then(r:=false;failwith"found";))xy;!rwithFailure_->!rletnot_equal?(cmp=Pervasives.compare)xy=assert(x.shape=y.shape);letr=reftrueintryiter2(funab->if(cmpab)=0then(r:=false;failwith"found";))xy;!rwithFailure_->!rletgreater?(cmp=Pervasives.compare)xy=assert(x.shape=y.shape);letr=reftrueintryiter2(funab->if(cmpab)<>1then(r:=false;failwith"found";))xy;!rwithFailure_->!rletless?(cmp=Pervasives.compare)xy=assert(x.shape=y.shape);letr=reftrueintryiter2(funab->if(cmpab)<>(-1)then(r:=false;failwith"found";))xy;!rwithFailure_->!rletgreater_equal?(cmp=Pervasives.compare)xy=assert(x.shape=y.shape);letr=reftrueintryiter2(funab->if(cmpab)=(-1)then(r:=false;failwith"found";))xy;!rwithFailure_->!rletless_equal?(cmp=Pervasives.compare)xy=assert(x.shape=y.shape);letr=reftrueintryiter2(funab->if(cmpab)=1then(r:=false;failwith"found";))xy;!rwithFailure_->!rletelt_equal?(cmp=Pervasives.compare)xy=map2(funab->cmpab=0)xyletelt_not_equal?(cmp=Pervasives.compare)xy=map2(funab->cmpab<>0)xyletelt_greater?(cmp=Pervasives.compare)xy=map2(funab->cmpab=1)xyletelt_less?(cmp=Pervasives.compare)xy=map2(funab->cmpab=(-1))xyletelt_greater_equal?(cmp=Pervasives.compare)xy=map2(funab->cmpab<>(-1))xyletelt_less_equal?(cmp=Pervasives.compare)xy=map2(funab->cmpab<>1)xyletelt_equal_scalar?(cmp=Pervasives.compare)xb=map(funa->cmpab=0)xletelt_not_equal_scalar?(cmp=Pervasives.compare)xb=map(funa->cmpab<>0)xletelt_greater_scalar?(cmp=Pervasives.compare)xb=map(funa->cmpab=1)xletelt_less_scalar?(cmp=Pervasives.compare)xb=map(funa->cmpab=(-1))xletelt_greater_equal_scalar?(cmp=Pervasives.compare)xb=map(funa->cmpab<>(-1))xletelt_less_equal_scalar?(cmp=Pervasives.compare)xb=map(funa->cmpab<>1)xletsort?(cmp=Pervasives.compare)x=Array.sortcmpx.dataletmax?(cmp=Pervasives.compare)x=letr=refx.data.(0)initer(funa->matchcmpa!rwith|1->r:=a|_->())x;!rletmin?(cmp=Pervasives.compare)x=letr=refx.data.(0)initer(funa->matchcmp!rawith|1->r:=a|_->())x;!rletmax_i?(cmp=Pervasives.compare)x=letr=refx.data.(0)inletj=ref0initeri(funia->matchcmpa!rwith|1->r:=a;j:=i|_->())x;!r,!jletmin_i?(cmp=Pervasives.compare)x=letr=refx.data.(0)inletj=ref0initeri(funia->matchcmp!rawith|1->r:=a;j:=i|_->())x;!r,!j(* operational functions *)let_check_transpose_axisaxisd=assert(Array.lengthaxis=d);leth=Hashtbl.create16inArray.iter(funx->assert(0<=x&&x<d);assert(Hashtbl.memhx=false);Hashtbl.addhx0)axislettranspose?axisx=letd=num_dimsxinleta=matchaxiswith|Somea->a|None->Array.initd(funi->d-i-1)in(* check if axis is a correct permutation *)_check_transpose_axisad;lets0=shapexinlets1=Array.map(funj->s0.(j))ainleti'=Array.maked0inleti=Array.maked0inlety=make_arrs1(Owl_utils.calc_strides1)(Array.copyx.data)initeri(funi_1dz->Owl_utils.index_1d_ndi_1dix.stride;Array.iteri(funkj->i'.(k)<-i.(j))a;setyi'z)x;ylettilexreps=(* check the validity of reps *)ifArray.exists((>)1)repsthenfailwith"tile: repitition must be >= 1";(* align and promote the shape *)leta=num_dimsxinletb=Array.lengthrepsinletx,reps=matcha<bwith|true->(letd=Owl_utils.Array.pad`Left(shapex)1(b-a)in(reshapexd),reps)|false->(letr=Owl_utils.Array.pad`Leftreps1(a-b)inx,r)in(* calculate the smallest continuous slice dx *)leti=ref(Array.lengthreps-1)inletsx=shapexinletdx=refsx.(!i)inwhilereps.(!i)=1&&!i-1>=0doi:=!i-1;dx:=!dx*sx.(!i);done;(* make the array to store the result *)letsy=Owl_utils.Array.map2i(fun_ab->a*b)sxrepsinlety_data=Array.make(_calc_numel_from_shapesy)x.data.(0)inlety=make_arrsy(Owl_utils.calc_stridesy)y_datain(* project x and y to 1-dimensional arrays *)letx1=x.datainlety1=y.datainletstride_x=Owl_utils.calc_stride(shapex)inletstride_y=Owl_utils.calc_stride(shapey)in(* recursively tile the data within y *)letrec_tileofsxofsylvl=iflvl=!ithen(fork=0toreps.(lvl)-1doletofsy'=ofsy+(k*!dx)inArray.blitx1ofsxy1ofsy'!dx;done;)else(forj=0tosx.(lvl)-1doletofsx'=ofsx+j*stride_x.(lvl)inletofsy'=ofsy+j*stride_y.(lvl)in_tileofsx'ofsy'(lvl+1);done;let_len=stride_y.(lvl)*sx.(lvl)infork=1toreps.(lvl)-1doletofsy'=ofsy+(k*_len)inArray.blity1ofsyy1ofsy'_lendone)in_tile000;yletrepeat?axisxreps=lethighest_dim=Array.length(shapex)-1in(* by default, repeat at the highest dimension *)letaxis=matchaxiswith|Somea->a|None->highest_dimin(* calculate the new shape of y based on reps *)let_shape_y=shapexin_shape_y.(axis)<-_shape_y.(axis)*reps;lety_data=Array.make(_calc_numel_from_shape_shape_y)x.data.(0)inlety=make_arr_shape_y(Owl_utils.calc_stride_shape_y)y_datain(* transform into a flat array first *)letx'=x.datainlety'=y.datain(* if repeat at the highest dimension, use this strategy *)ifaxis=highest_dimthen((* TODO: omg, cannot use blit, so have to copy one by one, I need to fiugre
out a more efficient way to copy at the highest dimension. *)letofsy=ref0infori=0tonumelx-1doforj=0toreps-1doy'.(!ofsy)<-x'.(i);ofsy:=!ofsy+1;donedone)(* if repeat at another dimension, use this block copying *)else(let_stride_x=Owl_utils.calc_stride(shapex)inlet_slice_sz=_stride_x.(axis)in(* be careful of the index, this is fortran layout *)fori=0to(numelx)/_slice_sz-1doletofsx=i*_slice_szinforj=0toreps-1doletofsy=(i*reps+j)*_slice_szinArray.blitx'ofsxy'ofsy_slice_sz;donedone);(* all done, return the result *)yletconcatenate?(axis=0)xs=(* get the shapes of all inputs and etc. *)letshapes=Array.mapshapexsinletshape0=Array.copyshapes.(0)inshape0.(axis)<-0;letacc_dim=ref0in(* validate all the input shapes; update step_sz *)letstep_sz=Array.(make(lengthxs)0)inArray.iteri(funishape1->step_sz.(i)<-(Owl_utils.calc_sliceshape1).(axis);acc_dim:=!acc_dim+shape1.(axis);shape1.(axis)<-0;assert(shape0=shape1);)shapes;(* allocalte space for new array *)shape0.(axis)<-!acc_dim;lety_data=Array.make(_calc_numel_from_shapeshape0)xs.(0).data.(0)inlety=make_arrshape0(Owl_utils.calc_strideshape0)y_datain(* flatten y then calculate the number of copies *)letz=y.datainletslice_sz=(Owl_utils.calc_sliceshape0).(axis)inletm=numely/slice_szinletn=Array.lengthxsin(* flatten all the inputs and init the copy location *)letx_flt=Array.map(funx->x.data)xsinletx_ofs=Array.maken0in(* copy data in the flattened space *)letz_ofs=ref0infori=0tom-1doforj=0ton-1doArray.blitx_flt.(j)x_ofs.(j)z!z_ofsstep_sz.(j);x_ofs.(j)<-x_ofs.(j)+step_sz.(j);z_ofs:=!z_ofs+step_sz.(j);done;done;(* all done, return the combined result *)y(* the following four padding related functions, they are simply the replica
from Owl_dense_ndarray_generic module, so please refer to that module. *)let_expand_padding_indexds=letls=Array.lengthsinletld=Array.lengthdinletd=Owl_utils.(Array.pad`Rightd[|0;0|](ls-ld))inArray.map(function|[||]->[|0;0|]|[|x|]->[|x;x|]|x->x)dletrec_copy_to_paddingp1lsl0l1i0i1d0d1s0s1x0x1=ifd0<d1then(fori=0tos0.(d0)-1doi0.(d0)<-i;i1.(d0)<-i+p1.(d0).(0);_copy_to_paddingp1lsl0l1i0i1(d0+1)d1s0s1x0x1;i0.(d0)<-0;i1.(d0)<-p1.(d0).(0);done)else(letj0=Owl_utils.index_nd_1di0l0inletj1=Owl_utils.index_nd_1di1l1inArray.blitx0j0x1j1ls.(d0))let_highest_padding_dimensionp=letl=Array.lengthp-1inletd=reflin(tryfori=ldownto0dod:=i;ifp.(i)<>[|0;0|]thenfailwith"stop"donewithexn->());!dletpadvdx=lets0=shapexinletp1=_expand_padding_index(Owl_utils.llss2aarrd)s0inlets1=Array.map2(funmn->m+n.(0)+n.(1))s0p1in(* create ndarray y for storing the result *)lety_data=Array.make(_calc_numel_from_shapes1)vinlety=make_arrs1(Owl_utils.calc_strides1)y_datain(* prepare variables for block copying *)letls=Owl_utils.calc_slices0inletl0=Owl_utils.calc_strides0inletl1=Owl_utils.calc_strides1inleti0=Array.make(num_dimsx)0inleti1=Array.map(funa->a.(0))p1inletd0=0inletd1=_highest_padding_dimensionp1inletx0=x.datainletx1=y.datain_copy_to_paddingp1lsl0l1i0i1d0d1s0s1x0x1;y(* get_fancy function is adapted from its original implementation in owl_slicing
module, refer to Owl_slicing module for more information
*)letget_fancyaxisx=letaxis=Owl_slicing.sdlist_to_sdarrayaxisin(* check axis is within boundary then re-format *)lets0=shapexinletaxis=Owl_slicing.check_slice_definitionaxiss0in(* calculate the new shape for slice *)lets1=Owl_slicing.calc_slice_shapeaxisinlety_data=Array.make(_calc_numel_from_shapes1)x.data.(0)inlety=make_arrs1(Owl_utils.calc_strides1)y_datain(* transform into 1d array *)letx'=x.datainlety'=y.datain(* prepare function of copying blocks *)letd0=Array.lengths1inletd1,cb=Owl_slicing.calc_continuous_blkszaxiss0inletsd=Owl_utils.calc_strides0inletofsy_i=ref0in(* two copying strategies based on the size of the minimum continuous block *)lethigh_dim_list=(functionL__->true|_->false)axis.(d0-1)inifcb>1||s0.(d0-1)=1||high_dim_list=truethen((* yay, there are at least some continuous blocks *)letb=cbinletf=funi->(letofsx=Owl_utils.index_nd_1disdinletofsy=!ofsy_i*binArray.blitx'ofsxy'ofsyb;ofsy_i:=!ofsy_i+1)in(* start copying blocks *)Owl_slicing._foreach_continuous_blkaxisd1f;(* all done, return the result *)y)else((* copy happens at the highest dimension, no continuous block *)letb=s1.(d0-1)inletc,dd=matchaxis.(d0-1)with|R_i->(ifi.(2)>0theni.(2),i.(0)elsei.(2),i.(0)+(b-1)*i.(2))|_->failwith"owl_dense_ndarray_a:slice_array_typ"inletcx=ifc>0thencelse-cinletcy=ifc>0then1else-1in(* TODO: blit cannot be used, have to copy one by one *)letf=funi->(letofsx=ref(Owl_utils.index_nd_1disd+dd)inletofsy=ifc>0thenref(!ofsy_i*b)elseref((!ofsy_i+1)*b-1)infori=0tob-1doy'.(!ofsy)<-x'.(!ofsx);ofsx:=!ofsx+cx;ofsy:=!ofsy+cy;done;ofsy_i:=!ofsy_i+1)in(* start copying blocks *)Owl_slicing._foreach_continuous_blkaxis(d1-1)f;(* all done, return the result *)y)(* set_fancy function is adapted from its original implementation in owl_slicing
module, refer to Owl_slicing module for more information
*)letset_fancyaxisxy=letaxis=Owl_slicing.sdlist_to_sdarrayaxisin(* check axis is within boundary then re-format *)lets0=shapexinletaxis=Owl_slicing.check_slice_definitionaxiss0in(* calculate the new shape for slice *)lets1=Owl_slicing.calc_slice_shapeaxisinassert(shapey=s1);(* transform into 1d array *)letx'=x.datainlety'=y.datain(* prepare function of copying blocks *)letd0=Array.lengths1inletd1,cb=Owl_slicing.calc_continuous_blkszaxiss0inletsd=Owl_utils.calc_strides0inletofsy_i=ref0in(* two copying strategies based on the size of the minimum continuous block *)lethigh_dim_list=(functionL__->true|_->false)axis.(d0-1)inifcb>1||s0.(d0-1)=1||high_dim_list=truethen((* yay, there are at least some continuous blocks *)letb=cbinletf=funi->(letofsx=Owl_utils.index_nd_1disdinletofsy=!ofsy_i*binArray.blity'ofsyx'ofsxb;ofsy_i:=!ofsy_i+1)in(* start copying blocks *)Owl_slicing._foreach_continuous_blkaxisd1f)else((* copy happens at the highest dimension, no continuous block *)letb=s1.(d0-1)inletc,dd=matchaxis.(d0-1)with|R_i->(ifi.(2)>0theni.(2),i.(0)elsei.(2),i.(0)+(b-1)*i.(2))|_->failwith"owl_dense_ndarray_a:slice_array_typ"inletcx=ifc>0thencelse-cinletcy=ifc>0then1else-1in(* TODO: blit cannot be used, have to copy one by one *)letf=funi->(letofsx=ref(Owl_utils.index_nd_1disd+dd)inletofsy=ifc>0thenref(!ofsy_i*b)elseref((!ofsy_i+1)*b-1)infori=0tob-1dox'.(!ofsx)<-y'.(!ofsy);ofsx:=!ofsx+cx;ofsy:=!ofsy+cy;done;ofsy_i:=!ofsy_i+1)in(* start copying blocks *)Owl_slicing._foreach_continuous_blkaxis(d1-1)f)(* simplified get_fancy function which accept list of list as slice definition.
adapted from owl_slicing module, refer to Owl_slicing for more information.
*)letget_sliceaxisx=letaxis=List.map(funi->Ri)axisinget_fancyaxisx(* simplified set_slice function which accept list of list as slice definition
adapted from owl_slicing module, refer to Owl_slicing for more information.
*)letset_sliceaxisxy=letaxis=List.map(funi->Ri)axisinset_fancyaxisxyletswapa0a1x=letd=num_dimsxinleta=Array.initd(funi->i)inlett=a.(a0)ina.(a0)<-a.(a1);a.(a1)<-t;transpose~axis:axletstridesx=x|>shape|>Owl_utils.calc_strideletslice_sizex=x|>shape|>Owl_utils.calc_sliceletindex_nd_1di_nd_stride=Owl_utils.index_nd_1di_nd_strideletindex_1d_ndi_1d_stride=leti_nd=Array.copy_strideinOwl_utils.index_1d_ndi_1di_nd_stride;i_nd(* input/output functions *)letof_arrayxd=make_arrd(Owl_utils.calc_strided)xletto_arrayx=x.data(* ends here *)