Source file server_impl.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
module Private = Mehari.Private
module type S = sig
type stack
module IO : Private.IO
type handler = Ipaddr.t Private.Handler.Make(IO).t
val run :
?port:int ->
?verify_url_host:bool ->
?config:Tls.Config.server ->
?timeout:float ->
certchains:Tls.Config.certchain list ->
stack ->
handler ->
unit IO.t
end
module Make
(Stack : Tcpip.Stack.V4V6)
(Time : Mirage_time.S)
(Logger : Private.Logger_impl.S) :
S with module IO = Lwt and type stack := Stack.t = struct
module IO = Lwt
type handler = Ipaddr.t Private.Handler.Make(IO).t
module TLS = Tls_mirage.Make (Stack.TCP)
module Channel = Mirage_channel.Make (TLS)
module Protocol = Mehari.Private.Protocol
open Lwt.Infix
open Lwt.Syntax
type config = {
addr : Ipaddr.t;
port : int;
timeout : float option;
tls_config : Tls.Config.server;
verify_url_host : bool;
}
let make_config ~addr ~port ~timeout ~tls_config ~verify_url_host =
{ addr; port; timeout; tls_config; verify_url_host }
let src = Logs.Src.create "mehari.mirage"
module Log = (val Logs.src_log src)
let flush_channel chan =
Channel.flush chan |> Lwt_result.map_error (fun e -> `ChannelWriteErr e)
let write_resp chan resp =
let write buf = Channel.write_string chan buf 0 (String.length buf) in
match Mehari.Private.view_of_resp resp with
| Immediate bufs ->
List.iter write bufs;
flush_channel chan
| Delayed { body; _ } ->
body write;
flush_channel chan
let read_client_req flow =
let buf = Buffer.create 1024 in
let rec loop n cr =
Channel.read_char flow >>= function
| Ok (`Data _) when n > 1024 -> Lwt.return_error `BufferLimitExceeded
| Ok (`Data '\n') when cr -> Buffer.contents buf |> Lwt.return_ok
| Ok (`Data '\r') -> loop n true
| Ok (`Data c) ->
Buffer.add_char buf c;
loop (n + 1) false
| Ok `Eof -> Lwt.return_error `Eof
| Error err -> `ChannelErr err |> Lwt.return_error
in
loop 0 false
exception Timeout
let with_timeout _timeout f =
match Some 1.0 with
| None -> f ()
| Some duration ->
let timeout =
let* () = Time.sleep_ns (Duration.of_f duration) in
Lwt.fail Timeout
in
Lwt.pick [ f (); timeout ]
let write_and_close chan flow resp =
write_resp chan resp >>= function
| Ok () ->
let+ () = TLS.close flow in
Ok ()
| Error err -> Lwt.return_error err
let handle_client config callback flow epoch =
let chan = Channel.create flow in
with_timeout config.timeout (fun () -> read_client_req chan) >>= function
| Ok client_req -> (
match epoch with
| Ok ep ->
let* resp =
match
Protocol.make_request
(module Ipaddr)
~port:config.port ~addr:config.addr
~verify_url_host:config.verify_url_host ep client_req
with
| Ok req -> callback req
| Error err -> Protocol.to_response err |> Lwt.return
in
write_and_close chan flow resp
| Error () -> Lwt.return_error `ConnectionClosed)
| Error `BufferLimitExceeded ->
Protocol.to_response AboveMaxSize |> write_and_close chan flow
| Error err -> Lwt.return_error err
let handler config callback flow =
TLS.server_of_flow config.tls_config flow >>= function
| Ok server -> TLS.epoch server |> handle_client config callback server
| Error err -> `TLSWriteErr err |> Lwt.return_error
let log_err = function
| `BufferLimitExceeded -> assert false
| `ConnectionClosed ->
Log.warn (fun log -> log "Connection has been closed prematurly")
| `Eof -> Log.warn (fun log -> log "EOF encountered prematurly")
| `ChannelWriteErr err ->
Log.warn (fun log ->
log "ChannelWriteErr: %a" Channel.pp_write_error err)
| `ChannelErr err -> Log.warn (fun log -> log "%a" Channel.pp_error err)
| `Timeout ->
Log.warn (fun log -> log "Timeout while reading client request")
| `TLSWriteErr err ->
Log.warn (fun log -> log "TLSWriteErr: %a" TLS.pp_write_error err)
let run ?(port = 1965) ?(verify_url_host = true) ?config ?timeout ~certchains
stack callback =
let certificates = Private.Cert.get_certs ~exn_msg:"run_lwt" certchains in
let addr =
Stack.ip stack |> Stack.IP.get_ip
|> Fun.flip List.nth 0
in
let tls_config =
match config with
| Some c -> c
| None ->
Tls.Config.server ~certificates
~authenticator:(fun ?ip:_ ~host:_ _ -> Ok None)
()
in
let config =
make_config ~addr ~port ~timeout ~tls_config ~verify_url_host
in
Logger.info (fun log -> log "Listening on port %i" port);
Stack.TCP.listen (Stack.tcp stack) ~port (fun flow ->
Lwt.catch
(fun () -> handler config callback flow >|= Result.iter_error log_err)
(function
| Timeout ->
log_err `Timeout;
Lwt.return_unit
| exn -> raise exn));
Stack.listen stack
end