Source file rfc7748.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
module type DH = sig
  type private_key
  type public_key

  val key_size: int

  val base: public_key

  val public_key_of_string: string -> public_key
  val private_key_of_string: string -> private_key

  val string_of_public_key: public_key -> string
  val string_of_private_key: private_key -> string

  val scale: private_key -> public_key -> public_key
  val public_key_of_private_key: private_key -> public_key
end

module X25519: DH = struct
  type private_key = Private_key of Z.t
  type public_key = Public_key of Z.t

  let key_size = 32

  module A = struct
    type element = Z.t
    type integral = Z.t

    let p = Z.(one lsl 255 - ~$19)

    let bits = 255

    let a24 = Z.of_int 121665

    let two = Z.(~$2)

    let constant_time_conditional_swap cond a b =
      let c = Z.(rem cond two) in
      let c' = Z.(one - c) in
      let a' = Z.(c'*a + c*b) in
      let b' = Z.(c'*b + c*a) in
      a', b'
  end

  module C = Curve.Make(Zfield.Zp(A))(Z)(A)

  (* Quoth the RFC:
     set the three least significant bits of the first byte and the most significant bit
     of the last to zero, set the second most significant bit of the last byte to 1
  *)
  let sanitize_scalar =
    let unset_this = Z.logor Z.(~$7) (Z.shift_left Z.(~$128) (8*31)) in
    let set_that = Z.shift_left Z.(~$64) (8*31) in
    fun z ->
      Z.(z - (logand z unset_this))
      |> Z.logor set_that

  let public_key_of_string: string -> public_key = fun s ->
    let p = Serde.z_of_hex s in
    let high = Z.(logand p (~$128 lsl 248)) in
    Public_key Z.(p - high)

  let string_of_public_key: public_key -> string = function Public_key pk ->
    Serde.hex_of_z key_size pk

  let private_key_of_string: string -> private_key = fun s ->
    let z = Serde.z_of_hex s |> sanitize_scalar in
    Private_key z

  let string_of_private_key: private_key -> string = function Private_key pk ->
    Serde.hex_of_z key_size pk

  let scale (Private_key priv) (Public_key pub) = Public_key (C.scale priv pub)

  let base = Public_key (Z.of_int 9)

  let public_key_of_private_key priv = scale priv base
end

let x25519 ~priv ~pub =
  X25519.(
    scale (private_key_of_string priv) (public_key_of_string pub)
    |> string_of_public_key
  )

module X448: DH = struct
  type private_key = Private_key of Z.t
  type public_key = Public_key of Z.t

  let key_size = 56

  module A = struct
    type element = Z.t
    type integral = Z.t

    let p = Z.(one lsl 448 - one lsl 224 - ~$1)

    let bits = 448

    let a24 = Z.of_int 39081

    let two = Z.(~$2)

    let constant_time_conditional_swap cond a b =
      let c = Z.(rem cond two) in
      let c' = Z.(one - c) in
      let a' = Z.(c'*a + c*b) in
      let b' = Z.(c'*b + c*a) in
      a', b'
  end

  module C = Curve.Make(Zfield.Zp(A))(Z)(A)

  (* Quoth the RFC:
     set the two least significant bits of the first byte to 0, and the most
     significant bit of the last byte to 1.
  *)
  let sanitize_scalar =
    let unset_this = Z.(~$3) in
    let set_that = Z.shift_left Z.(~$128) (8*55) in
    fun z ->
      Z.(z - (logand z unset_this))
      |> Z.logor set_that

  let public_key_of_string: string -> public_key = fun s ->
    let p = Serde.z_of_hex s in
    Public_key p

  let string_of_public_key: public_key -> string = function Public_key pk ->
    Serde.hex_of_z key_size pk

  let private_key_of_string: string -> private_key = fun s ->
    let z = Serde.z_of_hex s |> sanitize_scalar in
    Private_key z

  let string_of_private_key: private_key -> string = function Private_key pk ->
    Serde.hex_of_z key_size pk

  let scale (Private_key priv) (Public_key pub) = Public_key (C.scale priv pub)

  let base = Public_key (Z.of_int 5)

  let public_key_of_private_key priv = scale priv base
end

let x448 ~priv ~pub =
  X448.(
    scale (private_key_of_string priv) (public_key_of_string pub)
    |> string_of_public_key
  )