From d56f37f35bcc290f2f5e1a56ad9afacc4f071f6a Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Sat, 15 Jul 2023 22:52:19 +0200 Subject: [PATCH] [bignum]: easier to user Chebyshev approximation --- ckks/ckks_test.go | 10 +-------- ckks/homomorphic_mod.go | 21 +++++++++---------- examples/ckks/ckks_tutorial/main.go | 17 +++------------- examples/ckks/polyeval/main.go | 21 ++++--------------- utils/bignum/chebyshev_approximation.go | 27 ++++++++++++++++++++++--- utils/bignum/remez_test.go | 10 ++++----- 6 files changed, 47 insertions(+), 59 deletions(-) diff --git a/ckks/ckks_test.go b/ckks/ckks_test.go index 3a032295..b9ff632a 100644 --- a/ckks/ckks_test.go +++ b/ckks/ckks_test.go @@ -935,21 +935,13 @@ func testChebyshevInterpolator(tc *testContext, t *testing.T) { prec := tc.params.PlaintextPrecision() - sin := func(x *bignum.Complex) (y *bignum.Complex) { - xf64, _ := x[0].Float64() - y = bignum.NewComplex() - y.SetPrec(prec) - y[0].SetFloat64(math.Sin(xf64)) - return - } - interval := bignum.Interval{ Nodes: degree, A: *new(big.Float).SetPrec(prec).SetFloat64(-8), B: *new(big.Float).SetPrec(prec).SetFloat64(8), } - poly := rlwe.NewPolynomial(bignum.ChebyshevApproximation(sin, interval)) + poly := rlwe.NewPolynomial(bignum.ChebyshevApproximation(math.Sin, interval)) scalar, constant := poly.ChangeOfBasis() eval.Mul(ciphertext, scalar, ciphertext) diff --git a/ckks/homomorphic_mod.go b/ckks/homomorphic_mod.go index 3be1a444..b91b845b 100644 --- a/ckks/homomorphic_mod.go +++ b/ckks/homomorphic_mod.go @@ -17,19 +17,18 @@ import ( // for the homomorphic modular reduction type SineType uint64 -func sin2pi(x *bignum.Complex) (y *bignum.Complex) { - y = bignum.NewComplex().Set(x) - y[0].Mul(y[0], new(big.Float).SetFloat64(2)) - y[0].Mul(y[0], bignum.Pi(x.Prec())) - y[0] = bignum.Sin(y[0]) - return +func sin2pi(x *big.Float) (y *big.Float) { + y = new(big.Float).Set(x) + y.Mul(y, new(big.Float).SetFloat64(2)) + y.Mul(y, bignum.Pi(x.Prec())) + return bignum.Sin(y) } -func cos2pi(x *bignum.Complex) (y *bignum.Complex) { - y = bignum.NewComplex().Set(x) - y[0].Mul(y[0], new(big.Float).SetFloat64(2)) - y[0].Mul(y[0], bignum.Pi(x.Prec())) - y[0] = bignum.Cos(y[0]) +func cos2pi(x *big.Float) (y *big.Float) { + y = new(big.Float).Set(x) + y.Mul(y, new(big.Float).SetFloat64(2)) + y.Mul(y, bignum.Pi(x.Prec())) + y = bignum.Cos(y) return y } diff --git a/examples/ckks/ckks_tutorial/main.go b/examples/ckks/ckks_tutorial/main.go index 7fd2d22c..598eeddf 100644 --- a/examples/ckks/ckks_tutorial/main.go +++ b/examples/ckks/ckks_tutorial/main.go @@ -531,20 +531,9 @@ func main() { // Let define a function, for example, the SiLU. // The signature needed is `func(x *bignum.Complex) (y *bignum.Complex)` so we must accommodate for it first: - SiLU := func(x *bignum.Complex) (y *bignum.Complex) { - - // Yes sigmoid over the complex! - sigmoid := func(x complex128) (y complex128) { - return 1 / (cmplx.Exp(-x) + 1) - } - - ycmplx128 := x.Complex128() - - ycmplx128 = ycmplx128 * sigmoid(ycmplx128) - - y = bignum.NewComplex().SetPrec(prec).SetComplex128(ycmplx128) - - return + // Yes SiLU over the complex! + SiLU := func(x complex128) (y complex128) { + return x / (cmplx.Exp(-x) + 1) } // We must also give an interval [a, b], for example [-8, 8], in which we approximate SiLU, as well as the degree of approximation. diff --git a/examples/ckks/polyeval/main.go b/examples/ckks/polyeval/main.go index d84b2a35..5ca16237 100644 --- a/examples/ckks/polyeval/main.go +++ b/examples/ckks/polyeval/main.go @@ -98,27 +98,14 @@ func chebyshevinterpolation() { // Evaluation process // We approximate f(x) in the range [-8, 8] with a Chebyshev interpolant of 33 coefficients (degree 32). - approxF := bignum.ChebyshevApproximation(func(x *bignum.Complex) (y *bignum.Complex) { - xf64, _ := x[0].Float64() - y = bignum.NewComplex().SetPrec(53) - y[0].SetFloat64(f(xf64)) - return - }, bignum.Interval{ + interval := bignum.Interval{ Nodes: deg, A: *new(big.Float).SetFloat64(a), B: *new(big.Float).SetFloat64(b), - }) + } - approxG := bignum.ChebyshevApproximation(func(x *bignum.Complex) (y *bignum.Complex) { - xf64, _ := x[0].Float64() - y = bignum.NewComplex().SetPrec(53) - y[0].SetFloat64(g(xf64)) - return - }, bignum.Interval{ - Nodes: deg, - A: *new(big.Float).SetFloat64(a), - B: *new(big.Float).SetFloat64(b), - }) + approxF := bignum.ChebyshevApproximation(f, interval) + approxG := bignum.ChebyshevApproximation(g, interval) // Map storing which polynomial has to be applied to which slot. slotsIndex := make(map[int][]int) diff --git a/utils/bignum/chebyshev_approximation.go b/utils/bignum/chebyshev_approximation.go index 6b49a1b9..86d78a04 100644 --- a/utils/bignum/chebyshev_approximation.go +++ b/utils/bignum/chebyshev_approximation.go @@ -5,13 +5,34 @@ import ( ) // ChebyshevApproximation computes a Chebyshev approximation of the input function, for the range [-a, b] of degree degree. -// function.(type) can be either : +// f.(type) can be either : // - func(Complex128)Complex128 // - func(float64)float64 // - func(*big.Float)*big.Float // - func(*Complex)*Complex // The reference precision is taken from the values stored in the Interval struct. -func ChebyshevApproximation(f func(*Complex) *Complex, interval Interval) (pol Polynomial) { +func ChebyshevApproximation(f interface{}, interval Interval) (pol Polynomial) { + + var fCmplx func(*Complex) *Complex + + switch f := f.(type) { + case func(x complex128) (y complex128): + fCmplx = func(x *Complex) (y *Complex) { + yCmplx := f(x.Complex128()) + return &Complex{new(big.Float).SetFloat64(real(yCmplx)), new(big.Float).SetFloat64(imag(yCmplx))} + } + case func(x float64) (y float64): + fCmplx = func(x *Complex) (y *Complex) { + xf64, _ := x[0].Float64() + return &Complex{new(big.Float).SetFloat64(f(xf64)), new(big.Float)} + } + case func(x *big.Float) (y *big.Float): + fCmplx = func(x *Complex) (y *Complex) { + return &Complex{f(x[0]), new(big.Float)} + } + case func(x *Complex) *Complex: + fCmplx = f + } nodes := chebyshevNodes(interval.Nodes+1, interval) @@ -22,7 +43,7 @@ func ChebyshevApproximation(f func(*Complex) *Complex, interval Interval) (pol P for i := range nodes { x[0].Set(nodes[i]) - fi[i] = f(x) + fi[i] = fCmplx(x) } return NewPolynomial(Chebyshev, chebyCoeffs(nodes, fi, interval), &interval) diff --git a/utils/bignum/remez_test.go b/utils/bignum/remez_test.go index 608bb15b..70e0204b 100644 --- a/utils/bignum/remez_test.go +++ b/utils/bignum/remez_test.go @@ -22,13 +22,13 @@ func TestApproximation(t *testing.T) { t.Run("Chebyshev", func(t *testing.T) { - interval := Interval{A: *NewFloat(-4, prec), B: *NewFloat(4, prec), Nodes: 47} - - f := func(x *Complex) (y *Complex) { - return &Complex{sigmoid(x[0]), new(big.Float)} + interval := Interval{ + Nodes: 47, + A: *NewFloat(-4, prec), + B: *NewFloat(4, prec), } - poly := ChebyshevApproximation(f, interval) + poly := ChebyshevApproximation(sigmoid, interval) xBig := NewFloat(1.4142135623730951, prec)