From 89af606619a6c5e8eac336769ffdafc955ab1561 Mon Sep 17 00:00:00 2001 From: lehugueni Date: Fri, 24 Jan 2025 15:43:09 +0100 Subject: [PATCH] feat: make samplers thread-safe + thread safe prng in encryptor --- ring/sampler.go | 11 ---------- ring/sampler_gaussian.go | 44 ++++++++++++++++++++-------------------- ring/sampler_uniform.go | 23 +++++++-------------- utils/sampling/prng.go | 12 +++++++++++ 4 files changed, 41 insertions(+), 49 deletions(-) diff --git a/ring/sampler.go b/ring/sampler.go index 353f0dc7..ee6c3b7d 100644 --- a/ring/sampler.go +++ b/ring/sampler.go @@ -79,17 +79,6 @@ type baseSampler struct { baseRing *Ring } -type randomBuffer struct { - randomBufferN []byte - ptr int -} - -func newRandomBuffer() *randomBuffer { - return &randomBuffer{ - randomBufferN: make([]byte, 1024), - } -} - // AtLevel returns an instance of the target base sampler that operates at the target level. // This instance is not thread safe and cannot be used concurrently to the base instance. func (b baseSampler) AtLevel(level int) *baseSampler { diff --git a/ring/sampler_gaussian.go b/ring/sampler_gaussian.go index 1e186194..e721ec62 100644 --- a/ring/sampler_gaussian.go +++ b/ring/sampler_gaussian.go @@ -16,7 +16,6 @@ const ( // GaussianSampler keeps the state of a truncated Gaussian polynomial sampler. type GaussianSampler struct { *baseSampler - *randomBuffer xe DiscreteGaussian montgomery bool } @@ -28,7 +27,6 @@ func NewGaussianSampler(prng sampling.PRNG, baseRing *Ring, X DiscreteGaussian, g = new(GaussianSampler) g.baseSampler = &baseSampler{} g.prng = prng - g.randomBuffer = newRandomBuffer() g.baseRing = baseRing g.xe = X g.montgomery = montgomery @@ -39,10 +37,9 @@ func NewGaussianSampler(prng sampling.PRNG, baseRing *Ring, X DiscreteGaussian, // This instance is not thread safe and cannot be used concurrently to the base instance. func (g *GaussianSampler) AtLevel(level int) Sampler { return &GaussianSampler{ - baseSampler: g.baseSampler.AtLevel(level), - randomBuffer: g.randomBuffer, - xe: g.xe, - montgomery: g.montgomery, + baseSampler: g.baseSampler.AtLevel(level), + xe: g.xe, + montgomery: g.montgomery, } } @@ -74,9 +71,12 @@ func (g *GaussianSampler) read(pol Poly, f func(a, b, c uint64) uint64) { r := g.baseRing + randomBufferN := make([]byte, 4*r.N()) + var ptr int + level := r.level - if _, err := g.prng.Read(g.randomBufferN); err != nil { + if _, err := g.prng.Read(randomBufferN); err != nil { // Sanity check, this error should not happen. panic(err) } @@ -120,7 +120,7 @@ func (g *GaussianSampler) read(pol Poly, f func(a, b, c uint64) uint64) { for { // Sample norm with sigma = 1 and sign - norm, sign = g.normFloat64() + norm, sign = g.normFloat64(randomBufferN, &ptr) // Sets normFlo = norm * sigma with precision 53 bits // and 0.5 for rounding discretization @@ -158,7 +158,7 @@ func (g *GaussianSampler) read(pol Poly, f func(a, b, c uint64) uint64) { for i := 0; i < N; i++ { for { - norm, sign = g.normFloat64() + norm, sign = g.normFloat64(randomBufferN, &ptr) if v := norm * sigma; v <= bound { coeffInt = uint64(v + 0.5) // rounding @@ -187,34 +187,34 @@ func (g *GaussianSampler) read(pol Poly, f func(a, b, c uint64) uint64) { // // Algorithm adapted from https://golang.org/src/math/rand/normal.go // to use a secure PRNG instead of math/rand. -func (g *GaussianSampler) normFloat64() (float64, uint64) { +func (g *GaussianSampler) normFloat64(buffer []byte, ptr *int) (float64, uint64) { - ptr := g.ptr - buff := g.randomBufferN + currPtr := *ptr + buff := buffer prng := g.prng buffLen := len(buff) read := func() { - if ptr == buffLen { - if _, err := prng.Read(buff); err != nil { + if currPtr == buffLen { + if _, err := prng.Read(buff[:]); err != nil { // Sanity check, this error should not happen. panic(err) } - ptr = 0 + currPtr = 0 } } randU32 := func() (x uint32) { read() - x = binary.LittleEndian.Uint32(buff[ptr : ptr+4]) - ptr += 8 // Avoids buffer misalignment + x = binary.LittleEndian.Uint32(buff[currPtr : currPtr+4]) + currPtr += 8 // Avoids buffer misalignment return } randF64 := func() (x float64) { read() - x = float64(binary.LittleEndian.Uint64(buff[ptr:ptr+8])&0x1fffffffffffff) / float64(0x1fffffffffffff) - ptr += 8 + x = float64(binary.LittleEndian.Uint64(buff[currPtr:currPtr+8])&0x1fffffffffffff) / float64(0x1fffffffffffff) + currPtr += 8 return } @@ -231,7 +231,7 @@ func (g *GaussianSampler) normFloat64() (float64, uint64) { // 1 (>99%) if uint32(j) < kn[i] { - g.ptr = ptr + *ptr = currPtr return x, sign } @@ -249,13 +249,13 @@ func (g *GaussianSampler) normFloat64() (float64, uint64) { } } - g.ptr = ptr + *ptr = currPtr return x + rn, sign } // 3 if fn[i]+float32(randF64())*(fn[i-1]-fn[i]) < float32(math.Exp(-0.5*x*x)) { - g.ptr = ptr + *ptr = currPtr return x, sign } } diff --git a/ring/sampler_uniform.go b/ring/sampler_uniform.go index 3837ea03..b4df1523 100644 --- a/ring/sampler_uniform.go +++ b/ring/sampler_uniform.go @@ -9,7 +9,6 @@ import ( // UniformSampler wraps a util.PRNG and represents the state of a sampler of uniform polynomials. type UniformSampler struct { *baseSampler - *randomBuffer } // NewUniformSampler creates a new instance of UniformSampler from a PRNG and ring definition. @@ -18,7 +17,6 @@ func NewUniformSampler(prng sampling.PRNG, baseRing *Ring) (u *UniformSampler) { u.baseSampler = &baseSampler{} u.baseRing = baseRing u.prng = prng - u.randomBuffer = newRandomBuffer() return } @@ -26,8 +24,7 @@ func NewUniformSampler(prng sampling.PRNG, baseRing *Ring) (u *UniformSampler) { // The returned sampler cannot be used concurrently to the original sampler. func (u *UniformSampler) AtLevel(level int) Sampler { return &UniformSampler{ - baseSampler: u.baseSampler.AtLevel(level), - randomBuffer: u.randomBuffer, + baseSampler: u.baseSampler.AtLevel(level), } } @@ -48,22 +45,18 @@ func (u *UniformSampler) read(pol Poly, f func(a, b, c uint64) uint64) { level := u.baseRing.Level() var randomUint, mask, qi uint64 + var buffer [1024]byte prng := u.prng N := u.baseRing.N() - byteArrayLength := len(u.randomBufferN) + byteArrayLength := len(buffer) var ptr int - if ptr = u.ptr; ptr == 0 || ptr == byteArrayLength { - if _, err := prng.Read(u.randomBufferN); err != nil { - // Sanity check, this error should not happen. - panic(err) - } - ptr = 0 // for the case where ptr == byteArrayLength + if _, err := prng.Read(buffer[:]); err != nil { + // Sanity check, this error should not happen. + panic(err) } - buffer := u.randomBufferN - for j := 0; j < level+1; j++ { qi = u.baseRing.SubRings[j].Modulus @@ -81,7 +74,7 @@ func (u *UniformSampler) read(pol Poly, f func(a, b, c uint64) uint64) { // Refills the buff if it runs empty if ptr == byteArrayLength { - if _, err := u.prng.Read(buffer); err != nil { + if _, err := u.prng.Read(buffer[:]); err != nil { // Sanity check, this error should not happen. panic(err) } @@ -102,7 +95,6 @@ func (u *UniformSampler) read(pol Poly, f func(a, b, c uint64) uint64) { } } - u.ptr = ptr } // ReadNew generates a new polynomial with coefficients following a uniform distribution over [0, Qi-1]. @@ -119,7 +111,6 @@ func (u *UniformSampler) WithPRNG(prng sampling.PRNG) *UniformSampler { prng: prng, baseRing: u.baseRing, }, - randomBuffer: newRandomBuffer(), } } diff --git a/utils/sampling/prng.go b/utils/sampling/prng.go index 992048fe..435077f7 100644 --- a/utils/sampling/prng.go +++ b/utils/sampling/prng.go @@ -22,6 +22,18 @@ type KeyedPRNG struct { xof blake2b.XOF } +type ThreadSafePRNG struct { +} + +// Read reads bytes from the KeyedPRNG on sum. +func (prng *ThreadSafePRNG) Read(sum []byte) (n int, err error) { + tmpPRNG, err := NewPRNG() + if err != nil { + return 0, fmt.Errorf("crypto rand error: %w", err) + } + return tmpPRNG.Read(sum) +} + // NewKeyedPRNG creates a new instance of KeyedPRNG. // Accepts an optional key, else set key=nil which is treated as key=[]byte{} // WARNING: A PRNG INITIALISED WITH key=nil IS INSECURE!