1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
let strf = Printf.sprintf
exception Read_error of string
let read_error fmt = Printf.ksprintf (fun s -> raise (Read_error s)) fmt
let magic = "\147NUMPY"
let magic_len = String.length magic
let really_write_fd fd buf off len =
let rec loop off remaining =
if remaining > 0 then
let w = Unix.write fd buf off remaining in
loop (off + w) (remaining - w)
in
loop off len
let as_flat_c ga =
let n = Array.fold_left ( * ) 1 (Nx_buffer.genarray_dims ga) in
let ga = Nx_buffer.genarray_change_layout ga Bigarray.C_layout in
(n, Nx_buffer.of_genarray (Bigarray.reshape ga [| n |]))
let write_genarray_to_fd fd ga =
let n, buf = as_flat_c ga in
let byte_size =
n * Nx_buffer.kind_size_in_bytes (Nx_buffer.genarray_kind ga)
in
let bytes = Bytes.create byte_size in
Nx_buffer.blit_to_bytes ~src_off:0 ~dst_off:0 ~len:n buf bytes;
really_write_fd fd bytes 0 byte_size
let read_fd_to_genarray fd ga =
let n, buf = as_flat_c ga in
let byte_size =
n * Nx_buffer.kind_size_in_bytes (Nx_buffer.genarray_kind ga)
in
let bytes = Bytes.create byte_size in
let rec loop off =
if off < byte_size then (
let r = Unix.read fd bytes off (byte_size - off) in
if r = 0 then read_error "unexpected eof reading tensor data";
loop (off + r))
in
loop 0;
Nx_buffer.blit_from_bytes ~src_off:0 ~dst_off:0 ~len:n bytes buf
type packed_kind = K : (_, _) Nx_buffer.kind -> packed_kind
let dtype_string (K kind) =
let endian =
match kind with
| Nx_buffer.Int8_signed | Int8_unsigned | Bool -> "|"
| _ -> if Sys.big_endian then ">" else "<"
in
let descr =
match kind with
| Nx_buffer.Float16 -> "f2"
| Float32 -> "f4"
| Float64 -> "f8"
| Bfloat16 -> "f2"
| Float8_e4m3 -> "f1"
| Float8_e5m2 -> "f1"
| Int8_signed -> "i1"
| Int8_unsigned -> "u1"
| Int16_signed -> "i2"
| Int16_unsigned -> "u2"
| Int32 -> "i4"
| Int64 -> "i8"
| Uint32 -> "u4"
| Uint64 -> "u8"
| Int4_signed -> "i1"
| Int4_unsigned -> "u1"
| Complex32 -> "c8"
| Complex64 -> "c16"
| Bool -> "b1"
in
endian ^ descr
let kind_of_descr = function
| "f4" -> K Float32
| "f8" -> K Float64
| "i4" -> K Int32
| "i8" -> K Int64
| "u4" -> K Uint32
| "u8" -> K Uint64
| "u1" -> K Int8_unsigned
| "i1" -> K Int8_signed
| "u2" -> K Int16_unsigned
| "i2" -> K Int16_signed
| "c8" -> K Complex32
| "c16" -> K Complex64
| "b1" -> K Bool
| s -> read_error "unsupported dtype descriptor %s" s
let str ~on =
let parens = ref 0 in
let cuts = ref [] in
for i = 0 to String.length str - 1 do
match str.[i] with
| '(' -> incr parens
| ')' -> decr parens
| c when !parens = 0 && c = on -> cuts := i :: !cuts
| _ -> ()
done;
List.fold_left
(fun (prev, acc) i -> (i, String.sub str (i + 1) (prev - i - 1) :: acc))
(String.length str, [])
!cuts
|> fun (first, acc) -> String.sub str 0 first :: acc
let str ~on =
let len = String.length str in
let rec scan_left i =
if i >= len then i else if List.mem str.[i] on then scan_left (i + 1) else i
in
let rec scan_right j =
if j <= 0 then j
else if List.mem str.[j - 1] on then scan_right (j - 1)
else j
in
let l = scan_left 0 in
let r = scan_right len in
if l >= r then "" else String.sub str l (r - l)
let s =
let s = header_trim s ~on:[ '{'; ' '; '}'; '\n' ] in
let fields =
header_split s ~on:',' |> List.map String.trim
|> List.filter (fun s -> String.length s > 0)
|> List.map (fun field ->
match header_split field ~on:':' with
| [ name; value ] ->
( header_trim name ~on:[ '\''; ' ' ],
header_trim value ~on:[ '\''; ' '; '('; ')' ] )
| _ -> read_error "unable to parse header field %s" field)
in
let find name =
try List.assoc name fields
with Not_found -> read_error "missing header field %s" name
in
let kind =
let descr = find "descr" in
(match descr.[0] with
| '|' | '=' -> ()
| '>' ->
if not Sys.big_endian then
read_error "big endian data on little endian arch"
| '<' ->
if Sys.big_endian then
read_error "little endian data on big endian arch"
| c -> read_error "unknown endianness marker %c" c);
kind_of_descr (String.sub descr 1 (String.length descr - 1))
in
let fortran_order =
match find "fortran_order" with
| "False" -> false
| "True" -> true
| s -> read_error "invalid fortran_order %s" s
in
let shape =
find "shape" |> header_split ~on:',' |> List.map String.trim
|> List.filter (fun s -> String.length s > 0)
|> List.map int_of_string |> Array.of_list
in
{ kind; fortran_order; shape }
let shape_string dims =
match dims with
| [| n |] -> strf "%d," n
| _ -> Array.to_list dims |> List.map string_of_int |> String.concat ", "
let fortran_string (type a) (layout : a Bigarray.layout) =
match layout with
| Bigarray.C_layout -> "False"
| Bigarray.Fortran_layout -> "True"
let ~layout ~packed_kind ~dims =
let =
strf "{'descr': '%s', 'fortran_order': %s, 'shape': (%s), }"
(dtype_string packed_kind) (fortran_string layout) (shape_string dims)
in
let total_len = String.length header + magic_len + 4 + 1 in
let pad = if total_len mod 16 = 0 then 0 else 16 - (total_len mod 16) in
let = String.length header + pad + 1 in
strf "%s\001\000%c%c%s%s\n" magic
(header_len mod 256 |> Char.chr)
(header_len / 256 |> Char.chr)
header (String.make pad ' ')
let with_fd path flags perm f =
let fd = Unix.openfile path flags perm in
Fun.protect ~finally:(fun () -> Unix.close fd) (fun () -> f fd)
let really_read_fd fd n =
let buf = Bytes.create n in
let rec loop off =
if off >= n then ()
else
let r = Unix.read fd buf off (n - off) in
if r = 0 then read_error "unexpected eof";
loop (off + r)
in
loop 0;
Bytes.to_string buf
let map_or_read fd ~pos kind layout shape =
let is_scalar = Array.length shape = 0 in
let actual = if is_scalar then [| 1 |] else shape in
let ga =
match Nx_buffer.to_stdlib_kind kind with
| Some std_kind -> Unix.map_file fd ~pos std_kind layout false actual
| None ->
let ga = Nx_buffer.genarray_create kind layout actual in
ignore (Unix.lseek fd (Int64.to_int pos) Unix.SEEK_SET);
read_fd_to_genarray fd ga;
ga
in
if is_scalar then Bigarray.reshape ga [||] else ga
type packed = P : (_, _, _) Bigarray.Genarray.t -> packed
let read_copy path =
with_fd path [ O_RDONLY ] 0 @@ fun fd ->
let magic' = really_read_fd fd magic_len in
if magic <> magic' then read_error "not a .npy file (bad magic)";
let version = Char.code (really_read_fd fd 2).[0] in
let hdr_len_bytes =
match version with
| 1 -> 2
| 2 -> 4
| v -> read_error "unsupported npy version %d" v
in
let hdr_len_str = really_read_fd fd hdr_len_bytes in
let hdr_len = ref 0 in
for i = String.length hdr_len_str - 1 downto 0 do
hdr_len := (256 * !hdr_len) + Char.code hdr_len_str.[i]
done;
let hdr = parse_header (really_read_fd fd !hdr_len) in
let pos = Int64.of_int (!hdr_len + hdr_len_bytes + magic_len + 2) in
let (K kind) = hdr.kind in
let build layout =
let src = map_or_read fd ~pos kind layout hdr.shape in
let dst =
Nx_buffer.genarray_create kind layout (Nx_buffer.genarray_dims src)
in
Nx_buffer.genarray_blit src dst;
P dst
in
if hdr.fortran_order then build Bigarray.Fortran_layout
else build Bigarray.C_layout
let write ga path =
with_fd path [ O_CREAT; O_TRUNC; O_RDWR ] 0o640 @@ fun fd ->
let kind = Nx_buffer.genarray_kind ga in
let dims = Nx_buffer.genarray_dims ga in
let layout = Bigarray.Genarray.layout ga in
let hdr = encode_header ~layout ~packed_kind:(K kind) ~dims in
let hdr_len = String.length hdr in
if Unix.write_substring fd hdr 0 hdr_len <> hdr_len then
failwith "npy: incomplete header write";
match Nx_buffer.to_stdlib_kind kind with
| Some std_kind ->
let dst =
Unix.map_file fd ~pos:(Int64.of_int hdr_len) std_kind layout true dims
in
Bigarray.Genarray.blit ga dst
| None ->
ignore (Unix.lseek fd hdr_len Unix.SEEK_SET);
write_genarray_to_fd fd ga
module Npz = struct
let npy_suffix = ".npy"
type in_file = Zip.in_file
type out_file = Zip.out_file
let open_in = Zip.open_in
let close_in = Zip.close_in
let open_out = Zip.open_out
let close_out = Zip.close_out
let entries t =
List.map
(fun (entry : Zip.entry) ->
let name = entry.Zip.filename in
let suf_len = String.length npy_suffix in
if
String.length name >= suf_len
&& String.sub name (String.length name - suf_len) suf_len = npy_suffix
then String.sub name 0 (String.length name - suf_len)
else name)
(Zip.entries t)
let read t name =
let entry_name = name ^ npy_suffix in
let entry =
try Zip.find_entry t entry_name with Not_found -> raise Not_found
in
let tmp = Filename.temp_file "npy" ".tmp" in
Fun.protect ~finally:(fun () -> Sys.remove tmp) @@ fun () ->
Zip.copy_entry_to_file t entry tmp;
read_copy tmp
let write t name ga =
let entry_name = name ^ npy_suffix in
let tmp = Filename.temp_file "npy" ".tmp" in
Fun.protect ~finally:(fun () -> Sys.remove tmp) @@ fun () ->
write ga tmp;
Zip.copy_file_to_entry tmp t entry_name
end