Source file moonpool_forkjoin.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
module A = Moonpool.Atomic
module Suspend_ = Moonpool.Private.Suspend_
module Domain_ = Moonpool_private.Domain_
module State_ = struct
type error = exn * Printexc.raw_backtrace
type 'a or_error = ('a, error) result
type ('a, 'b) t =
| Init
| Left_solved of 'a or_error
| Right_solved of 'b or_error * Suspend_.suspension
| Both_solved of 'a or_error * 'b or_error
let get_exn_ (self : _ t A.t) =
match A.get self with
| Both_solved (Ok a, Ok b) -> a, b
| Both_solved (Error (exn, bt), _) | Both_solved (_, Error (exn, bt)) ->
Printexc.raise_with_backtrace exn bt
| _ -> assert false
let rec set_left_ (self : _ t A.t) (left : _ or_error) =
let old_st = A.get self in
match old_st with
| Init ->
let new_st = Left_solved left in
if not (A.compare_and_set self old_st new_st) then (
Domain_.relax ();
set_left_ self left
)
| Right_solved (right, cont) ->
let new_st = Both_solved (left, right) in
if not (A.compare_and_set self old_st new_st) then (
Domain_.relax ();
set_left_ self left
) else
cont (Ok ())
| Left_solved _ | Both_solved _ -> assert false
let rec set_right_ (self : _ t A.t) (right : _ or_error) : unit =
let old_st = A.get self in
match old_st with
| Left_solved left ->
let new_st = Both_solved (left, right) in
if not (A.compare_and_set self old_st new_st) then set_right_ self right
| Init ->
Suspend_.suspend
{
Suspend_.handle =
(fun ~run:_ ~resume suspension ->
while
let old_st = A.get self in
match old_st with
| Init ->
not
(A.compare_and_set self old_st
(Right_solved (right, suspension)))
| Left_solved left ->
A.set self (Both_solved (left, right));
resume suspension (Ok ());
false
| Right_solved _ | Both_solved _ -> assert false
do
()
done);
}
| Right_solved _ | Both_solved _ -> assert false
end
let both f g : _ * _ =
let module ST = State_ in
let st = A.make ST.Init in
let runner =
match Runner.get_current_runner () with
| None -> invalid_arg "Fork_join.both must be run from within a runner"
| Some r -> r
in
Runner.run_async runner (fun () ->
try
let res = f () in
ST.set_left_ st (Ok res)
with exn ->
let bt = Printexc.get_raw_backtrace () in
ST.set_left_ st (Error (exn, bt)));
let res_right =
try Ok (g ())
with exn ->
let bt = Printexc.get_raw_backtrace () in
Error (exn, bt)
in
ST.set_right_ st res_right;
ST.get_exn_ st
let both_ignore f g = ignore (both f g : _ * _)
let for_ ?chunk_size n (f : int -> int -> unit) : unit =
if n > 0 then (
let has_failed = A.make false in
let missing = A.make n in
let chunk_size =
match chunk_size with
| Some cs -> max 1 (min n cs)
| None ->
max 1 (1 + (n / Moonpool.Private.num_domains ()))
in
let start_tasks ~run ~resume (suspension : Suspend_.suspension) =
let task_for ~offset ~len_range =
match f offset (offset + len_range - 1) with
| () ->
if A.fetch_and_add missing (-len_range) = len_range then
resume suspension (Ok ())
| exception exn ->
let bt = Printexc.get_raw_backtrace () in
if not (A.exchange has_failed true) then
resume suspension (Error (exn, bt))
in
let i = ref 0 in
while !i < n do
let offset = !i in
let len_range = min chunk_size (n - offset) in
assert (offset + len_range <= n);
run (fun () -> task_for ~offset ~len_range);
i := !i + len_range
done
in
Suspend_.suspend
{
Suspend_.handle =
(fun ~run ~resume suspension ->
start_tasks ~run ~resume suspension);
}
)
let all_array ?chunk_size (fs : _ array) : _ array =
let len = Array.length fs in
let arr = Array.make len None in
for_ ?chunk_size len (fun low high ->
for i = low to high do
let x = fs.(i) () in
arr.(i) <- Some x
done);
Array.map
(function
| None -> assert false
| Some x -> x)
arr
let all_list ?chunk_size fs : _ list =
Array.to_list @@ all_array ?chunk_size @@ Array.of_list fs
let all_init ?chunk_size n f : _ list =
let arr = Array.make n None in
for_ ?chunk_size n (fun low high ->
for i = low to high do
let x = f i in
arr.(i) <- Some x
done);
List.init n (fun i ->
match arr.(i) with
| None -> assert false
| Some x -> x)
let map_array ?chunk_size f arr : _ array =
let n = Array.length arr in
let res = Array.make n None in
for_ ?chunk_size n (fun low high ->
for i = low to high do
res.(i) <- Some (f arr.(i))
done);
Array.map
(function
| None -> assert false
| Some x -> x)
res
let map_list ?chunk_size f (l : _ list) : _ list =
let arr = Array.of_list l in
let n = Array.length arr in
let res = Array.make n None in
for_ ?chunk_size n (fun low high ->
for i = low to high do
res.(i) <- Some (f arr.(i))
done);
List.init n (fun i ->
match res.(i) with
| None -> assert false
| Some x -> x)