Source file dist.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
type 'a dist = { sample : RNG.t -> 'a; ll : 'a -> Log_space.t }

type 'a kernel =
  { start : 'a; sample : 'a -> RNG.t -> 'a; ll : 'a -> 'a -> Log_space.t }

type 'a t = Stateless of 'a dist | Kernel of 'a kernel

let stateless sample ll = Stateless { sample; ll } [@@inline]

let kernel start sample ll = Kernel { start; sample; ll } [@@inline]

let dist0 sampler log_pdf = stateless sampler log_pdf [@@inline]

let dist1 sampler log_pdf arg =
  stateless (fun rng_state -> sampler arg rng_state) (fun x -> log_pdf arg x)
  [@@inline]

let dist2 sampler log_pdf arg1 arg2 =
  stateless
    (fun rng_state -> sampler arg1 arg2 rng_state)
    (fun x -> log_pdf arg1 arg2 x)
  [@@inline]

let kernel1 sampler log_pdf start arg =
  kernel
    start
    (fun x rng_state -> sampler arg x rng_state)
    (fun x y -> log_pdf arg x y)
  [@@inline]

let iid n dist =
  match dist with
  | Stateless dist ->
      let sample rng_state = Array.init n (fun _i -> dist.sample rng_state) in
      let ll arr =
        let acc = ref Log_space.one in
        for i = 0 to Array.length arr - 1 do
          acc := Log_space.mul (dist.ll arr.(i)) !acc
        done ;
        !acc
      in
      Stateless { sample; ll }
  | Kernel k ->
      let start = Array.make n k.start in
      let sample x rng_state =
        Array.init n (fun i -> k.sample x.(i) rng_state)
      in
      let ll x y =
        assert (Array.length x = Array.length y) ;
        let acc = ref Log_space.one in
        for i = 0 to Array.length x - 1 do
          acc := Log_space.mul (k.ll x.(i) y.(i)) !acc
        done ;
        !acc
      in
      Kernel { start; sample; ll }

let conv (type a b) (f : a -> b) (g : b -> a) (s : a t) : b t =
  match s with
  | Kernel k ->
      Kernel
        { sample = (fun x rng -> f (k.sample (g x) rng));
          ll = (fun x y -> k.ll (g x) (g y));
          start = f k.start
        }
  | Stateless d ->
      Stateless
        { sample = (fun rng -> f (d.sample rng)); ll = (fun x -> d.ll (g x)) }