123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106(* This file is free software, part of Zipperposition. See file "license" for more details. *)(** {1 Imperative Union-Find structure} *)(** We need to be able to hash and compare keys, and values need to form
a monoid *)moduletypePAIR=sigtypekeytypevaluevalhash:key->intvalequal:key->key->boolvalmerge:value->value->valuevalzero:valueend(** Build a union-find module from a key/value specification *)moduleMake(P:PAIR)=structtypekey=P.key(** Elements that can be compared *)typevalue=P.value(** Values associated with elements *)typenode={n_key:key;mutablen_repr:nodeoption;(* representative *)mutablen_value:value;(* value (only up-to-date for representative) *)}moduleH=CCHashtbl.Make(structincludePtypet=P.keyend)(** The union-find imperative structure itself*)typet=nodeH.tletmk_nodekey={n_repr=None;n_key=key;n_value=P.zero;}(** Elements that can be compared *)letcreatekeys=lett=H.create8in(* add k -> zero for each key k *)List.iter(funkey->H.replacetkey(mk_nodekey))keys;tletmemtkey=H.memtkeyletn_reprn=n.n_reprletn_keyn=n.n_keyletn_valuen=n.n_valueletrecfind_root(node:node):node=matchnode.n_reprwith|None->node|Somenode'->(* path compression *)assert(node!=node');letroot=find_rootnode'innode.n_repr<-Someroot;rootletfind_nodetkey=H.findtkeyletfindtkey=find_nodetkey|>find_root|>n_key(** Get value of the root for this key. *)letfind_valuetkey=find_nodetkey|>find_root|>n_valueletunion_noden1n2=assert(n1.n_repr=None&&n2.n_repr=None);ifn1!=n2then((* k2 points to k1, and k1 points to the new value *)n1.n_value<-P.mergen1.n_valuen2.n_value;n2.n_repr<-Somen1;)(** Merge two representatives *)letuniontk1k2=letn1=find_nodetk1|>find_rootinletn2=find_nodetk2|>find_rootinunion_noden1n2(** Add the given value to the key (monoid) *)letaddtkeyvalue=tryletnode=find_nodetkey|>find_rootinnode.n_value<-P.mergenode.n_valuevaluewithNot_found->letnode=mk_nodekeyinnode.n_value<-value;H.addtkeynode(** Iterate on representative and their value *)letitertf=H.iter(funkeynode->ifP.equalkeynode.n_key&&node.n_repr=Nonethenfkeynode.n_value)tletto_itertf=itert(funxy->f(x,y))end