1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
let strf = Printf.sprintf
type t = P : ('a, 'b) Nx.t -> t
type archive = (string, t) Hashtbl.t
let err_dtype_mismatch ~expected ~got =
strf "dtype mismatch: expected %s, got %s" expected got
let to_typed : type a b. (a, b) Nx.dtype -> t -> (a, b) Nx.t =
fun target (P nx) ->
let source = Nx.dtype nx in
match Nx_core.Dtype.equal_witness source target with
| Some Type.Equal -> (nx : (a, b) Nx.t)
| None ->
let expected = Nx_core.Dtype.to_string target in
let got = Nx_core.Dtype.to_string source in
failwith (err_dtype_mismatch ~expected ~got)
let packed_shape (P nx) = Nx.shape nx