1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
open Bigarray_ext
type error = Error.t =
| Io_error of string
| Format_error of string
| Unsupported_dtype
| Unsupported_shape
| Missing_entry of string
| Other of string
type layout = Scalar | Vector of int | Matrix of int * int
let layout_of_shape shape =
match shape with
| [||] -> Some Scalar
| [| n |] -> Some (Vector n)
| [| rows; cols |] -> Some (Matrix (rows, cols))
| _ -> None
let option_exists pred = function Some x -> pred x | None -> false
let split_lines_opt = function
| None -> []
| Some text -> String.split_on_char '\n' text
let rec trim_trailing_whitespace s =
if s = "" then ""
else
let len = String.length s in
match s.[len - 1] with
| ' ' | '\t' -> trim_trailing_whitespace (String.sub s 0 (len - 1))
| _ -> s
let try_parse f s = try Some (f s) with _ -> None
let float_of_string_opt = try_parse float_of_string
let int_of_string_opt = try_parse int_of_string
let int32_of_string_opt = try_parse Int32.of_string
let int64_of_string_opt = try_parse Int64.of_string
let nativeint_of_string_opt = try_parse Nativeint.of_string
module type SPEC = sig
type elt
type kind
val kind : (elt, kind) Bigarray_ext.kind
val print : out_channel -> elt -> unit
val parse : string -> (elt, error) result
end
let invalid_literal dtype_name token =
Format_error
(Printf.sprintf "Invalid %s literal: %S" dtype_name
(trim_trailing_whitespace token))
let out_of_range dtype_name token =
Format_error
(Printf.sprintf "Value %S is out of range for %s"
(trim_trailing_whitespace token)
dtype_name)
let parse_float dtype token =
match float_of_string_opt token with
| Some v -> Ok v
| None -> Error (invalid_literal dtype token)
let parse_bool token =
let lowered = String.lowercase_ascii (String.trim token) in
match lowered with
| "true" | "t" | "yes" | "y" -> Ok true
| "false" | "f" | "no" | "n" -> Ok false
| _ -> (
match int_of_string_opt lowered with
| Some 0 -> Ok false
| Some _ -> Ok true
| None -> (
match float_of_string_opt lowered with
| Some f -> Ok (abs_float f > 0.0)
| None -> Error (invalid_literal "bool" token)))
let parse_int_with_bounds dtype token ~min ~max =
match int_of_string_opt token with
| Some v when v >= min && v <= max -> Ok v
| Some _ -> Error (out_of_range dtype token)
| None -> Error (invalid_literal dtype token)
let spec_of_dtype (type a) (type b) (dtype : (a, b) Nx.dtype) :
(module SPEC with type elt = a and type kind = b) option =
let dtype_name = Nx_core.Dtype.to_string dtype in
let module M (X : sig
type elt
type kind
val kind : (elt, kind) Bigarray_ext.kind
val print : out_channel -> elt -> unit
val parse : string -> (elt, error) result
end) =
struct
include X
end in
let open Nx_core.Dtype in
match dtype with
| Float16 ->
let module S = M (struct
type elt = float
type kind = Bigarray_ext.float16_elt
let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
let print oc v = Printf.fprintf oc "%.18e" v
let parse token = parse_float dtype_name token
end) in
Some (module S : SPEC with type elt = a and type kind = b)
| Float32 ->
let module S = M (struct
type elt = float
type kind = Bigarray_ext.float32_elt
let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
let print oc v = Printf.fprintf oc "%.18e" v
let parse token = parse_float dtype_name token
end) in
Some (module S : SPEC with type elt = a and type kind = b)
| Float64 ->
let module S = M (struct
type elt = float
type kind = Bigarray_ext.float64_elt
let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
let print oc v = Printf.fprintf oc "%.18e" v
let parse token = parse_float dtype_name token
end) in
Some (module S : SPEC with type elt = a and type kind = b)
| BFloat16 ->
let module S = M (struct
type elt = float
type kind = Bigarray_ext.bfloat16_elt
let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
let print oc v = Printf.fprintf oc "%.18e" v
let parse token = parse_float dtype_name token
end) in
Some (module S : SPEC with type elt = a and type kind = b)
| Int8 ->
let module S = M (struct
type elt = int
type kind = Bigarray_ext.int8_signed_elt
let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
let print oc v = output_string oc (string_of_int v)
let parse token =
parse_int_with_bounds dtype_name token ~min:(-128) ~max:127
end) in
Some (module S : SPEC with type elt = a and type kind = b)
| UInt8 ->
let module S = M (struct
type elt = int
type kind = Bigarray_ext.int8_unsigned_elt
let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
let print oc v = output_string oc (string_of_int v)
let parse token = parse_int_with_bounds dtype_name token ~min:0 ~max:255
end) in
Some (module S : SPEC with type elt = a and type kind = b)
| Int16 ->
let module S = M (struct
type elt = int
type kind = Bigarray_ext.int16_signed_elt
let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
let print oc v = output_string oc (string_of_int v)
let parse token =
parse_int_with_bounds dtype_name token ~min:(-32768) ~max:32767
end) in
Some (module S : SPEC with type elt = a and type kind = b)
| UInt16 ->
let module S = M (struct
type elt = int
type kind = Bigarray_ext.int16_unsigned_elt
let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
let print oc v = output_string oc (string_of_int v)
let parse token =
parse_int_with_bounds dtype_name token ~min:0 ~max:65535
end) in
Some (module S : SPEC with type elt = a and type kind = b)
| Int32 ->
let module S = M (struct
type elt = int32
type kind = Bigarray_ext.int32_elt
let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
let print oc v = output_string oc (Int32.to_string v)
let parse token =
match int32_of_string_opt token with
| Some v -> Ok v
| None -> Error (invalid_literal dtype_name token)
end) in
Some (module S : SPEC with type elt = a and type kind = b)
| Int64 ->
let module S = M (struct
type elt = int64
type kind = Bigarray_ext.int64_elt
let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
let print oc v = output_string oc (Int64.to_string v)
let parse token =
match int64_of_string_opt token with
| Some v -> Ok v
| None -> Error (invalid_literal dtype_name token)
end) in
Some (module S : SPEC with type elt = a and type kind = b)
| Int ->
let module S = M (struct
type elt = int
type kind = Bigarray_ext.int_elt
let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
let print oc v = output_string oc (string_of_int v)
let parse token =
match int_of_string_opt token with
| Some v -> Ok v
| None -> Error (invalid_literal dtype_name token)
end) in
Some (module S : SPEC with type elt = a and type kind = b)
| NativeInt ->
let module S = M (struct
type elt = nativeint
type kind = Bigarray_ext.nativeint_elt
let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
let print oc v = output_string oc (Nativeint.to_string v)
let parse token =
match nativeint_of_string_opt token with
| Some v -> Ok v
| None -> Error (invalid_literal dtype_name token)
end) in
Some (module S : SPEC with type elt = a and type kind = b)
| Bool ->
let module S = M (struct
type elt = bool
type kind = Bigarray_ext.bool_elt
let kind = Nx_core.Dtype.to_bigarray_ext_kind dtype
let print oc v = output_string oc (if v then "1" else "0")
let parse = parse_bool
end) in
Some (module S : SPEC with type elt = a and type kind = b)
| _ -> None
let save ?(sep = " ") ?(append = false) ?(newline = "\n") ? ?
?( = "# ") ~out (type a) (type b) (arr : (a, b) Nx.t) =
match layout_of_shape (Nx.shape arr) with
| None -> Error Unsupported_shape
| Some layout -> (
match spec_of_dtype (Nx.dtype arr) with
| None -> Error Unsupported_dtype
| Some spec_module -> (
let module S =
(val spec_module : SPEC with type elt = a and type kind = b)
in
let perm = 0o666 in
let flags =
if append then [ Open_wronly; Open_creat; Open_append; Open_text ]
else [ Open_wronly; Open_creat; Open_trunc; Open_text ]
in
try
let oc = open_out_gen flags perm out in
Fun.protect
~finally:(fun () -> close_out oc)
(fun () ->
let write_prefixed line =
if comments <> "" then output_string oc comments;
output_string oc line;
output_string oc newline
in
List.iter write_prefixed (split_lines_opt header);
let data =
(Nx.to_bigarray_ext arr
: (S.elt, S.kind, Bigarray.c_layout) Genarray.t)
in
(match layout with
| Scalar ->
let value = Genarray.get data [||] in
S.print oc value;
output_string oc newline
| Vector length ->
let view = array1_of_genarray data in
for j = 0 to length - 1 do
if j > 0 then output_string oc sep;
S.print oc (Array1.unsafe_get view j)
done;
output_string oc newline
| Matrix (rows, cols) ->
let view = array2_of_genarray data in
for i = 0 to rows - 1 do
for j = 0 to cols - 1 do
if j > 0 then output_string oc sep;
S.print oc (Array2.unsafe_get view i j)
done;
output_string oc newline
done);
List.iter write_prefixed (split_lines_opt footer);
Ok ())
with
| Sys_error msg -> Error (Io_error msg)
| Unix.Unix_error (e, _, _) -> Error (Io_error (Unix.error_message e))
))
let load ?(sep = " ") ?( = "#") ?(skiprows = 0) ?max_rows (type a)
(type b) (dtype : (a, b) Nx.dtype) path =
if skiprows < 0 then Error (Format_error "skiprows must be non-negative")
else if option_exists (fun rows -> rows <= 0) max_rows then
Error (Format_error "max_rows must be strictly positive")
else
match spec_of_dtype dtype with
| None -> Error Unsupported_dtype
| Some spec_module -> (
let module S =
(val spec_module : SPEC with type elt = a and type kind = b)
in
try
let ic = open_in path in
Fun.protect
~finally:(fun () -> close_in ic)
(fun () ->
let = String.trim comments in
let line =
if comment_prefix = "" then false
else
let trimmed = String.trim line in
let len = String.length comment_prefix in
String.length trimmed >= len
&& String.sub trimmed 0 len = comment_prefix
in
let split_fields line =
let trimmed = String.trim line in
if trimmed = "" then [||]
else if sep = "" then [| trimmed |]
else if String.length sep = 1 then
trimmed
|> String.split_on_char sep.[0]
|> List.filter (fun s -> s <> "")
|> Array.of_list
else
let len_sep = String.length sep in
let len = String.length trimmed in
let rec aux acc start =
if start >= len then List.rev acc
else
match String.index_from_opt trimmed start sep.[0] with
| None ->
let part = String.sub trimmed start (len - start) in
if part = "" then List.rev acc
else List.rev (part :: acc)
| Some idx ->
if
idx + len_sep <= len
&& String.sub trimmed idx len_sep = sep
then
let part = String.sub trimmed start (idx - start) in
let acc = if part = "" then acc else part :: acc in
aux acc (idx + len_sep)
else aux acc (idx + 1)
in
aux [] 0 |> Array.of_list
in
let rows_rev = ref [] in
let cols = ref None in
let rows_read = ref 0 in
let read_error = ref None in
let rec loop skip_remaining =
if Option.is_some !read_error then ()
else if option_exists (fun rows -> !rows_read >= rows) max_rows
then ()
else
match input_line ic with
| line ->
if skip_remaining > 0 then loop (skip_remaining - 1)
else if is_comment_line line then loop 0
else
let fields = split_fields line in
if Array.length fields = 0 then loop 0
else (
(match !cols with
| None -> cols := Some (Array.length fields)
| Some expected ->
if Array.length fields <> expected then
read_error :=
Some
(Format_error
"Inconsistent number of columns"));
if Option.is_none !read_error then (
rows_rev := fields :: !rows_rev;
incr rows_read);
loop 0)
| exception End_of_file -> ()
in
loop skiprows;
let parsed_result =
match (!read_error, !cols, !rows_rev) with
| Some err, _, _ -> Error err
| _, None, _ -> Error (Format_error "No data found")
| _, _, [] -> Error (Format_error "No data found")
| _, Some col_count, rows_rev_list -> (
let rows = List.rev rows_rev_list |> Array.of_list in
let row_count = Array.length rows in
let dims = [| row_count; col_count |] in
let ba = Genarray.create S.kind Bigarray.c_layout dims in
let parse_error = ref None in
for i = 0 to row_count - 1 do
let row = rows.(i) in
for j = 0 to col_count - 1 do
if Option.is_none !parse_error then
match S.parse row.(j) with
| Ok value -> Genarray.set ba [| i; j |] value
| Error err -> parse_error := Some err
done
done;
match !parse_error with
| Some err -> Error err
| None ->
let tensor = Nx.of_bigarray_ext ba in
let result =
if row_count = 1 then
Nx.reshape [| col_count |] tensor
else if col_count = 1 then
Nx.reshape [| row_count |] tensor
else tensor
in
Ok result)
in
parsed_result)
with
| Sys_error msg -> Error (Io_error msg)
| Unix.Unix_error (e, _, _) -> Error (Io_error (Unix.error_message e)))