1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
let invalid_argf fmt = Printf.ksprintf invalid_arg fmt
let invalid_argf_fn fn fmt =
Printf.ksprintf (fun msg -> invalid_argf "Loss.%s: %s" fn msg) fmt
let check_logits_shape ~fn logits =
let logits_shape = Nx.shape logits in
let logits_rank = Array.length logits_shape in
if logits_rank < 1 then invalid_argf_fn fn "logits must have rank >= 1";
let class_axis = logits_rank - 1 in
let num_classes = logits_shape.(class_axis) in
if num_classes <= 0 then
invalid_argf_fn fn "logits class dimension must be positive (got %d)"
num_classes;
logits_shape
let check_same_shape ~fn ~rhs_name lhs rhs =
let lhs_rank = Array.length lhs in
let rhs_rank = Array.length rhs in
if rhs_rank <> lhs_rank then
invalid_argf_fn fn "%s rank mismatch (got %d, expected %d)" rhs_name
rhs_rank lhs_rank;
for i = 0 to lhs_rank - 1 do
if rhs.(i) <> lhs.(i) then
invalid_argf_fn fn "%s shape mismatch at axis %d (got %d, expected %d)"
rhs_name i rhs.(i) lhs.(i)
done
let check_cross_entropy_shapes logits labels =
let fn = "cross_entropy" in
let logits_shape = check_logits_shape ~fn logits in
let labels_shape = Nx.shape labels in
check_same_shape ~fn ~rhs_name:"labels" logits_shape labels_shape
let cross_entropy logits labels =
check_cross_entropy_shapes logits labels;
let max_logits = Nx.max logits ~axes:[ -1 ] ~keepdims:true in
let shifted = Nx.sub logits max_logits in
let log_sum_exp =
Nx.log (Nx.sum (Nx.exp shifted) ~axes:[ -1 ] ~keepdims:true)
in
let log_softmax = Nx.sub shifted log_sum_exp in
let per_example = Nx.neg (Nx.sum (Nx.mul labels log_softmax) ~axes:[ -1 ]) in
Nx.mean per_example
let check_sparse_indices_dtype indices =
let fn = "cross_entropy_sparse" in
let dtype = Nx.dtype indices in
if not (Nx_core.Dtype.is_int dtype) then
invalid_argf_fn fn "expected integer labels, got %s"
(Nx_core.Dtype.to_string dtype)
let check_sparse_shapes logits indices =
let fn = "cross_entropy_sparse" in
let logits_shape = check_logits_shape ~fn logits in
let indices_shape = Nx.shape indices in
let logits_rank = Array.length logits_shape in
let indices_rank = Array.length indices_shape in
if indices_rank <> logits_rank - 1 then
invalid_argf_fn fn "labels rank mismatch (got %d, expected %d)" indices_rank
(logits_rank - 1);
for i = 0 to indices_rank - 1 do
if indices_shape.(i) <> logits_shape.(i) then
invalid_argf_fn fn
"labels shape mismatch at axis %d (got %d, expected %d)" i
indices_shape.(i) logits_shape.(i)
done;
let class_axis = logits_rank - 1 in
logits_shape.(class_axis)
let cross_entropy_sparse logits indices =
check_sparse_indices_dtype indices;
ignore (check_sparse_shapes logits indices : int);
let indices_int = Nx.cast Nx.int32 indices in
let max_logits = Nx.max logits ~axes:[ -1 ] ~keepdims:true in
let shifted = Nx.sub logits max_logits in
let log_sum_exp =
Nx.log (Nx.sum (Nx.exp shifted) ~axes:[ -1 ] ~keepdims:true)
in
let indices_expanded = Nx.expand_dims [ -1 ] indices_int in
let true_logits = Nx.take_along_axis ~axis:(-1) indices_expanded shifted in
let per_example =
Nx.neg
(Nx.sub
(Nx.squeeze ~axes:[ -1 ] true_logits)
(Nx.squeeze ~axes:[ -1 ] log_sum_exp))
in
Nx.mean per_example
let binary_cross_entropy logits labels =
let fn = "binary_cross_entropy" in
let logits_shape = Nx.shape logits in
let labels_shape = Nx.shape labels in
check_same_shape ~fn ~rhs_name:"labels" logits_shape labels_shape;
let dtype = Nx.dtype logits in
let one = Nx.scalar dtype 1.0 in
let log_p = Activation.log_sigmoid logits in
let log_1_minus_p = Activation.log_sigmoid (Nx.neg logits) in
let per_element =
Nx.neg
(Nx.add (Nx.mul labels log_p) (Nx.mul (Nx.sub one labels) log_1_minus_p))
in
Nx.mean per_element
let mse predictions targets =
let diff = Nx.sub predictions targets in
Nx.mean (Nx.mul diff diff)
let mae predictions targets = Nx.mean (Nx.abs (Nx.sub predictions targets))