
* Copyright (C) Citrix Systems Inc.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published
* by the Free Software Foundation; version 2.1 only. with the special
* exception on linking described in file LICENSE.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*)(** A multiplexing xenstore protocol client over a byte-level transport *)openLwtopenXs_protocolmoduletypeIO=sigtype'at='aLwt.tvalreturn:'a->'atval(>>=):'at->('a->'bt)->'bttypebackend=[`xen|`unix]valbackend:backendtypechannelvalcreate:unit->channeltvaldestroy:channel->unittvalread:channel->bytes->int->int->inttvalwrite:channel->bytes->int->int->unittendmoduletypeS=sigtypeclientvalmake:unit->clientLwt.tvalsuspend:client->unitLwt.tvalresume:client->unitLwt.ttypehandlevalimmediate:client->(handle->'aLwt.t)->'aLwt.tvaltransaction:client->(handle->'aLwt.t)->'aLwt.tvalwait:client->(handle->'aLwt.t)->'aLwt.tvaldirectory:handle->string->stringlistLwt.tvalread:handle->string->stringLwt.tvalwrite:handle->string->string->unitLwt.tvalrm:handle->string->unitLwt.tvalmkdir:handle->string->unitLwt.tvalsetperms:handle->string->Xs_protocol.ACL.t->unitLwt.tvaldebug:handle->stringlist->stringlistLwt.tvalrestrict:handle->int->unitLwt.tvalgetdomainpath:handle->int->stringLwt.tvalwatch:handle->string->Xs_protocol.Token.t->unitLwt.tvalunwatch:handle->string->Xs_protocol.Token.t->unitLwt.tvalintroduce:handle->int->nativeint->int->unitLwt.tvalset_target:handle->int->int->unitLwt.tendletfinallyfg=Lwt.catch(fun()->f()>>=funresult->g()>>=fun()->Lwt.returnresult)(fune->g()>>=fun()->Lwt.faile)moduleStringSet=Xs_handle.StringSetmoduleWatcher=struct(** Someone who is watching paths is represented by one of these: *)typet={mutablepaths:StringSet.t;(* we never care about events or ordering, only paths *)mutablecancelling:bool;(* we need to stop watching and clean up *)c:unitLwt_condition.t;m:Lwt_mutex.t;}letmake()={paths=StringSet.empty;cancelling=false;c=Lwt_condition.create();m=Lwt_mutex.create();}(** Register that a watched path has been changed *)letput(x:t)path=Lwt_mutex.with_lockx.m(fun()->x.paths<-StringSet.addpathx.paths;Lwt_condition.signalx.c();return();)(** Return a set of modified paths, or an empty set if we're cancelling *)letget(x:t)=Lwt_mutex.with_lockx.m(fun()->letrecloop()=ifx.paths=StringSet.empty&¬x.cancellingthenbeginLwt_condition.wait~mutex:x.mx.c>>=fun()->loop()endelseLwt.return()inloop()>>=fun()->letresults=x.pathsinx.paths<-StringSet.empty;returnresults)(** Called to shutdown the watcher and trigger an orderly cleanup *)letcancel(x:t)=let(_:unitLwt.t)=Lwt_mutex.with_lockx.m(fun()->x.cancelling<-true;Lwt_condition.signalx.c();return())in()endexceptionMalformed_watch_eventexceptionUnexpected_ridofint32exceptionDispatcher_failedmoduleClient=functor(IO:IOwithtype'at='aLwt.t)->structmodulePS=PacketStream(IO)(* Represents a single acive connection to a server *)typeclient={mutabletransport:IO.channel;ps:PS.stream;rid_to_wakeup:(int32,Xs_protocol.tLwt.u)Hashtbl.t;mutabledispatcher_thread:unitLwt.t;mutabledispatcher_shutting_down:bool;watchevents:(Token.t,Watcher.t)Hashtbl.t;mutablesuspended:bool;suspended_m:Lwt_mutex.t;suspended_c:unitLwt_condition.t;}(* The following values are only used if IO.backend = `xen. *)letclient_cache=refNone(* The whole application must only use one xenstore client, which will
multiplex all requests onto the same ring. *)letclient_cache_m=Lwt_mutex.create()(* Multiple threads will call 'make' in parallel. We must ensure only
one client is created. *)letrecv_onet=PS.recvt.ps>>=function|Okx->returnx|Exceptione->Lwt.faileletsend_onet=PS.sendt.pslethandle_exnte=Printf.fprintfstderr"Caught: %s\n%!"(Printexc.to_stringe);beginmatchewith|Xs_protocol.Response_parser_failed_->(* Lwt_io.hexdump Lwt_io.stderr x *)return()|_->return()end>>=fun()->t.dispatcher_shutting_down<-true;(* no more hashtable entries after this *)(* all blocking threads are failed with our exception *)Lwt_mutex.with_lockt.suspended_m(fun()->Printf.fprintfstderr"Propagating exception to %d threads\n%!"(Hashtbl.lengtht.rid_to_wakeup);Hashtbl.iter(fun_u->Lwt.wakeup_later_exnue)t.rid_to_wakeup;return())>>=fun()->Lwt.faileletrecdispatchert=Lwt.catch(fun()->recv_onet)(handle_exnt)>>=funpkt->matchget_typktwith|Op.Watchevent->beginmatchUnmarshal.listpktwith|Some[path;token]->lettoken=Token.of_stringtokenin(* We may get old watches: silently drop these *)ifHashtbl.memt.watcheventstokenthenbeginWatcher.put(Hashtbl.findt.watcheventstoken)path>>=fun()->dispatchertendelsedispatchert|_->handle_exntMalformed_watch_eventend>>=fun()->dispatchert|_->letrid=get_ridpktinLwt_mutex.with_lockt.suspended_m(fun()->ifHashtbl.memt.rid_to_wakeupridthenreturn(Some(Hashtbl.findt.rid_to_wakeuprid))elsereturnNone)>>=function|None->handle_exnt(Unexpected_ridrid)|Somethread->beginLwt.wakeup_laterthreadpkt;dispatchertendletmake_unsafe()=IO.create()>>=funtransport->lett={transport=transport;ps=PS.maketransport;rid_to_wakeup=Hashtbl.create10;dispatcher_thread=return();dispatcher_shutting_down=false;watchevents=Hashtbl.create10;suspended=false;suspended_m=Lwt_mutex.create();suspended_c=Lwt_condition.create();}int.dispatcher_thread<-dispatchert;returntletmake()=matchIO.backendwith|`unix->make_unsafe()|`xen->Lwt_mutex.with_lockclient_cache_m(fun()->match!client_cachewith|Somec->returnc|None->make_unsafe()>>=func->client_cache:=Somec;returnc)letsuspendt=Lwt_mutex.with_lockt.suspended_m(fun()->t.suspended<-true;letrecloop()=ifHashtbl.lengtht.rid_to_wakeup>0thenbeginLwt_condition.wait~mutex:t.suspended_mt.suspended_c>>=fun()->loop()endelseLwt.return()inloop())>>=fun()->Hashtbl.iter(fun_watcher->Watcher.cancelwatcher)t.watchevents;Lwt.cancelt.dispatcher_thread;return()letresume_unsafet=Lwt_mutex.with_lockt.suspended_m(fun()->t.suspended<-false;t.dispatcher_shutting_down<-false;Lwt_condition.broadcastt.suspended_c();return())>>=fun()->t.dispatcher_thread<-dispatchert;return()letresumet=matchIO.backendwith|`unix->resume_unsafet|`xen->(match!client_cachewith|None->Lwt.return()|Somec->IO.create()>>=funtransport->c.transport<-transport;resume_unsafet)typehandle=clientXs_handle.tletmake_rid=letcounter=ref0linfun()->letresult=!counterincounter:=Int32.succ!counter;resultletrpchinthpayloadunmarshal=letopenXs_handleinletrid=make_rid()inletrequest=Request.printpayload(get_tidh)ridinlett,u=wait()inletc=get_clienthinifc.dispatcher_shutting_downthenLwt.failDispatcher_failedelsebeginLwt_mutex.with_lockc.suspended_m(fun()->letrecloop()=ifc.suspendedthenbeginLwt_condition.wait~mutex:c.suspended_mc.suspended_c>>=fun()->loop()endelseLwt.return()inloop()>>=fun()->Hashtbl.addc.rid_to_wakeupridu;send_onecrequest)>>=fun()->t>>=funres->Lwt_mutex.with_lockc.suspended_m(fun()->Hashtbl.removec.rid_to_wakeuprid;Lwt_condition.broadcastc.suspended_c();return())>>=fun()->return(responsehintrequestresunmarshal)endletdirectoryhpath=rpc"directory"(Xs_handle.accessed_pathhpath)Request.(PathOp(path,Directory))Unmarshal.listletreadhpath=rpc"read"(Xs_handle.accessed_pathhpath)Request.(PathOp(path,Read))Unmarshal.stringletwritehpathdata=rpc"write"(Xs_handle.accessed_pathhpath)Request.(PathOp(path,Writedata))Unmarshal.okletrmhpath=rpc"rm"(Xs_handle.accessed_pathhpath)Request.(PathOp(path,Rm))Unmarshal.okletmkdirhpath=rpc"mkdir"(Xs_handle.accessed_pathhpath)Request.(PathOp(path,Mkdir))Unmarshal.okletsetpermshpathacl=rpc"setperms"(Xs_handle.accessed_pathhpath)Request.(PathOp(path,Setpermsacl))Unmarshal.okletdebughcmd_args=rpc"debug"h(Request.Debugcmd_args)Unmarshal.listletrestricthdomid=rpc"restrict"h(Request.Restrictdomid)Unmarshal.okletgetdomainpathhdomid=rpc"getdomainpath"h(Request.Getdomainpathdomid)Unmarshal.stringletwatchhpathtoken=rpc"watch"(Xs_handle.watchhpath)(Request.Watch(path,Token.to_stringtoken))Unmarshal.okletunwatchhpathtoken=rpc"unwatch"(Xs_handle.unwatchhpath)(Request.Unwatch(path,Token.to_stringtoken))Unmarshal.okletintroducehdomidstore_mfnstore_port=rpc"introduce"h(Request.Introduce(domid,store_mfn,store_port))Unmarshal.okletset_targethstubdom_domiddomid=rpc"set_target"h(Request.Set_target(stubdom_domid,domid))Unmarshal.okletimmediateclientf=f(Xs_handle.no_transactionclient)letcounter=ref0lletrecadd_watcheshtoken=function|[]->Lwt.return_unit|p::ps->Lwt.try_bind(fun()->watchhptoken)(fun()->add_watcheshtokenps)(funex->(* If we fail to add all the watches (e.g. because we exceeded our watch quota)
then mark the remaining paths as unwatched so we don't try to release them. *)(p::ps)|>List.iter(funp->let_:_Xs_handle.t=Xs_handle.unwatchhpin());Lwt.failex)letwaitclientf=letopenStringSetincounter:=Int32.succ!counter;lettoken=Token.of_string(Printf.sprintf"%ld:xs_client.wait"!counter)in(* When we register the 'watcher', the dispatcher thread will signal us when
watches arrive. *)letwatcher=Watcher.make()inHashtbl.addclient.watcheventstokenwatcher;(* We signal the caller via this cancellable task: *)letresult,wakener=Lwt.task()inon_cancelresult(fun()->(* Trigger an orderly cleanup in the background: *)Watcher.cancelwatcher);leth=Xs_handle.watchingclientin(* Adjust the paths we're watching (if necessary) and block (if possible) *)letadjust_paths()=letcurrent_paths=Xs_handle.get_watched_pathshin(* Paths which weren't read don't need to be watched: *)letold_paths=diffcurrent_paths(Xs_handle.get_accessed_pathsh)inLwt_list.iter_s(funp->unwatchhptoken)(elementsold_paths)>>=fun()->(* Paths which were read do need to be watched: *)letnew_paths=diff(Xs_handle.get_accessed_pathsh)current_pathsinadd_watcheshtoken(elementsnew_paths)>>=fun()->(* If we're watching the correct set of paths already then just block *)ifold_paths=empty&&(new_paths=empty)thenbeginWatcher.getwatcher>>=funresults->(* an empty results set means we've been cancelled: trigger cleanup *)ifresults=emptythenfail(Failure"goodnight")elsereturn()endelsereturn()in(* Main client loop: *)letrecloop()=Lwt.catch(fun()->fh>>=funresult->wakeupwakenerresult;returntrue)(function|Eagain->returnfalse|ex->wakeup_exnwakenerex;returntrue)>>=function|true->return()|false->Lwt.try_bindadjust_pathsloop(funex->wakeup_exnwakenerex;Lwt.return_unit)inLwt.async(fun()->finallyloop(fun()->letcurrent_paths=Xs_handle.get_watched_pathshinLwt_list.iter_s(funp->unwatchhptoken)(elementscurrent_paths)>>=fun()->Hashtbl.removeclient.watcheventstoken;return()));resultletrectransactionclientf=rpc"transaction_start"(Xs_handle.no_transactionclient)Request.Transaction_startUnmarshal.int32>>=funtid->leth=Xs_handle.transactionclienttidinfh>>=funresult->Lwt.catch(fun()->rpc"transaction_end"h(Request.Transaction_endtrue)Unmarshal.string>>=funres'->ifres'="OK"thenreturnresultelseLwt.fail(Error(Printf.sprintf"Unexpected transaction result: %s"res')))(function|Eagain->transactionclientf|e->Lwt.faile)end