Source file nx_safetensors.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
open Bigarray_ext
open Error
open Packed_nx
let load_safetensor path =
try
let ic = open_in_bin path in
let len = in_channel_length ic in
let buffer = really_input_string ic len in
close_in ic;
match Safetensors.deserialize buffer with
| Ok safetensors ->
let tensors = Safetensors.tensors safetensors in
let result = Hashtbl.create (List.length tensors) in
List.iter
(fun (name, (view : Safetensors.tensor_view)) ->
let open Safetensors in
let shape = Array.of_list view.shape in
let num_elems = Array.fold_left ( * ) 1 shape in
let process_float32 () =
let ba = Array1.create Float32 c_layout num_elems in
for i = 0 to num_elems - 1 do
let offset = view.offset + (i * 4) in
let b0 = Char.code view.data.[offset] in
let b1 = Char.code view.data.[offset + 1] in
let b2 = Char.code view.data.[offset + 2] in
let b3 = Char.code view.data.[offset + 3] in
let bits =
Int32.(
logor
(shift_left (of_int b3) 24)
(logor
(shift_left (of_int b2) 16)
(logor (shift_left (of_int b1) 8) (of_int b0))))
in
Array1.unsafe_set ba i (Int32.float_of_bits bits)
done;
let nx_arr = Nx.of_bigarray_ext (genarray_of_array1 ba) in
Nx.reshape shape nx_arr
in
let process_float64 () =
let ba = Array1.create Float64 c_layout num_elems in
for i = 0 to num_elems - 1 do
let offset = view.offset + (i * 8) in
let bits = Safetensors.read_u64_le view.data offset in
Array1.unsafe_set ba i (Int64.float_of_bits bits)
done;
let nx_arr = Nx.of_bigarray_ext (genarray_of_array1 ba) in
Nx.reshape shape nx_arr
in
let process_int32 () =
let ba = Array1.create Int32 c_layout num_elems in
for i = 0 to num_elems - 1 do
let offset = view.offset + (i * 4) in
let b0 = Char.code view.data.[offset] in
let b1 = Char.code view.data.[offset + 1] in
let b2 = Char.code view.data.[offset + 2] in
let b3 = Char.code view.data.[offset + 3] in
let bits =
Int32.(
logor
(shift_left (of_int b3) 24)
(logor
(shift_left (of_int b2) 16)
(logor (shift_left (of_int b1) 8) (of_int b0))))
in
Array1.unsafe_set ba i bits
done;
let nx_arr = Nx.of_bigarray_ext (genarray_of_array1 ba) in
Nx.reshape shape nx_arr
in
match view.dtype with
| F32 -> Hashtbl.add result name (P (process_float32 ()))
| F64 -> Hashtbl.add result name (P (process_float64 ()))
| I32 -> Hashtbl.add result name (P (process_int32 ()))
| F16 ->
let ba = Array1.create Float16 c_layout num_elems in
for i = 0 to num_elems - 1 do
let offset = view.offset + (i * 2) in
let _b0 = Char.code view.data.[offset] in
let _b1 = Char.code view.data.[offset + 1] in
Array1.unsafe_set ba i 0.0
done;
let nx_arr = Nx.of_bigarray_ext (genarray_of_array1 ba) in
Hashtbl.add result name (P (Nx.reshape shape nx_arr))
| BF16 ->
let ba = Array1.create Bfloat16 c_layout num_elems in
for i = 0 to num_elems - 1 do
let offset = view.offset + (i * 2) in
let _b0 = Char.code view.data.[offset] in
let _b1 = Char.code view.data.[offset + 1] in
Array1.unsafe_set ba i 0.0
done;
let nx_arr = Nx.of_bigarray_ext (genarray_of_array1 ba) in
Hashtbl.add result name (P (Nx.reshape shape nx_arr))
| _ ->
Printf.eprintf
"Warning: Skipping tensor '%s' with unsupported dtype %s\n"
name
(Safetensors.dtype_to_string view.dtype))
tensors;
Ok result
| Error err -> Error (Format_error (Safetensors.string_of_error err))
with
| Sys_error msg -> Error (Io_error msg)
| ex -> Error (Other (Printexc.to_string ex))
let save_safetensor ?(overwrite = true) path items =
try
if (not overwrite) && Sys.file_exists path then
Error (Io_error (Printf.sprintf "File '%s' already exists" path))
else
let tensor_views =
List.map
(fun (name, P arr) ->
let shape = Array.to_list (Nx.shape arr) in
let ba = Nx.to_bigarray_ext arr in
let num_elems = Array.fold_left ( * ) 1 (Nx.shape arr) in
let dtype, data =
match Genarray.kind ba with
| Float32 ->
let bytes = Bytes.create (num_elems * 4) in
let ba_flat = Nx.to_bigarray_ext (Nx.flatten arr) in
let ba1 = array1_of_genarray ba_flat in
for i = 0 to num_elems - 1 do
let bits = Int32.bits_of_float (Array1.unsafe_get ba1 i) in
let offset = i * 4 in
Bytes.set bytes offset
(Char.chr (Int32.to_int (Int32.logand bits 0xffl)));
Bytes.set bytes (offset + 1)
(Char.chr
(Int32.to_int
(Int32.logand (Int32.shift_right bits 8) 0xffl)));
Bytes.set bytes (offset + 2)
(Char.chr
(Int32.to_int
(Int32.logand (Int32.shift_right bits 16) 0xffl)));
Bytes.set bytes (offset + 3)
(Char.chr
(Int32.to_int
(Int32.logand (Int32.shift_right bits 24) 0xffl)))
done;
(Safetensors.F32, Bytes.unsafe_to_string bytes)
| Float64 ->
let bytes = Bytes.create (num_elems * 8) in
let ba_flat = Nx.to_bigarray_ext (Nx.flatten arr) in
let ba1 = array1_of_genarray ba_flat in
for i = 0 to num_elems - 1 do
let bits = Int64.bits_of_float (Array1.unsafe_get ba1 i) in
Safetensors.write_u64_le bytes (i * 8) bits
done;
(Safetensors.F64, Bytes.unsafe_to_string bytes)
| Int32 ->
let bytes = Bytes.create (num_elems * 4) in
let ba_flat = Nx.to_bigarray_ext (Nx.flatten arr) in
let ba1 = array1_of_genarray ba_flat in
for i = 0 to num_elems - 1 do
let value = Array1.unsafe_get ba1 i in
let offset = i * 4 in
Bytes.set bytes offset
(Char.chr (Int32.to_int (Int32.logand value 0xffl)));
Bytes.set bytes (offset + 1)
(Char.chr
(Int32.to_int
(Int32.logand (Int32.shift_right value 8) 0xffl)));
Bytes.set bytes (offset + 2)
(Char.chr
(Int32.to_int
(Int32.logand (Int32.shift_right value 16) 0xffl)));
Bytes.set bytes (offset + 3)
(Char.chr
(Int32.to_int
(Int32.logand (Int32.shift_right value 24) 0xffl)))
done;
(Safetensors.I32, Bytes.unsafe_to_string bytes)
| Float16 ->
let bytes = Bytes.create (num_elems * 2) in
for i = 0 to (num_elems * 2) - 1 do
Bytes.set bytes i '\000'
done;
(Safetensors.F16, Bytes.unsafe_to_string bytes)
| Bfloat16 ->
let bytes = Bytes.create (num_elems * 2) in
for i = 0 to (num_elems * 2) - 1 do
Bytes.set bytes i '\000'
done;
(Safetensors.BF16, Bytes.unsafe_to_string bytes)
| _ ->
fail_msg "Unsupported dtype for safetensors: %s"
(Nx_core.Dtype.of_bigarray_ext_kind (Genarray.kind ba)
|> Nx_core.Dtype.to_string)
in
match Safetensors.tensor_view_new ~dtype ~shape ~data with
| Ok view -> (name, view)
| Error err ->
fail_msg "Failed to create tensor view for '%s': %s" name
(Safetensors.string_of_error err))
items
in
match Safetensors.serialize_to_file tensor_views None path with
| Ok () -> Ok ()
| Error err -> Error (Format_error (Safetensors.string_of_error err))
with
| Sys_error msg -> Error (Io_error msg)
| ex -> Error (Other (Printexc.to_string ex))