diff --git a/dbfv/dbfv_benchmark_test.go b/dbfv/dbfv_benchmark_test.go index ec2cda8a..cf5d3f9a 100644 --- a/dbfv/dbfv_benchmark_test.go +++ b/dbfv/dbfv_benchmark_test.go @@ -88,7 +88,7 @@ func benchRelinKeyGen(b *testing.B) { } p := new(Party) - p.RKGProtocol = NewEkgProtocol(bfvContext) + p.RKGProtocol = NewEkgProtocol(¶meters) p.u = p.RKGProtocol.NewEphemeralKey(1.0 / 3.0) p.s = sk0Shards[0].Get() p.share1, p.share2, p.share3 = p.RKGProtocol.AllocateShares() diff --git a/dbfv/dbfv_test.go b/dbfv/dbfv_test.go index a69197e9..504f0a36 100644 --- a/dbfv/dbfv_test.go +++ b/dbfv/dbfv_test.go @@ -195,7 +195,7 @@ func testRelinKeyGen(t *testing.T) { for i := range rkgParties { p := new(Party) - p.RKGProtocol = NewEkgProtocol(bfvContext) + p.RKGProtocol = NewEkgProtocol(¶meters) p.u = p.RKGProtocol.NewEphemeralKey(1.0 / 3.0) p.s = sk0Shards[i].Get() p.share1, p.share2, p.share3 = p.RKGProtocol.AllocateShares() @@ -946,7 +946,7 @@ func Test_Relin_Marshalling(t *testing.T) { t.Run(fmt.Sprintf("RLKG/N=%d/limbQ=%d/limbsP=%d", contextQ.N, len(contextQ.Modulus), len(contextPKeys.Modulus)), func(t *testing.T) { - rlk := NewEkgProtocol(bfvCtx) + rlk := NewEkgProtocol(params) u := rlk.NewEphemeralKey(1 / 3.0) sk := bfv.NewKeyGenerator(params).NewSecretKey() log.Print("Starting to test marshalling for share one") diff --git a/dbfv/relinkey_gen.go b/dbfv/relinkey_gen.go index f52abab8..2fda2e56 100644 --- a/dbfv/relinkey_gen.go +++ b/dbfv/relinkey_gen.go @@ -7,11 +7,56 @@ import ( "math" ) +type pkgProtocolContext struct { + // Polynomial degree + n uint64 + + // Ternary and Gaussian samplers + gaussianSampler *ring.KYSampler + + contextKeys *ring.Context + contextPKeys *ring.Context + specialPrimes []uint64 +} + +func newPkgProtocolContext(params *bfv.Parameters) *pkgProtocolContext { + n := params.N + + contextKeys := ring.NewContext() + contextKeys.SetParameters(n, append(params.Qi, params.KeySwitchPrimes...)) + err := contextKeys.GenNTTParams() + if err != nil { + panic(err) + } + + contextPKeys := ring.NewContext() + contextPKeys.SetParameters(n, params.KeySwitchPrimes) + err = contextPKeys.GenNTTParams() + if err != nil { + panic(err) + } + + specialPrimes := make([]uint64, len(params.KeySwitchPrimes)) + for i := range params.KeySwitchPrimes { + specialPrimes[i] = params.KeySwitchPrimes[i] + } + + gaussianSampler := contextKeys.NewKYSampler(params.Sigma, int(6*params.Sigma)) + + return &pkgProtocolContext{ + n: n, + gaussianSampler: gaussianSampler, + contextKeys: contextKeys, + contextPKeys: contextPKeys, + specialPrimes: specialPrimes, + } +} + // RKGProtocol is the structure storing the parameters and state for a party in the collective relinearization key // generation protocol. type RKGProtocol struct { ringContext *ring.Context - bfvContext *bfv.Context + context *pkgProtocolContext keyswitchprimes []uint64 alpha uint64 beta uint64 @@ -190,21 +235,22 @@ func (ekg *RKGProtocol) AllocateShares() (r1 RKGShareRoundOne, r2 RKGShareRoundT // NewEkgProtocol creates a new RKGProtocol object that will be used to generate a collective evaluation-key // among j parties in the given context with the given bit-decomposition. -func NewEkgProtocol(context *bfv.Context) *RKGProtocol { +func NewEkgProtocol(params *bfv.Parameters) *RKGProtocol { + context := newPkgProtocolContext(params) ekg := new(RKGProtocol) - ekg.ringContext = context.ContextKeys() - ekg.bfvContext = context + ekg.ringContext = context.contextKeys + ekg.context = context - ekg.keyswitchprimes = make([]uint64, len(context.KeySwitchPrimes())) - for i, pi := range context.KeySwitchPrimes() { + ekg.keyswitchprimes = make([]uint64, len(context.specialPrimes)) + for i, pi := range context.specialPrimes { ekg.keyswitchprimes[i] = pi } ekg.alpha = uint64(len(ekg.keyswitchprimes)) ekg.beta = uint64(math.Ceil(float64(len(ekg.ringContext.Modulus)-len(ekg.keyswitchprimes)) / float64(ekg.alpha))) - ekg.gaussianSampler = context.GaussianSampler() + ekg.gaussianSampler = context.gaussianSampler ekg.tmpPoly1 = ekg.ringContext.NewPoly() ekg.tmpPoly2 = ekg.ringContext.NewPoly() @@ -233,7 +279,7 @@ func (ekg *RKGProtocol) GenShareRoundOne(u, sk *ring.Poly, crp []*ring.Poly, sha ekg.polypool.Copy(sk) - ekg.ringContext.MulScalarBigint(ekg.polypool, ekg.bfvContext.ContextPKeys().ModulusBigint, ekg.polypool) + ekg.ringContext.MulScalarBigint(ekg.polypool, ekg.context.contextPKeys.ModulusBigint, ekg.polypool) ekg.ringContext.InvMForm(ekg.polypool, ekg.polypool) diff --git a/examples/dbfv/pir/pir.go b/examples/dbfv/pir/pir.go index 4d7d452c..ad9a250c 100644 --- a/examples/dbfv/pir/pir.go +++ b/examples/dbfv/pir/pir.go @@ -103,7 +103,7 @@ func main() { // Instantiation of each of the protocols needed for the pir example ckg := dbfv.NewCKGProtocol(bfvctx) // public key generation - rkg := dbfv.NewEkgProtocol(bfvctx) // relineariation key generation + rkg := dbfv.NewEkgProtocol(params) // relineariation key generation rtg := dbfv.NewRotKGProtocol(params) // rotation keys generation cks := dbfv.NewCKSProtocol(params, 3.19) // collective public-key re-encryption diff --git a/examples/dbfv/psi/psi.go b/examples/dbfv/psi/psi.go index 1353921c..96b37c6c 100644 --- a/examples/dbfv/psi/psi.go +++ b/examples/dbfv/psi/psi.go @@ -82,7 +82,7 @@ func main() { } ckg := dbfv.NewCKGProtocol(bfvctx) - rkg := dbfv.NewEkgProtocol(bfvctx) + rkg := dbfv.NewEkgProtocol(params) pcks := dbfv.NewPCKSProtocol(bfvctx, 3.19) P := make([]*party, N, N)