123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416# 1 "src/base/compute/owl_computation_cpu_init.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
*)openOwl_graph(* Functor of making initialisor of a CPU-based engine. *)moduleMake(Graph:Owl_computation_graph_sig.Sig)=structopenGraph.Optimiser.Operator.SymbolopenGraph.Optimiser.Operator.Symbol.Shape.TypemoduleMultiMap=Owl_utils_multimap.Make(structtypet=intletcompare:int->int->int=compareend)(* utility functions *)(* cannot overwrite parents *)letsplit_00p=[||],p(* can overwrite parents *)letsplit_01p=p,[||](* broadcasting nodes can overwrite their parents iff they have the same
* shape *)letsplit_02xp=letshape_x=node_shapexin(Owl_utils.Array.filter(funp->node_shapep=shape_x)p,Owl_utils.Array.filter(funp->node_shapep<>shape_x)p)(* concatenate: can overwrite first parent if axis = 0 *)letsplit_03paxis=ifaxis=0then[|p.(0)|],Array.subp1(Array.lengthp-1)elsesplit_00p(* return a partition of the parents in two arrays: the parents that the node
* can safely overwrite during its computation and the others.
* Written to be safe, but can probably make it more fine-grained for some
* operations. *)letsplit_parentsx=letp=Owl_utils.Array.unique(parentsx)inmatchget_operatorxwith|Noop->split_01p|Var->split_01p|Const->split_01p|Empty_shape->split_01p|Zeros_shape->split_01p|Ones_shape->split_01p|Create_shape->split_00p|Sequential_shape->split_00p|Uniform_shape->split_00p|Gaussian_shape->split_00p|Bernoulli_shape->split_00p|Init(_shape,_f)->split_01p|Get_i->split_01p|Set_i->split_01p|GetSlice_slice->split_00p(* ? *)|SetSlice_slice->split_00p(* ? *)|Copy->split_01p|Reset->split_01p|Reshape_shape->split_01p|Reverse->split_00p|Tile_repeats->split_00p|Repeat_repeats->split_00p(* ? *)|Pad(_v,_padding)->split_00p|Concatenateaxis->split_03paxis|Stackaxis->split_03paxis|Split(_axis,_parts)->failwith"Split"|Draw(_axis,_n)->failwith"Draw"|Map_f->split_01p|Fold(_axis,_f)->split_00p(* ? *)|Scan(_axis,_f)->split_00p(* ? *)|OneHot_depth->split_00p(* ? *)|OfArray_s->split_00p(* ? *)|Delay_f->split_01p|DelayArray(_shape,_f)->split_01p|LazyPrint(_max_row,_max_col,_header,_fmt)->split_01p|Abs->split_01p|Neg->split_01p|Floor->split_01p|Ceil->split_01p|Round->split_01p|Sqr->split_01p|Sqrt->split_01p|Log->split_01p|Log2->split_01p|Log10->split_01p|Exp->split_01p|Sin->split_01p|Cos->split_01p|Tan->split_01p|Sinh->split_01p|Cosh->split_01p|Tanh->split_01p|Asin->split_01p|Acos->split_01p|Atan->split_01p|Asinh->split_01p|Acosh->split_01p|Atanh->split_01p|Min(_keep_dims,_axis)->split_00p(* ? *)|Max(_keep_dims,_axis)->split_00p(* ? *)|Sum(_keep_dims,_axis)->split_00p(* ? *)|SumReduce_axis->split_00p(* ? *)|Signum->split_01p|Sigmoid->split_01p|Relu->split_01p|Dawsn->split_01p|Min'->split_01p|Max'->split_01p|Sum'->split_01p|LogSumExp'->split_01p|LogSumExp(_keep_dims,_axis)->split_00p|L1norm'->split_01p|L2norm'->split_01p|L2NormSqr'->split_01p|ClipByValue->split_01p|ClipByL2norm->split_01p|Pow->split_02xp|ScalarPow->split_01p|PowScalar->split_01p|Atan2->split_02xp|ScalarAtan2->split_01p|Atan2Scalar->split_01p|Hypot->split_02xp|Min2->split_02xp|Max2->split_02xp|Add->split_02xp|Sub->split_02xp|Mul->split_02xp|Div->split_02xp|AddScalar->split_01p|SubScalar->split_01p|MulScalar->split_01p|DivScalar->split_01p|ScalarAdd->split_01p|ScalarSub->split_01p|ScalarMul->split_01p|ScalarDiv->split_01p|FMA->split_02xp|EltEqual->split_02xp|EltNotEqual->split_02xp|EltLess->split_02xp|EltGreater->split_02xp|EltLessEqual->split_02xp|EltGreaterEqual->split_02xp|EltEqualScalar->split_01p|EltNotEqualScalar->split_01p|EltLessScalar->split_01p|EltGreaterScalar->split_01p|EltLessEqualScalar->split_01p|EltGreaterEqualScalar->split_01p|Conv1d(_padding,_stride)->split_00p(* condition on pad, ker and str for conv ops? *)|Conv2d(_padding,_stride)->split_00p|Conv3d(_padding,_stride)->split_00p|TransposeConv1d(_padding,_stride)->split_00p|TransposeConv2d(_padding,_stride)->split_00p|TransposeConv3d(_padding,_stride)->split_00p|DilatedConv1d(_padding,_stride,_rate)->split_00p|DilatedConv2d(_padding,_stride,_rate)->split_00p|DilatedConv3d(_padding,_stride,_rate)->split_00p|MaxPool1d(_padding,_kernel,_stride)->split_00p(* pool ops? depends on pad? *)|MaxPool2d(_padding,_kernel,_stride)->split_00p|MaxPool3d(_padding,_kernel,_stride)->split_00p|AvgPool1d(_padding,_kernel,_stride)->split_00p|AvgPool2d(_padding,_kernel,_stride)->split_00p|AvgPool3d(_padding,_kernel,_stride)->split_00p|UpSampling2d_size->split_00p|Conv1dBackwardInput_stride->split_00p|Conv1dBackwardKernel_stride->split_00p|Conv2dBackwardInput_stride->split_00p|Conv2dBackwardKernel_stride->split_00p|Conv3dBackwardInput_stride->split_00p|Conv3dBackwardKernel_stride->split_00p|TransposeConv1dBackwardInput_stride->split_00p|TransposeConv1dBackwardKernel_stride->split_00p|TransposeConv2dBackwardInput_stride->split_00p|TransposeConv2dBackwardKernel_stride->split_00p|TransposeConv3dBackwardInput_stride->split_00p|TransposeConv3dBackwardKernel_stride->split_00p|DilatedConv1dBackwardInput(_stride,_rate)->split_00p|DilatedConv1dBackwardKernel(_stride,_rate)->split_00p|DilatedConv2dBackwardInput(_stride,_rate)->split_00p|DilatedConv2dBackwardKernel(_stride,_rate)->split_00p|DilatedConv3dBackwardInput(_stride,_rate)->split_00p|DilatedConv3dBackwardKernel(_stride,_rate)->split_00p|MaxPool1dBackward(_padding,_kernel,_stride)->split_00p|MaxPool2dBackward(_padding,_kernel,_stride)->split_00p|MaxPool3dBackward(_padding,_kernel,_stride)->split_00p|AvgPool1dBackward(_padding,_kernel,_stride)->split_00p|AvgPool2dBackward(_padding,_kernel,_stride)->split_00p|AvgPool3dBackward(_padding,_kernel,_stride)->split_00p|UpSampling2dBackward_size->split_00p|RowNum->split_01p|ColNum->split_01p|Row->failwith"Row"|Rows_i->failwith"Rows"|CopyRowTo->failwith"CopyRowTo"|CopyColTo->failwith"CopyColTo"|Dot(_transa,_transb,_alpha,_beta)->split_00p|Inv->split_00p|Trace->split_01p|Transpose_axis->split_00p|ToRows->failwith"ToRows"|OfRows->failwith"OfRows"|Scalar_Add->split_01p|Scalar_Sub->split_01p|Scalar_Mul->split_01p|Scalar_Div->split_01p|Scalar_Pow->split_01p|Scalar_Atan2->split_01p|Scalar_Abs->split_01p|Scalar_Neg->split_01p|Scalar_Sqr->split_01p|Scalar_Sqrt->split_01p|Scalar_Exp->split_01p|Scalar_Log->split_01p|Scalar_Log2->split_01p|Scalar_Log10->split_01p|Scalar_Signum->split_01p|Scalar_Floor->split_01p|Scalar_Ceil->split_01p|Scalar_Round->split_01p|Scalar_Sin->split_01p|Scalar_Cos->split_01p|Scalar_Tan->split_01p|Scalar_Sinh->split_01p|Scalar_Cosh->split_01p|Scalar_Tanh->split_01p|Scalar_Asin->split_01p|Scalar_Acos->split_01p|Scalar_Atan->split_01p|Scalar_Asinh->split_01p|Scalar_Acosh->split_01p|Scalar_Atanh->split_01p|Scalar_Relu->split_01p|Scalar_Dawsn->split_01p|Scalar_Sigmoid->split_01p|Fused_Adagrad(_rate,_eps)->split_00p(* ? *)(* Core initialisation function. Inspired by
* https://mxnet.incubator.apache.org/architecture/note_memory.html. *)let_init_termsnodes=(* hashtable: node -> its number of references left to use *)letrefs=Hashtbl.create256in(* number of elements -> id of a reusable block of corresponding size *)letreusable=refMultiMap.emptyin(* node id -> id of a block that was assigned to it *)letnode_to_block=Hashtbl.create256in(* block id -> its size *)letblock_to_size=Hashtbl.create16in(* node id -> the corresponding node *)letid_to_node=Hashtbl.create256in(* already has a block or is already associated to a block id during the
* execution of the algorithm *)letis_initialisedx=is_assignedx||Hashtbl.memnode_to_block(idx)in(* Notifies a node that it has been used by one of its children.
* If no more children have to use the node, assumes that the memory of the
* node can be reused by another node. *)letupdate_parentp=letid_p=idpinif(not(is_assignedp))&&Hashtbl.memrefsid_pthen(letnum=Hashtbl.findrefsid_pinassert(num>0);ifnum-1=0then((* can be reused *)Hashtbl.removerefsid_p;letblock_id=Hashtbl.findnode_to_blockid_pinletblock_size=Hashtbl.findblock_to_sizeblock_idinreusable:=MultiMap.addblock_sizeblock_id!reusable)elseHashtbl.replacerefsid_p(num-1))in(* Heuristic: return the smallest block that is greater than [numel].
* If no such block exists, return the biggest one and make it bigger.
* Time complexity: [O(log b)] where [b] is the size of [reusable]. *)letbest_block_to_reusenumel=ifMultiMap.is_empty!reusablethenNoneelse(letto_reuse=MultiMap.find_first_opt(funk->k>=numel)!reusableinletsize,b_id=matchto_reusewith|Somex->x|None->MultiMap.max_binding!reusableinreusable:=MultiMap.removesize!reusable;ifsize<numelthenHashtbl.replaceblock_to_sizeb_idnumel;Someb_id)in(* Links node [x] to a new block. *)letallocate_newx=letnumel_x=node_numelxinletb_id=new_block_id()inHashtbl.addnode_to_block(idx)b_id;Hashtbl.addblock_to_sizeb_idnumel_xin(* Links the node [x] to the best reusable block if such a block exists.
* Otherwise, links [x] to a new block. *)letallocatex=letnumel_x=node_numelxinletblock_id_to_reuse=best_block_to_reusenumel_xinmatchblock_id_to_reusewith|Someb_id->Hashtbl.addnode_to_block(idx)b_id|None->allocate_newxin(* assume the parents of an initialised node are always initialised *)letrecinitx=Owl_log.debug"init %s ..."(node_to_strx);ifnot(is_initialisedx)then(Hashtbl.addid_to_node(idx)x;Array.iterinit(parentsx);letpre_par,post_par=split_parentsxinArray.iterupdate_parentpre_par;(* do not bother sharing the memory of single elements *)ifget_reusex&¬(is_node_eltx)then(Hashtbl.addrefs(idx)(refnumx);allocatex)else(* a node that cannot be reused cannot reuse either *)allocate_newx;Array.iterupdate_parentpost_par)in(* link all the nodes to a block id and all the blocks to a size *)Array.iterinitnodes;(* create the blocks and initialise the relevant attributes of the nodes *)letid_to_block=Hashtbl.create16inHashtbl.iter(funx_idb_id->letx=Hashtbl.findid_to_nodex_idinifHashtbl.memid_to_blockb_idthen(letblock=Hashtbl.findid_to_blockb_idinadd_node_to_blockxblock)else(letsize=Hashtbl.findblock_to_sizeb_idinletblock=make_empty_block~block_id:b_idsizeinHashtbl.addid_to_blockb_idblock;add_node_to_blockxblock))node_to_block(* display some statistics about the number of blocks and the number of
* allocated elements *)letinit_statsnodes=lettotal_elt=ref0inletshared_elt=ref0inletnon_shared_elt=ref0inlettotal_nodes=ref0inletreusable_nodes=ref0inletnon_reusable_nodes=ref0inletblocks_seen=Hashtbl.create256inletreusable_blocks=ref0inletalloc_reusable=ref0inletupdate_statsx=letnumel_x=node_numelxintotal_nodes:=!total_nodes+1;total_elt:=!total_elt+numel_x;ifget_reusexthen(reusable_nodes:=!reusable_nodes+1;shared_elt:=!shared_elt+numel_x)else(non_reusable_nodes:=!non_reusable_nodes+1;non_shared_elt:=!non_shared_elt+numel_x);letblock_x=(get_blockx).(0)inifnot(Hashtbl.memblocks_seenblock_x)then(Hashtbl.addblocks_seenblock_xNone;ifget_reusexthen(reusable_blocks:=!reusable_blocks+1;alloc_reusable:=!alloc_reusable+block_x.size))inOwl_graph.iter_ancestorsupdate_statsnodes;letb=Buffer.create170inBuffer.add_stringb"*** INITIALISATION STATISTICS ***\n";Buffer.add_stringb(Printf.sprintf" %d nodes, %d elements\n"!total_nodes!total_elt);Buffer.add_stringb(Printf.sprintf" %d reusable nodes, %d elements\n"!reusable_nodes!shared_elt);Buffer.add_stringb(Printf.sprintf" %d non-reusable nodes, %d elements\n"!non_reusable_nodes!non_shared_elt);Buffer.add_stringb(Printf.sprintf" %d shared blocks, %d elements\n"!reusable_blocks!alloc_reusable);Buffer.add_stringb(Printf.sprintf" TOTAL NUMBER OF ALLOCATED ELEMENTS: %d\n"(!alloc_reusable+!non_shared_elt));Owl_log.info"%s"(Buffer.contentsb)end(* Make functor ends *)