123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
contributor: Alain Delaët-Tixeuil <alain.delaet--tixeuil@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)openUtilsmoduleD=Dcalc.AstmoduleA=Ast(** The main idea around this pass is to compile Dcalc to Lcalc without using
[raise EmptyError] nor [try _ with EmptyError -> _]. To do so, we use the
same technique as in rust or erlang to handle this kind of exceptions. Each
[raise EmptyError] will be translated as [None] and each
[try e1 with EmtpyError -> e2] as
[match e1 with | None -> e2 | Some x -> x].
When doing this naively, this requires to add matches and Some constructor
everywhere. We apply here an other technique where we generate what we call
`hoists`. Hoists are expression whom could minimally [raise EmptyError]. For
instance in [let x = <e1, e2, ..., en| e_just :- e_cons> * 3 in x + 1], the
sub-expression [<e1, e2, ..., en| e_just :- e_cons>] can produce an empty
error. So we make a hoist with a new variable [y] linked to the Dcalc
expression [<e1, e2, ..., en| e_just :- e_cons>], and we return as the
translated expression [let x = y * 3 in x + 1].
The compilation of expressions is found in the functions
[translate_and_hoist ctx e] and [translate_expr ctx e]. Every
option-generating expression when calling [translate_and_hoist] will be
hoisted and later handled by the [translate_expr] function. Every other
cases is found in the translate_and_hoist function. *)type'mhoists='mD.marked_exprA.VarMap.t(** Hoists definition. It represent bindings between [A.Var.t] and [D.expr]. *)type'minfo={expr:'mA.marked_exprBindlib.box;var:'mA.exprBindlib.var;is_pure:bool;}(** Information about each encontered Dcalc variable is stored inside a context
: what is the corresponding LCalc variable; an expression corresponding to
the variable build correctly using Bindlib, and a boolean `is_pure`
indicating whenever the variable can be an EmptyError and hence should be
matched (false) or if it never can be EmptyError (true). *)letpp_info(fmt:Format.formatter)(info:'minfo)=Format.fprintffmt"{var: %a; is_pure: %b}"Print.format_varinfo.varinfo.is_puretype'mctx={decl_ctx:D.decl_ctx;vars:'minfoD.VarMap.t;(** information context about variables in the current scope *)}let_pp_ctx(fmt:Format.formatter)(ctx:'mctx)=letpp_binding(fmt:Format.formatter)((v,info):D.Var.t*'minfo)=Format.fprintffmt"%a: %a"Dcalc.Print.format_var(D.Var.getv)pp_infoinfoinletpp_bindings=Format.pp_print_list~pp_sep:(funfmt()->Format.pp_print_stringfmt"; ")pp_bindinginFormat.fprintffmt"@[<2>[%a]@]"pp_bindings(D.VarMap.bindingsctx.vars)(** [find ~info n ctx] is a warpper to ocaml's Map.find that handle errors in a
slightly better way. *)letfind?(info:string="none")(n:'mD.var)(ctx:'mctx):'minfo=(* let _ = Format.asprintf "Searching for variable %a inside context %a"
Dcalc.Print.format_var n pp_ctx ctx |> Cli.debug_print in *)tryD.VarMap.find(D.Var.tn)ctx.varswithNot_found->Errors.raise_spanned_errorPos.no_pos"Internal Error: Variable %a was not found in the current environment. \
Additional informations : %s."Dcalc.Print.format_varninfo(** [add_var pos var is_pure ctx] add to the context [ctx] the Dcalc variable
var, creating a unique corresponding variable in Lcalc, with the
corresponding expression, and the boolean is_pure. It is usefull for
debuging purposes as it printing each of the Dcalc/Lcalc variable pairs. *)letadd_var(mark:'mD.mark)(var:'mD.var)(is_pure:bool)(ctx:'mctx):'mctx=letnew_var=A.new_var(Bindlib.name_ofvar)inletexpr=A.make_var(new_var,mark)in(* Cli.debug_print @@ Format.asprintf "D.%a |-> A.%a" Dcalc.Print.format_var
var Print.format_var new_var; *){ctxwithvars=D.VarMap.update(D.Var.tvar)(fun_->Some{expr;var=new_var;is_pure})ctx.vars;}(** [tau' = translate_typ tau] translate the a dcalc type into a lcalc type.
Since positions where there is thunked expressions is exactly where we will
put option expressions. Hence, the transformation simply reduce [unit -> 'a]
into ['a option] recursivly. There is no polymorphism inside catala. *)letrectranslate_typ(tau:D.typMarked.pos):D.typMarked.pos=(Fun.flipMarked.same_mark_as)taubeginmatchMarked.unmarktauwith|D.TLitl->D.TLitl|D.TTuple(ts,s)->D.TTuple(List.maptranslate_typts,s)|D.TEnum(ts,en)->D.TEnum(List.maptranslate_typts,en)|D.TAny->D.TAny|D.TArrayts->D.TArray(translate_typts)(* catala is not polymorphic *)|D.TArrow((D.TLitD.TUnit,pos_unit),t2)->D.TEnum([D.TLitD.TUnit,pos_unit;translate_typt2],A.option_enum)(* D.TAny *)|D.TArrow(t1,t2)->D.TArrow(translate_typt1,translate_typt2)endlettranslate_lit(l:D.lit)(pos:Pos.t):A.lit=matchlwith|D.LBooll->A.LBooll|D.LInti->A.LInti|D.LRatr->A.LRatr|D.LMoneym->A.LMoneym|D.LUnit->A.LUnit|D.LDated->A.LDated|D.LDurationd->A.LDurationd|D.LEmptyError->Errors.raise_spanned_errorpos"Internal Error: An empty error was found in a place that shouldn't be \
possible."(** [c = disjoint_union_maps cs] Compute the disjoint union of multiple maps.
Raises an internal error if there is two identicals keys in differnts parts. *)letdisjoint_union_maps(pos:Pos.t)(cs:'aA.VarMap.tlist):'aA.VarMap.t=letdisjoint_union=A.VarMap.union(fun___->Errors.raise_spanned_errorpos"Internal Error: Two supposed to be disjoints maps have one shared \
key.")inList.fold_leftdisjoint_unionA.VarMap.emptycs(** [e' = translate_and_hoist ctx e ] Translate the Dcalc expression e into an
expression in Lcalc, given we translate each hoists correctly. It ensures
the equivalence between the execution of e and the execution of e' are
equivalent in an environement where each variable v, where (v, e_v) is in
hoists, has the non-empty value in e_v. *)letrectranslate_and_hoist(ctx:'mctx)(e:'mD.marked_expr):'mA.marked_exprBindlib.box*'mhoists=letpos=Marked.get_markeinmatchMarked.unmarkewith(* empty-producing/using terms. We hoist those. (D.EVar in some cases,
EApp(D.EVar _, [ELit LUnit]), EDefault _, ELit LEmptyDefault) I'm unsure
about assert. *)|D.EVarv->(* todo: for now, every unpure (such that [is_pure] is [false] in the
current context) is thunked, hence matched in the next case. This
assumption can change in the future, and this case is here for this
reason. *)ifnot(find~info:"search for a variable"vctx).is_purethenletv'=A.new_var(Bindlib.name_ofv)in(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a,
created a variable %a to replace it" Dcalc.Print.format_var v
Print.format_var v'; *)A.make_var(v',pos),A.VarMap.singleton(A.Var.tv')eelse(find~info:"should never happend"vctx).expr,A.VarMap.empty|D.EApp((D.EVarv,p),[(D.ELitD.LUnit,_)])->ifnot(find~info:"search for a variable"vctx).is_purethenletv'=A.new_var(Bindlib.name_ofv)in(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a,
created a variable %a to replace it" Dcalc.Print.format_var v
Print.format_var v'; *)A.make_var(v',pos),A.VarMap.singleton(A.Var.tv')(D.EVarv,p)elseErrors.raise_spanned_error(D.pose)"Internal error: an pure variable was found in an unpure environment."|D.EDefault(_exceptions,_just,_cons)->letv'=A.new_var"default_term"inA.make_var(v',pos),A.VarMap.singleton(A.Var.tv')e|D.ELitD.LEmptyError->letv'=A.new_var"empty_litteral"inA.make_var(v',pos),A.VarMap.singleton(A.Var.tv')e(* This one is a very special case. It transform an unpure expression
environement to a pure expression. *)|ErrorOnEmptyarg->(* [ match arg with | None -> raise NoValueProvided | Some v -> {{ v }} ] *)letsilent_var=A.new_var"_"inletx=A.new_var"non_empty_argument"inletarg'=translate_exprctxargin(A.make_matchopt_with_abs_armsarg'(A.make_abs[|silent_var|](Bindlib.box(A.ERaiseA.NoValueProvided,pos))[D.TAny,D.pose]pos)(A.make_abs[|x|](A.make_var(x,pos))[D.TAny,D.pose]pos),A.VarMap.empty)(* pure terms *)|D.ELitl->A.elit(translate_litl(D.pose))pos,A.VarMap.empty|D.EIfThenElse(e1,e2,e3)->lete1',h1=translate_and_hoistctxe1inlete2',h2=translate_and_hoistctxe2inlete3',h3=translate_and_hoistctxe3inlete'=A.eifthenelsee1'e2'e3'posin(*(* equivalent code : *) let e' = let+ e1' = e1' and+ e2' = e2' and+ e3' =
e3' in (A.EIfThenElse (e1', e2', e3'), pos) in *)e',disjoint_union_maps(D.pose)[h1;h2;h3]|D.EAsserte1->(* same behavior as in the ICFP paper: if e1 is empty, then no error is
raised. *)lete1',h1=translate_and_hoistctxe1inA.easserte1'pos,h1|D.EAbs(binder,ts)->letvars,body=Bindlib.unmbindbinderinletctx,lc_vars=ArrayLabels.fold_rightvars~init:(ctx,[])~f:(funvar(ctx,lc_vars)->(* we suppose the invariant that when applying a function, its
arguments cannot be of the type "option".
The code should behave correctly in the without this assumption if
we put here an is_pure=false, but the types are more compilcated.
(unimplemented for now) *)letctx=add_varposvartruectxinletlc_var=(findvarctx).varinctx,lc_var::lc_vars)inletlc_vars=Array.of_listlc_varsin(* here we take the guess that if we cannot build the closure because one of
the variable is empty, then we cannot build the function. *)letnew_body,hoists=translate_and_hoistctxbodyinletnew_binder=Bindlib.bind_mvarlc_varsnew_bodyin(Bindlib.box_apply(funnew_binder->A.EAbs(new_binder,List.maptranslate_typts),pos)new_binder,hoists)|EApp(e1,args)->lete1',h1=translate_and_hoistctxe1inletargs',h_args=args|>List.map(translate_and_hoistctx)|>List.splitinlethoists=disjoint_union_maps(D.pose)(h1::h_args)inlete'=A.eappe1'args'posine',hoists|ETuple(args,s)->letargs',h_args=args|>List.map(translate_and_hoistctx)|>List.splitinlethoists=disjoint_union_maps(D.pose)h_argsinA.etupleargs'spos,hoists|ETupleAccess(e1,i,s,ts)->lete1',hoists=translate_and_hoistctxe1inlete1'=A.etupleaccesse1'istsposine1',hoists|EInj(e1,i,en,ts)->lete1',hoists=translate_and_hoistctxe1inlete1'=A.einje1'ientsposine1',hoists|EMatch(e1,cases,en)->lete1',h1=translate_and_hoistctxe1inletcases',h_cases=cases|>List.map(translate_and_hoistctx)|>List.splitinlethoists=disjoint_union_maps(D.pose)(h1::h_cases)inlete'=A.ematche1'cases'enposine',hoists|EArrayes->letes',hoists=es|>List.map(translate_and_hoistctx)|>List.splitinA.earrayes'pos,disjoint_union_maps(D.pose)hoists|EOpop->Bindlib.box(A.EOpop,pos),A.VarMap.emptyandtranslate_expr?(append_esome=true)(ctx:'mctx)(e:'mD.marked_expr):'mA.marked_exprBindlib.box=lete',hoists=translate_and_hoistctxeinlethoists=A.VarMap.bindingshoistsinlet_pos=Marked.get_markein(* build the hoists *)(* Cli.debug_print @@ Format.asprintf "hoist for the expression: [%a]"
(Format.pp_print_list Print.format_var) (List.map fst hoists); *)ListLabels.fold_lefthoists~init:(ifappend_esomethenA.make_somee'elsee')~f:(funacc(v,(hoist,mark_hoist))->(* Cli.debug_print @@ Format.asprintf "hoist using A.%a" Print.format_var
v; *)letc':'mA.marked_exprBindlib.box=matchhoistwith(* Here we have to handle only the cases appearing in hoists, as defined
the [translate_and_hoist] function. *)|D.EVarv->(find~info:"should never happend"vctx).expr|D.EDefault(excep,just,cons)->letexcep'=List.map(translate_exprctx)excepinletjust'=translate_exprctxjustinletcons'=translate_exprctxconsin(* calls handle_option. *)A.make_app(A.make_var(A.Var.getA.handle_default_opt,mark_hoist))[Bindlib.box_apply(funexcep'->A.EArrayexcep',mark_hoist)(Bindlib.box_listexcep');just';cons';]mark_hoist|D.ELitD.LEmptyError->A.make_nonemark_hoist|D.EAssertarg->letarg'=translate_exprctxargin(* [ match arg with | None -> raise NoValueProvided | Some v -> assert
{{ v }} ] *)letsilent_var=A.new_var"_"inletx=A.new_var"assertion_argument"inA.make_matchopt_with_abs_armsarg'(A.make_abs[|silent_var|](Bindlib.box(A.ERaiseA.NoValueProvided,mark_hoist))[D.TAny,D.mark_posmark_hoist]mark_hoist)(A.make_abs[|x|](Bindlib.box_apply(funarg->A.EAssertarg,mark_hoist)(A.make_var(x,mark_hoist)))[D.TAny,D.mark_posmark_hoist]mark_hoist)|_->Errors.raise_spanned_error(D.mark_posmark_hoist)"Internal Error: An term was found in a position where it should \
not be"in(* [ match {{ c' }} with | None -> None | Some {{ v }} -> {{ acc }} end
] *)(* Cli.debug_print @@ Format.asprintf "build matchopt using %a"
Print.format_var v; *)A.make_matchoptmark_hoist(A.Var.getv)(D.TAny,D.mark_posmark_hoist)c'(A.make_nonemark_hoist)acc)letrectranslate_scope_let(ctx:'mctx)(lets:('mD.expr,'m)D.scope_body_expr):('mA.expr,'m)D.scope_body_exprBindlib.box=matchletswith|Resulte->Bindlib.box_apply(fune->D.Resulte)(translate_expr~append_esome:falsectxe)|ScopeLet{scope_let_kind=SubScopeVarDefinition;scope_let_typ=typ;scope_let_expr=D.EAbs(binder,_),emark;scope_let_next=next;scope_let_pos=pos;}->(* special case : the subscope variable is thunked (context i/o). We remove
this thunking. *)let_,expr=Bindlib.unmbindbinderinletvar_is_pure=trueinletvar,next=Bindlib.unbindnextin(* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var
var; *)letvmark=D.map_mark(fun_->pos)(fun_->typ)emarkinletctx'=add_varvmarkvarvar_is_purectxinletnew_var=(find~info:"variable that was just created"varctx').varinletnew_next=translate_scope_letctx'nextinBindlib.box_apply2(funnew_exprnew_next->D.ScopeLet{scope_let_kind=SubScopeVarDefinition;scope_let_typ=translate_typtyp;scope_let_expr=new_expr;scope_let_next=new_next;scope_let_pos=pos;})(translate_exprctx~append_esome:falseexpr)(Bindlib.bind_varnew_varnew_next)|ScopeLet{scope_let_kind=SubScopeVarDefinition;scope_let_typ=typ;scope_let_expr=(D.ErrorOnEmpty_,emark)asexpr;scope_let_next=next;scope_let_pos=pos;}->(* special case: regular input to the subscope *)letvar_is_pure=trueinletvar,next=Bindlib.unbindnextin(* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var
var; *)letvmark=D.map_mark(fun_->pos)(fun_->typ)emarkinletctx'=add_varvmarkvarvar_is_purectxinletnew_var=(find~info:"variable that was just created"varctx').varinBindlib.box_apply2(funnew_exprnew_next->D.ScopeLet{scope_let_kind=SubScopeVarDefinition;scope_let_typ=translate_typtyp;scope_let_expr=new_expr;scope_let_next=new_next;scope_let_pos=pos;})(translate_exprctx~append_esome:falseexpr)(Bindlib.bind_varnew_var(translate_scope_letctx'next))|ScopeLet{scope_let_kind=SubScopeVarDefinition;scope_let_pos=pos;scope_let_expr=expr;_;}->Errors.raise_spanned_errorpos"Internal Error: found an SubScopeVarDefinition that does not satisfy \
the invariants when translating Dcalc to Lcalc without exceptions: \
@[<hov 2>%a@]"(Dcalc.Print.format_exprctx.decl_ctx)expr|ScopeLet{scope_let_kind=kind;scope_let_typ=typ;scope_let_expr=expr;scope_let_next=next;scope_let_pos=pos;}->letvar_is_pure=matchkindwith|DestructuringInputStruct->((* Here, we have to distinguish between context and input variables. We
can do so by looking at the typ of the destructuring: if it's
thunked, then the variable is context. If it's not thunked, it's a
regular input. *)matchMarked.unmarktypwith|D.TArrow((D.TLitD.TUnit,_),_)->false|_->true)|ScopeVarDefinition|SubScopeVarDefinition|CallingSubScope|DestructuringSubScopeResults|Assertion->trueinletvar,next=Bindlib.unbindnextin(* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var
var; *)letvmark=D.map_mark(fun_->pos)(fun_->typ)(Marked.get_markexpr)inletctx'=add_varvmarkvarvar_is_purectxinletnew_var=(find~info:"variable that was just created"varctx').varinBindlib.box_apply2(funnew_exprnew_next->D.ScopeLet{scope_let_kind=kind;scope_let_typ=translate_typtyp;scope_let_expr=new_expr;scope_let_next=new_next;scope_let_pos=pos;})(translate_exprctx~append_esome:falseexpr)(Bindlib.bind_varnew_var(translate_scope_letctx'next))lettranslate_scope_body(scope_pos:Pos.t)(ctx:'mctx)(body:('mD.expr,'m)D.scope_body):('mA.expr,'m)D.scope_bodyBindlib.box=matchbodywith|{scope_body_expr=result;scope_body_input_struct=input_struct;scope_body_output_struct=output_struct;}->letv,lets=Bindlib.unbindresultinletvmark=letm=matchletswith|Resulte|ScopeLet{scope_let_expr=e;_}->Marked.get_markeinD.map_mark(fun_->scope_pos)(funty->ty)minletctx'=add_varvmarkvtruectxinletv'=(find~info:"variable that was just created"vctx').varinBindlib.box_apply(funnew_expr->{D.scope_body_expr=new_expr;scope_body_input_struct=input_struct;scope_body_output_struct=output_struct;})(Bindlib.bind_varv'(translate_scope_letctx'lets))letrectranslate_scopes(ctx:'mctx)(scopes:('mD.expr,'m)D.scopes):('mA.expr,'m)D.scopesBindlib.box=matchscopeswith|Nil->Bindlib.boxD.Nil|ScopeDef{scope_name;scope_body;scope_next}->letscope_var,next=Bindlib.unbindscope_nextinletvmark=matchBindlib.unbindscope_body.scope_body_exprwith|_,(Resulte|ScopeLet{scope_let_expr=e;_})->Marked.get_markeinletnew_ctx=add_varvmarkscope_vartruectxinletnew_scope_name=(find~info:"variable that was just created"scope_varnew_ctx).varinletscope_pos=Marked.get_mark(D.ScopeName.get_infoscope_name)inletnew_body=translate_scope_bodyscope_posctxscope_bodyinlettail=translate_scopesnew_ctxnextinBindlib.box_apply2(funbodytail->D.ScopeDef{scope_name;scope_body=body;scope_next=tail})new_body(Bindlib.bind_varnew_scope_nametail)lettranslate_program(prgm:'mD.program):'mA.program=letinputs_structs=D.fold_left_scope_defsprgm.scopes~init:[]~f:(funaccscope_def_->scope_def.D.scope_body.scope_body_input_struct::acc)in(* Cli.debug_print @@ Format.asprintf "List of structs to modify: [%a]"
(Format.pp_print_list D.StructName.format_t) inputs_structs; *)letdecl_ctx={prgm.decl_ctxwithD.ctx_enums=prgm.decl_ctx.ctx_enums|>D.EnumMap.addA.option_enumA.option_enum_config;}inletdecl_ctx={decl_ctxwithD.ctx_structs=prgm.decl_ctx.ctx_structs|>D.StructMap.mapi(funnl->ifList.memninputs_structsthenListLabels.mapl~f:(fun(n,tau)->(* Cli.debug_print @@ Format.asprintf "Input type: %a"
(Dcalc.Print.format_typ decl_ctx) tau; Cli.debug_print
@@ Format.asprintf "Output type: %a"
(Dcalc.Print.format_typ decl_ctx) (translate_typ
tau); *)n,translate_typtau)elsel);}inletscopes=Bindlib.unbox(translate_scopes{decl_ctx;vars=D.VarMap.empty}prgm.scopes)in{scopes;decl_ctx}