hermes-inspired rlwe ring packing

Co-authored-by: Jean-Philippe Bossuat <jean-philippe@tuneinsight.com>
This commit is contained in:
Andrea Caforio
2024-06-19 13:13:13 +02:00
committed by Romain Bouyé
parent 355afc49ea
commit 2db3225d30
11 changed files with 1840 additions and 897 deletions

View File

@@ -1,18 +1,156 @@
package rlwe
import (
"fmt"
"math/big"
"github.com/tuneinsight/lattigo/v5/ring"
"github.com/tuneinsight/lattigo/v5/ring/ringqp"
"github.com/tuneinsight/lattigo/v5/utils"
)
// Trace maps X -> sum((-1)^i * X^{i*n+1}) for n <= i < N
// Monomial X^k vanishes if k is not divisible by (N/n), otherwise it is multiplied by (N/n).
// Ciphertext is pre-multiplied by (N/n)^-1 to remove the (N/n) factor.
// Examples of full Trace for [0 + 1X + 2X^2 + 3X^3 + 4X^4 + 5X^5 + 6X^6 + 7X^7]
//
// 1.
//
// [1 + 2X + 3X^2 + 4X^3 + 5X^4 + 6X^5 + 7X^6 + 8X^7]
// + [1 - 6X - 3X^2 + 8X^3 + 5X^4 + 2X^5 - 7X^6 - 4X^7] {X-> X^(i * 5^1)}
// = [2 - 4X + 0X^2 +12X^3 +10X^4 + 8X^5 - 0X^6 + 4X^7]
//
// 2.
//
// [2 - 4X + 0X^2 +12X^3 +10X^4 + 8X^5 - 0X^6 + 4X^7]
// + [2 + 4X + 0X^2 -12X^3 +10X^4 - 8X^5 + 0X^6 - 4X^7] {X-> X^(i * 5^2)}
// = [4 + 0X + 0X^2 - 0X^3 +20X^4 + 0X^5 + 0X^6 - 0X^7]
//
// 3.
//
// [4 + 0X + 0X^2 - 0X^3 +20X^4 + 0X^5 + 0X^6 - 0X^7]
// + [4 + 0X + 0X^2 - 0X^3 -20X^4 + 0X^5 + 0X^6 - 0X^7] {X-> X^(i * -1)}
// = [8 + 0X + 0X^2 - 0X^3 + 0X^4 + 0X^5 + 0X^6 - 0X^7]
//
// The method will return an error if the input and output ciphertexts degree is not one.
func (eval Evaluator) Trace(ctIn *Ciphertext, logN int, opOut *Ciphertext) (err error) {
if ctIn.Degree() != 1 || opOut.Degree() != 1 {
return fmt.Errorf("ctIn.Degree() != 1 or opOut.Degree() != 1")
}
params := eval.GetRLWEParameters()
level := utils.Min(ctIn.Level(), opOut.Level())
opOut.Resize(opOut.Degree(), level)
*opOut.MetaData = *ctIn.MetaData
gap := 1 << (params.LogN() - logN - 1)
if logN == 0 {
gap <<= 1
}
if gap > 1 {
ringQ := params.RingQ().AtLevel(level)
if ringQ.Type() == ring.ConjugateInvariant {
gap >>= 1 // We skip the last step that applies phi(5^{-1})
}
NInv := new(big.Int).SetUint64(uint64(gap))
NInv.ModInverse(NInv, ringQ.ModulusAtLevel[level])
// pre-multiplication by (N/n)^-1
ringQ.MulScalarBigint(ctIn.Value[0], NInv, opOut.Value[0])
ringQ.MulScalarBigint(ctIn.Value[1], NInv, opOut.Value[1])
if !ctIn.IsNTT {
ringQ.NTT(opOut.Value[0], opOut.Value[0])
ringQ.NTT(opOut.Value[1], opOut.Value[1])
opOut.IsNTT = true
}
buff, err := NewCiphertextAtLevelFromPoly(level, []ring.Poly{eval.BuffQP[3].Q, eval.BuffQP[4].Q})
// Sanity check, this error should not happen unless the
// evaluator's buffer thave been improperly tempered with.
if err != nil {
panic(err)
}
buff.IsNTT = true
for i := logN; i < params.LogN()-1; i++ {
if err = eval.Automorphism(opOut, params.GaloisElement(1<<i), buff); err != nil {
return err
}
ringQ.Add(opOut.Value[0], buff.Value[0], opOut.Value[0])
ringQ.Add(opOut.Value[1], buff.Value[1], opOut.Value[1])
}
if logN == 0 && ringQ.Type() == ring.Standard {
if err = eval.Automorphism(opOut, ringQ.NthRoot()-1, buff); err != nil {
return err
}
ringQ.Add(opOut.Value[0], buff.Value[0], opOut.Value[0])
ringQ.Add(opOut.Value[1], buff.Value[1], opOut.Value[1])
}
if !ctIn.IsNTT {
ringQ.INTT(opOut.Value[0], opOut.Value[0])
ringQ.INTT(opOut.Value[1], opOut.Value[1])
opOut.IsNTT = false
}
} else {
if ctIn != opOut {
opOut.Copy(ctIn)
}
}
return
}
// GaloisElementsForTrace returns the list of Galois elements requored for the for the `Trace` operation.
// Trace maps X -> sum((-1)^i * X^{i*n+1}) for 2^{LogN} <= i < N.
func GaloisElementsForTrace(params ParameterProvider, logN int) (galEls []uint64) {
p := params.GetRLWEParameters()
galEls = []uint64{}
for i, j := logN, 0; i < p.LogN()-1; i, j = i+1, j+1 {
galEls = append(galEls, p.GaloisElement(1<<i))
}
if logN == 0 {
switch p.RingType() {
case ring.Standard:
galEls = append(galEls, p.GaloisElementOrderTwoOrthogonalSubgroup())
case ring.ConjugateInvariant:
panic("cannot GaloisElementsForTrace: Galois element GaloisGen^-1 is undefined in ConjugateInvariant Ring")
default:
panic("cannot GaloisElementsForTrace: invalid ring type")
}
}
return
}
// InnerSum applies an optimized inner sum on the Ciphertext (log2(n) + HW(n) rotations with double hoisting).
// The operation assumes that `ctIn` encrypts Slots/`batchSize` sub-vectors of size `batchSize` and will add them together (in parallel) in groups of `n`.
// It outputs in opOut a Ciphertext for which the "leftmost" sub-vector of each group is equal to the sum of the group.
// It outputs in opOut a [Ciphertext] for which the "leftmost" sub-vector of each group is equal to the sum of the group.
//
// The inner sum is computed in a tree fashion. Example for batchSize=2 & n=4 (garbage slots are marked by 'x'):
//
// 1) [{a, b}, {c, d}, {e, f}, {g, h}, {a, b}, {c, d}, {e, f}, {g, h}]
// 1. [{a, b}, {c, d}, {e, f}, {g, h}, {a, b}, {c, d}, {e, f}, {g, h}]
//
// 2. [{a, b}, {c, d}, {e, f}, {g, h}, {a, b}, {c, d}, {e, f}, {g, h}]
// +
@@ -163,16 +301,17 @@ func (eval Evaluator) InnerSum(ctIn *Ciphertext, batchSize, n int, opOut *Cipher
return
}
// InnerFunction applies an user defined function on the Ciphertext with a tree-like combination requiring log2(n) + HW(n) rotations.
// InnerFunction applies an user defined function on the [Ciphertext] with a tree-like combination requiring log2(n) + HW(n) rotations.
//
// InnerFunction with f = eval.Add(a, b, c) is equivalent to InnerSum (although slightly slower).
// InnerFunction with f = eval.Add(a, b, c) is equivalent to [Evaluator.InnerSum] (although slightly slower).
//
// The operation assumes that `ctIn` encrypts Slots/`batchSize` sub-vectors of size `batchSize` and will add them together (in parallel) in groups of `n`.
// It outputs in opOut a Ciphertext for which the "leftmost" sub-vector of each group is equal to the pair-wise recursive evaluation of function over the group.
// It outputs in opOut a [Ciphertext] for which the "leftmost" sub-vector of each group is equal to the pair-wise recursive evaluation of
// function over the group.
//
// The inner function is computed in a tree fashion. Example for batchSize=2 & n=4 (garbage slots are marked by 'x'):
//
// 1) [{a, b}, {c, d}, {e, f}, {g, h}, {a, b}, {c, d}, {e, f}, {g, h}]
// 1. [{a, b}, {c, d}, {e, f}, {g, h}, {a, b}, {c, d}, {e, f}, {g, h}]
//
// 2. [{a, b}, {c, d}, {e, f}, {g, h}, {a, b}, {c, d}, {e, f}, {g, h}]
// f
@@ -312,7 +451,7 @@ func (eval Evaluator) InnerFunction(ctIn *Ciphertext, batchSize, n int, f func(a
}
// GaloisElementsForInnerSum returns the list of Galois elements necessary to apply the method
// `InnerSum` operation with parameters `batch` and `n`.
// [Evaluator.InnerSum] operation with parameters batch and n.
func GaloisElementsForInnerSum(params ParameterProvider, batch, n int) (galEls []uint64) {
rotIndex := make(map[int]bool)
@@ -339,19 +478,19 @@ func GaloisElementsForInnerSum(params ParameterProvider, batch, n int) (galEls [
return params.GetRLWEParameters().GaloisElements(rotations)
}
// Replicate applies an optimized replication on the Ciphertext (log2(n) + HW(n) rotations with double hoisting).
// Replicate applies an optimized replication on the [Ciphertext] (log2(n) + HW(n) rotations with double hoisting).
// It acts as the inverse of a inner sum (summing elements from left to right).
// The replication is parameterized by the size of the sub-vectors to replicate "batchSize" and
// the number of times 'n' they need to be replicated.
// The replication is parameterized by the size of the sub-vectors to replicate batchSize and
// the number of times n they need to be replicated.
// To ensure correctness, a gap of zero values of size batchSize * (n-1) must exist between
// two consecutive sub-vectors to replicate.
// This method is faster than Replicate when the number of rotations is large and it uses log2(n) + HW(n) instead of 'n'.
// This method is faster than Replicate when the number of rotations is large and it uses log2(n) + HW(n) instead of n.
func (eval Evaluator) Replicate(ctIn *Ciphertext, batchSize, n int, opOut *Ciphertext) (err error) {
return eval.InnerSum(ctIn, -batchSize, n, opOut)
}
// GaloisElementsForReplicate returns the list of Galois elements necessary to perform the
// `Replicate` operation with parameters `batch` and `n`.
// [Evaluator.Replicate] operation with parameters batch and n.
func GaloisElementsForReplicate(params ParameterProvider, batch, n int) (galEls []uint64) {
return GaloisElementsForInnerSum(params, -batch, n)
}

View File

@@ -1,505 +0,0 @@
package rlwe
import (
"fmt"
"math/big"
"math/bits"
"github.com/tuneinsight/lattigo/v5/ring"
"github.com/tuneinsight/lattigo/v5/utils"
)
// Trace maps X -> sum((-1)^i * X^{i*n+1}) for n <= i < N
// Monomial X^k vanishes if k is not divisible by (N/n), otherwise it is multiplied by (N/n).
// Ciphertext is pre-multiplied by (N/n)^-1 to remove the (N/n) factor.
// Examples of full Trace for [0 + 1X + 2X^2 + 3X^3 + 4X^4 + 5X^5 + 6X^6 + 7X^7]
//
// 1.
//
// [1 + 2X + 3X^2 + 4X^3 + 5X^4 + 6X^5 + 7X^6 + 8X^7]
// + [1 - 6X - 3X^2 + 8X^3 + 5X^4 + 2X^5 - 7X^6 - 4X^7] {X-> X^(i * 5^1)}
// = [2 - 4X + 0X^2 +12X^3 +10X^4 + 8X^5 - 0X^6 + 4X^7]
//
// 2.
//
// [2 - 4X + 0X^2 +12X^3 +10X^4 + 8X^5 - 0X^6 + 4X^7]
// + [2 + 4X + 0X^2 -12X^3 +10X^4 - 8X^5 + 0X^6 - 4X^7] {X-> X^(i * 5^2)}
// = [4 + 0X + 0X^2 - 0X^3 +20X^4 + 0X^5 + 0X^6 - 0X^7]
//
// 3.
//
// [4 + 0X + 0X^2 - 0X^3 +20X^4 + 0X^5 + 0X^6 - 0X^7]
// + [4 + 0X + 0X^2 - 0X^3 -20X^4 + 0X^5 + 0X^6 - 0X^7] {X-> X^(i * -1)}
// = [8 + 0X + 0X^2 - 0X^3 + 0X^4 + 0X^5 + 0X^6 - 0X^7]
//
// The method will return an error if the input and output ciphertexts degree is not one.
func (eval Evaluator) Trace(ctIn *Ciphertext, logN int, opOut *Ciphertext) (err error) {
if ctIn.Degree() != 1 || opOut.Degree() != 1 {
return fmt.Errorf("ctIn.Degree() != 1 or opOut.Degree() != 1")
}
params := eval.GetRLWEParameters()
level := utils.Min(ctIn.Level(), opOut.Level())
opOut.Resize(opOut.Degree(), level)
*opOut.MetaData = *ctIn.MetaData
gap := 1 << (params.LogN() - logN - 1)
if logN == 0 {
gap <<= 1
}
if gap > 1 {
ringQ := params.RingQ().AtLevel(level)
if ringQ.Type() == ring.ConjugateInvariant {
gap >>= 1 // We skip the last step that applies phi(5^{-1})
}
NInv := new(big.Int).SetUint64(uint64(gap))
NInv.ModInverse(NInv, ringQ.ModulusAtLevel[level])
// pre-multiplication by (N/n)^-1
ringQ.MulScalarBigint(ctIn.Value[0], NInv, opOut.Value[0])
ringQ.MulScalarBigint(ctIn.Value[1], NInv, opOut.Value[1])
if !ctIn.IsNTT {
ringQ.NTT(opOut.Value[0], opOut.Value[0])
ringQ.NTT(opOut.Value[1], opOut.Value[1])
opOut.IsNTT = true
}
buff, err := NewCiphertextAtLevelFromPoly(level, []ring.Poly{eval.BuffQP[3].Q, eval.BuffQP[4].Q})
// Sanity check, this error should not happen unless the
// evaluator's buffer thave been improperly tempered with.
if err != nil {
panic(err)
}
buff.IsNTT = true
for i := logN; i < params.LogN()-1; i++ {
if err = eval.Automorphism(opOut, params.GaloisElement(1<<i), buff); err != nil {
return err
}
ringQ.Add(opOut.Value[0], buff.Value[0], opOut.Value[0])
ringQ.Add(opOut.Value[1], buff.Value[1], opOut.Value[1])
}
if logN == 0 && ringQ.Type() == ring.Standard {
if err = eval.Automorphism(opOut, ringQ.NthRoot()-1, buff); err != nil {
return err
}
ringQ.Add(opOut.Value[0], buff.Value[0], opOut.Value[0])
ringQ.Add(opOut.Value[1], buff.Value[1], opOut.Value[1])
}
if !ctIn.IsNTT {
ringQ.INTT(opOut.Value[0], opOut.Value[0])
ringQ.INTT(opOut.Value[1], opOut.Value[1])
opOut.IsNTT = false
}
} else {
if ctIn != opOut {
opOut.Copy(ctIn)
}
}
return
}
// GaloisElementsForTrace returns the list of Galois elements requored for the for the `Trace` operation.
// Trace maps X -> sum((-1)^i * X^{i*n+1}) for 2^{LogN} <= i < N.
func GaloisElementsForTrace(params ParameterProvider, logN int) (galEls []uint64) {
p := params.GetRLWEParameters()
galEls = []uint64{}
for i, j := logN, 0; i < p.LogN()-1; i, j = i+1, j+1 {
galEls = append(galEls, p.GaloisElement(1<<i))
}
if logN == 0 {
switch p.RingType() {
case ring.Standard:
galEls = append(galEls, p.GaloisElementOrderTwoOrthogonalSubgroup())
case ring.ConjugateInvariant:
panic("cannot GaloisElementsForTrace: Galois element GaloisGen^-1 is undefined in ConjugateInvariant Ring")
default:
panic("cannot GaloisElementsForTrace: invalid ring type")
}
}
return
}
// Expand expands a RLWE Ciphertext encrypting sum ai * X^i to 2^logN ciphertexts,
// each encrypting ai * X^0 for 0 <= i < 2^LogN. That is, it extracts the first 2^logN
// coefficients, whose degree is a multiple of 2^logGap, of ctIn and returns an RLWE
// Ciphertext for each coefficient extracted.
//
// The method will return an error if:
// - The input ciphertext degree is not one
// - The ring type is not ring.Standard
func (eval Evaluator) Expand(ctIn *Ciphertext, logN, logGap int) (opOut []*Ciphertext, err error) {
if ctIn.Degree() != 1 {
return nil, fmt.Errorf("cannot Expand: ctIn.Degree() != 1")
}
params := eval.GetRLWEParameters()
if params.RingType() != ring.Standard {
return nil, fmt.Errorf("cannot Expand: method is only supported for ring.Type = ring.Standard (X^{-2^{i}} does not exist in the sub-ring Z[X + X^{-1}])")
}
level := ctIn.Level()
ringQ := params.RingQ().AtLevel(level)
// Compute X^{-2^{i}} from 1 to LogN
xPow2 := GenXPow2(ringQ, logN, true)
opOut = make([]*Ciphertext, 1<<(logN-logGap))
opOut[0] = ctIn.CopyNew()
opOut[0].LogDimensions = ring.Dimensions{Rows: 0, Cols: 0}
if ct := opOut[0]; !ctIn.IsNTT {
ringQ.NTT(ct.Value[0], ct.Value[0])
ringQ.NTT(ct.Value[1], ct.Value[1])
ct.IsNTT = true
}
// Multiplies by 2^{-logN} mod Q
NInv := new(big.Int).SetUint64(1 << logN)
NInv.ModInverse(NInv, ringQ.ModulusAtLevel[level])
ringQ.MulScalarBigint(opOut[0].Value[0], NInv, opOut[0].Value[0])
ringQ.MulScalarBigint(opOut[0].Value[1], NInv, opOut[0].Value[1])
gap := 1 << logGap
tmp, err := NewCiphertextAtLevelFromPoly(level, []ring.Poly{eval.BuffCt.Value[0], eval.BuffCt.Value[1]})
// Sanity check, this error should not happen unless the
// evaluator's buffer thave been improperly tempered with.
if err != nil {
panic(err)
}
tmp.MetaData = ctIn.MetaData
for i := 0; i < logN; i++ {
n := 1 << i
galEl := uint64(ringQ.N()/n + 1)
half := n / gap
for j := 0; j < (n+gap-1)/gap; j++ {
c0 := opOut[j]
// X -> X^{N/n + 1}
//[a, b, c, d] -> [a, -b, c, -d]
if err = eval.Automorphism(c0, galEl, tmp); err != nil {
return
}
if j+half > 0 {
c1 := opOut[j].CopyNew()
// Zeroes odd coeffs: [a, b, c, d] + [a, -b, c, -d] -> [2a, 0, 2b, 0]
ringQ.Add(c0.Value[0], tmp.Value[0], c0.Value[0])
ringQ.Add(c0.Value[1], tmp.Value[1], c0.Value[1])
// Zeroes even coeffs: [a, b, c, d] - [a, -b, c, -d] -> [0, 2b, 0, 2d]
ringQ.Sub(c1.Value[0], tmp.Value[0], c1.Value[0])
ringQ.Sub(c1.Value[1], tmp.Value[1], c1.Value[1])
// c1 * X^{-2^{i}}: [0, 2b, 0, 2d] * X^{-n} -> [2b, 0, 2d, 0]
ringQ.MulCoeffsMontgomery(c1.Value[0], xPow2[i], c1.Value[0])
ringQ.MulCoeffsMontgomery(c1.Value[1], xPow2[i], c1.Value[1])
opOut[j+half] = c1
} else {
// Zeroes odd coeffs: [a, b, c, d] + [a, -b, c, -d] -> [2a, 0, 2b, 0]
ringQ.Add(c0.Value[0], tmp.Value[0], c0.Value[0])
ringQ.Add(c0.Value[1], tmp.Value[1], c0.Value[1])
}
}
}
for _, ct := range opOut {
if ct != nil && !ctIn.IsNTT {
ringQ.INTT(ct.Value[0], ct.Value[0])
ringQ.INTT(ct.Value[1], ct.Value[1])
ct.IsNTT = false
}
}
return
}
// GaloisElementsForExpand returns the list of Galois elements required
// to perform the `Expand` operation with parameter `logN`.
func GaloisElementsForExpand(params ParameterProvider, logN int) (galEls []uint64) {
galEls = make([]uint64, logN)
NthRoot := params.GetRLWEParameters().RingQ().NthRoot()
for i := 0; i < logN; i++ {
galEls[i] = uint64(NthRoot/(2<<i) + 1)
}
return
}
// Pack packs a batch of RLWE ciphertexts, packing the batch of ciphertexts into a single ciphertext.
// The number of key-switching operations is inputLogGap - log2(gap) + len(cts), where log2(gap) is the
// minimum distance between two keys of the map cts[int]*Ciphertext.
//
// Input:
//
// cts: a map of Ciphertext, where the index in the map is the future position of the first coefficient
// of the indexed ciphertext in the final ciphertext (see example). Ciphertexts can be in or out of the NTT domain.
// logGap: all coefficients of the input ciphertexts that are not a multiple of X^{2^{logGap}} will be zeroed
// during the merging (see example). This is equivalent to skipping the first 2^{logGap} steps of the
// algorithm, i.e. having as input ciphertexts that are already partially packed.
// zeroGarbageSlots: if set to true, slots which are not multiples of X^{2^{logGap}} will be zeroed during the procedure.
// this will greatly increase the noise and increase the number of key-switching operations to inputLogGap + len(cts).
//
// Output: a ciphertext packing all input ciphertexts
//
// Example: we want to pack 4 ciphertexts into one, and keep only coefficients which are a multiple of X^{4}.
//
// To do so, we must set logGap = 2.
// Here the `X` slots are treated as garbage slots that we want to discard during the procedure.
//
// input: map[int]{
// 0: [x00, X, X, X, x01, X, X, X], with logGap = 2
// 1: [x10, X, X, X, x11, X, X, X],
// 2: [x20, X, X, X, x21, X, X, X],
// 3: [x30, X, X, X, x31, X, X, X],
// }
//
// Step 1:
// map[0]: 2^{-1} * (map[0] + X^2 * map[2] + phi_{5^2}(map[0] - X^2 * map[2]) = [x00, X, x20, X, x01, X, x21, X]
// map[1]: 2^{-1} * (map[1] + X^2 * map[3] + phi_{5^2}(map[1] - X^2 * map[3]) = [x10, X, x30, X, x11, X, x31, X]
// Step 2:
// map[0]: 2^{-1} * (map[0] + X^1 * map[1] + phi_{5^4}(map[0] - X^1 * map[1]) = [x00, x10, x20, x30, x01, x11, x21, x22]
func (eval Evaluator) Pack(cts map[int]*Ciphertext, inputLogGap int, zeroGarbageSlots bool) (ct *Ciphertext, err error) {
params := eval.GetRLWEParameters()
if params.RingType() != ring.Standard {
return nil, fmt.Errorf("cannot Pack: procedure is only supported for ring.Type = ring.Standard (X^{2^{i}} does not exist in the sub-ring Z[X + X^{-1}])")
}
if len(cts) < 2 {
return nil, fmt.Errorf("cannot Pack: #cts must be at least 2")
}
keys := utils.GetSortedKeys(cts)
gap := keys[1] - keys[0]
level := cts[keys[0]].Level()
for i, key := range keys[1:] {
level = utils.Min(level, cts[key].Level())
if i < len(keys)-1 {
gap = utils.Min(gap, keys[i+1]-keys[i])
}
}
logN := params.LogN()
ringQ := params.RingQ().AtLevel(level)
logStart := logN - inputLogGap
logEnd := logN
if !zeroGarbageSlots {
if gap > 0 {
logEnd -= bits.Len64(uint64(gap - 1))
}
}
if logStart >= logEnd {
return nil, fmt.Errorf("cannot Pack: gaps between ciphertexts is smaller than inputLogGap > N")
}
xPow2 := GenXPow2(ringQ.AtLevel(level), params.LogN(), false) // log(N) polynomial to generate, quick
NInv := new(big.Int).SetUint64(uint64(1 << (logEnd - logStart)))
NInv.ModInverse(NInv, ringQ.ModulusAtLevel[level])
for _, key := range keys {
ct := cts[key]
if ct.Degree() != 1 {
return nil, fmt.Errorf("cannot Pack: cts[%d].Degree() != 1", key)
}
if !ct.IsNTT {
ringQ.NTT(ct.Value[0], ct.Value[0])
ringQ.NTT(ct.Value[1], ct.Value[1])
ct.IsNTT = true
}
ringQ.MulScalarBigint(ct.Value[0], NInv, ct.Value[0])
ringQ.MulScalarBigint(ct.Value[1], NInv, ct.Value[1])
}
tmpa := &Ciphertext{}
tmpa.Value = []ring.Poly{ringQ.NewPoly(), ringQ.NewPoly()}
tmpa.MetaData = &MetaData{}
tmpa.MetaData.IsNTT = true
for i := logStart; i < logEnd; i++ {
t := 1 << (logN - 1 - i)
for jx, jy := 0, t; jx < t; jx, jy = jx+1, jy+1 {
a := cts[jx]
b := cts[jy]
if b != nil {
//X^(N/2^L)
ringQ.MulCoeffsMontgomery(b.Value[0], xPow2[len(xPow2)-i-1], b.Value[0])
ringQ.MulCoeffsMontgomery(b.Value[1], xPow2[len(xPow2)-i-1], b.Value[1])
if a != nil {
// tmpa = phi(a - b * X^{N/2^{i}}, 2^{i-1})
ringQ.Sub(a.Value[0], b.Value[0], tmpa.Value[0])
ringQ.Sub(a.Value[1], b.Value[1], tmpa.Value[1])
// a = a + b * X^{N/2^{i}}
ringQ.Add(a.Value[0], b.Value[0], a.Value[0])
ringQ.Add(a.Value[1], b.Value[1], a.Value[1])
} else {
// if ct[jx] == nil, then simply re-assigns
cts[jx] = cts[jy]
// Required for correctness, since each log step is expected
// to double the values, which are pre-scaled by N^{-1} mod Q
// Maybe this can be omitted by doing an individual pre-scaling.
ringQ.Add(cts[jx].Value[0], cts[jx].Value[0], cts[jx].Value[0])
ringQ.Add(cts[jx].Value[1], cts[jx].Value[1], cts[jx].Value[1])
}
}
if a != nil {
var galEl uint64
if i == 0 {
galEl = ringQ.NthRoot() - 1
} else {
galEl = params.GaloisElement(1 << (i - 1))
}
if b != nil {
if err = eval.Automorphism(tmpa, galEl, tmpa); err != nil {
return
}
} else {
if err = eval.Automorphism(a, galEl, tmpa); err != nil {
return
}
}
// a + b * X^{N/2^{i}} + phi(a - b * X^{N/2^{i}}, 2^{i-1})
ringQ.Add(a.Value[0], tmpa.Value[0], a.Value[0])
ringQ.Add(a.Value[1], tmpa.Value[1], a.Value[1])
}
}
}
return cts[0], nil
}
// GaloisElementsForPack returns the list of Galois elements required to perform the `Pack` operation.
func GaloisElementsForPack(params ParameterProvider, logGap int) (galEls []uint64) {
p := params.GetRLWEParameters()
// Sanity check
if logGap > p.LogN() || logGap < 0 {
panic(fmt.Errorf("cannot GaloisElementsForPack: logGap > logN || logGap < 0"))
}
galEls = make([]uint64, 0, logGap)
for i := 0; i < logGap; i++ {
galEls = append(galEls, p.GaloisElement(1<<i))
}
switch p.RingType() {
case ring.Standard:
if logGap == p.LogN() {
galEls = append(galEls, p.GaloisElementOrderTwoOrthogonalSubgroup())
}
default:
panic("cannot GaloisElementsForPack: invalid ring type")
}
return
}
func GenXPow2(r *ring.Ring, logN int, div bool) (xPow []ring.Poly) {
// Compute X^{-n} from 0 to LogN
xPow = make([]ring.Poly, logN)
moduli := r.ModuliChain()[:r.Level()+1]
BRC := r.BRedConstants()
var idx int
for i := 0; i < logN; i++ {
idx = 1 << i
if div {
idx = r.N() - idx
}
xPow[i] = r.NewPoly()
if i == 0 {
for j := range moduli {
xPow[i].Coeffs[j][idx] = ring.MForm(1, moduli[j], BRC[j])
}
r.NTT(xPow[i], xPow[i])
} else {
r.MulCoeffsMontgomery(xPow[i-1], xPow[i-1], xPow[i]) // X^{n} = X^{1} * X^{n-1}
}
}
if div {
r.Neg(xPow[0], xPow[0])
}
return
}

View File

@@ -149,73 +149,6 @@ func testUserDefinedParameters(t *testing.T) {
require.Equal(t, paramsWithBadDist, Parameters{})
})
// test valid/invalid configurations of prime fields
t.Run("Parameters/NewParametersFromLiteral", func(t *testing.T) {
Q := []uint64{0x200000440001, 0x7fff80001, 0x800280001, 0x7ffd80001, 0x7ffc80001}
P := []uint64{0x3ffffffb80001, 0x4000000800001}
logQ := []int{55, 40, 40, 40, 40}
logP := []int{55, 55}
// both Q and P given (good)
params, err := NewParametersFromLiteral(ParametersLiteral{
LogN: logN, Q: Q, P: P, LogQ: nil, LogP: nil,
})
require.NoError(t, err)
require.Equal(t, params.qi, Q)
require.Equal(t, params.pi, P)
// only Q given (good)
params, err = NewParametersFromLiteral(ParametersLiteral{
LogN: logN, Q: Q, P: nil, LogQ: nil, LogP: nil,
})
require.NoError(t, err)
require.Equal(t, params.qi, Q)
require.Empty(t, params.pi)
// Q and logP given (good)
params, err = NewParametersFromLiteral(ParametersLiteral{
LogN: logN, Q: Q, P: nil, LogQ: nil, LogP: logP,
})
require.NoError(t, err)
require.Equal(t, params.qi, Q)
require.Equal(t, len(params.pi), len(logP))
// logQ and P given (good)
params, err = NewParametersFromLiteral(ParametersLiteral{
LogN: logN, Q: nil, P: P, LogQ: logQ, LogP: nil,
})
require.NoError(t, err)
require.Equal(t, len(params.qi), len(logQ))
require.Equal(t, params.pi, P)
// both LogQ and LogP given (good)
params, err = NewParametersFromLiteral(ParametersLiteral{
LogN: logN, Q: nil, P: nil, LogQ: logQ, LogP: logP,
})
require.NoError(t, err)
require.Equal(t, len(params.qi), len(logQ))
require.Equal(t, len(params.pi), len(logP))
// only LogQ given (good)
params, err = NewParametersFromLiteral(ParametersLiteral{
LogN: logN, Q: nil, P: nil, LogQ: logQ, LogP: nil,
})
require.NoError(t, err)
require.Equal(t, len(params.qi), len(logQ))
require.Empty(t, params.pi)
// empty primes (bad)
_, err = NewParametersFromLiteral(ParametersLiteral{
LogN: logN, Q: nil, P: nil, LogQ: nil, LogP: nil,
})
require.Error(t, err)
// double set log/non-prime (bad)
_, err = NewParametersFromLiteral(ParametersLiteral{
LogN: logN, Q: Q, P: nil, LogQ: logQ, LogP: nil,
})
require.Error(t, err)
})
}
func NewTestContext(params Parameters) (tc *TestContext, err error) {
@@ -257,6 +190,12 @@ func testParameters(tc *TestContext, t *testing.T) {
require.Equal(t, uint64(1), res)
}
})
t.Run(testString(params, params.MaxLevelQ(), params.MaxLevelP(), 0, "Elements"), func(t *testing.T) {
ct := NewCiphertext(tc.params, 1, 0)
require.Equal(t, ct.N(), params.N())
require.Equal(t, ct.LogN(), params.LogN())
})
}
func testKeyGenerator(tc *TestContext, bpw2 int, t *testing.T) {
@@ -928,204 +867,6 @@ func testSlotOperations(tc *TestContext, level, bpw2 int, t *testing.T) {
enc := tc.enc
dec := tc.dec
evkParams := EvaluationKeyParameters{LevelQ: utils.Pointy(level), BaseTwoDecomposition: utils.Pointy(bpw2)}
t.Run(testString(params, level, params.MaxLevelP(), bpw2, "Evaluator/Expand"), func(t *testing.T) {
if params.RingType() != ring.Standard {
t.Skip("Expand not supported for ring.Type = ring.ConjugateInvariant")
}
pt := NewPlaintext(params, level)
ringQ := params.RingQ().AtLevel(level)
logN := 4
logGap := 0
gap := 1 << logGap
values := make([]uint64, params.N())
scale := 1 << 22
for i := 0; i < 1<<logN; i++ { // embeds even coefficients only
values[i] = uint64(i * scale)
}
for i := 0; i < pt.Level()+1; i++ {
copy(pt.Value.Coeffs[i], values)
}
if pt.IsNTT {
ringQ.NTT(pt.Value, pt.Value)
}
ctIn := NewCiphertext(params, 1, level)
enc.Encrypt(pt, ctIn)
// GaloisKeys
evk := NewMemEvaluationKeySet(nil, kgen.GenGaloisKeysNew(GaloisElementsForExpand(params, logN), sk, evkParams)...)
eval := NewEvaluator(params, evk)
ciphertexts, err := eval.WithKey(evk).Expand(ctIn, logN, logGap)
require.NoError(t, err)
Q := ringQ.ModuliChain()
NoiseBound := float64(params.LogN() - logN + bpw2)
if bpw2 != 0 {
NoiseBound += float64(level + 5)
}
for i := range ciphertexts {
dec.Decrypt(ciphertexts[i], pt)
if pt.IsNTT {
ringQ.INTT(pt.Value, pt.Value)
}
for j := 0; j < level+1; j++ {
pt.Value.Coeffs[j][0] = ring.CRed(pt.Value.Coeffs[j][0]+Q[j]-values[i*gap], Q[j])
}
// Logs the noise
require.GreaterOrEqual(t, NoiseBound, ringQ.Log2OfStandardDeviation(pt.Value))
}
})
t.Run(testString(params, level, params.MaxLevelP(), bpw2, "Evaluator/Pack/LogGap=LogN"), func(t *testing.T) {
if params.RingType() != ring.Standard {
t.Skip("Pack not supported for ring.Type = ring.ConjugateInvariant")
}
pt := NewPlaintext(params, level)
N := params.N()
ringQ := tc.params.RingQ().AtLevel(level)
gap := params.N() / 16
ptPacked := NewPlaintext(params, level)
ciphertexts := make(map[int]*Ciphertext)
slotIndex := make(map[int]bool)
for i := 0; i < N; i += gap {
ciphertexts[i] = enc.EncryptZeroNew(level)
scalar := (1 << 30) + uint64(i)*(1<<20)
if ciphertexts[i].IsNTT {
ringQ.AddScalar(ciphertexts[i].Value[0], scalar, ciphertexts[i].Value[0])
} else {
for j := 0; j < level+1; j++ {
ciphertexts[i].Value[0].Coeffs[j][0] = ring.CRed(ciphertexts[i].Value[0].Coeffs[j][0]+scalar, ringQ.SubRings[j].Modulus)
}
}
slotIndex[i] = true
for j := 0; j < level+1; j++ {
ptPacked.Value.Coeffs[j][i] = scalar
}
}
// Galois Keys
evk := NewMemEvaluationKeySet(nil, kgen.GenGaloisKeysNew(GaloisElementsForPack(params, params.LogN()), sk, evkParams)...)
ct, err := eval.WithKey(evk).Pack(ciphertexts, params.LogN(), false)
require.NoError(t, err)
dec.Decrypt(ct, pt)
if pt.IsNTT {
ringQ.INTT(pt.Value, pt.Value)
}
ringQ.Sub(pt.Value, ptPacked.Value, pt.Value)
for i := 0; i < N; i++ {
if i%gap != 0 {
for j := 0; j < level+1; j++ {
pt.Value.Coeffs[j][i] = 0
}
}
}
NoiseBound := 15.0 + float64(bpw2)
if bpw2 != 0 {
NoiseBound += math.Log2(float64(level)+1.0) + 1.0
}
// Logs the noise
require.GreaterOrEqual(t, NoiseBound, ringQ.Log2OfStandardDeviation(pt.Value))
})
t.Run(testString(params, level, params.MaxLevelP(), bpw2, "Evaluator/Pack/LogGap=LogN-1"), func(t *testing.T) {
if params.RingType() != ring.Standard {
t.Skip("Pack not supported for ring.Type = ring.ConjugateInvariant")
}
pt := NewPlaintext(params, level)
N := params.N()
ringQ := tc.params.RingQ().AtLevel(level)
ptPacked := NewPlaintext(params, level)
ciphertexts := make(map[int]*Ciphertext)
slotIndex := make(map[int]bool)
for i := 0; i < N/2; i += params.N() / 16 {
ciphertexts[i] = enc.EncryptZeroNew(level)
scalar := (1 << 30) + uint64(i)*(1<<20)
if ciphertexts[i].IsNTT {
ringQ.INTT(ciphertexts[i].Value[0], ciphertexts[i].Value[0])
}
for j := 0; j < level+1; j++ {
ciphertexts[i].Value[0].Coeffs[j][0] = ring.CRed(ciphertexts[i].Value[0].Coeffs[j][0]+scalar, ringQ.SubRings[j].Modulus)
ciphertexts[i].Value[0].Coeffs[j][N/2] = ring.CRed(ciphertexts[i].Value[0].Coeffs[j][N/2]+scalar, ringQ.SubRings[j].Modulus)
}
if ciphertexts[i].IsNTT {
ringQ.NTT(ciphertexts[i].Value[0], ciphertexts[i].Value[0])
}
slotIndex[i] = true
for j := 0; j < level+1; j++ {
ptPacked.Value.Coeffs[j][i] = scalar
ptPacked.Value.Coeffs[j][i+N/2] = scalar
}
}
// Galois Keys
evk := NewMemEvaluationKeySet(nil, kgen.GenGaloisKeysNew(GaloisElementsForPack(params, params.LogN()-1), sk, evkParams)...)
ct, err := eval.WithKey(evk).Pack(ciphertexts, params.LogN()-1, true)
require.NoError(t, err)
dec.Decrypt(ct, pt)
if pt.IsNTT {
ringQ.INTT(pt.Value, pt.Value)
}
ringQ.Sub(pt.Value, ptPacked.Value, pt.Value)
NoiseBound := 15.0 + float64(bpw2)
if bpw2 != 0 {
NoiseBound += math.Log2(float64(level)+1.0) + 1.0
}
// Logs the noise
require.GreaterOrEqual(t, NoiseBound, ringQ.Log2OfStandardDeviation(pt.Value))
})
t.Run(testString(params, level, params.MaxLevelP(), bpw2, "Evaluator/InnerSum"), func(t *testing.T) {
if params.MaxLevelP() == -1 {
@@ -1303,12 +1044,14 @@ func testWriteAndRead(tc *TestContext, bpw2 int, t *testing.T) {
func testMarshaller(tc *TestContext, t *testing.T) {
params := tc.params
t.Run("Marshaller/Parameters", func(t *testing.T) {
for _, p := range testInsecure {
params, err := NewParametersFromLiteral(p.ParametersLiteral)
require.NoError(t, err)
buffer.RequireSerializerCorrect(t, &params)
}
bytes, err := params.MarshalBinary()
require.Nil(t, err)
var p Parameters
require.Nil(t, p.UnmarshalBinary(bytes))
require.Equal(t, params, p)
})
t.Run("Marshaller/MetaData", func(t *testing.T) {

View File

@@ -14,6 +14,7 @@ import (
"time"
"github.com/tuneinsight/lattigo/v5/core/rlwe"
"github.com/tuneinsight/lattigo/v5/he"
"github.com/tuneinsight/lattigo/v5/he/hebin"
"github.com/tuneinsight/lattigo/v5/he/hefloat"
"github.com/tuneinsight/lattigo/v5/ring"
@@ -211,12 +212,35 @@ func main() {
fmt.Printf("Evaluating BlindRotations... ")
now = time.Now()
// Extracts & EvalBR(LWEs, indexTestPoly) on the fly -> Repack(LWEs, indexRepack) -> RLWE
ctN12, err = evalBR.EvaluateAndRepack(ctN11, testPolyMap, repackIndex, blindRotateKey, evk)
if err != nil {
// Extracts & EvalBR(LWEs, indexTestPoly)
var ctsN12 = map[int]*rlwe.Ciphertext{}
if ctsN12, err = evalBR.Evaluate(ctN11, testPolyMap, blindRotateKey); err != nil {
panic(err)
}
fmt.Printf("Done (%s)\n", time.Since(now))
// Instantiate the repacking keys
evkRepacking := &he.RingPackingEvaluationKey{
Parameters: map[int]rlwe.ParameterProvider{paramsN12.LogN(): &paramsN12},
RepackKeys: map[int]rlwe.EvaluationKeySet{paramsN12.LogN(): evk},
}
// Instantiate the repacking evaluator from the repacking keys
evalRepack := he.NewRingPackingEvaluator(evkRepacking)
fmt.Printf("Evaluating Ring-Packing... ")
now = time.Now()
// Permutes the ciphertexts according to the repacking map
var ctsN12Permuted = map[int]*rlwe.Ciphertext{}
for i := range ctsN12 {
ctsN12Permuted[repackIndex[i]] = ctsN12[i]
}
// Repacks the ciphertexts
if ctN12, err = evalRepack.Repack(ctsN12Permuted); err != nil {
panic(err)
}
fmt.Printf("Done (%s)\n", time.Since(now))
ctN12.IsBatched = false
ctN12.LogDimensions = paramsN12.LogMaxDimensions()
ctN12.Scale = paramsN12.DefaultScale()

View File

@@ -25,7 +25,7 @@ type Evaluator struct {
galoisGenDiscreteLog map[uint64]int
}
// NewEvaluator instantiates a new Evaluator.
// NewEvaluator instantiates a new [Evaluator].
func NewEvaluator(paramsBR, paramsLWE rlwe.ParameterProvider) (eval *Evaluator) {
eval = new(Evaluator)
eval.Evaluator = rgsw.NewEvaluator(paramsBR, nil)
@@ -43,27 +43,6 @@ func NewEvaluator(paramsBR, paramsLWE rlwe.ParameterProvider) (eval *Evaluator)
return
}
// EvaluateAndRepack extracts on the fly LWE samples, evaluates the provided blind rotations on the LWE and repacks everything into a single rlwe.Ciphertext.
// testPolyWithSlotIndex : a map with [slot_index] -> blind rotation
// repackIndex : a map with [slot_index_have] -> slot_index_want
func (eval *Evaluator) EvaluateAndRepack(ct *rlwe.Ciphertext, testPolyWithSlotIndex map[int]*ring.Poly, repackIndex map[int]int, key BlindRotationEvaluationKeySet, repackKey rlwe.EvaluationKeySet) (res *rlwe.Ciphertext, err error) {
cts, err := eval.Evaluate(ct, testPolyWithSlotIndex, key)
if err != nil {
return nil, err
}
ciphertexts := make(map[int]*rlwe.Ciphertext)
for i := range cts {
ciphertexts[repackIndex[i]] = cts[i]
}
eval.Evaluator = eval.Evaluator.WithKey(repackKey)
return eval.Pack(ciphertexts, eval.paramsBR.LogN(), true)
}
// Evaluate extracts on the fly LWE samples and evaluates the provided blind rotation on the LWE.
// testPolyWithSlotIndex : a map with [slot_index] -> blind rotation
// Returns a map[slot_index] -> BlindRotate(ct[slot_index])

View File

@@ -6,6 +6,7 @@ import (
"math/big"
"github.com/tuneinsight/lattigo/v5/core/rlwe"
"github.com/tuneinsight/lattigo/v5/he"
"github.com/tuneinsight/lattigo/v5/he/hefloat"
"github.com/tuneinsight/lattigo/v5/ring"
"github.com/tuneinsight/lattigo/v5/schemes/ckks"
@@ -37,7 +38,7 @@ type Evaluator struct {
SkDebug *rlwe.SecretKey
}
// NewEvaluator creates a new Evaluator.
// NewEvaluator creates a new [Evaluator].
func NewEvaluator(btpParams Parameters, evk *EvaluationKeys) (eval *Evaluator, err error) {
eval = &Evaluator{}
@@ -67,9 +68,9 @@ func NewEvaluator(btpParams Parameters, evk *EvaluationKeys) (eval *Evaluator, e
eval.Parameters = btpParams
if paramsN1.N() != paramsN2.N() {
eval.xPow2N1 = rlwe.GenXPow2(paramsN1.RingQ().AtLevel(0), paramsN2.LogN(), false)
eval.xPow2N2 = rlwe.GenXPow2(paramsN2.RingQ().AtLevel(0), paramsN2.LogN(), false)
eval.xPow2InvN2 = rlwe.GenXPow2(paramsN2.RingQ(), paramsN2.LogN(), true)
eval.xPow2N1 = he.GenXPow2NTT(paramsN1.RingQ().AtLevel(0), paramsN2.LogN(), false)
eval.xPow2N2 = he.GenXPow2NTT(paramsN2.RingQ().AtLevel(0), paramsN2.LogN(), false)
eval.xPow2InvN2 = he.GenXPow2NTT(paramsN2.RingQ(), paramsN2.LogN(), true)
}
if btpParams.Mod1ParametersLiteral.Mod1Type == hefloat.SinContinuous && btpParams.Mod1ParametersLiteral.DoubleAngle != 0 {
@@ -119,41 +120,23 @@ func NewEvaluator(btpParams Parameters, evk *EvaluationKeys) (eval *Evaluator, e
return
}
// ShallowCopy creates a shallow copy of this Evaluator in which all the read-only data-structures are
// ShallowCopy creates a shallow copy of this [Evaluator] in which all the read-only data-structures are
// shared with the receiver and the temporary buffers are reallocated. The receiver and the returned
// Evaluator can be used concurrently.
func (eval Evaluator) ShallowCopy() *Evaluator {
heEvaluator := eval.Evaluator.ShallowCopy()
paramsN1 := eval.ResidualParameters
paramsN2 := eval.BootstrappingParameters
var DomainSwitcher ckks.DomainSwitcher
if paramsN1.RingType() == ring.ConjugateInvariant {
var err error
if DomainSwitcher, err = ckks.NewDomainSwitcher(paramsN2.Parameters, eval.EvkCmplxToReal, eval.EvkRealToCmplx); err != nil {
panic(fmt.Errorf("cannot NewBootstrapper: ckks.NewDomainSwitcher: %w", err))
}
}
params := eval.BootstrappingParameters
return &Evaluator{
Parameters: eval.Parameters,
EvaluationKeys: eval.EvaluationKeys,
Mod1Parameters: eval.Mod1Parameters,
S2CDFTMatrix: eval.S2CDFTMatrix,
C2SDFTMatrix: eval.C2SDFTMatrix,
Evaluator: heEvaluator,
xPow2N1: eval.xPow2N1,
xPow2N2: eval.xPow2N2,
xPow2InvN2: eval.xPow2InvN2,
DomainSwitcher: DomainSwitcher,
DFTEvaluator: hefloat.NewDFTEvaluator(paramsN2, heEvaluator),
Mod1Evaluator: hefloat.NewMod1Evaluator(heEvaluator, hefloat.NewPolynomialEvaluator(paramsN2, heEvaluator), eval.Mod1Parameters),
SkDebug: eval.SkDebug,
DFTEvaluator: hefloat.NewDFTEvaluator(params, heEvaluator),
Mod1Evaluator: hefloat.NewMod1Evaluator(heEvaluator, hefloat.NewPolynomialEvaluator(params, heEvaluator), eval.Mod1Parameters),
}
}
// CheckKeys checks if all the necessary keys are present in the instantiated Evaluator
// CheckKeys checks if all the necessary keys are present in the instantiated [Evaluator]
func (eval Evaluator) checkKeys(evk *EvaluationKeys) (err error) {
if _, err = evk.GetRelinearizationKey(); err != nil {

871
he/ring_packing.go Normal file
View File

@@ -0,0 +1,871 @@
package he
import (
"fmt"
"math/big"
"github.com/tuneinsight/lattigo/v5/core/rlwe"
"github.com/tuneinsight/lattigo/v5/ring"
"github.com/tuneinsight/lattigo/v5/utils"
)
// RingPackingEvaluator is an evaluator for Ring-LWE packing operations.
// All fields of this struct are public, enabling custom instantiations.
type RingPackingEvaluator struct {
*RingPackingEvaluationKey
Evaluators map[int]*rlwe.Evaluator
//XPow2NTT: [1, x, x^2, x^4, ..., x^2^s] / (X^2^s +1)
XPow2NTT map[int][]ring.Poly
//XInvPow2NTT: [1, x^-1, x^-2, x^-4, ..., x^-2^s/2] / (X^2^s +1)
XInvPow2NTT map[int][]ring.Poly
}
// NewRingPackingEvaluator instantiates a new RingPackingEvaluator from a RingPackingEvaluationKey.
func NewRingPackingEvaluator(evk *RingPackingEvaluationKey) *RingPackingEvaluator {
Evaluators := map[int]*rlwe.Evaluator{}
XPow2NTT := map[int][]ring.Poly{}
XInvPow2NTT := map[int][]ring.Poly{}
minLogN := evk.MinLogN()
maxLogN := evk.MaxLogN()
levelQ := evk.Parameters[minLogN].GetRLWEParameters().MaxLevel()
for i := minLogN; i < maxLogN+1; i++ {
pi := evk.Parameters[i].GetRLWEParameters()
Evaluators[i] = rlwe.NewEvaluator(pi, nil)
XPow2NTT[i] = GenXPow2NTT(pi.RingQ().AtLevel(levelQ), pi.LogN(), false)
XInvPow2NTT[i] = GenXPow2NTT(pi.RingQ().AtLevel(levelQ), pi.LogN(), true)
}
return &RingPackingEvaluator{
RingPackingEvaluationKey: evk,
Evaluators: Evaluators,
XPow2NTT: XPow2NTT,
XInvPow2NTT: XInvPow2NTT,
}
}
// ShallowCopy creates a shallow copy of this struct in which all the read-only data-structures are
// shared with the receiver and the temporary buffers are reallocated. The receiver and the returned
// Evaluators can be used concurrently.
func (eval RingPackingEvaluator) ShallowCopy() *RingPackingEvaluator {
Evaluators := map[int]*rlwe.Evaluator{}
for i := range eval.Evaluators {
Evaluators[i] = eval.Evaluators[i].ShallowCopy()
}
return &RingPackingEvaluator{
RingPackingEvaluationKey: eval.RingPackingEvaluationKey,
Evaluators: Evaluators,
XPow2NTT: eval.XPow2NTT,
XInvPow2NTT: eval.XInvPow2NTT,
}
}
// Extract takes as input a ciphertext encrypting P(X) = c[i] * X^i and returns a map of
// ciphertexts of degree eval.MinLogN(), each encrypting P(X) = c[i] * X^{0} for i in idx.
// All non-constant coefficients are zeroed and thus correctness is ensured if this method
// is composed with either Repack or RepackNaive.
func (eval RingPackingEvaluator) Extract(ct *rlwe.Ciphertext, idx map[int]bool) (cts map[int]*rlwe.Ciphertext, err error) {
return eval.extract(ct, idx, false)
}
// ExtractNaive takes as input a ciphertext encrypting P(X) = c[i] * X^i and returns a map of
// ciphertexts of degree eval.MinLogN(), each encrypting P(X) = c[i] * X^{0} for i in idx.
// Non-constant coefficients are NOT zeroed thus correctness is only ensured if this method
// is composed with Repack.
//
// If eval.MinLogN() = eval.MaxLogN(), no evaluation keys are required for this method.
// If eval.MinLogN() < eval.MaxLogN(), only RingSwitchingKeys are required for this method.
func (eval RingPackingEvaluator) ExtractNaive(ct *rlwe.Ciphertext, idx map[int]bool) (cts map[int]*rlwe.Ciphertext, err error) {
return eval.extract(ct, idx, true)
}
// If naive = false, then all non-constant coefficients are zeroed.
func (eval RingPackingEvaluator) extract(ct *rlwe.Ciphertext, idx map[int]bool, naive bool) (cts map[int]*rlwe.Ciphertext, err error) {
logNMax := ct.LogN()
logNMin := eval.MinLogN()
level := ct.Level()
logNFactor := logNMax - logNMin
NFactor := 1 << logNFactor
keys := utils.GetSortedKeys(idx)
_, logGap, err := getMinimumGap(keys)
if err != nil {
return nil, fmt.Errorf("getMinimumGap: %w", err)
}
// First recursively splits the ciphertexts into smaller ciphertexts of half the ring
// degree until the minimum ring degre is reached
tmpCts := make(map[int]*rlwe.Ciphertext)
tmpCts[0] = ct.CopyNew()
for i := 0; i < logNFactor; i++ {
t := 1 << i
// Each split of the ring divides the gap a factor of two
logGap = utils.Max(0, logGap-1)
for j := 0; j < t; j++ {
if tmpCts[j] != nil {
ctEvenNHalf := rlwe.NewCiphertext(eval.Parameters[logNMax-i-1], 1, level)
ctOddNHalf := rlwe.NewCiphertext(eval.Parameters[logNMax-i-1], 1, level)
if err = eval.Split(tmpCts[j], ctEvenNHalf, ctOddNHalf); err != nil {
return nil, fmt.Errorf("eval.split(cts[%d]): %w", j, err)
}
tmpCts[j] = ctEvenNHalf
tmpCts[j+t] = ctOddNHalf
}
}
}
gap := 1 << logGap
// Applies the same split on the index map, but also update the
// indexes to take into account the new ordering
buckets := make(map[int][]int)
for _, i := range keys {
bucket := i & (NFactor - 1)
buckets[bucket] = append(buckets[bucket], i/NFactor)
}
// For each small ciphertext, extracts the relevant values
cts = make(map[int]*rlwe.Ciphertext)
for i := range buckets {
var ciphertexts map[int]*rlwe.Ciphertext
if naive {
ciphertexts = map[int]*rlwe.Ciphertext{}
for _, j := range buckets[i] {
ciphertexts[j] = tmpCts[i].CopyNew()
}
XInvPow2NTT := eval.XInvPow2NTT[logNMin]
ringQ := eval.Parameters[logNMin].GetRLWEParameters().RingQ().AtLevel(level)
// Rotates ciphertexts to move c[i] * X^{i} -> c[i] * X^{0}
// by sequentially multplying with the appropriate X^{-2^{i}}.
for i := 0; i < logNMin; i++ {
for j := range ciphertexts {
if (j>>i)&1 == 1 {
ct := ciphertexts[j]
ringQ.MulCoeffsMontgomery(ct.Value[0], XInvPow2NTT[i], ct.Value[0])
ringQ.MulCoeffsMontgomery(ct.Value[1], XInvPow2NTT[i], ct.Value[1])
}
}
}
} else {
if ciphertexts, err = eval.Expand(tmpCts[i], logGap); err != nil {
return nil, fmt.Errorf("evalN.expand(tmpCt[%d], %d): %w", i, logGap, err)
}
}
for _, j := range buckets[i] {
if ct, ok := ciphertexts[j]; ok {
cts[i+j*NFactor] = ct
} else {
return nil, fmt.Errorf("invalid ciphertexts map: index i+j*(NFactor*gap)=%d is nil", i+j*(NFactor*gap))
}
}
}
return
}
// Split splits a ciphertext of degree N into two ciphertexts of degree N/2:
// ctN[X] = ctEvenNHalf[Y] + X * ctOddNHalf[Y] where Y = X^2.
func (eval RingPackingEvaluator) Split(ctN, ctEvenNHalf, ctOddNHalf *rlwe.Ciphertext) (err error) {
if eval.MinLogN() == eval.MaxLogN() {
return fmt.Errorf("method is not supported when eval.MinLogN() == eval.MaxLogN()")
}
if ctN.LogN() <= eval.MinLogN() {
return fmt.Errorf("ctN.Log() must be greater than eval.MinLogN()")
}
if ctEvenNHalf == nil {
return fmt.Errorf("ctEvenNHalf cannot be nil")
}
if ctEvenNHalf.LogN() != ctN.LogN()-1 {
return fmt.Errorf("ctEvenNHalf.LogN() must be equal to ctN.LogN()-1")
}
LogN := ctN.LogN()
evalN := eval.Evaluators[LogN]
evkNToNHalf := eval.RingSwitchingKeys[LogN][LogN-1]
ctTmp := rlwe.NewCiphertext(eval.Parameters[LogN], 1, ctN.Level())
// SkN -> SkNHalf
if err = evalN.ApplyEvaluationKey(ctN, evkNToNHalf, ctTmp); err != nil {
return fmt.Errorf("ApplyEvaluationKey: %w", err)
}
r := eval.Parameters[LogN].GetRLWEParameters().RingQ().AtLevel(ctN.Level())
// Maps to smaller ring degree X -> Y = X^{2}
*ctEvenNHalf.MetaData = *ctN.MetaData
rlwe.SwitchCiphertextRingDegreeNTT(ctTmp.El(), r, ctEvenNHalf.El())
ctEvenNHalf.LogDimensions.Cols--
// Maps to smaller ring degree X -> Y = X^{2}
if ctOddNHalf != nil {
if ctOddNHalf.LogN() != ctN.LogN()-1 {
return fmt.Errorf("ctOddNHalf.LogN() must be equal to ctN.LogN()-1")
}
*ctOddNHalf.MetaData = *ctN.MetaData
r.MulCoeffsMontgomery(ctTmp.Value[0], eval.XInvPow2NTT[LogN][0], ctTmp.Value[0])
r.MulCoeffsMontgomery(ctTmp.Value[1], eval.XInvPow2NTT[LogN][0], ctTmp.Value[1])
rlwe.SwitchCiphertextRingDegreeNTT(ctTmp.El(), r, ctOddNHalf.El())
ctOddNHalf.LogDimensions.Cols--
}
return
}
// SplitNew splits a ciphertext of degree N into two ciphertexts of degree N/2:
// ctN[X] = ctEvenNHalf[Y] + X * ctOddNHalf[Y] where Y = X^2.
func (eval RingPackingEvaluator) SplitNew(ctN *rlwe.Ciphertext) (ctEvenNHalf, ctOddNHalf *rlwe.Ciphertext, err error) {
if eval.MinLogN() == eval.MaxLogN() {
return nil, nil, fmt.Errorf("method is not supported when eval.MinLogN() == eval.MaxLogN()")
}
LogN := ctN.LogN()
ctEvenNHalf = rlwe.NewCiphertext(eval.Parameters[LogN-1], 1, ctN.Level())
ctOddNHalf = rlwe.NewCiphertext(eval.Parameters[LogN-1], 1, ctN.Level())
return ctEvenNHalf, ctOddNHalf, eval.Split(ctN, ctEvenNHalf, ctOddNHalf)
}
// Repack takes as input a map of ciphertext and repacks the constant coefficient each ciphertext
// into a single ciphertext of degree eval.MaxLogN() following the indexing of the map.
//
// For example, if cts = map[int]*rlwe.Ciphertext{0:ct0, 1:ct1, 4:ct2}, then the method will return
// a ciphertext encrypting P(X) = ct0[0] + ct1[0] * X + ct2[0] * X^4.
//
// The method accepts ciphertexts of a ring degree between eval.MinLogN() and eval.MaxLogN().
//
// All non-constant coefficient are zeroed during the repacking, thus correctness is ensured if this
// method can be composed with either Extract or ExtractNaive.
func (eval RingPackingEvaluator) Repack(cts map[int]*rlwe.Ciphertext) (ct *rlwe.Ciphertext, err error) {
return eval.repack(cts, false)
}
// RepackNaive takes as input a map of ciphertext and repacks the constant coefficient each ciphertext
// into a single ciphertext of degree eval.MaxLogN() following the indexing of the map.
//
// For example, if cts = map[int]*rlwe.Ciphertext{0:ct0, 1:ct1, 4:ct2}, then the method will return
// a ciphertext encrypting P(X) = ct0[0] + ct1[0] * X + ct2[0] * X^4.
//
// The method accepts ciphertexts of a ring degree between eval.MinLogN() and eval.MaxLogN().
//
// If eval.MinLogN() = eval.MaxLogN(), no evaluation keys are required for this method.
// If eval.MinLogN() < eval.MaxLogN(), only RingSwitchingKeys are required for this method.
//
// Unlike Repack, non-constant coefficient are NOT zeroed during the repacking, thus correctness is only
// ensured if this method is composed with either Extract.
func (eval RingPackingEvaluator) RepackNaive(cts map[int]*rlwe.Ciphertext) (ct *rlwe.Ciphertext, err error) {
return eval.repack(cts, true)
}
func (eval RingPackingEvaluator) repack(cts map[int]*rlwe.Ciphertext, naive bool) (ct *rlwe.Ciphertext, err error) {
keys := utils.GetSortedKeys(cts)
logNMin := cts[keys[0]].LogN()
logNMax := eval.MaxLogN()
level := cts[keys[0]].Level()
logNFactor := logNMax - logNMin
NFactor := 1 << logNFactor
// List of map containing the repacking of cts
ctsSmallN := make([]map[int]*rlwe.Ciphertext, NFactor)
for i := range ctsSmallN {
ctsSmallN[i] = map[int]*rlwe.Ciphertext{}
}
// Assigns to each map the corresponding ciphertext.
// This takes into account the future merging, that merges
// ciphertexts in a base-2 tree-like fashion by evaluating
// ctN[X] = ctEvenNHalf[Y] + X * ctOddNHalf[Y] where Y = X^2.
for _, i := range keys {
ctsSmallN[i&(NFactor-1)][i/NFactor] = cts[i]
}
// Map of repacked ciphertext that will then be merged together.
// Each merging takes two ciphertexts, doubles their ring degree
// and adds them together.
ctsLargeN := map[int]*rlwe.Ciphertext{}
for i := 0; i < NFactor; i++ {
if naive {
tmpCts := ctsSmallN[i]
XPow2NTT := eval.XPow2NTT[logNMin]
ringQ := eval.Parameters[logNMin].GetRLWEParameters().RingQ().AtLevel(level)
for i := 0; i < logNMin; i++ {
t := 1 << (logNMin - 1 - i)
for jx, jy := 0, t; jx < t; jx, jy = jx+1, jy+1 {
a := tmpCts[jx]
b := tmpCts[jy]
if b != nil {
//X^(N/2^L)
ringQ.MulCoeffsMontgomery(b.Value[0], XPow2NTT[len(XPow2NTT)-i-1], b.Value[0])
ringQ.MulCoeffsMontgomery(b.Value[1], XPow2NTT[len(XPow2NTT)-i-1], b.Value[1])
if a != nil {
// a = a + b * X^{N/2^{i}}
ringQ.Add(a.Value[0], b.Value[0], a.Value[0])
ringQ.Add(a.Value[1], b.Value[1], a.Value[1])
} else {
// if ct[jx] == nil, then simply re-assigns
tmpCts[jx] = tmpCts[jy]
}
tmpCts[jy] = nil
}
}
}
ctsLargeN[i] = tmpCts[0]
} else {
if len(ctsSmallN[i]) != 0 {
if ctsLargeN[i], err = eval.Pack(ctsSmallN[i], logNMin, true); err != nil {
return nil, fmt.Errorf("eval.pack(ctsSmallN[%d], logGap=%d, true): %w", i, logNMin, err)
}
}
}
}
// Merges the cipehrtexts in a base-2 tree like fashion.
for i := logNFactor - 1; i >= 0; i-- {
t := 1 << i
for j := 0; j < t; j++ {
if ctsLargeN[j] != nil || ctsLargeN[j+1] != nil {
ctN := rlwe.NewCiphertext(eval.Parameters[logNMax-i], 1, level)
if err = eval.Merge(ctsLargeN[j], ctsLargeN[j+t], ctN); err != nil {
return nil, fmt.Errorf("eval.split(cts[%d]): %w", j, err)
}
ctsLargeN[j] = ctN
ctsLargeN[j+t] = nil
}
}
}
return ctsLargeN[0], nil
}
// Merge merges two ciphertexts of degree N/2 into a ciphertext of degre N:
// ctN[X] = ctEvenNHalf[Y] + X * ctOddNHalf[Y] where Y = X^2.
func (eval RingPackingEvaluator) Merge(ctEvenNHalf, ctOddNHalf, ctN *rlwe.Ciphertext) (err error) {
if eval.MinLogN() == eval.MaxLogN() {
return fmt.Errorf("method is not supported when eval.MinLogN() == eval.MaxLogN()")
}
if ctEvenNHalf == nil {
return fmt.Errorf("ctEvenNHalf cannot be nil")
}
if ctEvenNHalf.LogN() >= eval.MaxLogN() {
return fmt.Errorf("ctEvenNHalf.LogN() must be smaller than eval.MaxLogN()")
}
if ctN.LogN() != ctEvenNHalf.LogN()+1 {
return fmt.Errorf("ctN.LogN() must be equal to ctEvenNHalf.LogN()+1")
}
if ctOddNHalf != nil {
if ctEvenNHalf.LogN() != ctOddNHalf.LogN() {
return fmt.Errorf("ctEvenNHalf.LogN() and ctOddNHalf.LogN() must be equal")
}
}
LogN := ctN.LogN()
evalN := eval.Evaluators[LogN]
evkNHalfToN := eval.RingSwitchingKeys[LogN-1][LogN]
r := eval.Parameters[LogN].GetRLWEParameters().RingQ().AtLevel(ctN.Level())
ctTmp := rlwe.NewCiphertext(eval.Parameters[LogN], 1, ctN.Level())
if ctEvenNHalf != nil {
*ctN.MetaData = *ctEvenNHalf.MetaData
rlwe.SwitchCiphertextRingDegreeNTT(ctEvenNHalf.El(), r, ctN.El())
if ctOddNHalf != nil {
rlwe.SwitchCiphertextRingDegreeNTT(ctOddNHalf.El(), r, ctTmp.El())
r.MulCoeffsMontgomeryThenAdd(ctTmp.Value[0], eval.XPow2NTT[LogN][0], ctN.Value[0])
r.MulCoeffsMontgomeryThenAdd(ctTmp.Value[1], eval.XPow2NTT[LogN][0], ctN.Value[1])
}
}
// SkNHalf -> SkN
if err = evalN.ApplyEvaluationKey(ctN, evkNHalfToN, ctN); err != nil {
return fmt.Errorf("evalN.ApplyEvaluationKey(ctN, evkNToNHalf, ctN): %w", err)
}
ctN.LogDimensions.Cols++
return
}
// MergeNew merges two ciphertexts of degree N/2 into a ciphertext of degre N:
// ctN[X] = ctEvenNHalf[Y] + X * ctOddNHalf[Y] where Y = X^2.
func (eval RingPackingEvaluator) MergeNew(ctEvenNHalf, ctOddNHalf *rlwe.Ciphertext) (ctN *rlwe.Ciphertext, err error) {
if eval.MinLogN() == eval.MaxLogN() {
return nil, fmt.Errorf("method is not supported when eval.MinLogN() == eval.MaxLogN()")
}
if ctEvenNHalf == nil {
return nil, fmt.Errorf("ctEvenNHalf cannot be nil")
}
if ctEvenNHalf.LogN() >= eval.MaxLogN() {
return nil, fmt.Errorf("ctEvenNHalf.LogN() must be smaller than eval.MaxLogN()")
}
ctN = rlwe.NewCiphertext(eval.Parameters[ctEvenNHalf.LogN()+1], 1, ctEvenNHalf.Level())
return ctN, eval.Merge(ctEvenNHalf, ctOddNHalf, ctN)
}
// Expand expands a RLWE Ciphertext encrypting P(X) = ci * X^i and returns a map of
// ciphertexts, each encrypting P(X) = ci * X^0, indexed by i, for 0<= i < 2^{logN}
// and i divisible by 2^{logGap}.
//
// This method is a used as a sub-routine of the Extract method.
//
// The method will return an error if:
// - The input ciphertext degree is not one
// - The ring type is not ring.Standard
func (eval RingPackingEvaluator) Expand(ct *rlwe.Ciphertext, logGap int) (cts map[int]*rlwe.Ciphertext, err error) {
if ct.Degree() != 1 {
return nil, fmt.Errorf("ct.Degree() != 1")
}
logN := ct.LogN()
var params rlwe.Parameters
if p, ok := eval.Parameters[logN]; !ok {
return nil, fmt.Errorf("eval.Parameters[%d] is nil", logN)
} else {
params = *p.GetRLWEParameters()
}
if eval.ExtractKeys == nil {
return nil, fmt.Errorf("eval.ExtractKeys is nil")
}
var evk rlwe.EvaluationKeySet
if p, ok := eval.ExtractKeys[params.LogN()]; !ok {
return nil, fmt.Errorf("eval.ExtractKeys[%d] is nil", params.LogN())
} else {
evk = p
}
evalN := eval.Evaluators[params.LogN()].WithKey(evk)
xPow2 := eval.XInvPow2NTT[params.LogN()]
level := ct.Level()
ringQ := params.RingQ().AtLevel(level)
if params.RingType() != ring.Standard {
return nil, fmt.Errorf("method is only supported for ring.Type = ring.Standard (X^{-2^{i}} does not exist in the sub-ring Z[X + X^{-1}])")
}
cts = map[int]*rlwe.Ciphertext{}
cts[0] = ct.CopyNew()
cts[0].LogDimensions = ring.Dimensions{Rows: 0, Cols: 0}
if ct := cts[0]; !ct.IsNTT {
ringQ.NTT(ct.Value[0], ct.Value[0])
ringQ.NTT(ct.Value[1], ct.Value[1])
ct.IsNTT = true
}
// Multiplies by 2^{-logN} mod Q
NInv := new(big.Int).SetUint64(1 << logN)
NInv.ModInverse(NInv, ringQ.ModulusAtLevel[level])
ringQ.MulScalarBigint(cts[0].Value[0], NInv, cts[0].Value[0])
ringQ.MulScalarBigint(cts[0].Value[1], NInv, cts[0].Value[1])
gap := 1 << logGap
tmp, err := rlwe.NewCiphertextAtLevelFromPoly(level, []ring.Poly{evalN.BuffCt.Value[0], evalN.BuffCt.Value[1]})
// Sanity check, this error should not happen unless the
// evaluator's buffer thave been improperly tempered with.
if err != nil {
panic(err)
}
*tmp.MetaData = *ct.MetaData
for i := 0; i < logN; i++ {
n := 1 << i
galEl := uint64(ringQ.N()/n + 1)
for j := 0; j < n; j += gap {
c0 := cts[j]
// X -> X^{N/n + 1}
//[a, b, c, d] -> [a, -b, c, -d]
if err = evalN.Automorphism(c0, galEl, tmp); err != nil {
return nil, fmt.Errorf("evalN.Automorphism(c0, galEl, tmp): %w", err)
}
if j+n/gap > 0 {
c1 := cts[j].CopyNew()
// Zeroes odd coeffs: [a, b, c, d] + [a, -b, c, -d] -> [2a, 0, 2b, 0]
ringQ.Add(c0.Value[0], tmp.Value[0], c0.Value[0])
ringQ.Add(c0.Value[1], tmp.Value[1], c0.Value[1])
// Zeroes even coeffs: [a, b, c, d] - [a, -b, c, -d] -> [0, 2b, 0, 2d]
ringQ.Sub(c1.Value[0], tmp.Value[0], c1.Value[0])
ringQ.Sub(c1.Value[1], tmp.Value[1], c1.Value[1])
// c1 * X^{-2^{i}}: [0, 2b, 0, 2d] * X^{-n} -> [2b, 0, 2d, 0]
ringQ.MulCoeffsMontgomery(c1.Value[0], xPow2[i], c1.Value[0])
ringQ.MulCoeffsMontgomery(c1.Value[1], xPow2[i], c1.Value[1])
cts[j+n] = c1
} else {
// Zeroes odd coeffs: [a, b, c, d] + [a, -b, c, -d] -> [2a, 0, 2b, 0]
ringQ.Add(c0.Value[0], tmp.Value[0], c0.Value[0])
ringQ.Add(c0.Value[1], tmp.Value[1], c0.Value[1])
}
}
}
for _, ct := range cts {
if ct != nil && !ct.IsNTT {
ringQ.INTT(ct.Value[0], ct.Value[0])
ringQ.INTT(ct.Value[1], ct.Value[1])
ct.IsNTT = false
}
}
return
}
// Pack packs a map of of ciphertexts, each encrypting Pi(X) = ci * X^{i} for 0 <= i * 2^{inputLogGap} < 2^{LogN}
// and indexed by j, for 0<= j < 2^{eval.MaxLogN()} and returns ciphertext encrypting P(X) = Pi(X) * X^i.
// zeroGarbageSlots: if set to true, slots which are not multiples of X^{2^{logGap}} will be zeroed during the procedure.
//
// The method will return an error if:
// - The number of ciphertexts is 0
// - Any input ciphertext degree is not one
// - Gaps between ciphertexts is smaller than inputLogGap > N
// - The ring type is not ring.Standard
//
// Example: we want to pack 4 ciphertexts into one, and keep only coefficients which are a multiple of X^{4}.
//
// To do so, we must set logGap = 2.
// Here the `X` slots are treated as garbage slots that we want to discard during the procedure.
//
// input: map[int]{
// 0: [x00, X, X, X, x01, X, X, X], with logGap = 2
// 1: [x10, X, X, X, x11, X, X, X],
// 2: [x20, X, X, X, x21, X, X, X],
// 3: [x30, X, X, X, x31, X, X, X],
// }
//
// Step 1:
// map[0]: 2^{-1} * (map[0] + X^2 * map[2] + phi_{5^2}(map[0] - X^2 * map[2]) = [x00, X, x20, X, x01, X, x21, X]
// map[1]: 2^{-1} * (map[1] + X^2 * map[3] + phi_{5^2}(map[1] - X^2 * map[3]) = [x10, X, x30, X, x11, X, x31, X]
// Step 2:
// map[0]: 2^{-1} * (map[0] + X^1 * map[1] + phi_{5^4}(map[0] - X^1 * map[1]) = [x00, x10, x20, x30, x01, x11, x21, x22]
func (eval RingPackingEvaluator) Pack(cts map[int]*rlwe.Ciphertext, inputLogGap int, zeroGarbageSlots bool) (ct *rlwe.Ciphertext, err error) {
if len(cts) == 0 {
return nil, fmt.Errorf("len(cts) = 0")
}
keys := utils.GetSortedKeys(cts)
logN := cts[keys[0]].LogN()
var params rlwe.Parameters
if p, ok := eval.Parameters[logN]; !ok {
return nil, fmt.Errorf("eval.Parameters[%d] is nil", logN)
} else {
params = *p.GetRLWEParameters()
}
if eval.RepackKeys == nil {
return nil, fmt.Errorf("eval.RepackKeys is nil")
}
var evk rlwe.EvaluationKeySet
if p, ok := eval.RepackKeys[params.LogN()]; !ok {
return nil, fmt.Errorf("eval.RepackKeys[%d] is nil", params.LogN())
} else {
evk = p
}
evalN := eval.Evaluators[params.LogN()].WithKey(evk)
xPow2 := eval.XPow2NTT[params.LogN()]
if params.RingType() != ring.Standard {
return nil, fmt.Errorf("procedure is only supported for ring.Type = ring.Standard (X^{2^{i}} does not exist in the sub-ring Z[X + X^{-1}])")
}
level := cts[keys[0]].Level()
var gap, logGap int
if len(keys) > 1 {
if gap, logGap, err = getMinimumGap(keys); err != nil {
return nil, fmt.Errorf("getMinimumGap: %w", err)
}
} else {
gap = params.N()
logGap = params.LogN()
}
ringQ := params.RingQ().AtLevel(level)
logStart := logN - inputLogGap
logEnd := logN
if !zeroGarbageSlots {
if gap > 0 {
logEnd -= logGap
}
}
if logStart >= logEnd {
return nil, fmt.Errorf("gaps between ciphertexts is smaller than inputLogGap > N")
}
NInv := new(big.Int).SetUint64(uint64(1 << (logEnd - logStart)))
NInv.ModInverse(NInv, ringQ.ModulusAtLevel[level])
for _, key := range keys {
ct := cts[key]
if ct.Degree() != 1 {
return nil, fmt.Errorf("cts[%d].Degree() != 1", key)
}
if !ct.IsNTT {
ringQ.NTT(ct.Value[0], ct.Value[0])
ringQ.NTT(ct.Value[1], ct.Value[1])
ct.IsNTT = true
}
ringQ.MulScalarBigint(ct.Value[0], NInv, ct.Value[0])
ringQ.MulScalarBigint(ct.Value[1], NInv, ct.Value[1])
}
tmpa := &rlwe.Ciphertext{}
tmpa.Value = []ring.Poly{ringQ.NewPoly(), ringQ.NewPoly()}
tmpa.MetaData = &rlwe.MetaData{}
tmpa.MetaData.IsNTT = true
for i := logStart; i < logEnd; i++ {
t := 1 << (logN - 1 - i)
for jx, jy := 0, t; jx < t; jx, jy = jx+1, jy+1 {
a := cts[jx]
b := cts[jy]
if b != nil {
//X^(N/2^L)
ringQ.MulCoeffsMontgomery(b.Value[0], xPow2[len(xPow2)-i-1], b.Value[0])
ringQ.MulCoeffsMontgomery(b.Value[1], xPow2[len(xPow2)-i-1], b.Value[1])
if a != nil {
// tmpa = phi(a - b * X^{N/2^{i}}, 2^{i-1})
ringQ.Sub(a.Value[0], b.Value[0], tmpa.Value[0])
ringQ.Sub(a.Value[1], b.Value[1], tmpa.Value[1])
// a = a + b * X^{N/2^{i}}
ringQ.Add(a.Value[0], b.Value[0], a.Value[0])
ringQ.Add(a.Value[1], b.Value[1], a.Value[1])
} else {
// if ct[jx] == nil, then simply re-assigns
cts[jx] = cts[jy]
}
cts[jy] = nil
}
if a != nil {
var galEl uint64
if i == 0 {
galEl = ringQ.NthRoot() - 1
} else {
galEl = params.GaloisElement(1 << (i - 1))
}
if b != nil {
if err = evalN.Automorphism(tmpa, galEl, tmpa); err != nil {
return nil, fmt.Errorf("evalN.Automorphism(tmpa, galEl, tmpa): %w", err)
}
} else {
if err = evalN.Automorphism(a, galEl, tmpa); err != nil {
return nil, fmt.Errorf("evalN.Automorphism(a, galEl, tmpa): %w", err)
}
}
// a + b * X^{N/2^{i}} + phi(a - b * X^{N/2^{i}}, 2^{i-1})
ringQ.Add(a.Value[0], tmpa.Value[0], a.Value[0])
ringQ.Add(a.Value[1], tmpa.Value[1], a.Value[1])
} else if b != nil {
var galEl uint64
if i == 0 {
galEl = ringQ.NthRoot() - 1
} else {
galEl = params.GaloisElement(1 << (i - 1))
}
if err = evalN.Automorphism(b, galEl, tmpa); err != nil {
return nil, fmt.Errorf("evalN.Automorphism(b, galEl, tmpa): %w", err)
}
// b * X^{N/2^{i}} - phi(b * X^{N/2^{i}}, 2^{i-1}))
ringQ.Sub(b.Value[0], tmpa.Value[0], b.Value[0])
ringQ.Sub(b.Value[1], tmpa.Value[1], b.Value[1])
}
}
}
return cts[0], nil
}
// GenXPow2NTT generates X^({-1 if div else 1} * {2^{0 <= i < LogN}}) in NTT.
func GenXPow2NTT(r *ring.Ring, logN int, div bool) (xPow []ring.Poly) {
// Compute X^{-n} from 0 to LogN
xPow = make([]ring.Poly, logN)
moduli := r.ModuliChain()[:r.Level()+1]
BRC := r.GetBRedConstants()
var idx int
for i := 0; i < logN; i++ {
idx = 1 << i
if div {
idx = r.N() - idx
}
xPow[i] = r.NewPoly()
if i == 0 {
for j := range moduli {
xPow[i].Coeffs[j][idx] = ring.MForm(1, moduli[j], BRC[j])
}
r.NTT(xPow[i], xPow[i])
} else {
r.MulCoeffsMontgomery(xPow[i-1], xPow[i-1], xPow[i]) // X^{n} = X^{1} * X^{n-1}
}
}
if div {
r.Neg(xPow[0], xPow[0])
}
return
}
func getMinimumGap(list []int) (gap, logGap int, err error) {
// The loops over to find the smallest gap
gap = 0x7fffffffffffffff // 2^{63}-1
for i := 1; i < len(list); i++ {
a, b := list[i-1], list[i]
if a > b {
return gap, logGap, fmt.Errorf("invalid index list: element must be sorted from smallest to largest")
} else if a == b {
return gap, logGap, fmt.Errorf("invalid index list: contains duplicated elements")
}
if tmp := b - a; tmp < gap {
gap = tmp
}
if gap == 1 {
break
}
}
// Sets gap to the largest power-of-two that divides it.
// We will then discart all coefficients that are not a
// multiple of this gap (and thus possibly entire ciph-
// ertexts).
for gap&1 == 0 {
logGap++
gap >>= 1
}
return
}

181
he/ring_packing_keys.go Normal file
View File

@@ -0,0 +1,181 @@
package he
import (
"fmt"
"github.com/tuneinsight/lattigo/v5/core/rlwe"
"github.com/tuneinsight/lattigo/v5/ring"
"github.com/tuneinsight/lattigo/v5/utils"
)
// RingPackingEvaluationKey is a struct storing the
// ring packing evaluation keys.
// All fields of this struct are public, enabling
// custom instantiations.
type RingPackingEvaluationKey struct {
// Parameters are the different Parameters among
// which a ciphertext will be switched during the
// procedure. These parameters share the same primes
// but support different ring degrees.
Parameters map[int]rlwe.ParameterProvider
// RingSwitchingKeys are the ring degree switching keys
// indexed as map[inputLogN][outputLogN]
RingSwitchingKeys map[int]map[int]*rlwe.EvaluationKey
// RepackKeys are the [rlwe.EvaluationKey] used for the
// RLWE repacking.
RepackKeys map[int]rlwe.EvaluationKeySet
// ExtractKeys are the [rlwe.EvaluationKey] used for the
// RLWE extraction.
ExtractKeys map[int]rlwe.EvaluationKeySet
}
// MinLogN returns the minimum Log(N) among the supported ring degrees.
// This method requires that the field Parameters of [RingPackingEvaluationKey]
// has been populated.
func (rpk RingPackingEvaluationKey) MinLogN() (minLogN int) {
return utils.GetSortedKeys(rpk.Parameters)[0]
}
// MaxLogN returns the maximum Log(N) among the supported ring degrees.
// This method requires that the field Parameters of [RingPackingEvaluationKey]
// has been populated.
func (rpk RingPackingEvaluationKey) MaxLogN() (maxLogN int) {
return utils.GetSortedKeys(rpk.Parameters)[len(rpk.Parameters)-1]
}
// GenRingSwitchingKeys generates the [rlwe.Parameter]s and [rlwe.EvaluationKey]s
// to be able to split an [rlwe.Ciphertext] into two [rlwe.Ciphertext]s of half
// the ring degree and merge two [rlwe.Ciphertext]s into one [rlwe.Ciphertext]
// of twice the ring degree.
//
// The method returns the [rlwe.Parameter]s, [rlwe.EvaluationKey]s and ephemeral
// [rlwe.SecretKey]s used to generate the ring-switching [rlwe.EvaluationKey]s.
//
// See the methods [RingPackingEvaluator.Split] and [RingPackingEvaluator.Repack].
//
// This function will return an error if minLogN >= params.LogN().
func (rpk *RingPackingEvaluationKey) GenRingSwitchingKeys(params rlwe.ParameterProvider, sk *rlwe.SecretKey, minLogN int, evkParams rlwe.EvaluationKeyParameters) (ski map[int]*rlwe.SecretKey, err error) {
p := *params.GetRLWEParameters()
if minLogN >= p.LogN() {
return nil, fmt.Errorf("invalid minLogN: cannot be equal or larger than params.LogN()")
}
LevelQ, LevelP, _ := rlwe.ResolveEvaluationKeyParameters(p, []rlwe.EvaluationKeyParameters{evkParams})
Q := p.Q()
P := p.P()
Parameters := map[int]rlwe.ParameterProvider{}
Parameters[p.LogN()] = &p
ski = map[int]*rlwe.SecretKey{}
ski[p.LogN()] = sk
kgen := map[int]*rlwe.KeyGenerator{}
kgen[p.LogN()] = rlwe.NewKeyGenerator(p)
for i := minLogN; i < p.LogN(); i++ {
var pi rlwe.Parameters
if pi, err = rlwe.NewParametersFromLiteral(rlwe.ParametersLiteral{
LogN: i,
Q: Q[:LevelQ+1],
P: P[:LevelP+1],
NTTFlag: p.NTTFlag(),
DefaultScale: p.DefaultScale(),
}); err != nil {
return nil, fmt.Errorf("rlwe.NewParametersFromLiteral: %w", err)
}
kgen[i] = rlwe.NewKeyGenerator(pi)
ski[i] = kgen[i].GenSecretKeyNew()
Parameters[i] = &pi
}
// Ring switching evaluation keys
RingSwitchingKeys := map[int]map[int]*rlwe.EvaluationKey{}
for i := minLogN; i < p.LogN()+1; i++ {
RingSwitchingKeys[i] = map[int]*rlwe.EvaluationKey{}
}
for i := minLogN; i < p.LogN(); i++ {
RingSwitchingKeys[i][i+1] = kgen[i+1].GenEvaluationKeyNew(ski[i], ski[i+1], evkParams)
RingSwitchingKeys[i+1][i] = kgen[i+1].GenEvaluationKeyNew(ski[i+1], ski[i], evkParams)
}
rpk.Parameters = Parameters
rpk.RingSwitchingKeys = RingSwitchingKeys
return ski, nil
}
// GenRepackEvaluationKeys generates the set of params.LogN() [rlwe.EvaluationKey]s necessary to perform the repacking operation.
// See [RingPackingEvaluator.Repack] for additional information.
func (rpk *RingPackingEvaluationKey) GenRepackEvaluationKeys(params rlwe.ParameterProvider, sk *rlwe.SecretKey, evkParams rlwe.EvaluationKeyParameters) {
p := *params.GetRLWEParameters()
if rpk.RepackKeys == nil {
rpk.RepackKeys = map[int]rlwe.EvaluationKeySet{}
}
rpk.RepackKeys[p.LogN()] = rlwe.NewMemEvaluationKeySet(nil, rlwe.NewKeyGenerator(p).GenGaloisKeysNew(GaloisElementsForPack(p, p.LogN()), sk, evkParams)...)
}
// GenExtractEvaluationKeys generates the set of params.LogN() [rlwe.EvaluationKey]s necessary to perform the extraction operation.
// See [RingPackingEvaluator.Extract] for additional information.
func (rpk *RingPackingEvaluationKey) GenExtractEvaluationKeys(params rlwe.ParameterProvider, sk *rlwe.SecretKey, evkParams rlwe.EvaluationKeyParameters) {
p := *params.GetRLWEParameters()
if rpk.ExtractKeys == nil {
rpk.ExtractKeys = map[int]rlwe.EvaluationKeySet{}
}
rpk.ExtractKeys[p.LogN()] = rlwe.NewMemEvaluationKeySet(nil, rlwe.NewKeyGenerator(p).GenGaloisKeysNew(GaloisElementsForExpand(p, p.LogN()), sk, evkParams)...)
}
// GaloisElementsForExpand returns the list of Galois elements required
// to perform the `Expand` operation with parameter `logN`.
func GaloisElementsForExpand(params rlwe.ParameterProvider, logN int) (galEls []uint64) {
galEls = make([]uint64, logN)
NthRoot := params.GetRLWEParameters().RingQ().NthRoot()
for i := 0; i < logN; i++ {
galEls[i] = uint64(NthRoot/(2<<i) + 1)
}
return
}
// GaloisElementsForPack returns the list of Galois elements required to perform the `Pack` operation.
func GaloisElementsForPack(params rlwe.ParameterProvider, logGap int) (galEls []uint64) {
p := params.GetRLWEParameters()
// Sanity check
if logGap > p.LogN() || logGap < 0 {
panic(fmt.Errorf("cannot GaloisElementsForPack: logGap > logN || logGap < 0"))
}
galEls = make([]uint64, 0, logGap)
for i := 0; i < logGap; i++ {
galEls = append(galEls, p.GaloisElement(1<<i))
}
switch p.RingType() {
case ring.Standard:
if logGap == p.LogN() {
galEls = append(galEls, p.GaloisElementOrderTwoOrthogonalSubgroup())
}
default:
panic("cannot GaloisElementsForPack: invalid ring type")
}
return
}

540
he/ring_packing_test.go Normal file
View File

@@ -0,0 +1,540 @@
package he
import (
"fmt"
"math"
"math/big"
"math/bits"
"math/rand"
"runtime"
"testing"
"github.com/stretchr/testify/require"
"github.com/tuneinsight/lattigo/v5/core/rlwe"
"github.com/tuneinsight/lattigo/v5/ring"
"github.com/tuneinsight/lattigo/v5/utils"
)
const (
LogNLarge = 10
LogNSmall = 8
)
func testString(params rlwe.Parameters, opname string) string {
return fmt.Sprintf("%s/logN=%d/Qi=%d/Pi=%d/NTT=%t",
opname,
params.LogN(),
params.QCount(),
params.PCount(),
params.NTTFlag())
}
func TestRLWE(t *testing.T) {
var err error
paramsLit := rlwe.ParametersLiteral{
LogN: LogNLarge,
LogQ: []int{60},
LogP: []int{60},
NTTFlag: true,
}
for _, NTTFlag := range []bool{true, false} {
paramsLit.NTTFlag = NTTFlag
var params rlwe.Parameters
if params, err = rlwe.NewParametersFromLiteral(paramsLit); err != nil {
t.Fatal(err)
}
tc, err := NewTestContext(params)
require.NoError(t, err)
for _, testSet := range []func(tc *TestContext, t *testing.T){
testRingPacking,
} {
testSet(tc, t)
runtime.GC()
}
}
}
type TestContext struct {
params rlwe.Parameters
kgen *rlwe.KeyGenerator
enc *rlwe.Encryptor
dec *rlwe.Decryptor
sk *rlwe.SecretKey
pk *rlwe.PublicKey
}
func NewTestContext(params rlwe.Parameters) (tc *TestContext, err error) {
kgen := rlwe.NewKeyGenerator(params)
sk := kgen.GenSecretKeyNew()
pk := kgen.GenPublicKeyNew(sk)
enc := rlwe.NewEncryptor(params, sk)
dec := rlwe.NewDecryptor(params, sk)
return &TestContext{
params: params,
kgen: kgen,
sk: sk,
pk: pk,
enc: enc,
dec: dec,
}, nil
}
func testRingPacking(tc *TestContext, t *testing.T) {
params := tc.params
sk := tc.sk
enc := tc.enc
dec := tc.dec
level := params.MaxLevel()
evkParams := rlwe.EvaluationKeyParameters{
LevelQ: utils.Pointy(params.MaxLevelQ()),
LevelP: utils.Pointy(params.MaxLevelP()),
}
evkRP := RingPackingEvaluationKey{}
ski, err := evkRP.GenRingSwitchingKeys(params, sk, LogNSmall, evkParams)
require.NoError(t, err)
evkRP.GenRepackEvaluationKeys(evkRP.Parameters[LogNSmall], ski[LogNSmall], evkParams)
evkRP.GenRepackEvaluationKeys(evkRP.Parameters[params.LogN()], ski[params.LogN()], evkParams)
evkRP.GenExtractEvaluationKeys(evkRP.Parameters[LogNSmall], ski[LogNSmall], evkParams)
eval := NewRingPackingEvaluator(&evkRP)
t.Run(testString(params, "Split"), func(t *testing.T) {
pt := genPlaintextNTT(params, level, 1<<40)
ct, err := enc.EncryptNew(pt)
require.NoError(t, err)
ctEvenNHalf, ctOddNHalf, err := eval.SplitNew(ct)
if eval.MaxLogN() == eval.MinLogN() {
require.Error(t, err)
t.Skip("eval.MaxLogN() = eval.MinLogN()")
} else {
require.NoError(t, err)
paramsNHalf := eval.Parameters[ctEvenNHalf.LogN()].GetRLWEParameters()
r := paramsNHalf.RingQ().AtLevel(ct.Level())
decNHalf := rlwe.NewDecryptor(paramsNHalf, ski[paramsNHalf.LogN()])
ptEve := decNHalf.DecryptNew(ctEvenNHalf)
ptOdd := decNHalf.DecryptNew(ctOddNHalf)
if ptEve.IsNTT {
r.INTT(ptEve.Value, ptEve.Value)
}
if ptOdd.IsNTT {
r.INTT(ptOdd.Value, ptOdd.Value)
}
if pt.IsNTT {
params.RingQ().AtLevel(ct.Level()).INTT(pt.Value, pt.Value)
}
for i := 0; i < level+1; i++ {
Q := r.SubRings[i].Modulus
ref := pt.Value.Coeffs[i]
eve := ptEve.Value.Coeffs[i]
odd := ptOdd.Value.Coeffs[i]
for j := 0; j < paramsNHalf.N(); j++ {
eve[j] = ring.CRed(eve[j]+Q-ref[j*2+0], Q)
odd[j] = ring.CRed(odd[j]+Q-ref[j*2+1], Q)
}
}
require.GreaterOrEqual(t, float64(paramsNHalf.LogN()+1), r.Log2OfStandardDeviation(ptEve.Value))
require.GreaterOrEqual(t, float64(paramsNHalf.LogN()+1), r.Log2OfStandardDeviation(ptOdd.Value))
}
})
t.Run(testString(params, "Merge"), func(t *testing.T) {
if eval.MaxLogN() == eval.MinLogN() {
t.Skip("eval.MaxLogN() = eval.MinLogN()")
}
paramsNHalf := *eval.Parameters[params.LogN()-1].GetRLWEParameters()
encNHalf := rlwe.NewEncryptor(paramsNHalf, ski[paramsNHalf.LogN()])
ptEve := genPlaintextNTT(paramsNHalf, level, 1<<40)
ptOdd := genPlaintextNTT(paramsNHalf, level, 1<<40)
ctEve, err := encNHalf.EncryptNew(ptEve)
require.NoError(t, err)
ctOdd, err := encNHalf.EncryptNew(ptOdd)
require.NoError(t, err)
ct, err := eval.MergeNew(ctEve, ctOdd)
require.NoError(t, err)
pt := dec.DecryptNew(ct)
if pt.IsNTT {
params.RingQ().AtLevel(level).INTT(pt.Value, pt.Value)
}
if ptEve.IsNTT {
paramsNHalf.RingQ().AtLevel(level).INTT(ptEve.Value, ptEve.Value)
}
if ptOdd.IsNTT {
paramsNHalf.RingQ().AtLevel(level).INTT(ptOdd.Value, ptOdd.Value)
}
for i := 0; i < level+1; i++ {
Q := params.RingQ().SubRings[i].Modulus
ref := pt.Value.Coeffs[i]
eve := ptEve.Value.Coeffs[i]
odd := ptOdd.Value.Coeffs[i]
for j := 0; j < paramsNHalf.N(); j++ {
ref[2*j+0] = ring.CRed(ref[2*j+0]+Q-eve[j], Q)
ref[2*j+1] = ring.CRed(ref[2*j+1]+Q-odd[j], Q)
}
}
require.GreaterOrEqual(t, float64(params.LogN()+1), params.RingQ().AtLevel(level).Log2OfStandardDeviation(pt.Value))
})
t.Run(testString(params, "Extract/Naive=False"), func(t *testing.T) {
if params.RingType() != ring.Standard {
t.Skip("Expand not supported for ring.Type = ring.ConjugateInvariant")
}
ringQ := params.RingQ().AtLevel(level)
pt := genPlaintextNTT(params, level, 1<<40)
ct, err := enc.EncryptNew(pt)
require.NoError(t, err)
gap := 17
logGap := bits.Len64(uint64(gap))
idx := map[int]bool{}
for i := 0; i < params.N()/gap; i++ {
idx[i*gap] = true
}
ciphertexts, err := eval.Extract(ct, idx)
require.NoError(t, err)
// Checks that the number of returned ciphertexts is equal
// to the size of the index and that each element in the
// index list has a corresponding extracted ciphertext.
require.Equal(t, len(ciphertexts), len(idx))
for i := range idx {
_, ok := ciphertexts[i]
require.True(t, ok)
}
// Decrypts & Checks
if pt.IsNTT {
ringQ.INTT(pt.Value, pt.Value)
}
paramsSmallN := evkRP.Parameters[ciphertexts[0].LogN()].GetRLWEParameters()
ptDec := rlwe.NewPlaintext(paramsSmallN, level)
ringQSmallN := paramsSmallN.RingQ().AtLevel(level)
Q := ringQSmallN.ModuliChain()
decSmallN := rlwe.NewDecryptor(paramsSmallN, ski[paramsSmallN.LogN()])
for i := range idx {
require.Equal(t, ciphertexts[i].LogN(), paramsSmallN.LogN())
decSmallN.Decrypt(ciphertexts[i], ptDec)
if ptDec.IsNTT {
ringQSmallN.INTT(ptDec.Value, ptDec.Value)
}
for j := 0; j < level+1; j++ {
ptDec.Value.Coeffs[j][0] = ring.CRed(ptDec.Value.Coeffs[j][0]+Q[j]-pt.Value.Coeffs[j][i], Q[j])
}
// Logs the noise
require.GreaterOrEqual(t, float64(params.LogN()+logGap+1), ringQSmallN.Log2OfStandardDeviation(ptDec.Value))
}
})
t.Run(testString(params, "Extract/Naive=True"), func(t *testing.T) {
if params.RingType() != ring.Standard {
t.Skip("Expand not supported for ring.Type = ring.ConjugateInvariant")
}
ringQ := params.RingQ().AtLevel(level)
pt := genPlaintextNTT(params, level, 1<<40)
ct, err := enc.EncryptNew(pt)
require.NoError(t, err)
// Generates some extraction index map that contains
// elements which are both not power and where the
// smallest gap is not a power of two (to test the
// worst case)
gap := 17
idx := map[int]bool{}
for i := 0; i < params.N()/gap; i++ {
idx[i*gap] = true
}
// Extract & returns a map containing the extracted RLWE ciphertexts.
ciphertexts, err := eval.ExtractNaive(ct, idx)
require.NoError(t, err)
// Checks that the number of returned ciphertexts is equal
// to the size of the index and that each element in the
// index list has a corresponding extracted ciphertext.
require.Equal(t, len(ciphertexts), len(idx))
for i := range idx {
_, ok := ciphertexts[i]
require.True(t, ok)
}
// Decrypts & Checks
if pt.IsNTT {
ringQ.INTT(pt.Value, pt.Value)
}
paramsSmallN := evkRP.Parameters[ciphertexts[0].LogN()].GetRLWEParameters()
ptDec := rlwe.NewPlaintext(paramsSmallN, level)
ringQSmallN := paramsSmallN.RingQ().AtLevel(level)
Q := ringQSmallN.ModuliChain()
decSmallN := rlwe.NewDecryptor(paramsSmallN, ski[paramsSmallN.LogN()])
for i := range idx {
require.Equal(t, ciphertexts[i].LogN(), paramsSmallN.LogN())
decSmallN.Decrypt(ciphertexts[i], ptDec)
if ptDec.IsNTT {
ringQSmallN.INTT(ptDec.Value, ptDec.Value)
}
for j := 0; j < level+1; j++ {
ptDec.Value.Coeffs[j][0] = ring.CRed(ptDec.Value.Coeffs[j][0]+Q[j]-pt.Value.Coeffs[j][i], Q[j])
}
// Logs the noise
coeffs := []*big.Int{new(big.Int)}
ringQSmallN.PolyToBigintCentered(ptDec.Value, ringQ.N(), coeffs)
noise := math.Log2(math.Abs(float64(coeffs[0].Int64())))
require.GreaterOrEqual(t, float64(params.LogN()), noise)
}
})
t.Run(testString(params, "Repack"), func(t *testing.T) {
if params.RingType() != ring.Standard {
t.Skip("Pack not supported for ring.Type = ring.ConjugateInvariant")
}
pt := rlwe.NewPlaintext(params, level)
ringQ := tc.params.RingQ().AtLevel(level)
ptPacked := genPlaintextNTT(params, level, 1<<40)
ciphertexts := make(map[int]*rlwe.Ciphertext)
// Generates ciphertexts where the i-th ciphertext
// having as constant coefficients the i-th coefficient
// of the plaintext.
// Generates a list of ciphertexts indexed by non-power-of-two
// and where the smallest gap is not a power of two to test
// the worst case.
XInvNTT := GenXPow2NTT(ringQ, 1, true)[0]
gap := 3
for i := 0; i < params.N(); i++ {
if i%gap == 0 {
if ciphertexts[i], err = enc.EncryptNew(ptPacked); err != nil {
t.Fatal(err)
}
}
ringQ.MulCoeffsMontgomery(ptPacked.Value, XInvNTT, ptPacked.Value)
}
// Resets plaintext as it has been modified by being sequentially multiplied with X^-1
ptPacked = genPlaintextNTT(params, level, 1<<40)
// Repacks the ciphertexts
ct, err := eval.Repack(ciphertexts)
require.NoError(t, err)
// Decrypts & Checks
dec.Decrypt(ct, pt)
if pt.IsNTT {
ringQ.INTT(pt.Value, pt.Value)
}
if ptPacked.IsNTT {
ringQ.INTT(ptPacked.Value, ptPacked.Value)
}
for i := 0; i < level+1; i++ {
Q := ringQ.SubRings[i].Modulus
have := pt.Value.Coeffs[i]
ref := ptPacked.Value.Coeffs[i]
for j := 0; j < params.N(); j += gap {
have[j] = ring.CRed(have[j]+Q-ref[j], Q)
}
}
// Logs the noise
require.GreaterOrEqual(t, float64(params.LogN()+5), ringQ.Log2OfStandardDeviation(pt.Value))
})
t.Run(testString(params, "Extract[naive=false]->Permute->Repack[naive=true]"), func(t *testing.T) {
testExtractPermuteRepack(params, level, enc, dec, eval, false, true, t)
})
t.Run(testString(params, "Extract[naive=true]->Permute->Repack[naive=false]"), func(t *testing.T) {
testExtractPermuteRepack(params, level, enc, dec, eval, true, false, t)
})
}
func testExtractPermuteRepack(params rlwe.Parameters, level int, enc *rlwe.Encryptor, dec *rlwe.Decryptor, eval *RingPackingEvaluator, ExtractNaive, RepackNaive bool, t *testing.T) {
if params.RingType() != ring.Standard {
t.Skip("Expand not supported for ring.Type = ring.ConjugateInvariant")
}
ringQ := params.RingQ().AtLevel(level)
N := params.N()
pt := genPlaintextNTT(params, level, 1<<40)
ct, err := enc.EncryptNew(pt)
require.NoError(t, err)
// Ensures that ct is encrypted at the max
// defined ring degree
require.Equal(t, ct.LogN(), eval.MaxLogN())
// Generates a random index selection
// of size N/2 (to test that omitted
// elements output zero coefficients)
r := rand.New(rand.NewSource(0))
list := make([]int, params.N())
for i := range list {
list[i] = i
}
r.Shuffle(len(list), func(i, j int) { list[i], list[j] = list[j], list[i] })
idx := map[int]bool{}
for _, i := range list[:params.N()>>1] {
idx[i] = true
}
// Extract the coefficients at the given index
var cts map[int]*rlwe.Ciphertext
if ExtractNaive {
cts, err = eval.ExtractNaive(ct, idx)
} else {
cts, err = eval.Extract(ct, idx)
}
require.NoError(t, err)
// Checks that the output ciphertext match the smallest
// defined ring degree
for i := range cts {
require.Equal(t, cts[i].LogN(), eval.MinLogN())
}
// Defines a new mapping
permute := func(x int) (y int) {
return ((x + N/2) & (N - 1))
}
// Applies the mapping
ctsPermute := map[int]*rlwe.Ciphertext{}
for i := range cts {
ctsPermute[permute(i)] = cts[i]
}
// Repacks with the new permutation
if RepackNaive {
ct, err = eval.RepackNaive(ctsPermute)
} else {
ct, err = eval.Repack(ctsPermute)
}
require.NoError(t, err)
// Decrypts & Checks
ptHave := dec.DecryptNew(ct)
if pt.IsNTT {
ringQ.INTT(pt.Value, pt.Value)
}
if ptHave.IsNTT {
ringQ.INTT(ptHave.Value, ptHave.Value)
}
for i := 0; i < level+1; i++ {
Q := ringQ.SubRings[i].Modulus
have := ptHave.Value.Coeffs[i]
ref := pt.Value.Coeffs[i]
for k0 := range idx {
k1 := permute(k0)
have[k1] = ring.CRed(have[k1]+Q-ref[k0], Q)
}
}
// Logs the noise
require.GreaterOrEqual(t, float64(params.LogN()+5), ringQ.Log2OfStandardDeviation(ptHave.Value))
}
func genPlaintextNTT(params rlwe.Parameters, level, max int) (pt *rlwe.Plaintext) {
N := params.N()
step := float64(max) / float64(N)
pt = rlwe.NewPlaintext(params, level)
for i := 0; i < level+1; i++ {
c := pt.Value.Coeffs[i]
for j := 0; j < N; j++ {
c[j] = uint64(float64(j) * step)
}
}
params.RingQ().AtLevel(level).NTT(pt.Value, pt.Value)
pt.IsNTT = true
return
}

View File

@@ -18,7 +18,7 @@ const (
// ParametersLiteral is a literal representation of BGV parameters. It has public
// fields and is used to express unchecked user-defined parameters literally into
// Go programs. The NewParametersFromLiteral function is used to generate the actual
// Go programs. The [NewParametersFromLiteral] function is used to generate the actual
// checked parameters from the literal representation.
//
// Users must set the polynomial degree (LogN) and the coefficient modulus, by either setting
@@ -44,8 +44,8 @@ type ParametersLiteral struct {
PlaintextModulus uint64 // Plaintext modulus
}
// GetRLWEParametersLiteral returns the rlwe.ParametersLiteral from the target bgv.ParametersLiteral.
// See the ParametersLiteral type for details on the BGV parameters.
// GetRLWEParametersLiteral returns the [rlwe.ParametersLiteral] from the target [bgv.ParametersLiteral].
// See the [ParametersLiteral] type for details on the BGV parameters.
func (p ParametersLiteral) GetRLWEParametersLiteral() rlwe.ParametersLiteral {
return rlwe.ParametersLiteral{
LogN: p.LogN,
@@ -63,7 +63,7 @@ func (p ParametersLiteral) GetRLWEParametersLiteral() rlwe.ParametersLiteral {
}
// Parameters represents a parameter set for the BGV cryptosystem. Its fields are private and
// immutable. See ParametersLiteral for user-specified parameters.
// immutable. See [ParametersLiteral] for user-specified parameters.
type Parameters struct {
rlwe.Parameters
ringQMul *ring.Ring
@@ -71,8 +71,8 @@ type Parameters struct {
}
// NewParameters instantiate a set of BGV parameters from the generic RLWE parameters and the BGV-specific ones.
// It returns the empty parameters Parameters{} and a non-nil error if the specified parameters are invalid.
// See the ParametersLiteral type for more details on the BGV parameters.
// It returns the empty parameters [Parameters]{} and a non-nil error if the specified parameters are invalid.
// See the [ParametersLiteral] type for more details on the BGV parameters.
func NewParameters(rlweParams rlwe.Parameters, t uint64) (p Parameters, err error) {
if !rlweParams.NTTFlag() {
@@ -127,10 +127,10 @@ func NewParameters(rlweParams rlwe.Parameters, t uint64) (p Parameters, err erro
}, nil
}
// NewParametersFromLiteral instantiate a set of BGV parameters from a ParametersLiteral specification.
// It returns the empty parameters Parameters{} and a non-nil error if the specified parameters are invalid.
// NewParametersFromLiteral instantiate a set of BGV parameters from a [ParametersLiteral] specification.
// It returns the empty parameters [Parameters]{} and a non-nil error if the specified parameters are invalid.
//
// See `rlwe.NewParametersFromLiteral` for default values of the optional fields and other details on the BGV
// See [rlwe.NewParametersFromLiteral] for default values of the optional fields and other details on the BGV
// parameters.
func NewParametersFromLiteral(pl ParametersLiteral) (Parameters, error) {
rlweParams, err := rlwe.NewParametersFromLiteral(pl.GetRLWEParametersLiteral())
@@ -140,7 +140,7 @@ func NewParametersFromLiteral(pl ParametersLiteral) (Parameters, error) {
return NewParameters(rlweParams, pl.PlaintextModulus)
}
// ParametersLiteral returns the ParametersLiteral of the target Parameters.
// ParametersLiteral returns the [ParametersLiteral] of the target Parameters.
func (p Parameters) ParametersLiteral() ParametersLiteral {
return ParametersLiteral{
LogN: p.LogN(),
@@ -182,14 +182,14 @@ func (p Parameters) LogMaxDimensions() ring.Dimensions {
}
}
// MaxSlots returns the total number of entries (`slots`) that a plaintext can store.
// MaxSlots returns the total number of entries (slots) that a plaintext can store.
// This value is obtained by multiplying all dimensions from MaxDimensions.
func (p Parameters) MaxSlots() int {
dims := p.MaxDimensions()
return dims.Rows * dims.Cols
}
// LogMaxSlots returns the total number of entries (`slots`) that a plaintext can store.
// LogMaxSlots returns the total number of entries (slots) that a plaintext can store.
// This value is obtained by summing all log dimensions from LogDimensions.
func (p Parameters) LogMaxSlots() int {
dims := p.LogMaxDimensions()
@@ -254,7 +254,7 @@ func (p Parameters) GaloisElementForRowRotation() uint64 {
}
// GaloisElementsForInnerSum returns the list of Galois elements necessary to apply the method
// `InnerSum` operation with parameters `batch` and `n`.
// InnerSum operation with parameters batch and n.
func (p Parameters) GaloisElementsForInnerSum(batch, n int) (galEls []uint64) {
galEls = rlwe.GaloisElementsForInnerSum(p, batch, n)
if n > p.N()>>1 {
@@ -264,7 +264,7 @@ func (p Parameters) GaloisElementsForInnerSum(batch, n int) (galEls []uint64) {
}
// GaloisElementsForReplicate returns the list of Galois elements necessary to perform the
// `Replicate` operation with parameters `batch` and `n`.
// Replicate operation with parameters batch and n.
func (p Parameters) GaloisElementsForReplicate(batch, n int) (galEls []uint64) {
galEls = rlwe.GaloisElementsForReplicate(p, batch, n)
if n > p.N()>>1 {
@@ -273,23 +273,12 @@ func (p Parameters) GaloisElementsForReplicate(batch, n int) (galEls []uint64) {
return
}
// GaloisElementsForTrace returns the list of Galois elements requored for the for the `Trace` operation.
// GaloisElementsForTrace returns the list of Galois elements requored for the for the Trace operation.
// Trace maps X -> sum((-1)^i * X^{i*n+1}) for 2^{LogN} <= i < N.
func (p Parameters) GaloisElementsForTrace(logN int) []uint64 {
return rlwe.GaloisElementsForTrace(p, logN)
}
// GaloisElementsForExpand returns the list of Galois elements required
// to perform the `Expand` operation with parameter `logN`.
func (p Parameters) GaloisElementsForExpand(logN int) []uint64 {
return rlwe.GaloisElementsForExpand(p, logN)
}
// GaloisElementsForPack returns the list of Galois elements required to perform the `Pack` operation.
func (p Parameters) GaloisElementsForPack(logN int) []uint64 {
return rlwe.GaloisElementsForPack(p, logN)
}
// Equal compares two sets of parameters for equality.
func (p Parameters) Equal(other *Parameters) bool {
return p.Parameters.Equal(&other.Parameters) && (p.PlaintextModulus() == other.PlaintextModulus())
@@ -307,12 +296,12 @@ func (p *Parameters) UnmarshalBinary(data []byte) (err error) {
return p.UnmarshalJSON(data)
}
// MarshalJSON returns a JSON representation of this parameter set. See `Marshal` from the `encoding/json` package.
// MarshalJSON returns a JSON representation of this parameter set. See Marshal from the [encoding/json] package.
func (p Parameters) MarshalJSON() ([]byte, error) {
return json.Marshal(p.ParametersLiteral())
}
// UnmarshalJSON reads a JSON representation of a parameter set into the receiver Parameter. See `Unmarshal` from the `encoding/json` package.
// UnmarshalJSON reads a JSON representation of a parameter set into the receiver Parameter. See Unmarshal from the [encoding/json] package.
func (p *Parameters) UnmarshalJSON(data []byte) (err error) {
var params ParametersLiteral
if err = json.Unmarshal(data, &params); err != nil {

View File

@@ -16,8 +16,8 @@ import (
// This also sets how many primes are consumed per rescaling.
//
// There are currently two modes supported:
// - PREC64 (one 64 bit word)
// - PREC128 (two 64 bit words)
// - PREC64 (one 64-bit word)
// - PREC128 (two 64-bit words)
//
// PREC64 is the default mode and supports reference plaintext scaling
// factors of up to 2^{64}, while PREC128 scaling factors of up to 2^{128}.
@@ -34,7 +34,7 @@ const (
// ParametersLiteral is a literal representation of CKKS parameters. It has public
// fields and is used to express unchecked user-defined parameters literally into
// Go programs. The NewParametersFromLiteral function is used to generate the actual
// Go programs. The [NewParametersFromLiteral] function is used to generate the actual
// checked parameters from the literal representation.
//
// Users must set the polynomial degree (in log_2, LogN) and the coefficient modulus, by either setting
@@ -57,7 +57,7 @@ type ParametersLiteral struct {
LogDefaultScale int
}
// GetRLWEParametersLiteral returns the rlwe.ParametersLiteral from the target ckks.ParameterLiteral.
// GetRLWEParametersLiteral returns the [rlwe.ParametersLiteral] from the target [ckks.ParameterLiteral].
func (p ParametersLiteral) GetRLWEParametersLiteral() rlwe.ParametersLiteral {
return rlwe.ParametersLiteral{
LogN: p.LogN,
@@ -75,19 +75,19 @@ func (p ParametersLiteral) GetRLWEParametersLiteral() rlwe.ParametersLiteral {
}
// Parameters represents a parameter set for the CKKS cryptosystem. Its fields are private and
// immutable. See ParametersLiteral for user-specified parameters.
// immutable. See [ParametersLiteral] for user-specified parameters.
type Parameters struct {
rlwe.Parameters
precisionMode PrecisionMode
}
// NewParametersFromLiteral instantiate a set of CKKS parameters from a ParametersLiteral specification.
// It returns the empty parameters Parameters{} and a non-nil error if the specified parameters are invalid.
// NewParametersFromLiteral instantiate a set of CKKS parameters from a [ParametersLiteral] specification.
// It returns the empty parameters [Parameters]{} and a non-nil error if the specified parameters are invalid.
//
// If the `LogSlots` field is left unset, its value is set to `LogN-1` for the Standard ring and to `LogN` for
// If the LogSlots field is left unset, its value is set to LogN-1 for the Standard ring and to LogN for
// the conjugate-invariant ring.
//
// See `rlwe.NewParametersFromLiteral` for default values of the other optional fields.
// See [rlwe.NewParametersFromLiteral] for default values of the other optional fields.
func NewParametersFromLiteral(pl ParametersLiteral) (Parameters, error) {
rlweParams, err := rlwe.NewParametersFromLiteral(pl.GetRLWEParametersLiteral())
if err != nil {
@@ -119,7 +119,7 @@ func (p Parameters) StandardParameters() (pckks Parameters, err error) {
return
}
// ParametersLiteral returns the ParametersLiteral of the target Parameters.
// ParametersLiteral returns the [ParametersLiteral] of the target [Parameters].
func (p Parameters) ParametersLiteral() (pLit ParametersLiteral) {
return ParametersLiteral{
LogN: p.LogN(),
@@ -169,14 +169,14 @@ func (p Parameters) LogMaxDimensions() ring.Dimensions {
}
}
// MaxSlots returns the total number of entries (`slots`) that a plaintext can store.
// MaxSlots returns the total number of entries (slots) that a plaintext can store.
// This value is obtained by multiplying all dimensions from MaxDimensions.
func (p Parameters) MaxSlots() int {
dims := p.MaxDimensions()
return dims.Rows * dims.Cols
}
// LogMaxSlots returns the total number of entries (`slots`) that a plaintext can store.
// LogMaxSlots returns the total number of entries (slots) that a plaintext can store.
// This value is obtained by summing all log dimensions from LogDimensions.
func (p Parameters) LogMaxSlots() int {
dims := p.LogMaxDimensions()
@@ -202,7 +202,7 @@ func (p Parameters) EncodingPrecision() (prec uint) {
}
// PrecisionMode returns the precision mode of the parameters.
// This value can be ckks.PREC64 or ckks.PREC128.
// This value can be [ckks.PREC64] or [ckks.PREC128].
func (p Parameters) PrecisionMode() PrecisionMode {
return p.precisionMode
}
@@ -219,6 +219,16 @@ func (p Parameters) LevelsConsumedPerRescaling() int {
}
}
// GetOptimalScalingFactor returns a scaling factor b such that Rescale(a * b) = c
func (p Parameters) GetOptimalScalingFactor(a, c rlwe.Scale, level int) (b rlwe.Scale) {
b = rlwe.NewScale(1)
Q := p.Q()
for i := 0; i < p.LevelsConsumedPerRescaling(); i++ {
b = b.Mul(rlwe.NewScale(Q[level-i]))
}
return
}
// MaxDepth returns the maximum depth enabled by the parameters,
// which is obtained as p.MaxLevel() / p.LevelsConsumedPerRescaling().
func (p Parameters) MaxDepth() int {
@@ -231,7 +241,7 @@ func (p Parameters) LogQLvl(level int) int {
return tmp.BitLen()
}
// QLvl returns the product of the moduli at the given level as a big.Int
// QLvl returns the product of the moduli at the given level as a [big.Int]
func (p Parameters) QLvl(level int) *big.Int {
tmp := bignum.NewInt(1)
for _, qi := range p.Q()[:level+1] {
@@ -284,34 +294,23 @@ func (p Parameters) GaloisElementForComplexConjugation() uint64 {
}
// GaloisElementsForInnerSum returns the list of Galois elements necessary to apply the method
// `InnerSum` operation with parameters `batch` and `n`.
// `InnerSum` operation with parameters batch and n.
func (p Parameters) GaloisElementsForInnerSum(batch, n int) []uint64 {
return rlwe.GaloisElementsForInnerSum(p, batch, n)
}
// GaloisElementsForReplicate returns the list of Galois elements necessary to perform the
// `Replicate` operation with parameters `batch` and `n`.
// `Replicate` operation with parameters batch and n.
func (p Parameters) GaloisElementsForReplicate(batch, n int) []uint64 {
return rlwe.GaloisElementsForReplicate(p, batch, n)
}
// GaloisElementsForTrace returns the list of Galois elements requored for the for the `Trace` operation.
// GaloisElementsForTrace returns the list of Galois elements requored for the for the Trace operation.
// Trace maps X -> sum((-1)^i * X^{i*n+1}) for 2^{LogN} <= i < N.
func (p Parameters) GaloisElementsForTrace(logN int) []uint64 {
return rlwe.GaloisElementsForTrace(p, logN)
}
// GaloisElementsForExpand returns the list of Galois elements required
// to perform the `Expand` operation with parameter `logN`.
func (p Parameters) GaloisElementsForExpand(logN int) []uint64 {
return rlwe.GaloisElementsForExpand(p, logN)
}
// GaloisElementsForPack returns the list of Galois elements required to perform the `Pack` operation.
func (p Parameters) GaloisElementsForPack(logN int) []uint64 {
return rlwe.GaloisElementsForPack(p, logN)
}
// Equal compares two sets of parameters for equality.
func (p Parameters) Equal(other *Parameters) bool {
return p.Parameters.Equal(&other.Parameters) && p.precisionMode == other.precisionMode
@@ -328,12 +327,12 @@ func (p *Parameters) UnmarshalBinary(data []byte) (err error) {
return p.UnmarshalJSON(data)
}
// MarshalJSON returns a JSON representation of this parameter set. See `Marshal` from the `encoding/json` package.
// MarshalJSON returns a JSON representation of this parameter set. See Marshal from the [encoding/json] package.
func (p Parameters) MarshalJSON() ([]byte, error) {
return json.Marshal(p.ParametersLiteral())
}
// UnmarshalJSON reads a JSON representation of a parameter set into the receiver Parameter. See `Unmarshal` from the `encoding/json` package.
// UnmarshalJSON reads a JSON representation of a parameter set into the receiver Parameter. See Unmarshal from the [encoding/json] package.
func (p *Parameters) UnmarshalJSON(data []byte) (err error) {
var params ParametersLiteral
if err = json.Unmarshal(data, &params); err != nil {