1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
open Base
module Format = Stdlib.Format
module Sys = Stdlib.Sys
module Filename = Stdlib.Filename
module Fun = Stdlib.Fun
type t = {
n_layers : int;
n_inputs : int;
n_outputs : int;
max_layer_size : int;
layer_sizes : int list;
min_input_values : float list option;
max_input_values : float list option;
mean_values : (float list * float) option;
range_values : (float list * float) option;
weights_biases : float list list;
}
let nnet_format_error s =
Error (Format.sprintf "NNet format error: %s condition not satisfied." s)
let handle_nnet_line ~f in_channel =
List.filter_map
~f:(fun s -> try Some (f (String.strip s)) with _ -> None)
(Csv.next in_channel)
let filename in_channel =
let exception End_of_header in
let pos_in = ref (Stdlib.pos_in in_channel) in
try
while true do
let line = Stdlib.input_line in_channel in
if not (Str.string_match (Str.regexp "//") line 0)
then raise End_of_header
else pos_in := Stdlib.pos_in in_channel
done;
assert false
with
| End_of_header ->
Stdlib.seek_in in_channel !pos_in;
Ok ()
| End_of_file ->
Error (Format.sprintf "NNet model not found in file '%s'." filename)
let handle_nnet_basic_info in_channel =
match handle_nnet_line ~f:Int.of_string in_channel with
| [ n_layers; n_inputs; n_outputs; max_layer_size ] ->
Ok (n_layers, n_inputs, n_outputs, max_layer_size)
| _ -> nnet_format_error "second"
| exception End_of_file -> nnet_format_error "second"
let handle_nnet_layer_sizes n_layers in_channel =
try
let layer_sizes = handle_nnet_line ~f:Int.of_string in_channel in
if List.length layer_sizes = n_layers + 1
then Ok layer_sizes
else nnet_format_error "third"
with End_of_file -> nnet_format_error "third"
let handle_nnet_unused_flag in_channel =
try
let _ = Csv.next in_channel in
Ok ()
with End_of_file -> nnet_format_error "forth"
let handle_nnet_min_input_values n_inputs in_channel =
try
let min_input_values = handle_nnet_line ~f:Float.of_string in_channel in
if List.length min_input_values = n_inputs
then Ok min_input_values
else nnet_format_error "fifth"
with End_of_file -> nnet_format_error "fifth"
let handle_nnet_max_input_values n_inputs in_channel =
try
let max_input_values = handle_nnet_line ~f:Float.of_string in_channel in
if List.length max_input_values = n_inputs
then Ok max_input_values
else nnet_format_error "sixth"
with End_of_file -> nnet_format_error "sixth"
let handle_nnet_mean_values n_inputs in_channel =
try
let mean_values = handle_nnet_line ~f:Float.of_string in_channel in
if List.length mean_values = n_inputs + 1
then
let mean_input_values, mean_output_value =
List.split_n mean_values n_inputs
in
Ok (mean_input_values, List.hd_exn mean_output_value)
else nnet_format_error "seventh"
with End_of_file -> nnet_format_error "seventh"
let handle_nnet_range_values n_inputs in_channel =
try
let range_values = handle_nnet_line ~f:Float.of_string in_channel in
if List.length range_values = n_inputs + 1
then
let range_input_values, range_output_value =
List.split_n range_values n_inputs
in
Ok (range_input_values, List.hd_exn range_output_value)
else nnet_format_error "eighth"
with End_of_file -> nnet_format_error "eighth"
let handle_nnet_weights_and_biases in_channel =
List.rev
(Csv.fold_left ~init:[]
~f:(fun fll sl ->
List.filter_map
~f:(fun s ->
try Some (Float.of_string (String.strip s)) with _ -> None)
sl
:: fll)
in_channel)
let parse_in_channel ?(permissive = false) filename in_channel =
let open Result in
let ok_opt r =
match r with
| Ok x -> Ok (Some x)
| Error _ as error -> if not permissive then error else Ok None
in
try
skip_nnet_header filename in_channel >>= fun () ->
let in_channel = Csv.of_channel in_channel in
handle_nnet_basic_info in_channel >>= fun (n_ls, n_is, n_os, max_l_size) ->
handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes ->
handle_nnet_unused_flag in_channel >>= fun () ->
ok_opt (handle_nnet_min_input_values n_is in_channel)
>>= fun min_input_values ->
ok_opt (handle_nnet_max_input_values n_is in_channel)
>>= fun max_input_values ->
ok_opt (handle_nnet_mean_values n_is in_channel) >>= fun mean_values ->
ok_opt (handle_nnet_range_values n_is in_channel) >>= fun range_values ->
let weights_biases = handle_nnet_weights_and_biases in_channel in
Csv.close_in in_channel;
Ok
{
n_layers = n_ls;
n_inputs = n_is;
n_outputs = n_os;
max_layer_size = max_l_size;
layer_sizes;
min_input_values;
max_input_values;
mean_values;
range_values;
weights_biases;
}
with
| Csv.Failure (_nrecord, _nfield, msg) -> Error msg
| Sys_error s -> Error s
| Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg)
let parse ?(permissive = false) filename =
let in_channel = Stdlib.open_in filename in
Fun.protect
~finally:(fun () -> Stdlib.close_in in_channel)
(fun () -> parse_in_channel ~permissive filename in_channel)