123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465# 1 "src/base/compute/owl_computation_cpu_eval.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*)openOwl_graph(* Functor of making evaluator of a CPU-based engine *)moduleMake(Graph:Owl_computation_graph_sig.Sig)=structopenGraph.Optimiser.Operator.SymbolopenGraph.Optimiser.Operator.Symbol.Shape.Type.Device(* utility functions *)letinvalidate_opt=function|Somex->invalidatex|None->()letupdate_validityxb=invalidate_opt(get_active_nodeb);set_active_nodebx;validatex(* core evaluation function *)letrec_eval_termsnodes=letevalx=Owl_log.debug"eval %s ..."(node_to_strx);ifis_validx=falsethen(let()=trymatchget_operatorxwith|Noop->_eval_map_00x(funx->x.(0))|Var->check_assignedx|Const->check_assignedx|Empty_shape->(_eval_map_01x(fun~out_x->())[@warning"-27"])|Zeros_shape->_eval_map_01x(fun~out_x->A.zeros_~out)|Ones_shape->_eval_map_01x(fun~out_x->A.ones_~out)|Create_shape->_eval_map_02x(fun~outx->A.create_~outx.(0))|Sequential_shape->_eval_map_02x(fun~outx->A.sequential_~a:x.(0)~step:x.(1)~out)|Uniform_shape->_eval_map_02x(fun~outx->A.uniform_~a:x.(0)~b:x.(1)~out)|Gaussian_shape->_eval_map_02x(fun~outx->A.gaussian_~mu:x.(0)~sigma:x.(1)~out)|Bernoulli_shape->_eval_map_02x(fun~outx->A.bernoulli_~p:x.(0)~out)|Init(_shape,_f)->failwith"Init"|Geti->_eval_map_06x(funx->A.getxi)|Set_i->failwith"Set"|GetSliceslice->_eval_map_01x(fun~outx->A.get_slice_~outslicex.(0))|SetSliceslice->_eval_map_01x(fun~outx->A.set_slice_~outslicex.(0)x.(1))|Copy->(_eval_map_01x(fun~outx->A.copy_~outx.(0))[@warning"-27"])|Reset->_eval_map_01x(fun~out_x->A.resetout)|Reshape_shape->_eval_map_01x(fun~outx->A.reshape_~outx.(0))|Reverse->_eval_map_01x(fun~outx->A.reverse_~outx.(0))|Tilerepeats->_eval_map_01x(fun~outx->A.tile_~outx.(0)repeats)|Repeatrepeats->_eval_map_01x(fun~outx->A.repeat_~outx.(0)repeats)|Pad(v,padding)->_eval_map_01x(fun~outx->A.pad_~out~v:(unpack_eltv)paddingx.(0))|Concatenateaxis->_eval_map_00xA.(concatenate~axis)|Split(_axis,_parts)->failwith"Split"|Draw(_axis,_n)->failwith"Draw"|Map_f->failwith"Map"|Fold(_axis,_f)->failwith"Fold"|Scan(_axis,_f)->failwith"Scan"|OneHotdepth->_eval_map_01x(fun~outx->A.one_hot_~outdepthx.(0))|OfArrays->_eval_map_09x(funx->A.of_arrayxs)|Delayf->_eval_map_08xf|DelayArray(_shape,f)->_eval_map_00xf|LazyPrint(max_col,max_row,header,fmt)->_eval_map_00x(funx->A.print?max_col?max_row?header?fmtx.(0);x.(0))|Abs->_eval_map_01x(fun~outx->A.abs_~outx.(0))|Neg->_eval_map_01x(fun~outx->A.neg_~outx.(0))|Floor->_eval_map_01x(fun~outx->A.floor_~outx.(0))|Ceil->_eval_map_01x(fun~outx->A.ceil_~outx.(0))|Round->_eval_map_01x(fun~outx->A.round_~outx.(0))|Sqr->_eval_map_01x(fun~outx->A.sqr_~outx.(0))|Sqrt->_eval_map_01x(fun~outx->A.sqrt_~outx.(0))|Log->_eval_map_01x(fun~outx->A.log_~outx.(0))|Log2->_eval_map_01x(fun~outx->A.log2_~outx.(0))|Log10->_eval_map_01x(fun~outx->A.log10_~outx.(0))|Exp->_eval_map_01x(fun~outx->A.exp_~outx.(0))|Sin->_eval_map_01x(fun~outx->A.sin_~outx.(0))|Cos->_eval_map_01x(fun~outx->A.cos_~outx.(0))|Tan->_eval_map_01x(fun~outx->A.tan_~outx.(0))|Sinh->_eval_map_01x(fun~outx->A.sinh_~outx.(0))|Cosh->_eval_map_01x(fun~outx->A.cosh_~outx.(0))|Tanh->_eval_map_01x(fun~outx->A.tanh_~outx.(0))|Asin->_eval_map_01x(fun~outx->A.asin_~outx.(0))|Acos->_eval_map_01x(fun~outx->A.acos_~outx.(0))|Atan->_eval_map_01x(fun~outx->A.atan_~outx.(0))|Asinh->_eval_map_01x(fun~outx->A.asinh_~outx.(0))|Acosh->_eval_map_01x(fun~outx->A.acosh_~outx.(0))|Atanh->_eval_map_01x(fun~outx->A.atanh_~outx.(0))|Min(keep_dims,axis)->(* reuse memory (_eval_map_01) when keep_dims=true else create new node
(_eval_map_00) *)(* TODO: implement keep_dims for A.min_ in Ndarray to reuse memory *)ifkeep_dimsthen_eval_map_01x(fun~outx->A.min_~out~axisx.(0))else_eval_map_00x(funx->A.min~keep_dims~axisx.(0))|Max(keep_dims,axis)->(* reuse memory (_eval_map_01) when keep_dims=true else create new node
(_eval_map_00) *)(* TODO: implement keep_dims for A.max_ in Ndarray to reuse memory *)ifkeep_dimsthen_eval_map_01x(fun~outx->A.max_~out~axisx.(0))else_eval_map_00x(funx->A.max~keep_dims~axisx.(0))|Sum(keep_dims,axis)->(* reuse memory (_eval_map_01) when keep_dims=true else create new node
(_eval_map_00) *)(* TODO: implement keep_dims for A.sum_ in Ndarray to reuse memory *)ifkeep_dimsthen_eval_map_01x(fun~outx->A.sum_~out~axisx.(0))else_eval_map_00x(funx->A.sum~keep_dims~axisx.(0))|SumReduceaxis->_eval_map_00x(funx->A.sum_reduce~axisx.(0))|Signum->_eval_map_01x(fun~outx->A.signum_~outx.(0))|Sigmoid->_eval_map_01x(fun~outx->A.sigmoid_~outx.(0))|Relu->_eval_map_01x(fun~outx->A.relu_~outx.(0))|Min'->_eval_map_06xA.min'|Max'->_eval_map_06xA.max'|Sum'->_eval_map_06xA.sum'|L1norm'->_eval_map_06xA.l1norm'|L2norm'->_eval_map_06xA.l2norm'|L2NormSqr'->_eval_map_06xA.l2norm_sqr'|ClipByValue->_eval_map_07x(fun~outae->A.clip_by_value_~out~amin:e.(0)~amax:e.(1)a.(0))|ClipByL2norm->_eval_map_07x(fun~outae->A.clip_by_l2norm_~oute.(0)a.(0))|Pow->_eval_map_01x(fun~outx->A.pow_~outx.(0)x.(1))|ScalarPow->_eval_map_05xA.scalar_pow_|PowScalar->_eval_map_04xA.pow_scalar_|Atan2->_eval_map_01x(fun~outx->A.atan2_~outx.(0)x.(1))|ScalarAtan2->_eval_map_05xA.scalar_atan2_|Atan2Scalar->_eval_map_04xA.atan2_scalar_|Hypot->_eval_map_01x(fun~outx->A.hypot_~outx.(0)x.(1))|Min2->_eval_map_01x(fun~outx->A.min2_~outx.(0)x.(1))|Max2->_eval_map_01x(fun~outx->A.max2_~outx.(0)x.(1))|Add->_eval_map_01x(fun~outx->A.add_~outx.(0)x.(1))|Sub->_eval_map_01x(fun~outx->A.sub_~outx.(0)x.(1))|Mul->_eval_map_01x(fun~outx->A.mul_~outx.(0)x.(1))|Div->_eval_map_01x(fun~outx->A.div_~outx.(0)x.(1))|AddScalar->_eval_map_04xA.add_scalar_|SubScalar->_eval_map_04xA.sub_scalar_|MulScalar->_eval_map_04xA.mul_scalar_|DivScalar->_eval_map_04xA.div_scalar_|ScalarAdd->_eval_map_05xA.scalar_add_|ScalarSub->_eval_map_05xA.scalar_sub_|ScalarMul->_eval_map_05xA.scalar_mul_|ScalarDiv->_eval_map_05xA.scalar_div_|FMA->_eval_map_01x(fun~outx->A.fma_~outx.(0)x.(1)x.(2))|EltEqual->_eval_map_01x(fun~outx->A.elt_equal_~outx.(0)x.(1))|EltNotEqual->_eval_map_01x(fun~outx->A.elt_not_equal_~outx.(0)x.(1))|EltLess->_eval_map_01x(fun~outx->A.elt_less_~outx.(0)x.(1))|EltGreater->_eval_map_01x(fun~outx->A.elt_greater_~outx.(0)x.(1))|EltLessEqual->_eval_map_01x(fun~outx->A.elt_less_equal_~outx.(0)x.(1))|EltGreaterEqual->_eval_map_01x(fun~outx->A.elt_greater_equal_~outx.(0)x.(1))|EltEqualScalar->_eval_map_04xA.elt_equal_scalar_|EltNotEqualScalar->_eval_map_04xA.elt_not_equal_scalar_|EltLessScalar->_eval_map_04xA.elt_less_scalar_|EltGreaterScalar->_eval_map_04xA.elt_greater_scalar_|EltLessEqualScalar->_eval_map_04xA.elt_less_equal_scalar_|EltGreaterEqualScalar->_eval_map_04xA.elt_greater_equal_scalar_|Conv1d(padding,stride)->_eval_map_01x(fun~outx->A.conv1d_~out~paddingx.(0)x.(1)stride)|Conv2d(padding,stride)->_eval_map_01x(fun~outx->A.conv2d_~out~paddingx.(0)x.(1)stride)|Conv3d(padding,stride)->_eval_map_01x(fun~outx->A.conv3d_~out~paddingx.(0)x.(1)stride)|TransposeConv1d(padding,stride)->_eval_map_01x(fun~outx->A.transpose_conv1d_~out~paddingx.(0)x.(1)stride)|TransposeConv2d(padding,stride)->_eval_map_01x(fun~outx->A.transpose_conv2d_~out~paddingx.(0)x.(1)stride)|TransposeConv3d(padding,stride)->_eval_map_01x(fun~outx->A.transpose_conv3d_~out~paddingx.(0)x.(1)stride)|DilatedConv1d(padding,stride,rate)->_eval_map_01x(fun~outx->A.dilated_conv1d_~out~paddingx.(0)x.(1)striderate)|DilatedConv2d(padding,stride,rate)->_eval_map_01x(fun~outx->A.dilated_conv2d_~out~paddingx.(0)x.(1)striderate)|DilatedConv3d(padding,stride,rate)->_eval_map_01x(fun~outx->A.dilated_conv3d_~out~paddingx.(0)x.(1)striderate)|MaxPool1d(padding,kernel,stride)->_eval_map_01x(fun~outx->A.max_pool1d_~out~paddingx.(0)kernelstride)|MaxPool2d(padding,kernel,stride)->_eval_map_01x(fun~outx->A.max_pool2d_~out~paddingx.(0)kernelstride)|MaxPool3d(padding,kernel,stride)->_eval_map_01x(fun~outx->A.max_pool3d_~out~paddingx.(0)kernelstride)|AvgPool1d(padding,kernel,stride)->_eval_map_01x(fun~outx->A.avg_pool1d_~out~paddingx.(0)kernelstride)|AvgPool2d(padding,kernel,stride)->_eval_map_01x(fun~outx->A.avg_pool2d_~out~paddingx.(0)kernelstride)|AvgPool3d(padding,kernel,stride)->_eval_map_01x(fun~outx->A.avg_pool3d_~out~paddingx.(0)kernelstride)|UpSampling2dsize->_eval_map_01x(fun~outx->A.upsampling2d_~outx.(0)size)|Conv1dBackwardInputstride->_eval_map_01x(fun~outx->A.conv1d_backward_input_~outx.(0)x.(1)stridex.(2))|Conv1dBackwardKernelstride->_eval_map_01x(fun~outx->A.conv1d_backward_kernel_~outx.(0)x.(1)stridex.(2))|Conv2dBackwardInputstride->_eval_map_01x(fun~outx->A.conv2d_backward_input_~outx.(0)x.(1)stridex.(2))|Conv2dBackwardKernelstride->_eval_map_01x(fun~outx->A.conv2d_backward_kernel_~outx.(0)x.(1)stridex.(2))|Conv3dBackwardInputstride->_eval_map_01x(fun~outx->A.conv3d_backward_input_~outx.(0)x.(1)stridex.(2))|Conv3dBackwardKernelstride->_eval_map_01x(fun~outx->A.conv3d_backward_kernel_~outx.(0)x.(1)stridex.(2))|TransposeConv1dBackwardInputstride->_eval_map_01x(fun~outx->A.transpose_conv1d_backward_input_~outx.(0)x.(1)stridex.(2))|TransposeConv1dBackwardKernelstride->_eval_map_01x(fun~outx->A.transpose_conv1d_backward_kernel_~outx.(0)x.(1)stridex.(2))|TransposeConv2dBackwardInputstride->_eval_map_01x(fun~outx->A.transpose_conv2d_backward_input_~outx.(0)x.(1)stridex.(2))|TransposeConv2dBackwardKernelstride->_eval_map_01x(fun~outx->A.transpose_conv2d_backward_kernel_~outx.(0)x.(1)stridex.(2))|TransposeConv3dBackwardInputstride->_eval_map_01x(fun~outx->A.transpose_conv3d_backward_input_~outx.(0)x.(1)stridex.(2))|TransposeConv3dBackwardKernelstride->_eval_map_01x(fun~outx->A.transpose_conv3d_backward_kernel_~outx.(0)x.(1)stridex.(2))|DilatedConv1dBackwardInput(stride,rate)->_eval_map_01x(fun~outx->A.dilated_conv1d_backward_input_~outx.(0)x.(1)strideratex.(2))|DilatedConv1dBackwardKernel(stride,rate)->_eval_map_01x(fun~outx->A.dilated_conv1d_backward_kernel_~outx.(0)x.(1)strideratex.(2))|DilatedConv2dBackwardInput(stride,rate)->_eval_map_01x(fun~outx->A.dilated_conv2d_backward_input_~outx.(0)x.(1)strideratex.(2))|DilatedConv2dBackwardKernel(stride,rate)->_eval_map_01x(fun~outx->A.dilated_conv2d_backward_kernel_~outx.(0)x.(1)strideratex.(2))|DilatedConv3dBackwardInput(stride,rate)->_eval_map_01x(fun~outx->A.dilated_conv3d_backward_input_~outx.(0)x.(1)strideratex.(2))|DilatedConv3dBackwardKernel(stride,rate)->_eval_map_01x(fun~outx->A.dilated_conv3d_backward_kernel_~outx.(0)x.(1)strideratex.(2))|MaxPool1dBackward(padding,kernel,stride)->_eval_map_01x(fun~outx->A.max_pool1d_backward_~outpaddingx.(0)kernelstridex.(1))|MaxPool2dBackward(padding,kernel,stride)->_eval_map_01x(fun~outx->A.max_pool2d_backward_~outpaddingx.(0)kernelstridex.(1))|MaxPool3dBackward(padding,kernel,stride)->_eval_map_01x(fun~outx->A.max_pool3d_backward_~outpaddingx.(0)kernelstridex.(1))|AvgPool1dBackward(padding,kernel,stride)->_eval_map_01x(fun~outx->A.avg_pool1d_backward_~outpaddingx.(0)kernelstridex.(1))|AvgPool2dBackward(padding,kernel,stride)->_eval_map_01x(fun~outx->A.avg_pool2d_backward_~outpaddingx.(0)kernelstridex.(1))|AvgPool3dBackward(padding,kernel,stride)->_eval_map_01x(fun~outx->A.avg_pool3d_backward_~outpaddingx.(0)kernelstridex.(1))|UpSampling2dBackwardsize->_eval_map_01x(fun~outx->A.upsampling2d_backward_~outx.(0)sizex.(1))|Row->failwith"Row"|Rows_i->failwith"Rows"|CopyRowTo->failwith"CopyRowTo"|CopyColTo->failwith"CopyColTo"|Dot(transa,transb,alpha,beta)->_eval_map_01x(fun~outx->A.dot_~transa~transb~alpha:(unpack_eltalpha)~beta:(unpack_eltbeta)~c:outx.(0)x.(1))|Inv->_eval_map_00x(funx->A.Linalg.invx.(0))|Trace->_eval_map_06xA.trace|Transposeaxis->_eval_map_01x(fun~outx->A.transpose_~out~axisx.(0))|ToRows->failwith"ToRows"|OfRows->failwith"OfRows"|Scalar_Add->_eval_map_03x(funx->A.Scalar.addx.(0)x.(1))|Scalar_Sub->_eval_map_03x(funx->A.Scalar.subx.(0)x.(1))|Scalar_Mul->_eval_map_03x(funx->A.Scalar.mulx.(0)x.(1))|Scalar_Div->_eval_map_03x(funx->A.Scalar.divx.(0)x.(1))|Scalar_Pow->_eval_map_03x(funx->A.Scalar.powx.(0)x.(1))|Scalar_Atan2->_eval_map_03x(funx->A.Scalar.atan2x.(0)x.(1))|Scalar_Abs->_eval_map_03x(funx->A.Scalar.absx.(0))|Scalar_Neg->_eval_map_03x(funx->A.Scalar.negx.(0))|Scalar_Sqr->_eval_map_03x(funx->A.Scalar.sqrx.(0))|Scalar_Sqrt->_eval_map_03x(funx->A.Scalar.sqrtx.(0))|Scalar_Exp->_eval_map_03x(funx->A.Scalar.expx.(0))|Scalar_Log->_eval_map_03x(funx->A.Scalar.logx.(0))|Scalar_Log2->_eval_map_03x(funx->A.Scalar.log2x.(0))|Scalar_Log10->_eval_map_03x(funx->A.Scalar.log10x.(0))|Scalar_Signum->_eval_map_03x(funx->A.Scalar.signumx.(0))|Scalar_Floor->_eval_map_03x(funx->A.Scalar.floorx.(0))|Scalar_Ceil->_eval_map_03x(funx->A.Scalar.ceilx.(0))|Scalar_Round->_eval_map_03x(funx->A.Scalar.roundx.(0))|Scalar_Sin->_eval_map_03x(funx->A.Scalar.sinx.(0))|Scalar_Cos->_eval_map_03x(funx->A.Scalar.cosx.(0))|Scalar_Tan->_eval_map_03x(funx->A.Scalar.tanx.(0))|Scalar_Sinh->_eval_map_03x(funx->A.Scalar.sinhx.(0))|Scalar_Cosh->_eval_map_03x(funx->A.Scalar.coshx.(0))|Scalar_Tanh->_eval_map_03x(funx->A.Scalar.tanhx.(0))|Scalar_Asin->_eval_map_03x(funx->A.Scalar.asinx.(0))|Scalar_Acos->_eval_map_03x(funx->A.Scalar.acosx.(0))|Scalar_Atan->_eval_map_03x(funx->A.Scalar.atanx.(0))|Scalar_Asinh->_eval_map_03x(funx->A.Scalar.asinhx.(0))|Scalar_Acosh->_eval_map_03x(funx->A.Scalar.acoshx.(0))|Scalar_Atanh->_eval_map_03x(funx->A.Scalar.atanhx.(0))|Scalar_Relu->_eval_map_03x(funx->A.Scalar.relux.(0))|Scalar_Sigmoid->_eval_map_03x(funx->A.Scalar.sigmoidx.(0))|Fused_Adagrad(rate,eps)->_eval_map_01x(fun~outx->A.fused_adagrad_~out~rate~epsx.(0))|_->failwith"owl_computation_eval:_eval_term"with|exn->Owl_log.error"evaluating %s"(node_to_strx);raiseexninArray.iter(update_validityx)(get_blockx))inArray.iterevalnodes(* [f] is pure, for [arr array -> arr] *)and_eval_map_00xf=_eval_terms(parentsx);letinputs=Array.map(funx->value_to_arr(get_valuex).(0))(parentsx)inletout=finputsinset_valuex[|arr_to_valueout|](* [f] is impure, for [arr array -> arr] *)and_eval_map_01xf=_eval_terms(parentsx);letinputs=Array.map(funparent->value_to_arr(get_valueparent).(0))(parentsx)inletout=value_to_arr(get_valuex).(0)inf~outinputs(* [f] is impure, for [elt array -> arr] *)and_eval_map_02xf=_eval_terms(parentsx);letinputs=Array.map(funparent->value_to_elt(get_valueparent).(0))(parentsx)inletout=value_to_arr(get_valuex).(0)inf~outinputs(* [f] is pure, for [elt array -> elt] *)and_eval_map_03xf=_eval_terms(parentsx);letinputs=Array.map(funparent->value_to_elt(get_valueparent).(0))(parentsx)inletout=finputsinset_valuex[|elt_to_valueout|](* [f] is impure, for [arr -> elt -> arr] *)and_eval_map_04xf=letx_parent_0=(parentsx).(0)inletx_parent_1=(parentsx).(1)in_eval_terms(parentsx);leta=value_to_arr(get_valuex_parent_0).(0)inletb=value_to_elt(get_valuex_parent_1).(0)inletout=value_to_arr(get_valuex).(0)inf~outab(* [f] is impure, for [elt -> arr -> arr] *)and_eval_map_05xf=letx_parent_0=(parentsx).(0)inletx_parent_1=(parentsx).(1)in_eval_terms(parentsx);leta=value_to_elt(get_valuex_parent_0).(0)inletb=value_to_arr(get_valuex_parent_1).(0)inletout=value_to_arr(get_valuex).(0)inf~outab(* [f] is pure, for [arr -> elt] *)and_eval_map_06xf=letx_parent=(parentsx).(0)in_eval_terms(parentsx);leta=(get_valuex_parent).(0)|>value_to_arr|>finset_valuex[|elt_to_valuea|](* [f] is impure, for [arr array -> elt array -> arr] *)and_eval_map_07xf=letx_parents=parentsxin_eval_termsx_parents;letarr_args=Owl_utils_array.filteris_node_arrx_parents|>Array.map(funv->(get_valuev).(0)|>value_to_arr)inletelt_args=Owl_utils_array.filteris_node_eltx_parents|>Array.map(funv->(get_valuev).(0)|>value_to_elt)inletout=value_to_arr(get_valuex).(0)inf~outarr_argselt_args(* [f] is pure, for [arr -> arr] *)and_eval_map_08xf=letx_parent=(parentsx).(0)in_eval_terms(parentsx);leta=(get_valuex_parent).(0)|>value_to_arr|>finset_valuex[|arr_to_valuea|](* [f] is pure, for [elt array -> arr] *)and_eval_map_09xf=_eval_terms(parentsx);letinputs=Array.map(funparent->value_to_elt(get_valueparent).(0))(parentsx)inletout=finputsinset_valuex[|arr_to_valueout|]end(* Make functor ends *)