Source file qcow_block_cache.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
let src =
let src = Logs.Src.create "qcow" ~doc:"qcow2-formatted BLOCK device" in
Logs.Src.set_level src (Some Logs.Info) ;
src
module Log = (val Logs.src_log src : Logs.LOG)
let kib = 1024L
let mib = Int64.mul kib 1024L
open Qcow_types
module Cstructs = Qcow_cstructs
module RangeLocks = struct
(** A set of exclusively locked intervals *)
type t = {mutable locked: Int64.IntervalSet.t; c: unit Lwt_condition.t}
let create () =
let locked = Int64.IntervalSet.empty in
let c = Lwt_condition.create () in
{locked; c}
let with_lock t i f =
let open Lwt.Infix in
let set = Int64.IntervalSet.(add i empty) in
let rec get_lock () =
if Int64.IntervalSet.(is_empty @@ inter t.locked set) then (
t.locked <- Int64.IntervalSet.(union t.locked set) ;
Lwt.return_unit
) else
Lwt_condition.wait t.c >>= fun () -> get_lock ()
in
let put_lock () =
t.locked <- Int64.IntervalSet.(diff t.locked set) ;
Lwt.return_unit
in
get_lock () >>= fun () -> Lwt.finalize f put_lock
end
module Make (B : Qcow_s.RESIZABLE_BLOCK) = struct
type error = B.error
type write_error = B.write_error
let pp_error = B.pp_error
let pp_write_error = B.pp_write_error
type t = {
base: B.t
; mutable info: Mirage_block.info
; sector_size: int
; max_size_bytes: int64
; mutable in_cache: Int64.IntervalSet.t
; mutable zeros: Int64.IntervalSet.t
; mutable cache: Cstruct.t Int64.Map.t
; locks: RangeLocks.t
; mutable disconnect_request: bool
; disconnect_m: Lwt_mutex.t
; write_back_m: Lwt_mutex.t
; zero: Cstruct.t
}
let get_info t = Lwt.return t.info
let lazy_write_back t =
let open Lwt.Infix in
Lwt_mutex.with_lock t.write_back_m (fun () ->
Log.debug (fun f ->
f "lazy_write_back cached sectors = %Ld zeros = %Ld"
(Int64.IntervalSet.cardinal t.in_cache)
(Int64.IntervalSet.cardinal t.zeros)
) ;
assert (Int64.IntervalSet.(is_empty @@ inter t.in_cache t.zeros)) ;
let all = Int64.IntervalSet.union t.in_cache t.zeros in
Int64.diet_fold_s
(fun i err ->
match err with
| Error e ->
Lwt.return (Error e)
| Ok () ->
RangeLocks.with_lock t.locks i (fun () ->
let x, y = Int64.IntervalSet.Interval.(x i, y i) in
let mib = Int64.(div 1048576L (of_int t.sector_size)) in
let rec loop x y =
if x > y then
Lwt.return (Ok ())
else
let y' = min (Int64.add x mib) y in
let rec bufs acc sector last =
if sector > last then
List.rev acc
else
let buf =
if Int64.Map.mem sector t.cache then (
let buf = Int64.Map.find sector t.cache in
t.in_cache <-
Int64.IntervalSet.remove i t.in_cache ;
t.zeros <- Int64.IntervalSet.remove i t.zeros ;
t.cache <- Int64.Map.remove sector t.cache ;
buf
) else
t.zero
in
bufs (buf :: acc) (Int64.succ sector) last
in
let bufs = bufs [] x y' in
B.write t.base x bufs >>= function
| Error e ->
Lwt.return (Error e)
| Ok () ->
loop (Int64.succ y') y
in
loop x y
)
)
all (Ok ())
)
let flush t =
let open Lwt.Infix in
lazy_write_back t >>= function
| Error e ->
Lwt.return (Error e)
| Ok () ->
B.flush t.base
let connect ?(max_size_bytes = Int64.mul 100L mib) base =
let open Lwt.Infix in
B.get_info base >>= fun info ->
let sector_size = info.Mirage_block.sector_size in
let in_cache = Int64.IntervalSet.empty in
let zeros = Int64.IntervalSet.empty in
let cache = Int64.Map.empty in
let locks = RangeLocks.create () in
let disconnect_request = false in
let disconnect_m = Lwt_mutex.create () in
let write_back_m = Lwt_mutex.create () in
let zero = Cstruct.create sector_size in
Cstruct.memset zero 0 ;
let t =
{
base
; info
; sector_size
; max_size_bytes
; in_cache
; cache
; zeros
; locks
; disconnect_request
; disconnect_m
; write_back_m
; zero
}
in
Lwt.return t
let disconnect t =
let open Lwt.Infix in
Lwt_mutex.with_lock t.disconnect_m (fun () ->
t.disconnect_request <- true ;
Lwt.return_unit
)
>>= fun () ->
flush t >>= fun _ -> B.disconnect t.base
let rec per_sector sector_size start bufs f =
match bufs with
| [] ->
Lwt.return (Ok ())
| b :: bs -> (
let open Lwt.Infix in
let rec loop sector remaining =
if Cstruct.length remaining = 0 then
Lwt.return (Ok sector)
else (
assert (Cstruct.length remaining >= sector_size) ;
let first = Cstruct.sub remaining 0 sector_size in
f sector first >>= function
| Error e ->
Lwt.return (Error e)
| Ok () ->
loop (Int64.succ sector) (Cstruct.shift remaining sector_size)
)
in
loop start b >>= function
| Error e ->
Lwt.return (Error e)
| Ok start' ->
per_sector sector_size start' bs f
)
let read t start bufs =
let len = Int64.of_int @@ Cstructs.len bufs in
let i =
Int64.IntervalSet.Interval.make start
Int64.(pred @@ add start (div len (of_int t.sector_size)))
in
let set = Int64.IntervalSet.(add i empty) in
if t.disconnect_request then
Lwt.return (Error `Disconnected)
else
RangeLocks.with_lock t.locks i (fun () ->
if Int64.IntervalSet.(is_empty @@ inter t.in_cache set) then
B.read t.base start bufs
else
per_sector t.sector_size start bufs (fun sector buf ->
if Int64.Map.mem sector t.cache then (
let from_cache = Int64.Map.find sector t.cache in
Cstruct.blit from_cache 0 buf 0 t.sector_size ;
Lwt.return (Ok ())
) else
B.read t.base sector [buf]
)
)
let write t start bufs =
let open Lwt.Infix in
let len = Int64.of_int @@ Cstructs.len bufs in
let current_size_bytes =
Int64.(mul (IntervalSet.cardinal t.in_cache) (of_int t.sector_size))
in
( if Int64.(add current_size_bytes len) > t.max_size_bytes then
lazy_write_back t
else
Lwt.return (Ok ())
)
>>= function
| Error e ->
Lwt.return (Error e)
| Ok () ->
let i =
Int64.IntervalSet.Interval.make start
Int64.(pred @@ add start (div len (of_int t.sector_size)))
in
Lwt_mutex.with_lock t.disconnect_m (fun () ->
if t.disconnect_request then
Lwt.return (Error `Disconnected)
else
RangeLocks.with_lock t.locks i (fun () ->
per_sector t.sector_size start bufs (fun sector buf ->
assert (Cstruct.length buf = t.sector_size) ;
if not (Int64.Map.mem sector t.cache) then (
t.in_cache <- Int64.IntervalSet.(add i t.in_cache) ;
t.zeros <- Int64.IntervalSet.(remove i t.zeros)
) ;
t.cache <- Int64.Map.add sector buf t.cache ;
Lwt.return (Ok ())
)
)
)
let resize t new_size =
let open Lwt.Infix in
B.resize t.base new_size >>= function
| Error e ->
Lwt.return (Error e)
| Ok () ->
if new_size < t.info.Mirage_block.size_sectors then (
let still_ok, to_drop =
Int64.Map.partition (fun sector _ -> sector < new_size) t.cache
in
let to_drop' =
Int64.Map.fold
(fun sector _ set ->
let i = Int64.IntervalSet.Interval.make sector sector in
Int64.IntervalSet.(add i set)
)
to_drop Int64.IntervalSet.empty
in
t.cache <- still_ok ;
t.in_cache <- Int64.IntervalSet.diff t.in_cache to_drop'
) ;
( if new_size > t.info.Mirage_block.size_sectors then
let i =
Int64.IntervalSet.Interval.make t.info.Mirage_block.size_sectors
(Int64.pred new_size)
in
t.zeros <- Int64.IntervalSet.add i t.zeros
) ;
t.info <- {t.info with Mirage_block.size_sectors= new_size} ;
Lwt.return (Ok ())
end