refactor test utilities

Add `BFV`, `BGV` and `CKKS` test utilities that can be reused
throughout the entirety of Lattigo. This reduces code duplication and
should simplify the creation of future unit tests.
This commit is contained in:
Andrea Caforio
2024-07-03 16:58:05 +02:00
parent 8378a3b4ae
commit a315439b5b
18 changed files with 1308 additions and 1802 deletions

View File

@@ -14,7 +14,6 @@ import (
"github.com/tuneinsight/lattigo/v5/core/rlwe"
"github.com/tuneinsight/lattigo/v5/ring"
"github.com/tuneinsight/lattigo/v5/schemes"
"github.com/tuneinsight/lattigo/v5/schemes/ckks"
"github.com/tuneinsight/lattigo/v5/utils"
"github.com/tuneinsight/lattigo/v5/utils/bignum"
@@ -24,34 +23,18 @@ import (
var flagParamString = flag.String("params", "", "specify the test cryptographic parameters as a JSON string. Overrides -short.")
var printPrecisionStats = flag.Bool("print-precision", false, "print precision stats")
func GetTestName(params ckks.Parameters, opname string) string {
func name(opname string, tc *ckks.TestContext) string {
return fmt.Sprintf("%s/RingType=%s/logN=%d/logQP=%d/Qi=%d/Pi=%d/LogScale=%d",
opname,
params.RingType(),
params.LogN(),
int(math.Round(params.LogQP())),
params.QCount(),
params.PCount(),
int(math.Log2(params.DefaultScale().Float64())))
}
type testContext struct {
params ckks.Parameters
ringQ *ring.Ring
ringP *ring.Ring
prng sampling.PRNG
encoder ckks.Encoder
kgen *rlwe.KeyGenerator
sk *rlwe.SecretKey
pk *rlwe.PublicKey
encryptorPk *rlwe.Encryptor
encryptorSk *rlwe.Encryptor
decryptor *rlwe.Decryptor
evaluator ckks.Evaluator
tc.Params.RingType(),
tc.Params.LogN(),
int(math.Round(tc.Params.LogQP())),
tc.Params.QCount(),
tc.Params.PCount(),
int(math.Log2(tc.Params.DefaultScale().Float64())))
}
func TestPolynomialEvaluator(t *testing.T) {
var err error
var testParams []ckks.ParametersLiteral
@@ -62,7 +45,7 @@ func TestPolynomialEvaluator(t *testing.T) {
t.Fatal(err)
}
default:
testParams = schemes.CkksTestParametersLiteral
testParams = ckks.TestParametersLiteral
}
for _, ringType := range []ring.Type{ring.Standard, ring.ConjugateInvariant} {
@@ -75,17 +58,9 @@ func TestPolynomialEvaluator(t *testing.T) {
paramsLiteral.LogN = 10
}
var params ckks.Parameters
if params, err = ckks.NewParametersFromLiteral(paramsLiteral); err != nil {
t.Fatal(err)
}
tc := ckks.NewTestContext(paramsLiteral)
var tc *testContext
if tc, err = genTestParams(params); err != nil {
t.Fatal(err)
}
for _, testSet := range []func(tc *testContext, t *testing.T){
for _, testSet := range []func(tc *ckks.TestContext, t *testing.T){
run,
} {
testSet(tc, t)
@@ -95,9 +70,8 @@ func TestPolynomialEvaluator(t *testing.T) {
}
}
func run(tc *testContext, t *testing.T) {
params := tc.params
func run(tc *ckks.TestContext, t *testing.T) {
params := tc.Params
mulCmplx := bignum.NewComplexMultiplier().Mul
@@ -119,7 +93,7 @@ func run(tc *testContext, t *testing.T) {
}
}
prec := tc.encoder.Prec()
prec := tc.Ecd.Prec()
newVec := func(size int) (vec []*bignum.Complex) {
vec = make([]*bignum.Complex, size)
@@ -129,9 +103,9 @@ func run(tc *testContext, t *testing.T) {
return
}
t.Run(GetTestName(params, "Average"), func(t *testing.T) {
t.Run(name("Average", tc), func(t *testing.T) {
values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values, _, ciphertext := tc.NewTestVector(-1-1i, 1+1i)
slots := ciphertext.Slots()
@@ -139,7 +113,7 @@ func run(tc *testContext, t *testing.T) {
batch := 1 << logBatch
n := slots / batch
eval := tc.evaluator.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.kgen.GenGaloisKeysNew(rlwe.GaloisElementsForInnerSum(params, batch, n), tc.sk)...))
eval := tc.Evl.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(rlwe.GaloisElementsForInnerSum(params, batch, n), tc.Sk)...))
require.NoError(t, eval.Average(ciphertext, logBatch, ciphertext))
@@ -164,12 +138,12 @@ func run(tc *testContext, t *testing.T) {
values[i][1].Quo(values[i][1], nB)
}
ckks.VerifyTestVectors(params, &tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
ckks.VerifyTestVectors(params, tc.Ecd, tc.Dec, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(params, "LinearTransform/BSGS=True"), func(t *testing.T) {
t.Run(name("LinearTransform/BSGS=True", tc), func(t *testing.T) {
values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values, _, ciphertext := tc.NewTestVector(-1-1i, 1+1i)
slots := ciphertext.Slots()
@@ -199,24 +173,24 @@ func run(tc *testContext, t *testing.T) {
linTransf := NewLinearTransformation(params, ltparams)
// Encode on the linear transformation
require.NoError(t, EncodeLinearTransformation(tc.encoder, diagonals, linTransf))
require.NoError(t, EncodeLinearTransformation(tc.Ecd, diagonals, linTransf))
galEls := linTransf.GaloisElements(params)
evk := rlwe.NewMemEvaluationKeySet(nil, tc.kgen.GenGaloisKeysNew(galEls, tc.sk)...)
evk := rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(galEls, tc.Sk)...)
ltEval := NewLinearTransformationEvaluator(tc.evaluator.WithKey(evk))
ltEval := NewLinearTransformationEvaluator(tc.Evl.WithKey(evk))
require.NoError(t, ltEval.Evaluate(ciphertext, linTransf, ciphertext))
values = diagonals.Evaluate(values, newVec, add, muladd)
ckks.VerifyTestVectors(params, &tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
ckks.VerifyTestVectors(params, tc.Ecd, tc.Dec, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(params, "LinearTransform/BSGS=False"), func(t *testing.T) {
t.Run(name("LinearTransform/BSGS=False", tc), func(t *testing.T) {
values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values, _, ciphertext := tc.NewTestVector(-1-1i, 1+1i)
slots := ciphertext.Slots()
@@ -247,22 +221,22 @@ func run(tc *testContext, t *testing.T) {
linTransf := NewLinearTransformation(params, ltparams)
// Encode on the linear transformation
require.NoError(t, EncodeLinearTransformation(tc.encoder, diagonals, linTransf))
require.NoError(t, EncodeLinearTransformation(tc.Ecd, diagonals, linTransf))
galEls := linTransf.GaloisElements(params)
evk := rlwe.NewMemEvaluationKeySet(nil, tc.kgen.GenGaloisKeysNew(galEls, tc.sk)...)
evk := rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(galEls, tc.Sk)...)
ltEval := NewLinearTransformationEvaluator(tc.evaluator.WithKey(evk))
ltEval := NewLinearTransformationEvaluator(tc.Evl.WithKey(evk))
require.NoError(t, ltEval.Evaluate(ciphertext, linTransf, ciphertext))
values = diagonals.Evaluate(values, newVec, add, muladd)
ckks.VerifyTestVectors(params, &tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
ckks.VerifyTestVectors(params, tc.Ecd, tc.Dec, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(params, "LinearTransform/Permutation"), func(t *testing.T) {
t.Run(name("LinearTransform/Permutation", tc), func(t *testing.T) {
idx := make([]int, params.MaxSlots())
for i := range idx {
idx[i] = i
@@ -289,7 +263,7 @@ func run(tc *testContext, t *testing.T) {
diagonals := Permutation[*bignum.Complex](permutation).GetDiagonals(params.LogMaxSlots())
values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values, _, ciphertext := tc.NewTestVector(-1-1i, 1+1i)
ltparams := LinearTransformationParameters{
DiagonalsIndexList: diagonals.DiagonalsIndexList(),
@@ -304,87 +278,18 @@ func run(tc *testContext, t *testing.T) {
linTransf := NewLinearTransformation(params, ltparams)
// Encode on the linear transformation
require.NoError(t, EncodeLinearTransformation(tc.encoder, diagonals, linTransf))
require.NoError(t, EncodeLinearTransformation(tc.Ecd, diagonals, linTransf))
galEls := linTransf.GaloisElements(params)
evk := rlwe.NewMemEvaluationKeySet(nil, tc.kgen.GenGaloisKeysNew(galEls, tc.sk)...)
evk := rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(galEls, tc.Sk)...)
ltEval := NewLinearTransformationEvaluator(tc.evaluator.WithKey(evk))
ltEval := NewLinearTransformationEvaluator(tc.Evl.WithKey(evk))
require.NoError(t, ltEval.Evaluate(ciphertext, linTransf, ciphertext))
values = diagonals.Evaluate(values, newVec, add, muladd)
ckks.VerifyTestVectors(params, &tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
ckks.VerifyTestVectors(params, tc.Ecd, tc.Dec, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
}
func genTestParams(defaultParam ckks.Parameters) (tc *testContext, err error) {
tc = new(testContext)
tc.params = defaultParam
tc.kgen = rlwe.NewKeyGenerator(tc.params)
tc.sk, tc.pk = tc.kgen.GenKeyPairNew()
tc.ringQ = defaultParam.RingQ()
if tc.params.PCount() != 0 {
tc.ringP = defaultParam.RingP()
}
if tc.prng, err = sampling.NewPRNG(); err != nil {
return nil, err
}
tc.encoder = *ckks.NewEncoder(tc.params)
tc.encryptorPk = rlwe.NewEncryptor(tc.params, tc.pk)
tc.encryptorSk = rlwe.NewEncryptor(tc.params, tc.sk)
tc.decryptor = rlwe.NewDecryptor(tc.params, tc.sk)
tc.evaluator = *ckks.NewEvaluator(tc.params, rlwe.NewMemEvaluationKeySet(tc.kgen.GenRelinearizationKeyNew(tc.sk)))
return tc, nil
}
func newTestVectors(tc *testContext, encryptor *rlwe.Encryptor, a, b complex128, t *testing.T) (values []*bignum.Complex, pt *rlwe.Plaintext, ct *rlwe.Ciphertext) {
var err error
prec := tc.params.EncodingPrecision()
pt = ckks.NewPlaintext(tc.params, tc.params.MaxLevel())
values = make([]*bignum.Complex, pt.Slots())
switch tc.params.RingType() {
case ring.Standard:
for i := range values {
values[i] = &bignum.Complex{
bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec),
bignum.NewFloat(sampling.RandFloat64(imag(a), imag(b)), prec),
}
}
case ring.ConjugateInvariant:
for i := range values {
values[i] = &bignum.Complex{
bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec),
new(big.Float),
}
}
default:
t.Fatal("invalid ring type")
}
tc.encoder.Encode(values, pt)
if encryptor != nil {
ct, err = encryptor.EncryptNew(pt)
require.NoError(t, err)
}
return values, pt, ct
}

View File

@@ -3,86 +3,23 @@ package ltint
import (
"encoding/json"
"flag"
"fmt"
"math"
"math/rand"
"runtime"
"slices"
"testing"
"github.com/stretchr/testify/require"
"github.com/tuneinsight/lattigo/v5/core/rlwe"
"github.com/tuneinsight/lattigo/v5/ring"
"github.com/tuneinsight/lattigo/v5/schemes"
"github.com/tuneinsight/lattigo/v5/schemes/bgv"
"github.com/tuneinsight/lattigo/v5/utils/sampling"
)
type testContext struct {
params bgv.Parameters
ringQ *ring.Ring
ringT *ring.Ring
prng sampling.PRNG
uSampler *ring.UniformSampler
encoder schemes.Encoder
kgen *rlwe.KeyGenerator
sk *rlwe.SecretKey
pk *rlwe.PublicKey
encryptorPk *rlwe.Encryptor
encryptorSk *rlwe.Encryptor
decryptor *rlwe.Decryptor
evaluator *bgv.Evaluator
testLevel []int
}
func genTestParams(params bgv.Parameters) (tc *testContext, err error) {
tc = new(testContext)
tc.params = params
if tc.prng, err = sampling.NewPRNG(); err != nil {
return nil, err
}
tc.ringQ = params.RingQ()
tc.ringT = params.RingT()
tc.uSampler = ring.NewUniformSampler(tc.prng, tc.ringT)
tc.kgen = rlwe.NewKeyGenerator(tc.params)
tc.sk, tc.pk = tc.kgen.GenKeyPairNew()
tc.encoder = bgv.NewEncoder(tc.params)
tc.encryptorPk = rlwe.NewEncryptor(tc.params, tc.pk)
tc.encryptorSk = rlwe.NewEncryptor(tc.params, tc.sk)
tc.decryptor = rlwe.NewDecryptor(tc.params, tc.sk)
tc.evaluator = bgv.NewEvaluator(tc.params, rlwe.NewMemEvaluationKeySet(tc.kgen.GenRelinearizationKeyNew(tc.sk)))
tc.testLevel = []int{0, params.MaxLevel()}
return
}
var flagPrintNoise = flag.Bool("print-noise", false, "print the residual noise")
var flagParamString = flag.String("params", "", "specify the test cryptographic parameters as a JSON string. Overrides -short.")
func GetTestName(opname string, p bgv.Parameters, lvl int) string {
return fmt.Sprintf("%s/LogN=%d/logQ=%d/logP=%d/LogSlots=%dx%d/logT=%d/Qi=%d/Pi=%d/lvl=%d",
opname,
p.LogN(),
int(math.Round(p.LogQ())),
int(math.Round(p.LogP())),
p.LogMaxDimensions().Rows,
p.LogMaxDimensions().Cols,
int(math.Round(p.LogT())),
p.QCount(),
p.PCount(),
lvl)
}
func TestPolynomialEvaluator(t *testing.T) {
func TestLinearTransformation(t *testing.T) {
var err error
paramsLiterals := schemes.BgvTestParams
paramsLiterals := bgv.TestParams
if *flagParamString != "" {
var jsonParams bgv.ParametersLiteral
@@ -94,23 +31,12 @@ func TestPolynomialEvaluator(t *testing.T) {
for _, p := range paramsLiterals[:] {
for _, plaintextModulus := range schemes.BgvTestPlaintextModulus[:] {
for _, plaintextModulus := range bgv.TestPlaintextModulus[:] {
p.PlaintextModulus = plaintextModulus
var params bgv.Parameters
if params, err = bgv.NewParametersFromLiteral(p); err != nil {
t.Error(err)
t.Fail()
}
tc := bgv.NewTestContext(p)
var tc *testContext
if tc, err = genTestParams(params); err != nil {
t.Error(err)
t.Fail()
}
for _, testSet := range []func(tc *testContext, t *testing.T){
for _, testSet := range []func(tc *bgv.TestContext, t *testing.T){
run,
} {
testSet(tc, t)
@@ -120,30 +46,26 @@ func TestPolynomialEvaluator(t *testing.T) {
}
}
func run(tc *testContext, t *testing.T) {
rT := tc.params.RingT().SubRings[0]
func run(tc *bgv.TestContext, t *testing.T) {
rT := tc.Params.RingT().SubRings[0]
add := func(a, b, c []uint64) {
rT.Add(a, b, c)
}
muladd := func(a, b, c []uint64) {
rT.MulCoeffsBarrettThenAdd(a, b, c)
}
newVec := func(size int) (vec []uint64) {
return make([]uint64, size)
}
params := tc.params
params := tc.Params
T := params.PlaintextModulus()
level := tc.params.MaxLevel()
t.Run("Evaluator/LinearTransformationBSGS=true/"+tc.String(), func(t *testing.T) {
t.Run(GetTestName("Evaluator/LinearTransformationBSGS=true", params, level), func(t *testing.T) {
values, _, ciphertext := newTestVectorsLvl(level, params.DefaultScale(), tc, tc.encryptorSk)
values, _, ciphertext := bgv.NewTestVector(params, tc.Ecd, tc.Enc, params.MaxLevel(), params.DefaultScale())
slots := ciphertext.Slots()
@@ -161,7 +83,7 @@ func run(tc *testContext, t *testing.T) {
DiagonalsIndexList: diagonals.DiagonalsIndexList(),
LevelQ: ciphertext.Level(),
LevelP: params.MaxLevelP(),
Scale: tc.params.DefaultScale(),
Scale: tc.Params.DefaultScale(),
LogDimensions: ciphertext.LogDimensions,
LogBabyStepGiantStepRatio: 1,
}
@@ -170,23 +92,23 @@ func run(tc *testContext, t *testing.T) {
linTransf := NewLinearTransformation(params, ltparams)
// Encode on the linear transformation
require.NoError(t, EncodeLinearTransformation(tc.encoder, diagonals, linTransf))
require.NoError(t, EncodeLinearTransformation(tc.Ecd, diagonals, linTransf))
galEls := linTransf.GaloisElements(params)
eval := tc.evaluator.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.kgen.GenGaloisKeysNew(galEls, tc.sk)...))
eval := tc.Evl.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(galEls, tc.Sk)...))
ltEval := NewLinearTransformationEvaluator(eval)
require.NoError(t, ltEval.Evaluate(ciphertext, linTransf, ciphertext))
values.Coeffs[0] = diagonals.Evaluate(values.Coeffs[0], newVec, add, muladd)
values = diagonals.Evaluate(values, newVec, add, muladd)
verifyTestVectors(tc, tc.decryptor, values, ciphertext, t)
bgv.VerifyTestVectors(params, tc.Ecd, tc.Dec, ciphertext, values, t)
})
t.Run(GetTestName("Evaluator/LinearTransformationBSGS=false", params, level), func(t *testing.T) {
t.Run("Evaluator/LinearTransformationBSGS=false"+tc.String(), func(t *testing.T) {
values, _, ciphertext := newTestVectorsLvl(level, params.DefaultScale(), tc, tc.encryptorSk)
values, _, ciphertext := bgv.NewTestVector(params, tc.Ecd, tc.Enc, params.MaxLevel(), params.DefaultScale())
slots := ciphertext.Slots()
@@ -213,21 +135,21 @@ func run(tc *testContext, t *testing.T) {
linTransf := NewLinearTransformation(params, ltparams)
// Encode on the linear transformation
require.NoError(t, EncodeLinearTransformation(tc.encoder, diagonals, linTransf))
require.NoError(t, EncodeLinearTransformation(tc.Ecd, diagonals, linTransf))
galEls := linTransf.GaloisElements(params)
eval := tc.evaluator.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.kgen.GenGaloisKeysNew(galEls, tc.sk)...))
eval := tc.Evl.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(galEls, tc.Sk)...))
ltEval := NewLinearTransformationEvaluator(eval)
require.NoError(t, ltEval.Evaluate(ciphertext, linTransf, ciphertext))
values.Coeffs[0] = diagonals.Evaluate(values.Coeffs[0], newVec, add, muladd)
values = diagonals.Evaluate(values, newVec, add, muladd)
verifyTestVectors(tc, tc.decryptor, values, ciphertext, t)
bgv.VerifyTestVectors(params, tc.Ecd, tc.Dec, ciphertext, values, t)
})
t.Run(GetTestName("Evaluator/LinearTransformation/Permutation", params, level), func(t *testing.T) {
t.Run("Evaluator/LinearTransformation/Permutation"+tc.String(), func(t *testing.T) {
idx := [2][]int{
make([]int, params.MaxSlots()>>1),
@@ -272,7 +194,7 @@ func run(tc *testContext, t *testing.T) {
diagonals := Permutation[uint64](permutation).GetDiagonals(params.LogMaxSlots())
values, _, ciphertext := newTestVectorsLvl(level, tc.params.NewScale(1), tc, tc.encryptorSk)
values, _, ciphertext := bgv.NewTestVector(params, tc.Ecd, tc.Enc, params.MaxLevel(), params.DefaultScale())
ltparams := LinearTransformationParameters{
DiagonalsIndexList: diagonals.DiagonalsIndexList(),
@@ -287,66 +209,18 @@ func run(tc *testContext, t *testing.T) {
linTransf := NewLinearTransformation(params, ltparams)
// Encode on the linear transformation
require.NoError(t, EncodeLinearTransformation(tc.encoder, diagonals, linTransf))
require.NoError(t, EncodeLinearTransformation(tc.Ecd, diagonals, linTransf))
galEls := linTransf.GaloisElements(params)
evk := rlwe.NewMemEvaluationKeySet(nil, tc.kgen.GenGaloisKeysNew(galEls, tc.sk)...)
evk := rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(galEls, tc.Sk)...)
ltEval := NewLinearTransformationEvaluator(tc.evaluator.WithKey(evk))
ltEval := NewLinearTransformationEvaluator(tc.Evl.WithKey(evk))
require.NoError(t, ltEval.Evaluate(ciphertext, linTransf, ciphertext))
values.Coeffs[0] = diagonals.Evaluate(values.Coeffs[0], newVec, add, muladd)
values = diagonals.Evaluate(values, newVec, add, muladd)
verifyTestVectors(tc, tc.decryptor, values, ciphertext, t)
bgv.VerifyTestVectors(params, tc.Ecd, tc.Dec, ciphertext, values, t)
})
}
func newTestVectorsLvl(level int, scale rlwe.Scale, tc *testContext, encryptor *rlwe.Encryptor) (coeffs ring.Poly, plaintext *rlwe.Plaintext, ciphertext *rlwe.Ciphertext) {
coeffs = tc.uSampler.ReadNew()
for i := range coeffs.Coeffs[0] {
coeffs.Coeffs[0][i] = uint64(i)
}
plaintext = bgv.NewPlaintext(tc.params, level)
plaintext.Scale = scale
tc.encoder.Encode(coeffs.Coeffs[0], plaintext)
if encryptor != nil {
var err error
ciphertext, err = encryptor.EncryptNew(plaintext)
if err != nil {
panic(err)
}
}
return coeffs, plaintext, ciphertext
}
func verifyTestVectors(tc *testContext, decryptor *rlwe.Decryptor, coeffs ring.Poly, element rlwe.ElementInterface[ring.Poly], t *testing.T) {
coeffsTest := make([]uint64, tc.params.MaxSlots())
switch el := element.(type) {
case *rlwe.Plaintext:
require.NoError(t, tc.encoder.Decode(el, coeffsTest))
case *rlwe.Ciphertext:
pt := decryptor.DecryptNew(el)
require.NoError(t, tc.encoder.Decode(pt, coeffsTest))
if *flagPrintNoise {
require.NoError(t, tc.encoder.Encode(coeffsTest, pt))
ct, err := tc.evaluator.SubNew(el, pt)
require.NoError(t, err)
vartmp, _, _ := rlwe.Norm(ct, decryptor)
t.Logf("STD(noise): %f\n", vartmp)
}
default:
t.Error("invalid test object to verify")
}
require.True(t, slices.Equal(coeffs.Coeffs[0], coeffsTest))
}

View File

@@ -11,45 +11,26 @@ import (
"github.com/stretchr/testify/require"
"github.com/tuneinsight/lattigo/v5/core/rlwe"
"github.com/tuneinsight/lattigo/v5/ring"
"github.com/tuneinsight/lattigo/v5/schemes"
"github.com/tuneinsight/lattigo/v5/schemes/ckks"
"github.com/tuneinsight/lattigo/v5/utils/bignum"
"github.com/tuneinsight/lattigo/v5/utils/sampling"
)
var flagParamString = flag.String("params", "", "specify the test cryptographic parameters as a JSON string. Overrides -short.")
var printPrecisionStats = flag.Bool("print-precision", false, "print precision stats")
func GetTestName(params ckks.Parameters, opname string) string {
func name(opname string, tc *ckks.TestContext) string {
return fmt.Sprintf("%s/RingType=%s/logN=%d/logQP=%d/Qi=%d/Pi=%d/LogScale=%d",
opname,
params.RingType(),
params.LogN(),
int(math.Round(params.LogQP())),
params.QCount(),
params.PCount(),
int(math.Log2(params.DefaultScale().Float64())))
}
type testContext struct {
params ckks.Parameters
ringQ *ring.Ring
ringP *ring.Ring
prng sampling.PRNG
encoder ckks.Encoder
kgen *rlwe.KeyGenerator
sk *rlwe.SecretKey
pk *rlwe.PublicKey
encryptorPk *rlwe.Encryptor
encryptorSk *rlwe.Encryptor
decryptor *rlwe.Decryptor
evaluator schemes.Evaluator
tc.Params.RingType(),
tc.Params.LogN(),
int(math.Round(tc.Params.LogQP())),
tc.Params.QCount(),
tc.Params.PCount(),
int(math.Log2(tc.Params.DefaultScale().Float64())))
}
func TestPolynomialEvaluator(t *testing.T) {
var err error
var testParams []ckks.ParametersLiteral
@@ -60,7 +41,7 @@ func TestPolynomialEvaluator(t *testing.T) {
t.Fatal(err)
}
default:
testParams = schemes.CkksTestParametersLiteral
testParams = ckks.TestParametersLiteral
}
for _, ringType := range []ring.Type{ring.Standard, ring.ConjugateInvariant} {
@@ -73,17 +54,9 @@ func TestPolynomialEvaluator(t *testing.T) {
paramsLiteral.LogN = 10
}
var params ckks.Parameters
if params, err = ckks.NewParametersFromLiteral(paramsLiteral); err != nil {
t.Fatal(err)
}
tc := ckks.NewTestContext(paramsLiteral)
var tc *testContext
if tc, err = genTestParams(params); err != nil {
t.Fatal(err)
}
for _, testSet := range []func(tc *testContext, t *testing.T){
for _, testSet := range []func(tc *ckks.TestContext, t *testing.T){
run,
} {
testSet(tc, t)
@@ -93,21 +66,21 @@ func TestPolynomialEvaluator(t *testing.T) {
}
}
func run(tc *testContext, t *testing.T) {
func run(tc *ckks.TestContext, t *testing.T) {
params := tc.params
params := tc.Params
var err error
polyEval := NewPolynomialEvaluator(params, tc.evaluator)
polyEval := NewPolynomialEvaluator(params, tc.Evl)
t.Run(GetTestName(params, "EvaluatePoly/PolySingle/Exp"), func(t *testing.T) {
t.Run(name("EvaluatePoly/PolySingle/Exp", tc), func(t *testing.T) {
if params.MaxLevel() < 3 {
t.Skip("skipping test for params max level < 3")
}
values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1, 1, t)
values, _, ciphertext := tc.NewTestVector(-1, 1)
prec := params.EncodingPrecision()
@@ -132,16 +105,16 @@ func run(tc *testContext, t *testing.T) {
t.Fatal(err)
}
ckks.VerifyTestVectors(params, &tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
ckks.VerifyTestVectors(params, tc.Ecd, tc.Dec, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(params, "Polynomial/PolyVector/Exp"), func(t *testing.T) {
t.Run(name("Polynomial/PolyVector/Exp", tc), func(t *testing.T) {
if params.MaxLevel() < 3 {
t.Skip("skipping test for params max level < 3")
}
values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1, 1, t)
values, _, ciphertext := tc.NewTestVector(-1, 1)
prec := params.EncodingPrecision()
@@ -180,75 +153,6 @@ func run(tc *testContext, t *testing.T) {
t.Fatal(err)
}
ckks.VerifyTestVectors(params, &tc.encoder, tc.decryptor, valuesWant, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
ckks.VerifyTestVectors(params, tc.Ecd, tc.Dec, valuesWant, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
}
func genTestParams(defaultParam ckks.Parameters) (tc *testContext, err error) {
tc = new(testContext)
tc.params = defaultParam
tc.kgen = rlwe.NewKeyGenerator(tc.params)
tc.sk, tc.pk = tc.kgen.GenKeyPairNew()
tc.ringQ = defaultParam.RingQ()
if tc.params.PCount() != 0 {
tc.ringP = defaultParam.RingP()
}
if tc.prng, err = sampling.NewPRNG(); err != nil {
return nil, err
}
tc.encoder = *ckks.NewEncoder(tc.params)
tc.encryptorPk = rlwe.NewEncryptor(tc.params, tc.pk)
tc.encryptorSk = rlwe.NewEncryptor(tc.params, tc.sk)
tc.decryptor = rlwe.NewDecryptor(tc.params, tc.sk)
tc.evaluator = ckks.NewEvaluator(tc.params, rlwe.NewMemEvaluationKeySet(tc.kgen.GenRelinearizationKeyNew(tc.sk)))
return tc, nil
}
func newTestVectors(tc *testContext, encryptor *rlwe.Encryptor, a, b complex128, t *testing.T) (values []*bignum.Complex, pt *rlwe.Plaintext, ct *rlwe.Ciphertext) {
var err error
prec := tc.params.EncodingPrecision()
pt = ckks.NewPlaintext(tc.params, tc.params.MaxLevel())
values = make([]*bignum.Complex, pt.Slots())
switch tc.params.RingType() {
case ring.Standard:
for i := range values {
values[i] = &bignum.Complex{
bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec),
bignum.NewFloat(sampling.RandFloat64(imag(a), imag(b)), prec),
}
}
case ring.ConjugateInvariant:
for i := range values {
values[i] = &bignum.Complex{
bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec),
new(big.Float),
}
}
default:
t.Fatal("invalid ring type")
}
tc.encoder.Encode(values, pt)
if encryptor != nil {
ct, err = encryptor.EncryptNew(pt)
require.NoError(t, err)
}
return values, pt, ct
}

View File

@@ -3,87 +3,22 @@ package polyint
import (
"encoding/json"
"flag"
"fmt"
"math"
"runtime"
"slices"
"testing"
"github.com/stretchr/testify/require"
"github.com/tuneinsight/lattigo/v5/core/rlwe"
"github.com/tuneinsight/lattigo/v5/ring"
"github.com/tuneinsight/lattigo/v5/schemes"
"github.com/tuneinsight/lattigo/v5/schemes/bgv"
"github.com/tuneinsight/lattigo/v5/utils/bignum"
"github.com/tuneinsight/lattigo/v5/utils/sampling"
)
type testContext struct {
params bgv.Parameters
ringQ *ring.Ring
ringT *ring.Ring
prng sampling.PRNG
uSampler *ring.UniformSampler
encoder *bgv.Encoder
kgen *rlwe.KeyGenerator
sk *rlwe.SecretKey
pk *rlwe.PublicKey
encryptorPk *rlwe.Encryptor
encryptorSk *rlwe.Encryptor
decryptor *rlwe.Decryptor
evaluator *bgv.Evaluator
testLevel []int
}
func genTestParams(params bgv.Parameters) (tc *testContext, err error) {
tc = new(testContext)
tc.params = params
if tc.prng, err = sampling.NewPRNG(); err != nil {
return nil, err
}
tc.ringQ = params.RingQ()
tc.ringT = params.RingT()
tc.uSampler = ring.NewUniformSampler(tc.prng, tc.ringT)
tc.kgen = rlwe.NewKeyGenerator(tc.params)
tc.sk, tc.pk = tc.kgen.GenKeyPairNew()
tc.encoder = bgv.NewEncoder(tc.params)
tc.encryptorPk = rlwe.NewEncryptor(tc.params, tc.pk)
tc.encryptorSk = rlwe.NewEncryptor(tc.params, tc.sk)
tc.decryptor = rlwe.NewDecryptor(tc.params, tc.sk)
tc.evaluator = bgv.NewEvaluator(tc.params, rlwe.NewMemEvaluationKeySet(tc.kgen.GenRelinearizationKeyNew(tc.sk)))
tc.testLevel = []int{0, params.MaxLevel()}
return
}
var flagPrintNoise = flag.Bool("print-noise", false, "print the residual noise")
var flagParamString = flag.String("params", "", "specify the test cryptographic parameters as a JSON string. Overrides -short.")
func GetTestName(opname string, p bgv.Parameters, lvl int) string {
return fmt.Sprintf("%s/LogN=%d/logQ=%d/logP=%d/LogSlots=%dx%d/logT=%d/Qi=%d/Pi=%d/lvl=%d",
opname,
p.LogN(),
int(math.Round(p.LogQ())),
int(math.Round(p.LogP())),
p.LogMaxDimensions().Rows,
p.LogMaxDimensions().Cols,
int(math.Round(p.LogT())),
p.QCount(),
p.PCount(),
lvl)
}
func TestPolynomialEvaluator(t *testing.T) {
var err error
paramsLiterals := schemes.BgvTestParams
paramsLiterals := bgv.TestParams
if *flagParamString != "" {
var jsonParams bgv.ParametersLiteral
@@ -95,23 +30,12 @@ func TestPolynomialEvaluator(t *testing.T) {
for _, p := range paramsLiterals[:] {
for _, plaintextModulus := range schemes.BgvTestPlaintextModulus[:] {
for _, plaintextModulus := range bgv.TestPlaintextModulus[:] {
p.PlaintextModulus = plaintextModulus
var params bgv.Parameters
if params, err = bgv.NewParametersFromLiteral(p); err != nil {
t.Error(err)
t.Fail()
}
tc := bgv.NewTestContext(p)
var tc *testContext
if tc, err = genTestParams(params); err != nil {
t.Error(err)
t.Fail()
}
for _, testSet := range []func(tc *testContext, t *testing.T){
for _, testSet := range []func(tc *bgv.TestContext, t *testing.T){
run,
} {
testSet(tc, t)
@@ -121,95 +45,47 @@ func TestPolynomialEvaluator(t *testing.T) {
}
}
func run(tc *testContext, t *testing.T) {
func run(tc *bgv.TestContext, t *testing.T) {
t.Run("Single", func(t *testing.T) {
if tc.params.MaxLevel() < 4 {
if tc.Params.MaxLevel() < 4 {
t.Skip("MaxLevel() to low")
}
values, _, ciphertext := newTestVectorsLvl(tc.params.MaxLevel(), tc.params.NewScale(1), tc, tc.encryptorSk)
values, _, ciphertext := bgv.NewTestVector(tc.Params, tc.Ecd, tc.Enc, tc.Params.MaxLevel(), tc.Params.DefaultScale())
coeffs := []uint64{0, 0, 1}
T := tc.params.PlaintextModulus()
for i := range values.Coeffs[0] {
values.Coeffs[0][i] = ring.EvalPolyModP(values.Coeffs[0][i], coeffs, T)
T := tc.Params.PlaintextModulus()
for i := range values {
values[i] = ring.EvalPolyModP(values[i], coeffs, T)
}
poly := bignum.NewPolynomial(bignum.Monomial, coeffs, nil)
t.Run(GetTestName("Standard", tc.params, tc.params.MaxLevel()), func(t *testing.T) {
t.Run("Standard"+tc.String(), func(t *testing.T) {
polyEval := NewPolynomialEvaluator(tc.params, tc.evaluator, false)
polyEval := NewPolynomialEvaluator(tc.Params, tc.Evl, false)
res, err := polyEval.Evaluate(ciphertext, poly, tc.params.DefaultScale())
res, err := polyEval.Evaluate(ciphertext, poly, tc.Params.DefaultScale())
require.NoError(t, err)
require.Equal(t, res.Scale.Cmp(tc.params.DefaultScale()), 0)
require.Equal(t, res.Scale.Cmp(tc.Params.DefaultScale()), 0)
verifyTestVectors(tc, tc.decryptor, values, res, t)
bgv.VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, res, values, t)
})
t.Run(GetTestName("Invariant", tc.params, tc.params.MaxLevel()), func(t *testing.T) {
t.Run("Invariant"+tc.String(), func(t *testing.T) {
polyEval := NewPolynomialEvaluator(tc.params, tc.evaluator, true)
polyEval := NewPolynomialEvaluator(tc.Params, tc.Evl, true)
res, err := polyEval.Evaluate(ciphertext, poly, tc.params.DefaultScale())
res, err := polyEval.Evaluate(ciphertext, poly, tc.Params.DefaultScale())
require.NoError(t, err)
require.Equal(t, res.Level(), ciphertext.Level())
require.Equal(t, res.Scale.Cmp(tc.params.DefaultScale()), 0)
require.Equal(t, res.Scale.Cmp(tc.Params.DefaultScale()), 0)
verifyTestVectors(tc, tc.decryptor, values, res, t)
bgv.VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, res, values, t)
})
})
}
func newTestVectorsLvl(level int, scale rlwe.Scale, tc *testContext, encryptor *rlwe.Encryptor) (coeffs ring.Poly, plaintext *rlwe.Plaintext, ciphertext *rlwe.Ciphertext) {
coeffs = tc.uSampler.ReadNew()
for i := range coeffs.Coeffs[0] {
coeffs.Coeffs[0][i] = uint64(i)
}
plaintext = bgv.NewPlaintext(tc.params, level)
plaintext.Scale = scale
tc.encoder.Encode(coeffs.Coeffs[0], plaintext)
if encryptor != nil {
var err error
ciphertext, err = encryptor.EncryptNew(plaintext)
if err != nil {
panic(err)
}
}
return coeffs, plaintext, ciphertext
}
func verifyTestVectors(tc *testContext, decryptor *rlwe.Decryptor, coeffs ring.Poly, element rlwe.ElementInterface[ring.Poly], t *testing.T) {
coeffsTest := make([]uint64, tc.params.MaxSlots())
switch el := element.(type) {
case *rlwe.Plaintext:
require.NoError(t, tc.encoder.Decode(el, coeffsTest))
case *rlwe.Ciphertext:
pt := decryptor.DecryptNew(el)
require.NoError(t, tc.encoder.Decode(pt, coeffsTest))
if *flagPrintNoise {
require.NoError(t, tc.encoder.Encode(coeffsTest, pt))
ct, err := tc.evaluator.SubNew(el, pt)
require.NoError(t, err)
vartmp, _, _ := rlwe.Norm(ct, decryptor)
t.Logf("STD(noise): %f\n", vartmp)
}
default:
t.Error("invalid test object to verify")
}
require.True(t, slices.Equal(coeffs.Coeffs[0], coeffsTest))
}

View File

@@ -2,25 +2,13 @@ package bfv
import (
"encoding/json"
"fmt"
"runtime"
"testing"
"github.com/tuneinsight/lattigo/v5/core/rlwe"
)
func GetBenchName(params Parameters, opname string) string {
return fmt.Sprintf("%s/logN=%d/Qi=%d/Pi=%d/LogSlots=%d",
opname,
params.LogN(),
params.QCount(),
params.PCount(),
params.LogMaxSlots())
}
func BenchmarkBFV(b *testing.B) {
var err error
var testParams []ParametersLiteral
@@ -42,20 +30,9 @@ func BenchmarkBFV(b *testing.B) {
}
for _, paramsLiteral := range testParams {
tc := NewTestContext(paramsLiteral)
var params Parameters
if params, err = NewParametersFromLiteral(paramsLiteral); err != nil {
b.Error(err)
b.Fail()
}
var tc *testContext
if tc, err = genTestParams(params); err != nil {
b.Error(err)
b.Fail()
}
for _, testSet := range []func(tc *testContext, b *testing.B){
for _, testSet := range []func(tc *TestContext, b *testing.B){
benchEncoder,
benchEvaluator,
} {
@@ -65,11 +42,11 @@ func BenchmarkBFV(b *testing.B) {
}
}
func benchEncoder(tc *testContext, b *testing.B) {
func benchEncoder(tc *TestContext, b *testing.B) {
params := tc.params
params := tc.Params
poly := tc.uSampler.ReadNew()
poly := tc.Sampler.ReadNew()
params.RingT().Reduce(poly, poly)
coeffsUint64 := poly.Coeffs[0]
coeffsInt64 := make([]int64, len(coeffsUint64))
@@ -77,12 +54,12 @@ func benchEncoder(tc *testContext, b *testing.B) {
coeffsInt64[i] = int64(coeffsUint64[i])
}
encoder := tc.encoder
encoder := tc.Ecd
level := params.MaxLevel()
plaintext := NewPlaintext(params, level)
b.Run(GetBenchName(params, "Encoder/Encode/Uint"), func(b *testing.B) {
b.Run(name("Encoder/Encode/Uint", tc, level), func(b *testing.B) {
for i := 0; i < b.N; i++ {
if err := encoder.Encode(coeffsUint64, plaintext); err != nil {
b.Log(err)
@@ -91,7 +68,7 @@ func benchEncoder(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Encoder/Encode/Int"), func(b *testing.B) {
b.Run(name("Encoder/Encode/Int", tc, level), func(b *testing.B) {
for i := 0; i < b.N; i++ {
if err := encoder.Encode(coeffsInt64, plaintext); err != nil {
b.Log(err)
@@ -100,7 +77,7 @@ func benchEncoder(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Encoder/Decode/Uint"), func(b *testing.B) {
b.Run(name("Encoder/Decode/Uint", tc, level), func(b *testing.B) {
for i := 0; i < b.N; i++ {
if err := encoder.Decode(plaintext, coeffsUint64); err != nil {
b.Log(err)
@@ -109,7 +86,7 @@ func benchEncoder(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Encoder/Decode/Int"), func(b *testing.B) {
b.Run(name("Encoder/Decode/Int", tc, level), func(b *testing.B) {
for i := 0; i < b.N; i++ {
if err := encoder.Decode(plaintext, coeffsInt64); err != nil {
b.Log(err)
@@ -119,16 +96,18 @@ func benchEncoder(tc *testContext, b *testing.B) {
})
}
func benchEvaluator(tc *testContext, b *testing.B) {
func benchEvaluator(tc *TestContext, b *testing.B) {
params := tc.params
eval := tc.evaluator
params := tc.Params
eval := tc.Evl
plaintext := NewPlaintext(params, params.MaxLevel())
plaintext.Value = rlwe.NewCiphertextRandom(tc.prng, params.Parameters, 0, plaintext.Level()).Value[0]
level := params.MaxLevel()
ciphertext1 := rlwe.NewCiphertextRandom(tc.prng, params.Parameters, 1, params.MaxLevel())
ciphertext2 := rlwe.NewCiphertextRandom(tc.prng, params.Parameters, 1, params.MaxLevel())
plaintext := NewPlaintext(params, level)
plaintext.Value = rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 0, plaintext.Level()).Value[0]
ciphertext1 := rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 1, level)
ciphertext2 := rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 1, level)
scalar := params.PlaintextModulus() >> 1
*ciphertext1.MetaData = *plaintext.MetaData
@@ -136,7 +115,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
vector := plaintext.Value.Coeffs[0][:params.MaxSlots()]
b.Run(GetBenchName(params, "Evaluator/Add/Scalar"), func(b *testing.B) {
b.Run(name("Evaluator/Add/Scalar", tc, level), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -147,7 +126,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Add/Vector"), func(b *testing.B) {
b.Run(name("Evaluator/Add/Vector", tc, level), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -158,7 +137,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Add/Plaintext"), func(b *testing.B) {
b.Run(name("Evaluator/Add/Plaintext", tc, level), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -169,7 +148,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Add/Ciphertext"), func(b *testing.B) {
b.Run(name("Evaluator/Add/Ciphertext", tc, level), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -180,7 +159,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Mul/Scalar"), func(b *testing.B) {
b.Run(name("Evaluator/Mul/Scalar", tc, level), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -191,7 +170,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Mul/Plaintext"), func(b *testing.B) {
b.Run(name("Evaluator/Mul/Plaintext", tc, level), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -202,7 +181,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Mul/Vector"), func(b *testing.B) {
b.Run(name("Evaluator/Mul/Vector", tc, level), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -213,7 +192,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Mul/Ciphertext"), func(b *testing.B) {
b.Run(name("Evaluator/Mul/Ciphertext", tc, level), func(b *testing.B) {
receiver := NewCiphertext(params, 2, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -224,7 +203,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/MulRelin/Ciphertext"), func(b *testing.B) {
b.Run(name("Evaluator/MulRelin/Ciphertext", tc, level), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -235,7 +214,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/MulThenAdd/Scalar"), func(b *testing.B) {
b.Run(name("Evaluator/MulThenAdd/Scalar", tc, level), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -246,7 +225,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/MulThenAdd/Vector"), func(b *testing.B) {
b.Run(name("Evaluator/MulThenAdd/Vector", tc, level), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -257,7 +236,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/MulThenAdd/Plaintext"), func(b *testing.B) {
b.Run(name("Evaluator/MulThenAdd/Plaintext", tc, level), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -268,7 +247,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/MulThenAdd/Ciphertext"), func(b *testing.B) {
b.Run(name("Evaluator/MulThenAdd/Ciphertext", tc, level), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -279,7 +258,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/MulRelinThenAdd/Ciphertext"), func(b *testing.B) {
b.Run(name("Evaluator/MulRelinThenAdd/Ciphertext", tc, level), func(b *testing.B) {
receiver := NewCiphertext(params, 2, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -290,8 +269,8 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Rotate"), func(b *testing.B) {
gk := tc.kgen.GenGaloisKeyNew(5, tc.sk)
b.Run(name("Evaluator/Rotate", tc, level), func(b *testing.B) {
gk := tc.Kgen.GenGaloisKeyNew(5, tc.Sk)
evk := rlwe.NewMemEvaluationKeySet(nil, gk)
eval := eval.WithKey(evk)
receiver := NewCiphertext(params, 1, ciphertext2.Level())

View File

@@ -4,14 +4,12 @@ import (
"encoding/json"
"flag"
"fmt"
"math"
"runtime"
"slices"
"testing"
"github.com/tuneinsight/lattigo/v5/core/rlwe"
"github.com/tuneinsight/lattigo/v5/ring"
"github.com/tuneinsight/lattigo/v5/utils/sampling"
"github.com/stretchr/testify/require"
)
@@ -19,23 +17,14 @@ import (
var flagPrintNoise = flag.Bool("print-noise", false, "print the residual noise")
var flagParamString = flag.String("params", "", "specify the test cryptographic parameters as a JSON string. Overrides -short.")
func GetTestName(opname string, p Parameters, lvl int) string {
return fmt.Sprintf("%s/LogN=%d/logQ=%d/logP=%d/logT=%d/Qi=%d/Pi=%d/lvl=%d",
opname,
p.LogN(),
int(math.Round(p.LogQ())),
int(math.Round(p.LogP())),
int(math.Round(p.LogT())),
p.QCount(),
p.PCount(),
lvl)
func name(op string, tc *TestContext, lvl int) string {
return fmt.Sprintf("%s/%s/lvl=%d", op, tc, lvl)
}
func TestBFV(t *testing.T) {
var err error
paramsLiterals := testParams
paramsLiterals := TestParams
if *flagParamString != "" {
var jsonParams ParametersLiteral
@@ -47,17 +36,13 @@ func TestBFV(t *testing.T) {
for _, p := range paramsLiterals[:] {
for _, plaintextModulus := range testPlaintextModulus[:] {
for _, plaintextModulus := range TestPlaintextModulus[:] {
p.PlaintextModulus = plaintextModulus
params, err := NewParametersFromLiteral(p)
require.NoError(t, err)
tc := NewTestContext(p)
tc, err := genTestParams(params)
require.NoError(t, err)
for _, testSet := range []func(tc *testContext, t *testing.T){
for _, testSet := range []func(tc *TestContext, t *testing.T){
testParameters,
testEncoder,
testEvaluator,
@@ -69,110 +54,18 @@ func TestBFV(t *testing.T) {
}
}
type testContext struct {
params Parameters
ringQ *ring.Ring
ringT *ring.Ring
prng sampling.PRNG
uSampler *ring.UniformSampler
encoder *Encoder
kgen *rlwe.KeyGenerator
sk *rlwe.SecretKey
pk *rlwe.PublicKey
encryptorPk *rlwe.Encryptor
encryptorSk *rlwe.Encryptor
decryptor *rlwe.Decryptor
evaluator *Evaluator
testLevel []int
}
func genTestParams(params Parameters) (tc *testContext, err error) {
tc = new(testContext)
tc.params = params
if tc.prng, err = sampling.NewPRNG(); err != nil {
return nil, err
}
tc.ringQ = params.RingQ()
tc.ringT = params.RingT()
tc.uSampler = ring.NewUniformSampler(tc.prng, tc.ringT)
tc.kgen = NewKeyGenerator(tc.params)
tc.sk, tc.pk = tc.kgen.GenKeyPairNew()
tc.encoder = NewEncoder(tc.params)
tc.encryptorPk = NewEncryptor(tc.params, tc.pk)
tc.encryptorSk = NewEncryptor(tc.params, tc.sk)
tc.decryptor = NewDecryptor(tc.params, tc.sk)
tc.evaluator = NewEvaluator(tc.params, rlwe.NewMemEvaluationKeySet(tc.kgen.GenRelinearizationKeyNew(tc.sk)))
tc.testLevel = []int{0, params.MaxLevel()}
return
}
func newTestVectorsLvl(level int, scale rlwe.Scale, tc *testContext, encryptor *rlwe.Encryptor) (coeffs ring.Poly, plaintext *rlwe.Plaintext, ciphertext *rlwe.Ciphertext) {
coeffs = tc.uSampler.ReadNew()
for i := range coeffs.Coeffs[0] {
coeffs.Coeffs[0][i] = uint64(i)
}
plaintext = NewPlaintext(tc.params, level)
plaintext.Scale = scale
tc.encoder.Encode(coeffs.Coeffs[0], plaintext)
if encryptor != nil {
var err error
ciphertext, err = encryptor.EncryptNew(plaintext)
if err != nil {
panic(err)
}
}
return coeffs, plaintext, ciphertext
}
func verifyTestVectors(tc *testContext, decryptor *rlwe.Decryptor, coeffs ring.Poly, element rlwe.ElementInterface[ring.Poly], t *testing.T) {
coeffsTest := make([]uint64, tc.params.MaxSlots())
switch el := element.(type) {
case *rlwe.Plaintext:
require.NoError(t, tc.encoder.Decode(el, coeffsTest))
case *rlwe.Ciphertext:
pt := decryptor.DecryptNew(el)
require.NoError(t, tc.encoder.Decode(pt, coeffsTest))
if *flagPrintNoise {
require.NoError(t, tc.encoder.Encode(coeffsTest, pt))
ct, err := tc.evaluator.SubNew(el, pt)
require.NoError(t, err)
vartmp, _, _ := rlwe.Norm(ct, decryptor)
t.Logf("STD(noise): %f\n", vartmp)
}
default:
t.Fatal("invalid test object to verify")
}
require.True(t, slices.Equal(coeffs.Coeffs[0], coeffsTest))
}
func testParameters(tc *testContext, t *testing.T) {
t.Run(GetTestName("Parameters/Marshaller/Binary", tc.params, 0), func(t *testing.T) {
bytes, err := tc.params.MarshalBinary()
func testParameters(tc *TestContext, t *testing.T) {
t.Run(name("Parameters/Marshaller/Binary", tc, 0), func(t *testing.T) {
bytes, err := tc.Params.MarshalBinary()
require.Nil(t, err)
var p Parameters
require.Nil(t, p.UnmarshalBinary(bytes))
require.True(t, tc.params.Equal(&p))
require.True(t, tc.Params.Equal(&p))
})
t.Run(GetTestName("Parameters/Marshaller/JSON", tc.params, 0), func(t *testing.T) {
t.Run(name("Parameters/Marshaller/JSON", tc, 0), func(t *testing.T) {
// checks that parameters can be marshalled without error
data, err := json.Marshal(tc.params)
data, err := json.Marshal(tc.Params)
require.Nil(t, err)
require.NotNil(t, data)
@@ -180,10 +73,10 @@ func testParameters(tc *testContext, t *testing.T) {
var paramsRec Parameters
err = json.Unmarshal(data, &paramsRec)
require.Nil(t, err)
require.True(t, tc.params.Equal(&paramsRec))
require.True(t, tc.Params.Equal(&paramsRec))
// checks that the Parameters can be unmarshalled with log-moduli definition without error
dataWithLogModuli := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[60], "PlaintextModulus":65537}`, tc.params.LogN()))
dataWithLogModuli := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[60], "PlaintextModulus":65537}`, tc.Params.LogN()))
var paramsWithLogModuli Parameters
err = json.Unmarshal(dataWithLogModuli, &paramsWithLogModuli)
require.Nil(t, err)
@@ -193,7 +86,7 @@ func testParameters(tc *testContext, t *testing.T) {
require.Equal(t, rlwe.DefaultXs, paramsWithLogModuli.Xs()) // Omitting Xe should result in Default being used
// checks that one can provide custom parameters for the secret-key and error distributions
dataWithCustomSecrets := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[60], "PlaintextModulus":65537, "Xs": {"Type": "Ternary", "H": 192}, "Xe": {"Type": "DiscreteGaussian", "Sigma": 6.6, "Bound": 39.6}}`, tc.params.LogN()))
dataWithCustomSecrets := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[60], "PlaintextModulus":65537, "Xs": {"Type": "Ternary", "H": 192}, "Xe": {"Type": "DiscreteGaussian", "Sigma": 6.6, "Bound": 39.6}}`, tc.Params.LogN()))
var paramsWithCustomSecrets Parameters
err = json.Unmarshal(dataWithCustomSecrets, &paramsWithCustomSecrets)
require.Nil(t, err)
@@ -202,21 +95,22 @@ func testParameters(tc *testContext, t *testing.T) {
})
}
func testEncoder(tc *testContext, t *testing.T) {
func testEncoder(tc *TestContext, t *testing.T) {
testLevels := []int{0, tc.Params.MaxLevel()}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Encoder/Uint", tc.params, lvl), func(t *testing.T) {
values, plaintext, _ := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, nil)
verifyTestVectors(tc, nil, values, plaintext, t)
for _, lvl := range testLevels {
t.Run(name("Encoder/Uint", tc, lvl), func(t *testing.T) {
values, plaintext, _ := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale())
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, plaintext, values, t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Encoder/Int", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevels {
t.Run(name("Encoder/Int", tc, lvl), func(t *testing.T) {
T := tc.params.PlaintextModulus()
T := tc.Params.PlaintextModulus()
THalf := T >> 1
coeffs := tc.uSampler.ReadNew()
coeffs := tc.Sampler.ReadNew()
coeffsInt := make([]int64, len(coeffs.Coeffs[0]))
for i, c := range coeffs.Coeffs[0] {
c %= T
@@ -227,317 +121,335 @@ func testEncoder(tc *testContext, t *testing.T) {
}
}
plaintext := NewPlaintext(tc.params, lvl)
tc.encoder.Encode(coeffsInt, plaintext)
have := make([]int64, tc.params.MaxSlots())
tc.encoder.Decode(plaintext, have)
plaintext := NewPlaintext(tc.Params, lvl)
tc.Ecd.Encode(coeffsInt, plaintext)
have := make([]int64, tc.Params.MaxSlots())
tc.Ecd.Decode(plaintext, have)
require.True(t, slices.Equal(coeffsInt, have))
})
}
}
func testEvaluator(tc *testContext, t *testing.T) {
func testEvaluator(tc *TestContext, t *testing.T) {
testLevels := []int{0, tc.Params.MaxLevel()}
t.Run("Evaluator", func(t *testing.T) {
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Add/Ct/Ct/New", tc.params, lvl), func(t *testing.T) {
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
for _, lvl := range testLevels {
t.Run(name("Add/Ct/Ct/New", tc, lvl), func(t *testing.T) {
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0)
ciphertext2, err := tc.evaluator.AddNew(ciphertext0, ciphertext1)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
ciphertext2, err := tc.Evl.AddNew(ciphertext0, ciphertext1)
require.NoError(t, err)
tc.ringT.Add(values0, values1, values0)
verifyTestVectors(tc, tc.decryptor, values0, ciphertext2, t)
tc.Params.RingT().Add(p0, p1, p0)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext2, p0.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Add/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) {
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
for _, lvl := range testLevels {
t.Run(name("Add/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) {
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0)
require.NoError(t, tc.evaluator.Add(ciphertext0, ciphertext1, ciphertext0))
tc.ringT.Add(values0, values1, values0)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t)
require.NoError(t, tc.Evl.Add(ciphertext0, ciphertext1, ciphertext0))
tc.Params.RingT().Add(p0, p1, p0)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Add/Ct/Pt/Inplace", tc.params, lvl), func(t *testing.T) {
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values1, plaintext, _ := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
for _, lvl := range testLevels {
t.Run(name("Add/Ct/Pt/Inplace", tc, lvl), func(t *testing.T) {
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, plaintext, _ := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
require.True(t, ciphertext0.Scale.Cmp(plaintext.Scale) != 0)
require.NoError(t, tc.evaluator.Add(ciphertext0, plaintext, ciphertext0))
tc.ringT.Add(values0, values1, values0)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t)
require.NoError(t, tc.Evl.Add(ciphertext0, plaintext, ciphertext0))
tc.Params.RingT().Add(p0, p1, p0)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Add/Ct/Scalar/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevels {
t.Run(name("Add/Ct/Scalar/Inplace", tc, lvl), func(t *testing.T) {
values, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale())
values, _, ciphertext := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk)
scalar := tc.Params.PlaintextModulus() >> 1
scalar := tc.params.PlaintextModulus() >> 1
p := ring.Poly{Coeffs: [][]uint64{values}}
require.NoError(t, tc.evaluator.Add(ciphertext, scalar, ciphertext))
tc.ringT.AddScalar(values, scalar, values)
verifyTestVectors(tc, tc.decryptor, values, ciphertext, t)
require.NoError(t, tc.Evl.Add(ciphertext, scalar, ciphertext))
tc.Params.RingT().AddScalar(p, scalar, p)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Sub/Ct/Ct/New", tc.params, lvl), func(t *testing.T) {
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
for _, lvl := range testLevels {
t.Run(name("Sub/Ct/Ct/New", tc, lvl), func(t *testing.T) {
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0)
ciphertext0, err := tc.evaluator.SubNew(ciphertext0, ciphertext1)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
ciphertext0, err := tc.Evl.SubNew(ciphertext0, ciphertext1)
require.NoError(t, err)
tc.ringT.Sub(values0, values1, values0)
verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t)
tc.Params.RingT().Sub(p0, p1, p0)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Sub/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) {
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
for _, lvl := range testLevels {
t.Run(name("Sub/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) {
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0)
require.NoError(t, tc.evaluator.Sub(ciphertext0, ciphertext1, ciphertext0))
tc.ringT.Sub(values0, values1, values0)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t)
require.NoError(t, tc.Evl.Sub(ciphertext0, ciphertext1, ciphertext0))
tc.Params.RingT().Sub(p0, p1, p0)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Sub/Ct/Pt/Inplace", tc.params, lvl), func(t *testing.T) {
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values1, plaintext, _ := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
for _, lvl := range testLevels {
t.Run(name("Sub/Ct/Pt/Inplace", tc, lvl), func(t *testing.T) {
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, plaintext, _ := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
require.True(t, ciphertext0.Scale.Cmp(plaintext.Scale) != 0)
require.NoError(t, tc.evaluator.Sub(ciphertext0, plaintext, ciphertext0))
tc.ringT.Sub(values0, values1, values0)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t)
require.NoError(t, tc.Evl.Sub(ciphertext0, plaintext, ciphertext0))
tc.Params.RingT().Sub(p0, p1, p0)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Mul/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevels {
t.Run(name("Mul/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Level = 0")
t.Skip("Skipping: Level = 0")
}
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0)
require.NoError(t, tc.evaluator.Mul(ciphertext0, ciphertext1, ciphertext0))
tc.ringT.MulCoeffsBarrett(values0, values1, values0)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t)
require.NoError(t, tc.Evl.Mul(ciphertext0, ciphertext1, ciphertext0))
tc.Params.RingT().MulCoeffsBarrett(p0, p1, p0)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Mul/Ct/Pt/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevels {
t.Run(name("Mul/Ct/Pt/Inplace", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Level = 0")
}
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values1, plaintext, _ := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
values0, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, plaintext, _ := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
require.True(t, ciphertext0.Scale.Cmp(plaintext.Scale) != 0)
require.True(t, ciphertext.Scale.Cmp(plaintext.Scale) != 0)
require.NoError(t, tc.evaluator.Mul(ciphertext0, plaintext, ciphertext0))
tc.ringT.MulCoeffsBarrett(values0, values1, values0)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t)
require.NoError(t, tc.Evl.Mul(ciphertext, plaintext, ciphertext))
tc.Params.RingT().MulCoeffsBarrett(p0, p1, p0)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p0.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Mul/Ct/Scalar/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevels {
t.Run(name("Mul/Ct/Scalar/Inplace", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Level = 0")
}
values, _, ciphertext := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk)
values, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale())
scalar := tc.params.PlaintextModulus() >> 1
scalar := tc.Params.PlaintextModulus() >> 1
tc.evaluator.Mul(ciphertext, scalar, ciphertext)
tc.ringT.MulScalar(values, scalar, values)
p := ring.Poly{Coeffs: [][]uint64{values}}
verifyTestVectors(tc, tc.decryptor, values, ciphertext, t)
require.NoError(t, tc.Evl.Mul(ciphertext, scalar, ciphertext))
tc.Params.RingT().MulScalar(p, scalar, p)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Square/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevels {
t.Run(name("Square/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Level = 0")
}
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale())
require.NoError(t, tc.evaluator.Mul(ciphertext0, ciphertext0, ciphertext0))
tc.ringT.MulCoeffsBarrett(values0, values0, values0)
p := ring.Poly{Coeffs: [][]uint64{values}}
verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t)
require.NoError(t, tc.Evl.Mul(ciphertext, ciphertext, ciphertext))
tc.Params.RingT().MulCoeffsBarrett(p, p, p)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("MulRelin/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevels {
t.Run(name("MulRelin/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Level = 0")
}
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
tc.ringT.MulCoeffsBarrett(values0, values1, values0)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
tc.Params.RingT().MulCoeffsBarrett(p0, p1, p0)
require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0)
receiver := NewCiphertext(tc.params, 1, lvl)
receiver := NewCiphertext(tc.Params, 1, lvl)
require.NoError(t, tc.evaluator.MulRelin(ciphertext0, ciphertext1, receiver))
require.NoError(t, tc.evaluator.Rescale(receiver, receiver))
verifyTestVectors(tc, tc.decryptor, values0, receiver, t)
require.NoError(t, tc.Evl.MulRelin(ciphertext0, ciphertext1, receiver))
require.NoError(t, tc.Evl.Rescale(receiver, receiver))
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, receiver, p0.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("MulThenAdd/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevels {
t.Run(name("MulThenAdd/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Level = 0")
}
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk)
values1, _, ciphertext1 := newTestVectorsLvl(lvl, rlwe.NewScale(2), tc, tc.encryptorSk)
values2, _, ciphertext2 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale())
values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(2))
values2, _, ciphertext2 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
p2 := ring.Poly{Coeffs: [][]uint64{values2}}
require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0)
require.True(t, ciphertext0.Scale.Cmp(ciphertext2.Scale) != 0)
require.NoError(t, tc.evaluator.MulThenAdd(ciphertext0, ciphertext1, ciphertext2))
tc.ringT.MulCoeffsBarrettThenAdd(values0, values1, values2)
verifyTestVectors(tc, tc.decryptor, values2, ciphertext2, t)
require.NoError(t, tc.Evl.MulThenAdd(ciphertext0, ciphertext1, ciphertext2))
tc.Params.RingT().MulCoeffsBarrettThenAdd(p0, p1, p2)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext2, p2.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("MulThenAdd/Ct/Pt/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevels {
t.Run(name("MulThenAdd/Ct/Pt/Inplace", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Level = 0")
}
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk)
values1, plaintext1, _ := newTestVectorsLvl(lvl, rlwe.NewScale(2), tc, tc.encryptorSk)
values2, _, ciphertext2 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale())
values1, plaintext1, _ := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(2))
values2, _, ciphertext2 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
p2 := ring.Poly{Coeffs: [][]uint64{values2}}
require.True(t, ciphertext0.Scale.Cmp(plaintext1.Scale) != 0)
require.True(t, ciphertext0.Scale.Cmp(ciphertext2.Scale) != 0)
require.NoError(t, tc.evaluator.MulThenAdd(ciphertext0, plaintext1, ciphertext2))
tc.ringT.MulCoeffsBarrettThenAdd(values0, values1, values2)
verifyTestVectors(tc, tc.decryptor, values2, ciphertext2, t)
require.NoError(t, tc.Evl.MulThenAdd(ciphertext0, plaintext1, ciphertext2))
tc.Params.RingT().MulCoeffsBarrettThenAdd(p0, p1, p2)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext2, p2.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("MulThenAdd/Ct/Scalar/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevels {
t.Run(name("MulThenAdd/Ct/Scalar/Inplace", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Level = 0")
}
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0)
scalar := tc.params.PlaintextModulus() >> 1
scalar := tc.Params.PlaintextModulus() >> 1
require.NoError(t, tc.evaluator.MulThenAdd(ciphertext0, scalar, ciphertext1))
tc.ringT.MulScalarThenAdd(values0, scalar, values1)
require.NoError(t, tc.Evl.MulThenAdd(ciphertext0, scalar, ciphertext1))
tc.Params.RingT().MulScalarThenAdd(p0, scalar, p1)
verifyTestVectors(tc, tc.decryptor, values1, ciphertext1, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext1, p1.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("MulRelinThenAdd/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevels {
t.Run(name("MulRelinThenAdd/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Level = 0")
}
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk)
values1, _, ciphertext1 := newTestVectorsLvl(lvl, rlwe.NewScale(2), tc, tc.encryptorSk)
values2, _, ciphertext2 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale())
values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(2))
values2, _, ciphertext2 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0)
require.True(t, ciphertext0.Scale.Cmp(ciphertext2.Scale) != 0)
require.NoError(t, tc.evaluator.MulRelinThenAdd(ciphertext0, ciphertext1, ciphertext2))
tc.ringT.MulCoeffsBarrettThenAdd(values0, values1, values2)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
p2 := ring.Poly{Coeffs: [][]uint64{values2}}
verifyTestVectors(tc, tc.decryptor, values2, ciphertext2, t)
require.NoError(t, tc.Evl.MulRelinThenAdd(ciphertext0, ciphertext1, ciphertext2))
tc.Params.RingT().MulCoeffsBarrettThenAdd(p0, p1, p2)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext2, p2.Coeffs[0], t)
})
}
})

View File

@@ -1,15 +0,0 @@
package bfv
var (
// testInsecure are insecure parameters used for the sole purpose of fast testing.
testInsecure = ParametersLiteral{
LogN: 10,
Q: []uint64{0x3fffffa8001, 0x1000090001, 0x10000c8001, 0x10000f0001, 0xffff00001},
P: []uint64{0x7fffffd8001},
}
testPlaintextModulus = []uint64{0x101, 0xffc001}
testParams = []ParametersLiteral{testInsecure}
)

120
schemes/bfv/test_utils.go Normal file
View File

@@ -0,0 +1,120 @@
package bfv
import (
"fmt"
"math"
"slices"
"testing"
"github.com/stretchr/testify/require"
"github.com/tuneinsight/lattigo/v5/core/rlwe"
"github.com/tuneinsight/lattigo/v5/ring"
"github.com/tuneinsight/lattigo/v5/utils/sampling"
)
type TestContext struct {
Params Parameters
Ecd *Encoder
Prng sampling.PRNG
Sampler *ring.UniformSampler
Kgen *rlwe.KeyGenerator
Sk *rlwe.SecretKey
Pk *rlwe.PublicKey
Enc *rlwe.Encryptor
Dec *rlwe.Decryptor
Evl *Evaluator
}
func NewTestContext(params ParametersLiteral) *TestContext {
tc := new(TestContext)
var err error
tc.Params, err = NewParametersFromLiteral(params)
if err != nil {
panic(err)
}
tc.Ecd = NewEncoder(tc.Params)
tc.Prng, err = sampling.NewPRNG()
if err != nil {
panic(err)
}
tc.Sampler = ring.NewUniformSampler(tc.Prng, tc.Params.RingT())
tc.Kgen = rlwe.NewKeyGenerator(tc.Params)
tc.Sk, tc.Pk = tc.Kgen.GenKeyPairNew()
tc.Enc = rlwe.NewEncryptor(tc.Params, tc.Pk)
tc.Dec = rlwe.NewDecryptor(tc.Params, tc.Sk)
tc.Evl = NewEvaluator(tc.Params, rlwe.NewMemEvaluationKeySet(tc.Kgen.GenRelinearizationKeyNew(tc.Sk)))
return tc
}
func (tc TestContext) String() string {
return fmt.Sprintf("LogN=%d/logQ=%d/logP=%d/LogSlots=%dx%d/logT=%d/Qi=%d/Pi=%d",
tc.Params.LogN(),
int(math.Round(tc.Params.LogQ())),
int(math.Round(tc.Params.LogP())),
tc.Params.LogMaxDimensions().Rows,
tc.Params.LogMaxDimensions().Cols,
int(math.Round(tc.Params.LogT())),
tc.Params.QCount(),
tc.Params.PCount())
}
func VerifyTestVectors(params Parameters, encoder *Encoder, decryptor *rlwe.Decryptor, have interface{}, want []uint64, t *testing.T) {
values := make([]uint64, params.MaxSlots())
switch have := have.(type) {
case *rlwe.Plaintext:
require.NoError(t, encoder.Decode(have, values))
case *rlwe.Ciphertext:
require.NoError(t, encoder.Decode(decryptor.DecryptNew(have), values))
default:
t.Error("invalid unsupported test object type")
}
require.True(t, slices.Equal(values, want))
}
func NewTestVector(params Parameters, encoder *Encoder, encryptor *rlwe.Encryptor, level int, scale rlwe.Scale) (values []uint64, pt *rlwe.Plaintext, ct *rlwe.Ciphertext) {
values = make([]uint64, params.MaxSlots())
for i := range values {
values[i] = sampling.RandUint64() % params.PlaintextModulus()
}
pt = NewPlaintext(params, level)
pt.Scale = scale
if err := encoder.Encode(values, pt); err != nil {
panic(err)
}
if encryptor != nil {
var err error
ct, err = encryptor.EncryptNew(pt)
if err != nil {
panic(err)
}
}
return
}
var (
// testInsecure are insecure parameters used for the sole purpose of fast testing.
TestInsecure = ParametersLiteral{
LogN: 10,
Q: []uint64{0x3fffffa8001, 0x1000090001, 0x10000c8001, 0x10000f0001, 0xffff00001},
P: []uint64{0x7fffffd8001},
}
TestPlaintextModulus = []uint64{0x101, 0xffc001}
TestParams = []ParametersLiteral{TestInsecure}
)

View File

@@ -2,24 +2,13 @@ package bgv
import (
"encoding/json"
"fmt"
"runtime"
"testing"
"github.com/tuneinsight/lattigo/v5/core/rlwe"
)
func GetBenchName(params Parameters, opname string) string {
return fmt.Sprintf("%s/logN=%d/Qi=%d/Pi=%d/LogSlots=%d",
opname,
params.LogN(),
params.QCount(),
params.PCount(),
params.LogMaxSlots())
}
func BenchmarkBGV(b *testing.B) {
var err error
var testParams []ParametersLiteral
@@ -41,20 +30,9 @@ func BenchmarkBGV(b *testing.B) {
}
for _, paramsLiteral := range testParams {
tc := NewTestContext(paramsLiteral)
var params Parameters
if params, err = NewParametersFromLiteral(paramsLiteral); err != nil {
b.Error(err)
b.Fail()
}
var tc *testContext
if tc, err = genTestParams(params); err != nil {
b.Error(err)
b.Fail()
}
for _, testSet := range []func(tc *testContext, b *testing.B){
for _, testSet := range []func(tc *TestContext, b *testing.B){
benchEncoder,
benchEvaluator,
} {
@@ -64,11 +42,12 @@ func BenchmarkBGV(b *testing.B) {
}
}
func benchEncoder(tc *testContext, b *testing.B) {
func benchEncoder(tc *TestContext, b *testing.B) {
params := tc.params
params := tc.Params
lvl := params.MaxLevel()
poly := tc.uSampler.ReadNew()
poly := tc.Sampler.ReadNew()
params.RingT().Reduce(poly, poly)
coeffsUint64 := poly.Coeffs[0]
coeffsInt64 := make([]int64, len(coeffsUint64))
@@ -76,10 +55,10 @@ func benchEncoder(tc *testContext, b *testing.B) {
coeffsInt64[i] = int64(coeffsUint64[i])
}
encoder := tc.encoder
encoder := tc.Ecd
b.Run(GetBenchName(params, "Encoder/Encode/Uint"), func(b *testing.B) {
plaintext := NewPlaintext(params, params.MaxLevel())
b.Run(name("Encoder/Encode/Uint", tc, lvl), func(b *testing.B) {
plaintext := NewPlaintext(params, lvl)
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := encoder.Encode(coeffsUint64, plaintext); err != nil {
@@ -89,8 +68,8 @@ func benchEncoder(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Encoder/Encode/Int"), func(b *testing.B) {
plaintext := NewPlaintext(params, params.MaxLevel())
b.Run(name("Encoder/Encode/Int", tc, lvl), func(b *testing.B) {
plaintext := NewPlaintext(params, lvl)
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := encoder.Encode(coeffsInt64, plaintext); err != nil {
@@ -100,8 +79,8 @@ func benchEncoder(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Encoder/Decode/Uint"), func(b *testing.B) {
plaintext := NewPlaintext(params, params.MaxLevel())
b.Run(name("Encoder/Decode/Uint", tc, lvl), func(b *testing.B) {
plaintext := NewPlaintext(params, lvl)
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := encoder.Decode(plaintext, coeffsUint64); err != nil {
@@ -111,8 +90,8 @@ func benchEncoder(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Encoder/Decode/Int"), func(b *testing.B) {
plaintext := NewPlaintext(params, params.MaxLevel())
b.Run(name("Encoder/Decode/Int", tc, lvl), func(b *testing.B) {
plaintext := NewPlaintext(params, lvl)
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := encoder.Decode(plaintext, coeffsInt64); err != nil {
@@ -123,16 +102,17 @@ func benchEncoder(tc *testContext, b *testing.B) {
})
}
func benchEvaluator(tc *testContext, b *testing.B) {
func benchEvaluator(tc *TestContext, b *testing.B) {
params := tc.params
eval := tc.evaluator
params := tc.Params
lvl := params.MaxLevel()
eval := tc.Evl
plaintext := NewPlaintext(params, params.MaxLevel())
plaintext.Value = rlwe.NewCiphertextRandom(tc.prng, params.Parameters, 0, plaintext.Level()).Value[0]
plaintext := NewPlaintext(params, lvl)
plaintext.Value = rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 0, plaintext.Level()).Value[0]
ciphertext1 := rlwe.NewCiphertextRandom(tc.prng, params.Parameters, 1, params.MaxLevel())
ciphertext2 := rlwe.NewCiphertextRandom(tc.prng, params.Parameters, 1, params.MaxLevel())
ciphertext1 := rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 1, lvl)
ciphertext2 := rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 1, lvl)
scalar := params.PlaintextModulus() >> 1
*ciphertext1.MetaData = *plaintext.MetaData
@@ -140,7 +120,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
vector := plaintext.Value.Coeffs[0][:params.MaxSlots()]
b.Run(GetBenchName(params, "Evaluator/Add/Scalar"), func(b *testing.B) {
b.Run(name("Evaluator/Add/Scalar", tc, lvl), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -151,7 +131,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Add/Vector"), func(b *testing.B) {
b.Run(name("Evaluator/Add/Vector", tc, lvl), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -162,7 +142,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Add/Plaintext"), func(b *testing.B) {
b.Run(name("Evaluator/Add/Plaintext", tc, lvl), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -173,7 +153,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Add/Ciphertext"), func(b *testing.B) {
b.Run(name("Evaluator/Add/Ciphertext", tc, lvl), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -184,7 +164,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Mul/Scalar"), func(b *testing.B) {
b.Run(name("Evaluator/Mul/Scalar", tc, lvl), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -195,7 +175,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Mul/Plaintext"), func(b *testing.B) {
b.Run(name("Evaluator/Mul/Plaintext", tc, lvl), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -206,7 +186,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Mul/Vector"), func(b *testing.B) {
b.Run(name("Evaluator/Mul/Vector", tc, lvl), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -217,7 +197,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Mul/Ciphertext"), func(b *testing.B) {
b.Run(name("Evaluator/Mul/Ciphertext", tc, lvl), func(b *testing.B) {
receiver := NewCiphertext(params, 2, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -228,7 +208,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/MulRelin/Ciphertext"), func(b *testing.B) {
b.Run(name("Evaluator/MulRelin/Ciphertext", tc, lvl), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -239,7 +219,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/MulInvariant/Ciphertext"), func(b *testing.B) {
b.Run(name("Evaluator/MulInvariant/Ciphertext", tc, lvl), func(b *testing.B) {
receiver := NewCiphertext(params, 2, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -250,7 +230,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/MulRelinInvariant/Ciphertext"), func(b *testing.B) {
b.Run(name("Evaluator/MulRelinInvariant/Ciphertext", tc, lvl), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -261,7 +241,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/MulThenAdd/Scalar"), func(b *testing.B) {
b.Run(name("Evaluator/MulThenAdd/Scalar", tc, lvl), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -272,7 +252,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/MulThenAdd/Vector"), func(b *testing.B) {
b.Run(name("Evaluator/MulThenAdd/Vector", tc, lvl), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -283,7 +263,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/MulThenAdd/Plaintext"), func(b *testing.B) {
b.Run(name("Evaluator/MulThenAdd/Plaintext", tc, lvl), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -294,7 +274,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/MulThenAdd/Ciphertext"), func(b *testing.B) {
b.Run(name("Evaluator/MulThenAdd/Ciphertext", tc, lvl), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -305,7 +285,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/MulRelinThenAdd/Ciphertext"), func(b *testing.B) {
b.Run(name("Evaluator/MulRelinThenAdd/Ciphertext", tc, lvl), func(b *testing.B) {
receiver := NewCiphertext(params, 2, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -316,7 +296,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Rescale"), func(b *testing.B) {
b.Run(name("Evaluator/Rescale", tc, lvl), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level()-1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -327,8 +307,8 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Rotate"), func(b *testing.B) {
gk := tc.kgen.GenGaloisKeyNew(5, tc.sk)
b.Run(name("Evaluator/Rotate", tc, lvl), func(b *testing.B) {
gk := tc.Kgen.GenGaloisKeyNew(5, tc.Sk)
evk := rlwe.NewMemEvaluationKeySet(nil, gk)
eval := eval.WithKey(evk)
receiver := NewCiphertext(params, 1, ciphertext2.Level())

View File

@@ -4,40 +4,28 @@ import (
"encoding/json"
"flag"
"fmt"
"math"
"runtime"
"slices"
"testing"
"github.com/stretchr/testify/require"
"github.com/tuneinsight/lattigo/v5/core/rlwe"
"github.com/tuneinsight/lattigo/v5/ring"
"github.com/stretchr/testify/require"
"github.com/tuneinsight/lattigo/v5/utils/sampling"
)
var flagPrintNoise = flag.Bool("print-noise", false, "print the residual noise")
var flagParamString = flag.String("params", "", "specify the test cryptographic parameters as a JSON string. Overrides -short.")
func GetTestName(opname string, p Parameters, lvl int) string {
return fmt.Sprintf("%s/LogN=%d/logQ=%d/logP=%d/LogSlots=%dx%d/logT=%d/Qi=%d/Pi=%d/lvl=%d",
opname,
p.LogN(),
int(math.Round(p.LogQ())),
int(math.Round(p.LogP())),
p.LogMaxDimensions().Rows,
p.LogMaxDimensions().Cols,
int(math.Round(p.LogT())),
p.QCount(),
p.PCount(),
lvl)
func name(op string, tc *TestContext, lvl int) string {
return fmt.Sprintf("%s/%s/lvl=%d", op, tc, lvl)
}
func TestBGV(t *testing.T) {
var err error
paramsLiterals := testParams
paramsLiterals := TestParams
if *flagParamString != "" {
var jsonParams ParametersLiteral
@@ -49,23 +37,13 @@ func TestBGV(t *testing.T) {
for _, p := range paramsLiterals[:] {
for _, plaintextModulus := range testPlaintextModulus[:] {
for _, plaintextModulus := range TestPlaintextModulus[:] {
p.PlaintextModulus = plaintextModulus
var params Parameters
if params, err = NewParametersFromLiteral(p); err != nil {
t.Error(err)
t.Fail()
}
tc := NewTestContext(p)
var tc *testContext
if tc, err = genTestParams(params); err != nil {
t.Error(err)
t.Fail()
}
for _, testSet := range []func(tc *testContext, t *testing.T){
for _, testSet := range []func(tc *TestContext, t *testing.T){
testParameters,
testEncoder,
testEvaluator,
@@ -77,114 +55,20 @@ func TestBGV(t *testing.T) {
}
}
type testContext struct {
params Parameters
ringQ *ring.Ring
ringT *ring.Ring
prng sampling.PRNG
uSampler *ring.UniformSampler
encoder *Encoder
kgen *rlwe.KeyGenerator
sk *rlwe.SecretKey
pk *rlwe.PublicKey
encryptorPk *rlwe.Encryptor
encryptorSk *rlwe.Encryptor
decryptor *rlwe.Decryptor
evaluator *Evaluator
testLevel []int
}
func testParameters(tc *TestContext, t *testing.T) {
t.Run(name("Parameters/Binary", tc, 0), func(t *testing.T) {
func genTestParams(params Parameters) (tc *testContext, err error) {
tc = new(testContext)
tc.params = params
if tc.prng, err = sampling.NewPRNG(); err != nil {
return nil, err
}
tc.ringQ = params.RingQ()
tc.ringT = params.RingT()
tc.uSampler = ring.NewUniformSampler(tc.prng, tc.ringT)
tc.kgen = NewKeyGenerator(tc.params)
tc.sk, tc.pk = tc.kgen.GenKeyPairNew()
tc.encoder = NewEncoder(tc.params)
tc.encryptorPk = NewEncryptor(tc.params, tc.pk)
tc.encryptorSk = NewEncryptor(tc.params, tc.sk)
tc.decryptor = NewDecryptor(tc.params, tc.sk)
tc.evaluator = NewEvaluator(tc.params, rlwe.NewMemEvaluationKeySet(tc.kgen.GenRelinearizationKeyNew(tc.sk)))
tc.testLevel = []int{0, params.MaxLevel()}
return
}
func newTestVectorsLvl(level int, scale rlwe.Scale, tc *testContext, encryptor *rlwe.Encryptor) (coeffs ring.Poly, plaintext *rlwe.Plaintext, ciphertext *rlwe.Ciphertext) {
coeffs = tc.uSampler.ReadNew()
for i := range coeffs.Coeffs[0] {
coeffs.Coeffs[0][i] = uint64(i)
}
plaintext = NewPlaintext(tc.params, level)
plaintext.Scale = scale
if err := tc.encoder.Encode(coeffs.Coeffs[0], plaintext); err != nil {
panic(err)
}
if encryptor != nil {
var err error
ciphertext, err = encryptor.EncryptNew(plaintext)
if err != nil {
panic(err)
}
}
return coeffs, plaintext, ciphertext
}
func verifyTestVectors(tc *testContext, decryptor *rlwe.Decryptor, coeffs ring.Poly, element rlwe.ElementInterface[ring.Poly], t *testing.T) {
coeffsTest := make([]uint64, tc.params.MaxSlots())
switch el := element.(type) {
case *rlwe.Plaintext:
require.NoError(t, tc.encoder.Decode(el, coeffsTest))
case *rlwe.Ciphertext:
pt := decryptor.DecryptNew(el)
require.NoError(t, tc.encoder.Decode(pt, coeffsTest))
if *flagPrintNoise {
require.NoError(t, tc.encoder.Encode(coeffsTest, pt))
ct, err := tc.evaluator.SubNew(el, pt)
require.NoError(t, err)
vartmp, _, _ := rlwe.Norm(ct, decryptor)
t.Logf("STD(noise): %f\n", vartmp)
}
default:
t.Error("invalid test object to verify")
}
require.True(t, slices.Equal(coeffs.Coeffs[0], coeffsTest))
}
func testParameters(tc *testContext, t *testing.T) {
t.Run(GetTestName("Parameters/Binary", tc.params, 0), func(t *testing.T) {
bytes, err := tc.params.MarshalBinary()
bytes, err := tc.Params.MarshalBinary()
require.Nil(t, err)
var p Parameters
require.Nil(t, p.UnmarshalBinary(bytes))
require.True(t, tc.params.Equal(&p))
require.True(t, tc.Params.Equal(&p))
})
t.Run(GetTestName("Parameters/JSON", tc.params, 0), func(t *testing.T) {
t.Run(name("Parameters/JSON", tc, 0), func(t *testing.T) {
// checks that parameters can be marshalled without error
data, err := json.Marshal(tc.params)
data, err := json.Marshal(tc.Params)
require.Nil(t, err)
require.NotNil(t, data)
@@ -192,10 +76,10 @@ func testParameters(tc *testContext, t *testing.T) {
var paramsRec Parameters
err = json.Unmarshal(data, &paramsRec)
require.Nil(t, err)
require.True(t, tc.params.Equal(&paramsRec))
require.True(t, tc.Params.Equal(&paramsRec))
// checks that the Parameters can be unmarshalled with log-moduli definition without error
dataWithLogModuli := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[60], "PlaintextModulus":65537}`, tc.params.LogN()))
dataWithLogModuli := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[60], "PlaintextModulus":65537}`, tc.Params.LogN()))
var paramsWithLogModuli Parameters
err = json.Unmarshal(dataWithLogModuli, &paramsWithLogModuli)
require.Nil(t, err)
@@ -205,7 +89,7 @@ func testParameters(tc *testContext, t *testing.T) {
require.Equal(t, rlwe.DefaultXs, paramsWithLogModuli.Xs()) // Omitting Xe should result in Default being used
// checks that one can provide custom parameters for the secret-key and error distributions
dataWithCustomSecrets := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[60], "PlaintextModulus":65537, "Xs": {"Type": "Ternary", "H": 192}, "Xe": {"Type": "DiscreteGaussian", "Sigma": 6.6, "Bound": 39.6}}`, tc.params.LogN()))
dataWithCustomSecrets := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[60], "PlaintextModulus":65537, "Xs": {"Type": "Ternary", "H": 192}, "Xe": {"Type": "DiscreteGaussian", "Sigma": 6.6, "Bound": 39.6}}`, tc.Params.LogN()))
var paramsWithCustomSecrets Parameters
err = json.Unmarshal(dataWithCustomSecrets, &paramsWithCustomSecrets)
require.Nil(t, err)
@@ -214,21 +98,22 @@ func testParameters(tc *testContext, t *testing.T) {
})
}
func testEncoder(tc *testContext, t *testing.T) {
func testEncoder(tc *TestContext, t *testing.T) {
testLevel := [2]int{0, tc.Params.MaxLevel()}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Encoder/Uint/IsBatched=true", tc.params, lvl), func(t *testing.T) {
values, plaintext, _ := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, nil)
verifyTestVectors(tc, nil, values, plaintext, t)
for _, lvl := range testLevel {
t.Run(name("Encoder/Uint/IsBatched=true", tc, lvl), func(t *testing.T) {
values, plaintext, _ := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale())
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, plaintext, values, t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Encoder/Int/IsBatched=true", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevel {
t.Run(name("Encoder/Int/IsBatched=true", tc, lvl), func(t *testing.T) {
T := tc.params.PlaintextModulus()
T := tc.Params.PlaintextModulus()
THalf := T >> 1
poly := tc.uSampler.ReadNew()
poly := tc.Sampler.ReadNew()
coeffs := make([]int64, poly.N())
for i, c := range poly.Coeffs[0] {
c %= T
@@ -239,38 +124,38 @@ func testEncoder(tc *testContext, t *testing.T) {
}
}
plaintext := NewPlaintext(tc.params, lvl)
tc.encoder.Encode(coeffs, plaintext)
have := make([]int64, tc.params.MaxSlots())
tc.encoder.Decode(plaintext, have)
plaintext := NewPlaintext(tc.Params, lvl)
tc.Ecd.Encode(coeffs, plaintext)
have := make([]int64, tc.Params.MaxSlots())
tc.Ecd.Decode(plaintext, have)
require.True(t, slices.Equal(coeffs, have))
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Encoder/Uint/IsBatched=false", tc.params, lvl), func(t *testing.T) {
T := tc.params.PlaintextModulus()
poly := tc.uSampler.ReadNew()
for _, lvl := range testLevel {
t.Run(name("Encoder/Uint/IsBatched=false", tc, lvl), func(t *testing.T) {
T := tc.Params.PlaintextModulus()
poly := tc.Sampler.ReadNew()
coeffs := make([]uint64, poly.N())
for i, c := range poly.Coeffs[0] {
coeffs[i] = c % T
}
plaintext := NewPlaintext(tc.params, lvl)
plaintext := NewPlaintext(tc.Params, lvl)
plaintext.IsBatched = false
require.NoError(t, tc.encoder.Encode(coeffs, plaintext))
have := make([]uint64, tc.params.MaxSlots())
require.NoError(t, tc.encoder.Decode(plaintext, have))
require.NoError(t, tc.Ecd.Encode(coeffs, plaintext))
have := make([]uint64, tc.Params.MaxSlots())
require.NoError(t, tc.Ecd.Decode(plaintext, have))
require.True(t, slices.Equal(coeffs, have))
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Encoder/Int/IsBatched=false", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevel {
t.Run(name("Encoder/Int/IsBatched=false", tc, lvl), func(t *testing.T) {
T := tc.params.PlaintextModulus()
T := tc.Params.PlaintextModulus()
THalf := T >> 1
poly := tc.uSampler.ReadNew()
poly := tc.Sampler.ReadNew()
coeffs := make([]int64, poly.N())
for i, c := range poly.Coeffs[0] {
c %= T
@@ -281,438 +166,468 @@ func testEncoder(tc *testContext, t *testing.T) {
}
}
plaintext := NewPlaintext(tc.params, lvl)
plaintext := NewPlaintext(tc.Params, lvl)
plaintext.IsBatched = false
require.NoError(t, tc.encoder.Encode(coeffs, plaintext))
have := make([]int64, tc.params.MaxSlots())
require.NoError(t, tc.encoder.Decode(plaintext, have))
require.NoError(t, tc.Ecd.Encode(coeffs, plaintext))
have := make([]int64, tc.Params.MaxSlots())
require.NoError(t, tc.Ecd.Decode(plaintext, have))
require.True(t, slices.Equal(coeffs, have))
})
}
}
func testEvaluator(tc *testContext, t *testing.T) {
func testEvaluator(tc *TestContext, t *testing.T) {
testLevel := [2]int{0, tc.Params.MaxLevel()}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/Add/Ct/Ct/New", tc.params, lvl), func(t *testing.T) {
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
for _, lvl := range testLevel {
t.Run(name("Evaluator/Add/Ct/Ct/New", tc, lvl), func(t *testing.T) {
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0)
ciphertext2, err := tc.evaluator.AddNew(ciphertext0, ciphertext1)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
ciphertext2, err := tc.Evl.AddNew(ciphertext0, ciphertext1)
require.NoError(t, err)
tc.ringT.Add(values0, values1, values0)
verifyTestVectors(tc, tc.decryptor, values0, ciphertext2, t)
tc.Params.RingT().Add(p0, p1, p0)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext2, p0.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/Add/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) {
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
for _, lvl := range testLevel {
t.Run(name("Evaluator/Add/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) {
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0)
require.NoError(t, tc.evaluator.Add(ciphertext0, ciphertext1, ciphertext0))
tc.ringT.Add(values0, values1, values0)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t)
require.NoError(t, tc.Evl.Add(ciphertext0, ciphertext1, ciphertext0))
tc.Params.RingT().Add(p0, p1, p0)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/Add/Ct/Pt/Inplace", tc.params, lvl), func(t *testing.T) {
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values1, plaintext, _ := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
for _, lvl := range testLevel {
t.Run(name("Evaluator/Add/Ct/Pt/Inplace", tc, lvl), func(t *testing.T) {
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, plaintext, _ := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
require.True(t, ciphertext0.Scale.Cmp(plaintext.Scale) != 0)
require.NoError(t, tc.evaluator.Add(ciphertext0, plaintext, ciphertext0))
tc.ringT.Add(values0, values1, values0)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t)
require.NoError(t, tc.Evl.Add(ciphertext0, plaintext, ciphertext0))
tc.Params.RingT().Add(p0, p1, p0)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/Add/Ct/Scalar/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevel {
t.Run(name("Evaluator/Add/Ct/Scalar/Inplace", tc, lvl), func(t *testing.T) {
values, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale())
values, _, ciphertext := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk)
scalar := tc.Params.PlaintextModulus() >> 1
scalar := tc.params.PlaintextModulus() >> 1
p := ring.Poly{Coeffs: [][]uint64{values}}
require.NoError(t, tc.evaluator.Add(ciphertext, scalar, ciphertext))
tc.ringT.AddScalar(values, scalar, values)
verifyTestVectors(tc, tc.decryptor, values, ciphertext, t)
require.NoError(t, tc.Evl.Add(ciphertext, scalar, ciphertext))
tc.Params.RingT().AddScalar(p, scalar, p)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/Add/Ct/Vector/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevel {
t.Run(name("Evaluator/Add/Ct/Vector/Inplace", tc, lvl), func(t *testing.T) {
values, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale())
values, _, ciphertext := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk)
p := ring.Poly{Coeffs: [][]uint64{values}}
require.NoError(t, tc.evaluator.Add(ciphertext, values.Coeffs[0], ciphertext))
tc.ringT.Add(values, values, values)
verifyTestVectors(tc, tc.decryptor, values, ciphertext, t)
require.NoError(t, tc.Evl.Add(ciphertext, values, ciphertext))
tc.Params.RingT().Add(p, p, p)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/Sub/Ct/Ct/New", tc.params, lvl), func(t *testing.T) {
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
for _, lvl := range testLevel {
t.Run(name("Evaluator/Sub/Ct/Ct/New", tc, lvl), func(t *testing.T) {
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0)
ciphertext0, err := tc.evaluator.SubNew(ciphertext0, ciphertext1)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
ciphertext0, err := tc.Evl.SubNew(ciphertext0, ciphertext1)
require.NoError(t, err)
tc.ringT.Sub(values0, values1, values0)
verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t)
tc.Params.RingT().Sub(p0, p1, p0)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/Sub/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) {
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
for _, lvl := range testLevel {
t.Run(name("Evaluator/Sub/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) {
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0)
require.NoError(t, tc.evaluator.Sub(ciphertext0, ciphertext1, ciphertext0))
tc.ringT.Sub(values0, values1, values0)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t)
require.NoError(t, tc.Evl.Sub(ciphertext0, ciphertext1, ciphertext0))
tc.Params.RingT().Sub(p0, p1, p0)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/Sub/Ct/Pt/Inplace", tc.params, lvl), func(t *testing.T) {
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values1, plaintext, _ := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
for _, lvl := range testLevel {
t.Run(name("Evaluator/Sub/Ct/Pt/Inplace", tc, lvl), func(t *testing.T) {
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, plaintext, _ := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
require.True(t, ciphertext0.Scale.Cmp(plaintext.Scale) != 0)
require.NoError(t, tc.evaluator.Sub(ciphertext0, plaintext, ciphertext0))
tc.ringT.Sub(values0, values1, values0)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t)
require.NoError(t, tc.Evl.Sub(ciphertext0, plaintext, ciphertext0))
tc.Params.RingT().Sub(p0, p1, p0)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/Sub/Ct/Scalar/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevel {
t.Run(name("Evaluator/Sub/Ct/Scalar/Inplace", tc, lvl), func(t *testing.T) {
values, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale())
values, _, ciphertext := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk)
scalar := tc.Params.PlaintextModulus() >> 1
scalar := tc.params.PlaintextModulus() >> 1
p := ring.Poly{Coeffs: [][]uint64{values}}
require.NoError(t, tc.evaluator.Sub(ciphertext, scalar, ciphertext))
tc.ringT.SubScalar(values, scalar, values)
verifyTestVectors(tc, tc.decryptor, values, ciphertext, t)
require.NoError(t, tc.Evl.Sub(ciphertext, scalar, ciphertext))
tc.Params.RingT().SubScalar(p, scalar, p)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/Sub/Ct/Vector/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevel {
t.Run(name("Evaluator/Sub/Ct/Vector/Inplace", tc, lvl), func(t *testing.T) {
values, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale())
values, _, ciphertext := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk)
p := ring.Poly{Coeffs: [][]uint64{values}}
require.NoError(t, tc.evaluator.Sub(ciphertext, values.Coeffs[0], ciphertext))
tc.ringT.Sub(values, values, values)
verifyTestVectors(tc, tc.decryptor, values, ciphertext, t)
require.NoError(t, tc.Evl.Sub(ciphertext, values, ciphertext))
tc.Params.RingT().Sub(p, p, p)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/Mul/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevel {
t.Run(name("Evaluator/Mul/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Skipping: Level = 0")
}
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0)
require.NoError(t, tc.evaluator.Mul(ciphertext0, ciphertext1, ciphertext0))
tc.ringT.MulCoeffsBarrett(values0, values1, values0)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t)
require.NoError(t, tc.Evl.Mul(ciphertext0, ciphertext1, ciphertext0))
tc.Params.RingT().MulCoeffsBarrett(p0, p1, p0)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/Mul/Ct/Pt/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevel {
t.Run(name("Evaluator/Mul/Ct/Pt/Inplace", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Skipping: Level = 0")
}
values0, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, plaintext, _ := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
require.True(t, ciphertext.Scale.Cmp(plaintext.Scale) != 0)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
require.NoError(t, tc.Evl.Mul(ciphertext, plaintext, ciphertext))
tc.Params.RingT().MulCoeffsBarrett(p0, p1, p0)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p0.Coeffs[0], t)
})
}
for _, lvl := range testLevel {
t.Run(name("Evaluator/Mul/Ct/Scalar/Inplace", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Skipping: Level = 0")
}
values, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale())
scalar := tc.Params.PlaintextModulus() >> 1
p := ring.Poly{Coeffs: [][]uint64{values}}
require.NoError(t, tc.Evl.Mul(ciphertext, scalar, ciphertext))
tc.Params.RingT().MulScalar(p, scalar, p)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p.Coeffs[0], t)
})
}
for _, lvl := range testLevel {
t.Run(name("Evaluator/Mul/Ct/Vector/Inplace", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Skipping: Level = 0")
}
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values1, plaintext, _ := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
values, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale())
require.True(t, ciphertext0.Scale.Cmp(plaintext.Scale) != 0)
p := ring.Poly{Coeffs: [][]uint64{values}}
require.NoError(t, tc.evaluator.Mul(ciphertext0, plaintext, ciphertext0))
tc.ringT.MulCoeffsBarrett(values0, values1, values0)
verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t)
require.NoError(t, tc.Evl.Mul(ciphertext, values, ciphertext))
tc.Params.RingT().MulCoeffsBarrett(p, p, p)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/Mul/Ct/Scalar/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevel {
t.Run(name("Evaluator/Square/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Skipping: Level = 0")
}
values, _, ciphertext := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk)
values, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale())
scalar := tc.params.PlaintextModulus() >> 1
p := ring.Poly{Coeffs: [][]uint64{values}}
require.NoError(t, tc.evaluator.Mul(ciphertext, scalar, ciphertext))
tc.ringT.MulScalar(values, scalar, values)
require.NoError(t, tc.Evl.Mul(ciphertext, ciphertext, ciphertext))
tc.Params.RingT().MulCoeffsBarrett(p, p, p)
verifyTestVectors(tc, tc.decryptor, values, ciphertext, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/Mul/Ct/Vector/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevel {
t.Run(name("Evaluator/MulRelin/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Skipping: Level = 0")
}
values, _, ciphertext := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk)
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
require.NoError(t, tc.evaluator.Mul(ciphertext, values.Coeffs[0], ciphertext))
tc.ringT.MulCoeffsBarrett(values, values, values)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
verifyTestVectors(tc, tc.decryptor, values, ciphertext, t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/Square/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Skipping: Level = 0")
}
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
require.NoError(t, tc.evaluator.Mul(ciphertext0, ciphertext0, ciphertext0))
tc.ringT.MulCoeffsBarrett(values0, values0, values0)
verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/MulRelin/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Skipping: Level = 0")
}
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
tc.ringT.MulCoeffsBarrett(values0, values1, values0)
tc.Params.RingT().MulCoeffsBarrett(p0, p1, p0)
require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0)
receiver := NewCiphertext(tc.params, 1, lvl)
receiver := NewCiphertext(tc.Params, 1, lvl)
require.NoError(t, tc.evaluator.MulRelin(ciphertext0, ciphertext1, receiver))
require.NoError(t, tc.Evl.MulRelin(ciphertext0, ciphertext1, receiver))
require.NoError(t, tc.Evl.Rescale(receiver, receiver))
require.NoError(t, tc.evaluator.Rescale(receiver, receiver))
verifyTestVectors(tc, tc.decryptor, values0, receiver, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, receiver, p0.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/MulThenAdd/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevel {
t.Run(name("Evaluator/MulThenAdd/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Skipping: Level = 0")
}
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk)
values1, _, ciphertext1 := newTestVectorsLvl(lvl, rlwe.NewScale(2), tc, tc.encryptorSk)
values2, _, ciphertext2 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale())
values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(2))
values2, _, ciphertext2 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
p2 := ring.Poly{Coeffs: [][]uint64{values2}}
require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0)
require.True(t, ciphertext0.Scale.Cmp(ciphertext2.Scale) != 0)
require.NoError(t, tc.evaluator.MulThenAdd(ciphertext0, ciphertext1, ciphertext2))
tc.ringT.MulCoeffsBarrettThenAdd(values0, values1, values2)
require.NoError(t, tc.Evl.MulThenAdd(ciphertext0, ciphertext1, ciphertext2))
tc.Params.RingT().MulCoeffsBarrettThenAdd(p0, p1, p2)
verifyTestVectors(tc, tc.decryptor, values2, ciphertext2, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext2, p2.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/MulThenAdd/Ct/Pt/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevel {
t.Run(name("Evaluator/MulThenAdd/Ct/Pt/Inplace", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Skipping: Level = 0")
}
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk)
values1, plaintext1, _ := newTestVectorsLvl(lvl, rlwe.NewScale(2), tc, tc.encryptorSk)
values2, _, ciphertext2 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale())
values1, plaintext1, _ := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(2))
values2, _, ciphertext2 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
p2 := ring.Poly{Coeffs: [][]uint64{values2}}
require.True(t, ciphertext0.Scale.Cmp(plaintext1.Scale) != 0)
require.True(t, ciphertext0.Scale.Cmp(ciphertext2.Scale) != 0)
require.NoError(t, tc.evaluator.MulThenAdd(ciphertext0, plaintext1, ciphertext2))
tc.ringT.MulCoeffsBarrettThenAdd(values0, values1, values2)
require.NoError(t, tc.Evl.MulThenAdd(ciphertext0, plaintext1, ciphertext2))
tc.Params.RingT().MulCoeffsBarrettThenAdd(p0, p1, p2)
verifyTestVectors(tc, tc.decryptor, values2, ciphertext2, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext2, p2.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/MulThenAdd/Ct/Scalar/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevel {
t.Run(name("Evaluator/MulThenAdd/Ct/Scalar/Inplace", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Skipping: Level = 0")
}
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0)
scalar := tc.params.PlaintextModulus() >> 1
scalar := tc.Params.PlaintextModulus() >> 1
require.NoError(t, tc.evaluator.MulThenAdd(ciphertext0, scalar, ciphertext1))
tc.ringT.MulScalarThenAdd(values0, scalar, values1)
require.NoError(t, tc.Evl.MulThenAdd(ciphertext0, scalar, ciphertext1))
tc.Params.RingT().MulScalarThenAdd(p0, scalar, p1)
verifyTestVectors(tc, tc.decryptor, values1, ciphertext1, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext1, p1.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/MulThenAdd/Ct/Vector/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevel {
t.Run(name("Evaluator/MulThenAdd/Ct/Vector/Inplace", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Skipping: Level = 0")
}
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk)
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))
values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
scale := ciphertext1.Scale
require.NoError(t, tc.evaluator.MulThenAdd(ciphertext0, values1.Coeffs[0], ciphertext1))
tc.ringT.MulCoeffsBarrettThenAdd(values0, values1, values1)
require.NoError(t, tc.Evl.MulThenAdd(ciphertext0, values1, ciphertext1))
tc.Params.RingT().MulCoeffsBarrettThenAdd(p0, p1, p1)
// Checks that output scale isn't changed
require.True(t, scale.Equal(ciphertext1.Scale))
verifyTestVectors(tc, tc.decryptor, values1, ciphertext1, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext1, p1.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel {
t.Run(GetTestName("Evaluator/MulRelinThenAdd/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevel {
t.Run(name("Evaluator/MulRelinThenAdd/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Skipping: Level = 0")
}
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk)
values1, _, ciphertext1 := newTestVectorsLvl(lvl, rlwe.NewScale(2), tc, tc.encryptorSk)
values2, _, ciphertext2 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk)
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale())
values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(2))
values2, _, ciphertext2 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7))
require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0)
require.True(t, ciphertext0.Scale.Cmp(ciphertext2.Scale) != 0)
require.NoError(t, tc.evaluator.MulRelinThenAdd(ciphertext0, ciphertext1, ciphertext2))
tc.ringT.MulCoeffsBarrettThenAdd(values0, values1, values2)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
p2 := ring.Poly{Coeffs: [][]uint64{values2}}
verifyTestVectors(tc, tc.decryptor, values2, ciphertext2, t)
require.NoError(t, tc.Evl.MulRelinThenAdd(ciphertext0, ciphertext1, ciphertext2))
tc.Params.RingT().MulCoeffsBarrettThenAdd(p0, p1, p2)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext2, p2.Coeffs[0], t)
})
}
for _, lvl := range tc.testLevel[:] {
t.Run(GetTestName("Evaluator/Rescale", tc.params, lvl), func(t *testing.T) {
for _, lvl := range testLevel[:] {
t.Run(name("Evaluator/Rescale", tc, lvl), func(t *testing.T) {
ringT := tc.params.RingT()
ringT := tc.Params.RingT()
values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorPk)
values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale())
printNoise := func(msg string, values []uint64, ct *rlwe.Ciphertext) {
pt := NewPlaintext(tc.params, ct.Level())
pt := NewPlaintext(tc.Params, ct.Level())
pt.MetaData = ciphertext0.MetaData
require.NoError(t, tc.encoder.Encode(values0.Coeffs[0], pt))
ct, err := tc.evaluator.SubNew(ct, pt)
require.NoError(t, tc.Ecd.Encode(values0, pt))
ct, err := tc.Evl.SubNew(ct, pt)
require.NoError(t, err)
vartmp, _, _ := rlwe.Norm(ct, tc.decryptor)
vartmp, _, _ := rlwe.Norm(ct, tc.Dec)
t.Logf("STD(noise) %s: %f\n", msg, vartmp)
}
if lvl != 0 {
values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk)
values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale())
if *flagPrintNoise {
printNoise("0x", values0.Coeffs[0], ciphertext0)
printNoise("0x", values0, ciphertext0)
}
for i := 0; i < lvl; i++ {
tc.evaluator.MulRelin(ciphertext0, ciphertext1, ciphertext0)
p0 := ring.Poly{Coeffs: [][]uint64{values0}}
p1 := ring.Poly{Coeffs: [][]uint64{values1}}
ringT.MulCoeffsBarrett(values0, values1, values0)
for i := 0; i < lvl; i++ {
tc.Evl.MulRelin(ciphertext0, ciphertext1, ciphertext0)
ringT.MulCoeffsBarrett(p0, p1, p0)
if *flagPrintNoise {
printNoise(fmt.Sprintf("%dx", i+1), values0.Coeffs[0], ciphertext0)
printNoise(fmt.Sprintf("%dx", i+1), p0.Coeffs[0], ciphertext0)
}
}
verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t)
require.Nil(t, tc.evaluator.Rescale(ciphertext0, ciphertext0))
require.Nil(t, tc.Evl.Rescale(ciphertext0, ciphertext0))
verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t)
} else {
require.NotNil(t, tc.evaluator.Rescale(ciphertext0, ciphertext0))
require.NotNil(t, tc.Evl.Rescale(ciphertext0, ciphertext0))
}
})
}

View File

@@ -1,14 +0,0 @@
package bgv
var (
// testInsecure are insecure parameters used for the sole purpose of fast testing.
testInsecure = ParametersLiteral{
LogN: 10,
Q: []uint64{0x3fffffa8001, 0x1000090001, 0x10000c8001, 0x10000f0001, 0xffff00001},
P: []uint64{0x7fffffd8001},
}
testPlaintextModulus = []uint64{0x101, 0xffc001}
testParams = []ParametersLiteral{testInsecure}
)

120
schemes/bgv/test_utils.go Normal file
View File

@@ -0,0 +1,120 @@
package bgv
import (
"fmt"
"math"
"slices"
"testing"
"github.com/stretchr/testify/require"
"github.com/tuneinsight/lattigo/v5/core/rlwe"
"github.com/tuneinsight/lattigo/v5/ring"
"github.com/tuneinsight/lattigo/v5/utils/sampling"
)
type TestContext struct {
Params Parameters
Ecd *Encoder
Prng sampling.PRNG
Sampler *ring.UniformSampler
Kgen *rlwe.KeyGenerator
Sk *rlwe.SecretKey
Pk *rlwe.PublicKey
Enc *rlwe.Encryptor
Dec *rlwe.Decryptor
Evl *Evaluator
}
func NewTestContext(params ParametersLiteral) *TestContext {
tc := new(TestContext)
var err error
tc.Params, err = NewParametersFromLiteral(params)
if err != nil {
panic(err)
}
tc.Ecd = NewEncoder(tc.Params)
tc.Prng, err = sampling.NewPRNG()
if err != nil {
panic(err)
}
tc.Sampler = ring.NewUniformSampler(tc.Prng, tc.Params.RingT())
tc.Kgen = rlwe.NewKeyGenerator(tc.Params)
tc.Sk, tc.Pk = tc.Kgen.GenKeyPairNew()
tc.Enc = rlwe.NewEncryptor(tc.Params, tc.Pk)
tc.Dec = rlwe.NewDecryptor(tc.Params, tc.Sk)
tc.Evl = NewEvaluator(tc.Params, rlwe.NewMemEvaluationKeySet(tc.Kgen.GenRelinearizationKeyNew(tc.Sk)))
return tc
}
func (tc TestContext) String() string {
return fmt.Sprintf("LogN=%d/logQ=%d/logP=%d/LogSlots=%dx%d/logT=%d/Qi=%d/Pi=%d",
tc.Params.LogN(),
int(math.Round(tc.Params.LogQ())),
int(math.Round(tc.Params.LogP())),
tc.Params.LogMaxDimensions().Rows,
tc.Params.LogMaxDimensions().Cols,
int(math.Round(tc.Params.LogT())),
tc.Params.QCount(),
tc.Params.PCount())
}
func VerifyTestVectors(params Parameters, encoder *Encoder, decryptor *rlwe.Decryptor, have interface{}, want []uint64, t *testing.T) {
values := make([]uint64, params.MaxSlots())
switch have := have.(type) {
case *rlwe.Plaintext:
require.NoError(t, encoder.Decode(have, values))
case *rlwe.Ciphertext:
require.NoError(t, encoder.Decode(decryptor.DecryptNew(have), values))
default:
t.Error("invalid unsupported test object type")
}
require.True(t, slices.Equal(values, want))
}
func NewTestVector(params Parameters, encoder *Encoder, encryptor *rlwe.Encryptor, level int, scale rlwe.Scale) (values []uint64, pt *rlwe.Plaintext, ct *rlwe.Ciphertext) {
values = make([]uint64, params.MaxSlots())
for i := range values {
values[i] = sampling.RandUint64() % params.PlaintextModulus()
}
pt = NewPlaintext(params, level)
pt.Scale = scale
if err := encoder.Encode(values, pt); err != nil {
panic(err)
}
if encryptor != nil {
var err error
ct, err = encryptor.EncryptNew(pt)
if err != nil {
panic(err)
}
}
return
}
var (
// BgvTestInsecure are insecure parameters used for the sole purpose of fast testing.
TestInsecure = ParametersLiteral{
LogN: 10,
Q: []uint64{0x3fffffa8001, 0x1000090001, 0x10000c8001, 0x10000f0001, 0xffff00001},
P: []uint64{0x7fffffd8001},
}
TestPlaintextModulus = []uint64{0x101, 0xffc001}
TestParams = []ParametersLiteral{TestInsecure}
)

View File

@@ -2,7 +2,6 @@ package ckks
import (
"encoding/json"
"fmt"
"runtime"
"testing"
@@ -11,28 +10,7 @@ import (
"github.com/tuneinsight/lattigo/v5/utils/sampling"
)
func GetBenchName(params Parameters, opname string) string {
var PrecisionMod string
switch params.precisionMode {
case PREC64:
PrecisionMod = "PREC64"
case PREC128:
PrecisionMod = "PREC128"
}
return fmt.Sprintf("%s/RingType=%s/logN=%d/Qi=%d/Pi=%d/LogSlots=%d/%s",
opname,
params.RingType(),
params.LogN(),
params.QCount(),
params.PCount(),
params.LogMaxSlots(),
PrecisionMod)
}
func BenchmarkCKKS(b *testing.B) {
var err error
var testParams []ParametersLiteral
@@ -56,18 +34,9 @@ func BenchmarkCKKS(b *testing.B) {
for _, paramsLiteral := range testParams {
var params Parameters
if params, err = NewParametersFromLiteral(paramsLiteral); err != nil {
b.Error(err)
b.Fail()
}
tc := NewTestContext(paramsLiteral)
var tc *testContext
if tc, err = genTestParams(params); err != nil {
b.Fatal(err)
}
for _, testSet := range []func(tc *testContext, b *testing.B){
for _, testSet := range []func(tc *TestContext, b *testing.B){
benchEncoder,
benchEvaluator,
} {
@@ -77,13 +46,13 @@ func BenchmarkCKKS(b *testing.B) {
}
}
func benchEncoder(tc *testContext, b *testing.B) {
func benchEncoder(tc *TestContext, b *testing.B) {
encoder := tc.encoder
encoder := tc.Ecd
b.Run(GetBenchName(tc.params, "Encoder/Encode"), func(b *testing.B) {
b.Run(name("Encoder/Encode", tc), func(b *testing.B) {
pt := NewPlaintext(tc.params, tc.params.MaxLevel())
pt := NewPlaintext(tc.Params, tc.Params.MaxLevel())
values := make([]complex128, 1<<pt.LogDimensions.Cols)
for i := range values {
@@ -100,9 +69,9 @@ func benchEncoder(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(tc.params, "Encoder/Decode"), func(b *testing.B) {
b.Run(name("Encoder/Decode", tc), func(b *testing.B) {
pt := NewPlaintext(tc.params, tc.params.MaxLevel())
pt := NewPlaintext(tc.Params, tc.Params.MaxLevel())
values := make([]complex128, 1<<pt.LogDimensions.Cols)
for i := range values {
@@ -122,26 +91,26 @@ func benchEncoder(tc *testContext, b *testing.B) {
})
}
func benchEvaluator(tc *testContext, b *testing.B) {
func benchEvaluator(tc *TestContext, b *testing.B) {
params := tc.params
params := tc.Params
plaintext := NewPlaintext(params, params.MaxLevel())
plaintext.Value = rlwe.NewCiphertextRandom(tc.prng, params.Parameters, 0, plaintext.Level()).Value[0]
plaintext.Value = rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 0, plaintext.Level()).Value[0]
vector := make([]float64, params.MaxSlots())
for i := range vector {
vector[i] = 1
}
ciphertext1 := rlwe.NewCiphertextRandom(tc.prng, params.Parameters, 1, params.MaxLevel())
ciphertext2 := rlwe.NewCiphertextRandom(tc.prng, params.Parameters, 1, params.MaxLevel())
ciphertext1 := rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 1, params.MaxLevel())
ciphertext2 := rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 1, params.MaxLevel())
*ciphertext1.MetaData = *plaintext.MetaData
*ciphertext2.MetaData = *plaintext.MetaData
eval := tc.evaluator.WithKey(rlwe.NewMemEvaluationKeySet(tc.kgen.GenRelinearizationKeyNew(tc.sk)))
eval := tc.Evl.WithKey(rlwe.NewMemEvaluationKeySet(tc.Kgen.GenRelinearizationKeyNew(tc.Sk)))
b.Run(GetBenchName(params, "Evaluator/Add/Scalar"), func(b *testing.B) {
b.Run(name("Evaluator/Add/Scalar", tc), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -152,7 +121,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Add/Vector"), func(b *testing.B) {
b.Run(name("Evaluator/Add/Vector", tc), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -163,7 +132,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Add/Plaintext"), func(b *testing.B) {
b.Run(name("Evaluator/Add/Plaintext", tc), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -174,7 +143,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Add/Ciphertext"), func(b *testing.B) {
b.Run(name("Evaluator/Add/Ciphertext", tc), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -185,7 +154,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Mul/Scalar"), func(b *testing.B) {
b.Run(name("Evaluator/Mul/Scalar", tc), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -196,7 +165,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Mul/Vector"), func(b *testing.B) {
b.Run(name("Evaluator/Mul/Vector", tc), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -207,7 +176,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Mul/Plaintext"), func(b *testing.B) {
b.Run(name("Evaluator/Mul/Plaintext", tc), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -218,7 +187,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Mul/Ciphertext"), func(b *testing.B) {
b.Run(name("Evaluator/Mul/Ciphertext", tc), func(b *testing.B) {
receiver := NewCiphertext(params, 2, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -229,7 +198,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/MulRelin/Ciphertext"), func(b *testing.B) {
b.Run(name("Evaluator/MulRelin/Ciphertext", tc), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -240,7 +209,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/MulThenAdd/Scalar"), func(b *testing.B) {
b.Run(name("Evaluator/MulThenAdd/Scalar", tc), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -251,7 +220,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/MulThenAdd/Vector"), func(b *testing.B) {
b.Run(name("Evaluator/MulThenAdd/Vector", tc), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -262,7 +231,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/MulThenAdd/Plaintext"), func(b *testing.B) {
b.Run(name("Evaluator/MulThenAdd/Plaintext", tc), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -273,7 +242,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/MulThenAdd/Ciphertext"), func(b *testing.B) {
b.Run(name("Evaluator/MulThenAdd/Ciphertext", tc), func(b *testing.B) {
receiver := NewCiphertext(params, 2, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -284,7 +253,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/MulRelinThenAdd/Ciphertext"), func(b *testing.B) {
b.Run(name("Evaluator/MulRelinThenAdd/Ciphertext", tc), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -295,7 +264,7 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Rescale"), func(b *testing.B) {
b.Run(name("Evaluator/Rescale", tc), func(b *testing.B) {
receiver := NewCiphertext(params, 1, ciphertext1.Level()-1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -306,8 +275,8 @@ func benchEvaluator(tc *testContext, b *testing.B) {
}
})
b.Run(GetBenchName(params, "Evaluator/Rotate"), func(b *testing.B) {
gk := tc.kgen.GenGaloisKeyNew(5, tc.sk)
b.Run(name("Evaluator/Rotate", tc), func(b *testing.B) {
gk := tc.Kgen.GenGaloisKeyNew(5, tc.Sk)
evk := rlwe.NewMemEvaluationKeySet(nil, gk)
eval := eval.WithKey(evk)
receiver := NewCiphertext(params, 1, ciphertext1.Level())

View File

@@ -10,9 +10,9 @@ import (
"testing"
"github.com/stretchr/testify/require"
"github.com/tuneinsight/lattigo/v5/ring"
"github.com/tuneinsight/lattigo/v5/core/rlwe"
"github.com/tuneinsight/lattigo/v5/ring"
"github.com/tuneinsight/lattigo/v5/utils/bignum"
"github.com/tuneinsight/lattigo/v5/utils/sampling"
)
@@ -20,32 +20,6 @@ import (
var flagParamString = flag.String("params", "", "specify the test cryptographic parameters as a JSON string. Overrides -short and -long.")
var printPrecisionStats = flag.Bool("print-precision", false, "print precision stats")
func GetTestName(params Parameters, opname string) string {
return fmt.Sprintf("%s/RingType=%s/logN=%d/logQP=%d/Qi=%d/Pi=%d/LogScale=%d",
opname,
params.RingType(),
params.LogN(),
int(math.Round(params.LogQP())),
params.QCount(),
params.PCount(),
int(math.Log2(params.DefaultScale().Float64())))
}
type testContext struct {
params Parameters
ringQ *ring.Ring
ringP *ring.Ring
prng sampling.PRNG
encoder *Encoder
kgen *rlwe.KeyGenerator
sk *rlwe.SecretKey
pk *rlwe.PublicKey
encryptorPk *rlwe.Encryptor
encryptorSk *rlwe.Encryptor
decryptor *rlwe.Decryptor
evaluator *Evaluator
}
func TestCKKS(t *testing.T) {
var err error
@@ -58,7 +32,7 @@ func TestCKKS(t *testing.T) {
t.Fatal(err)
}
default:
testParams = testParamsLiteral
testParams = TestParametersLiteral
}
for _, ringType := range []ring.Type{ring.Standard, ring.ConjugateInvariant} {
@@ -71,17 +45,9 @@ func TestCKKS(t *testing.T) {
paramsLiteral.LogN = 10
}
var params Parameters
if params, err = NewParametersFromLiteral(paramsLiteral); err != nil {
t.Fatal(err)
}
tc := NewTestContext(paramsLiteral)
var tc *testContext
if tc, err = genTestParams(params); err != nil {
t.Fatal(err)
}
for _, testSet := range []func(tc *testContext, t *testing.T){
for _, testSet := range []func(tc *TestContext, t *testing.T){
testParameters,
testEncoder,
testEvaluatorAdd,
@@ -96,103 +62,11 @@ func TestCKKS(t *testing.T) {
}
}
}
}
func genTestParams(defaultParam Parameters) (tc *testContext, err error) {
func testParameters(tc *TestContext, t *testing.T) {
tc = new(testContext)
tc.params = defaultParam
tc.kgen = NewKeyGenerator(tc.params)
tc.sk, tc.pk = tc.kgen.GenKeyPairNew()
tc.ringQ = defaultParam.RingQ()
if tc.params.PCount() != 0 {
tc.ringP = defaultParam.RingP()
}
if tc.prng, err = sampling.NewPRNG(); err != nil {
return nil, err
}
tc.encoder = NewEncoder(tc.params)
tc.encryptorPk = NewEncryptor(tc.params, tc.pk)
tc.encryptorSk = NewEncryptor(tc.params, tc.sk)
tc.decryptor = NewDecryptor(tc.params, tc.sk)
tc.evaluator = NewEvaluator(tc.params, rlwe.NewMemEvaluationKeySet(tc.kgen.GenRelinearizationKeyNew(tc.sk)))
return tc, nil
}
func newTestVectors(tc *testContext, encryptor *rlwe.Encryptor, a, b complex128, t *testing.T) (values []*bignum.Complex, pt *rlwe.Plaintext, ct *rlwe.Ciphertext) {
prec := tc.encoder.Prec()
pt = NewPlaintext(tc.params, tc.params.MaxLevel())
values = make([]*bignum.Complex, pt.Slots())
switch tc.params.RingType() {
case ring.Standard:
for i := range values {
values[i] = &bignum.Complex{
bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec),
bignum.NewFloat(sampling.RandFloat64(imag(a), imag(b)), prec),
}
}
case ring.ConjugateInvariant:
for i := range values {
values[i] = &bignum.Complex{
bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec),
new(big.Float),
}
}
default:
t.Fatal("invalid ring type")
}
tc.encoder.Encode(values, pt)
if encryptor != nil {
var err error
ct, err = encryptor.EncryptNew(pt)
if err != nil {
panic(err)
}
}
return values, pt, ct
}
func randomConst(tp ring.Type, prec uint, a, b complex128) (constant *bignum.Complex) {
switch tp {
case ring.Standard:
constant = &bignum.Complex{
bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec),
bignum.NewFloat(sampling.RandFloat64(imag(a), imag(b)), prec),
}
case ring.ConjugateInvariant:
constant = &bignum.Complex{
bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec),
new(big.Float),
}
default:
panic("invalid ring type")
}
return
}
func testParameters(tc *testContext, t *testing.T) {
t.Run(GetTestName(tc.params, "Parameters/NewParameters"), func(t *testing.T) {
t.Run(name("Parameters/NewParameters", tc), func(t *testing.T) {
params, err := NewParametersFromLiteral(ParametersLiteral{
LogN: 4,
LogQ: []int{60, 60},
@@ -205,32 +79,32 @@ func testParameters(tc *testContext, t *testing.T) {
require.Equal(t, rlwe.DefaultXs, params.Xs())
})
t.Run(GetTestName(tc.params, "Parameters/StandardRing"), func(t *testing.T) {
params, err := tc.params.StandardParameters()
switch tc.params.RingType() {
t.Run(name("Parameters/StandardRing", tc), func(t *testing.T) {
params, err := tc.Params.StandardParameters()
switch tc.Params.RingType() {
case ring.Standard:
require.True(t, params.Equal(&tc.params))
require.True(t, params.Equal(&tc.Params))
require.NoError(t, err)
case ring.ConjugateInvariant:
require.Equal(t, params.LogN(), tc.params.LogN()+1)
require.Equal(t, params.LogN(), tc.Params.LogN()+1)
require.NoError(t, err)
default:
t.Fatal("invalid RingType")
}
})
t.Run(GetTestName(tc.params, "Parameters/Marshaller/Binary"), func(t *testing.T) {
t.Run(name("Parameters/Marshaller/Binary", tc), func(t *testing.T) {
bytes, err := tc.params.MarshalBinary()
bytes, err := tc.Params.MarshalBinary()
require.Nil(t, err)
var p Parameters
require.Nil(t, p.UnmarshalBinary(bytes))
require.True(t, tc.params.Equal(&p))
require.True(t, tc.Params.Equal(&p))
})
t.Run(GetTestName(tc.params, "Parameters/Marshaller/JSON"), func(t *testing.T) {
t.Run(name("Parameters/Marshaller/JSON", tc), func(t *testing.T) {
// checks that parameters can be marshalled without error
data, err := json.Marshal(tc.params)
data, err := json.Marshal(tc.Params)
require.Nil(t, err)
require.NotNil(t, data)
@@ -238,10 +112,10 @@ func testParameters(tc *testContext, t *testing.T) {
var paramsRec Parameters
err = json.Unmarshal(data, &paramsRec)
require.Nil(t, err)
require.True(t, tc.params.Equal(&paramsRec))
require.True(t, tc.Params.Equal(&paramsRec))
// checks that ckks.Parameters can be unmarshalled with log-moduli definition without error
dataWithLogModuli := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[60], "LogDefaultScale":30}`, tc.params.LogN()))
dataWithLogModuli := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[60], "LogDefaultScale":30}`, tc.Params.LogN()))
var paramsWithLogModuli Parameters
err = json.Unmarshal(dataWithLogModuli, &paramsWithLogModuli)
require.Nil(t, err)
@@ -252,7 +126,7 @@ func testParameters(tc *testContext, t *testing.T) {
require.Equal(t, float64(1<<30), paramsWithLogModuli.DefaultScale().Float64())
// checks that ckks.Parameters can be unmarshalled with log-moduli definition with empty P without error
dataWithLogModuliNoP := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[], "RingType": "ConjugateInvariant"}`, tc.params.LogN()))
dataWithLogModuliNoP := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[], "RingType": "ConjugateInvariant"}`, tc.Params.LogN()))
var paramsWithLogModuliNoP Parameters
err = json.Unmarshal(dataWithLogModuliNoP, &paramsWithLogModuliNoP)
require.Nil(t, err)
@@ -261,7 +135,7 @@ func testParameters(tc *testContext, t *testing.T) {
require.Equal(t, ring.ConjugateInvariant, paramsWithLogModuliNoP.RingType())
// checks that one can provide custom parameters for the secret-key and error distributions
dataWithCustomSecrets := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[60], "Xs": {"Type": "Ternary", "H": 192}, "Xe": {"Type": "DiscreteGaussian", "Sigma": 6.6, "Bound": 39.6}}`, tc.params.LogN()))
dataWithCustomSecrets := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[60], "Xs": {"Type": "Ternary", "H": 192}, "Xe": {"Type": "DiscreteGaussian", "Sigma": 6.6, "Bound": 39.6}}`, tc.Params.LogN()))
var paramsWithCustomSecrets Parameters
err = json.Unmarshal(dataWithCustomSecrets, &paramsWithCustomSecrets)
require.Nil(t, err)
@@ -270,24 +144,24 @@ func testParameters(tc *testContext, t *testing.T) {
})
}
func testEncoder(tc *testContext, t *testing.T) {
func testEncoder(tc *TestContext, t *testing.T) {
t.Run(GetTestName(tc.params, "Encoder/IsBatched=true"), func(t *testing.T) {
t.Run(name("Encoder/IsBatched=true", tc), func(t *testing.T) {
values, plaintext, _ := newTestVectors(tc, nil, -1-1i, 1+1i, t)
values, plaintext, _ := tc.NewTestVector(-1-1i, 1+1i)
VerifyTestVectors(tc.params, tc.encoder, nil, values, plaintext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, nil, values, plaintext, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
logprec := float64(tc.params.LogDefaultScale()) / 2
logprec := float64(tc.Params.LogDefaultScale()) / 2
t.Run(GetTestName(tc.params, "Encoder/IsBatched=true/DecodePublic/[]float64"), func(t *testing.T) {
t.Run(name("Encoder/IsBatched=true/DecodePublic/[]float64", tc), func(t *testing.T) {
values, plaintext, _ := newTestVectors(tc, nil, -1-1i, 1+1i, t)
values, plaintext, _ := tc.NewTestVector(-1-1i, 1+1i)
have := make([]float64, len(values))
require.NoError(t, tc.encoder.DecodePublic(plaintext, have, logprec))
require.NoError(t, tc.Ecd.DecodePublic(plaintext, have, logprec))
want := make([]float64, len(values))
for i := range want {
@@ -299,15 +173,15 @@ func testEncoder(tc *testContext, t *testing.T) {
require.GreaterOrEqual(t, StandardDeviation(want, rlwe.NewScale(1)), math.Exp2(-logprec)/math.Sqrt(12)*0.9)
})
t.Run(GetTestName(tc.params, "Encoder/IsBatched=true/DecodePublic/[]complex128"), func(t *testing.T) {
t.Run(name("Encoder/IsBatched=true/DecodePublic/[]complex128", tc), func(t *testing.T) {
if tc.params.RingType() == ring.ConjugateInvariant {
if tc.Params.RingType() == ring.ConjugateInvariant {
t.Skip()
}
values, plaintext, _ := newTestVectors(tc, nil, -1-1i, 1+1i, t)
values, plaintext, _ := tc.NewTestVector(-1-1i, 1+1i)
have := make([]complex128, len(values))
require.NoError(t, tc.encoder.DecodePublic(plaintext, have, logprec))
require.NoError(t, tc.Ecd.DecodePublic(plaintext, have, logprec))
wantReal := make([]float64, len(values))
wantImag := make([]float64, len(values))
@@ -325,10 +199,10 @@ func testEncoder(tc *testContext, t *testing.T) {
require.GreaterOrEqual(t, StandardDeviation(wantImag, rlwe.NewScale(1)), math.Exp2(-logprec)/math.Sqrt(12)*0.9)
})
t.Run(GetTestName(tc.params, "Encoder/IsBatched=true/DecodePublic/[]big.Float"), func(t *testing.T) {
values, plaintext, _ := newTestVectors(tc, nil, -1-1i, 1+1i, t)
t.Run(name("Encoder/IsBatched=true/DecodePublic/[]big.Float", tc), func(t *testing.T) {
values, plaintext, _ := tc.NewTestVector(-1-1i, 1+1i)
have := make([]*big.Float, len(values))
require.NoError(t, tc.encoder.DecodePublic(plaintext, have, logprec))
require.NoError(t, tc.Ecd.DecodePublic(plaintext, have, logprec))
want := make([]*big.Float, len(values))
for i := range want {
@@ -339,13 +213,13 @@ func testEncoder(tc *testContext, t *testing.T) {
require.GreaterOrEqual(t, StandardDeviation(want, rlwe.NewScale(1)), math.Exp2(-logprec)/math.Sqrt(12)*0.9)
})
t.Run(GetTestName(tc.params, "Encoder/IsBatched=true/DecodePublic/[]bignum.Complex"), func(t *testing.T) {
if tc.params.RingType() == ring.ConjugateInvariant {
t.Run(name("Encoder/IsBatched=true/DecodePublic/[]bignum.Complex", tc), func(t *testing.T) {
if tc.Params.RingType() == ring.ConjugateInvariant {
t.Skip()
}
values, plaintext, _ := newTestVectors(tc, nil, -1-1i, 1+1i, t)
values, plaintext, _ := tc.NewTestVector(-1-1i, 1+1i)
have := make([]*bignum.Complex, len(values))
require.NoError(t, tc.encoder.DecodePublic(plaintext, have, logprec))
require.NoError(t, tc.Ecd.DecodePublic(plaintext, have, logprec))
wantReal := make([]*big.Float, len(values))
wantImag := make([]*big.Float, len(values))
@@ -360,9 +234,9 @@ func testEncoder(tc *testContext, t *testing.T) {
require.GreaterOrEqual(t, StandardDeviation(wantImag, rlwe.NewScale(1)), math.Exp2(-logprec)/math.Sqrt(12)*0.9)
})
t.Run(GetTestName(tc.params, "Encoder/IsBatched=false"), func(t *testing.T) {
t.Run(name("Encoder/IsBatched=false", tc), func(t *testing.T) {
slots := tc.params.N()
slots := tc.Params.N()
valuesWant := make([]float64, slots)
@@ -372,14 +246,14 @@ func testEncoder(tc *testContext, t *testing.T) {
valuesWant[0] = 0.607538
pt := NewPlaintext(tc.params, tc.params.MaxLevel())
pt := NewPlaintext(tc.Params, tc.Params.MaxLevel())
pt.IsBatched = false
tc.encoder.Encode(valuesWant, pt)
tc.Ecd.Encode(valuesWant, pt)
valuesTest := make([]float64, len(valuesWant))
tc.encoder.Decode(pt, valuesTest)
tc.Ecd.Decode(pt, valuesTest)
var meanprec float64
@@ -393,7 +267,7 @@ func testEncoder(tc *testContext, t *testing.T) {
t.Logf("\nMean precision : %.2f \n", math.Log2(1/meanprec))
}
minPrec := math.Log2(tc.params.DefaultScale().Float64()) - float64(tc.params.LogN()+2)
minPrec := math.Log2(tc.Params.DefaultScale().Float64()) - float64(tc.Params.LogN()+2)
if minPrec < 0 {
minPrec = 0
}
@@ -401,12 +275,12 @@ func testEncoder(tc *testContext, t *testing.T) {
require.GreaterOrEqual(t, math.Log2(1/meanprec), minPrec)
// Also tests at level 0
pt = NewPlaintext(tc.params, tc.params.LevelsConsumedPerRescaling()-1)
pt = NewPlaintext(tc.Params, tc.Params.LevelsConsumedPerRescaling()-1)
pt.IsBatched = false
tc.encoder.Encode(valuesWant, pt)
tc.Ecd.Encode(valuesWant, pt)
tc.encoder.Decode(pt, valuesTest)
tc.Ecd.Decode(pt, valuesTest)
meanprec = 0
for i := range valuesWant {
@@ -423,116 +297,116 @@ func testEncoder(tc *testContext, t *testing.T) {
})
}
func testEvaluatorAdd(tc *testContext, t *testing.T) {
func testEvaluatorAdd(tc *TestContext, t *testing.T) {
t.Run(GetTestName(tc.params, "Evaluator/AddNew/Ct"), func(t *testing.T) {
t.Run(name("Evaluator/AddNew/Ct", tc), func(t *testing.T) {
values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values1, _, ciphertext1 := tc.NewTestVector(-1-1i, 1+1i)
values2, _, ciphertext2 := tc.NewTestVector(-1-1i, 1+1i)
for i := range values1 {
values1[i].Add(values1[i], values2[i])
}
ciphertext3, err := tc.evaluator.AddNew(ciphertext1, ciphertext2)
ciphertext3, err := tc.Evl.AddNew(ciphertext1, ciphertext2)
require.NoError(t, err)
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext3, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext3, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(tc.params, "Evaluator/Add/Ct"), func(t *testing.T) {
t.Run(name("Evaluator/Add/Ct", tc), func(t *testing.T) {
values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values1, _, ciphertext1 := tc.NewTestVector(-1-1i, 1+1i)
values2, _, ciphertext2 := tc.NewTestVector(-1-1i, 1+1i)
for i := range values1 {
values1[i].Add(values1[i], values2[i])
}
require.NoError(t, tc.evaluator.Add(ciphertext1, ciphertext2, ciphertext1))
require.NoError(t, tc.Evl.Add(ciphertext1, ciphertext2, ciphertext1))
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext1, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(tc.params, "Evaluator/Add/Pt"), func(t *testing.T) {
t.Run(name("Evaluator/Add/Pt", tc), func(t *testing.T) {
values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values2, plaintext2, _ := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values1, _, ciphertext1 := tc.NewTestVector(-1-1i, 1+1i)
values2, plaintext2, _ := tc.NewTestVector(-1-1i, 1+1i)
for i := range values1 {
values1[i].Add(values1[i], values2[i])
}
require.NoError(t, tc.evaluator.Add(ciphertext1, plaintext2, ciphertext1))
require.NoError(t, tc.Evl.Add(ciphertext1, plaintext2, ciphertext1))
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext1, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(tc.params, "Evaluator/Add/Scalar"), func(t *testing.T) {
t.Run(name("Evaluator/Add/Scalar", tc), func(t *testing.T) {
values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values, _, ciphertext := tc.NewTestVector(-1-1i, 1+1i)
constant := randomConst(tc.params.RingType(), tc.encoder.Prec(), -1+1i, -1+1i)
constant := randomConst(tc.Params.RingType(), tc.Ecd.Prec(), -1+1i, -1+1i)
for i := range values {
values[i].Add(values[i], constant)
}
require.NoError(t, tc.evaluator.Add(ciphertext, constant, ciphertext))
require.NoError(t, tc.Evl.Add(ciphertext, constant, ciphertext))
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values, ciphertext, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(tc.params, "Evaluator/Add/Vector"), func(t *testing.T) {
t.Run(name("Evaluator/Add/Vector", tc), func(t *testing.T) {
values1, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values2, _, _ := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values1, _, ciphertext := tc.NewTestVector(-1-1i, 1+1i)
values2, _, _ := tc.NewTestVector(-1-1i, 1+1i)
for i := range values1 {
values1[i].Add(values1[i], values2[i])
}
require.NoError(t, tc.evaluator.Add(ciphertext, values2, ciphertext))
require.NoError(t, tc.Evl.Add(ciphertext, values2, ciphertext))
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
}
func testEvaluatorSub(tc *testContext, t *testing.T) {
func testEvaluatorSub(tc *TestContext, t *testing.T) {
t.Run(GetTestName(tc.params, "Evaluator/SubNew/Ct"), func(t *testing.T) {
t.Run(name("Evaluator/SubNew/Ct", tc), func(t *testing.T) {
values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values1, _, ciphertext1 := tc.NewTestVector(-1-1i, 1+1i)
values2, _, ciphertext2 := tc.NewTestVector(-1-1i, 1+1i)
for i := range values1 {
values1[i].Sub(values1[i], values2[i])
}
ciphertext3, err := tc.evaluator.SubNew(ciphertext1, ciphertext2)
ciphertext3, err := tc.Evl.SubNew(ciphertext1, ciphertext2)
require.NoError(t, err)
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext3, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext3, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(tc.params, "Evaluator/Sub/Ct"), func(t *testing.T) {
t.Run(name("Evaluator/Sub/Ct", tc), func(t *testing.T) {
values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values1, _, ciphertext1 := tc.NewTestVector(-1-1i, 1+1i)
values2, _, ciphertext2 := tc.NewTestVector(-1-1i, 1+1i)
for i := range values1 {
values1[i].Sub(values1[i], values2[i])
}
require.NoError(t, tc.evaluator.Sub(ciphertext1, ciphertext2, ciphertext1))
require.NoError(t, tc.Evl.Sub(ciphertext1, ciphertext2, ciphertext1))
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext1, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(tc.params, "Evaluator/Sub/Pt"), func(t *testing.T) {
t.Run(name("Evaluator/Sub/Pt", tc), func(t *testing.T) {
values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values2, plaintext2, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values1, _, ciphertext1 := tc.NewTestVector(-1-1i, 1+1i)
values2, plaintext2, ciphertext2 := tc.NewTestVector(-1-1i, 1+1i)
valuesTest := make([]*bignum.Complex, len(values1))
for i := range values1 {
@@ -540,96 +414,96 @@ func testEvaluatorSub(tc *testContext, t *testing.T) {
valuesTest[i].Sub(values1[i], values2[i])
}
require.NoError(t, tc.evaluator.Sub(ciphertext1, plaintext2, ciphertext2))
require.NoError(t, tc.Evl.Sub(ciphertext1, plaintext2, ciphertext2))
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, valuesTest, ciphertext2, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, valuesTest, ciphertext2, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(tc.params, "Evaluator/Sub/Scalar"), func(t *testing.T) {
t.Run(name("Evaluator/Sub/Scalar", tc), func(t *testing.T) {
values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values, _, ciphertext := tc.NewTestVector(-1-1i, 1+1i)
constant := randomConst(tc.params.RingType(), tc.encoder.Prec(), -1+1i, -1+1i)
constant := randomConst(tc.Params.RingType(), tc.Ecd.Prec(), -1+1i, -1+1i)
for i := range values {
values[i].Sub(values[i], constant)
}
require.NoError(t, tc.evaluator.Sub(ciphertext, constant, ciphertext))
require.NoError(t, tc.Evl.Sub(ciphertext, constant, ciphertext))
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values, ciphertext, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(tc.params, "Evaluator/Sub/Vector"), func(t *testing.T) {
t.Run(name("Evaluator/Sub/Vector", tc), func(t *testing.T) {
values1, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values2, _, _ := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values1, _, ciphertext := tc.NewTestVector(-1-1i, 1+1i)
values2, _, _ := tc.NewTestVector(-1-1i, 1+1i)
for i := range values1 {
values1[i].Sub(values1[i], values2[i])
}
require.NoError(t, tc.evaluator.Sub(ciphertext, values2, ciphertext))
require.NoError(t, tc.Evl.Sub(ciphertext, values2, ciphertext))
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
}
func testEvaluatorRescale(tc *testContext, t *testing.T) {
func testEvaluatorRescale(tc *TestContext, t *testing.T) {
t.Run(GetTestName(tc.params, "Evaluator/RescaleTo/Single"), func(t *testing.T) {
t.Run(name("Evaluator/RescaleTo/Single", tc), func(t *testing.T) {
if tc.params.MaxLevel() < 2 {
if tc.Params.MaxLevel() < 2 {
t.Skip("skipping test for params max level < 2")
}
values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values, _, ciphertext := tc.NewTestVector(-1-1i, 1+1i)
constant := tc.ringQ.SubRings[ciphertext.Level()].Modulus
constant := tc.Params.RingQ().SubRings[ciphertext.Level()].Modulus
require.NoError(t, tc.evaluator.Mul(ciphertext, constant, ciphertext))
require.NoError(t, tc.Evl.Mul(ciphertext, constant, ciphertext))
ciphertext.Scale = ciphertext.Scale.Mul(rlwe.NewScale(constant))
if err := tc.evaluator.RescaleTo(ciphertext, tc.params.DefaultScale(), ciphertext); err != nil {
if err := tc.Evl.RescaleTo(ciphertext, tc.Params.DefaultScale(), ciphertext); err != nil {
t.Fatal(err)
}
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values, ciphertext, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(tc.params, "Evaluator/RescaleTo/Many"), func(t *testing.T) {
t.Run(name("Evaluator/RescaleTo/Many", tc), func(t *testing.T) {
if tc.params.MaxLevel() < 2 {
if tc.Params.MaxLevel() < 2 {
t.Skip("skipping test for params max level < 2")
}
values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values, _, ciphertext := tc.NewTestVector(-1-1i, 1+1i)
nbRescales := tc.params.MaxLevel()
nbRescales := tc.Params.MaxLevel()
if nbRescales > 5 {
nbRescales = 5
}
for i := 0; i < nbRescales; i++ {
constant := tc.ringQ.SubRings[ciphertext.Level()-i].Modulus
require.NoError(t, tc.evaluator.Mul(ciphertext, constant, ciphertext))
constant := tc.Params.RingQ().SubRings[ciphertext.Level()-i].Modulus
require.NoError(t, tc.Evl.Mul(ciphertext, constant, ciphertext))
ciphertext.Scale = ciphertext.Scale.Mul(rlwe.NewScale(constant))
}
if err := tc.evaluator.RescaleTo(ciphertext, tc.params.DefaultScale(), ciphertext); err != nil {
if err := tc.Evl.RescaleTo(ciphertext, tc.Params.DefaultScale(), ciphertext); err != nil {
t.Fatal(err)
}
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values, ciphertext, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
}
func testEvaluatorMul(tc *testContext, t *testing.T) {
func testEvaluatorMul(tc *TestContext, t *testing.T) {
t.Run(GetTestName(tc.params, "Evaluator/MulNew/Ct/Pt"), func(t *testing.T) {
t.Run(name("Evaluator/MulNew/Ct/Pt", tc), func(t *testing.T) {
values1, plaintext1, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values1, plaintext1, ciphertext1 := tc.NewTestVector(-1-1i, 1+1i)
mul := bignum.NewComplexMultiplier()
@@ -637,17 +511,17 @@ func testEvaluatorMul(tc *testContext, t *testing.T) {
mul.Mul(values1[i], values1[i], values1[i])
}
ciphertext2, err := tc.evaluator.MulNew(ciphertext1, plaintext1)
ciphertext2, err := tc.Evl.MulNew(ciphertext1, plaintext1)
require.NoError(t, err)
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext2, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext2, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(tc.params, "Evaluator/Mul/Ct/Scalar"), func(t *testing.T) {
t.Run(name("Evaluator/Mul/Ct/Scalar", tc), func(t *testing.T) {
values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values, _, ciphertext := tc.NewTestVector(-1-1i, 1+1i)
constant := randomConst(tc.params.RingType(), tc.encoder.Prec(), -1+1i, -1+1i)
constant := randomConst(tc.Params.RingType(), tc.Ecd.Prec(), -1+1i, -1+1i)
mul := bignum.NewComplexMultiplier()
@@ -655,15 +529,15 @@ func testEvaluatorMul(tc *testContext, t *testing.T) {
mul.Mul(values[i], constant, values[i])
}
require.NoError(t, tc.evaluator.Mul(ciphertext, constant, ciphertext))
require.NoError(t, tc.Evl.Mul(ciphertext, constant, ciphertext))
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values, ciphertext, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(tc.params, "Evaluator/Mul/Ct/Vector"), func(t *testing.T) {
t.Run(name("Evaluator/Mul/Ct/Vector", tc), func(t *testing.T) {
values1, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values2, _, _ := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values1, _, ciphertext := tc.NewTestVector(-1-1i, 1+1i)
values2, _, _ := tc.NewTestVector(-1-1i, 1+1i)
mul := bignum.NewComplexMultiplier()
@@ -671,14 +545,14 @@ func testEvaluatorMul(tc *testContext, t *testing.T) {
mul.Mul(values1[i], values2[i], values1[i])
}
tc.evaluator.Mul(ciphertext, values2, ciphertext)
tc.Evl.Mul(ciphertext, values2, ciphertext)
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(tc.params, "Evaluator/Mul/Ct/Pt"), func(t *testing.T) {
t.Run(name("Evaluator/Mul/Ct/Pt", tc), func(t *testing.T) {
values1, plaintext1, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values1, plaintext1, ciphertext1 := tc.NewTestVector(-1-1i, 1+1i)
mul := bignum.NewComplexMultiplier()
@@ -686,15 +560,15 @@ func testEvaluatorMul(tc *testContext, t *testing.T) {
mul.Mul(values1[i], values1[i], values1[i])
}
require.NoError(t, tc.evaluator.MulRelin(ciphertext1, plaintext1, ciphertext1))
require.NoError(t, tc.Evl.MulRelin(ciphertext1, plaintext1, ciphertext1))
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext1, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(tc.params, "Evaluator/Mul/Ct/Ct/Degree0"), func(t *testing.T) {
t.Run(name("Evaluator/Mul/Ct/Ct/Degree0", tc), func(t *testing.T) {
values1, plaintext1, _ := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values1, plaintext1, _ := tc.NewTestVector(-1-1i, 1+1i)
values2, _, ciphertext2 := tc.NewTestVector(-1-1i, 1+1i)
mul := bignum.NewComplexMultiplier()
@@ -706,16 +580,16 @@ func testEvaluatorMul(tc *testContext, t *testing.T) {
ciphertext1.Value = []ring.Poly{plaintext1.Value}
ciphertext1.MetaData = plaintext1.MetaData.CopyNew()
require.NoError(t, tc.evaluator.MulRelin(ciphertext1, ciphertext2, ciphertext1))
require.NoError(t, tc.Evl.MulRelin(ciphertext1, ciphertext2, ciphertext1))
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values2, ciphertext1, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(tc.params, "Evaluator/MulRelin/Ct/Ct"), func(t *testing.T) {
t.Run(name("Evaluator/MulRelin/Ct/Ct", tc), func(t *testing.T) {
// op0 <- op0 * op1
values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values1, _, ciphertext1 := tc.NewTestVector(-1-1i, 1+1i)
values2, _, ciphertext2 := tc.NewTestVector(-1-1i, 1+1i)
mul := bignum.NewComplexMultiplier()
@@ -723,44 +597,44 @@ func testEvaluatorMul(tc *testContext, t *testing.T) {
mul.Mul(values1[i], values2[i], values1[i])
}
require.NoError(t, tc.evaluator.MulRelin(ciphertext1, ciphertext2, ciphertext1))
require.NoError(t, tc.Evl.MulRelin(ciphertext1, ciphertext2, ciphertext1))
require.Equal(t, ciphertext1.Degree(), 1)
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext1, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
// op1 <- op0 * op1
values1, _, ciphertext1 = newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values2, _, ciphertext2 = newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values1, _, ciphertext1 = tc.NewTestVector(-1-1i, 1+1i)
values2, _, ciphertext2 = tc.NewTestVector(-1-1i, 1+1i)
for i := range values1 {
mul.Mul(values2[i], values1[i], values2[i])
}
require.NoError(t, tc.evaluator.MulRelin(ciphertext1, ciphertext2, ciphertext2))
require.NoError(t, tc.Evl.MulRelin(ciphertext1, ciphertext2, ciphertext2))
require.Equal(t, ciphertext2.Degree(), 1)
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext2, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values2, ciphertext2, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
// op0 <- op0 * op0
for i := range values1 {
mul.Mul(values1[i], values1[i], values1[i])
}
require.NoError(t, tc.evaluator.MulRelin(ciphertext1, ciphertext1, ciphertext1))
require.NoError(t, tc.Evl.MulRelin(ciphertext1, ciphertext1, ciphertext1))
require.Equal(t, ciphertext1.Degree(), 1)
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext1, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
}
func testEvaluatorMulThenAdd(tc *testContext, t *testing.T) {
func testEvaluatorMulThenAdd(tc *TestContext, t *testing.T) {
t.Run(GetTestName(tc.params, "Evaluator/MulThenAdd/Scalar"), func(t *testing.T) {
t.Run(name("Evaluator/MulThenAdd/Scalar", tc), func(t *testing.T) {
values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values1, _, ciphertext1 := tc.NewTestVector(-1-1i, 1+1i)
values2, _, ciphertext2 := tc.NewTestVector(-1-1i, 1+1i)
constant := randomConst(tc.params.RingType(), tc.encoder.Prec(), -1+1i, -1+1i)
constant := randomConst(tc.Params.RingType(), tc.Ecd.Prec(), -1+1i, -1+1i)
mul := bignum.NewComplexMultiplier()
@@ -773,17 +647,17 @@ func testEvaluatorMulThenAdd(tc *testContext, t *testing.T) {
values2[i].Add(values2[i], tmp)
}
require.NoError(t, tc.evaluator.MulThenAdd(ciphertext1, constant, ciphertext2))
require.NoError(t, tc.Evl.MulThenAdd(ciphertext1, constant, ciphertext2))
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext2, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values2, ciphertext2, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(tc.params, "Evaluator/MulThenAdd/Vector"), func(t *testing.T) {
t.Run(name("Evaluator/MulThenAdd/Vector", tc), func(t *testing.T) {
values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1, 1, t)
values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1, 1, t)
values1, _, ciphertext1 := tc.NewTestVector(-1-1i, 1+1i)
values2, _, ciphertext2 := tc.NewTestVector(-1-1i, 1+1i)
require.NoError(t, tc.evaluator.MulThenAdd(ciphertext2, values1, ciphertext1))
require.NoError(t, tc.Evl.MulThenAdd(ciphertext2, values1, ciphertext1))
mul := bignum.NewComplexMultiplier()
@@ -798,13 +672,13 @@ func testEvaluatorMulThenAdd(tc *testContext, t *testing.T) {
require.Equal(t, ciphertext1.Degree(), 1)
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext1, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(tc.params, "Evaluator/MulThenAdd/Pt"), func(t *testing.T) {
t.Run(name("Evaluator/MulThenAdd/Pt", tc), func(t *testing.T) {
values1, plaintext1, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1, 1, t)
values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1, 1, t)
values1, plaintext1, ciphertext1 := tc.NewTestVector(-1, 1)
values2, _, ciphertext2 := tc.NewTestVector(-1, 1)
mul := bignum.NewComplexMultiplier()
@@ -817,18 +691,18 @@ func testEvaluatorMulThenAdd(tc *testContext, t *testing.T) {
values1[i].Add(values1[i], tmp)
}
require.NoError(t, tc.evaluator.MulThenAdd(ciphertext2, plaintext1, ciphertext1))
require.NoError(t, tc.Evl.MulThenAdd(ciphertext2, plaintext1, ciphertext1))
require.Equal(t, ciphertext1.Degree(), 1)
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext1, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
t.Run(GetTestName(tc.params, "Evaluator/MulRelinThenAdd/Ct"), func(t *testing.T) {
t.Run(name("Evaluator/MulRelinThenAdd/Ct", tc), func(t *testing.T) {
// opOut = opOut + op1 * op0
values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values1, _, ciphertext1 := tc.NewTestVector(-1-1i, 1+1i)
values2, _, ciphertext2 := tc.NewTestVector(-1-1i, 1+1i)
mul := bignum.NewComplexMultiplier()
@@ -836,23 +710,23 @@ func testEvaluatorMulThenAdd(tc *testContext, t *testing.T) {
mul.Mul(values1[i], values2[i], values2[i])
}
ciphertext3 := NewCiphertext(tc.params, 2, ciphertext1.Level())
ciphertext3 := NewCiphertext(tc.Params, 2, ciphertext1.Level())
ciphertext3.Scale = ciphertext1.Scale.Mul(ciphertext2.Scale)
require.NoError(t, tc.evaluator.MulThenAdd(ciphertext1, ciphertext2, ciphertext3))
require.NoError(t, tc.Evl.MulThenAdd(ciphertext1, ciphertext2, ciphertext3))
require.Equal(t, ciphertext3.Degree(), 2)
require.NoError(t, tc.evaluator.Relinearize(ciphertext3, ciphertext3))
require.NoError(t, tc.Evl.Relinearize(ciphertext3, ciphertext3))
require.Equal(t, ciphertext3.Degree(), 1)
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext3, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values2, ciphertext3, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
// op1 = op1 + op0*op0
values1, _, ciphertext1 = newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values2, _, ciphertext2 = newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values1, _, ciphertext1 = tc.NewTestVector(-1-1i, 1+1i)
values2, _, ciphertext2 = tc.NewTestVector(-1-1i, 1+1i)
tmp := bignum.NewComplex()
for i := range values1 {
@@ -860,23 +734,23 @@ func testEvaluatorMulThenAdd(tc *testContext, t *testing.T) {
values1[i].Add(values1[i], tmp)
}
require.NoError(t, tc.evaluator.MulRelinThenAdd(ciphertext2, ciphertext2, ciphertext1))
require.NoError(t, tc.Evl.MulRelinThenAdd(ciphertext2, ciphertext2, ciphertext1))
require.Equal(t, ciphertext1.Degree(), 1)
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext1, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
}
func testBridge(tc *testContext, t *testing.T) {
func testBridge(tc *TestContext, t *testing.T) {
t.Run(GetTestName(tc.params, "Bridge"), func(t *testing.T) {
t.Run(name("Bridge", tc), func(t *testing.T) {
if tc.params.RingType() != ring.ConjugateInvariant {
if tc.Params.RingType() != ring.ConjugateInvariant {
t.Skip("only tested for params.RingType() == ring.ConjugateInvariant")
}
ciParams := tc.params
ciParams := tc.Params
var err error
if _, err = ciParams.StandardParameters(); err != nil {
t.Fatalf("all Conjugate Invariant parameters should have a standard counterpart but got: %f", err)
@@ -896,7 +770,7 @@ func testBridge(tc *testContext, t *testing.T) {
stdEncoder := NewEncoder(stdParams)
stdEvaluator := NewEvaluator(stdParams, nil)
evkCtR, evkRtC := stdKeyGen.GenEvaluationKeysForRingSwapNew(stdSK, tc.sk)
evkCtR, evkRtC := stdKeyGen.GenEvaluationKeysForRingSwapNew(stdSK, tc.Sk)
switcher, err := NewDomainSwitcher(stdParams, evkCtR, evkRtC)
if err != nil {
@@ -905,13 +779,13 @@ func testBridge(tc *testContext, t *testing.T) {
evalStandar := NewEvaluator(stdParams, nil)
values, _, ctCI := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
values, _, ctCI := tc.NewTestVector(-1-1i, 1+1i)
stdCTHave := NewCiphertext(stdParams, ctCI.Degree(), ctCI.Level())
switcher.RealToComplex(evalStandar, ctCI, stdCTHave)
VerifyTestVectors(stdParams, stdEncoder, stdDecryptor, values, stdCTHave, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(stdParams, stdEncoder, stdDecryptor, values, stdCTHave, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
stdCTImag, err := stdEvaluator.MulNew(stdCTHave, 1i)
require.NoError(t, err)
@@ -920,6 +794,27 @@ func testBridge(tc *testContext, t *testing.T) {
ciCTHave := NewCiphertext(ciParams, 1, stdCTHave.Level())
switcher.ComplexToReal(evalStandar, stdCTHave, ciCTHave)
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciCTHave, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values, ciCTHave, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t)
})
}
func name(opname string, tc *TestContext) string {
var precMode string
switch tc.Params.precisionMode {
case PREC64:
precMode = "PREC64"
case PREC128:
precMode = "PREC128"
}
return fmt.Sprintf("%s/RingType=%s/logN=%d/logQP=%d/Qi=%d/Pi=%d/LogScale=%d/PrecMode=%s",
opname,
tc.Params.RingType(),
tc.Params.LogN(),
int(math.Round(tc.Params.LogQP())),
tc.Params.QCount(),
tc.Params.PCount(),
int(math.Log2(tc.Params.DefaultScale().Float64())),
precMode)
}

View File

@@ -1,48 +0,0 @@
package ckks
var (
// testInsecurePrec45 are insecure parameters used for the sole purpose of fast testing.
testInsecurePrec45 = ParametersLiteral{
LogN: 10,
Q: []uint64{
0x80000000080001,
0x2000000a0001,
0x2000000e0001,
0x2000001d0001,
0x1fffffcf0001,
0x1fffffc20001,
0x200000440001,
},
P: []uint64{
0x80000000130001,
0x7fffffffe90001,
},
LogDefaultScale: 45,
}
// testInsecurePrec90 are insecure parameters used for the sole purpose of fast testing.
testInsecurePrec90 = ParametersLiteral{
LogN: 10,
Q: []uint64{
0x80000000080001,
0x80000000440001,
0x2000000a0001,
0x2000000e0001,
0x1fffffc20001,
0x200000440001,
0x200000500001,
0x200000620001,
0x1fffff980001,
0x2000006a0001,
0x1fffff7e0001,
0x200000860001,
},
P: []uint64{
0xffffffffffc0001,
0x10000000006e0001,
},
LogDefaultScale: 90,
}
testParamsLiteral = []ParametersLiteral{testInsecurePrec45, testInsecurePrec90}
)

174
schemes/ckks/test_utils.go Normal file
View File

@@ -0,0 +1,174 @@
package ckks
import (
"math/big"
"github.com/tuneinsight/lattigo/v5/core/rlwe"
"github.com/tuneinsight/lattigo/v5/ring"
"github.com/tuneinsight/lattigo/v5/utils/bignum"
"github.com/tuneinsight/lattigo/v5/utils/sampling"
)
type TestContext struct {
Params Parameters
Ecd *Encoder
Prng sampling.PRNG
Kgen *rlwe.KeyGenerator
Sk *rlwe.SecretKey
Pk *rlwe.PublicKey
Enc *rlwe.Encryptor
Dec *rlwe.Decryptor
Evl *Evaluator
}
func NewTestContext(params ParametersLiteral) *TestContext {
tc := new(TestContext)
var err error
tc.Params, err = NewParametersFromLiteral(params)
if err != nil {
panic(err)
}
tc.Ecd = NewEncoder(tc.Params)
tc.Prng, err = sampling.NewPRNG()
if err != nil {
panic(err)
}
tc.Kgen = rlwe.NewKeyGenerator(tc.Params)
tc.Sk, tc.Pk = tc.Kgen.GenKeyPairNew()
tc.Enc = rlwe.NewEncryptor(tc.Params, tc.Pk)
tc.Dec = rlwe.NewDecryptor(tc.Params, tc.Sk)
tc.Evl = NewEvaluator(tc.Params, rlwe.NewMemEvaluationKeySet(tc.Kgen.GenRelinearizationKeyNew(tc.Sk)))
return tc
}
func (tc *TestContext) NewTestVector(a, b complex128) (values []*bignum.Complex, pt *rlwe.Plaintext, ct *rlwe.Ciphertext) {
prec := tc.Ecd.Prec()
pt = NewPlaintext(tc.Params, tc.Params.MaxLevel())
values = make([]*bignum.Complex, pt.Slots())
switch tc.Params.RingType() {
case ring.Standard:
for i := range values {
values[i] = &bignum.Complex{
bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec),
bignum.NewFloat(sampling.RandFloat64(imag(a), imag(b)), prec),
}
}
case ring.ConjugateInvariant:
for i := range values {
values[i] = &bignum.Complex{
bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec),
new(big.Float),
}
}
default:
panic("unsupported ring type")
}
var err error
if err = tc.Ecd.Encode(values, pt); err != nil {
panic(err)
}
ct, err = tc.Enc.EncryptNew(pt)
if err != nil {
panic(err)
}
return values, pt, ct
}
func randomConst(tp ring.Type, prec uint, a, b complex128) (constant *bignum.Complex) {
switch tp {
case ring.Standard:
constant = &bignum.Complex{
bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec),
bignum.NewFloat(sampling.RandFloat64(imag(a), imag(b)), prec),
}
case ring.ConjugateInvariant:
constant = &bignum.Complex{
bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec),
new(big.Float),
}
default:
panic("invalid ring type")
}
return
}
var (
// testInsecurePrec45 are insecure parameters used for the sole purpose of fast testing.
TestInsecurePrec45 = ParametersLiteral{
LogN: 10,
LogQ: []int{55, 45, 45, 45, 45, 45, 45},
LogP: []int{60},
LogDefaultScale: 45,
}
// testInsecurePrec90 are insecure parameters used for the sole purpose of fast testing.
TestInsecurePrec90 = ParametersLiteral{
LogN: 10,
LogQ: []int{55, 55, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45},
LogP: []int{60, 60},
LogDefaultScale: 90,
}
TestParametersLiteral = []ParametersLiteral{TestInsecurePrec45, TestInsecurePrec90}
)
// testInsecurePrec45 = ParametersLiteral{
// LogN: 10,
// Q: []uint64{
// 0x80000000080001,
// 0x2000000a0001,
// 0x2000000e0001,
// 0x2000001d0001,
// 0x1fffffcf0001,
// 0x1fffffc20001,
// 0x200000440001,
// },
// P: []uint64{
// 0x80000000130001,
// 0x7fffffffe90001,
// },
// LogDefaultScale: 45,
// }
// testInsecurePrec90 = ParametersLiteral{
// LogN: 10,
// Q: []uint64{
// 0x80000000080001,
// 0x80000000440001,
// 0x2000000a0001,
// 0x2000000e0001,
// 0x1fffffc20001,
// 0x200000440001,
// 0x200000500001,
// 0x200000620001,
// 0x1fffff980001,
// 0x2000006a0001,
// 0x1fffff7e0001,
// 0x200000860001,
// },
// P: []uint64{
// 0xffffffffffc0001,
// 0x10000000006e0001,
// },
// LogDefaultScale: 90,
// }
// testParamsLiteral = []ParametersLiteral{testInsecurePrec45, testInsecurePrec90}

View File

@@ -1 +0,0 @@
package schemes

View File

@@ -1,39 +0,0 @@
package schemes
import (
"github.com/tuneinsight/lattigo/v5/schemes/bgv"
"github.com/tuneinsight/lattigo/v5/schemes/ckks"
)
var (
// BgvTestInsecure are insecure parameters used for the sole purpose of fast testing.
BgvTestInsecure = bgv.ParametersLiteral{
LogN: 10,
Q: []uint64{0x3fffffa8001, 0x1000090001, 0x10000c8001, 0x10000f0001, 0xffff00001},
P: []uint64{0x7fffffd8001},
}
BgvTestPlaintextModulus = []uint64{0x101, 0xffc001}
BgvTestParams = []bgv.ParametersLiteral{BgvTestInsecure}
)
var (
// BgvTestInsecurePrec45 are insecure parameters used for the sole purpose of fast testing.
CkksTestInsecurePrec45 = ckks.ParametersLiteral{
LogN: 10,
LogQ: []int{55, 45, 45, 45, 45, 45, 45},
LogP: []int{60},
LogDefaultScale: 45,
}
// BgvTestInsecurePrec90 are insecure parameters used for the sole purpose of fast testing.
CkksTestInsecurePrec90 = ckks.ParametersLiteral{
LogN: 10,
LogQ: []int{55, 55, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45},
LogP: []int{60, 60},
LogDefaultScale: 90,
}
CkksTestParametersLiteral = []ckks.ParametersLiteral{CkksTestInsecurePrec45, CkksTestInsecurePrec90}
)