1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
open Bigarray
open Dataset_utils
module Config = struct
type t = {
name : string;
cache_subdir : string;
train_images_url : string;
train_labels_url : string;
test_images_url : string;
test_labels_url : string;
image_magic_number : int;
label_magic_number : int;
}
let mnist =
{
name = "MNIST";
cache_subdir = "mnist/";
train_images_url = "https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz";
train_labels_url = "https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz";
test_images_url = "https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz";
test_labels_url = "https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz";
image_magic_number = 2051;
label_magic_number = 2049;
}
let fashion_mnist =
{
name = "Fashion-MNIST";
cache_subdir = "fashion-mnist/";
train_images_url =
"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz";
train_labels_url =
"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz";
test_images_url =
"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz";
test_labels_url =
"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz";
image_magic_number = 2051;
label_magic_number = 2049;
}
end
let mnist_config = Config.mnist
let fashion_mnist_config = Config.fashion_mnist
let read_int32_be s pos =
let b1 = Char.code s.[pos] in
let b2 = Char.code s.[pos + 1] in
let b3 = Char.code s.[pos + 2] in
let b4 = Char.code s.[pos + 3] in
(b1 lsl 24) lor (b2 lsl 16) lor (b3 lsl 8) lor b4
let ensure_dataset config =
let dataset_dir = get_cache_dir config.Config.cache_subdir in
mkdir_p dataset_dir;
let files_to_process =
[
("train-images-idx3-ubyte", config.Config.train_images_url);
("train-labels-idx1-ubyte", config.Config.train_labels_url);
("t10k-images-idx3-ubyte", config.Config.test_images_url);
("t10k-labels-idx1-ubyte", config.Config.test_labels_url);
]
in
List.iter
(fun (base_filename, url) ->
let gz_filename = base_filename ^ ".gz" in
let gz_path = dataset_dir ^ gz_filename in
let path = dataset_dir ^ base_filename in
if not (Sys.file_exists path) then (
Printf.printf "File %s not found for %s dataset.\n%!" base_filename config.name;
ensure_file url gz_path;
if not (ensure_decompressed_gz ~gz_path ~target_path:path) then
failwith (Printf.sprintf "Failed to obtain decompressed file %s" path))
else Printf.printf "Found decompressed file %s.\n%!" path)
files_to_process
let read_idx_file ~ ~create_array ~populate_array ~expected_magic config filename =
Printf.printf "Reading %s file: %s\n%!" config.Config.name filename;
let ic = open_in_bin filename in
let s =
try really_input_string ic (in_channel_length ic)
with exn ->
close_in_noerr ic;
failwith (Printf.sprintf "Error reading file %s: %s" filename (Printexc.to_string exn))
in
close_in ic;
let magic = read_int32_be s 0 in
if magic <> expected_magic then
failwith
(Printf.sprintf "Invalid magic number %d in %s (expected %d)" magic filename expected_magic);
let dimensions, data_offset = read_header s in
let total_items, data_len =
match dimensions with
| [| d1 |] -> (d1, d1)
| [| d1; d2; d3 |] -> (d1, d1 * d2 * d3)
| _ -> failwith "Unsupported dimension format"
in
let expected_len = data_offset + data_len in
if String.length s <> expected_len then
failwith
(Printf.sprintf "File %s has unexpected length: %d vs %d (header offset %d, data len %d)"
filename (String.length s) expected_len data_offset data_len);
let arr = create_array dimensions in
populate_array arr s data_offset total_items;
arr
let read_images config filename =
let s =
let num_images = read_int32_be s 4 in
let num_rows = read_int32_be s 8 in
let num_cols = read_int32_be s 12 in
([| num_images; num_rows; num_cols |], 16)
in
let create_array dims = Genarray.create int8_unsigned c_layout dims in
let populate_array arr s offset _ =
let dims = Genarray.dims arr in
let num_images = dims.(0) in
let num_rows = dims.(1) in
let num_cols = dims.(2) in
let img_size = num_rows * num_cols in
for i = 0 to num_images - 1 do
let start_pos = offset + (i * img_size) in
for r = 0 to num_rows - 1 do
for c = 0 to num_cols - 1 do
let pos = start_pos + (r * num_cols) + c in
Genarray.set arr [| i; r; c |] (Char.code s.[pos])
done
done
done
in
read_idx_file ~read_header ~create_array ~populate_array
~expected_magic:config.Config.image_magic_number config filename
let read_labels config filename =
let s =
let num_labels = read_int32_be s 4 in
([| num_labels |], 8)
in
let create_array dims = Genarray.create int8_unsigned c_layout dims in
let populate_array arr s offset total_items =
for i = 0 to total_items - 1 do
Genarray.set arr [| i |] (Char.code s.[offset + i])
done
in
read_idx_file ~read_header ~create_array ~populate_array
~expected_magic:config.Config.label_magic_number config filename
let load ~fashion_mnist =
let config = if fashion_mnist then Config.fashion_mnist else Config.mnist in
ensure_dataset config;
let dataset_dir = get_cache_dir config.Config.cache_subdir in
let train_images_path = dataset_dir ^ "train-images-idx3-ubyte" in
let train_labels_path = dataset_dir ^ "train-labels-idx1-ubyte" in
let test_images_path = dataset_dir ^ "t10k-images-idx3-ubyte" in
let test_labels_path = dataset_dir ^ "t10k-labels-idx1-ubyte" in
Printf.printf "Loading %s datasets...\n%!" config.name;
let train_images = read_images config train_images_path in
let train_labels = read_labels config train_labels_path in
let test_images = read_images config test_images_path in
let test_labels = read_labels config test_labels_path in
Printf.printf "%s loading complete.\n%!" config.name;
((train_images, train_labels), (test_images, test_labels))