From 3871936fdac64b35f5e7a7a585dc7c209b6e82c5 Mon Sep 17 00:00:00 2001 From: Christian M Date: Thu, 3 Sep 2020 13:54:53 +0200 Subject: [PATCH] [dbfv/dckks] two-rounds relinearization key generation --- Makefile | 3 + dbfv/dbfv_benchmark_test.go | 23 +-- dbfv/dbfv_test.go | 47 ++----- dbfv/relinkey_gen.go | 193 ++++++------------------- dckks/dckks_benchmark_test.go | 23 +-- dckks/dckks_test.go | 19 +-- dckks/relinkey_gen.go | 256 +++++++++++++++++++--------------- examples/dbfv/pir/pir.go | 43 +++--- examples/dbfv/psi/psi.go | 41 ++---- 9 files changed, 241 insertions(+), 407 deletions(-) diff --git a/Makefile b/Makefile index a7eae1f4..cbd6df0b 100644 --- a/Makefile +++ b/Makefile @@ -32,6 +32,9 @@ test_local: go run ./examples/bfv/examples_bfv.go > /dev/null go run ./examples/ckks/euler/euler.go > /dev/null go run ./examples/ckks/sigmoid/sigmoid.go > /dev/null + go run ./examples/dbfv/pir/pir.go &> /dev/null + go run ./examples/dbfv/psi/psi.go &> /dev/null + @echo ok test: test_fmt test_local diff --git a/dbfv/dbfv_benchmark_test.go b/dbfv/dbfv_benchmark_test.go index bb53d746..ad71b488 100644 --- a/dbfv/dbfv_benchmark_test.go +++ b/dbfv/dbfv_benchmark_test.go @@ -84,9 +84,8 @@ func benchRelinKeyGen(b *testing.B) { *RKGProtocol u *ring.Poly s *ring.Poly - share1 RKGShareRoundOne - share2 RKGShareRoundTwo - share3 RKGShareRoundThree + share1 RKGShare + share2 RKGShare rlk *bfv.EvaluationKey } @@ -95,7 +94,7 @@ func benchRelinKeyGen(b *testing.B) { p.RKGProtocol = NewEkgProtocol(parameters) p.u = p.RKGProtocol.NewEphemeralKey() p.s = sk0Shards[0].Get() - p.share1, p.share2, p.share3 = p.RKGProtocol.AllocateShares() + p.share1, p.share2 = p.RKGProtocol.AllocateShares() p.rlk = bfv.NewRelinKey(parameters, 2) prng, err := utils.NewKeyedPRNG(nil) if err != nil { @@ -124,7 +123,7 @@ func benchRelinKeyGen(b *testing.B) { b.Run(testString("Round2/Gen", parties, parameters), func(b *testing.B) { for i := 0; i < b.N; i++ { - p.GenShareRoundTwo(p.share1, p.s, crp, p.share2) + p.GenShareRoundTwo(p.share1, p.u, p.s, crp, p.share2) } }) @@ -134,21 +133,9 @@ func benchRelinKeyGen(b *testing.B) { } }) - b.Run(testString("Round3/Gen", parties, parameters), func(b *testing.B) { - for i := 0; i < b.N; i++ { - p.GenShareRoundThree(p.share2, p.u, p.s, p.share3) - } - }) - - b.Run(testString("Round3/Agg", parties, parameters), func(b *testing.B) { - for i := 0; i < b.N; i++ { - p.AggregateShareRoundThree(p.share3, p.share3, p.share3) - } - }) - b.Run(testString("Finalize", parties, parameters), func(b *testing.B) { for i := 0; i < b.N; i++ { - p.GenRelinearizationKey(p.share2, p.share3, p.rlk) + p.GenRelinearizationKey(p.share1, p.share2, p.rlk) } }) } diff --git a/dbfv/dbfv_test.go b/dbfv/dbfv_test.go index bf4f193b..76721f3f 100644 --- a/dbfv/dbfv_test.go +++ b/dbfv/dbfv_test.go @@ -190,9 +190,8 @@ func testRelinKeyGen(t *testing.T) { *RKGProtocol u *ring.Poly s *ring.Poly - share1 RKGShareRoundOne - share2 RKGShareRoundTwo - share3 RKGShareRoundThree + share1 RKGShare + share2 RKGShare } rkgParties := make([]*Party, parties) @@ -202,7 +201,7 @@ func testRelinKeyGen(t *testing.T) { p.RKGProtocol = NewEkgProtocol(parameters) p.u = p.RKGProtocol.NewEphemeralKey() p.s = sk0Shards[i].Get() - p.share1, p.share2, p.share3 = p.RKGProtocol.AllocateShares() + p.share1, p.share2 = p.RKGProtocol.AllocateShares() rkgParties[i] = p } @@ -229,22 +228,14 @@ func testRelinKeyGen(t *testing.T) { //ROUND 2 for i, p := range rkgParties { - p.GenShareRoundTwo(P0.share1, p.s, crp, p.share2) + p.GenShareRoundTwo(P0.share1, p.u, p.s, crp, p.share2) if i > 0 { P0.AggregateShareRoundTwo(p.share2, P0.share2, P0.share2) } } - // ROUND 3 - for i, p := range rkgParties { - p.GenShareRoundThree(P0.share2, p.u, p.s, p.share3) - if i > 0 { - P0.AggregateShareRoundThree(p.share3, P0.share3, P0.share3) - } - } - evk := bfv.NewRelinKey(parameters, 1) - P0.GenRelinearizationKey(P0.share2, P0.share3, evk) + P0.GenRelinearizationKey(P0.share1, P0.share2, evk) coeffs, _, ciphertext := newTestVectors(testCtx, encryptorPk0, t) @@ -918,28 +909,28 @@ func Test_Relin_Marshalling(t *testing.T) { u := rlk.NewEphemeralKey() sk := bfv.NewKeyGenerator(params).GenSecretKey() - r1, r2, r3 := rlk.AllocateShares() + r1, r2 := rlk.AllocateShares() rlk.GenShareRoundOne(u, sk.Get(), crp, r1) data, err := r1.MarshalBinary() require.NoError(t, err) - r1After := new(RKGShareRoundOne) + r1After := new(RKGShare) err = r1After.UnmarshalBinary(data) require.NoError(t, err) for i := 0; i < (len(r1)); i++ { - a := r1[i] - b := (*r1After)[i] + a := r1[i][0] + b := (*r1After)[i][0] moduli := a.GetLenModuli() require.Equal(t, a.Coeffs[:moduli], b.Coeffs[:moduli]) } - rlk.GenShareRoundTwo(r1, sk.Get(), crp, r2) + rlk.GenShareRoundTwo(r1, u, sk.Get(), crp, r2) data, err = r2.MarshalBinary() require.NoError(t, err) - r2After := new(RKGShareRoundTwo) + r2After := new(RKGShare) err = r2After.UnmarshalBinary(data) require.NoError(t, err) @@ -952,22 +943,6 @@ func Test_Relin_Marshalling(t *testing.T) { } } - - rlk.GenShareRoundThree(r2, u, sk.Get(), r3) - - data, err = r3.MarshalBinary() - require.NoError(t, err) - - r3After := new(RKGShareRoundThree) - err = r3After.UnmarshalBinary(data) - require.NoError(t, err) - - for i := 0; i < (len(r3)); i++ { - a := r3[i] - b := (*r3After)[i] - moduli := a.GetLenModuli() - require.Equal(t, a.Coeffs[:moduli], b.Coeffs[:moduli]) - } }) } diff --git a/dbfv/relinkey_gen.go b/dbfv/relinkey_gen.go index 1b1a9a77..e17b483c 100644 --- a/dbfv/relinkey_gen.go +++ b/dbfv/relinkey_gen.go @@ -2,6 +2,7 @@ package dbfv import ( "errors" + "github.com/ldsec/lattigo/bfv" "github.com/ldsec/lattigo/ring" "github.com/ldsec/lattigo/utils" @@ -18,63 +19,15 @@ type RKGProtocol struct { ternarySamplerMontgomery *ring.TernarySampler } -// RKGShareRoundOne is a struct storing the round one RKG shares. -type RKGShareRoundOne []*ring.Poly - -// RKGShareRoundTwo is a struct storing the round two RKG shares. -type RKGShareRoundTwo [][2]*ring.Poly - -// RKGShareRoundThree is a struct storing the round three RKG shares. -type RKGShareRoundThree []*ring.Poly +type RKGShare [][2]*ring.Poly // MarshalBinary encodes the target element on a slice of bytes. -func (share *RKGShareRoundOne) MarshalBinary() ([]byte, error) { - rLength := (*share)[0].GetDataLen(true) - data := make([]byte, 1+rLength*uint64(len(*share))) - data[0] = uint8(len(*share)) - - pointer := uint64(1) - for _, s := range *share { - tmp, err := s.WriteTo(data[pointer : pointer+rLength]) - if err != nil { - return []byte{}, err - } - pointer += tmp - } - - return data, nil -} - -// UnmarshalBinary decodes a slice of bytes on the target element. -func (share *RKGShareRoundOne) UnmarshalBinary(data []byte) error { - //share.modulus = data[0] - lenShare := data[0] - rLength := len(data[1:]) / int(lenShare) - if *share == nil { - *share = make([]*ring.Poly, lenShare) - } - ptr := 1 - for i := uint8(0); i < lenShare; i++ { - if (*share)[i] == nil { - (*share)[i] = new(ring.Poly) - } - err := (*share)[i].UnmarshalBinary(data[ptr : ptr+rLength]) - if err != nil { - return err - } - ptr += rLength - } - - return nil -} - -// MarshalBinary encodes the target element on a slice of bytes. -func (share *RKGShareRoundTwo) MarshalBinary() ([]byte, error) { +func (share *RKGShare) MarshalBinary() ([]byte, error) { //we have modulus * bitLog * Len of 1 ring rings rLength := ((*share)[0])[0].GetDataLen(true) data := make([]byte, 1+2*rLength*uint64(len(*share))) if len(*share) > 0xFF { - return []byte{}, errors.New("RKGShareRoundTwo : uint8 overflow on length") + return []byte{}, errors.New("RKGShare : uint8 overflow on length") } data[0] = uint8(len(*share)) @@ -99,7 +52,7 @@ func (share *RKGShareRoundTwo) MarshalBinary() ([]byte, error) { } // UnmarshalBinary decodes a slice of bytes on the target element. -func (share *RKGShareRoundTwo) UnmarshalBinary(data []byte) error { +func (share *RKGShare) UnmarshalBinary(data []byte) error { lenShare := data[0] rLength := (len(data) - 1) / (2 * int(lenShare)) @@ -129,59 +82,16 @@ func (share *RKGShareRoundTwo) UnmarshalBinary(data []byte) error { return nil } -// MarshalBinary encodes the target element on a slice of bytes. -func (share *RKGShareRoundThree) MarshalBinary() ([]byte, error) { - rLength := (*share)[0].GetDataLen(true) - data := make([]byte, 1+rLength*uint64(len(*share))) - data[0] = uint8(len(*share)) - - pointer := uint64(1) - for _, s := range *share { - tmp, err := s.WriteTo(data[pointer : pointer+rLength]) - if err != nil { - return []byte{}, err - } - pointer += tmp - } - - return data, nil -} - -// UnmarshalBinary decodes a slice of bytes on the target element. -func (share *RKGShareRoundThree) UnmarshalBinary(data []byte) error { - //share.modulus = data[0] - lenShare := data[0] - rLength := len(data[1:]) / int(lenShare) - if *share == nil { - *share = make([]*ring.Poly, lenShare) - } - ptr := 1 - for i := uint8(0); i < lenShare; i++ { - if (*share)[i] == nil { - (*share)[i] = new(ring.Poly) - } - err := (*share)[i].UnmarshalBinary(data[ptr : ptr+rLength]) - if err != nil { - return err - } - ptr += rLength - } - - return nil -} - // AllocateShares allocates the shares of the EKG protocol. -func (ekg *RKGProtocol) AllocateShares() (r1 RKGShareRoundOne, r2 RKGShareRoundTwo, r3 RKGShareRoundThree) { - r1 = make([]*ring.Poly, ekg.context.params.Beta) +func (ekg *RKGProtocol) AllocateShares() (r1 RKGShare, r2 RKGShare) { + r1 = make([][2]*ring.Poly, ekg.context.params.Beta) r2 = make([][2]*ring.Poly, ekg.context.params.Beta) - r3 = make([]*ring.Poly, ekg.context.params.Beta) for i := uint64(0); i < ekg.context.params.Beta; i++ { - r1[i] = ekg.context.contextQP.NewPoly() + r1[i][0] = ekg.context.contextQP.NewPoly() + r1[i][1] = ekg.context.contextQP.NewPoly() r2[i][0] = ekg.context.contextQP.NewPoly() r2[i][1] = ekg.context.contextQP.NewPoly() - r3[i] = ekg.context.contextQP.NewPoly() } - return } @@ -223,10 +133,9 @@ func (ekg *RKGProtocol) NewEphemeralKey() (ephemeralKey *ring.Poly) { // GenShareRoundOne is the first of three rounds of the RKGProtocol protocol. Each party generates a pseudo encryption of // its secret share of the key s_i under its ephemeral key u_i : [-u_i*a + s_i*w + e_i] and broadcasts it to the other // j-1 parties. -func (ekg *RKGProtocol) GenShareRoundOne(u, sk *ring.Poly, crp []*ring.Poly, shareOut RKGShareRoundOne) { +func (ekg *RKGProtocol) GenShareRoundOne(u, sk *ring.Poly, crp []*ring.Poly, shareOut RKGShare) { var index uint64 - // Given a base decomposition w_i (here the CRT decomposition) // computes [-u*a_i + P*s_i + e_i] // where a_i = crp_i @@ -240,10 +149,9 @@ func (ekg *RKGProtocol) GenShareRoundOne(u, sk *ring.Poly, crp []*ring.Poly, sha ringContext.InvMForm(ekg.polypool, ekg.polypool) for i := uint64(0); i < ekg.context.params.Beta; i++ { - // h = e - ekg.gaussianSampler.Read(shareOut[i]) - ringContext.NTT(shareOut[i], shareOut[i]) + ekg.gaussianSampler.Read(shareOut[i][0]) + ringContext.NTT(shareOut[i][0], shareOut[i][0]) // h = sk*CrtBaseDecompQi + e for j := uint64(0); j < ekg.context.params.Alpha; j++ { @@ -251,7 +159,7 @@ func (ekg *RKGProtocol) GenShareRoundOne(u, sk *ring.Poly, crp []*ring.Poly, sha index = i*ekg.context.params.Alpha + j qi := ringContext.Modulus[index] tmp0 := ekg.polypool.Coeffs[index] - tmp1 := shareOut[i].Coeffs[index] + tmp1 := shareOut[i][0].Coeffs[index] for w := uint64(0); w < ekg.context.contextQP.N; w++ { tmp1[w] = ring.CRed(tmp1[w]+tmp0[w], qi) @@ -262,21 +170,28 @@ func (ekg *RKGProtocol) GenShareRoundOne(u, sk *ring.Poly, crp []*ring.Poly, sha break } } - // h = sk*CrtBaseDecompQi + -u*a + e - ringContext.MulCoeffsMontgomeryAndSub(u, crp[i], shareOut[i]) + ekg.context.contextQP.MulCoeffsMontgomeryAndSub(u, crp[i], shareOut[i][0]) + + // Second Element + // e_2i + ekg.gaussianSampler.Read(shareOut[i][1]) + ringContext.NTT(shareOut[i][1], shareOut[i][1]) + // s*a + e_2i + ringContext.MulCoeffsMontgomeryAndAdd(sk, crp[i], shareOut[i][1]) } - ekg.polypool.Zero() + ekg.polypool.Zero() // TODO: check if we can remove this one return } // AggregateShareRoundOne adds share1 and share2 on shareOut. -func (ekg *RKGProtocol) AggregateShareRoundOne(share1, share2, shareOut RKGShareRoundOne) { +func (ekg *RKGProtocol) AggregateShareRoundOne(share1, share2, shareOut RKGShare) { for i := uint64(0); i < ekg.context.params.Beta; i++ { - ekg.context.contextQP.Add(share1[i], share2[i], shareOut[i]) + ekg.context.contextQP.Add(share1[i][0], share2[i][0], shareOut[i][0]) + ekg.context.contextQP.Add(share1[i][1], share2[i][1], shareOut[i][1]) } } @@ -288,10 +203,13 @@ func (ekg *RKGProtocol) AggregateShareRoundOne(share1, share2, shareOut RKGShare // = [s_i * (-u*a + s*w + e) + e_i1, s_i*a + e_i2] // // and broadcasts both values to the other j-1 parties. -func (ekg *RKGProtocol) GenShareRoundTwo(round1 RKGShareRoundOne, sk *ring.Poly, crp []*ring.Poly, shareOut RKGShareRoundTwo) { +func (ekg *RKGProtocol) GenShareRoundTwo(round1 RKGShare, u, sk *ring.Poly, crp []*ring.Poly, shareOut RKGShare) { ringContext := ekg.context.contextQP + // (u_i - s_i) + ringContext.Sub(u, sk, ekg.tmpPoly1) + // Each sample is of the form [-u*a_i + s*w_i + e_i] // So for each element of the base decomposition w_i : for i := uint64(0); i < ekg.context.params.Beta; i++ { @@ -299,19 +217,18 @@ func (ekg *RKGProtocol) GenShareRoundTwo(round1 RKGShareRoundOne, sk *ring.Poly, // Computes [(sum samples)*sk + e_1i, sk*a + e_2i] // (AggregateShareRoundTwo samples) * sk - ringContext.MulCoeffsMontgomery(round1[i], sk, shareOut[i][0]) + ringContext.MulCoeffsMontgomery(round1[i][0], sk, shareOut[i][0]) // (AggregateShareRoundTwo samples) * sk + e_1i - ekg.gaussianSampler.Read(ekg.tmpPoly1) - ringContext.NTT(ekg.tmpPoly1, ekg.tmpPoly1) - ringContext.Add(shareOut[i][0], ekg.tmpPoly1, shareOut[i][0]) + ekg.gaussianSampler.Read(ekg.tmpPoly2) + ringContext.NTT(ekg.tmpPoly2, ekg.tmpPoly2) + ringContext.Add(shareOut[i][0], ekg.tmpPoly2, shareOut[i][0]) - // Second Element - // e_2i + // second part + // (u - s) * (sum [x][s*a_i + e_2i]) + e3i ekg.gaussianSampler.Read(shareOut[i][1]) ringContext.NTT(shareOut[i][1], shareOut[i][1]) - // s*a + e_2i - ringContext.MulCoeffsMontgomeryAndAdd(sk, crp[i], shareOut[i][1]) + ringContext.MulCoeffsMontgomeryAndAdd(ekg.tmpPoly1, round1[i][1], shareOut[i][1]) } } @@ -322,54 +239,24 @@ func (ekg *RKGProtocol) GenShareRoundTwo(round1 RKGShareRoundOne, sk *ring.Poly, // [sum(s_j * (-u*a + s*w + e) + e_j1), sum(s_j*a + e_j2)] // // = [s * (-u*a + s*w + e) + e_1, s*a + e_2]. -func (ekg *RKGProtocol) AggregateShareRoundTwo(share1, share2, shareOut RKGShareRoundTwo) { +func (ekg *RKGProtocol) AggregateShareRoundTwo(share1, share2, shareOut RKGShare) { for i := uint64(0); i < ekg.context.params.Beta; i++ { ekg.context.contextQP.Add(share1[i][0], share2[i][0], shareOut[i][0]) ekg.context.contextQP.Add(share1[i][1], share2[i][1], shareOut[i][1]) } - -} - -// GenShareRoundThree is the second pard of the third and last round of the RKGProtocol protocol. Each party operates a key-switch on [s*a + e_2], -// by computing : -// -// [(u_i - s_i)*(s*a + e_2)] -// -// and broadcasts the result to the other j-1 parties. -func (ekg *RKGProtocol) GenShareRoundThree(round2 RKGShareRoundTwo, u, sk *ring.Poly, shareOut RKGShareRoundThree) { - - ringContext := ekg.context.contextQP - - // (u_i - s_i) - ringContext.Sub(u, sk, ekg.tmpPoly1) - - for i := uint64(0); i < ekg.context.params.Beta; i++ { - - // (u - s) * (sum [x][s*a_i + e_2i]) + e3i - ekg.gaussianSampler.Read(shareOut[i]) - ringContext.NTT(shareOut[i], shareOut[i]) - ringContext.MulCoeffsMontgomeryAndAdd(ekg.tmpPoly1, round2[i][1], shareOut[i]) - } -} - -// AggregateShareRoundThree adds share1 and share2 on shareOut. -func (ekg *RKGProtocol) AggregateShareRoundThree(share1, share2, shareOut RKGShareRoundThree) { - for i := uint64(0); i < ekg.context.params.Beta; i++ { - ekg.context.contextQP.Add(share1[i], share2[i], shareOut[i]) - } } // GenRelinearizationKey finalizes the protocol and returns the common EvaluationKey. -func (ekg *RKGProtocol) GenRelinearizationKey(round2 RKGShareRoundTwo, round3 RKGShareRoundThree, evalKeyOut *bfv.EvaluationKey) { +func (ekg *RKGProtocol) GenRelinearizationKey(round1 RKGShare, round2 RKGShare, evalKeyOut *bfv.EvaluationKey) { ringContext := ekg.context.contextQP key := evalKeyOut.Get()[0].Get() for i := uint64(0); i < ekg.context.params.Beta; i++ { - ringContext.Add(round2[i][0], round3[i], key[i][0]) - key[i][1].Copy(round2[i][1]) + ringContext.Add(round2[i][0], round2[i][1], key[i][0]) + key[i][1].Copy(round1[i][1]) ringContext.MForm(key[i][0], key[i][0]) ringContext.MForm(key[i][1], key[i][1]) diff --git a/dckks/dckks_benchmark_test.go b/dckks/dckks_benchmark_test.go index a585be39..20a69f2f 100644 --- a/dckks/dckks_benchmark_test.go +++ b/dckks/dckks_benchmark_test.go @@ -76,16 +76,15 @@ func benchRelinKeyGen(b *testing.B) { *RKGProtocol u *ring.Poly s *ring.Poly - share1 RKGShareRoundOne - share2 RKGShareRoundTwo - share3 RKGShareRoundThree + share1 RKGShare + share2 RKGShare } p := new(Party) p.RKGProtocol = NewEkgProtocol(parameters) p.u = p.RKGProtocol.NewEphemeralKey() p.s = sk0Shards[0].Get() - p.share1, p.share2, p.share3 = p.RKGProtocol.AllocateShares() + p.share1, p.share2 = p.RKGProtocol.AllocateShares() prng, err := utils.NewKeyedPRNG(nil) if err != nil { panic(err) @@ -114,7 +113,7 @@ func benchRelinKeyGen(b *testing.B) { b.Run(testString("Round2Gen/", parties, parameters), func(b *testing.B) { for i := 0; i < b.N; i++ { - p.GenShareRoundTwo(p.share1, p.s, crp, p.share2) + p.GenShareRoundTwo(p.share1, p.u, p.s, crp, p.share2) } }) @@ -124,20 +123,6 @@ func benchRelinKeyGen(b *testing.B) { p.AggregateShareRoundTwo(p.share2, p.share2, p.share2) } }) - - b.Run(testString("Round3Gen/", parties, parameters), func(b *testing.B) { - - for i := 0; i < b.N; i++ { - p.GenShareRoundThree(p.share2, p.u, p.s, p.share3) - } - }) - - b.Run(testString("Round3Agg/", parties, parameters), func(b *testing.B) { - - for i := 0; i < b.N; i++ { - p.AggregateShareRoundThree(p.share3, p.share3, p.share3) - } - }) } } diff --git a/dckks/dckks_test.go b/dckks/dckks_test.go index 5cbd7b3f..58fc7c21 100644 --- a/dckks/dckks_test.go +++ b/dckks/dckks_test.go @@ -209,9 +209,8 @@ func testRelinKeyGen(t *testing.T) { *RKGProtocol u *ring.Poly s *ring.Poly - share1 RKGShareRoundOne - share2 RKGShareRoundTwo - share3 RKGShareRoundThree + share1 RKGShare + share2 RKGShare } rkgParties := make([]*Party, parties) @@ -221,7 +220,7 @@ func testRelinKeyGen(t *testing.T) { p.RKGProtocol = NewEkgProtocol(parameters) p.u = p.NewEphemeralKey() p.s = sk0Shards[i].Get() - p.share1, p.share2, p.share3 = p.AllocateShares() + p.share1, p.share2 = p.AllocateShares() rkgParties[i] = p } @@ -248,22 +247,14 @@ func testRelinKeyGen(t *testing.T) { //ROUND 2 for i, p := range rkgParties { - p.GenShareRoundTwo(P0.share1, p.s, crp, p.share2) + p.GenShareRoundTwo(P0.share1, p.u, p.s, crp, p.share2) if i > 0 { P0.AggregateShareRoundTwo(p.share2, P0.share2, P0.share2) } } - // ROUND 3 - for i, p := range rkgParties { - p.GenShareRoundThree(P0.share2, p.u, p.s, p.share3) - if i > 0 { - P0.AggregateShareRoundThree(p.share3, P0.share3, P0.share3) - } - } - evk := ckks.NewRelinKey(parameters) - P0.GenRelinearizationKey(P0.share2, P0.share3, evk) + P0.GenRelinearizationKey(P0.share1, P0.share2, evk) coeffs, _, ciphertext := newTestVectors(params, encryptorPk0, 1, t) diff --git a/dckks/relinkey_gen.go b/dckks/relinkey_gen.go index 638f98fd..3e6dedb3 100644 --- a/dckks/relinkey_gen.go +++ b/dckks/relinkey_gen.go @@ -1,62 +1,123 @@ package dckks import ( + "errors" + "github.com/ldsec/lattigo/ckks" "github.com/ldsec/lattigo/ring" "github.com/ldsec/lattigo/utils" ) -// RKGProtocol is a structure storing the parameters for the collective evaluation-key generation. +// RKGProtocol is the structure storing the parameters and state for a party in the collective relinearization key +// generation protocol. type RKGProtocol struct { - dckksContext *dckksContext + context *dckksContext + tmpPoly1 *ring.Poly + tmpPoly2 *ring.Poly polypool *ring.Poly gaussianSampler *ring.GaussianSampler ternarySamplerMontgomery *ring.TernarySampler } -// RKGShareRoundOne is a struct storing the round one share of the RKG protocol. -type RKGShareRoundOne []*ring.Poly +type RKGShare [][2]*ring.Poly -// RKGShareRoundTwo is a struct storing the round two share of the RKG protocol. -type RKGShareRoundTwo [][2]*ring.Poly +// MarshalBinary encodes the target element on a slice of bytes. +func (share *RKGShare) MarshalBinary() ([]byte, error) { + //we have modulus * bitLog * Len of 1 ring rings + rLength := ((*share)[0])[0].GetDataLen(true) + data := make([]byte, 1+2*rLength*uint64(len(*share))) + if len(*share) > 0xFF { + return []byte{}, errors.New("RKGShare : uint8 overflow on length") + } + data[0] = uint8(len(*share)) -// RKGShareRoundThree is a struct storing the round three share of the RKG protocol. -type RKGShareRoundThree []*ring.Poly + //write all of our rings in the data. + //write all the polys + ptr := uint64(1) + for _, elem := range *share { + _, err := elem[0].WriteTo(data[ptr : ptr+rLength]) + if err != nil { + return []byte{}, err + } + ptr += rLength + _, err = elem[1].WriteTo(data[ptr : ptr+rLength]) + if err != nil { + return []byte{}, err + } + ptr += rLength + } -// AllocateShares allocates the shares of the RKG protocol. -func (ekg *RKGProtocol) AllocateShares() (r1 RKGShareRoundOne, r2 RKGShareRoundTwo, r3 RKGShareRoundThree) { + return data, nil - contextQP := ekg.dckksContext.contextQP +} - r1 = make([]*ring.Poly, ekg.dckksContext.beta) - r2 = make([][2]*ring.Poly, ekg.dckksContext.beta) - r3 = make([]*ring.Poly, ekg.dckksContext.beta) - for i := uint64(0); i < ekg.dckksContext.beta; i++ { - r1[i] = contextQP.NewPoly() - r2[i][0] = contextQP.NewPoly() - r2[i][1] = contextQP.NewPoly() - r3[i] = contextQP.NewPoly() +// UnmarshalBinary decodes a slice of bytes on the target element. +func (share *RKGShare) UnmarshalBinary(data []byte) error { + lenShare := data[0] + rLength := (len(data) - 1) / (2 * int(lenShare)) + + if *share == nil { + *share = make([][2]*ring.Poly, lenShare) + } + ptr := (1) + for i := (0); i < int(lenShare); i++ { + if (*share)[i][0] == nil || (*share)[i][1] == nil { + (*share)[i][0] = new(ring.Poly) + (*share)[i][1] = new(ring.Poly) + } + + err := (*share)[i][0].UnmarshalBinary(data[ptr : ptr+rLength]) + if err != nil { + return err + } + ptr += rLength + err = (*share)[i][1].UnmarshalBinary(data[ptr : ptr+rLength]) + if err != nil { + return err + } + ptr += rLength + + } + + return nil +} + +// AllocateShares allocates the shares of the EKG protocol. +func (ekg *RKGProtocol) AllocateShares() (r1 RKGShare, r2 RKGShare) { + r1 = make([][2]*ring.Poly, ekg.context.params.Beta) + r2 = make([][2]*ring.Poly, ekg.context.params.Beta) + for i := uint64(0); i < ekg.context.params.Beta; i++ { + r1[i][0] = ekg.context.contextQP.NewPoly() + r1[i][1] = ekg.context.contextQP.NewPoly() + r2[i][0] = ekg.context.contextQP.NewPoly() + r2[i][1] = ekg.context.contextQP.NewPoly() } return } -// NewEkgProtocol creates a new RKGProtocol object that will be used to generate a collective evaluation-key. +// NewEkgProtocol creates a new RKGProtocol object that will be used to generate a collective evaluation-key +// among j parties in the given context with the given bit-decomposition. func NewEkgProtocol(params *ckks.Parameters) *RKGProtocol { if !params.IsValid() { panic("cannot NewEkgProtocol : params not valid (check if they where generated properly)") } + context := newDckksContext(params) + ekg := new(RKGProtocol) - dckksContext := newDckksContext(params) - ekg.dckksContext = dckksContext - ekg.polypool = ekg.dckksContext.contextQP.NewPoly() + ekg.context = context + + ekg.tmpPoly1 = ekg.context.contextQP.NewPoly() + ekg.tmpPoly2 = ekg.context.contextQP.NewPoly() + ekg.polypool = ekg.context.contextQP.NewPoly() prng, err := utils.NewPRNG() if err != nil { panic(err) } - ekg.gaussianSampler = ring.NewGaussianSampler(prng, dckksContext.contextQP, params.Sigma, uint64(6*params.Sigma)) - ekg.ternarySamplerMontgomery = ring.NewTernarySampler(prng, dckksContext.contextQP, 0.5, true) + ekg.ternarySamplerMontgomery = ring.NewTernarySampler(prng, ekg.context.contextQP, 0.5, true) + ekg.gaussianSampler = ring.NewGaussianSampler(prng, ekg.context.contextQP, params.Sigma, uint64(6*params.Sigma)) + return ekg } @@ -65,70 +126,72 @@ func NewEkgProtocol(params *ckks.Parameters) *RKGProtocol { // of the collective secret-key. func (ekg *RKGProtocol) NewEphemeralKey() (ephemeralKey *ring.Poly) { ephemeralKey = ekg.ternarySamplerMontgomery.ReadNew() - ekg.dckksContext.contextQP.NTT(ephemeralKey, ephemeralKey) + ekg.context.contextQP.NTT(ephemeralKey, ephemeralKey) return ephemeralKey } // GenShareRoundOne is the first of three rounds of the RKGProtocol protocol. Each party generates a pseudo encryption of -// its secret share of the key s_i under its ephemeral key u_i : [-u_i*a + P*s_i + e_i] and broadcasts it to the other +// its secret share of the key s_i under its ephemeral key u_i : [-u_i*a + s_i*w + e_i] and broadcasts it to the other // j-1 parties. -func (ekg *RKGProtocol) GenShareRoundOne(u, sk *ring.Poly, crp []*ring.Poly, shareOut RKGShareRoundOne) { - - contextQP := ekg.dckksContext.contextQP +func (ekg *RKGProtocol) GenShareRoundOne(u, sk *ring.Poly, crp []*ring.Poly, shareOut RKGShare) { var index uint64 - // Given a base decomposition w_i (here the CRT decomposition) // computes [-u*a_i + P*s_i + e_i] // where a_i = crp_i + ringContext := ekg.context.contextQP + ekg.polypool.Copy(sk) - contextQP.MulScalarBigint(ekg.polypool, ekg.dckksContext.contextP.ModulusBigint, ekg.polypool) + ringContext.MulScalarBigint(ekg.polypool, ekg.context.contextP.ModulusBigint, ekg.polypool) - contextQP.InvMForm(ekg.polypool, ekg.polypool) - - for i := uint64(0); i < ekg.dckksContext.beta; i++ { + ringContext.InvMForm(ekg.polypool, ekg.polypool) + for i := uint64(0); i < ekg.context.params.Beta; i++ { // h = e - ekg.gaussianSampler.Read(shareOut[i]) - contextQP.NTT(shareOut[i], shareOut[i]) + ekg.gaussianSampler.Read(shareOut[i][0]) + ringContext.NTT(shareOut[i][0], shareOut[i][0]) // h = sk*CrtBaseDecompQi + e - for j := uint64(0); j < ekg.dckksContext.alpha; j++ { + for j := uint64(0); j < ekg.context.params.Alpha; j++ { - index = i*ekg.dckksContext.alpha + j - - qi := contextQP.Modulus[index] + index = i*ekg.context.params.Alpha + j + qi := ringContext.Modulus[index] tmp0 := ekg.polypool.Coeffs[index] - tmp1 := shareOut[i].Coeffs[index] + tmp1 := shareOut[i][0].Coeffs[index] - for w := uint64(0); w < contextQP.N; w++ { + for w := uint64(0); w < ekg.context.contextQP.N; w++ { tmp1[w] = ring.CRed(tmp1[w]+tmp0[w], qi) } // Handles the case where nb pj does not divides nb qi - if index >= uint64(len(ekg.dckksContext.params.Qi)-1) { + if index == uint64(len(ekg.context.params.LogQi)-1) { break } } - // h = sk*CrtBaseDecompQi + -u*a + e - contextQP.MulCoeffsMontgomeryAndSub(u, crp[i], shareOut[i]) + ekg.context.contextQP.MulCoeffsMontgomeryAndSub(u, crp[i], shareOut[i][0]) + + // Second Element + // e_2i + ekg.gaussianSampler.Read(shareOut[i][1]) + ringContext.NTT(shareOut[i][1], shareOut[i][1]) + // s*a + e_2i + ringContext.MulCoeffsMontgomeryAndAdd(sk, crp[i], shareOut[i][1]) } - ekg.polypool.Zero() + ekg.polypool.Zero() // TODO: check if we can remove this one return } -// AggregateShareRoundOne sums share1 with share2 on shareOut. -func (ekg *RKGProtocol) AggregateShareRoundOne(share1, share2, shareOut RKGShareRoundOne) { +// AggregateShareRoundOne adds share1 and share2 on shareOut. +func (ekg *RKGProtocol) AggregateShareRoundOne(share1, share2, shareOut RKGShare) { - contextQP := ekg.dckksContext.contextQP - - for i := uint64(0); i < ekg.dckksContext.beta; i++ { - contextQP.Add(share1[i], share2[i], shareOut[i]) + for i := uint64(0); i < ekg.context.params.Beta; i++ { + ekg.context.contextQP.Add(share1[i][0], share2[i][0], shareOut[i][0]) + ekg.context.contextQP.Add(share1[i][1], share2[i][1], shareOut[i][1]) } } @@ -140,34 +203,34 @@ func (ekg *RKGProtocol) AggregateShareRoundOne(share1, share2, shareOut RKGShare // = [s_i * (-u*a + s*w + e) + e_i1, s_i*a + e_i2] // // and broadcasts both values to the other j-1 parties. -func (ekg *RKGProtocol) GenShareRoundTwo(round1 RKGShareRoundOne, sk *ring.Poly, crp []*ring.Poly, shareOut RKGShareRoundTwo) { +func (ekg *RKGProtocol) GenShareRoundTwo(round1 RKGShare, u, sk *ring.Poly, crp []*ring.Poly, shareOut RKGShare) { - contextQP := ekg.dckksContext.contextQP + ringContext := ekg.context.contextQP + + // (u_i - s_i) + ringContext.Sub(u, sk, ekg.tmpPoly1) // Each sample is of the form [-u*a_i + s*w_i + e_i] // So for each element of the base decomposition w_i : - for i := uint64(0); i < ekg.dckksContext.beta; i++ { + for i := uint64(0); i < ekg.context.params.Beta; i++ { // Computes [(sum samples)*sk + e_1i, sk*a + e_2i] // (AggregateShareRoundTwo samples) * sk - contextQP.MulCoeffsMontgomery(round1[i], sk, shareOut[i][0]) + ringContext.MulCoeffsMontgomery(round1[i][0], sk, shareOut[i][0]) // (AggregateShareRoundTwo samples) * sk + e_1i - ekg.gaussianSampler.Read(ekg.polypool) - contextQP.NTT(ekg.polypool, ekg.polypool) - contextQP.Add(shareOut[i][0], ekg.polypool, shareOut[i][0]) + ekg.gaussianSampler.Read(ekg.tmpPoly2) + ringContext.NTT(ekg.tmpPoly2, ekg.tmpPoly2) + ringContext.Add(shareOut[i][0], ekg.tmpPoly2, shareOut[i][0]) - // Second Element - // e_2i + // second part + // (u - s) * (sum [x][s*a_i + e_2i]) + e3i ekg.gaussianSampler.Read(shareOut[i][1]) - contextQP.NTT(shareOut[i][1], shareOut[i][1]) - // s*a + e_2i - contextQP.MulCoeffsMontgomeryAndAdd(sk, crp[i], shareOut[i][1]) + ringContext.NTT(shareOut[i][1], shareOut[i][1]) + ringContext.MulCoeffsMontgomeryAndAdd(ekg.tmpPoly1, round1[i][1], shareOut[i][1]) } - ekg.polypool.Zero() - } // AggregateShareRoundTwo is the first part of the third and last round of the RKGProtocol protocol. Upon receiving the j-1 elements, each party @@ -176,62 +239,27 @@ func (ekg *RKGProtocol) GenShareRoundTwo(round1 RKGShareRoundOne, sk *ring.Poly, // [sum(s_j * (-u*a + s*w + e) + e_j1), sum(s_j*a + e_j2)] // // = [s * (-u*a + s*w + e) + e_1, s*a + e_2]. -func (ekg *RKGProtocol) AggregateShareRoundTwo(share1, share2, shareOut RKGShareRoundTwo) { +func (ekg *RKGProtocol) AggregateShareRoundTwo(share1, share2, shareOut RKGShare) { - contextQP := ekg.dckksContext.contextQP - - for i := uint64(0); i < ekg.dckksContext.beta; i++ { - contextQP.Add(share1[i][0], share2[i][0], shareOut[i][0]) - contextQP.Add(share1[i][1], share2[i][1], shareOut[i][1]) - } - -} - -// GenShareRoundThree is the second pard of the third and last round of the RKGProtocol protocol. Each party operates a key-switch on [s*a + e_2], -// by computing : -// -// [(u_i - s_i)*(s*a + e_2)] -// -// and broadcasts the result to the other j-1 parties. -func (ekg *RKGProtocol) GenShareRoundThree(round2 RKGShareRoundTwo, u, sk *ring.Poly, shareOut RKGShareRoundThree) { - - contextQP := ekg.dckksContext.contextQP - - // (u_i - s_i) - contextQP.Sub(u, sk, ekg.polypool) - - for i := uint64(0); i < ekg.dckksContext.beta; i++ { - - // (u - s) * (sum [x][s*a_i + e_2i]) + e3i - ekg.gaussianSampler.Read(shareOut[i]) - contextQP.NTT(shareOut[i], shareOut[i]) - contextQP.MulCoeffsMontgomeryAndAdd(ekg.polypool, round2[i][1], shareOut[i]) + for i := uint64(0); i < ekg.context.params.Beta; i++ { + ekg.context.contextQP.Add(share1[i][0], share2[i][0], shareOut[i][0]) + ekg.context.contextQP.Add(share1[i][1], share2[i][1], shareOut[i][1]) } } -// AggregateShareRoundThree sums share1 with share2 on shareOut. -func (ekg *RKGProtocol) AggregateShareRoundThree(share1, share2, shareOut RKGShareRoundThree) { +// GenRelinearizationKey finalizes the protocol and returns the common EvaluationKey. +func (ekg *RKGProtocol) GenRelinearizationKey(round1 RKGShare, round2 RKGShare, evalKeyOut *ckks.EvaluationKey) { - contextQP := ekg.dckksContext.contextQP - - for i := uint64(0); i < ekg.dckksContext.beta; i++ { - contextQP.Add(share1[i], share2[i], shareOut[i]) - } -} - -// GenRelinearizationKey finalizes the protocol and returns the collective EvalutionKey. -func (ekg *RKGProtocol) GenRelinearizationKey(round2 RKGShareRoundTwo, round3 RKGShareRoundThree, evalKeyOut *ckks.EvaluationKey) { - - contextQP := ekg.dckksContext.contextQP + ringContext := ekg.context.contextQP key := evalKeyOut.Get().Get() - for i := uint64(0); i < ekg.dckksContext.beta; i++ { + for i := uint64(0); i < ekg.context.params.Beta; i++ { - contextQP.Add(round2[i][0], round3[i], key[i][0]) - key[i][1].Copy(round2[i][1]) + ringContext.Add(round2[i][0], round2[i][1], key[i][0]) + key[i][1].Copy(round1[i][1]) - contextQP.MForm(key[i][0], key[i][0]) - contextQP.MForm(key[i][1], key[i][1]) + ringContext.MForm(key[i][0], key[i][0]) + ringContext.MForm(key[i][1], key[i][1]) } } diff --git a/examples/dbfv/pir/pir.go b/examples/dbfv/pir/pir.go index f22192d1..be5a09e9 100644 --- a/examples/dbfv/pir/pir.go +++ b/examples/dbfv/pir/pir.go @@ -1,15 +1,16 @@ package main import ( - "github.com/ldsec/lattigo/bfv" - "github.com/ldsec/lattigo/dbfv" - "github.com/ldsec/lattigo/ring" - "github.com/ldsec/lattigo/utils" "log" "os" "strconv" "sync" "time" + + "github.com/ldsec/lattigo/bfv" + "github.com/ldsec/lattigo/dbfv" + "github.com/ldsec/lattigo/ring" + "github.com/ldsec/lattigo/utils" ) func check(err error) { @@ -68,12 +69,11 @@ func main() { sk *bfv.SecretKey rlkEphemSk *ring.Poly - ckgShare dbfv.CKGShare - rkgShareOne dbfv.RKGShareRoundOne - rkgShareTwo dbfv.RKGShareRoundTwo - rkgShareThree dbfv.RKGShareRoundThree - rtgShare dbfv.RTGShare - cksShare dbfv.CKSShare + ckgShare dbfv.CKGShare + rkgShareOne dbfv.RKGShare + rkgShareTwo dbfv.RKGShare + rtgShare dbfv.RTGShare + cksShare dbfv.CKSShare input []uint64 } @@ -128,7 +128,7 @@ func main() { contextKeys.Add(colSk.Get(), pi.sk.Get(), colSk.Get()) //TODO: doc says "return" pi.ckgShare = ckg.AllocateShares() - pi.rkgShareOne, pi.rkgShareTwo, pi.rkgShareThree = rkg.AllocateShares() + pi.rkgShareOne, pi.rkgShareTwo = rkg.AllocateShares() pi.rtgShare = rtg.AllocateShare() pi.cksShare = cks.AllocateShare() @@ -167,7 +167,7 @@ func main() { } }, N) - rkgCombined1, rkgCombined2, rkgCombined3 := rkg.AllocateShares() + rkgCombined1, rkgCombined2 := rkg.AllocateShares() elapsedRKGCloud = runTimed(func() { for _, pi := range P { @@ -177,29 +177,18 @@ func main() { elapsedRKGParty += runTimedParty(func() { for _, pi := range P { - rkg.GenShareRoundTwo(rkgCombined1, pi.sk.Get(), crp, pi.rkgShareTwo) - } - }, N) - - elapsedRKGCloud += runTimed(func() { - for _, pi := range P { - rkg.AggregateShareRoundTwo(pi.rkgShareTwo, rkgCombined2, rkgCombined2) - } - }) - - elapsedRKGParty += runTimedParty(func() { - for _, pi := range P { - rkg.GenShareRoundThree(rkgCombined2, pi.rlkEphemSk, pi.sk.Get(), pi.rkgShareThree) + rkg.GenShareRoundTwo(rkgCombined1, pi.rlkEphemSk, pi.sk.Get(), crp, pi.rkgShareTwo) } }, N) rlk := bfv.NewRelinKey(params, 1) elapsedRKGCloud += runTimed(func() { for _, pi := range P { - rkg.AggregateShareRoundThree(pi.rkgShareThree, rkgCombined3, rkgCombined3) + rkg.AggregateShareRoundTwo(pi.rkgShareTwo, rkgCombined2, rkgCombined2) } - rkg.GenRelinearizationKey(rkgCombined2, rkgCombined3, rlk) + rkg.GenRelinearizationKey(rkgCombined1, rkgCombined2, rlk) }) + l.Printf("\tdone (cloud: %s, party: %s)\n", elapsedRKGCloud, elapsedRKGParty) // 3) Collective rotation keys geneneration diff --git a/examples/dbfv/psi/psi.go b/examples/dbfv/psi/psi.go index 50dd4bbf..d3353cb1 100644 --- a/examples/dbfv/psi/psi.go +++ b/examples/dbfv/psi/psi.go @@ -1,16 +1,17 @@ package main import ( - "github.com/ldsec/lattigo/bfv" - "github.com/ldsec/lattigo/dbfv" - "github.com/ldsec/lattigo/ring" - "github.com/ldsec/lattigo/utils" "log" "math/rand" "os" "strconv" "sync" "time" + + "github.com/ldsec/lattigo/bfv" + "github.com/ldsec/lattigo/dbfv" + "github.com/ldsec/lattigo/ring" + "github.com/ldsec/lattigo/utils" ) func check(err error) { @@ -53,11 +54,10 @@ func main() { sk *bfv.SecretKey rlkEphemSk *ring.Poly - ckgShare dbfv.CKGShare - rkgShareOne dbfv.RKGShareRoundOne - rkgShareTwo dbfv.RKGShareRoundTwo - rkgShareThree dbfv.RKGShareRoundThree - pcksShare dbfv.PCKSShare + ckgShare dbfv.CKGShare + rkgShareOne dbfv.RKGShare + rkgShareTwo dbfv.RKGShare + pcksShare dbfv.PCKSShare input []uint64 } @@ -112,7 +112,7 @@ func main() { contextKeys.Add(colSk.Get(), pi.sk.Get(), colSk.Get()) //TODO: doc says "return" pi.ckgShare = ckg.AllocateShares() - pi.rkgShareOne, pi.rkgShareTwo, pi.rkgShareThree = rkg.AllocateShares() + pi.rkgShareOne, pi.rkgShareTwo = rkg.AllocateShares() pi.pcksShare = pcks.AllocateShares() P[i] = pi @@ -146,7 +146,7 @@ func main() { } }, N) - rkgCombined1, rkgCombined2, rkgCombined3 := rkg.AllocateShares() + rkgCombined1, rkgCombined2 := rkg.AllocateShares() elapsedRKGCloud = runTimed(func() { for _, pi := range P { @@ -156,29 +156,18 @@ func main() { elapsedRKGParty += runTimedParty(func() { for _, pi := range P { - rkg.GenShareRoundTwo(rkgCombined1, pi.sk.Get(), crp, pi.rkgShareTwo) - } - }, N) - - elapsedRKGCloud += runTimed(func() { - for _, pi := range P { - rkg.AggregateShareRoundTwo(pi.rkgShareTwo, rkgCombined2, rkgCombined2) - } - }) - - elapsedRKGParty += runTimedParty(func() { - for _, pi := range P { - rkg.GenShareRoundThree(rkgCombined2, pi.rlkEphemSk, pi.sk.Get(), pi.rkgShareThree) + rkg.GenShareRoundTwo(rkgCombined1, pi.rlkEphemSk, pi.sk.Get(), crp, pi.rkgShareTwo) } }, N) rlk := bfv.NewRelinKey(params, 1) elapsedRKGCloud += runTimed(func() { for _, pi := range P { - rkg.AggregateShareRoundThree(pi.rkgShareThree, rkgCombined3, rkgCombined3) + rkg.AggregateShareRoundTwo(pi.rkgShareTwo, rkgCombined2, rkgCombined2) } - rkg.GenRelinearizationKey(rkgCombined2, rkgCombined3, rlk) + rkg.GenRelinearizationKey(rkgCombined1, rkgCombined2, rlk) }) + l.Printf("\tdone (cloud: %s, party: %s)\n", elapsedRKGCloud, elapsedRKGParty) l.Printf("\tSetup done (cloud: %s, party: %s)\n",