Source file polling_state_rpc.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
open! Core
open! Async_kernel
open! Async_rpc_kernel
module Seqnum = struct
include Unique_id.Int ()
let forget = of_int_exn (-1)
end
module Query_dispatch_id = Unique_id.Int ()
module Client_id = Unique_id.Int ()
module Cache : sig
type ('a, 'client_state) t
type ('a, 'client_state) per_client
val create : on_client_forgotten:('client_state -> unit) -> ('a, 'client_state) t
val find
: ('a, 'client_state) t
-> connection_state:'connection_state
-> connection:Rpc.Connection.t
-> client_id:Client_id.t
-> create_client_state:('connection_state -> 'client_state)
-> ('a, 'client_state) per_client
val remove_and_trigger_cancel
: ('a, 'client_state) t
-> connection:Rpc.Connection.t
-> client_id:Client_id.t
-> unit
val last_response : ('a, 'client_state) per_client -> ('a * Seqnum.t) option
val set : ('a, 'client_state) per_client -> 'a -> Seqnum.t
val wait_for_cancel : ('a, 'client_state) per_client -> unit Deferred.t
val trigger_cancel : ('a, 'client_state) per_client -> unit
val client_state : ('a, 'client_state) per_client -> 'client_state
end = struct
type ('a, 'client_state) per_client =
{ mutable last_response : ('a * Seqnum.t) option
; mutable cancel : unit Ivar.t
; client_state : 'client_state
}
type ('a, 'client_state) per_connection =
('a, 'client_state) per_client Client_id.Table.t
type ('a, 'client_state) t =
{ connections : (Rpc.Connection.t * ('a, 'client_state) per_connection) Bag.t
; on_client_forgotten : 'client_state -> unit
}
let create ~on_client_forgotten = { connections = Bag.create (); on_client_forgotten }
let find_by_connection t ~connection =
Bag.find t ~f:(fun (conn, _) -> phys_equal connection conn) |> Option.map ~f:snd
;;
let find t ~connection_state ~connection ~client_id ~create_client_state =
let per_connection =
match find_by_connection t.connections ~connection with
| Some per_connection -> per_connection
| None ->
let result = Client_id.Table.create () in
let elt = Bag.add t.connections (connection, result) in
Deferred.upon (Rpc.Connection.close_finished connection) (fun () ->
Hashtbl.iter result ~f:(fun { client_state; _ } ->
t.on_client_forgotten client_state);
Bag.remove t.connections elt);
result
in
Hashtbl.find_or_add per_connection client_id ~default:(fun () ->
{ last_response = None
; cancel = Ivar.create ()
; client_state = create_client_state connection_state
})
;;
let remove_and_trigger_cancel t ~connection ~client_id =
Option.iter (find_by_connection t.connections ~connection) ~f:(fun per_connection ->
Option.iter (Hashtbl.find_and_remove per_connection client_id) ~f:(fun per_client ->
t.on_client_forgotten per_client.client_state;
Ivar.fill per_client.cancel ()))
;;
let last_response per_client = per_client.last_response
let set per_client data =
let seqnum = Seqnum.create () in
per_client.last_response <- Some (data, seqnum);
seqnum
;;
let wait_for_cancel per_client = Ivar.read per_client.cancel
let trigger_cancel per_client =
Ivar.fill per_client.cancel ();
per_client.cancel <- Ivar.create ()
;;
let client_state per_client = per_client.client_state
end
module Request = struct
module Unstable = struct
type 'query t =
| Query of
{ last_seqnum : Seqnum.t option
; query : 'query
; client_id : Client_id.t
}
| Cancel_ongoing of Client_id.t
| Forget_client of
{ query : 'query
; client_id : Client_id.t
}
end
module Stable = struct
type 'query t =
| Query of
{ last_seqnum : Seqnum.t option
; query : 'query
; client_id : Client_id.t
}
| Cancel_ongoing of Client_id.t
[@@deriving bin_io]
let%expect_test _ =
print_endline [%bin_digest: unit t];
[%expect {| 4eb554fadd7eded37e4da89efd208c52 |}]
;;
end
let unstable_of_stable : 'query Stable.t -> 'query Unstable.t = function
| Query { last_seqnum = Some last_seqnum; query; client_id }
when [%equal: Seqnum.t] last_seqnum Seqnum.forget ->
Forget_client { query; client_id }
| Query { last_seqnum; query; client_id } -> Query { last_seqnum; query; client_id }
| Cancel_ongoing client_id -> Cancel_ongoing client_id
;;
let stable_of_unstable : 'query Unstable.t -> 'query Stable.t = function
| Query { last_seqnum; query; client_id } -> Query { last_seqnum; query; client_id }
| Cancel_ongoing client_id -> Cancel_ongoing client_id
| Forget_client { query; client_id } ->
Query { last_seqnum = Some Seqnum.forget; query; client_id }
;;
end
module Response = struct
type ('response, 'update) t =
| Fresh of 'response
| Update of 'update
[@@deriving bin_io, sexp_of]
let%expect_test _ =
print_endline [%bin_digest: (int, string) t];
[%expect {| 13ef8c5223a0ea284c72512be32e5c09 |}]
;;
type ('response, 'update) pair =
| Response of
{ new_seqnum : Seqnum.t
; response : ('response, 'update) t
}
| Cancellation_successful
[@@deriving bin_io]
let%expect_test _ =
print_endline [%bin_digest: (int, string) pair];
[%expect {| 8bc63a85561d87b693d15e78c64e1008 |}]
;;
end
module type Diffable = sig
type t [@@deriving bin_io]
include Diffable.S with type t := t
end
type ('query, 'response) t =
| T :
{ response_module :
(module Diffable with type t = 'response and type Update.Diff.t = 'diff)
; query_equal : 'query -> 'query -> bool
; underlying_rpc :
('query Request.Stable.t, ('response, 'diff list) Response.pair) Rpc.Rpc.t
}
-> ('query, 'response) t
let name (T { underlying_rpc; _ }) = Rpc.Rpc.name underlying_rpc
let create
(type a)
~name
~version
~query_equal
~bin_query
(module M : Diffable with type t = a)
=
let bin_response = Response.bin_pair M.bin_t M.Update.bin_t in
let bin_query = Request.Stable.bin_t bin_query in
T
{ query_equal
; response_module = (module M)
; underlying_rpc = Rpc.Rpc.create ~name ~version ~bin_query ~bin_response
}
;;
let implement_with_client_state
(type response)
~on_client_and_server_out_of_sync
~create_client_state
?(on_client_forgotten = ignore)
?for_first_request
t
f
=
let do_implement
: type diff.
response_module:
(module Diffable with type t = response and type Update.Diff.t = diff)
-> underlying_rpc:(_, (response, diff list) Response.pair) Rpc.Rpc.t
-> _
=
fun ~response_module ~underlying_rpc ~query_equal ->
let module M = (val response_module) in
let for_first_request = Option.value for_first_request ~default:f in
let init connection_state client_state query =
let%map response = for_first_request connection_state client_state query in
response, response
in
let updates connection_state client_state ~prev:(prev_query, prev_response) query =
let f = if query_equal prev_query query then f else for_first_request in
let%map new_ = f connection_state client_state query in
let diff = M.diffs ~from:prev_response ~to_:new_ in
Response.Update diff, new_
in
let cache = Cache.create ~on_client_forgotten in
Rpc.Rpc.implement underlying_rpc (fun (connection_state, connection) request ->
match Request.unstable_of_stable request with
| Cancel_ongoing client_id ->
let per_client =
Cache.find cache ~connection_state ~connection ~client_id ~create_client_state
in
Cache.trigger_cancel per_client;
return Response.Cancellation_successful
| Forget_client { query = _; client_id } ->
Cache.remove_and_trigger_cancel cache ~connection ~client_id;
return Response.Cancellation_successful
| Query { last_seqnum; query; client_id } ->
let per_client =
Cache.find cache ~connection_state ~connection ~client_id ~create_client_state
in
let prev =
match Cache.last_response per_client, last_seqnum with
| Some (prev, prev_seqnum), Some last_seqnum ->
(match Seqnum.equal prev_seqnum last_seqnum with
| true -> Some prev
| false ->
let rpc_name = Rpc.Rpc.name underlying_rpc in
let rpc_version = Rpc.Rpc.version underlying_rpc in
on_client_and_server_out_of_sync
[%message
[%here]
"A polling state RPC client has requested diffs from a seqnum that \
the server does not have, so the server is sending a fresh \
response instead. This likely means that the client had trouble \
receiving the last RPC response."
(rpc_name : string)
(rpc_version : int)];
None)
| None, Some _ | None, None -> None
| Some _, None ->
let rpc_name = Rpc.Rpc.name underlying_rpc in
let rpc_version = Rpc.Rpc.version underlying_rpc in
on_client_and_server_out_of_sync
[%message
[%here]
"A polling state RPC client has requested a fresh response, but the \
server expected it to have the seqnum of the latest diffs. The server \
will send a fresh response as requested. This likely means that the \
client had trouble receiving the last RPC response."
(rpc_name : string)
(rpc_version : int)];
None
in
let response =
let client_state = Cache.client_state per_client in
match prev with
| Some prev ->
let%map response, userdata =
updates connection_state client_state ~prev query
in
response, userdata
| None ->
let%map response, userdata = init connection_state client_state query in
Response.Fresh response, userdata
in
let response_or_cancelled =
choose
[ choice response (fun r -> `Response r)
; choice (Cache.wait_for_cancel per_client) (fun () -> `Cancelled)
]
in
(match%map response_or_cancelled with
| `Response (response, userdata) ->
let new_seqnum = Cache.set per_client (query, userdata) in
Cache.trigger_cancel per_client;
Response.Response { new_seqnum; response }
| `Cancelled ->
Exn.raise_without_backtrace (Failure "this request was cancelled")))
in
let (T { query_equal; response_module; underlying_rpc }) = t in
do_implement ~response_module ~underlying_rpc ~query_equal
;;
let implement ~on_client_and_server_out_of_sync ?for_first_request t f =
let for_first_request =
Option.map for_first_request ~f:(fun f connection_state _client_state query ->
f connection_state query)
in
let f connection_state _client_state query = f connection_state query in
implement_with_client_state
~on_client_and_server_out_of_sync
~create_client_state:(fun _ -> ())
?for_first_request
t
f
;;
let implement_via_bus
~on_client_and_server_out_of_sync
~create_client_state
?on_client_forgotten
rpc
f
=
implement_with_client_state
~on_client_and_server_out_of_sync
~create_client_state:(fun connection_state ->
Mvar.create (), ref (fun () -> ()), create_client_state connection_state)
?on_client_forgotten:
(Option.map
on_client_forgotten
~f:
(fun
on_client_forgotten
(_most_recent_unsent_response, _unsubscribe, client_state)
-> on_client_forgotten client_state))
rpc
~for_first_request:
(fun
connection_state (most_recent_unsent_response, unsubscribe, client_state) query ->
let (_ : 'response option) = Mvar.take_now most_recent_unsent_response in
!unsubscribe ();
let bus = f connection_state client_state query in
let subscriber =
Bus.subscribe_exn bus [%here] ~f:(fun response ->
Mvar.set most_recent_unsent_response response)
in
(unsubscribe := fun () -> Bus.unsubscribe bus subscriber);
Mvar.take most_recent_unsent_response)
(fun _connection_state
(most_recent_unsent_response, _unsubscribe, _client_state)
_query ->
Mvar.take most_recent_unsent_response)
;;
module Client = struct
type ('a, 'b) x = ('a, 'b) t
type ('query, 'response, 'diff) unpacked =
{ mutable last_seqnum : Seqnum.t option
; mutable last_query : 'query option
; mutable last_query_dispatch_id : Query_dispatch_id.t
; mutable out : 'response option
; mutable sequencer : unit Sequencer.t
; cleaning_sequencer : unit Sequencer.t
; client_id : Client_id.t
; bus : ('query -> 'response -> unit) Bus.Read_write.t
; response_module :
(module Diffable with type t = 'response and type Update.Diff.t = 'diff)
; query_equal : 'query -> 'query -> bool
; underlying_rpc :
('query Request.Stable.t, ('response, 'diff list) Response.pair) Rpc.Rpc.t
; fold : 'response option -> 'query -> ('response, 'diff list) Response.t -> 'response
}
type ('query, 'response) t =
| T : ('query, 'response, 'diff) unpacked -> ('query, 'response) t
let dispatch_underlying t connection request =
Rpc.Rpc.dispatch t.underlying_rpc connection (Request.stable_of_unstable request)
;;
let cancel_current_if_query_changed t q connection =
Throttle.enqueue t.cleaning_sequencer (fun () ->
match t.last_query with
| None -> Deferred.Or_error.ok_unit
| Some _ when Throttle.num_jobs_running t.sequencer = 0 -> Deferred.Or_error.ok_unit
| Some q' when t.query_equal q q' -> Deferred.Or_error.ok_unit
| _ ->
Throttle.kill t.sequencer;
let%bind cancel_response =
dispatch_underlying t connection (Cancel_ongoing t.client_id)
in
let%map () = Throttle.cleaned t.sequencer in
t.sequencer <- Sequencer.create ~continue_on_error:true ();
(match cancel_response with
| Ok (Response _) ->
[%message "BUG" [%here] "regular response caused by cancellation"]
|> Or_error.error_s
| Ok Cancellation_successful -> Ok ()
| Error e -> Error e))
;;
let dispatch' t connection query =
t.last_query <- Some query;
let%bind response =
let last_seqnum = t.last_seqnum in
let client_id = t.client_id in
dispatch_underlying t connection (Query { query; last_seqnum; client_id })
in
match response with
| Ok (Response { response; new_seqnum }) ->
let new_out = t.fold t.out query response in
t.last_seqnum <- Some new_seqnum;
t.out <- Some new_out;
Bus.write2 t.bus query new_out;
return (Ok new_out)
| Ok Cancellation_successful ->
[%message "BUG" [%here] "cancellation caused by regular request"]
|> Or_error.error_s
|> return
| Error e -> return (Error e)
;;
let collapse_sequencer_error = function
| `Ok result_or_error -> result_or_error
| `Aborted -> Error (Error.of_string "Request aborted")
| `Raised exn -> Error (Error.of_exn exn)
;;
let dispatch (T t) connection query =
let query_dispatch_id = Query_dispatch_id.create () in
t.last_query_dispatch_id <- query_dispatch_id;
let%bind.Deferred.Or_error () = cancel_current_if_query_changed t query connection in
(if not (Query_dispatch_id.equal t.last_query_dispatch_id query_dispatch_id)
then return `Aborted
else Throttle.enqueue' t.sequencer (fun () -> dispatch' t connection query))
>>| collapse_sequencer_error
;;
let redispatch (T t) connection =
Throttle.enqueue' t.sequencer (fun () ->
match t.last_query with
| Some q -> dispatch' t connection q
| None ->
"[redispatch] called before a query was set or a regular dispatch had completed"
|> Error.of_string
|> Error
|> Deferred.return)
>>| collapse_sequencer_error
;;
let forget_on_server (T t) connection =
t.last_query_dispatch_id <- Query_dispatch_id.create ();
match t.last_query with
| None -> Deferred.Or_error.ok_unit
| Some query ->
Throttle.enqueue' t.cleaning_sequencer (fun () ->
Throttle.kill t.sequencer;
let%bind forget_response =
dispatch_underlying
t
connection
(Forget_client { query; client_id = t.client_id })
in
let%map () = Throttle.cleaned t.sequencer in
t.sequencer <- Sequencer.create ~continue_on_error:true ();
match forget_response with
| Ok (Response _) ->
Or_error.error_s
[%message
"BUG"
[%here]
{|Regular response caused by forget request. This can also happen if the server is old and does not support forget requests, in which case this is not a bug.|}]
| Ok Cancellation_successful -> Ok ()
| Error e -> Error e)
>>| collapse_sequencer_error
;;
let query (T { last_query; _ }) = last_query
let bus (T { bus; _ }) = Bus.read_only bus
let create (type query response) ?initial_query (t : (query, response) x) =
let (T { query_equal; underlying_rpc; response_module } : _ x) = t in
let module M = (val response_module) in
let f prev _query = function
| Response.Fresh r -> r
| Update diffs ->
(match prev with
| None ->
raise_s
[%message
"BUG" [%here] "received an update without receiving any previous values"]
| Some prev -> M.update prev diffs)
in
let bus =
Bus.create_exn
[%here]
Bus.Callback_arity.Arity2
~on_callback_raise:(fun error ->
let tag = "exception thrown from inside of polling-state-rpc bus handler" in
error |> Error.tag ~tag |> [%sexp_of: Error.t] |> eprint_s)
~on_subscription_after_first_write:Allow_and_send_last_value
in
let sequencer = Sequencer.create ~continue_on_error:true () in
let cleaning_sequencer = Sequencer.create ~continue_on_error:true () in
T
{ bus
; sequencer
; cleaning_sequencer
; last_seqnum = None
; last_query = initial_query
; last_query_dispatch_id = Query_dispatch_id.create ()
; out = None
; response_module
; query_equal
; underlying_rpc
; fold = f
; client_id = Client_id.create ()
}
;;
end
module Private_for_testing = struct
module Response' = struct
type 'response t =
| Fresh : 'response -> 'response t
| Update :
{ t : 'update
; sexp_of : 'update -> Sexp.t
}
-> 'response t
let sexp_of_t sexp_of_a = function
| Fresh a -> [%sexp Fresh (sexp_of_a a : Sexp.t)]
| Update { t; sexp_of } -> [%sexp Update (sexp_of t : Sexp.t)]
;;
end
let create_client
(type query response)
?initial_query
(t : (query, response) t)
~introspect
=
let make_fold
: type diff.
fold:(response option -> query -> (response, diff list) Response.t -> response)
-> response_module:
(module Diffable with type t = response and type Update.Diff.t = diff)
-> (response option -> query -> (response, diff list) Response.t -> response)
=
fun ~fold ~response_module ->
let module M = (val response_module) in
let fold prev query (resp : (response, diff list) Response.t) =
let resp' =
match resp with
| Response.Fresh r -> Response'.Fresh r
| Update diffs ->
Update { t = diffs; sexp_of = [%sexp_of: M.Update.Diff.t list] }
in
introspect prev query resp';
fold prev query resp
in
fold
in
let (T client) = Client.create ?initial_query t in
let new_fold = make_fold ~fold:client.fold ~response_module:client.response_module in
Client.T { client with fold = new_fold }
;;
module Response = Response'
end