[bignum]: easier to user Chebyshev approximation

This commit is contained in:
Jean-Philippe Bossuat
2023-07-15 22:52:19 +02:00
parent 3591d28763
commit d56f37f35b
6 changed files with 47 additions and 59 deletions

View File

@@ -935,21 +935,13 @@ func testChebyshevInterpolator(tc *testContext, t *testing.T) {
prec := tc.params.PlaintextPrecision()
sin := func(x *bignum.Complex) (y *bignum.Complex) {
xf64, _ := x[0].Float64()
y = bignum.NewComplex()
y.SetPrec(prec)
y[0].SetFloat64(math.Sin(xf64))
return
}
interval := bignum.Interval{
Nodes: degree,
A: *new(big.Float).SetPrec(prec).SetFloat64(-8),
B: *new(big.Float).SetPrec(prec).SetFloat64(8),
}
poly := rlwe.NewPolynomial(bignum.ChebyshevApproximation(sin, interval))
poly := rlwe.NewPolynomial(bignum.ChebyshevApproximation(math.Sin, interval))
scalar, constant := poly.ChangeOfBasis()
eval.Mul(ciphertext, scalar, ciphertext)

View File

@@ -17,19 +17,18 @@ import (
// for the homomorphic modular reduction
type SineType uint64
func sin2pi(x *bignum.Complex) (y *bignum.Complex) {
y = bignum.NewComplex().Set(x)
y[0].Mul(y[0], new(big.Float).SetFloat64(2))
y[0].Mul(y[0], bignum.Pi(x.Prec()))
y[0] = bignum.Sin(y[0])
return
func sin2pi(x *big.Float) (y *big.Float) {
y = new(big.Float).Set(x)
y.Mul(y, new(big.Float).SetFloat64(2))
y.Mul(y, bignum.Pi(x.Prec()))
return bignum.Sin(y)
}
func cos2pi(x *bignum.Complex) (y *bignum.Complex) {
y = bignum.NewComplex().Set(x)
y[0].Mul(y[0], new(big.Float).SetFloat64(2))
y[0].Mul(y[0], bignum.Pi(x.Prec()))
y[0] = bignum.Cos(y[0])
func cos2pi(x *big.Float) (y *big.Float) {
y = new(big.Float).Set(x)
y.Mul(y, new(big.Float).SetFloat64(2))
y.Mul(y, bignum.Pi(x.Prec()))
y = bignum.Cos(y)
return y
}

View File

@@ -531,20 +531,9 @@ func main() {
// Let define a function, for example, the SiLU.
// The signature needed is `func(x *bignum.Complex) (y *bignum.Complex)` so we must accommodate for it first:
SiLU := func(x *bignum.Complex) (y *bignum.Complex) {
// Yes sigmoid over the complex!
sigmoid := func(x complex128) (y complex128) {
return 1 / (cmplx.Exp(-x) + 1)
}
ycmplx128 := x.Complex128()
ycmplx128 = ycmplx128 * sigmoid(ycmplx128)
y = bignum.NewComplex().SetPrec(prec).SetComplex128(ycmplx128)
return
// Yes SiLU over the complex!
SiLU := func(x complex128) (y complex128) {
return x / (cmplx.Exp(-x) + 1)
}
// We must also give an interval [a, b], for example [-8, 8], in which we approximate SiLU, as well as the degree of approximation.

View File

@@ -98,27 +98,14 @@ func chebyshevinterpolation() {
// Evaluation process
// We approximate f(x) in the range [-8, 8] with a Chebyshev interpolant of 33 coefficients (degree 32).
approxF := bignum.ChebyshevApproximation(func(x *bignum.Complex) (y *bignum.Complex) {
xf64, _ := x[0].Float64()
y = bignum.NewComplex().SetPrec(53)
y[0].SetFloat64(f(xf64))
return
}, bignum.Interval{
interval := bignum.Interval{
Nodes: deg,
A: *new(big.Float).SetFloat64(a),
B: *new(big.Float).SetFloat64(b),
})
}
approxG := bignum.ChebyshevApproximation(func(x *bignum.Complex) (y *bignum.Complex) {
xf64, _ := x[0].Float64()
y = bignum.NewComplex().SetPrec(53)
y[0].SetFloat64(g(xf64))
return
}, bignum.Interval{
Nodes: deg,
A: *new(big.Float).SetFloat64(a),
B: *new(big.Float).SetFloat64(b),
})
approxF := bignum.ChebyshevApproximation(f, interval)
approxG := bignum.ChebyshevApproximation(g, interval)
// Map storing which polynomial has to be applied to which slot.
slotsIndex := make(map[int][]int)

View File

@@ -5,13 +5,34 @@ import (
)
// ChebyshevApproximation computes a Chebyshev approximation of the input function, for the range [-a, b] of degree degree.
// function.(type) can be either :
// f.(type) can be either :
// - func(Complex128)Complex128
// - func(float64)float64
// - func(*big.Float)*big.Float
// - func(*Complex)*Complex
// The reference precision is taken from the values stored in the Interval struct.
func ChebyshevApproximation(f func(*Complex) *Complex, interval Interval) (pol Polynomial) {
func ChebyshevApproximation(f interface{}, interval Interval) (pol Polynomial) {
var fCmplx func(*Complex) *Complex
switch f := f.(type) {
case func(x complex128) (y complex128):
fCmplx = func(x *Complex) (y *Complex) {
yCmplx := f(x.Complex128())
return &Complex{new(big.Float).SetFloat64(real(yCmplx)), new(big.Float).SetFloat64(imag(yCmplx))}
}
case func(x float64) (y float64):
fCmplx = func(x *Complex) (y *Complex) {
xf64, _ := x[0].Float64()
return &Complex{new(big.Float).SetFloat64(f(xf64)), new(big.Float)}
}
case func(x *big.Float) (y *big.Float):
fCmplx = func(x *Complex) (y *Complex) {
return &Complex{f(x[0]), new(big.Float)}
}
case func(x *Complex) *Complex:
fCmplx = f
}
nodes := chebyshevNodes(interval.Nodes+1, interval)
@@ -22,7 +43,7 @@ func ChebyshevApproximation(f func(*Complex) *Complex, interval Interval) (pol P
for i := range nodes {
x[0].Set(nodes[i])
fi[i] = f(x)
fi[i] = fCmplx(x)
}
return NewPolynomial(Chebyshev, chebyCoeffs(nodes, fi, interval), &interval)

View File

@@ -22,13 +22,13 @@ func TestApproximation(t *testing.T) {
t.Run("Chebyshev", func(t *testing.T) {
interval := Interval{A: *NewFloat(-4, prec), B: *NewFloat(4, prec), Nodes: 47}
f := func(x *Complex) (y *Complex) {
return &Complex{sigmoid(x[0]), new(big.Float)}
interval := Interval{
Nodes: 47,
A: *NewFloat(-4, prec),
B: *NewFloat(4, prec),
}
poly := ChebyshevApproximation(f, interval)
poly := ChebyshevApproximation(sigmoid, interval)
xBig := NewFloat(1.4142135623730951, prec)