1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
module S = Csir.Scalar
let alpha = Z.(shift_left one (Z.numbits S.order) - S.order)
let bitlist : le:bool -> bytes -> bool list =
fun ~le b ->
let l = Bytes.length b in
let start = if le then 0 else l - 1 in
let stop = if le then l else -1 in
let next = if le then succ else pred in
let rec loop_byte acc n =
if n = stop then acc
else
let byte = Bytes.get_uint8 b n in
let rec loop_bit acc m =
if m = 8 then acc
else
let mask = 1 lsl m in
let bit = byte land mask in
let bit = if bit = 0 then false else true in
loop_bit (bit :: acc) (m + 1)
in
let acc = loop_bit acc 0 in
loop_byte acc (next n)
in
List.rev @@ loop_byte [] start
let of_bitlist : le:bool -> bool list -> bytes =
fun ~le bl ->
assert (List.length bl mod 8 = 0) ;
let rec loop_byte acc rest =
match rest with
| [] ->
let res = if le then List.rev acc else acc in
Bytes.(concat empty res)
| _ ->
let rec loop_bit acc pos rest =
if pos = 8 then (acc, rest)
else
match rest with
| [] -> assert false
| bit :: rest ->
let mask = if bit then 1 lsl pos else 0 in
let acc = acc lor mask in
loop_bit acc (succ pos) rest
in
let byte_as_int, rest = loop_bit 0 0 rest in
let byte = Bytes.create 1 in
Bytes.set_uint8 byte 0 byte_as_int ;
loop_byte (byte :: acc) rest
in
loop_byte [] bl
let bytes_of_hex hs =
let h = `Hex hs in
Hex.to_bytes h
let hex_of_bytes bs = Hex.of_bytes bs |> Hex.show
let bool_list_to_scalar : bool list -> S.t =
fun b_list ->
let res, _ =
List.fold_left
(fun (acc_res, acc_p) b ->
let acc_res = if b then S.(acc_res + acc_p) else acc_res in
let acc_p = S.double acc_p in
(acc_res, acc_p))
(S.zero, S.one)
b_list
in
res
let bool_list_to_z : bool list -> Z.t =
fun b_list ->
let res, _ =
List.fold_left
(fun (acc_res, acc_p) b ->
let acc_res = if b then Z.(acc_res + acc_p) else acc_res in
let acc_p = Z.(acc_p + acc_p) in
(acc_res, acc_p))
(Z.zero, Z.one)
b_list
in
res
let bool_list_of_z : ?nb_bits:int -> Z.t -> bool list =
fun ?nb_bits z ->
let two = Z.of_int 2 in
let rec aux bits z = function
| 0 -> List.rev bits
| n ->
let b = Z.(equal (z mod two) one) in
aux (b :: bits) (Z.div z two) (n - 1)
in
aux [] z @@ Option.value ~default:(Z.numbits z) nb_bits
let split_exactly list size_chunk =
let len = List.length list in
let nb_chunks = len / size_chunk in
assert (len = size_chunk * nb_chunks) ;
List.init nb_chunks (fun i ->
let array = Array.of_list list in
let array = Array.sub array (i * size_chunk) size_chunk in
Array.to_list array)
let bool_list_change_endianness b =
assert (List.length b mod 8 = 0) ;
split_exactly b 8 |> List.rev |> List.concat
let limbs_of_bool_list ~nb_bits bl =
let bl = split_exactly bl nb_bits in
let sum x =
List.fold_left
(fun acc a -> (acc lsl 1) + if a then 1 else 0)
0
(List.rev x)
in
List.map sum bl
module Z = struct
include Z
let t : t Repr.t =
Repr.(
map
bytes
(fun bs -> Z.of_bits (Bytes.unsafe_to_string bs))
(fun s -> Z.to_bits s |> Bytes.of_string))
end
let ( %! ) = Z.rem
let next_multiple_of k n = k * (1 + ((n - 1) / k))
let is_power_of_2 n = Z.log2 n = Z.log2up n
let min_nb_limbs ~modulus ~base =
assert (Z.(modulus > one)) ;
assert (Z.(base > one)) ;
let rec aux acc k =
if acc >= modulus then k else aux Z.(acc * base) (k + 1)
in
aux base 1
let z_to_limbs ~len ~base n =
let rec aux output n =
let q, r = Z.div_rem n base in
if Z.(q = zero) then r :: output else aux (r :: output) q
in
if n < Z.zero then
raise @@ Failure "z_to_limbs: n must be greater than or equal to zero" ;
let limbs = aux [] n in
let nb_limbs = List.length limbs in
if nb_limbs > len then
raise @@ Failure "z_to_limbs: n must be strictly lower than base^len"
else List.init (len - nb_limbs) (Fun.const Z.zero) @ limbs
let z_of_limbs ~base limbs =
List.fold_left (fun acc x -> Z.((base * acc) + x)) Z.zero limbs
let mod_add_limbs ~modulus ~base xs ys =
let nb_limbs = List.length xs in
assert (List.compare_length_with ys nb_limbs = 0) ;
let x = z_of_limbs ~base xs in
let y = z_of_limbs ~base ys in
let z = Z.((x + y) %! modulus) in
let z = if z < Z.zero then Z.(z + modulus) else z in
z_to_limbs ~len:nb_limbs ~base z
let mod_sub_limbs ~modulus ~base xs ys =
mod_add_limbs ~modulus ~base xs (List.map Z.neg ys)
let mod_mul_limbs ~modulus ~base xs ys =
let nb_limbs = List.length xs in
assert (List.compare_length_with ys nb_limbs = 0) ;
let x = z_of_limbs ~base xs in
let y = z_of_limbs ~base ys in
let z = Z.(x * y %! modulus) in
let z = if z < Z.zero then Z.(z + modulus) else z in
z_to_limbs ~len:nb_limbs ~base z
let mod_div_limbs ~modulus ~base xs ys =
let nb_limbs = List.length xs in
assert (List.compare_length_with ys nb_limbs = 0) ;
let x = z_of_limbs ~base xs in
let y = z_of_limbs ~base ys in
let d, y_inv, _v = Z.gcdext y modulus in
if Z.(rem x d <> zero) then
raise
@@ Failure
(Format.sprintf
"mod_div_limbs: %s is not divisible by %s (modulo %s)"
(Z.to_string x)
(Z.to_string y)
(Z.to_string modulus)) ;
let z = Z.(divexact x d * y_inv %! modulus) in
let z = if z < Z.zero then Z.(z + modulus) else z in
z_to_limbs ~len:nb_limbs ~base z
let rec transpose = function
| [] | [] :: _ -> []
| rows -> List.(map hd rows :: (transpose @@ map tl rows))
let of_bytes repr bs =
Stdlib.Result.get_ok
@@ Repr.(unstage @@ of_bin_string repr) (Bytes.unsafe_to_string bs)
let to_bytes repr e =
Bytes.unsafe_of_string @@ Repr.(unstage @@ to_bin_string repr) e
let tables_cs_encoding_t : (string list * Csir.CS.t) Repr.t =
let open Repr in
pair (list string) Csir.CS.t
let save_cs_to_file path tables cs =
let s = Repr.to_json_string tables_cs_encoding_t (tables, cs) in
let outc = open_out path in
output_string outc s ;
close_out outc
let load_cs_from_file path =
if not (Sys.file_exists path) then
raise
@@ Invalid_argument
(Printf.sprintf "load_cs_from_file: %s does not exist." path) ;
let inc = open_in path in
let content = really_input_string inc (in_channel_length inc) in
let res =
Repr.of_json_string tables_cs_encoding_t content |> Stdlib.Result.get_ok
in
close_in inc ;
res
let get_circuit_id cs =
let serialized_bytes = to_bytes Csir.CS.t cs in
Hacl_star.Hacl.Blake2b_32.hash serialized_bytes 32 |> Hex.of_bytes |> Hex.show
let circuit_dir =
match Sys.getenv_opt "TMPDIR" with
| None -> "/tmp/plompiler"
| Some dir -> dir ^ "/plompiler"
let circuit_path s =
if not @@ Sys.file_exists circuit_dir then Sys.mkdir circuit_dir 0o755 ;
circuit_dir ^ "/" ^ s
let dump_label_traces path (cs : Csir.CS.t) =
let outc = open_out path in
List.iter
Csir.CS.(
Array.iter (fun c ->
Printf.fprintf outc "%s 1\n" @@ String.concat "; " (List.rev c.label)))
cs ;
close_out outc
let dump_label_range_checks_traces path fg =
let outc = open_out path in
List.iter
(fun (label, nb) ->
Printf.fprintf outc "%s %d\n" (String.concat "; " label) nb)
fg ;
close_out outc