123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362openBaseopenPpxlibopenPpx_arrayjit.Ppx_helpertypeli=longidentletstring_expr~locs=Ast_helper.Exp.constant@@Pconst_string(s,loc,None)letstring_of_patpat=letreclident=functionLidents|Ldot(_,s)->s|Lapply(_,i)->lidentiinletreclooppat=matchpat.ppat_descwith|Ppat_open(_,pat)|Ppat_lazypat|Ppat_constraint(pat,_)->looppat|Ppat_alias(_,ident)->ident.txt|Ppat_varident->ident.txt|Ppat_any->"_"|Ppat_variant(s,_)|Ppat_constant(Pconst_string(s,_,_))|Ppat_constant(Pconst_integer(s,_))|Ppat_constant(Pconst_float(s,_))->s|Ppat_constant(Pconst_charc)->Char.to_stringc|Ppat_tuplepats->"("^String.concat~sep:", "(List.map~f:looppats)^")"|Ppat_arraypats->"[|"^String.concat~sep:", "(List.map~f:looppats)^"|]"|Ppat_construct(c,_)->lidentc.txt|Ppat_interval(_,_)|Ppat_record(_,_)|Ppat_or(_,_)|Ppat_type_|Ppat_unpack_|Ppat_exception_|Ppat_extension_->""inlooppatletpat2stringpat=letloc=pat.ppat_locinstring_expr~loc@@string_of_patpatletcollect_pat_identspat=letone=Set.singleton(moduleString)inletnone=Set.empty(moduleString)inletreclooppat=letallpats=Set.union_list(moduleString)@@List.map~f:looppatsinmatchpat.ppat_descwith|Ppat_open(_,pat)|Ppat_lazypat|Ppat_constraint(pat,_)->looppat|Ppat_alias(_,ident)->oneident.txt|Ppat_varident->oneident.txt|Ppat_any->none|Ppat_variant(_,None)->none|Ppat_variant(_,Somepat)->looppat|Ppat_constant_->none|Ppat_tuplepats|Ppat_arraypats->allpats|Ppat_construct(_,None)->none|Ppat_construct(_,Some(_,pat))->looppat|Ppat_interval(_,_)->none|Ppat_record(lpats,_)->all@@List.map~f:sndlpats|Ppat_or(p1,p2)->all[p1;p2]|Ppat_type_|Ppat_unpack_|Ppat_exception_|Ppat_extension_->noneinlooppatletexpr2string_or_emptyexpr=letreclident=function|Lidents->s|Ldot(li,s)->lidentli^"."^s|Lapply(_,i)->lidentiinletrecloopexpr=matchexpr.pexp_descwith|Pexp_open(_,expr)|Pexp_lazyexpr|Pexp_constraint(expr,_)->loopexpr|Pexp_identident->lidentident.txt|Pexp_variant(s,_)|Pexp_constant(Pconst_string(s,_,_))|Pexp_constant(Pconst_integer(s,_))|Pexp_constant(Pconst_float(s,_))->s|Pexp_constant(Pconst_charc)->Char.to_stringc|Pexp_tupleexprs->"("^String.concat~sep:", "(List.map~f:loopexprs)^")"|Pexp_arrayexprs->"[|"^String.concat~sep:", "(List.map~f:loopexprs)^"|]"|Pexp_construct(c,_)->lidentc.txt|_->""instring_expr~loc:expr.pexp_loc@@loopexprletopt_pat2string~loc=function|None->[%exprNone]|Somepat->[%exprSome[%epat2stringpat]]letopt_pat2string_list~loc=function|None->[%expr[]]|Somepat->[%expr[[%epat2stringpat]]]letopt_expr~loc=functionNone->[%exprNone]|Someexpr->[%exprSome[%eexpr]]letrecpat2exprpat=letmoduleAst=Ast_builder.Defaultinletloc=pat.ppat_locinmatchpat.ppat_descwith|Ppat_constraint(pat',typ)->Ast.pexp_constraint~loc(pat2exprpat')typ|Ppat_alias(_,ident)|Ppat_varident->Ast.pexp_ident~loc{identwithtxt=Lidentident.txt}|Ppat_variant(ident,e_opt)->Ast.pexp_variant~locident@@Option.mape_opt~f:pat2expr|Ppat_constantc->Ast.pexp_constant~locc|Ppat_construct(c,None)->Ast.pexp_construct~loccNone|Ppat_construct(c,Some([],args))->Ast.pexp_construct~locc@@Some(pat2exprargs)|Ppat_record(fields,Asttypes.Closed)->Ast.pexp_record~loc(List.mapfields~f:(fun(label,field)->(label,pat2exprfield)))None|Ppat_tuplepats->Ast.pexp_tuple~loc@@List.mappats~f:pat2expr|Ppat_arraypats->Ast.pexp_array~loc@@List.mappats~f:pat2expr|_->Ast.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl does not recognize/support the pattern; maybe try using an `as` alias."letnon_alphanum_regexp=Str.regexp"^[^a-zA-Z0-9]+$"letis_operatorident=Str.string_matchnon_alphanum_regexpident0letis_assignmentident=String.lengthident>1&&Char.equalident.[0]'='&&(not@@List.mem["==";"===";"=>";"==>";"=>>"]ident~equal:String.equal)(** Binary primitive ops, both infix operator and function name variants. *)letbinary_ops=Hashtbl.of_alist_exn(moduleString)[("-@>",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Arg1]));("fst",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Arg1]));("-/>",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Arg2]));("snd",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Arg2]));("+",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Add]));("add",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Add]));("-",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Sub]));("sub",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Sub]));("*",funloc->(Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"No default compose type for binary `*`, try e.g. ~logic:\".\" for pointwise, %s""~logic:\"@\" for matrix multiplication",[%exprIr.Ops.Mul]));("mul",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Mul]));("/",funloc->(Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"For clarity, no default compose type for binary `/`, use ~logic:\".\" for \
pointwise division",[%exprIr.Ops.Div]));("div",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Div]));("**",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.ToPowOf]));("pow",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.ToPowOf]));("-?/",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Relu_gate]));("relu_gate",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Relu_gate]));("-?^",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Satur01_gate]));("sat01_gate",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Satur01_gate]));("<",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Cmplt]));("lt",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Cmplt]));("=",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Cmpeq]));("eq",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Cmpeq]));("<>",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Cmpne]));("ne",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Cmpne]));("||",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Or]));("or_",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Or]));("&&",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.And]));("and_",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.And]));("%",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Mod]));("mod_",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Mod]));("@^",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Max]));("max",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Max]));("@-",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Min]));("min",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Min]));("^^^^",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Threefry4x32_crypto]));("threefry4x32_crypto",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Threefry4x32_crypto]));("^^",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Threefry4x32_light]));("threefry4x32_light",funloc->([%exprShape.Pointwise_bin],[%exprIr.Ops.Threefry4x32_light]));](** Unary primitive ops. *)letunary_ops=Hashtbl.of_alist_exn(moduleString)[("id",funloc->([%exprShape.Pointwise_un],[%exprIr.Ops.Identity]));("relu",funloc->([%exprShape.Pointwise_un],[%exprIr.Ops.Relu]));("sat01",funloc->([%exprShape.Pointwise_un],[%exprIr.Ops.Satur01]));("exp",funloc->([%exprShape.Pointwise_un],[%exprIr.Ops.Exp]));("log",funloc->([%exprShape.Pointwise_un],[%exprIr.Ops.Log]));("exp2",funloc->([%exprShape.Pointwise_un],[%exprIr.Ops.Exp2]));("log2",funloc->([%exprShape.Pointwise_un],[%exprIr.Ops.Log2]));("sin",funloc->([%exprShape.Pointwise_un],[%exprIr.Ops.Sin]));("cos",funloc->([%exprShape.Pointwise_un],[%exprIr.Ops.Cos]));("sqrt",funloc->([%exprShape.Pointwise_un],[%exprIr.Ops.Sqrt]));("recip",funloc->([%exprShape.Pointwise_un],[%exprIr.Ops.Recip]));("recip_sqrt",funloc->([%exprShape.Pointwise_un],[%exprIr.Ops.Recip_sqrt]));("neg",funloc->([%exprShape.Pointwise_un],[%exprIr.Ops.Neg]));("tanh",funloc->([%exprShape.Pointwise_un],[%exprIr.Ops.Tanh_approx]));("not",funloc->([%exprShape.Pointwise_un],[%exprIr.Ops.Not]));("uint4x32_to_prec_uniform1",funloc->([%exprShape.Pointwise_un],[%exprIr.Ops.Uint4x32_to_prec_uniform1]));](** Vector unary primitive ops. *)letvec_unary_ops=Hashtbl.of_alist_exn(moduleString)[("uint4x32_to_prec_uniform",funloc->([%exprShape.Uint4x32_to_prec],[%exprIr.Ops.Uint4x32_to_prec_uniform]));](** Ternary primitive ops. *)letternary_ops=Hashtbl.of_alist_exn(moduleString)[("where",funloc->([%exprShape.Pointwise_tern],[%exprIr.Ops.Where]));("fma",funloc->([%exprShape.Compose_accumulate],[%exprIr.Ops.FMA]));](** Assignment binary ops, and whether assignment reduction is zero-initialized. *)letassignment_ops=(* This should stay in sync with Ir.Ops.assign_op_cd_syntax. *)Hashtbl.of_alist_exn(moduleString)[("=:",funloc->(false,[%exprIr.Ops.Arg2]));("=+",funloc->(false,[%exprIr.Ops.Add]));("=-",funloc->(false,[%exprIr.Ops.Sub]));("=*",funloc->(false,[%exprIr.Ops.Mul]));("=/",funloc->(false,[%exprIr.Ops.Div]));("=**",funloc->(false,[%exprIr.Ops.ToPowOf]));("=?/",funloc->(false,[%exprIr.Ops.Relu_gate]));("=?^",funloc->(false,[%exprIr.Ops.Satur01_gate]));("=||",funloc->(false,[%exprIr.Ops.Or]));("=&&",funloc->(false,[%exprIr.Ops.And]));("=@^",funloc->(false,[%exprIr.Ops.Max]));("=@-",funloc->(false,[%exprIr.Ops.Min]));("=^^^^",funloc->(false,[%exprIr.Ops.Threefry4x32]));("=:+",funloc->(true,[%exprIr.Ops.Add]));("=:-",funloc->(true,[%exprIr.Ops.Sub]));("=:*",funloc->(true,[%exprIr.Ops.Mul]));("=:/",funloc->(true,[%exprIr.Ops.Div]));("=:**",funloc->(true,[%exprIr.Ops.ToPowOf]));("=:?/",funloc->(true,[%exprIr.Ops.Relu_gate]));("=:?^",funloc->(true,[%exprIr.Ops.Satur01_gate]));("=:||",funloc->(true,[%exprIr.Ops.Or]));("=:&&",funloc->(true,[%exprIr.Ops.And]));("=:@^",funloc->(true,[%exprIr.Ops.Max]));("=:@-",funloc->(true,[%exprIr.Ops.Min]));("=:^^^^",funloc->(true,[%exprIr.Ops.Threefry4x32]));]leteinsum_binary_ops=Hashtbl.of_alist_exn(moduleString)[("+*",funloc->[%expreinsum]);("@^+",funloc->[%exprtropical]);("+++",funloc->[%exprouter_sum]);]leteinsum_unary_ops=Hashtbl.of_alist_exn(moduleString)[("++",funloc->[%expreinsum1]);("@^^",funloc->[%expreinmax1])]letis_primitive_opop_ident=List.exists~f:(Fn.flipHashtbl.memop_ident)[ternary_ops;unary_ops;binary_ops]letlet_opt~locvbsexpr=ifMap.is_emptyvbsthenexprelseAst_helper.Exp.let_~locNonrecursive(Map.datavbs)exprletno_vbs=Map.empty(moduleString)letreduce_vbss=List.reduce_exn~f:(Map.merge_skewed~combine:(fun~key:__v1v2->v2))letexpr_expander_with_punningtranslate~loc~path:_payload=matchpayloadwith|{pexp_desc=Pexp_let(recflag,bindings,body);_}->(* We are at the %op/%cd annotation level: do not tranlsate the body. *)letvbss,bindings=List.unzip@@List.mapbindings~f:(funvb->letvbs,v=translate?ident_label:(Somevb.pvb_pat)vb.pvb_exprin(vbs,{vbwithpvb_expr=v}))inletexpr={payloadwithpexp_desc=Pexp_let(recflag,bindings,body)}inlet_opt~loc(reduce_vbssvbss)expr|expr->letvbs,expr=translate?ident_label:Noneexprinlet_opt~locvbsexprletflatten_str~loc~path:_items=matchitemswith|[x]->x|_->Ast_helper.Str.include_{pincl_mod=Ast_helper.Mod.structureitems;pincl_loc=loc;pincl_attributes=[]}lettranslate_strtranslate({pstr_desc;pstr_loc=loc;_}asstr)=matchpstr_descwith|Pstr_eval(expr,attrs)->letexpr=expr_expander_with_punningtranslate~loc~path:()exprin{strwithpstr_desc=Pstr_eval(expr,attrs)}|Pstr_value(recf,bindings)->letfvb=letloc=vb.pvb_locinletvbs,v=translate?ident_label:(Somevb.pvb_pat)vb.pvb_exprinletv=let_opt~locvbsvin{vbwithpvb_expr=v}in{strwithpstr_desc=Pstr_value(recf,List.mapbindings~f)}|_->strletstr_expander_with_punningtranslate~loc~path(payload:structure_itemlist)=flatten_str~loc~path@@List.mappayload~f:(translate_strtranslate)letndarray_op?axis_labels?label~ndarray_fnexpr=letloc=expr.pexp_locinletvalues,batch_dims,output_dims,input_dims=ndarray_constantexprinletedimsdims=Ast_builder.Default.elist~locdimsinletw_val=[%expr[%endarray_fn][%evalues]]inletop=match(axis_labels,label)with|None,None->w_val|Someaxis_labels,None->[%expr[%ew_val]~axis_labels:[%eaxis_labels]]|None,Somelabel->[%expr[%ew_val]~label:[%elabel]]|Someaxis_labels,Somelabel->[%expr[%ew_val]~axis_labels:[%eaxis_labels]~label:[%elabel]]in[%expr[%eop]~batch_dims:[%eedimsbatch_dims]~input_dims:[%eedimsinput_dims]~output_dims:[%eedimsoutput_dims]()]letcollect_capture_labels~locheadrest=letcapture_labels=head::collect_list[]restinletcapture_labels,errors=List.partition_mapcapture_labels~f:(function|{pexp_desc=Pexp_constant(Pconst_string(label,_,_));pexp_loc;_}->Either.First(pexp_loc,label)|expr->Either.Second(Ast_builder.Default.pexp_extension~loc:expr.pexp_loc@@Location.error_extensionf~loc:expr.pexp_loc"ppx_ocannl %%op: expected a string literal"))inletcapture_refs,capture_bindings=List.mapcapture_labels~f:(fun(loc,label)->letref_expr=[%exprShape.get_variable_ref[%eAst_builder.Default.estring~loclabel]]inletbinding=Ast_builder.Default.value_binding~loc~pat:(Ast_builder.Default.pvar~loclabel)~expr:ref_exprin(Ast_builder.Default.evar~loclabel,(label,binding)))|>List.unzipinletcapture_dims_expr=Ast_builder.Default.elist~loc(errors@capture_refs)inletcapture_vbs=Map.of_alist_exn(moduleString)capture_bindingsin(capture_vbs,capture_dims_expr)