Source file kaun_datasets.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
let mnist ?(fashion = false) ?(normalize = true) ?(data_format = `NCHW) () =
let (train_images, train_labels), (test_images, test_labels) =
Mnist.load ~fashion_mnist:fashion
in
let make_tensors images labels =
let n = Bigarray.Array3.dim1 images in
let h = Bigarray.Array3.dim2 images in
let w = Bigarray.Array3.dim3 images in
let x =
Nx.of_bigarray (Bigarray.genarray_of_array3 images)
|> Nx.reshape [| n; h; w; 1 |]
|> Nx.cast Nx.float32
in
let x = if normalize then Nx.div_s x 255.0 else x in
let x =
match data_format with
| `NCHW -> Nx.transpose x ~axes:[ 0; 3; 1; 2 ]
| `NHWC -> x
in
let y =
Nx.of_bigarray (Bigarray.genarray_of_array1 labels) |> Nx.cast Nx.int32
in
(x, y)
in
let train = make_tensors train_images train_labels in
let test = make_tensors test_images test_labels in
(train, test)
let cifar10 ?(normalize = true) ?(data_format = `NCHW) () =
let (train_images, train_labels), (test_images, test_labels) =
Cifar10.load ()
in
let make_tensors images labels =
let x = Nx.of_bigarray images |> Nx.cast Nx.float32 in
let x = if normalize then Nx.div_s x 255.0 else x in
let x =
match data_format with
| `NCHW -> x
| `NHWC -> Nx.transpose x ~axes:[ 0; 2; 3; 1 ]
in
let y =
Nx.of_bigarray (Bigarray.genarray_of_array1 labels) |> Nx.cast Nx.int32
in
(x, y)
in
let train = make_tensors train_images train_labels in
let test = make_tensors test_images test_labels in
(train, test)