freelist + benchmarking

This commit is contained in:
lehugueni
2024-12-02 12:26:42 +01:00
parent 6cb65c300a
commit 0504cfa1a6
4 changed files with 283 additions and 7 deletions

View File

@@ -29,23 +29,30 @@ type EvaluatorBuffers struct {
BuffCtPool structs.BufferPool[*Ciphertext]
}
func newBuffer[T any](f func() T) structs.BufferPool[T] {
// Uncomment to try with free lists instead of sync pool:
// nbItemsInPool := 10
// return structs.NewFreeList(nbItemsInPool, f)
return structs.NewSyncPool(f)
}
func NewEvaluatorBuffers(params Parameters) *EvaluatorBuffers {
buff := new(EvaluatorBuffers)
ringQP := params.RingQP()
buff.BuffQPPool = structs.NewSyncPool(func() *ringqp.Poly {
buff.BuffQPPool = newBuffer(func() *ringqp.Poly {
poly := ringQP.NewPoly()
return &poly
})
buff.BuffQPool = structs.NewSyncPool(func() *ring.Poly {
buff.BuffQPool = newBuffer(func() *ring.Poly {
poly := params.RingQ().NewPoly()
return &poly
})
buff.BuffCtPool = structs.NewSyncPool(func() *Ciphertext {
buff.BuffCtPool = newBuffer(func() *Ciphertext {
return NewCiphertext(params, 2, params.MaxLevel())
})
buff.BuffBitPool = structs.NewSyncPool(func() *[]uint64 {
buff.BuffBitPool = newBuffer(func() *[]uint64 {
buff := make([]uint64, params.RingQ().N())
return &buff
})

View File

@@ -37,8 +37,9 @@ func BenchmarkCKKS(b *testing.B) {
tc := NewTestContext(paramsLiteral)
for _, testSet := range []func(tc *TestContext, b *testing.B){
benchEncoder,
benchEvaluator,
// benchEncoder,
// benchEvaluator,
benchEvaluatorParallel,
} {
testSet(tc, b)
runtime.GC()
@@ -91,6 +92,238 @@ func benchEncoder(tc *TestContext, b *testing.B) {
})
}
func benchEvaluatorParallel(tc *TestContext, b *testing.B) {
params := tc.Params
plaintext := NewPlaintext(params, params.MaxLevel())
plaintext.Value = rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 0, plaintext.Level()).Value[0]
vector := make([]float64, params.MaxSlots())
for i := range vector {
vector[i] = 1
}
ciphertext1 := rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 1, params.MaxLevel())
ciphertext2 := rlwe.NewCiphertextRandom(tc.Prng, params.Parameters, 1, params.MaxLevel())
*ciphertext1.MetaData = *plaintext.MetaData
*ciphertext2.MetaData = *plaintext.MetaData
eval := tc.Evl.WithKey(rlwe.NewMemEvaluationKeySet(tc.Kgen.GenRelinearizationKeyNew(tc.Sk)))
b.Run(name("EvaluatorParallel/Add/Scalar", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.Add(ciphertext1, 3.1415-1.4142i, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})
b.Run(name("EvaluatorParallel/Add/Scalar", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.Add(ciphertext1, vector, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})
b.Run(name("EvaluatorParallel/Add/Plaintext", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.Add(ciphertext1, plaintext, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})
b.Run(name("EvaluatorParallel/Add/Ciphertext", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.Add(ciphertext1, ciphertext2, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})
b.Run(name("EvaluatorParallel/Mul/Scalar", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.Mul(ciphertext1, 3.1415-1.4142i, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})
b.Run(name("EvaluatorParallel/Mul/Vector", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.Mul(ciphertext1, vector, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})
b.Run(name("EvaluatorParallel/Mul/Plaintext", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.Mul(ciphertext1, plaintext, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})
b.Run(name("EvaluatorParallel/Mul/Ciphertext", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 2, ciphertext1.Level())
for pb.Next() {
if err := eval.Mul(ciphertext1, ciphertext2, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})
b.Run(name("EvaluatorParallel/MulRelin/Ciphertext", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.MulRelin(ciphertext1, ciphertext2, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})
b.Run(name("EvaluatorParallel/MulThenAdd/Scalar", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.MulThenAdd(ciphertext1, 3.1415-1.4142i, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})
b.Run(name("EvaluatorParallel/MulThenAdd/Vector", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.MulThenAdd(ciphertext1, vector, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})
b.Run(name("EvaluatorParallel/MulThenAdd/Plaintext", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.MulThenAdd(ciphertext1, plaintext, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})
b.Run(name("EvaluatorParallel/MulThenAdd/Ciphertext", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 2, ciphertext1.Level())
for pb.Next() {
if err := eval.MulThenAdd(ciphertext1, ciphertext2, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})
b.Run(name("EvaluatorParallel/MulRelinThenAdd/Ciphertext", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
for pb.Next() {
if err := eval.MulRelinThenAdd(ciphertext1, ciphertext2, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})
b.Run(name("EvaluatorParallel/Rescale", tc), func(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level()-1)
for pb.Next() {
if err := eval.Rescale(ciphertext1, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})
b.Run(name("EvaluatorParallel/Rotate", tc), func(b *testing.B) {
b.ResetTimer()
gk := tc.Kgen.GenGaloisKeyNew(5, tc.Sk)
evk := rlwe.NewMemEvaluationKeySet(nil, gk)
eval := eval.WithKey(evk)
b.RunParallel(func(pb *testing.PB) {
receiver := NewCiphertext(params, 1, ciphertext1.Level())
b.ResetTimer()
for pb.Next() {
if err := eval.Rotate(ciphertext1, 1, receiver); err != nil {
b.Log(err)
b.Fail()
}
}
})
})
}
func benchEvaluator(tc *TestContext, b *testing.B) {
params := tc.Params

View File

@@ -1052,7 +1052,6 @@ func (ecd *Encoder) polyToFloatCRT(p ring.Poly, values FloatSlice, scale rlwe.Sc
defer ecd.BuffBigIntPool.Put(buffRef)
bigintCoeffs := *buffRef
// TODO: Double check, was using ecd.buff instead of p, but they are equal?
ecd.parameters.RingQ().PolyToBigint(p, 1, bigintCoeffs)
Q := r.ModulusAtLevel[r.Level()]

View File

@@ -6,6 +6,7 @@ type BufferPool[T any] interface {
Get() T
Put(T)
}
type SyncPool[T any] struct {
pool *sync.Pool
}
@@ -26,3 +27,39 @@ func (spool *SyncPool[T]) Get() T {
func (spool *SyncPool[T]) Put(buff T) {
spool.pool.Put(buff)
}
type FreeList[T any] struct {
pool chan T
newObject func() T
capacity int
}
func NewFreeList[T any](capacity int, f func() T) *FreeList[T] {
pool := make(chan T, capacity)
for i := 0; i < capacity; i++ {
pool <- f()
}
return &FreeList[T]{
pool: pool,
newObject: f,
capacity: capacity,
}
}
func (fl *FreeList[T]) Get() T {
var obj T
select {
case obj = <-fl.pool:
default:
obj = fl.newObject()
}
return obj
}
func (fl *FreeList[T]) Put(obj T) {
select {
case fl.pool <- obj:
default:
}
}