Source file owl_neural_compiler.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
# 1 "src/base/neural/owl_neural_compiler.ml"
module Make (E : Owl_types_computation_engine.Sig) = struct
module Engine = Owl_computation_engine.Flatten (E)
module Neural = Owl_neural_generic.Make (Engine)
open Neural
open Algodiff
(** Naive compilation functions, need to pass in loss function *)
let compile_simple network input_shape loss_fun =
Graph.init network;
Graph.mkpar network
|> Owl_utils.aarr_map (fun v ->
Engine.var_arr "" ~shape:(unpack_arr v |> Engine.shape) |> pack_arr)
|> Graph.update network;
let x = Engine.var_arr "x" ~shape:input_shape |> pack_arr in
let y' = Graph.forward network x |> fst in
let output_shape = unpack_arr y' |> Engine.shape in
let y = Engine.var_arr "y" ~shape:output_shape |> pack_arr in
let loss = loss_fun y y' in
let z = Graph.(backward network loss) in
let pri = Owl_utils_array.flatten (fst z) in
let adj = Owl_utils_array.flatten (snd z) in
Owl_graph.set_name (unpack_elt loss |> Engine.elt_to_node) "loss";
Array.iteri
(fun i a ->
let b = unpack_arr a |> Engine.arr_to_node in
let s = Printf.sprintf "x%i" i in
Owl_graph.set_name b s)
pri;
Array.iteri
(fun i a ->
let b = unpack_arr a |> Engine.arr_to_node in
let s = Printf.sprintf "x%i'" i in
Owl_graph.set_name b s)
adj;
let xt = unpack_arr x in
let yt = unpack_arr y in
let pri = Array.map unpack_arr pri in
let adj = Array.map unpack_arr adj in
xt, yt, pri, adj
(** Shallow compilation functions, includes only gradient *)
let compile_shallow (params : Params.typ) network full_size =
let loss_fun = Loss.run params.loss in
let batch =
match params.batch with
| Full -> full_size
| Mini n -> n
| Sample n -> n
| Stochastic -> 1
in
let network_shape = Graph.input_shape network in
let input_shape = Array.append [| batch |] network_shape in
Graph.init network;
Graph.mkpar network
|> Owl_utils.aarr_map (fun v ->
let v = Algodiff.unpack_arr v in
Engine.eval_arr [| v |];
let u = Engine.var_arr "" ~shape:(Engine.shape v) in
Engine.(assign_arr u (unpack_arr v));
Algodiff.pack_arr u)
|> Graph.update network;
let x = Engine.var_arr "x" ~shape:input_shape |> pack_arr in
let y' = Graph.forward network x |> fst in
let output_shape = unpack_arr y' |> Engine.shape in
let y = Engine.var_arr "y" ~shape:output_shape |> pack_arr in
let loss = loss_fun y y' in
let loss = Maths.(loss / _f (Mat.row_num y |> float_of_int)) in
let z = Graph.(backward network loss) in
let pri = Owl_utils_array.flatten (fst z) in
let adj = Owl_utils_array.flatten (snd z) in
Owl_graph.set_name (unpack_elt loss |> Engine.elt_to_node) "loss";
Array.iteri
(fun i a ->
let b = unpack_arr a |> Engine.arr_to_node in
let s = Printf.sprintf "x%i" i in
Owl_graph.set_name b s)
pri;
Array.iteri
(fun i a ->
let b = unpack_arr a |> Engine.arr_to_node in
let s = Printf.sprintf "x%i'" i in
Owl_graph.set_name b s)
adj;
let a0 = [| unpack_elt loss |> Engine.elt_to_node |] in
let a1 = Array.map (fun v -> unpack_arr v |> Engine.arr_to_node) pri in
let a2 = Array.map (fun v -> unpack_arr v |> Engine.arr_to_node) adj in
let a3 = Owl_utils_array.(a0 @ a1 @ a2) in
Engine.freeze_ancestors a3;
x, y, pri, adj, loss
(** Deep compilation functions, includes gs, us, ps, ch, and new weights *)
let compile_deep (params : Params.typ) network full_size =
let loss_fun = Loss.run params.loss in
let grad_fun = Gradient.run params.gradient in
let rate_fun = Learning_Rate.run params.learning_rate in
let regl_fun = Regularisation.run params.regularisation in
let momt_fun = Momentum.run params.momentum in
let upch_fun = Learning_Rate.update_ch params.learning_rate in
let clip_fun = Clipping.run params.clipping in
let batch =
match params.batch with
| Full -> full_size
| Mini n -> n
| Sample n -> n
| Stochastic -> 1
in
let network_shape = Graph.input_shape network in
let input_shape = Array.append [| batch |] network_shape in
Graph.init network;
Graph.mkpar network
|> Owl_utils.aarr_map (fun v ->
let v = Algodiff.unpack_arr v in
Engine.eval_arr [| v |];
let u = Engine.var_arr "" ~shape:(Engine.shape v) in
Engine.(assign_arr u (unpack_arr v));
Algodiff.pack_arr u)
|> Graph.update network;
let x = Engine.var_arr "x" ~shape:input_shape |> pack_arr in
let y' = Graph.forward network x |> fst in
let output_shape = unpack_arr y' |> Engine.shape in
let y = Engine.var_arr "y" ~shape:output_shape |> pack_arr in
let loss = loss_fun y y' in
let loss = Maths.(loss / _f (Mat.row_num y |> float_of_int)) in
let ws = Owl_utils_array.flatten (Graph.mkpri network) in
let reg =
match params.regularisation <> Regularisation.None with
| true -> Array.fold_left (fun a w -> Maths.(a + regl_fun w)) (_f 0.) ws
| false -> _f 0.
in
let loss = Maths.(loss + reg) in
Owl_graph.set_name (unpack_elt loss |> Engine.elt_to_node) "loss";
let z = Graph.(backward network loss) in
let ws = Owl_utils_array.flatten (fst z) in
let gs' = Owl_utils_array.flatten (snd z) in
Array.iteri
(fun i a ->
let b = unpack_arr a |> Engine.arr_to_node in
let s = Printf.sprintf "ws%i" i in
Owl_graph.set_name b s)
ws;
Array.iteri
(fun i a ->
let b = unpack_arr a |> Engine.arr_to_node in
let s = Printf.sprintf "gs'%i'" i in
Owl_graph.set_name b s)
gs';
let gs =
Array.mapi
(fun i w ->
let name = Printf.sprintf "gs%i" i in
let shape = Engine.shape (unpack_arr w) in
Engine.var_arr name ~shape |> pack_arr)
ws
in
let ps =
Array.mapi
(fun i w ->
let name = Printf.sprintf "ps%i" i in
let shape = Engine.shape (unpack_arr w) in
Engine.var_arr name ~shape |> pack_arr)
ws
in
let us =
Array.mapi
(fun i w ->
let name = Printf.sprintf "us%i" i in
let shape = Engine.shape (unpack_arr w) in
Engine.var_arr name ~shape |> pack_arr)
ws
in
let ch =
Array.mapi
(fun i w ->
let name1 = Printf.sprintf "cha%i" i in
let name2 = Printf.sprintf "chb%i" i in
let shape = Engine.shape (unpack_arr w) in
let ch1 = Engine.var_arr name1 ~shape |> pack_arr in
let ch2 = Engine.var_arr name2 ~shape |> pack_arr in
[| ch1; ch2 |])
ws
in
let gs' = Array.map clip_fun gs' in
let ps' = Owl_utils_array.map4 (grad_fun (fun a -> a)) ws gs ps gs' in
let ch' = Owl_utils_array.map2 upch_fun gs' ch in
let us' =
Owl_utils_array.map3
(fun p' g' c ->
Maths.(p' * rate_fun 999 g' c))
ps'
gs'
ch'
in
let us' = Owl_utils_array.map2 momt_fun us us' in
let ws' = Owl_utils_array.map2 (fun w u -> Maths.(w + u)) ws us' in
Array.iteri
(fun i a ->
let b = unpack_arr a |> Engine.arr_to_node in
let s = Printf.sprintf "ws'%i" i in
Owl_graph.set_name b s)
ws';
Array.iteri
(fun i a ->
let b = unpack_arr a |> Engine.arr_to_node in
let s = Printf.sprintf "ps'%i" i in
Owl_graph.set_name b s)
ps';
Array.iteri
(fun i a ->
let b = unpack_arr a |> Engine.arr_to_node in
let s = Printf.sprintf "us'%i" i in
Owl_graph.set_name b s)
us';
Array.iteri
(fun i a ->
let c0 = unpack_arr a.(0) |> Engine.arr_to_node in
let c1 = unpack_arr a.(1) |> Engine.arr_to_node in
let s0 = Printf.sprintf "cha'%i" i in
let s1 = Printf.sprintf "chb'%i" i in
Owl_graph.set_name c0 s0;
Owl_graph.set_name c1 s1)
ch';
let network_name = Graph.get_network_name network in
let ch, ch' = Owl_utils_array.(flatten ch, flatten ch') in
let _to_nodes = Array.map (fun v -> unpack_arr v |> Engine.arr_to_node) in
let raw_i = Owl_utils_array.(ws @ gs @ ps @ us @ ch) |> _to_nodes in
let raw_o = Owl_utils_array.(ws' @ gs' @ ps' @ us' @ ch') |> _to_nodes in
let param_i, param_o = Engine.remove_unused_iopair raw_i raw_o in
let output = Array.append param_o [| unpack_elt loss |> Engine.elt_to_node |] in
let cgraph = Engine.make_graph ~input:param_i ~output network_name in
Engine.make_iopair cgraph param_i param_o;
Owl_utils.aarr_iter
(fun x ->
let y = Algodiff.unpack_arr x in
let shape = Engine.shape y in
Engine.assign_arr y (Engine.A.zeros shape))
[| gs; ps; us; ch |];
loss, x, y, cgraph
let make_eval_fun loss xt yt cgraph =
let xt = Algodiff.unpack_arr xt in
let yt = Algodiff.unpack_arr yt in
let _eval xt' yt' =
let xt' = Algodiff.unpack_arr xt' in
let yt' = Algodiff.unpack_arr yt' in
Engine.eval_arr [| xt'; yt' |];
let xt' = Engine.unpack_arr xt' in
let yt' = Engine.unpack_arr yt' in
Engine.unsafe_assign_arr xt xt';
Engine.unsafe_assign_arr yt yt';
Engine.eval_graph cgraph;
loss
in
_eval
let make_update_fun cgraph =
let _update () = Engine.update_iopair cgraph in
_update
let train ?state ?params network x y =
let params =
match params with
| Some p -> p
| None -> Params.default ()
in
let network_name = Graph.get_network_name network in
Owl_log.info "compile network %s into static graph ..." network_name;
let x_size = (unpack_arr x |> Engine.shape).(0) in
let loss, xt, yt, cgraph = compile_deep params network x_size in
let eval = make_eval_fun loss xt yt cgraph in
let update = make_update_fun cgraph in
let save _fname = () in
Engine.save_graph cgraph (network_name ^ "_raw.cgd");
Engine.optimise cgraph;
Engine.save_graph cgraph (network_name ^ "_opt.cgd");
Owl_log.info "start training %s ..." network_name;
Optimise.minimise_compiled_network ?state params eval update save x y
let model_inputs ?(optimise = true) ?(batch_size = 1) network =
let network_name = Graph.get_network_name network in
Owl_log.info "compile network %s into static graph ..." network_name;
let input_shapes = Graph.input_shapes network in
let inputs =
Array.mapi
(fun i sh ->
Engine.var_arr
("input_" ^ string_of_int i)
~shape:(Array.append [| batch_size |] sh)
|> pack_arr)
input_shapes
in
let outputs = Graph.run_inputs inputs network in
let _to_nodes = Array.map (fun v -> unpack_arr v |> Engine.arr_to_node) in
let i, o = _to_nodes inputs, _to_nodes outputs in
let cgraph = Engine.make_graph ~input:i ~output:o network_name in
if optimise then Engine.optimise cgraph;
let eval xt' =
let xt = Array.map (fun x -> Algodiff.unpack_arr x) inputs in
let xt' = Array.map (fun x' -> Algodiff.unpack_arr x') xt' in
Engine.eval_arr xt';
let xt' = Array.map (fun x' -> Engine.unpack_arr x') xt' in
Array.iter2 (fun x x' -> Engine.unsafe_assign_arr x x') xt xt';
Engine.eval_graph cgraph;
outputs
in
let results xt =
let n = Optimise.Utils.sample_num xt.(0) in
let chunk_size i =
let a = i * batch_size in
let b = min n (a + batch_size) - 1 in
let c = b - a + 1 in
a, b, c
in
let get_chunk a b x =
match x with
| Arr x ->
let res = A.get_slice [ [ a; b ] ] x in
Arr res
| _ -> failwith ("Owl_neural_compiler.model_inputs: get_chunk: " ^ type_info x)
in
let iterate i =
let a, b, c = chunk_size i in
let xt = Array.map (get_chunk a b) xt in
let result =
Array.map
(fun x ->
let x =
let y = Algodiff.unpack_arr x in
if c <> batch_size then A.get_slice [ [ 0; c - 1 ] ] y else y
in
A.copy x)
(eval xt)
in
Engine.eval_arr result;
result
in
let nb_iterations = ((n - 1) / batch_size) + 1 in
let result = Array.init nb_iterations (fun i -> iterate i) in
let slice i = Array.init nb_iterations (fun j -> result.(j).(i)) in
let result =
Array.init (Array.length result.(0)) (fun i -> A.concatenate ~axis:0 (slice i))
in
Engine.eval_arr result;
Array.map Algodiff.pack_arr result
in
results
let model ?optimise ?batch_size network =
let eval = model_inputs ?optimise ?batch_size network in
fun xt' -> (eval [| xt' |]).(0)
end