Source file poseidon_core.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
module type PARAMETERS = sig
val width : int
val full_rounds : int
val partial_rounds : int
val round_constants : string array
val mds_matrix : string array array
val partial_round_idx_to_permute : int
end
module type STRATEGY = sig
type scalar
type state
val init : ?input_length:int -> scalar array -> state
val apply_perm : state -> unit
val get : state -> scalar array
val input_length : state -> int option
end
module type HASH = sig
type scalar
type ctxt
val init : ?input_length:int -> unit -> ctxt
val digest : ctxt -> scalar array -> ctxt
val get : ctxt -> scalar
end
module Make (C : PARAMETERS) (Scalar : Bls12_381.Ff_sig.PRIME) = struct
open C
let () =
assert (Array.length mds_matrix = width) ;
assert (Array.for_all (fun line -> Array.length line = width) mds_matrix)
let mds_matrix = Array.map (Array.map Scalar.of_string) mds_matrix
let round_constants = Array.map Scalar.of_string round_constants
let res = Array.make width Scalar.zero
module Strategy = struct
type scalar = Scalar.t
type state = {
mutable i_round_key : int;
state : Scalar.t array;
input_length : int option;
}
let init ?input_length state =
{i_round_key = 0; state = Array.copy state; input_length}
let get_next_round_key s =
let v = round_constants.(s.i_round_key) in
s.i_round_key <- s.i_round_key + 1 ;
v
let s_box x = Scalar.(square (square x) * x)
let apply_round_key s =
let state = s.state in
for i = 0 to Array.length state - 1 do
state.(i) <- Scalar.(get_next_round_key s + state.(i))
done
let apply_s_box_last_elem s =
let s = s.state in
s.(partial_round_idx_to_permute) <- s_box s.(partial_round_idx_to_permute)
let apply_s_box s =
let s = s.state in
for i = 0 to Array.length s - 1 do
s.(i) <- s_box s.(i)
done
let apply_eval_matrix m v =
let v = v.state in
for j = 0 to width - 1 do
for k = 0 to width - 1 do
res.(k) <- Scalar.(res.(k) + (m.(k).(j) * v.(j)))
done
done ;
for j = 0 to width - 1 do
v.(j) <- res.(j) ;
res.(j) <- Scalar.zero
done
let apply_partial_round s =
apply_round_key s ;
apply_s_box_last_elem s ;
apply_eval_matrix mds_matrix s
let apply_full_round s =
apply_round_key s ;
apply_s_box s ;
apply_eval_matrix mds_matrix s
let apply_perm s =
s.i_round_key <- 0 ;
for _i = 0 to (full_rounds / 2) - 1 do
apply_full_round s
done ;
for _i = 0 to partial_rounds - 1 do
apply_partial_round s
done ;
for _i = 0 to (full_rounds / 2) - 1 do
apply_full_round s
done
let get s = Array.copy s.state
let add_cst s idx v =
assert (idx <= width) ;
s.state.(idx) <- Scalar.(s.state.(idx) + v)
let input_length s = s.input_length
end
module Hash = struct
type scalar = Scalar.t
type ctxt = Strategy.state
let init ?input_length () =
let state = Array.make width Scalar.zero in
match input_length with
| None -> Strategy.init state
| Some input_length -> Strategy.init ~input_length state
let digest state data =
let l = Array.length data in
let assert_length expected =
let error_msg =
Format.sprintf "digest expects data of length %d, %d given" expected l
in
if l <> expected then raise @@ Invalid_argument error_msg
in
let input_length_opt = Strategy.input_length state in
Option.iter assert_length input_length_opt ;
let with_padding = Option.is_none input_length_opt in
let chunk_size = width - 1 in
let nb_full_chunk = (l - if with_padding then 0 else 1) / chunk_size in
let r = l mod chunk_size in
for i = 0 to nb_full_chunk - 1 do
let ofs = i * chunk_size in
for j = 0 to chunk_size - 1 do
Strategy.add_cst state (1 + j) data.(ofs + j)
done ;
Strategy.apply_perm state
done ;
let r = if with_padding then r else l - (nb_full_chunk * (width - 1)) in
for j = 0 to r - 1 do
let idx = 1 + j in
Strategy.add_cst state idx data.((nb_full_chunk * chunk_size) + j)
done ;
if with_padding then Strategy.add_cst state (r + 1) Scalar.one ;
Strategy.apply_perm state ;
state
let get (ctxt : ctxt) = ctxt.state.(1)
end
end