Source file lang_stdlib.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
(*****************************************************************************)
(*                                                                           *)
(* MIT License                                                               *)
(* Copyright (c) 2022 Nomadic Labs <contact@nomadic-labs.com>                *)
(*                                                                           *)
(* Permission is hereby granted, free of charge, to any person obtaining a   *)
(* copy of this software and associated documentation files (the "Software"),*)
(* to deal in the Software without restriction, including without limitation *)
(* the rights to use, copy, modify, merge, publish, distribute, sublicense,  *)
(* and/or sell copies of the Software, and to permit persons to whom the     *)
(* Software is furnished to do so, subject to the following conditions:      *)
(*                                                                           *)
(* The above copyright notice and this permission notice shall be included   *)
(* in all copies or substantial portions of the Software.                    *)
(*                                                                           *)
(* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR*)
(* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,  *)
(* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL   *)
(* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER*)
(* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING   *)
(* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER       *)
(* DEALINGS IN THE SOFTWARE.                                                 *)
(*                                                                           *)
(*****************************************************************************)

open Lang_core

module type LIB = sig
  include COMMON

  val foldiM : ('a -> int -> 'a t) -> 'a -> int -> 'a t

  val fold2M : ('a -> 'b -> 'c -> 'a t) -> 'a -> 'b list -> 'c list -> 'a t

  val mapM : ('a -> 'b t) -> 'a list -> 'b list t

  val map2M : ('a -> 'b -> 'c t) -> 'a list -> 'b list -> 'c list t

  val iterM : ('a -> unit repr t) -> 'a list -> unit repr t

  val iter2M : ('a -> 'b -> unit repr t) -> 'a list -> 'b list -> unit repr t

  module Bool : sig
    include BOOL

    (** Returns the pair (s, c_out), as per
     https://en.wikipedia.org/wiki/Adder_(electronics)#Full_adder *)
    val full_adder : bool repr -> bool repr -> bool repr -> (bool * bool) repr t
  end
  with type scalar = scalar
   and type 'a repr = 'a repr
   and type 'a t = 'a t

  module Num : sig
    include NUM

    val square : scalar repr -> scalar repr t

    val pow : scalar repr -> bool repr list -> scalar repr t

    val add_list :
      ?qc:S.t -> ?coeffs:S.t list -> scalar list repr -> scalar repr t

    val mul_list : scalar list repr -> scalar repr t

    val mul_by_constant : S.t -> scalar repr -> scalar repr t

    val scalar_of_bytes : bool list repr -> scalar repr t

    val is_eq_const : scalar repr -> S.t -> bool repr t

    val assert_eq_const : scalar repr -> S.t -> unit repr t

    (** [is_upper_bounded ~bound x] returns whether the scalar [x] is
        strictly lower than [bound] when [x] is interpreted as an integer
        from [0] to [p-1] (being [p] the scalar field order).
        This circuit is total (and more expensive than our version below). *)
    val is_upper_bounded : bound:Z.t -> scalar repr -> bool repr t

    (** Same as [is_upper_bounded] but cheaper and partial.
        [is_upper_bounded_unsafe ~bound l] is unsatisfiable if l cannot be
        represented in binary with [Z.numbits bound] bits. *)
    val is_upper_bounded_unsafe :
      ?nb_bits:int -> bound:Z.t -> scalar repr -> bool repr t

    (** [geq (a, bound_a) (b, bound_b)] returns the boolean wire representing
        a >= b.
        Pre-condition: a ∈ [0, bound_a) ∧ b ∈ [0, bound_b) *)
    val geq : scalar repr * Z.t -> scalar repr * Z.t -> bool repr t
  end
  with type scalar = scalar
   and type 'a repr = 'a repr
   and type 'a t = 'a t

  module Enum (N : sig
    val n : int
  end) : sig
    (* [switch_case k l] returns the k-th element of the list [l] if k ∈ [0,n)
       or the first element of [l] otherwise. *)
    val switch_case : scalar repr -> 'a list repr -> 'a repr t
  end

  module Bytes : sig
    type bl = bool list

    val add : ?ignore_carry:bool -> bl repr -> bl repr -> bl repr t

    val xor : bl repr -> bl repr -> bl repr t

    val rotate : bl repr -> int -> bl repr
  end

  val add2 :
    (scalar * scalar) repr -> (scalar * scalar) repr -> (scalar * scalar) repr t

  val constant_bool : bool -> bool repr t

  val constant_bytes : ?le:bool -> bytes -> Bytes.bl repr t

  val constant_uint32 : ?le:bool -> Stdint.uint32 -> Bytes.bl repr t
end

module Lib (C : COMMON) = struct
  include C

  let foldiM : ('a -> int -> 'a t) -> 'a -> int -> 'a t =
   fun f e n -> foldM f e (List.init n (fun i -> i))

  let fold2M f acc ls rs =
    foldM (fun acc (l, r) -> f acc l r) acc (List.combine ls rs)

  let mapM f l =
    let* l =
      foldM
        (fun acc e ->
          let* e = f e in
          ret @@ (e :: acc))
        []
        l
    in
    ret @@ List.rev l

  let map2M f ls rs = mapM (fun (l, r) -> f l r) (List.combine ls rs)

  let iterM f l = foldM (fun _ a -> f a) unit l

  let iter2M f l r = iterM (fun (l, r) -> f l r) (List.combine l r)

  module Bool = struct
    include Bool

    let full_adder a b c_in =
      let* a_xor_b = xor a b in
      let* a_xor_b_xor_c = xor a_xor_b c_in in
      let* a_xor_b_and_c = band a_xor_b c_in in
      let* a_and_b = band a b in
      let* c = bor a_xor_b_and_c a_and_b in
      ret (pair a_xor_b_xor_c c)
  end

  module Num = struct
    include Num

    let square l = mul l l

    let pow x n_list =
      let* init =
        let* left = constant_scalar S.one in
        ret (left, x)
      in
      let* res, _acc =
        foldM
          (fun (res, acc) bool ->
            let* res_true = mul res acc in
            let* res = Bool.ifthenelse bool res_true res in
            let* acc = mul acc acc in
            ret (res, acc))
          init
          n_list
      in
      ret res

    let add_list ?(qc = S.zero) ?(coeffs = []) l =
      let l = of_list l in
      let q =
        if coeffs != [] then coeffs
        else List.init (List.length l) (fun _ -> S.one)
      in
      assert (List.compare_lengths q l = 0) ;
      match (l, q) with
      | x1 :: x2 :: xs, ql :: qr :: qs ->
          let* res = Num.add ~qc ~ql ~qr x1 x2 in
          fold2M (fun acc x ql -> Num.add ~ql x acc) res xs qs
      | [x], [ql] -> Num.add_constant ~ql qc x
      | [], [] -> constant_scalar qc
      | _, _ -> assert false

    let mul_list l =
      match of_list l with [] -> assert false | x :: xs -> foldM Num.mul x xs

    let mul_by_constant s x = Num.add_constant ~ql:s S.zero x

    (* Evaluates P(X) = \sum_i bᵢ Xⁱ at 2 with Horner's method:
       P(2) = b₀ + 2 (b₁+ 2 (b₂ + 2(…))). *)
    let scalar_of_bytes b =
      let* zero = constant_scalar S.zero in
      foldM
        (fun acc b -> add acc (scalar_of_bool b) ~ql:S.(one + one) ~qr:S.one)
        zero
        (List.rev (of_list b))

    let assert_eq_const l s = Num.assert_custom ~ql:S.mone ~qc:s l l l

    let is_eq_const l s =
      let* diff = add_constant ~ql:S.mone s l in
      is_zero diff

    (* Function used by [is_upper_bounded(_unsafe)] *)
    let ignore_leading_zeros ~nb_bits ~bound xbits =
      (* We can ignore all leading zeros in the bound's little-endian binary
         decomposition. The assertion cannot be satisfied if they are all zeros. *)
      let rec shave_zeros = function
        | [] ->
            raise
              (Invalid_argument
                 "is_upper_bounded cannot be satisfied on bound = 0")
        | (x, true) :: tl -> (x, tl)
        | (_, false) :: tl -> shave_zeros tl
      in
      List.combine (of_list xbits) (Utils.bool_list_of_z ~nb_bits bound)
      |> shave_zeros

    (* Let [(bn,...,b0)] and [(xn,...,x0)] be binary representations
       of [bound] and [x] respectively where the least significant bit is
       indexed by [0]. Let [op_i = if b_i = 1 then band else bor] for all [i].
       Predicate [x] < [bound] can be expressed as the negation of predicate:
       [ xn op_n (... (x1 op_1 (x0 op_0 true))) ].
       Intuitively we need to carry through a flag that indicates if up to
       step i, x is greater than b. In order for x[0,i] to be greater than
       b[0,i]:
       - if b_i = one then x_i will need to match it and the flag must be true.
       - if b_i = zero then x needs to be one if the flag is false or it can
         have any value if the flag is already true. *)
    let is_upper_bounded_unsafe ?nb_bits ~bound x =
      assert (Z.zero <= bound && bound < S.order) ;
      let nb_bits =
        Option.value
          ~default:(if Z.(equal bound zero) then 1 else Z.numbits bound)
          nb_bits
      in
      let* xbits = bits_of_scalar ~nb_bits x in
      let init, xibi = ignore_leading_zeros ~nb_bits ~bound xbits in
      let* geq =
        foldM
          (fun acc (xi, bi) ->
            let op = if bi then Bool.band else Bool.bor in
            op xi acc)
          init
          xibi
      in
      Bool.bnot geq

    let is_upper_bounded ~bound x =
      assert (Z.zero <= bound && bound < S.order) ;
      let nb_bits = Z.numbits S.order in
      let bound_plus_alpha = Z.(bound + Utils.alpha) in
      let* xbits = bits_of_scalar ~shift:Utils.alpha ~nb_bits x in
      let init, xibi =
        ignore_leading_zeros ~nb_bits ~bound:bound_plus_alpha xbits
      in
      let* geq =
        foldM
          (fun acc (xi, bi) ->
            let op = if bi then Bool.band else Bool.bor in
            op xi acc)
          init
          xibi
      in
      Bool.bnot geq

    let geq (a, bound_a) (b, bound_b) =
      (* (a - b) + bound_b - 1 ∈ [0, bound_a + bound_b - 1) *)
      let* shifted_diff =
        Num.add ~qr:S.mone ~qc:(S.of_z Z.(pred bound_b)) a b
      in
      let nb_bits = Z.(numbits @@ pred (add bound_a bound_b)) in
      let* bits = bits_of_scalar ~nb_bits shifted_diff in
      let init, xibi =
        ignore_leading_zeros ~nb_bits ~bound:Z.(pred bound_b) bits
      in
      foldM
        (fun acc (xi, bi) ->
          let op = if bi then Bool.band else Bool.bor in
          op xi acc)
        init
        xibi
  end

  module Enum (N : sig
    val n : int
  end) =
  struct
    let switch_case k cases =
      let cases = of_list cases in
      assert (List.compare_length_with cases N.n = 0) ;
      let indexed = List.mapi (fun i x -> (i, x)) cases in
      foldM
        (fun c (i, ci) ->
          let* f = Num.is_eq_const k (S.of_z (Z.of_int i)) in
          Bool.ifthenelse f ci c)
        (snd @@ List.hd indexed)
        (List.tl indexed)
  end

  module Bytes = struct
    type bl = bool list

    let add ?(ignore_carry = false) a b =
      let ha, ta = (List.hd (of_list a), List.tl (of_list a)) in
      let hb, tb = (List.hd (of_list b), List.tl (of_list b)) in
      let* a_xor_b = Bool.xor ha hb in
      let* a_and_b = Bool.band ha hb in
      let* res, carry =
        fold2M
          (fun (res, c) a b ->
            let* p = Bool.full_adder a b c in
            let s, c = of_pair p in
            ret (s :: res, c))
          ([a_xor_b], a_and_b)
          ta
          tb
      in
      ret @@ to_list @@ List.rev (if ignore_carry then res else carry :: res)

    let xor a b =
      let* l = map2M Bool.xor (of_list a) (of_list b) in
      ret @@ to_list l

    let rotate a i =
      let split_n n l =
        let rec aux acc k l =
          if k = n then (List.rev acc, l)
          else
            match l with
            | h :: t -> aux (h :: acc) (k + 1) t
            | [] ->
                raise
                  (Invalid_argument
                     (Printf.sprintf "split_n: n=%d >= List.length l=%d" n k))
        in
        aux [] 0 l
      in
      let head, tail = split_n i (of_list a) in
      to_list @@ tail @ head
  end

  let add2 p1 p2 =
    let x1, y1 = of_pair p1 in
    let x2, y2 = of_pair p2 in
    let* x3 = Num.add x1 x2 in
    let* y3 = Num.add y1 y2 in
    ret (pair x3 y3)

  let constant_bool b =
    let* bit = constant_scalar (if b then S.one else S.zero) in
    ret @@ unsafe_bool_of_scalar bit

  let constant_bytes ?(le = false) b =
    let bl = Utils.bitlist ~le b in
    let* ws =
      foldM
        (fun ws bit ->
          let* w = constant_bool bit in
          ret (w :: ws))
        []
        bl
    in
    ret @@ to_list @@ List.rev ws

  let constant_uint32 ?(le = false) u32 =
    let b = Stdlib.Bytes.create 4 in
    Stdint.Uint32.to_bytes_big_endian u32 b 0 ;
    constant_bytes ~le b
end