Source file caqti_async.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
open Async_kernel
open Async_unix
open Caqti_platform
open Core
type Caqti_error.msg += Msg_unix of Core_unix.Error.t * string * string
let () =
let pp ppf = function
| Msg_unix (err, func, arg) ->
Format.fprintf ppf "%s in %s(%S)" (Core_unix.Error.message err) func arg
| _ -> assert false
in
Caqti_error.define_msg ~pp [%extension_constructor Msg_unix]
module Fiber = struct
type 'a t = 'a Deferred.t
module Infix = struct
let (>>=) m f = Deferred.bind m ~f
let (>|=) = Deferred.(>>|)
end
open Infix
let return = Deferred.return
let catch f g =
try_with ~extract_exn:true f >>= function
| Ok y -> return y
| Error exn -> g exn
let finally f g =
try_with ~extract_exn:true f >>= function
| Ok y -> g () >|= fun () -> y
| Error exn -> g () >|= fun () -> Error.raise (Error.of_exn exn)
let cleanup f g =
try_with ~extract_exn:true f >>= function
| Ok y -> return y
| Error exn -> g () >|= fun () -> Error.raise (Error.of_exn exn)
end
open Fiber.Infix
module Stream = Caqti_platform.Stream.Make (Fiber)
module System_core = struct
module Fiber = Fiber
type stdenv = unit
let async ~sw:_ f = don't_wait_for (f ())
module Stream = Stream
module Semaphore = struct
type t = unit Ivar.t
let create = Ivar.create
let release v = Ivar.fill v ()
let acquire v = Ivar.read v
end
module Switch = Caqti_platform.Switch.Make (Fiber)
module Log = struct
type 'a log = ('a, unit Deferred.t) Logs.msgf -> unit Deferred.t
let kmsg ?(src = Logs.default) level msgf =
let count_it () =
(match level with
| Logs.Error -> Logs.incr_err_count ()
| Logs.Warning -> Logs.incr_warn_count ()
| _ -> ()) in
(match Logs.Src.level src with
| None -> return ()
| Some level' when Poly.(level > level') ->
count_it ();
return ()
| Some _ ->
count_it ();
let ivar = Ivar.create () in
let k () = Ivar.read ivar in
let over () = Ivar.fill ivar () in
Logs.report src level ~over k msgf)
let err ?(src = Logging.default_log_src) msgf = kmsg ~src Logs.Error msgf
let warn ?(src = Logging.default_log_src) msgf = kmsg ~src Logs.Warning msgf
let info ?(src = Logging.default_log_src) msgf = kmsg ~src Logs.Info msgf
let debug ?(src = Logging.default_log_src) msgf = kmsg ~src Logs.Debug msgf
end
module Sequencer = struct
type 'a t = 'a Sequencer.t
let create t = Sequencer.create ~continue_on_error:true t
let enqueue = Throttle.enqueue
end
end
module Alarm = struct
type t = unit
let schedule ~sw:_ ~stdenv:() t f =
let t_now = Mtime_clock.now () in
let dt_ns =
if Mtime.is_later t ~than:t_now then 0L else
Mtime.Span.to_uint64_ns (Mtime.span t t_now)
in
(match Int63.of_int64 dt_ns with
| None -> failwith "Arithmetic overflow while computing scheduling time."
| Some dt_ns -> Clock_ns.run_after (Time_ns.Span.of_int63_ns dt_ns) f ())
let unschedule () = ()
end
module Pool = Caqti_platform.Pool.Make (System_core) (Alarm)
module System = struct
include System_core
module Stream = Stream
module Net = struct
module Sockaddr = struct
type t = Async_unix.Unix.sockaddr
let unix s = Async_unix.Unix.ADDR_UNIX s
let tcp (addr, port) =
Async_unix.Unix.ADDR_INET
(Core_unix.Inet_addr.of_string (Ipaddr.to_string addr), port)
end
let getaddrinfo ~stdenv:() host port =
let module Ai = Async_unix.Unix.Addr_info in
let ai = ai.Ai.ai_addr in
Ai.get ~host:(Domain_name.to_string host) ~service:(string_of_int port)
Ai.[AI_SOCKTYPE SOCK_STREAM]
>|= List.map ~f:extract >|= (fun addrs -> Ok addrs)
module Socket = struct
type t = Reader.t * Writer.t
let output_char (_, oc) c = return (Writer.write_char oc c)
let output_string (_, oc) s = return (Writer.write oc s)
let flush (_, oc) = Writer.flushed oc
let input_char (ic, _) =
Reader.read_char ic
>|= function `Ok c -> c | `Eof -> raise End_of_file
let really_input (ic, _) s pos len =
Reader.really_read ic ~pos ~len s
>|= function `Ok -> () | `Eof _ -> raise End_of_file
let close (_ic, oc) = Writer.close oc
end
type tcp_flow = Socket.t
type tls_flow = Socket.t
let convert_io_exception exn =
(match Async_kernel.Monitor.extract_exn exn with
| Core_unix.Unix_error (err, func, arg) ->
Some (Msg_unix (err, func, arg))
| _ -> None)
let intercept_exceptions f =
Async_kernel.Monitor.try_with f >|= function
| Ok _ as r -> r
| Error exn ->
(match convert_io_exception exn with
| Some msg -> Error msg
| None -> raise exn)
let connect_tcp ~sw:_ ~stdenv:() addr =
intercept_exceptions @@ fun () ->
(match addr with
| Async_unix.Unix.ADDR_INET (addr, port) ->
Async_unix.Tcp.connect
(Tcp.Where_to_connect.of_inet_address (`Inet (addr, port)))
>|= fun (_, ic, oc) -> (ic, oc)
| Async_unix.Unix.ADDR_UNIX path ->
Async_unix.Tcp.connect
(Tcp.Where_to_connect.of_unix_address (`Unix path))
>|= fun (_, ic, oc) -> (ic, oc))
let tcp_flow_of_socket socket = Some socket
let socket_of_tls_flow ~sw:_ socket = socket
module type TLS_PROVIDER = Caqti_platform.System_sig.TLS_PROVIDER
with type 'a fiber := 'a Deferred.t
and type tcp_flow := tcp_flow
and type tls_flow := tls_flow
let tls_providers_r : (module TLS_PROVIDER) list ref = ref []
let tls_providers config =
if Caqti_connect_config.mem_name "tls" config then
(match Caqti_platform.Connector.load_library "caqti-tls-async" with
| Ok () -> ()
| Error msg ->
Logs.warn ~src:Logging.default_log_src (fun p ->
p "TLS configured but caqti-tls-async not available: %s" msg));
!tls_providers_r
let register_tls_provider p = tls_providers_r := p :: !tls_providers_r
end
end
module System_unix = struct
module Unix = struct
type file_descr = Async_unix.Fd.t
let fdinfo = Info.of_string "Caqti_async file descriptor"
let wrap_fd f ufd =
let fd = Fd.create (Fd.Kind.Socket `Active) ufd fdinfo in
let open Deferred in
f fd >>= fun r ->
Fd.(close ~file_descriptor_handling:Do_not_close_file_descriptor) fd
>>= fun () ->
return r
let poll ~stdenv:() ?(read = false) ?(write = false) ?timeout fd =
let wait_read =
if read then Async_unix.Fd.ready_to fd `Read else Deferred.never () in
let wait_write =
if write then Async_unix.Fd.ready_to fd `Write else Deferred.never () in
let wait_timeout =
(match timeout with
| Some t -> Clock.after (Time_float.Span.of_sec t)
| None -> Deferred.never ()) in
let did_read, did_write, did_timeout = ref false, ref false, ref false in
let is_ready = function
| `Ready -> true
| `Bad_fd | `Closed -> false in
Deferred.enabled [
Deferred.choice wait_read (fun st -> did_read := is_ready st);
Deferred.choice wait_write (fun st -> did_write := is_ready st);
Deferred.choice wait_timeout (fun () -> did_timeout := true);
] >>|
(fun f ->
ignore (f ());
(!did_read, !did_write, !did_timeout))
end
module Preemptive = struct
let detach f x = In_thread.run (fun () -> f x)
let run_in_main f = Thread_safe.block_on_async_exn f
end
end
module Loader = Caqti_platform_unix.Driver_loader.Make (System) (System_unix)
include Connector.Make (System) (Pool) (Loader)
open System
module type CONNECTION = Caqti_connection_sig.S
with type 'a fiber := 'a Deferred.t
and type ('a, 'e) stream := ('a, 'e) Stream.t
type connection = (module CONNECTION)
let connect = connect ~sw:Switch.eternal ~stdenv:()
let with_connection = with_connection ~stdenv:()
let connect_pool = connect_pool~sw:Switch.eternal ~stdenv:()