1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
open Base
open Torch
type buffer = (int, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t
external resize_
: in_data:buffer
-> in_w:int
-> in_h:int
-> out_data:buffer
-> out_w:int
-> out_h:int
-> nchannels:int
-> int
= "ml_stbir_resize_bytecode" "ml_stbir_resize"
let tensor_of_data ~data ~width ~height =
Tensor.of_bigarray (Bigarray.genarray_of_array1 data)
|> Tensor.view ~size:[ 1; height; width; 3 ]
|> Tensor.permute ~dims:[ 0; 3; 1; 2 ]
let maybe_crop tensor ~dim ~length ~target_length =
if length < target_length
then
Printf.failwithf "target_length greater than length %d > %d" target_length length ();
if length = target_length
then tensor
else
Tensor.narrow tensor ~dim ~start:((length - target_length) / 2) ~length:target_length
let load_image ?resize image_file =
Stb_image.load image_file
|> Result.bind ~f:(fun (image : _ Stb_image.t) ->
if image.channels = 3
then
(match resize with
| None ->
tensor_of_data ~data:image.data ~width:image.width ~height:image.height
| Some (target_width, target_height) ->
let resize_width, resize_height =
let ratio_w = Float.of_int target_width /. Float.of_int image.width in
let ratio_h = Float.of_int target_height /. Float.of_int image.height in
let r = Float.max ratio_w ratio_h in
( Float.to_int (r *. Float.of_int image.width) |> Int.max target_width
, Float.to_int (r *. Float.of_int image.height) |> Int.max target_height )
in
let out_data =
Bigarray.Array1.create
Int8_unsigned
C_layout
(resize_width * resize_height * 3)
in
let status =
resize_
~in_data:image.data
~in_w:image.width
~in_h:image.height
~out_data
~out_w:resize_width
~out_h:resize_height
~nchannels:3
in
if status = 0 then Printf.failwithf "error when resizing %s" image_file ();
let tensor =
tensor_of_data ~data:out_data ~width:resize_width ~height:resize_height
in
maybe_crop tensor ~dim:3 ~length:resize_width ~target_length:target_width
|> maybe_crop ~dim:2 ~length:resize_height ~target_length:target_height)
|> Result.return
else Error (`Msg (Printf.sprintf "%d channels <> 3" image.channels)))
|> Result.map_error ~f:(fun (`Msg msg) -> Error.of_string msg)
let image_suffixes = [ ".jpg"; ".png" ]
let has_image_suffix filename =
List.exists image_suffixes ~f:(fun suffix -> String.is_suffix filename ~suffix)
let load_images ?resize dir =
if not (Caml.Sys.is_directory dir) then Printf.failwithf "not a directory %s" dir ();
let files = Caml.Sys.readdir dir |> Array.to_list in
Stdio.printf "%d files found in %s\n%!" (List.length files) dir;
List.filter_map files ~f:(fun filename ->
if has_image_suffix filename
then (
match load_image (Caml.Filename.concat dir filename) ?resize with
| Ok image -> Some image
| Error msg ->
Stdio.printf "%s: %s\n%!" filename (Error.to_string_hum msg);
None)
else None)
|> Tensor.cat ~dim:0
let load_dataset ~dir ~classes ~with_cache ~resize =
let read () =
let load tv s = load_images (Printf.sprintf "%s/%s/%s" dir tv s) ~resize in
let load_tv tv =
List.mapi classes ~f:(fun class_index class_dir ->
let images = load tv class_dir in
let labels =
Tensor.zeros [ Tensor.shape images |> List.hd_exn ] ~kind:(T Int64)
in
Tensor.fill_int labels class_index;
images, labels)
|> List.unzip
|> fun (images, labels) -> Tensor.cat images ~dim:0, Tensor.cat labels ~dim:0
in
let train_images, train_labels = load_tv "train" in
let test_images, test_labels = load_tv "val" in
{ Dataset_helper.train_images; train_labels; test_images; test_labels }
in
match with_cache with
| None -> read ()
| Some cache_file -> Dataset_helper.read_with_cache ~cache_file ~read
let write_image tensor ~filename =
let tensor, height, width, channels, chw =
match Tensor.shape tensor with
| [ 1; channels; b; c ] when channels = 1 || channels = 3 ->
Tensor.reshape tensor ~shape:[ channels; b; c ], b, c, channels, true
| [ channels; b; c ] when channels = 1 || channels = 3 -> tensor, b, c, channels, true
| [ b; c; channels ] when channels = 1 || channels = 3 ->
tensor, b, c, channels, false
| _ -> Printf.failwithf "unexpected shape %s" (Tensor.shape_str tensor) ()
in
let bigarray =
let tensor =
if chw then Tensor.permute tensor ~dims:[ 1; 2; 0 ] |> Tensor.contiguous else tensor
in
Tensor.view tensor ~size:[ channels * height * width ]
|> Tensor.to_type ~type_:(T Uint8)
|> Tensor.to_bigarray ~kind:Int8_unsigned
|> Bigarray.array1_of_genarray
in
match String.rsplit2 filename ~on:'.' with
| Some (_, "jpg") ->
Stb_image_write.jpg filename bigarray ~w:width ~h:height ~c:channels ~quality:90
| Some (_, "tga") ->
Stb_image_write.tga filename bigarray ~w:width ~h:height ~c:channels
| Some (_, "bmp") ->
Stb_image_write.bmp filename bigarray ~w:width ~h:height ~c:channels
| Some (_, "png") ->
Stb_image_write.png filename bigarray ~w:width ~h:height ~c:channels
| Some _ | None ->
Stb_image_write.png (filename ^ ".png") bigarray ~w:width ~h:height ~c:channels
module Loader = struct
type t =
{ dir_and_filenames : (string * string) array
; mutable current_index : int
; resize : (int * int) option
}
let nsamples t = Array.length t.dir_and_filenames
let create ?resize ~dir () =
if not (Caml.Sys.is_directory dir) then Printf.failwithf "not a directory %s" dir ();
let dir_and_filenames = ref [] in
let rec walk_dir dir =
Caml.Sys.readdir dir
|> Array.iter ~f:(fun file ->
let next_dir = Caml.Filename.concat dir file in
if Caml.Sys.is_directory next_dir
then walk_dir next_dir
else if has_image_suffix file
then dir_and_filenames := (dir, file) :: !dir_and_filenames)
in
walk_dir dir;
{ dir_and_filenames = Array.of_list !dir_and_filenames; current_index = 0; resize }
let reset t = t.current_index <- 0
let next_batch t ~batch_size =
let nsamples = nsamples t in
let batch_size = Int.min batch_size (nsamples - t.current_index) in
if batch_size = 0
then None
else (
let batch =
List.init batch_size ~f:(fun i ->
let dir, file = t.dir_and_filenames.(t.current_index + i) in
load_image ?resize:t.resize (Caml.Filename.concat dir file) |> Or_error.ok_exn)
|> Tensor.cat ~dim:0
in
t.current_index <- t.current_index + batch_size;
Some batch)
let random_batch t ~batch_size =
let nsamples = nsamples t in
List.init batch_size ~f:(fun _ ->
let dir, file = t.dir_and_filenames.(Random.int nsamples) in
load_image ?resize:t.resize (Caml.Filename.concat dir file) |> Or_error.ok_exn)
|> Tensor.cat ~dim:0
end
let resize tensor ~height ~width =
let resize_one tensor =
match Tensor.shape tensor with
| [ c; h; w ] when c = 1 || c = 3 ->
let in_data =
Tensor.permute tensor ~dims:[ 1; 2; 0 ]
|> Tensor.contiguous
|> Tensor.view ~size:[ c * h * w ]
|> Tensor.to_type ~type_:(T Uint8)
|> Tensor.to_bigarray ~kind:Int8_unsigned
|> Bigarray.array1_of_genarray
in
let out_data = Bigarray.Array1.create Int8_unsigned C_layout (width * height * c) in
let status =
resize_ ~in_data ~in_w:w ~in_h:h ~out_data ~out_w:width ~out_h:height ~nchannels:c
in
if status = 0 then failwith "error when resizing image";
tensor_of_data ~data:out_data ~width ~height
| _ -> assert false
in
match Tensor.shape tensor with
| [ c; _; _ ] when c = 1 || c = 3 -> resize_one tensor
| [ _; c; _; _ ] when c = 1 || c = 3 ->
Tensor.to_list tensor |> List.map ~f:resize_one |> Tensor.cat ~dim:0
| _ -> Printf.failwithf "unexpected shape %s" (Tensor.shape_str tensor) ()