1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
type ('layout, 'dev) tensor = (float, 'layout, 'dev) Rune.t
type 'layout dtype = (float, 'layout) Rune.dtype
type 'dev device = 'dev Rune.device
type ('layout, 'dev) params =
| Tensor of ('layout, 'dev) tensor
| List of ('layout, 'dev) params list
| Record of (string * ('layout, 'dev) params) list
module Rngs = struct
type t = int
let create ~seed () =
abs seed land 0x7FFFFFFF
let split t =
let hash x =
let open Int32 in
let x = of_int x in
let x = logxor x (shift_right_logical x 16) in
let x = mul x 0x85ebca6bl in
let x = logxor x (shift_right_logical x 13) in
let x = mul x 0xc2b2ae35l in
let x = logxor x (shift_right_logical x 16) in
to_int (logand x 0x7FFFFFFFl)
in
let new_seed1 = hash ((t * 2) + 1) in
let new_seed2 = hash ((t * 2) + 2) in
(new_seed1, new_seed2)
end
type model =
| Model : {
init :
'layout 'dev.
rngs:Rngs.t -> ('layout, 'dev) tensor -> ('layout, 'dev) params;
apply :
'layout 'dev.
('layout, 'dev) params ->
training:bool ->
?rngs:Rngs.t ->
('layout, 'dev) tensor ->
('layout, 'dev) tensor;
}
-> model
let init (Model m) ~rngs x =
let result = m.init ~rngs x in
result
let apply (Model m) params ~training ?rngs x = m.apply params ~training ?rngs x
let split_at n lst =
let rec aux i acc = function
| [] -> (List.rev acc, [])
| h :: t as l ->
if i = 0 then (List.rev acc, l) else aux (i - 1) (h :: acc) t
in
aux n [] lst
let rec flatten_ptree : type layout dev.
(layout, dev) params ->
(layout, dev) tensor list
* ((layout, dev) tensor list -> (layout, dev) params) = function
| Tensor t ->
( [ t ],
function
| [ t' ] -> Tensor t'
| _ -> failwith "Invalid number of tensors" )
| List l ->
let pairs = List.map flatten_ptree l in
let tensors = List.concat (List.map fst pairs) in
let rebuild =
fun tensors ->
let rec aux tensors acc pairs =
match pairs with
| [] -> List.rev acc
| (tensors_pt, rebuild_pt) :: pairs' ->
let n = List.length tensors_pt in
let tensors_for_pt, tensors_rest = split_at n tensors in
let pt' = rebuild_pt tensors_for_pt in
aux tensors_rest (pt' :: acc) pairs'
in
List (aux tensors [] pairs)
in
(tensors, rebuild)
| Record r ->
let sorted_r =
List.sort (fun (k1, _) (k2, _) -> String.compare k1 k2) r
in
let pairs = List.map (fun (k, pt) -> (k, flatten_ptree pt)) sorted_r in
let tensors =
List.concat (List.map (fun (_, (tensors_pt, _)) -> tensors_pt) pairs)
in
let rebuild =
fun tensors ->
let rec aux tensors acc pairs =
match pairs with
| [] -> List.rev acc
| (k, (tensors_pt, rebuild_pt)) :: pairs' ->
let n = List.length tensors_pt in
let tensors_for_pt, tensors_rest = split_at n tensors in
let pt' = rebuild_pt tensors_for_pt in
aux tensors_rest ((k, pt') :: acc) pairs'
in
Record (aux tensors [] pairs)
in
(tensors, rebuild)
let value_and_grad f params =
let tensors, rebuild = flatten_ptree params in
let f_on_list ts =
let params' = rebuild ts in
f params'
in
let value, grads_list = Rune.value_and_grads f_on_list tensors in
let grad_ptree = rebuild grads_list in
(value, grad_ptree)
let grad f params =
let _, grads = value_and_grad f params in
grads
module Metrics = struct
type metric_type = Avg | Sum | Accuracy
type metric = { name : string; metric_type : metric_type }
type t = {
metrics : metric list;
mutable values : (string * float) list;
mutable counts : (string * int) list;
}
let avg name = { name; metric_type = Avg }
let sum name = { name; metric_type = Sum }
let accuracy name = { name; metric_type = Accuracy }
let create metrics =
let values = List.map (fun m -> (m.name, 0.0)) metrics in
let counts = List.map (fun m -> (m.name, 0)) metrics in
{ metrics; values; counts }
let update t ?loss ?logits ?labels () =
(match loss with
| Some loss_tensor ->
let loss_val = Rune.unsafe_get [] loss_tensor in
t.values <-
List.map
(fun (n, v) -> if n = "loss" then (n, v +. loss_val) else (n, v))
t.values;
t.counts <-
List.map
(fun (n, c) -> if n = "loss" then (n, c + 1) else (n, c))
t.counts
| None -> ());
match (logits, labels) with
| Some logits_tensor, Some labels_tensor ->
let predictions = Rune.argmax logits_tensor ~axis:(-1) in
let labels_int = Rune.cast (Rune.dtype predictions) labels_tensor in
let correct = Rune.equal predictions labels_int in
let accuracy_val =
Rune.unsafe_get []
(Rune.mean (Rune.cast (Rune.dtype logits_tensor) correct))
in
t.values <-
List.map
(fun (n, v) ->
if n = "accuracy" then (n, v +. accuracy_val) else (n, v))
t.values;
t.counts <-
List.map
(fun (n, c) -> if n = "accuracy" then (n, c + 1) else (n, c))
t.counts
| _ -> ()
let compute t =
List.map2
(fun metric (name, total) ->
let count = List.assoc name t.counts in
match metric.metric_type with
| Avg -> (name, if count = 0 then 0.0 else total /. float_of_int count)
| Sum -> (name, total)
| Accuracy ->
(name, if count = 0 then 0.0 else total /. float_of_int count))
t.metrics t.values
let get t name =
let total = List.assoc name t.values in
let count = List.assoc name t.counts in
let metric = List.find (fun m -> m.name = name) t.metrics in
match metric.metric_type with
| Avg | Accuracy -> if count = 0 then 0.0 else total /. float_of_int count
| Sum -> total
let reset t =
t.values <- List.map (fun (n, _) -> (n, 0.0)) t.values;
t.counts <- List.map (fun (n, _) -> (n, 0)) t.counts
end
module Dataset = struct
type 'a t = 'a Seq.t
let of_xy (x, y) = Seq.return (x, y)
let map f ds = Seq.map f ds
let batch_xy batch_size ds =
Seq.flat_map
(fun (x, y) ->
if Array.length (Rune.shape x) > 1 then
let n_samples = (Rune.shape x).(0) in
let rec create_batches start =
if start >= n_samples then Seq.empty
else
let end_idx = min (start + batch_size) n_samples in
let x_batch = Rune.slice [ R [ start; end_idx ] ] x in
let y_batch = Rune.slice [ R [ start; end_idx ] ] y in
let x_batch =
if Rune.is_c_contiguous x_batch then x_batch
else
let result = Rune.contiguous x_batch in
result
in
let y_batch =
if Rune.is_c_contiguous y_batch then y_batch
else Rune.contiguous y_batch
in
Seq.cons (x_batch, y_batch) (create_batches end_idx)
in
create_batches 0
else Seq.return (x, y))
ds
let batch _batch_size ds = ds
let shuffle ?seed ds =
let _ = seed in
ds
let iter f ds = Seq.iter f ds
let length ds = Seq.length ds
let take n ds =
let rec take_aux n acc seq =
if n <= 0 then List.rev acc
else
match seq () with
| Seq.Nil -> List.rev acc
| Seq.Cons (x, rest) -> take_aux (n - 1) (x :: acc) rest
in
take_aux n [] ds
end
module Loss = struct
let softmax_cross_entropy logits labels =
Rune.debug_with_context "softmax_cross_entropy" (fun () ->
let max_logits = Rune.max logits ~axes:[| -1 |] ~keepdims:true in
let exp_logits = Rune.exp (Rune.sub logits max_logits) in
let sum_exp = Rune.sum exp_logits ~axes:[| -1 |] ~keepdims:true in
let log_softmax =
Rune.sub logits (Rune.add max_logits (Rune.log sum_exp))
in
let loss =
Rune.neg (Rune.sum (Rune.mul labels log_softmax) ~axes:[| -1 |])
in
Rune.mean loss)
let softmax_cross_entropy_with_indices logits indices =
let indices_int = Rune.cast Rune.int32 indices in
let num_classes = (Rune.shape logits).(1) in
let one_hot = Rune.one_hot ~num_classes indices_int in
let one_hot_float = Rune.cast (Rune.dtype logits) one_hot in
softmax_cross_entropy logits one_hot_float
let binary_cross_entropy logits labels =
Rune.debug_with_context "binary_cross_entropy" (fun () ->
let dtype = Rune.dtype logits in
let dev = Rune.device logits in
let one = Rune.scalar dev dtype 1.0 in
let log_sig = Rune.log_sigmoid logits in
let log_sig_neg = Rune.log_sigmoid (Rune.neg logits) in
let term1 = Rune.mul labels log_sig in
let term2 = Rune.mul (Rune.sub one labels) log_sig_neg in
let loss_per_example = Rune.neg (Rune.add term1 term2) in
Rune.mean loss_per_example)
let sigmoid_binary_cross_entropy logits labels =
Rune.debug_with_context "sigmoid_binary_cross_entropy" (fun () ->
let dtype = Rune.dtype logits in
let dev = Rune.device logits in
let one = Rune.scalar dev dtype 1.0 in
let log_sig = Rune.log_sigmoid logits in
let log_sig_neg = Rune.log_sigmoid (Rune.neg logits) in
let term1 = Rune.mul labels log_sig in
let term2 = Rune.mul (Rune.sub one labels) log_sig_neg in
Rune.neg (Rune.add term1 term2))
let mse predictions targets =
Rune.debug_with_context "mse" (fun () ->
let diff = Rune.sub predictions targets in
let squared = Rune.mul diff diff in
Rune.mean squared)
let mae predictions targets =
Rune.debug_with_context "mae" (fun () ->
let diff = Rune.sub predictions targets in
let abs_diff = Rune.abs diff in
Rune.mean abs_diff)
end
module Initializer = struct
type t =
| Constant of float
| Zeros
| Ones
| Uniform of { scale : float }
| Normal of { mean : float; std : float }
| TruncatedNormal of { stddev : float; lower : float; upper : float }
| VarianceScaling of {
scale : float;
mode : [ `Fan_in | `Fan_out | `Fan_avg ];
distribution : [ `Normal | `Truncated_normal | `Uniform ];
in_axis : int;
out_axis : int;
}
| GlorotUniform of { in_axis : int; out_axis : int }
| GlorotNormal of { in_axis : int; out_axis : int }
| HeUniform of { in_axis : int; out_axis : int }
| HeNormal of { in_axis : int; out_axis : int }
| LecunUniform of { in_axis : int; out_axis : int }
| LecunNormal of { in_axis : int; out_axis : int }
| Orthogonal of { scale : float; column_axis : int }
| DeltaOrthogonal of { scale : float; column_axis : int }
| UniformRange of { low : float; high : float }
| NormalRange of { mean : float; stddev : float }
let constant value = Constant value
let zeros () = Zeros
let ones () = Ones
let uniform ?(scale = 0.01) () = Uniform { scale }
let normal ~mean ~std = Normal { mean; std }
let truncated_normal ?(stddev = 0.01) ?(lower = -2.0) ?(upper = 2.0) () =
TruncatedNormal { stddev; lower; upper }
let variance_scaling ~scale ~mode ~distribution ~in_axis ~out_axis () =
VarianceScaling { scale; mode; distribution; in_axis; out_axis }
let glorot_uniform ?(in_axis = -2) ?(out_axis = -1) () =
GlorotUniform { in_axis; out_axis }
let glorot_normal ?(in_axis = -2) ?(out_axis = -1) () =
GlorotNormal { in_axis; out_axis }
let xavier_uniform = glorot_uniform
let xavier_normal = glorot_normal
let he_uniform ?(in_axis = -2) ?(out_axis = -1) () =
HeUniform { in_axis; out_axis }
let he_normal ?(in_axis = -2) ?(out_axis = -1) () =
HeNormal { in_axis; out_axis }
let kaiming_uniform = he_uniform
let kaiming_normal = he_normal
let lecun_uniform ?(in_axis = -2) ?(out_axis = -1) () =
LecunUniform { in_axis; out_axis }
let lecun_normal ?(in_axis = -2) ?(out_axis = -1) () =
LecunNormal { in_axis; out_axis }
let orthogonal ?(scale = 1.0) ?(column_axis = -1) () =
Orthogonal { scale; column_axis }
let delta_orthogonal ?(scale = 1.0) ?(column_axis = -1) () =
DeltaOrthogonal { scale; column_axis }
let uniform_range ~low ~high () = UniformRange { low; high }
let normal_range ~mean ~stddev () = NormalRange { mean; stddev }
let apply (type layout dev) init rng shape (dev : dev Rune.device)
(dtype : (float, layout) Rune.dtype) =
match init with
| Constant value -> Rune.full dev dtype shape value
| Zeros -> Initializers.zeros () rng shape dev dtype
| Ones -> Initializers.ones () rng shape dev dtype
| Uniform { scale } -> Initializers.uniform ~scale () rng shape dev dtype
| Normal { mean; std } ->
Initializers.normal_range ~mean ~stddev:std () rng shape dev dtype
| TruncatedNormal { stddev; lower; upper } ->
Initializers.truncated_normal_init ~stddev ~lower ~upper () rng shape
dev dtype
| VarianceScaling { scale; mode; distribution; in_axis; out_axis } ->
Initializers.variance_scaling ~scale ~mode ~distribution ~in_axis
~out_axis () rng shape dev dtype
| GlorotUniform { in_axis; out_axis } ->
let rank = Array.length shape in
if rank = 0 then
Rune.zeros dev dtype shape
else if rank = 1 then
let n = float_of_int shape.(0) in
let scale = sqrt (3.0 /. n) in
Initializers.uniform ~scale () rng shape dev dtype
else
Initializers.glorot_uniform ~in_axis ~out_axis () rng shape dev dtype
| GlorotNormal { in_axis; out_axis } ->
let rank = Array.length shape in
if rank = 0 then
Rune.zeros dev dtype shape
else if rank = 1 then
let n = float_of_int shape.(0) in
let stddev = sqrt (1.0 /. n) in
Initializers.normal ~stddev () rng shape dev dtype
else
Initializers.glorot_normal ~in_axis ~out_axis () rng shape dev dtype
| HeUniform { in_axis; out_axis } ->
Initializers.he_uniform ~in_axis ~out_axis () rng shape dev dtype
| HeNormal { in_axis; out_axis } ->
Initializers.he_normal ~in_axis ~out_axis () rng shape dev dtype
| LecunUniform { in_axis; out_axis } ->
Initializers.lecun_uniform ~in_axis ~out_axis () rng shape dev dtype
| LecunNormal { in_axis; out_axis } ->
Initializers.lecun_normal ~in_axis ~out_axis () rng shape dev dtype
| Orthogonal { scale; column_axis } ->
Initializers.orthogonal ~scale ~column_axis () rng shape dev dtype
| DeltaOrthogonal { scale; column_axis } ->
Initializers.delta_orthogonal ~scale ~column_axis () rng shape dev dtype
| UniformRange { low; high } ->
Initializers.uniform_range ~low ~high () rng shape dev dtype
| NormalRange { mean; stddev } ->
Initializers.normal_range ~mean ~stddev () rng shape dev dtype
end
module Layer = struct
let conv2d ~in_channels ~out_channels ?(kernel_size = (3, 3)) () =
let kh, kw = kernel_size in
Model
{
init =
(fun (type l d) ~rngs (x : (l, d) tensor) ->
Rune.debug_with_context
(Printf.sprintf "conv2d_%dx%d_%dx%d_init" in_channels out_channels
kh kw) (fun () ->
let rng1, _rng2 = Rngs.split rngs in
let dev = Rune.device x in
let dtype = Rune.dtype x in
let fan_in = in_channels * kh * kw in
let fan_out = out_channels * kh * kw in
let limit = sqrt (6.0 /. float_of_int (fan_in + fan_out)) in
let weight_shape = [| out_channels; in_channels; kh; kw |] in
let w = Rune.rand dev dtype ~seed:rng1 weight_shape in
let w =
Rune.sub
(Rune.mul w (Rune.scalar dev dtype (2.0 *. limit)))
(Rune.scalar dev dtype limit)
in
let b = Rune.zeros dev dtype [| out_channels |] in
Record [ ("weight", Tensor w); ("bias", Tensor b) ]));
apply =
(fun (type l d)
(params : (l, d) params)
~training:_
?rngs:_
(x : (l, d) tensor)
->
match params with
| Record fields ->
let w =
match List.assoc_opt "weight" fields with
| Some (Tensor t) -> t
| _ -> failwith "conv2d: missing or invalid weight parameter"
in
let b =
match List.assoc_opt "bias" fields with
| Some (Tensor t) -> t
| _ -> failwith "conv2d: missing or invalid bias parameter"
in
Rune.debug_with_context
(Printf.sprintf "conv2d_%dx%d_%dx%d" in_channels out_channels
kh kw) (fun () ->
let conv =
Rune.convolve2d x w ~stride:(1, 1) ~padding_mode:`Same
in
let b_reshaped =
Rune.reshape [| 1; out_channels; 1; 1 |] b
in
Rune.add conv b_reshaped)
| _ -> failwith "conv2d: invalid params structure");
}
let linear ~in_features ~out_features ?weight_init ?bias_init () =
let weight_init =
match weight_init with
| Some init -> init
| None -> Initializer.glorot_uniform ~in_axis:0 ~out_axis:1 ()
in
let bias_init =
match bias_init with
| Some init -> init
| None -> Initializer.constant 0.0
in
Model
{
init =
(fun ~rngs x ->
Rune.debug_with_context
(Printf.sprintf "linear_%dx%d_init" in_features out_features)
(fun () ->
let rng1, rng2 = Rngs.split rngs in
let dev = Rune.device x in
let dtype = Rune.dtype x in
let w =
Initializer.apply weight_init rng1
[| in_features; out_features |]
dev dtype
in
let b =
Initializer.apply bias_init rng2 [| out_features |] dev dtype
in
Record [ ("weight", Tensor w); ("bias", Tensor b) ]));
apply =
(fun (type l d)
(params : (l, d) params)
~training:_
?rngs:_
(x : (l, d) tensor)
->
Rune.debug_with_context
(Printf.sprintf "linear_%dx%d" in_features out_features)
(fun () ->
match params with
| Record fields ->
let w =
match List.assoc_opt "weight" fields with
| Some (Tensor t) -> t
| _ ->
failwith "linear: missing or invalid weight parameter"
in
let b =
match List.assoc_opt "bias" fields with
| Some (Tensor t) -> t
| _ ->
failwith "linear: missing or invalid bias parameter"
in
let z = Rune.matmul x w in
Rune.add z b
| _ -> failwith "linear: invalid params structure"));
}
let dropout ~rate () =
Model
{
init = (fun ~rngs:_ _x -> List []);
apply =
(fun _params ~training ?rngs x ->
if training && rate > 0.0 then
match rngs with
| Some rng ->
let seed = fst (Rngs.split rng) in
let dev = Rune.device x in
let dtype = Rune.dtype x in
let shape = Rune.shape x in
let mask = Rune.rand dev dtype ~seed shape in
let keep_prob = 1.0 -. rate in
let threshold = Rune.scalar dev dtype keep_prob in
let binary_mask = Rune.less mask threshold in
let binary_mask_float = Rune.cast dtype binary_mask in
let scale = Rune.scalar dev dtype (1.0 /. keep_prob) in
Rune.mul x (Rune.mul binary_mask_float scale)
| None -> failwith "dropout requires RNG during training"
else x);
}
let batch_norm ~num_features () =
Model
{
init =
(fun ~rngs x ->
Rune.debug_with_context
(Printf.sprintf "batch_norm_%d_init" num_features) (fun () ->
let _rng1, _rng2 = Rngs.split rngs in
let dev = Rune.device x in
let dtype = Rune.dtype x in
let scale = Rune.ones dev dtype [| num_features |] in
let bias = Rune.zeros dev dtype [| num_features |] in
Record [ ("scale", Tensor scale); ("bias", Tensor bias) ]));
apply =
(fun params ~training:_ ?rngs:_ x ->
match params with
| Record fields ->
let scale =
match List.assoc_opt "scale" fields with
| Some (Tensor t) -> t
| _ ->
failwith "batch_norm: missing or invalid scale parameter"
in
let bias =
match List.assoc_opt "bias" fields with
| Some (Tensor t) -> t
| _ ->
failwith "batch_norm: missing or invalid bias parameter"
in
let axes =
match Array.length (Rune.shape x) with
| 2 -> [| 0 |]
| 4 -> [| 0; 2; 3 |]
| _ -> [| 0 |]
in
Rune.debug_with_context
(Printf.sprintf "batch_norm_%d_apply" num_features) (fun () ->
let mean = Rune.mean x ~axes ~keepdims:true in
let variance = Rune.var x ~axes ~keepdims:true in
let eps = 1e-5 in
let dtype = Rune.dtype x in
let dev = Rune.device x in
let epsilon = Rune.scalar dev dtype eps in
let x_normalized =
Rune.div (Rune.sub x mean)
(Rune.sqrt (Rune.add variance epsilon))
in
let scale_shape =
match Array.length (Rune.shape x) with
| 2 -> [| 1; num_features |]
| 4 -> [| 1; num_features; 1; 1 |]
| _ -> [| 1; num_features |]
in
let scale_reshaped = Rune.reshape scale_shape scale in
let bias_reshaped = Rune.reshape scale_shape bias in
Rune.add
(Rune.mul x_normalized scale_reshaped)
bias_reshaped)
| _ -> failwith "batch_norm: invalid params structure");
}
let max_pool2d ~kernel_size ?stride () =
let stride = match stride with Some s -> s | None -> kernel_size in
Model
{
init = (fun ~rngs:_ _x -> List []);
apply =
(fun _params ~training:_ ?rngs:_ x ->
let pooled, _ = Rune.max_pool2d x ~kernel_size ~stride in
pooled);
}
let avg_pool2d ~kernel_size ?stride () =
let stride = match stride with Some s -> s | None -> kernel_size in
Model
{
init = (fun ~rngs:_ _x -> List []);
apply =
(fun _params ~training:_ ?rngs:_ x ->
Rune.avg_pool2d x ~kernel_size ~stride);
}
let flatten () =
Model
{
init = (fun ~rngs:_ _x -> List []);
apply =
(fun _params ~training:_ ?rngs:_ x ->
let shape = Rune.shape x in
let batch_size = shape.(0) in
let flat_size =
Array.fold_left ( * ) 1
(Array.sub shape 1 (Array.length shape - 1))
in
let x = if Rune.is_c_contiguous x then x else Rune.contiguous x in
Rune.reshape [| batch_size; flat_size |] x);
}
let relu () =
Model
{
init = (fun ~rngs:_ _x -> List []);
apply = (fun _params ~training:_ ?rngs:_ x -> Rune.relu x);
}
let sigmoid () =
Model
{
init = (fun ~rngs:_ _x -> List []);
apply = (fun _params ~training:_ ?rngs:_ x -> Rune.sigmoid x);
}
let tanh () =
Model
{
init = (fun ~rngs:_ _x -> List []);
apply = (fun _params ~training:_ ?rngs:_ x -> Rune.tanh x);
}
let sequential models =
Model
{
init =
(fun ~rngs x ->
let rec init_layers models x acc rngs_current layer_idx =
match models with
| [] -> List (List.rev acc)
| Model m :: rest ->
let rngs_layer, rngs_rest = Rngs.split rngs_current in
let params = m.init ~rngs:rngs_layer x in
let x' = m.apply params ~training:false x in
init_layers rest x' (params :: acc) rngs_rest (layer_idx + 1)
in
init_layers models x [] rngs 1);
apply =
(fun params ~training ?rngs:_ x ->
match params with
| List param_list ->
let rec apply_layers models params x layer_idx =
match (models, params) with
| [], [] -> x
| Model m :: ms, p :: ps ->
let x' = m.apply p ~training x in
apply_layers ms ps x' (layer_idx + 1)
| _ -> failwith "sequential: mismatched models and params"
in
apply_layers models param_list x 1
| _ -> failwith "sequential: invalid params structure");
}
end
module Optimizer = struct
type transform =
| SGD of { lr : float; momentum : float option }
| Adam of { lr : float; beta1 : float; beta2 : float; eps : float }
| AdamW of {
lr : float;
beta1 : float;
beta2 : float;
eps : float;
weight_decay : float;
}
type ('layout, 'dev) state = {
m_tensors : ('layout, 'dev) tensor list;
v_tensors : ('layout, 'dev) tensor list;
}
type ('layout, 'dev) t = {
transform : transform;
mutable state : ('layout, 'dev) state option;
mutable step : int;
}
let sgd ~lr ?momentum () = SGD { lr; momentum }
let adam ~lr ?(beta1 = 0.9) ?(beta2 = 0.999) ?(eps = 1e-8) () =
Adam { lr; beta1; beta2; eps }
let adamw ~lr ?(beta1 = 0.9) ?(beta2 = 0.999) ?(eps = 1e-8)
?(weight_decay = 0.01) () =
AdamW { lr; beta1; beta2; eps; weight_decay }
let create transform = { transform; state = None; step = 0 }
let rec apply_updates_inplace : type a b.
(a, b) params -> (a, b) params -> unit =
fun params updates ->
match (params, updates) with
| Tensor t, Tensor u -> ignore (Rune.isub t u)
| List ps, List us -> List.iter2 apply_updates_inplace ps us
| Record ps, Record us ->
let sorted_ps =
List.sort (fun (k1, _) (k2, _) -> String.compare k1 k2) ps
in
let sorted_us =
List.sort (fun (k1, _) (k2, _) -> String.compare k1 k2) us
in
List.iter2
(fun (k1, p) (k2, u) ->
assert (k1 = k2);
apply_updates_inplace p u)
sorted_ps sorted_us
| _ -> failwith "Mismatched parameter structure"
let update opt params grads =
opt.step <- opt.step + 1;
let params_tensors, rebuild_params = flatten_ptree params in
let grads_tensors, _ = flatten_ptree grads in
match opt.transform with
| SGD { lr; momentum } ->
let updates =
match momentum with
| None ->
List.map
(fun g ->
let dev = Rune.device g in
let dt = Rune.dtype g in
Rune.mul g (Rune.scalar dev dt lr))
grads_tensors
| Some momentum_val ->
let state =
match opt.state with
| None ->
let v_tensors = List.map Rune.zeros_like params_tensors in
let s = { m_tensors = v_tensors; v_tensors = [] } in
opt.state <- Some s;
s
| Some s -> s
in
let new_velocities, updates =
List.fold_left2
(fun (v_acc, u_acc) g v_old ->
let dev = Rune.device g in
let dt = Rune.dtype g in
let v_new =
Rune.(add (mul (scalar dev dt momentum_val) v_old) g)
in
let update = Rune.mul (Rune.scalar dev dt lr) v_new in
(v_new :: v_acc, update :: u_acc))
([], []) grads_tensors state.m_tensors
in
opt.state <-
Some { m_tensors = List.rev new_velocities; v_tensors = [] };
List.rev updates
in
apply_updates_inplace params (rebuild_params updates)
| Adam { lr; beta1; beta2; eps } ->
let state =
match opt.state with
| None ->
let m_tensors = List.map Rune.zeros_like params_tensors in
let v_tensors = List.map Rune.zeros_like params_tensors in
let s = { m_tensors; v_tensors } in
opt.state <- Some s;
s
| Some s -> s
in
let t = opt.step in
let bc1 = 1. -. (beta1 ** float_of_int t) in
let bc2 = 1. -. (beta2 ** float_of_int t) in
let new_m_tensors, new_v_tensors, updates =
List.fold_left2
(fun (m_acc, v_acc, u_acc) (p, g) (m_old, v_old) ->
let dev = Rune.device p in
let dt = Rune.dtype p in
let m_new =
Rune.(
add
(mul (scalar dev dt beta1) m_old)
(mul (scalar dev dt (1. -. beta1)) g))
in
let v_new =
Rune.(
add
(mul (scalar dev dt beta2) v_old)
(mul (scalar dev dt (1. -. beta2)) (mul g g)))
in
let m_hat = Rune.div m_new (Rune.scalar dev dt bc1) in
let v_hat = Rune.div v_new (Rune.scalar dev dt bc2) in
let update =
Rune.(
mul (scalar dev dt lr)
(div m_hat (add (sqrt v_hat) (scalar dev dt eps))))
in
(m_new :: m_acc, v_new :: v_acc, update :: u_acc))
([], [], [])
(List.combine params_tensors grads_tensors)
(List.combine state.m_tensors state.v_tensors)
in
opt.state <-
Some
{
m_tensors = List.rev new_m_tensors;
v_tensors = List.rev new_v_tensors;
};
apply_updates_inplace params (rebuild_params (List.rev updates))
| AdamW { lr; beta1; beta2; eps; weight_decay } ->
let state =
match opt.state with
| None ->
let m_tensors = List.map Rune.zeros_like params_tensors in
let v_tensors = List.map Rune.zeros_like params_tensors in
let s = { m_tensors; v_tensors } in
opt.state <- Some s;
s
| Some s -> s
in
let t = opt.step in
let bc1 = 1. -. (beta1 ** float_of_int t) in
let bc2 = 1. -. (beta2 ** float_of_int t) in
let new_m_tensors, new_v_tensors, updates =
List.fold_left2
(fun (m_acc, v_acc, u_acc) (p, g) (m_old, v_old) ->
let dev = Rune.device p in
let dt = Rune.dtype p in
let m_new =
Rune.(
add
(mul (scalar dev dt beta1) m_old)
(mul (scalar dev dt (1. -. beta1)) g))
in
let v_new =
Rune.(
add
(mul (scalar dev dt beta2) v_old)
(mul (scalar dev dt (1. -. beta2)) (mul g g)))
in
let m_hat = Rune.div m_new (Rune.scalar dev dt bc1) in
let v_hat = Rune.div v_new (Rune.scalar dev dt bc2) in
let adam_update =
Rune.(
mul (scalar dev dt lr)
(div m_hat (add (sqrt v_hat) (scalar dev dt eps))))
in
let decay_update =
Rune.mul (Rune.scalar dev dt (lr *. weight_decay)) p
in
let total_update = Rune.add adam_update decay_update in
(m_new :: m_acc, v_new :: v_acc, total_update :: u_acc))
([], [], [])
(List.combine params_tensors grads_tensors)
(List.combine state.m_tensors state.v_tensors)
in
opt.state <-
Some
{
m_tensors = List.rev new_m_tensors;
v_tensors = List.rev new_v_tensors;
};
apply_updates_inplace params (rebuild_params (List.rev updates))
end