123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353(************************************************************************)(* * The Coq Proof Assistant / The Coq Development Team *)(* v * Copyright INRIA, CNRS and contributors *)(* <O___,, * (see version control and CREDITS file for authors & dates) *)(* \VV/ **************************************************************)(* // * This file is distributed under the terms of the *)(* * GNU Lesser General Public License Version 2.1 *)(* * (see LICENSE file for the text of the license) *)(************************************************************************)(** An equation is of the following form a1.x1 + a2.x2 = c
*)typevar=inttypeid=intmoduleItv=structletdebug=falsetypet=int*int(* We only consider closed intervals *)exceptionEmptyletoutputo(lb,ub)=Printf.fprintfo"[%i,%i]"lbubletwfo(lb,ub)=iflb<=ubthen()elsePrintf.fprintfo"error %a\n"output(lb,ub)(** [mul_cst c (lb,ub)] requires c > 0 *)letmul_cstc(lb,ub)=(c*lb,c*ub)(** [opp (lb,ub)] is multplication by -1 *)letopp(lb,ub)=(-ub,-lb)letoppi1=leti=oppi1inifdebugthenPrintf.printf"opp %a -> %a\n"outputi1outputi;i(** [div (lb,ub) c] requires c > 0, lb >= 0 *)letdiv(lb,ub)c=letlb=lb/c+(iflbmodc=0then0else1)inletub=ub/ciniflb<=ubthen(lb,ub)elseraiseEmptyletdivic=tryletr=divicinifdebugthenPrintf.printf"%a div %i -> %a\n"outputicoutputr;rwithEmpty->ifdebugthenPrintf.printf"%a div %i -> Empty \n"outputic;raiseEmptyletadd(lb1,ub1)(lb2,ub2)=(lb1+lb2,ub1+ub2)letaddi1i2=leti=addi1i2inifdebugthenPrintf.printf"%a add %a -> %a\n"outputi1outputi2outputi;iletinter:t->t->t=fun(lb1,ub1)(lb2,ub2)->letub=maxlb1lb2inletlb=minub1ub2inifub<=lbthen(ub,lb)elseraiseEmptyletinteri1i2=tryleti=interi1i2inifdebugthenPrintf.printf"%a inter %a -> %a\n"outputi1outputi2outputi;iwithEmpty->ifdebugthenPrintf.printf"%a inter %a -> Empty\n"outputi1outputi2;raiseEmpty(* [enum (lb,ub)] is only defined for finite intervals *)letenum(lb,ub)=matchInt.comparelbubwith|0->(lb,None)|_->(lb,Some(lb+1,ub))letrange(lb,ub)=ub-lb+1lettop=(min_int,max_int)letlti1i2=rangei1<rangei2endmoduleItvMap=structmoduleM=Map.Make(Int)includeMletrefine_withvim=tryleti0=M.findvminleti'=Itv.interii0in(Itv.lti'i0,i',M.addvi'm)withNot_found->(true,i,M.addvim)letpickm=let(x,i,r)=fold(funvi(v',i',r')->letr=Itv.rangeiinifr<r'then(v,i,r)else(v',i',r'))m(min_int,Itv.top,max_int)inifx=min_intthenraiseNot_foundelse(x,i)letoutputom=Printf.fprintfo"[";iter(funk(lb,ub)->Printf.fprintfo"x%i -> [%i,%i] "klbub)m;Printf.fprintfo"]";endexceptionUnsatmoduleEqn=structtypet=(var*int)list*intletempty=([],0)letrecoutput_linol=matchlwith|[]->Printf.fprintfo"0"|[x,v]->Printf.fprintfo"%i.x%i"vx|(x,v)::l'->Printf.fprintfo"%i.x%i + %a"vxoutput_linl'letnormalise(l,c)=matchlwith|[]->ifc=0thenNoneelseraiseUnsat|_->Some(l,c)letrecno_dupl=matchlwith|[]->true|(x,v)::l->trylet_=List.assocxlinfalsewithNot_found->no_duplletadd(l1,c1)(l2,c2)=(l1@l2,c1+c2)letadde1e2=letr=adde1e2inifno_dup(fstr)then()elsePrintf.printf"add(duplicate)%a %a"output_lin(fste1)output_lin(fste2);rletitv_of_axm(var,coe)=Itv.mul_cstcoe(ItvMap.findvarm)letitv_listml=List.fold_left(funi(var,coe)->Itv.addi(itv_of_axm(var,coe)))(0,0)lletget_removexl=letl'=List.remove_assocxlinletc=tryList.assocxlwithNot_found->0in(c,l')endtypeeqn=Eqn.topenEqnletdebug=false(** Given an equation a1.x1 + ... an.xn = c,
bound all the variables xi in [0; c/ai] *)letinit_boundm(v,c)=matchvwith|[]->ifc=0thenmelseraiseUnsat|[x,v]->let(_,_,m)=ItvMap.refine_withx(Itv.div(c,c)v)minm|_->List.fold_left(funm(var,coe)->let(_,_,m)=ItvMap.refine_withvar(0,c/coe)minm)mvletinit_boundssys=List.fold_leftinit_boundItvMap.emptysysletinit_boundssys=letm=init_boundssysinifdebugthenPrintf.printf"init_bound : %a\n"ItvMap.outputm;m(* [refine_bound p m acc (v,c)]
improves the bounds of the equation v + acc = c
*)letrecrefine_boundpmacc(v,c)=Itv.wfstdoutacc;matchvwith|[]->(m,p)|(var,coe)::v'->ifdebugthenPrintf.printf"Refining %i.x%i + %a + %a = %i\n"coevarItv.outputaccoutput_linv'c;letitv_acc_l=Itv.inter(0,c)(Itv.addacc(itv_listmv'))inletitv_coe_var=Itv.add(c,c)(Itv.oppitv_acc_l)inleti=Itv.divitv_coe_varcoeinlet(b,i',m)=ItvMap.refine_withvariminrefine_bound(p||b)m(Itv.add(Itv.mul_cstcoei')acc)(v',c)letrefine_boundspml=List.fold_left(fun(m,p)eqn->refine_boundpm(0,0)eqn)(m,p)lletrefine_until_fixml=letreciter_refinem=let(m',b)=refine_boundsfalsemlinifbtheniter_refinem'elsem'initer_refinemletsubstxal=letsubst_eqnacc(v,c)=let(coe,v')=Eqn.get_removexvinlet(v',c')=(v',c-coe*a)inmatchv'with|[]->ifc'=0thenaccelseraiseUnsat|_->(v',c')::accinList.fold_leftsubst_eqn[]lletoutput_listeltol=Printf.fprintfo"[";List.iter(fune->Printf.fprintfo"%a; "elte)l;Printf.fprintfo"]"letoutput_equationsol=letoutput_equationo(l,c)=Printf.fprintfo"%a = %i"output_linlcinoutput_listoutput_equationolletoutput_intervalsom=ItvMap.iter(funkv->Printf.fprintfo"x%i:%a "kItv.outputv)mtypesolution=(var*int)listletsolve_systeml=letrecsolveml=ifdebugthenPrintf.printf"Solve %a\n"output_equationsl;matchlwith|[]->[m](* we have a solution *)|_->tryletm'=refine_until_fixmlintryifdebugthenPrintf.printf"Refined %a\n"ItvMap.outputm';let(k,i)=ItvMap.pickm'inlet(v,itv')=Itv.enumiin(* We recursively solve using k = v *)letsol1=List.map(ItvMap.addk(v,v))(solve(ItvMap.removekm)(substkvl))inletsol2=matchitv'with|None->[]|Someitv'->(* We recursively solve with a smaller interval *)solve(ItvMap.addkitv'm)linsol1@sol2with|Not_found->Printf.printf"NOT FOUND %a %a\n"output_equationsloutput_intervalsm';raiseNot_foundwith(Unsat|Itv.Empty)ase->beginifdebugthenPrintf.printf"Unsat detected %s\n"(Printexc.to_stringe);[]endintryletl=CList.map_filterEqn.normaliselinsolve(init_boundsl)lwithItv.Empty|Unsat->[]letenum_solm=letrecaugment_solsx(lb,ub)s=letslb=iflb=0thenselseList.rev_map(funs->(x,lb)::s)siniflb=ubthenslbelseletsl=augment_solsx(lb+1,ub)sinList.rev_appendslbslinItvMap.foldaugment_solsm[[]]letenum_solsl=List.fold_left(funsm->List.rev_append(enum_solm)s)[]lletsolve_and_enuml=enum_sols(solve_systeml)letoutput_solutionos=letoutput_var_coefo(x,v)=Printf.fprintfo"x%i:%i"xvinoutput_listoutput_var_coefos;Printf.fprintfo"\n"letoutput_solutionsol=output_listoutput_solutionol(** Incremental construction of systems of equations *)openMutilstypesystem=Eqn.tIMap.tletempty:system=IMap.emptyletset_constant(idx:int)(c:int)(s:system):Eqn.t=lete=tryIMap.findidxswith|Not_found->Eqn.emptyin(fste,c)letmake_mon(idx:int)(v:var)(c:int)(s:system):system=IMap.addidx([v,c],0)sletmerge(s1:system)(s2:system):system=IMap.merge(funke1e2->matche1,e2with|None,None->None|None,Somee|Somee,None->Somee|Somee1,Somee2->Some(Eqn.adde1e2))s1s2