From 589fcbec4a3ef8efd09e6817ed5d479a4b5bb7fc Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Sat, 28 Sep 2019 13:43:07 +0200 Subject: [PATCH] bfv/ckks keygen --- bfv/bfv_benchmark_test.go | 2 +- bfv/bfv_test.go | 8 ++-- bfv/encryptor.go | 22 +++++----- bfv/keygen.go | 58 ++++++++++++------------- ckks/ckks_benchmarks_test.go | 9 +--- ckks/ckks_test.go | 8 ++-- ckks/keygen.go | 22 +++++----- dbfv/dbfv_benchmark_test.go | 11 +---- dbfv/dbfv_test.go | 4 +- dckks/dckks_benchmark_test.go | 11 +---- dckks/dckks_test.go | 4 +- examples/bfv/examples_bfv.go | 5 +-- examples/ckks/examples_ckks.go | 4 +- ring/sampler.go | 77 +++++----------------------------- 14 files changed, 81 insertions(+), 164 deletions(-) diff --git a/bfv/bfv_benchmark_test.go b/bfv/bfv_benchmark_test.go index f1fe9760..c8230985 100755 --- a/bfv/bfv_benchmark_test.go +++ b/bfv/bfv_benchmark_test.go @@ -27,7 +27,7 @@ func BenchmarkBFVScheme(b *testing.B) { // Public Key Generation b.Run(fmt.Sprintf("params=%d/KeyGen", params.N), func(b *testing.B) { for i := 0; i < b.N; i++ { - sk, pk, err = kgen.NewKeyPair(1.0 / 3) + sk, pk = kgen.NewKeyPair() if err != nil { b.Error(err) } diff --git a/bfv/bfv_test.go b/bfv/bfv_test.go index 843bb8cd..892d9865 100644 --- a/bfv/bfv_test.go +++ b/bfv/bfv_test.go @@ -10,7 +10,7 @@ import ( type BFVTESTPARAMS struct { bfvcontext *BfvContext batchencoder *BatchEncoder - kgen *keygenerator + kgen *KeyGenerator sk *SecretKey pk *PublicKey encryptorSk *Encryptor @@ -52,9 +52,7 @@ func Test_BFV(t *testing.T) { log.Fatal(err) } - if bfvTest.sk, bfvTest.pk, err = bfvTest.kgen.NewKeyPair(1.0 / 3.0); err != nil { - log.Fatal(err) - } + bfvTest.sk, bfvTest.pk = bfvTest.kgen.NewKeyPair() if bfvTest.decryptor, err = bfvTest.bfvcontext.NewDecryptor(bfvTest.sk); err != nil { log.Fatal(err) @@ -1009,7 +1007,7 @@ func test_KeySwitching(bfvTest *BFVTESTPARAMS, bitDecomps []uint64, t *testing.T Sk := bfvTest.sk evaluator := bfvTest.evaluator - SkNew, _ := kgen.NewSecretKey(1.0 / 3) + SkNew := kgen.NewSecretKey() decryptor_SkNew, err := bfvContext.NewDecryptor(SkNew) if err != nil { diff --git a/bfv/encryptor.go b/bfv/encryptor.go index 8a0e8e87..1353f552 100644 --- a/bfv/encryptor.go +++ b/bfv/encryptor.go @@ -16,34 +16,32 @@ type Encryptor struct { // NewEncryptorFromPk creates a new Encryptor with the provided public-key. // This encryptor can be used to encrypt plaintexts, using the stored key. func (bfvcontext *BfvContext) NewEncryptorFromPk(pk *PublicKey) (*Encryptor, error) { - - if uint64(pk.pk[0].GetDegree()+pk.pk[1].GetDegree())>>1 != bfvcontext.n { - return nil, errors.New("error : pk ring degree doesn't match bfvcontext ring degree") - } - - return bfvcontext.newEncryptor(pk, nil), nil + return bfvcontext.newEncryptor(pk, nil) } // NewEncryptorFromSk creates a new Encryptor with the provided secret-key. // This encryptor can be used to encrypt plaintexts, using the stored key. func (bfvcontext *BfvContext) NewEncryptorFromSk(sk *SecretKey) (*Encryptor, error) { + return bfvcontext.newEncryptor(nil, sk) +} + +func (bfvcontext *BfvContext) newEncryptor(pk *PublicKey, sk *SecretKey) (encryptor *Encryptor, err error) { + + if pk != nil && uint64(pk.pk[0].GetDegree()+pk.pk[1].GetDegree())>>1 != bfvcontext.n { + return nil, errors.New("error : pk ring degree doesn't match bfvcontext ring degree") + } if sk != nil && uint64(sk.sk.GetDegree()) != bfvcontext.n { return nil, errors.New("error : sk ring degree doesn't match bfvcontext ring degree") } - return bfvcontext.newEncryptor(nil, sk), nil -} - -func (bfvcontext *BfvContext) newEncryptor(pk *PublicKey, sk *SecretKey) (encryptor *Encryptor) { - encryptor = new(Encryptor) encryptor.bfvcontext = bfvcontext encryptor.pk = pk encryptor.sk = sk encryptor.polypool = bfvcontext.contextQ.NewPoly() - return encryptor + return encryptor, nil } // EncryptFromPkNew encrypts the input plaintext using the stored public-key and returns diff --git a/bfv/keygen.go b/bfv/keygen.go index c791ae67..dcab44ce 100644 --- a/bfv/keygen.go +++ b/bfv/keygen.go @@ -8,9 +8,9 @@ import ( "math/bits" ) -// Keygenerator is a structure that stores the elements required to create new keys, +// KeyGenerator is a structure that stores the elements required to create new keys, // as well as a small memory pool for intermediate values. -type keygenerator struct { +type KeyGenerator struct { bfvcontext *BfvContext context *ring.Context polypool *ring.Poly @@ -46,18 +46,24 @@ type SwitchingKey struct { evakey [][][2]*ring.Poly } -// NewKeyGenerator creates a new keygenerator, from which the secret and public keys, as well as the evaluation, +// NewKeyGenerator creates a new KeyGenerator, from which the secret and public keys, as well as the evaluation, // rotation and switching keys can be generated. -func (bfvcontext *BfvContext) NewKeyGenerator() (keygen *keygenerator) { - keygen = new(keygenerator) +func (bfvcontext *BfvContext) NewKeyGenerator() (keygen *KeyGenerator) { + keygen = new(KeyGenerator) keygen.bfvcontext = bfvcontext keygen.context = bfvcontext.contextQ keygen.polypool = keygen.context.NewPoly() return } -// Newsecretkey creates a new SecretKey with uniform distribution in [-1, 0, 1]. -func (keygen *keygenerator) NewSecretKey(p float64) (sk *SecretKey, err error) { +// Newsecretkey creates a new SecretKey with the distribution [1/3, 1/3, 1/3] +func (keygen *KeyGenerator) NewSecretKey() (sk *SecretKey) { + sk, _ = keygen.NewSecretkeyWithDistrib(1.0 / 3) + return sk +} + +// Newsecretkey creates a new SecretKey with the distribution [(p-1)/2, p, (p-1)/2] +func (keygen *KeyGenerator) NewSecretkeyWithDistrib(p float64) (sk *SecretKey, err error) { sk = new(SecretKey) if sk.sk, err = keygen.bfvcontext.ternarySampler.SampleMontgomeryNTTNew(p); err != nil { @@ -68,7 +74,7 @@ func (keygen *keygenerator) NewSecretKey(p float64) (sk *SecretKey, err error) { } // NewSecretKeyEmpty creates a new SecretKey with all coeffcients set to zero, ready to received a marshaled SecretKey. -func (keygen *keygenerator) NewSecretKeyEmpty() *SecretKey { +func (keygen *KeyGenerator) NewSecretKeyEmpty() *SecretKey { sk := new(SecretKey) sk.sk = keygen.context.NewPoly() return sk @@ -84,8 +90,8 @@ func (sk *SecretKey) Set(poly *ring.Poly) { sk.sk = poly.CopyNew() } -// check_sk checks if the input secret-key complies with the keygenerator context. -func (keygen *keygenerator) check_sk(sk_output *SecretKey) (err error) { +// check_sk checks if the input secret-key complies with the KeyGenerator context. +func (keygen *KeyGenerator) check_sk(sk_output *SecretKey) (err error) { if sk_output.Get().GetDegree() != int(keygen.context.N) { return errors.New("error : pol degree sk != bfvcontext.n") @@ -143,7 +149,7 @@ func (sk *SecretKey) UnMarshalBinary(data []byte) error { } // Newpublickey generates a new publickkey from the provided secret-key -func (keygen *keygenerator) NewPublicKey(sk *SecretKey) (pk *PublicKey, err error) { +func (keygen *KeyGenerator) NewPublicKey(sk *SecretKey) (pk *PublicKey, err error) { if err = keygen.check_sk(sk); err != nil { return nil, err @@ -162,7 +168,7 @@ func (keygen *keygenerator) NewPublicKey(sk *SecretKey) (pk *PublicKey, err erro return pk, nil } -func (keygen *keygenerator) NewPublicKeyEmpty() (pk *PublicKey) { +func (keygen *KeyGenerator) NewPublicKeyEmpty() (pk *PublicKey) { pk = new(PublicKey) pk.pk[0] = keygen.context.NewPoly() @@ -232,21 +238,17 @@ func (pk *PublicKey) UnMarshalBinary(data []byte) error { return nil } -// NewKeyPair generates a new (secret-key, public-key) pair. -func (keygen *keygenerator) NewKeyPair(p float64) (sk *SecretKey, pk *PublicKey, err error) { - if sk, err = keygen.NewSecretKey(p); err != nil { - return nil, nil, err - } - if pk, err = keygen.NewPublicKey(sk); err != nil { - return nil, nil, err - } +// NewKeyPair generates a new secret-key with distribution [1/3, 1/3, 1/3] and a corresponding public-key. +func (keygen *KeyGenerator) NewKeyPair() (sk *SecretKey, pk *PublicKey) { + sk = keygen.NewSecretKey() + pk, _ = keygen.NewPublicKey(sk) return } // NewRelinKey generates a new evaluation key from the provided secret-key. It will be used to relinearize a ciphertext (encrypted under a public-key generated from the provided secret-key) // of degree > 1 to a ciphertext of degree 1. Max degree is the maximum degree of the ciphertext allowed to relinearize. Bitdecomp is the power of two binary decomposition of the key. // A higher bigdecomp will induce smaller keys, faster key-switching, but at the cost of more noise. -func (keygen *keygenerator) NewRelinKey(sk *SecretKey, maxDegree, bitDecomp uint64) (newEvakey *EvaluationKey, err error) { +func (keygen *KeyGenerator) NewRelinKey(sk *SecretKey, maxDegree, bitDecomp uint64) (newEvakey *EvaluationKey, err error) { newEvakey = new(EvaluationKey) newEvakey.evakey = make([]*SwitchingKey, maxDegree) @@ -260,7 +262,7 @@ func (keygen *keygenerator) NewRelinKey(sk *SecretKey, maxDegree, bitDecomp uint return newEvakey, nil } -func (keygen *keygenerator) NewRelinKeyEmpty(maxDegree, bitDecomp uint64) (evakey *EvaluationKey) { +func (keygen *KeyGenerator) NewRelinKeyEmpty(maxDegree, bitDecomp uint64) (evakey *EvaluationKey) { evakey = new(EvaluationKey) if bitDecomp > keygen.bfvcontext.maxBit || bitDecomp == 0 { @@ -321,7 +323,7 @@ func (newevakey *EvaluationKey) SetRelinKeys(rlk [][][][2]*ring.Poly, bitDecomp // Newswitchintkey generates a new key-switching key, that will allow to re-encrypt under the output-key a ciphertext encrypted under the input-key. Bitdecomp // is the power of two binary decomposition of the key. A higher bigdecomp will induce smaller keys, faster key-switching, but at the cost of more noise. -func (keygen *keygenerator) NewSwitchingKey(sk_input, sk_output *SecretKey, bitDecomp uint64) (newevakey *SwitchingKey, err error) { +func (keygen *KeyGenerator) NewSwitchingKey(sk_input, sk_output *SecretKey, bitDecomp uint64) (newevakey *SwitchingKey, err error) { if err = keygen.check_sk(sk_input); err != nil { return nil, err @@ -338,7 +340,7 @@ func (keygen *keygenerator) NewSwitchingKey(sk_input, sk_output *SecretKey, bitD return } -func (keygen *keygenerator) NewSwitchingKeyEmpty(bitDecomp uint64) (evakey *SwitchingKey) { +func (keygen *KeyGenerator) NewSwitchingKeyEmpty(bitDecomp uint64) (evakey *SwitchingKey) { evakey = new(SwitchingKey) if bitDecomp > keygen.bfvcontext.maxBit || bitDecomp == 0 { @@ -372,7 +374,7 @@ func (keygen *keygenerator) NewSwitchingKeyEmpty(bitDecomp uint64) (evakey *Swit // Newrotationkeys generates a new struct of rotationkeys storing the keys for the specified rotations. The provided secret-key must be the secret-key used to generate the public-key under // which the ciphertexts to rotate are encrypted under. Bitdecomp is the power of two binary decomposition of the key. A higher bigdecomp will induce smaller keys, faster key-switching, // but at the cost of more noise. rotLeft and rotRight must be a slice of uint64 rotations, row is a boolean value indicating if the key for the row rotation must be generated. -func (keygen *keygenerator) NewRotationKeys(sk *SecretKey, bitDecomp uint64, rotLeft []uint64, rotRight []uint64, row bool) (rotKey *RotationKeys, err error) { +func (keygen *KeyGenerator) NewRotationKeys(sk *SecretKey, bitDecomp uint64, rotLeft []uint64, rotRight []uint64, row bool) (rotKey *RotationKeys, err error) { if err = keygen.check_sk(sk); err != nil { return nil, err @@ -408,7 +410,7 @@ func (keygen *keygenerator) NewRotationKeys(sk *SecretKey, bitDecomp uint64, rot } -func (keygen *keygenerator) NewRotationKeysEmpty() (rotKey *RotationKeys) { +func (keygen *KeyGenerator) NewRotationKeysEmpty() (rotKey *RotationKeys) { rotKey = new(RotationKeys) rotKey.bfvcontext = keygen.bfvcontext @@ -420,7 +422,7 @@ func (keygen *keygenerator) NewRotationKeysEmpty() (rotKey *RotationKeys) { // Newrotationkeys generates a new struct of rotationkeys storing the keys of all the left and right powers of two rotations. The provided secret-key must be the secret-key used to generate the public-key under // which the ciphertexts to rotate are encrypted under. rows is a boolean value indicatig if the keys for the row rotation have to be generated. Bitdecomp is the power of two binary decomposition of the key. // A higher bigdecomp will induce smaller keys, faster key-switching, but at the cost of more noise. -func (keygen *keygenerator) NewRotationKeysPow2(sk *SecretKey, bitDecomp uint64, row bool) (rotKey *RotationKeys, err error) { +func (keygen *KeyGenerator) NewRotationKeysPow2(sk *SecretKey, bitDecomp uint64, row bool) (rotKey *RotationKeys, err error) { if err = keygen.check_sk(sk); err != nil { return nil, err @@ -447,7 +449,7 @@ func (keygen *keygenerator) NewRotationKeysPow2(sk *SecretKey, bitDecomp uint64, } // genrotkey is a methode used in the rotation-keys generation. -func genrotkey(keygen *keygenerator, sk *ring.Poly, gen, bitDecomp uint64) (switchkey *SwitchingKey) { +func genrotkey(keygen *KeyGenerator, sk *ring.Poly, gen, bitDecomp uint64) (switchkey *SwitchingKey) { ring.PermuteNTT(sk, gen, keygen.polypool) keygen.context.Sub(keygen.polypool, sk, keygen.polypool) diff --git a/ckks/ckks_benchmarks_test.go b/ckks/ckks_benchmarks_test.go index 62ab6472..382ca3e9 100755 --- a/ckks/ckks_benchmarks_test.go +++ b/ckks/ckks_benchmarks_test.go @@ -54,9 +54,7 @@ func BenchmarkCKKSScheme(b *testing.B) { kgen = ckkscontext.NewKeyGenerator() - if sk, pk, err = kgen.NewKeyPair(1.0 / 3); err != nil { - b.Error(err) - } + sk, pk = kgen.NewKeyPair() if rlk, err = kgen.NewRelinKey(sk, bdc); err != nil { b.Error(err) @@ -113,10 +111,7 @@ func BenchmarkCKKSScheme(b *testing.B) { // Key Pair Generation b.Run(fmt.Sprintf("logN=%d/logQ=%d/levels=%d/decomp=%d/sigma=%.2f/KeyPairGen", logN, ckkscontext.LogQ(), levels, bdc, sigma), func(b *testing.B) { for i := 0; i < b.N; i++ { - sk, pk, err = kgen.NewKeyPair(1.0 / 3) - if err != nil { - b.Error(err) - } + sk, pk = kgen.NewKeyPair() } }) diff --git a/ckks/ckks_test.go b/ckks/ckks_test.go index 99a55d50..43cff1ae 100644 --- a/ckks/ckks_test.go +++ b/ckks/ckks_test.go @@ -52,9 +52,7 @@ func Test_CKKS(t *testing.T) { ckksTest.kgen = ckksTest.ckkscontext.NewKeyGenerator() - if ckksTest.sk, ckksTest.pk, err = ckksTest.kgen.NewKeyPair(1.0 / 3); err != nil { - t.Error(err) - } + ckksTest.sk, ckksTest.pk = ckksTest.kgen.NewKeyPair() log.Printf("Generating relinearization keys") if ckksTest.rlk, err = ckksTest.kgen.NewRelinKey(ckksTest.sk, ckksTest.ckkscontext.Scale()); err != nil { @@ -1240,7 +1238,7 @@ func test_SwitchKeys(params *CKKSTESTPARAMS, t *testing.T) { t.Error(err) } - sk2, _ := params.kgen.NewSecretKey(1.0 / 3) + sk2 := params.kgen.NewSecretKey() switchingkeys, err := params.kgen.NewSwitchingKey(params.sk, sk2, 10) if err != nil { @@ -1546,7 +1544,7 @@ func test_MarshalSwitchingKey(params *CKKSTESTPARAMS, t *testing.T) { bitDecomp := uint64(15) - s1, _ := params.kgen.NewSecretKey(1.0 / 3) + s1 := params.kgen.NewSecretKey() switchkey, _ := params.kgen.NewSwitchingKey(params.sk, s1, bitDecomp) diff --git a/ckks/keygen.go b/ckks/keygen.go index 8379862a..6b3ef000 100644 --- a/ckks/keygen.go +++ b/ckks/keygen.go @@ -70,8 +70,14 @@ func (keygen *KeyGenerator) check_sk(sk_output *SecretKey) error { return nil } -// NewSecretKey generates a new secret key. -func (keygen *KeyGenerator) NewSecretKey(p float64) (sk *SecretKey, err error) { +// NewSecretKey generates a new secret key with the distribution [1/3, 1/3, 1/3]. +func (keygen *KeyGenerator) NewSecretKey() (sk *SecretKey) { + sk, _ = keygen.NewSecretKeyWithDistrib(1.0 / 3) + return sk +} + +// NewSecretKey generates a new secret key with the distribution [(p-1)/2, p, (p-1)/2]. +func (keygen *KeyGenerator) NewSecretKeyWithDistrib(p float64) (sk *SecretKey, err error) { sk = new(SecretKey) if sk.sk, err = keygen.ckkscontext.ternarySampler.SampleMontgomeryNTTNew(p); err != nil { return nil, err @@ -135,14 +141,10 @@ func (pk *PublicKey) Set(poly [2]*ring.Poly) { pk.pk[1] = poly[1].CopyNew() } -// NewKeyPair generates a new secretkey and a corresponding public key. -func (keygen *KeyGenerator) NewKeyPair(p float64) (sk *SecretKey, pk *PublicKey, err error) { - if sk, err = keygen.NewSecretKey(p); err != nil { - return nil, nil, err - } - if pk, err = keygen.NewPublicKey(sk); err != nil { - return nil, nil, err - } +// NewKeyPair generates a new secretkey with distribution [1/3, 1/3, 1/3] and a corresponding public key. +func (keygen *KeyGenerator) NewKeyPair() (sk *SecretKey, pk *PublicKey) { + sk = keygen.NewSecretKey() + pk, _ = keygen.NewPublicKey(sk) return } diff --git a/dbfv/dbfv_benchmark_test.go b/dbfv/dbfv_benchmark_test.go index 97a20e6a..86fb52b3 100644 --- a/dbfv/dbfv_benchmark_test.go +++ b/dbfv/dbfv_benchmark_test.go @@ -25,15 +25,8 @@ func Benchmark_DBFVScheme(b *testing.B) { kgen := bfvContext.NewKeyGenerator() - sk0, pk0, err := kgen.NewKeyPair(1.0 / 3) - if err != nil { - log.Fatal(err) - } - - sk1, pk1, err := kgen.NewKeyPair(1.0 / 3) - if err != nil { - log.Fatal(err) - } + sk0, pk0 := kgen.NewKeyPair() + sk1, pk1 := kgen.NewKeyPair() crpGenerator, err := NewCRPGenerator(nil, context) if err != nil { diff --git a/dbfv/dbfv_test.go b/dbfv/dbfv_test.go index 040f1867..f1f745b1 100644 --- a/dbfv/dbfv_test.go +++ b/dbfv/dbfv_test.go @@ -61,8 +61,8 @@ func Test_DBFVScheme(t *testing.T) { tmp1 := context.NewPoly() for i := 0; i < parties; i++ { - sk0_shards[i], _ = kgen.NewSecretKey(1.0 / 3) - sk1_shards[i], _ = kgen.NewSecretKey(1.0 / 3) + sk0_shards[i] = kgen.NewSecretKey() + sk1_shards[i] = kgen.NewSecretKey() context.Add(tmp0, sk0_shards[i].Get(), tmp0) context.Add(tmp1, sk1_shards[i].Get(), tmp1) } diff --git a/dckks/dckks_benchmark_test.go b/dckks/dckks_benchmark_test.go index 830dbd43..7d1c35ef 100644 --- a/dckks/dckks_benchmark_test.go +++ b/dckks/dckks_benchmark_test.go @@ -54,15 +54,8 @@ func Benchmark_DCKKSScheme(b *testing.B) { kgen := benchcontext.ckkscontext.NewKeyGenerator() - benchcontext.sk0, benchcontext.pk0, err = kgen.NewKeyPair(1.0 / 3) - if err != nil { - log.Fatal(err) - } - - benchcontext.sk1, benchcontext.pk1, err = kgen.NewKeyPair(1.0 / 3) - if err != nil { - log.Fatal(err) - } + benchcontext.sk0, benchcontext.pk0 = kgen.NewKeyPair() + benchcontext.sk1, benchcontext.pk1 = kgen.NewKeyPair() benchcontext.cprng, err = NewCRPGenerator(nil, benchcontext.ckkscontext.ContextKeys()) if err != nil { diff --git a/dckks/dckks_test.go b/dckks/dckks_test.go index 83cee300..184249ee 100644 --- a/dckks/dckks_test.go +++ b/dckks/dckks_test.go @@ -78,8 +78,8 @@ func Test_DBFVScheme(t *testing.T) { tmp1 := context.NewPoly() for i := uint64(0); i < parties; i++ { - sk0_shards[i], _ = kgen.NewSecretKey(1.0 / 3) - sk1_shards[i], _ = kgen.NewSecretKey(1.0 / 3) + sk0_shards[i] = kgen.NewSecretKey() + sk1_shards[i] = kgen.NewSecretKey() context.Add(tmp0, sk0_shards[i].Get(), tmp0) context.Add(tmp1, sk1_shards[i].Get(), tmp1) } diff --git a/examples/bfv/examples_bfv.go b/examples/bfv/examples_bfv.go index c8fff463..e859c432 100644 --- a/examples/bfv/examples_bfv.go +++ b/examples/bfv/examples_bfv.go @@ -31,10 +31,7 @@ func ObliviousRiding() { kgen := bfvContext.NewKeyGenerator() - Sk, Pk, err := kgen.NewKeyPair(1.0 / 3) - if err != nil { - log.Fatal(err) - } + Sk, Pk := kgen.NewKeyPair() Decryptor, err := bfvContext.NewDecryptor(Sk) if err != nil { diff --git a/examples/ckks/examples_ckks.go b/examples/ckks/examples_ckks.go index a34b8ef6..71a3480d 100644 --- a/examples/ckks/examples_ckks.go +++ b/examples/ckks/examples_ckks.go @@ -36,9 +36,7 @@ func chebyshevinterpolation() { // Keys var sk *ckks.SecretKey var pk *ckks.PublicKey - if sk, pk, err = kgen.NewKeyPair(1.0 / 3); err != nil { - log.Fatal(err) - } + sk, pk = kgen.NewKeyPair() // Relinearization key var rlk *ckks.EvaluationKey diff --git a/ring/sampler.go b/ring/sampler.go index bbe8c434..4a56a856 100644 --- a/ring/sampler.go +++ b/ring/sampler.go @@ -251,7 +251,7 @@ func computeMatrixTernary(p float64) (M [][]uint8) { } // SampleMontgomeryNew samples coefficients with ternary distribution in montgomery form on the target polynomial. -func sampleOnPol(context *Context, samplerMatrix [][]uint64, p float64, pol *Poly) (err error) { +func (sampler *TernarySampler) sample(samplerMatrix [][]uint64, p float64, pol *Poly) (err error) { if p == 0 { return errors.New("cannot sample -> p = 0") @@ -263,8 +263,8 @@ func sampleOnPol(context *Context, samplerMatrix [][]uint64, p float64, pol *Pol if p == 0.5 { - randomBytesCoeffs := make([]byte, context.N>>3) - randomBytesSign := make([]byte, context.N>>3) + randomBytesCoeffs := make([]byte, sampler.context.N>>3) + randomBytesSign := make([]byte, sampler.context.N>>3) if _, err := rand.Read(randomBytesCoeffs); err != nil { panic("crypto rand error") @@ -274,13 +274,13 @@ func sampleOnPol(context *Context, samplerMatrix [][]uint64, p float64, pol *Pol panic("crypto rand error") } - for i := uint64(0); i < context.N; i++ { + for i := uint64(0); i < sampler.context.N; i++ { coeff = uint64(uint8(randomBytesCoeffs[i>>3])>>(i&7)) & 1 sign = uint64(uint8(randomBytesSign[i>>3])>>(i&7)) & 1 index = (coeff & (sign ^ 1)) | ((sign & coeff) << 1) - for j := range context.Modulus { + for j := range sampler.context.Modulus { pol.Coeffs[j][i] = samplerMatrix[j][index] //(coeff & (sign^1)) | (qi - 1) * (sign & coeff) } } @@ -297,13 +297,13 @@ func sampleOnPol(context *Context, samplerMatrix [][]uint64, p float64, pol *Pol panic("crypto rand error") } - for i := uint64(0); i < context.N; i++ { + for i := uint64(0); i < sampler.context.N; i++ { coeff, sign, randomBytes, pointer = kysampling(matrix, randomBytes, pointer) index = (coeff & (sign ^ 1)) | ((sign & coeff) << 1) - for j := range context.Modulus { + for j := range sampler.context.Modulus { pol.Coeffs[j][i] = samplerMatrix[j][index] //(coeff & (sign^1)) | (qi - 1) * (sign & coeff) } } @@ -312,76 +312,19 @@ func sampleOnPol(context *Context, samplerMatrix [][]uint64, p float64, pol *Pol return nil } -func sampleOnArray(values []uint64, p float64) (err error) { - - if p == 0 { - return errors.New("cannot sample -> p = 0") - } - - var coeff uint64 - var sign uint64 - - if p == 0.5 { - - randomBytesCoeffs := make([]byte, len(values)>>3) - randomBytesSign := make([]byte, len(values)>>3) - - if _, err := rand.Read(randomBytesCoeffs); err != nil { - panic("crypto rand error") - } - - if _, err := rand.Read(randomBytesSign); err != nil { - panic("crypto rand error") - } - - for i := uint64(0); i < uint64(len(values)); i++ { - - coeff = uint64(uint8(randomBytesCoeffs[i>>3])>>(i&7)) & 1 - sign = uint64(uint8(randomBytesSign[i>>3])>>(i&7)) & 1 - - values[i] = (coeff & (sign ^ 1)) | ((sign & coeff) << 1) - } - - } else { - - matrix := computeMatrixTernary(p) - - randomBytes := make([]byte, 8) - - pointer := uint8(0) - - if _, err := rand.Read(randomBytes); err != nil { - panic("crypto rand error") - } - - for i := 0; i < len(values); i++ { - - coeff, sign, randomBytes, pointer = kysampling(matrix, randomBytes, pointer) - - values[i] = (coeff & (sign ^ 1)) | ((sign & coeff) << 1) - } - } - - return nil -} - -func (sampler *TernarySampler) SampleOnArray(values []uint64, p float64) (err error) { - return sampleOnArray(values, p) -} - func (sampler *TernarySampler) SampleUniform(pol *Poly) { - _ = sampleOnPol(sampler.context, sampler.Matrix, 1.0/3.0, pol) + _ = sampler.sample(sampler.Matrix, 1.0/3.0, pol) } func (sampler *TernarySampler) Sample(p float64, pol *Poly) (err error) { - if err = sampleOnPol(sampler.context, sampler.Matrix, p, pol); err != nil { + if err = sampler.sample(sampler.Matrix, p, pol); err != nil { return err } return nil } func (sampler *TernarySampler) SampleMontgomery(p float64, pol *Poly) (err error) { - if err = sampleOnPol(sampler.context, sampler.MatrixMontgomery, p, pol); err != nil { + if err = sampler.sample(sampler.MatrixMontgomery, p, pol); err != nil { return err } return nil