1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
module Log_space = Dagger.Log_space
module type GSL_SIG = sig
module Rng : sig
type rng_type
type t
val default : unit -> rng_type
val make : rng_type -> t
val set : t -> nativeint -> unit
val uniform_int : t -> int -> int
end
module Randist : sig
val flat : Rng.t -> a:float -> b:float -> float
val flat_pdf : float -> a:float -> b:float -> float
val bernoulli : Rng.t -> p:float -> int
val bernoulli_pdf : int -> p:float -> float
val gaussian : Rng.t -> sigma:float -> float
val gaussian_pdf : float -> sigma:float -> float
val gaussian_tail : Rng.t -> a:float -> sigma:float -> float
val gaussian_tail_pdf : float -> a:float -> sigma:float -> float
val laplace : Rng.t -> a:float -> float
val laplace_pdf : float -> a:float -> float
val exppow : Rng.t -> a:float -> b:float -> float
val exppow_pdf : float -> a:float -> b:float -> float
val cauchy : Rng.t -> a:float -> float
val cauchy_pdf : float -> a:float -> float
val rayleigh : Rng.t -> sigma:float -> float
val rayleigh_pdf : float -> sigma:float -> float
val rayleigh_tail : Rng.t -> a:float -> sigma:float -> float
val rayleigh_tail_pdf : float -> a:float -> sigma:float -> float
val landau : Rng.t -> float
val landau_pdf : float -> float
val gamma : Rng.t -> a:float -> b:float -> float
val gamma_pdf : float -> a:float -> b:float -> float
val weibull : Rng.t -> a:float -> b:float -> float
val weibull_pdf : float -> a:float -> b:float -> float
val binomial : Rng.t -> p:float -> n:int -> int
val binomial_pdf : int -> p:float -> n:int -> float
val geometric : Rng.t -> p:float -> int
val geometric_pdf : int -> p:float -> float
val exponential : Rng.t -> mu:float -> float
val exponential_pdf : float -> mu:float -> float
val poisson : Rng.t -> mu:float -> int
val poisson_pdf : int -> mu:float -> float
type discrete
val discrete_preproc : float array -> discrete
val discrete : Rng.t -> discrete -> int
val discrete_pdf : int -> discrete -> float
val beta : Rng.t -> a:float -> b:float -> float
val beta_pdf : float -> a:float -> b:float -> float
val dirichlet : Rng.t -> alpha:float array -> theta:float array -> unit
val dirichlet_pdf : alpha:float array -> theta:float array -> float
val dirichlet_lnpdf : alpha:float array -> theta:float array -> float
val lognormal : Rng.t -> zeta:float -> sigma:float -> float
val lognormal_pdf : float -> zeta:float -> sigma:float -> float
val chisq : Rng.t -> nu:float -> float
val chisq_pdf : float -> nu:float -> float
end
end
module Make (Gsl : GSL_SIG) = struct
open Gsl
let gsl_rng = ref (Rng.default ())
let rng (s : Random.State.t) =
let rng = Rng.make !gsl_rng in
let seed = Random.State.nativebits s in
Rng.set rng seed ;
rng
open Dagger.Dist
let dist0 sampler log_pdf =
stateless (fun state -> sampler (rng state)) log_pdf
[@@inline]
let dist1 sampler log_pdf arg =
stateless
(fun rng_state -> sampler arg (rng rng_state))
(fun x -> log_pdf arg x)
[@@inline]
let dist2 sampler log_pdf arg1 arg2 =
stateless
(fun rng_state -> sampler arg1 arg2 (rng rng_state))
(fun x -> log_pdf arg1 arg2 x)
[@@inline]
let kernel1 sampler log_pdf start arg =
kernel
start
(fun x rng_state -> sampler arg x (rng rng_state))
(fun x y -> log_pdf arg x y)
[@@inline]
let float bound =
dist1
(fun b state -> Randist.flat state ~a:0.0 ~b)
(fun b x -> Log_space.of_float (Randist.flat_pdf x ~a:0.0 ~b))
bound
let int bound =
dist1
(fun b state -> Rng.uniform_int state b)
(fun b x ->
if x < 0 || x >= b then Log_space.zero
else Log_space.of_float (1. /. float_of_int b))
bound
let bool =
let ll = Log_space.of_float 0.5 in
dist0 (fun state -> Randist.bernoulli state ~p:0.5) (fun _ -> ll)
let gaussian ~mean ~std =
dist2
(fun mean sigma state -> mean +. Randist.gaussian state ~sigma)
(fun mean sigma x ->
Log_space.of_float (Randist.gaussian_pdf (mean +. x) ~sigma))
mean
std
let gaussian_tail ~a ~std =
dist2
(fun a sigma state -> Randist.gaussian_tail state ~a ~sigma)
(fun a sigma x ->
Log_space.of_float (Randist.gaussian_tail_pdf x ~a ~sigma))
a
std
let laplace ~a =
dist1
(fun a state -> Randist.laplace state ~a)
(fun a x -> Log_space.of_float (Randist.laplace_pdf x ~a))
a
let exppow ~a ~b =
dist2
(fun a b state -> Randist.exppow state ~a ~b)
(fun a b x -> Log_space.of_float (Randist.exppow_pdf x ~a ~b))
a
b
let cauchy ~a =
dist1
(fun a state -> Randist.cauchy state ~a)
(fun a x -> Log_space.of_float (Randist.cauchy_pdf x ~a))
a
let rayleigh ~sigma =
dist1
(fun sigma state -> Randist.rayleigh state ~sigma)
(fun sigma x -> Log_space.of_float (Randist.rayleigh_pdf x ~sigma))
sigma
let rayleigh_tail ~a ~sigma =
dist2
(fun a sigma state -> Randist.rayleigh_tail state ~a ~sigma)
(fun a sigma x ->
Log_space.of_float (Randist.rayleigh_tail_pdf x ~a ~sigma))
a
sigma
let landau =
dist0 Randist.landau (fun x -> Log_space.of_float (Randist.landau_pdf x))
let gamma ~a ~b =
dist2
(fun a b state -> Randist.gamma state ~a ~b)
(fun a b x -> Log_space.of_float (Randist.gamma_pdf x ~a ~b))
a
b
let weibull ~a ~b =
dist2
(fun a b state -> Randist.weibull state ~a ~b)
(fun a b x -> Log_space.of_float (Randist.weibull_pdf x ~a ~b))
a
b
let flat a b =
dist2
(fun a b state -> Randist.flat state ~a ~b)
(fun a b x -> Log_space.of_float (Randist.flat_pdf x ~a ~b))
a
b
let bernoulli ~bias =
dist1
(fun p state -> Randist.bernoulli state ~p = 1)
(fun p x ->
if x then Log_space.of_float p else Log_space.of_float (1. -. p))
bias
let binomial p n =
dist2
(fun p n state -> Randist.binomial state ~p ~n)
(fun p n x -> Log_space.of_float (Randist.binomial_pdf x ~p ~n))
p
n
let geometric ~p =
dist1
(fun p state -> Randist.geometric state ~p)
(fun p k -> Log_space.of_float (Randist.geometric_pdf k ~p))
p
let exponential ~rate =
dist1
(fun mu state -> Randist.exponential state ~mu)
(fun mu x -> Log_space.of_float (Randist.exponential_pdf x ~mu))
rate
let poisson ~rate =
dist1
(fun mu state -> Randist.poisson state ~mu)
(fun mu x -> Log_space.of_float (Randist.poisson_pdf x ~mu))
rate
let categorical (type a) (module H : Hashtbl.S with type key = a)
(cases : (a * float) array) =
let xs = Array.map fst cases in
let ps = Array.map snd cases in
let contents = Array.mapi (fun i x -> (x, i)) xs in
let table = H.of_seq (Array.to_seq contents) in
let sampler = Randist.discrete_preproc ps in
dist0
(fun state ->
let index = Randist.discrete state sampler in
xs.(index))
(fun elt ->
match H.find_opt table elt with
| None -> assert false
| Some i -> Log_space.of_float (Randist.discrete_pdf i sampler))
let beta ~a ~b =
dist2
(fun a b state -> Randist.beta state ~a ~b)
(fun a b x -> Log_space.of_float (Randist.beta_pdf ~a ~b x))
a
b
let dirichlet ~alpha =
dist1
(fun alpha state ->
let theta = Array.make (Array.length alpha) 0.0 in
Randist.dirichlet state ~alpha ~theta ;
theta)
(fun alpha theta ->
Log_space.unsafe_cast (Randist.dirichlet_lnpdf ~alpha ~theta))
alpha
let lognormal ~zeta ~sigma =
dist2
(fun zeta sigma state -> Randist.lognormal state ~zeta ~sigma)
(fun zeta sigma x ->
Log_space.of_float (Randist.lognormal_pdf x ~zeta ~sigma))
zeta
sigma
let chi_squared ~nu =
dist1
(fun nu state -> Randist.chisq state ~nu)
(fun nu x -> Log_space.of_float (Randist.chisq_pdf x ~nu))
nu
let mixture coeffs (dists : 'a t array) =
let dists =
Array.map
(function Stateless d -> d | Kernel _ -> invalid_arg "mixture")
dists
in
if Array.length coeffs <> Array.length dists then invalid_arg "mixture" ;
if Array.length coeffs = 0 then invalid_arg "mixture" ;
let log_coeffs = Array.map Log_space.of_float coeffs in
let sampler = Randist.discrete_preproc coeffs in
let log_pdf x =
let open Log_space in
let acc = ref one in
for i = 0 to Array.length dists - 1 do
acc := mul !acc (mul log_coeffs.(i) (dists.(i).ll x))
done ;
!acc
in
let sampler rng_state =
let case = Randist.discrete (rng rng_state) sampler in
dists.(case).sample rng_state
in
stateless sampler log_pdf
end