Source file pbkdf.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
module type S = sig
  val pbkdf1 : password:Cstruct.t -> salt:Cstruct.t -> count:int -> dk_len:int -> Cstruct.t
  val pbkdf2 : password:Cstruct.t -> salt:Cstruct.t -> count:int -> dk_len:int32 -> Cstruct.t
end

let cdiv x y =
  (* This is lifted from Nocrypto.Uncommon.(//)
     (formerly known as [cdiv]). It is part of the documented, publically
     exposed _internal_ utility library not for public consumption, hence
     the API break that prompted this copy-pasted function. *)
  if y < 1 then raise Division_by_zero else
    if x > 0 then 1 + ((x - 1) / y) else 0 [@@inline]

module Make (H: Mirage_crypto.Hash.S) : S = struct
  let pbkdf1 ~password ~salt ~count ~dk_len =
    if Cstruct.length salt <> 8 then invalid_arg "salt should be 8 bytes"
    else if count <= 0 then invalid_arg "count must be a positive integer"
    else if dk_len <= 0 then invalid_arg "derived key length must be a positive integer"
    else if dk_len > H.digest_size then invalid_arg "derived key too long"
    else
      let rec loop t = function
          0 -> t
        | i -> loop (H.digest t) (i - 1)
      in
      Cstruct.sub (loop (Cstruct.append password salt) count) 0 dk_len

  let pbkdf2 ~password ~salt ~count ~dk_len =
    if count <= 0 then invalid_arg "count must be a positive integer"
    else if dk_len <= 0l then invalid_arg "derived key length must be a positive integer"
    else
      let h_len = H.digest_size
      and dk_len = Int32.to_int dk_len in
      let l = cdiv dk_len h_len in
      let r = dk_len - (l - 1) * h_len in
      let block i =
        let rec f u xor = function
            0 -> xor
          | j -> let u = H.hmac ~key:password u in
            f u (Mirage_crypto.Uncommon.Cs.xor xor u) (j - 1)
        in
        let int_i = Cstruct.create 4 in
        Cstruct.BE.set_uint32 int_i 0 (Int32.of_int i);
        let u_1 = H.hmac ~key:password (Cstruct.append salt int_i) in
        f u_1 u_1 (count - 1)
      in
      let rec loop blocks = function
          0 -> blocks
        | i -> loop ((block i)::blocks) (i - 1)
      in
      Cstruct.concat (loop [Cstruct.sub (block l) 0 r] (l - 1))
end

let pbkdf1 ~hash ~password ~salt ~count ~dk_len =
  let module H = (val (Mirage_crypto.Hash.module_of hash)) in
  let module PBKDF = Make (H) in
  PBKDF.pbkdf1 ~password ~salt ~count ~dk_len

let pbkdf2 ~prf ~password ~salt ~count ~dk_len =
  let module H = (val (Mirage_crypto.Hash.module_of prf)) in
  let module PBKDF = Make (H) in
  PBKDF.pbkdf2 ~password ~salt ~count ~dk_len