Source file custom_gates.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
open Bls
open Identities
module L = Plompiler.LibCircuit
module type S = sig
val arith_label : string
val qadv_label : string
val com_label : string
val gates_list : string list
val nb_custom_gates : int
val nb_input_com : int
val get_eqs :
string ->
q:Scalar.t ->
wires:Scalar.t array ->
wires_g:Scalar.t array ->
?precomputed_advice:Scalar.t SMap.t ->
unit ->
Scalar.t list
val get_ids : string -> string * int
val get_cs :
string ->
q:L.scalar L.repr ->
wires:L.scalar L.repr array ->
wires_g:L.scalar L.repr array ->
?precomputed_advice:L.scalar L.repr SMap.t ->
unit ->
L.scalar L.repr list L.t
val aggregate_prover_identities :
?circuit_prefix:(string -> string) ->
input_coms_size:int ->
proof_prefix:(string -> string) ->
gates:'a SMap.t ->
public_inputs:Scalar.t array ->
domain:Domain.t ->
unit ->
prover_identities
val aggregate_verifier_identities :
?circuit_prefix:(string -> string) ->
input_com_sizes:int list ->
proof_prefix:(string -> string) ->
gates:'a SMap.t ->
public_inputs:Scalar.t array ->
generator:Scalar.t ->
size_domain:int ->
unit ->
verifier_identities
val aggregate_polynomials_degree : gates:'a SMap.t -> int
val exists_gx_composition : gates:'a SMap.t -> bool
val cs_pi :
generator:Scalar.t ->
n:Scalar.t ->
x:L.scalar L.repr ->
zs:L.scalar L.repr ->
L.scalar L.repr list ->
L.scalar L.repr L.t
end
module Aggregator = struct
open Gates_common
let arith_label = arith
let com_label = com_label
let qadv_label = qadv_label
let nb_input_com = 3
let gates_map =
let open Arithmetic_gates in
let open Boolean_gates in
let open Hash_gates in
let open Ecc_gates in
let open Mod_arith_gates in
let linear_monomials =
let open Plompiler.Csir in
List.init nb_wires_arch (fun i -> (linear_selector_name i, i))
in
SMap.of_list
([
(Public.q_label, (module Public : Base_sig));
(Constant.q_label, (module Constant));
(Multiplication.q_label, (module Multiplication));
(X5A.q_label, (module X5A));
(X5C.q_label, (module X5C));
(X2B.q_label, (module X2B));
(AddWeierstrass.q_label, (module AddWeierstrass));
(AddEdwards.q_label, (module AddEdwards));
(ConditionalAddEdwards.q_label, (module ConditionalAddEdwards));
(BoolCheck.q_label, (module BoolCheck));
(CondSwap.q_label, (module CondSwap));
(AnemoiDouble.q_label, (module AnemoiDouble));
(AddMod25519.q_label, (module AddMod25519));
(MulMod25519.q_label, (module MulMod25519));
(AddMod64.q_label, (module AddMod64));
(MulMod64.q_label, (module MulMod64));
]
@ List.map (fun (q, i) -> (q, linear_monomial i q)) linear_monomials
@ List.map
(fun (q, i) ->
let q = Plompiler.Csir.add_next_wire_suffix q in
(q, linear_monomial ~is_next:true i q))
linear_monomials
@ List.init nb_input_com (fun i ->
( "qcom" ^ string_of_int i,
(module InputCom (struct
let idx = i
end) : Base_sig) )))
let gates_list = SMap.keys gates_map
let nb_custom_gates = SMap.cardinal gates_map
let filter_gates gates = SMap.(filter (fun q _ -> mem q gates_map) gates)
let find_gate q =
match SMap.find_opt q gates_map with
| Some gate -> gate
| None ->
failwith
(Printf.sprintf "\nCustom_gates.find_gate : unknown selector %s." q)
let get_prover_identities q =
let module M = (val find_gate q : Base_sig) in
M.prover_identities
let get_coms q =
let module M = (val find_gate q : Base_sig) in
M.index_com
let get_nb_advs q =
let module M = (val find_gate q : Base_sig) in
M.nb_advs
let get_nb_buffers q =
let module M = (val find_gate q : Base_sig) in
M.nb_buffers
let get_verifier_identities q =
let module M = (val find_gate q : Base_sig) in
M.verifier_identities
let get_polynomials_degree q =
let module M = (val find_gate q : Base_sig) in
M.polynomials_degree
let get_eqs q =
let module M = (val find_gate q : Base_sig) in
M.equations
let get_ids q =
let module M = (val find_gate q : Base_sig) in
M.identity
let get_cs q =
let module M = (val find_gate q : Base_sig) in
M.cs
let get_gx_composition q =
let module M = (val find_gate q : Base_sig) in
M.gx_composition
let exists_gx_composition ~gates =
SMap.exists (fun q _ -> get_gx_composition q) (filter_gates gates)
let wires_map prefix get =
List.init Plompiler.Csir.nb_wires_arch (fun i ->
let w = prefix @@ wire_name i in
(w, get w))
|> SMap.of_list
let filter_evaluations ~evaluations ~prefix ~circuit_prefix gates =
let get_eval = Evaluations.find_evaluation evaluations in
let base = SMap.singleton "X" (get_eval "X") in
let wires = SMap.union_disjoint base (wires_map prefix get_eval) in
List.fold_left
(fun acc gate ->
let acc =
match get_coms gate with
| None -> acc
| Some idx ->
let com_i = prefix com_label ^ string_of_int idx in
SMap.add com_i (get_eval com_i) acc
in
let acc =
List.init (get_nb_advs gate) (fun i ->
circuit_prefix qadv_label ^ string_of_int i)
|> List.fold_left
(fun acc2 adv_i -> SMap.add adv_i (get_eval adv_i) acc2)
acc
in
let gate = circuit_prefix gate in
if gate = circuit_prefix "qpub" then acc
else SMap.add gate (get_eval gate) acc)
wires
gates
let filter_answers ~answers ~prefix ~circuit_prefix gates =
let x, gx = (string_of_eval_point X, string_of_eval_point GX) in
let get_x, get_gx = (get_answer answers X, get_answer answers GX) in
let add_mapmap k' k v mm =
SMap.(update k' (fun m -> Option.bind m (fun m -> Some (add k v m)))) mm
in
let add_x = add_mapmap x in
let wires =
let base = SMap.singleton x (wires_map prefix get_x) in
let wires_g =
if exists_gx_composition ~gates then wires_map prefix get_gx
else SMap.empty
in
SMap.add_unique gx wires_g base
in
SMap.fold
(fun gate _ acc ->
let acc =
match get_coms gate with
| None -> acc
| Some idx ->
let com_i = prefix com_label ^ string_of_int idx in
add_x com_i (get_x com_i) acc
in
let acc =
List.init (get_nb_advs gate) (fun i ->
circuit_prefix qadv_label ^ string_of_int i)
|> List.fold_left
(fun acc2 adv_i -> add_x adv_i (get_x adv_i) acc2)
acc
in
let gate = circuit_prefix gate in
if gate = circuit_prefix "qpub" then acc
else add_x gate (get_x gate) acc)
gates
wires
let aggregate_prover_identities ?(circuit_prefix = Fun.id) ~input_coms_size
~proof_prefix:prefix ~gates ~public_inputs ~domain () : prover_identities
=
fun evaluations ->
let size_eval = Evaluations.size_evaluations evaluations in
let gates =
List.sort
(fun gate1 gate2 ->
Int.compare (get_nb_buffers gate1) (get_nb_buffers gate2))
(filter_gates gates |> SMap.keys)
in
let nb_min_buffers =
match gates with
| [] -> 0
| hd_gates :: _ -> max 1 (get_nb_buffers hd_gates)
in
tmp_buffers :=
Array.init nb_min_buffers (fun _ -> Evaluations.create size_eval) ;
let evaluations =
filter_evaluations ~evaluations ~prefix ~circuit_prefix gates
in
let init_ids =
let arith_acc_evaluation = Evaluations.create size_eval in
SMap.singleton (prefix @@ arith ^ ".0") arith_acc_evaluation
in
let union key e1 e2 =
assert (key = prefix @@ arith ^ ".0") ;
Some (Evaluations.add ~res:e1 e1 e2)
in
let public = {public_inputs; input_coms_size} in
List.fold_left
(fun accumulated_ids gate ->
SMap.union
union
accumulated_ids
(get_prover_identities
gate
~prefix_common:circuit_prefix
~prefix
~public
~domain
evaluations))
init_ids
gates
let aggregate_verifier_identities ?(circuit_prefix = Fun.id) ~input_com_sizes
~proof_prefix:prefix ~gates ~public_inputs ~generator ~size_domain () :
verifier_identities =
fun x answers ->
let gates = filter_gates gates in
let answers = filter_answers ~answers ~prefix ~circuit_prefix gates in
let arith_id = SMap.singleton (prefix @@ arith ^ ".0") Scalar.zero in
let union key s1 s2 =
assert (key = prefix @@ arith ^ ".0") ;
Some (Scalar.add s1 s2)
in
let public =
{public_inputs; input_coms_size = List.fold_left ( + ) 0 input_com_sizes}
in
SMap.fold
(fun gate _ accumulated_ids ->
let gate_ids =
get_verifier_identities
gate
~prefix_common:circuit_prefix
~prefix
~public
~generator
~size_domain
in
SMap.union union accumulated_ids (gate_ids x answers))
gates
arith_id
let aggregate_polynomials_degree ~gates =
SMap.fold
(fun gate _ degree ->
let map = get_polynomials_degree gate in
SMap.fold (fun _ d acc -> max d acc) map degree)
(filter_gates gates)
0
let cs_pi ~generator ~n ~x ~zs pi_list =
let open L in
let open Num in
let mone = Scalar.(negate one) in
let n_inv = Scalar.(inverse_exn (negate n)) in
let g_inv = Scalar.inverse_exn generator in
match pi_list with
| [] -> Num.zero
| hd :: tl_pi ->
let* left_term = mul_by_constant n_inv zs in
let* init_value =
let* x_minus_one = add_constant mone x in
div hd x_minus_one
in
let* sum, _g_inv_k =
foldM
(fun (res, g_inv_k) w_k ->
let* x_g_inv_k_minus_one = add_constant ~ql:g_inv_k mone x in
let* to_add = div w_k x_g_inv_k_minus_one in
let* res = add res to_add in
ret (res, Scalar.mul g_inv_k g_inv))
(init_value, g_inv)
tl_pi
in
mul left_term sum
end
include (Aggregator : S)