123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560openBaseopenPpxlibopenPpx_arrayjit.Ppx_helperopenPpx_sharedmoduleA=Ppxlib_ast.Ast_helpertypeexpr_type=|Codeof{is_commented:bool}|Array|Value_of_tensorofexpression|Grad_of_tensorofexpression|Tensor|Unknown|Merge_valueofexpression|Merge_gradofexpression|No_grad_tensor_introof{name:string;name_expr:expression}|Functionletis_unknown=functionUnknown->true|_->falsetypeprojections_slot=LHS|RHS1|RHS2|RHS3|Scalar|Nonslot|Undet[@@derivingequal,sexp]typeresult={vbs:value_bindingMap.M(String).t;(** [vbs] are the bindings introduced by inline tensor declarations (aka. punning). These
bindings are discharged with the whole [%cd] extension scope in scope. *)typ:expr_type;slot:projections_slot;expr:expression;(** Depending on {!field-typ}, of type:
- if [Code]: [Assignments.comp];
- if [Array | Merge_value | Value_of_tensor]: [Tnode.t];
- if [Merge_grad | Grad_of_tensor]: [Tnode.t option];
- if [Tensor | Unknown | No_grad_tensor_intro]: [Tensor.t]. *)array_opt_of_code:expressionoption;(** Of type: [Tnode.t]. It keeps track of which tensor node to use when [typ] is [Code] but
the result is used in an array context. *)}typearray_setup={vb:value_bindingoption;(** This binding is only generated for a tensor expression that is not an identifier, since
recomputing the expression (when copied) would generate a fresh tensor. It is discharged
when an assignment is built. *)slot:projections_slot;filler_typ:expr_type;fwd_code_or_noop:expressionoption;(** Of type: [Assignments.comp]. *)array_opt:expression;(** Of type: if [slot = LHS] then [Tnode.t] else [Assignments.buffer]. *)tensor:expressionoption;(** Of type: [Tensor.t]. *)pun_hint_tnode:(expression*bool)option;(** Of type: [string list]. The tnode, if any, whose label the relevant punned no-gradient
tensor should incorporate in its label. The bool denotes whether this is a preferred (high
quality) guess. *)}letmake_vb~loc~name~name_expr~hint_label=letpat=A.Pat.var~loc{loc=name_expr.pexp_loc;txt=name}inletv=matchhint_labelwith|None->[%exprNTDSL.term~label:[[%ename_expr]]()]|Somehint_label->[%exprNTDSL.term~label:([%ename_expr]::[%ehint_label])()]inletvb=A.Vb.mk~locpatvinvb(** The expression argument is of type: [Assignments.t]. *)letassignment~punned~lhs~rhses?body_for_lhs?raw_body()=letsetups=lhs::rhsesinletbody,is_for_lhs=match(body_for_lhs,raw_body)with|Somebody_for_lhs,None->letloc=body_for_lhs.pexp_locin([%exprOption.value~default:Ir.Assignments.Noop[%ebody_for_lhs]],true)|None,Someraw_body->(raw_body,false)|_->assertfalseinletloc=body.pexp_locinletforward_args=List.filter_mapsetups~f:(fun{fwd_code_or_noop;_}->fwd_code_or_noop)inletvbs,body=matchlhs.filler_typwith|No_grad_tensor_intro{name;name_expr}->(letgood_hints,bad_hints=List.partition_tf~f:snd@@List.filter_maprhses~f:(funsup->sup.pun_hint_tnode)inlethint_data=Option.first_some(List.hdgood_hints)(List.hdbad_hints)inlethint_label=Option.map~f:fsthint_datainletvbs=Map.singleton(moduleString)name@@make_vb~loc~name~name_expr~hint_labelinmatchhint_datawith|None->(vbs,body)|Somedata->(matchHashtbl.addpunned~key:name~datawith|`Ok->(vbs,body)|`Duplicate->(no_vbs,Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: duplicate inline declaration of no-gradient tensor %s"name)))|_->(no_vbs,body)inletbody=(* Note: this is not a binding from an inline declaration, it's a temporary binding. *)ifOption.is_somelhs.vbthenAst_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: the assigned-to position cannot be an expression building a new tensor"elsebodyinlettensor_vbs=List.filter_maprhses~f:(funrhs->rhs.vb)inletbody=[%expr{Ir.Assignments.asgns=[%ebody];embedded_nodes=Base.Set.empty(moduleIr.Tnode)}]inletcomps=List.fold(body::List.revforward_args)~init:[%expr[]]~f:(funxsx->[%expr[%ex]::[%exs]])inletbody=[%exprIr.Assignments.sequence[%ecomps]]inletbody=ifList.is_emptytensor_vbsthenbodyelseA.Exp.let_~locNonrecursivetensor_vbsbodyinletexpr=ifis_for_lhsthen[%exprOption.value~default:Ir.Assignments.{asgns=Noop;embedded_nodes=Base.Set.empty(moduleIr.Tnode)}@@Option.map[%elhs.array_opt]~f:(funlhs->[%ebody])]elsebodyin{vbs;typ=Code{is_commented=false};slot=Nonslot;expr;array_opt_of_code=Somelhs.array_opt;}letproject_p_slotdebuglocslot=matchslotwith|LHS->[%exprp.project_lhs]|RHS1->[%exprp.project_rhs.(0)]|RHS2->[%exprp.project_rhs.(1)]|RHS3->[%exprp.project_rhs.(2)]|Scalar->[%expr[|Ir.Indexing.Fixed_idx0|]]|Nonslot->Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: not a valid accumulation/assignment slot filler at %s"debug|Undet->Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: insufficient slot filler information at %s %s"debug"(incorporate one of: v, v1, v2, g, g1, g2, lhs, rhs, rhs1, rhs2)"letproject_p_dimsdebuglocslot=matchslotwith|LHS->[%exprp.lhs_dims]|RHS1->[%exprp.rhs_dims.(0)]|RHS2->[%exprp.rhs_dims.(1)]|RHS3->[%exprp.rhs_dims.(2)]|Scalar->[%expr[|1|]]|Nonslot->Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: not a valid accumulation/assignment slot filler at %s"debug|Undet->Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: insufficient slot filler information at %s %s"debug"(incorporate one of: v, v1, v2, g, g1, g2, lhs, rhs, rhs1, rhs2)"letguess_pun_hint~no_filler_label~punned~bad_pun_hintsfiller_typfiller=letloc=filler.pexp_locinlethint=[%expr[%efiller].Ir.Tnode.label]inmatch(filler_typ,filler,no_filler_label)with|(Code_|Function),_,_->None|_,{pexp_desc=Pexp_ident{txt=Lidentname;_};_},_whenSet.membad_pun_hintsname->None|Array,_,false->Some(hint,false)|(Tensor|Unknown),{pexp_desc=Pexp_ident{txt=Lidentname;_};_},_whenHashtbl.mempunnedname->Hashtbl.findpunnedname|(Tensor|Unknown),{pexp_desc=Pexp_ident_;_},_->Some(hint,true)|(Tensor|Unknown),_,false->Some(hint,false)|((Value_of_tensor{pexp_desc=Pexp_ident{txt=Lidentname;_};_}|Grad_of_tensor{pexp_desc=Pexp_ident{txt=Lidentname;_};_}|Merge_value{pexp_desc=Pexp_ident{txt=Lidentname;_};_}|Merge_grad{pexp_desc=Pexp_ident{txt=Lidentname;_};_}),_,_)whenSet.membad_pun_hintsname->None|((Value_of_tensor{pexp_desc=Pexp_ident{txt=Lidentname;_};_}|Grad_of_tensor{pexp_desc=Pexp_ident{txt=Lidentname;_};_}|Merge_value{pexp_desc=Pexp_ident{txt=Lidentname;_};_}|Merge_grad{pexp_desc=Pexp_ident{txt=Lidentname;_};_}),_,_)whenHashtbl.mempunnedname->Hashtbl.findpunnedname|(Value_of_tensort|Grad_of_tensort|Merge_valuet|Merge_gradt),_,false->(lethint=[%expr[%et].Tensor.value.Ir.Tnode.label]inmatchtwith{pexp_desc=Pexp_ident_;_}->Some(hint,true)|_->Some(hint,false))|No_grad_tensor_intro{name;_},_,_->Hashtbl.findpunnedname|_,_,true->Noneletempty_tns~loc=[%exprBase.Set.empty(moduleIr.Tnode)]letempty_comp~loc=[%expr{Ir.Assignments.asgns=Ir.Assignments.Noop;embedded_nodes=[%eempty_tns~loc]}]letsetup_array~punned~bad_pun_hints~is_lhs{typ=filler_typ;slot;expr=filler;vbs;array_opt_of_code}=letloc=filler.pexp_locinletopt_buffertn=ifis_lhsthen[%exprSome[%etn]]else[%exprSome(Ir.Assignments.Node[%etn])]inletbufferopt_tn=ifis_lhsthenopt_tnelse[%exprOption.map[%eopt_tn]~f:(funtn->Ir.Assignments.Nodetn)]inletpun_hint_tnodeno_filler_label=guess_pun_hint~no_filler_label~punned~bad_pun_hintsfiller_typfillerinletdefault_setupno_filler_label={vb=None;fwd_code_or_noop=None;filler_typ;slot;array_opt=opt_buffer[%expr[%efiller].Tensor.value];tensor=None;pun_hint_tnode=pun_hint_tnodeno_filler_label;}inmatch(Map.is_emptyvbs,filler_typ)with|(false,_|_,No_grad_tensor_intro_)whennotis_lhs->{(default_setupfalse)witharray_opt=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: inline tensor declarations are not allowed in assignment \
right-hand side, to prevent over-use in locations with less label information";}|_,(Tensor|Unknown)whenmatchfillerwith{pexp_desc=Pexp_ident_;_}->true|_->false->lett=fillerinletfwd_code_or_noop=Some[%exprifTensor.is_fwd_root[%et]then(Tensor.remove_fwd_root[%et];[%et].Tensor.forward)else[%eempty_comp~loc]]in{(default_setupfalse)withfwd_code_or_noop;tensor=Somet}|_,Value_of_tensor({pexp_desc=Pexp_ident_;_}ast)->letfwd_code_or_noop=Some[%exprifTensor.is_fwd_root[%et]then(Tensor.remove_fwd_root[%et];[%et].Tensor.forward)else[%eempty_comp~loc]]in{(default_setupfalse)withfwd_code_or_noop;array_opt=opt_buffer[%expr[%et].Tensor.value];tensor=Somet;}|_,Value_of_tensort->{(default_setupfalse)witharray_opt=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: the <tensor>.value notation is only supported when <tensor> is an \
identifier";tensor=Somet;}|_,(Tensor|Unknown)->(* Need to bind the expression computing the tensor so we don't recompute it. *)letv=matchslotwith|LHS->[%pat?nondiff__lhs]|RHS1->[%pat?nondiff__rhs1]|RHS2->[%pat?nondiff__rhs2]|RHS3->[%pat?nondiff__rhs3]|Scalar|Nonslot|Undet->[%pat?nondiff__tensor]inlett=pat2exprvinletvb=Some(A.Vb.mk~locvfiller)inletfwd_code_or_noop=Some[%exprifTensor.is_fwd_root[%et]then(Tensor.remove_fwd_root[%et];[%et].Tensor.forward)else[%eempty_comp~loc]]in{(default_setuptrue)withvb;fwd_code_or_noop;array_opt=opt_buffer[%expr[%et].Tensor.value];tensor=Somet;}|_,No_grad_tensor_intro_->(* Inline tensors are guaranteed to be leaf tensors, so they don't have forward code, but they
are embedded. *)letfwd_code_or_noop=Some[%expr{Ir.Assignments.asgns=Ir.Assignments.Noop;embedded_nodes=Base.Set.singleton(moduleIr.Tnode)[%efiller].Tensor.value;}]in{(default_setupfalse)withfwd_code_or_noop;tensor=Somefiller}|_,Function->{(default_setupfalse)withfwd_code_or_noop=Somefiller;array_opt=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: a syntactic function in place of an array is not supported";}|_,Code_whenOption.is_nonearray_opt_of_code->{(default_setupfalse)withfwd_code_or_noop=Somefiller;array_opt=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: could not determine a lead array of provided code";}|_,Code_->{(default_setupfalse)withfwd_code_or_noop=Somefiller;array_opt=buffer(Option.value_exnarray_opt_of_code);}|_,Array->{(default_setupfalse)witharray_opt=opt_bufferfiller}|_,Grad_of_tensor({pexp_desc=Pexp_ident_;_}ast)->{(default_setupfalse)witharray_opt=bufferfiller;tensor=Somet}|_,Grad_of_tensort->{(default_setupfalse)witharray_opt=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: the <tensor>.grad notation is only supported when <tensor> is an \
identifier";tensor=Somet;}|_,(Merge_value_|Merge_grad_)whenis_lhs->{(default_setupfalse)witharray_opt=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: merge buffers cannot be assigned to";}|_,Merge_valuet->{(default_setupfalse)witharray_opt=[%exprSome(Merge_buffer[%efiller])];tensor=Somet;}|_,Merge_gradt->{(default_setupfalse)witharray_opt=[%exprOption.map[%efiller]~f:(funtn->Ir.Assignments.Merge_buffertn)];tensor=Somet;}letargs_for~loc=function|{filler_typ=Merge_grad_;tensor=Somet;_}->(t,[%exprtrue],[%exprtrue])|{filler_typ=Grad_of_tensor_;tensor=Somet;_}->(t,[%exprtrue],[%exprfalse])|{filler_typ=Merge_value_;tensor=Somet;_}->(t,[%exprfalse],[%exprtrue])|{filler_typ=_;tensor=Somet;_}->(t,[%exprfalse],[%exprfalse])|_->(Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: cannot use `~logic` (infer shapes) for arrays, use tensors or \
`.value` or `.grad` notation",[%exprfalse],[%exprfalse])letreduce_res_vbsrs=reduce_vbss@@List.maprs~f:(funr->r.vbs)letcompare_slotsab=match(a,b)with|Nonslot,_->1|_,Nonslot->-1|Undet,_->1|_,Undet->-1|Scalar,_->1|_,Scalar->-1|_->0(** Helper function to handle cases (for Pexp_match, Pexp_function with cases, etc.) *)lethandle_cases~bad_pun_hints~proj_in_scopetranslcases=letfields,transformed_cases=List.unzip@@List.mapcases~f:(fun({pc_rhs;_}asc)->letres=transl~bad_pun_hints~proj_in_scopepc_rhsin((res.vbs,res.typ,res.slot),{cwithpc_rhs=res.expr}))inletvbss,typs,slots=List.unzip3fieldsin(* TODO: make the inference of typ and slot more strict by detecting mismatches. *)lettyp=Option.value~default:Unknown@@List.findtyps~f:(Fn.nonis_unknown)inletslot=List.hd_exn@@List.sortslots~compare:compare_slotsinletloc=(List.hd_exncases).pc_lhs.ppat_locin(transformed_cases,{vbs=reduce_vbssvbss;typ;slot;expr=[%expr()];(* This will be replaced by the caller *)array_opt_of_code=None;})lettranslate?ident_label(expr:expression):result=letpunned=Hashtbl.create(moduleString)inletrectransl~bad_pun_hints~proj_in_scope(expr:expression):result=letloc=expr.pexp_locinletdefault_result={vbs=no_vbs;typ=Tensor;slot=Undet;expr;array_opt_of_code=None}inletloop=transl~bad_pun_hintsinletassignment_opaccu_op=loc|>Option.value_or_thunk(Hashtbl.findassignment_opsaccu_op)~default:(fun()_loc->(false,Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: expected an assignment operator, one of: %s %s""=+ (Add), =- (Sub), =* (Mul),=/ (Div), =** (ToPowOf), =?/ (Relu_gate), =?^ \
(Satur01_gate), =|| (Or), =&& (And), =@^ (Max), =@- (Min), =^^^^ \
(threefry4x32), =: (Arg2), =:+, =:-,"" =:*, =:/, =:**, =:?/, =:?^, =:||, =:&&, =:@^, =:@-, =:^^^^ (same with \
initializing the tensor to the neutral value before the start of the \
calculation)"))inletunary_opun_op=loc|>Option.value_or_thunk(Hashtbl.findunary_opsun_op)~default:(fun()loc->([%exprShape.Pointwise_un],Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: expected a unary operator, one of: %s""id, relu, sat01, exp, log, exp2, log2, sin, cos, sqrt, recip, recip_sqrt, \
neg, tanh"))inletvec_unary_opvec_un_op=loc|>Option.value_or_thunk(Hashtbl.findvec_unary_opsvec_un_op)~default:(fun()loc->([%exprShape.Uint4x32_to_prec_uniform],Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: expected a vector unary operator, one of: \
uint4x32_to_prec_uniform; found: %s"vec_un_op))inletbinary_opbin_op=loc|>Option.value_or_thunk(Hashtbl.findbinary_opsbin_op)~default:(fun()_loc->([%exprShape.Pointwise_bin],Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: expected a binary operator, one of: %s""+ (Add), - (Sub), * (Mul), / (Div), **(ToPowOf), -?/ (Relu_gate), -?^ \
(Satur01_gate), -/> (Arg2), < (Cmplt), = (Cmpeq), <> (Cmpne), || (Or), && \
(And), % (Mod), @^(Max), @- (Min), ^^^^ (threefry4x32)"))inletternary_optern_op=loc|>Option.value_or_thunk(Hashtbl.findternary_opstern_op)~default:(fun()_loc->([%exprShape.Pointwise_tern],Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: expected a ternary operator, one of: where, fma"))in(* TODO: collapse these (code reuse) *)letprocess_assign_ternop~accu_op~lhs~tern_op~rhs1~rhs2~rhs3?projections~proj_in_scope()=letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array~punned~bad_pun_hints~is_lhs:true@@loop~proj_in_scope:truelhsinlet_,tern_op=ternary_optern_opinletsetup_r1=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs1inletsetup_r2=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs2inletsetup_r3=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs3inletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inletprojections_lazy,projections_debug=matchprojectionswith|Someprjs->([%expr[%eprjs].Tensor.projections],[%expr[%eprjs].Tensor.projections_debug])|None->letlhs_dims=project_p_dims"LHS"lhs.pexp_locsetup_l.slotinletrhs1_dims=project_p_dims"RHS1"lhs.pexp_locsetup_r1.slotinletrhs2_dims=project_p_dims"RHS2"lhs.pexp_locsetup_r2.slotinletrhs3_dims=project_p_dims"RHS3"lhs.pexp_locsetup_r3.slotinletproject_lhs=project_p_slot"LHS"lhs.pexp_locsetup_l.slotinletproject_rhs1=project_p_slot"RHS1"rhs1.pexp_locsetup_r1.slotinletproject_rhs2=project_p_slot"RHS2"rhs2.pexp_locsetup_r2.slotinletproject_rhs3=project_p_slot"RHS3"rhs3.pexp_locsetup_r3.slotinletproj_lazy=[%exprlazy(letp=Lazy.forceprojections.Tensor.projectionsinIr.Indexing.{product_space=p.product_space;product_iterators=p.product_iterators;lhs_dims=[%elhs_dims];rhs_dims=[|[%erhs1_dims];[%erhs2_dims];[%erhs3_dims]|];project_lhs=[%eproject_lhs];project_rhs=[|[%eproject_rhs1];[%eproject_rhs2];[%eproject_rhs3]|];debug_info={p.debug_infowithtrace=("ppx_cd "^[%eexpr2string_or_emptyaccu_op]^" "^[%eexpr2string_or_emptytern_op],Ir.Indexing.unique_debug_id())::p.debug_info.trace;};})]in(proj_lazy,[%exprprojections.Tensor.projections_debug])in(* FIXME: might be better to treat missing [rhs1, rhs2, rhs3] as zeros or errors rather than
eliding the code, only lhs should decide whether to elide the code. *)letbody_for_lhs=[%exprOption.map3[%esetup_r1.array_opt][%esetup_r2.array_opt][%esetup_r3.array_opt]~f:(funrhs1rhs2rhs3->Ir.Assignments.Accum_op{initialize_neutral=[%einitialize_neutral];accum=[%eaccu_op];lhs;rhs=Ternop{op=[%etern_op];rhs1;rhs2;rhs3};projections=[%eprojections_lazy];projections_debug=[%eprojections_debug];})]inassignment~punned~lhs:setup_l~rhses:[setup_r1;setup_r2;setup_r3]~body_for_lhs()inletprocess_assign_binop~accu_op~lhs~bin_op~rhs1~rhs2?projections~proj_in_scope()=letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array~punned~bad_pun_hints~is_lhs:true@@loop~proj_in_scope:truelhsinlet_,bin_op=binary_opbin_opinletsetup_r1=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs1inletsetup_r2=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs2inletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inletprojections_lazy,projections_debug=matchprojectionswith|Someprjs->([%expr[%eprjs].Tensor.projections],[%expr[%eprjs].Tensor.projections_debug])|None->letlhs_dims=project_p_dims"LHS"lhs.pexp_locsetup_l.slotinletrhs1_dims=project_p_dims"RHS1"lhs.pexp_locsetup_r1.slotinletrhs2_dims=project_p_dims"RHS2"lhs.pexp_locsetup_r2.slotinletproject_lhs=project_p_slot"LHS"lhs.pexp_locsetup_l.slotinletproject_rhs1=project_p_slot"RHS1"rhs1.pexp_locsetup_r1.slotinletproject_rhs2=project_p_slot"RHS2"rhs2.pexp_locsetup_r2.slotinletproj_lazy=[%exprlazy(letp=Lazy.forceprojections.Tensor.projectionsinIr.Indexing.{product_space=p.product_space;product_iterators=p.product_iterators;lhs_dims=[%elhs_dims];rhs_dims=[|[%erhs1_dims];[%erhs2_dims]|];project_lhs=[%eproject_lhs];project_rhs=[|[%eproject_rhs1];[%eproject_rhs2]|];debug_info={p.debug_infowithtrace=("ppx_cd "^[%eexpr2string_or_emptyaccu_op]^" "^[%eexpr2string_or_emptybin_op],Ir.Indexing.unique_debug_id())::p.debug_info.trace;};})]in(proj_lazy,[%exprprojections.Tensor.projections_debug])in(* FIXME: might be better to treat missing [rhs1, rhs2] as zeros or errors rather than eliding
the code, only lhs should decide whether to elide the code. *)letbody_for_lhs=[%exprOption.map2[%esetup_r1.array_opt][%esetup_r2.array_opt]~f:(funrhs1rhs2->Ir.Assignments.Accum_op{initialize_neutral=[%einitialize_neutral];accum=[%eaccu_op];lhs;rhs=Binop{op=[%ebin_op];rhs1;rhs2};projections=[%eprojections_lazy];projections_debug=[%eprojections_debug];})]inassignment~punned~lhs:setup_l~rhses:[setup_r1;setup_r2]~body_for_lhs()inletprocess_assign_unop~accu_op~lhs~un_op~rhs?projections~proj_in_scope()=letinitialize_neutral,accum=assignment_opaccu_opinlet_,op=unary_opun_opin(* FIXME: I think this ignores the slot information here! Just assuming [projections] is
as-should-be, but that's not consistent with omitting the projections arg (assuming it
comes from the context). *)letsetup_l=setup_array~punned~bad_pun_hints~is_lhs:true@@loop~proj_in_scopelhsinletsetup_r=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhsinletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inletprojections_lazy,projections_debug=matchprojectionswith|Someprjs->([%expr[%eprjs].Tensor.projections],[%expr[%eprjs].Tensor.projections_debug])|None->letlhs_dims=project_p_dims"LHS"lhs.pexp_locsetup_l.slotinletrhs1_dims=project_p_dims"RHS1"lhs.pexp_locsetup_r.slotinletproject_lhs=project_p_slot"LHS"lhs.pexp_locsetup_l.slotinletproject_rhs1=project_p_slot"RHS1"rhs.pexp_locsetup_r.slotinletproj_lazy=[%exprlazy(letp=Lazy.forceprojections.Tensor.projectionsinIr.Indexing.{product_space=p.product_space;product_iterators=p.product_iterators;lhs_dims=[%elhs_dims];rhs_dims=[|[%erhs1_dims]|];project_lhs=[%eproject_lhs];project_rhs=[|[%eproject_rhs1]|];debug_info={p.debug_infowithtrace=("ppx_cd "^[%estring_expr~locaccu_op]^" "^[%estring_expr~locun_op],Ir.Indexing.unique_debug_id())::p.debug_info.trace;};})]in(proj_lazy,[%exprprojections.Tensor.projections_debug])in(* FIXME: might be better to treat missing [rhs] as zeros or errors rather than eliding the
code, only lhs should decide whether to elide the code. *)letbody_for_lhs=[%exprOption.map[%esetup_r.array_opt]~f:(funrhs->Ir.Assignments.Accum_op{initialize_neutral=[%einitialize_neutral];accum=[%eaccum];lhs;rhs=Unop{op=[%eop];rhs};projections=[%eprojections_lazy];projections_debug=[%eprojections_debug];})]inassignment~punned~lhs:setup_l~rhses:[setup_r]~body_for_lhs()inletprocess_vec_unop~lhs~vec_un_op~rhs?projections~proj_in_scope()=(* Vector unary operations do not have accumulation, they directly set values *)let_,op=vec_unary_opvec_un_opinletsetup_l=setup_array~punned~bad_pun_hints~is_lhs:true@@loop~proj_in_scopelhsinletsetup_r=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhsinletprojections_lazy,projections_debug=matchprojectionswith|Someprjs->([%expr[%eprjs].Tensor.projections],[%expr[%eprjs].Tensor.projections_debug])|None->letlhs_dims=project_p_dims"LHS"lhs.pexp_locsetup_l.slotinletrhs1_dims=project_p_dims"RHS1"lhs.pexp_locsetup_r.slotinletproject_lhs=project_p_slot"LHS"lhs.pexp_locsetup_l.slotinletproject_rhs1=project_p_slot"RHS1"rhs.pexp_locsetup_r.slotinletproj_lazy=[%exprlazy(letp=Lazy.forceprojections.Tensor.projectionsinIr.Indexing.{product_space=p.product_space;product_iterators=p.product_iterators;lhs_dims=[%elhs_dims];rhs_dims=[|[%erhs1_dims]|];project_lhs=[%eproject_lhs];project_rhs=[|[%eproject_rhs1]|];debug_info={p.debug_infowithtrace=("ppx_cd vec "^[%estring_expr~locvec_un_op],Ir.Indexing.unique_debug_id())::p.debug_info.trace;};})]in(proj_lazy,[%exprprojections.Tensor.projections_debug])inletbody_for_lhs=[%exprOption.map[%esetup_r.array_opt]~f:(funrhs->Ir.Assignments.Set_vec_unop{lhs;op=[%eop];rhs;projections=[%eprojections_lazy];projections_debug=[%eprojections_debug];})]inassignment~punned~lhs:setup_l~rhses:[setup_r]~body_for_lhs()inletprocess_raw_ternop~accu_op~lhs~tern_op~rhs1~rhs2~rhs3~logic=letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array~punned~bad_pun_hints~is_lhs:true@@loop~proj_in_scopelhsinletsetup_r1=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs1inletsetup_r2=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs2inletsetup_r3=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs3inletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inlett_expr,lhs_is_grad,_=args_for~locsetup_linlett1_expr,rhs1_is_grad,rhs1_is_merge=args_for~locsetup_r1inlett2_expr,rhs2_is_grad,rhs2_is_merge=args_for~locsetup_r2inlett3_expr,rhs3_is_grad,rhs3_is_merge=args_for~locsetup_r3inletraw_body=[%exprTensor.raw_ternop~initialize_neutral:[%einitialize_neutral]~accum:[%eaccu_op]~t:[%et_expr]~lhs_is_grad:[%elhs_is_grad]~op:[%etern_op]~t1:[%et1_expr]~rhs1_is_grad:[%erhs1_is_grad]~rhs1_is_merge:[%erhs1_is_merge]~t2:[%et2_expr]~rhs2_is_grad:[%erhs2_is_grad]~rhs2_is_merge:[%erhs2_is_merge]~t3:[%et3_expr]~rhs3_is_grad:[%erhs3_is_grad]~rhs3_is_merge:[%erhs3_is_merge]~logic:[%elogic]]inassignment~punned~lhs:setup_l~rhses:[setup_r1;setup_r2;setup_r3]~raw_body()inletprocess_raw_binop~accu_op~lhs~bin_op~rhs1~rhs2~logic=letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array~punned~bad_pun_hints~is_lhs:true@@loop~proj_in_scopelhsinletsetup_r1=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs1inletsetup_r2=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs2inletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inlett_expr,lhs_is_grad,_=args_for~locsetup_linlett1_expr,rhs1_is_grad,rhs1_is_merge=args_for~locsetup_r1inlett2_expr,rhs2_is_grad,rhs2_is_merge=args_for~locsetup_r2inletraw_body=[%exprTensor.raw_binop~initialize_neutral:[%einitialize_neutral]~accum:[%eaccu_op]~t:[%et_expr]~lhs_is_grad:[%elhs_is_grad]~op:[%ebin_op]~t1:[%et1_expr]~rhs1_is_grad:[%erhs1_is_grad]~rhs1_is_merge:[%erhs1_is_merge]~t2:[%et2_expr]~rhs2_is_grad:[%erhs2_is_grad]~rhs2_is_merge:[%erhs2_is_merge]~logic:[%elogic]]inassignment~punned~lhs:setup_l~rhses:[setup_r1;setup_r2]~raw_body()inletprocess_raw_unop~accu_op~lhs~un_op~rhs~logic=letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array~punned~bad_pun_hints~is_lhs:true@@loop~proj_in_scopelhsinletsetup_r=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhsinletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inlett_expr,lhs_is_grad,_=args_for~locsetup_linlett1_expr,rhs_is_grad,rhs_is_merge=args_for~locsetup_rinletraw_body=[%exprTensor.raw_unop~initialize_neutral:[%einitialize_neutral]~accum:[%eaccu_op]~t:[%et_expr]~lhs_is_grad:[%elhs_is_grad]~op:[%eun_op]~t1:[%et1_expr]~rhs_is_grad:[%erhs_is_grad]~rhs_is_merge:[%erhs_is_merge]~logic:[%elogic]]inassignment~punned~lhs:setup_l~rhses:[setup_r]~raw_body()inmatchexprwith|{pexp_desc=Pexp_constant(Pconst_float_);_}->{default_resultwithexpr=[%exprNTDSL.number[%eexpr]];slot=Scalar}|{pexp_desc=Pexp_constant(Pconst_integer(_,Some('L'|'l')));_}->{default_resultwithexpr=[%exprNTDSL.bits[%eexpr]];slot=Scalar}|{pexp_desc=Pexp_constant(Pconst_integer_);_}->{default_resultwithexpr=[%exprNTDSL.number(Float.of_int[%eexpr])];slot=Scalar}|[%expr[%e?{pexp_desc=Pexp_constant(Pconst_charch);pexp_loc;_}][%e?{pexp_desc=Pexp_constant(Pconst_float_);_}asf]]->letaxis=Ast_helper.Exp.constant~loc:pexp_loc(Pconst_string(String.of_charch,pexp_loc,None))in{default_resultwithexpr=[%exprNTDSL.number~axis_label:[%eaxis][%ef]];slot=Scalar;}|[%expr[%e?{pexp_desc=Pexp_constant(Pconst_charch);pexp_loc;_}][%e?{pexp_desc=Pexp_constant(Pconst_integer(_,Some('L'|'l')));_}asi]]->letaxis=Ast_helper.Exp.constant~loc:pexp_loc(Pconst_string(String.of_charch,pexp_loc,None))in{default_resultwithexpr=[%exprNTDSL.bits~axis_label:[%eaxis][%ei]];slot=Scalar;}|[%expr[%e?{pexp_desc=Pexp_constant(Pconst_charch);pexp_loc;_}][%e?{pexp_desc=Pexp_constant(Pconst_integer_);_}asi]]->letaxis=Ast_helper.Exp.constant~loc:pexp_loc(Pconst_string(String.of_charch,pexp_loc,None))in{default_resultwithexpr=[%exprNTDSL.number~axis_label:[%eaxis](Float.of_int[%ei])];slot=Scalar;}|{pexp_desc=Pexp_constant(Pconst_string(name,str_loc,_));_}->letvbs=Map.singleton(moduleString)name@@make_vb~loc~name~name_expr:expr~hint_label:(Option.map~f:(funs->[%expr[[%es]]])ident_label)in{vbs;typ=No_grad_tensor_intro{name;name_expr=expr};expr=A.Exp.ident~loc:str_loc{txt=Lidentname;loc=str_loc};array_opt_of_code=None;slot=Undet;}|{pexp_desc=Pexp_array_;_}|{pexp_desc=Pexp_construct({txt=Lident"::";_},_);_}->{default_resultwithexpr=ndarray_opexpr}|{pexp_desc=Pexp_ident{txt=Lident("v"|"lhs");_};_}->{default_resultwithtyp=Array;slot=LHS}|{pexp_desc=Pexp_ident{txt=Lident"g";_};_}->{default_resultwithtyp=Array;slot=LHS}|{pexp_desc=Pexp_ident{txt=Lident"rhs1";_};_}->{default_resultwithtyp=Array;slot=RHS1}|{pexp_desc=Pexp_ident{txt=Lident"t";_};_}->{default_resultwithslot=LHS}|{pexp_desc=Pexp_ident{txt=Lident"t1";_};_}->{default_resultwithslot=RHS1}|{pexp_desc=Pexp_ident{txt=Lident"v1";_};_}->{default_resultwithtyp=Array;slot=RHS1;expr=[%exprt1.Tensor.value]}|{pexp_desc=Pexp_ident{txt=Lident"g1";_};_}->{default_resultwithtyp=Grad_of_tensor[%exprt1];slot=RHS1;expr=[%exprOption.mapt1.Tensor.diff~f:(fund->d.Tensor.grad)];}|{pexp_desc=Pexp_ident{txt=Lident"rhs2";_};_}->{default_resultwithtyp=Array;slot=RHS2}|{pexp_desc=Pexp_ident{txt=Lident"t2";_};_}->{default_resultwithtyp=Tensor;slot=RHS2}|{pexp_desc=Pexp_ident{txt=Lident"v2";_};_}->{default_resultwithtyp=Array;slot=RHS2;expr=[%exprt2.Tensor.value]}|{pexp_desc=Pexp_ident{txt=Lident"g2";_};_}->{default_resultwithtyp=Grad_of_tensor[%exprt2];slot=RHS2;expr=[%exprOption.mapt2.Tensor.diff~f:(fund->d.Tensor.grad)];}|{pexp_desc=Pexp_ident{txt=Lident"rhs3";_};_}->{default_resultwithtyp=Array;slot=RHS3}|{pexp_desc=Pexp_ident{txt=Lident"t3";_};_}->{default_resultwithtyp=Tensor;slot=RHS3}|{pexp_desc=Pexp_ident{txt=Lident"v3";_};_}->{default_resultwithtyp=Array;slot=RHS3;expr=[%exprt3.Tensor.value]}|{pexp_desc=Pexp_ident{txt=Lident"g3";_};_}->{default_resultwithtyp=Grad_of_tensor[%exprt3];slot=RHS3;expr=[%exprOption.mapt3.Tensor.diff~f:(fund->d.Tensor.grad)];}|{pexp_desc=Pexp_ident{txt=Lidentop_ident;_};_}whenis_primitive_opop_ident->default_result|[%expr!.[%e?expr1]]->(* Hardcoding these two patterns (!. and !..) to improve projection derivation expressivity
and avoid treating the constants as already tensors. *){typ=Tensor;slot=Scalar;expr=[%exprNTDSL.O.(!.)[%eexpr1]];array_opt_of_code=None;vbs=no_vbs;}|[%expr!..[%e?expr1]]->{typ=Tensor;slot=Scalar;expr=[%exprNTDSL.O.(!..)[%eexpr1]];array_opt_of_code=None;vbs=no_vbs;}|[%expr[%e?expr1]**.[%e?{pexp_desc=Pexp_constant(Pconst_integer_);_}asi]]->(* We need to hardcode these two patterns (for **. ) to prevent the numbers from being
converted to tensors. *)letres1=loop~proj_in_scopeexpr1in{res1withtyp=Tensor;expr=[%exprNTDSL.O.(**.)[%eres1.expr](Float.of_int[%ei])];}|[%expr[%e?expr1]**.[%e?expr2]]->letres1=loop~proj_in_scopeexpr1in{res1withtyp=Tensor;expr=[%exprNTDSL.O.(**.)[%eres1.expr][%eexpr2]]}|[%expr[%e?expr1]*+[%e?{pexp_desc=Pexp_constant(Pconst_string(spec_str,_,_));_}][%e?expr2]]whenString.containsspec_str'>'->letres1=loop~proj_in_scopeexpr1inletres2=loop~proj_in_scopeexpr2inletspec=substitute_identifiers_in_einsum_spec~locspec_strinletslot=List.hd_exn@@List.sort[res1.slot;res2.slot]~compare:compare_slotsin{vbs=reduce_vbss[res1.vbs;res2.vbs];typ=Tensor;slot;expr=[%expreinsum[%espec][%eres1.expr][%eres2.expr]];array_opt_of_code=None;}|[%expr[%e?expr1]++[%e?{pexp_desc=Pexp_constant(Pconst_string(spec_str,_,_));_}]]whenString.containsspec_str'>'->letres1=loop~proj_in_scopeexpr1inletspec=substitute_identifiers_in_einsum_spec~locspec_strin{res1withtyp=Tensor;expr=[%expreinsum1[%espec][%eres1.expr]]}|[%expr[%e?expr1].grad]->(letres1=loop~proj_in_scopeexpr1inmatchres1.typwith|Unknown|Tensor|No_grad_tensor_intro_->{res1withtyp=Grad_of_tensorexpr1;expr=[%exprOption.map[%eres1.expr].Tensor.diff~f:(fund->d.Tensor.grad)];(* It's never a good idea to embed backprop code outside of a proper backprop pass. *)}|Merge_value_->{res1withtyp=Merge_gradexpr1;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: write .grad.merge instead of .merge.grad";}|Function|Code_|Array|Value_of_tensor_|Grad_of_tensor_|Merge_grad_->{res1withtyp=Array;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: only tensors have a gradient";})|[%expr[%e?expr1].value]->(letres1=loop~proj_in_scopeexpr1in(* TODO: maybe this is too permissive? E.g. [t1.grad.value] is accepted. *)matchres1.typwith|Unknown|Tensor|No_grad_tensor_intro_->{res1withtyp=Value_of_tensorres1.expr;expr=[%expr[%eres1.expr].Tensor.value];}|Function|Code_->{res1withtyp=Array;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: <code>.value notation not supported when <code> is not a \
tensor";}|Array|Value_of_tensor_|Grad_of_tensor_|Merge_value_|Merge_grad_->res1)|[%expr[%e?expr1].merge]->(letres1=loop~proj_in_scopeexpr1inmatchres1.typwith|Unknown|Tensor|No_grad_tensor_intro_->{res1withtyp=Merge_valueres1.expr;expr=[%expr[%eres1.expr].Tensor.value]}|Value_of_tensort->{res1withtyp=Merge_valuet;expr=[%expr[%eres1.expr].Tensor.value]}|Function|Array|Code_->{res1withtyp=Array;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: only tensor nodes (e.g. `.value` or `.grad`) can be merged";}|Grad_of_tensort->{res1withtyp=Merge_gradt}|Merge_value_|Merge_grad_->{res1withexpr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: repeated .merge not allowed";})|[%expr[%e?expr1].forward]->(letres1=loop~proj_in_scopeexpr1inmatchres1.typwith|Unknown|Tensor|No_grad_tensor_intro_->{res1withtyp=Code{is_commented=false};expr=[%exprTensor.consume_forward_code[%eres1.expr]];}|_->{res1withexpr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: .forward can only be applied to tensors";})|[%expr[%e?expr1].backprop]->(letres1=loop~proj_in_scopeexpr1inmatchres1.typwith|Unknown|Tensor|No_grad_tensor_intro_->{res1withtyp=Code{is_commented=false};expr=[%exprTensor.consume_backprop_code[%eres1.expr]];}|_->{res1withexpr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: .backprop can only be applied to tensors";})|[%expr[%e?expr1].zero_grads]->(letres1=loop~proj_in_scopeexpr1inmatchres1.typwith|Unknown|Tensor|No_grad_tensor_intro_->{res1withtyp=Code{is_commented=false};expr=[%exprmatch[%eres1.expr].diffwith|None->raise(Invalid_argument"ppx_ocannl %cd: .zero_grads requires a differentiable tensor")|Somediff->Ir.Assignments.to_compdiff.zero_grads];}|_->{res1withexpr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: .zero_grads can only be applied to tensors";})|[%expr~~([%e?{pexp_desc=Pexp_constant(Pconst_string_);_}ascomment];[%e?expr2])]->letres2=loop~proj_in_scopeexpr2inletblock=matchres2.typwith|Code_->res2.expr|_->Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: only code can be commented, e.g. assignments or t.forward, \
t.backprop, t.zero_grads"in{res2withexpr=[%exprlet__comment_block=[%eblock]in{Ir.Assignments.asgns=Ir.Assignments.Block_comment([%ecomment],__comment_block.Ir.Assignments.asgns);embedded_nodes=__comment_block.Ir.Assignments.embedded_nodes;}];}|[%expr~~([%e?{pexp_desc=Pexp_apply(expr,exprs);pexp_loc;_}];[%e?expr2])]->letelements=expr::List.map~f:sndexprs|>List.map~f:(function|{pexp_desc=Pexp_constant(Pconst_string_);_}ass->s|[%expr[%e?t].value]->[%exprIr.Tnode.debug_name[%et].value]|[%expr[%e?t].grad]->[%exprIr.Tnode.debug_name[%et].value^".grad"]|t->[%exprIr.Tnode.debug_name[%et].value])inletres2=loop~proj_in_scopeexpr2inletblock=matchres2.typwith|Code_->res2.expr|_->Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: only code can be commented, e.g. assignments or t.forward, \
t.backprop, t.zero_grads"in{res2withexpr=[%exprlet__comment_block=[%eblock]in{Ir.Assignments.asgns=Ir.Assignments.Block_comment(String.concat_array~sep:" "[%eAst_helper.Exp.array~loc:pexp_locelements],__comment_block.Ir.Assignments.asgns);embedded_nodes=__comment_block.Ir.Assignments.embedded_nodes;}];}|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbin_op;_};_}]([%e?rhs1],[%e?rhs2])~projections:[%e?projections])]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbin_op;_};_}][%e?rhs1][%e?rhs2]~projections:[%e?projections])]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbin_op;_};_}][%e?rhs1]([%e?rhs2]~projections:[%e?projections]))]->(* TODO: when clause not needed here and below, it's an error if bin_op is not a primitive
binary op. But this is error-prone with regard to ordering of the clauses. *)process_assign_binop~accu_op~lhs~bin_op~rhs1~rhs2~projections~proj_in_scope:true()|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidenttern_op;_};_}]([%e?rhs1],[%e?rhs2],[%e?rhs3])~projections:[%e?projections])]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidenttern_op;_};_}][%e?rhs1][%e?rhs2][%e?rhs3]~projections:[%e?projections])]->process_assign_ternop~accu_op~lhs~tern_op~rhs1~rhs2~rhs3~projections~proj_in_scope:true()|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentun_op;_};_}][%e?rhs]~projections:[%e?projections])]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs](* FIXME: this was never needed as prefix operators bind tighter? *)([%e?{pexp_desc=Pexp_ident{txt=Lidentun_op;_};_}]([%e?rhs]~projections:[%e?projections]))]whenHashtbl.memunary_opsun_op->process_assign_unop~accu_op~lhs~un_op~rhs~projections~proj_in_scope:true()|[%expr[%e?lhs]=:[%e?{pexp_desc=Pexp_ident{txt=Lidentvec_un_op;_};_}][%e?rhs]~projections:[%e?projections]]whenHashtbl.memvec_unary_opsvec_un_op->process_vec_unop~lhs~vec_un_op~rhs~projections~proj_in_scope:true()|[%expr[%e?lhs]=:[%e?{pexp_desc=Pexp_ident{txt=Lidentvec_un_op;_};_}][%e?rhs]]whenHashtbl.memvec_unary_opsvec_un_op&&proj_in_scope->process_vec_unop~lhs~vec_un_op~rhs~proj_in_scope()|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?rhs]~projections:[%e?projections])]->process_assign_unop~accu_op~lhs~un_op:"id"~rhs~projections~proj_in_scope:true()|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbin_op;_};_}]([%e?rhs1],[%e?rhs2])~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}aslogic])]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbin_op;_};_}][%e?rhs1][%e?rhs2]~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}aslogic])]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbin_op;_};_}][%e?rhs1]([%e?rhs2]~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}aslogic]))]->letlogic=letloc=s_locinifString.equalspec"."then[%exprShape.Pointwise_bin]elseifString.equalspec"@"then[%exprShape.Compose]else[%exprShape.Einsum[%elogic]]inlet_,bin_op=binary_opbin_opinprocess_raw_binop~accu_op~lhs~bin_op~rhs1~rhs2~logic|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidenttern_op;_};_}]([%e?rhs1],[%e?rhs2],[%e?rhs3])~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}])]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidenttern_op;_};_}][%e?rhs1][%e?rhs2][%e?rhs3]~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}])]->letlogic=letloc=s_locinifString.equalspec"."then[%exprShape.Pointwise_bin]elseifString.equalspec"@"then[%exprShape.Compose]elseAst_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: expected <.> or <@>, found <%s> -- einsum notation for ternary \
operators not supported yet, see issue #305"specinlet_,tern_op=ternary_optern_opinprocess_raw_ternop~accu_op~lhs~tern_op~rhs1~rhs2~rhs3~logic|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs](([%e?{pexp_desc=Pexp_ident{txt=Lidentunop_ident;_};_}][%e?rhs])~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}aslogic])]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentunop_ident;_};_}][%e?rhs]~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}aslogic])]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs](* FIXME: this was never needed as prefix operators bind tighter? *)([%e?{pexp_desc=Pexp_ident{txt=Lidentunop_ident;_};_}]([%e?rhs]~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}aslogic]))]whenHashtbl.memunary_opsunop_ident->(* Handle both un_op priority levels -- where application binds tighter and less tight. *)letlogic=letloc=s_locinifString.equalspec"."then[%exprShape.Pointwise_un]elseifString.equalspec"T"then[%exprShape.Transpose]else[%exprShape.Permute[%elogic]]inlet_,un_op=Hashtbl.find_exnunary_opsunop_identlocinprocess_raw_unop~accu_op~lhs~un_op~rhs~logic|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbin_op;_};_}]([%e?rhs1],[%e?rhs2]))]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbin_op;_};_}][%e?rhs1][%e?rhs2])]whenis_assignmentaccu_op&&Hashtbl.membinary_opsbin_op&&proj_in_scope->process_assign_binop~accu_op~lhs~bin_op~rhs1~rhs2~proj_in_scope()|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidenttern_op;_};_}]([%e?rhs1],[%e?rhs2],[%e?rhs3]))]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidenttern_op;_};_}][%e?rhs1][%e?rhs2][%e?rhs3])]whenis_assignmentaccu_op&&Hashtbl.memternary_opstern_op&&proj_in_scope->process_assign_ternop~accu_op~lhs~tern_op~rhs1~rhs2~rhs3~proj_in_scope()|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentun_op;_};_}][%e?rhs])]whenis_assignmentaccu_op&&Hashtbl.memunary_opsun_op&&proj_in_scope->process_assign_unop~accu_op~lhs~un_op~rhs~proj_in_scope()|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs][%e?rhs]]whenis_assignmentaccu_op&&proj_in_scope->process_assign_unop~accu_op~lhs~un_op:"id"~rhs~proj_in_scope()|[%expr[%e?lhs]=:[%e?{pexp_desc=Pexp_ident{txt=Lidentvec_un_op;_};_}][%e?rhs]]whenHashtbl.memvec_unary_opsvec_un_op&&proj_in_scope->process_vec_unop~lhs~vec_un_op~rhs~proj_in_scope()|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbin_op;_};_}]([%e?rhs1],[%e?rhs2]))]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbin_op;_};_}][%e?rhs1][%e?rhs2])]whenis_assignmentaccu_op&&Hashtbl.membinary_opsbin_op->letlogic,bin_op=binary_opbin_opinprocess_raw_binop~accu_op~lhs~bin_op~rhs1~rhs2~logic|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidenttern_op;_};_}]([%e?rhs1],[%e?rhs2],[%e?rhs3]))]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidenttern_op;_};_}][%e?rhs1][%e?rhs2][%e?rhs3])]whenis_assignmentaccu_op&&Hashtbl.memternary_opstern_op->letlogic,tern_op=ternary_optern_opinprocess_raw_ternop~accu_op~lhs~tern_op~rhs1~rhs2~rhs3~logic|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentun_op;_};_}][%e?rhs])]whenis_assignmentaccu_op&&Hashtbl.memunary_opsun_op->letlogic,un_op=Hashtbl.find_exnunary_opsun_oplocinprocess_raw_unop~accu_op~lhs~un_op~rhs~logic|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs][%e?rhs]]whenis_assignmentaccu_op->process_raw_unop~accu_op~lhs~un_op:[%exprIr.Ops.Identity]~rhs~logic:[%exprShape.Pointwise_un]|[%expr[%e?expr1][%e?expr2][%e?expr3]]->letres1=loop~proj_in_scopeexpr1inletres2=loop~proj_in_scopeexpr2inletres3=loop~proj_in_scopeexpr3inletslot=List.hd_exn@@List.sort[res1.slot;res2.slot;res3.slot]~compare:compare_slotsin{vbs=reduce_vbss[res1.vbs;res2.vbs;res3.vbs];typ=res1.typ;slot;expr=[%expr[%eres1.expr][%eres2.expr][%eres3.expr]];array_opt_of_code=None;}|[%expr[%e?expr1][%e?expr2]]->letres1=loop~proj_in_scopeexpr1inletres2=loop~proj_in_scopeexpr2inletslot=List.hd_exn@@List.sort[res1.slot;res2.slot]~compare:compare_slotsin{vbs=reduce_vbss[res1.vbs;res2.vbs];typ=res1.typ;slot;expr=[%expr[%eres1.expr][%eres2.expr]];array_opt_of_code=None;}|{pexp_desc=Pexp_function(args,constr,body);_}asexpr->letproj_in_scope=proj_in_scope||List.existsargs~f:(function|{pparam_desc=Pparam_val((Labelleds|Optionals),_,_);_}whenString.equals"projections"->true|_->false)inletbad_pun_hints=Set.union_list(moduleString)@@bad_pun_hints::List.mapargs~f:(funarg->matcharg.pparam_descwith|Pparam_val(_,_,pat)->collect_pat_identspat|_->Set.empty(moduleString))inletresult=matchbodywith|Pfunction_bodybody->letres=transl~bad_pun_hints~proj_in_scopebodyin{reswithtyp=Function;expr={exprwithpexp_desc=Pexp_function(args,constr,Pfunction_bodyres.expr)};}|Pfunction_cases(cases,loc,attrs)->lettransformed_cases,cases_result=handle_cases~bad_pun_hints~proj_in_scope(fun~bad_pun_hints~proj_in_scope->transl~bad_pun_hints~proj_in_scope)casesin{cases_resultwithtyp=Function;expr={exprwithpexp_desc=Pexp_function(args,constr,Pfunction_cases(transformed_cases,loc,attrs));};}inresult|[%exprwhile[%e?_test_expr]do[%e?_body]done]->{default_resultwithtyp=Code{is_commented=false};slot=Nonslot;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: while: low-level code embeddings not supported yet";}|[%exprfor[%p?_pat]=[%e?_init]to[%e?_final]do[%e?_body_expr]done]->{default_resultwithtyp=Code{is_commented=false};slot=Nonslot;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: for-to: low-level code embeddings not supported yet";}|[%exprfor[%p?_pat]=[%e?_init]downto[%e?_final]do[%e?_body_expr]done]->{default_resultwithtyp=Code{is_commented=false};slot=Nonslot;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: for-downto: low-level code embeddings not supported yet";}|[%expr[%e?expr1];[%e?expr2]]->letres1=loop~proj_in_scopeexpr1inletres2=loop~proj_in_scopeexpr2in{vbs=reduce_vbss[res1.vbs;res2.vbs];typ=Code{is_commented=false};slot=Nonslot;expr=[%exprIr.Assignments.sequence[[%eres1.expr];[%eres2.expr]]];array_opt_of_code=res2.array_opt_of_code;}|[%exprif[%e?expr1]then[%e?expr2]else[%e?expr3]]->letres2=loop~proj_in_scopeexpr2inletres3=loop~proj_in_scopeexpr3inlettyp=ifis_unknownres2.typthenres3.typelseres2.typinletslot=List.hd_exn@@List.sort[res2.slot;res3.slot]~compare:compare_slotsin{vbs=reduce_vbss[res2.vbs;res3.vbs];typ;slot;expr=[%exprif[%eexpr1]then[%eres2.expr]else[%eres3.expr]];array_opt_of_code=None;}|[%exprif[%e?expr1]then[%e?expr2]]->letres2=loop~proj_in_scopeexpr2in{vbs=res2.vbs;typ=Code{is_commented=false};slot=Nonslot;expr=[%exprif[%eexpr1]then[%eres2.expr]elseIr.Assignments.empty_comp];array_opt_of_code=res2.array_opt_of_code;}|{pexp_desc=Pexp_match(expr1,cases);_}->lettransformed_cases,cases_result=handle_cases~bad_pun_hints~proj_in_scopetranslcasesin{cases_resultwithexpr={exprwithpexp_desc=Pexp_match(expr1,transformed_cases)}}|{pexp_desc=Pexp_let(_recflag,_bindings,_body);_}->(* TODO(#80): to properly support local bindings, we need to collect the type
environment. *){default_resultwithtyp=Unknown;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: let-in: local let-bindings not implemented yet";}(* let bindings = List.map bindings ~f:(fun binding -> {binding with pvb_expr=loop
binding.pvb_expr}) in {expr with pexp_desc=Pexp_let (recflag, bindings, loop body)} *)|{pexp_desc=Pexp_open(decl,expr1);_}->letres1=loop~proj_in_scopeexpr1in{res1withexpr={exprwithpexp_desc=Pexp_open(decl,res1.expr)}}|{pexp_desc=Pexp_letmodule(name,module_expr,expr1);_}->letres1=loop~proj_in_scopeexpr1in{res1withexpr={exprwithpexp_desc=Pexp_letmodule(name,module_expr,res1.expr)}}|_->{default_resultwithtyp=Unknown}inletres=transl~proj_in_scope:false~bad_pun_hints:(Set.empty(moduleString))exprinmatch(res.typ,ident_label)with|Code{is_commented=false},Somestring_expr->letloc=res.expr.pexp_locin{reswithexpr=[%exprletuncommented_comp=[%eres.expr]in{Ir.Assignments.embedded_nodes=uncommented_comp.Ir.Assignments.embedded_nodes;asgns=Ir.Assignments.Block_comment([%estring_expr],uncommented_comp.Ir.Assignments.asgns);}];typ=Code{is_commented=true};}|_->reslettranslate?ident_labelexpr=letident_label,is_ignore=matchident_labelwith|Some[%pat?_]->(None,true)|Somelabel->(Some(pat2stringlabel),false)|None->(None,false)inletres=translate?ident_labelexprinletloc=res.expr.pexp_locinletexpr=res.exprin(res.vbs,ifis_ignorethen[%exprTensor.with_unchanged_roots~f:(fun()->letopen!NTDSL.Oin[%eexpr])]else[%exprletopen!NTDSL.Oin[%eexpr]])letexpr_expander~loc~path=expr_expander_with_punningtranslate~loc~pathletstr_expander~loc~path=str_expander_with_punningtranslate~loc~path