[dbfv/dckks] two-rounds relinearization key generation

This commit is contained in:
Christian M
2020-09-03 13:54:53 +02:00
parent 456530b1f1
commit 3871936fda
9 changed files with 241 additions and 407 deletions

View File

@@ -32,6 +32,9 @@ test_local:
go run ./examples/bfv/examples_bfv.go > /dev/null
go run ./examples/ckks/euler/euler.go > /dev/null
go run ./examples/ckks/sigmoid/sigmoid.go > /dev/null
go run ./examples/dbfv/pir/pir.go &> /dev/null
go run ./examples/dbfv/psi/psi.go &> /dev/null
@echo ok
test: test_fmt test_local

View File

@@ -84,9 +84,8 @@ func benchRelinKeyGen(b *testing.B) {
*RKGProtocol
u *ring.Poly
s *ring.Poly
share1 RKGShareRoundOne
share2 RKGShareRoundTwo
share3 RKGShareRoundThree
share1 RKGShare
share2 RKGShare
rlk *bfv.EvaluationKey
}
@@ -95,7 +94,7 @@ func benchRelinKeyGen(b *testing.B) {
p.RKGProtocol = NewEkgProtocol(parameters)
p.u = p.RKGProtocol.NewEphemeralKey()
p.s = sk0Shards[0].Get()
p.share1, p.share2, p.share3 = p.RKGProtocol.AllocateShares()
p.share1, p.share2 = p.RKGProtocol.AllocateShares()
p.rlk = bfv.NewRelinKey(parameters, 2)
prng, err := utils.NewKeyedPRNG(nil)
if err != nil {
@@ -124,7 +123,7 @@ func benchRelinKeyGen(b *testing.B) {
b.Run(testString("Round2/Gen", parties, parameters), func(b *testing.B) {
for i := 0; i < b.N; i++ {
p.GenShareRoundTwo(p.share1, p.s, crp, p.share2)
p.GenShareRoundTwo(p.share1, p.u, p.s, crp, p.share2)
}
})
@@ -134,21 +133,9 @@ func benchRelinKeyGen(b *testing.B) {
}
})
b.Run(testString("Round3/Gen", parties, parameters), func(b *testing.B) {
for i := 0; i < b.N; i++ {
p.GenShareRoundThree(p.share2, p.u, p.s, p.share3)
}
})
b.Run(testString("Round3/Agg", parties, parameters), func(b *testing.B) {
for i := 0; i < b.N; i++ {
p.AggregateShareRoundThree(p.share3, p.share3, p.share3)
}
})
b.Run(testString("Finalize", parties, parameters), func(b *testing.B) {
for i := 0; i < b.N; i++ {
p.GenRelinearizationKey(p.share2, p.share3, p.rlk)
p.GenRelinearizationKey(p.share1, p.share2, p.rlk)
}
})
}

View File

@@ -190,9 +190,8 @@ func testRelinKeyGen(t *testing.T) {
*RKGProtocol
u *ring.Poly
s *ring.Poly
share1 RKGShareRoundOne
share2 RKGShareRoundTwo
share3 RKGShareRoundThree
share1 RKGShare
share2 RKGShare
}
rkgParties := make([]*Party, parties)
@@ -202,7 +201,7 @@ func testRelinKeyGen(t *testing.T) {
p.RKGProtocol = NewEkgProtocol(parameters)
p.u = p.RKGProtocol.NewEphemeralKey()
p.s = sk0Shards[i].Get()
p.share1, p.share2, p.share3 = p.RKGProtocol.AllocateShares()
p.share1, p.share2 = p.RKGProtocol.AllocateShares()
rkgParties[i] = p
}
@@ -229,22 +228,14 @@ func testRelinKeyGen(t *testing.T) {
//ROUND 2
for i, p := range rkgParties {
p.GenShareRoundTwo(P0.share1, p.s, crp, p.share2)
p.GenShareRoundTwo(P0.share1, p.u, p.s, crp, p.share2)
if i > 0 {
P0.AggregateShareRoundTwo(p.share2, P0.share2, P0.share2)
}
}
// ROUND 3
for i, p := range rkgParties {
p.GenShareRoundThree(P0.share2, p.u, p.s, p.share3)
if i > 0 {
P0.AggregateShareRoundThree(p.share3, P0.share3, P0.share3)
}
}
evk := bfv.NewRelinKey(parameters, 1)
P0.GenRelinearizationKey(P0.share2, P0.share3, evk)
P0.GenRelinearizationKey(P0.share1, P0.share2, evk)
coeffs, _, ciphertext := newTestVectors(testCtx, encryptorPk0, t)
@@ -918,28 +909,28 @@ func Test_Relin_Marshalling(t *testing.T) {
u := rlk.NewEphemeralKey()
sk := bfv.NewKeyGenerator(params).GenSecretKey()
r1, r2, r3 := rlk.AllocateShares()
r1, r2 := rlk.AllocateShares()
rlk.GenShareRoundOne(u, sk.Get(), crp, r1)
data, err := r1.MarshalBinary()
require.NoError(t, err)
r1After := new(RKGShareRoundOne)
r1After := new(RKGShare)
err = r1After.UnmarshalBinary(data)
require.NoError(t, err)
for i := 0; i < (len(r1)); i++ {
a := r1[i]
b := (*r1After)[i]
a := r1[i][0]
b := (*r1After)[i][0]
moduli := a.GetLenModuli()
require.Equal(t, a.Coeffs[:moduli], b.Coeffs[:moduli])
}
rlk.GenShareRoundTwo(r1, sk.Get(), crp, r2)
rlk.GenShareRoundTwo(r1, u, sk.Get(), crp, r2)
data, err = r2.MarshalBinary()
require.NoError(t, err)
r2After := new(RKGShareRoundTwo)
r2After := new(RKGShare)
err = r2After.UnmarshalBinary(data)
require.NoError(t, err)
@@ -952,22 +943,6 @@ func Test_Relin_Marshalling(t *testing.T) {
}
}
rlk.GenShareRoundThree(r2, u, sk.Get(), r3)
data, err = r3.MarshalBinary()
require.NoError(t, err)
r3After := new(RKGShareRoundThree)
err = r3After.UnmarshalBinary(data)
require.NoError(t, err)
for i := 0; i < (len(r3)); i++ {
a := r3[i]
b := (*r3After)[i]
moduli := a.GetLenModuli()
require.Equal(t, a.Coeffs[:moduli], b.Coeffs[:moduli])
}
})
}

