Files
lattigo/bfv/encoder.go
2020-10-09 12:30:31 +02:00

228 lines
6.3 KiB
Go

// Package bfv implements a RNS-accelerated Fan-Vercauteren version of Brakerski's scale invariant homomorphic encryption scheme. It provides modular arithmetic over the integers.
package bfv
import (
"math/big"
"unsafe"
"github.com/ldsec/lattigo/v2/ring"
"github.com/ldsec/lattigo/v2/utils"
)
// GaloisGen is an integer of order N/2 modulo M and that spans Z_M with the integer -1.
// The j-th ring automorphism takes the root zeta to zeta^(5j).
const GaloisGen uint64 = 5
// Encoder is an interface implementing the encoder.
type Encoder interface {
EncodeUint(coeffs []uint64, plaintext *Plaintext)
EncodeInt(coeffs []int64, plaintext *Plaintext)
DecodeUint(plaintext *Plaintext) (coeffs []uint64)
DecodeInt(plaintext *Plaintext) (coeffs []int64)
}
// Encoder is a structure that stores the parameters to encode values on a plaintext in a SIMD (Single-Instruction Multiple-Data) fashion.
type encoder struct {
params *Parameters
ringQ *ring.Ring
ringT *ring.Ring
indexMatrix []uint64
scaler ring.Scaler
polypool *ring.Poly
deltaMont []uint64
}
// NewEncoder creates a new encoder from the provided parameters.
func NewEncoder(params *Parameters) Encoder {
var ringQ, ringT *ring.Ring
var err error
if ringQ, err = ring.NewRing(params.N(), params.qi); err != nil {
panic(err)
}
if ringT, err = ring.NewRing(params.N(), []uint64{params.t}); err != nil {
panic(err)
}
var m, pos, index1, index2 uint64
slots := params.N()
indexMatrix := make([]uint64, slots)
logN := params.LogN()
rowSize := params.N() >> 1
m = (params.N() << 1)
pos = 1
for i := uint64(0); i < rowSize; i++ {
index1 = (pos - 1) >> 1
index2 = (m - pos - 1) >> 1
indexMatrix[i] = utils.BitReverse64(index1, logN)
indexMatrix[i|rowSize] = utils.BitReverse64(index2, logN)
pos *= GaloisGen
pos &= (m - 1)
}
return &encoder{
params: params.Copy(),
ringQ: ringQ,
ringT: ringT,
indexMatrix: indexMatrix,
deltaMont: GenLiftParams(ringQ, params.t),
scaler: ring.NewRNSScaler(params.t, ringQ),
polypool: ringT.NewPoly(),
}
}
// GenLiftParams generates the lifting parameters.
func GenLiftParams(ringQ *ring.Ring, t uint64) (deltaMont []uint64) {
delta := new(big.Int).Quo(ringQ.ModulusBigint, ring.NewUint(t))
deltaMont = make([]uint64, len(ringQ.Modulus))
tmp := new(big.Int)
bredParams := ringQ.GetBredParams()
for i, Qi := range ringQ.Modulus {
deltaMont[i] = tmp.Mod(delta, ring.NewUint(Qi)).Uint64()
deltaMont[i] = ring.MForm(deltaMont[i], Qi, bredParams[i])
}
return
}
// EncodeUint encodes an uint64 slice of size at most N on a plaintext.
func (encoder *encoder) EncodeUint(coeffs []uint64, plaintext *Plaintext) {
if len(coeffs) > len(encoder.indexMatrix) {
panic("invalid input to encode: number of coefficients must be smaller or equal to the ring degree")
}
if len(plaintext.value.Coeffs[0]) != len(encoder.indexMatrix) {
panic("invalid plaintext to receive encoding: number of coefficients does not match the ring degree")
}
for i := 0; i < len(coeffs); i++ {
plaintext.value.Coeffs[0][encoder.indexMatrix[i]] = coeffs[i]
}
for i := len(coeffs); i < len(encoder.indexMatrix); i++ {
plaintext.value.Coeffs[0][encoder.indexMatrix[i]] = 0
}
encoder.encodePlaintext(plaintext)
}
// EncodeInt encodes an int64 slice of size at most N on a plaintext. It also encodes the sign of the given integer (as its inverse modulo the plaintext modulus).
// The sign will correctly decode as long as the absolute value of the coefficient does not exceed half of the plaintext modulus.
func (encoder *encoder) EncodeInt(coeffs []int64, plaintext *Plaintext) {
if len(coeffs) > len(encoder.indexMatrix) {
panic("invalid input to encode: number of coefficients must be smaller or equal to the ring degree")
}
if len(plaintext.value.Coeffs[0]) != len(encoder.indexMatrix) {
panic("invalid plaintext to receive encoding: number of coefficients does not match the ring degree")
}
for i := 0; i < len(coeffs); i++ {
if coeffs[i] < 0 {
plaintext.value.Coeffs[0][encoder.indexMatrix[i]] = uint64(int64(encoder.params.t) + coeffs[i])
} else {
plaintext.value.Coeffs[0][encoder.indexMatrix[i]] = uint64(coeffs[i])
}
}
for i := len(coeffs); i < len(encoder.indexMatrix); i++ {
plaintext.value.Coeffs[0][encoder.indexMatrix[i]] = 0
}
encoder.encodePlaintext(plaintext)
}
func (encoder *encoder) encodePlaintext(p *Plaintext) {
encoder.ringT.InvNTT(p.value, p.value)
ringQ := encoder.ringQ
for i := len(ringQ.Modulus) - 1; i >= 0; i-- {
tmp1 := p.value.Coeffs[i]
tmp2 := p.value.Coeffs[0]
deltaMont := encoder.deltaMont[i]
qi := ringQ.Modulus[i]
bredParams := ringQ.GetMredParams()[i]
for j := uint64(0); j < ringQ.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&tmp2[j]))
z := (*[8]uint64)(unsafe.Pointer(&tmp1[j]))
z[0] = ring.MRed(x[0], deltaMont, qi, bredParams)
z[1] = ring.MRed(x[1], deltaMont, qi, bredParams)
z[2] = ring.MRed(x[2], deltaMont, qi, bredParams)
z[3] = ring.MRed(x[3], deltaMont, qi, bredParams)
z[4] = ring.MRed(x[4], deltaMont, qi, bredParams)
z[5] = ring.MRed(x[5], deltaMont, qi, bredParams)
z[6] = ring.MRed(x[6], deltaMont, qi, bredParams)
z[7] = ring.MRed(x[7], deltaMont, qi, bredParams)
}
}
}
// DecodeUint decodes a batched plaintext and returns the coefficients in a uint64 slice.
func (encoder *encoder) DecodeUint(plaintext *Plaintext) (coeffs []uint64) {
encoder.scaler.DivByQOverTRounded(plaintext.value, encoder.polypool)
encoder.ringT.NTT(encoder.polypool, encoder.polypool)
coeffs = make([]uint64, encoder.ringQ.N)
for i := uint64(0); i < encoder.ringQ.N; i++ {
coeffs[i] = encoder.polypool.Coeffs[0][encoder.indexMatrix[i]]
}
return
}
// DecodeInt decodes a batched plaintext and returns the coefficients in an int64 slice. It also decodes the sign (by centering the values around the plaintext
// modulus).
func (encoder *encoder) DecodeInt(plaintext *Plaintext) (coeffs []int64) {
var value int64
encoder.scaler.DivByQOverTRounded(plaintext.value, encoder.polypool)
encoder.ringT.NTT(encoder.polypool, encoder.polypool)
coeffs = make([]int64, encoder.ringQ.N)
modulus := int64(encoder.params.t)
for i := uint64(0); i < encoder.ringQ.N; i++ {
value = int64(encoder.polypool.Coeffs[0][encoder.indexMatrix[i]])
coeffs[i] = value
if value > modulus>>1 {
coeffs[i] -= modulus
}
}
return coeffs
}