1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
open Types_
module A = Atomic_
module WSQ = Ws_deque_
module WL = Worker_loop_
include Runner
let ( let@ ) = ( @@ )
module Id = struct
type t = unit ref
(** Unique identifier for a pool *)
let create () : t = Sys.opaque_identity (ref ())
let equal : t -> t -> bool = ( == )
end
type state = {
id_: Id.t;
(** Unique to this pool. Used to make sure tasks stay within the same pool. *)
active: bool A.t; (** Becomes [false] when the pool is shutdown. *)
mutable workers: worker_state array; (** Fixed set of workers. *)
main_q: WL.task_full Queue.t;
(** Main queue for tasks coming from the outside *)
mutable n_waiting: int;
mutable n_waiting_nonzero: bool; (** [n_waiting > 0] *)
mutex: Mutex.t;
cond: Condition.t;
mutable as_runner: t;
around_task: WL.around_task;
name: string option;
on_init_thread: dom_id:int -> t_id:int -> unit -> unit;
on_exit_thread: dom_id:int -> t_id:int -> unit -> unit;
on_exn: exn -> Printexc.raw_backtrace -> unit;
}
(** internal state *)
and worker_state = {
mutable thread: Thread.t;
idx: int;
dom_id: int;
st: state;
q: WL.task_full WSQ.t; (** Work stealing queue *)
rng: Random.State.t;
}
(** State for a given worker. Only this worker is
allowed to push into the queue, but other workers
can come and steal from it if they're idle. *)
let[@inline] size_ (self : state) = Array.length self.workers
let num_tasks_ (self : state) : int =
let n = ref 0 in
n := Queue.length self.main_q;
Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers;
!n
(** TLS, used by worker to store their specific state
and be able to retrieve it from tasks when we schedule new
sub-tasks. *)
let k_worker_state : worker_state TLS.t = TLS.create ()
let[@inline] get_current_worker_ () : worker_state option =
TLS.get_opt k_worker_state
let[@inline] get_current_worker_exn () : worker_state =
match TLS.get_exn k_worker_state with
| w -> w
| exception TLS.Not_set ->
failwith "Moonpool: get_current_runner was called from outside a pool."
(** Try to wake up a waiter, if there's any. *)
let[@inline] try_wake_someone_ (self : state) : unit =
if self.n_waiting_nonzero then (
Mutex.lock self.mutex;
Condition.signal self.cond;
Mutex.unlock self.mutex
)
(** Push into worker's local queue, open to work stealing.
precondition: this runs on the worker thread whose state is [self] *)
let schedule_on_current_worker (self : worker_state) task : unit =
let pushed = WSQ.push self.q task in
if pushed then
try_wake_someone_ self.st
else (
Mutex.lock self.st.mutex;
Queue.push task self.st.main_q;
if self.st.n_waiting_nonzero then Condition.signal self.st.cond;
Mutex.unlock self.st.mutex
)
(** Push into the shared queue of this pool *)
let schedule_in_main_queue (self : state) task : unit =
if A.get self.active then (
Mutex.lock self.mutex;
Queue.push task self.main_q;
if self.n_waiting_nonzero then Condition.signal self.cond;
Mutex.unlock self.mutex
) else
raise Shutdown
let schedule_from_w (self : worker_state) (task : WL.task_full) : unit =
match get_current_worker_ () with
| Some w when Id.equal self.st.id_ w.st.id_ ->
schedule_on_current_worker w task
| _ -> schedule_in_main_queue self.st task
exception Got_task of WL.task_full
(** Try to steal a task.
@raise Got_task if it finds one. *)
let try_to_steal_work_once_ (self : worker_state) : unit =
let init = Random.State.int self.rng (Array.length self.st.workers) in
for i = 0 to Array.length self.st.workers - 1 do
let w' =
Array.unsafe_get self.st.workers
((i + init) mod Array.length self.st.workers)
in
if self != w' then (
match WSQ.steal w'.q with
| Some t -> raise_notrace (Got_task t)
| None -> ()
)
done
(** Wait on condition. Precondition: we hold the mutex. *)
let[@inline] wait_for_condition_ (self : state) : unit =
self.n_waiting <- self.n_waiting + 1;
if self.n_waiting = 1 then self.n_waiting_nonzero <- true;
Condition.wait self.cond self.mutex;
self.n_waiting <- self.n_waiting - 1;
if self.n_waiting = 0 then self.n_waiting_nonzero <- false
let rec get_next_task (self : worker_state) : WL.task_full =
match WSQ.pop_exn self.q with
| task ->
try_wake_someone_ self.st;
task
| exception WSQ.Empty -> try_to_steal_from_other_workers_ self
and try_to_steal_from_other_workers_ (self : worker_state) =
match try_to_steal_work_once_ self with
| exception Got_task task -> task
| () -> wait_on_main_queue self
and wait_on_main_queue (self : worker_state) : WL.task_full =
Mutex.lock self.st.mutex;
match Queue.pop self.st.main_q with
| task ->
Mutex.unlock self.st.mutex;
task
| exception Queue.Empty ->
if A.get self.st.active then (
wait_for_condition_ self.st;
match Queue.pop self.st.main_q with
| task ->
Mutex.unlock self.st.mutex;
task
| exception Queue.Empty ->
Mutex.unlock self.st.mutex;
try_to_steal_from_other_workers_ self
) else (
Mutex.unlock self.st.mutex;
raise WL.No_more_tasks
)
let before_start (self : worker_state) : unit =
let t_id = Thread.id @@ Thread.self () in
self.st.on_init_thread ~dom_id:self.dom_id ~t_id ();
TLS.set k_cur_fiber _dummy_fiber;
TLS.set Runner.For_runner_implementors.k_cur_runner self.st.as_runner;
TLS.set k_worker_state self;
Option.iter
(fun name ->
Tracing_.set_thread_name (Printf.sprintf "%s.worker.%d" name self.idx))
self.st.name
let cleanup (self : worker_state) : unit =
Domain_pool_.decr_on self.dom_id;
let t_id = Thread.id @@ Thread.self () in
self.st.on_exit_thread ~dom_id:self.dom_id ~t_id ()
let worker_ops : worker_state WL.ops =
let runner (st : worker_state) = st.st.as_runner in
let around_task st = st.st.around_task in
let on_exn (st : worker_state) (ebt : Exn_bt.t) =
st.st.on_exn (Exn_bt.exn ebt) (Exn_bt.bt ebt)
in
{
WL.schedule = schedule_from_w;
runner;
get_next_task;
get_thread_state = get_current_worker_exn;
around_task;
on_exn;
before_start;
cleanup;
}
let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()
let shutdown_ ~wait (self : state) : unit =
if A.exchange self.active false then (
Mutex.lock self.mutex;
Condition.broadcast self.cond;
Mutex.unlock self.mutex;
if wait then Array.iter (fun w -> Thread.join w.thread) self.workers
)
let as_runner_ (self : state) : t =
Runner.For_runner_implementors.create
~shutdown:(fun ~wait () -> shutdown_ self ~wait)
~run_async:(fun ~fiber f ->
schedule_in_main_queue self @@ T_start { fiber; f })
~size:(fun () -> size_ self)
~num_tasks:(fun () -> num_tasks_ self)
()
type ('a, 'b) create_args =
?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?on_exn:(exn -> Printexc.raw_backtrace -> unit) ->
?around_task:(t -> 'b) * (t -> 'b -> unit) ->
?num_threads:int ->
?name:string ->
'a
(** Arguments used in {!create}. See {!create} for explanations. *)
let create ?(on_init_thread = default_thread_init_exit_)
?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())
?around_task ?num_threads ?name () : t =
let pool_id_ = Id.create () in
let around_task =
match around_task with
| Some (f, g) -> WL.AT_pair (f, g)
| None -> WL.AT_pair (ignore, fun _ _ -> ())
in
let num_domains = Domain_pool_.max_number_of_domains () in
let num_threads = Util_pool_.num_threads ?num_threads () in
let offset = Random.int num_domains in
let pool =
{
id_ = pool_id_;
active = A.make true;
workers = [||];
main_q = Queue.create ();
n_waiting = 0;
n_waiting_nonzero = true;
mutex = Mutex.create ();
cond = Condition.create ();
around_task;
on_exn;
on_init_thread;
on_exit_thread;
name;
as_runner = Runner.dummy;
}
in
pool.as_runner <- as_runner_ pool;
let receive_threads = Bb_queue.create () in
let create_worker_state idx =
let dom_id = (offset + idx) mod num_domains in
{
st = pool;
thread = Thread.self ();
q = WSQ.create ~dummy:WL._dummy_task ();
rng = Random.State.make [| idx |];
dom_id;
idx;
}
in
pool.workers <- Array.init num_threads create_worker_state;
let start_thread_with_idx idx (st : worker_state) =
let create_thread_in_domain () =
let thread =
Thread.create (WL.worker_loop ~block_signals:true ~ops:worker_ops) st
in
Bb_queue.push receive_threads (idx, thread)
in
Domain_pool_.run_on st.dom_id create_thread_in_domain
in
Array.iteri start_thread_with_idx pool.workers;
for _j = 1 to num_threads do
let i, th = Bb_queue.pop receive_threads in
let worker_state = pool.workers.(i) in
worker_state.thread <- th
done;
pool.as_runner
let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
?name () f =
let pool =
create ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
?name ()
in
let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in
f pool