12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485(**************************************************************************)(* *)(* This file is part of CAISAR. *)(* *)(* Copyright (C) 2022 *)(* CEA (Commissariat à l'énergie atomique et aux énergies *)(* alternatives) *)(* *)(* You can redistribute it and/or modify it under the terms of the GNU *)(* Lesser General Public License as published by the Free Software *)(* Foundation, version 2.1. *)(* *)(* It is distributed in the hope that it will be useful, *)(* but WITHOUT ANY WARRANTY; without even the implied warranty of *)(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *)(* GNU Lesser General Public License for more details. *)(* *)(* See the GNU Lesser General Public License version 2.1 *)(* for more details (enclosed in the file licenses/LGPLv2.1). *)(* *)(**************************************************************************)openBasemoduleFormat=Caml.FormatmoduleFun=Caml.FunmoduleOproto=Onnx_protoc(* Autogenerated during compilation *)moduleOprotom=Oproto.Onnx.ModelPrototypet={n_inputs:int;(* Number of inputs. *)n_outputs:int;(* Number of outputs. *)}(* ONNX format handling. *)letget_nested_dims(s:Oproto.Onnx.ValueInfoProto.tlist)=matchList.nths0with|Some{type'=Some{value=`Tensor_type{shape=Somev;_};_};_}->v|_->[]letflattened_dim(dim:Oproto.Onnx.TensorShapeProto.Dimension.tlist)=List.fold~init:1dim~f:(funaccx->matchx.valuewith|`Dim_valuev->acc*v|`Dim_param_->acc|`not_set->acc)letget_input_output_dim(model:Oprotom.t)=letins,outs=matchmodel.graphwith|Someg->(Someg.input,Someg.output)|None->(None,None)inletinput_shape,output_shape=match(ins,outs)with|Somei,Someo->(get_nested_dimsi,get_nested_dimso)|_->([],[])in(* TODO: here we only get the flattened dimension of inputs and outputs, but
more interesting parsing could be done later on. *)letinput_flat_dim=flattened_diminput_shapeinletoutput_flat_dim=flattened_dimoutput_shapein(input_flat_dim,output_flat_dim)letparse_in_channelin_channel=letopenResultintryletbuf=Stdio.In_channel.input_allin_channelinletreader=Ocaml_protoc_plugin.Reader.createbufinmatchOprotom.from_protoreaderwith|Okr->letn_inputs,n_outputs=get_input_output_dimrinOk{n_inputs;n_outputs}|_->Error"Error parsing protobuf"with|Sys_errors->Errors|Failuremsg->Error(Format.sprintf"Unexpected error: %s."msg)letparsefilename=letin_channel=Stdlib.open_infilenameinFun.protect~finally:(fun()->Stdlib.close_inin_channel)(fun()->parse_in_channelin_channel)