1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
type t = P : ('a, 'b) Nx.t -> t
type archive = (string, t) Hashtbl.t
let convert_result_with_error : type a b.
(a, b) Nx.dtype -> t -> ((a, b) Nx.t, Error.t) result =
fun target_dtype packed ->
match packed with
| P nx -> (
let source_dtype = Nx.dtype nx in
let source_ba_kind = Nx_core.Dtype.to_bigarray_ext_kind source_dtype in
let target_ba_kind = Nx_core.Dtype.to_bigarray_ext_kind target_dtype in
match Npy.Eq.Kind.( === ) source_ba_kind target_ba_kind with
| Some Npy.Eq.W -> Ok nx
| None -> Error Unsupported_dtype)
let as_float16 packed = convert_result_with_error Nx.float16 packed
let as_float32 packed = convert_result_with_error Nx.float32 packed
let as_float64 packed = convert_result_with_error Nx.float64 packed
let as_int8 packed = convert_result_with_error Nx.int8 packed
let as_int16 packed = convert_result_with_error Nx.int16 packed
let as_int32 packed = convert_result_with_error Nx.int32 packed
let as_int64 packed = convert_result_with_error Nx.int64 packed
let as_uint8 packed = convert_result_with_error Nx.uint8 packed
let as_uint16 packed = convert_result_with_error Nx.uint16 packed
let as_complex32 packed = convert_result_with_error Nx.complex32 packed
let as_complex64 packed = convert_result_with_error Nx.complex64 packed