src/base/compute/owl_computation_symbol.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*)openOwl_graph(* Functor of making the symbols of a computation graph. *)moduleMake(Shape:Owl_computation_shape_sig.Sig)=structmoduleShape=ShapeopenShapeopenShape.TypeopenShape.Type.Device(* string representation of symbols *)letop_to_str=function|Noop->"Noop"|Var->"Var"|Const->"Const"|Empty_shape->"Empty"|Zeros_shape->"Zeros"|Ones_shape->"Ones"|Create_shape->"Create"|Sequential_shape->"Sequential"|Uniform_shape->"Uniform"|Gaussian_shape->"Gaussian"|Bernoulli_shape->"Bernoulli"|Init(_shape,_f)->"Init"|Get_i->"Get"|Set_i->"Set"|GetSlice_slice->"GetSlice"|SetSlice_slice->"SetSlice"|GetFancy_->"GetFancy"|SetFancy_->"SetFancy"|Copy->"Copy"|Reset->"Reset"|Reshape_shape->"Reshape"|Reverse->"Reverse"|Tile_repeats->"Tile"|Repeat_repeats->"Repeat"|Concatenate_axis->"Concatenate"|Stack_axis->"Stack"|Split(_axis,_parts)->"Split"|Draw(_axis,_n)->"Draw"|Map_f->"Map"|Fold(_axis,_f)->"Fold"|Scan(_axis,_f)->"Scan"|OneHotdepth->Printf.sprintf"OneHot d:%i"depth|OfArray_s->"OfArray"|Delay_f->"Delay"|DelayArray(_shape,_f)->"DelayArray"|LazyPrint(_max_row,_max_col,_header,_fmt)->"LazyPrint"|Abs->"Abs"|Neg->"Neg"|Floor->"Floor"|Ceil->"Ceil"|Round->"Round"|Sqr->"Sqr"|Sqrt->"Sqrt"|Log->"Log"|Log2->"Log2"|Log10->"Log10"|Exp->"Exp"|Sin->"Sin"|Cos->"Cos"|Tan->"Tan"|Sinh->"Sinh"|Cosh->"Cosh"|Tanh->"Tanh"|Asin->"Asin"|Acos->"Acos"|Atan->"Atan"|Asinh->"Asinh"|Acosh->"Acosh"|Atanh->"Atanh"|Min(keep_dims,axis)->Printf.sprintf"Min keep_dims: %b, axis:%i"keep_dimsaxis|Max(keep_dims,axis)->Printf.sprintf"Max keep_dims: %b, axis:%i"keep_dimsaxis|Sum(keep_dims,axis)->Printf.sprintf"Sum keep_dims: %b, axis:%i"keep_dimsaxis|SumReduce_axis->"SumReduce"|Signum->"Signum"|Sigmoid->"Sigmoid"|Relu->"Relu"|Dawsn->"Dawsn"|Min'->"Min'"|Max'->"Max'"|Sum'->"Sum'"|LogSumExp'->"LogSumExp'"|LogSumExp(keep_dims,axis)->Printf.sprintf"LogSumExp keep_dims: %b, axis:%i"keep_dimsaxis|L1norm'->"L1norm'"|L2norm'->"L2norm'"|L2NormSqr'->"L2NormSqr'"|ClipByValue->"ClipByValue"|ClipByL2norm->"ClipByL2norm"|Pow->"Pow"|ScalarPow->"ScalarPow"|PowScalar->"PowScalar"|Atan2->"Atan2"|ScalarAtan2->"ScalarAtan2"|Atan2Scalar->"Atan2Scalar"|Hypot->"Hypot"|Min2->"Min2"|Max2->"Max2"|Add->"Add"|Sub->"Sub"|Mul->"Mul"|Div->"Div"|AddScalar->"AddScalar"|SubScalar->"SubScalar"|MulScalar->"MulScalar"|DivScalar->"DivScalar"|ScalarAdd->"ScalarAdd"|ScalarSub->"ScalarSub"|ScalarMul->"ScalarMul"|ScalarDiv->"ScalarDiv"|FMA->"FMA"|EltEqual->"EltEqual"|EltNotEqual->"EltNotEqual"|EltLess->"EltLess"|EltGreater->"EltGreater"|EltLessEqual->"EltLessEqual"|EltGreaterEqual->"EltGreaterEqual"|EltEqualScalar->"EltEqualScalar"|EltNotEqualScalar->"EltNotEqualScalar"|EltLessScalar->"EltLessScalar"|EltGreaterScalar->"EltGreaterScalar"|EltLessEqualScalar->"EltLessEqualScalar"|EltGreaterEqualScalar->"EltGreaterEqualScalar"|Conv1d(_padding,_stride)->"Conv1d"|Conv2d(__padding,__stride)->"Conv2d"|Conv3d(_padding,_stride)->"Conv3d"|TransposeConv1d(_padding,_stride)->"TransposeConv1d"|TransposeConv2d(_padding,_stride)->"TransposeConv2d"|TransposeConv3d(_padding,_stride)->"TransposeConv3d"|DilatedConv1d(_padding,_stride,_rate)->"DilatedConv1d"|DilatedConv2d(_padding,_stride,_rate)->"DilatedConv2d"|DilatedConv3d(_padding,_stride,_rate)->"DilatedConv3d"|MaxPool1d(_padding,_kernel,_stride)->"MaxPool1d"|MaxPool2d(_padding,_kernel,_stride)->"MaxPool2d"|MaxPool3d(_padding,_kernel,_stride)->"MaxPool3d"|AvgPool1d(_padding,_kernel,_stride)->"AvgPool1d"|AvgPool2d(_padding,_kernel,_stride)->"AvgPool2d"|AvgPool3d(_padding,_kernel,_stride)->"AvgPool3d"|UpSampling2d_size->"UpSampling2d"|Conv1dBackwardInput_stride->"Conv1dBackwardInput"|Conv1dBackwardKernel_stride->"Conv1dBackwardKernel"|Conv2dBackwardInput_stride->"Conv2dBackwardInput"|Conv2dBackwardKernel_stride->"Conv2dBackwardKernel"|Conv3dBackwardInput_stride->"Conv3dBackwardInput"|Conv3dBackwardKernel_stride->"Conv3dBackwardKernel"|TransposeConv1dBackwardInput_stride->"TransposeConv1dBackwardInput"|TransposeConv1dBackwardKernel_stride->"TransposeConv1dBackwardKernel"|TransposeConv2dBackwardInput_stride->"TransposeConv2dBackwardInput"|TransposeConv2dBackwardKernel_stride->"TransposeConv2dBackwardKernel"|TransposeConv3dBackwardInput_stride->"TransposeConv3dBackwardInput"|TransposeConv3dBackwardKernel_stride->"TransposeConv3dBackwardKernel"|DilatedConv1dBackwardInput(_stride,_rate)->"DilatedConv1dBackwardInput"|DilatedConv1dBackwardKernel(_stride,_rate)->"DilatedConv1dBackwardKernel"|DilatedConv2dBackwardInput(_stride,_rate)->"DilatedConv2dBackwardInput"|DilatedConv2dBackwardKernel(_stride,_rate)->"DilatedConv2dBackwardKernel"|DilatedConv3dBackwardInput(_stride,_rate)->"DilatedConv3dBackwardInput"|DilatedConv3dBackwardKernel(_stride,_rate)->"DilatedConv3dBackwardKernel"|MaxPool1dBackward(_padding,_kernel,_stride)->"MaxPool1dBackward"|MaxPool2dBackward(_padding,_kernel,_stride)->"MaxPool2dBackward"|MaxPool3dBackward(_padding,_kernel,_stride)->"MaxPool3dBackward"|AvgPool1dBackward(_padding,_kernel,_stride)->"AvgPool1dBackward"|AvgPool2dBackward(_padding,_kernel,_stride)->"AvgPool2dBackward"|AvgPool3dBackward(_padding,_kernel,_stride)->"AvgPool3dBackward"|UpSampling2dBackward_size->"UpSampling2dBackward"|Pad(_v,_padding)->"Pad"|RowNum->"RowNum"|ColNum->"ColNum"|Row->"Row"|Rows_i->"Rows"|CopyRowTo->"CopyRowTo"|CopyColTo->"CopyColTo"|Dot(_transa,_transb,_alpha,_beta)->"Dot"|Inv->"Inv"|Trace->"Trace"|Transpose_i->"Transpose"|ToRows->"ToRows"|OfRows->"OfRows"|Scalar_Add->"Scalar Add"|Scalar_Sub->"Scalar Sub"|Scalar_Mul->"Scalar Mul"|Scalar_Div->"Scalar Div"|Scalar_Pow->"Scalar Pow"|Scalar_Atan2->"Scalar Atan2"|Scalar_Abs->"Scalar Abs"|Scalar_Neg->"Scalar Neg"|Scalar_Sqr->"Scalar Sqr"|Scalar_Sqrt->"Scalar Sqrt"|Scalar_Exp->"Scalar Exp"|Scalar_Log->"Scalar Log"|Scalar_Log2->"Scalar Log2"|Scalar_Log10->"Scalar Log10"|Scalar_Signum->"Scalar Signum"|Scalar_Floor->"Scalar Floor"|Scalar_Ceil->"Scalar Ceil"|Scalar_Round->"Scalar Round"|Scalar_Sin->"Scalar Sin"|Scalar_Cos->"Scalar Cos"|Scalar_Tan->"Scalar Tan"|Scalar_Sinh->"Scalar Sinh"|Scalar_Cosh->"Scalar Cosh"|Scalar_Tanh->"Scalar Tanh"|Scalar_Asin->"Scalar Asin"|Scalar_Acos->"Scalar Acos"|Scalar_Atan->"Scalar Atan"|Scalar_Asinh->"Scalar Asinh"|Scalar_Acosh->"Scalar Acosh"|Scalar_Atanh->"Scalar Atanh"|Scalar_Relu->"Scalar Relu"|Scalar_Dawsn->"Scalar Dawsn"|Scalar_Sigmoid->"Scalar Sigmoid"|Fused_Adagrad(_rate,_eps)->"Fused_Adagrad"(* utility functions *)letis_random_variable=function|Uniform_shape->true|Gaussian_shape->true|Bernoulli_shape->true|_->falseletrefnumx=Owl_graph.outdegreexletnode_shapex=letx_shape=(attrx).shapeinletp=Array.lengthx_shape>0inlets="shape information is missing."inOwl_exception.(checkp(INVALID_ARGUMENTs));matchx_shape.(0)with|Somes->s|None->failwith"Owl_computation_symbol:node_shape"letnode_numelx=Array.fold_left(*)1(node_shapex)letis_shape_unknownx=letx_shape=(attrx).shapeinmatchx_shape.(0)with|Some_->true|None->falseletinfer_shape_graphxs=iter_descendants(funx->ifis_shape_unknownx=falsethen(letx_attr=attrxinletx_parents=parentsxinx_attr.shape<-infer_shapex_attr.opx_parents))xsletshape_to_strshp=letp=Array.lengthshp>0inlets="shape information is missing."inOwl_exception.(checkp(INVALID_ARGUMENTs));lets=matchshp.(0)with|Somes->Owl_utils_array.to_stringstring_of_ints|None->"unknown"inPrintf.sprintf"[%s]"sletnode_to_strn=letattr=attrninletshape_s=shape_to_strattr.shapeinletstate_s=ifattr.state=Validthen"valid"else"invalid"inPrintf.sprintf"[ #%i | name:%s | op:%s | state:%s | r:%i | s:%s ]"(idn)(namen)(op_to_strattr.op)state_s(refnumn)shape_s(* core manipulation functions *)letnode_to_arrx=Arrxletarr_to_node=function|Arrx->xletnode_to_eltx=Eltxletelt_to_node=function|Eltx->xletnew_block_id=let_global_block_id=ref0infun()->_global_block_id:=!_global_block_id+1;!_global_block_id(* Meant for reusable nodes. *)letmake_empty_block?block_idsize=letblock_id=matchblock_idwith|Someblock_id->block_id|None->new_block_id()in(* allocate a one-dimensional array *)letmemory=arr_to_value(A.empty[|size|])in{size;block_id;active=None;memory;nodes=[]}(* This is meant for nodes that are not reusable: memory is not reshaped. *)letmake_value_blockmemoryx=letblock_id=new_block_id()inletsize=ifis_eltmemorythen1elseA.numel(value_to_arrmemory)inletblock={size;block_id;active=Somex;memory;nodes=[x]}in(attrx).value<-[|memory|];(attrx).block<-Some[|block|]letmake_node?name?value?shape?freeze?reuse?stateop=letshape=matchshapewith|Somes->s|None->[|None|]inletstate=matchstatewith|Somes->s|None->Invalidinletreuse=matchreusewith|Somes->s|None->trueinletfreeze=matchfreezewith|Somes->s|None->falseinletvalue=matchvaluewith|Somev->v|None->[||]inletattr={op;freeze;reuse;state;shape;value;block=None}inletnode=Owl_graph.node?nameattrinifvalue<>[||]thenmake_value_blockvalue.(0)node;nodeletmake_then_connect?shapeopparents=letshape=matchshapewith|Somes->s|None->infer_shapeopparentsinletchild=make_node~shapeopin(* define the dependency of operation, can have duplicates *)connect_ancestorsparents[|child|];(* define the flow of computation graph, no duplicates *)letuniq_parents=Owl_utils_array.uniqueparentsinArray.iter(funparent->if(attrparent).freeze=falsethenconnect_descendants[|parent|][|child|])uniq_parents;childletvar_arr?shapename=make_node~name~shape:[|shape|]~reuse:falseVar|>node_to_arrletvar_eltname=make_node~name~shape:[|Some[||]|]~reuse:falseVar|>node_to_eltletconst_arrnamev=letvalue=[|arr_to_valuev|]inletshape=[|SomeA.(shapev)|]inmake_node~name~value~shape~freeze:true~reuse:false~state:ValidConst|>node_to_arrletconst_eltnamev=letvalue=[|elt_to_valuev|]inletshape=[|Some[||]|]inmake_node~name~value~shape~freeze:true~reuse:false~state:ValidConst|>node_to_eltletget_nodes_using_blockb=b.nodeslet_get_value_blockb=b.memoryletget_block_optx=(attrx).blockletget_blockx=matchget_block_optxwith|Someb->b|None->failwith"Symbol:get_block_exn: block not assigned"let_set_blockxb=(attrx).block<-Somebletadd_node_to_blockxblock=letdst_shp=node_shapexinletdst_numel=node_numelxinletsrc_val=value_to_arr(_get_value_blockblock)in(* allocate the first [dst_numel] elements for the memory of the node *)letdst_val=arr_to_value(A.reshape(A.sub_leftsrc_val0dst_numel)dst_shp)inblock.nodes<-x::block.nodes;_set_blockx[|block|];(attrx).value<-[|dst_val|]letget_active_nodeb=b.activeletset_active_nodebx=b.active<-Somexletget_block_idx=matchget_block_optxwith|Somebs->bs.(0).block_id|None->-1letset_valuexv=ifis_arrv.(0)then(matchget_block_optxwith|Some_->letxv=value_to_arr(attrx).value.(0)inletvv=value_to_arrv.(0)inA.copy_~out:xvvv|None->make_value_blockv.(0)x)else(matchget_block_optxwith|Somebs->(attrx).value<-v;bs.(0).memory<-v.(0)|None->make_value_blockv.(0)x)letget_valuex=(attrx).valueletset_operatorxop=(attrx).op<-opletget_operatorx=(attrx).opletset_reusexreuse=letop=(attrx).opinifop=Var&&op=ConstthenOwl_log.warn"set_reuse: ignored, %s"(node_to_strx)else(attrx).reuse<-reuseletget_reusex=(attrx).reuseletis_sharedx=matchget_block_optxwith|Somebs->(matchget_nodes_using_blockbs.(0)with|_::_::_->true(* at least 2 elements *)|_->false)|None->false(* contains itself *)letget_shared_nodesx=matchget_block_optxwith|Somebs->Array.of_list(get_nodes_using_blockbs.(0))|None->[|x|]letis_varx=(attrx).op=Varletis_constx=(attrx).op=Const(* TODO: change it to rely on the operator. *)letis_node_arrx=match(attrx).shape.(0)with|Some[||]->false|Some_->true|_->failwith"Owl_computation_symbol:is_arr"letis_node_eltx=match(attrx).shape.(0)with|Some[||]->true|Some_->false|_->failwith"Owl_computation_symbol:is_elt"letis_assignedx=matchget_block_optxwith|Some_->true|None->falseletcheck_assignedx=ifnot(is_assignedx)then(Owl_log.error"value not assigned: %s"(node_to_strx);failwith"owl_computation_symbol:check_assigned")letis_validx=(attrx).state=Validletvalidatex=(attrx).state<-Validletinvalidatex=(attrx).state<-Invalidletinvalidate_graphx=iter_descendantsinvalidate[|x|]letis_freezex=(attrx).freezeletfreezex=(attrx).freeze<-trueletfreeze_descendantsx=iter_descendantsfreezexletfreeze_ancestorsx=iter_ancestorsfreezexletpack_arrarr=const_arr""arrletunpack_arrx=letvalue=arr_to_nodex|>get_valueinletvalen=Array.lengthvalueinifvalen=0then(Owl_log.error"not evaluated: %s"(arr_to_nodex|>node_to_str);lets=Printf.sprintf"%s"(arr_to_nodex|>node_to_str)inOwl_exception.(check(valen>0)(INVALID_ARGUMENTs)));value_to_arrvalue.(0)letpack_eltelt=const_elt""eltletunpack_eltx=letvalue=elt_to_nodex|>get_valueinletvalen=Array.lengthvalueinifvalen=0then(Owl_log.error"not evaluated: %s"(elt_to_nodex|>node_to_str);lets=Printf.sprintf"%s"(elt_to_nodex|>node_to_str)inOwl_exception.(check(valen>0)(INVALID_ARGUMENTs)));value_to_eltvalue.(0)letunsafe_assign_arrxarr=letnode=arr_to_nodexinifis_varnodethen(ifis_assignednode=falsethen((attrnode).shape<-[|SomeA.(shapearr)|];infer_shape_graph[|node|]);set_valuenode[|arr_to_valuearr|];invalidate_graphnode)else(letinfo=node_to_strnodeinPrintf.sprintf"unsafe_assign_arr: const cannot be assigned, %s"info|>failwith)letassign_arrxarr=letnode=arr_to_nodexinifis_varnodethen(ifis_assignednodethen(letout=unpack_arrxinA.copy_~outarr)else(letdst=A.copyarrinset_valuenode[|arr_to_valuedst|];(* propagate the shape information *)(attrnode).shape<-[|SomeA.(shapedst)|];infer_shape_graph[|node|]);invalidate_graphnode)else(letinfo=node_to_strnodeinPrintf.sprintf"assign_arr: const cannot be assigned, %s"info|>failwith)letassign_eltxelt=letnode=elt_to_nodexinifis_varnodethen(set_valuenode[|elt_to_valueelt|];invalidate_graphnode)else(letinfo=node_to_strnodeinPrintf.sprintf"assign_elt: const cannot be assigned, %s"info|>failwith)letfloat_to_eltx=const_elt""(A.float_to_eltx)letelt_to_floatx=unpack_eltx|>A.elt_to_floatend(* Make functor ends *)