Source file onnx.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
(**************************************************************************)
(*                                                                        *)
(*  This file is part of CAISAR.                                          *)
(*                                                                        *)
(*  Copyright (C) 2022                                                    *)
(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
(*         alternatives)                                                  *)
(*                                                                        *)
(*  You can redistribute it and/or modify it under the terms of the GNU   *)
(*  Lesser General Public License as published by the Free Software       *)
(*  Foundation, version 2.1.                                              *)
(*                                                                        *)
(*  It is distributed in the hope that it will be useful,                 *)
(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
(*  GNU Lesser General Public License for more details.                   *)
(*                                                                        *)
(*  See the GNU Lesser General Public License version 2.1                 *)
(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
(*                                                                        *)
(**************************************************************************)

open Base
module Format = Caml.Format
module Fun = Caml.Fun
module Oproto = Onnx_protoc (* Autogenerated during compilation *)
module Oprotom = Oproto.Onnx.ModelProto

type t = {
  n_inputs : int; (* Number of inputs. *)
  n_outputs : int; (* Number of outputs. *)
}

(* ONNX format handling. *)

let get_nested_dims (s : Oproto.Onnx.ValueInfoProto.t list) =
  match List.nth s 0 with
  | Some { type' = Some { value = `Tensor_type { shape = Some v; _ }; _ }; _ }
    ->
    v
  | _ -> []

let flattened_dim (dim : Oproto.Onnx.TensorShapeProto.Dimension.t list) =
  List.fold ~init:1 dim ~f:(fun acc x ->
    match x.value with
    | `Dim_value v -> acc * v
    | `Dim_param _ -> acc
    | `not_set -> acc)

let get_input_output_dim (model : Oprotom.t) =
  let ins, outs =
    match model.graph with
    | Some g -> (Some g.input, Some g.output)
    | None -> (None, None)
  in
  let input_shape, output_shape =
    match (ins, outs) with
    | Some i, Some o -> (get_nested_dims i, get_nested_dims o)
    | _ -> ([], [])
  in
  (* TODO: here we only get the flattened dimension of inputs and outputs, but
     more interesting parsing could be done later on. *)
  let input_flat_dim = flattened_dim input_shape in
  let output_flat_dim = flattened_dim output_shape in
  (input_flat_dim, output_flat_dim)

let parse_in_channel in_channel =
  let open Result in
  try
    let buf = Stdio.In_channel.input_all in_channel in
    let reader = Ocaml_protoc_plugin.Reader.create buf in
    match Oprotom.from_proto reader with
    | Ok r ->
      let n_inputs, n_outputs = get_input_output_dim r in
      Ok { n_inputs; n_outputs }
    | _ -> Error "Error parsing protobuf"
  with
  | Sys_error s -> Error s
  | Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg)

let parse filename =
  let in_channel = Stdlib.open_in filename in
  Fun.protect
    ~finally:(fun () -> Stdlib.close_in in_channel)
    (fun () -> parse_in_channel in_channel)