mirror of
https://github.com/tuneinsight/lattigo.git
synced 2025-09-13 03:27:14 +00:00
421 lines
12 KiB
Go
421 lines
12 KiB
Go
package float_test
|
|
|
|
import (
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"math"
|
|
"math/big"
|
|
"runtime"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/tuneinsight/lattigo/v4/circuits/float"
|
|
"github.com/tuneinsight/lattigo/v4/ckks"
|
|
"github.com/tuneinsight/lattigo/v4/ring"
|
|
"github.com/tuneinsight/lattigo/v4/rlwe"
|
|
"github.com/tuneinsight/lattigo/v4/utils"
|
|
"github.com/tuneinsight/lattigo/v4/utils/bignum"
|
|
"github.com/tuneinsight/lattigo/v4/utils/sampling"
|
|
)
|
|
|
|
var flagParamString = flag.String("params", "", "specify the test cryptographic parameters as a JSON string. Overrides -short.")
|
|
var printPrecisionStats = flag.Bool("print-precision", false, "print precision stats")
|
|
|
|
func GetTestName(params ckks.Parameters, opname string) string {
|
|
return fmt.Sprintf("%s/RingType=%s/logN=%d/logQP=%d/Qi=%d/Pi=%d/LogScale=%d",
|
|
opname,
|
|
params.RingType(),
|
|
params.LogN(),
|
|
int(math.Round(params.LogQP())),
|
|
params.QCount(),
|
|
params.PCount(),
|
|
int(math.Log2(params.DefaultScale().Float64())))
|
|
}
|
|
|
|
type ckksTestContext struct {
|
|
params ckks.Parameters
|
|
ringQ *ring.Ring
|
|
ringP *ring.Ring
|
|
prng sampling.PRNG
|
|
encoder *ckks.Encoder
|
|
kgen *rlwe.KeyGenerator
|
|
sk *rlwe.SecretKey
|
|
pk *rlwe.PublicKey
|
|
encryptorPk *rlwe.Encryptor
|
|
encryptorSk *rlwe.Encryptor
|
|
decryptor *rlwe.Decryptor
|
|
evaluator *ckks.Evaluator
|
|
}
|
|
|
|
func TestFloat(t *testing.T) {
|
|
|
|
var err error
|
|
|
|
var testParams []ckks.ParametersLiteral
|
|
switch {
|
|
case *flagParamString != "": // the custom test suite reads the parameters from the -params flag
|
|
testParams = append(testParams, ckks.ParametersLiteral{})
|
|
if err = json.Unmarshal([]byte(*flagParamString), &testParams[0]); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
default:
|
|
testParams = testParametersLiteral
|
|
}
|
|
|
|
for _, ringType := range []ring.Type{ring.Standard, ring.ConjugateInvariant} {
|
|
|
|
for _, paramsLiteral := range testParams {
|
|
|
|
paramsLiteral.RingType = ringType
|
|
|
|
if testing.Short() {
|
|
paramsLiteral.LogN = 10
|
|
}
|
|
|
|
var params ckks.Parameters
|
|
if params, err = ckks.NewParametersFromLiteral(paramsLiteral); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
var tc *ckksTestContext
|
|
if tc, err = genCKKSTestParams(params); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
for _, testSet := range []func(tc *ckksTestContext, t *testing.T){
|
|
testCKKSLinearTransformation,
|
|
testEvaluatePolynomial,
|
|
} {
|
|
testSet(tc, t)
|
|
runtime.GC()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func genCKKSTestParams(defaultParam ckks.Parameters) (tc *ckksTestContext, err error) {
|
|
|
|
tc = new(ckksTestContext)
|
|
|
|
tc.params = defaultParam
|
|
|
|
tc.kgen = ckks.NewKeyGenerator(tc.params)
|
|
|
|
tc.sk, tc.pk = tc.kgen.GenKeyPairNew()
|
|
|
|
tc.ringQ = defaultParam.RingQ()
|
|
if tc.params.PCount() != 0 {
|
|
tc.ringP = defaultParam.RingP()
|
|
}
|
|
|
|
if tc.prng, err = sampling.NewPRNG(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tc.encoder = ckks.NewEncoder(tc.params)
|
|
|
|
tc.encryptorPk = ckks.NewEncryptor(tc.params, tc.pk)
|
|
tc.encryptorSk = ckks.NewEncryptor(tc.params, tc.sk)
|
|
tc.decryptor = ckks.NewDecryptor(tc.params, tc.sk)
|
|
tc.evaluator = ckks.NewEvaluator(tc.params, rlwe.NewMemEvaluationKeySet(tc.kgen.GenRelinearizationKeyNew(tc.sk)))
|
|
|
|
return tc, nil
|
|
|
|
}
|
|
|
|
func newCKKSTestVectors(tc *ckksTestContext, encryptor *rlwe.Encryptor, a, b complex128, t *testing.T) (values []*bignum.Complex, pt *rlwe.Plaintext, ct *rlwe.Ciphertext) {
|
|
|
|
var err error
|
|
|
|
prec := tc.encoder.Prec()
|
|
|
|
pt = ckks.NewPlaintext(tc.params, tc.params.MaxLevel())
|
|
|
|
values = make([]*bignum.Complex, pt.Slots())
|
|
|
|
switch tc.params.RingType() {
|
|
case ring.Standard:
|
|
for i := range values {
|
|
values[i] = &bignum.Complex{
|
|
bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec),
|
|
bignum.NewFloat(sampling.RandFloat64(imag(a), imag(b)), prec),
|
|
}
|
|
}
|
|
case ring.ConjugateInvariant:
|
|
for i := range values {
|
|
values[i] = &bignum.Complex{
|
|
bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec),
|
|
new(big.Float),
|
|
}
|
|
}
|
|
default:
|
|
t.Fatal("invalid ring type")
|
|
}
|
|
|
|
tc.encoder.Encode(values, pt)
|
|
|
|
if encryptor != nil {
|
|
ct, err = encryptor.EncryptNew(pt)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
return values, pt, ct
|
|
}
|
|
|
|
func testCKKSLinearTransformation(tc *ckksTestContext, t *testing.T) {
|
|
|
|
params := tc.params
|
|
|
|
t.Run(GetTestName(params, "Average"), func(t *testing.T) {
|
|
|
|
values, _, ciphertext := newCKKSTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
|
|
|
|
slots := ciphertext.Slots()
|
|
|
|
logBatch := 9
|
|
batch := 1 << logBatch
|
|
n := slots / batch
|
|
|
|
eval := tc.evaluator.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.kgen.GenGaloisKeysNew(rlwe.GaloisElementsForInnerSum(params, batch, n), tc.sk)...))
|
|
|
|
require.NoError(t, eval.Average(ciphertext, logBatch, ciphertext))
|
|
|
|
tmp0 := make([]*bignum.Complex, len(values))
|
|
for i := range tmp0 {
|
|
tmp0[i] = values[i].Clone()
|
|
}
|
|
|
|
for i := 1; i < n; i++ {
|
|
|
|
tmp1 := utils.RotateSlice(tmp0, i*batch)
|
|
|
|
for j := range values {
|
|
values[j].Add(values[j], tmp1[j])
|
|
}
|
|
}
|
|
|
|
nB := new(big.Float).SetFloat64(float64(n))
|
|
|
|
for i := range values {
|
|
values[i][0].Quo(values[i][0], nB)
|
|
values[i][1].Quo(values[i][1], nB)
|
|
}
|
|
|
|
ckks.VerifyTestVectors(params, tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
|
})
|
|
|
|
t.Run(GetTestName(params, "LinearTransform/BSGS=True"), func(t *testing.T) {
|
|
|
|
values, _, ciphertext := newCKKSTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
|
|
|
|
slots := ciphertext.Slots()
|
|
|
|
nonZeroDiags := []int{-15, -4, -1, 0, 1, 2, 3, 4, 15}
|
|
|
|
one := new(big.Float).SetInt64(1)
|
|
zero := new(big.Float)
|
|
|
|
diagonals := make(float.Diagonals[*bignum.Complex])
|
|
for _, i := range nonZeroDiags {
|
|
diagonals[i] = make([]*bignum.Complex, slots)
|
|
|
|
for j := 0; j < slots; j++ {
|
|
diagonals[i][j] = &bignum.Complex{one, zero}
|
|
}
|
|
}
|
|
|
|
ltparams := float.LinearTransformationParameters{
|
|
DiagonalsIndexList: nonZeroDiags,
|
|
Level: ciphertext.Level(),
|
|
Scale: rlwe.NewScale(params.Q()[ciphertext.Level()]),
|
|
LogDimensions: ciphertext.LogDimensions,
|
|
LogBabyStepGianStepRatio: 1,
|
|
}
|
|
|
|
// Allocate the linear transformation
|
|
linTransf := float.NewLinearTransformation(params, ltparams)
|
|
|
|
// Encode on the linear transformation
|
|
require.NoError(t, float.EncodeLinearTransformation[*bignum.Complex](tc.encoder, diagonals, linTransf))
|
|
|
|
galEls := float.GaloisElementsForLinearTransformation(params, ltparams)
|
|
|
|
evk := rlwe.NewMemEvaluationKeySet(nil, tc.kgen.GenGaloisKeysNew(galEls, tc.sk)...)
|
|
|
|
ltEval := float.NewLinearTransformationEvaluator(tc.evaluator.WithKey(evk))
|
|
|
|
require.NoError(t, ltEval.Evaluate(ciphertext, linTransf, ciphertext))
|
|
|
|
tmp := make([]*bignum.Complex, len(values))
|
|
for i := range tmp {
|
|
tmp[i] = values[i].Clone()
|
|
}
|
|
|
|
for i := 0; i < slots; i++ {
|
|
values[i].Add(values[i], tmp[(i-15+slots)%slots])
|
|
values[i].Add(values[i], tmp[(i-4+slots)%slots])
|
|
values[i].Add(values[i], tmp[(i-1+slots)%slots])
|
|
values[i].Add(values[i], tmp[(i+1)%slots])
|
|
values[i].Add(values[i], tmp[(i+2)%slots])
|
|
values[i].Add(values[i], tmp[(i+3)%slots])
|
|
values[i].Add(values[i], tmp[(i+4)%slots])
|
|
values[i].Add(values[i], tmp[(i+15)%slots])
|
|
}
|
|
|
|
ckks.VerifyTestVectors(params, tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
|
})
|
|
|
|
t.Run(GetTestName(params, "LinearTransform/BSGS=False"), func(t *testing.T) {
|
|
|
|
values, _, ciphertext := newCKKSTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
|
|
|
|
slots := ciphertext.Slots()
|
|
|
|
nonZeroDiags := []int{-15, -4, -1, 0, 1, 2, 3, 4, 15}
|
|
|
|
one := new(big.Float).SetInt64(1)
|
|
zero := new(big.Float)
|
|
|
|
diagonals := make(float.Diagonals[*bignum.Complex])
|
|
for _, i := range nonZeroDiags {
|
|
diagonals[i] = make([]*bignum.Complex, slots)
|
|
|
|
for j := 0; j < slots; j++ {
|
|
diagonals[i][j] = &bignum.Complex{one, zero}
|
|
}
|
|
}
|
|
|
|
ltparams := float.LinearTransformationParameters{
|
|
DiagonalsIndexList: nonZeroDiags,
|
|
Level: ciphertext.Level(),
|
|
Scale: rlwe.NewScale(params.Q()[ciphertext.Level()]),
|
|
LogDimensions: ciphertext.LogDimensions,
|
|
LogBabyStepGianStepRatio: -1,
|
|
}
|
|
|
|
// Allocate the linear transformation
|
|
linTransf := float.NewLinearTransformation(params, ltparams)
|
|
|
|
// Encode on the linear transformation
|
|
require.NoError(t, float.EncodeLinearTransformation[*bignum.Complex](tc.encoder, diagonals, linTransf))
|
|
|
|
galEls := float.GaloisElementsForLinearTransformation(params, ltparams)
|
|
|
|
evk := rlwe.NewMemEvaluationKeySet(nil, tc.kgen.GenGaloisKeysNew(galEls, tc.sk)...)
|
|
|
|
ltEval := float.NewLinearTransformationEvaluator(tc.evaluator.WithKey(evk))
|
|
|
|
require.NoError(t, ltEval.Evaluate(ciphertext, linTransf, ciphertext))
|
|
|
|
tmp := make([]*bignum.Complex, len(values))
|
|
for i := range tmp {
|
|
tmp[i] = values[i].Clone()
|
|
}
|
|
|
|
for i := 0; i < slots; i++ {
|
|
values[i].Add(values[i], tmp[(i-15+slots)%slots])
|
|
values[i].Add(values[i], tmp[(i-4+slots)%slots])
|
|
values[i].Add(values[i], tmp[(i-1+slots)%slots])
|
|
values[i].Add(values[i], tmp[(i+1)%slots])
|
|
values[i].Add(values[i], tmp[(i+2)%slots])
|
|
values[i].Add(values[i], tmp[(i+3)%slots])
|
|
values[i].Add(values[i], tmp[(i+4)%slots])
|
|
values[i].Add(values[i], tmp[(i+15)%slots])
|
|
}
|
|
|
|
ckks.VerifyTestVectors(params, tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
|
})
|
|
}
|
|
|
|
func testEvaluatePolynomial(tc *ckksTestContext, t *testing.T) {
|
|
|
|
params := tc.params
|
|
|
|
var err error
|
|
|
|
polyEval := float.NewPolynomialEvaluator(params, tc.evaluator)
|
|
|
|
t.Run(GetTestName(params, "EvaluatePoly/PolySingle/Exp"), func(t *testing.T) {
|
|
|
|
if params.MaxLevel() < 3 {
|
|
t.Skip("skipping test for params max level < 3")
|
|
}
|
|
|
|
values, _, ciphertext := newCKKSTestVectors(tc, tc.encryptorSk, -1, 1, t)
|
|
|
|
prec := tc.encoder.Prec()
|
|
|
|
coeffs := []*big.Float{
|
|
bignum.NewFloat(1, prec),
|
|
bignum.NewFloat(1, prec),
|
|
new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(2, prec)),
|
|
new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(6, prec)),
|
|
new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(24, prec)),
|
|
new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(120, prec)),
|
|
new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(720, prec)),
|
|
new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(5040, prec)),
|
|
}
|
|
|
|
poly := bignum.NewPolynomial(bignum.Monomial, coeffs, nil)
|
|
|
|
for i := range values {
|
|
values[i] = poly.Evaluate(values[i])
|
|
}
|
|
|
|
if ciphertext, err = polyEval.Evaluate(ciphertext, poly, ciphertext.Scale); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
ckks.VerifyTestVectors(params, tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
|
})
|
|
|
|
t.Run(GetTestName(params, "Polynomial/PolyVector/Exp"), func(t *testing.T) {
|
|
|
|
if params.MaxLevel() < 3 {
|
|
t.Skip("skipping test for params max level < 3")
|
|
}
|
|
|
|
values, _, ciphertext := newCKKSTestVectors(tc, tc.encryptorSk, -1, 1, t)
|
|
|
|
prec := tc.encoder.Prec()
|
|
|
|
coeffs := []*big.Float{
|
|
bignum.NewFloat(1, prec),
|
|
bignum.NewFloat(1, prec),
|
|
new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(2, prec)),
|
|
new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(6, prec)),
|
|
new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(24, prec)),
|
|
new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(120, prec)),
|
|
new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(720, prec)),
|
|
new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(5040, prec)),
|
|
}
|
|
|
|
poly := bignum.NewPolynomial(bignum.Monomial, coeffs, nil)
|
|
|
|
slots := ciphertext.Slots()
|
|
|
|
slotIndex := make(map[int][]int)
|
|
idx := make([]int, slots>>1)
|
|
for i := 0; i < slots>>1; i++ {
|
|
idx[i] = 2 * i
|
|
}
|
|
|
|
slotIndex[0] = idx
|
|
|
|
valuesWant := make([]*bignum.Complex, slots)
|
|
for _, j := range idx {
|
|
valuesWant[j] = poly.Evaluate(values[j])
|
|
}
|
|
|
|
polyVector, err := float.NewPolynomialVector([]bignum.Polynomial{poly}, slotIndex)
|
|
require.NoError(t, err)
|
|
|
|
if ciphertext, err = polyEval.Evaluate(ciphertext, polyVector, ciphertext.Scale); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
ckks.VerifyTestVectors(params, tc.encoder, tc.decryptor, valuesWant, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
|
})
|
|
}
|