Source file unix_flow.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
open Lwt.Infix

(* Slightly rude to set signal handlers in a library, but SIGPIPE makes no sense
   in a modern application. *)
let () = if not Sys.win32 then Sys.(set_signal sigpipe Signal_ignore)

type flow = {
  fd : Lwt_unix.file_descr;
  mutable current_write : int Lwt.t option;
  mutable current_read : int Lwt.t option;
  mutable closed : bool;
}
type error = [`Exception of exn]
type write_error = [`Closed | `Exception of exn]

let opt_cancel = function
  | None -> ()
  | Some x -> Lwt.cancel x

let close t =
  if t.closed then Lwt.return_unit
  else (
    t.closed <- true;
    opt_cancel t.current_read;
    opt_cancel t.current_write;
    Lwt_unix.close t.fd
  )

let pp_error f = function
  | `Exception ex -> Fmt.exn f ex
  | `Closed -> Fmt.string f "Closed"

let pp_write_error = pp_error

let write t buf =
  let rec aux buf =
    if t.closed then Lwt.return (Error `Closed)
    else (
      assert (t.current_write = None);
      let write_thread = Lwt_cstruct.write t.fd buf in
      t.current_write <- Some write_thread;
      write_thread >>= fun wrote ->
      t.current_write <- None;
      if wrote = Cstruct.length buf then Lwt.return (Ok ())
      else aux (Cstruct.shift buf wrote)
    )
  in
  Lwt.catch
    (fun () -> aux buf)
    (function
      | Unix.Unix_error (Unix.ECONNRESET, _, _)
      | Unix.Unix_error (Unix.ENOTCONN, _, _)           (* macos *)
      | Unix.Unix_error (Unix.EPIPE, _, _) -> Lwt.return @@ Error `Closed
      | ex -> Lwt.return @@ Error (`Exception ex))

let rec writev t = function
  | [] -> Lwt.return (Ok ())
  | x :: xs ->
    write t x >>= function
    | Ok () -> writev t xs
    | Error _ as e -> Lwt.return e

let read t =
  let len = 4096 in
  let buf = Cstruct.create_unsafe len in
  Lwt.try_bind
    (fun () ->
       assert (t.current_read = None);
       if t.closed then raise Lwt.Canceled;
       let read_thread = Lwt_cstruct.read t.fd buf in
       t.current_read <- Some read_thread;
       read_thread
    )
    (function
      | 0 ->
        Lwt.return @@ Ok `Eof
      | got ->
        t.current_read <- None;
        Lwt.return @@ Ok (`Data (Cstruct.sub buf 0 got))
    )
    (function
      | Lwt.Canceled
      | Unix.Unix_error (Unix.EPIPE, _, _)
      | Unix.Unix_error (Unix.ECONNRESET, _, _) -> Lwt_result.return `Eof
      | ex -> Lwt.return @@ Error (`Exception ex)
    )

let connect ?switch fd =
  let t = { fd; closed = false; current_read = None; current_write = None } in
  Lwt_switch.add_hook switch (fun () -> close t);
  t

let socketpair ?switch () =
  let a, b = Lwt_unix.(socketpair PF_UNIX SOCK_STREAM 0) in
  connect ?switch a, connect ?switch b