123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386(* This file is part of asak.
*
* Copyright (C) 2019 IRIF / OCaml Software Foundation.
*
* asak is distributed under the terms of the MIT license. See the
* included LICENSE file for details. *)openWtreemoduleDistance=structtypet=Regularofint|Infinityletcomparexy=matchxwith|Infinity->1|Regularx'->matchywith|Infinity->-1|Regulary'->comparex'y'letltxy=comparexy<0letmaxxy=ifcomparexy<=0thenyelsexletminxy=ifcomparexy<=0thenxelseyendmoduleHash=structtypet=int*stringletcompare=compareendmoduleHashPairs=structtypet=Hash.t*Hash.tletcompare=compareendmoduleHPMap=Map.Make(HashPairs)moduleHMap=Map.Make(Hash)moduleHSet=Set.Make(Hash)lethashtbl_updatetable~defaultkf=matchHashtbl.findtablekwith|exceptionNot_found->Hashtbl.addtablek(fdefault)|v->Hashtbl.replacetablek(fv)letincrement_key tablekeyvalue=hashtbl_updatetable ~default:0key(funn->n+value)(* NB: the returned hashtable contains only keys (x,y) where x < y.
This is not a problem since the distance is symmetric. *)letcompute_all_sym_diffchildrenxs=(* For each pairof parent nodes, we want to compute the sum of
weights of the children in their symmetric difference.
The general idea of this algorithm is to compute, for a given
parent, the distance to all its "neighbors" (the parents with
some children in common) in one traversal of its children. During
this traversal we compute, for each neighbor of this parent, the
sum of the weights of the children they have in common. From this
"common weight" we can deduce the weight of the symmetric
difference.
*)letparents=(* map each subtree to its occurrence count in its parents;
(an occurrence count is itself a map from parent to the number
of occurrences of the subtree) *)letparents=Hashtbl.create42inletadd_child~parent(_weight,child)=hashtbl_updateparents~default:(Hashtbl.create5)child(funoccurrence_map ->increment_key occurrence_mapparent1;occurrence_map)inletadd_childrenparent=List.iter(add_child ~parent)(Hashtbl.findchildrenparent)inList.iteradd_children xs;parentsinlettotal_weights =(* map each tree to the sum of its children weights *)letweights=Hashtbl.create42inletadd_child_weight ~parent(weight,_child)=increment_keyweightsparentweightinletadd_weights parent=List.iter(add_child_weight ~parent)(Hashtbl.findchildrenparent)inList.iteradd_weights xs;weights inletnode_neighborsx=(* map each nodeto a hashtable of the common weights with its neighbors *)letdiffs=Hashtbl.create42inletadd_child~parent(weight,child)=letchild_occurrences=Hashtbl.findparentschildinletoccurrences_in_parent=Hashtbl.findchild_occurrences parentinletadd_common_weight neighboroccurrences_in_neighbor=ifnot(parent <neighbor)then()elseletcommon_occurrences=minoccurrences_in_parent occurrences_in_neighborinincrement_key diffsneighbor(weight*common_occurrences)inHashtbl.iteradd_common_weightchild_occurrencesinHashtbl.findchildren x|>HSet.of_list|>HSet.iter(add_child~parent:x);diffsinletnodes=refHSet.emptyinletdists=Hashtbl.create42inlettreat_nodex=letadd_neighborycommon_weight=letdist=Hashtbl.findtotal_weightsx+Hashtbl.findtotal_weightsy-2*common_weight innodes:=HSet.addx(HSet.addy!nodes);Hashtbl.adddists(x,y)distinHashtbl.iteradd_neighbor (node_neighbors x)inList.itertreat_nodexs;!nodes,distsletiter_on_cart_prod fxs=List.iter(funx->List.iter(fx)xs)xsletadjacency_lists tblxs=letneighbors=Hashtbl.create(List.lengthxs)inletadd_linkxy=ifx<y&&Hashtbl.memtbl(x,y)thenbeginHashtbl.addneighborsxy;Hashtbl.addneighborsyx;endiniter_on_cart_prod add_linkxs;neighbors(* Surapproximate classes by using the transitive closure of xRy <=> dist x y < Infinity.
If not (xRy) => dist x y = Infinity, thus x and y cannot be in the same class.
*)letsurapproximate_classesneighborsxs=letclasses=ref[]inletmissing_nodes=ref(List.fold_rightHSet.addxsHSet.empty)inwhilenot(HSet.is_empty!missing_nodes)doletstart=HSet.choose!missing_nodesinletto_visit=Stack.create()inStack.pushstartto_visit;letnew_class=refHSet.emptyinwhilenot(Stack.is_emptyto_visit)doletcur=Stack.popto_visitinifHSet.memcur !new_classthen()elsebeginnew_class:=HSet.addcur!new_class;List.iter(fun next->Stack.pushnextto_visit)(Hashtbl.find_allneighborscur)enddone;missing_nodes:=HSet.diff!missing_nodes!new_class;classes:=HSet.elements!new_class ::!classes;done;!classes(* code from the OCaml manual (tutorial on modules) *)modulePrioQueue=structtypepriority=inttype'aqueue=Empty|Nodeofpriority*'a*'aqueue*'aqueueletempty=Emptyletis_empty=function|Empty->true|Node_->falseletrecinsertqueueprioelt=match queue with|Empty->Node(prio,elt,Empty,Empty)|Node(p,e,left,right)->ifprio<=pthenNode(prio,elt,insertrightpe,left)elseNode(p,e,insertrightprioelt,left)exceptionQueue_is_emptyletrecremove_top=function|Empty->raiseQueue_is_empty|Node(_prio,_elt,left,Empty)->left|Node(_prio,_elt,Empty,right)->right|Node(_prio,_elt,(Node(lprio,lelt,_,_)asleft),(Node(rprio,relt,_,_)asright))->iflprio<=rpriothenNode(lprio,lelt,remove_topleft,right)elseNode(rprio,relt,left,remove_topright)let extract=function|Empty->raiseQueue_is_empty|Node(prio,elt,_,_)asqueue->(prio,elt,remove_topqueue)endletinitlenf=#ifOCAML_VERSION>=(4,06,0)List.initlenf#elseletrecauxinf=ifi>=nthen[]elseletr=fiinr::aux(i+1)nfinaux0lenf#endiflethierarchical_clusteringtblclasses=letclusterconnex_class=(* Within each connected component, we can expect most nodes to
have a distance to each other. We store nodes as dense arrays
by indexing the input leaves from 0 to N-1, and we store
distances between nodes in a dense matrix.
At most N-1 non-leaf nodes will be created by the clustering process.
(One way to prove this is to look at the number of nodes that are "roots"
(do not have a parent) at any point in the construction. At the beginning,
the N leaves are all roots. Each time we add a new node, its children were
roots and are not anymore, so have one less root. At the end, there is at least
one root.)
So we index all intermediate trees in the construction from 0 to 2N-2, starting
with the initial leaves, and filling the end of those arrays/matrices with nodes
as they constructed.
The fast algorithm keeps the following state:
- [nodes]: a mapping from indices to trees
- [roots]: a boolean array to know which trees are roots at the current step
- [distances]: a matrix of distances between all known nodes
- [next_slot]: the index of the next node to be constructed
- [queue]: a priority queue of pairs of nodes,
At a given step, only the indices from [0] to [!next_slot - 1] represent valid trees,
so the values in [roots], [distances] and [nodes] above [!next_slot] are invalid.
Because distances are symmetric, we store the distance between indices [i] and [j]
in [distances.(min i j).(max i j)].
(We could store only a diagonal matrix to save space.)
*)letsize=List.lengthconnex_classin(* N *)letclass_array=Array.of_list connex_classinletnodes=List.map(funx->Leafx)connex_class@init(size-1)(fun_->Leaf(List.hdconnex_class))|> Array.of_listinletroots=Array.make (2*size-1)trueinletqueue=refPrioQueue.emptyinletdistances=(* construct the initial distance matrix between leaves,
and initialize [queue] with the N^2 pairs of leaves *)letdist_leavesij=letx,y=class_array.(i),class_array.(j)inmatchHashtbl.findtbl(ifx<ythen(x,y)else (y,x))with|dist->queue:=PrioQueue.insert!queuedist(i,j);Distance.Regulardist|exceptionNot_found->Distance.InfinityinArray.init(2*size-1)(funi->Array.init(2*size-1)(funj->ifi<size&&j<size&&i<jthen dist_leaves ijelseDistance.Regular(-1)))inletnext_slot=refsizeinwhile not(PrioQueue.is_empty!queue||!next_slot =2*size-1)dolet(dist,(i,j),rest)=PrioQueue.extract!queuein(* We use the priority queue to select the pair of closest nodes.
We can only pair those nodes if they are not already roots,
we skip all the pairs that contain a non-root node.*)queue:=rest;assert(i<>j);if not(roots.(i)&&roots.(j))then()else beginletx,y=nodes.(i),nodes.(j)inletk=!next_slotinassert(k<2*size-1);nodes.(k)<-Node(dist,x,y);incrnext_slot;roots.(i)<-false;roots.(j)<-false;assertroots.(k);(* We create a new node [k], whose children [i] and [j] are not roots anymore.
(this may invalidate many pairs in [queue], but they stay there and
will be discarded when picked; [queue] is large so cleaning it up
would be too costly)
We now compute the distance between [k] and all older root
nodes (indices upto [k - 1]). At the same time we add each
pair ([k], [other_root]) into the priority queue.
*)forn=0tok-1doifnotroots.(n)then()else beginletdist=Distance.maxdistances.(minin).(maxin)distances.(minjn).(maxjn)inassert(distances.(n).(k)=Distance.Regular(-1));distances.(n).(k)<-dist;matchdistwith|Distance.Infinity ->()|Distance.Regulardist->queue:=PrioQueue.insert!queuedist(k,n)enddone;end;done;(* Complexity analysis: the priority queue is large, of the order
of O(N^2) entries during the traversal -- at a maintenance cost
of O(N^2 log N) with our implementation.
Each root-pair in the queue requires O(N) work to update the
distance matrix and the priority queue.
Naively this gives a O(N^3) bound, but in fact we encounter
much fewer than N^2 root-pairs: each time we see a minimal
root-pair, we create a new node, and we know that we create
a most N nodes. So among the O(N^2) elements of the queue,
only N require O(N) work, giving a total complexity of O(N^2 logN).
*)List.concat(init!next_slot (funi->ifroots.(i)then[nodes.(i)]else[]))inList.concat(List.mapclusterclasses)letadd_in_clustermap(x,h)=matchHMap.find_opt hmapwith|None->HMap.addh[x]map|Someys->HMap.addh(x::ys)mapletinitial_clusterhash_list=List.fold_leftadd_in_clusterHMap.emptyhash_listletcreate_start_clusterhash_list=letcluster =initial_clusterhash_listinHMap.fold(funkxsacc->(k,xs)::acc)cluster[]letcompare_size_of_trees xy=compare (size_of_treeList.lengthx)(size_of_treeList.lengthy)letadd_in_assoctbl(_,(h,xs))=ifnot (Hashtbl.memtblh)thenHashtbl.addtblh(List.sortcomparexs)letcluster?filter_small_trees(hash_list:('a*(Hash.t*Hash.tlist))list)=lethash_list=matchfilter_small_treeswith|None ->hash_list|Somet->List.filter(fun(_,((p,_),_))->p>=t)hash_list inletassoc_main_subs=Hashtbl.create(List.length hash_list)inList.iter(add_in_assocassoc_main_subs)hash_list;letstart=create_start_cluster(List.rev_map(fun(k,(h,_))->k,h)hash_list)inletstart,assoc_hash_ident_list=List.fold_left(fun(acc,m)(main_hash,xs)->main_hash::acc,HMap.addmain_hashxsm)([],HMap.empty)startinletcreate_leaf k=Leaf(HMap.findkassoc_hash_ident_list)inletwas_seen,hdistance_matrix=compute_all_sym_diffassoc_main_subsstartinletlst,alone=List.partition(funx->HSet.memxwas_seen)startinletneighobrs=adjacency_lists hdistance_matrixlstinletsurapprox=surapproximate_classesneighobrslstinletdendrogram_list =hierarchical_clustering hdistance_matrixsurapproxinletcluster=List.sort(funxy->-(compare_size_of_treesxy))(List.map(fold_tree(funabc->Node (a,b,c))create_leaf)dendrogram_list)incluster@(List.rev_mapcreate_leafalone)letprint_clustershowcluster=letrecauxi=function|Leaf x->beginprint_string(String.makei' ');List.iter(funx->print_string(showx^" "))x;print_endline "";end|Node(w,x,y)->beginprint_string(String.makei' ');print_string("Node "^string_of_int w^":");print_endline"";aux(i+1)x;aux(i+1)yendinletpclassix=print_endline ("Class "^string_of_inti^":");aux1xinList.iteripclasscluster