123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868openBaseopenPpxlibopenPpx_arrayjit.Ppx_helperopenPpx_sharedletndarray_op~ident_label?axis_labels?labelexpr=letloc=expr.pexp_locinletvalues,batch_dims,output_dims,input_dims=ndarray_constantexprinletedimsdims=Ast_builder.Default.elist~locdimsinletop=match(axis_labels,label)with|None,None->[%exprNTDSL.ndarray]|Someaxis_labels,None->[%exprNTDSL.ndarray~axis_labels:[%eaxis_labels]]|None,Somelabel->[%exprNTDSL.ndarray~label:[%elabel]]|Someaxis_labels,Somelabel->[%exprNTDSL.ndarray~label:[%eopt_pat2string_list~locident_label]~axis_labels:[%eaxis_labels]~label:[%elabel]]in[%expr[%eop]~label:[%eopt_pat2string_list~locident_label]~batch_dims:[%eedimsbatch_dims]~input_dims:[%eedimsinput_dims]~output_dims:[%eedimsoutput_dims][%evalues]]typeexpr_type=Code|Array|Grad_of_tensorofexpression|Tensor|Unknownletis_unknown=functionUnknown->true|_->falsetypeprojections_slot=LHS|RHS1|RHS2|Nonslot|Undet[@@derivingequal,sexp]letassignment_opexpr=(* This should stay in sync with Arrayjit.Ops.assign_op_cd_syntax. *)letloc=expr.pexp_locinmatchexprwith|[%expr(=:)]->(false,[%exprArrayjit.Ops.Arg2])|[%expr(=+)]->(false,[%exprArrayjit.Ops.Add])|[%expr(=-)]->(false,[%exprArrayjit.Ops.Sub])|[%expr(=*)]->(false,[%exprArrayjit.Ops.Mul])|[%expr(=/)]->(false,[%exprArrayjit.Ops.Div])|[%expr(=**)]->(false,[%exprArrayjit.Ops.ToPowOf])|[%expr(=?/)]->(false,[%exprArrayjit.Ops.Relu_gate])|[%expr(=:+)]->(true,[%exprArrayjit.Ops.Add])|[%expr(=:-)]->(true,[%exprArrayjit.Ops.Sub])|[%expr(=:*)]->(true,[%exprArrayjit.Ops.Mul])|[%expr(=:/)]->(true,[%exprArrayjit.Ops.Div])|[%expr(=:**)]->(true,[%exprArrayjit.Ops.ToPowOf])|[%expr(=:?/)]->(true,[%exprArrayjit.Ops.Relu_gate])|_->(false,Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: expected an assignment operator, one of: %s %s""=+ (Add), =- (Sub), =* (Mul), =/ (Div), =** (ToPowOf), =?/ (Relu_gate), =: (Arg2), =:+, =:-,"" =:*, =:/, =:**, =:?/ (same with initializing the tensor to the neutral value before the start \
of the calculation)")letbinary_opexpr=(* This and is_binary_op should stay in sync with Arrayjit.Ops.binop_cd_syntax. *)letloc=expr.pexp_locinmatchexprwith|[%expr(+)]->([%exprShape.Pointwise_bin],[%exprArrayjit.Ops.Add])|[%expr(-)]->([%exprShape.Pointwise_bin],[%exprArrayjit.Ops.Sub])|[%expr(*)]->(Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"No default compose type for binary `*`, try e.g. ~logic:\".\" for pointwise, %s""~logic:\"@\" for matrix multiplication",[%exprArrayjit.Ops.Mul])|[%expr(/)]->(Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"For clarity, no default compose type for binary `/`, use ~logic:\".\" for pointwise division",[%exprArrayjit.Ops.Div])|[%expr(**)]->([%exprShape.Pointwise_bin],[%exprArrayjit.Ops.ToPowOf])|[%expr(-?/)]->([%exprShape.Pointwise_bin],[%exprArrayjit.Ops.Relu_gate])|[%expr(-/>)]->([%exprShape.Pointwise_bin],[%exprArrayjit.Ops.Arg2])|[%expr(-@>)]->([%exprShape.Pointwise_bin],[%exprArrayjit.Ops.Arg1])|_->([%exprShape.Pointwise_bin],Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: expected a binary operator, one of: %s""+ (Add), - (Sub), * (Mul), / (Div), ** (ToPowOf), -?/ (Relu_gate), -/> (Arg2)")letis_binary_opident=List.mem["+";"-";"*";"/";"**";"-?/";"-/>";"-@>"]ident~equal:String.equalletunary_opexpr=(* This and is_unary_op should stay in sync with Arrayjit.Ops.unop_cd_syntax. *)letloc=expr.pexp_locinmatchexprwith|[%expr(~=)]->([%exprShape.Pointwise_un],[%exprArrayjit.Ops.Identity])|[%expr(?/)]->([%exprShape.Pointwise_un],[%exprArrayjit.Ops.Relu])|_->([%exprShape.Pointwise_un],Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: expected a unary operator, one of: = (Identity), ?/ (Relu)")letis_unary_opident=List.mem["~=";"?/"]ident~equal:String.equalletrecarray_of_codec=letloc=c.pexp_locin[%exprmatch[%ec]with|Arrayjit.Assignments.Accum_binop{lhs;_}|Accum_unop{lhs;_}->lhs|Fetch{array;_}->array|Seq(_,subexpr)|Block_comment(_,subexpr)->[%earray_of_code[%exprsubexpr]]|Noop->Location.error_extensionf~loc"ppx_ocannl %%cd: Noop code does not refer to any data"]typebinding_setup={var:pattern;lazy_bind_to:expression;fwd_code_or_noop:expression}letwith_forward_argssetupsbody=letloc=body.pexp_locinletbindings=List.mapsetups~f:(fun{var;lazy_bind_to;_}->Ast_helper.Vb.mk~locvar[%exprLazy.force[%elazy_bind_to]])inletforward_args=List.mapsetups~f:(fun{fwd_code_or_noop;_}->fwd_code_or_noop)|>List.reduce~f:(funcodefwd->[%exprArrayjit.Assignments.Seq([%ecode],[%efwd])])in(Code,Nonslot,matchforward_argswith|None->body|Somefwd->[%expr(* FIXME: we do not want to force the computation unnecessarily, but we want the bindings? *)(*if Arrayjit.Assignments.is_noop [%e body] then Arrayjit.Assignments.Noop else*)[%eAst_helper.Exp.let_~locNonrecursivebindings[%exprArrayjit.Assignments.Seq([%efwd],[%ebody])]]])letproject_p_slotdebuglocslot=matchslotwith|LHS->[%exprp.project_lhs]|RHS1->[%exprp.project_rhs.(0)]|RHS2->[%exprp.project_rhs.(1)]|Nonslot->Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: not a valid accumulation/assignment slot filler at %s"debug|Undet->Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: insufficient slot filler information at %s %s"debug"(incorporate one of: v, v1, v2, g, g1, g2, lhs, rhs, rhs1, rhs2)"letproject_p_dimsdebuglocslot=matchslotwith|LHS->[%exprp.lhs_dims]|RHS1->[%exprp.rhs_dims.(0)]|RHS2->[%exprp.rhs_dims.(1)]|Nonslot->Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: not a valid accumulation/assignment slot filler at %s"debug|Undet->Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: insufficient slot filler information at %s %s"debug"(incorporate one of: v, v1, v2, g, g1, g2, lhs, rhs, rhs1, rhs2)"typearray_setup={slot:projections_slot;filler_typ:expr_type;binding:binding_setupoption;array_opt:expression;tensor:expressionoption;}letsetup_arrayfiller_pat(filler_typ,slot,filler)=letloc=filler.pexp_locinmatchfiller_typwith|Tensor|Unknown->lett=pat2exprfiller_patinletfwd_code_or_noop=[%exprifTensor.is_fwd_root[%et]then(Tensor.remove_fwd_root[%et];[%et].Tensor.forward)elseArrayjit.Assignments.Noop]in{binding=Some{var=filler_pat;lazy_bind_to=[%exprlazy[%efiller]];fwd_code_or_noop};filler_typ;slot;array_opt=[%exprSome[%et].value];tensor=Somet;}|Code->{binding=Some{var=filler_pat;lazy_bind_to=[%exprlazy[%efiller]];fwd_code_or_noop=pat2exprfiller_pat;};filler_typ;slot;array_opt=[%exprSome[%earray_of_codefiller]];tensor=None;}|Array->{binding=None;filler_typ;slot;array_opt=[%exprSome[%efiller]];tensor=None}|Grad_of_tensort->{binding=None;filler_typ;slot;array_opt=filler;tensor=Somet}letrectranslate?ident_label~proj_in_scope(expr:expression):expr_type*projections_slot*expression=letloc=expr.pexp_locinmatchexprwith|{pexp_desc=Pexp_constant(Pconst_float_);_}->(Tensor,Undet,[%exprNTDSL.number~label:[%eopt_pat2string_list~locident_label][%eexpr]])|{pexp_desc=Pexp_constant(Pconst_integer_);_}->(Tensor,Undet,[%exprNTDSL.number~label:[%eopt_pat2string_list~locident_label](Float.of_int[%eexpr])])|[%expr[%e?{pexp_desc=Pexp_constant(Pconst_charch);pexp_loc;_}][%e?{pexp_desc=Pexp_constant(Pconst_float_);_}asf]]->letaxis=Ast_helper.Exp.constant~loc:pexp_loc(Pconst_string(String.of_charch,pexp_loc,None))in(Tensor,Undet,[%exprNTDSL.number~label:[%eopt_pat2string_list~locident_label]~axis_label:[%eaxis][%ef]])|[%expr[%e?{pexp_desc=Pexp_constant(Pconst_charch);pexp_loc;_}][%e?{pexp_desc=Pexp_constant(Pconst_integer_);_}asi]]->letaxis=Ast_helper.Exp.constant~loc:pexp_loc(Pconst_string(String.of_charch,pexp_loc,None))in(Tensor,Undet,[%exprNTDSL.number~label:[%eopt_pat2string_list~locident_label]~axis_label:[%eaxis](Float.of_int[%ei])])|{pexp_desc=Pexp_array_;_}|{pexp_desc=Pexp_construct({txt=Lident"::";_},_);_}->(Tensor,Undet,ndarray_op~ident_labelexpr)|{pexp_desc=Pexp_ident{txt=Lident("v"|"lhs");_};_}->(Array,LHS,expr)|{pexp_desc=Pexp_ident{txt=Lident"g";_};_}->(Array,LHS,expr)|{pexp_desc=Pexp_ident{txt=Lident"rhs1";_};_}->(Array,RHS1,expr)|{pexp_desc=Pexp_ident{txt=Lident"t1";_};_}->(Tensor,RHS1,expr)|{pexp_desc=Pexp_ident{txt=Lident"v1";_};_}->(Array,RHS1,[%exprt1.Tensor.value])|{pexp_desc=Pexp_ident{txt=Lident"g1";_};_}->(Grad_of_tensor[%exprt1],RHS1,[%exprOption.mapt1.Tensor.diff~f:(fund->d.Tensor.grad)])|{pexp_desc=Pexp_ident{txt=Lident"rhs2";_};_}->(Array,RHS2,expr)|{pexp_desc=Pexp_ident{txt=Lident"t2";_};_}->(Tensor,RHS2,expr)|{pexp_desc=Pexp_ident{txt=Lident"v2";_};_}->(Array,RHS2,[%exprt2.Tensor.value])|{pexp_desc=Pexp_ident{txt=Lident"g2";_};_}->(Grad_of_tensor[%exprt2],RHS2,[%exprOption.mapt2.Tensor.diff~f:(fund->d.Tensor.grad)])|{pexp_desc=Pexp_ident{txt=Lidentop_ident;_};_}whenis_operatorop_ident->(Tensor,Undet,expr)|[%expr[%e?expr1]**.[%e?{pexp_desc=Pexp_constant(Pconst_integer_);_}asi]]->(* FIXME: `**.` should take a tensor and require that it's a literal. *)(* We need to hardcode these two patterns to prevent the numbers from being converted to tensors. *)let_typ1,slot1,e1=translate~proj_in_scopeexpr1in(Tensor,slot1,[%exprNTDSL.O.(**.)~label:[%eopt_pat2string_list~locident_label][%ee1](Float.of_int[%ei])])|[%expr[%e?expr1]**.[%e?expr2]]->let_typ1,slot1,e1=translate~proj_in_scopeexpr1in(Tensor,slot1,[%exprNTDSL.O.(**.)~label:[%eopt_pat2string_list~locident_label][%ee1][%eexpr2]])|[%expr[%e?expr1]*+[%e?{pexp_desc=Pexp_constant(Pconst_string(spec_str,_,_));_}asspec][%e?expr2]]whenString.containsspec_str'>'->let_typ1,slot1,expr1=translate~proj_in_scopeexpr1inlet_typ2,slot2,expr2=translate~proj_in_scopeexpr2inletslot=Option.value~default:Undet@@List.find~f:(functionUndet->false|_->true)[slot1;slot2]in(Tensor,slot,[%exprNTDSL.einsum~label:[%eopt_pat2string_list~locident_label][%espec][%eexpr1][%eexpr2]])|[%expr[%e?expr1]++[%e?{pexp_desc=Pexp_constant(Pconst_string(spec_str,_,_));_}asspec]]whenString.containsspec_str'>'->let_typ1,slot1,expr1=translate~proj_in_scopeexpr1in(Tensor,slot1,[%exprNTDSL.einsum1~label:[%eopt_pat2string_list~locident_label][%espec][%eexpr1]])|[%expr[%e?expr1].grad]->(lettyp1,slot1,expr1=translate?ident_label~proj_in_scopeexpr1inmatchtyp1with|Unknown|Tensor->(Grad_of_tensorexpr1,slot1,[%exprOption.map[%eexpr1].Tensor.diff~f:(fund->d.Tensor.grad)])|Code|Array|Grad_of_tensor_->(Array,slot1,Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: only tensors have a gradient"))|[%expr[%e?expr1].value]->(let((typ1,slot1,expr1)asresult)=translate?ident_label~proj_in_scopeexpr1in(* TODO: maybe this is too permissive? E.g. [t1.grad.value] is accepted. *)matchtyp1with|Unknown|Tensor->(Array,slot1,[%expr[%eexpr1].Tensor.value])|Code->(Array,slot1,array_of_codeexpr1)|Array|Grad_of_tensor_->result)|[%expr[%e?accu_op][%e?lhs]([%e?bin_op][%e?rhs1]([%e?rhs2]~projections:[%e?projections]))]->letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array[%pat?nondiff___lhs]@@translate?ident_label~proj_in_scope:truelhsinlet_,bin_op=binary_opbin_opinletsetup_r1=setup_array[%pat?nondiff___rhs1]@@translate~proj_in_scope:truerhs1inletsetup_r2=setup_array[%pat?nondiff___rhs2]@@translate~proj_in_scope:truerhs2inletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]in(* TODO: might be better to treat missing [rhs1, rhs2] as zeros rather than eliding the code. *)letbody=[%exprOption.value~default:Arrayjit.Assignments.Noop@@Option.map3[%esetup_l.array_opt][%esetup_r1.array_opt][%esetup_r2.array_opt]~f:(funlhsrhs1rhs2->Arrayjit.Assignments.Accum_binop{initialize_neutral=[%einitialize_neutral];accum=[%eaccu_op];lhs;op=[%ebin_op];rhs1;rhs2;projections=[%eprojections];})]inletsetups=List.filter_map~f:(funsetup->setup.binding)[setup_l;setup_r1;setup_r2]inwith_forward_argssetupsbody|[%expr[%e?accu_op][%e?lhs](([%e?un_op][%e?rhs])~projections:[%e?projections])]|[%expr[%e?accu_op][%e?lhs]([%e?un_op]([%e?rhs]~projections:[%e?projections]))]->(* Handle both un_op priority levels -- where application binds tighter and less tight. *)letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array[%pat?nondiff___lhs]@@translate?ident_label~proj_in_scope:truelhsinlet_,un_op=unary_opun_opinletsetup_r=setup_array[%pat?nondiff___rhs]@@translate~proj_in_scope:truerhsinletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]in(* TODO: might be better to treat missing [rhs] as zeros rather than eliding the code. *)letbody=[%exprOption.value~default:Arrayjit.Assignments.Noop@@Option.map2[%esetup_l.array_opt][%esetup_r.array_opt]~f:(funlhsrhs->Arrayjit.Assignments.Accum_unop{initialize_neutral=[%einitialize_neutral];accum=[%eaccu_op];lhs;op=[%eun_op];rhs;projections=[%eprojections];})]inletsetups=List.filter_map~f:(funsetup->setup.binding)[setup_l;setup_r]inwith_forward_argssetupsbody|[%expr[%e?accu_op][%e?lhs]([%e?rhs]~projections:[%e?projections])]->letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array[%pat?nondiff___lhs]@@translate?ident_label~proj_in_scope:truelhsinletsetup_r=setup_array[%pat?nondiff___rhs]@@translate~proj_in_scope:truerhsinletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inletbody=[%exprOption.value~default:Arrayjit.Assignments.Noop@@Option.map2[%esetup_l.array_opt][%esetup_r.array_opt]~f:(funlhsrhs->Arrayjit.Assignments.Accum_unop{initialize_neutral=[%einitialize_neutral];accum=[%eaccu_op];lhs;op=Arrayjit.Ops.Identity;rhs;projections=[%eprojections];})]inletsetups=List.filter_map~f:(funsetup->setup.binding)[setup_l;setup_r]inwith_forward_argssetupsbody|[%expr[%e?accu_op][%e?lhs]([%e?bin_op][%e?rhs1]([%e?rhs2]~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}aslogic]))]->letlogic=letloc=s_locinifString.equalspec"."then[%exprShape.Pointwise_bin]elseifString.equalspec"@"then[%exprShape.Compose]else[%exprShape.Einsum[%elogic]]inletinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array[%pat?nondiff___lhs]@@translate?ident_label~proj_in_scopelhsinlet_,bin_op=binary_opbin_opinletsetup_r1=setup_array[%pat?nondiff___rhs1]@@translate~proj_in_scoperhs1inletsetup_r2=setup_array[%pat?nondiff___rhs2]@@translate~proj_in_scoperhs2inletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inletargs_for=function|{filler_typ=Grad_of_tensor_;tensor=Somet;_}->(t,[%exprtrue])|{filler_typ=_;tensor=Somet;_}->(t,[%exprfalse])|_->(Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: cannot use `~logic` (infer shapes) for arrays, use tensors or `.grad` \
notation",[%exprfalse])inlett_expr,lhs_is_grad=args_forsetup_linlett1_expr,rhs1_is_grad=args_forsetup_r1inlett2_expr,rhs2_is_grad=args_forsetup_r2inletbody=[%exprTensor.raw_binop~initialize_neutral:[%einitialize_neutral]~accum:[%eaccu_op]~t:[%et_expr]~lhs_is_grad:[%elhs_is_grad]~op:[%ebin_op]~t1:[%et1_expr]~rhs1_is_grad:[%erhs1_is_grad]~t2:[%et2_expr]~rhs2_is_grad:[%erhs2_is_grad]~logic:[%elogic]]inletsetups=List.filter_map~f:(funsetup->setup.binding)[setup_l;setup_r1;setup_r2]inwith_forward_argssetupsbody|[%expr[%e?accu_op][%e?lhs](([%e?un_op][%e?rhs])~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}aslogic])]|[%expr[%e?accu_op][%e?lhs]([%e?un_op]([%e?rhs]~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}aslogic]))]->(* Handle both un_op priority levels -- where application binds tighter and less tight. *)letlogic=letloc=s_locinifString.equalspec"."then[%exprShape.Pointwise_un]elseifString.equalspec"T"then[%exprShape.Transpose]else[%exprShape.Permute[%elogic]]inletinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array[%pat?nondiff___lhs]@@translate?ident_label~proj_in_scopelhsinlet_,un_op=unary_opun_opinletsetup_r=setup_array[%pat?nondiff___rhs]@@translate~proj_in_scoperhsinletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inletargs_for=function|{filler_typ=Grad_of_tensor_;tensor=Somet;_}->(t,[%exprtrue])|{filler_typ=_;tensor=Somet;_}->(t,[%exprfalse])|_->(Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: cannot use `~logic` (infer shapes) for arrays, use tensors or `.grad` \
notation",[%exprfalse])inlett_expr,lhs_is_grad=args_forsetup_linlett1_expr,rhs_is_grad=args_forsetup_rinletbody=[%exprTensor.raw_unop~initialize_neutral:[%einitialize_neutral]~accum:[%eaccu_op]~t:[%et_expr]~lhs_is_grad:[%elhs_is_grad]~op:[%eun_op]~t1:[%et1_expr]~rhs_is_grad:[%erhs_is_grad]~logic:[%elogic]]inletsetups=List.filter_map~f:(funsetup->setup.binding)[setup_l;setup_r]inwith_forward_argssetupsbody|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_ident;loc=accu_loc};_}asaccu_op][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbinop_ident;loc=op_loc};_}asbin_op][%e?rhs1][%e?rhs2])]whenis_assignmentaccu_ident&&is_binary_opbinop_ident&&proj_in_scope->letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array[%pat?nondiff___lhs]@@translate?ident_label~proj_in_scopelhsinlet_,bin_op=binary_opbin_opinletsetup_r1=setup_array[%pat?nondiff___rhs1]@@translate~proj_in_scoperhs1inletsetup_r2=setup_array[%pat?nondiff___rhs2]@@translate~proj_in_scoperhs2inletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inletprojections=letlhs_dims=project_p_dims"LHS"lhs.pexp_locsetup_l.slotinletrhs1_dims=project_p_dims"RHS1"lhs.pexp_locsetup_r1.slotinletrhs2_dims=project_p_dims"RHS2"lhs.pexp_locsetup_r2.slotinletproject_lhs=project_p_slot"LHS"lhs.pexp_locsetup_l.slotinletproject_rhs1=project_p_slot"RHS1"rhs1.pexp_locsetup_r1.slotinletproject_rhs2=project_p_slot"RHS2"rhs2.pexp_locsetup_r2.slotin[%exprlazy(letp=Lazy.forceprojectionsinArrayjit.Indexing.{product_space=p.product_space;product_iterators=p.product_iterators;lhs_dims=[%elhs_dims];rhs_dims=[|[%erhs1_dims];[%erhs2_dims]|];project_lhs=[%eproject_lhs];project_rhs=[|[%eproject_rhs1];[%eproject_rhs2]|];debug_info={p.debug_infowithtrace=("ppx_cd "^[%estring_expr~loc:accu_locaccu_ident]^" "^[%estring_expr~loc:op_locbinop_ident],Arrayjit.Indexing.unique_debug_id())::p.debug_info.trace;};})]inletbody=[%exprOption.value~default:Arrayjit.Assignments.Noop@@Option.map3[%esetup_l.array_opt][%esetup_r1.array_opt][%esetup_r2.array_opt]~f:(funlhsrhs1rhs2->Arrayjit.Assignments.Accum_binop{initialize_neutral=[%einitialize_neutral];accum=[%eaccu_op];lhs;op=[%ebin_op];rhs1;rhs2;projections=[%eprojections];})]inletsetups=List.filter_map~f:(funsetup->setup.binding)[setup_l;setup_r1;setup_r2]inwith_forward_argssetupsbody|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_ident;loc=accu_loc};_}asaccu_op][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentunop_ident;loc=op_loc};_}asun_op][%e?rhs])]whenis_assignmentaccu_ident&&is_unary_opunop_ident&&proj_in_scope->letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array[%pat?nondiff___lhs]@@translate?ident_label~proj_in_scopelhsinlet_,un_op=unary_opun_opinletsetup_r1=setup_array[%pat?nondiff___rhs1]@@translate~proj_in_scoperhsinletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inletprojections=letlhs_dims=project_p_dims"LHS"lhs.pexp_locsetup_l.slotinletrhs1_dims=project_p_dims"RHS1"lhs.pexp_locsetup_r1.slotinletproject_lhs=project_p_slot"LHS"lhs.pexp_locsetup_l.slotinletproject_rhs1=project_p_slot"RHS1"rhs.pexp_locsetup_r1.slotin[%exprlazy(letp=Lazy.forceprojectionsinArrayjit.Indexing.{product_space=p.product_space;product_iterators=p.product_iterators;lhs_dims=[%elhs_dims];rhs_dims=[|[%erhs1_dims]|];project_lhs=[%eproject_lhs];project_rhs=[|[%eproject_rhs1]|];debug_info={p.debug_infowithtrace=("ppx_cd "^[%estring_expr~loc:accu_locaccu_ident]^" "^[%estring_expr~loc:op_locunop_ident],Arrayjit.Indexing.unique_debug_id())::p.debug_info.trace;};})]inletbody=[%exprOption.value~default:Arrayjit.Assignments.Noop@@Option.map2[%esetup_l.array_opt][%esetup_r1.array_opt]~f:(funlhsrhs->Arrayjit.Assignments.Accum_unop{initialize_neutral=[%einitialize_neutral];accum=[%eaccu_op];lhs;op=[%eun_op];rhs;projections=[%eprojections];})]inletsetups=List.filter_map~f:(funsetup->setup.binding)[setup_l;setup_r1]inwith_forward_argssetupsbody|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentop_ident;loc=op_loc};_}asaccu_op][%e?lhs][%e?rhs]]whenis_assignmentop_ident&&proj_in_scope->letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array[%pat?nondiff___lhs]@@translate?ident_label~proj_in_scopelhsinletsetup_r1=setup_array[%pat?nondiff___rhs1]@@translate~proj_in_scoperhsinletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inletprojections=letlhs_dims=project_p_dims"LHS"lhs.pexp_locsetup_l.slotinletrhs1_dims=project_p_dims"RHS1"lhs.pexp_locsetup_r1.slotinletproject_lhs=project_p_slot"LHS"lhs.pexp_locsetup_l.slotinletproject_rhs1=project_p_slot"RHS1"rhs.pexp_locsetup_r1.slotin[%exprlazy(letp=Lazy.forceprojectionsinArrayjit.Indexing.{product_space=p.product_space;product_iterators=p.product_iterators;lhs_dims=[%elhs_dims];rhs_dims=[|[%erhs1_dims]|];project_lhs=[%eproject_lhs];project_rhs=[|[%eproject_rhs1]|];debug_info={p.debug_infowithtrace=("ppx_cd "^[%estring_expr~loc:op_locop_ident],Arrayjit.Indexing.unique_debug_id())::p.debug_info.trace;};})]inletbody=[%exprOption.value~default:Arrayjit.Assignments.Noop@@Option.map2[%esetup_l.array_opt][%esetup_r1.array_opt]~f:(funlhsrhs->Arrayjit.Assignments.Accum_unop{initialize_neutral=[%einitialize_neutral];accum=[%eaccu_op];lhs;op=Arrayjit.Ops.Identity;rhs;projections=[%eprojections];})]inletsetups=List.filter_map~f:(funsetup->setup.binding)[setup_l;setup_r1]inwith_forward_argssetupsbody|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_ident;_};_}asaccu_op][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbinop_ident;_};_}asbin_op][%e?rhs1][%e?rhs2])]whenis_assignmentaccu_ident&&is_binary_opbinop_ident->letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array[%pat?nondiff___lhs]@@translate?ident_label~proj_in_scopelhsinletlogic,bin_op=binary_opbin_opinletsetup_r1=setup_array[%pat?nondiff___rhs1]@@translate~proj_in_scoperhs1inletsetup_r2=setup_array[%pat?nondiff___rhs2]@@translate~proj_in_scoperhs2inletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inletargs_for=function|{filler_typ=Grad_of_tensor_;tensor=Somet;_}->(t,[%exprtrue])|{filler_typ=_;tensor=Somet;_}->(t,[%exprfalse])|_->(Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: cannot use `~logic` (infer shapes) for arrays, use tensors or `.grad` \
notation",[%exprfalse])inlett_expr,lhs_is_grad=args_forsetup_linlett1_expr,rhs1_is_grad=args_forsetup_r1inlett2_expr,rhs2_is_grad=args_forsetup_r2inletbody=[%exprTensor.raw_binop~initialize_neutral:[%einitialize_neutral]~accum:[%eaccu_op]~t:[%et_expr]~lhs_is_grad:[%elhs_is_grad]~op:[%ebin_op]~t1:[%et1_expr]~rhs1_is_grad:[%erhs1_is_grad]~t2:[%et2_expr]~rhs2_is_grad:[%erhs2_is_grad]~logic:[%elogic]]inletsetups=List.filter_map~f:(funsetup->setup.binding)[setup_l;setup_r1;setup_r2]inwith_forward_argssetupsbody|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_ident;_};_}asaccu_op][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentunop_ident;_};_}asun_op][%e?rhs])]whenis_assignmentaccu_ident&&is_unary_opunop_ident->letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array[%pat?nondiff___lhs]@@translate?ident_label~proj_in_scopelhsinletlogic,un_op=unary_opun_opinletsetup_r=setup_array[%pat?nondiff___rhs]@@translate~proj_in_scoperhsinletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inletargs_for=function|{filler_typ=Grad_of_tensor_;tensor=Somet;_}->(t,[%exprtrue])|{filler_typ=_;tensor=Somet;_}->(t,[%exprfalse])|_->(Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: cannot use `~logic` (infer shapes) for arrays, use tensors or `.grad` \
notation",[%exprfalse])inlett_expr,lhs_is_grad=args_forsetup_linlett1_expr,rhs_is_grad=args_forsetup_rinletbody=[%exprTensor.raw_unop~initialize_neutral:[%einitialize_neutral]~accum:[%eaccu_op]~t:[%et_expr]~lhs_is_grad:[%elhs_is_grad]~op:[%eun_op]~t1:[%et1_expr]~rhs_is_grad:[%erhs_is_grad]~logic:[%elogic]]inletsetups=List.filter_map~f:(funsetup->setup.binding)[setup_l;setup_r]inwith_forward_argssetupsbody|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentop_ident;_};_}asaccu_op][%e?lhs][%e?rhs]]whenis_assignmentop_ident->letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array[%pat?nondiff___lhs]@@translate?ident_label~proj_in_scopelhsinletsetup_r=setup_array[%pat?nondiff___rhs]@@translate~proj_in_scoperhsinletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inletargs_for=function|{filler_typ=Grad_of_tensor_;tensor=Somet;_}->(t,[%exprtrue])|{filler_typ=_;tensor=Somet;_}->(t,[%exprfalse])|_->(Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: cannot use `~logic` (infer shapes) for arrays, use tensors or `.grad` \
notation",[%exprfalse])inlett_expr,lhs_is_grad=args_forsetup_linlett1_expr,rhs_is_grad=args_forsetup_rinletbody=[%exprTensor.raw_unop~initialize_neutral:[%einitialize_neutral]~accum:[%eaccu_op]~t:[%et_expr]~lhs_is_grad:[%elhs_is_grad]~op:Arrayjit.Ops.Identity~t1:[%et1_expr]~rhs_is_grad:[%erhs_is_grad]~logic:Shape.Pointwise_un]inletsetups=List.filter_map~f:(funsetup->setup.binding)[setup_l;setup_r]inwith_forward_argssetupsbody|[%expr[%e?expr1][%e?expr2][%e?expr3]]->lettyp1,slot1,expr1=translate?ident_label~proj_in_scopeexpr1inlet_typ2,slot2,expr2=translate~proj_in_scopeexpr2inlet_typ3,slot3,expr3=translate~proj_in_scopeexpr3inletslot=Option.value~default:Undet@@List.find~f:(functionUndet->false|_->true)[slot1;slot2;slot3]in(typ1,slot,[%expr[%eexpr1][%eexpr2][%eexpr3]])|[%expr[%e?expr1][%e?expr2]]->lettyp1,slot1,expr1=translate?ident_label~proj_in_scopeexpr1inlet_typ2,slot2,expr2=translate~proj_in_scopeexpr2inletslot=Option.value~default:Undet@@List.find~f:(functionUndet->false|_->true)[slot1;slot2]in(typ1,slot,[%expr[%eexpr1][%eexpr2]])|{pexp_desc=Pexp_fun((arg_label:arg_label),arg,opt_val,body);_}asexpr->letproj_in_scope=proj_in_scope||matcharg_labelwith|(Labelleds|Optionals)whenString.equals"projections"->true|_->falseinlettyp,slot,body=translate?ident_label~proj_in_scopebodyin(typ,slot,{exprwithpexp_desc=Pexp_fun(arg_label,arg,opt_val,body)})|[%exprwhile[%e?_test_expr]do[%e?_body]done]->(Code,Nonslot,Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: while: low-level code embeddings not supported yet")|[%exprfor[%p?_pat]=[%e?_init]to[%e?_final]do[%e?_body_expr]done]->(Code,Nonslot,Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: for-to: low-level code embeddings not supported yet")|[%exprfor[%p?_pat]=[%e?_init]downto[%e?_final]do[%e?_body_expr]done]->(Code,Nonslot,Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: for-downto: low-level code embeddings not supported yet")|[%expr[%e?expr1];[%e?expr2]]->let_typ1,_slot1,expr1=translate~proj_in_scopeexpr1inlet_typ2,_slot2,expr2=translate?ident_label~proj_in_scopeexpr2in(Code,Nonslot,[%exprArrayjit.Assignments.Seq([%eexpr1],[%eexpr2])])|[%exprif[%e?expr1]then[%e?expr2]else[%e?expr3]]->lettyp2,slot2,expr2=translate?ident_label~proj_in_scopeexpr2inlettyp3,slot3,expr3=translate?ident_label~proj_in_scopeexpr3inlettyp=ifis_unknowntyp2thentyp3elsetyp2inletslot=Option.value~default:Undet@@List.find~f:(functionUndet->false|_->true)[slot2;slot3]in(typ,slot,[%exprif[%eexpr1]then[%eexpr2]else[%eexpr3]])|[%exprif[%e?expr1]then[%e?expr2]]->let_typ2,_slot2,expr2=translate?ident_label~proj_in_scopeexpr2in(Code,Nonslot,[%exprif[%eexpr1]then[%eexpr2]elseArrayjit.Assignments.Noop])|{pexp_desc=Pexp_match(expr1,cases);_}->lettyps,slots,cases=List.unzip3@@List.mapcases~f:(fun({pc_rhs;_}asc)->lettyp,slot,pc_rhs=translate?ident_label~proj_in_scopepc_rhsin(typ,slot,{cwithpc_rhs}))inlettyp=Option.value~default:Unknown@@List.findtyps~f:(Fn.nonis_unknown)inletslot=Option.value~default:Undet@@List.find~f:(functionUndet->false|_->true)slotsin(typ,slot,{exprwithpexp_desc=Pexp_match(expr1,cases)})|{pexp_desc=Pexp_let(_recflag,_bindings,_body);_}->(* TODO(80): to properly support local bindings, we need to collect the type environment. *)(Unknown,Undet,Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: let-in: local let-bindings not implemented yet")(* let bindings = List.map bindings ~f:(fun binding -> {binding with pvb_expr=translate binding.pvb_expr})
in {expr with pexp_desc=Pexp_let (recflag, bindings, translate body)} *)|{pexp_desc=Pexp_open(decl,body);_}->lettyp,slot,body=translate?ident_label~proj_in_scopebodyin(typ,slot,{exprwithpexp_desc=Pexp_open(decl,body)})|{pexp_desc=Pexp_letmodule(name,module_expr,body);_}->lettyp,slot,body=translate?ident_label~proj_in_scopebodyin(typ,slot,{exprwithpexp_desc=Pexp_letmodule(name,module_expr,body)})|{pexp_desc=Pexp_ident{txt=Lidentop_ident;_};_}whenis_operatorop_ident->(Unknown,Undet,[%expr[%eexpr]~label:[%eopt_pat2string_list~locident_label]])|_->(Unknown,Undet,expr)lettranslate?ident_label(expr:expression):expression=let_,_,v=translate?ident_label~proj_in_scope:falseexprinmatchident_labelwith|Some[%pat?_]->letloc=v.pexp_locin[%exprTensor.with_unchanged_roots~f:(fun()->[%ev])]|_->vtypeextension=Cd|Dt|Rs[@@derivingequal,variants]letexpr_expander~loc~path:_payload=matchpayloadwith|{pexp_desc=Pexp_let(recflag,bindings,body);_}->(* We are at the %cd annotation level: do not tranlsate the body. *)letbindings=List.mapbindings~f:(funvb->letv=translate~ident_label:vb.pvb_patvb.pvb_exprin{vbwithpvb_expr=[%exprletopen!NTDSL.Oin[%ev]];})in{payloadwithpexp_desc=Pexp_let(recflag,bindings,body)}|expr->letexpr=translateexprin[%exprletopen!NTDSL.Oin[%eexpr]]letflatten_str~loc~path:_items=matchitemswith|[x]->x|_->Ast_helper.Str.include_{pincl_mod=Ast_helper.Mod.structureitems;pincl_loc=loc;pincl_attributes=[]}lettranslate_str({pstr_desc;_}asstr)=matchpstr_descwith|Pstr_eval(expr,attrs)->letexpr=translateexprinletloc=expr.pexp_locin{strwithpstr_desc=Pstr_eval([%exprletopen!NTDSL.Oin[%eexpr]],attrs);}|Pstr_value(recf,bindings)->letfvb=letloc=vb.pvb_locinletv=translate~ident_label:vb.pvb_patvb.pvb_exprin{vbwithpvb_expr=[%exprletopen!NTDSL.Oin[%ev]];}in{strwithpstr_desc=Pstr_value(recf,List.mapbindings~f)}|_->strletstr_expander~loc~path(payload:structure_itemlist)=flatten_str~loc~path@@List.mappayload~f:translate_str