Source file irc_client_lwt_ssl.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

module Config = struct
  type t = {
    check_certificate: bool;
    proto: Ssl.protocol;
  }

  let default = { check_certificate=false; proto=Ssl.TLSv1_3; }
end

module Io_lwt_ssl = struct
  type 'a t = 'a Lwt.t
  let (>>=) = Lwt.bind
  let (>|=) = Lwt.(>|=)
  let return = Lwt.return

  type file_descr = {
    ssl: Ssl.context;
    fd: Lwt_ssl.socket;
  }

  type config = Config.t
  type inet_addr = Lwt_unix.inet_addr

  let open_socket ?(config=Config.default) addr port : file_descr t =
    let ssl = Ssl.create_context config.Config.proto Ssl.Client_context in
    if config.Config.check_certificate then begin
      (* from https://github.com/johnelse/ocaml-irc-client/pull/21 *)
      Ssl.set_verify_depth ssl 3;
      Ssl.set_verify ssl [Ssl.Verify_peer] (Some Ssl.client_verify_callback);
      Ssl.set_client_verify_callback_verbose true;
    end;
    let sock = Lwt_unix.socket Lwt_unix.PF_INET Lwt_unix.SOCK_STREAM 0 in
    let sockaddr = Lwt_unix.ADDR_INET (addr, port) in
    (* Printf.printf "connect socket…\n%!"; *)
    Lwt_unix.connect sock sockaddr >>= fun () ->
    (* Printf.printf "Ssl.connect socket…\n%!"; *)
    Lwt_ssl.ssl_connect sock ssl >>= fun sock ->
    Lwt.return {fd=sock; ssl}

  let close_socket {fd;ssl=_} =
    Lwt_ssl.close fd

  let read {fd;_} i len = Lwt_ssl.read fd i len
  let write {fd;_} s i len = Lwt_ssl.write fd s i len

  let read_with_timeout ~timeout fd buf off len =
    let open Lwt.Infix in
    Lwt.pick
      [ (read fd buf off len >|= fun i -> Some i);
        (Lwt_unix.sleep (float timeout) >|= fun () -> None);
      ]

  let gethostbyname name =
    Lwt.catch
      (fun () ->
      Lwt_unix.gethostbyname name >>= fun entry ->
      let addrs = Array.to_list entry.Unix.h_addr_list in
      Lwt.return addrs
    ) (function
      | Not_found -> Lwt.return_nil
      | e -> Lwt.fail e
      )

  let iter = Lwt_list.iter_s
  let sleep d = Lwt_unix.sleep (float d)
  let catch = Lwt.catch
  let time = Unix.time

  let pick = Some Lwt.pick
end

include Irc_client.Make(Io_lwt_ssl)