123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260openLogtkopenLibzipperpositionletsection=Util.Section.make~parent:Const.section"ple"typeid_sgn=ID.t*boolmoduleIDMap=CCMap.Make(structtypet=id_sgnletcompare(id1,sgn1)(id2,sgn2)=letopenCCOrdinID.compareid1id2<?>(CCBool.compare,sgn1,sgn2)end)moduleIntSet=Util.Int_setmoduleTST=TypedSTermexceptionAppVarFoundlet_enabled=reffalselettotal_clauses=Util.mk_stat"total_clauses"letremoved_clauses=Util.mk_stat"removed_clauses"letpp_key=CCPair.ppID.ppCCBool.ppletcl_symslits=letlit_symslit=SLiteral.fold(funacct->ID.Set.union(ID.Set.of_iter(TST.Seq.symbolst))acc)ID.Set.emptylitinList.fold_left(funacclit->ID.Set.unionacc(lit_symslit))ID.Set.emptylits(* Computes a map (symbol, polarity) -> item, that is for each item we know how many times
it occurs positive, how many times it occurs negative and for the clauses in
which it occurs, what other symbols occur with what polarity *)letcompute_occurence_map(seq:(TST.tSLiteral.tlist,TST.t,TST.t)Statement.tIter.t)=(* given map of symbols to clauses in which they occur and a new clause
in the form (cl_idx, literals) return updated map and converted clause *)letprocess_literalsids2clauses(cl_idx,lits)=(* This will compute the map (symbol, polarity) -> occurrence count
for the list of literals *)letrecocurrencesmap=letinc_maphdsignmap=letprev=IDMap.get_or(hd,sign)map~default:0inIDMap.add(hd,sign)(prev+1)mapinfunction|[]->map|x::xs->trymatchxwith|SLiteral.Atom(pred,sign)->ocurrences(inc_map(TST.head_exnpred)signmap)xs|SLiteral.Neq(lhs,rhs)|SLiteral.Eq(lhs,rhs)whenTST.Ty.is_prop(TST.ty_exnlhs)->let(l_hd,r_hd)=CCPair.map_sameTST.head_exn(lhs,rhs)in(* each symbol occurs once negative and once positive *)letinc_l=inc_mapl_hdfalse(inc_mapl_hdtruemap)inocurrences(inc_mapr_hdfalse(inc_mapr_hdtrueinc_l))xs|_->ocurrencesmapxswithInvalid_argument_->(* invalid arg is raised if the head is not a symbol *)raiseAppVarFoundinletall_symbol_occurences=ocurrencesIDMap.emptylitsinletid_map'=IDMap.keysall_symbol_occurences|>Iter.mapfst|>ID.Set.of_iter|>(funsyms->ID.Set.fold(funsymacc->letprev=ID.Map.get_orsymacc~default:IntSet.emptyinID.Map.addsym(IntSet.addcl_idxprev)acc)symsids2clauses)inid_map',all_symbol_occurencesinIter.fold(fun(forbidden,ids2clauses,clauses)(stmt:(TST.tSLiteral.tlist,TST.t,TST.t)Statement.t)->matchStatement.viewstmtwith(* Ignoring type declarations *)|Statement.TyDecl_->forbidden,ids2clauses,clauses|Statement.Data_|Statement.Lemma_|Statement.Def_|Statement.Rewrite_->letf_syms=Statement.Seq.formsstmt|>Iter.fold(funacclits->ID.Set.unionacc(cl_symslits))ID.Set.emptyin(ID.Set.unionf_symsforbidden),ids2clauses,clauses|Statement.Assertlits->(* normal clause *)letids2clauses',new_cl=process_literalsids2clauses(List.lengthclauses,lits)inforbidden,ids2clauses',new_cl::clauses|Statement.NegatedGoal(skolems,list_of_lits)->(* clauses stemming from the negated goal *)letids2clauses',clauses'=List.fold_left(fun(ids2cls,cls)lits->letids2cls',new_cl=process_literalsids2cls(List.lengthcls,lits)in(ids2cls',new_cl::cls))(ids2clauses,clauses)list_of_litsinforbidden,ids2clauses',clauses'|Statement.Goallits->(* after CNFing a 'normal' problem all goals should
be negated and clausified *)failwith"Not implemented: Goal")(ID.Set.empty,ID.Map.empty,[])seqletget_pure_symbolsforbiddenids2clausesclauses=letcalculate_pureinit_purecl_statusall_clauses=letclauses=CCArray.of_list(List.revclauses)inletrecauxprocessedsymbol_occurences=function|[]->processed|(sym::syms)asall_syms->ifID.Set.memsymprocessed||ID.Set.memsymforbiddenthen(auxprocessedsymbol_occurencessyms)else(letclauses_to_remove=CCList.filter_map(funidx->ifnot(CCBV.getcl_statusidx)then(CCBV.setcl_statusidx;Some(clauses.(idx)))elseNone)(IntSet.to_list@@ID.Map.findsymids2clauses)inletsymbol_occurences',next_to_process=List.fold_left(fun(sym_occs,next_to_process)cl->IDMap.fold(funkeyocc(sym_occs,next_to_process)->letprev=IDMap.findkeysym_occsinletnew_=prev-occinletsym_occs'=IDMap.addkeynew_sym_occsinifnew_=0&¬(ID.Set.memsymprocessed||ID.Set.memsymforbidden||(CCList.mem~eq:ID.equal(fstkey)all_syms))then(sym_occs',(fstkey)::next_to_process)else(sym_occs',next_to_process))cl(sym_occs,next_to_process))(symbol_occurences,[])clauses_to_removeinUtil.debugf~section1"became pure: @[%a@]@."(funk->k(CCList.ppID.pp)next_to_process);aux(ID.Set.addsymprocessed)symbol_occurences'(next_to_process@syms))inauxID.Set.emptyall_clauses(ID.Set.to_listinit_pure)in(* joins all clauses in one map with occurences of symbol *)letall_clauses=List.fold_left(funacccl->IDMap.union(fun_ab->CCOpt.return@@a+b)acccl)IDMap.emptyclausesinletall_symbols=IDMap.keysall_clauses|>Iter.mapfst|>ID.Set.of_iterinletinit_pure=ID.Set.filter(funsym->not@@ID.Set.memsymforbidden&&(IDMap.get_or~default:0(sym,true)all_clauses==0||IDMap.get_or~default:0(sym,false)all_clauses==0))all_symbolsinUtil.debugf~section1"initially pure: @[%a@]@."(funk->k(ID.Set.ppID.pp)init_pure);letclause_status=CCBV.create~size:(List.lengthclauses)falseincalculate_pureinit_pureclause_statusall_clausesletremove_pure_clauses(seq:(TST.tSLiteral.tlist,TST.t,TST.t)Statement.tIter.t)=letforbidden,ids2cls,cls=compute_occurence_mapseqinletpure_syms=ID.Set.diff(get_pure_symbolsforbiddenids2clscls)forbiddeninletfilter_if_has_purestmtlits=Util.incr_stattotal_clauses;letans=CCOpt.return_if(ID.Set.is_empty@@ID.Set.interpure_syms(cl_symslits))stmtinifCCOpt.is_noneansthen(Util.incr_statremoved_clauses;Util.debugf~section2"removed: @[%a@]@."(funk->k(CCList.pp(SLiteral.ppTST.pp))lits););ansinIter.filter_map(funstmt->matchStatement.viewstmtwith|Statement.TyDecl_|Statement.Data_|Statement.Lemma_|Statement.Def_|Statement.Rewrite_->Somestmt|Statement.Assertlits->(* normal clause *)filter_if_has_purestmtlits|Statement.NegatedGoal(skolems,list_of_lits)->(* clauses stemming from the negated goal *)letnew_cls=CCList.filter_map(funx->filter_if_has_purexx)list_of_litsinifCCList.is_emptynew_clsthenNoneelseifList.lengthnew_cls==List.lengthlist_of_litsthenSomestmtelse(letnew_stm=Statement.neg_goal~attrs:(Statement.attrsstmt)~proof:(Statement.proof_stepstmt)~skolemsnew_clsinSome(new_stm))|Statement.Goallits->(* after CNFing a 'normal' problem all goals should
be negated and clausified *)Somestmt)seqletextension=letmodifier(seq:(TST.tSLiteral.tlist,TST.t,TST.t)Statement.tIter.t)=if!_enabledthenbegintryremove_pure_clausesseqwithAppVarFound|Not_found->seqendelseseqinletprint_stats_=if!_enabledthenCCFormat.printf"%%%a@.%%%a@."Util.pp_statremoved_clausesUtil.pp_stattotal_clauses;inExtensions.({defaultwithname="pure_literal_elimination";post_cnf_modifiers=[modifier];env_actions=[print_stats];})let()=Options.add_opts["--pure-literal-preprocessing",Arg.Bool(funv->_enabled:=v)," remove all pure literals in fixpoint"];Extensions.registerextension