Source file checkpointing.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
open Base
let latest_index_and_filename ~checkpoint_base =
let dirname = Caml.Filename.dirname checkpoint_base in
let basename = Caml.Filename.basename checkpoint_base in
Caml.Sys.readdir dirname
|> Array.to_list
|> List.filter_map ~f:(fun filename ->
match String.chop_prefix filename ~prefix:(basename ^ ".") with
| None -> None
| Some suffix ->
(try Some (Int.of_string suffix, Caml.Filename.concat dirname filename) with
| _ -> None))
|> List.sort ~compare:Stdlib.compare
|> List.last
let loop
~start_index
~end_index
~var_stores
~checkpoint_base
?only_keep
?(checkpoint_every = `seconds 600.)
f
=
if start_index < 0 then Printf.invalid_argf "negative start_index %d" start_index ();
Option.iter only_keep ~f:(fun only_keep ->
if only_keep <= 0 then Printf.invalid_argf "non-positive only_keep %d" only_keep ());
let temp_checkpoint = checkpoint_base ^ ".tmp" in
let latest_index_and_filename = latest_index_and_filename ~checkpoint_base in
let named_tensors =
match var_stores with
| [ vs ] -> Var_store.all_vars vs
| var_stores ->
List.concat_map var_stores ~f:(fun vs ->
let vs_name = Var_store.name vs in
Var_store.all_vars vs
|> List.map ~f:(fun (name, tensor) ->
Printf.sprintf "%s:%s" vs_name name, tensor))
in
Option.iter latest_index_and_filename ~f:(fun (latest_index, filename) ->
Stdio.eprintf
"Restoring checkpoint for index %d from '%s'.\n%!"
latest_index
filename;
Serialize.load_multi_ ~named_tensors ~filename);
let start_index =
Option.value_map latest_index_and_filename ~default:start_index ~f:(fun (index, _) ->
index + 1)
in
let only_keep =
Option.map only_keep ~f:(fun only_keep -> only_keep, Linked_queue.create ())
in
let save ~suffix =
Serialize.save_multi ~named_tensors ~filename:temp_checkpoint;
Unix.rename temp_checkpoint (Printf.sprintf "%s.%s" checkpoint_base suffix)
in
let save_index index =
save ~suffix:(Int.to_string index);
Option.iter only_keep ~f:(fun (only_keep, index_queue) ->
Linked_queue.enqueue index_queue index;
if Linked_queue.length index_queue > only_keep
then
Linked_queue.dequeue_exn index_queue
|> Int.to_string
|> Printf.sprintf "%s.%s" checkpoint_base
|> Unix.unlink)
in
let last_checkpoint_time = ref (Unix.time ()) in
for index = start_index to end_index do
f ~index;
let should_checkpoint =
match checkpoint_every with
| `seconds seconds -> Float.( > ) (Unix.time () -. !last_checkpoint_time) seconds
| `iters iters -> index % iters = 0
in
if should_checkpoint
then (
save_index index;
last_checkpoint_time := Unix.time ())
done;
save ~suffix:"final"