Skip to content

Commit 51807a3

Browse files
committed
dns-client(eio): update to latest happy-eyeballs
1 parent d98235a commit 51807a3

File tree

3 files changed

+169
-153
lines changed

3 files changed

+169
-153
lines changed

eio/client/dns_client_eio.ml

Lines changed: 131 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ type 'a env = <
77
..
88
> as 'a
99

10-
type io_addr = [`Plaintext of Ipaddr.t * int | `Tls of Tls.Config.client * Ipaddr.t * int]
10+
type io_addr = [`Plaintext of Ipaddr.t * int | `Tls of Ipaddr.t * int]
1111
type stack = {
1212
fs : Eio.Fs.dir Eio.Path.t;
1313
sw : Eio.Switch.t;
@@ -30,20 +30,20 @@ module Transport : Dns_client.S
3030
type nonrec stack = stack
3131
type +'a io = 'a
3232

33-
type t =
34-
{ nameservers : Dns.proto * nameservers
35-
; stack : stack
36-
; timeout : Eio.Time.Timeout.t
37-
; mutable ns_connection_condition : Eio.Condition.t option
38-
; mutable ctx : (Dns.proto * context) option
39-
}
33+
type t = {
34+
nameservers : Dns.proto * nameservers ;
35+
stack : stack ;
36+
timeout : Eio.Time.Timeout.t ;
37+
mutable ns_connection_condition : Eio.Condition.t option ;
38+
mutable ctx : (Dns.proto * context) option ;
39+
}
4040

41-
and context =
42-
{ t : t
43-
; mutable requests : Cstruct.t Eio.Promise.u IM.t
44-
; mutable ns_connection: <Eio.Flow.two_way>
45-
; mutable buf : Cstruct.t
46-
}
41+
and context = {
42+
t : t ;
43+
mutable requests : Cstruct.t Eio.Promise.u IM.t ;
44+
mutable ns_connection: <Eio.Flow.two_way> ;
45+
mutable buf : Cstruct.t ;
46+
}
4747

4848
(* DNS nameservers. *)
4949
and nameservers =
@@ -65,54 +65,36 @@ module Transport : Dns_client.S
6565
let ( let* ) = Result.bind
6666
let ( let+ ) r f = Result.map f r
6767

68-
let authenticator =
69-
let authenticator_ref = ref None in
70-
fun () ->
71-
match !authenticator_ref with
72-
| Some x -> x
73-
| None -> match Ca_certs.authenticator () with
74-
| Ok a -> authenticator_ref := Some a ; a
75-
| Error `Msg m -> invalid_arg ("failed to load trust anchors: " ^ m)
76-
7768
let decode_resolv_conf data =
7869
let* ips = Dns_resolvconf.parse data in
79-
let authenticator = authenticator () in
8070
match ips with
8171
| [] -> Error (`Msg "empty nameservers from resolv.conf")
8272
| ips ->
83-
List.map
84-
(function `Nameserver ip ->
85-
let tls_config = Tls.Config.client ~authenticator ~ip () in
86-
[`Plaintext (ip, 53); `Tls (tls_config, ip, 853)]
87-
)
88-
ips
73+
List.map (function `Nameserver ip -> [`Plaintext (ip, 53); `Tls (ip, 853)]) ips
8974
|> List.flatten
9075
|> Result.ok
9176

9277
let default_resolvers () =
93-
let authenticator = authenticator () in
94-
let peer_name = Dns_client.default_resolver_hostname in
95-
let tls_config = Tls.Config.client ~authenticator ~peer_name () in
96-
List.map (fun ip -> `Tls (tls_config, ip, 853)) Dns_client.default_resolvers
78+
List.map (fun ip -> `Tls (ip, 853)) Dns_client.default_resolvers
9779

9880
let rng = Mirage_crypto_rng.generate ?g:None
9981
let clock = Mtime_clock.elapsed_ns
10082

