123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102# 1 "src/base/misc/owl_utils_conv.ml"(*
* OWL - an OCaml numerical library for scientific computing
* Copyright (c) 2016-2018 Liang Wang <liang.wang@cl.cam.ac.uk>
*)openOwl_types(* This module contains helper function for ndarray convolution functions. *)(* calculate the output shape of [conv2d] given input and kernel and stride *)letcalc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_stride=letinput_cols=float_of_intinput_colsinletinput_rows=float_of_intinput_rowsinletkernel_cols=float_of_intkernel_colsinletkernel_rows=float_of_intkernel_rowsinletrow_stride=float_of_introw_strideinletcol_stride=float_of_intcol_strideinletoutput_cols=matchpaddingwith|SAME->(input_cols/.col_stride)|>ceil|>int_of_float|VALID->((input_cols-.kernel_cols+.1.)/.col_stride)|>ceil|>int_of_floatinletoutput_rows=matchpaddingwith|SAME->(input_rows/.row_stride)|>ceil|>int_of_float|VALID->((input_rows-.kernel_rows+.1.)/.row_stride)|>ceil|>int_of_floatin(output_cols,output_rows)(* calculate the padding size along width and height *)letcalc_conv2d_paddinginput_colsinput_rowskernel_colskernel_rowsoutput_colsoutput_rowsrow_stridecol_stride=letpad_along_height=Pervasives.max((output_rows-1)*row_stride+kernel_rows-input_rows)0inletpad_along_width=Pervasives.max((output_cols-1)*col_stride+kernel_cols-input_cols)0inletpad_top=pad_along_height/2inletpad_bottom=pad_along_height-pad_topinletpad_left=pad_along_width/2inletpad_right=pad_along_width-pad_leftinpad_top,pad_left,pad_bottom,pad_right(* calc_conv1d_output_shape actually calls its 2d version *)letcalc_conv1d_output_shapepaddinginput_colskernel_colscol_stride=letinput_rows=1inletkernel_rows=1inletrow_stride=1incalc_conv2d_output_shapepaddinginput_colsinput_rowskernel_colskernel_rowsrow_stridecol_stride|>fst(* calculate the output shape of [conv3d] given input and kernel and stride *)letcalc_conv3d_output_shapepaddinginput_colsinput_rowsinput_dptskernel_colskernel_rowskernel_dptsrow_stridecol_stridedpt_stride=letinput_cols=float_of_intinput_colsinletinput_rows=float_of_intinput_rowsinletinput_dpts=float_of_intinput_dptsinletkernel_cols=float_of_intkernel_colsinletkernel_rows=float_of_intkernel_rowsinletkernel_dpts=float_of_intkernel_dptsinletrow_stride=float_of_introw_strideinletcol_stride=float_of_intcol_strideinletdpt_stride=float_of_intdpt_strideinletoutput_cols=matchpaddingwith|SAME->(input_cols/.col_stride)|>ceil|>int_of_float|VALID->((input_cols-.kernel_cols+.1.)/.col_stride)|>ceil|>int_of_floatinletoutput_rows=matchpaddingwith|SAME->(input_rows/.row_stride)|>ceil|>int_of_float|VALID->((input_rows-.kernel_rows+.1.)/.row_stride)|>ceil|>int_of_floatinletoutput_dpts=matchpaddingwith|SAME->(input_dpts/.dpt_stride)|>ceil|>int_of_float|VALID->((input_dpts-.kernel_dpts+.1.)/.dpt_stride)|>ceil|>int_of_floatin(output_cols,output_rows,output_dpts)(* calculate the padding size along width, height, and depth. *)letcalc_conv3d_paddinginput_colsinput_rowsinput_depthkernel_colskernel_rowskernel_depthoutput_colsoutput_rowsoutput_depthrow_stridecol_stridedepth_stride=letpad_along_height=Pervasives.max((output_rows-1)*row_stride+kernel_rows-input_rows)0inletpad_along_width=Pervasives.max((output_cols-1)*col_stride+kernel_cols-input_cols)0inletpad_along_depth=Pervasives.max((output_depth-1)*depth_stride+kernel_depth-input_depth)0inletpad_top=pad_along_height/2inletpad_bottom=pad_along_height-pad_topinletpad_left=pad_along_width/2inletpad_right=pad_along_width-pad_leftinletpad_shallow=pad_along_depth/2inletpad_deep=pad_along_depth-pad_shallowinpad_top,pad_left,pad_shallow,pad_bottom,pad_right,pad_deep