diff --git a/CHANGELOG.md b/CHANGELOG.md index 78e62159..8f9fcd64 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,32 @@ # Changelog All notable changes to this project will be documented in this file. +## [Unreleased] +- RING: added `MapSmallDimensionToLargerDimensionNTT` method which maps from Y = X^{N/n} to X in the NTT domain. +- RING: `FastBasisExtender` type can now extend the basis of polynomials of any level in base Q to polynomials of any level in base P. +- RING: changed RNS division `Div[floor/round]BylastModulus[NTT]` to `Div[floor/round]BylastModulus[NTT]Lvl` (the level of the last modulus must always be provided). +- RING: RNS division no longer modifies the output polynomial's level, this is to facilitate the usage of memory pools. +- RING: added the method `MFormVector`, which switches a slice of `uint64` into the Montgomery domain. +- RING: RNS scaler (used in BFV) does not modify the input anymore. +- RLWE: `GenSwitchingKey` now accepts secret-keys of different dimensions and level as input to enable re-encryption between different ciphertext degrees. +- RLWE: added `SwitchCiphertextRingDegreeNTT` and `SwitchCiphertextRingDegree` to switch ciphertext ring degrees. +- RLWE: added the `rlwe.RingQP` type to represent the extended ring R_qp. +- RLWE: added the `rlwe.PolyQP` type to represent polynomials in the extended ring R_qp. +- DRLWE: added the `CKGCRP`, `RKGCRP`, `RTGCRP` and `CKSCRP` types to represent the common reference polynomials in these protocols. +- DRLWE: added the `CRS` interface for PRNGs that implement a common reference string among the parties. +- DRLWE: added the `SampleCRP(crs CRS)` method to each protocol types to sample their respective CRP type. +- BFV: changed the plaintext scaling from `floor(Q/T)*m` to `round((Q*m)/T)` to reduce the initial ciphertext noise. +- CKKS: added the `ckks/advanced` sub-package and moved the homomorphic encoding, decoding and modular reduction into it. +- CKKS: added the `ckks/bootstrapping` sub-package and moved the CKKS bootstrapping into it. This package now mostly relies on the `ckks/advanced` package. +- CKKS: renamed the `ChebyshevInterpolation` type to `Polynomial`. +- CKKS: removed the `EvaluateCheby` method that was redundant with the `EvaluatePoly` one. +- CKKS: optimized the `EvaluatePoly` to account for odd/even polynomials and fixed some small imprecisions in scale management occurring for some specific polynomial degrees. +- CKKS: some advanced methods related to automorphisms are now public to facilitate their external use. +- CKKS: improved the consistency of the API for in-place and `[..]New` methods. +- CKKS: added the method `NewCiphertextAtLevelFromPoly`, which creates a ciphertext at a specific level from two polynomials. +- DBFV/DCKKS: both now use their respective CRP type for each protocol. +- EXAMPLE: added showcase of the `ckks/advanced` sub-package: a bridge between CKKS and FHEW ciphertexts using homomorphic decoding, ring dimension switching, homomorphic matrix multiplication and homomorphic modular reduction. + ## [2.2.0] - 2020-07-15 - Added SECURITY.md diff --git a/Makefile b/Makefile index 9c84bfad..ee4843e4 100644 --- a/Makefile +++ b/Makefile @@ -15,12 +15,14 @@ test_examples: @echo ok @echo Building resources-heavy examples go build -o /dev/null ./examples/ckks/bootstrapping + go build -o /dev/null ./examples/ckks/advanced @echo ok .PHONY: test_gotest test_gotest: go test -v -timeout=0 ./utils ./ring ./bfv ./dbfv ./dckks - go test -v -timeout=0 ./ckks -test-bootstrapping + go test -v -timeout=0 ./ckks/advanced + go test -v -timeout=0 ./ckks/bootstrapping -test-bootstrapping -short .PHONY: test test: test_fmt test_gotest test_examples diff --git a/bfv/bfv_test.go b/bfv/bfv_test.go index e32ace75..31930e3f 100644 --- a/bfv/bfv_test.go +++ b/bfv/bfv_test.go @@ -25,7 +25,6 @@ func testString(opname string, p Parameters) string { type testContext struct { params Parameters ringQ *ring.Ring - ringQP *ring.Ring ringT *ring.Ring prng utils.PRNG uSampler *ring.UniformSampler @@ -90,7 +89,6 @@ func genTestParams(params Parameters) (testctx *testContext, err error) { } testctx.ringQ = params.RingQ() - testctx.ringQP = params.RingQP() testctx.ringT = params.RingT() testctx.uSampler = ring.NewUniformSampler(testctx.prng, testctx.ringT) diff --git a/bfv/encoder.go b/bfv/encoder.go index 7a13b174..35d9ec97 100644 --- a/bfv/encoder.go +++ b/bfv/encoder.go @@ -60,12 +60,10 @@ type Encoder interface { type encoder struct { params Parameters - ringQ *ring.Ring - ringT *ring.Ring - indexMatrix []uint64 scaler ring.Scaler - deltaMont []uint64 + + rescaleParams []uint64 tmpPoly *ring.Poly tmpPtRt *PlaintextRingT @@ -101,35 +99,21 @@ func NewEncoder(params Parameters) Encoder { pos &= (m - 1) } + rescaleParams := make([]uint64, len(ringQ.Modulus)) + for i, qi := range ringQ.Modulus { + rescaleParams[i] = ring.MForm(ring.ModExp(params.T(), qi-2, qi), qi, ringQ.BredParams[i]) + } + return &encoder{ - params: params, - ringQ: ringQ, - ringT: ringT, - indexMatrix: indexMatrix, - deltaMont: GenLiftParams(ringQ, params.T()), - scaler: ring.NewRNSScaler(params.T(), ringQ), - tmpPoly: ringT.NewPoly(), - tmpPtRt: NewPlaintextRingT(params), + params: params, + indexMatrix: indexMatrix, + scaler: ring.NewRNSScaler(ringQ, ringT), + rescaleParams: rescaleParams, + tmpPoly: ringT.NewPoly(), + tmpPtRt: NewPlaintextRingT(params), } } -// GenLiftParams generates the lifting parameters. -func GenLiftParams(ringQ *ring.Ring, t uint64) (deltaMont []uint64) { - - delta := new(big.Int).Quo(ringQ.ModulusBigint, ring.NewUint(t)) - - deltaMont = make([]uint64, len(ringQ.Modulus)) - - tmp := new(big.Int) - bredParams := ringQ.BredParams - for i, Qi := range ringQ.Modulus { - deltaMont[i] = tmp.Mod(delta, ring.NewUint(Qi)).Uint64() - deltaMont[i] = ring.MForm(deltaMont[i], Qi, bredParams[i]) - } - - return -} - // EncodeUintRingT encodes a slice of uint64 into a Plaintext in R_t func (encoder *encoder) EncodeUintRingT(coeffs []uint64, p *PlaintextRingT) { if len(coeffs) > len(encoder.indexMatrix) { @@ -148,7 +132,7 @@ func (encoder *encoder) EncodeUintRingT(coeffs []uint64, p *PlaintextRingT) { p.Value.Coeffs[0][encoder.indexMatrix[i]] = 0 } - encoder.ringT.InvNTT(p.Value, p.Value) + encoder.params.RingT().InvNTT(p.Value, p.Value) } // EncodeUint encodes an uint64 slice of size at most N on a plaintext. @@ -198,7 +182,7 @@ func (encoder *encoder) EncodeIntRingT(coeffs []int64, p *PlaintextRingT) { p.Value.Coeffs[0][encoder.indexMatrix[i]] = 0 } - encoder.ringT.InvNTTLazy(p.Value, p.Value) + encoder.params.RingT().InvNTTLazy(p.Value, p.Value) } func (encoder *encoder) EncodeInt(coeffs []int64, p *Plaintext) { @@ -223,31 +207,57 @@ func (encoder *encoder) EncodeIntMul(coeffs []int64, p *PlaintextMul) { // ScaleUp transforms a PlaintextRingT (R_t) into a Plaintext (R_q) by scaling up the coefficient by Q/t. func (encoder *encoder) ScaleUp(ptRt *PlaintextRingT, pt *Plaintext) { - scaleUp(encoder.ringQ, encoder.deltaMont, ptRt.Value, pt.Value) + encoder.scaleUp(encoder.params.RingQ(), encoder.params.RingT(), encoder.tmpPoly.Coeffs[0], ptRt.Value, pt.Value) } -func scaleUp(ringQ *ring.Ring, deltaMont []uint64, pIn, pOut *ring.Poly) { +// takes m mod T and returns round((m*Q)/T) mod Q +func (encoder *encoder) scaleUp(ringQ, ringT *ring.Ring, tmp []uint64, pIn, pOut *ring.Poly) { - for i := len(ringQ.Modulus) - 1; i >= 0; i-- { - out := pOut.Coeffs[i] - in := pIn.Coeffs[0] - d := deltaMont[i] + qModTmontgomery := ring.MForm(new(big.Int).Mod(ringQ.ModulusBigint, ringT.ModulusBigint).Uint64(), ringT.Modulus[0], ringT.BredParams[0]) + + t := ringT.Modulus[0] + tHalf := t >> 1 + tInv := ringT.MredParams[0] + + // (x * Q + T/2) mod T + for i := 0; i < ringQ.N; i = i + 8 { + x := (*[8]uint64)(unsafe.Pointer(&pIn.Coeffs[0][i])) + z := (*[8]uint64)(unsafe.Pointer(&tmp[i])) + + z[0] = ring.CRed(ring.MRed(x[0], qModTmontgomery, t, tInv)+tHalf, t) + z[1] = ring.CRed(ring.MRed(x[1], qModTmontgomery, t, tInv)+tHalf, t) + z[2] = ring.CRed(ring.MRed(x[2], qModTmontgomery, t, tInv)+tHalf, t) + z[3] = ring.CRed(ring.MRed(x[3], qModTmontgomery, t, tInv)+tHalf, t) + z[4] = ring.CRed(ring.MRed(x[4], qModTmontgomery, t, tInv)+tHalf, t) + z[5] = ring.CRed(ring.MRed(x[5], qModTmontgomery, t, tInv)+tHalf, t) + z[6] = ring.CRed(ring.MRed(x[6], qModTmontgomery, t, tInv)+tHalf, t) + z[7] = ring.CRed(ring.MRed(x[7], qModTmontgomery, t, tInv)+tHalf, t) + } + + // (x * T^-1 - T/2) mod Qi + for i := 0; i < len(pOut.Coeffs); i++ { + p0tmp := tmp + p1tmp := pOut.Coeffs[i] qi := ringQ.Modulus[i] + bredParams := ringQ.BredParams[i] mredParams := ringQ.MredParams[i] + rescaleParams := qi - encoder.rescaleParams[i] + + tHalfNegQi := qi - ring.BRedAdd(tHalf, qi, bredParams) for j := 0; j < ringQ.N; j = j + 8 { - x := (*[8]uint64)(unsafe.Pointer(&in[j])) - z := (*[8]uint64)(unsafe.Pointer(&out[j])) + x := (*[8]uint64)(unsafe.Pointer(&p0tmp[j])) + z := (*[8]uint64)(unsafe.Pointer(&p1tmp[j])) - z[0] = ring.MRed(x[0], d, qi, mredParams) - z[1] = ring.MRed(x[1], d, qi, mredParams) - z[2] = ring.MRed(x[2], d, qi, mredParams) - z[3] = ring.MRed(x[3], d, qi, mredParams) - z[4] = ring.MRed(x[4], d, qi, mredParams) - z[5] = ring.MRed(x[5], d, qi, mredParams) - z[6] = ring.MRed(x[6], d, qi, mredParams) - z[7] = ring.MRed(x[7], d, qi, mredParams) + z[0] = ring.MRed(x[0]+tHalfNegQi, rescaleParams, qi, mredParams) + z[1] = ring.MRed(x[1]+tHalfNegQi, rescaleParams, qi, mredParams) + z[2] = ring.MRed(x[2]+tHalfNegQi, rescaleParams, qi, mredParams) + z[3] = ring.MRed(x[3]+tHalfNegQi, rescaleParams, qi, mredParams) + z[4] = ring.MRed(x[4]+tHalfNegQi, rescaleParams, qi, mredParams) + z[5] = ring.MRed(x[5]+tHalfNegQi, rescaleParams, qi, mredParams) + z[6] = ring.MRed(x[6]+tHalfNegQi, rescaleParams, qi, mredParams) + z[7] = ring.MRed(x[7]+tHalfNegQi, rescaleParams, qi, mredParams) } } } @@ -263,19 +273,19 @@ func (encoder *encoder) RingTToMul(ptRt *PlaintextRingT, ptMul *PlaintextMul) { if ptRt.Value != ptMul.Value { copy(ptMul.Value.Coeffs[0], ptRt.Value.Coeffs[0]) } - for i := 1; i < len(encoder.ringQ.Modulus); i++ { + for i := 1; i < len(encoder.params.RingQ().Modulus); i++ { copy(ptMul.Value.Coeffs[i], ptRt.Value.Coeffs[0]) } - encoder.ringQ.NTTLazy(ptMul.Value, ptMul.Value) - encoder.ringQ.MForm(ptMul.Value, ptMul.Value) + encoder.params.RingQ().NTTLazy(ptMul.Value, ptMul.Value) + encoder.params.RingQ().MForm(ptMul.Value, ptMul.Value) } // MulToRingT transforms a PlaintextMul into PlaintextRingT by operating the inverse NTT transform of R_q and // putting the coefficients out of the Montgomery form. func (encoder *encoder) MulToRingT(pt *PlaintextMul, ptRt *PlaintextRingT) { - encoder.ringQ.InvNTTLvl(0, pt.Value, ptRt.Value) - encoder.ringQ.InvMFormLvl(0, ptRt.Value, ptRt.Value) + encoder.params.RingQ().InvNTTLvl(0, pt.Value, ptRt.Value) + encoder.params.RingQ().InvMFormLvl(0, ptRt.Value, ptRt.Value) } // DecodeRingT decodes any plaintext type into a PlaintextRingT. It panics if p is not PlaintextRingT, Plaintext or PlaintextMul. @@ -302,9 +312,9 @@ func (encoder *encoder) DecodeUint(p interface{}, coeffs []uint64) { ptRt = encoder.tmpPtRt } - encoder.ringT.NTT(ptRt.Value, encoder.tmpPoly) + encoder.params.RingT().NTT(ptRt.Value, encoder.tmpPoly) - for i := 0; i < encoder.ringQ.N; i++ { + for i := 0; i < encoder.params.RingQ().N; i++ { coeffs[i] = encoder.tmpPoly.Coeffs[0][encoder.indexMatrix[i]] } } @@ -312,7 +322,7 @@ func (encoder *encoder) DecodeUint(p interface{}, coeffs []uint64) { // DecodeUintNew decodes any plaintext type and returns the coefficients in a new []uint64. // It panics if p is not PlaintextRingT, Plaintext or PlaintextMul. func (encoder *encoder) DecodeUintNew(p interface{}) (coeffs []uint64) { - coeffs = make([]uint64, encoder.ringQ.N) + coeffs = make([]uint64, encoder.params.RingQ().N) encoder.DecodeUint(p, coeffs) return } @@ -323,12 +333,12 @@ func (encoder *encoder) DecodeInt(p interface{}, coeffs []int64) { encoder.DecodeRingT(p, encoder.tmpPtRt) - encoder.ringT.NTT(encoder.tmpPtRt.Value, encoder.tmpPoly) + encoder.params.RingT().NTT(encoder.tmpPtRt.Value, encoder.tmpPoly) modulus := int64(encoder.params.T()) modulusHalf := modulus >> 1 var value int64 - for i := 0; i < encoder.ringQ.N; i++ { + for i := 0; i < encoder.params.RingQ().N; i++ { value = int64(encoder.tmpPoly.Coeffs[0][encoder.indexMatrix[i]]) coeffs[i] = value @@ -341,7 +351,7 @@ func (encoder *encoder) DecodeInt(p interface{}, coeffs []int64) { // DecodeIntNew decodes any plaintext type and returns the coefficients in a new []int64. It also decodes the sign // modulus (by centering the values around the plaintext). It panics if p is not PlaintextRingT, Plaintext or PlaintextMul. func (encoder *encoder) DecodeIntNew(p interface{}) (coeffs []int64) { - coeffs = make([]int64, encoder.ringQ.N) + coeffs = make([]int64, encoder.params.RingQ().N) encoder.DecodeInt(p, coeffs) return } diff --git a/bfv/evaluator.go b/bfv/evaluator.go index 1da3aaff..6c9cbf80 100644 --- a/bfv/evaluator.go +++ b/bfv/evaluator.go @@ -55,6 +55,8 @@ type evaluator struct { *evaluatorBuffers *rlwe.KeySwitcher + lightEncoder *encoder + rlk *rlwe.RelinearizationKey rtks *rlwe.RotationKeySet @@ -69,8 +71,6 @@ type evaluatorBase struct { t uint64 pHalf *big.Int - - deltaMont []uint64 } func newEvaluatorPrecomp(params Parameters) *evaluatorBase { @@ -85,7 +85,6 @@ func newEvaluatorPrecomp(params Parameters) *evaluatorBase { ev.ringQMul = params.RingQMul() ev.pHalf = new(big.Int).Rsh(ev.ringQMul.ModulusBigint, 1) - ev.deltaMont = GenLiftParams(ev.ringQ, params.T()) return ev } @@ -121,6 +120,13 @@ func NewEvaluator(params Parameters, evaluationKey rlwe.EvaluationKey) Evaluator ev := new(evaluator) ev.evaluatorBase = newEvaluatorPrecomp(params) ev.evaluatorBuffers = newEvaluatorBuffer(ev.evaluatorBase) + + rescaleParams := make([]uint64, len(params.RingQ().Modulus)) + for i, qi := range params.RingQ().Modulus { + rescaleParams[i] = ring.MForm(ring.ModExp(params.T(), qi-2, qi), qi, params.RingQ().BredParams[i]) + } + + ev.lightEncoder = &encoder{rescaleParams: rescaleParams} ev.baseconverterQ1Q2 = ring.NewFastBasisExtender(ev.ringQ, ev.ringQMul) if params.PCount() != 0 { ev.KeySwitcher = rlwe.NewKeySwitcher(params.Parameters) @@ -312,7 +318,7 @@ func (eval *evaluator) tensorAndRescale(ct0, ct1, ctOut *rlwe.Ciphertext) { func (eval *evaluator) modUpAndNTT(ct *rlwe.Ciphertext, cQ, cQMul []*ring.Poly) { levelQ := len(eval.ringQ.Modulus) - 1 for i := range ct.Value { - eval.baseconverterQ1Q2.ModUpSplitQP(levelQ, ct.Value[i], cQMul[i]) + eval.baseconverterQ1Q2.ModUpQtoP(levelQ, len(eval.ringQMul.Modulus)-1, ct.Value[i], cQMul[i]) eval.ringQ.NTTLazy(ct.Value[i], cQ[i]) eval.ringQMul.NTTLazy(cQMul[i], cQMul[i]) } @@ -448,11 +454,11 @@ func (eval *evaluator) quantize(ctOut *rlwe.Ciphertext) { eval.ringQMul.InvNTTLazy(c2Q2[i], c2Q2[i]) // Extends the basis Q of ct(x) to the basis P and Divides (ct(x)Q -> P) by Q - eval.baseconverterQ1Q2.ModDownSplitQP(levelQ, levelQMul, c2Q1[i], c2Q2[i], c2Q2[i]) + eval.baseconverterQ1Q2.ModDownQPtoP(levelQ, levelQMul, c2Q1[i], c2Q2[i], c2Q2[i]) // Centers (ct(x)Q -> P)/Q by (P-1)/2 and extends ((ct(x)Q -> P)/Q) to the basis Q eval.ringQMul.AddScalarBigint(c2Q2[i], eval.pHalf, c2Q2[i]) - eval.baseconverterQ1Q2.ModUpSplitPQ(levelQMul, c2Q2[i], ctOut.Value[i]) + eval.baseconverterQ1Q2.ModUpPtoQ(levelQMul, levelQ, c2Q2[i], ctOut.Value[i]) eval.ringQ.SubScalarBigint(ctOut.Value[i], eval.pHalf, ctOut.Value[i]) // Option (2) (ct(x)/Q)*T, doing so only requires that Q*P > Q*Q, faster but adds error ~|T| @@ -548,9 +554,9 @@ func (eval *evaluator) relinearize(ct0 *Ciphertext, ctOut *Ciphertext) { } for deg := uint64(ct0.Degree()); deg > 1; deg-- { - eval.SwitchKeysInPlace(ct0.Value[deg].Level(), ct0.Value[deg], eval.rlk.Keys[deg-2], eval.PoolQ[1], eval.PoolQ[2]) - eval.ringQ.Add(ctOut.Value[0], eval.PoolQ[1], ctOut.Value[0]) - eval.ringQ.Add(ctOut.Value[1], eval.PoolQ[2], ctOut.Value[1]) + eval.SwitchKeysInPlace(ct0.Value[deg].Level(), ct0.Value[deg], eval.rlk.Keys[deg-2], eval.Pool[1].Q, eval.Pool[2].Q) + eval.ringQ.Add(ctOut.Value[0], eval.Pool[1].Q, ctOut.Value[0]) + eval.ringQ.Add(ctOut.Value[1], eval.Pool[2].Q, ctOut.Value[1]) } ctOut.SetValue(ctOut.Value[:2]) @@ -605,10 +611,10 @@ func (eval *evaluator) SwitchKeys(ct0 *Ciphertext, switchKey *rlwe.SwitchingKey, panic("cannot SwitchKeys: input and output must be of degree 1 to allow key switching") } - eval.SwitchKeysInPlace(ct0.Value[1].Level(), ct0.Value[1], switchKey, eval.PoolQ[1], eval.PoolQ[2]) + eval.SwitchKeysInPlace(ct0.Value[1].Level(), ct0.Value[1], switchKey, eval.Pool[1].Q, eval.Pool[2].Q) - eval.ringQ.Add(ct0.Value[0], eval.PoolQ[1], ctOut.Value[0]) - ring.CopyValues(eval.PoolQ[2], ctOut.Value[1]) + eval.ringQ.Add(ct0.Value[0], eval.Pool[1].Q, ctOut.Value[0]) + ring.CopyValues(eval.Pool[2].Q, ctOut.Value[1]) } // SwitchKeysNew applies the key-switching procedure to the ciphertext ct0 and creates a new ciphertext to store the result. It requires as an additional input a valid switching-key: @@ -702,12 +708,12 @@ func (eval *evaluator) InnerSum(ct0 *Ciphertext, ctOut *Ciphertext) { // permute performs a column rotation on ct0 and returns the result in ctOut func (eval *evaluator) permute(ct0 *Ciphertext, generator uint64, switchKey *rlwe.SwitchingKey, ctOut *Ciphertext) { - eval.SwitchKeysInPlace(ct0.Value[1].Level(), ct0.Value[1], switchKey, eval.PoolQ[1], eval.PoolQ[2]) + eval.SwitchKeysInPlace(ct0.Value[1].Level(), ct0.Value[1], switchKey, eval.Pool[1].Q, eval.Pool[2].Q) - eval.ringQ.Add(eval.PoolQ[1], ct0.Value[0], eval.PoolQ[1]) + eval.ringQ.Add(eval.Pool[1].Q, ct0.Value[0], eval.Pool[1].Q) - eval.ringQ.Permute(eval.PoolQ[1], generator, ctOut.Value[0]) - eval.ringQ.Permute(eval.PoolQ[2], generator, ctOut.Value[1]) + eval.ringQ.Permute(eval.Pool[1].Q, generator, ctOut.Value[0]) + eval.ringQ.Permute(eval.Pool[2].Q, generator, ctOut.Value[1]) } func (eval *evaluator) getRingQElem(op Operand) *rlwe.Ciphertext { @@ -715,7 +721,7 @@ func (eval *evaluator) getRingQElem(op Operand) *rlwe.Ciphertext { case *Ciphertext, *Plaintext: return o.El() case *PlaintextRingT: - scaleUp(eval.ringQ, eval.deltaMont, o.Value, eval.tmpPt.Value) + eval.lightEncoder.scaleUp(eval.params.RingQ(), eval.params.RingT(), eval.Pool[0].Q.Coeffs[0], o.Value, eval.tmpPt.Value) return eval.tmpPt.El() default: panic(fmt.Errorf("invalid operand type for operation: %T", o)) diff --git a/bfv/keys.go b/bfv/keys.go index dbb79eca..2909a167 100644 --- a/bfv/keys.go +++ b/bfv/keys.go @@ -19,7 +19,7 @@ func NewPublicKey(params Parameters) (pk *rlwe.PublicKey) { // NewSwitchingKey returns an allocated BFV public switching key with zero values. func NewSwitchingKey(params Parameters) *rlwe.SwitchingKey { - return rlwe.NewSwitchingKey(params.Parameters) + return rlwe.NewSwitchingKey(params.Parameters, params.QCount()-1, params.PCount()-1) } // NewRelinearizationKey returns an allocated BFV public relinearization key with zero value for each degree in [2 < maxRelinDegree]. diff --git a/bfv/utils.go b/bfv/utils.go new file mode 100644 index 00000000..e3a24b8b --- /dev/null +++ b/bfv/utils.go @@ -0,0 +1,84 @@ +package bfv + +import ( + "fmt" + "github.com/ldsec/lattigo/v2/ring" + "math" + "math/big" +) + +// DecryptAndPrintError decrypts a ciphertext and prints the log2 of the error. +func DecryptAndPrintError(ptWant *Plaintext, cthave *Ciphertext, ringQ *ring.Ring, decryptor Decryptor) { + ringQ.Sub(cthave.Value[0], ptWant.Value, cthave.Value[0]) + plaintext := decryptor.DecryptNew(cthave) + bigintCoeffs := make([]*big.Int, ringQ.N) + ringQ.PolyToBigint(plaintext.Value, bigintCoeffs) + center(bigintCoeffs, ringQ.ModulusBigint) + stdErr, minErr, maxErr := errorStats(bigintCoeffs) + fmt.Printf("STD : %f - Min : %f - Max : %f\n", math.Log2(stdErr), math.Log2(minErr), math.Log2(maxErr)) +} + +func errorStats(vec []*big.Int) (float64, float64, float64) { + + vecfloat := make([]*big.Float, len(vec)) + minErr := new(big.Float).SetFloat64(0) + maxErr := new(big.Float).SetFloat64(0) + tmp := new(big.Float) + minErr.SetInt(vec[0]) + minErr.Abs(minErr) + for i := range vec { + vecfloat[i] = new(big.Float) + vecfloat[i].SetInt(vec[i]) + + tmp.Abs(vecfloat[i]) + + if minErr.Cmp(tmp) == 1 { + minErr.Set(tmp) + } + + if maxErr.Cmp(tmp) == -1 { + maxErr.Set(tmp) + } + } + + n := new(big.Float).SetFloat64(float64(len(vec))) + + mean := new(big.Float).SetFloat64(0) + + for _, c := range vecfloat { + mean.Add(mean, c) + } + + mean.Quo(mean, n) + + err := new(big.Float).SetFloat64(0) + for _, c := range vecfloat { + tmp.Sub(c, mean) + tmp.Mul(tmp, tmp) + err.Add(err, tmp) + } + + err.Quo(err, n) + err.Sqrt(err) + + x, _ := err.Float64() + y, _ := minErr.Float64() + z, _ := maxErr.Float64() + + return x, y, z + +} + +func center(coeffs []*big.Int, Q *big.Int) { + qHalf := new(big.Int) + qHalf.Set(Q) + qHalf.Rsh(qHalf, 1) + var sign int + for i := range coeffs { + coeffs[i].Mod(coeffs[i], Q) + sign = coeffs[i].Cmp(qHalf) + if sign == 1 || sign == 0 { + coeffs[i].Sub(coeffs[i], Q) + } + } +} diff --git a/ckks/bettersine/bettersine.go b/ckks/advanced/cosine_approx.go similarity index 82% rename from ckks/bettersine/bettersine.go rename to ckks/advanced/cosine_approx.go index b72787dc..10b8816e 100644 --- a/ckks/bettersine/bettersine.go +++ b/ckks/advanced/cosine_approx.go @@ -1,14 +1,74 @@ -package bettersine +package advanced -// This is the Go implementation of the approximation polynomial algorithm in +// This is the Go implementation of the approximation polynomial algorithm from Han and Ki in // "Better Bootstrapping for Approximate Homomorphic Encryption", . // The algorithm was originally implemented in C++, available at // https://github.com/DohyeongKi/better-homomorphic-sine-evaluation import ( + //"fmt" + "math" "math/big" ) +// NewFloat creates a new big.Float element with 1000 bits of precision +func NewFloat(x float64) (y *big.Float) { + y = new(big.Float) + y.SetPrec(1000) // log2 precision + y.SetFloat64(x) + return +} + +// BigintCos is an iterative arbitrary precision computation of Cos(x) +// Iterative process with an error of ~10^{−0.60206*k} after k iterations. +// ref : Johansson, B. Tomas, An elementary algorithm to evaluate trigonometric functions to high precision, 2018 +func BigintCos(x *big.Float) (cosx *big.Float) { + tmp := new(big.Float) + + k := 1000 // number of iterations + t := NewFloat(0.5) + half := new(big.Float).Copy(t) + + for i := 1; i < k-1; i++ { + t.Mul(t, half) + } + + s := new(big.Float).Mul(x, t) + s.Mul(s, x) + s.Mul(s, t) + + four := NewFloat(4.0) + + for i := 1; i < k; i++ { + tmp.Sub(four, s) + s.Mul(s, tmp) + } + + cosx = new(big.Float).Quo(s, NewFloat(2.0)) + cosx.Sub(NewFloat(1.0), cosx) + return + +} + +// BigintSin is an iterative arbitrary precision computation of Sin(x) +func BigintSin(x *big.Float) (sinx *big.Float) { + + sinx = NewFloat(1) + tmp := BigintCos(x) + tmp.Mul(tmp, tmp) + sinx.Sub(sinx, tmp) + sinx.Sqrt(sinx) + return +} + +func log2(x float64) float64 { + return math.Log2(x) +} + +func abs(x float64) float64 { + return math.Abs(x) +} + var pi = "3.1415926535897932384626433832795028841971693993751058209749445923078164062862089986280348253421170679821480865132823066470938446095505822317253594081284811174502841027019385211055596446229489549303819644288109756659334461284756482337867831652712019091456485669234603486104543266482133936072602491412737245870066063155881748815209209628292540917153643678925903600113305305488204665213841469519415116094330572703657595919530921861173819326117931051185480744623799627495673518857527248912279381830119491298336733624406566430860213949463952247371907021798609437027705392171762931767523846748184676694051320005681271452635608277857713427577896091736371787214684409012249534301465495853710507922796892589235420199561121290219608640344181598136297747713099605187072113499999983729780499510597317328160963185950244594553469083026425223082533446850352619311881710100031378387528865875332083814206171776691473035982534904287554687311595628638823537875937519577818577805321712268066130019278766111959092164201989" var mPI = 3.141592653589793238462643383279502884 @@ -141,7 +201,7 @@ func genNodes(deg []int, dev float64, totdeg, K, scnum int) ([]*big.Float, []*bi tmp = NewFloat(float64(2*j - 1)) tmp.Mul(tmp, PI) tmp.Quo(tmp, NewFloat(float64(2*deg[i]))) - tmp = Cos(tmp) + tmp = BigintCos(tmp) tmp.Mul(tmp, intersize) @@ -160,7 +220,7 @@ func genNodes(deg []int, dev float64, totdeg, K, scnum int) ([]*big.Float, []*bi tmp = NewFloat(float64(2*j - 1)) tmp.Mul(tmp, PI) tmp.Quo(tmp, NewFloat(float64(2*deg[0]))) - tmp = Cos(tmp) + tmp = BigintCos(tmp) tmp.Mul(tmp, intersize) z[cnt] = new(big.Float).Add(NewFloat(0), tmp) @@ -181,7 +241,7 @@ func genNodes(deg []int, dev float64, totdeg, K, scnum int) ([]*big.Float, []*bi z[i].Quo(z[i], scfac) d[i].Mul(d[i], z[i]) - d[i] = Cos(d[i]) + d[i] = BigintCos(d[i]) //tmp := new(big.Float).Sqrt(PI) //tmp.Sqrt(tmp) @@ -204,7 +264,7 @@ func genNodes(deg []int, dev float64, totdeg, K, scnum int) ([]*big.Float, []*bi x[i].Quo(x[i], scfac) tmp.Mul(NewFloat(float64(i)), PI) tmp.Quo(tmp, NewFloat(float64(totdeg-1))) - x[i].Mul(x[i], Cos(tmp)) + x[i].Mul(x[i], BigintCos(tmp)) } var c = make([]*big.Float, totdeg) @@ -221,10 +281,10 @@ func genNodes(deg []int, dev float64, totdeg, K, scnum int) ([]*big.Float, []*bi return x, p, c, totdeg } -// Approximate computes a polynomial approximation of degree "degree" in Chevyshev basis of the function +// ApproximateCos computes a polynomial approximation of degree "degree" in Chevyshev basis of the function // cos(2*pi*x/2^"scnum") in the range -"K" to "K" // The nodes of the Chevyshev approximation are are located from -dev to +dev at each integer value between -K and -K -func Approximate(K, degree int, dev float64, scnum int) []complex128 { +func ApproximateCos(K, degree int, dev float64, scnum int) []complex128 { var scfac = NewFloat(float64(int(1 << scnum))) diff --git a/ckks/advanced/evaluator.go b/ckks/advanced/evaluator.go new file mode 100644 index 00000000..9a7820ec --- /dev/null +++ b/ckks/advanced/evaluator.go @@ -0,0 +1,287 @@ +package advanced + +import ( + "github.com/ldsec/lattigo/v2/ckks" + "github.com/ldsec/lattigo/v2/ring" + "github.com/ldsec/lattigo/v2/rlwe" + "github.com/ldsec/lattigo/v2/utils" + "math" +) + +// Evaluator is an interface embeding the ckks.Evaluator interface with +// additional advanced arithmetic features. +type Evaluator interface { + + // ======================================= + // === Original ckks.Evaluator methods === + // ======================================= + + Add(op0, op1 ckks.Operand, ctOut *ckks.Ciphertext) + AddNoMod(op0, op1 ckks.Operand, ctOut *ckks.Ciphertext) + AddNew(op0, op1 ckks.Operand) (ctOut *ckks.Ciphertext) + AddNoModNew(op0, op1 ckks.Operand) (ctOut *ckks.Ciphertext) + Sub(op0, op1 ckks.Operand, ctOut *ckks.Ciphertext) + SubNoMod(op0, op1 ckks.Operand, ctOut *ckks.Ciphertext) + SubNew(op0, op1 ckks.Operand) (ctOut *ckks.Ciphertext) + SubNoModNew(op0, op1 ckks.Operand) (ctOut *ckks.Ciphertext) + Neg(ctIn *ckks.Ciphertext, ctOut *ckks.Ciphertext) + NegNew(ctIn *ckks.Ciphertext) (ctOut *ckks.Ciphertext) + AddConstNew(ctIn *ckks.Ciphertext, constant interface{}) (ctOut *ckks.Ciphertext) + AddConst(ctIn *ckks.Ciphertext, constant interface{}, ctOut *ckks.Ciphertext) + MultByConstNew(ctIn *ckks.Ciphertext, constant interface{}) (ctOut *ckks.Ciphertext) + MultByConst(ctIn *ckks.Ciphertext, constant interface{}, ctOut *ckks.Ciphertext) + MultByGaussianInteger(ctIn *ckks.Ciphertext, cReal, cImag interface{}, ctOut *ckks.Ciphertext) + MultByConstAndAdd(ctIn *ckks.Ciphertext, constant interface{}, ctOut *ckks.Ciphertext) + MultByGaussianIntegerAndAdd(ctIn *ckks.Ciphertext, cReal, cImag interface{}, ctOut *ckks.Ciphertext) + MultByiNew(ctIn *ckks.Ciphertext) (ctOut *ckks.Ciphertext) + MultByi(ctIn *ckks.Ciphertext, ctOut *ckks.Ciphertext) + DivByiNew(ctIn *ckks.Ciphertext) (ctOut *ckks.Ciphertext) + DivByi(ctIn *ckks.Ciphertext, ctOut *ckks.Ciphertext) + ConjugateNew(ctIn *ckks.Ciphertext) (ctOut *ckks.Ciphertext) + Conjugate(ctIn *ckks.Ciphertext, ctOut *ckks.Ciphertext) + Mul(op0, op1 ckks.Operand, ctOut *ckks.Ciphertext) + MulNew(op0, op1 ckks.Operand) (ctOut *ckks.Ciphertext) + MulRelin(op0, op1 ckks.Operand, ctOut *ckks.Ciphertext) + MulRelinNew(op0, op1 ckks.Operand) (ctOut *ckks.Ciphertext) + RotateNew(ctIn *ckks.Ciphertext, k int) (ctOut *ckks.Ciphertext) + Rotate(ctIn *ckks.Ciphertext, k int, ctOut *ckks.Ciphertext) + RotateHoistedNew(ctIn *ckks.Ciphertext, rotations []int) (ctOut map[int]*ckks.Ciphertext) + RotateHoisted(ctIn *ckks.Ciphertext, rotations []int, ctOut map[int]*ckks.Ciphertext) + MulByPow2New(ctIn *ckks.Ciphertext, pow2 int) (ctOut *ckks.Ciphertext) + MulByPow2(ctIn *ckks.Ciphertext, pow2 int, ctOut *ckks.Ciphertext) + PowerOf2(ctIn *ckks.Ciphertext, logPow2 int, ctOut *ckks.Ciphertext) + Power(ctIn *ckks.Ciphertext, degree int, ctOut *ckks.Ciphertext) + PowerNew(ctIn *ckks.Ciphertext, degree int) (ctOut *ckks.Ciphertext) + EvaluatePoly(ctIn *ckks.Ciphertext, pol *ckks.Polynomial, targetScale float64) (ctOut *ckks.Ciphertext, err error) + InverseNew(ctIn *ckks.Ciphertext, steps int) (ctOut *ckks.Ciphertext) + LinearTransformNew(ctIn *ckks.Ciphertext, linearTransform interface{}) (ctOut []*ckks.Ciphertext) + LinearTransform(ctIn *ckks.Ciphertext, linearTransform interface{}, ctOut []*ckks.Ciphertext) + MultiplyByDiagMatrix(ctIn *ckks.Ciphertext, matrix ckks.PtDiagMatrix, c2DecompQP []rlwe.PolyQP, ctOut *ckks.Ciphertext) + MultiplyByDiagMatrixBSGS(ctIn *ckks.Ciphertext, matrix ckks.PtDiagMatrix, c2DecompQP []rlwe.PolyQP, ctOut *ckks.Ciphertext) + InnerSumLog(ctIn *ckks.Ciphertext, batch, n int, ctOut *ckks.Ciphertext) + InnerSum(ctIn *ckks.Ciphertext, batch, n int, ctOut *ckks.Ciphertext) + ReplicateLog(ctIn *ckks.Ciphertext, batch, n int, ctOut *ckks.Ciphertext) + Replicate(ctIn *ckks.Ciphertext, batch, n int, ctOut *ckks.Ciphertext) + TraceNew(ctIn *ckks.Ciphertext, logSlotsStart, logSlotsEnd int) *ckks.Ciphertext + Trace(ctIn *ckks.Ciphertext, logSlotsStart, logSlotsEnd int, ctOut *ckks.Ciphertext) + SwitchKeysNew(ctIn *ckks.Ciphertext, switchingKey *rlwe.SwitchingKey) (ctOut *ckks.Ciphertext) + SwitchKeys(ctIn *ckks.Ciphertext, switchingKey *rlwe.SwitchingKey, ctOut *ckks.Ciphertext) + RelinearizeNew(ctIn *ckks.Ciphertext) (ctOut *ckks.Ciphertext) + Relinearize(ctIn *ckks.Ciphertext, ctOut *ckks.Ciphertext) + ScaleUpNew(ctIn *ckks.Ciphertext, scale float64) (ctOut *ckks.Ciphertext) + ScaleUp(ctIn *ckks.Ciphertext, scale float64, ctOut *ckks.Ciphertext) + SetScale(ctIn *ckks.Ciphertext, scale float64) + Rescale(ctIn *ckks.Ciphertext, minScale float64, ctOut *ckks.Ciphertext) (err error) + DropLevelNew(ctIn *ckks.Ciphertext, levels int) (ctOut *ckks.Ciphertext) + DropLevel(ctIn *ckks.Ciphertext, levels int) + ReduceNew(ctIn *ckks.Ciphertext) (ctOut *ckks.Ciphertext) + Reduce(ctIn *ckks.Ciphertext, ctOut *ckks.Ciphertext) error + + // ====================================== + // === advanced.Evaluator new methods === + // ====================================== + + CoeffsToSlotsNew(ctIn *ckks.Ciphertext, ctsMatrices EncodingMatrix) (ctReal, ctImag *ckks.Ciphertext) + CoeffsToSlots(ctIn *ckks.Ciphertext, ctsMatrices EncodingMatrix, ctReal, ctImag *ckks.Ciphertext) + SlotsToCoeffsNew(ctReal, ctImag *ckks.Ciphertext, stcMatrices EncodingMatrix) (ctOut *ckks.Ciphertext) + SlotsToCoeffs(ctReal, ctImag *ckks.Ciphertext, stcMatrices EncodingMatrix, ctOut *ckks.Ciphertext) + EvalModNew(ctIn *ckks.Ciphertext, evalModPoly EvalModPoly) (ctOut *ckks.Ciphertext) + + // ================================================= + // === original ckks.Evaluator redefined methods === + // ================================================= + + GetKeySwitcher() *rlwe.KeySwitcher + PoolQMul() [3]*ring.Poly + CtxPool() *ckks.Ciphertext + ShallowCopy() Evaluator + WithKey(rlwe.EvaluationKey) Evaluator +} + +type evaluator struct { + ckks.Evaluator + params ckks.Parameters +} + +// NewEvaluator creates a new Evaluator. +func NewEvaluator(params ckks.Parameters, evaluationKey rlwe.EvaluationKey) Evaluator { + return &evaluator{ckks.NewEvaluator(params, evaluationKey), params} +} + +// ShallowCopy creates a shallow copy of this evaluator in which all the read-only data-structures are +// shared with the receiver and the temporary buffers are reallocated. The receiver and the returned +// Evaluators can be used concurrently. +func (eval *evaluator) ShallowCopy() Evaluator { + return &evaluator{eval.Evaluator.ShallowCopy(), eval.params} +} + +// WithKey creates a shallow copy of the receiver Evaluator for which the new EvaluationKey is evaluationKey +// and where the temporary buffers are shared. The receiver and the returned Evaluators cannot be used concurrently. +func (eval *evaluator) WithKey(evaluationKey rlwe.EvaluationKey) Evaluator { + return &evaluator{eval.Evaluator.WithKey(evaluationKey), eval.params} +} + +// CoeffsToSlotsNew applies the homomorphic encoding and returns the result on new ciphertexts. +// Homomorphically encodes a complex vector vReal + i*vImag. +// If the packing is sparse (n < N/2), then returns ctReal = Ecd(vReal || vImag) and ctImag = nil. +// If the packing is dense (n == N/2), then returns ctReal = Ecd(vReal) and ctImag = Ecd(vImag). +func (eval *evaluator) CoeffsToSlotsNew(ctIn *ckks.Ciphertext, ctsMatrices EncodingMatrix) (ctReal, ctImag *ckks.Ciphertext) { + ctReal = ckks.NewCiphertext(eval.params, 1, ctIn.Level()-ctsMatrices.Depth(true), 0) + + if eval.params.LogSlots() == eval.params.LogN()-1 { + ctImag = ckks.NewCiphertext(eval.params, 1, ctIn.Level()-ctsMatrices.Depth(true), 0) + } + + eval.CoeffsToSlots(ctIn, ctsMatrices, ctReal, ctImag) + return +} + +// CoeffsToSlots applies the homomorphic encoding and returns the results on the provided ciphertexts. +// Homomorphically encodes a complex vector vReal + i*vImag of size n on a real vector of size 2n. +// If the packing is sparse (n < N/2), then returns ctReal = Ecd(vReal || vImag) and ctImag = nil. +// If the packing is dense (n == N/2), then returns ctReal = Ecd(vReal) and ctImag = Ecd(vImag). +func (eval *evaluator) CoeffsToSlots(ctIn *ckks.Ciphertext, ctsMatrices EncodingMatrix, ctReal, ctImag *ckks.Ciphertext) { + zV := ctIn.CopyNew() + eval.dft(ctIn, ctsMatrices.matrices, zV) + + eval.Conjugate(zV, ctReal) + + var tmp *ckks.Ciphertext + if ctImag != nil { + tmp = ctImag + } else { + tmp = ckks.NewCiphertextAtLevelFromPoly(ctReal.Level(), [2]*ring.Poly{eval.CtxPool().Value[0], eval.CtxPool().Value[1]}) + } + + // Imag part + eval.Sub(zV, ctReal, tmp) + eval.DivByi(tmp, tmp) + + // Real part + eval.Add(ctReal, zV, ctReal) + + // If repacking, then ct0 and ct1 right n/2 slots are zero. + if eval.params.LogSlots() < eval.params.LogN()-1 { + eval.Rotate(tmp, eval.params.Slots(), tmp) + eval.Add(ctReal, tmp, ctReal) + } + + zV = nil +} + +// SlotsToCoeffsNew applies the homomorphic decoding and returns the result on a new ciphertext. +// Homomorphically decodes a real vector of size 2n on a complex vector vReal + i*vImag of size n. +// If the packing is sparse (n < N/2) then ctReal = Ecd(vReal || vImag) and ctImag = nil. +// If the packing is dense (n == N/2), then ctReal = Ecd(vReal) and ctImag = Ecd(vImag). +func (eval *evaluator) SlotsToCoeffsNew(ctReal, ctImag *ckks.Ciphertext, stcMatrices EncodingMatrix) (ctOut *ckks.Ciphertext) { + level := ctReal.Level() + if ctImag != nil { + level = utils.MinInt(level, ctImag.Level()) + } + ctOut = ckks.NewCiphertext(eval.params, 1, level, ctReal.Scale) + eval.SlotsToCoeffs(ctReal, ctImag, stcMatrices, ctOut) + return + +} + +// SlotsToCoeffsNew applies the homomorphic decoding and returns the result on the provided ciphertext. +// Homomorphically decodes a real vector of size 2n on a complex vector vReal + i*vImag of size n. +// If the packing is sparse (n < N/2) then ctReal = Ecd(vReal || vImag) and ctImag = nil. +// If the packing is dense (n == N/2), then ctReal = Ecd(vReal) and ctImag = Ecd(vImag). +func (eval *evaluator) SlotsToCoeffs(ctReal, ctImag *ckks.Ciphertext, stcMatrices EncodingMatrix, ctOut *ckks.Ciphertext) { + // If full packing, the repacking can be done directly using ct0 and ct1. + if ctImag != nil { + eval.MultByi(ctImag, ctOut) + eval.Add(ctOut, ctReal, ctOut) + eval.dft(ctOut, stcMatrices.matrices, ctOut) + } else { + eval.dft(ctReal, stcMatrices.matrices, ctOut) + } + + return +} + +func (eval *evaluator) dft(ctIn *ckks.Ciphertext, plainVectors []ckks.PtDiagMatrix, ctOut *ckks.Ciphertext) { + + // Sequentially multiplies w with the provided dft matrices. + var in, out *ckks.Ciphertext + for i, plainVector := range plainVectors { + in, out = ctOut, ctOut + if i == 0 { + in, out = ctIn, ctOut + } + scale := out.Scale + eval.LinearTransform(in, plainVector, []*ckks.Ciphertext{out}) + if err := eval.Rescale(out, scale, out); err != nil { + panic(err) + } + } + + return +} + +// EvalModNew applies a homomorphic mod Q on a vector scaled by Delta, scaled down to mod 1 : +// +// 1) Delta * (Q/Delta * I(X) + m(X)) (Delta = scaling factor, I(X) integer poly, m(X) message) +// 2) Delta * (I(X) + Delta/Q * m(X)) (divide by Q/Delta) +// 3) Delta * (Delta/Q * m(X)) (x mod 1) +// 4) Delta * (m(X)) (multiply back by Q/Delta) +// +// Since Q is not a power of two, but Delta is, then does an approximate division by the closest +// power of two to Q instead. Hence, it assumes that the input plaintext is already scaled by +// the correcting factor Q/2^{round(log(Q))}. +// +// !! Assumes that the input is normalized by 1/K for K the range of the approximation. +// +// Scaling back error correction by 2^{round(log(Q))}/Q afterward is included in the polynomial +func (eval *evaluator) EvalModNew(ct *ckks.Ciphertext, evalModPoly EvalModPoly) *ckks.Ciphertext { + + // Stores default scales + prevScaleCt := ct.Scale + + // Normalize the modular reduction to mod by 1 (division by Q) + ct.Scale = evalModPoly.scalingFactor + + var err error + + // Compute the scales that the ciphertext should have before the double angle + // formula such that after it it has the scale it had before the polynomial + // evaluation + targetScale := ct.Scale + for i := 0; i < evalModPoly.doubleAngle; i++ { + targetScale = math.Sqrt(targetScale * eval.params.QiFloat64(evalModPoly.levelStart-evalModPoly.sinePoly.Depth()-evalModPoly.doubleAngle+i+1)) + } + + // Division by 1/2^r and change of variable for the Chebysehev evaluation + if evalModPoly.sineType == Cos1 || evalModPoly.sineType == Cos2 { + eval.AddConst(ct, -0.5/(complex(evalModPoly.scFac, 0)*(evalModPoly.sinePoly.B-evalModPoly.sinePoly.A)), ct) + } + + // Chebyshev evaluation + if ct, err = eval.EvaluatePoly(ct, evalModPoly.sinePoly, targetScale); err != nil { + panic(err) + } + + // Double angle + sqrt2pi := evalModPoly.sqrt2Pi + for i := 0; i < evalModPoly.doubleAngle; i++ { + sqrt2pi *= sqrt2pi + eval.MulRelin(ct, ct, ct) + eval.Add(ct, ct, ct) + eval.AddConst(ct, -sqrt2pi, ct) + if err := eval.Rescale(ct, targetScale, ct); err != nil { + panic(err) + } + } + + // ArcSine + if evalModPoly.arcSinePoly != nil { + if ct, err = eval.EvaluatePoly(ct, evalModPoly.arcSinePoly, ct.Scale); err != nil { + panic(err) + } + } + + // Multiplies back by q + ct.Scale = prevScaleCt + return ct +} diff --git a/ckks/advanced/homomorphic_encoding.go b/ckks/advanced/homomorphic_encoding.go new file mode 100644 index 00000000..d134a17b --- /dev/null +++ b/ckks/advanced/homomorphic_encoding.go @@ -0,0 +1,650 @@ +package advanced + +import ( + "github.com/ldsec/lattigo/v2/ckks" + "github.com/ldsec/lattigo/v2/utils" + "math" +) + +// LinearTransformType is a type used to distinguish different linear transformations. +type LinearTransformType int + +// CoeffsToSlots and SlotsToCoeffs are two linear transformations. +const ( + CoeffsToSlots = LinearTransformType(0) // Homomorphic Encoding + SlotsToCoeffs = LinearTransformType(1) // Homomorphic Decoding +) + +// EncodingMatrix is a struct storing the factorized DFT matrix +type EncodingMatrix struct { + EncodingMatrixLiteral + matrices []ckks.PtDiagMatrix +} + +// EncodingMatrixLiteral is a struct storing the parameters to generate the factorized DFT matrix. +type EncodingMatrixLiteral struct { + LinearTransformType LinearTransformType + LevelStart int // Encoding level + BitReversed bool // Flag for bit-reverseed input to the DFT (with bit-reversed output), by default false. + BSGSRatio float64 // n1/n2 ratio for the bsgs algo for matrix x vector eval + ScalingFactor [][]float64 +} + +// Depth returns the number of levels allocated. +// If actual == true then returns the number of moduli consumed, else +// returns the factorization depth. +func (mParams *EncodingMatrixLiteral) Depth(actual bool) (depth int) { + if actual { + depth = len(mParams.ScalingFactor) + } else { + for i := range mParams.ScalingFactor { + for range mParams.ScalingFactor[i] { + depth++ + } + } + } + return +} + +// Levels returns the index of the Qi used int CoeffsToSlots. +func (mParams *EncodingMatrixLiteral) Levels() (levels []int) { + levels = []int{} + trueDepth := mParams.Depth(true) + for i := range mParams.ScalingFactor { + for range mParams.ScalingFactor[trueDepth-1-i] { + levels = append(levels, mParams.LevelStart-i) + } + } + + return +} + +// Rotations returns the list of rotations performed during the CoeffsToSlot operation. +func (mParams *EncodingMatrixLiteral) Rotations(logN, logSlots int) (rotations []int) { + rotations = []int{} + + slots := 1 << logSlots + dslots := slots + if logSlots < logN-1 { + dslots <<= 1 + if mParams.LinearTransformType == CoeffsToSlots { + rotations = append(rotations, slots) + } + } + + indexCtS := computeBootstrappingDFTIndexMap(logN, logSlots, mParams.Depth(false), mParams.LinearTransformType, mParams.BitReversed) + + // Coeffs to Slots rotations + for i, pVec := range indexCtS { + N1 := ckks.FindBestBSGSSplit(pVec, dslots, mParams.BSGSRatio) + rotations = addMatrixRotToList(pVec, rotations, N1, slots, mParams.LinearTransformType == SlotsToCoeffs && logSlots < logN-1 && i == 0) + } + + return +} + +// NewHomomorphicEncodingMatrixFromLiteral generates the factorized encoding matrix. +// scaling : constant by witch the all the matrices will be multuplied by. +// encoder : ckks.Encoder. +func NewHomomorphicEncodingMatrixFromLiteral(mParams EncodingMatrixLiteral, encoder ckks.Encoder, logN, logSlots int, scaling complex128) EncodingMatrix { + + slots := 1 << logSlots + depth := mParams.Depth(false) + logdSlots := logSlots + 1 + if logdSlots == logN { + logdSlots-- + } + + roots := computeRoots(slots << 1) + pow5 := make([]int, (slots<<1)+1) + pow5[0] = 1 + for i := 1; i < (slots<<1)+1; i++ { + pow5[i] = pow5[i-1] * 5 + pow5[i] &= (slots << 2) - 1 + } + + ctsLevels := mParams.Levels() + + // CoeffsToSlots vectors + matrices := make([]ckks.PtDiagMatrix, len(ctsLevels)) + pVecDFT := computeDFTMatrices(logSlots, logdSlots, depth, roots, pow5, scaling, mParams.LinearTransformType, mParams.BitReversed) + cnt := 0 + trueDepth := mParams.Depth(true) + for i := range mParams.ScalingFactor { + for j := range mParams.ScalingFactor[trueDepth-i-1] { + matrices[cnt] = encoder.EncodeDiagMatrixBSGSAtLvl(ctsLevels[cnt], pVecDFT[cnt], mParams.ScalingFactor[trueDepth-i-1][j], mParams.BSGSRatio, logdSlots) + cnt++ + } + } + + return EncodingMatrix{EncodingMatrixLiteral: mParams, matrices: matrices} +} + +func computeRoots(N int) (roots []complex128) { + + var angle float64 + + m := N << 1 + + roots = make([]complex128, m) + + roots[0] = 1 + + for i := 1; i < m; i++ { + angle = 6.283185307179586 * float64(i) / float64(m) + roots[i] = complex(math.Cos(angle), math.Sin(angle)) + } + + return +} + +func fftPlainVec(logN, dslots int, roots []complex128, pow5 []int) (a, b, c [][]complex128) { + + var N, m, index, tt, gap, k, mask, idx1, idx2 int + + N = 1 << logN + + a = make([][]complex128, logN) + b = make([][]complex128, logN) + c = make([][]complex128, logN) + + var size int + if 2*N == dslots { + size = 2 + } else { + size = 1 + } + + index = 0 + for m = 2; m <= N; m <<= 1 { + + a[index] = make([]complex128, dslots) + b[index] = make([]complex128, dslots) + c[index] = make([]complex128, dslots) + + tt = m >> 1 + + for i := 0; i < N; i += m { + + gap = N / m + mask = (m << 2) - 1 + + for j := 0; j < m>>1; j++ { + + k = (pow5[j] & mask) * gap + + idx1 = i + j + idx2 = i + j + tt + + for u := 0; u < size; u++ { + a[index][idx1+u*N] = 1 + a[index][idx2+u*N] = -roots[k] + b[index][idx1+u*N] = roots[k] + c[index][idx2+u*N] = 1 + } + } + } + + index++ + } + + return +} + +func fftInvPlainVec(logN, dslots int, roots []complex128, pow5 []int) (a, b, c [][]complex128) { + + var N, m, index, tt, gap, k, mask, idx1, idx2 int + + N = 1 << logN + + a = make([][]complex128, logN) + b = make([][]complex128, logN) + c = make([][]complex128, logN) + + var size int + if 2*N == dslots { + size = 2 + } else { + size = 1 + } + + index = 0 + for m = N; m >= 2; m >>= 1 { + + a[index] = make([]complex128, dslots) + b[index] = make([]complex128, dslots) + c[index] = make([]complex128, dslots) + + tt = m >> 1 + + for i := 0; i < N; i += m { + + gap = N / m + mask = (m << 2) - 1 + + for j := 0; j < m>>1; j++ { + + k = ((m << 2) - (pow5[j] & mask)) * gap + + idx1 = i + j + idx2 = i + j + tt + + for u := 0; u < size; u++ { + + a[index][idx1+u*N] = 1 + a[index][idx2+u*N] = -roots[k] + b[index][idx1+u*N] = 1 + c[index][idx2+u*N] = roots[k] + } + } + } + + index++ + } + + return +} + +func addMatrixRotToList(pVec map[int]bool, rotations []int, N1, slots int, repack bool) []int { + + if len(pVec) < 3 { + for j := range pVec { + if !utils.IsInSliceInt(j, rotations) { + rotations = append(rotations, j) + } + } + } else { + var index int + for j := range pVec { + + index = (j / N1) * N1 + + if repack { + // Sparse repacking, occurring during the first DFT matrix of the CoeffsToSlots. + index &= (2*slots - 1) + } else { + // Other cases + index &= (slots - 1) + } + + if index != 0 && !utils.IsInSliceInt(index, rotations) { + rotations = append(rotations, index) + } + + index = j & (N1 - 1) + + if index != 0 && !utils.IsInSliceInt(index, rotations) { + rotations = append(rotations, index) + } + } + } + + return rotations +} + +func computeBootstrappingDFTIndexMap(logN, logSlots, maxDepth int, ltType LinearTransformType, bitreversed bool) (rotationMap []map[int]bool) { + + var level, depth, nextLevel int + + level = logSlots + + rotationMap = make([]map[int]bool, maxDepth) + + // We compute the chain of merge in order or reverse order depending if its DFT or InvDFT because + // the way the levels are collapsed has an impact on the total number of rotations and keys to be + // stored. Ex. instead of using 255 + 64 plaintext vectors, we can use 127 + 128 plaintext vectors + // by reversing the order of the merging. + merge := make([]int, maxDepth) + for i := 0; i < maxDepth; i++ { + + depth = int(math.Ceil(float64(level) / float64(maxDepth-i))) + + if ltType == CoeffsToSlots { + merge[i] = depth + } else { + merge[len(merge)-i-1] = depth + + } + + level -= depth + } + + level = logSlots + for i := 0; i < maxDepth; i++ { + + if logSlots < logN-1 && ltType == SlotsToCoeffs && i == 0 { + + // Special initial matrix for the repacking before SlotsToCoeffs + rotationMap[i] = genWfftRepackIndexMap(logSlots, level) + + // Merges this special initial matrix with the first layer of SlotsToCoeffs DFT + rotationMap[i] = nextLevelfftIndexMap(rotationMap[i], logSlots, 2< 1< 1<>1 { + mat[i], mat[N-i] = mat[N-i], mat[i] + } + } +} + +func conjugateDiagMatrix(mat map[int][]complex128) { + for i := range mat { + + for j := range mat[i] { + c := mat[i][j] + mat[i][j] = complex(real(c), -imag(c)) + } + } +} + +func genBitReverseDiagMatrix(logN int) (diagMat map[int][]complex128) { + + var N, iRev, diff int + + diagMat = make(map[int][]complex128) + + N = 1 << logN + + for i := 0; i < N; i++ { + iRev = int(utils.BitReverse64(uint64(i), uint64(logN))) + + diff = (i - iRev) & (N - 1) + + if diagMat[diff] == nil { + diagMat[diff] = make([]complex128, N) + } + + diagMat[diff][iRev] = complex(1, 0) + } + + return +} + +func addToDiagMatrix(diagMat map[int][]complex128, index int, vec []complex128) { + if diagMat[index] == nil { + diagMat[index] = vec + } else { + diagMat[index] = add(diagMat[index], vec) + } +} + +func rotate(x []complex128, n int) (y []complex128) { + + y = make([]complex128, len(x)) + + mask := int(len(x) - 1) + + // Rotates to the left + for i := 0; i < len(x); i++ { + y[i] = x[(i+n)&mask] + } + + return +} + +func mul(a, b []complex128) (res []complex128) { + + res = make([]complex128, len(a)) + + for i := 0; i < len(a); i++ { + res[i] = a[i] * b[i] + } + + return +} + +func add(a, b []complex128) (res []complex128) { + + res = make([]complex128, len(a)) + + for i := 0; i < len(a); i++ { + res[i] = a[i] + b[i] + } + + return +} diff --git a/ckks/advanced/homomorphic_encoding_test.go b/ckks/advanced/homomorphic_encoding_test.go new file mode 100644 index 00000000..a71610a1 --- /dev/null +++ b/ckks/advanced/homomorphic_encoding_test.go @@ -0,0 +1,343 @@ +package advanced + +import ( + "flag" + "github.com/ldsec/lattigo/v2/ckks" + "github.com/ldsec/lattigo/v2/rlwe" + "github.com/ldsec/lattigo/v2/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "math" + "runtime" + "testing" +) + +var printPrecisionStats = flag.Bool("print-precision", false, "print precision stats") + +var minPrec float64 = 15 + +func TestHomomorphicEncoding(t *testing.T) { + var err error + + if runtime.GOARCH == "wasm" { + t.Skip("skipping homomorphic encoding tests for GOARCH=wasm") + } + + ParametersLiteral := ckks.ParametersLiteral{ + LogN: 13, + LogSlots: 12, + Scale: 1 << 45, + Sigma: rlwe.DefaultSigma, + Q: []uint64{ + 0x10000000006e0001, // 60 Q0 + 0x2000000a0001, // 45 + 0x2000000e0001, // 45 + 0x1fffffc20001, // 45 + + }, + P: []uint64{ + 0x1fffffffffe00001, // Pi 61 + 0x1fffffffffc80001, // Pi 61 + }, + } + + testEncodingMatrixLiteralMarshalling(t) + + var params ckks.Parameters + if params, err = ckks.NewParametersFromLiteral(ParametersLiteral); err != nil { + panic(err) + } + + for _, testSet := range []func(params ckks.Parameters, t *testing.T){ + testCoeffsToSlots, + testSlotsToCoeffs, + } { + testSet(params, t) + runtime.GC() + } + + ParametersLiteral.LogSlots-- + if params, err = ckks.NewParametersFromLiteral(ParametersLiteral); err != nil { + panic(err) + } + + for _, testSet := range []func(params ckks.Parameters, t *testing.T){ + testCoeffsToSlots, + testSlotsToCoeffs, + } { + testSet(params, t) + runtime.GC() + } +} + +func testEncodingMatrixLiteralMarshalling(t *testing.T) { + t.Run("Marshalling", func(t *testing.T) { + m := EncodingMatrixLiteral{ + LinearTransformType: CoeffsToSlots, + LevelStart: 12, + BSGSRatio: 16.0, + BitReversed: false, + ScalingFactor: [][]float64{ + {0x100000000060001}, + {0xfffffffff00001}, + {0xffffffffd80001}, + {0x1000000002a0001}, + }, + } + + data, err := m.MarshalBinary() + assert.Nil(t, err) + + mNew := new(EncodingMatrixLiteral) + if err := mNew.UnmarshalBinary(data); err != nil { + assert.Nil(t, err) + } + assert.Equal(t, m, *mNew) + }) +} + +func testCoeffsToSlots(params ckks.Parameters, t *testing.T) { + + packing := "FullPacking" + if params.LogSlots() < params.LogN()-1 { + packing = "SparsePacking" + } + + t.Run("CoeffsToSlots/"+packing, func(t *testing.T) { + + // This test tests the homomorphic encoding + // It first generates a vector of complex values of size params.Slots() + // + // vReal + i*vImag + // + // Then encode coefficient-wise and encrypts the vectors : + // + // Enc(bitReverse(vReal)||bitReverse(vImg)) + // + // And applies the homomorphic Encoding (will merge both vectors if there was two) + // + // Enc(iFFT(vReal+ i*vImag)) + // + // And returns the result in one ciphextext if the ciphertext can store it else in two ciphertexts + // + // Enc(Ecd(vReal) || Ecd(vImag)) or Enc(Ecd(vReal)) and Enc(Ecd(vImag)) + // + // Then checks that Dcd(Dec(Enc(Ecd(vReal)))) = vReal and Dcd(Dec(Enc(Ecd(vImag)))) = vImag + + CoeffsToSlotsParametersLiteral := EncodingMatrixLiteral{ + LinearTransformType: CoeffsToSlots, + LevelStart: params.MaxLevel(), + BSGSRatio: 16.0, + BitReversed: false, + ScalingFactor: [][]float64{ + {params.QiFloat64(params.MaxLevel() - 2)}, + {params.QiFloat64(params.MaxLevel() - 1)}, + {params.QiFloat64(params.MaxLevel() - 0)}, + }, + } + + kgen := ckks.NewKeyGenerator(params) + sk := kgen.GenSecretKey() + encoder := ckks.NewEncoder(params) + encryptor := ckks.NewEncryptor(params, sk) + decryptor := ckks.NewDecryptor(params, sk) + + n := math.Pow(1.0/float64(2*params.Slots()), 1.0/float64(CoeffsToSlotsParametersLiteral.Depth(true))) + + // Generates the encoding matrices + CoeffsToSlotMatrices := NewHomomorphicEncodingMatrixFromLiteral(CoeffsToSlotsParametersLiteral, encoder, params.LogN(), params.LogSlots(), complex(n, 0)) + + // Gets the rotations indexes for CoeffsToSlots + rotations := CoeffsToSlotsParametersLiteral.Rotations(params.LogN(), params.LogSlots()) + + // Generates the rotation keys + rotKey := kgen.GenRotationKeysForRotations(rotations, true, sk) + + // Creates an evaluator with the rotation keys + eval := NewEvaluator(params, rlwe.EvaluationKey{Rlk: nil, Rtks: rotKey}) + + // Generates the vector of random complex values + values := make([]complex128, params.Slots()) + for i := range values { + values[i] = complex(utils.RandFloat64(-1, 1), utils.RandFloat64(-1, 1)) + } + + // Splits between real and imaginary + valuesReal := make([]complex128, params.Slots()) + for i := range valuesReal { + valuesReal[i] = complex(real(values[i]), 0) + } + + valuesImag := make([]complex128, params.Slots()) + for i := range valuesImag { + valuesImag[i] = complex(imag(values[i]), 0) + } + + // Applies bit-reverse on the original complex vector + ckks.SliceBitReverseInPlaceComplex128(values, params.Slots()) + + // Maps to a float vector + // Add gaps if sparse packing + valuesFloat := make([]float64, params.N()) + gap := params.N() / (2 * params.Slots()) + for i, jdx, idx := 0, params.N()>>1, 0; i < params.Slots(); i, jdx, idx = i+1, jdx+gap, idx+gap { + valuesFloat[idx] = real(values[i]) + valuesFloat[jdx] = imag(values[i]) + } + + // Encodes coefficient-wise and encrypts the test vector + plaintext := ckks.NewPlaintext(params, params.MaxLevel(), params.Scale()) + encoder.EncodeCoeffs(valuesFloat, plaintext) + ciphertext := encryptor.EncryptNew(plaintext) + + // Applies the homomorphic DFT + ct0, ct1 := eval.CoeffsToSlotsNew(ciphertext, CoeffsToSlotMatrices) + + // Checks against the original coefficients + var coeffsReal, coeffsImag []complex128 + if params.LogSlots() < params.LogN()-1 { + coeffsRealImag := encoder.DecodePublic(decryptor.DecryptNew(ct0), params.LogSlots()+1, 0) + coeffsReal = coeffsRealImag[:params.Slots()] + coeffsImag = coeffsRealImag[params.Slots():] + } else { + coeffsReal = encoder.DecodePublic(decryptor.DecryptNew(ct0), params.LogSlots(), 0) + coeffsImag = encoder.DecodePublic(decryptor.DecryptNew(ct1), params.LogSlots(), 0) + } + + verifyTestVectors(params, encoder, nil, valuesReal, coeffsReal, params.LogSlots(), 0, t) + verifyTestVectors(params, encoder, nil, valuesImag, coeffsImag, params.LogSlots(), 0, t) + }) +} + +func testSlotsToCoeffs(params ckks.Parameters, t *testing.T) { + + packing := "FullPacking" + if params.LogSlots() < params.LogN()-1 { + packing = "SparsePacking" + } + + t.Run("SlotsToCoeffs/"+packing, func(t *testing.T) { + + // This test tests the homomorphic decoding + // It first generates a complex vector of size 2*slots + // if 2*slots == N, then two vectors are generated, one for the real part, one for the imaginary part : + // + // vReal and vReal (both floating point vectors because the encoding always result in a real vector) + // + // Then encode and encrypts the vectors : + // + // Enc(Ecd(vReal)) and Enc(Ecd(vImag)) + // + // And applies the homomorphic decoding (will merge both vectors if there was two) + // + // Enc(FFT(Ecd(vReal) + i*Ecd(vImag))) + // + // The result should be the decoding of the initial vectors bit-reversed + // + // Enc(FFT(Ecd(vReal) + i*Ecd(vImag))) = Enc(BitReverse(Dcd(Ecd(vReal + i*vImag)))) + // + // The first N/2 slots of the plaintext will be the real part while the last N/2 the imaginary part + // In case of 2*slots < N, then there is a gap of N/(2*slots) between each values + + SlotsToCoeffsParametersLiteral := EncodingMatrixLiteral{ + LinearTransformType: SlotsToCoeffs, + LevelStart: params.MaxLevel(), + BSGSRatio: 16.0, + BitReversed: false, + ScalingFactor: [][]float64{ + {params.QiFloat64(params.MaxLevel() - 2)}, + {params.QiFloat64(params.MaxLevel() - 1)}, + {params.QiFloat64(params.MaxLevel() - 0)}, + }, + } + + kgen := ckks.NewKeyGenerator(params) + sk := kgen.GenSecretKey() + encoder := ckks.NewEncoder(params) + encryptor := ckks.NewEncryptor(params, sk) + decryptor := ckks.NewDecryptor(params, sk) + + // Generates the encoding matrices + SlotsToCoeffsMatrix := NewHomomorphicEncodingMatrixFromLiteral(SlotsToCoeffsParametersLiteral, encoder, params.LogN(), params.LogSlots(), 1.0) + + // Gets the rotations indexes for SlotsToCoeffs + rotations := SlotsToCoeffsParametersLiteral.Rotations(params.LogN(), params.LogSlots()) + + // Generates the rotation keys + rotKey := kgen.GenRotationKeysForRotations(rotations, true, sk) + + // Creates an evaluator with the rotation keys + eval := NewEvaluator(params, rlwe.EvaluationKey{Rlk: nil, Rtks: rotKey}) + + // Generates the n first slots of the test vector (real part to encode) + valuesReal := make([]complex128, params.Slots()) + for i := range valuesReal { + valuesReal[i] = complex(float64(i+1)/float64(params.Slots()), 0) + } + + // Generates the n first slots of the test vector (imaginary part to encode) + valuesImag := make([]complex128, params.Slots()) + for i := range valuesImag { + valuesImag[i] = complex(float64(i+1)/float64(params.Slots()), 0) + } + + // If sparse, there there is the space to store both vectors in one + if params.LogSlots() < params.LogN()-1 { + for i := range valuesReal { + valuesReal[i] += complex(0, real(valuesImag[i])) + } + } + + // Encodes and encrypts the test vectors + logSlots := params.LogSlots() + if params.LogSlots() < params.LogN()-1 { + logSlots++ + } + + plaintext := ckks.NewPlaintext(params, params.MaxLevel(), params.Scale()) + encoder.Encode(plaintext, valuesReal, logSlots) + ct0 := encryptor.EncryptNew(plaintext) + var ct1 *ckks.Ciphertext + if params.LogSlots() == params.LogN()-1 { + encoder.Encode(plaintext, valuesImag, logSlots) + ct1 = encryptor.EncryptNew(plaintext) + } + + // Applies the homomorphic DFT + res := eval.SlotsToCoeffsNew(ct0, ct1, SlotsToCoeffsMatrix) + + // Decrypt and decode in the coefficient domain + coeffsFloat := encoder.DecodeCoeffsPublic(decryptor.DecryptNew(res), 0) + + // Extracts the coefficients and construct the complex vector + // This is simply coefficient ordering + valuesTest := make([]complex128, params.Slots()) + gap := params.N() / (2 * params.Slots()) + for i, idx := 0, 0; i < params.Slots(); i, idx = i+1, idx+gap { + valuesTest[i] = complex(coeffsFloat[idx], coeffsFloat[idx+(params.N()>>1)]) + } + + // The result is always returned as a single complex vector, so if full-packing (2 initial vectors) + // then repacks both vectors together + if params.LogSlots() == params.LogN()-1 { + for i := range valuesReal { + valuesReal[i] += complex(0, real(valuesImag[i])) + } + } + + // Result is bit-reversed, so applies the bit-reverse permutation on the reference vector + ckks.SliceBitReverseInPlaceComplex128(valuesReal, params.Slots()) + + verifyTestVectors(params, encoder, decryptor, valuesReal, valuesTest, params.LogSlots(), 0, t) + }) +} + +func verifyTestVectors(params ckks.Parameters, encoder ckks.Encoder, decryptor ckks.Decryptor, valuesWant []complex128, element interface{}, logSlots int, bound float64, t *testing.T) { + + precStats := ckks.GetPrecisionStats(params, encoder, decryptor, valuesWant, element, logSlots, bound) + if *printPrecisionStats { + t.Log(precStats.String()) + } + require.GreaterOrEqual(t, real(precStats.MeanPrecision), minPrec) + require.GreaterOrEqual(t, imag(precStats.MeanPrecision), minPrec) +} diff --git a/ckks/advanced/homomorphic_mod.go b/ckks/advanced/homomorphic_mod.go new file mode 100644 index 00000000..898cf801 --- /dev/null +++ b/ckks/advanced/homomorphic_mod.go @@ -0,0 +1,183 @@ +package advanced + +import ( + "github.com/ldsec/lattigo/v2/ckks" + "math" + "math/cmplx" +) + +// SineType is the type of function used during the bootstrapping +// for the homomorphic modular reduction +type SineType uint64 + +func sin2pi2pi(x complex128) complex128 { + return cmplx.Sin(6.283185307179586 * x) // 6.283185307179586 +} + +func cos2pi(x complex128) complex128 { + return cmplx.Cos(6.283185307179586 * x) +} + +// Sin and Cos are the two proposed functions for SineType +const ( + Sin = SineType(0) // Standard Chebyshev approximation of (1/2pi) * sin(2pix) + Cos1 = SineType(1) // Special approximation (Han and Ki) of pow((1/2pi), 1/2^r) * cos(2pi(x-0.25)/2^r) + Cos2 = SineType(2) // Standard Chebyshev approximation of pow((1/2pi), 1/2^r) * cos(2pi(x-0.25)/2^r) +) + +// EvalModLiteral a struct for the paramters of the EvalMod step +// of the bootstrapping +type EvalModLiteral struct { + Q uint64 // Q to reduce by during EvalMod + LevelStart int // Starting level of EvalMod + ScalingFactor float64 // Scaling factor used during EvalMod + SineType SineType // Chose betwenn [Sin(2*pi*x)] or [cos(2*pi*x/r) with double angle formula] + MessageRatio float64 // Ratio between Q0 and m, i.e. Q[0]/|m| + K int // K parameter (interpolation in the range -K to K) + SineDeg int // Degree of the interpolation + DoubleAngle int // Number of rescale and double angle formula (only applies for cos) + ArcSineDeg int // Degree of the Taylor arcsine composed with f(2*pi*x) (if zero then not used) +} + +// QDiff return Q/ClosestedPow2 +// This is the error introduced by the approximate division by Q +func (evm *EvalModLiteral) QDiff() float64 { + return float64(evm.Q) / math.Exp2(math.Round(math.Log2(float64(evm.Q)))) +} + +// EvalModPoly is a struct storing the EvalModLiteral with +// the polynomials. +type EvalModPoly struct { + levelStart int + scalingFactor float64 + sineType SineType + messageRatio float64 + doubleAngle int + qDiff float64 + scFac float64 + sqrt2Pi float64 + sinePoly *ckks.Polynomial + arcSinePoly *ckks.Polynomial +} + +// LevelStart returns the starting level of the EvalMod. +func (evp *EvalModPoly) LevelStart() int { + return evp.levelStart +} + +// ScalingFactor returns scaling factor used during the EvalMod. +func (evp *EvalModPoly) ScalingFactor() float64 { + return evp.scalingFactor +} + +// ScFac returns 1/2^r where r is the number of double angle evaluation. +func (evp *EvalModPoly) ScFac() float64 { + return evp.scFac +} + +// MessageRatio returns the pre-set ratio Q[0]/|m|. +func (evp *EvalModPoly) MessageRatio() float64 { + return evp.messageRatio +} + +// A returns the left bound of the sine approximation (scaled by 1/2^r). +func (evp *EvalModPoly) A() float64 { + return real(evp.sinePoly.A) +} + +// B returns the right bound of the sine approximation (scaled by 1/2^r). +func (evp *EvalModPoly) B() float64 { + return real(evp.sinePoly.B) +} + +// K return the sine approximation range. +func (evp *EvalModPoly) K() float64 { + return real(evp.sinePoly.B) * evp.scFac +} + +// QDiff return Q/ClosestedPow2 +// This is the error introduced by the approximate division by Q. +func (evp *EvalModPoly) QDiff() float64 { + return evp.qDiff +} + +// NewEvalModPolyFromLiteral generates an EvalModPoly fromt the EvalModLiteral. +func NewEvalModPolyFromLiteral(evm EvalModLiteral) EvalModPoly { + + var arcSinePoly *ckks.Polynomial + var sinePoly *ckks.Polynomial + var sqrt2pi float64 + + scFac := math.Exp2(float64(evm.DoubleAngle)) + + qDiff := evm.QDiff() + + if evm.ArcSineDeg > 0 { + + sqrt2pi = 1.0 + + coeffs := make([]complex128, evm.ArcSineDeg+1) + + coeffs[1] = 0.15915494309189535 * complex(qDiff, 0) + + for i := 3; i < evm.ArcSineDeg+1; i += 2 { + + coeffs[i] = coeffs[i-2] * complex(float64(i*i-4*i+4)/float64(i*i-i), 0) + + } + + arcSinePoly = ckks.NewPoly(coeffs) + + } else { + sqrt2pi = math.Pow(0.15915494309189535*qDiff, 1.0/scFac) + } + + if evm.SineType == Sin { + + if evm.DoubleAngle != 0 { + panic("cannot user double angle with SineType == Sin") + } + + sinePoly = ckks.Approximate(sin2pi2pi, -complex(float64(evm.K), 0), complex(float64(evm.K), 0), evm.SineDeg) + + } else if evm.SineType == Cos1 { + + sinePoly = new(ckks.Polynomial) + sinePoly.Coeffs = ApproximateCos(evm.K, evm.SineDeg, evm.MessageRatio, int(evm.DoubleAngle)) + sinePoly.MaxDeg = sinePoly.Degree() + sinePoly.A = complex(float64(-evm.K)/scFac, 0) + sinePoly.B = complex(float64(evm.K)/scFac, 0) + sinePoly.Lead = true + sinePoly.Basis = ckks.ChebyshevBasis + + } else if evm.SineType == Cos2 { + sinePoly = ckks.Approximate(cos2pi, -complex(float64(evm.K)/scFac, 0), complex(float64(evm.K)/scFac, 0), evm.SineDeg) + } else { + panic("invalid SineType") + } + + for i := range sinePoly.Coeffs { + sinePoly.Coeffs[i] *= complex(sqrt2pi, 0) + } + + return EvalModPoly{ + levelStart: evm.LevelStart, + scalingFactor: evm.ScalingFactor, + sineType: evm.SineType, + messageRatio: evm.MessageRatio, + doubleAngle: evm.DoubleAngle, + qDiff: qDiff, + scFac: scFac, + sqrt2Pi: sqrt2pi, + arcSinePoly: arcSinePoly, + sinePoly: sinePoly} +} + +// Depth returns the depth of the SineEval. If true, then also +// counts the double angle formula. +func (evm *EvalModLiteral) Depth() int { + depth := int(math.Ceil(math.Log2(float64(evm.SineDeg + 1)))) + depth += evm.DoubleAngle + depth += int(math.Ceil(math.Log2(float64(evm.ArcSineDeg + 1)))) + return depth +} diff --git a/ckks/advanced/homomorphic_mod_test.go b/ckks/advanced/homomorphic_mod_test.go new file mode 100644 index 00000000..a432527d --- /dev/null +++ b/ckks/advanced/homomorphic_mod_test.go @@ -0,0 +1,214 @@ +package advanced + +import ( + "math" + "runtime" + "testing" + + "github.com/ldsec/lattigo/v2/ckks" + "github.com/ldsec/lattigo/v2/rlwe" + "github.com/ldsec/lattigo/v2/utils" + "github.com/stretchr/testify/assert" +) + +func TestHomomorphicMod(t *testing.T) { + var err error + + if runtime.GOARCH == "wasm" { + t.Skip("skipping homomorphic mod tests for GOARCH=wasm") + } + + ParametersLiteral := ckks.ParametersLiteral{ + LogN: 14, + LogSlots: 13, + Scale: 1 << 45, + Sigma: rlwe.DefaultSigma, + Q: []uint64{ + 0x80000000080001, // 55 Q0 + 0xffffffffffc0001, // 60 + 0x10000000006e0001, // 60 + 0xfffffffff840001, // 60 + 0x1000000000860001, // 60 + 0xfffffffff6a0001, // 60 + 0x1000000000980001, // 60 + 0xfffffffff5a0001, // 60 + 0x1000000000b00001, // 60 + 0x1000000000ce0001, // 60 + 0xfffffffff2a0001, // 60 + 0xfffffffff240001, // 60 + 0x1000000000f00001, // 60 + 0x200000000e0001, // 53 + }, + P: []uint64{ + 0x1fffffffffe00001, // Pi 61 + 0x1fffffffffc80001, // Pi 61 + 0x1fffffffffb40001, // Pi 61 + 0x1fffffffff500001, // Pi 61 + 0x1fffffffff420001, // Pi 61 + }, + } + + testEvalModMarshalling(t) + + var params ckks.Parameters + if params, err = ckks.NewParametersFromLiteral(ParametersLiteral); err != nil { + panic(err) + } + + for _, testSet := range []func(params ckks.Parameters, t *testing.T){ + testEvalMod, + } { + testSet(params, t) + runtime.GC() + } + +} + +func testEvalModMarshalling(t *testing.T) { + t.Run("Marshalling", func(t *testing.T) { + + evm := EvalModLiteral{ + Q: 0x80000000080001, + LevelStart: 12, + SineType: Sin, + MessageRatio: 256.0, + K: 14, + SineDeg: 127, + DoubleAngle: 0, + ArcSineDeg: 7, + ScalingFactor: 1 << 60, + } + + data, err := evm.MarshalBinary() + assert.Nil(t, err) + + evmNew := new(EvalModLiteral) + if err := evmNew.UnmarshalBinary(data); err != nil { + assert.Nil(t, err) + } + assert.Equal(t, evm, *evmNew) + + }) + +} + +func testEvalMod(params ckks.Parameters, t *testing.T) { + + kgen := ckks.NewKeyGenerator(params) + sk := kgen.GenSecretKey() + rlk := kgen.GenRelinearizationKey(sk, 2) + encoder := ckks.NewEncoder(params) + encryptor := ckks.NewEncryptor(params, sk) + decryptor := ckks.NewDecryptor(params, sk) + eval := NewEvaluator(params, rlwe.EvaluationKey{Rlk: rlk, Rtks: nil}) + + t.Run("SineChebyshevWithArcSine", func(t *testing.T) { + + evm := EvalModLiteral{ + Q: 0x80000000080001, + LevelStart: 12, + SineType: Sin, + MessageRatio: 256.0, + K: 14, + SineDeg: 127, + DoubleAngle: 0, + ArcSineDeg: 7, + ScalingFactor: 1 << 60, + } + + EvalModPoly := NewEvalModPolyFromLiteral(evm) + + values, _, ciphertext := newTestVectorsEvalMod(params, encryptor, encoder, evm, t) + + scale := math.Exp2(math.Round(math.Log2(float64(evm.Q) / evm.MessageRatio))) + + // Scale the message to Delta = Q/MessageRatio + eval.ScaleUp(ciphertext, math.Round(scale/ciphertext.Scale), ciphertext) + + // Scale the message up to Sine/MessageRatio + eval.ScaleUp(ciphertext, math.Round((evm.ScalingFactor/evm.MessageRatio)/ciphertext.Scale), ciphertext) + + // Normalization + eval.MultByConst(ciphertext, 1/(float64(evm.K)*evm.QDiff()), ciphertext) + eval.Rescale(ciphertext, params.Scale(), ciphertext) + + // EvalMod + ciphertext = eval.EvalModNew(ciphertext, EvalModPoly) + + // PlaintextCircuit + //pi2r := 6.283185307179586/complex(math.Exp2(float64(evm.DoubleAngle)), 0) + for i := range values { + values[i] -= complex(evm.MessageRatio*evm.QDiff()*math.Round(real(values[i])/(evm.MessageRatio/evm.QDiff())), 0) + } + + verifyTestVectors(params, encoder, decryptor, values, ciphertext, params.LogSlots(), 0, t) + }) + + t.Run("CosOptimizedChebyshevWithArcSine", func(t *testing.T) { + + evm := EvalModLiteral{ + Q: 0x80000000080001, + LevelStart: 12, + SineType: Cos1, + MessageRatio: 256.0, + K: 10, + SineDeg: 31, + DoubleAngle: 2, + ArcSineDeg: 7, + ScalingFactor: 1 << 60, + } + + EvalModPoly := NewEvalModPolyFromLiteral(evm) + + values, _, ciphertext := newTestVectorsEvalMod(params, encryptor, encoder, evm, t) + + scale := math.Exp2(math.Round(math.Log2(float64(evm.Q) / evm.MessageRatio))) + + // Scale the message to Delta = Q/MessageRatio + eval.ScaleUp(ciphertext, math.Round(scale/ciphertext.Scale), ciphertext) + + // Scale the message up to Sine/MessageRatio + eval.ScaleUp(ciphertext, math.Round((evm.ScalingFactor/evm.MessageRatio)/ciphertext.Scale), ciphertext) + + // Normalization + eval.MultByConst(ciphertext, 1/(float64(evm.K)*evm.QDiff()), ciphertext) + eval.Rescale(ciphertext, params.Scale(), ciphertext) + + // EvalMod + ciphertext = eval.EvalModNew(ciphertext, EvalModPoly) + + // PlaintextCircuit + //pi2r := 6.283185307179586/complex(math.Exp2(float64(evm.DoubleAngle)), 0) + for i := range values { + values[i] -= complex(evm.MessageRatio*evm.QDiff()*math.Round(real(values[i])/(evm.MessageRatio/evm.QDiff())), 0) + } + + verifyTestVectors(params, encoder, decryptor, values, ciphertext, params.LogSlots(), 0, t) + }) +} + +func newTestVectorsEvalMod(params ckks.Parameters, encryptor ckks.Encryptor, encoder ckks.Encoder, evm EvalModLiteral, t *testing.T) (values []complex128, plaintext *ckks.Plaintext, ciphertext *ckks.Ciphertext) { + + logSlots := params.LogSlots() + + values = make([]complex128, 1< 1 { - btp.evaluator.DropLevel(ct, 1) - } - - // Brings the ciphertext scale to Q0/2^{10} - if ct.Level() == 1 { - - // if one level is available, then uses it to match the scale - btp.evaluator.SetScale(ct, btp.prescale) - - // then drops to level 0 - for ct.Level() != 0 { - btp.evaluator.DropLevel(ct, 1) - } - - } else { - - // else drop to level 0 - for ct.Level() != 0 { - btp.evaluator.DropLevel(ct, 1) - } - - // and does an integer constant mult by round((Q0/Delta_m)/ctscle) - - if btp.prescale < ct.Scale { - panic("ciphetext scale > Q[0]/(Q[0]/Delta_m)") - } - btp.evaluator.ScaleUp(ct, math.Round(btp.prescale/ct.Scale), ct) - } - - // ModUp ct_{Q_0} -> ct_{Q_L} - //t = time.Now() - ct = btp.modUp(ct) - //log.Println("After ModUp :", time.Now().Sub(t), ct.Level(), ct.Scale()) - - // Brings the ciphertext scale to sineQi/(Q0/scale) if its under - btp.evaluator.ScaleUp(ct, math.Round(btp.postscale/ct.Scale), ct) - - //SubSum X -> (N/dslots) * Y^dslots - //t = time.Now() - ct = btp.subSum(ct) - //log.Println("After SubSum :", time.Now().Sub(t), ct.Level(), ct.Scale()) - // Part 1 : Coeffs to slots - - //t = time.Now() - ct0, ct1 = CoeffsToSlots(ct, btp.pDFTInv, btp.evaluator) - //log.Println("After CtS :", time.Now().Sub(t), ct0.Level(), ct0.Scale()) - - // Part 2 : SineEval - //t = time.Now() - ct0, ct1 = btp.evaluateSine(ct0, ct1) - //log.Println("After Sine :", time.Now().Sub(t), ct0.Level(), ct0.Scale()) - - // Part 3 : Slots to coeffs - //t = time.Now() - ct0 = SlotsToCoeffs(ct0, ct1, btp.pDFT, btp.evaluator) - - ct0.Scale = math.Exp2(math.Round(math.Log2(ct0.Scale))) // rounds to the nearest power of two - //log.Println("After StC :", time.Now().Sub(t), ct0.Level(), ct0.Scale()) - return ct0 -} - -func (btp *Bootstrapper) subSum(ct *Ciphertext) *Ciphertext { - - for i := btp.params.LogSlots(); i < btp.params.MaxLogSlots(); i++ { - - btp.evaluator.Rotate(ct, 1< (Q >> 1) { - ct.Value[u].Coeffs[i][j] = qi - ring.BRedAdd(Q-coeff, qi, bredparams[i]) - } else { - ct.Value[u].Coeffs[i][j] = ring.BRedAdd(coeff, qi, bredparams[i]) - } - } - } - } - - for i := range ct.Value { - ringQ.NTTLvl(ct.Level(), ct.Value[i], ct.Value[i]) - } - - return ct -} - -// CoeffsToSlots applies the homomorphic encoding -func CoeffsToSlots(vec *Ciphertext, pDFTInv []*PtDiagMatrix, eval Evaluator) (ct0, ct1 *Ciphertext) { - - var zV, zVconj *Ciphertext - - zV = dft(vec, pDFTInv, true, eval) - - zVconj = eval.ConjugateNew(zV) - - // The real part is stored in ct0 - ct0 = eval.AddNew(zV, zVconj) - - // The imaginary part is stored in ct1 - ct1 = eval.SubNew(zV, zVconj) - - eval.DivByi(ct1, ct1) - - // If repacking, then ct0 and ct1 right n/2 slots are zero. - if eval.(*evaluator).params.LogSlots() < eval.(*evaluator).params.LogN()-1 { - eval.Rotate(ct1, eval.(*evaluator).params.Slots(), ct1) - eval.Add(ct0, ct1, ct0) - return ct0, nil - } - - zV = nil - zVconj = nil - - return ct0, ct1 -} - -// SlotsToCoeffs applies the homomorphic decoding -func SlotsToCoeffs(ct0, ct1 *Ciphertext, pDFT []*PtDiagMatrix, eval Evaluator) (ct *Ciphertext) { - - // If full packing, the repacking can be done directly using ct0 and ct1. - if ct1 != nil { - eval.MultByi(ct1, ct1) - eval.Add(ct0, ct1, ct0) - } - - ct1 = nil - - return dft(ct0, pDFT, false, eval) -} - -func dft(vec *Ciphertext, plainVectors []*PtDiagMatrix, forward bool, eval Evaluator) *Ciphertext { - - // Sequentially multiplies w with the provided dft matrices. - for _, plainVector := range plainVectors { - scale := vec.Scale - vec = eval.LinearTransform(vec, plainVector)[0] - if err := eval.Rescale(vec, scale, vec); err != nil { - panic(err) - } - } - - return vec -} - -// Sine Evaluation ct0 = Q/(2pi) * sin((2pi/Q) * ct0) -func (btp *Bootstrapper) evaluateSine(ct0, ct1 *Ciphertext) (*Ciphertext, *Ciphertext) { - - ct0.Scale *= btp.MessageRatio - btp.evaluator.scale = btp.sinescale // Reference scale is changed to the Qi used for the SineEval (which is also close to the new ciphetext scale) - - ct0 = btp.evaluateCheby(ct0) - - ct0.Scale /= (btp.MessageRatio * btp.postscale / btp.params.Scale()) - - if ct1 != nil { - ct1.Scale *= btp.MessageRatio - ct1 = btp.evaluateCheby(ct1) - ct1.Scale /= (btp.MessageRatio * btp.postscale / btp.params.Scale()) - } - - // Reference scale is changed back to the current ciphertext's scale. - btp.evaluator.scale = ct0.Scale - - return ct0, ct1 -} - -func (btp *Bootstrapper) evaluateCheby(ct *Ciphertext) *Ciphertext { - - var err error - - cheby := btp.sineEvalPoly - - targetScale := btp.sinescale - - // Compute the scales that the ciphertext should have before the double angle - // formula such that after it it has the scale it had before the polynomial - // evaluation - for i := 0; i < btp.SinRescal; i++ { - targetScale = math.Sqrt(targetScale * float64(btp.SineEvalModuli.Qi[i])) - } - - // Division by 1/2^r and change of variable for the Chebysehev evaluation - if btp.SinType == Cos1 || btp.SinType == Cos2 { - btp.AddConst(ct, -0.5/(complex(btp.scFac, 0)*(cheby.b-cheby.a)), ct) - } - - // Chebyshev evaluation - if ct, err = btp.EvaluateCheby(ct, cheby, targetScale); err != nil { - panic(err) - } - - // Double angle - sqrt2pi := btp.sqrt2pi - for i := 0; i < btp.SinRescal; i++ { - sqrt2pi *= sqrt2pi - btp.MulRelin(ct, ct, ct) - btp.Add(ct, ct, ct) - btp.AddConst(ct, -sqrt2pi, ct) - if err := btp.Rescale(ct, btp.evaluator.scale, ct); err != nil { - panic(err) - } - } - - // ArcSine - if btp.ArcSineDeg > 0 { - if ct, err = btp.EvaluatePoly(ct, btp.arcSinePoly, ct.Scale); err != nil { - panic(err) - } - } - - return ct -} diff --git a/ckks/bootstrap_bench_test.go b/ckks/bootstrap_bench_test.go deleted file mode 100644 index fed9551f..00000000 --- a/ckks/bootstrap_bench_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package ckks - -import ( - "math" - "testing" - "time" -) - -func BenchmarkBootstrapp(b *testing.B) { - - if !*testBootstrapping { - b.Skip("skipping bootstrapping test") - } - - var err error - var testContext = new(testParams) - var btp *Bootstrapper - - paramSet := 2 - - btpParams := DefaultBootstrapParams[paramSet] - - params, err := btpParams.Params() - if err != nil { - panic(err) - } - if testContext, err = genTestParams(params, btpParams.H); err != nil { - panic(err) - } - - rotations := btpParams.RotationsForBootstrapping(testContext.params.LogSlots()) - - rotkeys := testContext.kgen.GenRotationKeysForRotations(rotations, true, testContext.sk) - - btpKey := BootstrappingKey{testContext.rlk, rotkeys} - - if btp, err = NewBootstrapper(testContext.params, btpParams, btpKey); err != nil { - panic(err) - } - - b.Run(testString(testContext, "Bootstrapp/"), func(b *testing.B) { - for i := 0; i < b.N; i++ { - - b.StopTimer() - ct := NewCiphertextRandom(testContext.prng, testContext.params, 1, 0, testContext.params.Scale()) - b.StartTimer() - - var t time.Time - var ct0, ct1 *Ciphertext - - // Brings the ciphertext scale to Q0/2^{10} - btp.evaluator.ScaleUp(ct, math.Round(btp.prescale/ct.Scale), ct) - - // ModUp ct_{Q_0} -> ct_{Q_L} - t = time.Now() - ct = btp.modUp(ct) - b.Log("After ModUp :", time.Since(t), ct.Level(), ct.Scale) - - // Brings the ciphertext scale to sineQi/(Q0/scale) if its under - btp.evaluator.ScaleUp(ct, math.Round(btp.postscale/ct.Scale), ct) - - //SubSum X -> (N/dslots) * Y^dslots - t = time.Now() - ct = btp.subSum(ct) - b.Log("After SubSum :", time.Since(t), ct.Level(), ct.Scale) - - // Part 1 : Coeffs to slots - t = time.Now() - ct0, ct1 = CoeffsToSlots(ct, btp.pDFTInv, btp.evaluator) - b.Log("After CtS :", time.Since(t), ct0.Level(), ct0.Scale) - - // Part 2 : SineEval - t = time.Now() - ct0, ct1 = btp.evaluateSine(ct0, ct1) - b.Log("After Sine :", time.Since(t), ct0.Level(), ct0.Scale) - - // Part 3 : Slots to coeffs - t = time.Now() - ct0 = SlotsToCoeffs(ct0, ct1, btp.pDFT, btp.evaluator) - ct0.Scale = math.Exp2(math.Round(math.Log2(ct0.Scale))) - b.Log("After StC :", time.Since(t), ct0.Level(), ct0.Scale) - } - }) -} diff --git a/ckks/bootstrap_params.go b/ckks/bootstrap_params.go deleted file mode 100644 index 7125f33b..00000000 --- a/ckks/bootstrap_params.go +++ /dev/null @@ -1,1332 +0,0 @@ -package ckks - -import ( - "math" - - "github.com/ldsec/lattigo/v2/rlwe" - "github.com/ldsec/lattigo/v2/utils" - //"fmt" -) - -// SinType is the type of function used during the bootstrapping -// for the homomorphic modular reduction -type SinType uint64 - -// Sin and Cos are the two proposed functions for SinType -const ( - Sin = SinType(0) // Standard Chebyshev approximation of (1/2pi) * sin(2pix) - Cos1 = SinType(1) // Special approximation (Han and Ki) of pow((1/2pi), 1/2^r) * cos(2pi(x-0.25)/2^r) - Cos2 = SinType(2) // Standard Chebyshev approximation of pow((1/2pi), 1/2^r) * cos(2pi(x-0.25)/2^r) -) - -// BootstrappingParameters is a struct for the default bootstrapping parameters -type BootstrappingParameters struct { - ResidualModuli - KeySwitchModuli - SlotsToCoeffsModuli - SineEvalModuli - CoeffsToSlotsModuli - LogN int - LogSlots int - Scale float64 - Sigma float64 - H int // Hamming weight of the secret key - SinType SinType // Chose betwenn [Sin(2*pi*x)] or [cos(2*pi*x/r) with double angle formula] - MessageRatio float64 // Ratio between Q0 and m, i.e. Q[0]/|m| - SinRange int // K parameter (interpolation in the range -K to K) - SinDeg int // Degree of the interpolation - SinRescal int // Number of rescale and double angle formula (only applies for cos) - ArcSineDeg int // Degree of the Taylor arcsine composed with f(2*pi*x) (if zero then not used) - MaxN1N2Ratio float64 // n1/n2 ratio for the bsgs algo for matrix x vector eval - BitReversed bool // Flag for bit-reverseed input to the DFT (with bit-reversed output), by default false. -} - -// Params generates a new set of Parameters from the BootstrappingParameters -func (b *BootstrappingParameters) Params() (p Parameters, err error) { - Qi := append(b.ResidualModuli, b.SlotsToCoeffsModuli.Qi...) - Qi = append(Qi, b.SineEvalModuli.Qi...) - Qi = append(Qi, b.CoeffsToSlotsModuli.Qi...) - - if p, err = NewParametersFromLiteral(ParametersLiteral{ - Q: Qi, - P: b.KeySwitchModuli, - LogN: b.LogN, - Sigma: b.Sigma, - Scale: b.Scale, - LogSlots: b.LogSlots, - }); err != nil { - return Parameters{}, err - } - return -} - -// Copy return a new BootstrappingParameters which is a copy of the target -func (b *BootstrappingParameters) Copy() *BootstrappingParameters { - paramsCopy := &BootstrappingParameters{ - LogN: b.LogN, - LogSlots: b.LogSlots, - Scale: b.Scale, - Sigma: b.Sigma, - H: b.H, - SinType: b.SinType, - MessageRatio: b.MessageRatio, - SinRange: b.SinRange, - SinDeg: b.SinDeg, - SinRescal: b.SinRescal, - ArcSineDeg: b.ArcSineDeg, - MaxN1N2Ratio: b.MaxN1N2Ratio, - BitReversed: b.BitReversed, - } - - // KeySwitchModuli - paramsCopy.KeySwitchModuli = make([]uint64, len(b.KeySwitchModuli)) - copy(paramsCopy.KeySwitchModuli, b.KeySwitchModuli) - - // ResidualModuli - paramsCopy.ResidualModuli = make([]uint64, len(b.ResidualModuli)) - copy(paramsCopy.ResidualModuli, b.ResidualModuli) - - // CoeffsToSlotsModuli - paramsCopy.CoeffsToSlotsModuli.Qi = make([]uint64, b.CtSDepth(true)) - copy(paramsCopy.CoeffsToSlotsModuli.Qi, b.CoeffsToSlotsModuli.Qi) - - paramsCopy.CoeffsToSlotsModuli.ScalingFactor = make([][]float64, b.CtSDepth(true)) - for i := range paramsCopy.CoeffsToSlotsModuli.ScalingFactor { - paramsCopy.CoeffsToSlotsModuli.ScalingFactor[i] = make([]float64, len(b.CoeffsToSlotsModuli.ScalingFactor[i])) - copy(paramsCopy.CoeffsToSlotsModuli.ScalingFactor[i], b.CoeffsToSlotsModuli.ScalingFactor[i]) - } - - // SineEvalModuli - paramsCopy.SineEvalModuli.Qi = make([]uint64, len(b.SineEvalModuli.Qi)) - copy(paramsCopy.SineEvalModuli.Qi, b.SineEvalModuli.Qi) - paramsCopy.SineEvalModuli.ScalingFactor = b.SineEvalModuli.ScalingFactor - - // SlotsToCoeffsModuli - paramsCopy.SlotsToCoeffsModuli.Qi = make([]uint64, b.StCDepth(true)) - copy(paramsCopy.SlotsToCoeffsModuli.Qi, b.SlotsToCoeffsModuli.Qi) - - paramsCopy.SlotsToCoeffsModuli.ScalingFactor = make([][]float64, b.StCDepth(true)) - for i := range paramsCopy.SlotsToCoeffsModuli.ScalingFactor { - paramsCopy.SlotsToCoeffsModuli.ScalingFactor[i] = make([]float64, len(b.SlotsToCoeffsModuli.ScalingFactor[i])) - copy(paramsCopy.SlotsToCoeffsModuli.ScalingFactor[i], b.SlotsToCoeffsModuli.ScalingFactor[i]) - } - - return paramsCopy -} - -// ResidualModuli is a list of the moduli available after the bootstrapping. -type ResidualModuli []uint64 - -// KeySwitchModuli is a list of the special moduli used for the key-switching. -type KeySwitchModuli []uint64 - -// CoeffsToSlotsModuli is a list of the moduli used during he CoeffsToSlots step. -type CoeffsToSlotsModuli struct { - Qi []uint64 - ScalingFactor [][]float64 -} - -// SineEvalModuli is a list of the moduli used during the SineEval step. -type SineEvalModuli struct { - Qi []uint64 - ScalingFactor float64 -} - -// SlotsToCoeffsModuli is a list of the moduli used during the SlotsToCoeffs step. -type SlotsToCoeffsModuli struct { - Qi []uint64 - ScalingFactor [][]float64 -} - -// MaxLevel returns the maximum level of the bootstrapping parameters -func (b *BootstrappingParameters) MaxLevel() int { - return len(b.ResidualModuli) + len(b.CoeffsToSlotsModuli.Qi) + len(b.SineEvalModuli.Qi) + len(b.SlotsToCoeffsModuli.Qi) - 1 -} - -// SineEvalDepth returns the depth of the SineEval. If true, then also -// counts the double angle formula. -func (b *BootstrappingParameters) SineEvalDepth(withRescale bool) int { - depth := int(math.Ceil(math.Log2(float64(b.SinDeg + 1)))) - - if withRescale { - depth += b.SinRescal - } - - return depth -} - -// ArcSineDepth returns the depth of the arcsine polynomial. -func (b *BootstrappingParameters) ArcSineDepth() int { - return int(math.Ceil(math.Log2(float64(b.ArcSineDeg + 1)))) -} - -// CtSDepth returns the number of levels allocated to CoeffsToSlots. -// If actual == true then returns the number of moduli consumed, else -// returns the factorization depth. -func (b *BootstrappingParameters) CtSDepth(actual bool) (depth int) { - if actual { - depth = len(b.CoeffsToSlotsModuli.ScalingFactor) - } else { - for i := range b.CoeffsToSlotsModuli.ScalingFactor { - for range b.CoeffsToSlotsModuli.ScalingFactor[i] { - depth++ - } - } - } - - return -} - -// CtSLevels returns the index of the Qi used int CoeffsToSlots -func (b *BootstrappingParameters) CtSLevels() (ctsLevel []int) { - ctsLevel = []int{} - for i := range b.CoeffsToSlotsModuli.Qi { - for range b.CoeffsToSlotsModuli.ScalingFactor[b.CtSDepth(true)-1-i] { - ctsLevel = append(ctsLevel, b.MaxLevel()-i) - } - } - - return -} - -// StCDepth returns the number of levels allocated to SlotToCoeffs. -// If actual == true then returns the number of moduli consumed, else -// returns the factorization depth. -func (b *BootstrappingParameters) StCDepth(actual bool) (depth int) { - if actual { - depth = len(b.SlotsToCoeffsModuli.ScalingFactor) - } else { - for i := range b.SlotsToCoeffsModuli.ScalingFactor { - for range b.SlotsToCoeffsModuli.ScalingFactor[i] { - depth++ - } - } - } - - return -} - -// StCLevels returns the index of the Qi used in SlotsToCoeffs -func (b *BootstrappingParameters) StCLevels() (stcLevel []int) { - stcLevel = []int{} - for i := range b.SlotsToCoeffsModuli.Qi { - for range b.SlotsToCoeffsModuli.ScalingFactor[b.StCDepth(true)-1-i] { - stcLevel = append(stcLevel, b.MaxLevel()-b.CtSDepth(true)-b.SineEvalDepth(true)-b.ArcSineDepth()-i) - } - } - - return -} - -// GenCoeffsToSlotsMatrix generates the factorized encoding matrix -// scaling : constant by witch the all the matrices will be multuplied by -// encoder : ckks.Encoder -func (b *BootstrappingParameters) GenCoeffsToSlotsMatrix(scaling complex128, encoder Encoder) []*PtDiagMatrix { - - logSlots := b.LogSlots - slots := 1 << logSlots - depth := b.CtSDepth(false) - logdSlots := logSlots + 1 - if logdSlots == b.LogN { - logdSlots-- - } - - roots := computeRoots(slots << 1) - pow5 := make([]int, (slots<<1)+1) - pow5[0] = 1 - for i := 1; i < (slots<<1)+1; i++ { - pow5[i] = pow5[i-1] * 5 - pow5[i] &= (slots << 2) - 1 - } - - ctsLevels := b.CtSLevels() - - // CoeffsToSlots vectors - pDFTInv := make([]*PtDiagMatrix, len(ctsLevels)) - pVecDFTInv := computeDFTMatrices(logSlots, logdSlots, depth, roots, pow5, scaling, true, b.BitReversed) - cnt := 0 - for i := range b.CoeffsToSlotsModuli.ScalingFactor { - for j := range b.CoeffsToSlotsModuli.ScalingFactor[b.CtSDepth(true)-i-1] { - pDFTInv[cnt] = encoder.EncodeDiagMatrixBSGSAtLvl(ctsLevels[cnt], pVecDFTInv[cnt], b.CoeffsToSlotsModuli.ScalingFactor[b.CtSDepth(true)-i-1][j], b.MaxN1N2Ratio, logdSlots) - cnt++ - } - } - - return pDFTInv -} - -// GenSlotsToCoeffsMatrix generates the factorized decoding matrix -// scaling : constant by witch the all the matrices will be multuplied by -// encoder : ckks.Encoder -func (b *BootstrappingParameters) GenSlotsToCoeffsMatrix(scaling complex128, encoder Encoder) []*PtDiagMatrix { - - logSlots := b.LogSlots - slots := 1 << logSlots - depth := b.StCDepth(false) - logdSlots := logSlots + 1 - if logdSlots == b.LogN { - logdSlots-- - } - - roots := computeRoots(slots << 1) - pow5 := make([]int, (slots<<1)+1) - pow5[0] = 1 - for i := 1; i < (slots<<1)+1; i++ { - pow5[i] = pow5[i-1] * 5 - pow5[i] &= (slots << 2) - 1 - } - - stcLevels := b.StCLevels() - - // CoeffsToSlots vectors - pDFT := make([]*PtDiagMatrix, len(stcLevels)) - pVecDFT := computeDFTMatrices(logSlots, logdSlots, depth, roots, pow5, scaling, false, b.BitReversed) - cnt := 0 - for i := range b.SlotsToCoeffsModuli.ScalingFactor { - for j := range b.SlotsToCoeffsModuli.ScalingFactor[b.StCDepth(true)-i-1] { - pDFT[cnt] = encoder.EncodeDiagMatrixBSGSAtLvl(stcLevels[cnt], pVecDFT[cnt], b.SlotsToCoeffsModuli.ScalingFactor[b.StCDepth(true)-i-1][j], b.MaxN1N2Ratio, logdSlots) - cnt++ - } - } - - return pDFT -} - -// RotationsForCoeffsToSlots returns the list of rotations performed during the CoeffsToSlot operation. -func (b *BootstrappingParameters) RotationsForCoeffsToSlots(logSlots int) (rotations []int) { - rotations = []int{} - - slots := 1 << logSlots - dslots := slots - if logSlots < b.LogN-1 { - dslots <<= 1 - rotations = append(rotations, slots) - } - - indexCtS := computeBootstrappingDFTIndexMap(b.LogN, logSlots, b.CtSDepth(false), true, b.BitReversed) - - // Coeffs to Slots rotations - for _, pVec := range indexCtS { - N1 := findbestbabygiantstepsplit(pVec, dslots, b.MaxN1N2Ratio) - rotations = addMatrixRotToList(pVec, rotations, N1, slots, false) - } - - return -} - -// RotationsForSlotsToCoeffs returns the list of rotations performed during the SlotsToCoeffs operation. -func (b *BootstrappingParameters) RotationsForSlotsToCoeffs(logSlots int) (rotations []int) { - rotations = []int{} - - slots := 1 << logSlots - dslots := slots - if logSlots < b.LogN-1 { - dslots <<= 1 - } - - indexStC := computeBootstrappingDFTIndexMap(b.LogN, logSlots, b.StCDepth(false), false, b.BitReversed) - - // Slots to Coeffs rotations - for i, pVec := range indexStC { - N1 := findbestbabygiantstepsplit(pVec, dslots, b.MaxN1N2Ratio) - rotations = addMatrixRotToList(pVec, rotations, N1, slots, logSlots < b.LogN-1 && i == 0) - } - - return -} - -// RotationsForBootstrapping returns the list of rotations performed during the Bootstrapping operation. -func (b *BootstrappingParameters) RotationsForBootstrapping(logSlots int) (rotations []int) { - - // List of the rotation key values to needed for the bootstrapp - rotations = []int{} - - slots := 1 << logSlots - dslots := slots - if logSlots < b.LogN-1 { - dslots <<= 1 - } - - //SubSum rotation needed X -> Y^slots rotations - for i := logSlots; i < b.LogN-1; i++ { - if !utils.IsInSliceInt(1<> 1 - - for i := 0; i < N; i += m { - - gap = N / m - mask = (m << 2) - 1 - - for j := 0; j < m>>1; j++ { - - k = (pow5[j] & mask) * gap - - idx1 = i + j - idx2 = i + j + tt - - for u := 0; u < size; u++ { - a[index][idx1+u*N] = 1 - a[index][idx2+u*N] = -roots[k] - b[index][idx1+u*N] = roots[k] - c[index][idx2+u*N] = 1 - } - } - } - - index++ - } - - return -} - -func fftInvPlainVec(logN, dslots int, roots []complex128, pow5 []int) (a, b, c [][]complex128) { - - var N, m, index, tt, gap, k, mask, idx1, idx2 int - - N = 1 << logN - - a = make([][]complex128, logN) - b = make([][]complex128, logN) - c = make([][]complex128, logN) - - var size int - if 2*N == dslots { - size = 2 - } else { - size = 1 - } - - index = 0 - for m = N; m >= 2; m >>= 1 { - - a[index] = make([]complex128, dslots) - b[index] = make([]complex128, dslots) - c[index] = make([]complex128, dslots) - - tt = m >> 1 - - for i := 0; i < N; i += m { - - gap = N / m - mask = (m << 2) - 1 - - for j := 0; j < m>>1; j++ { - - k = ((m << 2) - (pow5[j] & mask)) * gap - - idx1 = i + j - idx2 = i + j + tt - - for u := 0; u < size; u++ { - - a[index][idx1+u*N] = 1 - a[index][idx2+u*N] = -roots[k] - b[index][idx1+u*N] = 1 - c[index][idx2+u*N] = roots[k] - } - } - } - - index++ - } - - return -} - -func addMatrixRotToList(pVec map[int]bool, rotations []int, N1, slots int, repack bool) []int { - - if len(pVec) < 3 { - for j := range pVec { - if !utils.IsInSliceInt(j, rotations) { - rotations = append(rotations, j) - } - } - } else { - var index int - for j := range pVec { - - index = (j / N1) * N1 - - if repack { - // Sparse repacking, occurring during the first DFT matrix of the CoeffsToSlots. - index &= (2*slots - 1) - } else { - // Other cases - index &= (slots - 1) - } - - if index != 0 && !utils.IsInSliceInt(index, rotations) { - rotations = append(rotations, index) - } - - index = j & (N1 - 1) - - if index != 0 && !utils.IsInSliceInt(index, rotations) { - rotations = append(rotations, index) - } - } - } - - return rotations -} - -func computeBootstrappingDFTIndexMap(logN, logSlots, maxDepth int, forward, bitreversed bool) (rotationMap []map[int]bool) { - - var level, depth, nextLevel int - - level = logSlots - - rotationMap = make([]map[int]bool, maxDepth) - - // We compute the chain of merge in order or reverse order depending if its DFT or InvDFT because - // the way the levels are collapsed has an impact on the total number of rotations and keys to be - // stored. Ex. instead of using 255 + 64 plaintext vectors, we can use 127 + 128 plaintext vectors - // by reversing the order of the merging. - merge := make([]int, maxDepth) - for i := 0; i < maxDepth; i++ { - - depth = int(math.Ceil(float64(level) / float64(maxDepth-i))) - - if forward { - merge[i] = depth - } else { - merge[len(merge)-i-1] = depth - - } - - level -= depth - } - - level = logSlots - for i := 0; i < maxDepth; i++ { - - if logSlots < logN-1 && !forward && i == 0 { - - // Special initial matrix for the repacking before SlotsToCoeffs - rotationMap[i] = genWfftRepackIndexMap(logSlots, level) - - // Merges this special initial matrix with the first layer of SlotsToCoeffs DFT - rotationMap[i] = nextLevelfftIndexMap(rotationMap[i], logSlots, 2< 1< 1<>1 { - mat[i], mat[N-i] = mat[N-i], mat[i] - } - } -} - -func conjugateDiagMatrix(mat map[int][]complex128) { - for i := range mat { - - for j := range mat[i] { - c := mat[i][j] - mat[i][j] = complex(real(c), -imag(c)) - } - } -} - -func genBitReverseDiagMatrix(logN int) (diagMat map[int][]complex128) { - - var N, iRev, diff int - - diagMat = make(map[int][]complex128) - - N = 1 << logN - - for i := 0; i < N; i++ { - iRev = int(utils.BitReverse64(uint64(i), uint64(logN))) - - diff = (i - iRev) & (N - 1) - - if diagMat[diff] == nil { - diagMat[diff] = make([]complex128, N) - } - - diagMat[diff][iRev] = complex(1, 0) - } - - return -} - -func addToDiagMatrix(diagMat map[int][]complex128, index int, vec []complex128) { - if diagMat[index] == nil { - diagMat[index] = vec - } else { - diagMat[index] = add(diagMat[index], vec) - } -} - -func rotate(x []complex128, n int) (y []complex128) { - - y = make([]complex128, len(x)) - - mask := int(len(x) - 1) - - // Rotates to the left - for i := 0; i < len(x); i++ { - y[i] = x[(i+n)&mask] - } - - return -} - -func mul(a, b []complex128) (res []complex128) { - - res = make([]complex128, len(a)) - - for i := 0; i < len(a); i++ { - res[i] = a[i] * b[i] - } - - return -} - -func add(a, b []complex128) (res []complex128) { - - res = make([]complex128, len(a)) - - for i := 0; i < len(a); i++ { - res[i] = a[i] + b[i] - } - - return -} diff --git a/ckks/bootstrap_test.go b/ckks/bootstrap_test.go deleted file mode 100644 index 1b39101a..00000000 --- a/ckks/bootstrap_test.go +++ /dev/null @@ -1,543 +0,0 @@ -package ckks - -import ( - "github.com/ldsec/lattigo/v2/ckks/bettersine" - "github.com/ldsec/lattigo/v2/rlwe" - "github.com/ldsec/lattigo/v2/utils" - "math" - "math/cmplx" - "runtime" - "testing" -) - -func TestBootstrap(t *testing.T) { - - if !*testBootstrapping { - t.Skip("skipping bootstrapping test") - } - - if runtime.GOARCH == "wasm" { - t.Skip("skipping bootstrapping tests for GOARCH=wasm") - } - - var testContext = new(testParams) - - paramSet := 0 - - bootstrapParams := DefaultBootstrapParams[paramSet : paramSet+1] - - for paramSet := range bootstrapParams { - - btpParams := bootstrapParams[paramSet] - - // Insecure params for fast testing only - if !*flagLongTest { - btpParams.LogN = 14 - btpParams.LogSlots = 13 - } - - // Tests homomorphic modular reduction encoding and bootstrapping on sparse slots - params, err := btpParams.Params() - if err != nil { - panic(err) - } - - if testContext, err = genTestParams(params, btpParams.H); err != nil { // TODO: setting the param.scale field is not something the user can do - panic(err) - } - - for _, testSet := range []func(testContext *testParams, btpParams *BootstrappingParameters, t *testing.T){ - testEvalSine, - } { - testSet(testContext, btpParams, t) - runtime.GC() - } - - for _, testSet := range []func(testContext *testParams, btpParams *BootstrappingParameters, t *testing.T){ - testCoeffsToSlots, - testSlotsToCoeffs, - testbootstrap, - } { - testSet(testContext, btpParams, t) - runtime.GC() - } - - if !*flagLongTest { - btpParams.LogSlots = 12 - } - - // Tests homomorphic encoding and bootstrapping on full slots - params, err = btpParams.Params() - if err != nil { - panic(err) - } - - if testContext, err = genTestParams(params, btpParams.H); err != nil { // TODO: setting the param.scale field is not something the user can do - panic(err) - } - - for _, testSet := range []func(testContext *testParams, btpParams *BootstrappingParameters, t *testing.T){ - testCoeffsToSlots, - testSlotsToCoeffs, - testbootstrap, - } { - testSet(testContext, btpParams, t) - runtime.GC() - } - } -} - -func testEvalSine(testContext *testParams, btpParams *BootstrappingParameters, t *testing.T) { - - t.Run(testString(testContext, "Sin/"), func(t *testing.T) { - - var err error - - eval := testContext.evaluator - - DefaultScale := testContext.params.Scale() - - SineScale := btpParams.SineEvalModuli.ScalingFactor - - testContext.params.scale = SineScale - eval.(*evaluator).scale = SineScale - - deg := 127 - K := float64(15) - - values, _, ciphertext := newTestVectorsSineBootstrapp(testContext, btpParams, testContext.encryptorSk, -K+1, K-1, t) - eval.DropLevel(ciphertext, btpParams.CtSDepth(true)-1) - - cheby := Approximate(sin2pi2pi, -complex(K, 0), complex(K, 0), deg) - - for i := range values { - values[i] = sin2pi2pi(values[i]) - } - - eval.MultByConst(ciphertext, 2/(cheby.b-cheby.a), ciphertext) - eval.AddConst(ciphertext, (-cheby.a-cheby.b)/(cheby.b-cheby.a), ciphertext) - eval.Rescale(ciphertext, eval.(*evaluator).scale, ciphertext) - - if ciphertext, err = eval.EvaluateCheby(ciphertext, cheby, ciphertext.Scale); err != nil { - t.Error(err) - } - - verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t) - - testContext.params.scale = DefaultScale - eval.(*evaluator).scale = DefaultScale - }) - - t.Run(testString(testContext, "Cos1/"), func(t *testing.T) { - - var err error - - eval := testContext.evaluator - - DefaultScale := testContext.params.Scale() - - SineScale := btpParams.SineEvalModuli.ScalingFactor - - testContext.params.scale = SineScale - eval.(*evaluator).scale = SineScale - - K := 25 - deg := 63 - dev := btpParams.MessageRatio - scNum := 2 - - scFac := complex(float64(int(1< 0 { - sqrt2pi = math.Pow(1, 1.0/real(scFac)) - } else { - sqrt2pi = math.Pow(0.15915494309189535, 1.0/real(scFac)) - } - - for i := range cheby.coeffs { - cheby.coeffs[i] *= complex(sqrt2pi, 0) - } - - verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t) - - for i := range values { - - values[i] = cmplx.Cos(6.283185307179586 * (1 / scFac) * (values[i] - 0.25)) - - for j := 0; j < scNum; j++ { - values[i] = 2*values[i]*values[i] - 1 - } - - if btpParams.ArcSineDeg == 0 { - values[i] /= 6.283185307179586 - } - } - - eval.AddConst(ciphertext, -0.25, ciphertext) - - eval.MultByConst(ciphertext, 2/((cheby.b-cheby.a)*scFac), ciphertext) - eval.AddConst(ciphertext, (-cheby.a-cheby.b)/(cheby.b-cheby.a), ciphertext) - eval.Rescale(ciphertext, eval.(*evaluator).scale, ciphertext) - - if ciphertext, err = eval.EvaluateCheby(ciphertext, cheby, ciphertext.Scale); err != nil { - t.Error(err) - } - - for i := 0; i < scNum; i++ { - sqrt2pi *= sqrt2pi - eval.MulRelin(ciphertext, ciphertext, ciphertext) - eval.Add(ciphertext, ciphertext, ciphertext) - eval.AddConst(ciphertext, -sqrt2pi, ciphertext) - eval.Rescale(ciphertext, eval.(*evaluator).scale, ciphertext) - } - - verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t) - - testContext.params.scale = DefaultScale - eval.(*evaluator).scale = DefaultScale - - }) - - t.Run(testString(testContext, "Cos2/"), func(t *testing.T) { - - if len(btpParams.SineEvalModuli.Qi) < 12 { - t.Skip() - } - - var err error - - eval := testContext.evaluator - - DefaultScale := testContext.params.Scale() - - SineScale := btpParams.SineEvalModuli.ScalingFactor - - testContext.params.scale = SineScale - eval.(*evaluator).scale = SineScale - - K := 325 - deg := 255 - scNum := 4 - - scFac := complex(float64(int(1<>1, 0; i < params.Slots(); i, jdx, idx = i+1, jdx+gap, idx+gap { - valuesFloat[idx] = real(values[i]) - valuesFloat[jdx] = imag(values[i]) - } - - // Encodes coefficient-wise and encrypts the test vector - plaintext := NewPlaintext(params, params.MaxLevel(), params.Scale()) - testContext.encoder.EncodeCoeffs(valuesFloat, plaintext) - ciphertext := testContext.encryptorPk.EncryptNew(plaintext) - - // Creates an evaluator with the rotation keys - eval := testContext.evaluator.WithKey(rlwe.EvaluationKey{Rlk: testContext.rlk, Rtks: rotKey}) - - // Applies the homomorphic DFT - ct0, ct1 := CoeffsToSlots(ciphertext, CoeffsToSlotMatrices, eval) - - // Checks against the original coefficients - var coeffsReal, coeffsImag []complex128 - if params.LogSlots() < params.LogN()-1 { - coeffsRealImag := testContext.encoder.DecodePublic(testContext.decryptor.DecryptNew(ct0), params.LogSlots()+1, 0) - coeffsReal = coeffsRealImag[:params.Slots()] - coeffsImag = coeffsRealImag[params.Slots():] - } else { - coeffsReal = testContext.encoder.DecodePublic(testContext.decryptor.DecryptNew(ct0), params.LogSlots(), 0) - coeffsImag = testContext.encoder.DecodePublic(testContext.decryptor.DecryptNew(ct1), params.LogSlots(), 0) - } - - verifyTestVectors(testContext, nil, valuesReal, coeffsReal, params.LogSlots(), 0, t) - verifyTestVectors(testContext, nil, valuesImag, coeffsImag, params.LogSlots(), 0, t) - }) -} - -func testSlotsToCoeffs(testContext *testParams, btpParams *BootstrappingParameters, t *testing.T) { - t.Run(testString(testContext, "SlotsToCoeffs/"), func(t *testing.T) { - - // This test tests the homomorphic decoding - // It first generates a complex vector of size 2*slots - // if 2*slots == N, then two vectors are generated, one for the real part, one for the imaginary part : - // - // vReal and vReal (both floating point vectors because the encoding always result in a real vector) - // - // Then encode and encrypts the vectors : - // - // Enc(Ecd(vReal)) and Enc(Ecd(vImag)) - // - // And applies the homomorphic decoding (will merge both vectors if there was two) - // - // Enc(FFT(Ecd(vReal) + i*Ecd(vImag))) - // - // The result should be the decoding of the initial vectors bit-reversed - // - // Enc(FFT(Ecd(vReal) + i*Ecd(vImag))) = Enc(BitReverse(Dcd(Ecd(vReal + i*vImag)))) - // - // The first N/2 slots of the plaintext will be the real part while the last N/2 the imaginary part - // In case of 2*slots < N, then there is a gap of N/(2*slots) between each values - - params := testContext.params - - // Generates the encoding matrices - SlotsToCoeffsMatrix := btpParams.GenSlotsToCoeffsMatrix(1.0, testContext.encoder) - - // Gets the rotations indexes for SlotsToCoeffs - rotations := btpParams.RotationsForSlotsToCoeffs(params.LogSlots()) - - // Generates the rotation keys - rotKey := testContext.kgen.GenRotationKeysForRotations(rotations, true, testContext.sk) - - // Creates an evaluator with the rotation keys - eval := testContext.evaluator.WithKey(rlwe.EvaluationKey{Rlk: testContext.rlk, Rtks: rotKey}) - - // Generates the n first slots of the test vector (real part to encode) - valuesReal := make([]complex128, params.Slots()) - for i := range valuesReal { - valuesReal[i] = complex(float64(i+1)/float64(params.Slots()), 0) - } - - // Generates the n first slots of the test vector (imaginary part to encode) - valuesImag := make([]complex128, params.Slots()) - for i := range valuesImag { - valuesImag[i] = complex(1+float64(i+1)/float64(params.Slots()), 0) - } - - // If sparse, there there is the space to store both vectors in one - if params.LogSlots() < params.LogN()-1 { - for i := range valuesReal { - valuesReal[i] += complex(0, real(valuesImag[i])) - } - } - - // Encodes and encrypts the test vectors - logSlots := params.LogSlots() - if params.LogSlots() < params.LogN()-1 { - logSlots++ - } - encoder := testContext.encoder.(*encoderComplex128) - plaintext := NewPlaintext(params, params.MaxLevel(), params.Scale()) - encoder.Encode(plaintext, valuesReal, logSlots) - ct0 := testContext.encryptorPk.EncryptNew(plaintext) - var ct1 *Ciphertext - if params.LogSlots() == params.LogN()-1 { - testContext.encoder.Encode(plaintext, valuesImag, logSlots) - ct1 = testContext.encryptorPk.EncryptNew(plaintext) - } - - // Applies the homomorphic DFT - res := SlotsToCoeffs(ct0, ct1, SlotsToCoeffsMatrix, eval) - - // Decrypt and decode in the coefficient domain - coeffsFloat := testContext.encoder.DecodeCoeffsPublic(testContext.decryptor.DecryptNew(res), 0) - - // Extracts the coefficients and construct the complex vector - // This is simply coefficient ordering - valuesTest := make([]complex128, params.Slots()) - gap := params.N() / (2 * params.Slots()) - for i, idx := 0, 0; i < params.Slots(); i, idx = i+1, idx+gap { - valuesTest[i] = complex(coeffsFloat[idx], coeffsFloat[idx+(params.N()>>1)]) - } - - // The result is always returned as a single complex vector, so if full-packing (2 initial vectors) - // then repacks both vectors together - if params.LogSlots() == params.LogN()-1 { - for i := range valuesReal { - valuesReal[i] += complex(0, real(valuesImag[i])) - } - } - - // Result is bit-reversed, so applies the bit-reverse permutation on the reference vector - sliceBitReverseInPlaceComplex128(valuesReal, params.Slots()) - - verifyTestVectors(testContext, testContext.decryptor, valuesReal, valuesTest, params.LogSlots(), 0, t) - }) -} - -func testbootstrap(testContext *testParams, btpParams *BootstrappingParameters, t *testing.T) { - - t.Run(testString(testContext, "Bootstrapping/FullCircuit/"), func(t *testing.T) { - - params := testContext.params - - rotations := btpParams.RotationsForBootstrapping(testContext.params.LogSlots()) - rotkeys := testContext.kgen.GenRotationKeysForRotations(rotations, true, testContext.sk) - btpKey := BootstrappingKey{testContext.rlk, rotkeys} - - btp, err := NewBootstrapper(testContext.params, btpParams, btpKey) - if err != nil { - panic(err) - } - - values := make([]complex128, 1< 2 { - values[2] = complex(0.9238795325112867, 0.3826834323650898) - values[3] = complex(0.9238795325112867, 0.3826834323650898) - } - - plaintext := NewPlaintext(params, params.MaxLevel(), params.Scale()) - testContext.encoder.Encode(plaintext, values, params.LogSlots()) - - ciphertext := testContext.encryptorPk.EncryptNew(plaintext) - - eval := testContext.evaluator - for ciphertext.Level() != 0 { - eval.DropLevel(ciphertext, 1) - } - - for i := 0; i < 1; i++ { - - ciphertext = btp.Bootstrapp(ciphertext) - //testContext.evaluator.SetScale(ciphertext, testContext.params.scale) - verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t) - } - - }) -} - -func newTestVectorsSineBootstrapp(testContext *testParams, btpParams *BootstrappingParameters, encryptor Encryptor, a, b float64, t *testing.T) (values []complex128, plaintext *Plaintext, ciphertext *Ciphertext) { - - logSlots := testContext.params.LogSlots() - - values = make([]complex128, 1< must use SinType = Cos") - } - - btp = newBootstrapper(params, btpParams) - - btp.BootstrappingKey = &BootstrappingKey{btpKey.Rlk, btpKey.Rtks} - if err = btp.CheckKeys(); err != nil { - return nil, fmt.Errorf("invalid bootstrapping key: %w", err) - } - btp.evaluator = btp.evaluator.WithKey(rlwe.EvaluationKey{Rlk: btpKey.Rlk, Rtks: btpKey.Rtks}).(*evaluator) - - return btp, nil -} - -// newBootstrapper is a constructor of "dummy" bootstrapper to enable the generation of bootstrapping-related constants -// without providing a bootstrapping key. To be replaced by a proper factorization of the bootstrapping pre-computations. -func newBootstrapper(params Parameters, btpParams *BootstrappingParameters) (btp *Bootstrapper) { - btp = new(Bootstrapper) - - btp.params = params - btp.BootstrappingParameters = *btpParams.Copy() - - btp.dslots = params.Slots() - btp.logdslots = params.LogSlots() - if params.LogSlots() < params.MaxLogSlots() { - btp.dslots <<= 1 - btp.logdslots++ - } - - btp.prescale = math.Exp2(math.Round(math.Log2(float64(params.Q()[0]) / btp.MessageRatio))) - btp.sinescale = math.Exp2(math.Round(math.Log2(btp.SineEvalModuli.ScalingFactor))) - btp.postscale = btp.sinescale / btp.MessageRatio - - btp.encoder = NewEncoder(params) - btp.evaluator = NewEvaluator(params, rlwe.EvaluationKey{}).(*evaluator) // creates an evaluator without keys for genDFTMatrices - - btp.genSinePoly() - btp.genDFTMatrices() - - btp.ctxpool = NewCiphertext(params, 1, params.MaxLevel(), 0) - - return btp -} - -// CheckKeys checks if all the necessary keys are present -func (btp *Bootstrapper) CheckKeys() (err error) { - - if btp.Rlk == nil { - return fmt.Errorf("relinearization key is nil") - } - - if btp.Rtks == nil { - return fmt.Errorf("rotation key is nil") - } - - rotMissing := []int{} - for _, i := range btp.rotKeyIndex { - galEl := btp.params.GaloisElementForColumnRotationBy(int(i)) - if _, generated := btp.Rtks.Keys[galEl]; !generated { - rotMissing = append(rotMissing, i) - } - } - - if len(rotMissing) != 0 { - return fmt.Errorf("rotation key(s) missing: %d", rotMissing) - } - - return nil -} - -// AddMatrixRotToList adds the rotations neede to evaluate pVec to the list rotations -func AddMatrixRotToList(pVec *PtDiagMatrix, rotations []int, slots int, repack bool) []int { - - if pVec.naive { - for j := range pVec.Vec { - if !utils.IsInSliceInt(j, rotations) { - rotations = append(rotations, j) - } - } - } else { - var index int - for j := range pVec.Vec { - - N1 := pVec.N1 - - index = ((j / N1) * N1) - - if repack { - // Sparse repacking, occurring during the first DFT matrix of the CoeffsToSlots. - index &= 2*slots - 1 - } else { - // Other cases - index &= slots - 1 - } - - if index != 0 && !utils.IsInSliceInt(index, rotations) { - rotations = append(rotations, index) - } - - index = j & (N1 - 1) - - if index != 0 && !utils.IsInSliceInt(index, rotations) { - rotations = append(rotations, index) - } - } - } - - return rotations -} - -func (btp *Bootstrapper) genDFTMatrices() { - - a := real(btp.sineEvalPoly.a) - b := real(btp.sineEvalPoly.b) - n := float64(btp.params.N()) - qDiff := float64(btp.params.Q()[0]) / math.Exp2(math.Round(math.Log2(float64(btp.params.Q()[0])))) - - // Change of variable for the evaluation of the Chebyshev polynomial + cancelling factor for the DFT and SubSum + evantual scaling factor for the double angle formula - btp.coeffsToSlotsDiffScale = complex(math.Pow(2.0/((b-a)*n*btp.scFac*qDiff), 1.0/float64(btp.CtSDepth(false))), 0) - - // Rescaling factor to set the final ciphertext to the desired scale - btp.slotsToCoeffsDiffScale = complex(math.Pow((qDiff*btp.params.Scale())/btp.postscale, 1.0/float64(btp.StCDepth(false))), 0) - - // CoeffsToSlots vectors - btp.pDFTInv = btp.BootstrappingParameters.GenCoeffsToSlotsMatrix(btp.coeffsToSlotsDiffScale, btp.encoder) - - // SlotsToCoeffs vectors - btp.pDFT = btp.BootstrappingParameters.GenSlotsToCoeffsMatrix(btp.slotsToCoeffsDiffScale, btp.encoder) - - // List of the rotation key values to needed for the bootstrapp - btp.rotKeyIndex = []int{} - - //SubSum rotation needed X -> Y^slots rotations - for i := btp.params.LogSlots(); i < btp.params.MaxLogSlots(); i++ { - if !utils.IsInSliceInt(1< 0 { - btp.sqrt2pi = 1.0 - - coeffs := make([]complex128, btp.ArcSineDeg+1) - - coeffs[1] = 0.15915494309189535 - - for i := 3; i < btp.ArcSineDeg+1; i += 2 { - - coeffs[i] = coeffs[i-2] * complex(float64(i*i-4*i+4)/float64(i*i-i), 0) - - } - - btp.arcSinePoly = NewPoly(coeffs) - - } else { - btp.sqrt2pi = math.Pow(0.15915494309189535, 1.0/btp.scFac) - } - - if btp.SinType == Sin { - - btp.sineEvalPoly = Approximate(sin2pi2pi, -complex(float64(K)/btp.scFac, 0), complex(float64(K)/btp.scFac, 0), deg) - - } else if btp.SinType == Cos1 { - - btp.sineEvalPoly = new(ChebyshevInterpolation) - - btp.sineEvalPoly.coeffs = bettersine.Approximate(K, deg, btp.MessageRatio, int(btp.SinRescal)) - - btp.sineEvalPoly.maxDeg = btp.sineEvalPoly.Degree() - btp.sineEvalPoly.a = complex(float64(-K)/btp.scFac, 0) - btp.sineEvalPoly.b = complex(float64(K)/btp.scFac, 0) - btp.sineEvalPoly.lead = true - - } else if btp.SinType == Cos2 { - - btp.sineEvalPoly = Approximate(cos2pi, -complex(float64(K)/btp.scFac, 0), complex(float64(K)/btp.scFac, 0), deg) - - } else { - panic("Bootstrapper -> invalid sineType") - } - - for i := range btp.sineEvalPoly.coeffs { - btp.sineEvalPoly.coeffs[i] *= complex(btp.sqrt2pi, 0) - } -} diff --git a/ckks/bootstrapping/bootstrap.go b/ckks/bootstrapping/bootstrap.go new file mode 100644 index 00000000..5542a9c9 --- /dev/null +++ b/ckks/bootstrapping/bootstrap.go @@ -0,0 +1,122 @@ +package bootstrapping + +import ( + "github.com/ldsec/lattigo/v2/ckks" + "github.com/ldsec/lattigo/v2/ring" + "math" +) + +// Bootstrapp re-encrypt a ciphertext at lvl Q0 to a ciphertext at MaxLevel-k where k is the depth of the bootstrapping circuit. +// If the input ciphertext level is zero, the input scale must be an exact power of two smaller or equal to round(Q0/2^{10}). +// If the input ciphertext is at level one or more, the input scale does not need to be an exact power of two as one level +// can be used to do a scale matching. +func (btp *Bootstrapper) Bootstrapp(ctIn *ckks.Ciphertext) (ctOut *ckks.Ciphertext) { + + ctOut = ctIn.CopyNew() + + bootstrappingScale := math.Exp2(math.Round(math.Log2(btp.params.QiFloat64(0) / btp.evalModPoly.MessageRatio()))) + + // Drops the level to 1 + for ctOut.Level() > 1 { + btp.DropLevel(ctOut, 1) + } + + // Brings the ciphertext scale to Q0/MessageRatio + if ctOut.Level() == 1 { + + // If one level is available, then uses it to match the scale + btp.SetScale(ctOut, bootstrappingScale) + + // Then drops to level 0 + for ctOut.Level() != 0 { + btp.DropLevel(ctOut, 1) + } + + } else { + + // Does an integer constant mult by round((Q0/Delta_m)/ctscle) + if bootstrappingScale < ctOut.Scale { + panic("ciphetext scale > q/||m||)") + } + + btp.ScaleUp(ctOut, math.Round(bootstrappingScale/ctOut.Scale), ctOut) + } + + // Step 1 : Extend the basis from q to Q + ctOut = btp.modUpFromQ0(ctOut) + + // Brings the ciphertext scale to sineQi/(Q0/scale) if Q0 < sineQi + // Does it after modUp to avoid plaintext overflow + // Reduces the additive error of the next steps + btp.ScaleUp(ctOut, math.Round((btp.evalModPoly.ScalingFactor()/btp.evalModPoly.MessageRatio())/ctOut.Scale), ctOut) + + //SubSum X -> (N/dslots) * Y^dslots + btp.Trace(ctOut, btp.params.LogSlots(), btp.params.LogN()-1, ctOut) + + // Step 2 : CoeffsToSlots (Homomorphic encoding) + ctReal, ctImag := btp.CoeffsToSlotsNew(ctOut, btp.ctsMatrices) + + // Step 3 : EvalMod (Homomorphic modular reduction) + // ctReal = Ecd(real) + // ctImag = Ecd(imag) + // If n < N/2 then ctReal = Ecd(real|imag) + ctReal = btp.EvalModNew(ctReal, btp.evalModPoly) + ctReal.Scale = btp.params.Scale() + + if ctImag != nil { + ctImag = btp.EvalModNew(ctImag, btp.evalModPoly) + ctImag.Scale = btp.params.Scale() + } + + // Step 4 : SlotsToCoeffs (Homomorphic decoding) + ctOut = btp.SlotsToCoeffsNew(ctReal, ctImag, btp.stcMatrices) + + return +} + +func (btp *Bootstrapper) modUpFromQ0(ct *ckks.Ciphertext) *ckks.Ciphertext { + + ringQ := btp.params.RingQ() + + for i := range ct.Value { + ringQ.InvNTTLvl(ct.Level(), ct.Value[i], ct.Value[i]) + } + + // Extend the ciphertext with zero polynomials. + for u := range ct.Value { + ct.Value[u].Coeffs = append(ct.Value[u].Coeffs, make([][]uint64, btp.params.MaxLevel())...) + for i := 1; i < btp.params.MaxLevel()+1; i++ { + ct.Value[u].Coeffs[i] = make([]uint64, btp.params.N()) + } + } + + //Centers the values around Q0 and extends the basis from Q0 to QL + Q := ringQ.Modulus[0] + bredparams := ringQ.BredParams + + var coeff, qi uint64 + for u := range ct.Value { + + for j := 0; j < btp.params.N(); j++ { + + coeff = ct.Value[u].Coeffs[0][j] + + for i := 1; i < btp.params.MaxLevel()+1; i++ { + + qi = ringQ.Modulus[i] + + if coeff > (Q >> 1) { + ct.Value[u].Coeffs[i][j] = qi - ring.BRedAdd(Q-coeff, qi, bredparams[i]) + } else { + ct.Value[u].Coeffs[i][j] = ring.BRedAdd(coeff, qi, bredparams[i]) + } + } + } + } + + for i := range ct.Value { + ringQ.NTTLvl(ct.Level(), ct.Value[i], ct.Value[i]) + } + + return ct +} diff --git a/ckks/bootstrapping/bootstrap_bench_test.go b/ckks/bootstrapping/bootstrap_bench_test.go new file mode 100644 index 00000000..6175fedb --- /dev/null +++ b/ckks/bootstrapping/bootstrap_bench_test.go @@ -0,0 +1,82 @@ +package bootstrapping + +import ( + "github.com/ldsec/lattigo/v2/ckks" + "github.com/ldsec/lattigo/v2/rlwe" + "math" + "testing" + "time" +) + +func BenchmarkBootstrapp(b *testing.B) { + + var err error + var btp *Bootstrapper + + paramSet := 0 + + ckksParams := DefaultCKKSParameters[paramSet] + btpParams := DefaultParameters[paramSet] + + params, err := ckks.NewParametersFromLiteral(ckksParams) + if err != nil { + panic(err) + } + + kgen := ckks.NewKeyGenerator(params) + sk := kgen.GenSecretKeySparse(btpParams.H) + rlk := kgen.GenRelinearizationKey(sk, 2) + + rotations := btpParams.RotationsForBootstrapping(params.LogN(), params.LogSlots()) + rotkeys := kgen.GenRotationKeysForRotations(rotations, true, sk) + + if btp, err = NewBootstrapper(params, btpParams, rlwe.EvaluationKey{Rlk: rlk, Rtks: rotkeys}); err != nil { + panic(err) + } + + b.Run(ParamsToString(params, "Bootstrapp/"), func(b *testing.B) { + for i := 0; i < b.N; i++ { + + bootstrappingScale := math.Exp2(math.Round(math.Log2(btp.params.QiFloat64(0) / btp.evalModPoly.MessageRatio()))) + + b.StopTimer() + ct := ckks.NewCiphertext(params, 1, 0, bootstrappingScale) + b.StartTimer() + + var t time.Time + var ct0, ct1 *ckks.Ciphertext + + // ModUp ct_{Q_0} -> ct_{Q_L} + t = time.Now() + ct = btp.modUpFromQ0(ct) + b.Log("After ModUp :", time.Since(t), ct.Level(), ct.Scale) + + //SubSum X -> (N/dslots) * Y^dslots + t = time.Now() + btp.Trace(ct, btp.params.LogSlots(), btp.params.LogN()-1, ct) + b.Log("After SubSum :", time.Since(t), ct.Level(), ct.Scale) + + // Part 1 : Coeffs to slots + t = time.Now() + ct0, ct1 = btp.CoeffsToSlotsNew(ct, btp.ctsMatrices) + b.Log("After CtS :", time.Since(t), ct0.Level(), ct0.Scale) + + // Part 2 : SineEval + t = time.Now() + ct0 = btp.EvalModNew(ct0, btp.evalModPoly) + ct0.Scale = btp.params.Scale() + + if ct1 != nil { + ct1 = btp.EvalModNew(ct1, btp.evalModPoly) + ct1.Scale = btp.params.Scale() + } + b.Log("After Sine :", time.Since(t), ct0.Level(), ct0.Scale) + + // Part 3 : Slots to coeffs + t = time.Now() + ct0 = btp.SlotsToCoeffsNew(ct0, ct1, btp.stcMatrices) + ct0.Scale = math.Exp2(math.Round(math.Log2(ct0.Scale))) + b.Log("After StC :", time.Since(t), ct0.Level(), ct0.Scale) + } + }) +} diff --git a/ckks/bootstrapping/bootstrap_params.go b/ckks/bootstrapping/bootstrap_params.go new file mode 100644 index 00000000..9c2880a5 --- /dev/null +++ b/ckks/bootstrapping/bootstrap_params.go @@ -0,0 +1,499 @@ +package bootstrapping + +import ( + "github.com/ldsec/lattigo/v2/ckks" + "github.com/ldsec/lattigo/v2/ckks/advanced" + "github.com/ldsec/lattigo/v2/rlwe" + "github.com/ldsec/lattigo/v2/utils" +) + +// Parameters is a struct for the default bootstrapping parameters +type Parameters struct { + SlotsToCoeffsParameters advanced.EncodingMatrixLiteral + EvalModParameters advanced.EvalModLiteral + CoeffsToSlotsParameters advanced.EncodingMatrixLiteral + H int // Hamming weight of the secret key +} + +// MarshalBinary encode the target Parameters on a slice of bytes. +func (p *Parameters) MarshalBinary() (data []byte, err error) { + data = []byte{} + tmp := []byte{} + + if tmp, err = p.SlotsToCoeffsParameters.MarshalBinary(); err != nil { + return nil, err + } + + data = append(data, uint8(len(tmp))) + data = append(data, tmp...) + + if tmp, err = p.EvalModParameters.MarshalBinary(); err != nil { + return nil, err + } + + data = append(data, uint8(len(tmp))) + data = append(data, tmp...) + + if tmp, err = p.CoeffsToSlotsParameters.MarshalBinary(); err != nil { + return nil, err + } + + data = append(data, uint8(len(tmp))) + data = append(data, tmp...) + + tmp = make([]byte, 4) + tmp[0] = uint8(p.H >> 24) + tmp[1] = uint8(p.H >> 16) + tmp[2] = uint8(p.H >> 8) + tmp[3] = uint8(p.H >> 0) + data = append(data, tmp...) + return +} + +// UnmarshalBinary decodes a slice of bytes on the target Parameters. +func (p *Parameters) UnmarshalBinary(data []byte) (err error) { + + pt := 0 + dLen := int(data[pt]) + + if err := p.SlotsToCoeffsParameters.UnmarshalBinary(data[pt+1 : pt+dLen+1]); err != nil { + return err + } + + pt += dLen + pt++ + dLen = int(data[pt]) + + if err := p.EvalModParameters.UnmarshalBinary(data[pt+1 : pt+dLen+1]); err != nil { + return err + } + + pt += dLen + pt++ + dLen = int(data[pt]) + + if err := p.CoeffsToSlotsParameters.UnmarshalBinary(data[pt+1 : pt+dLen+1]); err != nil { + return err + } + + pt += dLen + pt++ + dLen = int(data[pt]) + + p.H = int(data[pt])<<24 | int(data[pt+1])<<16 | int(data[pt+2])<<8 | int(data[pt+3]) + + return +} + +// RotationsForBootstrapping returns the list of rotations performed during the Bootstrapping operation. +func (p *Parameters) RotationsForBootstrapping(LogN, LogSlots int) (rotations []int) { + + // List of the rotation key values to needed for the bootstrapp + rotations = []int{} + + slots := 1 << LogSlots + dslots := slots + if LogSlots < LogN-1 { + dslots <<= 1 + } + + //SubSum rotation needed X -> Y^slots rotations + for i := LogSlots; i < LogN-1; i++ { + if !utils.IsInSliceInt(1< 2 { + values[2] = complex(0.9238795325112867, 0.3826834323650898) + values[3] = complex(0.9238795325112867, 0.3826834323650898) + } + + plaintext := ckks.NewPlaintext(params, 0, params.Scale()) + encoder.Encode(plaintext, values, params.LogSlots()) + + ciphertexts := make([]*ckks.Ciphertext, 2) + bootstrappers := make([]*Bootstrapper, 2) + for i := range ciphertexts { + ciphertexts[i] = encryptor.EncryptNew(plaintext) + if i == 0 { + bootstrappers[i] = btp + } else { + bootstrappers[i] = bootstrappers[0].ShallowCopy() + } + } + + var wg sync.WaitGroup + wg.Add(2) + for i := range ciphertexts { + go func(index int) { + ciphertexts[index] = bootstrappers[index].Bootstrapp(ciphertexts[index]) + //btp.SetScale(ciphertexts[index], params.Scale()) + wg.Done() + }(i) + } + wg.Wait() + + for i := range ciphertexts { + verifyTestVectors(params, encoder, decryptor, values, ciphertexts[i], params.LogSlots(), 0, t) + } + }) +} + +func verifyTestVectors(params ckks.Parameters, encoder ckks.Encoder, decryptor ckks.Decryptor, valuesWant []complex128, element interface{}, logSlots int, bound float64, t *testing.T) { + precStats := ckks.GetPrecisionStats(params, encoder, decryptor, valuesWant, element, logSlots, bound) + t.Log(precStats.String()) +} diff --git a/ckks/bootstrapping/bootstrapper.go b/ckks/bootstrapping/bootstrapper.go new file mode 100644 index 00000000..c79eea84 --- /dev/null +++ b/ckks/bootstrapping/bootstrapper.go @@ -0,0 +1,137 @@ +package bootstrapping + +import ( + "fmt" + "github.com/ldsec/lattigo/v2/ckks" + "github.com/ldsec/lattigo/v2/ckks/advanced" + "github.com/ldsec/lattigo/v2/rlwe" + "math" +) + +// Bootstrapper is a struct to stores a memory pool the plaintext matrices +// the polynomial approximation and the keys for the bootstrapping. +type Bootstrapper struct { + advanced.Evaluator + *bootstrapperBase +} + +type bootstrapperBase struct { + Parameters + params ckks.Parameters + + dslots int // Number of plaintext slots after the re-encoding + logdslots int + + evalModPoly advanced.EvalModPoly + stcMatrices advanced.EncodingMatrix + ctsMatrices advanced.EncodingMatrix +} + +// NewBootstrapper creates a new Bootstrapper. +func NewBootstrapper(params ckks.Parameters, btpParams Parameters, btpKey rlwe.EvaluationKey) (btp *Bootstrapper, err error) { + + if btpParams.EvalModParameters.SineType == advanced.Sin && btpParams.EvalModParameters.DoubleAngle != 0 { + return nil, fmt.Errorf("cannot use double angle formul for SineType = Sin -> must use SineType = Cos") + } + + if btpParams.CoeffsToSlotsParameters.LevelStart-btpParams.CoeffsToSlotsParameters.Depth(true) != btpParams.EvalModParameters.LevelStart { + return nil, fmt.Errorf("starting level and depth of CoeffsToSlotsParameters inconsistent starting level of SineEvalParameters") + } + + if btpParams.EvalModParameters.LevelStart-btpParams.EvalModParameters.Depth() != btpParams.SlotsToCoeffsParameters.LevelStart { + return nil, fmt.Errorf("starting level and depth of SineEvalParameters inconsistent starting level of CoeffsToSlotsParameters") + } + + btp = new(Bootstrapper) + btp.bootstrapperBase = newBootstrapperBase(params, btpParams, btpKey) + + if err = btp.bootstrapperBase.CheckKeys(btpKey); err != nil { + return nil, fmt.Errorf("invalid bootstrapping key: %w", err) + } + + btp.Evaluator = advanced.NewEvaluator(params, btpKey) + + return +} + +// ShallowCopy creates a shallow copy of this Bootstrapper in which all the read-only data-structures are +// shared with the receiver and the temporary buffers are reallocated. The receiver and the returned +// Bootstrapper can be used concurrently. +func (btp *Bootstrapper) ShallowCopy() *Bootstrapper { + return &Bootstrapper{ + Evaluator: btp.Evaluator.ShallowCopy(), + bootstrapperBase: btp.bootstrapperBase, + } +} + +// CheckKeys checks if all the necessary keys are present in the instantiated Bootstrapper +func (bb *bootstrapperBase) CheckKeys(btpKey rlwe.EvaluationKey) (err error) { + + if btpKey.Rlk == nil { + return fmt.Errorf("relinearization key is nil") + } + + if btpKey.Rtks == nil { + return fmt.Errorf("rotation key is nil") + } + + rotKeyIndex := []int{} + rotKeyIndex = append(rotKeyIndex, bb.params.RotationsForTrace(bb.params.LogSlots(), bb.params.MaxLogSlots())...) + rotKeyIndex = append(rotKeyIndex, bb.CoeffsToSlotsParameters.Rotations(bb.params.LogN(), bb.params.LogSlots())...) + rotKeyIndex = append(rotKeyIndex, bb.SlotsToCoeffsParameters.Rotations(bb.params.LogN(), bb.params.LogSlots())...) + + rotMissing := []int{} + for _, i := range rotKeyIndex { + galEl := bb.params.GaloisElementForColumnRotationBy(int(i)) + if _, generated := btpKey.Rtks.Keys[galEl]; !generated { + rotMissing = append(rotMissing, i) + } + } + + if len(rotMissing) != 0 { + return fmt.Errorf("rotation key(s) missing: %d", rotMissing) + } + + return nil +} + +func newBootstrapperBase(params ckks.Parameters, btpParams Parameters, btpKey rlwe.EvaluationKey) (bb *bootstrapperBase) { + bb = new(bootstrapperBase) + bb.params = params + bb.Parameters = btpParams + + bb.dslots = params.Slots() + bb.logdslots = params.LogSlots() + if params.LogSlots() < params.MaxLogSlots() { + bb.dslots <<= 1 + bb.logdslots++ + } + + bb.evalModPoly = advanced.NewEvalModPolyFromLiteral(btpParams.EvalModParameters) + + scFac := bb.evalModPoly.ScFac() + K := bb.evalModPoly.K() / scFac + n := float64(bb.params.N()) + ctsDepth := float64(bb.CoeffsToSlotsParameters.Depth(false)) + stcDepth := float64(bb.SlotsToCoeffsParameters.Depth(false)) + + // Correcting factor for approximate division by Q + // The second correcting factor for approximate multiplication by Q is included in the coefficients of the EvalMod polynomials + qDiff := bb.evalModPoly.QDiff() + + encoder := ckks.NewEncoder(bb.params) + + // CoeffsToSlots vectors + // Change of variable for the evaluation of the Chebyshev polynomial + cancelling factor for the DFT and SubSum + evantual scaling factor for the double angle formula + coeffsToSlotsDiffScale := complex(math.Pow(1.0/(K*n*scFac*qDiff), 1.0/ctsDepth), 0) + bb.ctsMatrices = advanced.NewHomomorphicEncodingMatrixFromLiteral(bb.CoeffsToSlotsParameters, encoder, bb.params.LogN(), bb.params.LogSlots(), coeffsToSlotsDiffScale) + + // SlotsToCoeffs vectors + // Rescaling factor to set the final ciphertext to the desired scale + slotsToCoeffsDiffScale := complex(math.Pow(bb.params.Scale()/(bb.evalModPoly.ScalingFactor()/bb.evalModPoly.MessageRatio()), 1.0/stcDepth), 0) + bb.stcMatrices = advanced.NewHomomorphicEncodingMatrixFromLiteral(bb.SlotsToCoeffsParameters, encoder, bb.params.LogN(), bb.params.LogSlots(), slotsToCoeffsDiffScale) + + encoder = nil + + return +} diff --git a/ckks/chebyshev_interpolation.go b/ckks/chebyshev_interpolation.go index 483bb59c..a1cd9f5a 100644 --- a/ckks/chebyshev_interpolation.go +++ b/ckks/chebyshev_interpolation.go @@ -4,32 +4,16 @@ import ( "math" ) -// ChebyshevInterpolation is a struct storing the coefficients, degree and range of a Chebyshev interpolation polynomial. -type ChebyshevInterpolation struct { - Poly - a complex128 - b complex128 -} - -// A returns the start of the approximation interval. -func (c *ChebyshevInterpolation) A() complex128 { - return c.a -} - -// B returns the end of the approximation interval. -func (c *ChebyshevInterpolation) B() complex128 { - return c.b -} - // Approximate computes a Chebyshev approximation of the input function, for the range [-a, b] of degree degree. // To be used in conjunction with the function EvaluateCheby. -func Approximate(function func(complex128) complex128, a, b complex128, degree int) (cheby *ChebyshevInterpolation) { +func Approximate(function func(complex128) complex128, a, b complex128, degree int) (pol *Polynomial) { - cheby = new(ChebyshevInterpolation) - cheby.a = a - cheby.b = b - cheby.maxDeg = degree - cheby.lead = true + pol = new(Polynomial) + pol.A = a + pol.B = b + pol.MaxDeg = degree + pol.Lead = true + pol.Basis = ChebyshevBasis nodes := chebyshevNodes(degree+1, a, b) @@ -38,17 +22,15 @@ func Approximate(function func(complex128) complex128, a, b complex128, degree i fi[i] = function(nodes[i]) } - cheby.coeffs = chebyCoeffs(nodes, fi, a, b) + pol.Coeffs = chebyCoeffs(nodes, fi, a, b) return } func chebyshevNodes(n int, a, b complex128) (u []complex128) { u = make([]complex128, n) - var x, y complex128 + x, y := 0.5*(a+b), 0.5*(b-a) for k := 1; k < n+1; k++ { - x = 0.5 * (a + b) - y = 0.5 * (b - a) u[k-1] = x + y*complex(math.Cos((float64(k)-0.5)*(3.141592653589793/float64(n))), 0) } return diff --git a/ckks/ciphertext.go b/ckks/ciphertext.go index 2545802b..46bc9e5b 100644 --- a/ckks/ciphertext.go +++ b/ckks/ciphertext.go @@ -39,6 +39,16 @@ func NewCiphertextRandom(prng utils.PRNG, params Parameters, degree, level int, return ciphertext } +// NewCiphertextAtLevelFromPoly construct a new Ciphetext at a specific level +// where the message is set to the passed poly. No checks are performed on poly and +// the returned Ciphertext will share its backing array of coefficient. +func NewCiphertextAtLevelFromPoly(level int, poly [2]*ring.Poly) *Ciphertext { + v0, v1 := new(ring.Poly), new(ring.Poly) + v0.IsNTT, v1.IsNTT = true, true + v0.Coeffs, v1.Coeffs = poly[0].Coeffs[:level+1], poly[1].Coeffs[:level+1] + return &Ciphertext{Ciphertext: &rlwe.Ciphertext{Value: []*ring.Poly{v0, v1}}, Scale: 0} +} + // ScalingFactor returns the scaling factor of the ciphertext func (ct *Ciphertext) ScalingFactor() float64 { return ct.Scale diff --git a/ckks/ckks_benchmarks_test.go b/ckks/ckks_benchmarks_test.go index 4bb995dc..bf2dae42 100644 --- a/ckks/ckks_benchmarks_test.go +++ b/ckks/ckks_benchmarks_test.go @@ -45,7 +45,7 @@ func benchEncoder(testContext *testParams, b *testing.B) { encoder := testContext.encoder logSlots := testContext.params.LogSlots() - b.Run(testString(testContext, "Encoder/Encode/"), func(b *testing.B) { + b.Run(GetTestName(testContext.params, "Encoder/Encode/"), func(b *testing.B) { values := make([]complex128, 1<ct0/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "Evaluator/Mul/ct0*pt->ct0/"), func(t *testing.T) { values1, plaintext1, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t) @@ -494,10 +491,10 @@ func testEvaluatorMul(testContext *testParams, t *testing.T) { testContext.evaluator.MulRelin(ciphertext1, plaintext1, ciphertext1) - verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t) }) - t.Run(testString(testContext, "Evaluator/Mul/pt*ct0->ct0/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "Evaluator/Mul/pt*ct0->ct0/"), func(t *testing.T) { values1, plaintext1, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t) @@ -507,10 +504,10 @@ func testEvaluatorMul(testContext *testParams, t *testing.T) { testContext.evaluator.MulRelin(ciphertext1, plaintext1, ciphertext1) - verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t) }) - t.Run(testString(testContext, "Evaluator/Mul/ct0*pt->ct1/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "Evaluator/Mul/ct0*pt->ct1/"), func(t *testing.T) { values1, plaintext1, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t) @@ -520,10 +517,10 @@ func testEvaluatorMul(testContext *testParams, t *testing.T) { ciphertext2 := testContext.evaluator.MulRelinNew(ciphertext1, plaintext1) - verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext2, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values1, ciphertext2, testContext.params.LogSlots(), 0, t) }) - t.Run(testString(testContext, "Evaluator/Mul/ct0*ct1->ct0/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "Evaluator/Mul/ct0*ct1->ct0/"), func(t *testing.T) { values1, _, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t) values2, _, ciphertext2 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t) @@ -534,10 +531,10 @@ func testEvaluatorMul(testContext *testParams, t *testing.T) { testContext.evaluator.MulRelin(ciphertext1, ciphertext2, ciphertext1) - verifyTestVectors(testContext, testContext.decryptor, values2, ciphertext1, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values2, ciphertext1, testContext.params.LogSlots(), 0, t) }) - t.Run(testString(testContext, "Evaluator/Mul/ct0*ct1->ct1/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "Evaluator/Mul/ct0*ct1->ct1/"), func(t *testing.T) { values1, _, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t) values2, _, ciphertext2 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t) @@ -548,10 +545,10 @@ func testEvaluatorMul(testContext *testParams, t *testing.T) { testContext.evaluator.MulRelin(ciphertext1, ciphertext2, ciphertext2) - verifyTestVectors(testContext, testContext.decryptor, values2, ciphertext2, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values2, ciphertext2, testContext.params.LogSlots(), 0, t) }) - t.Run(testString(testContext, "Evaluator/Mul/ct0*ct1->ct2/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "Evaluator/Mul/ct0*ct1->ct2/"), func(t *testing.T) { values1, _, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t) values2, _, ciphertext2 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t) @@ -562,10 +559,10 @@ func testEvaluatorMul(testContext *testParams, t *testing.T) { ciphertext3 := testContext.evaluator.MulRelinNew(ciphertext1, ciphertext2) - verifyTestVectors(testContext, testContext.decryptor, values2, ciphertext3, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values2, ciphertext3, testContext.params.LogSlots(), 0, t) }) - t.Run(testString(testContext, "Evaluator/Mul/ct0*ct0->ct0/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "Evaluator/Mul/ct0*ct0->ct0/"), func(t *testing.T) { values1, _, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t) @@ -575,10 +572,10 @@ func testEvaluatorMul(testContext *testParams, t *testing.T) { testContext.evaluator.MulRelin(ciphertext1, ciphertext1, ciphertext1) - verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t) }) - t.Run(testString(testContext, "Evaluator/Mul/ct0*ct0->ct1/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "Evaluator/Mul/ct0*ct0->ct1/"), func(t *testing.T) { values1, _, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t) @@ -588,10 +585,10 @@ func testEvaluatorMul(testContext *testParams, t *testing.T) { ciphertext2 := testContext.evaluator.MulRelinNew(ciphertext1, ciphertext1) - verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext2, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values1, ciphertext2, testContext.params.LogSlots(), 0, t) }) - t.Run(testString(testContext, "Evaluator/Mul/Relinearize(ct0*ct1->ct0)/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "Evaluator/Mul/Relinearize(ct0*ct1->ct0)/"), func(t *testing.T) { if testContext.params.PCount() == 0 { t.Skip("#Pi is empty") @@ -609,10 +606,10 @@ func testEvaluatorMul(testContext *testParams, t *testing.T) { testContext.evaluator.Relinearize(ciphertext1, ciphertext1) require.Equal(t, ciphertext1.Degree(), 1) - verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t) }) - t.Run(testString(testContext, "Evaluator/Mul/Relinearize(ct0*ct1->ct1)/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "Evaluator/Mul/Relinearize(ct0*ct1->ct1)/"), func(t *testing.T) { if testContext.params.PCount() == 0 { t.Skip("#Pi is empty") @@ -630,14 +627,14 @@ func testEvaluatorMul(testContext *testParams, t *testing.T) { testContext.evaluator.Relinearize(ciphertext2, ciphertext2) require.Equal(t, ciphertext2.Degree(), 1) - verifyTestVectors(testContext, testContext.decryptor, values2, ciphertext2, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values2, ciphertext2, testContext.params.LogSlots(), 0, t) }) } func testFunctions(testContext *testParams, t *testing.T) { - t.Run(testString(testContext, "Evaluator/PowerOf2/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "Evaluator/PowerOf2/"), func(t *testing.T) { if testContext.params.PCount() == 0 { t.Skip("#Pi is empty") @@ -664,10 +661,10 @@ func testFunctions(testContext *testParams, t *testing.T) { testContext.evaluator.PowerOf2(ciphertext, n, ciphertext) - verifyTestVectors(testContext, testContext.decryptor, valuesWant, ciphertext, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, valuesWant, ciphertext, testContext.params.LogSlots(), 0, t) }) - t.Run(testString(testContext, "Evaluator/Power/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "Evaluator/Power/"), func(t *testing.T) { if testContext.params.PCount() == 0 { t.Skip("#Pi is empty") @@ -687,10 +684,10 @@ func testFunctions(testContext *testParams, t *testing.T) { testContext.evaluator.Power(ciphertext, n, ciphertext) - verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t) }) - t.Run(testString(testContext, "Evaluator/Inverse/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "Evaluator/Inverse/"), func(t *testing.T) { if testContext.params.PCount() == 0 { t.Skip("#Pi is empty") @@ -710,7 +707,7 @@ func testFunctions(testContext *testParams, t *testing.T) { ciphertext = testContext.evaluator.InverseNew(ciphertext, n) - verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t) }) } @@ -718,7 +715,7 @@ func testEvaluatePoly(testContext *testParams, t *testing.T) { var err error - t.Run(testString(testContext, "EvaluatePoly/Exp/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "EvaluatePoly/Exp/"), func(t *testing.T) { if testContext.params.PCount() == 0 { t.Skip("#Pi is empty") @@ -751,7 +748,7 @@ func testEvaluatePoly(testContext *testParams, t *testing.T) { t.Error(err) } - verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t) }) } @@ -759,7 +756,7 @@ func testChebyshevInterpolator(testContext *testParams, t *testing.T) { var err error - t.Run(testString(testContext, "ChebyshevInterpolator/Sin/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "ChebyshevInterpolator/Sin/"), func(t *testing.T) { if testContext.params.PCount() == 0 { t.Skip("#Pi is empty") @@ -773,21 +770,21 @@ func testChebyshevInterpolator(testContext *testParams, t *testing.T) { values, _, ciphertext := newTestVectors(testContext, testContext.encryptorSk, complex(-1, 0), complex(1, 0), t) - cheby := Approximate(cmplx.Sin, complex(-1.5, 0), complex(1.5, 0), 15) + poly := Approximate(cmplx.Sin, complex(-1.5, 0), complex(1.5, 0), 15) for i := range values { values[i] = cmplx.Sin(values[i]) } - eval.MultByConst(ciphertext, 2/(cheby.b-cheby.a), ciphertext) - eval.AddConst(ciphertext, (-cheby.a-cheby.b)/(cheby.b-cheby.a), ciphertext) - eval.Rescale(ciphertext, eval.(*evaluator).scale, ciphertext) + eval.MultByConst(ciphertext, 2/(poly.B-poly.A), ciphertext) + eval.AddConst(ciphertext, (-poly.A-poly.B)/(poly.B-poly.A), ciphertext) + eval.Rescale(ciphertext, testContext.params.Scale(), ciphertext) - if ciphertext, err = eval.EvaluateCheby(ciphertext, cheby, ciphertext.Scale); err != nil { + if ciphertext, err = eval.EvaluatePoly(ciphertext, poly, ciphertext.Scale); err != nil { t.Error(err) } - verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t) }) } @@ -795,7 +792,7 @@ func testDecryptPublic(testContext *testParams, t *testing.T) { var err error - t.Run(testString(testContext, "DecryptPublic/Sin/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "DecryptPublic/Sin/"), func(t *testing.T) { if testContext.params.PCount() == 0 { t.Skip("#Pi is empty") @@ -809,17 +806,17 @@ func testDecryptPublic(testContext *testParams, t *testing.T) { values, _, ciphertext := newTestVectors(testContext, testContext.encryptorSk, complex(-1, 0), complex(1, 0), t) - cheby := Approximate(cmplx.Sin, complex(-1.5, 0), complex(1.5, 0), 15) + poly := Approximate(cmplx.Sin, complex(-1.5, 0), complex(1.5, 0), 15) for i := range values { values[i] = cmplx.Sin(values[i]) } - eval.MultByConst(ciphertext, 2/(cheby.b-cheby.a), ciphertext) - eval.AddConst(ciphertext, (-cheby.a-cheby.b)/(cheby.b-cheby.a), ciphertext) - eval.Rescale(ciphertext, eval.(*evaluator).scale, ciphertext) + eval.MultByConst(ciphertext, 2/(poly.B-poly.A), ciphertext) + eval.AddConst(ciphertext, (-poly.A-poly.B)/(poly.B-poly.A), ciphertext) + eval.Rescale(ciphertext, testContext.params.Scale(), ciphertext) - if ciphertext, err = eval.EvaluateCheby(ciphertext, cheby, ciphertext.Scale); err != nil { + if ciphertext, err = eval.EvaluatePoly(ciphertext, poly, ciphertext.Scale); err != nil { t.Error(err) } @@ -827,13 +824,13 @@ func testDecryptPublic(testContext *testParams, t *testing.T) { valuesHave := testContext.encoder.Decode(plaintext, testContext.params.LogSlots()) - verifyTestVectors(testContext, nil, values, valuesHave, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, nil, values, valuesHave, testContext.params.LogSlots(), 0, t) sigma := testContext.encoder.GetErrSTDCoeffDomain(values, valuesHave, plaintext.Scale) valuesHave = testContext.encoder.DecodePublic(plaintext, testContext.params.LogSlots(), sigma) - verifyTestVectors(testContext, nil, values, valuesHave, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, nil, values, valuesHave, testContext.params.LogSlots(), 0, t) }) } @@ -849,7 +846,7 @@ func testSwitchKeys(testContext *testParams, t *testing.T) { switchingKey = testContext.kgen.GenSwitchingKey(testContext.sk, sk2) } - t.Run(testString(testContext, "SwitchKeys/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "SwitchKeys/"), func(t *testing.T) { if testContext.params.PCount() == 0 { t.Skip("#Pi is empty") @@ -859,10 +856,10 @@ func testSwitchKeys(testContext *testParams, t *testing.T) { testContext.evaluator.SwitchKeys(ciphertext, switchingKey, ciphertext) - verifyTestVectors(testContext, decryptorSk2, values, ciphertext, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, decryptorSk2, values, ciphertext, testContext.params.LogSlots(), 0, t) }) - t.Run(testString(testContext, "SwitchKeysNew/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "SwitchKeysNew/"), func(t *testing.T) { if testContext.params.PCount() == 0 { t.Skip("#Pi is empty") @@ -872,7 +869,7 @@ func testSwitchKeys(testContext *testParams, t *testing.T) { ciphertext = testContext.evaluator.SwitchKeysNew(ciphertext, switchingKey) - verifyTestVectors(testContext, decryptorSk2, values, ciphertext, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, decryptorSk2, values, ciphertext, testContext.params.LogSlots(), 0, t) }) } @@ -886,7 +883,7 @@ func testAutomorphisms(testContext *testParams, t *testing.T) { rotKey := testContext.kgen.GenRotationKeysForRotations(rots, true, testContext.sk) evaluator := testContext.evaluator.WithKey(rlwe.EvaluationKey{Rlk: testContext.rlk, Rtks: rotKey}) - t.Run(testString(testContext, "Conjugate/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "Conjugate/"), func(t *testing.T) { if testContext.params.PCount() == 0 { t.Skip("#Pi is empty") @@ -900,10 +897,10 @@ func testAutomorphisms(testContext *testParams, t *testing.T) { evaluator.Conjugate(ciphertext, ciphertext) - verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t) }) - t.Run(testString(testContext, "ConjugateNew/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "ConjugateNew/"), func(t *testing.T) { if testContext.params.PCount() == 0 { t.Skip("#Pi is empty") @@ -917,10 +914,10 @@ func testAutomorphisms(testContext *testParams, t *testing.T) { ciphertext = evaluator.ConjugateNew(ciphertext) - verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t) }) - t.Run(testString(testContext, "Rotate/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "Rotate/"), func(t *testing.T) { values1, _, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t) @@ -928,28 +925,28 @@ func testAutomorphisms(testContext *testParams, t *testing.T) { for _, n := range rots { evaluator.Rotate(ciphertext1, n, ciphertext2) - verifyTestVectors(testContext, testContext.decryptor, utils.RotateComplex128Slice(values1, n), ciphertext2, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, utils.RotateComplex128Slice(values1, n), ciphertext2, testContext.params.LogSlots(), 0, t) } }) - t.Run(testString(testContext, "RotateNew/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "RotateNew/"), func(t *testing.T) { values1, _, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t) for _, n := range rots { - verifyTestVectors(testContext, testContext.decryptor, utils.RotateComplex128Slice(values1, n), evaluator.RotateNew(ciphertext1, n), testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, utils.RotateComplex128Slice(values1, n), evaluator.RotateNew(ciphertext1, n), testContext.params.LogSlots(), 0, t) } }) - t.Run(testString(testContext, "RotateHoisted/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "RotateHoisted/"), func(t *testing.T) { values1, _, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t) - ciphertexts := evaluator.RotateHoisted(ciphertext1, rots) + ciphertexts := evaluator.RotateHoistedNew(ciphertext1, rots) for _, n := range rots { - verifyTestVectors(testContext, testContext.decryptor, utils.RotateComplex128Slice(values1, n), ciphertexts[n], testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, utils.RotateComplex128Slice(values1, n), ciphertexts[n], testContext.params.LogSlots(), 0, t) } }) } @@ -960,7 +957,7 @@ func testInnerSum(testContext *testParams, t *testing.T) { t.Skip("#Pi is empty") } - t.Run(testString(testContext, "InnerSum/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "InnerSum/"), func(t *testing.T) { batch := 2 n := 35 @@ -983,10 +980,10 @@ func testInnerSum(testContext *testParams, t *testing.T) { } } - verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t) }) - t.Run(testString(testContext, "InnerSumLog/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "InnerSumLog/"), func(t *testing.T) { batch := 3 n := 15 @@ -1010,7 +1007,7 @@ func testInnerSum(testContext *testParams, t *testing.T) { } } - verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t) }) } @@ -1021,7 +1018,7 @@ func testReplicate(testContext *testParams, t *testing.T) { t.Skip("#Pi is empty") } - t.Run(testString(testContext, "Replicate/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "Replicate/"), func(t *testing.T) { batch := 2 n := 35 @@ -1044,10 +1041,10 @@ func testReplicate(testContext *testParams, t *testing.T) { } } - verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t) }) - t.Run(testString(testContext, "ReplicateLog/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "ReplicateLog/"), func(t *testing.T) { batch := 3 n := 15 @@ -1071,7 +1068,7 @@ func testReplicate(testContext *testParams, t *testing.T) { } } - verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t) }) } @@ -1082,7 +1079,7 @@ func testLinearTransform(testContext *testParams, t *testing.T) { t.Skip("#Pi is empty") } - t.Run(testString(testContext, "LinearTransform/BSGS/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "LinearTransform/BSGS/"), func(t *testing.T) { params := testContext.params @@ -1116,7 +1113,7 @@ func testLinearTransform(testContext *testParams, t *testing.T) { eval := testContext.evaluator.WithKey(rlwe.EvaluationKey{Rlk: testContext.rlk, Rtks: rotKey}) - res := eval.LinearTransform(ciphertext1, ptDiagMatrix)[0] + res := eval.LinearTransformNew(ciphertext1, ptDiagMatrix)[0] tmp := make([]complex128, params.Slots()) copy(tmp, values1) @@ -1130,10 +1127,10 @@ func testLinearTransform(testContext *testParams, t *testing.T) { values1[i] += tmp[(i+15)%params.Slots()] } - verifyTestVectors(testContext, testContext.decryptor, values1, res, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values1, res, testContext.params.LogSlots(), 0, t) }) - t.Run(testString(testContext, "LinearTransform/Naive/"), func(t *testing.T) { + t.Run(GetTestName(testContext.params, "LinearTransform/Naive/"), func(t *testing.T) { params := testContext.params @@ -1157,7 +1154,7 @@ func testLinearTransform(testContext *testParams, t *testing.T) { eval := testContext.evaluator.WithKey(rlwe.EvaluationKey{Rlk: testContext.rlk, Rtks: rotKey}) - res := eval.LinearTransform(ciphertext1, ptDiagMatrix)[0] + res := eval.LinearTransformNew(ciphertext1, ptDiagMatrix)[0] tmp := make([]complex128, params.Slots()) copy(tmp, values1) @@ -1166,7 +1163,7 @@ func testLinearTransform(testContext *testParams, t *testing.T) { values1[i] += tmp[(i-1+params.Slots())%params.Slots()] } - verifyTestVectors(testContext, testContext.decryptor, values1, res, testContext.params.LogSlots(), 0, t) + verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values1, res, testContext.params.LogSlots(), 0, t) }) } @@ -1212,7 +1209,7 @@ func testMarshaller(testctx *testParams, t *testing.T) { }) t.Run("Marshaller/Ciphertext/", func(t *testing.T) { - t.Run(testString(testctx, "EndToEnd/"), func(t *testing.T) { + t.Run(GetTestName(testctx.params, "EndToEnd/"), func(t *testing.T) { ciphertextWant := NewCiphertextRandom(testctx.prng, testctx.params, 2, testctx.params.MaxLevel(), testctx.params.Scale()) @@ -1231,7 +1228,7 @@ func testMarshaller(testctx *testParams, t *testing.T) { } }) - t.Run(testString(testctx, "Minimal/"), func(t *testing.T) { + t.Run(GetTestName(testctx.params, "Minimal/"), func(t *testing.T) { ciphertext := NewCiphertextRandom(testctx.prng, testctx.params, 0, testctx.params.MaxLevel(), testctx.params.Scale()) diff --git a/ckks/encoder.go b/ckks/encoder.go index 4fed7ddf..e9ab73cf 100644 --- a/ckks/encoder.go +++ b/ckks/encoder.go @@ -8,6 +8,7 @@ import ( "math/bits" "github.com/ldsec/lattigo/v2/ring" + "github.com/ldsec/lattigo/v2/rlwe" "github.com/ldsec/lattigo/v2/utils" ) @@ -27,8 +28,9 @@ type Encoder interface { EncodeNTTNew(values []complex128, logSlots int) (plaintext *Plaintext) EncodeNTTAtLvlNew(level int, values []complex128, logSlots int) (plaintext *Plaintext) - EncodeDiagMatrixBSGSAtLvl(level int, vector map[int][]complex128, scale, maxM1N2Ratio float64, logSlots int) (matrix *PtDiagMatrix) - EncodeDiagMatrixAtLvl(level int, vector map[int][]complex128, scale float64, logSlots int) (matrix *PtDiagMatrix) + EncodeDiagMatrixBSGSAtLvl(level int, vector map[int][]complex128, scale, maxM1N2Ratio float64, logSlots int) (matrix PtDiagMatrix) + EncodeDiagMatrixAtLvl(level int, vector map[int][]complex128, scale float64, logSlots int) (matrix PtDiagMatrix) + EncodeDiagonal(logSlots, level int, scale float64, m []complex128) (vecQP rlwe.PolyQP) Decode(plaintext *Plaintext, logSlots int) (res []complex128) DecodePublic(plaintext *Plaintext, logSlots int, sigma float64) []complex128 @@ -64,8 +66,6 @@ type EncoderBigComplex interface { // encoder is a struct storing the necessary parameters to encode a slice of complex number on a Plaintext. type encoder struct { params Parameters - ringQ *ring.Ring - ringP *ring.Ring bigintChain []*big.Int bigintCoeffs []*big.Int qHalf *big.Int @@ -87,19 +87,6 @@ func newEncoder(params Parameters) encoder { m := 2 * params.N() - var q *ring.Ring - var err error - if q, err = ring.NewRing(params.N(), params.Q()); err != nil { - panic(err) - } - - var p *ring.Ring - if params.PCount() != 0 { - if p, err = ring.NewRing(params.N(), params.P()); err != nil { - panic(err) - } - } - rotGroup := make([]int, m>>1) fivePows := 1 for i := 0; i < m>>2; i++ { @@ -113,16 +100,14 @@ func newEncoder(params Parameters) encoder { panic(err) } - gaussianSampler := ring.NewGaussianSampler(prng, q, params.Sigma(), int(6*params.Sigma())) + gaussianSampler := ring.NewGaussianSampler(prng, params.RingQ(), params.Sigma(), int(6*params.Sigma())) return encoder{ params: params, - ringQ: q, - ringP: p, bigintChain: genBigIntChain(params.Q()), bigintCoeffs: make([]*big.Int, m>>1), qHalf: ring.NewUint(0), - polypool: q.NewPoly(), + polypool: params.RingQ().NewPoly(), m: m, rotGroup: rotGroup, gaussianSampler: gaussianSampler, @@ -166,7 +151,7 @@ func (encoder *encoderComplex128) EncodeAtLvlNew(level int, values []complex128, // Encode encodes a slice of complex128 of length slots = 2^{logSlots} on the input plaintext. func (encoder *encoderComplex128) Encode(plaintext *Plaintext, values []complex128, logSlots int) { encoder.Embed(values, logSlots) - encoder.ScaleUp(plaintext.Value, plaintext.Scale, encoder.ringQ.Modulus[:plaintext.Level()+1]) + encoder.ScaleUp(plaintext.Value, plaintext.Scale, encoder.params.RingQ().Modulus[:plaintext.Level()+1]) plaintext.Value.IsNTT = false } @@ -188,7 +173,7 @@ func (encoder *encoderComplex128) EncodeNTTAtLvlNew(level int, values []complex1 // Returns a plaintext in the NTT domain. func (encoder *encoderComplex128) EncodeNTT(plaintext *Plaintext, values []complex128, logSlots int) { encoder.Encode(plaintext, values, logSlots) - encoder.ringQ.NTTLvl(plaintext.Level(), plaintext.Value, plaintext.Value) + encoder.params.RingQ().NTTLvl(plaintext.Level(), plaintext.Value, plaintext.Value) plaintext.Value.IsNTT = true } @@ -209,10 +194,10 @@ func (encoder *encoderComplex128) Embed(values []complex128, logSlots int) { } invfft(encoder.values, slots, encoder.m, encoder.rotGroup, encoder.roots) + N := encoder.params.RingQ().N + gap := (N >> 1) / slots - gap := (encoder.ringQ.N >> 1) / slots - - for i, jdx, idx := 0, encoder.ringQ.N>>1, 0; i < slots; i, jdx, idx = i+1, jdx+gap, idx+gap { + for i, jdx, idx := 0, N>>1, 0; i < slots; i, jdx, idx = i+1, jdx+gap, idx+gap { encoder.valuesfloat[idx] = real(encoder.values[i]) encoder.valuesfloat[jdx] = imag(encoder.values[i]) } @@ -340,9 +325,9 @@ func polyToComplexCRT(poly *ring.Poly, bigintCoeffs []*big.Int, values []complex func (encoder *encoderComplex128) plaintextToComplex(level int, scale float64, logSlots int, p *ring.Poly, values []complex128) { if level == 0 { - polyToComplexNoCRT(p.Coeffs[0], encoder.values, scale, logSlots, encoder.ringQ.Modulus[0]) + polyToComplexNoCRT(p.Coeffs[0], encoder.values, scale, logSlots, encoder.params.RingQ().Modulus[0]) } else { - polyToComplexCRT(p, encoder.bigintCoeffs, values, scale, logSlots, encoder.ringQ, encoder.bigintChain[level]) + polyToComplexCRT(p, encoder.bigintCoeffs, values, scale, logSlots, encoder.params.RingQ(), encoder.bigintChain[level]) } } @@ -369,12 +354,12 @@ func polyToFloatNoCRT(coeffs []uint64, values []float64, scale float64, Q uint64 // PtDiagMatrix is a struct storing a plaintext diagonalized matrix // ready to be evaluated on a ciphertext using evaluator.MultiplyByDiagMatrice. type PtDiagMatrix struct { - LogSlots int // Log of the number of slots of the plaintext (needed to compute the appropriate rotation keys) - N1 int // N1 is the number of inner loops of the baby-step giant-step algo used in the evaluation. - Level int // Level is the level at which the matrix is encoded (can be circuit dependant) - Scale float64 // Scale is the scale at which the matrix is encoded (can be circuit dependant) - Vec map[int][2]*ring.Poly // Vec is the matrix, in diagonal form, where each entry of vec is an indexed non zero diagonal. - naive bool + LogSlots int // Log of the number of slots of the plaintext (needed to compute the appropriate rotation keys) + N1 int // N1 is the number of inner loops of the baby-step giant-step algo used in the evaluation. + Level int // Level is the level at which the matrix is encoded (can be circuit dependant) + Scale float64 // Scale is the scale at which the matrix is encoded (can be circuit dependant) + Vec map[int]rlwe.PolyQP // Vec is the matrix, in diagonal form, where each entry of vec is an indexed non zero diagonal. + Naive bool isGaussian bool // Each diagonal of the matrix is of the form [k, ..., k] for k a gaussian integer } @@ -392,7 +377,6 @@ func bsgsIndex(el interface{}, slots, N1 int) (index map[int][]int, rotations [] } else { index[idx1] = append(index[idx1], idx2) } - if !utils.IsInSliceInt(idx2, rotations) { rotations = append(rotations, idx2) } @@ -411,7 +395,21 @@ func bsgsIndex(el interface{}, slots, N1 int) (index map[int][]int, rotations [] rotations = append(rotations, idx2) } } - case map[int][2]*ring.Poly: + case map[int]rlwe.PolyQP: + for key := range element { + key &= (slots - 1) + idx1 := key / N1 + idx2 := key & (N1 - 1) + if index[idx1] == nil { + index[idx1] = []int{idx2} + } else { + index[idx1] = append(index[idx1], idx2) + } + if !utils.IsInSliceInt(idx2, rotations) { + rotations = append(rotations, idx2) + } + } + case []int: for key := range element { key &= (slots - 1) idx1 := key / N1 @@ -426,6 +424,7 @@ func bsgsIndex(el interface{}, slots, N1 int) (index map[int][]int, rotations [] } } } + return } @@ -435,90 +434,79 @@ func bsgsIndex(el interface{}, slots, N1 int) (index map[int][]int, rotations [] // Faster if there is more than a few non-zero diagonals. // maxM1N2Ratio is the maximum ratio between the inner and outer loop of the baby-step giant-step algorithm used in evaluator.LinearTransform. // Optimal maxM1N2Ratio value is between 4 and 16 depending on the sparsity of the matrix. -func (encoder *encoderComplex128) EncodeDiagMatrixBSGSAtLvl(level int, diagMatrix map[int][]complex128, scale, maxM1N2Ratio float64, logSlots int) (matrix *PtDiagMatrix) { +func (encoder *encoderComplex128) EncodeDiagMatrixBSGSAtLvl(level int, diagMatrix map[int][]complex128, scale, maxM1N2Ratio float64, logSlots int) (matrix PtDiagMatrix) { - matrix = new(PtDiagMatrix) - matrix.LogSlots = logSlots slots := 1 << logSlots // N1*N2 = N - N1 := findbestbabygiantstepsplit(diagMatrix, slots, maxM1N2Ratio) - matrix.N1 = N1 + n1 := FindBestBSGSSplit(diagMatrix, slots, maxM1N2Ratio) - index, _ := bsgsIndex(diagMatrix, slots, N1) + index, _ := bsgsIndex(diagMatrix, slots, n1) - matrix.Vec = make(map[int][2]*ring.Poly) - - matrix.Level = level - matrix.Scale = scale + vec := make(map[int]rlwe.PolyQP) for j := range index { for _, i := range index[j] { // manages inputs that have rotation between 0 and slots-1 or between -slots/2 and slots/2-1 - v := diagMatrix[N1*j+i] + v := diagMatrix[n1*j+i] if len(v) == 0 { - v = diagMatrix[(N1*j+i)-slots] + v = diagMatrix[(n1*j+i)-slots] } - matrix.Vec[N1*j+i] = encoder.encodeDiagonal(logSlots, level, scale, rotate(v, -N1*j)) + if len(v) != slots { + panic("diagMatrix []complex slices mut have len '1< 0 { - encoder.ringQ.PolyToBigint(encoder.polypool, encoder.bigintCoeffs) + encoder.params.RingQ().PolyToBigint(encoder.polypool, encoder.bigintCoeffs) Q := encoder.bigintChain[plaintext.Level()] @@ -716,7 +704,7 @@ func (encoder *encoderComplex128) decodeCoeffsPublic(plaintext *Plaintext, sigma // We can directly get the coefficients } else { - Q := encoder.ringQ.Modulus[0] + Q := encoder.params.RingQ().Modulus[0] coeffs := encoder.polypool.Coeffs[0] for i := range res { @@ -826,7 +814,7 @@ func (encoder *encoderBigComplex) EncodeNTTAtLvlNew(level int, values []*ring.Co // Returns a plaintext in the NTT domain. func (encoder *encoderBigComplex) EncodeNTT(plaintext *Plaintext, values []*ring.Complex, logSlots int) { encoder.Encode(plaintext, values, logSlots) - encoder.ringQ.NTTLvl(plaintext.Level(), plaintext.Value, plaintext.Value) + encoder.params.RingQ().NTTLvl(plaintext.Level(), plaintext.Value, plaintext.Value) plaintext.Value.IsNTT = true } @@ -849,25 +837,25 @@ func (encoder *encoderBigComplex) Encode(plaintext *Plaintext, values []*ring.Co encoder.InvFFT(encoder.values, slots) - gap := (encoder.ringQ.N >> 1) / slots + gap := (encoder.params.RingQ().N >> 1) / slots - for i, jdx, idx := 0, (encoder.ringQ.N >> 1), 0; i < slots; i, jdx, idx = i+1, jdx+gap, idx+gap { + for i, jdx, idx := 0, (encoder.params.RingQ().N >> 1), 0; i < slots; i, jdx, idx = i+1, jdx+gap, idx+gap { encoder.valuesfloat[idx].Set(encoder.values[i].Real()) encoder.valuesfloat[jdx].Set(encoder.values[i].Imag()) } - scaleUpVecExactBigFloat(encoder.valuesfloat, plaintext.Scale, encoder.ringQ.Modulus[:plaintext.Level()+1], plaintext.Value.Coeffs) + scaleUpVecExactBigFloat(encoder.valuesfloat, plaintext.Scale, encoder.params.RingQ().Modulus[:plaintext.Level()+1], plaintext.Value.Coeffs) coeffsBigInt := make([]*big.Int, encoder.params.N()) - encoder.ringQ.PolyToBigint(plaintext.Value, coeffsBigInt) + encoder.params.RingQ().PolyToBigint(plaintext.Value, coeffsBigInt) - for i := 0; i < (encoder.ringQ.N >> 1); i++ { + for i := 0; i < (encoder.params.RingQ().N >> 1); i++ { encoder.values[i].Real().Set(encoder.zero) encoder.values[i].Imag().Set(encoder.zero) } - for i := 0; i < encoder.ringQ.N; i++ { + for i := 0; i < encoder.params.RingQ().N; i++ { encoder.valuesfloat[i].Set(encoder.zero) } } @@ -891,18 +879,18 @@ func (encoder *encoderBigComplex) decodePublic(plaintext *Plaintext, logSlots in panic("cannot Decode: too many slots for the given ring degree") } - encoder.ringQ.InvNTTLvl(plaintext.Level(), plaintext.Value, encoder.polypool) + encoder.params.RingQ().InvNTTLvl(plaintext.Level(), plaintext.Value, encoder.polypool) if sigma != 0 { // B = floor(sigma * sqrt(2*pi)) - encoder.gaussianSampler.ReadAndAddFromDistLvl(plaintext.Level(), encoder.polypool, encoder.ringQ, sigma, int(2.5066282746310002*sigma+0.5)) + encoder.gaussianSampler.ReadAndAddFromDistLvl(plaintext.Level(), encoder.polypool, encoder.params.RingQ(), sigma, int(2.5066282746310002*sigma+0.5)) } - encoder.ringQ.PolyToBigint(encoder.polypool, encoder.bigintCoeffs) + encoder.params.RingQ().PolyToBigint(encoder.polypool, encoder.bigintCoeffs) Q := encoder.bigintChain[plaintext.Level()] - maxSlots := encoder.ringQ.N >> 1 + maxSlots := encoder.params.RingQ().N >> 1 scaleFlo := ring.NewFloat(plaintext.Scale, encoder.logPrecision) @@ -981,7 +969,7 @@ func (encoder *encoderBigComplex) InvFFT(values []*ring.Complex, N int) { values[i][1].Quo(values[i][1], NBig) } - sliceBitReverseInPlaceRingComplex(values, N) + SliceBitReverseInPlaceRingComplex(values, N) } // FFT evaluates the decoding matrix on a slice fo ring.Complex values. @@ -992,7 +980,7 @@ func (encoder *encoderBigComplex) FFT(values []*ring.Complex, N int) { u := ring.NewComplex(nil, nil) v := ring.NewComplex(nil, nil) - sliceBitReverseInPlaceRingComplex(values, N) + SliceBitReverseInPlaceRingComplex(values, N) for len := 2; len <= N; len <<= 1 { for i := 0; i < N; i += len { diff --git a/ckks/evaluator.go b/ckks/evaluator.go index 2f99a5dd..33c35d10 100644 --- a/ckks/evaluator.go +++ b/ckks/evaluator.go @@ -49,11 +49,11 @@ type Evaluator interface { // Constant Multiplication MultByConstNew(ctIn *Ciphertext, constant interface{}) (ctOut *Ciphertext) MultByConst(ctIn *Ciphertext, constant interface{}, ctOut *Ciphertext) - MultByGaussianInteger(ctIn *Ciphertext, cReal, cImag int64, ctOut *Ciphertext) + MultByGaussianInteger(ctIn *Ciphertext, cReal, cImag interface{}, ctOut *Ciphertext) // Constant Multiplication with Addition MultByConstAndAdd(ctIn *Ciphertext, constant interface{}, ctOut *Ciphertext) - MultByGaussianIntegerAndAdd(ctIn *Ciphertext, cReal, cImag int64, ctOut *Ciphertext) + MultByGaussianIntegerAndAdd(ctIn *Ciphertext, cReal, cImag interface{}, ctOut *Ciphertext) // Multiplication by the imaginary unit MultByiNew(ctIn *Ciphertext) (ctOut *Ciphertext) @@ -74,7 +74,10 @@ type Evaluator interface { // Slot Rotations RotateNew(ctIn *Ciphertext, k int) (ctOut *Ciphertext) Rotate(ctIn *Ciphertext, k int, ctOut *Ciphertext) - RotateHoisted(ctIn *Ciphertext, rotations []int) (ctOut map[int]*Ciphertext) + RotateHoistedNew(ctIn *Ciphertext, rotations []int) (ctOut map[int]*Ciphertext) + RotateHoisted(ctIn *Ciphertext, rotations []int, ctOut map[int]*Ciphertext) + PermuteNTTHoistedNoModDown(level int, c2DecompQP []rlwe.PolyQP, k int, ct0OutQ, ct1OutQ, ct0OutP, ct1OutP *ring.Poly) + PermuteNTTHoisted(level int, c0, c1 *ring.Poly, c2DecompQP []rlwe.PolyQP, k int, cOut0, cOut1 *ring.Poly) // =========================== // === Advanced Arithmetic === @@ -90,16 +93,16 @@ type Evaluator interface { PowerNew(ctIn *Ciphertext, degree int) (ctOut *Ciphertext) // Polynomial evaluation - EvaluatePoly(ctIn *Ciphertext, coeffs *Poly, targetScale float64) (ctOut *Ciphertext, err error) - EvaluateCheby(ctIn *Ciphertext, cheby *ChebyshevInterpolation, targetScale float64) (ctOut *Ciphertext, err error) + EvaluatePoly(ctIn *Ciphertext, pol *Polynomial, targetScale float64) (ctOut *Ciphertext, err error) // Inversion InverseNew(ctIn *Ciphertext, steps int) (ctOut *Ciphertext) // Linear Transformations - LinearTransform(ctIn *Ciphertext, linearTransform interface{}) (ctOut []*Ciphertext) - MultiplyByDiagMatrix(ctIn *Ciphertext, matrix *PtDiagMatrix, c2QiQDecomp, c2QiPDecomp []*ring.Poly, ctOut *Ciphertext) - MultiplyByDiagMatrixBSGS(ctIn *Ciphertext, matrix *PtDiagMatrix, c2QiQDecomp, c2QiPDecomp []*ring.Poly, ctOut *Ciphertext) + LinearTransformNew(ctIn *Ciphertext, linearTransform interface{}) (ctOut []*Ciphertext) + LinearTransform(ctIn *Ciphertext, linearTransform interface{}, ctOut []*Ciphertext) + MultiplyByDiagMatrix(ctIn *Ciphertext, matrix PtDiagMatrix, c2DecompQP []rlwe.PolyQP, ctOut *Ciphertext) + MultiplyByDiagMatrixBSGS(ctIn *Ciphertext, matrix PtDiagMatrix, c2DecompQP []rlwe.PolyQP, ctOut *Ciphertext) // Inner sum InnerSumLog(ctIn *Ciphertext, batch, n int, ctOut *Ciphertext) @@ -109,6 +112,10 @@ type Evaluator interface { ReplicateLog(ctIn *Ciphertext, batch, n int, ctOut *Ciphertext) Replicate(ctIn *Ciphertext, batch, n int, ctOut *Ciphertext) + // Trace + Trace(ctIn *Ciphertext, logSlotsStart, logSlotsEnd int, ctOut *Ciphertext) + TraceNew(ctIn *Ciphertext, logSlotsStart, logSlotsEnd int) (ctOut *Ciphertext) + // ============================= // === Ciphertext Management === // ============================= @@ -138,6 +145,9 @@ type Evaluator interface { // ============== // === Others === // ============== + GetKeySwitcher() *rlwe.KeySwitcher + PoolQMul() [3]*ring.Poly + CtxPool() *Ciphertext ShallowCopy() Evaluator WithKey(rlwe.EvaluationKey) Evaluator } @@ -156,10 +166,6 @@ type evaluator struct { type evaluatorBase struct { params Parameters - scale float64 - - ringQ *ring.Ring - ringP *ring.Ring } type evaluatorBuffers struct { @@ -167,21 +173,28 @@ type evaluatorBuffers struct { ctxpool *Ciphertext // Memory pool for ciphertext that need to be scaled up (to be removed eventually) } +// PoolQMul returns a pointer to internal memory pool poolQMul. +func (eval *evaluator) PoolQMul() [3]*ring.Poly { + return eval.poolQMul +} + +// CtxPool returns a pointer to internal memory pool CtxPool. +func (eval *evaluator) CtxPool() *Ciphertext { + return eval.ctxpool +} + func newEvaluatorBase(params Parameters) *evaluatorBase { ev := new(evaluatorBase) ev.params = params - ev.scale = params.Scale() - ev.ringQ = params.RingQ() - ev.ringP = params.RingP() - return ev } func newEvaluatorBuffers(evalBase *evaluatorBase) *evaluatorBuffers { buff := new(evaluatorBuffers) - ringQ := evalBase.ringQ + params := evalBase.params + ringQ := params.RingQ() buff.poolQMul = [3]*ring.Poly{ringQ.NewPoly(), ringQ.NewPoly(), ringQ.NewPoly()} - buff.ctxpool = NewCiphertext(evalBase.params, 2, evalBase.params.MaxLevel(), evalBase.params.Scale()) + buff.ctxpool = NewCiphertext(params, 2, params.MaxLevel(), params.Scale()) return buff } @@ -211,12 +224,18 @@ func (eval *evaluator) permuteNTTIndexesForKey(rtks *rlwe.RotationKeySet) *map[u return &map[uint64][]uint64{} } permuteNTTIndex := make(map[uint64][]uint64, len(rtks.Keys)) + N := uint64(eval.params.RingQ().N) for galEl := range rtks.Keys { - permuteNTTIndex[galEl] = ring.PermuteNTTIndex(galEl, uint64(eval.ringQ.N)) + permuteNTTIndex[galEl] = ring.PermuteNTTIndex(galEl, N) } return &permuteNTTIndex } +// GetKeySwitcher returns a pointer to the internal rlwe.KeySwither. +func (eval *evaluator) GetKeySwitcher() *rlwe.KeySwitcher { + return eval.KeySwitcher +} + // ShallowCopy creates a shallow copy of this evaluator in which all the read-only data-structures are // shared with the receiver and the temporary buffers are reallocated. The receiver and the returned // Evaluators can be used concurrently. @@ -290,13 +309,13 @@ func (eval *evaluator) newCiphertextBinary(op0, op1 Operand) (ctOut *Ciphertext) // Add adds op0 to op1 and returns the result in ctOut. func (eval *evaluator) Add(op0, op1 Operand, ctOut *Ciphertext) { eval.checkBinary(op0, op1, ctOut, utils.MaxInt(op0.Degree(), op1.Degree())) - eval.evaluateInPlace(op0, op1, ctOut, eval.ringQ.AddLvl) + eval.evaluateInPlace(op0, op1, ctOut, eval.params.RingQ().AddLvl) } // AddNoMod adds op0 to op1 and returns the result in ctOut, without modular reduction. func (eval *evaluator) AddNoMod(op0, op1 Operand, ctOut *Ciphertext) { eval.checkBinary(op0, op1, ctOut, utils.MaxInt(op0.Degree(), op1.Degree())) - eval.evaluateInPlace(op0, op1, ctOut, eval.ringQ.AddNoModLvl) + eval.evaluateInPlace(op0, op1, ctOut, eval.params.RingQ().AddNoModLvl) } // AddNew adds op0 to op1 and returns the result in a newly created element. @@ -318,13 +337,13 @@ func (eval *evaluator) Sub(op0, op1 Operand, ctOut *Ciphertext) { eval.checkBinary(op0, op1, ctOut, utils.MaxInt(op0.Degree(), op1.Degree())) - eval.evaluateInPlace(op0, op1, ctOut, eval.ringQ.SubLvl) + eval.evaluateInPlace(op0, op1, ctOut, eval.params.RingQ().SubLvl) level := utils.MinInt(utils.MinInt(op0.Level(), op1.Level()), ctOut.Level()) if op0.Degree() < op1.Degree() { for i := op0.Degree() + 1; i < op1.Degree()+1; i++ { - eval.ringQ.NegLvl(level, ctOut.Value[i], ctOut.Value[i]) + eval.params.RingQ().NegLvl(level, ctOut.Value[i], ctOut.Value[i]) } } @@ -335,13 +354,13 @@ func (eval *evaluator) SubNoMod(op0, op1 Operand, ctOut *Ciphertext) { eval.checkBinary(op0, op1, ctOut, utils.MaxInt(op0.Degree(), op1.Degree())) - eval.evaluateInPlace(op0, op1, ctOut, eval.ringQ.SubNoModLvl) + eval.evaluateInPlace(op0, op1, ctOut, eval.params.RingQ().SubNoModLvl) level := utils.MinInt(utils.MinInt(op0.Level(), op1.Level()), ctOut.Level()) if op0.Degree() < op1.Degree() { for i := op0.Degree() + 1; i < op1.Degree()+1; i++ { - eval.ringQ.NegLvl(level, ctOut.Value[i], ctOut.Value[i]) + eval.params.RingQ().NegLvl(level, ctOut.Value[i], ctOut.Value[i]) } } @@ -484,7 +503,7 @@ func (eval *evaluator) Neg(ct0 *Ciphertext, ctOut *Ciphertext) { } for i := range ct0.Value { - eval.ringQ.NegLvl(level, ct0.Value[i], ctOut.Value[i]) + eval.params.RingQ().NegLvl(level, ct0.Value[i], ctOut.Value[i]) } ctOut.Scale = ct0.Scale @@ -518,7 +537,7 @@ func (eval *evaluator) getConstAndScale(level int, constant interface{}) (cReal, valueFloat := cReal - float64(valueInt) if valueFloat != 0 { - scale = float64(eval.ringQ.Modulus[level]) + scale = float64(eval.params.RingQ().Modulus[level]) } } @@ -527,7 +546,7 @@ func (eval *evaluator) getConstAndScale(level int, constant interface{}) (cReal, valueFloat := cImag - float64(valueInt) if valueFloat != 0 { - scale = float64(eval.ringQ.Modulus[level]) + scale = float64(eval.params.RingQ().Modulus[level]) } } @@ -540,7 +559,7 @@ func (eval *evaluator) getConstAndScale(level int, constant interface{}) (cReal, valueFloat := cReal - float64(valueInt) if valueFloat != 0 { - scale = float64(eval.ringQ.Modulus[level]) + scale = float64(eval.params.RingQ().Modulus[level]) } } @@ -568,7 +587,7 @@ func (eval *evaluator) AddConst(ct0 *Ciphertext, constant interface{}, ctOut *Ci cReal, cImag, _ := eval.getConstAndScale(level, constant) - ringQ := eval.ringQ + ringQ := eval.params.RingQ() ctOut.Scale = ct0.Scale @@ -651,7 +670,7 @@ func (eval *evaluator) MultByConstAndAdd(ct0 *Ciphertext, constant interface{}, var scaledConst, scaledConstReal, scaledConstImag uint64 - ringQ := eval.ringQ + ringQ := eval.params.RingQ() // If a scaling would be required to multiply by the constant, // it equalizes scales such that the scales match in the end. @@ -788,7 +807,7 @@ func (eval *evaluator) MultByConst(ct0 *Ciphertext, constant interface{}, ctOut // [a + b*psi_qi^2, ....., a + b*psi_qi^2, a - b*psi_qi^2, ...., a - b*psi_qi^2] mod Qi // [{ N/2 }{ N/2 }] // Which is equivalent outside of the NTT domain to adding a to the first coefficient of ct0 and b to the N/2-th coefficient of ct0. - ringQ := eval.ringQ + ringQ := eval.params.RingQ() var scaledConst, scaledConstReal, scaledConstImag uint64 for i := 0; i < level+1; i++ { @@ -861,9 +880,11 @@ func (eval *evaluator) MultByConst(ct0 *Ciphertext, constant interface{}, ctOut ctOut.Scale = ct0.Scale * scale } -func (eval *evaluator) MultByGaussianInteger(ct0 *Ciphertext, cReal, cImag int64, ctOut *Ciphertext) { +// MultByGaussianInteger multiples the ct0 by the gaussian integer cReal + i*cImag and returns the result on ctOut. +// Accepted types for cReal and cImag are uint64, int64 and big.Int. +func (eval *evaluator) MultByGaussianInteger(ct0 *Ciphertext, cReal, cImag interface{}, ctOut *Ciphertext) { - ringQ := eval.ringQ + ringQ := eval.params.RingQ() level := utils.MinInt(ct0.Level(), ctOut.Level()) var scaledConst, scaledConstReal, scaledConstImag uint64 @@ -876,25 +897,11 @@ func (eval *evaluator) MultByGaussianInteger(ct0 *Ciphertext, cReal, cImag int64 bredParams := ringQ.BredParams[i] mredParams := ringQ.MredParams[i] - scaledConstReal = 0 - scaledConstImag = 0 - scaledConst = 0 + scaledConstReal = interfaceMod(cReal, qi) + scaledConstImag = interfaceMod(cImag, qi) + scaledConst = scaledConstReal - if cReal != 0 { - if cReal < 0 { - scaledConstReal = uint64(int64(qi) + cReal%int64(qi)) - } else { - scaledConstReal = uint64(cReal) - } - scaledConst = scaledConstReal - } - - if cImag != 0 { - if cImag < 0 { - scaledConstImag = uint64(int64(qi) + cImag%int64(qi)) - } else { - scaledConstImag = uint64(cImag) - } + if scaledConstImag != 0 { scaledConstImag = ring.MRed(scaledConstImag, ringQ.NttPsi[i][1], qi, mredParams) scaledConst = ring.CRed(scaledConst+scaledConstImag, qi) } @@ -948,9 +955,11 @@ func (eval *evaluator) MultByGaussianInteger(ct0 *Ciphertext, cReal, cImag int64 } } -func (eval *evaluator) MultByGaussianIntegerAndAdd(ct0 *Ciphertext, cReal, cImag int64, ctOut *Ciphertext) { +// MultByGaussianIntegerAndAdd multiples the ct0 by the gaussian integer cReal + i*cImag and adds the result on ctOut. +// Accepted types for cReal and cImag are uint64, int64 and big.Int. +func (eval *evaluator) MultByGaussianIntegerAndAdd(ct0 *Ciphertext, cReal, cImag interface{}, ctOut *Ciphertext) { - ringQ := eval.ringQ + ringQ := eval.params.RingQ() level := utils.MinInt(ct0.Level(), ctOut.Level()) var scaledConst, scaledConstReal, scaledConstImag uint64 @@ -961,25 +970,11 @@ func (eval *evaluator) MultByGaussianIntegerAndAdd(ct0 *Ciphertext, cReal, cImag bredParams := ringQ.BredParams[i] mredParams := ringQ.MredParams[i] - scaledConstReal = 0 - scaledConstImag = 0 - scaledConst = 0 + scaledConstReal = interfaceMod(cReal, qi) + scaledConstImag = interfaceMod(cImag, qi) + scaledConst = scaledConstReal - if cReal != 0 { - if cReal < 0 { - scaledConstReal = uint64(int64(qi) + cReal%int64(qi)) - } else { - scaledConstReal = uint64(cReal) - } - scaledConst = scaledConstReal - } - - if cImag != 0 { - if cImag < 0 { - scaledConstImag = uint64(int64(qi) + cImag%int64(qi)) - } else { - scaledConstImag = uint64(cImag) - } + if scaledConstImag != 0 { scaledConstImag = ring.MRed(scaledConstImag, ringQ.NttPsi[i][1], qi, mredParams) scaledConst = ring.CRed(scaledConst+scaledConstImag, qi) } @@ -1048,7 +1043,7 @@ func (eval *evaluator) MultByi(ct0 *Ciphertext, ctOut *Ciphertext) { var level = utils.MinInt(ct0.Level(), ctOut.Level()) ctOut.Scale = ct0.Scale - ringQ := eval.ringQ + ringQ := eval.params.RingQ() var imag uint64 @@ -1118,7 +1113,7 @@ func (eval *evaluator) DivByi(ct0 *Ciphertext, ctOut *Ciphertext) { var level = utils.MinInt(ct0.Level(), ctOut.Level()) - ringQ := eval.ringQ + ringQ := eval.params.RingQ() ctOut.Scale = ct0.Scale @@ -1191,20 +1186,11 @@ func (eval *evaluator) ScaleUp(ct0 *Ciphertext, scale float64, ctOut *Ciphertext // SetScale sets the scale of the ciphertext to the input scale (consumes a level) func (eval *evaluator) SetScale(ct *Ciphertext, scale float64) { - - var tmp = eval.params.Scale() - - eval.scale = scale - eval.MultByConst(ct, scale/ct.Scale, ct) - if err := eval.Rescale(ct, scale, ct); err != nil { panic(err) } - ct.Scale = scale - - eval.scale = tmp } // MulByPow2New multiplies ct0 by 2^pow2 and returns the result in a newly created element. @@ -1219,7 +1205,7 @@ func (eval *evaluator) MulByPow2(ct0 *Ciphertext, pow2 int, ctOut *Ciphertext) { var level = utils.MinInt(ct0.Level(), ctOut.Level()) ctOut.Scale = ct0.Scale for i := range ctOut.Value { - eval.ringQ.MulByPow2Lvl(level, ct0.Value[i], pow2, ctOut.Value[i]) + eval.params.RingQ().MulByPow2Lvl(level, ct0.Value[i], pow2, ctOut.Value[i]) } } @@ -1243,7 +1229,7 @@ func (eval *evaluator) Reduce(ct0 *Ciphertext, ctOut *Ciphertext) error { } for i := range ct0.Value { - eval.ringQ.ReduceLvl(utils.MinInt(ct0.Level(), ctOut.Level()), ct0.Value[i], ctOut.Value[i]) + eval.params.RingQ().ReduceLvl(utils.MinInt(ct0.Level(), ctOut.Level()), ct0.Value[i], ctOut.Value[i]) } ctOut.Scale = ct0.Scale @@ -1289,7 +1275,7 @@ func (eval *evaluator) RescaleNew(ct0 *Ciphertext, threshold float64) (ctOut *Ci // Returns an error if "minScale <= 0", ct.Scale = 0, ct.Level() = 0, ct.IsNTT() != true or if ct.Leve() != ctOut.Level() func (eval *evaluator) Rescale(ctIn *Ciphertext, minScale float64, ctOut *Ciphertext) (err error) { - ringQ := eval.ringQ + ringQ := eval.params.RingQ() if minScale <= 0 { return errors.New("cannot Rescale: minScale is 0") @@ -1309,16 +1295,24 @@ func (eval *evaluator) Rescale(ctIn *Ciphertext, minScale float64, ctOut *Cipher ctOut.Scale = ctIn.Scale - var nbRescale int + var nbRescales int // Divides the scale by each moduli of the modulus chain as long as the scale isn't smaller than minScale/2 // or until the output Level() would be zero - for ctOut.Scale/float64(ringQ.Modulus[ctIn.Level()-nbRescale]) >= minScale/2 && ctIn.Level()-nbRescale >= 0 { - ctOut.Scale /= (float64(ringQ.Modulus[ctIn.Level()-nbRescale])) - nbRescale++ + for ctOut.Scale/float64(ringQ.Modulus[ctIn.Level()-nbRescales]) >= minScale/2 && ctIn.Level()-nbRescales >= 0 { + ctOut.Scale /= (float64(ringQ.Modulus[ctIn.Level()-nbRescales])) + nbRescales++ } - for i := range ctOut.Value { - ringQ.DivRoundByLastModulusManyNTT(ctIn.Value[i], ctOut.Value[i], nbRescale) + if nbRescales > 0 { + level := ctIn.Level() + for i := range ctOut.Value { + ringQ.DivRoundByLastModulusManyNTTLvl(level, nbRescales, ctIn.Value[i], eval.poolQMul[0], ctOut.Value[i]) + ctOut.Value[i].Coeffs = ctOut.Value[i].Coeffs[:level+1-nbRescales] + } + } else { + if ctIn != ctOut { + ctOut.Copy(ctIn) + } } return nil @@ -1372,7 +1366,7 @@ func (eval *evaluator) mulRelin(op0, op1 Operand, relin bool, ctOut *Ciphertext) ctOut.Scale = op0.ScalingFactor() * op1.ScalingFactor() - ringQ := eval.ringQ + ringQ := eval.params.RingQ() var c00, c01, c0, c1, c2 *ring.Poly @@ -1420,9 +1414,9 @@ func (eval *evaluator) mulRelin(op0, op1 Operand, relin bool, ctOut *Ciphertext) if relin { c2.IsNTT = true - eval.SwitchKeysInPlace(level, c2, eval.rlk.Keys[0], eval.PoolQ[1], eval.PoolQ[2]) - ringQ.AddLvl(level, c0, eval.PoolQ[1], ctOut.Value[0]) - ringQ.AddLvl(level, c1, eval.PoolQ[2], ctOut.Value[1]) + eval.SwitchKeysInPlace(level, c2, eval.rlk.Keys[0], eval.Pool[1].Q, eval.Pool[2].Q) + ringQ.AddLvl(level, c0, eval.Pool[1].Q, ctOut.Value[0]) + ringQ.AddLvl(level, c1, eval.Pool[2].Q, ctOut.Value[1]) } // Case Plaintext (x) Ciphertext or Ciphertext (x) Plaintext @@ -1465,12 +1459,12 @@ func (eval *evaluator) Relinearize(ct0 *Ciphertext, ctOut *Ciphertext) { ctOut.Scale = ct0.Scale level := utils.MinInt(ct0.Level(), ctOut.Level()) - ringQ := eval.ringQ + ringQ := eval.params.RingQ() - eval.SwitchKeysInPlace(level, ct0.Value[2], eval.rlk.Keys[0], eval.PoolQ[1], eval.PoolQ[2]) + eval.SwitchKeysInPlace(level, ct0.Value[2], eval.rlk.Keys[0], eval.Pool[1].Q, eval.Pool[2].Q) - ringQ.AddLvl(level, ct0.Value[0], eval.PoolQ[1], ctOut.Value[0]) - ringQ.AddLvl(level, ct0.Value[1], eval.PoolQ[2], ctOut.Value[1]) + ringQ.AddLvl(level, ct0.Value[0], eval.Pool[1].Q, ctOut.Value[0]) + ringQ.AddLvl(level, ct0.Value[1], eval.Pool[2].Q, ctOut.Value[1]) ctOut.El().Resize(eval.params.Parameters, 1) } @@ -1494,14 +1488,14 @@ func (eval *evaluator) SwitchKeys(ct0 *Ciphertext, switchingKey *rlwe.SwitchingK } level := utils.MinInt(ct0.Level(), ctOut.Level()) - ringQ := eval.ringQ + ringQ := eval.params.RingQ() ctOut.Scale = ct0.Scale - eval.SwitchKeysInPlace(level, ct0.Value[1], switchingKey, eval.PoolQ[1], eval.PoolQ[2]) + eval.SwitchKeysInPlace(level, ct0.Value[1], switchingKey, eval.Pool[1].Q, eval.Pool[2].Q) - ringQ.AddLvl(level, ct0.Value[0], eval.PoolQ[1], ctOut.Value[0]) - ring.CopyValuesLvl(level, eval.PoolQ[2], ctOut.Value[1]) + ringQ.AddLvl(level, ct0.Value[0], eval.Pool[1].Q, ctOut.Value[0]) + ring.CopyValuesLvl(level, eval.Pool[2].Q, ctOut.Value[1]) } // RotateNew rotates the columns of ct0 by k positions to the left, and returns the result in a newly created element. @@ -1545,6 +1539,10 @@ func (eval *evaluator) ConjugateNew(ct0 *Ciphertext) (ctOut *Ciphertext) { // If the provided element is a Ciphertext, a key-switching operation is necessary and a rotation key for the row rotation needs to be provided. func (eval *evaluator) Conjugate(ct0 *Ciphertext, ctOut *Ciphertext) { + if ct0.Degree() != 1 || ctOut.Degree() != 1 { + panic("input and output Ciphertext must be of degree 1") + } + galEl := eval.params.GaloisElementForRowRotation() ctOut.Scale = ct0.Scale eval.permuteNTT(ct0, galEl, ctOut) @@ -1552,57 +1550,52 @@ func (eval *evaluator) Conjugate(ct0 *Ciphertext, ctOut *Ciphertext) { func (eval *evaluator) permuteNTT(ct0 *Ciphertext, galEl uint64, ctOut *Ciphertext) { - if ct0.Degree() != 1 || ctOut.Degree() != 1 { - panic("input and output Ciphertext must be of degree 1") - } - - rtk, generated := eval.rtks.Keys[galEl] + rtk, generated := eval.rtks.GetRotationKey(galEl) if !generated { panic(fmt.Sprintf("rotation key k=%d not available", eval.params.InverseGaloisElement(galEl))) } level := utils.MinInt(ct0.Level(), ctOut.Level()) index := eval.permuteNTTIndex[galEl] - pool2Q := eval.PoolQ[1] - pool3Q := eval.PoolQ[2] + pool2Q := eval.Pool[1].Q + pool3Q := eval.Pool[2].Q eval.SwitchKeysInPlace(level, ct0.Value[1], rtk, pool2Q, pool3Q) - eval.ringQ.AddLvl(level, pool2Q, ct0.Value[0], pool2Q) + eval.params.RingQ().AddLvl(level, pool2Q, ct0.Value[0], pool2Q) ring.PermuteNTTWithIndexLvl(level, pool2Q, index, ctOut.Value[0]) ring.PermuteNTTWithIndexLvl(level, pool3Q, index, ctOut.Value[1]) } -func (eval *evaluator) rotateHoistedNoModDown(ct0 *Ciphertext, rotations []int, c2QiQDecomp, c2QiPDecomp []*ring.Poly) (cOutQ, cOutP map[int][2]*ring.Poly) { +func (eval *evaluator) rotateHoistedNoModDown(level int, rotations []int, c2DecompQP []rlwe.PolyQP) (cOutQ, cOutP map[int][2]*ring.Poly) { - ringQ := eval.ringQ + ringQ := eval.params.RingQ() + ringP := eval.params.RingP() cOutQ = make(map[int][2]*ring.Poly) cOutP = make(map[int][2]*ring.Poly) - level := ct0.Level() - for _, i := range rotations { if i != 0 { cOutQ[i] = [2]*ring.Poly{ringQ.NewPolyLvl(level), ringQ.NewPolyLvl(level)} - cOutP[i] = [2]*ring.Poly{eval.ringP.NewPoly(), eval.ringP.NewPoly()} + cOutP[i] = [2]*ring.Poly{ringP.NewPoly(), ringP.NewPoly()} - eval.permuteNTTHoistedNoModDown(level, c2QiQDecomp, c2QiPDecomp, i, cOutQ[i][0], cOutQ[i][1], cOutP[i][0], cOutP[i][1]) + eval.PermuteNTTHoistedNoModDown(level, c2DecompQP, i, cOutQ[i][0], cOutQ[i][1], cOutP[i][0], cOutP[i][1]) } } return } -func (eval *evaluator) permuteNTTHoistedNoModDown(level int, c2QiQDecomp, c2QiPDecomp []*ring.Poly, k int, ct0OutQ, ct1OutQ, ct0OutP, ct1OutP *ring.Poly) { +func (eval *evaluator) PermuteNTTHoistedNoModDown(level int, c2DecompQP []rlwe.PolyQP, k int, ct0OutQ, ct1OutQ, ct0OutP, ct1OutP *ring.Poly) { - pool2Q := eval.PoolQ[0] - pool3Q := eval.PoolQ[1] + pool2Q := eval.Pool[0].Q + pool3Q := eval.Pool[1].Q - pool2P := eval.PoolP[0] - pool3P := eval.PoolP[1] + pool2P := eval.Pool[0].P + pool3P := eval.Pool[1].P levelQ := level levelP := eval.params.PCount() - 1 @@ -1616,7 +1609,7 @@ func (eval *evaluator) permuteNTTHoistedNoModDown(level int, c2QiQDecomp, c2QiPD } index := eval.permuteNTTIndex[galEl] - eval.KeyswitchHoistedNoModDown(levelQ, c2QiQDecomp, c2QiPDecomp, rtk, pool2Q, pool3Q, pool2P, pool3P) + eval.KeyswitchHoistedNoModDown(levelQ, c2DecompQP, rtk, pool2Q, pool3Q, pool2P, pool3P) ring.PermuteNTTWithIndexLvl(levelQ, pool2Q, index, ct0OutQ) ring.PermuteNTTWithIndexLvl(levelQ, pool3Q, index, ct1OutQ) @@ -1625,7 +1618,7 @@ func (eval *evaluator) permuteNTTHoistedNoModDown(level int, c2QiQDecomp, c2QiPD ring.PermuteNTTWithIndexLvl(levelP, pool3P, index, ct1OutP) } -func (eval *evaluator) permuteNTTHoisted(level int, c0, c1 *ring.Poly, c2QiQDecomp, c2QiPDecomp []*ring.Poly, k int, cOut0, cOut1 *ring.Poly) { +func (eval *evaluator) PermuteNTTHoisted(level int, c0, c1 *ring.Poly, c2DecompQP []rlwe.PolyQP, k int, cOut0, cOut1 *ring.Poly) { if k == 0 { cOut0.Copy(c0) @@ -1641,15 +1634,15 @@ func (eval *evaluator) permuteNTTHoisted(level int, c0, c1 *ring.Poly, c2QiQDeco index := eval.permuteNTTIndex[galEl] - pool2Q := eval.PoolQ[0] - pool3Q := eval.PoolQ[1] + pool2Q := eval.Pool[0].Q + pool3Q := eval.Pool[1].Q - pool2P := eval.PoolP[0] - pool3P := eval.PoolP[1] + pool2P := eval.Pool[0].P + pool3P := eval.Pool[1].P - eval.KeyswitchHoisted(level, c2QiQDecomp, c2QiPDecomp, rtk, pool2Q, pool3Q, pool2P, pool3P) + eval.KeyswitchHoisted(level, c2DecompQP, rtk, pool2Q, pool3Q, pool2P, pool3P) - eval.ringQ.AddLvl(level, pool2Q, c0, pool2Q) + eval.params.RingQ().AddLvl(level, pool2Q, c0, pool2Q) ring.PermuteNTTWithIndexLvl(level, pool2Q, index, cOut0) ring.PermuteNTTWithIndexLvl(level, pool3Q, index, cOut1) diff --git a/ckks/keys.go b/ckks/keys.go index 46e18260..fd7224dc 100644 --- a/ckks/keys.go +++ b/ckks/keys.go @@ -7,10 +7,6 @@ func NewKeyGenerator(params Parameters) rlwe.KeyGenerator { return rlwe.NewKeyGenerator(params.Parameters) } -// BootstrappingKey is a type for a CKKS bootstrapping key, wich regroups the necessary public relinearization -// and rotation keys (i.e., an EvaluationKey). -type BootstrappingKey rlwe.EvaluationKey - // NewSecretKey returns an allocated CKKS secret key with zero values. func NewSecretKey(params Parameters) (sk *rlwe.SecretKey) { return rlwe.NewSecretKey(params.Parameters) @@ -23,7 +19,7 @@ func NewPublicKey(params Parameters) (pk *rlwe.PublicKey) { // NewSwitchingKey returns an allocated CKKS public switching key with zero values. func NewSwitchingKey(params Parameters) *rlwe.SwitchingKey { - return rlwe.NewSwitchingKey(params.Parameters) + return rlwe.NewSwitchingKey(params.Parameters, params.QCount()-1, params.PCount()-1) } // NewRelinearizationKey returns an allocated CKKS public relinearization key with zero value. diff --git a/ckks/linear_transform.go b/ckks/linear_transform.go index 767bde20..d382238f 100644 --- a/ckks/linear_transform.go +++ b/ckks/linear_transform.go @@ -2,39 +2,69 @@ package ckks import ( "github.com/ldsec/lattigo/v2/ring" + "github.com/ldsec/lattigo/v2/rlwe" "github.com/ldsec/lattigo/v2/utils" ) -// RotateHoisted takes an input Ciphertext and a list of rotations and returns a map of Ciphertext, where each element of the map is the input Ciphertext -// rotation by one element of the list. It is much faster than sequential calls to Rotate. -func (eval *evaluator) RotateHoisted(ctIn *Ciphertext, rotations []int) (cOut map[int]*Ciphertext) { - - level := ctIn.Level() - - eval.DecomposeNTT(level, ctIn.Value[1], eval.PoolDecompQ, eval.PoolDecompP) - - cOut = make(map[int]*Ciphertext) - for _, i := range rotations { - - if i == 0 { - cOut[i] = ctIn.CopyNew() - } else { - cOut[i] = NewCiphertext(eval.params, 1, level, ctIn.Scale) - eval.permuteNTTHoisted(level, ctIn.Value[0], ctIn.Value[1], eval.PoolDecompQ, eval.PoolDecompP, i, cOut[i].Value[0], cOut[i].Value[1]) - } +// Trace maps X -> sum((-1)^i * X^{i*n+1}) for 0 <= i < N +// For log(n) = logSlotStart and log(N/2) = logSlotsEnd +func (eval *evaluator) Trace(ctIn *Ciphertext, logSlotsStart, logSlotsEnd int, ctOut *Ciphertext) { + if ctIn != ctOut { + ctOut.Copy(ctIn) + } else { + ctOut = ctIn } + for i := logSlotsStart; i < logSlotsEnd; i++ { + eval.permuteNTT(ctOut, eval.params.GaloisElementForColumnRotationBy(1< sum((-1)^i * X^{i*n+1}) for 0 <= i < N and returns the result on a new ciphertext. +// For log(n) = logSlotStart and log(N/2) = logSlotsEnd +func (eval *evaluator) TraceNew(ctIn *Ciphertext, logSlotsStart, logSlotsEnd int) (ctOut *Ciphertext) { + ctOut = NewCiphertext(eval.params, 1, ctIn.Level(), ctIn.Scale) + eval.Trace(ctIn, logSlotsStart, logSlotsEnd, ctOut) return } -// LinearTransform evaluates a linear transform on the ciphertext. The linearTransform can either be an (ordered) list of -// PtDiagMatrix or a single PtDiagMatrix. In either case a list of ciphertext is return (the second case returnign a list of +// RotateHoistedNew takes an input Ciphertext and a list of rotations and returns a map of Ciphertext, where each element of the map is the input Ciphertext +// rotation by one element of the list. It is much faster than sequential calls to Rotate. +func (eval *evaluator) RotateHoistedNew(ctIn *Ciphertext, rotations []int) (ctOut map[int]*Ciphertext) { + ctOut = make(map[int]*Ciphertext) + for _, i := range rotations { + ctOut[i] = NewCiphertext(eval.params, 1, ctIn.Level(), ctIn.Scale) + } + eval.RotateHoisted(ctIn, rotations, ctOut) + return +} + +// RotateHoisted takes an input Ciphertext and a list of rotations and populates a map of pre-allocated Ciphertexts, +// where each element of the map is the input Ciphertext rotation by one element of the list. +// It is much faster than sequential calls to Rotate. +func (eval *evaluator) RotateHoisted(ctIn *Ciphertext, rotations []int, ctOut map[int]*Ciphertext) { + levelQ := ctIn.Level() + eval.DecomposeNTT(levelQ, eval.params.PCount()-1, eval.params.PCount(), ctIn.Value[1], eval.PoolDecompQP) + for _, i := range rotations { + if i == 0 { + ctOut[i].Copy(ctIn) + } else { + eval.PermuteNTTHoisted(levelQ, ctIn.Value[0], ctIn.Value[1], eval.PoolDecompQP, i, ctOut[i].Value[0], ctOut[i].Value[1]) + } + } +} + +// LinearTransformNew evaluates a linear transform on the ciphertext and returns the result on a new ciphertext. +// The linearTransform can either be an (ordered) list of PtDiagMatrix or a single PtDiagMatrix. +// In either case a list of ciphertext is return (the second case returnign a list of // containing a single ciphertext. A PtDiagMatrix is a diagonalized plaintext matrix contructed with an Encoder using // the method encoder.EncodeDiagMatrixAtLvl(*). -func (eval *evaluator) LinearTransform(ctIn *Ciphertext, linearTransform interface{}) (ctOut []*Ciphertext) { +func (eval *evaluator) LinearTransformNew(ctIn *Ciphertext, linearTransform interface{}) (ctOut []*Ciphertext) { switch element := linearTransform.(type) { - case []*PtDiagMatrix: + case []PtDiagMatrix: ctOut = make([]*Ciphertext, len(element)) var maxLevel int @@ -44,45 +74,82 @@ func (eval *evaluator) LinearTransform(ctIn *Ciphertext, linearTransform interfa minLevel := utils.MinInt(maxLevel, ctIn.Level()) - eval.DecomposeNTT(minLevel, ctIn.Value[1], eval.PoolDecompQ, eval.PoolDecompP) + eval.DecomposeNTT(minLevel, eval.params.PCount()-1, eval.params.PCount(), ctIn.Value[1], eval.PoolDecompQP) for i, matrix := range element { ctOut[i] = NewCiphertext(eval.params, 1, minLevel, ctIn.Scale) - if matrix.naive { - eval.MultiplyByDiagMatrix(ctIn, matrix, eval.PoolDecompQ, eval.PoolDecompP, ctOut[i]) + if matrix.Naive { + eval.MultiplyByDiagMatrix(ctIn, matrix, eval.PoolDecompQP, ctOut[i]) } else { - eval.MultiplyByDiagMatrixBSGS(ctIn, matrix, eval.PoolDecompQ, eval.PoolDecompP, ctOut[i]) + eval.MultiplyByDiagMatrixBSGS(ctIn, matrix, eval.PoolDecompQP, ctOut[i]) } } - case *PtDiagMatrix: + case PtDiagMatrix: minLevel := utils.MinInt(element.Level, ctIn.Level()) - eval.DecomposeNTT(minLevel, ctIn.Value[1], eval.PoolDecompQ, eval.PoolDecompP) + eval.DecomposeNTT(minLevel, eval.params.PCount()-1, eval.params.PCount(), ctIn.Value[1], eval.PoolDecompQP) ctOut = []*Ciphertext{NewCiphertext(eval.params, 1, minLevel, ctIn.Scale)} - if element.naive { - eval.MultiplyByDiagMatrix(ctIn, element, eval.PoolDecompQ, eval.PoolDecompP, ctOut[0]) + if element.Naive { + eval.MultiplyByDiagMatrix(ctIn, element, eval.PoolDecompQP, ctOut[0]) } else { - eval.MultiplyByDiagMatrixBSGS(ctIn, element, eval.PoolDecompQ, eval.PoolDecompP, ctOut[0]) + eval.MultiplyByDiagMatrixBSGS(ctIn, element, eval.PoolDecompQP, ctOut[0]) } } - return } +// LinearTransformNew evaluates a linear transform on the pre-allocated ciphertexts. +// The linearTransform can either be an (ordered) list of PtDiagMatrix or a single PtDiagMatrix. +// In either case a list of ciphertext is return (the second case returnign a list of +// containing a single ciphertext. A PtDiagMatrix is a diagonalized plaintext matrix contructed with an Encoder using +// the method encoder.EncodeDiagMatrixAtLvl(*). +func (eval *evaluator) LinearTransform(ctIn *Ciphertext, linearTransform interface{}, ctOut []*Ciphertext) { + + switch element := linearTransform.(type) { + case []PtDiagMatrix: + var maxLevel int + for _, matrix := range element { + maxLevel = utils.MaxInt(maxLevel, matrix.Level) + } + + minLevel := utils.MinInt(maxLevel, ctIn.Level()) + + eval.DecomposeNTT(minLevel, eval.params.PCount()-1, eval.params.PCount(), ctIn.Value[1], eval.PoolDecompQP) + + for i, matrix := range element { + if matrix.Naive { + eval.MultiplyByDiagMatrix(ctIn, matrix, eval.PoolDecompQP, ctOut[i]) + } else { + eval.MultiplyByDiagMatrixBSGS(ctIn, matrix, eval.PoolDecompQP, ctOut[i]) + } + } + + case PtDiagMatrix: + minLevel := utils.MinInt(element.Level, ctIn.Level()) + eval.DecomposeNTT(minLevel, eval.params.PCount()-1, eval.params.PCount(), ctIn.Value[1], eval.PoolDecompQP) + if element.Naive { + eval.MultiplyByDiagMatrix(ctIn, element, eval.PoolDecompQP, ctOut[0]) + } else { + eval.MultiplyByDiagMatrixBSGS(ctIn, element, eval.PoolDecompQP, ctOut[0]) + } + } +} + // InnerSumLog applies an optimized inner sum on the ciphetext (log2(n) + HW(n) rotations with double hoisting). // The operation assumes that `ctIn` encrypts SlotCount/`batchSize` sub-vectors of size `batchSize` which it adds together (in parallel) by groups of `n`. // It outputs in ctOut a ciphertext for which the "leftmost" sub-vector of each group is equal to the sum of the group. // This method is faster than InnerSum when the number of rotations is large and uses log2(n) + HW(n) insteadn of 'n' keys. func (eval *evaluator) InnerSumLog(ctIn *Ciphertext, batchSize, n int, ctOut *Ciphertext) { - ringQ := eval.ringQ - ringP := eval.ringP + ringQ := eval.params.RingQ() + ringP := eval.params.RingP() levelQ := ctIn.Level() + levelP := len(ringP.Modulus) - 1 //QiOverF := eval.params.QiOverflowMargin(levelQ) //PiOverF := eval.params.PiOverflowMargin() @@ -100,22 +167,22 @@ func (eval *evaluator) InnerSumLog(ctIn *Ciphertext, batchSize, n int, ctOut *Ci tmpc2 := eval.poolQMul[2] // Accumulator outer loop for ctOut = ctOut + rot(ctIn, k) in QP - ct0OutQ := eval.PoolQ[4] - ct1OutQ := eval.PoolQ[5] - ct0OutP := eval.PoolP[4] - ct1OutP := eval.PoolP[5] + ct0OutQ := eval.Pool[4].Q + ct1OutQ := eval.Pool[5].Q + ct0OutP := eval.Pool[4].P + ct1OutP := eval.Pool[5].P // Memory pool for rot(ctIn, k) - pool2Q := eval.PoolQ[2] // ctOut(c0', c1') from evaluator keyswitch memory pool - pool3Q := eval.PoolQ[3] // ctOut(c0', c1') from evaluator keyswitch memory pool - pool2P := eval.PoolP[2] // ctOut(c0', c1') from evaluator keyswitch memory pool - pool3P := eval.PoolP[3] // ctOut(c0', c1') from evaluator keyswitch memory pool + pool2Q := eval.Pool[2].Q // ctOut(c0', c1') from evaluator keyswitch memory pool + pool3Q := eval.Pool[3].Q // ctOut(c0', c1') from evaluator keyswitch memory pool + pool2P := eval.Pool[2].P // ctOut(c0', c1') from evaluator keyswitch memory pool + pool3P := eval.Pool[3].P // ctOut(c0', c1') from evaluator keyswitch memory pool // Used by the key-switch - // eval.poolQ[0] - // eval.poolQ[1] - // eval.poolP[0] - // eval.poolP[1] + // eval.Pool[0].Q + // eval.Pool[1].Q + // eval.Pool[0].P + // eval.Pool[1].P state := false copy := true @@ -125,11 +192,11 @@ func (eval *evaluator) InnerSumLog(ctIn *Ciphertext, batchSize, n int, ctOut *Ci // Starts by decomposing the input ciphertext if i == 0 { // If first iteration, then copies directly from the input ciphertext that hasn't been rotated - eval.DecomposeNTT(levelQ, ctIn.Value[1], eval.PoolDecompQ, eval.PoolDecompP) + eval.DecomposeNTT(levelQ, levelP, levelP+1, ctIn.Value[1], eval.PoolDecompQP) } else { // Else copies from the rotated input ciphertext tmpc1.IsNTT = true - eval.DecomposeNTT(levelQ, tmpc1, eval.PoolDecompQ, eval.PoolDecompP) + eval.DecomposeNTT(levelQ, levelP, levelP+1, tmpc1, eval.PoolDecompQP) } // If the binary reading scans a 1 @@ -142,7 +209,7 @@ func (eval *evaluator) InnerSumLog(ctIn *Ciphertext, batchSize, n int, ctOut *Ci if k != 0 { // Rotate((tmpc0, tmpc1), k) - eval.permuteNTTHoistedNoModDown(levelQ, eval.PoolDecompQ, eval.PoolDecompP, k, pool2Q, pool3Q, pool2P, pool3P) + eval.PermuteNTTHoistedNoModDown(levelQ, eval.PoolDecompQP, k, pool2Q, pool3Q, pool2P, pool3P) // ctOut += Rotate((tmpc0, tmpc1), k) if copy { @@ -173,8 +240,8 @@ func (eval *evaluator) InnerSumLog(ctIn *Ciphertext, batchSize, n int, ctOut *Ci // if n is not a power of two if n&(n-1) != 0 { - eval.Baseconverter.ModDownSplitNTTPQ(levelQ, ct0OutQ, ct0OutP, ct0OutQ) // Division by P - eval.Baseconverter.ModDownSplitNTTPQ(levelQ, ct1OutQ, ct1OutP, ct1OutQ) // Division by P + eval.Baseconverter.ModDownQPtoQNTT(levelQ, levelP, ct0OutQ, ct0OutP, ct0OutQ) // Division by P + eval.Baseconverter.ModDownQPtoQNTT(levelQ, levelP, ct1OutQ, ct1OutP, ct1OutQ) // Division by P // ctOut += (tmpc0, tmpc1) ringQ.AddLvl(levelQ, ct0OutQ, tmpc0, ctOut.Value[0]) @@ -191,13 +258,13 @@ func (eval *evaluator) InnerSumLog(ctIn *Ciphertext, batchSize, n int, ctOut *Ci if !state { if i == 0 { - eval.permuteNTTHoisted(levelQ, ctIn.Value[0], ctIn.Value[1], eval.PoolDecompQ, eval.PoolDecompP, (1<> 1 - PiOverF := eval.params.PiOverflowMargin() >> 1 + PiOverF := eval.params.PiOverflowMargin(levelP) >> 1 // If sum with only the first element, then returns the input if n == 1 { @@ -242,15 +310,15 @@ func (eval *evaluator) InnerSum(ctIn *Ciphertext, batchSize, n int, ctOut *Ciphe tmpQ0 := eval.poolQMul[0] // unused memory pool from evaluator tmpQ1 := eval.poolQMul[1] // unused memory pool from evaluator - pool2P := eval.PoolP[1] // ctOut(c0', c1') from evaluator keyswitch memory pool - pool3P := eval.PoolP[2] // ctOut(c0', c1') from evaluator keyswitch memory pool + pool2P := eval.Pool[1].P // ctOut(c0', c1') from evaluator keyswitch memory pool + pool3P := eval.Pool[2].P // ctOut(c0', c1') from evaluator keyswitch memory pool // Basis decomposition - eval.DecomposeNTT(levelQ, ctIn.Value[1], eval.PoolDecompQ, eval.PoolDecompP) + eval.DecomposeNTT(levelQ, levelP, levelP+1, ctIn.Value[1], eval.PoolDecompQP) // Pre-rotates all [1, ..., n-1] rotations // Hoisted rotation without division by P - vecRotQ, vecRotP := eval.rotateHoistedNoModDown(ctIn, rotations, eval.PoolDecompQ, eval.PoolDecompP) + vecRotQ, vecRotP := eval.rotateHoistedNoModDown(levelQ, rotations, eval.PoolDecompQP) // P*c0 -> tmpQ0 ringQ.MulScalarBigintLvl(levelQ, ctIn.Value[0], ringP.ModulusBigint, tmpQ0) @@ -283,13 +351,13 @@ func (eval *evaluator) InnerSum(ctIn *Ciphertext, batchSize, n int, ctOut *Ciphe if i == 1 { ring.CopyValuesLvl(levelQ, vecRotQ[j][0], tmpQ0) ring.CopyValuesLvl(levelQ, vecRotQ[j][1], tmpQ1) - ring.CopyValues(vecRotP[j][0], pool2P) - ring.CopyValues(vecRotP[j][1], pool3P) + ring.CopyValuesLvl(levelP, vecRotP[j][0], pool2P) + ring.CopyValuesLvl(levelP, vecRotP[j][1], pool3P) } else { ringQ.AddNoModLvl(levelQ, tmpQ0, vecRotQ[j][0], tmpQ0) ringQ.AddNoModLvl(levelQ, tmpQ1, vecRotQ[j][1], tmpQ1) - ringP.AddNoMod(pool2P, vecRotP[j][0], pool2P) - ringP.AddNoMod(pool3P, vecRotP[j][1], pool3P) + ringP.AddNoModLvl(levelP, pool2P, vecRotP[j][0], pool2P) + ringP.AddNoModLvl(levelP, pool3P, vecRotP[j][1], pool3P) } if reduce%QiOverF == QiOverF-1 { @@ -298,8 +366,8 @@ func (eval *evaluator) InnerSum(ctIn *Ciphertext, batchSize, n int, ctOut *Ciphe } if reduce%PiOverF == PiOverF-1 { - ringP.Reduce(pool2P, pool2P) - ringP.Reduce(pool3P, pool3P) + ringP.ReduceLvl(levelP, pool2P, pool2P) + ringP.ReduceLvl(levelP, pool3P, pool3P) } reduce++ @@ -311,13 +379,13 @@ func (eval *evaluator) InnerSum(ctIn *Ciphertext, batchSize, n int, ctOut *Ciphe } if reduce%PiOverF != 0 { - ringP.Reduce(pool2P, pool2P) - ringP.Reduce(pool3P, pool3P) + ringP.ReduceLvl(levelP, pool2P, pool2P) + ringP.ReduceLvl(levelP, pool3P, pool3P) } // Division by P of sum(elements [2, ..., n-1] ) - eval.Baseconverter.ModDownSplitNTTPQ(levelQ, tmpQ0, pool2P, tmpQ0) // sum_{i=1, n-1}(phi(d0))/P - eval.Baseconverter.ModDownSplitNTTPQ(levelQ, tmpQ1, pool3P, tmpQ1) // sum_{i=1, n-1}(phi(d1))/P + eval.Baseconverter.ModDownQPtoQNTT(levelQ, levelP, tmpQ0, pool2P, tmpQ0) // sum_{i=1, n-1}(phi(d0))/P + eval.Baseconverter.ModDownQPtoQNTT(levelQ, levelP, tmpQ1, pool3P, tmpQ1) // sum_{i=1, n-1}(phi(d1))/P // Adds element[1] (which did not require rotation) ringQ.AddLvl(levelQ, ctIn.Value[0], tmpQ0, ctOut.Value[0]) // sum_{i=1, n-1}(phi(d0))/P + ct0 @@ -352,31 +420,40 @@ func (eval *evaluator) Replicate(ctIn *Ciphertext, batchSize, n int, ctOut *Ciph // respectively, each of size params.Beta(). // The naive approach is used (single hoisting and no baby-step giant-step), which is faster than MultiplyByDiagMatrixBSGS // for matrix of only a few non-zero diagonals but uses more keys. -func (eval *evaluator) MultiplyByDiagMatrix(ctIn *Ciphertext, matrix *PtDiagMatrix, PoolDecompQ, PoolDecompP []*ring.Poly, ctOut *Ciphertext) { +func (eval *evaluator) MultiplyByDiagMatrix(ctIn *Ciphertext, matrix PtDiagMatrix, PoolDecompQP []rlwe.PolyQP, ctOut *Ciphertext) { - ringQ := eval.ringQ - ringP := eval.ringP + ringQ := eval.params.RingQ() + ringP := eval.params.RingP() levelQ := utils.MinInt(ctOut.Level(), utils.MinInt(ctIn.Level(), matrix.Level)) - levelP := eval.params.PCount() - 1 + levelP := len(ringP.Modulus) - 1 QiOverF := eval.params.QiOverflowMargin(levelQ) - PiOverF := eval.params.PiOverflowMargin() + PiOverF := eval.params.PiOverflowMargin(levelP) - ksResP0 := eval.PoolP[0] // Key-Switch ctOut[0] mod P - ksResP1 := eval.PoolP[1] // Key-Switch ctOut[1] mod P - tmpP0 := eval.PoolP[2] // Automorphism not-inplace pool res[0] mod P + ksResP0 := eval.Pool[0].P // Key-Switch ctOut[0] mod P + ksResP1 := eval.Pool[1].P // Key-Switch ctOut[1] mod P + tmpP0 := eval.Pool[2].P // Automorphism not-inplace pool res[0] mod P tmpP1 := eval.poolQMul[0] // Automorphism not-inplace pool res[1] mod P - accP0 := eval.PoolP[3] // Accumulator ctOut[0] mod P - accP1 := eval.PoolP[4] // Accumulator ctOut[1] mod P + accP0 := eval.Pool[3].P // Accumulator ctOut[0] mod P + accP1 := eval.Pool[4].P // Accumulator ctOut[1] mod P - ct0TimesP := eval.PoolQ[0] // ct0 * P mod Q - ksResQ0 := eval.PoolQ[1] // Key-Switch ctOut[0] mod Q - ksResQ1 := eval.PoolQ[2] // Key-Switch ctOut[0] mod Q - tmpQ0 := eval.PoolQ[3] // Automorphism not-inplace pool ctOut[0] mod Q - tmpQ1 := eval.PoolQ[4] // Automorphism not-inplace pool ctOut[1] mod Q + ct0TimesP := eval.Pool[0].Q // ct0 * P mod Q + ksResQ0 := eval.Pool[1].Q // Key-Switch ctOut[0] mod Q + ksResQ1 := eval.Pool[2].Q // Key-Switch ctOut[0] mod Q + tmpQ0 := eval.Pool[3].Q // Automorphism not-inplace pool ctOut[0] mod Q + tmpQ1 := eval.Pool[4].Q // Automorphism not-inplace pool ctOut[1] mod Q - ringQ.MulScalarBigintLvl(levelQ, ctIn.Value[0], ringP.ModulusBigint, ct0TimesP) // P*c0 + var ctInTmp0, ctInTmp1 *ring.Poly + if ctIn != ctOut { + ring.CopyValuesLvl(levelQ, ctIn.Value[0], eval.ctxpool.Value[0]) + ring.CopyValuesLvl(levelQ, ctIn.Value[1], eval.ctxpool.Value[1]) + ctInTmp0, ctInTmp1 = eval.ctxpool.Value[0], eval.ctxpool.Value[1] + } else { + ctInTmp0, ctInTmp1 = ctIn.Value[0], ctIn.Value[1] + } + + ringQ.MulScalarBigintLvl(levelQ, ctInTmp0, ringP.ModulusBigint, ct0TimesP) // P*c0 var state bool var cnt int @@ -397,7 +474,7 @@ func (eval *evaluator) MultiplyByDiagMatrix(ctIn *Ciphertext, matrix *PtDiagMatr index := eval.permuteNTTIndex[galEl] - eval.KeyswitchHoistedNoModDown(levelQ, PoolDecompQ, PoolDecompP, rtk, ksResQ0, ksResQ1, ksResP0, ksResP1) + eval.KeyswitchHoistedNoModDown(levelQ, PoolDecompQP, rtk, ksResQ0, ksResQ1, ksResP0, ksResP1) ringQ.AddLvl(levelQ, ksResQ0, ct0TimesP, ksResQ0) // phi(d0_Q) += phi(P*c0) @@ -407,21 +484,21 @@ func (eval *evaluator) MultiplyByDiagMatrix(ctIn *Ciphertext, matrix *PtDiagMatr ring.PermuteNTTWithIndexLvl(levelP, ksResP0, index, tmpP0) // phi(P*c0 + d0_P) ring.PermuteNTTWithIndexLvl(levelP, ksResP1, index, tmpP1) // phi( d1_P) - plaintextQ := matrix.Vec[k][0] - plaintextP := matrix.Vec[k][1] + plaintextQ := matrix.Vec[k].Q + plaintextP := matrix.Vec[k].P if cnt == 0 { // keyswitch(c1_Q) = (d0_QP, d1_QP) ringQ.MulCoeffsMontgomeryLvl(levelQ, plaintextQ, tmpQ0, ctOut.Value[0]) // phi(P*c0 + d0_Q) * plaintext ringQ.MulCoeffsMontgomeryLvl(levelQ, plaintextQ, tmpQ1, ctOut.Value[1]) // phi(d1_Q) * plaintext - ringP.MulCoeffsMontgomery(plaintextP, tmpP0, accP0) // phi(d0_P) * plaintext - ringP.MulCoeffsMontgomery(plaintextP, tmpP1, accP1) // phi(d1_P) * plaintext + ringP.MulCoeffsMontgomeryLvl(levelP, plaintextP, tmpP0, accP0) // phi(d0_P) * plaintext + ringP.MulCoeffsMontgomeryLvl(levelP, plaintextP, tmpP1, accP1) // phi(d1_P) * plaintext } else { // keyswitch(c1_Q) = (d0_QP, d1_QP) ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, plaintextQ, tmpQ0, ctOut.Value[0]) // phi(P*c0 + d0_Q) * plaintext ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, plaintextQ, tmpQ1, ctOut.Value[1]) // phi(d1_Q) * plaintext - ringP.MulCoeffsMontgomeryAndAdd(plaintextP, tmpP0, accP0) // phi(d0_P) * plaintext - ringP.MulCoeffsMontgomeryAndAdd(plaintextP, tmpP1, accP1) // phi(d1_P) * plaintext + ringP.MulCoeffsMontgomeryAndAddLvl(levelP, plaintextP, tmpP0, accP0) // phi(d0_P) * plaintext + ringP.MulCoeffsMontgomeryAndAddLvl(levelP, plaintextP, tmpP1, accP1) // phi(d1_P) * plaintext } if cnt%QiOverF == QiOverF-1 { @@ -430,8 +507,8 @@ func (eval *evaluator) MultiplyByDiagMatrix(ctIn *Ciphertext, matrix *PtDiagMatr } if cnt%PiOverF == PiOverF-1 { - ringP.Reduce(accP0, accP0) - ringP.Reduce(accP1, accP1) + ringP.ReduceLvl(levelP, accP0, accP0) + ringP.ReduceLvl(levelP, accP1, accP1) } cnt++ @@ -444,16 +521,16 @@ func (eval *evaluator) MultiplyByDiagMatrix(ctIn *Ciphertext, matrix *PtDiagMatr } if cnt%PiOverF == 0 { - ringP.Reduce(accP0, accP0) - ringP.Reduce(accP1, accP1) + ringP.ReduceLvl(levelP, accP0, accP0) + ringP.ReduceLvl(levelP, accP1, accP1) } - eval.Baseconverter.ModDownSplitNTTPQ(levelQ, ctOut.Value[0], accP0, ctOut.Value[0]) // sum(phi(c0 * P + d0_QP))/P - eval.Baseconverter.ModDownSplitNTTPQ(levelQ, ctOut.Value[1], accP1, ctOut.Value[1]) // sum(phi(d1_QP))/P + eval.Baseconverter.ModDownQPtoQNTT(levelQ, levelP, ctOut.Value[0], accP0, ctOut.Value[0]) // sum(phi(c0 * P + d0_QP))/P + eval.Baseconverter.ModDownQPtoQNTT(levelQ, levelP, ctOut.Value[1], accP1, ctOut.Value[1]) // sum(phi(d1_QP))/P if state { // Rotation by zero - ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, matrix.Vec[0][0], ctIn.Value[0], ctOut.Value[0]) // ctOut += c0_Q * plaintext - ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, matrix.Vec[0][0], ctIn.Value[1], ctOut.Value[1]) // ctOut += c1_Q * plaintext + ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, matrix.Vec[0].Q, ctInTmp0, ctOut.Value[0]) // ctOut += c0_Q * plaintext + ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, matrix.Vec[0].Q, ctInTmp1, ctOut.Value[1]) // ctOut += c1_Q * plaintext } ctOut.Scale = matrix.Scale * ctIn.Scale @@ -464,26 +541,35 @@ func (eval *evaluator) MultiplyByDiagMatrix(ctIn *Ciphertext, matrix *PtDiagMatr // respectively, each of size params.Beta(). // The BSGS approach is used (double hoisting with baby-step giant-step), which is faster than MultiplyByDiagMatrix // for matrix with more than a few non-zero diagonals and uses much less keys. -func (eval *evaluator) MultiplyByDiagMatrixBSGS(ctIn *Ciphertext, matrix *PtDiagMatrix, PoolDecompQ, PoolDecompP []*ring.Poly, ctOut *Ciphertext) { +func (eval *evaluator) MultiplyByDiagMatrixBSGS(ctIn *Ciphertext, matrix PtDiagMatrix, PoolDecompQP []rlwe.PolyQP, ctOut *Ciphertext) { // N1*N2 = N N1 := matrix.N1 - ringQ := eval.ringQ - ringP := eval.ringP + ringQ := eval.params.RingQ() + ringP := eval.params.RingP() levelQ := utils.MinInt(ctOut.Level(), utils.MinInt(ctIn.Level(), matrix.Level)) - levelP := eval.params.PCount() - 1 + levelP := len(ringP.Modulus) - 1 QiOverF := eval.params.QiOverflowMargin(levelQ) - PiOverF := eval.params.PiOverflowMargin() + PiOverF := eval.params.PiOverflowMargin(levelP) // Computes the rotations indexes of the non-zero rows of the diagonalized DFT matrix for the baby-step giang-step algorithm index, rotations := bsgsIndex(matrix.Vec, 1<>1) == (PiOverF>>1)-1 { - ringP.Reduce(pool2P, pool2P) - ringP.Reduce(pool3P, pool3P) + ringP.ReduceLvl(levelP, pool2P, pool2P) + ringP.ReduceLvl(levelP, pool3P, pool3P) } cnt1++ @@ -580,24 +666,24 @@ func (eval *evaluator) MultiplyByDiagMatrixBSGS(ctIn *Ciphertext, matrix *PtDiag } if cnt1%(PiOverF>>1) != 0 { - ringP.Reduce(pool2P, pool2P) - ringP.Reduce(pool3P, pool3P) + ringP.ReduceLvl(levelP, pool2P, pool2P) + ringP.ReduceLvl(levelP, pool3P, pool3P) } // Hoisting of the ModDown of sum(sum(phi(d0 + P*c0) * plaintext)) and sum(sum(phi(d1) * plaintext)) - eval.Baseconverter.ModDownSplitNTTPQ(levelQ, tmpQ0, pool2P, tmpQ0) // sum(phi(d0) * plaintext)/P - eval.Baseconverter.ModDownSplitNTTPQ(levelQ, tmpQ1, pool3P, tmpQ1) // sum(phi(d1) * plaintext)/P + eval.Baseconverter.ModDownQPtoQNTT(levelQ, levelP, tmpQ0, pool2P, tmpQ0) // sum(phi(d0) * plaintext)/P + eval.Baseconverter.ModDownQPtoQNTT(levelQ, levelP, tmpQ1, pool3P, tmpQ1) // sum(phi(d1) * plaintext)/P // If i == 0 if state { // If no loop before, then we copy the values on the accumulator instead of adding them if len(index[j]) == 1 { - ringQ.MulCoeffsMontgomeryLvl(levelQ, matrix.Vec[N1*j][0], ctIn.Value[0], tmpQ0) // c0 * plaintext + sum(phi(d0) * plaintext)/P + phi(c0) * plaintext mod Q - ringQ.MulCoeffsMontgomeryLvl(levelQ, matrix.Vec[N1*j][0], ctIn.Value[1], tmpQ1) // c1 * plaintext + sum(phi(d1) * plaintext)/P + phi(c1) * plaintext mod Q + ringQ.MulCoeffsMontgomeryLvl(levelQ, matrix.Vec[N1*j].Q, ctInTmp0, tmpQ0) // c0 * plaintext + sum(phi(d0) * plaintext)/P + phi(c0) * plaintext mod Q + ringQ.MulCoeffsMontgomeryLvl(levelQ, matrix.Vec[N1*j].Q, ctInTmp1, tmpQ1) // c1 * plaintext + sum(phi(d1) * plaintext)/P + phi(c1) * plaintext mod Q } else { - ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, matrix.Vec[N1*j][0], ctIn.Value[0], tmpQ0) // c0 * plaintext + sum(phi(d0) * plaintext)/P + phi(c0) * plaintext mod Q - ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, matrix.Vec[N1*j][0], ctIn.Value[1], tmpQ1) // c1 * plaintext + sum(phi(d1) * plaintext)/P + phi(c1) * plaintext mod Q + ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, matrix.Vec[N1*j].Q, ctInTmp0, tmpQ0) // c0 * plaintext + sum(phi(d0) * plaintext)/P + phi(c0) * plaintext mod Q + ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, matrix.Vec[N1*j].Q, ctInTmp1, tmpQ1) // c1 * plaintext + sum(phi(d1) * plaintext)/P + phi(c1) * plaintext mod Q } N1Rot++ @@ -639,8 +725,8 @@ func (eval *evaluator) MultiplyByDiagMatrixBSGS(ctIn *Ciphertext, matrix *PtDiag } if cnt0%PiOverF == PiOverF-1 { - ringP.Reduce(tmpP2, tmpP2) - ringP.Reduce(tmpP3, tmpP3) + ringP.ReduceLvl(levelP, tmpP2, tmpP2) + ringP.ReduceLvl(levelP, tmpP3, tmpP3) } cnt0++ @@ -653,8 +739,8 @@ func (eval *evaluator) MultiplyByDiagMatrixBSGS(ctIn *Ciphertext, matrix *PtDiag } if cnt0%PiOverF != 0 { - ringP.Reduce(tmpP2, tmpP2) - ringP.Reduce(tmpP3, tmpP3) + ringP.ReduceLvl(levelP, tmpP2, tmpP2) + ringP.ReduceLvl(levelP, tmpP3, tmpP3) } // if j == 0 (N2 rotation by zero) @@ -666,14 +752,14 @@ func (eval *evaluator) MultiplyByDiagMatrixBSGS(ctIn *Ciphertext, matrix *PtDiag state = true } else { - plaintextQ := matrix.Vec[i][0] - plaintextP := matrix.Vec[i][1] + plaintextQ := matrix.Vec[i].Q + plaintextP := matrix.Vec[i].P N1Rot++ // keyswitch(c1_Q) = (d0_QP, d1_QP) ringQ.MulCoeffsMontgomeryConstantAndAddNoModLvl(levelQ, plaintextQ, vecRotQ[i][0], tmpQ2) // phi(P*c0 + d0_Q) * plaintext ringQ.MulCoeffsMontgomeryConstantAndAddNoModLvl(levelQ, plaintextQ, vecRotQ[i][1], tmpQ3) // phi(d1_Q) * plaintext - ringP.MulCoeffsMontgomeryAndAddNoMod(plaintextP, vecRotP[i][0], tmpP2) // phi(d0_P) * plaintext - ringP.MulCoeffsMontgomeryAndAddNoMod(plaintextP, vecRotP[i][1], tmpP3) // phi(d1_P) * plaintext + ringP.MulCoeffsMontgomeryAndAddNoModLvl(levelP, plaintextP, vecRotP[i][0], tmpP2) // phi(d0_P) * plaintext + ringP.MulCoeffsMontgomeryAndAddNoModLvl(levelP, plaintextP, vecRotP[i][1], tmpP3) // phi(d1_P) * plaintext if cnt1%(QiOverF>>1) == (QiOverF>>1)-1 { ringQ.ReduceLvl(levelQ, tmpQ2, tmpQ2) @@ -681,8 +767,8 @@ func (eval *evaluator) MultiplyByDiagMatrixBSGS(ctIn *Ciphertext, matrix *PtDiag } if cnt1%(PiOverF>>1) == (PiOverF>>1)-1 { - ringP.Reduce(tmpP2, tmpP2) - ringP.Reduce(tmpP3, tmpP3) + ringP.ReduceLvl(levelP, tmpP2, tmpP2) + ringP.ReduceLvl(levelP, tmpP3, tmpP3) } cnt1++ @@ -695,20 +781,20 @@ func (eval *evaluator) MultiplyByDiagMatrixBSGS(ctIn *Ciphertext, matrix *PtDiag } if cnt1%(PiOverF>>1) != 0 { - ringP.Reduce(tmpP2, tmpP2) - ringP.Reduce(tmpP3, tmpP3) + ringP.ReduceLvl(levelP, tmpP2, tmpP2) + ringP.ReduceLvl(levelP, tmpP3, tmpP3) } - eval.Baseconverter.ModDownSplitNTTPQ(levelQ, tmpQ2, tmpP2, tmpQ2) // sum(phi(c0 * P + d0_QP))/P - eval.Baseconverter.ModDownSplitNTTPQ(levelQ, tmpQ3, tmpP3, tmpQ3) // sum(phi(d1_QP))/P + eval.Baseconverter.ModDownQPtoQNTT(levelQ, levelP, tmpQ2, tmpP2, tmpQ2) // sum(phi(c0 * P + d0_QP))/P + eval.Baseconverter.ModDownQPtoQNTT(levelQ, levelP, tmpQ3, tmpP3, tmpQ3) // sum(phi(d1_QP))/P ringQ.AddLvl(levelQ, ctOut.Value[0], tmpQ2, ctOut.Value[0]) // ctOut += sum(phi(c0 * P + d0_QP))/P ringQ.AddLvl(levelQ, ctOut.Value[1], tmpQ3, ctOut.Value[1]) // ctOut += sum(phi(d1_QP))/P if state { // Rotation by zero N1Rot++ - ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, matrix.Vec[0][0], ctIn.Value[0], ctOut.Value[0]) // ctOut += c0_Q * plaintext - ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, matrix.Vec[0][0], ctIn.Value[1], ctOut.Value[1]) // ctOut += c1_Q * plaintext + ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, matrix.Vec[0].Q, ctInTmp0, ctOut.Value[0]) // ctOut += c0_Q * plaintext + ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, matrix.Vec[0].Q, ctInTmp1, ctOut.Value[1]) // ctOut += c1_Q * plaintext } ctOut.Scale = matrix.Scale * ctIn.Scale diff --git a/ckks/marshaler.go b/ckks/marshaler.go index 8e231c38..adcd6ac5 100644 --- a/ckks/marshaler.go +++ b/ckks/marshaler.go @@ -3,10 +3,9 @@ package ckks import ( "encoding/binary" "errors" - "math" - "github.com/ldsec/lattigo/v2/ring" "github.com/ldsec/lattigo/v2/rlwe" + "math" ) // GetDataLen returns the length in bytes of the target Ciphertext. diff --git a/ckks/params.go b/ckks/params.go index 6f8a296d..81fa24ce 100644 --- a/ckks/params.go +++ b/ckks/params.go @@ -289,15 +289,12 @@ func (p Parameters) RotationsForReplicateLog(batch, n int) (rotations []int) { return p.RotationsForInnerSumLog(-batch, n) } -// RotationsForSubSum generates the rotations that will be performed by the +// RotationsForTrace generates the rotations that will be performed by the // `Evaluator.SubSum` operation. -func (p Parameters) RotationsForSubSum(logSlots int) (rotations []int) { +func (p Parameters) RotationsForTrace(logSlotsStart, logSlotsEnd int) (rotations []int) { rotations = []int{} - - logN := p.LogN() - //SubSum rotation needed X -> Y^slots rotations - for i := logSlots; i < logN-1; i++ { + for i := logSlotsStart; i < logSlotsEnd; i++ { if !utils.IsInSliceInt(1< cannot evaluate", levels, logDegree) + if levels < depth { + return fmt.Errorf("%d levels < %d log(d) -> cannot evaluate", levels, depth) } return nil } -// Degree returns the degree of the polynomial -func (p *Poly) Degree() int { - return len(p.coeffs) - 1 -} - // EvaluatePoly evaluates a polynomial in standard basis on the input Ciphertext in ceil(log2(deg+1)) levels. // Returns an error if the input ciphertext does not have enough level to carry out the full polynomial evaluation. // Returns an error if something is wrong with the scale. - -func (eval *evaluator) EvaluatePoly(ct0 *Ciphertext, pol *Poly, targetScale float64) (opOut *Ciphertext, err error) { +// If the polynomial is given in Chebyshev basis, then a change of basis ct' = (2/(b-a)) * (ct + (-a-b)/(b-a)) +// is necessary before the polynomial evaluation to ensure correctness. +// Coefficients of the polynomial with an absolute value smaller than "IsNegligbleThreshold" will automatically be set to zero +// if the polynomial is "even" or "odd" (to ensure that the even or odd property remains valid +// after the "splitCoeffs" polynomial decomposition). +func (eval *evaluator) EvaluatePoly(ct0 *Ciphertext, pol *Polynomial, targetScale float64) (opOut *Ciphertext, err error) { if err := checkEnoughLevels(ct0.Level(), pol, 1); err != nil { return ct0, err @@ -66,212 +93,133 @@ func (eval *evaluator) EvaluatePoly(ct0 *Ciphertext, pol *Poly, targetScale floa logDegree := bits.Len64(uint64(pol.Degree())) logSplit := (logDegree >> 1) //optimalSplit(logDegree) // + odd, even := isOddOrEvenPolynomial(pol.Coeffs) + for i := 2; i < (1 << logSplit); i++ { - if err = computePowerBasis(i, C, eval); err != nil { - return nil, err + if !(even || odd) || (i&1 == 0 && even) || (i&1 == 1 && odd) { + if err = computePowerBasis(i, C, targetScale, pol.Basis, eval); err != nil { + return nil, err + } } } for i := logSplit; i < logDegree; i++ { - if err = computePowerBasis(1<> 1) //optimalSplit(logDegree) // - - for i := 2; i < (1 << logSplit); i++ { - if err = computePowerBasisCheby(i, C, eval); err != nil { - return nil, err - } - } - - for i := logSplit; i < logDegree; i++ { - if err = computePowerBasisCheby(1<> 1 + var a, b, c int + if n&(n-1) == 0 { + a, b = n/2, n/2 //Necessary for optimal depth + } else { + // [Lee et al. 2020] : High-Precision and Low-Complexity Approximate Homomorphic Encryption by Error Variance Minimization + // Maximize the number of odd terms of Chebyshev basis + k := int(math.Ceil(math.Log2(float64(n)))) - 1 + a = (1 << k) - 1 + b = n + 1 - (1 << k) - // Recurses on the given indexes - if err = computePowerBasis(a, C, evaluator); err != nil { - return err - } - if err = computePowerBasis(b, C, evaluator); err != nil { - return err - } - - // Computes C[n] = C[a]*C[b] - C[n] = evaluator.MulRelinNew(C[a], C[b]) - - if err = evaluator.Rescale(C[n], evaluator.scale, C[n]); err != nil { - return err - } - } - - return nil -} - -func computePowerBasisCheby(n int, C map[int]*Ciphertext, evaluator *evaluator) (err error) { - - // Given a hash table with the first three evaluations of the Chebyshev ring at x in the interval a, b: - // C0 = 1 (actually not stored in the hash table) - // C1 = (2*x - a - b)/(b-a) - // C2 = 2*C1*C1 - C0 - // Evaluates the nth degree Chebyshev ring in a recursive manner, storing intermediate results in the hashtable. - // Consumes at most ceil(sqrt(n)) levels for an evaluation at Cn. - // Uses the following property: for a given Chebyshev ring Cn = 2*Ca*Cb - Cc, n = a+b and c = abs(a-b) - - if C[n] == nil { - - // Computes the index required to compute the asked ring evaluation - a := int(math.Ceil(float64(n) / 2)) - b := n >> 1 - c := int(math.Abs(float64(a) - float64(b))) - - // Recurses on the given indexes - if err = computePowerBasisCheby(a, C, evaluator); err != nil { - return err - } - if err = computePowerBasisCheby(b, C, evaluator); err != nil { - return err - } - - // Since C[0] is not stored (but rather seen as the constant 1), only recurses on c if c!= 0 - if c != 0 { - if err = computePowerBasisCheby(c, C, evaluator); err != nil { - return err + if basis == ChebyshevBasis { + c = int(math.Abs(float64(a) - float64(b))) // Cn = 2*Ca*Cb - Cc, n = a+b and c = abs(a-b) } } - - // Computes C[n] = C[a]*C[b] - //fmt.Println("Mul", C[a].Level(), C[b].Level()) - C[n] = evaluator.MulRelinNew(C[a], C[b]) - if err = evaluator.Rescale(C[n], evaluator.scale, C[n]); err != nil { + // Recurses on the given indexes + if err = computePowerBasis(a, C, scale, basis, evaluator); err != nil { + return err + } + if err = computePowerBasis(b, C, scale, basis, evaluator); err != nil { return err } - // Computes C[n] = 2*C[a]*C[b] - evaluator.Add(C[n], C[n], C[n]) + // Computes C[n] = C[a]*C[b] + C[n] = evaluator.MulRelinNew(C[a], C[b]) - // Computes C[n] = 2*C[a]*C[b] - C[c] - if c == 0 { - evaluator.AddConst(C[n], -1, C[n]) - } else { - evaluator.Sub(C[n], C[c], C[n]) + if err = evaluator.Rescale(C[n], scale, C[n]); err != nil { + return err } + if basis == ChebyshevBasis { + + // Computes C[n] = 2*C[a]*C[b] + evaluator.Add(C[n], C[n], C[n]) + + // Computes C[n] = 2*C[a]*C[b] - C[c] + if c == 0 { + evaluator.AddConst(C[n], -1, C[n]) + } else { + // Since C[0] is not stored (but rather seen as the constant 1), only recurses on c if c!= 0 + if err = computePowerBasis(c, C, scale, basis, evaluator); err != nil { + return err + } + evaluator.Sub(C[n], C[c], C[n]) + } + } } return nil } -func splitCoeffs(coeffs *Poly, split int) (coeffsq, coeffsr *Poly) { +func splitCoeffs(coeffs *Polynomial, split int) (coeffsq, coeffsr *Polynomial) { // Splits a polynomial p such that p = q*C^degree + r. - coeffsr = new(Poly) - coeffsr.coeffs = make([]complex128, split) - if coeffs.maxDeg == coeffs.Degree() { - coeffsr.maxDeg = split - 1 + coeffsr = new(Polynomial) + coeffsr.Coeffs = make([]complex128, split) + if coeffs.MaxDeg == coeffs.Degree() { + coeffsr.MaxDeg = split - 1 } else { - coeffsr.maxDeg = coeffs.maxDeg - (coeffs.Degree() - split + 1) + coeffsr.MaxDeg = coeffs.MaxDeg - (coeffs.Degree() - split + 1) } for i := 0; i < split; i++ { - coeffsr.coeffs[i] = coeffs.coeffs[i] + coeffsr.Coeffs[i] = coeffs.Coeffs[i] } - coeffsq = new(Poly) - coeffsq.coeffs = make([]complex128, coeffs.Degree()-split+1) - coeffsq.maxDeg = coeffs.maxDeg + coeffsq = new(Polynomial) + coeffsq.Coeffs = make([]complex128, coeffs.Degree()-split+1) + coeffsq.MaxDeg = coeffs.MaxDeg - coeffsq.coeffs[0] = coeffs.coeffs[split] - for i := split + 1; i < coeffs.Degree()+1; i++ { - coeffsq.coeffs[i-split] = coeffs.coeffs[i] + coeffsq.Coeffs[0] = coeffs.Coeffs[split] + + if coeffs.Basis == StandardBasis { + for i := split + 1; i < coeffs.Degree()+1; i++ { + coeffsq.Coeffs[i-split] = coeffs.Coeffs[i] + } + } else if coeffs.Basis == ChebyshevBasis { + for i, j := split+1, 1; i < coeffs.Degree()+1; i, j = i+1, j+1 { + coeffsq.Coeffs[i-split] = 2 * coeffs.Coeffs[i] + coeffsr.Coeffs[split-j] -= coeffs.Coeffs[i] + } } - if coeffs.lead { - coeffsq.lead = true + if coeffs.Lead { + coeffsq.Lead = true } + coeffsq.Basis, coeffsr.Basis = coeffs.Basis, coeffs.Basis + return coeffsq, coeffsr } -func splitCoeffsCheby(coeffs *Poly, split int) (coeffsq, coeffsr *Poly) { +func recurse(targetScale float64, logSplit, logDegree int, coeffs *Polynomial, C map[int]*Ciphertext, evaluator *evaluator) (res *Ciphertext, err error) { - // Splits a Chebyshev polynomial p such that p = q*C^degree + r, where q and r are a linear combination of a Chebyshev basis. - coeffsr = new(Poly) - coeffsr.coeffs = make([]complex128, split) - if coeffs.maxDeg == coeffs.Degree() { - coeffsr.maxDeg = split - 1 - } else { - coeffsr.maxDeg = coeffs.maxDeg - (coeffs.Degree() - split + 1) - } - - for i := 0; i < split; i++ { - coeffsr.coeffs[i] = coeffs.coeffs[i] - } - - coeffsq = new(Poly) - coeffsq.coeffs = make([]complex128, coeffs.Degree()-split+1) - coeffsq.maxDeg = coeffs.maxDeg - - coeffsq.coeffs[0] = coeffs.coeffs[split] - for i, j := split+1, 1; i < coeffs.Degree()+1; i, j = i+1, j+1 { - coeffsq.coeffs[i-split] = 2 * coeffs.coeffs[i] - coeffsr.coeffs[split-j] -= coeffs.coeffs[i] - } - - if coeffs.lead { - coeffsq.lead = true - } - - return coeffsq, coeffsr -} - -func recurse(targetScale float64, logSplit, logDegree int, coeffs *Poly, C map[int]*Ciphertext, evaluator *evaluator) (res *Ciphertext, err error) { - - // Recursively computes the evalution of the Chebyshev polynomial using a baby-set giant-step algorithm. + // Recursively computes the evaluation of the Chebyshev polynomial using a baby-set giant-step algorithm. if coeffs.Degree() < (1 << logSplit) { - if coeffs.lead && coeffs.maxDeg > ((1< 1 { + if coeffs.Lead && logSplit > 1 && coeffs.MaxDeg%(1<<(logSplit+1)) > (1<<(logSplit-1)) { logDegree = int(bits.Len64(uint64(coeffs.Degree()))) logSplit = logDegree >> 1 @@ -279,7 +227,7 @@ func recurse(targetScale float64, logSplit, logDegree int, coeffs *Poly, C map[i return recurse(targetScale, logSplit, logDegree, coeffs, C, evaluator) } - return evaluatePolyFromPowerBasis(targetScale, coeffs, C, evaluator) + return evaluatePolyFromPowerBasis(targetScale, coeffs, logSplit, C, evaluator) } var nextPower = 1 << logSplit @@ -291,21 +239,17 @@ func recurse(targetScale float64, logSplit, logDegree int, coeffs *Poly, C map[i level := C[nextPower].Level() - 1 - if coeffsq.maxDeg >= 1<<(logDegree-1) && coeffsq.lead { + if coeffsq.MaxDeg >= 1<<(logDegree-1) && coeffsq.Lead { level++ } - currentQi := float64(evaluator.params.Q()[level]) + currentQi := evaluator.params.QiFloat64(level) - //fmt.Printf("X^%2d: %d %d %t %d\n", nextPower, coeffsq.maxDeg, coeffsr.maxDeg, coeffsq.maxDeg >= 1<<(logDegree-1), level) - //fmt.Printf("X^%2d: %f %f\n", nextPower, targetScale, targetScale* currentQi / C[nextPower].Scale()) - //fmt.Printf("X^%2d : qi %d %t %d %d\n", nextPower, level, coeffsq.lead, coeffsq.maxDeg, 1<<(logDegree-1)) - //fmt.Println() - var tmp *Ciphertext if res, err = recurse(targetScale*currentQi/C[nextPower].Scale, logSplit, logDegree, coeffsq, C, evaluator); err != nil { return nil, err } + var tmp *Ciphertext if tmp, err = recurse(targetScale, logSplit, logDegree, coeffsr, C, evaluator); err != nil { return nil, err } @@ -316,22 +260,18 @@ func recurse(targetScale float64, logSplit, logDegree int, coeffs *Poly, C map[i } } - //fmt.Printf("X^%2d: (%d %f -> \n", nextPower, res.Level(), res.Scale()) evaluator.MulRelin(res, C[nextPower], res) if res.Level() > tmp.Level() { - if err = evaluator.Rescale(res, evaluator.scale, res); err != nil { + if err = evaluator.Rescale(res, targetScale, res); err != nil { return nil, err } - //fmt.Printf("%f = %d) + (%d %f) = ", res.Scale(), res.Level(), tmp.Level(), tmp.Scale()) evaluator.Add(res, tmp, res) - //fmt.Printf("(%d %f) %f\n", res.Level(), res.Scale(), res.Scale()-tmp.Scale()) } else { evaluator.Add(res, tmp, res) - if err = evaluator.Rescale(res, evaluator.scale, res); err != nil { + if err = evaluator.Rescale(res, targetScale, res); err != nil { return nil, err } - } tmp = nil @@ -339,124 +279,96 @@ func recurse(targetScale float64, logSplit, logDegree int, coeffs *Poly, C map[i return } -func recurseCheby(targetScale float64, logSplit, logDegree int, coeffs *Poly, C map[int]*Ciphertext, evaluator *evaluator) (res *Ciphertext, err error) { +func evaluatePolyFromPowerBasis(targetScale float64, coeffs *Polynomial, logSplit int, C map[int]*Ciphertext, evaluator *evaluator) (res *Ciphertext, err error) { - // Recursively computes the evalution of the Chebyshev polynomial using a baby-set giant-step algorithm. - if coeffs.Degree() < (1 << logSplit) { + minimumDegreeNonZeroCoefficient := 0 - if coeffs.lead && coeffs.maxDeg > ((1< 1 { - - logDegree = int(bits.Len64(uint64(coeffs.Degree()))) - logSplit = logDegree >> 1 - - return recurseCheby(targetScale, logSplit, logDegree, coeffs, C, evaluator) - } - - return evaluatePolyFromPowerBasis(targetScale, coeffs, C, evaluator) - } - - var nextPower = 1 << logSplit - for nextPower < (coeffs.Degree()>>1)+1 { - nextPower <<= 1 - } - - coeffsq, coeffsr := splitCoeffsCheby(coeffs, nextPower) - - level := C[nextPower].Level() - 1 - - if coeffsq.maxDeg >= 1<<(logDegree-1) && coeffsq.lead { - level++ - } - - currentQi := float64(evaluator.params.Q()[level]) - - //fmt.Printf("X^%2d: %d %d %t %d\n", nextPower, coeffsq.maxDeg, coeffsr.maxDeg, coeffsq.maxDeg >= 1<<(logDegree-1), level) - //fmt.Printf("X^%2d: %f %f\n", nextPower, targetScale, targetScale* currentQi / C[nextPower].Scale()) - //fmt.Printf("X^%2d : qi %d %t %d %d\n", nextPower, level, coeffsq.lead, coeffsq.maxDeg, 1<<(logDegree-1)) - //fmt.Println() - - if res, err = recurseCheby(targetScale*currentQi/C[nextPower].Scale, logSplit, logDegree, coeffsq, C, evaluator); err != nil { - return nil, err - } - - var tmp *Ciphertext - if tmp, err = recurseCheby(targetScale, logSplit, logDegree, coeffsr, C, evaluator); err != nil { - return nil, err - } - - if res.Level() > tmp.Level() { - for res.Level() != tmp.Level()+1 { - evaluator.DropLevel(res, 1) + for i := coeffs.Degree(); i > 0; i-- { + if isNotNegligible(coeffs.Coeffs[i]) { + minimumDegreeNonZeroCoefficient = i + break } } - //fmt.Printf("X^%2d: (%d %f -> \n", nextPower, res.Level(), res.Scale()) - evaluator.MulRelin(res, C[nextPower], res) + c := coeffs.Coeffs[0] - if res.Level() > tmp.Level() { - if err = evaluator.Rescale(res, evaluator.scale, res); err != nil { - return nil, err - } - //fmt.Printf("%f = %d) + (%d %f) = ", res.Scale(), res.Level(), tmp.Level(), tmp.Scale()) - evaluator.Add(res, tmp, res) - //fmt.Printf("(%d %f) %f\n", res.Level(), res.Scale(), res.Scale()-tmp.Scale()) - } else { - evaluator.Add(res, tmp, res) - if err = evaluator.Rescale(res, evaluator.scale, res); err != nil { - return nil, err - } - } - - tmp = nil - - return - -} - -func evaluatePolyFromPowerBasis(targetScale float64, coeffs *Poly, C map[int]*Ciphertext, evaluator *evaluator) (res *Ciphertext, err error) { - - if coeffs.Degree() == 0 { + if minimumDegreeNonZeroCoefficient == 0 { res = NewCiphertext(evaluator.params, 1, C[1].Level(), targetScale) - if math.Abs(real(coeffs.coeffs[0])) > 1e-14 || math.Abs(imag(coeffs.coeffs[0])) > 1e-14 { - evaluator.AddConst(res, coeffs.coeffs[0], res) + if isNotNegligible(c) { + evaluator.AddConst(res, c, res) } return } - currentQi := float64(evaluator.params.Q()[C[coeffs.Degree()].Level()]) + minimumDegreeNonZeroCoefficient = coeffs.Degree() + + currentQi := evaluator.params.QiFloat64(C[(minimumDegreeNonZeroCoefficient)].Level()) ctScale := targetScale * currentQi - //fmt.Printf("%d %f\n", coeffs.maxDeg, targetScale) - //fmt.Println("current Qi", evaluator.params.Qi[C[coeffs.Degree()].Level()]) - //fmt.Println(coeffs.Degree(), C[coeffs.Degree()].Level()) + res = NewCiphertext(evaluator.params, 1, C[minimumDegreeNonZeroCoefficient].Level(), ctScale) - res = NewCiphertext(evaluator.params, 1, C[coeffs.Degree()].Level(), ctScale) - - if math.Abs(real(coeffs.coeffs[0])) > 1e-14 || math.Abs(imag(coeffs.coeffs[0])) > 1e-14 { - evaluator.AddConst(res, coeffs.coeffs[0], res) + if isNotNegligible(c) { + evaluator.AddConst(res, c, res) } + cRealFlo, cImagFlo, constScale := ring.NewFloat(0, 128), ring.NewFloat(0, 128), ring.NewFloat(0, 128) + cRealBig, cImagBig := ring.NewUint(0), ring.NewUint(0) + for key := coeffs.Degree(); key > 0; key-- { - if key != 0 && (math.Abs(real(coeffs.coeffs[key])) > 1e-14 || math.Abs(imag(coeffs.coeffs[key])) > 1e-14) { + c = coeffs.Coeffs[key] + + if key != 0 && isNotNegligible(c) { + + cRealFlo.SetFloat64(real(c)) + cImagFlo.SetFloat64(imag(c)) + constScale.SetFloat64(targetScale * currentQi / C[key].Scale) // Target scale * rescale-scale / power basis scale - constScale := targetScale * currentQi / C[key].Scale + cRealFlo.Mul(cRealFlo, constScale).Int(cRealBig) + cImagFlo.Mul(cImagFlo, constScale).Int(cImagBig) - cReal := int64(real(coeffs.coeffs[key]) * constScale) - cImag := int64(imag(coeffs.coeffs[key]) * constScale) - - evaluator.MultByGaussianIntegerAndAdd(C[key], cReal, cImag, res) + evaluator.MultByGaussianIntegerAndAdd(C[key], cRealBig, cImagBig, res) } } - if err = evaluator.Rescale(res, evaluator.scale, res); err != nil { + if err = evaluator.Rescale(res, targetScale, res); err != nil { return nil, err } return } + +func isNotNegligible(c complex128) bool { + return (math.Abs(real(c)) > IsNegligbleThreshold || math.Abs(imag(c)) > IsNegligbleThreshold) +} + +func isOddOrEvenPolynomial(coeffs []complex128) (odd, even bool) { + even = true + odd = true + for i, c := range coeffs { + isnotnegligible := isNotNegligible(c) + odd = odd && !(i&1 == 0 && isnotnegligible) + even = even && !(i&1 == 1 && isnotnegligible) + if !odd && !even { + break + } + } + + // If even or odd, then sets the expected zero coefficients to zero + if even || odd { + var start int + if even { + start = 1 + } + for i := start; i < len(coeffs); i += 2 { + coeffs[i] = complex(0, 0) + } + } + + return +} diff --git a/ckks/utils.go b/ckks/utils.go index 4c156ba0..464c0102 100644 --- a/ckks/utils.go +++ b/ckks/utils.go @@ -27,6 +27,35 @@ func StandardDeviation(vec []float64, scale float64) (std float64) { return math.Sqrt(err/n) * scale } +func interfaceMod(x interface{}, qi uint64) uint64 { + + switch x := x.(type) { + + case uint64: + return x % qi + + case int64: + + if x > 0 { + return uint64(x) + } else if x < 0 { + return uint64(int64(qi) + x%int64(qi)) + } + return 0 + + case *big.Int: + + if x.Cmp(ring.NewUint(0)) != 0 { + return new(big.Int).Mod(x, ring.NewUint(qi)).Uint64() + } + + return 0 + + default: + panic("constant must either be uint64, int64 or *big.Int") + } +} + func scaleUpExact(value float64, n float64, q uint64) (res uint64) { var isNegative bool @@ -187,14 +216,15 @@ func GenSwitchkeysRescalingParams(Q, P []uint64) (params []uint64) { for i := 0; i < len(Q); i++ { params[i] = tmp.Mod(PBig, ring.NewUint(Q[i])).Uint64() - params[i] = ring.ModExp(params[i], int(Q[i]-2), Q[i]) + params[i] = ring.ModExp(params[i], Q[i]-2, Q[i]) params[i] = ring.MForm(params[i], Q[i], ring.BRedParams(Q[i])) } return } -func sliceBitReverseInPlaceFloat64(slice []float64, N int) { +// SliceBitReverseInPlaceFloat64 applies an in-place bit-reverse permuation on the input slice. +func SliceBitReverseInPlaceFloat64(slice []float64, N int) { var bit, j int @@ -215,7 +245,8 @@ func sliceBitReverseInPlaceFloat64(slice []float64, N int) { } } -func sliceBitReverseInPlaceComplex128(slice []complex128, N int) { +// SliceBitReverseInPlaceComplex128 applies an in-place bit-reverse permuation on the input slice. +func SliceBitReverseInPlaceComplex128(slice []complex128, N int) { var bit, j int @@ -236,7 +267,8 @@ func sliceBitReverseInPlaceComplex128(slice []complex128, N int) { } } -func sliceBitReverseInPlaceRingComplex(slice []*ring.Complex, N int) { +// SliceBitReverseInPlaceRingComplex applies an in-place bit-reverse permuation on the input slice. +func SliceBitReverseInPlaceRingComplex(slice []*ring.Complex, N int) { var bit, j int diff --git a/dbfv/dbfv_benchmark_test.go b/dbfv/dbfv_benchmark_test.go index d5a66159..76715556 100644 --- a/dbfv/dbfv_benchmark_test.go +++ b/dbfv/dbfv_benchmark_test.go @@ -6,7 +6,6 @@ import ( "github.com/ldsec/lattigo/v2/bfv" "github.com/ldsec/lattigo/v2/drlwe" - "github.com/ldsec/lattigo/v2/ring" "github.com/ldsec/lattigo/v2/rlwe" ) @@ -45,10 +44,6 @@ func benchPublicKeyGen(testCtx *testContext, b *testing.B) { sk0Shards := testCtx.sk0Shards - crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.ringQP) - - crp := crpGenerator.ReadNew() - type Party struct { *CKGProtocol s *rlwe.SecretKey @@ -60,6 +55,8 @@ func benchPublicKeyGen(testCtx *testContext, b *testing.B) { p.s = sk0Shards[0] p.s1 = p.AllocateShares() + crp := p.SampleCRP(testCtx.crs) + b.Run(testString("PublicKeyGen/Round1/Gen", parties, testCtx.params), func(b *testing.B) { for i := 0; i < b.N; i++ { @@ -102,13 +99,7 @@ func benchRelinKeyGen(testCtx *testContext, b *testing.B) { p.ephSk, p.share1, p.share2 = p.RKGProtocol.AllocateShares() p.rlk = bfv.NewRelinearizationKey(testCtx.params, 2) - crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.ringQP) - - crp := make([]*ring.Poly, testCtx.params.Beta()) - - for i := 0; i < testCtx.params.Beta(); i++ { - crp[i] = crpGenerator.ReadNew() - } + crp := p.SampleCRP(testCtx.crs) b.Run(testString("RelinKeyGen/Round1/Gen", parties, testCtx.params), func(b *testing.B) { for i := 0; i < b.N; i++ { @@ -124,7 +115,7 @@ func benchRelinKeyGen(testCtx *testContext, b *testing.B) { b.Run(testString("RelinKeyGen/Round2/Gen", parties, testCtx.params), func(b *testing.B) { for i := 0; i < b.N; i++ { - p.GenShareRoundTwo(p.ephSk, p.sk, p.share1, crp, p.share2) + p.GenShareRoundTwo(p.ephSk, p.sk, p.share1, p.share2) } }) @@ -153,7 +144,7 @@ func benchKeyswitching(testCtx *testContext, b *testing.B) { share *drlwe.CKSShare } - ciphertext := bfv.NewCiphertextRandom(testCtx.prng, testCtx.params, 1) + ciphertext := bfv.NewCiphertext(testCtx.params, 1) p := new(Party) p.CKSProtocol = NewCKSProtocol(testCtx.params, 6.36) @@ -188,7 +179,7 @@ func benchPublicKeySwitching(testCtx *testContext, b *testing.B) { sk0Shards := testCtx.sk0Shards pk1 := testCtx.pk1 - ciphertext := bfv.NewCiphertextRandom(testCtx.prng, testCtx.params, 1) + ciphertext := bfv.NewCiphertext(testCtx.params, 1) type Party struct { *PCKSProtocol @@ -239,12 +230,7 @@ func benchRotKeyGen(testCtx *testContext, b *testing.B) { p.s = sk0Shards[0] p.share = p.AllocateShares() - crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.ringQP) - crp := make([]*ring.Poly, testCtx.params.Beta()) - - for i := 0; i < testCtx.params.Beta(); i++ { - crp[i] = crpGenerator.ReadNew() - } + crp := p.SampleCRP(testCtx.crs) b.Run(testString("RotKeyGen/Round1/Gen", parties, testCtx.params), func(b *testing.B) { @@ -284,10 +270,9 @@ func benchRefresh(testCtx *testContext, b *testing.B) { p.s = sk0Shards[0] p.share = p.AllocateShare() - crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.ringQ) - crp := crpGenerator.ReadNew() + ciphertext := bfv.NewCiphertext(testCtx.params, 1) - ciphertext := bfv.NewCiphertextRandom(testCtx.prng, testCtx.params, 1) + crp := p.SampleCRP(ciphertext.Level(), testCtx.crs) b.Run(testString("Refresh/Round1/Gen", parties, testCtx.params), func(b *testing.B) { diff --git a/dbfv/dbfv_test.go b/dbfv/dbfv_test.go index e22ecb3a..b9d3c730 100644 --- a/dbfv/dbfv_test.go +++ b/dbfv/dbfv_test.go @@ -30,13 +30,10 @@ type testContext struct { // Polynomial degree n int - // floor(Q/T) mod each Qi in Montgomery form - deltaMont []uint64 - // Polynomial contexts - ringT *ring.Ring - ringQ *ring.Ring - ringQP *ring.Ring + ringT *ring.Ring + ringQ *ring.Ring + ringP *ring.Ring prng utils.PRNG @@ -55,6 +52,9 @@ type testContext struct { decryptorSk0 bfv.Decryptor decryptorSk1 bfv.Decryptor evaluator bfv.Evaluator + + crs drlwe.CRS + uniformSampler *ring.UniformSampler } func Test_DBFV(t *testing.T) { @@ -103,13 +103,11 @@ func gentestContext(params bfv.Parameters) (testCtx *testContext, err error) { testCtx.ringT = params.RingT() testCtx.ringQ = params.RingQ() - testCtx.ringQP = params.RingQP() + testCtx.ringP = params.RingP() - testCtx.deltaMont = bfv.GenLiftParams(testCtx.ringQ, params.T()) - - if testCtx.prng, err = utils.NewPRNG(); err != nil { - return nil, err - } + prng, _ := utils.NewKeyedPRNG([]byte{'t', 'e', 's', 't'}) + testCtx.crs = prng + testCtx.uniformSampler = ring.NewUniformSampler(prng, params.RingQ()) testCtx.encoder = bfv.NewEncoder(testCtx.params) testCtx.evaluator = bfv.NewEvaluator(testCtx.params, rlwe.EvaluationKey{}) @@ -119,21 +117,17 @@ func gentestContext(params bfv.Parameters) (testCtx *testContext, err error) { // SecretKeys testCtx.sk0Shards = make([]*rlwe.SecretKey, parties) testCtx.sk1Shards = make([]*rlwe.SecretKey, parties) - tmp0 := testCtx.ringQP.NewPoly() - tmp1 := testCtx.ringQP.NewPoly() - - for j := 0; j < parties; j++ { - testCtx.sk0Shards[j] = kgen.GenSecretKey() - testCtx.sk1Shards[j] = kgen.GenSecretKey() - testCtx.ringQP.Add(tmp0, testCtx.sk0Shards[j].Value, tmp0) - testCtx.ringQP.Add(tmp1, testCtx.sk1Shards[j].Value, tmp1) - } testCtx.sk0 = bfv.NewSecretKey(testCtx.params) testCtx.sk1 = bfv.NewSecretKey(testCtx.params) - testCtx.sk0.Value.Copy(tmp0) - testCtx.sk1.Value.Copy(tmp1) + ringQP, levelQ, levelP := params.RingQP(), params.QCount()-1, params.PCount()-1 + for j := 0; j < parties; j++ { + testCtx.sk0Shards[j] = kgen.GenSecretKey() + testCtx.sk1Shards[j] = kgen.GenSecretKey() + ringQP.AddLvl(levelQ, levelP, testCtx.sk0.Value, testCtx.sk0Shards[j].Value, testCtx.sk0.Value) + ringQP.AddLvl(levelQ, levelP, testCtx.sk1.Value, testCtx.sk1Shards[j].Value, testCtx.sk1.Value) + } // Publickeys testCtx.pk0 = kgen.GenPublicKey(testCtx.sk0) @@ -153,9 +147,6 @@ func testPublicKeyGen(testCtx *testContext, t *testing.T) { t.Run(testString("PublicKeyGen/", parties, testCtx.params), func(t *testing.T) { - crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.ringQP) - crp := crpGenerator.ReadNew() - type Party struct { *CKGProtocol s *rlwe.SecretKey @@ -172,7 +163,9 @@ func testPublicKeyGen(testCtx *testContext, t *testing.T) { } P0 := ckgParties[0] - // Checks that bfv.CKGProtocol complies to the drlwe.CollectivePublicKeyGenerator interface + crp := P0.SampleCRP(testCtx.crs) + + // Checks that dbfv.CKGProtocol complies to the drlwe.CollectivePublicKeyGenerator interface var _ drlwe.CollectivePublicKeyGenerator = P0.CKGProtocol // Each party creates a new CKGProtocol instance @@ -226,12 +219,7 @@ func testRelinKeyGen(testCtx *testContext, t *testing.T) { // checks that bfv.RKGProtocol complies to the drlwe.RelinearizationKeyGenerator interface var _ drlwe.RelinearizationKeyGenerator = P0.RKGProtocol - crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.ringQP) - crp := make([]*ring.Poly, testCtx.params.Beta()) - - for i := 0; i < testCtx.params.Beta(); i++ { - crp[i] = crpGenerator.ReadNew() - } + crp := P0.SampleCRP(testCtx.crs) // ROUND 1 for i, p := range rkgParties { @@ -243,7 +231,7 @@ func testRelinKeyGen(testCtx *testContext, t *testing.T) { //ROUND 2 for i, p := range rkgParties { - p.GenShareRoundTwo(p.ephSk, p.sk, P0.share1, crp, p.share2) + p.GenShareRoundTwo(p.ephSk, p.sk, P0.share1, p.share2) if i > 0 { P0.AggregateShares(p.share2, P0.share2, P0.share2) } @@ -395,12 +383,7 @@ func testRotKeyGenRotRows(testCtx *testContext, t *testing.T) { // Checks that bfv.RTGProtocol complies to the drlwe.RotationKeyGenerator interface var _ drlwe.RotationKeyGenerator = P0.RTGProtocol - crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.ringQP) - crp := make([]*ring.Poly, testCtx.params.Beta()) - - for i := 0; i < testCtx.params.Beta(); i++ { - crp[i] = crpGenerator.ReadNew() - } + crp := P0.SampleCRP(testCtx.crs) galEl := testCtx.params.GaloisElementForRowRotation() rotKeySet := bfv.NewRotationKeySet(testCtx.params, []uint64{galEl}) @@ -453,12 +436,7 @@ func testRotKeyGenRotCols(testCtx *testContext, t *testing.T) { // Checks that bfv.RTGProtocol complies to the drlwe.RotationKeyGenerator interface var _ drlwe.RotationKeyGenerator = P0.RTGProtocol - crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.ringQP) - crp := make([]*ring.Poly, testCtx.params.Beta()) - - for i := 0; i < testCtx.params.Beta(); i++ { - crp[i] = crpGenerator.ReadNew() - } + crp := P0.SampleCRP(testCtx.crs) coeffs, _, ciphertext := newTestVectors(testCtx, encryptorPk0, t) @@ -534,19 +512,18 @@ func testEncToShares(testCtx *testContext, t *testing.T) { }) - crs := ring.NewUniformSampler(testCtx.prng, testCtx.ringQ) - c1 := crs.ReadNew() + crp := P[0].e2s.SampleCRP(params.MaxLevel(), testCtx.crs) t.Run(testString("S2EProtocol/", parties, testCtx.params), func(t *testing.T) { for i, p := range P { - p.s2e.GenShare(p.sk, c1, p.secretShare, p.publicShare) + p.s2e.GenShare(p.sk, crp, p.secretShare, p.publicShare) if i > 0 { p.s2e.AggregateShares(P[0].publicShare, p.publicShare, P[0].publicShare) } } ctRec := bfv.NewCiphertext(testCtx.params, 1) - P[0].s2e.GetEncryption(P[0].publicShare, c1, ctRec) + P[0].s2e.GetEncryption(P[0].publicShare, crp, ctRec) verifyTestVectors(testCtx, testCtx.decryptorSk0, coeffs, ctRec, t) }) @@ -584,8 +561,7 @@ func testRefresh(testCtx *testContext, t *testing.T) { P0 := RefreshParties[0] - crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.ringQ) - crp := crpGenerator.ReadNew() + crp := P0.SampleCRP(testCtx.params.MaxLevel(), testCtx.crs) coeffs, _, ciphertext := newTestVectors(testCtx, encryptorPk0, t) @@ -680,15 +656,15 @@ func testRefreshAndPermutation(testCtx *testContext, t *testing.T) { P0 := RefreshParties[0] - crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.ringQ) - crp := crpGenerator.ReadNew() + crp := P0.SampleCRP(testCtx.params.MaxLevel(), testCtx.crs) coeffs, _, ciphertext := newTestVectors(testCtx, encryptorPk0, t) permutation := make([]uint64, len(coeffs)) N := uint64(testCtx.params.N()) + prng, _ := utils.NewPRNG() for i := range permutation { - permutation[i] = ring.RandUniform(testCtx.prng, N, N-1) + permutation[i] = ring.RandUniform(prng, N, N-1) } permute := func(ptIn, ptOut bfv.PlaintextRingT) { @@ -723,8 +699,8 @@ func testRefreshAndPermutation(testCtx *testContext, t *testing.T) { func newTestVectors(testCtx *testContext, encryptor bfv.Encryptor, t *testing.T) (coeffs []uint64, plaintext *bfv.Plaintext, ciphertext *bfv.Ciphertext) { - uniformSampler := ring.NewUniformSampler(testCtx.prng, testCtx.ringT) - + prng, _ := utils.NewPRNG() + uniformSampler := ring.NewUniformSampler(prng, testCtx.ringT) coeffsPol := uniformSampler.ReadNew() plaintext = bfv.NewPlaintext(testCtx.params) testCtx.encoder.EncodeUint(coeffsPol.Coeffs[0], plaintext) @@ -737,21 +713,19 @@ func verifyTestVectors(testCtx *testContext, decryptor bfv.Decryptor, coeffs []u } func testMarshalling(testCtx *testContext, t *testing.T) { + ciphertext := bfv.NewCiphertext(testCtx.params, 1) + testCtx.uniformSampler.Read(ciphertext.Value[0]) + testCtx.uniformSampler.Read(ciphertext.Value[1]) - //verify if the un.marshalling works properly - - crsGen := ring.NewUniformSampler(testCtx.prng, testCtx.ringQP) - crs := crsGen.ReadNew() - ringQ := testCtx.ringQ - - ciphertext := bfv.NewCiphertextRandom(testCtx.prng, testCtx.params, 1) - - t.Run(fmt.Sprintf("Marshalling/Refresh/N=%d/limbQ=%d/limbsP=%d", ringQ.N, testCtx.params.QCount(), testCtx.params.PCount()), func(t *testing.T) { + t.Run(testString("MarshallingRefresh/", parties, testCtx.params), func(t *testing.T) { //testing refresh shares refreshproto := NewRefreshProtocol(testCtx.params, 3.2) refreshshare := refreshproto.AllocateShare() - refreshproto.GenShares(testCtx.sk0, ciphertext, crs, refreshshare) + + crp := refreshproto.SampleCRP(testCtx.params.MaxLevel(), testCtx.crs) + + refreshproto.GenShares(testCtx.sk0, ciphertext, crp, refreshshare) data, err := refreshshare.MarshalBinary() if err != nil { diff --git a/dbfv/refresh.go b/dbfv/refresh.go index 35d992b7..910d8932 100644 --- a/dbfv/refresh.go +++ b/dbfv/refresh.go @@ -2,7 +2,7 @@ package dbfv import ( "github.com/ldsec/lattigo/v2/bfv" - "github.com/ldsec/lattigo/v2/ring" + "github.com/ldsec/lattigo/v2/drlwe" "github.com/ldsec/lattigo/v2/rlwe" ) @@ -25,12 +25,13 @@ func NewRefreshProtocol(params bfv.Parameters, sigmaSmudging float64) (rfp *Refr // AllocateShare allocates the shares of the PermuteProtocol func (rfp *RefreshProtocol) AllocateShare() *RefreshShare { - return &RefreshShare{*rfp.MaskedTransformProtocol.AllocateShare()} + share := rfp.MaskedTransformProtocol.AllocateShare() + return &RefreshShare{*share} } // GenShares generates a share for the Refresh protocol. -func (rfp *RefreshProtocol) GenShares(sk *rlwe.SecretKey, ciphertext *bfv.Ciphertext, crs *ring.Poly, shareOut *RefreshShare) { - rfp.MaskedTransformProtocol.GenShares(sk, ciphertext, crs, nil, &shareOut.MaskedTransformShare) +func (rfp *RefreshProtocol) GenShares(sk *rlwe.SecretKey, ciphertext *bfv.Ciphertext, crp drlwe.CKSCRP, shareOut *RefreshShare) { + rfp.MaskedTransformProtocol.GenShares(sk, ciphertext, crp, nil, &shareOut.MaskedTransformShare) } // Aggregate aggregates two parties' shares in the Refresh protocol. @@ -39,6 +40,6 @@ func (rfp *RefreshProtocol) Aggregate(share1, share2, shareOut *RefreshShare) { } // Finalize applies Decrypt, Recode and Recrypt on the input ciphertext. -func (rfp *RefreshProtocol) Finalize(ciphertext *bfv.Ciphertext, crs *ring.Poly, share *RefreshShare, ciphertextOut *bfv.Ciphertext) { - rfp.MaskedTransformProtocol.Transform(ciphertext, nil, crs, &share.MaskedTransformShare, ciphertextOut) +func (rfp *RefreshProtocol) Finalize(ciphertext *bfv.Ciphertext, crp drlwe.CKSCRP, share *RefreshShare, ciphertextOut *bfv.Ciphertext) { + rfp.MaskedTransformProtocol.Transform(ciphertext, nil, crp, &share.MaskedTransformShare, ciphertextOut) } diff --git a/dbfv/sharing.go b/dbfv/sharing.go index 20ca8c65..63fd6f02 100644 --- a/dbfv/sharing.go +++ b/dbfv/sharing.go @@ -89,19 +89,19 @@ func NewS2EProtocol(params bfv.Parameters, sigmaSmudging float64) *S2EProtocol { } // GenShare generates a party's in the shares-to-encryption protocol given the party's secret-key share `sk`, a common -// polynomial sampled from the CRS `c1` and the party's secret share of the message. -func (s2e *S2EProtocol) GenShare(sk *rlwe.SecretKey, c1 *ring.Poly, secretShare *rlwe.AdditiveShare, c0ShareOut *drlwe.CKSShare) { +// polynomial sampled from the CRS `crp` and the party's secret share of the message. +func (s2e *S2EProtocol) GenShare(sk *rlwe.SecretKey, crp drlwe.CKSCRP, secretShare *rlwe.AdditiveShare, c0ShareOut *drlwe.CKSShare) { s2e.encoder.ScaleUp(&bfv.PlaintextRingT{Plaintext: &rlwe.Plaintext{Value: &secretShare.Value}}, s2e.tmpPlaintext) - s2e.CKSProtocol.GenShare(s2e.zero, sk, &rlwe.Ciphertext{Value: []*ring.Poly{c0ShareOut.Value, c1}}, c0ShareOut) + s2e.CKSProtocol.GenShare(s2e.zero, sk, &rlwe.Ciphertext{Value: []*ring.Poly{c0ShareOut.Value, (*ring.Poly)(&crp)}}, c0ShareOut) s2e.ringQ.Add(c0ShareOut.Value, s2e.tmpPlaintext.Value, c0ShareOut.Value) } // GetEncryption computes the final encryption of the secret-shared message when provided with the aggregation `c0Agg` of the parties' -// share in the protocol and with the common, CRS-sampled polynomial `c1`. -func (s2e *S2EProtocol) GetEncryption(c0Agg *drlwe.CKSShare, c1 *ring.Poly, ctOut *bfv.Ciphertext) { +// share in the protocol and with the common, CRS-sampled polynomial `crp`. +func (s2e *S2EProtocol) GetEncryption(c0Agg *drlwe.CKSShare, crp drlwe.CKSCRP, ctOut *bfv.Ciphertext) { if ctOut.Degree() != 1 { panic("ctOut must have degree 1.") } ctOut.Value[0].Copy(c0Agg.Value) - ctOut.Value[1].Copy(c1) + ctOut.Value[1].Copy((*ring.Poly)(&crp)) } diff --git a/dbfv/transform.go b/dbfv/transform.go index 61e7d7f5..59612999 100644 --- a/dbfv/transform.go +++ b/dbfv/transform.go @@ -5,6 +5,7 @@ import ( "github.com/ldsec/lattigo/v2/drlwe" "github.com/ldsec/lattigo/v2/ring" "github.com/ldsec/lattigo/v2/rlwe" + "github.com/ldsec/lattigo/v2/utils" ) // MaskedTransformProtocol is a struct storing the parameters for the MaskedTransformProtocol protocol. @@ -71,6 +72,12 @@ func NewMaskedTransformProtocol(params bfv.Parameters, sigmaSmudging float64) (r return } +// SampleCRP samples a common random polynomial to be used in the Masked-Transform protocol from the provided +// common reference string. +func (rfp *MaskedTransformProtocol) SampleCRP(level int, crs utils.PRNG) drlwe.CKSCRP { + return rfp.e2s.SampleCRP(level, crs) +} + // AllocateShare allocates the shares of the PermuteProtocol func (rfp *MaskedTransformProtocol) AllocateShare() *MaskedTransformShare { level := len(rfp.ringQ.Modulus) - 1 @@ -78,7 +85,7 @@ func (rfp *MaskedTransformProtocol) AllocateShare() *MaskedTransformShare { } // GenShares generates the shares of the PermuteProtocol -func (rfp *MaskedTransformProtocol) GenShares(sk *rlwe.SecretKey, ciphertext *bfv.Ciphertext, crs *ring.Poly, transform MaskedTransformFunc, shareOut *MaskedTransformShare) { +func (rfp *MaskedTransformProtocol) GenShares(sk *rlwe.SecretKey, ciphertext *bfv.Ciphertext, crs drlwe.CKSCRP, transform MaskedTransformFunc, shareOut *MaskedTransformShare) { rfp.e2s.GenShare(sk, ciphertext, &rlwe.AdditiveShare{Value: *rfp.tmpMask}, &shareOut.e2sShare) mask := rfp.tmpMask if transform != nil { @@ -95,7 +102,7 @@ func (rfp *MaskedTransformProtocol) Aggregate(share1, share2, shareOut *MaskedTr } // Transform applies Decrypt, Recode and Recrypt on the input ciphertext. -func (rfp *MaskedTransformProtocol) Transform(ciphertext *bfv.Ciphertext, transform MaskedTransformFunc, crs *ring.Poly, share *MaskedTransformShare, ciphertextOut *bfv.Ciphertext) { +func (rfp *MaskedTransformProtocol) Transform(ciphertext *bfv.Ciphertext, transform MaskedTransformFunc, crs drlwe.CKSCRP, share *MaskedTransformShare, ciphertextOut *bfv.Ciphertext) { rfp.e2s.GetShare(nil, &share.e2sShare, ciphertext, &rlwe.AdditiveShare{Value: *rfp.tmpMask}) // tmpMask RingT(m - sum M_i) mask := rfp.tmpMask if transform != nil { diff --git a/dckks/dckks_benchmark_test.go b/dckks/dckks_benchmark_test.go index ea465977..5eac6a16 100644 --- a/dckks/dckks_benchmark_test.go +++ b/dckks/dckks_benchmark_test.go @@ -49,9 +49,6 @@ func benchPublicKeyGen(testCtx *testContext, b *testing.B) { sk0Shards := testCtx.sk0Shards params := testCtx.params - crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.ringQP) - crp := crpGenerator.ReadNew() - type Party struct { *CKGProtocol s *rlwe.SecretKey @@ -63,6 +60,8 @@ func benchPublicKeyGen(testCtx *testContext, b *testing.B) { p.s = sk0Shards[0] p.s1 = p.AllocateShares() + crp := p.SampleCRP(testCtx.crs) + b.Run(testString("PublicKeyGen/Gen/", parties, params), func(b *testing.B) { // Each party creates a new CKGProtocol instance @@ -98,12 +97,7 @@ func benchRelinKeyGen(testCtx *testContext, b *testing.B) { p.sk = sk0Shards[0] p.ephSk, p.share1, p.share2 = p.RKGProtocol.AllocateShares() - crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.ringQP) - crp := make([]*ring.Poly, params.Beta()) - - for i := 0; i < params.Beta(); i++ { - crp[i] = crpGenerator.ReadNew() - } + crp := p.SampleCRP(testCtx.crs) b.Run(testString("RelinKeyGen/Round1Gen/", parties, params), func(b *testing.B) { @@ -122,7 +116,7 @@ func benchRelinKeyGen(testCtx *testContext, b *testing.B) { b.Run(testString("RelinKeyGen/Round2Gen/", parties, params), func(b *testing.B) { for i := 0; i < b.N; i++ { - p.GenShareRoundTwo(p.ephSk, p.sk, p.share1, crp, p.share2) + p.GenShareRoundTwo(p.ephSk, p.sk, p.share1, p.share2) } }) @@ -141,7 +135,7 @@ func benchKeySwitching(testCtx *testContext, b *testing.B) { sk1Shards := testCtx.sk1Shards params := testCtx.params - ciphertext := ckks.NewCiphertextRandom(testCtx.prng, params, 1, params.MaxLevel(), params.Scale()) + ciphertext := ckks.NewCiphertext(params, 1, params.MaxLevel(), params.Scale()) type Party struct { *CKSProtocol @@ -184,7 +178,7 @@ func benchPublicKeySwitching(testCtx *testContext, b *testing.B) { pk1 := testCtx.pk1 params := testCtx.params - ciphertext := ckks.NewCiphertextRandom(testCtx.prng, params, 1, params.MaxLevel(), params.Scale()) + ciphertext := ckks.NewCiphertext(params, 1, params.MaxLevel(), params.Scale()) type Party struct { *PCKSProtocol @@ -221,7 +215,6 @@ func benchPublicKeySwitching(testCtx *testContext, b *testing.B) { func benchRotKeyGen(testCtx *testContext, b *testing.B) { - ringQP := testCtx.ringQP sk0Shards := testCtx.sk0Shards params := testCtx.params @@ -236,12 +229,8 @@ func benchRotKeyGen(testCtx *testContext, b *testing.B) { p.s = sk0Shards[0] p.share = p.AllocateShares() - crpGenerator := ring.NewUniformSampler(testCtx.prng, ringQP) - crp := make([]*ring.Poly, params.Beta()) + crp := p.SampleCRP(testCtx.crs) - for i := 0; i < params.Beta(); i++ { - crp[i] = crpGenerator.ReadNew() - } galEl := params.GaloisElementForRowRotation() b.Run(testString("RotKeyGen/Round1/Gen/", parties, params), func(b *testing.B) { @@ -286,10 +275,9 @@ func benchRefresh(testCtx *testContext, b *testing.B) { p.s = sk0Shards[0] p.share = p.AllocateShare(minLevel, params.MaxLevel()) - crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.ringQ) - crp := crpGenerator.ReadNew() + ciphertext := ckks.NewCiphertext(params, 1, minLevel, params.Scale()) - ciphertext := ckks.NewCiphertextRandom(testCtx.prng, params, 1, minLevel, params.Scale()) + crp := p.SampleCRP(params.MaxLevel(), testCtx.crs) b.Run(testString("Refresh/Round1/Gen", parties, params), func(b *testing.B) { @@ -333,15 +321,14 @@ func benchMaskedTransform(testCtx *testContext, b *testing.B) { share *MaskedTransformShare } - ciphertext := ckks.NewCiphertextRandom(testCtx.prng, params, 1, minLevel, params.Scale()) + ciphertext := ckks.NewCiphertext(params, 1, minLevel, params.Scale()) p := new(Party) p.MaskedTransformProtocol = NewMaskedTransformProtocol(params, logBound, 3.2) p.s = sk0Shards[0] p.share = p.AllocateShare(ciphertext.Level(), params.MaxLevel()) - crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.ringQ) - crp := crpGenerator.ReadNew() + crp := p.SampleCRP(params.MaxLevel(), testCtx.crs) permute := func(ptIn, ptOut []*ring.Complex) { for i := range ptIn { diff --git a/dckks/dckks_test.go b/dckks/dckks_test.go index 264a1450..c7b10226 100644 --- a/dckks/dckks_test.go +++ b/dckks/dckks_test.go @@ -36,10 +36,8 @@ func testString(opname string, parties int, params ckks.Parameters) string { type testContext struct { params ckks.Parameters - ringQ *ring.Ring - ringQP *ring.Ring - - prng utils.PRNG + ringQ *ring.Ring + ringP *ring.Ring encoder ckks.Encoder evaluator ckks.Evaluator @@ -56,6 +54,9 @@ type testContext struct { sk0Shards []*rlwe.SecretKey sk1Shards []*rlwe.SecretKey + + crs drlwe.CRS + uniformSampler *ring.UniformSampler } func TestDCKKS(t *testing.T) { @@ -98,19 +99,18 @@ func TestDCKKS(t *testing.T) { } } -func genTestParams(defaultParams ckks.Parameters) (testCtx *testContext, err error) { +func genTestParams(params ckks.Parameters) (testCtx *testContext, err error) { testCtx = new(testContext) - testCtx.params = defaultParams + testCtx.params = params - testCtx.ringQ = defaultParams.RingQ() + testCtx.ringQ = params.RingQ() + testCtx.ringP = params.RingP() - testCtx.ringQP = defaultParams.RingQP() - - if testCtx.prng, err = utils.NewPRNG(); err != nil { - return nil, err - } + prng, _ := utils.NewKeyedPRNG([]byte{'t', 'e', 's', 't'}) + testCtx.crs = prng + testCtx.uniformSampler = ring.NewUniformSampler(prng, params.RingQ()) testCtx.encoder = ckks.NewEncoder(testCtx.params) testCtx.evaluator = ckks.NewEvaluator(testCtx.params, rlwe.EvaluationKey{}) @@ -120,21 +120,17 @@ func genTestParams(defaultParams ckks.Parameters) (testCtx *testContext, err err // SecretKeys testCtx.sk0Shards = make([]*rlwe.SecretKey, parties) testCtx.sk1Shards = make([]*rlwe.SecretKey, parties) - tmp0 := testCtx.ringQP.NewPoly() - tmp1 := testCtx.ringQP.NewPoly() + testCtx.sk0 = ckks.NewSecretKey(testCtx.params) + testCtx.sk1 = ckks.NewSecretKey(testCtx.params) + ringQP, levelQ, levelP := params.RingQP(), params.QCount()-1, params.PCount()-1 for j := 0; j < parties; j++ { testCtx.sk0Shards[j] = kgen.GenSecretKey() testCtx.sk1Shards[j] = kgen.GenSecretKey() - testCtx.ringQP.Add(tmp0, testCtx.sk0Shards[j].Value, tmp0) - testCtx.ringQP.Add(tmp1, testCtx.sk1Shards[j].Value, tmp1) + ringQP.AddLvl(levelQ, levelP, testCtx.sk0.Value, testCtx.sk0Shards[j].Value, testCtx.sk0.Value) + ringQP.AddLvl(levelQ, levelP, testCtx.sk1.Value, testCtx.sk1Shards[j].Value, testCtx.sk1.Value) } - testCtx.sk0 = ckks.NewSecretKey(testCtx.params) - testCtx.sk1 = ckks.NewSecretKey(testCtx.params) - testCtx.sk0.Value.Copy(tmp0) - testCtx.sk1.Value.Copy(tmp1) - // Publickeys testCtx.pk0 = kgen.GenPublicKey(testCtx.sk0) testCtx.pk1 = kgen.GenPublicKey(testCtx.sk1) @@ -153,8 +149,6 @@ func testPublicKeyGen(testCtx *testContext, t *testing.T) { params := testCtx.params t.Run(testString("PublicKeyGen/", parties, params), func(t *testing.T) { - crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.ringQP) - crp := crpGenerator.ReadNew() type Party struct { *CKGProtocol @@ -172,6 +166,8 @@ func testPublicKeyGen(testCtx *testContext, t *testing.T) { } P0 := ckgParties[0] + crp := P0.SampleCRP(testCtx.crs) + var _ drlwe.CollectivePublicKeyGenerator = P0.CKGProtocol // Each party creates a new CKGProtocol instance @@ -224,16 +220,11 @@ func testRelinKeyGen(testCtx *testContext, t *testing.T) { P0 := rkgParties[0] + crp := P0.SampleCRP(testCtx.crs) + // Checks that ckks.RKGProtocol complies to the drlwe.RelinearizationKeyGenerator interface var _ drlwe.RelinearizationKeyGenerator = P0.RKGProtocol - crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.ringQP) - crp := make([]*ring.Poly, params.Beta()) - - for i := 0; i < params.Beta(); i++ { - crp[i] = crpGenerator.ReadNew() - } - // ROUND 1 for i, p := range rkgParties { p.GenShareRoundOne(p.sk, crp, p.ephSk, p.share1) @@ -244,7 +235,7 @@ func testRelinKeyGen(testCtx *testContext, t *testing.T) { //ROUND 2 for i, p := range rkgParties { - p.GenShareRoundTwo(p.ephSk, p.sk, P0.share1, crp, p.share2) + p.GenShareRoundTwo(p.ephSk, p.sk, P0.share1, p.share2) if i > 0 { P0.AggregateShares(p.share2, P0.share2, P0.share2) } @@ -389,7 +380,6 @@ func testPublicKeySwitching(testCtx *testContext, t *testing.T) { func testRotKeyGenConjugate(testCtx *testContext, t *testing.T) { - ringQP := testCtx.ringQP encryptorPk0 := testCtx.encryptorPk0 decryptorSk0 := testCtx.decryptorSk0 sk0Shards := testCtx.sk0Shards @@ -416,12 +406,7 @@ func testRotKeyGenConjugate(testCtx *testContext, t *testing.T) { // checks that ckks.RTGProtocol complies to the drlwe.RotationKeyGenerator interface var _ drlwe.RotationKeyGenerator = P0.RTGProtocol - crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.ringQP) - crp := make([]*ring.Poly, params.Beta()) - - for i := 0; i < params.Beta(); i++ { - crp[i] = crpGenerator.ReadNew() - } + crp := P0.SampleCRP(testCtx.crs) galEl := params.GaloisElementForRowRotation() rotKeySet := ckks.NewRotationKeySet(params, []uint64{galEl}) @@ -440,9 +425,9 @@ func testRotKeyGenConjugate(testCtx *testContext, t *testing.T) { evaluator := testCtx.evaluator.WithKey(rlwe.EvaluationKey{Rlk: nil, Rtks: rotKeySet}) evaluator.Conjugate(ciphertext, ciphertext) - coeffsWant := make([]complex128, ringQP.N>>1) + coeffsWant := make([]complex128, params.Slots()) - for i := 0; i < ringQP.N>>1; i++ { + for i := 0; i < params.Slots(); i++ { coeffsWant[i] = complex(real(coeffs[i]), -imag(coeffs[i])) } @@ -453,7 +438,6 @@ func testRotKeyGenConjugate(testCtx *testContext, t *testing.T) { func testRotKeyGenCols(testCtx *testContext, t *testing.T) { - ringQP := testCtx.ringQP encryptorPk0 := testCtx.encryptorPk0 decryptorSk0 := testCtx.decryptorSk0 sk0Shards := testCtx.sk0Shards @@ -478,12 +462,7 @@ func testRotKeyGenCols(testCtx *testContext, t *testing.T) { P0 := pcksParties[0] - crpGenerator := ring.NewUniformSampler(testCtx.prng, ringQP) - crp := make([]*ring.Poly, params.Beta()) - - for i := 0; i < params.Beta(); i++ { - crp[i] = crpGenerator.ReadNew() - } + crp := P0.SampleCRP(testCtx.crs) coeffs, _, ciphertext := newTestVectors(testCtx, encryptorPk0, -1, 1, t) @@ -504,7 +483,7 @@ func testRotKeyGenCols(testCtx *testContext, t *testing.T) { evaluator := testCtx.evaluator.WithKey(rlwe.EvaluationKey{Rlk: nil, Rtks: rotKeySet}) - for k := 1; k < ringQP.N>>1; k <<= 1 { + for k := 1; k < params.Slots(); k <<= 1 { evaluator.Rotate(ciphertext, int(k), receiver) coeffsWant := utils.RotateComplex128Slice(coeffs, int(k)) @@ -583,12 +562,11 @@ func testE2SProtocol(testCtx *testContext, t *testing.T) { verifyTestVectors(testCtx, nil, coeffs, pt, t) - crs := ring.NewUniformSampler(testCtx.prng, testCtx.ringQ) - c1 := crs.ReadLvlNew(params.Parameters.MaxLevel()) + crp := P[0].s2e.SampleCRP(params.Parameters.MaxLevel(), testCtx.crs) for i, p := range P { - p.s2e.GenShare(p.sk, c1, p.secretShare, p.publicShareS2E) + p.s2e.GenShare(p.sk, crp, p.secretShare, p.publicShareS2E) if i > 0 { p.s2e.AggregateShares(P[0].publicShareS2E, p.publicShareS2E, P[0].publicShareS2E) @@ -596,7 +574,7 @@ func testE2SProtocol(testCtx *testContext, t *testing.T) { } ctRec := ckks.NewCiphertext(params, 1, params.Parameters.MaxLevel(), ciphertext.Scale) - P[0].s2e.GetEncryption(P[0].publicShareS2E, c1, ctRec) + P[0].s2e.GetEncryption(P[0].publicShareS2E, crp, ctRec) verifyTestVectors(testCtx, testCtx.decryptorSk0, coeffs, ctRec, t) @@ -629,22 +607,21 @@ func testRefresh(testCtx *testContext, t *testing.T) { // Brings ciphertext to level 2 testCtx.evaluator.DropLevel(ciphertext, ciphertext.Level()-minLevel) - levelMin := ciphertext.Level() - levelMax := params.MaxLevel() + levelIn := ciphertext.Level() + levelOut := params.MaxLevel() RefreshParties := make([]*Party, parties) for i := 0; i < parties; i++ { p := new(Party) p.RefreshProtocol = NewRefreshProtocol(params, logBound, 3.2) p.s = sk0Shards[i] - p.share = p.AllocateShare(levelMin, levelMax) + p.share = p.AllocateShare(levelIn, levelOut) RefreshParties[i] = p } P0 := RefreshParties[0] - crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.ringQ) - crp := crpGenerator.ReadLvlNew(levelMax) + crp := P0.SampleCRP(levelOut, testCtx.crs) for i, p := range RefreshParties { p.GenShares(p.s, logBound, params.LogSlots(), ciphertext, crp, p.share) @@ -685,22 +662,20 @@ func testRefreshAndTransform(testCtx *testContext, t *testing.T) { // Drops the ciphertext to the minimum level that ensures correctness and 128-bit security testCtx.evaluator.DropLevel(ciphertext, ciphertext.Level()-minLevel) - levelMin := ciphertext.Level() - levelMax := params.MaxLevel() + levelIn := ciphertext.Level() + levelOut := params.MaxLevel() RefreshParties := make([]*Party, parties) for i := 0; i < parties; i++ { p := new(Party) p.MaskedTransformProtocol = NewMaskedTransformProtocol(params, logBound, 3.2) p.s = sk0Shards[i] - p.share = p.AllocateShare(levelMin, levelMax) + p.share = p.AllocateShare(levelIn, levelOut) RefreshParties[i] = p } P0 := RefreshParties[0] - - crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.ringQ) - crp := crpGenerator.ReadLvlNew(levelMax) + crp := P0.SampleCRP(levelOut, testCtx.crs) permute := func(ptIn, ptOut []*ring.Complex) { for i := range ptIn { @@ -727,7 +702,6 @@ func testRefreshAndTransform(testCtx *testContext, t *testing.T) { } func testMarshalling(testCtx *testContext, t *testing.T) { - crsGen := ring.NewUniformSampler(testCtx.prng, testCtx.ringQP) params := testCtx.params t.Run(testString("Marshalling/Refresh/", parties, params), func(t *testing.T) { @@ -738,14 +712,17 @@ func testMarshalling(testCtx *testContext, t *testing.T) { t.Skip("Not enough levels to ensure correcness and 128 security") } - ciphertext := ckks.NewCiphertextRandom(testCtx.prng, testCtx.params, 1, minLevel, testCtx.params.Scale()) - - crsLevel := crsGen.ReadLvlNew(minLevel) + ciphertext := ckks.NewCiphertext(params, 1, minLevel, params.Scale()) + testCtx.uniformSampler.Read(ciphertext.Value[0]) + testCtx.uniformSampler.Read(ciphertext.Value[1]) //testing refresh shares refreshproto := NewRefreshProtocol(testCtx.params, logBound, 3.2) - refreshshare := refreshproto.AllocateShare(ciphertext.Level(), ciphertext.Level()) - refreshproto.GenShares(testCtx.sk0, logBound, params.LogSlots(), ciphertext, crsLevel, refreshshare) + refreshshare := refreshproto.AllocateShare(ciphertext.Level(), params.MaxLevel()) + + crp := refreshproto.SampleCRP(params.MaxLevel(), testCtx.crs) + + refreshproto.GenShares(testCtx.sk0, logBound, params.LogSlots(), ciphertext, crp, refreshshare) data, err := refreshshare.MarshalBinary() diff --git a/dckks/refresh.go b/dckks/refresh.go index de38a8c2..c2c9a172 100644 --- a/dckks/refresh.go +++ b/dckks/refresh.go @@ -2,7 +2,7 @@ package dckks import ( "github.com/ldsec/lattigo/v2/ckks" - "github.com/ldsec/lattigo/v2/ring" + "github.com/ldsec/lattigo/v2/drlwe" "github.com/ldsec/lattigo/v2/rlwe" ) @@ -24,8 +24,9 @@ func NewRefreshProtocol(params ckks.Parameters, precision int, sigmaSmudging flo } // AllocateShare allocates the shares of the PermuteProtocol -func (rfp *RefreshProtocol) AllocateShare(minLevel, maxLevel int) *RefreshShare { - return &RefreshShare{*rfp.MaskedTransformProtocol.AllocateShare(minLevel, maxLevel)} +func (rfp *RefreshProtocol) AllocateShare(inputLevel, outputLevel int) *RefreshShare { + share := rfp.MaskedTransformProtocol.AllocateShare(inputLevel, outputLevel) + return &RefreshShare{*share} } // GenShares generates a share for the Refresh protocol. @@ -35,7 +36,7 @@ func (rfp *RefreshProtocol) AllocateShare(minLevel, maxLevel int) *RefreshShare // // The method "GetMinimumLevelForBootstrapping" should be used to get the minimum level at which the refresh can be called while still ensure 128-bits of security, as well as the // value for logBound. -func (rfp *RefreshProtocol) GenShares(sk *rlwe.SecretKey, logBound, logSlots int, ciphertext *ckks.Ciphertext, crs *ring.Poly, shareOut *RefreshShare) { +func (rfp *RefreshProtocol) GenShares(sk *rlwe.SecretKey, logBound, logSlots int, ciphertext *ckks.Ciphertext, crs drlwe.CKSCRP, shareOut *RefreshShare) { rfp.MaskedTransformProtocol.GenShares(sk, logBound, logSlots, ciphertext, crs, nil, &shareOut.MaskedTransformShare) } @@ -45,6 +46,6 @@ func (rfp *RefreshProtocol) Aggregate(share1, share2, shareOut *RefreshShare) { } // Finalize applies Decrypt, Recode and Recrypt on the input ciphertext. -func (rfp *RefreshProtocol) Finalize(ciphertext *ckks.Ciphertext, logSlots int, crs *ring.Poly, share *RefreshShare, ciphertextOut *ckks.Ciphertext) { +func (rfp *RefreshProtocol) Finalize(ciphertext *ckks.Ciphertext, logSlots int, crs drlwe.CKSCRP, share *RefreshShare, ciphertextOut *ckks.Ciphertext) { rfp.MaskedTransformProtocol.Transform(ciphertext, logSlots, nil, crs, &share.MaskedTransformShare, ciphertextOut) } diff --git a/dckks/sharing.go b/dckks/sharing.go index 50034d22..74ba33cd 100644 --- a/dckks/sharing.go +++ b/dckks/sharing.go @@ -145,6 +145,7 @@ func (e2s *E2SProtocol) GetShare(secretShare *rlwe.AdditiveShareBigint, aggregat // required by the shares-to-encryption protocol. type S2EProtocol struct { CKSProtocol + params ckks.Parameters ringQ *ring.Ring tmp *ring.Poly ssBigint []*big.Int @@ -155,6 +156,7 @@ type S2EProtocol struct { func NewS2EProtocol(params ckks.Parameters, sigmaSmudging float64) *S2EProtocol { s2e := new(S2EProtocol) s2e.CKSProtocol = *NewCKSProtocol(params, sigmaSmudging) + s2e.params = params s2e.ringQ = params.RingQ() s2e.tmp = s2e.ringQ.NewPoly() s2e.ssBigint = make([]*big.Int, s2e.ringQ.N) @@ -171,7 +173,9 @@ func (s2e S2EProtocol) AllocateShare(level int) (share *drlwe.CKSShare) { // GenShare generates a party's in the shares-to-encryption protocol given the party's secret-key share `sk`, a common // polynomial sampled from the CRS `c1` and the party's secret share of the message. -func (s2e *S2EProtocol) GenShare(sk *rlwe.SecretKey, c1 *ring.Poly, secretShare *rlwe.AdditiveShareBigint, c0ShareOut *drlwe.CKSShare) { +func (s2e *S2EProtocol) GenShare(sk *rlwe.SecretKey, crs drlwe.CKSCRP, secretShare *rlwe.AdditiveShareBigint, c0ShareOut *drlwe.CKSShare) { + + c1 := ring.Poly(crs) if c1.Level() != c0ShareOut.Value.Level() { panic("c1 and c0ShareOut level must be equal") @@ -179,7 +183,7 @@ func (s2e *S2EProtocol) GenShare(sk *rlwe.SecretKey, c1 *ring.Poly, secretShare // Generates an encryption share c1.IsNTT = true - s2e.CKSProtocol.GenShare(s2e.zero, sk, &rlwe.Ciphertext{Value: []*ring.Poly{nil, c1}}, c0ShareOut) + s2e.CKSProtocol.GenShare(s2e.zero, sk, &rlwe.Ciphertext{Value: []*ring.Poly{nil, &c1}}, c0ShareOut) s2e.ringQ.SetCoefficientsBigintLvl(c1.Level(), secretShare.Value, s2e.tmp) s2e.ringQ.NTTLvl(c1.Level(), s2e.tmp, s2e.tmp) @@ -188,12 +192,14 @@ func (s2e *S2EProtocol) GenShare(sk *rlwe.SecretKey, c1 *ring.Poly, secretShare // GetEncryption computes the final encryption of the secret-shared message when provided with the aggregation `c0Agg` of the parties' // share in the protocol and with the common, CRS-sampled polynomial `c1`. -func (s2e *S2EProtocol) GetEncryption(c0Agg *drlwe.CKSShare, c1 *ring.Poly, ctOut *ckks.Ciphertext) { +func (s2e *S2EProtocol) GetEncryption(c0Agg *drlwe.CKSShare, crs drlwe.CKSCRP, ctOut *ckks.Ciphertext) { if ctOut.Degree() != 1 { panic("ctOut must have degree 1.") } + c1 := ring.Poly(crs) + if c0Agg.Value.Level() != c1.Level() { panic("c0Agg level must be equal to c1 level") } @@ -203,5 +209,5 @@ func (s2e *S2EProtocol) GetEncryption(c0Agg *drlwe.CKSShare, c1 *ring.Poly, ctOu } ctOut.Value[0].Copy(c0Agg.Value) - ctOut.Value[1].Copy(c1) + ctOut.Value[1].Copy(&c1) } diff --git a/dckks/transform.go b/dckks/transform.go index 1a58707a..d8d102ee 100644 --- a/dckks/transform.go +++ b/dckks/transform.go @@ -9,6 +9,7 @@ import ( "github.com/ldsec/lattigo/v2/drlwe" "github.com/ldsec/lattigo/v2/ring" "github.com/ldsec/lattigo/v2/rlwe" + "github.com/ldsec/lattigo/v2/utils" ) // MaskedTransformProtocol is a struct storing the parameters for the MaskedTransformProtocol protocol. @@ -92,6 +93,12 @@ func (rfp *MaskedTransformProtocol) AllocateShare(levelDecrypt, levelRecrypt int return &MaskedTransformShare{*rfp.e2s.AllocateShare(levelDecrypt), *rfp.s2e.AllocateShare(levelRecrypt)} } +// SampleCRP samples a common random polynomial to be used in the Masked-Transform protocol from the provided +// common reference string. +func (rfp *MaskedTransformProtocol) SampleCRP(level int, crs utils.PRNG) drlwe.CKSCRP { + return rfp.s2e.SampleCRP(level, crs) +} + // GenShares generates the shares of the PermuteProtocol // This protocol requires additional inputs which are : // logBound : the bit length of the masks @@ -99,13 +106,13 @@ func (rfp *MaskedTransformProtocol) AllocateShare(levelDecrypt, levelRecrypt int // // The method "GetMinimumLevelForBootstrapping" should be used to get the minimum level at which the masked transform can be called while still ensure 128-bits of security, as well as the // value for logBound. -func (rfp *MaskedTransformProtocol) GenShares(sk *rlwe.SecretKey, logBound, logSlots int, ct *ckks.Ciphertext, crs *ring.Poly, transform MaskedTransformFunc, shareOut *MaskedTransformShare) { +func (rfp *MaskedTransformProtocol) GenShares(sk *rlwe.SecretKey, logBound, logSlots int, ct *ckks.Ciphertext, crs drlwe.CKSCRP, transform MaskedTransformFunc, shareOut *MaskedTransformShare) { if ct.Level() != shareOut.e2sShare.Value.Level() { panic("ciphertext level must be equal to e2sShare") } - if crs.Level() != shareOut.s2eShare.Value.Level() { + if (*ring.Poly)(&crs).Level() != shareOut.s2eShare.Value.Level() { panic("crs level must be equal to s2eShare") } @@ -172,18 +179,18 @@ func (rfp *MaskedTransformProtocol) Aggregate(share1, share2, shareOut *MaskedTr } // Transform applies Decrypt, Recode and Recrypt on the input ciphertext. -func (rfp *MaskedTransformProtocol) Transform(ct *ckks.Ciphertext, logSlots int, transform MaskedTransformFunc, crs *ring.Poly, share *MaskedTransformShare, ciphertextOut *ckks.Ciphertext) { +func (rfp *MaskedTransformProtocol) Transform(ct *ckks.Ciphertext, logSlots int, transform MaskedTransformFunc, crs drlwe.CKSCRP, share *MaskedTransformShare, ciphertextOut *ckks.Ciphertext) { if ct.Level() != share.e2sShare.Value.Level() { panic("ciphertext level and e2s level must be the same") } - if crs.Level() != share.s2eShare.Value.Level() { + maxLevel := (*ring.Poly)(&crs).Level() + + if maxLevel != share.s2eShare.Value.Level() { panic("crs level and s2e level must be the same") } - maxLevel := crs.Level() - // Returns -sum(M_i) + x (outside of the NTT domain) rfp.e2s.GetShare(nil, &share.e2sShare, ct, &rlwe.AdditiveShareBigint{Value: rfp.tmpMask}) @@ -225,7 +232,7 @@ func (rfp *MaskedTransformProtocol) Transform(ct *ckks.Ciphertext, logSlots int, } // Extend the levels of the ciphertext for future allocation - for ciphertextOut.Level() != crs.Level() { + for ciphertextOut.Level() != maxLevel { level := ciphertextOut.Level() + 1 ciphertextOut.Value[0].Coeffs = append(ciphertextOut.Value[0].Coeffs, make([][]uint64, 1)...) diff --git a/dckks/utils.go b/dckks/utils.go index 63eb8ae5..56ddb82e 100644 --- a/dckks/utils.go +++ b/dckks/utils.go @@ -1,32 +1,10 @@ package dckks import ( - "github.com/ldsec/lattigo/v2/ring" "math" "math/bits" ) -func extendBasisSmallNormAndCenter(ringQ, ringP *ring.Ring, polQ, polP *ring.Poly) { - var coeff, Q, QHalf, sign uint64 - Q = ringQ.Modulus[0] - QHalf = Q >> 1 - - for j := 0; j < ringQ.N; j++ { - - coeff = polQ.Coeffs[0][j] - - sign = 1 - if coeff > QHalf { - coeff = Q - coeff - sign = 0 - } - - for i, pi := range ringP.Modulus { - polP.Coeffs[i][j] = (coeff * sign) | (pi-coeff)*(sign^1) - } - } -} - // GetMinimumLevelForBootstrapping takes the security parameter lambda, the ciphertext scale, the number of parties and the moduli chain // and returns the minimum level at which the collective refresh can be called with a security of at least 128-bits. // It returns 3 parameters : diff --git a/drlwe/crs.go b/drlwe/crs.go new file mode 100644 index 00000000..c3e80080 --- /dev/null +++ b/drlwe/crs.go @@ -0,0 +1,12 @@ +package drlwe + +import ( + "github.com/ldsec/lattigo/v2/utils" +) + +// CRS is an interface for Common Reference Strings. +// CRSs are PRNGs for which the read bits are the same for +// all parties. +type CRS interface { + utils.PRNG +} diff --git a/drlwe/drlwe_test.go b/drlwe/drlwe_test.go index 4f66ef36..857de1b0 100644 --- a/drlwe/drlwe_test.go +++ b/drlwe/drlwe_test.go @@ -4,15 +4,16 @@ import ( "encoding/json" "flag" "fmt" - "github.com/ldsec/lattigo/v2/ring" - "github.com/ldsec/lattigo/v2/rlwe" - "github.com/ldsec/lattigo/v2/utils" - "github.com/stretchr/testify/require" "math" "math/big" "math/bits" "runtime" "testing" + + "github.com/ldsec/lattigo/v2/ring" + "github.com/ldsec/lattigo/v2/rlwe" + "github.com/ldsec/lattigo/v2/utils" + "github.com/stretchr/testify/require" ) var flagParamString = flag.String("params", "", "specify the test cryptographic parameters as a JSON string. Overrides -short and -long.") @@ -34,26 +35,25 @@ type testContext struct { params rlwe.Parameters kgen rlwe.KeyGenerator sk0, sk1, sk2, skIdeal *rlwe.SecretKey - crpGenerator *ring.UniformSampler + uniformSampler *ring.UniformSampler + crs utils.PRNG } func newTestContext(params rlwe.Parameters) testContext { - var err error + kgen := rlwe.NewKeyGenerator(params) sk0 := kgen.GenSecretKey() sk1 := kgen.GenSecretKey() sk2 := kgen.GenSecretKey() skIdeal := sk0.CopyNew() - params.RingQP().Add(skIdeal.Value, sk1.Value, skIdeal.Value) - params.RingQP().Add(skIdeal.Value, sk2.Value, skIdeal.Value) + levelQ, levelP := params.QCount()-1, params.PCount()-1 + params.RingQP().AddLvl(levelQ, levelP, skIdeal.Value, sk1.Value, skIdeal.Value) + params.RingQP().AddLvl(levelQ, levelP, skIdeal.Value, sk2.Value, skIdeal.Value) - var prng utils.PRNG - if prng, err = utils.NewPRNG(); err != nil { - panic(err) - } - crpGenerator := ring.NewUniformSampler(prng, params.RingQP()) + prng, _ := utils.NewKeyedPRNG([]byte{'t', 'e', 's', 't'}) + unifSampler := ring.NewUniformSampler(prng, params.RingQ()) - return testContext{params, kgen, sk0, sk1, sk2, skIdeal, crpGenerator} + return testContext{params, kgen, sk0, sk1, sk2, skIdeal, unifSampler, prng} } func TestDRLWE(t *testing.T) { @@ -93,34 +93,38 @@ func TestDRLWE(t *testing.T) { func testPublicKeyGen(testCtx testContext, t *testing.T) { params := testCtx.params + ringQ := params.RingQ() + ringP := params.RingP() ringQP := params.RingQP() + levelQ, levelP := params.QCount()-1, params.PCount()-1 t.Run(testString(params, "PublicKeyGen/"), func(t *testing.T) { - CKGProtocol := NewCKGProtocol(params) + ckg := NewCKGProtocol(params) - share0 := CKGProtocol.AllocateShares() - share1 := CKGProtocol.AllocateShares() - share2 := CKGProtocol.AllocateShares() + share0 := ckg.AllocateShares() + share1 := ckg.AllocateShares() + share2 := ckg.AllocateShares() - crp := testCtx.crpGenerator.ReadNew() + crp := ckg.SampleCRP(testCtx.crs) - CKGProtocol.GenShare(testCtx.sk0, crp, share0) - CKGProtocol.GenShare(testCtx.sk1, crp, share1) - CKGProtocol.GenShare(testCtx.sk2, crp, share2) + ckg.GenShare(testCtx.sk0, crp, share0) + ckg.GenShare(testCtx.sk1, crp, share1) + ckg.GenShare(testCtx.sk2, crp, share2) - CKGProtocol.AggregateShares(share0, share1, share0) - CKGProtocol.AggregateShares(share0, share2, share0) + ckg.AggregateShares(share0, share1, share0) + ckg.AggregateShares(share0, share2, share0) pk := rlwe.NewPublicKey(params) - CKGProtocol.GenPublicKey(share0, crp, pk) + ckg.GenPublicKey(share0, crp, pk) // [-as + e] + [as] - ringQP.MulCoeffsMontgomeryAndAdd(testCtx.skIdeal.Value, pk.Value[1], pk.Value[0]) - ringQP.InvNTT(pk.Value[0], pk.Value[0]) + ringQP.MulCoeffsMontgomeryAndAddLvl(levelQ, levelP, testCtx.skIdeal.Value, pk.Value[1], pk.Value[0]) + ringQP.InvNTTLvl(levelQ, levelP, pk.Value[0], pk.Value[0]) log2Bound := bits.Len64(3 * uint64(math.Floor(rlwe.DefaultSigma*6)) * uint64(params.N())) - require.GreaterOrEqual(t, log2Bound, log2OfInnerSum(pk.Value[0].Level(), ringQP, pk.Value[0])) + require.GreaterOrEqual(t, log2Bound, log2OfInnerSum(pk.Value[0].Q.Level(), ringQ, pk.Value[0].Q)) + require.GreaterOrEqual(t, log2Bound, log2OfInnerSum(pk.Value[0].P.Level(), ringP, pk.Value[0].P)) }) } @@ -128,8 +132,8 @@ func testKeySwitching(testCtx testContext, t *testing.T) { params := testCtx.params ringQ := params.RingQ() - level := len(ringQ.Modulus) - 1 - + ringQP := params.RingQP() + levelQ, levelP := params.QCount()-1, params.PCount()-1 t.Run(testString(params, "KeySwitching/"), func(t *testing.T) { sk0Out := testCtx.kgen.GenSecretKey() @@ -137,33 +141,34 @@ func testKeySwitching(testCtx testContext, t *testing.T) { sk2Out := testCtx.kgen.GenSecretKey() skOutIdeal := sk0Out.CopyNew() - params.RingQP().Add(skOutIdeal.Value, sk1Out.Value, skOutIdeal.Value) - params.RingQP().Add(skOutIdeal.Value, sk2Out.Value, skOutIdeal.Value) + ringQP.AddLvl(levelQ, levelP, skOutIdeal.Value, sk1Out.Value, skOutIdeal.Value) + ringQP.AddLvl(levelQ, levelP, skOutIdeal.Value, sk2Out.Value, skOutIdeal.Value) - ciphertext := &rlwe.Ciphertext{Value: []*ring.Poly{ringQ.NewPoly(), testCtx.crpGenerator.ReadLvlNew(level)}} - ringQ.MulCoeffsMontgomeryAndSub(ciphertext.Value[1], testCtx.skIdeal.Value, ciphertext.Value[0]) + ciphertext := &rlwe.Ciphertext{Value: []*ring.Poly{ringQ.NewPoly(), ringQ.NewPoly()}} + testCtx.uniformSampler.Read(ciphertext.Value[1]) + ringQ.MulCoeffsMontgomeryAndSub(ciphertext.Value[1], testCtx.skIdeal.Value.Q, ciphertext.Value[0]) ciphertext.Value[0].IsNTT = true ciphertext.Value[1].IsNTT = true - CKSProtocol := NewCKSProtocol(params, rlwe.DefaultSigma) + cks := NewCKSProtocol(params, rlwe.DefaultSigma) - share0 := CKSProtocol.AllocateShare(level) - share1 := CKSProtocol.AllocateShare(level) - share2 := CKSProtocol.AllocateShare(level) + share0 := cks.AllocateShare(ciphertext.Level()) + share1 := cks.AllocateShare(ciphertext.Level()) + share2 := cks.AllocateShare(ciphertext.Level()) - CKSProtocol.GenShare(testCtx.sk0, sk0Out, ciphertext, share0) - CKSProtocol.GenShare(testCtx.sk1, sk1Out, ciphertext, share1) - CKSProtocol.GenShare(testCtx.sk2, sk2Out, ciphertext, share2) + cks.GenShare(testCtx.sk0, sk0Out, ciphertext, share0) + cks.GenShare(testCtx.sk1, sk1Out, ciphertext, share1) + cks.GenShare(testCtx.sk2, sk2Out, ciphertext, share2) - CKSProtocol.AggregateShares(share0, share1, share0) - CKSProtocol.AggregateShares(share0, share2, share0) + cks.AggregateShares(share0, share1, share0) + cks.AggregateShares(share0, share2, share0) ksCiphertext := &rlwe.Ciphertext{Value: []*ring.Poly{params.RingQ().NewPoly(), params.RingQ().NewPoly()}} - CKSProtocol.KeySwitch(share0, ciphertext, ksCiphertext) + cks.KeySwitch(share0, ciphertext, ksCiphertext) // [-as + e] + [as] - ringQ.MulCoeffsMontgomeryAndAdd(ksCiphertext.Value[1], skOutIdeal.Value, ksCiphertext.Value[0]) + ringQ.MulCoeffsMontgomeryAndAdd(ksCiphertext.Value[1], skOutIdeal.Value.Q, ksCiphertext.Value[0]) ringQ.InvNTT(ksCiphertext.Value[0], ksCiphertext.Value[0]) log2Bound := bits.Len64(3 * uint64(math.Floor(rlwe.DefaultSigma*6)) * uint64(params.N())) require.GreaterOrEqual(t, log2Bound, log2OfInnerSum(ksCiphertext.Value[0].Level(), ringQ, ksCiphertext.Value[0])) @@ -175,22 +180,22 @@ func testPublicKeySwitching(testCtx testContext, t *testing.T) { params := testCtx.params ringQ := params.RingQ() - level := len(ringQ.Modulus) - 1 t.Run(testString(params, "PublicKeySwitching/"), func(t *testing.T) { skOut, pkOut := testCtx.kgen.GenKeyPair() - ciphertext := &rlwe.Ciphertext{Value: []*ring.Poly{ringQ.NewPoly(), testCtx.crpGenerator.ReadLvlNew(level)}} - ringQ.MulCoeffsMontgomeryAndSub(ciphertext.Value[1], testCtx.skIdeal.Value, ciphertext.Value[0]) + ciphertext := &rlwe.Ciphertext{Value: []*ring.Poly{ringQ.NewPoly(), ringQ.NewPoly()}} + testCtx.uniformSampler.Read(ciphertext.Value[1]) + ringQ.MulCoeffsMontgomeryAndSub(ciphertext.Value[1], testCtx.skIdeal.Value.Q, ciphertext.Value[0]) ciphertext.Value[0].IsNTT = true ciphertext.Value[1].IsNTT = true PCKSProtocol := NewPCKSProtocol(params, rlwe.DefaultSigma) - share0 := PCKSProtocol.AllocateShare(level) - share1 := PCKSProtocol.AllocateShare(level) - share2 := PCKSProtocol.AllocateShare(level) + share0 := PCKSProtocol.AllocateShare(ciphertext.Level()) + share1 := PCKSProtocol.AllocateShare(ciphertext.Level()) + share2 := PCKSProtocol.AllocateShare(ciphertext.Level()) PCKSProtocol.GenShare(testCtx.sk0, pkOut, ciphertext, share0) PCKSProtocol.GenShare(testCtx.sk1, pkOut, ciphertext, share1) @@ -204,7 +209,7 @@ func testPublicKeySwitching(testCtx testContext, t *testing.T) { PCKSProtocol.KeySwitch(share0, ciphertext, ksCiphertext) // [-as + e] + [as] - ringQ.MulCoeffsMontgomeryAndAdd(ksCiphertext.Value[1], skOut.Value, ksCiphertext.Value[0]) + ringQ.MulCoeffsMontgomeryAndAdd(ksCiphertext.Value[1], skOut.Value.Q, ksCiphertext.Value[0]) ringQ.InvNTT(ksCiphertext.Value[0], ksCiphertext.Value[0]) log2Bound := bits.Len64(3 * uint64(math.Floor(rlwe.DefaultSigma*6)) * uint64(params.N())) require.GreaterOrEqual(t, log2Bound+5, log2OfInnerSum(ksCiphertext.Value[0].Level(), ringQ, ksCiphertext.Value[0])) @@ -214,49 +219,48 @@ func testPublicKeySwitching(testCtx testContext, t *testing.T) { func testRelinKeyGen(testCtx testContext, t *testing.T) { params := testCtx.params - ringQP := params.RingQP() ringQ := params.RingQ() + ringP := params.RingP() + ringQP := params.RingQP() + levelQ, levelP := params.QCount()-1, params.PCount()-1 t.Run(testString(params, "RelinKeyGen/"), func(t *testing.T) { - RKGProtocol := NewRKGProtocol(params, rlwe.DefaultSigma) + rkg := NewRKGProtocol(params, rlwe.DefaultSigma) - ephSk0, share10, share20 := RKGProtocol.AllocateShares() - ephSk1, share11, share21 := RKGProtocol.AllocateShares() - ephSk2, share12, share22 := RKGProtocol.AllocateShares() + ephSk0, share10, share20 := rkg.AllocateShares() + ephSk1, share11, share21 := rkg.AllocateShares() + ephSk2, share12, share22 := rkg.AllocateShares() - crp := make([]*ring.Poly, params.Beta()) - for i := 0; i < params.Beta(); i++ { - crp[i] = testCtx.crpGenerator.ReadNew() - } + crp := rkg.SampleCRP(testCtx.crs) - RKGProtocol.GenShareRoundOne(testCtx.sk0, crp, ephSk0, share10) - RKGProtocol.GenShareRoundOne(testCtx.sk1, crp, ephSk1, share11) - RKGProtocol.GenShareRoundOne(testCtx.sk2, crp, ephSk2, share12) + rkg.GenShareRoundOne(testCtx.sk0, crp, ephSk0, share10) + rkg.GenShareRoundOne(testCtx.sk1, crp, ephSk1, share11) + rkg.GenShareRoundOne(testCtx.sk2, crp, ephSk2, share12) - RKGProtocol.AggregateShares(share10, share11, share10) - RKGProtocol.AggregateShares(share10, share12, share10) + rkg.AggregateShares(share10, share11, share10) + rkg.AggregateShares(share10, share12, share10) - RKGProtocol.GenShareRoundTwo(ephSk0, testCtx.sk0, share10, crp, share20) - RKGProtocol.GenShareRoundTwo(ephSk1, testCtx.sk1, share10, crp, share21) - RKGProtocol.GenShareRoundTwo(ephSk2, testCtx.sk2, share10, crp, share22) + rkg.GenShareRoundTwo(ephSk0, testCtx.sk0, share10, share20) + rkg.GenShareRoundTwo(ephSk1, testCtx.sk1, share10, share21) + rkg.GenShareRoundTwo(ephSk2, testCtx.sk2, share10, share22) - RKGProtocol.AggregateShares(share20, share21, share20) - RKGProtocol.AggregateShares(share20, share22, share20) + rkg.AggregateShares(share20, share21, share20) + rkg.AggregateShares(share20, share22, share20) rlk := rlwe.NewRelinKey(params, 2) - RKGProtocol.GenRelinearizationKey(share10, share20, rlk) + rkg.GenRelinearizationKey(share10, share20, rlk) skIn := testCtx.skIdeal.CopyNew() skOut := testCtx.skIdeal.CopyNew() - ringQP.MulCoeffsMontgomery(skIn.Value, skIn.Value, skIn.Value) + ringQP.MulCoeffsMontgomeryLvl(levelQ, levelP, skIn.Value, skIn.Value, skIn.Value) swk := rlk.Keys[0] // Decrypts // [-asIn + w*P*sOut + e, a] + [asIn] for j := range swk.Value { - ringQ.MulCoeffsMontgomeryAndAdd(swk.Value[j][1], skOut.Value, swk.Value[j][0]) + ringQP.MulCoeffsMontgomeryAndAddLvl(levelQ, levelP, swk.Value[j][1], skOut.Value, swk.Value[j][0]) } poly := swk.Value[0][0] @@ -265,24 +269,26 @@ func testRelinKeyGen(testCtx testContext, t *testing.T) { // sum([1]_w * [w*P*sOut + e]) = P*sOut + sum(e) for j := range swk.Value { if j > 0 { - ringQ.Add(poly, swk.Value[j][0], poly) + ringQP.AddLvl(levelQ, levelP, poly, swk.Value[j][0], poly) } } // sOut * P - ringQ.MulScalarBigint(skIn.Value, params.RingP().ModulusBigint, skIn.Value) + ringQ.MulScalarBigint(skIn.Value.Q, ringP.ModulusBigint, skIn.Value.Q) // P*s^i + sum(e) - P*s^i = sum(e) - ringQ.Sub(poly, skIn.Value, poly) + ringQ.Sub(swk.Value[0][0].Q, skIn.Value.Q, swk.Value[0][0].Q) // Checks that the error is below the bound - ringQ.InvNTT(poly, poly) - ringQ.InvMForm(poly, poly) + // Worst error bound is N * floor(6*sigma) * #Keys + ringQP.InvNTTLvl(levelQ, levelP, poly, poly) + ringQP.InvMFormLvl(levelQ, levelP, poly, poly) // Worst bound of inner sum // N*#Keys*(N * #Parties * floor(sigma*6) + #Parties * floor(sigma*6) + N * #Parties + #Parties * floor(6*sigma)) log2Bound := bits.Len64(uint64(params.N() * len(swk.Value) * (params.N()*3*int(math.Floor(rlwe.DefaultSigma*6)) + 2*3*int(math.Floor(rlwe.DefaultSigma*6)) + params.N()*3))) - require.GreaterOrEqual(t, log2Bound, log2OfInnerSum(len(ringQ.Modulus)-1, ringQ, poly)) + require.GreaterOrEqual(t, log2Bound, log2OfInnerSum(len(ringQ.Modulus)-1, ringQ, swk.Value[0][0].Q)) + require.GreaterOrEqual(t, log2Bound, log2OfInnerSum(len(ringP.Modulus)-1, ringP, swk.Value[0][0].P)) }) } @@ -290,85 +296,88 @@ func testRotKeyGen(testCtx testContext, t *testing.T) { params := testCtx.params ringQ := params.RingQ() + ringP := params.RingP() + ringQP := params.RingQP() + levelQ, levelP := params.QCount()-1, params.PCount()-1 t.Run(testString(params, "RotKeyGen/"), func(t *testing.T) { - crp := make([]*ring.Poly, params.Beta()) - for i := 0; i < params.Beta(); i++ { - crp[i] = testCtx.crpGenerator.ReadNew() - } + rtg := NewRTGProtocol(params) - RTGProtocol := NewRTGProtocol(params) + share0 := rtg.AllocateShares() + share1 := rtg.AllocateShares() + share2 := rtg.AllocateShares() - share0 := RTGProtocol.AllocateShares() - share1 := RTGProtocol.AllocateShares() - share2 := RTGProtocol.AllocateShares() + crp := rtg.SampleCRP(testCtx.crs) galEl := params.GaloisElementForRowRotation() - RTGProtocol.GenShare(testCtx.sk0, galEl, crp, share0) - RTGProtocol.GenShare(testCtx.sk1, galEl, crp, share1) - RTGProtocol.GenShare(testCtx.sk2, galEl, crp, share2) + rtg.GenShare(testCtx.sk0, galEl, crp, share0) + rtg.GenShare(testCtx.sk1, galEl, crp, share1) + rtg.GenShare(testCtx.sk2, galEl, crp, share2) - RTGProtocol.Aggregate(share0, share1, share0) - RTGProtocol.Aggregate(share0, share2, share0) + rtg.Aggregate(share0, share1, share0) + rtg.Aggregate(share0, share2, share0) rotKeySet := rlwe.NewRotationKeySet(params, []uint64{galEl}) - RTGProtocol.GenRotationKey(share0, crp, rotKeySet.Keys[galEl]) + rtg.GenRotationKey(share0, crp, rotKeySet.Keys[galEl]) skIn := testCtx.skIdeal.CopyNew() skOut := testCtx.skIdeal.CopyNew() - galElInv := ring.ModExp(galEl, int(4*params.N()-1), uint64(4*params.N())) - ring.PermuteNTT(testCtx.skIdeal.Value, galElInv, skOut.Value) + galElInv := ring.ModExp(galEl, uint64(2*params.N()-1), uint64(2*params.N())) + ring.PermuteNTT(testCtx.skIdeal.Value.Q, galElInv, skOut.Value.Q) + ring.PermuteNTT(testCtx.skIdeal.Value.P, galElInv, skOut.Value.P) swk := rotKeySet.Keys[galEl] // Decrypts // [-asIn + w*P*sOut + e, a] + [asIn] for j := range swk.Value { - ringQ.MulCoeffsMontgomeryAndAdd(swk.Value[j][1], skOut.Value, swk.Value[j][0]) - } + ringQP.MulCoeffsMontgomeryAndAddLvl(levelQ, levelP, swk.Value[j][1], skOut.Value, swk.Value[j][0]) - poly := swk.Value[0][0] + } // Sums all basis together (equivalent to multiplying with CRT decomposition of 1) // sum([1]_w * [w*P*sOut + e]) = P*sOut + sum(e) for j := range swk.Value { if j > 0 { - ringQ.Add(poly, swk.Value[j][0], poly) + ringQP.AddLvl(levelQ, levelP, swk.Value[0][0], swk.Value[j][0], swk.Value[0][0]) } } // sOut * P - ringQ.MulScalarBigint(skIn.Value, params.RingP().ModulusBigint, skIn.Value) + ringQ.MulScalarBigint(skIn.Value.Q, ringP.ModulusBigint, skIn.Value.Q) // P*s^i + sum(e) - P*s^i = sum(e) - ringQ.Sub(poly, skIn.Value, poly) + ringQ.Sub(swk.Value[0][0].Q, skIn.Value.Q, swk.Value[0][0].Q) // Checks that the error is below the bound - ringQ.InvNTT(poly, poly) - ringQ.InvMForm(poly, poly) + // Worst error bound is N * floor(6*sigma) * #Keys + ringQP.InvNTTLvl(levelQ, levelP, swk.Value[0][0], swk.Value[0][0]) + ringQP.InvMFormLvl(levelQ, levelP, swk.Value[0][0], swk.Value[0][0]) // Worst bound of inner sum // N*#Keys*(N * #Parties * floor(sigma*6) + #Parties * floor(sigma*6) + N * #Parties + #Parties * floor(6*sigma)) - log2Bound := bits.Len64(uint64(params.N() * len(swk.Value) * (params.N()*3*int(math.Floor(rlwe.DefaultSigma*6)) + 2*3*int(math.Floor(rlwe.DefaultSigma*6)) + params.N()*3))) - require.GreaterOrEqual(t, log2Bound, log2OfInnerSum(len(ringQ.Modulus)-1, ringQ, poly)) + log2Bound := bits.Len64(3 * uint64(math.Floor(rlwe.DefaultSigma*6)) * uint64(params.N())) + require.GreaterOrEqual(t, log2Bound, log2OfInnerSum(len(ringQ.Modulus)-1, ringQ, swk.Value[0][0].Q)) + require.GreaterOrEqual(t, log2Bound, log2OfInnerSum(len(ringP.Modulus)-1, ringP, swk.Value[0][0].P)) }) } func testMarshalling(testCtx testContext, t *testing.T) { - crs := testCtx.crpGenerator.ReadNew() params := testCtx.params - level := len(params.RingQ().Modulus) - 1 + ciphertext := &rlwe.Ciphertext{Value: []*ring.Poly{params.RingQ().NewPoly(), params.RingQ().NewPoly()}} + testCtx.uniformSampler.Read(ciphertext.Value[0]) + testCtx.uniformSampler.Read(ciphertext.Value[1]) - ciphertext := &rlwe.Ciphertext{Value: []*ring.Poly{testCtx.crpGenerator.ReadLvlNew(level), testCtx.crpGenerator.ReadLvlNew(level)}} + t.Run(testString(params, "Marshalling/CKG/"), func(t *testing.T) { + ckg := NewCKGProtocol(testCtx.params) + KeyGenShareBefore := ckg.AllocateShares() + crs := ckg.SampleCRP(testCtx.crs) - t.Run(testString(params, "Marshalling/CPK/"), func(t *testing.T) { - keygenProtocol := NewCKGProtocol(testCtx.params) - KeyGenShareBefore := keygenProtocol.AllocateShares() - keygenProtocol.GenShare(testCtx.sk0, crs, KeyGenShareBefore) + ckg.GenShare(testCtx.sk0, crs, KeyGenShareBefore) //now we marshall it data, err := KeyGenShareBefore.MarshalBinary() @@ -383,11 +392,12 @@ func testMarshalling(testCtx testContext, t *testing.T) { } //comparing the results - require.Equal(t, KeyGenShareBefore.Degree(), KeyGenShareAfter.Degree()) - require.Equal(t, KeyGenShareBefore.LenModuli(), KeyGenShareAfter.LenModuli()) - - moduli := KeyGenShareBefore.LenModuli() - require.Equal(t, KeyGenShareAfter.Coeffs[:moduli], KeyGenShareBefore.Coeffs[:moduli]) + require.Equal(t, KeyGenShareBefore.Value.Q.Degree(), KeyGenShareAfter.Value.Q.Degree()) + require.Equal(t, KeyGenShareBefore.Value.P.Degree(), KeyGenShareAfter.Value.P.Degree()) + require.Equal(t, KeyGenShareBefore.Value.Q.LenModuli(), KeyGenShareAfter.Value.Q.LenModuli()) + require.Equal(t, KeyGenShareBefore.Value.P.LenModuli(), KeyGenShareAfter.Value.P.LenModuli()) + require.Equal(t, KeyGenShareAfter.Value.Q.Coeffs, KeyGenShareBefore.Value.Q.Coeffs) + require.Equal(t, KeyGenShareAfter.Value.P.Coeffs, KeyGenShareBefore.Value.P.Coeffs) }) t.Run(testString(params, "Marshalling/PCKS/"), func(t *testing.T) { @@ -405,14 +415,12 @@ func testMarshalling(testCtx testContext, t *testing.T) { err = SwitchShareReceiver.UnmarshalBinary(data) require.NoError(t, err) - for i := 0; i < 2; i++ { - //compare the shares. - ringBefore := SwitchShare.Value[i] - ringAfter := SwitchShareReceiver.Value[i] - require.Equal(t, ringBefore.Degree(), ringAfter.Degree()) - moduli := ringAfter.LenModuli() - require.Equal(t, ringAfter.Coeffs[:moduli], ringBefore.Coeffs[:moduli]) - } + require.Equal(t, SwitchShare.Value[0].Degree(), SwitchShareReceiver.Value[0].Degree()) + require.Equal(t, SwitchShare.Value[1].Degree(), SwitchShareReceiver.Value[1].Degree()) + require.Equal(t, SwitchShare.Value[0].LenModuli(), SwitchShareReceiver.Value[0].LenModuli()) + require.Equal(t, SwitchShare.Value[1].LenModuli(), SwitchShareReceiver.Value[1].LenModuli()) + require.Equal(t, SwitchShare.Value[0].Coeffs, SwitchShareReceiver.Value[0].Coeffs) + require.Equal(t, SwitchShare.Value[1].Coeffs, SwitchShareReceiver.Value[1].Coeffs) }) t.Run(testString(params, "Marshalling/CKS/"), func(t *testing.T) { @@ -433,8 +441,7 @@ func testMarshalling(testCtx testContext, t *testing.T) { require.Equal(t, cksshare.Value.Degree(), cksshareAfter.Value.Degree()) require.Equal(t, cksshare.Value.LenModuli(), cksshareAfter.Value.LenModuli()) - moduli := cksshare.Value.LenModuli() - require.Equal(t, cksshare.Value.Coeffs[:moduli], cksshareAfter.Value.Coeffs[:moduli]) + require.Equal(t, cksshare.Value.Coeffs, cksshareAfter.Value.Coeffs) }) t.Run(testString(params, "Marshalling/RKG/"), func(t *testing.T) { @@ -445,10 +452,7 @@ func testMarshalling(testCtx testContext, t *testing.T) { ephSk0, share10, _ := RKGProtocol.AllocateShares() - crp := make([]*ring.Poly, params.Beta()) - for i := 0; i < params.Beta(); i++ { - crp[i] = testCtx.crpGenerator.ReadNew() - } + crp := RKGProtocol.SampleCRP(testCtx.crs) RKGProtocol.GenShareRoundOne(testCtx.sk0, crp, ephSk0, share10) @@ -461,14 +465,15 @@ func testMarshalling(testCtx testContext, t *testing.T) { require.Equal(t, len(rkgShare.Value), len(share10.Value)) for i, val := range share10.Value { - require.Equal(t, len(rkgShare.Value[i][0].Coeffs), len(val[0].Coeffs)) - moduli := val[0].LenModuli() - require.Equal(t, rkgShare.Value[i][0].Coeffs[:moduli], val[0].Coeffs[:moduli]) - - require.Equal(t, len(rkgShare.Value[i][1].Coeffs), len(val[1].Coeffs)) - moduli = val[1].LenModuli() - require.Equal(t, rkgShare.Value[i][1].Coeffs[:moduli], val[1].Coeffs[:moduli]) + require.Equal(t, len(rkgShare.Value[i][0].Q.Coeffs), len(val[0].Q.Coeffs)) + require.Equal(t, len(rkgShare.Value[i][0].P.Coeffs), len(val[0].P.Coeffs)) + require.Equal(t, rkgShare.Value[i][0].Q.Coeffs, val[0].Q.Coeffs) + require.Equal(t, rkgShare.Value[i][0].P.Coeffs, val[0].P.Coeffs) + require.Equal(t, len(rkgShare.Value[i][1].Q.Coeffs), len(val[1].Q.Coeffs)) + require.Equal(t, len(rkgShare.Value[i][1].P.Coeffs), len(val[1].P.Coeffs)) + require.Equal(t, rkgShare.Value[i][1].Q.Coeffs, val[1].Q.Coeffs) + require.Equal(t, rkgShare.Value[i][1].P.Coeffs, val[1].P.Coeffs) } }) @@ -476,16 +481,14 @@ func testMarshalling(testCtx testContext, t *testing.T) { //check RTGShare - crp := make([]*ring.Poly, params.Beta()) - for i := 0; i < params.Beta(); i++ { - crp[i] = testCtx.crpGenerator.ReadNew() - } - galEl := testCtx.params.GaloisElementForColumnRotationBy(64) - RTGProtocol := NewRTGProtocol(testCtx.params) - rtgShare := RTGProtocol.AllocateShares() - RTGProtocol.GenShare(testCtx.sk1, galEl, crp, rtgShare) + rtg := NewRTGProtocol(testCtx.params) + rtgShare := rtg.AllocateShares() + + crp := rtg.SampleCRP(testCtx.crs) + + rtg.GenShare(testCtx.sk1, galEl, crp, rtgShare) data, err := rtgShare.MarshalBinary() require.NoError(t, err) @@ -497,10 +500,10 @@ func testMarshalling(testCtx testContext, t *testing.T) { require.Equal(t, len(resRTGShare.Value), len(rtgShare.Value)) for i, val := range rtgShare.Value { - require.Equal(t, len(resRTGShare.Value[i].Coeffs), len(val.Coeffs)) - moduli := val.LenModuli() - require.Equal(t, resRTGShare.Value[i].Coeffs[:moduli], val.Coeffs[:moduli]) - + require.Equal(t, len(resRTGShare.Value[i].Q.Coeffs), len(val.Q.Coeffs)) + require.Equal(t, resRTGShare.Value[i].Q.Coeffs, val.Q.Coeffs) + require.Equal(t, len(resRTGShare.Value[i].P.Coeffs), len(val.P.Coeffs)) + require.Equal(t, resRTGShare.Value[i].P.Coeffs, val.P.Coeffs) } }) } diff --git a/drlwe/keygen_cpk.go b/drlwe/keygen_cpk.go index 92c4fffc..156d9612 100644 --- a/drlwe/keygen_cpk.go +++ b/drlwe/keygen_cpk.go @@ -10,73 +10,89 @@ import ( // CollectivePublicKeyGenerator is an interface describing the local steps of a generic RLWE CKG protocol. type CollectivePublicKeyGenerator interface { AllocateShares() *CKGShare - GenShare(sk *rlwe.SecretKey, crs *ring.Poly, shareOut *CKGShare) + GenShare(sk *rlwe.SecretKey, crp CKGCRP, shareOut *CKGShare) AggregateShares(share1, share2, shareOut *CKGShare) - GenPublicKey(aggregatedShare *CKGShare, crs *ring.Poly, pubkey *rlwe.PublicKey) + GenPublicKey(aggregatedShare *CKGShare, crp CKGCRP, pubkey *rlwe.PublicKey) } // CKGProtocol is the structure storing the parameters and and precomputations for the collective key generation protocol. type CKGProtocol struct { - params rlwe.Parameters - - ringQP *ring.Ring - gaussianSampler *ring.GaussianSampler + params rlwe.Parameters + gaussianSamplerQ *ring.GaussianSampler } // CKGShare is a struct storing the CKG protocol's share. type CKGShare struct { - *ring.Poly + Value rlwe.PolyQP } -// UnmarshalBinary decode a marshaled CKG share on the target CKG share. -func (share *CKGShare) UnmarshalBinary(data []byte) error { - if share.Poly == nil { - share.Poly = new(ring.Poly) +// CKGCRP is a type for common reference polynomials in the CKG protocol. +type CKGCRP rlwe.PolyQP + +// MarshalBinary encodes the target element on a slice of bytes. +func (share *CKGShare) MarshalBinary() (data []byte, err error) { + data = make([]byte, share.Value.GetDataLen(true)) + if _, err = share.Value.WriteTo(data); err != nil { + return nil, err } - err := share.Poly.UnmarshalBinary(data) + return +} + +// UnmarshalBinary decodes a slice of bytes on the target element. +func (share *CKGShare) UnmarshalBinary(data []byte) (err error) { + _, err = share.Value.DecodePolyNew(data) return err } // NewCKGProtocol creates a new CKGProtocol instance -func NewCKGProtocol(params rlwe.Parameters) *CKGProtocol { // TODO drlwe.Params - +func NewCKGProtocol(params rlwe.Parameters) *CKGProtocol { ckg := new(CKGProtocol) ckg.params = params - ckg.ringQP = params.RingQP() - var err error prng, err := utils.NewPRNG() if err != nil { panic(err) } - ckg.gaussianSampler = ring.NewGaussianSampler(prng, ckg.ringQP, params.Sigma(), int(6*params.Sigma())) + ckg.gaussianSamplerQ = ring.NewGaussianSampler(prng, ckg.params.RingQ(), params.Sigma(), int(6*params.Sigma())) return ckg } // AllocateShares allocates the share of the CKG protocol. func (ckg *CKGProtocol) AllocateShares() *CKGShare { - return &CKGShare{ckg.ringQP.NewPoly()} + return &CKGShare{ckg.params.RingQP().NewPoly()} +} + +// SampleCRP samples a common random polynomial to be used in the CKG protocol from the provided +// common reference string. +func (ckg *CKGProtocol) SampleCRP(crs CRS) CKGCRP { + crp := ckg.params.RingQP().NewPoly() + rlwe.NewUniformSamplerQP(ckg.params, crs, ckg.params.RingQP()).Read(&crp) + return CKGCRP(crp) } // GenShare generates the party's public key share from its secret key as: // -// crs*s_i + e_i +// crp*s_i + e_i // // for the receiver protocol. Has no effect is the share was already generated. -func (ckg *CKGProtocol) GenShare(sk *rlwe.SecretKey, crs *ring.Poly, shareOut *CKGShare) { - ringQP := ckg.ringQP - ckg.gaussianSampler.Read(shareOut.Poly) - ringQP.NTT(shareOut.Poly, shareOut.Poly) - ringQP.MulCoeffsMontgomeryAndSub(sk.Value, crs, shareOut.Poly) +func (ckg *CKGProtocol) GenShare(sk *rlwe.SecretKey, crp CKGCRP, shareOut *CKGShare) { + ringQP := ckg.params.RingQP() + + ckg.gaussianSamplerQ.Read(shareOut.Value.Q) + ringQP.ExtendBasisSmallNormAndCenter(shareOut.Value.Q, ckg.params.PCount()-1, nil, shareOut.Value.P) + levelQ, levelP := ckg.params.QCount()-1, ckg.params.PCount()-1 + ringQP.NTTLvl(levelQ, levelP, shareOut.Value, shareOut.Value) + + ringQP.MulCoeffsMontgomeryAndSubLvl(levelQ, levelP, sk.Value, rlwe.PolyQP(crp), shareOut.Value) } // AggregateShares aggregates a new share to the aggregate key func (ckg *CKGProtocol) AggregateShares(share1, share2, shareOut *CKGShare) { - ckg.ringQP.Add(share1.Poly, share2.Poly, shareOut.Poly) + ckg.params.RingQP().AddLvl(ckg.params.QCount()-1, ckg.params.PCount()-1, share1.Value, share2.Value, shareOut.Value) } // GenPublicKey return the current aggregation of the received shares as a bfv.PublicKey. -func (ckg *CKGProtocol) GenPublicKey(roundShare *CKGShare, crs *ring.Poly, pubkey *rlwe.PublicKey) { - pubkey.Value[0].Copy(roundShare.Poly) - pubkey.Value[1].Copy(crs) +func (ckg *CKGProtocol) GenPublicKey(roundShare *CKGShare, crp CKGCRP, pubkey *rlwe.PublicKey) { + pubkey.Value[0].Copy(roundShare.Value) + pubkey.Value[1].Copy(rlwe.PolyQP(crp)) } diff --git a/drlwe/keygen_relin.go b/drlwe/keygen_relin.go index 780e41b1..0b4ca4f9 100644 --- a/drlwe/keygen_relin.go +++ b/drlwe/keygen_relin.go @@ -2,54 +2,57 @@ package drlwe import ( "errors" + "math/big" + "github.com/ldsec/lattigo/v2/ring" "github.com/ldsec/lattigo/v2/rlwe" "github.com/ldsec/lattigo/v2/utils" - "math/big" ) // RelinearizationKeyGenerator is an interface describing the local steps of a generic RLWE RKG protocol type RelinearizationKeyGenerator interface { AllocateShares() (ephKey *rlwe.SecretKey, r1 *RKGShare, r2 *RKGShare) - GenShareRoundOne(sk *rlwe.SecretKey, crp []*ring.Poly, ephKeyOut *rlwe.SecretKey, shareOut *RKGShare) - GenShareRoundTwo(ephSk, sk *rlwe.SecretKey, round1 *RKGShare, crp []*ring.Poly, shareOut *RKGShare) + GenShareRoundOne(sk *rlwe.SecretKey, crp RKGCRP, ephKeyOut *rlwe.SecretKey, shareOut *RKGShare) + GenShareRoundTwo(ephSk, sk *rlwe.SecretKey, round1 *RKGShare, shareOut *RKGShare) AggregateShares(share1, share2, shareOut *RKGShare) - GenRelinearizationKey(round1 *RKGShare, round2 *RKGShare, relinKeyOut *rlwe.RelinearizationKey) // TODO type for generic eval key + GenRelinearizationKey(round1 *RKGShare, round2 *RKGShare, relinKeyOut *rlwe.RelinearizationKey) } // RKGProtocol is the structure storing the parameters and and precomputations for the collective relinearization key generation protocol. type RKGProtocol struct { - params rlwe.Parameters - pBigInt *big.Int - ringQP *ring.Ring - gaussianSampler *ring.GaussianSampler - ternarySampler *ring.TernarySampler // sampling in Montgomerry form + params rlwe.Parameters + pBigInt *big.Int + gaussianSamplerQ *ring.GaussianSampler + ternarySamplerQ *ring.TernarySampler // sampling in Montgomerry form - tmpPoly1 *ring.Poly - tmpPoly2 *ring.Poly + tmpPoly1 rlwe.PolyQP + tmpPoly2 rlwe.PolyQP } // RKGShare is a share in the RKG protocol type RKGShare struct { - Value [][2]*ring.Poly + Value [][2]rlwe.PolyQP } +// RKGCRP is a type for common reference polynomials in the RKG protocol. +type RKGCRP []rlwe.PolyQP + // NewRKGProtocol creates a new RKG protocol struct func NewRKGProtocol(params rlwe.Parameters, ephSkPr float64) *RKGProtocol { rkg := new(RKGProtocol) rkg.params = params - rkg.ringQP = params.RingQP() var err error prng, err := utils.NewPRNG() if err != nil { - panic(err) // TODO error + panic(err) } rkg.pBigInt = params.PBigInt() - rkg.gaussianSampler = ring.NewGaussianSampler(prng, rkg.ringQP, params.Sigma(), int(6*params.Sigma())) - rkg.ternarySampler = ring.NewTernarySampler(prng, rkg.ringQP, ephSkPr, true) - rkg.tmpPoly1, rkg.tmpPoly2 = rkg.ringQP.NewPoly(), rkg.ringQP.NewPoly() + rkg.gaussianSamplerQ = ring.NewGaussianSampler(prng, params.RingQ(), params.Sigma(), int(6*params.Sigma())) + rkg.ternarySamplerQ = ring.NewTernarySampler(prng, params.RingQ(), ephSkPr, false) + rkg.tmpPoly1 = params.RingQP().NewPoly() + rkg.tmpPoly2 = params.RingQP().NewPoly() return rkg } @@ -57,61 +60,84 @@ func NewRKGProtocol(params rlwe.Parameters, ephSkPr float64) *RKGProtocol { func (ekg *RKGProtocol) AllocateShares() (ephSk *rlwe.SecretKey, r1 *RKGShare, r2 *RKGShare) { ephSk = rlwe.NewSecretKey(ekg.params) r1, r2 = new(RKGShare), new(RKGShare) - r1.Value = make([][2]*ring.Poly, ekg.params.Beta()) - r2.Value = make([][2]*ring.Poly, ekg.params.Beta()) + r1.Value = make([][2]rlwe.PolyQP, ekg.params.Beta()) + r2.Value = make([][2]rlwe.PolyQP, ekg.params.Beta()) for i := 0; i < ekg.params.Beta(); i++ { - r1.Value[i][0] = ekg.ringQP.NewPoly() - r1.Value[i][1] = ekg.ringQP.NewPoly() - r2.Value[i][0] = ekg.ringQP.NewPoly() - r2.Value[i][1] = ekg.ringQP.NewPoly() + r1.Value[i][0] = ekg.params.RingQP().NewPoly() + r1.Value[i][1] = ekg.params.RingQP().NewPoly() + r2.Value[i][0] = ekg.params.RingQP().NewPoly() + r2.Value[i][1] = ekg.params.RingQP().NewPoly() } return } +// SampleCRP samples a common random polynomial to be used in the RKG protocol from the provided +// common reference string. +func (ekg *RKGProtocol) SampleCRP(crs CRS) RKGCRP { + crp := make([]rlwe.PolyQP, ekg.params.Beta()) + us := rlwe.NewUniformSamplerQP(ekg.params, crs, ekg.params.RingQP()) + for i := range crp { + crp[i] = ekg.params.RingQP().NewPoly() + us.Read(&crp[i]) + } + return RKGCRP(crp) +} + // GenShareRoundOne is the first of three rounds of the RKGProtocol protocol. Each party generates a pseudo encryption of // its secret share of the key s_i under its ephemeral key u_i : [-u_i*a + s_i*w + e_i] and broadcasts it to the other // j-1 parties. -func (ekg *RKGProtocol) GenShareRoundOne(sk *rlwe.SecretKey, crp []*ring.Poly, ephSkOut *rlwe.SecretKey, shareOut *RKGShare) { +func (ekg *RKGProtocol) GenShareRoundOne(sk *rlwe.SecretKey, crp RKGCRP, ephSkOut *rlwe.SecretKey, shareOut *RKGShare) { // Given a base decomposition w_i (here the CRT decomposition) // computes [-u*a_i + P*s_i + e_i] // where a_i = crp_i - ekg.ringQP.MulScalarBigint(sk.Value, ekg.pBigInt, ekg.tmpPoly1) - ekg.ringQP.InvMForm(ekg.tmpPoly1, ekg.tmpPoly1) - ekg.ternarySampler.Read(ephSkOut.Value) - ekg.ringQP.NTT(ephSkOut.Value, ephSkOut.Value) + ringQ := ekg.params.RingQ() + ringQP := ekg.params.RingQP() + levelQ := ekg.params.QCount() - 1 + levelP := ekg.params.PCount() - 1 + + ringQ.MulScalarBigint(sk.Value.Q, ekg.pBigInt, ekg.tmpPoly1.Q) + ringQ.InvMForm(ekg.tmpPoly1.Q, ekg.tmpPoly1.Q) + + ekg.ternarySamplerQ.Read(ephSkOut.Value.Q) + ringQP.ExtendBasisSmallNormAndCenter(ephSkOut.Value.Q, levelP, nil, ephSkOut.Value.P) + ringQP.NTTLvl(levelQ, levelP, ephSkOut.Value, ephSkOut.Value) + ringQP.MFormLvl(levelQ, levelP, ephSkOut.Value, ephSkOut.Value) for i := 0; i < ekg.params.Beta(); i++ { // h = e - ekg.gaussianSampler.Read(shareOut.Value[i][0]) - ekg.ringQP.NTT(shareOut.Value[i][0], shareOut.Value[i][0]) + ekg.gaussianSamplerQ.Read(shareOut.Value[i][0].Q) + ringQP.ExtendBasisSmallNormAndCenter(shareOut.Value[i][0].Q, levelP, nil, shareOut.Value[i][0].P) + ringQP.NTTLvl(levelQ, levelP, shareOut.Value[i][0], shareOut.Value[i][0]) // h = sk*CrtBaseDecompQi + e for j := 0; j < ekg.params.PCount(); j++ { index := i*ekg.params.PCount() + j - qi := ekg.ringQP.Modulus[index] - skP := ekg.tmpPoly1.Coeffs[index] - h := shareOut.Value[i][0].Coeffs[index] - - for w := 0; w < ekg.ringQP.N; w++ { - h[w] = ring.CRed(h[w]+skP[w], qi) - } // Handles the case where nb pj does not divides nb qi if index >= ekg.params.QCount() { break } + + qi := ringQ.Modulus[index] + skP := ekg.tmpPoly1.Q.Coeffs[index] + h := shareOut.Value[i][0].Q.Coeffs[index] + + for w := 0; w < ringQ.N; w++ { + h[w] = ring.CRed(h[w]+skP[w], qi) + } } // h = sk*CrtBaseDecompQi + -u*a + e - ekg.ringQP.MulCoeffsMontgomeryAndSub(ephSkOut.Value, crp[i], shareOut.Value[i][0]) + ringQP.MulCoeffsMontgomeryAndSubLvl(levelQ, levelP, ephSkOut.Value, crp[i], shareOut.Value[i][0]) // Second Element // e_2i - ekg.gaussianSampler.Read(shareOut.Value[i][1]) - ekg.ringQP.NTT(shareOut.Value[i][1], shareOut.Value[i][1]) + ekg.gaussianSamplerQ.Read(shareOut.Value[i][1].Q) + ringQP.ExtendBasisSmallNormAndCenter(shareOut.Value[i][1].Q, levelP, nil, shareOut.Value[i][1].P) + ringQP.NTTLvl(levelQ, levelP, shareOut.Value[i][1], shareOut.Value[i][1]) // s*a + e_2i - ekg.ringQP.MulCoeffsMontgomeryAndAdd(sk.Value, crp[i], shareOut.Value[i][1]) + ringQP.MulCoeffsMontgomeryAndAddLvl(levelQ, levelP, sk.Value, crp[i], shareOut.Value[i][1]) } } @@ -122,9 +148,14 @@ func (ekg *RKGProtocol) GenShareRoundOne(sk *rlwe.SecretKey, crp []*ring.Poly, e // = [s_i * (-u*a + s*w + e) + e_i1, s_i*a + e_i2] // // and broadcasts both values to the other j-1 parties. -func (ekg *RKGProtocol) GenShareRoundTwo(ephSk, sk *rlwe.SecretKey, round1 *RKGShare, crp []*ring.Poly, shareOut *RKGShare) { +func (ekg *RKGProtocol) GenShareRoundTwo(ephSk, sk *rlwe.SecretKey, round1 *RKGShare, shareOut *RKGShare) { + + ringQP := ekg.params.RingQP() + levelQ := ekg.params.QCount() - 1 + levelP := ekg.params.PCount() - 1 + // (u_i - s_i) - ekg.ringQP.Sub(ephSk.Value, sk.Value, ekg.tmpPoly1) + ringQP.SubLvl(levelQ, levelP, ephSk.Value, sk.Value, ekg.tmpPoly1) // Each sample is of the form [-u*a_i + s*w_i + e_i] // So for each element of the base decomposition w_i : @@ -133,47 +164,48 @@ func (ekg *RKGProtocol) GenShareRoundTwo(ephSk, sk *rlwe.SecretKey, round1 *RKGS // Computes [(sum samples)*sk + e_1i, sk*a + e_2i] // (AggregateShareRoundTwo samples) * sk - ekg.ringQP.MulCoeffsMontgomeryConstant(round1.Value[i][0], sk.Value, shareOut.Value[i][0]) + ringQP.MulCoeffsMontgomeryConstantLvl(levelQ, levelP, round1.Value[i][0], sk.Value, shareOut.Value[i][0]) // (AggregateShareRoundTwo samples) * sk + e_1i - ekg.gaussianSampler.Read(ekg.tmpPoly2) - ekg.ringQP.NTT(ekg.tmpPoly2, ekg.tmpPoly2) - ekg.ringQP.Add(shareOut.Value[i][0], ekg.tmpPoly2, shareOut.Value[i][0]) + ekg.gaussianSamplerQ.Read(ekg.tmpPoly2.Q) + ringQP.ExtendBasisSmallNormAndCenter(ekg.tmpPoly2.Q, levelP, nil, ekg.tmpPoly2.P) + ringQP.NTTLvl(levelQ, levelP, ekg.tmpPoly2, ekg.tmpPoly2) + ringQP.AddLvl(levelQ, levelP, shareOut.Value[i][0], ekg.tmpPoly2, shareOut.Value[i][0]) // second part // (u - s) * (sum [x][s*a_i + e_2i]) + e3i - ekg.gaussianSampler.Read(shareOut.Value[i][1]) - ekg.ringQP.NTT(shareOut.Value[i][1], shareOut.Value[i][1]) - ekg.ringQP.MulCoeffsMontgomeryAndAdd(ekg.tmpPoly1, round1.Value[i][1], shareOut.Value[i][1]) + ekg.gaussianSamplerQ.Read(shareOut.Value[i][1].Q) + ringQP.ExtendBasisSmallNormAndCenter(shareOut.Value[i][1].Q, levelP, nil, shareOut.Value[i][1].P) + ringQP.NTTLvl(levelQ, levelP, shareOut.Value[i][1], shareOut.Value[i][1]) + ringQP.MulCoeffsMontgomeryAndAddLvl(levelQ, levelP, ekg.tmpPoly1, round1.Value[i][1], shareOut.Value[i][1]) } } // AggregateShares combines two RKG shares into a single one func (ekg *RKGProtocol) AggregateShares(share1, share2, shareOut *RKGShare) { - + ringQP, levelQ, levelP := ekg.params.RingQP(), ekg.params.QCount()-1, ekg.params.PCount()-1 for i := 0; i < ekg.params.Beta(); i++ { - ekg.ringQP.Add(share1.Value[i][0], share2.Value[i][0], shareOut.Value[i][0]) - ekg.ringQP.Add(share1.Value[i][1], share2.Value[i][1], shareOut.Value[i][1]) + ringQP.AddLvl(levelQ, levelP, share1.Value[i][0], share2.Value[i][0], shareOut.Value[i][0]) + ringQP.AddLvl(levelQ, levelP, share1.Value[i][1], share2.Value[i][1], shareOut.Value[i][1]) } } // GenRelinearizationKey computes the generated RLK from the public shares and write the result in evalKeyOut func (ekg *RKGProtocol) GenRelinearizationKey(round1 *RKGShare, round2 *RKGShare, evalKeyOut *rlwe.RelinearizationKey) { + ringQP, levelQ, levelP := ekg.params.RingQP(), ekg.params.QCount()-1, ekg.params.PCount()-1 for i := 0; i < ekg.params.Beta(); i++ { - ekg.ringQP.Add(round2.Value[i][0], round2.Value[i][1], evalKeyOut.Keys[0].Value[i][0]) + ringQP.AddLvl(levelQ, levelP, round2.Value[i][0], round2.Value[i][1], evalKeyOut.Keys[0].Value[i][0]) evalKeyOut.Keys[0].Value[i][1].Copy(round1.Value[i][1]) - - ekg.ringQP.MForm(evalKeyOut.Keys[0].Value[i][0], evalKeyOut.Keys[0].Value[i][0]) - ekg.ringQP.MForm(evalKeyOut.Keys[0].Value[i][1], evalKeyOut.Keys[0].Value[i][1]) + ringQP.MFormLvl(levelQ, levelP, evalKeyOut.Keys[0].Value[i][0], evalKeyOut.Keys[0].Value[i][0]) + ringQP.MFormLvl(levelQ, levelP, evalKeyOut.Keys[0].Value[i][1], evalKeyOut.Keys[0].Value[i][1]) } } // MarshalBinary encodes the target element on a slice of bytes. func (share *RKGShare) MarshalBinary() ([]byte, error) { //we have modulus * bitLog * Len of 1 ring rings - rLength := (share.Value[0])[0].GetDataLen(true) - data := make([]byte, 1+2*rLength*len(share.Value)) + data := make([]byte, 1+2*share.Value[0][0].GetDataLen(true)*len(share.Value)) if len(share.Value) > 0xFF { return []byte{}, errors.New("RKGShare : uint8 overflow on length") } @@ -182,17 +214,19 @@ func (share *RKGShare) MarshalBinary() ([]byte, error) { //write all of our rings in the data. //write all the polys ptr := 1 + var inc int + var err error for _, elem := range share.Value { - _, err := elem[0].WriteTo(data[ptr : ptr+rLength]) - if err != nil { + + if inc, err = elem[0].WriteTo(data[ptr:]); err != nil { return []byte{}, err } - ptr += rLength - _, err = elem[1].WriteTo(data[ptr : ptr+rLength]) - if err != nil { + ptr += inc + + if inc, err = elem[1].WriteTo(data[ptr:]); err != nil { return []byte{}, err } - ptr += rLength + ptr += inc } return data, nil @@ -200,31 +234,20 @@ func (share *RKGShare) MarshalBinary() ([]byte, error) { } // UnmarshalBinary decodes a slice of bytes on the target element. -func (share *RKGShare) UnmarshalBinary(data []byte) error { - lenShare := data[0] - rLength := (len(data) - 1) / (2 * int(lenShare)) - - if share.Value == nil { - share.Value = make([][2]*ring.Poly, lenShare) - } - ptr := (1) - for i := (0); i < int(lenShare); i++ { - if share.Value[i][0] == nil || share.Value[i][1] == nil { - share.Value[i][0] = new(ring.Poly) - share.Value[i][1] = new(ring.Poly) - } - - err := share.Value[i][0].UnmarshalBinary(data[ptr : ptr+rLength]) - if err != nil { +func (share *RKGShare) UnmarshalBinary(data []byte) (err error) { + share.Value = make([][2]rlwe.PolyQP, data[0]) + ptr := 1 + var inc int + for i := range share.Value { + if inc, err = share.Value[i][0].DecodePolyNew(data[ptr:]); err != nil { return err } - ptr += rLength - err = share.Value[i][1].UnmarshalBinary(data[ptr : ptr+rLength]) - if err != nil { + ptr += inc + + if inc, err = share.Value[i][1].DecodePolyNew(data[ptr:]); err != nil { return err } - ptr += rLength - + ptr += inc } return nil diff --git a/drlwe/keygen_rot.go b/drlwe/keygen_rot.go index c7ed3767..2ebf2e10 100644 --- a/drlwe/keygen_rot.go +++ b/drlwe/keygen_rot.go @@ -1,9 +1,7 @@ package drlwe import ( - "encoding/binary" "errors" - "math/big" "github.com/ldsec/lattigo/v2/ring" "github.com/ldsec/lattigo/v2/rlwe" @@ -13,69 +11,90 @@ import ( // RotationKeyGenerator is an interface for the local operation in the generation of rotation keys type RotationKeyGenerator interface { AllocateShares() (rtgShare *RTGShare) - GenShare(sk *rlwe.SecretKey, galEl uint64, crp []*ring.Poly, shareOut *RTGShare) + GenShare(sk *rlwe.SecretKey, galEl uint64, crp RTGCRP, shareOut *RTGShare) Aggregate(share1, share2, shareOut *RTGShare) - GenRotationKey(share *RTGShare, crp []*ring.Poly, rotKey *rlwe.SwitchingKey) + GenRotationKey(share *RTGShare, crp RTGCRP, rotKey *rlwe.SwitchingKey) } // RTGShare is represent a Party's share in the RTG protocol type RTGShare struct { - Value []*ring.Poly + Value []rlwe.PolyQP } +// RTGCRP is a type for common reference polynomials in the RTG protocol. +type RTGCRP []rlwe.PolyQP + // RTGProtocol is the structure storing the parameters for the collective rotation-keys generation. -type RTGProtocol struct { // TODO rename GaloisKeyGen ? - params rlwe.Parameters - ringQP *ring.Ring - pBigInt *big.Int - tmpPoly [2]*ring.Poly - gaussianSampler *ring.GaussianSampler +type RTGProtocol struct { + params rlwe.Parameters + tmpPoly0 rlwe.PolyQP + tmpPoly1 rlwe.PolyQP + gaussianSamplerQ *ring.GaussianSampler } // NewRTGProtocol creates a RTGProtocol instance func NewRTGProtocol(params rlwe.Parameters) *RTGProtocol { rtg := new(RTGProtocol) rtg.params = params - rtg.ringQP = params.RingQP() - rtg.pBigInt = params.PBigInt() - var err error + prng, err := utils.NewPRNG() if err != nil { panic(err) } - rtg.gaussianSampler = ring.NewGaussianSampler(prng, rtg.ringQP, params.Sigma(), int(6*params.Sigma())) - rtg.tmpPoly = [2]*ring.Poly{rtg.ringQP.NewPoly(), rtg.ringQP.NewPoly()} + rtg.gaussianSamplerQ = ring.NewGaussianSampler(prng, params.RingQ(), params.Sigma(), int(6*params.Sigma())) + rtg.tmpPoly0 = params.RingQP().NewPoly() + rtg.tmpPoly1 = params.RingQP().NewPoly() return rtg } // AllocateShares allocates a party's share in the RTG protocol func (rtg *RTGProtocol) AllocateShares() (rtgShare *RTGShare) { rtgShare = new(RTGShare) - rtgShare.Value = make([]*ring.Poly, rtg.params.Beta()) + rtgShare.Value = make([]rlwe.PolyQP, rtg.params.Beta()) for i := range rtgShare.Value { - rtgShare.Value[i] = rtg.ringQP.NewPoly() + rtgShare.Value[i] = rtg.params.RingQP().NewPoly() } return } +// SampleCRP samples a common random polynomial to be used in the RTG protocol from the provided +// common reference string. +func (rtg *RTGProtocol) SampleCRP(crs CRS) RTGCRP { + crp := make([]rlwe.PolyQP, rtg.params.Beta()) + us := rlwe.NewUniformSamplerQP(rtg.params, crs, rtg.params.RingQP()) + for i := range crp { + crp[i] = rtg.params.RingQP().NewPoly() + us.Read(&crp[i]) + } + return RTGCRP(crp) +} + // GenShare generates a party's share in the RTG protocol -func (rtg *RTGProtocol) GenShare(sk *rlwe.SecretKey, galEl uint64, crp []*ring.Poly, shareOut *RTGShare) { +func (rtg *RTGProtocol) GenShare(sk *rlwe.SecretKey, galEl uint64, crp RTGCRP, shareOut *RTGShare) { - twoN := rtg.ringQP.N << 2 - galElInv := ring.ModExp(galEl, int(twoN-1), uint64(twoN)) + ringQ := rtg.params.RingQ() + ringP := rtg.params.RingP() + ringQP := rtg.params.RingQP() + levelQ := rtg.params.QCount() - 1 + levelP := rtg.params.PCount() - 1 - ring.PermuteNTT(sk.Value, galElInv, rtg.tmpPoly[1]) + twoN := uint64(ringQ.N << 1) + galElInv := ring.ModExp(galEl, twoN-1, twoN) - rtg.ringQP.MulScalarBigint(sk.Value, rtg.pBigInt, rtg.tmpPoly[0]) + ring.PermuteNTT(sk.Value.Q, galElInv, rtg.tmpPoly1.Q) + ring.PermuteNTT(sk.Value.P, galElInv, rtg.tmpPoly1.P) + + ringQ.MulScalarBigint(sk.Value.Q, ringP.ModulusBigint, rtg.tmpPoly0.Q) var index int for i := 0; i < rtg.params.Beta(); i++ { // e - rtg.gaussianSampler.Read(shareOut.Value[i]) - rtg.ringQP.NTTLazy(shareOut.Value[i], shareOut.Value[i]) - rtg.ringQP.MForm(shareOut.Value[i], shareOut.Value[i]) + rtg.gaussianSamplerQ.Read(shareOut.Value[i].Q) + ringQP.ExtendBasisSmallNormAndCenter(shareOut.Value[i].Q, levelP, nil, shareOut.Value[i].P) + ringQP.NTTLazyLvl(levelQ, levelP, shareOut.Value[i], shareOut.Value[i]) + ringQP.MFormLvl(levelQ, levelP, shareOut.Value[i], shareOut.Value[i]) // a is the CRP @@ -85,79 +104,70 @@ func (rtg *RTGProtocol) GenShare(sk *rlwe.SecretKey, galEl uint64, crp []*ring.P index = i*rtg.params.PCount() + j - qi := rtg.ringQP.Modulus[index] - tmp0 := rtg.tmpPoly[0].Coeffs[index] - tmp1 := shareOut.Value[i].Coeffs[index] - - for w := 0; w < rtg.ringQP.N; w++ { - tmp1[w] = ring.CRed(tmp1[w]+tmp0[w], qi) - } - // Handles the case where nb pj does not divides nb qi if index >= rtg.params.QCount() { break } + + qi := ringQ.Modulus[index] + tmp0 := rtg.tmpPoly0.Q.Coeffs[index] + tmp1 := shareOut.Value[i].Q.Coeffs[index] + + for w := 0; w < ringQ.N; w++ { + tmp1[w] = ring.CRed(tmp1[w]+tmp0[w], qi) + } } // sk_in * (qiBarre*qiStar) * 2^w - a*sk + e - rtg.ringQP.MulCoeffsMontgomeryAndSub(crp[i], rtg.tmpPoly[1], shareOut.Value[i]) + ringQP.MulCoeffsMontgomeryAndSubLvl(levelQ, levelP, crp[i], rtg.tmpPoly1, shareOut.Value[i]) } - - rtg.tmpPoly[0].Zero() - rtg.tmpPoly[1].Zero() - - return } // Aggregate aggregates two shares in the Rotation Key Generation protocol func (rtg *RTGProtocol) Aggregate(share1, share2, shareOut *RTGShare) { + ringQP, levelQ, levelP := rtg.params.RingQP(), rtg.params.QCount()-1, rtg.params.PCount()-1 for i := 0; i < rtg.params.Beta(); i++ { - rtg.ringQP.Add(share1.Value[i], share2.Value[i], shareOut.Value[i]) + ringQP.AddLvl(levelQ, levelP, share1.Value[i], share2.Value[i], shareOut.Value[i]) } } // GenRotationKey finalizes the RTG protocol and populates the input RotationKey with the computed collective SwitchingKey. -func (rtg *RTGProtocol) GenRotationKey(share *RTGShare, crp []*ring.Poly, rotKey *rlwe.SwitchingKey) { +func (rtg *RTGProtocol) GenRotationKey(share *RTGShare, crp RTGCRP, rotKey *rlwe.SwitchingKey) { for i := 0; i < rtg.params.Beta(); i++ { - ring.CopyValues(share.Value[i], rotKey.Value[i][0]) - ring.CopyValues(crp[i], rotKey.Value[i][1]) + rotKey.Value[i][0].CopyValues(share.Value[i]) + rotKey.Value[i][1].CopyValues(crp[i]) } } // MarshalBinary encode the target element on a slice of byte. -func (share *RTGShare) MarshalBinary() ([]byte, error) { - lenRing := share.Value[0].GetDataLen(true) - data := make([]byte, 8+lenRing*len(share.Value)) - binary.BigEndian.PutUint64(data[:8], uint64(lenRing)) - ptr := 8 +func (share *RTGShare) MarshalBinary() (data []byte, err error) { + data = make([]byte, 1+share.Value[0].GetDataLen(true)*len(share.Value)) + if len(share.Value) > 0xFF { + return []byte{}, errors.New("RKGShare : uint8 overflow on length") + } + data[0] = uint8(len(share.Value)) + ptr := 1 + var inc int for _, val := range share.Value { - cnt, err := val.WriteTo(data[ptr : ptr+lenRing]) - if err != nil { + if inc, err = val.WriteTo(data[ptr:]); err != nil { return []byte{}, err } - ptr += cnt + ptr += inc } return data, nil } // UnmarshalBinary decodes a slice of bytes on the target element. -func (share *RTGShare) UnmarshalBinary(data []byte) error { - if len(data) <= 8 { - return errors.New("Unsufficient data length") - } - - lenRing := binary.BigEndian.Uint64(data[:8]) - valLength := uint64(len(data)-8) / lenRing - share.Value = make([]*ring.Poly, valLength) - ptr := uint64(8) +func (share *RTGShare) UnmarshalBinary(data []byte) (err error) { + share.Value = make([]rlwe.PolyQP, data[0]) + ptr := 1 + var inc int for i := range share.Value { - share.Value[i] = new(ring.Poly) - err := share.Value[i].UnmarshalBinary(data[ptr : ptr+lenRing]) - if err != nil { + if inc, err = share.Value[i].DecodePolyNew(data[ptr:]); err != nil { return err } - ptr += lenRing + ptr += inc } return nil diff --git a/drlwe/keyswitch_pk.go b/drlwe/keyswitch_pk.go index 05172e4f..76a2ea37 100644 --- a/drlwe/keyswitch_pk.go +++ b/drlwe/keyswitch_pk.go @@ -8,7 +8,7 @@ import ( // PublicKeySwitchingProtocol is an interface describing the local steps of a generic RLWE PCKS protocol. type PublicKeySwitchingProtocol interface { - AllocateShare(level int) *PCKSShare + AllocateShare(levelQ int) *PCKSShare GenShare(skInput *rlwe.SecretKey, pkOutput *rlwe.PublicKey, ct *rlwe.Ciphertext, shareOut *PCKSShare) AggregateShares(share1, share2, shareOut *PCKSShare) KeySwitch(combined *PCKSShare, ct *rlwe.Ciphertext, ctOut *rlwe.Ciphertext) @@ -21,17 +21,11 @@ type PCKSShare struct { // PCKSProtocol is the structure storing the parameters for the collective public key-switching. type PCKSProtocol struct { - ringQ *ring.Ring - ringP *ring.Ring - ringQP *ring.Ring + params rlwe.Parameters sigmaSmudging float64 - tmpQ *ring.Poly - tmpP *ring.Poly - share0tmpQ *ring.Poly - share1tmpQ *ring.Poly - share0tmpP *ring.Poly - share1tmpP *ring.Poly + tmpQP rlwe.PolyQP + tmpP [2]*ring.Poly baseconverter *ring.FastBasisExtender gaussianSampler *ring.GaussianSampler @@ -40,36 +34,28 @@ type PCKSProtocol struct { // NewPCKSProtocol creates a new PCKSProtocol object and will be used to re-encrypt a ciphertext ctx encrypted under a secret-shared key among j parties under a new // collective public-key. -func NewPCKSProtocol(params rlwe.Parameters, sigmaSmudging float64) *PCKSProtocol { - pcks := new(PCKSProtocol) - pcks.ringQ = params.RingQ() - pcks.ringP = params.RingP() +func NewPCKSProtocol(params rlwe.Parameters, sigmaSmudging float64) (pcks *PCKSProtocol) { + pcks = new(PCKSProtocol) + pcks.params = params pcks.sigmaSmudging = sigmaSmudging - pcks.tmpQ = pcks.ringQ.NewPoly() - pcks.tmpP = pcks.ringP.NewPoly() - pcks.share0tmpQ = pcks.ringQ.NewPoly() - pcks.share1tmpQ = pcks.ringQ.NewPoly() - pcks.share0tmpP = pcks.ringP.NewPoly() - pcks.share1tmpP = pcks.ringP.NewPoly() + pcks.tmpQP = params.RingQP().NewPoly() + pcks.tmpP = [2]*ring.Poly{params.RingP().NewPoly(), params.RingP().NewPoly()} - pcks.baseconverter = ring.NewFastBasisExtender(pcks.ringQ, pcks.ringP) + pcks.baseconverter = ring.NewFastBasisExtender(params.RingQ(), params.RingP()) prng, err := utils.NewPRNG() if err != nil { panic(err) } - pcks.gaussianSampler = ring.NewGaussianSampler(prng, pcks.ringQ, sigmaSmudging, int(6*sigmaSmudging)) - pcks.ternarySamplerMontgomeryQ = ring.NewTernarySampler(prng, pcks.ringQ, 0.5, false) + pcks.gaussianSampler = ring.NewGaussianSampler(prng, params.RingQ(), sigmaSmudging, int(6*sigmaSmudging)) + pcks.ternarySamplerMontgomeryQ = ring.NewTernarySampler(prng, params.RingQ(), 0.5, false) return pcks } // AllocateShare allocates the shares of the PCKS protocol -func (pcks *PCKSProtocol) AllocateShare(level int) (s *PCKSShare) { - s = new(PCKSShare) - s.Value[0] = pcks.ringQ.NewPolyLvl(level) - s.Value[1] = pcks.ringQ.NewPolyLvl(level) - return +func (pcks *PCKSProtocol) AllocateShare(levelQ int) (s *PCKSShare) { + return &PCKSShare{[2]*ring.Poly{pcks.params.RingQ().NewPolyLvl(levelQ), pcks.params.RingQ().NewPolyLvl(levelQ)}} } // GenShare is the first part of the unique round of the PCKSProtocol protocol. Each party computes the following : @@ -81,73 +67,59 @@ func (pcks *PCKSProtocol) GenShare(sk *rlwe.SecretKey, pk *rlwe.PublicKey, ct *r el := ct.RLWEElement() - ringQ := pcks.ringQ - ringP := pcks.ringP + ringQ := pcks.params.RingQ() + ringP := pcks.params.RingP() + ringQP := pcks.params.RingQP() - level := el.Level() - - pk0Q := new(ring.Poly) - pk0P := new(ring.Poly) - pk1Q := new(ring.Poly) - pk1P := new(ring.Poly) - - // Splits pk[0] and pk[1] with their respective modulus Qlvl and P - pk0Q.Coeffs = pk.Value[0].Coeffs[:level+1] - pk1Q.Coeffs = pk.Value[1].Coeffs[:level+1] - pk0P.Coeffs = pk.Value[0].Coeffs[len(ringQ.Modulus):] - pk1P.Coeffs = pk.Value[1].Coeffs[len(ringQ.Modulus):] + levelQ := el.Level() + levelP := len(ringP.Modulus) - 1 // samples MForm(u_i) in Q and P separately - pcks.ternarySamplerMontgomeryQ.ReadLvl(level, pcks.tmpQ) - extendBasisSmallNormAndCenter(ringQ.Modulus[0], ringP.Modulus, pcks.tmpQ.Coeffs[0], pcks.tmpP.Coeffs) - ringQ.MFormLvl(level, pcks.tmpQ, pcks.tmpQ) - ringP.MForm(pcks.tmpP, pcks.tmpP) - ringQ.NTTLvl(level, pcks.tmpQ, pcks.tmpQ) - ringP.NTT(pcks.tmpP, pcks.tmpP) + pcks.ternarySamplerMontgomeryQ.ReadLvl(levelQ, pcks.tmpQP.Q) + ringQP.ExtendBasisSmallNormAndCenter(pcks.tmpQP.Q, levelP, nil, pcks.tmpQP.P) + ringQP.MFormLvl(levelQ, levelP, pcks.tmpQP, pcks.tmpQP) + ringQP.NTTLvl(levelQ, levelP, pcks.tmpQP, pcks.tmpQP) - // h_0 = NTT(u_i * pk_0) - ringQ.MulCoeffsMontgomeryLvl(level, pcks.tmpQ, pk0Q, pcks.share0tmpQ) - ringP.MulCoeffsMontgomery(pcks.tmpP, pk0P, pcks.share0tmpP) - ringQ.InvNTTLvl(level, pcks.share0tmpQ, pcks.share0tmpQ) - ringP.InvNTT(pcks.share0tmpP, pcks.share0tmpP) + shareOutQP0 := rlwe.PolyQP{Q: shareOut.Value[0], P: pcks.tmpP[0]} + shareOutQP1 := rlwe.PolyQP{Q: shareOut.Value[1], P: pcks.tmpP[1]} - // h_1 = NTT(u_i * pk_1) - ringQ.MulCoeffsMontgomeryLvl(level, pcks.tmpQ, pk1Q, pcks.share1tmpQ) - ringP.MulCoeffsMontgomery(pcks.tmpP, pk1P, pcks.share1tmpP) - ringQ.InvNTTLvl(level, pcks.share1tmpQ, pcks.share1tmpQ) - ringP.InvNTT(pcks.share1tmpP, pcks.share1tmpP) + // h_0 = u_i * pk_0 + // h_1 = u_i * pk_1 + ringQP.MulCoeffsMontgomeryLvl(levelQ, levelP, pcks.tmpQP, pk.Value[0], shareOutQP0) + ringQP.MulCoeffsMontgomeryLvl(levelQ, levelP, pcks.tmpQP, pk.Value[1], shareOutQP1) - // h_0 = u_i * pk_0 + e0 - pcks.gaussianSampler.ReadLvl(level, pcks.tmpQ) - extendBasisSmallNormAndCenter(ringQ.Modulus[0], ringP.Modulus, pcks.tmpQ.Coeffs[0], pcks.tmpP.Coeffs) - ringQ.AddLvl(level, pcks.share0tmpQ, pcks.tmpQ, pcks.share0tmpQ) - ringP.Add(pcks.share0tmpP, pcks.tmpP, pcks.share0tmpP) + ringQP.InvNTTLvl(levelQ, levelP, shareOutQP0, shareOutQP0) + ringQP.InvNTTLvl(levelQ, levelP, shareOutQP1, shareOutQP1) + + // h_0 = u_i * pk_0 + pcks.gaussianSampler.ReadLvl(levelQ, pcks.tmpQP.Q) + ringQP.ExtendBasisSmallNormAndCenter(pcks.tmpQP.Q, levelP, nil, pcks.tmpQP.P) + ringQP.AddLvl(levelQ, levelP, shareOutQP0, pcks.tmpQP, shareOutQP0) // h_1 = u_i * pk_1 + e1 - pcks.gaussianSampler.ReadLvl(level, pcks.tmpQ) - extendBasisSmallNormAndCenter(ringQ.Modulus[0], ringP.Modulus, pcks.tmpQ.Coeffs[0], pcks.tmpP.Coeffs) - ringQ.AddLvl(level, pcks.share1tmpQ, pcks.tmpQ, pcks.share1tmpQ) - ringP.Add(pcks.share1tmpP, pcks.tmpP, pcks.share1tmpP) + pcks.gaussianSampler.ReadLvl(levelQ, pcks.tmpQP.Q) + ringQP.ExtendBasisSmallNormAndCenter(pcks.tmpQP.Q, levelP, nil, pcks.tmpQP.P) + ringQP.AddLvl(levelQ, levelP, shareOutQP1, pcks.tmpQP, shareOutQP1) // h_0 = (u_i * pk_0 + e0)/P - pcks.baseconverter.ModDownSplitPQ(level, pcks.share0tmpQ, pcks.share0tmpP, shareOut.Value[0]) + pcks.baseconverter.ModDownQPtoQ(levelQ, levelP, shareOutQP0.Q, shareOutQP0.P, shareOutQP0.Q) // h_1 = (u_i * pk_1 + e1)/P - pcks.baseconverter.ModDownSplitPQ(level, pcks.share1tmpQ, pcks.share1tmpP, shareOut.Value[1]) + pcks.baseconverter.ModDownQPtoQ(levelQ, levelP, shareOutQP1.Q, shareOutQP1.P, shareOutQP1.Q) // h_0 = s_i*c_1 + (u_i * pk_0 + e0)/P if el.Value[0].IsNTT { - ringQ.NTTLvl(level, shareOut.Value[0], shareOut.Value[0]) - ringQ.NTTLvl(level, shareOut.Value[1], shareOut.Value[1]) - ringQ.MulCoeffsMontgomeryAndAddLvl(level, el.Value[1], sk.Value, shareOut.Value[0]) + ringQ.NTTLvl(levelQ, shareOut.Value[0], shareOut.Value[0]) + ringQ.NTTLvl(levelQ, shareOut.Value[1], shareOut.Value[1]) + ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, el.Value[1], sk.Value.Q, shareOut.Value[0]) } else { // tmp = s_i*c_1 - ringQ.NTTLazyLvl(level, el.Value[1], pcks.tmpQ) - ringQ.MulCoeffsMontgomeryConstantLvl(level, pcks.tmpQ, sk.Value, pcks.tmpQ) - ringQ.InvNTTLvl(level, pcks.tmpQ, pcks.tmpQ) + ringQ.NTTLazyLvl(levelQ, el.Value[1], pcks.tmpQP.Q) + ringQ.MulCoeffsMontgomeryConstantLvl(levelQ, pcks.tmpQP.Q, sk.Value.Q, pcks.tmpQP.Q) + ringQ.InvNTTLvl(levelQ, pcks.tmpQP.Q, pcks.tmpQP.Q) // h_0 = s_i*c_1 + (u_i * pk_0 + e0)/P - ringQ.AddLvl(level, shareOut.Value[0], pcks.tmpQ, shareOut.Value[0]) + ringQ.AddLvl(levelQ, shareOut.Value[0], pcks.tmpQP.Q, shareOut.Value[0]) } } @@ -156,60 +128,49 @@ func (pcks *PCKSProtocol) GenShare(sk *rlwe.SecretKey, pk *rlwe.PublicKey, ct *r // // [ctx[0] + sum(s_i * ctx[0] + u_i * pk[0] + e_0i), sum(u_i * pk[1] + e_1i)] func (pcks *PCKSProtocol) AggregateShares(share1, share2, shareOut *PCKSShare) { - level1, level2 := len(share1.Value[0].Coeffs)-1, len(share2.Value[0].Coeffs)-1 - if level1 != level2 { - panic("cannot aggreate two shares at different levels.") + levelQ1, levelQ2 := len(share1.Value[0].Coeffs)-1, len(share2.Value[1].Coeffs)-1 + if levelQ1 != levelQ2 { + panic("cannot aggreate two shares at different levelQs.") } - pcks.ringQ.AddLvl(level1, share1.Value[0], share2.Value[0], shareOut.Value[0]) - pcks.ringQ.AddLvl(level1, share1.Value[1], share2.Value[1], shareOut.Value[1]) + pcks.params.RingQ().AddLvl(levelQ1, share1.Value[0], share2.Value[0], shareOut.Value[0]) + pcks.params.RingQ().AddLvl(levelQ1, share1.Value[1], share2.Value[1], shareOut.Value[1]) + } // KeySwitch performs the actual keyswitching operation on a ciphertext ct and put the result in ctOut func (pcks *PCKSProtocol) KeySwitch(combined *PCKSShare, ct, ctOut *rlwe.Ciphertext) { el, elOut := ct.RLWEElement(), ctOut.RLWEElement() - pcks.ringQ.AddLvl(el.Level(), el.Value[0], combined.Value[0], elOut.Value[0]) + pcks.params.RingQ().AddLvl(el.Level(), el.Value[0], combined.Value[0], elOut.Value[0]) ring.CopyValuesLvl(el.Level(), combined.Value[1], elOut.Value[1]) } // MarshalBinary encodes a PCKS share on a slice of bytes. -func (share *PCKSShare) MarshalBinary() ([]byte, error) { - lenR1 := share.Value[0].GetDataLen(true) - lenR2 := share.Value[1].GetDataLen(true) - - data := make([]byte, lenR1+lenR2) - _, err := share.Value[0].WriteTo(data[0:lenR1]) - if err != nil { - return []byte{}, err +func (share *PCKSShare) MarshalBinary() (data []byte, err error) { + data = make([]byte, share.Value[0].GetDataLen(true)+share.Value[1].GetDataLen(true)) + var inc, pt int + if inc, err = share.Value[0].WriteTo(data[pt:]); err != nil { + return nil, err } + pt += inc - _, err = share.Value[1].WriteTo(data[lenR1 : lenR1+lenR2]) - if err != nil { - return []byte{}, err + if _, err = share.Value[1].WriteTo(data[pt:]); err != nil { + return nil, err } - - return data, nil + return } // UnmarshalBinary decodes marshaled PCKS share on the target PCKS share. -func (share *PCKSShare) UnmarshalBinary(data []byte) error { - - if share.Value[0] == nil { - share.Value[0] = new(ring.Poly) +func (share *PCKSShare) UnmarshalBinary(data []byte) (err error) { + var pt, inc int + share.Value[0] = new(ring.Poly) + if inc, err = share.Value[0].DecodePolyNew(data[pt:]); err != nil { + return } + pt += inc - if share.Value[1] == nil { - share.Value[1] = new(ring.Poly) + share.Value[1] = new(ring.Poly) + if _, err = share.Value[1].DecodePolyNew(data[pt:]); err != nil { + return } - - err := share.Value[0].UnmarshalBinary(data[0 : len(data)/2]) - if err != nil { - return err - } - - err = share.Value[1].UnmarshalBinary(data[len(data)/2:]) - if err != nil { - return err - } - - return nil + return } diff --git a/drlwe/keyswitch_sk.go b/drlwe/keyswitch_sk.go index 3082bc7e..8ff5c691 100644 --- a/drlwe/keyswitch_sk.go +++ b/drlwe/keyswitch_sk.go @@ -16,14 +16,11 @@ type KeySwitchingProtocol interface { // CKSProtocol is the structure storing the parameters and and precomputations for the collective key-switching protocol. type CKSProtocol struct { - ringQ *ring.Ring - ringP *ring.Ring + params rlwe.Parameters gaussianSampler *ring.GaussianSampler baseconverter *ring.FastBasisExtender - - tmpP *ring.Poly - tmpQ *ring.Poly - tmpDelta *ring.Poly + tmpQP rlwe.PolyQP + tmpDelta *ring.Poly } // CKSShare is a type for the CKS protocol shares. @@ -31,16 +28,17 @@ type CKSShare struct { Value *ring.Poly } +// CKSCRP is a type for common reference polynomials in the CKS protocol. +type CKSCRP ring.Poly + // MarshalBinary encodes a CKS share on a slice of bytes. -func (ckss *CKSShare) MarshalBinary() ([]byte, error) { +func (ckss *CKSShare) MarshalBinary() (data []byte, err error) { return ckss.Value.MarshalBinary() } // UnmarshalBinary decodes marshaled CKS share on the target CKS share. -func (ckss *CKSShare) UnmarshalBinary(data []byte) error { - if ckss.Value == nil { - ckss.Value = new(ring.Poly) - } +func (ckss *CKSShare) UnmarshalBinary(data []byte) (err error) { + ckss.Value = new(ring.Poly) return ckss.Value.UnmarshalBinary(data) } @@ -49,27 +47,29 @@ func (ckss *CKSShare) UnmarshalBinary(data []byte) error { // parties. func NewCKSProtocol(params rlwe.Parameters, sigmaSmudging float64) *CKSProtocol { cks := new(CKSProtocol) - cks.ringQ = params.RingQ() - cks.ringP = params.RingP() // TODO this assumes that P larger than 1 - + cks.params = params prng, err := utils.NewPRNG() if err != nil { panic(err) } - - cks.gaussianSampler = ring.NewGaussianSampler(prng, cks.ringQ, sigmaSmudging, int(6*sigmaSmudging)) - cks.baseconverter = ring.NewFastBasisExtender(cks.ringQ, cks.ringP) - - cks.tmpQ = cks.ringQ.NewPoly() - cks.tmpP = cks.ringP.NewPoly() - cks.tmpDelta = cks.ringQ.NewPoly() - + cks.gaussianSampler = ring.NewGaussianSampler(prng, params.RingQ(), sigmaSmudging, int(6*sigmaSmudging)) + cks.baseconverter = ring.NewFastBasisExtender(params.RingQ(), params.RingP()) + cks.tmpQP = params.RingQP().NewPoly() + cks.tmpDelta = params.RingQ().NewPoly() return cks } // AllocateShare allocates the shares of the CKSProtocol func (cks *CKSProtocol) AllocateShare(level int) *CKSShare { - return &CKSShare{cks.ringQ.NewPolyLvl(level)} + return &CKSShare{cks.params.RingQ().NewPolyLvl(level)} +} + +// SampleCRP samples a common random polynomial to be used in the CKS protocol from the provided +// common reference string. +func (cks *CKSProtocol) SampleCRP(level int, crs CRS) CKSCRP { + crp := cks.params.RingQ().NewPolyLvl(level) + ring.NewUniformSampler(crs, cks.params.RingQ()).ReadLvl(level, crp) + return CKSCRP(*crp) } // GenShare computes a party's share in the CKS protocol. @@ -79,57 +79,60 @@ func (cks *CKSProtocol) GenShare(skInput, skOutput *rlwe.SecretKey, ct *rlwe.Cip el := ct.RLWEElement() - ringQ := cks.ringQ - ringP := cks.ringP + ringQ := cks.params.RingQ() + ringP := cks.params.RingP() + ringQP := cks.params.RingQP() level := utils.MinInt(len(ringQ.Modulus)-1, el.Value[1].Level()) + levelP := cks.params.PCount() - 1 - ringQ.SubLvl(level, skInput.Value, skOutput.Value, cks.tmpDelta) + ringQ.SubLvl(level, skInput.Value.Q, skOutput.Value.Q, cks.tmpDelta) ct1 := el.Value[1] if !el.Value[1].IsNTT { - ringQ.NTTLazyLvl(level, el.Value[1], cks.tmpQ) - ct1 = cks.tmpQ + ringQ.NTTLazyLvl(level, el.Value[1], cks.tmpQP.Q) + ct1 = cks.tmpQP.Q } // a * (skIn - skOut) mod Q ringQ.MulCoeffsMontgomeryConstantLvl(level, ct1, cks.tmpDelta, shareOut.Value) // P * a * (skIn - skOut) mod QP (mod P = 0) - ringQ.MulScalarBigintLvl(level, shareOut.Value, cks.ringP.ModulusBigint, shareOut.Value) + ringQ.MulScalarBigintLvl(level, shareOut.Value, ringP.ModulusBigint, shareOut.Value) if !el.Value[1].IsNTT { // InvNTT(P * a * (skIn - skOut)) mod QP (mod P = 0) ringQ.InvNTTLazyLvl(level, shareOut.Value, shareOut.Value) // Samples e in Q - cks.gaussianSampler.ReadLvl(level, cks.tmpQ) + cks.gaussianSampler.ReadLvl(level, cks.tmpQP.Q) // Extend e to P (assumed to have norm < qi) - extendBasisSmallNormAndCenter(ringQ.Modulus[0], ringP.Modulus, cks.tmpQ.Coeffs[0], cks.tmpP.Coeffs) + ringQP.ExtendBasisSmallNormAndCenter(cks.tmpQP.Q, levelP, nil, cks.tmpQP.P) // InvNTT(P * a * (skIn - skOut) + e) mod QP (mod P = e) - ringQ.AddNoModLvl(level, shareOut.Value, cks.tmpQ, shareOut.Value) + ringQ.AddNoModLvl(level, shareOut.Value, cks.tmpQP.Q, shareOut.Value) // InvNTT(P * a * (skIn - skOut) + e) * (1/P) mod QP (mod P = e) - cks.baseconverter.ModDownSplitPQ(level, shareOut.Value, cks.tmpP, shareOut.Value) + cks.baseconverter.ModDownQPtoQ(level, levelP, shareOut.Value, cks.tmpQP.P, shareOut.Value) } else { // Sample e in Q - cks.gaussianSampler.ReadLvl(level, cks.tmpQ) + cks.gaussianSampler.ReadLvl(level, cks.tmpQP.Q) // Extend e to P (assumed to have norm < qi) - extendBasisSmallNormAndCenter(ringQ.Modulus[0], ringP.Modulus, cks.tmpQ.Coeffs[0], cks.tmpP.Coeffs) + ringQP.ExtendBasisSmallNormAndCenter(cks.tmpQP.Q, levelP, nil, cks.tmpQP.P) // Takes the error to the NTT domain - ringQ.NTTLvl(level, cks.tmpQ, cks.tmpQ) - ringP.NTT(cks.tmpP, cks.tmpP) + ringQ.InvNTTLvl(level, shareOut.Value, shareOut.Value) // P * a * (skIn - skOut) + e mod Q (mod P = 0, so P = e) - ringQ.AddLvl(level, shareOut.Value, cks.tmpQ, shareOut.Value) + ringQ.AddLvl(level, shareOut.Value, cks.tmpQP.Q, shareOut.Value) // (P * a * (skIn - skOut) + e) * (1/P) mod QP (mod P = e) - cks.baseconverter.ModDownSplitNTTPQ(level, shareOut.Value, cks.tmpP, shareOut.Value) + cks.baseconverter.ModDownQPtoQ(level, levelP, shareOut.Value, cks.tmpQP.P, shareOut.Value) + + ringQ.NTTLvl(level, shareOut.Value, shareOut.Value) } shareOut.Value.Coeffs = shareOut.Value.Coeffs[:level+1] @@ -139,12 +142,12 @@ func (cks *CKSProtocol) GenShare(skInput, skOutput *rlwe.SecretKey, ct *rlwe.Cip // // [ctx[0] + sum((skInput_i - skOutput_i) * ctx[0] + e_i), ctx[1]] func (cks *CKSProtocol) AggregateShares(share1, share2, shareOut *CKSShare) { - cks.ringQ.AddLvl(share1.Value.Level(), share1.Value, share2.Value, shareOut.Value) + cks.params.RingQ().AddLvl(share1.Value.Level(), share1.Value, share2.Value, shareOut.Value) } // KeySwitch performs the actual keyswitching operation on a ciphertext ct and put the result in ctOut func (cks *CKSProtocol) KeySwitch(combined *CKSShare, ct, ctOut *rlwe.Ciphertext) { el, elOut := ct.RLWEElement(), ctOut.RLWEElement() - cks.ringQ.AddLvl(el.Level(), el.Value[0], combined.Value, elOut.Value[0]) + cks.params.RingQ().AddLvl(el.Level(), el.Value[0], combined.Value, elOut.Value[0]) ring.CopyValuesLvl(el.Level(), el.Value[1], elOut.Value[1]) } diff --git a/drlwe/utils.go b/drlwe/utils.go deleted file mode 100644 index 472ec077..00000000 --- a/drlwe/utils.go +++ /dev/null @@ -1,18 +0,0 @@ -package drlwe - -func extendBasisSmallNormAndCenter(Q uint64, modulusP []uint64, coeffsQ []uint64, coeffsP [][]uint64) { - QHalf := Q >> 1 - var sign uint64 - for j, c := range coeffsQ { - - sign = 1 - if c > QHalf { - c = Q - c - sign = 0 - } - - for i, pi := range modulusP { - coeffsP[i][j] = (c * sign) | (pi-c)*(sign^1) - } - } -} diff --git a/examples/ckks/advanced/main.go b/examples/ckks/advanced/main.go new file mode 100644 index 00000000..5b276a55 --- /dev/null +++ b/examples/ckks/advanced/main.go @@ -0,0 +1,416 @@ +package main + +import ( + "fmt" + "math" + "time" + + "github.com/ldsec/lattigo/v2/ckks" + ckksAdvanced "github.com/ldsec/lattigo/v2/ckks/advanced" + "github.com/ldsec/lattigo/v2/ring" + "github.com/ldsec/lattigo/v2/rlwe" + "github.com/ldsec/lattigo/v2/utils" +) + +// This example is an implementation of the RLWE -> LWE extraction followed by an LWE -> RLWE repacking +// (bridge between CKKS and FHEW ciphertext) based on "Pegasus: Bridging Polynomial and Non-polynomial +// Evaluations in Homomorphic Encryption". +// It showcases advanced tools of the CKKS scheme, such as homomorphic decoding and homomorphic modular reduction. + +func main() { + + // Ring Learning With Error parameters + fmt.Printf("Gen RLWE Parameters... ") + start := time.Now() + paramsRLWE := genRLWEParameters() + fmt.Printf("Done (%s)\n", time.Since(start)) + + // Learning With Error parameters + fmt.Printf("Gen LWE Parameters... ") + start = time.Now() + paramsLWE := genLWEParameters(paramsRLWE) + fmt.Printf("Done (%s)\n", time.Since(start)) + + fmt.Printf("RLWE Params : logN=%2d, logQP=%3d\n", paramsRLWE.LogN(), paramsRLWE.LogQP()) + fmt.Printf("LWE Params : logN=%2d, logQP=%3d\n", paramsLWE.LogN(), paramsLWE.LogQP()) + + // Homomorphic decoding parameters + SlotsToCoeffsParameters := ckksAdvanced.EncodingMatrixLiteral{ + LinearTransformType: ckksAdvanced.SlotsToCoeffs, + LevelStart: 2, // starting level + BSGSRatio: 16.0, // ratio between n1/n2 for n1*n2 = slots + BitReversed: false, // bit-reversed input + ScalingFactor: [][]float64{ // Decomposition level of the encoding matrix + {paramsRLWE.QiFloat64(1)}, // Scale of the second matriox + {paramsRLWE.QiFloat64(2)}, // Scale of the first matrix + }, + } + + // Homomorphic modular reduction parameters + EvalModParameters := ckksAdvanced.EvalModLiteral{ + Q: paramsRLWE.Q()[0], // Modulus + LevelStart: paramsRLWE.MaxLevel() - 1, // Starting level of the procedure + SineType: ckksAdvanced.Cos1, // Type of approximation + MessageRatio: 256.0, // Q/|m| + K: 16, // Interval of approximation + SineDeg: 63, // Degree of approximation + DoubleAngle: 2, // Number of double angle evaluation + ArcSineDeg: 0, // Degree of arcsine Taylor polynomial + ScalingFactor: 1 << 60, // Scaling factor during the procedure + } + + // Generates the homomorphic modular reduction polynomial approximation + fmt.Printf("Gen EvalMod Poly... ") + start = time.Now() + EvalModPoly := ckksAdvanced.NewEvalModPolyFromLiteral(EvalModParameters) + fmt.Printf("Done (%s)\n", time.Since(start)) + + // RLWE Parameters + start = time.Now() + encoder := ckks.NewEncoder(paramsRLWE) + kgenRLWE := ckks.NewKeyGenerator(paramsRLWE) + skRLWE := kgenRLWE.GenSecretKey() + encryptor := ckks.NewEncryptor(paramsRLWE, skRLWE) + decryptor := ckks.NewDecryptor(paramsRLWE, skRLWE) + + fmt.Printf("Gen SlotsToCoeffs Matrices... ") + start = time.Now() + SlotsToCoeffsMatrix := ckksAdvanced.NewHomomorphicEncodingMatrixFromLiteral(SlotsToCoeffsParameters, encoder, paramsRLWE.LogN(), paramsRLWE.LogSlots(), 1.0) + fmt.Printf("Done (%s)\n", time.Since(start)) + + fmt.Printf("Gen Evaluation Keys:\n") + fmt.Printf(" Decoding Keys... ") + start = time.Now() + rotKey := kgenRLWE.GenRotationKeysForRotations(SlotsToCoeffsParameters.Rotations(paramsRLWE.LogN(), paramsRLWE.LogSlots()), true, skRLWE) + fmt.Printf("Done (%s)\n", time.Since(start)) + fmt.Printf(" Relinearization Key... ") + start = time.Now() + rlk := kgenRLWE.GenRelinearizationKey(skRLWE, 2) + fmt.Printf("Done (%s)\n", time.Since(start)) + + fmt.Printf(" Repacking Keys... ") + nonzerodiags := make([]int, paramsRLWE.Slots()) + for i := range nonzerodiags { + nonzerodiags[i] = i + } + rotationsRepack := paramsRLWE.RotationsForDiaMatrixMultRaw(nonzerodiags, paramsRLWE.Slots(), 16.0) + rotationsRepack = append(rotationsRepack, paramsRLWE.RotationsForTrace(paramsRLWE.LogSlots(), paramsLWE.LogN())...) + rotKeyRepack := kgenRLWE.GenRotationKeysForRotations(rotationsRepack, false, skRLWE) + + fmt.Printf("Done (%s)\n", time.Since(start)) + + eval := ckksAdvanced.NewEvaluator(paramsRLWE, rlwe.EvaluationKey{Rlk: rlk, Rtks: rotKey}) + + // LWE Parameters + kgenLWE := ckks.NewKeyGenerator(paramsLWE) + skLWE := kgenLWE.GenSecretKeySparse(64) + + // RLWE -> LWE Switching key + fmt.Printf(" RLWE -> LWE Switching Key... ") + start = time.Now() + swkRLWEDimToLWEDim := kgenRLWE.GenSwitchingKey(skRLWE, skLWE) + fmt.Printf("Done (%s)\n", time.Since(start)) + + // Encodes and Encrypts skLWE + fmt.Printf("Encode & Encrypt SK LWE... ") + start = time.Now() + skLWEInvNTT := paramsLWE.RingQ().NewPoly() + ring.CopyValues(skLWE.Value.Q, skLWEInvNTT) + paramsLWE.RingQ().InvNTT(skLWEInvNTT, skLWEInvNTT) + Q := paramsRLWE.Q()[0] + paramsLWE.RingQ().InvMFormLvl(0, skLWEInvNTT, skLWEInvNTT) + skFloat := make([]complex128, paramsLWE.N()) + for i, s := range skLWEInvNTT.Coeffs[0] { + if s >= Q>>1 { + skFloat[i] = -complex(float64(Q-s), 0) + } else { + skFloat[i] = complex(float64(s), 0) + } + + skFloat[i] *= complex(math.Pow(1.0/(EvalModPoly.K()*EvalModPoly.QDiff()), 0.5), 0) // sqrt(pre-scaling for Cheby) + } + + paramsLWE.RingQ().MFormLvl(0, skLWEInvNTT, skLWEInvNTT) + ptSk := ckks.NewPlaintext(paramsRLWE, paramsRLWE.MaxLevel(), paramsRLWE.QiFloat64(paramsRLWE.MaxLevel())) + encoder.Encode(ptSk, skFloat, paramsLWE.LogN()) + ctSk := encryptor.EncryptNew(ptSk) + fmt.Printf("Done (%s)\n", time.Since(start)) + + // ********** PLAINTEXT GENERATION & ENCRYPTION ************** + + // Random complex plaintext encrypted + fmt.Printf("Gen Plaintext & Encrypt... ") + start = time.Now() + values := make([]complex128, paramsRLWE.Slots()) + for i := range values { + values[i] = complex(float64(i+1)/float64(paramsRLWE.Slots()), 1+float64(i+1)/float64(paramsRLWE.Slots())) + } + + plaintext := ckks.NewPlaintext(paramsRLWE, paramsRLWE.MaxLevel(), paramsRLWE.Scale()) + // Must encode with 2*Slots because a real vector is returned + encoder.EncodeNTT(plaintext, values, utils.MinInt(paramsRLWE.LogSlots()+1, paramsRLWE.LogN()-1)) + ct := encryptor.EncryptNew(plaintext) + fmt.Printf("Done (%s)\n", time.Since(start)) + + // ******** STEP 1 : HOMOMORPHIC DECODING ******* + fmt.Printf("Homomorphic Decoding... ") + start = time.Now() + ct = eval.SlotsToCoeffsNew(ct, nil, SlotsToCoeffsMatrix) + fmt.Printf("Done (%s)\n", time.Since(start)) + + // ******** STEP 2 : RLWE -> LWE EXTRACTION ************* + + fmt.Printf("RLWE -> LWE Extraction... ") + start = time.Now() + // Scale the message to Delta = Q/MessageRatio + scale := math.Exp2(math.Round(math.Log2(float64(EvalModParameters.Q) / EvalModParameters.MessageRatio))) + eval.ScaleUp(ct, math.Round(scale/ct.Scale), ct) + + // Switch from RLWE parameters to LWE parameters + ctTmp := eval.SwitchKeysNew(ct, swkRLWEDimToLWEDim) + ctLWE := ckks.NewCiphertext(paramsLWE, 1, 0, ctTmp.Scale) + + // Switch the ciphertext outside of the NTT domain for the LWE extraction + for i := range ctLWE.Value { + paramsRLWE.RingQ().InvNTTLvl(0, ctTmp.Value[i], ctTmp.Value[i]) + } + + rlwe.SwitchCiphertextRingDegree(ctTmp.El(), ctLWE.El()) + + // RLWE -> LWE Extraction + lweReal, lweImag := ExtractLWESamplesBitReversed(ctLWE, paramsLWE) + fmt.Printf("Done (%s)\n", time.Since(start)) + + // Visual of some values + fmt.Println("Visual Comparison :") + fmt.Printf("Slot %4d : RLWE %f LWE %f\n", 0, values[0], complex(DecryptLWE(paramsLWE.RingQ(), lweReal[0], scale, skLWEInvNTT), DecryptLWE(paramsLWE.RingQ(), lweImag[0], scale, skLWEInvNTT))) + fmt.Printf("Slot %4d : RLWE %f LWE %f\n", paramsLWE.Slots()-1, values[paramsLWE.Slots()-1], complex(DecryptLWE(paramsLWE.RingQ(), lweReal[paramsLWE.Slots()-1], scale, skLWEInvNTT), DecryptLWE(paramsLWE.RingQ(), lweImag[paramsLWE.Slots()-1], scale, skLWEInvNTT))) + + // ********* STEP 3 : LWE -> RLWE REPACKING + fmt.Printf("Encode LWE Samples... ") + start = time.Now() + ptLWE := ckks.NewPlaintext(paramsRLWE, paramsRLWE.MaxLevel(), 1.0) + + // Encode the LWE samples as a vector + lweEncoded := make([]complex128, paramsRLWE.Slots()) + for i := 0; i < paramsRLWE.Slots(); i++ { + lweEncoded[i] = complex(float64(lweReal[i].b), float64(lweImag[i].b)) + lweEncoded[i] *= complex(math.Pow(1/(EvalModPoly.K()*EvalModPoly.QDiff()), 1.0), 0) // pre-scaling for Cheby + } + + encoder.EncodeNTT(ptLWE, lweEncoded, paramsRLWE.LogSlots()) + fmt.Printf("Done (%s)\n", time.Since(start)) + + // Encode A + + // Compute skLeft skRight + // __________ _ __________ _ + // | | | | | | | | + // | | | | | | | | + // n | ALeft | x | | + | ARight | x | | = A x sk + // | | | | | | | | + // |__________| |_| |__________| |_| + // + // N/n 1 N/n 1 + + fmt.Printf("Encode A... ") + start = time.Now() + AVectors := make([][]complex128, paramsLWE.Slots()) + for i := range AVectors { + tmp := make([]complex128, paramsLWE.N()) + for j := 0; j < paramsLWE.N(); j++ { + tmp[j] = complex(float64(lweReal[i].a[j]), float64(lweImag[i].a[j])) + tmp[j] *= complex(math.Pow(1/(EvalModPoly.K()*EvalModPoly.QDiff()), 0.5), 0) // sqrt(pre-scaling for Cheby) + } + + AVectors[i] = tmp + } + + // Diagonalize + AMatDiag := make(map[int][]complex128) + for i := 0; i < paramsLWE.Slots(); i++ { + tmp := make([]complex128, paramsLWE.N()) + for j := 0; j < paramsLWE.N(); j++ { + tmp[j] = AVectors[j%paramsLWE.Slots()][(j+i)%paramsLWE.N()] + } + AMatDiag[i] = tmp + } + + ptMatDiag := encoder.EncodeDiagMatrixBSGSAtLvl(paramsRLWE.MaxLevel(), AMatDiag, 1.0, 16.0, paramsLWE.LogN()) + fmt.Printf("Done (%s)\n", time.Since(start)) + + evalRepack := ckks.NewEvaluator(paramsRLWE, rlwe.EvaluationKey{Rlk: rlk, Rtks: rotKeyRepack}) + + fmt.Printf("Homomorphic Partial Decryption : pt = A x sk + encode(LWE) + I(X)*Q... ") + start = time.Now() + ctAs := evalRepack.LinearTransformNew(ctSk, ptMatDiag)[0] // A_left * sk || A_right * sk + ctAs = evalRepack.TraceNew(ctAs, paramsLWE.LogSlots(), paramsLWE.LogN()) // A * sk || A * sk + eval.Rescale(ctAs, 1.0, ctAs) + eval.Add(ctAs, ptLWE, ctAs) // A * sk || A * sk + LWE_real || LWE_imag = RLWE + I(X) * Q + ctAs.Scale = scale + fmt.Printf("Done (%s)\n", time.Since(start)) + + fmt.Printf("Homomorphic Modular Reduction : pt mod Q... ") + start = time.Now() + // Extract imaginary part : RLWE_real + I(X)*Q ; RLWE_imag + I(X)*Q + ctAsConj := eval.ConjugateNew(ctAs) + ctAsReal := eval.AddNew(ctAs, ctAsConj) + ctAsImag := eval.SubNew(ctAs, ctAsConj) + ctAsReal.Scale = ctAsReal.Scale * 2 // Divide by 2 + ctAsImag.Scale = ctAsImag.Scale * 2 // Divide by 2 + eval.ScaleUp(ctAsReal, math.Round((EvalModPoly.ScalingFactor()/EvalModPoly.MessageRatio())/ctAsReal.Scale), ctAsReal) // Scale the real message up to Sine/MessageRatio + eval.ScaleUp(ctAsImag, math.Round((EvalModPoly.ScalingFactor()/EvalModPoly.MessageRatio())/ctAsImag.Scale), ctAsImag) // Scale the imag message up to Sine/MessageRatio + v := encoder.DecodePublic(decryptor.DecryptNew(ctAsReal), paramsRLWE.LogSlots(), 0) + fmt.Printf("Slot %4d : Want %f Have %f\n", 0, values[0], v[0]) + fmt.Printf("Slot %4d : Want %f Have %f\n", paramsRLWE.Slots()-1, values[paramsRLWE.Slots()-1], v[paramsRLWE.Slots()-1]) + ctAsReal = eval.EvalModNew(ctAsReal, EvalModPoly) // Real mod Q + eval.DivByi(ctAsImag, ctAsImag) + ctAsImag = eval.EvalModNew(ctAsImag, EvalModPoly) // (-i*imag mod Q)*i + eval.MultByi(ctAsImag, ctAsImag) + eval.Add(ctAsReal, ctAsImag, ctAsReal) // Repack both imag and real parts + fmt.Printf("Done (%s)\n", time.Since(start)) + + fmt.Println("Visual Comparison :") + v = encoder.DecodePublic(decryptor.DecryptNew(ctAsReal), paramsRLWE.LogSlots(), 0) + fmt.Printf("Slot %4d : Want %f Have %f\n", 0, values[0], v[0]) + fmt.Printf("Slot %4d : Want %f Have %f\n", paramsRLWE.Slots()-1, values[paramsRLWE.Slots()-1], v[paramsRLWE.Slots()-1]) + +} + +func genRLWEParameters() (paramsRLWE ckks.Parameters) { + var err error + if paramsRLWE, err = ckks.NewParametersFromLiteral(ckks.ParametersLiteral{ + LogN: 14, + LogSlots: 8, + Scale: 1 << 30, + Sigma: rlwe.DefaultSigma, + Q: []uint64{ + 0xffff820001, // 40 Q0 + 0x2000000a0001, // 45 + 0x2000000e0001, // 45 + 0xfffffffff840001, // 60 Sine (double angle) + 0x1000000000860001, // 60 Sine (double angle) + 0xfffffffff6a0001, // 60 Sine + 0x1000000000980001, // 60 Sine + 0xfffffffff5a0001, // 60 Sine + 0x1000000000b00001, // 60 Sine + 0x1000000000ce0001, // 60 Sine + 0xfffffffff2a0001, // 60 Sine + 0x100000000060001, // 58 Repack & Change of basis + }, + P: []uint64{ + 0x1fffffffffe00001, // 61 + 0x1fffffffffc80001, // 61 + 0x1fffffffffb40001, // 61 + }, + }); err != nil { + panic(err) + } + return +} + +func genLWEParameters(paramsRLWE ckks.Parameters) (paramsLWE ckks.Parameters) { + var err error + if paramsLWE, err = ckks.NewParametersFromLiteral(ckks.ParametersLiteral{ + LogN: 10, + LogSlots: paramsRLWE.LogSlots(), + Scale: paramsRLWE.Scale(), + Sigma: paramsRLWE.Sigma(), + Q: paramsRLWE.Q()[:1], // 40 Q0 + P: paramsRLWE.P()[:1], // Pi 61 + }); err != nil { + panic(err) + } + return +} + +// DecryptLWE decrypts an LWE sample +func DecryptLWE(ringQ *ring.Ring, lwe RNSLWESample, scale float64, skInvNTT *ring.Poly) float64 { + + tmp := ringQ.NewPolyLvl(0) + pol := new(ring.Poly) + pol.Coeffs = [][]uint64{lwe.a} + ringQ.MulCoeffsMontgomeryLvl(0, pol, skInvNTT, tmp) + qi := ringQ.Modulus[0] + tmp0 := tmp.Coeffs[0] + tmp1 := lwe.b + for j := 0; j < ringQ.N; j++ { + tmp1 = ring.CRed(tmp1+tmp0[j], qi) + } + + if tmp1 >= ringQ.Modulus[0]>>1 { + tmp1 = ringQ.Modulus[0] - tmp1 + return -float64(tmp1) / scale + } + + return float64(tmp1) / scale +} + +// RNSLWESample is a struct for RNS LWE samples +type RNSLWESample struct { + b uint64 + a []uint64 +} + +// ExtractLWESamplesBitReversed extracts LWE samples from a R-LWE sample +func ExtractLWESamplesBitReversed(ct *ckks.Ciphertext, params ckks.Parameters) (LWEReal, LWEImag []RNSLWESample) { + + ringQ := params.RingQ() + + LWEReal = make([]RNSLWESample, params.Slots()) + LWEImag = make([]RNSLWESample, params.Slots()) + + // Copy coefficients multiplied by X^{N-1} in reverse order: + // a_{0} -a_{N-1} -a2_{N-2} ... -a_{1} + acc := ringQ.NewPolyLvl(ct.Level()) + for i, qi := range ringQ.Modulus[:ct.Level()+1] { + tmp0 := acc.Coeffs[i] + tmp1 := ct.Value[1].Coeffs[i] + tmp0[0] = tmp1[0] + for j := 1; j < ringQ.N; j++ { + tmp0[j] = qi - tmp1[ringQ.N-j] + } + } + + pol := ct.Value[0] + + gap := params.N() / (2 * params.Slots()) // Gap between plaintext coefficient if sparse packed + + // Real values + for i, idx := 0, 0; i < params.Slots(); i, idx = i+1, idx+gap { + + iRev := utils.BitReverse64(uint64(i), uint64(params.LogSlots())) + + LWEReal[iRev].b = pol.Coeffs[0][idx] + LWEReal[iRev].a = make([]uint64, params.N()) + copy(LWEReal[iRev].a, acc.Coeffs[0]) + + // Multiplies the accumulator by X^{N/(2*slots)} + MulBySmallMonomial(ringQ, acc, gap) + } + + // Imaginary values + for i, idx := 0, 0; i < params.Slots(); i, idx = i+1, idx+gap { + + iRev := utils.BitReverse64(uint64(i), uint64(params.LogSlots())) + + LWEImag[iRev].b = pol.Coeffs[0][idx+(params.N()>>1)] + LWEImag[iRev].a = make([]uint64, params.N()) + copy(LWEImag[iRev].a, acc.Coeffs[0]) + + // Multiply the accumulator by X^{N/(2*slots)} + MulBySmallMonomial(ringQ, acc, gap) + } + return +} + +//MulBySmallMonomial multiplies pol by x^n +func MulBySmallMonomial(ringQ *ring.Ring, pol *ring.Poly, n int) { + for i, qi := range ringQ.Modulus[:pol.Level()+1] { + pol.Coeffs[i] = append(pol.Coeffs[i][ringQ.N-n:], pol.Coeffs[i][:ringQ.N-n]...) + tmp := pol.Coeffs[i] + for j := 0; j < n; j++ { + tmp[j] = qi - tmp[j] + } + } +} diff --git a/examples/ckks/bootstrapping/main.go b/examples/ckks/bootstrapping/main.go index 3d18bae4..fe785cc2 100644 --- a/examples/ckks/bootstrapping/main.go +++ b/examples/ckks/bootstrapping/main.go @@ -5,6 +5,7 @@ import ( "math" "github.com/ldsec/lattigo/v2/ckks" + "github.com/ldsec/lattigo/v2/ckks/bootstrapping" "github.com/ldsec/lattigo/v2/rlwe" "github.com/ldsec/lattigo/v2/utils" ) @@ -13,7 +14,7 @@ func main() { var err error - var btp *ckks.Bootstrapper + var btp *bootstrapping.Bootstrapper var kgen rlwe.KeyGenerator var encoder ckks.Encoder var sk *rlwe.SecretKey @@ -28,8 +29,9 @@ func main() { // LogSlots is hardcoded to 15 in the parameters, but can be changed from 1 to 15. // When changing logSlots make sure that the number of levels allocated to CtS and StC is // smaller or equal to logSlots. - btpParams := ckks.DefaultBootstrapParams[0] - params, err := btpParams.Params() + ckksParams := bootstrapping.DefaultCKKSParameters[0] + btpParams := bootstrapping.DefaultParameters[0] + params, err := ckks.NewParametersFromLiteral(ckksParams) if err != nil { panic(err) } @@ -48,11 +50,10 @@ func main() { fmt.Println() fmt.Println("Generating bootstrapping keys...") - rotations := btpParams.RotationsForBootstrapping(params.LogSlots()) + rotations := btpParams.RotationsForBootstrapping(params.LogN(), params.LogSlots()) rotkeys := kgen.GenRotationKeysForRotations(rotations, true, sk) rlk := kgen.GenRelinearizationKey(sk, 2) - btpKey := ckks.BootstrappingKey{Rlk: rlk, Rtks: rotkeys} - if btp, err = ckks.NewBootstrapper(params, btpParams, btpKey); err != nil { + if btp, err = bootstrapping.NewBootstrapper(params, btpParams, rlwe.EvaluationKey{Rlk: rlk, Rtks: rotkeys}); err != nil { panic(err) } fmt.Println("Done") diff --git a/examples/ckks/sigmoid/main.go b/examples/ckks/sigmoid/main.go index cc8b4562..b3a0a570 100644 --- a/examples/ckks/sigmoid/main.go +++ b/examples/ckks/sigmoid/main.go @@ -69,8 +69,8 @@ func chebyshevinterpolation() { // We approximate f(x) in the range [-8, 8] with a Chebyshev interpolant of 33 coefficients (degree 32). chebyapproximation := ckks.Approximate(f, -8, 8, 33) - a := chebyapproximation.A() - b := chebyapproximation.B() + a := chebyapproximation.A + b := chebyapproximation.B // Change of variable evaluator.MultByConst(ciphertext, 2/(b-a), ciphertext) @@ -80,7 +80,7 @@ func chebyshevinterpolation() { } // We evaluate the interpolated Chebyshev interpolant on the ciphertext - if ciphertext, err = evaluator.EvaluateCheby(ciphertext, chebyapproximation, ciphertext.Scale); err != nil { + if ciphertext, err = evaluator.EvaluatePoly(ciphertext, chebyapproximation, ciphertext.Scale); err != nil { panic(err) } diff --git a/examples/dbfv/pir/main.go b/examples/dbfv/pir/main.go index 936924af..0a565cbc 100644 --- a/examples/dbfv/pir/main.go +++ b/examples/dbfv/pir/main.go @@ -10,7 +10,6 @@ import ( "github.com/ldsec/lattigo/v2/bfv" "github.com/ldsec/lattigo/v2/dbfv" "github.com/ldsec/lattigo/v2/drlwe" - "github.com/ldsec/lattigo/v2/ring" "github.com/ldsec/lattigo/v2/rlwe" "github.com/ldsec/lattigo/v2/utils" ) @@ -109,32 +108,25 @@ func main() { panic(err) } - // PRNG keyed with "lattigo" - lattigoPRNG, err := utils.NewKeyedPRNG([]byte{'l', 'a', 't', 't', 'i', 'g', 'o'}) + // Common reference polynomial generator that uses the PRNG + crs, err := utils.NewKeyedPRNG([]byte{'l', 'a', 't', 't', 'i', 'g', 'o'}) if err != nil { panic(err) } - // Ring for the common reference polynomials sampling - ringQP := params.RingQP() - - // Common reference polynomial generator that uses the PRNG - crsGen := ring.NewUniformSampler(lattigoPRNG, ringQP) - ternarySamplerMontgomery := ring.NewTernarySampler(lattigoPRNG, ringQP, 0.5, true) - // Instantiation of each of the protocols needed for the PIR example // Create each party, and allocate the memory for all the shares that the protocols will need - P := genparties(params, N, ternarySamplerMontgomery, ringQP) + P := genparties(params, N) // 1) Collective public key generation - pk := ckgphase(params, crsGen, P) + pk := ckgphase(params, crs, P) // 2) Collective relinearization key generation - rlk := rkgphase(params, crsGen, P) + rlk := rkgphase(params, crs, P) // 3) Collective rotation keys generation - rtk := rtkphase(params, crsGen, P) + rtk := rtkphase(params, crs, P) l.Printf("\tSetup done (cloud: %s, party: %s)\n", elapsedCKGCloud+elapsedRKGCloud+elapsedRTGCloud, @@ -160,7 +152,7 @@ func main() { encoder.EncodeUintMul(maskCoeffs, plainMask[i]) } - // Ciphertexts encrypted under CPK and stored in the cloud + // Ciphertexts encrypted under CKG and stored in the cloud l.Println("> Encrypt Phase") encryptor := bfv.NewEncryptor(params, pk) pt := bfv.NewPlaintext(params) @@ -230,7 +222,7 @@ func cksphase(params bfv.Parameters, P []*party, result *bfv.Ciphertext) *bfv.Ci return encOut } -func genparties(params bfv.Parameters, N int, sampler *ring.TernarySampler, ringQP *ring.Ring) []*party { +func genparties(params bfv.Parameters, N int) []*party { P := make([]*party, N) @@ -251,34 +243,34 @@ func genparties(params bfv.Parameters, N int, sampler *ring.TernarySampler, ring return P } -func ckgphase(params bfv.Parameters, crsGen *ring.UniformSampler, P []*party) *rlwe.PublicKey { +func ckgphase(params bfv.Parameters, crs utils.PRNG, P []*party) *rlwe.PublicKey { l := log.New(os.Stderr, "", 0) l.Println("> CKG Phase") ckg := dbfv.NewCKGProtocol(params) // Public key generation - crs := crsGen.ReadNew() // for the public-key + ckgCombined := ckg.AllocateShares() for _, pi := range P { pi.ckgShare = ckg.AllocateShares() } + crp := ckg.SampleCRP(crs) + elapsedCKGParty = runTimedParty(func() { for _, pi := range P { - ckg.GenShare(pi.sk, crs, pi.ckgShare) + ckg.GenShare(pi.sk, crp, pi.ckgShare) } }, len(P)) - ckgCombined := ckg.AllocateShares() - pk := bfv.NewPublicKey(params) elapsedCKGCloud = runTimed(func() { for _, pi := range P { ckg.AggregateShares(pi.ckgShare, ckgCombined, ckgCombined) } - ckg.GenPublicKey(ckgCombined, crs, pk) + ckg.GenPublicKey(ckgCombined, crp, pk) }) l.Printf("\tdone (cloud: %s, party: %s)\n", elapsedCKGCloud, elapsedCKGParty) @@ -286,21 +278,20 @@ func ckgphase(params bfv.Parameters, crsGen *ring.UniformSampler, P []*party) *r return pk } -func rkgphase(params bfv.Parameters, crsGen *ring.UniformSampler, P []*party) *rlwe.RelinearizationKey { +func rkgphase(params bfv.Parameters, crs utils.PRNG, P []*party) *rlwe.RelinearizationKey { l := log.New(os.Stderr, "", 0) l.Println("> RKG Phase") rkg := dbfv.NewRKGProtocol(params) // Relineariation key generation + _, rkgCombined1, rkgCombined2 := rkg.AllocateShares() + for _, pi := range P { pi.rlkEphemSk, pi.rkgShareOne, pi.rkgShareTwo = rkg.AllocateShares() } - crp := make([]*ring.Poly, params.Beta()) // for the relinearization keys - for i := 0; i < params.Beta(); i++ { - crp[i] = crsGen.ReadNew() - } + crp := rkg.SampleCRP(crs) elapsedRKGParty = runTimedParty(func() { for _, pi := range P { @@ -308,8 +299,6 @@ func rkgphase(params bfv.Parameters, crsGen *ring.UniformSampler, P []*party) *r } }, len(P)) - _, rkgCombined1, rkgCombined2 := rkg.AllocateShares() - elapsedRKGCloud = runTimed(func() { for _, pi := range P { rkg.AggregateShares(pi.rkgShareOne, rkgCombined1, rkgCombined1) @@ -318,7 +307,7 @@ func rkgphase(params bfv.Parameters, crsGen *ring.UniformSampler, P []*party) *r elapsedRKGParty += runTimedParty(func() { for _, pi := range P { - rkg.GenShareRoundTwo(pi.rlkEphemSk, pi.sk, rkgCombined1, crp, pi.rkgShareTwo) + rkg.GenShareRoundTwo(pi.rlkEphemSk, pi.sk, rkgCombined1, pi.rkgShareTwo) } }, len(P)) @@ -335,7 +324,7 @@ func rkgphase(params bfv.Parameters, crsGen *ring.UniformSampler, P []*party) *r return rlk } -func rtkphase(params bfv.Parameters, crsGen *ring.UniformSampler, P []*party) *rlwe.RotationKeySet { +func rtkphase(params bfv.Parameters, crs utils.PRNG, P []*party) *rlwe.RotationKeySet { l := log.New(os.Stderr, "", 0) @@ -347,28 +336,26 @@ func rtkphase(params bfv.Parameters, crsGen *ring.UniformSampler, P []*party) *r pi.rtgShare = rtg.AllocateShares() } - crpRot := make([]*ring.Poly, params.Beta()) // for the rotation keys - - for i := 0; i < params.Beta(); i++ { - crpRot[i] = crsGen.ReadNew() - } - galEls := params.GaloisElementsForRowInnerSum() rotKeySet := bfv.NewRotationKeySet(params, galEls) + for _, galEl := range galEls { + rtgShareCombined := rtg.AllocateShares() + + crp := rtg.SampleCRP(crs) + elapsedRTGParty += runTimedParty(func() { for _, pi := range P { - rtg.GenShare(pi.sk, galEl, crpRot, pi.rtgShare) + rtg.GenShare(pi.sk, galEl, crp, pi.rtgShare) } }, len(P)) - rtgShareCombined := rtg.AllocateShares() elapsedRTGCloud += runTimed(func() { for _, pi := range P { rtg.Aggregate(pi.rtgShare, rtgShareCombined, rtgShareCombined) } - rtg.GenRotationKey(rtgShareCombined, crpRot, rotKeySet.Keys[galEl]) + rtg.GenRotationKey(rtgShareCombined, crp, rotKeySet.Keys[galEl]) }) } l.Printf("\tdone (cloud: %s, party %s)\n", elapsedRTGCloud, elapsedRTGParty) diff --git a/examples/dbfv/psi/main.go b/examples/dbfv/psi/main.go index 549adc59..bc170907 100644 --- a/examples/dbfv/psi/main.go +++ b/examples/dbfv/psi/main.go @@ -10,7 +10,6 @@ import ( "github.com/ldsec/lattigo/v2/bfv" "github.com/ldsec/lattigo/v2/dbfv" "github.com/ldsec/lattigo/v2/drlwe" - "github.com/ldsec/lattigo/v2/ring" "github.com/ldsec/lattigo/v2/rlwe" "github.com/ldsec/lattigo/v2/utils" ) @@ -96,41 +95,27 @@ func main() { panic(err) } - // PRNG keyed with "lattigo" - lattigoPRNG, err := utils.NewKeyedPRNG([]byte{'l', 'a', 't', 't', 'i', 'g', 'o'}) + crs, err := utils.NewKeyedPRNG([]byte{'l', 'a', 't', 't', 'i', 'g', 'o'}) if err != nil { panic(err) } - // Ring for the common reference polynomials sampling - ringQP, _ := ring.NewRing(1< RKG Phase") rkg := dbfv.NewRKGProtocol(params) // Relineariation key generation + _, rkgCombined1, rkgCombined2 := rkg.AllocateShares() for _, pi := range P { pi.rlkEphemSk, pi.rkgShareOne, pi.rkgShareTwo = rkg.AllocateShares() } - crp := make([]*ring.Poly, params.Beta()) // for the relinearization keys - for i := 0; i < params.Beta(); i++ { - crp[i] = crsGen.ReadNew() - } + crp := rkg.SampleCRP(crs) elapsedRKGParty = runTimedParty(func() { for _, pi := range P { @@ -359,8 +342,6 @@ func rkgphase(params bfv.Parameters, crsGen *ring.UniformSampler, P []*party) *r } }, len(P)) - _, rkgCombined1, rkgCombined2 := rkg.AllocateShares() - elapsedRKGCloud = runTimed(func() { for _, pi := range P { rkg.AggregateShares(pi.rkgShareOne, rkgCombined1, rkgCombined1) @@ -369,7 +350,7 @@ func rkgphase(params bfv.Parameters, crsGen *ring.UniformSampler, P []*party) *r elapsedRKGParty += runTimedParty(func() { for _, pi := range P { - rkg.GenShareRoundTwo(pi.rlkEphemSk, pi.sk, rkgCombined1, crp, pi.rkgShareTwo) + rkg.GenShareRoundTwo(pi.rlkEphemSk, pi.sk, rkgCombined1, pi.rkgShareTwo) } }, len(P)) @@ -386,34 +367,33 @@ func rkgphase(params bfv.Parameters, crsGen *ring.UniformSampler, P []*party) *r return rlk } -func ckgphase(params bfv.Parameters, crsGen *ring.UniformSampler, P []*party) *rlwe.PublicKey { +func ckgphase(params bfv.Parameters, crs utils.PRNG, P []*party) *rlwe.PublicKey { l := log.New(os.Stderr, "", 0) l.Println("> CKG Phase") ckg := dbfv.NewCKGProtocol(params) // Public key generation - crs := crsGen.ReadNew() // for the public-key - + ckgCombined := ckg.AllocateShares() for _, pi := range P { pi.ckgShare = ckg.AllocateShares() } + crp := ckg.SampleCRP(crs) + elapsedCKGParty = runTimedParty(func() { for _, pi := range P { - ckg.GenShare(pi.sk, crs, pi.ckgShare) + ckg.GenShare(pi.sk, crp, pi.ckgShare) } }, len(P)) - ckgCombined := ckg.AllocateShares() - pk := bfv.NewPublicKey(params) elapsedCKGCloud = runTimed(func() { for _, pi := range P { ckg.AggregateShares(pi.ckgShare, ckgCombined, ckgCombined) } - ckg.GenPublicKey(ckgCombined, crs, pk) + ckg.GenPublicKey(ckgCombined, crp, pk) }) l.Printf("\tdone (cloud: %s, party: %s)\n", elapsedCKGCloud, elapsedCKGParty) diff --git a/examples/ring/vOLE/main.go b/examples/ring/vOLE/main.go index 161bb6aa..cc365a87 100644 --- a/examples/ring/vOLE/main.go +++ b/examples/ring/vOLE/main.go @@ -178,6 +178,8 @@ func main() { // ********* 1. SETUP ********* + pool := ringQ.NewPoly() + // NTT(MForm(skBob)) skBob := ternarySamplerMontgomeryQ.ReadNew() ringQ.NTT(skBob, skBob) @@ -286,7 +288,8 @@ func main() { ringQ.MulCoeffsMontgomeryAndSub(a[i], sigmaAlice, rhoAlice[i]) // rhoAlice = NTT(skBob * u) + NTT(a*skBob*skAlice - a*sigmaAlice) * (P/Q) - ringQ.DivRoundByLastModulusManyNTT(rhoAlice[i], rhoAlice[i], qlevel-plevel) + ringQ.DivRoundByLastModulusManyNTTLvl(qlevel, qlevel-plevel, rhoAlice[i], pool, rhoAlice[i]) + rhoAlice[i].Coeffs = rhoAlice[i].Coeffs[:plevel+1] } elapsed = time.Since(start) @@ -303,7 +306,8 @@ func main() { ringQ.MulCoeffsMontgomery(a[i], sigmaBob, rhoBob[i]) // rhoBob = NTT(a * sigmaBob * (P/Q)) - ringQ.DivRoundByLastModulusManyNTT(rhoBob[i], rhoBob[i], qlevel-plevel) + ringQ.DivRoundByLastModulusManyNTTLvl(qlevel, qlevel-plevel, rhoBob[i], pool, rhoBob[i]) + rhoBob[i].Coeffs = rhoBob[i].Coeffs[:plevel+1] // rhoBob = NTT(-a * sigmaBob) * (P/Q) ringQ.NegLvl(plevel, rhoBob[i], rhoBob[i]) @@ -394,7 +398,8 @@ func main() { // beta = (u*(v * (P/M) + e + a'*skAlice) * u - a'*-a*sigmaBob * (P/Q)) * (M/P) // = (M/P) * u * (v * (P/M) + e + a'*skAlice) - a'*-a*sigmaBob * (M/Q) // = u*v + a'*skAlice*u*(M/P) + a'*a*sigmaBob * (M/Q) - ringQ.DivRoundByLastModulusMany(beta[i], beta[i], plevel-mlevel) + ringQ.DivRoundByLastModulusManyLvl(plevel, plevel-mlevel, beta[i], pool, beta[i]) + beta[i].Coeffs = beta[i].Coeffs[:mlevel+1] } elapsed = time.Since(start) @@ -413,7 +418,8 @@ func main() { // alpha = (a'*skAlice*u + (a'*a*skBob*skAlice - a'*a*sigmaAlice) * (P/Q)) * (M/P) // = a'*skAlice*u*(M/P) + a'*a*skBob*skAlice*(M/Q) - a'*a*sigmaAlice*(M/Q) - ringQ.DivRoundByLastModulusMany(alpha[i], alpha[i], plevel-mlevel) + ringQ.DivRoundByLastModulusManyLvl(plevel, plevel-mlevel, alpha[i], pool, alpha[i]) + alpha[i].Coeffs = alpha[i].Coeffs[:mlevel+1] // alpha = - a'*skAlice*u*(M/P) - a'*a*skBob*skAlice*(M/Q) + a'*a*sigmaAlice*(M/Q) // = - a'*skAlice*u*(M/P) - a'*a*skBob*skAlice*(M/Q) + a'*a*(skBob*skAlice - sigmaBob)*(M/Q) diff --git a/ring/ring.go b/ring/ring.go index 07b7d2d6..b4db67aa 100644 --- a/ring/ring.go +++ b/ring/ring.go @@ -145,7 +145,7 @@ func (r *Ring) genNTTParams() error { for i := 0; i < j; i++ { - r.RescaleParams[j-1][i] = MForm(r.Modulus[i]-ModExp(r.Modulus[j], int(r.Modulus[i]-2), r.Modulus[i]), r.Modulus[i], r.BredParams[i]) + r.RescaleParams[j-1][i] = MForm(r.Modulus[i]-ModExp(r.Modulus[j], r.Modulus[i]-2, r.Modulus[i]), r.Modulus[i], r.BredParams[i]) } } @@ -160,7 +160,7 @@ func (r *Ring) genNTTParams() error { for i, qi := range r.Modulus { // 1.1 Compute N^(-1) mod Q in Montgomery form - r.NttNInv[i] = MForm(ModExp(uint64(r.N), int(qi-2), qi), qi, r.BredParams[i]) + r.NttNInv[i] = MForm(ModExp(uint64(r.N), qi-2, qi), qi, r.BredParams[i]) // 1.2 Compute Psi and PsiInv in Montgomery form r.NttPsi[i] = make([]uint64, r.N) @@ -169,10 +169,10 @@ func (r *Ring) genNTTParams() error { // Finds a 2N-th primitive Root g := primitiveRoot(qi) - _2n := r.N << 1 + _2n := uint64(r.N << 1) - power := (int(qi) - 1) / _2n - powerInv := (int(qi) - 1) - power + power := (qi - 1) / _2n + powerInv := (qi - 1) - power // Computes Psi and PsiInv in Montgomery form PsiMont := MForm(ModExp(g, power, qi), qi, r.BredParams[i]) diff --git a/ring/ring_basis_extension.go b/ring/ring_basis_extension.go index 27db52ee..a7e847be 100644 --- a/ring/ring_basis_extension.go +++ b/ring/ring_basis_extension.go @@ -2,7 +2,6 @@ package ring import ( "math" - "math/big" "math/bits" "unsafe" ) @@ -10,47 +9,44 @@ import ( // FastBasisExtender stores the necessary parameters for RNS basis extension. // The used algorithm is from https://eprint.iacr.org/2018/117.pdf. type FastBasisExtender struct { - ringQ *Ring - ringP *Ring - paramsQP *modupParams - paramsPQ *modupParams - modDownParamsPQ []uint64 - modDownParamsQP []uint64 + ringQ *Ring + ringP *Ring + paramsQtoP []modupParams + paramsPtoQ []modupParams + modDownparamsPtoQ [][]uint64 + modDownparamsQtoP [][]uint64 polypoolQ *Poly polypoolP *Poly } type modupParams struct { - Q []uint64 - P []uint64 - //Parameters for basis extension from Q to P // (Q/Qi)^-1) (mod each Qi) (in Montgomery form) - qibMont []uint64 + qoverqiinvqi []uint64 // Q/qi (mod each Pj) (in Montgomery form) - qispjMont [][]uint64 + qoverqimodp [][]uint64 // Q*v (mod each Pj) for v in [1,...,k] where k is the number of Pj moduli - qpjInv [][]uint64 - - bredParamsQ [][]uint64 - mredParamsQ []uint64 - - bredParamsP [][]uint64 - mredParamsP []uint64 + vtimesqmodp [][]uint64 } -func genModDownParams(ringP, ringQ *Ring) (params []uint64) { +func genModDownParams(ringQ, ringP *Ring) (params [][]uint64) { - params = make([]uint64, len(ringP.Modulus)) + params = make([][]uint64, len(ringP.Modulus)) - bredParams := ringP.BredParams - tmp := new(big.Int) - for i, Qi := range ringP.Modulus { + bredParams := ringQ.BredParams + mredParams := ringQ.MredParams - params[i] = tmp.Mod(ringQ.ModulusBigint, NewUint(Qi)).Uint64() - params[i] = ModExp(params[i], int(Qi-2), Qi) - params[i] = MForm(params[i], Qi, bredParams[i]) + for j := range ringP.Modulus { + params[j] = make([]uint64, len(ringQ.Modulus)) + for i, qi := range ringQ.Modulus { + params[j][i] = ModExp(ringP.Modulus[j], qi-2, qi) + params[j][i] = MForm(params[j][i], qi, bredParams[i]) + + if j > 0 { + params[j][i] = MRed(params[j][i], params[j-1][i], qi, mredParams[i]) + } + } } return @@ -64,11 +60,18 @@ func NewFastBasisExtender(ringQ, ringP *Ring) *FastBasisExtender { newParams.ringQ = ringQ newParams.ringP = ringP - newParams.paramsQP = basisextenderparameters(ringQ.Modulus, ringP.Modulus) - newParams.paramsPQ = basisextenderparameters(ringP.Modulus, ringQ.Modulus) + newParams.paramsQtoP = make([]modupParams, len(ringQ.Modulus)) + for i := range ringQ.Modulus { + newParams.paramsQtoP[i] = basisextenderparameters(ringQ.Modulus[:i+1], ringP.Modulus) + } - newParams.modDownParamsPQ = genModDownParams(ringQ, ringP) - newParams.modDownParamsQP = genModDownParams(ringP, ringQ) + newParams.paramsPtoQ = make([]modupParams, len(ringP.Modulus)) + for i := range ringP.Modulus { + newParams.paramsPtoQ[i] = basisextenderparameters(ringP.Modulus[:i+1], ringQ.Modulus) + } + + newParams.modDownparamsPtoQ = genModDownParams(ringQ, ringP) + newParams.modDownparamsQtoP = genModDownParams(ringP, ringQ) newParams.polypoolQ = ringQ.NewPoly() newParams.polypoolP = ringP.NewPoly() @@ -76,73 +79,76 @@ func NewFastBasisExtender(ringQ, ringP *Ring) *FastBasisExtender { return newParams } -func basisextenderparameters(Q, P []uint64) (params *modupParams) { +func basisextenderparameters(Q, P []uint64) modupParams { - params = new(modupParams) + bredParamsQ := make([][]uint64, len(Q)) + mredParamsQ := make([]uint64, len(Q)) + bredParamsP := make([][]uint64, len(P)) + mredParamsP := make([]uint64, len(P)) - params.Q = make([]uint64, len(Q)) - params.bredParamsQ = make([][]uint64, len(Q)) - params.mredParamsQ = make([]uint64, len(Q)) - for i, qi := range Q { - params.Q[i] = Q[i] - params.bredParamsQ[i] = BRedParams(qi) - params.mredParamsQ[i] = MRedParams(qi) + for i := range Q { + bredParamsQ[i] = BRedParams(Q[i]) + mredParamsQ[i] = MRedParams(Q[i]) } - params.P = make([]uint64, len(P)) - params.bredParamsP = make([][]uint64, len(P)) - params.mredParamsP = make([]uint64, len(P)) - for i, pj := range P { - params.P[i] = P[i] - params.bredParamsP[i] = BRedParams(pj) - params.mredParamsP[i] = MRedParams(pj) - } - - tmp := new(big.Int) - QiB := new(big.Int) - QiStar := new(big.Int) - QiBarre := new(big.Int) - - modulusbigint := NewUint(1) - for _, qi := range Q { - modulusbigint.Mul(modulusbigint, NewUint(qi)) - } - - params.qibMont = make([]uint64, len(Q)) - params.qispjMont = make([][]uint64, len(P)) - for i := range P { - params.qispjMont[i] = make([]uint64, len(Q)) + bredParamsP[i] = BRedParams(P[i]) + mredParamsP[i] = MRedParams(P[i]) } + qoverqiinvqi := make([]uint64, len(Q)) + qoverqimodp := make([][]uint64, len(P)) + + for i := range P { + qoverqimodp[i] = make([]uint64, len(Q)) + } + + var qiStar uint64 for i, qi := range Q { - QiB.SetUint64(qi) - QiStar.Quo(modulusbigint, QiB) - QiBarre.ModInverse(QiStar, QiB) - QiBarre.Mod(QiBarre, QiB) + qiStar = MForm(1, qi, bredParamsQ[i]) + + for j := 0; j < len(Q); j++ { + if j != i { + qiStar = MRed(qiStar, MForm(Q[j], qi, bredParamsQ[i]), qi, mredParamsQ[i]) + } + } // (Q/Qi)^-1) * r (mod Qi) (in Montgomery form) - params.qibMont[i] = MForm(QiBarre.Uint64(), qi, params.bredParamsQ[i]) + qoverqiinvqi[i] = ModexpMontgomery(qiStar, int(qi-2), qi, mredParamsQ[i], bredParamsQ[i]) for j, pj := range P { // (Q/qi * r) (mod Pj) (in Montgomery form) - params.qispjMont[j][i] = MForm(tmp.Mod(QiStar, NewUint(pj)).Uint64(), pj, params.bredParamsP[j]) + qiStar = 1 + for u := 0; u < len(Q); u++ { + if u != i { + qiStar = MRed(qiStar, MForm(Q[u], pj, bredParamsP[j]), pj, mredParamsP[j]) + } + } + + qoverqimodp[j][i] = MForm(qiStar, pj, bredParamsP[j]) } } - params.qpjInv = make([][]uint64, len(P)) + vtimesqmodp := make([][]uint64, len(P)) + var QmodPi uint64 for j, pj := range P { - params.qpjInv[j] = make([]uint64, len(Q)+1) + vtimesqmodp[j] = make([]uint64, len(Q)+1) // Correction Term (v*Q) mod each Pj - v := pj - tmp.Mod(modulusbigint, NewUint(pj)).Uint64() - params.qpjInv[j][0] = 0 + + QmodPi = 1 + for _, qi := range Q { + QmodPi = MRed(QmodPi, MForm(qi, pj, bredParamsP[j]), pj, mredParamsP[j]) + } + + v := pj - QmodPi + vtimesqmodp[j][0] = 0 for i := 1; i < len(Q)+1; i++ { - params.qpjInv[j][i] = CRed(params.qpjInv[j][i-1]+v, pj) + vtimesqmodp[j][i] = CRed(vtimesqmodp[j][i-1]+v, pj) } } - return + return modupParams{qoverqiinvqi: qoverqiinvqi, qoverqimodp: qoverqimodp, vtimesqmodp: vtimesqmodp} } // ShallowCopy creates a shallow copy of this basis extender in which the read-only data-structures are @@ -152,71 +158,56 @@ func (basisextender *FastBasisExtender) ShallowCopy() *FastBasisExtender { return nil } return &FastBasisExtender{ - ringQ: basisextender.ringQ, - ringP: basisextender.ringP, - paramsQP: basisextender.paramsQP, - paramsPQ: basisextender.paramsPQ, - modDownParamsQP: basisextender.modDownParamsQP, - modDownParamsPQ: basisextender.modDownParamsPQ, + ringQ: basisextender.ringQ, + ringP: basisextender.ringP, + paramsQtoP: basisextender.paramsQtoP, + paramsPtoQ: basisextender.paramsPtoQ, + modDownparamsQtoP: basisextender.modDownparamsQtoP, + modDownparamsPtoQ: basisextender.modDownparamsPtoQ, polypoolQ: basisextender.ringQ.NewPoly(), polypoolP: basisextender.ringP.NewPoly(), } } -// ModUpSplitQP extends the RNS basis of a polynomial from Q to QP. +// ModUpQtoP extends the RNS basis of a polynomial from Q to QP. // Given a polynomial with coefficients in basis {Q0,Q1....Qlevel}, // it extends its basis from {Q0,Q1....Qlevel} to {Q0,Q1....Qlevel,P0,P1...Pj} -func (basisextender *FastBasisExtender) ModUpSplitQP(level int, p1, p2 *Poly) { - modUpExact(p1.Coeffs[:level+1], p2.Coeffs[:len(basisextender.paramsQP.P)], basisextender.paramsQP) +func (basisextender *FastBasisExtender) ModUpQtoP(levelQ, levelP int, polQ, polP *Poly) { + modUpExact(polQ.Coeffs[:levelQ+1], polP.Coeffs[:levelP+1], basisextender.ringQ, basisextender.ringP, basisextender.paramsQtoP[levelQ]) } -// ModUpSplitPQ extends the RNS basis of a polynomial from P to PQ. +// ModUpPtoQ extends the RNS basis of a polynomial from P to PQ. // Given a polynomial with coefficients in basis {P0,P1....Plevel}, // it extends its basis from {P0,P1....Plevel} to {Q0,Q1...Qj} -func (basisextender *FastBasisExtender) ModUpSplitPQ(level int, p1, p2 *Poly) { - modUpExact(p1.Coeffs[:level+1], p2.Coeffs[:len(basisextender.paramsPQ.P)], basisextender.paramsPQ) +func (basisextender *FastBasisExtender) ModUpPtoQ(levelP, levelQ int, polP, polQ *Poly) { + modUpExact(polP.Coeffs[:levelP+1], polQ.Coeffs[:levelQ+1], basisextender.ringP, basisextender.ringQ, basisextender.paramsPtoQ[levelP]) } -// ModDownNTTPQ reduces the basis RNS of a polynomial in the NTT domain -// from QP to Q and divides its coefficients by P. -// Given a polynomial with coefficients in basis {Q0,Q1....Qlevel,P0,P1...Pj}, -// it reduces its basis from {Q0,Q1....Qlevel,P0,P1...Pj} to {Q0,Q1....Qlevel} -// and performs a rounded integer division of the result by P. -// Inputs must be in the NTT domain. -func (basisextender *FastBasisExtender) ModDownNTTPQ(level int, p1, p2 *Poly) { +// ModDownQPtoQ reduces the basis of a polynomial. +// Given a polynomial with coefficients in basis {Q0,Q1....Qlevel} and {P0,P1...Pj}, +// it reduces its basis from {Q0,Q1....Qlevel} and {P0,P1...Pj} to {Q0,Q1....Qlevel} +// and does a rounded integer division of the result by P. +func (basisextender *FastBasisExtender) ModDownQPtoQ(levelQ, levelP int, p1Q, p1P, p2Q *Poly) { ringQ := basisextender.ringQ - ringP := basisextender.ringP - modDownParams := basisextender.modDownParamsPQ + modDownParams := basisextender.modDownparamsPtoQ polypool := basisextender.polypoolQ - nQi := len(ringQ.Modulus) - nPj := len(ringP.Modulus) - - // First we get the P basis part of p1 out of the NTT domain - for j := 0; j < nPj; j++ { - InvNTTLazy(p1.Coeffs[nQi+j], p1.Coeffs[nQi+j], ringP.N, ringP.NttPsiInv[j], ringP.NttNInv[j], ringP.Modulus[j], ringP.MredParams[j]) - } // Then we target this P basis of p1 and convert it to a Q basis (at the "level" of p1) and copy it on polypool // polypool is now the representation of the P basis of p1 but in basis Q (at the "level" of p1) - modUpExact(p1.Coeffs[nQi:nQi+nPj], polypool.Coeffs[:level+1], basisextender.paramsPQ) + basisextender.ModUpPtoQ(levelP, levelQ, p1P, polypool) // Finally, for each level of p1 (and polypool since they now share the same basis) we compute p2 = (P^-1) * (p1 - polypool) mod Q - for i := 0; i < level+1; i++ { + for i := 0; i < levelQ+1; i++ { qi := ringQ.Modulus[i] twoqi := qi << 1 - p1tmp := p1.Coeffs[i] - p2tmp := p2.Coeffs[i] + p1tmp := p1Q.Coeffs[i] + p2tmp := p2Q.Coeffs[i] p3tmp := polypool.Coeffs[i] - params := qi - modDownParams[i] + params := qi - modDownParams[levelP][i] mredParams := ringQ.MredParams[i] - bredParams := ringQ.BredParams[i] - nttPsi := ringQ.NttPsi[i] - - // First we switch back the relevant polypool CRT array back to the NTT domain - NTTLazy(p3tmp, p3tmp, ringQ.N, nttPsi, qi, mredParams, bredParams) // Then for each coefficient we compute (P^-1) * (p1[i][j] - polypool[i][j]) mod qi for j := 0; j < ringQ.N; j = j + 8 { @@ -239,34 +230,34 @@ func (basisextender *FastBasisExtender) ModDownNTTPQ(level int, p1, p2 *Poly) { // In total we do len(P) + len(Q) NTT, which is optimal (linear in the number of moduli of P and Q) } -// ModDownSplitNTTPQ reduces the basis of a polynomial. +// ModDownQPtoQNTT reduces the basis of a polynomial. // Given a polynomial with coefficients in basis {Q0,Q1....Qi} and {P0,P1...Pj}, // it reduces its basis from {Q0,Q1....Qi} and {P0,P1...Pj} to {Q0,Q1....Qi} // and does a rounded integer division of the result by P. // Inputs must be in the NTT domain. -func (basisextender *FastBasisExtender) ModDownSplitNTTPQ(level int, p1Q, p1P, p2 *Poly) { +func (basisextender *FastBasisExtender) ModDownQPtoQNTT(levelQ, levelP int, p1Q, p1P, p2Q *Poly) { ringQ := basisextender.ringQ ringP := basisextender.ringP - modDownParams := basisextender.modDownParamsPQ + modDownParams := basisextender.modDownparamsPtoQ polypool := basisextender.polypoolQ // First we get the P basis part of p1 out of the NTT domain - ringP.InvNTTLazy(p1P, p1P) + ringP.InvNTTLazyLvl(levelP, p1P, p1P) // Then we target this P basis of p1 and convert it to a Q basis (at the "level" of p1) and copy it on polypool // polypool is now the representation of the P basis of p1 but in basis Q (at the "level" of p1) - modUpExact(p1P.Coeffs, polypool.Coeffs[:level+1], basisextender.paramsPQ) + basisextender.ModUpPtoQ(levelP, levelQ, p1P, polypool) // Finally, for each level of p1 (and polypool since they now share the same basis) we compute p2 = (P^-1) * (p1 - polypool) mod Q - for i := 0; i < level+1; i++ { + for i := 0; i < levelQ+1; i++ { qi := ringQ.Modulus[i] twoqi := qi << 1 p1tmp := p1Q.Coeffs[i] - p2tmp := p2.Coeffs[i] + p2tmp := p2Q.Coeffs[i] p3tmp := polypool.Coeffs[i] - params := qi - modDownParams[i] + params := qi - modDownParams[levelP][i] mredParams := ringQ.MredParams[i] bredParams := ringQ.BredParams[i] nttPsi := ringQ.NttPsi[i] @@ -295,112 +286,19 @@ func (basisextender *FastBasisExtender) ModDownSplitNTTPQ(level int, p1Q, p1P, p // In total we do len(P) + len(Q) NTT, which is optimal (linear in the number of moduli of P and Q) } -// ModDownPQ reduces the basis of a polynomial. -// Given a polynomial with coefficients in basis {Q0,Q1....Qlevel,P0,P1...Pj}, -// it reduces its basis from {Q0,Q1....Qlevel,P0,P1...Pj} to {Q0,Q1....Qlevel} -// and does a rounded integer division of the result by P. -func (basisextender *FastBasisExtender) ModDownPQ(level int, p1, p2 *Poly) { - - ringQ := basisextender.ringQ - modDownParams := basisextender.modDownParamsPQ - polypool := basisextender.polypoolQ - nPi := len(basisextender.paramsQP.P) - - // We target this P basis of p1 and convert it to a Q basis (at the "level" of p1) and copy it on polypool - // polypool is now the representation of the P basis of p1 but in basis Q (at the "level" of p1) - modUpExact(p1.Coeffs[level+1:level+1+nPi], polypool.Coeffs[:level+1], basisextender.paramsPQ) - - // Finally, for each level of p1 (and polypool since they now share the same basis) we compute p2 = (P^-1) * (p1 - polypool) mod Q - for i := 0; i < level+1; i++ { - - qi := ringQ.Modulus[i] - twoqi := qi << 1 - p1tmp := p1.Coeffs[i] - p2tmp := p2.Coeffs[i] - p3tmp := polypool.Coeffs[i] - params := qi - modDownParams[i] - mredParams := ringQ.MredParams[i] - - // Then for each coefficient we compute (P^-1) * (p1[i][j] - polypool[i][j]) mod qi - for j := 0; j < ringQ.N; j = j + 8 { - - x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j])) - y := (*[8]uint64)(unsafe.Pointer(&p3tmp[j])) - z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j])) - - z[0] = MRed(y[0]+twoqi-x[0], params, qi, mredParams) - z[1] = MRed(y[1]+twoqi-x[1], params, qi, mredParams) - z[2] = MRed(y[2]+twoqi-x[2], params, qi, mredParams) - z[3] = MRed(y[3]+twoqi-x[3], params, qi, mredParams) - z[4] = MRed(y[4]+twoqi-x[4], params, qi, mredParams) - z[5] = MRed(y[5]+twoqi-x[5], params, qi, mredParams) - z[6] = MRed(y[6]+twoqi-x[6], params, qi, mredParams) - z[7] = MRed(y[7]+twoqi-x[7], params, qi, mredParams) - } - } - - // In total we do len(P) + len(Q) NTT, which is optimal (linear in the number of moduli of P and Q) -} - -// ModDownSplitPQ reduces the basis of a polynomial. -// Given a polynomial with coefficients in basis {Q0,Q1....Qlevel} and {P0,P1...Pj}, -// it reduces its basis from {Q0,Q1....Qlevel} and {P0,P1...Pj} to {Q0,Q1....Qlevel} -// and does a rounded integer division of the result by P. -func (basisextender *FastBasisExtender) ModDownSplitPQ(level int, p1Q, p1P, p2 *Poly) { - - ringQ := basisextender.ringQ - modDownParams := basisextender.modDownParamsPQ - polypool := basisextender.polypoolQ - - // Then we target this P basis of p1 and convert it to a Q basis (at the "level" of p1) and copy it on polypool - // polypool is now the representation of the P basis of p1 but in basis Q (at the "level" of p1) - modUpExact(p1P.Coeffs, polypool.Coeffs[:level+1], basisextender.paramsPQ) - - // Finally, for each level of p1 (and polypool since they now share the same basis) we compute p2 = (P^-1) * (p1 - polypool) mod Q - for i := 0; i < level+1; i++ { - - qi := ringQ.Modulus[i] - twoqi := qi << 1 - p1tmp := p1Q.Coeffs[i] - p2tmp := p2.Coeffs[i] - p3tmp := polypool.Coeffs[i] - params := qi - modDownParams[i] - mredParams := ringQ.MredParams[i] - - // Then for each coefficient we compute (P^-1) * (p1[i][j] - polypool[i][j]) mod qi - for j := 0; j < ringQ.N; j = j + 8 { - - x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j])) - y := (*[8]uint64)(unsafe.Pointer(&p3tmp[j])) - z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j])) - - z[0] = MRed(y[0]+twoqi-x[0], params, qi, mredParams) - z[1] = MRed(y[1]+twoqi-x[1], params, qi, mredParams) - z[2] = MRed(y[2]+twoqi-x[2], params, qi, mredParams) - z[3] = MRed(y[3]+twoqi-x[3], params, qi, mredParams) - z[4] = MRed(y[4]+twoqi-x[4], params, qi, mredParams) - z[5] = MRed(y[5]+twoqi-x[5], params, qi, mredParams) - z[6] = MRed(y[6]+twoqi-x[6], params, qi, mredParams) - z[7] = MRed(y[7]+twoqi-x[7], params, qi, mredParams) - } - } - - // In total we do len(P) + len(Q) NTT, which is optimal (linear in the number of moduli of P and Q) -} - -// ModDownSplitQP reduces the basis of a polynomial. +// ModDownQPtoP reduces the basis of a polynomial. // Given a polynomial with coefficients in basis {Q0,Q1....QlevelQ} and {P0,P1...PlevelP}, // it reduces its basis from {Q0,Q1....QlevelQ} and {P0,P1...PlevelP} to {P0,P1...PlevelP} // and does a floored integer division of the result by Q. -func (basisextender *FastBasisExtender) ModDownSplitQP(levelQ, levelP int, p1Q, p1P, p2 *Poly) { +func (basisextender *FastBasisExtender) ModDownQPtoP(levelQ, levelP int, p1Q, p1P, p2P *Poly) { ringP := basisextender.ringP - modDownParams := basisextender.modDownParamsQP + modDownParams := basisextender.modDownparamsQtoP polypool := basisextender.polypoolP // Then we target this P basis of p1 and convert it to a Q basis (at the "level" of p1) and copy it on polypool // polypool is now the representation of the P basis of p1 but in basis Q (at the "level" of p1) - basisextender.ModUpSplitQP(levelQ, p1Q, polypool) + basisextender.ModUpQtoP(levelQ, levelP, p1Q, polypool) // Finally, for each level of p1 (and polypool since they now share the same basis) we compute p2 = (P^-1) * (p1 - polypool) mod Q for i := 0; i < levelP+1; i++ { @@ -408,9 +306,9 @@ func (basisextender *FastBasisExtender) ModDownSplitQP(levelQ, levelP int, p1Q, qi := ringP.Modulus[i] twoqi := qi << 1 p1tmp := p1P.Coeffs[i] - p2tmp := p2.Coeffs[i] + p2tmp := p2P.Coeffs[i] p3tmp := polypool.Coeffs[i] - params := qi - modDownParams[i] + params := qi - modDownParams[levelP][i] mredParams := ringP.MredParams[i] // Then for each coefficient we compute (P^-1) * (p1[i][j] - polypool[i][j]) mod qi @@ -435,26 +333,24 @@ func (basisextender *FastBasisExtender) ModDownSplitQP(levelQ, levelP int, p1Q, } // Caution, returns the values in [0, 2q-1] -func modUpExact(p1, p2 [][]uint64, params *modupParams) { +func modUpExact(p1, p2 [][]uint64, ringQ, ringP *Ring, params modupParams) { var v [8]uint64 var y0, y1, y2, y3, y4, y5, y6, y7 [32]uint64 + Q := ringQ.Modulus + P := ringP.Modulus + mredParamsQ := ringQ.MredParams + mredParamsP := ringP.MredParams + vtimesqmodp := params.vtimesqmodp + qoverqiinvqi := params.qoverqiinvqi + qoverqimodp := params.qoverqimodp + // We loop over each coefficient and apply the basis extension for x := 0; x < len(p1[0]); x = x + 8 { - - reconstructRNS(len(p1), x, p1, &v, &y0, &y1, &y2, &y3, &y4, &y5, &y6, &y7, params.Q, params.mredParamsQ, params.qibMont) - + reconstructRNS(len(p1), x, p1, &v, &y0, &y1, &y2, &y3, &y4, &y5, &y6, &y7, Q, mredParamsQ, qoverqiinvqi) for j := 0; j < len(p2); j++ { - - pj := params.P[j] - qInv := params.mredParamsP[j] - qpjInv := params.qpjInv[j] - qispjMont := params.qispjMont[j] - - res := (*[8]uint64)(unsafe.Pointer(&p2[j][x])) - - multSum(res, &v, &y0, &y1, &y2, &y3, &y4, &y5, &y6, &y7, len(p1), pj, qInv, qpjInv, qispjMont) + multSum((*[8]uint64)(unsafe.Pointer(&p2[j][x])), &v, &y0, &y1, &y2, &y3, &y4, &y5, &y6, &y7, len(p1), P[j], mredParamsP[j], vtimesqmodp[j], qoverqimodp[j]) } } } @@ -463,75 +359,63 @@ func modUpExact(p1, p2 [][]uint64, params *modupParams) { // This decomposer takes a p(x)_Q (in basis Q) and returns p(x) mod qi in basis QP, where // qi = prod(Q_i) for 0<=i<=L, where L is the number of factors in P. type Decomposer struct { - nQprimes int - nPprimes int - alpha int - beta int - xalpha []int - modUpParams [][]*modupParams - QInt *big.Int - PInt *big.Int -} - -// Xalpha returns a slice that contains all the values of #Qi/#Pi. -func (decomposer *Decomposer) Xalpha() (xalpha []int) { - return decomposer.xalpha + ringQ, ringP *Ring + modUpParams [][][]modupParams } // NewDecomposer creates a new Decomposer. -func NewDecomposer(Q, P []uint64) (decomposer *Decomposer) { +func NewDecomposer(ringQ, ringP *Ring) (decomposer *Decomposer) { decomposer = new(Decomposer) - decomposer.nQprimes = len(Q) - decomposer.nPprimes = len(P) + decomposer.ringQ = ringQ + decomposer.ringP = ringP - decomposer.QInt = NewUint(1) - for i := range Q { - decomposer.QInt.Mul(decomposer.QInt, NewUint(Q[i])) - } + Q := ringQ.Modulus - decomposer.PInt = NewUint(1) - for i := range P { - decomposer.PInt.Mul(decomposer.PInt, NewUint(P[i])) - } + decomposer.modUpParams = make([][][]modupParams, len(ringP.Modulus)-1) - decomposer.alpha = len(P) - decomposer.beta = int(math.Ceil(float64(len(Q)) / float64(decomposer.alpha))) + for lvlP := range ringP.Modulus[1:] { - decomposer.xalpha = make([]int, decomposer.beta) - for i := range decomposer.xalpha { - decomposer.xalpha[i] = decomposer.alpha - } + P := ringP.Modulus[:lvlP+2] - if len(Q)%decomposer.alpha != 0 { - decomposer.xalpha[decomposer.beta-1] = len(Q) % decomposer.alpha - } + alpha := len(P) + beta := int(math.Ceil(float64(len(Q)) / float64(alpha))) - decomposer.modUpParams = make([][]*modupParams, decomposer.beta) + xalpha := make([]int, beta) + for i := range xalpha { + xalpha[i] = alpha + } - // Create a basis extension for each possible combination of [Qi,Pj] according to xalpha - for i := 0; i < decomposer.beta; i++ { + if len(Q)%alpha != 0 { + xalpha[beta-1] = len(Q) % alpha + } - decomposer.modUpParams[i] = make([]*modupParams, decomposer.xalpha[i]-1) + decomposer.modUpParams[lvlP] = make([][]modupParams, beta) - for j := 0; j < decomposer.xalpha[i]-1; j++ { + // Create modUpParams for each possible combination of [Qi,Pj] according to xalpha + for i := 0; i < beta; i++ { - Qi := make([]uint64, j+2) - Pi := make([]uint64, len(Q)+len(P)) + decomposer.modUpParams[lvlP][i] = make([]modupParams, xalpha[i]-1) - for k := 0; k < j+2; k++ { - Qi[k] = Q[i*decomposer.alpha+k] + for j := 0; j < xalpha[i]-1; j++ { + + Qi := make([]uint64, j+2) + Pi := make([]uint64, len(Q)+len(P)) + + for k := 0; k < j+2; k++ { + Qi[k] = Q[i*alpha+k] + } + + for k := 0; k < len(Q); k++ { + Pi[k] = Q[k] + } + + for k := len(Q); k < len(Q)+len(P); k++ { + Pi[k] = P[k-len(Q)] + } + + decomposer.modUpParams[lvlP][i][j] = basisextenderparameters(Qi, Pi) } - - for k := 0; k < len(Q); k++ { - Pi[k] = Q[k] - } - - for k := len(Q); k < len(Q)+len(P); k++ { - Pi[k] = P[k-len(Q)] - } - - decomposer.modUpParams[i][j] = basisextenderparameters(Qi, Pi) } } @@ -540,69 +424,75 @@ func NewDecomposer(Q, P []uint64) (decomposer *Decomposer) { // DecomposeAndSplit decomposes a polynomial p(x) in basis Q, reduces it modulo qi, and returns // the result in basis QP separately. -func (decomposer *Decomposer) DecomposeAndSplit(level, crtDecompLevel int, p0, p1Q, p1P *Poly) { +func (decomposer *Decomposer) DecomposeAndSplit(levelQ, levelP, alpha, beta int, p0Q, p1Q, p1P *Poly) { - alphai := decomposer.xalpha[crtDecompLevel] + ringQ := decomposer.ringQ + ringP := decomposer.ringP - p0idxst := crtDecompLevel * decomposer.alpha - p0idxed := p0idxst + alphai + lvlQStart := beta * alpha + + var decompLvl int + if levelQ > alpha*(beta+1)-1 { + decompLvl = alpha - 2 + } else { + decompLvl = (levelQ % alpha) - 1 + } // First we check if the vector can simply by coping and rearranging elements (the case where no reconstruction is needed) - if (p0idxed > level+1 && (level+1)%decomposer.nPprimes == 1) || alphai == 1 { + if decompLvl == -1 { - for j := 0; j < level+1; j++ { - copy(p1Q.Coeffs[j], p0.Coeffs[p0idxst]) + for j := 0; j < levelQ+1; j++ { + copy(p1Q.Coeffs[j], p0Q.Coeffs[lvlQStart]) } - for j := 0; j < decomposer.nPprimes; j++ { - copy(p1P.Coeffs[j], p0.Coeffs[p0idxst]) + for j := 0; j < levelP+1; j++ { + copy(p1P.Coeffs[j], p0Q.Coeffs[lvlQStart]) } // Otherwise, we apply a fast exact base conversion for the reconstruction } else { - var index int - if level >= alphai+crtDecompLevel*decomposer.alpha { - index = decomposer.xalpha[crtDecompLevel] - 2 - } else { - index = (level - 1) % decomposer.alpha - } - - params := decomposer.modUpParams[crtDecompLevel][index] + params := decomposer.modUpParams[alpha-2][beta][decompLvl] var v [8]uint64 var vi [8]float64 var y0, y1, y2, y3, y4, y5, y6, y7 [32]uint64 - var qibMont, qi, pj, mredParams uint64 - var qif float64 + + Q := ringQ.Modulus + P := ringP.Modulus + mredParamsQ := ringQ.MredParams + mredParamsP := ringP.MredParams + qoverqiinvqi := params.qoverqiinvqi + vtimesqmodp := params.vtimesqmodp + qoverqimodp := params.qoverqimodp // We loop over each coefficient and apply the basis extension - for x := 0; x < len(p0.Coeffs[0]); x = x + 8 { + for x := 0; x < len(p0Q.Coeffs[0]); x = x + 8 { vi[0], vi[1], vi[2], vi[3], vi[4], vi[5], vi[6], vi[7] = 0, 0, 0, 0, 0, 0, 0, 0 // Coefficients to be decomposed - for i, j := 0, p0idxst; i < index+2; i, j = i+1, j+1 { + for i, j := 0, lvlQStart; i < decompLvl+2; i, j = i+1, j+1 { - qibMont = params.qibMont[i] - qi = params.Q[i] - mredParams = params.mredParamsQ[i] - qif = float64(qi) + qqiinv := qoverqiinvqi[i] + qi := Q[j] + mredParams := mredParamsQ[j] + qif := float64(qi) - px := (*[8]uint64)(unsafe.Pointer(&p0.Coeffs[j][x])) + px := (*[8]uint64)(unsafe.Pointer(&p0Q.Coeffs[j][x])) py := (*[8]uint64)(unsafe.Pointer(&p1Q.Coeffs[j][x])) // For the coefficients to be decomposed, we can simply copy them py[0], py[1], py[2], py[3], py[4], py[5], py[6], py[7] = px[0], px[1], px[2], px[3], px[4], px[5], px[6], px[7] - y0[i] = MRed(px[0], qibMont, qi, mredParams) - y1[i] = MRed(px[1], qibMont, qi, mredParams) - y2[i] = MRed(px[2], qibMont, qi, mredParams) - y3[i] = MRed(px[3], qibMont, qi, mredParams) - y4[i] = MRed(px[4], qibMont, qi, mredParams) - y5[i] = MRed(px[5], qibMont, qi, mredParams) - y6[i] = MRed(px[6], qibMont, qi, mredParams) - y7[i] = MRed(px[7], qibMont, qi, mredParams) + y0[i] = MRed(px[0], qqiinv, qi, mredParams) + y1[i] = MRed(px[1], qqiinv, qi, mredParams) + y2[i] = MRed(px[2], qqiinv, qi, mredParams) + y3[i] = MRed(px[3], qqiinv, qi, mredParams) + y4[i] = MRed(px[4], qqiinv, qi, mredParams) + y5[i] = MRed(px[5], qqiinv, qi, mredParams) + y6[i] = MRed(px[6], qqiinv, qi, mredParams) + y7[i] = MRed(px[7], qqiinv, qi, mredParams) // Computation of the correction term v * Q%pi vi[0] += float64(y0[i]) / qif @@ -626,42 +516,18 @@ func (decomposer *Decomposer) DecomposeAndSplit(level, crtDecompLevel int, p0, p v[7] = uint64(vi[7]) // Coefficients of index smaller than the ones to be decomposed - for j := 0; j < p0idxst; j++ { - - pj = params.P[j] - qInv := params.mredParamsP[j] - qpjInv := params.qpjInv[j] - qispjMont := params.qispjMont[j] - - res := (*[8]uint64)(unsafe.Pointer(&p1Q.Coeffs[j][x])) - - multSum(res, &v, &y0, &y1, &y2, &y3, &y4, &y5, &y6, &y7, index+2, pj, qInv, qpjInv, qispjMont) + for j := 0; j < lvlQStart; j++ { + multSum((*[8]uint64)(unsafe.Pointer(&p1Q.Coeffs[j][x])), &v, &y0, &y1, &y2, &y3, &y4, &y5, &y6, &y7, decompLvl+2, Q[j], mredParamsQ[j], vtimesqmodp[j], qoverqimodp[j]) } // Coefficients of index greater than the ones to be decomposed - for j := decomposer.alpha * crtDecompLevel; j < level+1; j = j + 1 { - - pj = params.P[j] - qInv := params.mredParamsP[j] - qpjInv := params.qpjInv[j] - qispjMont := params.qispjMont[j] - - res := (*[8]uint64)(unsafe.Pointer(&p1Q.Coeffs[j][x])) - - multSum(res, &v, &y0, &y1, &y2, &y3, &y4, &y5, &y6, &y7, index+2, pj, qInv, qpjInv, qispjMont) + for j := alpha * beta; j < levelQ+1; j = j + 1 { + multSum((*[8]uint64)(unsafe.Pointer(&p1Q.Coeffs[j][x])), &v, &y0, &y1, &y2, &y3, &y4, &y5, &y6, &y7, decompLvl+2, Q[j], mredParamsQ[j], vtimesqmodp[j], qoverqimodp[j]) } // Coefficients of the special primes Pi - for j, u := 0, decomposer.nQprimes; j < decomposer.nPprimes; j, u = j+1, u+1 { - - pj = params.P[u] - qInv := params.mredParamsP[u] - qpjInv := params.qpjInv[u] - qispjMont := params.qispjMont[u] - - res := (*[8]uint64)(unsafe.Pointer(&p1P.Coeffs[j][x])) - - multSum(res, &v, &y0, &y1, &y2, &y3, &y4, &y5, &y6, &y7, index+2, pj, qInv, qpjInv, qispjMont) + for j, u := 0, len(ringQ.Modulus); j < levelP+1; j, u = j+1, u+1 { + multSum((*[8]uint64)(unsafe.Pointer(&p1P.Coeffs[j][x])), &v, &y0, &y1, &y2, &y3, &y4, &y5, &y6, &y7, decompLvl+2, P[j], mredParamsP[j], vtimesqmodp[u], qoverqimodp[u]) } } } @@ -670,25 +536,25 @@ func (decomposer *Decomposer) DecomposeAndSplit(level, crtDecompLevel int, p0, p func reconstructRNS(index, x int, p [][]uint64, v *[8]uint64, y0, y1, y2, y3, y4, y5, y6, y7 *[32]uint64, Q, QInv, QbMont []uint64) { var vi [8]float64 - var qi, qiInv, qibMont uint64 + var qi, qiInv, qoverqiinvqi uint64 var qif float64 for i := 0; i < index; i++ { - qibMont = QbMont[i] + qoverqiinvqi = QbMont[i] qi = Q[i] qiInv = QInv[i] qif = float64(qi) pTmp := (*[8]uint64)(unsafe.Pointer(&p[i][x])) - y0[i] = MRed(pTmp[0], qibMont, qi, qiInv) - y1[i] = MRed(pTmp[1], qibMont, qi, qiInv) - y2[i] = MRed(pTmp[2], qibMont, qi, qiInv) - y3[i] = MRed(pTmp[3], qibMont, qi, qiInv) - y4[i] = MRed(pTmp[4], qibMont, qi, qiInv) - y5[i] = MRed(pTmp[5], qibMont, qi, qiInv) - y6[i] = MRed(pTmp[6], qibMont, qi, qiInv) - y7[i] = MRed(pTmp[7], qibMont, qi, qiInv) + y0[i] = MRed(pTmp[0], qoverqiinvqi, qi, qiInv) + y1[i] = MRed(pTmp[1], qoverqiinvqi, qi, qiInv) + y2[i] = MRed(pTmp[2], qoverqiinvqi, qi, qiInv) + y3[i] = MRed(pTmp[3], qoverqiinvqi, qi, qiInv) + y4[i] = MRed(pTmp[4], qoverqiinvqi, qi, qiInv) + y5[i] = MRed(pTmp[5], qoverqiinvqi, qi, qiInv) + y6[i] = MRed(pTmp[6], qoverqiinvqi, qi, qiInv) + y7[i] = MRed(pTmp[7], qoverqiinvqi, qi, qiInv) // Computation of the correction term v * Q%pi vi[0] += float64(y0[i]) / qif @@ -712,68 +578,68 @@ func reconstructRNS(index, x int, p [][]uint64, v *[8]uint64, y0, y1, y2, y3, y4 } // Caution, returns the values in [0, 2q-1] -func multSum(res, v *[8]uint64, y0, y1, y2, y3, y4, y5, y6, y7 *[32]uint64, index int, pj, qInv uint64, qpjInv, qispjMont []uint64) { +func multSum(res, v *[8]uint64, y0, y1, y2, y3, y4, y5, y6, y7 *[32]uint64, alpha int, pj, qInv uint64, vtimesqmodp, qoverqimodp []uint64) { var rlo, rhi [8]uint64 var mhi, mlo, c, hhi uint64 // Accumulates the sum on uint128 and does a lazy montgomery reduction at the end - for i := 0; i < index; i++ { + for i := 0; i < alpha; i++ { - mhi, mlo = bits.Mul64(y0[i], qispjMont[i]) + mhi, mlo = bits.Mul64(y0[i], qoverqimodp[i]) rlo[0], c = bits.Add64(rlo[0], mlo, 0) rhi[0] += mhi + c - mhi, mlo = bits.Mul64(y1[i], qispjMont[i]) + mhi, mlo = bits.Mul64(y1[i], qoverqimodp[i]) rlo[1], c = bits.Add64(rlo[1], mlo, 0) rhi[1] += mhi + c - mhi, mlo = bits.Mul64(y2[i], qispjMont[i]) + mhi, mlo = bits.Mul64(y2[i], qoverqimodp[i]) rlo[2], c = bits.Add64(rlo[2], mlo, 0) rhi[2] += mhi + c - mhi, mlo = bits.Mul64(y3[i], qispjMont[i]) + mhi, mlo = bits.Mul64(y3[i], qoverqimodp[i]) rlo[3], c = bits.Add64(rlo[3], mlo, 0) rhi[3] += mhi + c - mhi, mlo = bits.Mul64(y4[i], qispjMont[i]) + mhi, mlo = bits.Mul64(y4[i], qoverqimodp[i]) rlo[4], c = bits.Add64(rlo[4], mlo, 0) rhi[4] += mhi + c - mhi, mlo = bits.Mul64(y5[i], qispjMont[i]) + mhi, mlo = bits.Mul64(y5[i], qoverqimodp[i]) rlo[5], c = bits.Add64(rlo[5], mlo, 0) rhi[5] += mhi + c - mhi, mlo = bits.Mul64(y6[i], qispjMont[i]) + mhi, mlo = bits.Mul64(y6[i], qoverqimodp[i]) rlo[6], c = bits.Add64(rlo[6], mlo, 0) rhi[6] += mhi + c - mhi, mlo = bits.Mul64(y7[i], qispjMont[i]) + mhi, mlo = bits.Mul64(y7[i], qoverqimodp[i]) rlo[7], c = bits.Add64(rlo[7], mlo, 0) rhi[7] += mhi + c } hhi, _ = bits.Mul64(rlo[0]*qInv, pj) - res[0] = rhi[0] - hhi + pj + qpjInv[v[0]] + res[0] = rhi[0] - hhi + pj + vtimesqmodp[v[0]] hhi, _ = bits.Mul64(rlo[1]*qInv, pj) - res[1] = rhi[1] - hhi + pj + qpjInv[v[1]] + res[1] = rhi[1] - hhi + pj + vtimesqmodp[v[1]] hhi, _ = bits.Mul64(rlo[2]*qInv, pj) - res[2] = rhi[2] - hhi + pj + qpjInv[v[2]] + res[2] = rhi[2] - hhi + pj + vtimesqmodp[v[2]] hhi, _ = bits.Mul64(rlo[3]*qInv, pj) - res[3] = rhi[3] - hhi + pj + qpjInv[v[3]] + res[3] = rhi[3] - hhi + pj + vtimesqmodp[v[3]] hhi, _ = bits.Mul64(rlo[4]*qInv, pj) - res[4] = rhi[4] - hhi + pj + qpjInv[v[4]] + res[4] = rhi[4] - hhi + pj + vtimesqmodp[v[4]] hhi, _ = bits.Mul64(rlo[5]*qInv, pj) - res[5] = rhi[5] - hhi + pj + qpjInv[v[5]] + res[5] = rhi[5] - hhi + pj + vtimesqmodp[v[5]] hhi, _ = bits.Mul64(rlo[6]*qInv, pj) - res[6] = rhi[6] - hhi + pj + qpjInv[v[6]] + res[6] = rhi[6] - hhi + pj + vtimesqmodp[v[6]] hhi, _ = bits.Mul64(rlo[7]*qInv, pj) - res[7] = rhi[7] - hhi + pj + qpjInv[v[7]] + res[7] = rhi[7] - hhi + pj + vtimesqmodp[v[7]] } diff --git a/ring/ring_benchmark_test.go b/ring/ring_benchmark_test.go index a917f43f..4ac7e4ae 100644 --- a/ring/ring_benchmark_test.go +++ b/ring/ring_benchmark_test.go @@ -279,23 +279,24 @@ func benchExtendBasis(testContext *testParams, b *testing.B) { p0 := testContext.uniformSamplerQ.ReadNew() p1 := testContext.uniformSamplerP.ReadNew() - level := len(testContext.ringQ.Modulus) - 1 + levelQ := len(testContext.ringQ.Modulus) - 1 + levelP := len(testContext.ringP.Modulus) - 1 b.Run(fmt.Sprintf("ExtendBasis/ModUp/N=%d/limbsQ=%d/limbsP=%d", testContext.ringQ.N, len(testContext.ringQ.Modulus), len(testContext.ringP.Modulus)), func(b *testing.B) { for i := 0; i < b.N; i++ { - basisExtender.ModUpSplitQP(level, p0, p1) + basisExtender.ModUpQtoP(levelQ, levelP, p0, p1) } }) b.Run(fmt.Sprintf("ExtendBasis/ModDown/N=%d/limbsQ=%d/limbsP=%d", testContext.ringQ.N, len(testContext.ringQ.Modulus), len(testContext.ringP.Modulus)), func(b *testing.B) { for i := 0; i < b.N; i++ { - basisExtender.ModDownSplitPQ(level, p0, p1, p0) + basisExtender.ModDownQPtoQ(levelQ, levelP, p0, p1, p0) } }) b.Run(fmt.Sprintf("ExtendBasis/ModDownNTT/N=%d/limbsQ=%d/limbsP=%d", testContext.ringQ.N, len(testContext.ringQ.Modulus), len(testContext.ringP.Modulus)), func(b *testing.B) { for i := 0; i < b.N; i++ { - basisExtender.ModDownSplitNTTPQ(level, p0, p1, p0) + basisExtender.ModDownQPtoQNTT(levelQ, levelP, p0, p1, p0) } }) } @@ -305,27 +306,29 @@ func benchDivByLastModulus(testContext *testParams, b *testing.B) { p0 := testContext.uniformSamplerQ.ReadNew() p1 := testContext.ringQ.NewPolyLvl(p0.Level() - 1) + pool := testContext.ringQ.NewPoly() + b.Run(testString("DivByLastModulus/Floor/", testContext.ringQ), func(b *testing.B) { for i := 0; i < b.N; i++ { - testContext.ringQ.DivFloorByLastModulus(p0, p1) + testContext.ringQ.DivFloorByLastModulusLvl(p0.Level(), p0, p1) } }) b.Run(testString("DivByLastModulus/FloorNTT/", testContext.ringQ), func(b *testing.B) { for i := 0; i < b.N; i++ { - testContext.ringQ.DivFloorByLastModulusNTT(p0, p1) + testContext.ringQ.DivFloorByLastModulusNTTLvl(p0.Level(), p0, pool, p1) } }) b.Run(testString("DivByLastModulus/Round/", testContext.ringQ), func(b *testing.B) { for i := 0; i < b.N; i++ { - testContext.ringQ.DivRoundByLastModulus(p0, p1) + testContext.ringQ.DivRoundByLastModulusLvl(p0.Level(), p0, p1) } }) b.Run(testString("DivByLastModulus/RoundNTT/", testContext.ringQ), func(b *testing.B) { for i := 0; i < b.N; i++ { - testContext.ringQ.DivRoundByLastModulusNTT(p0, p1) + testContext.ringQ.DivRoundByLastModulusNTTLvl(p0.Level(), p0, pool, p1) } }) } @@ -377,7 +380,9 @@ func benchDivByRNSBasis(testContext *testParams, b *testing.B) { coeffs[i] = RandInt(testContext.ringQ.ModulusBigint) } - scaler := NewRNSScaler(T, testContext.ringQ) + ringT, _ := NewRing(testContext.ringQ.N, []uint64{T}) + + scaler := NewRNSScaler(testContext.ringQ, ringT) polyQ := testContext.ringQ.NewPoly() polyT := NewPoly(testContext.ringQ.N, 1) testContext.ringQ.SetCoefficientsBigint(coeffs, polyQ) diff --git a/ring/ring_operations.go b/ring/ring_operations.go index 69b1073d..97aaa238 100644 --- a/ring/ring_operations.go +++ b/ring/ring_operations.go @@ -707,6 +707,24 @@ func (r *Ring) Shift(p1 *Poly, n int, p2 *Poly) { } } +// MFormVec switches the input vector to the Montgomery domain. +func MFormVec(p0, p1 []uint64, qi uint64, bredParams []uint64) { + + for j := 0; j < len(p0); j = j + 8 { + x := (*[8]uint64)(unsafe.Pointer(&p0[j])) + z := (*[8]uint64)(unsafe.Pointer(&p1[j])) + + z[0] = MForm(x[0], qi, bredParams) + z[1] = MForm(x[1], qi, bredParams) + z[2] = MForm(x[2], qi, bredParams) + z[3] = MForm(x[3], qi, bredParams) + z[4] = MForm(x[4], qi, bredParams) + z[5] = MForm(x[5], qi, bredParams) + z[6] = MForm(x[6], qi, bredParams) + z[7] = MForm(x[7], qi, bredParams) + } +} + // MForm switches p1 to the Montgomery domain and writes the result on p2. func (r *Ring) MForm(p1, p2 *Poly) { r.MFormLvl(r.minLevelBinary(p1, p2), p1, p2) @@ -718,20 +736,7 @@ func (r *Ring) MFormLvl(level int, p1, p2 *Poly) { qi := r.Modulus[i] bredParams := r.BredParams[i] p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i] - for j := 0; j < r.N; j = j + 8 { - - x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j])) - z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j])) - - z[0] = MForm(x[0], qi, bredParams) - z[1] = MForm(x[1], qi, bredParams) - z[2] = MForm(x[2], qi, bredParams) - z[3] = MForm(x[3], qi, bredParams) - z[4] = MForm(x[4], qi, bredParams) - z[5] = MForm(x[5], qi, bredParams) - z[6] = MForm(x[6], qi, bredParams) - z[7] = MForm(x[7], qi, bredParams) - } + MFormVec(p1tmp, p2tmp, qi, bredParams) } } @@ -907,6 +912,21 @@ func (r *Ring) MulByVectorMontgomeryAndAddNoMod(p1 *Poly, vector []uint64, p2 *P } } +// MapSmallDimensionToLargerDimensionNTT maps Y = X^{N/n} -> X directly in the NTT domain +func MapSmallDimensionToLargerDimensionNTT(polSmall, polLarge *Poly) { + gap := len(polLarge.Coeffs[0]) / len(polSmall.Coeffs[0]) + for j := range polSmall.Coeffs { + tmp0 := polSmall.Coeffs[j] + tmp1 := polLarge.Coeffs[j] + for i := range polSmall.Coeffs[0] { + coeff := tmp0[i] + for w := 0; w < gap; w++ { + tmp1[i*gap+w] = coeff + } + } + } +} + // BitReverse applies a bit reverse permutation on the coefficients of p1 and writes the result on p2. // In can safely be used for in-place permutation. func (r *Ring) BitReverse(p1, p2 *Poly) { @@ -947,7 +967,7 @@ func (r *Ring) Rotate(p1 *Poly, n int, p2 *Poly) { root = MRed(r.PsiMont[i], r.PsiMont[i], qi, mredParams) - root = modexpMontgomery(root, n, qi, mredParams, r.BredParams[i]) + root = ModexpMontgomery(root, n, qi, mredParams, r.BredParams[i]) gal = MForm(1, qi, r.BredParams[i]) diff --git a/ring/ring_sampler.go b/ring/ring_sampler.go index bf13d750..391a8f79 100644 --- a/ring/ring_sampler.go +++ b/ring/ring_sampler.go @@ -10,3 +10,10 @@ type baseSampler struct { prng utils.PRNG baseRing *Ring } + +// Sampler is an interface for random polynomial samplers. +// It has a single Read method which takes as argument the polynomial to be +// populated according to the Sampler's distribution. +type Sampler interface { + Read(pOut *Poly) +} diff --git a/ring/ring_sampler_uniform.go b/ring/ring_sampler_uniform.go index 0fd6de66..4e33005a 100644 --- a/ring/ring_sampler_uniform.go +++ b/ring/ring_sampler_uniform.go @@ -29,7 +29,7 @@ func (uniformSampler *UniformSampler) Read(Pol *Poly) { uniformSampler.prng.Clock(uniformSampler.randomBufferN) - for j := range uniformSampler.baseRing.Modulus { + for j := range uniformSampler.baseRing.Modulus[:len(Pol.Coeffs)] { qi = uniformSampler.baseRing.Modulus[j] diff --git a/ring/ring_scaling.go b/ring/ring_scaling.go index f144506c..83e8d106 100644 --- a/ring/ring_scaling.go +++ b/ring/ring_scaling.go @@ -15,35 +15,33 @@ type Scaler interface { // RNSScaler implements the Scaler interface by performing a scaling by t/Q in the RNS domain. // This implementation of the Scaler interface is preferred over the SimpleScaler implementation. type RNSScaler struct { - ringQ *Ring - polypoolT *Poly + ringQ, ringT *Ring + polypoolQ *Poly + polypoolT *Poly qHalf *big.Int // (q-1)/2 qHalfModT uint64 // (q-1)/2 mod t + qInv uint64 //(q mod t)^-1 mod t - t uint64 - qInv uint64 //(q mod t)^-1 mod t - - mredParamsT uint64 - - paramsQP *modupParams + paramsQP modupParams } // NewRNSScaler creates a new SimpleScaler from t, the modulus under which the reconstruction is returned, the Ring in which the polynomial to reconstruct is represented. -func NewRNSScaler(t uint64, ringQ *Ring) (rnss *RNSScaler) { +func NewRNSScaler(ringQ, ringT *Ring) (rnss *RNSScaler) { rnss = new(RNSScaler) rnss.ringQ = ringQ + rnss.ringT = ringT - rnss.mredParamsT = MRedParams(t) + rnss.polypoolQ = ringQ.NewPoly() + rnss.polypoolT = ringT.NewPoly() - rnss.polypoolT = NewPoly(ringQ.N, 1) + t := ringT.Modulus[0] - rnss.t = t rnss.qHalf = new(big.Int) rnss.qInv = rnss.qHalf.Mod(ringQ.ModulusBigint, NewUint(t)).Uint64() - rnss.qInv = ModExp(rnss.qInv, int(t-2), t) + rnss.qInv = ModExp(rnss.qInv, t-2, t) rnss.qInv = MForm(rnss.qInv, t, BRedParams(t)) rnss.qHalf.Set(ringQ.ModulusBigint) @@ -62,24 +60,25 @@ func NewRNSScaler(t uint64, ringQ *Ring) (rnss *RNSScaler) { func (rnss *RNSScaler) DivByQOverTRounded(p1Q, p2T *Poly) { ringQ := rnss.ringQ + ringT := rnss.ringT - T := rnss.t + T := ringT.Modulus[0] p2tmp := p2T.Coeffs[0] p3tmp := rnss.polypoolT.Coeffs[0] - mredParams := rnss.mredParamsT + mredParams := rnss.ringT.MredParams[0] qInv := T - rnss.qInv qHalfModT := T - rnss.qHalfModT // Multiply P_{Q} by t and extend the basis from P_{Q} to t*(P_{Q}||P_{t}) // Since the coefficients of P_{t} are multiplied by t, they are all zero, // hence the basis extension can be omitted - ringQ.MulScalar(p1Q, T, p1Q) + ringQ.MulScalar(p1Q, T, rnss.polypoolQ) // Center t*P_{Q} around (Q-1)/2 to round instead of floor during the division - ringQ.AddScalarBigint(p1Q, rnss.qHalf, p1Q) + ringQ.AddScalarBigint(rnss.polypoolQ, rnss.qHalf, rnss.polypoolQ) // Extend the basis of (t*P_{Q} + (Q-1)/2) to (t*P_{t} + (Q-1)/2) - modUpExact(p1Q.Coeffs, rnss.polypoolT.Coeffs, rnss.paramsQP) + modUpExact(rnss.polypoolQ.Coeffs, rnss.polypoolT.Coeffs, ringQ, ringT, rnss.paramsQP) // Compute [Q^{-1} * (t*P_{t} - (t*P_{Q} - ((Q-1)/2 mod t)))] mod t which returns round(t/Q * P_{Q}) mod t for j := 0; j < ringQ.N; j = j + 8 { @@ -288,17 +287,12 @@ func (ss *SimpleScaler) reconstructAndScale(p1, p2 *Poly) { // ============== Scaling-related methods ============== -// DivFloorByLastModulusNTT divides (floored) the polynomial by its last modulus. The input must be in the NTT domain. +// DivFloorByLastModulusNTTLvl divides (floored) the polynomial by its last modulus. The input must be in the NTT domain. // Output poly level must be equal or one less than input level. -func (r *Ring) DivFloorByLastModulusNTT(p0, p1 *Poly) { - r.divFloorByLastModulusNTT(p0.Level(), p0, p1) - p1.Coeffs = p1.Coeffs[:p0.Level()] -} +func (r *Ring) DivFloorByLastModulusNTTLvl(level int, p0, pool, p1 *Poly) { -func (r *Ring) divFloorByLastModulusNTT(level int, p0, p1 *Poly) { - - pool0 := make([]uint64, len(p0.Coeffs[0])) - pool1 := make([]uint64, len(p0.Coeffs[0])) + pool0 := pool.Coeffs[0] + pool1 := pool.Coeffs[1] InvNTTLazy(p0.Coeffs[level], pool0, r.N, r.NttPsiInv[level], r.NttNInv[level], r.Modulus[level], r.MredParams[level]) @@ -334,14 +328,9 @@ func (r *Ring) divFloorByLastModulusNTT(level int, p0, p1 *Poly) { } } -// DivFloorByLastModulus divides (floored) the polynomial by its last modulus. +// DivFloorByLastModulusLvl divides (floored) the polynomial by its last modulus. // Output poly level must be equal or one less than input level. -func (r *Ring) DivFloorByLastModulus(p0, p1 *Poly) { - r.divFloorByLastModulus(p0.Level(), p0, p1) - p1.Coeffs = p1.Coeffs[:p0.Level()] -} - -func (r *Ring) divFloorByLastModulus(level int, p0, p1 *Poly) { +func (r *Ring) DivFloorByLastModulusLvl(level int, p0, p1 *Poly) { for i := 0; i < level; i++ { p0tmp := p0.Coeffs[level] @@ -371,80 +360,66 @@ func (r *Ring) divFloorByLastModulus(level int, p0, p1 *Poly) { } } -// DivFloorByLastModulusManyNTT divides (floored) sequentially nbRescales times the polynomial by its last modulus. Input must be in the NTT domain. +// DivFloorByLastModulusManyNTTLvl divides (floored) sequentially nbRescales times the polynomial by its last modulus. Input must be in the NTT domain. // Output poly level must be equal or nbRescales less than input level. -func (r *Ring) DivFloorByLastModulusManyNTT(p0, p1 *Poly, nbRescales int) { - - level := p0.Level() +func (r *Ring) DivFloorByLastModulusManyNTTLvl(level, nbRescales int, p0, pool, p1 *Poly) { if nbRescales == 0 { if p0 != p1 { - CopyValuesLvl(p1.Level(), p0, p1) + CopyValuesLvl(level, p0, p1) } } else { - r.InvNTTLvl(level, p0, p1) + r.InvNTTLvl(level, p0, pool) for i := 0; i < nbRescales; i++ { - r.divFloorByLastModulus(level-i, p1, p1) + r.DivFloorByLastModulusLvl(level-i, pool, pool) } - p1.Coeffs = p1.Coeffs[:level-nbRescales+1] - - r.NTTLvl(p1.Level(), p1, p1) + r.NTTLvl(level-nbRescales, pool, p1) } } -// DivFloorByLastModulusMany divides (floored) sequentially nbRescales times the polynomial by its last modulus. +// DivFloorByLastModulusManyLvl divides (floored) sequentially nbRescales times the polynomial by its last modulus. // Output poly level must be equal or nbRescales less than input level. -func (r *Ring) DivFloorByLastModulusMany(p0, p1 *Poly, nbRescales int) { - - level := p0.Level() +func (r *Ring) DivFloorByLastModulusManyLvl(level, nbRescales int, p0, pool, p1 *Poly) { if nbRescales == 0 { if p0 != p1 { - CopyValuesLvl(p1.Level(), p0, p1) + CopyValuesLvl(level, p0, p1) } } else { if nbRescales > 1 { - r.divFloorByLastModulus(level, p0, p1) + r.DivFloorByLastModulusLvl(level, p0, pool) for i := 1; i < nbRescales; i++ { if i == nbRescales-1 { - r.divFloorByLastModulus(level-i, p1, p1) + r.DivFloorByLastModulusLvl(level-i, pool, p1) } else { - r.divFloorByLastModulus(level-i, p1, p1) + r.DivFloorByLastModulusLvl(level-i, pool, pool) } } } else { - r.divFloorByLastModulus(level, p0, p1) + r.DivFloorByLastModulusLvl(level, p0, p1) } - - p1.Coeffs = p1.Coeffs[:level-nbRescales+1] } - } -// DivRoundByLastModulusNTT divides (rounded) the polynomial by its last modulus. The input must be in the NTT domain. +// DivRoundByLastModulusNTTLvl divides (rounded) the polynomial by its last modulus. The input must be in the NTT domain. // Output poly level must be equal or one less than input level. -func (r *Ring) DivRoundByLastModulusNTT(p0, p1 *Poly) { - r.divRoundByLastModulusNTT(p0.Level(), p0, p1) - p1.Coeffs = p1.Coeffs[:p0.Level()] -} - -func (r *Ring) divRoundByLastModulusNTT(level int, p0, p1 *Poly) { +func (r *Ring) DivRoundByLastModulusNTTLvl(level int, p0, pool, p1 *Poly) { var pHalf, pHalfNegQi uint64 - pool0 := make([]uint64, len(p0.Coeffs[0])) - pool1 := make([]uint64, len(p0.Coeffs[0])) + pool0 := pool.Coeffs[0] + pool1 := pool.Coeffs[1] InvNTT(p0.Coeffs[level], pool0, r.N, r.NttPsiInv[level], r.NttNInv[level], r.Modulus[level], r.MredParams[level]) @@ -515,14 +490,9 @@ func (r *Ring) divRoundByLastModulusNTT(level int, p0, p1 *Poly) { } } -// DivRoundByLastModulus divides (rounded) the polynomial by its last modulus. The input must be in the NTT domain. +// DivRoundByLastModulusLvl divides (rounded) the polynomial by its last modulus. The input must be in the NTT domain. // Output poly level must be equal or one less than input level. -func (r *Ring) DivRoundByLastModulus(p0, p1 *Poly) { - r.divRoundByLastModulus(p0.Level(), p0, p1) - p1.Coeffs = p1.Coeffs[:p0.Level()] -} - -func (r *Ring) divRoundByLastModulus(level int, p0, p1 *Poly) { +func (r *Ring) DivRoundByLastModulusLvl(level int, p0, p1 *Poly) { var pHalf, pHalfNegQi uint64 @@ -577,69 +547,61 @@ func (r *Ring) divRoundByLastModulus(level int, p0, p1 *Poly) { } } -// DivRoundByLastModulusManyNTT divides (rounded) sequentially nbRescales times the polynomial by its last modulus. The input must be in the NTT domain. +// DivRoundByLastModulusManyNTTLvl divides (rounded) sequentially nbRescales times the polynomial by its last modulus. The input must be in the NTT domain. // Output poly level must be equal or nbRescales less than input level. -func (r *Ring) DivRoundByLastModulusManyNTT(p0, p1 *Poly, nbRescales int) { - - level := p0.Level() +func (r *Ring) DivRoundByLastModulusManyNTTLvl(level, nbRescales int, p0, pool, p1 *Poly) { if nbRescales == 0 { if p0 != p1 { - CopyValuesLvl(p1.Level(), p0, p1) + CopyValuesLvl(level, p0, p1) } } else { if nbRescales > 1 { - r.InvNTTLvl(level, p0, p1) + r.InvNTTLvl(level, p0, pool) for i := 0; i < nbRescales; i++ { - r.divRoundByLastModulus(level-i, p1, p1) + r.DivRoundByLastModulusLvl(level-i, pool, pool) } - r.NTTLvl(p1.Level(), p1, p1) + r.NTTLvl(p1.Level(), pool, p1) } else { - r.divRoundByLastModulusNTT(level, p0, p1) + r.DivRoundByLastModulusNTTLvl(level, p0, pool, p1) } - - p1.Coeffs = p1.Coeffs[:level-nbRescales+1] } } -// DivRoundByLastModulusMany divides (rounded) sequentially nbRescales times the polynomial by its last modulus. +// DivRoundByLastModulusManyLvl divides (rounded) sequentially nbRescales times the polynomial by its last modulus. // Output poly level must be equal or nbRescales less than input level. -func (r *Ring) DivRoundByLastModulusMany(p0, p1 *Poly, nbRescales int) { - - level := p0.Level() +func (r *Ring) DivRoundByLastModulusManyLvl(level, nbRescales int, p0, pool, p1 *Poly) { if nbRescales == 0 { if p0 != p1 { - CopyValuesLvl(p1.Level(), p0, p1) + CopyValuesLvl(level, p0, p1) } } else { if nbRescales > 1 { - r.divRoundByLastModulus(level, p0, p1) + r.DivRoundByLastModulusLvl(level, p0, pool) for i := 1; i < nbRescales; i++ { if i == nbRescales-1 { - r.divRoundByLastModulus(level-i, p1, p1) + r.DivRoundByLastModulusLvl(level-i, pool, p1) } else { - r.divRoundByLastModulus(level-i, p1, p1) + r.DivRoundByLastModulusLvl(level-i, pool, pool) } } } else { - r.divRoundByLastModulus(level, p0, p1) + r.DivRoundByLastModulusLvl(level, p0, p1) } - - p1.Coeffs = p1.Coeffs[:level-nbRescales+1] } } diff --git a/ring/ring_test.go b/ring/ring_test.go index 4fa6d7cc..0c54757c 100644 --- a/ring/ring_test.go +++ b/ring/ring_test.go @@ -60,7 +60,7 @@ func TestRing(t *testing.T) { } testNewRing(t) - for _, defaultParam := range defaultParams { + for _, defaultParam := range defaultParams[:] { var testContext *testParams if testContext, err = genTestParams(defaultParam); err != nil { @@ -189,12 +189,12 @@ func testDivFloorByLastModulusMany(testContext *testParams, t *testing.T) { coeffs[i].Quo(coeffs[i], NewUint(10)) } - nbRescals := len(testContext.ringQ.Modulus) - 1 + nbRescales := len(testContext.ringQ.Modulus) - 1 coeffsWant := make([]*big.Int, testContext.ringQ.N) for i := range coeffs { coeffsWant[i] = new(big.Int).Set(coeffs[i]) - for j := 0; j < nbRescals; j++ { + for j := 0; j < nbRescales; j++ { coeffsWant[i].Quo(coeffsWant[i], NewUint(testContext.ringQ.Modulus[len(testContext.ringQ.Modulus)-1-j])) } } @@ -202,13 +202,14 @@ func testDivFloorByLastModulusMany(testContext *testParams, t *testing.T) { polTest0 := testContext.ringQ.NewPoly() polTest1 := testContext.ringQ.NewPoly() polWant := testContext.ringQ.NewPoly() + pool := testContext.ringQ.NewPoly() testContext.ringQ.SetCoefficientsBigint(coeffs, polTest0) testContext.ringQ.SetCoefficientsBigint(coeffsWant, polWant) - testContext.ringQ.DivFloorByLastModulusMany(polTest0, polTest1, nbRescals) + testContext.ringQ.DivFloorByLastModulusManyLvl(polTest0.Level(), nbRescales, polTest0, pool, polTest1) for i := 0; i < testContext.ringQ.N; i++ { - for j := 0; j < len(testContext.ringQ.Modulus)-nbRescals; j++ { + for j := 0; j < polTest0.Level()-nbRescales+1; j++ { require.Equalf(t, polWant.Coeffs[j][i], polTest1.Coeffs[j][i], "coeff %v Qi%v = %s", i, j, coeffs[i].String()) } } @@ -238,13 +239,14 @@ func testDivRoundByLastModulusMany(testContext *testParams, t *testing.T) { polTest0 := testContext.ringQ.NewPoly() polTest1 := testContext.ringQ.NewPoly() polWant := testContext.ringQ.NewPoly() + pool := testContext.ringQ.NewPoly() testContext.ringQ.SetCoefficientsBigint(coeffs, polTest0) testContext.ringQ.SetCoefficientsBigint(coeffsWant, polWant) - testContext.ringQ.DivRoundByLastModulusMany(polTest0, polTest1, nbRescals) + testContext.ringQ.DivRoundByLastModulusManyLvl(polTest0.Level(), nbRescals, polTest0, pool, polTest1) for i := 0; i < testContext.ringQ.N; i++ { - for j := 0; j < len(testContext.ringQ.Modulus)-nbRescals; j++ { + for j := 0; j < polTest0.Level()-nbRescals+1; j++ { require.Equalf(t, polWant.Coeffs[j][i], polTest1.Coeffs[j][i], "coeff %v Qi%v = %s", i, j, coeffs[i].String()) } } @@ -558,30 +560,87 @@ func testMulScalarBigint(testContext *testParams, t *testing.T) { func testExtendBasis(testContext *testParams, t *testing.T) { - t.Run(testString("ExtendBasis/", testContext.ringQ), func(t *testing.T) { + t.Run(testString("ModUp/", testContext.ringQ), func(t *testing.T) { basisextender := NewFastBasisExtender(testContext.ringQ, testContext.ringP) + levelQ := len(testContext.ringQ.Modulus) - 2 + levelP := len(testContext.ringQ.Modulus) - 2 + + Q := NewUint(testContext.ringQ.Modulus[0]) + for i := 1; i < levelQ+1; i++ { + Q.Mul(Q, NewUint(testContext.ringQ.Modulus[i])) + } + coeffs := make([]*big.Int, testContext.ringQ.N) for i := 0; i < testContext.ringQ.N; i++ { - coeffs[i] = RandInt(testContext.ringQ.ModulusBigint) + coeffs[i] = RandInt(Q) } - Pol := testContext.ringQ.NewPoly() - PolTest := testContext.ringP.NewPoly() - PolWant := testContext.ringP.NewPoly() + PolQHave := testContext.ringQ.NewPolyLvl(levelQ) + PolPTest := testContext.ringP.NewPolyLvl(levelP) + PolPWant := testContext.ringP.NewPolyLvl(levelP) - testContext.ringQ.SetCoefficientsBigint(coeffs, Pol) - testContext.ringP.SetCoefficientsBigint(coeffs, PolWant) + testContext.ringQ.SetCoefficientsBigintLvl(levelQ, coeffs, PolQHave) + testContext.ringP.SetCoefficientsBigintLvl(levelP, coeffs, PolPWant) - basisextender.ModUpSplitQP(len(testContext.ringQ.Modulus)-1, Pol, PolTest) + basisextender.ModUpQtoP(levelQ, levelP, PolQHave, PolPTest) + testContext.ringP.Reduce(PolPTest, PolPTest) - testContext.ringP.Reduce(PolTest, PolTest) - - for i := range testContext.ringP.Modulus { - require.Equal(t, PolTest.Coeffs[i][:testContext.ringQ.N], PolWant.Coeffs[i][:testContext.ringQ.N]) + for i := range testContext.ringP.Modulus[:levelP+1] { + require.Equal(t, PolPTest.Coeffs[i][:testContext.ringQ.N], PolPWant.Coeffs[i][:testContext.ringQ.N]) } }) + + t.Run(testString("ModDown/", testContext.ringQ), func(t *testing.T) { + + basisextender := NewFastBasisExtender(testContext.ringQ, testContext.ringP) + + levelQ := len(testContext.ringQ.Modulus) - 2 + levelP := len(testContext.ringP.Modulus) - 2 + + Q := NewUint(1) + P := NewUint(1) + QP := NewUint(1) + for i := range testContext.ringQ.Modulus[:levelQ+1] { + Q.Mul(Q, NewUint(testContext.ringQ.Modulus[i])) + } + + for i := range testContext.ringP.Modulus[:levelP+1] { + P.Mul(P, NewUint(testContext.ringP.Modulus[i])) + } + + QP.Mul(QP, Q) + QP.Mul(QP, P) + + coeffs := make([]*big.Int, testContext.ringQ.N) + for i := 0; i < testContext.ringQ.N; i++ { + coeffs[i] = RandInt(QP) + coeffs[i].Quo(coeffs[i], NewUint(10)) + } + + coeffsWant := make([]*big.Int, testContext.ringQ.N) + for i := range coeffs { + coeffsWant[i] = new(big.Int).Set(coeffs[i]) + coeffsWant[i].Quo(coeffsWant[i], P) + } + + PolQHave := testContext.ringQ.NewPolyLvl(levelQ) + PolPHave := testContext.ringP.NewPolyLvl(levelP) + PolQWant := testContext.ringP.NewPolyLvl(levelQ) + + testContext.ringQ.SetCoefficientsBigintLvl(levelQ, coeffs, PolQHave) + testContext.ringP.SetCoefficientsBigintLvl(levelP, coeffs, PolPHave) + testContext.ringQ.SetCoefficientsBigintLvl(levelQ, coeffsWant, PolQWant) + + basisextender.ModDownQPtoQ(levelQ, levelP, PolQHave, PolPHave, PolQHave) + testContext.ringQ.Reduce(PolQHave, PolQHave) + + for i := 0; i < levelQ+1; i++ { + require.Equal(t, PolQHave.Coeffs[i][:testContext.ringQ.N], PolQWant.Coeffs[i][:testContext.ringQ.N]) + } + + }) } func testScaling(testContext *testParams, t *testing.T) { @@ -616,7 +675,9 @@ func testScaling(testContext *testParams, t *testing.T) { t.Run(testString("Scaling/RNS", testContext.ringQ), func(t *testing.T) { - scaler := NewRNSScaler(T, testContext.ringQ) + ringT, _ := NewRing(testContext.ringQ.N, []uint64{T}) + + scaler := NewRNSScaler(testContext.ringQ, ringT) coeffs := make([]*big.Int, testContext.ringQ.N) for i := 0; i < testContext.ringQ.N; i++ { diff --git a/ring/utils.go b/ring/utils.go index f382f8c1..9fa03af6 100644 --- a/ring/utils.go +++ b/ring/utils.go @@ -27,7 +27,7 @@ func PowerOf2(x uint64, n int, q, qInv uint64) (r uint64) { // ModExp performs the modular exponentiation x^e mod p, // x and p are required to be at most 64 bits to avoid an overflow. -func ModExp(x uint64, e int, p uint64) (result uint64) { +func ModExp(x, e, p uint64) (result uint64) { params := BRedParams(p) result = 1 for i := e; i > 0; i >>= 1 { @@ -39,9 +39,9 @@ func ModExp(x uint64, e int, p uint64) (result uint64) { return result } -// modexpMontgomery performs the modular exponentiation x^e mod p, +// ModexpMontgomery performs the modular exponentiation x^e mod p, // where x is in Montgomery form, and returns x^e in Montgomery form. -func modexpMontgomery(x uint64, e int, q, qInv uint64, bredParams []uint64) (result uint64) { +func ModexpMontgomery(x uint64, e int, q, qInv uint64, bredParams []uint64) (result uint64) { result = MForm(1, q, bredParams) @@ -80,7 +80,7 @@ func primitiveRoot(q uint64) (g uint64) { for _, factor := range factors { tmp = (q - 1) / factor // if for any factor of q-1, g^(q-1)/factor = 1 mod q, g is not a primitive root - if ModExp(g, int(tmp), q) == 1 { + if ModExp(g, tmp, q) == 1 { notFoundPrimitiveRoot = true break } diff --git a/rlwe/decryptor.go b/rlwe/decryptor.go index d75bac8d..ef546dc9 100644 --- a/rlwe/decryptor.go +++ b/rlwe/decryptor.go @@ -24,7 +24,7 @@ type decryptor struct { // NewDecryptor instantiates a new generic RLWE Decryptor. func NewDecryptor(params Parameters, sk *SecretKey) Decryptor { - if sk.Value.Degree() != params.N() { + if sk.Value.Q.Degree() != params.N() { panic("secret_key is invalid for the provided parameters") } @@ -55,7 +55,7 @@ func (decryptor *decryptor) Decrypt(ciphertext *Ciphertext, plaintext *Plaintext for i := ciphertext.Degree(); i > 0; i-- { - ringQ.MulCoeffsMontgomeryLvl(level, plaintext.Value, decryptor.sk.Value, plaintext.Value) + ringQ.MulCoeffsMontgomeryLvl(level, plaintext.Value, decryptor.sk.Value.Q, plaintext.Value) if !ciphertext.Value[0].IsNTT { ringQ.NTTLazyLvl(level, ciphertext.Value[i-1], decryptor.pool) diff --git a/rlwe/elements.go b/rlwe/elements.go index dd6e472c..aae6a573 100644 --- a/rlwe/elements.go +++ b/rlwe/elements.go @@ -131,6 +131,59 @@ func (el *Ciphertext) Resize(params Parameters, degree int) { } } +// SwitchCiphertextRingDegreeNTT changes the ring degree of ctIn to the one of ctOut. +// Maps Y^{N/n} -> X^{N} or X^{N} -> Y^{N/n}. +// If the ring degree of ctOut is larger than the one of ctIn, then the ringQ of ctIn +// must be provided (else a nil pointer). +// The ctIn must be in the NTT domain and ctOut will be in the NTT domain. +func SwitchCiphertextRingDegreeNTT(ctIn *Ciphertext, ringQLargeDim *ring.Ring, ctOut *Ciphertext) { + + NIn, NOut := len(ctIn.Value[0].Coeffs[0]), len(ctOut.Value[0].Coeffs[0]) + + if NIn > NOut { + r := ringQLargeDim + gap := NIn / NOut + pool := make([]uint64, NIn) + for i := range ctOut.Value { + for j := range ctOut.Value[i].Coeffs { + tmp0, tmp1 := ctOut.Value[i].Coeffs[j], ctIn.Value[i].Coeffs[j] + ring.InvNTT(tmp1, pool, NIn, r.NttPsiInv[j], r.NttNInv[j], r.Modulus[j], r.MredParams[j]) + for w0, w1 := 0, 0; w0 < NOut; w0, w1 = w0+1, w1+gap { + pool[w0] = pool[w1] + } + ring.NTT(pool, tmp0, NOut, r.NttPsi[j], r.Modulus[j], r.MredParams[j], r.BredParams[j]) + } + } + } else { + for i := range ctOut.Value { + ring.MapSmallDimensionToLargerDimensionNTT(ctIn.Value[i], ctOut.Value[i]) + } + } +} + +// SwitchCiphertextRingDegree changes the ring degree of ctIn to the one of ctOut. +// Maps Y^{N/n} -> X^{N} or X^{N} -> Y^{N/n}. +// If the ring degree of ctOut is larger than the one of ctIn, then the ringQ of ctIn +// must be provided (else a nil pointer). +func SwitchCiphertextRingDegree(ctIn *Ciphertext, ctOut *Ciphertext) { + + NIn, NOut := len(ctIn.Value[0].Coeffs[0]), len(ctOut.Value[0].Coeffs[0]) + + gapIn, gapOut := NOut/NIn, 1 + if NIn > NOut { + gapIn, gapOut = 1, NIn/NOut + } + + for i := range ctOut.Value { + for j := range ctOut.Value[i].Coeffs { + tmp0, tmp1 := ctOut.Value[i].Coeffs[j], ctIn.Value[i].Coeffs[j] + for w0, w1 := 0, 0; w0 < NOut; w0, w1 = w0+gapIn, w1+gapOut { + tmp0[w0] = tmp1[w1] + } + } + } +} + // CopyNew creates a new element as a copy of the target element. func (el *Ciphertext) CopyNew() *Ciphertext { diff --git a/rlwe/encryptor.go b/rlwe/encryptor.go index 9630c141..133c901b 100644 --- a/rlwe/encryptor.go +++ b/rlwe/encryptor.go @@ -52,7 +52,7 @@ type skEncryptor struct { func NewEncryptor(params Parameters, key interface{}) Encryptor { switch key := key.(type) { case *PublicKey: - if key.Value[0].Degree() != params.N() || key.Value[1].Degree() != params.N() { + if key.Value[0].Q.Degree() != params.N() || key.Value[1].Q.Degree() != params.N() { panic("cannot newEncryptor: pk ring degree does not match params ring degree") } encryptorBase := newEncryptorBase(params) @@ -62,7 +62,7 @@ func NewEncryptor(params Parameters, key interface{}) Encryptor { } return &pkFastEncryptor{encryptorBase, key} case *SecretKey: - if key.Value.Degree() != params.N() { + if key.Value.Q.Degree() != params.N() { panic("cannot newEncryptor: sk ring degree does not match params ring degree") } return &skEncryptor{newEncryptorBase(params), key} @@ -80,8 +80,11 @@ func NewFastEncryptor(params Parameters, key *PublicKey) Encryptor { // Encrypt encrypts the input Plaintext and write the result in ctOut. func (encryptor *pkEncryptor) Encrypt(plaintext *Plaintext, ctOut *Ciphertext) { + ringQ := encryptor.ringQ + ringQP := encryptor.params.RingQP() - lvl := utils.MinInt(plaintext.Level(), ctOut.Level()) + levelQ := utils.MinInt(plaintext.Level(), ctOut.Level()) + levelP := 0 poolQ0 := encryptor.poolQ[0] poolP0 := encryptor.poolP[0] @@ -90,90 +93,80 @@ func (encryptor *pkEncryptor) Encrypt(plaintext *Plaintext, ctOut *Ciphertext) { // We sample a R-WLE instance (encryption of zero) over the extended ring (ciphertext ring + special prime) - ringQ := encryptor.ringQ - ringP := encryptor.ringP - ciphertextNTT := ctOut.Value[0].IsNTT - encryptor.ternarySampler.ReadLvl(lvl, poolQ0) - extendBasisSmallNormAndCenter(ringQ, ringP, poolQ0, poolP0) + u := PolyQP{Q: poolQ0, P: poolP2} + + encryptor.ternarySampler.ReadLvl(levelQ, u.Q) + ringQP.ExtendBasisSmallNormAndCenter(u.Q, levelP, nil, u.P) // (#Q + #P) NTT - ringQ.NTTLvl(lvl, poolQ0, poolQ0) - ringP.NTT(poolP0, poolP0) + ringQP.NTTLvl(levelQ, levelP, u, u) + ringQP.MFormLvl(levelQ, levelP, u, u) - ringQ.MFormLvl(lvl, poolQ0, poolQ0) - ringP.MForm(poolP0, poolP0) - - pk0P := new(ring.Poly) - pk1P := new(ring.Poly) - pk0P.Coeffs = encryptor.pk.Value[0].Coeffs[len(ringQ.Modulus):] - pk1P.Coeffs = encryptor.pk.Value[1].Coeffs[len(ringQ.Modulus):] + ct0QP := PolyQP{Q: ctOut.Value[0], P: poolP0} + ct1QP := PolyQP{Q: ctOut.Value[1], P: poolP1} // ct0 = u*pk0 // ct1 = u*pk1 - ringQ.MulCoeffsMontgomeryLvl(lvl, poolQ0, encryptor.pk.Value[0], ctOut.Value[0]) - ringQ.MulCoeffsMontgomeryLvl(lvl, poolQ0, encryptor.pk.Value[1], ctOut.Value[1]) - ringP.MulCoeffsMontgomery(poolP0, pk1P, poolP1) - ringP.MulCoeffsMontgomery(poolP0, pk0P, poolP0) + ringQP.MulCoeffsMontgomeryLvl(levelQ, levelP, u, encryptor.pk.Value[0], ct0QP) + ringQP.MulCoeffsMontgomeryLvl(levelQ, levelP, u, encryptor.pk.Value[1], ct1QP) // 2*(#Q + #P) NTT - ringQ.InvNTTLvl(lvl, ctOut.Value[0], ctOut.Value[0]) - ringQ.InvNTTLvl(lvl, ctOut.Value[1], ctOut.Value[1]) - ringP.InvNTT(poolP0, poolP0) - ringP.InvNTT(poolP1, poolP1) + ringQP.InvNTTLvl(levelQ, levelP, ct0QP, ct0QP) + ringQP.InvNTTLvl(levelQ, levelP, ct1QP, ct1QP) - encryptor.gaussianSampler.ReadLvl(lvl, poolQ0) - extendBasisSmallNormAndCenter(ringQ, ringP, poolQ0, poolP2) - ringQ.AddLvl(lvl, ctOut.Value[0], poolQ0, ctOut.Value[0]) - ringP.Add(poolP0, poolP2, poolP0) + e := PolyQP{Q: poolQ0, P: poolP2} - encryptor.gaussianSampler.ReadLvl(lvl, poolQ0) - extendBasisSmallNormAndCenter(ringQ, ringP, poolQ0, poolP2) - ringQ.AddLvl(lvl, ctOut.Value[1], poolQ0, ctOut.Value[1]) - ringP.Add(poolP1, poolP2, poolP1) + encryptor.gaussianSampler.ReadLvl(levelQ, e.Q) + ringQP.ExtendBasisSmallNormAndCenter(e.Q, levelP, nil, e.P) + ringQP.AddLvl(levelQ, levelP, ct0QP, e, ct0QP) + + encryptor.gaussianSampler.ReadLvl(levelQ, e.Q) + ringQP.ExtendBasisSmallNormAndCenter(e.Q, levelP, nil, e.P) + ringQP.AddLvl(levelQ, levelP, ct1QP, e, ct1QP) // ct0 = (u*pk0 + e0)/P - encryptor.baseconverter.ModDownSplitPQ(lvl, ctOut.Value[0], poolP0, ctOut.Value[0]) + encryptor.baseconverter.ModDownQPtoQ(levelQ, levelP, ct0QP.Q, ct0QP.P, ct0QP.Q) // ct1 = (u*pk1 + e1)/P - encryptor.baseconverter.ModDownSplitPQ(lvl, ctOut.Value[1], poolP1, ctOut.Value[1]) + encryptor.baseconverter.ModDownQPtoQ(levelQ, levelP, ct1QP.Q, ct1QP.P, ct1QP.Q) if ciphertextNTT { if !plaintext.Value.IsNTT { - ringQ.AddLvl(lvl, ctOut.Value[0], plaintext.Value, ctOut.Value[0]) + ringQ.AddLvl(levelQ, ctOut.Value[0], plaintext.Value, ctOut.Value[0]) } // 2*#Q NTT - ringQ.NTTLvl(lvl, ctOut.Value[0], ctOut.Value[0]) - ringQ.NTTLvl(lvl, ctOut.Value[1], ctOut.Value[1]) + ringQ.NTTLvl(levelQ, ctOut.Value[0], ctOut.Value[0]) + ringQ.NTTLvl(levelQ, ctOut.Value[1], ctOut.Value[1]) if plaintext.Value.IsNTT { // ct0 = (u*pk0 + e0)/P + m - ringQ.AddLvl(lvl, ctOut.Value[0], plaintext.Value, ctOut.Value[0]) + ringQ.AddLvl(levelQ, ctOut.Value[0], plaintext.Value, ctOut.Value[0]) } } else { if !plaintext.Value.IsNTT { - ringQ.AddLvl(lvl, ctOut.Value[0], plaintext.Value, ctOut.Value[0]) + ringQ.AddLvl(levelQ, ctOut.Value[0], plaintext.Value, ctOut.Value[0]) } else { - ringQ.InvNTTLvl(lvl, plaintext.Value, poolQ0) - ringQ.AddLvl(lvl, ctOut.Value[0], poolQ0, ctOut.Value[0]) + ringQ.InvNTTLvl(levelQ, plaintext.Value, poolQ0) + ringQ.AddLvl(levelQ, ctOut.Value[0], poolQ0, ctOut.Value[0]) } } ctOut.Value[1].IsNTT = ctOut.Value[0].IsNTT - ctOut.Value[0].Coeffs = ctOut.Value[0].Coeffs[:lvl+1] - ctOut.Value[1].Coeffs = ctOut.Value[1].Coeffs[:lvl+1] + ctOut.Value[0].Coeffs = ctOut.Value[0].Coeffs[:levelQ+1] + ctOut.Value[1].Coeffs = ctOut.Value[1].Coeffs[:levelQ+1] } // Encrypt encrypts the input Plaintext and write the result in ctOut. // It first encrypts zero in Q and then adds the plaintext. // This method is faster than the normal encryptor but result in a noisier ciphertext. func (encryptor *pkFastEncryptor) Encrypt(plaintext *Plaintext, ctOut *Ciphertext) { - lvl := utils.MinInt(plaintext.Level(), ctOut.Level()) + levelQ := utils.MinInt(plaintext.Level(), ctOut.Level()) poolQ0 := encryptor.poolQ[0] @@ -181,39 +174,39 @@ func (encryptor *pkFastEncryptor) Encrypt(plaintext *Plaintext, ctOut *Ciphertex ciphertextNTT := ctOut.Value[0].IsNTT - encryptor.ternarySampler.ReadLvl(lvl, poolQ0) - ringQ.NTTLvl(lvl, poolQ0, poolQ0) - ringQ.MFormLvl(lvl, poolQ0, poolQ0) + encryptor.ternarySampler.ReadLvl(levelQ, poolQ0) + ringQ.NTTLvl(levelQ, poolQ0, poolQ0) + ringQ.MFormLvl(levelQ, poolQ0, poolQ0) // ct0 = u*pk0 - ringQ.MulCoeffsMontgomeryLvl(lvl, poolQ0, encryptor.pk.Value[0], ctOut.Value[0]) + ringQ.MulCoeffsMontgomeryLvl(levelQ, poolQ0, encryptor.pk.Value[0].Q, ctOut.Value[0]) // ct1 = u*pk1 - ringQ.MulCoeffsMontgomeryLvl(lvl, poolQ0, encryptor.pk.Value[1], ctOut.Value[1]) + ringQ.MulCoeffsMontgomeryLvl(levelQ, poolQ0, encryptor.pk.Value[1].Q, ctOut.Value[1]) if ciphertextNTT { // ct1 = u*pk1 + e1 - encryptor.gaussianSampler.ReadLvl(lvl, poolQ0) - ringQ.NTTLvl(lvl, poolQ0, poolQ0) - ringQ.AddLvl(lvl, ctOut.Value[1], poolQ0, ctOut.Value[1]) + encryptor.gaussianSampler.ReadLvl(levelQ, poolQ0) + ringQ.NTTLvl(levelQ, poolQ0, poolQ0) + ringQ.AddLvl(levelQ, ctOut.Value[1], poolQ0, ctOut.Value[1]) // ct0 = u*pk0 + e0 - encryptor.gaussianSampler.ReadLvl(lvl, poolQ0) + encryptor.gaussianSampler.ReadLvl(levelQ, poolQ0) if !plaintext.Value.IsNTT { - ringQ.AddLvl(lvl, poolQ0, plaintext.Value, poolQ0) - ringQ.NTTLvl(lvl, poolQ0, poolQ0) - ringQ.AddLvl(lvl, ctOut.Value[0], poolQ0, ctOut.Value[0]) + ringQ.AddLvl(levelQ, poolQ0, plaintext.Value, poolQ0) + ringQ.NTTLvl(levelQ, poolQ0, poolQ0) + ringQ.AddLvl(levelQ, ctOut.Value[0], poolQ0, ctOut.Value[0]) } else { - ringQ.NTTLvl(lvl, poolQ0, poolQ0) - ringQ.AddLvl(lvl, ctOut.Value[0], poolQ0, ctOut.Value[0]) - ringQ.AddLvl(lvl, ctOut.Value[0], plaintext.Value, ctOut.Value[0]) + ringQ.NTTLvl(levelQ, poolQ0, poolQ0) + ringQ.AddLvl(levelQ, ctOut.Value[0], poolQ0, ctOut.Value[0]) + ringQ.AddLvl(levelQ, ctOut.Value[0], plaintext.Value, ctOut.Value[0]) } } else { - ringQ.InvNTTLvl(lvl, ctOut.Value[0], ctOut.Value[0]) - ringQ.InvNTTLvl(lvl, ctOut.Value[1], ctOut.Value[1]) + ringQ.InvNTTLvl(levelQ, ctOut.Value[0], ctOut.Value[0]) + ringQ.InvNTTLvl(levelQ, ctOut.Value[1], ctOut.Value[1]) // ct[0] = pk[0]*u + e0 encryptor.gaussianSampler.ReadAndAddLvl(ctOut.Level(), ctOut.Value[0]) @@ -222,17 +215,17 @@ func (encryptor *pkFastEncryptor) Encrypt(plaintext *Plaintext, ctOut *Ciphertex encryptor.gaussianSampler.ReadAndAddLvl(ctOut.Level(), ctOut.Value[1]) if !plaintext.Value.IsNTT { - ringQ.AddLvl(lvl, ctOut.Value[0], plaintext.Value, ctOut.Value[0]) + ringQ.AddLvl(levelQ, ctOut.Value[0], plaintext.Value, ctOut.Value[0]) } else { - ringQ.InvNTTLvl(lvl, plaintext.Value, poolQ0) - ringQ.AddLvl(lvl, ctOut.Value[0], poolQ0, ctOut.Value[0]) + ringQ.InvNTTLvl(levelQ, plaintext.Value, poolQ0) + ringQ.AddLvl(levelQ, ctOut.Value[0], poolQ0, ctOut.Value[0]) } } ctOut.Value[1].IsNTT = ctOut.Value[0].IsNTT - ctOut.Value[0].Coeffs = ctOut.Value[0].Coeffs[:lvl+1] - ctOut.Value[1].Coeffs = ctOut.Value[1].Coeffs[:lvl+1] + ctOut.Value[0].Coeffs = ctOut.Value[0].Coeffs[:levelQ+1] + ctOut.Value[1].Coeffs = ctOut.Value[1].Coeffs[:levelQ+1] } // Encrypt encrypts the input Plaintext and write the result in ctOut. @@ -251,27 +244,27 @@ func (encryptor *skEncryptor) encrypt(plaintext *Plaintext, ciphertext *Cipherte ringQ := encryptor.ringQ - lvl := utils.MinInt(plaintext.Level(), ciphertext.Level()) + levelQ := utils.MinInt(plaintext.Level(), ciphertext.Level()) poolQ0 := encryptor.poolQ[0] ciphertextNTT := ciphertext.Value[0].IsNTT - ringQ.MulCoeffsMontgomeryLvl(lvl, ciphertext.Value[1], encryptor.sk.Value, ciphertext.Value[0]) - ringQ.NegLvl(lvl, ciphertext.Value[0], ciphertext.Value[0]) + ringQ.MulCoeffsMontgomeryLvl(levelQ, ciphertext.Value[1], encryptor.sk.Value.Q, ciphertext.Value[0]) + ringQ.NegLvl(levelQ, ciphertext.Value[0], ciphertext.Value[0]) if ciphertextNTT { - encryptor.gaussianSampler.ReadLvl(lvl, poolQ0) + encryptor.gaussianSampler.ReadLvl(levelQ, poolQ0) if plaintext.Value.IsNTT { - ringQ.NTTLvl(lvl, poolQ0, poolQ0) - ringQ.AddLvl(lvl, ciphertext.Value[0], poolQ0, ciphertext.Value[0]) - ringQ.AddLvl(lvl, ciphertext.Value[0], plaintext.Value, ciphertext.Value[0]) + ringQ.NTTLvl(levelQ, poolQ0, poolQ0) + ringQ.AddLvl(levelQ, ciphertext.Value[0], poolQ0, ciphertext.Value[0]) + ringQ.AddLvl(levelQ, ciphertext.Value[0], plaintext.Value, ciphertext.Value[0]) } else { - ringQ.AddLvl(lvl, poolQ0, plaintext.Value, poolQ0) - ringQ.NTTLvl(lvl, poolQ0, poolQ0) - ringQ.AddLvl(lvl, ciphertext.Value[0], poolQ0, ciphertext.Value[0]) + ringQ.AddLvl(levelQ, poolQ0, plaintext.Value, poolQ0) + ringQ.NTTLvl(levelQ, poolQ0, poolQ0) + ringQ.AddLvl(levelQ, ciphertext.Value[0], poolQ0, ciphertext.Value[0]) } ciphertext.Value[0].IsNTT = true @@ -280,25 +273,25 @@ func (encryptor *skEncryptor) encrypt(plaintext *Plaintext, ciphertext *Cipherte } else { if plaintext.Value.IsNTT { - ringQ.AddLvl(lvl, ciphertext.Value[0], plaintext.Value, ciphertext.Value[0]) - ringQ.InvNTTLvl(lvl, ciphertext.Value[0], ciphertext.Value[0]) + ringQ.AddLvl(levelQ, ciphertext.Value[0], plaintext.Value, ciphertext.Value[0]) + ringQ.InvNTTLvl(levelQ, ciphertext.Value[0], ciphertext.Value[0]) } else { - ringQ.InvNTTLvl(lvl, ciphertext.Value[0], ciphertext.Value[0]) - ringQ.AddLvl(lvl, ciphertext.Value[0], plaintext.Value, ciphertext.Value[0]) + ringQ.InvNTTLvl(levelQ, ciphertext.Value[0], ciphertext.Value[0]) + ringQ.AddLvl(levelQ, ciphertext.Value[0], plaintext.Value, ciphertext.Value[0]) } encryptor.gaussianSampler.ReadAndAddLvl(ciphertext.Level(), ciphertext.Value[0]) - ringQ.InvNTTLvl(lvl, ciphertext.Value[1], ciphertext.Value[1]) + ringQ.InvNTTLvl(levelQ, ciphertext.Value[1], ciphertext.Value[1]) ciphertext.Value[0].IsNTT = false ciphertext.Value[1].IsNTT = false } - ciphertext.Value[0].Coeffs = ciphertext.Value[0].Coeffs[:lvl+1] - ciphertext.Value[1].Coeffs = ciphertext.Value[1].Coeffs[:lvl+1] + ciphertext.Value[0].Coeffs = ciphertext.Value[0].Coeffs[:levelQ+1] + ciphertext.Value[1].Coeffs = ciphertext.Value[1].Coeffs[:levelQ+1] } func newEncryptorBase(params Parameters) encryptorBase { @@ -331,24 +324,3 @@ func newEncryptorBase(params Parameters) encryptorBase { func (encryptor *encryptorBase) EncryptFromCRP(plaintext *Plaintext, crp *ring.Poly, ctOut *Ciphertext) { panic("Cannot encrypt with CRP using an encryptor created with the public-key") } - -func extendBasisSmallNormAndCenter(ringQ, ringP *ring.Ring, polQ, polP *ring.Poly) { - var coeff, Q, QHalf, sign uint64 - Q = ringQ.Modulus[0] - QHalf = Q >> 1 - - for j := 0; j < ringQ.N; j++ { - - coeff = polQ.Coeffs[0][j] - - sign = 1 - if coeff > QHalf { - coeff = Q - coeff - sign = 0 - } - - for i, pi := range ringP.Modulus { - polP.Coeffs[i][j] = (coeff * sign) | (pi-coeff)*(sign^1) - } - } -} diff --git a/rlwe/keygen.go b/rlwe/keygen.go index 61d51534..05e396c4 100644 --- a/rlwe/keygen.go +++ b/rlwe/keygen.go @@ -1,6 +1,7 @@ package rlwe import ( + "math" "math/big" "github.com/ldsec/lattigo/v2/ring" @@ -29,32 +30,30 @@ type KeyGenerator interface { // KeyGenerator is a structure that stores the elements required to create new keys, // as well as a small memory pool for intermediate values. type keyGenerator struct { - params Parameters - ringQP *ring.Ring - pBigInt *big.Int - polypool [2]*ring.Poly - gaussianSampler *ring.GaussianSampler - uniformSampler *ring.UniformSampler + params Parameters + poolQ *ring.Poly + poolQP PolyQP + gaussianSamplerQ *ring.GaussianSampler + uniformSamplerQ *ring.UniformSampler + uniformSamplerP *ring.UniformSampler } // NewKeyGenerator creates a new KeyGenerator, from which the secret and public keys, as well as the evaluation, // rotation and switching keys can be generated. func NewKeyGenerator(params Parameters) KeyGenerator { - ringQP := params.RingQP() - prng, err := utils.NewPRNG() if err != nil { panic(err) } return &keyGenerator{ - params: params, - ringQP: ringQP, - pBigInt: params.PBigInt(), - polypool: [2]*ring.Poly{ringQP.NewPoly(), ringQP.NewPoly()}, - gaussianSampler: ring.NewGaussianSampler(prng, ringQP, params.Sigma(), int(6*params.Sigma())), - uniformSampler: ring.NewUniformSampler(prng, ringQP), + params: params, + poolQ: params.RingQ().NewPoly(), + poolQP: params.RingQP().NewPoly(), + gaussianSamplerQ: ring.NewGaussianSampler(prng, params.RingQ(), params.Sigma(), int(6*params.Sigma())), + uniformSamplerQ: ring.NewUniformSampler(prng, params.RingQ()), + uniformSamplerP: ring.NewUniformSampler(prng, params.RingP()), } } @@ -63,12 +62,9 @@ func (keygen *keyGenerator) GenSecretKey() (sk *SecretKey) { return keygen.GenSecretKeyWithDistrib(1.0 / 3) } +// GenSecretKey generates a new SecretKey with the error distribution. func (keygen *keyGenerator) GenSecretKeyGaussian() (sk *SecretKey) { - sk = new(SecretKey) - - sk.Value = keygen.gaussianSampler.ReadNew() - keygen.ringQP.NTT(sk.Value, sk.Value) - return sk + return keygen.genSecretKeyFromSampler(keygen.gaussianSamplerQ) } // GenSecretKeyWithDistrib generates a new SecretKey with the distribution [(p-1)/2, p, (p-1)/2]. @@ -77,12 +73,8 @@ func (keygen *keyGenerator) GenSecretKeyWithDistrib(p float64) (sk *SecretKey) { if err != nil { panic(err) } - ternarySamplerMontgomery := ring.NewTernarySampler(prng, keygen.ringQP, p, true) - - sk = new(SecretKey) - sk.Value = ternarySamplerMontgomery.ReadNew() - keygen.ringQP.NTT(sk.Value, sk.Value) - return sk + ternarySamplerMontgomery := ring.NewTernarySampler(prng, keygen.params.RingQ(), p, false) + return keygen.genSecretKeyFromSampler(ternarySamplerMontgomery) } // GenSecretKeySparse generates a new SecretKey with exactly hw non-zero coefficients. @@ -91,30 +83,28 @@ func (keygen *keyGenerator) GenSecretKeySparse(hw int) (sk *SecretKey) { if err != nil { panic(err) } - ternarySamplerMontgomery := ring.NewTernarySamplerSparse(prng, keygen.ringQP, hw, true) - - sk = new(SecretKey) - sk.Value = ternarySamplerMontgomery.ReadNew() - keygen.ringQP.NTT(sk.Value, sk.Value) - return sk + ternarySamplerMontgomery := ring.NewTernarySamplerSparse(prng, keygen.params.RingQ(), hw, false) + return keygen.genSecretKeyFromSampler(ternarySamplerMontgomery) } // GenPublicKey generates a new public key from the provided SecretKey. func (keygen *keyGenerator) GenPublicKey(sk *SecretKey) (pk *PublicKey) { pk = new(PublicKey) - - ringQP := keygen.ringQP + ringQP := keygen.params.RingQP() + levelQ, levelP := keygen.params.QCount()-1, keygen.params.PCount()-1 //pk[0] = [-as + e] //pk[1] = [a] + pk = NewPublicKey(keygen.params) + keygen.gaussianSamplerQ.Read(pk.Value[0].Q) + ringQP.ExtendBasisSmallNormAndCenter(pk.Value[0].Q, levelP, nil, pk.Value[0].P) + ringQP.NTTLvl(levelQ, levelP, pk.Value[0], pk.Value[0]) - pk.Value[0] = keygen.gaussianSampler.ReadNew() - ringQP.NTT(pk.Value[0], pk.Value[0]) - pk.Value[1] = keygen.uniformSampler.ReadNew() - - ringQP.MulCoeffsMontgomeryAndSub(sk.Value, pk.Value[1], pk.Value[0]) + keygen.uniformSamplerQ.Read(pk.Value[1].Q) + keygen.uniformSamplerP.Read(pk.Value[1].P) + ringQP.MulCoeffsMontgomeryAndSubLvl(levelQ, levelP, sk.Value, pk.Value[1], pk.Value[0]) return pk } @@ -133,29 +123,26 @@ func (keygen *keyGenerator) GenKeyPairSparse(hw int) (sk *SecretKey, pk *PublicK // GenRelinKey generates a new EvaluationKey that will be used to relinearize Ciphertexts during multiplication. func (keygen *keyGenerator) GenRelinearizationKey(sk *SecretKey, maxDegree int) (evk *RelinearizationKey) { - if keygen.ringQP == nil { + if keygen.params.PCount() == 0 { panic("modulus P is empty") } + levelQ := keygen.params.QCount() - 1 + levelP := keygen.params.PCount() - 1 + evk = new(RelinearizationKey) evk.Keys = make([]*SwitchingKey, maxDegree) for i := range evk.Keys { - evk.Keys[i] = NewSwitchingKey(keygen.params) + evk.Keys[i] = NewSwitchingKey(keygen.params, levelQ, levelP) } - keygen.polypool[0].CopyValues(sk.Value) // TODO Remove ? - - ringQP := keygen.ringQP - - keygen.polypool[1].CopyValues(sk.Value) + keygen.poolQP.Q.CopyValues(sk.Value.Q) + ringQ := keygen.params.RingQ() for i := 0; i < maxDegree; i++ { - ringQP.MulCoeffsMontgomery(keygen.polypool[1], sk.Value, keygen.polypool[1]) - keygen.newSwitchingKey(keygen.polypool[1], sk.Value, evk.Keys[i]) + ringQ.MulCoeffsMontgomery(keygen.poolQP.Q, sk.Value.Q, keygen.poolQP.Q) + keygen.genSwitchingKey(keygen.poolQP.Q, sk.Value, evk.Keys[i]) } - keygen.polypool[0].Zero() - keygen.polypool[1].Zero() - return } @@ -170,7 +157,7 @@ func (keygen *keyGenerator) GenRotationKeys(galEls []uint64, sk *SecretKey) (rks } func (keygen *keyGenerator) GenSwitchingKeyForRotationBy(k int, sk *SecretKey) (swk *SwitchingKey) { - swk = NewSwitchingKey(keygen.params) + swk = NewSwitchingKey(keygen.params, keygen.params.QCount()-1, keygen.params.PCount()-1) galElInv := keygen.params.GaloisElementForColumnRotationBy(-int(k)) keygen.genrotKey(sk.Value, galElInv, swk) return @@ -190,28 +177,14 @@ func (keygen *keyGenerator) GenRotationKeysForRotations(ks []int, includeConjuga return keygen.GenRotationKeys(galEls, sk) } -// GenSwitchingKey generates a new key-switching key, that will re-encrypt a Ciphertext encrypted under the input key into the output key. -func (keygen *keyGenerator) GenSwitchingKey(skInput, skOutput *SecretKey) (newevakey *SwitchingKey) { - - if keygen.params.PCount() == 0 { - panic("Cannot GenSwitchingKey: modulus P is empty") - } - - ring.CopyValues(skInput.Value, keygen.polypool[0]) - newevakey = NewSwitchingKey(keygen.params) - keygen.newSwitchingKey(keygen.polypool[0], skOutput.Value, newevakey) - keygen.polypool[0].Zero() - return -} - func (keygen *keyGenerator) GenSwitchingKeyForRowRotation(sk *SecretKey) (swk *SwitchingKey) { - swk = NewSwitchingKey(keygen.params) + swk = NewSwitchingKey(keygen.params, keygen.params.QCount()-1, keygen.params.PCount()-1) keygen.genrotKey(sk.Value, keygen.params.GaloisElementForRowRotation(), swk) return } func (keygen *keyGenerator) GenSwitchingKeyForGalois(galoisEl uint64, sk *SecretKey) (swk *SwitchingKey) { - swk = NewSwitchingKey(keygen.params) + swk = NewSwitchingKey(keygen.params, keygen.params.QCount()-1, keygen.params.PCount()-1) keygen.genrotKey(sk.Value, keygen.params.InverseGaloisElement(galoisEl), swk) return } @@ -221,41 +194,136 @@ func (keygen *keyGenerator) GenRotationKeysForInnerSum(sk *SecretKey) (rks *Rota return keygen.GenRotationKeys(keygen.params.GaloisElementsForRowInnerSum(), sk) } -func (keygen *keyGenerator) genrotKey(sk *ring.Poly, galEl uint64, swk *SwitchingKey) { +func (keygen *keyGenerator) genrotKey(sk PolyQP, galEl uint64, swk *SwitchingKey) { skIn := sk - skOut := keygen.polypool[1] + skOut := keygen.poolQP - index := ring.PermuteNTTIndex(galEl, uint64(keygen.ringQP.N)) - ring.PermuteNTTWithIndexLvl(keygen.params.QPCount()-1, skIn, index, skOut) + index := ring.PermuteNTTIndex(galEl, uint64(keygen.params.N())) + ring.PermuteNTTWithIndexLvl(keygen.params.QCount()-1, skIn.Q, index, skOut.Q) + ring.PermuteNTTWithIndexLvl(keygen.params.PCount()-1, skIn.P, index, skOut.P) - keygen.newSwitchingKey(skIn, skOut, swk) - - keygen.polypool[0].Zero() - keygen.polypool[1].Zero() + keygen.genSwitchingKey(skIn.Q, skOut, swk) } -func (keygen *keyGenerator) newSwitchingKey(skIn, skOut *ring.Poly, swk *SwitchingKey) { +// GenSwitchingKey generates a new key-switching key, that will re-encrypt a Ciphertext encrypted under the input key into the output key. +// If the ringDegree(skOutput) > ringDegree(skInput), generates [-a*SkOut + w*P*skIn_{Y^{N/n}} + e, a] in X^{N}. +// If the ringDegree(skOutput) < ringDegree(skInput), generates [-a*skOut_{Y^{N/n}} + w*P*skIn + e_{N}, a_{N}] in X^{N}. +// Else generates [-a*skOut + w*P*skIn + e, a] in X^{N}. +// The output switching key is always given in max(N, n) and in the moduli of the output switching key. +// When key-switching a ciphertext from Y^{N/n} to X^{N}, the ciphertext must first be mapped to X^{N} +// using SwitchCiphertextRingDegreeNTT(ctSmallDim, nil, ctLargeDim). +// When key-switching a ciphertext from X^{N} to Y^{N/n}, the output of the key-switch is in still X^{N} and +// must be mapped Y^{N/n} using SwitchCiphertextRingDegreeNTT(ctLargeDim, ringQLargeDim, ctSmallDim). +func (keygen *keyGenerator) GenSwitchingKey(skInput, skOutput *SecretKey) (swk *SwitchingKey) { - ringQP := keygen.ringQP + if keygen.params.PCount() == 0 { + panic("Cannot GenSwitchingKey: modulus P is empty") + } + + swk = NewSwitchingKey(keygen.params, skOutput.Value.Q.Level(), skOutput.Value.P.Level()) + + // n -> N + if len(skInput.Value.Q.Coeffs[0]) > len(skOutput.Value.Q.Coeffs[0]) { + + ring.MapSmallDimensionToLargerDimensionNTT(skOutput.Value.Q, keygen.poolQP.Q) + ring.MapSmallDimensionToLargerDimensionNTT(skOutput.Value.P, keygen.poolQP.P) + keygen.genSwitchingKey(skInput.Value.Q, keygen.poolQP, swk) + // N -> N or N -> n + } else { + + ring.MapSmallDimensionToLargerDimensionNTT(skInput.Value.Q, keygen.poolQ) + + if skInput.Value.Q.Level() < skOutput.Value.Q.Level() { + + ringQ := keygen.params.RingQ() + + ringQ.InvNTTLvl(0, keygen.poolQ, keygen.poolQP.Q) + ringQ.InvMFormLvl(0, keygen.poolQP.Q, keygen.poolQP.Q) + + Q := ringQ.Modulus[0] + QHalf := Q >> 1 + + polQ := keygen.poolQP.Q + polP := keygen.poolQ + var sign uint64 + for j := 0; j < ringQ.N; j++ { + + coeff := polQ.Coeffs[0][j] + + sign = 1 + if coeff > QHalf { + coeff = Q - coeff + sign = 0 + } + + for i := skInput.Value.Q.Level() + 1; i < skOutput.Value.Q.Level()+1; i++ { + polP.Coeffs[i][j] = (coeff * sign) | (ringQ.Modulus[i]-coeff)*(sign^1) + } + } + + for i := skInput.Value.Q.Level() + 1; i < skOutput.Value.Q.Level()+1; i++ { + ring.NTT(polP.Coeffs[i], polP.Coeffs[i], ringQ.N, ringQ.NttPsi[i], ringQ.Modulus[i], ringQ.MredParams[i], ringQ.BredParams[i]) + ring.MFormVec(polP.Coeffs[i], polP.Coeffs[i], ringQ.Modulus[i], ringQ.BredParams[i]) + } + } + + keygen.genSwitchingKey(keygen.poolQ, skOutput.Value, swk) + } + + return +} + +// genSecretKeyFromSampler generates a new SecretKey sampled from the provided Sampler. +func (keygen *keyGenerator) genSecretKeyFromSampler(sampler ring.Sampler) *SecretKey { + ringQP := keygen.params.RingQP() + sk := new(SecretKey) + sk.Value = ringQP.NewPoly() + levelQ, levelP := keygen.params.QCount()-1, keygen.params.PCount()-1 + sampler.Read(sk.Value.Q) + ringQP.ExtendBasisSmallNormAndCenter(sk.Value.Q, levelP, nil, sk.Value.P) + ringQP.NTTLvl(levelQ, levelP, sk.Value, sk.Value) + ringQP.MFormLvl(levelQ, levelP, sk.Value, sk.Value) + return sk +} + +func (keygen *keyGenerator) genSwitchingKey(skIn *ring.Poly, skOut PolyQP, swk *SwitchingKey) { + + ringQ := keygen.params.RingQ() + ringQP := keygen.params.RingQP() + + levelQ := len(swk.Value[0][0].Q.Coeffs) - 1 + levelP := len(swk.Value[0][0].P.Coeffs) - 1 + + var pBigInt *big.Int + if levelP == keygen.params.PCount()-1 { + pBigInt = keygen.params.RingP().ModulusBigint + } else { + P := keygen.params.RingP().Modulus + pBigInt = new(big.Int).SetUint64(P[0]) + for i := 1; i < levelP+1; i++ { + pBigInt.Mul(pBigInt, ring.NewUint(P[i])) + } + } // Computes P * skIn - ringQP.MulScalarBigint(skIn, keygen.pBigInt, keygen.polypool[0]) + ringQ.MulScalarBigintLvl(levelQ, skIn, pBigInt, keygen.poolQ) - alpha := keygen.params.PCount() - beta := keygen.params.Beta() + alpha := levelP + 1 + beta := int(math.Ceil(float64(levelQ+1) / float64(levelP+1))) var index int for i := 0; i < beta; i++ { // e - - keygen.gaussianSampler.Read(swk.Value[i][0]) - ringQP.NTTLazy(swk.Value[i][0], swk.Value[i][0]) - ringQP.MForm(swk.Value[i][0], swk.Value[i][0]) + keygen.gaussianSamplerQ.ReadLvl(levelQ, swk.Value[i][0].Q) + ringQP.ExtendBasisSmallNormAndCenter(swk.Value[i][0].Q, levelP, nil, swk.Value[i][0].P) + ringQP.NTTLazyLvl(levelQ, levelP, swk.Value[i][0], swk.Value[i][0]) + ringQP.MFormLvl(levelQ, levelP, swk.Value[i][0], swk.Value[i][0]) // a (since a is uniform, we consider we already sample it in the NTT and Montgomery domain) - keygen.uniformSampler.Read(swk.Value[i][1]) + keygen.uniformSamplerQ.ReadLvl(levelQ, swk.Value[i][1].Q) + keygen.uniformSamplerP.ReadLvl(levelP, swk.Value[i][1].P) // e + (skIn * P) * (q_star * q_tild) mod QP // @@ -268,21 +336,21 @@ func (keygen *keyGenerator) newSwitchingKey(skIn, skOut *ring.Poly, swk *Switchi index = i*alpha + j - qi := ringQP.Modulus[index] - p0tmp := keygen.polypool[0].Coeffs[index] - p1tmp := swk.Value[i][0].Coeffs[index] - - for w := 0; w < ringQP.N; w++ { - p1tmp[w] = ring.CRed(p1tmp[w]+p0tmp[w], qi) + // It handles the case where nb pj does not divide nb qi + if index >= levelQ+1 { + break } - // It handles the case where nb pj does not divide nb qi - if index >= keygen.params.QCount() { - break + qi := ringQ.Modulus[index] + p0tmp := keygen.poolQ.Coeffs[index] + p1tmp := swk.Value[i][0].Q.Coeffs[index] + + for w := 0; w < ringQ.N; w++ { + p1tmp[w] = ring.CRed(p1tmp[w]+p0tmp[w], qi) } } // (skIn * P) * (q_star * q_tild) - a * skOut + e mod QP - ringQP.MulCoeffsMontgomeryAndSub(swk.Value[i][1], skOut, swk.Value[i][0]) + ringQP.MulCoeffsMontgomeryAndSubLvl(levelQ, levelP, swk.Value[i][1], skOut, swk.Value[i][0]) } } diff --git a/rlwe/keys.go b/rlwe/keys.go index 4f5e4f49..c1d4249f 100644 --- a/rlwe/keys.go +++ b/rlwe/keys.go @@ -1,24 +1,22 @@ package rlwe import ( - "encoding/binary" - - "github.com/ldsec/lattigo/v2/ring" + "math" ) // SecretKey is a type for generic RLWE secret keys. type SecretKey struct { - Value *ring.Poly + Value PolyQP } // PublicKey is a type for generic RLWE public keys. type PublicKey struct { - Value [2]*ring.Poly + Value [2]PolyQP } // SwitchingKey is a type for generic RLWE public switching keys. type SwitchingKey struct { - Value [][2]*ring.Poly + Value [][2]PolyQP } // RelinearizationKey is a type for generic RLWE public relinearization keys. It stores a slice with a @@ -43,17 +41,12 @@ type EvaluationKey struct { // NewSecretKey generates a new SecretKey with zero values. func NewSecretKey(params Parameters) *SecretKey { - - sk := new(SecretKey) - sk.Value = ring.NewPoly(params.N(), params.QPCount()) - return sk + return &SecretKey{Value: params.RingQP().NewPoly()} } // NewPublicKey returns a new PublicKey with zero values. func NewPublicKey(params Parameters) (pk *PublicKey) { - ringDegree := params.N() - moduliCount := params.QPCount() - return &PublicKey{Value: [2]*ring.Poly{ring.NewPoly(ringDegree, moduliCount), ring.NewPoly(ringDegree, moduliCount)}} + return &PublicKey{Value: [2]PolyQP{params.RingQP().NewPoly(), params.RingQP().NewPoly()}} } // Equals checks two PublicKey struct for equality. @@ -61,8 +54,7 @@ func (pk *PublicKey) Equals(other *PublicKey) bool { if pk == other { return true } - nilVal := [2]*ring.Poly{} - return pk.Value != nilVal && other.Value != nilVal && pk.Value[0].Equals(other.Value[0]) && pk.Value[1].Equals(other.Value[1]) + return pk.Value[0].Equals(other.Value[0]) && pk.Value[1].Equals(other.Value[1]) } // NewRotationKeySet returns a new RotationKeySet with pre-allocated switching keys for each distinct galoisElement value. @@ -70,7 +62,7 @@ func NewRotationKeySet(params Parameters, galoisElement []uint64) (rotKey *Rotat rotKey = new(RotationKeySet) rotKey.Keys = make(map[uint64]*SwitchingKey, len(galoisElement)) for _, galEl := range galoisElement { - rotKey.Keys[galEl] = NewSwitchingKey(params) + rotKey.Keys[galEl] = NewSwitchingKey(params, params.QCount()-1, params.PCount()-1) } return } @@ -78,22 +70,21 @@ func NewRotationKeySet(params Parameters, galoisElement []uint64) (rotKey *Rotat // GetRotationKey return the rotation key for the given galois element or nil if such key is not in the set. The // second argument is true iff the first one is non-nil. func (rtks *RotationKeySet) GetRotationKey(galoisEl uint64) (*SwitchingKey, bool) { + if rtks.Keys == nil { + return nil, false + } rotKey, inSet := rtks.Keys[galoisEl] return rotKey, inSet } // NewSwitchingKey returns a new public switching key with pre-allocated zero-value -func NewSwitchingKey(params Parameters) *SwitchingKey { - ringDegree := params.N() - moduliCount := params.QPCount() - decompSize := params.Beta() - +func NewSwitchingKey(params Parameters, levelQ, levelP int) *SwitchingKey { + decompSize := int(math.Ceil(float64(levelQ+1) / float64(levelP+1))) swk := new(SwitchingKey) - swk.Value = make([][2]*ring.Poly, int(decompSize)) - + swk.Value = make([][2]PolyQP, int(decompSize)) for i := 0; i < decompSize; i++ { - swk.Value[i][0] = ring.NewPoly(ringDegree, moduliCount) - swk.Value[i][1] = ring.NewPoly(ringDegree, moduliCount) + swk.Value[i][0] = params.RingQP().NewPolyLvl(levelQ, levelP) + swk.Value[i][1] = params.RingQP().NewPolyLvl(levelQ, levelP) } return swk @@ -107,105 +98,26 @@ func NewRelinKey(params Parameters, maxRelinDegree int) (evakey *Relinearization evakey.Keys = make([]*SwitchingKey, maxRelinDegree) for d := 0; d < maxRelinDegree; d++ { - evakey.Keys[d] = NewSwitchingKey(params) + evakey.Keys[d] = NewSwitchingKey(params, params.QCount()-1, params.PCount()-1) } return } -// GetDataLen returns the length in bytes of the target SecretKey. -func (sk *SecretKey) GetDataLen(WithMetadata bool) (dataLen int) { - return sk.Value.GetDataLen(WithMetadata) -} - -// MarshalBinary encodes a secret key in a byte slice. -func (sk *SecretKey) MarshalBinary() (data []byte, err error) { - - data = make([]byte, sk.GetDataLen(true)) - - if _, err = sk.Value.WriteTo(data); err != nil { - return nil, err - } - - return data, nil -} - -// UnmarshalBinary decodes a previously marshaled SecretKey in the target SecretKey. -func (sk *SecretKey) UnmarshalBinary(data []byte) (err error) { - - sk.Value = new(ring.Poly) - - if _, err = sk.Value.DecodePolyNew(data); err != nil { - return err - } - - return nil -} - // CopyNew creates a deep copy of the receiver secret key and returns it. func (sk *SecretKey) CopyNew() *SecretKey { - if sk == nil || sk.Value == nil { + if sk == nil { return nil } return &SecretKey{sk.Value.CopyNew()} } -// GetDataLen returns the length in bytes of the target PublicKey. -func (pk *PublicKey) GetDataLen(WithMetadata bool) (dataLen int) { - - for _, el := range pk.Value { - dataLen += el.GetDataLen(WithMetadata) - } - - return -} - -// MarshalBinary encodes a PublicKey in a byte slice. -func (pk *PublicKey) MarshalBinary() (data []byte, err error) { - - dataLen := pk.GetDataLen(true) - - data = make([]byte, dataLen) - - var pointer, inc int - - if inc, err = pk.Value[0].WriteTo(data[pointer:]); err != nil { - return nil, err - } - - if _, err = pk.Value[1].WriteTo(data[pointer+inc:]); err != nil { - return nil, err - } - - return data, err - -} - -// UnmarshalBinary decodes a previously marshaled PublicKey in the target PublicKey. -func (pk *PublicKey) UnmarshalBinary(data []byte) (err error) { - - var pointer, inc int - - pk.Value[0] = new(ring.Poly) - pk.Value[1] = new(ring.Poly) - - if inc, err = pk.Value[0].DecodePolyNew(data[pointer:]); err != nil { - return err - } - - if _, err = pk.Value[1].DecodePolyNew(data[pointer+inc:]); err != nil { - return err - } - - return nil -} - // CopyNew creates a deep copy of the receiver PublicKey and returns it. func (pk *PublicKey) CopyNew() *PublicKey { - if pk == nil || pk.Value[0] == nil || pk.Value[1] == nil { + if pk == nil { return nil } - return &PublicKey{[2]*ring.Poly{pk.Value[0].CopyNew(), pk.Value[1].CopyNew()}} + return &PublicKey{[2]PolyQP{pk.Value[0].CopyNew(), pk.Value[1].CopyNew()}} } // Equals checks two RelinearizationKeys for equality. @@ -227,63 +139,6 @@ func (rlk *RelinearizationKey) Equals(other *RelinearizationKey) bool { return true } -// GetDataLen returns the length in bytes of the target EvaluationKey. -func (rlk *RelinearizationKey) GetDataLen(WithMetadata bool) (dataLen int) { - - if WithMetadata { - dataLen++ - } - - for _, evakey := range rlk.Keys { - dataLen += (*SwitchingKey)(evakey).GetDataLen(WithMetadata) - } - - return -} - -// MarshalBinary encodes an EvaluationKey key in a byte slice. -func (rlk *RelinearizationKey) MarshalBinary() (data []byte, err error) { - - var pointer int - - dataLen := rlk.GetDataLen(true) - - data = make([]byte, dataLen) - - data[0] = uint8(len(rlk.Keys)) - - pointer++ - - for _, evakey := range rlk.Keys { - - if pointer, err = (*SwitchingKey)(evakey).encode(pointer, data); err != nil { - return nil, err - } - } - - return data, nil -} - -// UnmarshalBinary decodes a previously marshaled EvaluationKey in the target EvaluationKey. -func (rlk *RelinearizationKey) UnmarshalBinary(data []byte) (err error) { - - deg := int(data[0]) - - rlk.Keys = make([]*SwitchingKey, deg) - - pointer := 1 - var inc int - for i := 0; i < deg; i++ { - rlk.Keys[i] = new(SwitchingKey) - if inc, err = rlk.Keys[i].decode(data[pointer:]); err != nil { - return err - } - pointer += inc - } - - return nil -} - // CopyNew creates a deep copy of the receiver RelinearizationKey and returns it. func (rlk *RelinearizationKey) CopyNew() *RelinearizationKey { if rlk == nil || len(rlk.Keys) == 0 { @@ -308,118 +163,25 @@ func (swk *SwitchingKey) Equals(other *SwitchingKey) bool { return false } for i := range swk.Value { - if !(swk.Value[i][0].Equals(other.Value[i][0]) && swk.Value[i][1].Equals(other.Value[i][1])) { + if !((&PublicKey{Value: swk.Value[i]}).Equals(&PublicKey{Value: other.Value[i]})) { return false } } return true } -// GetDataLen returns the length in bytes of the target SwitchingKey. -func (swk *SwitchingKey) GetDataLen(WithMetadata bool) (dataLen int) { - - if WithMetadata { - dataLen++ - } - - for j := uint64(0); j < uint64(len(swk.Value)); j++ { - dataLen += swk.Value[j][0].GetDataLen(WithMetadata) - dataLen += swk.Value[j][1].GetDataLen(WithMetadata) - } - - return -} - -// MarshalBinary encodes an SwitchingKey in a byte slice. -func (swk *SwitchingKey) MarshalBinary() (data []byte, err error) { - - data = make([]byte, swk.GetDataLen(true)) - - if _, err = swk.encode(0, data); err != nil { - return nil, err - } - - return data, nil -} - -// UnmarshalBinary decode a previously marshaled SwitchingKey in the target SwitchingKey. -func (swk *SwitchingKey) UnmarshalBinary(data []byte) (err error) { - - if _, err = swk.decode(data); err != nil { - return err - } - - return nil -} - // CopyNew creates a deep copy of the receiver SwitchingKey and returns it. func (swk *SwitchingKey) CopyNew() *SwitchingKey { if swk == nil || len(swk.Value) == 0 { return nil } - swkb := &SwitchingKey{Value: make([][2]*ring.Poly, len(swk.Value))} + swkb := &SwitchingKey{Value: make([][2]PolyQP, len(swk.Value))} for i, el := range swk.Value { - swkb.Value[i] = [2]*ring.Poly{el[0].CopyNew(), el[1].CopyNew()} + swkb.Value[i] = [2]PolyQP{el[0].CopyNew(), el[1].CopyNew()} } return swkb } -func (swk *SwitchingKey) encode(pointer int, data []byte) (int, error) { - - var err error - var inc int - - data[pointer] = uint8(len(swk.Value)) - - pointer++ - - for j := 0; j < len(swk.Value); j++ { - - if inc, err = swk.Value[j][0].WriteTo(data[pointer : pointer+swk.Value[j][0].GetDataLen(true)]); err != nil { - return pointer, err - } - - pointer += inc - - if inc, err = swk.Value[j][1].WriteTo(data[pointer : pointer+swk.Value[j][1].GetDataLen(true)]); err != nil { - return pointer, err - } - - pointer += inc - } - - return pointer, nil -} - -func (swk *SwitchingKey) decode(data []byte) (pointer int, err error) { - - decomposition := int(data[0]) - - pointer = 1 - - swk.Value = make([][2]*ring.Poly, decomposition) - - var inc int - - for j := 0; j < decomposition; j++ { - - swk.Value[j][0] = new(ring.Poly) - if inc, err = swk.Value[j][0].DecodePolyNew(data[pointer:]); err != nil { - return pointer, err - } - pointer += inc - - swk.Value[j][1] = new(ring.Poly) - if inc, err = swk.Value[j][1].DecodePolyNew(data[pointer:]); err != nil { - return pointer, err - } - pointer += inc - - } - - return pointer, nil -} - // Equals checks to RotationKeySets for equality. func (rtks *RotationKeySet) Equals(other *RotationKeySet) bool { if rtks == other { @@ -431,7 +193,12 @@ func (rtks *RotationKeySet) Equals(other *RotationKeySet) bool { if len(rtks.Keys) != len(other.Keys) { return false } - return rtks.Includes(other) + for galEl, otherKey := range other.Keys { + if key, inSet := rtks.Keys[galEl]; !inSet || !otherKey.Equals(key) { + return false + } + } + return true } // Includes checks whether the receiver RotationKeySet includes the given other RotationKeySet. @@ -446,57 +213,3 @@ func (rtks *RotationKeySet) Includes(other *RotationKeySet) bool { } return true } - -// GetDataLen returns the length in bytes of the target RotationKeys. -func (rtks *RotationKeySet) GetDataLen(WithMetaData bool) (dataLen int) { - for _, k := range rtks.Keys { - if WithMetaData { - dataLen += 4 - } - dataLen += k.GetDataLen(WithMetaData) - } - return -} - -// MarshalBinary encodes a RotationKeys struct in a byte slice. -func (rtks *RotationKeySet) MarshalBinary() (data []byte, err error) { - - data = make([]byte, rtks.GetDataLen(true)) - - pointer := int(0) - - for galEL, key := range rtks.Keys { - - binary.BigEndian.PutUint32(data[pointer:pointer+4], uint32(galEL)) - pointer += 4 - - if pointer, err = key.encode(pointer, data); err != nil { - return nil, err - } - } - - return data, nil -} - -// UnmarshalBinary decodes a previously marshaled RotationKeys in the target RotationKeys. -func (rtks *RotationKeySet) UnmarshalBinary(data []byte) (err error) { - - rtks.Keys = make(map[uint64]*SwitchingKey) - - for len(data) > 0 { - - galEl := uint64(binary.BigEndian.Uint32(data)) - data = data[4:] - - swk := new(SwitchingKey) - var inc int - if inc, err = swk.decode(data); err != nil { - return err - } - data = data[inc:] - rtks.Keys[galEl] = swk - - } - - return nil -} diff --git a/rlwe/keyswitch.go b/rlwe/keyswitch.go index baeb4243..1ab1b0c0 100644 --- a/rlwe/keyswitch.go +++ b/rlwe/keyswitch.go @@ -1,8 +1,9 @@ package rlwe import ( - "github.com/ldsec/lattigo/v2/ring" "math" + + "github.com/ldsec/lattigo/v2/ring" ) // KeySwitcher is a struct for RLWE key-switching. @@ -16,31 +17,24 @@ type KeySwitcher struct { type keySwitcherBuffer struct { // PoolQ[0]/PoolP[0] : on the fly decomp(c2) // PoolQ[1-5]/PoolP[1-5] : available - PoolQ [6]*ring.Poly - PoolP [6]*ring.Poly - PoolInvNTT *ring.Poly - PoolDecompQ []*ring.Poly // Memory pool for the basis extension in hoisting - PoolDecompP []*ring.Poly // Memory pool for the basis extension in hoisting + Pool [6]PolyQP + PoolInvNTT *ring.Poly + PoolDecompQP []PolyQP // Memory pool for the basis extension in hoisting } func newKeySwitcherBuffer(params Parameters) *keySwitcherBuffer { buff := new(keySwitcherBuffer) beta := params.Beta() - ringQ := params.RingQ() - ringP := params.RingP() + ringQP := params.RingQP() - buff.PoolQ = [6]*ring.Poly{ringQ.NewPoly(), ringQ.NewPoly(), ringQ.NewPoly(), ringQ.NewPoly(), ringQ.NewPoly(), ringQ.NewPoly()} - buff.PoolP = [6]*ring.Poly{ringP.NewPoly(), ringP.NewPoly(), ringP.NewPoly(), ringP.NewPoly(), ringP.NewPoly(), ringP.NewPoly()} + buff.Pool = [6]PolyQP{ringQP.NewPoly(), ringQP.NewPoly(), ringQP.NewPoly(), ringQP.NewPoly(), ringQP.NewPoly(), ringQP.NewPoly()} - buff.PoolInvNTT = ringQ.NewPoly() - - buff.PoolDecompQ = make([]*ring.Poly, beta) - buff.PoolDecompP = make([]*ring.Poly, beta) + buff.PoolInvNTT = params.RingQ().NewPoly() + buff.PoolDecompQP = make([]PolyQP, beta) for i := 0; i < beta; i++ { - buff.PoolDecompQ[i] = ringQ.NewPoly() - buff.PoolDecompP[i] = ringP.NewPoly() + buff.PoolDecompQP[i] = ringQP.NewPoly() } return buff @@ -51,7 +45,7 @@ func NewKeySwitcher(params Parameters) *KeySwitcher { ks := new(KeySwitcher) ks.Parameters = ¶ms ks.Baseconverter = ring.NewFastBasisExtender(params.RingQ(), params.RingP()) - ks.Decomposer = ring.NewDecomposer(params.RingQ().Modulus, params.RingP().Modulus) + ks.Decomposer = ring.NewDecomposer(params.RingQ(), params.RingP()) ks.keySwitcherBuffer = newKeySwitcherBuffer(params) return ks } @@ -68,21 +62,23 @@ func (ks *KeySwitcher) ShallowCopy() *KeySwitcher { // SwitchKeysInPlace applies the general key-switching procedure of the form [c0 + cx*evakey[0], c1 + cx*evakey[1]] // Will return the result in the same NTT domain as the input cx. -func (ks *KeySwitcher) SwitchKeysInPlace(level int, cx *ring.Poly, evakey *SwitchingKey, p0, p1 *ring.Poly) { - ks.SwitchKeysInPlaceNoModDown(level, cx, evakey, p0, ks.PoolP[1], p1, ks.PoolP[2]) +func (ks *KeySwitcher) SwitchKeysInPlace(levelQ int, cx *ring.Poly, evakey *SwitchingKey, p0, p1 *ring.Poly) { + ks.SwitchKeysInPlaceNoModDown(levelQ, cx, evakey, p0, ks.Pool[1].P, p1, ks.Pool[2].P) + + levelP := len(evakey.Value[0][0].P.Coeffs) - 1 if cx.IsNTT { - ks.Baseconverter.ModDownSplitNTTPQ(level, p0, ks.PoolP[1], p0) - ks.Baseconverter.ModDownSplitNTTPQ(level, p1, ks.PoolP[2], p1) + ks.Baseconverter.ModDownQPtoQNTT(levelQ, levelP, p0, ks.Pool[1].P, p0) + ks.Baseconverter.ModDownQPtoQNTT(levelQ, levelP, p1, ks.Pool[2].P, p1) } else { - ks.ringQ.InvNTTLazyLvl(level, p0, p0) - ks.ringQ.InvNTTLazyLvl(level, p1, p1) - ks.ringP.InvNTTLazy(ks.PoolP[1], ks.PoolP[1]) - ks.ringP.InvNTTLazy(ks.PoolP[2], ks.PoolP[2]) + ks.ringQ.InvNTTLazyLvl(levelQ, p0, p0) + ks.ringQ.InvNTTLazyLvl(levelQ, p1, p1) + ks.ringP.InvNTTLazyLvl(levelP, ks.Pool[1].P, ks.Pool[1].P) + ks.ringP.InvNTTLazyLvl(levelP, ks.Pool[2].P, ks.Pool[2].P) - ks.Baseconverter.ModDownSplitPQ(level, p0, ks.PoolP[1], p0) - ks.Baseconverter.ModDownSplitPQ(level, p1, ks.PoolP[2], p1) + ks.Baseconverter.ModDownQPtoQ(levelQ, levelP, p0, ks.Pool[1].P, p0) + ks.Baseconverter.ModDownQPtoQ(levelQ, levelP, p1, ks.Pool[2].P, p1) } } @@ -90,7 +86,7 @@ func (ks *KeySwitcher) SwitchKeysInPlace(level int, cx *ring.Poly, evakey *Switc // Expects the IsNTT flag of c2 to correctly reflect the domain of c2. // PoolDecompQ and PoolDecompQ are vectors of polynomials (mod Q and mod P) that store the // special RNS decomposition of c2 (in the NTT domain) -func (ks *KeySwitcher) DecomposeNTT(levelQ int, c2 *ring.Poly, PoolDecompQ, PoolDecompP []*ring.Poly) { +func (ks *KeySwitcher) DecomposeNTT(levelQ, levelP, alpha int, c2 *ring.Poly, PoolDecomp []PolyQP) { ringQ := ks.RingQ() @@ -106,29 +102,28 @@ func (ks *KeySwitcher) DecomposeNTT(levelQ int, c2 *ring.Poly, PoolDecompQ, Pool ringQ.NTTLvl(levelQ, polyInvNTT, polyNTT) } - alpha := ks.Parameters.PCount() - beta := int(math.Ceil(float64(levelQ+1) / float64(alpha))) + beta := int(math.Ceil(float64(levelQ+1) / float64(levelP+1))) for i := 0; i < beta; i++ { - ks.DecomposeSingleNTT(levelQ, i, polyNTT, polyInvNTT, PoolDecompQ[i], PoolDecompP[i]) + ks.DecomposeSingleNTT(levelQ, levelP, alpha, i, polyNTT, polyInvNTT, PoolDecomp[i].Q, PoolDecomp[i].P) } } // DecomposeSingleNTT takes the input polynomial c2 (c2NTT and c2InvNTT, respectively in the NTT and out of the NTT domain) // modulo q_alpha_beta, and returns the result on c2QiQ are c2QiP the receiver polynomials // respectively mod Q and mod P (in the NTT domain) -func (ks *KeySwitcher) DecomposeSingleNTT(level, beta int, c2NTT, c2InvNTT, c2QiQ, c2QiP *ring.Poly) { +func (ks *KeySwitcher) DecomposeSingleNTT(levelQ, levelP, alpha, beta int, c2NTT, c2InvNTT, c2QiQ, c2QiP *ring.Poly) { ringQ := ks.RingQ() ringP := ks.RingP() - ks.Decomposer.DecomposeAndSplit(level, beta, c2InvNTT, c2QiQ, c2QiP) + ks.Decomposer.DecomposeAndSplit(levelQ, levelP, alpha, beta, c2InvNTT, c2QiQ, c2QiP) - p0idxst := beta * len(ringP.Modulus) - p0idxed := p0idxst + ks.Decomposer.Xalpha()[beta] + p0idxst := beta * (levelP + 1) + p0idxed := p0idxst + 1 // c2_qi = cx mod qi mod qi - for x := 0; x < level+1; x++ { + for x := 0; x < levelQ+1; x++ { if p0idxst <= x && x < p0idxed { copy(c2QiQ.Coeffs[x], c2NTT.Coeffs[x]) } else { @@ -136,7 +131,7 @@ func (ks *KeySwitcher) DecomposeSingleNTT(level, beta int, c2NTT, c2InvNTT, c2Qi } } // c2QiP = c2 mod qi mod pj - ringP.NTTLazy(c2QiP, c2QiP) + ringP.NTTLazyLvl(levelP, c2QiP, c2QiP) } // SwitchKeysInPlaceNoModDown applies the key-switch to the polynomial cx : @@ -145,83 +140,73 @@ func (ks *KeySwitcher) DecomposeSingleNTT(level, beta int, c2NTT, c2InvNTT, c2Qi // pool3 = dot(decomp(cx) * evakey[1]) mod QP (encrypted input is multiplied by P factor) // // Expects the flag IsNTT of cx to correctly reflect the domain of cx. -func (ks *KeySwitcher) SwitchKeysInPlaceNoModDown(level int, cx *ring.Poly, evakey *SwitchingKey, pool2Q, pool2P, pool3Q, pool3P *ring.Poly) { +func (ks *KeySwitcher) SwitchKeysInPlaceNoModDown(levelQ int, cx *ring.Poly, evakey *SwitchingKey, pool2Q, pool2P, pool3Q, pool3P *ring.Poly) { var reduce int - ringQ := ks.ringQ - ringP := ks.ringP + ringQ := ks.RingQ() + ringP := ks.RingP() + ringQP := ks.RingQP() - c2QiQ := ks.PoolQ[0] - c2QiP := ks.PoolP[0] + c2QP := ks.Pool[0] var cxNTT, cxInvNTT *ring.Poly if cx.IsNTT { cxNTT = cx cxInvNTT = ks.PoolInvNTT - ringQ.InvNTTLvl(level, cxNTT, cxInvNTT) + ringQ.InvNTTLvl(levelQ, cxNTT, cxInvNTT) } else { cxNTT = ks.PoolInvNTT cxInvNTT = cx - ringQ.NTTLvl(level, cxInvNTT, cxNTT) + ringQ.NTTLvl(levelQ, cxInvNTT, cxNTT) } - evakey0Q := new(ring.Poly) - evakey1Q := new(ring.Poly) - evakey0P := new(ring.Poly) - evakey1P := new(ring.Poly) + pool2QP := PolyQP{pool2Q, pool2P} + pool3QP := PolyQP{pool3Q, pool3P} reduce = 0 - alpha := len(ringP.Modulus) - beta := int(math.Ceil(float64(level+1) / float64(alpha))) + alpha := len(evakey.Value[0][0].P.Coeffs) + levelP := alpha - 1 + beta := int(math.Ceil(float64(levelQ+1) / float64(levelP+1))) - QiOverF := ks.Parameters.QiOverflowMargin(level) >> 1 - PiOverF := ks.Parameters.PiOverflowMargin() >> 1 + QiOverF := ks.Parameters.QiOverflowMargin(levelQ) >> 1 + PiOverF := ks.Parameters.PiOverflowMargin(levelP) >> 1 // Key switching with CRT decomposition for the Qi for i := 0; i < beta; i++ { - ks.DecomposeSingleNTT(level, i, cxNTT, cxInvNTT, c2QiQ, c2QiP) - - evakey0Q.Coeffs = evakey.Value[i][0].Coeffs[:level+1] - evakey1Q.Coeffs = evakey.Value[i][1].Coeffs[:level+1] - evakey0P.Coeffs = evakey.Value[i][0].Coeffs[len(ringQ.Modulus):] - evakey1P.Coeffs = evakey.Value[i][1].Coeffs[len(ringQ.Modulus):] + ks.DecomposeSingleNTT(levelQ, levelP, alpha, i, cxNTT, cxInvNTT, c2QP.Q, c2QP.P) if i == 0 { - ringQ.MulCoeffsMontgomeryConstantLvl(level, evakey0Q, c2QiQ, pool2Q) - ringQ.MulCoeffsMontgomeryConstantLvl(level, evakey1Q, c2QiQ, pool3Q) - ringP.MulCoeffsMontgomeryConstant(evakey0P, c2QiP, pool2P) - ringP.MulCoeffsMontgomeryConstant(evakey1P, c2QiP, pool3P) + ringQP.MulCoeffsMontgomeryConstantLvl(levelQ, levelP, evakey.Value[i][0], c2QP, pool2QP) + ringQP.MulCoeffsMontgomeryConstantLvl(levelQ, levelP, evakey.Value[i][1], c2QP, pool3QP) } else { - ringQ.MulCoeffsMontgomeryConstantAndAddNoModLvl(level, evakey0Q, c2QiQ, pool2Q) - ringQ.MulCoeffsMontgomeryConstantAndAddNoModLvl(level, evakey1Q, c2QiQ, pool3Q) - ringP.MulCoeffsMontgomeryConstantAndAddNoMod(evakey0P, c2QiP, pool2P) - ringP.MulCoeffsMontgomeryConstantAndAddNoMod(evakey1P, c2QiP, pool3P) + ringQP.MulCoeffsMontgomeryConstantAndAddNoModLvl(levelQ, levelP, evakey.Value[i][0], c2QP, pool2QP) + ringQP.MulCoeffsMontgomeryConstantAndAddNoModLvl(levelQ, levelP, evakey.Value[i][1], c2QP, pool3QP) } if reduce%QiOverF == QiOverF-1 { - ringQ.ReduceLvl(level, pool2Q, pool2Q) - ringQ.ReduceLvl(level, pool3Q, pool3Q) + ringQ.ReduceLvl(levelQ, pool2QP.Q, pool2QP.Q) + ringQ.ReduceLvl(levelQ, pool3QP.Q, pool3QP.Q) } if reduce%PiOverF == PiOverF-1 { - ringP.Reduce(pool2P, pool2P) - ringP.Reduce(pool3P, pool3P) + ringP.ReduceLvl(levelP, pool2QP.P, pool2QP.P) + ringP.ReduceLvl(levelP, pool3QP.P, pool3QP.P) } reduce++ } if reduce%QiOverF != 0 { - ringQ.ReduceLvl(level, pool2Q, pool2Q) - ringQ.ReduceLvl(level, pool3Q, pool3Q) + ringQ.ReduceLvl(levelQ, pool2QP.Q, pool2QP.Q) + ringQ.ReduceLvl(levelQ, pool3QP.Q, pool3QP.Q) } if reduce%PiOverF != 0 { - ringP.Reduce(pool2P, pool2P) - ringP.Reduce(pool3P, pool3P) + ringP.ReduceLvl(levelP, pool2QP.P, pool2QP.P) + ringP.ReduceLvl(levelP, pool3QP.P, pool3QP.P) } } @@ -230,76 +215,69 @@ func (ks *KeySwitcher) SwitchKeysInPlaceNoModDown(level int, cx *ring.Poly, evak // // pool2 = dot(PoolDecompQ||PoolDecompP * evakey[0]) mod Q // pool3 = dot(PoolDecompQ||PoolDecompP * evakey[1]) mod Q -func (ks *KeySwitcher) KeyswitchHoisted(level int, PoolDecompQ, PoolDecompP []*ring.Poly, evakey *SwitchingKey, pool2Q, pool3Q, pool2P, pool3P *ring.Poly) { +func (ks *KeySwitcher) KeyswitchHoisted(levelQ int, PoolDecompQP []PolyQP, evakey *SwitchingKey, pool2Q, pool3Q, pool2P, pool3P *ring.Poly) { - ks.KeyswitchHoistedNoModDown(level, PoolDecompQ, PoolDecompP, evakey, pool2Q, pool3Q, pool2P, pool3P) + ks.KeyswitchHoistedNoModDown(levelQ, PoolDecompQP, evakey, pool2Q, pool3Q, pool2P, pool3P) + + levelP := len(evakey.Value[0][0].P.Coeffs) - 1 // Computes pool2Q = pool2Q/pool2P and pool3Q = pool3Q/pool3P - ks.Baseconverter.ModDownSplitNTTPQ(level, pool2Q, pool2P, pool2Q) - ks.Baseconverter.ModDownSplitNTTPQ(level, pool3Q, pool3P, pool3Q) + ks.Baseconverter.ModDownQPtoQNTT(levelQ, levelP, pool2Q, pool2P, pool2Q) + ks.Baseconverter.ModDownQPtoQNTT(levelQ, levelP, pool3Q, pool3P, pool3Q) } // KeyswitchHoistedNoModDown applies the key-switch to the decomposed polynomial c2 mod QP (PoolDecompQ and PoolDecompP) // // pool2 = dot(PoolDecompQ||PoolDecompP * evakey[0]) mod QP // pool3 = dot(PoolDecompQ||PoolDecompP * evakey[1]) mod QP -func (ks *KeySwitcher) KeyswitchHoistedNoModDown(level int, PoolDecompQ, PoolDecompP []*ring.Poly, evakey *SwitchingKey, pool2Q, pool3Q, pool2P, pool3P *ring.Poly) { +func (ks *KeySwitcher) KeyswitchHoistedNoModDown(levelQ int, PoolDecompQP []PolyQP, evakey *SwitchingKey, pool2Q, pool3Q, pool2P, pool3P *ring.Poly) { - ringQ := ks.ringQ - ringP := ks.ringP + ringQ := ks.RingQ() + ringP := ks.RingP() + ringQP := ks.RingQP() - alpha := len(ringP.Modulus) - beta := int(math.Ceil(float64(level+1) / float64(alpha))) + pool2QP := PolyQP{pool2Q, pool2P} + pool3QP := PolyQP{pool3Q, pool3P} - evakey0Q := new(ring.Poly) - evakey1Q := new(ring.Poly) - evakey0P := new(ring.Poly) - evakey1P := new(ring.Poly) + alpha := len(evakey.Value[0][0].P.Coeffs) + levelP := alpha - 1 + beta := int(math.Ceil(float64(levelQ+1) / float64(alpha))) - QiOverF := ks.Parameters.QiOverflowMargin(level) >> 1 - PiOverF := ks.Parameters.PiOverflowMargin() >> 1 + QiOverF := ks.Parameters.QiOverflowMargin(levelQ) >> 1 + PiOverF := ks.Parameters.PiOverflowMargin(levelP) >> 1 // Key switching with CRT decomposition for the Qi var reduce int for i := 0; i < beta; i++ { - evakey0Q.Coeffs = evakey.Value[i][0].Coeffs[:level+1] - evakey1Q.Coeffs = evakey.Value[i][1].Coeffs[:level+1] - evakey0P.Coeffs = evakey.Value[i][0].Coeffs[len(ringQ.Modulus):] - evakey1P.Coeffs = evakey.Value[i][1].Coeffs[len(ringQ.Modulus):] - if i == 0 { - ringQ.MulCoeffsMontgomeryConstantLvl(level, evakey0Q, PoolDecompQ[i], pool2Q) - ringQ.MulCoeffsMontgomeryConstantLvl(level, evakey1Q, PoolDecompQ[i], pool3Q) - ringP.MulCoeffsMontgomeryConstant(evakey0P, PoolDecompP[i], pool2P) - ringP.MulCoeffsMontgomeryConstant(evakey1P, PoolDecompP[i], pool3P) + ringQP.MulCoeffsMontgomeryConstantLvl(levelQ, levelP, evakey.Value[i][0], PoolDecompQP[i], pool2QP) + ringQP.MulCoeffsMontgomeryConstantLvl(levelQ, levelP, evakey.Value[i][1], PoolDecompQP[i], pool3QP) } else { - ringQ.MulCoeffsMontgomeryConstantAndAddNoModLvl(level, evakey0Q, PoolDecompQ[i], pool2Q) - ringQ.MulCoeffsMontgomeryConstantAndAddNoModLvl(level, evakey1Q, PoolDecompQ[i], pool3Q) - ringP.MulCoeffsMontgomeryConstantAndAddNoMod(evakey0P, PoolDecompP[i], pool2P) - ringP.MulCoeffsMontgomeryConstantAndAddNoMod(evakey1P, PoolDecompP[i], pool3P) + ringQP.MulCoeffsMontgomeryConstantAndAddNoModLvl(levelQ, levelP, evakey.Value[i][0], PoolDecompQP[i], pool2QP) + ringQP.MulCoeffsMontgomeryConstantAndAddNoModLvl(levelQ, levelP, evakey.Value[i][1], PoolDecompQP[i], pool3QP) } if reduce%QiOverF == QiOverF-1 { - ringQ.ReduceLvl(level, pool2Q, pool2Q) - ringQ.ReduceLvl(level, pool3Q, pool3Q) + ringQ.ReduceLvl(levelQ, pool2QP.Q, pool2QP.Q) + ringQ.ReduceLvl(levelQ, pool3QP.Q, pool3QP.Q) } if reduce%PiOverF == PiOverF-1 { - ringP.Reduce(pool2P, pool2P) - ringP.Reduce(pool3P, pool3P) + ringP.ReduceLvl(levelP, pool2QP.P, pool2QP.P) + ringP.ReduceLvl(levelP, pool3QP.P, pool3QP.P) } reduce++ } if reduce%QiOverF != 0 { - ringQ.ReduceLvl(level, pool2Q, pool2Q) - ringQ.ReduceLvl(level, pool3Q, pool3Q) + ringQ.ReduceLvl(levelQ, pool2QP.Q, pool2QP.Q) + ringQ.ReduceLvl(levelQ, pool3QP.Q, pool3QP.Q) } if reduce%PiOverF != 0 { - ringP.Reduce(pool2P, pool2P) - ringP.Reduce(pool3P, pool3P) + ringP.ReduceLvl(levelP, pool2QP.P, pool2QP.P) + ringP.ReduceLvl(levelP, pool3QP.P, pool3QP.P) } } diff --git a/rlwe/marshaler.go b/rlwe/marshaler.go new file mode 100644 index 00000000..f85785b1 --- /dev/null +++ b/rlwe/marshaler.go @@ -0,0 +1,267 @@ +package rlwe + +import ( + "encoding/binary" + + "github.com/ldsec/lattigo/v2/ring" +) + +// GetDataLen returns the length in bytes of the target SecretKey. +func (sk *SecretKey) GetDataLen(WithMetadata bool) (dataLen int) { + return sk.Value.GetDataLen(WithMetadata) +} + +// MarshalBinary encodes a secret key in a byte slice. +func (sk *SecretKey) MarshalBinary() (data []byte, err error) { + data = make([]byte, sk.GetDataLen(true)) + if _, err = sk.Value.WriteTo(data); err != nil { + return nil, err + } + return +} + +// UnmarshalBinary decodes a previously marshaled SecretKey in the target SecretKey. +func (sk *SecretKey) UnmarshalBinary(data []byte) (err error) { + _, err = sk.Value.DecodePolyNew(data) + return +} + +// GetDataLen returns the length in bytes of the target PublicKey. +func (pk *PublicKey) GetDataLen(WithMetadata bool) (dataLen int) { + return pk.Value[0].GetDataLen(WithMetadata) + pk.Value[1].GetDataLen(WithMetadata) +} + +// MarshalBinary encodes a PublicKey in a byte slice. +func (pk *PublicKey) MarshalBinary() (data []byte, err error) { + data = make([]byte, pk.GetDataLen(true)) + var inc, pt int + if inc, err = pk.Value[0].WriteTo(data[pt:]); err != nil { + return nil, err + } + pt += inc + + if _, err = pk.Value[1].WriteTo(data[pt:]); err != nil { + return nil, err + } + + return +} + +// UnmarshalBinary decodes a previously marshaled PublicKey in the target PublicKey. +func (pk *PublicKey) UnmarshalBinary(data []byte) (err error) { + + var pt, inc int + if inc, err = pk.Value[0].DecodePolyNew(data[pt:]); err != nil { + return + } + pt += inc + + if _, err = pk.Value[1].DecodePolyNew(data[pt:]); err != nil { + return + } + + return +} + +// GetDataLen returns the length in bytes of the target EvaluationKey. +func (rlk *RelinearizationKey) GetDataLen(WithMetadata bool) (dataLen int) { + + if WithMetadata { + dataLen++ + } + + for _, evakey := range rlk.Keys { + dataLen += (*SwitchingKey)(evakey).GetDataLen(WithMetadata) + } + + return +} + +// MarshalBinary encodes an EvaluationKey key in a byte slice. +func (rlk *RelinearizationKey) MarshalBinary() (data []byte, err error) { + + var pointer int + + dataLen := rlk.GetDataLen(true) + + data = make([]byte, dataLen) + + data[0] = uint8(len(rlk.Keys)) + + pointer++ + + for _, evakey := range rlk.Keys { + + if pointer, err = (*SwitchingKey)(evakey).encode(pointer, data); err != nil { + return nil, err + } + } + + return data, nil +} + +// UnmarshalBinary decodes a previously marshaled EvaluationKey in the target EvaluationKey. +func (rlk *RelinearizationKey) UnmarshalBinary(data []byte) (err error) { + + deg := int(data[0]) + + rlk.Keys = make([]*SwitchingKey, deg) + + pointer := 1 + var inc int + for i := 0; i < deg; i++ { + rlk.Keys[i] = new(SwitchingKey) + if inc, err = rlk.Keys[i].decode(data[pointer:]); err != nil { + return err + } + pointer += inc + } + + return nil +} + +// GetDataLen returns the length in bytes of the target SwitchingKey. +func (swk *SwitchingKey) GetDataLen(WithMetadata bool) (dataLen int) { + + if WithMetadata { + dataLen++ + } + + for j := uint64(0); j < uint64(len(swk.Value)); j++ { + dataLen += swk.Value[j][0].GetDataLen(WithMetadata) + dataLen += swk.Value[j][1].GetDataLen(WithMetadata) + } + + return +} + +// MarshalBinary encodes an SwitchingKey in a byte slice. +func (swk *SwitchingKey) MarshalBinary() (data []byte, err error) { + + data = make([]byte, swk.GetDataLen(true)) + + if _, err = swk.encode(0, data); err != nil { + return nil, err + } + + return data, nil +} + +// UnmarshalBinary decode a previously marshaled SwitchingKey in the target SwitchingKey. +func (swk *SwitchingKey) UnmarshalBinary(data []byte) (err error) { + + if _, err = swk.decode(data); err != nil { + return err + } + + return nil +} + +func (swk *SwitchingKey) encode(pointer int, data []byte) (int, error) { + + var err error + var inc int + + data[pointer] = uint8(len(swk.Value)) + + pointer++ + + for j := 0; j < len(swk.Value); j++ { + + if inc, err = swk.Value[j][0].WriteTo(data[pointer : pointer+swk.Value[j][0].GetDataLen(true)]); err != nil { + return pointer, err + } + + pointer += inc + + if inc, err = swk.Value[j][1].WriteTo(data[pointer : pointer+swk.Value[j][1].GetDataLen(true)]); err != nil { + return pointer, err + } + + pointer += inc + } + + return pointer, nil +} + +func (swk *SwitchingKey) decode(data []byte) (pointer int, err error) { + + decomposition := int(data[0]) + + pointer = 1 + + swk.Value = make([][2]PolyQP, decomposition) + + var inc int + + for j := 0; j < decomposition; j++ { + + swk.Value[j][0].Q = new(ring.Poly) + if inc, err = swk.Value[j][0].DecodePolyNew(data[pointer:]); err != nil { + return + } + pointer += inc + + swk.Value[j][1].P = new(ring.Poly) + if inc, err = swk.Value[j][1].DecodePolyNew(data[pointer:]); err != nil { + return + } + pointer += inc + } + + return +} + +// GetDataLen returns the length in bytes of the target RotationKeys. +func (rtks *RotationKeySet) GetDataLen(WithMetaData bool) (dataLen int) { + for _, k := range rtks.Keys { + if WithMetaData { + dataLen += 4 + } + dataLen += k.GetDataLen(WithMetaData) + } + return +} + +// MarshalBinary encodes a RotationKeys struct in a byte slice. +func (rtks *RotationKeySet) MarshalBinary() (data []byte, err error) { + + data = make([]byte, rtks.GetDataLen(true)) + + pointer := int(0) + + for galEL, key := range rtks.Keys { + + binary.BigEndian.PutUint32(data[pointer:pointer+4], uint32(galEL)) + pointer += 4 + + if pointer, err = key.encode(pointer, data); err != nil { + return nil, err + } + } + + return data, nil +} + +// UnmarshalBinary decodes a previously marshaled RotationKeys in the target RotationKeys. +func (rtks *RotationKeySet) UnmarshalBinary(data []byte) (err error) { + + rtks.Keys = make(map[uint64]*SwitchingKey) + + for len(data) > 0 { + + galEl := uint64(binary.BigEndian.Uint32(data)) + data = data[4:] + + swk := new(SwitchingKey) + var inc int + if inc, err = swk.decode(data); err != nil { + return err + } + data = data[inc:] + rtks.Keys[galEl] = swk + + } + + return nil +} diff --git a/rlwe/params.go b/rlwe/params.go index 6da25bb4..c177694f 100644 --- a/rlwe/params.go +++ b/rlwe/params.go @@ -46,13 +46,12 @@ type ParametersLiteral struct { // Parameters represents a set of generic RLWE parameters. Its fields are private and // immutable. See ParametersLiteral for user-specified parameters. type Parameters struct { - logN int - qi []uint64 - pi []uint64 - sigma float64 - ringQ *ring.Ring - ringP *ring.Ring - ringQP *ring.Ring + logN int + qi []uint64 + pi []uint64 + sigma float64 + ringQ *ring.Ring + ringP *ring.Ring } var ( @@ -123,10 +122,6 @@ func NewParameters(logn int, q, p []uint64, sigma float64) (Parameters, error) { } } - if params.ringQP, err = ring.NewRing(1<> 1 + + if polyInQ != polyOutQ && polyOutQ != nil { + polyOutQ.Copy(polyInQ) + } + + for j := 0; j < r.RingQ.N; j++ { + + coeff = polyInQ.Coeffs[0][j] + + sign = 1 + if coeff > QHalf { + coeff = Q - coeff + sign = 0 + } + + for i, pi := range r.RingP.Modulus[:levelP+1] { + polyOutP.Coeffs[i][j] = (coeff * sign) | (pi-coeff)*(sign^1) + } + } +} + +// Copy copies the input polyQP on the target polyQP. +func (p *PolyQP) Copy(polFrom PolyQP) { + p.Q.Copy(polFrom.Q) + p.P.Copy(polFrom.P) +} + +// GetDataLen returns the length in byte of the target PolyQP +func (p *PolyQP) GetDataLen(WithMetadata bool) (dataLen int) { + return p.Q.GetDataLen(WithMetadata) + p.P.GetDataLen(WithMetadata) +} + +// WriteTo writes a polyQP on the inpute data. +func (p *PolyQP) WriteTo(data []byte) (pt int, err error) { + var inc int + if inc, err = p.Q.WriteTo(data[pt:]); err != nil { + return + } + pt += inc + + if inc, err = p.P.WriteTo(data[pt:]); err != nil { + return + } + pt += inc + + return +} + +// DecodePolyNew decodes the input bytes on the target polyQP. +func (p *PolyQP) DecodePolyNew(data []byte) (pt int, err error) { + p.Q = new(ring.Poly) + var inc int + if inc, err = p.Q.DecodePolyNew(data[pt:]); err != nil { + return + } + pt += inc + + p.P = new(ring.Poly) + if inc, err = p.P.DecodePolyNew(data[pt:]); err != nil { + return + } + pt += inc + + return +} + +// UniformSamplerQP is a type for sampling polynomials in RingQP. +type UniformSamplerQP struct { + samplerQ, samplerP ring.UniformSampler +} + +// NewUniformSamplerQP instantiates a new UniformSamplerQP from a given PRNG. +func NewUniformSamplerQP(params Parameters, prng utils.PRNG, baseRing *RingQP) (s UniformSamplerQP) { + s.samplerQ = *ring.NewUniformSampler(prng, params.RingQ()) + s.samplerP = *ring.NewUniformSampler(prng, params.RingP()) + return s +} + +// Read samples a new polynomial in RingQP and stores it into p. +func (s UniformSamplerQP) Read(p *PolyQP) { + s.samplerQ.Read(p.Q) + s.samplerP.Read(p.P) +} diff --git a/rlwe/rlwe_benchmark_test.go b/rlwe/rlwe_benchmark_test.go index 04e4088a..24e25e78 100644 --- a/rlwe/rlwe_benchmark_test.go +++ b/rlwe/rlwe_benchmark_test.go @@ -51,14 +51,14 @@ func benchHoistedKeySwitch(kgen KeyGenerator, keySwitcher *KeySwitcher, b *testi b.Run(testString(params, "DecomposeNTT/"), func(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - keySwitcher.DecomposeNTT(ciphertext.Level(), ciphertext.Value[1], keySwitcher.PoolDecompQ, keySwitcher.PoolDecompP) + keySwitcher.DecomposeNTT(ciphertext.Level(), params.PCount()-1, params.PCount(), ciphertext.Value[1], keySwitcher.PoolDecompQP) } }) b.Run(testString(params, "KeySwitchHoisted/"), func(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - keySwitcher.KeyswitchHoisted(ciphertext.Level(), keySwitcher.PoolDecompQ, keySwitcher.PoolDecompP, swk, ciphertext.Value[0], ciphertext.Value[1], keySwitcher.PoolP[1], keySwitcher.PoolP[2]) + keySwitcher.KeyswitchHoisted(ciphertext.Level(), keySwitcher.PoolDecompQP, swk, ciphertext.Value[0], ciphertext.Value[1], keySwitcher.Pool[1].P, keySwitcher.Pool[2].P) } }) } diff --git a/rlwe/rlwe_test.go b/rlwe/rlwe_test.go index 440f14fd..0bbac376 100644 --- a/rlwe/rlwe_test.go +++ b/rlwe/rlwe_test.go @@ -42,7 +42,7 @@ func TestRLWE(t *testing.T) { defaultParams = []ParametersLiteral{jsonParams} // the custom test suite reads the parameters from the -params flag } - for _, defaultParam := range defaultParams { + for _, defaultParam := range defaultParams[:] { params, err := NewParametersFromLiteral(defaultParam) if err != nil { panic(err) @@ -56,6 +56,7 @@ func TestRLWE(t *testing.T) { testEncryptor, testDecryptor, testKeySwitcher, + testKeySwitchDimension, testMarshaller, } { testSet(kgen, t) @@ -135,17 +136,16 @@ func testGenKeyPair(kgen KeyGenerator, t *testing.T) { // Checks that sum([-as + e, a] + [as])) <= N * 6 * sigma t.Run(testString(params, "PKGen/"), func(t *testing.T) { - - ringQP := params.RingQP() - sk, pk := kgen.GenKeyPair() // [-as + e] + [as] - ringQP.MulCoeffsMontgomeryAndAdd(sk.Value, pk.Value[1], pk.Value[0]) - ringQP.InvNTT(pk.Value[0], pk.Value[0]) + params.RingQP().MulCoeffsMontgomeryAndAddLvl(sk.Value.Q.Level(), sk.Value.P.Level(), sk.Value, pk.Value[1], pk.Value[0]) + params.RingQP().InvNTTLvl(sk.Value.Q.Level(), sk.Value.P.Level(), pk.Value[0], pk.Value[0]) log2Bound := bits.Len64(uint64(math.Floor(DefaultSigma*6)) * uint64(params.N())) - require.GreaterOrEqual(t, log2Bound, log2OfInnerSum(pk.Value[0].Level(), ringQP, pk.Value[0])) + require.GreaterOrEqual(t, log2Bound, log2OfInnerSum(pk.Value[0].Q.Level(), params.RingQ(), pk.Value[0].Q)) + + require.GreaterOrEqual(t, log2Bound, log2OfInnerSum(pk.Value[0].P.Level(), params.RingP(), pk.Value[0].P)) }) } @@ -161,42 +161,46 @@ func testSwitchKeyGen(kgen KeyGenerator, t *testing.T) { t.Run(testString(params, "SWKGen/"), func(t *testing.T) { ringQ := params.RingQ() + ringP := params.RingP() + ringQP := params.RingQP() skIn := kgen.GenSecretKey() skOut := kgen.GenSecretKey() + levelQ, levelP := params.QCount()-1, params.PCount()-1 // Generates Decomp([-asIn + w*P*sOut + e, a]) - swk := NewSwitchingKey(params) - kgen.(*keyGenerator).newSwitchingKey(skIn.Value, skOut.Value, swk) + swk := NewSwitchingKey(params, params.QCount()-1, params.PCount()-1) + kgen.(*keyGenerator).genSwitchingKey(skIn.Value.Q, skOut.Value, swk) // Decrypts // [-asIn + w*P*sOut + e, a] + [asIn] for j := range swk.Value { - ringQ.MulCoeffsMontgomeryAndAdd(swk.Value[j][1], skOut.Value, swk.Value[j][0]) + ringQP.MulCoeffsMontgomeryAndAddLvl(levelQ, levelP, swk.Value[j][1], skOut.Value, swk.Value[j][0]) } - poly := swk.Value[0][0] - // Sums all basis together (equivalent to multiplying with CRT decomposition of 1) // sum([1]_w * [w*P*sOut + e]) = P*sOut + sum(e) for j := range swk.Value { if j > 0 { - ringQ.Add(poly, swk.Value[j][0], poly) + ringQP.AddLvl(levelQ, levelP, swk.Value[0][0], swk.Value[j][0], swk.Value[0][0]) } } // sOut * P - ringQ.MulScalarBigint(skIn.Value, kgen.(*keyGenerator).pBigInt, skIn.Value) + ringQ.MulScalarBigint(skIn.Value.Q, ringP.ModulusBigint, skIn.Value.Q) // P*s^i + sum(e) - P*s^i = sum(e) - ringQ.Sub(poly, skIn.Value, poly) + ringQ.Sub(swk.Value[0][0].Q, skIn.Value.Q, swk.Value[0][0].Q) // Checks that the error is below the bound // Worst error bound is N * floor(6*sigma) * #Keys - ringQ.InvNTT(poly, poly) - ringQ.InvMForm(poly, poly) + + ringQP.InvNTTLvl(levelQ, levelP, swk.Value[0][0], swk.Value[0][0]) + ringQP.InvMFormLvl(levelQ, levelP, swk.Value[0][0], swk.Value[0][0]) log2Bound := bits.Len64(uint64(math.Floor(DefaultSigma*6)) * uint64(params.N()*len(swk.Value))) - require.GreaterOrEqual(t, log2Bound, log2OfInnerSum(len(ringQ.Modulus)-1, ringQ, poly)) + require.GreaterOrEqual(t, log2Bound, log2OfInnerSum(len(ringQ.Modulus)-1, ringQ, swk.Value[0][0].Q)) + require.GreaterOrEqual(t, log2Bound, log2OfInnerSum(len(ringP.Modulus)-1, ringP, swk.Value[0][0].P)) + }) } @@ -215,7 +219,7 @@ func testEncryptor(kgen KeyGenerator, t *testing.T) { ciphertext := NewCiphertextNTT(params, 1, plaintext.Level()) encryptor.Encrypt(plaintext, ciphertext) require.Equal(t, plaintext.Level(), ciphertext.Level()) - ringQ.MulCoeffsMontgomeryAndAddLvl(ciphertext.Level(), ciphertext.Value[1], sk.Value, ciphertext.Value[0]) + ringQ.MulCoeffsMontgomeryAndAddLvl(ciphertext.Level(), ciphertext.Value[1], sk.Value.Q, ciphertext.Value[0]) ringQ.InvNTTLvl(ciphertext.Level(), ciphertext.Value[0], ciphertext.Value[0]) require.GreaterOrEqual(t, 12+params.LogN(), log2OfInnerSum(ciphertext.Level(), ringQ, ciphertext.Value[0])) }) @@ -227,7 +231,7 @@ func testEncryptor(kgen KeyGenerator, t *testing.T) { ciphertext := NewCiphertextNTT(params, 1, plaintext.Level()) encryptor.Encrypt(plaintext, ciphertext) require.Equal(t, plaintext.Level(), ciphertext.Level()) - ringQ.MulCoeffsMontgomeryAndAddLvl(ciphertext.Level(), ciphertext.Value[1], sk.Value, ciphertext.Value[0]) + ringQ.MulCoeffsMontgomeryAndAddLvl(ciphertext.Level(), ciphertext.Value[1], sk.Value.Q, ciphertext.Value[0]) ringQ.InvNTTLvl(ciphertext.Level(), ciphertext.Value[0], ciphertext.Value[0]) require.GreaterOrEqual(t, 12+params.LogN(), log2OfInnerSum(ciphertext.Level(), ringQ, ciphertext.Value[0])) }) @@ -242,7 +246,7 @@ func testEncryptor(kgen KeyGenerator, t *testing.T) { ciphertext := NewCiphertextNTT(params, 1, plaintext.Level()) encryptor.Encrypt(plaintext, ciphertext) require.Equal(t, plaintext.Level(), ciphertext.Level()) - ringQ.MulCoeffsMontgomeryAndAddLvl(ciphertext.Level(), ciphertext.Value[1], sk.Value, ciphertext.Value[0]) + ringQ.MulCoeffsMontgomeryAndAddLvl(ciphertext.Level(), ciphertext.Value[1], sk.Value.Q, ciphertext.Value[0]) ringQ.InvNTTLvl(ciphertext.Level(), ciphertext.Value[0], ciphertext.Value[0]) require.GreaterOrEqual(t, 9+params.LogN(), log2OfInnerSum(ciphertext.Level(), ringQ, ciphertext.Value[0])) }) @@ -257,7 +261,7 @@ func testEncryptor(kgen KeyGenerator, t *testing.T) { ciphertext := NewCiphertextNTT(params, 1, plaintext.Level()) encryptor.Encrypt(plaintext, ciphertext) require.Equal(t, plaintext.Level(), ciphertext.Level()) - ringQ.MulCoeffsMontgomeryAndAddLvl(ciphertext.Level(), ciphertext.Value[1], sk.Value, ciphertext.Value[0]) + ringQ.MulCoeffsMontgomeryAndAddLvl(ciphertext.Level(), ciphertext.Value[1], sk.Value.Q, ciphertext.Value[0]) ringQ.InvNTTLvl(ciphertext.Level(), ciphertext.Value[0], ciphertext.Value[0]) require.GreaterOrEqual(t, 9+params.LogN(), log2OfInnerSum(ciphertext.Level(), ringQ, ciphertext.Value[0])) }) @@ -269,7 +273,7 @@ func testEncryptor(kgen KeyGenerator, t *testing.T) { ciphertext := NewCiphertextNTT(params, 1, plaintext.Level()) encryptor.Encrypt(plaintext, ciphertext) require.Equal(t, plaintext.Level(), ciphertext.Level()) - ringQ.MulCoeffsMontgomeryAndAddLvl(ciphertext.Level(), ciphertext.Value[1], sk.Value, ciphertext.Value[0]) + ringQ.MulCoeffsMontgomeryAndAddLvl(ciphertext.Level(), ciphertext.Value[1], sk.Value.Q, ciphertext.Value[0]) ringQ.InvNTTLvl(ciphertext.Level(), ciphertext.Value[0], ciphertext.Value[0]) require.GreaterOrEqual(t, 5+params.LogN(), log2OfInnerSum(ciphertext.Level(), ringQ, ciphertext.Value[0])) }) @@ -281,7 +285,7 @@ func testEncryptor(kgen KeyGenerator, t *testing.T) { ciphertext := NewCiphertextNTT(params, 1, plaintext.Level()) encryptor.Encrypt(plaintext, ciphertext) require.Equal(t, plaintext.Level(), ciphertext.Level()) - ringQ.MulCoeffsMontgomeryAndAddLvl(ciphertext.Level(), ciphertext.Value[1], sk.Value, ciphertext.Value[0]) + ringQ.MulCoeffsMontgomeryAndAddLvl(ciphertext.Level(), ciphertext.Value[1], sk.Value.Q, ciphertext.Value[0]) ringQ.InvNTTLvl(ciphertext.Level(), ciphertext.Value[0], ciphertext.Value[0]) require.GreaterOrEqual(t, 5+params.LogN(), log2OfInnerSum(ciphertext.Level(), ringQ, ciphertext.Value[0])) }) @@ -324,10 +328,24 @@ func testKeySwitcher(kgen KeyGenerator, t *testing.T) { skOut := kgen.GenSecretKey() ks := NewKeySwitcher(params) - ringQP := params.RingQP() ringQ := params.RingQ() + ringP := params.RingP() - plaintext := NewPlaintext(params, params.MaxLevel()) + levelQ := params.MaxLevel() + alpha := params.PCount() + levelP := alpha - 1 + + QBig := ring.NewUint(1) + for i := range ringQ.Modulus[:levelQ+1] { + QBig.Mul(QBig, ring.NewUint(ringQ.Modulus[i])) + } + + PBig := ring.NewUint(1) + for i := range ringP.Modulus[:levelP+1] { + PBig.Mul(PBig, ring.NewUint(ringP.Modulus[i])) + } + + plaintext := NewPlaintext(params, levelQ) plaintext.Value.IsNTT = true encryptor := NewEncryptor(params, sk) ciphertext := NewCiphertextNTT(params, 1, plaintext.Level()) @@ -337,65 +355,162 @@ func testKeySwitcher(kgen KeyGenerator, t *testing.T) { // reconstruction mod each RNS t.Run(testString(params, "DecomposeNTT/"), func(t *testing.T) { - c2 := ciphertext.Value[1] + c2InvNTT := ringQ.NewPolyLvl(ciphertext.Level()) + ringQ.InvNTT(ciphertext.Value[1], c2InvNTT) - ks.DecomposeNTT(ciphertext.Level(), c2, ks.PoolDecompQ, ks.PoolDecompP) - - coeffsBigintHave := make([]*big.Int, ringQ.N) + coeffsBigintHaveQ := make([]*big.Int, ringQ.N) + coeffsBigintHaveP := make([]*big.Int, ringQ.N) coeffsBigintRef := make([]*big.Int, ringQ.N) coeffsBigintWant := make([]*big.Int, ringQ.N) for i := range coeffsBigintRef { - coeffsBigintHave[i] = new(big.Int) + coeffsBigintHaveQ[i] = new(big.Int) + coeffsBigintHaveP[i] = new(big.Int) coeffsBigintRef[i] = new(big.Int) coeffsBigintWant[i] = new(big.Int) } - ringQ.PolyToBigintCenteredLvl(len(ringQ.Modulus)-1, c2, coeffsBigintRef) + ringQ.PolyToBigintCenteredLvl(ciphertext.Level(), c2InvNTT, coeffsBigintRef) - for i := 0; i < len(ks.PoolDecompQ); i++ { + tmpQ := ringQ.NewPolyLvl(ciphertext.Level()) + tmpP := ringP.NewPolyLvl(levelP) + + for i := 0; i < len(ks.PoolDecompQP); i++ { + + ks.DecomposeSingleNTT(levelQ, levelP, alpha, i, ciphertext.Value[1], c2InvNTT, ks.PoolDecompQP[i].Q, ks.PoolDecompQP[i].P) // Compute q_alpha_i in bigInt - modulus := ring.NewInt(1) + qalphai := ring.NewInt(1) - for j := 0; j < params.PCount(); j++ { - idx := i*params.PCount() + j - if idx > params.QCount()-1 { + for j := 0; j < alpha; j++ { + idx := i*alpha + j + if idx > levelQ { break } - modulus.Mul(modulus, ring.NewUint(ringQ.Modulus[idx])) + qalphai.Mul(qalphai, ring.NewUint(ringQ.Modulus[idx])) } - // Reconstruct the decomposed polynomial - polyQP := new(ring.Poly) - polyQP.Coeffs = append(ks.PoolDecompQ[i].Coeffs, ks.PoolDecompP[i].Coeffs...) - ringQP.PolyToBigintCenteredLvl(len(ringQP.Modulus)-1, polyQP, coeffsBigintHave) + ringQ.ReduceLvl(levelQ, ks.PoolDecompQP[i].Q, ks.PoolDecompQP[i].Q) + ringP.ReduceLvl(levelP, ks.PoolDecompQP[i].P, ks.PoolDecompQP[i].P) + + ringQ.InvNTTLvl(levelQ, ks.PoolDecompQP[i].Q, tmpQ) + ringP.InvNTTLvl(levelP, ks.PoolDecompQP[i].P, tmpP) + + ringQ.PolyToBigintCenteredLvl(levelQ, tmpQ, coeffsBigintHaveQ) + ringP.PolyToBigintCenteredLvl(levelP, tmpP, coeffsBigintHaveP) // Checks that Reconstruct(NTT(c2 mod Q)) mod q_alpha_i == Reconstruct(NTT(Decomp(c2 mod Q, q_alpha-i) mod QP)) - for i := range coeffsBigintWant { - coeffsBigintHave[i].Mod(coeffsBigintHave[i], modulus) - coeffsBigintWant[i].Mod(coeffsBigintRef[i], modulus) - require.Equal(t, coeffsBigintHave[i].Cmp(coeffsBigintWant[i]), 0) + for i := range coeffsBigintWant[:1] { + + coeffsBigintWant[i].Mod(coeffsBigintRef[i], qalphai) + coeffsBigintWant[i].Mod(coeffsBigintWant[i], QBig) + coeffsBigintHaveQ[i].Mod(coeffsBigintHaveQ[i], QBig) + require.Equal(t, coeffsBigintHaveQ[i].Cmp(coeffsBigintWant[i]), 0) + + coeffsBigintWant[i].Mod(coeffsBigintRef[i], qalphai) + coeffsBigintWant[i].Mod(coeffsBigintWant[i], PBig) + coeffsBigintHaveP[i].Mod(coeffsBigintHaveP[i], PBig) + require.Equal(t, coeffsBigintHaveP[i].Cmp(coeffsBigintWant[i]), 0) + } } }) // Test that Dec(KS(Enc(ct, sk), skOut), skOut) has a small norm - t.Run(testString(params, "KeySwitch/"), func(t *testing.T) { + t.Run(testString(params, "KeySwitch/Standard/"), func(t *testing.T) { swk := kgen.GenSwitchingKey(sk, skOut) - ks.SwitchKeysInPlace(ciphertext.Value[1].Level(), ciphertext.Value[1], swk, ks.PoolQ[1], ks.PoolQ[2]) - ringQ.Add(ciphertext.Value[0], ks.PoolQ[1], ciphertext.Value[0]) - ring.CopyValues(ks.PoolQ[2], ciphertext.Value[1]) - ringQ.MulCoeffsMontgomeryAndAddLvl(ciphertext.Level(), ciphertext.Value[1], skOut.Value, ciphertext.Value[0]) + ks.SwitchKeysInPlace(ciphertext.Value[1].Level(), ciphertext.Value[1], swk, ks.Pool[1].Q, ks.Pool[2].Q) + ringQ.Add(ciphertext.Value[0], ks.Pool[1].Q, ciphertext.Value[0]) + ring.CopyValues(ks.Pool[2].Q, ciphertext.Value[1]) + ringQ.MulCoeffsMontgomeryAndAddLvl(ciphertext.Level(), ciphertext.Value[1], skOut.Value.Q, ciphertext.Value[0]) ringQ.InvNTTLvl(ciphertext.Level(), ciphertext.Value[0], ciphertext.Value[0]) require.GreaterOrEqual(t, 10+params.LogN(), log2OfInnerSum(ciphertext.Level(), ringQ, ciphertext.Value[0])) }) } +func testKeySwitchDimension(kgen KeyGenerator, t *testing.T) { + + paramsLargeDim := kgen.(*keyGenerator).params + paramsSmallDim, _ := NewParametersFromLiteral(ParametersLiteral{ + LogN: paramsLargeDim.LogN() - 1, + Q: paramsLargeDim.Q()[:1], + P: paramsLargeDim.P()[:1], + Sigma: DefaultSigma, + }) + + t.Run(testString(paramsLargeDim, "KeySwitchDimension/LargeToSmall/"), func(t *testing.T) { + + ringQLargeDim := paramsLargeDim.RingQ() + ringQSmallDim := paramsSmallDim.RingQ() + + kgenLargeDim := NewKeyGenerator(paramsLargeDim) + skLargeDim := kgenLargeDim.GenSecretKey() + kgenSmallDim := NewKeyGenerator(paramsSmallDim) + skSmallDim := kgenSmallDim.GenSecretKey() + + swk := kgenLargeDim.GenSwitchingKey(skLargeDim, skSmallDim) + + plaintext := NewPlaintext(paramsLargeDim, paramsLargeDim.MaxLevel()) + plaintext.Value.IsNTT = true + encryptor := NewEncryptor(paramsLargeDim, skLargeDim) + ctLargeDim := NewCiphertextNTT(paramsLargeDim, 1, plaintext.Level()) + encryptor.Encrypt(plaintext, ctLargeDim) + + ks := NewKeySwitcher(paramsLargeDim) + ks.SwitchKeysInPlace(paramsSmallDim.MaxLevel(), ctLargeDim.Value[1], swk, ks.Pool[1].Q, ks.Pool[2].Q) + ringQLargeDim.AddLvl(paramsSmallDim.MaxLevel(), ctLargeDim.Value[0], ks.Pool[1].Q, ctLargeDim.Value[0]) + ring.CopyValues(ks.Pool[2].Q, ctLargeDim.Value[1]) + + //Extracts Coefficients + ctSmallDim := NewCiphertextNTT(paramsSmallDim, 1, paramsSmallDim.MaxLevel()) + + SwitchCiphertextRingDegreeNTT(ctLargeDim, ringQLargeDim, ctSmallDim) + + // Decrypts with smaller dimension key + ringQSmallDim.MulCoeffsMontgomeryAndAddLvl(ctSmallDim.Level(), ctSmallDim.Value[1], skSmallDim.Value.Q, ctSmallDim.Value[0]) + ringQSmallDim.InvNTTLvl(ctSmallDim.Level(), ctSmallDim.Value[0], ctSmallDim.Value[0]) + + require.GreaterOrEqual(t, 10+paramsSmallDim.LogN(), log2OfInnerSum(ctSmallDim.Level(), ringQSmallDim, ctSmallDim.Value[0])) + }) + + t.Run(testString(paramsLargeDim, "KeySwitchDimension/SmallToLarge/"), func(t *testing.T) { + + ringQLargeDim := paramsLargeDim.RingQ() + + kgenLargeDim := NewKeyGenerator(paramsLargeDim) + skLargeDim := kgenLargeDim.GenSecretKey() + kgenSmallDim := NewKeyGenerator(paramsSmallDim) + skSmallDim := kgenSmallDim.GenSecretKey() + + swk := kgenLargeDim.GenSwitchingKey(skSmallDim, skLargeDim) + + plaintext := NewPlaintext(paramsSmallDim, paramsSmallDim.MaxLevel()) + plaintext.Value.IsNTT = true + encryptor := NewEncryptor(paramsSmallDim, skSmallDim) + ctSmallDim := NewCiphertextNTT(paramsSmallDim, 1, plaintext.Level()) + encryptor.Encrypt(plaintext, ctSmallDim) + + //Extracts Coefficients + ctLargeDim := NewCiphertextNTT(paramsLargeDim, 1, plaintext.Level()) + + SwitchCiphertextRingDegreeNTT(ctSmallDim, nil, ctLargeDim) + + ks := NewKeySwitcher(paramsLargeDim) + ks.SwitchKeysInPlace(ctLargeDim.Value[1].Level(), ctLargeDim.Value[1], swk, ks.Pool[1].Q, ks.Pool[2].Q) + ringQLargeDim.Add(ctLargeDim.Value[0], ks.Pool[1].Q, ctLargeDim.Value[0]) + ring.CopyValues(ks.Pool[2].Q, ctLargeDim.Value[1]) + + // Decrypts with smaller dimension key + ringQLargeDim.MulCoeffsMontgomeryAndAddLvl(ctLargeDim.Level(), ctLargeDim.Value[1], skLargeDim.Value.Q, ctLargeDim.Value[0]) + ringQLargeDim.InvNTTLvl(ctLargeDim.Level(), ctLargeDim.Value[0], ctLargeDim.Value[0]) + + require.GreaterOrEqual(t, 10+paramsSmallDim.LogN(), log2OfInnerSum(ctLargeDim.Level(), ringQLargeDim, ctLargeDim.Value[0])) + }) +} + func testMarshaller(kgen KeyGenerator, t *testing.T) { params := kgen.(*keyGenerator).params - ringQP := params.RingQP() sk, pk := kgen.GenKeyPair() @@ -431,8 +546,7 @@ func testMarshaller(kgen KeyGenerator, t *testing.T) { err = skTest.UnmarshalBinary(marshalledSk) require.NoError(t, err) - require.True(t, ringQP.Equal(sk.Value, skTest.Value)) - + require.True(t, sk.Value.Equals(skTest.Value)) }) t.Run(testString(params, "Marshaller/Pk/"), func(t *testing.T) { @@ -444,9 +558,7 @@ func testMarshaller(kgen KeyGenerator, t *testing.T) { err = pkTest.UnmarshalBinary(marshalledPk) require.NoError(t, err) - for k := range pk.Value { - require.Truef(t, ringQP.Equal(pk.Value[k], pkTest.Value[k]), "Marshal PublicKey element [%d]", k) - } + require.True(t, pk.Equals(pkTest)) }) t.Run(testString(params, "Marshaller/EvaluationKey/"), func(t *testing.T) { @@ -463,14 +575,7 @@ func testMarshaller(kgen KeyGenerator, t *testing.T) { err = resEvalKey.UnmarshalBinary(data) require.NoError(t, err) - evakeyWant := evalKey.Keys[0].Value - evakeyTest := resEvalKey.Keys[0].Value - - for j := range evakeyWant { - for k := range evakeyWant[j] { - require.Truef(t, ringQP.Equal(evakeyWant[j][k], evakeyTest[j][k]), "Marshal EvaluationKey element [%d][%d]", j, k) - } - } + require.True(t, evalKey.Equals(resEvalKey)) }) t.Run(testString(params, "Marshaller/SwitchingKey/"), func(t *testing.T) { @@ -489,14 +594,7 @@ func testMarshaller(kgen KeyGenerator, t *testing.T) { err = resSwitchingKey.UnmarshalBinary(data) require.NoError(t, err) - evakeyWant := switchingKey.Value - evakeyTest := resSwitchingKey.Value - - for j := range evakeyWant { - for k := range evakeyWant[j] { - require.True(t, ringQP.Equal(evakeyWant[j][k], evakeyTest[j][k])) - } - } + require.True(t, switchingKey.Equals(resSwitchingKey)) }) t.Run(testString(params, "Marshaller/RotationKey/"), func(t *testing.T) { @@ -520,16 +618,6 @@ func testMarshaller(kgen KeyGenerator, t *testing.T) { err = resRotationKey.UnmarshalBinary(data) require.NoError(t, err) - for _, galEl := range galEls { - - evakeyWant := rotationKey.Keys[galEl].Value - evakeyTest := resRotationKey.Keys[galEl].Value - - for j := range evakeyWant { - for k := range evakeyWant[j] { - require.Truef(t, ringQP.Equal(evakeyWant[j][k], evakeyTest[j][k]), "Marshal RotationKey RotateLeft %d element [%d][%d]", galEl, j, k) - } - } - } + rotationKey.Equals(resRotationKey) }) }