From f0a8e4688caa809a8bfeeffc68b79cb87938cde8 Mon Sep 17 00:00:00 2001 From: Christian Grigis Date: Thu, 28 Nov 2019 21:22:41 +0100 Subject: [PATCH] Add pcksProtocolContext for dbfv --- dbfv/dbfv_benchmark_test.go | 3 +- dbfv/dbfv_test.go | 4 +- dbfv/public_keyswitching.go | 104 +++++++++++++++++++++++++++++------- examples/dbfv/psi/psi.go | 2 +- 4 files changed, 89 insertions(+), 24 deletions(-) diff --git a/dbfv/dbfv_benchmark_test.go b/dbfv/dbfv_benchmark_test.go index 962ba349..26894d3c 100644 --- a/dbfv/dbfv_benchmark_test.go +++ b/dbfv/dbfv_benchmark_test.go @@ -263,7 +263,6 @@ func benchPublicKeySwitching(b *testing.B) { params := genDBFVContext(¶meters) - bfvContext := params.bfvContext sk0Shards := params.sk0Shards pk1 := params.pk1 @@ -278,7 +277,7 @@ func benchPublicKeySwitching(b *testing.B) { } p := new(Party) - p.PCKSProtocol = NewPCKSProtocol(bfvContext, 6.36) + p.PCKSProtocol = NewPCKSProtocol(¶meters, 6.36) p.s = sk0Shards[0].Get() p.share = p.AllocateShares() diff --git a/dbfv/dbfv_test.go b/dbfv/dbfv_test.go index b1afd7e8..c4816955 100644 --- a/dbfv/dbfv_test.go +++ b/dbfv/dbfv_test.go @@ -419,7 +419,7 @@ func testPublicKeySwitching(t *testing.T) { pcksParties := make([]*Party, parties) for i := uint64(0); i < parties; i++ { p := new(Party) - p.PCKSProtocol = NewPCKSProtocol(bfvContext, 6.36) + p.PCKSProtocol = NewPCKSProtocol(¶meters, 6.36) p.s = sk0Shards[i].Get() p.share = p.AllocateShares() pcksParties[i] = p @@ -770,7 +770,7 @@ func Test_Marshalling(t *testing.T) { t.Run(fmt.Sprintf("PCKS/N=%d/limbQ=%d/limbsP=%d", contextQ.N, len(contextQ.Modulus), len(contextPKeys.Modulus)), func(t *testing.T) { //Check marshalling for the PCKS - KeySwitchProtocol := NewPCKSProtocol(bfvCtx, bfvCtx.Sigma()) + KeySwitchProtocol := NewPCKSProtocol(params, bfvCtx.Sigma()) SwitchShare := KeySwitchProtocol.AllocateShares() pk := KeyGenerator.NewPublicKey(sk) KeySwitchProtocol.GenShare(sk.Get(), pk, Ciphertext, SwitchShare) diff --git a/dbfv/public_keyswitching.go b/dbfv/public_keyswitching.go index c5a180e8..d7767214 100644 --- a/dbfv/public_keyswitching.go +++ b/dbfv/public_keyswitching.go @@ -1,13 +1,78 @@ package dbfv import ( + "math/big" + "github.com/ldsec/lattigo/bfv" "github.com/ldsec/lattigo/ring" ) +type pcksProtocolContext struct { + // Polynomial degree + n uint64 + + // Ternary and Gaussian samplers + gaussianSampler *ring.KYSampler + + // Polynomial contexts + contextQ *ring.Context + + contextKeys *ring.Context + specialPrimes []uint64 + rescaleParamsKeys []uint64 // (P^-1) mod each qi +} + +func newPcksProtocolContext(params *bfv.Parameters) *pcksProtocolContext { + n := params.N + + contextQ := ring.NewContext() + contextQ.SetParameters(n, params.Qi) + err := contextQ.GenNTTParams() + if err != nil { + panic(err) + } + + contextKeys := ring.NewContext() + contextKeys.SetParameters(n, append(params.Qi, params.KeySwitchPrimes...)) + err = contextKeys.GenNTTParams() + if err != nil { + panic(err) + } + + specialPrimes := make([]uint64, len(params.KeySwitchPrimes)) + for i := range params.KeySwitchPrimes { + specialPrimes[i] = params.KeySwitchPrimes[i] + } + + rescaleParamsKeys := make([]uint64, len(params.Qi)) + + PBig := ring.NewUint(1) + for _, pj := range specialPrimes { + PBig.Mul(PBig, ring.NewUint(pj)) + } + + tmp := new(big.Int) + bredParams := contextQ.GetBredParams() + for i, Qi := range params.Qi { + tmp.Mod(PBig, ring.NewUint(Qi)) + rescaleParamsKeys[i] = ring.MForm(ring.ModExp(ring.BRedAdd(tmp.Uint64(), Qi, bredParams[i]), Qi-2, Qi), Qi, bredParams[i]) + } + + gaussianSampler := contextKeys.NewKYSampler(params.Sigma, int(6*params.Sigma)) + + return &pcksProtocolContext{ + n: n, + contextQ: contextQ, + contextKeys: contextKeys, + specialPrimes: specialPrimes, + gaussianSampler: gaussianSampler, + rescaleParamsKeys: rescaleParamsKeys, + } +} + // PCKSProtocol is the structure storing the parameters for the collective public key-switching. type PCKSProtocol struct { - bfvContext *bfv.Context + context *pcksProtocolContext sigmaSmudging float64 gaussianSamplerSmudge *ring.KYSampler @@ -73,27 +138,28 @@ func (share *PCKSShare) UnmarshalBinary(data []byte) error { // NewPCKSProtocol creates a new PCKSProtocol object and will be used to re-encrypt a ciphertext ctx encrypted under a secret-shared key among j parties under a new // collective public-key. -func NewPCKSProtocol(bfvContext *bfv.Context, sigmaSmudging float64) *PCKSProtocol { +func NewPCKSProtocol(params *bfv.Parameters, sigmaSmudging float64) *PCKSProtocol { + context := newPcksProtocolContext(params) pcks := new(PCKSProtocol) - pcks.bfvContext = bfvContext + pcks.context = context - pcks.gaussianSamplerSmudge = bfvContext.ContextKeys().NewKYSampler(sigmaSmudging, int(6*sigmaSmudging)) + pcks.gaussianSamplerSmudge = context.contextKeys.NewKYSampler(sigmaSmudging, int(6*sigmaSmudging)) - pcks.tmp = bfvContext.ContextKeys().NewPoly() - pcks.share0tmp = bfvContext.ContextKeys().NewPoly() - pcks.share1tmp = bfvContext.ContextKeys().NewPoly() + pcks.tmp = context.contextKeys.NewPoly() + pcks.share0tmp = context.contextKeys.NewPoly() + pcks.share1tmp = context.contextKeys.NewPoly() - pcks.baseconverter = ring.NewFastBasisExtender(bfvContext.ContextQ().Modulus, bfvContext.KeySwitchPrimes()) + pcks.baseconverter = ring.NewFastBasisExtender(context.contextQ.Modulus, context.specialPrimes) return pcks } // AllocateShares allocates the shares of the PCKS protocol func (pcks *PCKSProtocol) AllocateShares() (s PCKSShare) { - s[0] = pcks.bfvContext.ContextQ().NewPoly() - s[1] = pcks.bfvContext.ContextQ().NewPoly() + s[0] = pcks.context.contextQ.NewPoly() + s[1] = pcks.context.contextQ.NewPoly() return } @@ -104,8 +170,8 @@ func (pcks *PCKSProtocol) AllocateShares() (s PCKSShare) { // and broadcasts the result to the other j-1 parties. func (pcks *PCKSProtocol) GenShare(sk *ring.Poly, pk *bfv.PublicKey, ct *bfv.Ciphertext, shareOut PCKSShare) { - contextQ := pcks.bfvContext.ContextQ() - contextKeys := pcks.bfvContext.ContextKeys() + contextQ := pcks.context.contextQ + contextKeys := pcks.context.contextKeys contextKeys.SampleTernaryMontgomeryNTT(pcks.tmp, 0.5) @@ -120,14 +186,14 @@ func (pcks *PCKSProtocol) GenShare(sk *ring.Poly, pk *bfv.PublicKey, ct *bfv.Cip // h_0 = u_i * pk_0 + e0 pcks.gaussianSamplerSmudge.SampleAndAdd(pcks.share0tmp) // h_1 = u_i * pk_1 + e1 - pcks.bfvContext.GaussianSampler().SampleAndAdd(pcks.share1tmp) + pcks.context.gaussianSampler.SampleAndAdd(pcks.share1tmp) // h_0 = (u_i * pk_0 + e0)/P - pcks.baseconverter.ModDown(contextKeys, pcks.bfvContext.RescaleParamsKeys(), uint64(len(contextQ.Modulus))-1, pcks.share0tmp, shareOut[0], pcks.tmp) + pcks.baseconverter.ModDown(contextKeys, pcks.context.rescaleParamsKeys, uint64(len(contextQ.Modulus))-1, pcks.share0tmp, shareOut[0], pcks.tmp) // h_0 = (u_i * pk_0 + e0)/P // Could be moved to the keyswitch phase, but the second element of the shares will be larger - pcks.baseconverter.ModDown(contextKeys, pcks.bfvContext.RescaleParamsKeys(), uint64(len(contextQ.Modulus))-1, pcks.share1tmp, shareOut[1], pcks.tmp) + pcks.baseconverter.ModDown(contextKeys, pcks.context.rescaleParamsKeys, uint64(len(contextQ.Modulus))-1, pcks.share1tmp, shareOut[1], pcks.tmp) // tmp = s_i*c_1 contextQ.NTT(ct.Value()[1], pcks.tmp) @@ -147,13 +213,13 @@ func (pcks *PCKSProtocol) GenShare(sk *ring.Poly, pk *bfv.PublicKey, ct *bfv.Cip // [ctx[0] + sum(s_i * ctx[0] + u_i * pk[0] + e_0i), sum(u_i * pk[1] + e_1i)] func (pcks *PCKSProtocol) AggregateShares(share1, share2, shareOut PCKSShare) { - pcks.bfvContext.ContextQ().Add(share1[0], share2[0], shareOut[0]) - pcks.bfvContext.ContextQ().Add(share1[1], share2[1], shareOut[1]) + pcks.context.contextQ.Add(share1[0], share2[0], shareOut[0]) + pcks.context.contextQ.Add(share1[1], share2[1], shareOut[1]) } // KeySwitch performs the actual keyswitching operation on a ciphertext ct and put the result in ctOut func (pcks *PCKSProtocol) KeySwitch(combined PCKSShare, ct, ctOut *bfv.Ciphertext) { - pcks.bfvContext.ContextQ().Add(ct.Value()[0], combined[0], ctOut.Value()[0]) - pcks.bfvContext.ContextQ().Copy(combined[1], ctOut.Value()[1]) + pcks.context.contextQ.Add(ct.Value()[0], combined[0], ctOut.Value()[0]) + pcks.context.contextQ.Copy(combined[1], ctOut.Value()[1]) } diff --git a/examples/dbfv/psi/psi.go b/examples/dbfv/psi/psi.go index 96b37c6c..14269928 100644 --- a/examples/dbfv/psi/psi.go +++ b/examples/dbfv/psi/psi.go @@ -83,7 +83,7 @@ func main() { ckg := dbfv.NewCKGProtocol(bfvctx) rkg := dbfv.NewEkgProtocol(params) - pcks := dbfv.NewPCKSProtocol(bfvctx, 3.19) + pcks := dbfv.NewPCKSProtocol(params, 3.19) P := make([]*party, N, N) for i := range P {