mirror of
https://github.com/tuneinsight/lattigo.git
synced 2025-09-13 03:27:14 +00:00
Add cksProtocolContext for dbfv
This commit is contained in:
@@ -215,7 +215,6 @@ func benchKeyswitching(b *testing.B) {
|
||||
|
||||
params := genDBFVContext(¶meters)
|
||||
|
||||
bfvContext := params.bfvContext
|
||||
sk0Shards := params.sk0Shards
|
||||
sk1Shards := params.sk1Shards
|
||||
|
||||
@@ -229,7 +228,7 @@ func benchKeyswitching(b *testing.B) {
|
||||
}
|
||||
|
||||
p := new(Party)
|
||||
p.CKSProtocol = NewCKSProtocol(bfvContext, 6.36)
|
||||
p.CKSProtocol = NewCKSProtocol(¶meters, 6.36)
|
||||
p.s0 = sk0Shards[0].Get()
|
||||
p.s1 = sk1Shards[0].Get()
|
||||
p.share = p.AllocateShare()
|
||||
|
||||
@@ -361,7 +361,7 @@ func testKeyswitching(t *testing.T) {
|
||||
cksParties := make([]*Party, parties)
|
||||
for i := uint64(0); i < parties; i++ {
|
||||
p := new(Party)
|
||||
p.CKSProtocol = NewCKSProtocol(bfvContext, 6.36)
|
||||
p.CKSProtocol = NewCKSProtocol(¶meters, 6.36)
|
||||
p.s0 = sk0Shards[i].Get()
|
||||
p.s1 = sk1Shards[i].Get()
|
||||
p.share = p.AllocateShare()
|
||||
@@ -809,7 +809,7 @@ func Test_Marshalling(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("CKS/N=%d/limbQ=%d/limbsP=%d", contextQ.N, len(contextQ.Modulus), len(contextPKeys.Modulus)), func(t *testing.T) {
|
||||
|
||||
//Now for CKSShare ~ its similar to PKSShare
|
||||
cksp := NewCKSProtocol(bfvCtx, bfvCtx.Sigma())
|
||||
cksp := NewCKSProtocol(params, bfvCtx.Sigma())
|
||||
cksshare := cksp.AllocateShare()
|
||||
skIn := KeyGenerator.NewSecretKey()
|
||||
skOut := KeyGenerator.NewSecretKey()
|
||||
|
||||
@@ -1,13 +1,77 @@
|
||||
package dbfv
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
|
||||
"github.com/ldsec/lattigo/bfv"
|
||||
"github.com/ldsec/lattigo/ring"
|
||||
)
|
||||
|
||||
type cksProtocolContext struct {
|
||||
// Polynomial contexts
|
||||
contextQ *ring.Context
|
||||
|
||||
contextKeys *ring.Context
|
||||
contextPKeys *ring.Context
|
||||
specialPrimes []uint64
|
||||
rescaleParamsKeys []uint64 // (P^-1) mod each qi
|
||||
}
|
||||
|
||||
func newCksProtocolContext(params *bfv.Parameters) *cksProtocolContext {
|
||||
n := params.N
|
||||
|
||||
contextQ := ring.NewContext()
|
||||
contextQ.SetParameters(n, params.Qi)
|
||||
err := contextQ.GenNTTParams()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
contextKeys := ring.NewContext()
|
||||
contextKeys.SetParameters(n, append(params.Qi, params.KeySwitchPrimes...))
|
||||
err = contextKeys.GenNTTParams()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
contextPKeys := ring.NewContext()
|
||||
contextPKeys.SetParameters(n, params.KeySwitchPrimes)
|
||||
err = contextPKeys.GenNTTParams()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
specialPrimes := make([]uint64, len(params.KeySwitchPrimes))
|
||||
for i := range params.KeySwitchPrimes {
|
||||
specialPrimes[i] = params.KeySwitchPrimes[i]
|
||||
}
|
||||
|
||||
rescaleParamsKeys := make([]uint64, len(params.Qi))
|
||||
|
||||
PBig := ring.NewUint(1)
|
||||
for _, pj := range specialPrimes {
|
||||
PBig.Mul(PBig, ring.NewUint(pj))
|
||||
}
|
||||
|
||||
tmp := new(big.Int)
|
||||
bredParams := contextQ.GetBredParams()
|
||||
for i, Qi := range params.Qi {
|
||||
tmp.Mod(PBig, ring.NewUint(Qi))
|
||||
rescaleParamsKeys[i] = ring.MForm(ring.ModExp(ring.BRedAdd(tmp.Uint64(), Qi, bredParams[i]), Qi-2, Qi), Qi, bredParams[i])
|
||||
}
|
||||
|
||||
return &cksProtocolContext{
|
||||
contextQ: contextQ,
|
||||
contextKeys: contextKeys,
|
||||
contextPKeys: contextPKeys,
|
||||
specialPrimes: specialPrimes,
|
||||
rescaleParamsKeys: rescaleParamsKeys,
|
||||
}
|
||||
}
|
||||
|
||||
// CKSProtocol is a structure storing the parameters for the collective key-switching protocol.
|
||||
type CKSProtocol struct {
|
||||
bfvContext *bfv.Context
|
||||
context *cksProtocolContext
|
||||
|
||||
sigmaSmudging float64
|
||||
gaussianSamplerSmudge *ring.KYSampler
|
||||
@@ -35,19 +99,20 @@ func (share *CKSShare) UnmarshalBinary(data []byte) error {
|
||||
// NewCKSProtocol creates a new CKSProtocol that will be used to operate a collective key-switching on a ciphertext encrypted under a collective public-key, whose
|
||||
// secret-shares are distributed among j parties, re-encrypting the ciphertext under another public-key, whose secret-shares are also known to the
|
||||
// parties.
|
||||
func NewCKSProtocol(bfvContext *bfv.Context, sigmaSmudging float64) *CKSProtocol {
|
||||
func NewCKSProtocol(params *bfv.Parameters, sigmaSmudging float64) *CKSProtocol {
|
||||
context := newCksProtocolContext(params)
|
||||
|
||||
cks := new(CKSProtocol)
|
||||
|
||||
cks.bfvContext = bfvContext
|
||||
cks.context = context
|
||||
|
||||
cks.gaussianSamplerSmudge = bfvContext.ContextKeys().NewKYSampler(sigmaSmudging, int(6*sigmaSmudging))
|
||||
cks.gaussianSamplerSmudge = context.contextKeys.NewKYSampler(sigmaSmudging, int(6*sigmaSmudging))
|
||||
|
||||
cks.tmpNtt = cks.bfvContext.ContextKeys().NewPoly()
|
||||
cks.tmpDelta = cks.bfvContext.ContextQ().NewPoly()
|
||||
cks.hP = cks.bfvContext.ContextPKeys().NewPoly()
|
||||
cks.tmpNtt = cks.context.contextKeys.NewPoly()
|
||||
cks.tmpDelta = cks.context.contextQ.NewPoly()
|
||||
cks.hP = cks.context.contextPKeys.NewPoly()
|
||||
|
||||
cks.baseconverter = ring.NewFastBasisExtender(cks.bfvContext.ContextQ().Modulus, cks.bfvContext.KeySwitchPrimes())
|
||||
cks.baseconverter = ring.NewFastBasisExtender(cks.context.contextQ.Modulus, cks.context.specialPrimes)
|
||||
|
||||
return cks
|
||||
}
|
||||
@@ -55,8 +120,7 @@ func NewCKSProtocol(bfvContext *bfv.Context, sigmaSmudging float64) *CKSProtocol
|
||||
// AllocateShare allocates the shares of the CKSProtocol
|
||||
func (cks *CKSProtocol) AllocateShare() CKSShare {
|
||||
|
||||
//return cks.bfvContext.ContextQ().NewPoly()
|
||||
return CKSShare{cks.bfvContext.ContextQ().NewPoly()}
|
||||
return CKSShare{cks.context.contextQ.NewPoly()}
|
||||
|
||||
}
|
||||
|
||||
@@ -68,7 +132,7 @@ func (cks *CKSProtocol) AllocateShare() CKSShare {
|
||||
// Each party then broadcast the result of this computation to the other j-1 parties.
|
||||
func (cks *CKSProtocol) GenShare(skInput, skOutput *ring.Poly, ct *bfv.Ciphertext, shareOut CKSShare) {
|
||||
|
||||
cks.bfvContext.ContextQ().Sub(skInput, skOutput, cks.tmpDelta)
|
||||
cks.context.contextQ.Sub(skInput, skOutput, cks.tmpDelta)
|
||||
|
||||
cks.genShareDelta(cks.tmpDelta, ct, shareOut)
|
||||
}
|
||||
@@ -77,8 +141,8 @@ func (cks *CKSProtocol) genShareDelta(skDelta *ring.Poly, ct *bfv.Ciphertext, sh
|
||||
|
||||
level := uint64(len(ct.Value()[1].Coeffs) - 1)
|
||||
|
||||
contextQ := cks.bfvContext.ContextQ()
|
||||
contextP := cks.bfvContext.ContextPKeys()
|
||||
contextQ := cks.context.contextQ
|
||||
contextP := cks.context.contextPKeys
|
||||
|
||||
contextQ.NTT(ct.Value()[1], cks.tmpNtt)
|
||||
contextQ.MulCoeffsMontgomery(cks.tmpNtt, skDelta, shareOut.Poly)
|
||||
@@ -89,7 +153,7 @@ func (cks *CKSProtocol) genShareDelta(skDelta *ring.Poly, ct *bfv.Ciphertext, sh
|
||||
cks.gaussianSamplerSmudge.Sample(cks.tmpNtt)
|
||||
contextQ.Add(shareOut.Poly, cks.tmpNtt, shareOut.Poly)
|
||||
|
||||
for x, i := 0, uint64(len(contextQ.Modulus)); i < uint64(len(cks.bfvContext.ContextKeys().Modulus)); x, i = x+1, i+1 {
|
||||
for x, i := 0, uint64(len(contextQ.Modulus)); i < uint64(len(cks.context.contextKeys.Modulus)); x, i = x+1, i+1 {
|
||||
tmphP := cks.hP.Coeffs[x]
|
||||
tmpNTT := cks.tmpNtt.Coeffs[i]
|
||||
for j := uint64(0); j < contextQ.N; j++ {
|
||||
@@ -97,7 +161,7 @@ func (cks *CKSProtocol) genShareDelta(skDelta *ring.Poly, ct *bfv.Ciphertext, sh
|
||||
}
|
||||
}
|
||||
|
||||
cks.baseconverter.ModDownSplited(contextQ, contextP, cks.bfvContext.RescaleParamsKeys(), level, shareOut.Poly, cks.hP, shareOut.Poly, cks.tmpNtt)
|
||||
cks.baseconverter.ModDownSplited(contextQ, contextP, cks.context.rescaleParamsKeys, level, shareOut.Poly, cks.hP, shareOut.Poly, cks.tmpNtt)
|
||||
|
||||
cks.tmpNtt.Zero()
|
||||
cks.hP.Zero()
|
||||
@@ -107,11 +171,11 @@ func (cks *CKSProtocol) genShareDelta(skDelta *ring.Poly, ct *bfv.Ciphertext, sh
|
||||
//
|
||||
// [ctx[0] + sum((skInput_i - skOutput_i) * ctx[0] + e_i), ctx[1]]
|
||||
func (cks *CKSProtocol) AggregateShares(share1, share2, shareOut CKSShare) {
|
||||
cks.bfvContext.ContextQ().Add(share1.Poly, share2.Poly, shareOut.Poly)
|
||||
cks.context.contextQ.Add(share1.Poly, share2.Poly, shareOut.Poly)
|
||||
}
|
||||
|
||||
// KeySwitch performs the actual keyswitching operation on a ciphertext ct and put the result in ctOut
|
||||
func (cks *CKSProtocol) KeySwitch(combined CKSShare, ct *bfv.Ciphertext, ctOut *bfv.Ciphertext) {
|
||||
cks.bfvContext.ContextQ().Add(ct.Value()[0], combined.Poly, ctOut.Value()[0])
|
||||
cks.bfvContext.ContextQ().Copy(ct.Value()[1], ctOut.Value()[1])
|
||||
cks.context.contextQ.Add(ct.Value()[0], combined.Poly, ctOut.Value()[0])
|
||||
cks.context.contextQ.Copy(ct.Value()[1], ctOut.Value()[1])
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ func main() {
|
||||
ckg := dbfv.NewCKGProtocol(bfvctx) // public key generation
|
||||
rkg := dbfv.NewEkgProtocol(bfvctx) // relineariation key generation
|
||||
rtg := dbfv.NewRotKGProtocol(bfvctx) // rotation keys generation
|
||||
cks := dbfv.NewCKSProtocol(bfvctx, 3.19) // collective public-key re-encryption
|
||||
cks := dbfv.NewCKSProtocol(params, 3.19) // collective public-key re-encryption
|
||||
|
||||
// Creates each party, and allocates the memory for all the shares that the protocols will need
|
||||
P := make([]*party, N, N)
|
||||
|
||||
Reference in New Issue
Block a user