123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556exceptionCannot_writeexceptionRead_errorofstringletread_errorfmt=Printf.ksprintf(funs->raise(Read_errors))fmtletmagic_string="\147NUMPY"letmagic_string_len=String.lengthmagic_stringtypepacked_kind=P:(_,_)Bigarray_ext.kind->packed_kindletdtype~packed_kind=letendianness=matchpacked_kindwith|PBigarray_ext.Char->"|"|P_->ifSys.big_endianthen">"else"<"inletkind=matchpacked_kindwith|PBigarray_ext.Int32->"i4"|PBigarray_ext.Int64->"i8"|PBigarray_ext.Float16->"f2"|PBigarray_ext.Float32->"f4"|PBigarray_ext.Float64->"f8"|PBigarray_ext.Int8_unsigned->"u1"|PBigarray_ext.Int8_signed->"i1"|PBigarray_ext.Int16_unsigned->"u2"|PBigarray_ext.Int16_signed->"i2"|PBigarray_ext.Char->"S1"|PBigarray_ext.Complex32->"c8"(* 2 32bits float. *)|PBigarray_ext.Complex64->"c16"(* 2 64bits float. *)|PBigarray_ext.Int->failwith"Int is not supported"|PBigarray_ext.Nativeint->failwith"Nativeint is not supported."inendianness^kind(* For extended types, we can't use Unix.map_file, but we can still do file I/O
by creating the array in memory and using C stubs *)(* External functions for file I/O with bigarrays *)externalwrite_bigarray_to_fd:Unix.file_descr->('a,'b,'c)Bigarray_ext.Genarray.t->unit="caml_npy_write_bigarray"externalread_fd_to_bigarray:Unix.file_descr->('a,'b,'c)Bigarray_ext.Genarray.t->unit="caml_npy_read_bigarray"letextended_file_blitsrcdstfile_descrpos=(* This is called when we need to write an extended bigarray to a file.
dst is a dummy array we created, but we actually write to file_descr at pos *)ignore(Unix.lseekfile_descrposUnix.SEEK_SET);write_bigarray_to_fdfile_descrsrcletmap_filefile_descr~poskindlayoutsharedshape=letis_scalar=Array.lengthshape=0inletactual_shape=ifis_scalarthen[|1|]elseshapein(* Create the array first *)letarray=matchBigarray_ext.to_stdlib_kindkindwith|Somestd_kind->(* Standard bigarray type - use Unix.map_file for efficiency *)Unix.map_filefile_descr~posstd_kindlayoutsharedactual_shape|None->(* Extended type - create in memory and read if needed *)letarr=Bigarray_ext.Genarray.createkindlayoutactual_shapeinifnotsharedthenbegin(* Reading mode - read the file data into the array *)ignore(Unix.lseekfile_descr(Int64.to_intpos)Unix.SEEK_SET);read_fd_to_bigarrayfile_descrarrend;(* For writing mode (shared=true), we'll handle it when blit is called *)arrinifis_scalarthenBigarray_ext.reshapearray[||]elsearrayletfortran_order(typea)~(layout:aBigarray_ext.layout)=matchlayoutwith|Bigarray_ext.C_layout->"False"|Bigarray_ext.Fortran_layout->"True"letshape~dims=matchdimswith|[|dim1|]->Printf.sprintf"%d,"dim1|dims->Array.to_listdims|>List.mapstring_of_int|>String.concat", "letfull_header?header_len~layout~packed_kind~dims()=letheader=Printf.sprintf"{'descr': '%s', 'fortran_order': %s, 'shape': (%s), }"(dtype~packed_kind)(fortran_order~layout)(shape~dims)inletpadding_len=lettotal_len=String.lengthheader+magic_string_len+4+1inmatchheader_lenwith|None->iftotal_lenmod16=0then0else16-(total_lenmod16)|Someheader_len->ifheader_lenmod16<>0thenfailwith"header_len has to be divisible by 16";ifheader_len<total_lenthenfailwith"header_len is smaller than total_len";header_len-total_leninlettotal_header_len=String.lengthheader+padding_len+1inPrintf.sprintf"%s\001\000%c%c%s%s\n"magic_string(total_header_lenmod256|>Char.chr)(total_header_len/256|>Char.chr)header(String.makepadding_len' ')letwith_filefilenameflagsmask~f=letfile_descr=Unix.openfilefilenameflagsmaskintryletresult=ffile_descrinUnix.closefile_descr;resultwith|exn->Unix.closefile_descr;raiseexnletwrite?header_lenbigarrayfilename=with_filefilename[O_CREAT;O_TRUNC;O_RDWR]0o640~f:(funfile_descr->letfull_header=full_header()?header_len~layout:(Bigarray_ext.Genarray.layoutbigarray)~packed_kind:(P(Bigarray_ext.Genarray.kindbigarray))~dims:(Bigarray_ext.Genarray.dimsbigarray)inletfull_header_len=String.lengthfull_headerinifUnix.write_substringfile_descrfull_header0full_header_len<>full_header_lenthenraiseCannot_write;letkind=Bigarray_ext.Genarray.kindbigarrayinletfile_array=map_file~pos:(Int64.of_intfull_header_len)file_descrkind(Bigarray_ext.Genarray.layoutbigarray)true(Bigarray_ext.Genarray.dimsbigarray)inmatchBigarray_ext.to_stdlib_kindkindwith|Some_->(* Standard type - normal blit works with memory mapping *)Bigarray_ext.Genarray.blitbigarrayfile_array|None->(* Extended type - we created a dummy array, need custom file write *)extended_file_blitbigarrayfile_arrayfile_descrfull_header_len)letwrite1array1filename=write(Bigarray_ext.genarray_of_array1array1)filenameletwrite2array2filename=write(Bigarray_ext.genarray_of_array2array2)filenameletwrite3array3filename=write(Bigarray_ext.genarray_of_array3array3)filenamemoduleBatch_writer=structletheader_len=128typet={file_descr:Unix.file_descr;mutablebytes_written_so_far:int;mutabledims_and_packed_kind:(intarray*packed_kind)option}letappendtbigarray=letkind=Bigarray_ext.Genarray.kindbigarrayinletsize_in_bytes=Bigarray_ext.Genarray.size_in_bytesbigarrayinletfile_array=map_file~pos:(Int64.of_intt.bytes_written_so_far)t.file_descrkind(Bigarray_ext.Genarray.layoutbigarray)true(Bigarray_ext.Genarray.dimsbigarray)in(matchBigarray_ext.to_stdlib_kindkindwith|Some_->(* Standard type - normal blit works *)Bigarray_ext.Genarray.blitbigarrayfile_array|None->(* Extended type - need custom file write *)extended_file_blitbigarrayfile_arrayt.file_descrt.bytes_written_so_far);t.bytes_written_so_far<-t.bytes_written_so_far+size_in_bytes;matcht.dims_and_packed_kindwith|None->letdims=Bigarray_ext.Genarray.dimsbigarrayinletkind=Bigarray_ext.Genarray.kindbigarrayint.dims_and_packed_kind<-Some(dims,Pkind)|Some(dims,_kind)->letdims'=Bigarray_ext.Genarray.dimsbigarrayinletincorrect_dimensions=matchArray.to_listdims,Array.to_listdims'with|[],_|_,[]->true|_::d,_::d'->d<>d'inifincorrect_dimensionsthenPrintf.sprintf"Incorrect dimensions %s vs %s."(shape~dims)(shape~dims:dims')|>failwith;dims.(0)<-dims.(0)+dims'.(0)letcreatefilename=letfile_descr=Unix.openfilefilename[O_CREAT;O_TRUNC;O_RDWR]0o640in{file_descr;bytes_written_so_far=header_len;dims_and_packed_kind=None}letcloset=assert(Unix.lseekt.file_descr0SEEK_SET=0);letheader=matcht.dims_and_packed_kindwith|None->failwith"Nothing to write"|Some(dims,packed_kind)->full_header~header_len~layout:C_layout~dims~packed_kind()inifUnix.write_substringt.file_descrheader0header_len<>header_lenthenraiseCannot_write;Unix.closet.file_descrendletreally_readfdlen=letbuffer=Bytes.createleninletrecloopoffset=letread=Unix.readfdbufferoffset(len-offset)inifread+offset<lenthenloop(read+offset)elseifread=0thenread_error"unexpected eof"inloop0;Bytes.to_stringbuffermoduleHeader=structtypepacked_kind=P:(_,_)Bigarray_ext.kind->packed_kindtypet={kind:packed_kind;fortran_order:bool;shape:intarray}letsplitstr~on=letparens=ref0inletindexes=ref[]infori=0toString.lengthstr-1domatchstr.[i]with|'('->incrparens|')'->decrparens|cwhen!parens=0&&c=on->indexes:=i::!indexes|_->()done;List.fold_left(fun(prev_p,acc)index->index,String.substr(index+1)(prev_p-index-1)::acc)(String.lengthstr,[])!indexes|>fun(first_pos,acc)->String.substr0first_pos::acclettrimstr~on=letreclooprstartlen=iflen=0thenstart,lenelseifList.memstr.[start+len-1]onthenlooprstart(len-1)elsestart,leninletreclooplstartlen=iflen=0thenstart,lenelseifList.memstr.[start]onthenloopl(start+1)(len-1)elselooprstartleninletstart,len=loopl0(String.lengthstr)inString.substrstartlenletparseheader=letheader_fields=trimheader~on:['{';' ';'}';'\n']|>split~on:','|>List.mapString.trim|>List.filter(funs->String.lengths>0)|>List.map(funheader_field->matchsplitheader_field~on:':'with|[name;value]->trimname~on:['\'';' '],trimvalue~on:['\'';' ';'(';')']|_->read_error"unable to parse field %s"header_field)inletfind_fieldfield=tryList.assocfieldheader_fieldswith|Not_found->read_error"cannot find field %s"fieldinletkind=letkind=find_field"descr"in(matchkind.[0]with|'|'|'='->()|'>'->ifnotSys.big_endianthenread_error"big endian data but arch is little endian"|'<'->ifSys.big_endianthenread_error"little endian data but arch is big endian"|otherwise->read_error"incorrect endianness %c"otherwise);matchString.subkind1(String.lengthkind-1)with|"f4"->PFloat32|"f8"->PFloat64|"i4"->PInt32|"i8"->PInt64|"u1"->PInt8_unsigned|"i1"->PInt8_signed|"u2"->PInt16_unsigned|"i2"->PInt16_signed|"S1"->PChar|"c8"->PComplex32|"c16"->PComplex64|otherwise->read_error"incorrect descr %s"otherwiseinletfortran_order=matchfind_field"fortran_order"with|"False"->false|"True"->true|otherwise->read_error"incorrect fortran_order %s"otherwiseinletshape=find_field"shape"|>split~on:','|>List.mapString.trim|>List.filter(funs->String.lengths>0)|>List.mapint_of_string|>Array.of_listin{kind;fortran_order;shape}endtypepacked_array=P:(_,_,_)Bigarray_ext.Genarray.t->packed_arraytypepacked_array1=P1:(_,_,_)Bigarray_ext.Array1.t->packed_array1typepacked_array2=P2:(_,_,_)Bigarray_ext.Array2.t->packed_array2typepacked_array3=P3:(_,_,_)Bigarray_ext.Array3.t->packed_array3letread_mmapfilename~shared=letaccess=ifsharedthenUnix.O_RDWRelseO_RDONLYinletfile_descr=Unix.openfilefilename[access]0inletpos,header=tryletmagic_string'=really_readfile_descrmagic_string_leninifmagic_string<>magic_string'thenread_error"magic string mismatch";letversion=really_readfile_descr2|>funv->v.[0]|>Char.codeinletheader_len_len=matchversionwith|1->2|2->4|_->read_error"unsupported version %d"versioninletheader,header_len=really_readfile_descrheader_len_len|>funstr->letheader_len=ref0infori=String.lengthstr-1downto0doheader_len:=(256*!header_len)+Char.codestr.[i]done;really_readfile_descr!header_len,!header_leninletheader=Header.parseheaderinInt64.of_int(header_len+header_len_len+magic_string_len+2),headerwith|exn->Unix.closefile_descr;raiseexninlet(Header.Pkind)=header.kindinletbuildlayout=letarray=map_filefile_descr~poskindlayoutsharedheader.shapeinGc.finalise(fun_->Unix.closefile_descr)array;Parrayinifheader.fortran_orderthenbuildFortran_layoutelsebuildC_layoutletread_mmap1filename~shared=let(Parray)=read_mmapfilename~sharedinP1(Bigarray_ext.array1_of_genarrayarray)letread_mmap2filename~shared=let(Parray)=read_mmapfilename~sharedinP2(Bigarray_ext.array2_of_genarrayarray)letread_mmap3filename~shared=let(Parray)=read_mmapfilename~sharedinP3(Bigarray_ext.array3_of_genarrayarray)letread_copyfilename=let(Parray)=read_mmapfilename~shared:falseinletresult=Bigarray_ext.Genarray.create(Bigarray_ext.Genarray.kindarray)(Bigarray_ext.Genarray.layoutarray)(Bigarray_ext.Genarray.dimsarray)inBigarray_ext.Genarray.blitarrayresult;Presultletread_copy1filename=let(Parray)=read_copyfilenameinP1(Bigarray_ext.array1_of_genarrayarray)letread_copy2filename=let(Parray)=read_copyfilenameinP2(Bigarray_ext.array2_of_genarrayarray)letread_copy3filename=let(Parray)=read_copyfilenameinP3(Bigarray_ext.array3_of_genarrayarray)moduleNpz=structletnpy_suffix=".npy"letmaybe_add_suffixarray_name~suffix=letsuffix=matchsuffixwith|None->npy_suffix|Somesuffix->suffixinarray_name^suffixtypein_file=Zip.in_fileletopen_in=Zip.open_inletentriest=Zip.entriest|>List.map(funentry->letfilename=entry.Zip.filenameinifString.lengthfilename<String.lengthnpy_suffixthenfilenameelse(letstart_pos=String.lengthfilename-String.lengthnpy_suffixinifString.subfilenamestart_pos(String.lengthnpy_suffix)=npy_suffixthenString.subfilename0start_poselsefilename))letclose_in=Zip.close_inletread?suffixtarray_name=letarray_name=maybe_add_suffixarray_name~suffixinletentry=tryZip.find_entrytarray_namewith|Not_found->raise(Invalid_argument("unable to find "^array_name))inlettmp_file=Filename.temp_file"ocaml-npz"".tmp"inZip.copy_entry_to_filetentrytmp_file;letdata=read_copytmp_fileinSys.removetmp_file;datatypeout_file=Zip.out_fileletopen_outfilename=Zip.open_outfilenameletclose_out=Zip.close_outletwrite?suffixtarray_namearray=letarray_name=maybe_add_suffixarray_name~suffixinlettmp_file=Filename.temp_file"ocaml-npz"".tmp"inwritearraytmp_file;Zip.copy_file_to_entrytmp_filetarray_name;Sys.removetmp_fileend(** Type equalities module, used in conversion function *)moduleEq=struct(** An equality type to extract type equalities *)type('a,'b)t=W:('a,'a)topenBigarray_ext(** Type equalities for bigarray kinds *)moduleKind=structlet(===):typeabcd.(a,b)kind->(c,d)kind->((a,b)kind,(c,d)kind)toption=funxy->matchx,ywith|Float32,Float32->SomeW|Float64,Float64->SomeW|Int8_signed,Int8_signed->SomeW|Int8_unsigned,Int8_unsigned->SomeW|Int16_signed,Int16_signed->SomeW|Int16_unsigned,Int16_unsigned->SomeW|Int32,Int32->SomeW|Int64,Int64->SomeW|Int,Int->SomeW|Nativeint,Nativeint->SomeW|Complex32,Complex32->SomeW|Complex64,Complex64->SomeW|Char,Char->SomeW|_->Noneend(** Type equalities for layout *)moduleLayout=structlet(===):typeab.alayout->blayout->(alayout,blayout)toption=funxy->matchx,ywith|Fortran_layout,Fortran_layout->SomeW|C_layout,C_layout->SomeW|_,_->Noneendend(** Conversion functions from packed arrays to bigarrays *)letto_bigarray(typeabc)(layout:cBigarray_ext.layout)(kind:(a,b)Bigarray_ext.kind)(Px)=matchEq.Layout.(Bigarray_ext.Genarray.layoutx===layout)with|None->None|SomeEq.W->(matchEq.Kind.(Bigarray_ext.Genarray.kindx===kind)with|None->None|SomeEq.W->Some(x:(a,b,c)Bigarray_ext.Genarray.t))letto_bigarray1(typeabc)(layout:cBigarray_ext.layout)(kind:(a,b)Bigarray_ext.kind)(P1x)=matchEq.Layout.(Bigarray_ext.Array1.layoutx===layout)with|None->None|SomeEq.W->(matchEq.Kind.(Bigarray_ext.Array1.kindx===kind)with|None->None|SomeEq.W->Some(x:(a,b,c)Bigarray_ext.Array1.t))letto_bigarray2(typeabc)(layout:cBigarray_ext.layout)(kind:(a,b)Bigarray_ext.kind)(P2x)=matchEq.Layout.(Bigarray_ext.Array2.layoutx===layout)with|None->None|SomeEq.W->(matchEq.Kind.(Bigarray_ext.Array2.kindx===kind)with|None->None|SomeEq.W->Some(x:(a,b,c)Bigarray_ext.Array2.t))letto_bigarray3(typeabc)(layout:cBigarray_ext.layout)(kind:(a,b)Bigarray_ext.kind)(P3x)=matchEq.Layout.(Bigarray_ext.Array3.layoutx===layout)with|None->None|SomeEq.W->(matchEq.Kind.(Bigarray_ext.Array3.kindx===kind)with|None->None|SomeEq.W->Some(x:(a,b,c)Bigarray_ext.Array3.t))