Source file mirage_block_safe.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
(*
 * Copyright (C) 2015 David Scott <dave.scott@docker.com>
 *
 * Permission to use, copy, modify, and distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 *
 *)

open Lwt.Infix
open Mirage_block_log

module Make (B: Mirage_block.S) = struct

  type t = B.t

  type error = [
    | Mirage_block.error
    | `Unsafe of string
    | `Private of B.error
  ]

  let pp_error ppf = function
    | #Mirage_block.error as e -> Mirage_block.pp_error ppf e
    | `Unsafe s  -> Fmt.string ppf s
    | `Private b -> B.pp_error ppf b

  type write_error = [
    | Mirage_block.write_error
    | `Unsafe of string
    | `Private of B.write_error
  ]

  let pp_write_error ppf = function
    | #Mirage_block.write_error as e -> Mirage_block.pp_write_error ppf e
    | `Unsafe s  -> Fmt.string ppf s
    | `Private b -> B.pp_write_error ppf b

  let get_info = B.get_info
  let disconnect = B.disconnect

  let lift_error = function
    | Ok x -> Ok x
    | Error (#Mirage_block.error as e) -> Error e
    | Error e -> Error (`Private e)

  let lift_write_error = function
    | Ok x -> Ok x
    | Error (#Mirage_block.write_error as e) -> Error e
    | Error e -> Error (`Private e)

  let (>>*=) x f = x >>= function
    | Ok q    -> f q
    | Error e -> Lwt.return @@ Error e

  let fatalf fmt = Printf.ksprintf (fun s ->
      err (fun f -> f "%s" s);
      Lwt.return (Error (`Unsafe s))
    ) fmt

  let check_buffer op sector_size b =
    (* Check buffers are whole numbers of sectors *)
    ( let len = Cstruct.length b in
      if len mod sector_size <> 0
      then fatalf "%s: buffer length (%d) is not a multiple of \
                   sector_size (%d)" op len sector_size
      else Lwt.return (Ok ()) )
    >>*= fun () ->
    (* TODO: Check buffers are sector-aligned *)
    Lwt.return (Ok ())

  let rec check_buffers op sector_size = function
    | [] -> Lwt.return (Ok ())
    | b :: bs ->
      check_buffer op sector_size b
      >>*= fun () ->
      check_buffers op sector_size bs

  let check_in_range op size_sectors offset =
    if offset < 0L || offset >= size_sectors
    then fatalf "%s: sector offset out of range 0 <= %Ld < %Ld"
        op offset size_sectors
    else Lwt.return (Ok ())

  let check op sector_size size_sectors offset buffers =
    check_buffers op sector_size buffers
    >>*= fun () ->
    check_in_range op size_sectors offset
    >>*= fun () ->
    let length = List.fold_left (fun acc b -> Cstruct.length b + acc) 0 buffers in
    let next_sector = Int64.(add offset (of_int @@ length / sector_size)) in
    if next_sector > size_sectors
    then fatalf "%s: sector offset out of range %Ld > %Ld"
        op next_sector size_sectors
    else Lwt.return (Ok ())

  let unsafe_read = B.read
  let unsafe_write = B.write

  open Lwt.Infix
  open Mirage_block

  let read t offset buffers =
    B.get_info t
    >>= fun info ->
    check "read" info.sector_size info.size_sectors offset buffers
    >>*= fun () ->
    unsafe_read t offset buffers >|=
    lift_error

  let write t offset buffers =
    B.get_info t
    >>= fun info ->
    check "write" info.sector_size info.size_sectors offset buffers
    >>*= fun () ->
    unsafe_write t offset buffers >|=
    lift_write_error

end