Source file owl_stats_sampler.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# 1 "src/owl/stats/owl_stats_sampler.ml"
type 'a t = {
samples : 'a array;
accept : float;
}
let rejection ~m ~proprvs ~proppdf ~pdf nsamples =
assert (m > 0.);
let log_m = log m in
let accept = ref 0 in
let total = ref 0 in
let samples = Array.make nsamples 0. in
while !accept < nsamples do
let x = proprvs () in
let a = log (Owl_stats_dist.std_uniform_rvs ()) in
let b = (log (pdf x)) -. (log (proppdf x)) -. log_m in
assert (b < 0.);
if a < b then (
samples.(!accept) <- x;
accept := !accept + 1;
);
total := !total + 1
done;
let accept = (float_of_int !accept) /. (float_of_int !total) in
{ samples; accept }
let ars = None
let arms = None
let metropolis ?burnin ?thin ~initial ~proprvs ~proppdf ~pdf nsamples =
let burnin = match burnin with Some a -> a | None -> 1000 in
let thin = match thin with Some a -> a | None -> 10 in
let accept = ref 0 in
let niter = burnin + thin * nsamples in
let samples = Array.make niter [||] in
samples.(0) <- Array.copy initial;
for i = 1 to niter - 1 do
let x = samples.(i - 1) in
let x' = proprvs x in
let p = log (pdf x) in
let p' = log (pdf x') in
let g = log (proppdf x x') in
let g' = log (proppdf x' x) in
let a = min 0. (p' -. p +. g -. g') in
let b =
if a >= 0. then true
else if log (Owl_stats_dist.std_uniform_rvs ()) < a then true
else false
in
let x' = if b then x' else Array.copy x in
if b then accept := !accept + 1;
samples.(i) <- x'
done;
let samples = Owl_utils.Array.filteri (fun i _ ->
(i >= burnin) && (i mod thin = 0)
) samples
in
let accept = (float_of_int !accept) /. (float_of_int niter) in
{ samples; accept }
let gibbs = None
let slice = None
let adaptive_rejection = None
let to_string t =
Printf.sprintf "[ samples: %i; accept_ratio: %g ]"
(Array.length t.samples) t.accept
let pp_t formatter t =
Format.open_box 0;
Format.fprintf formatter "%s" (to_string t);
Format.close_box ()