123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378# 1 "src/owl/nlp/owl_nlp_lda0.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2017
* Ben Catterall <bpwc2@cam.ac.uk>
* Liang Wang <liang.wang@cl.cam.ac.uk>
*)(** NLP: LDA module *)openBigarraymoduleMS=Owl_sparse.Dok_matrixmoduleMD=Owl_dense.Matrix.Dtypelda_typ=|SimpleLDA|FTreeLDA|LightLDA|SparseLDAtypedsmat=(float,float64_elt)Owl_dense_matrix_generic.ttypespmat=(float,float64_elt)Owl_sparse_dok_matrix.ttypemodel={mutablen_d:int;(* number of documents *)mutablen_k:int;(* number of topics *)mutablen_v:int;(* number of vocabulary *)mutablealpha:float;(* model hyper-parameters *)mutablebeta:float;(* model hyper-parameters *)mutablealpha_k:float;(* model hyper-parameters *)mutablebeta_v:float;(* model hyper-parameters *)mutablet_dk:dsmat;(* document-topic table: num of tokens assigned to each topic in each doc *)mutablet_wk:spmat;(* word-topic table: num of tokens assigned to each topic for each word *)mutablet__k:dsmat;(* number of tokens assigned to a topic: k = sum_w t_wk = sum_d t_dk *)mutablet__z:intarrayarray;(* table of topic assignment of each token in each document *)mutableiter:int;(* number of iterations *)mutabledata:intarrayarray;(* training data, tokenised *)mutablevocb:(string,int)Hashtbl.t(* vocabulary, or dictionary if you prefer *)}letinclude_tokenmwdk=MD.(setm.t__k0k(getm.t__k0k+.1.));MS.(setm.t_wkwk(getm.t_wkwk+.1.));MD.(setm.t_dkdk(getm.t_dkdk+.1.))letexclude_tokenmwdk=MD.(setm.t__k0k(getm.t__k0k-.1.));MS.(setm.t_wkwk(getm.t_wkwk-.1.));MD.(setm.t_dkdk(getm.t_dkdk-.1.))letlikelihoodm=let_sum=ref0.inletn_token=ref0in(* every document *)fori=0tom.n_d-1doletdlen=Array.lengthm.t__z.(i)inn_token:=!n_token+dlen;letdsum=ref0.in(* every token *)forj=0todlen-1doletwsum=ref0.inletw=m.data.(i).(j)in(* every topic *)fork=0tom.n_k-1dowsum:=!wsum+.((MD.getm.t_dkik+.m.alpha_k)*.(MS.getm.t_wkwk+.m.beta)/.(MD.getm.t__k0k+.m.beta_v))done;dsum:=!dsum+.Owl_maths.log2!wsumdone;letdlen=float_of_intdlenin_sum:=!_sum+.!dsum-.(dlen*.Owl_maths.log2dlen)done;!_sum/.float_of_int!n_tokenletshow_infomit=lets=matchimod1=0with|true->Printf.sprintf"likelihood:%.3f"(likelihoodm)|false->""inOwl_log.info"iter#%i t(s):%.1f t_dk:%.3f t_wk:%.3f %s"it(MD.densitym.t_dk)(MS.densitym.t_wk)s(* implement several LDA with specific samplings *)moduleSimpleLDA=structletinit_m=()letsamplingmd=letp=MD.zeros1m.n_kinArray.iteri(funiw->letk=m.t__z.(d).(i)inexclude_tokenmwdk;(* make cdf function *)letx=ref0.inforj=0tom.n_k-1dox:=!x+.((MD.getm.t_dkdj+.m.alpha_k)*.(MS.getm.t_wkwj+.m.beta)/.(MD.getm.t__k0j+.m.beta_v));MD.setp0j!xdone;(* draw a sample *)letu=Owl_stats.std_uniform_rvs()*.!xinletk=ref0inwhileMD.getp0!k<udok:=!k+1done;include_tokenmwd!k;m.t__z.(d).(i)<-!k)m.data.(d)endmoduleSparseLDA=structlets=ref0.(* Cache of s *)letq=ref[||](* Cache of q *)letr_non_zero:(int,float)Hashtbl.tref=ref(Hashtbl.create1)(* *)letq_non_zero:(int,bool)Hashtbl.tref=ref(Hashtbl.create1)(* *)letexclude_token_sparsemwdk~s~r~q=lett__klocal=ref(MD.getm.t__k0k)in(* Reduce s, r l *)s:=!s-.(m.beta*.m.alpha_k/.(!t__klocal+.m.beta_v));r:=!r-.(m.beta*.MD.getm.t_dkdk/.(m.beta_v+.!t__klocal));exclude_tokenmwdk;(* add back in s,r *)t__klocal:=MD.getm.t__k0k;!q.(k)<-(m.alpha_k+.MD.getm.t_dkdk)/.(m.beta_v+.!t__klocal);letr_local=MD.getm.t_dkdkin(matchr_localwith|0.->Hashtbl.remove!r_non_zerok|_->Hashtbl.replace!r_non_zerokr_local;r:=!r+.(m.beta*.r_local/.(m.beta_v+.!t__klocal)));s:=!s+.(m.beta*.m.alpha_k/.(!t__klocal+.m.beta_v))letinclude_token_sparsemwdk~s~r~q=lett__klocal=ref(MD.getm.t__k0k)in(* Reduce s, r l *)s:=!s-.(m.beta*.m.alpha_k/.(!t__klocal+.m.beta_v));r:=!r-.(m.beta*.MD.getm.t_dkdk/.(m.beta_v+.!t__klocal));include_tokenmwdk;(* add back in s, r *)t__klocal:=MD.getm.t__k0k;s:=!s+.(m.beta*.m.alpha_k/.(!t__klocal+.m.beta_v));letr_local=MD.getm.t_dkdkin(matchr_localwith|0.->Hashtbl.remove!r_non_zerok|_->Hashtbl.replace!r_non_zerokr_local;r:=!r+.(m.beta*.r_local/.(m.beta_v+.!t__klocal)));!q.(k)<-(m.alpha_k+.MD.getm.t_dkdk)/.(m.beta_v+.!t__klocal)letinitm=(* reset module parameters, maybe wrap into model? *)s:=0.;q:=[||];Hashtbl.reset!r_non_zero;Hashtbl.reset!q_non_zero;(* s is independent of document *)letk=ref0inwhile!k<m.n_kdolett__klocal=MD.getm.t__k0!kins:=!s+.(1./.(m.beta_v+.t__klocal));k:=!k+1done;q:=Array.makem.n_k0.;r_non_zero:=Hashtbl.createm.n_k;q_non_zero:=Hashtbl.createm.n_k;s:=!s*.(m.alpha_k*.m.beta)letsamplingmd=letk=ref0inletr=ref0.in(* Cache of r *)(* Calculate r *)Hashtbl.clear!r_non_zero;while!k<m.n_kdolett__klocal=MD.getm.t__k0!kinletr_local=MD.getm.t_dkd!kin(* Sparse representation of r *)ifr_local!=0.then(letr_val=r_local/.(m.beta_v+.t__klocal)inr:=!r+.r_val;Hashtbl.add!r_non_zero!kr_val);(* Build up our q cache *)(* TODO: efficiently handle t_dk = 0 *)!q.(!k)<-(m.alpha_k+.MD.getm.t_dkd!k)/.(m.beta_v+.t__klocal);k:=!k+1done;r:=!r*.m.beta;(* Process the document *)Array.iteri(funiw->letk=m.t__z.(d).(i)inexclude_token_sparsemwdk~s~r~q;(* Calculate q *)letqsum=ref0.inletk_q=ref0inHashtbl.clear!q_non_zero;(* This bit makes it (K) rather than O(K_d + K_w) *)while!k_q<m.n_kdoletq_local=MS.getm.t_wkw!k_qinifq_local!=0.then(qsum:=!qsum+.(!q.(!k_q)*.q_local);Hashtbl.add!q_non_zero!k_qtrue);k_q:=!k_q+1done;k_q:=0;letu=ref(Owl_stats.std_uniform_rvs()*.(!s+.!r+.!qsum))inletk=ref0in(* Work out which factor to sample from *)if!u<!sthen((* sum up *)u:=!u/.(m.alpha_k*.m.beta);(* Don't need this *)letslocal=ref0.inwhile!slocal<!udoslocal:=!slocal+.(1./.(m.beta_v+.MD.getm.t__k0!k_q));k_q:=!k_q+1done;(* Found our topic (we went past it by one) *)k:=!k_q-1)elseif!u<!s+.!rthen((* Iterate over set of non-zero r *)u:=(!u-.!s)/.m.beta;(* compare just to r and don't need !beta *)letrlocal=ref0.in(* TODO: pick largest (order by decreasing) for efficiency *)Hashtbl.iter(funkeydata->if!rlocal<!uthen(rlocal:=!rlocal+.(data/.(m.beta_v+.MD.getm.t__k0key));k:=key))!r_non_zero)else(u:=!u-.(!s+.!r);letqlocal=ref0.in(* Iterate over set of non-zero q *)(* TODO: make descending *)Hashtbl.iter(funkey_->if!qlocal<!uthen(qlocal:=!qlocal+.(!q.(key)*.MS.getm.t_wkwkey);k:=key))!q_non_zero);include_token_sparsemwd!k~s~r~q;m.t__z.(d).(i)<-!k)m.data.(d)endmoduleFTreeLDA=structletinit_m=()letsampling_m_d=()endmoduleLightLDA=structletinit_m=()letsampling_m_d=()end(* init the model based on: topics, vocabulary, tokens *)letinit?(iter=100)kvd=Owl_log.info"init the model";(* set basic model stats *)letn_d=Array.lengthdinletn_v=Hashtbl.lengthvinletn_k=kin(* set model hyper-parameters *)letalpha=50.inletbeta=0.1inletalpha_k=alpha/.float_of_intn_kinletbeta_v=float_of_intn_v*.betain(* init model parameters *)lett_dk=MD.zerosn_dn_kinlett_wk=MS.zerosfloat64n_vn_kinlett__k=MD.zeros1n_kin(* set document data and vocabulary *)letdata=dinletvocb=vin(* init a partial model *)letm={n_d;n_k;n_v;alpha;beta;alpha_k;beta_v;t_dk;t_wk;t__k;t__z=[||];iter;data;vocb}in(* randomise the topic assignment for each token *)m.t__z<-Array.mapi(funis->Array.init(Array.lengths)(funj->letk'=Owl_stats.uniform_int_rvs~a:0~b:(k-1)ininclude_tokenms.(j)ik';k'))d;m(* general training function *)lettraintypm=letsampling=matchtypwith|SimpleLDA->SimpleLDA.sampling|FTreeLDA->FTreeLDA.sampling|LightLDA->LightLDA.sampling|SparseLDA->SparseLDA.samplinginletinit=matchtypwith|SimpleLDA->SimpleLDA.init|FTreeLDA->FTreeLDA.init|LightLDA->LightLDA.init|SparseLDA->SparseLDA.initininitm;fori=0tom.iter-1dolett0=Unix.gettimeofday()inforj=0tom.n_d-1do(* Owl_log.info "iteration #%i - doc#%i" i j; *)samplingmjdone;lett1=Unix.gettimeofday()inshow_infomi(t1-.t0)done