1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
open Error
open Packed_nx
let strf = Printf.sprintf
let npy_to_nx (Npy.P ga) =
let ga = Nx_buffer.genarray_change_layout ga Bigarray.C_layout in
let shape = Nx_buffer.genarray_dims ga in
P (Nx.of_buffer (Nx_buffer.of_genarray ga) ~shape)
let wrap_exn f =
try f () with
| Npy.Read_error msg -> Error (Format_error msg)
| Zip.Error (name, func, msg) ->
Error (Io_error (strf "zip: %s in %s: %s" name func msg))
| Unix.Unix_error (e, _, _) -> Error (Io_error (Unix.error_message e))
| Sys_error msg -> Error (Io_error msg)
| Failure msg -> Error (Format_error msg)
| ex -> Error (Other (Printexc.to_string ex))
let check_overwrite overwrite path =
if (not overwrite) && Sys.file_exists path then
failwith (strf "file already exists: %s" path)
let load_npy path = wrap_exn @@ fun () -> Ok (npy_to_nx (Npy.read_copy path))
let save_npy ?(overwrite = true) path arr =
wrap_exn @@ fun () ->
check_overwrite overwrite path;
let buf = Nx.to_buffer arr in
let shape = Nx.shape arr in
Npy.write (Nx_buffer.to_genarray buf shape) path;
Ok ()
let load_npz path =
wrap_exn @@ fun () ->
let zi = Npy.Npz.open_in path in
Fun.protect ~finally:(fun () -> Npy.Npz.close_in zi) @@ fun () ->
let entries = Npy.Npz.entries zi in
let archive = Hashtbl.create (List.length entries) in
List.iter
(fun name -> Hashtbl.add archive name (npy_to_nx (Npy.Npz.read zi name)))
entries;
Ok archive
let load_npz_entry ~name path =
wrap_exn @@ fun () ->
let zi = Npy.Npz.open_in path in
Fun.protect ~finally:(fun () -> Npy.Npz.close_in zi) @@ fun () ->
match Npy.Npz.read zi name with
| packed -> Ok (npy_to_nx packed)
| exception Not_found -> Error (Missing_entry name)
let save_npz ?(overwrite = true) path items =
wrap_exn @@ fun () ->
check_overwrite overwrite path;
let zo = Npy.Npz.open_out path in
Fun.protect ~finally:(fun () -> Npy.Npz.close_out zo) @@ fun () ->
List.iter
(fun (name, P nx) ->
let buf = Nx.to_buffer nx in
Npy.Npz.write zo name (Nx_buffer.to_genarray buf (Nx.shape nx)))
items;
Ok ()