repackage minimax

This commit is contained in:
Andrea Caforio
2024-07-04 11:00:20 +02:00
parent 7ecd3beccf
commit 4fa9991b33
3 changed files with 360 additions and 11 deletions

View File

@@ -0,0 +1,258 @@
package minimax
import (
"fmt"
"math"
"math/big"
"github.com/tuneinsight/lattigo/v5/utils"
"github.com/tuneinsight/lattigo/v5/utils/bignum"
)
// MinimaxCompositePolynomial is a struct storing P(x) = pk(x) o pk-1(x) o ... o p1(x) o p0(x).
type MinimaxCompositePolynomial []bignum.Polynomial
// NewMinimaxCompositePolynomial creates a new MinimaxCompositePolynomial from a list of coefficients.
// Coefficients are expected to be given in the Chebyshev basis.
func NewMinimaxCompositePolynomial(coeffsStr [][]string) MinimaxCompositePolynomial {
polys := make([]bignum.Polynomial, len(coeffsStr))
for i := range coeffsStr {
coeffs := parseCoeffs(coeffsStr[i])
poly := bignum.NewPolynomial(
bignum.Chebyshev,
coeffs,
&bignum.Interval{
A: *bignum.NewFloat(-1, coeffs[0].Prec()),
B: *bignum.NewFloat(1, coeffs[0].Prec()),
},
)
polys[i] = poly
}
return MinimaxCompositePolynomial(polys)
}
func (mcp MinimaxCompositePolynomial) MaxDepth() (depth int) {
for i := range mcp {
depth = utils.Max(depth, mcp[i].Depth())
}
return
}
func (mcp MinimaxCompositePolynomial) Evaluate(x interface{}) (y *bignum.Complex) {
y = mcp[0].Evaluate(x)
for _, p := range mcp[1:] {
y = p.Evaluate(y)
}
return
}
// CoeffsSignX2Cheby (from https://eprint.iacr.org/2019/1234.pdf) are the coefficients
// of 1.5*x - 0.5*x^3 in Chebyshev basis.
// Evaluating this polynomial on values already close to -1, or 1 ~doubles the number of
// of correct digits.
// For example, if x = -0.9993209 then p(x) = -0.999999308
// This polynomial can be composed after the minimax composite polynomial to double the
// output precision (up to the scheme precision) each time it is evaluated.
var CoeffsSignX2Cheby = []string{"0", "1.125", "0", "-0.125"}
// CoeffsSignX4Cheby (from https://eprint.iacr.org/2019/1234.pdf) are the coefficients
// of 35/16 * x - 35/16 * x^3 + 21/16 * x^5 - 5/16 * x^7 in Chebyshev basis.
// Evaluating this polynomial on values already close to -1, or 1 ~quadruples the number of
// of correct digits.
// For example, if x = -0.9993209 then p(x) = -0.9999999999990705
// This polynomial can be composed after the minimax composite polynomial to quadruple the
// output precision (up to the scheme precision) each time it is evaluated.
var CoeffsSignX4Cheby = []string{"0", "1.1962890625", "0", "-0.2392578125", "0", "0.0478515625", "0", "-0.0048828125"}
// GenMinimaxCompositePolynomialForSign generates the minimax composite polynomial
// P(x) = pk(x) o pk-1(x) o ... o p1(x) o p0(x) of the sign function in their interval
// [min-err, -2^{-alpha}] U [2^{-alpha}, max+err] where alpha is the desired distinguishing
// precision between two values and err an upperbound on the scheme error.
//
// The sign function is defined as: -1 if -1 <= x < 0, 0 if x = 0, 1 if 0 < x <= 1.
//
// See GenMinimaxCompositePolynomial for information about how to instantiate and
// parameterize each input value of the algorithm.
func GenMinimaxCompositePolynomialForSign(prec uint, logalpha, logerr int, deg []int) {
coeffs := GenMinimaxCompositePolynomial(prec, logalpha, logerr, deg, bignum.Sign)
decimals := int(float64(logalpha)/math.Log2(10)+0.5) + 10
fmt.Println("COEFFICIENTS:")
fmt.Printf("{\n")
for i := range coeffs {
PrettyPrintCoefficients(decimals, coeffs[i], true, false, false)
}
fmt.Printf("},\n")
}
// GenMinimaxCompositePolynomial generates the minimax composite polynomial
// P(x) = pk(x) o pk-1(x) o ... o p1(x) o p0(x) for the provided function in the interval
// in their interval [min-err, -2^{-alpha}] U [2^{-alpha}, max+err] where alpha is
// the desired distinguishing precision between two values and err an upperbound on
// the scheme error.
//
// The user must provide the following inputs:
// - prec: the bit precision of the big.Float values used by the algorithm to compute the polynomials.
// This will impact the speed of the algorithm.
// A too low precision can prevent convergence or induce a slope zero during the zero finding.
// A sign that the precision is too low is when the iteration continue without the error getting smaller.
// - logalpha: log2(alpha)
// - logerr: log2(err), the upperbound on the scheme precision. Usually this value should be smaller or equal to logalpha.
// Correctly setting this value is mandatory for correctness, because if x is outside of the interval
// (i.e. smaller than -1-e or greater than 1+e), then the values will explode during the evaluation.
// Note that it is not required to apply change of interval [-1, 1] -> [-1-e, 1+e] because the function to evaluate
// is the sign (i.e. it will evaluate to the same value).
// - deg: the degree of each polynomial, ordered as follow [deg(p0(x)), deg(p1(x)), ..., deg(pk(x))].
// It is highly recommended that deg(p0) <= deg(p1) <= ... <= deg(pk) for optimal approximation.
//
// The polynomials are returned in the Chebyshev basis and pre-scaled for
// the interval [-1, 1] (no further scaling is required on the ciphertext).
//
// Be aware that finding the minimax polynomials can take a while (in the order of minutes for high precision when using large degree polynomials).
//
// The function will print information about each step of the computation in real time so that it can be monitored.
//
// The underlying algorithm use the multi-interval Remez algorithm of https://eprint.iacr.org/2020/834.pdf.
func GenMinimaxCompositePolynomial(prec uint, logalpha, logerr int, deg []int, f func(*big.Float) *big.Float) (coeffs [][]*big.Float) {
decimals := int(float64(logalpha)/math.Log2(10)+0.5) + 10
// Precision of the output value of the sign polynomial
alpha := math.Exp2(-float64(logalpha))
// Expected upperbound scheme error
e := bignum.NewFloat(math.Exp2(-float64(logerr)), prec)
// Maximum number of iterations
maxIters := 50
// Scan step for finding zeroes of the error function
scanStep := bignum.NewFloat(1e-3, prec)
// Interval [-1, alpha] U [alpha, 1]
intervals := []bignum.Interval{
{A: *bignum.NewFloat(-1, prec), B: *bignum.NewFloat(-alpha, prec), Nodes: 1 + ((deg[0] + 1) >> 1)},
{A: *bignum.NewFloat(alpha, prec), B: *bignum.NewFloat(1, prec), Nodes: 1 + ((deg[0] + 1) >> 1)},
}
// Adds the error to the interval
// [A, -alpha] U [alpha, B] becomes [A-e, -alpha] U [alpha, B+e]
intervals[0].A.Sub(&intervals[0].A, e)
intervals[1].B.Add(&intervals[1].B, e)
// Parameters of the minimax approximation
params := bignum.RemezParameters{
Function: f,
Basis: bignum.Chebyshev,
Intervals: intervals,
ScanStep: scanStep,
Prec: prec,
OptimalScanStep: true,
}
fmt.Printf("P[0]\n")
fmt.Printf("Interval: [%.*f, %.*f] U [%.*f, %.*f]\n", decimals, &intervals[0].A, decimals, &intervals[0].B, decimals, &intervals[1].A, decimals, &intervals[1].B)
r := bignum.NewRemez(params)
r.Approximate(maxIters, alpha)
//r.ShowCoeffs(decimals)
r.ShowError(decimals)
fmt.Println()
coeffs = make([][]*big.Float, len(deg))
for i := 1; i < len(deg); i++ {
// New interval as [-(1+max_err), -(1-min_err)] U [1-min_err, 1+max_err]
maxInterval := bignum.NewFloat(1, prec)
maxInterval.Add(maxInterval, r.MaxErr)
minInterval := bignum.NewFloat(1, prec)
minInterval.Sub(minInterval, r.MinErr)
// Extends the new interval by the scheme error
// [-(1+max_err), -(1-min_err)] U [1-min_err, 1 + max_err] becomes [-(1+max_err+e), -(1-min_err-e)] U [1-min_err-e, 1+max_err+e]
maxInterval.Add(maxInterval, e)
minInterval.Sub(minInterval, e)
intervals = []bignum.Interval{
{A: *new(big.Float).Neg(maxInterval), B: *new(big.Float).Neg(minInterval), Nodes: 1 + ((deg[i] + 1) >> 1)},
{A: *minInterval, B: *maxInterval, Nodes: 1 + ((deg[i] + 1) >> 1)},
}
coeffs[i-1] = make([]*big.Float, deg[i-1]+1)
for j := range coeffs[i-1] {
coeffs[i-1][j] = new(big.Float).Set(r.Coeffs[j])
coeffs[i-1][j].Quo(coeffs[i-1][j], maxInterval) // Interval normalization
}
params := bignum.RemezParameters{
Function: f,
Basis: bignum.Chebyshev,
Intervals: intervals,
ScanStep: scanStep,
Prec: prec,
OptimalScanStep: true,
}
fmt.Printf("P[%d]\n", i)
fmt.Printf("Interval: [%.*f, %.*f] U [%.*f, %.*f]\n", decimals, &intervals[0].A, decimals, &intervals[0].B, decimals, &intervals[1].A, decimals, &intervals[1].B)
r = bignum.NewRemez(params)
r.Approximate(maxIters, alpha)
//r.ShowCoeffs(decimals)
r.ShowError(decimals)
fmt.Println()
}
// Since this is the last polynomial, we can skip the interval scaling.
coeffs[len(deg)-1] = make([]*big.Float, deg[len(deg)-1]+1)
for j := range coeffs[len(deg)-1] {
coeffs[len(deg)-1][j] = new(big.Float).Set(r.Coeffs[j])
}
f64, _ := r.MaxErr.Float64()
fmt.Printf("Output Precision: %f\n", math.Log2(f64))
fmt.Println()
return coeffs
}
// PrettyPrintCoefficients prints the coefficients formatted.
// If odd = true, even coefficients are zeroed.
// If even = true, odd coefficients are zeroed.
func PrettyPrintCoefficients(decimals int, coeffs []*big.Float, odd, even, first bool) {
fmt.Printf("{")
for i, c := range coeffs {
if (i&1 == 1 && odd) || (i&1 == 0 && even) || (i == 0 && first) {
fmt.Printf("\"%.*f\", ", decimals, c)
} else {
fmt.Printf("\"0\", ")
}
}
fmt.Printf("},\n")
}
func parseCoeffs(coeffsStr []string) (coeffs []*big.Float) {
var prec uint
for _, c := range coeffsStr {
prec = utils.Max(prec, uint(len(c)))
}
prec = uint(float64(prec)*3.3219280948873626 + 0.5) // max(float64, digits * log2(10))
coeffs = make([]*big.Float, len(coeffsStr))
for i := range coeffsStr {
coeffs[i], _ = new(big.Float).SetPrec(prec).SetString(coeffsStr[i])
}
return
}

