Source file discrete_pd.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
open Core
type t = {
shift : int ;
weights : float array ;
}
let is_leaf dpd i = i >= dpd.shift
let init n ~f =
let shift = Float.(to_int (2. ** round_up (log (float n) /. log 2.))) - 1 in
let m = shift + n in
let weights = Array.create ~len:m 0. in
for i = 0 to n - 1 do
weights.(shift + i) <- f (i)
done ;
for i = shift - 1 downto 0 do
if 2 * i + 1 < m then weights.(i) <- weights.(2 * i + 1) ;
if 2 * i + 2 < m then weights.(i) <- weights.(i) +. weights.(2 * i + 2)
done ;
{ shift ; weights }
let draw dpd rng =
let x = dpd.weights.(0) *. Gsl.Rng.uniform rng in
let rec loop acc i =
if is_leaf dpd i then i
else if Float.( >= ) (acc +. dpd.weights.(2 * i + 1)) x then
loop acc (2 * i + 1)
else loop (acc +. dpd.weights.(2 * i + 1)) (2 * i + 2)
in
loop 0. 0 - dpd.shift
let update dpd i w_i =
let m = Array.length dpd.weights in
let j = i + dpd.shift in
dpd.weights.(j) <- w_i ;
let rec loop k =
dpd.weights.(k) <- dpd.weights.(2 * k + 1) ;
if 2 * k + 2 < m then dpd.weights.(k) <- dpd.weights.(k) +. dpd.weights.(2 * k + 2) ;
if k > 0 then loop ((k - 1) / 2)
in
loop ((j - 1) / 2)
let total_weight dpd = dpd.weights.(0)
let demo ~n ~ncat =
let rng = Gsl.Rng.(make (default ())) in
let probs = Array.init ncat ~f:(fun _ -> Gsl.Rng.uniform rng) in
let sum = Array.fold probs ~init:0. ~f:( +. ) in
let pd = init ncat ~f:(fun _ -> 0.) in
let counts = Array.create ~len:ncat 0 in
Array.iteri probs ~f:(update pd) ;
for _ = 1 to n do
let k = draw pd rng in
counts.(k) <- counts.(k) + 1
done ;
Array.map probs ~f:(fun x -> x /. sum),
Array.map counts ~f:(fun k -> float k /. float n)