diff --git a/bfv/evaluator.go b/bfv/evaluator.go index bf35d9be..b2c6d4f6 100644 --- a/bfv/evaluator.go +++ b/bfv/evaluator.go @@ -316,7 +316,6 @@ func (eval *evaluator) tensorAndRescale(ct0, ct1, ctOut *Element) { for i := range ct0.value { eval.baseconverterQ1Q2.ModUpSplitQP(levelQ, ct0.value[i], c0Q2[i]) - eval.ringQ.NTT(ct0.value[i], c0Q1[i]) eval.ringQMul.NTT(c0Q2[i], c0Q2[i]) } @@ -325,7 +324,6 @@ func (eval *evaluator) tensorAndRescale(ct0, ct1, ctOut *Element) { for i := range ct1.value { eval.baseconverterQ1Q2.ModUpSplitQP(levelQ, ct1.value[i], c1Q2[i]) - eval.ringQ.NTT(ct1.value[i], c1Q1[i]) eval.ringQMul.NTT(c1Q2[i], c1Q2[i]) } diff --git a/ckks/bootstrapp_test.go b/ckks/bootstrapp_test.go index c16e66a8..97a81334 100644 --- a/ckks/bootstrapp_test.go +++ b/ckks/bootstrapp_test.go @@ -17,7 +17,7 @@ func TestBootstrapp(t *testing.T) { var err error var testContext = new(testParams) - paramSet := uint64(0) + paramSet := uint64(1) shemeParams := DefaultBootstrappSchemeParams[paramSet : paramSet+1] bootstrappParams := DefaultBootstrappParams[paramSet : paramSet+1] diff --git a/ckks/evaluator.go b/ckks/evaluator.go index 62609543..73dd6c07 100644 --- a/ckks/evaluator.go +++ b/ckks/evaluator.go @@ -1643,33 +1643,32 @@ func (eval *evaluator) switchKeysInPlaceNoModDown(level uint64, cx *ring.Poly, e evakey1P.Coeffs = evakey.evakey[i][1].Coeffs[len(ringQ.Modulus):] if i == 0 { - ringQ.MulCoeffsMontgomeryLvl(level, evakey0Q, c2QiQ, pool2Q) - ringQ.MulCoeffsMontgomeryLvl(level, evakey1Q, c2QiQ, pool3Q) - ringP.MulCoeffsMontgomery(evakey0P, c2QiP, pool2P) - ringP.MulCoeffsMontgomery(evakey1P, c2QiP, pool3P) + ringQ.MulCoeffsMontgomeryConstantLvl(level, evakey0Q, c2QiQ, pool2Q) + ringQ.MulCoeffsMontgomeryConstantLvl(level, evakey1Q, c2QiQ, pool3Q) + ringP.MulCoeffsMontgomeryConstant(evakey0P, c2QiP, pool2P) + ringP.MulCoeffsMontgomeryConstant(evakey1P, c2QiP, pool3P) } else { - ringQ.MulCoeffsMontgomeryAndAddNoModLvl(level, evakey0Q, c2QiQ, pool2Q) - ringQ.MulCoeffsMontgomeryAndAddNoModLvl(level, evakey1Q, c2QiQ, pool3Q) - ringP.MulCoeffsMontgomeryAndAddNoMod(evakey0P, c2QiP, pool2P) - ringP.MulCoeffsMontgomeryAndAddNoMod(evakey1P, c2QiP, pool3P) + ringQ.MulCoeffsMontgomeryConstantAndAddNoModLvl(level, evakey0Q, c2QiQ, pool2Q) + ringQ.MulCoeffsMontgomeryConstantAndAddNoModLvl(level, evakey1Q, c2QiQ, pool3Q) + ringP.MulCoeffsMontgomeryConstantAndAddNoMod(evakey0P, c2QiP, pool2P) + ringP.MulCoeffsMontgomeryConstantAndAddNoMod(evakey1P, c2QiP, pool3P) } - if reduce&7 == 1 { - ringQ.ReduceLvl(level, pool2Q, pool2Q) - ringQ.ReduceLvl(level, pool3Q, pool3Q) - ringP.Reduce(pool2P, pool2P) - ringP.Reduce(pool3P, pool3P) + // + if reduce&1 == 1 { + ringQ.ReduceConstantLvl(level, pool2Q, pool2Q) + ringQ.ReduceConstantLvl(level, pool3Q, pool3Q) + ringP.ReduceConstant(pool2P, pool2P) + ringP.ReduceConstant(pool3P, pool3P) } reduce++ } - if (reduce-1)&7 != 1 { - ringQ.ReduceLvl(level, pool2Q, pool2Q) - ringQ.ReduceLvl(level, pool3Q, pool3Q) - ringP.Reduce(pool2P, pool2P) - ringP.Reduce(pool3P, pool3P) - } + ringQ.ReduceLvl(level, pool2Q, pool2Q) + ringQ.ReduceLvl(level, pool3Q, pool3Q) + ringP.Reduce(pool2P, pool2P) + ringP.Reduce(pool3P, pool3P) } // switchKeysInPlace applies the general key-switching procedure of the form [c0 + cx*evakey[0], c1 + cx*evakey[1]] @@ -1707,11 +1706,11 @@ func (eval *evaluator) decomposeAndSplitNTT(level, beta uint64, c2NTT, c2InvNTT, p1tmp[j] = p0tmp[j] } } else { - ring.NTT(c2QiQ.Coeffs[x], c2QiQ.Coeffs[x], ringQ.N, nttPsi, qi, mredParams, bredParams) + ring.NTTLazy(c2QiQ.Coeffs[x], c2QiQ.Coeffs[x], ringQ.N, nttPsi, qi, mredParams, bredParams) } } // c2QiP = c2 mod qi mod pj - ringP.NTT(c2QiP, c2QiP) + ringP.NTTLazy(c2QiP, c2QiP) } // RotateHoisted takes an input Ciphertext and a list of rotations and returns a map of Ciphertext, where each element of the map is the input Ciphertext diff --git a/ring/ring_basis_extension.go b/ring/ring_basis_extension.go index 10428a3f..f8a9b1d2 100644 --- a/ring/ring_basis_extension.go +++ b/ring/ring_basis_extension.go @@ -3,6 +3,7 @@ package ring import ( "math" "math/big" + "math/bits" "unsafe" ) @@ -174,7 +175,7 @@ func (basisextender *FastBasisExtender) ModDownNTTPQ(level uint64, p1, p2 *Poly) // First we get the P basis part of p1 out of the NTT domain for j := 0; j < nPj; j++ { - InvNTT(p1.Coeffs[nQi+j], p1.Coeffs[nQi+j], ringP.N, ringP.GetNttPsiInv()[j], ringP.GetNttNInv()[j], ringP.Modulus[j], ringP.GetMredParams()[j]) + InvNTTLazy(p1.Coeffs[nQi+j], p1.Coeffs[nQi+j], ringP.N, ringP.GetNttPsiInv()[j], ringP.GetNttNInv()[j], ringP.Modulus[j], ringP.GetMredParams()[j]) } // Then we target this P basis of p1 and convert it to a Q basis (at the "level" of p1) and copy it on polypool @@ -185,15 +186,16 @@ func (basisextender *FastBasisExtender) ModDownNTTPQ(level uint64, p1, p2 *Poly) for i := uint64(0); i < level+1; i++ { qi := ringQ.Modulus[i] + twoqi := qi << 1 p1tmp := p1.Coeffs[i] p2tmp := p2.Coeffs[i] p3tmp := polypool.Coeffs[i] - params := modDownParams[i] + params := qi - modDownParams[i] mredParams := ringQ.MredParams[i] bredParams := ringQ.BredParams[i] // First we switch back the relevant polypool CRT array back to the NTT domain - NTT(p3tmp, p3tmp, ringQ.N, ringQ.GetNttPsi()[i], qi, mredParams, bredParams) + NTTLazy(p3tmp, p3tmp, ringQ.N, ringQ.GetNttPsi()[i], qi, mredParams, bredParams) // Then for each coefficient we compute (P^-1) * (p1[i][j] - polypool[i][j]) mod qi for j := uint64(0); j < ringQ.N; j = j + 8 { @@ -202,14 +204,14 @@ func (basisextender *FastBasisExtender) ModDownNTTPQ(level uint64, p1, p2 *Poly) y := (*[8]uint64)(unsafe.Pointer(&p3tmp[j])) z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j])) - z[0] = MRed(x[0]+(qi-y[0]), params, qi, mredParams) - z[1] = MRed(x[1]+(qi-y[1]), params, qi, mredParams) - z[2] = MRed(x[2]+(qi-y[2]), params, qi, mredParams) - z[3] = MRed(x[3]+(qi-y[3]), params, qi, mredParams) - z[4] = MRed(x[4]+(qi-y[4]), params, qi, mredParams) - z[5] = MRed(x[5]+(qi-y[5]), params, qi, mredParams) - z[6] = MRed(x[6]+(qi-y[6]), params, qi, mredParams) - z[7] = MRed(x[7]+(qi-y[7]), params, qi, mredParams) + z[0] = MRed(y[0]+twoqi-x[0], params, qi, mredParams) + z[1] = MRed(y[1]+twoqi-x[1], params, qi, mredParams) + z[2] = MRed(y[2]+twoqi-x[2], params, qi, mredParams) + z[3] = MRed(y[3]+twoqi-x[3], params, qi, mredParams) + z[4] = MRed(y[4]+twoqi-x[4], params, qi, mredParams) + z[5] = MRed(y[5]+twoqi-x[5], params, qi, mredParams) + z[6] = MRed(y[6]+twoqi-x[6], params, qi, mredParams) + z[7] = MRed(y[7]+twoqi-x[7], params, qi, mredParams) } } @@ -229,7 +231,7 @@ func (basisextender *FastBasisExtender) ModDownSplitNTTPQ(level uint64, p1Q, p1P polypool := basisextender.polypoolQ // First we get the P basis part of p1 out of the NTT domain - ringP.InvNTT(p1P, p1P) + ringP.InvNTTLazy(p1P, p1P) // Then we target this P basis of p1 and convert it to a Q basis (at the "level" of p1) and copy it on polypool // polypool is now the representation of the P basis of p1 but in basis Q (at the "level" of p1) @@ -239,15 +241,16 @@ func (basisextender *FastBasisExtender) ModDownSplitNTTPQ(level uint64, p1Q, p1P for i := uint64(0); i < level+1; i++ { qi := ringQ.Modulus[i] + twoqi := qi << 1 p1tmp := p1Q.Coeffs[i] p2tmp := p2.Coeffs[i] p3tmp := polypool.Coeffs[i] - params := modDownParams[i] + params := qi - modDownParams[i] mredParams := ringQ.MredParams[i] bredParams := ringQ.BredParams[i] // First we switch back the relevant polypool CRT array back to the NTT domain - NTT(p3tmp, p3tmp, ringQ.N, ringQ.GetNttPsi()[i], ringQ.Modulus[i], mredParams, bredParams) + NTTLazy(p3tmp, p3tmp, ringQ.N, ringQ.GetNttPsi()[i], ringQ.Modulus[i], mredParams, bredParams) // Then for each coefficient we compute (P^-1) * (p1[i][j] - polypool[i][j]) mod qi for j := uint64(0); j < ringQ.N; j = j + 8 { @@ -256,14 +259,14 @@ func (basisextender *FastBasisExtender) ModDownSplitNTTPQ(level uint64, p1Q, p1P y := (*[8]uint64)(unsafe.Pointer(&p3tmp[j])) z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j])) - z[0] = MRed(x[0]+(qi-y[0]), params, qi, mredParams) - z[1] = MRed(x[1]+(qi-y[1]), params, qi, mredParams) - z[2] = MRed(x[2]+(qi-y[2]), params, qi, mredParams) - z[3] = MRed(x[3]+(qi-y[3]), params, qi, mredParams) - z[4] = MRed(x[4]+(qi-y[4]), params, qi, mredParams) - z[5] = MRed(x[5]+(qi-y[5]), params, qi, mredParams) - z[6] = MRed(x[6]+(qi-y[6]), params, qi, mredParams) - z[7] = MRed(x[7]+(qi-y[7]), params, qi, mredParams) + z[0] = MRed(y[0]+twoqi-x[0], params, qi, mredParams) + z[1] = MRed(y[1]+twoqi-x[1], params, qi, mredParams) + z[2] = MRed(y[2]+twoqi-x[2], params, qi, mredParams) + z[3] = MRed(y[3]+twoqi-x[3], params, qi, mredParams) + z[4] = MRed(y[4]+twoqi-x[4], params, qi, mredParams) + z[5] = MRed(y[5]+twoqi-x[5], params, qi, mredParams) + z[6] = MRed(y[6]+twoqi-x[6], params, qi, mredParams) + z[7] = MRed(y[7]+twoqi-x[7], params, qi, mredParams) } } @@ -289,10 +292,11 @@ func (basisextender *FastBasisExtender) ModDownPQ(level uint64, p1, p2 *Poly) { for i := uint64(0); i < level+1; i++ { qi := ringQ.Modulus[i] + twoqi := qi << 1 p1tmp := p1.Coeffs[i] p2tmp := p2.Coeffs[i] p3tmp := polypool.Coeffs[i] - params := modDownParams[i] + params := qi - modDownParams[i] mredParams := ringQ.MredParams[i] // Then for each coefficient we compute (P^-1) * (p1[i][j] - polypool[i][j]) mod qi @@ -302,14 +306,14 @@ func (basisextender *FastBasisExtender) ModDownPQ(level uint64, p1, p2 *Poly) { y := (*[8]uint64)(unsafe.Pointer(&p3tmp[j])) z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j])) - z[0] = MRed(x[0]+(qi-y[0]), params, qi, mredParams) - z[1] = MRed(x[1]+(qi-y[1]), params, qi, mredParams) - z[2] = MRed(x[2]+(qi-y[2]), params, qi, mredParams) - z[3] = MRed(x[3]+(qi-y[3]), params, qi, mredParams) - z[4] = MRed(x[4]+(qi-y[4]), params, qi, mredParams) - z[5] = MRed(x[5]+(qi-y[5]), params, qi, mredParams) - z[6] = MRed(x[6]+(qi-y[6]), params, qi, mredParams) - z[7] = MRed(x[7]+(qi-y[7]), params, qi, mredParams) + z[0] = MRed(y[0]+twoqi-x[0], params, qi, mredParams) + z[1] = MRed(y[1]+twoqi-x[1], params, qi, mredParams) + z[2] = MRed(y[2]+twoqi-x[2], params, qi, mredParams) + z[3] = MRed(y[3]+twoqi-x[3], params, qi, mredParams) + z[4] = MRed(y[4]+twoqi-x[4], params, qi, mredParams) + z[5] = MRed(y[5]+twoqi-x[5], params, qi, mredParams) + z[6] = MRed(y[6]+twoqi-x[6], params, qi, mredParams) + z[7] = MRed(y[7]+twoqi-x[7], params, qi, mredParams) } } @@ -334,10 +338,11 @@ func (basisextender *FastBasisExtender) ModDownSplitPQ(level uint64, p1Q, p1P, p for i := uint64(0); i < level+1; i++ { qi := ringQ.Modulus[i] + twoqi := qi << 1 p1tmp := p1Q.Coeffs[i] p2tmp := p2.Coeffs[i] p3tmp := polypool.Coeffs[i] - params := modDownParams[i] + params := qi - modDownParams[i] mredParams := ringQ.MredParams[i] // Then for each coefficient we compute (P^-1) * (p1[i][j] - polypool[i][j]) mod qi @@ -347,14 +352,14 @@ func (basisextender *FastBasisExtender) ModDownSplitPQ(level uint64, p1Q, p1P, p y := (*[8]uint64)(unsafe.Pointer(&p3tmp[j])) z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j])) - z[0] = MRed(x[0]+(qi-y[0]), params, qi, mredParams) - z[1] = MRed(x[1]+(qi-y[1]), params, qi, mredParams) - z[2] = MRed(x[2]+(qi-y[2]), params, qi, mredParams) - z[3] = MRed(x[3]+(qi-y[3]), params, qi, mredParams) - z[4] = MRed(x[4]+(qi-y[4]), params, qi, mredParams) - z[5] = MRed(x[5]+(qi-y[5]), params, qi, mredParams) - z[6] = MRed(x[6]+(qi-y[6]), params, qi, mredParams) - z[7] = MRed(x[7]+(qi-y[7]), params, qi, mredParams) + z[0] = MRed(y[0]+twoqi-x[0], params, qi, mredParams) + z[1] = MRed(y[1]+twoqi-x[1], params, qi, mredParams) + z[2] = MRed(y[2]+twoqi-x[2], params, qi, mredParams) + z[3] = MRed(y[3]+twoqi-x[3], params, qi, mredParams) + z[4] = MRed(y[4]+twoqi-x[4], params, qi, mredParams) + z[5] = MRed(y[5]+twoqi-x[5], params, qi, mredParams) + z[6] = MRed(y[6]+twoqi-x[6], params, qi, mredParams) + z[7] = MRed(y[7]+twoqi-x[7], params, qi, mredParams) } } @@ -379,117 +384,55 @@ func (basisextender *FastBasisExtender) ModDownSplitQP(levelQ, levelP uint64, p1 for i := uint64(0); i < levelP+1; i++ { qi := ringP.Modulus[i] + twoqi := qi << 1 p1tmp := p1P.Coeffs[i] p2tmp := p2.Coeffs[i] p3tmp := polypool.Coeffs[i] - params := modDownParams[i] + params := qi - modDownParams[i] mredParams := ringP.MredParams[i] // Then for each coefficient we compute (P^-1) * (p1[i][j] - polypool[i][j]) mod qi - for j := uint64(0); j < ringP.N; j++ { - p2tmp[j] = MRed(p1tmp[j]+(qi-p3tmp[j]), params, qi, mredParams) + for j := uint64(0); j < ringP.N; j = j + 8 { + + x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j])) + y := (*[8]uint64)(unsafe.Pointer(&p3tmp[j])) + z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j])) + + z[0] = MRed(y[0]+twoqi-x[0], params, qi, mredParams) + z[1] = MRed(y[1]+twoqi-x[1], params, qi, mredParams) + z[2] = MRed(y[2]+twoqi-x[2], params, qi, mredParams) + z[3] = MRed(y[3]+twoqi-x[3], params, qi, mredParams) + z[4] = MRed(y[4]+twoqi-x[4], params, qi, mredParams) + z[5] = MRed(y[5]+twoqi-x[5], params, qi, mredParams) + z[6] = MRed(y[6]+twoqi-x[6], params, qi, mredParams) + z[7] = MRed(y[7]+twoqi-x[7], params, qi, mredParams) } } // In total we do len(P) + len(Q) NTT, which is optimal (linear in the number of moduli of P and Q) } +// Caution, returns the values in [0, 2q-1] func modUpExact(p1, p2 [][]uint64, params *modupParams) { - var v0, v1, v2, v3, v4, v5, v6, v7 uint64 - var vi0, vi1, vi2, vi3, vi4, vi5, vi6, vi7 float64 - var xpj0, xpj1, xpj2, xpj3, xpj4, xpj5, xpj6, xpj7 uint64 - - y0 := make([]uint64, len(p1), len(p1)) - y1 := make([]uint64, len(p1), len(p1)) - y2 := make([]uint64, len(p1), len(p1)) - y3 := make([]uint64, len(p1), len(p1)) - y4 := make([]uint64, len(p1), len(p1)) - y5 := make([]uint64, len(p1), len(p1)) - y6 := make([]uint64, len(p1), len(p1)) - y7 := make([]uint64, len(p1), len(p1)) - - var qibMont, qi, pj, mredParams uint64 - var qif float64 + var v [8]uint64 + var y0, y1, y2, y3, y4, y5, y6, y7 [32]uint64 // We loop over each coefficient and apply the basis extension for x := uint64(0); x < uint64(len(p1[0])); x = x + 8 { - vi0, vi1, vi2, vi3, vi4, vi5, vi6, vi7 = 0, 0, 0, 0, 0, 0, 0, 0 - - for i := 0; i < len(p1); i++ { - - qibMont = params.qibMont[i] - qi = params.Q[i] - mredParams = params.mredParamsQ[i] - qif = float64(qi) - - y0[i] = MRed(p1[i][x], qibMont, qi, mredParams) - y1[i] = MRed(p1[i][x+1], qibMont, qi, mredParams) - y2[i] = MRed(p1[i][x+2], qibMont, qi, mredParams) - y3[i] = MRed(p1[i][x+3], qibMont, qi, mredParams) - y4[i] = MRed(p1[i][x+4], qibMont, qi, mredParams) - y5[i] = MRed(p1[i][x+5], qibMont, qi, mredParams) - y6[i] = MRed(p1[i][x+6], qibMont, qi, mredParams) - y7[i] = MRed(p1[i][x+7], qibMont, qi, mredParams) - - // Computation of the correction term v * Q%pi - vi0 += float64(y0[i]) / qif - vi1 += float64(y1[i]) / qif - vi2 += float64(y2[i]) / qif - vi3 += float64(y3[i]) / qif - vi4 += float64(y4[i]) / qif - vi5 += float64(y5[i]) / qif - vi6 += float64(y6[i]) / qif - vi7 += float64(y7[i]) / qif - } - - // Index of the correction term - v0, v1, v2, v3, v4, v5, v6, v7 = uint64(vi0), uint64(vi1), uint64(vi2), uint64(vi3), uint64(vi4), uint64(vi5), uint64(vi6), uint64(vi7) + reconstructRNS(uint64(len(p1)), x, p1, &v, &y0, &y1, &y2, &y3, &y4, &y5, &y6, &y7, params.Q, params.mredParamsQ, params.qibMont) for j := 0; j < len(p2); j++ { - xpj0, xpj1, xpj2, xpj3, xpj4, xpj5, xpj6, xpj7 = 0, 0, 0, 0, 0, 0, 0, 0 - - pj = params.P[j] - mredParams = params.mredParamsP[j] - bredParams := params.bredParamsP[j] + pj := params.P[j] + qInv := params.mredParamsP[j] qpjInv := params.qpjInv[j] qispjMont := params.qispjMont[j] - res := p2[j] - for i := 0; i < len(p1); i++ { - - xpj0 += MRed(y0[i], qispjMont[i], pj, mredParams) - xpj1 += MRed(y1[i], qispjMont[i], pj, mredParams) - xpj2 += MRed(y2[i], qispjMont[i], pj, mredParams) - xpj3 += MRed(y3[i], qispjMont[i], pj, mredParams) - xpj4 += MRed(y4[i], qispjMont[i], pj, mredParams) - xpj5 += MRed(y5[i], qispjMont[i], pj, mredParams) - xpj6 += MRed(y6[i], qispjMont[i], pj, mredParams) - xpj7 += MRed(y7[i], qispjMont[i], pj, mredParams) - - if i&7 == 6 { //Only every 7 additions, since we add one more 60 bit integer after the loop - xpj0 = BRedAdd(xpj0, pj, bredParams) - xpj1 = BRedAdd(xpj1, pj, bredParams) - xpj2 = BRedAdd(xpj2, pj, bredParams) - xpj3 = BRedAdd(xpj3, pj, bredParams) - xpj4 = BRedAdd(xpj4, pj, bredParams) - xpj5 = BRedAdd(xpj5, pj, bredParams) - xpj6 = BRedAdd(xpj6, pj, bredParams) - xpj7 = BRedAdd(xpj7, pj, bredParams) - } - } - - res[x+0] = BRedAdd(xpj0+qpjInv[v0], pj, bredParams) - res[x+1] = BRedAdd(xpj1+qpjInv[v1], pj, bredParams) - res[x+2] = BRedAdd(xpj2+qpjInv[v2], pj, bredParams) - res[x+3] = BRedAdd(xpj3+qpjInv[v3], pj, bredParams) - res[x+4] = BRedAdd(xpj4+qpjInv[v4], pj, bredParams) - res[x+5] = BRedAdd(xpj5+qpjInv[v5], pj, bredParams) - res[x+6] = BRedAdd(xpj6+qpjInv[v6], pj, bredParams) - res[x+7] = BRedAdd(xpj7+qpjInv[v7], pj, bredParams) + res := (*[8]uint64)(unsafe.Pointer(&p2[j][x])) + multSum(res, &v, &y0, &y1, &y2, &y3, &y4, &y5, &y6, &y7, uint64(len(p1)), pj, qInv, qpjInv, qispjMont) } } } @@ -573,259 +516,6 @@ func NewDecomposer(Q, P []uint64) (decomposer *Decomposer) { return } -// Decompose decomposes a polynomial p(x) in basis Q, reduces it modulo qi, and returns -// the result in basis QP. -func (decomposer *Decomposer) Decompose(level, crtDecompLevel uint64, p0, p1 *Poly) { - - alphai := decomposer.xalpha[crtDecompLevel] - - p0idxst := crtDecompLevel * decomposer.alpha - p0idxed := p0idxst + alphai - - // First we check if the vector can simply by coping and rearranging elements (the case where no reconstruction is needed) - if (p0idxed > level+1 && (level+1)%decomposer.nPprimes == 1) || alphai == 1 { - - for x := uint64(0); x < uint64(len(p0.Coeffs[0])); x = x + 8 { - - tmp := p0.Coeffs[p0idxst] - - for j := uint64(0); j < level+decomposer.nPprimes+1; j++ { - - p1.Coeffs[j][x+0] = tmp[x+0] - p1.Coeffs[j][x+1] = tmp[x+1] - p1.Coeffs[j][x+2] = tmp[x+2] - p1.Coeffs[j][x+3] = tmp[x+3] - p1.Coeffs[j][x+4] = tmp[x+4] - p1.Coeffs[j][x+5] = tmp[x+5] - p1.Coeffs[j][x+6] = tmp[x+6] - p1.Coeffs[j][x+7] = tmp[x+7] - } - } - - // Otherwise, we apply a fast exact base conversion for the reconstruction - } else { - - var index uint64 - if level >= alphai+crtDecompLevel*decomposer.alpha { - index = decomposer.xalpha[crtDecompLevel] - 2 - } else { - index = (level - 1) % decomposer.alpha - } - - params := decomposer.modUpParams[crtDecompLevel][index] - - v := make([]uint64, 8, 8) - vi := make([]float64, 8, 8) - xpj := make([]uint64, 8, 8) - - y0 := make([]uint64, index+2, index+2) - y1 := make([]uint64, index+2, index+2) - y2 := make([]uint64, index+2, index+2) - y3 := make([]uint64, index+2, index+2) - y4 := make([]uint64, index+2, index+2) - y5 := make([]uint64, index+2, index+2) - y6 := make([]uint64, index+2, index+2) - y7 := make([]uint64, index+2, index+2) - - var qibMont, qi, pj, mredParams uint64 - var qif float64 - - // We loop over each coefficient and apply the basis extension - for x := uint64(0); x < uint64(len(p0.Coeffs[0])); x = x + 8 { - - vi[0], vi[1], vi[2], vi[3], vi[4], vi[5], vi[6], vi[7] = 0, 0, 0, 0, 0, 0, 0, 0 - - // Coefficients to be decomposed - for i, j := uint64(0), p0idxst; i < index+2; i, j = i+1, j+1 { - - qibMont = params.qibMont[i] - qi = params.Q[i] - mredParams = params.mredParamsQ[i] - qif = float64(qi) - - px := p0.Coeffs[j] - py := p1.Coeffs[j] - - // For the coefficients to be decomposed, we can simply copy them - py[x+0] = px[x+0] - py[x+1] = px[x+1] - py[x+2] = px[x+2] - py[x+3] = px[x+3] - py[x+4] = px[x+4] - py[x+5] = px[x+5] - py[x+6] = px[x+6] - py[x+7] = px[x+7] - - y0[i] = MRed(px[x+0], qibMont, qi, mredParams) - y1[i] = MRed(px[x+1], qibMont, qi, mredParams) - y2[i] = MRed(px[x+2], qibMont, qi, mredParams) - y3[i] = MRed(px[x+3], qibMont, qi, mredParams) - y4[i] = MRed(px[x+4], qibMont, qi, mredParams) - y5[i] = MRed(px[x+5], qibMont, qi, mredParams) - y6[i] = MRed(px[x+6], qibMont, qi, mredParams) - y7[i] = MRed(px[x+7], qibMont, qi, mredParams) - - // Computation of the correction term v * Q%pi - vi[0] += float64(y0[i]) / qif - vi[1] += float64(y1[i]) / qif - vi[2] += float64(y2[i]) / qif - vi[3] += float64(y3[i]) / qif - vi[4] += float64(y4[i]) / qif - vi[5] += float64(y5[i]) / qif - vi[6] += float64(y6[i]) / qif - vi[7] += float64(y7[i]) / qif - } - - // Index of the correction term - v[0] = uint64(vi[0]) - v[1] = uint64(vi[1]) - v[2] = uint64(vi[2]) - v[3] = uint64(vi[3]) - v[4] = uint64(vi[4]) - v[5] = uint64(vi[5]) - v[6] = uint64(vi[6]) - v[7] = uint64(vi[7]) - - // Coefficients of index smaller than the ones to be decomposed - for j := uint64(0); j < p0idxst; j++ { - - xpj[0], xpj[1], xpj[2], xpj[3], xpj[4], xpj[5], xpj[6], xpj[7] = 0, 0, 0, 0, 0, 0, 0, 0 - - pj = params.P[j] - mredParams = params.mredParamsP[j] - bredParams := params.bredParamsP[j] - qpjInv := params.qpjInv[j] - qispjMont := params.qispjMont[j] - res := p1.Coeffs[j] - - for i := uint64(0); i < index+2; i++ { - - xpj[0] += MRed(y0[i], qispjMont[i], pj, mredParams) - xpj[1] += MRed(y1[i], qispjMont[i], pj, mredParams) - xpj[2] += MRed(y2[i], qispjMont[i], pj, mredParams) - xpj[3] += MRed(y3[i], qispjMont[i], pj, mredParams) - xpj[4] += MRed(y4[i], qispjMont[i], pj, mredParams) - xpj[5] += MRed(y5[i], qispjMont[i], pj, mredParams) - xpj[6] += MRed(y6[i], qispjMont[i], pj, mredParams) - xpj[7] += MRed(y7[i], qispjMont[i], pj, mredParams) - - if i&7 == 6 { // Only every 7 additions, since we add one more 60 bit integer after the loop - xpj[0] = BRedAdd(xpj[0], pj, bredParams) - xpj[1] = BRedAdd(xpj[1], pj, bredParams) - xpj[2] = BRedAdd(xpj[2], pj, bredParams) - xpj[3] = BRedAdd(xpj[3], pj, bredParams) - xpj[4] = BRedAdd(xpj[4], pj, bredParams) - xpj[5] = BRedAdd(xpj[5], pj, bredParams) - xpj[6] = BRedAdd(xpj[6], pj, bredParams) - xpj[7] = BRedAdd(xpj[7], pj, bredParams) - } - } - - res[x+0] = BRedAdd(xpj[0]+qpjInv[v[0]], pj, bredParams) - res[x+1] = BRedAdd(xpj[1]+qpjInv[v[1]], pj, bredParams) - res[x+2] = BRedAdd(xpj[2]+qpjInv[v[2]], pj, bredParams) - res[x+3] = BRedAdd(xpj[3]+qpjInv[v[3]], pj, bredParams) - res[x+4] = BRedAdd(xpj[4]+qpjInv[v[4]], pj, bredParams) - res[x+5] = BRedAdd(xpj[5]+qpjInv[v[5]], pj, bredParams) - res[x+6] = BRedAdd(xpj[6]+qpjInv[v[6]], pj, bredParams) - res[x+7] = BRedAdd(xpj[7]+qpjInv[v[7]], pj, bredParams) - - } - - // Coefficients of index greater than the ones to be decomposed - for j := decomposer.alpha * crtDecompLevel; j < level+1; j = j + 1 { - - xpj[0], xpj[1], xpj[2], xpj[3], xpj[4], xpj[5], xpj[6], xpj[7] = 0, 0, 0, 0, 0, 0, 0, 0 - - pj = params.P[j] - mredParams = params.mredParamsP[j] - bredParams := params.bredParamsP[j] - qpjInv := params.qpjInv[j] - qispjMont := params.qispjMont[j] - res := p1.Coeffs[j] - - for i := uint64(0); i < index+2; i++ { - - xpj[0] += MRed(y0[i], qispjMont[i], pj, mredParams) - xpj[1] += MRed(y1[i], qispjMont[i], pj, mredParams) - xpj[2] += MRed(y2[i], qispjMont[i], pj, mredParams) - xpj[3] += MRed(y3[i], qispjMont[i], pj, mredParams) - xpj[4] += MRed(y4[i], qispjMont[i], pj, mredParams) - xpj[5] += MRed(y5[i], qispjMont[i], pj, mredParams) - xpj[6] += MRed(y6[i], qispjMont[i], pj, mredParams) - xpj[7] += MRed(y7[i], qispjMont[i], pj, mredParams) - - if i&7 == 6 { // Only every 7 additions, since we add one more 60 bit integer after the loop - xpj[0] = BRedAdd(xpj[0], pj, bredParams) - xpj[1] = BRedAdd(xpj[1], pj, bredParams) - xpj[2] = BRedAdd(xpj[2], pj, bredParams) - xpj[3] = BRedAdd(xpj[3], pj, bredParams) - xpj[4] = BRedAdd(xpj[4], pj, bredParams) - xpj[5] = BRedAdd(xpj[5], pj, bredParams) - xpj[6] = BRedAdd(xpj[6], pj, bredParams) - xpj[7] = BRedAdd(xpj[7], pj, bredParams) - } - } - - res[x+0] = BRedAdd(xpj[0]+qpjInv[v[0]], pj, bredParams) - res[x+1] = BRedAdd(xpj[1]+qpjInv[v[1]], pj, bredParams) - res[x+2] = BRedAdd(xpj[2]+qpjInv[v[2]], pj, bredParams) - res[x+3] = BRedAdd(xpj[3]+qpjInv[v[3]], pj, bredParams) - res[x+4] = BRedAdd(xpj[4]+qpjInv[v[4]], pj, bredParams) - res[x+5] = BRedAdd(xpj[5]+qpjInv[v[5]], pj, bredParams) - res[x+6] = BRedAdd(xpj[6]+qpjInv[v[6]], pj, bredParams) - res[x+7] = BRedAdd(xpj[7]+qpjInv[v[7]], pj, bredParams) - - } - - // Coefficients of the special primes - for j, u := level+1, decomposer.nQprimes; j < level+1+decomposer.nPprimes; j, u = u+1, j+1 { - - xpj[0], xpj[1], xpj[2], xpj[3], xpj[4], xpj[5], xpj[6], xpj[7] = 0, 0, 0, 0, 0, 0, 0, 0 - - pj = params.P[j] - mredParams = params.mredParamsP[j] - bredParams := params.bredParamsP[j] - qpjInv := params.qpjInv[j] - qispjMont := params.qispjMont[j] - res := p1.Coeffs[u] - - for i := uint64(0); i < index+2; i++ { - - xpj[0] += MRed(y0[i], qispjMont[i], pj, mredParams) - xpj[1] += MRed(y1[i], qispjMont[i], pj, mredParams) - xpj[2] += MRed(y2[i], qispjMont[i], pj, mredParams) - xpj[3] += MRed(y3[i], qispjMont[i], pj, mredParams) - xpj[4] += MRed(y4[i], qispjMont[i], pj, mredParams) - xpj[5] += MRed(y5[i], qispjMont[i], pj, mredParams) - xpj[6] += MRed(y6[i], qispjMont[i], pj, mredParams) - xpj[7] += MRed(y7[i], qispjMont[i], pj, mredParams) - - if i&7 == 6 { // Only every 7 additions, since we add one more 60 bit integer after the loop - xpj[0] = BRedAdd(xpj[0], pj, bredParams) - xpj[1] = BRedAdd(xpj[1], pj, bredParams) - xpj[2] = BRedAdd(xpj[2], pj, bredParams) - xpj[3] = BRedAdd(xpj[3], pj, bredParams) - xpj[4] = BRedAdd(xpj[4], pj, bredParams) - xpj[5] = BRedAdd(xpj[5], pj, bredParams) - xpj[6] = BRedAdd(xpj[6], pj, bredParams) - xpj[7] = BRedAdd(xpj[7], pj, bredParams) - } - } - - res[x+0] = BRedAdd(xpj[0]+qpjInv[v[0]], pj, bredParams) - res[x+1] = BRedAdd(xpj[1]+qpjInv[v[1]], pj, bredParams) - res[x+2] = BRedAdd(xpj[2]+qpjInv[v[2]], pj, bredParams) - res[x+3] = BRedAdd(xpj[3]+qpjInv[v[3]], pj, bredParams) - res[x+4] = BRedAdd(xpj[4]+qpjInv[v[4]], pj, bredParams) - res[x+5] = BRedAdd(xpj[5]+qpjInv[v[5]], pj, bredParams) - res[x+6] = BRedAdd(xpj[6]+qpjInv[v[6]], pj, bredParams) - res[x+7] = BRedAdd(xpj[7]+qpjInv[v[7]], pj, bredParams) - } - } - } -} - // DecomposeAndSplit decomposes a polynomial p(x) in basis Q, reduces it modulo qi, and returns // the result in basis QP separately. func (decomposer *Decomposer) DecomposeAndSplit(level, crtDecompLevel uint64, p0, p1Q, p1P *Poly) { @@ -838,33 +528,12 @@ func (decomposer *Decomposer) DecomposeAndSplit(level, crtDecompLevel uint64, p0 // First we check if the vector can simply by coping and rearranging elements (the case where no reconstruction is needed) if (p0idxed > level+1 && (level+1)%decomposer.nPprimes == 1) || alphai == 1 { - for x := uint64(0); x < uint64(len(p0.Coeffs[0])); x = x + 8 { + for j := uint64(0); j < level+1; j++ { + copy(p1Q.Coeffs[j], p0.Coeffs[p0idxst]) + } - tmp := p0.Coeffs[p0idxst] - - for j := uint64(0); j < level+1; j++ { - - p1Q.Coeffs[j][x+0] = tmp[x+0] - p1Q.Coeffs[j][x+1] = tmp[x+1] - p1Q.Coeffs[j][x+2] = tmp[x+2] - p1Q.Coeffs[j][x+3] = tmp[x+3] - p1Q.Coeffs[j][x+4] = tmp[x+4] - p1Q.Coeffs[j][x+5] = tmp[x+5] - p1Q.Coeffs[j][x+6] = tmp[x+6] - p1Q.Coeffs[j][x+7] = tmp[x+7] - } - - for j := uint64(0); j < decomposer.nPprimes; j++ { - - p1P.Coeffs[j][x+0] = tmp[x+0] - p1P.Coeffs[j][x+1] = tmp[x+1] - p1P.Coeffs[j][x+2] = tmp[x+2] - p1P.Coeffs[j][x+3] = tmp[x+3] - p1P.Coeffs[j][x+4] = tmp[x+4] - p1P.Coeffs[j][x+5] = tmp[x+5] - p1P.Coeffs[j][x+6] = tmp[x+6] - p1P.Coeffs[j][x+7] = tmp[x+7] - } + for j := uint64(0); j < decomposer.nPprimes; j++ { + copy(p1P.Coeffs[j], p0.Coeffs[p0idxst]) } // Otherwise, we apply a fast exact base conversion for the reconstruction @@ -879,19 +548,9 @@ func (decomposer *Decomposer) DecomposeAndSplit(level, crtDecompLevel uint64, p0 params := decomposer.modUpParams[crtDecompLevel][index] - v := make([]uint64, 8, 8) - vi := make([]float64, 8, 8) - xpj := make([]uint64, 8, 8) - - y0 := make([]uint64, index+2, index+2) - y1 := make([]uint64, index+2, index+2) - y2 := make([]uint64, index+2, index+2) - y3 := make([]uint64, index+2, index+2) - y4 := make([]uint64, index+2, index+2) - y5 := make([]uint64, index+2, index+2) - y6 := make([]uint64, index+2, index+2) - y7 := make([]uint64, index+2, index+2) - + var v [8]uint64 + var vi [8]float64 + var y0, y1, y2, y3, y4, y5, y6, y7 [32]uint64 var qibMont, qi, pj, mredParams uint64 var qif float64 @@ -908,27 +567,20 @@ func (decomposer *Decomposer) DecomposeAndSplit(level, crtDecompLevel uint64, p0 mredParams = params.mredParamsQ[i] qif = float64(qi) - px := p0.Coeffs[j] - py := p1Q.Coeffs[j] + px := (*[8]uint64)(unsafe.Pointer(&p0.Coeffs[j][x])) + py := (*[8]uint64)(unsafe.Pointer(&p1Q.Coeffs[j][x])) // For the coefficients to be decomposed, we can simply copy them - py[x+0] = px[x+0] - py[x+1] = px[x+1] - py[x+2] = px[x+2] - py[x+3] = px[x+3] - py[x+4] = px[x+4] - py[x+5] = px[x+5] - py[x+6] = px[x+6] - py[x+7] = px[x+7] + py[0], py[1], py[2], py[3], py[4], py[5], py[6], py[7] = px[0], px[1], px[2], px[3], px[4], px[5], px[6], px[7] - y0[i] = MRed(px[x+0], qibMont, qi, mredParams) - y1[i] = MRed(px[x+1], qibMont, qi, mredParams) - y2[i] = MRed(px[x+2], qibMont, qi, mredParams) - y3[i] = MRed(px[x+3], qibMont, qi, mredParams) - y4[i] = MRed(px[x+4], qibMont, qi, mredParams) - y5[i] = MRed(px[x+5], qibMont, qi, mredParams) - y6[i] = MRed(px[x+6], qibMont, qi, mredParams) - y7[i] = MRed(px[x+7], qibMont, qi, mredParams) + y0[i] = MRed(px[0], qibMont, qi, mredParams) + y1[i] = MRed(px[1], qibMont, qi, mredParams) + y2[i] = MRed(px[2], qibMont, qi, mredParams) + y3[i] = MRed(px[3], qibMont, qi, mredParams) + y4[i] = MRed(px[4], qibMont, qi, mredParams) + y5[i] = MRed(px[5], qibMont, qi, mredParams) + y6[i] = MRed(px[6], qibMont, qi, mredParams) + y7[i] = MRed(px[7], qibMont, qi, mredParams) // Computation of the correction term v * Q%pi vi[0] += float64(y0[i]) / qif @@ -954,140 +606,151 @@ func (decomposer *Decomposer) DecomposeAndSplit(level, crtDecompLevel uint64, p0 // Coefficients of index smaller than the ones to be decomposed for j := uint64(0); j < p0idxst; j++ { - xpj[0], xpj[1], xpj[2], xpj[3], xpj[4], xpj[5], xpj[6], xpj[7] = 0, 0, 0, 0, 0, 0, 0, 0 - pj = params.P[j] - mredParams = params.mredParamsP[j] - bredParams := params.bredParamsP[j] + qInv := params.mredParamsP[j] qpjInv := params.qpjInv[j] qispjMont := params.qispjMont[j] - res := p1Q.Coeffs[j] - for i := uint64(0); i < index+2; i++ { - - xpj[0] += MRed(y0[i], qispjMont[i], pj, mredParams) - xpj[1] += MRed(y1[i], qispjMont[i], pj, mredParams) - xpj[2] += MRed(y2[i], qispjMont[i], pj, mredParams) - xpj[3] += MRed(y3[i], qispjMont[i], pj, mredParams) - xpj[4] += MRed(y4[i], qispjMont[i], pj, mredParams) - xpj[5] += MRed(y5[i], qispjMont[i], pj, mredParams) - xpj[6] += MRed(y6[i], qispjMont[i], pj, mredParams) - xpj[7] += MRed(y7[i], qispjMont[i], pj, mredParams) - - if i&7 == 6 { // Only every 7 additions, since we add one more 60 bit integer after the loop - xpj[0] = BRedAdd(xpj[0], pj, bredParams) - xpj[1] = BRedAdd(xpj[1], pj, bredParams) - xpj[2] = BRedAdd(xpj[2], pj, bredParams) - xpj[3] = BRedAdd(xpj[3], pj, bredParams) - xpj[4] = BRedAdd(xpj[4], pj, bredParams) - xpj[5] = BRedAdd(xpj[5], pj, bredParams) - xpj[6] = BRedAdd(xpj[6], pj, bredParams) - xpj[7] = BRedAdd(xpj[7], pj, bredParams) - } - } - - res[x+0] = BRedAdd(xpj[0]+qpjInv[v[0]], pj, bredParams) - res[x+1] = BRedAdd(xpj[1]+qpjInv[v[1]], pj, bredParams) - res[x+2] = BRedAdd(xpj[2]+qpjInv[v[2]], pj, bredParams) - res[x+3] = BRedAdd(xpj[3]+qpjInv[v[3]], pj, bredParams) - res[x+4] = BRedAdd(xpj[4]+qpjInv[v[4]], pj, bredParams) - res[x+5] = BRedAdd(xpj[5]+qpjInv[v[5]], pj, bredParams) - res[x+6] = BRedAdd(xpj[6]+qpjInv[v[6]], pj, bredParams) - res[x+7] = BRedAdd(xpj[7]+qpjInv[v[7]], pj, bredParams) + res := (*[8]uint64)(unsafe.Pointer(&p1Q.Coeffs[j][x])) + multSum(res, &v, &y0, &y1, &y2, &y3, &y4, &y5, &y6, &y7, index+2, pj, qInv, qpjInv, qispjMont) } // Coefficients of index greater than the ones to be decomposed for j := decomposer.alpha * crtDecompLevel; j < level+1; j = j + 1 { - xpj[0], xpj[1], xpj[2], xpj[3], xpj[4], xpj[5], xpj[6], xpj[7] = 0, 0, 0, 0, 0, 0, 0, 0 - pj = params.P[j] - mredParams = params.mredParamsP[j] - bredParams := params.bredParamsP[j] + qInv := params.mredParamsP[j] qpjInv := params.qpjInv[j] qispjMont := params.qispjMont[j] - res := p1Q.Coeffs[j] - for i := uint64(0); i < index+2; i++ { - - xpj[0] += MRed(y0[i], qispjMont[i], pj, mredParams) - xpj[1] += MRed(y1[i], qispjMont[i], pj, mredParams) - xpj[2] += MRed(y2[i], qispjMont[i], pj, mredParams) - xpj[3] += MRed(y3[i], qispjMont[i], pj, mredParams) - xpj[4] += MRed(y4[i], qispjMont[i], pj, mredParams) - xpj[5] += MRed(y5[i], qispjMont[i], pj, mredParams) - xpj[6] += MRed(y6[i], qispjMont[i], pj, mredParams) - xpj[7] += MRed(y7[i], qispjMont[i], pj, mredParams) - - if i&7 == 6 { // Only every 7 additions, since we add one more 60 bit integer after the loop - xpj[0] = BRedAdd(xpj[0], pj, bredParams) - xpj[1] = BRedAdd(xpj[1], pj, bredParams) - xpj[2] = BRedAdd(xpj[2], pj, bredParams) - xpj[3] = BRedAdd(xpj[3], pj, bredParams) - xpj[4] = BRedAdd(xpj[4], pj, bredParams) - xpj[5] = BRedAdd(xpj[5], pj, bredParams) - xpj[6] = BRedAdd(xpj[6], pj, bredParams) - xpj[7] = BRedAdd(xpj[7], pj, bredParams) - } - } - - res[x+0] = BRedAdd(xpj[0]+qpjInv[v[0]], pj, bredParams) - res[x+1] = BRedAdd(xpj[1]+qpjInv[v[1]], pj, bredParams) - res[x+2] = BRedAdd(xpj[2]+qpjInv[v[2]], pj, bredParams) - res[x+3] = BRedAdd(xpj[3]+qpjInv[v[3]], pj, bredParams) - res[x+4] = BRedAdd(xpj[4]+qpjInv[v[4]], pj, bredParams) - res[x+5] = BRedAdd(xpj[5]+qpjInv[v[5]], pj, bredParams) - res[x+6] = BRedAdd(xpj[6]+qpjInv[v[6]], pj, bredParams) - res[x+7] = BRedAdd(xpj[7]+qpjInv[v[7]], pj, bredParams) + res := (*[8]uint64)(unsafe.Pointer(&p1Q.Coeffs[j][x])) + multSum(res, &v, &y0, &y1, &y2, &y3, &y4, &y5, &y6, &y7, index+2, pj, qInv, qpjInv, qispjMont) } - // Coefficients of the special primes + // Coefficients of the special primes Pi for j, u := uint64(0), decomposer.nQprimes; j < decomposer.nPprimes; j, u = j+1, u+1 { - xpj[0], xpj[1], xpj[2], xpj[3], xpj[4], xpj[5], xpj[6], xpj[7] = 0, 0, 0, 0, 0, 0, 0, 0 - pj = params.P[u] - mredParams = params.mredParamsP[u] - bredParams := params.bredParamsP[u] + qInv := params.mredParamsP[u] qpjInv := params.qpjInv[u] qispjMont := params.qispjMont[u] - res := p1P.Coeffs[j] - for i := uint64(0); i < index+2; i++ { - - xpj[0] += MRed(y0[i], qispjMont[i], pj, mredParams) - xpj[1] += MRed(y1[i], qispjMont[i], pj, mredParams) - xpj[2] += MRed(y2[i], qispjMont[i], pj, mredParams) - xpj[3] += MRed(y3[i], qispjMont[i], pj, mredParams) - xpj[4] += MRed(y4[i], qispjMont[i], pj, mredParams) - xpj[5] += MRed(y5[i], qispjMont[i], pj, mredParams) - xpj[6] += MRed(y6[i], qispjMont[i], pj, mredParams) - xpj[7] += MRed(y7[i], qispjMont[i], pj, mredParams) - - if i&7 == 6 { // Only every 7 additions, since we add one more 60 bit integer after the loop - xpj[0] = BRedAdd(xpj[0], pj, bredParams) - xpj[1] = BRedAdd(xpj[1], pj, bredParams) - xpj[2] = BRedAdd(xpj[2], pj, bredParams) - xpj[3] = BRedAdd(xpj[3], pj, bredParams) - xpj[4] = BRedAdd(xpj[4], pj, bredParams) - xpj[5] = BRedAdd(xpj[5], pj, bredParams) - xpj[6] = BRedAdd(xpj[6], pj, bredParams) - xpj[7] = BRedAdd(xpj[7], pj, bredParams) - } - } - - res[x+0] = BRedAdd(xpj[0]+qpjInv[v[0]], pj, bredParams) - res[x+1] = BRedAdd(xpj[1]+qpjInv[v[1]], pj, bredParams) - res[x+2] = BRedAdd(xpj[2]+qpjInv[v[2]], pj, bredParams) - res[x+3] = BRedAdd(xpj[3]+qpjInv[v[3]], pj, bredParams) - res[x+4] = BRedAdd(xpj[4]+qpjInv[v[4]], pj, bredParams) - res[x+5] = BRedAdd(xpj[5]+qpjInv[v[5]], pj, bredParams) - res[x+6] = BRedAdd(xpj[6]+qpjInv[v[6]], pj, bredParams) - res[x+7] = BRedAdd(xpj[7]+qpjInv[v[7]], pj, bredParams) + res := (*[8]uint64)(unsafe.Pointer(&p1P.Coeffs[j][x])) + multSum(res, &v, &y0, &y1, &y2, &y3, &y4, &y5, &y6, &y7, index+2, pj, qInv, qpjInv, qispjMont) } } } } + +func reconstructRNS(index, x uint64, p [][]uint64, v *[8]uint64, y0, y1, y2, y3, y4, y5, y6, y7 *[32]uint64, Q, QInv, QbMont []uint64) { + + var vi [8]float64 + var qi, qiInv, qibMont uint64 + var qif float64 + + for i := uint64(0); i < index; i++ { + + qibMont = QbMont[i] + qi = Q[i] + qiInv = QInv[i] + qif = float64(qi) + + y0[i] = MRed(p[i][x+0], qibMont, qi, qiInv) + y1[i] = MRed(p[i][x+1], qibMont, qi, qiInv) + y2[i] = MRed(p[i][x+2], qibMont, qi, qiInv) + y3[i] = MRed(p[i][x+3], qibMont, qi, qiInv) + y4[i] = MRed(p[i][x+4], qibMont, qi, qiInv) + y5[i] = MRed(p[i][x+5], qibMont, qi, qiInv) + y6[i] = MRed(p[i][x+6], qibMont, qi, qiInv) + y7[i] = MRed(p[i][x+7], qibMont, qi, qiInv) + + // Computation of the correction term v * Q%pi + vi[0] += float64(y0[i]) / qif + vi[1] += float64(y1[i]) / qif + vi[2] += float64(y2[i]) / qif + vi[3] += float64(y3[i]) / qif + vi[4] += float64(y4[i]) / qif + vi[5] += float64(y5[i]) / qif + vi[6] += float64(y6[i]) / qif + vi[7] += float64(y7[i]) / qif + } + + v[0] = uint64(vi[0]) + v[1] = uint64(vi[1]) + v[2] = uint64(vi[2]) + v[3] = uint64(vi[3]) + v[4] = uint64(vi[4]) + v[5] = uint64(vi[5]) + v[6] = uint64(vi[6]) + v[7] = uint64(vi[7]) +} + +// Caution, returns the values in [0, 2q-1] +func multSum(res, v *[8]uint64, y0, y1, y2, y3, y4, y5, y6, y7 *[32]uint64, index, pj, qInv uint64, qpjInv, qispjMont []uint64) { + + var rlo, rhi [8]uint64 + var mhi, mlo, c, hhi uint64 + + // Accumulates the sum on uint128 and does a lazy montgomery reduction at the end + for i := uint64(0); i < index; i++ { + + mhi, mlo = bits.Mul64(y0[i], qispjMont[i]) + rlo[0], c = bits.Add64(rlo[0], mlo, 0) + rhi[0] += mhi + c + + mhi, mlo = bits.Mul64(y1[i], qispjMont[i]) + rlo[1], c = bits.Add64(rlo[1], mlo, 0) + rhi[1] += mhi + c + + mhi, mlo = bits.Mul64(y2[i], qispjMont[i]) + rlo[2], c = bits.Add64(rlo[2], mlo, 0) + rhi[2] += mhi + c + + mhi, mlo = bits.Mul64(y3[i], qispjMont[i]) + rlo[3], c = bits.Add64(rlo[3], mlo, 0) + rhi[3] += mhi + c + + mhi, mlo = bits.Mul64(y4[i], qispjMont[i]) + rlo[4], c = bits.Add64(rlo[4], mlo, 0) + rhi[4] += mhi + c + + mhi, mlo = bits.Mul64(y5[i], qispjMont[i]) + rlo[5], c = bits.Add64(rlo[5], mlo, 0) + rhi[5] += mhi + c + + mhi, mlo = bits.Mul64(y6[i], qispjMont[i]) + rlo[6], c = bits.Add64(rlo[6], mlo, 0) + rhi[6] += mhi + c + + mhi, mlo = bits.Mul64(y7[i], qispjMont[i]) + rlo[7], c = bits.Add64(rlo[7], mlo, 0) + rhi[7] += mhi + c + } + + hhi, _ = bits.Mul64(rlo[0]*qInv, pj) + res[0] = rhi[0] - hhi + pj + qpjInv[v[0]] + + hhi, _ = bits.Mul64(rlo[1]*qInv, pj) + res[1] = rhi[1] - hhi + pj + qpjInv[v[1]] + + hhi, _ = bits.Mul64(rlo[2]*qInv, pj) + res[2] = rhi[2] - hhi + pj + qpjInv[v[2]] + + hhi, _ = bits.Mul64(rlo[3]*qInv, pj) + res[3] = rhi[3] - hhi + pj + qpjInv[v[3]] + + hhi, _ = bits.Mul64(rlo[4]*qInv, pj) + res[4] = rhi[4] - hhi + pj + qpjInv[v[4]] + + hhi, _ = bits.Mul64(rlo[5]*qInv, pj) + res[5] = rhi[5] - hhi + pj + qpjInv[v[5]] + + hhi, _ = bits.Mul64(rlo[6]*qInv, pj) + res[6] = rhi[6] - hhi + pj + qpjInv[v[6]] + + hhi, _ = bits.Mul64(rlo[7]*qInv, pj) + res[7] = rhi[7] - hhi + pj + qpjInv[v[7]] +} diff --git a/ring/ring_ntt.go b/ring/ring_ntt.go index 36e81dbc..4bf67d01 100644 --- a/ring/ring_ntt.go +++ b/ring/ring_ntt.go @@ -1,6 +1,7 @@ package ring import ( + "math/bits" "unsafe" ) @@ -34,142 +35,51 @@ func (r *Ring) InvNTTLvl(level uint64, p1, p2 *Poly) { } } -// butterfly computes X, Y = U + V*Psi, U - V*Psi mod Q. -func butterfly(U, V, Psi, Q, Qinv uint64) (X, Y uint64) { - if U > 2*Q { - U -= 2 * Q +// NTTLazy computes the NTT of p1 and returns the result on p2. +// Output values are in the range [0, 2q-1] +func (r *Ring) NTTLazy(p1, p2 *Poly) { + for x := range r.Modulus { + NTTLazy(p1.Coeffs[x], p2.Coeffs[x], r.N, r.NttPsi[x], r.Modulus[x], r.MredParams[x], r.BredParams[x]) } - V = MRedConstant(V, Psi, Q, Qinv) - X = U + V - Y = U + 2*Q - V - return } -// invbutterfly computes X, Y = U + V, (U - V) * Psi mod Q. -func invbutterfly(U, V, Psi, Q, Qinv uint64) (X, Y uint64) { - X = U + V - if X > 2*Q { - X -= 2 * Q +// NTTLazyLvl computes the NTT of p1 and returns the result on p2. +// The value level defines the number of moduli of the input polynomials. +// Output values are in the range [0, 2q-1] +func (r *Ring) NTTLazyLvl(level uint64, p1, p2 *Poly) { + for x := uint64(0); x < level+1; x++ { + NTTLazy(p1.Coeffs[x], p2.Coeffs[x], r.N, r.NttPsi[x], r.Modulus[x], r.MredParams[x], r.BredParams[x]) } - Y = MRedConstant(U+2*Q-V, Psi, Q, Qinv) // At the moment it is not possible to use MRedConstant if Q > 61 bits - return +} + +// InvNTTLazy computes the inverse-NTT of p1 and returns the result on p2. +// Output values are in the range [0, 2q-1] +func (r *Ring) InvNTTLazy(p1, p2 *Poly) { + for x := range r.Modulus { + InvNTTLazy(p1.Coeffs[x], p2.Coeffs[x], r.N, r.NttPsiInv[x], r.NttNInv[x], r.Modulus[x], r.MredParams[x]) + } +} + +// InvNTTLazyLvl computes the inverse-NTT of p1 and returns the result on p2. +// The value level defines the number of moduli of the input polynomials. +// Output values are in the range [0, 2q-1] +func (r *Ring) InvNTTLazyLvl(level uint64, p1, p2 *Poly) { + for x := uint64(0); x < level+1; x++ { + InvNTTLazy(p1.Coeffs[x], p2.Coeffs[x], r.N, r.NttPsiInv[x], r.NttNInv[x], r.Modulus[x], r.MredParams[x]) + } +} + +// butterfly computes X, Y = U + V*Psi, U - V*Psi mod Q. +func butterfly(U, V, Psi, twoQ, QBitLen, Q, Qinv uint64) (uint64, uint64) { + U -= (U >> QBitLen) * Q + V = MRedConstant(V, Psi, Q, Qinv) + return U + V, U + twoQ - V } // NTT computes the NTT on the input coefficients using the input parameters. func NTT(coeffsIn, coeffsOut []uint64, N uint64, nttPsi []uint64, Q, mredParams uint64, bredParams []uint64) { - var j1, j2, t uint64 - var F uint64 - - // Copy the result of the first round of butterflies on p2 with approximate reduction - t = N >> 1 - F = nttPsi[1] - for j := uint64(0); j <= t-1; j = j + 8 { - - xin := (*[8]uint64)(unsafe.Pointer(&coeffsIn[j])) - yin := (*[8]uint64)(unsafe.Pointer(&coeffsIn[j+t])) - - xout := (*[8]uint64)(unsafe.Pointer(&coeffsOut[j])) - yout := (*[8]uint64)(unsafe.Pointer(&coeffsOut[j+t])) - - xout[0], yout[0] = butterfly(xin[0], yin[0], F, Q, mredParams) - xout[1], yout[1] = butterfly(xin[1], yin[1], F, Q, mredParams) - xout[2], yout[2] = butterfly(xin[2], yin[2], F, Q, mredParams) - xout[3], yout[3] = butterfly(xin[3], yin[3], F, Q, mredParams) - xout[4], yout[4] = butterfly(xin[4], yin[4], F, Q, mredParams) - xout[5], yout[5] = butterfly(xin[5], yin[5], F, Q, mredParams) - xout[6], yout[6] = butterfly(xin[6], yin[6], F, Q, mredParams) - xout[7], yout[7] = butterfly(xin[7], yin[7], F, Q, mredParams) - } - - // Continue the rest of the second to the n-1 butterflies on p2 with approximate reduction - for m := uint64(2); m < N; m <<= 1 { - - t >>= 1 - - if t >= 8 { - - for i := uint64(0); i < m; i++ { - - j1 = (i * t) << 1 - - j2 = j1 + t - 1 - - F = nttPsi[m+i] - - for j := j1; j <= j2; j = j + 8 { - - x := (*[8]uint64)(unsafe.Pointer(&coeffsOut[j])) - y := (*[8]uint64)(unsafe.Pointer(&coeffsOut[j+t])) - - x[0], y[0] = butterfly(x[0], y[0], F, Q, mredParams) - x[1], y[1] = butterfly(x[1], y[1], F, Q, mredParams) - x[2], y[2] = butterfly(x[2], y[2], F, Q, mredParams) - x[3], y[3] = butterfly(x[3], y[3], F, Q, mredParams) - x[4], y[4] = butterfly(x[4], y[4], F, Q, mredParams) - x[5], y[5] = butterfly(x[5], y[5], F, Q, mredParams) - x[6], y[6] = butterfly(x[6], y[6], F, Q, mredParams) - x[7], y[7] = butterfly(x[7], y[7], F, Q, mredParams) - } - } - - } else if t == 4 { - - for i := uint64(0); i < m; i = i + 2 { - - j1 = (i * t) << 1 - - psi := (*[2]uint64)(unsafe.Pointer(&nttPsi[m+i])) - x := (*[16]uint64)(unsafe.Pointer(&coeffsOut[j1])) - - x[0], x[4] = butterfly(x[0], x[4], psi[0], Q, mredParams) - x[1], x[5] = butterfly(x[1], x[5], psi[0], Q, mredParams) - x[2], x[6] = butterfly(x[2], x[6], psi[0], Q, mredParams) - x[3], x[7] = butterfly(x[3], x[7], psi[0], Q, mredParams) - x[8], x[12] = butterfly(x[8], x[12], psi[1], Q, mredParams) - x[9], x[13] = butterfly(x[9], x[13], psi[1], Q, mredParams) - x[10], x[14] = butterfly(x[10], x[14], psi[1], Q, mredParams) - x[11], x[15] = butterfly(x[11], x[15], psi[1], Q, mredParams) - - } - - } else if t == 2 { - - for i := uint64(0); i < m; i = i + 4 { - - j1 = (i * t) << 1 - - psi := (*[4]uint64)(unsafe.Pointer(&nttPsi[m+i])) - x := (*[16]uint64)(unsafe.Pointer(&coeffsOut[j1])) - - x[0], x[2] = butterfly(x[0], x[2], psi[0], Q, mredParams) - x[1], x[3] = butterfly(x[1], x[3], psi[0], Q, mredParams) - x[4], x[6] = butterfly(x[4], x[6], psi[1], Q, mredParams) - x[5], x[7] = butterfly(x[5], x[7], psi[1], Q, mredParams) - x[8], x[10] = butterfly(x[8], x[10], psi[2], Q, mredParams) - x[9], x[11] = butterfly(x[9], x[11], psi[2], Q, mredParams) - x[12], x[14] = butterfly(x[12], x[14], psi[3], Q, mredParams) - x[13], x[15] = butterfly(x[13], x[15], psi[3], Q, mredParams) - } - - } else { - - for i := uint64(0); i < m; i = i + 8 { - - psi := (*[8]uint64)(unsafe.Pointer(&nttPsi[m+i])) - x := (*[16]uint64)(unsafe.Pointer(&coeffsOut[2*i])) - - x[0], x[1] = butterfly(x[0], x[1], psi[0], Q, mredParams) - x[2], x[3] = butterfly(x[2], x[3], psi[1], Q, mredParams) - x[4], x[5] = butterfly(x[4], x[5], psi[2], Q, mredParams) - x[6], x[7] = butterfly(x[6], x[7], psi[3], Q, mredParams) - x[8], x[9] = butterfly(x[8], x[9], psi[4], Q, mredParams) - x[10], x[11] = butterfly(x[10], x[11], psi[5], Q, mredParams) - x[12], x[13] = butterfly(x[12], x[13], psi[6], Q, mredParams) - x[14], x[15] = butterfly(x[14], x[15], psi[7], Q, mredParams) - } - } - } + NTTLazy(coeffsIn, coeffsOut, N, nttPsi, Q, mredParams, bredParams) // Finish with an exact reduction for i := uint64(0); i < N; i = i + 8 { @@ -184,11 +94,319 @@ func NTT(coeffsIn, coeffsOut []uint64, N uint64, nttPsi []uint64, Q, mredParams x[6] = BRedAdd(x[6], Q, bredParams) x[7] = BRedAdd(x[7], Q, bredParams) } +} +// NTTLazy computes the NTT on the input coefficients using the input parameters with output values in the range [0, 2q-1]. +func NTTLazy(coeffsIn, coeffsOut []uint64, N uint64, nttPsi []uint64, Q, QInv uint64, bredParams []uint64) { + var j1, j2, t uint64 + var F uint64 + + QBitLen := uint64(bits.Len64(Q)) + twoQ := 2 * Q + + // Copy the result of the first round of butterflies on p2 with approximate reduction + t = N >> 1 + F = nttPsi[1] + var V uint64 + + for j := uint64(0); j <= t-1; j = j + 8 { + + xin := (*[8]uint64)(unsafe.Pointer(&coeffsIn[j])) + yin := (*[8]uint64)(unsafe.Pointer(&coeffsIn[j+t])) + + xout := (*[8]uint64)(unsafe.Pointer(&coeffsOut[j])) + yout := (*[8]uint64)(unsafe.Pointer(&coeffsOut[j+t])) + + V = MRedConstant(yin[0], F, Q, QInv) + xout[0], yout[0] = xin[0]+V, xin[0]+twoQ-V + + V = MRedConstant(yin[1], F, Q, QInv) + xout[1], yout[1] = xin[1]+V, xin[1]+twoQ-V + + V = MRedConstant(yin[2], F, Q, QInv) + xout[2], yout[2] = xin[2]+V, xin[2]+twoQ-V + + V = MRedConstant(yin[3], F, Q, QInv) + xout[3], yout[3] = xin[3]+V, xin[3]+twoQ-V + + V = MRedConstant(yin[4], F, Q, QInv) + xout[4], yout[4] = xin[4]+V, xin[4]+twoQ-V + + V = MRedConstant(yin[5], F, Q, QInv) + xout[5], yout[5] = xin[5]+V, xin[5]+twoQ-V + + V = MRedConstant(yin[6], F, Q, QInv) + xout[6], yout[6] = xin[6]+V, xin[6]+twoQ-V + + V = MRedConstant(yin[7], F, Q, QInv) + xout[7], yout[7] = xin[7]+V, xin[7]+twoQ-V + } + + // Continue the rest of the second to the n-1 butterflies on p2 with approximate reduction + var reduce bool + + for m := uint64(2); m < N; m <<= 1 { + + reduce = (bits.Len64(m)&1 == 1) + + t >>= 1 + + if t >= 8 { + + for i := uint64(0); i < m; i++ { + + j1 = (i * t) << 1 + + j2 = j1 + t - 1 + + F = nttPsi[m+i] + + if reduce { + + for j := j1; j <= j2; j = j + 8 { + + x := (*[8]uint64)(unsafe.Pointer(&coeffsOut[j])) + y := (*[8]uint64)(unsafe.Pointer(&coeffsOut[j+t])) + + // input := (x[i] < 8q, y[i] < 8q) + // output := (x[i] < 2q, y[i] < 2q) + x[0], y[0] = butterfly(x[0], y[0], F, twoQ, QBitLen, Q, QInv) + x[1], y[1] = butterfly(x[1], y[1], F, twoQ, QBitLen, Q, QInv) + x[2], y[2] = butterfly(x[2], y[2], F, twoQ, QBitLen, Q, QInv) + x[3], y[3] = butterfly(x[3], y[3], F, twoQ, QBitLen, Q, QInv) + x[4], y[4] = butterfly(x[4], y[4], F, twoQ, QBitLen, Q, QInv) + x[5], y[5] = butterfly(x[5], y[5], F, twoQ, QBitLen, Q, QInv) + x[6], y[6] = butterfly(x[6], y[6], F, twoQ, QBitLen, Q, QInv) + x[7], y[7] = butterfly(x[7], y[7], F, twoQ, QBitLen, Q, QInv) + } + + } else { + + for j := j1; j <= j2; j = j + 8 { + + x := (*[8]uint64)(unsafe.Pointer(&coeffsOut[j])) + y := (*[8]uint64)(unsafe.Pointer(&coeffsOut[j+t])) + + // input := (x[i] < 2q, y[i] < 2q) + // output := (x[i] < 4q, y[i] < 4q) + V = MRedConstant(y[0], F, Q, QInv) + x[0], y[0] = x[0]+V, x[0]+twoQ-V + + V = MRedConstant(y[1], F, Q, QInv) + x[1], y[1] = x[1]+V, x[1]+twoQ-V + + V = MRedConstant(y[2], F, Q, QInv) + x[2], y[2] = x[2]+V, x[2]+twoQ-V + + V = MRedConstant(y[3], F, Q, QInv) + x[3], y[3] = x[3]+V, x[3]+twoQ-V + + V = MRedConstant(y[4], F, Q, QInv) + x[4], y[4] = x[4]+V, x[4]+twoQ-V + + V = MRedConstant(y[5], F, Q, QInv) + x[5], y[5] = x[5]+V, x[5]+twoQ-V + + V = MRedConstant(y[6], F, Q, QInv) + x[6], y[6] = x[6]+V, x[6]+twoQ-V + + V = MRedConstant(y[7], F, Q, QInv) + x[7], y[7] = x[7]+V, x[7]+twoQ-V + } + } + } + + } else if t == 4 { + + if reduce { + + for i := uint64(0); i < m; i = i + 2 { + + j1 = (i * t) << 1 + + psi := (*[2]uint64)(unsafe.Pointer(&nttPsi[m+i])) + x := (*[16]uint64)(unsafe.Pointer(&coeffsOut[j1])) + + // input := (x[i] < 4q, y[i] < 4q) + // output := (x[i] < 2q, y[i] < 2q) + x[0], x[4] = butterfly(x[0], x[4], psi[0], twoQ, QBitLen, Q, QInv) + x[1], x[5] = butterfly(x[1], x[5], psi[0], twoQ, QBitLen, Q, QInv) + x[2], x[6] = butterfly(x[2], x[6], psi[0], twoQ, QBitLen, Q, QInv) + x[3], x[7] = butterfly(x[3], x[7], psi[0], twoQ, QBitLen, Q, QInv) + x[8], x[12] = butterfly(x[8], x[12], psi[1], twoQ, QBitLen, Q, QInv) + x[9], x[13] = butterfly(x[9], x[13], psi[1], twoQ, QBitLen, Q, QInv) + x[10], x[14] = butterfly(x[10], x[14], psi[1], twoQ, QBitLen, Q, QInv) + x[11], x[15] = butterfly(x[11], x[15], psi[1], twoQ, QBitLen, Q, QInv) + + } + } else { + + for i := uint64(0); i < m; i = i + 2 { + + j1 = (i * t) << 1 + + psi := (*[2]uint64)(unsafe.Pointer(&nttPsi[m+i])) + x := (*[16]uint64)(unsafe.Pointer(&coeffsOut[j1])) + + // input := (x[i] < 2q, y[i] < 2q) + // output := (x[i] < 8q, y[i] < 8q) + V = MRedConstant(x[4], psi[0], Q, QInv) + x[0], x[4] = x[0]+V, x[0]+twoQ-V + + V = MRedConstant(x[5], psi[0], Q, QInv) + x[1], x[5] = x[1]+V, x[1]+twoQ-V + + V = MRedConstant(x[6], psi[0], Q, QInv) + x[2], x[6] = x[2]+V, x[2]+twoQ-V + + V = MRedConstant(x[7], psi[0], Q, QInv) + x[3], x[7] = x[3]+V, x[3]+twoQ-V + + V = MRedConstant(x[12], psi[1], Q, QInv) + x[8], x[12] = x[8]+V, x[8]+twoQ-V + + V = MRedConstant(x[13], psi[1], Q, QInv) + x[9], x[13] = x[9]+V, x[9]+twoQ-V + + V = MRedConstant(x[14], psi[1], Q, QInv) + x[10], x[14] = x[10]+V, x[10]+twoQ-V + + V = MRedConstant(x[15], psi[1], Q, QInv) + x[11], x[15] = x[11]+V, x[11]+twoQ-V + + } + + } + + } else if t == 2 { + + if reduce { + + for i := uint64(0); i < m; i = i + 4 { + + j1 = (i * t) << 1 + + psi := (*[4]uint64)(unsafe.Pointer(&nttPsi[m+i])) + x := (*[16]uint64)(unsafe.Pointer(&coeffsOut[j1])) + + // input := (x[i] < 8q, y[i] < 8q) + // output := (x[i] < 2q, y[i] < 2q) + x[0], x[2] = butterfly(x[0], x[2], psi[0], twoQ, QBitLen, Q, QInv) + x[1], x[3] = butterfly(x[1], x[3], psi[0], twoQ, QBitLen, Q, QInv) + x[4], x[6] = butterfly(x[4], x[6], psi[1], twoQ, QBitLen, Q, QInv) + x[5], x[7] = butterfly(x[5], x[7], psi[1], twoQ, QBitLen, Q, QInv) + x[8], x[10] = butterfly(x[8], x[10], psi[2], twoQ, QBitLen, Q, QInv) + x[9], x[11] = butterfly(x[9], x[11], psi[2], twoQ, QBitLen, Q, QInv) + x[12], x[14] = butterfly(x[12], x[14], psi[3], twoQ, QBitLen, Q, QInv) + x[13], x[15] = butterfly(x[13], x[15], psi[3], twoQ, QBitLen, Q, QInv) + } + } else { + + for i := uint64(0); i < m; i = i + 4 { + + j1 = (i * t) << 1 + + psi := (*[4]uint64)(unsafe.Pointer(&nttPsi[m+i])) + x := (*[16]uint64)(unsafe.Pointer(&coeffsOut[j1])) + + // input := (x[i] < 2q, y[i] < 2q) + // output := (x[i] < 8q, y[i] < 8q) + V = MRedConstant(x[2], psi[0], Q, QInv) + x[0], x[2] = x[0]+V, x[0]+twoQ-V + + V = MRedConstant(x[3], psi[0], Q, QInv) + x[1], x[3] = x[1]+V, x[1]+twoQ-V + + V = MRedConstant(x[6], psi[1], Q, QInv) + x[4], x[6] = x[4]+V, x[4]+twoQ-V + + V = MRedConstant(x[7], psi[1], Q, QInv) + x[5], x[7] = x[5]+V, x[5]+twoQ-V + + V = MRedConstant(x[10], psi[2], Q, QInv) + x[8], x[10] = x[8]+V, x[8]+twoQ-V + + V = MRedConstant(x[11], psi[2], Q, QInv) + x[9], x[11] = x[9]+V, x[9]+twoQ-V + + V = MRedConstant(x[14], psi[3], Q, QInv) + x[12], x[14] = x[12]+V, x[12]+twoQ-V + + V = MRedConstant(x[15], psi[3], Q, QInv) + x[13], x[15] = x[13]+V, x[13]+twoQ-V + } + } + + } else { + + if reduce { + + for i := uint64(0); i < m; i = i + 8 { + + psi := (*[8]uint64)(unsafe.Pointer(&nttPsi[m+i])) + x := (*[16]uint64)(unsafe.Pointer(&coeffsOut[2*i])) + + // input := (x[i] < 8q, y[i] < 8q) + // output := (x[i] < 2q, y[i] < 2q) + x[0], x[1] = butterfly(x[0], x[1], psi[0], twoQ, QBitLen, Q, QInv) + x[2], x[3] = butterfly(x[2], x[3], psi[1], twoQ, QBitLen, Q, QInv) + x[4], x[5] = butterfly(x[4], x[5], psi[2], twoQ, QBitLen, Q, QInv) + x[6], x[7] = butterfly(x[6], x[7], psi[3], twoQ, QBitLen, Q, QInv) + x[8], x[9] = butterfly(x[8], x[9], psi[4], twoQ, QBitLen, Q, QInv) + x[10], x[11] = butterfly(x[10], x[11], psi[5], twoQ, QBitLen, Q, QInv) + x[12], x[13] = butterfly(x[12], x[13], psi[6], twoQ, QBitLen, Q, QInv) + x[14], x[15] = butterfly(x[14], x[15], psi[7], twoQ, QBitLen, Q, QInv) + } + } else { + + for i := uint64(0); i < m; i = i + 8 { + + psi := (*[8]uint64)(unsafe.Pointer(&nttPsi[m+i])) + x := (*[16]uint64)(unsafe.Pointer(&coeffsOut[2*i])) + + // input := (x[i] < 2q, y[i] < 2q) + // output := (x[i] < 4q, y[i] < 4q) + V = MRedConstant(x[1], psi[0], Q, QInv) + x[0], x[1] = x[0]+V, x[0]+twoQ-V + + V = MRedConstant(x[3], psi[1], Q, QInv) + x[2], x[3] = x[2]+V, x[2]+twoQ-V + + V = MRedConstant(x[5], psi[2], Q, QInv) + x[4], x[5] = x[4]+V, x[4]+twoQ-V + + V = MRedConstant(x[7], psi[3], Q, QInv) + x[6], x[7] = x[6]+V, x[6]+twoQ-V + + V = MRedConstant(x[9], psi[4], Q, QInv) + x[8], x[9] = x[8]+V, x[8]+twoQ-V + + V = MRedConstant(x[11], psi[5], Q, QInv) + x[10], x[11] = x[10]+V, x[10]+twoQ-V + + V = MRedConstant(x[13], psi[6], Q, QInv) + x[12], x[13] = x[12]+V, x[12]+twoQ-V + + V = MRedConstant(x[15], psi[7], Q, QInv) + x[14], x[15] = x[14]+V, x[14]+twoQ-V + } + } + } + } +} + +// invbutterfly computes X, Y = U + V, (U - V) * Psi mod Q. +func invbutterfly(U, V, Psi, twoQ, Q, Qinv uint64) (X, Y uint64) { + X = U + V + if X >= twoQ { + X -= twoQ + } + Y = MRedConstant(U+twoQ-V, Psi, Q, Qinv) // At the moment it is not possible to use MRedConstant if Q > 61 bits + return } // InvNTT computes the InvNTT transformation on the input coefficients using the input parameters. -func InvNTT(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttNInv, Q, mredParams uint64) { +func InvNTT(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttNInv, Q, QInv uint64) { var j1, j2, h, t uint64 var F uint64 @@ -196,6 +414,7 @@ func InvNTT(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttNInv, // Copy the result of the first round of butterflies on p2 with approximate reduction t = 1 h = N >> 1 + twoQ := Q << 1 for i := uint64(0); i < h; i = i + 8 { @@ -203,14 +422,14 @@ func InvNTT(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttNInv, xin := (*[16]uint64)(unsafe.Pointer(&coeffsIn[2*i])) xout := (*[16]uint64)(unsafe.Pointer(&coeffsOut[2*i])) - xout[0], xout[1] = invbutterfly(xin[0], xin[1], psi[0], Q, mredParams) - xout[2], xout[3] = invbutterfly(xin[2], xin[3], psi[1], Q, mredParams) - xout[4], xout[5] = invbutterfly(xin[4], xin[5], psi[2], Q, mredParams) - xout[6], xout[7] = invbutterfly(xin[6], xin[7], psi[3], Q, mredParams) - xout[8], xout[9] = invbutterfly(xin[8], xin[9], psi[4], Q, mredParams) - xout[10], xout[11] = invbutterfly(xin[10], xin[11], psi[5], Q, mredParams) - xout[12], xout[13] = invbutterfly(xin[12], xin[13], psi[6], Q, mredParams) - xout[14], xout[15] = invbutterfly(xin[14], xin[15], psi[7], Q, mredParams) + xout[0], xout[1] = invbutterfly(xin[0], xin[1], psi[0], twoQ, Q, QInv) + xout[2], xout[3] = invbutterfly(xin[2], xin[3], psi[1], twoQ, Q, QInv) + xout[4], xout[5] = invbutterfly(xin[4], xin[5], psi[2], twoQ, Q, QInv) + xout[6], xout[7] = invbutterfly(xin[6], xin[7], psi[3], twoQ, Q, QInv) + xout[8], xout[9] = invbutterfly(xin[8], xin[9], psi[4], twoQ, Q, QInv) + xout[10], xout[11] = invbutterfly(xin[10], xin[11], psi[5], twoQ, Q, QInv) + xout[12], xout[13] = invbutterfly(xin[12], xin[13], psi[6], twoQ, Q, QInv) + xout[14], xout[15] = invbutterfly(xin[14], xin[15], psi[7], twoQ, Q, QInv) } // Continue the rest of the second to the n-1 butterflies on p2 with approximate reduction @@ -233,14 +452,14 @@ func InvNTT(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttNInv, x := (*[8]uint64)(unsafe.Pointer(&coeffsOut[j])) y := (*[8]uint64)(unsafe.Pointer(&coeffsOut[j+t])) - x[0], y[0] = invbutterfly(x[0], y[0], F, Q, mredParams) - x[1], y[1] = invbutterfly(x[1], y[1], F, Q, mredParams) - x[2], y[2] = invbutterfly(x[2], y[2], F, Q, mredParams) - x[3], y[3] = invbutterfly(x[3], y[3], F, Q, mredParams) - x[4], y[4] = invbutterfly(x[4], y[4], F, Q, mredParams) - x[5], y[5] = invbutterfly(x[5], y[5], F, Q, mredParams) - x[6], y[6] = invbutterfly(x[6], y[6], F, Q, mredParams) - x[7], y[7] = invbutterfly(x[7], y[7], F, Q, mredParams) + x[0], y[0] = invbutterfly(x[0], y[0], F, twoQ, Q, QInv) + x[1], y[1] = invbutterfly(x[1], y[1], F, twoQ, Q, QInv) + x[2], y[2] = invbutterfly(x[2], y[2], F, twoQ, Q, QInv) + x[3], y[3] = invbutterfly(x[3], y[3], F, twoQ, Q, QInv) + x[4], y[4] = invbutterfly(x[4], y[4], F, twoQ, Q, QInv) + x[5], y[5] = invbutterfly(x[5], y[5], F, twoQ, Q, QInv) + x[6], y[6] = invbutterfly(x[6], y[6], F, twoQ, Q, QInv) + x[7], y[7] = invbutterfly(x[7], y[7], F, twoQ, Q, QInv) } j1 = j1 + (t << 1) @@ -253,14 +472,14 @@ func InvNTT(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttNInv, psi := (*[2]uint64)(unsafe.Pointer(&nttPsiInv[h+i])) x := (*[16]uint64)(unsafe.Pointer(&coeffsOut[j1])) - x[0], x[4] = invbutterfly(x[0], x[4], psi[0], Q, mredParams) - x[1], x[5] = invbutterfly(x[1], x[5], psi[0], Q, mredParams) - x[2], x[6] = invbutterfly(x[2], x[6], psi[0], Q, mredParams) - x[3], x[7] = invbutterfly(x[3], x[7], psi[0], Q, mredParams) - x[8], x[12] = invbutterfly(x[8], x[12], psi[1], Q, mredParams) - x[9], x[13] = invbutterfly(x[9], x[13], psi[1], Q, mredParams) - x[10], x[14] = invbutterfly(x[10], x[14], psi[1], Q, mredParams) - x[11], x[15] = invbutterfly(x[11], x[15], psi[1], Q, mredParams) + x[0], x[4] = invbutterfly(x[0], x[4], psi[0], twoQ, Q, QInv) + x[1], x[5] = invbutterfly(x[1], x[5], psi[0], twoQ, Q, QInv) + x[2], x[6] = invbutterfly(x[2], x[6], psi[0], twoQ, Q, QInv) + x[3], x[7] = invbutterfly(x[3], x[7], psi[0], twoQ, Q, QInv) + x[8], x[12] = invbutterfly(x[8], x[12], psi[1], twoQ, Q, QInv) + x[9], x[13] = invbutterfly(x[9], x[13], psi[1], twoQ, Q, QInv) + x[10], x[14] = invbutterfly(x[10], x[14], psi[1], twoQ, Q, QInv) + x[11], x[15] = invbutterfly(x[11], x[15], psi[1], twoQ, Q, QInv) j1 = j1 + (t << 2) } @@ -272,14 +491,14 @@ func InvNTT(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttNInv, psi := (*[4]uint64)(unsafe.Pointer(&nttPsiInv[h+i])) x := (*[16]uint64)(unsafe.Pointer(&coeffsOut[j1])) - x[0], x[2] = invbutterfly(x[0], x[2], psi[0], Q, mredParams) - x[1], x[3] = invbutterfly(x[1], x[3], psi[0], Q, mredParams) - x[4], x[6] = invbutterfly(x[4], x[6], psi[1], Q, mredParams) - x[5], x[7] = invbutterfly(x[5], x[7], psi[1], Q, mredParams) - x[8], x[10] = invbutterfly(x[8], x[10], psi[2], Q, mredParams) - x[9], x[11] = invbutterfly(x[9], x[11], psi[2], Q, mredParams) - x[12], x[14] = invbutterfly(x[12], x[14], psi[3], Q, mredParams) - x[13], x[15] = invbutterfly(x[13], x[15], psi[3], Q, mredParams) + x[0], x[2] = invbutterfly(x[0], x[2], psi[0], twoQ, Q, QInv) + x[1], x[3] = invbutterfly(x[1], x[3], psi[0], twoQ, Q, QInv) + x[4], x[6] = invbutterfly(x[4], x[6], psi[1], twoQ, Q, QInv) + x[5], x[7] = invbutterfly(x[5], x[7], psi[1], twoQ, Q, QInv) + x[8], x[10] = invbutterfly(x[8], x[10], psi[2], twoQ, Q, QInv) + x[9], x[11] = invbutterfly(x[9], x[11], psi[2], twoQ, Q, QInv) + x[12], x[14] = invbutterfly(x[12], x[14], psi[3], twoQ, Q, QInv) + x[13], x[15] = invbutterfly(x[13], x[15], psi[3], twoQ, Q, QInv) j1 = j1 + (t << 3) } @@ -293,14 +512,133 @@ func InvNTT(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttNInv, x := (*[8]uint64)(unsafe.Pointer(&coeffsOut[i])) - x[0] = MRed(x[0], nttNInv, Q, mredParams) - x[1] = MRed(x[1], nttNInv, Q, mredParams) - x[2] = MRed(x[2], nttNInv, Q, mredParams) - x[3] = MRed(x[3], nttNInv, Q, mredParams) - x[4] = MRed(x[4], nttNInv, Q, mredParams) - x[5] = MRed(x[5], nttNInv, Q, mredParams) - x[6] = MRed(x[6], nttNInv, Q, mredParams) - x[7] = MRed(x[7], nttNInv, Q, mredParams) + x[0] = MRed(x[0], nttNInv, Q, QInv) + x[1] = MRed(x[1], nttNInv, Q, QInv) + x[2] = MRed(x[2], nttNInv, Q, QInv) + x[3] = MRed(x[3], nttNInv, Q, QInv) + x[4] = MRed(x[4], nttNInv, Q, QInv) + x[5] = MRed(x[5], nttNInv, Q, QInv) + x[6] = MRed(x[6], nttNInv, Q, QInv) + x[7] = MRed(x[7], nttNInv, Q, QInv) + } +} + +// InvNTTLazy computes the InvNTT transformation on the input coefficients using the input parameters with output values in the range [0, 2q-1]. +func InvNTTLazy(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttNInv, Q, mredParams uint64) { + + var j1, j2, h, t uint64 + var F uint64 + + // Copy the result of the first round of butterflies on p2 with approximate reduction + t = 1 + h = N >> 1 + + twoQ := Q << 1 + + for i := uint64(0); i < h; i = i + 8 { + + psi := (*[8]uint64)(unsafe.Pointer(&nttPsiInv[h+i])) + xin := (*[16]uint64)(unsafe.Pointer(&coeffsIn[2*i])) + xout := (*[16]uint64)(unsafe.Pointer(&coeffsOut[2*i])) + + xout[0], xout[1] = invbutterfly(xin[0], xin[1], psi[0], twoQ, Q, mredParams) + xout[2], xout[3] = invbutterfly(xin[2], xin[3], psi[1], twoQ, Q, mredParams) + xout[4], xout[5] = invbutterfly(xin[4], xin[5], psi[2], twoQ, Q, mredParams) + xout[6], xout[7] = invbutterfly(xin[6], xin[7], psi[3], twoQ, Q, mredParams) + xout[8], xout[9] = invbutterfly(xin[8], xin[9], psi[4], twoQ, Q, mredParams) + xout[10], xout[11] = invbutterfly(xin[10], xin[11], psi[5], twoQ, Q, mredParams) + xout[12], xout[13] = invbutterfly(xin[12], xin[13], psi[6], twoQ, Q, mredParams) + xout[14], xout[15] = invbutterfly(xin[14], xin[15], psi[7], twoQ, Q, mredParams) + } + + // Continue the rest of the second to the n-1 butterflies on p2 with approximate reduction + t <<= 1 + for m := N >> 1; m > 1; m >>= 1 { + + j1 = 0 + h = m >> 1 + + if t >= 8 { + + for i := uint64(0); i < h; i++ { + + j2 = j1 + t - 1 + + F = nttPsiInv[h+i] + + for j := j1; j <= j2; j = j + 8 { + + x := (*[8]uint64)(unsafe.Pointer(&coeffsOut[j])) + y := (*[8]uint64)(unsafe.Pointer(&coeffsOut[j+t])) + + x[0], y[0] = invbutterfly(x[0], y[0], F, twoQ, Q, mredParams) + x[1], y[1] = invbutterfly(x[1], y[1], F, twoQ, Q, mredParams) + x[2], y[2] = invbutterfly(x[2], y[2], F, twoQ, Q, mredParams) + x[3], y[3] = invbutterfly(x[3], y[3], F, twoQ, Q, mredParams) + x[4], y[4] = invbutterfly(x[4], y[4], F, twoQ, Q, mredParams) + x[5], y[5] = invbutterfly(x[5], y[5], F, twoQ, Q, mredParams) + x[6], y[6] = invbutterfly(x[6], y[6], F, twoQ, Q, mredParams) + x[7], y[7] = invbutterfly(x[7], y[7], F, twoQ, Q, mredParams) + } + + j1 = j1 + (t << 1) + } + + } else if t == 4 { + + for i := uint64(0); i < h; i = i + 2 { + + psi := (*[2]uint64)(unsafe.Pointer(&nttPsiInv[h+i])) + x := (*[16]uint64)(unsafe.Pointer(&coeffsOut[j1])) + + x[0], x[4] = invbutterfly(x[0], x[4], psi[0], twoQ, Q, mredParams) + x[1], x[5] = invbutterfly(x[1], x[5], psi[0], twoQ, Q, mredParams) + x[2], x[6] = invbutterfly(x[2], x[6], psi[0], twoQ, Q, mredParams) + x[3], x[7] = invbutterfly(x[3], x[7], psi[0], twoQ, Q, mredParams) + x[8], x[12] = invbutterfly(x[8], x[12], psi[1], twoQ, Q, mredParams) + x[9], x[13] = invbutterfly(x[9], x[13], psi[1], twoQ, Q, mredParams) + x[10], x[14] = invbutterfly(x[10], x[14], psi[1], twoQ, Q, mredParams) + x[11], x[15] = invbutterfly(x[11], x[15], psi[1], twoQ, Q, mredParams) + + j1 = j1 + (t << 2) + } + + } else { + + for i := uint64(0); i < h; i = i + 4 { + + psi := (*[4]uint64)(unsafe.Pointer(&nttPsiInv[h+i])) + x := (*[16]uint64)(unsafe.Pointer(&coeffsOut[j1])) + + x[0], x[2] = invbutterfly(x[0], x[2], psi[0], twoQ, Q, mredParams) + x[1], x[3] = invbutterfly(x[1], x[3], psi[0], twoQ, Q, mredParams) + x[4], x[6] = invbutterfly(x[4], x[6], psi[1], twoQ, Q, mredParams) + x[5], x[7] = invbutterfly(x[5], x[7], psi[1], twoQ, Q, mredParams) + x[8], x[10] = invbutterfly(x[8], x[10], psi[2], twoQ, Q, mredParams) + x[9], x[11] = invbutterfly(x[9], x[11], psi[2], twoQ, Q, mredParams) + x[12], x[14] = invbutterfly(x[12], x[14], psi[3], twoQ, Q, mredParams) + x[13], x[15] = invbutterfly(x[13], x[15], psi[3], twoQ, Q, mredParams) + + j1 = j1 + (t << 3) + } + } + + t <<= 1 + } + + // Finish with an exact reduction + for i := uint64(0); i < N; i = i + 8 { + + x := (*[8]uint64)(unsafe.Pointer(&coeffsOut[i])) + + x[0] = MRedConstant(x[0], nttNInv, Q, mredParams) + x[1] = MRedConstant(x[1], nttNInv, Q, mredParams) + x[2] = MRedConstant(x[2], nttNInv, Q, mredParams) + x[3] = MRedConstant(x[3], nttNInv, Q, mredParams) + x[4] = MRedConstant(x[4], nttNInv, Q, mredParams) + x[5] = MRedConstant(x[5], nttNInv, Q, mredParams) + x[6] = MRedConstant(x[6], nttNInv, Q, mredParams) + x[7] = MRedConstant(x[7], nttNInv, Q, mredParams) } } diff --git a/ring/ring_operations.go b/ring/ring_operations.go index b1b3e784..e9e300fc 100644 --- a/ring/ring_operations.go +++ b/ring/ring_operations.go @@ -258,6 +258,29 @@ func (r *Ring) Reduce(p1, p2 *Poly) { } } +// ReduceConstant applies a modular reduction on the coefficients of p1 and writes the result on p2. +// Return values in [0, 2q-1] +func (r *Ring) ReduceConstant(p1, p2 *Poly) { + for i, qi := range r.Modulus { + p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i] + bredParams := r.BredParams[i] + for j := uint64(0); j < r.N; j = j + 8 { + + x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j])) + z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j])) + + z[0] = BRedAddConstant(x[0], qi, bredParams) + z[1] = BRedAddConstant(x[1], qi, bredParams) + z[2] = BRedAddConstant(x[2], qi, bredParams) + z[3] = BRedAddConstant(x[3], qi, bredParams) + z[4] = BRedAddConstant(x[4], qi, bredParams) + z[5] = BRedAddConstant(x[5], qi, bredParams) + z[6] = BRedAddConstant(x[6], qi, bredParams) + z[7] = BRedAddConstant(x[7], qi, bredParams) + } + } +} + // ReduceLvl applies a modular reduction on the coefficients of p1 // for the moduli from q_0 up to q_level and writes the result on p2. func (r *Ring) ReduceLvl(level uint64, p1, p2 *Poly) { @@ -282,6 +305,31 @@ func (r *Ring) ReduceLvl(level uint64, p1, p2 *Poly) { } } +// ReduceConstantLvl applies a modular reduction on the coefficients of p1 +// for the moduli from q_0 up to q_level and writes the result on p2. +// Return values in [0, 2q-1] +func (r *Ring) ReduceConstantLvl(level uint64, p1, p2 *Poly) { + for i := uint64(0); i < level+1; i++ { + qi := r.Modulus[i] + p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i] + bredParams := r.BredParams[i] + for j := uint64(0); j < r.N; j = j + 8 { + + x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j])) + z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j])) + + z[0] = BRedAddConstant(x[0], qi, bredParams) + z[1] = BRedAddConstant(x[1], qi, bredParams) + z[2] = BRedAddConstant(x[2], qi, bredParams) + z[3] = BRedAddConstant(x[3], qi, bredParams) + z[4] = BRedAddConstant(x[4], qi, bredParams) + z[5] = BRedAddConstant(x[5], qi, bredParams) + z[6] = BRedAddConstant(x[6], qi, bredParams) + z[7] = BRedAddConstant(x[7], qi, bredParams) + } + } +} + // Mod applies a modular reduction by m on the coefficients of p1 and writes the result on p2. func (r *Ring) Mod(p1 *Poly, m uint64, p2 *Poly) { bredParams := BRedParams(m) @@ -456,6 +504,31 @@ func (r *Ring) MulCoeffsMontgomeryLvl(level uint64, p1, p2, p3 *Poly) { } } +// MulCoeffsMontgomeryConstantLvl multiplies p1 by p2 coefficient-wise with a Montgomery +// modular reduction for the moduli from q_0 up to q_level and returns the result on p3. +func (r *Ring) MulCoeffsMontgomeryConstantLvl(level uint64, p1, p2, p3 *Poly) { + for i := uint64(0); i < level+1; i++ { + qi := r.Modulus[i] + p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i] + mredParams := r.MredParams[i] + for j := uint64(0); j < r.N; j = j + 8 { + + x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j])) + y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j])) + z := (*[8]uint64)(unsafe.Pointer(&p3tmp[j])) + + z[0] = MRedConstant(x[0], y[0], qi, mredParams) + z[1] = MRedConstant(x[1], y[1], qi, mredParams) + z[2] = MRedConstant(x[2], y[2], qi, mredParams) + z[3] = MRedConstant(x[3], y[3], qi, mredParams) + z[4] = MRedConstant(x[4], y[4], qi, mredParams) + z[5] = MRedConstant(x[5], y[5], qi, mredParams) + z[6] = MRedConstant(x[6], y[6], qi, mredParams) + z[7] = MRedConstant(x[7], y[7], qi, mredParams) + } + } +} + // MulCoeffsMontgomeryAndAdd multiplies p1 by p2 coefficient-wise with a // Montgomery modular reduction and adds the result to p3. func (r *Ring) MulCoeffsMontgomeryAndAdd(p1, p2, p3 *Poly) { @@ -529,6 +602,31 @@ func (r *Ring) MulCoeffsMontgomeryAndAddNoMod(p1, p2, p3 *Poly) { } } +// MulCoeffsMontgomeryConstantAndAddNoMod multiplies p1 by p2 coefficient-wise with a +// Montgomery modular reduction and adds the result to p3 without modular reduction. +// Return values in [0, 3q-1] +func (r *Ring) MulCoeffsMontgomeryConstantAndAddNoMod(p1, p2, p3 *Poly) { + for i, qi := range r.Modulus { + p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i] + mredParams := r.MredParams[i] + for j := uint64(0); j < r.N; j = j + 8 { + + x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j])) + y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j])) + z := (*[8]uint64)(unsafe.Pointer(&p3tmp[j])) + + z[0] += MRedConstant(x[0], y[0], qi, mredParams) + z[1] += MRedConstant(x[1], y[1], qi, mredParams) + z[2] += MRedConstant(x[2], y[2], qi, mredParams) + z[3] += MRedConstant(x[3], y[3], qi, mredParams) + z[4] += MRedConstant(x[4], y[4], qi, mredParams) + z[5] += MRedConstant(x[5], y[5], qi, mredParams) + z[6] += MRedConstant(x[6], y[6], qi, mredParams) + z[7] += MRedConstant(x[7], y[7], qi, mredParams) + } + } +} + // MulCoeffsMontgomeryAndAddNoModLvl multiplies p1 by p2 coefficient-wise with a Montgomery modular // reduction for the moduli from q_0 up to q_level and adds the result to p3 without modular reduction. func (r *Ring) MulCoeffsMontgomeryAndAddNoModLvl(level uint64, p1, p2, p3 *Poly) { @@ -556,6 +654,7 @@ func (r *Ring) MulCoeffsMontgomeryAndAddNoModLvl(level uint64, p1, p2, p3 *Poly) // MulCoeffsMontgomeryConstantAndAddNoModLvl multiplies p1 by p2 coefficient-wise with a constant-time Montgomery // modular reduction for the moduli from q_0 up to q_level and adds the result to p3 without modular reduction. +// Return values in [0, 3q-1] func (r *Ring) MulCoeffsMontgomeryConstantAndAddNoModLvl(level uint64, p1, p2, p3 *Poly) { for i := uint64(0); i < level+1; i++ { qi := r.Modulus[i] diff --git a/ring/ring_scaling.go b/ring/ring_scaling.go index db89beef..97f83101 100644 --- a/ring/ring_scaling.go +++ b/ring/ring_scaling.go @@ -70,8 +70,8 @@ func (rnss *RNSScaler) DivByQOverTRounded(p1Q, p2T *Poly) { p2tmp := p2T.Coeffs[0] p3tmp := rnss.polypoolT.Coeffs[0] mredParams := rnss.mredParamsT - qInv := rnss.qInv - qHalfModT := rnss.qHalfModT + qInv := T - rnss.qInv + qHalfModT := T - rnss.qHalfModT // Multiply P_{Q} by t and extend the basis from P_{Q} to t*(P_{Q}||P_{t}) // Since the coefficients of P_{t} are multiplied by t, they are all zero, @@ -90,14 +90,14 @@ func (rnss *RNSScaler) DivByQOverTRounded(p1Q, p2T *Poly) { x := (*[8]uint64)(unsafe.Pointer(&p3tmp[j])) z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j])) - z[0] = MRed(qHalfModT+T-x[0], qInv, T, mredParams) - z[1] = MRed(qHalfModT+T-x[1], qInv, T, mredParams) - z[2] = MRed(qHalfModT+T-x[2], qInv, T, mredParams) - z[3] = MRed(qHalfModT+T-x[3], qInv, T, mredParams) - z[4] = MRed(qHalfModT+T-x[4], qInv, T, mredParams) - z[5] = MRed(qHalfModT+T-x[5], qInv, T, mredParams) - z[6] = MRed(qHalfModT+T-x[6], qInv, T, mredParams) - z[7] = MRed(qHalfModT+T-x[7], qInv, T, mredParams) + z[0] = MRed(qHalfModT+x[0], qInv, T, mredParams) + z[1] = MRed(qHalfModT+x[1], qInv, T, mredParams) + z[2] = MRed(qHalfModT+x[2], qInv, T, mredParams) + z[3] = MRed(qHalfModT+x[3], qInv, T, mredParams) + z[4] = MRed(qHalfModT+x[4], qInv, T, mredParams) + z[5] = MRed(qHalfModT+x[5], qInv, T, mredParams) + z[6] = MRed(qHalfModT+x[6], qInv, T, mredParams) + z[7] = MRed(qHalfModT+x[7], qInv, T, mredParams) } } @@ -298,17 +298,18 @@ func (r *Ring) DivFloorByLastModulusNTT(p0 *Poly) { pTmp := make([]uint64, r.N) - InvNTT(p0.Coeffs[level], p0.Coeffs[level], r.N, r.NttPsiInv[level], r.NttNInv[level], r.Modulus[level], r.MredParams[level]) + InvNTTLazy(p0.Coeffs[level], p0.Coeffs[level], r.N, r.NttPsiInv[level], r.NttNInv[level], r.Modulus[level], r.MredParams[level]) for i := 0; i < level; i++ { - NTT(p0.Coeffs[level], pTmp, r.N, r.NttPsi[i], r.Modulus[i], r.MredParams[i], r.BredParams[i]) + NTTLazy(p0.Coeffs[level], pTmp, r.N, r.NttPsi[i], r.Modulus[i], r.MredParams[i], r.BredParams[i]) p0tmp := p0.Coeffs[i] qi := r.Modulus[i] + twoqi := qi << 1 mredParams := r.MredParams[i] - rescalParams := r.RescaleParams[level-1][i] + rescalParams := qi - r.RescaleParams[level-1][i] // (x[i] - x[-1]) * InvQ for j := uint64(0); j < r.N; j = j + 8 { @@ -316,14 +317,14 @@ func (r *Ring) DivFloorByLastModulusNTT(p0 *Poly) { x := (*[8]uint64)(unsafe.Pointer(&pTmp[j])) z := (*[8]uint64)(unsafe.Pointer(&p0tmp[j])) - z[0] = MRed(z[0]+(qi-x[0]), rescalParams, qi, mredParams) - z[1] = MRed(z[1]+(qi-x[1]), rescalParams, qi, mredParams) - z[2] = MRed(z[2]+(qi-x[2]), rescalParams, qi, mredParams) - z[3] = MRed(z[3]+(qi-x[3]), rescalParams, qi, mredParams) - z[4] = MRed(z[4]+(qi-x[4]), rescalParams, qi, mredParams) - z[5] = MRed(z[5]+(qi-x[5]), rescalParams, qi, mredParams) - z[6] = MRed(z[6]+(qi-x[6]), rescalParams, qi, mredParams) - z[7] = MRed(z[7]+(qi-x[7]), rescalParams, qi, mredParams) + z[0] = MRed(twoqi-z[0]+x[0], rescalParams, qi, mredParams) + z[1] = MRed(twoqi-z[1]+x[1], rescalParams, qi, mredParams) + z[2] = MRed(twoqi-z[2]+x[2], rescalParams, qi, mredParams) + z[3] = MRed(twoqi-z[3]+x[3], rescalParams, qi, mredParams) + z[4] = MRed(twoqi-z[4]+x[4], rescalParams, qi, mredParams) + z[5] = MRed(twoqi-z[5]+x[5], rescalParams, qi, mredParams) + z[6] = MRed(twoqi-z[6]+x[6], rescalParams, qi, mredParams) + z[7] = MRed(twoqi-z[7]+x[7], rescalParams, qi, mredParams) } } @@ -340,23 +341,24 @@ func (r *Ring) DivFloorByLastModulus(p0 *Poly) { p0tmp := p0.Coeffs[level] p1tmp := p0.Coeffs[i] qi := r.Modulus[i] + twoqi := qi << 1 bredParams := r.BredParams[i] mredParams := r.MredParams[i] - rescaleParams := r.RescaleParams[level-1][i] + rescaleParams := qi - r.RescaleParams[level-1][i] // (x[i] - x[-1]) * InvQ for j := uint64(0); j < r.N; j = j + 8 { x := (*[8]uint64)(unsafe.Pointer(&p0tmp[j])) z := (*[8]uint64)(unsafe.Pointer(&p1tmp[j])) - z[0] = MRed(z[0]+(qi-BRedAdd(x[0], qi, bredParams)), rescaleParams, qi, mredParams) - z[1] = MRed(z[1]+(qi-BRedAdd(x[1], qi, bredParams)), rescaleParams, qi, mredParams) - z[2] = MRed(z[2]+(qi-BRedAdd(x[2], qi, bredParams)), rescaleParams, qi, mredParams) - z[3] = MRed(z[3]+(qi-BRedAdd(x[3], qi, bredParams)), rescaleParams, qi, mredParams) - z[4] = MRed(z[4]+(qi-BRedAdd(x[4], qi, bredParams)), rescaleParams, qi, mredParams) - z[5] = MRed(z[5]+(qi-BRedAdd(x[5], qi, bredParams)), rescaleParams, qi, mredParams) - z[6] = MRed(z[6]+(qi-BRedAdd(x[6], qi, bredParams)), rescaleParams, qi, mredParams) - z[7] = MRed(z[7]+(qi-BRedAdd(x[7], qi, bredParams)), rescaleParams, qi, mredParams) + z[0] = MRed(twoqi-z[0]+BRedAdd(x[0], qi, bredParams), rescaleParams, qi, mredParams) + z[1] = MRed(twoqi-z[1]+BRedAdd(x[1], qi, bredParams), rescaleParams, qi, mredParams) + z[2] = MRed(twoqi-z[2]+BRedAdd(x[2], qi, bredParams), rescaleParams, qi, mredParams) + z[3] = MRed(twoqi-z[3]+BRedAdd(x[3], qi, bredParams), rescaleParams, qi, mredParams) + z[4] = MRed(twoqi-z[4]+BRedAdd(x[4], qi, bredParams), rescaleParams, qi, mredParams) + z[5] = MRed(twoqi-z[5]+BRedAdd(x[5], qi, bredParams), rescaleParams, qi, mredParams) + z[6] = MRed(twoqi-z[6]+BRedAdd(x[6], qi, bredParams), rescaleParams, qi, mredParams) + z[7] = MRed(twoqi-z[7]+BRedAdd(x[7], qi, bredParams), rescaleParams, qi, mredParams) } } @@ -386,37 +388,22 @@ func (r *Ring) DivRoundByLastModulusNTT(p0 *Poly) { pTmp := make([]uint64, r.N) - InvNTT(p0.Coeffs[level], p0.Coeffs[level], r.N, r.NttPsiInv[level], r.NttNInv[level], r.Modulus[level], r.MredParams[level]) + InvNTTLazy(p0.Coeffs[level], p0.Coeffs[level], r.N, r.NttPsiInv[level], r.NttNInv[level], r.Modulus[level], r.MredParams[level]) // Center by (p-1)/2 pHalf = (r.Modulus[level] - 1) >> 1 p0tmp := p0.Coeffs[level] - pj := r.Modulus[level] - - for i := uint64(0); i < r.N; i = i + 8 { - - z := (*[8]uint64)(unsafe.Pointer(&p0tmp[i])) - - z[0] = CRed(z[0]+pHalf, pj) - z[1] = CRed(z[1]+pHalf, pj) - z[2] = CRed(z[2]+pHalf, pj) - z[3] = CRed(z[3]+pHalf, pj) - z[4] = CRed(z[4]+pHalf, pj) - z[5] = CRed(z[5]+pHalf, pj) - z[6] = CRed(z[6]+pHalf, pj) - z[7] = CRed(z[7]+pHalf, pj) - } for i := 0; i < level; i++ { p1tmp := p0.Coeffs[i] - qi := r.Modulus[i] + twoqi := qi << 1 bredParams := r.BredParams[i] mredParams := r.MredParams[i] - rescaleParams := r.RescaleParams[level-1][i] + rescaleParams := qi - r.RescaleParams[level-1][i] - pHalfNegQi = r.Modulus[i] - BRedAdd(pHalf, qi, bredParams) + pHalfNegQi = pHalf + r.Modulus[i] - BRedAdd(pHalf, qi, bredParams) for j := uint64(0); j < r.N; j = j + 8 { @@ -433,7 +420,7 @@ func (r *Ring) DivRoundByLastModulusNTT(p0 *Poly) { z[7] = x[7] + pHalfNegQi } - NTT(pTmp, pTmp, r.N, r.NttPsi[i], qi, mredParams, bredParams) + NTTLazy(pTmp, pTmp, r.N, r.NttPsi[i], qi, mredParams, bredParams) // (x[i] - x[-1]) * InvQ for j := uint64(0); j < r.N; j = j + 8 { @@ -441,14 +428,14 @@ func (r *Ring) DivRoundByLastModulusNTT(p0 *Poly) { x := (*[8]uint64)(unsafe.Pointer(&pTmp[j])) z := (*[8]uint64)(unsafe.Pointer(&p1tmp[j])) - z[0] = MRed(z[0]+(qi-x[0]), rescaleParams, qi, mredParams) - z[1] = MRed(z[1]+(qi-x[1]), rescaleParams, qi, mredParams) - z[2] = MRed(z[2]+(qi-x[2]), rescaleParams, qi, mredParams) - z[3] = MRed(z[3]+(qi-x[3]), rescaleParams, qi, mredParams) - z[4] = MRed(z[4]+(qi-x[4]), rescaleParams, qi, mredParams) - z[5] = MRed(z[5]+(qi-x[5]), rescaleParams, qi, mredParams) - z[6] = MRed(z[6]+(qi-x[6]), rescaleParams, qi, mredParams) - z[7] = MRed(z[7]+(qi-x[7]), rescaleParams, qi, mredParams) + z[0] = MRed(twoqi+x[0]-z[0], rescaleParams, qi, mredParams) + z[1] = MRed(twoqi+x[1]-z[1], rescaleParams, qi, mredParams) + z[2] = MRed(twoqi+x[2]-z[2], rescaleParams, qi, mredParams) + z[3] = MRed(twoqi+x[3]-z[3], rescaleParams, qi, mredParams) + z[4] = MRed(twoqi+x[4]-z[4], rescaleParams, qi, mredParams) + z[5] = MRed(twoqi+x[5]-z[5], rescaleParams, qi, mredParams) + z[6] = MRed(twoqi+x[6]-z[6], rescaleParams, qi, mredParams) + z[7] = MRed(twoqi+x[7]-z[7], rescaleParams, qi, mredParams) } } @@ -470,16 +457,16 @@ func (r *Ring) DivRoundByLastModulus(p0 *Poly) { for i := uint64(0); i < r.N; i = i + 8 { - z := (*[8]uint64)(unsafe.Pointer(&p0tmp[i])) + x := (*[8]uint64)(unsafe.Pointer(&p0tmp[i])) - z[0] = CRed(z[0]+pHalf, pj) - z[1] = CRed(z[1]+pHalf, pj) - z[2] = CRed(z[2]+pHalf, pj) - z[3] = CRed(z[3]+pHalf, pj) - z[4] = CRed(z[4]+pHalf, pj) - z[5] = CRed(z[5]+pHalf, pj) - z[6] = CRed(z[6]+pHalf, pj) - z[7] = CRed(z[7]+pHalf, pj) + x[0] = CRed(x[0]+pHalf, pj) + x[1] = CRed(x[1]+pHalf, pj) + x[2] = CRed(x[2]+pHalf, pj) + x[3] = CRed(x[3]+pHalf, pj) + x[4] = CRed(x[4]+pHalf, pj) + x[5] = CRed(x[5]+pHalf, pj) + x[6] = CRed(x[6]+pHalf, pj) + x[7] = CRed(x[7]+pHalf, pj) } for i := 0; i < level; i++ { @@ -487,9 +474,10 @@ func (r *Ring) DivRoundByLastModulus(p0 *Poly) { p1tmp := p0.Coeffs[i] qi := r.Modulus[i] + twoqi := qi << 1 bredParams := r.BredParams[i] mredParams := r.MredParams[i] - rescaleParams := r.RescaleParams[level-1][i] + rescaleParams := qi - r.RescaleParams[level-1][i] pHalfNegQi = r.Modulus[i] - BRedAdd(pHalf, qi, bredParams) @@ -499,14 +487,14 @@ func (r *Ring) DivRoundByLastModulus(p0 *Poly) { x := (*[8]uint64)(unsafe.Pointer(&p0tmp[j])) z := (*[8]uint64)(unsafe.Pointer(&p1tmp[j])) - z[0] = MRed(z[0]+(qi-BRedAdd(x[0]+pHalfNegQi, qi, bredParams)), rescaleParams, qi, mredParams) - z[1] = MRed(z[1]+(qi-BRedAdd(x[1]+pHalfNegQi, qi, bredParams)), rescaleParams, qi, mredParams) - z[2] = MRed(z[2]+(qi-BRedAdd(x[2]+pHalfNegQi, qi, bredParams)), rescaleParams, qi, mredParams) - z[3] = MRed(z[3]+(qi-BRedAdd(x[3]+pHalfNegQi, qi, bredParams)), rescaleParams, qi, mredParams) - z[4] = MRed(z[4]+(qi-BRedAdd(x[4]+pHalfNegQi, qi, bredParams)), rescaleParams, qi, mredParams) - z[5] = MRed(z[5]+(qi-BRedAdd(x[5]+pHalfNegQi, qi, bredParams)), rescaleParams, qi, mredParams) - z[6] = MRed(z[6]+(qi-BRedAdd(x[6]+pHalfNegQi, qi, bredParams)), rescaleParams, qi, mredParams) - z[7] = MRed(z[7]+(qi-BRedAdd(x[7]+pHalfNegQi, qi, bredParams)), rescaleParams, qi, mredParams) + z[0] = MRed(x[0]+pHalfNegQi+twoqi-z[0], rescaleParams, qi, mredParams) + z[1] = MRed(x[1]+pHalfNegQi+twoqi-z[1], rescaleParams, qi, mredParams) + z[2] = MRed(x[2]+pHalfNegQi+twoqi-z[2], rescaleParams, qi, mredParams) + z[3] = MRed(x[3]+pHalfNegQi+twoqi-z[3], rescaleParams, qi, mredParams) + z[4] = MRed(x[4]+pHalfNegQi+twoqi-z[4], rescaleParams, qi, mredParams) + z[5] = MRed(x[5]+pHalfNegQi+twoqi-z[5], rescaleParams, qi, mredParams) + z[6] = MRed(x[6]+pHalfNegQi+twoqi-z[6], rescaleParams, qi, mredParams) + z[7] = MRed(x[7]+pHalfNegQi+twoqi-z[7], rescaleParams, qi, mredParams) } } diff --git a/ring/ring_test.go b/ring/ring_test.go index ea9ead3c..0a22a14e 100644 --- a/ring/ring_test.go +++ b/ring/ring_test.go @@ -55,7 +55,7 @@ func TestRing(t *testing.T) { var defaultParams []*Parameters if testing.Short() { - defaultParams = DefaultParams[:3] + defaultParams = DefaultParams[0:3] } else { defaultParams = DefaultParams } @@ -473,6 +473,8 @@ func testExtendBasis(testContext *testParams, t *testing.T) { basisextender.ModUpSplitQP(uint64(len(testContext.ringQ.Modulus)-1), Pol, PolTest) + testContext.ringP.Reduce(PolTest, PolTest) + for i := range testContext.ringP.Modulus { require.Equal(t, PolTest.Coeffs[i][:testContext.ringQ.N], PolWant.Coeffs[i][:testContext.ringQ.N]) } diff --git a/ring/ring_test_params.go b/ring/ring_test_params.go index 4636fb7f..87102ea5 100644 --- a/ring/ring_test_params.go +++ b/ring/ring_test_params.go @@ -11,51 +11,27 @@ type Parameters struct { var DefaultParams = []*Parameters{ {12, Qi60[len(Qi60)-2:], Pi60[len(Pi60)-2:]}, {13, Qi60[len(Qi60)-4:], Pi60[len(Pi60)-4:]}, - {14, Qi60[len(Qi60)-8:], Pi60[len(Pi60)-8:]}, - {15, Qi60[len(Qi60)-16:], Pi60[len(Pi60)-16:]}, - {16, Qi60[len(Qi60)-32:], Pi60[len(Pi60)-32:]}, + {14, Qi60[len(Qi60)-7:], Pi60[len(Pi60)-7:]}, + {15, Qi60[len(Qi60)-14:], Pi60[len(Pi60)-14:]}, + {16, Qi60[len(Qi60)-29:], Pi60[len(Pi60)-29:]}, } -// Pi60 are the first one hundred (from 0x800000000000000 and upward) 60-bit NTT-friendly primes for N up to 65536 -var Pi60 = []uint64{576460752308273153, 576460752315482113, 576460752319021057, 576460752319414273, 576460752321642497, - 576460752325705729, 576460752328327169, 576460752329113601, 576460752329506817, 576460752329900033, - 576460752331210753, 576460752337502209, 576460752340123649, 576460752342876161, 576460752347201537, - 576460752347332609, 576460752352837633, 576460752354017281, 576460752355065857, 576460752355459073, - 576460752358604801, 576460752364240897, 576460752368435201, 576460752371187713, 576460752373547009, - 576460752374333441, 576460752376692737, 576460752378003457, 576460752378396673, 576460752380755969, - 576460752381411329, 576460752386129921, 576460752395173889, 576460752395960321, 576460752396091393, - 576460752396484609, 576460752399106049, 576460752405135361, 576460752405921793, 576460752409722881, - 576460752410116097, 576460752411033601, 576460752412082177, 576460752416145409, 576460752416931841, - 576460752421257217, 576460752427548673, 576460752429514753, 576460752435281921, 576460752437248001, - 576460752438558721, 576460752441966593, 576460752449044481, 576460752451141633, 576460752451534849, - 576460752462938113, 576460752465952769, 576460752468705281, 576460752469491713, 576460752472375297, - 576460752473948161, 576460752475389953, 576460752480894977, 576460752483254273, 576460752484827137, - 576460752486793217, 576460752486924289, 576460752492691457, 576460752498589697, 576460752498720769, - 576460752499507201, 576460752504225793, 576460752505405441, 576460752507240449, 576460752507764737, - 576460752509206529, 576460752510124033, 576460752510779393, 576460752511959041, 576460752514449409, - 576460752516284417, 576460752519168001, 576460752520347649, 576460752520609793, 576460752522969089, - 576460752523100161, 576460752524279809, 576460752525852673, 576460752526245889, 576460752526508033, - 576460752532013057, 576460752545120257, 576460752550100993, 576460752551804929, 576460752567402497, - 576460752568975361, 576460752573431809, 576460752580902913, 576460752585490433, 576460752586407937} +// Qi60 are the first [0:32] 61-bit close to 2^{62} NTT-friendly primes for N up to 2^{17} +var Qi60 = []uint64{0x1fffffffffe00001, 0x1fffffffffc80001, 0x1fffffffffb40001, 0x1fffffffff500001, + 0x1fffffffff380001, 0x1fffffffff000001, 0x1ffffffffef00001, 0x1ffffffffee80001, + 0x1ffffffffeb40001, 0x1ffffffffe780001, 0x1ffffffffe600001, 0x1ffffffffe4c0001, + 0x1ffffffffdf40001, 0x1ffffffffdac0001, 0x1ffffffffda40001, 0x1ffffffffc680001, + 0x1ffffffffc000001, 0x1ffffffffb880001, 0x1ffffffffb7c0001, 0x1ffffffffb300001, + 0x1ffffffffb1c0001, 0x1ffffffffadc0001, 0x1ffffffffa400001, 0x1ffffffffa140001, + 0x1ffffffff9d80001, 0x1ffffffff9140001, 0x1ffffffff8ac0001, 0x1ffffffff8a80001, + 0x1ffffffff81c0001, 0x1ffffffff7800001, 0x1ffffffff7680001, 0x1ffffffff7080001} -// Qi60 are the last one hundred (from 0xfffffffffffffff and downward) 60-bit NTT-friendly primes for N up to 65536 -var Qi60 = []uint64{1152921504606584833, 1152921504598720513, 1152921504592429057, 1152921504581419009, 1152921504580894721, - 1152921504578273281, 1152921504577748993, 1152921504577486849, 1152921504568836097, 1152921504565166081, - 1152921504563331073, 1152921504556515329, 1152921504555466753, 1152921504554156033, 1152921504552583169, - 1152921504542883841, 1152921504538951681, 1152921504537378817, 1152921504531873793, 1152921504521650177, - 1152921504509853697, 1152921504508280833, 1152921504506970113, 1152921504495697921, 1152921504491241473, - 1152921504488620033, 1152921504479444993, 1152921504470794241, 1152921504468172801, 1152921504462929921, - 1152921504462667777, 1152921504455589889, 1152921504447987713, 1152921504442482689, 1152921504436191233, - 1152921504427278337, 1152921504419414017, 1152921504409190401, 1152921504403947521, 1152921504396869633, - 1152921504395821057, 1152921504373014529, 1152921504369344513, 1152921504368558081, 1152921504364625921, - 1152921504362790913, 1152921504361218049, 1152921504353615873, 1152921504337887233, 1152921504337625089, - 1152921504321372161, 1152921504314032129, 1152921504303022081, 1152921504301449217, 1152921504288342017, - 1152921504287293441, 1152921504286769153, 1152921504282836993, 1152921504274972673, 1152921504266321921, - 1152921504256622593, 1152921504253739009, 1152921504245088257, 1152921504241942529, 1152921504240107521, - 1152921504239583233, 1152921504238010369, 1152921504234078209, 1152921504231718913, 1152921504230670337, - 1152921504227524609, 1152921504214417409, 1152921504207339521, 1152921504205504513, 1152921504204193793, - 1152921504190824449, 1152921504179552257, 1152921504177192961, 1152921504176668673, 1152921504174309377, - 1152921504172474369, 1152921504164872193, 1152921504162512897, 1152921504139706369, 1152921504134987777, - 1152921504132628481, 1152921504122142721, 1152921504120832001, 1152921504116899841, 1152921504105627649, - 1152921504101957633, 1152921504100384769, 1152921504096452609, 1152921504093306881, 1152921504078364673, - 1152921504067092481, 1152921504066306049, 1152921504057917441, 1152921504053723137, 1152921504050839553} +// Pi60 are the next [32:64] 61-bit close to 2^{62} NTT-friendly primes for N up to 2^{17} +var Pi60 = []uint64{0x1ffffffff6c80001, 0x1ffffffff6140001, 0x1ffffffff5f40001, 0x1ffffffff5700001, + 0x1ffffffff4bc0001, 0x1ffffffff4380001, 0x1ffffffff3240001, 0x1ffffffff2dc0001, + 0x1ffffffff1a40001, 0x1ffffffff11c0001, 0x1ffffffff0fc0001, 0x1ffffffff0d80001, + 0x1ffffffff0c80001, 0x1ffffffff08c0001, 0x1fffffffefd00001, 0x1fffffffef9c0001, + 0x1fffffffef600001, 0x1fffffffeef40001, 0x1fffffffeed40001, 0x1fffffffeed00001, + 0x1fffffffeebc0001, 0x1fffffffed540001, 0x1fffffffed440001, 0x1fffffffed2c0001, + 0x1fffffffed200001, 0x1fffffffec940001, 0x1fffffffec6c0001, 0x1fffffffebe80001, + 0x1fffffffebac0001, 0x1fffffffeba40001, 0x1fffffffeb4c0001, 0x1fffffffeb280001}