123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238openBase(** Operation types shared by all backends; and precision types. *)moduleLazy=Utils.Lazy(** {2 *** Precision ***} *)typeuint8_elt=Bigarray.int8_unsigned_elt(* FIXME: Upcoming in OCaml 5.2.0. See:
https://github.com/ocaml/ocaml/pull/10775/commits/ba6a2c378056c8669fb1bb99bf07b12d69bd4a12 *)typefloat16_elt=Bigarray.float32_elttypefloat32_elt=Bigarray.float32_elttypefloat64_elt=Bigarray.float64_eltletfloat16:(float,float16_elt)Bigarray.kind=Bigarray.float32type('ocaml,'impl)precision=|Byte:(char,uint8_elt)precision|Half:(float,float16_elt)precision|Single:(float,float32_elt)precision|Double:(float,float64_elt)precision[@@derivingsexp_of]typeprec=|Void_prec|Byte_precof(char,uint8_elt)precision|Half_precof(float,float16_elt)precision|Single_precof(float,float32_elt)precision|Double_precof(float,float64_elt)precisionletbyte=Byte_precBytelethalf=Half_precHalfletsingle=Single_precSingleletdouble=Double_precDoubleletsexp_of_prec=function|Void_prec->Sexp.Atom"Void_prec"|Byte_prec_->Sexp.Atom"Byte_prec"|Half_prec_->Sexp.Atom"Half_prec"|Single_prec_->Sexp.Atom"Single_prec"|Double_prec_->Sexp.Atom"Double_prec"letprec_of_sexp=function|Sexp.Atom"Void_prec"->Void_prec|Sexp.Atom"Byte_prec"->byte|Sexp.Atom"Half_prec"->half|Sexp.Atom"Single_prec"->single|Sexp.Atom"Double_prec"->double|Sexp.List_->invalid_arg"prec_of_sexp: expected atom, found list"|Sexp.Atoms->invalid_arg@@"prec_of_sexp: unknown precision "^sletprecision_to_string(typeocamlelt_t)(prec:(ocaml,elt_t)precision)=matchprecwithByte->"byte"|Half->"half"|Single->"single"|Double->"double"letprec_string=function|Void_prec->"void"|Byte_prec_->"byte"|Half_prec_->"half"|Single_prec_->"single"|Double_prec_->"double"letequal_precp1p2=match(p1,p2)with|Void_prec,Void_prec->true|Byte_prec_,Byte_prec_->true|Half_prec_,Half_prec_->true|Single_prec_,Single_prec_->true|Double_prec_,Double_prec_->true|Void_prec,_|Byte_prec_,_|Half_prec_,_|Single_prec_,_|Double_prec_,_->falseletprec_in_bytes=function|Void_prec->0|Byte_prec_->1|Half_prec_->2|Single_prec_->4|Double_prec_->8letpromote_precp1p2=match(p1,p2)with|Double_prec_,_->p1|_,Double_prec_->p2|Single_prec_,_->p1|_,Single_prec_->p2|Half_prec_,_->p1|_,Half_prec_->p2|Byte_prec_,_->p1|_,Byte_prec_->p2|Void_prec,Void_prec->Void_precletpack_prec(typeocamlelt_t)(prec:(ocaml,elt_t)precision)=matchprecwithByte->byte|Half->half|Single->single|Double->doubletype'rmap_prec={f:'ocaml'elt_t.('ocaml,'elt_t)precision->'r}letmap_prec?default{f}=function|Void_prec->Option.value_or_thunkdefault~default:(fun()->invalid_arg"map_prec: Void_prec")|Byte_precByte->fByte|Half_prec(Half|Single)->fHalf|Single_prec(Single|Half)->fSingle|Double_precDouble->fDouble|_->.letcuda_typ_of_prec=function|Byte_prec_->"unsigned char"(* TODO: or should it be uint8, or uint8_t? *)|Half_prec_->(* FIXME: *)"float"|Single_prec_->"float"|Double_prec_->"double"|Void_prec->"void"(** {2 *** Operations ***} *)(** Initializes or resets a array by filling in the corresponding numbers, at the appropriate
precision. *)typeinit_op=|Constant_fillof{values:floatarray;strict:bool}(** Fills in the numbers where the rightmost axis is contiguous. If [strict=true], loops over
the provided values. *)|Range_over_offsets(** Fills in the offset number of each cell (i.e. how many cells away it is from the
beginning). *)|Standard_uniform(** Draws the values from U(0,1). *)|File_mappedofstring*prec(** Reads the data using [Unix.openfile] and [Unix.map_file]. *)[@@derivingequal,sexp]typebinop=Add|Sub|Mul|Div|ToPowOf|Relu_gate|Arg2|Arg1[@@derivingsexp,compare,equal]typeunop=Identity|Relu[@@derivingsexp,compare,equal](** Either the left-neutral or right-neutral element of the operation. Unspecified if the operation
does not have a neutral element. *)letneutral_elem=function|Add|Sub->0.|Mul|Div->1.|ToPowOf->1.|Relu_gate->1.|Arg2->0.|Arg1->0.letinterpret_binopopv1v2=letopenFloatinmatchopwith|Arg1->v1|Arg2->v2|Add->v1+v2|Sub->v1-v2|Mul->v1*v2|Div->v1/v2|ToPowOf->ifis_integerv2thenint_powv1@@to_intv2elsev1**v2|Relu_gate->ifv1>0.0thenv2else0.0letinterpret_unopopv=letopenFloatinmatchopwithIdentity->v|Reluwhenv>=0.->v|Relu->0.letbinop_C_syntaxprecv=match(v,prec)with|Arg1,_->invalid_arg"Ops.binop_C_syntax: Arg1 is not a C operator"|Arg2,_->invalid_arg"Ops.binop_C_syntax: Arg2 is not a C operator"|_,Void_prec->invalid_arg"Ops.binop_C_syntax: Void precision"|Add,_->("("," +",")")|Sub,_->("("," -",")")|Mul,_->("("," *",")")|Div,_->("("," /",")")|ToPowOf,Double_prec_->("pow(",",",")")|ToPowOf,Single_prec_->("powf(",",",")")|ToPowOf,Half_prec_->("powf(",",",")")|ToPowOf,Byte_prec_->invalid_arg"Ops.binop_C_syntax: ToPowOf not supported for byte/integer precisions"|Relu_gate,Byte_prec_->("("," > 0 ?"," : 0)")|Relu_gate,_->("("," > 0.0 ?"," : 0.0)")(* "((int)(", "> 0.0) *", ")" *)letbinop_cd_syntax=function|Arg1->"-@>"|Arg2->"-/>"|Add->"+"|Sub->"-"|Mul->"*"|Div->"/"|ToPowOf->"**"|Relu_gate->"-?/"letassign_op_C_syntax=function|Arg1->invalid_arg"Ops.assign_op_C_syntax: Arg1 is not a C assignment operator"|Arg2->"="|Add->"+="|Sub->"-="|Mul->"*="|Div->"/="|ToPowOf->invalid_arg"Ops.assign_op_C_syntax: ToPowOf function is not a C assignment operator"|Relu_gate->invalid_arg"Ops.assign_op_C_syntax: Relu_gate is not a C assignment operator"letassign_op_cd_syntax~initialize_neutral=function|Arg1->invalid_arg"Ops.assign_op_cd_syntax: Arg1 is not a %cd assignment operator"|Arg2->"=:"|Addwheninitialize_neutral->"=:+"|Subwheninitialize_neutral->"=:-"|Mulwheninitialize_neutral->"=:*"|Divwheninitialize_neutral->"=:/"|ToPowOfwheninitialize_neutral->"=:**"|Relu_gatewheninitialize_neutral->"=:?/"|Add->"=+"|Sub->"=-"|Mul->"=*"|Div->"=/"|ToPowOf->"=**"|Relu_gate->"=?/"letunop_cd_syntax=functionIdentity->"~="|Relu->"?/"(** {2 *** Global references ***} *)typevoidptr=unitCtypes.ptrletsexp_of_voidptrp=Sexp.AtomCtypes.(string_of(ptrvoid)p)letcompare_voidptr=Ctypes.ptr_compareletequal_voidptr:voidptr->voidptr->bool=phys_equalletptr_to_string(typeelem)(ptr:elemCtypes.ptr)prec="("^cuda_typ_of_precprec^"*)"^Nativeint.Hex.to_string(Ctypes.raw_address_of_ptr@@Ctypes.to_voidpptr)typeglobal_identifier=|C_functionofstring(** Calls a no-argument or indices-arguments C function. *)|External_unsafeof{ptr:voidptr;prec:(prec[@equal.ignore][@compare.ignore]);dims:intarrayLazy.t;}|Merge_bufferof{source_node_id:int}(** Each device has at most one merge buffer, which is re-used, and re-allocated as needed, by
merge operations. The merge buffer is associated with the source node of the device's most
recent [device_to_device ~into_merge_buffer:true] operation. *)[@@derivingsexp_of,equal,compare]