1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
open Rune
type config = {
num_classes : int;
input_channels : int;
input_size : int * int;
activation : [ `tanh | `relu | `sigmoid ];
dropout_rate : float option;
}
let default_config =
{
num_classes = 10;
input_channels = 1;
input_size = (32, 32);
activation = `tanh;
dropout_rate = None;
}
let mnist_config =
{
default_config with
input_size = (28, 28);
}
let cifar10_config =
{
num_classes = 10;
input_channels = 3;
input_size = (32, 32);
activation = `relu;
dropout_rate = Some 0.5;
}
type t = Kaun.module_
let create ?(config = default_config) () =
let open Kaun.Layer in
let activation_fn =
match config.activation with
| `tanh -> tanh ()
| `relu -> relu ()
| `sigmoid -> sigmoid ()
in
let layers =
[
conv2d ~in_channels:config.input_channels ~out_channels:6
~kernel_size:(5, 5) ();
activation_fn;
avg_pool2d ~kernel_size:(2, 2) ~stride:(2, 2) ();
conv2d ~in_channels:6 ~out_channels:16 ~kernel_size:(5, 5) ();
activation_fn;
avg_pool2d ~kernel_size:(2, 2) ~stride:(2, 2) ();
flatten ();
linear ~in_features:400 ~out_features:120 ();
activation_fn;
]
@ (match config.dropout_rate with
| Some rate -> [ dropout ~rate () ]
| None -> [])
@ [
linear ~in_features:120 ~out_features:84 ();
activation_fn;
]
@ (match config.dropout_rate with
| Some rate -> [ dropout ~rate () ]
| None -> [])
@ [
linear ~in_features:84 ~out_features:config.num_classes ();
]
in
sequential layers
let for_mnist () = create ~config:mnist_config ()
let for_cifar10 () = create ~config:cifar10_config ()
let forward ~model ~params ~training ~input =
Kaun.apply model params ~training input
let ~model:_ ~params:_ ~input:_ =
failwith "extract_features not implemented yet"
let num_parameters params =
let tensors = Kaun.Ptree.flatten_with_paths params in
List.fold_left
(fun acc (_, tensor) -> acc + Kaun.Ptree.Tensor.numel tensor)
0 tensors
let parameter_breakdown params =
let tensors = Kaun.Ptree.flatten_with_paths params in
let breakdown = Buffer.create 256 in
Buffer.add_string breakdown "LeNet-5 Parameter Breakdown:\n";
Buffer.add_string breakdown "============================\n";
let layer_params = Hashtbl.create 10 in
List.iter
(fun (path, tensor) ->
let name = Kaun.Ptree.Path.to_string path in
let layer_name =
try
let idx = String.index name '.' in
String.sub name 0 idx
with Not_found -> name
in
let size = Kaun.Ptree.Tensor.numel tensor in
let current =
try Hashtbl.find layer_params layer_name with Not_found -> 0
in
Hashtbl.replace layer_params layer_name (current + size))
tensors;
Hashtbl.iter
(fun layer count ->
Buffer.add_string breakdown
(Printf.sprintf " %s: %d parameters\n" layer count))
layer_params;
let total = num_parameters params in
Buffer.add_string breakdown
(Printf.sprintf "\nTotal: %d parameters (%.2f MB with float32)\n" total
(float_of_int (total * 4) /. 1024. /. 1024.));
Buffer.contents breakdown
type train_config = {
learning_rate : float;
batch_size : int;
num_epochs : int;
weight_decay : float option;
momentum : float option;
}
let default_train_config =
{
learning_rate = 0.01;
batch_size = 64;
num_epochs = 10;
weight_decay = Some 0.0001;
momentum = Some 0.9;
}
let accuracy ~predictions ~labels =
let pred_classes = argmax predictions ~axis:1 in
let labels_int32 = cast Int32 labels in
let correct = equal pred_classes labels_int32 in
let correct_float = cast Float32 correct in
let total = float_of_int (Array.get (shape labels) 0) in
let num_correct = sum correct_float in
let num_correct_scalar =
match shape num_correct with
| [||] ->
let arr = to_array num_correct in
arr.(0)
| _ -> failwith "Expected scalar result from sum"
in
num_correct_scalar /. total