View File

@@ -0,0 +1,90 @@
package minimax
import (
"fmt"
"github.com/tuneinsight/lattigo/v5/circuits/bootstrapping"
"github.com/tuneinsight/lattigo/v5/circuits/polynomial/polyfloat"
"github.com/tuneinsight/lattigo/v5/core/rlwe"
"github.com/tuneinsight/lattigo/v5/ring"
"github.com/tuneinsight/lattigo/v5/schemes/ckks"
)
// MinimaxCompositePolynomialEvaluator is an evaluator used to evaluate composite polynomials on ciphertexts.
// All fields of this struct are publics, enabling custom instantiations.
type MinimaxCompositePolynomialEvaluator struct {
*ckks.Evaluator
*polyfloat.PolynomialEvaluator
BtsEval *bootstrapping.Evaluator
Parameters ckks.Parameters
}
// NewMinimaxCompositePolynomialEvaluator instantiates a new MinimaxCompositePolynomialEvaluator.
// The default hefloat.Evaluator is compliant to the EvaluatorForMinimaxCompositePolynomial interface.
// This method is allocation free.
func NewMinimaxCompositePolynomialEvaluator(params ckks.Parameters, eval *ckks.Evaluator, btsEval *bootstrapping.Evaluator) *MinimaxCompositePolynomialEvaluator {
return &MinimaxCompositePolynomialEvaluator{eval, polyfloat.NewPolynomialEvaluator(params, eval), btsEval, params}
}
// Evaluate evaluates the provided MinimaxCompositePolynomial on the input ciphertext.
func (eval MinimaxCompositePolynomialEvaluator) Evaluate(ct *rlwe.Ciphertext, mcp MinimaxCompositePolynomial) (res *rlwe.Ciphertext, err error) {
params := eval.Parameters
btp := eval.BtsEval
levelsConsumedPerRescaling := params.LevelsConsumedPerRescaling()
// Checks that the number of levels available after the bootstrapping is enough to evaluate all polynomials
if maxDepth := mcp.MaxDepth() * levelsConsumedPerRescaling; params.MaxLevel() < maxDepth+btp.MinimumInputLevel() {
return nil, fmt.Errorf("parameters do not enable the evaluation of the minimax composite polynomial, required levels is %d but parameters only provide %d levels", maxDepth+btp.MinimumInputLevel(), params.MaxLevel())
}
res = ct.CopyNew()
for _, poly := range mcp {
// Checks that res has enough level to evaluate the next polynomial, else bootstrap
if res.Level() < poly.Depth()*params.LevelsConsumedPerRescaling()+btp.MinimumInputLevel() {
if res, err = btp.Bootstrap(res); err != nil {
return
}
}
// Define the scale that res must have after the polynomial evaluation.
// If we use the regular CKKS (with complex values), we chose a scale to be
// half of the desired scale, so that (x + conj(x)/2) has the correct scale.
var targetScale rlwe.Scale
if params.RingType() == ring.Standard {
targetScale = params.DefaultScale().Div(rlwe.NewScale(2))
} else {
targetScale = params.DefaultScale()
}
// Evaluate the polynomial
if res, err = eval.PolynomialEvaluator.Evaluate(res, poly, targetScale); err != nil {
return nil, fmt.Errorf("evaluate polynomial: %w", err)
}
// Clean the imaginary part (else it tends to explode)
if params.RingType() == ring.Standard {
// Reassigns the scale back to the original one
res.Scale = res.Scale.Mul(rlwe.NewScale(2))
var resConj *rlwe.Ciphertext
if resConj, err = eval.ConjugateNew(res); err != nil {
return
}
if err = eval.Add(res, resConj, res); err != nil {
return
}
}
}
// Avoids float errors
res.Scale = ct.Scale
return
}