10183
let create ?nameservers ~timeout stack =
10284
{ nameservers =
10385
(match nameservers with
104-
| Some (`Udp,_) -> invalid_arg "UDP is not supported"
105-
| Some (proto, []) -> proto, Static (default_resolvers ())
106-
| Some (`Tcp, ns) -> `Tcp, Static ns
107-
| None ->
108-
(let* data = read_resolv_conf stack in
109-
let+ ips = decode_resolv_conf data in
110-
(ips, Some (Digest.string data)))
111-
|> function
112-
| Error (`Msg e) ->
113-
Log.warn (fun m -> m "failed to decode %s - %s" stack.resolv_conf e);
114-
(`Tcp, Resolv_conf { ips = default_resolvers (); digest = None})
115-
| Ok(ips, digest) -> `Tcp, Resolv_conf {ips; digest})
86+
| Some (`Udp,_) -> invalid_arg "UDP is not supported"
87+
| Some (proto, []) -> proto, Static (default_resolvers ())
88+
| Some (`Tcp, ns) -> `Tcp, Static ns
89+
| None ->
90+
(let* data = read_resolv_conf stack in
91+
let+ ips = decode_resolv_conf data in
92+
(ips, Some (Digest.string data)))
93+
|> function
94+
| Error (`Msg e) ->
95+
Log.warn (fun m -> m "failed to decode %s - %s" stack.resolv_conf e);
96+
(`Tcp, Resolv_conf { ips = default_resolvers (); digest = None})
97+
| Ok(ips, digest) -> `Tcp, Resolv_conf {ips; digest})
11698
; stack
11799
; timeout = Eio.Time.Timeout.v stack.mono_clock @@ Mtime.Span.of_uint64_ns timeout
118100
; ns_connection_condition = None
@@ -140,65 +122,67 @@ module Transport : Dns_client.S
140122
| _, Static _ -> ()
141123
| _, Resolv_conf resolv_conf ->
142124
(match read_resolv_conf t.stack, resolv_conf.digest with
143-
| Ok data, Some d ->
144-
let digest = Digest.string data in
145-
if Digest.equal digest d then () else update_resolv_conf resolv_conf data digest
146-
| Ok data, None -> update_resolv_conf resolv_conf data (Digest.string data)
147-
| Error _, None -> ()
148-
| Error _, Some _ ->
149-
resolv_conf.digest <- None;
150-
resolv_conf.ips <- default_resolvers ())
125+
| Ok data, Some d ->
126+
let digest = Digest.string data in
127+
if Digest.equal digest d then () else update_resolv_conf resolv_conf data digest
128+
| Ok data, None -> update_resolv_conf resolv_conf data (Digest.string data)
129+
| Error _, None -> ()
130+
| Error _, Some _ ->
131+
resolv_conf.digest <- None;
132+
resolv_conf.ips <- default_resolvers ())
151133

