From 9ebcc9aa0f477202bc24ae8b70392b56a242d7b9 Mon Sep 17 00:00:00 2001 From: lehugueni Date: Fri, 20 Jun 2025 16:36:48 +0200 Subject: [PATCH] style: apply feedback from pr --- circuits/ckks/bootstrapping/evaluator.go | 10 ++----- .../evaluator_benchmarks_test.go | 1 + .../common/lintrans/lintrans_evaluator.go | 9 ++---- core/rlwe/inner_sum.go | 9 ++---- core/rlwe/pool.go | 28 ++++++++++++++++++ core/rlwe/rlwe_test.go | 29 +++++-------------- .../ckks_bootstrapping/basics/main.go | 16 +++++----- examples/singleparty/tutorials/ckks/main.go | 2 +- ring/sampler_gaussian.go | 13 ++++----- schemes/ckks/evaluator.go | 11 ++----- utils/sampling/prng.go | 6 ++-- 11 files changed, 64 insertions(+), 70 deletions(-) diff --git a/circuits/ckks/bootstrapping/evaluator.go b/circuits/ckks/bootstrapping/evaluator.go index c48efadf..989810a8 100644 --- a/circuits/ckks/bootstrapping/evaluator.go +++ b/circuits/ckks/bootstrapping/evaluator.go @@ -11,7 +11,6 @@ import ( "github.com/tuneinsight/lattigo/v6/circuits/ckks/polynomial" "github.com/tuneinsight/lattigo/v6/core/rlwe" "github.com/tuneinsight/lattigo/v6/ring" - "github.com/tuneinsight/lattigo/v6/ring/ringqp" "github.com/tuneinsight/lattigo/v6/schemes/ckks" "github.com/tuneinsight/lattigo/v6/utils" "github.com/tuneinsight/lattigo/v6/utils/bignum" @@ -671,13 +670,8 @@ func (eval Evaluator) ModUp(ctIn *rlwe.Ciphertext) (ctOut *rlwe.Ciphertext, err ks := eval.Evaluator.Evaluator - size := eval.ResidualParameters.BaseRNSDecompositionVectorSize(eval.BootstrappingParameters.MaxLevelQ(), 0) - buffDecompQP := make([]ringqp.Poly, size) - for i := 0; i < size; i++ { - buff := poolQP.GetBuffPolyQP() - defer poolQP.RecycleBuffPolyQP(buff) - buffDecompQP[i] = *buff - } + buffDecompQP := poolQP.GetBuffDecompQP(eval.ResidualParameters.Parameters, eval.BootstrappingParameters.MaxLevelQ(), 0) + defer eval.pool.RecycleBuffDecompQP(buffDecompQP) // ModUp q->QP for ctIn[1] centered around q for j := 0; j < N; j++ { diff --git a/circuits/ckks/bootstrapping/evaluator_benchmarks_test.go b/circuits/ckks/bootstrapping/evaluator_benchmarks_test.go index cd4193ba..8c9d52c2 100644 --- a/circuits/ckks/bootstrapping/evaluator_benchmarks_test.go +++ b/circuits/ckks/bootstrapping/evaluator_benchmarks_test.go @@ -26,6 +26,7 @@ func BenchmarkConcurrentBootstrap(b *testing.B) { evk, _, err := btpParams.GenEvaluationKeys(sk) require.NoError(b, err) + // Benchmark parallel bootstrapping b.Run(ParamsToString(params, btpParams.LogMaxDimensions().Cols, "Bootstrap/"), func(b *testing.B) { var err error eval, err := NewEvaluator(btpParams, evk) diff --git a/circuits/common/lintrans/lintrans_evaluator.go b/circuits/common/lintrans/lintrans_evaluator.go index 4100294b..1845246d 100644 --- a/circuits/common/lintrans/lintrans_evaluator.go +++ b/circuits/common/lintrans/lintrans_evaluator.go @@ -48,13 +48,8 @@ func (eval Evaluator) EvaluateMany(ctIn *rlwe.Ciphertext, linearTransformations levelQ = utils.Min(levelQ, ctIn.Level()) poolQP := eval.pool.AtLevel(levelQ, levelP) - baseRNSDecompositionVectorSize := eval.Evaluator.GetRLWEParameters().BaseRNSDecompositionVectorSize(levelQ, levelP) - buffDecompQP := make([]ringqp.Poly, baseRNSDecompositionVectorSize) - for i := 0; i < len(buffDecompQP); i++ { - buff := poolQP.GetBuffPolyQP() - defer poolQP.RecycleBuffPolyQP(buff) - buffDecompQP[i] = *buff - } + buffDecompQP := poolQP.GetBuffDecompQP(*eval.GetRLWEParameters(), levelQ, levelP) + defer eval.pool.RecycleBuffDecompQP(buffDecompQP) eval.DecomposeNTT(levelQ, levelP, levelP+1, ctIn.Value[1], ctIn.IsNTT, buffDecompQP) diff --git a/core/rlwe/inner_sum.go b/core/rlwe/inner_sum.go index 913f61b9..4e5ab372 100644 --- a/core/rlwe/inner_sum.go +++ b/core/rlwe/inner_sum.go @@ -211,13 +211,8 @@ func (eval Evaluator) PartialTracesSum(ctIn *Ciphertext, offset, n int, opOut *C cQ.MetaData = ctInNTT.MetaData - baseRNSDecompositionVectorSize := eval.params.BaseRNSDecompositionVectorSize(levelQ, levelP) - buffDecompQP := make([]ringqp.Poly, baseRNSDecompositionVectorSize) - for i := 0; i < len(buffDecompQP); i++ { - buff := poolQP.GetBuffPolyQP() - defer poolQP.RecycleBuffPolyQP(buff) - buffDecompQP[i] = *buff - } + buffDecompQP := poolQP.GetBuffDecompQP(eval.params, levelQ, levelP) + defer eval.pool.RecycleBuffDecompQP(buffDecompQP) state := false copy := true diff --git a/core/rlwe/pool.go b/core/rlwe/pool.go index 7338d9bb..ee3737c3 100644 --- a/core/rlwe/pool.go +++ b/core/rlwe/pool.go @@ -29,6 +29,15 @@ func NewPool(rqp *ringqp.Ring, pools ...structs.BufferPool[*[]uint64]) *Pool { return &Pool{ringqpPool} } +// AtLevel returns a new pool from which objects from polynomials at the given levels can be drawn. +// The method accepts up to two levels: +// Zero level: the objects returned are built from polynomials at level 0. +// One level: the objects returned are built from polynomials in RingQ (resp. RingP) at the given level (resp. level 0). +// Two levels: the objects returned are built from polynomials in RingQ (resp. RingP) at levels[0] (resp. levels[1]). +func (pool Pool) AtLevel(levels ...int) *Pool { + return &Pool{pool.Pool.AtLevel(levels...)} +} + // GetBuffCt returns a ciphertext that can be used as a buffer for intermediate computations. // After use, the ciphertext should be recycled with [Pool.RecycleBuffCt]. // The optional dimensions specify the degree and level of the ciphertext (default to 2, eval.params.ringQ.Level()). @@ -99,3 +108,22 @@ func (pool *Pool) GetBuffPt(level ...int) *Plaintext { func (pool *Pool) RecycleBuffPt(pt *Plaintext) { pool.RecycleBuffPoly(&pt.Value) } + +// GetBuffDecompQP returns buffers of polys to be used for RNS decomposition. +// After use, the array of buffers must be recycled with [Pool.RecycleBuffDecompQP]. +func (pool *Pool) GetBuffDecompQP(params Parameters, levelQ, levelP int) []ringqp.Poly { + size := params.BaseRNSDecompositionVectorSize(levelQ, levelP) + buffDecompQP := make([]ringqp.Poly, size) + for i := 0; i < size; i++ { + poly := pool.GetBuffPolyQP() + buffDecompQP[i] = *poly + } + return buffDecompQP +} + +// RecycleBuffDecompQP recycles a temporary array of polys used for decomposition. +func (pool *Pool) RecycleBuffDecompQP(decomp []ringqp.Poly) { + for _, poly := range decomp { + pool.RecycleBuffPolyQP(&poly) + } +} diff --git a/core/rlwe/rlwe_test.go b/core/rlwe/rlwe_test.go index 80c1f651..b2afe736 100644 --- a/core/rlwe/rlwe_test.go +++ b/core/rlwe/rlwe_test.go @@ -729,13 +729,9 @@ func testGadgetProduct(tc *TestContext, levelQ, bpw2 int, t *testing.T) { // Setup temporary buffer for decomposition poolQP := tc.eval.pool.AtLevel(levelQ, levelP) - size := params.BaseRNSDecompositionVectorSize(levelQ, 0) - buffDecompQP := make([]ringqp.Poly, size) - for i := 0; i < size; i++ { - poly := poolQP.GetBuffPolyQP() - defer poolQP.RecycleBuffPolyQP(poly) - buffDecompQP[i] = *poly - } + buffDecompQP := poolQP.GetBuffDecompQP(params, levelQ, levelP) + defer poolQP.RecycleBuffDecompQP(buffDecompQP) + if bpw2 != 0 { t.Skip("method is unsupported for BaseTwoDecomposition != 0") } @@ -966,13 +962,9 @@ func testAutomorphism(tc *TestContext, level, bpw2 int, t *testing.T) { // Setup temporary buffer for decomposition poolQP := tc.eval.pool.AtLevel(level, params.MaxLevelP()) - size := params.BaseRNSDecompositionVectorSize(level, 0) - buffDecompQP := make([]ringqp.Poly, size) - for i := 0; i < size; i++ { - poly := poolQP.GetBuffPolyQP() - defer poolQP.RecycleBuffPolyQP(poly) - buffDecompQP[i] = *poly - } + buffDecompQP := poolQP.GetBuffDecompQP(params, level, 0) + defer poolQP.RecycleBuffDecompQP(buffDecompQP) + // Generate a plaintext with values up to 2^30 pt := genPlaintext(params, level, 1<<30) @@ -1025,13 +1017,8 @@ func testAutomorphism(tc *TestContext, level, bpw2 int, t *testing.T) { // Setup temporary buffer for decomposition poolQP := tc.eval.pool.AtLevel(level, params.MaxLevelP()) - size := params.BaseRNSDecompositionVectorSize(level, 0) - buffDecompQP := make([]ringqp.Poly, size) - for i := 0; i < size; i++ { - poly := poolQP.GetBuffPolyQP() - defer poolQP.RecycleBuffPolyQP(poly) - buffDecompQP[i] = *poly - } + buffDecompQP := poolQP.GetBuffDecompQP(params, level, 0) + defer poolQP.RecycleBuffDecompQP(buffDecompQP) // Generate a plaintext with values up to 2^30 pt := genPlaintext(params, level, 1<<30) diff --git a/examples/singleparty/ckks_bootstrapping/basics/main.go b/examples/singleparty/ckks_bootstrapping/basics/main.go index a8b0b8e6..aafe86fd 100644 --- a/examples/singleparty/ckks_bootstrapping/basics/main.go +++ b/examples/singleparty/ckks_bootstrapping/basics/main.go @@ -9,7 +9,6 @@ import ( "flag" "fmt" "math" - "sync" "github.com/tuneinsight/lattigo/v6/circuits/ckks/bootstrapping" "github.com/tuneinsight/lattigo/v6/core/rlwe" @@ -194,22 +193,23 @@ func main() { // To equalize the scale, the function evaluator.SetScale(ciphertext, parameters.DefaultScale()) can be used at the expense of one level. // If the ciphertext is is at level one or greater when given to the bootstrapper, this equalization is automatically done. // Here we bootstrap two ciphertexts in parallel to demonstrate that evaluators are thread-safe. - var wg sync.WaitGroup - var res1, res2 *rlwe.Ciphertext + resChan := make(chan *rlwe.Ciphertext, 1) fmt.Println("Bootstrapping...") - wg.Add(1) + go func() { - defer wg.Done() - res2, err = eval.Bootstrap(ciphertext2) + res, err := eval.Bootstrap(ciphertext2) if err != nil { panic(err) } + resChan <- res }() - res1, err = eval.Bootstrap(ciphertext1) + + res1, err := eval.Bootstrap(ciphertext1) if err != nil { panic(err) } - wg.Wait() + + res2 := <-resChan fmt.Println("Done") //================== diff --git a/examples/singleparty/tutorials/ckks/main.go b/examples/singleparty/tutorials/ckks/main.go index 6da6da8b..336a2966 100644 --- a/examples/singleparty/tutorials/ckks/main.go +++ b/examples/singleparty/tutorials/ckks/main.go @@ -759,7 +759,7 @@ func main() { // CONCURRENCY // ========== // - // All public structures in Lattigo are thread-safe and can be used concurrently. + // In lattigo, a type's methods are thread-safe and can be called concurrently. E.g. one can use an evaluator's methods concurrently. // } diff --git a/ring/sampler_gaussian.go b/ring/sampler_gaussian.go index 61ab19b6..bc103923 100644 --- a/ring/sampler_gaussian.go +++ b/ring/sampler_gaussian.go @@ -72,12 +72,12 @@ func (g *GaussianSampler) read(pol Poly, f func(a, b, c uint64) uint64) { r := g.baseRing - randomBufferN := make([]byte, 1024) + var randomBufferN [1024]byte var ptr int level := r.level - if _, err := g.prng.Read(randomBufferN); err != nil { + if _, err := g.prng.Read(randomBufferN[:]); err != nil { // Sanity check, this error should not happen. panic(err) } @@ -121,7 +121,7 @@ func (g *GaussianSampler) read(pol Poly, f func(a, b, c uint64) uint64) { for { // Sample norm with sigma = 1 and sign - norm, sign = g.normFloat64(randomBufferN, &ptr) + norm, sign = g.normFloat64(randomBufferN[:], &ptr) // Sets normFlo = norm * sigma with precision 53 bits // and 0.5 for rounding discretization @@ -160,7 +160,7 @@ func (g *GaussianSampler) read(pol Poly, f func(a, b, c uint64) uint64) { for i := 0; i < N; i++ { for { - norm, sign = g.normFloat64(randomBufferN, &ptr) + norm, sign = g.normFloat64(randomBufferN[:], &ptr) if v := norm * sigma; v <= bound { coeffInt = uint64(v + 0.5) // rounding @@ -189,16 +189,15 @@ func (g *GaussianSampler) read(pol Poly, f func(a, b, c uint64) uint64) { // // Algorithm adapted from https://golang.org/src/math/rand/normal.go // to use a secure PRNG instead of math/rand. -func (g *GaussianSampler) normFloat64(buffer []byte, ptr *int) (float64, uint64) { +func (g *GaussianSampler) normFloat64(buff []byte, ptr *int) (float64, uint64) { currPtr := *ptr - buff := buffer prng := g.prng buffLen := len(buff) read := func() { if currPtr == buffLen { - if _, err := prng.Read(buff[:]); err != nil { + if _, err := prng.Read(buff); err != nil { // Sanity check, this error should not happen. panic(err) } diff --git a/schemes/ckks/evaluator.go b/schemes/ckks/evaluator.go index 8df78789..d3bc417a 100644 --- a/schemes/ckks/evaluator.go +++ b/schemes/ckks/evaluator.go @@ -1240,15 +1240,10 @@ func (eval Evaluator) RotateHoistedNew(ctIn *rlwe.Ciphertext, rotations []int) ( func (eval Evaluator) RotateHoisted(ctIn *rlwe.Ciphertext, rotations []int, opOut map[int]*rlwe.Ciphertext) (err error) { levelQ := ctIn.Level() - baseRNSDecompositionVectorSize := eval.GetParameters().BaseRNSDecompositionVectorSize(levelQ, eval.GetParameters().MaxLevelP()) - buffDecompQP := make([]ringqp.Poly, baseRNSDecompositionVectorSize) - poolQP := eval.pool.AtLevel(levelQ, eval.GetParameters().MaxLevelP()) - for i := 0; i < len(buffDecompQP); i++ { - buff := poolQP.GetBuffPolyQP() - defer poolQP.RecycleBuffPolyQP(buff) - buffDecompQP[i] = *buff - } + + buffDecompQP := poolQP.GetBuffDecompQP(*eval.GetRLWEParameters(), levelQ, eval.GetParameters().MaxLevelP()) + defer poolQP.RecycleBuffDecompQP(buffDecompQP) eval.DecomposeNTT(levelQ, eval.GetParameters().MaxLevelP(), eval.GetParameters().PCount(), ctIn.Value[1], ctIn.IsNTT, buffDecompQP) for _, i := range rotations { diff --git a/utils/sampling/prng.go b/utils/sampling/prng.go index 800cbb18..99338f13 100644 --- a/utils/sampling/prng.go +++ b/utils/sampling/prng.go @@ -30,8 +30,8 @@ func (prng *ThreadSafePRNG) Read(sum []byte) (n int, err error) { // sequences of random bytes among different parties using the hash function blake2b. Backward sequence // security (given the digest i, compute the digest i-1) is ensured by default, however forward sequence // security (given the digest i, compute the digest i+1) is only ensured if the KeyedPRNG is keyed. -// WARNING: KeyedPRNG should NOT be called by multiple threads. It does not make sense to do so as the resulting -// sequence will not be deterministic for a given key. For a PRNG securely seeded with a private keyuse [ThreadSafePRNG]. +// WARNING: If KeyedPRNG is called concurrently by multiple threads, the resulting sequences will be independent and no error will be triggered. However, the result will not be deterministic and therefore, in most cases, it does not make sense to use KeyedPRNG in a concurrent setting. +// NOTE: For a PRNG securely seeded with a private key use [ThreadSafePRNG]. type KeyedPRNG struct { mutex sync.Mutex key []byte @@ -41,7 +41,7 @@ type KeyedPRNG struct { // NewKeyedPRNG creates a new instance of KeyedPRNG. // Accepts an optional key, else set key=nil which is treated as key=[]byte{} // WARNING: A PRNG INITIALISED WITH key=nil IS INSECURE! -// WARNING: KeyedPRNG should NOT be called by multiple threads. If that occurs, the generated sequence will not be deterministic. +// WARNING: KeyedPRNG can be called by multiple threads BUT the generated sequences will not be deterministic. func NewKeyedPRNG(key []byte) (*KeyedPRNG, error) { var err error prng := new(KeyedPRNG)