12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289openBaseopenPpxlibopenPpx_arrayjit.Ppx_helperopenPpx_sharedmoduleA=Ppxlib_ast.Ast_helperletndarray_op?axis_labels?labelexpr=letloc=expr.pexp_locinletvalues,batch_dims,output_dims,input_dims=ndarray_constantexprinletedimsdims=Ast_builder.Default.elist~locdimsinletop=match(axis_labels,label)with|None,None->[%exprNTDSL.ndarray]|Someaxis_labels,None->[%exprNTDSL.ndarray~axis_labels:[%eaxis_labels]]|None,Somelabel->[%exprNTDSL.ndarray~label:[%elabel]]|Someaxis_labels,Somelabel->[%exprNTDSL.ndarray~axis_labels:[%eaxis_labels]~label:[%elabel]]in[%expr[%eop]~batch_dims:[%eedimsbatch_dims]~input_dims:[%eedimsinput_dims]~output_dims:[%eedimsoutput_dims][%evalues]]typeexpr_type=|Code|Array|Value_of_tensorofexpression|Grad_of_tensorofexpression|Tensor|Unknown|Merge_valueofexpression|Merge_gradofexpression|No_grad_tensor_introof{name:string;name_expr:expression}letis_unknown=functionUnknown->true|_->falsetypeprojections_slot=LHS|RHS1|RHS2|RHS3|Scalar|Nonslot|Undet[@@derivingequal,sexp]typeresult={vbs:value_bindingMap.M(String).t;(** [vbs] are the bindings introduced by inline tensor declarations (aka. punning). These
bindings are discharged with the whole [%cd] extension scope in scope. *)typ:expr_type;slot:projections_slot;expr:expression;(** Depending on {!field-typ}, of type:
- if [Code]: [Assignments.comp];
- if [Array | Merge_value | Value_of_tensor]: [Tnode.t];
- if [Merge_grad | Grad_of_tensor]: [Tnode.t option];
- if [Tensor | Unknown | No_grad_tensor_intro]: [Tensor.t]. *)array_opt_of_code:expressionoption;(** Of type: [Tnode.t]. It keeps track of which tensor node to use when [typ] is [Code] but
the result is used in an array context. *)}typearray_setup={vb:value_bindingoption;(** This binding is only generated for a tensor expression that is not an identifier, since
recomputing the expression (when copied) would generate a fresh tensor. It is discharged
when an assignment is built. *)slot:projections_slot;filler_typ:expr_type;fwd_code_or_noop:expressionoption;(** Of type: [Assignments.comp]. *)array_opt:expression;(** Of type: if [slot = LHS] then [Tnode.t] else [Assignments.buffer]. *)tensor:expressionoption;(** Of type: [Tensor.t]. *)pun_hint_tnode:(expression*bool)option;(** Of type: [string list]. The tnode, if any, whose label the relevant punned no-gradient
tensor should incorporate in its label. The bool denotes whether this is a preferred (high
quality) guess. *)}letmake_vb~loc~name~name_expr~hint_label=letpat=A.Pat.var~loc{loc=name_expr.pexp_loc;txt=name}inletv=matchhint_labelwith|None->[%exprNTDSL.term~label:[[%ename_expr]]()]|Somehint_label->[%exprNTDSL.term~label:([%ename_expr]::[%ehint_label])()]inletvb=A.Vb.mk~locpatvinvbletreduce_embs_arr~loc(rs:array_setuplist)=List.filter_maprs~f:(funhs->hs.fwd_code_or_noop)|>List.reduce~f:(funembscomp->[%exprBase.Set.union[%eembs][%ecomp].embedded_nodes])(** The expression argument is of type: [Assignments.t]. *)letassignment~punned~lhs~rhsesbody=letsetups=lhs::rhsesinletloc=body.pexp_locinletforward_args=List.filter_mapsetups~f:(fun{fwd_code_or_noop;_}->fwd_code_or_noop)inletvbs,body=matchlhs.filler_typwith|No_grad_tensor_intro{name;name_expr}->(letgood_hints,bad_hints=List.partition_tf~f:snd@@List.filter_maprhses~f:(funsup->sup.pun_hint_tnode)inlethint_data=Option.first_some(List.hdgood_hints)(List.hdbad_hints)inlethint_label=Option.map~f:fsthint_datainletvbs=Map.singleton(moduleString)name@@make_vb~loc~name~name_expr~hint_labelinmatchhint_datawith|None->(vbs,body)|Somedata->(matchHashtbl.addpunned~key:name~datawith|`Ok->(vbs,body)|`Duplicate->(no_vbs,Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: duplicate inline declaration of no-gradient tensor %s"name)))|_->(no_vbs,body)inletbody=ifOption.is_somelhs.vbthenAst_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: the assigned-to position cannot be an expression building a new tensor"elsebodyinlettensor_vbs=List.filter_maprhses~f:(funrhs->rhs.vb)inletbody=[%expr{asgns=[%ebody];embedded_nodes=Base.Set.empty(moduleArrayjit.Tnode)}]inletcomps=List.fold(body::List.revforward_args)~init:[%expr[]]~f:(funxsx->[%expr[%ex]::[%exs]])inletexpr=[%exprArrayjit.Assignments.sequence[%ecomps]]inletexpr=ifList.is_emptytensor_vbsthenexprelseA.Exp.let_~locNonrecursivetensor_vbsexprin{vbs;typ=Code;slot=Nonslot;expr;array_opt_of_code=Somelhs.array_opt}letproject_p_slotdebuglocslot=matchslotwith|LHS->[%exprp.project_lhs]|RHS1->[%exprp.project_rhs.(0)]|RHS2->[%exprp.project_rhs.(1)]|RHS3->[%exprp.project_rhs.(2)]|Scalar->[%expr[|Arrayjit.Indexing.Fixed_idx0|]]|Nonslot->Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: not a valid accumulation/assignment slot filler at %s"debug|Undet->Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: insufficient slot filler information at %s %s"debug"(incorporate one of: v, v1, v2, g, g1, g2, lhs, rhs, rhs1, rhs2)"letproject_p_dimsdebuglocslot=matchslotwith|LHS->[%exprp.lhs_dims]|RHS1->[%exprp.rhs_dims.(0)]|RHS2->[%exprp.rhs_dims.(1)]|RHS3->[%exprp.rhs_dims.(2)]|Scalar->[%expr[|1|]]|Nonslot->Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: not a valid accumulation/assignment slot filler at %s"debug|Undet->Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: insufficient slot filler information at %s %s"debug"(incorporate one of: v, v1, v2, g, g1, g2, lhs, rhs, rhs1, rhs2)"letguess_pun_hint~punned~bad_pun_hintsfiller_typfiller=letloc=filler.pexp_locinlethint=[%expr[%efiller].Arrayjit.Tnode.label]inmatch(filler_typ,filler)with|Code,_->None|_,{pexp_desc=Pexp_ident{txt=Lidentname;_};_}whenSet.membad_pun_hintsname->None|Array,_->Some(hint,false)|(Tensor|Unknown),{pexp_desc=Pexp_ident{txt=Lidentname;_};_}whenHashtbl.mempunnedname->Hashtbl.findpunnedname|(Tensor|Unknown),{pexp_desc=Pexp_ident_;_}->Some(hint,true)|(Tensor|Unknown),_->Some(hint,false)|((Value_of_tensor{pexp_desc=Pexp_ident{txt=Lidentname;_};_}|Grad_of_tensor{pexp_desc=Pexp_ident{txt=Lidentname;_};_}|Merge_value{pexp_desc=Pexp_ident{txt=Lidentname;_};_}|Merge_grad{pexp_desc=Pexp_ident{txt=Lidentname;_};_}),_)whenSet.membad_pun_hintsname->None|((Value_of_tensor{pexp_desc=Pexp_ident{txt=Lidentname;_};_}|Grad_of_tensor{pexp_desc=Pexp_ident{txt=Lidentname;_};_}|Merge_value{pexp_desc=Pexp_ident{txt=Lidentname;_};_}|Merge_grad{pexp_desc=Pexp_ident{txt=Lidentname;_};_}),_)whenHashtbl.mempunnedname->Hashtbl.findpunnedname|(Value_of_tensort|Grad_of_tensort|Merge_valuet|Merge_gradt),_->(lethint=[%expr[%et].Tensor.value.Arrayjit.Tnode.label]inmatchtwith{pexp_desc=Pexp_ident_;_}->Some(hint,true)|_->Some(hint,false))|No_grad_tensor_intro{name;_},_->Hashtbl.findpunnednameletempty_tns~loc=[%exprBase.Set.empty(moduleArrayjit.Tnode)]letempty_comp~loc=[%expr{asgns=Arrayjit.Assignments.Noop;embedded_nodes=[%eempty_tns~loc]}]letsetup_array~punned~bad_pun_hints~is_lhs{typ=filler_typ;slot;expr=filler;vbs;array_opt_of_code}=assert(Map.is_emptyvbs);letloc=filler.pexp_locinletopt_buffertn=ifis_lhsthen[%exprSome[%etn]]else[%exprSome(Arrayjit.Assignments.Node[%etn])]inletbufferopt_tn=ifis_lhsthenopt_tnelse[%exprOption.map[%eopt_tn]~f:(funtn->Arrayjit.Assignments.Nodetn)]inletpun_hint_tnode=guess_pun_hint~punned~bad_pun_hintsfiller_typfillerinletdefault_setup={vb=None;fwd_code_or_noop=None;filler_typ;slot;array_opt=opt_buffer[%expr[%efiller].Tensor.value];tensor=None;pun_hint_tnode;}inmatchfiller_typwith|No_grad_tensor_intro_whennotis_lhs->{default_setupwitharray_opt=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: punning is only allowed in the assigned-to position";}|(Tensor|Unknown)whenmatchfillerwith{pexp_desc=Pexp_ident_;_}->true|_->false->lett=fillerinletfwd_code_or_noop=Some[%exprifTensor.is_fwd_root[%et]then(Tensor.remove_fwd_root[%et];[%et].Tensor.forward)else[%eempty_comp~loc]]in{default_setupwithfwd_code_or_noop;tensor=Somet}|Value_of_tensor({pexp_desc=Pexp_ident_;_}ast)->letfwd_code_or_noop=Some[%exprifTensor.is_fwd_root[%et]then(Tensor.remove_fwd_root[%et];[%et].Tensor.forward)else[%eempty_comp~loc]]in{default_setupwithfwd_code_or_noop;array_opt=opt_buffer[%expr[%et].Tensor.value];tensor=Somet;}|Value_of_tensort->{default_setupwitharray_opt=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: the <tensor>.value notation is only supported when <tensor> is an \
identifier";tensor=Somet;}|Tensor|Unknown->(* Need to bind the expression computing the tensor so we don't recompute it. *)letv=matchslotwith|LHS->[%pat?nondiff__lhs]|RHS1->[%pat?nondiff__rhs1]|RHS2->[%pat?nondiff__rhs2]|RHS3->[%pat?nondiff__rhs3]|Scalar|Nonslot|Undet->[%pat?nondiff__tensor]inlett=pat2exprvinletvb=Some(A.Vb.mk~locvfiller)inletfwd_code_or_noop=Some[%exprifTensor.is_fwd_root[%et]then(Tensor.remove_fwd_root[%et];[%et].Tensor.forward)else[%eempty_comp~loc]]in{default_setupwithvb;fwd_code_or_noop;array_opt=opt_buffer[%expr[%et].Tensor.value];tensor=Somet;}|No_grad_tensor_intro_->(* Inline tensors are guaranteed to be leaf tensors, so they don't have forward code, but they
are embedded. *)letfwd_code_or_noop=Some[%exprArrayjit.Assignments.{asgns=Noop;embedded_nodes=Base.Set.singleton(moduleArrayjit.Tnode)[%efiller].Tensor.value;}]in{default_setupwithfwd_code_or_noop;tensor=Somefiller}|CodewhenOption.is_nonearray_opt_of_code->{default_setupwithfwd_code_or_noop=Somefiller;array_opt=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: could not determine a lead array of provided code";}|Code->{default_setupwithfwd_code_or_noop=Somefiller;array_opt=buffer(Option.value_exnarray_opt_of_code);}|Array->{default_setupwitharray_opt=opt_bufferfiller}|Grad_of_tensor({pexp_desc=Pexp_ident_;_}ast)->{default_setupwitharray_opt=bufferfiller;tensor=Somet}|Grad_of_tensort->{default_setupwitharray_opt=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: the <tensor>.grad notation is only supported when <tensor> is an \
identifier";tensor=Somet;}|(Merge_value_|Merge_grad_)whenis_lhs->{default_setupwitharray_opt=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: merge buffers cannot be assigned to";}|Merge_valuet->{default_setupwitharray_opt=[%exprSome(Merge_buffer[%efiller])];tensor=Somet}|Merge_gradt->{default_setupwitharray_opt=[%exprOption.map[%efiller]~f:(funtn->Arrayjit.Assignments.Merge_buffertn)];tensor=Somet;}letargs_for~loc=function|{filler_typ=Merge_grad_;tensor=Somet;_}->(t,[%exprtrue],[%exprtrue])|{filler_typ=Grad_of_tensor_;tensor=Somet;_}->(t,[%exprtrue],[%exprfalse])|{filler_typ=Merge_value_;tensor=Somet;_}->(t,[%exprfalse],[%exprtrue])|{filler_typ=_;tensor=Somet;_}->(t,[%exprfalse],[%exprfalse])|_->(Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: cannot use `~logic` (infer shapes) for arrays, use tensors or \
`.value` or `.grad` notation",[%exprfalse],[%exprfalse])letreduce_res_vbsrs=reduce_vbss@@List.maprs~f:(funr->r.vbs)lettranslate(expr:expression):result=letpunned=Hashtbl.create(moduleString)inletrectransl~bad_pun_hints~proj_in_scope(expr:expression):result=letloc=expr.pexp_locinletdefault_result={vbs=no_vbs;typ=Tensor;slot=Undet;expr;array_opt_of_code=None}inletloop=transl~bad_pun_hintsinletassignment_opaccu_op=loc|>Option.value_or_thunk(Hashtbl.findassignment_opsaccu_op)~default:(fun()_loc->(false,Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: expected an assignment operator, one of: %s %s""=+ (Add), =- (Sub), =* (Mul),=/ (Div), =** (ToPowOf), =?/ (Relu_gate), =?^ \
(Satur01_gate), =|| (Or), =&& (And), =@^ (Max), =^^ (Min), =: (Arg2),=:+, \
=:-,"" =:*, =:/, =:**, =:?/, =:?^, =:||, =:&&, =:@^, =:^^ (same with initializing \
the tensor to the neutral value before the start of the calculation)"))inletunary_opun_op=loc|>Option.value_or_thunk(Hashtbl.findunary_opsun_op)~default:(fun()loc->([%exprShape.Pointwise_un],Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: expected an assignment operator, one of: %s""id, relu, sat01, exp, log, exp2, log2, sin, cos, sqrt, recip, recip_sqrt, \
neg, tanh"))inletbinary_opbin_op=loc|>Option.value_or_thunk(Hashtbl.findbinary_opsbin_op)~default:(fun()_loc->([%exprShape.Pointwise_bin],Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: expected a binary operator, one of: %s""+ (Add), - (Sub), * (Mul), / (Div), **(ToPowOf), -?/ (Relu_gate), -?^ \
(Satur01_gate), -/> (Arg2), < (Cmplt), = (Cmpeq), <> (Cmpne), || (Or), && \
(And), % (Mod), @^(Max), ^^ (Min)"))inletternary_optern_op=loc|>Option.value_or_thunk(Hashtbl.findternary_opstern_op)~default:(fun()_loc->([%exprShape.Pointwise_tern],Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: expected a ternary operator, one of: where, fma"))in(* FIXME: collapse these (code reuse) *)letprocess_assign_ternop~accu_op~lhs~tern_op~rhs1~rhs2~rhs3?projections~proj_in_scope()=letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array~punned~bad_pun_hints~is_lhs:true@@loop~proj_in_scope:truelhsinlet_,tern_op=ternary_optern_opinletsetup_r1=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs1inletsetup_r2=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs2inletsetup_r3=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs3inletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inletprojections=matchprojectionswith|Someprjs->prjs|None->letlhs_dims=project_p_dims"LHS"lhs.pexp_locsetup_l.slotinletrhs1_dims=project_p_dims"RHS1"lhs.pexp_locsetup_r1.slotinletrhs2_dims=project_p_dims"RHS2"lhs.pexp_locsetup_r2.slotinletrhs3_dims=project_p_dims"RHS3"lhs.pexp_locsetup_r3.slotinletproject_lhs=project_p_slot"LHS"lhs.pexp_locsetup_l.slotinletproject_rhs1=project_p_slot"RHS1"rhs1.pexp_locsetup_r1.slotinletproject_rhs2=project_p_slot"RHS2"rhs2.pexp_locsetup_r2.slotinletproject_rhs3=project_p_slot"RHS3"rhs3.pexp_locsetup_r3.slotin[%exprlazy(letp=Lazy.forceprojectionsinArrayjit.Indexing.{product_space=p.product_space;product_iterators=p.product_iterators;lhs_dims=[%elhs_dims];rhs_dims=[|[%erhs1_dims];[%erhs2_dims];[%erhs3_dims]|];project_lhs=[%eproject_lhs];project_rhs=[|[%eproject_rhs1];[%eproject_rhs2];[%eproject_rhs3]|];debug_info={p.debug_infowithtrace=("ppx_cd "^[%eexpr2string_or_emptyaccu_op]^" "^[%eexpr2string_or_emptytern_op],Arrayjit.Indexing.unique_debug_id())::p.debug_info.trace;};})]in(* FIXME: might be better to treat missing [rhs1, rhs2, rhs3] as zeros or errors rather than
eliding the code. *)letbody=[%exprOption.value~default:Arrayjit.Assignments.Noop@@Option.map3[%esetup_r1.array_opt][%esetup_r2.array_opt][%esetup_r3.array_opt]~f:(funrhs1rhs2rhs3->Arrayjit.Assignments.Accum_ternop{initialize_neutral=[%einitialize_neutral];accum=[%eaccu_op];lhs=Option.value_exn[%esetup_l.array_opt];op=[%etern_op];rhs1;rhs2;rhs3;projections=[%eprojections];})]inassignment~punned~lhs:setup_l~rhses:[setup_r1;setup_r2;setup_r3]bodyinletprocess_assign_binop~accu_op~lhs~bin_op~rhs1~rhs2?projections~proj_in_scope()=letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array~punned~bad_pun_hints~is_lhs:true@@loop~proj_in_scope:truelhsinlet_,bin_op=binary_opbin_opinletsetup_r1=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs1inletsetup_r2=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs2inletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inletprojections=matchprojectionswith|Someprjs->prjs|None->letlhs_dims=project_p_dims"LHS"lhs.pexp_locsetup_l.slotinletrhs1_dims=project_p_dims"RHS1"lhs.pexp_locsetup_r1.slotinletrhs2_dims=project_p_dims"RHS2"lhs.pexp_locsetup_r2.slotinletproject_lhs=project_p_slot"LHS"lhs.pexp_locsetup_l.slotinletproject_rhs1=project_p_slot"RHS1"rhs1.pexp_locsetup_r1.slotinletproject_rhs2=project_p_slot"RHS2"rhs2.pexp_locsetup_r2.slotin[%exprlazy(letp=Lazy.forceprojectionsinArrayjit.Indexing.{product_space=p.product_space;product_iterators=p.product_iterators;lhs_dims=[%elhs_dims];rhs_dims=[|[%erhs1_dims];[%erhs2_dims]|];project_lhs=[%eproject_lhs];project_rhs=[|[%eproject_rhs1];[%eproject_rhs2]|];debug_info={p.debug_infowithtrace=("ppx_cd "^[%eexpr2string_or_emptyaccu_op]^" "^[%eexpr2string_or_emptybin_op],Arrayjit.Indexing.unique_debug_id())::p.debug_info.trace;};})]in(* TODO: might be better to treat missing [rhs1, rhs2] as zeros or errors rather than eliding
the code. *)letbody=[%exprOption.value~default:Arrayjit.Assignments.Noop@@Option.map3[%esetup_l.array_opt][%esetup_r1.array_opt][%esetup_r2.array_opt]~f:(funlhsrhs1rhs2->Arrayjit.Assignments.Accum_binop{initialize_neutral=[%einitialize_neutral];accum=[%eaccu_op];lhs;op=[%ebin_op];rhs1;rhs2;projections=[%eprojections];})]inassignment~punned~lhs:setup_l~rhses:[setup_r1;setup_r2]bodyinletprocess_assign_unop~accu_op~lhs~un_op~rhs?projections~proj_in_scope()=letinitialize_neutral,accum=assignment_opaccu_opinlet_,op=unary_opun_opin(* FIXME: I think this ignores the slot information here! Just assuming [projections] is
as-should-be, but that's not consistent with omitting the projections arg (assuming it
comes from the context). *)letsetup_l=setup_array~punned~bad_pun_hints~is_lhs:true@@loop~proj_in_scopelhsinletsetup_r=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhsinletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inletprojections=matchprojectionswith|Someprjs->prjs|None->letlhs_dims=project_p_dims"LHS"lhs.pexp_locsetup_l.slotinletrhs1_dims=project_p_dims"RHS1"lhs.pexp_locsetup_r.slotinletproject_lhs=project_p_slot"LHS"lhs.pexp_locsetup_l.slotinletproject_rhs1=project_p_slot"RHS1"rhs.pexp_locsetup_r.slotin[%exprlazy(letp=Lazy.forceprojectionsinArrayjit.Indexing.{product_space=p.product_space;product_iterators=p.product_iterators;lhs_dims=[%elhs_dims];rhs_dims=[|[%erhs1_dims]|];project_lhs=[%eproject_lhs];project_rhs=[|[%eproject_rhs1]|];debug_info={p.debug_infowithtrace=("ppx_cd "^[%estring_expr~locaccu_op]^" "^[%estring_expr~locun_op],Arrayjit.Indexing.unique_debug_id())::p.debug_info.trace;};})]in(* TODO: might be better to treat missing [rhs] as zeros or errors rather than eliding the
code. *)letbody=[%exprOption.value~default:Arrayjit.Assignments.Noop@@Option.map2[%esetup_l.array_opt][%esetup_r.array_opt]~f:(funlhsrhs->Arrayjit.Assignments.Accum_unop{initialize_neutral=[%einitialize_neutral];accum=[%eaccum];lhs;op=[%eop];rhs;projections=[%eprojections];})]inassignment~punned~lhs:setup_l~rhses:[setup_r]bodyinletprocess_raw_ternop~accu_op~lhs~tern_op~rhs1~rhs2~rhs3~logic=letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array~punned~bad_pun_hints~is_lhs:true@@loop~proj_in_scopelhsinletsetup_r1=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs1inletsetup_r2=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs2inletsetup_r3=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs3inletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inlett_expr,lhs_is_grad,_=args_for~locsetup_linlett1_expr,rhs1_is_grad,rhs1_is_merge=args_for~locsetup_r1inlett2_expr,rhs2_is_grad,rhs2_is_merge=args_for~locsetup_r2inlett3_expr,rhs3_is_grad,rhs3_is_merge=args_for~locsetup_r3inletbody=[%exprTensor.raw_ternop~initialize_neutral:[%einitialize_neutral]~accum:[%eaccu_op]~t:[%et_expr]~lhs_is_grad:[%elhs_is_grad]~op:[%etern_op]~t1:[%et1_expr]~rhs1_is_grad:[%erhs1_is_grad]~rhs1_is_merge:[%erhs1_is_merge]~t2:[%et2_expr]~rhs2_is_grad:[%erhs2_is_grad]~rhs2_is_merge:[%erhs2_is_merge]~t3:[%et3_expr]~rhs3_is_grad:[%erhs3_is_grad]~rhs3_is_merge:[%erhs3_is_merge]~logic:[%elogic]]inassignment~punned~lhs:setup_l~rhses:[setup_r1;setup_r2;setup_r3]bodyinletprocess_raw_binop~accu_op~lhs~bin_op~rhs1~rhs2~logic=letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array~punned~bad_pun_hints~is_lhs:true@@loop~proj_in_scopelhsinletsetup_r1=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs1inletsetup_r2=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhs2inletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inlett_expr,lhs_is_grad,_=args_for~locsetup_linlett1_expr,rhs1_is_grad,rhs1_is_merge=args_for~locsetup_r1inlett2_expr,rhs2_is_grad,rhs2_is_merge=args_for~locsetup_r2inletbody=[%exprTensor.raw_binop~initialize_neutral:[%einitialize_neutral]~accum:[%eaccu_op]~t:[%et_expr]~lhs_is_grad:[%elhs_is_grad]~op:[%ebin_op]~t1:[%et1_expr]~rhs1_is_grad:[%erhs1_is_grad]~rhs1_is_merge:[%erhs1_is_merge]~t2:[%et2_expr]~rhs2_is_grad:[%erhs2_is_grad]~rhs2_is_merge:[%erhs2_is_merge]~logic:[%elogic]]inassignment~punned~lhs:setup_l~rhses:[setup_r1;setup_r2]bodyinletprocess_raw_unop~accu_op~lhs~un_op~rhs~logic=letinitialize_neutral,accu_op=assignment_opaccu_opinletsetup_l=setup_array~punned~bad_pun_hints~is_lhs:true@@loop~proj_in_scopelhsinletsetup_r=setup_array~punned~bad_pun_hints~is_lhs:false@@loop~proj_in_scoperhsinletinitialize_neutral=ifinitialize_neutralthen[%exprtrue]else[%exprfalse]inlett_expr,lhs_is_grad,_=args_for~locsetup_linlett1_expr,rhs_is_grad,rhs_is_merge=args_for~locsetup_rinletbody=[%exprTensor.raw_unop~initialize_neutral:[%einitialize_neutral]~accum:[%eaccu_op]~t:[%et_expr]~lhs_is_grad:[%elhs_is_grad]~op:[%eun_op]~t1:[%et1_expr]~rhs_is_grad:[%erhs_is_grad]~rhs_is_merge:[%erhs_is_merge]~logic:[%elogic]]inassignment~punned~lhs:setup_l~rhses:[setup_r]bodyinmatchexprwith|{pexp_desc=Pexp_constant(Pconst_float_);_}->{default_resultwithexpr=[%exprNTDSL.number[%eexpr]];slot=Scalar}|{pexp_desc=Pexp_constant(Pconst_integer_);_}->{default_resultwithexpr=[%exprNTDSL.number(Float.of_int[%eexpr])];slot=Scalar}|[%expr[%e?{pexp_desc=Pexp_constant(Pconst_charch);pexp_loc;_}][%e?{pexp_desc=Pexp_constant(Pconst_float_);_}asf]]->letaxis=Ast_helper.Exp.constant~loc:pexp_loc(Pconst_string(String.of_charch,pexp_loc,None))in{default_resultwithexpr=[%exprNTDSL.number~axis_label:[%eaxis][%ef]];slot=Scalar;}|[%expr[%e?{pexp_desc=Pexp_constant(Pconst_charch);pexp_loc;_}][%e?{pexp_desc=Pexp_constant(Pconst_integer_);_}asi]]->letaxis=Ast_helper.Exp.constant~loc:pexp_loc(Pconst_string(String.of_charch,pexp_loc,None))in{default_resultwithexpr=[%exprNTDSL.number~axis_label:[%eaxis](Float.of_int[%ei])];slot=Scalar;}|{pexp_desc=Pexp_constant(Pconst_string(name,str_loc,_));_}->{default_resultwithtyp=No_grad_tensor_intro{name;name_expr=expr};expr=A.Exp.ident~loc:str_loc{txt=Lidentname;loc=str_loc};}|{pexp_desc=Pexp_array_;_}|{pexp_desc=Pexp_construct({txt=Lident"::";_},_);_}->{default_resultwithexpr=ndarray_opexpr}|{pexp_desc=Pexp_ident{txt=Lident("v"|"lhs");_};_}->{default_resultwithtyp=Array;slot=LHS}|{pexp_desc=Pexp_ident{txt=Lident"g";_};_}->{default_resultwithtyp=Array;slot=LHS}|{pexp_desc=Pexp_ident{txt=Lident"rhs1";_};_}->{default_resultwithtyp=Array;slot=RHS1}|{pexp_desc=Pexp_ident{txt=Lident"t";_};_}->{default_resultwithslot=LHS}|{pexp_desc=Pexp_ident{txt=Lident"t1";_};_}->{default_resultwithslot=RHS1}|{pexp_desc=Pexp_ident{txt=Lident"v1";_};_}->{default_resultwithtyp=Array;slot=RHS1;expr=[%exprt1.Tensor.value]}|{pexp_desc=Pexp_ident{txt=Lident"g1";_};_}->{default_resultwithtyp=Grad_of_tensor[%exprt1];slot=RHS1;expr=[%exprOption.mapt1.Tensor.diff~f:(fund->d.Tensor.grad)];}|{pexp_desc=Pexp_ident{txt=Lident"rhs2";_};_}->{default_resultwithtyp=Array;slot=RHS2}|{pexp_desc=Pexp_ident{txt=Lident"t2";_};_}->{default_resultwithtyp=Tensor;slot=RHS2}|{pexp_desc=Pexp_ident{txt=Lident"v2";_};_}->{default_resultwithtyp=Array;slot=RHS2;expr=[%exprt2.Tensor.value]}|{pexp_desc=Pexp_ident{txt=Lident"g2";_};_}->{default_resultwithtyp=Grad_of_tensor[%exprt2];slot=RHS2;expr=[%exprOption.mapt2.Tensor.diff~f:(fund->d.Tensor.grad)];}|{pexp_desc=Pexp_ident{txt=Lident"rhs3";_};_}->{default_resultwithtyp=Array;slot=RHS3}|{pexp_desc=Pexp_ident{txt=Lident"t3";_};_}->{default_resultwithtyp=Tensor;slot=RHS3}|{pexp_desc=Pexp_ident{txt=Lident"v3";_};_}->{default_resultwithtyp=Array;slot=RHS3;expr=[%exprt3.Tensor.value]}|{pexp_desc=Pexp_ident{txt=Lident"g3";_};_}->{default_resultwithtyp=Grad_of_tensor[%exprt3];slot=RHS3;expr=[%exprOption.mapt3.Tensor.diff~f:(fund->d.Tensor.grad)];}|{pexp_desc=Pexp_ident{txt=Lidentop_ident;_};_}whenis_primitive_opop_ident->default_result|[%expr!.[%e?expr1]]->(* Hardcoding these two patterns to improve projection derivation expressivity. *)letres1=loop~proj_in_scopeexpr1in{res1withtyp=Tensor;slot=Scalar;expr=[%exprNTDSL.O.(!.)[%eres1.expr]]}|[%expr!..[%e?expr1]]->letres1=loop~proj_in_scopeexpr1in{res1withtyp=Tensor;slot=Scalar;expr=[%exprNTDSL.O.(!..)[%eres1.expr]]}|[%expr[%e?expr1]**.[%e?{pexp_desc=Pexp_constant(Pconst_integer_);_}asi]]->(* FIXME: `**.` should take a tensor and require that it's a literal. *)(* We need to hardcode these two patterns to prevent the numbers from being converted to tensors. *)letres1=loop~proj_in_scopeexpr1in{res1withtyp=Tensor;expr=[%exprNTDSL.O.(**.)[%eres1.expr](Float.of_int[%ei])];}|[%expr[%e?expr1]**.[%e?expr2]]->letres1=loop~proj_in_scopeexpr1in{res1withtyp=Tensor;expr=[%exprNTDSL.O.(**.)[%eres1.expr][%eexpr2]]}|[%expr[%e?expr1]*+[%e?{pexp_desc=Pexp_constant(Pconst_string(spec_str,_,_));_}asspec][%e?expr2]]whenString.containsspec_str'>'->letres1=loop~proj_in_scopeexpr1inletres2=loop~proj_in_scopeexpr2inletslot=Option.value~default:Undet@@List.find~f:(functionUndet->false|_->true)[res1.slot;res2.slot]in{vbs=reduce_vbss[res1.vbs;res2.vbs];typ=Tensor;slot;expr=[%exprNTDSL.einsum[%espec][%eres1.expr][%eres2.expr]];array_opt_of_code=None;}|[%expr[%e?expr1]++[%e?{pexp_desc=Pexp_constant(Pconst_string(spec_str,_,_));_}asspec]]whenString.containsspec_str'>'->letres1=loop~proj_in_scopeexpr1in{res1withtyp=Tensor;expr=[%exprNTDSL.einsum1[%espec][%eres1.expr]]}|[%expr[%e?expr1].grad]->(letres1=loop~proj_in_scopeexpr1inmatchres1.typwith|Unknown|Tensor|No_grad_tensor_intro_->{res1withtyp=Grad_of_tensorexpr1;expr=[%exprOption.map[%eres1.expr].Tensor.diff~f:(fund->d.Tensor.grad)];(* It's never a good idea to embed backprop code outside of a proper backprop pass. *)}|Merge_value_->{res1withtyp=Merge_gradexpr1;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: write .grad.merge instead of .merge.grad";}|Code|Array|Value_of_tensor_|Grad_of_tensor_|Merge_grad_->{res1withtyp=Array;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: only tensors have a gradient";})|[%expr[%e?expr1].value]->(letres1=loop~proj_in_scopeexpr1in(* TODO: maybe this is too permissive? E.g. [t1.grad.value] is accepted. *)matchres1.typwith|Unknown|Tensor|No_grad_tensor_intro_->{res1withtyp=Value_of_tensorres1.expr;expr=[%expr[%eres1.expr].Tensor.value];}|Code->{default_resultwithtyp=Array;slot=res1.slot;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: <code>.value notation not supported when <code> is not a \
tensor";}|Array|Value_of_tensor_|Grad_of_tensor_|Merge_value_|Merge_grad_->res1)|[%expr[%e?expr1].merge]->(letres1=loop~proj_in_scopeexpr1inmatchres1.typwith|Unknown|Tensor|No_grad_tensor_intro_->{res1withtyp=Merge_valueres1.expr;expr=[%expr[%eres1.expr].Tensor.value]}|Value_of_tensort->{res1withtyp=Merge_valuet;expr=[%expr[%eres1.expr].Tensor.value]}|Array|Code->{res1withtyp=Array;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: only tensor nodes (e.g. `.value` or `.grad`) can be merged";}|Grad_of_tensort->{res1withvbs=no_vbs;typ=Merge_gradt}|Merge_value_|Merge_grad_->{res1withexpr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: repeated .merge not allowed";})|[%expr~~([%e?{pexp_desc=Pexp_constant(Pconst_string_);_}ascomment];[%e?expr2])]->letres2=loop~proj_in_scopeexpr2in{res2withexpr=[%exprlet__comment_block=[%eres2.expr]in{Arrayjit.Assignments.asgns=Arrayjit.Assignments.Block_comment([%ecomment],__comment_block.Arrayjit.Assignments.asgns);embedded_nodes=__comment_block.Arrayjit.Assignments.embedded_nodes;}];}|[%expr~~([%e?{pexp_desc=Pexp_apply(expr,exprs);pexp_loc;_}];[%e?expr2])]->letelements=expr::List.map~f:sndexprs|>List.map~f:(function|{pexp_desc=Pexp_constant(Pconst_string_);_}ass->s|[%expr[%e?t].value]->[%exprArrayjit.Tnode.debug_name[%et].value]|[%expr[%e?t].grad]->[%exprArrayjit.Tnode.debug_name[%et].value^".grad"]|t->[%exprArrayjit.Tnode.debug_name[%et].value])inletres2=loop~proj_in_scopeexpr2in{res2withexpr=[%exprlet__comment_block=[%eres2.expr]in{Arrayjit.Assignments.asgns=Arrayjit.Assignments.Block_comment(String.concat_array~sep:" "[%eAst_helper.Exp.array~loc:pexp_locelements],__comment_block.Arrayjit.Assignments.asgns);embedded_nodes=__comment_block.Arrayjit.Assignments.embedded_nodes;}];}|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbin_op;_};_}]([%e?rhs1],[%e?rhs2])~projections:[%e?projections])]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbin_op;_};_}][%e?rhs1][%e?rhs2]~projections:[%e?projections])]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbin_op;_};_}][%e?rhs1]([%e?rhs2]~projections:[%e?projections]))]->(* Note: when clause not needed here and below, it's an error if bin_op is not a primitive
binary op. *)process_assign_binop~accu_op~lhs~bin_op~rhs1~rhs2~projections~proj_in_scope:true()|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidenttern_op;_};_}]([%e?rhs1],[%e?rhs2],[%e?rhs3])~projections:[%e?projections])]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidenttern_op;_};_}][%e?rhs1][%e?rhs2][%e?rhs3]~projections:[%e?projections])]->process_assign_ternop~accu_op~lhs~tern_op~rhs1~rhs2~rhs3~projections~proj_in_scope:true()|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentun_op;_};_}][%e?rhs]~projections:[%e?projections])]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs](* FIXME: this was never needed as prefix operators bind tighter? *)([%e?{pexp_desc=Pexp_ident{txt=Lidentun_op;_};_}]([%e?rhs]~projections:[%e?projections]))]whenHashtbl.memunary_opsun_op->process_assign_unop~accu_op~lhs~un_op~rhs~projections~proj_in_scope:true()|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?rhs]~projections:[%e?projections])]->process_assign_unop~accu_op~lhs~un_op:"id"~rhs~projections~proj_in_scope:true()|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbin_op;_};_}]([%e?rhs1],[%e?rhs2])~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}aslogic])]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbin_op;_};_}][%e?rhs1][%e?rhs2]~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}aslogic])]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbin_op;_};_}][%e?rhs1]([%e?rhs2]~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}aslogic]))]->letlogic=letloc=s_locinifString.equalspec"."then[%exprShape.Pointwise_bin]elseifString.equalspec"@"then[%exprShape.Compose]else[%exprShape.Einsum[%elogic]]inlet_,bin_op=binary_opbin_opinprocess_raw_binop~accu_op~lhs~bin_op~rhs1~rhs2~logic|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidenttern_op;_};_}]([%e?rhs1],[%e?rhs2],[%e?rhs3])~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}])]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidenttern_op;_};_}][%e?rhs1][%e?rhs2][%e?rhs3]~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}])]->letlogic=letloc=s_locinifString.equalspec"."then[%exprShape.Pointwise_bin]elseifString.equalspec"@"then[%exprShape.Compose]elseAst_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: expected <.> or <@>, found <%s> -- einsum notation for ternary \
operators not supported yet, see issue #305"specinlet_,tern_op=ternary_optern_opinprocess_raw_ternop~accu_op~lhs~tern_op~rhs1~rhs2~rhs3~logic|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs](([%e?{pexp_desc=Pexp_ident{txt=Lidentunop_ident;_};_}][%e?rhs])~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}aslogic])]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentunop_ident;_};_}][%e?rhs]~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}aslogic])]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs](* FIXME: this was never needed as prefix operators bind tighter? *)([%e?{pexp_desc=Pexp_ident{txt=Lidentunop_ident;_};_}]([%e?rhs]~logic:[%e?{pexp_desc=Pexp_constant(Pconst_string(spec,s_loc,_));_}aslogic]))]whenHashtbl.memunary_opsunop_ident->(* Handle both un_op priority levels -- where application binds tighter and less tight. *)letlogic=letloc=s_locinifString.equalspec"."then[%exprShape.Pointwise_un]elseifString.equalspec"T"then[%exprShape.Transpose]else[%exprShape.Permute[%elogic]]inlet_,un_op=Hashtbl.find_exnunary_opsunop_identlocinprocess_raw_unop~accu_op~lhs~un_op~rhs~logic|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbin_op;_};_}]([%e?rhs1],[%e?rhs2]))]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbin_op;_};_}][%e?rhs1][%e?rhs2])]whenis_assignmentaccu_op&&Hashtbl.membinary_opsbin_op&&proj_in_scope->process_assign_binop~accu_op~lhs~bin_op~rhs1~rhs2~proj_in_scope()|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidenttern_op;_};_}]([%e?rhs1],[%e?rhs2],[%e?rhs3]))]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidenttern_op;_};_}][%e?rhs1][%e?rhs2][%e?rhs3])]whenis_assignmentaccu_op&&Hashtbl.memternary_opstern_op&&proj_in_scope->process_assign_ternop~accu_op~lhs~tern_op~rhs1~rhs2~rhs3~proj_in_scope()|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentun_op;_};_}][%e?rhs])]whenis_assignmentaccu_op&&Hashtbl.memunary_opsun_op&&proj_in_scope->process_assign_unop~accu_op~lhs~un_op~rhs~proj_in_scope()|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs][%e?rhs]]whenis_assignmentaccu_op&&proj_in_scope->process_assign_unop~accu_op~lhs~un_op:"id"~rhs~proj_in_scope()|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbin_op;_};_}]([%e?rhs1],[%e?rhs2]))]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentbin_op;_};_}][%e?rhs1][%e?rhs2])]whenis_assignmentaccu_op&&Hashtbl.membinary_opsbin_op->letlogic,bin_op=binary_opbin_opinprocess_raw_binop~accu_op~lhs~bin_op~rhs1~rhs2~logic|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidenttern_op;_};_}]([%e?rhs1],[%e?rhs2],[%e?rhs3]))]|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidenttern_op;_};_}][%e?rhs1][%e?rhs2][%e?rhs3])]whenis_assignmentaccu_op&&Hashtbl.memternary_opstern_op->letlogic,tern_op=ternary_optern_opinprocess_raw_ternop~accu_op~lhs~tern_op~rhs1~rhs2~rhs3~logic|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs]([%e?{pexp_desc=Pexp_ident{txt=Lidentun_op;_};_}][%e?rhs])]whenis_assignmentaccu_op&&Hashtbl.memunary_opsun_op->letlogic,un_op=Hashtbl.find_exnunary_opsun_oplocinprocess_raw_unop~accu_op~lhs~un_op~rhs~logic|[%expr[%e?{pexp_desc=Pexp_ident{txt=Lidentaccu_op;_};_}][%e?lhs][%e?rhs]]whenis_assignmentaccu_op->process_raw_unop~accu_op~lhs~un_op:[%exprArrayjit.Ops.Identity]~rhs~logic:[%exprShape.Pointwise_un]|[%expr[%e?expr1][%e?expr2][%e?expr3]]->letres1=loop~proj_in_scopeexpr1inletres2=loop~proj_in_scopeexpr2inletres3=loop~proj_in_scopeexpr3inletslot=Option.value~default:Undet@@List.find~f:(functionUndet->false|_->true)[res1.slot;res2.slot;res3.slot]in{vbs=reduce_vbss[res1.vbs;res2.vbs;res3.vbs];typ=res1.typ;slot;expr=[%expr[%eres1.expr][%eres2.expr][%eres3.expr]];array_opt_of_code=None;}|[%expr[%e?expr1][%e?expr2]]->letres1=loop~proj_in_scopeexpr1inletres2=loop~proj_in_scopeexpr2inletslot=Option.value~default:Undet@@List.find~f:(functionUndet->false|_->true)[res1.slot;res2.slot]in{vbs=reduce_vbss[res1.vbs;res2.vbs];typ=res1.typ;slot;expr=[%expr[%eres1.expr][%eres2.expr]];array_opt_of_code=None;}|{pexp_desc=Pexp_fun((arg_label:arg_label),arg,pat,expr1);_}asexpr->letproj_in_scope=proj_in_scope||matcharg_labelwith|(Labelleds|Optionals)whenString.equals"projections"->true|_->falseinletbad_pun_hints=Set.unionbad_pun_hints@@collect_pat_identspatinletres1=transl~bad_pun_hints~proj_in_scopeexpr1in{res1withexpr={exprwithpexp_desc=Pexp_fun(arg_label,arg,pat,res1.expr)}}|[%exprwhile[%e?_test_expr]do[%e?_body]done]->{default_resultwithtyp=Code;slot=Nonslot;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: while: low-level code embeddings not supported yet";}|[%exprfor[%p?_pat]=[%e?_init]to[%e?_final]do[%e?_body_expr]done]->{default_resultwithtyp=Code;slot=Nonslot;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: for-to: low-level code embeddings not supported yet";}|[%exprfor[%p?_pat]=[%e?_init]downto[%e?_final]do[%e?_body_expr]done]->{default_resultwithtyp=Code;slot=Nonslot;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: for-downto: low-level code embeddings not supported yet";}|[%expr[%e?expr1];[%e?expr2]]->letres1=loop~proj_in_scopeexpr1inletres2=loop~proj_in_scopeexpr2in{vbs=reduce_vbss[res1.vbs;res2.vbs];typ=Code;slot=Nonslot;expr=[%exprArrayjit.Assignments.sequence[[%eres1.expr];[%eres2.expr]]];array_opt_of_code=res2.array_opt_of_code;}|[%exprif[%e?expr1]then[%e?expr2]else[%e?expr3]]->letres2=loop~proj_in_scopeexpr2inletres3=loop~proj_in_scopeexpr3inlettyp=ifis_unknownres2.typthenres3.typelseres2.typinletslot=Option.value~default:Undet@@List.find~f:(functionUndet->false|_->true)[res2.slot;res3.slot]in{vbs=reduce_vbss[res2.vbs;res3.vbs];typ;slot;expr=[%exprif[%eexpr1]then[%eres2.expr]else[%eres3.expr]];array_opt_of_code=None;}|[%exprif[%e?expr1]then[%e?expr2]]->letres2=loop~proj_in_scopeexpr2in{vbs=res2.vbs;typ=Code;slot=Nonslot;expr=[%exprif[%eexpr1]then[%eres2.expr]elseArrayjit.Assignments.empty_comp];array_opt_of_code=res2.array_opt_of_code;}|{pexp_desc=Pexp_match(expr1,cases);_}->letfields,cases=List.unzip@@List.mapcases~f:(fun({pc_rhs;_}asc)->letres=loop~proj_in_scopepc_rhsin((res.vbs,res.typ,res.slot),{cwithpc_rhs=res.expr}))inletvbss,typs,slots=List.unzip3fieldsinlettyp=Option.value~default:Unknown@@List.findtyps~f:(Fn.nonis_unknown)inletslot=Option.value~default:Undet@@List.find~f:(functionUndet->false|_->true)slotsin{vbs=reduce_vbssvbss;typ;slot;expr={exprwithpexp_desc=Pexp_match(expr1,cases)};array_opt_of_code=None;}|{pexp_desc=Pexp_let(_recflag,_bindings,_body);_}->(* TODO(80): to properly support local bindings, we need to collect the type environment. *){default_resultwithtyp=Unknown;expr=Ast_builder.Default.pexp_extension~loc@@Location.error_extensionf~loc"ppx_ocannl %%cd: let-in: local let-bindings not implemented yet";}(* let bindings = List.map bindings ~f:(fun binding -> {binding with pvb_expr=loop
binding.pvb_expr}) in {expr with pexp_desc=Pexp_let (recflag, bindings, loop body)} *)|{pexp_desc=Pexp_open(decl,expr1);_}->letres1=loop~proj_in_scopeexpr1in{res1withexpr={exprwithpexp_desc=Pexp_open(decl,res1.expr)}}|{pexp_desc=Pexp_letmodule(name,module_expr,expr1);_}->letres1=loop~proj_in_scopeexpr1in{res1withexpr={exprwithpexp_desc=Pexp_letmodule(name,module_expr,res1.expr)}}|_->{default_resultwithtyp=Unknown}intransl~proj_in_scope:false~bad_pun_hints:(Set.empty(moduleString))exprlettranslate?ident_labelexpr=letres=translateexprinletloc=res.expr.pexp_locinletexpr=res.exprin(res.vbs,matchident_labelwith|Some[%pat?_]->[%exprTensor.with_unchanged_roots~f:(fun()->letopen!NTDSL.Oin[%eexpr])]|_->[%exprletopen!NTDSL.Oin[%eexpr]])letexpr_expander~loc~path=expr_expander_with_punningtranslate~loc~pathletstr_expander~loc~path=str_expander_with_punningtranslate~loc~path