152134
let find_ns t (ip, port) =
153135
List.find
154-
(function `Plaintext (ip', p)
155-
| `Tls (_, ip', p) -> Ipaddr.compare ip ip' = 0 && p = port
156-
)
136+
(function `Plaintext (ip', p) | `Tls (ip', p) -> Ipaddr.compare ip ip' = 0 && p = port)
157137
(nameserver_ips t)
158138

159-
let rec he_handle_actions t he actions : #Eio.Flow.two_way option =
139+
let rec he_handle_actions t he actions =
160140
let fiber_of_action = function
161141
| Happy_eyeballs.Connect (host, id, (ip, port)) ->
162142
fun () ->
163143
let ip' =
164-
begin match ip with
165-
| Ipaddr.V4 ip -> Ipaddr.V4.to_octets ip
166-
| Ipaddr.V6 ip -> Ipaddr.V6.to_octets ip
167-
end
144+
(match ip with
145+
| Ipaddr.V4 ip -> Ipaddr.V4.to_octets ip
146+
| Ipaddr.V6 ip -> Ipaddr.V6.to_octets ip)
168147
|> Eio.Net.Ipaddr.of_raw
169148
in
170149
let stream = `Tcp (ip', port) in
171150
begin try
172-
Eio.Time.Timeout.run_exn t.timeout (fun () ->
173-
let flow = Eio.Net.connect ~sw:t.stack.sw t.stack.net stream in
174-
Log.debug (fun m -> m "he_handle_actions: connected to nameserver (%a)"
175-
Fmt.(pair ~sep:comma Ipaddr.pp int) (ip, port));
176-
let flow =
177-
match find_ns t (ip, port) with
178-
| `Plaintext _ -> (flow :> Eio.Flow.two_way)
179-
| `Tls (config, _,_) -> (Tls_eio.client_of_flow config flow :> Eio.Flow.two_way)
180-
in
181-
Some flow)
182-
with Eio.Time.Timeout ->
183-
Log.debug (fun m -> m "he_handle_actions: connection to nameserver (%a) timed out"
184-
Fmt.(pair ~sep:comma Ipaddr.pp int) (ip, port));
185-
let event = Happy_eyeballs.Connection_failed (host, id, (ip, port)) in
186-
let he, actions = Happy_eyeballs.event he (clock ()) event in
187-
he_handle_actions t he actions
151+
Eio.Time.Timeout.run_exn t.timeout (fun () ->
152+
let flow = Eio.Net.connect ~sw:t.stack.sw t.stack.net stream in
153+
Log.debug (fun m -> m "[he_handle_actions] connected to nameserver (%a)"
154+
Fmt.(pair ~sep:comma Ipaddr.pp int) (ip, port));
155+
Some (ip, port, flow))
156+
with Eio.Time.Timeout as ex ->
157+
Log.debug (fun m -> m "[he_handle_actions] connection to nameserver (%a) timed out"
158+
Fmt.(pair ~sep:comma Ipaddr.pp int) (ip, port));
159+
let err = Printexc.to_string ex in
160+
let event = Happy_eyeballs.Connection_failed (host, id, (ip, port), err) in
161+
let he, actions = Happy_eyeballs.event he (clock ()) event in
162+
he_handle_actions t he actions
188163
end
189-
| Happy_eyeballs.Connect_failed (_host, id) ->
164+
| Connect_failed _ ->
190165
fun () ->
191-
Logs.debug (fun m -> m "he_handle_actions: connection failed %d" id);
166+
Log.debug (fun m -> m "[he_handle_actions] connection failed");
192167
None
193-
| a ->
168+
| Connect_cancelled _ | Resolve_a _ | Resolve_aaaa _ as a ->
194169
fun () ->
195-
Log.warn (fun m -> m "he_handle_actions: ignoring action %a" Happy_eyeballs.pp_action a);
170+
Log.warn (fun m -> m "[he_handle_actions] ignoring action %a" Happy_eyeballs.pp_action a);
196171
None
197172
in
198173
Eio.Fiber.any (List.map fiber_of_action actions)
199174

200175
let to_ip_port =
201-
List.map (function `Plaintext (ip, port) -> (ip, port) | `Tls (_, ip, port) -> (ip, port))
176+
List.map (function `Plaintext (ip, port) -> (ip, port) | `Tls (ip, port) -> (ip, port))
177+
178+
let authenticator =
179+
let authenticator_ref = ref None in
180+
fun () ->
181+
match !authenticator_ref with
182+
| Some x -> x
183+
| None -> match Ca_certs.authenticator () with
184+
| Ok a -> authenticator_ref := Some a ; a
185+
| Error `Msg m -> invalid_arg ("failed to load trust anchors: " ^ m)
202186

203187
let rec connect t =
204188
Log.debug (fun m -> m "connect : establishing connection to nameservers");
@@ -212,47 +196,56 @@ module Transport : Dns_client.S
212196
t.ns_connection_condition <- Some ns_connection_condition;
213197
maybe_update_nameservers t;
214198
let ns = to_ip_port @@ nameserver_ips t in
199+
let _waiters, id = Happy_eyeballs.Waiter_map.(register () empty) in
215200
let he = Happy_eyeballs.create (clock ()) in
216-
let he, actions = Happy_eyeballs.connect_ip he (clock ()) ~id:1 ns in
201+
let he, actions = Happy_eyeballs.connect_ip he (clock ()) ~id ns in
217202
begin match he_handle_actions t he actions with
218-
| Some conn ->
219-
let context =
220-
{ t = t
221-
; requests = IM.empty
222-
; ns_connection = conn
223-
; buf = Cstruct.empty
224-
}
225-
in
226-
t.ctx <- Some (`Tcp, context);
227-
Eio.Condition.broadcast ns_connection_condition;
228-
Ok (`Tcp, context)
229-
| None ->
230-
t.ns_connection_condition <- None;
231-
Eio.Condition.broadcast ns_connection_condition;
232-
let error_msg =
233-
Fmt.str "unable to connect to nameservers %a"
234-
Fmt.(list ~sep:(any ", ") (pair ~sep:(any ":") Ipaddr.pp int))
235-
(to_ip_port @@ nameserver_ips t)
236-
in
237-
Logs.debug (fun m -> m "connect : %s" error_msg);
238-
Error (`Msg error_msg)
203+
| Some (ip, port, conn) ->
204+
let conn =
205+
match find_ns t (ip, port) with
206+
| `Plaintext _ -> (conn :> Eio.Flow.two_way)
207+
| `Tls (_,_) ->
208+
let authenticator = authenticator () in
209+
let config = Tls.Config.(client ~authenticator ()) in
210+
(Tls_eio.client_of_flow config conn :> Eio.Flow.two_way)
211+
in
212+
let context =
213+
{ t = t
214+
; requests = IM.empty
215+
; ns_connection = conn
216+
; buf = Cstruct.empty
217+
}
218+
in
219+
t.ctx <- Some (`Tcp, context);
220+
Eio.Condition.broadcast ns_connection_condition;
221+
Ok (`Tcp, context)
222+
| None ->
223+
t.ns_connection_condition <- None;
224+
Eio.Condition.broadcast ns_connection_condition;
225+
let error_msg =
226+
Fmt.str "unable to connect to nameservers %a"
227+
Fmt.(list ~sep:(any ", ") (pair ~sep:(any ":") Ipaddr.pp int))
228+
(to_ip_port @@ nameserver_ips t)
229+
in
230+
Log.debug (fun m -> m "connect : %s" error_msg);
231+
Error (`Msg error_msg)
239232
end
240233

241234
let recv_data t flow id : unit =
242235
let buf = Cstruct.create 512 in
243-
Logs.debug (fun m -> m "recv_data (%d): t.buf.len %d" id (Cstruct.length t.buf));
236+
Log.debug (fun m -> m "recv_data (%X): t.buf.len %d" id (Cstruct.length t.buf));
244237
let got = Eio.Flow.single_read flow buf in
245-
Logs.debug (fun m -> m "recv_data (%d): got %d" id got);
238+
Log.debug (fun m -> m "recv_data (%X): got %d" id got);
246239
let buf = Cstruct.sub buf 0 got in
247240
t.buf <- if Cstruct.length t.buf = 0 then buf else Cstruct.append t.buf buf;
248-
Logs.debug (fun m -> m "recv_data (%d): t.buf.len %d" id (Cstruct.length t.buf))
241+
Log.debug (fun m -> m "recv_data (%X): t.buf.len %d" id (Cstruct.length t.buf))
249242

250243
let rec recv_packet t ns_connection request_id =
251-
Logs.debug (fun m -> m "recv_packet (%d): recv_packet" request_id);
244+
Log.debug (fun m -> m "recv_packet (%X)" request_id);
252245
let buf_len = Cstruct.length t.buf in
253246
if buf_len > 2 then (
254247
let packet_len = Cstruct.BE.get_uint16 t.buf 0 in
255-
Logs.debug (fun m -> m "recv_packet (%d): packet_len %d" request_id (Cstruct.length t.buf));
248+
Log.debug (fun m -> m "recv_packet (%X): packet_len %d" request_id (Cstruct.length t.buf));
256249
if buf_len - 2 >= packet_len then
257250
let packet, rest =
258251
if buf_len - 2 = packet_len
@@ -261,13 +254,13 @@ module Transport : Dns_client.S
261254
in
262255
t.buf <- rest;
263256
let response_id = Cstruct.BE.get_uint16 packet 2 in
264-
Logs.debug (fun m -> m "recv_packet (%d): response %d" request_id response_id);
257+
Log.debug (fun m -> m "recv_packet (%X): got response %X" request_id response_id);
265258
if response_id = request_id
266259
then packet
267260
else begin
268261
(match IM.find response_id t.requests with
269-
| r -> Eio.Promise.resolve r packet
270-
| exception Not_found -> ());
262+
| r -> Eio.Promise.resolve r packet
263+
| exception Not_found -> ());
271264
recv_packet t ns_connection request_id
272265
end
273266
else begin
@@ -282,28 +275,28 @@ module Transport : Dns_client.S
282275

283276
let validate_query_packet tx =
284277
if Cstruct.length tx > 4 then Ok () else
285-
Error (`Msg "Invalid DNS query packet (data length <= 4)")
278+
Error (`Msg "Invalid DNS query packet (data length <= 4)")
286279

287280
let send_recv ctx packet =
288281
let* () = validate_query_packet packet in
289282
try
290283
let request_id = Cstruct.BE.get_uint16 packet 2 in
291284
Eio.Time.Timeout.run_exn ctx.t.timeout (fun () ->
292-
Eio.Flow.write ctx.ns_connection [packet];
293-
Logs.debug (fun m -> m "send_recv (%d): request" request_id);
294-
let response_p, response_r = Eio.Promise.create () in
295-
ctx.requests <- IM.add request_id response_r ctx.requests;
296-
let response =
297-
Eio.Fiber.first
298-
(fun () -> recv_packet ctx ctx.ns_connection request_id)
299-
(fun () -> Eio.Promise.await response_p)
300-
in
301-
Logs.debug (fun m -> m "send_recv (%d): got response" request_id);
302-
Ok response
303-
)
285+
Eio.Flow.write ctx.ns_connection [packet];
286+
Log.debug (fun m -> m "send_recv (%X): wrote request" request_id);
287+
let response_p, response_r = Eio.Promise.create () in
288+
ctx.requests <- IM.add request_id response_r ctx.requests;
289+
let response =
290+
Eio.Fiber.first
291+
(fun () -> recv_packet ctx ctx.ns_connection request_id)
292+
(fun () -> Eio.Promise.await response_p)
293+
in
294+
Log.debug (fun m -> m "send_recv (%X): got response" request_id);
295+
Ok response
296+
)
304297
with
305298
| Eio.Time.Timeout -> Error (`Msg "DNS request timeout")
306-
| exn -> Error (`Msg (Printexc.to_string_default exn))
299+
(* | exn -> Error (`Msg (Printexc.to_string exn)) *)
307300

308301
let close _ = ()
309302
let bind a f = f a
@@ -314,15 +307,15 @@ include Dns_client.Make(Transport)
314307

315308
let run ?(resolv_conf = "/etc/resolv.conf") (env: _ env) f =
316309
Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env (fun () ->
317-
Eio.Switch.run (fun sw ->
318-
let stack =
319-
{ sw
320-
; mono_clock = env#mono_clock
321-
; net = env#net
322-
; resolv_conf
323-
; fs = env#fs
324-
}
325-
in
326-
f stack
310+
Eio.Switch.run (fun sw ->
311+
let stack =
312+
{ sw
313+
; mono_clock = env#mono_clock
314+
; net = env#net
315+
; resolv_conf
316+
; fs = env#fs
317+
}
318+
in
319+
f stack
320+
)
327321
)
328-
)

0 commit comments

Comments
 (0)