View File

@@ -2,6 +2,7 @@ package dbfv
import (
"errors"
"github.com/ldsec/lattigo/bfv"
"github.com/ldsec/lattigo/ring"
"github.com/ldsec/lattigo/utils"
@@ -18,63 +19,15 @@ type RKGProtocol struct {
ternarySamplerMontgomery *ring.TernarySampler
}
// RKGShareRoundOne is a struct storing the round one RKG shares.
type RKGShareRoundOne []*ring.Poly
// RKGShareRoundTwo is a struct storing the round two RKG shares.
type RKGShareRoundTwo [][2]*ring.Poly
// RKGShareRoundThree is a struct storing the round three RKG shares.
type RKGShareRoundThree []*ring.Poly
type RKGShare [][2]*ring.Poly
// MarshalBinary encodes the target element on a slice of bytes.
func (share *RKGShareRoundOne) MarshalBinary() ([]byte, error) {
rLength := (*share)[0].GetDataLen(true)
data := make([]byte, 1+rLength*uint64(len(*share)))
data[0] = uint8(len(*share))
pointer := uint64(1)
for _, s := range *share {
tmp, err := s.WriteTo(data[pointer : pointer+rLength])
if err != nil {
return []byte{}, err
}
pointer += tmp
}
return data, nil
}
// UnmarshalBinary decodes a slice of bytes on the target element.
func (share *RKGShareRoundOne) UnmarshalBinary(data []byte) error {
//share.modulus = data[0]
lenShare := data[0]
rLength := len(data[1:]) / int(lenShare)
if *share == nil {
*share = make([]*ring.Poly, lenShare)
}
ptr := 1
for i := uint8(0); i < lenShare; i++ {
if (*share)[i] == nil {
(*share)[i] = new(ring.Poly)
}
err := (*share)[i].UnmarshalBinary(data[ptr : ptr+rLength])
if err != nil {
return err
}
ptr += rLength
}
return nil
}
// MarshalBinary encodes the target element on a slice of bytes.
func (share *RKGShareRoundTwo) MarshalBinary() ([]byte, error) {
func (share *RKGShare) MarshalBinary() ([]byte, error) {
//we have modulus * bitLog * Len of 1 ring rings
rLength := ((*share)[0])[0].GetDataLen(true)
data := make([]byte, 1+2*rLength*uint64(len(*share)))
if len(*share) > 0xFF {
return []byte{}, errors.New("RKGShareRoundTwo : uint8 overflow on length")
return []byte{}, errors.New("RKGShare : uint8 overflow on length")
}
data[0] = uint8(len(*share))
@@ -99,7 +52,7 @@ func (share *RKGShareRoundTwo) MarshalBinary() ([]byte, error) {
}
// UnmarshalBinary decodes a slice of bytes on the target element.
func (share *RKGShareRoundTwo) UnmarshalBinary(data []byte) error {
func (share *RKGShare) UnmarshalBinary(data []byte) error {
lenShare := data[0]
rLength := (len(data) - 1) / (2 * int(lenShare))
@@ -129,59 +82,16 @@ func (share *RKGShareRoundTwo) UnmarshalBinary(data []byte) error {
return nil
}
// MarshalBinary encodes the target element on a slice of bytes.
func (share *RKGShareRoundThree) MarshalBinary() ([]byte, error) {
rLength := (*share)[0].GetDataLen(true)
data := make([]byte, 1+rLength*uint64(len(*share)))
data[0] = uint8(len(*share))
pointer := uint64(1)
for _, s := range *share {
tmp, err := s.WriteTo(data[pointer : pointer+rLength])
if err != nil {
return []byte{}, err
}
pointer += tmp
}
return data, nil
}
// UnmarshalBinary decodes a slice of bytes on the target element.
func (share *RKGShareRoundThree) UnmarshalBinary(data []byte) error {
//share.modulus = data[0]
lenShare := data[0]
rLength := len(data[1:]) / int(lenShare)
if *share == nil {
*share = make([]*ring.Poly, lenShare)
}
ptr := 1
for i := uint8(0); i < lenShare; i++ {
if (*share)[i] == nil {
(*share)[i] = new(ring.Poly)
}
err := (*share)[i].UnmarshalBinary(data[ptr : ptr+rLength])
if err != nil {
return err
}
ptr += rLength
}
return nil
}
// AllocateShares allocates the shares of the EKG protocol.
func (ekg *RKGProtocol) AllocateShares() (r1 RKGShareRoundOne, r2 RKGShareRoundTwo, r3 RKGShareRoundThree) {
r1 = make([]*ring.Poly, ekg.context.params.Beta)
func (ekg *RKGProtocol) AllocateShares() (r1 RKGShare, r2 RKGShare) {
r1 = make([][2]*ring.Poly, ekg.context.params.Beta)
r2 = make([][2]*ring.Poly, ekg.context.params.Beta)
r3 = make([]*ring.Poly, ekg.context.params.Beta)
for i := uint64(0); i < ekg.context.params.Beta; i++ {
r1[i] = ekg.context.contextQP.NewPoly()
r1[i][0] = ekg.context.contextQP.NewPoly()
r1[i][1] = ekg.context.contextQP.NewPoly()
r2[i][0] = ekg.context.contextQP.NewPoly()
r2[i][1] = ekg.context.contextQP.NewPoly()
r3[i] = ekg.context.contextQP.NewPoly()
}
return
}
@@ -223,10 +133,9 @@ func (ekg *RKGProtocol) NewEphemeralKey() (ephemeralKey *ring.Poly) {
// GenShareRoundOne is the first of three rounds of the RKGProtocol protocol. Each party generates a pseudo encryption of
// its secret share of the key s_i under its ephemeral key u_i : [-u_i*a + s_i*w + e_i] and broadcasts it to the other
// j-1 parties.
func (ekg *RKGProtocol) GenShareRoundOne(u, sk *ring.Poly, crp []*ring.Poly, shareOut RKGShareRoundOne) {
func (ekg *RKGProtocol) GenShareRoundOne(u, sk *ring.Poly, crp []*ring.Poly, shareOut RKGShare) {
var index uint64
// Given a base decomposition w_i (here the CRT decomposition)
// computes [-u*a_i + P*s_i + e_i]
// where a_i = crp_i
@@ -240,10 +149,9 @@ func (ekg *RKGProtocol) GenShareRoundOne(u, sk *ring.Poly, crp []*ring.Poly, sha
ringContext.InvMForm(ekg.polypool, ekg.polypool)
for i := uint64(0); i < ekg.context.params.Beta; i++ {
// h = e
ekg.gaussianSampler.Read(shareOut[i])
ringContext.NTT(shareOut[i], shareOut[i])
ekg.gaussianSampler.Read(shareOut[i][0])
ringContext.NTT(shareOut[i][0], shareOut[i][0])
// h = sk*CrtBaseDecompQi + e
for j := uint64(0); j < ekg.context.params.Alpha; j++ {
@@ -251,7 +159,7 @@ func (ekg *RKGProtocol) GenShareRoundOne(u, sk *ring.Poly, crp []*ring.Poly, sha
index = i*ekg.context.params.Alpha + j
qi := ringContext.Modulus[index]
tmp0 := ekg.polypool.Coeffs[index]
tmp1 := shareOut[i].Coeffs[index]
tmp1 := shareOut[i][0].Coeffs[index]
for w := uint64(0); w < ekg.context.contextQP.N; w++ {
tmp1[w] = ring.CRed(tmp1[w]+tmp0[w], qi)
@@ -262,21 +170,28 @@ func (ekg *RKGProtocol) GenShareRoundOne(u, sk *ring.Poly, crp []*ring.Poly, sha
break
}
}
// h = sk*CrtBaseDecompQi + -u*a + e
ringContext.MulCoeffsMontgomeryAndSub(u, crp[i], shareOut[i])
ekg.context.contextQP.MulCoeffsMontgomeryAndSub(u, crp[i], shareOut[i][0])
// Second Element
// e_2i
ekg.gaussianSampler.Read(shareOut[i][1])
ringContext.NTT(shareOut[i][1], shareOut[i][1])
// s*a + e_2i
ringContext.MulCoeffsMontgomeryAndAdd(sk, crp[i], shareOut[i][1])
}
ekg.polypool.Zero()
ekg.polypool.Zero() // TODO: check if we can remove this one
return
}
// AggregateShareRoundOne adds share1 and share2 on shareOut.
func (ekg *RKGProtocol) AggregateShareRoundOne(share1, share2, shareOut RKGShareRoundOne) {
func (ekg *RKGProtocol) AggregateShareRoundOne(share1, share2, shareOut RKGShare) {
for i := uint64(0); i < ekg.context.params.Beta; i++ {
ekg.context.contextQP.Add(share1[i], share2[i], shareOut[i])
ekg.context.contextQP.Add(share1[i][0], share2[i][0], shareOut[i][0])
ekg.context.contextQP.Add(share1[i][1], share2[i][1], shareOut[i][1])
}
}
@@ -288,10 +203,13 @@ func (ekg *RKGProtocol) AggregateShareRoundOne(share1, share2, shareOut RKGShare
// = [s_i * (-u*a + s*w + e) + e_i1, s_i*a + e_i2]
//
// and broadcasts both values to the other j-1 parties.
func (ekg *RKGProtocol) GenShareRoundTwo(round1 RKGShareRoundOne, sk *ring.Poly, crp []*ring.Poly, shareOut RKGShareRoundTwo) {
func (ekg *RKGProtocol) GenShareRoundTwo(round1 RKGShare, u, sk *ring.Poly, crp []*ring.Poly, shareOut RKGShare) {
ringContext := ekg.context.contextQP
// (u_i - s_i)
ringContext.Sub(u, sk, ekg.tmpPoly1)
// Each sample is of the form [-u*a_i + s*w_i + e_i]
// So for each element of the base decomposition w_i :
for i := uint64(0); i < ekg.context.params.Beta; i++ {
@@ -299,19 +217,18 @@ func (ekg *RKGProtocol) GenShareRoundTwo(round1 RKGShareRoundOne, sk *ring.Poly,
// Computes [(sum samples)*sk + e_1i, sk*a + e_2i]
// (AggregateShareRoundTwo samples) * sk
ringContext.MulCoeffsMontgomery(round1[i], sk, shareOut[i][0])
ringContext.MulCoeffsMontgomery(round1[i][0], sk, shareOut[i][0])
// (AggregateShareRoundTwo samples) * sk + e_1i
ekg.gaussianSampler.Read(ekg.tmpPoly1)
ringContext.NTT(ekg.tmpPoly1, ekg.tmpPoly1)
ringContext.Add(shareOut[i][0], ekg.tmpPoly1, shareOut[i][0])
ekg.gaussianSampler.Read(ekg.tmpPoly2)
ringContext.NTT(ekg.tmpPoly2, ekg.tmpPoly2)
ringContext.Add(shareOut[i][0], ekg.tmpPoly2, shareOut[i][0])
// Second Element
// e_2i
// second part
// (u - s) * (sum [x][s*a_i + e_2i]) + e3i
ekg.gaussianSampler.Read(shareOut[i][1])
ringContext.NTT(shareOut[i][1], shareOut[i][1])
// s*a + e_2i
ringContext.MulCoeffsMontgomeryAndAdd(sk, crp[i], shareOut[i][1])
ringContext.MulCoeffsMontgomeryAndAdd(ekg.tmpPoly1, round1[i][1], shareOut[i][1])
}
}
@@ -322,54 +239,24 @@ func (ekg *RKGProtocol) GenShareRoundTwo(round1 RKGShareRoundOne, sk *ring.Poly,
// [sum(s_j * (-u*a + s*w + e) + e_j1), sum(s_j*a + e_j2)]
//
// = [s * (-u*a + s*w + e) + e_1, s*a + e_2].
func (ekg *RKGProtocol) AggregateShareRoundTwo(share1, share2, shareOut RKGShareRoundTwo) {
func (ekg *RKGProtocol) AggregateShareRoundTwo(share1, share2, shareOut RKGShare) {
for i := uint64(0); i < ekg.context.params.Beta; i++ {
ekg.context.contextQP.Add(share1[i][0], share2[i][0], shareOut[i][0])
ekg.context.contextQP.Add(share1[i][1], share2[i][1], shareOut[i][1])
}
}
// GenShareRoundThree is the second pard of the third and last round of the RKGProtocol protocol. Each party operates a key-switch on [s*a + e_2],
// by computing :
//
// [(u_i - s_i)*(s*a + e_2)]
//
// and broadcasts the result to the other j-1 parties.
func (ekg *RKGProtocol) GenShareRoundThree(round2 RKGShareRoundTwo, u, sk *ring.Poly, shareOut RKGShareRoundThree) {
ringContext := ekg.context.contextQP
// (u_i - s_i)
ringContext.Sub(u, sk, ekg.tmpPoly1)
for i := uint64(0); i < ekg.context.params.Beta; i++ {
// (u - s) * (sum [x][s*a_i + e_2i]) + e3i
ekg.gaussianSampler.Read(shareOut[i])
ringContext.NTT(shareOut[i], shareOut[i])
ringContext.MulCoeffsMontgomeryAndAdd(ekg.tmpPoly1, round2[i][1], shareOut[i])
}
}
// AggregateShareRoundThree adds share1 and share2 on shareOut.
func (ekg *RKGProtocol) AggregateShareRoundThree(share1, share2, shareOut RKGShareRoundThree) {
for i := uint64(0); i < ekg.context.params.Beta; i++ {
ekg.context.contextQP.Add(share1[i], share2[i], shareOut[i])
}
}
// GenRelinearizationKey finalizes the protocol and returns the common EvaluationKey.
func (ekg *RKGProtocol) GenRelinearizationKey(round2 RKGShareRoundTwo, round3 RKGShareRoundThree, evalKeyOut *bfv.EvaluationKey) {
func (ekg *RKGProtocol) GenRelinearizationKey(round1 RKGShare, round2 RKGShare, evalKeyOut *bfv.EvaluationKey) {
ringContext := ekg.context.contextQP
key := evalKeyOut.Get()[0].Get()
for i := uint64(0); i < ekg.context.params.Beta; i++ {
ringContext.Add(round2[i][0], round3[i], key[i][0])
key[i][1].Copy(round2[i][1])
ringContext.Add(round2[i][0], round2[i][1], key[i][0])
key[i][1].Copy(round1[i][1])
ringContext.MForm(key[i][0], key[i][0])
ringContext.MForm(key[i][1], key[i][1])

View File

@@ -76,16 +76,15 @@ func benchRelinKeyGen(b *testing.B) {
*RKGProtocol
u *ring.Poly
s *ring.Poly
share1 RKGShareRoundOne
share2 RKGShareRoundTwo
share3 RKGShareRoundThree
share1 RKGShare
share2 RKGShare
}
p := new(Party)
p.RKGProtocol = NewEkgProtocol(parameters)
p.u = p.RKGProtocol.NewEphemeralKey()
p.s = sk0Shards[0].Get()
p.share1, p.share2, p.share3 = p.RKGProtocol.AllocateShares()
p.share1, p.share2 = p.RKGProtocol.AllocateShares()
prng, err := utils.NewKeyedPRNG(nil)
if err != nil {
panic(err)
@@ -114,7 +113,7 @@ func benchRelinKeyGen(b *testing.B) {
b.Run(testString("Round2Gen/", parties, parameters), func(b *testing.B) {
for i := 0; i < b.N; i++ {
p.GenShareRoundTwo(p.share1, p.s, crp, p.share2)
p.GenShareRoundTwo(p.share1, p.u, p.s, crp, p.share2)
}
})
@@ -124,20 +123,6 @@ func benchRelinKeyGen(b *testing.B) {
p.AggregateShareRoundTwo(p.share2, p.share2, p.share2)
}
})
b.Run(testString("Round3Gen/", parties, parameters), func(b *testing.B) {
for i := 0; i < b.N; i++ {
p.GenShareRoundThree(p.share2, p.u, p.s, p.share3)
}
})
b.Run(testString("Round3Agg/", parties, parameters), func(b *testing.B) {
for i := 0; i < b.N; i++ {
p.AggregateShareRoundThree(p.share3, p.share3, p.share3)
}
})
}
}

View File

@@ -209,9 +209,8 @@ func testRelinKeyGen(t *testing.T) {
*RKGProtocol
u *ring.Poly
s *ring.Poly
share1 RKGShareRoundOne
share2 RKGShareRoundTwo
share3 RKGShareRoundThree
share1 RKGShare
share2 RKGShare
}
rkgParties := make([]*Party, parties)
@@ -221,7 +220,7 @@ func testRelinKeyGen(t *testing.T) {
p.RKGProtocol = NewEkgProtocol(parameters)
p.u = p.NewEphemeralKey()
p.s = sk0Shards[i].Get()
p.share1, p.share2, p.share3 = p.AllocateShares()
p.share1, p.share2 = p.AllocateShares()
rkgParties[i] = p
}
@@ -248,22 +247,14 @@ func testRelinKeyGen(t *testing.T) {
//ROUND 2
for i, p := range rkgParties {
p.GenShareRoundTwo(P0.share1, p.s, crp, p.share2)
p.GenShareRoundTwo(P0.share1, p.u, p.s, crp, p.share2)
if i > 0 {
P0.AggregateShareRoundTwo(p.share2, P0.share2, P0.share2)
}
}
// ROUND 3
for i, p := range rkgParties {
p.GenShareRoundThree(P0.share2, p.u, p.s, p.share3)
if i > 0 {
P0.AggregateShareRoundThree(p.share3, P0.share3, P0.share3)
}
}
evk := ckks.NewRelinKey(parameters)
P0.GenRelinearizationKey(P0.share2, P0.share3, evk)
P0.GenRelinearizationKey(P0.share1, P0.share2, evk)
coeffs, _, ciphertext := newTestVectors(params, encryptorPk0, 1, t)

View File

@@ -1,62 +1,123 @@
package dckks
import (
"errors"
"github.com/ldsec/lattigo/ckks"
"github.com/ldsec/lattigo/ring"
"github.com/ldsec/lattigo/utils"
)
// RKGProtocol is a structure storing the parameters for the collective evaluation-key generation.
// RKGProtocol is the structure storing the parameters and state for a party in the collective relinearization key
// generation protocol.
type RKGProtocol struct {
dckksContext *dckksContext
context *dckksContext
tmpPoly1 *ring.Poly
tmpPoly2 *ring.Poly
polypool *ring.Poly
gaussianSampler *ring.GaussianSampler
ternarySamplerMontgomery *ring.TernarySampler
}
// RKGShareRoundOne is a struct storing the round one share of the RKG protocol.
type RKGShareRoundOne []*ring.Poly
type RKGShare [][2]*ring.Poly
// RKGShareRoundTwo is a struct storing the round two share of the RKG protocol.
type RKGShareRoundTwo [][2]*ring.Poly
// MarshalBinary encodes the target element on a slice of bytes.
func (share *RKGShare) MarshalBinary() ([]byte, error) {
//we have modulus * bitLog * Len of 1 ring rings
rLength := ((*share)[0])[0].GetDataLen(true)
data := make([]byte, 1+2*rLength*uint64(len(*share)))
if len(*share) > 0xFF {
return []byte{}, errors.New("RKGShare : uint8 overflow on length")
}
data[0] = uint8(len(*share))
// RKGShareRoundThree is a struct storing the round three share of the RKG protocol.
type RKGShareRoundThree []*ring.Poly
//write all of our rings in the data.
//write all the polys
ptr := uint64(1)
for _, elem := range *share {
_, err := elem[0].WriteTo(data[ptr : ptr+rLength])
if err != nil {
return []byte{}, err
}
ptr += rLength
_, err = elem[1].WriteTo(data[ptr : ptr+rLength])
if err != nil {
return []byte{}, err
}
ptr += rLength
}
// AllocateShares allocates the shares of the RKG protocol.
func (ekg *RKGProtocol) AllocateShares() (r1 RKGShareRoundOne, r2 RKGShareRoundTwo, r3 RKGShareRoundThree) {
return data, nil
contextQP := ekg.dckksContext.contextQP
}
r1 = make([]*ring.Poly, ekg.dckksContext.beta)
r2 = make([][2]*ring.Poly, ekg.dckksContext.beta)
r3 = make([]*ring.Poly, ekg.dckksContext.beta)
for i := uint64(0); i < ekg.dckksContext.beta; i++ {
r1[i] = contextQP.NewPoly()
r2[i][0] = contextQP.NewPoly()
r2[i][1] = contextQP.NewPoly()
r3[i] = contextQP.NewPoly()
// UnmarshalBinary decodes a slice of bytes on the target element.
func (share *RKGShare) UnmarshalBinary(data []byte) error {
lenShare := data[0]
rLength := (len(data) - 1) / (2 * int(lenShare))
if *share == nil {
*share = make([][2]*ring.Poly, lenShare)
}
ptr := (1)
for i := (0); i < int(lenShare); i++ {
if (*share)[i][0] == nil || (*share)[i][1] == nil {
(*share)[i][0] = new(ring.Poly)
(*share)[i][1] = new(ring.Poly)
}
err := (*share)[i][0].UnmarshalBinary(data[ptr : ptr+rLength])
if err != nil {
return err
}
ptr += rLength
err = (*share)[i][1].UnmarshalBinary(data[ptr : ptr+rLength])
if err != nil {
return err
}
ptr += rLength
}
return nil
}
// AllocateShares allocates the shares of the EKG protocol.
func (ekg *RKGProtocol) AllocateShares() (r1 RKGShare, r2 RKGShare) {
r1 = make([][2]*ring.Poly, ekg.context.params.Beta)
r2 = make([][2]*ring.Poly, ekg.context.params.Beta)
for i := uint64(0); i < ekg.context.params.Beta; i++ {
r1[i][0] = ekg.context.contextQP.NewPoly()
r1[i][1] = ekg.context.contextQP.NewPoly()
r2[i][0] = ekg.context.contextQP.NewPoly()
r2[i][1] = ekg.context.contextQP.NewPoly()
}
return
}
// NewEkgProtocol creates a new RKGProtocol object that will be used to generate a collective evaluation-key.
// NewEkgProtocol creates a new RKGProtocol object that will be used to generate a collective evaluation-key
// among j parties in the given context with the given bit-decomposition.
func NewEkgProtocol(params *ckks.Parameters) *RKGProtocol {
if !params.IsValid() {
panic("cannot NewEkgProtocol : params not valid (check if they where generated properly)")
}
context := newDckksContext(params)
ekg := new(RKGProtocol)
dckksContext := newDckksContext(params)
ekg.dckksContext = dckksContext
ekg.polypool = ekg.dckksContext.contextQP.NewPoly()
ekg.context = context
ekg.tmpPoly1 = ekg.context.contextQP.NewPoly()
ekg.tmpPoly2 = ekg.context.contextQP.NewPoly()
ekg.polypool = ekg.context.contextQP.NewPoly()
prng, err := utils.NewPRNG()
if err != nil {
panic(err)
}
ekg.gaussianSampler = ring.NewGaussianSampler(prng, dckksContext.contextQP, params.Sigma, uint64(6*params.Sigma))
ekg.ternarySamplerMontgomery = ring.NewTernarySampler(prng, dckksContext.contextQP, 0.5, true)
ekg.ternarySamplerMontgomery = ring.NewTernarySampler(prng, ekg.context.contextQP, 0.5, true)
ekg.gaussianSampler = ring.NewGaussianSampler(prng, ekg.context.contextQP, params.Sigma, uint64(6*params.Sigma))
return ekg
}
@@ -65,70 +126,72 @@ func NewEkgProtocol(params *ckks.Parameters) *RKGProtocol {
// of the collective secret-key.
func (ekg *RKGProtocol) NewEphemeralKey() (ephemeralKey *ring.Poly) {
ephemeralKey = ekg.ternarySamplerMontgomery.ReadNew()
ekg.dckksContext.contextQP.NTT(ephemeralKey, ephemeralKey)
ekg.context.contextQP.NTT(ephemeralKey, ephemeralKey)
return ephemeralKey
}
// GenShareRoundOne is the first of three rounds of the RKGProtocol protocol. Each party generates a pseudo encryption of
// its secret share of the key s_i under its ephemeral key u_i : [-u_i*a + P*s_i + e_i] and broadcasts it to the other
// its secret share of the key s_i under its ephemeral key u_i : [-u_i*a + s_i*w + e_i] and broadcasts it to the other
// j-1 parties.
func (ekg *RKGProtocol) GenShareRoundOne(u, sk *ring.Poly, crp []*ring.Poly, shareOut RKGShareRoundOne) {
contextQP := ekg.dckksContext.contextQP
func (ekg *RKGProtocol) GenShareRoundOne(u, sk *ring.Poly, crp []*ring.Poly, shareOut RKGShare) {
var index uint64
// Given a base decomposition w_i (here the CRT decomposition)
// computes [-u*a_i + P*s_i + e_i]
// where a_i = crp_i
ringContext := ekg.context.contextQP
ekg.polypool.Copy(sk)
contextQP.MulScalarBigint(ekg.polypool, ekg.dckksContext.contextP.ModulusBigint, ekg.polypool)
ringContext.MulScalarBigint(ekg.polypool, ekg.context.contextP.ModulusBigint, ekg.polypool)
contextQP.InvMForm(ekg.polypool, ekg.polypool)
for i := uint64(0); i < ekg.dckksContext.beta; i++ {
ringContext.InvMForm(ekg.polypool, ekg.polypool)
for i := uint64(0); i < ekg.context.params.Beta; i++ {
// h = e
ekg.gaussianSampler.Read(shareOut[i])
contextQP.NTT(shareOut[i], shareOut[i])
ekg.gaussianSampler.Read(shareOut[i][0])
ringContext.NTT(shareOut[i][0], shareOut[i][0])
// h = sk*CrtBaseDecompQi + e
for j := uint64(0); j < ekg.dckksContext.alpha; j++ {
for j := uint64(0); j < ekg.context.params.Alpha; j++ {
index = i*ekg.dckksContext.alpha + j
qi := contextQP.Modulus[index]
index = i*ekg.context.params.Alpha + j
qi := ringContext.Modulus[index]
tmp0 := ekg.polypool.Coeffs[index]
tmp1 := shareOut[i].Coeffs[index]
tmp1 := shareOut[i][0].Coeffs[index]
for w := uint64(0); w < contextQP.N; w++ {
for w := uint64(0); w < ekg.context.contextQP.N; w++ {
tmp1[w] = ring.CRed(tmp1[w]+tmp0[w], qi)
}
// Handles the case where nb pj does not divides nb qi
if index >= uint64(len(ekg.dckksContext.params.Qi)-1) {
if index == uint64(len(ekg.context.params.LogQi)-1) {
break
}
}
// h = sk*CrtBaseDecompQi + -u*a + e
contextQP.MulCoeffsMontgomeryAndSub(u, crp[i], shareOut[i])
ekg.context.contextQP.MulCoeffsMontgomeryAndSub(u, crp[i], shareOut[i][0])
// Second Element
// e_2i
ekg.gaussianSampler.Read(shareOut[i][1])
ringContext.NTT(shareOut[i][1], shareOut[i][1])
// s*a + e_2i
ringContext.MulCoeffsMontgomeryAndAdd(sk, crp[i], shareOut[i][1])
}
ekg.polypool.Zero()
ekg.polypool.Zero() // TODO: check if we can remove this one
return
}
// AggregateShareRoundOne sums share1 with share2 on shareOut.
func (ekg *RKGProtocol) AggregateShareRoundOne(share1, share2, shareOut RKGShareRoundOne) {
// AggregateShareRoundOne adds share1 and share2 on shareOut.
func (ekg *RKGProtocol) AggregateShareRoundOne(share1, share2, shareOut RKGShare) {
contextQP := ekg.dckksContext.contextQP
for i := uint64(0); i < ekg.dckksContext.beta; i++ {
contextQP.Add(share1[i], share2[i], shareOut[i])
for i := uint64(0); i < ekg.context.params.Beta; i++ {
ekg.context.contextQP.Add(share1[i][0], share2[i][0], shareOut[i][0])
ekg.context.contextQP.Add(share1[i][1], share2[i][1], shareOut[i][1])
}
}
@@ -140,34 +203,34 @@ func (ekg *RKGProtocol) AggregateShareRoundOne(share1, share2, shareOut RKGShare
// = [s_i * (-u*a + s*w + e) + e_i1, s_i*a + e_i2]
//
// and broadcasts both values to the other j-1 parties.
func (ekg *RKGProtocol) GenShareRoundTwo(round1 RKGShareRoundOne, sk *ring.Poly, crp []*ring.Poly, shareOut RKGShareRoundTwo) {
func (ekg *RKGProtocol) GenShareRoundTwo(round1 RKGShare, u, sk *ring.Poly, crp []*ring.Poly, shareOut RKGShare) {
contextQP := ekg.dckksContext.contextQP
ringContext := ekg.context.contextQP
// (u_i - s_i)
ringContext.Sub(u, sk, ekg.tmpPoly1)
// Each sample is of the form [-u*a_i + s*w_i + e_i]
// So for each element of the base decomposition w_i :
for i := uint64(0); i < ekg.dckksContext.beta; i++ {
for i := uint64(0); i < ekg.context.params.Beta; i++ {
// Computes [(sum samples)*sk + e_1i, sk*a + e_2i]
// (AggregateShareRoundTwo samples) * sk
contextQP.MulCoeffsMontgomery(round1[i], sk, shareOut[i][0])
ringContext.MulCoeffsMontgomery(round1[i][0], sk, shareOut[i][0])
// (AggregateShareRoundTwo samples) * sk + e_1i
ekg.gaussianSampler.Read(ekg.polypool)
contextQP.NTT(ekg.polypool, ekg.polypool)
contextQP.Add(shareOut[i][0], ekg.polypool, shareOut[i][0])
ekg.gaussianSampler.Read(ekg.tmpPoly2)
ringContext.NTT(ekg.tmpPoly2, ekg.tmpPoly2)
ringContext.Add(shareOut[i][0], ekg.tmpPoly2, shareOut[i][0])
// Second Element
// e_2i
// second part
// (u - s) * (sum [x][s*a_i + e_2i]) + e3i
ekg.gaussianSampler.Read(shareOut[i][1])
contextQP.NTT(shareOut[i][1], shareOut[i][1])
// s*a + e_2i
contextQP.MulCoeffsMontgomeryAndAdd(sk, crp[i], shareOut[i][1])
ringContext.NTT(shareOut[i][1], shareOut[i][1])
ringContext.MulCoeffsMontgomeryAndAdd(ekg.tmpPoly1, round1[i][1], shareOut[i][1])
}
ekg.polypool.Zero()
}
// AggregateShareRoundTwo is the first part of the third and last round of the RKGProtocol protocol. Upon receiving the j-1 elements, each party
@@ -176,62 +239,27 @@ func (ekg *RKGProtocol) GenShareRoundTwo(round1 RKGShareRoundOne, sk *ring.Poly,
// [sum(s_j * (-u*a + s*w + e) + e_j1), sum(s_j*a + e_j2)]
//
// = [s * (-u*a + s*w + e) + e_1, s*a + e_2].
func (ekg *RKGProtocol) AggregateShareRoundTwo(share1, share2, shareOut RKGShareRoundTwo) {
func (ekg *RKGProtocol) AggregateShareRoundTwo(share1, share2, shareOut RKGShare) {
contextQP := ekg.dckksContext.contextQP
for i := uint64(0); i < ekg.dckksContext.beta; i++ {
contextQP.Add(share1[i][0], share2[i][0], shareOut[i][0])
contextQP.Add(share1[i][1], share2[i][1], shareOut[i][1])
}
}
// GenShareRoundThree is the second pard of the third and last round of the RKGProtocol protocol. Each party operates a key-switch on [s*a + e_2],
// by computing :
//
// [(u_i - s_i)*(s*a + e_2)]
//
// and broadcasts the result to the other j-1 parties.
func (ekg *RKGProtocol) GenShareRoundThree(round2 RKGShareRoundTwo, u, sk *ring.Poly, shareOut RKGShareRoundThree) {
contextQP := ekg.dckksContext.contextQP
// (u_i - s_i)
contextQP.Sub(u, sk, ekg.polypool)
for i := uint64(0); i < ekg.dckksContext.beta; i++ {
// (u - s) * (sum [x][s*a_i + e_2i]) + e3i
ekg.gaussianSampler.Read(shareOut[i])
contextQP.NTT(shareOut[i], shareOut[i])
contextQP.MulCoeffsMontgomeryAndAdd(ekg.polypool, round2[i][1], shareOut[i])
for i := uint64(0); i < ekg.context.params.Beta; i++ {
ekg.context.contextQP.Add(share1[i][0], share2[i][0], shareOut[i][0])
ekg.context.contextQP.Add(share1[i][1], share2[i][1], shareOut[i][1])
}
}
// AggregateShareRoundThree sums share1 with share2 on shareOut.
func (ekg *RKGProtocol) AggregateShareRoundThree(share1, share2, shareOut RKGShareRoundThree) {
// GenRelinearizationKey finalizes the protocol and returns the common EvaluationKey.
func (ekg *RKGProtocol) GenRelinearizationKey(round1 RKGShare, round2 RKGShare, evalKeyOut *ckks.EvaluationKey) {
contextQP := ekg.dckksContext.contextQP
for i := uint64(0); i < ekg.dckksContext.beta; i++ {
contextQP.Add(share1[i], share2[i], shareOut[i])
}
}
// GenRelinearizationKey finalizes the protocol and returns the collective EvalutionKey.
func (ekg *RKGProtocol) GenRelinearizationKey(round2 RKGShareRoundTwo, round3 RKGShareRoundThree, evalKeyOut *ckks.EvaluationKey) {
contextQP := ekg.dckksContext.contextQP
ringContext := ekg.context.contextQP
key := evalKeyOut.Get().Get()
for i := uint64(0); i < ekg.dckksContext.beta; i++ {
for i := uint64(0); i < ekg.context.params.Beta; i++ {
contextQP.Add(round2[i][0], round3[i], key[i][0])
key[i][1].Copy(round2[i][1])
ringContext.Add(round2[i][0], round2[i][1], key[i][0])
key[i][1].Copy(round1[i][1])
contextQP.MForm(key[i][0], key[i][0])
contextQP.MForm(key[i][1], key[i][1])
ringContext.MForm(key[i][0], key[i][0])
ringContext.MForm(key[i][1], key[i][1])
}
}

View File

@@ -1,15 +1,16 @@
package main
import (
"github.com/ldsec/lattigo/bfv"
"github.com/ldsec/lattigo/dbfv"
"github.com/ldsec/lattigo/ring"
"github.com/ldsec/lattigo/utils"
"log"
"os"
"strconv"
"sync"
"time"
"github.com/ldsec/lattigo/bfv"
"github.com/ldsec/lattigo/dbfv"
"github.com/ldsec/lattigo/ring"
"github.com/ldsec/lattigo/utils"
)
func check(err error) {
@@ -68,12 +69,11 @@ func main() {
sk *bfv.SecretKey
rlkEphemSk *ring.Poly
ckgShare dbfv.CKGShare
rkgShareOne dbfv.RKGShareRoundOne
rkgShareTwo dbfv.RKGShareRoundTwo
rkgShareThree dbfv.RKGShareRoundThree
rtgShare dbfv.RTGShare
cksShare dbfv.CKSShare
ckgShare dbfv.CKGShare
rkgShareOne dbfv.RKGShare
rkgShareTwo dbfv.RKGShare
rtgShare dbfv.RTGShare
cksShare dbfv.CKSShare
input []uint64
}
@@ -128,7 +128,7 @@ func main() {
contextKeys.Add(colSk.Get(), pi.sk.Get(), colSk.Get()) //TODO: doc says "return"
pi.ckgShare = ckg.AllocateShares()
pi.rkgShareOne, pi.rkgShareTwo, pi.rkgShareThree = rkg.AllocateShares()
pi.rkgShareOne, pi.rkgShareTwo = rkg.AllocateShares()
pi.rtgShare = rtg.AllocateShare()
pi.cksShare = cks.AllocateShare()
@@ -167,7 +167,7 @@ func main() {
}
}, N)
rkgCombined1, rkgCombined2, rkgCombined3 := rkg.AllocateShares()
rkgCombined1, rkgCombined2 := rkg.AllocateShares()
elapsedRKGCloud = runTimed(func() {
for _, pi := range P {
@@ -177,29 +177,18 @@ func main() {
elapsedRKGParty += runTimedParty(func() {
for _, pi := range P {
rkg.GenShareRoundTwo(rkgCombined1, pi.sk.Get(), crp, pi.rkgShareTwo)
}
}, N)
elapsedRKGCloud += runTimed(func() {
for _, pi := range P {
rkg.AggregateShareRoundTwo(pi.rkgShareTwo, rkgCombined2, rkgCombined2)
}
})
elapsedRKGParty += runTimedParty(func() {
for _, pi := range P {
rkg.GenShareRoundThree(rkgCombined2, pi.rlkEphemSk, pi.sk.Get(), pi.rkgShareThree)
rkg.GenShareRoundTwo(rkgCombined1, pi.rlkEphemSk, pi.sk.Get(), crp, pi.rkgShareTwo)
}
}, N)
rlk := bfv.NewRelinKey(params, 1)
elapsedRKGCloud += runTimed(func() {
for _, pi := range P {
rkg.AggregateShareRoundThree(pi.rkgShareThree, rkgCombined3, rkgCombined3)
rkg.AggregateShareRoundTwo(pi.rkgShareTwo, rkgCombined2, rkgCombined2)
}
rkg.GenRelinearizationKey(rkgCombined2, rkgCombined3, rlk)
rkg.GenRelinearizationKey(rkgCombined1, rkgCombined2, rlk)
})
l.Printf("\tdone (cloud: %s, party: %s)\n", elapsedRKGCloud, elapsedRKGParty)
// 3) Collective rotation keys geneneration

View File

@@ -1,16 +1,17 @@
package main
import (
"github.com/ldsec/lattigo/bfv"
"github.com/ldsec/lattigo/dbfv"
"github.com/ldsec/lattigo/ring"
"github.com/ldsec/lattigo/utils"
"log"
"math/rand"
"os"
"strconv"
"sync"
"time"
"github.com/ldsec/lattigo/bfv"
"github.com/ldsec/lattigo/dbfv"
"github.com/ldsec/lattigo/ring"
"github.com/ldsec/lattigo/utils"
)
func check(err error) {
@@ -53,11 +54,10 @@ func main() {
sk *bfv.SecretKey
rlkEphemSk *ring.Poly
ckgShare dbfv.CKGShare
rkgShareOne dbfv.RKGShareRoundOne
rkgShareTwo dbfv.RKGShareRoundTwo
rkgShareThree dbfv.RKGShareRoundThree
pcksShare dbfv.PCKSShare
ckgShare dbfv.CKGShare
rkgShareOne dbfv.RKGShare
rkgShareTwo dbfv.RKGShare
pcksShare dbfv.PCKSShare
input []uint64
}
@@ -112,7 +112,7 @@ func main() {
contextKeys.Add(colSk.Get(), pi.sk.Get(), colSk.Get()) //TODO: doc says "return"
pi.ckgShare = ckg.AllocateShares()
pi.rkgShareOne, pi.rkgShareTwo, pi.rkgShareThree = rkg.AllocateShares()
pi.rkgShareOne, pi.rkgShareTwo = rkg.AllocateShares()
pi.pcksShare = pcks.AllocateShares()
P[i] = pi
@@ -146,7 +146,7 @@ func main() {
}
}, N)
rkgCombined1, rkgCombined2, rkgCombined3 := rkg.AllocateShares()
rkgCombined1, rkgCombined2 := rkg.AllocateShares()
elapsedRKGCloud = runTimed(func() {
for _, pi := range P {
@@ -156,29 +156,18 @@ func main() {
elapsedRKGParty += runTimedParty(func() {
for _, pi := range P {
rkg.GenShareRoundTwo(rkgCombined1, pi.sk.Get(), crp, pi.rkgShareTwo)
}
}, N)
elapsedRKGCloud += runTimed(func() {
for _, pi := range P {
rkg.AggregateShareRoundTwo(pi.rkgShareTwo, rkgCombined2, rkgCombined2)
}
})
elapsedRKGParty += runTimedParty(func() {
for _, pi := range P {
rkg.GenShareRoundThree(rkgCombined2, pi.rlkEphemSk, pi.sk.Get(), pi.rkgShareThree)
rkg.GenShareRoundTwo(rkgCombined1, pi.rlkEphemSk, pi.sk.Get(), crp, pi.rkgShareTwo)
}
}, N)
rlk := bfv.NewRelinKey(params, 1)
elapsedRKGCloud += runTimed(func() {
for _, pi := range P {
rkg.AggregateShareRoundThree(pi.rkgShareThree, rkgCombined3, rkgCombined3)
rkg.AggregateShareRoundTwo(pi.rkgShareTwo, rkgCombined2, rkgCombined2)
}
rkg.GenRelinearizationKey(rkgCombined2, rkgCombined3, rlk)
rkg.GenRelinearizationKey(rkgCombined1, rkgCombined2, rlk)
})
l.Printf("\tdone (cloud: %s, party: %s)\n",
elapsedRKGCloud, elapsedRKGParty)
l.Printf("\tSetup done (cloud: %s, party: %s)\n",