Source file checkpoint.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
let invalid_argf fmt = Printf.ksprintf invalid_arg fmt
let shape_to_string s =
"[" ^ String.concat "; " (Array.to_list (Array.map string_of_int s)) ^ "]"
let save path tree =
let pairs = Ptree.flatten_with_paths tree in
let items =
List.map
(fun (name, pt) ->
let nx = Ptree.with_tensor pt { run = (fun t -> Nx_io.P t) } in
(name, nx))
pairs
in
Nx_io.save_safetensors path items
let load path ~like =
let archive = Nx_io.load_safetensors path in
let _, rebuild = Ptree.flatten like in
let path_leaves = Ptree.flatten_with_paths like in
let loaded =
List.map
(fun (name, template) ->
match Hashtbl.find_opt archive name with
| None -> invalid_argf "Checkpoint.load: missing key %S" name
| Some (Nx_io.P nx) ->
Ptree.with_tensor template
{
run =
(fun tmpl ->
let expected = Nx.shape tmpl in
let actual = Nx.shape nx in
if expected <> actual then
invalid_argf
"Checkpoint.load: shape mismatch for %S: expected %s, \
got %s"
name (shape_to_string expected) (shape_to_string actual);
let casted = Nx.cast (Nx.dtype tmpl) nx in
Ptree.P casted);
})
path_leaves
in
rebuild loaded