style: apply feedback from pr

This commit is contained in:
lehugueni
2025-06-20 16:36:48 +02:00
parent 0638ff8636
commit 9ebcc9aa0f
11 changed files with 64 additions and 70 deletions

View File

@@ -11,7 +11,6 @@ import (
"github.com/tuneinsight/lattigo/v6/circuits/ckks/polynomial" "github.com/tuneinsight/lattigo/v6/circuits/ckks/polynomial"
"github.com/tuneinsight/lattigo/v6/core/rlwe" "github.com/tuneinsight/lattigo/v6/core/rlwe"
"github.com/tuneinsight/lattigo/v6/ring" "github.com/tuneinsight/lattigo/v6/ring"
"github.com/tuneinsight/lattigo/v6/ring/ringqp"
"github.com/tuneinsight/lattigo/v6/schemes/ckks" "github.com/tuneinsight/lattigo/v6/schemes/ckks"
"github.com/tuneinsight/lattigo/v6/utils" "github.com/tuneinsight/lattigo/v6/utils"
"github.com/tuneinsight/lattigo/v6/utils/bignum" "github.com/tuneinsight/lattigo/v6/utils/bignum"
@@ -671,13 +670,8 @@ func (eval Evaluator) ModUp(ctIn *rlwe.Ciphertext) (ctOut *rlwe.Ciphertext, err
ks := eval.Evaluator.Evaluator ks := eval.Evaluator.Evaluator
size := eval.ResidualParameters.BaseRNSDecompositionVectorSize(eval.BootstrappingParameters.MaxLevelQ(), 0) buffDecompQP := poolQP.GetBuffDecompQP(eval.ResidualParameters.Parameters, eval.BootstrappingParameters.MaxLevelQ(), 0)
buffDecompQP := make([]ringqp.Poly, size) defer eval.pool.RecycleBuffDecompQP(buffDecompQP)
for i := 0; i < size; i++ {
buff := poolQP.GetBuffPolyQP()
defer poolQP.RecycleBuffPolyQP(buff)
buffDecompQP[i] = *buff
}
// ModUp q->QP for ctIn[1] centered around q // ModUp q->QP for ctIn[1] centered around q
for j := 0; j < N; j++ { for j := 0; j < N; j++ {

View File

@@ -26,6 +26,7 @@ func BenchmarkConcurrentBootstrap(b *testing.B) {
evk, _, err := btpParams.GenEvaluationKeys(sk) evk, _, err := btpParams.GenEvaluationKeys(sk)
require.NoError(b, err) require.NoError(b, err)
// Benchmark parallel bootstrapping
b.Run(ParamsToString(params, btpParams.LogMaxDimensions().Cols, "Bootstrap/"), func(b *testing.B) { b.Run(ParamsToString(params, btpParams.LogMaxDimensions().Cols, "Bootstrap/"), func(b *testing.B) {
var err error var err error
eval, err := NewEvaluator(btpParams, evk) eval, err := NewEvaluator(btpParams, evk)

View File

@@ -48,13 +48,8 @@ func (eval Evaluator) EvaluateMany(ctIn *rlwe.Ciphertext, linearTransformations
levelQ = utils.Min(levelQ, ctIn.Level()) levelQ = utils.Min(levelQ, ctIn.Level())
poolQP := eval.pool.AtLevel(levelQ, levelP) poolQP := eval.pool.AtLevel(levelQ, levelP)
baseRNSDecompositionVectorSize := eval.Evaluator.GetRLWEParameters().BaseRNSDecompositionVectorSize(levelQ, levelP) buffDecompQP := poolQP.GetBuffDecompQP(*eval.GetRLWEParameters(), levelQ, levelP)
buffDecompQP := make([]ringqp.Poly, baseRNSDecompositionVectorSize) defer eval.pool.RecycleBuffDecompQP(buffDecompQP)
for i := 0; i < len(buffDecompQP); i++ {
buff := poolQP.GetBuffPolyQP()
defer poolQP.RecycleBuffPolyQP(buff)
buffDecompQP[i] = *buff
}
eval.DecomposeNTT(levelQ, levelP, levelP+1, ctIn.Value[1], ctIn.IsNTT, buffDecompQP) eval.DecomposeNTT(levelQ, levelP, levelP+1, ctIn.Value[1], ctIn.IsNTT, buffDecompQP)

View File

@@ -211,13 +211,8 @@ func (eval Evaluator) PartialTracesSum(ctIn *Ciphertext, offset, n int, opOut *C
cQ.MetaData = ctInNTT.MetaData cQ.MetaData = ctInNTT.MetaData
baseRNSDecompositionVectorSize := eval.params.BaseRNSDecompositionVectorSize(levelQ, levelP) buffDecompQP := poolQP.GetBuffDecompQP(eval.params, levelQ, levelP)
buffDecompQP := make([]ringqp.Poly, baseRNSDecompositionVectorSize) defer eval.pool.RecycleBuffDecompQP(buffDecompQP)
for i := 0; i < len(buffDecompQP); i++ {
buff := poolQP.GetBuffPolyQP()
defer poolQP.RecycleBuffPolyQP(buff)
buffDecompQP[i] = *buff
}
state := false state := false
copy := true copy := true

View File

@@ -29,6 +29,15 @@ func NewPool(rqp *ringqp.Ring, pools ...structs.BufferPool[*[]uint64]) *Pool {
return &Pool{ringqpPool} return &Pool{ringqpPool}
} }
// AtLevel returns a new pool from which objects from polynomials at the given levels can be drawn.
// The method accepts up to two levels:
// Zero level: the objects returned are built from polynomials at level 0.
// One level: the objects returned are built from polynomials in RingQ (resp. RingP) at the given level (resp. level 0).
// Two levels: the objects returned are built from polynomials in RingQ (resp. RingP) at levels[0] (resp. levels[1]).
func (pool Pool) AtLevel(levels ...int) *Pool {
return &Pool{pool.Pool.AtLevel(levels...)}
}
// GetBuffCt returns a ciphertext that can be used as a buffer for intermediate computations. // GetBuffCt returns a ciphertext that can be used as a buffer for intermediate computations.
// After use, the ciphertext should be recycled with [Pool.RecycleBuffCt]. // After use, the ciphertext should be recycled with [Pool.RecycleBuffCt].
// The optional dimensions specify the degree and level of the ciphertext (default to 2, eval.params.ringQ.Level()). // The optional dimensions specify the degree and level of the ciphertext (default to 2, eval.params.ringQ.Level()).
@@ -99,3 +108,22 @@ func (pool *Pool) GetBuffPt(level ...int) *Plaintext {
func (pool *Pool) RecycleBuffPt(pt *Plaintext) { func (pool *Pool) RecycleBuffPt(pt *Plaintext) {
pool.RecycleBuffPoly(&pt.Value) pool.RecycleBuffPoly(&pt.Value)
} }
// GetBuffDecompQP returns buffers of polys to be used for RNS decomposition.
// After use, the array of buffers must be recycled with [Pool.RecycleBuffDecompQP].
func (pool *Pool) GetBuffDecompQP(params Parameters, levelQ, levelP int) []ringqp.Poly {
size := params.BaseRNSDecompositionVectorSize(levelQ, levelP)
buffDecompQP := make([]ringqp.Poly, size)
for i := 0; i < size; i++ {
poly := pool.GetBuffPolyQP()
buffDecompQP[i] = *poly
}
return buffDecompQP
}
// RecycleBuffDecompQP recycles a temporary array of polys used for decomposition.
func (pool *Pool) RecycleBuffDecompQP(decomp []ringqp.Poly) {
for _, poly := range decomp {
pool.RecycleBuffPolyQP(&poly)
}
}

View File

@@ -729,13 +729,9 @@ func testGadgetProduct(tc *TestContext, levelQ, bpw2 int, t *testing.T) {
// Setup temporary buffer for decomposition // Setup temporary buffer for decomposition
poolQP := tc.eval.pool.AtLevel(levelQ, levelP) poolQP := tc.eval.pool.AtLevel(levelQ, levelP)
size := params.BaseRNSDecompositionVectorSize(levelQ, 0) buffDecompQP := poolQP.GetBuffDecompQP(params, levelQ, levelP)
buffDecompQP := make([]ringqp.Poly, size) defer poolQP.RecycleBuffDecompQP(buffDecompQP)
for i := 0; i < size; i++ {
poly := poolQP.GetBuffPolyQP()
defer poolQP.RecycleBuffPolyQP(poly)
buffDecompQP[i] = *poly
}
if bpw2 != 0 { if bpw2 != 0 {
t.Skip("method is unsupported for BaseTwoDecomposition != 0") t.Skip("method is unsupported for BaseTwoDecomposition != 0")
} }
@@ -966,13 +962,9 @@ func testAutomorphism(tc *TestContext, level, bpw2 int, t *testing.T) {
// Setup temporary buffer for decomposition // Setup temporary buffer for decomposition
poolQP := tc.eval.pool.AtLevel(level, params.MaxLevelP()) poolQP := tc.eval.pool.AtLevel(level, params.MaxLevelP())
size := params.BaseRNSDecompositionVectorSize(level, 0) buffDecompQP := poolQP.GetBuffDecompQP(params, level, 0)
buffDecompQP := make([]ringqp.Poly, size) defer poolQP.RecycleBuffDecompQP(buffDecompQP)
for i := 0; i < size; i++ {
poly := poolQP.GetBuffPolyQP()
defer poolQP.RecycleBuffPolyQP(poly)
buffDecompQP[i] = *poly
}
// Generate a plaintext with values up to 2^30 // Generate a plaintext with values up to 2^30
pt := genPlaintext(params, level, 1<<30) pt := genPlaintext(params, level, 1<<30)
@@ -1025,13 +1017,8 @@ func testAutomorphism(tc *TestContext, level, bpw2 int, t *testing.T) {
// Setup temporary buffer for decomposition // Setup temporary buffer for decomposition
poolQP := tc.eval.pool.AtLevel(level, params.MaxLevelP()) poolQP := tc.eval.pool.AtLevel(level, params.MaxLevelP())
size := params.BaseRNSDecompositionVectorSize(level, 0) buffDecompQP := poolQP.GetBuffDecompQP(params, level, 0)
buffDecompQP := make([]ringqp.Poly, size) defer poolQP.RecycleBuffDecompQP(buffDecompQP)
for i := 0; i < size; i++ {
poly := poolQP.GetBuffPolyQP()
defer poolQP.RecycleBuffPolyQP(poly)
buffDecompQP[i] = *poly
}
// Generate a plaintext with values up to 2^30 // Generate a plaintext with values up to 2^30
pt := genPlaintext(params, level, 1<<30) pt := genPlaintext(params, level, 1<<30)

View File

@@ -9,7 +9,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"math" "math"
"sync"
"github.com/tuneinsight/lattigo/v6/circuits/ckks/bootstrapping" "github.com/tuneinsight/lattigo/v6/circuits/ckks/bootstrapping"
"github.com/tuneinsight/lattigo/v6/core/rlwe" "github.com/tuneinsight/lattigo/v6/core/rlwe"
@@ -194,22 +193,23 @@ func main() {
// To equalize the scale, the function evaluator.SetScale(ciphertext, parameters.DefaultScale()) can be used at the expense of one level. // To equalize the scale, the function evaluator.SetScale(ciphertext, parameters.DefaultScale()) can be used at the expense of one level.
// If the ciphertext is is at level one or greater when given to the bootstrapper, this equalization is automatically done. // If the ciphertext is is at level one or greater when given to the bootstrapper, this equalization is automatically done.
// Here we bootstrap two ciphertexts in parallel to demonstrate that evaluators are thread-safe. // Here we bootstrap two ciphertexts in parallel to demonstrate that evaluators are thread-safe.
var wg sync.WaitGroup resChan := make(chan *rlwe.Ciphertext, 1)
var res1, res2 *rlwe.Ciphertext
fmt.Println("Bootstrapping...") fmt.Println("Bootstrapping...")
wg.Add(1)
go func() { go func() {
defer wg.Done() res, err := eval.Bootstrap(ciphertext2)
res2, err = eval.Bootstrap(ciphertext2)
if err != nil { if err != nil {
panic(err) panic(err)
} }
resChan <- res
}() }()
res1, err = eval.Bootstrap(ciphertext1)
res1, err := eval.Bootstrap(ciphertext1)
if err != nil { if err != nil {
panic(err) panic(err)
} }
wg.Wait()
res2 := <-resChan
fmt.Println("Done") fmt.Println("Done")
//================== //==================

View File

@@ -759,7 +759,7 @@ func main() {
// CONCURRENCY // CONCURRENCY
// ========== // ==========
// //
// All public structures in Lattigo are thread-safe and can be used concurrently. // In lattigo, a type's methods are thread-safe and can be called concurrently. E.g. one can use an evaluator's methods concurrently.
// //
} }

View File

@@ -72,12 +72,12 @@ func (g *GaussianSampler) read(pol Poly, f func(a, b, c uint64) uint64) {
r := g.baseRing r := g.baseRing
randomBufferN := make([]byte, 1024) var randomBufferN [1024]byte
var ptr int var ptr int
level := r.level level := r.level
if _, err := g.prng.Read(randomBufferN); err != nil { if _, err := g.prng.Read(randomBufferN[:]); err != nil {
// Sanity check, this error should not happen. // Sanity check, this error should not happen.
panic(err) panic(err)
} }
@@ -121,7 +121,7 @@ func (g *GaussianSampler) read(pol Poly, f func(a, b, c uint64) uint64) {
for { for {
// Sample norm with sigma = 1 and sign // Sample norm with sigma = 1 and sign
norm, sign = g.normFloat64(randomBufferN, &ptr) norm, sign = g.normFloat64(randomBufferN[:], &ptr)
// Sets normFlo = norm * sigma with precision 53 bits // Sets normFlo = norm * sigma with precision 53 bits
// and 0.5 for rounding discretization // and 0.5 for rounding discretization
@@ -160,7 +160,7 @@ func (g *GaussianSampler) read(pol Poly, f func(a, b, c uint64) uint64) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
for { for {
norm, sign = g.normFloat64(randomBufferN, &ptr) norm, sign = g.normFloat64(randomBufferN[:], &ptr)
if v := norm * sigma; v <= bound { if v := norm * sigma; v <= bound {
coeffInt = uint64(v + 0.5) // rounding coeffInt = uint64(v + 0.5) // rounding
@@ -189,16 +189,15 @@ func (g *GaussianSampler) read(pol Poly, f func(a, b, c uint64) uint64) {
// //
// Algorithm adapted from https://golang.org/src/math/rand/normal.go // Algorithm adapted from https://golang.org/src/math/rand/normal.go
// to use a secure PRNG instead of math/rand. // to use a secure PRNG instead of math/rand.
func (g *GaussianSampler) normFloat64(buffer []byte, ptr *int) (float64, uint64) { func (g *GaussianSampler) normFloat64(buff []byte, ptr *int) (float64, uint64) {
currPtr := *ptr currPtr := *ptr
buff := buffer
prng := g.prng prng := g.prng
buffLen := len(buff) buffLen := len(buff)
read := func() { read := func() {
if currPtr == buffLen { if currPtr == buffLen {
if _, err := prng.Read(buff[:]); err != nil { if _, err := prng.Read(buff); err != nil {
// Sanity check, this error should not happen. // Sanity check, this error should not happen.
panic(err) panic(err)
} }

View File

@@ -1240,15 +1240,10 @@ func (eval Evaluator) RotateHoistedNew(ctIn *rlwe.Ciphertext, rotations []int) (
func (eval Evaluator) RotateHoisted(ctIn *rlwe.Ciphertext, rotations []int, opOut map[int]*rlwe.Ciphertext) (err error) { func (eval Evaluator) RotateHoisted(ctIn *rlwe.Ciphertext, rotations []int, opOut map[int]*rlwe.Ciphertext) (err error) {
levelQ := ctIn.Level() levelQ := ctIn.Level()
baseRNSDecompositionVectorSize := eval.GetParameters().BaseRNSDecompositionVectorSize(levelQ, eval.GetParameters().MaxLevelP())
buffDecompQP := make([]ringqp.Poly, baseRNSDecompositionVectorSize)
poolQP := eval.pool.AtLevel(levelQ, eval.GetParameters().MaxLevelP()) poolQP := eval.pool.AtLevel(levelQ, eval.GetParameters().MaxLevelP())
for i := 0; i < len(buffDecompQP); i++ {
buff := poolQP.GetBuffPolyQP() buffDecompQP := poolQP.GetBuffDecompQP(*eval.GetRLWEParameters(), levelQ, eval.GetParameters().MaxLevelP())
defer poolQP.RecycleBuffPolyQP(buff) defer poolQP.RecycleBuffDecompQP(buffDecompQP)
buffDecompQP[i] = *buff
}
eval.DecomposeNTT(levelQ, eval.GetParameters().MaxLevelP(), eval.GetParameters().PCount(), ctIn.Value[1], ctIn.IsNTT, buffDecompQP) eval.DecomposeNTT(levelQ, eval.GetParameters().MaxLevelP(), eval.GetParameters().PCount(), ctIn.Value[1], ctIn.IsNTT, buffDecompQP)
for _, i := range rotations { for _, i := range rotations {

View File

@@ -30,8 +30,8 @@ func (prng *ThreadSafePRNG) Read(sum []byte) (n int, err error) {
// sequences of random bytes among different parties using the hash function blake2b. Backward sequence // sequences of random bytes among different parties using the hash function blake2b. Backward sequence
// security (given the digest i, compute the digest i-1) is ensured by default, however forward sequence // security (given the digest i, compute the digest i-1) is ensured by default, however forward sequence
// security (given the digest i, compute the digest i+1) is only ensured if the KeyedPRNG is keyed. // security (given the digest i, compute the digest i+1) is only ensured if the KeyedPRNG is keyed.
// WARNING: KeyedPRNG should NOT be called by multiple threads. It does not make sense to do so as the resulting // WARNING: If KeyedPRNG is called concurrently by multiple threads, the resulting sequences will be independent and no error will be triggered. However, the result will not be deterministic and therefore, in most cases, it does not make sense to use KeyedPRNG in a concurrent setting.
// sequence will not be deterministic for a given key. For a PRNG securely seeded with a private keyuse [ThreadSafePRNG]. // NOTE: For a PRNG securely seeded with a private key use [ThreadSafePRNG].
type KeyedPRNG struct { type KeyedPRNG struct {
mutex sync.Mutex mutex sync.Mutex
key []byte key []byte
@@ -41,7 +41,7 @@ type KeyedPRNG struct {
// NewKeyedPRNG creates a new instance of KeyedPRNG. // NewKeyedPRNG creates a new instance of KeyedPRNG.
// Accepts an optional key, else set key=nil which is treated as key=[]byte{} // Accepts an optional key, else set key=nil which is treated as key=[]byte{}
// WARNING: A PRNG INITIALISED WITH key=nil IS INSECURE! // WARNING: A PRNG INITIALISED WITH key=nil IS INSECURE!
// WARNING: KeyedPRNG should NOT be called by multiple threads. If that occurs, the generated sequence will not be deterministic. // WARNING: KeyedPRNG can be called by multiple threads BUT the generated sequences will not be deterministic.
func NewKeyedPRNG(key []byte) (*KeyedPRNG, error) { func NewKeyedPRNG(key []byte) (*KeyedPRNG, error) {
var err error var err error
prng := new(KeyedPRNG) prng := new(KeyedPRNG)