123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455(* Extended Bigarray module with additional types *)(* Re-export all standard bigarray types first *)includeStdlib.Bigarray(* Additional element types - following Bigarray naming convention *)typebfloat16_elt=Bfloat16_elttypebool_elt=Bool_elttypeint4_signed_elt=Int4_signed_elttypeint4_unsigned_elt=Int4_unsigned_elttypefloat8_e4m3_elt=Float8_e4m3_elt(* 4 exponent bits, 3 mantissa bits *)typefloat8_e5m2_elt=Float8_e5m2_elt(* 5 exponent bits, 2 mantissa bits *)typecomplex16_elt=Complex16_elt(* Half-precision complex *)typeqint8_elt=Qint8_elt(* Quantized int8 *)typequint8_elt=Quint8_elt(* Quantized uint8 *)(* Shadow the kind type to include our new types *)type('a,'b)kind=|Float32:(float,float32_elt)kind|Float64:(float,float64_elt)kind|Int8_signed:(int,int8_signed_elt)kind|Int8_unsigned:(int,int8_unsigned_elt)kind|Int16_signed:(int,int16_signed_elt)kind|Int16_unsigned:(int,int16_unsigned_elt)kind|Int32:(int32,int32_elt)kind|Int64:(int64,int64_elt)kind|Int:(int,int_elt)kind|Nativeint:(nativeint,nativeint_elt)kind|Complex32:(Complex.t,complex32_elt)kind|Complex64:(Complex.t,complex64_elt)kind|Char:(char,int8_unsigned_elt)kind|Float16:(float,float16_elt)kind|Bfloat16:(float,bfloat16_elt)kind|Bool:(bool,bool_elt)kind|Int4_signed:(int,int4_signed_elt)kind|Int4_unsigned:(int,int4_unsigned_elt)kind|Float8_e4m3:(float,float8_e4m3_elt)kind|Float8_e5m2:(float,float8_e5m2_elt)kind|Complex16:(Complex.t,complex16_elt)kind|Qint8:(int,qint8_elt)kind|Quint8:(int,quint8_elt)kind(* Shadow the value constructors *)letfloat32=Float32letfloat64=Float64letint8_signed=Int8_signedletint8_unsigned=Int8_unsignedletint16_signed=Int16_signedletint16_unsigned=Int16_unsignedletint32=Int32letint64=Int64letint=Intletnativeint=Nativeintletcomplex32=Complex32letcomplex64=Complex64letchar=Charletfloat16=Float16letbfloat16=Bfloat16letbool=Boolletint4_signed=Int4_signedletint4_unsigned=Int4_unsignedletfloat8_e4m3=Float8_e4m3letfloat8_e5m2=Float8_e5m2letcomplex16=Complex16letqint8=Qint8letquint8=Quint8(* Shadow kind_size_in_bytes to handle new types *)letkind_size_in_bytes:typeab.(a,b)kind->int=function|Float16->2|Float32->4|Float64->8|Int8_signed->1|Int8_unsigned->1|Int16_signed->2|Int16_unsigned->2|Int32->4|Int64->8|Int->Sys.word_size/8|Nativeint->Sys.word_size/8|Complex32->8|Complex64->16|Char->1|Bfloat16->2|Bool->1|Int4_signed->1(* 2 values packed per byte *)|Int4_unsigned->1(* 2 values packed per byte *)|Float8_e4m3->1|Float8_e5m2->1|Complex16->4(* 2 x float16 *)|Qint8->1|Quint8->1(* Convert our extended kind to stdlib kind for fallback *)letto_stdlib_kind:typeab.(a,b)kind->(a,b)Stdlib.Bigarray.kindoption=function|Float32->SomeStdlib.Bigarray.Float32|Float64->SomeStdlib.Bigarray.Float64|Int8_signed->SomeStdlib.Bigarray.Int8_signed|Int8_unsigned->SomeStdlib.Bigarray.Int8_unsigned|Int16_signed->SomeStdlib.Bigarray.Int16_signed|Int16_unsigned->SomeStdlib.Bigarray.Int16_unsigned|Int32->SomeStdlib.Bigarray.Int32|Int64->SomeStdlib.Bigarray.Int64|Int->SomeStdlib.Bigarray.Int|Nativeint->SomeStdlib.Bigarray.Nativeint|Complex32->SomeStdlib.Bigarray.Complex32|Complex64->SomeStdlib.Bigarray.Complex64|Char->SomeStdlib.Bigarray.Char|Float16->SomeStdlib.Bigarray.Float16|Bfloat16->None|Bool->None|Int4_signed->None|Int4_unsigned->None|Float8_e4m3->None|Float8_e5m2->None|Complex16->None|Qint8->None|Quint8->None(* External functions for creating arrays with new types *)externalcreate_bfloat16_genarray:'clayout->intarray->('a,'b,'c)Genarray.t="caml_nx_ba_create_bfloat16"externalcreate_bool_genarray:'clayout->intarray->('a,'b,'c)Genarray.t="caml_nx_ba_create_bool"externalcreate_int4_signed_genarray:'clayout->intarray->('a,'b,'c)Genarray.t="caml_nx_ba_create_int4_signed"externalcreate_int4_unsigned_genarray:'clayout->intarray->('a,'b,'c)Genarray.t="caml_nx_ba_create_int4_unsigned"externalcreate_float8_e4m3_genarray:'clayout->intarray->('a,'b,'c)Genarray.t="caml_nx_ba_create_float8_e4m3"externalcreate_float8_e5m2_genarray:'clayout->intarray->('a,'b,'c)Genarray.t="caml_nx_ba_create_float8_e5m2"externalcreate_complex16_genarray:'clayout->intarray->('a,'b,'c)Genarray.t="caml_nx_ba_create_complex16"externalcreate_qint8_genarray:'clayout->intarray->('a,'b,'c)Genarray.t="caml_nx_ba_create_qint8"externalcreate_quint8_genarray :'clayout->intarray->('a,'b,'c)Genarray.t="caml_nx_ba_create_quint8"(* External functions for get/set operations on extended types *)externalnx_ba_get_generic:('a,'b,'c)Genarray.t->intarray->'a="caml_nx_ba_get_generic"externalnx_ba_set_generic:('a,'b,'c)Genarray.t->intarray->'a->unit="caml_nx_ba_set_generic"(* External function to get extended kind - needs C stub implementation *)externalnx_ba_kind:('a,'b,'c)Genarray.t->('a,'b)kind="caml_nx_ba_kind"(* Shadow the Genarray module *)moduleGenarray=structincludeStdlib.Bigarray.Genarray(* Shadow create to handle new types *)letcreate:typeabc.(a,b)kind->clayout->intarray->(a,b,c)t=funkindlayoutdims->matchkindwith|Bfloat16 ->create_bfloat16_genarraylayoutdims|Bool->create_bool_genarraylayout dims|Int4_signed->create_int4_signed_genarraylayoutdims|Int4_unsigned->create_int4_unsigned_genarraylayoutdims|Float8_e4m3->create_float8_e4m3_genarraylayoutdims|Float8_e5m2->create_float8_e5m2_genarraylayoutdims|Complex16->create_complex16_genarraylayoutdims|Qint8->create_qint8_genarraylayoutdims|Quint8->create_quint8_genarraylayoutdims|_->(matchto_stdlib_kindkindwith|Somek->Stdlib.Bigarray.Genarray.createklayoutdims|None->failwith"Internal error: unhandled kind")(* Override kind to return Bigarray_ext.kind *)letkind:typeabc.(a,b,c)t->(a,b)kind=nx_ba_kind(* Shadow get to handle extended types *)letgetarridx=nx_ba_get_genericarridx(* Shadow set to handle extended types *)letsetarridxvalue=nx_ba_set_genericarridxvalue(* Shadow init function *)letinit(typet)kind(layout:tlayout)dimsf=letarr=createkindlayoutdimsinletdlen=Array.lengthdimsinmatchlayoutwith|C_layout->letrecclooparridxfcolmax=ifcol=Array.lengthidxthensetarridx(fidx)elseforj=0topredmax.(col)doidx.(col)<-j;clooparridxf(succcol)maxdoneinclooparr(Array.makedlen0)f0dims;arr|Fortran_layout->letrecflooparridxfcolmax=ifcol<0thensetarridx(fidx)elseforj=1tomax.(col)doidx.(col)<-j;flooparridxf(predcol)maxdoneinflooparr(Array.makedlen1)f(preddlen)dims;arr(* size_in_bytes needs to use our extended kind_size_in_bytes *)letsize_in_bytesarr=(* We can't get the extended kind from the array, so we keep the original *)Stdlib.Bigarray.Genarray.size_in_bytesarr(* Override blit to handle extended types *)externalnx_ba_blit_genarray:('a,'b,'c)t->('a,'b,'c)t->unit="caml_nx_ba_blit"letblit=nx_ba_blit_genarray(* Override fill for extended types *)externalnx_ba_fill:('a,'b,'c)t->'a->unit="caml_nx_ba_fill"letfill=nx_ba_fillend(* Shadow Array0 module *)moduleArray0=structincludeStdlib.Bigarray.Array0letcreate:typeabc.(a,b)kind->clayout->(a,b,c)t=funkindlayout->array0_of_genarray(Genarray.createkindlayout[||])(* Override get and set to use our extended functions *)letgetarr=Genarray.get(genarray_of_array0arr)[||]letsetarrv=Genarray.set(genarray_of_array0arr)[||]vletinitkindlayoutf=leta=createkindlayoutinGenarray.set(genarray_of_array0a)[||](f());aletof_value=initend(* Shadow Array1 module *)moduleArray1=structincludeStdlib.Bigarray.Array1letcreate:typeabc.(a,b)kind->clayout->int->(a,b,c)t=funkindlayoutdim->array1_of_genarray(Genarray.createkindlayout[|dim|])(* Override kind to return Bigarray_ext.kind *)letkind:typeabc.(a,b,c)t->(a,b)kind=funarr->Genarray.kind(genarray_of_array1arr)(* Override get and set to use our extended functions *)letgetarri=Genarray.get(genarray_of_array1arr)[|i|]letsetarriv=Genarray.set(genarray_of_array1arr)[|i|]v(* unsafe versions just call the safe versions for extended types *)letunsafe_getarri=getarriletunsafe_setarriv=setarriv(* Override blit to handle extended types *)externalnx_ba_blit:('a,'b,'c)t->('a,'b,'c)t->unit="caml_nx_ba_blit"letblit=nx_ba_blit(* Override fill for extended types *)externalnx_ba_fill:('a,'b,'c)t->'a->unit="caml_nx_ba_fill"letfill=nx_ba_fillletinit(typet)kind(layout:tlayout)dimf=letarr=createkindlayoutdiminmatchlayoutwith|C_layout->fori=0topreddimdounsafe_setarri(fi)done;arr|Fortran_layout->fori=1todimdounsafe_setarri(fi)done;arrletof_array(typet)kind(layout:tlayout)data=letba=createkindlayout(Array.lengthdata)inletofs=matchlayoutwithC_layout->0|Fortran_layout->1infori=0toArray.lengthdata-1dounsafe_setba(i+ofs)data.(i)done;baend(* Shadow Array2 module *)moduleArray2=structincludeStdlib.Bigarray.Array2letcreate:typeabc.(a,b)kind->clayout->int->int->(a,b,c)t=funkindlayoutdim1dim2->array2_of_genarray(Genarray.createkindlayout[|dim1;dim2|])(* Override kind to return Bigarray_ext.kind *)letkind:typeabc.(a,b,c)t->(a,b)kind=funarr->Genarray.kind(genarray_of_array2arr)(* Override get and set to use our extended functions *)letgetarrij=Genarray.get(genarray_of_array2arr)[|i;j|]letsetarrijv=Genarray.set(genarray_of_array2 arr)[|i;j|]v(* unsafe versions just call the safe versions for extended types *)letunsafe_getarrij=getarrijletunsafe_setarrijv=setarrijv(* Override blit to handle extended types *)externalnx_ba_blit:('a,'b,'c)t->('a,'b,'c)t->unit="caml_nx_ba_blit"letblit=nx_ba_blit(* Override fill for extended types *)externalnx_ba_fill:('a,'b,'c)t->'a->unit="caml_nx_ba_fill"letfill=nx_ba_fillletinit(typet)kind(layout:tlayout)dim1dim2f=letarr=createkindlayoutdim1dim2inmatchlayoutwith|C_layout->fori=0topreddim1doforj=0topreddim2dounsafe_setarrij(fij)donedone;arr|Fortran_layout->forj=1todim2dofori=1todim1dounsafe_setarrij(fij)donedone;arrletof_array(typet)kind(layout:tlayout)data=letdim1=Array.lengthdatainletdim2=ifdim1=0then0elseArray.lengthdata.(0)inletba=createkindlayoutdim1dim2inletofs=matchlayoutwithC_layout->0|Fortran_layout->1infori=0todim1-1doletrow=data.(i)inifArray.lengthrow<>dim2theninvalid_arg"Bigarray_ext.Array2.of_array: non-rectangular data";forj=0todim2-1dounsafe_setba(i+ofs)(j+ofs)row.(j)donedone;baend(* Shadow Array3 module *)moduleArray3=structincludeStdlib.Bigarray.Array3letcreate:typeabc.(a,b)kind->clayout->int->int->int->(a,b,c)t=funkindlayoutdim1dim2dim3->array3_of_genarray(Genarray.createkindlayout[|dim1;dim2;dim3|])(* Override kind to return Bigarray_ext.kind *)letkind:typeabc.(a,b,c)t->(a,b)kind=funarr->Genarray.kind(genarray_of_array3arr)(* Override get and set to use our extended functions *)letgetarrijk=Genarray.get(genarray_of_array3arr)[|i;j;k|]letsetarrijkv=Genarray.set(genarray_of_array3arr)[|i;j;k|]v(* unsafe versions just call the safe versions for extended types *)letunsafe_getarrijk=getarrijkletunsafe_setarrijkv=setarrijkv(* Override blit to handle extended types *)externalnx_ba_blit:('a,'b,'c)t->('a,'b,'c)t->unit="caml_nx_ba_blit"letblit=nx_ba_blit(* Override fill for extended types *)externalnx_ba_fill:('a,'b,'c)t->'a->unit="caml_nx_ba_fill"letfill=nx_ba_fillletinit(typet)kind(layout:tlayout)dim1dim2dim3f=letarr=createkindlayoutdim1dim2dim3inmatchlayoutwith|C_layout->fori=0topreddim1doforj=0topreddim2dofork=0topreddim3dounsafe_setarrijk(fijk)donedonedone;arr|Fortran_layout->fork=1todim3doforj=1todim2dofori=1todim1dounsafe_setarrijk(fijk)donedonedone;arrletof_array(typet)kind(layout:tlayout)data=letdim1=Array.lengthdatainletdim2=ifdim1=0then0elseArray.lengthdata.(0)inletdim3=ifdim2=0then0elseArray.lengthdata.(0).(0)inletba=createkindlayoutdim1dim2dim3inletofs=matchlayoutwithC_layout->0|Fortran_layout->1infori=0todim1-1doletrow=data.(i)inifArray.lengthrow<>dim2theninvalid_arg"Bigarray_ext.Array3.of_array: non-cubic data";forj=0todim2-1doletcol=row.(j)inifArray.lengthcol<>dim3theninvalid_arg"Bigarray_ext.Array3.of_array: non-cubic data";fork=0todim3-1dounsafe_setba(i+ofs)(j+ofs)(k+ofs)col.(k)donedonedone;baend