Source file lmh_inference.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
module Counter = struct
let x = ref 0
let gen () =
let v = !x in
incr x ;
v
end
type 'ret sampling_point =
| Sampling_point :
{ uid : int;
ll : Log_space.t;
pos : 'a;
dist : 'a Dist.t;
k : 'a -> unit -> 'ret trace -> 'ret
}
-> 'ret sampling_point
and score = int * Log_space.t
and 'ret trace =
{ scores : score list; samples : (Log_space.t * 'ret sampling_point) list }
let empty_trace = { scores = []; samples = [] }
module Trace = struct
type 'a t = 'a trace
end
module Cps = Cps_monad.Make (Identity_monad) (Trace)
module Syntax = struct
type 'a t = (unit, 'a) Cps.m
include Cps
include Lmh_generic.Make (Identity_monad) (Trace)
end
module Table = struct
let push_sample (s : 'a trace) sp =
match s.samples with
| [] -> { s with samples = [(Log_space.one, sp)] }
| (cumu, Sampling_point { ll; _ }) :: _ ->
{ s with samples = (Log_space.mul cumu ll, sp) :: s.samples }
let push_score (s : 'a trace) sc = { s with scores = sc :: s.scores }
let total_score { scores; samples } =
let total_samples =
match samples with
| [] -> Log_space.one
| (cumu, Sampling_point { ll; _ }) :: _ -> Log_space.mul cumu ll
in
let total_scores =
List.fold_left
(fun acc (_, ll) -> Log_space.mul acc ll)
Log_space.one
scores
in
(total_samples, Log_space.mul total_samples total_scores)
let cardinal s = List.length s.samples
let wipe_after s limit =
{ samples =
List.filter
(fun (_, Sampling_point { uid; _ }) -> uid < limit)
s.samples;
scores = List.filter (fun (uid, _) -> uid < limit) s.scores
}
let handler : RNG.t -> unit Syntax.handler =
let open Syntax in
fun rng_state ->
{ handler =
(fun (type a) (dist : a Cps_monad.effect) k () trace ->
match dist with
| Dist (Stateless { sample; ll } as d) ->
let pos = sample rng_state in
let ll = ll pos in
let uid = Counter.gen () in
let trace =
push_sample
trace
(Sampling_point { uid; ll; dist = d; k; pos })
in
k pos () trace
| Dist (Kernel { start; sample; ll } as d) ->
let pos = sample start rng_state in
let ll = ll start pos in
let uid = Counter.gen () in
let trace =
push_sample
trace
(Sampling_point { uid; ll; dist = d; k; pos })
in
k pos () trace
| Score (m, ll) ->
let uid = Counter.gen () in
let trace = push_score trace (uid, ll) in
k m () trace
| _unknown_effect -> invalid_arg "Lmh_inference: unknown effect")
}
end
let stream_samples (type a) (v : a Syntax.t) rng_state : a Seq.t =
let select_resampling trace rng_state =
let sample_points =
Array.of_list trace.samples
in
let length = Array.length sample_points in
if length = 0 then None else Some sample_points.(RNG.int rng_state length)
in
let exception Return of a * a trace in
let first_run (v : a Syntax.t) rng_state =
let handler = Table.handler rng_state in
(v ~handler).cont (fun x () trace -> raise (Return (x, trace)))
in
let (first_value, first_trace) =
match first_run v rng_state () empty_trace with
| exception Return (x, trace) -> (x, trace)
| _ -> assert false
in
let sample_step (prev_value : a) (prev_trace : a trace) rng_state =
let prev_card = Table.cardinal prev_trace in
let (prev_total_sampling_score, prev_total_score) =
Table.total_score prev_trace
in
match select_resampling prev_trace rng_state with
| None -> (prev_value, prev_trace)
| Some
( prefix_score,
Sampling_point ({ uid; ll = _prev_ll; dist; pos; k } as _payload) ) ->
let propose rng_state =
match dist with
| Stateless d ->
let sample = d.sample rng_state in
let ll = d.ll sample in
(sample, ll)
| Kernel d ->
let sample = d.sample pos rng_state in
let fwd_ll = d.ll pos sample in
(sample, fwd_ll)
in
let (proposed, fwd_ll) = propose rng_state in
let backward_sampling_score =
Log_space.div prev_total_sampling_score prefix_score
in
let trace_prefix =
let trace = Table.wipe_after prev_trace uid in
{ trace with
samples =
( prefix_score,
Sampling_point { uid; ll = fwd_ll; dist; pos = proposed; k } )
:: trace.samples
}
in
let (new_value, new_trace) : a * a trace =
match k proposed () trace_prefix with
| exception Return (x, trace) -> (x, trace)
| _ -> assert false
in
let (new_total_sampling_score, new_total_score) =
Table.total_score new_trace
in
let new_card = Table.cardinal new_trace in
let forward_sampling_score =
Log_space.div new_total_sampling_score prefix_score
in
let forward_flow =
Log_space.(
mul
prev_total_score
(div forward_sampling_score (of_float (float_of_int prev_card))))
in
let backward_flow =
Log_space.(
mul
new_total_score
(div backward_sampling_score (of_float (float_of_int new_card))))
in
let ratio = Log_space.div backward_flow forward_flow in
let acceptance = Log_space.(min one ratio) in
if
Log_space.lt (Log_space.of_float (RNG.float rng_state 1.0)) acceptance
then (new_value, new_trace)
else (prev_value, prev_trace)
in
Seq.unfold
(fun (prev_value, prev_trace) ->
let next = sample_step prev_value prev_trace rng_state in
Some (prev_value, next))
(first_value, first_trace)
include Syntax
module List_ops = Foldable.Make_list (Syntax)
module Seq_ops = Foldable.Make_seq (Syntax)
module Array_ops = Foldable.Make_array (Syntax)