1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
open Poseidon_utils
module Stubs = struct
type ctxt
external allocate_ctxt :
width:int ->
nb_full_rounds:int ->
nb_partial_rounds:int ->
batch_size:int ->
ark:Bls12_381.Fr.t array ->
mds:Bls12_381.Fr.t array array ->
ctxt
= "caml_bls12_381_hash_poseidon_allocate_ctxt_stubs_bytecode" "caml_bls12_381_hash_poseidon_allocate_ctxt_stubs"
external init : ctxt -> width:int -> Bls12_381.Fr.t array -> unit
= "caml_bls12_381_hash_poseidon_init_stubs"
external apply_perm :
ctxt ->
width:int ->
nb_full_rounds:int ->
nb_partial_rounds:int ->
batch_size:int ->
unit = "caml_bls12_381_hash_poseidon_apply_perm_stubs"
external get_state : Bls12_381.Fr.t array -> ctxt -> int -> unit
= "caml_bls12_381_hash_poseidon_get_state_stubs"
end
module Make (Parameters : sig
val nb_full_rounds : int
val nb_partial_rounds : int
val batch_size : int
val width : int
val ark : Bls12_381.Fr.t array
val mds : Bls12_381.Fr.t array array
end) =
struct
open Parameters
type ctxt = Stubs.ctxt
let init inputs =
if Array.length inputs <> width then
failwith (Printf.sprintf "The inputs must be of size %d" width) ;
let modified_ark =
let ( arc_full_round_start_with_first_partial,
arc_intermediate_state,
arc_unbatched,
arc_full_round_end ) =
compute_updated_constants
nb_partial_rounds
nb_full_rounds
width
batch_size
ark
mds
in
Array.concat
[ arc_full_round_start_with_first_partial;
arc_intermediate_state;
arc_unbatched;
arc_full_round_end;
Array.init width (fun _ -> Bls12_381.Fr.(copy zero)) ]
in
let mds_nb_rows = Array.length mds in
let mds_nb_cols = Array.length mds.(0) in
if mds_nb_cols <> mds_nb_rows then
failwith "The parameter MDS must be a square matrix" ;
let ctxt =
Stubs.allocate_ctxt
~width
~nb_full_rounds
~nb_partial_rounds
~batch_size
~ark:modified_ark
~mds
in
Stubs.init ctxt ~width inputs ;
ctxt
let apply_permutation ctxt =
Stubs.apply_perm ctxt ~width ~nb_full_rounds ~nb_partial_rounds ~batch_size
let get ctxt =
let res = Array.init width (fun _ -> Bls12_381.Fr.(copy zero)) in
Stubs.get_state res ctxt width ;
res
end