Source file server_impl.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
module Private = Mehari.Private

module type S = sig
  type stack

  module IO : Private.IO

  type handler = Ipaddr.t Private.Handler.Make(IO).t

  val run :
    ?port:int ->
    ?verify_url_host:bool ->
    ?config:Tls.Config.server ->
    ?timeout:float ->
    certchains:Tls.Config.certchain list ->
    stack ->
    handler ->
    unit IO.t
end

module Make
    (Stack : Tcpip.Stack.V4V6)
    (Time : Mirage_time.S)
    (Logger : Private.Logger_impl.S) :
  S with module IO = Lwt and type stack := Stack.t = struct
  module IO = Lwt

  type handler = Ipaddr.t Private.Handler.Make(IO).t

  module TLS = Tls_mirage.Make (Stack.TCP)
  module Channel = Mirage_channel.Make (TLS)
  module Protocol = Mehari.Private.Protocol
  open Lwt.Infix
  open Lwt.Syntax

  type config = {
    addr : Ipaddr.t;
    port : int;
    timeout : float option;
    tls_config : Tls.Config.server;
    verify_url_host : bool;
  }

  let make_config ~addr ~port ~timeout ~tls_config ~verify_url_host =
    { addr; port; timeout; tls_config; verify_url_host }

  let src = Logs.Src.create "mehari.mirage"

  module Log = (val Logs.src_log src)

  let flush_channel chan =
    Channel.flush chan |> Lwt_result.map_error (fun e -> `ChannelWriteErr e)

  let write_resp chan resp =
    let write buf = Channel.write_string chan buf 0 (String.length buf) in
    match Mehari.Private.view_of_resp resp with
    | Immediate bufs ->
        List.iter write bufs;
        flush_channel chan
    | Delayed { body; _ } ->
        body write;
        flush_channel chan

  let read_client_req flow =
    let buf = Buffer.create 1024 in
    let rec loop n cr =
      Channel.read_char flow >>= function
      | Ok (`Data _) when n > 1024 -> Lwt.return_error `BufferLimitExceeded
      | Ok (`Data '\n') when cr -> Buffer.contents buf |> Lwt.return_ok
      | Ok (`Data '\r') -> loop n true
      | Ok (`Data c) ->
          Buffer.add_char buf c;
          loop (n + 1) false
      | Ok `Eof -> Lwt.return_error `Eof
      | Error err -> `ChannelErr err |> Lwt.return_error
    in
    loop 0 false

  exception Timeout

  let with_timeout _timeout f =
    match Some 1.0 with
    | None -> f ()
    | Some duration ->
        let timeout =
          let* () = Time.sleep_ns (Duration.of_f duration) in
          Lwt.fail Timeout
        in
        Lwt.pick [ f (); timeout ]

  let write_and_close chan flow resp =
    write_resp chan resp >>= function
    | Ok () ->
        let+ () = TLS.close flow in
        Ok ()
    | Error err -> Lwt.return_error err

  let handle_client config callback flow epoch =
    let chan = Channel.create flow in
    with_timeout config.timeout (fun () -> read_client_req chan) >>= function
    | Ok client_req -> (
        match epoch with
        | Ok ep ->
            let* resp =
              match
                Protocol.make_request
                  (module Ipaddr)
                  ~port:config.port ~addr:config.addr
                  ~verify_url_host:config.verify_url_host ep client_req
              with
              | Ok req -> callback req
              | Error err -> Protocol.to_response err |> Lwt.return
            in
            write_and_close chan flow resp
        | Error () -> Lwt.return_error `ConnectionClosed)
    | Error `BufferLimitExceeded ->
        Protocol.to_response AboveMaxSize |> write_and_close chan flow
    | Error err -> Lwt.return_error err

  let handler config callback flow =
    TLS.server_of_flow config.tls_config flow >>= function
    | Ok server -> TLS.epoch server |> handle_client config callback server
    | Error err -> `TLSWriteErr err |> Lwt.return_error

  let log_err = function
    | `BufferLimitExceeded -> assert false
    | `ConnectionClosed ->
        Log.warn (fun log -> log "Connection has been closed prematurly")
    | `Eof -> Log.warn (fun log -> log "EOF encountered prematurly")
    | `ChannelWriteErr err ->
        Log.warn (fun log ->
            log "ChannelWriteErr: %a" Channel.pp_write_error err)
    | `ChannelErr err -> Log.warn (fun log -> log "%a" Channel.pp_error err)
    | `Timeout ->
        Log.warn (fun log -> log "Timeout while reading client request")
    | `TLSWriteErr err ->
        Log.warn (fun log -> log "TLSWriteErr: %a" TLS.pp_write_error err)

  let run ?(port = 1965) ?(verify_url_host = true) ?config ?timeout ~certchains
      stack callback =
    let certificates = Private.Cert.get_certs ~exn_msg:"run_lwt" certchains in
    let addr =
      Stack.ip stack |> Stack.IP.get_ip
      |> Fun.flip List.nth 0 (* Should not be empty. *)
    in
    let tls_config =
      match config with
      | Some c -> c
      | None ->
          Tls.Config.server ~certificates
            ~authenticator:(fun ?ip:_ ~host:_ _ -> Ok None)
            ()
    in
    let config =
      make_config ~addr ~port ~timeout ~tls_config ~verify_url_host
    in
    Logger.info (fun log -> log "Listening on port %i" port);
    Stack.TCP.listen (Stack.tcp stack) ~port (fun flow ->
        Lwt.catch
          (fun () -> handler config callback flow >|= Result.iter_error log_err)
          (function
            | Timeout ->
                log_err `Timeout;
                Lwt.return_unit
            | exn -> raise exn));
    Stack.listen stack
end