src/base/neural/owl_neural_graph.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2019 Liang Wang <liang.wang@cl.cam.ac.uk>
*)(** Neural network: Graphical neural network *)openOwl_types(* Make functor starts *)moduleMake(Neuron:Owl_neural_neuron_sig.Sig)=structmoduleNeuron=NeuronopenNeuronopenNeuron.Optimise.Algodiff(* graph network and node definition *)typenode={mutablename:string;(* name of a node *)mutableprev:nodearray;(* parents of a node *)mutablenext:nodearray;(* children of a node *)mutableneuron:neuron;(* neuron contained in a node *)mutableoutput:toption;(* output of a node *)mutablenetwork:network;(* network a node belongs to *)mutabletrain:bool;(* specify if a node is only for training *)}andnetwork={mutablennid:string;(* name of the graph network *)mutablesize:int;(* size of the graph network *)mutableroots:nodearray;(* roots of the graph network, i.e. inputs *)mutableoutputs:nodearray;(* outputs of the graph network *)mutabletopo:nodearray;(* nodes sorted in topological order *)}(* functions to manipulate the network *)letmake_network?nnidsizerootstopo=letnnid=matchnnidwith|Somes->s|None->"Graphical network"in{nnid;size;roots;topo;outputs=[||]}letmake_node?name?(train=false)prevnextneuronoutputnetwork=letname=matchnamewith|Somes->s|None->Printf.sprintf"%s_%i"(to_nameneuron)network.sizein{name;prev;next;neuron;output;network;train;}letget_rootsnn=matchnn.rootswith|[||]->failwith"Owl_neural_graph:get_roots"|x->xletget_outputsnn=nn.outputsletget_nodennname=letx=Owl_utils.Array.filter(funn->n.name=name)nn.topoinifArray.lengthx=0thenfailwith"Owl_neural_graph:get_node"elsex.(0)letget_network?namen=letname=matchnamewith|Somes->s|None->Random.int65535|>string_of_intinn.network.nnid<-name;(* if run and run_inputs are merged, the next line is necessary. *)n.network.outputs<-[|n|];n.networkletoutputs?namenodes=assert(Array.lengthnodes>0);letname=matchnamewith|Somes->s|None->Random.int65535|>string_of_intin(* assumes that all the outputs are part of the same network *)letnn=nodes.(0).networkinnn.nnid<-name;nn.outputs<-nodes;nnletget_network_namen=n.nnidletset_network_namenname=n.nnid<-nameletinput_shapen=(get_rootsn).(0).neuron|>Neuron.get_in_shapeletinput_shapesn=Array.map(funr->r.neuron|>Neuron.get_in_shape)(get_rootsn)(* collect the outputs of a given set of nodes *)letcollect_outputnodes=Array.map(funn->matchn.outputwith|Someo->o|None->failwith"Owl_neural_graph:collect_output")nodesletconnect_pairprevnext=ifArray.memprevnext.prev=falsethennext.prev<-Array.appendnext.prev[|prev|];ifArray.memnextprev.next=falsethenprev.next<-Array.appendprev.next[|next|]letconnect_to_parentsparentschild=(* update the child's input and output shape *)ifArray.lengthparents>0then(letout_shapes=Array.map(funn->n.neuron|>get_out_shape)parentsinconnectout_shapeschild.neuron);(* connect the child to the parents *)Array.iter(funp->connect_pairpchild)parents(* add child node to nn and connect to parents *)letrecadd_node?act_typnnparentschild=nn.size<-nn.size+1;connect_to_parentsparentschild;nn.topo<-Array.appendnn.topo[|child|];child.network<-nn;(* if activation is specified, recursively add_node *)matchact_typwith|Someact->(letneuron=Activation(Activation.createact)inletchild_of_child=make_node[||][||]neuronNonenninadd_nodenn[|child|]child_of_child)|None->child(* functions to interface to optimisation engine *)letinitnn=Array.iter(funn->initn.neuron)nn.topoletresetnn=Array.iter(funn->resetn.neuron)nn.topoletmktagtnn=Array.iter(funn->mktagtn.neuron)nn.topoletmkparnn=Array.map(funn->mkparn.neuron)nn.topoletmkprinn=Array.map(funn->mkprin.neuron)nn.topoletmkadjnn=Array.map(funn->mkadjn.neuron)nn.topoletupdatennus=Array.iter2(funnu->updaten.neuronu)nn.topousletrun_inputsinputsnn=assertArray.(lengthinputs=length(get_rootsnn));Array.iter(funn->(* collect the inputs from parents' output *)letinput=matchn.neuronwith|Input_->letindex=Owl_utils.Array.index_of(Array.map(funr->r.name)(get_rootsnn))n.namein[|inputs.(index)|]|_->collect_outputn.previn(* process the current neuron, save output *)letoutput=runinputn.neuroninn.output<-Someoutput)nn.topo;(* collect the final outputs *)collect_outputnn.outputsletrunxnn=Array.iter(funn->(* collect the inputs from parents' output *)letinput=matchn.neuronwith|Input_->[|x|]|_->collect_outputn.previn(* process the current neuron, save output *)letoutput=runinputn.neuroninn.output<-Someoutput)nn.topo;(* collect the final output from the tail *)letsink=[|nn.topo.(Array.lengthnn.topo-1)|]in(collect_outputsink).(0)letforwardnnx=mktag(tag())nn;runxnn,mkparnnletforward_inputsnnx=mktag(tag())nn;run_inputsxnn,mkparnnletbackwardnny=reverse_prop(_f1.)y;mkprinn,mkadjnnletcopynn=letnn'=make_network~nnid:nn.nnidnn.size[||][||]in(* first iteration to copy the neurons *)nn'.topo<-Array.map(funnode->letneuron'=copynode.neuroninmake_node~name:node.name~train:node.train[||][||]neuron'Nonenn')nn.topo;(* second iteration to re-construct the structure and infer the shape *)Array.iter2(funnodenode'->node'.prev<-Array.map(funn->get_nodenn'n.name)node.prev;node'.next<-Array.map(funn->get_nodenn'n.name)node.next;connect_to_parentsnode'.prevnode')nn.toponn'.topo;(* set roots and outputs to finalise the structure *)nn'.roots<-Array.map(funn->get_nodenn'n.name)(get_rootsnn);nn'.outputs<-Array.map(funn->get_nodenn'n.name)(get_outputsnn);nn'let_remove_training_nodesnn=lettopo'=Owl_utils.Array.filter(funn->ifn.train=truethen((* remove myself from my parents *)Array.iter(funm->letnext'=Owl_utils.Array.filter(funx->x.name<>n.name)m.nextinm.next<-next')n.prev;(* remove myself from my children *)Array.iter(funm->letprev'=Owl_utils.Array.filter(funx->x.name<>n.name)m.previnm.prev<-prev')n.next;(* connect my parents and my children *)Array.iter(connect_to_parentsn.prev)n.next);notn.train)nn.topoinnn.topo<-topo'letmodelnn=letnn=copynnin_remove_training_nodesnn;letinferencex=matchrun(Arrx)nnwith|Arry->y|_->failwith"Owl_neural_graph:model"ininferenceletmodel_inputsnn=letnn=copynnin_remove_training_nodesnn;letinferenceinputs=letoutputs=run_inputs(Array.map(funx->Arrx)inputs)nninArray.mapunpack_arroutputsininference(* functions to create functional nodes *)letinput?nameinputs=letneuron=Input(Input.createinputs)inletnn=make_network0[||][||]inletn=make_node?name[||][||]neuronNonenninnn.roots<-[|n|];add_nodenn[||]nletinputs?namesinput_shapes=letnames=matchnameswith|Somex->assertArray.(lengthx=lengthinput_shapes);Array.map(funname->Somename)x|None->Array.(make(lengthinput_shapes)None)inletneurons=Array.map(funs->Input(Input.creates))input_shapesinletnn=make_network0[||][||]inletns=Array.map2(funnname->make_node?name[||][||]nNonenn)neuronsnamesinnn.roots<-ns;Array.map(funn->add_nodenn[||]n)nsletactivation?nameact_typinput_node=letneuron=Activation(Activation.createact_typ)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_nodenn[|input_node|]nletlinear?name?(init_typ=Init.Standard)?act_typoutputsinput_node=letneuron=Linear(Linear.createoutputsinit_typ)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletlinear_nobias?name?(init_typ=Init.Standard)?act_typoutputsinput_node=letneuron=LinearNoBias(LinearNoBias.createoutputsinit_typ)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletembedding?name?(init_typ=Init.Standard)?act_typin_dimout_diminput_node=letneuron=Embedding(Embedding.createin_dimout_diminit_typ)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletrecurrent?name?(init_typ=Init.Standard)~act_typoutputshiddensinput_node=letneuron=Recurrent(Recurrent.createhiddensoutputsact_typinit_typ)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_nodenn[|input_node|]nletlstm?name?(init_typ=Init.Tanh)cellsinput_node=letneuron=LSTM(LSTM.createcellsinit_typ)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_nodenn[|input_node|]nletgru?name?(init_typ=Init.Tanh)cellsinput_node=letneuron=GRU(GRU.createcellsinit_typ)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_nodenn[|input_node|]nletconv1d?name?(padding=SAME)?(init_typ=Init.Tanh)?act_typkernelstrideinput_node=letneuron=Conv1D(Conv1D.createpaddingkernelstrideinit_typ)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletconv2d?name?(padding=SAME)?(init_typ=Init.Tanh)?act_typkernelstrideinput_node=letneuron=Conv2D(Conv2D.createpaddingkernelstrideinit_typ)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletconv3d?name?(padding=SAME)?(init_typ=Init.Tanh)?act_typkernelstrideinput_node=letneuron=Conv3D(Conv3D.createpaddingkernelstrideinit_typ)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletdilated_conv1d?name?(padding=SAME)?(init_typ=Init.Tanh)?act_typkernelstriderateinput_node=letneuron=DilatedConv1D(DilatedConv1D.createpaddingkernelstriderateinit_typ)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletdilated_conv2d?name?(padding=SAME)?(init_typ=Init.Tanh)?act_typkernelstriderateinput_node=letneuron=DilatedConv2D(DilatedConv2D.createpaddingkernelstriderateinit_typ)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletdilated_conv3d?name?(padding=SAME)?(init_typ=Init.Tanh)?act_typkernelstriderateinput_node=letneuron=DilatedConv3D(DilatedConv3D.createpaddingkernelstriderateinit_typ)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nlettranspose_conv1d?name?(padding=SAME)?(init_typ=Init.Tanh)?act_typkernelstrideinput_node=letneuron=TransposeConv1D(TransposeConv1D.createpaddingkernelstrideinit_typ)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nlettranspose_conv2d?name?(padding=SAME)?(init_typ=Init.Tanh)?act_typkernelstrideinput_node=letneuron=TransposeConv2D(TransposeConv2D.createpaddingkernelstrideinit_typ)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nlettranspose_conv3d?name?(padding=SAME)?(init_typ=Init.Tanh)?act_typkernelstrideinput_node=letneuron=TransposeConv3D(TransposeConv3D.createpaddingkernelstrideinit_typ)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletfully_connected?name?(init_typ=Init.Standard)?act_typoutputsinput_node=letneuron=FullyConnected(FullyConnected.createoutputsinit_typ)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletmax_pool1d?name?(padding=SAME)?act_typkernelstrideinput_node=letneuron=MaxPool1D(MaxPool1D.createpaddingkernelstride)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletmax_pool2d?name?(padding=SAME)?act_typkernelstrideinput_node=letneuron=MaxPool2D(MaxPool2D.createpaddingkernelstride)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletavg_pool1d?name?(padding=SAME)?act_typkernelstrideinput_node=letneuron=AvgPool1D(AvgPool1D.createpaddingkernelstride)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletavg_pool2d?name?(padding=SAME)?act_typkernelstrideinput_node=letneuron=AvgPool2D(AvgPool2D.createpaddingkernelstride)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletglobal_max_pool1d?name?act_typinput_node=letneuron=GlobalMaxPool1D(GlobalMaxPool1D.create())inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletglobal_max_pool2d?name?act_typinput_node=letneuron=GlobalMaxPool2D(GlobalMaxPool2D.create())inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletglobal_avg_pool1d?name?act_typinput_node=letneuron=GlobalAvgPool1D(GlobalAvgPool1D.create())inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletglobal_avg_pool2d?name?act_typinput_node=letneuron=GlobalAvgPool2D(GlobalAvgPool2D.create())inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletupsampling2d?name?act_typsizeinput_node=letneuron=UpSampling2D(UpSampling2D.createsize)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletpadding2d?name?act_typpaddinginput_node=letneuron=Padding2D(Padding2D.createpadding)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletdropout?namerateinput_node=letneuron=Dropout(Dropout.createrate)inletnn=get_networkinput_nodeinletn=make_node?name~train:true[||][||]neuronNonenninadd_nodenn[|input_node|]nletgaussian_noise?namesigmainput_node=letneuron=GaussianNoise(GaussianNoise.createsigma)inletnn=get_networkinput_nodeinletn=make_node?name~train:true[||][||]neuronNonenninadd_nodenn[|input_node|]nletgaussian_dropout?namerateinput_node=letneuron=GaussianDropout(GaussianDropout.createrate)inletnn=get_networkinput_nodeinletn=make_node?name~train:true[||][||]neuronNonenninadd_nodenn[|input_node|]nletalpha_dropout?namerateinput_node=letneuron=AlphaDropout(AlphaDropout.createrate)inletnn=get_networkinput_nodeinletn=make_node?name~train:true[||][||]neuronNonenninadd_nodenn[|input_node|]nletnormalisation?name?(axis=(-1))?training?decay?mu?varinput_node=letneuron=Normalisation(Normalisation.create?training?decay?mu?varaxis)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_nodenn[|input_node|]nletreshape?nameoutputsinput_node=letneuron=Reshape(Reshape.createoutputs)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_nodenn[|input_node|]nletflatten?nameinput_node=letneuron=Flatten(Flatten.create())inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_nodenn[|input_node|]nletlambda?name?act_typlambdainput_node=letneuron=Lambda(Lambda.createlambda)inletnn=get_networkinput_nodeinletn=make_node?name[||][||]neuronNonenninadd_node?act_typnn[|input_node|]nletlambda_array?name?act_typout_shapelambdainput_node=letneuron=LambdaArray(LambdaArray.createout_shapelambda)inletnn=get_networkinput_node.(0)inletn=make_node?name[||][||]neuronNonenninadd_node?act_typnninput_nodenletadd?name?act_typinput_node=letneuron=Add(Add.create())inletnn=get_networkinput_node.(0)inletn=make_node?name[||][||]neuronNonenninadd_node?act_typnninput_nodenletmul?name?act_typinput_node=letneuron=Mul(Mul.create())inletnn=get_networkinput_node.(0)inletn=make_node?name[||][||]neuronNonenninadd_node?act_typnninput_nodenletdot?name?act_typinput_node=letneuron=Dot(Dot.create())inletnn=get_networkinput_node.(0)inletn=make_node?name[||][||]neuronNonenninadd_node?act_typnninput_nodenletmax?name?act_typinput_node=letneuron=Max(Max.create())inletnn=get_networkinput_node.(0)inletn=make_node?name[||][||]neuronNonenninadd_node?act_typnninput_nodenletaverage?name?act_typinput_node=letneuron=Average(Average.create())inletnn=get_networkinput_node.(0)inletn=make_node?name[||][||]neuronNonenninadd_node?act_typnninput_nodenletconcatenate?name?act_typaxisinput_node=letneuron=Concatenate(Concatenate.createaxis)inletnn=get_networkinput_node.(0)inletn=make_node?name[||][||]neuronNonenninadd_node?act_typnninput_noden(* I/O functions *)letto_stringnn=lets=ref(nn.nnid^"\n\n")inArray.iter(funn->letprev=Array.map(funn->n.name)n.prev|>Owl_utils_array.to_string(funs->s)inletnext=Array.map(funn->n.name)n.next|>Owl_utils_array.to_string(funs->s)ins:=!s^Printf.sprintf"\x1b[31m[ Node %s ]:\x1b[0m\n"n.name^Printf.sprintf"%s"(to_stringn.neuron)^Printf.sprintf" prev:[%s] next:[%s]\n\n"prevnext)nn.topo;!sletpp_networkformatternn=Format.open_box0;Format.fprintfformatter"%s"(to_stringnn);Format.close_box()letprintnn=pp_networkFormat.std_formatternnletsave?(unsafe=false)nnf=ifunsafe=truethen(Owl_log.warn"Unsafely saved network can only be loaded back in exactly the same version of OCaml and Owl.";Owl_io.marshal_to_file~flags:[Marshal.Closures](copynn)f)else(Owl_io.marshal_to_file(copynn)f)letloadf:network=Owl_io.marshal_from_filefletsave_weightsnnf=leth=Hashtbl.createnn.sizeinArray.iter(funn->letws=Neuron.save_weightsn.neuroninHashtbl.addhn.namews)nn.topo;Owl_io.marshal_to_filehfletload_weightsnnf=leth=Owl_io.marshal_from_filefinArray.iter(funn->letws=Hashtbl.findhn.nameinNeuron.load_weightsn.neuronws)nn.topo(* training functions *)(* generic minimisation functions
forward: function to run the forward pass
backward: function to run the backward pass
update: function to update the weights according to the gradient
save: function to save the model for checkpoint
*)lettrain_generic?state?params?(init_model=true)nnxy=ifinit_model=truetheninitnn;letf=forwardnninletb=backwardnninletu=updatenninlets=savenninletp=matchparamswith|Somep->p|None->Optimise.Params.default()inOptimise.minimise_network?statepfbusxylettrain?state?params?init_modelnnxy=train_generic?state?params?init_modelnn(Arrx)(Arry)end(* Make functor ends *)