1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
open Bigarray_ext
open Error
type packed_nx = Packed_nx.t = P : ('a, 'b) Nx.t -> packed_nx
type archive = (string, packed_nx) Hashtbl.t
module Safe = struct
type error = Error.t =
| Io_error of string
| Format_error of string
| Unsupported_dtype
| Unsupported_shape
| Missing_entry of string
| Other of string
type nx_dims = [ `Gray of int * int | `Color of int * int * int ]
let get_nx_dims arr : nx_dims =
match Nx.shape arr with
| [| h; w |] -> `Gray (h, w)
| [| h; w; c |] -> `Color (h, w, c)
| s ->
fail_msg "Invalid nx dimensions: expected 2 or 3, got %d (%s)"
(Array.length s)
(Array.to_list s |> List.map string_of_int |> String.concat "x")
let load_image ?grayscale path =
let grayscale = Option.value grayscale ~default:false in
try
let desired_channels = if grayscale then 1 else 3 in
match Stb_image.load ~channels:desired_channels path with
| Ok img ->
let h = Stb_image.height img in
let w = Stb_image.width img in
let c = Stb_image.channels img in
let buffer = Stb_image.data img in
let nd = Nx.of_bigarray_ext (genarray_of_array1 buffer) in
let shape = if c = 1 then [| h; w |] else [| h; w; c |] in
Ok (Nx.reshape shape nd)
| Error (`Msg msg) -> Error (Format_error msg)
with
| Sys_error msg -> Error (Io_error msg)
| ex -> Error (Other (Printexc.to_string ex))
let save_image ?(overwrite = true) path img =
try
if (not overwrite) && Sys.file_exists path then
Error (Io_error (Printf.sprintf "File '%s' already exists" path))
else
let h, w, c =
match get_nx_dims img with
| `Gray (h, w) -> (h, w, 1)
| `Color (h, w, c) -> (h, w, c)
in
let data_gen = Nx.to_bigarray_ext img in
let data =
match Genarray.kind data_gen with
| Int8_unsigned -> array1_of_genarray data_gen
in
let extension = Filename.extension path |> String.lowercase_ascii in
match extension with
| ".png" ->
Stb_image_write.png path ~w ~h ~c data;
Ok ()
| ".bmp" ->
Stb_image_write.bmp path ~w ~h ~c data;
Ok ()
| ".tga" ->
Stb_image_write.tga path ~w ~h ~c data;
Ok ()
| ".jpg" | ".jpeg" ->
Stb_image_write.jpg path ~w ~h ~c ~quality:90 data;
Ok ()
| _ ->
Error
(Format_error
(Printf.sprintf
"Unsupported image format: '%s'. Use .png, .bmp, .tga, .jpg"
extension))
with
| Sys_error msg -> Error (Io_error msg)
| Invalid_argument msg -> Error (Other msg)
| Failure msg -> Error (Other msg)
| ex -> Error (Other (Printexc.to_string ex))
let load_npy path = Nx_npy.load_npy path
let save_npy ?(overwrite = true) path arr =
Nx_npy.save_npy ~overwrite path arr
let load_npz path = Nx_npy.load_npz path
let load_npz_member ~name path = Nx_npy.load_npz_member ~name path
let save_npz ?(overwrite = true) path items =
Nx_npy.save_npz ~overwrite path items
let as_float16 = Packed_nx.as_float16
let as_float32 = Packed_nx.as_float32
let as_float64 = Packed_nx.as_float64
let as_int8 = Packed_nx.as_int8
let as_int16 = Packed_nx.as_int16
let as_int32 = Packed_nx.as_int32
let as_int64 = Packed_nx.as_int64
let as_uint8 = Packed_nx.as_uint8
let as_uint16 = Packed_nx.as_uint16
let as_complex32 = Packed_nx.as_complex32
let as_complex64 = Packed_nx.as_complex64
let load_safetensor path = Nx_safetensors.load_safetensor path
let save_safetensor ?overwrite path items =
Nx_safetensors.save_safetensor ?overwrite path items
end
let unwrap_result = function
| Ok v -> v
| Error err -> failwith (Error.to_string err)
let as_float16 packed = Packed_nx.as_float16 packed |> unwrap_result
let as_float32 packed = Packed_nx.as_float32 packed |> unwrap_result
let as_float64 packed = Packed_nx.as_float64 packed |> unwrap_result
let as_int8 packed = Packed_nx.as_int8 packed |> unwrap_result
let as_int16 packed = Packed_nx.as_int16 packed |> unwrap_result
let as_int32 packed = Packed_nx.as_int32 packed |> unwrap_result
let as_int64 packed = Packed_nx.as_int64 packed |> unwrap_result
let as_uint8 packed = Packed_nx.as_uint8 packed |> unwrap_result
let as_uint16 packed = Packed_nx.as_uint16 packed |> unwrap_result
let as_complex32 packed = Packed_nx.as_complex32 packed |> unwrap_result
let as_complex64 packed = Packed_nx.as_complex64 packed |> unwrap_result
let load_image ?grayscale path =
Safe.load_image ?grayscale path |> unwrap_result
let save_image ?overwrite path img =
Safe.save_image ?overwrite path img |> unwrap_result
let load_npy path = Safe.load_npy path |> unwrap_result
let save_npy ?overwrite path arr =
Safe.save_npy ?overwrite path arr |> unwrap_result
let load_npz path = Safe.load_npz path |> unwrap_result
let load_npz_member ~name path =
Safe.load_npz_member ~name path |> unwrap_result
let save_npz ?overwrite path items =
Safe.save_npz ?overwrite path items |> unwrap_result
let load_safetensor path = Safe.load_safetensor path |> unwrap_result
let save_safetensor ?overwrite path items =
Safe.save_safetensor ?overwrite path items |> unwrap_result