123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611# 1 "src/base/compute/owl_computation_symbol.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*)openOwl_graph(* Functor of making the symbols of a computation graph. *)moduleMake(Shape:Owl_computation_shape_sig.Sig)=structmoduleShape=ShapeopenShapeopenShape.TypeopenShape.Type.Device(* string representation of symbols *)letop_to_str=function|Noop->"Noop"|Var->"Var"|Const->"Const"|Empty_shape->"Empty"|Zeros_shape->"Zeros"|Ones_shape->"Ones"|Create_shape->"Create"|Sequential_shape->"Sequential"|Uniform_shape->"Uniform"|Gaussian_shape->"Gaussian"|Bernoulli_shape->"Bernoulli"|Init(_shape,_f)->"Init"|Get_i->"Get"|Set_i->"Set"|GetSlice_slice->"GetSlice"|SetSlice_slice->"SetSlice"|GetFancy_->"GetFancy"|SetFancy_->"SetFancy"|Copy->"Copy"|Reset->"Reset"|Reshape_shape->"Reshape"|Reverse->"Reverse"|Tile_repeats->"Tile"|Repeat_repeats->"Repeat"|Concatenate_axis->"Concatenate"|Stack_axis->"Stack"|Split(_axis,_parts)->"Split"|Draw(_axis,_n)->"Draw"|Map_f->"Map"|Fold(_axis,_f)->"Fold"|Scan(_axis,_f)->"Scan"|OneHotdepth->Printf.sprintf"OneHot d:%i"depth|OfArray_s->"OfArray"|Delay_f->"Delay"|DelayArray(_shape,_f)->"DelayArray"|LazyPrint(_max_row,_max_col,_header,_fmt)->"LazyPrint"|Abs->"Abs"|Neg->"Neg"|Floor->"Floor"|Ceil->"Ceil"|Round->"Round"|Sqr->"Sqr"|Sqrt->"Sqrt"|Log->"Log"|Log2->"Log2"|Log10->"Log10"|Exp->"Exp"|Sin->"Sin"|Cos->"Cos"|Tan->"Tan"|Sinh->"Sinh"|Cosh->"Cosh"|Tanh->"Tanh"|Asin->"Asin"|Acos->"Acos"|Atan->"Atan"|Asinh->"Asinh"|Acosh->"Acosh"|Atanh->"Atanh"|Min(keep_dims,axis)->Printf.sprintf"Min keep_dims: %b, axis:%i"keep_dimsaxis|Max(keep_dims,axis)->Printf.sprintf"Max keep_dims: %b, axis:%i"keep_dimsaxis|Sum(keep_dims,axis)->Printf.sprintf"Sum keep_dims: %b, axis:%i"keep_dimsaxis|SumReduce_axis->"SumReduce"|Signum->"Signum"|Sigmoid->"Sigmoid"|Relu->"Relu"|Dawsn->"Dawsn"|Min'->"Min'"|Max'->"Max'"|Sum'->"Sum'"|LogSumExp'->"LogSumExp'"|LogSumExp(keep_dims,axis)->Printf.sprintf"LogSumExp keep_dims: %b, axis:%i"keep_dimsaxis|L1norm'->"L1norm'"|L2norm'->"L2norm'"|L2NormSqr'->"L2NormSqr'"|ClipByValue->"ClipByValue"|ClipByL2norm->"ClipByL2norm"|Pow->"Pow"|ScalarPow->"ScalarPow"|PowScalar->"PowScalar"|Atan2->"Atan2"|ScalarAtan2->"ScalarAtan2"|Atan2Scalar->"Atan2Scalar"|Hypot->"Hypot"|Min2->"Min2"|Max2->"Max2"|Add->"Add"|Sub->"Sub"|Mul->"Mul"|Div->"Div"|AddScalar->"AddScalar"|SubScalar->"SubScalar"|MulScalar->"MulScalar"|DivScalar->"DivScalar"|ScalarAdd->"ScalarAdd"|ScalarSub->"ScalarSub"|ScalarMul->"ScalarMul"|ScalarDiv->"ScalarDiv"|FMA->"FMA"|EltEqual->"EltEqual"|EltNotEqual->"EltNotEqual"|EltLess->"EltLess"|EltGreater->"EltGreater"|EltLessEqual->"EltLessEqual"|EltGreaterEqual->"EltGreaterEqual"|EltEqualScalar->"EltEqualScalar"|EltNotEqualScalar->"EltNotEqualScalar"|EltLessScalar->"EltLessScalar"|EltGreaterScalar->"EltGreaterScalar"|EltLessEqualScalar->"EltLessEqualScalar"|EltGreaterEqualScalar->"EltGreaterEqualScalar"|Conv1d(_padding,_stride)->"Conv1d"|Conv2d(__padding,__stride)->"Conv2d"|Conv3d(_padding,_stride)->"Conv3d"|TransposeConv1d(_padding,_stride)->"TransposeConv1d"|TransposeConv2d(_padding,_stride)->"TransposeConv2d"|TransposeConv3d(_padding,_stride)->"TransposeConv3d"|DilatedConv1d(_padding,_stride,_rate)->"DilatedConv1d"|DilatedConv2d(_padding,_stride,_rate)->"DilatedConv2d"|DilatedConv3d(_padding,_stride,_rate)->"DilatedConv3d"|MaxPool1d(_padding,_kernel,_stride)->"MaxPool1d"|MaxPool2d(_padding,_kernel,_stride)->"MaxPool2d"|MaxPool3d(_padding,_kernel,_stride)->"MaxPool3d"|AvgPool1d(_padding,_kernel,_stride)->"AvgPool1d"|AvgPool2d(_padding,_kernel,_stride)->"AvgPool2d"|AvgPool3d(_padding,_kernel,_stride)->"AvgPool3d"|UpSampling2d_size->"UpSampling2d"|Conv1dBackwardInput_stride->"Conv1dBackwardInput"|Conv1dBackwardKernel_stride->"Conv1dBackwardKernel"|Conv2dBackwardInput_stride->"Conv2dBackwardInput"|Conv2dBackwardKernel_stride->"Conv2dBackwardKernel"|Conv3dBackwardInput_stride->"Conv3dBackwardInput"|Conv3dBackwardKernel_stride->"Conv3dBackwardKernel"|TransposeConv1dBackwardInput_stride->"TransposeConv1dBackwardInput"|TransposeConv1dBackwardKernel_stride->"TransposeConv1dBackwardKernel"|TransposeConv2dBackwardInput_stride->"TransposeConv2dBackwardInput"|TransposeConv2dBackwardKernel_stride->"TransposeConv2dBackwardKernel"|TransposeConv3dBackwardInput_stride->"TransposeConv3dBackwardInput"|TransposeConv3dBackwardKernel_stride->"TransposeConv3dBackwardKernel"|DilatedConv1dBackwardInput(_stride,_rate)->"DilatedConv1dBackwardInput"|DilatedConv1dBackwardKernel(_stride,_rate)->"DilatedConv1dBackwardKernel"|DilatedConv2dBackwardInput(_stride,_rate)->"DilatedConv2dBackwardInput"|DilatedConv2dBackwardKernel(_stride,_rate)->"DilatedConv2dBackwardKernel"|DilatedConv3dBackwardInput(_stride,_rate)->"DilatedConv3dBackwardInput"|DilatedConv3dBackwardKernel(_stride,_rate)->"DilatedConv3dBackwardKernel"|MaxPool1dBackward(_padding,_kernel,_stride)->"MaxPool1dBackward"|MaxPool2dBackward(_padding,_kernel,_stride)->"MaxPool2dBackward"|MaxPool3dBackward(_padding,_kernel,_stride)->"MaxPool3dBackward"|AvgPool1dBackward(_padding,_kernel,_stride)->"AvgPool1dBackward"|AvgPool2dBackward(_padding,_kernel,_stride)->"AvgPool2dBackward"|AvgPool3dBackward(_padding,_kernel,_stride)->"AvgPool3dBackward"|UpSampling2dBackward_size->"UpSampling2dBackward"|Pad(_v,_padding)->"Pad"|RowNum->"RowNum"|ColNum->"ColNum"|Row->"Row"|Rows_i->"Rows"|CopyRowTo->"CopyRowTo"|CopyColTo->"CopyColTo"|Dot(_transa,_transb,_alpha,_beta)->"Dot"|Inv->"Inv"|Trace->"Trace"|Transpose_i->"Transpose"|ToRows->"ToRows"|OfRows->"OfRows"|Scalar_Add->"Scalar Add"|Scalar_Sub->"Scalar Sub"|Scalar_Mul->"Scalar Mul"|Scalar_Div->"Scalar Div"|Scalar_Pow->"Scalar Pow"|Scalar_Atan2->"Scalar Atan2"|Scalar_Abs->"Scalar Abs"|Scalar_Neg->"Scalar Neg"|Scalar_Sqr->"Scalar Sqr"|Scalar_Sqrt->"Scalar Sqrt"|Scalar_Exp->"Scalar Exp"|Scalar_Log->"Scalar Log"|Scalar_Log2->"Scalar Log2"|Scalar_Log10->"Scalar Log10"|Scalar_Signum->"Scalar Signum"|Scalar_Floor->"Scalar Floor"|Scalar_Ceil->"Scalar Ceil"|Scalar_Round->"Scalar Round"|Scalar_Sin->"Scalar Sin"|Scalar_Cos->"Scalar Cos"|Scalar_Tan->"Scalar Tan"|Scalar_Sinh->"Scalar Sinh"|Scalar_Cosh->"Scalar Cosh"|Scalar_Tanh->"Scalar Tanh"|Scalar_Asin->"Scalar Asin"|Scalar_Acos->"Scalar Acos"|Scalar_Atan->"Scalar Atan"|Scalar_Asinh->"Scalar Asinh"|Scalar_Acosh->"Scalar Acosh"|Scalar_Atanh->"Scalar Atanh"|Scalar_Relu->"Scalar Relu"|Scalar_Dawsn->"Scalar Dawsn"|Scalar_Sigmoid->"Scalar Sigmoid"|Fused_Adagrad(_rate,_eps)->"Fused_Adagrad"(* utility functions *)letis_random_variable=function|Uniform_shape->true|Gaussian_shape->true|Bernoulli_shape->true|_->falseletrefnumx=Owl_graph.outdegreexletnode_shapex=letx_shape=(attrx).shapeinletp=Array.lengthx_shape>0inlets="shape information is missing."inOwl_exception.(checkp(INVALID_ARGUMENTs));matchx_shape.(0)with|Somes->s|None->failwith"Owl_computation_symbol:node_shape"letnode_numelx=Array.fold_left(*)1(node_shapex)letis_shape_unknownx=letx_shape=(attrx).shapeinmatchx_shape.(0)with|Some_->true|None->falseletinfer_shape_graphxs=iter_descendants(funx->ifis_shape_unknownx=falsethen(letx_attr=attrxinletx_parents=parentsxinx_attr.shape<-infer_shapex_attr.opx_parents))xsletshape_to_strshp=letp=Array.lengthshp>0inlets="shape information is missing."inOwl_exception.(checkp(INVALID_ARGUMENTs));lets=matchshp.(0)with|Somes->Owl_utils_array.to_stringstring_of_ints|None->"unknown"inPrintf.sprintf"[%s]"sletnode_to_strn=letattr=attrninletshape_s=shape_to_strattr.shapeinletstate_s=ifattr.state=Validthen"valid"else"invalid"inPrintf.sprintf"[ #%i | name:%s | op:%s | state:%s | r:%i | s:%s ]"(idn)(namen)(op_to_strattr.op)state_s(refnumn)shape_s(* core manipulation functions *)letnode_to_arrx=Arrxletarr_to_node=function|Arrx->xletnode_to_eltx=Eltxletelt_to_node=function|Eltx->xletnew_block_id=let_global_block_id=ref0infun()->_global_block_id:=!_global_block_id+1;!_global_block_id(* Meant for reusable nodes. *)letmake_empty_block?block_idsize=letblock_id=matchblock_idwith|Someblock_id->block_id|None->new_block_id()in(* allocate a one-dimensional array *)letmemory=arr_to_value(A.empty[|size|])in{size;block_id;active=None;memory;nodes=[]}(* This is meant for nodes that are not reusable: memory is not reshaped. *)letmake_value_blockmemoryx=letblock_id=new_block_id()inletsize=ifis_eltmemorythen1elseA.numel(value_to_arrmemory)inletblock={size;block_id;active=Somex;memory;nodes=[x]}in(attrx).value<-[|memory|];(attrx).block<-Some[|block|]letmake_node?name?value?shape?freeze?reuse?stateop=letshape=matchshapewith|Somes->s|None->[|None|]inletstate=matchstatewith|Somes->s|None->Invalidinletreuse=matchreusewith|Somes->s|None->trueinletfreeze=matchfreezewith|Somes->s|None->falseinletvalue=matchvaluewith|Somev->v|None->[||]inletattr={op;freeze;reuse;state;shape;value;block=None}inletnode=Owl_graph.node?nameattrinifvalue<>[||]thenmake_value_blockvalue.(0)node;nodeletmake_then_connect?shapeopparents=letshape=matchshapewith|Somes->s|None->infer_shapeopparentsinletchild=make_node~shapeopin(* define the dependency of operation, can have duplicates *)connect_ancestorsparents[|child|];(* define the flow of computation graph, no duplicates *)letuniq_parents=Owl_utils_array.uniqueparentsinArray.iter(funparent->if(attrparent).freeze=falsethenconnect_descendants[|parent|][|child|])uniq_parents;childletvar_arr?shapename=make_node~name~shape:[|shape|]~reuse:falseVar|>node_to_arrletvar_eltname=make_node~name~shape:[|Some[||]|]~reuse:falseVar|>node_to_eltletconst_arrnamev=letvalue=[|arr_to_valuev|]inletshape=[|SomeA.(shapev)|]inmake_node~name~value~shape~freeze:true~reuse:false~state:ValidConst|>node_to_arrletconst_eltnamev=letvalue=[|elt_to_valuev|]inletshape=[|Some[||]|]inmake_node~name~value~shape~freeze:true~reuse:false~state:ValidConst|>node_to_eltletget_nodes_using_blockb=b.nodeslet_get_value_blockb=b.memoryletget_block_optx=(attrx).blockletget_blockx=matchget_block_optxwith|Someb->b|None->failwith"Symbol:get_block_exn: block not assigned"let_set_blockxb=(attrx).block<-Somebletadd_node_to_blockxblock=letdst_shp=node_shapexinletdst_numel=node_numelxinletsrc_val=value_to_arr(_get_value_blockblock)in(* allocate the first [dst_numel] elements for the memory of the node *)letdst_val=arr_to_value(A.reshape(A.sub_leftsrc_val0dst_numel)dst_shp)inblock.nodes<-x::block.nodes;_set_blockx[|block|];(attrx).value<-[|dst_val|]letget_active_nodeb=b.activeletset_active_nodebx=b.active<-Somexletget_block_idx=matchget_block_optxwith|Somebs->bs.(0).block_id|None->-1letset_valuexv=ifis_arrv.(0)then(matchget_block_optxwith|Some_->letxv=value_to_arr(attrx).value.(0)inletvv=value_to_arrv.(0)inA.copy_~out:xvvv|None->make_value_blockv.(0)x)else(matchget_block_optxwith|Somebs->(attrx).value<-v;bs.(0).memory<-v.(0)|None->make_value_blockv.(0)x)letget_valuex=(attrx).valueletset_operatorxop=(attrx).op<-opletget_operatorx=(attrx).opletset_reusexreuse=letop=(attrx).opinifop=Var&&op=ConstthenOwl_log.warn"set_reuse: ignored, %s"(node_to_strx)else(attrx).reuse<-reuseletget_reusex=(attrx).reuseletis_sharedx=matchget_block_optxwith|Somebs->(matchget_nodes_using_blockbs.(0)with|_::_::_->true(* at least 2 elements *)|_->false)|None->false(* contains itself *)letget_shared_nodesx=matchget_block_optxwith|Somebs->Array.of_list(get_nodes_using_blockbs.(0))|None->[|x|]letis_varx=(attrx).op=Varletis_constx=(attrx).op=Const(* TODO: change it to rely on the operator. *)letis_node_arrx=match(attrx).shape.(0)with|Some[||]->false|Some_->true|_->failwith"Owl_computation_symbol:is_arr"letis_node_eltx=match(attrx).shape.(0)with|Some[||]->true|Some_->false|_->failwith"Owl_computation_symbol:is_elt"letis_assignedx=matchget_block_optxwith|Some_->true|None->falseletcheck_assignedx=ifnot(is_assignedx)then(Owl_log.error"value not assigned: %s"(node_to_strx);failwith"owl_computation_symbol:check_assigned")letis_validx=(attrx).state=Validletvalidatex=(attrx).state<-Validletinvalidatex=(attrx).state<-Invalidletinvalidate_graphx=iter_descendantsinvalidate[|x|]letis_freezex=(attrx).freezeletfreezex=(attrx).freeze<-trueletfreeze_descendantsx=iter_descendantsfreezexletfreeze_ancestorsx=iter_ancestorsfreezexletpack_arrarr=const_arr""arrletunpack_arrx=letvalue=arr_to_nodex|>get_valueinletvalen=Array.lengthvalueinifvalen=0then(Owl_log.error"not evaluated: %s"(arr_to_nodex|>node_to_str);lets=Printf.sprintf"%s"(arr_to_nodex|>node_to_str)inOwl_exception.(check(valen>0)(INVALID_ARGUMENTs)));value_to_arrvalue.(0)letpack_eltelt=const_elt""eltletunpack_eltx=letvalue=elt_to_nodex|>get_valueinletvalen=Array.lengthvalueinifvalen=0then(Owl_log.error"not evaluated: %s"(elt_to_nodex|>node_to_str);lets=Printf.sprintf"%s"(elt_to_nodex|>node_to_str)inOwl_exception.(check(valen>0)(INVALID_ARGUMENTs)));value_to_eltvalue.(0)letunsafe_assign_arrxarr=letnode=arr_to_nodexinifis_varnodethen(ifis_assignednode=falsethen((attrnode).shape<-[|SomeA.(shapearr)|];infer_shape_graph[|node|]);set_valuenode[|arr_to_valuearr|];invalidate_graphnode)else(letinfo=node_to_strnodeinPrintf.sprintf"unsafe_assign_arr: const cannot be assigned, %s"info|>failwith)letassign_arrxarr=letnode=arr_to_nodexinifis_varnodethen(ifis_assignednodethen(letout=unpack_arrxinA.copy_~outarr)else(letdst=A.copyarrinset_valuenode[|arr_to_valuedst|];(* propagate the shape information *)(attrnode).shape<-[|SomeA.(shapedst)|];infer_shape_graph[|node|]);invalidate_graphnode)else(letinfo=node_to_strnodeinPrintf.sprintf"assign_arr: const cannot be assigned, %s"info|>failwith)letassign_eltxelt=letnode=elt_to_nodexinifis_varnodethen(set_valuenode[|elt_to_valueelt|];invalidate_graphnode)else(letinfo=node_to_strnodeinPrintf.sprintf"assign_elt: const cannot be assigned, %s"info|>failwith)letfloat_to_eltx=const_elt""(A.float_to_eltx)letelt_to_floatx=unpack_eltx|>A.elt_to_floatend(* Make functor ends *)