123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176# 1 "src/owl/neural/owl_neural_parallel.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2019 Liang Wang <liang.wang@cl.cam.ac.uk>
*)(** Neural network: interface of parallel engine *)openOwl_algodiff.SopenOwl_optimise.S(* module signature of model parallel engine *)moduletypeEngineSig=sigtypeparam_contexttypebarrier=ASP|BSP|SSP|PSP(* functions of parameter server engine *)valget:'a->'b*intvalset:'a->'b->unitvalworker_num:unit->intvalstart:?barrier:barrier->string->string->unitvalregister_barrier:(param_contextref->int*(stringlist))->unitvalregister_schedule:('alist->('a*('b*'c)list)list)->unitvalregister_pull:(('a*'b)list->('a*'c)list)->unitvalregister_push:('a->('b*'c)list->('b*'c)list)->unitvalregister_stop:(param_contextref->bool)->unitend(* module signature of neural network model *)moduletypeModelSig=sigtypenetworkvalmkpar:network->tarrayarrayvalinit:network->unitvalupdate:network->tarrayarray->unitvalcopy:network->networkvaltrain_generic:?state:Checkpoint.state->?params:Params.typ->?init_model:bool->network->t->t->Checkpoint.stateend(* implementation of parallel neural network training *)moduleMake(M:ModelSig)(E:EngineSig)=structtypetask={mutableid:int;mutablestate:Checkpoint.stateoption;mutableparams:Params.typ;mutablemodel:M.network;mutabledata_x:t;mutabledata_y:t;}letmake_taskidparamsmodeldata_xdata_y={id;state=None;params;model;data_x;data_y;}(* calculate \delta model = model0 - model1, save the result in model0 *)letdelta_modelmodel0model1=letpar0=M.mkparmodel0inletpar1=M.mkparmodel1inletdelta=Owl_utils.aarr_map2(funa0a1->Maths.(a0-a1))par0par1inM.updatemodel0delta(* retrieve local model at parameter server, init if none *)letlocal_modeltask=tryE.gettask.id|>fstwithNot_found->(Owl_log.warn"set up first model";M.inittask.model;E.settask.idtask.model;E.gettask.id|>fst;)letscheduletaskworkers=(* get model, if none then init locally *)letmodel=local_modeltaskinlettasks=List.map(funx->(x,[(task.id,model)]))workersintasksletpulltaskvars=letn=E.worker_num()|>float_of_intinassert(n>=1.);(* at least one worker *)(* there should be only one item in list *)List.map(fun(k,model1)->letmodel0=local_modeltaskinletpar0=M.mkparmodel0inletpar1=M.mkparmodel1inOwl_utils.aarr_map2(funa0a1->Maths.(a0+a1))par0par1|>M.updatemodel0;task.model<-model0;E.settask.idtask.model;(k,model0))varsletpushtask_idvars=(* there should be only one item in list *)letupdates=List.map(fun(k,model)->task.model<-M.copymodel;(* start local training *)letparams=task.paramsinletx=task.data_xinlety=task.data_yinletstate=matchtask.statewith|Somestate->M.(train_generic~state~params~init_model:falsemodelxy)|None->M.(train_generic~params~init_model:falsemodelxy)inCheckpoint.(state.stop<-false);task.state<-Somestate;(* only send out delta model *)delta_modelmodeltask.model;(k,M.copymodel))varsinupdates(* FIXME: currently running forever *)letstop_task_context=falselettrain_generic?paramsnnxyjidurl=(* prepare params and make task *)letparams=matchparamswith|Somep->p|None->Params.default()inletid=Owl_stats.uniform_int_rvs~a:0~b:max_intinlettask=make_taskidparamsnnxyin(* register sched/push/pull/stop/barrier *)E.register_schedule(scheduletask);E.register_pull(pulltask);E.register_push(pushtask);E.register_stop(stoptask);E.start~barrier:E.ASPjidurllettrain?paramsnnxyjidurl=train_generic?paramsnn(Arrx)(Arry)jidurlend