Source file owl_stats_sampler.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# 1 "src/owl/stats/owl_stats_sampler.ml"
type 'a t =
{ samples : 'a array
; accept : float
}
let rejection ~m ~proprvs ~proppdf ~pdf nsamples =
assert (m > 0.);
let log_m = log m in
let accept = ref 0 in
let total = ref 0 in
let samples = Array.make nsamples 0. in
while !accept < nsamples do
let x = proprvs () in
let a = log (Owl_stats_dist.std_uniform_rvs ()) in
let b = log (pdf x) -. log (proppdf x) -. log_m in
assert (b < 0.);
if a < b
then (
samples.(!accept) <- x;
accept := !accept + 1);
total := !total + 1
done;
let accept = float_of_int !accept /. float_of_int !total in
{ samples; accept }
let ars = None
let arms = None
let metropolis ?burnin ?thin ~initial ~proprvs ~proppdf ~pdf nsamples =
let burnin =
match burnin with
| Some a -> a
| None -> 1000
in
let thin =
match thin with
| Some a -> a
| None -> 10
in
let accept = ref 0 in
let niter = burnin + (thin * nsamples) in
let samples = Array.make niter [||] in
samples.(0) <- Array.copy initial;
for i = 1 to niter - 1 do
let x = samples.(i - 1) in
let x' = proprvs x in
let p = log (pdf x) in
let p' = log (pdf x') in
let g = log (proppdf x x') in
let g' = log (proppdf x' x) in
let a = min 0. (p' -. p +. g -. g') in
let b =
if a >= 0.
then true
else if log (Owl_stats_dist.std_uniform_rvs ()) < a
then true
else false
in
let x' = if b then x' else Array.copy x in
if b then accept := !accept + 1;
samples.(i) <- x'
done;
let samples =
Owl_utils.Array.filteri (fun i _ -> i >= burnin && i mod thin = 0) samples
in
let accept = float_of_int !accept /. float_of_int niter in
{ samples; accept }
let gibbs = None
let slice = None
let adaptive_rejection = None
let to_string t =
Printf.sprintf "[ samples: %i; accept_ratio: %g ]" (Array.length t.samples) t.accept
let pp_t formatter t =
Format.open_box 0;
Format.fprintf formatter "%s" (to_string t);
Format.close_box ()