Files
lattigo/ckks/plaintext.go
Jean-Philippe Bossuat 4421f325b0 Initial commit
2019-08-12 21:15:01 +02:00

359 lines
9.5 KiB
Go

package ckks
import (
"errors"
"github.com/lca1/lattigo/ring"
"math/bits"
"math/cmplx"
)
// The plaintext is a ring of N coefficients with two contexts.
// The first context is defined by the BFV parameters. The second
// context defines a NTT around its modulus it it permits it.
type Plaintext BigPoly
// NewPlaintext creates a new plaintext of level level and scale scale.
func (ckkscontext *CkksContext) NewPlaintext(level uint64, scale uint64) *Plaintext {
plaintext := new(Plaintext)
plaintext.ckkscontext = ckkscontext
plaintext.value = []*ring.Poly{ckkscontext.contextLevel[level].NewPoly()}
plaintext.scale = scale
plaintext.currentModulus = ring.Copy(ckkscontext.contextLevel[level].ModulusBigint)
plaintext.isNTT = true
return plaintext
}
// Value returns the value (polynomial) of the plaintext.
func (P *Plaintext) Value() []*ring.Poly {
return P.value
}
// SetValue sets the value (polynomial) of the plaintext to the provided value.
func (P *Plaintext) SetValue(value []*ring.Poly) {
P.value = value
}
// Resize does nothing on a plaintext since it is always of degree 0.
func (P *Plaintext) Resize(degree uint64) {
}
// CkksContext returns the ckkscontext of the plaintext.
func (P *Plaintext) CkksContext() *CkksContext {
return P.ckkscontext
}
// SetCkksContext assigns a new ckkscontext to the plaintext.
func (P *Plaintext) SetCkksContext(ckkscontext *CkksContext) {
P.ckkscontext = ckkscontext
}
// CurrentModulus returns the current modulus of the plaintext.
// This variable is only used during the decoding.
func (P *Plaintext) CurrentModulus() *ring.Int {
return P.currentModulus
}
// SetCurrentModulus sets the current modulus to the provided values.
// This variable is only used during the decoding.
func (P *Plaintext) SetCurrentModulus(modulus *ring.Int) {
P.currentModulus = ring.Copy(modulus)
}
// Degree returns the degree of the plaintext,
// this value should always be zero.
func (P *Plaintext) Degree() uint64 {
return uint64(len(P.value) - 1)
}
// Level returns the current level of the plaintext.
func (P *Plaintext) Level() uint64 {
return uint64(len(P.value[0].Coeffs) - 1)
}
// Scale returns the current scale of the plaintext (in log2).
func (P *Plaintext) Scale() uint64 {
return P.scale
}
// SetScale sets the scale of the plaintext to the provided value (in log2).
func (P *Plaintext) SetScale(scale uint64) {
P.scale = scale
}
// IsNTT returns true or false depending on if the plaintext is in the NTT domain or not.
func (P *Plaintext) IsNTT() bool {
return P.isNTT
}
// SetIsNTT sets the isNTT value of the plaintext to the provided value.
func (P *Plaintext) SetIsNTT(isNTT bool) {
P.isNTT = isNTT
}
// NTT applies the NTT transform to a plaintext and returns the result on the receiver element.
// Can only be used if the plaintext is not already in the NTT domain.
func (P *Plaintext) NTT(ct0 CkksElement) {
if P.isNTT != true {
for i := range ct0.Value() {
P.ckkscontext.contextLevel[P.Level()].NTT(P.value[i], ct0.Value()[i])
}
ct0.SetIsNTT(true)
}
}
// InvNTT applies the inverse NTT transform to a plaintext and returns the result on the receiver element.
// Can only be used it the plaintext is in the NTT domain
func (P *Plaintext) InvNTT(ct0 CkksElement) {
if P.isNTT != false {
for i := range ct0.Value() {
P.ckkscontext.contextLevel[P.Level()].InvNTT(P.value[i], ct0.Value()[i])
}
ct0.SetIsNTT(false)
}
}
// CopyNew creates a new plaintext with the same value and same parameters.
func (P *Plaintext) CopyNew() CkksElement {
PCopy := new(Plaintext)
PCopy.value = make([]*ring.Poly, 1)
PCopy.value[0] = P.value[0].CopyNew()
PCopy.ckkscontext = P.ckkscontext
P.CopyParams(PCopy)
return PCopy
}
// Copy copies the value and parameters of the reference plaintext ot the receiver plaintext.
func (P *Plaintext) Copy(PCopy CkksElement) error {
if !checkContext([]CkksElement{P, PCopy}) {
return errors.New("input ciphertext are not using the same ckkscontext")
}
P.value[0].Copy(PCopy.Value()[0])
P.CopyParams(PCopy)
return nil
}
// CopyParams copies the parameters of the reference plaintext to the receiver plaintext.
func (P *Plaintext) CopyParams(ckkselement CkksElement) {
ckkselement.SetCurrentModulus(P.CurrentModulus())
ckkselement.SetScale(P.Scale())
ckkselement.SetIsNTT(P.IsNTT())
}
// EncodeFloat encode a float64 slice of at most N/2 values.
func (plaintext *Plaintext) EncodeFloat(coeffs []float64) error {
if len(coeffs) > (len(plaintext.ckkscontext.indexMatrix)>>1)/int(plaintext.ckkscontext.gap) {
return errors.New("error : invalid input to encode (number of coefficients must be smaller or equal to the context)")
}
if len(plaintext.value[0].Coeffs[0]) != len(plaintext.ckkscontext.indexMatrix) {
return errors.New("error : invalid plaintext to receive encoding (number of coefficients does not match the context of the encoder")
}
values := make([]complex128, len(coeffs)*int(plaintext.ckkscontext.gap))
for i := range coeffs {
values[i*int(plaintext.ckkscontext.gap)] = complex(coeffs[i], 0)
for j := 0; j < int(plaintext.ckkscontext.gap)-1; j++ {
values[i*int(plaintext.ckkscontext.gap)+j+1] = complex(0, 0)
}
}
encodeFromComplex(values, plaintext)
return nil
}
// EncodeFloat encode a complex128 slice of at most N/2 values.
func (plaintext *Plaintext) EncodeComplex(coeffs []complex128) error {
if len(coeffs) > (len(plaintext.ckkscontext.indexMatrix)>>1)/int(plaintext.ckkscontext.gap) {
return errors.New("error : invalid input to encode (number of coefficients must be smaller or equal to the context)")
}
if len(plaintext.value[0].Coeffs[0]) != len(plaintext.ckkscontext.indexMatrix) {
return errors.New("error : invalid plaintext to receive encoding (number of coefficients does not match the context of the encoder")
}
values := make([]complex128, len(coeffs)*int(plaintext.ckkscontext.gap))
for i := range coeffs {
values[i*int(plaintext.ckkscontext.gap)] = coeffs[i]
for j := 0; j < int(plaintext.ckkscontext.gap)-1; j++ {
values[i*int(plaintext.ckkscontext.gap)+j+1] = complex(0, 0)
}
}
encodeFromComplex(values, plaintext)
return nil
}
// DecodeFloat decodes the plaintext to a slice of float64 values of size at most N/2.
func (plaintext *Plaintext) DecodeFloat() (res []float64) {
values := decodeToComplex(plaintext.value[0], plaintext.currentModulus, plaintext.ckkscontext.contextLevel[plaintext.Level()], plaintext.ckkscontext.roots, plaintext.scale)
res = make([]float64, int(plaintext.ckkscontext.slots)/int(plaintext.ckkscontext.gap))
for i := range res {
res[i] = real(values[plaintext.ckkscontext.indexMatrix[i*int(plaintext.ckkscontext.gap)]])
}
return
}
// DecodeFloat decodes the plaintext to a slice of complex128 values of size at most N/2.
func (plaintext *Plaintext) DecodeComplex() (res []complex128) {
values := decodeToComplex(plaintext.value[0], plaintext.currentModulus, plaintext.ckkscontext.contextLevel[plaintext.Level()], plaintext.ckkscontext.roots, plaintext.scale)
res = make([]complex128, int(plaintext.ckkscontext.slots)/int(plaintext.ckkscontext.gap))
for i := range res {
res[i] = values[plaintext.ckkscontext.indexMatrix[i*int(plaintext.ckkscontext.gap)]]
}
return
}
func encodeFromComplex(coeffs []complex128, plaintext *Plaintext) {
values := make([]complex128, plaintext.ckkscontext.n)
for i := 0; i < len(coeffs); i++ {
values[plaintext.ckkscontext.indexMatrix[i]] = coeffs[i]
values[plaintext.ckkscontext.indexMatrix[i+int(plaintext.ckkscontext.slots)]] = cmplx.Conj(coeffs[i])
}
invfft(values, plaintext.ckkscontext.inv_roots)
for i, qi := range plaintext.ckkscontext.modulie {
for j := uint64(0); j < plaintext.ckkscontext.n; j++ {
tmp := real(values[j]) / float64(plaintext.ckkscontext.n)
if tmp != 0 {
plaintext.value[0].Coeffs[i][j] = scaleUp(tmp, plaintext.scale, qi)
} else {
plaintext.value[0].Coeffs[i][j] = 0
}
}
}
plaintext.ckkscontext.contextLevel[plaintext.Level()].NTT(plaintext.value[0], plaintext.value[0])
}
func decodeToComplex(pol *ring.Poly, Q *ring.Int, context *ring.Context, roots []complex128, scale uint64) (values []complex128) {
tmp := context.NewPoly()
context.InvNTT(pol, tmp)
bigint_coeffs := context.PolyToBigint(tmp)
Q_half := new(ring.Int)
Q_half.SetBigInt(Q)
Q_half.Rsh(Q_half, 1)
var sign int
values = make([]complex128, context.N)
for i := range bigint_coeffs {
// Centers the value arounds the current modulus
bigint_coeffs[i].Mod(bigint_coeffs[i], Q)
sign = bigint_coeffs[i].Compare(Q_half)
if sign == 1 || sign == 0 {
bigint_coeffs[i].Sub(bigint_coeffs[i], Q)
}
values[i] = complex(scaleDown(bigint_coeffs[i], scale), 0)
}
fft(values, roots)
return
}
func invfft(values, inv_roots []complex128) {
var logN, mm, k_start, k_end, h, t uint64
var u, v, psi complex128
logN = uint64(bits.Len64(uint64(len(values))) - 1)
t = 1
for i := uint64(0); i < logN; i++ {
mm = 1 << (logN - i)
k_start = 0
h = mm >> 1
for j := uint64(0); j < h; j++ {
k_end = k_start + t
psi = inv_roots[h+j]
for k := k_start; k < k_end; k++ {
u = values[k]
v = values[k+t]
values[k] = u + v
values[k+t] = (u - v) * psi
}
k_start += (t << 1)
}
t <<= 1
}
}
func fft(values, roots []complex128) {
var logN, t, mm, j1, j2 uint64
var psi, u, v complex128
t = uint64(len(values))
logN = uint64(bits.Len64(t) - 1)
for i := uint64(0); i < logN; i++ {
mm = 1 << i
t >>= 1
for j := uint64(0); j < mm; j++ {
j1 = 2 * j * t
j2 = j1 + t - 1
psi = roots[mm+j]
for k := j1; k < j2+1; k++ {
u = values[k]
v = values[k+t] * psi
values[k] = u + v
values[k+t] = u - v
}
}
}
}