bufferpool uint64 in ring

This commit is contained in:
lehugueni
2024-12-12 11:22:39 +01:00
parent 5e6a37a7e4
commit 77c62e6421
10 changed files with 122 additions and 66 deletions

View File

@@ -5,7 +5,6 @@ import (
"github.com/tuneinsight/lattigo/v6/ring"
"github.com/tuneinsight/lattigo/v6/utils/sampling"
"github.com/tuneinsight/lattigo/v6/utils/structs"
)
// Ciphertext is a generic type for RLWE ciphertexts.
@@ -71,12 +70,14 @@ func (ct Ciphertext) Equal(other *Ciphertext) bool {
return ct.Element.Equal(&other.Element)
}
func NewCiphertextFromUintPool(pool structs.BufferPool[*[]uint64], params ParameterProvider, degree int, level int) *Ciphertext {
func NewCiphertextFromUintPool(params ParameterProvider, degree int, levelQ int) *Ciphertext {
p := params.GetRLWEParameters()
ringQ := p.RingQ().AtLevel(levelQ)
Value := make([]ring.Poly, degree+1)
for i := range Value {
Value[i] = *ring.NewPolyFromUintPool(pool, p.N(), level)
Value[i] = *ringQ.NewPolyFromUintPool()
}
el := Element[ring.Poly]{
@@ -90,9 +91,10 @@ func NewCiphertextFromUintPool(pool structs.BufferPool[*[]uint64], params Parame
return &Ciphertext{el}
}
func RecycleCiphertextInUintPool(pool structs.BufferPool[*[]uint64], ct *Ciphertext) {
func RecycleCiphertextInUintPool(params ParameterProvider, ct *Ciphertext) {
ringQ := params.GetRLWEParameters().ringQ
for i := range ct.Value {
ring.RecyclePolyInUintPool(pool, &ct.Value[i])
ringQ.RecyclePolyInUintPool(&ct.Value[i])
}
ct = nil
return

View File

@@ -38,37 +38,33 @@ func newBuffer[T any](f func() T) structs.BufferPool[T] {
func NewEvaluatorBuffersWithUintPool(params Parameters) *EvaluatorBuffers {
buff := new(EvaluatorBuffers)
ringQP := params.RingQP()
ringQ := params.ringQ
buffUint := newBuffer(func() *[]uint64 {
buff := make([]uint64, params.RingQ().N())
return &buff
})
buff.BuffQPPool = structs.NewBuffFromUintPool(buffUint,
func(bp structs.BufferPool[*[]uint64]) *ringqp.Poly {
return ringQP.NewPolyQPFromUintPool(bp)
buff.BuffQPPool = structs.NewBuffFromUintPool(
func() *ringqp.Poly {
return ringQP.NewPolyQPFromUintPool()
},
func(bp structs.BufferPool[*[]uint64], poly *ringqp.Poly) {
ringqp.RecyclePolyQPFromUintPool(bp, poly)
func(poly *ringqp.Poly) {
ringQP.RecyclePolyQPFromUintPool(poly)
},
)
buff.BuffQPool = structs.NewBuffFromUintPool(buffUint,
func(bp structs.BufferPool[*[]uint64]) *ring.Poly {
return ring.NewPolyFromUintPool(bp, params.ringQ.N(), params.ringQ.Level())
buff.BuffQPool = structs.NewBuffFromUintPool(
func() *ring.Poly {
return ringQ.NewPolyFromUintPool()
},
func(bp structs.BufferPool[*[]uint64], poly *ring.Poly) {
ring.RecyclePolyInUintPool(bp, poly)
func(poly *ring.Poly) {
ringQ.RecyclePolyInUintPool(poly)
},
)
buff.BuffCtPool = structs.NewBuffFromUintPool(buffUint,
func(bp structs.BufferPool[*[]uint64]) *Ciphertext {
return NewCiphertextFromUintPool(bp, params, 2, params.MaxLevel())
buff.BuffCtPool = structs.NewBuffFromUintPool(
func() *Ciphertext {
return NewCiphertextFromUintPool(params, 2, params.MaxLevel())
},
func(bp structs.BufferPool[*[]uint64], ct *Ciphertext) {
RecycleCiphertextInUintPool(bp, ct)
func(ct *Ciphertext) {
RecycleCiphertextInUintPool(params, ct)
},
)
buff.BuffBitPool = buffUint
buff.BuffBitPool = ringQ.BufferPool()
return buff
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/tuneinsight/lattigo/v6/ring/ringqp"
"github.com/tuneinsight/lattigo/v6/utils"
"github.com/tuneinsight/lattigo/v6/utils/buffer"
"github.com/tuneinsight/lattigo/v6/utils/structs"
)
// MaxLogN is the log2 of the largest supported polynomial modulus degree.
@@ -855,11 +856,16 @@ func GenModuli(LogNthRoot int, logQ, logP []int) (q, p []uint64, err error) {
}
func (p *Parameters) initRings() (err error) {
if p.ringQ, err = ring.NewRingFromType(1<<p.logN, p.qi, p.ringType); err != nil {
N := 1 << p.logN
buffPool := structs.NewSyncPool(func() *[]uint64 {
buff := make([]uint64, N)
return &buff
})
if p.ringQ, err = ring.NewRingFromType(1<<p.logN, p.qi, p.ringType, buffPool); err != nil {
return fmt.Errorf("initRings/ringQ: %w", err)
}
if len(p.pi) != 0 {
if p.ringP, err = ring.NewRingFromType(1<<p.logN, p.pi, p.ringType); err != nil {
if p.ringP, err = ring.NewRingFromType(1<<p.logN, p.pi, p.ringType, buffPool); err != nil {
return fmt.Errorf("initRings/ringP: %w", err)
}
}

View File

@@ -73,14 +73,22 @@ func NewBasisExtender(ringQ, ringP *Ring) (be *BasisExtender) {
be.modDownConstantsPtoQ = genmodDownConstants(ringQ, ringP)
be.modDownConstantsQtoP = genmodDownConstants(ringP, ringQ)
be.buffQPool = structs.NewSyncPool(func() *Poly {
polyQ := ringQ.NewPoly()
return &polyQ
})
be.buffPPool = structs.NewSyncPool(func() *Poly {
polyP := ringP.NewPoly()
return &polyP
})
be.buffQPool = structs.NewBuffFromUintPool(
func() *Poly {
return ringQ.NewPolyFromUintPool()
},
func(poly *Poly) {
ringQ.RecyclePolyInUintPool(poly)
},
)
be.buffPPool = structs.NewBuffFromUintPool(
func() *Poly {
return ringP.NewPolyFromUintPool()
},
func(poly *Poly) {
ringP.RecyclePolyInUintPool(poly)
},
)
return
}

View File

@@ -14,7 +14,7 @@ type Poly struct {
Coeffs structs.Matrix[uint64]
}
func NewPolyFromUintPool(pool structs.BufferPool[*[]uint64], N, level int) (pol *Poly) {
func NewPolyFromUintPool(pool structs.BufferPool[*[]uint64], level int) (pol *Poly) {
coeffs := make([][]uint64, level+1)
for i := range coeffs {
coeffs[i] = *pool.Get()

View File

@@ -11,6 +11,7 @@ import (
"github.com/tuneinsight/lattigo/v6/utils"
"github.com/tuneinsight/lattigo/v6/utils/bignum"
"github.com/tuneinsight/lattigo/v6/utils/structs"
)
const (
@@ -78,6 +79,8 @@ type Ring struct {
RescaleConstants [][]uint64
level int
bufferPool structs.BufferPool[*[]uint64]
}
// ConjugateInvariantRing returns the conjugate invariant ring of the receiver ring.
@@ -173,6 +176,11 @@ func (r Ring) Level() int {
return r.level
}
// BufferPool returns the pool of *[]uint64
func (r Ring) BufferPool() structs.BufferPool[*[]uint64] {
return r.bufferPool
}
// AtLevel returns an instance of the target ring that operates at the target level.
// This instance is thread safe and can be use concurrently with the base ring.
func (r Ring) AtLevel(level int) *Ring {
@@ -192,6 +200,7 @@ func (r Ring) AtLevel(level int) *Ring {
ModulusAtLevel: r.ModulusAtLevel,
RescaleConstants: r.RescaleConstants,
level: level,
bufferPool: r.bufferPool,
}
}
@@ -241,15 +250,31 @@ func (r Ring) BRedConstants() (BRC [][2]uint64) {
// NewRing creates a new RNS Ring with degree N and coefficient moduli Moduli with Standard NTT. N must be a power of two larger than 8. Moduli should be
// a non-empty []uint64 with distinct prime elements. All moduli must also be equal to 1 modulo 2*N.
// An error is returned with a nil *Ring in the case of non NTT-enabling parameters.
func NewRing(N int, Moduli []uint64) (r *Ring, err error) {
return NewRingWithCustomNTT(N, Moduli, NewNumberTheoreticTransformerStandard, 2*N)
func NewRing(N int, Moduli []uint64, pool ...structs.BufferPool[*[]uint64]) (r *Ring, err error) {
var bp structs.BufferPool[*[]uint64]
switch len(pool) {
case 0:
case 1:
bp = pool[0]
default:
return nil, fmt.Errorf("cannot create new ring: more than 1 buffer pools provided")
}
return NewRingWithCustomNTT(N, Moduli, NewNumberTheoreticTransformerStandard, 2*N, bp)
}
// NewRingConjugateInvariant creates a new RNS Ring with degree N and coefficient moduli Moduli with Conjugate Invariant NTT. N must be a power of two larger than 8. Moduli should be
// a non-empty []uint64 with distinct prime elements. All moduli must also be equal to 1 modulo 4*N.
// An error is returned with a nil *Ring in the case of non NTT-enabling parameters.
func NewRingConjugateInvariant(N int, Moduli []uint64) (r *Ring, err error) {
return NewRingWithCustomNTT(N, Moduli, NewNumberTheoreticTransformerConjugateInvariant, 4*N)
func NewRingConjugateInvariant(N int, Moduli []uint64, pool ...structs.BufferPool[*[]uint64]) (r *Ring, err error) {
var bp structs.BufferPool[*[]uint64]
switch len(pool) {
case 0:
case 1:
bp = pool[0]
default:
return nil, fmt.Errorf("cannot create new ring: more than 1 buffer pools provided")
}
return NewRingWithCustomNTT(N, Moduli, NewNumberTheoreticTransformerConjugateInvariant, 4*N, bp)
}
// NewRingFromType creates a new RNS Ring with degree N and coefficient moduli Moduli for which the type of NTT is determined by the ringType argument.
@@ -257,12 +282,12 @@ func NewRingConjugateInvariant(N int, Moduli []uint64) (r *Ring, err error) {
// is instantiated with a ConjugateInvariant NTT with Nth root of unity 4*N. N must be a power of two larger than 8.
// Moduli should be a non-empty []uint64 with distinct prime elements. All moduli must also be equal to 1 modulo the root of unity.
// An error is returned with a nil *Ring in the case of non NTT-enabling parameters.
func NewRingFromType(N int, Moduli []uint64, ringType Type) (r *Ring, err error) {
func NewRingFromType(N int, Moduli []uint64, ringType Type, pool structs.BufferPool[*[]uint64]) (r *Ring, err error) {
switch ringType {
case Standard:
return NewRingWithCustomNTT(N, Moduli, NewNumberTheoreticTransformerStandard, 2*N)
return NewRingWithCustomNTT(N, Moduli, NewNumberTheoreticTransformerStandard, 2*N, pool)
case ConjugateInvariant:
return NewRingWithCustomNTT(N, Moduli, NewNumberTheoreticTransformerConjugateInvariant, 4*N)
return NewRingWithCustomNTT(N, Moduli, NewNumberTheoreticTransformerConjugateInvariant, 4*N, pool)
default:
return nil, fmt.Errorf("invalid ring type")
}
@@ -272,7 +297,7 @@ func NewRingFromType(N int, Moduli []uint64, ringType Type) (r *Ring, err error)
// ModuliChain should be a non-empty []uint64 with distinct prime elements.
// All moduli must also be equal to 1 modulo the root of unity.
// N must be a power of two larger than 8. An error is returned with a nil *Ring in the case of non NTT-enabling parameters.
func NewRingWithCustomNTT(N int, ModuliChain []uint64, ntt func(*SubRing, int) NumberTheoreticTransformer, NthRoot int) (r *Ring, err error) {
func NewRingWithCustomNTT(N int, ModuliChain []uint64, ntt func(*SubRing, int) NumberTheoreticTransformer, NthRoot int, pool structs.BufferPool[*[]uint64]) (r *Ring, err error) {
r = new(Ring)
// Checks if N is a power of 2
@@ -307,6 +332,8 @@ func NewRingWithCustomNTT(N int, ModuliChain []uint64, ntt func(*SubRing, int) N
r.level = len(ModuliChain) - 1
r.bufferPool = pool
return r, r.generateNTTConstants(nil, nil)
}
@@ -359,6 +386,14 @@ func (r Ring) NewPoly() Poly {
return NewPoly(r.N(), r.level)
}
func (r Ring) NewPolyFromUintPool() (p *Poly) {
return NewPolyFromUintPool(r.bufferPool, r.level)
}
func (r Ring) RecyclePolyInUintPool(pol *Poly) {
RecyclePolyInUintPool(r.bufferPool, pol)
}
// NewMonomialXi returns a polynomial X^{i}.
func (r Ring) NewMonomialXi(i int) (p Poly) {

View File

@@ -35,10 +35,14 @@ func genTestParams(defaultParams Parameters) (tc *testParams, err error) {
tc = new(testParams)
if tc.ringQ, err = NewRing(1<<defaultParams.logN, defaultParams.qi); err != nil {
pool := structs.NewSyncPool(func() *[]uint64 {
buff := make([]uint64, 1<<defaultParams.logN)
return &buff
})
if tc.ringQ, err = NewRing(1<<defaultParams.logN, defaultParams.qi, pool); err != nil {
return nil, err
}
if tc.ringP, err = NewRing(1<<defaultParams.logN, defaultParams.pi); err != nil {
if tc.ringP, err = NewRing(1<<defaultParams.logN, defaultParams.pi, pool); err != nil {
return nil, err
}
if tc.prng, err = sampling.NewPRNG(); err != nil {
@@ -90,7 +94,7 @@ func testNTTConjugateInvariant(tc *testParams, t *testing.T) {
Q := ringQ.ModuliChain()
N := ringQ.N()
ringQ2N, _ := NewRing(N<<1, Q)
ringQConjugateInvariant, _ := NewRingFromType(N, Q, ConjugateInvariant)
ringQConjugateInvariant, _ := NewRingFromType(N, Q, ConjugateInvariant, nil)
sampler := NewUniformSampler(tc.prng, ringQ)
p1 := sampler.ReadNew()

View File

@@ -7,7 +7,6 @@ import (
"github.com/tuneinsight/lattigo/v6/ring"
"github.com/tuneinsight/lattigo/v6/utils/bignum"
"github.com/tuneinsight/lattigo/v6/utils/structs"
)
// Ring is a structure that implements the operation in the ring R_QP.
@@ -221,18 +220,19 @@ func (r Ring) NewPoly() Poly {
return Poly{Q, P}
}
func (r Ring) NewPolyQPFromUintPool(pool structs.BufferPool[*[]uint64]) *Poly {
// NewPolyQPFromUintPool creates a new polynomial using the *[]uint64 BufferPool for backing arrays.
func (r Ring) NewPolyQPFromUintPool() *Poly {
var Q, P *ring.Poly
if r.RingQ != nil {
Q = ring.NewPolyFromUintPool(pool, r.RingQ.N(), r.RingQ.Level())
Q = r.RingQ.NewPolyFromUintPool()
}
if r.RingP != nil {
P = ring.NewPolyFromUintPool(pool, r.RingP.N(), r.RingP.Level())
P = r.RingP.NewPolyFromUintPool()
}
return &Poly{*Q, *P}
}
func RecyclePolyQPFromUintPool(pool structs.BufferPool[*[]uint64], poly *Poly) {
ring.RecyclePolyInUintPool(pool, &poly.Q)
ring.RecyclePolyInUintPool(pool, &poly.P)
func (r Ring) RecyclePolyQPFromUintPool(poly *Poly) {
r.RingQ.RecyclePolyInUintPool(&poly.Q)
r.RingP.RecyclePolyInUintPool(&poly.P)
}

View File

@@ -110,10 +110,17 @@ func NewEncoder(parameters Parameters, precision ...uint) (ecd *Encoder) {
buff := make([]*big.Int, m>>1)
return &buff
})
ecd.BuffPolyPool = structs.NewSyncPool(func() *ring.Poly {
poly := parameters.RingQ().NewPoly()
return &poly
})
ringQ := parameters.RingQ()
ecd.BuffPolyPool = structs.NewBuffFromUintPool(
func() *ring.Poly {
return ringQ.NewPolyFromUintPool()
},
func(poly *ring.Poly) {
ringQ.RecyclePolyInUintPool(poly)
},
)
if prec <= 53 {

View File

@@ -29,25 +29,23 @@ func (spool *SyncPool[T]) Put(buff T) {
}
type BuffFromUintPool[T any] struct {
uintPool BufferPool[*[]uint64] // Pool that must contain *[]uint64 objects
createObject func(BufferPool[*[]uint64]) T
recycleObject func(BufferPool[*[]uint64], T)
createObject func() T
recycleObject func(T)
}
func NewBuffFromUintPool[T any](pool BufferPool[*[]uint64], create func(BufferPool[*[]uint64]) T, recycle func(BufferPool[*[]uint64], T)) *BuffFromUintPool[T] {
func NewBuffFromUintPool[T any](create func() T, recycle func(T)) *BuffFromUintPool[T] {
return &BuffFromUintPool[T]{
uintPool: pool,
createObject: create,
recycleObject: recycle,
}
}
func (bu *BuffFromUintPool[T]) Get() T {
return bu.createObject(bu.uintPool)
return bu.createObject()
}
func (bu *BuffFromUintPool[T]) Put(obj T) {
bu.recycleObject(bu.uintPool, obj)
bu.recycleObject(obj)
}
type FreeList[T any] struct {