1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
open Lwt.Infix
let pp_sockaddr ppf = function
| Lwt_unix.ADDR_UNIX str -> Fmt.pf ppf "unix domain socket %s" str
| Lwt_unix.ADDR_INET (addr, port) -> Fmt.pf ppf "TCP %s:%d"
(Unix.string_of_inet_addr addr) port
let safe_close fd =
Lwt.catch
(fun () -> Lwt_unix.close fd)
(fun _ -> Lwt.return_unit)
let server_socket ~systemd sock =
if systemd
then match Vmm_unix.sd_listen_fds () with
| Some [fd] -> Lwt.return (Lwt_unix.of_unix_file_descr fd)
| _ -> failwith "Systemd socket activation error"
else
let name = Vmm_core.socket_path sock in
(Lwt_unix.file_exists name >>= function
| true -> Lwt_unix.unlink name
| false -> Lwt.return_unit) >>= fun () ->
let s = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in
Lwt_unix.set_close_on_exec s ;
let old_umask = Unix.umask 0 in
let _ = Unix.umask (old_umask land 0o707) in
Lwt_unix.(bind s (ADDR_UNIX name)) >|= fun () ->
Logs.app (fun m -> m "listening on %s" name);
let _ = Unix.umask old_umask in
Lwt_unix.listen s 1 ;
s
let connect addrtype sockaddr =
let c = Lwt_unix.(socket addrtype SOCK_STREAM 0) in
Lwt_unix.set_close_on_exec c ;
Lwt.catch (fun () ->
Lwt_unix.(connect c sockaddr) >|= fun () ->
Some c)
(fun e ->
Logs.warn (fun m -> m "error %s connecting to socket %a"
(Printexc.to_string e) pp_sockaddr sockaddr);
safe_close c >|= fun () ->
None)
let pp_process_status ppf = function
| Unix.WEXITED c -> Fmt.pf ppf "exited with %d" c
| Unix.WSIGNALED s -> Fmt.pf ppf "killed by signal %a" Fmt.Dump.signal s
| Unix.WSTOPPED s -> Fmt.pf ppf "stopped by signal %a" Fmt.Dump.signal s
let ret = function
| Unix.WEXITED c -> `Exit c
| Unix.WSIGNALED s -> `Signal s
| Unix.WSTOPPED s -> `Stop s
let rec waitpid pid =
Lwt.catch
(fun () -> Lwt_unix.waitpid [] pid >|= fun r -> Ok r)
(function
| Unix.(Unix_error (EINTR, _, _)) ->
Logs.debug (fun m -> m "EINTR in waitpid(), %d retrying" pid) ;
waitpid pid
| e ->
Logs.err (fun m -> m "error %s in waitpid() %d"
(Printexc.to_string e) pid) ;
Lwt.return (Error ()))
let wait_and_clear pid =
Logs.debug (fun m -> m "waitpid() for pid %d" pid) ;
waitpid pid >|= fun r ->
match r with
| Error () ->
Logs.err (fun m -> m "waitpid() for %d returned error" pid) ;
`Exit 23
| Ok (_, s) ->
Logs.debug (fun m -> m "pid %d exited: %a" pid pp_process_status s) ;
ret s
let read_wire s =
let buf = Bytes.create 4 in
let rec r b i l =
Lwt.catch (fun () ->
Lwt_unix.read s b i l >>= function
| 0 ->
Logs.debug (fun m -> m "end of file while reading") ;
Lwt.return (Error `Eof)
| n when n == l -> Lwt.return (Ok ())
| n when n < l -> r b (i + n) (l - n)
| _ ->
Logs.err (fun m -> m "read too much, shouldn't happen)") ;
Lwt.return (Error `Toomuch))
(fun e ->
let err = Printexc.to_string e in
Logs.err (fun m -> m "exception %s while reading" err) ;
safe_close s >|= fun () ->
Error `Exception)
in
r buf 0 4 >>= function
| Error e -> Lwt.return (Error e)
| Ok () ->
let len = Cstruct.BE.get_uint32 (Cstruct.of_bytes buf) 0 in
if len > 0l then begin
let b = Bytes.create (Int32.to_int len) in
r b 0 (Int32.to_int len) >|= function
| Error e -> Error e
| Ok () ->
match Vmm_asn.wire_of_cstruct (Cstruct.of_bytes b) with
| Error (`Msg msg) ->
Logs.err (fun m -> m "error %s while parsing data" msg) ;
Error `Exception
| (Ok (hdr, _)) as w ->
if not Vmm_commands.(is_current hdr.version) then
Logs.warn (fun m -> m "version mismatch, received %a current %a"
Vmm_commands.pp_version hdr.Vmm_commands.version
Vmm_commands.pp_version Vmm_commands.current);
w
end else begin
Lwt.return (Error `Eof)
end
let write_raw s buf =
let rec w off l =
Lwt.catch (fun () ->
Lwt_unix.send s buf off l [] >>= fun n ->
if n = l then
Lwt.return (Ok ())
else
w (off + n) (l - n))
(fun e ->
Logs.err (fun m -> m "exception %s while writing" (Printexc.to_string e)) ;
safe_close s >|= fun () ->
Error `Exception)
in
w 0 (Bytes.length buf)
let write_wire s wire =
let data = Vmm_asn.wire_to_cstruct wire in
let dlen = Cstruct.create 4 in
Cstruct.BE.set_uint32 dlen 0 (Int32.of_int (Cstruct.length data)) ;
let buf = Cstruct.(to_bytes (append dlen data)) in
write_raw s buf
let read_from_file file =
Lwt.catch (fun () ->
Lwt_unix.stat file >>= fun stat ->
let size = stat.Lwt_unix.st_size in
Lwt_unix.openfile file Lwt_unix.[O_RDONLY] 0 >>= fun fd ->
Lwt.catch (fun () ->
let buf = Bytes.create size in
let rec read off =
Lwt_unix.read fd buf off (size - off) >>= fun bytes ->
if bytes + off = size then
Lwt.return_unit
else
read (bytes + off)
in
read 0 >>= fun () ->
safe_close fd >|= fun () ->
Cstruct.of_bytes buf)
(fun e ->
Logs.err (fun m -> m "exception %s while reading %s" (Printexc.to_string e) file) ;
safe_close fd >|= fun () ->
Cstruct.empty))
(fun e ->
Logs.err (fun m -> m "exception %s while reading %s" (Printexc.to_string e) file) ;
Lwt.return Cstruct.empty)