Add cksProtocolContext for dbfv

This commit is contained in:
Christian Grigis
2019-11-28 19:10:01 +01:00
parent c0c089575c
commit fc2ecb87d6
4 changed files with 86 additions and 23 deletions

View File

@@ -215,7 +215,6 @@ func benchKeyswitching(b *testing.B) {
params := genDBFVContext(&parameters)
bfvContext := params.bfvContext
sk0Shards := params.sk0Shards
sk1Shards := params.sk1Shards
@@ -229,7 +228,7 @@ func benchKeyswitching(b *testing.B) {
}
p := new(Party)
p.CKSProtocol = NewCKSProtocol(bfvContext, 6.36)
p.CKSProtocol = NewCKSProtocol(&parameters, 6.36)
p.s0 = sk0Shards[0].Get()
p.s1 = sk1Shards[0].Get()
p.share = p.AllocateShare()

View File

@@ -361,7 +361,7 @@ func testKeyswitching(t *testing.T) {
cksParties := make([]*Party, parties)
for i := uint64(0); i < parties; i++ {
p := new(Party)
p.CKSProtocol = NewCKSProtocol(bfvContext, 6.36)
p.CKSProtocol = NewCKSProtocol(&parameters, 6.36)
p.s0 = sk0Shards[i].Get()
p.s1 = sk1Shards[i].Get()
p.share = p.AllocateShare()
@@ -809,7 +809,7 @@ func Test_Marshalling(t *testing.T) {
t.Run(fmt.Sprintf("CKS/N=%d/limbQ=%d/limbsP=%d", contextQ.N, len(contextQ.Modulus), len(contextPKeys.Modulus)), func(t *testing.T) {
//Now for CKSShare ~ its similar to PKSShare
cksp := NewCKSProtocol(bfvCtx, bfvCtx.Sigma())
cksp := NewCKSProtocol(params, bfvCtx.Sigma())
cksshare := cksp.AllocateShare()
skIn := KeyGenerator.NewSecretKey()
skOut := KeyGenerator.NewSecretKey()

View File

@@ -1,13 +1,77 @@
package dbfv
import (
"math/big"
"github.com/ldsec/lattigo/bfv"
"github.com/ldsec/lattigo/ring"
)
type cksProtocolContext struct {
// Polynomial contexts
contextQ *ring.Context
contextKeys *ring.Context
contextPKeys *ring.Context
specialPrimes []uint64
rescaleParamsKeys []uint64 // (P^-1) mod each qi
}
func newCksProtocolContext(params *bfv.Parameters) *cksProtocolContext {
n := params.N
contextQ := ring.NewContext()
contextQ.SetParameters(n, params.Qi)
err := contextQ.GenNTTParams()
if err != nil {
panic(err)
}
contextKeys := ring.NewContext()
contextKeys.SetParameters(n, append(params.Qi, params.KeySwitchPrimes...))
err = contextKeys.GenNTTParams()
if err != nil {
panic(err)
}
contextPKeys := ring.NewContext()
contextPKeys.SetParameters(n, params.KeySwitchPrimes)
err = contextPKeys.GenNTTParams()
if err != nil {
panic(err)
}
specialPrimes := make([]uint64, len(params.KeySwitchPrimes))
for i := range params.KeySwitchPrimes {
specialPrimes[i] = params.KeySwitchPrimes[i]
}
rescaleParamsKeys := make([]uint64, len(params.Qi))
PBig := ring.NewUint(1)
for _, pj := range specialPrimes {
PBig.Mul(PBig, ring.NewUint(pj))
}
tmp := new(big.Int)
bredParams := contextQ.GetBredParams()
for i, Qi := range params.Qi {
tmp.Mod(PBig, ring.NewUint(Qi))
rescaleParamsKeys[i] = ring.MForm(ring.ModExp(ring.BRedAdd(tmp.Uint64(), Qi, bredParams[i]), Qi-2, Qi), Qi, bredParams[i])
}
return &cksProtocolContext{
contextQ: contextQ,
contextKeys: contextKeys,
contextPKeys: contextPKeys,
specialPrimes: specialPrimes,
rescaleParamsKeys: rescaleParamsKeys,
}
}
// CKSProtocol is a structure storing the parameters for the collective key-switching protocol.
type CKSProtocol struct {
bfvContext *bfv.Context
context *cksProtocolContext
sigmaSmudging float64
gaussianSamplerSmudge *ring.KYSampler
@@ -35,19 +99,20 @@ func (share *CKSShare) UnmarshalBinary(data []byte) error {
// NewCKSProtocol creates a new CKSProtocol that will be used to operate a collective key-switching on a ciphertext encrypted under a collective public-key, whose
// secret-shares are distributed among j parties, re-encrypting the ciphertext under another public-key, whose secret-shares are also known to the
// parties.
func NewCKSProtocol(bfvContext *bfv.Context, sigmaSmudging float64) *CKSProtocol {
func NewCKSProtocol(params *bfv.Parameters, sigmaSmudging float64) *CKSProtocol {
context := newCksProtocolContext(params)
cks := new(CKSProtocol)
cks.bfvContext = bfvContext
cks.context = context
cks.gaussianSamplerSmudge = bfvContext.ContextKeys().NewKYSampler(sigmaSmudging, int(6*sigmaSmudging))
cks.gaussianSamplerSmudge = context.contextKeys.NewKYSampler(sigmaSmudging, int(6*sigmaSmudging))
cks.tmpNtt = cks.bfvContext.ContextKeys().NewPoly()
cks.tmpDelta = cks.bfvContext.ContextQ().NewPoly()
cks.hP = cks.bfvContext.ContextPKeys().NewPoly()
cks.tmpNtt = cks.context.contextKeys.NewPoly()
cks.tmpDelta = cks.context.contextQ.NewPoly()
cks.hP = cks.context.contextPKeys.NewPoly()
cks.baseconverter = ring.NewFastBasisExtender(cks.bfvContext.ContextQ().Modulus, cks.bfvContext.KeySwitchPrimes())
cks.baseconverter = ring.NewFastBasisExtender(cks.context.contextQ.Modulus, cks.context.specialPrimes)
return cks
}
@@ -55,8 +120,7 @@ func NewCKSProtocol(bfvContext *bfv.Context, sigmaSmudging float64) *CKSProtocol
// AllocateShare allocates the shares of the CKSProtocol
func (cks *CKSProtocol) AllocateShare() CKSShare {
//return cks.bfvContext.ContextQ().NewPoly()
return CKSShare{cks.bfvContext.ContextQ().NewPoly()}
return CKSShare{cks.context.contextQ.NewPoly()}
}
@@ -68,7 +132,7 @@ func (cks *CKSProtocol) AllocateShare() CKSShare {
// Each party then broadcast the result of this computation to the other j-1 parties.
func (cks *CKSProtocol) GenShare(skInput, skOutput *ring.Poly, ct *bfv.Ciphertext, shareOut CKSShare) {
cks.bfvContext.ContextQ().Sub(skInput, skOutput, cks.tmpDelta)
cks.context.contextQ.Sub(skInput, skOutput, cks.tmpDelta)
cks.genShareDelta(cks.tmpDelta, ct, shareOut)
}
@@ -77,8 +141,8 @@ func (cks *CKSProtocol) genShareDelta(skDelta *ring.Poly, ct *bfv.Ciphertext, sh
level := uint64(len(ct.Value()[1].Coeffs) - 1)
contextQ := cks.bfvContext.ContextQ()
contextP := cks.bfvContext.ContextPKeys()
contextQ := cks.context.contextQ
contextP := cks.context.contextPKeys
contextQ.NTT(ct.Value()[1], cks.tmpNtt)
contextQ.MulCoeffsMontgomery(cks.tmpNtt, skDelta, shareOut.Poly)
@@ -89,7 +153,7 @@ func (cks *CKSProtocol) genShareDelta(skDelta *ring.Poly, ct *bfv.Ciphertext, sh
cks.gaussianSamplerSmudge.Sample(cks.tmpNtt)
contextQ.Add(shareOut.Poly, cks.tmpNtt, shareOut.Poly)
for x, i := 0, uint64(len(contextQ.Modulus)); i < uint64(len(cks.bfvContext.ContextKeys().Modulus)); x, i = x+1, i+1 {
for x, i := 0, uint64(len(contextQ.Modulus)); i < uint64(len(cks.context.contextKeys.Modulus)); x, i = x+1, i+1 {
tmphP := cks.hP.Coeffs[x]
tmpNTT := cks.tmpNtt.Coeffs[i]
for j := uint64(0); j < contextQ.N; j++ {
@@ -97,7 +161,7 @@ func (cks *CKSProtocol) genShareDelta(skDelta *ring.Poly, ct *bfv.Ciphertext, sh
}
}
cks.baseconverter.ModDownSplited(contextQ, contextP, cks.bfvContext.RescaleParamsKeys(), level, shareOut.Poly, cks.hP, shareOut.Poly, cks.tmpNtt)
cks.baseconverter.ModDownSplited(contextQ, contextP, cks.context.rescaleParamsKeys, level, shareOut.Poly, cks.hP, shareOut.Poly, cks.tmpNtt)
cks.tmpNtt.Zero()
cks.hP.Zero()
@@ -107,11 +171,11 @@ func (cks *CKSProtocol) genShareDelta(skDelta *ring.Poly, ct *bfv.Ciphertext, sh
//
// [ctx[0] + sum((skInput_i - skOutput_i) * ctx[0] + e_i), ctx[1]]
func (cks *CKSProtocol) AggregateShares(share1, share2, shareOut CKSShare) {
cks.bfvContext.ContextQ().Add(share1.Poly, share2.Poly, shareOut.Poly)
cks.context.contextQ.Add(share1.Poly, share2.Poly, shareOut.Poly)
}
// KeySwitch performs the actual keyswitching operation on a ciphertext ct and put the result in ctOut
func (cks *CKSProtocol) KeySwitch(combined CKSShare, ct *bfv.Ciphertext, ctOut *bfv.Ciphertext) {
cks.bfvContext.ContextQ().Add(ct.Value()[0], combined.Poly, ctOut.Value()[0])
cks.bfvContext.ContextQ().Copy(ct.Value()[1], ctOut.Value()[1])
cks.context.contextQ.Add(ct.Value()[0], combined.Poly, ctOut.Value()[0])
cks.context.contextQ.Copy(ct.Value()[1], ctOut.Value()[1])
}

View File

@@ -105,7 +105,7 @@ func main() {
ckg := dbfv.NewCKGProtocol(bfvctx) // public key generation
rkg := dbfv.NewEkgProtocol(bfvctx) // relineariation key generation
rtg := dbfv.NewRotKGProtocol(bfvctx) // rotation keys generation
cks := dbfv.NewCKSProtocol(bfvctx, 3.19) // collective public-key re-encryption
cks := dbfv.NewCKSProtocol(params, 3.19) // collective public-key re-encryption
// Creates each party, and allocates the memory for all the shares that the protocols will need
P := make([]*party, N, N)