Source file dns_resolver_cache.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
(* (c) 2017, 2018 Hannes Mehnert, all rights reserved *)

open Dns

module N = Domain_name.Set

let src = Logs.Src.create "dns_resolver_cache" ~doc:"DNS resolver cache"
module Log = (val Logs.src_log src : Logs.LOG)

let _pp_err ppf = function
  | `Cache_miss -> Fmt.string ppf "cache miss"
  | `Cache_drop -> Fmt.string ppf "cache drop"

let pp_question ppf (name, typ) =
  Fmt.pf ppf "%a (%a)" Domain_name.pp name Packet.Question.pp_qtype typ

let nsec_no_ds t ts name =
  let rec up name =
    match snd (Dns_cache.get t ts name Nsec) with
    | Ok (`Entry (_, nsec), _) ->
      not (Bit_map.mem (Rr_map.to_int Ds) nsec.Nsec.types)
    | _ ->
      if Domain_name.count_labels name >= 1 then
        up (Domain_name.drop_label_exn name)
      else
        false
  in
  up name

let nsec3_covering t ts name =
  let rec up name =
    match snd (Dns_cache.get_nsec3 t ts name) with
    | Ok nsec3 ->
      let Nsec3.{ iterations ; salt ; _ } = snd (List.hd nsec3) in
      let soa_name = Domain_name.drop_label_exn name in
      let hashed_name = Dnssec.nsec3_hashed_name salt iterations ~soa_name name in
      List.exists (fun (name, nsec3) ->
          let hashed_next_owner =
            Domain_name.prepend_label_exn soa_name
              (Base32.encode (Cstruct.to_string nsec3.Nsec3.next_owner_hashed))
          in
          (* TODO non-wc-expanded nsec3 only?? *)
          (Domain_name.compare name hashed_name < 0 &&
           Domain_name.compare hashed_name hashed_next_owner < 0) ||
          (* TODO wc nsec3 as well? *)
          (Domain_name.compare name hashed_name = 0 &&
           not (Bit_map.mem (Rr_map.to_int Ds) nsec3.types))
        )
        nsec3
    | Error _ ->
      if Domain_name.count_labels name > 1 then
        up (Domain_name.drop_label_exn name)
      else
        false
  in
  up name

let find_nearest_ns rng ip_proto dnssec ts t name =
  let pick = function
    | [] -> None
    | [ x ] -> Some x
    | xs -> Some (List.nth xs (Randomconv.int ~bound:(List.length xs) rng))
  in
  let find_ns name = match snd (Dns_cache.get t ts name Ns) with
    | Ok (`Entry (_, names), _) -> Domain_name.Host_set.elements names
    | _ -> []
  and find_dnskey name = match snd (Dns_cache.get t ts name Dnskey) with
    | Ok _ -> true
    | _ -> false
  and need_to_query_for_ds name = match snd (Dns_cache.get t ts name Ds) with
    | Ok _ -> true
    | Error _ -> nsec_no_ds t ts name || nsec3_covering t ts name
  and find_address name =
    let ip4s =
      Result.fold
        ~ok:(function
            | `Entry (_, ips), _ ->
              List.map (fun ip -> Ipaddr.V4 ip) (Ipaddr.V4.Set.elements ips)
            | _ -> [])
        ~error:(fun _ -> [])
        (snd (Dns_cache.get t ts name A))
    and ip6s =
      Result.fold
        ~ok:(function
            | `Entry (_, ips), _ ->
              List.map (fun ip -> Ipaddr.V6 ip) (Ipaddr.V6.Set.elements ips)
            | _ -> [])
        ~error:(fun _ -> [])
        (snd (Dns_cache.get t ts name Aaaa))
    in
    match ip_proto with
    | `Both -> ip4s @ ip6s
    | `Ipv4_only -> ip4s
    | `Ipv6_only -> ip6s
  in
  let have_ip_or_dnskey name ip =
    if dnssec && not (find_dnskey name) then
      `NeedDnskey (name, ip)
    else
      `HaveIP (name, ip)
  in
  let or_root f nam =
    if Domain_name.(equal root nam) then
      match pick (Dns_resolver_root.ips ip_proto) with
      | None -> assert false
      | Some ip -> have_ip_or_dnskey nam ip
    else
      f (Domain_name.drop_label_exn nam)
  in
  let rec go nam =
    (* Log.warn (fun m -> m "go %a" Domain_name.pp nam); *)
    match pick (find_ns nam) with
    | None ->
      (* Log.warn (fun m -> m "go no NS for %a" Domain_name.pp nam); *)
      or_root go nam
    | Some _ when dnssec && not (need_to_query_for_ds nam) ->
      or_root go nam
    | Some ns ->
      let host = Domain_name.raw ns in
      match pick (find_address host) with
      | None ->
        (* Log.warn (fun m -> m "go no address for NS %a (for %a)"
                      Domain_name.pp host
                      Domain_name.pp nam); *)
        if Domain_name.is_subdomain ~subdomain:ns ~domain:nam then
          (* we actually need glue *)
          or_root go nam
        else
          `NeedAddress host
      | Some ip ->
        (* Log.warn (fun m -> m "go address for NS %a (for %a): %a (dnskey %B)"
                      Domain_name.pp host
                      Domain_name.pp nam
                      Ipaddr.pp ip
                      (find_dnskey nam)); *)
        have_ip_or_dnskey nam ip
  in
  go name

