123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208(*
* Copyright (c) 2012-2018 Vincent Bernardoff <vb@luminar.eu.org>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*
*)openAstringopenWebsocketopenLwt.InfixincludeWebsocket.Make(Cohttp_lwt_unix.IO)letsection=Lwt_log.Section.make"websocket_lwt_unix"exceptionHTTP_Errorofstringlethttp_errormsg=Lwt.fail(HTTP_Errormsg)letprotocol_errormsg=Lwt.fail(Protocol_errormsg)letset_tcp_nodelayflow=letopenConduit_lwt_unixinmatchflowwith|TCP{fd;_}->Lwt_unix.setsockoptfdLwt_unix.TCP_NODELAYtrue|_->()letfail_unlesseqf=ifnoteqthenf()elseLwt.return_unitletfail_ifeqf=ifeqthenf()elseLwt.return_unitletdrain_handshakereqicocnonce=Request.write(fun_writer->Lwt.return())reqoc>>=fun()->Response.readic>>=(function|`Okr->Lwt.returnr|`Eof->Lwt.failEnd_of_file|`Invalids->Lwt.fail@@Failures)>>=funresponse->letopenCohttpinletstatus=Response.statusresponseinletheaders=Response.headersresponseinfail_ifCode.(is_error@@code_of_statusstatus)(fun()->http_errorCode.(string_of_statusstatus))>>=fun()->fail_unless(Response.versionresponse=`HTTP_1_1)(fun()->protocol_error"wrong http version")>>=fun()->fail_unless(status=`Switching_protocols)(fun()->protocol_error"wrong status")>>=fun()->(matchHeader.getheaders"upgrade"with|SomeawhenString.Ascii.lowercasea="websocket"->Lwt.return_unit|_->protocol_error"wrong upgrade")>>=fun()->fail_unless(upgrade_presentheaders)(fun()->protocol_error"upgrade header not present")>>=fun()->matchHeader.getheaders"sec-websocket-accept"with|Some acceptwhenaccept=b64_encoded_sha1sum(nonce^websocket_uuid)->Lwt.return_unit|_->protocol_error"wrong accept"letconnectctxclienturlnonceextra_headers=letopenCohttpinletheaders=Header.add_listextra_headers[("Upgrade","websocket");("Connection","Upgrade");("Sec-WebSocket-Key",nonce);("Sec-WebSocket-Version","13")]inletreq=Request.make~headersurlinConduit_lwt_unix.connect~ctxclient>>=fun(flow,ic,oc)->set_tcp_nodelayflow;Lwt.catch(fun()->drain_handshakereqicocnonce)(funexn->Lwt_io.closeic>>=fun()->Lwt.failexn)>>=fun()->Lwt_log.info_f~section"Connected to %s"(Uri.to_stringurl)>>=fun()->Lwt.return(ic,oc)typeconn={read_frame:unit->Frame.tLwt.t;write_frame:Websocket.Frame.t->unitLwt.t;oc:Lwt_io.output_channel}letread{read_frame;_}=read_frame()letwrite{write_frame;_}frame=write_frameframeletclose_transport{oc;_}=Lwt_io.closeocletconnect?(extra_headers=Cohttp.Header.init())?(random_string=Websocket.Rng.init())?(ctx=Lazy.forceConduit_lwt_unix.default_ctx)?bufclienturl=letnonce=Base64.encode_exn(random_string16)inconnectctxclienturlnonceextra_headers>|=fun(ic,oc)->letread_frame=make_read_frame?buf~mode:(Clientrandom_string)icocinletread_frame()=Lwt.catchread_frame(funexn->Lwt.async(fun()->Lwt_io.closeic);Lwt.failexn)inletbuf=Buffer.create128inletwrite_frameframe=Buffer.clearbuf;Lwt.wrap2(write_frame_to_buf~mode:(Clientrandom_string))bufframe>>=fun()->Lwt.catch(fun()->Lwt_io.writeoc(Buffer.contentsbuf))(funexn->Lwt.async(fun()->Lwt_io.closeoc);Lwt.failexn)in{read_frame;write_frame;oc}letwrite_failed_responseoc=letbody="403 Forbidden"inletbody_len=String.lengthbody|>Int64.of_intinletresponse=Cohttp.Response.make~status:`Forbidden~encoding:(Cohttp.Transfer.Fixedbody_len)()inletopenResponseinwrite~flush:true(funwriter->write_bodywriterbody)responseocletserver_fun?read_buf?write_bufcheck_requestflowicocreact=letread=function|`Okr->Lwt.returnr|`Eof->(* Remote endpoint closed connection. No further action necessary here. *)Lwt_log.info~section"Remote endpoint closed connection">>=fun()->Lwt.failEnd_of_file|`Invalidreason->Lwt_log.info_f~section"Invalid input from remote endpoint: %s"reason>>=fun()->Lwt.fail@@HTTP_ErrorreasoninRequest.readic>>=read>>=funrequest->letmeth=Cohttp.Request.methrequestinletversion=Cohttp.Request.versionrequestinletheaders=Cohttp.Request.headersrequestinletkey=Cohttp.Header.getheaders"sec-websocket-key"in(match(version,meth,Cohttp.Header.getheaders"upgrade",key,upgrade_presentheaders,check_requestrequest)with|`HTTP_1_1,`GET,Someup,Somekey,true,truewhenString.Ascii.lowercaseup="websocket"->Lwt.returnkey|_->write_failed_responseoc>>=fun()->Lwt.fail(Protocol_error"Bad headers"))>>=funkey->lethash=key^websocket_uuid|>b64_encoded_sha1suminletresponse_headers=Cohttp.Header.of_list[("Upgrade","websocket");("Connection","Upgrade");("Sec-WebSocket-Accept",hash)]inletresponse=Cohttp.Response.make~status:`Switching_protocols~encoding:Cohttp.Transfer.Unknown~headers:response_headers()inResponse.write(fun_writer->Lwt.return_unit)responseoc>>=fun()->letclient=Connected_client.create?read_buf?write_bufrequestflowicocinreactclientletestablish_server?read_buf?write_buf?timeout?stop?(on_exn=funexn->!Lwt.async_exception_hookexn)?(check_request=check_origin_with_host)?(ctx=Lazy.forceConduit_lwt_unix.default_ctx)~modereact=letmoduleC=CohttpinConduit_lwt_unix.serve~on_exn?timeout?stop~ctx~mode(funflowicoc->set_tcp_nodelayflow;Lwt.catch(fun()->server_fun?read_buf?write_bufcheck_request(Conduit_lwt_unix.endp_of_flowflow)icocreact)(function|End_of_file->Lwt.return_unit|HTTP_Error_->Lwt.return_unit|exn->Lwt.failexn))letmk_frame_streamrecv=letf()=recv()>>=funfr->matchfr.Frame.opcodewith|Frame.Opcode.Close->Lwt.return_none|_->Lwt.return(Somefr)inLwt_stream.fromfletestablish_standard_server?read_buf?write_buf?timeout?stop?on_exn?check_request?(ctx=Lazy.forceConduit_lwt_unix.default_ctx)~modereact=letfclient=react(Connected_client.make_standardclient)inestablish_server?read_buf?write_buf?timeout?stop?on_exn?check_request~ctx~modef