123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224# 1 "src/base/core/owl_base_slicing.ml"(*
* OWL - OCaml Scientific Computing
* Copyright (c) 2016-2022 Liang Wang <liang@ocaml.xyz>
*)openBigarrayopenOwl_types(* convert from a list of slice definition to array for internal use *)letsdlist_to_sdarrayaxis=List.map(function|Ii->I_i|Li->L_(Array.of_listi)|Ri->R_(Array.of_listi))axis|>Array.of_listletsdarray_to_sdarrayaxis=Array.map(function|Ii->I_i|Li->L_(Array.of_listi)|Ri->R_(Array.of_listi))axis(* return true if slicing (all R_) or false if fancy indexing (has L_) *)letis_basic_slicing=Array.for_all(function|R__->true|_->false)(* check the validity of the slice definition, also re-format slice definition,
axis: slice definition;
shp: shape of the original ndarray;
*)letcheck_slice_definitionaxisshp=letaxis_len=Array.lengthaxisinletshp_len=Array.lengthshpinassert(axis_len<=shp_len);(* add missing definition on higher dimensions *)letaxis=ifaxis_len<shp_lenthen(letsuffix=Array.make(shp_len-axis_len)(R_[||])inArray.appendaxissuffix)elseaxisin(* re-format slice definition, note I_ will be replaced with L_ *)Array.map2(funin->matchiwith|I_x->letx=ifx>=0thenxelsen+xinassert(x<n);R_[|x;x;1|]|L_x->letis_cont=reftrueinifArray.lengthx<>nthenis_cont:=false;letx=Array.mapi(funij->letj=ifj>=0thenjelsen+jinassert(j<n);ifi<>jthenis_cont:=false;j)xinif!is_cont=truethenR_[|0;n-1;1|]elseL_x|R_x->(matchArray.lengthxwith|0->R_[|0;n-1;1|]|1->leta=ifx.(0)>=0thenx.(0)elsen+x.(0)inassert(a<n);R_[|a;a;1|]|2->leta=ifx.(0)>=0thenx.(0)elsen+x.(0)inletb=ifx.(1)>=0thenx.(1)elsen+x.(1)inletc=ifa<=bthen1else-1inassert(not(a>=n||b>=n));R_[|a;b;c|]|3->leta=ifx.(0)>=0thenx.(0)elsen+x.(0)inletb=ifx.(1)>=0thenx.(1)elsen+x.(1)inletc=x.(2)inassert(not(a>=n||b>=n||c=0));assert(not((a<b&&c<0)||(a>b&&c>0)));R_[|a;b;c|]|_->failwith"check_slice_definition: error"))axisshp(* calculate the minimum continuous block size and its corresponding dimension
axis: slice definition;
shp: shape of the original ndarray;
*)letcalc_continuous_blkszaxisshp=letslice_sz=Owl_utils.calc_sliceshpinletssz=ref1inletd=ref0inlet_=tryforl=Array.lengthshp-1downto-1do(* note: d is actually the corresponding dimension of continuous block
plus one; also note the loop is down to -1 so the lowest dimension is
also considered, in which case the whole array is copied. *)d:=l+1;ifl<0thenfailwith"stop";matchaxis.(l)with|I__->failwith"stop"(* never reached *)|L__->failwith"stop"|R_x->ifx.(0)=0&&x.(1)=shp.(l)-1&&x.(2)=1thenssz:=slice_sz.(l)elsefailwith"stop"donewith|_exn->()in!d,!ssz(* calculat the shape according to the slice definition
axis: slice definition
*)letcalc_slice_shapeaxis=Array.map(function|I__x->1(* never reached *)|L_x->Array.lengthx|R_x->leta,b,c=x.(0),x.(1),x.(2)inStdlib.(abs((b-a)/c))+1)axis(* recursively copy the continuous block, stop at its corresponding dimension d
a: slice definition
d: the corresponding dimension of continuous block + 1
j: current dimension index
i: current index of the data for copying
f: copy function of the continuous block
*)letrec__foreach_continuous_blkadjif=ifj=dthenfielse(matcha.(j)with|I__->((* never reache here *))|L_x->Array.iter(funk->i.(j)<-k;__foreach_continuous_blkad(j+1)if)x|R_x->letk=refx.(0)inifx.(2)>0thenwhile!k<=x.(1)doi.(j)<-!k;k:=!k+x.(2);__foreach_continuous_blkad(j+1)ifdoneelsewhile!k>=x.(1)doi.(j)<-!k;k:=!k+x.(2);__foreach_continuous_blkad(j+1)ifdone)(* a : slice definition, same rank as original ndarray
d : the corresponding dimension of the continuous block +1
f : the copy function for the continuous block
*)let_foreach_continuous_blkadf=leti=Array.(make(lengtha)0)in__foreach_continuous_blkad0if(* reshape inputs in order to optimise the slicing performance *)letoptimise_input_shapeaxisxy=letn=Genarray.num_dimsxinletsx=Genarray.dimsxinletsy=Genarray.dimsyinletdim=ref(n-1)inletacx=ref1inletacy=ref1in(tryfori=!dimdownto0domatchaxis.(i)with|R_a->ifa.(0)=0&&a.(1)=sx.(i)-1&&a.(2)=1then(acx:=!acx*sx.(i);acy:=!acy*sy.(i);dim:=i)elsefailwith"stop"|_->failwith"stop"donewith|_exn->());ifn-!dim>1then((* can be optimised *)letaxis'=Array.subaxis0(!dim+1)inletsx'=Array.subsx0(!dim+1)inletsy'=Array.subsy0(!dim+1)insx'.(!dim)<-!acx;sy'.(!dim)<-!acy;letx'=reshapexsx'inlety'=reshapeysy'inaxis',x',y')else(* cannot be optimised *)axis,x,y(* ends here *)