Source file polling_state_rpc.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
open! Core
open! Async_kernel
open! Async_rpc_kernel

module Seqnum = struct
  include Unique_id.Int ()

  let forget = of_int_exn (-1)
end

module Query_dispatch_id = Unique_id.Int ()

(* A Client_id.t is unique per Client.t for a given connection.
   They are not guaranteed to be unique between connections. We need this
   because a user could have multiple Client.t for the same Rpc. *)
module Client_id = Unique_id.Int ()

module Cache : sig
  type ('a, 'client_state) t
  type ('a, 'client_state) per_client

  val create : on_client_forgotten:('client_state -> unit) -> ('a, 'client_state) t

  val find
    :  ('a, 'client_state) t
    -> connection_state:'connection_state
    -> connection:Rpc.Connection.t
    -> client_id:Client_id.t
    -> create_client_state:('connection_state -> 'client_state)
    -> ('a, 'client_state) per_client

  val remove_and_trigger_cancel
    :  ('a, 'client_state) t
    -> connection:Rpc.Connection.t
    -> client_id:Client_id.t
    -> unit

  val last_response : ('a, 'client_state) per_client -> ('a * Seqnum.t) option
  val set : ('a, 'client_state) per_client -> 'a -> Seqnum.t
  val wait_for_cancel : ('a, 'client_state) per_client -> unit Deferred.t
  val trigger_cancel : ('a, 'client_state) per_client -> unit
  val client_state : ('a, 'client_state) per_client -> 'client_state
end = struct
  type ('a, 'client_state) per_client =
    { mutable last_response : ('a * Seqnum.t) option
    ; mutable cancel : unit Ivar.t
    ; client_state : 'client_state
    }

  type ('a, 'client_state) per_connection =
    ('a, 'client_state) per_client Client_id.Table.t

  type ('a, 'client_state) t =
    { connections : (Rpc.Connection.t * ('a, 'client_state) per_connection) Bag.t
    ; on_client_forgotten : 'client_state -> unit
    }

  let create ~on_client_forgotten = { connections = Bag.create (); on_client_forgotten }

  let find_by_connection t ~connection =
    Bag.find t ~f:(fun (conn, _) -> phys_equal connection conn) |> Option.map ~f:snd
  ;;

  let find t ~connection_state ~connection ~client_id ~create_client_state =
    let per_connection =
      match find_by_connection t.connections ~connection with
      | Some per_connection -> per_connection
      | None ->
        let result = Client_id.Table.create () in
        let elt = Bag.add t.connections (connection, result) in
        Deferred.upon (Rpc.Connection.close_finished connection) (fun () ->
          Hashtbl.iter result ~f:(fun { client_state; _ } ->
            t.on_client_forgotten client_state);
          Bag.remove t.connections elt);
        result
    in
    Hashtbl.find_or_add per_connection client_id ~default:(fun () ->
      { last_response = None
      ; cancel = Ivar.create ()
      ; client_state = create_client_state connection_state
      })
  ;;

  let remove_and_trigger_cancel t ~connection ~client_id =
    Option.iter (find_by_connection t.connections ~connection) ~f:(fun per_connection ->
      Option.iter (Hashtbl.find_and_remove per_connection client_id) ~f:(fun per_client ->
        t.on_client_forgotten per_client.client_state;
        Ivar.fill per_client.cancel ()))
  ;;

  let last_response per_client = per_client.last_response

  let set per_client data =
    let seqnum = Seqnum.create () in
    per_client.last_response <- Some (data, seqnum);
    seqnum
  ;;

  let wait_for_cancel per_client = Ivar.read per_client.cancel

  let trigger_cancel per_client =
    Ivar.fill per_client.cancel ();
    per_client.cancel <- Ivar.create ()
  ;;

  let client_state per_client = per_client.client_state
end

module Request = struct
  (* We maintain two variations of the request type so that we can change one
     of them while keeping the other fixed. The stable type should be used only
     as an intermediate step of (de)serialization.

     Changing the unstable type requires figuring out how to encode/decode any
     additional data into/out-of the old stable type using the [unstable_of_stable]
     and [stable_of_unstable]. *)

  module Unstable = struct
    type 'query t =
      | Query of
          { last_seqnum : Seqnum.t option
          ; query : 'query
          ; client_id : Client_id.t
          }
      | Cancel_ongoing of Client_id.t
      | Forget_client of
          { query : 'query
          ; client_id : Client_id.t
          }
  end

  module Stable = struct
    type 'query t =
      | Query of
          { last_seqnum : Seqnum.t option
          ; query : 'query
          ; client_id : Client_id.t
          }
      | Cancel_ongoing of Client_id.t
    [@@deriving bin_io]

    let%expect_test _ =
      print_endline [%bin_digest: unit t];
      [%expect {| 4eb554fadd7eded37e4da89efd208c52 |}]
    ;;
  end

  let unstable_of_stable : 'query Stable.t -> 'query Unstable.t = function
    | Query { last_seqnum = Some last_seqnum; query; client_id }
      when [%equal: Seqnum.t] last_seqnum Seqnum.forget ->
      Forget_client { query; client_id }
    | Query { last_seqnum; query; client_id } -> Query { last_seqnum; query; client_id }
    | Cancel_ongoing client_id -> Cancel_ongoing client_id
  ;;

  let stable_of_unstable : 'query Unstable.t -> 'query Stable.t = function
    | Query { last_seqnum; query; client_id } -> Query { last_seqnum; query; client_id }
    | Cancel_ongoing client_id -> Cancel_ongoing client_id
    | Forget_client { query; client_id } ->
      Query { last_seqnum = Some Seqnum.forget; query; client_id }
  ;;
end

module Response = struct
  type ('response, 'update) t =
    | Fresh of 'response
    | Update of 'update
  [@@deriving bin_io, sexp_of]

  let%expect_test _ =
    print_endline [%bin_digest: (int, string) t];
    [%expect {| 13ef8c5223a0ea284c72512be32e5c09 |}]
  ;;

  type ('response, 'update) pair =
    | Response of
        { new_seqnum : Seqnum.t
        ; response : ('response, 'update) t
        }
    | Cancellation_successful
  [@@deriving bin_io]

  let%expect_test _ =
    print_endline [%bin_digest: (int, string) pair];
    [%expect {| 8bc63a85561d87b693d15e78c64e1008 |}]
  ;;
end

module type Diffable = sig
  type t [@@deriving bin_io]

  include Diffable.S with type t := t
end

type ('query, 'response) t =
  | T :
      { response_module :
          (module Diffable with type t = 'response and type Update.Diff.t = 'diff)
      ; query_equal : 'query -> 'query -> bool
      ; underlying_rpc :
          ('query Request.Stable.t, ('response, 'diff list) Response.pair) Rpc.Rpc.t
      }
      -> ('query, 'response) t

let name (T { underlying_rpc; _ }) = Rpc.Rpc.name underlying_rpc

let create
      (type a)
      ~name
      ~version
      ~query_equal
      ~bin_query
      (module M : Diffable with type t = a)
  =
  let bin_response = Response.bin_pair M.bin_t M.Update.bin_t in
  let bin_query = Request.Stable.bin_t bin_query in
  T
    { query_equal
    ; response_module = (module M)
    ; underlying_rpc = Rpc.Rpc.create ~name ~version ~bin_query ~bin_response
    }
;;

let implement_with_client_state
      (type response)
      ~on_client_and_server_out_of_sync
      ~create_client_state
      ?(on_client_forgotten = ignore)
      ?for_first_request
      t
      f
  =
  (* make a new function to introduce a locally abstract type for diff *)
  let do_implement
    : type diff.
      response_module:
        (module Diffable with type t = response and type Update.Diff.t = diff)
      -> underlying_rpc:(_, (response, diff list) Response.pair) Rpc.Rpc.t
      -> _
    =
    fun ~response_module ~underlying_rpc ~query_equal ->
      let module M = (val response_module) in
      let for_first_request = Option.value for_first_request ~default:f in
      let init connection_state client_state query =
        let%map response = for_first_request connection_state client_state query in
        response, response
      in
      let updates connection_state client_state ~prev:(prev_query, prev_response) query =
        let f = if query_equal prev_query query then f else for_first_request in
        let%map new_ = f connection_state client_state query in
        let diff = M.diffs ~from:prev_response ~to_:new_ in
        Response.Update diff, new_
      in
      let cache = Cache.create ~on_client_forgotten in
      Rpc.Rpc.implement underlying_rpc (fun (connection_state, connection) request ->
        match Request.unstable_of_stable request with
        | Cancel_ongoing client_id ->
          let per_client =
            Cache.find cache ~connection_state ~connection ~client_id ~create_client_state
          in
          Cache.trigger_cancel per_client;
          return Response.Cancellation_successful
        | Forget_client { query = _; client_id } ->
          Cache.remove_and_trigger_cancel cache ~connection ~client_id;
          return Response.Cancellation_successful
        | Query { last_seqnum; query; client_id } ->
          let per_client =
            Cache.find cache ~connection_state ~connection ~client_id ~create_client_state
          in
          let prev =
            match Cache.last_response per_client, last_seqnum with
            | Some (prev, prev_seqnum), Some last_seqnum ->
              (match Seqnum.equal prev_seqnum last_seqnum with
               | true -> Some prev
               | false ->
                 let rpc_name = Rpc.Rpc.name underlying_rpc in
                 let rpc_version = Rpc.Rpc.version underlying_rpc in
                 on_client_and_server_out_of_sync
                   [%message
                     [%here]
                       "A polling state RPC client has requested diffs from a seqnum that \
                        the server does not have, so the server is sending a fresh \
                        response instead. This likely means that the client had trouble \
                        receiving the last RPC response."
                       (rpc_name : string)
                       (rpc_version : int)];
                 None)
            | None, Some _ | None, None -> None
            | Some _, None ->
              let rpc_name = Rpc.Rpc.name underlying_rpc in
              let rpc_version = Rpc.Rpc.version underlying_rpc in
              on_client_and_server_out_of_sync
                [%message
                  [%here]
                    "A polling state RPC client has requested a fresh response, but the \
                     server expected it to have the seqnum of the latest diffs. The server \
                     will send a fresh response as requested. This likely means that the \
                     client had trouble receiving the last RPC response."
                    (rpc_name : string)
                    (rpc_version : int)];
              None
          in
          let response =
            let client_state = Cache.client_state per_client in
            match prev with
            | Some prev ->
              let%map response, userdata =
                updates connection_state client_state ~prev query
              in
              response, userdata
            | None ->
              let%map response, userdata = init connection_state client_state query in
              Response.Fresh response, userdata
          in
          let response_or_cancelled =
            choose
              [ choice response (fun r -> `Response r)
              ; choice (Cache.wait_for_cancel per_client) (fun () -> `Cancelled)
              ]
          in
          (match%map response_or_cancelled with
           | `Response (response, userdata) ->
             let new_seqnum = Cache.set per_client (query, userdata) in
             Cache.trigger_cancel per_client;
             Response.Response { new_seqnum; response }
           | `Cancelled ->
             Exn.raise_without_backtrace (Failure "this request was cancelled")))
  in
  let (T { query_equal; response_module; underlying_rpc }) = t in
  do_implement ~response_module ~underlying_rpc ~query_equal
;;

let implement ~on_client_and_server_out_of_sync ?for_first_request t f =
  let for_first_request =
    Option.map for_first_request ~f:(fun f connection_state _client_state query ->
      f connection_state query)
  in
  let f connection_state _client_state query = f connection_state query in
  implement_with_client_state
    ~on_client_and_server_out_of_sync
    ~create_client_state:(fun _ -> ())
    ?for_first_request
    t
    f
;;

let implement_via_bus
      ~on_client_and_server_out_of_sync
      ~create_client_state
      ?on_client_forgotten
      rpc
      f
  =
  implement_with_client_state
    ~on_client_and_server_out_of_sync
    ~create_client_state:(fun connection_state ->
      Mvar.create (), ref (fun () -> ()), create_client_state connection_state)
    ?on_client_forgotten:
      (Option.map
         on_client_forgotten
         ~f:
           (fun
             on_client_forgotten
             (_most_recent_unsent_response, _unsubscribe, client_state)
             -> on_client_forgotten client_state))
    rpc
    ~for_first_request:
      (fun
        connection_state (most_recent_unsent_response, unsubscribe, client_state) query ->
        (* Since this is the first response for this query, we need to clean up
           the mvar and the bus subscriber from the previous query. *)
        let (_ : 'response option) = Mvar.take_now most_recent_unsent_response in
        !unsubscribe ();
        (* Then we can set up the bus subscription to the new query. *)
        let bus = f connection_state client_state query in
        let subscriber =
          Bus.subscribe_exn bus [%here] ~f:(fun response ->
            Mvar.set most_recent_unsent_response response)
        in
        (unsubscribe := fun () -> Bus.unsubscribe bus subscriber);
        (* Finally, we'll wait for the bus to publish something to the mvar so
           we can return it as the response. *)
        Mvar.take most_recent_unsent_response)
    (fun _connection_state
      (most_recent_unsent_response, _unsubscribe, _client_state)
      _query ->
      (* For all polls except the first one for each query, we can simply wait
         for the current bus to publish a new response. We ignore the query,
         because we know that the bus subscription only ever corresponds to
         the current query. *)
      Mvar.take most_recent_unsent_response)
;;

module Client = struct
  type ('a, 'b) x = ('a, 'b) t

  type ('query, 'response, 'diff) unpacked =
    { mutable last_seqnum : Seqnum.t option
    ; mutable last_query : 'query option
    ; mutable last_query_dispatch_id : Query_dispatch_id.t
    ; mutable out : 'response option
    ; mutable sequencer : unit Sequencer.t
    ; cleaning_sequencer : unit Sequencer.t
    ; client_id : Client_id.t
    ; bus : ('query -> 'response -> unit) Bus.Read_write.t
    ; response_module :
        (module Diffable with type t = 'response and type Update.Diff.t = 'diff)
    ; query_equal : 'query -> 'query -> bool
    ; underlying_rpc :
        ('query Request.Stable.t, ('response, 'diff list) Response.pair) Rpc.Rpc.t
    ; fold : 'response option -> 'query -> ('response, 'diff list) Response.t -> 'response
    }

  type ('query, 'response) t =
    | T : ('query, 'response, 'diff) unpacked -> ('query, 'response) t

  let dispatch_underlying t connection request =
    Rpc.Rpc.dispatch t.underlying_rpc connection (Request.stable_of_unstable request)
  ;;

  let cancel_current_if_query_changed t q connection =
    (* use this cleaning_sequencer as a lock to prevent two subsequent queries
       from obliterating the same sequencer. *)
    Throttle.enqueue t.cleaning_sequencer (fun () ->
      match t.last_query with
      (* If there was no previous query, we can keep the existing sequencer
         because its totally empty. *)
      | None -> Deferred.Or_error.ok_unit
      (* If the sequencer isn't running anything right now, then we can co-opt it
         for this query. *)
      | Some _ when Throttle.num_jobs_running t.sequencer = 0 -> Deferred.Or_error.ok_unit
      (* If the current query is the same as the last one, then we're fine
         because the sequencer is already running that kind of query. *)
      | Some q' when t.query_equal q q' -> Deferred.Or_error.ok_unit
      (* Otherwise, we need to cancel the current request, kill every task currently
         in the sequencer, and make a new sequencer for this query. *)
      | _ ->
        Throttle.kill t.sequencer;
        let%bind cancel_response =
          dispatch_underlying t connection (Cancel_ongoing t.client_id)
        in
        let%map () = Throttle.cleaned t.sequencer in
        t.sequencer <- Sequencer.create ~continue_on_error:true ();
        (match cancel_response with
         | Ok (Response _) ->
           [%message "BUG" [%here] "regular response caused by cancellation"]
           |> Or_error.error_s
         | Ok Cancellation_successful -> Ok ()
         | Error e -> Error e))
  ;;

  let dispatch' t connection query =
    t.last_query <- Some query;
    let%bind response =
      let last_seqnum = t.last_seqnum in
      let client_id = t.client_id in
      dispatch_underlying t connection (Query { query; last_seqnum; client_id })
    in
    match response with
    | Ok (Response { response; new_seqnum }) ->
      let new_out = t.fold t.out query response in
      t.last_seqnum <- Some new_seqnum;
      t.out <- Some new_out;
      Bus.write2 t.bus query new_out;
      return (Ok new_out)
    | Ok Cancellation_successful ->
      [%message "BUG" [%here] "cancellation caused by regular request"]
      |> Or_error.error_s
      |> return
    | Error e -> return (Error e)
  ;;

  (* These are errors defined by sequencer; squish them into Or_error.t *)
  let collapse_sequencer_error = function
    | `Ok result_or_error -> result_or_error
    | `Aborted -> Error (Error.of_string "Request aborted")
    | `Raised exn -> Error (Error.of_exn exn)
  ;;

  (* Use a sequencer to ensure that there aren't any sequential outgoing requests *)
  let dispatch (T t) connection query =
    let query_dispatch_id = Query_dispatch_id.create () in
    t.last_query_dispatch_id <- query_dispatch_id;
    let%bind.Deferred.Or_error () = cancel_current_if_query_changed t query connection in
    (if not (Query_dispatch_id.equal t.last_query_dispatch_id query_dispatch_id)
     then return `Aborted
     else Throttle.enqueue' t.sequencer (fun () -> dispatch' t connection query))
    >>| collapse_sequencer_error
  ;;

  let redispatch (T t) connection =
    Throttle.enqueue' t.sequencer (fun () ->
      match t.last_query with
      | Some q -> dispatch' t connection q
      | None ->
        "[redispatch] called before a query was set or a regular dispatch had completed"
        |> Error.of_string
        |> Error
        |> Deferred.return)
    >>| collapse_sequencer_error
  ;;

  let forget_on_server (T t) connection =
    t.last_query_dispatch_id <- Query_dispatch_id.create ();
    match t.last_query with
    (* If there was no previous query, then the server has nothing to forget. *)
    | None -> Deferred.Or_error.ok_unit
    | Some query ->
      Throttle.enqueue' t.cleaning_sequencer (fun () ->
        Throttle.kill t.sequencer;
        let%bind forget_response =
          dispatch_underlying
            t
            connection
            (Forget_client { query; client_id = t.client_id })
        in
        let%map () = Throttle.cleaned t.sequencer in
        t.sequencer <- Sequencer.create ~continue_on_error:true ();
        match forget_response with
        | Ok (Response _) ->
          (* It is possible that new clients of old servers will get
             spammed with this message if [forget_on_server] is called a lot. However,
             this sounds like an unlikely case to me, since old servers probably won't
             exist for much longer; in addition, spamming this message is not terrible. *)
          Or_error.error_s
            [%message
              "BUG"
                [%here]
                {|Regular response caused by forget request. This can also happen if the server is old and does not support forget requests, in which case this is not a bug.|}]
        | Ok Cancellation_successful -> Ok ()
        | Error e -> Error e)
      >>| collapse_sequencer_error
  ;;

  let query (T { last_query; _ }) = last_query
  let bus (T { bus; _ }) = Bus.read_only bus

  let create (type query response) ?initial_query (t : (query, response) x) =
    let (T { query_equal; underlying_rpc; response_module } : _ x) = t in
    let module M = (val response_module) in
    let f prev _query = function
      | Response.Fresh r -> r
      | Update diffs ->
        (match prev with
         | None ->
           raise_s
             [%message
               "BUG" [%here] "received an update without receiving any previous values"]
         | Some prev -> M.update prev diffs)
    in
    let bus =
      Bus.create_exn
        [%here]
        Bus.Callback_arity.Arity2
        ~on_callback_raise:(fun error ->
          let tag = "exception thrown from inside of polling-state-rpc bus handler" in
          error |> Error.tag ~tag |> [%sexp_of: Error.t] |> eprint_s)
        ~on_subscription_after_first_write:Allow_and_send_last_value
    in
    let sequencer = Sequencer.create ~continue_on_error:true () in
    let cleaning_sequencer = Sequencer.create ~continue_on_error:true () in
    T
      { bus
      ; sequencer
      ; cleaning_sequencer
      ; last_seqnum = None
      ; last_query = initial_query
      ; last_query_dispatch_id = Query_dispatch_id.create ()
      ; out = None
      ; response_module
      ; query_equal
      ; underlying_rpc
      ; fold = f
      ; client_id = Client_id.create ()
      }
  ;;
end

module Private_for_testing = struct
  module Response' = struct
    type 'response t =
      | Fresh : 'response -> 'response t
      | Update :
          { t : 'update
          ; sexp_of : 'update -> Sexp.t
          }
          -> 'response t

    let sexp_of_t sexp_of_a = function
      | Fresh a -> [%sexp Fresh (sexp_of_a a : Sexp.t)]
      | Update { t; sexp_of } -> [%sexp Update (sexp_of t : Sexp.t)]
    ;;
  end

  let create_client
        (type query response)
        ?initial_query
        (t : (query, response) t)
        ~introspect
    =
    (* Make a new function to introduce a locally abstract type for [diff] *)
    let make_fold
      : type diff.
        fold:(response option -> query -> (response, diff list) Response.t -> response)
        -> response_module:
             (module Diffable with type t = response and type Update.Diff.t = diff)
        -> (response option -> query -> (response, diff list) Response.t -> response)
      =
      fun ~fold ~response_module ->
        let module M = (val response_module) in
        let fold prev query (resp : (response, diff list) Response.t) =
          let resp' =
            match resp with
            | Response.Fresh r -> Response'.Fresh r
            | Update diffs ->
              Update { t = diffs; sexp_of = [%sexp_of: M.Update.Diff.t list] }
          in
          introspect prev query resp';
          fold prev query resp
        in
        fold
    in
    let (T client) = Client.create ?initial_query t in
    let new_fold = make_fold ~fold:client.fold ~response_module:client.response_module in
    Client.T { client with fold = new_fold }
  ;;

  module Response = Response'
end