From a0d8e7617e13b3361dc4044111a4b3c490fe7bbe Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 27 Oct 2023 10:53:44 +0200 Subject: [PATCH] [bgv]: added tests for IsBatched = false --- bgv/bgv_test.go | 83 ++++++++++++++++++++++++++++++++---------- utils/sampling/prng.go | 2 +- 2 files changed, 64 insertions(+), 21 deletions(-) diff --git a/bgv/bgv_test.go b/bgv/bgv_test.go index eb71b15a..88ce8f2d 100644 --- a/bgv/bgv_test.go +++ b/bgv/bgv_test.go @@ -217,33 +217,76 @@ func testParameters(tc *testContext, t *testing.T) { func testEncoder(tc *testContext, t *testing.T) { for _, lvl := range tc.testLevel { - t.Run(GetTestName("Encoder/Uint", tc.params, lvl), func(t *testing.T) { + t.Run(GetTestName("Encoder/Uint/IsBatched=true", tc.params, lvl), func(t *testing.T) { values, plaintext, _ := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, nil) verifyTestVectors(tc, nil, values, plaintext, t) }) } for _, lvl := range tc.testLevel { - t.Run(GetTestName("Encoder/Int", tc.params, lvl), func(t *testing.T) { + t.Run(GetTestName("Encoder/Int/IsBatched=true", tc.params, lvl), func(t *testing.T) { T := tc.params.PlaintextModulus() THalf := T >> 1 - coeffs := tc.uSampler.ReadNew() - coeffsInt := make([]int64, coeffs.N()) - for i, c := range coeffs.Coeffs[0] { + poly := tc.uSampler.ReadNew() + coeffs := make([]int64, poly.N()) + for i, c := range poly.Coeffs[0] { c %= T if c >= THalf { - coeffsInt[i] = -int64(T - c) + coeffs[i] = -int64(T - c) } else { - coeffsInt[i] = int64(c) + coeffs[i] = int64(c) } } plaintext := NewPlaintext(tc.params, lvl) - tc.encoder.Encode(coeffsInt, plaintext) + tc.encoder.Encode(coeffs, plaintext) have := make([]int64, tc.params.MaxSlots()) tc.encoder.Decode(plaintext, have) - require.True(t, utils.EqualSlice(coeffsInt, have)) + require.True(t, utils.EqualSlice(coeffs, have)) + }) + } + + for _, lvl := range tc.testLevel { + t.Run(GetTestName("Encoder/Uint/IsBatched=false", tc.params, lvl), func(t *testing.T) { + T := tc.params.PlaintextModulus() + poly := tc.uSampler.ReadNew() + coeffs := make([]uint64, poly.N()) + for i, c := range poly.Coeffs[0] { + coeffs[i] = c % T + } + + plaintext := NewPlaintext(tc.params, lvl) + plaintext.IsBatched = false + require.NoError(t, tc.encoder.Encode(coeffs, plaintext)) + have := make([]uint64, tc.params.MaxSlots()) + require.NoError(t, tc.encoder.Decode(plaintext, have)) + require.True(t, utils.EqualSlice(coeffs, have)) + }) + } + + for _, lvl := range tc.testLevel { + t.Run(GetTestName("Encoder/Int/IsBatched=false", tc.params, lvl), func(t *testing.T) { + + T := tc.params.PlaintextModulus() + THalf := T >> 1 + poly := tc.uSampler.ReadNew() + coeffs := make([]int64, poly.N()) + for i, c := range poly.Coeffs[0] { + c %= T + if c >= THalf { + coeffs[i] = -int64(T - c) + } else { + coeffs[i] = int64(c) + } + } + + plaintext := NewPlaintext(tc.params, lvl) + plaintext.IsBatched = false + require.NoError(t, tc.encoder.Encode(coeffs, plaintext)) + have := make([]int64, tc.params.MaxSlots()) + require.NoError(t, tc.encoder.Decode(plaintext, have)) + require.True(t, utils.EqualSlice(coeffs, have)) }) } } @@ -408,7 +451,7 @@ func testEvaluator(tc *testContext, t *testing.T) { t.Run(GetTestName("Evaluator/Mul/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) { if lvl == 0 { - t.Skip("Level = 0") + t.Skip("Skipping: Level = 0") } values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) @@ -428,7 +471,7 @@ func testEvaluator(tc *testContext, t *testing.T) { t.Run(GetTestName("Evaluator/Mul/Ct/Pt/Inplace", tc.params, lvl), func(t *testing.T) { if lvl == 0 { - t.Skip("Level = 0") + t.Skip("Skipping: Level = 0") } values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) @@ -448,7 +491,7 @@ func testEvaluator(tc *testContext, t *testing.T) { t.Run(GetTestName("Evaluator/Mul/Ct/Scalar/Inplace", tc.params, lvl), func(t *testing.T) { if lvl == 0 { - t.Skip("Level = 0") + t.Skip("Skipping: Level = 0") } values, _, ciphertext := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk) @@ -466,7 +509,7 @@ func testEvaluator(tc *testContext, t *testing.T) { t.Run(GetTestName("Evaluator/Mul/Ct/Vector/Inplace", tc.params, lvl), func(t *testing.T) { if lvl == 0 { - t.Skip("Level = 0") + t.Skip("Skipping: Level = 0") } values, _, ciphertext := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk) @@ -482,7 +525,7 @@ func testEvaluator(tc *testContext, t *testing.T) { t.Run(GetTestName("Evaluator/Square/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) { if lvl == 0 { - t.Skip("Level = 0") + t.Skip("Skipping: Level = 0") } values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) @@ -498,7 +541,7 @@ func testEvaluator(tc *testContext, t *testing.T) { t.Run(GetTestName("Evaluator/MulRelin/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) { if lvl == 0 { - t.Skip("Level = 0") + t.Skip("Skipping: Level = 0") } values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(3), tc, tc.encryptorSk) @@ -522,7 +565,7 @@ func testEvaluator(tc *testContext, t *testing.T) { t.Run(GetTestName("Evaluator/MulThenAdd/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) { if lvl == 0 { - t.Skip("Level = 0") + t.Skip("Skipping: Level = 0") } values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk) @@ -543,7 +586,7 @@ func testEvaluator(tc *testContext, t *testing.T) { t.Run(GetTestName("Evaluator/MulThenAdd/Ct/Pt/Inplace", tc.params, lvl), func(t *testing.T) { if lvl == 0 { - t.Skip("Level = 0") + t.Skip("Skipping: Level = 0") } values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk) @@ -564,7 +607,7 @@ func testEvaluator(tc *testContext, t *testing.T) { t.Run(GetTestName("Evaluator/MulThenAdd/Ct/Scalar/Inplace", tc.params, lvl), func(t *testing.T) { if lvl == 0 { - t.Skip("Level = 0") + t.Skip("Skipping: Level = 0") } values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) @@ -585,7 +628,7 @@ func testEvaluator(tc *testContext, t *testing.T) { t.Run(GetTestName("Evaluator/MulThenAdd/Ct/Vector/Inplace", tc.params, lvl), func(t *testing.T) { if lvl == 0 { - t.Skip("Level = 0") + t.Skip("Skipping: Level = 0") } values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.NewScale(7), tc, tc.encryptorSk) @@ -609,7 +652,7 @@ func testEvaluator(tc *testContext, t *testing.T) { t.Run(GetTestName("Evaluator/MulRelinThenAdd/Ct/Ct/Inplace", tc.params, lvl), func(t *testing.T) { if lvl == 0 { - t.Skip("Level = 0") + t.Skip("Skipping: Level = 0") } values0, _, ciphertext0 := newTestVectorsLvl(lvl, tc.params.DefaultScale(), tc, tc.encryptorSk) diff --git a/utils/sampling/prng.go b/utils/sampling/prng.go index 0c2695f0..992048fe 100644 --- a/utils/sampling/prng.go +++ b/utils/sampling/prng.go @@ -38,7 +38,7 @@ func NewPRNG() (*KeyedPRNG, error) { prng := new(KeyedPRNG) key := make([]byte, 64) if _, err := rand.Read(key); err != nil { - return fmt.Errorf("crypto rand error: %w", err) + return nil, fmt.Errorf("crypto rand error: %w", err) } prng.key = key prng.xof, err = blake2b.NewXOF(blake2b.OutputLengthUnknown, key)