Source file dns_client_lwt.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
open Lwt.Infix
module Transport : Dns_client.S
with type io_addr = Lwt_unix.inet_addr * int
and type +'a io = 'a Lwt.t
and type stack = unit
= struct
type io_addr = Lwt_unix.inet_addr * int
type ns_addr = [`TCP | `UDP] * io_addr
type +'a io = 'a Lwt.t
type stack = unit
type t = {
nameserver : ns_addr ;
timeout_ns : int64 ;
}
type context = { t : t ; fd : Lwt_unix.file_descr ; timeout_ns : int64 ref }
let read_file file =
try
let fh = open_in file in
try
let content = really_input_string fh (in_channel_length fh) in
close_in_noerr fh ;
Ok content
with _ ->
close_in_noerr fh;
Error (`Msg ("Error reading file: " ^ file))
with _ -> Error (`Msg ("Error opening file " ^ file))
let create ?nameserver ~timeout () =
let nameserver =
Rresult.R.(get_ok (of_option ~none:(fun () ->
let ip =
match
read_file "/etc/resolv.conf" >>= fun data ->
Dns_resolvconf.parse data >>= fun nameservers ->
List.fold_left (fun acc ns ->
match acc, ns with
| Ok ip, _ -> Ok ip
| _, `Nameserver (Ipaddr.V4 ip) -> Ok ip
| acc, _ -> acc)
(Error (`Msg "no nameserver")) nameservers
with
| Error _ -> Unix.inet_addr_of_string Dns_client.default_resolver
| Ok ip -> Ipaddr_unix.V4.to_inet_addr ip
in
Ok (`TCP, (ip, 53)))
nameserver))
in
{ nameserver ; timeout_ns = timeout }
let nameserver { nameserver ; _ } = nameserver
let rng = Mirage_crypto_rng.generate ?g:None
let clock = Mtime_clock.elapsed_ns
let with_timeout ctx f =
let timeout = Lwt_unix.sleep (Duration.to_f !(ctx.timeout_ns)) >|= fun () -> Error (`Msg "DNS request timeout") in
let start = clock () in
Lwt.pick [ f ; timeout ] >|= fun result ->
let stop = clock () in
ctx.timeout_ns := Int64.sub !(ctx.timeout_ns) (Int64.sub stop start);
result
let close { fd ; _ } =
Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit)
let send ctx tx =
let open Lwt in
Lwt.catch (fun () ->
with_timeout ctx
(Lwt_unix.send ctx.fd (Cstruct.to_bytes tx) 0
(Cstruct.len tx) [] >>= fun res ->
if res <> Cstruct.len tx then
Lwt_result.fail (`Msg ("oops" ^ (string_of_int res)))
else
Lwt_result.return ()))
(fun e -> Lwt.return (Error (`Msg (Printexc.to_string e))))
let recv ctx =
let open Lwt in
let recv_buffer = Bytes.make 2048 '\000' in
Lwt.catch (fun () ->
with_timeout ctx
(Lwt_unix.recv ctx.fd recv_buffer 0 (Bytes.length recv_buffer) []
>>= fun read_len ->
if read_len > 0 then
Lwt_result.return (Cstruct.of_bytes ~len:read_len recv_buffer)
else
Lwt_result.fail (`Msg "Empty response")))
(fun e -> Lwt_result.fail (`Msg (Printexc.to_string e)))
let bind = Lwt.bind
let lift = Lwt.return
let connect ?nameserver:ns t =
let (proto, (server, port)) =
match ns with None -> nameserver t | Some x -> x
in
Lwt.catch (fun () ->
begin match proto with
| `UDP ->
Lwt_unix.((getprotobyname "udp") >|= fun x -> x.p_proto,
SOCK_DGRAM)
| `TCP ->
Lwt_unix.((getprotobyname "tcp") >|= fun x -> x.p_proto,
SOCK_STREAM)
end >>= fun (proto_number, socket_type) ->
let socket = Lwt_unix.socket PF_INET socket_type proto_number in
let addr = Lwt_unix.ADDR_INET (server, port) in
let ctx = { t ; fd = socket ; timeout_ns = ref t.timeout_ns } in
Lwt.catch (fun () ->
with_timeout ctx
(Lwt_unix.connect socket addr >|= fun () -> Ok ()) >>= function
| Ok () -> Lwt_result.return ctx
| Error e -> close ctx >|= fun () -> Error e)
(fun e ->
close ctx >|= fun () ->
Error (`Msg (Printexc.to_string e))))
(fun e ->
Lwt_result.fail (`Msg (Printexc.to_string e)))
end
include Dns_client.Make(Transport)
let () = Mirage_crypto_rng_lwt.initialize ()