diff --git a/bgv/encoder.go b/bgv/encoder.go index 0fcae698..d6f45582 100644 --- a/bgv/encoder.go +++ b/bgv/encoder.go @@ -6,6 +6,7 @@ import ( "github.com/tuneinsight/lattigo/v4/ring" "github.com/tuneinsight/lattigo/v4/rlwe" "github.com/tuneinsight/lattigo/v4/utils" + "github.com/tuneinsight/lattigo/v4/utils/bignum" ) // GaloisGen is an integer of order N=2^d modulo M=2N and that spans Z_M with the integer -1. @@ -92,7 +93,7 @@ func NewEncoder(params Parameters) Encoder { tInvModQ := make([]*big.Int, ringQ.ModuliChainLength()) for i := range moduli { - tInvModQ[i] = ring.NewUint(T) + tInvModQ[i] = bignum.NewInt(T) tInvModQ[i].ModInverse(tInvModQ[i], ringQ.ModulusAtLevel[i]) } diff --git a/bgv/evaluator.go b/bgv/evaluator.go index f9460f7b..618ac86f 100644 --- a/bgv/evaluator.go +++ b/bgv/evaluator.go @@ -9,6 +9,7 @@ import ( "github.com/tuneinsight/lattigo/v4/rlwe" "github.com/tuneinsight/lattigo/v4/rlwe/ringqp" "github.com/tuneinsight/lattigo/v4/utils" + "github.com/tuneinsight/lattigo/v4/utils/bignum" ) // Evaluator is an interface implementing the public methods of the eval. @@ -109,7 +110,7 @@ func newEvaluatorPrecomp(params Parameters) *evaluatorBase { tInvModQ := make([]*big.Int, ringQ.ModuliChainLength()) for i := range tInvModQ { - tInvModQ[i] = ring.NewUint(t) + tInvModQ[i] = bignum.NewInt(t) tInvModQ[i].ModInverse(tInvModQ[i], ringQ.ModulusAtLevel[i]) } @@ -275,8 +276,8 @@ func (eval *evaluator) Add(op0 *rlwe.Ciphertext, op1 interface{}, op2 *rlwe.Ciph _, level := eval.CheckBinary(op0, op1, op2, utils.Max(op0.Degree(), op1.Degree())) - if op0.Scale.Cmp(op1.GetScale()) == 0 { - eval.evaluateInPlace(level, op0.El(), op1.El(), op2.El(), ringQ.AtLevel(level).Add) + if op0.Scale.Cmp(op1.GetMetaData().Scale) == 0 { + eval.evaluateInPlace(level, op0, op1.El(), op2, ringQ.AtLevel(level).Add) } else { eval.matchScaleThenEvaluateInPlace(level, op0.El(), op1.El(), op2.El(), ringQ.AtLevel(level).MulScalarThenAdd) } @@ -336,8 +337,8 @@ func (eval *evaluator) Sub(op0 *rlwe.Ciphertext, op1 interface{}, op2 *rlwe.Ciph ringQ := eval.params.RingQ() - if op0.Scale.Cmp(op1.GetScale()) == 0 { - eval.evaluateInPlace(level, op0.El(), op1.El(), op2.El(), ringQ.AtLevel(level).Sub) + if op0.Scale.Cmp(op1.GetMetaData().Scale) == 0 { + eval.evaluateInPlace(level, op0, op1.El(), op2, ringQ.AtLevel(level).Sub) } else { eval.matchScaleThenEvaluateInPlace(level, op0.El(), op1.El(), op2.El(), ringQ.AtLevel(level).MulScalarThenSub) } @@ -504,7 +505,7 @@ func (eval *evaluator) tensorStandard(op0 *rlwe.Ciphertext, op1 *rlwe.OperandQ, } op2.MetaData = op0.MetaData - op2.Scale = op0.Scale.Mul(op1.GetScale()) + op2.Scale = op0.Scale.Mul(op1.GetMetaData().Scale) ringQ := eval.params.RingQ().AtLevel(level) @@ -898,7 +899,7 @@ func (eval *evaluator) mulRelinThenAdd(op0 *rlwe.Ciphertext, op1 rlwe.Operand, r tmp0, tmp1 := op0.El(), op1.El() var r0 uint64 = 1 - if targetScale := ring.BRed(op0.Scale.Uint64(), op1.GetScale().Uint64(), sT.Modulus, sT.BRedConstant); op2.Scale.Cmp(eval.params.NewScale(targetScale)) != 0 { + if targetScale := ring.BRed(op0.Scale.Uint64(), op1.GetMetaData().Scale.Uint64(), sT.Modulus, sT.BRedConstant); op2.Scale.Cmp(eval.params.NewScale(targetScale)) != 0 { var r1 uint64 r0, r1, _ = eval.matchScalesBinary(targetScale, op2.Scale.Uint64()) @@ -960,7 +961,7 @@ func (eval *evaluator) mulRelinThenAdd(op0 *rlwe.Ciphertext, op1 rlwe.Operand, r ringQ.MulScalar(c00, eval.params.T(), c00) var r0 = uint64(1) - if targetScale := ring.BRed(op0.Scale.Uint64(), op1.GetScale().Uint64(), sT.Modulus, sT.BRedConstant); op2.Scale.Cmp(eval.params.NewScale(targetScale)) != 0 { + if targetScale := ring.BRed(op0.Scale.Uint64(), op1.GetMetaData().Scale.Uint64(), sT.Modulus, sT.BRedConstant); op2.Scale.Cmp(eval.params.NewScale(targetScale)) != 0 { var r1 uint64 r0, r1, _ = eval.matchScalesBinary(targetScale, op2.Scale.Uint64()) diff --git a/ckks/advanced/evaluator.go b/ckks/advanced/evaluator.go index 42685bb1..2631b1ab 100644 --- a/ckks/advanced/evaluator.go +++ b/ckks/advanced/evaluator.go @@ -2,12 +2,13 @@ package advanced import ( - "math" + "math/big" "github.com/tuneinsight/lattigo/v4/ckks" "github.com/tuneinsight/lattigo/v4/ring" "github.com/tuneinsight/lattigo/v4/rlwe" "github.com/tuneinsight/lattigo/v4/rlwe/ringqp" + "github.com/tuneinsight/lattigo/v4/utils/bignum" ) // Evaluator is an interface embedding the ckks.Evaluator interface with @@ -18,48 +19,87 @@ type Evaluator interface { // === Original ckks.Evaluator methods === // ======================================= - Add(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) - AddNew(ctIn *rlwe.Ciphertext, op1 rlwe.Operand) (ctOut *rlwe.Ciphertext) - Sub(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) - SubNew(ctIn *rlwe.Ciphertext, op1 rlwe.Operand) (ctOut *rlwe.Ciphertext) - Neg(ctIn *rlwe.Ciphertext, ctOut *rlwe.Ciphertext) - NegNew(ctIn *rlwe.Ciphertext) (ctOut *rlwe.Ciphertext) - AddConstNew(ctIn *rlwe.Ciphertext, constant interface{}) (ctOut *rlwe.Ciphertext) - AddConst(ctIn *rlwe.Ciphertext, constant interface{}, ctOut *rlwe.Ciphertext) - MultByConstNew(ctIn *rlwe.Ciphertext, constant interface{}) (ctOut *rlwe.Ciphertext) - MultByConst(ctIn *rlwe.Ciphertext, constant interface{}, ctOut *rlwe.Ciphertext) - MultByConstThenAdd(ctIn *rlwe.Ciphertext, constant interface{}, ctOut *rlwe.Ciphertext) - ConjugateNew(ctIn *rlwe.Ciphertext) (ctOut *rlwe.Ciphertext) - Conjugate(ctIn *rlwe.Ciphertext, ctOut *rlwe.Ciphertext) - Mul(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) - MulNew(ctIn *rlwe.Ciphertext, op1 rlwe.Operand) (ctOut *rlwe.Ciphertext) - MulRelin(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) - MulRelinNew(ctIn *rlwe.Ciphertext, op1 rlwe.Operand) (ctOut *rlwe.Ciphertext) - RotateNew(ctIn *rlwe.Ciphertext, k int) (ctOut *rlwe.Ciphertext) - Rotate(ctIn *rlwe.Ciphertext, k int, ctOut *rlwe.Ciphertext) - RotateHoistedNew(ctIn *rlwe.Ciphertext, rotations []int) (ctOut map[int]*rlwe.Ciphertext) - RotateHoisted(ctIn *rlwe.Ciphertext, rotations []int, ctOut map[int]*rlwe.Ciphertext) - EvaluatePoly(input interface{}, pol *ckks.Polynomial, targetscale rlwe.Scale) (ctOut *rlwe.Ciphertext, err error) - EvaluatePolyVector(input interface{}, pols []*ckks.Polynomial, encoder ckks.Encoder, slotIndex map[int][]int, targetscale rlwe.Scale) (ctOut *rlwe.Ciphertext, err error) - InverseNew(ctIn *rlwe.Ciphertext, steps int) (ctOut *rlwe.Ciphertext, err error) - LinearTransformNew(ctIn *rlwe.Ciphertext, linearTransform interface{}) (ctOut []*rlwe.Ciphertext) - LinearTransform(ctIn *rlwe.Ciphertext, linearTransform interface{}, ctOut []*rlwe.Ciphertext) - MultiplyByDiagMatrix(ctIn *rlwe.Ciphertext, matrix ckks.LinearTransform, c2DecompQP []ringqp.Poly, ctOut *rlwe.Ciphertext) - MultiplyByDiagMatrixBSGS(ctIn *rlwe.Ciphertext, matrix ckks.LinearTransform, c2DecompQP []ringqp.Poly, ctOut *rlwe.Ciphertext) - InnerSum(ctIn *rlwe.Ciphertext, batch, n int, ctOut *rlwe.Ciphertext) - Replicate(ctIn *rlwe.Ciphertext, batch, n int, ctOut *rlwe.Ciphertext) - TraceNew(ctIn *rlwe.Ciphertext, logSlots int) *rlwe.Ciphertext - Trace(ctIn *rlwe.Ciphertext, logSlots int, ctOut *rlwe.Ciphertext) - ApplyEvaluationKeyNew(ctIn *rlwe.Ciphertext, evk *rlwe.EvaluationKey) (ctOut *rlwe.Ciphertext) - ApplyEvaluationKey(ctIn *rlwe.Ciphertext, evk *rlwe.EvaluationKey, ctOut *rlwe.Ciphertext) - RelinearizeNew(ctIn *rlwe.Ciphertext) (ctOut *rlwe.Ciphertext) - Relinearize(ctIn *rlwe.Ciphertext, ctOut *rlwe.Ciphertext) - ScaleUpNew(ctIn *rlwe.Ciphertext, scale rlwe.Scale) (ctOut *rlwe.Ciphertext) - ScaleUp(ctIn *rlwe.Ciphertext, scale rlwe.Scale, ctOut *rlwe.Ciphertext) - SetScale(ctIn *rlwe.Ciphertext, scale rlwe.Scale) - Rescale(ctIn *rlwe.Ciphertext, minscale rlwe.Scale, ctOut *rlwe.Ciphertext) (err error) - DropLevelNew(ctIn *rlwe.Ciphertext, levels int) (ctOut *rlwe.Ciphertext) - DropLevel(ctIn *rlwe.Ciphertext, levels int) + // ======================== + // === Basic Arithmetic === + // ======================== + + // Addition + Add(op0 *rlwe.Ciphertext, op1 interface{}, op2 *rlwe.Ciphertext) + AddNew(op0 *rlwe.Ciphertext, op1 interface{}) (op2 *rlwe.Ciphertext) + + // Subtraction + Sub(op0 *rlwe.Ciphertext, op1 interface{}, op2 *rlwe.Ciphertext) + SubNew(op0 *rlwe.Ciphertext, op1 interface{}) (op2 *rlwe.Ciphertext) + + // Complex Conjugation + ConjugateNew(op0 *rlwe.Ciphertext) (ctOut *rlwe.Ciphertext) + Conjugate(op0 *rlwe.Ciphertext, ctOut *rlwe.Ciphertext) + + // Multiplication + Mul(op0 *rlwe.Ciphertext, op1 interface{}, ctOut *rlwe.Ciphertext) + MulNew(op0 *rlwe.Ciphertext, op1 interface{}) (ctOut *rlwe.Ciphertext) + MulRelin(op0 *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) + MulRelinNew(op0 *rlwe.Ciphertext, op1 rlwe.Operand) (ctOut *rlwe.Ciphertext) + + MulThenAdd(op0 *rlwe.Ciphertext, op1 interface{}, ctOut *rlwe.Ciphertext) + MulRelinThenAdd(op0 *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) + + // Slot Rotations + RotateNew(op0 *rlwe.Ciphertext, k int) (ctOut *rlwe.Ciphertext) + Rotate(op0 *rlwe.Ciphertext, k int, ctOut *rlwe.Ciphertext) + RotateHoistedNew(op0 *rlwe.Ciphertext, rotations []int) (ctOut map[int]*rlwe.Ciphertext) + RotateHoisted(op0 *rlwe.Ciphertext, rotations []int, ctOut map[int]*rlwe.Ciphertext) + RotateHoistedLazyNew(level int, rotations []int, ct *rlwe.Ciphertext, c2DecompQP []ringqp.Poly) (cOut map[int]rlwe.CiphertextQP) + + // =========================== + // === Advanced Arithmetic === + // =========================== + + // Polynomial evaluation + EvaluatePoly(input interface{}, pol *bignum.Polynomial, targetScale rlwe.Scale) (ctOut *rlwe.Ciphertext, err error) + EvaluatePolyVector(input interface{}, pols []*bignum.Polynomial, encoder *ckks.Encoder, slotIndex map[int][]int, targetScale rlwe.Scale) (ctOut *rlwe.Ciphertext, err error) + + // GoldschmidtDivision + GoldschmidtDivisionNew(ct *rlwe.Ciphertext, minValue, log2Targetprecision float64, btp rlwe.Bootstrapper) (ctInv *rlwe.Ciphertext, err error) + + // Linear Transformations + LinearTransformNew(op0 *rlwe.Ciphertext, linearTransform interface{}) (ctOut []*rlwe.Ciphertext) + LinearTransform(op0 *rlwe.Ciphertext, linearTransform interface{}, ctOut []*rlwe.Ciphertext) + MultiplyByDiagMatrix(op0 *rlwe.Ciphertext, matrix ckks.LinearTransform, c2DecompQP []ringqp.Poly, ctOut *rlwe.Ciphertext) + MultiplyByDiagMatrixBSGS(op0 *rlwe.Ciphertext, matrix ckks.LinearTransform, c2DecompQP []ringqp.Poly, ctOut *rlwe.Ciphertext) + + // Inner sum + InnerSum(op0 *rlwe.Ciphertext, batch, n int, ctOut *rlwe.Ciphertext) + Average(op0 *rlwe.Ciphertext, batch int, ctOut *rlwe.Ciphertext) + + // Replication (inverse of Inner sum) + Replicate(op0 *rlwe.Ciphertext, batch, n int, ctOut *rlwe.Ciphertext) + + // Trace + Trace(op0 *rlwe.Ciphertext, logSlots int, ctOut *rlwe.Ciphertext) + TraceNew(op0 *rlwe.Ciphertext, logSlots int) (ctOut *rlwe.Ciphertext) + + // ============================= + // === Ciphertext Management === + // ============================= + + // Generic EvaluationKeys + ApplyEvaluationKeyNew(op0 *rlwe.Ciphertext, evk *rlwe.EvaluationKey) (ctOut *rlwe.Ciphertext) + ApplyEvaluationKey(op0 *rlwe.Ciphertext, evk *rlwe.EvaluationKey, ctOut *rlwe.Ciphertext) + + // Degree Management + RelinearizeNew(op0 *rlwe.Ciphertext) (ctOut *rlwe.Ciphertext) + Relinearize(op0 *rlwe.Ciphertext, ctOut *rlwe.Ciphertext) + + // Scale Management + ScaleUpNew(op0 *rlwe.Ciphertext, scale rlwe.Scale) (ctOut *rlwe.Ciphertext) + ScaleUp(op0 *rlwe.Ciphertext, scale rlwe.Scale, ctOut *rlwe.Ciphertext) + SetScale(op0 *rlwe.Ciphertext, scale rlwe.Scale) + Rescale(op0 *rlwe.Ciphertext, minScale rlwe.Scale, ctOut *rlwe.Ciphertext) (err error) + + // Level Management + DropLevelNew(op0 *rlwe.Ciphertext, levels int) (ctOut *rlwe.Ciphertext) + DropLevel(op0 *rlwe.Ciphertext, levels int) // ====================================== // === advanced.Evaluator new methods === @@ -75,12 +115,13 @@ type Evaluator interface { // === original ckks.Evaluator redefined methods === // ================================================= - Parameters() ckks.Parameters + CheckBinary(op0, op1, opOut rlwe.Operand, opOutMinDegree int) (degree, level int) + CheckUnary(op0, opOut rlwe.Operand) (degree, level int) GetRLWEEvaluator() *rlwe.Evaluator BuffQ() [3]*ring.Poly BuffCt() *rlwe.Ciphertext ShallowCopy() Evaluator - WithKey(rlwe.EvaluationKeySetInterface) Evaluator + WithKey(evk rlwe.EvaluationKeySetInterface) Evaluator } type evaluator struct { @@ -118,7 +159,7 @@ func (eval *evaluator) WithKey(evk rlwe.EvaluationKeySetInterface) Evaluator { func (eval *evaluator) CoeffsToSlotsNew(ctIn *rlwe.Ciphertext, ctsMatrices HomomorphicDFTMatrix) (ctReal, ctImag *rlwe.Ciphertext) { ctReal = ckks.NewCiphertext(eval.params, 1, ctsMatrices.LevelStart) - if eval.params.LogSlots() == eval.params.LogN()-1 { + if ctsMatrices.LogSlots == eval.params.MaxLogSlots() { ctImag = ckks.NewCiphertext(eval.params, 1, ctsMatrices.LevelStart) } @@ -150,18 +191,19 @@ func (eval *evaluator) CoeffsToSlots(ctIn *rlwe.Ciphertext, ctsMatrices Homomorp // Imag part eval.Sub(zV, ctReal, tmp) - eval.MultByConst(tmp, -1i, tmp) + eval.Mul(tmp, -1i, tmp) // Real part eval.Add(ctReal, zV, ctReal) // If repacking, then ct0 and ct1 right n/2 slots are zero. - if eval.params.LogSlots() < eval.params.LogN()-1 { - eval.Rotate(tmp, eval.params.Slots(), tmp) + if ctsMatrices.LogSlots < eval.params.MaxLogSlots() { + eval.Rotate(tmp, ctIn.Slots(), tmp) eval.Add(ctReal, tmp, ctReal) } zV = nil + } else { eval.dft(ctIn, ctsMatrices.Matrices, ctReal) } @@ -190,7 +232,7 @@ func (eval *evaluator) SlotsToCoeffsNew(ctReal, ctImag *rlwe.Ciphertext, stcMatr func (eval *evaluator) SlotsToCoeffs(ctReal, ctImag *rlwe.Ciphertext, stcMatrices HomomorphicDFTMatrix, ctOut *rlwe.Ciphertext) { // If full packing, the repacking can be done directly using ct0 and ct1. if ctImag != nil { - eval.MultByConst(ctImag, 1i, ctOut) + eval.Mul(ctImag, 1i, ctOut) eval.Add(ctOut, ctReal, ctOut) eval.dft(ctOut, stcMatrices.Matrices, ctOut) } else { @@ -200,6 +242,8 @@ func (eval *evaluator) SlotsToCoeffs(ctReal, ctImag *rlwe.Ciphertext, stcMatrice func (eval *evaluator) dft(ctIn *rlwe.Ciphertext, plainVectors []ckks.LinearTransform, ctOut *rlwe.Ciphertext) { + inputLogSlots := ctIn.LogSlots + // Sequentially multiplies w with the provided dft matrices. scale := ctIn.Scale var in, out *rlwe.Ciphertext @@ -208,12 +252,18 @@ func (eval *evaluator) dft(ctIn *rlwe.Ciphertext, plainVectors []ckks.LinearTran if i == 0 { in, out = ctIn, ctOut } + eval.LinearTransform(in, plainVector, []*rlwe.Ciphertext{out}) if err := eval.Rescale(out, scale, out); err != nil { panic(err) } } + + // Encoding matrices are a special case of `fractal` linear transform + // that doesn't change the underlying plaintext polynomial Y = X^{N/n} + // of the input ciphertext. + ctOut.LogSlots = inputLogSlots } // EvalModNew applies a homomorphic mod Q on a vector scaled by Delta, scaled down to mod 1 : @@ -252,14 +302,19 @@ func (eval *evaluator) EvalModNew(ct *rlwe.Ciphertext, evalModPoly EvalModPoly) // formula such that after it it has the scale it had before the polynomial // evaluation - targetScale := ct.Scale.Float64() + targetScale := ct.Scale for i := 0; i < evalModPoly.doubleAngle; i++ { - targetScale = math.Sqrt(targetScale * eval.params.QiFloat64(evalModPoly.levelStart-evalModPoly.sinePoly.Depth()-evalModPoly.doubleAngle+i+1)) + qi := eval.params.Q()[evalModPoly.levelStart-evalModPoly.sinePoly.Depth()-evalModPoly.doubleAngle+i+1] + targetScale = targetScale.Mul(rlwe.NewScale(qi)) + targetScale.Value.Sqrt(&targetScale.Value) } // Division by 1/2^r and change of variable for the Chebyshev evaluation if evalModPoly.sineType == CosDiscrete || evalModPoly.sineType == CosContinuous { - eval.AddConst(ct, -0.5/(evalModPoly.scFac*(evalModPoly.sinePoly.B-evalModPoly.sinePoly.A)), ct) + offset := new(big.Float).Sub(evalModPoly.sinePoly.B, evalModPoly.sinePoly.A) + offset.Mul(offset, new(big.Float).SetFloat64(evalModPoly.scFac)) + offset.Quo(new(big.Float).SetFloat64(-0.5), offset) + eval.Add(ct, offset, ct) } // Chebyshev evaluation @@ -273,7 +328,7 @@ func (eval *evaluator) EvalModNew(ct *rlwe.Ciphertext, evalModPoly EvalModPoly) sqrt2pi *= sqrt2pi eval.MulRelin(ct, ct, ct) eval.Add(ct, ct, ct) - eval.AddConst(ct, -sqrt2pi, ct) + eval.Add(ct, -sqrt2pi, ct) if err := eval.Rescale(ct, rlwe.NewScale(targetScale), ct); err != nil { panic(err) } diff --git a/ckks/advanced/homomorphic_DFT.go b/ckks/advanced/homomorphic_DFT.go index fb7e356b..8aa1e922 100644 --- a/ckks/advanced/homomorphic_DFT.go +++ b/ckks/advanced/homomorphic_DFT.go @@ -2,10 +2,12 @@ package advanced import ( "math" + "math/big" "github.com/tuneinsight/lattigo/v4/ckks" "github.com/tuneinsight/lattigo/v4/rlwe" "github.com/tuneinsight/lattigo/v4/utils" + "github.com/tuneinsight/lattigo/v4/utils/bignum" ) // DFTType is a type used to distinguish different linear transformations. @@ -42,7 +44,6 @@ type HomomorphicDFTMatrix struct { type HomomorphicDFTMatrixLiteral struct { // Mandatory Type DFTType - LogN int LogSlots int LevelStart int Levels []int @@ -72,7 +73,7 @@ func (d *HomomorphicDFTMatrixLiteral) GaloisElements(params ckks.Parameters) (ga rotations := []int{} logSlots := d.LogSlots - logN := d.LogN + logN := params.LogN() slots := 1 << logSlots dslots := slots if logSlots < logN-1 && d.RepackImag2Real { @@ -82,7 +83,7 @@ func (d *HomomorphicDFTMatrixLiteral) GaloisElements(params ckks.Parameters) (ga } } - indexCtS := d.computeBootstrappingDFTIndexMap() + indexCtS := d.computeBootstrappingDFTIndexMap(logN) // Coeffs to Slots rotations for i, pVec := range indexCtS { @@ -94,25 +95,32 @@ func (d *HomomorphicDFTMatrixLiteral) GaloisElements(params ckks.Parameters) (ga } // NewHomomorphicDFTMatrixFromLiteral generates the factorized DFT/IDFT matrices for the homomorphic encoding/decoding. -func NewHomomorphicDFTMatrixFromLiteral(d HomomorphicDFTMatrixLiteral, encoder ckks.Encoder) HomomorphicDFTMatrix { - - logSlots := d.LogSlots - logdSlots := logSlots - if logdSlots < d.LogN-1 && d.RepackImag2Real { - logdSlots++ - } +func NewHomomorphicDFTMatrixFromLiteral(d HomomorphicDFTMatrixLiteral, encoder *ckks.Encoder) HomomorphicDFTMatrix { params := encoder.Parameters() - // DFT vectors + logSlots := d.LogSlots + logdSlots := logSlots + if logdSlots < params.MaxLogSlots() && d.RepackImag2Real { + logdSlots++ + } + + // CoeffsToSlots vectors matrices := []ckks.LinearTransform{} - pVecDFT := d.GenMatrices() + pVecDFT := d.GenMatrices(params.LogN()) level := d.LevelStart var idx int for i := range d.Levels { - scale := rlwe.NewScale(math.Pow(params.QiFloat64(level), 1.0/float64(d.Levels[i]))) + scale := rlwe.NewScale(params.Q()[level]) + + if d.Levels[i] > 1 { + y := new(big.Float).SetPrec(scale.Value.Prec()).SetInt64(1) + y.Quo(y, new(big.Float).SetPrec(scale.Value.Prec()).SetInt64(int64(d.Levels[i]))) + + scale.Value = *bignum.Pow(&scale.Value, y) + } for j := 0; j < d.Levels[i]; j++ { matrices = append(matrices, ckks.GenLinearTransformBSGS(encoder, pVecDFT[idx], level, scale, d.LogBSGSRatio, logdSlots)) @@ -247,7 +255,7 @@ func addMatrixRotToList(pVec map[int]bool, rotations []int, N1, slots int, repac index = (j / N1) * N1 if repack { - // Sparse repacking, occurring during the first IDFT matrix. + // Sparse repacking, occurring during the first DFT matrix of the CoeffsToSlots. index &= (2*slots - 1) } else { // Other cases @@ -269,9 +277,8 @@ func addMatrixRotToList(pVec map[int]bool, rotations []int, N1, slots int, repac return rotations } -func (d *HomomorphicDFTMatrixLiteral) computeBootstrappingDFTIndexMap() (rotationMap []map[int]bool) { +func (d *HomomorphicDFTMatrixLiteral) computeBootstrappingDFTIndexMap(logN int) (rotationMap []map[int]bool) { - logN := d.LogN logSlots := d.LogSlots ltType := d.Type repacki2r := d.RepackImag2Real @@ -308,10 +315,10 @@ func (d *HomomorphicDFTMatrixLiteral) computeBootstrappingDFTIndexMap() (rotatio if logSlots < logN-1 && ltType == Decode && i == 0 && repacki2r { - // Special initial matrix for the repacking before DFT + // Special initial matrix for the repacking before Decode rotationMap[i] = genWfftRepackIndexMap(logSlots, level) - // Merges this special initial matrix with the first layer of DFT + // Merges this special initial matrix with the first layer of Decode DFT rotationMap[i] = nextLevelfftIndexMap(rotationMap[i], logSlots, 2<>1, 0; i < params.Slots(); i, jdx, idx = i+1, jdx+gap, idx+gap { + gap := params.N() / (2 * slots) + for i, jdx, idx := 0, params.N()>>1, 0; i < slots; i, jdx, idx = i+1, jdx+gap, idx+gap { valuesFloat[idx] = real(values[i]) valuesFloat[jdx] = imag(values[i]) } // Encodes coefficient-wise and encrypts the test vector - plaintext := ckks.NewPlaintext(params, params.MaxLevel()) - encoder.EncodeCoeffs(valuesFloat, plaintext) - ciphertext := encryptor.EncryptNew(plaintext) + pt := ckks.NewPlaintext(params, params.MaxLevel()) + pt.LogSlots = LogSlots + + pt.EncodingDomain = rlwe.CoefficientsDomain + encoder.Encode(valuesFloat, pt) + pt.EncodingDomain = rlwe.SlotsDomain + + ct := encryptor.EncryptNew(pt) // Applies the homomorphic DFT - ct0, ct1 := eval.CoeffsToSlotsNew(ciphertext, CoeffsToSlotMatrices) + ct0, ct1 := eval.CoeffsToSlotsNew(ct, CoeffsToSlotMatrices) // Checks against the original coefficients if sparse { - coeffsReal := encoder.DecodeCoeffs(decryptor.DecryptNew(ct0)) + ct0.EncodingDomain = rlwe.CoefficientsDomain + + coeffsReal := make([]float64, params.N()) + + encoder.Decode(decryptor.DecryptNew(ct0), coeffsReal) // Plaintext circuit - vec := make([]complex128, 2*params.Slots()) + vec := make([]complex128, 2*slots) // Embed real vector into the complex vector (trivial) - for i, j := 0, params.Slots(); i < params.Slots(); i, j = i+1, j+1 { + for i, j := 0, slots; i < slots; i, j = i+1, j+1 { vec[i] = complex(valuesReal[i], 0) vec[j] = complex(valuesImag[i], 0) } // IFFT - encoder.IFFT(vec, params.LogSlots()+1) + encoder.IFFT(vec, LogSlots+1) // Extract complex vector into real vector vecReal := make([]float64, params.N()) - for i, idx, jdx := 0, 0, params.N()>>1; i < 2*params.Slots(); i, idx, jdx = i+1, idx+gap/2, jdx+gap/2 { + for i, idx, jdx := 0, 0, params.N()>>1; i < 2*slots; i, idx, jdx = i+1, idx+gap/2, jdx+gap/2 { vecReal[idx] = real(vec[i]) vecReal[jdx] = imag(vec[i]) } // Compares - verifyTestVectors(params, ecd2N, nil, vecReal, coeffsReal, params.LogSlots(), t) + verifyTestVectors(params, ecd2N, nil, vecReal, coeffsReal, t) } else { - coeffsReal := encoder.DecodeCoeffs(decryptor.DecryptNew(ct0)) - coeffsImag := encoder.DecodeCoeffs(decryptor.DecryptNew(ct1)) - vec0 := make([]complex128, params.Slots()) - vec1 := make([]complex128, params.Slots()) + ct0.EncodingDomain = rlwe.CoefficientsDomain + ct1.EncodingDomain = rlwe.CoefficientsDomain + + coeffsReal := make([]float64, params.N()) + coeffsImag := make([]float64, params.N()) + + encoder.Decode(decryptor.DecryptNew(ct0), coeffsReal) + encoder.Decode(decryptor.DecryptNew(ct1), coeffsImag) + + vec0 := make([]complex128, slots) + vec1 := make([]complex128, slots) // Embed real vector into the complex vector (trivial) - for i := 0; i < params.Slots(); i++ { + for i := 0; i < slots; i++ { vec0[i] = complex(valuesReal[i], 0) vec1[i] = complex(valuesImag[i], 0) } // IFFT - encoder.IFFT(vec0, params.LogSlots()) - encoder.IFFT(vec1, params.LogSlots()) + encoder.IFFT(vec0, LogSlots) + encoder.IFFT(vec1, LogSlots) // Extract complex vectors into real vectors vecReal := make([]float64, params.N()) vecImag := make([]float64, params.N()) - for i, j := 0, params.Slots(); i < params.Slots(); i, j = i+1, j+1 { + for i, j := 0, slots; i < slots; i, j = i+1, j+1 { vecReal[i], vecReal[j] = real(vec0[i]), imag(vec0[i]) vecImag[i], vecImag[j] = real(vec1[i]), imag(vec1[i]) } - verifyTestVectors(params, ecd2N, nil, vecReal, coeffsReal, params.LogSlots(), t) - verifyTestVectors(params, ecd2N, nil, vecImag, coeffsImag, params.LogSlots(), t) + verifyTestVectors(params, ecd2N, nil, vecReal, coeffsReal, t) + verifyTestVectors(params, ecd2N, nil, vecImag, coeffsImag, t) } }) } -func testSlotsToCoeffs(params ckks.Parameters, t *testing.T) { +func testSlotsToCoeffs(params ckks.Parameters, LogSlots int, t *testing.T) { - var sparse bool = params.LogSlots() < params.LogN()-1 + slots := 1 << LogSlots + + var sparse bool = LogSlots < params.LogN()-1 packing := "FullPacking" - if params.LogSlots() < params.LogN()-1 { + if LogSlots < params.LogN()-1 { packing = "SparsePacking" } @@ -310,9 +316,8 @@ func testSlotsToCoeffs(params ckks.Parameters, t *testing.T) { } SlotsToCoeffsParametersLiteral := HomomorphicDFTMatrixLiteral{ + LogSlots: LogSlots, Type: Decode, - LogN: params.LogN(), - LogSlots: params.LogSlots(), RepackImag2Real: true, LevelStart: params.MaxLevel(), Levels: Levels, @@ -345,33 +350,33 @@ func testSlotsToCoeffs(params ckks.Parameters, t *testing.T) { eval := NewEvaluator(params, evk) // Generates the n first slots of the test vector (real part to encode) - valuesReal := make([]complex128, params.Slots()) + valuesReal := make([]complex128, slots) for i := range valuesReal { valuesReal[i] = complex(sampling.RandFloat64(-1, 1), 0) } // Generates the n first slots of the test vector (imaginary part to encode) - valuesImag := make([]complex128, params.Slots()) + valuesImag := make([]complex128, slots) for i := range valuesImag { valuesImag[i] = complex(sampling.RandFloat64(-1, 1), 0) } // If sparse, there there is the space to store both vectors in one - logSlots := params.LogSlots() if sparse { for i := range valuesReal { valuesReal[i] = complex(real(valuesReal[i]), real(valuesImag[i])) } - logSlots++ + LogSlots++ } // Encodes and encrypts the test vectors plaintext := ckks.NewPlaintext(params, params.MaxLevel()) - encoder.Encode(valuesReal, plaintext, logSlots) + plaintext.LogSlots = LogSlots + encoder.Encode(valuesReal, plaintext) ct0 := encryptor.EncryptNew(plaintext) var ct1 *rlwe.Ciphertext if !sparse { - encoder.Encode(valuesImag, plaintext, logSlots) + encoder.Encode(valuesImag, plaintext) ct1 = encryptor.EncryptNew(plaintext) } @@ -379,13 +384,16 @@ func testSlotsToCoeffs(params ckks.Parameters, t *testing.T) { res := eval.SlotsToCoeffsNew(ct0, ct1, SlotsToCoeffsMatrix) // Decrypt and decode in the coefficient domain - coeffsFloat := encoder.DecodeCoeffs(decryptor.DecryptNew(res)) + coeffsFloat := make([]float64, params.N()) + res.EncodingDomain = rlwe.CoefficientsDomain + + encoder.Decode(decryptor.DecryptNew(res), coeffsFloat) // Extracts the coefficients and construct the complex vector // This is simply coefficient ordering - valuesTest := make([]complex128, params.Slots()) - gap := params.N() / (2 * params.Slots()) - for i, idx := 0, 0; i < params.Slots(); i, idx = i+1, idx+gap { + valuesTest := make([]complex128, slots) + gap := params.N() / (2 * slots) + for i, idx := 0, 0; i < slots; i, idx = i+1, idx+gap { valuesTest[i] = complex(coeffsFloat[idx], coeffsFloat[idx+(params.N()>>1)]) } @@ -400,16 +408,21 @@ func testSlotsToCoeffs(params ckks.Parameters, t *testing.T) { // Result is bit-reversed, so applies the bit-reverse permutation on the reference vector utils.BitReverseInPlaceSlice(valuesReal, params.Slots()) - verifyTestVectors(params, encoder, decryptor, valuesReal, valuesTest, params.LogSlots(), t) + verifyTestVectors(params, encoder, decryptor, valuesReal, valuesTest, t) }) } -func verifyTestVectors(params ckks.Parameters, encoder ckks.Encoder, decryptor rlwe.Decryptor, valuesWant, element interface{}, logSlots int, t *testing.T) { +func verifyTestVectors(params ckks.Parameters, encoder *ckks.Encoder, decryptor rlwe.Decryptor, valuesWant, element interface{}, t *testing.T) { + + precStats := ckks.GetPrecisionStats(params, encoder, decryptor, valuesWant, element, nil, false) - precStats := ckks.GetPrecisionStats(params, encoder, decryptor, valuesWant, element, logSlots, nil) if *printPrecisionStats { t.Log(precStats.String()) } - require.GreaterOrEqual(t, precStats.MeanPrecision.Real, minPrec) - require.GreaterOrEqual(t, precStats.MeanPrecision.Imag, minPrec) + + rf64, _ := precStats.MeanPrecision.Real.Float64() + if64, _ := precStats.MeanPrecision.Imag.Float64() + + require.GreaterOrEqual(t, rf64, minPrec) + require.GreaterOrEqual(t, if64, minPrec) } diff --git a/ckks/advanced/homomorphic_mod.go b/ckks/advanced/homomorphic_mod.go index 782a7922..6d78c75a 100644 --- a/ckks/advanced/homomorphic_mod.go +++ b/ckks/advanced/homomorphic_mod.go @@ -1,15 +1,15 @@ package advanced import ( - "fmt" "math" + "math/big" "math/bits" "math/cmplx" "github.com/tuneinsight/lattigo/v4/ckks" "github.com/tuneinsight/lattigo/v4/rlwe" "github.com/tuneinsight/lattigo/v4/utils" - "github.com/tuneinsight/lattigo/v4/utils/bignum/polynomial" + "github.com/tuneinsight/lattigo/v4/utils/bignum" ) // SineType is the type of function used during the bootstrapping @@ -57,8 +57,8 @@ type EvalModPoly struct { qDiff float64 scFac float64 sqrt2Pi float64 - sinePoly *ckks.Polynomial - arcSinePoly *ckks.Polynomial + sinePoly *bignum.Polynomial + arcSinePoly *bignum.Polynomial k float64 } @@ -98,8 +98,8 @@ func (evp *EvalModPoly) QDiff() float64 { // homomorphically evaluates x mod Q[0] (the first prime of the moduli chain) on the ciphertext. func NewEvalModPolyFromLiteral(params ckks.Parameters, evm EvalModLiteral) EvalModPoly { - var arcSinePoly *ckks.Polynomial - var sinePoly *ckks.Polynomial + var arcSinePoly *bignum.Polynomial + var sinePoly *bignum.Polynomial var sqrt2pi float64 doubleAngle := evm.DoubleAngle @@ -126,7 +126,14 @@ func NewEvalModPolyFromLiteral(params ckks.Parameters, evm EvalModLiteral) EvalM coeffs[i] = coeffs[i-2] * complex(float64(i*i-4*i+4)/float64(i*i-i), 0) } - arcSinePoly = ckks.NewPoly(coeffs) + arcSinePoly = bignum.NewPolynomial(bignum.Monomial, coeffs, nil) + arcSinePoly.IsEven = false + + for i := range arcSinePoly.Coeffs { + if i&1 == 0 { + arcSinePoly.Coeffs[i] = nil + } + } } else { sqrt2pi = math.Pow(0.15915494309189535*qDiff, 1.0/scFac) @@ -136,26 +143,44 @@ func NewEvalModPolyFromLiteral(params ckks.Parameters, evm EvalModLiteral) EvalM case SinContinuous: sinePoly = ckks.Approximate(sin2pi2pi, -K, K, evm.SineDegree) - case CosDiscrete: + sinePoly.IsEven = false - sinePoly = new(ckks.Polynomial) - sinePoly.Coeffs = ApproximateCos(evm.K, evm.SineDegree, float64(uint(1< the bit-precision doubles after each iteration. +// The method automatically estimates how many iterations are needed to achieve the desired precision, and returns an error if the input ciphertext +// does not have enough remaining level and if no bootstrapper was given. +// Note that the desired precision will never exceed log2(ct.Scale) - logN + 1. +func (eval *evaluator) GoldschmidtDivisionNew(ct *rlwe.Ciphertext, minValue, log2Targetprecision float64, btp rlwe.Bootstrapper) (ctInv *rlwe.Ciphertext, err error) { - eval.AddConst(cbar, 1, cbar) + params := eval.params - tmp := eval.AddConstNew(cbar, 1) - opOut = tmp.CopyNew() - - for i := 1; i < steps; i++ { - - eval.MulRelin(cbar, cbar, cbar) - - if err = eval.Rescale(cbar, op.Scale, cbar); err != nil { - return - } - - tmp = eval.AddConstNew(cbar, 1) - - eval.MulRelin(tmp, opOut, tmp) - - if err = eval.Rescale(tmp, op.Scale, tmp); err != nil { - return - } - - opOut = tmp.CopyNew() + start := math.Log2(1 - minValue) + var iters int + for start+log2Targetprecision > 0.5 { + start *= 2 // Doubles the bit-precision at each iteration + iters++ } - return + if depth := iters * params.DefaultScaleModuliRatio(); btp == nil && depth > ct.Level() { + return nil, fmt.Errorf("cannot GoldschmidtDivisionNew: ct.Level()=%d < depth=%d", ct.Level(), depth) + } + + a := eval.MulNew(ct, -1) + b := a.CopyNew() + eval.Add(a, 2, a) + eval.Add(b, 1, b) + + for i := 1; i < iters; i++ { + + if btp != nil && (b.Level() == btp.MinimumInputLevel() || b.Level() == params.DefaultScaleModuliRatio()-1) { + if b, err = btp.Bootstrap(b); err != nil { + return nil, err + } + } + + if btp != nil && (a.Level() == btp.MinimumInputLevel() || a.Level() == params.DefaultScaleModuliRatio()-1) { + if a, err = btp.Bootstrap(a); err != nil { + return nil, err + } + } + + eval.MulRelin(b, b, b) + if err = eval.Rescale(b, params.DefaultScale(), b); err != nil { + return nil, err + } + + if btp != nil && (b.Level() == btp.MinimumInputLevel() || b.Level() == params.DefaultScaleModuliRatio()-1) { + if b, err = btp.Bootstrap(b); err != nil { + return nil, err + } + } + + tmp := eval.MulRelinNew(a, b) + if err = eval.Rescale(tmp, params.DefaultScale(), tmp); err != nil { + return nil, err + } + + eval.SetScale(a, tmp.Scale) + + eval.Add(a, tmp, a) + } + + return a, nil } diff --git a/ckks/bootstrapping/bootstrapper.go b/ckks/bootstrapping/bootstrapper.go index e2bfe571..a41fb660 100644 --- a/ckks/bootstrapping/bootstrapper.go +++ b/ckks/bootstrapping/bootstrapper.go @@ -159,9 +159,9 @@ func newBootstrapperBase(params ckks.Parameters, btpParams Parameters, btpKey *E bb.params = params bb.Parameters = btpParams - bb.dslots = params.Slots() - bb.logdslots = params.LogSlots() - if params.LogSlots() < params.MaxLogSlots() { + bb.logdslots = btpParams.LogSlots() + bb.dslots = 1 << bb.logdslots + if bb.dslots < params.MaxLogSlots() { bb.dslots <<= 1 bb.logdslots++ } @@ -175,13 +175,15 @@ func newBootstrapperBase(params ckks.Parameters, btpParams Parameters, btpKey *E // The second correcting factor for approximate multiplication by Q is included in the coefficients of the EvalMod polynomials qDiff := bb.evalModPoly.QDiff() + Q0 := params.Q()[0] + // Q0/|m| - bb.q0OverMessageRatio = math.Exp2(math.Round(math.Log2(params.QiFloat64(0) / bb.evalModPoly.MessageRatio()))) + bb.q0OverMessageRatio = math.Exp2(math.Round(math.Log2(float64(Q0) / bb.evalModPoly.MessageRatio()))) // If the scale used during the EvalMod step is smaller than Q0, then we cannot increase the scale during // the EvalMod step to get a free division by MessageRatio, and we need to do this division (totally or partly) // during the CoeffstoSlots step - qDiv := bb.evalModPoly.ScalingFactor().Float64() / math.Exp2(math.Round(math.Log2(params.QiFloat64(0)))) + qDiv := bb.evalModPoly.ScalingFactor().Float64() / math.Exp2(math.Round(math.Log2(float64(Q0)))) // Sets qDiv to 1 if there is enough room for the division to happen using scale manipulation. if qDiv > 1 { @@ -192,8 +194,6 @@ func newBootstrapperBase(params ckks.Parameters, btpParams Parameters, btpKey *E // CoeffsToSlots vectors // Change of variable for the evaluation of the Chebyshev polynomial + cancelling factor for the DFT and SubSum + eventual scaling factor for the double angle formula - bb.CoeffsToSlotsParameters.LogN = params.LogN() - bb.CoeffsToSlotsParameters.LogSlots = params.LogSlots() if bb.CoeffsToSlotsParameters.Scaling == 0 { bb.CoeffsToSlotsParameters.Scaling = qDiv / (K * scFac * qDiff) @@ -205,8 +205,6 @@ func newBootstrapperBase(params ckks.Parameters, btpParams Parameters, btpKey *E // SlotsToCoeffs vectors // Rescaling factor to set the final ciphertext to the desired scale - bb.SlotsToCoeffsParameters.LogN = params.LogN() - bb.SlotsToCoeffsParameters.LogSlots = params.LogSlots() if bb.SlotsToCoeffsParameters.Scaling == 0 { bb.SlotsToCoeffsParameters.Scaling = bb.params.DefaultScale().Float64() / (bb.evalModPoly.ScalingFactor().Float64() / bb.evalModPoly.MessageRatio()) diff --git a/ckks/bootstrapping/bootstrapping.go b/ckks/bootstrapping/bootstrapping.go index eb71fdc9..90e1f63e 100644 --- a/ckks/bootstrapping/bootstrapping.go +++ b/ckks/bootstrapping/bootstrapping.go @@ -49,7 +49,7 @@ func (btp *Bootstrapper) Bootstrap(ctIn *rlwe.Ciphertext) (ctOut *rlwe.Ciphertex } // Scales the message to Q0/|m|, which is the maximum possible before ModRaise to avoid plaintext overflow. - if scale := math.Round((btp.params.QiFloat64(0) / btp.evalModPoly.MessageRatio()) / ctDiff.Scale.Float64()); scale > 1 { + if scale := math.Round((float64(btp.params.Q()[0]) / btp.evalModPoly.MessageRatio()) / ctDiff.Scale.Float64()); scale > 1 { btp.ScaleUp(ctDiff, rlwe.NewScale(scale), ctDiff) } @@ -61,14 +61,13 @@ func (btp *Bootstrapper) Bootstrap(ctIn *rlwe.Ciphertext) (ctOut *rlwe.Ciphertex tmp := btp.SubNew(ctDiff, ctOut) // 2^d * e - btp.MultByConst(tmp, 1<<16, tmp) + btp.Mul(tmp, 1<<16, tmp) // 2^d * e + 2^(d-n) * e' tmp = btp.bootstrap(tmp) // 2^(d-n) * e + 2^(d-2n) * e' - btp.MultByConst(tmp, btp.params.QiFloat64(tmp.Level())/float64(uint64(1<<16)), tmp) - + btp.Mul(tmp, float64(btp.params.Q()[tmp.Level()])/float64(uint64(1<<16)), tmp) tmp.Scale = tmp.Scale.Mul(rlwe.NewScale(btp.params.Q()[tmp.Level()])) if err := btp.Rescale(tmp, btp.params.DefaultScale(), tmp); err != nil { @@ -93,7 +92,7 @@ func (btp *Bootstrapper) bootstrap(ctIn *rlwe.Ciphertext) (ctOut *rlwe.Ciphertex } //SubSum X -> (N/dslots) * Y^dslots - btp.Trace(ctOut, btp.params.LogSlots(), ctOut) + btp.Trace(ctOut, ctOut.LogSlots, ctOut) // Step 2 : CoeffsToSlots (Homomorphic encoding) ctReal, ctImag := btp.CoeffsToSlotsNew(ctOut, btp.ctsMatrices) diff --git a/ckks/bootstrapping/bootstrapping_bench_test.go b/ckks/bootstrapping/bootstrapping_bench_test.go index 85ae7db1..8e88e79e 100644 --- a/ckks/bootstrapping/bootstrapping_bench_test.go +++ b/ckks/bootstrapping/bootstrapping_bench_test.go @@ -34,10 +34,10 @@ func BenchmarkBootstrap(b *testing.B) { panic(err) } - b.Run(ParamsToString(params, "Bootstrap/"), func(b *testing.B) { + b.Run(ParamsToString(params, btpParams.LogSlots(), "Bootstrap/"), func(b *testing.B) { for i := 0; i < b.N; i++ { - bootstrappingScale := rlwe.NewScale(math.Exp2(math.Round(math.Log2(btp.params.QiFloat64(0) / btp.evalModPoly.MessageRatio())))) + bootstrappingScale := rlwe.NewScale(math.Exp2(math.Round(math.Log2(float64(btp.params.Q()[0]) / btp.evalModPoly.MessageRatio())))) b.StopTimer() ct := ckks.NewCiphertext(params, 1, 0) @@ -54,7 +54,7 @@ func BenchmarkBootstrap(b *testing.B) { //SubSum X -> (N/dslots) * Y^dslots t = time.Now() - btp.Trace(ct, btp.params.LogSlots(), ct) + btp.Trace(ct, ct.LogSlots, ct) b.Log("After SubSum :", time.Since(t), ct.Level(), ct.Scale.Float64()) // Part 1 : Coeffs to slots diff --git a/ckks/bootstrapping/bootstrapping_test.go b/ckks/bootstrapping/bootstrapping_test.go index 3fcb9e75..4462292f 100644 --- a/ckks/bootstrapping/bootstrapping_test.go +++ b/ckks/bootstrapping/bootstrapping_test.go @@ -20,11 +20,11 @@ var minPrec float64 = 12.0 var flagLongTest = flag.Bool("long", false, "run the long test suite (all parameters + secure bootstrapping). Overrides -short and requires -timeout=0.") var printPrecisionStats = flag.Bool("print-precision", false, "print precision stats") -func ParamsToString(params ckks.Parameters, opname string) string { +func ParamsToString(params ckks.Parameters, LogSlots int, opname string) string { return fmt.Sprintf("%slogN=%d/LogSlots=%d/logQP=%f/levels=%d/a=%d/b=%d", opname, params.LogN(), - params.LogSlots(), + LogSlots, params.LogQP(), params.MaxLevel()+1, params.PCount(), @@ -80,58 +80,49 @@ func TestBootstrap(t *testing.T) { paramSet := DefaultParametersSparse[0] - ckksParamsLit, btpParams, err := NewParametersFromLiteral(paramSet.SchemeParams, paramSet.BootstrappingParams) - require.Nil(t, err) - - // Insecure params for fast testing only if !*flagLongTest { - ckksParamsLit.LogN = 13 - - // Corrects the message ratio to take into account the smaller number of slots and keep the same precision - btpParams.EvalModParameters.LogMessageRatio += paramSet.SchemeParams.LogN - 1 - ckksParamsLit.LogN - 1 - - ckksParamsLit.LogSlots = ckksParamsLit.LogN - 1 + paramSet.SchemeParams.LogN -= 3 } - Xs := ckksParamsLit.Xs + for _, LogSlots := range []int{1, paramSet.SchemeParams.LogN - 2, paramSet.SchemeParams.LogN - 1} { + for _, encapsulation := range []bool{true, false} { - EphemeralSecretWeight := btpParams.EphemeralSecretWeight + paramSet.BootstrappingParams.LogSlots = &LogSlots - for _, testSet := range [][]bool{{false, false}, {true, false}, {false, true}, {true, true}} { + ckksParamsLit, btpParams, err := NewParametersFromLiteral(paramSet.SchemeParams, paramSet.BootstrappingParams) + require.Nil(t, err) - if testSet[0] { - ckksParamsLit.Xs = &distribution.Ternary{H: EphemeralSecretWeight} - btpParams.EphemeralSecretWeight = 0 - } else { - ckksParamsLit.Xs = Xs - btpParams.EphemeralSecretWeight = EphemeralSecretWeight + // Insecure params for fast testing only + if !*flagLongTest { + // Corrects the message ratio to take into account the smaller number of slots and keep the same precision + btpParams.EvalModParameters.LogMessageRatio += utils.MinInt(utils.MaxInt(15-LogSlots, 0), 8) + } + + if !encapsulation { + ckksParamsLit.Xs = &distribution.Ternary{H: btpParams.EphemeralSecretWeight} + btpParams.EphemeralSecretWeight = 0 + } + + params, err := ckks.NewParametersFromLiteral(ckksParamsLit) + if err != nil { + panic(err) + } + + testbootstrap(params, btpParams, t) + runtime.GC() } - - if testSet[1] { - ckksParamsLit.LogSlots = ckksParamsLit.LogN - 2 - } else { - ckksParamsLit.LogSlots = ckksParamsLit.LogN - 1 - } - - params, err := ckks.NewParametersFromLiteral(ckksParamsLit) - if err != nil { - panic(err) - } - - testbootstrap(params, testSet[0], btpParams, t) - runtime.GC() } } -func testbootstrap(params ckks.Parameters, original bool, btpParams Parameters, t *testing.T) { +func testbootstrap(params ckks.Parameters, btpParams Parameters, t *testing.T) { btpType := "Encapsulation/" - if original { + if btpParams.EphemeralSecretWeight == 0 { btpType = "Original/" } - t.Run(ParamsToString(params, "Bootstrapping/FullCircuit/"+btpType), func(t *testing.T) { + t.Run(ParamsToString(params, btpParams.LogSlots(), "Bootstrapping/FullCircuit/"+btpType), func(t *testing.T) { kgen := ckks.NewKeyGenerator(params) sk := kgen.GenSecretKeyNew() @@ -146,22 +137,24 @@ func testbootstrap(params ckks.Parameters, original bool, btpParams Parameters, panic(err) } - values := make([]complex128, 1< 2 { + + if btpParams.LogSlots() > 1 { values[2] = complex(0.9238795325112867, 0.3826834323650898) values[3] = complex(0.9238795325112867, 0.3826834323650898) } plaintext := ckks.NewPlaintext(params, 0) - encoder.Encode(values, plaintext, params.LogSlots()) + plaintext.LogSlots = btpParams.LogSlots() + encoder.Encode(values, plaintext) - n := 2 + n := 1 ciphertexts := make([]*rlwe.Ciphertext, n) bootstrappers := make([]*Bootstrapper, n) @@ -183,17 +176,20 @@ func testbootstrap(params ckks.Parameters, original bool, btpParams Parameters, wg.Wait() for i := range ciphertexts { - verifyTestVectors(params, encoder, decryptor, values, ciphertexts[i], params.LogSlots(), t) + verifyTestVectors(params, encoder, decryptor, values, ciphertexts[i], t) } }) } -func verifyTestVectors(params ckks.Parameters, encoder ckks.Encoder, decryptor rlwe.Decryptor, valuesWant []complex128, element interface{}, logSlots int, t *testing.T) { - precStats := ckks.GetPrecisionStats(params, encoder, decryptor, valuesWant, element, logSlots, nil) +func verifyTestVectors(params ckks.Parameters, encoder *ckks.Encoder, decryptor rlwe.Decryptor, valuesWant, valuesHave interface{}, t *testing.T) { + precStats := ckks.GetPrecisionStats(params, encoder, decryptor, valuesWant, valuesHave, nil, false) if *printPrecisionStats { t.Log(precStats.String()) } - require.GreaterOrEqual(t, precStats.MeanPrecision.Real, minPrec) - require.GreaterOrEqual(t, precStats.MeanPrecision.Imag, minPrec) + rf64, _ := precStats.MeanPrecision.Real.Float64() + if64, _ := precStats.MeanPrecision.Imag.Float64() + + require.GreaterOrEqual(t, rf64, minPrec) + require.GreaterOrEqual(t, if64, minPrec) } diff --git a/ckks/bootstrapping/parameters.go b/ckks/bootstrapping/parameters.go index 44e8d5fc..d5aedcc1 100644 --- a/ckks/bootstrapping/parameters.go +++ b/ckks/bootstrapping/parameters.go @@ -22,16 +22,21 @@ type Parameters struct { // NewParametersFromLiteral takes as input a ckks.ParametersLiteral and a bootstrapping.ParametersLiteral structs and returns the // appropriate ckks.ParametersLiteral for the bootstrapping circuit as well as the instantiated bootstrapping.Parameters. // The returned ckks.ParametersLiteral contains allocated primes. -func NewParametersFromLiteral(paramsCKKS ckks.ParametersLiteral, paramsBootstrap ParametersLiteral) (ckks.ParametersLiteral, Parameters, error) { +func NewParametersFromLiteral(ckksLit ckks.ParametersLiteral, btpLit ParametersLiteral) (ckks.ParametersLiteral, Parameters, error) { var err error - if paramsCKKS.RingType != ring.Standard { + if ckksLit.RingType != ring.Standard { return ckks.ParametersLiteral{}, Parameters{}, fmt.Errorf("NewParametersFromLiteral: invalid ring.RingType: must be ring.Standard") } - CoeffsToSlotsFactorizationDepthAndLogScales := paramsBootstrap.GetCoeffsToSlotsFactorizationDepthAndLogScales() - SlotsToCoeffsFactorizationDepthAndLogScales := paramsBootstrap.GetSlotsToCoeffsFactorizationDepthAndLogScales() + var LogSlots int + if LogSlots, err = btpLit.GetLogSlots(ckksLit.LogN); err != nil { + return ckks.ParametersLiteral{}, Parameters{}, err + } + + CoeffsToSlotsFactorizationDepthAndLogScales := btpLit.GetCoeffsToSlotsFactorizationDepthAndLogScales(LogSlots) + SlotsToCoeffsFactorizationDepthAndLogScales := btpLit.GetSlotsToCoeffsFactorizationDepthAndLogScales(LogSlots) // Slots To Coeffs params SlotsToCoeffsLevels := make([]int, len(SlotsToCoeffsFactorizationDepthAndLogScales)) @@ -40,47 +45,48 @@ func NewParametersFromLiteral(paramsCKKS ckks.ParametersLiteral, paramsBootstrap } var Iterations int - if Iterations, err = paramsBootstrap.GetIterations(); err != nil { + if Iterations, err = btpLit.GetIterations(); err != nil { return ckks.ParametersLiteral{}, Parameters{}, err } S2CParams := advanced.HomomorphicDFTMatrixLiteral{ Type: advanced.Decode, + LogSlots: LogSlots, RepackImag2Real: true, - LevelStart: len(paramsCKKS.LogQ) - 1 + len(SlotsToCoeffsFactorizationDepthAndLogScales) + Iterations - 1, + LevelStart: len(ckksLit.LogQ) - 1 + len(SlotsToCoeffsFactorizationDepthAndLogScales) + Iterations - 1, LogBSGSRatio: 1, Levels: SlotsToCoeffsLevels, } var EvalModLogScale int - if EvalModLogScale, err = paramsBootstrap.GetEvalModLogScale(); err != nil { + if EvalModLogScale, err = btpLit.GetEvalModLogScale(); err != nil { return ckks.ParametersLiteral{}, Parameters{}, err } - SineType := paramsBootstrap.GetSineType() + SineType := btpLit.GetSineType() var ArcSineDegree int - if ArcSineDegree, err = paramsBootstrap.GetArcSineDegree(); err != nil { + if ArcSineDegree, err = btpLit.GetArcSineDegree(); err != nil { return ckks.ParametersLiteral{}, Parameters{}, err } var LogMessageRatio int - if LogMessageRatio, err = paramsBootstrap.GetLogMessageRatio(); err != nil { + if LogMessageRatio, err = btpLit.GetLogMessageRatio(); err != nil { return ckks.ParametersLiteral{}, Parameters{}, err } var K int - if K, err = paramsBootstrap.GetK(); err != nil { + if K, err = btpLit.GetK(); err != nil { return ckks.ParametersLiteral{}, Parameters{}, err } var DoubleAngle int - if DoubleAngle, err = paramsBootstrap.GetDoubleAngle(); err != nil { + if DoubleAngle, err = btpLit.GetDoubleAngle(); err != nil { return ckks.ParametersLiteral{}, Parameters{}, err } var SineDegree int - if SineDegree, err = paramsBootstrap.GetSineDegree(); err != nil { + if SineDegree, err = btpLit.GetSineDegree(); err != nil { return ckks.ParametersLiteral{}, Parameters{}, err } @@ -95,7 +101,7 @@ func NewParametersFromLiteral(paramsCKKS ckks.ParametersLiteral, paramsBootstrap } var EphemeralSecretWeight int - if EphemeralSecretWeight, err = paramsBootstrap.GetEphemeralSecretWeight(); err != nil { + if EphemeralSecretWeight, err = btpLit.GetEphemeralSecretWeight(); err != nil { return ckks.ParametersLiteral{}, Parameters{}, err } @@ -109,14 +115,15 @@ func NewParametersFromLiteral(paramsCKKS ckks.ParametersLiteral, paramsBootstrap C2SParams := advanced.HomomorphicDFTMatrixLiteral{ Type: advanced.Encode, + LogSlots: LogSlots, RepackImag2Real: true, LevelStart: EvalModParams.LevelStart + len(CoeffsToSlotsFactorizationDepthAndLogScales), LogBSGSRatio: 1, Levels: CoeffsToSlotsLevels, } - LogQ := make([]int, len(paramsCKKS.LogQ)) - copy(LogQ, paramsCKKS.LogQ) + LogQ := make([]int, len(ckksLit.LogQ)) + copy(LogQ, ckksLit.LogQ) for i := 0; i < Iterations-1; i++ { LogQ = append(LogQ, DefaultIterationsLogScale) @@ -128,8 +135,8 @@ func NewParametersFromLiteral(paramsCKKS ckks.ParametersLiteral, paramsBootstrap qi += SlotsToCoeffsFactorizationDepthAndLogScales[i][j] } - if qi+paramsCKKS.LogScale < 61 { - qi += paramsCKKS.LogScale + if qi+ckksLit.LogScale < 61 { + qi += ckksLit.LogScale } LogQ = append(LogQ, qi) @@ -147,23 +154,22 @@ func NewParametersFromLiteral(paramsCKKS ckks.ParametersLiteral, paramsBootstrap LogQ = append(LogQ, qi) } - LogP := make([]int, len(paramsCKKS.LogP)) - copy(LogP, paramsCKKS.LogP) + LogP := make([]int, len(ckksLit.LogP)) + copy(LogP, ckksLit.LogP) - Q, P, err := rlwe.GenModuli(paramsCKKS.LogN, LogQ, LogP) + Q, P, err := rlwe.GenModuli(ckksLit.LogN, LogQ, LogP) if err != nil { return ckks.ParametersLiteral{}, Parameters{}, err } return ckks.ParametersLiteral{ - LogN: paramsCKKS.LogN, + LogN: ckksLit.LogN, Q: Q, P: P, - LogSlots: paramsCKKS.LogSlots, - LogScale: paramsCKKS.LogScale, - Xe: paramsCKKS.Xe, - Xs: paramsCKKS.Xs, + LogScale: ckksLit.LogScale, + Xe: ckksLit.Xe, + Xs: ckksLit.Xs, }, Parameters{ EphemeralSecretWeight: EphemeralSecretWeight, @@ -174,6 +180,11 @@ func NewParametersFromLiteral(paramsCKKS ckks.ParametersLiteral, paramsBootstrap }, nil } +// LogSlots returns the LogSlots of the target Parameters. +func (p *Parameters) LogSlots() int { + return p.SlotsToCoeffsParameters.LogSlots +} + // DepthCoeffsToSlots returns the depth of the Coeffs to Slots of the CKKS bootstrapping. func (p *Parameters) DepthCoeffsToSlots() (depth int) { return p.SlotsToCoeffsParameters.Depth(true) @@ -210,22 +221,15 @@ func (p *Parameters) UnmarshalBinary(data []byte) (err error) { func (p *Parameters) GaloisElements(params ckks.Parameters) (galEls []uint64) { logN := params.LogN() - logSlots := params.LogSlots() // List of the rotation key values to needed for the bootstrapp keys := make(map[uint64]bool) //SubSum rotation needed X -> Y^slots rotations - for i := logSlots; i < logN-1; i++ { + for i := p.LogSlots(); i < logN-1; i++ { keys[params.GaloisElementForColumnRotationBy(1< LogN-1 { + return LogSlots, fmt.Errorf("field LogSlots cannot be smaller than 1 or greater than LogN-1") + } + } + + return +} + // GetCoeffsToSlotsFactorizationDepthAndLogScales returns a copy of the CoeffsToSlotsFactorizationDepthAndLogScales field of the target ParametersLiteral. // The default value constructed from DefaultC2SFactorization and DefaultC2SLogScale is returned if the field is nil. -func (p *ParametersLiteral) GetCoeffsToSlotsFactorizationDepthAndLogScales() (CoeffsToSlotsFactorizationDepthAndLogScales [][]int) { +func (p *ParametersLiteral) GetCoeffsToSlotsFactorizationDepthAndLogScales(LogSlots int) (CoeffsToSlotsFactorizationDepthAndLogScales [][]int) { if p.CoeffsToSlotsFactorizationDepthAndLogScales == nil { - CoeffsToSlotsFactorizationDepthAndLogScales = make([][]int, DefaultCoeffsToSlotsFactorizationDepth) + CoeffsToSlotsFactorizationDepthAndLogScales = make([][]int, utils.MinInt(DefaultCoeffsToSlotsFactorizationDepth, utils.MaxInt(LogSlots, 1))) for i := range CoeffsToSlotsFactorizationDepthAndLogScales { CoeffsToSlotsFactorizationDepthAndLogScales[i] = []int{DefaultCoeffsToSlotsLogScale} } @@ -139,9 +160,9 @@ func (p *ParametersLiteral) GetCoeffsToSlotsFactorizationDepthAndLogScales() (Co // GetSlotsToCoeffsFactorizationDepthAndLogScales returns a copy of the SlotsToCoeffsFactorizationDepthAndLogScales field of the target ParametersLiteral. // The default value constructed from DefaultS2CFactorization and DefaultS2CLogScale is returned if the field is nil. -func (p *ParametersLiteral) GetSlotsToCoeffsFactorizationDepthAndLogScales() (SlotsToCoeffsFactorizationDepthAndLogScales [][]int) { +func (p *ParametersLiteral) GetSlotsToCoeffsFactorizationDepthAndLogScales(LogSlots int) (SlotsToCoeffsFactorizationDepthAndLogScales [][]int) { if p.SlotsToCoeffsFactorizationDepthAndLogScales == nil { - SlotsToCoeffsFactorizationDepthAndLogScales = make([][]int, DefaultSlotsToCoeffsFactorizationDepth) + SlotsToCoeffsFactorizationDepthAndLogScales = make([][]int, utils.MinInt(DefaultSlotsToCoeffsFactorizationDepth, utils.MaxInt(LogSlots, 1))) for i := range SlotsToCoeffsFactorizationDepthAndLogScales { SlotsToCoeffsFactorizationDepthAndLogScales[i] = []int{DefaultSlotsToCoeffsLogScale} } @@ -294,16 +315,16 @@ func (p *ParametersLiteral) GetEphemeralSecretWeight() (EphemeralSecretWeight in // BitConsumption returns the expected consumption in bits of // bootstrapping circuit of the target ParametersLiteral. // The value is rounded up and thus will overestimate the value by up to 1 bit. -func (p *ParametersLiteral) BitConsumption() (logQ int, err error) { +func (p *ParametersLiteral) BitComsumption(LogSlots int) (logQ int, err error) { - C2SLogScale := p.GetCoeffsToSlotsFactorizationDepthAndLogScales() + C2SLogScale := p.GetCoeffsToSlotsFactorizationDepthAndLogScales(LogSlots) for i := range C2SLogScale { for _, logQi := range C2SLogScale[i] { logQ += logQi } } - S2CLogScale := p.GetSlotsToCoeffsFactorizationDepthAndLogScales() + S2CLogScale := p.GetSlotsToCoeffsFactorizationDepthAndLogScales(LogSlots) for i := range S2CLogScale { for _, logQi := range S2CLogScale[i] { logQ += logQi diff --git a/ckks/chebyshev_interpolation.go b/ckks/chebyshev_interpolation.go index ff9f9f91..3cf65420 100644 --- a/ckks/chebyshev_interpolation.go +++ b/ckks/chebyshev_interpolation.go @@ -3,13 +3,13 @@ package ckks import ( "math" - "github.com/tuneinsight/lattigo/v4/utils/bignum/polynomial" + "github.com/tuneinsight/lattigo/v4/utils/bignum" ) // Approximate computes a Chebyshev approximation of the input function, for the range [-a, b] of degree degree. // function.(type) can be either func(complex128)complex128 or func(float64)float64 // To be used in conjunction with the function EvaluateCheby. -func Approximate(function interface{}, a, b float64, degree int) (pol *Polynomial) { +func Approximate(function interface{}, a, b float64, degree int) (pol *bignum.Polynomial) { nodes := chebyshevNodes(degree+1, a, b) @@ -28,14 +28,7 @@ func Approximate(function interface{}, a, b float64, degree int) (pol *Polynomia panic("function must be either func(complex128)complex128 or func(float64)float64") } - pol = NewPoly(chebyCoeffs(nodes, fi, a, b)) - pol.A = a - pol.B = b - pol.MaxDeg = degree - pol.Lead = true - pol.Basis = polynomial.Chebyshev - - return + return bignum.NewPolynomial(bignum.Chebyshev, chebyCoeffs(nodes, fi, a, b), [2]float64{a, b}) } func chebyshevNodes(n int, a, b float64) (u []float64) { diff --git a/ckks/ckks.go b/ckks/ckks.go index 35d222e2..8bf1c92e 100644 --- a/ckks/ckks.go +++ b/ckks/ckks.go @@ -7,11 +7,15 @@ import ( ) func NewPlaintext(params Parameters, level int) (pt *rlwe.Plaintext) { - return rlwe.NewPlaintext(params.Parameters, level) + pt = rlwe.NewPlaintext(params.Parameters, level) + pt.LogSlots = params.MaxLogSlots() + return } func NewCiphertext(params Parameters, degree, level int) (ct *rlwe.Ciphertext) { - return rlwe.NewCiphertext(params.Parameters, degree, level) + ct = rlwe.NewCiphertext(params.Parameters, degree, level) + ct.LogSlots = params.MaxLogSlots() + return } func NewEncryptor(params Parameters, key interface{}) rlwe.Encryptor { diff --git a/ckks/ckks_benchmarks_test.go b/ckks/ckks_benchmarks_test.go index a60b0c39..d5521fde 100644 --- a/ckks/ckks_benchmarks_test.go +++ b/ckks/ckks_benchmarks_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "testing" + "github.com/tuneinsight/lattigo/v4/ring" "github.com/tuneinsight/lattigo/v4/rlwe" "github.com/tuneinsight/lattigo/v4/utils/sampling" ) @@ -12,66 +13,74 @@ func BenchmarkCKKSScheme(b *testing.B) { var err error - defaultParams := append(DefaultParams, DefaultConjugateInvariantParams...) - if testing.Short() { - defaultParams = DefaultParams[:2] + var testParams []ParametersLiteral + switch { + case *flagParamString != "": // the custom test suite reads the parameters from the -params flag + testParams = append(testParams, ParametersLiteral{}) + if err = json.Unmarshal([]byte(*flagParamString), &testParams[0]); err != nil { + b.Fatal(err) + } + default: + testParams = TestParamsLiteral } - if *flagParamString != "" { - var jsonParams ParametersLiteral - if err = json.Unmarshal([]byte(*flagParamString), &jsonParams); err != nil { - b.Fatal(err) - } - defaultParams = []ParametersLiteral{jsonParams} // the custom test suite reads the parameters from the -params flag - } + for _, ringType := range []ring.Type{ring.Standard, ring.ConjugateInvariant} { - for _, defaultParams := range defaultParams { - var params Parameters - if params, err = NewParametersFromLiteral(defaultParams); err != nil { - b.Fatal(err) - } + for _, paramsLiteral := range testParams { - var tc *testContext - if tc, err = genTestParams(params); err != nil { - b.Fatal(err) - } + paramsLiteral.RingType = ringType - benchEncoder(tc, b) - benchEvaluator(tc, b) + var params Parameters + if params, err = NewParametersFromLiteral(paramsLiteral); err != nil { + b.Fatal(err) + } + + var tc *testContext + if tc, err = genTestParams(params); err != nil { + b.Fatal(err) + } + + benchEncoder(tc, b) + benchEvaluator(tc, b) + } } } func benchEncoder(tc *testContext, b *testing.B) { encoder := tc.encoder - logSlots := tc.params.LogSlots() b.Run(GetTestName(tc.params, "Encoder/Encode"), func(b *testing.B) { - values := make([]complex128, 1<ct0"), func(t *testing.T) { + values1, plaintext1, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - values1, plaintext1, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + mul := bignum.NewComplexMultiplier() - for i := range values1 { - values1[i] *= values1[i] - } + for i := range values1 { + mul.Mul(values1[i], values1[i], values1[i]) + } - tc.evaluator.MulRelin(ciphertext1, plaintext1, ciphertext1) + ciphertext2 := tc.evaluator.MulNew(ciphertext1, plaintext1) - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogSlots(), nil, t) - }) + verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext2, nil, t) + }) - t.Run(GetTestName(tc.params, "pt*ct0->ct0"), func(t *testing.T) { + t.Run(GetTestName(tc.params, "Evaluator/Mul/Ct/Scalar"), func(t *testing.T) { - values1, plaintext1, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - for i := range values1 { - values1[i] *= values1[i] - } + constant := randomConst(tc.params.RingType(), tc.encoder.Prec(), -1+1i, -1+1i) - tc.evaluator.MulRelin(ciphertext1, plaintext1, ciphertext1) + mul := bignum.NewComplexMultiplier() - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogSlots(), nil, t) - }) + for i := range values { + mul.Mul(values[i], constant, values[i]) + } - t.Run(GetTestName(tc.params, "ct0*pt->ct1"), func(t *testing.T) { + tc.evaluator.Mul(ciphertext, constant, ciphertext) - values1, plaintext1, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, nil, t) + }) - for i := range values1 { - values1[i] *= values1[i] - } + t.Run(GetTestName(tc.params, "Evaluator/Mul/Ct/Pt"), func(t *testing.T) { - ciphertext2 := tc.evaluator.MulRelinNew(ciphertext1, plaintext1) + values1, plaintext1, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext2, tc.params.LogSlots(), nil, t) - }) + mul := bignum.NewComplexMultiplier() - t.Run(GetTestName(tc.params, "ct0*ct1->ct0"), func(t *testing.T) { + for i := range values1 { + mul.Mul(values1[i], values1[i], values1[i]) + } - values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + tc.evaluator.MulRelin(ciphertext1, plaintext1, ciphertext1) - for i := range values1 { - values2[i] *= values1[i] - } + verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, nil, t) + }) - tc.evaluator.MulRelin(ciphertext1, ciphertext2, ciphertext1) + t.Run(GetTestName(tc.params, "Evaluator/Mul/Ct/Ct/Degree0"), func(t *testing.T) { - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext1, tc.params.LogSlots(), nil, t) - }) + values1, plaintext1, _ := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - t.Run(GetTestName(tc.params, "ct0*ct1->ct0 (degree 0)"), func(t *testing.T) { + mul := bignum.NewComplexMultiplier() - values1, plaintext1, _ := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + for i := range values1 { + mul.Mul(values2[i], values1[i], values2[i]) + } - for i := range values1 { - values2[i] *= values1[i] - } + ciphertext1 := &rlwe.Ciphertext{Value: []*ring.Poly{plaintext1.Value}, MetaData: plaintext1.MetaData} - ciphertext1 := &rlwe.Ciphertext{} - ciphertext1.Value = []*ring.Poly{plaintext1.Value} - ciphertext1.MetaData = plaintext1.MetaData + tc.evaluator.MulRelin(ciphertext1, ciphertext2, ciphertext1) - tc.evaluator.MulRelin(ciphertext1, ciphertext2, ciphertext1) + verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext1, nil, t) + }) - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext1, tc.params.LogSlots(), nil, t) - }) + t.Run(GetTestName(tc.params, "Evaluator/MulRelin/Ct/Ct"), func(t *testing.T) { - t.Run(GetTestName(tc.params, "ct0*ct1->ct1"), func(t *testing.T) { + // op0 <- op0 * op1 + values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + mul := bignum.NewComplexMultiplier() - for i := range values1 { - values2[i] *= values1[i] - } + for i := range values1 { + mul.Mul(values1[i], values2[i], values1[i]) + } - tc.evaluator.MulRelin(ciphertext1, ciphertext2, ciphertext2) + tc.evaluator.MulRelin(ciphertext1, ciphertext2, ciphertext1) + require.Equal(t, ciphertext1.Degree(), 1) - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext2, tc.params.LogSlots(), nil, t) - }) + verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, nil, t) - t.Run(GetTestName(tc.params, "ct0*ct1->ct2"), func(t *testing.T) { + // op1 <- op0 * op1 + values1, _, ciphertext1 = newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values2, _, ciphertext2 = newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + for i := range values1 { + mul.Mul(values2[i], values1[i], values2[i]) + } - for i := range values1 { - values2[i] *= values1[i] - } + tc.evaluator.MulRelin(ciphertext1, ciphertext2, ciphertext2) + require.Equal(t, ciphertext2.Degree(), 1) - ciphertext3 := tc.evaluator.MulRelinNew(ciphertext1, ciphertext2) + verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext2, nil, t) - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext3, tc.params.LogSlots(), nil, t) - }) + // op0 <- op0 * op0 + for i := range values1 { + mul.Mul(values1[i], values1[i], values1[i]) + } - t.Run(GetTestName(tc.params, "ct0*ct0->ct0"), func(t *testing.T) { + tc.evaluator.MulRelin(ciphertext1, ciphertext1, ciphertext1) + require.Equal(t, ciphertext1.Degree(), 1) - values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - - for i := range values1 { - values1[i] *= values1[i] - } - - tc.evaluator.MulRelin(ciphertext1, ciphertext1, ciphertext1) - - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogSlots(), nil, t) - }) - - t.Run(GetTestName(tc.params, "ct0*ct0->ct1"), func(t *testing.T) { - - values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - - for i := range values1 { - values1[i] *= values1[i] - } - - ciphertext2 := tc.evaluator.MulRelinNew(ciphertext1, ciphertext1) - - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext2, tc.params.LogSlots(), nil, t) - }) - - t.Run(GetTestName(tc.params, "MulRelin(ct0*ct1->ct0)"), func(t *testing.T) { - - values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - - for i := range values1 { - values1[i] *= values2[i] - } - - tc.evaluator.MulRelin(ciphertext1, ciphertext2, ciphertext1) - require.Equal(t, ciphertext1.Degree(), 1) - - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogSlots(), nil, t) - }) + verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, nil, t) }) } func testEvaluatorMulThenAdd(tc *testContext, t *testing.T) { - t.Run(GetTestName(tc.params, "Evaluator/MulThenAdd/ct1*pt0->ct0"), func(t *testing.T) { - - values1, plaintext1, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - - for i := range values1 { - values1[i] += values1[i] * values2[i] - } - - tc.evaluator.MulRelinThenAdd(ciphertext2, plaintext1, ciphertext1) - - require.Equal(t, ciphertext1.Degree(), 1) - - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogSlots(), nil, t) - }) - - t.Run(GetTestName(tc.params, "Evaluator/MulRelinThenAdd/ct1*ct1->ct0"), func(t *testing.T) { + t.Run(GetTestName(tc.params, "Evaluator/MulThenAdd/Scalar"), func(t *testing.T) { values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + constant := randomConst(tc.params.RingType(), tc.encoder.Prec(), -1+1i, -1+1i) + + mul := bignum.NewComplexMultiplier() + + tmp := new(bignum.Complex) + tmp[0] = new(big.Float) + tmp[1] = new(big.Float) + for i := range values1 { - values1[i] += values2[i] * values2[i] + mul.Mul(values1[i], constant, tmp) + values2[i].Add(values2[i], tmp) } - tc.evaluator.MulRelinThenAdd(ciphertext2, ciphertext2, ciphertext1) + tc.evaluator.MulThenAdd(ciphertext1, constant, ciphertext2) + + verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext2, nil, t) + }) + + t.Run(GetTestName(tc.params, "Evaluator/MulThenAdd/Pt"), func(t *testing.T) { + + values1, plaintext1, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1, 1, t) + values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1, 1, t) + + mul := bignum.NewComplexMultiplier() + + tmp := new(bignum.Complex) + tmp[0] = new(big.Float) + tmp[1] = new(big.Float) + + for i := range values1 { + mul.Mul(values2[i], values1[i], tmp) + values1[i].Add(values1[i], tmp) + } + + tc.evaluator.MulThenAdd(ciphertext2, plaintext1, ciphertext1) require.Equal(t, ciphertext1.Degree(), 1) - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogSlots(), nil, t) + verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, nil, t) }) - t.Run(GetTestName(tc.params, "Evaluator/MulThenAdd/ct0*ct1->ct2"), func(t *testing.T) { + t.Run(GetTestName(tc.params, "Evaluator/MulRelinThenAdd/Ct"), func(t *testing.T) { + // op2 = op2 + op1 * op0 values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + mul := bignum.NewComplexMultiplier() + for i := range values1 { - values1[i] = values1[i] * values2[i] + mul.Mul(values1[i], values2[i], values2[i]) } ciphertext3 := NewCiphertext(tc.params, 2, ciphertext1.Level()) @@ -705,52 +643,47 @@ func testEvaluatorMulThenAdd(tc *testContext, t *testing.T) { require.Equal(t, ciphertext3.Degree(), 1) - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext3, tc.params.LogSlots(), nil, t) - }) + verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext3, nil, t) - t.Run(GetTestName(tc.params, "Evaluator/MulThenAdd/ct1*ct1->ct0"), func(t *testing.T) { - - values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - values2, _, ciphertext2 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + // op1 = op1 + op0*op0 + values1, _, ciphertext1 = newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values2, _, ciphertext2 = newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + tmp := bignum.NewComplex() for i := range values1 { - values1[i] += values2[i] * values2[i] + mul.Mul(values2[i], values2[i], tmp) + values1[i].Add(values1[i], tmp) } - tc.evaluator.MulThenAdd(ciphertext2, ciphertext2, ciphertext1) - - require.Equal(t, ciphertext1.Degree(), 2) - - tc.evaluator.Relinearize(ciphertext1, ciphertext1) + tc.evaluator.MulRelinThenAdd(ciphertext2, ciphertext2, ciphertext1) require.Equal(t, ciphertext1.Degree(), 1) - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogSlots(), nil, t) + verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, nil, t) }) } func testFunctions(tc *testContext, t *testing.T) { - t.Run(GetTestName(tc.params, "Evaluator/Inverse"), func(t *testing.T) { + t.Run(GetTestName(tc.params, "Evaluator/GoldschmidtDivisionNew"), func(t *testing.T) { - if tc.params.MaxLevel() < 7 { - t.Skip("skipping test for params max level < 7") - } + min := 0.1 - values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, 0.1+0i, 1+0i, t) - - n := 7 + values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, complex(min, 0), 1+0i, t) + one := new(big.Float).SetInt64(1) for i := range values { - values[i] = 1.0 / values[i] + values[i][0].Quo(one, values[i][0]) } + log2Targetprecision := math.Log2(tc.params.DefaultScale().Float64()) - float64(tc.params.LogN()) + var err error - if ciphertext, err = tc.evaluator.InverseNew(ciphertext, n); err != nil { + if ciphertext, err = tc.evaluator.GoldschmidtDivisionNew(ciphertext, min, log2Targetprecision, NewSimpleBootstrapper(tc.params, tc.sk)); err != nil { t.Fatal(err) } - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, tc.params.LogSlots(), nil, t) + verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, nil, t) }) } @@ -766,28 +699,30 @@ func testEvaluatePoly(tc *testContext, t *testing.T) { values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1, 1, t) - coeffs := []complex128{ - 1.0, - 1.0, - 1.0 / 2, - 1.0 / 6, - 1.0 / 24, - 1.0 / 120, - 1.0 / 720, - 1.0 / 5040, + prec := tc.encoder.Prec() + + coeffs := []*big.Float{ + bignum.NewFloat(1, prec), + bignum.NewFloat(1, prec), + new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(2, prec)), + new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(6, prec)), + new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(24, prec)), + new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(120, prec)), + new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(720, prec)), + new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(5040, prec)), } - poly := NewPoly(coeffs) + poly := bignum.NewPolynomial(bignum.Monomial, coeffs, nil) for i := range values { - values[i] = cmplx.Exp(values[i]) + values[i] = poly.Evaluate(values[i]) } if ciphertext, err = tc.evaluator.EvaluatePoly(ciphertext, poly, ciphertext.Scale); err != nil { t.Fatal(err) } - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, tc.params.LogSlots(), nil, t) + verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, nil, t) }) t.Run(GetTestName(tc.params, "EvaluatePoly/PolyVector/Exp"), func(t *testing.T) { @@ -798,37 +733,41 @@ func testEvaluatePoly(tc *testContext, t *testing.T) { values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1, 1, t) - coeffs := []complex128{ - 1.0, - 1.0, - 1.0 / 2, - 1.0 / 6, - 1.0 / 24, - 1.0 / 120, - 1.0 / 720, - 1.0 / 5040, + prec := tc.encoder.Prec() + + coeffs := []*big.Float{ + bignum.NewFloat(1, prec), + bignum.NewFloat(1, prec), + new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(2, prec)), + new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(6, prec)), + new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(24, prec)), + new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(120, prec)), + new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(720, prec)), + new(big.Float).Quo(bignum.NewFloat(1, prec), bignum.NewFloat(5040, prec)), } - poly := NewPoly(coeffs) + poly := bignum.NewPolynomial(bignum.Monomial, coeffs, nil) + + slots := ciphertext.Slots() slotIndex := make(map[int][]int) - idx := make([]int, tc.params.Slots()>>1) - for i := 0; i < tc.params.Slots()>>1; i++ { + idx := make([]int, slots>>1) + for i := 0; i < slots>>1; i++ { idx[i] = 2 * i } slotIndex[0] = idx - valuesWant := make([]complex128, tc.params.Slots()) + valuesWant := make([]*bignum.Complex, slots) for _, j := range idx { - valuesWant[j] = cmplx.Exp(values[j]) + valuesWant[j] = poly.Evaluate(values[j]) } - if ciphertext, err = tc.evaluator.EvaluatePolyVector(ciphertext, []*Polynomial{poly}, tc.encoder, slotIndex, ciphertext.Scale); err != nil { + if ciphertext, err = tc.evaluator.EvaluatePolyVector(ciphertext, []*bignum.Polynomial{poly}, tc.encoder, slotIndex, ciphertext.Scale); err != nil { t.Fatal(err) } - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, valuesWant, ciphertext, tc.params.LogSlots(), nil, t) + verifyTestVectors(tc.params, tc.encoder, tc.decryptor, valuesWant, ciphertext, nil, t) }) } @@ -838,7 +777,7 @@ func testChebyshevInterpolator(tc *testContext, t *testing.T) { t.Run(GetTestName(tc.params, "ChebyshevInterpolator/Sin"), func(t *testing.T) { - if tc.params.MaxLevel() < 5 { + if tc.params.MaxDepth() < 5 { t.Skip("skipping test for params max level < 5") } @@ -846,10 +785,11 @@ func testChebyshevInterpolator(tc *testContext, t *testing.T) { values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1, 1, t) - poly := Approximate(cmplx.Sin, -1.5, 1.5, 15) + poly := Approximate(cmplx.Sin, -1.5, 1.5, 7) - eval.MultByConst(ciphertext, 2/(poly.B-poly.A), ciphertext) - eval.AddConst(ciphertext, (-poly.A-poly.B)/(poly.B-poly.A), ciphertext) + scalar, constant := poly.ChangeOfBasis() + eval.Mul(ciphertext, scalar, ciphertext) + eval.Add(ciphertext, constant, ciphertext) if err = eval.Rescale(ciphertext, tc.params.DefaultScale(), ciphertext); err != nil { t.Fatal(err) @@ -857,14 +797,13 @@ func testChebyshevInterpolator(tc *testContext, t *testing.T) { if ciphertext, err = eval.EvaluatePoly(ciphertext, poly, ciphertext.Scale); err != nil { t.Fatal(err) - } for i := range values { - values[i] = cmplx.Sin(values[i]) + values[i] = poly.Evaluate(values[i]) } - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, tc.params.LogSlots(), nil, t) + verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, nil, t) }) } @@ -874,7 +813,9 @@ func testDecryptPublic(tc *testContext, t *testing.T) { t.Run(GetTestName(tc.params, "DecryptPublic/Sin"), func(t *testing.T) { - if tc.params.MaxLevel() < 5 { + degree := 7 + + if tc.params.MaxDepth() < bits.Len64(uint64(degree)) { t.Skip("skipping test for params max level < 5") } @@ -882,14 +823,16 @@ func testDecryptPublic(tc *testContext, t *testing.T) { values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1, 1, t) - poly := Approximate(cmplx.Sin, -1.5, 1.5, 15) + poly := Approximate(cmplx.Sin, -1.5, 1.5, degree) for i := range values { - values[i] = cmplx.Sin(values[i]) + values[i] = poly.Evaluate(values[i]) } - eval.MultByConst(ciphertext, 2/(poly.B-poly.A), ciphertext) - eval.AddConst(ciphertext, (-poly.A-poly.B)/(poly.B-poly.A), ciphertext) + scalar, constant := poly.ChangeOfBasis() + + eval.Mul(ciphertext, scalar, ciphertext) + eval.Add(ciphertext, constant, ciphertext) if err := eval.Rescale(ciphertext, tc.params.DefaultScale(), ciphertext); err != nil { t.Fatal(err) @@ -897,20 +840,26 @@ func testDecryptPublic(tc *testContext, t *testing.T) { if ciphertext, err = eval.EvaluatePoly(ciphertext, poly, ciphertext.Scale); err != nil { t.Fatal(err) - } plaintext := tc.decryptor.DecryptNew(ciphertext) - valuesHave := tc.encoder.Decode(plaintext, tc.params.LogSlots()) + valuesHave := make([]*big.Float, plaintext.Slots()) - verifyTestVectors(tc.params, tc.encoder, nil, values, valuesHave, tc.params.LogSlots(), nil, t) + tc.encoder.Decode(plaintext, valuesHave) - sigma := tc.encoder.GetErrSTDCoeffDomain(values, valuesHave, plaintext.Scale) + verifyTestVectors(tc.params, tc.encoder, nil, values, valuesHave, nil, t) - valuesHave = tc.encoder.DecodePublic(plaintext, tc.params.LogSlots(), &distribution.DiscreteGaussian{Sigma: sigma, Bound: 2.5066282746310002 * sigma}) + for i := range valuesHave { + valuesHave[i].Sub(valuesHave[i], values[i][0]) + } - verifyTestVectors(tc.params, tc.encoder, nil, values, valuesHave, tc.params.LogSlots(), nil, t) + // This should make it lose at most ~0.5 bit or precision. + sigma := StandardDeviation(valuesHave, rlwe.NewScale(plaintext.Scale.Float64()/math.Sqrt(float64(len(values))))) + + tc.encoder.DecodePublic(plaintext, valuesHave, &distribution.DiscreteGaussian{Sigma: sigma, Bound: 2.5066282746310002 * sigma}) + + verifyTestVectors(tc.params, tc.encoder, nil, values, valuesHave, nil, t) }) } @@ -957,15 +906,15 @@ func testBridge(tc *testContext, t *testing.T) { switcher.RealToComplex(evalStandar, ctCI, stdCTHave) - verifyTestVectors(stdParams, stdEncoder, stdDecryptor, values, stdCTHave, stdParams.LogSlots(), nil, t) + verifyTestVectors(stdParams, stdEncoder, stdDecryptor, values, stdCTHave, nil, t) - stdCTImag := stdEvaluator.MultByConstNew(stdCTHave, 1i) + stdCTImag := stdEvaluator.MulNew(stdCTHave, 1i) stdEvaluator.Add(stdCTHave, stdCTImag, stdCTHave) ciCTHave := NewCiphertext(ciParams, 1, stdCTHave.Level()) switcher.ComplexToReal(evalStandar, stdCTHave, ciCTHave) - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciCTHave, ciParams.LogSlots(), nil, t) + verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciCTHave, nil, t) }) } @@ -973,9 +922,13 @@ func testLinearTransform(tc *testContext, t *testing.T) { t.Run(GetTestName(tc.params, "Average"), func(t *testing.T) { + values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + + slots := ciphertext.Slots() + logBatch := 9 batch := 1 << logBatch - n := tc.params.Slots() / batch + n := slots / batch evk := rlwe.NewEvaluationKeySet() for _, galEl := range tc.params.GaloisElementsForInnerSum(batch, n) { @@ -984,60 +937,79 @@ func testLinearTransform(tc *testContext, t *testing.T) { eval := tc.evaluator.WithKey(evk) - values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + eval.Average(ciphertext, logBatch, ciphertext) - eval.Average(ciphertext1, logBatch, ciphertext1) + tmp0 := make([]*bignum.Complex, len(values)) + for i := range tmp0 { + tmp0[i] = values[i].Copy() + } - tmp0 := make([]complex128, len(values1)) - copy(tmp0, values1) + rotatebignumslice := func(s []*bignum.Complex, k int) []*bignum.Complex { + if k == 0 || len(s) == 0 { + return s + } + r := k % len(s) + if r < 0 { + r = r + len(s) + } + return append(s[r:], s[:r]...) + } for i := 1; i < n; i++ { tmp1 := utils.RotateSlice(tmp0, i*batch) - for j := range values1 { - values1[j] += tmp1[j] + for j := range values { + values[j].Add(values[j], tmp1[j]) } } - for i := range values1 { - values1[i] /= complex(float64(n), 0) + nB := new(big.Float).SetFloat64(float64(n)) + + for i := range values { + values[i][0].Quo(values[i][0], nB) + values[i][1].Quo(values[i][1], nB) } - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogSlots(), nil, t) + verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, nil, t) }) t.Run(GetTestName(tc.params, "LinearTransform/BSGS"), func(t *testing.T) { params := tc.params - values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - diagMatrix := make(map[int][]complex128) + slots := ciphertext.Slots() - diagMatrix[-15] = make([]complex128, params.Slots()) - diagMatrix[-4] = make([]complex128, params.Slots()) - diagMatrix[-1] = make([]complex128, params.Slots()) - diagMatrix[0] = make([]complex128, params.Slots()) - diagMatrix[1] = make([]complex128, params.Slots()) - diagMatrix[2] = make([]complex128, params.Slots()) - diagMatrix[3] = make([]complex128, params.Slots()) - diagMatrix[4] = make([]complex128, params.Slots()) - diagMatrix[15] = make([]complex128, params.Slots()) + diagMatrix := make(map[int][]*bignum.Complex) - for i := 0; i < params.Slots(); i++ { - diagMatrix[-15][i] = 1 - diagMatrix[-4][i] = 1 - diagMatrix[-1][i] = 1 - diagMatrix[0][i] = 1 - diagMatrix[1][i] = 1 - diagMatrix[2][i] = 1 - diagMatrix[3][i] = 1 - diagMatrix[4][i] = 1 - diagMatrix[15][i] = 1 + diagMatrix[-15] = make([]*bignum.Complex, slots) + diagMatrix[-4] = make([]*bignum.Complex, slots) + diagMatrix[-1] = make([]*bignum.Complex, slots) + diagMatrix[0] = make([]*bignum.Complex, slots) + diagMatrix[1] = make([]*bignum.Complex, slots) + diagMatrix[2] = make([]*bignum.Complex, slots) + diagMatrix[3] = make([]*bignum.Complex, slots) + diagMatrix[4] = make([]*bignum.Complex, slots) + diagMatrix[15] = make([]*bignum.Complex, slots) + + one := new(big.Float).SetInt64(1) + zero := new(big.Float) + + for i := 0; i < slots; i++ { + diagMatrix[-15][i] = &bignum.Complex{one, zero} + diagMatrix[-4][i] = &bignum.Complex{one, zero} + diagMatrix[-1][i] = &bignum.Complex{one, zero} + diagMatrix[0][i] = &bignum.Complex{one, zero} + diagMatrix[1][i] = &bignum.Complex{one, zero} + diagMatrix[2][i] = &bignum.Complex{one, zero} + diagMatrix[3][i] = &bignum.Complex{one, zero} + diagMatrix[4][i] = &bignum.Complex{one, zero} + diagMatrix[15][i] = &bignum.Complex{one, zero} } - linTransf := GenLinearTransformBSGS(tc.encoder, diagMatrix, params.MaxLevel(), rlwe.NewScale(params.Q()[params.MaxLevel()]), 2.0, params.logSlots) + linTransf := GenLinearTransformBSGS(tc.encoder, diagMatrix, params.MaxLevel(), rlwe.NewScale(params.Q()[params.MaxLevel()]), 2.0, ciphertext.LogSlots) evk := rlwe.NewEvaluationKeySet() for _, galEl := range linTransf.GaloisElements(params) { @@ -1046,42 +1018,49 @@ func testLinearTransform(tc *testContext, t *testing.T) { eval := tc.evaluator.WithKey(evk) - eval.LinearTransform(ciphertext1, linTransf, []*rlwe.Ciphertext{ciphertext1}) + eval.LinearTransform(ciphertext, linTransf, []*rlwe.Ciphertext{ciphertext}) - tmp := make([]complex128, params.Slots()) - copy(tmp, values1) - - for i := 0; i < params.Slots(); i++ { - values1[i] += tmp[(i-15+params.Slots())%params.Slots()] - values1[i] += tmp[(i-4+params.Slots())%params.Slots()] - values1[i] += tmp[(i-1+params.Slots())%params.Slots()] - values1[i] += tmp[(i+1)%params.Slots()] - values1[i] += tmp[(i+2)%params.Slots()] - values1[i] += tmp[(i+3)%params.Slots()] - values1[i] += tmp[(i+4)%params.Slots()] - values1[i] += tmp[(i+15)%params.Slots()] + tmp := make([]*bignum.Complex, len(values)) + for i := range tmp { + tmp[i] = values[i].Copy() } - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogSlots(), nil, t) + for i := 0; i < slots; i++ { + values[i].Add(values[i], tmp[(i-15+slots)%slots]) + values[i].Add(values[i], tmp[(i-4+slots)%slots]) + values[i].Add(values[i], tmp[(i-1+slots)%slots]) + values[i].Add(values[i], tmp[(i+1)%slots]) + values[i].Add(values[i], tmp[(i+2)%slots]) + values[i].Add(values[i], tmp[(i+3)%slots]) + values[i].Add(values[i], tmp[(i+4)%slots]) + values[i].Add(values[i], tmp[(i+15)%slots]) + } + + verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, nil, t) }) t.Run(GetTestName(tc.params, "LinearTransform/Naive"), func(t *testing.T) { params := tc.params - values1, _, ciphertext1 := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) + values, _, ciphertext := newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t) - diagMatrix := make(map[int][]complex128) + slots := ciphertext.Slots() - diagMatrix[-1] = make([]complex128, params.Slots()) - diagMatrix[0] = make([]complex128, params.Slots()) + diagMatrix := make(map[int][]*bignum.Complex) - for i := 0; i < params.Slots(); i++ { - diagMatrix[-1][i] = 1 - diagMatrix[0][i] = 1 + diagMatrix[-1] = make([]*bignum.Complex, slots) + diagMatrix[0] = make([]*bignum.Complex, slots) + + one := new(big.Float).SetInt64(1) + zero := new(big.Float) + + for i := 0; i < slots; i++ { + diagMatrix[-1][i] = &bignum.Complex{one, zero} + diagMatrix[0][i] = &bignum.Complex{one, zero} } - linTransf := GenLinearTransform(tc.encoder, diagMatrix, params.MaxLevel(), rlwe.NewScale(params.Q()[params.MaxLevel()]), params.LogSlots()) + linTransf := GenLinearTransform(tc.encoder, diagMatrix, params.MaxLevel(), rlwe.NewScale(params.Q()[params.MaxLevel()]), ciphertext.LogSlots) evk := rlwe.NewEvaluationKeySet() for _, galEl := range linTransf.GaloisElements(params) { @@ -1090,16 +1069,18 @@ func testLinearTransform(tc *testContext, t *testing.T) { eval := tc.evaluator.WithKey(evk) - eval.LinearTransform(ciphertext1, linTransf, []*rlwe.Ciphertext{ciphertext1}) + eval.LinearTransform(ciphertext, linTransf, []*rlwe.Ciphertext{ciphertext}) - tmp := make([]complex128, params.Slots()) - copy(tmp, values1) - - for i := 0; i < params.Slots(); i++ { - values1[i] += tmp[(i-1+params.Slots())%params.Slots()] + tmp := make([]*bignum.Complex, slots) + for i := range tmp { + tmp[i] = values[i].Copy() } - verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogSlots(), nil, t) + for i := 0; i < slots; i++ { + values[i].Add(values[i], tmp[(i-1+slots)%slots]) + } + + verifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, nil, t) }) } diff --git a/ckks/ckks_vector_ops.go b/ckks/ckks_vector_ops.go index 14fdf9cf..0b1dba3e 100644 --- a/ckks/ckks_vector_ops.go +++ b/ckks/ckks_vector_ops.go @@ -1,6 +1,7 @@ package ckks import ( + "math/big" "fmt" "math/bits" "unsafe" @@ -8,17 +9,14 @@ import ( "github.com/tuneinsight/lattigo/v4/utils" ) -const ( - minVecLenForLoopUnrolling = 16 + "github.com/tuneinsight/lattigo/v4/utils/bignum" ) -// SpecialiFFTVec performs the CKKS special inverse FFT transform in place. -func SpecialiFFTVec(values []complex128, N, M int, rotGroup []int, roots []complex128) { - if len(values) < N || len(rotGroup) < N || len(roots) < M+1 { panic(fmt.Sprintf("invalid call of SpecialiFFTVec: len(values)=%d or len(rotGroup)=%d < N=%d or len(roots)=%d < M+1=%d", len(values), len(rotGroup), N, len(roots), M)) } - +// SpecialIFFTDouble performs the CKKS special inverse FFT transform in place. +func SpecialIFFTDouble(values []complex128, N, M int, rotGroup []int, roots []complex128) { logN := int(bits.Len64(uint64(N))) - 1 logM := int(bits.Len64(uint64(M))) - 1 for loglen := logN; loglen > 0; loglen-- { @@ -41,9 +39,8 @@ func SpecialiFFTVec(values []complex128, N, M int, rotGroup []int, roots []compl utils.BitReverseInPlaceSlice(values, N) } -// SpecialFFTVec performs the CKKS special FFT transform in place. -func SpecialFFTVec(values []complex128, N, M int, rotGroup []int, roots []complex128) { - +// SpecialFFTDouble performs the CKKS special FFT transform in place. +func SpecialFFTDouble(values []complex128, N, M int, rotGroup []int, roots []complex128) { if len(values) < N || len(rotGroup) < N || len(roots) < M+1 { panic(fmt.Sprintf("invalid call of SpecialFFTVec: len(values)=%d or len(rotGroup)=%d < N=%d or len(roots)=%d < M+1=%d", len(values), len(rotGroup), N, len(roots), M)) } @@ -66,8 +63,75 @@ func SpecialFFTVec(values []complex128, N, M int, rotGroup []int, roots []comple } } -// SpecialFFTUL8Vec performs the CKKS special FFT transform in place with unrolled loops of size 8. -func SpecialFFTUL8Vec(values []complex128, N, M int, rotGroup []int, roots []complex128) { +// SpecialFFTArbitrary evaluates the decoding matrix on a slice of ring.Complex values. +func SpecialFFTArbitrary(values []*bignum.Complex, N, M int, rotGroup []int, roots []*bignum.Complex) { + + u := &bignum.Complex{new(big.Float), new(big.Float)} + v := &bignum.Complex{new(big.Float), new(big.Float)} + + SliceBitReverseInPlaceRingComplex(values, N) + + cMul := bignum.NewComplexMultiplier() + + logN := int(bits.Len64(uint64(N))) - 1 + logM := int(bits.Len64(uint64(M))) - 1 + for loglen := 1; loglen <= logN; loglen++ { + len := 1 << loglen + lenh := len >> 1 + lenq := len << 2 + logGap := logM - 2 - loglen + mask := lenq - 1 + for i := 0; i < N; i += len { + for j, k := 0, i; j < lenh; j, k = j+1, k+1 { + u.Set(values[i+j]) + v.Set(values[i+j+lenh]) + cMul.Mul(v, roots[(rotGroup[j]&mask)< 0; loglen-- { + len := 1 << loglen + lenh := len >> 1 + lenq := len << 2 + logGap := logM - 2 - loglen + mask := lenq - 1 + for i := 0; i < N; i += len { + for j, k := 0, i; j < lenh; j, k = j+1, k+1 { + u.Add(values[i+j], values[i+j+lenh]) + v.Sub(values[i+j], values[i+j+lenh]) + cMul.Mul(v, roots[(lenq-(rotGroup[j]&mask))<[]float64 ---------> Plaintext -// | -// Complex^{N/2} | -// EncodeSlots: []complex128/[]float64 -> iDFT ---┘ -type Encoder interface { +// Z_Q[X]/(X^N+1) +// Coefficients: ---------------> Real^{N} ---------> Plaintext +// | +// | +// Slots: Complex^{N/2} -> iDFT -----┘ +type Encoder struct { + prec uint - // Slots Encoding - Encode(values interface{}, plaintext *rlwe.Plaintext, logSlots int) - EncodeNew(values interface{}, level int, scale rlwe.Scale, logSlots int) (plaintext *rlwe.Plaintext) - EncodeSlots(values interface{}, plaintext *rlwe.Plaintext, logSlots int) - EncodeSlotsNew(values interface{}, level int, scale rlwe.Scale, logSlots int) (plaintext *rlwe.Plaintext) - Decode(plaintext *rlwe.Plaintext, logSlots int) (res []complex128) - DecodeSlots(plaintext *rlwe.Plaintext, logSlots int) (res []complex128) - DecodePublic(plaintext *rlwe.Plaintext, logSlots int, noise distribution.Distribution) []complex128 - DecodeSlotsPublic(plaintext *rlwe.Plaintext, logSlots int, noise distribution.Distribution) []complex128 - - FFT(values []complex128, N int) - IFFT(values []complex128, N int) - - // Coeffs Encoding - EncodeCoeffs(values []float64, plaintext *rlwe.Plaintext) - EncodeCoeffsNew(values []float64, level int, scale rlwe.Scale) (plaintext *rlwe.Plaintext) - DecodeCoeffs(plaintext *rlwe.Plaintext) (res []float64) - DecodeCoeffsPublic(plaintext *rlwe.Plaintext, noise distribution.Distribution) (res []float64) - - // Utility - Embed(values interface{}, logSlots int, scale rlwe.Scale, montgomery bool, polyOut interface{}) - GetErrSTDCoeffDomain(valuesWant, valuesHave []complex128, scale rlwe.Scale) (std float64) - GetErrSTDSlotDomain(valuesWant, valuesHave []complex128, scale rlwe.Scale) (std float64) - ShallowCopy() Encoder - Parameters() Parameters -} - -// encoder is a struct storing the necessary parameters to encode a slice of complex number on a Plaintext. -type encoder struct { params Parameters bigintCoeffs []*big.Int qHalf *big.Int @@ -79,43 +51,52 @@ type encoder struct { m int rotGroup []int - prng sampling.PRNG + prng utils.PRNG + + roots interface{} + buffCmplx interface{} } -type encoderComplex128 struct { - encoder - values []complex128 - valuesFloat []float64 - roots []complex128 -} - -// ShallowCopy creates a shallow copy of encoder in which all the read-only data-structures are -// shared with the receiver and the temporary buffers are reallocated. The receiver and the returned -// encoder can be used concurrently. -func (ecd *encoder) ShallowCopy() *encoder { +func (ecd *Encoder) ShallowCopy() *Encoder { prng, err := sampling.NewPRNG() if err != nil { panic(err) } - return &encoder{ + var buffCmplx interface{} + + if prec := ecd.prec; prec <= 53 { + buffCmplx = make([]complex128, ecd.m>>1) + } else { + tmp := make([]*bignum.Complex, ecd.m>>2) + + for i := 0; i < ecd.m>>2; i++ { + tmp[i] = &bignum.Complex{bignum.NewFloat(0, prec), bignum.NewFloat(0, prec)} + } + + buffCmplx = tmp + } + + return &Encoder{ + prec: ecd.prec, params: ecd.params, - bigintCoeffs: make([]*big.Int, ecd.m>>1), - qHalf: ring.NewUint(0), - buff: ecd.params.RingQ().NewPoly(), + bigintCoeffs: make([]*big.Int, len(ecd.bigintCoeffs)), + qHalf: new(big.Int), + buff: ecd.buff.CopyNew(), m: ecd.m, rotGroup: ecd.rotGroup, prng: prng, + roots: ecd.roots, + buffCmplx: buffCmplx, } } -// Parameters returns the parameters used by the encoder. -func (ecd *encoder) Parameters() Parameters { - return ecd.params -} - -func newEncoder(params Parameters) encoder { +// NewEncoder creates a new Encoder from the target parameters. +// Optional field `precision` can be given. If precision is empty +// or <= 53, then float64 and complex128 types will be used to +// perform the encoding. Else *big.Float and *bignum.Complex will be used. +func NewEncoder(params Parameters, precision ...uint) (ecd *Encoder) { m := int(params.RingQ().NthRoot()) @@ -132,192 +113,121 @@ func newEncoder(params Parameters) encoder { panic(err) } - return encoder{ + var prec uint + if len(precision) != 0 && precision[0] != 0 { + prec = precision[0] + } else { + prec = params.DefaultPrecision() + } + + ecd = &Encoder{ + prec: prec, params: params, bigintCoeffs: make([]*big.Int, m>>1), - qHalf: ring.NewUint(0), + qHalf: bignum.NewInt(0), buff: params.RingQ().NewPoly(), m: m, rotGroup: rotGroup, prng: prng, } -} -// NewEncoder creates a new Encoder that is used to encode a slice of complex values of size at most N/2 (the number of slots) on a Plaintext. -func NewEncoder(params Parameters) Encoder { + if prec <= 53 { - ecd := newEncoder(params) + ecd.roots = GetRootsFloat64(ecd.m) + ecd.buffCmplx = make([]complex128, ecd.m>>2) - return &encoderComplex128{ - encoder: ecd, - roots: GetRootsFloat64(ecd.m), - values: make([]complex128, ecd.m>>2), - valuesFloat: make([]float64, ecd.m>>1), + } else { + + tmp := make([]*bignum.Complex, ecd.m>>2) + + for i := 0; i < ecd.m>>2; i++ { + tmp[i] = &bignum.Complex{bignum.NewFloat(0, prec), bignum.NewFloat(0, prec)} + } + + ecd.roots = GetRootsbigFloat(ecd.m, prec) + ecd.buffCmplx = tmp } -} -// Encode encodes a set of values on the target plaintext. -// This method is identical to "EncodeSlots". -// Encoding is done at the level and scale of the plaintext. -// User must ensure that 1 <= len(values) <= 2^logSlots < 2^logN and that logSlots >= 3. -// values.(type) can be either []complex128 of []float64. -// The imaginary part of []complex128 will be discarded if ringType == ring.ConjugateInvariant. -// Returned plaintext is always in the NTT domain. -func (ecd *encoderComplex128) Encode(values interface{}, plaintext *rlwe.Plaintext, logSlots int) { - ecd.Embed(values, logSlots, plaintext.Scale, false, plaintext.Value) -} - -// EncodeNew encodes a set of values on a new plaintext. -// This method is identical to "EncodeSlotsNew". -// Encoding is done at the provided level and with the provided scale. -// User must ensure that 1 <= len(values) <= 2^logSlots < 2^logN and that logSlots >= 3. -// values.(type) can be either []complex128 of []float64. -// The imaginary part of []complex128 will be discarded if ringType == ring.ConjugateInvariant. -// Returned plaintext is always in the NTT domain. -func (ecd *encoderComplex128) EncodeNew(values interface{}, level int, scale rlwe.Scale, logSlots int) (plaintext *rlwe.Plaintext) { - plaintext = NewPlaintext(ecd.params, level) - plaintext.Scale = scale - ecd.Encode(values, plaintext, logSlots) return } -// EncodeSlots encodes a set of values on the target plaintext. -// Encoding is done at the level and scale of the plaintext. -// User must ensure that 1 <= len(values) <= 2^logSlots < 2^logN and that logSlots >= 3. -// values.(type) can be either []complex128 of []float64. -// The imaginary part of []complex128 will be discarded if ringType == ring.ConjugateInvariant. -// Returned plaintext is always in the NTT domain. -func (ecd *encoderComplex128) EncodeSlots(values interface{}, plaintext *rlwe.Plaintext, logSlots int) { - ecd.Encode(values, plaintext, logSlots) +// Prec returns the precision in bits used by the target Encoder. +// A precision <= 53 will use float64, else *big.Float. +func (ecd *Encoder) Prec() uint { + return ecd.prec } -// EncodeSlotsNew encodes a set of values on a new plaintext. -// Encoding is done at the provided level and with the provided scale. -// User must ensure that 1 <= len(values) <= 2^logSlots < 2^logN and that logSlots >= 3. -// values.(type) can be either []complex128 of []float64. +// Parameters returns the Parameters used by the target Encoder. +func (ecd *Encoder) Parameters() Parameters { + return ecd.params +} + +// Encode encodes a set of values on the target plaintext. +// Encoding is done at the level and scale of the plaintext. +// Encoding domain is done according to the metadata of the plaintext. +// User must ensure that 1 <= len(values) <= 2^pt.LogSlots < 2^logN. +// Accepted values.(type) for `rlwe.EncodingDomain = rlwe.SlotsDomain` is []complex128 of []float64. +// Accepted values.(type) for `rlwe.EncodingDomain = rlwe.CoefficientDomain` is []float64. // The imaginary part of []complex128 will be discarded if ringType == ring.ConjugateInvariant. -// Returned plaintext is always in the NTT domain. -func (ecd *encoderComplex128) EncodeSlotsNew(values interface{}, level int, scale rlwe.Scale, logSlots int) (plaintext *rlwe.Plaintext) { - return ecd.EncodeNew(values, level, scale, logSlots) +func (ecd *Encoder) Encode(values interface{}, pt *rlwe.Plaintext) (err error) { + + switch pt.EncodingDomain { + case rlwe.SlotsDomain: + + return ecd.Embed(values, pt.LogSlots, pt.Scale, false, pt.Value) + + case rlwe.CoefficientsDomain: + + switch values := values.(type) { + case []float64: + + if len(values) > ecd.params.N() { + return fmt.Errorf("cannot Encode: maximum number of values is %d but len(values) is %d", ecd.params.N(), len(values)) + } + + Float64ToFixedPointCRT(ecd.params.RingQ().AtLevel(pt.Level()), values, pt.Scale.Float64(), pt.Value.Coeffs) + + case []*big.Float: + + if len(values) > ecd.params.N() { + return fmt.Errorf("cannot Encode: maximum number of values is %d but len(values) is %d", ecd.params.N(), len(values)) + } + + BigFloatToFixedPointCRT(ecd.params.RingQ().AtLevel(pt.Level()), values, &pt.Scale.Value, pt.Value.Coeffs) + + default: + return fmt.Errorf("cannot Encode: supported values.(type) for %T encoding domain is []float64 or []*big.Float, but %T was given", rlwe.CoefficientsDomain, values) + } + + ecd.params.RingQ().AtLevel(pt.Level()).NTT(pt.Value, pt.Value) + + default: + return fmt.Errorf("cannot Encode: invalid rlwe.EncodingType, accepted types are rlwe.SlotsDomain and rlwe.CoefficientsDomain but is %T", pt.EncodingDomain) + } + + return } // Decode decodes the input plaintext on a new slice of complex128. // This method is the same as .DecodeSlots(*). -func (ecd *encoderComplex128) Decode(plaintext *rlwe.Plaintext, logSlots int) (res []complex128) { - return ecd.DecodeSlotsPublic(plaintext, logSlots, nil) -} - -// DecodeSlots decodes the input plaintext on a new slice of complex128. -func (ecd *encoderComplex128) DecodeSlots(plaintext *rlwe.Plaintext, logSlots int) (res []complex128) { - return ecd.decodePublic(plaintext, logSlots, nil) +func (ecd *Encoder) Decode(pt *rlwe.Plaintext, values interface{}) (err error) { + return ecd.DecodePublic(pt, values, nil) } // DecodePublic decodes the input plaintext on a new slice of complex128. -// This method is the same as .DecodeSlotsPublic(*). // Adds, before the decoding step, noise following the given distribution. // If the underlying ringType is ConjugateInvariant, the imaginary part (and its related error) are zero. -func (ecd *encoderComplex128) DecodePublic(plaintext *rlwe.Plaintext, logSlots int, noise distribution.Distribution) (res []complex128) { - return ecd.DecodeSlotsPublic(plaintext, logSlots, noise) -} - -// DecodeSlotsPublic decodes the input plaintext on a new slice of complex128. -// Adds, before the decoding step, noise following the given distribution. -// If the underlying ringType is ConjugateInvariant, the imaginary part (and its related error) are zero. -func (ecd *encoderComplex128) DecodeSlotsPublic(plaintext *rlwe.Plaintext, logSlots int, noise distribution.Distribution) (res []complex128) { - return ecd.decodePublic(plaintext, logSlots, noise) -} - -// EncodeCoeffs encodes the values on the coefficient of the plaintext polynomial. -// Encoding is done at the level and scale of the plaintext. -// User must ensure that 1<= len(values) <= 2^LogN -func (ecd *encoderComplex128) EncodeCoeffs(values []float64, plaintext *rlwe.Plaintext) { - - if len(values) > ecd.params.N() { - panic("cannot EncodeCoeffs: too many values (maximum is N)") - } - - FloatToFixedPointCRT(ecd.params.RingQ().AtLevel(plaintext.Level()), values, plaintext.Scale.Float64(), plaintext.Value.Coeffs) - ecd.params.RingQ().AtLevel(plaintext.Level()).NTT(plaintext.Value, plaintext.Value) -} - -// EncodeCoeffsNew encodes the values on the coefficient of a new plaintext. -// Encoding is done at the provided level and with the provided scale. -// User must ensure that 1<= len(values) <= 2^LogN -func (ecd *encoderComplex128) EncodeCoeffsNew(values []float64, level int, scale rlwe.Scale) (plaintext *rlwe.Plaintext) { - plaintext = NewPlaintext(ecd.params, level) - plaintext.Scale = scale - ecd.EncodeCoeffs(values, plaintext) - return -} - -// DecodeCoeffs reconstructs the RNS coefficients of the plaintext on a slice of float64. -func (ecd *encoderComplex128) DecodeCoeffs(plaintext *rlwe.Plaintext) (res []float64) { - return ecd.decodeCoeffsPublic(plaintext, nil) -} - -// DecodeCoeffsPublic reconstructs the RNS coefficients of the plaintext on a slice of float64. -// Adds noise following the given distribution to the decoding output. -func (ecd *encoderComplex128) DecodeCoeffsPublic(plaintext *rlwe.Plaintext, noise distribution.Distribution) (res []float64) { - return ecd.decodeCoeffsPublic(plaintext, noise) -} - -// GetErrSTDCoeffDomain returns StandardDeviation(Encode(valuesWant-valuesHave))*scale -// which is the scaled standard deviation in the coefficient domain of the difference -// of two complex vector in the slot domain. -func (ecd *encoderComplex128) GetErrSTDCoeffDomain(valuesWant, valuesHave []complex128, scale rlwe.Scale) (std float64) { - - for i := range valuesHave { - ecd.values[i] = (valuesWant[i] - valuesHave[i]) - } - - for i := len(valuesHave); i < len(ecd.values); i++ { - ecd.values[i] = complex(0, 0) - } - - logSlots := bits.Len64(uint64(len(valuesHave) - 1)) - - ecd.IFFT(ecd.values, logSlots) - - for i := range valuesWant { - ecd.valuesFloat[2*i] = real(ecd.values[i]) - ecd.valuesFloat[2*i+1] = imag(ecd.values[i]) - } - - return StandardDeviation(ecd.valuesFloat[:len(valuesWant)*2], scale.Float64()) -} - -// GetErrSTDSlotDomain returns StandardDeviation(valuesWant-valuesHave)*scale -// which is the scaled standard deviation of two complex vectors. -func (ecd *encoderComplex128) GetErrSTDSlotDomain(valuesWant, valuesHave []complex128, scale rlwe.Scale) (std float64) { - var err complex128 - for i := range valuesWant { - err = valuesWant[i] - valuesHave[i] - ecd.valuesFloat[2*i] = real(err) - ecd.valuesFloat[2*i+1] = imag(err) - } - - return StandardDeviation(ecd.valuesFloat[:len(valuesWant)*2], scale.Float64()) -} - -// ShallowCopy creates a shallow copy of this encoderComplex128 in which all the read-only data-structures are -// shared with the receiver and the temporary buffers are reallocated. The receiver and the returned -// Encoder can be used concurrently. -func (ecd *encoderComplex128) ShallowCopy() Encoder { - return &encoderComplex128{ - encoder: *ecd.encoder.ShallowCopy(), - values: make([]complex128, len(ecd.values)), - valuesFloat: make([]float64, len(ecd.valuesFloat)), - roots: ecd.roots, - } +func (ecd *Encoder) DecodePublic(pt *rlwe.Plaintext, values interface{}, noise distribution.Distribution) (err error) { + return ecd.decodePublic(pt, values, noise) } // Embed is a generic method to encode a set of values on the target polyOut interface. // This method it as the core of the slot encoding. -// values: values.(type) can be either []complex128 of []float64. +// values: values.(type) can be either []complex128, []*bignum.Complex, []float64 or []*big.Float. // -// The imaginary part of []complex128 will be discarded if ringType == ring.ConjugateInvariant. +// The imaginary part of []complex128 or []*bignum.Complex will be discarded if ringType == ring.ConjugateInvariant. // -// logslots: user must ensure that 1 <= len(values) <= 2^logSlots < 2^logN and that logSlots >= 3. +// logslots: user must ensure that 1 <= len(values) <= 2^logSlots < 2^logN. // scale: the scaling factor used do discretize float64 to fixed point integers. // montgomery: if true then the value written on polyOut are put in the Montgomery domain. // polyOut: polyOut.(type) can be either ringqp.Poly or *ring.Poly. @@ -325,113 +235,583 @@ func (ecd *encoderComplex128) ShallowCopy() Encoder { // The encoding encoding is done at the level of polyOut. // // Values written on polyOut are always in the NTT domain. -func (ecd *encoderComplex128) Embed(values interface{}, logSlots int, scale rlwe.Scale, montgomery bool, polyOut interface{}) { +func (ecd *Encoder) Embed(values interface{}, logSlots int, scale rlwe.Scale, montgomery bool, polyOut interface{}) (err error) { + if ecd.prec <= 53 { + return ecd.embedDouble(values, logSlots, scale, montgomery, polyOut) + } - if logSlots < minLogSlots || logSlots > ecd.params.MaxLogSlots() { - panic(fmt.Sprintf("cannot Embed: logSlots (%d) must be greater or equal to %d and smaller than %d\n", logSlots, minLogSlots, ecd.params.MaxLogSlots())) + return ecd.embedArbitrary(values, logSlots, scale, montgomery, polyOut) +} + +func (ecd *Encoder) embedDouble(values interface{}, logSlots int, scale rlwe.Scale, montgomery bool, polyOut interface{}) (err error) { + + if logSlots < 0 || logSlots > ecd.params.MaxLogSlots() { + return fmt.Errorf("cannot Embed: logSlots (%d) must be greater or equal to %d and smaller than %d", logSlots, 0, ecd.params.MaxLogSlots()) } slots := 1 << logSlots var lenValues int - // First checks the type of input values + buffCmplx := ecd.buffCmplx.([]complex128) + switch values := values.(type) { - // If complex case []complex128: - // Checks that the number of values is with the possible range - if len(values) > ecd.params.MaxSlots() || len(values) > slots { - panic(fmt.Sprintf("cannot Embed: ensure that #values (%d) <= slots (%d) <= maxSlots (%d)\n", len(values), slots, ecd.params.MaxSlots())) - } lenValues = len(values) - switch ecd.params.RingType() { + if lenValues > ecd.params.MaxSlots() || lenValues > slots { + return fmt.Errorf("cannot Embed: ensure that #values (%d) <= slots (%d) <= maxSlots (%d)", len(values), slots, ecd.params.MaxSlots()) + } - case ring.Standard: - copy(ecd.values[:len(values)], values) - - case ring.ConjugateInvariant: - // Discards the imaginary part + if ecd.params.RingType() == ring.ConjugateInvariant { for i := range values { - ecd.values[i] = complex(real(values[i]), 0) + buffCmplx[i] = complex(real(values[i]), 0) } - - // Else panics - default: - panic("cannot Embed: ringType must be ring.Standard or ring.ConjugateInvariant") + } else { + copy(buffCmplx[:len(values)], values) } - // If floats only - case []float64: - if len(values) > ecd.params.MaxSlots() || len(values) > slots { - panic(fmt.Sprintf("cannot Embed: ensure that #values (%d) <= slots (%d) <= maxSlots (%d)\n", len(values), slots, ecd.params.MaxSlots())) - } + case []*bignum.Complex: lenValues = len(values) + if lenValues > ecd.params.MaxSlots() || lenValues > slots { + return fmt.Errorf("cannot Embed: ensure that #values (%d) <= slots (%d) <= maxSlots (%d)", len(values), slots, ecd.params.MaxSlots()) + } + + if ecd.params.RingType() == ring.ConjugateInvariant { + for i := range values { + if values[i] != nil { + f64, _ := values[i][0].Float64() + buffCmplx[i] = complex(f64, 0) + } else { + buffCmplx[i] = 0 + } + } + } else { + for i := range values { + if values[i] != nil { + buffCmplx[i] = values[i].Complex128() + } else { + buffCmplx[i] = 0 + } + } + } + + case []float64: + + lenValues = len(values) + + if lenValues > ecd.params.MaxSlots() || lenValues > slots { + return fmt.Errorf("cannot Embed: ensure that #values (%d) <= slots (%d) <= maxSlots (%d)", len(values), slots, ecd.params.MaxSlots()) + } + for i := range values { - ecd.values[i] = complex(values[i], 0) + buffCmplx[i] = complex(values[i], 0) } + case []*big.Float: + + lenValues = len(values) + + if lenValues > ecd.params.MaxSlots() || lenValues > slots { + return fmt.Errorf("cannot Embed: ensure that #values (%d) <= slots (%d) <= maxSlots (%d)", len(values), slots, ecd.params.MaxSlots()) + } + + for i := range values { + if values[i] != nil { + f64, _ := values[i].Float64() + buffCmplx[i] = complex(f64, 0) + } else { + buffCmplx[i] = 0 + } + } default: - panic("cannot Embed: values.(Type) must be []complex128 or []float64") + return fmt.Errorf("cannot Embed: values.(Type) must be []complex128, []*bignum.Complex, []float64 or []*big.Float, but is %T", values) } + // Zeroes all other values for i := lenValues; i < slots; i++ { - ecd.values[i] = 0 + buffCmplx[i] = 0 } - ecd.IFFT(ecd.values, logSlots) + // IFFT + ecd.IFFT(buffCmplx[:slots], logSlots) + // Maps Y = X^{N/n} -> X and quantizes. switch p := polyOut.(type) { case ringqp.Poly: - ComplexToFixedPointCRT(ecd.params.RingQ().AtLevel(p.Q.Level()), ecd.values[:slots], scale.Float64(), p.Q.Coeffs) + Complex128ToFixedPointCRT(ecd.params.RingQ().AtLevel(p.Q.Level()), buffCmplx[:slots], scale.Float64(), p.Q.Coeffs) NttSparseAndMontgomery(ecd.params.RingQ().AtLevel(p.Q.Level()), logSlots, montgomery, p.Q) if p.P != nil { - ComplexToFixedPointCRT(ecd.params.RingP().AtLevel(p.P.Level()), ecd.values[:slots], scale.Float64(), p.P.Coeffs) + Complex128ToFixedPointCRT(ecd.params.RingP().AtLevel(p.P.Level()), buffCmplx[:slots], scale.Float64(), p.P.Coeffs) NttSparseAndMontgomery(ecd.params.RingP().AtLevel(p.P.Level()), logSlots, montgomery, p.P) } case *ring.Poly: - ComplexToFixedPointCRT(ecd.params.RingQ().AtLevel(p.Level()), ecd.values[:slots], scale.Float64(), p.Coeffs) + Complex128ToFixedPointCRT(ecd.params.RingQ().AtLevel(p.Level()), buffCmplx[:slots], scale.Float64(), p.Coeffs) NttSparseAndMontgomery(ecd.params.RingQ().AtLevel(p.Level()), logSlots, montgomery, p) default: - panic("cannot Embed: invalid polyOut.(Type) must be ringqp.Poly or *ring.Poly") + return fmt.Errorf("cannot Embed: invalid polyOut.(Type) must be ringqp.Poly or *ring.Poly") + } + + return +} + +func (ecd *Encoder) embedArbitrary(values interface{}, logSlots int, scale rlwe.Scale, montgomery bool, polyOut interface{}) (err error) { + if logSlots < 0 || logSlots > ecd.params.MaxLogSlots() { + return fmt.Errorf("cannot Embed: logSlots (%d) must be greater or equal to %d and smaller than %d", logSlots, 0, ecd.params.MaxLogSlots()) + } + + slots := 1 << logSlots + var lenValues int + + buffCmplx := ecd.buffCmplx.([]*bignum.Complex) + + switch values := values.(type) { + + case []complex128: + + lenValues = len(values) + + if lenValues > ecd.params.MaxSlots() || lenValues > slots { + return fmt.Errorf("cannot Embed: ensure that #values (%d) <= slots (%d) <= maxSlots (%d)", len(values), slots, ecd.params.MaxSlots()) + } + + if ecd.params.RingType() == ring.ConjugateInvariant { + for i := range values { + buffCmplx[i][0].SetFloat64(real(values[i])) + buffCmplx[i][1].SetFloat64(0) + } + } else { + for i := range values { + buffCmplx[i][0].SetFloat64(real(values[i])) + buffCmplx[i][1].SetFloat64(imag(values[i])) + } + } + + case []*bignum.Complex: + + lenValues = len(values) + + if lenValues > ecd.params.MaxSlots() || lenValues > slots { + return fmt.Errorf("cannot Embed: ensure that #values (%d) <= slots (%d) <= maxSlots (%d)", len(values), slots, ecd.params.MaxSlots()) + } + + if ecd.params.RingType() == ring.ConjugateInvariant { + for i := range values { + if values[i] != nil { + buffCmplx[i][0].Set(values[i][0]) + } else { + buffCmplx[i][0].SetFloat64(0) + } + + buffCmplx[i][1].SetFloat64(0) + } + } else { + for i := range values { + if values[i] != nil { + buffCmplx[i].Set(values[i]) + } else { + buffCmplx[i][0].SetFloat64(0) + buffCmplx[i][1].SetFloat64(0) + } + } + } + + case []float64: + + lenValues = len(values) + + if lenValues > ecd.params.MaxSlots() || lenValues > slots { + return fmt.Errorf("cannot Embed: ensure that #values (%d) <= slots (%d) <= maxSlots (%d)", len(values), slots, ecd.params.MaxSlots()) + } + + for i := range values { + buffCmplx[i][0].SetFloat64(values[i]) + buffCmplx[i][1].SetFloat64(0) + } + + case []*big.Float: + + lenValues = len(values) + + if lenValues > ecd.params.MaxSlots() || lenValues > slots { + return fmt.Errorf("cannot Embed: ensure that #values (%d) <= slots (%d) <= maxSlots (%d)", len(values), slots, ecd.params.MaxSlots()) + } + + for i := range values { + if values[i] != nil { + buffCmplx[i][0].Set(values[i]) + } else { + buffCmplx[i][0].SetFloat64(0) + } + + buffCmplx[i][1].SetFloat64(0) + } + default: + return fmt.Errorf("cannot Embed: values.(Type) must be []complex128, []*bignum.Complex, []float64 or []*big.Float, but is %T", values) + } + + // Zeroes all other values + for i := lenValues; i < slots; i++ { + buffCmplx[i][0].SetFloat64(0) + buffCmplx[i][1].SetFloat64(0) + } + + ecd.IFFT(buffCmplx[:slots], logSlots) + + // Maps Y = X^{N/n} -> X and quantizes. + switch p := polyOut.(type) { + + case *ring.Poly: + + ComplexArbitraryToFixedPointCRT(ecd.params.RingQ().AtLevel(p.Level()), buffCmplx[:slots], &scale.Value, p.Coeffs) + NttSparseAndMontgomery(ecd.params.RingQ().AtLevel(p.Level()), logSlots, montgomery, p) + + case ringqp.Poly: + + ComplexArbitraryToFixedPointCRT(ecd.params.RingQ().AtLevel(p.Q.Level()), buffCmplx[:slots], &scale.Value, p.Q.Coeffs) + NttSparseAndMontgomery(ecd.params.RingQ().AtLevel(p.Q.Level()), logSlots, montgomery, p.Q) + + if p.P != nil { + ComplexArbitraryToFixedPointCRT(ecd.params.RingP().AtLevel(p.P.Level()), buffCmplx[:slots], &scale.Value, p.P.Coeffs) + NttSparseAndMontgomery(ecd.params.RingP().AtLevel(p.P.Level()), logSlots, montgomery, p.P) + } + + default: + return fmt.Errorf("cannot Embed: invalid polyOut.(Type) must be ringqp.Poly or *ring.Poly") + } + + return +} + +func (ecd *Encoder) plaintextToComplex(level int, scale rlwe.Scale, logSlots int, p *ring.Poly, values interface{}) { + + isreal := ecd.params.RingType() == ring.ConjugateInvariant + if level == 0 { + polyToComplexNoCRT(p.Coeffs[0], values, scale, logSlots, isreal, ecd.params.RingQ().AtLevel(level)) + } else { + polyToComplexCRT(p, ecd.bigintCoeffs, values, scale, logSlots, isreal, ecd.params.RingQ().AtLevel(level)) } } -func polyToComplexNoCRT(coeffs []uint64, values []complex128, scale rlwe.Scale, logSlots int, isreal bool, ringQ *ring.Ring) { +func (ecd *Encoder) plaintextToFloat(level int, scale rlwe.Scale, logSlots int, p *ring.Poly, values interface{}) { + if level == 0 { + ecd.polyToFloatNoCRT(p.Coeffs[0], values, scale, logSlots, ecd.params.RingQ().AtLevel(level)) + } else { + ecd.polyToFloatCRT(p, values, scale, logSlots, ecd.params.RingQ().AtLevel(level)) + } +} + +func (ecd *Encoder) decodePublic(pt *rlwe.Plaintext, values interface{}, noise distribution.Distribution) (err error) { + + logSlots := pt.LogSlots + slots := 1 << logSlots + + if logSlots > ecd.params.MaxLogSlots() || logSlots < 0 { + return fmt.Errorf("cannot Decode: ensure that %d <= logSlots (%d) <= %d", 0, logSlots, ecd.params.MaxLogSlots()) + } + + if pt.IsNTT { + ecd.params.RingQ().AtLevel(pt.Level()).INTT(pt.Value, ecd.buff) + } else { + ring.CopyLvl(pt.Level(), pt.Value, ecd.buff) + } + + if noise != nil { + ring.NewSampler(ecd.prng, ecd.params.RingQ(), noise, pt.IsMontgomery).AtLevel(pt.Level()).ReadAndAdd(ecd.buff) + } + + switch values.(type) { + case []complex128, []float64, []*bignum.Complex, []*big.Float: + default: + return fmt.Errorf("cannot decode: values.(type) accepted are []complex128, []float64, []*bignum.Complex, []*big.Float but is %T", values) + } + + switch pt.EncodingDomain { + case rlwe.SlotsDomain: + + if ecd.prec <= 53 { + + buffCmplx := ecd.buffCmplx.([]complex128) + + ecd.plaintextToComplex(pt.Level(), pt.Scale, logSlots, ecd.buff, buffCmplx) + + ecd.FFT(buffCmplx[:slots], logSlots) + + switch values := values.(type) { + case []float64: + + slots := utils.MinInt(len(values), slots) + + for i := 0; i < slots; i++ { + values[i] = real(buffCmplx[i]) + } + case []complex128: + copy(values, buffCmplx) + + case []*big.Float: + slots := utils.MinInt(len(values), slots) + + for i := 0; i < slots; i++ { + + if values[i] == nil { + values[i] = new(big.Float) + } + + values[i].SetFloat64(real(buffCmplx[i])) + } + + case []*bignum.Complex: + + slots := utils.MinInt(len(values), slots) + + for i := 0; i < slots; i++ { + + if values[i] == nil { + values[i] = &bignum.Complex{ + new(big.Float), + new(big.Float), + } + } else { + if values[i][0] == nil { + values[i][0] = new(big.Float) + } + + if values[i][1] == nil { + values[i][1] = new(big.Float) + } + } + + values[i][0].SetFloat64(real(buffCmplx[i])) + values[i][1].SetFloat64(imag(buffCmplx[i])) + } + } + } else { + + buffCmplx := ecd.buffCmplx.([]*bignum.Complex) + + ecd.plaintextToComplex(pt.Level(), pt.Scale, logSlots, ecd.buff, buffCmplx[:slots]) + + ecd.FFT(buffCmplx[:slots], logSlots) + + switch values := values.(type) { + case []float64: + + slots := utils.MinInt(len(values), slots) + + for i := 0; i < slots; i++ { + values[i], _ = buffCmplx[i][0].Float64() + } + + case []complex128: + + slots := utils.MinInt(len(values), slots) + + for i := 0; i < slots; i++ { + values[i] = buffCmplx[i].Complex128() + } + + case []*big.Float: + slots := utils.MinInt(len(values), slots) + + for i := 0; i < slots; i++ { + + if values[i] == nil { + values[i] = new(big.Float) + } + + values[i].Set(buffCmplx[i][0]) + } + + case []*bignum.Complex: + + slots := utils.MinInt(len(values), slots) + + for i := 0; i < slots; i++ { + + if values[i] == nil { + values[i] = &bignum.Complex{ + new(big.Float), + new(big.Float), + } + } else { + if values[i][0] == nil { + values[i][0] = new(big.Float) + } + + if values[i][1] == nil { + values[i][1] = new(big.Float) + } + } + + values[i][0].Set(buffCmplx[i][0]) + values[i][1].Set(buffCmplx[i][1]) + } + } + } + + case rlwe.CoefficientsDomain: + ecd.plaintextToFloat(pt.Level(), pt.Scale, logSlots, ecd.buff, values) + default: + return fmt.Errorf("cannot decode: invalid rlwe.EncodingType, accepted types are rlwe.SlotsDomain and rlwe.CoefficientsDomain but is %T", pt.EncodingDomain) + } + + return +} + +func (ecd *Encoder) IFFT(values interface{}, logN int) (err error) { + switch values := values.(type) { + case []complex128: + switch roots := ecd.roots.(type) { + case []complex128: + if true { + SpecialIFFTDouble(values, 1<> 2) gap := maxSlots / slots Q := ringQ.SubRings[0].Modulus var c uint64 - for i, idx := 0, 0; i < slots; i, idx = i+1, idx+gap { - c = coeffs[idx] - if c >= Q>>1 { - values[i] = complex(-float64(Q-c), 0) - } else { - values[i] = complex(float64(c), 0) - } - } - if !isreal { - for i, idx := 0, maxSlots; i < slots; i, idx = i+1, idx+gap { + switch values := values.(type) { + case []complex128: + for i, idx := 0, 0; i < slots; i, idx = i+1, idx+gap { c = coeffs[idx] if c >= Q>>1 { - values[i] += complex(0, -float64(Q-c)) + values[i] = complex(-float64(Q-c), 0) } else { - values[i] += complex(0, float64(c)) + values[i] = complex(float64(c), 0) } } - } - divideComplex128SliceVec(values, complex(scale.Float64(), 0)) + if !isreal { + for i, idx := 0, maxSlots; i < slots; i, idx = i+1, idx+gap { + c = coeffs[idx] + if c >= Q>>1 { + values[i] += complex(0, -float64(Q-c)) + } else { + values[i] += complex(0, float64(c)) + } + } + } else { + // [X]/(X^N+1) to [X+X^-1]/(X^N+1) + slots := 1 << logSlots + for i := 1; i < slots; i++ { + values[i] -= complex(0, real(values[slots-i])) + } + } + + DivideComplex128SliceUnrolled8(values, complex(scale.Float64(), 0)) + + case []*bignum.Complex: + + for i, idx := 0, 0; i < slots; i, idx = i+1, idx+gap { + + if values[i] == nil { + values[i] = &bignum.Complex{ + new(big.Float), + nil, + } + } else { + if values[i][0] == nil { + values[i][0] = new(big.Float) + } + } + + if c = coeffs[idx]; c >= Q>>1 { + values[i][0].SetInt64(-int64(Q - c)) + } else { + values[i][0].SetInt64(int64(c)) + } + } + + if !isreal { + for i, idx := 0, maxSlots; i < slots; i, idx = i+1, idx+gap { + + if values[i][1] == nil { + values[i][1] = new(big.Float) + } + + if c = coeffs[idx]; c >= Q>>1 { + values[i][1].SetInt64(-int64(Q - c)) + } else { + values[i][1].SetInt64(int64(c)) + } + } + } else { + slots := 1 << logSlots + + for i := 1; i < slots; i++ { + values[i][1].Sub(values[i][1], values[slots-i][0]) + } + } + + s := &scale.Value + + for i := range values { + values[i][0].Quo(values[i][0], s) + values[i][1].Quo(values[i][1], s) + } + + default: + panic(fmt.Errorf("cannot polyToComplexNoCRT: values.(Type) must be []complex128 or []*bignum.Complex but is %T", values)) + } } -func polyToComplexCRT(poly *ring.Poly, bigintCoeffs []*big.Int, values []complex128, scale rlwe.Scale, logSlots int, isreal bool, ringQ *ring.Ring, Q *big.Int) { +func polyToComplexCRT(poly *ring.Poly, bigintCoeffs []*big.Int, values interface{}, scale rlwe.Scale, logSlots int, isreal bool, ringQ *ring.Ring) { maxSlots := int(ringQ.NthRoot() >> 2) slots := 1 << logSlots @@ -439,417 +819,273 @@ func polyToComplexCRT(poly *ring.Poly, bigintCoeffs []*big.Int, values []complex ringQ.PolyToBigint(poly, gap, bigintCoeffs) + Q := ringQ.ModulusAtLevel[ringQ.Level()] + qHalf := new(big.Int) qHalf.Set(Q) qHalf.Rsh(qHalf, 1) var sign int - scalef64 := scale.Float64() + switch values := values.(type) { - var c *big.Int - for i := 0; i < slots; i++ { - c = bigintCoeffs[i] - c.Mod(c, Q) - if sign = c.Cmp(qHalf); sign == 1 || sign == 0 { - c.Sub(c, Q) - } - values[i] = complex(scaleDown(c, scalef64), 0) - } + case []complex128: + scalef64 := scale.Float64() - if !isreal { - for i, j := 0, slots; i < slots; i, j = i+1, j+1 { - c = bigintCoeffs[j] + var c *big.Int + for i := 0; i < slots; i++ { + c = bigintCoeffs[i] c.Mod(c, Q) if sign = c.Cmp(qHalf); sign == 1 || sign == 0 { c.Sub(c, Q) } - values[i] += complex(0, scaleDown(c, scalef64)) + values[i] = complex(scaleDown(c, scalef64), 0) } - } -} -func (ecd *encoderComplex128) plaintextToComplex(level int, scale rlwe.Scale, logSlots int, p *ring.Poly, values []complex128) { - - isreal := ecd.params.RingType() == ring.ConjugateInvariant - if level == 0 { - polyToComplexNoCRT(p.Coeffs[0], values, scale, logSlots, isreal, ecd.params.RingQ()) - } else { - polyToComplexCRT(p, ecd.bigintCoeffs, values, scale, logSlots, isreal, ecd.params.RingQ(), ecd.params.RingQ().ModulusAtLevel[level]) - } - - if isreal { // [X]/(X^N+1) to [X+X^-1]/(X^N+1) - tmp := ecd.values - slots := 1 << logSlots - for i := 1; i < slots; i++ { - tmp[i] -= complex(0, real(tmp[slots-i])) + if !isreal { + for i, j := 0, slots; i < slots; i, j = i+1, j+1 { + c = bigintCoeffs[j] + c.Mod(c, Q) + if sign = c.Cmp(qHalf); sign == 1 || sign == 0 { + c.Sub(c, Q) + } + values[i] += complex(0, scaleDown(c, scalef64)) + } + } else { + // [X]/(X^N+1) to [X+X^-1]/(X^N+1) + slots := 1 << logSlots + for i := 1; i < slots; i++ { + values[i] -= complex(0, real(values[slots-i])) + } } - } -} + case []*bignum.Complex: -func (ecd *encoderComplex128) decodePublic(plaintext *rlwe.Plaintext, logSlots int, noise distribution.Distribution) (res []complex128) { - - if logSlots > ecd.params.MaxLogSlots() || logSlots < minLogSlots { - panic(fmt.Sprintf("cannot Decode: ensure that %d <= logSlots (%d) <= %d", minLogSlots, logSlots, ecd.params.MaxLogSlots())) - } - - if plaintext.IsNTT { - ecd.params.RingQ().AtLevel(plaintext.Level()).INTT(plaintext.Value, ecd.buff) - } else { - ring.CopyLvl(plaintext.Level(), plaintext.Value, ecd.buff) - } - - // B = floor(sigma * sqrt(2*pi)) - if noise != nil { - ring.NewSampler(ecd.prng, ecd.params.RingQ(), noise, plaintext.IsMontgomery).AtLevel(plaintext.Level()).ReadAndAdd(ecd.buff) - } - - ecd.plaintextToComplex(plaintext.Level(), plaintext.Scale, logSlots, ecd.buff, ecd.values) - - ecd.FFT(ecd.values, logSlots) - - res = make([]complex128, 1< 0 { - - ecd.params.RingQ().PolyToBigint(ecd.buff, 1, ecd.bigintCoeffs) - - Q := ecd.params.RingQ().ModulusAtLevel[plaintext.Level()] - - ecd.qHalf.Set(Q) - ecd.qHalf.Rsh(ecd.qHalf, 1) - - var sign int - - for i := range res { - - // Centers the value around the current modulus - ecd.bigintCoeffs[i].Mod(ecd.bigintCoeffs[i], Q) - - sign = ecd.bigintCoeffs[i].Cmp(ecd.qHalf) - if sign == 1 || sign == 0 { - ecd.bigintCoeffs[i].Sub(ecd.bigintCoeffs[i], Q) + var c *big.Int + for i := 0; i < slots; i++ { + c = bigintCoeffs[i] + c.Mod(c, Q) + if sign = c.Cmp(qHalf); sign == 1 || sign == 0 { + c.Sub(c, Q) } - res[i] = scaleDown(ecd.bigintCoeffs[i], sf64) - } - // We can directly get the coefficients - } else { - - Q := ecd.params.RingQ().SubRings[0].Modulus - coeffs := ecd.buff.Coeffs[0] - - for i := range res { - - if coeffs[i] >= Q>>1 { - res[i] = -float64(Q - coeffs[i]) + if values[i] == nil { + values[i] = &bignum.Complex{ + new(big.Float), + nil, + } } else { - res[i] = float64(coeffs[i]) + if values[i][0] == nil { + values[i][0] = new(big.Float) + } } - res[i] /= sf64 + values[i][0].SetInt(c) } - } - return -} + if !isreal { + for i, j := 0, slots; i < slots; i, j = i+1, j+1 { + c = bigintCoeffs[j] + c.Mod(c, Q) + if sign = c.Cmp(qHalf); sign == 1 || sign == 0 { + c.Sub(c, Q) + } -func (ecd *encoderComplex128) IFFT(values []complex128, logN int) { - if logN < 3 { - SpecialiFFTVec(values, 1<>2) - valuesfloat := make([]*big.Float, ecd.m>>1) - - for i := 0; i < ecd.m>>2; i++ { - - values[i] = ring.NewComplex(ring.NewFloat(0, prec), ring.NewFloat(0, prec)) - valuesfloat[i*2] = ring.NewFloat(0, prec) - valuesfloat[(i*2)+1] = ring.NewFloat(0, prec) - } - - return &encoderBigComplex{ - encoder: ecd, - zero: ring.NewFloat(0, prec), - cMul: ring.NewComplexMultiplier(), - prec: prec, - roots: GetRootsbigFloat(ecd.m, prec), - values: values, - valuesfloat: valuesfloat, - } -} - -// Encode encodes a set of values on the target plaintext. -// Encoding is done at the level and scale of the plaintext. -// User must ensure that 1 <= len(values) <= 2^logSlots < 2^LogN. -func (ecd *encoderBigComplex) Encode(values []*ring.Complex, plaintext *rlwe.Plaintext, logSlots int) { - - slots := 1 << logSlots - N := ecd.params.N() - - if len(values) > ecd.params.N()/2 || len(values) > slots || logSlots > ecd.params.LogN()-1 { - panic("cannot Encode: too many values/slots for the given ring degree") - } - - if len(values) != slots { - panic("cannot Encode: number of values must be equal to slots") - } - - for i := 0; i < slots; i++ { - ecd.values[i].Set(values[i]) - } - - ecd.InvFFT(ecd.values, slots) - - gap := (ecd.params.RingQ().N() >> 1) / slots - - for i, jdx, idx := 0, N>>1, 0; i < slots; i, jdx, idx = i+1, jdx+gap, idx+gap { - ecd.valuesfloat[idx].Set(ecd.values[i].Real()) - ecd.valuesfloat[jdx].Set(ecd.values[i].Imag()) - } - - scaleUpVecExactBigFloat(ecd.valuesfloat, plaintext.Scale.Float64(), ecd.params.RingQ().ModuliChain()[:plaintext.Level()+1], plaintext.Value.Coeffs) - - halfN := N >> 1 - for i := 0; i < halfN; i++ { - ecd.values[i].Real().Set(ecd.zero) - ecd.values[i].Imag().Set(ecd.zero) - } - - for i := 0; i < N; i++ { - ecd.valuesfloat[i].Set(ecd.zero) - } - - ecd.params.RingQ().AtLevel(plaintext.Level()).NTT(plaintext.Value, plaintext.Value) -} - -// EncodeNew encodes a set of values on a new plaintext. -// Encoding is done at the provided level and with the provided scale. -// User must ensure that 1 <= len(values) <= 2^logSlots < 2^LogN. -func (ecd *encoderBigComplex) EncodeNew(values []*ring.Complex, level int, scale rlwe.Scale, logSlots int) (plaintext *rlwe.Plaintext) { - plaintext = NewPlaintext(ecd.params, level) - plaintext.Scale = scale - ecd.Encode(values, plaintext, logSlots) - return -} - -// Decode decodes the input plaintext on a new slice of ring.Complex. -func (ecd *encoderBigComplex) Decode(plaintext *rlwe.Plaintext, logSlots int) (res []*ring.Complex) { - return ecd.decodePublic(plaintext, logSlots, nil) -} - -func (ecd *encoderBigComplex) DecodePublic(plaintext *rlwe.Plaintext, logSlots int, noise distribution.Distribution) (res []*ring.Complex) { - return ecd.decodePublic(plaintext, logSlots, noise) -} - -// FFT evaluates the decoding matrix on a slice of ring.Complex values. -func (ecd *encoderBigComplex) FFT(values []*ring.Complex, N int) { - - var lenh, lenq, gap, idx int - - u := ring.NewComplex(nil, nil) - v := ring.NewComplex(nil, nil) - - utils.BitReverseInPlaceSlice(values, N) - - for len := 2; len <= N; len <<= 1 { - for i := 0; i < N; i += len { - lenh = len >> 1 - lenq = len << 2 - gap = ecd.m / lenq - for j := 0; j < lenh; j++ { - idx = (ecd.rotGroup[j] % lenq) * gap - u.Set(values[i+j]) - v.Set(values[i+j+lenh]) - ecd.cMul.Mul(v, ecd.roots[idx], v) - values[i+j].Add(u, v) - values[i+j+lenh].Sub(u, v) + values[i][1].SetInt(c) + } + } else { + // [X]/(X^N+1) to [X+X^-1]/(X^N+1) + slots := 1 << logSlots + for i := 1; i < slots; i++ { + values[i][1].Sub(values[i][1], values[slots-i][0]) } } - } -} -// InvFFT evaluates the encoding matrix on a slice of ring.Complex values. -func (ecd *encoderBigComplex) InvFFT(values []*ring.Complex, N int) { + s := &scale.Value - var lenh, lenq, gap, idx int - u := ring.NewComplex(nil, nil) - v := ring.NewComplex(nil, nil) - - for len := N; len >= 1; len >>= 1 { - for i := 0; i < N; i += len { - lenh = len >> 1 - lenq = len << 2 - gap = ecd.m / lenq - for j := 0; j < lenh; j++ { - idx = (lenq - (ecd.rotGroup[j] % lenq)) * gap - u.Add(values[i+j], values[i+j+lenh]) - v.Sub(values[i+j], values[i+j+lenh]) - ecd.cMul.Mul(v, ecd.roots[idx], v) - values[i+j].Set(u) - values[i+j+lenh].Set(v) - } + for i := range values { + values[i][0].Quo(values[i][0], s) + values[i][1].Quo(values[i][1], s) } - } - NBig := ring.NewFloat(float64(N), ecd.prec) - for i := range values { - values[i][0].Quo(values[i][0], NBig) - values[i][1].Quo(values[i][1], NBig) - } - - utils.BitReverseInPlaceSlice(values, N) -} - -// ShallowCopy creates a shallow copy of this encoderBigComplex in which all the read-only data-structures are -// shared with the receiver and the temporary buffers are reallocated. The receiver and the returned -// EncoderBigComplex can be used concurrently. -func (ecd *encoderBigComplex) ShallowCopy() EncoderBigComplex { - - values := make([]*ring.Complex, ecd.m>>2) - valuesfloat := make([]*big.Float, ecd.m>>1) - - for i := 0; i < ecd.m>>2; i++ { - - values[i] = ring.NewComplex(ring.NewFloat(0, ecd.prec), ring.NewFloat(0, ecd.prec)) - valuesfloat[i*2] = ring.NewFloat(0, ecd.prec) - valuesfloat[(i*2)+1] = ring.NewFloat(0, ecd.prec) - } - - return &encoderBigComplex{ - encoder: *ecd.encoder.ShallowCopy(), - zero: ring.NewFloat(0, ecd.prec), - cMul: ring.NewComplexMultiplier(), - prec: ecd.prec, - values: values, - valuesfloat: valuesfloat, - roots: ecd.roots, + default: + panic(fmt.Errorf("cannot polyToComplexNoCRT: values.(Type) must be []complex128 or []*bignum.Complex but is %T", values)) } } -func (ecd *encoderBigComplex) decodePublic(plaintext *rlwe.Plaintext, logSlots int, noise distribution.Distribution) (res []*ring.Complex) { +func (ecd *Encoder) polyToFloatCRT(p *ring.Poly, values interface{}, scale rlwe.Scale, logSlots int, r *ring.Ring) { - slots := 1 << logSlots - - if logSlots > ecd.params.LogN()-1 { - panic("cannot Decode: too many slots for the given ring degree") + var slots int + switch values := values.(type) { + case []float64: + slots = utils.MinInt(len(p.Coeffs[0]), len(values)) + case []complex128: + slots = utils.MinInt(len(p.Coeffs[0]), len(values)) + case []*big.Float: + slots = utils.MinInt(len(p.Coeffs[0]), len(values)) + case []*bignum.Complex: + slots = utils.MinInt(len(p.Coeffs[0]), len(values)) } - ecd.params.RingQ().AtLevel(plaintext.Level()).INTT(plaintext.Value, ecd.buff) + bigintCoeffs := ecd.bigintCoeffs - if noise != nil { - ring.NewSampler(ecd.prng, ecd.params.RingQ(), noise, plaintext.IsMontgomery).AtLevel(plaintext.Level()).ReadAndAdd(ecd.buff) - } + ecd.params.RingQ().PolyToBigint(ecd.buff, 1, bigintCoeffs) - Q := ecd.params.RingQ().ModulusAtLevel[plaintext.Level()] - - maxSlots := ecd.params.N() >> 1 - - scaleFlo := plaintext.Scale.Value + Q := r.ModulusAtLevel[r.Level()] ecd.qHalf.Set(Q) ecd.qHalf.Rsh(ecd.qHalf, 1) - gap := maxSlots / slots - - ecd.params.RingQ().PolyToBigint(ecd.buff, gap, ecd.bigintCoeffs) - var sign int - - for i, j := 0, slots; i < slots; i, j = i+1, j+1 { - + for i := 0; i < slots; i++ { // Centers the value around the current modulus - ecd.bigintCoeffs[i].Mod(ecd.bigintCoeffs[i], Q) - sign = ecd.bigintCoeffs[i].Cmp(ecd.qHalf) + bigintCoeffs[i].Mod(bigintCoeffs[i], Q) + + sign = bigintCoeffs[i].Cmp(ecd.qHalf) if sign == 1 || sign == 0 { - ecd.bigintCoeffs[i].Sub(ecd.bigintCoeffs[i], Q) + bigintCoeffs[i].Sub(bigintCoeffs[i], Q) } + } - // Centers the value around the current modulus - ecd.bigintCoeffs[j].Mod(ecd.bigintCoeffs[j], Q) - sign = ecd.bigintCoeffs[j].Cmp(ecd.qHalf) - if sign == 1 || sign == 0 { - ecd.bigintCoeffs[j].Sub(ecd.bigintCoeffs[j], Q) + switch values := values.(type) { + + case []float64: + sf64 := scale.Float64() + for i := 0; i < slots; i++ { + values[i] = scaleDown(bigintCoeffs[i], sf64) } + case []complex128: + sf64 := scale.Float64() + for i := 0; i < slots; i++ { + values[i] = complex(scaleDown(bigintCoeffs[i], sf64), 0) + } + case []*big.Float: + s := &scale.Value + for i := 0; i < slots; i++ { - ecd.values[i].Real().SetInt(ecd.bigintCoeffs[i]) - ecd.values[i].Real().Quo(ecd.values[i].Real(), &scaleFlo) + if values[i] == nil { + values[i] = new(big.Float) + } + + values[i].SetInt(bigintCoeffs[i]) + values[i].Quo(values[i], s) + } + case []*bignum.Complex: + s := &scale.Value + for i := 0; i < slots; i++ { + + if values[i] == nil { + values[i] = &bignum.Complex{ + new(big.Float), + new(big.Float), + } + } else { + if values[i][0] == nil { + values[i][0] = new(big.Float) + } + } + + values[i][0].SetInt(bigintCoeffs[i]) + values[i][0].Quo(values[i][0], s) + } + default: + panic(fmt.Errorf("cannot polyToComplexNoCRT: values.(Type) must be []complex128, []*bignum.Complex, []float64 or []*big.Float but is %T", values)) + + } +} + +func (ecd *Encoder) polyToFloatNoCRT(coeffs []uint64, values interface{}, scale rlwe.Scale, logSlots int, r *ring.Ring) { + + Q := r.SubRings[0].Modulus + + var slots int + switch values := values.(type) { + case []float64: + slots = utils.MinInt(len(coeffs), len(values)) + case []complex128: + slots = utils.MinInt(len(coeffs), len(values)) + case []*big.Float: + slots = utils.MinInt(len(coeffs), len(values)) + case []*bignum.Complex: + slots = utils.MinInt(len(coeffs), len(values)) + } + + switch values := values.(type) { + + case []float64: + + sf64 := scale.Float64() + + for i := 0; i < slots; i++ { + if coeffs[i] >= Q>>1 { + values[i] = -float64(Q-coeffs[i]) / sf64 + } else { + values[i] = float64(coeffs[i]) / sf64 + } + } + + case []complex128: + + sf64 := scale.Float64() + + for i := 0; i < slots; i++ { + if coeffs[i] >= Q>>1 { + values[i] = complex(-float64(Q-coeffs[i])/sf64, 0) + } else { + values[i] = complex(float64(coeffs[i])/sf64, 0) + } + } + + case []*big.Float: + + s := &scale.Value + + for i := 0; i < slots; i++ { + + if values[i] == nil { + values[i] = new(big.Float) + } + + if coeffs[i] >= Q>>1 { + values[i].SetInt64(-int64(Q - coeffs[i])) + } else { + values[i].SetInt64(int64(coeffs[i])) + } + + values[i].Quo(values[i], s) + } + + case []*bignum.Complex: + + s := &scale.Value + + for i := 0; i < slots; i++ { + + if values[i] == nil { + values[i] = &bignum.Complex{ + new(big.Float), + nil, + } + } else { + if values[i][0] == nil { + values[i][0] = new(big.Float) + } + } + + if coeffs[i] >= Q>>1 { + values[i][0].SetInt64(-int64(Q - coeffs[i])) + } else { + values[i][0].SetInt64(int64(coeffs[i])) + } + + values[i][0].Quo(values[i][0], s) + } + + default: + panic(fmt.Errorf("cannot polyToComplexNoCRT: values.(Type) must be []complex128, []*bignum.Complex, []float64 or []*big.Float but is %T", values)) - ecd.values[i].Imag().SetInt(ecd.bigintCoeffs[j]) - ecd.values[i].Imag().Quo(ecd.values[i].Imag(), &scaleFlo) } - - ecd.FFT(ecd.values, slots) - - res = make([]*ring.Complex, slots) - - for i := range res { - res[i] = ecd.values[i].Copy() - } - - for i := 0; i < maxSlots; i++ { - ecd.values[i].Real().Set(ecd.zero) - ecd.values[i].Imag().Set(ecd.zero) - } - - return } diff --git a/ckks/evaluator.go b/ckks/evaluator.go index adcd104f..8448f91a 100644 --- a/ckks/evaluator.go +++ b/ckks/evaluator.go @@ -3,13 +3,13 @@ package ckks import ( "errors" "fmt" - "math" "math/big" "github.com/tuneinsight/lattigo/v4/ring" "github.com/tuneinsight/lattigo/v4/rlwe" "github.com/tuneinsight/lattigo/v4/rlwe/ringqp" "github.com/tuneinsight/lattigo/v4/utils" + "github.com/tuneinsight/lattigo/v4/utils/bignum" ) // Evaluator is an interface implementing the methods to conduct homomorphic operations between ciphertext and/or plaintexts. @@ -19,97 +19,82 @@ type Evaluator interface { // ======================== // Addition - Add(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) - AddNew(ctIn *rlwe.Ciphertext, op1 rlwe.Operand) (ctOut *rlwe.Ciphertext) + Add(op0 *rlwe.Ciphertext, op1 interface{}, op2 *rlwe.Ciphertext) + AddNew(op0 *rlwe.Ciphertext, op1 interface{}) (op2 *rlwe.Ciphertext) // Subtraction - Sub(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) - SubNew(ctIn *rlwe.Ciphertext, op1 rlwe.Operand) (ctOut *rlwe.Ciphertext) - - // Negation - Neg(ctIn *rlwe.Ciphertext, ctOut *rlwe.Ciphertext) - NegNew(ctIn *rlwe.Ciphertext) (ctOut *rlwe.Ciphertext) - - // Constant Addition - AddConstNew(ctIn *rlwe.Ciphertext, constant interface{}) (ctOut *rlwe.Ciphertext) - AddConst(ctIn *rlwe.Ciphertext, constant interface{}, ctOut *rlwe.Ciphertext) - - // Constant Multiplication - MultByConstNew(ctIn *rlwe.Ciphertext, constant interface{}) (ctOut *rlwe.Ciphertext) - MultByConst(ctIn *rlwe.Ciphertext, constant interface{}, ctOut *rlwe.Ciphertext) - - // Constant Multiplication followed by Addition - MultByConstThenAdd(ctIn *rlwe.Ciphertext, constant interface{}, ctOut *rlwe.Ciphertext) + Sub(op0 *rlwe.Ciphertext, op1 interface{}, op2 *rlwe.Ciphertext) + SubNew(op0 *rlwe.Ciphertext, op1 interface{}) (op2 *rlwe.Ciphertext) // Complex Conjugation - ConjugateNew(ctIn *rlwe.Ciphertext) (ctOut *rlwe.Ciphertext) - Conjugate(ctIn *rlwe.Ciphertext, ctOut *rlwe.Ciphertext) + ConjugateNew(op0 *rlwe.Ciphertext) (ctOut *rlwe.Ciphertext) + Conjugate(op0 *rlwe.Ciphertext, ctOut *rlwe.Ciphertext) // Multiplication - Mul(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) - MulNew(ctIn *rlwe.Ciphertext, op1 rlwe.Operand) (ctOut *rlwe.Ciphertext) - MulRelin(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) - MulRelinNew(ctIn *rlwe.Ciphertext, op1 rlwe.Operand) (ctOut *rlwe.Ciphertext) + Mul(op0 *rlwe.Ciphertext, op1 interface{}, ctOut *rlwe.Ciphertext) + MulNew(op0 *rlwe.Ciphertext, op1 interface{}) (ctOut *rlwe.Ciphertext) + MulRelin(op0 *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) + MulRelinNew(op0 *rlwe.Ciphertext, op1 rlwe.Operand) (ctOut *rlwe.Ciphertext) - MulThenAdd(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) - MulRelinThenAdd(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) + MulThenAdd(op0 *rlwe.Ciphertext, op1 interface{}, ctOut *rlwe.Ciphertext) + MulRelinThenAdd(op0 *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) // Slot Rotations - RotateNew(ctIn *rlwe.Ciphertext, k int) (ctOut *rlwe.Ciphertext) - Rotate(ctIn *rlwe.Ciphertext, k int, ctOut *rlwe.Ciphertext) - RotateHoistedNew(ctIn *rlwe.Ciphertext, rotations []int) (ctOut map[int]*rlwe.Ciphertext) - RotateHoisted(ctIn *rlwe.Ciphertext, rotations []int, ctOut map[int]*rlwe.Ciphertext) - RotateHoistedLazyNew(level int, rotations []int, ct *rlwe.Ciphertext, c2DecompQP []ringqp.Poly) (cOut map[int]*rlwe.OperandQP) + RotateNew(op0 *rlwe.Ciphertext, k int) (ctOut *rlwe.Ciphertext) + Rotate(op0 *rlwe.Ciphertext, k int, ctOut *rlwe.Ciphertext) + RotateHoistedNew(op0 *rlwe.Ciphertext, rotations []int) (ctOut map[int]*rlwe.Ciphertext) + RotateHoisted(op0 *rlwe.Ciphertext, rotations []int, ctOut map[int]*rlwe.Ciphertext) + RotateHoistedLazyNew(level int, rotations []int, ct *rlwe.Ciphertext, c2DecompQP []ringqp.Poly) (cOut map[int]rlwe.CiphertextQP) // =========================== // === Advanced Arithmetic === // =========================== // Polynomial evaluation - EvaluatePoly(input interface{}, pol *Polynomial, targetScale rlwe.Scale) (ctOut *rlwe.Ciphertext, err error) - EvaluatePolyVector(input interface{}, pols []*Polynomial, encoder Encoder, slotIndex map[int][]int, targetScale rlwe.Scale) (ctOut *rlwe.Ciphertext, err error) + EvaluatePoly(input interface{}, pol *bignum.Polynomial, targetScale rlwe.Scale) (ctOut *rlwe.Ciphertext, err error) + EvaluatePolyVector(input interface{}, pols []*bignum.Polynomial, encoder *Encoder, slotIndex map[int][]int, targetScale rlwe.Scale) (ctOut *rlwe.Ciphertext, err error) - // Inversion - InverseNew(ctIn *rlwe.Ciphertext, steps int) (ctOut *rlwe.Ciphertext, err error) + // GoldschmidtDivision + GoldschmidtDivisionNew(ct *rlwe.Ciphertext, minValue, log2Targetprecision float64, btp rlwe.Bootstrapper) (ctInv *rlwe.Ciphertext, err error) // Linear Transformations - LinearTransformNew(ctIn *rlwe.Ciphertext, linearTransform interface{}) (ctOut []*rlwe.Ciphertext) - LinearTransform(ctIn *rlwe.Ciphertext, linearTransform interface{}, ctOut []*rlwe.Ciphertext) - MultiplyByDiagMatrix(ctIn *rlwe.Ciphertext, matrix LinearTransform, c2DecompQP []ringqp.Poly, ctOut *rlwe.Ciphertext) - MultiplyByDiagMatrixBSGS(ctIn *rlwe.Ciphertext, matrix LinearTransform, c2DecompQP []ringqp.Poly, ctOut *rlwe.Ciphertext) + LinearTransformNew(op0 *rlwe.Ciphertext, linearTransform interface{}) (ctOut []*rlwe.Ciphertext) + LinearTransform(op0 *rlwe.Ciphertext, linearTransform interface{}, ctOut []*rlwe.Ciphertext) + MultiplyByDiagMatrix(op0 *rlwe.Ciphertext, matrix LinearTransform, c2DecompQP []ringqp.Poly, ctOut *rlwe.Ciphertext) + MultiplyByDiagMatrixBSGS(op0 *rlwe.Ciphertext, matrix LinearTransform, c2DecompQP []ringqp.Poly, ctOut *rlwe.Ciphertext) // Inner sum - InnerSum(ctIn *rlwe.Ciphertext, batch, n int, ctOut *rlwe.Ciphertext) - Average(ctIn *rlwe.Ciphertext, batch int, ctOut *rlwe.Ciphertext) + InnerSum(op0 *rlwe.Ciphertext, batch, n int, ctOut *rlwe.Ciphertext) + Average(op0 *rlwe.Ciphertext, batch int, ctOut *rlwe.Ciphertext) // Replication (inverse of Inner sum) - Replicate(ctIn *rlwe.Ciphertext, batch, n int, ctOut *rlwe.Ciphertext) + Replicate(op0 *rlwe.Ciphertext, batch, n int, ctOut *rlwe.Ciphertext) // Trace - Trace(ctIn *rlwe.Ciphertext, logSlots int, ctOut *rlwe.Ciphertext) - TraceNew(ctIn *rlwe.Ciphertext, logSlots int) (ctOut *rlwe.Ciphertext) + Trace(op0 *rlwe.Ciphertext, logSlots int, ctOut *rlwe.Ciphertext) + TraceNew(op0 *rlwe.Ciphertext, logSlots int) (ctOut *rlwe.Ciphertext) // ============================= // === Ciphertext Management === // ============================= // Generic EvaluationKeys - ApplyEvaluationKeyNew(ctIn *rlwe.Ciphertext, evk *rlwe.EvaluationKey) (ctOut *rlwe.Ciphertext) - ApplyEvaluationKey(ctIn *rlwe.Ciphertext, evk *rlwe.EvaluationKey, ctOut *rlwe.Ciphertext) + ApplyEvaluationKeyNew(op0 *rlwe.Ciphertext, evk *rlwe.EvaluationKey) (ctOut *rlwe.Ciphertext) + ApplyEvaluationKey(op0 *rlwe.Ciphertext, evk *rlwe.EvaluationKey, ctOut *rlwe.Ciphertext) // Degree Management - RelinearizeNew(ctIn *rlwe.Ciphertext) (ctOut *rlwe.Ciphertext) - Relinearize(ctIn *rlwe.Ciphertext, ctOut *rlwe.Ciphertext) + RelinearizeNew(op0 *rlwe.Ciphertext) (ctOut *rlwe.Ciphertext) + Relinearize(op0 *rlwe.Ciphertext, ctOut *rlwe.Ciphertext) // Scale Management - ScaleUpNew(ctIn *rlwe.Ciphertext, scale rlwe.Scale) (ctOut *rlwe.Ciphertext) - ScaleUp(ctIn *rlwe.Ciphertext, scale rlwe.Scale, ctOut *rlwe.Ciphertext) - SetScale(ctIn *rlwe.Ciphertext, scale rlwe.Scale) - Rescale(ctIn *rlwe.Ciphertext, minScale rlwe.Scale, ctOut *rlwe.Ciphertext) (err error) + ScaleUpNew(op0 *rlwe.Ciphertext, scale rlwe.Scale) (ctOut *rlwe.Ciphertext) + ScaleUp(op0 *rlwe.Ciphertext, scale rlwe.Scale, ctOut *rlwe.Ciphertext) + SetScale(op0 *rlwe.Ciphertext, scale rlwe.Scale) + Rescale(op0 *rlwe.Ciphertext, minScale rlwe.Scale, ctOut *rlwe.Ciphertext) (err error) // Level Management - DropLevelNew(ctIn *rlwe.Ciphertext, levels int) (ctOut *rlwe.Ciphertext) - DropLevel(ctIn *rlwe.Ciphertext, levels int) + DropLevelNew(op0 *rlwe.Ciphertext, levels int) (ctOut *rlwe.Ciphertext) + DropLevel(op0 *rlwe.Ciphertext, levels int) // ============== // === Others === @@ -182,46 +167,54 @@ func (eval *evaluator) GetRLWEEvaluator() *rlwe.Evaluator { return eval.Evaluator } -func (eval *evaluator) newCiphertextBinary(op0, op1 rlwe.Operand) (ctOut *rlwe.Ciphertext) { +// Add adds op1 to op0 and returns the result in op2. +func (eval *evaluator) Add(op0 *rlwe.Ciphertext, op1 interface{}, op2 *rlwe.Ciphertext) { - maxDegree := utils.Max(op0.Degree(), op1.Degree()) - minLevel := utils.Min(op0.Level(), op1.Level()) - - return NewCiphertext(eval.params, maxDegree, minLevel) + switch op1 := op1.(type) { + case rlwe.Operand: + _, level := eval.CheckBinary(op0, op1, op2, utils.MaxInt(op0.Degree(), op1.Degree())) + eval.evaluateInPlace(level, op0, op1, op2, eval.params.RingQ().AtLevel(level).Add) + default: + level := utils.MinInt(op0.Level(), op2.Level()) + RNSReal, RNSImag := bigComplexToRNSScalar(eval.params.RingQ().AtLevel(level), &op0.Scale.Value, bignum.ToComplex(op1, eval.params.DefaultPrecision())) + op2.Resize(op0.Degree(), level) + eval.evaluateWithScalar(level, op0.Value[:1], RNSReal, RNSImag, op2.Value[:1], eval.params.RingQ().AtLevel(level).AddDoubleRNSScalar) + } } -// Add adds op1 to ctIn and returns the result in ctOut. -func (eval *evaluator) Add(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) { - _, level := eval.CheckBinary(ctIn, op1, ctOut, utils.Max(ctIn.Degree(), op1.Degree())) - eval.evaluateInPlace(level, ctIn, op1, ctOut, eval.params.RingQ().AtLevel(level).Add) -} - -// AddNew adds op1 to ctIn and returns the result in a newly created element. -func (eval *evaluator) AddNew(ctIn *rlwe.Ciphertext, op1 rlwe.Operand) (ctOut *rlwe.Ciphertext) { - ctOut = eval.newCiphertextBinary(ctIn, op1) - eval.Add(ctIn, op1, ctOut) +// AddNew adds op1 to op0 and returns the result in a newly created element op2. +func (eval *evaluator) AddNew(op0 *rlwe.Ciphertext, op1 interface{}) (op2 *rlwe.Ciphertext) { + op2 = op0.CopyNew() + eval.Add(op2, op1, op2) return } -// Sub subtracts op1 from ctIn and returns the result in ctOut. -func (eval *evaluator) Sub(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) { +// Sub subtracts op1 from op0 and returns the result in op2. +func (eval *evaluator) Sub(op0 *rlwe.Ciphertext, op1 interface{}, op2 *rlwe.Ciphertext) { - _, level := eval.CheckBinary(ctIn, op1, ctOut, utils.Max(ctIn.Degree(), op1.Degree())) + switch op1 := op1.(type) { + case rlwe.Operand: + _, level := eval.CheckBinary(op0, op1, op2, utils.MaxInt(op0.Degree(), op1.Degree())) - eval.evaluateInPlace(level, ctIn, op1, ctOut, eval.params.RingQ().AtLevel(level).Sub) + eval.evaluateInPlace(level, op0, op1, op2, eval.params.RingQ().AtLevel(level).Sub) - if ctIn.Degree() < op1.Degree() { - for i := ctIn.Degree() + 1; i < op1.Degree()+1; i++ { - eval.params.RingQ().AtLevel(level).Neg(ctOut.Value[i], ctOut.Value[i]) + if op0.Degree() < op1.Degree() { + for i := op0.Degree() + 1; i < op1.Degree()+1; i++ { + eval.params.RingQ().AtLevel(level).Neg(op2.Value[i], op2.Value[i]) + } } + default: + level := utils.MinInt(op0.Level(), op2.Level()) + RNSReal, RNSImag := bigComplexToRNSScalar(eval.params.RingQ().AtLevel(level), &op0.Scale.Value, bignum.ToComplex(op1, eval.params.DefaultPrecision())) + op2.Resize(op0.Degree(), level) + eval.evaluateWithScalar(level, op0.Value[:1], RNSReal, RNSImag, op2.Value[:1], eval.params.RingQ().AtLevel(level).SubDoubleRNSScalar) } - } -// SubNew subtracts op1 from ctIn and returns the result in a newly created element. -func (eval *evaluator) SubNew(ctIn *rlwe.Ciphertext, op1 rlwe.Operand) (ctOut *rlwe.Ciphertext) { - ctOut = eval.newCiphertextBinary(ctIn, op1) - eval.Sub(ctIn, op1, ctOut) +// SubNew subtracts op1 from op0 and returns the result in a newly created element op2. +func (eval *evaluator) SubNew(op0 *rlwe.Ciphertext, op1 interface{}) (op2 *rlwe.Ciphertext) { + op2 = op0.CopyNew() + eval.Sub(op2, op1, op2) return } @@ -235,34 +228,49 @@ func (eval *evaluator) evaluateInPlace(level int, c0 *rlwe.Ciphertext, c1 rlwe.O // Else resizes the receiver element ctOut.El().Resize(maxDegree, ctOut.Level()) - c0Scale := c0.GetScale().Float64() - c1Scale := c1.GetScale().Float64() + c0Scale := c0.GetMetaData().Scale + c1Scale := c1.GetMetaData().Scale if ctOut.Level() > level { eval.DropLevel(ctOut, ctOut.Level()-utils.Min(c0.Level(), c1.Level())) } - cmp := c0.GetScale().Cmp(c1.GetScale()) + cmp := c0.GetMetaData().Scale.Cmp(c1.GetMetaData().Scale) // Checks whether or not the receiver element is the same as one of the input elements // and acts accordingly to avoid unnecessary element creation or element overwriting, // and scales properly the element before the evaluation. if ctOut == c0 { - if cmp == 1 && math.Floor(c0Scale/c1Scale) > 1 { + if cmp == 1 { - tmp1 = rlwe.NewCiphertextAtLevelFromPoly(level, eval.buffCt.Value[:c1.Degree()+1]) - tmp1.MetaData = ctOut.MetaData + ratioFlo := c0Scale.Div(c1Scale).Value - eval.MultByConst(&rlwe.Ciphertext{OperandQ: *c1.El()}, math.Floor(c0Scale/c1Scale), tmp1) + ratioInt, _ := ratioFlo.Int(nil) - } else if cmp == -1 && math.Floor(c1Scale/c0Scale) > 1 { + if ratioInt.Cmp(new(big.Int).SetUint64(0)) == 1 { - eval.MultByConst(c0, math.Floor(c1Scale/c0Scale), c0) + tmp1 = rlwe.NewCiphertextAtLevelFromPoly(level, eval.buffCt.Value[:c1.Degree()+1]) + tmp1.MetaData = ctOut.MetaData - ctOut.Scale = c1.GetScale() + eval.Mul(c1.El(), ratioInt, tmp1) + } + + } else if cmp == -1 { + + ratioFlo := c1Scale.Div(c0Scale).Value + + ratioInt, _ := ratioFlo.Int(nil) + + if ratioInt.Cmp(new(big.Int).SetUint64(0)) == 1 { + + eval.Mul(c0, ratioInt, c0) + + ctOut.Scale = c1.GetMetaData().Scale + + tmp1 = c1.El() + } - tmp1 = &rlwe.Ciphertext{OperandQ: *c1.El()} } else { tmp1 = &rlwe.Ciphertext{OperandQ: *c1.El()} } @@ -271,21 +279,34 @@ func (eval *evaluator) evaluateInPlace(level int, c0 *rlwe.Ciphertext, c1 rlwe.O } else if ctOut == c1 { - if cmp == 1 && math.Floor(c0Scale/c1Scale) > 1 { + if cmp == 1 { - eval.MultByConst(&rlwe.Ciphertext{OperandQ: *c1.El()}, math.Floor(c0Scale/c1Scale), ctOut) + ratioFlo := c0Scale.Div(c1Scale).Value - ctOut.Scale = c0.Scale + ratioInt, _ := ratioFlo.Int(nil) - tmp0 = &rlwe.Ciphertext{OperandQ: *c0.El()} + if ratioInt.Cmp(new(big.Int).SetUint64(0)) == 1 { + eval.Mul(c1.El(), ratioInt, ctOut) - } else if cmp == -1 && math.Floor(c1Scale/c0Scale) > 1 { + ctOut.Scale = c0.Scale - // Will avoid resizing on the output - tmp0 = rlwe.NewCiphertextAtLevelFromPoly(level, eval.buffCt.Value[:c0.Degree()+1]) - tmp0.MetaData = ctOut.MetaData + tmp0 = c0.El() + } + + } else if cmp == -1 { + + ratioFlo := c1Scale.Div(c0Scale).Value + + ratioInt, _ := ratioFlo.Int(nil) + + if ratioInt.Cmp(new(big.Int).SetUint64(0)) == 1 { + // Will avoid resizing on the output + tmp0 = rlwe.NewCiphertextAtLevelFromPoly(level, eval.buffCt.Value[:c0.Degree()+1]) + tmp0.MetaData = ctOut.MetaData + + eval.Mul(c0, ratioInt, tmp0) + } - eval.MultByConst(c0, math.Floor(c1Scale/c0Scale), tmp0) } else { tmp0 = &rlwe.Ciphertext{OperandQ: *c0.El()} } @@ -294,24 +315,38 @@ func (eval *evaluator) evaluateInPlace(level int, c0 *rlwe.Ciphertext, c1 rlwe.O } else { - if cmp == 1 && math.Floor(c0Scale/c1Scale) > 1 { + if cmp == 1 { - // Will avoid resizing on the output - tmp1 = rlwe.NewCiphertextAtLevelFromPoly(level, eval.buffCt.Value[:c1.Degree()+1]) - tmp1.MetaData = ctOut.MetaData + ratioFlo := c0Scale.Div(c1Scale).Value - eval.MultByConst(&rlwe.Ciphertext{OperandQ: *c1.El()}, math.Floor(c0Scale/c1Scale), tmp1) + ratioInt, _ := ratioFlo.Int(nil) - tmp0 = &rlwe.Ciphertext{OperandQ: *c0.El()} + if ratioInt.Cmp(new(big.Int).SetUint64(0)) == 1 { + // Will avoid resizing on the output + tmp1 = rlwe.NewCiphertextAtLevelFromPoly(level, eval.buffCt.Value[:c1.Degree()+1]) + tmp1.MetaData = ctOut.MetaData - } else if cmp == -1 && math.Floor(c1Scale/c0Scale) > 1 { + eval.Mul(c1.El(), ratioInt, tmp1) - tmp0 = rlwe.NewCiphertextAtLevelFromPoly(level, eval.buffCt.Value[:c0.Degree()+1]) - tmp0.MetaData = ctOut.MetaData + tmp0 = c0.El() + } - eval.MultByConst(c0, math.Floor(c1Scale/c0Scale), tmp0) + } else if cmp == -1 { - tmp1 = &rlwe.Ciphertext{OperandQ: *c1.El()} + ratioFlo := c1Scale.Div(c0Scale).Value + + ratioInt, _ := ratioFlo.Int(nil) + + if ratioInt.Cmp(new(big.Int).SetUint64(0)) == 1 { + + tmp0 = rlwe.NewCiphertextAtLevelFromPoly(level, eval.buffCt.Value[:c0.Degree()+1]) + tmp0.MetaData = ctOut.MetaData + + eval.Mul(c0, ratioInt, tmp0) + + tmp1 = c1.El() + + } } else { tmp0 = &rlwe.Ciphertext{OperandQ: *c0.El()} @@ -323,7 +358,7 @@ func (eval *evaluator) evaluateInPlace(level int, c0 *rlwe.Ciphertext, c1 rlwe.O evaluate(tmp0.Value[i], tmp1.Value[i], ctOut.El().Value[i]) } - scale := c0.Scale.Max(c1.GetScale()) + scale := c0.Scale.Max(c1.GetMetaData().Scale) ctOut.MetaData = c0.MetaData ctOut.Scale = scale @@ -342,97 +377,6 @@ func (eval *evaluator) evaluateInPlace(level int, c0 *rlwe.Ciphertext, c1 rlwe.O } } -// Neg negates the value of ct0 and returns the result in ctOut. -func (eval *evaluator) Neg(ct0 *rlwe.Ciphertext, ctOut *rlwe.Ciphertext) { - - level := utils.Min(ct0.Level(), ctOut.Level()) - - if ct0.Degree() != ctOut.Degree() { - panic("cannot Negate: invalid receiver Ciphertext does not match input Ciphertext degree") - } - - for i := range ct0.Value { - eval.params.RingQ().AtLevel(level).Neg(ct0.Value[i], ctOut.Value[i]) - } - - ctOut.MetaData = ct0.MetaData -} - -// NegNew negates ct0 and returns the result in a newly created element. -func (eval *evaluator) NegNew(ct0 *rlwe.Ciphertext) (ctOut *rlwe.Ciphertext) { - ctOut = NewCiphertext(eval.params, ct0.Degree(), ct0.Level()) - eval.Neg(ct0, ctOut) - return -} - -// AddConst adds the input constant to ct0 and returns the result in ctOut. -// The constant can be a complex128, float64, int, int64, uint64. *big.Float, *big.Int or *ring.Complex. -func (eval *evaluator) AddConst(ct0 *rlwe.Ciphertext, constant interface{}, ct1 *rlwe.Ciphertext) { - level := utils.Min(ct0.Level(), ct1.Level()) - ct1.Resize(ct0.Degree(), level) - RNSReal, RNSImag := bigComplexToRNSScalar(eval.params.RingQ().AtLevel(level), &ct0.Scale.Value, valueToBigComplex(constant, scalingPrecision)) - eval.evaluateWithScalar(level, ct0.Value[:1], RNSReal, RNSImag, ct1.Value[:1], eval.params.RingQ().AtLevel(level).AddDoubleRNSScalar) -} - -// AddConstNew adds the input constant to ct0 and returns the result in a new element. -// The constant can be a complex128, float64, int, int64, uint64. *big.Float, *big.Int or *ring.Complex. -func (eval *evaluator) AddConstNew(ct0 *rlwe.Ciphertext, constant interface{}) (ctOut *rlwe.Ciphertext) { - ctOut = ct0.CopyNew() - eval.AddConst(ct0, constant, ctOut) - return -} - -// MultByConstThenAdd multiplies ctIn by the input constant, and adds it to the receiver element, -// e.g., ctOut(x) = ctOut(x) + ctIn(x) * (a+bi). This functions removes the need of storing the intermediate value c(x) * (a+bi). -// -// This function will not modify ctIn but will multiply ctOut by Q[min(ctIn.Level(), ctOut.Level())] if: -// - ctIn.Scale == ctOut.Scale -// - constant is not a Gaussian integer. -// -// If ctIn.Scale == ctOut.Scale, and constant is not a Gaussian integer, then the constant will be scaled by -// Q[min(ctIn.Level(), ctOut.Level())] else if ctOut.Scale > ctIn.Scale, the constant will be scaled by ctOut.Scale/ctIn.Scale. -// -// To correctly use this function, make sure that either ctIn.Scale == ctOut.Scale or -// ctOut.Scale = ctIn.Scale * Q[min(ctIn.Level(), ctOut.Level())]. -// -// This function will panic if ctIn.Scale > ctOut.Scale. -func (eval *evaluator) MultByConstThenAdd(ctIn *rlwe.Ciphertext, constant interface{}, ctOut *rlwe.Ciphertext) { - - var level = utils.Min(ctIn.Level(), ctOut.Level()) - - ringQ := eval.params.RingQ().AtLevel(level) - - ctOut.Resize(ctOut.Degree(), level) - - cmplxBig := valueToBigComplex(constant, scalingPrecision) - - var scaleRLWE rlwe.Scale - - // If ctIn and ctOut scales are identical, but the constant is not a Gaussian integer then multiplies ctOut by scaleRLWE. - // This ensures noiseless addition with ctOut = scaleRLWE * ctOut + ctIn * round(scalar * scaleRLWE). - if cmp := ctIn.Scale.Cmp(ctOut.Scale); cmp == 0 { - - if cmplxBig.IsInt() { - scaleRLWE = rlwe.NewScale(1) - } else { - scaleRLWE = rlwe.NewScale(ringQ.SubRings[level].Modulus) - scaleInt := new(big.Int) - scaleRLWE.Value.Int(scaleInt) - eval.MultByConst(ctOut, scaleInt, ctOut) - ctOut.Scale = ctOut.Scale.Mul(scaleRLWE) - } - - } else if cmp == -1 { // ctOut.Scale > ctIn.Scale then the scaling factor for the constant becomes the quotient between the two scales - scaleRLWE = ctOut.Scale.Div(ctIn.Scale) - } else { - panic("MultByConstThenAdd: ctIn.Scale > ctOut.Scale is not supported") - } - - RNSReal, RNSImag := bigComplexToRNSScalar(ringQ, &scaleRLWE.Value, cmplxBig) - - eval.evaluateWithScalar(level, ctIn.Value, RNSReal, RNSImag, ctOut.Value, ringQ.MulDoubleRNSScalarThenAdd) -} - func (eval *evaluator) evaluateWithScalar(level int, p0 []*ring.Poly, RNSReal, RNSImag ring.RNSScalar, p1 []*ring.Poly, evaluate func(*ring.Poly, ring.RNSScalar, ring.RNSScalar, *ring.Poly)) { // Component wise operation with the following vector: @@ -440,8 +384,8 @@ func (eval *evaluator) evaluateWithScalar(level int, p0 []*ring.Poly, RNSReal, R // [{ N/2 }{ N/2 }] // Which is equivalent outside of the NTT domain to evaluating a to the first coefficient of ct0 and b to the N/2-th coefficient of ct0. for i, s := range eval.params.RingQ().SubRings[:level+1] { - RNSImag[i] = ring.MRedLazy(RNSImag[i], s.RootsForward[1], s.Modulus, s.MRedConstant) - RNSReal[i], RNSImag[i] = RNSReal[i]+RNSImag[i], RNSReal[i]+2*s.Modulus-RNSImag[i] + RNSImag[i] = ring.MRed(RNSImag[i], s.RootsForward[1], s.Modulus, s.MRedConstant) + RNSReal[i], RNSImag[i] = ring.CRed(RNSReal[i]+RNSImag[i], s.Modulus), ring.CRed(RNSReal[i]+s.Modulus-RNSImag[i], s.Modulus) } for i := range p0 { @@ -449,44 +393,6 @@ func (eval *evaluator) evaluateWithScalar(level int, p0 []*ring.Poly, RNSReal, R } } -// MultByConstNew multiplies ct0 by the input constant and returns the result in a newly created element. -// The scale of the output element will depend on the scale of the input element and the constant (if the constant -// needs to be scaled (its rational part is not zero)). -// The constant can be a complex128, float64, int, int64, uint64. *big.Float, *big.Int or *ring.Complex. -func (eval *evaluator) MultByConstNew(ct0 *rlwe.Ciphertext, constant interface{}) (ctOut *rlwe.Ciphertext) { - ctOut = NewCiphertext(eval.params, ct0.Degree(), ct0.Level()) - eval.MultByConst(ct0, constant, ctOut) - return -} - -// MultByConst multiplies ct0 by the input constant and returns the result in ctOut. -// The scale of the output element will depend on the scale of the input element and the constant (if the constant -// needs to be scaled (its rational part is not zero)). -// The constant can be a complex128, float64, int, int64, uint64. *big.Float, *big.Int or *ring.Complex. -func (eval *evaluator) MultByConst(ct0 *rlwe.Ciphertext, constant interface{}, ctOut *rlwe.Ciphertext) { - - level := utils.Min(ct0.Level(), ctOut.Level()) - ctOut.Resize(ct0.Degree(), level) - - ringQ := eval.params.RingQ().AtLevel(level) - - cmplxBig := valueToBigComplex(constant, scalingPrecision) - - var scale rlwe.Scale - - if cmplxBig.IsInt() { - scale = rlwe.NewScale(1) - } else { - scale = rlwe.NewScale(ringQ.SubRings[level].Modulus) - } - - RNSReal, RNSImag := bigComplexToRNSScalar(ringQ, &scale.Value, cmplxBig) - - eval.evaluateWithScalar(level, ct0.Value, RNSReal, RNSImag, ctOut.Value, ringQ.MulDoubleRNSScalar) - ctOut.MetaData = ct0.MetaData - ctOut.Scale = ct0.Scale.Mul(scale) -} - // ScaleUpNew multiplies ct0 by scale and sets its scale to its previous scale times scale returns the result in ctOut. func (eval *evaluator) ScaleUpNew(ct0 *rlwe.Ciphertext, scale rlwe.Scale) (ctOut *rlwe.Ciphertext) { ctOut = NewCiphertext(eval.params, ct0.Degree(), ct0.Level()) @@ -496,15 +402,15 @@ func (eval *evaluator) ScaleUpNew(ct0 *rlwe.Ciphertext, scale rlwe.Scale) (ctOut // ScaleUp multiplies ct0 by scale and sets its scale to its previous scale times scale returns the result in ctOut. func (eval *evaluator) ScaleUp(ct0 *rlwe.Ciphertext, scale rlwe.Scale, ctOut *rlwe.Ciphertext) { - eval.MultByConst(ct0, scale.Uint64(), ctOut) + eval.Mul(ct0, scale.Uint64(), ctOut) ctOut.MetaData = ct0.MetaData ctOut.Scale = ct0.Scale.Mul(scale) } // SetScale sets the scale of the ciphertext to the input scale (consumes a level). func (eval *evaluator) SetScale(ct *rlwe.Ciphertext, scale rlwe.Scale) { - - eval.MultByConst(ct, scale.Float64()/ct.Scale.Float64(), ct) + ratioFlo := scale.Div(ct.Scale).Value + eval.Mul(ct, &ratioFlo, ct) if err := eval.Rescale(ct, scale, ct); err != nil { panic(err) } @@ -543,8 +449,8 @@ func (eval *evaluator) RescaleNew(ct0 *rlwe.Ciphertext, minScale rlwe.Scale) (ct // in ctOut. Since all the moduli in the moduli chain are generated to be close to the // original scale, this procedure is equivalent to dividing the input element by the scale and adding // some error. -// Returns an error if "minScale <= 0", ct.scale = 0, ct.Level() = 0, ct.IsNTT() != true or if ct.Level() != ctOut.Level() -func (eval *evaluator) Rescale(ctIn *rlwe.Ciphertext, minScale rlwe.Scale, ctOut *rlwe.Ciphertext) (err error) { +// Returns an error if "minScale <= 0", ct.scale = 0, ct.Level() = 0, ct.IsNTT() != true or if ct.Leve() != ctOut.Level() +func (eval *evaluator) Rescale(op0 *rlwe.Ciphertext, minScale rlwe.Scale, ctOut *rlwe.Ciphertext) (err error) { if minScale.Cmp(rlwe.NewScale(0)) != 1 { return errors.New("cannot Rescale: minScale is <0") @@ -552,23 +458,23 @@ func (eval *evaluator) Rescale(ctIn *rlwe.Ciphertext, minScale rlwe.Scale, ctOut minScale = minScale.Div(rlwe.NewScale(2)) - if ctIn.Scale.Cmp(rlwe.NewScale(0)) != 1 { + if op0.Scale.Cmp(rlwe.NewScale(0)) != 1 { return errors.New("cannot Rescale: ciphertext scale is <0") } - if ctIn.Level() == 0 { + if op0.Level() == 0 { return errors.New("cannot Rescale: input Ciphertext already at level 0") } - if ctOut.Degree() != ctIn.Degree() { - return errors.New("cannot Rescale: ctIn.Degree() != ctOut.Degree()") + if ctOut.Degree() != op0.Degree() { + return errors.New("cannot Rescale: op0.Degree() != ctOut.Degree()") } - ctOut.MetaData = ctIn.MetaData + ctOut.MetaData = op0.MetaData - newLevel := ctIn.Level() + newLevel := op0.Level() - ringQ := eval.params.RingQ().AtLevel(ctIn.Level()) + ringQ := eval.params.RingQ().AtLevel(op0.Level()) // Divides the scale by each moduli of the modulus chain as long as the scale isn't smaller than minScale/2 // or until the output Level() would be zero @@ -589,65 +495,101 @@ func (eval *evaluator) Rescale(ctIn *rlwe.Ciphertext, minScale rlwe.Scale, ctOut if nbRescales > 0 { for i := range ctOut.Value { - ringQ.DivRoundByLastModulusManyNTT(nbRescales, ctIn.Value[i], eval.buffQ[0], ctOut.Value[i]) + ringQ.DivRoundByLastModulusManyNTT(nbRescales, op0.Value[i], eval.buffQ[0], ctOut.Value[i]) } ctOut.Resize(ctOut.Degree(), newLevel) } else { - if ctIn != ctOut { - ctOut.Copy(ctIn) + if op0 != ctOut { + ctOut.Copy(op0) } } return nil } -// MulNew multiplies ctIn with op1 without relinearization and returns the result in a newly created element. -// The procedure will panic if either ctIn.Degree or op1.Degree > 1. -func (eval *evaluator) MulNew(ctIn *rlwe.Ciphertext, op1 rlwe.Operand) (ctOut *rlwe.Ciphertext) { - ctOut = NewCiphertext(eval.params, ctIn.Degree()+op1.Degree(), utils.Min(ctIn.Level(), op1.Level())) - eval.mulRelin(ctIn, op1, false, ctOut) +// MulNew multiplies op0 with op1 without relinearization and returns the result in a newly created element op2. +// +// op1.(type) can be rlwe.Operand, complex128, float64, int, int64, uint64. *big.Float, *big.Int or *ring.Complex. +// +// If op1.(type) == rlwe.Operand: +// - The procedure will panic if either op0.Degree or op1.Degree > 1. +func (eval *evaluator) MulNew(op0 *rlwe.Ciphertext, op1 interface{}) (op2 *rlwe.Ciphertext) { + op2 = op0.CopyNew() + eval.Mul(op2, op1, op2) return } -// Mul multiplies ctIn with op1 without relinearization and returns the result in ctOut. -// The procedure will panic if either ctIn or op1 are have a degree higher than 1. -// The procedure will panic if ctOut.Degree != ctIn.Degree + op1.Degree. -func (eval *evaluator) Mul(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) { - eval.mulRelin(ctIn, op1, false, ctOut) +// Mul multiplies op0 with op1 without relinearization and returns the result in ctOut. +// +// op1.(type) can be rlwe.Operand, complex128, float64, int, int64, uint64. *big.Float, *big.Int or *ring.Complex. +// +// If op1.(type) == rlwe.Operand: +// - The procedure will panic if either op0 or op1 are have a degree higher than 1. +// - The procedure will panic if op2.Degree != op0.Degree + op1.Degree. +func (eval *evaluator) Mul(op0 *rlwe.Ciphertext, op1 interface{}, op2 *rlwe.Ciphertext) { + switch op1 := op1.(type) { + case rlwe.Operand: + eval.mulRelin(op0, op1, false, op2) + default: + level := utils.MinInt(op0.Level(), op2.Level()) + op2.Resize(op0.Degree(), level) + + ringQ := eval.params.RingQ().AtLevel(level) + + cmplxBig := bignum.ToComplex(op1, eval.params.DefaultPrecision()) + + var scale rlwe.Scale + + if cmplxBig.IsInt() { + scale = rlwe.NewScale(1) + } else { + scale = rlwe.NewScale(ringQ.SubRings[level].Modulus) + + for i := 1; i < eval.params.DefaultScaleModuliRatio(); i++ { + scale = scale.Mul(rlwe.NewScale(ringQ.SubRings[level-i].Modulus)) + } + } + + RNSReal, RNSImag := bigComplexToRNSScalar(ringQ, &scale.Value, cmplxBig) + + eval.evaluateWithScalar(level, op0.Value, RNSReal, RNSImag, op2.Value, ringQ.MulDoubleRNSScalar) + op2.MetaData = op0.MetaData + op2.Scale = op0.Scale.Mul(scale) + } } -// MulRelinNew multiplies ctIn with op1 with relinearization and returns the result in a newly created element. -// The procedure will panic if either ctIn.Degree or op1.Degree > 1. +// MulRelinNew multiplies op0 with op1 with relinearization and returns the result in a newly created element. +// The procedure will panic if either op0.Degree or op1.Degree > 1. // The procedure will panic if the evaluator was not created with an relinearization key. -func (eval *evaluator) MulRelinNew(ctIn *rlwe.Ciphertext, op1 rlwe.Operand) (ctOut *rlwe.Ciphertext) { - ctOut = NewCiphertext(eval.params, 1, utils.Min(ctIn.Level(), op1.Level())) - eval.mulRelin(ctIn, op1, true, ctOut) +func (eval *evaluator) MulRelinNew(op0 *rlwe.Ciphertext, op1 rlwe.Operand) (ctOut *rlwe.Ciphertext) { + ctOut = NewCiphertext(eval.params, 1, utils.MinInt(op0.Level(), op1.Level())) + eval.mulRelin(op0, op1, true, ctOut) return } -// MulRelin multiplies ctIn with op1 with relinearization and returns the result in ctOut. -// The procedure will panic if either ctIn.Degree or op1.Degree > 1. -// The procedure will panic if ctOut.Degree != ctIn.Degree + op1.Degree. +// MulRelin multiplies op0 with op1 with relinearization and returns the result in ctOut. +// The procedure will panic if either op0.Degree or op1.Degree > 1. +// The procedure will panic if ctOut.Degree != op0.Degree + op1.Degree. // The procedure will panic if the evaluator was not created with an relinearization key. -func (eval *evaluator) MulRelin(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) { - eval.mulRelin(ctIn, op1, true, ctOut) +func (eval *evaluator) MulRelin(op0 *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) { + eval.mulRelin(op0, op1, true, ctOut) } -func (eval *evaluator) mulRelin(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, relin bool, ctOut *rlwe.Ciphertext) { +func (eval *evaluator) mulRelin(op0 *rlwe.Ciphertext, op1 rlwe.Operand, relin bool, ctOut *rlwe.Ciphertext) { - if ctIn.Degree()+op1.Degree() > 2 { + if op0.Degree()+op1.Degree() > 2 { panic("cannot MulRelin: the sum of the input elements' total degree cannot be larger than 2") } - ctOut.MetaData = ctIn.MetaData - ctOut.Scale = ctIn.Scale.Mul(op1.GetScale()) + ctOut.MetaData = op0.MetaData + ctOut.Scale = op0.Scale.Mul(op1.GetMetaData().Scale) var c00, c01, c0, c1, c2 *ring.Poly // Case Ciphertext (x) Ciphertext - if ctIn.Degree() == 1 && op1.Degree() == 1 { + if op0.Degree() == 1 && op1.Degree() == 1 { - _, level := eval.CheckBinary(ctIn, op1, ctOut, ctOut.Degree()) + _, level := eval.CheckBinary(op0, op1, ctOut, ctOut.Degree()) ringQ := eval.params.RingQ().AtLevel(level) @@ -668,15 +610,15 @@ func (eval *evaluator) mulRelin(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, relin b // Avoid overwriting if the second input is the output var tmp0, tmp1 *rlwe.OperandQ if op1.El() == ctOut.El() { - tmp0, tmp1 = op1.El(), ctIn.El() + tmp0, tmp1 = op1.El(), op0.El() } else { - tmp0, tmp1 = ctIn.El(), op1.El() + tmp0, tmp1 = op0.El(), op1.El() } ringQ.MForm(tmp0.Value[0], c00) ringQ.MForm(tmp0.Value[1], c01) - if ctIn.El() == op1.El() { // squaring case + if op0.El() == op1.El() { // squaring case ringQ.MulCoeffsMontgomery(c00, tmp1.Value[0], c0) // c0 = c[0]*c[0] ringQ.MulCoeffsMontgomery(c01, tmp1.Value[1], c2) // c2 = c[1]*c[1] ringQ.MulCoeffsMontgomery(c00, tmp1.Value[1], c1) // c1 = 2*c[0]*c[1] @@ -709,24 +651,24 @@ func (eval *evaluator) mulRelin(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, relin b // Case Plaintext (x) Ciphertext or Ciphertext (x) Plaintext } else { - _, level := eval.CheckBinary(ctIn, op1, ctOut, ctOut.Degree()) + _, level := eval.CheckBinary(op0, op1, ctOut, ctOut.Degree()) ringQ := eval.params.RingQ().AtLevel(level) var c0 *ring.Poly var c1 []*ring.Poly - if ctIn.Degree() == 0 { + if op0.Degree() == 0 { c0 = eval.buffQ[0] - ringQ.MForm(ctIn.Value[0], c0) + ringQ.MForm(op0.Value[0], c0) c1 = op1.El().Value } else { c0 = eval.buffQ[0] ringQ.MForm(op1.El().Value[0], c0) - c1 = ctIn.Value + c1 = op0.Value } - ctOut.El().Resize(ctIn.Degree()+op1.Degree(), level) + ctOut.El().Resize(op0.Degree()+op1.Degree(), level) for i := range c1 { ringQ.MulCoeffsMontgomery(c0, c1[i], ctOut.Value[i]) @@ -734,48 +676,109 @@ func (eval *evaluator) mulRelin(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, relin b } } -// MulThenAdd multiplies ctIn with op1 without relinearization and adds the result on ctOut. -// User must ensure that ctOut.scale <= ctIn.scale * op1.scale. -// If ctOut.scale < ctIn.scale * op1.scale, then scales up ctOut before adding the result. -// The procedure will panic if either ctIn or op1 are have a degree higher than 1. -// The procedure will panic if ctOut.Degree != ctIn.Degree + op1.Degree. -// The procedure will panic if ctOut = ctIn or op1. -func (eval *evaluator) MulThenAdd(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) { - eval.mulRelinThenAdd(ctIn, op1, false, ctOut) +// MulThenAdd evaluate op2 = op2 + op0 * op1. +// +// op1.(type) can be rlwe.Operand, complex128, float64, int, int64, uint64. *big.Float, *big.Int or *ring.Complex. +// +// If op1.(type) is complex128, float64, int, int64, uint64. *big.Float, *big.Int or *ring.Complex: +// +// This function will not modify op0 but will multiply op2 by Q[min(op0.Level(), op2.Level())] if: +// - op0.Scale == op2.Scale +// - constant is not a Gaussian integer. +// +// If op0.Scale == op2.Scale, and constant is not a Gaussian integer, then the constant will be scaled by +// Q[min(op0.Level(), op2.Level())] else if op2.Scale > op0.Scale, the constant will be scaled by op2.Scale/op0.Scale. +// +// To correctly use this function, make sure that either op0.Scale == op2.Scale or +// op2.Scale = op0.Scale * Q[min(op0.Level(), op2.Level())]. +// +// If op1.(type) is rlwe.Operand, the multiplication is carried outwithout relinearization and: +// +// This function will panic if op0.Scale > op2.Scale. +// User must ensure that op2.scale <= op0.scale * op1.scale. +// If op2.scale < op0.scale * op1.scale, then scales up op2 before adding the result. +// Additionally, the procedure will panic if: +// - either op0 or op1 are have a degree higher than 1. +// - op2.Degree != op0.Degree + op1.Degree. +// - op2 = op0 or op1. +func (eval *evaluator) MulThenAdd(op0 *rlwe.Ciphertext, op1 interface{}, op2 *rlwe.Ciphertext) { + switch op1 := op1.(type) { + case rlwe.Operand: + eval.mulRelinThenAdd(op0, op1, false, op2) + default: + var level = utils.MinInt(op0.Level(), op2.Level()) + + ringQ := eval.params.RingQ().AtLevel(level) + + op2.Resize(op2.Degree(), level) + + cmplxBig := bignum.ToComplex(op1, eval.params.DefaultPrecision()) + + var scaleRLWE rlwe.Scale + + // If op0 and op2 scales are identical, but the op1 is not a Gaussian integer then multiplies op2 by scaleRLWE. + // This ensures noiseless addition with op2 = scaleRLWE * op2 + op0 * round(scalar * scaleRLWE). + if cmp := op0.Scale.Cmp(op2.Scale); cmp == 0 { + + if cmplxBig.IsInt() { + scaleRLWE = rlwe.NewScale(1) + } else { + scaleRLWE = rlwe.NewScale(ringQ.SubRings[level].Modulus) + + for i := 1; i < eval.params.DefaultScaleModuliRatio(); i++ { + scaleRLWE = scaleRLWE.Mul(rlwe.NewScale(ringQ.SubRings[level-i].Modulus)) + } + + scaleInt := new(big.Int) + scaleRLWE.Value.Int(scaleInt) + eval.Mul(op2, scaleInt, op2) + op2.Scale = op2.Scale.Mul(scaleRLWE) + } + + } else if cmp == -1 { // op2.Scale > op0.Scale then the scaling factor for op1 becomes the quotient between the two scales + scaleRLWE = op2.Scale.Div(op0.Scale) + } else { + panic("MulThenAdd: op0.Scale > op2.Scale is not supported") + } + + RNSReal, RNSImag := bigComplexToRNSScalar(ringQ, &scaleRLWE.Value, cmplxBig) + + eval.evaluateWithScalar(level, op0.Value, RNSReal, RNSImag, op2.Value, ringQ.MulDoubleRNSScalarThenAdd) + } } -// MulRelinThenAdd multiplies ctIn with op1 with relinearization and adds the result on ctOut. -// User must ensure that ctOut.scale <= ctIn.scale * op1.scale. -// If ctOut.scale < ctIn.scale * op1.scale, then scales up ctOut before adding the result. -// The procedure will panic if either ctIn.Degree or op1.Degree > 1. -// The procedure will panic if ctOut.Degree != ctIn.Degree + op1.Degree. +// MulRelinThenAdd multiplies op0 with op1 with relinearization and adds the result on op2. +// User must ensure that op2.scale <= op0.scale * op1.scale. +// If op2.scale < op0.scale * op1.scale, then scales up op2 before adding the result. +// The procedure will panic if either op0.Degree or op1.Degree > 1. +// The procedure will panic if op2.Degree != op0.Degree + op1.Degree. // The procedure will panic if the evaluator was not created with an relinearization key. -// The procedure will panic if ctOut = ctIn or op1. -func (eval *evaluator) MulRelinThenAdd(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, ctOut *rlwe.Ciphertext) { - eval.mulRelinThenAdd(ctIn, op1, true, ctOut) +// The procedure will panic if op2 = op0 or op1. +func (eval *evaluator) MulRelinThenAdd(op0 *rlwe.Ciphertext, op1 rlwe.Operand, op2 *rlwe.Ciphertext) { + eval.mulRelinThenAdd(op0, op1, true, op2) } -func (eval *evaluator) mulRelinThenAdd(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, relin bool, ctOut *rlwe.Ciphertext) { +func (eval *evaluator) mulRelinThenAdd(op0 *rlwe.Ciphertext, op1 rlwe.Operand, relin bool, op2 *rlwe.Ciphertext) { - _, level := eval.CheckBinary(ctIn, op1, ctOut, utils.Max(ctIn.Degree(), op1.Degree())) + _, level := eval.CheckBinary(op0, op1, op2, utils.MaxInt(op0.Degree(), op1.Degree())) - if ctIn.Degree()+op1.Degree() > 2 { + if op0.Degree()+op1.Degree() > 2 { panic("cannot MulRelinThenAdd: the sum of the input elements' degree cannot be larger than 2") } - if ctIn.El() == ctOut.El() || op1.El() == ctOut.El() { - panic("cannot MulRelinThenAdd: ctOut must be different from op0 and op1") + if op0.El() == op2.El() || op1.El() == op2.El() { + panic("cannot MulRelinThenAdd: op2 must be different from op0 and op1") } - c0f64 := ctIn.Scale.Float64() - c1f64 := op1.GetScale().Float64() - c2f64 := ctOut.Scale.Float64() + resScale := op0.Scale.Mul(op1.GetMetaData().Scale) - resScale := c0f64 * c1f64 - - if c2f64 < resScale { - eval.MultByConst(ctOut, math.Round(resScale/c2f64), ctOut) - ctOut.Scale = rlwe.NewScale(resScale) + if op2.Scale.Cmp(resScale) == -1 { + ratio := resScale.Div(op2.Scale) + // Only scales up if int(ratio) >= 2 + if ratio.Float64() >= 2.0 { + eval.Mul(op2, &ratio.Value, op2) + op2.Scale = resScale + } } ringQ := eval.params.RingQ().AtLevel(level) @@ -783,23 +786,23 @@ func (eval *evaluator) mulRelinThenAdd(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, var c00, c01, c0, c1, c2 *ring.Poly // Case Ciphertext (x) Ciphertext - if ctIn.Degree() == 1 && op1.Degree() == 1 { + if op0.Degree() == 1 && op1.Degree() == 1 { c00 = eval.buffQ[0] c01 = eval.buffQ[1] - c0 = ctOut.Value[0] - c1 = ctOut.Value[1] + c0 = op2.Value[0] + c1 = op2.Value[1] if !relin { - ctOut.El().Resize(2, level) - c2 = ctOut.Value[2] + op2.El().Resize(2, level) + c2 = op2.Value[2] } else { - // No resize here since we add on ctOut + // No resize here since we add on op2 c2 = eval.buffQ[2] } - tmp0, tmp1 := ctIn.El(), op1.El() + tmp0, tmp1 := op0.El(), op1.El() ringQ.MForm(tmp0.Value[0], c00) ringQ.MForm(tmp0.Value[1], c01) @@ -832,15 +835,15 @@ func (eval *evaluator) mulRelinThenAdd(ctIn *rlwe.Ciphertext, op1 rlwe.Operand, // Case Plaintext (x) Ciphertext or Ciphertext (x) Plaintext } else { - if ctOut.Degree() < ctIn.Degree() { - ctOut.Resize(ctIn.Degree(), level) + if op2.Degree() < op0.Degree() { + op2.Resize(op0.Degree(), level) } c00 := eval.buffQ[0] ringQ.MForm(op1.El().Value[0], c00) - for i := range ctIn.Value { - ringQ.MulCoeffsMontgomeryThenAdd(ctIn.Value[i], c00, ctOut.Value[i]) + for i := range op0.Value { + ringQ.MulCoeffsMontgomeryThenAdd(op0.Value[i], c00, op2.Value[i]) } } } diff --git a/ckks/linear_transform.go b/ckks/linear_transform.go index ea9e5063..ae2cf1c4 100644 --- a/ckks/linear_transform.go +++ b/ckks/linear_transform.go @@ -2,12 +2,14 @@ package ckks import ( "fmt" + "math/big" "runtime" "github.com/tuneinsight/lattigo/v4/ring" "github.com/tuneinsight/lattigo/v4/rlwe" "github.com/tuneinsight/lattigo/v4/rlwe/ringqp" "github.com/tuneinsight/lattigo/v4/utils" + "github.com/tuneinsight/lattigo/v4/utils/bignum" ) // TraceNew maps X -> sum((-1)^i * X^{i*n+1}) for 0 <= i < N and returns the result on a new ciphertext. @@ -30,7 +32,7 @@ func (eval *evaluator) Average(ctIn *rlwe.Ciphertext, logBatchSize int, ctOut *r panic("ctIn.Degree() != 1 or ctOut.Degree() != 1") } - if logBatchSize > eval.params.LogSlots() { + if logBatchSize > ctIn.LogSlots { panic("cannot Average: batchSize must be smaller or equal to the number of slots") } @@ -38,7 +40,7 @@ func (eval *evaluator) Average(ctIn *rlwe.Ciphertext, logBatchSize int, ctOut *r level := utils.Min(ctIn.Level(), ctOut.Level()) - n := eval.params.Slots() / (1 << logBatchSize) + n := 1 << (ctIn.LogSlots - logBatchSize) // pre-multiplication by n^-1 for i, s := range ringQ.SubRings[:level+1] { @@ -162,12 +164,7 @@ func (LT *LinearTransform) GaloisElements(params Parameters) (galEls []uint64) { // It can then be evaluated on a ciphertext using evaluator.LinearTransform. // Evaluation will use the naive approach (single hoisting and no baby-step giant-step). // Faster if there is only a few non-zero diagonals but uses more keys. -func (LT *LinearTransform) Encode(encoder Encoder, value interface{}, scale rlwe.Scale) { - - enc, ok := encoder.(*encoderComplex128) - if !ok { - panic("cannot Encode: encoder should be an encoderComplex128") - } +func (LT *LinearTransform) Encode(ecd *Encoder, value interface{}, scale rlwe.Scale) { dMat := interfaceMapToMapOfInterface(value) slots := 1 << LT.LogSlots @@ -184,7 +181,7 @@ func (LT *LinearTransform) Encode(encoder Encoder, value interface{}, scale rlwe panic("cannot Encode: error encoding on LinearTransform: input does not match the same non-zero diagonals") } - enc.Embed(dMat[i], LT.LogSlots, scale, true, LT.Vec[idx]) + ecd.Embed(dMat[i], LT.LogSlots, scale, true, LT.Vec[idx]) } } else { @@ -196,6 +193,10 @@ func (LT *LinearTransform) Encode(encoder Encoder, value interface{}, scale rlwe values = make([]complex128, slots) case map[int][]float64: values = make([]float64, slots) + case map[int][]*big.Float: + values = make([]*big.Float, slots) + case map[int][]*bignum.Complex: + values = make([]*bignum.Complex, slots) } for j := range index { @@ -215,7 +216,7 @@ func (LT *LinearTransform) Encode(encoder Encoder, value interface{}, scale rlwe copyRotInterface(values, v, rot) - enc.Embed(values, LT.LogSlots, scale, true, LT.Vec[j+i]) + ecd.Embed(values, LT.LogSlots, scale, true, LT.Vec[j+i]) } } } @@ -229,14 +230,9 @@ func (LT *LinearTransform) Encode(encoder Encoder, value interface{}, scale rlwe // It can then be evaluated on a ciphertext using evaluator.LinearTransform. // Evaluation will use the naive approach (single hoisting and no baby-step giant-step). // Faster if there is only a few non-zero diagonals but uses more keys. -func GenLinearTransform(encoder Encoder, value interface{}, level int, scale rlwe.Scale, logslots int) LinearTransform { +func GenLinearTransform(ecd *Encoder, value interface{}, level int, scale rlwe.Scale, logslots int) LinearTransform { - enc, ok := encoder.(*encoderComplex128) - if !ok { - panic("cannot GenLinearTransform: encoder should be an encoderComplex128") - } - - params := enc.params + params := ecd.params dMat := interfaceMapToMapOfInterface(value) vec := make(map[int]ringqp.Poly) slots := 1 << logslots @@ -249,8 +245,8 @@ func GenLinearTransform(encoder Encoder, value interface{}, level int, scale rlw if idx < 0 { idx += slots } - vec[idx] = *ringQP.NewPoly() - enc.Embed(dMat[i], logslots, scale, true, vec[idx]) + vec[idx] = ringQP.NewPoly() + ecd.Embed(dMat[i], logslots, scale, true, vec[idx]) } return LinearTransform{LogSlots: logslots, N1: 0, Vec: vec, Level: level, Scale: scale} @@ -264,14 +260,9 @@ func GenLinearTransform(encoder Encoder, value interface{}, level int, scale rlw // Faster if there is more than a few non-zero diagonals. // LogBSGSRatio is the log of the maximum ratio between the inner and outer loop of the baby-step giant-step algorithm used in evaluator.LinearTransform. // Optimal LogBSGSRatio value is between 0 and 4 depending on the sparsity of the matrix. -func GenLinearTransformBSGS(encoder Encoder, value interface{}, level int, scale rlwe.Scale, LogBSGSRatio int, logSlots int) (LT LinearTransform) { +func GenLinearTransformBSGS(ecd *Encoder, value interface{}, level int, scale rlwe.Scale, LogBSGSRatio int, logSlots int) (LT LinearTransform) { - enc, ok := encoder.(*encoderComplex128) - if !ok { - panic("cannot GenLinearTransformBSGS: encoder should be an encoderComplex128") - } - - params := enc.params + params := ecd.params slots := 1 << logSlots @@ -294,6 +285,10 @@ func GenLinearTransformBSGS(encoder Encoder, value interface{}, level int, scale values = make([]complex128, slots) case map[int][]float64: values = make([]float64, slots) + case map[int][]*big.Float: + values = make([]*big.Float, slots) + case map[int][]*bignum.Complex: + values = make([]*bignum.Complex, slots) } for j := range index { @@ -311,7 +306,7 @@ func GenLinearTransformBSGS(encoder Encoder, value interface{}, level int, scale copyRotInterface(values, v, rot) - enc.Embed(values, logSlots, scale, true, vec[j+i]) + ecd.Embed(values, logSlots, scale, true, vec[j+i]) } } @@ -346,6 +341,32 @@ func copyRotInterface(a, b interface{}, rot int) { } else { copy(af64[n-rot:], bf64) } + case []*big.Float: + + aF := a.([]*big.Float) + bF := b.([]*big.Float) + + n := len(aF) + + if len(bF) >= rot { + copy(aF[:n-rot], bF[rot:]) + copy(aF[n-rot:], bF[:rot]) + } else { + copy(aF[n-rot:], bF) + } + case []*bignum.Complex: + + aC := a.([]*bignum.Complex) + bC := b.([]*bignum.Complex) + + n := len(aC) + + if len(bC) >= rot { + copy(aC[:n-rot], bC[rot:]) + copy(aC[n-rot:], bC[:rot]) + } else { + copy(aC[n-rot:], bC) + } } } @@ -384,6 +405,20 @@ func BSGSIndex(el interface{}, slots, N1 int) (index map[int][]int, rotN1, rotN2 nonZeroDiags[i] = key i++ } + case map[int][]*big.Float: + nonZeroDiags = make([]int, len(element)) + var i int + for key := range element { + nonZeroDiags[i] = key + i++ + } + case map[int][]*bignum.Complex: + nonZeroDiags = make([]int, len(element)) + var i int + for key := range element { + nonZeroDiags[i] = key + i++ + } case []int: nonZeroDiags = element } @@ -425,8 +460,16 @@ func interfaceMapToMapOfInterface(m interface{}) map[int]interface{} { for i := range el { d[i] = el[i] } + case map[int][]*big.Float: + for i := range el { + d[i] = el[i] + } + case map[int][]*bignum.Complex: + for i := range el { + d[i] = el[i] + } default: - panic("cannot interfaceMapToMapOfInterface: invalid input, must be map[int][]complex128 or map[int][]float64") + panic("cannot interfaceMapToMapOfInterface: invalid input, must be map[int]{[]complex128, []float64, []*big.Float or []*bignum.Complex}") } return d } @@ -531,6 +574,7 @@ func (eval *evaluator) LinearTransform(ctIn *rlwe.Ciphertext, linearTransform in ctOut[i].MetaData = ctIn.MetaData ctOut[i].Scale = ctIn.Scale.Mul(LT.Scale) + ctOut[i].LogSlots = utils.MaxInt(ctOut[i].LogSlots, LT.LogSlots) } case LinearTransform: @@ -544,6 +588,7 @@ func (eval *evaluator) LinearTransform(ctIn *rlwe.Ciphertext, linearTransform in ctOut[0].MetaData = ctIn.MetaData ctOut[0].Scale = ctIn.Scale.Mul(LTs.Scale) + ctOut[0].LogSlots = utils.MaxInt(ctOut[0].LogSlots, LTs.LogSlots) } } diff --git a/ckks/params.go b/ckks/params.go index 89267596..7f1c8a6c 100644 --- a/ckks/params.go +++ b/ckks/params.go @@ -5,255 +5,19 @@ import ( "fmt" "math" "math/big" - "math/bits" "github.com/tuneinsight/lattigo/v4/ring" "github.com/tuneinsight/lattigo/v4/ring/distribution" "github.com/tuneinsight/lattigo/v4/rlwe" - "github.com/tuneinsight/lattigo/v4/utils" + "github.com/tuneinsight/lattigo/v4/utils/bignum" ) const ( - minLogSlots = 0 DefaultNTTFlag = true ) // Name of the different default parameter sets -var ( - - // PN12QP109 is a default parameter set for logN=12 and logQP=109 - PN12QP109 = ParametersLiteral{ - LogN: 12, - Q: []uint64{0x200000e001, 0x100006001}, // 37 + 32}, - P: []uint64{0x3ffffea001}, // 38 - LogScale: 32, - } - - // PN13QP218 is a default parameter set for logN=13 and logQP=218 - PN13QP218 = ParametersLiteral{ - LogN: 13, - Q: []uint64{0x1fffec001, // 33 + 5 x 30 - 0x3fff4001, - 0x3ffe8001, - 0x40020001, - 0x40038001, - 0x3ffc0001}, - P: []uint64{0x800004001}, // 35 - LogScale: 30, - } - // PN14QP438 is a default parameter set for logN=14 and logQP=438 - PN14QP438 = ParametersLiteral{ - LogN: 14, - Q: []uint64{0x200000008001, 0x400018001, // 45 + 9 x 34 - 0x3fffd0001, 0x400060001, - 0x400068001, 0x3fff90001, - 0x400080001, 0x4000a8001, - 0x400108001, 0x3ffeb8001}, - P: []uint64{0x7fffffd8001, 0x7fffffc8001}, // 43, 43 - LogScale: 34, - } - - // PN15QP880 is a default parameter set for logN=15 and logQP=880 - PN15QP880 = ParametersLiteral{ - LogN: 15, - Q: []uint64{0x4000000120001, 0x10000140001, 0xffffe80001, // 50 + 17 x 40 - 0x10000290001, 0xffffc40001, 0x100003e0001, - 0x10000470001, 0x100004b0001, 0xffffb20001, - 0x10000500001, 0x10000650001, 0xffff940001, - 0xffff8a0001, 0xffff820001, 0xffff780001, - 0x10000890001, 0xffff750001, 0x10000960001}, - P: []uint64{0x40000001b0001, 0x3ffffffdf0001, 0x4000000270001}, // 50, 50, 50 - LogScale: 40, - } - // PN16QP1761 is a default parameter set for logN=16 and logQP = 1761 - PN16QP1761 = ParametersLiteral{ - LogN: 16, - Q: []uint64{0x80000000080001, 0x2000000a0001, 0x2000000e0001, 0x1fffffc20001, // 55 + 33 x 45 - 0x200000440001, 0x200000500001, 0x200000620001, 0x1fffff980001, - 0x2000006a0001, 0x1fffff7e0001, 0x200000860001, 0x200000a60001, - 0x200000aa0001, 0x200000b20001, 0x200000c80001, 0x1fffff360001, - 0x200000e20001, 0x1fffff060001, 0x200000fe0001, 0x1ffffede0001, - 0x1ffffeca0001, 0x1ffffeb40001, 0x200001520001, 0x1ffffe760001, - 0x2000019a0001, 0x1ffffe640001, 0x200001a00001, 0x1ffffe520001, - 0x200001e80001, 0x1ffffe0c0001, 0x1ffffdee0001, 0x200002480001, - 0x1ffffdb60001, 0x200002560001}, - P: []uint64{0x80000000440001, 0x7fffffffba0001, 0x80000000500001, 0x7fffffffaa0001}, // 4 x 55 - LogScale: 45, - } - - // PN12QP109CI is a default parameter set for logN=12 and logQP=109 - PN12QP109CI = ParametersLiteral{ - LogN: 12, - Q: []uint64{0x1ffffe0001, 0x100014001}, // 37 + 32 - P: []uint64{0x4000038001}, // 38 - RingType: ring.ConjugateInvariant, - LogScale: 32, - } - - // PN13QP218CI is a default parameter set for logN=13 and logQP=218 - PN13QP218CI = ParametersLiteral{ - LogN: 13, - Q: []uint64{0x200038001, // 33 + 5 x 30 - 0x3ffe8001, - 0x40020001, - 0x40038001, - 0x3ffc0001, - 0x40080001}, - P: []uint64{0x800008001}, // 35 - RingType: ring.ConjugateInvariant, - LogScale: 30, - } - // PN14QP438CI is a default parameter set for logN=14 and logQP=438 - PN14QP438CI = ParametersLiteral{ - LogN: 14, - Q: []uint64{0x2000000a0001, 0x3fffd0001, // 45 + 9*34 - 0x400060001, 0x3fff90001, - 0x400080001, 0x400180001, - 0x3ffd20001, 0x400300001, - 0x400360001, 0x4003e0001}, - P: []uint64{0x80000050001, 0x7ffffdb0001}, // 43, 43 - RingType: ring.ConjugateInvariant, - LogScale: 34, - } - - // PN15QP880CI is a default parameter set for logN=15 and logQP=880 - PN15QP880CI = ParametersLiteral{ - LogN: 15, - Q: []uint64{0x4000000120001, // 50 + 17 x 40 - 0x10000140001, 0xffffe80001, 0xffffc40001, - 0x100003e0001, 0xffffb20001, 0x10000500001, - 0xffff940001, 0xffff8a0001, 0xffff820001, - 0xffff780001, 0x10000960001, 0x10000a40001, - 0xffff580001, 0x10000b60001, 0xffff480001, - 0xffff420001, 0xffff340001}, - P: []uint64{0x3ffffffd20001, 0x4000000420001, 0x3ffffffb80001}, // 50, 50, 50 - RingType: ring.ConjugateInvariant, - LogScale: 40, - } - // PN16QP1761CI is a default parameter set for logN=16 and logQP = 1761 - PN16QP1761CI = ParametersLiteral{ - LogN: 16, - Q: []uint64{0x80000000080001, // 55 + 33 x 45 - 0x200000440001, 0x200000500001, 0x1fffff980001, 0x200000c80001, - 0x1ffffeb40001, 0x1ffffe640001, 0x200001a00001, 0x200001e80001, - 0x1ffffe0c0001, 0x200002480001, 0x200002800001, 0x1ffffd800001, - 0x200002900001, 0x1ffffd700001, 0x2000029c0001, 0x1ffffcf00001, - 0x200003140001, 0x1ffffcc80001, 0x1ffffcb40001, 0x1ffffc980001, - 0x200003740001, 0x200003800001, 0x200003d40001, 0x1ffffc200001, - 0x1ffffc140001, 0x200004100001, 0x200004180001, 0x1ffffbc40001, - 0x200004700001, 0x1ffffb900001, 0x200004cc0001, 0x1ffffb240001, - 0x200004e80001}, - P: []uint64{0x80000000440001, 0x80000000500001, 0x7fffffff380001, 0x80000000e00001}, // 4 x 55 - RingType: ring.ConjugateInvariant, - LogScale: 45, - } - - // PN12QP101pq is a default (post quantum) parameter set for logN=12 and logQP=101 - PN12QP101pq = ParametersLiteral{ - LogN: 12, - Q: []uint64{0x800004001, 0x40002001}, // 35 + 30 - P: []uint64{0x1000002001}, // 36 - LogScale: 30, - } - // PN13QP202pq is a default (post quantum) parameter set for logN=13 and logQP=202 - PN13QP202pq = ParametersLiteral{ - LogN: 13, - Q: []uint64{0x1fffec001, 0x8008001, 0x8020001, 0x802c001, 0x7fa8001, 0x7f74001}, // 33 + 5 x 27 - P: []uint64{0x400018001}, // 34 - LogScale: 27, - } - - // PN14QP411pq is a default (post quantum) parameter set for logN=14 and logQP=411 - PN14QP411pq = ParametersLiteral{ - LogN: 14, - Q: []uint64{0x10000048001, 0x200038001, 0x1fff90001, 0x200080001, 0x1fff60001, - 0x2000b8001, 0x200100001, 0x1fff00001, 0x1ffef0001, 0x200128001}, // 40 + 9 x 33 - P: []uint64{0x1ffffe0001, 0x1ffffc0001}, // 37, 37 - LogScale: 33, - } - - // PN15QP827pq is a default (post quantum) parameter set for logN=15 and logQP=827 - PN15QP827pq = ParametersLiteral{ - LogN: 15, - Q: []uint64{0x400000060001, 0x4000170001, 0x3fffe80001, 0x40002f0001, 0x4000300001, - 0x3fffcf0001, 0x40003f0001, 0x3fffc10001, 0x4000450001, 0x3fffb80001, - 0x3fffb70001, 0x40004a0001, 0x3fffb20001, 0x4000510001, 0x3fffaf0001, - 0x4000540001, 0x4000560001, 0x4000590001}, // 46 + 17 x 38 - P: []uint64{0x2000000a0001, 0x2000000e0001, 0x2000001d0001}, // 3 x 45 - LogScale: 38, - } - // PN16QP1654pq is a default (post quantum) parameter set for logN=16 and logQP=1654 - PN16QP1654pq = ParametersLiteral{ - LogN: 16, - Q: []uint64{0x80000000080001, 0x2000000a0001, 0x2000000e0001, 0x1fffffc20001, 0x200000440001, - 0x200000500001, 0x200000620001, 0x1fffff980001, 0x2000006a0001, 0x1fffff7e0001, - 0x200000860001, 0x200000a60001, 0x200000aa0001, 0x200000b20001, 0x200000c80001, - 0x1fffff360001, 0x200000e20001, 0x1fffff060001, 0x200000fe0001, 0x1ffffede0001, - 0x1ffffeca0001, 0x1ffffeb40001, 0x200001520001, 0x1ffffe760001, 0x2000019a0001, - 0x1ffffe640001, 0x200001a00001, 0x1ffffe520001, 0x200001e80001, 0x1ffffe0c0001, - 0x1ffffdee0001, 0x200002480001}, // 55 + 31 x 45 - P: []uint64{0x7fffffffe0001, 0x80000001c0001, 0x80000002c0001, 0x7ffffffd20001}, // 4 x 51 - LogScale: 45, - } - - // PN12QP101pq is a default (post quantum) parameter set for logN=12 and logQP=101 - PN12QP101CIpq = ParametersLiteral{ - LogN: 12, - Q: []uint64{0x800004001, 0x3fff4001}, // 35 + 30 - P: []uint64{0xffffc4001}, // 36 - RingType: ring.ConjugateInvariant, - LogScale: 30, - } - // PN13QP202CIpq is a default (post quantum) parameter set for logN=13 and logQP=202 - PN13QP202CIpq = ParametersLiteral{ - LogN: 13, - Q: []uint64{0x1ffffe0001, 0x100050001, 0xfff88001, 0x100098001, 0x1000b0001}, // 37 + 4 x 32 - P: []uint64{0x1ffffc0001}, // 37 - RingType: ring.ConjugateInvariant, - LogScale: 32, - } - - // PN14QP411CIpq is a default (post quantum) parameter set for logN=14 and logQP=411 - PN14QP411CIpq = ParametersLiteral{ - LogN: 14, - Q: []uint64{0x10000140001, 0x1fff90001, 0x200080001, - 0x1fff60001, 0x200100001, 0x1fff00001, - 0x1ffef0001, 0x1ffe60001, 0x2001d0001, - 0x2002e0001}, // 40 + 9 x 33 - - P: []uint64{0x1ffffe0001, 0x1ffffc0001}, // 37, 37 - RingType: ring.ConjugateInvariant, - LogScale: 33, - } - - // PN15QP827CIpq is a default (post quantum) parameter set for logN=15 and logQP=827 - PN15QP827CIpq = ParametersLiteral{ - LogN: 15, - Q: []uint64{0x400000060001, 0x3fffe80001, 0x4000300001, 0x3fffb80001, - 0x40004a0001, 0x3fffb20001, 0x4000540001, 0x4000560001, - 0x3fff900001, 0x4000720001, 0x3fff8e0001, 0x4000800001, - 0x40008a0001, 0x3fff6c0001, 0x40009e0001, 0x3fff300001, - 0x3fff1c0001, 0x4000fc0001}, // 46 + 17 x 38 - P: []uint64{0x2000000a0001, 0x2000000e0001, 0x1fffffc20001}, // 3 x 45 - RingType: ring.ConjugateInvariant, - LogScale: 38, - } - // PN16QP1654CIpq is a default (post quantum) parameter set for logN=16 and logQP=1654 - PN16QP1654CIpq = ParametersLiteral{ - LogN: 16, - Q: []uint64{0x80000000080001, 0x200000440001, 0x200000500001, 0x1fffff980001, - 0x200000c80001, 0x1ffffeb40001, 0x1ffffe640001, 0x200001a00001, - 0x200001e80001, 0x1ffffe0c0001, 0x200002480001, 0x200002800001, - 0x1ffffd800001, 0x200002900001, 0x1ffffd700001, 0x2000029c0001, - 0x1ffffcf00001, 0x200003140001, 0x1ffffcc80001, 0x1ffffcb40001, - 0x1ffffc980001, 0x200003740001, 0x200003800001, 0x200003d40001, - 0x1ffffc200001, 0x1ffffc140001, 0x200004100001, 0x200004180001, - 0x1ffffbc40001, 0x200004700001, 0x1ffffb900001, 0x200004cc0001}, // 55 + 31 x 45 - P: []uint64{0x80000001c0001, 0x80000002c0001, 0x8000000500001, 0x7ffffff9c0001}, // 4 x 51 - RingType: ring.ConjugateInvariant, - LogScale: 45, - } -) +var () // ParametersLiteral is a literal representation of CKKS parameters. It has public // fields and is used to express unchecked user-defined parameters literally into @@ -277,7 +41,6 @@ type ParametersLiteral struct { Xe distribution.Distribution Xs distribution.Distribution RingType ring.Type - LogSlots int LogScale int } @@ -298,28 +61,15 @@ func (p ParametersLiteral) RLWEParametersLiteral() rlwe.ParametersLiteral { } } -// DefaultParams is a set of default CKKS parameters ensuring 128 bit security in a classic setting. -var DefaultParams = []ParametersLiteral{PN12QP109, PN13QP218, PN14QP438, PN15QP880, PN16QP1761} - -// DefaultConjugateInvariantParams is a set of default conjugate invariant parameters for encrypting real values and ensuring 128 bit security in a classic setting. -var DefaultConjugateInvariantParams = []ParametersLiteral{PN12QP109CI, PN13QP218CI, PN14QP438CI, PN15QP880CI, PN16QP1761CI} - -// DefaultPostQuantumParams is a set of default CKKS parameters ensuring 128 bit security in a post-quantum setting. -var DefaultPostQuantumParams = []ParametersLiteral{PN12QP101pq, PN13QP202pq, PN14QP411pq, PN15QP827pq, PN16QP1654pq} - -// DefaultPostQuantumConjugateInvariantParams is a set of default conjugate invariant parameters for encrypting real values and ensuring 128 bit security in a post-quantum setting. -var DefaultPostQuantumConjugateInvariantParams = []ParametersLiteral{PN12QP101CIpq, PN13QP202CIpq, PN14QP411CIpq, PN15QP827CIpq, PN16QP1654CIpq} - // Parameters represents a parameter set for the CKKS cryptosystem. Its fields are private and // immutable. See ParametersLiteral for user-specified parameters. type Parameters struct { rlwe.Parameters - logSlots int } // NewParameters instantiate a set of CKKS parameters from the generic RLWE parameters and the CKKS-specific ones. // It returns the empty parameters Parameters{} and a non-nil error if the specified parameters are invalid. -func NewParameters(rlweParams rlwe.Parameters, logSlots int) (p Parameters, err error) { +func NewParameters(rlweParams rlwe.Parameters) (p Parameters, err error) { if !rlweParams.DefaultNTTFlag() { return Parameters{}, fmt.Errorf("provided RLWE parameters are invalid for CKKS scheme (DefaultNTTFlag must be true)") @@ -329,11 +79,7 @@ func NewParameters(rlweParams rlwe.Parameters, logSlots int) (p Parameters, err return Parameters{}, fmt.Errorf("provided RLWE parameters are invalid") } - if maxLogSlots := bits.Len64(rlweParams.RingQ().NthRoot()) - 3; logSlots > maxLogSlots || logSlots < minLogSlots { - return Parameters{}, fmt.Errorf("logSlot=%d is larger than the logN-1=%d or smaller than %d", logSlots, maxLogSlots, minLogSlots) - } - - return Parameters{rlweParams, logSlots}, nil + return Parameters{rlweParams}, nil } // NewParametersFromLiteral instantiate a set of CKKS parameters from a ParametersLiteral specification. @@ -349,16 +95,7 @@ func NewParametersFromLiteral(pl ParametersLiteral) (Parameters, error) { return Parameters{}, err } - if pl.LogSlots == 0 { - switch pl.RingType { - case ring.Standard: - pl.LogSlots = pl.LogN - 1 - case ring.ConjugateInvariant: - pl.LogSlots = pl.LogN - } - } - - return NewParameters(rlweParams, pl.LogSlots) + return NewParameters(rlweParams) } // StandardParameters returns the CKKS parameters corresponding to the receiver @@ -384,13 +121,31 @@ func (p Parameters) ParametersLiteral() (pLit ParametersLiteral) { Xs: p.Xs(), RingType: p.RingType(), LogScale: int(math.Round(math.Log2(p.DefaultScale().Float64()))), - LogSlots: p.LogSlots(), } } -// LogSlots returns the log of the number of slots -func (p Parameters) LogSlots() int { - return p.logSlots +// DefaultPrecision returns the default precision in bits of the plaintext values which +// is max(53, log2(DefaultScale)). +func (p Parameters) DefaultPrecision() (prec uint) { + if log2scale := math.Log2(p.DefaultScale().Float64()); log2scale <= 53 { + prec = 53 + } else { + prec = uint(log2scale) + } + + return +} + +// MaxDepth returns MaxLevel / DefaultScaleModuliRatio which is the maximum number of multiplicaitons +// followed by a rescaling that can be carried out with on a ciphertext with the DefaultScale. +func (p Parameters) MaxDepth() int { + return p.MaxLevel() / p.DefaultScaleModuliRatio() +} + +// DefaultScaleModuliRatio returns the default ratio between the scaling factor and moduli. +// This default ratio is computed as ceil(DefaultScalingFactor/2^{60}). +func (p Parameters) DefaultScaleModuliRatio() int { + return int(math.Ceil(math.Log2(p.DefaultScale().Float64()) / 60.0)) } // MaxLevel returns the maximum ciphertext level @@ -398,11 +153,6 @@ func (p Parameters) MaxLevel() int { return p.QCount() - 1 } -// Slots returns number of available plaintext slots -func (p Parameters) Slots() int { - return 1 << p.logSlots -} - // MaxSlots returns the theoretical maximum of plaintext slots allowed by the ring degree func (p Parameters) MaxSlots() int { switch p.RingType() { @@ -435,9 +185,9 @@ func (p Parameters) LogQLvl(level int) int { // QLvl returns the product of the moduli at the given level as a big.Int func (p Parameters) QLvl(level int) *big.Int { - tmp := ring.NewUint(1) + tmp := bignum.NewInt(1) for _, qi := range p.Q()[:level+1] { - tmp.Mul(tmp, ring.NewUint(qi)) + tmp.Mul(tmp, bignum.NewInt(qi)) } return tmp } @@ -471,10 +221,9 @@ func (p Parameters) GaloisElementsForLinearTransform(nonZeroDiags interface{}, l return } -// Equal compares two sets of parameters for equality. -func (p Parameters) Equal(other Parameters) bool { - res := p.Parameters.Equal(other.Parameters) - res = res && (p.logSlots == other.LogSlots()) +// Equals compares two sets of parameters for equality. +func (p Parameters) Equals(other Parameters) bool { + res := p.Parameters.Equals(other.Parameters) return res } diff --git a/ckks/polynomial_evaluation.go b/ckks/polynomial_evaluation.go index 62c01530..648b8c1a 100644 --- a/ckks/polynomial_evaluation.go +++ b/ckks/polynomial_evaluation.go @@ -9,50 +9,36 @@ import ( "github.com/tuneinsight/lattigo/v4/rlwe" "github.com/tuneinsight/lattigo/v4/utils" - "github.com/tuneinsight/lattigo/v4/utils/bignum/polynomial" + "github.com/tuneinsight/lattigo/v4/utils/bignum" ) -// Polynomial is a struct storing the coefficients of a polynomial -// that then can be evaluated on the ciphertext -type Polynomial struct { - polynomial.Basis // Either `Monomial` or `Chebyshev` - MaxDeg int // Always set to len(Coeffs)-1 - Coeffs []complex128 // List of coefficients - Lead bool // Always set to true - A float64 // Bound A of the interval [A, B] - B float64 // Bound B of the interval [A, B] - Lazy bool // Flag for lazy-relinearization +type polynomial struct { + *bignum.Polynomial + Prec uint + MaxDeg int // Always set to len(Coeffs)-1 + Lead bool // Always set to true + Lazy bool // Flag for lazy-relinearization } -// IsNegligibleThreshold : threshold under which a coefficient -// of a polynomial is ignored. -const IsNegligibleThreshold float64 = 1e-14 - -// Depth returns the number of levels needed to evaluate the polynomial. -func (p *Polynomial) Depth() int { - return int(math.Ceil(math.Log2(float64(len(p.Coeffs))))) +func newPolynomial(poly *bignum.Polynomial, prec uint) (p *polynomial) { + return &polynomial{ + Polynomial: poly, + MaxDeg: poly.Degree(), + Lead: true, + Prec: prec, + } } -// Degree returns the degree of the polynomial -func (p *Polynomial) Degree() int { - return len(p.Coeffs) - 1 -} - -// NewPoly creates a new Poly from the input coefficients -func NewPoly(coeffs []complex128) (p *Polynomial) { - c := make([]complex128, len(coeffs)) - copy(c, coeffs) - return &Polynomial{Coeffs: c, MaxDeg: len(c) - 1, Lead: true} +type polynomialVector struct { + Encoder *Encoder + Value []*polynomial + SlotsIndex map[int][]int } // checkEnoughLevels checks that enough levels are available to evaluate the polynomial. // Also checks if c is a Gaussian integer or not. If not, then one more level is needed // to evaluate the polynomial. -func checkEnoughLevels(levels, depth int, c complex128) (err error) { - - if real(c) != float64(int64(real(c))) || imag(c) != float64(int64(imag(c))) { - depth++ - } +func checkEnoughLevels(levels, depth int) (err error) { if levels < depth { return fmt.Errorf("%d levels < %d log(d) -> cannot evaluate", levels, depth) @@ -63,8 +49,8 @@ func checkEnoughLevels(levels, depth int, c complex128) (err error) { type polynomialEvaluator struct { Evaluator - Encoder - PowerBasis + *Encoder + PolynomialBasis slotsIndex map[int][]int logDegree int logSplit int @@ -77,21 +63,12 @@ type polynomialEvaluator struct { // Returns an error if something is wrong with the scale. // If the polynomial is given in Chebyshev basis, then a change of basis ct' = (2/(b-a)) * (ct + (-a-b)/(b-a)) // is necessary before the polynomial evaluation to ensure correctness. -// Coefficients of the polynomial with an absolute value smaller than "IsNegligibleThreshold" will automatically be set to zero -// if the polynomial is "even" or "odd" (to ensure that the even or odd property remains valid -// after the "splitCoeffs" polynomial decomposition). -// input must be either *rlwe.Ciphertext or *PowerBasis. +// input must be either *rlwe.Ciphertext or *PolynomialBasis. // pol: a *Polynomial // targetScale: the desired output scale. This value shouldn't differ too much from the original ciphertext scale. It can // for example be used to correct small deviations in the ciphertext scale and reset it to the default scale. -func (eval *evaluator) EvaluatePoly(input interface{}, pol *Polynomial, targetScale rlwe.Scale) (opOut *rlwe.Ciphertext, err error) { - return eval.evaluatePolyVector(input, polynomialVector{Value: []*Polynomial{pol}}, targetScale) -} - -type polynomialVector struct { - Encoder Encoder - Value []*Polynomial - SlotsIndex map[int][]int +func (eval *evaluator) EvaluatePoly(input interface{}, poly *bignum.Polynomial, targetScale rlwe.Scale) (opOut *rlwe.Ciphertext, err error) { + return eval.evaluatePolyVector(input, polynomialVector{Value: []*polynomial{newPolynomial(poly, eval.params.DefaultPrecision())}}, targetScale) } // EvaluatePolyVector evaluates a vector of Polynomials on the input Ciphertext in ceil(log2(deg+1)) levels. @@ -101,10 +78,7 @@ type polynomialVector struct { // Returns an error if polynomials do not all have the same degree. // If the polynomials are given in Chebyshev basis, then a change of basis ct' = (2/(b-a)) * (ct + (-a-b)/(b-a)) // is necessary before the polynomial evaluation to ensure correctness. -// Coefficients of the polynomial with an absolute value smaller than "IsNegligibleThreshold" will automatically be set to zero -// if the polynomial is "even" or "odd" (to ensure that the even or odd property remains valid -// after the "splitCoeffs" polynomial decomposition). -// input: must be either *rlwe.Ciphertext or *PowerBasis. +// input: must be either *rlwe.Ciphertext or *PolynomialBasis. // pols: a slice of up to 'n' *Polynomial ('n' being the maximum number of slots), indexed from 0 to n-1. // encoder: an Encoder. // slotsIndex: a map[int][]int indexing as key the polynomial to evaluate and as value the index of the slots on which to evaluate the polynomial indexed by the key. @@ -113,25 +87,32 @@ type polynomialVector struct { // // Example: if pols = []*Polynomial{pol0, pol1} and slotsIndex = map[int][]int:{0:[1, 2, 4, 5, 7], 1:[0, 3]}, // then pol0 will be applied to slots [1, 2, 4, 5, 7], pol1 to slots [0, 3] and the slot 6 will be zero-ed. -func (eval *evaluator) EvaluatePolyVector(input interface{}, pols []*Polynomial, encoder Encoder, slotsIndex map[int][]int, targetScale rlwe.Scale) (opOut *rlwe.Ciphertext, err error) { +func (eval *evaluator) EvaluatePolyVector(input interface{}, polys []*bignum.Polynomial, encoder *Encoder, slotsIndex map[int][]int, targetScale rlwe.Scale) (opOut *rlwe.Ciphertext, err error) { var maxDeg int - var basis polynomial.Basis - for i := range pols { - maxDeg = utils.Max(maxDeg, pols[i].MaxDeg) - basis = pols[i].Basis + var basis bignum.BasisType + for i := range polys { + maxDeg = utils.MaxInt(maxDeg, polys[i].Degree()) + basis = polys[i].BasisType } - for i := range pols { - if basis != pols[i].Basis { + for i := range polys { + if basis != polys[i].BasisType { return nil, fmt.Errorf("polynomial basis must be the same for all polynomials in a polynomial vector") } - if maxDeg != pols[i].MaxDeg { + if maxDeg != polys[i].Degree() { return nil, fmt.Errorf("polynomial degree must all be the same") } } - return eval.evaluatePolyVector(input, polynomialVector{Encoder: encoder, Value: pols, SlotsIndex: slotsIndex}, targetScale) + polyvec := make([]*polynomial, len(polys)) + + prec := eval.params.DefaultPrecision() + for i := range polys { + polyvec[i] = newPolynomial(polys[i], prec) + } + + return eval.evaluatePolyVector(input, polynomialVector{Encoder: encoder, Value: polyvec, SlotsIndex: slotsIndex}, targetScale) } func optimalSplit(logDegree int) (logSplit int) { @@ -164,7 +145,9 @@ func (eval *evaluator) evaluatePolyVector(input interface{}, pol polynomialVecto return nil, fmt.Errorf("cannot evaluatePolyVector: invalid input, must be either *rlwe.Ciphertext or *PowerBasis") } - if err := checkEnoughLevels(monomialBasis.Value[1].Level(), pol.Value[0].Depth(), 1); err != nil { + nbModuliPerRescale := eval.params.DefaultScaleModuliRatio() + + if err := checkEnoughLevels(monomialBasis.Value[1].Level(), nbModuliPerRescale*pol.Value[0].Depth()); err != nil { return nil, err } @@ -173,8 +156,7 @@ func (eval *evaluator) evaluatePolyVector(input interface{}, pol polynomialVecto var odd, even bool = true, true for _, p := range pol.Value { - tmp0, tmp1 := isOddOrEvenPolynomial(p.Coeffs) - odd, even = odd && tmp0, even && tmp1 + odd, even = odd && p.IsOdd, even && p.IsEven } // Computes all the powers of two with relinearization @@ -202,11 +184,13 @@ func (eval *evaluator) evaluatePolyVector(input interface{}, pol polynomialVecto polyEval.isOdd = odd polyEval.isEven = even - if opOut, err = polyEval.recurse(monomialBasis.Value[1].Level()-logDegree+1, targetScale, pol); err != nil { + if opOut, err = polyEval.recurse(monomialBasis.Value[1].Level()-nbModuliPerRescale*(logDegree-1), targetScale, pol); err != nil { return nil, err } - polyEval.Relinearize(opOut, opOut) + if opOut.Degree() == 2 { + polyEval.Relinearize(opOut, opOut) + } if err = polyEval.Rescale(opOut, targetScale, opOut); err != nil { return nil, err @@ -219,11 +203,170 @@ func (eval *evaluator) evaluatePolyVector(input interface{}, pol polynomialVecto return opOut, err } -func splitCoeffs(coeffs *Polynomial, split int) (coeffsq, coeffsr *Polynomial) { +// PolynomialBasis is a struct storing powers of a ciphertext. +type PolynomialBasis struct { + bignum.BasisType + Value map[int]*rlwe.Ciphertext +} + +// NewPolynomialBasis creates a new PolynomialBasis. It takes as input a ciphertext +// and a basistype. The struct treats the input ciphertext as a monomial X and +// can be used to generates power of this monomial X^{n} in the given BasisType. +func NewPolynomialBasis(ct *rlwe.Ciphertext, basistype bignum.BasisType) (p *PolynomialBasis) { + p = new(PolynomialBasis) + p.Value = make(map[int]*rlwe.Ciphertext) + p.Value[1] = ct.CopyNew() + p.BasisType = basistype + return +} + +// GenPower recursively computes X^{n}. +// If lazy = true, the final X^{n} will not be relinearized. +// Previous non-relinearized X^{n} that are required to compute the target X^{n} are automatically relinearized. +// Scale sets the threshold for rescaling (ciphertext won't be rescaled if the rescaling operation would make the scale go under this threshold). +func (p *PolynomialBasis) GenPower(n int, lazy bool, scale rlwe.Scale, eval Evaluator) (err error) { + + if p.Value[n] == nil { + if err = p.genPower(n, lazy, scale, eval); err != nil { + return + } + + if err = eval.Rescale(p.Value[n], scale, p.Value[n]); err != nil { + return + } + } + + return nil +} + +func (p *PolynomialBasis) genPower(n int, lazy bool, scale rlwe.Scale, eval Evaluator) (err error) { + + if p.Value[n] == nil { + + isPow2 := n&(n-1) == 0 + + // Computes the index required to compute the asked ring evaluation + var a, b, c int + if isPow2 { + a, b = n/2, n/2 //Necessary for optimal depth + } else { + // [Lee et al. 2020] : High-Precision and Low-Complexity Approximate Homomorphic Encryption by Error Variance Minimization + // Maximize the number of odd terms of Chebyshev basis + k := int(math.Ceil(math.Log2(float64(n)))) - 1 + a = (1 << k) - 1 + b = n + 1 - (1 << k) + + if p.BasisType == bignum.Chebyshev { + c = int(math.Abs(float64(a) - float64(b))) // Cn = 2*Ca*Cb - Cc, n = a+b and c = abs(a-b) + } + } + + // Recurses on the given indexes + if err = p.genPower(a, lazy && !isPow2, scale, eval); err != nil { + return err + } + if err = p.genPower(b, lazy && !isPow2, scale, eval); err != nil { + return err + } + + // Computes C[n] = C[a]*C[b] + if lazy { + if p.Value[a].Degree() == 2 { + eval.Relinearize(p.Value[a], p.Value[a]) + } + + if p.Value[b].Degree() == 2 { + eval.Relinearize(p.Value[b], p.Value[b]) + } + + if err = eval.Rescale(p.Value[a], scale, p.Value[a]); err != nil { + return err + } + + if err = eval.Rescale(p.Value[b], scale, p.Value[b]); err != nil { + return err + } + + p.Value[n] = eval.MulNew(p.Value[a], p.Value[b]) + + } else { + + if err = eval.Rescale(p.Value[a], scale, p.Value[a]); err != nil { + return err + } + + if err = eval.Rescale(p.Value[b], scale, p.Value[b]); err != nil { + return err + } + + p.Value[n] = eval.MulRelinNew(p.Value[a], p.Value[b]) + } + + if p.BasisType == bignum.Chebyshev { + + // Computes C[n] = 2*C[a]*C[b] + eval.Add(p.Value[n], p.Value[n], p.Value[n]) + + // Computes C[n] = 2*C[a]*C[b] - C[c] + if c == 0 { + eval.Add(p.Value[n], -1, p.Value[n]) + } else { + // Since C[0] is not stored (but rather seen as the constant 1), only recurses on c if c!= 0 + if err = p.GenPower(c, lazy, scale, eval); err != nil { + return err + } + eval.Sub(p.Value[n], p.Value[c], p.Value[n]) + } + } + } + return +} + +// MarshalBinary encodes the target on a slice of bytes. +func (p *PolynomialBasis) MarshalBinary() (data []byte, err error) { + data = make([]byte, 16) + binary.LittleEndian.PutUint64(data[0:8], uint64(len(p.Value))) + binary.LittleEndian.PutUint64(data[8:16], uint64(p.Value[1].MarshalBinarySize())) + for key, ct := range p.Value { + keyBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(keyBytes, uint64(key)) + data = append(data, keyBytes...) + ctBytes, err := ct.MarshalBinary() + if err != nil { + return []byte{}, err + } + data = append(data, ctBytes...) + } + return +} + +// UnmarshalBinary decodes a slice of bytes on the target. +func (p *PolynomialBasis) UnmarshalBinary(data []byte) (err error) { + p.Value = make(map[int]*rlwe.Ciphertext) + nbct := int(binary.LittleEndian.Uint64(data[0:8])) + dtLen := int(binary.LittleEndian.Uint64(data[8:16])) + ptr := 16 + for i := 0; i < nbct; i++ { + idx := int(binary.LittleEndian.Uint64(data[ptr : ptr+8])) + ptr += 8 + p.Value[idx] = new(rlwe.Ciphertext) + if err = p.Value[idx].UnmarshalBinary(data[ptr : ptr+dtLen]); err != nil { + return + } + ptr += dtLen + } + return +} + +// splitCoeffs splits coeffs as X^{2n} * coeffsq + coeffsr. +// This function is sensitive to the precision of the coefficients. +func splitCoeffs(coeffs *polynomial, split int) (coeffsq, coeffsr *polynomial) { + + prec := coeffs.Prec // Splits a polynomial p such that p = q*C^degree + r. - coeffsr = &Polynomial{} - coeffsr.Coeffs = make([]complex128, split) + coeffsr = &polynomial{Polynomial: &bignum.Polynomial{}} + coeffsr.Coeffs = make([]*bignum.Complex, split) if coeffs.MaxDeg == coeffs.Degree() { coeffsr.MaxDeg = split - 1 } else { @@ -231,23 +374,49 @@ func splitCoeffs(coeffs *Polynomial, split int) (coeffsq, coeffsr *Polynomial) { } for i := 0; i < split; i++ { - coeffsr.Coeffs[i] = coeffs.Coeffs[i] + if coeffs.Coeffs[i] != nil { + coeffsr.Coeffs[i] = coeffs.Coeffs[i].Copy() + coeffsr.Coeffs[i].SetPrec(prec) + } + } - coeffsq = &Polynomial{} - coeffsq.Coeffs = make([]complex128, coeffs.Degree()-split+1) + coeffsq = &polynomial{Polynomial: &bignum.Polynomial{}} + coeffsq.Coeffs = make([]*bignum.Complex, coeffs.Degree()-split+1) coeffsq.MaxDeg = coeffs.MaxDeg - coeffsq.Coeffs[0] = coeffs.Coeffs[split] + if coeffs.Coeffs[split] != nil { + coeffsq.Coeffs[0] = coeffs.Coeffs[split].Copy() + } - if coeffs.Basis == polynomial.Monomial { + odd := coeffs.IsOdd + even := coeffs.IsEven + + switch coeffs.BasisType { + case bignum.Monomial: for i := split + 1; i < coeffs.Degree()+1; i++ { - coeffsq.Coeffs[i-split] = coeffs.Coeffs[i] + if coeffs.Coeffs[i] != nil && (!(even || odd) || (i&1 == 0 && even) || (i&1 == 1 && odd)) { + coeffsq.Coeffs[i-split] = coeffs.Coeffs[i].Copy() + coeffsr.Coeffs[i-split].SetPrec(prec) + } } - } else if coeffs.Basis == polynomial.Chebyshev { + case bignum.Chebyshev: + for i, j := split+1, 1; i < coeffs.Degree()+1; i, j = i+1, j+1 { - coeffsq.Coeffs[i-split] = 2 * coeffs.Coeffs[i] - coeffsr.Coeffs[split-j] -= coeffs.Coeffs[i] + if coeffs.Coeffs[i] != nil && (!(even || odd) || (i&1 == 0 && even) || (i&1 == 1 && odd)) { + coeffsq.Coeffs[i-split] = coeffs.Coeffs[i].Copy() + coeffsr.Coeffs[i-split].SetPrec(prec) + coeffsq.Coeffs[i-split].Add(coeffsq.Coeffs[i-split], coeffsq.Coeffs[i-split]) + + if coeffsr.Coeffs[split-j] != nil { + coeffsr.Coeffs[split-j].Sub(coeffsr.Coeffs[split-j], coeffs.Coeffs[i]) + } else { + coeffsr.Coeffs[split-j] = coeffs.Coeffs[i].Copy() + coeffsr.Coeffs[split-j].SetPrec(prec) + coeffsr.Coeffs[split-j][0].Neg(coeffsr.Coeffs[split-j][0]) + coeffsr.Coeffs[split-j][1].Neg(coeffsr.Coeffs[split-j][1]) + } + } } } @@ -255,14 +424,17 @@ func splitCoeffs(coeffs *Polynomial, split int) (coeffsq, coeffsr *Polynomial) { coeffsq.Lead = true } - coeffsq.Basis, coeffsr.Basis = coeffs.Basis, coeffs.Basis + coeffsq.BasisType, coeffsr.BasisType = coeffs.BasisType, coeffs.BasisType + coeffsq.IsOdd, coeffsr.IsOdd = coeffs.IsOdd, coeffs.IsOdd + coeffsq.IsEven, coeffsr.IsEven = coeffs.IsEven, coeffs.IsEven + coeffsq.Prec, coeffsr.Prec = prec, prec return } func splitCoeffsPolyVector(poly polynomialVector, split int) (polyq, polyr polynomialVector) { - coeffsq := make([]*Polynomial, len(poly.Value)) - coeffsr := make([]*Polynomial, len(poly.Value)) + coeffsq := make([]*polynomial, len(poly.Value)) + coeffsr := make([]*polynomial, len(poly.Value)) for i, p := range poly.Value { coeffsq[i], coeffsr[i] = splitCoeffs(p, split) } @@ -276,6 +448,8 @@ func (polyEval *polynomialEvaluator) recurse(targetLevel int, targetScale rlwe.S logSplit := polyEval.logSplit + nbModuliPerRescale := params.DefaultScaleModuliRatio() + // Recursively computes the evaluation of the Chebyshev polynomial using a baby-set giant-step algorithm. if pol.Value[0].Degree() < (1 << logSplit) { @@ -298,7 +472,12 @@ func (polyEval *polynomialEvaluator) recurse(targetLevel int, targetScale rlwe.S } if pol.Value[0].Lead { - targetScale = targetScale.Mul(rlwe.NewScale(params.QiFloat64(targetLevel))) + + targetScale = targetScale.Mul(rlwe.NewScale(params.Q()[targetLevel])) + + for i := 1; i < nbModuliPerRescale; i++ { + targetScale = targetScale.Mul(rlwe.NewScale(params.Q()[targetLevel-i])) + } } return polyEval.evaluatePolyFromPowerBasis(targetScale, targetLevel, pol) @@ -315,20 +494,25 @@ func (polyEval *polynomialEvaluator) recurse(targetLevel int, targetScale rlwe.S level := targetLevel - var currentQi float64 + var qi *big.Int if pol.Value[0].Lead { - currentQi = params.QiFloat64(level) + qi = bignum.NewInt(params.Q()[level]) + for i := 1; i < nbModuliPerRescale; i++ { + qi.Mul(qi, bignum.NewInt(params.Q()[level-i])) + } } else { - currentQi = params.QiFloat64(level + 1) + qi = bignum.NewInt(params.Q()[level+nbModuliPerRescale]) + for i := 1; i < nbModuliPerRescale; i++ { + qi.Mul(qi, bignum.NewInt(params.Q()[level+nbModuliPerRescale-i])) + } } - targetScale = targetScale.Mul(rlwe.NewScale(currentQi)) + targetScale = targetScale.Mul(rlwe.NewScale(qi)) targetScale = targetScale.Div(XPow.Scale) - if res, err = polyEval.recurse(targetLevel+1, targetScale, coeffsq); err != nil { + if res, err = polyEval.recurse(targetLevel+nbModuliPerRescale, targetScale, coeffsq); err != nil { return nil, err } - if res.Degree() == 2 { polyEval.Relinearize(res, res) } @@ -353,18 +537,25 @@ func (polyEval *polynomialEvaluator) recurse(targetLevel int, targetScale rlwe.S func (polyEval *polynomialEvaluator) evaluatePolyFromPowerBasis(targetScale rlwe.Scale, level int, pol polynomialVector) (res *rlwe.Ciphertext, err error) { - X := polyEval.PowerBasis.Value + // Map[int] of the powers [X^{0}, X^{1}, X^{2}, ...] + X := polyEval.PolynomialBasis.Value + + // Retrieve the number of slots + logSlots := X[1].LogSlots + slots := 1 << X[1].LogSlots params := polyEval.Evaluator.(*evaluator).params slotsIndex := polyEval.slotsIndex + // Retrieve the degree of the highest degree non-zero coefficient + // TODO: optimize for nil/zero coefficients minimumDegreeNonZeroCoefficient := len(pol.Value[0].Coeffs) - 1 - - if polyEval.isEven { + if polyEval.isEven && !polyEval.isOdd { minimumDegreeNonZeroCoefficient-- } - // Get the minimum non-zero degree coefficient + // Gets the maximum degree of the ciphertexts among the power basis + // TODO: optimize for nil/zero coefficients, odd/even polynomial maximumCiphertextDegree := 0 for i := pol.Value[0].Degree(); i > 0; i-- { if x, ok := X[i]; ok { @@ -372,13 +563,17 @@ func (polyEval *polynomialEvaluator) evaluatePolyFromPowerBasis(targetScale rlwe } } + // Retrieve flags for even/odd + even := polyEval.isEven + odd := polyEval.isOdd + // If an index slot is given (either multiply polynomials or masking) if slotsIndex != nil { var toEncode bool // Allocates temporary buffer for coefficients encoding - values := make([]complex128, params.Slots()) + values := make([]*bignum.Complex, slots) // If the degree of the poly is zero if minimumDegreeNonZeroCoefficient == 0 { @@ -386,10 +581,11 @@ func (polyEval *polynomialEvaluator) evaluatePolyFromPowerBasis(targetScale rlwe // Allocates the output ciphertext res = NewCiphertext(params, 1, level) res.Scale = targetScale + res.LogSlots = logSlots // Looks for non-zero coefficients among the degree 0 coefficients of the polynomials for i, p := range pol.Value { - if isNotNegligible(p.Coeffs[0]) { + if !isZero(p.Coeffs[0]) { toEncode = true for _, j := range slotsIndex[i] { values[j] = p.Coeffs[0] @@ -400,9 +596,10 @@ func (polyEval *polynomialEvaluator) evaluatePolyFromPowerBasis(targetScale rlwe // If a non-zero coefficient was found, encode the values, adds on the ciphertext, and returns if toEncode { pt := rlwe.NewPlaintextAtLevelFromPoly(level, res.Value[0]) + pt.LogSlots = logSlots pt.IsNTT = true pt.Scale = targetScale - polyEval.EncodeSlots(values, pt, params.LogSlots()) + polyEval.Encode(values, pt) } return @@ -411,14 +608,16 @@ func (polyEval *polynomialEvaluator) evaluatePolyFromPowerBasis(targetScale rlwe // Allocates the output ciphertext res = NewCiphertext(params, maximumCiphertextDegree, level) res.Scale = targetScale + res.LogSlots = logSlots // Allocates a temporary plaintext to encode the values pt := rlwe.NewPlaintextAtLevelFromPoly(level, polyEval.Evaluator.BuffCt().Value[0]) pt.IsNTT = true + pt.LogSlots = logSlots // Looks for a non-zero coefficient among the degree zero coefficient of the polynomials for i, p := range pol.Value { - if isNotNegligible(p.Coeffs[0]) { + if !isZero(p.Coeffs[0]) { toEncode = true for _, j := range slotsIndex[i] { values[j] = p.Coeffs[0] @@ -430,7 +629,7 @@ func (polyEval *polynomialEvaluator) evaluatePolyFromPowerBasis(targetScale rlwe // ciphertext if toEncode { pt.Scale = targetScale - polyEval.EncodeSlots(values, pt, params.LogSlots()) + polyEval.Encode(values, pt) polyEval.Add(res, pt, res) toEncode = false } @@ -443,7 +642,7 @@ func (polyEval *polynomialEvaluator) evaluatePolyFromPowerBasis(targetScale rlwe for i, p := range pol.Value { // Looks for a non-zero coefficient - if isNotNegligible(p.Coeffs[key]) { + if !isZero(p.Coeffs[key]) { toEncode = true // Resets the temporary array to zero @@ -452,7 +651,10 @@ func (polyEval *polynomialEvaluator) evaluatePolyFromPowerBasis(targetScale rlwe // coefficient if !reset { for j := range values { - values[j] = 0 + if values[j] != nil { + values[j][0].SetFloat64(0) + values[j][1].SetFloat64(0) + } } reset = true } @@ -469,7 +671,7 @@ func (polyEval *polynomialEvaluator) evaluatePolyFromPowerBasis(targetScale rlwe // ciphertext if toEncode { pt.Scale = targetScale.Div(X[key].Scale) - polyEval.EncodeSlots(values, pt, params.LogSlots()) + polyEval.Encode(values, pt) polyEval.MulThenAdd(X[key], pt, res) toEncode = false } @@ -477,15 +679,19 @@ func (polyEval *polynomialEvaluator) evaluatePolyFromPowerBasis(targetScale rlwe } else { - c := pol.Value[0].Coeffs[0] + var c *bignum.Complex + if polyEval.isEven && !isZero(pol.Value[0].Coeffs[0]) { + c = pol.Value[0].Coeffs[0] + } if minimumDegreeNonZeroCoefficient == 0 { res = NewCiphertext(params, 1, level) res.Scale = targetScale + res.LogSlots = logSlots - if isNotNegligible(c) { - polyEval.AddConst(res, c, res) + if !isZero(c) { + polyEval.Add(res, c, res) } return @@ -493,28 +699,25 @@ func (polyEval *polynomialEvaluator) evaluatePolyFromPowerBasis(targetScale rlwe res = NewCiphertext(params, maximumCiphertextDegree, level) res.Scale = targetScale + res.LogSlots = logSlots - if isNotNegligible(c) { - polyEval.AddConst(res, c, res) + if c != nil { + polyEval.Add(res, c, res) } - constScale := new(big.Float).SetPrec(scalingPrecision) + constScale := new(big.Float).SetPrec(pol.Value[0].Prec) ringQ := params.RingQ().AtLevel(level) for key := pol.Value[0].Degree(); key > 0; key-- { - c = pol.Value[0].Coeffs[key] - - if key != 0 && isNotNegligible(c) { + if c = pol.Value[0].Coeffs[key]; key != 0 && !isZero(c) && (!(even || odd) || (key&1 == 0 && even) || (key&1 == 1 && odd)) { XScale := X[key].Scale.Value tgScale := targetScale.Value constScale.Quo(&tgScale, &XScale) - cmplxBig := valueToBigComplex(c, scalingPrecision) - - RNSReal, RNSImag := bigComplexToRNSScalar(ringQ, constScale, cmplxBig) + RNSReal, RNSImag := bigComplexToRNSScalar(ringQ, constScale, bignum.ToComplex(c, pol.Value[0].Prec)) polyEval.Evaluator.(*evaluator).evaluateWithScalar(level, X[key].Value, RNSReal, RNSImag, res.Value, ringQ.MulDoubleRNSScalarThenAdd) } @@ -524,32 +727,7 @@ func (polyEval *polynomialEvaluator) evaluatePolyFromPowerBasis(targetScale rlwe return } -func isNotNegligible(c complex128) bool { - return (math.Abs(real(c)) > IsNegligibleThreshold || math.Abs(imag(c)) > IsNegligibleThreshold) -} - -func isOddOrEvenPolynomial(coeffs []complex128) (odd, even bool) { - even = true - odd = true - for i, c := range coeffs { - isnotnegligible := isNotNegligible(c) - odd = odd && !(i&1 == 0 && isnotnegligible) - even = even && !(i&1 == 1 && isnotnegligible) - if !odd && !even { - break - } - } - - // If even or odd, then sets the expected zero coefficients to zero - if even || odd { - var start int - if even { - start = 1 - } - for i := start; i < len(coeffs); i += 2 { - coeffs[i] = complex(0, 0) - } - } - - return +func isZero(c *bignum.Complex) bool { + zero := new(big.Float) + return c == nil || (c[0].Cmp(zero) == 0 && c[1].Cmp(zero) == 0) } diff --git a/ckks/precision.go b/ckks/precision.go index e7a97225..d861866b 100644 --- a/ckks/precision.go +++ b/ckks/precision.go @@ -3,10 +3,12 @@ package ckks import ( "fmt" "math" + "math/big" "sort" "github.com/tuneinsight/lattigo/v4/ring/distribution" "github.com/tuneinsight/lattigo/v4/rlwe" + "github.com/tuneinsight/lattigo/v4/utils/bignum" ) // PrecisionStats is a struct storing statistic about the precision of a CKKS plaintext @@ -19,11 +21,9 @@ type PrecisionStats struct { MeanPrecision Stats MedianDelta Stats MedianPrecision Stats - STDFreq float64 - STDTime float64 RealDist, ImagDist, L2Dist []struct { - Prec float64 + Prec *big.Float Count int } @@ -33,7 +33,7 @@ type PrecisionStats struct { // Stats is a struct storing the real, imaginary and L2 norm (modulus) // about the precision of a complex value. type Stats struct { - Real, Imag, L2 float64 + Real, Imag, L2 *big.Float } func (prec PrecisionStats) String() string { @@ -46,143 +46,191 @@ func (prec PrecisionStats) String() string { │AVG Prec │ %5.2f │ %5.2f │ %5.2f │ │MED Prec │ %5.2f │ %5.2f │ %5.2f │ └─────────┴───────┴───────┴───────┘ -Err STD Slots : %5.2f Log2 -Err STD Coeffs : %5.2f Log2 `, prec.MinPrecision.Real, prec.MinPrecision.Imag, prec.MinPrecision.L2, prec.MaxPrecision.Real, prec.MaxPrecision.Imag, prec.MaxPrecision.L2, prec.MeanPrecision.Real, prec.MeanPrecision.Imag, prec.MeanPrecision.L2, - prec.MedianPrecision.Real, prec.MedianPrecision.Imag, prec.MedianPrecision.L2, - math.Log2(prec.STDFreq), - math.Log2(prec.STDTime)) + prec.MedianPrecision.Real, prec.MedianPrecision.Imag, prec.MedianPrecision.L2) } // GetPrecisionStats generates a PrecisionStats struct from the reference values and the decrypted values // vWant.(type) must be either []complex128 or []float64 // element.(type) must be either *Plaintext, *Ciphertext, []complex128 or []float64. If not *Ciphertext, then decryptor can be nil. -func GetPrecisionStats(params Parameters, encoder Encoder, decryptor rlwe.Decryptor, vWant, element interface{}, logSlots int, noise distribution.Distribution) (prec PrecisionStats) { +func GetPrecisionStats(params Parameters, encoder *Encoder, decryptor rlwe.Decryptor, want, have interface{}, noise distribution.Distribution, computeDCF bool) (prec PrecisionStats) { - var valuesTest []complex128 - - switch element := element.(type) { - case *rlwe.Ciphertext: - valuesTest = encoder.DecodePublic(decryptor.DecryptNew(element), logSlots, noise) - case *rlwe.Plaintext: - valuesTest = encoder.DecodePublic(element, logSlots, noise) - case []complex128: - valuesTest = element - case []float64: - valuesTest = make([]complex128, len(element)) - for i := range element { - valuesTest[i] = complex(element[i], 0) - } + if encoder.Prec() <= 53 { + return getPrecisionStatsF64(params, encoder, decryptor, want, have, noise, computeDCF) } + return getPrecisionStatsF128(params, encoder, decryptor, want, have, noise, computeDCF) +} + +func getPrecisionStatsF64(params Parameters, encoder *Encoder, decryptor rlwe.Decryptor, want, have interface{}, noise distribution.Distribution, computeDCF bool) (prec PrecisionStats) { + + precision := encoder.Prec() + var valuesWant []complex128 - switch element := vWant.(type) { + switch want := want.(type) { case []complex128: - valuesWant = element + valuesWant = make([]complex128, len(want)) + copy(valuesWant, want) case []float64: - valuesWant = make([]complex128, len(element)) - for i := range element { - valuesWant[i] = complex(element[i], 0) + valuesWant = make([]complex128, len(want)) + for i := range want { + valuesWant[i] = complex(want[i], 0) + } + case []*big.Float: + valuesWant = make([]complex128, len(want)) + for i := range want { + if want[i] != nil { + f64, _ := want[i].Float64() + valuesWant[i] = complex(f64, 0) + } + } + case []*bignum.Complex: + valuesWant = make([]complex128, len(want)) + for i := range want { + if want[i] != nil { + valuesWant[i] = want[i].Complex128() + } + } } - var deltaReal, deltaImag, deltaL2 float64 + var valuesHave []complex128 + + switch have := have.(type) { + case *rlwe.Ciphertext: + valuesHave = make([]complex128, len(valuesWant)) + encoder.DecodePublic(decryptor.DecryptNew(have), valuesHave, noise) + case *rlwe.Plaintext: + valuesHave = make([]complex128, len(valuesWant)) + encoder.DecodePublic(have, valuesHave, noise) + case []complex128: + valuesHave = make([]complex128, len(valuesWant)) + copy(valuesHave, have) + case []float64: + valuesHave = make([]complex128, len(valuesWant)) + for i := range have { + valuesHave[i] = complex(have[i], 0) + } + case []*big.Float: + valuesHave = make([]complex128, len(valuesWant)) + for i := range have { + if have[i] != nil { + f64, _ := have[i].Float64() + valuesHave[i] = complex(f64, 0) + } + } + case []*bignum.Complex: + valuesHave = make([]complex128, len(valuesWant)) + for i := range have { + if have[i] != nil { + valuesHave[i] = have[i].Complex128() + } + } + } slots := len(valuesWant) - diff := make([]Stats, slots) - - prec.MaxDelta = Stats{0, 0, 0} - prec.MinDelta = Stats{1, 1, 1} - prec.MeanDelta = Stats{0, 0, 0} - - prec.cdfResol = 500 - - prec.RealDist = make([]struct { - Prec float64 - Count int - }, prec.cdfResol) - prec.ImagDist = make([]struct { - Prec float64 - Count int - }, prec.cdfResol) - prec.L2Dist = make([]struct { - Prec float64 - Count int - }, prec.cdfResol) + diff := make([]struct{ Real, Imag, L2 float64 }, slots) precReal := make([]float64, len(valuesWant)) precImag := make([]float64, len(valuesWant)) precL2 := make([]float64, len(valuesWant)) + var deltaReal, deltaImag, deltaL2 float64 + var MeanDeltaReal, MeanDeltaImag, MeanDeltaL2 float64 + var MaxDeltaReal, MaxDeltaImag, MaxDeltaL2 float64 + var MinDeltaReal, MinDeltaImag, MinDeltaL2 float64 = 1, 1, 1 + for i := range valuesWant { - deltaReal = math.Abs(real(valuesTest[i]) - real(valuesWant[i])) - deltaImag = math.Abs(imag(valuesTest[i]) - imag(valuesWant[i])) - deltaL2 = math.Sqrt(deltaReal*deltaReal + deltaImag*deltaImag) - precReal[i] = math.Log2(1 / deltaReal) - precImag[i] = math.Log2(1 / deltaImag) - precL2[i] = math.Log2(1 / deltaL2) + deltaReal = math.Abs(real(valuesHave[i]) - real(valuesWant[i])) + deltaImag = math.Abs(imag(valuesHave[i]) - imag(valuesWant[i])) + deltaL2 = math.Sqrt(deltaReal*deltaReal + deltaReal*deltaReal) + + precReal[i] = -math.Log2(deltaReal) + precImag[i] = -math.Log2(deltaImag) + precL2[i] = -math.Log2(deltaL2) diff[i].Real = deltaReal diff[i].Imag = deltaImag diff[i].L2 = deltaL2 - prec.MeanDelta.Real += deltaReal - prec.MeanDelta.Imag += deltaImag - prec.MeanDelta.L2 += deltaL2 + MeanDeltaReal += deltaReal + MeanDeltaImag += deltaImag + MeanDeltaL2 += deltaL2 - if deltaReal > prec.MaxDelta.Real { - prec.MaxDelta.Real = deltaReal + if deltaReal > MaxDeltaReal { + MaxDeltaReal = deltaReal } - if deltaImag > prec.MaxDelta.Imag { - prec.MaxDelta.Imag = deltaImag + if deltaImag < MaxDeltaImag { + MaxDeltaImag = deltaImag } - if deltaL2 > prec.MaxDelta.L2 { - prec.MaxDelta.L2 = deltaL2 + if deltaL2 < MaxDeltaL2 { + MaxDeltaL2 = deltaL2 } - if deltaReal < prec.MinDelta.Real { - prec.MinDelta.Real = deltaReal + if deltaReal < MinDeltaReal { + MinDeltaReal = deltaReal } - if deltaImag < prec.MinDelta.Imag { - prec.MinDelta.Imag = deltaImag + if deltaImag < MinDeltaImag { + MinDeltaImag = deltaImag } - if deltaL2 < prec.MinDelta.L2 { - prec.MinDelta.L2 = deltaL2 + if deltaL2 < MinDeltaL2 { + MinDeltaL2 = deltaL2 } } - prec.calcCDF(precReal, prec.RealDist) - prec.calcCDF(precImag, prec.ImagDist) - prec.calcCDF(precL2, prec.L2Dist) + if computeDCF { - prec.MinPrecision = deltaToPrecision(prec.MaxDelta) - prec.MaxPrecision = deltaToPrecision(prec.MinDelta) - prec.MeanDelta.Real /= float64(slots) - prec.MeanDelta.Imag /= float64(slots) - prec.MeanDelta.L2 /= float64(slots) - prec.MeanPrecision = deltaToPrecision(prec.MeanDelta) - prec.MedianDelta = calcmedian(diff) - prec.MedianPrecision = deltaToPrecision(prec.MedianDelta) - prec.STDFreq = encoder.GetErrSTDSlotDomain(valuesWant[:], valuesTest[:], params.DefaultScale()) - prec.STDTime = encoder.GetErrSTDCoeffDomain(valuesWant, valuesTest, params.DefaultScale()) + prec.cdfResol = 500 + + prec.RealDist = make([]struct { + Prec *big.Float + Count int + }, prec.cdfResol) + prec.ImagDist = make([]struct { + Prec *big.Float + Count int + }, prec.cdfResol) + prec.L2Dist = make([]struct { + Prec *big.Float + Count int + }, prec.cdfResol) + + prec.calcCDFF64(precReal, prec.RealDist) + prec.calcCDFF64(precImag, prec.ImagDist) + prec.calcCDFF64(precL2, prec.L2Dist) + } + + prec.MinPrecision = deltaToPrecisionF64(struct{ Real, Imag, L2 float64 }{Real: MaxDeltaReal, Imag: MaxDeltaImag, L2: MaxDeltaL2}) + prec.MaxPrecision = deltaToPrecisionF64(struct{ Real, Imag, L2 float64 }{Real: MinDeltaReal, Imag: MinDeltaImag, L2: MinDeltaL2}) + prec.MeanDelta.Real = new(big.Float).SetFloat64(MeanDeltaReal / float64(slots)) + prec.MeanDelta.Imag = new(big.Float).SetFloat64(MeanDeltaImag / float64(slots)) + prec.MeanDelta.L2 = new(big.Float).SetFloat64(MeanDeltaL2 / float64(slots)) + prec.MeanPrecision = deltaToPrecisionF64(struct{ Real, Imag, L2 float64 }{Real: MeanDeltaReal / float64(slots), Imag: MeanDeltaImag / float64(slots), L2: MeanDeltaL2 / float64(slots)}) + prec.MedianDelta = calcmedianF64(diff) + prec.MedianPrecision = deltaToPrecisionF128(prec.MedianDelta, bignum.Log(new(big.Float).SetPrec(precision).SetInt64(2))) return prec } -func deltaToPrecision(c Stats) Stats { - return Stats{math.Log2(1 / c.Real), math.Log2(1 / c.Imag), math.Log2(1 / c.L2)} +func deltaToPrecisionF64(c struct{ Real, Imag, L2 float64 }) Stats { + + return Stats{ + new(big.Float).SetFloat64(-math.Log2(c.Real)), + new(big.Float).SetFloat64(-math.Log2(c.Imag)), + new(big.Float).SetFloat64(-math.Log2(c.L2)), + } } -func (prec *PrecisionStats) calcCDF(precs []float64, res []struct { - Prec float64 +func (prec *PrecisionStats) calcCDFF64(precs []float64, res []struct { + Prec *big.Float Count int }) { sortedPrecs := make([]float64, len(precs)) @@ -194,7 +242,7 @@ func (prec *PrecisionStats) calcCDF(precs []float64, res []struct { curPrec := minPrec + float64(i)*(maxPrec-minPrec)/float64(prec.cdfResol) for countSmaller, p := range sortedPrecs { if p >= curPrec { - res[i].Prec = curPrec + res[i].Prec = new(big.Float).SetFloat64(curPrec) res[i].Count = countSmaller break } @@ -202,7 +250,7 @@ func (prec *PrecisionStats) calcCDF(precs []float64, res []struct { } } -func calcmedian(values []Stats) (median Stats) { +func calcmedianF64(values []struct{ Real, Imag, L2 float64 }) (median Stats) { tmp := make([]float64, len(values)) @@ -238,11 +286,339 @@ func calcmedian(values []Stats) (median Stats) { index := len(values) / 2 + if len(values)&1 == 1 || index+1 == len(values) { + return Stats{ + new(big.Float).SetFloat64(values[index].Real), + new(big.Float).SetFloat64(values[index].Imag), + new(big.Float).SetFloat64(values[index].L2), + } + } + + return Stats{ + new(big.Float).SetFloat64((values[index-1].Real + values[index].Real) / 2), + new(big.Float).SetFloat64((values[index-1].Imag + values[index].Imag) / 2), + new(big.Float).SetFloat64((values[index-1].L2 + values[index].L2) / 2), + } +} + +func getPrecisionStatsF128(params Parameters, encoder *Encoder, decryptor rlwe.Decryptor, want, have interface{}, noise distribution.Distribution, computeDCF bool) (prec PrecisionStats) { + precision := encoder.Prec() + + var valuesWant []*bignum.Complex + switch want := want.(type) { + case []complex128: + valuesWant = make([]*bignum.Complex, len(want)) + for i := range want { + valuesWant[i] = &bignum.Complex{ + new(big.Float).SetPrec(precision).SetFloat64(real(want[i])), + new(big.Float).SetPrec(precision).SetFloat64(imag(want[i])), + } + } + case []float64: + valuesWant = make([]*bignum.Complex, len(want)) + for i := range want { + valuesWant[i] = &bignum.Complex{ + new(big.Float).SetPrec(precision).SetFloat64(want[i]), + new(big.Float).SetPrec(precision), + } + } + case []*big.Float: + valuesWant = make([]*bignum.Complex, len(want)) + for i := range want { + valuesWant[i] = &bignum.Complex{ + want[i], + new(big.Float).SetPrec(precision), + } + } + case []*bignum.Complex: + valuesWant = want + + for i := range valuesWant { + if valuesWant[i] == nil { + valuesWant[i] = &bignum.Complex{new(big.Float), new(big.Float)} + } + } + } + + var valuesHave []*bignum.Complex + + switch have := have.(type) { + case *rlwe.Ciphertext: + valuesHave = make([]*bignum.Complex, len(valuesWant)) + encoder.DecodePublic(decryptor.DecryptNew(have), valuesHave, noise) + case *rlwe.Plaintext: + valuesHave = make([]*bignum.Complex, len(valuesWant)) + encoder.DecodePublic(have, valuesHave, noise) + case []complex128: + valuesHave = make([]*bignum.Complex, len(have)) + for i := range have { + valuesHave[i] = &bignum.Complex{ + new(big.Float).SetPrec(precision).SetFloat64(real(have[i])), + new(big.Float).SetPrec(precision).SetFloat64(imag(have[i])), + } + } + case []float64: + valuesHave = make([]*bignum.Complex, len(have)) + for i := range have { + valuesHave[i] = &bignum.Complex{ + new(big.Float).SetPrec(precision).SetFloat64(have[i]), + new(big.Float).SetPrec(precision), + } + } + case []*big.Float: + valuesHave = make([]*bignum.Complex, len(have)) + for i := range have { + valuesHave[i] = &bignum.Complex{ + have[i], + new(big.Float).SetPrec(precision), + } + } + case []*bignum.Complex: + valuesHave = have + for i := range valuesHave { + if valuesHave[i] == nil { + valuesHave[i] = &bignum.Complex{new(big.Float), new(big.Float)} + } + } + } + + slots := len(valuesWant) + + diff := make([]Stats, slots) + + prec.MaxDelta = Stats{ + new(big.Float).SetPrec(precision), + new(big.Float).SetPrec(precision), + new(big.Float).SetPrec(precision), + } + prec.MinDelta = Stats{ + new(big.Float).SetPrec(precision).SetInt64(1), + new(big.Float).SetPrec(precision).SetInt64(1), + new(big.Float).SetPrec(precision).SetInt64(1), + } + prec.MeanDelta = Stats{ + new(big.Float).SetPrec(precision), + new(big.Float).SetPrec(precision), + new(big.Float).SetPrec(precision), + } + + precReal := make([]*big.Float, len(valuesWant)) + precImag := make([]*big.Float, len(valuesWant)) + precL2 := make([]*big.Float, len(valuesWant)) + + deltaReal := new(big.Float) + deltaImag := new(big.Float) + deltaL2 := new(big.Float) + + tmp := new(big.Float) + + ln2 := bignum.Log(new(big.Float).SetPrec(precision).SetInt64(2)) + + for i := range valuesWant { + + deltaReal.Sub(valuesHave[i][0], valuesWant[i][0]) + deltaReal.Abs(deltaReal) + + deltaImag.Sub(valuesHave[i][1], valuesWant[i][1]) + deltaImag.Abs(deltaImag) + + deltaL2.Mul(deltaReal, deltaReal) + deltaL2.Add(deltaL2, tmp.Mul(deltaImag, deltaImag)) + deltaL2.Sqrt(deltaL2) + + precReal[i] = bignum.Log(deltaReal) + precReal[i].Quo(precReal[i], ln2) + precReal[i].Neg(precReal[i]) + + precImag[i] = bignum.Log(deltaImag) + precImag[i].Quo(precImag[i], ln2) + precImag[i].Neg(precImag[i]) + + precL2[i] = bignum.Log(deltaL2) + precL2[i].Quo(precL2[i], ln2) + precL2[i].Neg(precL2[i]) + + diff[i].Real = new(big.Float).Set(deltaReal) + diff[i].Imag = new(big.Float).Set(deltaImag) + diff[i].L2 = new(big.Float).Set(deltaL2) + + prec.MeanDelta.Real.Add(prec.MeanDelta.Real, deltaReal) + prec.MeanDelta.Imag.Add(prec.MeanDelta.Imag, deltaImag) + prec.MeanDelta.L2.Add(prec.MeanDelta.L2, deltaL2) + + if deltaReal.Cmp(prec.MaxDelta.Real) == 1 { + prec.MaxDelta.Real.Set(deltaReal) + } + + if deltaImag.Cmp(prec.MaxDelta.Imag) == 1 { + prec.MaxDelta.Imag.Set(deltaImag) + } + + if deltaL2.Cmp(prec.MaxDelta.L2) == 1 { + prec.MaxDelta.L2.Set(deltaL2) + } + + if deltaReal.Cmp(prec.MinDelta.Real) == -1 { + prec.MinDelta.Real.Set(deltaReal) + } + + if deltaImag.Cmp(prec.MinDelta.Imag) == -1 { + prec.MinDelta.Imag.Set(deltaImag) + } + + if deltaL2.Cmp(prec.MinDelta.L2) == -1 { + prec.MinDelta.L2.Set(deltaL2) + } + } + + if computeDCF { + + prec.cdfResol = 500 + + prec.RealDist = make([]struct { + Prec *big.Float + Count int + }, prec.cdfResol) + prec.ImagDist = make([]struct { + Prec *big.Float + Count int + }, prec.cdfResol) + prec.L2Dist = make([]struct { + Prec *big.Float + Count int + }, prec.cdfResol) + + prec.calcCDFF128(precReal, prec.RealDist) + prec.calcCDFF128(precImag, prec.ImagDist) + prec.calcCDFF128(precL2, prec.L2Dist) + } + + prec.MinPrecision = deltaToPrecisionF128(prec.MaxDelta, ln2) + prec.MaxPrecision = deltaToPrecisionF128(prec.MinDelta, ln2) + prec.MeanDelta.Real.Quo(prec.MeanDelta.Real, new(big.Float).SetPrec(precision).SetInt64(int64(slots))) + prec.MeanDelta.Imag.Quo(prec.MeanDelta.Imag, new(big.Float).SetPrec(precision).SetInt64(int64(slots))) + prec.MeanDelta.L2.Quo(prec.MeanDelta.L2, new(big.Float).SetPrec(precision).SetInt64(int64(slots))) + prec.MeanPrecision = deltaToPrecisionF128(prec.MeanDelta, ln2) + prec.MedianDelta = calcmedianF128(diff) + prec.MedianPrecision = deltaToPrecisionF128(prec.MedianDelta, ln2) + return prec +} + +func deltaToPrecisionF128(c Stats, ln2 *big.Float) Stats { + + real := bignum.Log(c.Real) + real.Quo(real, ln2) + real.Neg(real) + + imag := bignum.Log(c.Imag) + imag.Quo(imag, ln2) + imag.Neg(imag) + + l2 := bignum.Log(c.L2) + l2.Quo(l2, ln2) + l2.Neg(l2) + + return Stats{ + real, + imag, + l2, + } +} + +func (prec *PrecisionStats) calcCDFF128(precs []*big.Float, res []struct { + Prec *big.Float + Count int +}) { + sortedPrecs := make([]*big.Float, len(precs)) + copy(sortedPrecs, precs) + + sort.Slice(sortedPrecs, func(i, j int) bool { + return sortedPrecs[i].Cmp(sortedPrecs[j]) > 0 + }) + + minPrec := sortedPrecs[0] + maxPrec := sortedPrecs[len(sortedPrecs)-1] + + curPrec := new(big.Float) + + a := new(big.Float).Sub(maxPrec, minPrec) + a.Quo(a, new(big.Float).SetInt64(int64(prec.cdfResol))) + + b := new(big.Float).Quo(minPrec, new(big.Float).SetInt64(int64(prec.cdfResol))) + + for i := 0; i < prec.cdfResol; i++ { + + curPrec.Mul(new(big.Float).SetInt64(int64(i)), a) + curPrec.Add(curPrec, b) + + for countSmaller, p := range sortedPrecs { + if p.Cmp(curPrec) >= 0 { + res[i].Prec = new(big.Float).Set(curPrec) + res[i].Count = countSmaller + break + } + } + } +} + +func calcmedianF128(values []Stats) (median Stats) { + + tmp := make([]*big.Float, len(values)) + + for i := range values { + tmp[i] = values[i].Real + } + + sort.Slice(tmp, func(i, j int) bool { + return tmp[i].Cmp(tmp[j]) > 0 + }) + + for i := range values { + values[i].Real.Set(tmp[i]) + } + + for i := range values { + tmp[i] = values[i].Imag + } + + sort.Slice(tmp, func(i, j int) bool { + return tmp[i].Cmp(tmp[j]) > 0 + }) + + for i := range values { + values[i].Imag = tmp[i] + } + + for i := range values { + tmp[i] = values[i].L2 + } + + sort.Slice(tmp, func(i, j int) bool { + return tmp[i].Cmp(tmp[j]) > 0 + }) + + for i := range values { + values[i].L2 = tmp[i] + } + + index := len(values) / 2 + if len(values)&1 == 1 || index+1 == len(values) { return Stats{values[index].Real, values[index].Imag, values[index].L2} } - return Stats{(values[index-1].Real + values[index].Real) / 2, - (values[index-1].Imag + values[index].Imag) / 2, - (values[index-1].L2 + values[index].L2) / 2} + real := new(big.Float).Add(values[index-1].Real, values[index].Real) + real.Quo(real, new(big.Float).SetInt64(2)) + + imag := new(big.Float).Add(values[index-1].Imag, values[index].Imag) + imag.Quo(imag, new(big.Float).SetInt64(2)) + + l2 := new(big.Float).Add(values[index-1].L2, values[index].L2) + l2.Quo(l2, new(big.Float).SetInt64(2)) + + return Stats{ + real, + imag, + l2, + } } diff --git a/ckks/scaling.go b/ckks/scaling.go index d9c1c6a2..cc085427 100644 --- a/ckks/scaling.go +++ b/ckks/scaling.go @@ -1,58 +1,13 @@ package ckks import ( - "fmt" "math/big" "github.com/tuneinsight/lattigo/v4/ring" + "github.com/tuneinsight/lattigo/v4/utils/bignum" ) -const ( - scalingPrecision = uint(128) -) - -func valueToBigComplex(value interface{}, prec uint) (cmplx *ring.Complex) { - - cmplx = new(ring.Complex) - - switch value := value.(type) { - case complex128: - - if v := real(value); v != 0 { - cmplx[0] = new(big.Float).SetPrec(prec) - cmplx[0].SetFloat64(v) - } - - if v := imag(value); v != 0 { - cmplx[1] = new(big.Float).SetPrec(prec) - cmplx[1].SetFloat64(v) - } - - case float64: - return valueToBigComplex(complex(value, 0), prec) - case int: - return valueToBigComplex(new(big.Int).SetInt64(int64(value)), prec) - case int64: - return valueToBigComplex(new(big.Int).SetInt64(value), prec) - case uint64: - return valueToBigComplex(new(big.Int).SetUint64(value), prec) - case *big.Float: - cmplx[0] = new(big.Float).SetPrec(prec) - cmplx[0].Set(value) - case *big.Int: - cmplx[0] = new(big.Float).SetPrec(prec) - cmplx[0].SetInt(value) - case *ring.Complex: - cmplx[0] = new(big.Float).Set(value[0]) - cmplx[1] = new(big.Float).Set(value[1]) - default: - panic(fmt.Errorf("invalid value.(type): must be int, int64, uint64, float64, complex128, *big.Int, *big.Float or *ring.Complex but is %T", value)) - } - - return -} - -func bigComplexToRNSScalar(r *ring.Ring, scale *big.Float, cmplx *ring.Complex) (RNSReal, RNSImag ring.RNSScalar) { +func bigComplexToRNSScalar(r *ring.Ring, scale *big.Float, cmplx *bignum.Complex) (RNSReal, RNSImag ring.RNSScalar) { if scale == nil { scale = new(big.Float).SetFloat64(1) diff --git a/ckks/simple_bootstrapper.go b/ckks/simple_bootstrapper.go new file mode 100644 index 00000000..29df1b05 --- /dev/null +++ b/ckks/simple_bootstrapper.go @@ -0,0 +1,64 @@ +package ckks + +import ( + "github.com/tuneinsight/lattigo/v4/rlwe" + "github.com/tuneinsight/lattigo/v4/utils/bignum" +) + +// SimpleBootstrapper is an implementation of the rlwe.Bootstrapping interface that +// uses the secret-key to decrypt and re-encrypt the bootstrapped ciphertext. +type SimpleBootstrapper struct { + Parameters + *Encoder + rlwe.Decryptor + rlwe.Encryptor + sk *rlwe.SecretKey + Values []*bignum.Complex + Counter int // records the number of bootstrapping +} + +func NewSimpleBootstrapper(params Parameters, sk *rlwe.SecretKey) rlwe.Bootstrapper { + return &SimpleBootstrapper{ + params, + NewEncoder(params), + NewDecryptor(params, sk), + NewEncryptor(params, sk), + sk, + make([]*bignum.Complex, params.N()), + 0} +} + +func (d *SimpleBootstrapper) Bootstrap(ct *rlwe.Ciphertext) (*rlwe.Ciphertext, error) { + values := d.Values[:1<> 2 - var PI = new(big.Float) - PI.SetPrec(prec) - PI.SetString(pi) + Pi := bignum.Pi(prec) - e2ipi := ring.NewFloat(2, prec) - e2ipi.Mul(e2ipi, PI) - e2ipi.Quo(e2ipi, ring.NewFloat(float64(NthRoot), prec)) + e2ipi := bignum.NewFloat(2, prec) + e2ipi.Mul(e2ipi, Pi) + e2ipi.Quo(e2ipi, bignum.NewFloat(float64(NthRoot), prec)) angle := new(big.Float).SetPrec(prec) - roots[0] = &ring.Complex{ring.NewFloat(1, prec), ring.NewFloat(0, prec)} + roots[0] = &bignum.Complex{bignum.NewFloat(1, prec), bignum.NewFloat(0, prec)} for i := 1; i < quarm; i++ { - angle.Mul(e2ipi, ring.NewFloat(float64(i), prec)) - roots[i] = &ring.Complex{ring.Cos(angle), nil} + angle.Mul(e2ipi, bignum.NewFloat(float64(i), prec)) + roots[i] = &bignum.Complex{bignum.Cos(angle), nil} } for i := 1; i < quarm; i++ { roots[quarm-i][1] = new(big.Float).Set(roots[i].Real()) } - roots[quarm] = &ring.Complex{ring.NewFloat(0, prec), ring.NewFloat(1, prec)} + roots[quarm] = &bignum.Complex{bignum.NewFloat(0, prec), bignum.NewFloat(1, prec)} for i := 1; i < quarm+1; i++ { - roots[i+1*quarm] = &ring.Complex{new(big.Float).Neg(roots[quarm-i].Real()), new(big.Float).Set(roots[quarm-i].Imag())} - roots[i+2*quarm] = &ring.Complex{new(big.Float).Neg(roots[i].Real()), new(big.Float).Neg(roots[i].Imag())} - roots[i+3*quarm] = &ring.Complex{new(big.Float).Set(roots[quarm-i].Real()), new(big.Float).Neg(roots[quarm-i].Imag())} + roots[i+1*quarm] = &bignum.Complex{new(big.Float).Neg(roots[quarm-i].Real()), new(big.Float).Set(roots[quarm-i].Imag())} + roots[i+2*quarm] = &bignum.Complex{new(big.Float).Neg(roots[i].Real()), new(big.Float).Neg(roots[i].Imag())} + roots[i+3*quarm] = &bignum.Complex{new(big.Float).Set(roots[quarm-i].Real()), new(big.Float).Neg(roots[quarm-i].Imag())} } roots[NthRoot] = roots[0] @@ -77,24 +77,52 @@ func GetRootsFloat64(NthRoot int) (roots []complex128) { } // StandardDeviation computes the scaled standard deviation of the input vector. -func StandardDeviation(vec []float64, scale float64) (std float64) { - // We assume that the error is centered around zero - var err, tmp, mean, n float64 +func StandardDeviation(vec interface{}, scale rlwe.Scale) (std float64) { - n = float64(len(vec)) + switch vec := vec.(type) { + case []float64: + // We assume that the error is centered around zero + var err, tmp, mean, n float64 - for _, c := range vec { - mean += c + n = float64(len(vec)) + + for _, c := range vec { + mean += c + } + + mean /= n + + for _, c := range vec { + tmp = c - mean + err += tmp * tmp + } + + std = math.Sqrt(err/(n-1)) * scale.Float64() + case []*big.Float: + mean := new(big.Float) + + for _, c := range vec { + mean.Add(mean, c) + } + + mean.Quo(mean, new(big.Float).SetInt64(int64(len(vec)))) + + err := new(big.Float) + tmp := new(big.Float) + for _, c := range vec { + tmp.Sub(c, mean) + tmp.Mul(tmp, tmp) + err.Add(err, tmp) + } + + err.Quo(err, new(big.Float).SetInt64(int64(len(vec)-1))) + err.Sqrt(err) + err.Mul(err, &scale.Value) + + std, _ = err.Float64() } - mean /= n - - for _, c := range vec { - tmp = c - mean - err += tmp * tmp - } - - return math.Sqrt(err/(n-1)) * scale + return } // NttSparseAndMontgomery takes the polynomial polIn Z[Y] outside of the NTT domain to the polynomial Z[X] in the NTT domain where Y = X^(gap). @@ -144,19 +172,19 @@ func NttSparseAndMontgomery(r *ring.Ring, logSlots int, montgomery bool, pol *ri } } -// ComplexToFixedPointCRT encodes a vector of complex on a CRT polynomial. +// Complex128ToFixedPointCRT encodes a vector of complex128 on a CRT polynomial. // The real part is put in a left N/2 coefficient and the imaginary in the right N/2 coefficients. -func ComplexToFixedPointCRT(r *ring.Ring, values []complex128, scale float64, coeffs [][]uint64) { +func Complex128ToFixedPointCRT(r *ring.Ring, values []complex128, scale float64, coeffs [][]uint64) { for i, v := range values { - SingleFloatToFixedPointCRT(r, i, real(v), scale, coeffs) + SingleFloat64ToFixedPointCRT(r, i, real(v), scale, coeffs) } var start int if r.Type() == ring.Standard { slots := len(values) for i, v := range values { - SingleFloatToFixedPointCRT(r, i+slots, imag(v), scale, coeffs) + SingleFloat64ToFixedPointCRT(r, i+slots, imag(v), scale, coeffs) } start = 2 * len(values) @@ -186,8 +214,8 @@ func FloatToFixedPointCRT(r *ring.Ring, values []float64, scale float64, coeffs } } -// SingleFloatToFixedPointCRT encodes a single float on a CRT polynomial in the i-th coefficient. -func SingleFloatToFixedPointCRT(r *ring.Ring, i int, value float64, scale float64, coeffs [][]uint64) { +// SingleFloat64ToFixedPointCRT encodes a single float64 on a CRT polynomialon in the i-th coefficient. +func SingleFloat64ToFixedPointCRT(r *ring.Ring, i int, value float64, scale float64, coeffs [][]uint64) { if value == 0 { for j := range coeffs { @@ -220,7 +248,7 @@ func SingleFloatToFixedPointCRT(r *ring.Ring, i int, value float64, scale float6 xInt = new(big.Int) xFlo.Int(xInt) for j := range moduli { - tmp.Mod(xInt, ring.NewUint(moduli[j])) + tmp.Mod(xInt, bignum.NewInt(moduli[j])) if isNegative { coeffs[j][i] = moduli[j] - tmp.Uint64() } else { @@ -252,42 +280,91 @@ func SingleFloatToFixedPointCRT(r *ring.Ring, i int, value float64, scale float6 } } -func scaleUpVecExactBigFloat(values []*big.Float, scale float64, moduli []uint64, coeffs [][]uint64) { +// Float64ToFixedPointCRT encodes a vector of floats on a CRT polynomial. +func Float64ToFixedPointCRT(r *ring.Ring, values []float64, scale float64, coeffs [][]uint64) { + for i, v := range values { + SingleFloat64ToFixedPointCRT(r, i, v, scale, coeffs) + } - prec := values[0].Prec() + for i := 0; i < len(coeffs); i++ { + tmp := coeffs[i] + for j := len(values); j < len(coeffs[0]); j++ { + tmp[j] = 0 + } + } +} - xFlo := ring.NewFloat(0, prec) +func ComplexArbitraryToFixedPointCRT(r *ring.Ring, values []*bignum.Complex, scale *big.Float, coeffs [][]uint64) { + + xFlo := new(big.Float) xInt := new(big.Int) tmp := new(big.Int) - zero := ring.NewFloat(0, prec) + zero := new(big.Float) - scaleFlo := ring.NewFloat(scale, prec) - half := ring.NewFloat(0.5, prec) + half := new(big.Float).SetFloat64(0.5) + + moduli := r.ModuliChain()[:r.Level()+1] + + var negative bool for i := range values { - xFlo.Mul(scaleFlo, values[i]) + xFlo.Mul(scale, values[i][0]) - if values[i].Cmp(zero) < 0 { + if values[i][0].Cmp(zero) < 0 { xFlo.Sub(xFlo, half) + negative = true } else { xFlo.Add(xFlo, half) + negative = false } xFlo.Int(xInt) for j := range moduli { - Q := ring.NewUint(moduli[j]) + Q := bignum.NewInt(moduli[j]) tmp.Mod(xInt, Q) - if values[i].Cmp(zero) < 0 { + if negative { tmp.Add(tmp, Q) } coeffs[j][i] = tmp.Uint64() } } + + if r.Type() == ring.Standard { + + slots := len(values) + + for i := range values { + + xFlo.Mul(scale, values[i][1]) + + if values[i][1].Cmp(zero) < 0 { + xFlo.Sub(xFlo, half) + negative = true + } else { + xFlo.Add(xFlo, half) + negative = false + } + + xFlo.Int(xInt) + + for j := range moduli { + + Q := bignum.NewInt(moduli[j]) + + tmp.Mod(xInt, Q) + + if negative { + tmp.Add(tmp, Q) + } + coeffs[j][i+slots] = tmp.Uint64() + } + } + } } diff --git a/dckks/dckks_benchmark_test.go b/dckks/dckks_benchmark_test.go index d8f19f81..36607df0 100644 --- a/dckks/dckks_benchmark_test.go +++ b/dckks/dckks_benchmark_test.go @@ -8,40 +8,43 @@ import ( "github.com/tuneinsight/lattigo/v4/drlwe" "github.com/tuneinsight/lattigo/v4/ring" "github.com/tuneinsight/lattigo/v4/rlwe" + "github.com/tuneinsight/lattigo/v4/utils/bignum" ) func BenchmarkDCKKS(b *testing.B) { var err error - defaultParams := ckks.DefaultParams - if testing.Short() { - defaultParams = ckks.DefaultParams[:2] - } - if *flagParamString != "" { - var jsonParams ckks.ParametersLiteral - if err = json.Unmarshal([]byte(*flagParamString), &jsonParams); err != nil { + var testParams []ckks.ParametersLiteral + switch { + case *flagParamString != "": // the custom test suite reads the parameters from the -params flag + testParams = append(testParams, ckks.ParametersLiteral{}) + if err = json.Unmarshal([]byte(*flagParamString), &testParams[0]); err != nil { b.Fatal(err) } - defaultParams = []ckks.ParametersLiteral{jsonParams} // the custom test suite reads the parameters from the -params flag + default: + testParams = ckks.TestParamsLiteral } - parties := 3 + for _, ringType := range []ring.Type{ring.Standard, ring.ConjugateInvariant} { - for _, p := range defaultParams { + for _, paramsLiteral := range testParams { - var params ckks.Parameters - if params, err = ckks.NewParametersFromLiteral(p); err != nil { - b.Fatal(err) + paramsLiteral.RingType = ringType + + var params ckks.Parameters + if params, err = ckks.NewParametersFromLiteral(paramsLiteral); err != nil { + b.Fatal(err) + } + N := 3 + var tc *testContext + if tc, err = genTestParams(params, N); err != nil { + b.Fatal(err) + } + + benchRefresh(tc, b) + benchMaskedTransform(tc, b) } - - var tc *testContext - if tc, err = genTestParams(params, parties); err != nil { - b.Fatal(err) - } - - benchRefresh(tc, b) - benchMaskedTransform(tc, b) } } @@ -70,24 +73,24 @@ func benchRefresh(tc *testContext, b *testing.B) { crp := p.SampleCRP(params.MaxLevel(), tc.crs) - b.Run(testString("Refresh/Round1/Gen", tc.NParties, params), func(b *testing.B) { + b.Run(GetTestName("Refresh/Round1/Gen", tc.NParties, params), func(b *testing.B) { for i := 0; i < b.N; i++ { - p.GenShare(p.s, logBound, params.LogSlots(), ciphertext, crp, p.share) + p.GenShare(p.s, logBound, ciphertext, crp, p.share) } }) - b.Run(testString("Refresh/Round1/Agg", tc.NParties, params), func(b *testing.B) { + b.Run(GetTestName("Refresh/Round1/Agg", tc.NParties, params), func(b *testing.B) { for i := 0; i < b.N; i++ { p.AggregateShares(p.share, p.share, p.share) } }) - b.Run(testString("Refresh/Finalize", tc.NParties, params), func(b *testing.B) { + b.Run(GetTestName("Refresh/Finalize", tc.NParties, params), func(b *testing.B) { ctOut := ckks.NewCiphertext(params, 1, params.MaxLevel()) for i := 0; i < b.N; i++ { - p.Finalize(ciphertext, params.LogSlots(), crp, p.share, ctOut) + p.Finalize(ciphertext, crp, p.share, ctOut) } }) @@ -123,33 +126,33 @@ func benchMaskedTransform(tc *testContext, b *testing.B) { transform := &MaskedTransformFunc{ Decode: true, - Func: func(coeffs []*ring.Complex) { + Func: func(coeffs []*bignum.Complex) { for i := range coeffs { - coeffs[i][0].Mul(coeffs[i][0], ring.NewFloat(0.9238795325112867, logBound)) - coeffs[i][1].Mul(coeffs[i][1], ring.NewFloat(0.7071067811865476, logBound)) + coeffs[i][0].Mul(coeffs[i][0], bignum.NewFloat(0.9238795325112867, logBound)) + coeffs[i][1].Mul(coeffs[i][1], bignum.NewFloat(0.7071067811865476, logBound)) } }, Encode: true, } - b.Run(testString("Refresh&Transform/Round1/Gen", tc.NParties, params), func(b *testing.B) { + b.Run(GetTestName("Refresh&Transform/Round1/Gen", tc.NParties, params), func(b *testing.B) { for i := 0; i < b.N; i++ { - p.GenShare(p.s, p.s, logBound, params.LogSlots(), ciphertext, crp, transform, p.share) + p.GenShare(p.s, p.s, logBound, ciphertext, crp, transform, p.share) } }) - b.Run(testString("Refresh&Transform/Round1/Agg", tc.NParties, params), func(b *testing.B) { + b.Run(GetTestName("Refresh&Transform/Round1/Agg", tc.NParties, params), func(b *testing.B) { for i := 0; i < b.N; i++ { p.AggregateShares(p.share, p.share, p.share) } }) - b.Run(testString("Refresh&Transform/Transform", tc.NParties, params), func(b *testing.B) { + b.Run(GetTestName("Refresh&Transform/Transform", tc.NParties, params), func(b *testing.B) { ctOut := ckks.NewCiphertext(params, 1, params.MaxLevel()) for i := 0; i < b.N; i++ { - p.Transform(ciphertext, params.LogSlots(), transform, crp, p.share, ctOut) + p.Transform(ciphertext, transform, crp, p.share, ctOut) } }) diff --git a/dckks/dckks_test.go b/dckks/dckks_test.go index de6847b1..6bda127f 100644 --- a/dckks/dckks_test.go +++ b/dckks/dckks_test.go @@ -4,6 +4,8 @@ import ( "encoding/json" "flag" "fmt" + "math" + "math/big" "runtime" "testing" @@ -13,26 +15,23 @@ import ( "github.com/tuneinsight/lattigo/v4/drlwe" "github.com/tuneinsight/lattigo/v4/ring" "github.com/tuneinsight/lattigo/v4/rlwe" + "github.com/tuneinsight/lattigo/v4/utils" + "github.com/tuneinsight/lattigo/v4/utils/bignum" "github.com/tuneinsight/lattigo/v4/utils/sampling" ) -var flagLongTest = flag.Bool("long", false, "run the long test suite (all parameters + secure refresh). Overrides -short and requires -timeout=0.") -var flagPostQuantum = flag.Bool("pq", false, "run post quantum test suite (does not run non-PQ parameters).") var flagParamString = flag.String("params", "", "specify the test cryptographic parameters as a JSON string. Overrides -short and -long.") var printPrecisionStats = flag.Bool("print-precision", false, "print precision stats") -var minPrec float64 = 15.0 -func testString(opname string, parties int, params ckks.Parameters) string { - return fmt.Sprintf("%s/RingType=%s/logN=%d/logSlots=%d/logQ=%f/LogP=%f/levels=%d/#Pi=%d/Decomp=%d/parties=%d", +func GetTestName(opname string, parties int, params ckks.Parameters) string { + return fmt.Sprintf("%s/RingType=%s/logN=%d/logQP=%d/Qi=%d/Pi=%d/LogScale=%d/Parties=%d", opname, params.RingType(), params.LogN(), - params.LogSlots(), - params.LogQ(), - params.LogP(), - params.MaxLevel()+1, + int(math.Round(params.LogQP())), + params.QCount(), params.PCount(), - params.DecompRNS(params.MaxLevelQ(), params.MaxLevelP()), + int(math.Log2(params.DefaultScale().Float64())), parties) } @@ -43,7 +42,7 @@ type testContext struct { ringQ *ring.Ring ringP *ring.Ring - encoder ckks.Encoder + encoder *ckks.Encoder evaluator ckks.Evaluator encryptorPk0 rlwe.Encryptor @@ -74,36 +73,36 @@ func TestDCKKS(t *testing.T) { if err = json.Unmarshal([]byte(*flagParamString), &testParams[0]); err != nil { t.Fatal(err) } - case *flagLongTest: - for _, pls := range [][]ckks.ParametersLiteral{ - ckks.DefaultParams, - ckks.DefaultConjugateInvariantParams, - ckks.DefaultPostQuantumParams, - ckks.DefaultPostQuantumConjugateInvariantParams} { - testParams = append(testParams, pls...) - } - case *flagPostQuantum && testing.Short(): - testParams = append(ckks.DefaultPostQuantumParams[:2], ckks.DefaultPostQuantumConjugateInvariantParams[:2]...) - case *flagPostQuantum: - testParams = append(ckks.DefaultPostQuantumParams[:4], ckks.DefaultPostQuantumConjugateInvariantParams[:4]...) - case testing.Short(): - testParams = append(ckks.DefaultParams[:2], ckks.DefaultConjugateInvariantParams[:2]...) default: - testParams = append(ckks.DefaultParams[:4], ckks.DefaultConjugateInvariantParams[:4]...) + testParams = ckks.TestParamsLiteral } - for _, paramsLiteral := range testParams[:] { + for _, ringType := range []ring.Type{ring.Standard, ring.ConjugateInvariant} { - var params ckks.Parameters - if params, err = ckks.NewParametersFromLiteral(paramsLiteral); err != nil { - t.Fatal(err) - } - N := 3 - var tc *testContext - if tc, err = genTestParams(params, N); err != nil { - t.Fatal(err) - } + for _, paramsLiteral := range testParams { + paramsLiteral.RingType = ringType + + var params ckks.Parameters + if params, err = ckks.NewParametersFromLiteral(paramsLiteral); err != nil { + t.Fatal(err) + } + N := 3 + var tc *testContext + if tc, err = genTestParams(params, N); err != nil { + t.Fatal(err) + } + + for _, testSet := range []func(tc *testContext, t *testing.T){ + testE2SProtocol, + testRefresh, + testRefreshAndTransform, + testRefreshAndTransformSwitchParams, + testMarshalling, + } { + testSet(tc, t) + runtime.GC() + } for _, testSet := range []func(tc *testContext, t *testing.T){ testE2SProtocol, testRefresh, @@ -165,7 +164,7 @@ func testE2SProtocol(tc *testContext, t *testing.T) { params := tc.params - t.Run(testString("E2SProtocol", tc.NParties, params), func(t *testing.T) { + t.Run(GetTestName("E2SProtocol", tc.NParties, params), func(t *testing.T) { var minLevel int var logBound uint @@ -195,12 +194,12 @@ func testE2SProtocol(tc *testContext, t *testing.T) { P[i].sk = tc.sk0Shards[i] P[i].publicShareE2S = P[i].e2s.AllocateShare(minLevel) P[i].publicShareS2E = P[i].s2e.AllocateShare(params.MaxLevel()) - P[i].secretShare = drlwe.NewAdditiveShareBigint(params.Parameters, params.LogSlots()) + P[i].secretShare = NewAdditiveShareBigint(params, ciphertext.LogSlots) } for i, p := range P { // Enc(-M_i) - p.e2s.GenShare(p.sk, logBound, params.LogSlots(), ciphertext, p.secretShare, p.publicShareE2S) + p.e2s.GenShare(p.sk, logBound, ciphertext, p.secretShare, p.publicShareE2S) if i > 0 { // Enc(sum(-M_i)) @@ -209,10 +208,10 @@ func testE2SProtocol(tc *testContext, t *testing.T) { } // sum(-M_i) + x - P[0].e2s.GetShare(P[0].secretShare, P[0].publicShareE2S, params.LogSlots(), ciphertext, P[0].secretShare) + P[0].e2s.GetShare(P[0].secretShare, P[0].publicShareE2S, ciphertext, P[0].secretShare) // sum(-M_i) + x + sum(M_i) = x - rec := drlwe.NewAdditiveShareBigint(params.Parameters, params.LogSlots()) + rec := NewAdditiveShareBigint(params, ciphertext.LogSlots) for _, p := range P { a := rec.Value b := p.secretShare.Value @@ -232,7 +231,7 @@ func testE2SProtocol(tc *testContext, t *testing.T) { crp := P[0].s2e.SampleCRP(params.MaxLevel(), tc.crs) for i, p := range P { - p.s2e.GenShare(p.sk, crp, params.LogSlots(), p.secretShare, p.publicShareS2E) + p.s2e.GenShare(p.sk, crp, ciphertext.LogSlots, p.secretShare, p.publicShareS2E) if i > 0 { p.s2e.AggregateShares(P[0].publicShareS2E, p.publicShareS2E, P[0].publicShareS2E) } @@ -254,7 +253,7 @@ func testRefresh(tc *testContext, t *testing.T) { decryptorSk0 := tc.decryptorSk0 params := tc.params - t.Run(testString("Refresh", tc.NParties, params), func(t *testing.T) { + t.Run(GetTestName("Refresh", tc.NParties, params), func(t *testing.T) { var minLevel int var logBound uint @@ -289,7 +288,7 @@ func testRefresh(tc *testContext, t *testing.T) { P0 := RefreshParties[0] for _, scale := range []float64{params.DefaultScale().Float64(), params.DefaultScale().Float64() * 128} { - t.Run(fmt.Sprintf("atScale=%f", scale), func(t *testing.T) { + t.Run(fmt.Sprintf("AtScale=%d", int(math.Round(math.Log2(scale)))), func(t *testing.T) { coeffs, _, ciphertext := newTestVectorsAtScale(tc, encryptorPk0, -1, 1, rlwe.NewScale(scale)) // Brings ciphertext to minLevel + 1 @@ -299,14 +298,14 @@ func testRefresh(tc *testContext, t *testing.T) { for i, p := range RefreshParties { - p.GenShare(p.s, logBound, params.LogSlots(), ciphertext, crp, p.share) + p.GenShare(p.s, logBound, ciphertext, crp, p.share) if i > 0 { P0.AggregateShares(p.share, P0.share, P0.share) } } - P0.Finalize(ciphertext, params.LogSlots(), crp, P0.share, ciphertext) + P0.Finalize(ciphertext, crp, P0.share, ciphertext) verifyTestVectors(tc, decryptorSk0, coeffs, ciphertext, t) }) @@ -323,7 +322,7 @@ func testRefreshAndTransform(tc *testContext, t *testing.T) { params := tc.params decryptorSk0 := tc.decryptorSk0 - t.Run(testString("RefreshAndTransform", tc.NParties, params), func(t *testing.T) { + t.Run(GetTestName("RefreshAndTransform", tc.NParties, params), func(t *testing.T) { var minLevel int var logBound uint @@ -369,27 +368,28 @@ func testRefreshAndTransform(tc *testContext, t *testing.T) { transform := &MaskedTransformFunc{ Decode: true, - Func: func(coeffs []*ring.Complex) { + Func: func(coeffs []*bignum.Complex) { for i := range coeffs { - coeffs[i][0].Mul(coeffs[i][0], ring.NewFloat(0.9238795325112867, logBound)) - coeffs[i][1].Mul(coeffs[i][1], ring.NewFloat(0.7071067811865476, logBound)) + coeffs[i][0].Mul(coeffs[i][0], bignum.NewFloat(0.9238795325112867, logBound)) + coeffs[i][1].Mul(coeffs[i][1], bignum.NewFloat(0.7071067811865476, logBound)) } }, Encode: true, } for i, p := range RefreshParties { - p.GenShare(p.s, p.s, logBound, params.LogSlots(), ciphertext, crp, transform, p.share) + p.GenShare(p.s, p.s, logBound, ciphertext, crp, transform, p.share) if i > 0 { P0.AggregateShares(p.share, P0.share, P0.share) } } - P0.Transform(ciphertext, tc.params.LogSlots(), transform, crp, P0.share, ciphertext) + P0.Transform(ciphertext, transform, crp, P0.share, ciphertext) for i := range coeffs { - coeffs[i] = complex(real(coeffs[i])*0.9238795325112867, imag(coeffs[i])*0.7071067811865476) + coeffs[i][0].Mul(coeffs[i][0], bignum.NewFloat(0.9238795325112867, logBound)) + coeffs[i][1].Mul(coeffs[i][1], bignum.NewFloat(0.7071067811865476, logBound)) } verifyTestVectors(tc, decryptorSk0, coeffs, ciphertext, t) @@ -404,7 +404,7 @@ func testRefreshAndTransformSwitchParams(tc *testContext, t *testing.T) { sk0Shards := tc.sk0Shards params := tc.params - t.Run(testString("RefreshAndTransformAndSwitchParams", tc.NParties, params), func(t *testing.T) { + t.Run(GetTestName("RefreshAndTransformAndSwitchParams", tc.NParties, params), func(t *testing.T) { var minLevel int var logBound uint @@ -434,7 +434,6 @@ func testRefreshAndTransformSwitchParams(tc *testContext, t *testing.T) { LogQ: []int{54, 49, 49, 49, 49, 49, 49}, LogP: []int{52, 52}, RingType: params.RingType(), - LogSlots: params.MaxLogSlots() + 1, LogScale: 49, }) @@ -472,73 +471,145 @@ func testRefreshAndTransformSwitchParams(tc *testContext, t *testing.T) { transform := &MaskedTransformFunc{ Decode: true, - Func: func(coeffs []*ring.Complex) { + Func: func(coeffs []*bignum.Complex) { for i := range coeffs { - coeffs[i][0].Mul(coeffs[i][0], ring.NewFloat(0.9238795325112867, logBound)) - coeffs[i][1].Mul(coeffs[i][1], ring.NewFloat(0.7071067811865476, logBound)) + coeffs[i][0].Mul(coeffs[i][0], bignum.NewFloat(0.9238795325112867, logBound)) + coeffs[i][1].Mul(coeffs[i][1], bignum.NewFloat(0.7071067811865476, logBound)) } }, Encode: true, } for i, p := range RefreshParties { - p.GenShare(p.sIn, p.sOut, logBound, params.LogSlots(), ciphertext, crp, transform, p.share) + p.GenShare(p.sIn, p.sOut, logBound, ciphertext, crp, transform, p.share) if i > 0 { P0.AggregateShares(p.share, P0.share, P0.share) } } - P0.Transform(ciphertext, tc.params.LogSlots(), transform, crp, P0.share, ciphertext) + P0.Transform(ciphertext, transform, crp, P0.share, ciphertext) for i := range coeffs { - coeffs[i] = complex(real(coeffs[i])*0.9238795325112867, imag(coeffs[i])*0.7071067811865476) + coeffs[i][0].Mul(coeffs[i][0], bignum.NewFloat(0.9238795325112867, logBound)) + coeffs[i][1].Mul(coeffs[i][1], bignum.NewFloat(0.7071067811865476, logBound)) } - precStats := ckks.GetPrecisionStats(paramsOut, ckks.NewEncoder(paramsOut), nil, coeffs, ckks.NewDecryptor(paramsOut, skIdealOut).DecryptNew(ciphertext), params.LogSlots(), nil) + precStats := ckks.GetPrecisionStats(paramsOut, ckks.NewEncoder(paramsOut), nil, coeffs, ckks.NewDecryptor(paramsOut, skIdealOut).DecryptNew(ciphertext), nil, false) if *printPrecisionStats { t.Log(precStats.String()) } - require.GreaterOrEqual(t, precStats.MeanPrecision.Real, minPrec) - require.GreaterOrEqual(t, precStats.MeanPrecision.Imag, minPrec) + rf64, _ := precStats.MeanPrecision.Real.Float64() + if64, _ := precStats.MeanPrecision.Imag.Float64() + + minPrec := math.Log2(paramsOut.DefaultScale().Float64()) - float64(paramsOut.LogN()+2) + if minPrec < 0 { + minPrec = 0 + } + + require.GreaterOrEqual(t, rf64, minPrec) + require.GreaterOrEqual(t, if64, minPrec) }) } +func testMarshalling(tc *testContext, t *testing.T) { + params := tc.params + + t.Run(GetTestName("Marshalling/Refresh", tc.NParties, params), func(t *testing.T) { + + var minLevel int + var logBound uint + var ok bool + if minLevel, logBound, ok = GetMinimumLevelForRefresh(128, params.DefaultScale(), tc.NParties, params.Q()); ok != true { + t.Skip("Not enough levels to ensure correctness and 128 security") + } + + ciphertext := ckks.NewCiphertext(params, 1, minLevel) + ciphertext.Scale = params.DefaultScale() + tc.uniformSampler.AtLevel(minLevel).Read(ciphertext.Value[0]) + tc.uniformSampler.AtLevel(minLevel).Read(ciphertext.Value[1]) + + // Testing refresh shares + refreshproto := NewRefreshProtocol(tc.params, logBound, params.Xe()) + refreshshare := refreshproto.AllocateShare(ciphertext.Level(), params.MaxLevel()) + + crp := refreshproto.SampleCRP(params.MaxLevel(), tc.crs) + + refreshproto.GenShare(tc.sk0, logBound, ciphertext, crp, refreshshare) + + data, err := refreshshare.MarshalBinary() + + if err != nil { + t.Fatal("Could not marshal RefreshShare", err) + } + + resRefreshShare := new(MaskedTransformShare) + err = resRefreshShare.UnmarshalBinary(data) + + if err != nil { + t.Fatal("Could not unmarshal RefreshShare", err) + } + + for i, r := range refreshshare.e2sShare.Value.Coeffs { + if !utils.EqualSlice(resRefreshShare.e2sShare.Value.Coeffs[i], r) { + t.Fatal("Result of marshalling not the same as original : RefreshShare") + } + + } + for i, r := range refreshshare.s2eShare.Value.Coeffs { + if !utils.EqualSlice(resRefreshShare.s2eShare.Value.Coeffs[i], r) { + t.Fatal("Result of marshalling not the same as original : RefreshShare") + } + + } + }) +} + +func newTestVectors(testContext *testContext, encryptor rlwe.Encryptor, a, b complex128) (values []*bignum.Complex, plaintext *rlwe.Plaintext, ciphertext *rlwe.Ciphertext) { func newTestVectors(testContext *testContext, encryptor rlwe.Encryptor, a, b complex128) (values []complex128, plaintext *rlwe.Plaintext, ciphertext *rlwe.Ciphertext) { return newTestVectorsAtScale(testContext, encryptor, a, b, testContext.params.DefaultScale()) } -func newTestVectorsAtScale(testContext *testContext, encryptor rlwe.Encryptor, a, b complex128, scale rlwe.Scale) (values []complex128, plaintext *rlwe.Plaintext, ciphertext *rlwe.Ciphertext) { +func newTestVectorsAtScale(tc *testContext, encryptor rlwe.Encryptor, a, b complex128, scale rlwe.Scale) (values []*bignum.Complex, pt *rlwe.Plaintext, ct *rlwe.Ciphertext) { - params := testContext.params + prec := tc.encoder.Prec() - logSlots := params.LogSlots() + pt = ckks.NewPlaintext(tc.params, tc.params.MaxLevel()) + pt.Scale = scale - values = make([]complex128, 1<= 2^{128+logbound} - bound := ring.NewUint(1) + bound := bignum.NewInt(1) bound.Lsh(bound, uint(logBound)) boundMax := new(big.Int).Set(ringQ.ModulusAtLevel[levelQ]) @@ -95,7 +95,7 @@ func (e2s *E2SProtocol) GenShare(sk *rlwe.SecretKey, logBound uint, logSlots int boundHalf := new(big.Int).Rsh(bound, 1) - dslots := 1 << logSlots + dslots := 1 << ct.LogSlots if ringQ.Type() == ring.Standard { dslots *= 2 } @@ -104,7 +104,7 @@ func (e2s *E2SProtocol) GenShare(sk *rlwe.SecretKey, logBound uint, logSlots int // Generate the mask in Z[Y] for Y = X^{N/(2*slots)} for i := 0; i < dslots; i++ { - e2s.maskBigint[i] = ring.RandInt(prng, bound) + e2s.maskBigint[i] = bignum.RandInt(prng, bound) sign = e2s.maskBigint[i].Cmp(boundHalf) if sign == 1 || sign == 0 { e2s.maskBigint[i].Sub(e2s.maskBigint[i], bound) @@ -120,7 +120,7 @@ func (e2s *E2SProtocol) GenShare(sk *rlwe.SecretKey, logBound uint, logSlots int ringQ.SetCoefficientsBigint(secretShareOut.Value[:dslots], e2s.buff) // Maps Y^{N/n} -> X^{N} in Montgomery and NTT - ckks.NttSparseAndMontgomery(ringQ, logSlots, false, e2s.buff) + ckks.NttSparseAndMontgomery(ringQ, ct.LogSlots, false, e2s.buff) // Subtracts the mask to the encryption of zero ringQ.Sub(publicShareOut.Value, e2s.buff, publicShareOut.Value) @@ -131,7 +131,7 @@ func (e2s *E2SProtocol) GenShare(sk *rlwe.SecretKey, logBound uint, logSlots int // If the caller is not secret-key-share holder (i.e., didn't generate a decryption share), `secretShare` can be set to nil. // Therefore, in order to obtain an additive sharing of the message, only one party should call this method, and the other parties should use // the secretShareOut output of the GenShare method. -func (e2s *E2SProtocol) GetShare(secretShare *drlwe.AdditiveShareBigint, aggregatePublicShare *drlwe.CKSShare, logSlots int, ct *rlwe.Ciphertext, secretShareOut *drlwe.AdditiveShareBigint) { +func (e2s *E2SProtocol) GetShare(secretShare *rlwe.AdditiveShareBigint, aggregatePublicShare *drlwe.CKSShare, ct *rlwe.Ciphertext, secretShareOut *rlwe.AdditiveShareBigint) { levelQ := utils.Min(ct.Level(), aggregatePublicShare.Value.Level()) @@ -143,7 +143,7 @@ func (e2s *E2SProtocol) GetShare(secretShare *drlwe.AdditiveShareBigint, aggrega // Switches the LSSS RNS NTT ciphertext outside of the NTT domain ringQ.INTT(e2s.buff, e2s.buff) - dslots := 1 << logSlots + dslots := 1 << ct.LogSlots if ringQ.Type() == ring.Standard { dslots *= 2 } diff --git a/dckks/transform.go b/dckks/transform.go index fb0dd331..5aa80552 100644 --- a/dckks/transform.go +++ b/dckks/transform.go @@ -1,7 +1,6 @@ package dckks import ( - "fmt" "math/big" "github.com/tuneinsight/lattigo/v4/ckks" @@ -9,6 +8,7 @@ import ( "github.com/tuneinsight/lattigo/v4/ring" "github.com/tuneinsight/lattigo/v4/ring/distribution" "github.com/tuneinsight/lattigo/v4/rlwe" + "github.com/tuneinsight/lattigo/v4/utils/bignum" "github.com/tuneinsight/lattigo/v4/utils/sampling" ) @@ -23,7 +23,7 @@ type MaskedTransformProtocol struct { prec uint tmpMask []*big.Int - encoder ckks.EncoderBigComplex + encoder *ckks.Encoder } // ShallowCopy creates a shallow copy of MaskedTransformProtocol in which all the read-only data-structures are @@ -69,7 +69,7 @@ func (rfp *MaskedTransformProtocol) WithParams(paramsOut ckks.Parameters) *Maske // MaskedTransformFunc represents a user-defined in-place function that can be evaluated on masked CKKS plaintexts, as a part of the // Masked Transform Protocol. -// The function is called with a vector of *ring.Complex modulo ckks.Parameters.Slots() as input, and must write +// The function is called with a vector of *Complex modulo ckks.Parameters.Slots() as input, and must write // its output on the same buffer. // Transform can be the identity. // Decode: if true, then the masked CKKS plaintext will be decoded before applying Transform. @@ -77,7 +77,7 @@ func (rfp *MaskedTransformProtocol) WithParams(paramsOut ckks.Parameters) *Maske // i.e. : Decode (true/false) -> Transform -> Recode (true/false). type MaskedTransformFunc struct { Decode bool - Func func(coeffs []*ring.Complex) + Func func(coeffs []*bignum.Complex) Encode bool } @@ -88,10 +88,6 @@ type MaskedTransformFunc struct { // The method will return an error if the maximum number of slots of the output parameters is smaller than the number of slots of the input ciphertext. func NewMaskedTransformProtocol(paramsIn, paramsOut ckks.Parameters, prec uint, noise distribution.Distribution) (rfp *MaskedTransformProtocol, err error) { - if paramsIn.Slots() > paramsOut.MaxSlots() { - return nil, fmt.Errorf("newMaskedTransformProtocol: paramsOut.N()/2 < paramsIn.Slots()") - } - rfp = new(MaskedTransformProtocol) rfp.noise = noise.CopyNew() @@ -103,13 +99,14 @@ func NewMaskedTransformProtocol(paramsIn, paramsOut ckks.Parameters, prec uint, scale := paramsOut.DefaultScale().Value - rfp.defaultScale, _ = new(big.Float).SetPrec(256).Set(&scale).Int(nil) + rfp.defaultScale, _ = new(big.Float).SetPrec(prec).Set(&scale).Int(nil) rfp.tmpMask = make([]*big.Int, paramsIn.N()) for i := range rfp.tmpMask { rfp.tmpMask[i] = new(big.Int) } - rfp.encoder = ckks.NewEncoderBigComplex(paramsIn, prec) + + rfp.encoder = ckks.NewEncoder(paramsIn, prec) return } @@ -129,11 +126,11 @@ func (rfp *MaskedTransformProtocol) SampleCRP(level int, crs sampling.PRNG) drlw // skIn : the secret-key if the input ciphertext. // skOut : the secret-key of the output ciphertext. // logBound : the bit length of the masks. -// logSlots : the bit length of the number of slots. // ct1 : the degree 1 element the ciphertext to refresh, i.e. ct1 = ckk.Ciphetext.Value[1]. // scale : the scale of the ciphertext when entering the refresh. // The method "GetMinimumLevelForBootstrapping" should be used to get the minimum level at which the masked transform can be called while still ensure 128-bits of security, as well as the // value for logBound. +func (rfp *MaskedTransformProtocol) GenShare(skIn, skOut *rlwe.SecretKey, logBound uint, ct *rlwe.Ciphertext, crs drlwe.CKSCRP, transform *MaskedTransformFunc, shareOut *MaskedTransformShare) { func (rfp *MaskedTransformProtocol) GenShare(skIn, skOut *rlwe.SecretKey, logBound uint, logSlots int, ct *rlwe.Ciphertext, crs drlwe.CKSCRP, transform *MaskedTransformFunc, shareOut *drlwe.RefreshShare) { ringQ := rfp.s2e.params.RingQ() @@ -148,7 +145,7 @@ func (rfp *MaskedTransformProtocol) GenShare(skIn, skOut *rlwe.SecretKey, logBou panic("cannot GenShare: crs level must be equal to S2EShare") } - slots := 1 << logSlots + slots := 1 << ct.LogSlots dslots := slots if ringQ.Type() == ring.Standard { @@ -156,16 +153,20 @@ func (rfp *MaskedTransformProtocol) GenShare(skIn, skOut *rlwe.SecretKey, logBou } // Generates the decryption share + // Returns [M_i] on rfp.tmpMask and [a*s_i -M_i + e] on e2sShare + rfp.e2s.GenShare(skIn, logBound, ct, &rlwe.AdditiveShareBigint{Value: rfp.tmpMask}, &shareOut.e2sShare) // Returns [M_i] on rfp.tmpMask and [a*s_i -M_i + e] on E2SShare rfp.e2s.GenShare(skIn, logBound, logSlots, ct, &drlwe.AdditiveShareBigint{Value: rfp.tmpMask}, &shareOut.E2SShare) // Applies LT(M_i) if transform != nil { - bigComplex := make([]*ring.Complex, slots) + bigComplex := make([]*bignum.Complex, slots) for i := range bigComplex { - bigComplex[i] = ring.NewComplex(ring.NewFloat(0, rfp.prec), ring.NewFloat(0, rfp.prec)) + bigComplex[i] = bignum.NewComplex() + bigComplex[i][0].SetPrec(rfp.prec) + bigComplex[i][1].SetPrec(rfp.prec) } // Extracts sparse coefficients @@ -188,7 +189,7 @@ func (rfp *MaskedTransformProtocol) GenShare(skIn, skOut *rlwe.SecretKey, logBou // Decodes if asked to if transform.Decode { - rfp.encoder.FFT(bigComplex, 1< [a, c, b, d] ctN12 = evalCKKS.SlotsToCoeffsNew(ctN12, nil, SlotsToCoeffsMatrix) + ctN12.EncodingDomain = rlwe.CoefficientsDomain // Key-Switch from LogN = 12 to LogN = 11 ctN11 := rlwe.NewCiphertext(paramsN11.Parameters, 1, paramsN11.MaxLevel()) @@ -193,6 +196,7 @@ func main() { // Extracts & EvalLUT(LWEs, indexLUT) on the fly -> Repack(LWEs, indexRepack) -> RLWE ctN12 = evalLUT.EvaluateAndRepack(ctN11, lutPolyMap, repackIndex, LUTKEY) fmt.Printf("Done (%s)\n", time.Since(now)) + ctN12.EncodingDomain = rlwe.CoefficientsDomain fmt.Printf("Homomorphic Encoding... ") now = time.Now() @@ -200,7 +204,11 @@ func main() { ctN12, _ = evalCKKS.CoeffsToSlotsNew(ctN12, CoeffsToSlotsMatrix) fmt.Printf("Done (%s)\n", time.Since(now)) - for i, v := range encoderN12.Decode(decryptorN12.DecryptNew(ctN12), paramsN12.LogSlots()) { + res := make([]float64, slots) + ctN12.EncodingDomain = rlwe.SlotsDomain + ctN12.LogSlots = LogSlots + encoderN12.Decode(decryptorN12.DecryptNew(ctN12), res) + for i, v := range res { fmt.Printf("%7.4f -> %7.4f\n", values[i], v) } } diff --git a/examples/ckks/bootstrapping/main.go b/examples/ckks/bootstrapping/main.go index a5392ff3..d4bdc666 100644 --- a/examples/ckks/bootstrapping/main.go +++ b/examples/ckks/bootstrapping/main.go @@ -28,13 +28,19 @@ func main() { // bootstrapping circuit on top of the residual moduli that we defined. ckksParamsResidualLit := ckks.ParametersLiteral{ LogN: 16, // Log2 of the ringdegree - LogSlots: 15, // Log2 of the number of slots LogQ: []int{55, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40}, // Log2 of the ciphertext prime moduli LogP: []int{61, 61, 61, 61}, // Log2 of the key-switch auxiliary prime moduli LogScale: 40, // Log2 of the scale Xs: &distribution.Ternary{H: 192}, // Hamming weight of the secret } + LogSlots := ckksParamsResidualLit.LogN - 2 + + if *flagShort { + ckksParamsResidualLit.LogN -= 3 + LogSlots -= 3 + } + // Note that with H=192 and LogN=16, parameters are at least 128-bit if LogQP <= 1550. // Our default parameters have an expected logQP of 55 + 10*40 + 4*61 = 699, meaning // that the depth of the bootstrapping shouldn't be larger than 1550-699 = 851. @@ -43,12 +49,18 @@ func main() { // Thus we expect the bootstrapping to give a precision of 27.25 bits with H=192 (and 23.8 with H=N/2) // if the plaintext values are uniformly distributed in [-1, 1] for both the real and imaginary part. // See `/ckks/bootstrapping/parameters.go` for information about the optional fields. - btpParametersLit := bootstrapping.ParametersLiteral{} + btpParametersLit := bootstrapping.ParametersLiteral{ + // Since a ciphertext with message m and LogSlots = x is equivalent to a ciphertext with message m|m and LogSlots = x+1 + // it is possible to run the bootstrapping on any ciphertext with LogSlots <= bootstrapping.LogSlots, however doing so + // will increase the runtime, so it is recommanded to have the LogSlots of the ciphertext and bootstrapping parameters + // be the same. + LogSlots: &LogSlots, + } // The default bootstrapping parameters consume 822 bits which is smaller than the maximum // allowed of 851 in our example, so the target security is easily met. // We can print and verify the expected bit consumption of bootstrapping parameters with: - bits, err := btpParametersLit.BitConsumption() + bits, err := btpParametersLit.BitComsumption(LogSlots) if err != nil { panic(err) } @@ -63,15 +75,8 @@ func main() { } if *flagShort { - - prevLogSlots := ckksParamsLit.LogSlots - - ckksParamsLit.LogN = 13 - // Corrects the message ratio to take into account the smaller number of slots and keep the same precision - btpParams.EvalModParameters.LogMessageRatio += prevLogSlots - ckksParamsLit.LogN - 1 - - ckksParamsLit.LogSlots = ckksParamsLit.LogN - 1 + btpParams.EvalModParameters.LogMessageRatio += 3 } // This generate ckks.Parameters, with the NTT tables and other pre-computations from the ckks.ParametersLiteral (which is only a template). @@ -83,7 +88,7 @@ func main() { // Here we print some information about the generated ckks.Parameters // We can notably check that the LogQP of the generated ckks.Parameters is equal to 699 + 822 = 1521. // Not that this value can be overestimated by one bit. - fmt.Printf("CKKS parameters: logN=%d, logSlots=%d, H(%d; %d), sigma=%f, logQP=%f, levels=%d, scale=2^%f\n", params.LogN(), params.LogSlots(), params.XsHammingWeight(), btpParams.EphemeralSecretWeight, params.Xe(), params.LogQP(), params.QCount(), math.Log2(params.DefaultScale().Float64())) + fmt.Printf("CKKS parameters: logN=%d, logSlots=%d, H(%d; %d), sigma=%f, logQP=%f, levels=%d, scale=2^%f\n", params.LogN(), LogSlots, params.XsHammingWeight(), btpParams.EphemeralSecretWeight, params.Xe(), params.LogQP(), params.QCount(), math.Log2(params.DefaultScale().Float64())) // Scheme context and keys kgen := ckks.NewKeyGenerator(params) @@ -104,13 +109,15 @@ func main() { panic(err) } - // Generate a random plaintext with values uniformly distributed in [-1, 1] for the real and imaginary part. - valuesWant := make([]complex128, params.Slots()) + // Generate a random plaintext with values uniformely distributed in [-1, 1] for the real and imaginary part. + valuesWant := make([]complex128, 1<>1) - idxG := make([]int, params.Slots()>>1) - for i := 0; i < params.Slots()>>1; i++ { + idxF := make([]int, slots>>1) + idxG := make([]int, slots>>1) + for i := 0; i < slots>>1; i++ { idxF[i] = i * 2 // Index with all even slots idxG[i] = i*2 + 1 // Index with all odd slots } @@ -89,21 +99,21 @@ func chebyshevinterpolation() { slotsIndex[1] = idxG // Assigns index of all odd slots to poly[1] = g(x) // Change of variable - evaluator.MultByConst(ciphertext, 2/(b-a), ciphertext) - evaluator.AddConst(ciphertext, (-a-b)/(b-a), ciphertext) + evaluator.Mul(ciphertext, 2/(b-a), ciphertext) + evaluator.Add(ciphertext, (-a-b)/(b-a), ciphertext) if err := evaluator.Rescale(ciphertext, params.DefaultScale(), ciphertext); err != nil { panic(err) } // We evaluate the interpolated Chebyshev interpolant on the ciphertext - if ciphertext, err = evaluator.EvaluatePolyVector(ciphertext, []*ckks.Polynomial{approxF, approxG}, encoder, slotsIndex, ciphertext.Scale); err != nil { + if ciphertext, err = evaluator.EvaluatePolyVector(ciphertext, []*bignum.Polynomial{approxF, approxG}, encoder, slotsIndex, ciphertext.Scale); err != nil { panic(err) } fmt.Println("Done... Consumed levels:", params.MaxLevel()-ciphertext.Level()) // Computation of the reference values - for i := 0; i < params.Slots()>>1; i++ { + for i := 0; i < slots>>1; i++ { values[i*2] = f(values[i*2]) values[i*2+1] = g(values[i*2+1]) } @@ -125,14 +135,11 @@ func round(x float64) float64 { return math.Round(x*100000000) / 100000000 } -func printDebug(params ckks.Parameters, ciphertext *rlwe.Ciphertext, valuesWant []float64, decryptor rlwe.Decryptor, encoder ckks.Encoder) (valuesTest []float64) { +func printDebug(params ckks.Parameters, ciphertext *rlwe.Ciphertext, valuesWant []float64, decryptor rlwe.Decryptor, encoder *ckks.Encoder) (valuesTest []float64) { - tmp := encoder.Decode(decryptor.DecryptNew(ciphertext), params.LogSlots()) + valuesTest = make([]float64, 1<> 1 for i, s := range r.SubRings[:r.level+1] { @@ -172,6 +172,16 @@ func (r *Ring) AddDoubleRNSScalar(p1 *Poly, scalar0, scalar1 RNSScalar, p2 *Poly } } +// SubDoubleRNSScalar evaluates p2 = p1[:N/2] - scalar0 || p1[N/2] - scalar1 coefficient-wise in the ring, +// with the scalar values expressed in the CRT decomposition at a given level. +func (r *Ring) SubDoubleRNSScalar(p1 *Poly, scalar0, scalar1 RNSScalar, p2 *Poly) { + NHalf := r.N() >> 1 + for i, s := range r.SubRings[:r.level+1] { + s.SubScalar(p1.Coeffs[i][:NHalf], scalar0[i], p2.Coeffs[i][:NHalf]) + s.SubScalar(p1.Coeffs[i][NHalf:], scalar1[i], p2.Coeffs[i][NHalf:]) + } +} + // SubScalar evaluates p2 = p1 - scalar coefficient-wise in the ring. func (r *Ring) SubScalar(p1 *Poly, scalar uint64, p2 *Poly) { for i, s := range r.SubRings[:r.level+1] { @@ -183,7 +193,7 @@ func (r *Ring) SubScalar(p1 *Poly, scalar uint64, p2 *Poly) { func (r *Ring) SubScalarBigint(p1 *Poly, scalar *big.Int, p2 *Poly) { tmp := new(big.Int) for i, s := range r.SubRings[:r.level+1] { - s.SubScalar(p1.Coeffs[i], tmp.Mod(scalar, NewUint(s.Modulus)).Uint64(), p2.Coeffs[i]) + s.SubScalar(p1.Coeffs[i], tmp.Mod(scalar, bignum.NewInt(s.Modulus)).Uint64(), p2.Coeffs[i]) } } @@ -221,7 +231,7 @@ func (r *Ring) MulScalarThenSub(p1 *Poly, scalar uint64, p2 *Poly) { func (r *Ring) MulScalarBigint(p1 *Poly, scalar *big.Int, p2 *Poly) { scalarQi := new(big.Int) for i, s := range r.SubRings[:r.level+1] { - scalarQi.Mod(scalar, NewUint(s.Modulus)) + scalarQi.Mod(scalar, bignum.NewInt(s.Modulus)) s.MulScalarMontgomery(p1.Coeffs[i], MForm(scalarQi.Uint64(), s.Modulus, s.BRedConstant), p2.Coeffs[i]) } } diff --git a/ring/primes.go b/ring/primes.go index 75330594..ba057115 100644 --- a/ring/primes.go +++ b/ring/primes.go @@ -3,11 +3,13 @@ package ring import ( "fmt" "math/bits" + + "github.com/tuneinsight/lattigo/v4/utils/bignum" ) // IsPrime applies the Baillie-PSW, which is 100% accurate for numbers bellow 2^64. func IsPrime(x uint64) bool { - return NewUint(x).ProbablyPrime(0) + return bignum.NewInt(x).ProbablyPrime(0) } // GenerateNTTPrimes generates n NthRoot NTT friendly primes given logQ = size of the primes. diff --git a/ring/ring.go b/ring/ring.go index 6ba53b3c..b5fe9e29 100644 --- a/ring/ring.go +++ b/ring/ring.go @@ -10,6 +10,7 @@ import ( "math/big" "github.com/tuneinsight/lattigo/v4/utils" + "github.com/tuneinsight/lattigo/v4/utils/bignum" ) // GaloisGen is an integer of order N/2 modulo M that spans Z_M with the integer -1. @@ -268,9 +269,9 @@ func NewRingWithCustomNTT(N int, ModuliChain []uint64, ntt func(*SubRing, int) N // Computes bigQ for all levels r.ModulusAtLevel = make([]*big.Int, len(ModuliChain)) - r.ModulusAtLevel[0] = NewUint(ModuliChain[0]) + r.ModulusAtLevel[0] = bignum.NewInt(ModuliChain[0]) for i := 1; i < len(ModuliChain); i++ { - r.ModulusAtLevel[i] = new(big.Int).Mul(r.ModulusAtLevel[i-1], NewUint(ModuliChain[i])) + r.ModulusAtLevel[i] = new(big.Int).Mul(r.ModulusAtLevel[i-1], bignum.NewInt(ModuliChain[i])) } r.SubRings = make([]*SubRing, len(ModuliChain)) @@ -396,7 +397,7 @@ func (r *Ring) PolyToBigint(p1 *Poly, gap int, coeffsBigint []*big.Int) { coeffsBigint[i] = new(big.Int) for k := 0; k < r.level+1; k++ { - coeffsBigint[i].Add(coeffsBigint[i], tmp.Mul(NewUint(p1.Coeffs[k][j]), crtReconstruction[k])) + coeffsBigint[i].Add(coeffsBigint[i], tmp.Mul(bignum.NewInt(p1.Coeffs[k][j]), crtReconstruction[k])) } coeffsBigint[i].Mod(coeffsBigint[i], modulusBigint) @@ -436,7 +437,7 @@ func (r *Ring) PolyToBigintCentered(p1 *Poly, gap int, coeffsBigint []*big.Int) coeffsBigint[i].SetUint64(0) for k := 0; k < r.level+1; k++ { - coeffsBigint[i].Add(coeffsBigint[i], tmp.Mul(NewUint(p1.Coeffs[k][j]), crtReconstruction[k])) + coeffsBigint[i].Add(coeffsBigint[i], tmp.Mul(bignum.NewInt(p1.Coeffs[k][j]), crtReconstruction[k])) } coeffsBigint[i].Mod(coeffsBigint[i], modulusBigint) @@ -576,16 +577,16 @@ func (r *Ring) Log2OfStandardDeviation(poly *Poly) (std float64) { r.PolyToBigintCentered(poly, 1, coeffs) - mean := NewFloat(0, prec) - tmp := NewFloat(0, prec) + mean := bignum.NewFloat(0, prec) + tmp := bignum.NewFloat(0, prec) for i := 0; i < N; i++ { mean.Add(mean, tmp.SetInt(coeffs[i])) } - mean.Quo(mean, NewFloat(float64(N), prec)) + mean.Quo(mean, bignum.NewFloat(float64(N), prec)) - stdFloat := NewFloat(0, prec) + stdFloat := bignum.NewFloat(0, prec) for i := 0; i < N; i++ { tmp.SetInt(coeffs[i]) @@ -594,7 +595,7 @@ func (r *Ring) Log2OfStandardDeviation(poly *Poly) (std float64) { stdFloat.Add(stdFloat, tmp) } - stdFloat.Quo(stdFloat, NewFloat(float64(N-1), prec)) + stdFloat.Quo(stdFloat, bignum.NewFloat(float64(N-1), prec)) stdFloat.Sqrt(stdFloat) diff --git a/ring/ring_benchmark_test.go b/ring/ring_benchmark_test.go index 702a8aad..1c4bace4 100644 --- a/ring/ring_benchmark_test.go +++ b/ring/ring_benchmark_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/tuneinsight/lattigo/v4/ring/distribution" + "github.com/tuneinsight/lattigo/v4/utils/bignum" ) func BenchmarkRing(b *testing.B) { @@ -265,8 +266,8 @@ func benchMulScalar(tc *testParams, b *testing.B) { rand1 := RandUniform(tc.prng, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF) rand2 := RandUniform(tc.prng, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF) - scalarBigint := NewUint(rand1) - scalarBigint.Mul(scalarBigint, NewUint(rand2)) + scalarBigint := bignum.NewInt(rand1) + scalarBigint.Mul(scalarBigint, bignum.NewInt(rand2)) b.Run(testString("MulScalar/uint64/", tc.ringQ), func(b *testing.B) { for i := 0; i < b.N; i++ { diff --git a/ring/ring_test.go b/ring/ring_test.go index ec977e4f..f712353f 100644 --- a/ring/ring_test.go +++ b/ring/ring_test.go @@ -14,6 +14,7 @@ import ( "github.com/tuneinsight/lattigo/v4/utils/structs" "github.com/stretchr/testify/require" + "github.com/tuneinsight/lattigo/v4/utils/bignum" ) var flagLongTest = flag.Bool("long", false, "run the long test suite (all parameters). Overrides -short and requires -timeout=0.") @@ -231,8 +232,8 @@ func testDivFloorByLastModulusMany(tc *testParams, t *testing.T) { coeffs := make([]*big.Int, N) for i := 0; i < N; i++ { - coeffs[i] = RandInt(prng, tc.ringQ.ModulusAtLevel[level]) - coeffs[i].Quo(coeffs[i], NewUint(10)) + coeffs[i] = bignum.RandInt(prng, tc.ringQ.ModulusAtLevel[level]) + coeffs[i].Quo(coeffs[i], bignum.NewInt(10)) } nbRescales := level @@ -241,7 +242,7 @@ func testDivFloorByLastModulusMany(tc *testParams, t *testing.T) { for i := range coeffs { coeffsWant[i] = new(big.Int).Set(coeffs[i]) for j := 0; j < nbRescales; j++ { - coeffsWant[i].Quo(coeffsWant[i], NewUint(tc.ringQ.SubRings[level-j].Modulus)) + coeffsWant[i].Quo(coeffsWant[i], bignum.NewInt(tc.ringQ.SubRings[level-j].Modulus)) } } @@ -264,7 +265,7 @@ func testDivFloorByLastModulusMany(tc *testParams, t *testing.T) { func testDivRoundByLastModulusMany(tc *testParams, t *testing.T) { - t.Run(testString("DivRoundByLastModulusMany", tc.ringQ), func(t *testing.T) { + t.Run(testString("bignum.DivRoundByLastModulusMany", tc.ringQ), func(t *testing.T) { prng, _ := sampling.NewPRNG() @@ -276,8 +277,8 @@ func testDivRoundByLastModulusMany(tc *testParams, t *testing.T) { coeffs := make([]*big.Int, N) for i := 0; i < N; i++ { - coeffs[i] = RandInt(prng, tc.ringQ.ModulusAtLevel[level]) - coeffs[i].Quo(coeffs[i], NewUint(10)) + coeffs[i] = bignum.RandInt(prng, tc.ringQ.ModulusAtLevel[level]) + coeffs[i].Quo(coeffs[i], bignum.NewInt(10)) } nbRescals := level @@ -286,7 +287,7 @@ func testDivRoundByLastModulusMany(tc *testParams, t *testing.T) { for i := range coeffs { coeffsWant[i] = new(big.Int).Set(coeffs[i]) for j := 0; j < nbRescals; j++ { - DivRound(coeffsWant[i], NewUint(tc.ringQ.SubRings[level-j].Modulus), coeffsWant[i]) + bignum.DivRound(coeffsWant[i], bignum.NewInt(tc.ringQ.SubRings[level-j].Modulus), coeffsWant[i]) } } @@ -501,15 +502,15 @@ func testModularReduction(tc *testParams, t *testing.T) { for j, q := range tc.ringQ.ModuliChain() { - bigQ = NewUint(q) + bigQ = bignum.NewInt(q) brc := tc.ringQ.SubRings[j].BRedConstant x = 1 y = 1 - result = NewUint(x) - result.Mul(result, NewUint(y)) + result = bignum.NewInt(x) + result.Mul(result, bignum.NewInt(y)) result.Mod(result, bigQ) require.Equalf(t, BRed(x, y, q, brc), result.Uint64(), "x = %v, y=%v", x, y) @@ -517,8 +518,8 @@ func testModularReduction(tc *testParams, t *testing.T) { x = 1 y = q - 1 - result = NewUint(x) - result.Mul(result, NewUint(y)) + result = bignum.NewInt(x) + result.Mul(result, bignum.NewInt(y)) result.Mod(result, bigQ) require.Equalf(t, BRed(x, y, q, brc), result.Uint64(), "x = %v, y=%v", x, y) @@ -526,8 +527,8 @@ func testModularReduction(tc *testParams, t *testing.T) { x = 1 y = 0xFFFFFFFFFFFFFFFF - result = NewUint(x) - result.Mul(result, NewUint(y)) + result = bignum.NewInt(x) + result.Mul(result, bignum.NewInt(y)) result.Mod(result, bigQ) require.Equalf(t, BRed(x, y, q, brc), result.Uint64(), "x = %v, y=%v", x, y) @@ -535,8 +536,8 @@ func testModularReduction(tc *testParams, t *testing.T) { x = q - 1 y = q - 1 - result = NewUint(x) - result.Mul(result, NewUint(y)) + result = bignum.NewInt(x) + result.Mul(result, bignum.NewInt(y)) result.Mod(result, bigQ) require.Equalf(t, BRed(x, y, q, brc), result.Uint64(), "x = %v, y=%v", x, y) @@ -544,8 +545,8 @@ func testModularReduction(tc *testParams, t *testing.T) { x = q - 1 y = 0xFFFFFFFFFFFFFFFF - result = NewUint(x) - result.Mul(result, NewUint(y)) + result = bignum.NewInt(x) + result.Mul(result, bignum.NewInt(y)) result.Mod(result, bigQ) require.Equalf(t, BRed(x, y, q, brc), result.Uint64(), "x = %v, y=%v", x, y) @@ -553,8 +554,8 @@ func testModularReduction(tc *testParams, t *testing.T) { x = 0xFFFFFFFFFFFFFFFF y = 0xFFFFFFFFFFFFFFFF - result = NewUint(x) - result.Mul(result, NewUint(y)) + result = bignum.NewInt(x) + result.Mul(result, bignum.NewInt(y)) result.Mod(result, bigQ) require.Equalf(t, BRed(x, y, q, brc), result.Uint64(), "x = %v, y=%v", x, y) @@ -568,7 +569,7 @@ func testModularReduction(tc *testParams, t *testing.T) { for j, q := range tc.ringQ.ModuliChain() { - bigQ = NewUint(q) + bigQ = bignum.NewInt(q) brc := tc.ringQ.SubRings[j].BRedConstant mrc := tc.ringQ.SubRings[j].MRedConstant @@ -576,8 +577,8 @@ func testModularReduction(tc *testParams, t *testing.T) { x = 1 y = 1 - result = NewUint(x) - result.Mul(result, NewUint(y)) + result = bignum.NewInt(x) + result.Mul(result, bignum.NewInt(y)) result.Mod(result, bigQ) require.Equalf(t, MRed(x, MForm(y, q, brc), q, mrc), result.Uint64(), "x = %v, y=%v", x, y) @@ -585,8 +586,8 @@ func testModularReduction(tc *testParams, t *testing.T) { x = 1 y = q - 1 - result = NewUint(x) - result.Mul(result, NewUint(y)) + result = bignum.NewInt(x) + result.Mul(result, bignum.NewInt(y)) result.Mod(result, bigQ) require.Equalf(t, MRed(x, MForm(y, q, brc), q, mrc), result.Uint64(), "x = %v, y=%v", x, y) @@ -594,8 +595,8 @@ func testModularReduction(tc *testParams, t *testing.T) { x = 1 y = 0xFFFFFFFFFFFFFFFF - result = NewUint(x) - result.Mul(result, NewUint(y)) + result = bignum.NewInt(x) + result.Mul(result, bignum.NewInt(y)) result.Mod(result, bigQ) require.Equalf(t, MRed(x, MForm(y, q, brc), q, mrc), result.Uint64(), "x = %v, y=%v", x, y) @@ -603,8 +604,8 @@ func testModularReduction(tc *testParams, t *testing.T) { x = q - 1 y = q - 1 - result = NewUint(x) - result.Mul(result, NewUint(y)) + result = bignum.NewInt(x) + result.Mul(result, bignum.NewInt(y)) result.Mod(result, bigQ) require.Equalf(t, MRed(x, MForm(y, q, brc), q, mrc), result.Uint64(), "x = %v, y=%v", x, y) @@ -612,8 +613,8 @@ func testModularReduction(tc *testParams, t *testing.T) { x = q - 1 y = 0xFFFFFFFFFFFFFFFF - result = NewUint(x) - result.Mul(result, NewUint(y)) + result = bignum.NewInt(x) + result.Mul(result, bignum.NewInt(y)) result.Mod(result, bigQ) require.Equalf(t, MRed(x, MForm(y, q, brc), q, mrc), result.Uint64(), "x = %v, y=%v", x, y) @@ -621,8 +622,8 @@ func testModularReduction(tc *testParams, t *testing.T) { x = 0xFFFFFFFFFFFFFFFF y = 0xFFFFFFFFFFFFFFFF - result = NewUint(x) - result.Mul(result, NewUint(y)) + result = bignum.NewInt(x) + result.Mul(result, bignum.NewInt(y)) result.Mod(result, bigQ) require.Equalf(t, MRed(x, MForm(y, q, brc), q, mrc), result.Uint64(), "x = %v, y=%v", x, y) @@ -654,8 +655,8 @@ func testMulScalarBigint(tc *testParams, t *testing.T) { rand1 := RandUniform(tc.prng, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF) rand2 := RandUniform(tc.prng, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF) - scalarBigint := NewUint(rand1) - scalarBigint.Mul(scalarBigint, NewUint(rand2)) + scalarBigint := bignum.NewInt(rand1) + scalarBigint.Mul(scalarBigint, bignum.NewInt(rand2)) tc.ringQ.MulScalar(polWant, rand1, polWant) tc.ringQ.MulScalar(polWant, rand2, polWant) @@ -688,7 +689,7 @@ func testExtendBasis(tc *testParams, t *testing.T) { coeffs := make([]*big.Int, N) for i := 0; i < N; i++ { - coeffs[i] = RandInt(prng, Q) + coeffs[i] = bignum.RandInt(prng, Q) coeffs[i].Sub(coeffs[i], QHalf) } @@ -728,7 +729,7 @@ func testExtendBasis(tc *testParams, t *testing.T) { coeffs := make([]*big.Int, N) for i := 0; i < N; i++ { - coeffs[i] = RandInt(prng, P) + coeffs[i] = bignum.RandInt(prng, P) coeffs[i].Sub(coeffs[i], PHalf) } @@ -768,14 +769,14 @@ func testExtendBasis(tc *testParams, t *testing.T) { coeffs := make([]*big.Int, N) for i := 0; i < N; i++ { - coeffs[i] = RandInt(prng, QP) - coeffs[i].Quo(coeffs[i], NewUint(10)) + coeffs[i] = bignum.RandInt(prng, QP) + coeffs[i].Quo(coeffs[i], bignum.NewInt(10)) } coeffsWant := make([]*big.Int, N) for i := range coeffs { coeffsWant[i] = new(big.Int).Set(coeffs[i]) - DivRound(coeffsWant[i], P, coeffsWant[i]) + bignum.DivRound(coeffsWant[i], P, coeffsWant[i]) } PolQHave := ringQ.NewPoly() @@ -815,14 +816,14 @@ func testExtendBasis(tc *testParams, t *testing.T) { coeffs := make([]*big.Int, N) for i := 0; i < N; i++ { - coeffs[i] = RandInt(prng, QP) - coeffs[i].Quo(coeffs[i], NewUint(10)) + coeffs[i] = bignum.RandInt(prng, QP) + coeffs[i].Quo(coeffs[i], bignum.NewInt(10)) } coeffsWant := make([]*big.Int, N) for i := range coeffs { coeffsWant[i] = new(big.Int).Set(coeffs[i]) - DivRound(coeffsWant[i], Q, coeffsWant[i]) + bignum.DivRound(coeffsWant[i], Q, coeffsWant[i]) } PolQHave := ringQ.NewPoly() diff --git a/ring/sampler_gaussian.go b/ring/sampler_gaussian.go index ec50f2fa..b3e85ab9 100644 --- a/ring/sampler_gaussian.go +++ b/ring/sampler_gaussian.go @@ -7,6 +7,7 @@ import ( "github.com/tuneinsight/lattigo/v4/ring/distribution" "github.com/tuneinsight/lattigo/v4/utils/sampling" + "github.com/tuneinsight/lattigo/v4/utils/bignum" ) // GaussianSampler keeps the state of a truncated Gaussian polynomial sampler. @@ -100,7 +101,7 @@ func (g *GaussianSampler) read(pol *Poly, f func(a, b, c uint64) uint64) { Qi := make([]*big.Int, len(moduli)) for i, qi := range moduli { - Qi[i] = NewUint(qi) + Qi[i] = bignum.NewInt(qi) } var coeffInt *big.Int @@ -125,9 +126,9 @@ func (g *GaussianSampler) read(pol *Poly, f func(a, b, c uint64) uint64) { normInt.Lsh(sigmaInt, uint(math.Log2(norm)+bias)) } - coeffInt = RandInt(g.prng, normInt) + coeffInt = bignum.RandInt(g.prng, normInt) - coeffInt.Mul(coeffInt, NewInt(2*int64(sign)-1)) + coeffInt.Mul(coeffInt, bignum.NewInt(2*int64(sign)-1)) if coeffInt.Cmp(boundInt) < 1 { break diff --git a/rlwe/evaluator.go b/rlwe/evaluator.go index 4e86f180..838047b0 100644 --- a/rlwe/evaluator.go +++ b/rlwe/evaluator.go @@ -8,6 +8,14 @@ import ( "github.com/tuneinsight/lattigo/v4/utils" ) +// Operand is a common interface for Ciphertext and Plaintext types. +type Operand interface { + El() *Ciphertext + Degree() int + Level() int + GetMetaData() *MetaData +} + // Evaluator is a struct that holds the necessary elements to execute general homomorphic // operation on RLWE ciphertexts, such as automorphisms, key-switching and relinearization. type Evaluator struct { @@ -136,11 +144,17 @@ func (eval *Evaluator) CheckAndGetRelinearizationKey() (evk *RelinearizationKey, // CheckBinary checks that: // -// Inputs are not nil -// op0.Degree() + op1.Degree() != 0 (i.e at least one operand is a ciphertext) -// opOut.Degree() >= opOutMinDegree -// op0.IsNTT = DefaultNTTFlag -// op1.IsNTT = DefaultNTTFlag +// Inputs are not nil +// op0.Degree() + op1.Degree() != 0 (i.e at least one operand is a ciphertext) +// opOut.Degree() >= opOutMinDegree +// op0.IsNTT == op1.IsNTT == DefaultNTTFlag +// op0.EncodingDomain == op1.EncodingDomain +// +// The method will also update the MetaData of OpOut: +// +// IsNTT <- DefaultNTTFlag +// EncodingDomain <- op0.EncodingDomain +// LogSlots <- max(op0.LogSlots, op1.LogSlots) // // and returns max(op0.Degree(), op1.Degree(), opOut.Degree()) and min(op0.Level(), op1.Level(), opOut.Level()) func (eval *Evaluator) CheckBinary(op0, op1, opOut Operand, opOutMinDegree int) (degree, level int) { @@ -164,12 +178,31 @@ func (eval *Evaluator) CheckBinary(op0, op1, opOut Operand, opOutMinDegree int) opOut.El().IsNTT = op0.El().IsNTT } - opOut.El().Resize(utils.Max(opOutMinDegree, opOut.Degree()), level) + if op0.El().IsNTT != op1.El().IsNTT || op0.El().IsNTT != eval.params.DefaultNTTFlag() { + panic(fmt.Sprintf("op0.El().IsNTT or op1.El().IsNTT != %t", eval.params.DefaultNTTFlag())) + } else { + opOut.El().IsNTT = op0.El().IsNTT + } + + if op0.El().EncodingDomain != op1.El().EncodingDomain { + panic("op1.El().EncodingDomain != op2.El().EncodingDomain") + } else { + opOut.El().EncodingDomain = op0.El().EncodingDomain + } + + opOut.El().LogSlots = utils.MaxInt(op0.El().LogSlots, op1.El().LogSlots) return } // CheckUnary checks that op0 and opOut are not nil and that op0 respects the DefaultNTTFlag. +// +// The method will also update the metadata of opOut: +// +// IsNTT <- DefaultNTTFlag +// EncodingDomain <- op0.EncodingDomain +// LogSlots <- op0.LogSlots +// // Also returns max(op0.Degree(), opOut.Degree()) and min(op0.Level(), opOut.Level()). func (eval *Evaluator) CheckUnary(op0, opOut Operand) (degree, level int) { @@ -179,9 +212,15 @@ func (eval *Evaluator) CheckUnary(op0, opOut Operand) (degree, level int) { if op0.El().IsNTT != eval.params.DefaultNTTFlag() { panic(fmt.Sprintf("op0.IsNTT() != %t", eval.params.DefaultNTTFlag())) + } else { + opOut.El().IsNTT = op0.El().IsNTT } - return utils.Max(op0.Degree(), opOut.Degree()), utils.Min(op0.Level(), opOut.Level()) + opOut.El().EncodingDomain = op0.El().EncodingDomain + + opOut.El().LogSlots = op0.El().LogSlots + + return utils.MaxInt(op0.Degree(), opOut.Degree()), utils.MinInt(op0.Level(), opOut.Level()) } // ShallowCopy creates a shallow copy of this Evaluator in which all the read-only data-structures are diff --git a/rlwe/linear_transform.go b/rlwe/linear_transform.go index 231a3d5e..aaed9eb8 100644 --- a/rlwe/linear_transform.go +++ b/rlwe/linear_transform.go @@ -124,6 +124,7 @@ func (eval *Evaluator) Expand(ctIn *Ciphertext, logN, logGap int) (ctOut []*Ciph ctOut = make([]*Ciphertext, 1<<(logN-logGap)) ctOut[0] = ctIn.CopyNew() + ctOut[0].LogSlots = 0 if ct := ctOut[0]; !ctIn.IsNTT { ringQ.NTT(ct.Value[0], ct.Value[0]) diff --git a/rlwe/metadata.go b/rlwe/metadata.go index a0a6b75c..06807e9b 100644 --- a/rlwe/metadata.go +++ b/rlwe/metadata.go @@ -7,21 +7,37 @@ import ( "github.com/google/go-cmp/cmp" ) +type EncodingDomain int + +const ( + SlotsDomain = EncodingDomain(0) + CoefficientsDomain = EncodingDomain(1) +) + // MetaData is a struct storing metadata. type MetaData struct { Scale - IsNTT bool - IsMontgomery bool + EncodingDomain EncodingDomain + LogSlots int + IsNTT bool + IsMontgomery bool } // Equal returns true if two MetaData structs are identical. -func (m *MetaData) Equal(other *MetaData) (res bool) { return cmp.Equal(&m.Scale, &other.Scale) && m.IsNTT == other.IsNTT && m.IsMontgomery == other.IsMontgomery +func (m *MetaData) Equal(other MetaData) (res bool) { + res = m.Scale.Cmp(other.Scale) == 0 + res = res && m.EncodingDomain == other.EncodingDomain + res = res && m.LogSlots == other.LogSlots + res = res && m.IsNTT == other.IsNTT + res = res && m.IsMontgomery == other.IsMontgomery + return } -// BinarySize returns the size in bytes that the object once marshalled into a binary form. -func (m *MetaData) BinarySize() int { - return 2 + m.Scale.BinarySize() +// Slots returns the number of slots. +func (m *MetaData) Slots() int { + return 1 << m.LogSlots +} } // MarshalBinary encodes the object into a binary form on a newly allocated slice of bytes. @@ -72,6 +88,14 @@ func (m *MetaData) Encode(p []byte) (n int, err error) { return 0, err } + ptr += inc + + data[ptr] = uint8(m.EncodingDomain) + ptr++ + + data[ptr] = uint8(m.LogSlots) + ptr++ + if m.IsNTT { p[n] = 1 } @@ -99,6 +123,12 @@ func (m *MetaData) Decode(p []byte) (n int, err error) { return } + m.EncodingDomain = EncodingDomain(data[ptr]) + ptr++ + + m.LogSlots = int(data[ptr]) + ptr++ + m.IsNTT = p[n] == 1 n++ diff --git a/rlwe/params.go b/rlwe/params.go index e2831077..cf5a3992 100644 --- a/rlwe/params.go +++ b/rlwe/params.go @@ -381,11 +381,6 @@ func (p Parameters) Q() []uint64 { return qi } -// QiFloat64 returns the float64 value of the Qi at position level in the modulus chain. -func (p Parameters) QiFloat64(level int) float64 { - return float64(p.qi[level]) -} - // QCount returns the number of factors of the ciphertext modulus Q func (p Parameters) QCount() int { return len(p.qi) diff --git a/utils/bignum/complex.go b/utils/bignum/complex.go new file mode 100644 index 00000000..f11302f0 --- /dev/null +++ b/utils/bignum/complex.go @@ -0,0 +1,202 @@ +// Package bignum implements arbitrary precision arithmetic for integers, reals and complex numbers. +package bignum + +import ( + "fmt" + "math/big" +) + +// Complex is a type for arbitrary precision complex number +type Complex [2]*big.Float + +// NewComplex creates a new arbitrary precision complex number +func NewComplex() (c *Complex) { + return &Complex{ + new(big.Float), + new(big.Float), + } +} + +// ToComplex takes a complex128, float64, int, int64, uint, uint64, *big.Int, *big.Float or *Complex and returns a *Complex set to the given precision. +func ToComplex(value interface{}, prec uint) (cmplx *Complex) { + + cmplx = new(Complex) + + switch value := value.(type) { + case complex128: + cmplx[0] = new(big.Float).SetPrec(prec).SetFloat64(real(value)) + cmplx[1] = new(big.Float).SetPrec(prec).SetFloat64(imag(value)) + case float64: + cmplx[0] = new(big.Float).SetPrec(prec).SetFloat64(value) + cmplx[1] = new(big.Float).SetPrec(prec) + case int: + cmplx[0] = new(big.Float).SetPrec(prec).SetInt64(int64(value)) + cmplx[1] = new(big.Float).SetPrec(prec) + case int64: + cmplx[0] = new(big.Float).SetPrec(prec).SetInt64(value) + cmplx[1] = new(big.Float).SetPrec(prec) + case uint64: + return ToComplex(new(big.Int).SetUint64(value), prec) + case *big.Float: + cmplx[0] = new(big.Float).SetPrec(prec).Set(value) + cmplx[1] = new(big.Float).SetPrec(prec) + case *big.Int: + cmplx[0] = new(big.Float).SetPrec(prec).SetInt(value) + cmplx[1] = new(big.Float).SetPrec(prec) + case *Complex: + cmplx[0] = new(big.Float).SetPrec(prec).Set(value[0]) + cmplx[1] = new(big.Float).SetPrec(prec).Set(value[1]) + default: + panic(fmt.Errorf("invalid value.(type): must be int, int64, uint64, float64, complex128, *big.Int, *big.Float or *Complex but is %T", value)) + } + + return +} + +// IsInt returns true if both the real and imaginary part are integers. +func (c *Complex) IsInt() bool { + return c[0].IsInt() && c[1].IsInt() +} + +func (c *Complex) IsReal() bool { + return c[1].Cmp(new(big.Float)) == 0 +} + +func (c *Complex) SetComplex128(x complex128) { + c[0].SetFloat64(real(x)) + c[1].SetFloat64(real(x)) +} + +// Set sets a arbitrary precision complex number +func (c *Complex) Set(a *Complex) { + c[0].Set(a[0]) + c[1].Set(a[1]) +} + +func (c *Complex) Prec() uint { + return c[0].Prec() +} + +func (c *Complex) SetPrec(prec uint) { + c[0].SetPrec(prec) + c[1].SetPrec(prec) +} + +// Copy returns a new copy of the target arbitrary precision complex number +func (c *Complex) Copy() *Complex { + return &Complex{new(big.Float).Set(c[0]), new(big.Float).Set(c[1])} +} + +// Real returns the real part as a big.Float +func (c *Complex) Real() *big.Float { + return c[0] +} + +// Imag returns the imaginary part as a big.Float +func (c *Complex) Imag() *big.Float { + return c[1] +} + +// Complex128 returns the arbitrary precision complex number as a complex128 +func (c *Complex) Complex128() complex128 { + + real, _ := c[0].Float64() + imag, _ := c[1].Float64() + + return complex(real, imag) +} + +// Add adds two arbitrary precision complex numbers together +func (c *Complex) Add(a, b *Complex) { + c[0].Add(a[0], b[0]) + c[1].Add(a[1], b[1]) +} + +// Sub subtracts two arbitrary precision complex numbers together +func (c *Complex) Sub(a, b *Complex) { + c[0].Sub(a[0], b[0]) + c[1].Sub(a[1], b[1]) +} + +// ComplexMultiplier is a struct for the multiplication or division of two arbitrary precision complex numbers +type ComplexMultiplier struct { + tmp0 *big.Float + tmp1 *big.Float + tmp2 *big.Float + tmp3 *big.Float +} + +// NewComplexMultiplier creates a new ComplexMultiplier +func NewComplexMultiplier() (cEval *ComplexMultiplier) { + cEval = new(ComplexMultiplier) + cEval.tmp0 = new(big.Float) + cEval.tmp1 = new(big.Float) + cEval.tmp2 = new(big.Float) + cEval.tmp3 = new(big.Float) + return +} + +// Mul multiplies two arbitrary precision complex numbers together +func (cEval *ComplexMultiplier) Mul(a, b, c *Complex) { + + if a.IsReal() { + if b.IsReal() { + c[0].Mul(a[0], b[0]) + c[1].SetFloat64(0) + } else { + c[1].Mul(a[0], b[1]) + c[0].Mul(a[0], b[0]) + } + } else { + if b.IsReal() { + c[1].Mul(a[1], b[0]) + c[0].Mul(a[0], b[0]) + } else { + cEval.tmp0.Mul(a[0], b[0]) + cEval.tmp1.Mul(a[1], b[1]) + cEval.tmp2.Mul(a[0], b[1]) + cEval.tmp3.Mul(a[1], b[0]) + + c[0].Sub(cEval.tmp0, cEval.tmp1) + c[1].Add(cEval.tmp2, cEval.tmp3) + } + } +} + +// Quo divides two arbitrary precision complex numbers together +func (cEval *ComplexMultiplier) Quo(a, b, c *Complex) { + + if a.IsReal() { + if b.IsReal() { + c[0].Quo(a[0], b[0]) + c[1].SetFloat64(0) + } else { + c[1].Quo(a[0], b[1]) + c[0].Quo(a[0], b[0]) + } + } else { + if b.IsReal() { + c[1].Quo(a[1], b[0]) + c[0].Quo(a[0], b[0]) + } else { + // tmp0 = (a[0] * b[0]) + (a[1] * b[1]) real part + // tmp1 = (a[1] * b[0]) - (a[0] * b[0]) imag part + // tmp2 = (b[0] * b[0]) + (b[1] * b[1]) denominator + + cEval.tmp0.Mul(a[0], b[0]) + cEval.tmp1.Mul(a[1], b[1]) + cEval.tmp2.Mul(a[1], b[0]) + cEval.tmp3.Mul(a[0], b[1]) + + cEval.tmp0.Add(cEval.tmp0, cEval.tmp1) + cEval.tmp1.Sub(cEval.tmp2, cEval.tmp3) + + cEval.tmp2.Mul(b[0], b[0]) + cEval.tmp3.Mul(b[1], b[1]) + cEval.tmp2.Add(cEval.tmp2, cEval.tmp3) + + c[0].Quo(cEval.tmp0, cEval.tmp2) + c[1].Quo(cEval.tmp1, cEval.tmp2) + } + } +} diff --git a/utils/bignum/float.go b/utils/bignum/float.go new file mode 100644 index 00000000..6bf186ee --- /dev/null +++ b/utils/bignum/float.go @@ -0,0 +1,142 @@ +package bignum + +import ( + "math" + "math/big" + + "github.com/ALTree/bigfloat" +) + +const pi = "3.1415926535897932384626433832795028841971693993751058209749445923078164062862089986280348253421170679821480865132823066470938446095505822317253594081284811174502841027019385211055596446229489549303819644288109756659334461284756482337867831652712019091456485669234603486104543266482133936072602491412737245870066063155881748815209209628292540917153643678925903600113305305488204665213841469519415116094330572703657595919530921861173819326117931051185480744623799627495673518857527248912279381830119491298336733624406566430860213949463952247371907021798609437027705392171762931767523846748184676694051320005681271452635608277857713427577896091736371787214684409012249534301465495853710507922796892589235420199561121290219608640344181598136297747713099605187072113499999983729780499510597317328160963185950244594553469083026425223082533446850352619311881710100031378387528865875332083814206171776691473035982534904287554687311595628638823537875937519577818577805321712268066130019278766111959092164201989" + +// Pi returns Pi with prec bits of precision. +func Pi(prec uint) *big.Float { + pi, _ := new(big.Float).SetPrec(prec).SetString(pi) + return pi +} + +// NewFloat creates a new big.Float element with "prec" bits of precision +func NewFloat(x interface{}, prec uint) (y *big.Float) { + + y = new(big.Float) + y.SetPrec(prec) // decimal precision + + if x == nil { + return + } + + switch x := x.(type) { + case int: + y.SetInt64(int64(x)) + case int64: + y.SetInt64(x) + case uint: + y.SetUint64(uint64(x)) + case uint64: + y.SetUint64(x) + case float64: + y.SetFloat64(x) + case *big.Int: + y.SetInt(x) + case *big.Float: + y.Set(x) + } + + return +} + +// Cos is an iterative arbitrary precision computation of Cos(x) +// Iterative process with an error of ~10^{−0.60206*k} = (1/4)^k after k iterations. +// ref : Johansson, B. Tomas, An elementary algorithm to evaluate trigonometric functions to high precision, 2018 +func Cos(x *big.Float) (cosx *big.Float) { + tmp := new(big.Float) + + t := NewFloat(0.5, x.Prec()) + half := new(big.Float).Copy(t) + + for i := uint(1); i < (x.Prec()>>1)-1; i++ { + t.Mul(t, half) + } + + s := new(big.Float).Mul(x, t) + s.Mul(s, x) + s.Mul(s, t) + + four := NewFloat(4.0, x.Prec()) + + for i := uint(1); i < x.Prec()>>1; i++ { // (1/4)^k = (1/2)^(2*k) + tmp.Sub(four, s) + s.Mul(s, tmp) + } + + cosx = new(big.Float).Quo(s, NewFloat(2.0, x.Prec())) + cosx.Sub(NewFloat(1.0, x.Prec()), cosx) + return +} + +func Sin(x *big.Float) (sinx *big.Float) { + halfPi := Pi(x.Prec()) + halfPi.Quo(halfPi, new(big.Float).SetInt64(2)) + return Cos(new(big.Float).Sub(x, halfPi)) +} + +// Log return ln(x) with 2^precisions bits. +func Log(x *big.Float) (ln *big.Float) { + return bigfloat.Log(x) +} + +// Exp returns exp(x) with 2^precisions bits. +func Exp(x *big.Float) (exp *big.Float) { + return bigfloat.Exp(x) +} + +// Pow returns x^y +func Pow(x, y *big.Float) (pow *big.Float) { + return bigfloat.Pow(x, y) +} + +// SinH returns hyperbolic sin(x) with 2^precisions bits. +func SinH(x *big.Float) (sinh *big.Float) { + sinh = new(big.Float).Set(x) + sinh.Add(sinh, sinh) + sinh.Neg(sinh) + sinh = Exp(sinh) + sinh.Neg(sinh) + sinh.Add(sinh, NewFloat(1, x.Prec())) + tmp := new(big.Float).Set(x) + tmp.Neg(tmp) + tmp = Exp(tmp) + tmp.Add(tmp, tmp) + sinh.Quo(sinh, tmp) + return +} + +// TanH returns hyperbolic tan(x) with 2^precisions bits. +func TanH(x *big.Float) (tanh *big.Float) { + tanh = new(big.Float).Set(x) + tanh.Add(tanh, tanh) + tanh = Exp(tanh) + tmp := new(big.Float).Set(tanh) + tmp.Add(tmp, NewFloat(1, x.Prec())) + tanh.Sub(tanh, NewFloat(1, x.Prec())) + tanh.Quo(tanh, tmp) + return +} + +// ArithmeticGeometricMean returns the arithmetic–geometric mean of x and y with 2^precisions bits. +func ArithmeticGeometricMean(x, y *big.Float) *big.Float { + precision := x.Prec() + a := new(big.Float).Set(x) + g := new(big.Float).Set(y) + tmp := new(big.Float) + half := NewFloat(0.5, x.Prec()) + + for i := 0; i < int(math.Log2(float64(precision))); i++ { + tmp.Mul(a, g) + a.Add(a, g) + a.Mul(a, half) + g.Sqrt(tmp) + } + + return a +} diff --git a/ring/int.go b/utils/bignum/int.go similarity index 50% rename from ring/int.go rename to utils/bignum/int.go index 1da13258..5120d5fd 100644 --- a/ring/int.go +++ b/utils/bignum/int.go @@ -1,29 +1,40 @@ -package ring +package bignum import ( "crypto/rand" + "fmt" "io" "math/big" ) -// NewInt creates a new Int with a given int64 value. -func NewInt(v int64) *big.Int { - return new(big.Int).SetInt64(v) -} +func NewInt(x interface{}) (y *big.Int) { -// NewUint creates a new Int with a given uint64 value. -func NewUint(v uint64) *big.Int { - return new(big.Int).SetUint64(v) -} + y = new(big.Int) -// NewIntFromString creates a new Int from a string. -// A prefix of "0x" or "0X" selects base 16; -// the "0" prefix selects base 8, and -// a "0b" or "0B" prefix selects base 2. -// Otherwise, the selected base is 10. -func NewIntFromString(s string) *big.Int { - i, _ := new(big.Int).SetString(s, 0) - return i + if x == nil { + return + } + + switch x := x.(type) { + case string: + y.SetString(x, 0) + case uint: + y.SetUint64(uint64(x)) + case uint64: + y.SetUint64(x) + case int64: + y.SetInt64(x) + case int: + y.SetInt64(int64(x)) + case *big.Float: + x.Int(y) + case *big.Int: + y.Set(x) + default: + panic(fmt.Sprintf("cannot Newint: accepted types are string, uint, uint64, int, int64, *big.Float, *big.Int, but is %T", x)) + } + + return } // RandInt generates a random Int in [0, max-1]. diff --git a/ring/int_test.go b/utils/bignum/int_test.go similarity index 68% rename from ring/int_test.go rename to utils/bignum/int_test.go index a18fb9da..c17b7b32 100644 --- a/ring/int_test.go +++ b/utils/bignum/int_test.go @@ -1,4 +1,4 @@ -package ring +package bignum import ( "math" @@ -24,9 +24,9 @@ var divRoundVec = []argDivRound{ {NewInt(987654321), NewInt(123456789), NewInt(8)}, {NewInt(-987654320), NewInt(123456789), NewInt(-8)}, {NewInt(-121932631112635269), NewInt(-987654321), NewInt(123456789)}, - {NewIntFromString("123456789123456789123456789123456789"), NewInt(123456789), NewIntFromString("1000000001000000001000000001")}, - {NewIntFromString("987654321987654321987654321987654321"), NewIntFromString("123456789123456789123456789123456789"), NewInt(8)}, - {NewIntFromString("-987654321987654321987654321987654321"), NewIntFromString("-123456789123456789123456789123456789"), NewInt(8)}, + {NewInt("123456789123456789123456789123456789"), NewInt(123456789), NewInt("1000000001000000001000000001")}, + {NewInt("987654321987654321987654321987654321"), NewInt("123456789123456789123456789123456789"), NewInt(8)}, + {NewInt("-987654321987654321987654321987654321"), NewInt("-123456789123456789123456789123456789"), NewInt(8)}, } func TestDivRound(t *testing.T) { @@ -39,8 +39,8 @@ func TestDivRound(t *testing.T) { func BenchmarkDivRound(b *testing.B) { z := new(big.Int) - x := NewIntFromString("123456789123456789123456789123456789") - y := NewIntFromString("987654321987654321987654321987654321") + x := NewInt("123456789123456789123456789123456789") + y := NewInt("987654321987654321987654321987654321") for i := 0; i < b.N; i++ { DivRound(x, y, z) } diff --git a/utils/bignum/poly.go b/utils/bignum/poly.go new file mode 100644 index 00000000..89f65333 --- /dev/null +++ b/utils/bignum/poly.go @@ -0,0 +1,214 @@ +package bignum + +import ( + "fmt" + "math" + "math/big" +) + +// BasisType is a type for the polynomials basis +type BasisType int + +const ( + // Monomial : x^(a+b) = x^a * x^b + Monomial = BasisType(0) + // Chebyshev : T_(a+b) = 2 * T_a * T_b - T_(|a-b|) + Chebyshev = BasisType(1) +) + +type Interval struct { + A, B *big.Float +} + +type Polynomial struct { + BasisType + Interval + Coeffs []*Complex + IsOdd bool + IsEven bool +} + +// NewPolynomial creates a new polynomial from the input parameters: +// basis: either `Monomial` or `Chebyshev` +// coeffs: []complex128, []float64, []*Complex or []*big.Float +// interval: [2]float64{a, b} or *Interval +func NewPolynomial(basis BasisType, coeffs interface{}, interval interface{}) *Polynomial { + var coefficients []*Complex + + switch coeffs := coeffs.(type) { + case []complex128: + coefficients = make([]*Complex, len(coeffs)) + for i := range coeffs { + if c := coeffs[i]; c != 0 { + coefficients[i] = &Complex{ + new(big.Float).SetFloat64(real(c)), + new(big.Float).SetFloat64(imag(c)), + } + } + } + case []float64: + coefficients = make([]*Complex, len(coeffs)) + for i := range coeffs { + if c := coeffs[i]; c != 0 { + coefficients[i] = &Complex{ + new(big.Float).SetFloat64(c), + new(big.Float), + } + } + } + case []*Complex: + coefficients = make([]*Complex, len(coeffs)) + copy(coefficients, coeffs) + case []*big.Float: + coefficients = make([]*Complex, len(coeffs)) + for i := range coeffs { + if coeffs[i] != nil { + coefficients[i] = &Complex{ + new(big.Float).Set(coeffs[i]), + new(big.Float), + } + } + } + default: + panic(fmt.Sprintf("invalid coefficient type, allowed types are []{complex128, float64, *Complex, *big.Float} but is %T", coeffs)) + } + + inter := Interval{} + switch interval := interval.(type) { + case [2]float64: + inter.A = new(big.Float).SetFloat64(interval[0]) + inter.B = new(big.Float).SetFloat64(interval[1]) + case *Interval: + inter.A = new(big.Float).Set(interval.A) + inter.B = new(big.Float).Set(interval.B) + case nil: + + default: + panic(fmt.Sprintf("invalid interval type, allowed types are [2]float64 or *Interval, but is %T", interval)) + } + + return &Polynomial{ + BasisType: basis, + Interval: inter, + Coeffs: coefficients, + IsOdd: true, + IsEven: true, + } +} + +// ChangeOfBasis returns change of basis required to evaluate the polynomial +// Change of basis is defined as follow: +// - Monomial: scalar=1, constant=0. +// - Chebyshev: scalar=2/(b-a), constant = (-a-b)/(b-a). +func (p *Polynomial) ChangeOfBasis() (scalar, constant *big.Float) { + + switch p.BasisType { + case Monomial: + scalar = new(big.Float).SetInt64(1) + constant = new(big.Float) + case Chebyshev: + num := new(big.Float).Sub(p.B, p.A) + + // 2 / (b-a) + scalar = new(big.Float).Quo(new(big.Float).SetInt64(2), num) + + // (-b-a)/(b-a) + constant = new(big.Float).Set(p.B) + constant.Neg(constant) + constant.Sub(constant, p.A) + constant.Quo(constant, num) + default: + panic(fmt.Sprintf("invalid basis type, allowed types are `Monomial` or `Chebyshev` but is %T", p.BasisType)) + } + + return +} + +// Depth returns the number of sequential multiplications needed to evaluate the polynomial. +func (p *Polynomial) Depth() int { + return int(math.Ceil(math.Log2(float64(p.Degree())))) +} + +// Degree returns the degree of the polynomial. +func (p *Polynomial) Degree() int { + return len(p.Coeffs) - 1 +} + +// Evaluate takes x a *big.Float or *big.Complex and returns y = P(x). +// The precision of x is used as reference precision for y. +func (p *Polynomial) Evaluate(x interface{}) (y *Complex) { + + var xcmplx *Complex + switch x := x.(type) { + case *big.Float: + xcmplx = ToComplex(x, x.Prec()) + case *Complex: + xcmplx = ToComplex(x, x.Prec()) + default: + panic(fmt.Errorf("cannot Evaluate: accepted x.(type) are *big.Float and *Complex but x is %T", x)) + } + + coeffs := p.Coeffs + + n := len(coeffs) + + mul := NewComplexMultiplier() + + switch p.BasisType { + case Monomial: + y = coeffs[n-1].Copy() + y.SetPrec(xcmplx.Prec()) + for i := n - 2; i >= 0; i-- { + mul.Mul(y, xcmplx, y) + if coeffs[i] != nil { + y.Add(y, coeffs[i]) + } + } + + case Chebyshev: + + tmp := &Complex{new(big.Float), new(big.Float)} + + scalar, constant := p.ChangeOfBasis() + + xcmplx[0].Mul(xcmplx[0], scalar) + xcmplx[1].Mul(xcmplx[1], scalar) + + xcmplx[0].Add(xcmplx[0], constant) + xcmplx[1].Add(xcmplx[1], constant) + + TPrev := &Complex{new(big.Float).SetInt64(1), new(big.Float)} + + T := xcmplx + if coeffs[0] != nil { + y = coeffs[0].Copy() + } else { + y = &Complex{new(big.Float), new(big.Float)} + } + + y.SetPrec(xcmplx.Prec()) + + two := new(big.Float).SetInt64(2) + for i := 1; i < n; i++ { + + if coeffs[i] != nil { + mul.Mul(T, coeffs[i], tmp) + y.Add(y, tmp) + } + + tmp[0].Mul(xcmplx[0], two) + tmp[1].Mul(xcmplx[1], two) + + mul.Mul(tmp, T, tmp) + tmp.Sub(tmp, TPrev) + + TPrev = T.Copy() + T = tmp.Copy() + } + + default: + panic(fmt.Sprintf("invalid basis type, allowed types are `Monomial` or `Chebyshev` but is %T", p.BasisType)) + } + + return +} diff --git a/utils/sampling/prng.go b/utils/sampling/prng.go index 87f4f9ca..61ca3800 100644 --- a/utils/sampling/prng.go +++ b/utils/sampling/prng.go @@ -44,6 +44,15 @@ func NewPRNG() (*KeyedPRNG, error) { return prng, err } +// Key returns a copy of the key used to seed the PRNG. +// This value can be used with `NewKeyedPRNG` to instantiate +// a new PRNG that will produce the same stream of bytes. +func (prng *KeyedPRNG) Key() (key []byte) { + key = make([]byte, len(prng.key)) + copy(key, prng.key) + return +} + // Read reads bytes from the KeyedPRNG on sum. func (prng *KeyedPRNG) Read(sum []byte) (n int, err error) { if n, err = prng.xof.Read(sum); err != nil {