From 6991072a2ff00c91ea014c90be2e2c16b6faa7bc Mon Sep 17 00:00:00 2001 From: Christian Grigis Date: Wed, 27 Nov 2019 19:34:53 +0100 Subject: [PATCH] Added keygen context for bfv --- bfv/bfv_benchmark_test.go | 2 +- bfv/bfv_test.go | 6 +- bfv/keygen.go | 110 ++++++++++++++++++++++++++++++----- dbfv/dbfv_benchmark_test.go | 8 +-- dbfv/dbfv_test.go | 24 ++++---- examples/bfv/examples_bfv.go | 2 +- examples/dbfv/pir/pir.go | 8 +-- examples/dbfv/psi/psi.go | 8 +-- 8 files changed, 127 insertions(+), 41 deletions(-) diff --git a/bfv/bfv_benchmark_test.go b/bfv/bfv_benchmark_test.go index 014ff973..5ce40291 100755 --- a/bfv/bfv_benchmark_test.go +++ b/bfv/bfv_benchmark_test.go @@ -18,7 +18,7 @@ func Benchmark_BFV(b *testing.B) { var pk *PublicKey var err error - kgen := bfvContext.NewKeyGenerator() + kgen := NewKeyGenerator(¶ms) // Public Key Generation b.Run(testString("KeyGen", ¶ms), func(b *testing.B) { diff --git a/bfv/bfv_test.go b/bfv/bfv_test.go index 93729290..0d241a97 100644 --- a/bfv/bfv_test.go +++ b/bfv/bfv_test.go @@ -174,7 +174,7 @@ func testMarshaller(t *testing.T) { t.Run(testString2("RotationKey/", params), func(t *testing.T) { - rotationKey := params.bfvContext.NewRotationKeys() + rotationKey := NewRotationKeys() params.kgen.GenRot(RotationRow, params.sk, 0, rotationKey) params.kgen.GenRot(RotationLeft, params.sk, 1, rotationKey) @@ -246,7 +246,7 @@ func genBfvParams(contextParameters *Parameters) (params *bfvParams) { params.bfvContext = NewContextWithParam(contextParameters) - params.kgen = params.bfvContext.NewKeyGenerator() + params.kgen = NewKeyGenerator(contextParameters) params.sk, params.pk = params.kgen.NewKeyPair() @@ -512,7 +512,7 @@ func testRotateRows(t *testing.T) { params := genBfvParams(parameters) - rotkey := params.bfvContext.NewRotationKeys() + rotkey := NewRotationKeys() params.kgen.GenRot(RotationRow, params.sk, 0, rotkey) t.Run(testString2("InPlace/", params), func(t *testing.T) { diff --git a/bfv/keygen.go b/bfv/keygen.go index abcb7b71..d16f87de 100644 --- a/bfv/keygen.go +++ b/bfv/keygen.go @@ -1,13 +1,88 @@ package bfv import ( + "math" + "github.com/ldsec/lattigo/ring" ) +type keyGeneratorContext struct { + // Polynomial degree + n uint64 + + // Polynomial contexts + contextKeys *ring.Context + specialPrimes []uint64 + alpha uint64 + beta uint64 + + // Ternary and Gaussian samplers + gaussianSampler *ring.KYSampler + + // Galois elements used to permute the batched plaintext in the encrypted domain + galElRotRow uint64 + galElRotColLeft []uint64 + galElRotColRight []uint64 +} + +func newKeyGeneratorContext(params *Parameters) *keyGeneratorContext { + n := params.N + + contextKeys := ring.NewContext() + + contextKeys.SetParameters(params.N, append(params.Qi, params.KeySwitchPrimes...)) + + err := contextKeys.GenNTTParams() + if err != nil { + panic(err) + } + + specialPrimes := make([]uint64, len(params.KeySwitchPrimes)) + for i := range params.KeySwitchPrimes { + specialPrimes[i] = params.KeySwitchPrimes[i] + } + + alpha := uint64(len(specialPrimes)) + beta := uint64(math.Ceil(float64(len(params.Qi)) / float64(alpha))) + + gaussianSampler := contextKeys.NewKYSampler(params.Sigma, int(6*params.Sigma)) + + gen := GaloisGen + genInv := ring.ModExp(gen, (n<<1)-1, n<<1) + + mask := (n << 1) - 1 + + galElRotColLeft := make([]uint64, n>>1) + galElRotColRight := make([]uint64, n>>1) + + galElRotColRight[0] = 1 + galElRotColLeft[0] = 1 + + for i := uint64(1); i < n>>1; i++ { + galElRotColLeft[i] = (galElRotColLeft[i-1] * gen) & mask + galElRotColRight[i] = (galElRotColRight[i-1] * genInv) & mask + + } + + galElRotRow := (n << 1) - 1 + + return &keyGeneratorContext{ + n: n, + contextKeys: contextKeys, + specialPrimes: specialPrimes, + alpha: alpha, + beta: beta, + gaussianSampler: gaussianSampler, + galElRotRow: galElRotRow, + galElRotColLeft: galElRotColLeft, + galElRotColRight: galElRotColRight, + } +} + // KeyGenerator is a structure that stores the elements required to create new keys, // as well as a small memory pool for intermediate values. type KeyGenerator struct { - context *Context + context *keyGeneratorContext polypool *ring.Poly } @@ -54,10 +129,12 @@ func (swk *SwitchingKey) Get() [][2]*ring.Poly { // NewKeyGenerator creates a new KeyGenerator, from which the secret and public keys, as well as the evaluation, // rotation and switching keys can be generated. -func (context *Context) NewKeyGenerator() (keygen *KeyGenerator) { +func NewKeyGenerator(params *Parameters) (keygen *KeyGenerator) { + context := newKeyGeneratorContext(params) + keygen = new(KeyGenerator) keygen.context = context - keygen.polypool = keygen.context.ContextKeys().NewPoly() + keygen.polypool = keygen.context.contextKeys.NewPoly() return } @@ -74,7 +151,9 @@ func (keygen *KeyGenerator) NewSecretkeyWithDistrib(p float64) (sk *SecretKey) { } // NewSecretKey creates a new SecretKey zero values. -func (context *Context) NewSecretKey() *SecretKey { +func NewSecretKey(params *Parameters) *SecretKey { + context := newKeyGeneratorContext(params) + sk := new(SecretKey) sk.sk = context.contextKeys.NewPoly() return sk @@ -95,7 +174,7 @@ func (keygen *KeyGenerator) NewPublicKey(sk *SecretKey) (pk *PublicKey) { pk = new(PublicKey) - ringContext := keygen.context.ContextKeys() + ringContext := keygen.context.contextKeys //pk[0] = [-(a*s + e)] //pk[1] = [a] @@ -109,7 +188,9 @@ func (keygen *KeyGenerator) NewPublicKey(sk *SecretKey) (pk *PublicKey) { } // NewPublicKey creates a new PublicKey with zero values. -func (context *Context) NewPublicKey() (pk *PublicKey) { +func NewPublicKey(params *Parameters) (pk *PublicKey) { + context := newKeyGeneratorContext(params) + pk = new(PublicKey) pk.pk[0] = context.contextKeys.NewPoly() pk.pk[1] = context.contextKeys.NewPoly() @@ -145,7 +226,7 @@ func (keygen *KeyGenerator) NewRelinKey(sk *SecretKey, maxDegree uint64) (evk *E ringContext := keygen.context.contextKeys - for _, pj := range keygen.context.specialprimes { + for _, pj := range keygen.context.specialPrimes { ringContext.MulScalar(keygen.polypool, pj, keygen.polypool) } @@ -160,7 +241,8 @@ func (keygen *KeyGenerator) NewRelinKey(sk *SecretKey, maxDegree uint64) (evk *E } // NewRelinKey creates a new EvaluationKey with zero values. -func (context *Context) NewRelinKey(maxDegree uint64) (evakey *EvaluationKey) { +func NewRelinKey(maxDegree uint64, params *Parameters) (evakey *EvaluationKey) { + context := newKeyGeneratorContext(params) evakey = new(EvaluationKey) @@ -211,7 +293,7 @@ func (keygen *KeyGenerator) NewSwitchingKey(skIn, skOut *SecretKey) (evk *Switch ringContext.Sub(skIn.Get(), skOut.Get(), keygen.polypool) - for _, pj := range keygen.context.specialprimes { + for _, pj := range keygen.context.specialPrimes { ringContext.MulScalar(keygen.polypool, pj, keygen.polypool) } @@ -222,7 +304,9 @@ func (keygen *KeyGenerator) NewSwitchingKey(skIn, skOut *SecretKey) (evk *Switch } // NewSwitchingKey creates a new SwitchingKey with zero values. -func (context *Context) NewSwitchingKey() (evakey *SwitchingKey) { +func NewSwitchingKey(params *Parameters) (evakey *SwitchingKey) { + context := newKeyGeneratorContext(params) + evakey = new(SwitchingKey) // delta_sk = skIn - skOut = GaloisEnd(skOut, rotation) - skOut @@ -236,7 +320,7 @@ func (context *Context) NewSwitchingKey() (evakey *SwitchingKey) { return } -func newswitchintkey(context *Context, skIn, skOut *ring.Poly) (switchkey *SwitchingKey) { +func newswitchintkey(context *keyGeneratorContext, skIn, skOut *ring.Poly) (switchkey *SwitchingKey) { switchkey = new(SwitchingKey) @@ -286,7 +370,7 @@ func newswitchintkey(context *Context, skIn, skOut *ring.Poly) (switchkey *Switc } // NewRotationKeys returns a new empty RotationKeys struct. -func (context *Context) NewRotationKeys() (rotKey *RotationKeys) { +func NewRotationKeys() (rotKey *RotationKeys) { rotKey = new(RotationKeys) return } @@ -382,7 +466,7 @@ func genrotkey(keygen *KeyGenerator, sk *ring.Poly, gen uint64) (switchkey *Swit ring.PermuteNTT(sk, gen, keygen.polypool) ringContext.Sub(keygen.polypool, sk, keygen.polypool) - for _, pj := range keygen.context.specialprimes { + for _, pj := range keygen.context.specialPrimes { ringContext.MulScalar(keygen.polypool, pj, keygen.polypool) } diff --git a/dbfv/dbfv_benchmark_test.go b/dbfv/dbfv_benchmark_test.go index 35e4148d..a8047475 100644 --- a/dbfv/dbfv_benchmark_test.go +++ b/dbfv/dbfv_benchmark_test.go @@ -57,7 +57,7 @@ func benchPublicKeyGen(b *testing.B) { } }) - pk := bfvContext.NewPublicKey() + pk := bfv.NewPublicKey(¶meters) b.Run(testString("Finalize", ¶meters), func(b *testing.B) { for i := 0; i < b.N; i++ { p.GenPublicKey(p.s1, crp, pk) @@ -92,7 +92,7 @@ func benchRelinKeyGen(b *testing.B) { p.u = p.RKGProtocol.NewEphemeralKey(1.0 / 3.0) p.s = sk0Shards[0].Get() p.share1, p.share2, p.share3 = p.RKGProtocol.AllocateShares() - p.rlk = bfvContext.NewRelinKey(2) + p.rlk = bfv.NewRelinKey(2, ¶meters) crpGenerator := ring.NewCRPGenerator(nil, bfvContext.ContextKeys()) @@ -171,7 +171,7 @@ func benchRelinKeyGenNaive(b *testing.B) { p.RKGProtocolNaive = NewRKGProtocolNaive(bfvContext) p.s = sk0Shards[0].Get() p.share1, p.share2 = p.AllocateShares() - p.rlk = bfvContext.NewRelinKey(2) + p.rlk = bfv.NewRelinKey(2, ¶meters) b.Run(testString("Round1/Gen", ¶meters), func(b *testing.B) { @@ -353,7 +353,7 @@ func benchRotKeyGen(b *testing.B) { } }) - rotKey := bfvContext.NewRotationKeys() + rotKey := bfv.NewRotationKeys() b.Run(testString("Finalize", ¶meters), func(b *testing.B) { for i := 0; i < b.N; i++ { diff --git a/dbfv/dbfv_test.go b/dbfv/dbfv_test.go index e42ab529..a36600c2 100644 --- a/dbfv/dbfv_test.go +++ b/dbfv/dbfv_test.go @@ -74,7 +74,7 @@ func genDBFVContext(contextParameters *bfv.Parameters) (params *dbfvContext) { params.encoder = bfv.NewEncoder(contextParameters) params.evaluator = bfv.NewEvaluator(contextParameters) - kgen := params.bfvContext.NewKeyGenerator() + kgen := bfv.NewKeyGenerator(contextParameters) // SecretKeys params.sk0Shards = make([]*bfv.SecretKey, testParams.parties) @@ -236,7 +236,7 @@ func testRelinKeyGen(t *testing.T) { } } - evk := bfvContext.NewRelinKey(1) + evk := bfv.NewRelinKey(1, ¶meters) P0.GenRelinearizationKey(P0.share2, P0.share3, evk) coeffs, _, ciphertext := newTestVectors(params, encryptorPk0, t) @@ -312,7 +312,7 @@ func testRelinKeyGenNaive(t *testing.T) { } } - evk := bfvContext.NewRelinKey(1) + evk := bfv.NewRelinKey(1, ¶meters) P0.GenRelinearizationKey(P0.share2, evk) coeffs, _, ciphertext := newTestVectors(params, encryptorPk0, t) @@ -492,7 +492,7 @@ func testRotKeyGenRotRows(t *testing.T) { } } - rotkey := bfvContext.NewRotationKeys() + rotkey := bfv.NewRotationKeys() P0.Finalize(P0.share, crp, rotkey) coeffs, _, ciphertext := newTestVectors(params, encryptorPk0, t) @@ -566,7 +566,7 @@ func testRotKeyGenRotCols(t *testing.T) { } } - rotkey := bfvContext.NewRotationKeys() + rotkey := bfv.NewRotationKeys() P0.Finalize(P0.share, crp, rotkey) evaluator.RotateColumns(ciphertext, k, rotkey, receiver) @@ -600,7 +600,7 @@ func testRefresh(t *testing.T) { encoder := params.encoder decryptorSk0 := params.decryptorSk0 - kgen := bfvContext.NewKeyGenerator() + kgen := bfv.NewKeyGenerator(¶meters) rlk := kgen.NewRelinKey(params.sk0, 2) @@ -717,17 +717,18 @@ func verifyTestVectors(contextParams *dbfvContext, decryptor *bfv.Decryptor, coe } func Test_Marshalling(t *testing.T) { + params := &bfv.DefaultParams[1] //verify if the un.marshalling works properly - bfvCtx := bfv.NewContextWithParam(&bfv.DefaultParams[1]) - KeyGenerator := bfvCtx.NewKeyGenerator() + bfvCtx := bfv.NewContextWithParam(params) + KeyGenerator := bfv.NewKeyGenerator(params) crsGen := ring.NewCRPGenerator([]byte{'l', 'a', 't', 't', 'i', 'g', 'o'}, bfvCtx.ContextKeys()) sk := KeyGenerator.NewSecretKey() crs := crsGen.ClockNew() contextQ := bfvCtx.ContextQ() contextPKeys := bfvCtx.ContextPKeys() - ringCtx := bfv.NewRingContext(&bfv.DefaultParams[1]) + ringCtx := bfv.NewRingContext(params) Ciphertext := bfv.NewRandomCiphertext(1, ringCtx) t.Run(fmt.Sprintf("CPK/N=%d/limbQ=%d/limbsP=%d", contextQ.N, len(contextQ.Modulus), len(contextPKeys.Modulus)), func(t *testing.T) { @@ -926,8 +927,9 @@ func Test_Marshalling(t *testing.T) { } func Test_Relin_Marshalling(t *testing.T) { + params := &bfv.DefaultParams[1] - bfvCtx := bfv.NewContextWithParam(&bfv.DefaultParams[1]) + bfvCtx := bfv.NewContextWithParam(params) contextQ := bfvCtx.ContextQ() contextPKeys := bfvCtx.ContextPKeys() modulus := bfvCtx.ContextQ().Modulus @@ -946,7 +948,7 @@ func Test_Relin_Marshalling(t *testing.T) { rlk := NewEkgProtocol(bfvCtx) u := rlk.NewEphemeralKey(1 / 3.0) - sk := bfvCtx.NewKeyGenerator().NewSecretKey() + sk := bfv.NewKeyGenerator(params).NewSecretKey() log.Print("Starting to test marshalling for share one") r1, r2, r3 := rlk.AllocateShares() diff --git a/examples/bfv/examples_bfv.go b/examples/bfv/examples_bfv.go index 9fe281b3..f7e73ecc 100644 --- a/examples/bfv/examples_bfv.go +++ b/examples/bfv/examples_bfv.go @@ -57,7 +57,7 @@ func obliviousRiding() { bfvContext := bfv.NewContextWithParam(¶ms) // Rider's keygen - kgen := bfvContext.NewKeyGenerator() + kgen := bfv.NewKeyGenerator(¶ms) riderSk, riderPk := kgen.NewKeyPair() diff --git a/examples/dbfv/pir/pir.go b/examples/dbfv/pir/pir.go index 9aa7b491..af6d043b 100644 --- a/examples/dbfv/pir/pir.go +++ b/examples/dbfv/pir/pir.go @@ -111,7 +111,7 @@ func main() { P := make([]*party, N, N) for i := range P { pi := &party{} - pi.sk = bfvctx.NewKeyGenerator().NewSecretKey() + pi.sk = bfv.NewKeyGenerator(params).NewSecretKey() pi.rlkEphemSk = bfvctx.ContextKeys().SampleTernaryMontgomeryNTTNew(1.0 / 3) pi.input = make([]uint64, params.N, params.N) for j := range pi.input { @@ -136,7 +136,7 @@ func main() { // 1) Collective public key generation l.Println("> CKG Phase") - pk := bfvctx.NewPublicKey() + pk := bfv.NewPublicKey(params) elapsedCKGParty = runTimedParty(func() { for _, pi := range P { ckg.GenShare(pi.sk.Get(), crs, pi.ckgShare) @@ -185,7 +185,7 @@ func main() { } }, N) - rlk := bfvctx.NewRelinKey(1) + rlk := bfv.NewRelinKey(1, params) elapsedRKGCloud += runTimed(func() { for _, pi := range P { rkg.AggregateShareRoundThree(pi.rkgShareThree, rkgCombined3, rkgCombined3) @@ -196,7 +196,7 @@ func main() { // 3) Collective rotation keys geneneration l.Println("> RTG Phase") - rtk := bfvctx.NewRotationKeys() + rtk := bfv.NewRotationKeys() for _, rot := range []bfv.Rotation{bfv.RotationRight, bfv.RotationLeft, bfv.RotationRow} { for k := uint64(1); (rot == bfv.RotationRow && k == 1) || (rot != bfv.RotationRow && k < bfvctx.ContextKeys().N>>1); k <<= 1 { diff --git a/examples/dbfv/psi/psi.go b/examples/dbfv/psi/psi.go index 04bcabf9..1353921c 100644 --- a/examples/dbfv/psi/psi.go +++ b/examples/dbfv/psi/psi.go @@ -72,7 +72,7 @@ func main() { crp[i] = crsGen.ClockNew() } - tsk, tpk := bfvctx.NewKeyGenerator().NewKeyPair() + tsk, tpk := bfv.NewKeyGenerator(params).NewKeyPair() colSk := &bfv.SecretKey{} colSk.Set(bfvctx.ContextKeys().NewPoly()) @@ -88,7 +88,7 @@ func main() { P := make([]*party, N, N) for i := range P { pi := &party{} - pi.sk = bfvctx.NewKeyGenerator().NewSecretKey() + pi.sk = bfv.NewKeyGenerator(params).NewSecretKey() pi.rlkEphemSk = bfvctx.ContextKeys().SampleTernaryMontgomeryNTTNew(1.0 / 3) pi.input = make([]uint64, params.N, params.N) for i := range pi.input { @@ -112,7 +112,7 @@ func main() { var elapsedRKGParty time.Duration l.Println("> CKG Phase") - pk := bfvctx.NewPublicKey() + pk := bfv.NewPublicKey(params) elapsedCKGParty = runTimedParty(func() { for _, pi := range P { ckg.GenShare(pi.sk.Get(), crs, pi.ckgShare) @@ -160,7 +160,7 @@ func main() { } }, N) - rlk := bfvctx.NewRelinKey(1) + rlk := bfv.NewRelinKey(1, params) elapsedRKGCloud += runTimed(func() { for _, pi := range P { rkg.AggregateShareRoundThree(pi.rkgShareThree, rkgCombined3, rkgCombined3)