Source file tcpv4_socket.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
let src = Logs.Src.create "tcpv4-socket" ~doc:"TCP socket v4 (platform native)"
module Log = (val Logs.src_log src : Logs.LOG)
open Lwt.Infix
type ipaddr = Ipaddr.V4.t
type flow = Lwt_unix.file_descr
type t = {
interface: Unix.inet_addr;
mutable active_connections : Lwt_unix.file_descr list;
listen_sockets : (int, Lwt_unix.file_descr) Hashtbl.t;
mutable switched_off : unit Lwt.t;
}
include Tcp_socket
let connect addr =
let t = {
interface = Ipaddr_unix.V4.to_inet_addr (Ipaddr.V4.Prefix.address addr);
active_connections = [];
listen_sockets = Hashtbl.create 7;
switched_off = fst (Lwt.wait ());
} in
Lwt.return t
let set_switched_off t switched_off =
t.switched_off <- Lwt.pick [ switched_off; t.switched_off ]
let disconnect t =
Lwt_list.iter_p close t.active_connections >>= fun () ->
Lwt_list.iter_p close
(Hashtbl.fold (fun _ fd acc -> fd :: acc) t.listen_sockets []) >>= fun () ->
Lwt.cancel t.switched_off ; Lwt.return_unit
let dst fd =
match Lwt_unix.getpeername fd with
| Unix.ADDR_UNIX _ ->
raise (Failure "unexpected: got a unix instead of tcp sock")
| Unix.ADDR_INET (ia,port) -> begin
match Ipaddr_unix.V4.of_inet_addr ia with
| None -> raise (Failure "got a ipv6 sock instead of a tcpv4 one")
| Some ip -> ip,port
end
let create_connection ?keepalive t (dst,dst_port) =
let fd = Lwt_unix.(socket PF_INET SOCK_STREAM 0) in
Lwt.catch (fun () ->
Lwt_unix.bind fd (Lwt_unix.ADDR_INET (t.interface, 0)) >>= fun () ->
Lwt_unix.connect fd
(Lwt_unix.ADDR_INET ((Ipaddr_unix.V4.to_inet_addr dst), dst_port))
>>= fun () ->
( match keepalive with
| None -> ()
| Some { Tcpip.Tcp.Keepalive.after; interval; probes } ->
Tcp_socket_options.enable_keepalive ~fd ~after ~interval ~probes );
t.active_connections <- fd :: t.active_connections;
Lwt.return (Ok fd))
(fun exn ->
close fd >|= fun () ->
Error (`Exn exn))
let unlisten t ~port =
match Hashtbl.find_opt t.listen_sockets port with
| None -> ()
| Some fd ->
Hashtbl.remove t.listen_sockets port;
try Unix.close (Lwt_unix.unix_file_descr fd) with _ -> ()
let listen t ~port ?keepalive callback =
if port < 0 || port > 65535 then
raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port));
unlisten t ~port;
let fd = Lwt_unix.(socket PF_INET SOCK_STREAM 0) in
Lwt_unix.setsockopt fd Lwt_unix.SO_REUSEADDR true;
Unix.bind (Lwt_unix.unix_file_descr fd) (Unix.ADDR_INET (t.interface, port));
Hashtbl.replace t.listen_sockets port fd;
Lwt_unix.listen fd 10;
Lwt.async (fun () ->
let rec loop () =
if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ;
Lwt.catch (fun () ->
Lwt_unix.accept fd >|= fun (afd, _) ->
t.active_connections <- afd :: t.active_connections;
(match keepalive with
| None -> ()
| Some { Tcpip.Tcp.Keepalive.after; interval; probes } ->
Tcp_socket_options.enable_keepalive ~fd:afd ~after ~interval ~probes);
Lwt.async
(fun () ->
Lwt.catch
(fun () -> callback afd)
(fun exn ->
Log.warn (fun m -> m "error %s in callback" (Printexc.to_string exn)) ;
close afd));
`Continue)
(function
| Unix.Unix_error (Unix.EBADF, _, _) ->
Log.warn (fun m -> m "error bad file descriptor in accept") ;
Lwt.return `Stop
| exn ->
Log.warn (fun m -> m "error %s in accept" (Printexc.to_string exn)) ;
Lwt.return `Continue) >>= function
| `Continue -> loop ()
| `Stop -> Lwt.return_unit
in
Lwt.catch loop ignore_canceled >>= fun () -> close fd)