feat: make samplers thread-safe + thread safe prng in encryptor

This commit is contained in:
lehugueni
2025-01-24 15:43:09 +01:00
parent bad71e5091
commit 89af606619
4 changed files with 41 additions and 49 deletions

View File

@@ -79,17 +79,6 @@ type baseSampler struct {
baseRing *Ring
}
type randomBuffer struct {
randomBufferN []byte
ptr int
}
func newRandomBuffer() *randomBuffer {
return &randomBuffer{
randomBufferN: make([]byte, 1024),
}
}
// AtLevel returns an instance of the target base sampler that operates at the target level.
// This instance is not thread safe and cannot be used concurrently to the base instance.
func (b baseSampler) AtLevel(level int) *baseSampler {

View File

@@ -16,7 +16,6 @@ const (
// GaussianSampler keeps the state of a truncated Gaussian polynomial sampler.
type GaussianSampler struct {
*baseSampler
*randomBuffer
xe DiscreteGaussian
montgomery bool
}
@@ -28,7 +27,6 @@ func NewGaussianSampler(prng sampling.PRNG, baseRing *Ring, X DiscreteGaussian,
g = new(GaussianSampler)
g.baseSampler = &baseSampler{}
g.prng = prng
g.randomBuffer = newRandomBuffer()
g.baseRing = baseRing
g.xe = X
g.montgomery = montgomery
@@ -39,10 +37,9 @@ func NewGaussianSampler(prng sampling.PRNG, baseRing *Ring, X DiscreteGaussian,
// This instance is not thread safe and cannot be used concurrently to the base instance.
func (g *GaussianSampler) AtLevel(level int) Sampler {
return &GaussianSampler{
baseSampler: g.baseSampler.AtLevel(level),
randomBuffer: g.randomBuffer,
xe: g.xe,
montgomery: g.montgomery,
baseSampler: g.baseSampler.AtLevel(level),
xe: g.xe,
montgomery: g.montgomery,
}
}
@@ -74,9 +71,12 @@ func (g *GaussianSampler) read(pol Poly, f func(a, b, c uint64) uint64) {
r := g.baseRing
randomBufferN := make([]byte, 4*r.N())
var ptr int
level := r.level
if _, err := g.prng.Read(g.randomBufferN); err != nil {
if _, err := g.prng.Read(randomBufferN); err != nil {
// Sanity check, this error should not happen.
panic(err)
}
@@ -120,7 +120,7 @@ func (g *GaussianSampler) read(pol Poly, f func(a, b, c uint64) uint64) {
for {
// Sample norm with sigma = 1 and sign
norm, sign = g.normFloat64()
norm, sign = g.normFloat64(randomBufferN, &ptr)
// Sets normFlo = norm * sigma with precision 53 bits
// and 0.5 for rounding discretization
@@ -158,7 +158,7 @@ func (g *GaussianSampler) read(pol Poly, f func(a, b, c uint64) uint64) {
for i := 0; i < N; i++ {
for {
norm, sign = g.normFloat64()
norm, sign = g.normFloat64(randomBufferN, &ptr)
if v := norm * sigma; v <= bound {
coeffInt = uint64(v + 0.5) // rounding
@@ -187,34 +187,34 @@ func (g *GaussianSampler) read(pol Poly, f func(a, b, c uint64) uint64) {
//
// Algorithm adapted from https://golang.org/src/math/rand/normal.go
// to use a secure PRNG instead of math/rand.
func (g *GaussianSampler) normFloat64() (float64, uint64) {
func (g *GaussianSampler) normFloat64(buffer []byte, ptr *int) (float64, uint64) {
ptr := g.ptr
buff := g.randomBufferN
currPtr := *ptr
buff := buffer
prng := g.prng
buffLen := len(buff)
read := func() {
if ptr == buffLen {
if _, err := prng.Read(buff); err != nil {
if currPtr == buffLen {
if _, err := prng.Read(buff[:]); err != nil {
// Sanity check, this error should not happen.
panic(err)
}
ptr = 0
currPtr = 0
}
}
randU32 := func() (x uint32) {
read()
x = binary.LittleEndian.Uint32(buff[ptr : ptr+4])
ptr += 8 // Avoids buffer misalignment
x = binary.LittleEndian.Uint32(buff[currPtr : currPtr+4])
currPtr += 8 // Avoids buffer misalignment
return
}
randF64 := func() (x float64) {
read()
x = float64(binary.LittleEndian.Uint64(buff[ptr:ptr+8])&0x1fffffffffffff) / float64(0x1fffffffffffff)
ptr += 8
x = float64(binary.LittleEndian.Uint64(buff[currPtr:currPtr+8])&0x1fffffffffffff) / float64(0x1fffffffffffff)
currPtr += 8
return
}
@@ -231,7 +231,7 @@ func (g *GaussianSampler) normFloat64() (float64, uint64) {
// 1 (>99%)
if uint32(j) < kn[i] {
g.ptr = ptr
*ptr = currPtr
return x, sign
}
@@ -249,13 +249,13 @@ func (g *GaussianSampler) normFloat64() (float64, uint64) {
}
}
g.ptr = ptr
*ptr = currPtr
return x + rn, sign
}
// 3
if fn[i]+float32(randF64())*(fn[i-1]-fn[i]) < float32(math.Exp(-0.5*x*x)) {
g.ptr = ptr
*ptr = currPtr
return x, sign
}
}

View File

@@ -9,7 +9,6 @@ import (
// UniformSampler wraps a util.PRNG and represents the state of a sampler of uniform polynomials.
type UniformSampler struct {
*baseSampler
*randomBuffer
}
// NewUniformSampler creates a new instance of UniformSampler from a PRNG and ring definition.
@@ -18,7 +17,6 @@ func NewUniformSampler(prng sampling.PRNG, baseRing *Ring) (u *UniformSampler) {
u.baseSampler = &baseSampler{}
u.baseRing = baseRing
u.prng = prng
u.randomBuffer = newRandomBuffer()
return
}
@@ -26,8 +24,7 @@ func NewUniformSampler(prng sampling.PRNG, baseRing *Ring) (u *UniformSampler) {
// The returned sampler cannot be used concurrently to the original sampler.
func (u *UniformSampler) AtLevel(level int) Sampler {
return &UniformSampler{
baseSampler: u.baseSampler.AtLevel(level),
randomBuffer: u.randomBuffer,
baseSampler: u.baseSampler.AtLevel(level),
}
}
@@ -48,22 +45,18 @@ func (u *UniformSampler) read(pol Poly, f func(a, b, c uint64) uint64) {
level := u.baseRing.Level()
var randomUint, mask, qi uint64
var buffer [1024]byte
prng := u.prng
N := u.baseRing.N()
byteArrayLength := len(u.randomBufferN)
byteArrayLength := len(buffer)
var ptr int
if ptr = u.ptr; ptr == 0 || ptr == byteArrayLength {
if _, err := prng.Read(u.randomBufferN); err != nil {
// Sanity check, this error should not happen.
panic(err)
}
ptr = 0 // for the case where ptr == byteArrayLength
if _, err := prng.Read(buffer[:]); err != nil {
// Sanity check, this error should not happen.
panic(err)
}
buffer := u.randomBufferN
for j := 0; j < level+1; j++ {
qi = u.baseRing.SubRings[j].Modulus
@@ -81,7 +74,7 @@ func (u *UniformSampler) read(pol Poly, f func(a, b, c uint64) uint64) {
// Refills the buff if it runs empty
if ptr == byteArrayLength {
if _, err := u.prng.Read(buffer); err != nil {
if _, err := u.prng.Read(buffer[:]); err != nil {
// Sanity check, this error should not happen.
panic(err)
}
@@ -102,7 +95,6 @@ func (u *UniformSampler) read(pol Poly, f func(a, b, c uint64) uint64) {
}
}
u.ptr = ptr
}
// ReadNew generates a new polynomial with coefficients following a uniform distribution over [0, Qi-1].
@@ -119,7 +111,6 @@ func (u *UniformSampler) WithPRNG(prng sampling.PRNG) *UniformSampler {
prng: prng,
baseRing: u.baseRing,
},
randomBuffer: newRandomBuffer(),
}
}

View File

@@ -22,6 +22,18 @@ type KeyedPRNG struct {
xof blake2b.XOF
}
type ThreadSafePRNG struct {
}
// Read reads bytes from the KeyedPRNG on sum.
func (prng *ThreadSafePRNG) Read(sum []byte) (n int, err error) {
tmpPRNG, err := NewPRNG()
if err != nil {
return 0, fmt.Errorf("crypto rand error: %w", err)
}
return tmpPRNG.Read(sum)
}
// NewKeyedPRNG creates a new instance of KeyedPRNG.
// Accepts an optional key, else set key=nil which is treated as key=[]byte{}
// WARNING: A PRNG INITIALISED WITH key=nil IS INSECURE!