123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384(*****************************************************************************)(* *)(* MIT License *)(* Copyright (c) 2022 Nomadic Labs <contact@nomadic-labs.com> *)(* *)(* Permission is hereby granted, free of charge, to any person obtaining a *)(* copy of this software and associated documentation files (the "Software"),*)(* to deal in the Software without restriction, including without limitation *)(* the rights to use, copy, modify, merge, publish, distribute, sublicense, *)(* and/or sell copies of the Software, and to permit persons to whom the *)(* Software is furnished to do so, subject to the following conditions: *)(* *)(* The above copyright notice and this permission notice shall be included *)(* in all copies or substantial portions of the Software. *)(* *)(* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR*)(* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *)(* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *)(* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER*)(* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING *)(* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER *)(* DEALINGS IN THE SOFTWARE. *)(* *)(*****************************************************************************)openLang_coremoduletypeLIB=sigincludeCOMMONvalfoldiM:('a->int->'at)->'a->int->'atvalfold2M:('a->'b->'c->'at)->'a->'blist->'clist->'atvalmapM:('a->'bt)->'alist->'blisttvalmap2M:('a->'b->'ct)->'alist->'blist->'clisttvaliterM:('a->unitreprt)->'alist->unitreprtvaliter2M:('a->'b->unitreprt)->'alist->'blist->unitreprtmoduleBool:sigincludeBOOL(** Returns the pair (s, c_out), as per
https://en.wikipedia.org/wiki/Adder_(electronics)#Full_adder *)valfull_adder:boolrepr->boolrepr->boolrepr->(bool*bool)reprtendwithtypescalar=scalarandtype'arepr='areprandtype'at='atmoduleNum:sigincludeNUMvalsquare:scalarrepr->scalarreprtvalpow:scalarrepr->boolreprlist->scalarreprtvaladd_list:?qc:S.t->?coeffs:S.tlist->scalarlistrepr->scalarreprtvalmul_list:scalarlistrepr->scalarreprtvalmul_by_constant:S.t->scalarrepr->scalarreprtvalscalar_of_bytes:boollistrepr->scalarreprtvalis_eq_const:scalarrepr->S.t->boolreprtvalassert_eq_const:scalarrepr->S.t->unitreprt(** [is_upper_bounded ~bound x] returns whether the scalar [x] is
strictly lower than [bound] when [x] is interpreted as an integer
from [0] to [p-1] (being [p] the scalar field order).
This circuit is total (and more expensive than our version below). *)valis_upper_bounded:bound:Z.t->scalarrepr->boolreprt(** Same as [is_upper_bounded] but cheaper and partial.
[is_upper_bounded_unsafe ~bound l] is unsatisfiable if l cannot be
represented in binary with [Z.numbits bound] bits. *)valis_upper_bounded_unsafe:?nb_bits:int->bound:Z.t->scalarrepr->boolreprt(** [geq (a, bound_a) (b, bound_b)] returns the boolean wire representing
a >= b.
Pre-condition: a ∈ [0, bound_a) ∧ b ∈ [0, bound_b) *)valgeq:scalarrepr*Z.t->scalarrepr*Z.t->boolreprtendwithtypescalar=scalarandtype'arepr='areprandtype'at='atmoduleEnum(N:sigvaln:intend):sig(* [switch_case k l] returns the k-th element of the list [l] if k ∈ [0,n)
or the first element of [l] otherwise. *)valswitch_case:scalarrepr->'alistrepr->'areprtendmoduleBytes:sigtypebl=boollistvaladd:?ignore_carry:bool->blrepr->blrepr->blreprtvalxor:blrepr->blrepr->blreprtvalrotate:blrepr->int->blreprendvaladd2:(scalar*scalar)repr->(scalar*scalar)repr->(scalar*scalar)reprtvalconstant_bool:bool->boolreprtvalconstant_bytes:?le:bool->bytes->Bytes.blreprtvalconstant_uint32:?le:bool->Stdint.uint32->Bytes.blreprtendmoduleLib(C:COMMON)=structincludeCletfoldiM:('a->int->'at)->'a->int->'at=funfen->foldMfe(List.initn(funi->i))letfold2Mfacclsrs=foldM(funacc(l,r)->facclr)acc(List.combinelsrs)letmapMfl=let*l=foldM(funacce->let*e=feinret@@(e::acc))[]linret@@List.revlletmap2Mflsrs=mapM(fun(l,r)->flr)(List.combinelsrs)letiterMfl=foldM(fun_a->fa)unitlletiter2Mflr=iterM(fun(l,r)->flr)(List.combinelr)moduleBool=structincludeBoolletfull_adderabc_in=let*a_xor_b=xorabinlet*a_xor_b_xor_c=xora_xor_bc_ininlet*a_xor_b_and_c=banda_xor_bc_ininlet*a_and_b=bandabinlet*c=bora_xor_b_and_ca_and_binret(paira_xor_b_xor_cc)endmoduleNum=structincludeNumletsquarel=mulllletpowxn_list=let*init=let*left=constant_scalarS.oneinret(left,x)inlet*res,_acc=foldM(fun(res,acc)bool->let*res_true=mulresaccinlet*res=Bool.ifthenelseboolres_trueresinlet*acc=mulaccaccinret(res,acc))initn_listinretresletadd_list?(qc=S.zero)?(coeffs=[])l=letl=of_listlinletq=ifcoeffs!=[]thencoeffselseList.init(List.lengthl)(fun_->S.one)inassert(List.compare_lengthsql=0);match(l,q)with|x1::x2::xs,ql::qr::qs->let*res=Num.add~qc~ql~qrx1x2infold2M(funaccxql->Num.add~qlxacc)resxsqs|[x],[ql]->Num.add_constant~qlqcx|[],[]->constant_scalarqc|_,_->assertfalseletmul_listl=matchof_listlwith[]->assertfalse|x::xs->foldMNum.mulxxsletmul_by_constantsx=Num.add_constant~ql:sS.zerox(* Evaluates P(X) = \sum_i bᵢ Xⁱ at 2 with Horner's method:
P(2) = b₀ + 2 (b₁+ 2 (b₂ + 2(…))). *)letscalar_of_bytesb=let*zero=constant_scalarS.zeroinfoldM(funaccb->addacc(scalar_of_boolb)~ql:S.(one+one)~qr:S.one)zero(List.rev(of_listb))letassert_eq_constls=Num.assert_custom~ql:S.mone~qc:slllletis_eq_constls=let*diff=add_constant~ql:S.moneslinis_zerodiff(* Function used by [is_upper_bounded(_unsafe)] *)letignore_leading_zeros~nb_bits~boundxbits=(* We can ignore all leading zeros in the bound's little-endian binary
decomposition. The assertion cannot be satisfied if they are all zeros. *)letrecshave_zeros=function|[]->raise(Invalid_argument"is_upper_bounded cannot be satisfied on bound = 0")|(x,true)::tl->(x,tl)|(_,false)::tl->shave_zerostlinList.combine(of_listxbits)(Utils.bool_list_of_z~nb_bitsbound)|>shave_zeros(* Let [(bn,...,b0)] and [(xn,...,x0)] be binary representations
of [bound] and [x] respectively where the least significant bit is
indexed by [0]. Let [op_i = if b_i = 1 then band else bor] for all [i].
Predicate [x] < [bound] can be expressed as the negation of predicate:
[ xn op_n (... (x1 op_1 (x0 op_0 true))) ].
Intuitively we need to carry through a flag that indicates if up to
step i, x is greater than b. In order for x[0,i] to be greater than
b[0,i]:
- if b_i = one then x_i will need to match it and the flag must be true.
- if b_i = zero then x needs to be one if the flag is false or it can
have any value if the flag is already true. *)letis_upper_bounded_unsafe?nb_bits~boundx=assert(Z.zero<=bound&&bound<S.order);letnb_bits=Option.value~default:(ifZ.(equalboundzero)then1elseZ.numbitsbound)nb_bitsinlet*xbits=bits_of_scalar~nb_bitsxinletinit,xibi=ignore_leading_zeros~nb_bits~boundxbitsinlet*geq=foldM(funacc(xi,bi)->letop=ifbithenBool.bandelseBool.borinopxiacc)initxibiinBool.bnotgeqletis_upper_bounded~boundx=assert(Z.zero<=bound&&bound<S.order);letnb_bits=Z.numbitsS.orderinletbound_plus_alpha=Z.(bound+Utils.alpha)inlet*xbits=bits_of_scalar~shift:Utils.alpha~nb_bitsxinletinit,xibi=ignore_leading_zeros~nb_bits~bound:bound_plus_alphaxbitsinlet*geq=foldM(funacc(xi,bi)->letop=ifbithenBool.bandelseBool.borinopxiacc)initxibiinBool.bnotgeqletgeq(a,bound_a)(b,bound_b)=(* (a - b) + bound_b - 1 ∈ [0, bound_a + bound_b - 1) *)let*shifted_diff=Num.add~qr:S.mone~qc:(S.of_zZ.(predbound_b))abinletnb_bits=Z.(numbits@@pred(addbound_abound_b))inlet*bits=bits_of_scalar~nb_bitsshifted_diffinletinit,xibi=ignore_leading_zeros~nb_bits~bound:Z.(predbound_b)bitsinfoldM(funacc(xi,bi)->letop=ifbithenBool.bandelseBool.borinopxiacc)initxibiendmoduleEnum(N:sigvaln:intend)=structletswitch_casekcases=letcases=of_listcasesinassert(List.compare_length_withcasesN.n=0);letindexed=List.mapi(funix->(i,x))casesinfoldM(func(i,ci)->let*f=Num.is_eq_constk(S.of_z(Z.of_inti))inBool.ifthenelsefcic)(snd@@List.hdindexed)(List.tlindexed)endmoduleBytes=structtypebl=boollistletadd?(ignore_carry=false)ab=letha,ta=(List.hd(of_lista),List.tl(of_lista))inlethb,tb=(List.hd(of_listb),List.tl(of_listb))inlet*a_xor_b=Bool.xorhahbinlet*a_and_b=Bool.bandhahbinlet*res,carry=fold2M(fun(res,c)ab->let*p=Bool.full_adderabcinlets,c=of_pairpinret(s::res,c))([a_xor_b],a_and_b)tatbinret@@to_list@@List.rev(ifignore_carrythenreselsecarry::res)letxorab=let*l=map2MBool.xor(of_lista)(of_listb)inret@@to_listlletrotateai=letsplit_nnl=letrecauxacckl=ifk=nthen(List.revacc,l)elsematchlwith|h::t->aux(h::acc)(k+1)t|[]->raise(Invalid_argument(Printf.sprintf"split_n: n=%d >= List.length l=%d"nk))inaux[]0linlethead,tail=split_ni(of_lista)into_list@@tail@headendletadd2p1p2=letx1,y1=of_pairp1inletx2,y2=of_pairp2inlet*x3=Num.addx1x2inlet*y3=Num.addy1y2inret(pairx3y3)letconstant_boolb=let*bit=constant_scalar(ifbthenS.oneelseS.zero)inret@@unsafe_bool_of_scalarbitletconstant_bytes?(le=false)b=letbl=Utils.bitlist~lebinlet*ws=foldM(funwsbit->let*w=constant_boolbitinret(w::ws))[]blinret@@to_list@@List.revwsletconstant_uint32?(le=false)u32=letb=Stdlib.Bytes.create4inStdint.Uint32.to_bytes_big_endianu32b0;constant_bytes~lebend