123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481(* This file is part of the Catala compiler, a specification language for tax and social benefits
computation rules. Copyright (C) 2020-2022 Inria, contributor: Alain Delaët-Tixeuil
<alain.delaet--tixeuil@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
or implied. See the License for the specific language governing permissions and limitations under
the License. *)openUtilsmoduleD=Dcalc.AstmoduleA=AstopenDcalc.Binded_representation(** The main idea around this pass is to compile Dcalc to Lcalc without using [raise EmptyError] nor
[try _ with EmptyError -> _]. To do so, we use the same technique as in rust or erlang to handle
this kind of exceptions. Each [raise EmptyError] will be translated as [None] and each
[try e1 with EmtpyError -> e2] as [match e1 with | None -> e2 | Some x -> x].
When doing this naively, this requires to add matches and Some constructor everywhere. We apply
here an other technique where we generate what we call `hoists`. Hoists are expression whom
could minimally [raise EmptyError]. For instance in
[let x = <e1, e2, ..., en| e_just :- e_cons> * 3 in x + 1], the sub-expression
[<e1, e2, ..., en| e_just :- e_cons>] can produce an empty error. So we make a hoist with a new
variable [y] linked to the Dcalc expression [<e1, e2, ..., en| e_just :- e_cons>], and we return
as the translated expression [let x = y * 3 in x + 1].
The compilation of expressions is found in the functions [translate_and_hoist ctx e] and
[translate_expr ctx e]. Every option-generating expression when calling [translate_and_hoist]
will be hoisted and later handled by the [translate_expr] function. Every other cases is found
in the translate_and_hoist function. *)typehoists=D.exprPos.markedA.VarMap.t(** Hoists definition. It represent bindings between [A.Var.t] and [D.expr]. *)typeinfo={expr:A.exprPos.markedBindlib.box;var:A.exprBindlib.var;is_pure:bool}(** Information about each encontered Dcalc variable is stored inside a context : what is the
corresponding LCalc variable; an expression corresponding to the variable build correctly using
Bindlib, and a boolean `is_pure` indicating whenever the variable can be an EmptyError and hence
should be matched (false) or if it never can be EmptyError (true). *)letpp_info(fmt:Format.formatter)(info:info)=Format.fprintffmt"{var: %a; is_pure: %b}"Print.format_varinfo.varinfo.is_puretypectx={decl_ctx:D.decl_ctx;vars:infoD.VarMap.t;(** information context about variables in the current scope *)}let_pp_ctx(fmt:Format.formatter)(ctx:ctx)=letpp_binding(fmt:Format.formatter)((v,info):D.Var.t*info)=Format.fprintffmt"%a: %a"Dcalc.Print.format_varvpp_infoinfoinletpp_bindings=Format.pp_print_list~pp_sep:(funfmt()->Format.pp_print_stringfmt"; ")pp_bindinginFormat.fprintffmt"@[<2>[%a]@]"pp_bindings(D.VarMap.bindingsctx.vars)(** [find ~info n ctx] is a warpper to ocaml's Map.find that handle errors in a slightly better way. *)letfind?(info:string="none")(n:D.Var.t)(ctx:ctx):info=(* let _ = Format.asprintf "Searching for variable %a inside context %a" Dcalc.Print.format_var n
pp_ctx ctx |> Cli.debug_print in *)tryD.VarMap.findnctx.varswithNot_found->Errors.raise_spanned_error(Format.asprintf"Internal Error: Variable %a was not found in the current environment. Additional \
informations : %s."Dcalc.Print.format_varninfo)Pos.no_pos(** [add_var pos var is_pure ctx] add to the context [ctx] the Dcalc variable var, creating a unique
corresponding variable in Lcalc, with the corresponding expression, and the boolean is_pure. It
is usefull for debuging purposes as it printing each of the Dcalc/Lcalc variable pairs. *)letadd_var(pos:Pos.t)(var:D.Var.t)(is_pure:bool)(ctx:ctx):ctx=letnew_var=A.Var.make(Bindlib.name_ofvar,pos)inletexpr=A.make_var(new_var,pos)in(* Cli.debug_print @@ Format.asprintf "D.%a |-> A.%a" Dcalc.Print.format_var var Print.format_var
new_var; *){ctxwithvars=D.VarMap.updatevar(fun_->Some{expr;var=new_var;is_pure})ctx.vars}(** [tau' = translate_typ tau] translate the a dcalc type into a lcalc type.
Since positions where there is thunked expressions is exactly where we will put option
expressions. Hence, the transformation simply reduce [unit -> 'a] into ['a option] recursivly.
There is no polymorphism inside catala. *)letrectranslate_typ(tau:D.typPos.marked):D.typPos.marked=(Fun.flipPos.same_pos_as)taubeginmatchPos.unmarktauwith|D.TLitl->D.TLitl|D.TTuple(ts,s)->D.TTuple(List.maptranslate_typts,s)|D.TEnum(ts,en)->D.TEnum(List.maptranslate_typts,en)|D.TAny->D.TAny|D.TArrayts->D.TArray(translate_typts)(* catala is not polymorphic *)|D.TArrow((D.TLitD.TUnit,pos_unit),t2)->D.TEnum([(D.TLitD.TUnit,pos_unit);translate_typt2],A.option_enum)(* D.TAny *)|D.TArrow(t1,t2)->D.TArrow(translate_typt1,translate_typt2)endlettranslate_lit(l:D.lit)(pos:Pos.t):A.lit=matchlwith|D.LBooll->A.LBooll|D.LInti->A.LInti|D.LRatr->A.LRatr|D.LMoneym->A.LMoneym|D.LUnit->A.LUnit|D.LDated->A.LDated|D.LDurationd->A.LDurationd|D.LEmptyError->Errors.raise_spanned_error"Internal Error: An empty error was found in a place that shouldn't be possible."pos(** [c = disjoint_union_maps cs] Compute the disjoint union of multiple maps. Raises an internal
error if there is two identicals keys in differnts parts. *)letdisjoint_union_maps(pos:Pos.t)(cs:'aA.VarMap.tlist):'aA.VarMap.t=letdisjoint_union=A.VarMap.union(fun___->Errors.raise_spanned_error"Internal Error: Two supposed to be disjoints maps have one shared key."pos)inList.fold_leftdisjoint_unionA.VarMap.emptycs(** [e' = translate_and_hoist ctx e ] Translate the Dcalc expression e into an expression in Lcalc,
given we translate each hoists correctly. It ensures the equivalence between the execution of e
and the execution of e' are equivalent in an environement where each variable v, where (v, e_v)
is in hoists, has the non-empty value in e_v. *)letrectranslate_and_hoist(ctx:ctx)(e:D.exprPos.marked):A.exprPos.markedBindlib.box*hoists=letpos=Pos.get_positioneinmatchPos.unmarkewith(* empty-producing/using terms. We hoist those. (D.EVar in some cases, EApp(D.EVar _, [ELit
LUnit]), EDefault _, ELit LEmptyDefault) I'm unsure about assert. *)|D.EVarv->(* todo: for now, every unpure (such that [is_pure] is [false] in the current context) is
thunked, hence matched in the next case. This assumption can change in the future, and this
case is here for this reason. *)letv,pos_v=vinifnot(find~info:"search for a variable"vctx).is_purethenletv'=A.Var.make(Bindlib.name_ofv,pos_v)in(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a, created a variable %a to
replace it" Dcalc.Print.format_var v Print.format_var v'; *)(A.make_var(v',pos),A.VarMap.singletonv'e)else((find~info:"should never happend"vctx).expr,A.VarMap.empty)|D.EApp((D.EVar(v,pos_v),p),[(D.ELitD.LUnit,_)])->ifnot(find~info:"search for a variable"vctx).is_purethenletv'=A.Var.make(Bindlib.name_ofv,pos_v)in(* Cli.debug_print @@ Format.asprintf "Found an unpure variable %a, created a variable %a to
replace it" Dcalc.Print.format_var v Print.format_var v'; *)(A.make_var(v',pos),A.VarMap.singletonv'(D.EVar(v,pos_v),p))elseErrors.raise_spanned_error"Internal error: an pure variable was found in an unpure environment."pos|D.EDefault(_exceptions,_just,_cons)->letv'=A.Var.make("default_term",pos)in(A.make_var(v',pos),A.VarMap.singletonv'e)|D.ELitD.LEmptyError->letv'=A.Var.make("empty_litteral",pos)in(A.make_var(v',pos),A.VarMap.singletonv'e)(* This one is a very special case. It transform an unpure expression environement to a pure
expression. *)|ErrorOnEmptyarg->(* [ match arg with | None -> raise NoValueProvided | Some v -> {{ v }} ] *)letsilent_var=A.Var.make("_",pos)inletx=A.Var.make("non_empty_argument",pos)inletarg'=translate_exprctxargin(A.make_matchopt_with_abs_armsarg'(A.make_abs[|silent_var|](Bindlib.box(A.ERaiseA.NoValueProvided,pos))pos[(D.TAny,pos)]pos)(A.make_abs[|x|](A.make_var(x,pos))pos[(D.TAny,pos)]pos),A.VarMap.empty)(* pure terms *)|D.ELitl->(Bindlib.box(A.ELit(translate_litlpos),pos),A.VarMap.empty)|D.EIfThenElse(e1,e2,e3)->lete1',h1=translate_and_hoistctxe1inlete2',h2=translate_and_hoistctxe2inlete3',h3=translate_and_hoistctxe3inlete'=Bindlib.box_apply3(fune1'e2'e3'->(A.EIfThenElse(e1',e2',e3'),pos))e1'e2'e3'in(*(* equivalent code : *) let e' = let+ e1' = e1' and+ e2' = e2' and+ e3' = e3' in
(A.EIfThenElse (e1', e2', e3'), pos) in *)(e',disjoint_union_mapspos[h1;h2;h3])|D.EAsserte1->(* same behavior as in the ICFP paper: if e1 is empty, then no error is raised. *)lete1',h1=translate_and_hoistctxe1in(Bindlib.box_apply(fune1'->(A.EAsserte1',pos))e1',h1)|D.EAbs((binder,pos_binder),ts)->letvars,body=Bindlib.unmbindbinderinletctx,lc_vars=ArrayLabels.fold_rightvars~init:(ctx,[])~f:(funvar(ctx,lc_vars)->(* we suppose the invariant that when applying a function, its arguments cannot be of
the type "option".
The code should behave correctly in the without this assumption if we put here an
is_pure=false, but the types are more compilcated. (unimplemented for now) *)letctx=add_varposvartruectxinletlc_var=(findvarctx).varin(ctx,lc_var::lc_vars))inletlc_vars=Array.of_listlc_varsin(* here we take the guess that if we cannot build the closure because one of the variable is
empty, then we cannot build the function. *)letnew_body,hoists=translate_and_hoistctxbodyinletnew_binder=Bindlib.bind_mvarlc_varsnew_bodyin(Bindlib.box_apply(funnew_binder->(A.EAbs((new_binder,pos_binder),List.maptranslate_typts),pos))new_binder,hoists)|EApp(e1,args)->lete1',h1=translate_and_hoistctxe1inletargs',h_args=args|>List.map(translate_and_hoistctx)|>List.splitinlethoists=disjoint_union_mapspos(h1::h_args)inlete'=Bindlib.box_apply2(fune1'args'->(A.EApp(e1',args'),pos))e1'(Bindlib.box_listargs')in(e',hoists)|ETuple(args,s)->letargs',h_args=args|>List.map(translate_and_hoistctx)|>List.splitinlethoists=disjoint_union_mapsposh_argsin(Bindlib.box_apply(funargs'->(A.ETuple(args',s),pos))(Bindlib.box_listargs'),hoists)|ETupleAccess(e1,i,s,ts)->lete1',hoists=translate_and_hoistctxe1inlete1'=Bindlib.box_apply(fune1'->(A.ETupleAccess(e1',i,s,ts),pos))e1'in(e1',hoists)|EInj(e1,i,en,ts)->lete1',hoists=translate_and_hoistctxe1inlete1'=Bindlib.box_apply(fune1'->(A.EInj(e1',i,en,ts),pos))e1'in(e1',hoists)|EMatch(e1,cases,en)->lete1',h1=translate_and_hoistctxe1inletcases',h_cases=cases|>List.map(translate_and_hoistctx)|>List.splitinlethoists=disjoint_union_mapspos(h1::h_cases)inlete'=Bindlib.box_apply2(fune1'cases'->(A.EMatch(e1',cases',en),pos))e1'(Bindlib.box_listcases')in(e',hoists)|EArrayes->letes',hoists=es|>List.map(translate_and_hoistctx)|>List.splitin(Bindlib.box_apply(funes'->(A.EArrayes',pos))(Bindlib.box_listes'),disjoint_union_mapsposhoists)|EOpop->(Bindlib.box(A.EOpop,pos),A.VarMap.empty)andtranslate_expr?(append_esome=true)(ctx:ctx)(e:D.exprPos.marked):A.exprPos.markedBindlib.box=lete',hoists=translate_and_hoistctxeinlethoists=A.VarMap.bindingshoistsinlet_pos=Pos.get_positionein(* build the hoists *)(* Cli.debug_print @@ Format.asprintf "hoist for the expression: [%a]" (Format.pp_print_list
Print.format_var) (List.map fst hoists); *)ListLabels.fold_lefthoists~init:(ifappend_esomethenA.make_somee'elsee')~f:(funacc(v,(hoist,pos_hoist))->(* Cli.debug_print @@ Format.asprintf "hoist using A.%a" Print.format_var v; *)letc':A.exprPos.markedBindlib.box=matchhoistwith(* Here we have to handle only the cases appearing in hoists, as defined the
[translate_and_hoist] function. *)|D.EVarv->(find~info:"should never happend"(Pos.unmarkv)ctx).expr|D.EDefault(excep,just,cons)->letexcep'=List.map(translate_exprctx)excepinletjust'=translate_exprctxjustinletcons'=translate_exprctxconsin(* calls handle_option. *)A.make_app(A.make_var(A.handle_default_opt,pos_hoist))[Bindlib.box_apply(funexcep'->(A.EArrayexcep',pos_hoist))(Bindlib.box_listexcep');just';cons';]pos_hoist|D.ELitD.LEmptyError->A.make_nonepos_hoist|D.EAssertarg->letarg'=translate_exprctxargin(* [ match arg with | None -> raise NoValueProvided | Some v -> assert {{ v }} ] *)letsilent_var=A.Var.make("_",pos_hoist)inletx=A.Var.make("assertion_argument",pos_hoist)inA.make_matchopt_with_abs_armsarg'(A.make_abs[|silent_var|](Bindlib.box(A.ERaiseA.NoValueProvided,pos_hoist))pos_hoist[(D.TAny,pos_hoist)]pos_hoist)(A.make_abs[|x|](Bindlib.box_apply(funarg->(A.EAssertarg,pos_hoist))(A.make_var(x,pos_hoist)))pos_hoist[(D.TAny,pos_hoist)]pos_hoist)|_->Errors.raise_spanned_error"Internal Error: An term was found in a position where it should not be"pos_hoistin(* [ match {{ c' }} with | None -> None | Some {{ v }} -> {{ acc }} end ] *)(* Cli.debug_print @@ Format.asprintf "build matchopt using %a" Print.format_var v; *)A.make_matchoptpos_hoistv(D.TAny,pos_hoist)c'(A.make_nonepos_hoist)acc)letrectranslate_scope_let(ctx:ctx)(lets:scope_lets)=matchletswith|Resulte->translate_expr~append_esome:falsectxe|ScopeLet{scope_let_kind=SubScopeVarDefinition;scope_let_typ=typ;scope_let_expr=D.EAbs((binder,_),_),_pos;scope_let_next=next;scope_let_pos=pos;}->(* special case : the subscope variable is thunked (context i/o). We remove this thunking. *)let_,expr=Bindlib.unmbindbinderinletvar_is_pure=trueinletvar,next=Bindlib.unbindnextin(* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var var; *)letctx'=add_varposvarvar_is_purectxinletnew_var=(find~info:"variable that was just created"varctx').varinA.make_let_innew_var(translate_typtyp)(translate_exprctx~append_esome:falseexpr)(translate_scope_letctx'next)|ScopeLet{scope_let_kind=SubScopeVarDefinition;scope_let_typ=typ;scope_let_expr=(D.ErrorOnEmpty_,_)asexpr;scope_let_next=next;scope_let_pos=pos;}->(* special case: regular input to the subscope *)letvar_is_pure=trueinletvar,next=Bindlib.unbindnextin(* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var var; *)letctx'=add_varposvarvar_is_purectxinletnew_var=(find~info:"variable that was just created"varctx').varinA.make_let_innew_var(translate_typtyp)(translate_exprctx~append_esome:falseexpr)(translate_scope_letctx'next)|ScopeLet{scope_let_kind=SubScopeVarDefinition;scope_let_pos=pos;scope_let_expr=expr;_}->Errors.raise_spanned_error(Format.asprintf"Internal Error: found an SubScopeVarDefinition that does not satisfy the invariants \
when translating Dcalc to Lcalc without exceptions: @[<hov 2>%a@]"(Dcalc.Print.format_exprctx.decl_ctx)expr)pos|ScopeLet{scope_let_kind=kind;scope_let_typ=typ;scope_let_expr=expr;scope_let_next=next;scope_let_pos=pos;}->letvar_is_pure=matchkindwith|DestructuringInputStruct->((* Here, we have to distinguish between context and input variables. We can do so by
looking at the typ of the destructuring: if it's thunked, then the variable is
context. If it's not thunked, it's a regular input. *)matchPos.unmarktypwithD.TArrow((D.TLitD.TUnit,_),_)->false|_->true)|ScopeVarDefinition|SubScopeVarDefinition|CallingSubScope|DestructuringSubScopeResults|Assertion->trueinletvar,next=Bindlib.unbindnextin(* Cli.debug_print @@ Format.asprintf "unbinding %a" Dcalc.Print.format_var var; *)letctx'=add_varposvarvar_is_purectxinletnew_var=(find~info:"variable that was just created"varctx').varinA.make_let_innew_var(translate_typtyp)(translate_exprctx~append_esome:falseexpr)(translate_scope_letctx'next)lettranslate_scope_body(scope_pos:Pos.t)(ctx:ctx)(body:scope_body):A.exprPos.markedBindlib.box=matchbodywith|{scope_body_result=result;scope_body_input_struct=input_struct;scope_body_output_struct=_output_struct;}->letv,lets=Bindlib.unbindresultinletctx'=add_varscope_posvtruectxinletv'=(find~info:"variable that was just created"vctx').varinA.make_abs[|v'|](translate_scope_letctx'lets)Pos.no_pos[(D.TTuple([],Someinput_struct),Pos.no_pos)]Pos.no_posletrectranslate_scopes(ctx:ctx)(scopes:scopes):Ast.scope_bodylistBindlib.box=matchscopeswith|Nil->Bindlib.box[]|ScopeDef{scope_name;scope_body;scope_next}->letscope_var,next=Bindlib.unbindscope_nextinletnew_ctx=add_varPos.no_posscope_vartruectxinletnew_scope_name=(find~info:"variable that was just created"scope_varnew_ctx).varinletscope_pos=Pos.get_position(D.ScopeName.get_infoscope_name)inletnew_body=translate_scope_bodyscope_posctxscope_bodyinlettail=translate_scopesnew_ctxnextinBindlib.box_apply2(funbodytail->{Ast.scope_body_var=new_scope_name;scope_body_name=scope_name;scope_body_expr=body;}::tail)new_bodytaillettranslate_scopes(ctx:ctx)(scopes:scopes):Ast.scope_bodylist=Bindlib.unbox(translate_scopesctxscopes)lettranslate_program(prgm:D.program):A.program=letinputs_structs=ListLabels.fold_leftprgm.scopes~init:[]~f:(funacc(_,_,body)->body.D.scope_body_input_struct::acc)in(* Cli.debug_print @@ Format.asprintf "List of structs to modify: [%a]" (Format.pp_print_list
D.StructName.format_t) inputs_structs; *)letdecl_ctx={prgm.decl_ctxwithD.ctx_enums=prgm.decl_ctx.ctx_enums|>D.EnumMap.addA.option_enumA.option_enum_config;}inletdecl_ctx={decl_ctxwithD.ctx_structs=prgm.decl_ctx.ctx_structs|>D.StructMap.mapi(funnl->ifList.memninputs_structsthenListLabels.mapl~f:(fun(n,tau)->(* Cli.debug_print @@ Format.asprintf "Input type: %a" (Dcalc.Print.format_typ
decl_ctx) tau; Cli.debug_print @@ Format.asprintf "Output type: %a"
(Dcalc.Print.format_typ decl_ctx) (translate_typ tau); *)(n,translate_typtau))elsel);}inletscopes=prgm.scopes|>bind_scopes|>Bindlib.unbox|>translate_scopes{decl_ctx;vars=D.VarMap.empty}in{scopes;decl_ctx}