123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468moduleUnionFind=UnionFindBasic(* --------------------------------------------------------- *)(* Variables level and predefined constants *)moduleVl:sigtypetvalcompare:t->t->intvalequal:t->t->boolvalhash:t->intvalpublic:tvalsecret:tvalis_public:t->boolvalis_secret:t->boolvalconstants:tlistvalis_constant:t->boolvalis_uni:t->boolvalfresh:?name:string->unit->tvalname:t->stringvalto_string:t->stringvalpp:Format.formatter->t->unitend=structtypet={name:string;uid:int;}letnamet=t.nameletuid=ref(-1)letfresh?(name="?")_=incruid;{name;uid=!uid;}letpublic=fresh~name:"public"()letsecret=fresh~name:"secret"()letcomparevl1vl2=vl1.uid-vl2.uidletequalvl1vl2=vl1.uid=vl2.uidlethashvl=vl.uidletis_univl=vl.name="?"letto_stringvl=ifis_univlthenvl.name^(string_of_intvl.uid)elsevl.nameletppfmtvl=Format.fprintffmt"%s"(to_stringvl)letconstants=[public;secret]letis_constantvl=List.exists(equalvl)constantsletis_publicvl=equalpublicvlletis_secretvl=equalsecretvlendmoduleHvl=Hashtbl.Make(Vl)moduleMvl=Map.Make(Vl)moduleSvl=Set.Make(Vl)(* --------------------------------------------------------- *)(* Inequalities *)moduleLvl:sigtypetvalvlevel:t->Vl.tvalsuccessors:t->tlistvalfresh:Vl.t->tlist->tvalequal:t->t->boolvalle:t->t->bool(* Do not use this function until you are sure that the merge will not make the constraints inconsistant *)valmerge:t->t->texceptionUnsatof(tlist*t*t)valadd_le:t->t->unitvaliter:(t->unit)->t->unit(* val clear_successors : t -> unit *)(* Warning use this function only when you are sure to not create cycle *)valadd_successors:t->tlist->unitvalpp:Format.formatter->t->unitvalpp_s:?debug:bool->Format.formatter->t->unitvalis_public:t->boolvalis_secret:t->boolvalsimplify:tokeep:(t->bool)->t->unitend=structtypelevel={vlevel:Vl.t;(* l *)succ:tlist;(* l' \in succ => l <= l' *)}andt=levelUnionFind.elemletrepr(e:t)=UnionFind.findeletfreshvlevelsucc=UnionFind.make{vlevel;succ}letvlevel(e:t)=(UnionFind.gete).vlevelletsuccessors(e:t)=(UnionFind.gete).succletis_public(e:t)=Vl.is_public(vlevele)letis_secret(e:t)=Vl.is_secret(vlevele)letvisited=Hvl.create97letclear_visited()=Hvl.clearvisitedletis_visitedvl=Hvl.memvisitedvlletset_visitedvl=Hvl.addvisitedvl()letequal(e1:t)(e2:t)=Vl.equal(vlevele1)(vlevele2)letle(e1:t)(e2:t)=letvl1=vlevele1inletvl2=vlevele2inifVl.is_publicvl1||Vl.is_secretvl2thentrueelsebeginclear_visited();letrecfinde=letvl=vleveleinVl.equalvlvl2||(not(is_visitedvl)&&(set_visitedvl;List.existsfind(successorse)))infinde1endletiter(f:t->unit)e=clear_visited();letrecitere=letvl=vleveleinifnot(is_visitedvl)then(fe;set_visitedvl;List.iteriter(successorse))initereletpath=Hvl.create97letclear_path()=Hvl.clearpathletadd_pathvle=Hvl.addpathvl(Some(repre));trueletadd_nopathvl=Hvl.addpathvlNone;false(* [between l l' = s ]
forall l1 \in s, l <= l1 <= l' *)letbetween(e:t)(e':t)=clear_path();letvl'=vlevele'in(* e <= e1 *)letrecfinde1=letvl1=vlevele1inmatchHvl.findpathvl1with|None->false(* no path from e1 to e' *)|Some_->true(* e1 <= e' *)|exceptionNot_found->ifvl1=vl'then(* e <= e1 = e' *)add_pathvl1e1elseletfound=List.fold_left(funbe2->finde2||b)false(successorse1)iniffoundthen(* e1 <= s <= e' *)add_pathvl1e1elseadd_nopathvl1inignore(finde);Hvl.fold(fun_e1l->matche1withNone->l|Somee1->e1::l)path[]letsucc=Hvl.create97letclear_succ()=Hvl.clearsuccletadd_succvle=Hvl.replacesuccvl(UnionFind.finde)letmerge(e1:t)(e2:t)=clear_succ();letmerge_level(l1:level)(l2:level)=(* (vl1, succ1) (vl2, succ2) *)letadde=letvl=vleveleinifVl.equall1.vlevelvl||Vl.equall2.vlevelvlthen()elseadd_succvleinList.iteraddl1.succ;List.iteraddl2.succ;letvlevel=ifVl.is_constantl2.vlevel||Vl.is_unil1.vlevel&¬(Vl.is_unil2.vlevel)thenl2.vlevelelsel1.vlevelinletsucc=ifHvl.lengthsucc=1thenHvl.fold(fun_es->e::s)succ[]else(* remove secret *)Hvl.fold(funvles->ifVl.equalvlVl.secretthenselsee::s)succ[]in{vlevel;succ}inUnionFind.mergemerge_levele1e2(* let clear_successors (e:t) =
clear_succ ();
let lvl = UnionFind.get e in
List.iter (fun s -> add_succ (vlevel s) s) lvl.succ;
UnionFind.set e { lvl with succ = Hvl.fold (fun _ e s -> e :: s) succ [] } *)letadd_successors(e:t)succ=letlvl=UnionFind.geteinUnionFind.sete{lvlwithsucc=List.fold_left(funsss->reprs::ss)lvl.succsucc}exceptionUnsatof(tlist*t*t)(* Add the constraint l1 <= l2 *)letadd_le(l1:t)(l2:t)=iflel1l2then()(* constraint already present do nothing *)elsematchbetweenl2l1with|[]->(* no cycle, add the constraint *)letlvl=UnionFind.getl1inletlvl={lvlwithsucc=l2::lvl.succ}inUnionFind.setl1lvl;|e::esasees->(* cycle found *)(* Check that we do not merge two constants *)lettest=List.existsis_publicees&&List.existsis_secreteesiniftestthenraise(Unsat(ees,l1,l2));ignore(List.fold_leftmergeees)letppfmtl=Vl.ppfmt(vlevell)letpp_s?(debug=false)fmtl=letvl=vlevellinletsucc=successorslinletppsucc=Format.fprintffmt"%a <= @[%a@],@ "Vl.ppvl(Format.pp_print_list~pp_sep:(funfmt()->Format.fprintffmt"@ ")(funfmts->Vl.ppfmt(vlevels)))succinifdebugthenppsuccelseletsucc=ifVl.is_publicvlthen[]elseList.filter(funs->not(is_secrets))succinifsucc<>[]thenppsuccletsimplify~(tokeep:t->bool)(l:t)=letlong=trueandshort=falseinlet_R=Hvl.create97inletrecvisitex=letvx=vlevelxintryHvl.find_RvxwithNot_found->let_M:(t*bool)Hvl.t=Hvl.create23inletadd_Mz(ez,p)=letp'=p||(trysnd(Hvl.find_Mz)withNot_found->short)inHvl.replace_Mz(ez,p')inletdo_sy=letyin=tokeepyinHvl.iter(ifyinthen(funz(ez,_p)->Hvl.replace_Mz(ez,long))else(funzep->add_Mzep))(visitey);ifyinthenadd_M(vlevely)(y,short)inList.iterdo_s(successorsx);Hvl.add_Rvx_M;(* now clear the successor of x *)letsucc=Hvl.fold(fun_(s,p)succ->ifp=shortthens::succelsesucc)_M[]inletlvl=UnionFind.getxinUnionFind.setx{lvlwithsucc};_Minignore(visitel)end(* ----------------------------------------------------------- *)(* paired types. Essentially a shorthand for adding inequalities *)moduleVlPairs=structtypet=Lvl.t*Lvl.tletadd_le(n1,s1)(n2,s2)=Lvl.add_len1n2;Lvl.add_les1s2letadd_le_speculatives'(_,s)=Lvl.add_les'sletnormalise(l,_)=(l,l)end(* ----------------------------------------------------------- *)moduleC:sigtypeconstraintsvalinit:unit->constraintsvalpublic:constraints->Lvl.tvalsecret:constraints->Lvl.tvalfresh:?name:string->constraints->Lvl.tvalpp_debug:Format.formatter->constraints->unitvalpp:Format.formatter->constraints->unitvalsimplify:constraints->unitvalprune:constraints->Lvl.tlist->unitvaloptimize:constraints->tomax:Lvl.tlist->tomin:Lvl.tlist->unitvalclone:constraints->constraints->(Lvl.t->Lvl.t)valis_instance:(Lvl.t->Lvl.t)->constraints->constraints->boolend=structtypeconstraints={(* FIXME : repr is not needed it can be remove, it is useful only for printing *)repr:Lvl.tHvl.t;public:Lvl.t;secret:Lvl.t;}letpublicc=c.publicletsecretc=c.secretletadd_vltblvlsuccessors=letl=Lvl.freshvlsuccessorsinHvl.addtblvll;lletinit()=letrepr=Hvl.create257inletsecret=add_vlreprVl.secret[]inletpublic=add_vlreprVl.public[]in{repr;public;secret}letfresh?namec=(* FIXME: can we remove secret if we do a special case for "add_le secret l" ?
i.e. set l and all variable greater than l to secret *)letl=add_vlc.repr(Vl.fresh?name())[c.secret]inLvl.add_successorsc.public[l];lletpp_debugfmtc=(* print equalities *)Format.fprintffmt"@[<v>";Hvl.iter(funvll->letvl'=Lvl.vlevellinifnot(Vl.equalvlvl')thenFormat.fprintffmt"%a = %a@ "Vl.ppvlVl.ppvl')c.repr;(* print inequalities *)Lvl.iter(Lvl.pp_s~debug:truefmt)c.public;Format.fprintffmt"@]"letppfmtc=(* Do not print equalities or public <= l or l <= secret *)Format.fprintffmt"@[";Lvl.iter(Lvl.pp_s~debug:falsefmt)c.public;Format.fprintffmt"@]"(* [simplify c] simplify the graph by removing the transitive edge *)letsimplify(c:constraints)=Lvl.simplify~tokeep:(fun_->true)c.publicletprune(c:constraints)(ltokeep:Lvl.tlist)=lettokeep=Hvl.create31inList.iter(funl->Hvl.replacetokeep(Lvl.vlevell)())(c.public::c.secret::ltokeep);Lvl.simplify~tokeep:(funl->Hvl.memtokeep(Lvl.vlevell))c.publictypeminmax=|Minimize|Maximize|MinMax(*
let pp_minmax fmt = function
| Maximize -> Format.fprintf fmt "Maximize"
| MinMax -> Format.fprintf fmt "MinMax"
| Minimize -> Format.fprintf fmt "Minimize"
*)letoptimize(c:constraints)~(tomax:Lvl.tlist)~(tomin:Lvl.tlist)=letminmax=Hvl.create97inletaddmml=letvl=Lvl.vlevellinmatchHvl.findminmaxvlwith|mm'->ifmm<>mm'thenHvl.replaceminmaxvlMinMax|exceptionNot_found->Hvl.addminmaxvlmminaddMinMax(publicc);addMinMax(secretc);List.iter(addMaximize)tomax;List.iter(addMinimize)tomin;letget_minmaxl=tryHvl.findminmax(Lvl.vlevell)withNot_found->MinMaxinletmerge_minmaxls=letlmm=get_minmaxlinletsmm=get_minmaxsinletmm=iflmm=smmthenlmmelseMinMaxinaddmml;addmmsinletprogress=reftrueinwhile!progressdoprogress:=false;(* try to maximize first *)Lvl.iter(funl->ifget_minmaxl=MaximizethenmatchLvl.successorslwith|[s]->progress:=true;merge_minmaxls;ignore(Lvl.mergels)|_->())(publicc);ifnot!progressthenbegin(* Compute the table of predessors *)letpred=Hvl.create97inletget_predl=tryHvl.findpred(Lvl.vlevell)withNot_found->Svl.emptyinletadd_predps=Hvl.replacepred(Lvl.vlevels)(Svl.add(Lvl.vlevelp)(get_preds))inLvl.iter(funl->List.iter(add_predl)(Lvl.successorsl))(publicc);(* minimize *)Lvl.iter(funl->ifget_minmaxl=Minimizethenletp=get_predlinifSvl.cardinalp=1thenletp=Hvl.findc.repr(Svl.choosep)inprogress:=true;merge_minmaxpl;ignore(Lvl.mergepl))(publicc);end;if!progressthensimplifycdone(* let norm (c:constraints) =
Lvl.iter Lvl.clear_successors c.public *)(* clone c and add the constraints in c' *)letclone(c:constraints)(c':constraints)=letsubst=Hvl.create31inletrecdo_ll=letvl=Lvl.vlevellintryHvl.findsubstvlwithNot_found->letsuccessors=List.mapdo_l(Lvl.successorsl)inletl=ifVl.is_publicvlthenc'.publicelseifVl.is_secretvlthenc'.secretelsebeginassert(not(Vl.is_constantvl));add_vlc'.repr(Vl.fresh~name:(Vl.namevl)())[]endinLvl.add_successorslsuccessors;Hvl.addsubstvll;linignore(do_lc.public);do_l(* t | C *)letis_instance(rho:Lvl.t(* c *)->Lvl.t(* cu *))_cuc=tryLvl.iter(funl->letlu=rholinifList.for_all(funs->Lvl.lelu(* rho l*)(rhos))(Lvl.successorsl)then()elseraiseNot_found)c.public;truewithNot_found->falseend