let resolve t ~dnssec ~rng ip_proto ts name typ =
  (* the standard recursive algorithm *)
  let addresses = match ip_proto with
    | `Both -> [`K (Rr_map.K A); `K (Rr_map.K Aaaa)]
    | `Ipv4_only -> [`K (Rr_map.K A)]
    | `Ipv6_only -> [`K (Rr_map.K Aaaa)]
  in
  (* with DNSSec:
     - input is qname and qtyp
     - (a) we have (validated) NS record (+DNSKEY) for zone -> move along
     - (b) we miss a NS entry -> drop label and find one
     ---> we also want to collect DS and DNSKEY entries (or non-existence of DS)
     ---> we get DS by dnssec ok in EDNS
     ---> we may have unsigned NS (+ glue), and need to ask the NS for NS (+dnssec)
     ---> we may have unsigned glue, and need to go down for signed A/AAAA
  *)
  let rec go t types name =
    Log.debug (fun m -> m "go %a" Domain_name.pp name) ;
    match find_nearest_ns rng ip_proto dnssec ts t (Domain_name.raw name) with
    | `NeedAddress ns -> go t addresses ns
    | `NeedDnskey (zone, ip) -> zone, zone, [`K (Rr_map.K Dnskey)], ip, t
    | `HaveIP (zone, ip) -> zone, name, types, ip, t
  in
  go t [typ] name

let is_signed = function
  | Dns_cache.AuthoritativeAnswer signed
  | AuthoritativeAuthority signed -> signed
  | _ -> false

let to_map (name, soa) = Name_rr_map.singleton name Soa soa

let follow_cname t ts typ ~name ttl ~alias =
  let rec follow t acc name =
    let t, r = Dns_cache.get_or_cname t ts name typ in
    match r with
    | Error _ ->
      Log.debug (fun m -> m "follow_cname: cache miss, need to query %a"
                     Domain_name.pp name);
      `Query name, t
    | Ok (`Alias (_, alias), r) ->
      let acc' = Domain_name.Map.add name (Rr_map.singleton Cname (ttl, alias)) acc in
      if Domain_name.Map.mem alias acc then begin
        Log.warn (fun m -> m "follow_cname: cycle detected") ;
        `Out (Rcode.NoError, is_signed r, acc', Name_rr_map.empty), t
      end else begin
        Log.debug (fun m -> m "follow_cname: alias to %a, follow again"
                       Domain_name.pp alias);
        follow t acc' alias
      end
    | Ok (`Entry v, r) ->
      let acc' = Domain_name.Map.add name Rr_map.(singleton typ v) acc in
      Log.debug (fun m -> m "follow_cname: entry found, returning");
      `Out (Rcode.NoError, is_signed r, acc', Name_rr_map.empty), t
    | Ok (`No_domain res, r) ->
      Log.debug (fun m -> m "follow_cname: nodom");
      `Out (Rcode.NXDomain, is_signed r, acc, to_map res), t
    | Ok (`No_data res, r) ->
      Log.debug (fun m -> m "follow_cname: nodata");
      `Out (Rcode.NoError, is_signed r, acc, to_map res), t
    | Ok (`Serv_fail res, r) ->
      Log.debug (fun m -> m "follow_cname: servfail") ;
      `Out (Rcode.ServFail, is_signed r, acc, to_map res), t
  in
  let initial = Name_rr_map.singleton name Cname (ttl, alias) in
  follow t initial alias

let answer t ts name typ =
  let packet _t _add rcode ~signed answer authority =
    let data = (answer, authority) in
    let flags =
      let f = Packet.Flags.(add `Recursion_available (singleton `Recursion_desired)) in
      if signed then
        Packet.Flags.add `Authentic_data f
      else
        f
    (* XXX: we should look for a fixpoint here ;) *)
    (*    and additional, t = if add then additionals t ts answer else [], t *)
    and data = match rcode with
      | Rcode.NoError -> `Answer data
      | x ->
        let data = if Packet.Answer.is_empty data then None else Some data in
        `Rcode_error (x, Opcode.Query, data)
    in
    flags, data
  in
  match typ with
  | `Any ->
    let t, r = Dns_cache.get_any t ts name in
    begin match r with
      | Error _e ->
        (* Log.warn (fun m -> m "error %a while looking up %a, query"
                      pp_err e pp_question (name, typ)); *)
        `Query name, t
      | Ok (`No_domain res, r) ->
        Log.debug (fun m -> m "no domain while looking up %a, query" pp_question (name, typ));
        `Packet (packet t false Rcode.NXDomain ~signed:(is_signed r) Domain_name.Map.empty (to_map res)), t
      | Ok (`Entries rr_map, r) ->
        Log.debug (fun m -> m "entries while looking up %a" pp_question (name, typ));
        let data = Domain_name.Map.singleton name rr_map in
        `Packet (packet t true Rcode.NoError ~signed:(is_signed r) data Domain_name.Map.empty), t
    end
  | `K (Rr_map.K ty) ->
    let t, r = Dns_cache.get_or_cname t ts name ty in
    match r with
    | Error _e ->
      (* Log.warn (fun m -> m "error %a while looking up %a, query"
                    pp_err e pp_question (name, typ)); *)
      `Query name, t
    | Ok (`No_domain res, r) ->
      Log.debug (fun m -> m "no domain while looking up %a, query" pp_question (name, typ));
      `Packet (packet t false Rcode.NXDomain ~signed:(is_signed r) Domain_name.Map.empty (to_map res)), t
    | Ok (`No_data res, r) ->
      Log.debug (fun m -> m "no data while looking up %a" pp_question (name, typ));
      `Packet (packet t false Rcode.NoError ~signed:(is_signed r) Domain_name.Map.empty (to_map res)), t
    | Ok (`Serv_fail res, _) ->
      Log.debug (fun m -> m "serv fail while looking up %a" pp_question (name, typ));
      `Packet (packet t false Rcode.ServFail ~signed:false Domain_name.Map.empty (to_map res)), t
    | Ok (`Alias (ttl, alias), r) ->
      begin
        Log.debug (fun m -> m "alias while looking up %a" pp_question (name, typ));
        match ty with
        | Cname ->
          let data = Name_rr_map.singleton name Cname (ttl, alias) in
          `Packet (packet t false Rcode.NoError ~signed:(is_signed r) data Domain_name.Map.empty), t
        | ty ->
          match follow_cname t ts ty ~name ttl ~alias with
          | `Out (rcode, signed, an, au), t -> `Packet (packet t true rcode ~signed an au), t
          | `Query n, t -> `Query n, t
      end
    | Ok (`Entry v, r) ->
      Log.debug (fun m -> m "entry while looking up %a" pp_question (name, typ));
      let data = Name_rr_map.singleton name ty v in
      `Packet (packet t true Rcode.NoError ~signed:(is_signed r) data Domain_name.Map.empty), t

let handle_query t ~dnssec ~rng ip_proto ts (qname, qtype) =
  Log.info (fun m -> m "handle query %a (%a)"
                Domain_name.pp qname Packet.Question.pp_qtype qtype);
  match answer t ts qname qtype with
  | `Packet (flags, data), t ->
    Log.info (fun m -> m "reply for %a (%a)" Domain_name.pp qname Packet.Question.pp_qtype qtype);
    `Reply (flags, data), t
  | `Query name, t ->
    Log.info (fun m -> m "query for %a (%a): %a" Domain_name.pp qname Packet.Question.pp_qtype qtype Domain_name.pp name);
    (* DS should be requested at the parent *)
    let name', recover =
      if Domain_name.count_labels name > 1 && qtype = `K (Rr_map.K Ds) then
        let n' = Domain_name.drop_label_exn name in
        n', fun n -> if Domain_name.equal n n' then name else n
      else
        name, Fun.id
    in
    let zone, name'', types, ip, t = resolve t ~dnssec ~rng ip_proto ts name' qtype in
    let name'' = recover name'' in
    Log.info (fun m -> m "resolve returned zone %a query %a (%a), ip %a"
                   Domain_name.pp zone Domain_name.pp name''
                   Fmt.(list ~sep:(any ", ") Packet.Question.pp_qtype) types
                   Ipaddr.pp ip);
    `Query (zone, (name'', types), ip), t