1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
open! Core
open! Async
open! Import
module Connection = Ssl.Connection
open Require_explicit_time_source
let teardown_connection ~outer_rd ~outer_wr ~time_source =
let force_close = Time_source.after time_source (Time_ns.Span.of_sec 30.) in
let%bind () = Writer.close ~force_close outer_wr in
Reader.close outer_rd
;;
let reader_writer_pipes ~outer_rd ~outer_wr ~time_source =
let reader_pipe_r, reader_pipe_w = Pipe.create () in
let writer_pipe = Writer.pipe outer_wr in
upon (Reader.transfer outer_rd reader_pipe_w) (fun () ->
teardown_connection ~outer_rd ~outer_wr ~time_source
>>> fun () -> Pipe.close reader_pipe_w);
upon (Pipe.closed writer_pipe) (fun () ->
Deferred.choose
[ Deferred.choice
(Time_source.after time_source (Time_ns.Span.of_sec 30.))
(fun () -> ())
; Deferred.choice
(Pipe.downstream_flushed writer_pipe)
(fun (_ : Pipe.Flushed_result.t) -> ())
]
>>> fun () -> don't_wait_for (teardown_connection ~outer_rd ~outer_wr ~time_source));
reader_pipe_r, writer_pipe
;;
let reader_writer_of_pipes ~app_rd ~app_wr =
let%bind inner_rd = Reader.of_pipe (Info.of_string "async_ssl_tls_reader") app_rd in
upon (Reader.close_finished inner_rd) (fun () -> Pipe.close_read app_rd);
let%map inner_wr, _ = Writer.of_pipe (Info.of_string "async_ssl_tls_writer") app_wr in
Writer.set_raise_when_consumer_leaves inner_wr false;
inner_rd, inner_wr
;;
let call_handler_and_cleanup ~outer_rd:_ ~outer_wr ~inner_rd ~inner_wr f =
Monitor.protect f ~run:`Now ~rest:`Log ~finally:(fun () ->
let%bind () = Writer.flushed_or_failed_unit inner_wr in
let%bind () = Writer.close inner_wr in
Deferred.all_unit
[
Reader.close inner_rd
;
Writer.close_finished outer_wr
])
;;
let wrap_connection
?(timeout = Time_ns.Span.of_sec 30.)
outer_rd
outer_wr
~negotiate
~f
~time_source
=
let net_to_ssl, ssl_to_net = reader_writer_pipes ~outer_rd ~outer_wr ~time_source in
let app_to_ssl, app_wr = Pipe.create () in
let app_rd, ssl_to_app = Pipe.create () in
let%bind negotiate =
match%map
Time_source.with_timeout
time_source
timeout
(negotiate ~app_to_ssl ~ssl_to_app ~net_to_ssl ~ssl_to_net)
with
| `Timeout -> error_s [%message "Timeout exceeded"]
| `Result connection -> connection
in
match negotiate with
| Error error ->
let%map () = teardown_connection ~outer_rd ~outer_wr ~time_source in
Error.raise error
| Ok conn ->
let%bind inner_rd, inner_wr = reader_writer_of_pipes ~app_rd ~app_wr in
call_handler_and_cleanup ~outer_rd ~outer_wr ~inner_rd ~inner_wr (fun () ->
f conn inner_rd inner_wr)
;;
let wrap_server_connection tls_settings outer_rd outer_wr ~f ~time_source =
let ca_file = Config.Server.ca_file tls_settings in
let ca_path = Config.Server.ca_path tls_settings in
let verify_modes = Config.Server.verify_modes tls_settings in
let version = Config.Server.tls_version tls_settings in
let options = Config.Server.tls_options tls_settings in
let crt_file = Config.Server.crt_file tls_settings in
let key_file = Config.Server.key_file tls_settings in
let allowed_ciphers = Config.Server.allowed_ciphers tls_settings in
let override_security_level = Config.Server.override_security_level tls_settings in
wrap_connection
outer_rd
outer_wr
~negotiate:
(Ssl.server
?ca_file
?ca_path
?verify_modes
?override_security_level
~version
~options
~crt_file
~key_file
~allowed_ciphers
())
~f:(fun conn r w ->
match Ssl.Connection.peer_certificate conn with
| None | Some (Ok (_ : Ssl.Certificate.t)) -> f conn r w
| Some (Error error) -> Error.raise error)
~time_source
;;
let listen
?max_connections
?backlog
?buffer_age_limit
?advance_clock_before_tls_negotiation
?socket
tls_settings
where_to_listen
~on_handler_error
~f
=
Tcp.Server.create
?max_connections
?backlog
?buffer_age_limit
?socket
~on_handler_error
where_to_listen
(fun sock r w ->
let%bind time_source =
match advance_clock_before_tls_negotiation with
| None -> return (Time_source.wall_clock ())
| Some (time_source, delay) ->
let%map () = Time_source.advance_by_alarms_by time_source delay in
Time_source.read_only time_source
in
wrap_server_connection tls_settings r w ~f:(f sock) ~time_source)
;;
let wrap_client_connection ?timeout tls_settings outer_rd outer_wr ~f =
let ca_file = Config.Client.ca_file tls_settings in
let ca_path = Config.Client.ca_path tls_settings in
let version = Config.Client.tls_version tls_settings in
let options = Config.Client.tls_options tls_settings in
let crt_file = Config.Client.crt_file tls_settings in
let key_file = Config.Client.key_file tls_settings in
let hostname = Config.Client.remote_hostname tls_settings in
let allowed_ciphers = Config.Client.allowed_ciphers tls_settings in
let verify_modes = Config.Client.verify_modes tls_settings in
let verify_callback = Config.Client.verify_callback tls_settings in
let session = Config.Client.session tls_settings in
let connection_name = Config.Client.connection_name tls_settings in
let override_security_level = Config.Client.override_security_level tls_settings in
wrap_connection
?timeout
~negotiate:
(Ssl.client
?ca_file
?ca_path
?crt_file
?key_file
?hostname
?session
?name:connection_name
?override_security_level
~verify_modes
~allowed_ciphers
~version
~options
())
outer_rd
outer_wr
~f:(fun conn inner_rd inner_wr ->
match%bind verify_callback conn with
| Error connection_verification_error ->
raise_s
[%message
"Connection verification failed." (connection_verification_error : Error.t)]
| Ok () -> f conn inner_rd inner_wr)
;;
let with_connection ?interrupt ?timeout tls_settings where_to_connect ~f ~time_source =
let start_time = Time_source.now time_source in
Async.Tcp.with_connection
?interrupt
?timeout:(Option.map timeout ~f:Time_ns.Span.to_span_float_round_nearest)
where_to_connect
(fun socket outer_rd outer_wr ->
let timeout =
Option.map timeout ~f:(fun timeout ->
let tcp_time_elapsed = Time_ns.diff (Time_source.now time_source) start_time in
Time_ns.Span.(timeout - tcp_time_elapsed))
in
wrap_client_connection
?timeout
tls_settings
outer_rd
outer_wr
~f:(f socket)
~time_source)
;;
module For_testing = struct
let listen = listen
let with_connection = with_connection
end
let time_source = Time_source.wall_clock ()
let listen = listen ?advance_clock_before_tls_negotiation:None
let wrap_server_connection = wrap_server_connection ~time_source
let with_connection = with_connection ~time_source
let wrap_client_connection = wrap_client_connection ~time_source
module Expert = struct
let connect ?interrupt ?timeout tls_settings where_to_connect =
let conn_ivar = Ivar.create () in
don't_wait_for
(with_connection
?interrupt
?timeout
tls_settings
where_to_connect
~f:(fun sock conn r w ->
Ivar.fill_exn conn_ivar (sock, conn, r, w);
Deferred.any [ Reader.close_finished r; Writer.close_finished w ]));
Ivar.read conn_ivar
;;
let wrap_client_connection_and_stay_open tls_settings outer_rd outer_wr ~f =
let result = Ivar.create () in
let finished =
wrap_client_connection tls_settings outer_rd outer_wr ~f:(fun conn r w ->
let%bind res, `Do_not_close_until finished = f conn r w in
Ivar.fill_exn result res;
finished)
in
let%map result = Ivar.read result in
result, `Connection_closed finished
;;
end