123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315# 1 "src/owl/nlp/owl_nlp_lda.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2017
* Ben Catterall <bpwc2@cam.ac.uk>
* Liang Wang <liang.wang@cl.cam.ac.uk>
*)[@@@warning"-6"](** NLP: LDA module *)typelda_typ=|SimpleLDA|FTreeLDA|LightLDA|SparseLDAtypemodel={mutablen_d:int;(* number of documents *)mutablen_k:int;(* number of topics *)mutablen_v:int;(* number of vocabulary *)mutablealpha:float;(* model hyper-parameters *)mutablebeta:float;(* model hyper-parameters *)mutablealpha_k:float;(* model hyper-parameters *)mutablebeta_v:float;(* model hyper-parameters *)mutablet_dk:floatarrayarray;(* document-topic table: num of tokens assigned to each topic in each doc *)mutablet_wk:floatarrayarray;(* word-topic table: num of tokens assigned to each topic for each word *)mutablet__k:floatarray;(* number of tokens assigned to a topic: k = sum_w t_wk = sum_d t_dk *)mutablet__z:intarrayarray;(* table of topic assignment of each token in each document *)mutableiter:int;(* number of iterations *)mutabledata:Owl_nlp_corpus.t;(* training data, tokenised*)mutablevocb:(string,int)Hashtbl.t;(* vocabulary, or dictionary if you prefer *)}letinclude_tokenmwdk=m.t__k.(k)<-(m.t__k.(k)+.1.);m.t_wk.(w).(k)<-(m.t_wk.(w).(k)+.1.);m.t_dk.(d).(k)<-(m.t_dk.(d).(k)+.1.)letexclude_tokenmwdk=m.t__k.(k)<-(m.t__k.(k)-.1.);m.t_wk.(w).(k)<-(m.t_wk.(w).(k)-.1.);m.t_dk.(d).(k)<-(m.t_dk.(d).(k)-.1.)letshow_info_mit=Owl_log.info"iter#%i t(s):%.1f t_dk:%.3f t_wk:%.3f"it0.0.(* implement several LDA with specific samplings *)moduleSimpleLDA=structletinit_m=()letsamplingmddoc=letp=Array.makem.n_k0.inArray.iteri(funiw->letk=m.t__z.(d).(i)inexclude_tokenmwdk;(* make cdf function *)letx=ref0.inforj=0tom.n_k-1dox:=!x+.(m.t_dk.(d).(j)+.m.alpha_k)*.(m.t_wk.(w).(j)+.m.beta)/.(m.t__k.(j)+.m.beta_v);p.(j)<-!x;done;(* draw a sample *)letu=Owl_stats.std_uniform_rvs()*.!xinletk=ref0inwhilep.(!k)<udok:=!k+1done;include_tokenmwd!k;m.t__z.(d).(i)<-!k;)docendmoduleSparseLDA=structlets=ref0.(* Cache of s *)letq=ref[||](* Cache of q *)letr_non_zero:(int,float)Hashtbl.tref=ref(Hashtbl.create1)(* *)letq_non_zero:(int,bool)Hashtbl.tref=ref(Hashtbl.create1)(* *)letexclude_token_sparsemwdk~s~r~q=lett__klocal=refm.t__k.(k)in(* Reduce s, r l *)s:=!s-.(m.beta*.m.alpha_k)/.(!t__klocal+.m.beta_v);r:=!r-.(m.beta*.m.t_dk.(d).(k))/.(m.beta_v+.!t__klocal);exclude_tokenmwdk;(* add back in s,r *)t__klocal:=m.t__k.(k);Array.set!qk((m.alpha_k+.m.t_dk.(d).(k))/.(m.beta_v+.!t__klocal));letr_local=m.t_dk.(d).(k)in(matchr_localwith|0.->Hashtbl.remove!r_non_zerok|_->(Hashtbl.replace!r_non_zerokr_local;r:=!r+.(m.beta*.r_local)/.(m.beta_v+.!t__klocal)));s:=!s+.(m.beta*.m.alpha_k)/.(!t__klocal+.m.beta_v)letinclude_token_sparsemwdk~s~r~q=lett__klocal=refm.t__k.(k)in(* Reduce s, r l *)s:=!s-.(m.beta*.m.alpha_k)/.(!t__klocal+.m.beta_v);r:=!r-.(m.beta*.m.t_dk.(d).(k))/.(m.beta_v+.!t__klocal);include_tokenmwdk;(* add back in s, r *)t__klocal:=m.t__k.(k);s:=!s+.(m.beta*.m.alpha_k)/.(!t__klocal+.m.beta_v);letr_local=m.t_dk.(d).(k)in(matchr_localwith|0.->Hashtbl.remove!r_non_zerok|_->(Hashtbl.replace!r_non_zerokr_local;r:=!r+.(m.beta*.r_local)/.(m.beta_v+.!t__klocal)));Array.set!qk((m.alpha_k+.m.t_dk.(d).(k))/.(m.beta_v+.!t__klocal))letinitm=(* reset module parameters, maybe wrap into model? *)s:=0.;q:=[||];Hashtbl.reset!r_non_zero;Hashtbl.reset!q_non_zero;(* s is independent of document *)letk=ref0inwhile!k<m.n_kdolett__klocal=m.t__k.(!k)ins:=!s+.(1./.(m.beta_v+.t__klocal));k:=!k+1;done;q:=(Array.make(m.n_k)0.);r_non_zero:=(Hashtbl.createm.n_k);q_non_zero:=(Hashtbl.createm.n_k);s:=!s*.(m.alpha_k*.m.beta)letsamplingmddoc=letk=ref0inletr=ref0.in(* Cache of r *)(* Calculate r *)Hashtbl.clear!r_non_zero;while!k<m.n_kdolett__klocal=m.t__k.(!k)inletr_local=m.t_dk.(d).(!k)in(* Sparse representation of r *)ifr_local!=0.then(letr_val=r_local/.(m.beta_v+.t__klocal)inr:=!r+.r_val;Hashtbl.add!r_non_zero!kr_val;);(* Build up our q cache *)(* TODO: efficiently handle t_dk = 0 *)Array.set!q!k((m.alpha_k+.m.t_dk.(d).(!k))/.(m.beta_v+.t__klocal));k:=!k+1;done;r:=!r*.m.beta;(* Process the document *)Array.iteri(funiw->letk=m.t__z.(d).(i)inexclude_token_sparsemwdksrq;(* Calculate q *)letqsum=ref0.inletk_q=ref0inHashtbl.clear!q_non_zero;(* This bit makes it (K) rather than O(K_d + K_w) *)while!k_q<m.n_kdoletq_local=m.t_wk.(w).(!k_q)inifq_local!=0.then(qsum:=!qsum+.!q.(!k_q)*.q_local;Hashtbl.add!q_non_zero!k_qtrue;);k_q:=!k_q+1;done;k_q:=0;letu=ref(Owl_stats.std_uniform_rvs()*.(!s+.!r+.!qsum))inletk=ref0in(* Work out which factor to sample from *)if!u<!sthen((* sum up *)u:=!u/.(m.alpha_k*.m.beta);(* Don't need this *)letslocal=ref0.inwhile!slocal<!udoslocal:=!slocal+.(1./.(m.beta_v+.m.t__k.(!k_q)));k_q:=!k_q+1;done;(* Found our topic (we went past it by one) *)k:=!k_q-1;)elseif!u<(!s+.!r)then((* Iterate over set of non-zero r *)u:=(!u-.!s)/.m.beta;(* compare just to r and don't need !beta *)letrlocal=ref0.in(* TODO: pick largest (order by decreasing) for efficiency *)Hashtbl.iter(funkeydata->if!rlocal<!uthen(rlocal:=!rlocal+.(data)/.(m.beta_v+.m.t__k.(key));k:=key))!r_non_zero)else(u:=!u-.(!s+.!r);letqlocal=ref0.in(* Iterate over set of non-zero q *)(* TODO: make descending *)Hashtbl.iter(funkey_->if!qlocal<!uthen(qlocal:=!qlocal+.!q.(key)*.m.t_wk.(w).(key);k:=key))!q_non_zero);include_token_sparsemwd!ksrq;m.t__z.(d).(i)<-!k;)docendmoduleFTreeLDA=structletinit_m=failwith"FTreeLDA: not implemented"letsampling_m_d_doc=failwith"FTreeLDA: not implemented"endmoduleLightLDA=structletinit_m=failwith"LightLDA: not implemented"letsampling_m_d_doc=failwith"LightLDA: not implemented"end(* init the model based on: topics, vocabulary, tokens *)letinit?(iter=100)kvd=Owl_log.info"init the model";(* set basic model stats *)letn_d=Owl_nlp_corpus.lengthdinletn_v=Hashtbl.lengthvinletn_k=kin(* set model hyper-parameters *)letalpha=50.inletbeta=0.1inletalpha_k=alpha/.(float_of_intn_k)inletbeta_v=(float_of_intn_v)*.betain(* init model parameters *)lett_dk=Array.initn_d(fun_->Array.maken_k0.)inlett_wk=Array.initn_v(fun_->Array.maken_k0.)inlett__k=Array.maken_k0.in(* set document data and vocabulary *)letdata=dinletvocb=vin(* init a partial model *)letm={n_d;n_k;n_v;alpha;beta;alpha_k;beta_v;t_dk;t_wk;t__k;t__z=[||];iter;data;vocb;}in(* randomise the topic assignment for each token *)m.t__z<-Owl_nlp_corpus.mapi_tok(funis->Array.init(Array.lengths)(funj->letk'=Owl_stats.uniform_int_rvs~a:0~b:(k-1)ininclude_tokenms.(j)ik';k'))d;m(* general training function *)lettraintypm=letsampling=matchtypwith|SimpleLDA->SimpleLDA.sampling|FTreeLDA->FTreeLDA.sampling|LightLDA->LightLDA.sampling|SparseLDA->SparseLDA.samplinginletinit=matchtypwith|SimpleLDA->SimpleLDA.init|FTreeLDA->FTreeLDA.init|LightLDA->LightLDA.init|SparseLDA->SparseLDA.initininitm;fori=0tom.iter-1dolett0=Unix.gettimeofday()inOwl_nlp_corpus.iteri_tok(funjdoc->(* Owl_log.info "iteration #%i - doc#%i" i j; *)samplingmjdoc)m.data;lett1=Unix.gettimeofday()inshow_infomi(t1-.t0);done