Source file owl_algodiff_reverse.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# 1 "src/base/algodiff/owl_algodiff_reverse.ml"
module Make (C : sig
include Owl_algodiff_core_sig.Sig
val reverse_add : t -> t -> t
end) =
struct
open C
let reverse_reset x =
let rec reset xs =
match xs with
| [] -> ()
| x :: t ->
(match x with
| DR (_cp, aa, (_, register, _), af, _ai, tracker) ->
aa := reset_zero !aa;
af := !af + 1;
tracker := succ !tracker;
if !af = 1 && !tracker = 1 then reset (register t) else reset t
| _ -> reset t)
in
reset [ x ]
let reverse_push =
let rec push xs =
match xs with
| [] -> ()
| (v, x) :: t ->
(match x with
| DR (cp, aa, (adjoint, _, _), af, _ai, tracker) ->
aa := reverse_add !aa v;
(af := Stdlib.(!af - 1));
if !af = 0 && !tracker = 1
then push (adjoint cp aa t)
else (
tracker := pred !tracker;
push t)
| _ -> push t)
in
fun v x -> push [ v, x ]
let reverse_prop v x =
reverse_reset x;
reverse_push v x
end