123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475(*****************************************************************************)(* *)(* Open Source License *)(* Copyright (c) 2020 Nomadic Labs, <contact@nomadic-labs.com> *)(* Copyright (c) 2023 DaiLambda, Inc., <contact@dailambda.jp> *)(* *)(* Permission is hereby granted, free of charge, to any person obtaining a *)(* copy of this software and associated documentation files (the "Software"),*)(* to deal in the Software without restriction, including without limitation *)(* the rights to use, copy, modify, merge, publish, distribute, sublicense, *)(* and/or sell copies of the Software, and to permit persons to whom the *)(* Software is furnished to do so, subject to the following conditions: *)(* *)(* The above copyright notice and this permission notice shall be included *)(* in all copies or substantial portions of the Software. *)(* *)(* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR*)(* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *)(* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *)(* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER*)(* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING *)(* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER *)(* DEALINGS IN THE SOFTWARE. *)(* *)(*****************************************************************************)(* Transform multiplications by constants in a costlang expression to fixed
point arithmetic. Allows to make cost functions protocol-compatible. *)(* Modes of casting of float to int *)typecast_mode=Ceil|Floor|Round(* Parameters for conversion to fixed point *)typeoptions={precision:int;max_relative_error:float;cast_mode:cast_mode;inverse_scaling:int;resolution:int;}(* Handling bad floating point values. *)typefp_error=Bad_fpclassofFloat.fpclass|Negative_or_zero_fp(* Handling codegen errors. *)typefixed_point_transform_error=Term_is_not_closedofFree_variable.texceptionBad_floating_point_numberoffp_errorexceptionFixed_point_transform_erroroffixed_point_transform_error(* ------------------------------------------------------------------------- *)letdefault_options={precision=6;max_relative_error=0.1;cast_mode=Round;inverse_scaling=10;resolution=5;}(* ------------------------------------------------------------------------- *)(* Printers, encodings, etc. *)letpp_fixed_point_transform_errorfmtr(err:fixed_point_transform_error)=matcherrwith|Term_is_not_closeds->Format.fprintffmtr"Fixed_point_transform: Term is not closed (free variable %a \
encountered)"Free_variable.ppsletcast_mode_encoding=letopenData_encodinginunion[case~title:"Ceil"(Tag0)(constant"Ceil")(functionCeil->Some()|_->None)(fun()->Ceil);case~title:"Floor"(Tag1)(constant"Floor")(functionFloor->Some()|_->None)(fun()->Floor);case~title:"Round"(Tag2)(constant"Round")(functionRound->Some()|_->None)(fun()->Round);]letoptions_encoding=letopenData_encodinginconv(fun{precision;max_relative_error;cast_mode;inverse_scaling;resolution}->(precision,max_relative_error,cast_mode,inverse_scaling,resolution))(fun(precision,max_relative_error,cast_mode,inverse_scaling,resolution)->{precision;max_relative_error;cast_mode;inverse_scaling;resolution})(obj5(dft"precision"int31default_options.precision)(dft"max_relative_error"floatdefault_options.max_relative_error)(dft"cast_mode"cast_mode_encodingdefault_options.cast_mode)(dft"inverse_scaling"int31default_options.inverse_scaling)(dft"resolution"int31default_options.resolution))(* ------------------------------------------------------------------------- *)(* Error registration *)let()=Printexc.register_printer(funexn->matchexnwith|Bad_floating_point_numbererror->lets=matcherrorwith|Bad_fpclassfpcl->(matchfpclwith|FP_subnormal->"FP_subnormal"|FP_infinite->"FP_infinite"|FP_nan->"FP_nan"|_->assertfalse)|Negative_or_zero_fp->"<= 0"inSome(Printf.sprintf"Fixed_point_transform: Bad floating point number: %s"s)|Fixed_point_transform_errorerr->lets=Format.asprintf"%a"pp_fixed_point_transform_errorerrinSomes|_->None)(* ------------------------------------------------------------------------- *)(* Constant prettification *)letreclog10x=ifx<=0theninvalid_arg"log10"elseifx<=10then1else1+log10(x/10)letrecpowxn=ifn<0theninvalid_arg"pow"elseifn=0then1elseifn=1thenxelsex*powx(n-1)letsnap_to_grid~inverse_scaling~resolutionx=ifx=0then0elseletnot_significant=log10x/inverse_scalinginletgrid=resolution*pow10not_significantin(x+grid-1)/grid*grid(* ------------------------------------------------------------------------- *)(* Helpers *)letint_of_floatmodex=matchmodewith|Ceil->int_of_float(Float.ceilx)|Floor->int_of_float(Float.floorx)|Round->int_of_float(Float.roundx)(* Checks that a floating point number is 'good' *)letassert_fp_is_correct(x:float)=letfpcl=Float.classify_floatxinmatchfpclwith|FP_subnormal|FP_infinite|FP_nan->raise(Bad_floating_point_number(Bad_fpclassfpcl))|FP_normalwhenx<=0.0->raise(Bad_floating_point_numberNegative_or_zero_fp)|_->fpclletcast_to_intmax_relative_errormodef:int=leti=int_of_floatmodefinletfi=float_of_intiinletre=abs_float(f-.fi)/.abs_floatfinifre>max_relative_errorthenFormat.eprintf"Warning: Fixed_point_transform: Imprecise integer cast of %f to %d: %f \
%% relative error@."fi(re*.100.);i(* ------------------------------------------------------------------------- *)(* A minimal language in which to perform fixed-point multiplication of
a 'size repr' by a float *)moduletypeFixed_point_lang_sig=sigtype'areprtypesizevalshift_right:sizerepr->int->sizereprval(+):sizerepr->sizerepr->sizereprval(*):sizerepr->sizerepr->sizereprvalint:int->sizereprendmoduleFixed_point_arithmetic(Lang:Fixed_point_lang_sig):sig(** [approx_mult precision i f] generates fixed-precision multiplication
of [i * f] by positive constants. [precision] is a paramter to control
how many bit shifts are used.
*)valapprox_mult:cast_mode->int->Lang.sizeLang.repr->float->Lang.sizeLang.reprend=struct(* IEEE754 redux
-------------
Format of a (full-precision, ie 64 bit) floating point number:
. 1 bit of sign
. 11 bits of exponent, implicitly centered on 0 (-1022 to 1023)
. 52 bits of mantissa (+ 1 implicit)
value of a fp number = sign * mantissa * 2^{exponent - 1023}
(with the exceptions of nans, infinities and denormalised numbers) *)(* Extract ith bit from a float. *)letbit(x:float)(i:int)=assert(not(i<0||i>63));letbits=Int64.bits_of_floatxinInt64.(logand(shift_rightbitsi)one)(* All bits of a float:
all_bits x = [sign] @ exponent @ mantissa *)letall_bits(x:float):int64list=List.init~when_negative_length:()64(funi->bitxi)|>(* 64 >= 0 *)WithExceptions.Result.get_ok~loc:__LOC__|>List.rev(* take n first elements from a list *)lettakenl=letrectakenlacc=ifn<=0then(List.revacc,l)elsematchlwith|[]->Stdlib.failwith"take"|hd::tl->take(n-1)tl(hd::acc)intakenl[](* Split a float into sign/exponent/mantissa *)letsplitbits=letsign,rest=take1bitsinletexpo,rest=take11restinletmant,_=take52restin(sign,expo,mant)(* Convert bits of exponent to int. *)letexponent_bits_to_int(l:int64list)=letrecexponent_to_int(l:int64list)(index:int):int64=matchlwith|[]->-1023L|bit::tail->lettail=exponent_to_inttail(index+1)inInt64.(add(shift_leftbitindex)tail)inexponent_to_int(List.revl)0(* Decompose a float into sign/exponent/mantissa *)letdecompose(x:float)=split(all_bitsx)letincrement_bitsexpbits=letrecf=function|[]->(true,[])|0L::rest->letup,rest=frestin(false,(ifupthen1Lelse0L)::rest)|1L::rest->letup,rest=frestinifupthen(true,0L::rest)else(false,1L::rest)|_->assertfalseinletup,bits=fbitsinifupthen(exp+1,1L::bits)else(exp,bits)(* Generate fixed-precision multiplication by positive constants. *)letapprox_multmode(precision:int)(i:Lang.sizeLang.repr)(x:float):Lang.sizeLang.repr=assert(precision>0);letfpcl=assert_fp_is_correctxinmatchfpclwith|FP_zero->Lang.int0|_->let_sign,exp,mant=decomposexinletexp=Int64.to_int@@exponent_bits_to_intexpin(* The mantissa is always implicitly prefixed by one (except for
denormalized numbers, excluded here). *)letbits=1L::mantin(* Get the top [precision] bits *)letbits,rest=takeprecisionbitsin(* Rounding. [bits_rounded] has [precision+1] bits at most.
The number of ones in it is at most [precision] *)letexp,bits_rounded=matchmodewith|Ceil->ifList.for_all(funx->x=0L)restthen(exp,bits)elseincrement_bitsexpbits|Floor->(exp,bits)|Round->(matchrestwith|1L::_->increment_bitsexpbits|[]|0L::_->(exp,bits)|_->assertfalse)in(* convert bits for < 1.0 to sum of powers of 2 computed with shifts *)let_,integer,fracs=List.fold_left(fun(k,integer,fracs)bit->letinteger,fracs=ifbit=1Lthenifexp-k<0then(integer,Lang.shift_righti(k-exp)::fracs)else(integer+(1lsl(exp-k)),fracs)else(integer,fracs)in(k+1,integer,fracs))(0,0,[])bits_roundedinifinteger=0thenmatchList.revfracswith|[]->assertfalse|f::fracs->List.fold_left(funsumt->Lang.(sum+t))ffracselseList.fold_left(funtsum->Lang.(sum+t))(ifinteger=1thenielseLang.(i*intinteger))(List.revfracs)end(* [Convert_mult] approximates [float] values to integers:
- The multiplications of the form [float * term] or [term * float]
to integer-only expressions.
- [float] constants to its nearest grid point
It is assumed that the term is _closed_, i.e. contains no free variables.
*)moduleConvert_mult(P:sigvaloptions:optionsend)(X:Costlang.S):sigincludeCostlang.Swithtypesize=X.sizevalprj:'arepr->'aX.reprend=structtypesize=X.sizeletsize_ty=X.size_tytype'arepr=Term:'aX.repr->'arepr|Float:float->X.sizereprlet{precision;max_relative_error;cast_mode;inverse_scaling;resolution}=P.optionsmoduleFPA=Fixed_point_arithmetic(X)(* Cast to int then snap to the nearest grid point *)letcast_and_snapf=X.int@@snap_to_grid~inverse_scaling~resolution@@cast_to_intmax_relative_errorcast_modef(* Any float left is converted to the nearest grid point *)letprj(typea)(term:arepr):aX.repr=matchtermwithTermt->t|Floatf->cast_and_snapf(* By default, any float is converted to the nearest grid point *)letlift_unopopx=matchxwith|Termx->Term(opx)|Floatx->Term(op@@cast_and_snapx)(* By default, any float is converted to the nearest grid point *)letlift_binopopxy=match(x,y)with|Termx,Termy->Term(opxy)|Termx,Floaty->Term(opx(cast_and_snapy))|Floatx,Termy->Term(op(cast_and_snapx)y)|Floatx,Floaty->Term(op(cast_and_snapx)(cast_and_snapy))letgensym:unit->string=letx=ref0infun()->letv=!xinincrx;"v"^string_of_intvletfalse_=TermX.false_lettrue_=TermX.true_letfloatf=Floatf(* Integers are kept as they are *)letinti=Term(X.inti)let(+)=lift_binopX.(+)letsat_sub=lift_binopX.sat_sublet(*)xy=match(x,y)with|Termx,Termy->TermX.(x*y)|Termx,Floaty|Floaty,Termx->(* let-bind the non-constant term [x] to avoid copying it. *)Term(X.let_~name:(gensym())x(funx->FPA.approx_multCeilprecisionxy))|Floatx,Floaty->Float(x*.y)let(/)=lift_binopX.(/)letmax=lift_binopX.maxletmin=lift_binopX.minletshift_leftxi=lift_unop(funx->X.shift_leftxi)xletshift_rightxi=lift_unop(funx->X.shift_rightxi)xletlog2=lift_unopX.log2letsqrt=lift_unopX.sqrtletfree~name=raise(Fixed_point_transform_error(Term_is_not_closedname))letlt=lift_binopX.ltleteq=lift_binopX.eqletlam'(typeab)~name(ty:aCostlang.Ty.t)(f:arepr->brepr):(a->b)repr=Term(X.lam'~namety(funx->matchf(Termx)withTermy->y|Floatf->cast_and_snapf))letlam~name=lam'~namesize_tyletapp(typeab)(fn:(a->b)repr)(arg:arepr):brepr=match(fn,arg)with|Termfn,Termarg->Term(X.appfnarg)|Termfn,Floatf->Term(X.appfn(cast_and_snapf))|Float_,_->assertfalseletlet_(typeab)~name(m:arepr)(fn:arepr->brepr):brepr=matchmwith|Termm->Term(X.let_~namem(funx->matchfn(Termx)withTermy->y|Floatf->cast_and_snapf))|Floatf->Term(X.let_~name(cast_and_snapf)(funx->matchfn(Termx)withTermy->y|Floatf->cast_and_snapf))letif_condiftiff=Term(X.if_(prjcond)(prjift)(prjiff))endmoduleApply(P:sigvaloptions:optionsend):Costlang.Transform=functor(X:Costlang.S)->Convert_mult(P)(X)