From 4fa9991b33b992c81bc3f6ce2b059e0e5835fc0f Mon Sep 17 00:00:00 2001 From: Andrea Caforio Date: Thu, 4 Jul 2024 11:00:20 +0200 Subject: [PATCH] repackage minimax --- .../minimax/minimax_composite_polynomial.go | 258 ++++++++++++++++++ .../minimax_composite_polynomial_evaluator.go | 90 ++++++ .../reals_sigmoid_minimax/main.go | 23 +- 3 files changed, 360 insertions(+), 11 deletions(-) create mode 100644 circuits/minimax/minimax_composite_polynomial.go create mode 100644 circuits/minimax/minimax_composite_polynomial_evaluator.go diff --git a/circuits/minimax/minimax_composite_polynomial.go b/circuits/minimax/minimax_composite_polynomial.go new file mode 100644 index 00000000..c9eecbd0 --- /dev/null +++ b/circuits/minimax/minimax_composite_polynomial.go @@ -0,0 +1,258 @@ +package minimax + +import ( + "fmt" + "math" + "math/big" + + "github.com/tuneinsight/lattigo/v5/utils" + "github.com/tuneinsight/lattigo/v5/utils/bignum" +) + +// MinimaxCompositePolynomial is a struct storing P(x) = pk(x) o pk-1(x) o ... o p1(x) o p0(x). +type MinimaxCompositePolynomial []bignum.Polynomial + +// NewMinimaxCompositePolynomial creates a new MinimaxCompositePolynomial from a list of coefficients. +// Coefficients are expected to be given in the Chebyshev basis. +func NewMinimaxCompositePolynomial(coeffsStr [][]string) MinimaxCompositePolynomial { + polys := make([]bignum.Polynomial, len(coeffsStr)) + + for i := range coeffsStr { + + coeffs := parseCoeffs(coeffsStr[i]) + + poly := bignum.NewPolynomial( + bignum.Chebyshev, + coeffs, + &bignum.Interval{ + A: *bignum.NewFloat(-1, coeffs[0].Prec()), + B: *bignum.NewFloat(1, coeffs[0].Prec()), + }, + ) + + polys[i] = poly + } + + return MinimaxCompositePolynomial(polys) +} + +func (mcp MinimaxCompositePolynomial) MaxDepth() (depth int) { + for i := range mcp { + depth = utils.Max(depth, mcp[i].Depth()) + } + return +} + +func (mcp MinimaxCompositePolynomial) Evaluate(x interface{}) (y *bignum.Complex) { + y = mcp[0].Evaluate(x) + + for _, p := range mcp[1:] { + y = p.Evaluate(y) + } + + return +} + +// CoeffsSignX2Cheby (from https://eprint.iacr.org/2019/1234.pdf) are the coefficients +// of 1.5*x - 0.5*x^3 in Chebyshev basis. +// Evaluating this polynomial on values already close to -1, or 1 ~doubles the number of +// of correct digits. +// For example, if x = -0.9993209 then p(x) = -0.999999308 +// This polynomial can be composed after the minimax composite polynomial to double the +// output precision (up to the scheme precision) each time it is evaluated. +var CoeffsSignX2Cheby = []string{"0", "1.125", "0", "-0.125"} + +// CoeffsSignX4Cheby (from https://eprint.iacr.org/2019/1234.pdf) are the coefficients +// of 35/16 * x - 35/16 * x^3 + 21/16 * x^5 - 5/16 * x^7 in Chebyshev basis. +// Evaluating this polynomial on values already close to -1, or 1 ~quadruples the number of +// of correct digits. +// For example, if x = -0.9993209 then p(x) = -0.9999999999990705 +// This polynomial can be composed after the minimax composite polynomial to quadruple the +// output precision (up to the scheme precision) each time it is evaluated. +var CoeffsSignX4Cheby = []string{"0", "1.1962890625", "0", "-0.2392578125", "0", "0.0478515625", "0", "-0.0048828125"} + +// GenMinimaxCompositePolynomialForSign generates the minimax composite polynomial +// P(x) = pk(x) o pk-1(x) o ... o p1(x) o p0(x) of the sign function in their interval +// [min-err, -2^{-alpha}] U [2^{-alpha}, max+err] where alpha is the desired distinguishing +// precision between two values and err an upperbound on the scheme error. +// +// The sign function is defined as: -1 if -1 <= x < 0, 0 if x = 0, 1 if 0 < x <= 1. +// +// See GenMinimaxCompositePolynomial for information about how to instantiate and +// parameterize each input value of the algorithm. +func GenMinimaxCompositePolynomialForSign(prec uint, logalpha, logerr int, deg []int) { + + coeffs := GenMinimaxCompositePolynomial(prec, logalpha, logerr, deg, bignum.Sign) + + decimals := int(float64(logalpha)/math.Log2(10)+0.5) + 10 + + fmt.Println("COEFFICIENTS:") + fmt.Printf("{\n") + for i := range coeffs { + PrettyPrintCoefficients(decimals, coeffs[i], true, false, false) + } + fmt.Printf("},\n") +} + +// GenMinimaxCompositePolynomial generates the minimax composite polynomial +// P(x) = pk(x) o pk-1(x) o ... o p1(x) o p0(x) for the provided function in the interval +// in their interval [min-err, -2^{-alpha}] U [2^{-alpha}, max+err] where alpha is +// the desired distinguishing precision between two values and err an upperbound on +// the scheme error. +// +// The user must provide the following inputs: +// - prec: the bit precision of the big.Float values used by the algorithm to compute the polynomials. +// This will impact the speed of the algorithm. +// A too low precision can prevent convergence or induce a slope zero during the zero finding. +// A sign that the precision is too low is when the iteration continue without the error getting smaller. +// - logalpha: log2(alpha) +// - logerr: log2(err), the upperbound on the scheme precision. Usually this value should be smaller or equal to logalpha. +// Correctly setting this value is mandatory for correctness, because if x is outside of the interval +// (i.e. smaller than -1-e or greater than 1+e), then the values will explode during the evaluation. +// Note that it is not required to apply change of interval [-1, 1] -> [-1-e, 1+e] because the function to evaluate +// is the sign (i.e. it will evaluate to the same value). +// - deg: the degree of each polynomial, ordered as follow [deg(p0(x)), deg(p1(x)), ..., deg(pk(x))]. +// It is highly recommended that deg(p0) <= deg(p1) <= ... <= deg(pk) for optimal approximation. +// +// The polynomials are returned in the Chebyshev basis and pre-scaled for +// the interval [-1, 1] (no further scaling is required on the ciphertext). +// +// Be aware that finding the minimax polynomials can take a while (in the order of minutes for high precision when using large degree polynomials). +// +// The function will print information about each step of the computation in real time so that it can be monitored. +// +// The underlying algorithm use the multi-interval Remez algorithm of https://eprint.iacr.org/2020/834.pdf. +func GenMinimaxCompositePolynomial(prec uint, logalpha, logerr int, deg []int, f func(*big.Float) *big.Float) (coeffs [][]*big.Float) { + decimals := int(float64(logalpha)/math.Log2(10)+0.5) + 10 + + // Precision of the output value of the sign polynomial + alpha := math.Exp2(-float64(logalpha)) + + // Expected upperbound scheme error + e := bignum.NewFloat(math.Exp2(-float64(logerr)), prec) + + // Maximum number of iterations + maxIters := 50 + + // Scan step for finding zeroes of the error function + scanStep := bignum.NewFloat(1e-3, prec) + + // Interval [-1, alpha] U [alpha, 1] + intervals := []bignum.Interval{ + {A: *bignum.NewFloat(-1, prec), B: *bignum.NewFloat(-alpha, prec), Nodes: 1 + ((deg[0] + 1) >> 1)}, + {A: *bignum.NewFloat(alpha, prec), B: *bignum.NewFloat(1, prec), Nodes: 1 + ((deg[0] + 1) >> 1)}, + } + + // Adds the error to the interval + // [A, -alpha] U [alpha, B] becomes [A-e, -alpha] U [alpha, B+e] + intervals[0].A.Sub(&intervals[0].A, e) + intervals[1].B.Add(&intervals[1].B, e) + + // Parameters of the minimax approximation + params := bignum.RemezParameters{ + Function: f, + Basis: bignum.Chebyshev, + Intervals: intervals, + ScanStep: scanStep, + Prec: prec, + OptimalScanStep: true, + } + + fmt.Printf("P[0]\n") + fmt.Printf("Interval: [%.*f, %.*f] U [%.*f, %.*f]\n", decimals, &intervals[0].A, decimals, &intervals[0].B, decimals, &intervals[1].A, decimals, &intervals[1].B) + r := bignum.NewRemez(params) + r.Approximate(maxIters, alpha) + //r.ShowCoeffs(decimals) + r.ShowError(decimals) + fmt.Println() + + coeffs = make([][]*big.Float, len(deg)) + + for i := 1; i < len(deg); i++ { + + // New interval as [-(1+max_err), -(1-min_err)] U [1-min_err, 1+max_err] + maxInterval := bignum.NewFloat(1, prec) + maxInterval.Add(maxInterval, r.MaxErr) + + minInterval := bignum.NewFloat(1, prec) + minInterval.Sub(minInterval, r.MinErr) + + // Extends the new interval by the scheme error + // [-(1+max_err), -(1-min_err)] U [1-min_err, 1 + max_err] becomes [-(1+max_err+e), -(1-min_err-e)] U [1-min_err-e, 1+max_err+e] + maxInterval.Add(maxInterval, e) + minInterval.Sub(minInterval, e) + + intervals = []bignum.Interval{ + {A: *new(big.Float).Neg(maxInterval), B: *new(big.Float).Neg(minInterval), Nodes: 1 + ((deg[i] + 1) >> 1)}, + {A: *minInterval, B: *maxInterval, Nodes: 1 + ((deg[i] + 1) >> 1)}, + } + + coeffs[i-1] = make([]*big.Float, deg[i-1]+1) + for j := range coeffs[i-1] { + coeffs[i-1][j] = new(big.Float).Set(r.Coeffs[j]) + coeffs[i-1][j].Quo(coeffs[i-1][j], maxInterval) // Interval normalization + } + + params := bignum.RemezParameters{ + Function: f, + Basis: bignum.Chebyshev, + Intervals: intervals, + ScanStep: scanStep, + Prec: prec, + OptimalScanStep: true, + } + + fmt.Printf("P[%d]\n", i) + fmt.Printf("Interval: [%.*f, %.*f] U [%.*f, %.*f]\n", decimals, &intervals[0].A, decimals, &intervals[0].B, decimals, &intervals[1].A, decimals, &intervals[1].B) + r = bignum.NewRemez(params) + r.Approximate(maxIters, alpha) + //r.ShowCoeffs(decimals) + r.ShowError(decimals) + fmt.Println() + } + + // Since this is the last polynomial, we can skip the interval scaling. + coeffs[len(deg)-1] = make([]*big.Float, deg[len(deg)-1]+1) + for j := range coeffs[len(deg)-1] { + coeffs[len(deg)-1][j] = new(big.Float).Set(r.Coeffs[j]) + } + + f64, _ := r.MaxErr.Float64() + fmt.Printf("Output Precision: %f\n", math.Log2(f64)) + fmt.Println() + + return coeffs +} + +// PrettyPrintCoefficients prints the coefficients formatted. +// If odd = true, even coefficients are zeroed. +// If even = true, odd coefficients are zeroed. +func PrettyPrintCoefficients(decimals int, coeffs []*big.Float, odd, even, first bool) { + fmt.Printf("{") + for i, c := range coeffs { + if (i&1 == 1 && odd) || (i&1 == 0 && even) || (i == 0 && first) { + fmt.Printf("\"%.*f\", ", decimals, c) + } else { + fmt.Printf("\"0\", ") + } + + } + fmt.Printf("},\n") +} + +func parseCoeffs(coeffsStr []string) (coeffs []*big.Float) { + + var prec uint + for _, c := range coeffsStr { + prec = utils.Max(prec, uint(len(c))) + } + + prec = uint(float64(prec)*3.3219280948873626 + 0.5) // max(float64, digits * log2(10)) + + coeffs = make([]*big.Float, len(coeffsStr)) + for i := range coeffsStr { + coeffs[i], _ = new(big.Float).SetPrec(prec).SetString(coeffsStr[i]) + } + + return +} diff --git a/circuits/minimax/minimax_composite_polynomial_evaluator.go b/circuits/minimax/minimax_composite_polynomial_evaluator.go new file mode 100644 index 00000000..de9b3fbd --- /dev/null +++ b/circuits/minimax/minimax_composite_polynomial_evaluator.go @@ -0,0 +1,90 @@ +package minimax + +import ( + "fmt" + + "github.com/tuneinsight/lattigo/v5/circuits/bootstrapping" + "github.com/tuneinsight/lattigo/v5/circuits/polynomial/polyfloat" + "github.com/tuneinsight/lattigo/v5/core/rlwe" + "github.com/tuneinsight/lattigo/v5/ring" + "github.com/tuneinsight/lattigo/v5/schemes/ckks" +) + +// MinimaxCompositePolynomialEvaluator is an evaluator used to evaluate composite polynomials on ciphertexts. +// All fields of this struct are publics, enabling custom instantiations. +type MinimaxCompositePolynomialEvaluator struct { + *ckks.Evaluator + *polyfloat.PolynomialEvaluator + BtsEval *bootstrapping.Evaluator + Parameters ckks.Parameters +} + +// NewMinimaxCompositePolynomialEvaluator instantiates a new MinimaxCompositePolynomialEvaluator. +// The default hefloat.Evaluator is compliant to the EvaluatorForMinimaxCompositePolynomial interface. +// This method is allocation free. +func NewMinimaxCompositePolynomialEvaluator(params ckks.Parameters, eval *ckks.Evaluator, btsEval *bootstrapping.Evaluator) *MinimaxCompositePolynomialEvaluator { + return &MinimaxCompositePolynomialEvaluator{eval, polyfloat.NewPolynomialEvaluator(params, eval), btsEval, params} +} + +// Evaluate evaluates the provided MinimaxCompositePolynomial on the input ciphertext. +func (eval MinimaxCompositePolynomialEvaluator) Evaluate(ct *rlwe.Ciphertext, mcp MinimaxCompositePolynomial) (res *rlwe.Ciphertext, err error) { + + params := eval.Parameters + + btp := eval.BtsEval + + levelsConsumedPerRescaling := params.LevelsConsumedPerRescaling() + + // Checks that the number of levels available after the bootstrapping is enough to evaluate all polynomials + if maxDepth := mcp.MaxDepth() * levelsConsumedPerRescaling; params.MaxLevel() < maxDepth+btp.MinimumInputLevel() { + return nil, fmt.Errorf("parameters do not enable the evaluation of the minimax composite polynomial, required levels is %d but parameters only provide %d levels", maxDepth+btp.MinimumInputLevel(), params.MaxLevel()) + } + + res = ct.CopyNew() + + for _, poly := range mcp { + + // Checks that res has enough level to evaluate the next polynomial, else bootstrap + if res.Level() < poly.Depth()*params.LevelsConsumedPerRescaling()+btp.MinimumInputLevel() { + if res, err = btp.Bootstrap(res); err != nil { + return + } + } + + // Define the scale that res must have after the polynomial evaluation. + // If we use the regular CKKS (with complex values), we chose a scale to be + // half of the desired scale, so that (x + conj(x)/2) has the correct scale. + var targetScale rlwe.Scale + if params.RingType() == ring.Standard { + targetScale = params.DefaultScale().Div(rlwe.NewScale(2)) + } else { + targetScale = params.DefaultScale() + } + + // Evaluate the polynomial + if res, err = eval.PolynomialEvaluator.Evaluate(res, poly, targetScale); err != nil { + return nil, fmt.Errorf("evaluate polynomial: %w", err) + } + + // Clean the imaginary part (else it tends to explode) + if params.RingType() == ring.Standard { + + // Reassigns the scale back to the original one + res.Scale = res.Scale.Mul(rlwe.NewScale(2)) + + var resConj *rlwe.Ciphertext + if resConj, err = eval.ConjugateNew(res); err != nil { + return + } + + if err = eval.Add(res, resConj, res); err != nil { + return + } + } + } + + // Avoids float errors + res.Scale = ct.Scale + + return +} diff --git a/examples/single_party/applications/reals_sigmoid_minimax/main.go b/examples/single_party/applications/reals_sigmoid_minimax/main.go index 89533a93..632eaa44 100644 --- a/examples/single_party/applications/reals_sigmoid_minimax/main.go +++ b/examples/single_party/applications/reals_sigmoid_minimax/main.go @@ -6,21 +6,22 @@ import ( "math" "math/big" + "github.com/tuneinsight/lattigo/v5/circuits/polynomial/polyfloat" "github.com/tuneinsight/lattigo/v5/core/rlwe" - "github.com/tuneinsight/lattigo/v5/he/hefloat" "github.com/tuneinsight/lattigo/v5/ring" + "github.com/tuneinsight/lattigo/v5/schemes/ckks" "github.com/tuneinsight/lattigo/v5/utils/bignum" "github.com/tuneinsight/lattigo/v5/utils/sampling" ) func main() { var err error - var params hefloat.Parameters + var params ckks.Parameters // 128-bit secure parameters enabling depth-7 circuits. // LogN:14, LogQP: 431. - if params, err = hefloat.NewParametersFromLiteral( - hefloat.ParametersLiteral{ + if params, err = ckks.NewParametersFromLiteral( + ckks.ParametersLiteral{ LogN: 14, // log2(ring degree) LogQ: []int{55, 45, 45, 45, 45, 45, 45, 45}, // log2(primes Q) (ciphertext modulus) LogP: []int{61}, // log2(primes P) (auxiliary modulus) @@ -37,7 +38,7 @@ func main() { sk := kgen.GenSecretKeyNew() // Encoder - ecd := hefloat.NewEncoder(params) + ecd := ckks.NewEncoder(params) // Encryptor enc := rlwe.NewEncryptor(params, sk) @@ -52,13 +53,13 @@ func main() { evk := rlwe.NewMemEvaluationKeySet(rlk) // Evaluator - eval := hefloat.NewEvaluator(params, evk) + eval := ckks.NewEvaluator(params, evk) // Samples values in [-K, K] K := 25.0 // Allocates a plaintext at the max level. - pt := hefloat.NewPlaintext(params, params.MaxLevel()) + pt := ckks.NewPlaintext(params, params.MaxLevel()) // Vector of plaintext values values := make([]float64, pt.Slots()) @@ -84,10 +85,10 @@ func main() { } // Minimax approximation of the sigmoid in the domain [-K, K] of degree 63. - poly := hefloat.NewPolynomial(GetMinimaxPoly(K, 63, sigmoid)) + poly := polyfloat.NewPolynomial(GetMinimaxPoly(K, 63, sigmoid)) // Instantiates the polynomial evaluator - polyEval := hefloat.NewPolynomialEvaluator(params, eval) + polyEval := polyfloat.NewPolynomialEvaluator(params, eval) // Retrieves the change of basis y = scalar * x + constant scalar, constant := poly.ChangeOfBasis() @@ -180,7 +181,7 @@ func GetMinimaxPoly(K float64, degree int, f64 func(x float64) (y float64)) bign } // PrintPrecisionStats decrypts, decodes and prints the precision stats of a ciphertext. -func PrintPrecisionStats(params hefloat.Parameters, ct *rlwe.Ciphertext, want []float64, ecd *hefloat.Encoder, dec *rlwe.Decryptor) { +func PrintPrecisionStats(params ckks.Parameters, ct *rlwe.Ciphertext, want []float64, ecd *ckks.Encoder, dec *rlwe.Decryptor) { var err error @@ -207,5 +208,5 @@ func PrintPrecisionStats(params hefloat.Parameters, ct *rlwe.Ciphertext, want [] fmt.Printf("...\n") // Pretty prints the precision stats - fmt.Println(hefloat.GetPrecisionStats(params, ecd, dec, have, want, 0, false).String()) + fmt.Println(ckks.GetPrecisionStats(params, ecd, dec, have, want, 0, false).String()) }