Source file owl_algodiff_check.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# 1 "src/base/algodiff/owl_algodiff_check.ml"
(*
 * OWL - OCaml Scientific Computing
 * Copyright (c) 2016-2022 Liang Wang <liang@ocaml.xyz>
 *)

module Make (Algodiff : Owl_algodiff_generic_sig.Sig) = struct
  open Algodiff

  let generate_directions (dim1, dim2) =
    let n_directions = dim1 * dim2 in
    Array.init n_directions (fun j ->
        Arr
          (A.init [| dim1; dim2 |] (fun i ->
               if i = j then A.(float_to_elt 1.) else A.(float_to_elt 0.))))


  let generate_test_samples (dim1, dim2) n_samples =
    ( Array.init n_samples (fun _ -> Mat.gaussian dim1 dim2)
    , generate_directions (dim1, dim2) )


  module Reverse = struct
    let finite_difference_grad =
      let h x = F A.(float_to_elt x) in
      let two = h 2. in
      let eight = h 8. in
      let twelve = h 12. in
      fun ~order ~f ?(eps = 1E-5) x d ->
        let eps = F A.(float_to_elt eps) in
        let dx = Maths.(eps * d) in
        let df1 = Maths.(f (x + dx) - f (x - dx)) in
        match order with
        | `eighth ->
          let twodx = Maths.(h 2. * dx) in
          let threedx = Maths.(h 3. * dx) in
          let fourdx = Maths.(h 4. * dx) in
          let df2 = Maths.(f (x + twodx) - f (x - twodx)) in
          let df3 = Maths.(f (x + threedx) - f (x - threedx)) in
          let df4 = Maths.(f (x + fourdx) - f (x - fourdx)) in
          Maths.(
            ((h (4. /. 5.) * df1)
            + (h (-1. /. 5.) * df2)
            + (h (4. /. 105.) * df3)
            + (h (-1. /. 280.) * df4))
            / eps)
        | `fourth ->
          let df2 =
            let twodx = Maths.(two * dx) in
            Maths.(f (x + twodx) - f (x - twodx))
          in
          Maths.(((eight * df1) - df2) / (twelve * eps))
        | `second -> Maths.(df1 / (two * eps))


    let check ~threshold ~order ?(verbose = false) ?(eps = 1E-5) ~f =
      let compare rs =
        let n_d = Array.length rs in
        let r_fds = Array.map snd rs in
        let rms =
          Array.fold_left (fun acc r_fd -> acc +. (r_fd *. r_fd)) 0. r_fds /. float n_d
          |> sqrt
        in
        let max_err =
          rs
          |> Array.map (fun (r_ad, r_fd) -> abs_float (r_ad -. r_fd) /. (rms +. 1E-9))
          |> Array.fold_left max (-1.)
        in
        max_err < threshold, max_err
      in
      let f x = Maths.(sum' (f x)) in
      let g = grad f in
      fun ~directions samples ->
        let n_samples = Array.length samples in
        let check, max_err, n_passed =
          samples
          |> Array.map (fun x ->
                 let check, max_err =
                   Array.map
                     (fun d ->
                       let r_ad = Maths.(sum' (g x * d)) |> unpack_flt in
                       let r_fd =
                         finite_difference_grad ~order ~f ~eps x d |> unpack_flt
                       in
                       r_ad, r_fd)
                     directions
                   |> compare
                 in
                 check, max_err)
          |> Array.fold_left
               (fun (check_old, max_err_old, acc) (check, max_err) ->
                 ( check_old && check
                 , max max_err_old max_err
                 , if check then succ acc else acc ))
               (true, -1., 0)
        in
        if verbose
        then
          Printf.printf
            "adjoints passed: %i/%i | max_err: %f.\n%!"
            n_passed
            n_samples
            max_err;
        check, n_passed
  end

  module Forward = struct
    let check_tangent_dimensions ~f x =
      (* tangent at x should have the same dimension as f x *)
      match primal' (f x), primal' (jacobianv f x x) with
      | F _, F _     -> ()
      | Arr a, Arr b ->
        if A.shape a <> A.shape b then failwith "tangent dimension mismatch" else ()
      | _            -> failwith "tangent dimension mismatch"


    let check ~threshold ~f ~directions samples =
      check_tangent_dimensions ~f samples.(0);
      let f x = Maths.(sum' (f x)) in
      let reverse_g = grad f in
      let dim1, dim2 = Mat.shape directions.(0) in
      let forward_g x =
        Arr
          (A.init [| dim1; dim2 |] (fun i ->
               let v = directions.(i) in
               jacobianv f x v |> unpack_elt))
      in
      Array.fold_left
        (fun (b, n) x ->
          let reverse_gx = reverse_g x in
          let forward_gx = forward_g x in
          let e = Maths.(l2norm_sqr' (reverse_gx - forward_gx)) |> unpack_flt in
          let b' = e < threshold in
          let n = if b' then n + 1 else n in
          b && b', n)
        (true, 0)
        samples
  end
end