/src/owl/nlp/owl_nlp_lda.ml

https://github.com/ryanrhymes/owl · OCaml · 333 lines · 261 code · 30 blank · 42 comment · 17 complexity · 0cc690240168f45319a31d496a0965c1 MD5 · raw file

  1. (*
  2. * OWL - OCaml Scientific and Engineering Computing
  3. * Copyright (c) 2016-2017
  4. * Ben Catterall <bpwc2@cam.ac.uk>
  5. * Liang Wang <liang.wang@cl.cam.ac.uk>
  6. *)
  7. [@@@warning "-6"]
  8. (** NLP: LDA module *)
  9. type lda_typ =
  10. | SimpleLDA
  11. | FTreeLDA
  12. | LightLDA
  13. | SparseLDA
  14. type model =
  15. { mutable n_d : int
  16. ; (* number of documents *)
  17. mutable n_k : int
  18. ; (* number of topics *)
  19. mutable n_v : int
  20. ; (* number of vocabulary *)
  21. mutable alpha : float
  22. ; (* model hyper-parameters *)
  23. mutable beta : float
  24. ; (* model hyper-parameters *)
  25. mutable alpha_k : float
  26. ; (* model hyper-parameters *)
  27. mutable beta_v : float
  28. ; (* model hyper-parameters *)
  29. mutable t_dk : float array array
  30. ; (* document-topic table: num of tokens assigned to each topic in each doc *)
  31. mutable t_wk : float array array
  32. ; (* word-topic table: num of tokens assigned to each topic for each word *)
  33. mutable t__k : float array
  34. ; (* number of tokens assigned to a topic: k = sum_w t_wk = sum_d t_dk *)
  35. mutable t__z : int array array
  36. ; (* table of topic assignment of each token in each document *)
  37. mutable iter : int
  38. ; (* number of iterations *)
  39. mutable data : Owl_nlp_corpus.t
  40. ; (* training data, tokenised*)
  41. mutable vocb : (string, int) Hashtbl.t (* vocabulary, or dictionary if you prefer *)
  42. }
  43. let include_token m w d k =
  44. m.t__k.(k) <- m.t__k.(k) +. 1.;
  45. m.t_wk.(w).(k) <- m.t_wk.(w).(k) +. 1.;
  46. m.t_dk.(d).(k) <- m.t_dk.(d).(k) +. 1.
  47. let exclude_token m w d k =
  48. m.t__k.(k) <- m.t__k.(k) -. 1.;
  49. m.t_wk.(w).(k) <- m.t_wk.(w).(k) -. 1.;
  50. m.t_dk.(d).(k) <- m.t_dk.(d).(k) -. 1.
  51. let show_info _m i t = Owl_log.info "iter#%i t(s):%.1f t_dk:%.3f t_wk:%.3f" i t 0. 0.
  52. (* implement several LDA with specific samplings *)
  53. module SimpleLDA = struct
  54. let init _m = ()
  55. let sampling m d doc =
  56. let p = Array.make m.n_k 0. in
  57. Array.iteri
  58. (fun i w ->
  59. let k = m.t__z.(d).(i) in
  60. exclude_token m w d k;
  61. (* make cdf function *)
  62. let x = ref 0. in
  63. for j = 0 to m.n_k - 1 do
  64. x
  65. := !x
  66. +. ((m.t_dk.(d).(j) +. m.alpha_k)
  67. *. (m.t_wk.(w).(j) +. m.beta)
  68. /. (m.t__k.(j) +. m.beta_v));
  69. p.(j) <- !x
  70. done;
  71. (* draw a sample *)
  72. let u = Owl_stats.std_uniform_rvs () *. !x in
  73. let k = ref 0 in
  74. while p.(!k) < u do
  75. k := !k + 1
  76. done;
  77. include_token m w d !k;
  78. m.t__z.(d).(i) <- !k)
  79. doc
  80. end
  81. module SparseLDA = struct
  82. let s = ref 0. (* Cache of s *)
  83. let q = ref [||] (* Cache of q *)
  84. let r_non_zero : (int, float) Hashtbl.t ref = ref (Hashtbl.create 1) (* *)
  85. let q_non_zero : (int, bool) Hashtbl.t ref = ref (Hashtbl.create 1) (* *)
  86. let exclude_token_sparse m w d k ~s ~r ~q =
  87. let t__klocal = ref m.t__k.(k) in
  88. (* Reduce s, r l *)
  89. s := !s -. (m.beta *. m.alpha_k /. (!t__klocal +. m.beta_v));
  90. r := !r -. (m.beta *. m.t_dk.(d).(k) /. (m.beta_v +. !t__klocal));
  91. exclude_token m w d k;
  92. (* add back in s,r *)
  93. t__klocal := m.t__k.(k);
  94. !q.(k) <- (m.alpha_k +. m.t_dk.(d).(k)) /. (m.beta_v +. !t__klocal);
  95. let r_local = m.t_dk.(d).(k) in
  96. (match r_local with
  97. | 0. -> Hashtbl.remove !r_non_zero k
  98. | _ ->
  99. Hashtbl.replace !r_non_zero k r_local;
  100. r := !r +. (m.beta *. r_local /. (m.beta_v +. !t__klocal)));
  101. s := !s +. (m.beta *. m.alpha_k /. (!t__klocal +. m.beta_v))
  102. let include_token_sparse m w d k ~s ~r ~q =
  103. let t__klocal = ref m.t__k.(k) in
  104. (* Reduce s, r l *)
  105. s := !s -. (m.beta *. m.alpha_k /. (!t__klocal +. m.beta_v));
  106. r := !r -. (m.beta *. m.t_dk.(d).(k) /. (m.beta_v +. !t__klocal));
  107. include_token m w d k;
  108. (* add back in s, r *)
  109. t__klocal := m.t__k.(k);
  110. s := !s +. (m.beta *. m.alpha_k /. (!t__klocal +. m.beta_v));
  111. let r_local = m.t_dk.(d).(k) in
  112. (match r_local with
  113. | 0. -> Hashtbl.remove !r_non_zero k
  114. | _ ->
  115. Hashtbl.replace !r_non_zero k r_local;
  116. r := !r +. (m.beta *. r_local /. (m.beta_v +. !t__klocal)));
  117. !q.(k) <- (m.alpha_k +. m.t_dk.(d).(k)) /. (m.beta_v +. !t__klocal)
  118. let init m =
  119. (* reset module parameters, maybe wrap into model? *)
  120. s := 0.;
  121. q := [||];
  122. Hashtbl.reset !r_non_zero;
  123. Hashtbl.reset !q_non_zero;
  124. (* s is independent of document *)
  125. let k = ref 0 in
  126. while !k < m.n_k do
  127. let t__klocal = m.t__k.(!k) in
  128. s := !s +. (1. /. (m.beta_v +. t__klocal));
  129. k := !k + 1
  130. done;
  131. q := Array.make m.n_k 0.;
  132. r_non_zero := Hashtbl.create m.n_k;
  133. q_non_zero := Hashtbl.create m.n_k;
  134. s := !s *. (m.alpha_k *. m.beta)
  135. let sampling m d doc =
  136. let k = ref 0 in
  137. let r = ref 0. in
  138. (* Cache of r *)
  139. (* Calculate r *)
  140. Hashtbl.clear !r_non_zero;
  141. while !k < m.n_k do
  142. let t__klocal = m.t__k.(!k) in
  143. let r_local = m.t_dk.(d).(!k) in
  144. (* Sparse representation of r *)
  145. if r_local != 0.
  146. then (
  147. let r_val = r_local /. (m.beta_v +. t__klocal) in
  148. r := !r +. r_val;
  149. Hashtbl.add !r_non_zero !k r_val);
  150. (* Build up our q cache *)
  151. (* TODO: efficiently handle t_dk = 0 *)
  152. !q.(!k) <- (m.alpha_k +. m.t_dk.(d).(!k)) /. (m.beta_v +. t__klocal);
  153. k := !k + 1
  154. done;
  155. r := !r *. m.beta;
  156. (* Process the document *)
  157. Array.iteri
  158. (fun i w ->
  159. let k = m.t__z.(d).(i) in
  160. exclude_token_sparse m w d k s r q;
  161. (* Calculate q *)
  162. let qsum = ref 0. in
  163. let k_q = ref 0 in
  164. Hashtbl.clear !q_non_zero;
  165. (* This bit makes it (K) rather than O(K_d + K_w) *)
  166. while !k_q < m.n_k do
  167. let q_local = m.t_wk.(w).(!k_q) in
  168. if q_local != 0.
  169. then (
  170. qsum := !qsum +. (!q.(!k_q) *. q_local);
  171. Hashtbl.add !q_non_zero !k_q true);
  172. k_q := !k_q + 1
  173. done;
  174. k_q := 0;
  175. let u = ref (Owl_stats.std_uniform_rvs () *. (!s +. !r +. !qsum)) in
  176. let k = ref 0 in
  177. (* Work out which factor to sample from *)
  178. if !u < !s
  179. then (
  180. (* sum up *)
  181. u := !u /. (m.alpha_k *. m.beta);
  182. (* Don't need this *)
  183. let slocal = ref 0. in
  184. while !slocal < !u do
  185. slocal := !slocal +. (1. /. (m.beta_v +. m.t__k.(!k_q)));
  186. k_q := !k_q + 1
  187. done;
  188. (* Found our topic (we went past it by one) *)
  189. k := !k_q - 1)
  190. else if !u < !s +. !r
  191. then (
  192. (* Iterate over set of non-zero r *)
  193. u := (!u -. !s) /. m.beta;
  194. (* compare just to r and don't need !beta *)
  195. let rlocal = ref 0. in
  196. (* TODO: pick largest (order by decreasing) for efficiency *)
  197. Hashtbl.iter
  198. (fun key data ->
  199. if !rlocal < !u
  200. then (
  201. rlocal := !rlocal +. (data /. (m.beta_v +. m.t__k.(key)));
  202. k := key))
  203. !r_non_zero)
  204. else (
  205. u := !u -. (!s +. !r);
  206. let qlocal = ref 0. in
  207. (* Iterate over set of non-zero q *)
  208. (* TODO: make descending *)
  209. Hashtbl.iter
  210. (fun key _ ->
  211. if !qlocal < !u
  212. then (
  213. qlocal := !qlocal +. (!q.(key) *. m.t_wk.(w).(key));
  214. k := key))
  215. !q_non_zero);
  216. include_token_sparse m w d !k s r q;
  217. m.t__z.(d).(i) <- !k)
  218. doc
  219. end
  220. module FTreeLDA = struct
  221. let init _m = failwith "FTreeLDA: not implemented"
  222. let sampling _m _d _doc = failwith "FTreeLDA: not implemented"
  223. end
  224. module LightLDA = struct
  225. let init _m = failwith "LightLDA: not implemented"
  226. let sampling _m _d _doc = failwith "LightLDA: not implemented"
  227. end
  228. (* init the model based on: topics, vocabulary, tokens *)
  229. let init ?(iter = 100) k d =
  230. let vocab = Owl_nlp_corpus.get_vocab d in
  231. let v = Owl_nlp_vocabulary.get_w2i vocab in
  232. Owl_log.info "init the model";
  233. (* set basic model stats *)
  234. let n_d = Owl_nlp_corpus.length d in
  235. let n_v = Hashtbl.length v in
  236. let n_k = k in
  237. (* set model hyper-parameters *)
  238. let alpha = 50. in
  239. let beta = 0.1 in
  240. let alpha_k = alpha /. float_of_int n_k in
  241. let beta_v = float_of_int n_v *. beta in
  242. (* init model parameters *)
  243. let t_dk = Array.init n_d (fun _ -> Array.make n_k 0.) in
  244. let t_wk = Array.init n_v (fun _ -> Array.make n_k 0.) in
  245. let t__k = Array.make n_k 0. in
  246. (* set document data and vocabulary *)
  247. let data = d in
  248. let vocb = v in
  249. (* init a partial model *)
  250. let m =
  251. { n_d
  252. ; n_k
  253. ; n_v
  254. ; alpha
  255. ; beta
  256. ; alpha_k
  257. ; beta_v
  258. ; t_dk
  259. ; t_wk
  260. ; t__k
  261. ; t__z = [||]
  262. ; iter
  263. ; data
  264. ; vocb
  265. }
  266. in
  267. (* randomise the topic assignment for each token *)
  268. m.t__z
  269. <- Owl_nlp_corpus.mapi_tok
  270. (fun i s ->
  271. Array.init (Array.length s) (fun j ->
  272. let k' = Owl_stats.uniform_int_rvs ~a:0 ~b:(k - 1) in
  273. include_token m s.(j) i k';
  274. k'))
  275. d;
  276. m
  277. (* general training function *)
  278. let train typ m =
  279. let sampling =
  280. match typ with
  281. | SimpleLDA -> SimpleLDA.sampling
  282. | FTreeLDA -> FTreeLDA.sampling
  283. | LightLDA -> LightLDA.sampling
  284. | SparseLDA -> SparseLDA.sampling
  285. in
  286. let init =
  287. match typ with
  288. | SimpleLDA -> SimpleLDA.init
  289. | FTreeLDA -> FTreeLDA.init
  290. | LightLDA -> LightLDA.init
  291. | SparseLDA -> SparseLDA.init
  292. in
  293. init m;
  294. for i = 0 to m.iter - 1 do
  295. let t0 = Unix.gettimeofday () in
  296. Owl_nlp_corpus.iteri_tok
  297. (fun j doc ->
  298. (* Owl_log.info "iteration #%i - doc#%i" i j; *)
  299. sampling m j doc)
  300. m.data;
  301. let t1 = Unix.gettimeofday () in
  302. show_info m i (t1 -. t0)
  303. done