From fc2ecb87d6bc683dfdacfc0779aa7639016e033e Mon Sep 17 00:00:00 2001 From: Christian Grigis Date: Thu, 28 Nov 2019 19:10:01 +0100 Subject: [PATCH] Add cksProtocolContext for dbfv --- dbfv/dbfv_benchmark_test.go | 3 +- dbfv/dbfv_test.go | 4 +- dbfv/keyswitching.go | 100 +++++++++++++++++++++++++++++------- examples/dbfv/pir/pir.go | 2 +- 4 files changed, 86 insertions(+), 23 deletions(-) diff --git a/dbfv/dbfv_benchmark_test.go b/dbfv/dbfv_benchmark_test.go index a8047475..4eb996d7 100644 --- a/dbfv/dbfv_benchmark_test.go +++ b/dbfv/dbfv_benchmark_test.go @@ -215,7 +215,6 @@ func benchKeyswitching(b *testing.B) { params := genDBFVContext(¶meters) - bfvContext := params.bfvContext sk0Shards := params.sk0Shards sk1Shards := params.sk1Shards @@ -229,7 +228,7 @@ func benchKeyswitching(b *testing.B) { } p := new(Party) - p.CKSProtocol = NewCKSProtocol(bfvContext, 6.36) + p.CKSProtocol = NewCKSProtocol(¶meters, 6.36) p.s0 = sk0Shards[0].Get() p.s1 = sk1Shards[0].Get() p.share = p.AllocateShare() diff --git a/dbfv/dbfv_test.go b/dbfv/dbfv_test.go index a36600c2..2587900e 100644 --- a/dbfv/dbfv_test.go +++ b/dbfv/dbfv_test.go @@ -361,7 +361,7 @@ func testKeyswitching(t *testing.T) { cksParties := make([]*Party, parties) for i := uint64(0); i < parties; i++ { p := new(Party) - p.CKSProtocol = NewCKSProtocol(bfvContext, 6.36) + p.CKSProtocol = NewCKSProtocol(¶meters, 6.36) p.s0 = sk0Shards[i].Get() p.s1 = sk1Shards[i].Get() p.share = p.AllocateShare() @@ -809,7 +809,7 @@ func Test_Marshalling(t *testing.T) { t.Run(fmt.Sprintf("CKS/N=%d/limbQ=%d/limbsP=%d", contextQ.N, len(contextQ.Modulus), len(contextPKeys.Modulus)), func(t *testing.T) { //Now for CKSShare ~ its similar to PKSShare - cksp := NewCKSProtocol(bfvCtx, bfvCtx.Sigma()) + cksp := NewCKSProtocol(params, bfvCtx.Sigma()) cksshare := cksp.AllocateShare() skIn := KeyGenerator.NewSecretKey() skOut := KeyGenerator.NewSecretKey() diff --git a/dbfv/keyswitching.go b/dbfv/keyswitching.go index 4ef35f4c..98d6cdff 100644 --- a/dbfv/keyswitching.go +++ b/dbfv/keyswitching.go @@ -1,13 +1,77 @@ package dbfv import ( + "math/big" + "github.com/ldsec/lattigo/bfv" "github.com/ldsec/lattigo/ring" ) +type cksProtocolContext struct { + // Polynomial contexts + contextQ *ring.Context + + contextKeys *ring.Context + contextPKeys *ring.Context + specialPrimes []uint64 + rescaleParamsKeys []uint64 // (P^-1) mod each qi +} + +func newCksProtocolContext(params *bfv.Parameters) *cksProtocolContext { + n := params.N + + contextQ := ring.NewContext() + contextQ.SetParameters(n, params.Qi) + err := contextQ.GenNTTParams() + if err != nil { + panic(err) + } + + contextKeys := ring.NewContext() + contextKeys.SetParameters(n, append(params.Qi, params.KeySwitchPrimes...)) + err = contextKeys.GenNTTParams() + if err != nil { + panic(err) + } + + contextPKeys := ring.NewContext() + contextPKeys.SetParameters(n, params.KeySwitchPrimes) + err = contextPKeys.GenNTTParams() + if err != nil { + panic(err) + } + + specialPrimes := make([]uint64, len(params.KeySwitchPrimes)) + for i := range params.KeySwitchPrimes { + specialPrimes[i] = params.KeySwitchPrimes[i] + } + + rescaleParamsKeys := make([]uint64, len(params.Qi)) + + PBig := ring.NewUint(1) + for _, pj := range specialPrimes { + PBig.Mul(PBig, ring.NewUint(pj)) + } + + tmp := new(big.Int) + bredParams := contextQ.GetBredParams() + for i, Qi := range params.Qi { + tmp.Mod(PBig, ring.NewUint(Qi)) + rescaleParamsKeys[i] = ring.MForm(ring.ModExp(ring.BRedAdd(tmp.Uint64(), Qi, bredParams[i]), Qi-2, Qi), Qi, bredParams[i]) + } + + return &cksProtocolContext{ + contextQ: contextQ, + contextKeys: contextKeys, + contextPKeys: contextPKeys, + specialPrimes: specialPrimes, + rescaleParamsKeys: rescaleParamsKeys, + } +} + // CKSProtocol is a structure storing the parameters for the collective key-switching protocol. type CKSProtocol struct { - bfvContext *bfv.Context + context *cksProtocolContext sigmaSmudging float64 gaussianSamplerSmudge *ring.KYSampler @@ -35,19 +99,20 @@ func (share *CKSShare) UnmarshalBinary(data []byte) error { // NewCKSProtocol creates a new CKSProtocol that will be used to operate a collective key-switching on a ciphertext encrypted under a collective public-key, whose // secret-shares are distributed among j parties, re-encrypting the ciphertext under another public-key, whose secret-shares are also known to the // parties. -func NewCKSProtocol(bfvContext *bfv.Context, sigmaSmudging float64) *CKSProtocol { +func NewCKSProtocol(params *bfv.Parameters, sigmaSmudging float64) *CKSProtocol { + context := newCksProtocolContext(params) cks := new(CKSProtocol) - cks.bfvContext = bfvContext + cks.context = context - cks.gaussianSamplerSmudge = bfvContext.ContextKeys().NewKYSampler(sigmaSmudging, int(6*sigmaSmudging)) + cks.gaussianSamplerSmudge = context.contextKeys.NewKYSampler(sigmaSmudging, int(6*sigmaSmudging)) - cks.tmpNtt = cks.bfvContext.ContextKeys().NewPoly() - cks.tmpDelta = cks.bfvContext.ContextQ().NewPoly() - cks.hP = cks.bfvContext.ContextPKeys().NewPoly() + cks.tmpNtt = cks.context.contextKeys.NewPoly() + cks.tmpDelta = cks.context.contextQ.NewPoly() + cks.hP = cks.context.contextPKeys.NewPoly() - cks.baseconverter = ring.NewFastBasisExtender(cks.bfvContext.ContextQ().Modulus, cks.bfvContext.KeySwitchPrimes()) + cks.baseconverter = ring.NewFastBasisExtender(cks.context.contextQ.Modulus, cks.context.specialPrimes) return cks } @@ -55,8 +120,7 @@ func NewCKSProtocol(bfvContext *bfv.Context, sigmaSmudging float64) *CKSProtocol // AllocateShare allocates the shares of the CKSProtocol func (cks *CKSProtocol) AllocateShare() CKSShare { - //return cks.bfvContext.ContextQ().NewPoly() - return CKSShare{cks.bfvContext.ContextQ().NewPoly()} + return CKSShare{cks.context.contextQ.NewPoly()} } @@ -68,7 +132,7 @@ func (cks *CKSProtocol) AllocateShare() CKSShare { // Each party then broadcast the result of this computation to the other j-1 parties. func (cks *CKSProtocol) GenShare(skInput, skOutput *ring.Poly, ct *bfv.Ciphertext, shareOut CKSShare) { - cks.bfvContext.ContextQ().Sub(skInput, skOutput, cks.tmpDelta) + cks.context.contextQ.Sub(skInput, skOutput, cks.tmpDelta) cks.genShareDelta(cks.tmpDelta, ct, shareOut) } @@ -77,8 +141,8 @@ func (cks *CKSProtocol) genShareDelta(skDelta *ring.Poly, ct *bfv.Ciphertext, sh level := uint64(len(ct.Value()[1].Coeffs) - 1) - contextQ := cks.bfvContext.ContextQ() - contextP := cks.bfvContext.ContextPKeys() + contextQ := cks.context.contextQ + contextP := cks.context.contextPKeys contextQ.NTT(ct.Value()[1], cks.tmpNtt) contextQ.MulCoeffsMontgomery(cks.tmpNtt, skDelta, shareOut.Poly) @@ -89,7 +153,7 @@ func (cks *CKSProtocol) genShareDelta(skDelta *ring.Poly, ct *bfv.Ciphertext, sh cks.gaussianSamplerSmudge.Sample(cks.tmpNtt) contextQ.Add(shareOut.Poly, cks.tmpNtt, shareOut.Poly) - for x, i := 0, uint64(len(contextQ.Modulus)); i < uint64(len(cks.bfvContext.ContextKeys().Modulus)); x, i = x+1, i+1 { + for x, i := 0, uint64(len(contextQ.Modulus)); i < uint64(len(cks.context.contextKeys.Modulus)); x, i = x+1, i+1 { tmphP := cks.hP.Coeffs[x] tmpNTT := cks.tmpNtt.Coeffs[i] for j := uint64(0); j < contextQ.N; j++ { @@ -97,7 +161,7 @@ func (cks *CKSProtocol) genShareDelta(skDelta *ring.Poly, ct *bfv.Ciphertext, sh } } - cks.baseconverter.ModDownSplited(contextQ, contextP, cks.bfvContext.RescaleParamsKeys(), level, shareOut.Poly, cks.hP, shareOut.Poly, cks.tmpNtt) + cks.baseconverter.ModDownSplited(contextQ, contextP, cks.context.rescaleParamsKeys, level, shareOut.Poly, cks.hP, shareOut.Poly, cks.tmpNtt) cks.tmpNtt.Zero() cks.hP.Zero() @@ -107,11 +171,11 @@ func (cks *CKSProtocol) genShareDelta(skDelta *ring.Poly, ct *bfv.Ciphertext, sh // // [ctx[0] + sum((skInput_i - skOutput_i) * ctx[0] + e_i), ctx[1]] func (cks *CKSProtocol) AggregateShares(share1, share2, shareOut CKSShare) { - cks.bfvContext.ContextQ().Add(share1.Poly, share2.Poly, shareOut.Poly) + cks.context.contextQ.Add(share1.Poly, share2.Poly, shareOut.Poly) } // KeySwitch performs the actual keyswitching operation on a ciphertext ct and put the result in ctOut func (cks *CKSProtocol) KeySwitch(combined CKSShare, ct *bfv.Ciphertext, ctOut *bfv.Ciphertext) { - cks.bfvContext.ContextQ().Add(ct.Value()[0], combined.Poly, ctOut.Value()[0]) - cks.bfvContext.ContextQ().Copy(ct.Value()[1], ctOut.Value()[1]) + cks.context.contextQ.Add(ct.Value()[0], combined.Poly, ctOut.Value()[0]) + cks.context.contextQ.Copy(ct.Value()[1], ctOut.Value()[1]) } diff --git a/examples/dbfv/pir/pir.go b/examples/dbfv/pir/pir.go index af6d043b..9bb1bb30 100644 --- a/examples/dbfv/pir/pir.go +++ b/examples/dbfv/pir/pir.go @@ -105,7 +105,7 @@ func main() { ckg := dbfv.NewCKGProtocol(bfvctx) // public key generation rkg := dbfv.NewEkgProtocol(bfvctx) // relineariation key generation rtg := dbfv.NewRotKGProtocol(bfvctx) // rotation keys generation - cks := dbfv.NewCKSProtocol(bfvctx, 3.19) // collective public-key re-encryption + cks := dbfv.NewCKSProtocol(params, 3.19) // collective public-key re-encryption // Creates each party, and allocates the memory for all the shares that the protocols will need P := make([]*party, N, N)