View File

@@ -6,21 +6,22 @@ import (
"math"
"math/big"
"github.com/tuneinsight/lattigo/v5/circuits/polynomial/polyfloat"
"github.com/tuneinsight/lattigo/v5/core/rlwe"
"github.com/tuneinsight/lattigo/v5/he/hefloat"
"github.com/tuneinsight/lattigo/v5/ring"
"github.com/tuneinsight/lattigo/v5/schemes/ckks"
"github.com/tuneinsight/lattigo/v5/utils/bignum"
"github.com/tuneinsight/lattigo/v5/utils/sampling"
)
func main() {
var err error
var params hefloat.Parameters
var params ckks.Parameters
// 128-bit secure parameters enabling depth-7 circuits.
// LogN:14, LogQP: 431.
if params, err = hefloat.NewParametersFromLiteral(
hefloat.ParametersLiteral{
if params, err = ckks.NewParametersFromLiteral(
ckks.ParametersLiteral{
LogN: 14, // log2(ring degree)
LogQ: []int{55, 45, 45, 45, 45, 45, 45, 45}, // log2(primes Q) (ciphertext modulus)
LogP: []int{61}, // log2(primes P) (auxiliary modulus)
@@ -37,7 +38,7 @@ func main() {
sk := kgen.GenSecretKeyNew()
// Encoder
ecd := hefloat.NewEncoder(params)
ecd := ckks.NewEncoder(params)
// Encryptor
enc := rlwe.NewEncryptor(params, sk)
@@ -52,13 +53,13 @@ func main() {
evk := rlwe.NewMemEvaluationKeySet(rlk)
// Evaluator
eval := hefloat.NewEvaluator(params, evk)
eval := ckks.NewEvaluator(params, evk)
// Samples values in [-K, K]
K := 25.0
// Allocates a plaintext at the max level.
pt := hefloat.NewPlaintext(params, params.MaxLevel())
pt := ckks.NewPlaintext(params, params.MaxLevel())
// Vector of plaintext values
values := make([]float64, pt.Slots())
@@ -84,10 +85,10 @@ func main() {
}
// Minimax approximation of the sigmoid in the domain [-K, K] of degree 63.
poly := hefloat.NewPolynomial(GetMinimaxPoly(K, 63, sigmoid))
poly := polyfloat.NewPolynomial(GetMinimaxPoly(K, 63, sigmoid))
// Instantiates the polynomial evaluator
polyEval := hefloat.NewPolynomialEvaluator(params, eval)
polyEval := polyfloat.NewPolynomialEvaluator(params, eval)
// Retrieves the change of basis y = scalar * x + constant
scalar, constant := poly.ChangeOfBasis()
@@ -180,7 +181,7 @@ func GetMinimaxPoly(K float64, degree int, f64 func(x float64) (y float64)) bign
}
// PrintPrecisionStats decrypts, decodes and prints the precision stats of a ciphertext.
func PrintPrecisionStats(params hefloat.Parameters, ct *rlwe.Ciphertext, want []float64, ecd *hefloat.Encoder, dec *rlwe.Decryptor) {
func PrintPrecisionStats(params ckks.Parameters, ct *rlwe.Ciphertext, want []float64, ecd *ckks.Encoder, dec *rlwe.Decryptor) {
var err error
@@ -207,5 +208,5 @@ func PrintPrecisionStats(params hefloat.Parameters, ct *rlwe.Ciphertext, want []
fmt.Printf("...\n")
// Pretty prints the precision stats
fmt.Println(hefloat.GetPrecisionStats(params, ecd, dec, have, want, 0, false).String())
fmt.Println(ckks.GetPrecisionStats(params, ecd, dec, have, want, 0, false).String())
}