123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450# 1 "src/base/compute/owl_computation_shape.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*)openOwl_types(* Functor of making the shape inference of a computation graph. *)moduleMake(Type:Owl_computation_type_sig.Sig)=structmoduleType=TypeopenType(* infer the shape of outcome from inputs *)let_infer_shape_00_input_shapes=[|Some[||]|]let_infer_shape_01input_shapes=matchinput_shapes.(0).(0)with|Somes->[|SomeArray.(copys)|]|None->[|None|]let_infer_shape_02input_shapes=matchinput_shapes.(1).(0)with|Somes->[|SomeArray.(copys)|]|None->[|None|]let_infer_shape_03input_shapes=lets0=input_shapes.(0).(0)inlets1=input_shapes.(1).(0)inmatchs0,s1with|Somes0,Somes1->[|SomeOwl_utils_infer_shape.(broadcast1s0s1)|]|_,_->[|None|]let_infer_shape_04input_shapesaxis=matchinput_shapes.(0).(0)with|Somes->[|SomeOwl_utils_infer_shape.(foldsaxis)|]|None->[|None|]let_infer_shape_05input_shapesrepeats=matchinput_shapes.(0).(0)with|Somes->[|SomeOwl_utils_infer_shape.(tilesrepeats)|]|None->[|None|]let_infer_shape_06input_shapesrepeats=matchinput_shapes.(0).(0)with|Somes->[|SomeOwl_utils_infer_shape.(repeatsrepeats)|]|None->[|None|]let_infer_shape_07input_shapesaxis=lets0=Array.map(funs->s.(0))input_shapesinifArray.exists(function|Some_->false|None->true)s0then[|None|]else(lets1=Array.map(function|Somea->a|None->failwith"_infer_shape_07")s0in[|SomeOwl_utils_infer_shape.(concatenates1axis)|])let_infer_shape_08input_shapesaxisparts=matchinput_shapes.(0).(0)with|Somes->lets0=Owl_utils_infer_shape.(splitsaxisparts)inArray.map(funs->Somes)s0|None->Array.(make(lengthparts)None)let_infer_shape_09input_shapesaxisn=matchinput_shapes.(0).(0)with|Somes->[|SomeOwl_utils_infer_shape.(drawsaxisn)|]|None->[|None|]let_infer_shape_10input_shapesaxis=matchinput_shapes.(0).(0)with|Somes->[|SomeOwl_utils_infer_shape.(reducesaxis)|]|None->[|None|]let_infer_shape_11input_shapespaddingstride=letinput_shape=input_shapes.(0).(0)inletkernel_shape=input_shapes.(1).(0)inmatchinput_shape,kernel_shapewith|Someinput,Somekernel->[|SomeOwl_utils_infer_shape.(conv1dinputpaddingkernelstride)|]|_,_->[|None|]let_infer_shape_12input_shapespaddingstride=letinput_shape=input_shapes.(0).(0)inletkernel_shape=input_shapes.(1).(0)inmatchinput_shape,kernel_shapewith|Someinput,Somekernel->[|SomeOwl_utils_infer_shape.(conv2dinputpaddingkernelstride)|]|_,_->[|None|]let_infer_shape_13input_shapespaddingstride=letinput_shape=input_shapes.(0).(0)inletkernel_shape=input_shapes.(1).(0)inmatchinput_shape,kernel_shapewith|Someinput,Somekernel->[|SomeOwl_utils_infer_shape.(conv3dinputpaddingkernelstride)|]|_,_->[|None|]let_infer_shape_14input_shapespaddingstride=letinput_shape=input_shapes.(0).(0)inletkernel_shape=input_shapes.(1).(0)inmatchinput_shape,kernel_shapewith|Someinput,Somekernel->[|SomeOwl_utils_infer_shape.(transpose_conv2dinputpaddingkernelstride)|]|_,_->[|None|]let_infer_shape_15input_shapespaddingkernelstride=letinput_shape=input_shapes.(0).(0)inmatchinput_shapewith|Someinput->[|SomeOwl_utils_infer_shape.(conv1dinputpaddingkernelstride)|]|_->[|None|]let_infer_shape_16input_shapespaddingkernelstride=letinput_shape=input_shapes.(0).(0)inmatchinput_shapewith|Someinput->[|SomeOwl_utils_infer_shape.(conv2dinputpaddingkernelstride)|]|_->[|None|]let_infer_shape_17input_shapespaddingkernelstride=letinput_shape=input_shapes.(0).(0)inmatchinput_shapewith|Someinput->[|SomeOwl_utils_infer_shape.(conv3dinputpaddingkernelstride)|]|_->[|None|]let_infer_shape_18input_shapesaxis=matchinput_shapes.(0).(0)with|Somes->[|SomeOwl_utils_infer_shape.(transposesaxis)|]|None->[|None|]let_infer_shape_19input_shapes=letx_shape=input_shapes.(0).(0)inlety_shape=input_shapes.(1).(0)inmatchx_shape,y_shapewith|Somes0,Somes1->[|SomeOwl_utils_infer_shape.(dots0s1)|]|_,_->[|None|]let_infer_shape_20input_shapesaxis=matchinput_shapes.(0).(0)with|Somes->letaxis=List.map(funi->R_(Array.of_listi))axis|>Array.of_listinletaxis=Owl_base_slicing.check_slice_definitionaxissin[|SomeOwl_base_slicing.(calc_slice_shapeaxis)|]|None->[|None|]let_infer_shape_21input_shapespaddingkernelstride=letinput_shape=input_shapes.(0).(0)inmatchinput_shapewith|Someinput->[|SomeOwl_utils_infer_shape.(pool2dinputpaddingkernelstride)|]|_->[|None|]let_infer_shape_22input_shapesdepth=matchinput_shapes.(0).(0)with|Somes->[|SomeOwl_utils_infer_shape.(onehotsdepth)|]|None->[|None|]let_infer_shape_23input_shapes=lets0=input_shapes.(0).(0)inlets1=input_shapes.(1).(0)inlets2=input_shapes.(2).(0)inmatchs0,s1,s2with|Somes0,Somes1,Somes2->[|SomeOwl_utils_infer_shape.(broadcast2s0s1s2)|]|_,_,_->[|None|]let_infer_shape_24input_shapespaddingstride=letinput_shape=input_shapes.(0).(0)inletkernel_shape=input_shapes.(1).(0)inmatchinput_shape,kernel_shapewith|Someinput,Somekernel->[|SomeOwl_utils_infer_shape.(transpose_conv1dinputpaddingkernelstride)|]|_,_->[|None|]let_infer_shape_25input_shapespaddingstride=letinput_shape=input_shapes.(0).(0)inletkernel_shape=input_shapes.(1).(0)inmatchinput_shape,kernel_shapewith|Someinput,Somekernel->[|SomeOwl_utils_infer_shape.(transpose_conv3dinputpaddingkernelstride)|]|_,_->[|None|]let_infer_shape_26input_shapespaddingstriderate=letinput_shape=input_shapes.(0).(0)inletkernel_shape=input_shapes.(1).(0)inmatchinput_shape,kernel_shapewith|Someinput,Somekernel->[|SomeOwl_utils_infer_shape.(dilated_conv1dinputpaddingkernelstriderate)|]|_,_->[|None|]let_infer_shape_27input_shapespaddingstriderate=letinput_shape=input_shapes.(0).(0)inletkernel_shape=input_shapes.(1).(0)inmatchinput_shape,kernel_shapewith|Someinput,Somekernel->[|SomeOwl_utils_infer_shape.(dilated_conv2dinputpaddingkernelstriderate)|]|_,_->[|None|]let_infer_shape_28input_shapespaddingstriderate=letinput_shape=input_shapes.(0).(0)inletkernel_shape=input_shapes.(1).(0)inmatchinput_shape,kernel_shapewith|Someinput,Somekernel->[|SomeOwl_utils_infer_shape.(dilated_conv3dinputpaddingkernelstriderate)|]|_,_->[|None|]let_infer_shape_29input_shapessize=letinput_shape=input_shapes.(0).(0)inmatchinput_shapewith|Someinput->[|SomeOwl_utils_infer_shape.(upsampling2dinputsize)|]|_->[|None|]let_infer_shape_30input_shapespad=letinput_shape=input_shapes.(0).(0)inletpad=Owl_utils.llss2aarrpadinmatchinput_shapewith|Someinput->[|Some(Array.map2(funmn->m+n.(0)+n.(1))inputpad)|]|_->[|None|]letinfer_shapeoperatorargs=letinput_shapes=Array.map(funa->(Owl_graph.attra).shape)argsinmatchoperatorwith|Noop->_infer_shape_01input_shapes|Createshape->[|Someshape|]|Get_->_infer_shape_00input_shapes|GetSliceslice->_infer_shape_20input_shapesslice|Copy->_infer_shape_01input_shapes|Reshapeshape->[|Someshape|]|Reverse->_infer_shape_01input_shapes|Tilerepeats->_infer_shape_05input_shapesrepeats|Repeatrepeats->_infer_shape_06input_shapesrepeats|Pad(_v,padding)->_infer_shape_30input_shapespadding|Concatenateaxis->_infer_shape_07input_shapesaxis|Split(axis,parts)->_infer_shape_08input_shapesaxisparts|Draw(axis,n)->_infer_shape_09input_shapesaxisn|Map_->_infer_shape_01input_shapes|Fold(axis,_f)->_infer_shape_04input_shapesaxis|Scan(_axis,_f)->_infer_shape_01input_shapes|OneHotdepth->_infer_shape_22input_shapesdepth|OfArrayshape->[|Someshape|]|Abs->_infer_shape_01input_shapes|Delay_f->_infer_shape_01input_shapes|DelayArray(shape,_f)->[|Someshape|]|LazyPrint(_max_row,_max_col,_header,_fmt)->_infer_shape_01input_shapes|Neg->_infer_shape_01input_shapes|Floor->_infer_shape_01input_shapes|Ceil->_infer_shape_01input_shapes|Round->_infer_shape_01input_shapes|Sqr->_infer_shape_01input_shapes|Sqrt->_infer_shape_01input_shapes|Log->_infer_shape_01input_shapes|Log2->_infer_shape_01input_shapes|Log10->_infer_shape_01input_shapes|Exp->_infer_shape_01input_shapes|Sin->_infer_shape_01input_shapes|Cos->_infer_shape_01input_shapes|Tan->_infer_shape_01input_shapes|Sinh->_infer_shape_01input_shapes|Cosh->_infer_shape_01input_shapes|Tanh->_infer_shape_01input_shapes|Asin->_infer_shape_01input_shapes|Acos->_infer_shape_01input_shapes|Atan->_infer_shape_01input_shapes|Asinh->_infer_shape_01input_shapes|Acosh->_infer_shape_01input_shapes|Atanh->_infer_shape_01input_shapes|Minaxis->_infer_shape_04input_shapesaxis|Maxaxis->_infer_shape_04input_shapesaxis|Sumaxis->_infer_shape_04input_shapesaxis|SumReduceaxis->_infer_shape_10input_shapesaxis|Signum->_infer_shape_01input_shapes|Sigmoid->_infer_shape_01input_shapes|Relu->_infer_shape_01input_shapes|Dawsn->_infer_shape_01input_shapes|Min'->_infer_shape_00input_shapes|Max'->_infer_shape_00input_shapes|Sum'->_infer_shape_00input_shapes|L1norm'->_infer_shape_00input_shapes|L2norm'->_infer_shape_00input_shapes|L2NormSqr'->_infer_shape_00input_shapes|ClipByValue->_infer_shape_01input_shapes|ClipByL2norm->_infer_shape_01input_shapes|Pow->_infer_shape_01input_shapes|ScalarPow->_infer_shape_02input_shapes|PowScalar->_infer_shape_01input_shapes|Atan2->_infer_shape_03input_shapes|ScalarAtan2->_infer_shape_02input_shapes|Atan2Scalar->_infer_shape_01input_shapes|Hypot->_infer_shape_01input_shapes|Min2->_infer_shape_01input_shapes|Max2->_infer_shape_01input_shapes|Add->_infer_shape_03input_shapes|Sub->_infer_shape_03input_shapes|Mul->_infer_shape_03input_shapes|Div->_infer_shape_03input_shapes|AddScalar->_infer_shape_01input_shapes|SubScalar->_infer_shape_01input_shapes|MulScalar->_infer_shape_01input_shapes|DivScalar->_infer_shape_01input_shapes|ScalarAdd->_infer_shape_02input_shapes|ScalarSub->_infer_shape_02input_shapes|ScalarMul->_infer_shape_02input_shapes|ScalarDiv->_infer_shape_02input_shapes|FMA->_infer_shape_23input_shapes|EltEqual->_infer_shape_01input_shapes|EltNotEqual->_infer_shape_01input_shapes|EltLess->_infer_shape_01input_shapes|EltGreater->_infer_shape_01input_shapes|EltLessEqual->_infer_shape_01input_shapes|EltGreaterEqual->_infer_shape_01input_shapes|EltEqualScalar->_infer_shape_01input_shapes|EltNotEqualScalar->_infer_shape_01input_shapes|EltLessScalar->_infer_shape_01input_shapes|EltGreaterScalar->_infer_shape_01input_shapes|EltLessEqualScalar->_infer_shape_01input_shapes|EltGreaterEqualScalar->_infer_shape_01input_shapes|Conv1d(padding,stride)->_infer_shape_11input_shapespaddingstride|Conv2d(padding,stride)->_infer_shape_12input_shapespaddingstride|Conv3d(padding,stride)->_infer_shape_13input_shapespaddingstride|TransposeConv1d(padding,stride)->_infer_shape_24input_shapespaddingstride|TransposeConv2d(padding,stride)->_infer_shape_14input_shapespaddingstride|TransposeConv3d(padding,stride)->_infer_shape_25input_shapespaddingstride|DilatedConv1d(padding,stride,rate)->_infer_shape_26input_shapespaddingstriderate|DilatedConv2d(padding,stride,rate)->_infer_shape_27input_shapespaddingstriderate|DilatedConv3d(padding,stride,rate)->_infer_shape_28input_shapespaddingstriderate|MaxPool1d(padding,kernel,stride)->_infer_shape_15input_shapespaddingkernelstride|MaxPool2d(padding,kernel,stride)->_infer_shape_21input_shapespaddingkernelstride|MaxPool3d(padding,kernel,stride)->_infer_shape_17input_shapespaddingkernelstride|AvgPool1d(padding,kernel,stride)->_infer_shape_15input_shapespaddingkernelstride|AvgPool2d(padding,kernel,stride)->_infer_shape_21input_shapespaddingkernelstride|AvgPool3d(padding,kernel,stride)->_infer_shape_17input_shapespaddingkernelstride|UpSampling2dsize->_infer_shape_29input_shapessize|Conv1dBackwardInput_stride->_infer_shape_01input_shapes|Conv1dBackwardKernel_stride->_infer_shape_02input_shapes|Conv2dBackwardInput_stride->_infer_shape_01input_shapes|Conv2dBackwardKernel_stride->_infer_shape_02input_shapes|Conv3dBackwardInput_stride->_infer_shape_01input_shapes|Conv3dBackwardKernel_stride->_infer_shape_02input_shapes|TransposeConv1dBackwardInput_stride->_infer_shape_01input_shapes|TransposeConv1dBackwardKernel_stride->_infer_shape_02input_shapes|TransposeConv2dBackwardInput_stride->_infer_shape_01input_shapes|TransposeConv2dBackwardKernel_stride->_infer_shape_02input_shapes|TransposeConv3dBackwardInput_stride->_infer_shape_01input_shapes|TransposeConv3dBackwardKernel_stride->_infer_shape_02input_shapes|DilatedConv1dBackwardInput(_stride,_rate)->_infer_shape_01input_shapes|DilatedConv1dBackwardKernel(_stride,_rate)->_infer_shape_02input_shapes|DilatedConv2dBackwardInput(_stride,_rate)->_infer_shape_01input_shapes|DilatedConv2dBackwardKernel(_stride,_rate)->_infer_shape_02input_shapes|DilatedConv3dBackwardInput(_stride,_rate)->_infer_shape_01input_shapes|DilatedConv3dBackwardKernel(_stride,_rate)->_infer_shape_02input_shapes|MaxPool1dBackward(_padding,_kernel,_stride)->_infer_shape_01input_shapes|MaxPool2dBackward(_padding,_kernel,_stride)->_infer_shape_01input_shapes|MaxPool3dBackward(_padding,_kernel,_stride)->_infer_shape_01input_shapes|AvgPool1dBackward(_padding,_kernel,_stride)->_infer_shape_01input_shapes|AvgPool2dBackward(_padding,_kernel,_stride)->_infer_shape_01input_shapes|AvgPool3dBackward(_padding,_kernel,_stride)->_infer_shape_01input_shapes|UpSampling2dBackward_size->_infer_shape_01input_shapes|Row->_infer_shape_09input_shapes01|Rowsi->_infer_shape_09input_shapes0Array.(lengthi)|Dot(_transa,_transb,_alpha,_beta)->_infer_shape_19input_shapes|Inv->_infer_shape_01input_shapes|Trace->_infer_shape_00input_shapes|Transposeaxis->_infer_shape_18input_shapesaxis|ToRows->failwith"_infer_shape:ToRows not implemented"|OfRows->failwith"_infer_shape:OfRows not implemented"|Scalar_Add->_infer_shape_00input_shapes|Scalar_Sub->_infer_shape_00input_shapes|Scalar_Mul->_infer_shape_00input_shapes|Scalar_Div->_infer_shape_00input_shapes|Scalar_Pow->_infer_shape_00input_shapes|Scalar_Atan2->_infer_shape_00input_shapes|Scalar_Abs->_infer_shape_00input_shapes|Scalar_Neg->_infer_shape_00input_shapes|Scalar_Sqr->_infer_shape_00input_shapes|Scalar_Sqrt->_infer_shape_00input_shapes|Scalar_Exp->_infer_shape_00input_shapes|Scalar_Log->_infer_shape_00input_shapes|Scalar_Log2->_infer_shape_00input_shapes|Scalar_Log10->_infer_shape_00input_shapes|Scalar_Signum->_infer_shape_00input_shapes|Scalar_Floor->_infer_shape_00input_shapes|Scalar_Ceil->_infer_shape_00input_shapes|Scalar_Round->_infer_shape_00input_shapes|Scalar_Sin->_infer_shape_00input_shapes|Scalar_Cos->_infer_shape_00input_shapes|Scalar_Tan->_infer_shape_00input_shapes|Scalar_Sinh->_infer_shape_00input_shapes|Scalar_Cosh->_infer_shape_00input_shapes|Scalar_Tanh->_infer_shape_00input_shapes|Scalar_Asin->_infer_shape_00input_shapes|Scalar_Acos->_infer_shape_00input_shapes|Scalar_Atan->_infer_shape_00input_shapes|Scalar_Asinh->_infer_shape_00input_shapes|Scalar_Acosh->_infer_shape_00input_shapes|Scalar_Atanh->_infer_shape_00input_shapes|Scalar_Relu->_infer_shape_00input_shapes|Scalar_Dawsn->_infer_shape_00input_shapes|Scalar_Sigmoid->_infer_shape_00input_shapes|Fused_Adagrad(_rate,_eps)->_infer_shape_01input_shapes|_->[|None|]end(* Make functor ends *)