mirror of
https://github.com/tuneinsight/lattigo.git
synced 2025-09-13 03:27:14 +00:00
repackage minimax
This commit is contained in:
258
circuits/minimax/minimax_composite_polynomial.go
Normal file
258
circuits/minimax/minimax_composite_polynomial.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package minimax
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"math/big"
|
||||
|
||||
"github.com/tuneinsight/lattigo/v5/utils"
|
||||
"github.com/tuneinsight/lattigo/v5/utils/bignum"
|
||||
)
|
||||
|
||||
// MinimaxCompositePolynomial is a struct storing P(x) = pk(x) o pk-1(x) o ... o p1(x) o p0(x).
|
||||
type MinimaxCompositePolynomial []bignum.Polynomial
|
||||
|
||||
// NewMinimaxCompositePolynomial creates a new MinimaxCompositePolynomial from a list of coefficients.
|
||||
// Coefficients are expected to be given in the Chebyshev basis.
|
||||
func NewMinimaxCompositePolynomial(coeffsStr [][]string) MinimaxCompositePolynomial {
|
||||
polys := make([]bignum.Polynomial, len(coeffsStr))
|
||||
|
||||
for i := range coeffsStr {
|
||||
|
||||
coeffs := parseCoeffs(coeffsStr[i])
|
||||
|
||||
poly := bignum.NewPolynomial(
|
||||
bignum.Chebyshev,
|
||||
coeffs,
|
||||
&bignum.Interval{
|
||||
A: *bignum.NewFloat(-1, coeffs[0].Prec()),
|
||||
B: *bignum.NewFloat(1, coeffs[0].Prec()),
|
||||
},
|
||||
)
|
||||
|
||||
polys[i] = poly
|
||||
}
|
||||
|
||||
return MinimaxCompositePolynomial(polys)
|
||||
}
|
||||
|
||||
func (mcp MinimaxCompositePolynomial) MaxDepth() (depth int) {
|
||||
for i := range mcp {
|
||||
depth = utils.Max(depth, mcp[i].Depth())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (mcp MinimaxCompositePolynomial) Evaluate(x interface{}) (y *bignum.Complex) {
|
||||
y = mcp[0].Evaluate(x)
|
||||
|
||||
for _, p := range mcp[1:] {
|
||||
y = p.Evaluate(y)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// CoeffsSignX2Cheby (from https://eprint.iacr.org/2019/1234.pdf) are the coefficients
|
||||
// of 1.5*x - 0.5*x^3 in Chebyshev basis.
|
||||
// Evaluating this polynomial on values already close to -1, or 1 ~doubles the number of
|
||||
// of correct digits.
|
||||
// For example, if x = -0.9993209 then p(x) = -0.999999308
|
||||
// This polynomial can be composed after the minimax composite polynomial to double the
|
||||
// output precision (up to the scheme precision) each time it is evaluated.
|
||||
var CoeffsSignX2Cheby = []string{"0", "1.125", "0", "-0.125"}
|
||||
|
||||
// CoeffsSignX4Cheby (from https://eprint.iacr.org/2019/1234.pdf) are the coefficients
|
||||
// of 35/16 * x - 35/16 * x^3 + 21/16 * x^5 - 5/16 * x^7 in Chebyshev basis.
|
||||
// Evaluating this polynomial on values already close to -1, or 1 ~quadruples the number of
|
||||
// of correct digits.
|
||||
// For example, if x = -0.9993209 then p(x) = -0.9999999999990705
|
||||
// This polynomial can be composed after the minimax composite polynomial to quadruple the
|
||||
// output precision (up to the scheme precision) each time it is evaluated.
|
||||
var CoeffsSignX4Cheby = []string{"0", "1.1962890625", "0", "-0.2392578125", "0", "0.0478515625", "0", "-0.0048828125"}
|
||||
|
||||
// GenMinimaxCompositePolynomialForSign generates the minimax composite polynomial
|
||||
// P(x) = pk(x) o pk-1(x) o ... o p1(x) o p0(x) of the sign function in their interval
|
||||
// [min-err, -2^{-alpha}] U [2^{-alpha}, max+err] where alpha is the desired distinguishing
|
||||
// precision between two values and err an upperbound on the scheme error.
|
||||
//
|
||||
// The sign function is defined as: -1 if -1 <= x < 0, 0 if x = 0, 1 if 0 < x <= 1.
|
||||
//
|
||||
// See GenMinimaxCompositePolynomial for information about how to instantiate and
|
||||
// parameterize each input value of the algorithm.
|
||||
func GenMinimaxCompositePolynomialForSign(prec uint, logalpha, logerr int, deg []int) {
|
||||
|
||||
coeffs := GenMinimaxCompositePolynomial(prec, logalpha, logerr, deg, bignum.Sign)
|
||||
|
||||
decimals := int(float64(logalpha)/math.Log2(10)+0.5) + 10
|
||||
|
||||
fmt.Println("COEFFICIENTS:")
|
||||
fmt.Printf("{\n")
|
||||
for i := range coeffs {
|
||||
PrettyPrintCoefficients(decimals, coeffs[i], true, false, false)
|
||||
}
|
||||
fmt.Printf("},\n")
|
||||
}
|
||||
|
||||
// GenMinimaxCompositePolynomial generates the minimax composite polynomial
|
||||
// P(x) = pk(x) o pk-1(x) o ... o p1(x) o p0(x) for the provided function in the interval
|
||||
// in their interval [min-err, -2^{-alpha}] U [2^{-alpha}, max+err] where alpha is
|
||||
// the desired distinguishing precision between two values and err an upperbound on
|
||||
// the scheme error.
|
||||
//
|
||||
// The user must provide the following inputs:
|
||||
// - prec: the bit precision of the big.Float values used by the algorithm to compute the polynomials.
|
||||
// This will impact the speed of the algorithm.
|
||||
// A too low precision can prevent convergence or induce a slope zero during the zero finding.
|
||||
// A sign that the precision is too low is when the iteration continue without the error getting smaller.
|
||||
// - logalpha: log2(alpha)
|
||||
// - logerr: log2(err), the upperbound on the scheme precision. Usually this value should be smaller or equal to logalpha.
|
||||
// Correctly setting this value is mandatory for correctness, because if x is outside of the interval
|
||||
// (i.e. smaller than -1-e or greater than 1+e), then the values will explode during the evaluation.
|
||||
// Note that it is not required to apply change of interval [-1, 1] -> [-1-e, 1+e] because the function to evaluate
|
||||
// is the sign (i.e. it will evaluate to the same value).
|
||||
// - deg: the degree of each polynomial, ordered as follow [deg(p0(x)), deg(p1(x)), ..., deg(pk(x))].
|
||||
// It is highly recommended that deg(p0) <= deg(p1) <= ... <= deg(pk) for optimal approximation.
|
||||
//
|
||||
// The polynomials are returned in the Chebyshev basis and pre-scaled for
|
||||
// the interval [-1, 1] (no further scaling is required on the ciphertext).
|
||||
//
|
||||
// Be aware that finding the minimax polynomials can take a while (in the order of minutes for high precision when using large degree polynomials).
|
||||
//
|
||||
// The function will print information about each step of the computation in real time so that it can be monitored.
|
||||
//
|
||||
// The underlying algorithm use the multi-interval Remez algorithm of https://eprint.iacr.org/2020/834.pdf.
|
||||
func GenMinimaxCompositePolynomial(prec uint, logalpha, logerr int, deg []int, f func(*big.Float) *big.Float) (coeffs [][]*big.Float) {
|
||||
decimals := int(float64(logalpha)/math.Log2(10)+0.5) + 10
|
||||
|
||||
// Precision of the output value of the sign polynomial
|
||||
alpha := math.Exp2(-float64(logalpha))
|
||||
|
||||
// Expected upperbound scheme error
|
||||
e := bignum.NewFloat(math.Exp2(-float64(logerr)), prec)
|
||||
|
||||
// Maximum number of iterations
|
||||
maxIters := 50
|
||||
|
||||
// Scan step for finding zeroes of the error function
|
||||
scanStep := bignum.NewFloat(1e-3, prec)
|
||||
|
||||
// Interval [-1, alpha] U [alpha, 1]
|
||||
intervals := []bignum.Interval{
|
||||
{A: *bignum.NewFloat(-1, prec), B: *bignum.NewFloat(-alpha, prec), Nodes: 1 + ((deg[0] + 1) >> 1)},
|
||||
{A: *bignum.NewFloat(alpha, prec), B: *bignum.NewFloat(1, prec), Nodes: 1 + ((deg[0] + 1) >> 1)},
|
||||
}
|
||||
|
||||
// Adds the error to the interval
|
||||
// [A, -alpha] U [alpha, B] becomes [A-e, -alpha] U [alpha, B+e]
|
||||
intervals[0].A.Sub(&intervals[0].A, e)
|
||||
intervals[1].B.Add(&intervals[1].B, e)
|
||||
|
||||
// Parameters of the minimax approximation
|
||||
params := bignum.RemezParameters{
|
||||
Function: f,
|
||||
Basis: bignum.Chebyshev,
|
||||
Intervals: intervals,
|
||||
ScanStep: scanStep,
|
||||
Prec: prec,
|
||||
OptimalScanStep: true,
|
||||
}
|
||||
|
||||
fmt.Printf("P[0]\n")
|
||||
fmt.Printf("Interval: [%.*f, %.*f] U [%.*f, %.*f]\n", decimals, &intervals[0].A, decimals, &intervals[0].B, decimals, &intervals[1].A, decimals, &intervals[1].B)
|
||||
r := bignum.NewRemez(params)
|
||||
r.Approximate(maxIters, alpha)
|
||||
//r.ShowCoeffs(decimals)
|
||||
r.ShowError(decimals)
|
||||
fmt.Println()
|
||||
|
||||
coeffs = make([][]*big.Float, len(deg))
|
||||
|
||||
for i := 1; i < len(deg); i++ {
|
||||
|
||||
// New interval as [-(1+max_err), -(1-min_err)] U [1-min_err, 1+max_err]
|
||||
maxInterval := bignum.NewFloat(1, prec)
|
||||
maxInterval.Add(maxInterval, r.MaxErr)
|
||||
|
||||
minInterval := bignum.NewFloat(1, prec)
|
||||
minInterval.Sub(minInterval, r.MinErr)
|
||||
|
||||
// Extends the new interval by the scheme error
|
||||
// [-(1+max_err), -(1-min_err)] U [1-min_err, 1 + max_err] becomes [-(1+max_err+e), -(1-min_err-e)] U [1-min_err-e, 1+max_err+e]
|
||||
maxInterval.Add(maxInterval, e)
|
||||
minInterval.Sub(minInterval, e)
|
||||
|
||||
intervals = []bignum.Interval{
|
||||
{A: *new(big.Float).Neg(maxInterval), B: *new(big.Float).Neg(minInterval), Nodes: 1 + ((deg[i] + 1) >> 1)},
|
||||
{A: *minInterval, B: *maxInterval, Nodes: 1 + ((deg[i] + 1) >> 1)},
|
||||
}
|
||||
|
||||
coeffs[i-1] = make([]*big.Float, deg[i-1]+1)
|
||||
for j := range coeffs[i-1] {
|
||||
coeffs[i-1][j] = new(big.Float).Set(r.Coeffs[j])
|
||||
coeffs[i-1][j].Quo(coeffs[i-1][j], maxInterval) // Interval normalization
|
||||
}
|
||||
|
||||
params := bignum.RemezParameters{
|
||||
Function: f,
|
||||
Basis: bignum.Chebyshev,
|
||||
Intervals: intervals,
|
||||
ScanStep: scanStep,
|
||||
Prec: prec,
|
||||
OptimalScanStep: true,
|
||||
}
|
||||
|
||||
fmt.Printf("P[%d]\n", i)
|
||||
fmt.Printf("Interval: [%.*f, %.*f] U [%.*f, %.*f]\n", decimals, &intervals[0].A, decimals, &intervals[0].B, decimals, &intervals[1].A, decimals, &intervals[1].B)
|
||||
r = bignum.NewRemez(params)
|
||||
r.Approximate(maxIters, alpha)
|
||||
//r.ShowCoeffs(decimals)
|
||||
r.ShowError(decimals)
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
// Since this is the last polynomial, we can skip the interval scaling.
|
||||
coeffs[len(deg)-1] = make([]*big.Float, deg[len(deg)-1]+1)
|
||||
for j := range coeffs[len(deg)-1] {
|
||||
coeffs[len(deg)-1][j] = new(big.Float).Set(r.Coeffs[j])
|
||||
}
|
||||
|
||||
f64, _ := r.MaxErr.Float64()
|
||||
fmt.Printf("Output Precision: %f\n", math.Log2(f64))
|
||||
fmt.Println()
|
||||
|
||||
return coeffs
|
||||
}
|
||||
|
||||
// PrettyPrintCoefficients prints the coefficients formatted.
|
||||
// If odd = true, even coefficients are zeroed.
|
||||
// If even = true, odd coefficients are zeroed.
|
||||
func PrettyPrintCoefficients(decimals int, coeffs []*big.Float, odd, even, first bool) {
|
||||
fmt.Printf("{")
|
||||
for i, c := range coeffs {
|
||||
if (i&1 == 1 && odd) || (i&1 == 0 && even) || (i == 0 && first) {
|
||||
fmt.Printf("\"%.*f\", ", decimals, c)
|
||||
} else {
|
||||
fmt.Printf("\"0\", ")
|
||||
}
|
||||
|
||||
}
|
||||
fmt.Printf("},\n")
|
||||
}
|
||||
|
||||
func parseCoeffs(coeffsStr []string) (coeffs []*big.Float) {
|
||||
|
||||
var prec uint
|
||||
for _, c := range coeffsStr {
|
||||
prec = utils.Max(prec, uint(len(c)))
|
||||
}
|
||||
|
||||
prec = uint(float64(prec)*3.3219280948873626 + 0.5) // max(float64, digits * log2(10))
|
||||
|
||||
coeffs = make([]*big.Float, len(coeffsStr))
|
||||
for i := range coeffsStr {
|
||||
coeffs[i], _ = new(big.Float).SetPrec(prec).SetString(coeffsStr[i])
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
90
circuits/minimax/minimax_composite_polynomial_evaluator.go
Normal file
90
circuits/minimax/minimax_composite_polynomial_evaluator.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package minimax
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/tuneinsight/lattigo/v5/circuits/bootstrapping"
|
||||
"github.com/tuneinsight/lattigo/v5/circuits/polynomial/polyfloat"
|
||||
"github.com/tuneinsight/lattigo/v5/core/rlwe"
|
||||
"github.com/tuneinsight/lattigo/v5/ring"
|
||||
"github.com/tuneinsight/lattigo/v5/schemes/ckks"
|
||||
)
|
||||
|
||||
// MinimaxCompositePolynomialEvaluator is an evaluator used to evaluate composite polynomials on ciphertexts.
|
||||
// All fields of this struct are publics, enabling custom instantiations.
|
||||
type MinimaxCompositePolynomialEvaluator struct {
|
||||
*ckks.Evaluator
|
||||
*polyfloat.PolynomialEvaluator
|
||||
BtsEval *bootstrapping.Evaluator
|
||||
Parameters ckks.Parameters
|
||||
}
|
||||
|
||||
// NewMinimaxCompositePolynomialEvaluator instantiates a new MinimaxCompositePolynomialEvaluator.
|
||||
// The default hefloat.Evaluator is compliant to the EvaluatorForMinimaxCompositePolynomial interface.
|
||||
// This method is allocation free.
|
||||
func NewMinimaxCompositePolynomialEvaluator(params ckks.Parameters, eval *ckks.Evaluator, btsEval *bootstrapping.Evaluator) *MinimaxCompositePolynomialEvaluator {
|
||||
return &MinimaxCompositePolynomialEvaluator{eval, polyfloat.NewPolynomialEvaluator(params, eval), btsEval, params}
|
||||
}
|
||||
|
||||
// Evaluate evaluates the provided MinimaxCompositePolynomial on the input ciphertext.
|
||||
func (eval MinimaxCompositePolynomialEvaluator) Evaluate(ct *rlwe.Ciphertext, mcp MinimaxCompositePolynomial) (res *rlwe.Ciphertext, err error) {
|
||||
|
||||
params := eval.Parameters
|
||||
|
||||
btp := eval.BtsEval
|
||||
|
||||
levelsConsumedPerRescaling := params.LevelsConsumedPerRescaling()
|
||||
|
||||
// Checks that the number of levels available after the bootstrapping is enough to evaluate all polynomials
|
||||
if maxDepth := mcp.MaxDepth() * levelsConsumedPerRescaling; params.MaxLevel() < maxDepth+btp.MinimumInputLevel() {
|
||||
return nil, fmt.Errorf("parameters do not enable the evaluation of the minimax composite polynomial, required levels is %d but parameters only provide %d levels", maxDepth+btp.MinimumInputLevel(), params.MaxLevel())
|
||||
}
|
||||
|
||||
res = ct.CopyNew()
|
||||
|
||||
for _, poly := range mcp {
|
||||
|
||||
// Checks that res has enough level to evaluate the next polynomial, else bootstrap
|
||||
if res.Level() < poly.Depth()*params.LevelsConsumedPerRescaling()+btp.MinimumInputLevel() {
|
||||
if res, err = btp.Bootstrap(res); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Define the scale that res must have after the polynomial evaluation.
|
||||
// If we use the regular CKKS (with complex values), we chose a scale to be
|
||||
// half of the desired scale, so that (x + conj(x)/2) has the correct scale.
|
||||
var targetScale rlwe.Scale
|
||||
if params.RingType() == ring.Standard {
|
||||
targetScale = params.DefaultScale().Div(rlwe.NewScale(2))
|
||||
} else {
|
||||
targetScale = params.DefaultScale()
|
||||
}
|
||||
|
||||
// Evaluate the polynomial
|
||||
if res, err = eval.PolynomialEvaluator.Evaluate(res, poly, targetScale); err != nil {
|
||||
return nil, fmt.Errorf("evaluate polynomial: %w", err)
|
||||
}
|
||||
|
||||
// Clean the imaginary part (else it tends to explode)
|
||||
if params.RingType() == ring.Standard {
|
||||
|
||||
// Reassigns the scale back to the original one
|
||||
res.Scale = res.Scale.Mul(rlwe.NewScale(2))
|
||||
|
||||
var resConj *rlwe.Ciphertext
|
||||
if resConj, err = eval.ConjugateNew(res); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = eval.Add(res, resConj, res); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Avoids float errors
|
||||
res.Scale = ct.Scale
|
||||
|
||||
return
|
||||
}
|
||||
@@ -6,21 +6,22 @@ import (
|
||||
"math"
|
||||
"math/big"
|
||||
|
||||
"github.com/tuneinsight/lattigo/v5/circuits/polynomial/polyfloat"
|
||||
"github.com/tuneinsight/lattigo/v5/core/rlwe"
|
||||
"github.com/tuneinsight/lattigo/v5/he/hefloat"
|
||||
"github.com/tuneinsight/lattigo/v5/ring"
|
||||
"github.com/tuneinsight/lattigo/v5/schemes/ckks"
|
||||
"github.com/tuneinsight/lattigo/v5/utils/bignum"
|
||||
"github.com/tuneinsight/lattigo/v5/utils/sampling"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var err error
|
||||
var params hefloat.Parameters
|
||||
var params ckks.Parameters
|
||||
|
||||
// 128-bit secure parameters enabling depth-7 circuits.
|
||||
// LogN:14, LogQP: 431.
|
||||
if params, err = hefloat.NewParametersFromLiteral(
|
||||
hefloat.ParametersLiteral{
|
||||
if params, err = ckks.NewParametersFromLiteral(
|
||||
ckks.ParametersLiteral{
|
||||
LogN: 14, // log2(ring degree)
|
||||
LogQ: []int{55, 45, 45, 45, 45, 45, 45, 45}, // log2(primes Q) (ciphertext modulus)
|
||||
LogP: []int{61}, // log2(primes P) (auxiliary modulus)
|
||||
@@ -37,7 +38,7 @@ func main() {
|
||||
sk := kgen.GenSecretKeyNew()
|
||||
|
||||
// Encoder
|
||||
ecd := hefloat.NewEncoder(params)
|
||||
ecd := ckks.NewEncoder(params)
|
||||
|
||||
// Encryptor
|
||||
enc := rlwe.NewEncryptor(params, sk)
|
||||
@@ -52,13 +53,13 @@ func main() {
|
||||
evk := rlwe.NewMemEvaluationKeySet(rlk)
|
||||
|
||||
// Evaluator
|
||||
eval := hefloat.NewEvaluator(params, evk)
|
||||
eval := ckks.NewEvaluator(params, evk)
|
||||
|
||||
// Samples values in [-K, K]
|
||||
K := 25.0
|
||||
|
||||
// Allocates a plaintext at the max level.
|
||||
pt := hefloat.NewPlaintext(params, params.MaxLevel())
|
||||
pt := ckks.NewPlaintext(params, params.MaxLevel())
|
||||
|
||||
// Vector of plaintext values
|
||||
values := make([]float64, pt.Slots())
|
||||
@@ -84,10 +85,10 @@ func main() {
|
||||
}
|
||||
|
||||
// Minimax approximation of the sigmoid in the domain [-K, K] of degree 63.
|
||||
poly := hefloat.NewPolynomial(GetMinimaxPoly(K, 63, sigmoid))
|
||||
poly := polyfloat.NewPolynomial(GetMinimaxPoly(K, 63, sigmoid))
|
||||
|
||||
// Instantiates the polynomial evaluator
|
||||
polyEval := hefloat.NewPolynomialEvaluator(params, eval)
|
||||
polyEval := polyfloat.NewPolynomialEvaluator(params, eval)
|
||||
|
||||
// Retrieves the change of basis y = scalar * x + constant
|
||||
scalar, constant := poly.ChangeOfBasis()
|
||||
@@ -180,7 +181,7 @@ func GetMinimaxPoly(K float64, degree int, f64 func(x float64) (y float64)) bign
|
||||
}
|
||||
|
||||
// PrintPrecisionStats decrypts, decodes and prints the precision stats of a ciphertext.
|
||||
func PrintPrecisionStats(params hefloat.Parameters, ct *rlwe.Ciphertext, want []float64, ecd *hefloat.Encoder, dec *rlwe.Decryptor) {
|
||||
func PrintPrecisionStats(params ckks.Parameters, ct *rlwe.Ciphertext, want []float64, ecd *ckks.Encoder, dec *rlwe.Decryptor) {
|
||||
|
||||
var err error
|
||||
|
||||
@@ -207,5 +208,5 @@ func PrintPrecisionStats(params hefloat.Parameters, ct *rlwe.Ciphertext, want []
|
||||
fmt.Printf("...\n")
|
||||
|
||||
// Pretty prints the precision stats
|
||||
fmt.Println(hefloat.GetPrecisionStats(params, ecd, dec, have, want, 0, false).String())
|
||||
fmt.Println(ckks.GetPrecisionStats(params, ecd, dec, have, want, 0, false).String())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user