@@ -30,19 +30,20 @@ module Transport : Dns_client.S
30
30
type nonrec stack = stack
31
31
type +'a io = 'a
32
32
33
- type t = {
33
+ type t = {
34
34
nameservers : Dns .proto * nameservers ;
35
35
stack : stack ;
36
36
timeout : Eio.Time.Timeout .t ;
37
37
mutable ns_connection_condition : Eio.Condition .t option ;
38
38
mutable ctx : (Dns .proto * context ) option ;
39
39
}
40
40
41
- and context = {
41
+ and context = {
42
42
t : t ;
43
43
mutable requests : Cstruct .t Eio.Promise .u IM .t ;
44
44
mutable ns_connection : < Eio.Flow .two_way > ;
45
- mutable buf : Cstruct .t ;
45
+ mutable recv_buf : Cstruct .t ;
46
+ mutable closed : bool ;
46
47
}
47
48
48
49
(* DNS nameservers. *)
@@ -161,10 +162,7 @@ module Transport : Dns_client.S
161
162
let he, actions = Happy_eyeballs. event he (clock () ) event in
162
163
he_handle_actions t he actions
163
164
end
164
- | Connect_failed _ ->
165
- fun () ->
166
- Log. debug (fun m -> m " [he_handle_actions] connection failed" );
167
- None
165
+ | Connect_failed _ -> fun () -> None
168
166
| Connect_cancelled _ | Resolve_a _ | Resolve_aaaa _ as a ->
169
167
fun () ->
170
168
Log. warn (fun m -> m " [he_handle_actions] ignoring action %a" Happy_eyeballs. pp_action a);
@@ -185,7 +183,6 @@ module Transport : Dns_client.S
185
183
| Error `Msg m -> invalid_arg (" failed to load trust anchors: " ^ m)
186
184
187
185
let rec connect t =
188
- Log. debug (fun m -> m " connect : establishing connection to nameservers" );
189
186
match t.ctx, t.ns_connection_condition with
190
187
| Some ctx , _ -> Ok ctx
191
188
| None , Some condition ->
@@ -209,16 +206,18 @@ module Transport : Dns_client.S
209
206
let config = Tls.Config. (client ~authenticator () ) in
210
207
(Tls_eio. client_of_flow config conn :> Eio.Flow.two_way )
211
208
in
212
- let context =
209
+ let ctx =
213
210
{ t = t
214
211
; requests = IM. empty
215
212
; ns_connection = conn
216
- ; buf = Cstruct. empty
213
+ ; recv_buf = Cstruct. create 2048
214
+ ; closed = false
217
215
}
218
216
in
219
- t.ctx < - Some (`Tcp , context);
217
+ t.ctx < - Some (`Tcp , ctx);
218
+ Eio.Fiber. fork ~sw: t.stack.sw ( fun () -> recv_dns_packets ctx );
220
219
Eio.Condition. broadcast ns_connection_condition;
221
- Ok (`Tcp , context )
220
+ Ok (`Tcp , ctx )
222
221
| None ->
223
222
t.ns_connection_condition < - None ;
224
223
Eio.Condition. broadcast ns_connection_condition;
@@ -231,72 +230,67 @@ module Transport : Dns_client.S
231
230
Error (`Msg error_msg)
232
231
end
233
232
234
- let recv_data t flow id : unit =
235
- let buf = Cstruct. create 512 in
236
- Log. debug (fun m -> m " recv_data (%X): t.buf.len %d" id (Cstruct. length t.buf));
237
- let got = Eio.Flow. single_read flow buf in
238
- Log. debug (fun m -> m " recv_data (%X): got %d" id got);
239
- let buf = Cstruct. sub buf 0 got in
240
- t.buf < - if Cstruct. length t.buf = 0 then buf else Cstruct. append t.buf buf;
241
- Log. debug (fun m -> m " recv_data (%X): t.buf.len %d" id (Cstruct. length t.buf))
233
+ and recv_dns_packets ?(recv_data = Cstruct. empty) (ctx : context ) =
242
234
243
- let rec recv_packet t ns_connection request_id =
244
- Log. debug (fun m -> m " recv_packet (%X)" request_id);
245
- let buf_len = Cstruct. length t.buf in
246
- if buf_len > 2 then (
247
- let packet_len = Cstruct.BE. get_uint16 t.buf 0 in
248
- Log. debug (fun m -> m " recv_packet (%X): packet_len %d" request_id (Cstruct. length t.buf));
249
- if buf_len - 2 > = packet_len then
250
- let packet, rest =
251
- if buf_len - 2 = packet_len
252
- then t.buf, Cstruct. empty
253
- else Cstruct. split t.buf (packet_len + 2 )
254
- in
255
- t.buf < - rest;
256
- let response_id = Cstruct.BE. get_uint16 packet 2 in
257
- Log. debug (fun m -> m " recv_packet (%X): got response %X" request_id response_id);
258
- if response_id = request_id
259
- then packet
260
- else begin
261
- (match IM. find response_id t.requests with
262
- | r -> Eio.Promise. resolve r packet
263
- | exception Not_found -> () );
264
- recv_packet t ns_connection request_id
265
- end
266
- else begin
267
- recv_data t ns_connection request_id;
268
- recv_packet t ns_connection request_id
269
- end
270
- )
271
- else begin
272
- recv_data t ns_connection request_id;
273
- recv_packet t ns_connection request_id
274
- end
235
+ let append_recv_buf ctx got recv_data =
236
+ let buf = Cstruct. sub ctx.recv_buf 0 got in
237
+ if Cstruct. is_empty recv_data
238
+ then buf
239
+ else Cstruct. append recv_data buf
240
+ in
241
+
242
+ let rec handle_data recv_data =
243
+ let recv_data_len = Cstruct. length recv_data in
244
+ if recv_data_len < 2
245
+ then recv_dns_packets ~recv_data ctx
246
+ else
247
+ match Cstruct.BE. get_uint16 recv_data 0 with
248
+ | packet_len when recv_data_len - 2 > = packet_len ->
249
+ let packet, recv_data = Cstruct. split recv_data @@ packet_len + 2 in
250
+ let response_id = Cstruct.BE. get_uint16 packet 2 in
251
+ (match IM. find response_id ctx.requests with
252
+ | r ->
253
+ ctx.requests < - IM. remove response_id ctx.requests ;
254
+ Eio.Promise. resolve r packet
255
+ | exception Not_found -> () (* spurious data, ignore *)
256
+ );
257
+ if not @@ IM. is_empty ctx.requests then handle_data recv_data else ()
258
+ | _ -> recv_dns_packets ~recv_data ctx
259
+ in
260
+
261
+ match Eio.Flow. single_read ctx.ns_connection ctx.recv_buf with
262
+ | got ->
263
+ let recv_data = append_recv_buf ctx got recv_data in
264
+ handle_data recv_data
265
+ | exception End_of_file ->
266
+ ctx.t.ns_connection_condition < - None ;
267
+ ctx.t.ctx < - None ;
268
+ ctx.closed < - true ;
269
+ if not @@ IM. is_empty ctx.requests then
270
+ (match connect ctx.t with
271
+ | Ok _ -> recv_dns_packets ~recv_data ctx
272
+ | Error _ -> Log. warn (fun m -> m " [recv_dns_packets] connection closed while processing dns requests" ) )
273
+ else ()
275
274
276
275
let validate_query_packet tx =
277
276
if Cstruct. length tx > 4 then Ok () else
278
277
Error (`Msg " Invalid DNS query packet (data length <= 4)" )
279
278
280
279
let send_recv ctx packet =
281
- let * () = validate_query_packet packet in
282
- try
283
- let request_id = Cstruct.BE. get_uint16 packet 2 in
284
- Eio.Time.Timeout. run_exn ctx.t.timeout (fun () ->
285
- Eio.Flow. write ctx.ns_connection [packet];
286
- Log. debug (fun m -> m " send_recv (%X): wrote request" request_id);
287
- let response_p, response_r = Eio.Promise. create () in
288
- ctx.requests < - IM. add request_id response_r ctx.requests;
289
- let response =
290
- Eio.Fiber. first
291
- (fun () -> recv_packet ctx ctx.ns_connection request_id)
292
- (fun () -> Eio.Promise. await response_p)
293
- in
294
- Log. debug (fun m -> m " send_recv (%X): got response" request_id);
295
- Ok response
296
- )
297
- with
298
- | Eio.Time. Timeout -> Error (`Msg " DNS request timeout" )
299
- (* | exn -> Error (`Msg (Printexc.to_string exn)) *)
280
+ if not ctx.closed then
281
+ let * () = validate_query_packet packet in
282
+ try
283
+ let request_id = Cstruct.BE. get_uint16 packet 2 in
284
+ let response_p, response_r = Eio.Promise. create () in
285
+ ctx.requests < - IM. add request_id response_r ctx.requests;
286
+ Eio.Time.Timeout. run_exn ctx.t.timeout (fun () ->
287
+ Eio.Flow. write ctx.ns_connection [packet];
288
+ let response = Eio.Promise. await response_p in
289
+ Ok response
290
+ )
291
+ with Eio.Time. Timeout -> Error (`Msg " DNS request timeout" )
292
+ else
293
+ Error (`Msg " Nameserver closed connection" )
300
294
301
295
let close _ = ()
302
296
let bind a f = f a
0 commit comments