123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035openBaseopenPpxlibopenPpx_arrayjit.Ppx_helperopenPpx_sharedmoduleA=Ppxlib_ast.Ast_helperletndarray_op?axis_labels?labelexpr=letloc=expr.pexp_locinletvalues,batch_dims,output_dims,input_dims=ndarray_constantexprinletedimsdims=Ast_builder.Default.elist~locdimsinletop=match(axis_labels,label)with|None,None->[%exprNTDSL.ndarray]|Someaxis_labels,None->[%exprNTDSL.ndarray~axis_labels:[%eaxis_labels]]|None,Somelabel->[%exprNTDSL.ndarray~label:[%elabel]]|Someaxis_labels,Somelabel->[%exprNTDSL.ndarray~axis_labels:[%eaxis_labels]~label:[%elabel]]in[%expr[%eop]~batch_dims:[%eedimsbatch_dims]~input_dims:[%eedimsinput_dims]~output_dims:[%eedimsoutput_dims][%evalues]]typeexpr_type=|Code|Array|Value_of_tensorofexpression|Grad_of_tensorofexpression|Tensor|Unknown|Merge_valueofexpression|Merge_gradofexpression|No_grad_tensor_introof{name:string;name_expr:expression}letis_unknown=functionUnknown->true|_->falsetypeprojections_slot=LHS|RHS1|RHS2|Nonslot|Undet[@@derivingequal,sexp]letassignment_opexpr=(* This should stay in sync with Arrayjit.Ops.assign_op_cd_syntax. *)letloc=expr.pexp_locinmatchexprwith|[%expr(=:)]->(false,[%exprArrayjit.Ops.Arg2])|[%expr(=+)]->(false,[%exprArrayjit.Ops.Add])|[%expr(=-)]->(false,[%exprArrayjit.Ops.Sub])|[%expr(=*)]->(false,[%exprArrayjit.Ops.Mul])|[%expr(=/)]->(false,[%exprArrayjit.Ops.Div])|[%expr(=**)]->(false,[%exprArrayjit.Ops.ToPowOf])|[%expr(=?/)]->(false,[%exprArrayjit.Ops.Relu_gate])|[%expr(=:+)]->(true,[%exprArrayjit.Ops.Add])|[%expr(=:-)]->(true,[%exprArrayjit.Ops.Sub])|[%expr(=:*)]->(true,[%exprArrayjit.Ops.Mul])|[%expr(=:/)]->(true,[%exprArrayjit.Ops.Div])|[%expr(=:**)]->(true,[%exprArrayjit.Ops.ToPowOf])|[%expr(=:?/)]->(true,[%exprArrayjit.Ops.Relu_gate])|_->(false,Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: expected an assignment operator, one of: %s %s""=+ (Add), =- (Sub), =* (Mul), =/ (Div), =** (ToPowOf), =?/ (Relu_gate), =: (Arg2), \
=:+, =:-,"" =:*, =:/, =:**, =:?/ (same with initializing the tensor to the neutral value before \
the start of the calculation)")letbinary_opexpr=(* This and is_binary_op should stay in sync with Arrayjit.Ops.binop_cd_syntax. *)letloc=expr.pexp_locinmatchexprwith|[%expr(+)]->([%exprShape.Pointwise_bin],[%exprArrayjit.Ops.Add])|[%expr(-)]->([%exprShape.Pointwise_bin],[%exprArrayjit.Ops.Sub])|[%expr(*)]->(Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"No default compose type for binary `*`, try e.g. ~logic:\".\" for pointwise, %s""~logic:\"@\" for matrix multiplication",[%exprArrayjit.Ops.Mul])|[%expr(/)]->(Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"For clarity, no default compose type for binary `/`, use ~logic:\".\" for pointwise \
division",[%exprArrayjit.Ops.Div])|[%expr(**)]->([%exprShape.Pointwise_bin],[%exprArrayjit.Ops.ToPowOf])|[%expr(-?/)]->([%exprShape.Pointwise_bin],[%exprArrayjit.Ops.Relu_gate])|[%expr(-/>)]->([%exprShape.Pointwise_bin],[%exprArrayjit.Ops.Arg2])|[%expr(-@>)]->([%exprShape.Pointwise_bin],[%exprArrayjit.Ops.Arg1])|_->([%exprShape.Pointwise_bin],Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: expected a binary operator, one of: %s""+ (Add), - (Sub), * (Mul), / (Div), ** (ToPowOf), -?/ (Relu_gate), -/> (Arg2)")letis_binary_opident=List.mem["+";"-";"*";"/";"**";"-?/";"-/>";"-@>"]ident~equal:String.equalletunary_opexpr=(* This and is_unary_op should stay in sync with Arrayjit.Ops.unop_cd_syntax. *)letloc=expr.pexp_locinmatchexprwith|[%expr(~=)]->([%exprShape.Pointwise_un],[%exprArrayjit.Ops.Identity])|[%expr(?/)]->([%exprShape.Pointwise_un],[%exprArrayjit.Ops.Relu])|_->([%exprShape.Pointwise_un],Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: expected a unary operator, one of: = (Identity), ?/ (Relu)")letis_unary_opident=List.mem["~=";"?/"]ident~equal:String.equaltyperesult={vbs:value_bindingMap.M(String).t;(** [vbs] are the bindings introduced by inline tensor declarations (aka. punning). These
bindings are discharged with the whole [%cd] extension scope in scope. *)typ:expr_type;slot:projections_slot;expr:expression;array_opt_of_code:expressionoption;(** [array_opt_of_code] keeps track of which tensor node to use when [typ] is [Code] but the
result is used in an array context. *)}typearray_setup={vb:value_bindingoption;(** This binding is only generated for a tensor expression that is not an identifier, since
recomputing the expression (when copied) would generate a fresh tensor. It is discharged
when an assignment is built. *)slot:projections_slot;filler_typ:expr_type;fwd_code_or_noop:expressionoption;array_opt:expression;tensor:expressionoption;pun_hint_tnode:(expression*bool)option;(** The tnode, if any, whose label the relevant punned no-gradient tensor should incorporate
in its label. The bool denotes whether this is a preferred (high quality) guess. *)}letmake_vb~loc~name~name_expr~hint_label=letpat=A.Pat.var~loc{loc=name_expr.pexp_loc;txt=name}inletv=matchhint_labelwith|None->[%exprNTDSL.term~label:[[%ename_expr]]()]|Somehint_label->[%exprNTDSL.term~label:([%ename_expr]::[%ehint_label])()]inletvb=A.Vb.mk~locpatvinvbletassignment~punned~lhs~rhsesbody=letsetups=lhs::rhsesinletloc=body.pexp_locinletforward_args=List.filter_mapsetups~f:(fun{fwd_code_or_noop;_}->fwd_code_or_noop)|>List.reduce~f:(funcodefwd->[%exprArrayjit.Assignments.Seq([%ecode],[%efwd])])inletvbs,body=matchlhs.filler_typwith|No_grad_tensor_intro{name;name_expr}->(letgood_hints,bad_hints=List.partition_tf~f:snd@@List.filter_maprhses~f:(funsup->sup.pun_hint_tnode)inlethint_data=Option.first_some(List.hdgood_hints)(List.hdbad_hints)inlethint_label=Option.map~f:fsthint_datainletvbs=Map.singleton(moduleString)name@@make_vb~loc~name~name_expr~hint_labelinmatchhint_datawith|None->(vbs,body)|Somedata->(matchHashtbl.addpunned~key:name~datawith|`Ok->(vbs,body)|`Duplicate->(no_vbs,Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: duplicate inline declaration of no-gradient tensor %s"name)))|_->(no_vbs,body)inletbody=ifOption.is_somelhs.vbthenAst_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: the assigned-to position cannot be an expression building a new tensor"elsebodyinlettensor_vbs=List.filter_maprhses~f:(funrhs->rhs.vb)inletexpr=matchforward_argswith|None->body|Somefwd->[%exprArrayjit.Assignments.Seq([%efwd],[%ebody])]inletexpr=ifList.is_emptytensor_vbsthenexprelseA.Exp.let_~locNonrecursivetensor_vbsexprin{vbs;typ=Code;slot=Nonslot;expr;array_opt_of_code=Somelhs.array_opt}letproject_p_slotdebuglocslot=matchslotwith|LHS->[%exprp.project_lhs]|RHS1->[%exprp.project_rhs.(0)]|RHS2->[%exprp.project_rhs.(1)]|Nonslot->Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: not a valid accumulation/assignment slot filler at %s"debug|Undet->Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: insufficient slot filler information at %s %s"debug"(incorporate one of: v, v1, v2, g, g1, g2, lhs, rhs, rhs1, rhs2)"letproject_p_dimsdebuglocslot=matchslotwith|LHS->[%exprp.lhs_dims]|RHS1->[%exprp.rhs_dims.(0)]|RHS2->[%exprp.rhs_dims.(1)]|Nonslot->Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: not a valid accumulation/assignment slot filler at %s"debug|Undet->Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: insufficient slot filler information at %s %s"debug"(incorporate one of: v, v1, v2, g, g1, g2, lhs, rhs, rhs1, rhs2)"letguess_pun_hint~punned~bad_pun_hintsfiller_typfiller=letloc=filler.pexp_locinlethint=[%expr[%efiller].Arrayjit.Tnode.label]inmatch(filler_typ,filler)with|Code,_->None|_,{pexp_desc=Pexp_ident{txt=Lidentname;_};_}whenSet.membad_pun_hintsname->None|Array,_->Some(hint,false)|(Tensor|Unknown),{pexp_desc=Pexp_ident{txt=Lidentname;_};_}whenHashtbl.mempunnedname->Hashtbl.findpunnedname|(Tensor|Unknown),{pexp_desc=Pexp_ident_;_}->Some(hint,true)|(Tensor|Unknown),_->Some(hint,false)|((Value_of_tensor{pexp_desc=Pexp_ident{txt=Lidentname;_};_}|Grad_of_tensor{pexp_desc=Pexp_ident{txt=Lidentname;_};_}|Merge_value{pexp_desc=Pexp_ident{txt=Lidentname;_};_}|Merge_grad{pexp_desc=Pexp_ident{txt=Lidentname;_};_}),_)whenSet.membad_pun_hintsname->None|((Value_of_tensor{pexp_desc=Pexp_ident{txt=Lidentname;_};_}|Grad_of_tensor{pexp_desc=Pexp_ident{txt=Lidentname;_};_}|Merge_value{pexp_desc=Pexp_ident{txt=Lidentname;_};_}|Merge_grad{pexp_desc=Pexp_ident{txt=Lidentname;_};_}),_)whenHashtbl.mempunnedname->Hashtbl.findpunnedname|(Value_of_tensort|Grad_of_tensort|Merge_valuet|Merge_gradt),_->(lethint=[%expr[%et].Tensor.value.Arrayjit.Tnode.label]inmatchtwith{pexp_desc=Pexp_ident_;_}->Some(hint,true)|_->Some(hint,false))|No_grad_tensor_intro{name;_},_->Hashtbl.findpunnednameletsetup_array~punned~bad_pun_hints~is_lhs{typ=filler_typ;slot;expr=filler;vbs;array_opt_of_code}=assert(Map.is_emptyvbs);letloc=filler.pexp_locinletopt_buffertn=ifis_lhsthen[%exprSome[%etn]]else[%exprSome(Arrayjit.Assignments.Node[%etn])]inletbufferopt_tn=ifis_lhsthenopt_tnelse[%exprOption.map[%eopt_tn]~f:(funtn->Arrayjit.Assignments.Nodetn)]inletpun_hint_tnode=guess_pun_hint~punned~bad_pun_hintsfiller_typfillerinletdefault_setup={vb=None;fwd_code_or_noop=None;filler_typ;slot;array_opt=opt_buffer[%expr[%efiller].Tensor.value];tensor=None;pun_hint_tnode;}inmatchfiller_typwith|No_grad_tensor_intro_whennotis_lhs->{default_setupwitharray_opt=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: punning is only allowed in the assigned-to position";}|(Tensor|Unknown)whenmatchfillerwith{pexp_desc=Pexp_ident_;_}->true|_->false->lett=fillerinletfwd_code_or_noop=Some[%exprifTensor.is_fwd_root[%et]then(Tensor.remove_fwd_root[%et];[%et].Tensor.forward)elseArrayjit.Assignments.Noop]in{default_setupwithfwd_code_or_noop;tensor=Somet}|Value_of_tensor({pexp_desc=Pexp_ident_;_}ast)->letfwd_code_or_noop=Some[%exprifTensor.is_fwd_root[%et]then(Tensor.remove_fwd_root[%et];[%et].Tensor.forward)elseArrayjit.Assignments.Noop]in{default_setupwithfwd_code_or_noop;array_opt=opt_buffer[%expr[%et].Tensor.value];tensor=Somet;}|Value_of_tensort->{default_setupwitharray_opt=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: the <tensor>.value notation is only supported when <tensor> is an \
identifier";tensor=Somet;}|Tensor|Unknown->(* Need to bind the expression computing the tensor so we don't recompute it. *)letv=matchslotwith|LHS->[%pat?nondiff__lhs]|RHS1->[%pat?nondiff__rhs1]|RHS2->[%pat?nondiff__rhs2]|Nonslot|Undet->[%pat?nondiff__tensor]inlett=pat2exprvinletvb=Some(A.Vb.mk~locvfiller)inletfwd_code_or_noop=Some[%exprifTensor.is_fwd_root[%et]then(Tensor.remove_fwd_root[%et];[%et].Tensor.forward)elseArrayjit.Assignments.Noop]in{default_setupwithvb;fwd_code_or_noop;array_opt=opt_buffer[%expr[%et].Tensor.value];tensor=Somet;}|No_grad_tensor_intro_->(* Inline tensors are guaranteed to be leaf tensors, so they don't have forward code. *){default_setupwithtensor=Somefiller}|CodewhenOption.is_nonearray_opt_of_code->{default_setupwithfwd_code_or_noop=Somefiller;array_opt=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: could not determine a lead array of provided code";}|Code->{default_setupwithfwd_code_or_noop=Somefiller;array_opt=buffer(Option.value_exnarray_opt_of_code);}|Array->{default_setupwitharray_opt=opt_bufferfiller}|Grad_of_tensor({pexp_desc=Pexp_ident_;_}ast)->{default_setupwitharray_opt=bufferfiller;tensor=Somet}|Grad_of_tensort->{default_setupwitharray_opt=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: the <tensor>.grad notation is only supported when <tensor> is an \
identifier";tensor=Somet;}|(Merge_value_|Merge_grad_)whenis_lhs->{default_setupwitharray_opt=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: merge buffers cannot be assigned to";}|Merge_valuet->{default_setupwitharray_opt=[%exprSome(Merge_buffer[%efiller])];tensor=Somet}|Merge_gradt->{default_setupwitharray_opt=[%exprOption.map[%efiller]~f:(funtn->Arrayjit.Assignments.Merge_buffertn)];tensor=Somet;}letargs_for~loc=function|{filler_typ=Merge_grad_;tensor=Somet;_}->(t,[%exprtrue],[%exprtrue])|{filler_typ=Grad_of_tensor_;tensor=Somet;_}->(t,[%exprtrue],[%exprfalse])|{filler_typ=Merge_value_;tensor=Somet;_}->(t,[%exprfalse],[%exprtrue])|{filler_typ=_;tensor=Somet;_}->(t,[%exprfalse],[%exprfalse])|_->(Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: cannot use `~logic` (infer shapes) for arrays, use tensors or \
`.value` or `.grad` notation",[%exprfalse],[%exprfalse])lettranslate(expr:expression):result=letpunned=Hashtbl.create(moduleString)inletrectransl~bad_pun_hints~proj_in_scope(expr:expression):result=letloc=expr.pexp_locinletdefault_result={vbs=no_vbs;typ=Tensor;slot=Undet;expr;array_opt_of_code=None}inletloop=transl~bad_pun_hintsinletprocess_assign_binop~accu_op~lhs~bin_op~rhs1~rhs2?projections~proj_in_scope()=letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array~punned~bad_pun_hints~is_lhs:true@@loop~proj_in_scope:truelhsinlet_,bin_op=binary_opbin_opinletsetup_r1=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs1inletsetup_r2=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs2inletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inletprojections=matchprojectionswith|Someprjs->prjs|None->letlhs_dims=project_p_dims"LHS"lhs.pexp_locsetup_l.slotinletrhs1_dims=project_p_dims"RHS1"lhs.pexp_locsetup_r1.slotinletrhs2_dims=project_p_dims"RHS2"lhs.pexp_locsetup_r2.slotinletproject_lhs=project_p_slot"LHS"lhs.pexp_locsetup_l.slotinletproject_rhs1=project_p_slot"RHS1"rhs1.pexp_locsetup_r1.slotinletproject_rhs2=project_p_slot"RHS2"rhs2.pexp_locsetup_r2.slotin[%exprlazy(letp=Lazy.forceprojectionsinArrayjit.Indexing.{product_space=p.product_space;product_iterators=p.product_iterators;lhs_dims=[%elhs_dims];rhs_dims=[|[%erhs1_dims];[%erhs2_dims]|];project_lhs=[%eproject_lhs];project_rhs=[|[%eproject_rhs1];[%eproject_rhs2]|];debug_info={p.debug_infowithtrace=("ppx_cd "^[%eexpr2string_or_emptyaccu_op]^" "^[%eexpr2string_or_emptybin_op],Arrayjit.Indexing.unique_debug_id())::p.debug_info.trace;};})]in(* TODO: might be better to treat missing [rhs1, rhs2] as zeros or errors rather than eliding
the code. *)letbody=[%exprOption.value~default:Arrayjit.Assignments.Noop@@Option.map3[%esetup_l.array_opt][%esetup_r1.array_opt][%esetup_r2.array_opt]~f:(funlhsrhs1rhs2->Arrayjit.Assignments.Accum_binop{initialize_neutral=[%einitialize_neutral];accum=[%eaccu_op];lhs;op=[%ebin_op];rhs1;rhs2;projections=[%eprojections];})]inassignment~punned~lhs:setup_l~rhses:[setup_r1;setup_r2]bodyinletprocess_assign_unop~accu_op~lhs~un_op~rhs?projections~proj_in_scope()=letinitialize_neutral,accu_op=assignment_opaccu_opin(* FIXME: I think this ignores the slot information here! Just assuming [projections] is
as-should-be, but that's not consistent with omitting the projections arg (assuming it
comes from the context). *)letsetup_l=setup_array~punned~bad_pun_hints~is_lhs:true@@loop~proj_in_scopelhsinletsetup_r=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhsinletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inletprojections=matchprojectionswith|Someprjs->prjs|None->letlhs_dims=project_p_dims"LHS"lhs.pexp_locsetup_l.slotinletrhs1_dims=project_p_dims"RHS1"lhs.pexp_locsetup_r.slotinletproject_lhs=project_p_slot"LHS"lhs.pexp_locsetup_l.slotinletproject_rhs1=project_p_slot"RHS1"rhs.pexp_locsetup_r.slotin[%exprlazy(letp=Lazy.forceprojectionsinArrayjit.Indexing.{product_space=p.product_space;product_iterators=p.product_iterators;lhs_dims=[%elhs_dims];rhs_dims=[|[%erhs1_dims]|];project_lhs=[%eproject_lhs];project_rhs=[|[%eproject_rhs1]|];debug_info={p.debug_infowithtrace=("ppx_cd "^[%eexpr2string_or_emptyaccu_op]^" "^[%eexpr2string_or_emptyun_op],Arrayjit.Indexing.unique_debug_id())::p.debug_info.trace;};})]in(* TODO: might be better to treat missing [rhs] as zeros or errors rather than eliding the
code. *)letbody=[%exprOption.value~default:Arrayjit.Assignments.Noop@@Option.map2[%esetup_l.array_opt][%esetup_r.array_opt]~f:(funlhsrhs->Arrayjit.Assignments.Accum_unop{initialize_neutral=[%einitialize_neutral];accum=[%eaccu_op];lhs;op=[%eun_op];rhs;projections=[%eprojections];})]inassignment~punned~lhs:setup_l~rhses:[setup_r]bodyinletprocess_raw_binop~accu_op~lhs~bin_op~rhs1~rhs2~logic=letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array~punned~bad_pun_hints~is_lhs:true@@loop~proj_in_scopelhsinletsetup_r1=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs1inletsetup_r2=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs2inletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inlett_expr,lhs_is_grad,_=args_for~locsetup_linlett1_expr,rhs1_is_grad,rhs1_is_merge=args_for~locsetup_r1inlett2_expr,rhs2_is_grad,rhs2_is_merge=args_for~locsetup_r2inletbody=[%exprTensor.raw_binop~initialize_neutral:[%einitialize_neutral]~accum:[%eaccu_op]~t:[%et_expr]~lhs_is_grad:[%elhs_is_grad]~op:[%ebin_op]~t1:[%et1_expr]~rhs1_is_grad:[%erhs1_is_grad]~rhs1_is_merge:[%erhs1_is_merge]~t2:[%et2_expr]~rhs2_is_grad:[%erhs2_is_grad]~rhs2_is_merge:[%erhs2_is_merge]~logic:[%elogic]]inassignment~punned~lhs:setup_l~rhses:[setup_r1;setup_r2]bodyinletprocess_raw_unop~accu_op~lhs~un_op~rhs~logic=letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array~punned~bad_pun_hints~is_lhs:true@@loop~proj_in_scopelhsinletsetup_r=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhsinletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inlett_expr,lhs_is_grad,_=args_for~locsetup_linlett1_expr,rhs_is_grad,rhs_is_merge=args_for~locsetup_rinletbody=[%exprTensor.raw_unop~initialize_neutral:[%einitialize_neutral]~accum:[%eaccu_op]~t:[%et_expr]~lhs_is_grad:[%elhs_is_grad]~op:[%eun_op]~t1:[%et1_expr]~rhs_is_grad:[%erhs_is_grad]~rhs_is_merge:[%erhs_is_merge]~logic:[%elogic]]inassignment~punned~lhs:setup_l~rhses:[setup_r]bodyinmatchexprwith|{pexp_desc=Pexp_constant(Pconst_float_);_}->{default_resultwithexpr=[%exprNTDSL.number[%eexpr]]}|{pexp_desc=Pexp_constant(Pconst_integer_);_}->{default_resultwithexpr=[%exprNTDSL.number(Float.of_int[%eexpr])]}|[%expr[%e?{pexp_desc=Pexp_constant(Pconst_charch);pexp_loc;_}][%e?{pexp_desc=Pexp_constant(Pconst_float_);_}asf]]->letaxis=Ast_helper.Exp.constant~loc:pexp_loc(Pconst_string(String.of_charch,pexp_loc,None))in{default_resultwithexpr=[%exprNTDSL.number~axis_label:[%eaxis][%ef]]}|[%expr[%e?{pexp_desc=Pexp_constant(Pconst_charch);pexp_loc;_}][%e?{pexp_desc=Pexp_constant(Pconst_integer_);_}asi]]->letaxis=Ast_helper.Exp.constant~loc:pexp_loc(Pconst_string(String.of_charch,pexp_loc,None))in{default_resultwithexpr=[%exprNTDSL.number~axis_label:[%eaxis](Float.of_int[%ei])];}|{pexp_desc=Pexp_constant(Pconst_string(name,str_loc,_));_}->{default_resultwithtyp=No_grad_tensor_intro{name;name_expr=expr};expr=A.Exp.ident~loc:str_loc{txt=Lidentname;loc=str_loc};}|{pexp_desc=Pexp_array_;_}|{pexp_desc=Pexp_construct({txt=Lident"::";_},_);_}->{default_resultwithexpr=ndarray_opexpr}|{pexp_desc=Pexp_ident{txt=Lident("v"|"lhs");_};_}->{default_resultwithtyp=Array;slot=LHS}|{pexp_desc=Pexp_ident{txt=Lident"g";_};_}->{default_resultwithtyp=Array;slot=LHS}|{pexp_desc=Pexp_ident{txt=Lident"rhs1";_};_}->{default_resultwithtyp=Array;slot=RHS1}|{pexp_desc=Pexp_ident{txt=Lident"t1";_};_}->{default_resultwithslot=RHS1;expr}|{pexp_desc=Pexp_ident{txt=Lident"v1";_};_}->{default_resultwithtyp=Array;slot=RHS1;expr=[%exprt1.Tensor.value]}|{pexp_desc=Pexp_ident{txt=Lident"g1";_};_}->{default_resultwithtyp=Grad_of_tensor[%exprt1];slot=RHS1;expr=[%exprOption.mapt1.Tensor.diff~f:(fund->d.Tensor.grad)];}|{pexp_desc=Pexp_ident{txt=Lident"rhs2";_};_}->{default_resultwithtyp=Array;slot=RHS2}|{pexp_desc=Pexp_ident{txt=Lident"t2";_};_}->{default_resultwithtyp=Tensor;slot=RHS2}|{pexp_desc=Pexp_ident{txt=Lident"v2";_};_}->{default_resultwithtyp=Array;slot=RHS2;expr=[%exprt2.Tensor.value]}|{pexp_desc=Pexp_ident{txt=Lident"g2";_};_}->{default_resultwithtyp=Grad_of_tensor[%exprt2];slot=RHS2;expr=[%exprOption.mapt2.Tensor.diff~f:(fund->d.Tensor.grad)];}|{pexp_desc=Pexp_ident{txt=Lidentop_ident;_};_}whenis_operatorop_ident->default_result|[%expr[%e?expr1]**.[%e?{pexp_desc=Pexp_constant(Pconst_integer_);_}asi]]->(* FIXME: `**.` should take a tensor and require that it's a literal. *)(* We need to hardcode these two patterns to prevent the numbers from being converted to tensors. *)letres1=loop~proj_in_scopeexpr1in{res1withtyp=Tensor;expr=[%exprNTDSL.O.(**.)[%eres1.expr](Float.of_int[%ei])];}|[%expr[%e?expr1]**.[%e?expr2]]->letres1=loop~proj_in_scopeexpr1in{res1withtyp=Tensor;expr=[%exprNTDSL.O.(**.)[%eres1.expr][%eexpr2]]}|[%expr[%e?expr1]*+[%e?{pexp_desc=Pexp_constant(Pconst_string(spec_str,_,_));_}asspec][%e?expr2]]whenString.containsspec_str'>'->letres1=loop~proj_in_scopeexpr1inletres2=loop~proj_in_scopeexpr2inletslot=Option.value~default:Undet@@List.find~f:(functionUndet->false|_->true)[res1.slot;res2.slot]in{vbs=reduce_vbss[res1.vbs;res2.vbs];typ=Tensor;slot;expr=[%exprNTDSL.einsum[%espec][%eres1.expr][%eres2.expr]];array_opt_of_code=None;}|[%expr[%e?expr1]++[%e?{pexp_desc=Pexp_constant(Pconst_string(spec_str,_,_));_}asspec]]whenString.containsspec_str'>'->letres1=loop~proj_in_scopeexpr1in{res1withtyp=Tensor;expr=[%exprNTDSL.einsum1[%espec][%eres1.expr]]}|[%expr[%e?expr1].grad]->(letres1=loop~proj_in_scopeexpr1inmatchres1.typwith|Unknown|Tensor|No_grad_tensor_intro_->{res1withtyp=Grad_of_tensorexpr1;expr=[%exprOption.map[%eres1.expr].Tensor.diff~f:(fund->d.Tensor.grad)];}|Merge_value_->{res1withtyp=Merge_gradexpr1;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: write .grad.merge instead of .merge.grad";}|Code|Array|Value_of_tensor_|Grad_of_tensor_|Merge_grad_->{res1withtyp=Array;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: only tensors have a gradient";})|[%expr[%e?expr1].value]->(letres1=loop~proj_in_scopeexpr1in(* TODO: maybe this is too permissive? E.g. [t1.grad.value] is accepted. *)matchres1.typwith|Unknown|Tensor|No_grad_tensor_intro_->{res1withtyp=Value_of_tensorres1.expr;expr=[%expr[%eres1.expr].Tensor.value];}|Code->{default_resultwithtyp=Array;slot=res1.slot;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: <code>.value notation not supported when <code> is not a \
tensor";}|Array|Value_of_tensor_|Grad_of_tensor_|Merge_value_|Merge_grad_->res1)|[%expr[%e?expr1].merge]->(letres1=loop~proj_in_scopeexpr1inmatchres1.typwith|Unknown|Tensor|No_grad_tensor_intro_->{res1withtyp=Merge_valueres1.expr;expr=[%expr[%eres1.expr].Tensor.value]}|Value_of_tensort->{res1withtyp=Merge_valuet;expr=[%expr[%eres1.expr].Tensor.value]}|Array|Code->{res1withtyp=Array;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: only tensor nodes (e.g. `.value` or `.grad`) can be merged";}|Grad_of_tensort->{res1withvbs=no_vbs;typ=Merge_gradt}|Merge_value_|Merge_grad_->{res1withexpr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: repeated .merge not allowed";})|[%expr[%e?accu_op][%e?lhs]([%e?bin_op][%e?rhs1]([%e?rhs2]~projections:[%e?projections]))]->process_assign_binop~accu_op~lhs~bin_op~rhs1~rhs2~projections~proj_in_scope:true()|[%expr[%e?accu_op][%e?lhs](([%e?un_op][%e?rhs])~projections:[%e?projections])]|[%expr[%e?accu_op][%e?lhs]([%e?un_op]([%e?rhs]~projections:[%e?projections]))]->let_,un_op=unary_opun_opin(* Handle both un_op priority levels -- where application binds tighter and less tight. *)process_assign_unop~accu_op~lhs~un_op~rhs~projections~proj_in_scope:true()|[%expr[%e?accu_op][%e?lhs]([%e?rhs]~projections:[%e?projections])]->process_assign_unop~accu_op~lhs~un_op:[%exprArrayjit.Ops.Identity]~rhs~projections~proj_in_scope:true()|[%expr[%e?accu_op][%e?lhs]([%e?bin_op][%e?rhs1]([%e?rhs2]~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}aslogic]))]->letlogic=letloc=s_locinifString.equalspec"."then[%exprShape.Pointwise_bin]elseifString.equalspec"@"then[%exprShape.Compose]else[%exprShape.Einsum[%elogic]]inlet_,bin_op=binary_opbin_opinprocess_raw_binop~accu_op~lhs~bin_op~rhs1~rhs2~logic|[%expr[%e?accu_op][%e?lhs](([%e?un_op][%e?rhs])~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}aslogic])]|[%expr[%e?accu_op][%e?lhs]([%e?un_op]([%e?rhs]~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}aslogic]))]->(* Handle both un_op priority levels -- where application binds tighter and less tight. *)letlogic=letloc=s_locinifString.equalspec"."then[%exprShape.Pointwise_un]elseifString.equalspec"T"then[%exprShape.Transpose]else[%exprShape.Permute[%elogic]]inlet_,un_op=unary_opun_opinprocess_raw_unop~accu_op~lhs~un_op~rhs~logic|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_ident;_};_}asaccu_op][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbinop_ident;_};_}asbin_op][%e?rhs1][%e?rhs2])]whenis_assignmentaccu_ident&&is_binary_opbinop_ident&&proj_in_scope->process_assign_binop~accu_op~lhs~bin_op~rhs1~rhs2~proj_in_scope()|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_ident;_};_}asaccu_op][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentunop_ident;_};_}asun_op][%e?rhs])]whenis_assignmentaccu_ident&&is_unary_opunop_ident&&proj_in_scope->let_,un_op=unary_opun_opinprocess_assign_unop~accu_op~lhs~un_op~rhs~proj_in_scope()|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentop_ident;_};_}asaccu_op][%e?lhs][%e?rhs]]whenis_assignmentop_ident&&proj_in_scope->process_assign_unop~accu_op~lhs~un_op:[%exprArrayjit.Ops.Identity]~rhs~proj_in_scope()|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_ident;_};_}asaccu_op][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbinop_ident;_};_}asbin_op][%e?rhs1][%e?rhs2])]whenis_assignmentaccu_ident&&is_binary_opbinop_ident->letlogic,bin_op=binary_opbin_opinprocess_raw_binop~accu_op~lhs~bin_op~rhs1~rhs2~logic|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_ident;_};_}asaccu_op][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentunop_ident;_};_}asun_op][%e?rhs])]whenis_assignmentaccu_ident&&is_unary_opunop_ident->letlogic,un_op=unary_opun_opinprocess_raw_unop~accu_op~lhs~un_op~rhs~logic|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentop_ident;_};_}asaccu_op][%e?lhs][%e?rhs]]whenis_assignmentop_ident->process_raw_unop~accu_op~lhs~un_op:[%exprArrayjit.Ops.Identity]~rhs~logic:[%exprShape.Pointwise_un]|[%expr[%e?expr1][%e?expr2][%e?expr3]]->letres1=loop~proj_in_scopeexpr1inletres2=loop~proj_in_scopeexpr2inletres3=loop~proj_in_scopeexpr3inletslot=Option.value~default:Undet@@List.find~f:(functionUndet->false|_->true)[res1.slot;res2.slot;res3.slot]in{vbs=reduce_vbss[res1.vbs;res2.vbs;res3.vbs];typ=res1.typ;slot;expr=[%expr[%eres1.expr][%eres2.expr][%eres3.expr]];array_opt_of_code=None;}|[%expr[%e?expr1][%e?expr2]]->letres1=loop~proj_in_scopeexpr1inletres2=loop~proj_in_scopeexpr2inletslot=Option.value~default:Undet@@List.find~f:(functionUndet->false|_->true)[res1.slot;res2.slot]in{vbs=reduce_vbss[res1.vbs;res2.vbs];typ=res1.typ;slot;expr=[%expr[%eres1.expr][%eres2.expr]];array_opt_of_code=None;}|{pexp_desc=Pexp_fun((arg_label:arg_label),arg,pat,expr1);_}asexpr->letproj_in_scope=proj_in_scope||matcharg_labelwith|(Labelleds|Optionals)whenString.equals"projections"->true|_->falseinletbad_pun_hints=Set.unionbad_pun_hints@@collect_pat_identspatinletres1=transl~bad_pun_hints~proj_in_scopeexpr1in{res1withexpr={exprwithpexp_desc=Pexp_fun(arg_label,arg,pat,res1.expr)}}|[%exprwhile[%e?_test_expr]do[%e?_body]done]->{default_resultwithtyp=Code;slot=Nonslot;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: while: low-level code embeddings not supported yet";}|[%exprfor[%p?_pat]=[%e?_init]to[%e?_final]do[%e?_body_expr]done]->{default_resultwithtyp=Code;slot=Nonslot;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: for-to: low-level code embeddings not supported yet";}|[%exprfor[%p?_pat]=[%e?_init]downto[%e?_final]do[%e?_body_expr]done]->{default_resultwithtyp=Code;slot=Nonslot;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: for-downto: low-level code embeddings not supported yet";}|[%expr~~[%e?{pexp_desc=Pexp_apply(expr,exprs);pexp_loc;_}];[%e?expr2]]->letelements=expr::List.map~f:sndexprs|>List.map~f:(function|{pexp_desc=Pexp_constant(Pconst_string_);_}ass->s|[%expr[%e?t].value]->[%exprArrayjit.Tnode.debug_name[%et].value]|[%expr[%e?t].grad]->[%exprArrayjit.Tnode.debug_name[%et].value^".grad"]|t->[%exprArrayjit.Tnode.debug_name[%et].value])inletres2=loop~proj_in_scopeexpr2in{res2withexpr=[%exprArrayjit.Assignments.Block_comment(String.concat_array~sep:" "[%eAst_helper.Exp.array~loc:pexp_locelements],[%eres2.expr])];}|[%expr[%e?expr1];[%e?expr2]]->letres1=loop~proj_in_scopeexpr1inletres2=loop~proj_in_scopeexpr2in{vbs=reduce_vbss[res1.vbs;res2.vbs];typ=Code;slot=Nonslot;expr=[%exprArrayjit.Assignments.Seq([%eres1.expr],[%eres2.expr])];array_opt_of_code=res2.array_opt_of_code;}|[%exprif[%e?expr1]then[%e?expr2]else[%e?expr3]]->letres2=loop~proj_in_scopeexpr2inletres3=loop~proj_in_scopeexpr3inlettyp=ifis_unknownres2.typthenres3.typelseres2.typinletslot=Option.value~default:Undet@@List.find~f:(functionUndet->false|_->true)[res2.slot;res3.slot]in{vbs=reduce_vbss[res2.vbs;res3.vbs];typ;slot;expr=[%exprif[%eexpr1]then[%eres2.expr]else[%eres3.expr]];array_opt_of_code=None;}|[%exprif[%e?expr1]then[%e?expr2]]->letres2=loop~proj_in_scopeexpr2in{vbs=res2.vbs;typ=Code;slot=Nonslot;expr=[%exprif[%eexpr1]then[%eres2.expr]elseArrayjit.Assignments.Noop];array_opt_of_code=res2.array_opt_of_code;}|{pexp_desc=Pexp_match(expr1,cases);_}->letfields,cases=List.unzip@@List.mapcases~f:(fun({pc_rhs;_}asc)->letres=loop~proj_in_scopepc_rhsin((res.vbs,res.typ,res.slot),{cwithpc_rhs=res.expr}))inletvbss,typs,slots=List.unzip3fieldsinlettyp=Option.value~default:Unknown@@List.findtyps~f:(Fn.nonis_unknown)inletslot=Option.value~default:Undet@@List.find~f:(functionUndet->false|_->true)slotsin{vbs=reduce_vbssvbss;typ;slot;expr={exprwithpexp_desc=Pexp_match(expr1,cases)};array_opt_of_code=None;}|{pexp_desc=Pexp_let(_recflag,_bindings,_body);_}->(* TODO(80): to properly support local bindings, we need to collect the type environment. *){default_resultwithtyp=Unknown;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: let-in: local let-bindings not implemented yet";}(* let bindings = List.map bindings ~f:(fun binding -> {binding with pvb_expr=loop
binding.pvb_expr}) in {expr with pexp_desc=Pexp_let (recflag, bindings, loop body)} *)|{pexp_desc=Pexp_open(decl,expr1);_}->letres1=loop~proj_in_scopeexpr1in{res1withexpr={exprwithpexp_desc=Pexp_open(decl,res1.expr)}}|{pexp_desc=Pexp_letmodule(name,module_expr,expr1);_}->letres1=loop~proj_in_scopeexpr1in{res1withexpr={exprwithpexp_desc=Pexp_letmodule(name,module_expr,res1.expr)}}|{pexp_desc=Pexp_ident{txt=Lidentop_ident;_};_}whenis_operatorop_ident->{default_resultwithtyp=Unknown;expr=[%expr[%eexpr]]}|_->{default_resultwithtyp=Unknown}intransl~proj_in_scope:false~bad_pun_hints:(Set.empty(moduleString))exprlettranslate?ident_labelexpr=letres=translateexprinletexpr=res.exprinletloc=res.expr.pexp_locin(res.vbs,matchident_labelwith|Some[%pat?_]->[%exprTensor.with_unchanged_roots~f:(fun()->letopen!NTDSL.Oin[%eexpr])]|_->[%exprletopen!NTDSL.Oin[%eexpr]])letexpr_expander~loc~path=expr_expander_with_punningtranslate~loc~pathletstr_expander~loc~path=str_expander_with_punningtranslate~loc~path