From b926532a8df2f4b096a699fe4e7de29a66b59554 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 29 Nov 2019 13:49:07 +0100 Subject: [PATCH] BFV : params revamp --- bfv/bfv.go | 17 ++---- bfv/bfv_benchmark_test.go | 50 ++++++++-------- bfv/bfv_test.go | 10 ++-- bfv/ciphertext.go | 4 +- bfv/decryptor.go | 2 +- bfv/encryptor.go | 4 +- bfv/keygen.go | 14 ++--- bfv/operand.go | 2 +- bfv/params.go | 122 +++++++++++++++++++------------------- bfv/params_test.go | 4 +- bfv/plaintext.go | 2 +- bfv/utils.go | 110 ++++++++++++++++++++++++++++++++++ 12 files changed, 221 insertions(+), 120 deletions(-) create mode 100644 bfv/utils.go diff --git a/bfv/bfv.go b/bfv/bfv.go index 1e2a1e93..3a33c8ee 100644 --- a/bfv/bfv.go +++ b/bfv/bfv.go @@ -21,7 +21,6 @@ type Context struct { t uint64 logQ uint64 - logP uint64 // floor(Q/T) mod each Qi in Montgomery form deltaMont []uint64 @@ -88,11 +87,11 @@ func NewContext(params *Parameters) (newContext *Context) { // - sigma : the variance of the gaussian sampling. func (context *Context) SetParameters(params *Parameters) { - N := params.N + LogN := params.LogN + N := uint64(1 << LogN) t := params.T - ModuliQ1 := params.Q1 - ModuliQ2 := params.Q2 - ModuliP := params.P + + ModuliQ1, ModuliP, ModuliQ2 := genModuli(params) sigma := params.Sigma context.n = N @@ -146,8 +145,7 @@ func (context *Context) SetParameters(params *Parameters) { context.qHalf = new(big.Int).Rsh(context.contextQ1.ModulusBigint, 1) context.pHalf = new(big.Int).Rsh(context.contextQ2.ModulusBigint, 1) - context.logQ = uint64(context.contextP.ModulusBigint.BitLen()) - context.logP = uint64(context.contextQ2.ModulusBigint.BitLen()) + context.logQ = uint64(context.contextQ1P.ModulusBigint.BitLen()) delta := new(big.Int).Quo(context.contextQ1.ModulusBigint, ring.NewUint(t)) tmpBig := new(big.Int) @@ -177,11 +175,6 @@ func (context *Context) LogQ() uint64 { return context.logQ } -// LogP returns logQ which is the total bitzise of the secondary ciphertext modulus. -func (context *Context) LogP() uint64 { - return context.logP -} - // T returns the plaintext modulus of the target context. func (context *Context) T() uint64 { return context.t diff --git a/bfv/bfv_benchmark_test.go b/bfv/bfv_benchmark_test.go index 966714e1..544170bd 100755 --- a/bfv/bfv_benchmark_test.go +++ b/bfv/bfv_benchmark_test.go @@ -11,24 +11,24 @@ func Benchmark_BFV(b *testing.B) { for _, params := range paramSets { - bfvContext := NewContext(¶ms) + bfvContext := NewContext(params) var sk *SecretKey var pk *PublicKey var err error - kgen := NewKeyGenerator(¶ms) + kgen := NewKeyGenerator(params) - encoder := NewEncoder(¶ms) + encoder := NewEncoder(params) coeffs := bfvContext.contextT.NewUniformPoly() - plaintext := NewPlaintextFromParams(¶ms) + plaintext := NewPlaintextFromParams(params) encoder.EncodeUint(coeffs.Coeffs[0], plaintext) // Public Key Generation - b.Run(testString("KeyGen", ¶ms), func(b *testing.B) { + b.Run(testString("KeyGen", params), func(b *testing.B) { for i := 0; i < b.N; i++ { sk, pk = kgen.NewKeyPair() if err != nil { @@ -38,61 +38,61 @@ func Benchmark_BFV(b *testing.B) { }) // Encryption - encryptorPk := NewEncryptorFromPk(¶ms, pk) - encryptorSk := NewEncryptorFromSk(¶ms, sk) + encryptorPk := NewEncryptorFromPk(params, pk) + encryptorSk := NewEncryptorFromSk(params, sk) - ctd1 := NewCiphertextFromParams(¶ms, 1) - b.Run(testString("EncryptFromPk", ¶ms), func(b *testing.B) { + ctd1 := NewCiphertextFromParams(params, 1) + b.Run(testString("EncryptFromPk", params), func(b *testing.B) { for i := 0; i < b.N; i++ { encryptorPk.Encrypt(plaintext, ctd1) } }) - b.Run(testString("EncryptFromSk", ¶ms), func(b *testing.B) { + b.Run(testString("EncryptFromSk", params), func(b *testing.B) { for i := 0; i < b.N; i++ { encryptorSk.Encrypt(plaintext, ctd1) } }) // Decryption - decryptor := NewDecryptor(¶ms, sk) - ptp := NewPlaintextFromParams(¶ms) - b.Run(testString("Decrypt", ¶ms), func(b *testing.B) { + decryptor := NewDecryptor(params, sk) + ptp := NewPlaintextFromParams(params) + b.Run(testString("Decrypt", params), func(b *testing.B) { for i := 0; i < b.N; i++ { decryptor.Decrypt(ctd1, ptp) } _ = ptp }) - evaluator := NewEvaluator(¶ms) + evaluator := NewEvaluator(params) - ct1 := NewRandomCiphertextFromParams(¶ms, 1) - ct2 := NewRandomCiphertextFromParams(¶ms, 1) + ct1 := NewRandomCiphertextFromParams(params, 1) + ct2 := NewRandomCiphertextFromParams(params, 1) // Addition - b.Run(testString("Add", ¶ms), func(b *testing.B) { + b.Run(testString("Add", params), func(b *testing.B) { for i := 0; i < b.N; i++ { evaluator.Add(ct1, ct2, ctd1) } }) // Subtraction - b.Run(testString("Sub", ¶ms), func(b *testing.B) { + b.Run(testString("Sub", params), func(b *testing.B) { for i := 0; i < b.N; i++ { evaluator.Sub(ct1, ct2, ctd1) } }) // Multiplication - receiver := NewCiphertextFromParams(¶ms, 2) - b.Run(testString("Multiply", ¶ms), func(b *testing.B) { + receiver := NewCiphertextFromParams(params, 2) + b.Run(testString("Multiply", params), func(b *testing.B) { for i := 0; i < b.N; i++ { evaluator.Mul(ct1, ct2, receiver) } }) // Square is Mul(ct, ct) for now - b.Run(testString("Square", ¶ms), func(b *testing.B) { + b.Run(testString("Square", params), func(b *testing.B) { for i := 0; i < b.N; i++ { evaluator.Mul(ct1, ct1, receiver) } @@ -102,7 +102,7 @@ func Benchmark_BFV(b *testing.B) { rlk := kgen.NewRelinKey(sk, 2) // Relinearization - b.Run(testString("Relin", ¶ms), func(b *testing.B) { + b.Run(testString("Relin", params), func(b *testing.B) { for i := 0; i < b.N; i++ { evaluator.Relinearize(receiver, rlk, ctd1) } @@ -112,14 +112,14 @@ func Benchmark_BFV(b *testing.B) { rtk := kgen.NewRotationKeysPow2(sk) // Rotation Rows - b.Run(testString("RotateRows", ¶ms), func(b *testing.B) { + b.Run(testString("RotateRows", params), func(b *testing.B) { for i := 0; i < b.N; i++ { evaluator.RotateRows(ct1, rtk, ctd1) } }) // Rotation Cols - b.Run(testString("RotateCols", ¶ms), func(b *testing.B) { + b.Run(testString("RotateCols", params), func(b *testing.B) { for i := 0; i < b.N; i++ { evaluator.RotateColumns(ct1, 1, rtk, ctd1) } @@ -129,5 +129,5 @@ func Benchmark_BFV(b *testing.B) { } func testString(opname string, params *Parameters) string { - return fmt.Sprintf("%s/params=%d", opname, params.N) + return fmt.Sprintf("%s/params=%d", opname, params.LogN) } diff --git a/bfv/bfv_test.go b/bfv/bfv_test.go index bbea0f84..eb8b10f0 100644 --- a/bfv/bfv_test.go +++ b/bfv/bfv_test.go @@ -18,7 +18,7 @@ func check(t *testing.T, err error) { } func testString2(opname string, params *bfvParams) string { - return fmt.Sprintf("%sparams=%d", opname, params.bfvContext.N()) + return fmt.Sprintf("%sparams=%d/logQ=%d", opname, params.bfvContext.n, params.bfvContext.logQ) } type bfvParams struct { @@ -32,7 +32,6 @@ type bfvParams struct { encryptorSk *Encryptor decryptor *Decryptor evaluator *Evaluator - ringCtx *ring.Context } type bfvTestParameters struct { @@ -46,9 +45,10 @@ func init() { rand.Seed(time.Now().UnixNano()) testParams.bfvParameters = []*Parameters{ - &DefaultParams[0], - &DefaultParams[1], - &DefaultParams[2], + DefaultParams[12], + DefaultParams[13], + DefaultParams[14], + DefaultParams[15], } } diff --git a/bfv/ciphertext.go b/bfv/ciphertext.go index bfdbcbeb..b2f2b76e 100644 --- a/bfv/ciphertext.go +++ b/bfv/ciphertext.go @@ -20,7 +20,7 @@ func NewCiphertextFromParams(params *Parameters, degree uint64) (ciphertext *Cip ciphertext.value = make([]*ring.Poly, degree+1) for i := uint64(0); i < degree+1; i++ { - ciphertext.value[i] = ring.NewPoly(params.N, uint64(len(params.Q1))) + ciphertext.value[i] = ring.NewPoly(1< MaxN { + p.LogN = b.ReadUint8() + if p.LogN > MaxLogN { return errors.New("polynomial degree is too large") } - lenQ1 := uint64(b.ReadUint8()) + lenQ1 := b.ReadUint8() if lenQ1 > MaxModuliCount { return fmt.Errorf("len(Q1) is larger than %d", MaxModuliCount) } - lenP := uint64(b.ReadUint8()) + lenP := b.ReadUint8() if lenP > MaxModuliCount { return fmt.Errorf("len(lenP) is larger than %d", MaxModuliCount) } - lenQ2 := uint64(b.ReadUint8()) + lenQ2 := b.ReadUint8() if lenQ2 > MaxModuliCount { return fmt.Errorf("len(Q2) is larger than %d", MaxModuliCount) } p.T = b.ReadUint64() p.Sigma = math.Round((float64(b.ReadUint64())/float64(1<<32))*100) / 100 - p.Q1 = make([]uint64, lenQ1, lenQ1) - p.P = make([]uint64, lenP, lenP) - p.Q2 = make([]uint64, lenQ2, lenQ2) + p.Q1 = make([]uint8, lenQ1, lenQ1) + p.P = make([]uint8, lenP, lenP) + p.Q2 = make([]uint8, lenQ2, lenQ2) - b.ReadUint64Slice(p.Q1) - b.ReadUint64Slice(p.P) - b.ReadUint64Slice(p.Q2) + b.ReadUint8Slice(p.Q1) + b.ReadUint8Slice(p.P) + b.ReadUint8Slice(p.Q2) return nil } diff --git a/bfv/params_test.go b/bfv/params_test.go index 015bca58..4e96b295 100644 --- a/bfv/params_test.go +++ b/bfv/params_test.go @@ -10,7 +10,7 @@ func TestParams_BinaryMarshaller(t *testing.T) { bytes, err := (&Parameters{}).MarshalBinary() assert.Nil(t, err) assert.Equal(t, []byte{}, bytes) - var p Parameters + p := new(Parameters) err = p.UnmarshalBinary(bytes) assert.NotNil(t, err) }) @@ -18,7 +18,7 @@ func TestParams_BinaryMarshaller(t *testing.T) { for _, params := range DefaultParams { bytes, err := params.MarshalBinary() assert.Nil(t, err) - var p Parameters + p := new(Parameters) err = p.UnmarshalBinary(bytes) assert.Nil(t, err) assert.Equal(t, params, p) diff --git a/bfv/plaintext.go b/bfv/plaintext.go index b4b200f1..aebd0c87 100644 --- a/bfv/plaintext.go +++ b/bfv/plaintext.go @@ -14,7 +14,7 @@ type Plaintext struct { func NewPlaintextFromParams(params *Parameters) *Plaintext { plaintext := &Plaintext{&bfvElement{}, nil} - plaintext.bfvElement.value = []*ring.Poly{ring.NewPoly(params.N, uint64(len(params.Q1)))} + plaintext.bfvElement.value = []*ring.Poly{ring.NewPoly(uint64(1< 60 { + panic("provided moduli must be smaller than 61") + } + } + + for _, pj := range params.P { + primesbitlen[uint64(pj)]++ + + if uint64(pj) > 60 { + panic("provided P must be smaller than 61") + } + } + + for i, qi := range params.Q2 { + + primesbitlen[uint64(qi)]++ + + if uint64(params.Q2[i]) > 60 { + panic("provided moduli must be smaller than 61") + } + } + + // For each bitsize, finds that many primes + primes := make(map[uint64][]uint64) + for key, value := range primesbitlen { + primes[key] = generateNTTPrimes(key, uint64(params.LogN), value) + } + + // Assigns the primes to the ckks moduli chain + Q1 = make([]uint64, len(params.Q1)) + for i, qi := range params.Q1 { + Q1[i] = primes[uint64(params.Q1[i])][0] + primes[uint64(qi)] = primes[uint64(qi)][1:] + } + + // Assigns the primes to the special primes list for the the keyscontext + P = make([]uint64, len(params.P)) + for i, pj := range params.P { + P[i] = primes[uint64(pj)][0] + primes[uint64(pj)] = primes[uint64(pj)][1:] + } + + Q2 = make([]uint64, len(params.Q2)) + for i, qi := range params.Q2 { + Q2[i] = primes[uint64(params.Q2[i])][0] + primes[uint64(qi)] = primes[uint64(qi)][1:] + } + + return Q1, P, Q2 +} + +func generateNTTPrimes(logQ, logN, levels uint64) (primes []uint64) { + + // generateCKKSPrimes generates primes given logQ = size of the primes, logN = size of N and level, the number + // of levels required. Will return all the appropriate primes, up to the number of level, with the + // best avaliable deviation from the base power of 2 for the given level. + + if logQ > 60 { + panic("logQ must be between 1 and 60") + } + + var x, y, Qpow2, _2N uint64 + + primes = []uint64{} + + Qpow2 = 1 << logQ + + _2N = 2 << logN + + x = Qpow2 + 1 + y = Qpow2 + 1 + + for true { + + if ring.IsPrime(y) { + primes = append(primes, y) + if uint64(len(primes)) == levels { + return primes + } + } + + y -= _2N + + if ring.IsPrime(x) { + primes = append(primes, x) + if uint64(len(primes)) == levels { + return primes + } + } + + x += _2N + } + + return +}