Source file vmm_tls_lwt.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
open Lwt.Infix
let read_tls_chunk t =
let rec r_n buf off tot =
let l = tot - off in
if l = 0 then
Lwt.return (Ok ())
else
Tls_lwt.Unix.read t ~off buf >>= function
| 0 ->
Logs.debug (fun m -> m "TLS: end of file") ;
Lwt.return (Error `Tls_eof)
| x when x == l -> Lwt.return (Ok ())
| x when x < l -> r_n buf (off + x) tot
| _ ->
Logs.err (fun m -> m "TLS: read too much, shouldn't happen") ;
Lwt.return (Error `Toomuch)
in
let buf = Bytes.create 4 in
r_n buf 0 4 >>= function
| Error e -> Lwt.return (Error e)
| Ok () ->
let len = Bytes.get_int32_be buf 0 in
if len > 0l then
let b = Bytes.create (Int32.to_int len) in
r_n b 0 (Int32.to_int len) >|= function
| Error e -> Error e
| Ok () ->
Ok (Bytes.unsafe_to_string b)
else
Lwt.return (Error `Eof)
let read_tls t =
read_tls_chunk t >|= function
| Error _ as e -> e
| Ok data ->
match Vmm_asn.wire_of_str data with
| Error (`Msg msg) ->
Logs.err (fun m -> m "error %s while parsing data" msg) ;
Error `Exception
| (Ok (hdr, _)) as w ->
if not Vmm_commands.(is_current hdr.version) then
Logs.warn (fun m -> m "version mismatch, received %a current %a"
Vmm_commands.pp_version hdr.Vmm_commands.version
Vmm_commands.pp_version Vmm_commands.current);
w
let write_tls_chunk s data =
let dlen = Bytes.create 4 in
Bytes.set_int32_be dlen 0 (Int32.of_int (String.length data)) ;
let buf = Bytes.unsafe_to_string dlen ^ data in
Lwt.catch
(fun () -> Tls_lwt.Unix.write s buf >|= fun () -> Ok ())
(function
| Tls_lwt.Tls_failure a ->
Logs.err (fun m -> m "tls failure: %s" (Tls.Engine.string_of_failure a)) ;
Lwt.return (Error `Exception)
| e ->
Logs.err (fun m -> m "TLS write exception %s" (Printexc.to_string e)) ;
Lwt.return (Error `Exception))
let write_tls s wire =
let data = Vmm_asn.wire_to_str wire in
write_tls_chunk s data
let close tls =
Lwt.catch
(fun () -> Tls_lwt.Unix.close tls)
(fun _ -> Lwt.return_unit)