From a315439b5bbf2bfbca3cf0fcb37513a14e2e2dcf Mon Sep 17 00:00:00 2001 From: Andrea Caforio Date: Wed, 3 Jul 2024 16:58:05 +0200 Subject: [PATCH] refactor test utilities Add `BFV`, `BGV` and `CKKS` test utilities that can be reused throughout the entirety of Lattigo. This reduces code duplication and should simplify the creation of future unit tests. --- .../ltfloat/linear_transformation_test.go | 165 +---- .../ltint/linear_transformation_test.go | 184 +---- .../polyfloat/polynomial_evaluator_test.go | 134 +--- .../polyint/polynomial_evaluator_test.go | 164 +---- schemes/bfv/bfv_benchmark_test.go | 91 +-- schemes/bfv/bfv_test.go | 464 +++++------- schemes/bfv/test_parameters.go | 15 - schemes/bfv/test_utils.go | 120 ++++ schemes/bgv/bgv_benchmark_test.go | 104 ++- schemes/bgv/bgv_test.go | 659 ++++++++---------- schemes/bgv/test_parameters.go | 14 - schemes/bgv/test_utils.go | 120 ++++ schemes/ckks/ckks_benchmarks_test.go | 93 +-- schemes/ckks/ckks_test.go | 521 ++++++-------- schemes/ckks/test_params.go | 48 -- schemes/ckks/test_utils.go | 174 +++++ schemes/parameters_test.go | 1 - schemes/test_parameters.go | 39 -- 18 files changed, 1308 insertions(+), 1802 deletions(-) delete mode 100644 schemes/bfv/test_parameters.go create mode 100644 schemes/bfv/test_utils.go delete mode 100644 schemes/bgv/test_parameters.go create mode 100644 schemes/bgv/test_utils.go delete mode 100644 schemes/ckks/test_params.go create mode 100644 schemes/ckks/test_utils.go delete mode 100644 schemes/parameters_test.go delete mode 100644 schemes/test_parameters.go diff --git a/circuits/linear_transformation/ltfloat/linear_transformation_test.go b/circuits/linear_transformation/ltfloat/linear_transformation_test.go index b06d9239..30b13928 100644 --- a/circuits/linear_transformation/ltfloat/linear_transformation_test.go +++ b/circuits/linear_transformation/ltfloat/linear_transformation_test.go @@ -14,7 +14,6 @@ import ( "github.com/tuneinsight/lattigo/v5/core/rlwe" "github.com/tuneinsight/lattigo/v5/ring" - "github.com/tuneinsight/lattigo/v5/schemes" "github.com/tuneinsight/lattigo/v5/schemes/ckks" "github.com/tuneinsight/lattigo/v5/utils" "github.com/tuneinsight/lattigo/v5/utils/bignum" @@ -24,34 +23,18 @@ import ( var flagParamString = flag.String("params", "", "specify the test cryptographic parameters as a JSON string. Overrides -short.") var printPrecisionStats = flag.Bool("print-precision", false, "print precision stats") -func GetTestName(params ckks.Parameters, opname string) string { +func name(opname string, tc *ckks.TestContext) string { return fmt.Sprintf("%s/RingType=%s/logN=%d/logQP=%d/Qi=%d/Pi=%d/LogScale=%d", opname, - params.RingType(), - params.LogN(), - int(math.Round(params.LogQP())), - params.QCount(), - params.PCount(), - int(math.Log2(params.DefaultScale().Float64()))) -} - -type testContext struct { - params ckks.Parameters - ringQ *ring.Ring - ringP *ring.Ring - prng sampling.PRNG - encoder ckks.Encoder - kgen *rlwe.KeyGenerator - sk *rlwe.SecretKey - pk *rlwe.PublicKey - encryptorPk *rlwe.Encryptor - encryptorSk *rlwe.Encryptor - decryptor *rlwe.Decryptor - evaluator ckks.Evaluator + tc.Params.RingType(), + tc.Params.LogN(), + int(math.Round(tc.Params.LogQP())), + tc.Params.QCount(), + tc.Params.PCount(), + int(math.Log2(tc.Params.DefaultScale().Float64()))) } func TestPolynomialEvaluator(t *testing.T) { - var err error var testParams []ckks.ParametersLiteral @@ -62,7 +45,7 @@ func TestPolynomialEvaluator(t *testing.T) { t.Fatal(err) } default: - testParams = schemes.CkksTestParametersLiteral + testParams = ckks.TestParametersLiteral } for _, ringType := range []ring.Type{ring.Standard, ring.ConjugateInvariant} { @@ -75,17 +58,9 @@ func TestPolynomialEvaluator(t *testing.T) { paramsLiteral.LogN = 10 } - var params ckks.Parameters - if params, err = ckks.NewParametersFromLiteral(paramsLiteral); err != nil { - t.Fatal(err) - } + tc := ckks.NewTestContext(paramsLiteral) - var tc *testContext - if tc, err = genTestParams(params); err != nil { - t.Fatal(err) - } - - for _, testSet := range []func(tc *testContext, t *testing.T){ + for _, testSet := range []func(tc *ckks.TestContext, t *testing.T){ run, } { testSet(tc, t) @@ -95,9 +70,8 @@ func TestPolynomialEvaluator(t *testing.T) { } } -func run(tc *testContext, t *testing.T) { - - params := tc.params +func run(tc *ckks.TestContext, t *testing.T) { + params := tc.Params mulCmplx := bignum.NewComplexMultiplier().Mul @@ -119,7 +93,7 @@ func run(tc *testContext, t *testing.T) { } } - prec := tc.encoder.Prec() + prec := tc.Ecd.Prec() newVec := func(size int) (vec []*bignum.Complex) { vec = make([]*bignum.Complex, size) @@ -129,9 +103,9 @@ func run(tc *testContext, t *testing.T) { return } - t.Run(GetTestName(params, "Average"), func(t *testing.T) { + t.Run(name("Average", tc), func(t *testing.T) { - values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values, _, ciphertext := tc.NewTestVector(-1-1i, 1+1i) slots := ciphertext.Slots() @@ -139,7 +113,7 @@ func run(tc *testContext, t *testing.T) { batch := 1 << logBatch n := slots / batch - eval := tc.evaluator.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.kgen.GenGaloisKeysNew(rlwe.GaloisElementsForInnerSum(params, batch, n), tc.sk)...)) + eval := tc.Evl.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(rlwe.GaloisElementsForInnerSum(params, batch, n), tc.Sk)...)) require.NoError(t, eval.Average(ciphertext, logBatch, ciphertext)) @@ -164,12 +138,12 @@ func run(tc *testContext, t *testing.T) { values[i][1].Quo(values[i][1], nB) } - ckks.VerifyTestVectors(params, &tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t) + ckks.VerifyTestVectors(params, tc.Ecd, tc.Dec, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t) }) - t.Run(GetTestName(params, "LinearTransform/BSGS=True"), func(t *testing.T) { + t.Run(name("LinearTransform/BSGS=True", tc), func(t *testing.T) { - values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values, _, ciphertext := tc.NewTestVector(-1-1i, 1+1i) slots := ciphertext.Slots() @@ -199,24 +173,24 @@ func run(tc *testContext, t *testing.T) { linTransf := NewLinearTransformation(params, ltparams) // Encode on the linear transformation - require.NoError(t, EncodeLinearTransformation(tc.encoder, diagonals, linTransf)) + require.NoError(t, EncodeLinearTransformation(tc.Ecd, diagonals, linTransf)) galEls := linTransf.GaloisElements(params) - evk := rlwe.NewMemEvaluationKeySet(nil, tc.kgen.GenGaloisKeysNew(galEls, tc.sk)...) + evk := rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(galEls, tc.Sk)...) - ltEval := NewLinearTransformationEvaluator(tc.evaluator.WithKey(evk)) + ltEval := NewLinearTransformationEvaluator(tc.Evl.WithKey(evk)) require.NoError(t, ltEval.Evaluate(ciphertext, linTransf, ciphertext)) values = diagonals.Evaluate(values, newVec, add, muladd) - ckks.VerifyTestVectors(params, &tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t) + ckks.VerifyTestVectors(params, tc.Ecd, tc.Dec, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t) }) - t.Run(GetTestName(params, "LinearTransform/BSGS=False"), func(t *testing.T) { + t.Run(name("LinearTransform/BSGS=False", tc), func(t *testing.T) { - values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values, _, ciphertext := tc.NewTestVector(-1-1i, 1+1i) slots := ciphertext.Slots() @@ -247,22 +221,22 @@ func run(tc *testContext, t *testing.T) { linTransf := NewLinearTransformation(params, ltparams) // Encode on the linear transformation - require.NoError(t, EncodeLinearTransformation(tc.encoder, diagonals, linTransf)) + require.NoError(t, EncodeLinearTransformation(tc.Ecd, diagonals, linTransf)) galEls := linTransf.GaloisElements(params) - evk := rlwe.NewMemEvaluationKeySet(nil, tc.kgen.GenGaloisKeysNew(galEls, tc.sk)...) + evk := rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(galEls, tc.Sk)...) - ltEval := NewLinearTransformationEvaluator(tc.evaluator.WithKey(evk)) + ltEval := NewLinearTransformationEvaluator(tc.Evl.WithKey(evk)) require.NoError(t, ltEval.Evaluate(ciphertext, linTransf, ciphertext)) values = diagonals.Evaluate(values, newVec, add, muladd) - ckks.VerifyTestVectors(params, &tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t) + ckks.VerifyTestVectors(params, tc.Ecd, tc.Dec, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t) }) - t.Run(GetTestName(params, "LinearTransform/Permutation"), func(t *testing.T) { + t.Run(name("LinearTransform/Permutation", tc), func(t *testing.T) { idx := make([]int, params.MaxSlots()) for i := range idx { idx[i] = i @@ -289,7 +263,7 @@ func run(tc *testContext, t *testing.T) { diagonals := Permutation[*bignum.Complex](permutation).GetDiagonals(params.LogMaxSlots()) - values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values, _, ciphertext := tc.NewTestVector(-1-1i, 1+1i) ltparams := LinearTransformationParameters{ DiagonalsIndexList: diagonals.DiagonalsIndexList(), @@ -304,87 +278,18 @@ func run(tc *testContext, t *testing.T) { linTransf := NewLinearTransformation(params, ltparams) // Encode on the linear transformation - require.NoError(t, EncodeLinearTransformation(tc.encoder, diagonals, linTransf)) + require.NoError(t, EncodeLinearTransformation(tc.Ecd, diagonals, linTransf)) galEls := linTransf.GaloisElements(params) - evk := rlwe.NewMemEvaluationKeySet(nil, tc.kgen.GenGaloisKeysNew(galEls, tc.sk)...) + evk := rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(galEls, tc.Sk)...) - ltEval := NewLinearTransformationEvaluator(tc.evaluator.WithKey(evk)) + ltEval := NewLinearTransformationEvaluator(tc.Evl.WithKey(evk)) require.NoError(t, ltEval.Evaluate(ciphertext, linTransf, ciphertext)) values = diagonals.Evaluate(values, newVec, add, muladd) - ckks.VerifyTestVectors(params, &tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t) + ckks.VerifyTestVectors(params, tc.Ecd, tc.Dec, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t) }) } - -func genTestParams(defaultParam ckks.Parameters) (tc *testContext, err error) { - - tc = new(testContext) - - tc.params = defaultParam - - tc.kgen = rlwe.NewKeyGenerator(tc.params) - - tc.sk, tc.pk = tc.kgen.GenKeyPairNew() - - tc.ringQ = defaultParam.RingQ() - if tc.params.PCount() != 0 { - tc.ringP = defaultParam.RingP() - } - - if tc.prng, err = sampling.NewPRNG(); err != nil { - return nil, err - } - - tc.encoder = *ckks.NewEncoder(tc.params) - - tc.encryptorPk = rlwe.NewEncryptor(tc.params, tc.pk) - tc.encryptorSk = rlwe.NewEncryptor(tc.params, tc.sk) - tc.decryptor = rlwe.NewDecryptor(tc.params, tc.sk) - tc.evaluator = *ckks.NewEvaluator(tc.params, rlwe.NewMemEvaluationKeySet(tc.kgen.GenRelinearizationKeyNew(tc.sk))) - - return tc, nil - -} - -func newTestVectors(tc *testContext, encryptor *rlwe.Encryptor, a, b complex128, t *testing.T) (values []*bignum.Complex, pt *rlwe.Plaintext, ct *rlwe.Ciphertext) { - - var err error - - prec := tc.params.EncodingPrecision() - - pt = ckks.NewPlaintext(tc.params, tc.params.MaxLevel()) - - values = make([]*bignum.Complex, pt.Slots()) - - switch tc.params.RingType() { - case ring.Standard: - for i := range values { - values[i] = &bignum.Complex{ - bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec), - bignum.NewFloat(sampling.RandFloat64(imag(a), imag(b)), prec), - } - } - case ring.ConjugateInvariant: - for i := range values { - values[i] = &bignum.Complex{ - bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec), - new(big.Float), - } - } - default: - t.Fatal("invalid ring type") - } - - tc.encoder.Encode(values, pt) - - if encryptor != nil { - ct, err = encryptor.EncryptNew(pt) - require.NoError(t, err) - } - - return values, pt, ct -} diff --git a/circuits/linear_transformation/ltint/linear_transformation_test.go b/circuits/linear_transformation/ltint/linear_transformation_test.go index 3ab813f8..814e316b 100644 --- a/circuits/linear_transformation/ltint/linear_transformation_test.go +++ b/circuits/linear_transformation/ltint/linear_transformation_test.go @@ -3,86 +3,23 @@ package ltint import ( "encoding/json" "flag" - "fmt" - "math" "math/rand" "runtime" - "slices" "testing" "github.com/stretchr/testify/require" + "github.com/tuneinsight/lattigo/v5/core/rlwe" - "github.com/tuneinsight/lattigo/v5/ring" - "github.com/tuneinsight/lattigo/v5/schemes" "github.com/tuneinsight/lattigo/v5/schemes/bgv" "github.com/tuneinsight/lattigo/v5/utils/sampling" ) -type testContext struct { - params bgv.Parameters - ringQ *ring.Ring - ringT *ring.Ring - prng sampling.PRNG - uSampler *ring.UniformSampler - encoder schemes.Encoder - kgen *rlwe.KeyGenerator - sk *rlwe.SecretKey - pk *rlwe.PublicKey - encryptorPk *rlwe.Encryptor - encryptorSk *rlwe.Encryptor - decryptor *rlwe.Decryptor - evaluator *bgv.Evaluator - testLevel []int -} - -func genTestParams(params bgv.Parameters) (tc *testContext, err error) { - - tc = new(testContext) - tc.params = params - - if tc.prng, err = sampling.NewPRNG(); err != nil { - return nil, err - } - - tc.ringQ = params.RingQ() - tc.ringT = params.RingT() - - tc.uSampler = ring.NewUniformSampler(tc.prng, tc.ringT) - tc.kgen = rlwe.NewKeyGenerator(tc.params) - tc.sk, tc.pk = tc.kgen.GenKeyPairNew() - tc.encoder = bgv.NewEncoder(tc.params) - - tc.encryptorPk = rlwe.NewEncryptor(tc.params, tc.pk) - tc.encryptorSk = rlwe.NewEncryptor(tc.params, tc.sk) - tc.decryptor = rlwe.NewDecryptor(tc.params, tc.sk) - tc.evaluator = bgv.NewEvaluator(tc.params, rlwe.NewMemEvaluationKeySet(tc.kgen.GenRelinearizationKeyNew(tc.sk))) - - tc.testLevel = []int{0, params.MaxLevel()} - - return -} - -var flagPrintNoise = flag.Bool("print-noise", false, "print the residual noise") var flagParamString = flag.String("params", "", "specify the test cryptographic parameters as a JSON string. Overrides -short.") -func GetTestName(opname string, p bgv.Parameters, lvl int) string { - return fmt.Sprintf("%s/LogN=%d/logQ=%d/logP=%d/LogSlots=%dx%d/logT=%d/Qi=%d/Pi=%d/lvl=%d", - opname, - p.LogN(), - int(math.Round(p.LogQ())), - int(math.Round(p.LogP())), - p.LogMaxDimensions().Rows, - p.LogMaxDimensions().Cols, - int(math.Round(p.LogT())), - p.QCount(), - p.PCount(), - lvl) -} - -func TestPolynomialEvaluator(t *testing.T) { +func TestLinearTransformation(t *testing.T) { var err error - paramsLiterals := schemes.BgvTestParams + paramsLiterals := bgv.TestParams if *flagParamString != "" { var jsonParams bgv.ParametersLiteral @@ -94,23 +31,12 @@ func TestPolynomialEvaluator(t *testing.T) { for _, p := range paramsLiterals[:] { - for _, plaintextModulus := range schemes.BgvTestPlaintextModulus[:] { - + for _, plaintextModulus := range bgv.TestPlaintextModulus[:] { p.PlaintextModulus = plaintextModulus - var params bgv.Parameters - if params, err = bgv.NewParametersFromLiteral(p); err != nil { - t.Error(err) - t.Fail() - } + tc := bgv.NewTestContext(p) - var tc *testContext - if tc, err = genTestParams(params); err != nil { - t.Error(err) - t.Fail() - } - - for _, testSet := range []func(tc *testContext, t *testing.T){ + for _, testSet := range []func(tc *bgv.TestContext, t *testing.T){ run, } { testSet(tc, t) @@ -120,30 +46,26 @@ func TestPolynomialEvaluator(t *testing.T) { } } -func run(tc *testContext, t *testing.T) { - rT := tc.params.RingT().SubRings[0] +func run(tc *bgv.TestContext, t *testing.T) { + rT := tc.Params.RingT().SubRings[0] add := func(a, b, c []uint64) { rT.Add(a, b, c) } - muladd := func(a, b, c []uint64) { rT.MulCoeffsBarrettThenAdd(a, b, c) } - newVec := func(size int) (vec []uint64) { return make([]uint64, size) } - params := tc.params + params := tc.Params T := params.PlaintextModulus() - level := tc.params.MaxLevel() + t.Run("Evaluator/LinearTransformationBSGS=true/"+tc.String(), func(t *testing.T) { - t.Run(GetTestName("Evaluator/LinearTransformationBSGS=true", params, level), func(t *testing.T) { - - values, _, ciphertext := newTestVectorsLvl(level, params.DefaultScale(), tc, tc.encryptorSk) + values, _, ciphertext := bgv.NewTestVector(params, tc.Ecd, tc.Enc, params.MaxLevel(), params.DefaultScale()) slots := ciphertext.Slots() @@ -161,7 +83,7 @@ func run(tc *testContext, t *testing.T) { DiagonalsIndexList: diagonals.DiagonalsIndexList(), LevelQ: ciphertext.Level(), LevelP: params.MaxLevelP(), - Scale: tc.params.DefaultScale(), + Scale: tc.Params.DefaultScale(), LogDimensions: ciphertext.LogDimensions, LogBabyStepGiantStepRatio: 1, } @@ -170,23 +92,23 @@ func run(tc *testContext, t *testing.T) { linTransf := NewLinearTransformation(params, ltparams) // Encode on the linear transformation - require.NoError(t, EncodeLinearTransformation(tc.encoder, diagonals, linTransf)) + require.NoError(t, EncodeLinearTransformation(tc.Ecd, diagonals, linTransf)) galEls := linTransf.GaloisElements(params) - eval := tc.evaluator.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.kgen.GenGaloisKeysNew(galEls, tc.sk)...)) + eval := tc.Evl.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(galEls, tc.Sk)...)) ltEval := NewLinearTransformationEvaluator(eval) require.NoError(t, ltEval.Evaluate(ciphertext, linTransf, ciphertext)) - values.Coeffs[0] = diagonals.Evaluate(values.Coeffs[0], newVec, add, muladd) + values = diagonals.Evaluate(values, newVec, add, muladd) - verifyTestVectors(tc, tc.decryptor, values, ciphertext, t) + bgv.VerifyTestVectors(params, tc.Ecd, tc.Dec, ciphertext, values, t) }) - t.Run(GetTestName("Evaluator/LinearTransformationBSGS=false", params, level), func(t *testing.T) { + t.Run("Evaluator/LinearTransformationBSGS=false"+tc.String(), func(t *testing.T) { - values, _, ciphertext := newTestVectorsLvl(level, params.DefaultScale(), tc, tc.encryptorSk) + values, _, ciphertext := bgv.NewTestVector(params, tc.Ecd, tc.Enc, params.MaxLevel(), params.DefaultScale()) slots := ciphertext.Slots() @@ -213,21 +135,21 @@ func run(tc *testContext, t *testing.T) { linTransf := NewLinearTransformation(params, ltparams) // Encode on the linear transformation - require.NoError(t, EncodeLinearTransformation(tc.encoder, diagonals, linTransf)) + require.NoError(t, EncodeLinearTransformation(tc.Ecd, diagonals, linTransf)) galEls := linTransf.GaloisElements(params) - eval := tc.evaluator.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.kgen.GenGaloisKeysNew(galEls, tc.sk)...)) + eval := tc.Evl.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(galEls, tc.Sk)...)) ltEval := NewLinearTransformationEvaluator(eval) require.NoError(t, ltEval.Evaluate(ciphertext, linTransf, ciphertext)) - values.Coeffs[0] = diagonals.Evaluate(values.Coeffs[0], newVec, add, muladd) + values = diagonals.Evaluate(values, newVec, add, muladd) - verifyTestVectors(tc, tc.decryptor, values, ciphertext, t) + bgv.VerifyTestVectors(params, tc.Ecd, tc.Dec, ciphertext, values, t) }) - t.Run(GetTestName("Evaluator/LinearTransformation/Permutation", params, level), func(t *testing.T) { + t.Run("Evaluator/LinearTransformation/Permutation"+tc.String(), func(t *testing.T) { idx := [2][]int{ make([]int, params.MaxSlots()>>1), @@ -272,7 +194,7 @@ func run(tc *testContext, t *testing.T) { diagonals := Permutation[uint64](permutation).GetDiagonals(params.LogMaxSlots()) - values, _, ciphertext := newTestVectorsLvl(level, tc.params.NewScale(1), tc, tc.encryptorSk) + values, _, ciphertext := bgv.NewTestVector(params, tc.Ecd, tc.Enc, params.MaxLevel(), params.DefaultScale()) ltparams := LinearTransformationParameters{ DiagonalsIndexList: diagonals.DiagonalsIndexList(), @@ -287,66 +209,18 @@ func run(tc *testContext, t *testing.T) { linTransf := NewLinearTransformation(params, ltparams) // Encode on the linear transformation - require.NoError(t, EncodeLinearTransformation(tc.encoder, diagonals, linTransf)) + require.NoError(t, EncodeLinearTransformation(tc.Ecd, diagonals, linTransf)) galEls := linTransf.GaloisElements(params) - evk := rlwe.NewMemEvaluationKeySet(nil, tc.kgen.GenGaloisKeysNew(galEls, tc.sk)...) + evk := rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(galEls, tc.Sk)...) - ltEval := NewLinearTransformationEvaluator(tc.evaluator.WithKey(evk)) + ltEval := NewLinearTransformationEvaluator(tc.Evl.WithKey(evk)) require.NoError(t, ltEval.Evaluate(ciphertext, linTransf, ciphertext)) - values.Coeffs[0] = diagonals.Evaluate(values.Coeffs[0], newVec, add, muladd) + values = diagonals.Evaluate(values, newVec, add, muladd) - verifyTestVectors(tc, tc.decryptor, values, ciphertext, t) + bgv.VerifyTestVectors(params, tc.Ecd, tc.Dec, ciphertext, values, t) }) } - -func newTestVectorsLvl(level int, scale rlwe.Scale, tc *testContext, encryptor *rlwe.Encryptor) (coeffs ring.Poly, plaintext *rlwe.Plaintext, ciphertext *rlwe.Ciphertext) { - coeffs = tc.uSampler.ReadNew() - for i := range coeffs.Coeffs[0] { - coeffs.Coeffs[0][i] = uint64(i) - } - - plaintext = bgv.NewPlaintext(tc.params, level) - plaintext.Scale = scale - tc.encoder.Encode(coeffs.Coeffs[0], plaintext) - if encryptor != nil { - var err error - ciphertext, err = encryptor.EncryptNew(plaintext) - if err != nil { - panic(err) - } - } - - return coeffs, plaintext, ciphertext -} - -func verifyTestVectors(tc *testContext, decryptor *rlwe.Decryptor, coeffs ring.Poly, element rlwe.ElementInterface[ring.Poly], t *testing.T) { - - coeffsTest := make([]uint64, tc.params.MaxSlots()) - - switch el := element.(type) { - case *rlwe.Plaintext: - require.NoError(t, tc.encoder.Decode(el, coeffsTest)) - case *rlwe.Ciphertext: - - pt := decryptor.DecryptNew(el) - - require.NoError(t, tc.encoder.Decode(pt, coeffsTest)) - - if *flagPrintNoise { - require.NoError(t, tc.encoder.Encode(coeffsTest, pt)) - ct, err := tc.evaluator.SubNew(el, pt) - require.NoError(t, err) - vartmp, _, _ := rlwe.Norm(ct, decryptor) - t.Logf("STD(noise): %f\n", vartmp) - } - - default: - t.Error("invalid test object to verify") - } - - require.True(t, slices.Equal(coeffs.Coeffs[0], coeffsTest)) -} diff --git a/circuits/polynomial/polyfloat/polynomial_evaluator_test.go b/circuits/polynomial/polyfloat/polynomial_evaluator_test.go index 0a6743f3..e2d7e8c4 100644 --- a/circuits/polynomial/polyfloat/polynomial_evaluator_test.go +++ b/circuits/polynomial/polyfloat/polynomial_evaluator_test.go @@ -11,45 +11,26 @@ import ( "github.com/stretchr/testify/require" - "github.com/tuneinsight/lattigo/v5/core/rlwe" "github.com/tuneinsight/lattigo/v5/ring" - "github.com/tuneinsight/lattigo/v5/schemes" "github.com/tuneinsight/lattigo/v5/schemes/ckks" "github.com/tuneinsight/lattigo/v5/utils/bignum" - "github.com/tuneinsight/lattigo/v5/utils/sampling" ) var flagParamString = flag.String("params", "", "specify the test cryptographic parameters as a JSON string. Overrides -short.") var printPrecisionStats = flag.Bool("print-precision", false, "print precision stats") -func GetTestName(params ckks.Parameters, opname string) string { +func name(opname string, tc *ckks.TestContext) string { return fmt.Sprintf("%s/RingType=%s/logN=%d/logQP=%d/Qi=%d/Pi=%d/LogScale=%d", opname, - params.RingType(), - params.LogN(), - int(math.Round(params.LogQP())), - params.QCount(), - params.PCount(), - int(math.Log2(params.DefaultScale().Float64()))) -} - -type testContext struct { - params ckks.Parameters - ringQ *ring.Ring - ringP *ring.Ring - prng sampling.PRNG - encoder ckks.Encoder - kgen *rlwe.KeyGenerator - sk *rlwe.SecretKey - pk *rlwe.PublicKey - encryptorPk *rlwe.Encryptor - encryptorSk *rlwe.Encryptor - decryptor *rlwe.Decryptor - evaluator schemes.Evaluator + tc.Params.RingType(), + tc.Params.LogN(), + int(math.Round(tc.Params.LogQP())), + tc.Params.QCount(), + tc.Params.PCount(), + int(math.Log2(tc.Params.DefaultScale().Float64()))) } func TestPolynomialEvaluator(t *testing.T) { - var err error var testParams []ckks.ParametersLiteral @@ -60,7 +41,7 @@ func TestPolynomialEvaluator(t *testing.T) { t.Fatal(err) } default: - testParams = schemes.CkksTestParametersLiteral + testParams = ckks.TestParametersLiteral } for _, ringType := range []ring.Type{ring.Standard, ring.ConjugateInvariant} { @@ -73,17 +54,9 @@ func TestPolynomialEvaluator(t *testing.T) { paramsLiteral.LogN = 10 } - var params ckks.Parameters - if params, err = ckks.NewParametersFromLiteral(paramsLiteral); err != nil { - t.Fatal(err) - } + tc := ckks.NewTestContext(paramsLiteral) - var tc *testContext - if tc, err = genTestParams(params); err != nil { - t.Fatal(err) - } - - for _, testSet := range []func(tc *testContext, t *testing.T){ + for _, testSet := range []func(tc *ckks.TestContext, t *testing.T){ run, } { testSet(tc, t) @@ -93,21 +66,21 @@ func TestPolynomialEvaluator(t *testing.T) { } } -func run(tc *testContext, t *testing.T) { +func run(tc *ckks.TestContext, t *testing.T) { - params := tc.params + params := tc.Params var err error - polyEval := NewPolynomialEvaluator(params, tc.evaluator) + polyEval := NewPolynomialEvaluator(params, tc.Evl) - t.Run(GetTestName(params, "EvaluatePoly/PolySingle/Exp"), func(t *testing.T) { + t.Run(name("EvaluatePoly/PolySingle/Exp", tc), func(t *testing.T) { if params.MaxLevel() < 3 { t.Skip("skipping test for params max level < 3") } - values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1, 1, t) + values, _, ciphertext := tc.NewTestVector(-1, 1) prec := params.EncodingPrecision() @@ -132,16 +105,16 @@ func run(tc *testContext, t *testing.T) { t.Fatal(err) } - ckks.VerifyTestVectors(params, &tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t) + ckks.VerifyTestVectors(params, tc.Ecd, tc.Dec, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t) }) - t.Run(GetTestName(params, "Polynomial/PolyVector/Exp"), func(t *testing.T) { + t.Run(name("Polynomial/PolyVector/Exp", tc), func(t *testing.T) { if params.MaxLevel() < 3 { t.Skip("skipping test for params max level < 3") } - values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1, 1, t) + values, _, ciphertext := tc.NewTestVector(-1, 1) prec := params.EncodingPrecision() @@ -180,75 +153,6 @@ func run(tc *testContext, t *testing.T) { t.Fatal(err) } - ckks.VerifyTestVectors(params, &tc.encoder, tc.decryptor, valuesWant, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t) + ckks.VerifyTestVectors(params, tc.Ecd, tc.Dec, valuesWant, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t) }) } - -func genTestParams(defaultParam ckks.Parameters) (tc *testContext, err error) { - - tc = new(testContext) - - tc.params = defaultParam - - tc.kgen = rlwe.NewKeyGenerator(tc.params) - - tc.sk, tc.pk = tc.kgen.GenKeyPairNew() - - tc.ringQ = defaultParam.RingQ() - if tc.params.PCount() != 0 { - tc.ringP = defaultParam.RingP() - } - - if tc.prng, err = sampling.NewPRNG(); err != nil { - return nil, err - } - - tc.encoder = *ckks.NewEncoder(tc.params) - - tc.encryptorPk = rlwe.NewEncryptor(tc.params, tc.pk) - tc.encryptorSk = rlwe.NewEncryptor(tc.params, tc.sk) - tc.decryptor = rlwe.NewDecryptor(tc.params, tc.sk) - tc.evaluator = ckks.NewEvaluator(tc.params, rlwe.NewMemEvaluationKeySet(tc.kgen.GenRelinearizationKeyNew(tc.sk))) - - return tc, nil - -} - -func newTestVectors(tc *testContext, encryptor *rlwe.Encryptor, a, b complex128, t *testing.T) (values []*bignum.Complex, pt *rlwe.Plaintext, ct *rlwe.Ciphertext) { - - var err error - - prec := tc.params.EncodingPrecision() - - pt = ckks.NewPlaintext(tc.params, tc.params.MaxLevel()) - - values = make([]*bignum.Complex, pt.Slots()) - - switch tc.params.RingType() { - case ring.Standard: - for i := range values { - values[i] = &bignum.Complex{ - bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec), - bignum.NewFloat(sampling.RandFloat64(imag(a), imag(b)), prec), - } - } - case ring.ConjugateInvariant: - for i := range values { - values[i] = &bignum.Complex{ - bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec), - new(big.Float), - } - } - default: - t.Fatal("invalid ring type") - } - - tc.encoder.Encode(values, pt) - - if encryptor != nil { - ct, err = encryptor.EncryptNew(pt) - require.NoError(t, err) - } - - return values, pt, ct -} diff --git a/circuits/polynomial/polyint/polynomial_evaluator_test.go b/circuits/polynomial/polyint/polynomial_evaluator_test.go index d5252f59..e3bf2be4 100644 --- a/circuits/polynomial/polyint/polynomial_evaluator_test.go +++ b/circuits/polynomial/polyint/polynomial_evaluator_test.go @@ -3,87 +3,22 @@ package polyint import ( "encoding/json" "flag" - "fmt" - "math" "runtime" - "slices" "testing" "github.com/stretchr/testify/require" - "github.com/tuneinsight/lattigo/v5/core/rlwe" "github.com/tuneinsight/lattigo/v5/ring" - "github.com/tuneinsight/lattigo/v5/schemes" "github.com/tuneinsight/lattigo/v5/schemes/bgv" "github.com/tuneinsight/lattigo/v5/utils/bignum" - "github.com/tuneinsight/lattigo/v5/utils/sampling" ) -type testContext struct { - params bgv.Parameters - ringQ *ring.Ring - ringT *ring.Ring - prng sampling.PRNG - uSampler *ring.UniformSampler - encoder *bgv.Encoder - kgen *rlwe.KeyGenerator - sk *rlwe.SecretKey - pk *rlwe.PublicKey - encryptorPk *rlwe.Encryptor - encryptorSk *rlwe.Encryptor - decryptor *rlwe.Decryptor - evaluator *bgv.Evaluator - testLevel []int -} - -func genTestParams(params bgv.Parameters) (tc *testContext, err error) { - - tc = new(testContext) - tc.params = params - - if tc.prng, err = sampling.NewPRNG(); err != nil { - return nil, err - } - - tc.ringQ = params.RingQ() - tc.ringT = params.RingT() - - tc.uSampler = ring.NewUniformSampler(tc.prng, tc.ringT) - tc.kgen = rlwe.NewKeyGenerator(tc.params) - tc.sk, tc.pk = tc.kgen.GenKeyPairNew() - tc.encoder = bgv.NewEncoder(tc.params) - - tc.encryptorPk = rlwe.NewEncryptor(tc.params, tc.pk) - tc.encryptorSk = rlwe.NewEncryptor(tc.params, tc.sk) - tc.decryptor = rlwe.NewDecryptor(tc.params, tc.sk) - tc.evaluator = bgv.NewEvaluator(tc.params, rlwe.NewMemEvaluationKeySet(tc.kgen.GenRelinearizationKeyNew(tc.sk))) - - tc.testLevel = []int{0, params.MaxLevel()} - - return -} - -var flagPrintNoise = flag.Bool("print-noise", false, "print the residual noise") var flagParamString = flag.String("params", "", "specify the test cryptographic parameters as a JSON string. Overrides -short.") -func GetTestName(opname string, p bgv.Parameters, lvl int) string { - return fmt.Sprintf("%s/LogN=%d/logQ=%d/logP=%d/LogSlots=%dx%d/logT=%d/Qi=%d/Pi=%d/lvl=%d", - opname, - p.LogN(), - int(math.Round(p.LogQ())), - int(math.Round(p.LogP())), - p.LogMaxDimensions().Rows, - p.LogMaxDimensions().Cols, - int(math.Round(p.LogT())), - p.QCount(), - p.PCount(), - lvl) -} - func TestPolynomialEvaluator(t *testing.T) { var err error - paramsLiterals := schemes.BgvTestParams + paramsLiterals := bgv.TestParams if *flagParamString != "" { var jsonParams bgv.ParametersLiteral @@ -95,23 +30,12 @@ func TestPolynomialEvaluator(t *testing.T) { for _, p := range paramsLiterals[:] { - for _, plaintextModulus := range schemes.BgvTestPlaintextModulus[:] { - + for _, plaintextModulus := range bgv.TestPlaintextModulus[:] { p.PlaintextModulus = plaintextModulus - var params bgv.Parameters - if params, err = bgv.NewParametersFromLiteral(p); err != nil { - t.Error(err) - t.Fail() - } + tc := bgv.NewTestContext(p) - var tc *testContext - if tc, err = genTestParams(params); err != nil { - t.Error(err) - t.Fail() - } - - for _, testSet := range []func(tc *testContext, t *testing.T){ + for _, testSet := range []func(tc *bgv.TestContext, t *testing.T){ run, } { testSet(tc, t) @@ -121,95 +45,47 @@ func TestPolynomialEvaluator(t *testing.T) { } } -func run(tc *testContext, t *testing.T) { +func run(tc *bgv.TestContext, t *testing.T) { t.Run("Single", func(t *testing.T) { - if tc.params.MaxLevel() < 4 { + if tc.Params.MaxLevel() < 4 { t.Skip("MaxLevel() to low") } - values, _, ciphertext := newTestVectorsLvl(tc.params.MaxLevel(), tc.params.NewScale(1), tc, tc.encryptorSk) + values, _, ciphertext := bgv.NewTestVector(tc.Params, tc.Ecd, tc.Enc, tc.Params.MaxLevel(), tc.Params.DefaultScale()) coeffs := []uint64{0, 0, 1} - T := tc.params.PlaintextModulus() - for i := range values.Coeffs[0] { - values.Coeffs[0][i] = ring.EvalPolyModP(values.Coeffs[0][i], coeffs, T) + T := tc.Params.PlaintextModulus() + for i := range values { + values[i] = ring.EvalPolyModP(values[i], coeffs, T) } poly := bignum.NewPolynomial(bignum.Monomial, coeffs, nil) - t.Run(GetTestName("Standard", tc.params, tc.params.MaxLevel()), func(t *testing.T) { + t.Run("Standard"+tc.String(), func(t *testing.T) { - polyEval := NewPolynomialEvaluator(tc.params, tc.evaluator, false) + polyEval := NewPolynomialEvaluator(tc.Params, tc.Evl, false) - res, err := polyEval.Evaluate(ciphertext, poly, tc.params.DefaultScale()) + res, err := polyEval.Evaluate(ciphertext, poly, tc.Params.DefaultScale()) require.NoError(t, err) - require.Equal(t, res.Scale.Cmp(tc.params.DefaultScale()), 0) + require.Equal(t, res.Scale.Cmp(tc.Params.DefaultScale()), 0) - verifyTestVectors(tc, tc.decryptor, values, res, t) + bgv.VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, res, values, t) }) - t.Run(GetTestName("Invariant", tc.params, tc.params.MaxLevel()), func(t *testing.T) { + t.Run("Invariant"+tc.String(), func(t *testing.T) { - polyEval := NewPolynomialEvaluator(tc.params, tc.evaluator, true) + polyEval := NewPolynomialEvaluator(tc.Params, tc.Evl, true) - res, err := polyEval.Evaluate(ciphertext, poly, tc.params.DefaultScale()) + res, err := polyEval.Evaluate(ciphertext, poly, tc.Params.DefaultScale()) require.NoError(t, err) require.Equal(t, res.Level(), ciphertext.Level()) - require.Equal(t, res.Scale.Cmp(tc.params.DefaultScale()), 0) + require.Equal(t, res.Scale.Cmp(tc.Params.DefaultScale()), 0) - verifyTestVectors(tc, tc.decryptor, values, res, t) + bgv.VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, res, values, t) }) }) } - -func newTestVectorsLvl(level int, scale rlwe.Scale, tc *testContext, encryptor *rlwe.Encryptor) (coeffs ring.Poly, plaintext *rlwe.Plaintext, ciphertext *rlwe.Ciphertext) { - coeffs = tc.uSampler.ReadNew() - for i := range coeffs.Coeffs[0] { - coeffs.Coeffs[0][i] = uint64(i) - } - - plaintext = bgv.NewPlaintext(tc.params, level) - plaintext.Scale = scale - tc.encoder.Encode(coeffs.Coeffs[0], plaintext) - if encryptor != nil { - var err error - ciphertext, err = encryptor.EncryptNew(plaintext) - if err != nil { - panic(err) - } - } - - return coeffs, plaintext, ciphertext -} - -func verifyTestVectors(tc *testContext, decryptor *rlwe.Decryptor, coeffs ring.Poly, element rlwe.ElementInterface[ring.Poly], t *testing.T) { - - coeffsTest := make([]uint64, tc.params.MaxSlots()) - - switch el := element.(type) { - case *rlwe.Plaintext: - require.NoError(t, tc.encoder.Decode(el, coeffsTest)) - case *rlwe.Ciphertext: - - pt := decryptor.DecryptNew(el) - - require.NoError(t, tc.encoder.Decode(pt, coeffsTest)) - - if *flagPrintNoise { - require.NoError(t, tc.encoder.Encode(coeffsTest, pt)) - ct, err := tc.evaluator.SubNew(el, pt) - require.NoError(t, err) - vartmp, _, _ := rlwe.Norm(ct, decryptor) - t.Logf("STD(noise): %f\n", vartmp) - } - - default: - t.Error("invalid test object to verify") - } - - require.True(t, slices.Equal(coeffs.Coeffs[0], coeffsTest)) -} diff --git a/schemes/bfv/bfv_benchmark_test.go b/schemes/bfv/bfv_benchmark_test.go index d5e8873b..f6683ba0 100644 --- a/schemes/bfv/bfv_benchmark_test.go +++ b/schemes/bfv/bfv_benchmark_test.go @@ -2,25 +2,13 @@ package bfv import ( "encoding/json" - "fmt" "runtime" "testing" "github.com/tuneinsight/lattigo/v5/core/rlwe" ) -func GetBenchName(params Parameters, opname string) string { - - return fmt.Sprintf("%s/logN=%d/Qi=%d/Pi=%d/LogSlots=%d", - opname, - params.LogN(), - params.QCount(), - params.PCount(), - params.LogMaxSlots()) -} - func BenchmarkBFV(b *testing.B) { - var err error var testParams []ParametersLiteral @@ -42,20 +30,9 @@ func BenchmarkBFV(b *testing.B) { } for _, paramsLiteral := range testParams { + tc := NewTestContext(paramsLiteral) - var params Parameters - if params, err = NewParametersFromLiteral(paramsLiteral); err != nil { - b.Error(err) - b.Fail() - } - - var tc *testContext - if tc, err = genTestParams(params); err != nil { - b.Error(err) - b.Fail() - } - - for _, testSet := range []func(tc *testContext, b *testing.B){ + for _, testSet := range []func(tc *TestContext, b *testing.B){ benchEncoder, benchEvaluator, } { @@ -65,11 +42,11 @@ func BenchmarkBFV(b *testing.B) { } } -func benchEncoder(tc *testContext, b *testing.B) { +func benchEncoder(tc *TestContext, b *testing.B) { - params := tc.params + params := tc.Params - poly := tc.uSampler.ReadNew() + poly := tc.Sampler.ReadNew() params.RingT().Reduce(poly, poly) coeffsUint64 := poly.Coeffs[0] coeffsInt64 := make([]int64, len(coeffsUint64)) @@ -77,12 +54,12 @@ func benchEncoder(tc *testContext, b *testing.B) { coeffsInt64[i] = int64(coeffsUint64[i]) } - encoder := tc.encoder + encoder := tc.Ecd level := params.MaxLevel() plaintext := NewPlaintext(params, level) - b.Run(GetBenchName(params, "Encoder/Encode/Uint"), func(b *testing.B) { + b.Run(name("Encoder/Encode/Uint", tc, level), func(b *testing.B) { for i := 0; i < b.N; i++ { if err := encoder.Encode(coeffsUint64, plaintext); err != nil { b.Log(err) @@ -91,7 +68,7 @@ func benchEncoder(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Encoder/Encode/Int"), func(b *testing.B) { + b.Run(name("Encoder/Encode/Int", tc, level), func(b *testing.B) { for i := 0; i < b.N; i++ { if err := encoder.Encode(coeffsInt64, plaintext); err != nil { b.Log(err) @@ -100,7 +77,7 @@ func benchEncoder(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Encoder/Decode/Uint"), func(b *testing.B) { + b.Run(name("Encoder/Decode/Uint", tc, level), func(b *testing.B) { for i := 0; i < b.N; i++ { if err := encoder.Decode(plaintext, coeffsUint64); err != nil { b.Log(err) @@ -109,7 +86,7 @@ func benchEncoder(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Encoder/Decode/Int"), func(b *testing.B) { + b.Run(name("Encoder/Decode/Int", tc, level), func(b *testing.B) { for i := 0; i < b.N; i++ { if err := encoder.Decode(plaintext, coeffsInt64); err != nil { b.Log(err) @@ -119,16 +96,18 @@ func benchEncoder(tc *testContext, b *testing.B) { }) } -func benchEvaluator(tc *testContext, b *testing.B) { +func benchEvaluator(tc *TestContext, b *testing.B) { - params := tc.params - eval := tc.evaluator + params := tc.Params + eval := tc.Evl - plaintext := NewPlaintext(params, params.MaxLevel()) - plaintext.Value = rlwe.NewCiphertextRandom(tc.prng, params.Parameters, 0, plaintext.Level()).Value[0] + level := params.MaxLevel() - ciphertext1 := rlwe.NewCiphertextRandom(tc.prng, params.Parameters, 1, params.MaxLevel()) - ciphertext2 := rlwe.NewCiphertextRandom(tc.prng, params.Parameters, 1, params.MaxLevel()) + plaintext := NewPlaintext(params, level) + plaintext.Value = rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 0, plaintext.Level()).Value[0] + + ciphertext1 := rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 1, level) + ciphertext2 := rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 1, level) scalar := params.PlaintextModulus() >> 1 *ciphertext1.MetaData = *plaintext.MetaData @@ -136,7 +115,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { vector := plaintext.Value.Coeffs[0][:params.MaxSlots()] - b.Run(GetBenchName(params, "Evaluator/Add/Scalar"), func(b *testing.B) { + b.Run(name("Evaluator/Add/Scalar", tc, level), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -147,7 +126,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/Add/Vector"), func(b *testing.B) { + b.Run(name("Evaluator/Add/Vector", tc, level), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -158,7 +137,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/Add/Plaintext"), func(b *testing.B) { + b.Run(name("Evaluator/Add/Plaintext", tc, level), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -169,7 +148,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/Add/Ciphertext"), func(b *testing.B) { + b.Run(name("Evaluator/Add/Ciphertext", tc, level), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -180,7 +159,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/Mul/Scalar"), func(b *testing.B) { + b.Run(name("Evaluator/Mul/Scalar", tc, level), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -191,7 +170,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/Mul/Plaintext"), func(b *testing.B) { + b.Run(name("Evaluator/Mul/Plaintext", tc, level), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -202,7 +181,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/Mul/Vector"), func(b *testing.B) { + b.Run(name("Evaluator/Mul/Vector", tc, level), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -213,7 +192,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/Mul/Ciphertext"), func(b *testing.B) { + b.Run(name("Evaluator/Mul/Ciphertext", tc, level), func(b *testing.B) { receiver := NewCiphertext(params, 2, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -224,7 +203,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/MulRelin/Ciphertext"), func(b *testing.B) { + b.Run(name("Evaluator/MulRelin/Ciphertext", tc, level), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -235,7 +214,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/MulThenAdd/Scalar"), func(b *testing.B) { + b.Run(name("Evaluator/MulThenAdd/Scalar", tc, level), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -246,7 +225,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/MulThenAdd/Vector"), func(b *testing.B) { + b.Run(name("Evaluator/MulThenAdd/Vector", tc, level), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -257,7 +236,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/MulThenAdd/Plaintext"), func(b *testing.B) { + b.Run(name("Evaluator/MulThenAdd/Plaintext", tc, level), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -268,7 +247,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/MulThenAdd/Ciphertext"), func(b *testing.B) { + b.Run(name("Evaluator/MulThenAdd/Ciphertext", tc, level), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -279,7 +258,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/MulRelinThenAdd/Ciphertext"), func(b *testing.B) { + b.Run(name("Evaluator/MulRelinThenAdd/Ciphertext", tc, level), func(b *testing.B) { receiver := NewCiphertext(params, 2, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -290,8 +269,8 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/Rotate"), func(b *testing.B) { - gk := tc.kgen.GenGaloisKeyNew(5, tc.sk) + b.Run(name("Evaluator/Rotate", tc, level), func(b *testing.B) { + gk := tc.Kgen.GenGaloisKeyNew(5, tc.Sk) evk := rlwe.NewMemEvaluationKeySet(nil, gk) eval := eval.WithKey(evk) receiver := NewCiphertext(params, 1, ciphertext2.Level()) diff --git a/schemes/bfv/bfv_test.go b/schemes/bfv/bfv_test.go index 27130342..071bb130 100644 --- a/schemes/bfv/bfv_test.go +++ b/schemes/bfv/bfv_test.go @@ -4,14 +4,12 @@ import ( "encoding/json" "flag" "fmt" - "math" "runtime" "slices" "testing" "github.com/tuneinsight/lattigo/v5/core/rlwe" "github.com/tuneinsight/lattigo/v5/ring" - "github.com/tuneinsight/lattigo/v5/utils/sampling" "github.com/stretchr/testify/require" ) @@ -19,23 +17,14 @@ import ( var flagPrintNoise = flag.Bool("print-noise", false, "print the residual noise") var flagParamString = flag.String("params", "", "specify the test cryptographic parameters as a JSON string. Overrides -short.") -func GetTestName(opname string, p Parameters, lvl int) string { - return fmt.Sprintf("%s/LogN=%d/logQ=%d/logP=%d/logT=%d/Qi=%d/Pi=%d/lvl=%d", - opname, - p.LogN(), - int(math.Round(p.LogQ())), - int(math.Round(p.LogP())), - int(math.Round(p.LogT())), - p.QCount(), - p.PCount(), - lvl) +func name(op string, tc *TestContext, lvl int) string { + return fmt.Sprintf("%s/%s/lvl=%d", op, tc, lvl) } func TestBFV(t *testing.T) { - var err error - paramsLiterals := testParams + paramsLiterals := TestParams if *flagParamString != "" { var jsonParams ParametersLiteral @@ -47,17 +36,13 @@ func TestBFV(t *testing.T) { for _, p := range paramsLiterals[:] { - for _, plaintextModulus := range testPlaintextModulus[:] { + for _, plaintextModulus := range TestPlaintextModulus[:] { p.PlaintextModulus = plaintextModulus - params, err := NewParametersFromLiteral(p) - require.NoError(t, err) + tc := NewTestContext(p) - tc, err := genTestParams(params) - require.NoError(t, err) - - for _, testSet := range []func(tc *testContext, t *testing.T){ + for _, testSet := range []func(tc *TestContext, t *testing.T){ testParameters, testEncoder, testEvaluator, @@ -69,110 +54,18 @@ func TestBFV(t *testing.T) { } } -type testContext struct { - params Parameters - ringQ *ring.Ring - ringT *ring.Ring - prng sampling.PRNG - uSampler *ring.UniformSampler - encoder *Encoder - kgen *rlwe.KeyGenerator - sk *rlwe.SecretKey - pk *rlwe.PublicKey - encryptorPk *rlwe.Encryptor - encryptorSk *rlwe.Encryptor - decryptor *rlwe.Decryptor - evaluator *Evaluator - testLevel []int -} - -func genTestParams(params Parameters) (tc *testContext, err error) { - - tc = new(testContext) - tc.params = params - - if tc.prng, err = sampling.NewPRNG(); err != nil { - return nil, err - } - - tc.ringQ = params.RingQ() - tc.ringT = params.RingT() - - tc.uSampler = ring.NewUniformSampler(tc.prng, tc.ringT) - tc.kgen = NewKeyGenerator(tc.params) - tc.sk, tc.pk = tc.kgen.GenKeyPairNew() - tc.encoder = NewEncoder(tc.params) - - tc.encryptorPk = NewEncryptor(tc.params, tc.pk) - tc.encryptorSk = NewEncryptor(tc.params, tc.sk) - tc.decryptor = NewDecryptor(tc.params, tc.sk) - tc.evaluator = NewEvaluator(tc.params, rlwe.NewMemEvaluationKeySet(tc.kgen.GenRelinearizationKeyNew(tc.sk))) - - tc.testLevel = []int{0, params.MaxLevel()} - - return -} - -func newTestVectorsLvl(level int, scale rlwe.Scale, tc *testContext, encryptor *rlwe.Encryptor) (coeffs ring.Poly, plaintext *rlwe.Plaintext, ciphertext *rlwe.Ciphertext) { - coeffs = tc.uSampler.ReadNew() - for i := range coeffs.Coeffs[0] { - coeffs.Coeffs[0][i] = uint64(i) - } - plaintext = NewPlaintext(tc.params, level) - plaintext.Scale = scale - tc.encoder.Encode(coeffs.Coeffs[0], plaintext) - if encryptor != nil { - var err error - ciphertext, err = encryptor.EncryptNew(plaintext) - if err != nil { - panic(err) - } - } - - return coeffs, plaintext, ciphertext -} - -func verifyTestVectors(tc *testContext, decryptor *rlwe.Decryptor, coeffs ring.Poly, element rlwe.ElementInterface[ring.Poly], t *testing.T) { - - coeffsTest := make([]uint64, tc.params.MaxSlots()) - - switch el := element.(type) { - case *rlwe.Plaintext: - require.NoError(t, tc.encoder.Decode(el, coeffsTest)) - case *rlwe.Ciphertext: - - pt := decryptor.DecryptNew(el) - - require.NoError(t, tc.encoder.Decode(pt, coeffsTest)) - - if *flagPrintNoise { - require.NoError(t, tc.encoder.Encode(coeffsTest, pt)) - ct, err := tc.evaluator.SubNew(el, pt) - require.NoError(t, err) - vartmp, _, _ := rlwe.Norm(ct, decryptor) - t.Logf("STD(noise): %f\n", vartmp) - } - - default: - t.Fatal("invalid test object to verify") - } - - require.True(t, slices.Equal(coeffs.Coeffs[0], coeffsTest)) -} - -func testParameters(tc *testContext, t *testing.T) { - t.Run(GetTestName("Parameters/Marshaller/Binary", tc.params, 0), func(t *testing.T) { - - bytes, err := tc.params.MarshalBinary() +func testParameters(tc *TestContext, t *testing.T) { + t.Run(name("Parameters/Marshaller/Binary", tc, 0), func(t *testing.T) { + bytes, err := tc.Params.MarshalBinary() require.Nil(t, err) var p Parameters require.Nil(t, p.UnmarshalBinary(bytes)) - require.True(t, tc.params.Equal(&p)) + require.True(t, tc.Params.Equal(&p)) }) - t.Run(GetTestName("Parameters/Marshaller/JSON", tc.params, 0), func(t *testing.T) { + t.Run(name("Parameters/Marshaller/JSON", tc, 0), func(t *testing.T) { // checks that parameters can be marshalled without error - data, err := json.Marshal(tc.params) + data, err := json.Marshal(tc.Params) require.Nil(t, err) require.NotNil(t, data) @@ -180,10 +73,10 @@ func testParameters(tc *testContext, t *testing.T) { var paramsRec Parameters err = json.Unmarshal(data, ¶msRec) require.Nil(t, err) - require.True(t, tc.params.Equal(¶msRec)) + require.True(t, tc.Params.Equal(¶msRec)) // checks that the Parameters can be unmarshalled with log-moduli definition without error - dataWithLogModuli := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[60], "PlaintextModulus":65537}`, tc.params.LogN())) + dataWithLogModuli := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[60], "PlaintextModulus":65537}`, tc.Params.LogN())) var paramsWithLogModuli Parameters err = json.Unmarshal(dataWithLogModuli, ¶msWithLogModuli) require.Nil(t, err) @@ -193,7 +86,7 @@ func testParameters(tc *testContext, t *testing.T) { require.Equal(t, rlwe.DefaultXs, paramsWithLogModuli.Xs()) // Omitting Xe should result in Default being used // checks that one can provide custom parameters for the secret-key and error distributions - dataWithCustomSecrets := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[60], "PlaintextModulus":65537, "Xs": {"Type": "Ternary", "H": 192}, "Xe": {"Type": "DiscreteGaussian", "Sigma": 6.6, "Bound": 39.6}}`, tc.params.LogN())) + dataWithCustomSecrets := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[60], "PlaintextModulus":65537, "Xs": {"Type": "Ternary", "H": 192}, "Xe": {"Type": "DiscreteGaussian", "Sigma": 6.6, "Bound": 39.6}}`, tc.Params.LogN())) var paramsWithCustomSecrets Parameters err = json.Unmarshal(dataWithCustomSecrets, ¶msWithCustomSecrets) require.Nil(t, err) @@ -202,21 +95,22 @@ func testParameters(tc *testContext, t *testing.T) { }) } -func testEncoder(tc *testContext, t *testing.T) { +func testEncoder(tc *TestContext, t *testing.T) { + testLevels := []int{0, tc.Params.MaxLevel()} - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Encoder/Uint", tc.params, lvl), func(t *testing.T) { - values, plaintext, _ := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, nil) - verifyTestVectors(tc, nil, values, plaintext, t) + for _, lvl := range testLevels { + t.Run(name("Encoder/Uint", tc, lvl), func(t *testing.T) { + values, plaintext, _ := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale()) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, plaintext, values, t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Encoder/Int", tc.params, lvl), func(t *testing.T) { + for _, lvl := range testLevels { + t.Run(name("Encoder/Int", tc, lvl), func(t *testing.T) { - T := tc.params.PlaintextModulus() + T := tc.Params.PlaintextModulus() THalf := T >> 1 - coeffs := tc.uSampler.ReadNew() + coeffs := tc.Sampler.ReadNew() coeffsInt := make([]int64, len(coeffs.Coeffs[0])) for i, c := range coeffs.Coeffs[0] { c %= T @@ -227,317 +121,335 @@ func testEncoder(tc *testContext, t *testing.T) { } } - plaintext := NewPlaintext(tc.params, lvl) - tc.encoder.Encode(coeffsInt, plaintext) - have := make([]int64, tc.params.MaxSlots()) - tc.encoder.Decode(plaintext, have) + plaintext := NewPlaintext(tc.Params, lvl) + tc.Ecd.Encode(coeffsInt, plaintext) + have := make([]int64, tc.Params.MaxSlots()) + tc.Ecd.Decode(plaintext, have) require.True(t, slices.Equal(coeffsInt, have)) }) } } -func testEvaluator(tc *testContext, t *testing.T) { +func testEvaluator(tc *TestContext, t *testing.T) { + testLevels := []int{0, tc.Params.MaxLevel()} t.Run("Evaluator", func(t *testing.T) { - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Add/Ct/Ct/New", tc.params, lvl), func(t *testing.T) { - - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) - values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + for _, lvl := range testLevels { + t.Run(name("Add/Ct/Ct/New", tc, lvl), func(t *testing.T) { + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0) - ciphertext2, err := tc.evaluator.AddNew(ciphertext0, ciphertext1) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} + + ciphertext2, err := tc.Evl.AddNew(ciphertext0, ciphertext1) require.NoError(t, err) - tc.ringT.Add(values0, values1, values0) - - verifyTestVectors(tc, tc.decryptor, values0, ciphertext2, t) + tc.Params.RingT().Add(p0, p1, p0) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext2, p0.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Add/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) { - - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) - values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + for _, lvl := range testLevels { + t.Run(name("Add/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) { + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0) - require.NoError(t, tc.evaluator.Add(ciphertext0, ciphertext1, ciphertext0)) - tc.ringT.Add(values0, values1, values0) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} - verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t) + require.NoError(t, tc.Evl.Add(ciphertext0, ciphertext1, ciphertext0)) + tc.Params.RingT().Add(p0, p1, p0) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Add/Ct/Pt/Inplace", tc.params, lvl), func(t *testing.T) { - - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) - values1, plaintext, _ := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + for _, lvl := range testLevels { + t.Run(name("Add/Ct/Pt/Inplace", tc, lvl), func(t *testing.T) { + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, plaintext, _ := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) require.True(t, ciphertext0.Scale.Cmp(plaintext.Scale) != 0) - require.NoError(t, tc.evaluator.Add(ciphertext0, plaintext, ciphertext0)) - tc.ringT.Add(values0, values1, values0) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} - verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t) + require.NoError(t, tc.Evl.Add(ciphertext0, plaintext, ciphertext0)) + tc.Params.RingT().Add(p0, p1, p0) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Add/Ct/Scalar/Inplace", tc.params, lvl), func(t *testing.T) { + for _, lvl := range testLevels { + t.Run(name("Add/Ct/Scalar/Inplace", tc, lvl), func(t *testing.T) { + values, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale()) - values, _, ciphertext := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk) + scalar := tc.Params.PlaintextModulus() >> 1 - scalar := tc.params.PlaintextModulus() >> 1 + p := ring.Poly{Coeffs: [][]uint64{values}} - require.NoError(t, tc.evaluator.Add(ciphertext, scalar, ciphertext)) - tc.ringT.AddScalar(values, scalar, values) - - verifyTestVectors(tc, tc.decryptor, values, ciphertext, t) + require.NoError(t, tc.Evl.Add(ciphertext, scalar, ciphertext)) + tc.Params.RingT().AddScalar(p, scalar, p) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Sub/Ct/Ct/New", tc.params, lvl), func(t *testing.T) { - - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) - values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + for _, lvl := range testLevels { + t.Run(name("Sub/Ct/Ct/New", tc, lvl), func(t *testing.T) { + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0) - ciphertext0, err := tc.evaluator.SubNew(ciphertext0, ciphertext1) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} + + ciphertext0, err := tc.Evl.SubNew(ciphertext0, ciphertext1) require.NoError(t, err) - tc.ringT.Sub(values0, values1, values0) - - verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t) + tc.Params.RingT().Sub(p0, p1, p0) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Sub/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) { - - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) - values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + for _, lvl := range testLevels { + t.Run(name("Sub/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) { + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0) - require.NoError(t, tc.evaluator.Sub(ciphertext0, ciphertext1, ciphertext0)) - tc.ringT.Sub(values0, values1, values0) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} - verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t) + require.NoError(t, tc.Evl.Sub(ciphertext0, ciphertext1, ciphertext0)) + tc.Params.RingT().Sub(p0, p1, p0) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Sub/Ct/Pt/Inplace", tc.params, lvl), func(t *testing.T) { - - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) - values1, plaintext, _ := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + for _, lvl := range testLevels { + t.Run(name("Sub/Ct/Pt/Inplace", tc, lvl), func(t *testing.T) { + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, plaintext, _ := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) require.True(t, ciphertext0.Scale.Cmp(plaintext.Scale) != 0) - require.NoError(t, tc.evaluator.Sub(ciphertext0, plaintext, ciphertext0)) - tc.ringT.Sub(values0, values1, values0) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} - verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t) + require.NoError(t, tc.Evl.Sub(ciphertext0, plaintext, ciphertext0)) + tc.Params.RingT().Sub(p0, p1, p0) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Mul/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) { - + for _, lvl := range testLevels { + t.Run(name("Mul/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) { if lvl == 0 { - t.Skip("Level = 0") + t.Skip("Skipping: Level = 0") } - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) - values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0) - require.NoError(t, tc.evaluator.Mul(ciphertext0, ciphertext1, ciphertext0)) - tc.ringT.MulCoeffsBarrett(values0, values1, values0) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} - verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t) + require.NoError(t, tc.Evl.Mul(ciphertext0, ciphertext1, ciphertext0)) + tc.Params.RingT().MulCoeffsBarrett(p0, p1, p0) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Mul/Ct/Pt/Inplace", tc.params, lvl), func(t *testing.T) { - + for _, lvl := range testLevels { + t.Run(name("Mul/Ct/Pt/Inplace", tc, lvl), func(t *testing.T) { if lvl == 0 { t.Skip("Level = 0") } - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) - values1, plaintext, _ := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + values0, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, plaintext, _ := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) - require.True(t, ciphertext0.Scale.Cmp(plaintext.Scale) != 0) + require.True(t, ciphertext.Scale.Cmp(plaintext.Scale) != 0) - require.NoError(t, tc.evaluator.Mul(ciphertext0, plaintext, ciphertext0)) - tc.ringT.MulCoeffsBarrett(values0, values1, values0) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} - verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t) + require.NoError(t, tc.Evl.Mul(ciphertext, plaintext, ciphertext)) + tc.Params.RingT().MulCoeffsBarrett(p0, p1, p0) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p0.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Mul/Ct/Scalar/Inplace", tc.params, lvl), func(t *testing.T) { - + for _, lvl := range testLevels { + t.Run(name("Mul/Ct/Scalar/Inplace", tc, lvl), func(t *testing.T) { if lvl == 0 { t.Skip("Level = 0") } - values, _, ciphertext := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk) + values, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale()) - scalar := tc.params.PlaintextModulus() >> 1 + scalar := tc.Params.PlaintextModulus() >> 1 - tc.evaluator.Mul(ciphertext, scalar, ciphertext) - tc.ringT.MulScalar(values, scalar, values) + p := ring.Poly{Coeffs: [][]uint64{values}} - verifyTestVectors(tc, tc.decryptor, values, ciphertext, t) + require.NoError(t, tc.Evl.Mul(ciphertext, scalar, ciphertext)) + tc.Params.RingT().MulScalar(p, scalar, p) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Square/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) { - + for _, lvl := range testLevels { + t.Run(name("Square/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) { if lvl == 0 { t.Skip("Level = 0") } - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) + values, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale()) - require.NoError(t, tc.evaluator.Mul(ciphertext0, ciphertext0, ciphertext0)) - tc.ringT.MulCoeffsBarrett(values0, values0, values0) + p := ring.Poly{Coeffs: [][]uint64{values}} - verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t) + require.NoError(t, tc.Evl.Mul(ciphertext, ciphertext, ciphertext)) + tc.Params.RingT().MulCoeffsBarrett(p, p, p) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("MulRelin/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) { - + for _, lvl := range testLevels { + t.Run(name("MulRelin/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) { if lvl == 0 { t.Skip("Level = 0") } - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) - values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) - tc.ringT.MulCoeffsBarrett(values0, values1, values0) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} + + tc.Params.RingT().MulCoeffsBarrett(p0, p1, p0) require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0) - receiver := NewCiphertext(tc.params, 1, lvl) + receiver := NewCiphertext(tc.Params, 1, lvl) - require.NoError(t, tc.evaluator.MulRelin(ciphertext0, ciphertext1, receiver)) - - require.NoError(t, tc.evaluator.Rescale(receiver, receiver)) - - verifyTestVectors(tc, tc.decryptor, values0, receiver, t) + require.NoError(t, tc.Evl.MulRelin(ciphertext0, ciphertext1, receiver)) + require.NoError(t, tc.Evl.Rescale(receiver, receiver)) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, receiver, p0.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("MulThenAdd/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) { - + for _, lvl := range testLevels { + t.Run(name("MulThenAdd/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) { if lvl == 0 { t.Skip("Level = 0") } - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk) - values1, _, ciphertext1 := newTestVectorsLvl(lvl, rlwe.NewScale(2), tc, tc.encryptorSk) - values2, _, ciphertext2 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale()) + values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(2)) + values2, _, ciphertext2 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) + + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} + p2 := ring.Poly{Coeffs: [][]uint64{values2}} require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0) require.True(t, ciphertext0.Scale.Cmp(ciphertext2.Scale) != 0) - require.NoError(t, tc.evaluator.MulThenAdd(ciphertext0, ciphertext1, ciphertext2)) - tc.ringT.MulCoeffsBarrettThenAdd(values0, values1, values2) - - verifyTestVectors(tc, tc.decryptor, values2, ciphertext2, t) + require.NoError(t, tc.Evl.MulThenAdd(ciphertext0, ciphertext1, ciphertext2)) + tc.Params.RingT().MulCoeffsBarrettThenAdd(p0, p1, p2) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext2, p2.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("MulThenAdd/Ct/Pt/Inplace", tc.params, lvl), func(t *testing.T) { + for _, lvl := range testLevels { + t.Run(name("MulThenAdd/Ct/Pt/Inplace", tc, lvl), func(t *testing.T) { if lvl == 0 { t.Skip("Level = 0") } - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk) - values1, plaintext1, _ := newTestVectorsLvl(lvl, rlwe.NewScale(2), tc, tc.encryptorSk) - values2, _, ciphertext2 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale()) + values1, plaintext1, _ := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(2)) + values2, _, ciphertext2 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) + + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} + p2 := ring.Poly{Coeffs: [][]uint64{values2}} require.True(t, ciphertext0.Scale.Cmp(plaintext1.Scale) != 0) require.True(t, ciphertext0.Scale.Cmp(ciphertext2.Scale) != 0) - require.NoError(t, tc.evaluator.MulThenAdd(ciphertext0, plaintext1, ciphertext2)) - tc.ringT.MulCoeffsBarrettThenAdd(values0, values1, values2) - - verifyTestVectors(tc, tc.decryptor, values2, ciphertext2, t) + require.NoError(t, tc.Evl.MulThenAdd(ciphertext0, plaintext1, ciphertext2)) + tc.Params.RingT().MulCoeffsBarrettThenAdd(p0, p1, p2) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext2, p2.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("MulThenAdd/Ct/Scalar/Inplace", tc.params, lvl), func(t *testing.T) { - + for _, lvl := range testLevels { + t.Run(name("MulThenAdd/Ct/Scalar/Inplace", tc, lvl), func(t *testing.T) { if lvl == 0 { t.Skip("Level = 0") } - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) - values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) + + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0) - scalar := tc.params.PlaintextModulus() >> 1 + scalar := tc.Params.PlaintextModulus() >> 1 - require.NoError(t, tc.evaluator.MulThenAdd(ciphertext0, scalar, ciphertext1)) - tc.ringT.MulScalarThenAdd(values0, scalar, values1) + require.NoError(t, tc.Evl.MulThenAdd(ciphertext0, scalar, ciphertext1)) + tc.Params.RingT().MulScalarThenAdd(p0, scalar, p1) - verifyTestVectors(tc, tc.decryptor, values1, ciphertext1, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext1, p1.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("MulRelinThenAdd/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) { - + for _, lvl := range testLevels { + t.Run(name("MulRelinThenAdd/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) { if lvl == 0 { t.Skip("Level = 0") } - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk) - values1, _, ciphertext1 := newTestVectorsLvl(lvl, rlwe.NewScale(2), tc, tc.encryptorSk) - values2, _, ciphertext2 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale()) + values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(2)) + values2, _, ciphertext2 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0) require.True(t, ciphertext0.Scale.Cmp(ciphertext2.Scale) != 0) - require.NoError(t, tc.evaluator.MulRelinThenAdd(ciphertext0, ciphertext1, ciphertext2)) - tc.ringT.MulCoeffsBarrettThenAdd(values0, values1, values2) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} + p2 := ring.Poly{Coeffs: [][]uint64{values2}} - verifyTestVectors(tc, tc.decryptor, values2, ciphertext2, t) + require.NoError(t, tc.Evl.MulRelinThenAdd(ciphertext0, ciphertext1, ciphertext2)) + tc.Params.RingT().MulCoeffsBarrettThenAdd(p0, p1, p2) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext2, p2.Coeffs[0], t) }) } }) diff --git a/schemes/bfv/test_parameters.go b/schemes/bfv/test_parameters.go deleted file mode 100644 index e7128fa0..00000000 --- a/schemes/bfv/test_parameters.go +++ /dev/null @@ -1,15 +0,0 @@ -package bfv - -var ( - - // testInsecure are insecure parameters used for the sole purpose of fast testing. - testInsecure = ParametersLiteral{ - LogN: 10, - Q: []uint64{0x3fffffa8001, 0x1000090001, 0x10000c8001, 0x10000f0001, 0xffff00001}, - P: []uint64{0x7fffffd8001}, - } - - testPlaintextModulus = []uint64{0x101, 0xffc001} - - testParams = []ParametersLiteral{testInsecure} -) diff --git a/schemes/bfv/test_utils.go b/schemes/bfv/test_utils.go new file mode 100644 index 00000000..b3432f60 --- /dev/null +++ b/schemes/bfv/test_utils.go @@ -0,0 +1,120 @@ +package bfv + +import ( + "fmt" + "math" + "slices" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/tuneinsight/lattigo/v5/core/rlwe" + "github.com/tuneinsight/lattigo/v5/ring" + "github.com/tuneinsight/lattigo/v5/utils/sampling" +) + +type TestContext struct { + Params Parameters + Ecd *Encoder + + Prng sampling.PRNG + Sampler *ring.UniformSampler + + Kgen *rlwe.KeyGenerator + Sk *rlwe.SecretKey + Pk *rlwe.PublicKey + + Enc *rlwe.Encryptor + Dec *rlwe.Decryptor + + Evl *Evaluator +} + +func NewTestContext(params ParametersLiteral) *TestContext { + tc := new(TestContext) + + var err error + + tc.Params, err = NewParametersFromLiteral(params) + if err != nil { + panic(err) + } + tc.Ecd = NewEncoder(tc.Params) + + tc.Prng, err = sampling.NewPRNG() + if err != nil { + panic(err) + } + tc.Sampler = ring.NewUniformSampler(tc.Prng, tc.Params.RingT()) + + tc.Kgen = rlwe.NewKeyGenerator(tc.Params) + tc.Sk, tc.Pk = tc.Kgen.GenKeyPairNew() + + tc.Enc = rlwe.NewEncryptor(tc.Params, tc.Pk) + tc.Dec = rlwe.NewDecryptor(tc.Params, tc.Sk) + + tc.Evl = NewEvaluator(tc.Params, rlwe.NewMemEvaluationKeySet(tc.Kgen.GenRelinearizationKeyNew(tc.Sk))) + + return tc +} + +func (tc TestContext) String() string { + return fmt.Sprintf("LogN=%d/logQ=%d/logP=%d/LogSlots=%dx%d/logT=%d/Qi=%d/Pi=%d", + tc.Params.LogN(), + int(math.Round(tc.Params.LogQ())), + int(math.Round(tc.Params.LogP())), + tc.Params.LogMaxDimensions().Rows, + tc.Params.LogMaxDimensions().Cols, + int(math.Round(tc.Params.LogT())), + tc.Params.QCount(), + tc.Params.PCount()) +} + +func VerifyTestVectors(params Parameters, encoder *Encoder, decryptor *rlwe.Decryptor, have interface{}, want []uint64, t *testing.T) { + values := make([]uint64, params.MaxSlots()) + + switch have := have.(type) { + case *rlwe.Plaintext: + require.NoError(t, encoder.Decode(have, values)) + case *rlwe.Ciphertext: + require.NoError(t, encoder.Decode(decryptor.DecryptNew(have), values)) + default: + t.Error("invalid unsupported test object type") + } + + require.True(t, slices.Equal(values, want)) +} + +func NewTestVector(params Parameters, encoder *Encoder, encryptor *rlwe.Encryptor, level int, scale rlwe.Scale) (values []uint64, pt *rlwe.Plaintext, ct *rlwe.Ciphertext) { + values = make([]uint64, params.MaxSlots()) + for i := range values { + values[i] = sampling.RandUint64() % params.PlaintextModulus() + } + + pt = NewPlaintext(params, level) + pt.Scale = scale + if err := encoder.Encode(values, pt); err != nil { + panic(err) + } + if encryptor != nil { + var err error + ct, err = encryptor.EncryptNew(pt) + if err != nil { + panic(err) + } + } + return +} + +var ( + // testInsecure are insecure parameters used for the sole purpose of fast testing. + TestInsecure = ParametersLiteral{ + LogN: 10, + Q: []uint64{0x3fffffa8001, 0x1000090001, 0x10000c8001, 0x10000f0001, 0xffff00001}, + P: []uint64{0x7fffffd8001}, + } + + TestPlaintextModulus = []uint64{0x101, 0xffc001} + + TestParams = []ParametersLiteral{TestInsecure} +) diff --git a/schemes/bgv/bgv_benchmark_test.go b/schemes/bgv/bgv_benchmark_test.go index ff25f247..8eac0608 100644 --- a/schemes/bgv/bgv_benchmark_test.go +++ b/schemes/bgv/bgv_benchmark_test.go @@ -2,24 +2,13 @@ package bgv import ( "encoding/json" - "fmt" "runtime" "testing" "github.com/tuneinsight/lattigo/v5/core/rlwe" ) -func GetBenchName(params Parameters, opname string) string { - return fmt.Sprintf("%s/logN=%d/Qi=%d/Pi=%d/LogSlots=%d", - opname, - params.LogN(), - params.QCount(), - params.PCount(), - params.LogMaxSlots()) -} - func BenchmarkBGV(b *testing.B) { - var err error var testParams []ParametersLiteral @@ -41,20 +30,9 @@ func BenchmarkBGV(b *testing.B) { } for _, paramsLiteral := range testParams { + tc := NewTestContext(paramsLiteral) - var params Parameters - if params, err = NewParametersFromLiteral(paramsLiteral); err != nil { - b.Error(err) - b.Fail() - } - - var tc *testContext - if tc, err = genTestParams(params); err != nil { - b.Error(err) - b.Fail() - } - - for _, testSet := range []func(tc *testContext, b *testing.B){ + for _, testSet := range []func(tc *TestContext, b *testing.B){ benchEncoder, benchEvaluator, } { @@ -64,11 +42,12 @@ func BenchmarkBGV(b *testing.B) { } } -func benchEncoder(tc *testContext, b *testing.B) { +func benchEncoder(tc *TestContext, b *testing.B) { - params := tc.params + params := tc.Params + lvl := params.MaxLevel() - poly := tc.uSampler.ReadNew() + poly := tc.Sampler.ReadNew() params.RingT().Reduce(poly, poly) coeffsUint64 := poly.Coeffs[0] coeffsInt64 := make([]int64, len(coeffsUint64)) @@ -76,10 +55,10 @@ func benchEncoder(tc *testContext, b *testing.B) { coeffsInt64[i] = int64(coeffsUint64[i]) } - encoder := tc.encoder + encoder := tc.Ecd - b.Run(GetBenchName(params, "Encoder/Encode/Uint"), func(b *testing.B) { - plaintext := NewPlaintext(params, params.MaxLevel()) + b.Run(name("Encoder/Encode/Uint", tc, lvl), func(b *testing.B) { + plaintext := NewPlaintext(params, lvl) b.ResetTimer() for i := 0; i < b.N; i++ { if err := encoder.Encode(coeffsUint64, plaintext); err != nil { @@ -89,8 +68,8 @@ func benchEncoder(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Encoder/Encode/Int"), func(b *testing.B) { - plaintext := NewPlaintext(params, params.MaxLevel()) + b.Run(name("Encoder/Encode/Int", tc, lvl), func(b *testing.B) { + plaintext := NewPlaintext(params, lvl) b.ResetTimer() for i := 0; i < b.N; i++ { if err := encoder.Encode(coeffsInt64, plaintext); err != nil { @@ -100,8 +79,8 @@ func benchEncoder(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Encoder/Decode/Uint"), func(b *testing.B) { - plaintext := NewPlaintext(params, params.MaxLevel()) + b.Run(name("Encoder/Decode/Uint", tc, lvl), func(b *testing.B) { + plaintext := NewPlaintext(params, lvl) b.ResetTimer() for i := 0; i < b.N; i++ { if err := encoder.Decode(plaintext, coeffsUint64); err != nil { @@ -111,8 +90,8 @@ func benchEncoder(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Encoder/Decode/Int"), func(b *testing.B) { - plaintext := NewPlaintext(params, params.MaxLevel()) + b.Run(name("Encoder/Decode/Int", tc, lvl), func(b *testing.B) { + plaintext := NewPlaintext(params, lvl) b.ResetTimer() for i := 0; i < b.N; i++ { if err := encoder.Decode(plaintext, coeffsInt64); err != nil { @@ -123,16 +102,17 @@ func benchEncoder(tc *testContext, b *testing.B) { }) } -func benchEvaluator(tc *testContext, b *testing.B) { +func benchEvaluator(tc *TestContext, b *testing.B) { - params := tc.params - eval := tc.evaluator + params := tc.Params + lvl := params.MaxLevel() + eval := tc.Evl - plaintext := NewPlaintext(params, params.MaxLevel()) - plaintext.Value = rlwe.NewCiphertextRandom(tc.prng, params.Parameters, 0, plaintext.Level()).Value[0] + plaintext := NewPlaintext(params, lvl) + plaintext.Value = rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 0, plaintext.Level()).Value[0] - ciphertext1 := rlwe.NewCiphertextRandom(tc.prng, params.Parameters, 1, params.MaxLevel()) - ciphertext2 := rlwe.NewCiphertextRandom(tc.prng, params.Parameters, 1, params.MaxLevel()) + ciphertext1 := rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 1, lvl) + ciphertext2 := rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 1, lvl) scalar := params.PlaintextModulus() >> 1 *ciphertext1.MetaData = *plaintext.MetaData @@ -140,7 +120,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { vector := plaintext.Value.Coeffs[0][:params.MaxSlots()] - b.Run(GetBenchName(params, "Evaluator/Add/Scalar"), func(b *testing.B) { + b.Run(name("Evaluator/Add/Scalar", tc, lvl), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -151,7 +131,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/Add/Vector"), func(b *testing.B) { + b.Run(name("Evaluator/Add/Vector", tc, lvl), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -162,7 +142,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/Add/Plaintext"), func(b *testing.B) { + b.Run(name("Evaluator/Add/Plaintext", tc, lvl), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -173,7 +153,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/Add/Ciphertext"), func(b *testing.B) { + b.Run(name("Evaluator/Add/Ciphertext", tc, lvl), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -184,7 +164,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/Mul/Scalar"), func(b *testing.B) { + b.Run(name("Evaluator/Mul/Scalar", tc, lvl), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -195,7 +175,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/Mul/Plaintext"), func(b *testing.B) { + b.Run(name("Evaluator/Mul/Plaintext", tc, lvl), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -206,7 +186,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/Mul/Vector"), func(b *testing.B) { + b.Run(name("Evaluator/Mul/Vector", tc, lvl), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -217,7 +197,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/Mul/Ciphertext"), func(b *testing.B) { + b.Run(name("Evaluator/Mul/Ciphertext", tc, lvl), func(b *testing.B) { receiver := NewCiphertext(params, 2, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -228,7 +208,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/MulRelin/Ciphertext"), func(b *testing.B) { + b.Run(name("Evaluator/MulRelin/Ciphertext", tc, lvl), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -239,7 +219,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/MulInvariant/Ciphertext"), func(b *testing.B) { + b.Run(name("Evaluator/MulInvariant/Ciphertext", tc, lvl), func(b *testing.B) { receiver := NewCiphertext(params, 2, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -250,7 +230,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/MulRelinInvariant/Ciphertext"), func(b *testing.B) { + b.Run(name("Evaluator/MulRelinInvariant/Ciphertext", tc, lvl), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -261,7 +241,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/MulThenAdd/Scalar"), func(b *testing.B) { + b.Run(name("Evaluator/MulThenAdd/Scalar", tc, lvl), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -272,7 +252,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/MulThenAdd/Vector"), func(b *testing.B) { + b.Run(name("Evaluator/MulThenAdd/Vector", tc, lvl), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -283,7 +263,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/MulThenAdd/Plaintext"), func(b *testing.B) { + b.Run(name("Evaluator/MulThenAdd/Plaintext", tc, lvl), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -294,7 +274,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/MulThenAdd/Ciphertext"), func(b *testing.B) { + b.Run(name("Evaluator/MulThenAdd/Ciphertext", tc, lvl), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -305,7 +285,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/MulRelinThenAdd/Ciphertext"), func(b *testing.B) { + b.Run(name("Evaluator/MulRelinThenAdd/Ciphertext", tc, lvl), func(b *testing.B) { receiver := NewCiphertext(params, 2, ciphertext1.Level()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -316,7 +296,7 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/Rescale"), func(b *testing.B) { + b.Run(name("Evaluator/Rescale", tc, lvl), func(b *testing.B) { receiver := NewCiphertext(params, 1, ciphertext1.Level()-1) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -327,8 +307,8 @@ func benchEvaluator(tc *testContext, b *testing.B) { } }) - b.Run(GetBenchName(params, "Evaluator/Rotate"), func(b *testing.B) { - gk := tc.kgen.GenGaloisKeyNew(5, tc.sk) + b.Run(name("Evaluator/Rotate", tc, lvl), func(b *testing.B) { + gk := tc.Kgen.GenGaloisKeyNew(5, tc.Sk) evk := rlwe.NewMemEvaluationKeySet(nil, gk) eval := eval.WithKey(evk) receiver := NewCiphertext(params, 1, ciphertext2.Level()) diff --git a/schemes/bgv/bgv_test.go b/schemes/bgv/bgv_test.go index f4e22917..1b7e9d71 100644 --- a/schemes/bgv/bgv_test.go +++ b/schemes/bgv/bgv_test.go @@ -4,40 +4,28 @@ import ( "encoding/json" "flag" "fmt" - "math" "runtime" "slices" "testing" + "github.com/stretchr/testify/require" + "github.com/tuneinsight/lattigo/v5/core/rlwe" "github.com/tuneinsight/lattigo/v5/ring" - - "github.com/stretchr/testify/require" - "github.com/tuneinsight/lattigo/v5/utils/sampling" ) var flagPrintNoise = flag.Bool("print-noise", false, "print the residual noise") var flagParamString = flag.String("params", "", "specify the test cryptographic parameters as a JSON string. Overrides -short.") -func GetTestName(opname string, p Parameters, lvl int) string { - return fmt.Sprintf("%s/LogN=%d/logQ=%d/logP=%d/LogSlots=%dx%d/logT=%d/Qi=%d/Pi=%d/lvl=%d", - opname, - p.LogN(), - int(math.Round(p.LogQ())), - int(math.Round(p.LogP())), - p.LogMaxDimensions().Rows, - p.LogMaxDimensions().Cols, - int(math.Round(p.LogT())), - p.QCount(), - p.PCount(), - lvl) +func name(op string, tc *TestContext, lvl int) string { + return fmt.Sprintf("%s/%s/lvl=%d", op, tc, lvl) } func TestBGV(t *testing.T) { var err error - paramsLiterals := testParams + paramsLiterals := TestParams if *flagParamString != "" { var jsonParams ParametersLiteral @@ -49,23 +37,13 @@ func TestBGV(t *testing.T) { for _, p := range paramsLiterals[:] { - for _, plaintextModulus := range testPlaintextModulus[:] { + for _, plaintextModulus := range TestPlaintextModulus[:] { p.PlaintextModulus = plaintextModulus - var params Parameters - if params, err = NewParametersFromLiteral(p); err != nil { - t.Error(err) - t.Fail() - } + tc := NewTestContext(p) - var tc *testContext - if tc, err = genTestParams(params); err != nil { - t.Error(err) - t.Fail() - } - - for _, testSet := range []func(tc *testContext, t *testing.T){ + for _, testSet := range []func(tc *TestContext, t *testing.T){ testParameters, testEncoder, testEvaluator, @@ -77,114 +55,20 @@ func TestBGV(t *testing.T) { } } -type testContext struct { - params Parameters - ringQ *ring.Ring - ringT *ring.Ring - prng sampling.PRNG - uSampler *ring.UniformSampler - encoder *Encoder - kgen *rlwe.KeyGenerator - sk *rlwe.SecretKey - pk *rlwe.PublicKey - encryptorPk *rlwe.Encryptor - encryptorSk *rlwe.Encryptor - decryptor *rlwe.Decryptor - evaluator *Evaluator - testLevel []int -} +func testParameters(tc *TestContext, t *testing.T) { + t.Run(name("Parameters/Binary", tc, 0), func(t *testing.T) { -func genTestParams(params Parameters) (tc *testContext, err error) { - - tc = new(testContext) - tc.params = params - - if tc.prng, err = sampling.NewPRNG(); err != nil { - return nil, err - } - - tc.ringQ = params.RingQ() - tc.ringT = params.RingT() - - tc.uSampler = ring.NewUniformSampler(tc.prng, tc.ringT) - tc.kgen = NewKeyGenerator(tc.params) - tc.sk, tc.pk = tc.kgen.GenKeyPairNew() - tc.encoder = NewEncoder(tc.params) - - tc.encryptorPk = NewEncryptor(tc.params, tc.pk) - tc.encryptorSk = NewEncryptor(tc.params, tc.sk) - tc.decryptor = NewDecryptor(tc.params, tc.sk) - tc.evaluator = NewEvaluator(tc.params, rlwe.NewMemEvaluationKeySet(tc.kgen.GenRelinearizationKeyNew(tc.sk))) - - tc.testLevel = []int{0, params.MaxLevel()} - - return -} - -func newTestVectorsLvl(level int, scale rlwe.Scale, tc *testContext, encryptor *rlwe.Encryptor) (coeffs ring.Poly, plaintext *rlwe.Plaintext, ciphertext *rlwe.Ciphertext) { - coeffs = tc.uSampler.ReadNew() - for i := range coeffs.Coeffs[0] { - coeffs.Coeffs[0][i] = uint64(i) - } - - plaintext = NewPlaintext(tc.params, level) - plaintext.Scale = scale - if err := tc.encoder.Encode(coeffs.Coeffs[0], plaintext); err != nil { - panic(err) - } - if encryptor != nil { - var err error - ciphertext, err = encryptor.EncryptNew(plaintext) - if err != nil { - panic(err) - } - } - - return coeffs, plaintext, ciphertext -} - -func verifyTestVectors(tc *testContext, decryptor *rlwe.Decryptor, coeffs ring.Poly, element rlwe.ElementInterface[ring.Poly], t *testing.T) { - - coeffsTest := make([]uint64, tc.params.MaxSlots()) - - switch el := element.(type) { - case *rlwe.Plaintext: - require.NoError(t, tc.encoder.Decode(el, coeffsTest)) - case *rlwe.Ciphertext: - - pt := decryptor.DecryptNew(el) - - require.NoError(t, tc.encoder.Decode(pt, coeffsTest)) - - if *flagPrintNoise { - require.NoError(t, tc.encoder.Encode(coeffsTest, pt)) - ct, err := tc.evaluator.SubNew(el, pt) - require.NoError(t, err) - vartmp, _, _ := rlwe.Norm(ct, decryptor) - t.Logf("STD(noise): %f\n", vartmp) - } - - default: - t.Error("invalid test object to verify") - } - - require.True(t, slices.Equal(coeffs.Coeffs[0], coeffsTest)) -} - -func testParameters(tc *testContext, t *testing.T) { - t.Run(GetTestName("Parameters/Binary", tc.params, 0), func(t *testing.T) { - - bytes, err := tc.params.MarshalBinary() + bytes, err := tc.Params.MarshalBinary() require.Nil(t, err) var p Parameters require.Nil(t, p.UnmarshalBinary(bytes)) - require.True(t, tc.params.Equal(&p)) + require.True(t, tc.Params.Equal(&p)) }) - t.Run(GetTestName("Parameters/JSON", tc.params, 0), func(t *testing.T) { + t.Run(name("Parameters/JSON", tc, 0), func(t *testing.T) { // checks that parameters can be marshalled without error - data, err := json.Marshal(tc.params) + data, err := json.Marshal(tc.Params) require.Nil(t, err) require.NotNil(t, data) @@ -192,10 +76,10 @@ func testParameters(tc *testContext, t *testing.T) { var paramsRec Parameters err = json.Unmarshal(data, ¶msRec) require.Nil(t, err) - require.True(t, tc.params.Equal(¶msRec)) + require.True(t, tc.Params.Equal(¶msRec)) // checks that the Parameters can be unmarshalled with log-moduli definition without error - dataWithLogModuli := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[60], "PlaintextModulus":65537}`, tc.params.LogN())) + dataWithLogModuli := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[60], "PlaintextModulus":65537}`, tc.Params.LogN())) var paramsWithLogModuli Parameters err = json.Unmarshal(dataWithLogModuli, ¶msWithLogModuli) require.Nil(t, err) @@ -205,7 +89,7 @@ func testParameters(tc *testContext, t *testing.T) { require.Equal(t, rlwe.DefaultXs, paramsWithLogModuli.Xs()) // Omitting Xe should result in Default being used // checks that one can provide custom parameters for the secret-key and error distributions - dataWithCustomSecrets := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[60], "PlaintextModulus":65537, "Xs": {"Type": "Ternary", "H": 192}, "Xe": {"Type": "DiscreteGaussian", "Sigma": 6.6, "Bound": 39.6}}`, tc.params.LogN())) + dataWithCustomSecrets := []byte(fmt.Sprintf(`{"LogN":%d,"LogQ":[50,50],"LogP":[60], "PlaintextModulus":65537, "Xs": {"Type": "Ternary", "H": 192}, "Xe": {"Type": "DiscreteGaussian", "Sigma": 6.6, "Bound": 39.6}}`, tc.Params.LogN())) var paramsWithCustomSecrets Parameters err = json.Unmarshal(dataWithCustomSecrets, ¶msWithCustomSecrets) require.Nil(t, err) @@ -214,21 +98,22 @@ func testParameters(tc *testContext, t *testing.T) { }) } -func testEncoder(tc *testContext, t *testing.T) { +func testEncoder(tc *TestContext, t *testing.T) { + testLevel := [2]int{0, tc.Params.MaxLevel()} - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Encoder/Uint/IsBatched=true", tc.params, lvl), func(t *testing.T) { - values, plaintext, _ := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, nil) - verifyTestVectors(tc, nil, values, plaintext, t) + for _, lvl := range testLevel { + t.Run(name("Encoder/Uint/IsBatched=true", tc, lvl), func(t *testing.T) { + values, plaintext, _ := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale()) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, plaintext, values, t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Encoder/Int/IsBatched=true", tc.params, lvl), func(t *testing.T) { + for _, lvl := range testLevel { + t.Run(name("Encoder/Int/IsBatched=true", tc, lvl), func(t *testing.T) { - T := tc.params.PlaintextModulus() + T := tc.Params.PlaintextModulus() THalf := T >> 1 - poly := tc.uSampler.ReadNew() + poly := tc.Sampler.ReadNew() coeffs := make([]int64, poly.N()) for i, c := range poly.Coeffs[0] { c %= T @@ -239,38 +124,38 @@ func testEncoder(tc *testContext, t *testing.T) { } } - plaintext := NewPlaintext(tc.params, lvl) - tc.encoder.Encode(coeffs, plaintext) - have := make([]int64, tc.params.MaxSlots()) - tc.encoder.Decode(plaintext, have) + plaintext := NewPlaintext(tc.Params, lvl) + tc.Ecd.Encode(coeffs, plaintext) + have := make([]int64, tc.Params.MaxSlots()) + tc.Ecd.Decode(plaintext, have) require.True(t, slices.Equal(coeffs, have)) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Encoder/Uint/IsBatched=false", tc.params, lvl), func(t *testing.T) { - T := tc.params.PlaintextModulus() - poly := tc.uSampler.ReadNew() + for _, lvl := range testLevel { + t.Run(name("Encoder/Uint/IsBatched=false", tc, lvl), func(t *testing.T) { + T := tc.Params.PlaintextModulus() + poly := tc.Sampler.ReadNew() coeffs := make([]uint64, poly.N()) for i, c := range poly.Coeffs[0] { coeffs[i] = c % T } - plaintext := NewPlaintext(tc.params, lvl) + plaintext := NewPlaintext(tc.Params, lvl) plaintext.IsBatched = false - require.NoError(t, tc.encoder.Encode(coeffs, plaintext)) - have := make([]uint64, tc.params.MaxSlots()) - require.NoError(t, tc.encoder.Decode(plaintext, have)) + require.NoError(t, tc.Ecd.Encode(coeffs, plaintext)) + have := make([]uint64, tc.Params.MaxSlots()) + require.NoError(t, tc.Ecd.Decode(plaintext, have)) require.True(t, slices.Equal(coeffs, have)) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Encoder/Int/IsBatched=false", tc.params, lvl), func(t *testing.T) { + for _, lvl := range testLevel { + t.Run(name("Encoder/Int/IsBatched=false", tc, lvl), func(t *testing.T) { - T := tc.params.PlaintextModulus() + T := tc.Params.PlaintextModulus() THalf := T >> 1 - poly := tc.uSampler.ReadNew() + poly := tc.Sampler.ReadNew() coeffs := make([]int64, poly.N()) for i, c := range poly.Coeffs[0] { c %= T @@ -281,438 +166,468 @@ func testEncoder(tc *testContext, t *testing.T) { } } - plaintext := NewPlaintext(tc.params, lvl) + plaintext := NewPlaintext(tc.Params, lvl) plaintext.IsBatched = false - require.NoError(t, tc.encoder.Encode(coeffs, plaintext)) - have := make([]int64, tc.params.MaxSlots()) - require.NoError(t, tc.encoder.Decode(plaintext, have)) + require.NoError(t, tc.Ecd.Encode(coeffs, plaintext)) + have := make([]int64, tc.Params.MaxSlots()) + require.NoError(t, tc.Ecd.Decode(plaintext, have)) require.True(t, slices.Equal(coeffs, have)) }) } } -func testEvaluator(tc *testContext, t *testing.T) { +func testEvaluator(tc *TestContext, t *testing.T) { + testLevel := [2]int{0, tc.Params.MaxLevel()} - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/Add/Ct/Ct/New", tc.params, lvl), func(t *testing.T) { - - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) - values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + for _, lvl := range testLevel { + t.Run(name("Evaluator/Add/Ct/Ct/New", tc, lvl), func(t *testing.T) { + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0) - ciphertext2, err := tc.evaluator.AddNew(ciphertext0, ciphertext1) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} + + ciphertext2, err := tc.Evl.AddNew(ciphertext0, ciphertext1) require.NoError(t, err) - tc.ringT.Add(values0, values1, values0) - - verifyTestVectors(tc, tc.decryptor, values0, ciphertext2, t) + tc.Params.RingT().Add(p0, p1, p0) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext2, p0.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/Add/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) { - - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) - values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + for _, lvl := range testLevel { + t.Run(name("Evaluator/Add/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) { + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0) - require.NoError(t, tc.evaluator.Add(ciphertext0, ciphertext1, ciphertext0)) - tc.ringT.Add(values0, values1, values0) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} - verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t) + require.NoError(t, tc.Evl.Add(ciphertext0, ciphertext1, ciphertext0)) + tc.Params.RingT().Add(p0, p1, p0) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/Add/Ct/Pt/Inplace", tc.params, lvl), func(t *testing.T) { - - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) - values1, plaintext, _ := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + for _, lvl := range testLevel { + t.Run(name("Evaluator/Add/Ct/Pt/Inplace", tc, lvl), func(t *testing.T) { + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, plaintext, _ := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) require.True(t, ciphertext0.Scale.Cmp(plaintext.Scale) != 0) - require.NoError(t, tc.evaluator.Add(ciphertext0, plaintext, ciphertext0)) - tc.ringT.Add(values0, values1, values0) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} - verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t) + require.NoError(t, tc.Evl.Add(ciphertext0, plaintext, ciphertext0)) + tc.Params.RingT().Add(p0, p1, p0) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/Add/Ct/Scalar/Inplace", tc.params, lvl), func(t *testing.T) { + for _, lvl := range testLevel { + t.Run(name("Evaluator/Add/Ct/Scalar/Inplace", tc, lvl), func(t *testing.T) { + values, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale()) - values, _, ciphertext := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk) + scalar := tc.Params.PlaintextModulus() >> 1 - scalar := tc.params.PlaintextModulus() >> 1 + p := ring.Poly{Coeffs: [][]uint64{values}} - require.NoError(t, tc.evaluator.Add(ciphertext, scalar, ciphertext)) - tc.ringT.AddScalar(values, scalar, values) - - verifyTestVectors(tc, tc.decryptor, values, ciphertext, t) + require.NoError(t, tc.Evl.Add(ciphertext, scalar, ciphertext)) + tc.Params.RingT().AddScalar(p, scalar, p) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/Add/Ct/Vector/Inplace", tc.params, lvl), func(t *testing.T) { + for _, lvl := range testLevel { + t.Run(name("Evaluator/Add/Ct/Vector/Inplace", tc, lvl), func(t *testing.T) { + values, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale()) - values, _, ciphertext := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk) + p := ring.Poly{Coeffs: [][]uint64{values}} - require.NoError(t, tc.evaluator.Add(ciphertext, values.Coeffs[0], ciphertext)) - tc.ringT.Add(values, values, values) - - verifyTestVectors(tc, tc.decryptor, values, ciphertext, t) + require.NoError(t, tc.Evl.Add(ciphertext, values, ciphertext)) + tc.Params.RingT().Add(p, p, p) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/Sub/Ct/Ct/New", tc.params, lvl), func(t *testing.T) { - - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) - values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + for _, lvl := range testLevel { + t.Run(name("Evaluator/Sub/Ct/Ct/New", tc, lvl), func(t *testing.T) { + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0) - ciphertext0, err := tc.evaluator.SubNew(ciphertext0, ciphertext1) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} + + ciphertext0, err := tc.Evl.SubNew(ciphertext0, ciphertext1) require.NoError(t, err) - tc.ringT.Sub(values0, values1, values0) - - verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t) + tc.Params.RingT().Sub(p0, p1, p0) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/Sub/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) { - - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) - values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + for _, lvl := range testLevel { + t.Run(name("Evaluator/Sub/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) { + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0) - require.NoError(t, tc.evaluator.Sub(ciphertext0, ciphertext1, ciphertext0)) - tc.ringT.Sub(values0, values1, values0) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} - verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t) + require.NoError(t, tc.Evl.Sub(ciphertext0, ciphertext1, ciphertext0)) + tc.Params.RingT().Sub(p0, p1, p0) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/Sub/Ct/Pt/Inplace", tc.params, lvl), func(t *testing.T) { - - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) - values1, plaintext, _ := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + for _, lvl := range testLevel { + t.Run(name("Evaluator/Sub/Ct/Pt/Inplace", tc, lvl), func(t *testing.T) { + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, plaintext, _ := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) require.True(t, ciphertext0.Scale.Cmp(plaintext.Scale) != 0) - require.NoError(t, tc.evaluator.Sub(ciphertext0, plaintext, ciphertext0)) - tc.ringT.Sub(values0, values1, values0) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} - verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t) + require.NoError(t, tc.Evl.Sub(ciphertext0, plaintext, ciphertext0)) + tc.Params.RingT().Sub(p0, p1, p0) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/Sub/Ct/Scalar/Inplace", tc.params, lvl), func(t *testing.T) { + for _, lvl := range testLevel { + t.Run(name("Evaluator/Sub/Ct/Scalar/Inplace", tc, lvl), func(t *testing.T) { + values, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale()) - values, _, ciphertext := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk) + scalar := tc.Params.PlaintextModulus() >> 1 - scalar := tc.params.PlaintextModulus() >> 1 + p := ring.Poly{Coeffs: [][]uint64{values}} - require.NoError(t, tc.evaluator.Sub(ciphertext, scalar, ciphertext)) - tc.ringT.SubScalar(values, scalar, values) - - verifyTestVectors(tc, tc.decryptor, values, ciphertext, t) + require.NoError(t, tc.Evl.Sub(ciphertext, scalar, ciphertext)) + tc.Params.RingT().SubScalar(p, scalar, p) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/Sub/Ct/Vector/Inplace", tc.params, lvl), func(t *testing.T) { + for _, lvl := range testLevel { + t.Run(name("Evaluator/Sub/Ct/Vector/Inplace", tc, lvl), func(t *testing.T) { + values, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale()) - values, _, ciphertext := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk) + p := ring.Poly{Coeffs: [][]uint64{values}} - require.NoError(t, tc.evaluator.Sub(ciphertext, values.Coeffs[0], ciphertext)) - tc.ringT.Sub(values, values, values) - - verifyTestVectors(tc, tc.decryptor, values, ciphertext, t) + require.NoError(t, tc.Evl.Sub(ciphertext, values, ciphertext)) + tc.Params.RingT().Sub(p, p, p) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/Mul/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) { - + for _, lvl := range testLevel { + t.Run(name("Evaluator/Mul/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) { if lvl == 0 { t.Skip("Skipping: Level = 0") } - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) - values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0) - require.NoError(t, tc.evaluator.Mul(ciphertext0, ciphertext1, ciphertext0)) - tc.ringT.MulCoeffsBarrett(values0, values1, values0) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} - verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t) + require.NoError(t, tc.Evl.Mul(ciphertext0, ciphertext1, ciphertext0)) + tc.Params.RingT().MulCoeffsBarrett(p0, p1, p0) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/Mul/Ct/Pt/Inplace", tc.params, lvl), func(t *testing.T) { + for _, lvl := range testLevel { + t.Run(name("Evaluator/Mul/Ct/Pt/Inplace", tc, lvl), func(t *testing.T) { + if lvl == 0 { + t.Skip("Skipping: Level = 0") + } + + values0, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, plaintext, _ := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) + + require.True(t, ciphertext.Scale.Cmp(plaintext.Scale) != 0) + + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} + + require.NoError(t, tc.Evl.Mul(ciphertext, plaintext, ciphertext)) + tc.Params.RingT().MulCoeffsBarrett(p0, p1, p0) + + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p0.Coeffs[0], t) + }) + } + + for _, lvl := range testLevel { + t.Run(name("Evaluator/Mul/Ct/Scalar/Inplace", tc, lvl), func(t *testing.T) { + if lvl == 0 { + t.Skip("Skipping: Level = 0") + } + + values, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale()) + + scalar := tc.Params.PlaintextModulus() >> 1 + + p := ring.Poly{Coeffs: [][]uint64{values}} + + require.NoError(t, tc.Evl.Mul(ciphertext, scalar, ciphertext)) + tc.Params.RingT().MulScalar(p, scalar, p) + + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p.Coeffs[0], t) + }) + } + + for _, lvl := range testLevel { + t.Run(name("Evaluator/Mul/Ct/Vector/Inplace", tc, lvl), func(t *testing.T) { if lvl == 0 { t.Skip("Skipping: Level = 0") } - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) - values1, plaintext, _ := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + values, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale()) - require.True(t, ciphertext0.Scale.Cmp(plaintext.Scale) != 0) + p := ring.Poly{Coeffs: [][]uint64{values}} - require.NoError(t, tc.evaluator.Mul(ciphertext0, plaintext, ciphertext0)) - tc.ringT.MulCoeffsBarrett(values0, values1, values0) - - verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t) + require.NoError(t, tc.Evl.Mul(ciphertext, values, ciphertext)) + tc.Params.RingT().MulCoeffsBarrett(p, p, p) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/Mul/Ct/Scalar/Inplace", tc.params, lvl), func(t *testing.T) { - + for _, lvl := range testLevel { + t.Run(name("Evaluator/Square/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) { if lvl == 0 { t.Skip("Skipping: Level = 0") } - values, _, ciphertext := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk) + values, _, ciphertext := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale()) - scalar := tc.params.PlaintextModulus() >> 1 + p := ring.Poly{Coeffs: [][]uint64{values}} - require.NoError(t, tc.evaluator.Mul(ciphertext, scalar, ciphertext)) - tc.ringT.MulScalar(values, scalar, values) + require.NoError(t, tc.Evl.Mul(ciphertext, ciphertext, ciphertext)) + tc.Params.RingT().MulCoeffsBarrett(p, p, p) - verifyTestVectors(tc, tc.decryptor, values, ciphertext, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext, p.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/Mul/Ct/Vector/Inplace", tc.params, lvl), func(t *testing.T) { - + for _, lvl := range testLevel { + t.Run(name("Evaluator/MulRelin/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) { if lvl == 0 { t.Skip("Skipping: Level = 0") } - values, _, ciphertext := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk) + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) - require.NoError(t, tc.evaluator.Mul(ciphertext, values.Coeffs[0], ciphertext)) - tc.ringT.MulCoeffsBarrett(values, values, values) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} - verifyTestVectors(tc, tc.decryptor, values, ciphertext, t) - }) - } - - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/Square/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) { - - if lvl == 0 { - t.Skip("Skipping: Level = 0") - } - - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) - - require.NoError(t, tc.evaluator.Mul(ciphertext0, ciphertext0, ciphertext0)) - tc.ringT.MulCoeffsBarrett(values0, values0, values0) - - verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t) - }) - } - - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/MulRelin/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) { - - if lvl == 0 { - t.Skip("Skipping: Level = 0") - } - - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) - values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) - - tc.ringT.MulCoeffsBarrett(values0, values1, values0) + tc.Params.RingT().MulCoeffsBarrett(p0, p1, p0) require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0) - receiver := NewCiphertext(tc.params, 1, lvl) + receiver := NewCiphertext(tc.Params, 1, lvl) - require.NoError(t, tc.evaluator.MulRelin(ciphertext0, ciphertext1, receiver)) + require.NoError(t, tc.Evl.MulRelin(ciphertext0, ciphertext1, receiver)) + require.NoError(t, tc.Evl.Rescale(receiver, receiver)) - require.NoError(t, tc.evaluator.Rescale(receiver, receiver)) - - verifyTestVectors(tc, tc.decryptor, values0, receiver, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, receiver, p0.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/MulThenAdd/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) { - + for _, lvl := range testLevel { + t.Run(name("Evaluator/MulThenAdd/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) { if lvl == 0 { t.Skip("Skipping: Level = 0") } - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk) - values1, _, ciphertext1 := newTestVectorsLvl(lvl, rlwe.NewScale(2), tc, tc.encryptorSk) - values2, _, ciphertext2 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale()) + values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(2)) + values2, _, ciphertext2 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) + + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} + p2 := ring.Poly{Coeffs: [][]uint64{values2}} require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0) require.True(t, ciphertext0.Scale.Cmp(ciphertext2.Scale) != 0) - require.NoError(t, tc.evaluator.MulThenAdd(ciphertext0, ciphertext1, ciphertext2)) - tc.ringT.MulCoeffsBarrettThenAdd(values0, values1, values2) + require.NoError(t, tc.Evl.MulThenAdd(ciphertext0, ciphertext1, ciphertext2)) + tc.Params.RingT().MulCoeffsBarrettThenAdd(p0, p1, p2) - verifyTestVectors(tc, tc.decryptor, values2, ciphertext2, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext2, p2.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/MulThenAdd/Ct/Pt/Inplace", tc.params, lvl), func(t *testing.T) { - + for _, lvl := range testLevel { + t.Run(name("Evaluator/MulThenAdd/Ct/Pt/Inplace", tc, lvl), func(t *testing.T) { if lvl == 0 { t.Skip("Skipping: Level = 0") } - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk) - values1, plaintext1, _ := newTestVectorsLvl(lvl, rlwe.NewScale(2), tc, tc.encryptorSk) - values2, _, ciphertext2 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale()) + values1, plaintext1, _ := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(2)) + values2, _, ciphertext2 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) + + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} + p2 := ring.Poly{Coeffs: [][]uint64{values2}} require.True(t, ciphertext0.Scale.Cmp(plaintext1.Scale) != 0) require.True(t, ciphertext0.Scale.Cmp(ciphertext2.Scale) != 0) - require.NoError(t, tc.evaluator.MulThenAdd(ciphertext0, plaintext1, ciphertext2)) - tc.ringT.MulCoeffsBarrettThenAdd(values0, values1, values2) + require.NoError(t, tc.Evl.MulThenAdd(ciphertext0, plaintext1, ciphertext2)) + tc.Params.RingT().MulCoeffsBarrettThenAdd(p0, p1, p2) - verifyTestVectors(tc, tc.decryptor, values2, ciphertext2, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext2, p2.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/MulThenAdd/Ct/Scalar/Inplace", tc.params, lvl), func(t *testing.T) { - + for _, lvl := range testLevel { + t.Run(name("Evaluator/MulThenAdd/Ct/Scalar/Inplace", tc, lvl), func(t *testing.T) { if lvl == 0 { t.Skip("Skipping: Level = 0") } - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) - values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) + + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0) - scalar := tc.params.PlaintextModulus() >> 1 + scalar := tc.Params.PlaintextModulus() >> 1 - require.NoError(t, tc.evaluator.MulThenAdd(ciphertext0, scalar, ciphertext1)) - tc.ringT.MulScalarThenAdd(values0, scalar, values1) + require.NoError(t, tc.Evl.MulThenAdd(ciphertext0, scalar, ciphertext1)) + tc.Params.RingT().MulScalarThenAdd(p0, scalar, p1) - verifyTestVectors(tc, tc.decryptor, values1, ciphertext1, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext1, p1.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/MulThenAdd/Ct/Vector/Inplace", tc.params, lvl), func(t *testing.T) { - + for _, lvl := range testLevel { + t.Run(name("Evaluator/MulThenAdd/Ct/Vector/Inplace", tc, lvl), func(t *testing.T) { if lvl == 0 { t.Skip("Skipping: Level = 0") } - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) - values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} + scale := ciphertext1.Scale - require.NoError(t, tc.evaluator.MulThenAdd(ciphertext0, values1.Coeffs[0], ciphertext1)) - tc.ringT.MulCoeffsBarrettThenAdd(values0, values1, values1) + require.NoError(t, tc.Evl.MulThenAdd(ciphertext0, values1, ciphertext1)) + tc.Params.RingT().MulCoeffsBarrettThenAdd(p0, p1, p1) // Checks that output scale isn't changed require.True(t, scale.Equal(ciphertext1.Scale)) - verifyTestVectors(tc, tc.decryptor, values1, ciphertext1, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext1, p1.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel { - t.Run(GetTestName("Evaluator/MulRelinThenAdd/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) { - + for _, lvl := range testLevel { + t.Run(name("Evaluator/MulRelinThenAdd/Ct/Ct/Inplace", tc, lvl), func(t *testing.T) { if lvl == 0 { t.Skip("Skipping: Level = 0") } - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk) - values1, _, ciphertext1 := newTestVectorsLvl(lvl, rlwe.NewScale(2), tc, tc.encryptorSk) - values2, _, ciphertext2 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale()) + values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(2)) + values2, _, ciphertext2 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(7)) require.True(t, ciphertext0.Scale.Cmp(ciphertext1.Scale) != 0) require.True(t, ciphertext0.Scale.Cmp(ciphertext2.Scale) != 0) - require.NoError(t, tc.evaluator.MulRelinThenAdd(ciphertext0, ciphertext1, ciphertext2)) - tc.ringT.MulCoeffsBarrettThenAdd(values0, values1, values2) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} + p2 := ring.Poly{Coeffs: [][]uint64{values2}} - verifyTestVectors(tc, tc.decryptor, values2, ciphertext2, t) + require.NoError(t, tc.Evl.MulRelinThenAdd(ciphertext0, ciphertext1, ciphertext2)) + tc.Params.RingT().MulCoeffsBarrettThenAdd(p0, p1, p2) + + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext2, p2.Coeffs[0], t) }) } - for _, lvl := range tc.testLevel[:] { - t.Run(GetTestName("Evaluator/Rescale", tc.params, lvl), func(t *testing.T) { + for _, lvl := range testLevel[:] { + t.Run(name("Evaluator/Rescale", tc, lvl), func(t *testing.T) { - ringT := tc.params.RingT() + ringT := tc.Params.RingT() - values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorPk) + values0, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale()) printNoise := func(msg string, values []uint64, ct *rlwe.Ciphertext) { - pt := NewPlaintext(tc.params, ct.Level()) + pt := NewPlaintext(tc.Params, ct.Level()) pt.MetaData = ciphertext0.MetaData - require.NoError(t, tc.encoder.Encode(values0.Coeffs[0], pt)) - ct, err := tc.evaluator.SubNew(ct, pt) + require.NoError(t, tc.Ecd.Encode(values0, pt)) + ct, err := tc.Evl.SubNew(ct, pt) require.NoError(t, err) - vartmp, _, _ := rlwe.Norm(ct, tc.decryptor) + vartmp, _, _ := rlwe.Norm(ct, tc.Dec) t.Logf("STD(noise) %s: %f\n", msg, vartmp) } if lvl != 0 { - values1, _, ciphertext1 := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk) + values1, _, ciphertext1 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.DefaultScale()) if *flagPrintNoise { - printNoise("0x", values0.Coeffs[0], ciphertext0) + printNoise("0x", values0, ciphertext0) } - for i := 0; i < lvl; i++ { - tc.evaluator.MulRelin(ciphertext0, ciphertext1, ciphertext0) + p0 := ring.Poly{Coeffs: [][]uint64{values0}} + p1 := ring.Poly{Coeffs: [][]uint64{values1}} - ringT.MulCoeffsBarrett(values0, values1, values0) + for i := 0; i < lvl; i++ { + tc.Evl.MulRelin(ciphertext0, ciphertext1, ciphertext0) + + ringT.MulCoeffsBarrett(p0, p1, p0) if *flagPrintNoise { - printNoise(fmt.Sprintf("%dx", i+1), values0.Coeffs[0], ciphertext0) + printNoise(fmt.Sprintf("%dx", i+1), p0.Coeffs[0], ciphertext0) } } - verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t) - require.Nil(t, tc.evaluator.Rescale(ciphertext0, ciphertext0)) + require.Nil(t, tc.Evl.Rescale(ciphertext0, ciphertext0)) - verifyTestVectors(tc, tc.decryptor, values0, ciphertext0, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, ciphertext0, p0.Coeffs[0], t) } else { - require.NotNil(t, tc.evaluator.Rescale(ciphertext0, ciphertext0)) + require.NotNil(t, tc.Evl.Rescale(ciphertext0, ciphertext0)) } }) } diff --git a/schemes/bgv/test_parameters.go b/schemes/bgv/test_parameters.go deleted file mode 100644 index 378f0aa5..00000000 --- a/schemes/bgv/test_parameters.go +++ /dev/null @@ -1,14 +0,0 @@ -package bgv - -var ( - // testInsecure are insecure parameters used for the sole purpose of fast testing. - testInsecure = ParametersLiteral{ - LogN: 10, - Q: []uint64{0x3fffffa8001, 0x1000090001, 0x10000c8001, 0x10000f0001, 0xffff00001}, - P: []uint64{0x7fffffd8001}, - } - - testPlaintextModulus = []uint64{0x101, 0xffc001} - - testParams = []ParametersLiteral{testInsecure} -) diff --git a/schemes/bgv/test_utils.go b/schemes/bgv/test_utils.go new file mode 100644 index 00000000..b9cf183f --- /dev/null +++ b/schemes/bgv/test_utils.go @@ -0,0 +1,120 @@ +package bgv + +import ( + "fmt" + "math" + "slices" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/tuneinsight/lattigo/v5/core/rlwe" + "github.com/tuneinsight/lattigo/v5/ring" + "github.com/tuneinsight/lattigo/v5/utils/sampling" +) + +type TestContext struct { + Params Parameters + Ecd *Encoder + + Prng sampling.PRNG + Sampler *ring.UniformSampler + + Kgen *rlwe.KeyGenerator + Sk *rlwe.SecretKey + Pk *rlwe.PublicKey + + Enc *rlwe.Encryptor + Dec *rlwe.Decryptor + + Evl *Evaluator +} + +func NewTestContext(params ParametersLiteral) *TestContext { + tc := new(TestContext) + + var err error + + tc.Params, err = NewParametersFromLiteral(params) + if err != nil { + panic(err) + } + tc.Ecd = NewEncoder(tc.Params) + + tc.Prng, err = sampling.NewPRNG() + if err != nil { + panic(err) + } + tc.Sampler = ring.NewUniformSampler(tc.Prng, tc.Params.RingT()) + + tc.Kgen = rlwe.NewKeyGenerator(tc.Params) + tc.Sk, tc.Pk = tc.Kgen.GenKeyPairNew() + + tc.Enc = rlwe.NewEncryptor(tc.Params, tc.Pk) + tc.Dec = rlwe.NewDecryptor(tc.Params, tc.Sk) + + tc.Evl = NewEvaluator(tc.Params, rlwe.NewMemEvaluationKeySet(tc.Kgen.GenRelinearizationKeyNew(tc.Sk))) + + return tc +} + +func (tc TestContext) String() string { + return fmt.Sprintf("LogN=%d/logQ=%d/logP=%d/LogSlots=%dx%d/logT=%d/Qi=%d/Pi=%d", + tc.Params.LogN(), + int(math.Round(tc.Params.LogQ())), + int(math.Round(tc.Params.LogP())), + tc.Params.LogMaxDimensions().Rows, + tc.Params.LogMaxDimensions().Cols, + int(math.Round(tc.Params.LogT())), + tc.Params.QCount(), + tc.Params.PCount()) +} + +func VerifyTestVectors(params Parameters, encoder *Encoder, decryptor *rlwe.Decryptor, have interface{}, want []uint64, t *testing.T) { + values := make([]uint64, params.MaxSlots()) + + switch have := have.(type) { + case *rlwe.Plaintext: + require.NoError(t, encoder.Decode(have, values)) + case *rlwe.Ciphertext: + require.NoError(t, encoder.Decode(decryptor.DecryptNew(have), values)) + default: + t.Error("invalid unsupported test object type") + } + + require.True(t, slices.Equal(values, want)) +} + +func NewTestVector(params Parameters, encoder *Encoder, encryptor *rlwe.Encryptor, level int, scale rlwe.Scale) (values []uint64, pt *rlwe.Plaintext, ct *rlwe.Ciphertext) { + values = make([]uint64, params.MaxSlots()) + for i := range values { + values[i] = sampling.RandUint64() % params.PlaintextModulus() + } + + pt = NewPlaintext(params, level) + pt.Scale = scale + if err := encoder.Encode(values, pt); err != nil { + panic(err) + } + if encryptor != nil { + var err error + ct, err = encryptor.EncryptNew(pt) + if err != nil { + panic(err) + } + } + return +} + +var ( + // BgvTestInsecure are insecure parameters used for the sole purpose of fast testing. + TestInsecure = ParametersLiteral{ + LogN: 10, + Q: []uint64{0x3fffffa8001, 0x1000090001, 0x10000c8001, 0x10000f0001, 0xffff00001}, + P: []uint64{0x7fffffd8001}, + } + + TestPlaintextModulus = []uint64{0x101, 0xffc001} + + TestParams = []ParametersLiteral{TestInsecure} +) diff --git a/schemes/ckks/ckks_benchmarks_test.go b/schemes/ckks/ckks_benchmarks_test.go index f65b0f71..d87fedfd 100644 --- a/schemes/ckks/ckks_benchmarks_test.go +++ b/schemes/ckks/ckks_benchmarks_test.go @@ -2,7 +2,6 @@ package ckks import ( "encoding/json" - "fmt" "runtime" "testing" @@ -11,28 +10,7 @@ import ( "github.com/tuneinsight/lattigo/v5/utils/sampling" ) -func GetBenchName(params Parameters, opname string) string { - - var PrecisionMod string - switch params.precisionMode { - case PREC64: - PrecisionMod = "PREC64" - case PREC128: - PrecisionMod = "PREC128" - } - - return fmt.Sprintf("%s/RingType=%s/logN=%d/Qi=%d/Pi=%d/LogSlots=%d/%s", - opname, - params.RingType(), - params.LogN(), - params.QCount(), - params.PCount(), - params.LogMaxSlots(), - PrecisionMod) -} - func BenchmarkCKKS(b *testing.B) { - var err error var testParams []ParametersLiteral @@ -56,18 +34,9 @@ func BenchmarkCKKS(b *testing.B) { for _, paramsLiteral := range testParams { - var params Parameters - if params, err = NewParametersFromLiteral(paramsLiteral); err != nil { - b.Error(err) - b.Fail() - } + tc := NewTestContext(paramsLiteral) - var tc *testContext - if tc, err = genTestParams(params); err != nil { - b.Fatal(err) - } - - for _, testSet := range []func(tc *testContext, b *testing.B){ + for _, testSet := range []func(tc *TestContext, b *testing.B){ benchEncoder, benchEvaluator, } { @@ -77,13 +46,13 @@ func BenchmarkCKKS(b *testing.B) { } } -func benchEncoder(tc *testContext, b *testing.B) { +func benchEncoder(tc *TestContext, b *testing.B) { - encoder := tc.encoder + encoder := tc.Ecd - b.Run(GetBenchName(tc.params, "Encoder/Encode"), func(b *testing.B) { + b.Run(name("Encoder/Encode", tc), func(b *testing.B) { - pt := NewPlaintext(tc.params, tc.params.MaxLevel()) + pt := NewPlaintext(tc.Params, tc.Params.MaxLevel()) values := make([]complex128, 1< 5 { nbRescales = 5 } for i := 0; i < nbRescales; i++ { - constant := tc.ringQ.SubRings[ciphertext.Level()-i].Modulus - require.NoError(t, tc.evaluator.Mul(ciphertext, constant, ciphertext)) + constant := tc.Params.RingQ().SubRings[ciphertext.Level()-i].Modulus + require.NoError(t, tc.Evl.Mul(ciphertext, constant, ciphertext)) ciphertext.Scale = ciphertext.Scale.Mul(rlwe.NewScale(constant)) } - if err := tc.evaluator.RescaleTo(ciphertext, tc.params.DefaultScale(), ciphertext); err != nil { + if err := tc.Evl.RescaleTo(ciphertext, tc.Params.DefaultScale(), ciphertext); err != nil { t.Fatal(err) } - VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values, ciphertext, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t) }) } -func testEvaluatorMul(tc *testContext, t *testing.T) { +func testEvaluatorMul(tc *TestContext, t *testing.T) { - t.Run(GetTestName(tc.params, "Evaluator/MulNew/Ct/Pt"), func(t *testing.T) { + t.Run(name("Evaluator/MulNew/Ct/Pt", tc), func(t *testing.T) { - values1, plaintext1, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values1, plaintext1, ciphertext1 := tc.NewTestVector(-1-1i, 1+1i) mul := bignum.NewComplexMultiplier() @@ -637,17 +511,17 @@ func testEvaluatorMul(tc *testContext, t *testing.T) { mul.Mul(values1[i], values1[i], values1[i]) } - ciphertext2, err := tc.evaluator.MulNew(ciphertext1, plaintext1) + ciphertext2, err := tc.Evl.MulNew(ciphertext1, plaintext1) require.NoError(t, err) - VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext2, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext2, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t) }) - t.Run(GetTestName(tc.params, "Evaluator/Mul/Ct/Scalar"), func(t *testing.T) { + t.Run(name("Evaluator/Mul/Ct/Scalar", tc), func(t *testing.T) { - values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values, _, ciphertext := tc.NewTestVector(-1-1i, 1+1i) - constant := randomConst(tc.params.RingType(), tc.encoder.Prec(), -1+1i, -1+1i) + constant := randomConst(tc.Params.RingType(), tc.Ecd.Prec(), -1+1i, -1+1i) mul := bignum.NewComplexMultiplier() @@ -655,15 +529,15 @@ func testEvaluatorMul(tc *testContext, t *testing.T) { mul.Mul(values[i], constant, values[i]) } - require.NoError(t, tc.evaluator.Mul(ciphertext, constant, ciphertext)) + require.NoError(t, tc.Evl.Mul(ciphertext, constant, ciphertext)) - VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values, ciphertext, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t) }) - t.Run(GetTestName(tc.params, "Evaluator/Mul/Ct/Vector"), func(t *testing.T) { + t.Run(name("Evaluator/Mul/Ct/Vector", tc), func(t *testing.T) { - values1, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - values2, _, _ := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values1, _, ciphertext := tc.NewTestVector(-1-1i, 1+1i) + values2, _, _ := tc.NewTestVector(-1-1i, 1+1i) mul := bignum.NewComplexMultiplier() @@ -671,14 +545,14 @@ func testEvaluatorMul(tc *testContext, t *testing.T) { mul.Mul(values1[i], values2[i], values1[i]) } - tc.evaluator.Mul(ciphertext, values2, ciphertext) + tc.Evl.Mul(ciphertext, values2, ciphertext) - VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t) }) - t.Run(GetTestName(tc.params, "Evaluator/Mul/Ct/Pt"), func(t *testing.T) { + t.Run(name("Evaluator/Mul/Ct/Pt", tc), func(t *testing.T) { - values1, plaintext1, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values1, plaintext1, ciphertext1 := tc.NewTestVector(-1-1i, 1+1i) mul := bignum.NewComplexMultiplier() @@ -686,15 +560,15 @@ func testEvaluatorMul(tc *testContext, t *testing.T) { mul.Mul(values1[i], values1[i], values1[i]) } - require.NoError(t, tc.evaluator.MulRelin(ciphertext1, plaintext1, ciphertext1)) + require.NoError(t, tc.Evl.MulRelin(ciphertext1, plaintext1, ciphertext1)) - VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext1, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t) }) - t.Run(GetTestName(tc.params, "Evaluator/Mul/Ct/Ct/Degree0"), func(t *testing.T) { + t.Run(name("Evaluator/Mul/Ct/Ct/Degree0", tc), func(t *testing.T) { - values1, plaintext1, _ := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values1, plaintext1, _ := tc.NewTestVector(-1-1i, 1+1i) + values2, _, ciphertext2 := tc.NewTestVector(-1-1i, 1+1i) mul := bignum.NewComplexMultiplier() @@ -706,16 +580,16 @@ func testEvaluatorMul(tc *testContext, t *testing.T) { ciphertext1.Value = []ring.Poly{plaintext1.Value} ciphertext1.MetaData = plaintext1.MetaData.CopyNew() - require.NoError(t, tc.evaluator.MulRelin(ciphertext1, ciphertext2, ciphertext1)) + require.NoError(t, tc.Evl.MulRelin(ciphertext1, ciphertext2, ciphertext1)) - VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values2, ciphertext1, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t) }) - t.Run(GetTestName(tc.params, "Evaluator/MulRelin/Ct/Ct"), func(t *testing.T) { + t.Run(name("Evaluator/MulRelin/Ct/Ct", tc), func(t *testing.T) { // op0 <- op0 * op1 - values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values1, _, ciphertext1 := tc.NewTestVector(-1-1i, 1+1i) + values2, _, ciphertext2 := tc.NewTestVector(-1-1i, 1+1i) mul := bignum.NewComplexMultiplier() @@ -723,44 +597,44 @@ func testEvaluatorMul(tc *testContext, t *testing.T) { mul.Mul(values1[i], values2[i], values1[i]) } - require.NoError(t, tc.evaluator.MulRelin(ciphertext1, ciphertext2, ciphertext1)) + require.NoError(t, tc.Evl.MulRelin(ciphertext1, ciphertext2, ciphertext1)) require.Equal(t, ciphertext1.Degree(), 1) - VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext1, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t) // op1 <- op0 * op1 - values1, _, ciphertext1 = newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - values2, _, ciphertext2 = newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values1, _, ciphertext1 = tc.NewTestVector(-1-1i, 1+1i) + values2, _, ciphertext2 = tc.NewTestVector(-1-1i, 1+1i) for i := range values1 { mul.Mul(values2[i], values1[i], values2[i]) } - require.NoError(t, tc.evaluator.MulRelin(ciphertext1, ciphertext2, ciphertext2)) + require.NoError(t, tc.Evl.MulRelin(ciphertext1, ciphertext2, ciphertext2)) require.Equal(t, ciphertext2.Degree(), 1) - VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext2, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values2, ciphertext2, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t) // op0 <- op0 * op0 for i := range values1 { mul.Mul(values1[i], values1[i], values1[i]) } - require.NoError(t, tc.evaluator.MulRelin(ciphertext1, ciphertext1, ciphertext1)) + require.NoError(t, tc.Evl.MulRelin(ciphertext1, ciphertext1, ciphertext1)) require.Equal(t, ciphertext1.Degree(), 1) - VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext1, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t) }) } -func testEvaluatorMulThenAdd(tc *testContext, t *testing.T) { +func testEvaluatorMulThenAdd(tc *TestContext, t *testing.T) { - t.Run(GetTestName(tc.params, "Evaluator/MulThenAdd/Scalar"), func(t *testing.T) { + t.Run(name("Evaluator/MulThenAdd/Scalar", tc), func(t *testing.T) { - values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values1, _, ciphertext1 := tc.NewTestVector(-1-1i, 1+1i) + values2, _, ciphertext2 := tc.NewTestVector(-1-1i, 1+1i) - constant := randomConst(tc.params.RingType(), tc.encoder.Prec(), -1+1i, -1+1i) + constant := randomConst(tc.Params.RingType(), tc.Ecd.Prec(), -1+1i, -1+1i) mul := bignum.NewComplexMultiplier() @@ -773,17 +647,17 @@ func testEvaluatorMulThenAdd(tc *testContext, t *testing.T) { values2[i].Add(values2[i], tmp) } - require.NoError(t, tc.evaluator.MulThenAdd(ciphertext1, constant, ciphertext2)) + require.NoError(t, tc.Evl.MulThenAdd(ciphertext1, constant, ciphertext2)) - VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext2, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values2, ciphertext2, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t) }) - t.Run(GetTestName(tc.params, "Evaluator/MulThenAdd/Vector"), func(t *testing.T) { + t.Run(name("Evaluator/MulThenAdd/Vector", tc), func(t *testing.T) { - values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1, 1, t) - values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1, 1, t) + values1, _, ciphertext1 := tc.NewTestVector(-1-1i, 1+1i) + values2, _, ciphertext2 := tc.NewTestVector(-1-1i, 1+1i) - require.NoError(t, tc.evaluator.MulThenAdd(ciphertext2, values1, ciphertext1)) + require.NoError(t, tc.Evl.MulThenAdd(ciphertext2, values1, ciphertext1)) mul := bignum.NewComplexMultiplier() @@ -798,13 +672,13 @@ func testEvaluatorMulThenAdd(tc *testContext, t *testing.T) { require.Equal(t, ciphertext1.Degree(), 1) - VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext1, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t) }) - t.Run(GetTestName(tc.params, "Evaluator/MulThenAdd/Pt"), func(t *testing.T) { + t.Run(name("Evaluator/MulThenAdd/Pt", tc), func(t *testing.T) { - values1, plaintext1, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1, 1, t) - values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1, 1, t) + values1, plaintext1, ciphertext1 := tc.NewTestVector(-1, 1) + values2, _, ciphertext2 := tc.NewTestVector(-1, 1) mul := bignum.NewComplexMultiplier() @@ -817,18 +691,18 @@ func testEvaluatorMulThenAdd(tc *testContext, t *testing.T) { values1[i].Add(values1[i], tmp) } - require.NoError(t, tc.evaluator.MulThenAdd(ciphertext2, plaintext1, ciphertext1)) + require.NoError(t, tc.Evl.MulThenAdd(ciphertext2, plaintext1, ciphertext1)) require.Equal(t, ciphertext1.Degree(), 1) - VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext1, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t) }) - t.Run(GetTestName(tc.params, "Evaluator/MulRelinThenAdd/Ct"), func(t *testing.T) { + t.Run(name("Evaluator/MulRelinThenAdd/Ct", tc), func(t *testing.T) { // opOut = opOut + op1 * op0 - values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values1, _, ciphertext1 := tc.NewTestVector(-1-1i, 1+1i) + values2, _, ciphertext2 := tc.NewTestVector(-1-1i, 1+1i) mul := bignum.NewComplexMultiplier() @@ -836,23 +710,23 @@ func testEvaluatorMulThenAdd(tc *testContext, t *testing.T) { mul.Mul(values1[i], values2[i], values2[i]) } - ciphertext3 := NewCiphertext(tc.params, 2, ciphertext1.Level()) + ciphertext3 := NewCiphertext(tc.Params, 2, ciphertext1.Level()) ciphertext3.Scale = ciphertext1.Scale.Mul(ciphertext2.Scale) - require.NoError(t, tc.evaluator.MulThenAdd(ciphertext1, ciphertext2, ciphertext3)) + require.NoError(t, tc.Evl.MulThenAdd(ciphertext1, ciphertext2, ciphertext3)) require.Equal(t, ciphertext3.Degree(), 2) - require.NoError(t, tc.evaluator.Relinearize(ciphertext3, ciphertext3)) + require.NoError(t, tc.Evl.Relinearize(ciphertext3, ciphertext3)) require.Equal(t, ciphertext3.Degree(), 1) - VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext3, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values2, ciphertext3, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t) // op1 = op1 + op0*op0 - values1, _, ciphertext1 = newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - values2, _, ciphertext2 = newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values1, _, ciphertext1 = tc.NewTestVector(-1-1i, 1+1i) + values2, _, ciphertext2 = tc.NewTestVector(-1-1i, 1+1i) tmp := bignum.NewComplex() for i := range values1 { @@ -860,23 +734,23 @@ func testEvaluatorMulThenAdd(tc *testContext, t *testing.T) { values1[i].Add(values1[i], tmp) } - require.NoError(t, tc.evaluator.MulRelinThenAdd(ciphertext2, ciphertext2, ciphertext1)) + require.NoError(t, tc.Evl.MulRelinThenAdd(ciphertext2, ciphertext2, ciphertext1)) require.Equal(t, ciphertext1.Degree(), 1) - VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values1, ciphertext1, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t) }) } -func testBridge(tc *testContext, t *testing.T) { +func testBridge(tc *TestContext, t *testing.T) { - t.Run(GetTestName(tc.params, "Bridge"), func(t *testing.T) { + t.Run(name("Bridge", tc), func(t *testing.T) { - if tc.params.RingType() != ring.ConjugateInvariant { + if tc.Params.RingType() != ring.ConjugateInvariant { t.Skip("only tested for params.RingType() == ring.ConjugateInvariant") } - ciParams := tc.params + ciParams := tc.Params var err error if _, err = ciParams.StandardParameters(); err != nil { t.Fatalf("all Conjugate Invariant parameters should have a standard counterpart but got: %f", err) @@ -896,7 +770,7 @@ func testBridge(tc *testContext, t *testing.T) { stdEncoder := NewEncoder(stdParams) stdEvaluator := NewEvaluator(stdParams, nil) - evkCtR, evkRtC := stdKeyGen.GenEvaluationKeysForRingSwapNew(stdSK, tc.sk) + evkCtR, evkRtC := stdKeyGen.GenEvaluationKeysForRingSwapNew(stdSK, tc.Sk) switcher, err := NewDomainSwitcher(stdParams, evkCtR, evkRtC) if err != nil { @@ -905,13 +779,13 @@ func testBridge(tc *testContext, t *testing.T) { evalStandar := NewEvaluator(stdParams, nil) - values, _, ctCI := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values, _, ctCI := tc.NewTestVector(-1-1i, 1+1i) stdCTHave := NewCiphertext(stdParams, ctCI.Degree(), ctCI.Level()) switcher.RealToComplex(evalStandar, ctCI, stdCTHave) - VerifyTestVectors(stdParams, stdEncoder, stdDecryptor, values, stdCTHave, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t) + VerifyTestVectors(stdParams, stdEncoder, stdDecryptor, values, stdCTHave, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t) stdCTImag, err := stdEvaluator.MulNew(stdCTHave, 1i) require.NoError(t, err) @@ -920,6 +794,27 @@ func testBridge(tc *testContext, t *testing.T) { ciCTHave := NewCiphertext(ciParams, 1, stdCTHave.Level()) switcher.ComplexToReal(evalStandar, stdCTHave, ciCTHave) - VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciCTHave, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t) + VerifyTestVectors(tc.Params, tc.Ecd, tc.Dec, values, ciCTHave, tc.Params.LogDefaultScale(), 0, *printPrecisionStats, t) }) } + +func name(opname string, tc *TestContext) string { + + var precMode string + switch tc.Params.precisionMode { + case PREC64: + precMode = "PREC64" + case PREC128: + precMode = "PREC128" + } + + return fmt.Sprintf("%s/RingType=%s/logN=%d/logQP=%d/Qi=%d/Pi=%d/LogScale=%d/PrecMode=%s", + opname, + tc.Params.RingType(), + tc.Params.LogN(), + int(math.Round(tc.Params.LogQP())), + tc.Params.QCount(), + tc.Params.PCount(), + int(math.Log2(tc.Params.DefaultScale().Float64())), + precMode) +} diff --git a/schemes/ckks/test_params.go b/schemes/ckks/test_params.go deleted file mode 100644 index dd5c782f..00000000 --- a/schemes/ckks/test_params.go +++ /dev/null @@ -1,48 +0,0 @@ -package ckks - -var ( - // testInsecurePrec45 are insecure parameters used for the sole purpose of fast testing. - testInsecurePrec45 = ParametersLiteral{ - LogN: 10, - Q: []uint64{ - 0x80000000080001, - 0x2000000a0001, - 0x2000000e0001, - 0x2000001d0001, - 0x1fffffcf0001, - 0x1fffffc20001, - 0x200000440001, - }, - P: []uint64{ - 0x80000000130001, - 0x7fffffffe90001, - }, - LogDefaultScale: 45, - } - - // testInsecurePrec90 are insecure parameters used for the sole purpose of fast testing. - testInsecurePrec90 = ParametersLiteral{ - LogN: 10, - Q: []uint64{ - 0x80000000080001, - 0x80000000440001, - 0x2000000a0001, - 0x2000000e0001, - 0x1fffffc20001, - 0x200000440001, - 0x200000500001, - 0x200000620001, - 0x1fffff980001, - 0x2000006a0001, - 0x1fffff7e0001, - 0x200000860001, - }, - P: []uint64{ - 0xffffffffffc0001, - 0x10000000006e0001, - }, - LogDefaultScale: 90, - } - - testParamsLiteral = []ParametersLiteral{testInsecurePrec45, testInsecurePrec90} -) diff --git a/schemes/ckks/test_utils.go b/schemes/ckks/test_utils.go new file mode 100644 index 00000000..1156eff6 --- /dev/null +++ b/schemes/ckks/test_utils.go @@ -0,0 +1,174 @@ +package ckks + +import ( + "math/big" + + "github.com/tuneinsight/lattigo/v5/core/rlwe" + "github.com/tuneinsight/lattigo/v5/ring" + "github.com/tuneinsight/lattigo/v5/utils/bignum" + "github.com/tuneinsight/lattigo/v5/utils/sampling" +) + +type TestContext struct { + Params Parameters + Ecd *Encoder + + Prng sampling.PRNG + + Kgen *rlwe.KeyGenerator + Sk *rlwe.SecretKey + Pk *rlwe.PublicKey + + Enc *rlwe.Encryptor + Dec *rlwe.Decryptor + + Evl *Evaluator +} + +func NewTestContext(params ParametersLiteral) *TestContext { + tc := new(TestContext) + + var err error + + tc.Params, err = NewParametersFromLiteral(params) + if err != nil { + panic(err) + } + tc.Ecd = NewEncoder(tc.Params) + + tc.Prng, err = sampling.NewPRNG() + if err != nil { + panic(err) + } + + tc.Kgen = rlwe.NewKeyGenerator(tc.Params) + tc.Sk, tc.Pk = tc.Kgen.GenKeyPairNew() + + tc.Enc = rlwe.NewEncryptor(tc.Params, tc.Pk) + tc.Dec = rlwe.NewDecryptor(tc.Params, tc.Sk) + + tc.Evl = NewEvaluator(tc.Params, rlwe.NewMemEvaluationKeySet(tc.Kgen.GenRelinearizationKeyNew(tc.Sk))) + + return tc +} + +func (tc *TestContext) NewTestVector(a, b complex128) (values []*bignum.Complex, pt *rlwe.Plaintext, ct *rlwe.Ciphertext) { + prec := tc.Ecd.Prec() + + pt = NewPlaintext(tc.Params, tc.Params.MaxLevel()) + + values = make([]*bignum.Complex, pt.Slots()) + + switch tc.Params.RingType() { + case ring.Standard: + for i := range values { + values[i] = &bignum.Complex{ + bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec), + bignum.NewFloat(sampling.RandFloat64(imag(a), imag(b)), prec), + } + } + case ring.ConjugateInvariant: + for i := range values { + values[i] = &bignum.Complex{ + bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec), + new(big.Float), + } + } + default: + panic("unsupported ring type") + } + + var err error + + if err = tc.Ecd.Encode(values, pt); err != nil { + panic(err) + } + + ct, err = tc.Enc.EncryptNew(pt) + if err != nil { + panic(err) + } + + return values, pt, ct +} + +func randomConst(tp ring.Type, prec uint, a, b complex128) (constant *bignum.Complex) { + switch tp { + case ring.Standard: + constant = &bignum.Complex{ + bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec), + bignum.NewFloat(sampling.RandFloat64(imag(a), imag(b)), prec), + } + case ring.ConjugateInvariant: + constant = &bignum.Complex{ + bignum.NewFloat(sampling.RandFloat64(real(a), real(b)), prec), + new(big.Float), + } + default: + panic("invalid ring type") + } + return +} + +var ( + // testInsecurePrec45 are insecure parameters used for the sole purpose of fast testing. + TestInsecurePrec45 = ParametersLiteral{ + LogN: 10, + LogQ: []int{55, 45, 45, 45, 45, 45, 45}, + LogP: []int{60}, + LogDefaultScale: 45, + } + + // testInsecurePrec90 are insecure parameters used for the sole purpose of fast testing. + TestInsecurePrec90 = ParametersLiteral{ + LogN: 10, + LogQ: []int{55, 55, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45}, + LogP: []int{60, 60}, + LogDefaultScale: 90, + } + + TestParametersLiteral = []ParametersLiteral{TestInsecurePrec45, TestInsecurePrec90} +) + +// testInsecurePrec45 = ParametersLiteral{ +// LogN: 10, +// Q: []uint64{ +// 0x80000000080001, +// 0x2000000a0001, +// 0x2000000e0001, +// 0x2000001d0001, +// 0x1fffffcf0001, +// 0x1fffffc20001, +// 0x200000440001, +// }, +// P: []uint64{ +// 0x80000000130001, +// 0x7fffffffe90001, +// }, +// LogDefaultScale: 45, +// } + +// testInsecurePrec90 = ParametersLiteral{ +// LogN: 10, +// Q: []uint64{ +// 0x80000000080001, +// 0x80000000440001, +// 0x2000000a0001, +// 0x2000000e0001, +// 0x1fffffc20001, +// 0x200000440001, +// 0x200000500001, +// 0x200000620001, +// 0x1fffff980001, +// 0x2000006a0001, +// 0x1fffff7e0001, +// 0x200000860001, +// }, +// P: []uint64{ +// 0xffffffffffc0001, +// 0x10000000006e0001, +// }, +// LogDefaultScale: 90, +// } + +// testParamsLiteral = []ParametersLiteral{testInsecurePrec45, testInsecurePrec90} diff --git a/schemes/parameters_test.go b/schemes/parameters_test.go deleted file mode 100644 index 9b254d91..00000000 --- a/schemes/parameters_test.go +++ /dev/null @@ -1 +0,0 @@ -package schemes diff --git a/schemes/test_parameters.go b/schemes/test_parameters.go deleted file mode 100644 index 92a1e461..00000000 --- a/schemes/test_parameters.go +++ /dev/null @@ -1,39 +0,0 @@ -package schemes - -import ( - "github.com/tuneinsight/lattigo/v5/schemes/bgv" - "github.com/tuneinsight/lattigo/v5/schemes/ckks" -) - -var ( - // BgvTestInsecure are insecure parameters used for the sole purpose of fast testing. - BgvTestInsecure = bgv.ParametersLiteral{ - LogN: 10, - Q: []uint64{0x3fffffa8001, 0x1000090001, 0x10000c8001, 0x10000f0001, 0xffff00001}, - P: []uint64{0x7fffffd8001}, - } - - BgvTestPlaintextModulus = []uint64{0x101, 0xffc001} - - BgvTestParams = []bgv.ParametersLiteral{BgvTestInsecure} -) - -var ( - // BgvTestInsecurePrec45 are insecure parameters used for the sole purpose of fast testing. - CkksTestInsecurePrec45 = ckks.ParametersLiteral{ - LogN: 10, - LogQ: []int{55, 45, 45, 45, 45, 45, 45}, - LogP: []int{60}, - LogDefaultScale: 45, - } - - // BgvTestInsecurePrec90 are insecure parameters used for the sole purpose of fast testing. - CkksTestInsecurePrec90 = ckks.ParametersLiteral{ - LogN: 10, - LogQ: []int{55, 55, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45}, - LogP: []int{60, 60}, - LogDefaultScale: 90, - } - - CkksTestParametersLiteral = []ckks.ParametersLiteral{CkksTestInsecurePrec45, CkksTestInsecurePrec90} -)