Source file owl_base_stats_dist_gamma.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# 1 "src/base/stats/owl_base_stats_dist_gamma.ml"
(*
 * OWL - OCaml Scientific and Engineering Computing
 * Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
 *)

open Owl_base_stats_dist_exponential
open Owl_base_stats_dist_gaussian

let std_gamma_rvs ~shape =
  let x = ref infinity in
  if shape = 1.
  then x := std_exponential_rvs ()
  else if shape < 1.
  then (
    try
      while true do
        let u = Random.float 1. in
        let v = std_exponential_rvs () in
        if u <= 1. -. shape
        then (
          x := u ** (1. /. shape);
          if !x <= v then raise Owl_exception.FOUND)
        else (
          let y = -.log ((1. -. u) /. shape) in
          x := (1. -. shape +. (shape *. y)) ** (1. /. shape);
          if !x <= v +. y then raise Owl_exception.FOUND)
      done
    with
    | _ -> ())
  else (
    let b = shape -. (1. /. 3.) in
    let c = 1. /. sqrt (9. *. b) in
    while true do
      let v = ref neg_infinity in
      while !v <= 0. do
        x := std_gaussian_rvs ();
        v := 1. +. (c *. !x)
      done;
      let v = !v *. !v *. !v in
      let u = Random.float 1. in
      if u < 1. -. (0.0331 *. !x *. !x *. !x *. !x)
      then (
        x := b *. v;
        raise Owl_exception.FOUND);
      if log u < (0.5 *. !x *. !x) +. (b *. (1. -. v +. log v))
      then (
        x := b *. v;
        raise Owl_exception.FOUND)
    done);
  !x


let gamma_rvs ~shape ~scale = scale *. std_gamma_rvs ~shape