diff --git a/Makefile b/Makefile index a884d67f..d3dc27b0 100644 --- a/Makefile +++ b/Makefile @@ -11,15 +11,13 @@ test_examples: @echo Running the examples go run ./examples/ring/vOLE -short > /dev/null go run ./examples/bfv > /dev/null + go run ./examples/ckks/bootstrapping -short > /dev/nul + go run ./examples/ckks/advanced/lut -short > /dev/null go run ./examples/ckks/euler > /dev/null go run ./examples/ckks/polyeval > /dev/null go run ./examples/dbfv/pir &> /dev/null go run ./examples/dbfv/psi &> /dev/null @echo ok - @echo Building resources-heavy examples - go build -o /dev/null ./examples/ckks/bootstrapping - go build -o /dev/null ./examples/ckks/advanced/rlwe_lwe_bridge_LHHMQ20 - @echo ok .PHONY: static_check static_check: check_tools @@ -33,6 +31,21 @@ static_check: check_tools echo $$FMTOUT;\ false;\ fi +.PHONY: test_gotest +test_gotest: + go test -v -timeout=0 ./utils + go test -v -timeout=0 ./ring + go test -v -timeout=0 ./rlwe + go test -v -timeout=0 ./rlwe/ringqp + go test -v -timeout=0 ./rlwe/gadget + go test -v -timeout=0 ./rlwe/rgsw + go test -v -timeout=0 ./rlwe/lut + go test -v -timeout=0 ./bfv + go test -v -timeout=0 ./dbfv + go test -v -timeout=0 ./ckks + go test -v -timeout=0 ./ckks/advanced + go test -v -timeout=0 ./ckks/bootstrapping -test-bootstrapping -short + go test -v -timeout=0 ./dckks @GOVETOUT=$$(go vet ./... 2>&1); \ if [ -z "$$GOVETOUT" ]; then\ diff --git a/examples/ckks/advanced/lut/main.go b/examples/ckks/advanced/lut/main.go new file mode 100644 index 00000000..fd5323c2 --- /dev/null +++ b/examples/ckks/advanced/lut/main.go @@ -0,0 +1,261 @@ +package main + +import ( + "flag" + "fmt" + "github.com/tuneinsight/lattigo/v3/ckks" + ckksAdvanced "github.com/tuneinsight/lattigo/v3/ckks/advanced" + "github.com/tuneinsight/lattigo/v3/ring" + "github.com/tuneinsight/lattigo/v3/rlwe" + "github.com/tuneinsight/lattigo/v3/rlwe/lut" + "time" +) + +// This examples showcases how lookup tables can complement the CKKS scheme to compute non-linear functions +// such as sign. The examples starts by homomorphically decoding the CKKS ciphertext from the canonical embeding +// to the coefficient embeding. It then evaluates the Look-Up-Table (LUT) on each coefficient and repacks the +// outputs of each LUT in a single RLWE ciphertext. Finally it homomorphically encodes the RLWE ciphertext back +// to the canonical embeding of the CKKS scheme. + +// ============================== +// Functions to evaluate with LUT +// ============================== +func sign(x float64) (y float64) { + if x > 0 { + return 1 + } else if x < 0 { + return -1 + } else { + return 0 + } +} + +func main() { + + var err error + + // Base ring degree + LogN := 12 + + // Q modulus Q + Q := []uint64{0x800004001, 0x40002001} // 65.0000116961637 bits + + // P modulus P + P := []uint64{0x4000026001} // 38.00000081692261 bits + + flagShort := flag.Bool("short", false, "runs the example with insecure parameters for fast testing") + flag.Parse() + + if *flagShort { + LogN = 6 + } + + // Starting RLWE params, size of these params + // determine the complexity of the LUT: + // each LUT takes N RGSW ciphertext-ciphetext mul. + // LogN = 12 & LogQP = ~103 -> >128-bit secure. + var paramsN12 ckks.Parameters + if paramsN12, err = ckks.NewParametersFromLiteral(ckks.ParametersLiteral{ + ParametersLiteral: rlwe.ParametersLiteral{ + LogN: LogN, + Q: Q, + P: P, + LogBase2: 0, + H: 0, + Sigma: rlwe.DefaultSigma, + RingType: ring.Standard, + }, + LogSlots: 4, + DefaultScale: 1 << 32, + }); err != nil { + panic(err) + } + + // Params for Key-switching N12 to N11. + // LogN = 12 & LogQP = ~54 -> >>>128-bit secure. + var paramsN12ToN11 ckks.Parameters + if paramsN12ToN11, err = ckks.NewParametersFromLiteral(ckks.ParametersLiteral{ + ParametersLiteral: rlwe.ParametersLiteral{ + LogN: LogN, + Q: Q[:1], + P: []uint64{0x42001}, + LogBase2: 16, + H: 0, + Sigma: rlwe.DefaultSigma, + RingType: ring.Standard, + }, + LogSlots: 0, + DefaultScale: 0, + }); err != nil { + panic(err) + } + + // LUT RLWE params, N of these params determine + // the LUT poly and therefore precision. + // LogN = 11 & LogQP = ~54 -> 128-bit secure. + var paramsN11 ckks.Parameters + if paramsN11, err = ckks.NewParametersFromLiteral(ckks.ParametersLiteral{ + ParametersLiteral: rlwe.ParametersLiteral{ + LogN: LogN - 1, + Q: Q[:1], + P: []uint64{0x42001}, + LogBase2: 12, + H: 0, + Sigma: rlwe.DefaultSigma, + RingType: ring.Standard, + }, + LogSlots: 0, + DefaultScale: 0, + }); err != nil { + panic(err) + } + + // LUT interval + a, b := -8.0, 8.0 + + // Rescale inputs during Homomorphic Decoding by the normalization of the + // LUT inputs and change of scale to ensure that upperbound on the homomorphic + // decryption of LWE during the LUT evaluation X^{dec(lwe)} is smaller than N + // to avoid negacyclic wrapping of X^{dec(lwe)}. + diffScale := paramsN11.QiFloat64(0) / (4.0 * paramsN12.DefaultScale()) + normalization := 2.0 / (b - a) // all inputs are normalized before the LUT evaluation. + + // SlotsToCoeffsParameters homomorphic encoding parameters + var SlotsToCoeffsParameters = ckksAdvanced.EncodingMatrixLiteral{ + LogN: paramsN12.LogN(), + LogSlots: paramsN12.LogSlots(), + Scaling: normalization * diffScale, + LinearTransformType: ckksAdvanced.SlotsToCoeffs, + RepackImag2Real: false, + LevelStart: 1, // starting level + BSGSRatio: 4.0, // ratio between n1/n2 for n1*n2 = slots + BitReversed: false, // bit-reversed input + ScalingFactor: [][]float64{ // Decomposition level of the encoding matrix + {paramsN12.QiFloat64(1)}, // Scale of the decoding matrix + }, + } + + // CoeffsToSlotsParameters homomorphic decoding parameters + var CoeffsToSlotsParameters = ckksAdvanced.EncodingMatrixLiteral{ + LinearTransformType: ckksAdvanced.CoeffsToSlots, + RepackImag2Real: false, + LogN: paramsN12.LogN(), + LogSlots: paramsN12.LogSlots(), + Scaling: 1 / float64(paramsN12.Slots()), + LevelStart: 1, // starting level + BSGSRatio: 4.0, // ratio between n1/n2 for n1*n2 = slots + BitReversed: false, // bit-reversed input + ScalingFactor: [][]float64{ // Decomposition level of the encoding matrix + {paramsN12.QiFloat64(1)}, // Scale of the encoding matrix + }, + } + + fmt.Printf("Generating LUT... ") + now := time.Now() + // Generate LUT, provide function, outputscale, ring and interval. + LUTPoly := lut.InitLUT(sign, paramsN12.DefaultScale(), paramsN12.RingQ(), a, b) + fmt.Printf("Done (%s)\n", time.Since(now)) + + // Index of the LUT poly and repacking after evaluating the LUT. + lutPolyMap := make(map[int]*ring.Poly) // Which slot to evaluate on the LUT + repackIndex := make(map[int]int) // Where to repack slots after the LUT + gapN11 := paramsN11.N() / (2 * paramsN12.Slots()) + gapN12 := paramsN12.N() / (2 * paramsN12.Slots()) + + for i := 0; i < paramsN12.Slots(); i++ { + lutPolyMap[i*gapN11] = LUTPoly + repackIndex[i*gapN11] = i * gapN12 + } + + kgenN12 := ckks.NewKeyGenerator(paramsN12) + skN12 := kgenN12.GenSecretKey() + encoderN12 := ckks.NewEncoder(paramsN12) + encryptorN12 := ckks.NewEncryptor(paramsN12, skN12) + decryptorN12 := ckks.NewDecryptor(paramsN12, skN12) + + kgenN11 := ckks.NewKeyGenerator(paramsN11) + skN11 := kgenN11.GenSecretKey() + //decryptorN11 := ckks.NewDecryptor(paramsN11, skN11) + //encoderN11 := ckks.NewEncoder(paramsN11) + + // Switchingkey RLWEN12 -> RLWEN11 + swkN12ToN11 := ckks.NewKeyGenerator(paramsN12ToN11).GenSwitchingKey(skN12, skN11) + + fmt.Printf("Gen SlotsToCoeffs Matrices... ") + now = time.Now() + SlotsToCoeffsMatrix := ckksAdvanced.NewHomomorphicEncodingMatrixFromLiteral(SlotsToCoeffsParameters, encoderN12) + CoeffsToSlotsMatrix := ckksAdvanced.NewHomomorphicEncodingMatrixFromLiteral(CoeffsToSlotsParameters, encoderN12) + fmt.Printf("Done (%s)\n", time.Since(now)) + + // Rotation Keys + rotations := []int{} + for i := 1; i < paramsN12.N(); i <<= 1 { + rotations = append(rotations, i) + } + + rotations = append(rotations, SlotsToCoeffsParameters.Rotations()...) + rotations = append(rotations, CoeffsToSlotsParameters.Rotations()...) + + rotKey := kgenN12.GenRotationKeysForRotations(rotations, true, skN12) + + // LUT Evaluator + evalLUT := lut.NewEvaluator(paramsN12.Parameters, paramsN11.Parameters, rotKey) + + // CKKS Evaluator + evalCKKS := ckksAdvanced.NewEvaluator(paramsN12, rlwe.EvaluationKey{Rlk: nil, Rtks: rotKey}) + evalCKKSN12ToN11 := ckks.NewEvaluator(paramsN12ToN11, rlwe.EvaluationKey{}) + + fmt.Printf("Encrypting bits of skLWE in RGSW... ") + now = time.Now() + LUTKEY := evalLUT.GenLUTKey(skN12, skN11) // Generate RGSW(sk_i) for all coefficients of sk + fmt.Printf("Done (%s)\n", time.Since(now)) + + // Generates the starting plaintext values. + interval := (b - a) / float64(paramsN12.Slots()) + values := make([]float64, paramsN12.Slots()) + for i := 0; i < paramsN12.Slots(); i++ { + values[i] = a + float64(i)*interval + } + + pt := ckks.NewPlaintext(paramsN12, paramsN12.MaxLevel(), paramsN12.DefaultScale()) + encoderN12.EncodeSlots(values, pt, paramsN12.LogSlots()) + ctN12 := encryptorN12.EncryptNew(pt) + + fmt.Printf("Homomorphic Decoding... ") + now = time.Now() + // Homomorphic Decoding: [(a+bi), (c+di)] -> [a, c, b, d] + ctN12 = evalCKKS.SlotsToCoeffsNew(ctN12, nil, SlotsToCoeffsMatrix) + ctN12.Scale = paramsN11.QiFloat64(0) / 4.0 + + // Key-Switch from LogN = 12 to LogN = 10 + evalCKKS.DropLevel(ctN12, ctN12.Level()) // drop to LUT level + ctTmp := evalCKKSN12ToN11.SwitchKeysNew(ctN12, swkN12ToN11) // key-switch to LWE degree + ctN11 := ckks.NewCiphertext(paramsN11, 1, paramsN11.MaxLevel(), ctTmp.Scale) + rlwe.SwitchCiphertextRingDegreeNTT(ctTmp.Ciphertext, paramsN11.RingQ(), paramsN12.RingQ(), ctN11.Ciphertext) + fmt.Printf("Done (%s)\n", time.Since(now)) + + //for i, v := range encoderN11.DecodeCoeffs(decryptorN11.DecryptNew(ctN11)){ + // fmt.Printf("%3d: %7.4f\n", i, v) + //} + + fmt.Printf("Evaluating LUT... ") + now = time.Now() + // Extracts & EvalLUT(LWEs, indexLUT) on the fly -> Repack(LWEs, indexRepack) -> RLWE + ctN12.Ciphertext = evalLUT.EvaluateAndRepack(ctN11.Ciphertext, lutPolyMap, repackIndex, LUTKEY) + ctN12.Scale = paramsN12.DefaultScale() + fmt.Printf("Done (%s)\n", time.Since(now)) + + //for i, v := range encoderN12.DecodeCoeffs(decryptorN12.DecryptNew(ctN12)){ + // fmt.Printf("%3d: %7.4f\n", i, v) + //} + + fmt.Printf("Homomorphic Encoding... ") + now = time.Now() + // Homomorphic Encoding: [LUT(a), LUT(c), LUT(b), LUT(d)] -> [(LUT(a)+LUT(b)i), (LUT(c)+LUT(d)i)] + ctN12, _ = evalCKKS.CoeffsToSlotsNew(ctN12, CoeffsToSlotsMatrix) + fmt.Printf("Done (%s)\n", time.Since(now)) + + for i, v := range encoderN12.Decode(decryptorN12.DecryptNew(ctN12), paramsN12.LogSlots()) { + fmt.Printf("%7.4f -> %7.4f\n", values[i], v) + } +} diff --git a/examples/ckks/advanced/rlwe_lwe_bridge_LHHMQ20/main.go b/examples/ckks/advanced/rlwe_lwe_bridge_LHHMQ20/main.go deleted file mode 100644 index e7b9a4f8..00000000 --- a/examples/ckks/advanced/rlwe_lwe_bridge_LHHMQ20/main.go +++ /dev/null @@ -1,425 +0,0 @@ -package main - -import ( - "fmt" - "math" - "time" - - "github.com/tuneinsight/lattigo/v3/ckks" - ckksAdvanced "github.com/tuneinsight/lattigo/v3/ckks/advanced" - "github.com/tuneinsight/lattigo/v3/ring" - "github.com/tuneinsight/lattigo/v3/rlwe" - "github.com/tuneinsight/lattigo/v3/utils" -) - -// This example is an implementation of the RLWE -> LWE extraction followed by an LWE -> RLWE repacking -// (bridge between CKKS and FHEW ciphertext) based on "Pegasus: Bridging Polynomial and Non-polynomial -// Evaluations in Homomorphic Encryption". -// It showcases advanced tools of the CKKS scheme, such as homomorphic decoding and homomorphic modular reduction. - -func main() { - - // Ring Learning With Error parameters - fmt.Printf("Gen RLWE Parameters... ") - start := time.Now() - paramsRLWE := genRLWEParameters() - fmt.Printf("Done (%s)\n", time.Since(start)) - - // Learning With Error parameters - fmt.Printf("Gen LWE Parameters... ") - start = time.Now() - paramsLWE := genLWEParameters(paramsRLWE) - fmt.Printf("Done (%s)\n", time.Since(start)) - - fmt.Printf("RLWE Params : logN=%2d, logQP=%3d\n", paramsRLWE.LogN(), paramsRLWE.LogQP()) - fmt.Printf("LWE Params : logN=%2d, logQP=%3d\n", paramsLWE.LogN(), paramsLWE.LogQP()) - - // Homomorphic decoding parameters - SlotsToCoeffsParameters := ckksAdvanced.EncodingMatrixLiteral{ - LogN: paramsRLWE.LogN(), - LogSlots: paramsRLWE.LogSlots(), - Scaling: 1.0, - LinearTransformType: ckksAdvanced.SlotsToCoeffs, - LevelStart: 2, // starting level - BSGSRatio: 4.0, // ratio between n1/n2 for n1*n2 = slots - BitReversed: false, // bit-reversed input - ScalingFactor: [][]float64{ // Decomposition level of the encoding matrix - {paramsRLWE.QiFloat64(1)}, // Scale of the second matrix - {paramsRLWE.QiFloat64(2)}, // Scale of the first matrix - }, - } - - // Homomorphic modular reduction parameters - EvalModParameters := ckksAdvanced.EvalModLiteral{ - Q: paramsRLWE.Q()[0], // Modulus - LevelStart: paramsRLWE.MaxLevel() - 1, // Starting level of the procedure - SineType: ckksAdvanced.Cos1, // Type of approximation - MessageRatio: 256.0, // Q/|m| - K: 16, // Interval of approximation - SineDeg: 63, // Degree of approximation - DoubleAngle: 2, // Number of double angle evaluation - ArcSineDeg: 0, // Degree of arcsine Taylor polynomial - ScalingFactor: 1 << 60, // Scaling factor during the procedure - } - - // Generates the homomorphic modular reduction polynomial approximation - fmt.Printf("Gen EvalMod Poly... ") - start = time.Now() - EvalModPoly := ckksAdvanced.NewEvalModPolyFromLiteral(EvalModParameters) - fmt.Printf("Done (%s)\n", time.Since(start)) - - // RLWE Parameters - start = time.Now() - encoder := ckks.NewEncoder(paramsRLWE) - kgenRLWE := ckks.NewKeyGenerator(paramsRLWE) - skRLWE := kgenRLWE.GenSecretKey() - encryptor := ckks.NewEncryptor(paramsRLWE, skRLWE) - decryptor := ckks.NewDecryptor(paramsRLWE, skRLWE) - - fmt.Printf("Gen SlotsToCoeffs Matrices... ") - start = time.Now() - SlotsToCoeffsMatrix := ckksAdvanced.NewHomomorphicEncodingMatrixFromLiteral(SlotsToCoeffsParameters, encoder) - fmt.Printf("Done (%s)\n", time.Since(start)) - - fmt.Printf("Gen Evaluation Keys:\n") - fmt.Printf(" Decoding Keys... ") - start = time.Now() - rotKey := kgenRLWE.GenRotationKeysForRotations(SlotsToCoeffsParameters.Rotations(), true, skRLWE) - fmt.Printf("Done (%s)\n", time.Since(start)) - fmt.Printf(" Relinearization Key... ") - start = time.Now() - rlk := kgenRLWE.GenRelinearizationKey(skRLWE, 1) - fmt.Printf("Done (%s)\n", time.Since(start)) - - fmt.Printf(" Repacking Keys... ") - nonzerodiags := make([]int, paramsRLWE.Slots()) - for i := range nonzerodiags { - nonzerodiags[i] = i - } - rotationsRepack := paramsRLWE.RotationsForLinearTransform(nonzerodiags, paramsRLWE.Slots(), 4.0) - rotationsRepack = append(rotationsRepack, paramsRLWE.RotationsForTrace(paramsRLWE.LogSlots(), paramsLWE.LogN())...) - rotKeyRepack := kgenRLWE.GenRotationKeysForRotations(rotationsRepack, false, skRLWE) - - fmt.Printf("Done (%s)\n", time.Since(start)) - - eval := ckksAdvanced.NewEvaluator(paramsRLWE, rlwe.EvaluationKey{Rlk: rlk, Rtks: rotKey}) - - // LWE Parameters - kgenLWE := ckks.NewKeyGenerator(paramsLWE) - skLWE := kgenLWE.GenSecretKey() - - // RLWE -> LWE Switching key - fmt.Printf(" RLWE -> LWE Switching Key... ") - start = time.Now() - swkRLWEDimToLWEDim := kgenRLWE.GenSwitchingKey(skRLWE, skLWE) - fmt.Printf("Done (%s)\n", time.Since(start)) - - // Encodes and Encrypts skLWE - fmt.Printf("Encode & Encrypt SK LWE... ") - start = time.Now() - skLWEInvNTT := paramsLWE.RingQ().NewPoly() - ring.CopyValues(skLWE.Value.Q, skLWEInvNTT) - paramsLWE.RingQ().InvNTT(skLWEInvNTT, skLWEInvNTT) - Q := paramsRLWE.Q()[0] - paramsLWE.RingQ().InvMFormLvl(0, skLWEInvNTT, skLWEInvNTT) - skFloat := make([]float64, paramsLWE.N()) - for i, s := range skLWEInvNTT.Coeffs[0] { - if s >= Q>>1 { - skFloat[i] = -float64(Q - s) - } else { - skFloat[i] = float64(s) - } - - skFloat[i] *= math.Pow(1.0/(EvalModPoly.K()*EvalModPoly.QDiff()), 0.5) // sqrt(pre-scaling for Cheby) - } - - paramsLWE.RingQ().MFormLvl(0, skLWEInvNTT, skLWEInvNTT) - ptSk := ckks.NewPlaintext(paramsRLWE, paramsRLWE.MaxLevel(), paramsRLWE.QiFloat64(paramsRLWE.MaxLevel())) - encoder.Encode(skFloat, ptSk, paramsLWE.LogN()) - ctSk := encryptor.EncryptNew(ptSk) - fmt.Printf("Done (%s)\n", time.Since(start)) - - // ********** PLAINTEXT GENERATION & ENCRYPTION ************** - - // Random complex plaintext encrypted - fmt.Printf("Gen Plaintext & Encrypt... ") - start = time.Now() - values := make([]complex128, paramsRLWE.Slots()) - for i := range values { - values[i] = complex(float64(i+1)/float64(paramsRLWE.Slots()), 1+float64(i+1)/float64(paramsRLWE.Slots())) - } - - plaintext := ckks.NewPlaintext(paramsRLWE, paramsRLWE.MaxLevel(), paramsRLWE.DefaultScale()) - // Must encode with 2*Slots because a real vector is returned - encoder.Encode(values, plaintext, paramsRLWE.LogSlots()) - ct := encryptor.EncryptNew(plaintext) - fmt.Printf("Done (%s)\n", time.Since(start)) - - // ******** STEP 1 : HOMOMORPHIC DECODING ******* - fmt.Printf("Homomorphic Decoding... ") - start = time.Now() - ct = eval.SlotsToCoeffsNew(ct, nil, SlotsToCoeffsMatrix) - fmt.Printf("Done (%s)\n", time.Since(start)) - - // ******** STEP 2 : RLWE -> LWE EXTRACTION ************* - - fmt.Printf("RLWE -> LWE Extraction... ") - start = time.Now() - // Scale the message to Delta = Q/MessageRatio - scale := math.Exp2(math.Round(math.Log2(float64(EvalModParameters.Q) / EvalModParameters.MessageRatio))) - eval.ScaleUp(ct, math.Round(scale/ct.Scale), ct) - - // Switch from RLWE parameters to LWE parameters - ctTmp := eval.SwitchKeysNew(ct, swkRLWEDimToLWEDim) - ctLWE := ckks.NewCiphertext(paramsLWE, 1, 0, ctTmp.Scale) - - // Switch the ciphertext outside of the NTT domain for the LWE extraction - for i := range ctLWE.Value { - paramsRLWE.RingQ().InvNTTLvl(0, ctTmp.Value[i], ctTmp.Value[i]) - } - - rlwe.SwitchCiphertextRingDegree(ctTmp.El(), ctLWE.El()) - - // RLWE -> LWE Extraction - lweReal, lweImag := ExtractLWESamplesBitReversed(ctLWE, paramsLWE) - fmt.Printf("Done (%s)\n", time.Since(start)) - - // Visual of some values - fmt.Println("Visual Comparison :") - fmt.Printf("Slot %4d : RLWE %f LWE %f\n", 0, values[0], complex(DecryptLWE(paramsLWE.RingQ(), lweReal[0], scale, skLWEInvNTT), DecryptLWE(paramsLWE.RingQ(), lweImag[0], scale, skLWEInvNTT))) - fmt.Printf("Slot %4d : RLWE %f LWE %f\n", paramsLWE.Slots()-1, values[paramsLWE.Slots()-1], complex(DecryptLWE(paramsLWE.RingQ(), lweReal[paramsLWE.Slots()-1], scale, skLWEInvNTT), DecryptLWE(paramsLWE.RingQ(), lweImag[paramsLWE.Slots()-1], scale, skLWEInvNTT))) - - // ********* STEP 3 : LWE -> RLWE REPACKING - fmt.Printf("Encode LWE Samples... ") - start = time.Now() - ptLWE := ckks.NewPlaintext(paramsRLWE, paramsRLWE.MaxLevel(), 1.0) - - // Encode the LWE samples as a vector - lweEncoded := make([]complex128, paramsRLWE.Slots()) - for i := 0; i < paramsRLWE.Slots(); i++ { - lweEncoded[i] = complex(float64(lweReal[i].b), float64(lweImag[i].b)) - lweEncoded[i] *= complex(math.Pow(1/(EvalModPoly.K()*EvalModPoly.QDiff()), 1.0), 0) // pre-scaling for Cheby - } - - encoder.Encode(lweEncoded, ptLWE, paramsRLWE.LogSlots()) - fmt.Printf("Done (%s)\n", time.Since(start)) - - // Encode A - - // Compute skLeft skRight - // __________ _ __________ _ - // | | | | | | | | - // | | | | | | | | - // n | ALeft | x | | + | ARight | x | | = A x sk - // | | | | | | | | - // |__________| |_| |__________| |_| - // - // N/n 1 N/n 1 - - fmt.Printf("Encode A... ") - start = time.Now() - AVectors := make([][]complex128, paramsLWE.Slots()) - for i := range AVectors { - tmp := make([]complex128, paramsLWE.N()) - for j := 0; j < paramsLWE.N(); j++ { - tmp[j] = complex(float64(lweReal[i].a[j]), float64(lweImag[i].a[j])) - tmp[j] *= complex(math.Pow(1/(EvalModPoly.K()*EvalModPoly.QDiff()), 0.5), 0) // sqrt(pre-scaling for Cheby) - } - - AVectors[i] = tmp - } - - // Diagonalize - AMatDiag := make(map[int][]complex128) - for i := 0; i < paramsLWE.Slots(); i++ { - tmp := make([]complex128, paramsLWE.N()) - for j := 0; j < paramsLWE.N(); j++ { - tmp[j] = AVectors[j%paramsLWE.Slots()][(j+i)%paramsLWE.N()] - } - AMatDiag[i] = tmp - } - - linTransf := ckks.GenLinearTransformBSGS(encoder, AMatDiag, paramsRLWE.MaxLevel(), 1.0, 4.0, paramsLWE.LogN()) - fmt.Printf("Done (%s)\n", time.Since(start)) - - evalRepack := ckks.NewEvaluator(paramsRLWE, rlwe.EvaluationKey{Rlk: rlk, Rtks: rotKeyRepack}) - - fmt.Printf("Homomorphic Partial Decryption : pt = A x sk + encode(LWE) + I(X)*Q... ") - start = time.Now() - ctAs := evalRepack.LinearTransformNew(ctSk, linTransf)[0] // A_left * sk || A_right * sk - ctAs = evalRepack.TraceNew(ctAs, paramsLWE.LogSlots(), paramsLWE.LogN()) // A * sk || A * sk - evalRepack.MultByConst(ctAs, paramsLWE.N()/paramsLWE.Slots(), ctAs) - eval.Rescale(ctAs, 1.0, ctAs) - eval.Add(ctAs, ptLWE, ctAs) // A * sk || A * sk + LWE_real || LWE_imag = RLWE + I(X) * Q - ctAs.Scale = scale - fmt.Printf("Done (%s)\n", time.Since(start)) - - fmt.Printf("Homomorphic Modular Reduction : pt mod Q... ") - start = time.Now() - // Extract imaginary part : RLWE_real + I(X)*Q ; RLWE_imag + I(X)*Q - ctAsConj := eval.ConjugateNew(ctAs) - ctAsReal := eval.AddNew(ctAs, ctAsConj) - ctAsImag := eval.SubNew(ctAs, ctAsConj) - ctAsReal.Scale = ctAsReal.Scale * 2 // Divide by 2 - ctAsImag.Scale = ctAsImag.Scale * 2 // Divide by 2 - eval.ScaleUp(ctAsReal, math.Round((EvalModPoly.ScalingFactor()/EvalModPoly.MessageRatio())/ctAsReal.Scale), ctAsReal) // Scale the real message up to Sine/MessageRatio - eval.ScaleUp(ctAsImag, math.Round((EvalModPoly.ScalingFactor()/EvalModPoly.MessageRatio())/ctAsImag.Scale), ctAsImag) // Scale the imag message up to Sine/MessageRatio - ctAsReal = eval.EvalModNew(ctAsReal, EvalModPoly) // Real mod Q - eval.DivByi(ctAsImag, ctAsImag) - ctAsImag = eval.EvalModNew(ctAsImag, EvalModPoly) // (-i*imag mod Q)*i - eval.MultByi(ctAsImag, ctAsImag) - eval.Add(ctAsReal, ctAsImag, ctAsReal) // Repack both imag and real parts - fmt.Printf("Done (%s)\n", time.Since(start)) - - fmt.Println("Visual Comparison :") - v := encoder.DecodePublic(decryptor.DecryptNew(ctAsReal), paramsRLWE.LogSlots(), 0) - fmt.Printf("Slot %4d : Want %f Have %f\n", 0, values[0], v[0]) - fmt.Printf("Slot %4d : Want %f Have %f\n", paramsRLWE.Slots()-1, values[paramsRLWE.Slots()-1], v[paramsRLWE.Slots()-1]) - -} - -func genRLWEParameters() (paramsRLWE ckks.Parameters) { - var err error - if paramsRLWE, err = ckks.NewParametersFromLiteral(ckks.ParametersLiteral{ - ParametersLiteral: rlwe.ParametersLiteral{ - LogN: 15, - Q: []uint64{ - 0xffff820001, // 40 Q0 - 0x2000000a0001, // 45 - 0x2000000e0001, // 45 - 0xfffffffff840001, // 60 Sine (double angle) - 0x1000000000860001, // 60 Sine (double angle) - 0xfffffffff6a0001, // 60 Sine - 0x1000000000980001, // 60 Sine - 0xfffffffff5a0001, // 60 Sine - 0x1000000000b00001, // 60 Sine - 0x1000000000ce0001, // 60 Sine - 0xfffffffff2a0001, // 60 Sine - 0x100000000060001, // 58 Repack & Change of basis - }, - P: []uint64{ - 0x1fffffffffe00001, // 61 - 0x1fffffffffc80001, // 61 - 0x1fffffffffb40001, // 61 - }, - LogBase2: 0, - H: 0, - Sigma: rlwe.DefaultSigma, - }, - LogSlots: 9, - DefaultScale: 1 << 30, - }); err != nil { - panic(err) - } - return -} - -func genLWEParameters(paramsRLWE ckks.Parameters) (paramsLWE ckks.Parameters) { - var err error - if paramsLWE, err = ckks.NewParametersFromLiteral(ckks.ParametersLiteral{ - ParametersLiteral: rlwe.ParametersLiteral{ - LogN: 10, - Q: paramsRLWE.Q()[:1], // 40 Q0 - P: paramsRLWE.P()[:1], // Pi 61 - LogBase2: 0, - H: 64, - Sigma: paramsRLWE.Sigma(), - }, - LogSlots: paramsRLWE.LogSlots(), - DefaultScale: paramsRLWE.DefaultScale(), - }); err != nil { - panic(err) - } - return -} - -// DecryptLWE decrypts an LWE sample -func DecryptLWE(ringQ *ring.Ring, lwe RNSLWESample, scale float64, skInvNTT *ring.Poly) float64 { - - tmp := ringQ.NewPolyLvl(0) - pol := new(ring.Poly) - pol.Coeffs = [][]uint64{lwe.a} - ringQ.MulCoeffsMontgomeryLvl(0, pol, skInvNTT, tmp) - qi := ringQ.Modulus[0] - tmp0 := tmp.Coeffs[0] - tmp1 := lwe.b - for j := 0; j < ringQ.N; j++ { - tmp1 = ring.CRed(tmp1+tmp0[j], qi) - } - - if tmp1 >= ringQ.Modulus[0]>>1 { - tmp1 = ringQ.Modulus[0] - tmp1 - return -float64(tmp1) / scale - } - - return float64(tmp1) / scale -} - -// RNSLWESample is a struct for RNS LWE samples -type RNSLWESample struct { - b uint64 - a []uint64 -} - -// ExtractLWESamplesBitReversed extracts LWE samples from a R-LWE sample -func ExtractLWESamplesBitReversed(ct *ckks.Ciphertext, params ckks.Parameters) (LWEReal, LWEImag []RNSLWESample) { - - ringQ := params.RingQ() - - LWEReal = make([]RNSLWESample, params.Slots()) - LWEImag = make([]RNSLWESample, params.Slots()) - - // Copy coefficients multiplied by X^{N-1} in reverse order: - // a_{0} -a_{N-1} -a2_{N-2} ... -a_{1} - acc := ringQ.NewPolyLvl(ct.Level()) - for i, qi := range ringQ.Modulus[:ct.Level()+1] { - tmp0 := acc.Coeffs[i] - tmp1 := ct.Value[1].Coeffs[i] - tmp0[0] = tmp1[0] - for j := 1; j < ringQ.N; j++ { - tmp0[j] = qi - tmp1[ringQ.N-j] - } - } - - pol := ct.Value[0] - - gap := params.N() / (2 * params.Slots()) // Gap between plaintext coefficient if sparse packed - - // Real values - for i, idx := 0, 0; i < params.Slots(); i, idx = i+1, idx+gap { - - iRev := utils.BitReverse64(uint64(i), uint64(params.LogSlots())) - - LWEReal[iRev].b = pol.Coeffs[0][idx] - LWEReal[iRev].a = make([]uint64, params.N()) - copy(LWEReal[iRev].a, acc.Coeffs[0]) - - // Multiplies the accumulator by X^{N/(2*slots)} - MulBySmallMonomial(ringQ, acc, gap) - } - - // Imaginary values - for i, idx := 0, 0; i < params.Slots(); i, idx = i+1, idx+gap { - - iRev := utils.BitReverse64(uint64(i), uint64(params.LogSlots())) - - LWEImag[iRev].b = pol.Coeffs[0][idx+(params.N()>>1)] - LWEImag[iRev].a = make([]uint64, params.N()) - copy(LWEImag[iRev].a, acc.Coeffs[0]) - - // Multiply the accumulator by X^{N/(2*slots)} - MulBySmallMonomial(ringQ, acc, gap) - } - return -} - -//MulBySmallMonomial multiplies pol by x^n -func MulBySmallMonomial(ringQ *ring.Ring, pol *ring.Poly, n int) { - for i, qi := range ringQ.Modulus[:pol.Level()+1] { - pol.Coeffs[i] = append(pol.Coeffs[i][ringQ.N-n:], pol.Coeffs[i][:ringQ.N-n]...) - tmp := pol.Coeffs[i] - for j := 0; j < n; j++ { - tmp[j] = qi - tmp[j] - } - } -} diff --git a/examples/ckks/bootstrapping/main.go b/examples/ckks/bootstrapping/main.go index 9bddcc72..3a0ae2d8 100644 --- a/examples/ckks/bootstrapping/main.go +++ b/examples/ckks/bootstrapping/main.go @@ -1,6 +1,7 @@ package main import ( + "flag" "fmt" "math" @@ -14,6 +15,9 @@ func main() { var err error + flagShort := flag.Bool("short", false, "runs the example with insecure parameters for fast testing") + flag.Parse() + var btp *bootstrapping.Bootstrapper var kgen rlwe.KeyGenerator var encoder ckks.Encoder @@ -29,13 +33,37 @@ func main() { // github.com/tuneinsight/lattigo/v3/ckks/bootstrapping/default_params.go // // LogSlots is hardcoded to 15 in the parameters, but can be changed from 1 to 15. +<<<<<<< btp_eprint // When changing LogSlots make sure that the number of levels allocated to CtS and StC is // smaller or equal to LogSlots. +======= + // When changing logSlots make sure that the number of levels allocated to CtS and StC is + // smaller or equal to logSlots. +<<<<<<< 83ae36f5f9908381fe0d957ce0daa4f037d38e6f + ckksParams := bootstrapping.DefaultCKKSParameters[0] + btpParams := bootstrapping.DefaultParameters[0] +======= +<<<<<<< btp_eprint +>>>>>>> [rlwe]: further refactoring paramSet := bootstrapping.DefaultParametersSparse[0] // bootstrapping.DefaultParametersDense[0] ckksParams := paramSet.SchemeParams btpParams := paramSet.BootstrappingParams +<<<<<<< btp_eprint +======= +======= + ckksParams := bootstrapping.DefaultCKKSParameters[0] + + if *flagShort { + ckksParams.LogN = 13 + ckksParams.LogSlots = 12 + } + + btpParams := bootstrapping.DefaultParameters[0] +>>>>>>> [rlwe]: further refactoring +>>>>>>> [rlwe]: further refactoring +>>>>>>> [rlwe]: further refactoring params, err := ckks.NewParametersFromLiteral(ckksParams) if err != nil { panic(err) diff --git a/examples/main.go b/examples/main.go deleted file mode 100644 index b62ff6a1..00000000 --- a/examples/main.go +++ /dev/null @@ -1,311 +0,0 @@ -package main - -import ( - "fmt" - "github.com/tuneinsight/lattigo/v3/ckks" - ckksAdvanced "github.com/tuneinsight/lattigo/v3/ckks/advanced" - "github.com/tuneinsight/lattigo/v3/lwe" - "github.com/tuneinsight/lattigo/v3/ring" - "github.com/tuneinsight/lattigo/v3/rlwe" - "math" - "time" -) - -func main() { - LUT() -} - -// ============================== -// Functions to evaluate with LUT -// ============================== -func sign(x float64) (y float64) { - if x > 0 { - return 1 - } else if x < 0 { - return -1 - } else { - return 0 - } -} - -func sigmoid(x float64) (y float64) { - return 1.0 / (math.Exp(-x) + 1) -} - -func identity(x float64) (y float64) { - return x -} - -func relu(x float64) (y float64) { - if x < 0 { - return 0 - } - - return x -} - -// Q modulus Q -var Q = []uint64{0x80000000080001, 0x2000000e0001, 0x1fffffc20001} - -// P modulus P -var P = []uint64{0x4000000008a0001} - -// Starting RLWE params, size of these params -// determine the complexity of the LUT: -// each LUT takes N RGSW ciphertext-ciphetext mul. -var ckksParamsN12 = ckks.ParametersLiteral{ - ParametersLiteral: rlwe.ParametersLiteral{ - LogN: 7, - Q: Q, - P: P, - LogBase2: 0, - H: 0, - Sigma: rlwe.DefaultSigma, - RingType: ring.Standard, - }, - LogSlots: 4, - DefaultScale: 1 << 40, -} - -// LUT RLWE params, N of these params determine -// the LUT poly and therefore precision. -var ckksParamsN10 = ckks.ParametersLiteral{ - ParametersLiteral: rlwe.ParametersLiteral{ - LogN: 5, - Q: Q[:1], - P: P[:1], - LogBase2: 0, - H: 0, - Sigma: rlwe.DefaultSigma, - RingType: ring.Standard, - }, - LogSlots: 0, - DefaultScale: 0, -} - -// LUT example -func LUT() { - var err error - var paramsN12, paramsN10 ckks.Parameters - if paramsN12, err = ckks.NewParametersFromLiteral(ckksParamsN12); err != nil { - panic(err) - } - - if paramsN10, err = ckks.NewParametersFromLiteral(ckksParamsN10); err != nil { - panic(err) - } - - // LUT interval - a, b := -8.0, 8.0 - - fmt.Printf("Generating LUT... ") - now := time.Now() - // Generate LUT, provide function, outputscale, ring and interval. - LUTPoly := lwe.InitLUT(sign, paramsN12.DefaultScale(), paramsN12.RingQ(), a, b) - fmt.Printf("Done (%s)\n", time.Since(now)) - - lutPolyMap := make(map[int]*ring.Poly) // Which slot to evaluate on the LUT - repackIndex := make(map[int]int) // Where to repack slots after the LUT - gapN10 := paramsN10.N() / (2 * paramsN12.Slots()) - gapN12 := paramsN12.N() / (2 * paramsN12.Slots()) - - for i := 0; i < paramsN12.Slots(); i++ { - lutPolyMap[i*gapN10] = LUTPoly - repackIndex[i*gapN10] = i * gapN12 - } - - kgenN12 := ckks.NewKeyGenerator(paramsN12) - skN12 := kgenN12.GenSecretKey() - encoderN12 := ckks.NewEncoder(paramsN12) - encryptorN12 := ckks.NewEncryptor(paramsN12, skN12) - decryptorN12 := ckks.NewDecryptor(paramsN12, skN12) - - kgenN10 := ckks.NewKeyGenerator(paramsN10) - skN10 := kgenN10.GenSecretKey() - //decryptorN10 := ckks.NewDecryptor(paramsN10, skN10) - //encoderN10 := ckks.NewEncoder(paramsN10) - - // Switchingkey RLWEN12 -> RLWEN10 - swkN12ToN10 := kgenN12.GenSwitchingKey(skN12, skN10) - - fmt.Printf("Gen SlotsToCoeffs Matrices... ") - now = time.Now() - - // Rescale inputs during Homomorphic Decoding by the normalization of the - // LUT inputs and change of scale to ensure that upperbound on the homomorphic - // decryption of LWE during the LUT evaluation X^{dec(lwe)} is smaller than N - // to avoid negacyclic wrapping of X^{dec(lwe)}. - diffScale := paramsN10.QiFloat64(0) / (4.0 * paramsN12.DefaultScale()) - normalization := 2.0 / (b - a) - - // SlotsToCoeffsParameters homomorphic encoding parameters - var SlotsToCoeffsParameters = ckksAdvanced.EncodingMatrixLiteral{ - LogN: paramsN12.LogN(), - LogSlots: paramsN12.LogSlots(), - Scaling: normalization * diffScale, - LinearTransformType: ckksAdvanced.SlotsToCoeffs, - RepackImag2Real: false, - LevelStart: 2, // starting level - BSGSRatio: 4.0, // ratio between n1/n2 for n1*n2 = slots - BitReversed: false, // bit-reversed input - ScalingFactor: [][]float64{ // Decomposition level of the encoding matrix - {0x2000000e0001}, // Scale of the second matriox - {0x1fffffc20001}, // Scale of the first matrix - }, - } - - // CoeffsToSlotsParameters homomorphic decoding parameters - var CoeffsToSlotsParameters = ckksAdvanced.EncodingMatrixLiteral{ - LinearTransformType: ckksAdvanced.CoeffsToSlots, - RepackImag2Real: false, - LogN: paramsN12.LogN(), - LogSlots: paramsN12.LogSlots(), - Scaling: 1 / float64(paramsN12.Slots()), - LevelStart: 2, // starting level - BSGSRatio: 4.0, // ratio between n1/n2 for n1*n2 = slots - BitReversed: false, // bit-reversed input - ScalingFactor: [][]float64{ // Decomposition level of the encoding matrix - {0x2000000e0001}, // Scale of the second matriox - {0x1fffffc20001}, // Scale of the first matrix - }, - } - - SlotsToCoeffsMatrix := ckksAdvanced.NewHomomorphicEncodingMatrixFromLiteral(SlotsToCoeffsParameters, encoderN12) - CoeffsToSlotsMatrix := ckksAdvanced.NewHomomorphicEncodingMatrixFromLiteral(CoeffsToSlotsParameters, encoderN12) - fmt.Printf("Done (%s)\n", time.Since(now)) - - // Rotation Keys - rotations := []int{} - for i := 1; i < paramsN12.N(); i <<= 1 { - rotations = append(rotations, i) - } - - rotations = append(rotations, SlotsToCoeffsParameters.Rotations()...) - rotations = append(rotations, CoeffsToSlotsParameters.Rotations()...) - - rotKey := kgenN12.GenRotationKeysForRotations(rotations, true, skN12) - - // LUT handler - handler := lwe.NewHandler(paramsN12.Parameters, paramsN10.Parameters, rotKey) - - eval := ckksAdvanced.NewEvaluator(paramsN12, rlwe.EvaluationKey{Rlk: nil, Rtks: rotKey}) - - fmt.Printf("Encrypting bits of skLWE in RGSW... ") - now = time.Now() - LUTKEY := handler.GenLUTKey(skN12, skN10) // Generate RGSW(sk_i) for all coefficients of sk - fmt.Printf("Done (%s)\n", time.Since(now)) - - interval := (b - a) / float64(paramsN12.Slots()) - values := make([]float64, paramsN12.Slots()) - for i := 0; i < paramsN12.Slots(); i++ { - values[i] = a + float64(i)*interval - } - - pt := ckks.NewPlaintext(paramsN12, paramsN12.MaxLevel(), paramsN12.DefaultScale()) - encoderN12.EncodeSlots(values, pt, paramsN12.LogSlots()) - ctN12 := encryptorN12.EncryptNew(pt) - - fmt.Printf("Homomorphic Decoding... ") - now = time.Now() - // Homomorphic decoding: [(a+bi), (c+di)] -> [a, c, b, d] - ctN12 = eval.SlotsToCoeffsNew(ctN12, nil, SlotsToCoeffsMatrix) - ctN12.Scale = paramsN10.QiFloat64(0) / 4.0 - eval.DropLevel(ctN12, ctN12.Level()) // drop to LUT level - ctTmp := eval.SwitchKeysNew(ctN12, swkN12ToN10) // key-switch to LWE degree - ctN10 := ckks.NewCiphertext(paramsN10, 1, paramsN10.MaxLevel(), ctTmp.Scale) - rlwe.SwitchCiphertextRingDegreeNTT(ctTmp.Ciphertext, paramsN10.RingQ(), paramsN12.RingQ(), ctN10.Ciphertext) - fmt.Printf("Done (%s)\n", time.Since(now)) - - //for i, v := range encoderN10.DecodeCoeffs(decryptorN10.DecryptNew(ctN10)){ - // fmt.Printf("%3d: %7.4f\n", i, v) - //} - - fmt.Printf("Evaluating LUT... ") - now = time.Now() - // Extracts & EvalLUT(LWEs, indexLUT) on the fly -> Repack(LWEs, indexRepack) -> RLWE - ctN12.Ciphertext = handler.ExtractAndEvaluateLUTAndRepack(ctN10.Ciphertext, lutPolyMap, repackIndex, LUTKEY) - ctN12.Scale = paramsN12.DefaultScale() - fmt.Printf("Done (%s)\n", time.Since(now)) - - //for i, v := range encoderN12.DecodeCoeffs(decryptorN12.DecryptNew(ctN12)){ - // fmt.Printf("%3d: %7.4f\n", i, v) - //} - - fmt.Println("Homomorphic Encoding... ") - now = time.Now() - // [LUT(a), LUT(c), LUT(b), LUT(d)] -> [(LUT(a)+LUT(b)i), (LUT(c)+LUT(d)i)] - ctN12, _ = eval.CoeffsToSlotsNew(ctN12, CoeffsToSlotsMatrix) - fmt.Printf("Done (%s)\n", time.Since(now)) - - v := encoderN12.Decode(decryptorN12.DecryptNew(ctN12), paramsN12.LogSlots()) - - for i := range v { - fmt.Printf("%7.4f -> %7.4f\n", values[i], v[i]) - } -} - -// PrintPoly prints poly -func PrintPoly(pol *ring.Poly, scale float64, Q uint64) { - fmt.Printf("[") - for _, c := range pol.Coeffs[0][:1] { - if c > Q>>1 { - fmt.Printf("%8.4f, ", float64(int(c)-int(Q))/scale) - } else { - fmt.Printf("%8.4f, ", float64(int(c))/scale) - } - } - fmt.Printf("]\n") -} - -// DecryptAndPrint decrypts and prints the first N values. -func DecryptAndPrint(decryptor ckks.Decryptor, LogSlots int, ringQ *ring.Ring, ciphertext *ckks.Ciphertext, scale float64) { - plaintext := decryptor.DecryptNew(ciphertext) - ringQ.InvNTTLvl(ciphertext.Level(), plaintext.Value, plaintext.Value) - - v := make([]float64, 1<= ringQ.Modulus[0]>>1 { - v[i] = -float64(ringQ.Modulus[0] - plaintext.Value.Coeffs[0][j]) - } else { - v[i] = float64(plaintext.Value.Coeffs[0][i]) - } - - v[i] /= scale - } - - for i := 0; i < 1< Q>>1 { - fmt.Printf("%10.6f, ", (float64(c)-float64(Q))/scale) - } else { - fmt.Printf("%10.6f, ", float64(c)/scale) - } - } - } - fmt.Printf("]\n") -} diff --git a/lwe/bin_fhe.go b/lwe/bin_fhe.go deleted file mode 100644 index 30e7eb87..00000000 --- a/lwe/bin_fhe.go +++ /dev/null @@ -1,106 +0,0 @@ -package lwe - -import ( - "github.com/tuneinsight/lattigo/v3/ring" -) - -// m0 : [0, 1/4] -// m1 : [0, 1/4] -// | 0 0 -> 1 (0/8) -> 2/8 -// | 0 1 -> 1 (2/8) -> 2/8 -// | 1 0 -> 1 (2/8) -> 2/8 -// | 1 1 -> 0 (4/8) -> 0/8 -func nandGate(x float64) float64 { - if x > -1/8.0 && x < 3/8.0 { - return 2 / 8.0 - } - - return 0 -} - -// m0 : [0, 1/4] -// m1 : [0, 1/4] -// | 0 0 -> 0 (0/8) -> 0/8 -// | 0 1 -> 0 (2/8) -> 0/8 -// | 1 0 -> 0 (2/8) -> 0/8 -// | 1 1 -> 1 (4/8) -> 2/8 -func andGate(x float64) float64 { - if x > -1/8.0 && x < 3/8.0 { - return 0 - } - - return 1 / 4.0 -} - -// m0 : [0, 1/4] -// m1 : [0, 1/4] -// | 0 0 -> 0 (0/8) -> 0/8 -// | 0 1 -> 1 (2/8) -> 2/8 -// | 1 0 -> 1 (2/8) -> 2/8 -// | 1 1 -> 0 (4/8) -> 0/8 -func xorGate(x float64) float64 { - if x > 1/8.0 && x < 3/8.0 { - return 2 / 8.0 - } - - return 0 -} - -// m0 : [0, 1/4] -// m1 : [0, 1/4] -// | 0 0 -> 1 (0/8) -> 2/8 -// | 0 1 -> 0 (2/8) -> 0/8 -// | 1 0 -> 0 (2/8) -> 0/8 -// | 1 1 -> 1 (4/8) -> 2/8 -func nxorGate(x float64) float64 { - if x > 1/8.0 && x < 3/8.0 { - return 0 - } - - return 2 / 8.0 -} - -// m0 : [0, 1/4] -// m1 : [0, 1/4] -// | 0 0 -> 0 (0/8) -> 0/8 -// | 0 1 -> 1 (2/8) -> 2/8 -// | 1 0 -> 1 (2/8) -> 2/8 -// | 1 1 -> 1 (4/8) -> 2/8 -func orGate(x float64) float64 { - if x > 1/8.0 && x < 5/8.0 { - return 2 / 8.0 - } - - return 0 -} - -// m0 : [0, 1/4] -// m1 : [0, 1/4] -// | 0 0 -> 1 (0/8) -> 2/8 -// | 0 1 -> 0 (2/8) -> 0/8 -// | 1 0 -> 0 (2/8) -> 0/8 -// | 1 1 -> 0 (4/8) -> 0/8 -func norGate(x float64) float64 { - if x > 1/8.0 && x < 5/8.0 { - return 0 - } - - return 2 / 8.0 -} - -// m0 : [0, 1/4] -// m1 : [0, 1/4] -// | 0 -> 1 (0/8) -> 2/8 -// | 1 -> 0 (2/8) -> 0/8 -func notGate(x float64) float64 { - if x > 1/8.0 && x < 3/8.0 { - return 0 - } - - return 2 / 8.0 -} - -// InitGate generate the test rlwe plaintext for the selected gate. -func InitGate(gate func(x float64) float64, r *ring.Ring) *ring.Poly { - return InitLUT(gate, float64(r.Modulus[0])/4.0, r, -1, 1) -} diff --git a/lwe/handler.go b/lwe/handler.go deleted file mode 100644 index db62c056..00000000 --- a/lwe/handler.go +++ /dev/null @@ -1,322 +0,0 @@ -package lwe - -import ( - "github.com/tuneinsight/lattigo/v3/ring" - "github.com/tuneinsight/lattigo/v3/rlwe" - "github.com/tuneinsight/lattigo/v3/rlwe/rgsw" - "github.com/tuneinsight/lattigo/v3/rlwe/ringqp" - "math/big" -) - -// Handler is a struct that stores necessary -// data to handle LWE <-> RLWE conversion and -// LUT evaluation. -type Handler struct { - evalRGSW *rgsw.Evaluator - evalRLWE *rlwe.Evaluator - paramsLUT rlwe.Parameters - paramsLWE rlwe.Parameters - rtks *rlwe.RotationKeySet - - xPowMinusOne []ringqp.Poly //X^n - 1 from 0 to 2N LWE - - poolMod2N [2]*ring.Poly - - accumulator *rlwe.Ciphertext - Sk *rlwe.SecretKey -} - -// NewHandler creates a new Handler -func NewHandler(paramsLUT, paramsLWE rlwe.Parameters, rtks *rlwe.RotationKeySet) (h *Handler) { - h = new(Handler) - h.evalRGSW = rgsw.NewEvaluator(paramsLUT) - h.evalRLWE = rlwe.NewEvaluator(paramsLUT, &rlwe.EvaluationKey{Rtks: rtks}) - h.paramsLUT = paramsLUT - h.paramsLWE = paramsLWE - - ringQ := paramsLUT.RingQ() - ringP := paramsLUT.RingP() - - h.poolMod2N = [2]*ring.Poly{paramsLWE.RingQ().NewPolyLvl(0), paramsLWE.RingQ().NewPolyLvl(0)} - h.accumulator = rlwe.NewCiphertextNTT(paramsLUT, 1, paramsLUT.MaxLevel()) - - // Compute X^{n} - 1 from 0 to 2N LWE - oneNTTMFormQ := ringQ.NewPoly() - for i := range ringQ.Modulus { - for j := 0; j < ringQ.N; j++ { - oneNTTMFormQ.Coeffs[i][j] = ring.MForm(1, ringQ.Modulus[i], ringQ.BredParams[i]) - } - } - - N := ringQ.N - - h.xPowMinusOne = make([]ringqp.Poly, 2*N) - for i := 0; i < N; i++ { - h.xPowMinusOne[i].Q = ringQ.NewPoly() - h.xPowMinusOne[i+N].Q = ringQ.NewPoly() - if i == 0 || i == 1 { - for j := range ringQ.Modulus { - h.xPowMinusOne[i].Q.Coeffs[j][i] = ring.MForm(1, ringQ.Modulus[j], ringQ.BredParams[j]) - } - - ringQ.NTT(h.xPowMinusOne[i].Q, h.xPowMinusOne[i].Q) - - // Negacyclic wrap-arround for n > N - ringQ.Neg(h.xPowMinusOne[i].Q, h.xPowMinusOne[i+N].Q) - - } else { - ringQ.MulCoeffsMontgomery(h.xPowMinusOne[1].Q, h.xPowMinusOne[i-1].Q, h.xPowMinusOne[i].Q) // X^{n} = X^{1} * X^{n-1} - - // Negacyclic wrap-arround for n > N - ringQ.Neg(h.xPowMinusOne[i].Q, h.xPowMinusOne[i+N].Q) // X^{2n} = -X^{1} * X^{n-1} - } - } - - // Subtract -1 in NTT - for i := 0; i < 2*N; i++ { - ringQ.Sub(h.xPowMinusOne[i].Q, oneNTTMFormQ, h.xPowMinusOne[i].Q) // X^{n} - 1 - } - - if ringP != nil { - oneNTTMFormP := ringP.NewPoly() - for i := range ringP.Modulus { - for j := 0; j < ringP.N; j++ { - oneNTTMFormP.Coeffs[i][j] = ring.MForm(1, ringP.Modulus[i], ringP.BredParams[i]) - } - } - - for i := 0; i < N; i++ { - h.xPowMinusOne[i].P = ringP.NewPoly() - h.xPowMinusOne[i+N].P = ringP.NewPoly() - if i == 0 || i == 1 { - for j := range ringP.Modulus { - h.xPowMinusOne[i].P.Coeffs[j][i] = ring.MForm(1, ringP.Modulus[j], ringP.BredParams[j]) - } - - ringP.NTT(h.xPowMinusOne[i].P, h.xPowMinusOne[i].P) - - // Negacyclic wrap-arround for n > N - ringP.Neg(h.xPowMinusOne[i].P, h.xPowMinusOne[i+N].P) - - } else { - // X^{n} = X^{1} * X^{n-1} - ringP.MulCoeffsMontgomery(h.xPowMinusOne[1].P, h.xPowMinusOne[i-1].P, h.xPowMinusOne[i].P) - - // Negacyclic wrap-arround for n > N - // X^{2n} = -X^{1} * X^{n-1} - ringP.Neg(h.xPowMinusOne[i].P, h.xPowMinusOne[i+N].P) - } - } - - // Subtract -1 in NTT - for i := 0; i < 2*N; i++ { - // X^{n} - 1 - ringP.Sub(h.xPowMinusOne[i].P, oneNTTMFormP, h.xPowMinusOne[i].P) - } - } - - return -} - -func (h *Handler) permuteNTTIndexesForKey(rtks *rlwe.RotationKeySet) *map[uint64][]uint64 { - if rtks == nil { - return &map[uint64][]uint64{} - } - permuteNTTIndex := make(map[uint64][]uint64, len(rtks.Keys)) - for galEl := range rtks.Keys { - permuteNTTIndex[galEl] = h.paramsLUT.RingQ().PermuteNTTIndex(galEl) - } - return &permuteNTTIndex -} - -// LUTKey is a struct storing the encryption -// of the bits of the LWE key. -type LUTKey struct { - SkPos []*rgsw.Ciphertext - SkNeg []*rgsw.Ciphertext - OneRGSW []*ring.Poly -} - -// GenLUTKey generates the LUT evaluation key -func (h *Handler) GenLUTKey(skRLWE, skLWE *rlwe.SecretKey) (lutkey *LUTKey) { - - paramsLUT := h.paramsLUT - paramsLWE := h.paramsLWE - - skLWEInvNTT := h.paramsLWE.RingQ().NewPoly() - - paramsLWE.RingQ().InvNTT(skLWE.Value.Q, skLWEInvNTT) - - plaintextRGSWOne := rlwe.NewPlaintext(paramsLUT, paramsLUT.MaxLevel()) - plaintextRGSWOne.Value.IsNTT = true - for j := 0; j < paramsLUT.QCount(); j++ { - for i := 0; i < paramsLUT.N(); i++ { - plaintextRGSWOne.Value.Coeffs[j][i] = 1 - } - } - - encryptor := rgsw.NewEncryptor(paramsLUT, skRLWE) - - ringQLUT := paramsLUT.RingQ() - ringPLUT := paramsLUT.RingP() - - levelQ := paramsLUT.QCount() - 1 - levelP := paramsLUT.PCount() - 1 - - var pBigInt *big.Int - if levelP > -1 { - if levelP == paramsLUT.PCount()-1 { - pBigInt = ringPLUT.ModulusBigint - } else { - P := ringPLUT.Modulus - pBigInt = new(big.Int).SetUint64(P[0]) - for i := 1; i < levelP+1; i++ { - pBigInt.Mul(pBigInt, ring.NewUint(P[i])) - } - } - } else { - pBigInt = ring.NewUint(1) - } - - OneRGSW := make([]*ring.Poly, paramsLUT.DecompBIT(paramsLUT.QCount()-1, paramsLUT.PCount()-1)) - OneRGSW[0] = ringQLUT.NewPoly() - tmp := new(big.Int) - for i := 0; i < levelQ+1; i++ { - OneRGSW[0].Coeffs[i][0] = tmp.Mod(pBigInt, ring.NewUint(ringQLUT.Modulus[i])).Uint64() - } - - ringQLUT.NTTLvl(levelQ, OneRGSW[0], OneRGSW[0]) - ringQLUT.MFormLvl(levelQ, OneRGSW[0], OneRGSW[0]) - - for i := 1; i < len(OneRGSW); i++ { - OneRGSW[i] = OneRGSW[0].CopyNew() - ringQLUT.MulByPow2(OneRGSW[i], i*paramsLUT.LogBase2(), OneRGSW[i]) - } - - skRGSWPos := make([]*rgsw.Ciphertext, paramsLWE.N()) - skRGSWNeg := make([]*rgsw.Ciphertext, paramsLWE.N()) - - ringQ := paramsLWE.RingQ() - Q := ringQ.Modulus[0] - OneMForm := ring.MForm(1, Q, ringQ.BredParams[0]) - MinusOneMform := ring.MForm(Q-1, Q, ringQ.BredParams[0]) - - for i, si := range skLWEInvNTT.Coeffs[0] { - - skRGSWPos[i] = rgsw.NewCiphertextNTT(paramsLUT, paramsLUT.MaxLevel()) - skRGSWNeg[i] = rgsw.NewCiphertextNTT(paramsLUT, paramsLUT.MaxLevel()) - - // sk_i = 1 -> [RGSW(1), RGSW(0)] - if si == OneMForm { - encryptor.Encrypt(plaintextRGSWOne, skRGSWPos[i]) - encryptor.Encrypt(nil, skRGSWNeg[i]) - // sk_i = -1 -> [RGSW(0), RGSW(1)] - } else if si == MinusOneMform { - encryptor.Encrypt(nil, skRGSWPos[i]) - encryptor.Encrypt(plaintextRGSWOne, skRGSWNeg[i]) - // sk_i = 0 -> [RGSW(0), RGSW(0)] - } else { - encryptor.Encrypt(nil, skRGSWPos[i]) - encryptor.Encrypt(nil, skRGSWNeg[i]) - } - } - - return &LUTKey{SkPos: skRGSWPos, SkNeg: skRGSWNeg, OneRGSW: OneRGSW} -} - -// ReduceRGSW applies a homomorphic modular reduction on the input RGSW ciphertext and returns -// the result on the output RGSW ciphertext. -func ReduceRGSW(ctIn *rgsw.Ciphertext, ringQP *ringqp.Ring, ctOut *rgsw.Ciphertext) { - - ringQ := ringQP.RingQ - ringP := ringQP.RingP - - for i := range ctIn.Value[0].Value { - for j := range ctIn.Value[0].Value[i] { - - ringQ.Reduce(ctIn.Value[0].Value[i][j][0].Q, ctOut.Value[0].Value[i][j][0].Q) - ringQ.Reduce(ctIn.Value[0].Value[i][j][1].Q, ctOut.Value[0].Value[i][j][1].Q) - ringQ.Reduce(ctIn.Value[1].Value[i][j][0].Q, ctOut.Value[1].Value[i][j][0].Q) - ringQ.Reduce(ctIn.Value[1].Value[i][j][1].Q, ctOut.Value[1].Value[i][j][1].Q) - - if ringP != nil { - ringP.Reduce(ctIn.Value[0].Value[i][j][0].P, ctOut.Value[0].Value[i][j][0].P) - ringP.Reduce(ctIn.Value[0].Value[i][j][1].P, ctOut.Value[0].Value[i][j][1].P) - ringP.Reduce(ctIn.Value[1].Value[i][j][0].P, ctOut.Value[1].Value[i][j][0].P) - ringP.Reduce(ctIn.Value[1].Value[i][j][1].P, ctOut.Value[1].Value[i][j][1].P) - } - } - } -} - -// AddRGSW adds the input RGSW ciphertext on the output RGSW ciphertext. -func AddRGSW(ctIn *rgsw.Ciphertext, ringQP *ringqp.Ring, ctOut *rgsw.Ciphertext) { - - ringQ := ringQP.RingQ - ringP := ringQP.RingP - - for i := range ctIn.Value[0].Value { - for j := range ctIn.Value[0].Value[i] { - - ringQ.AddNoMod(ctOut.Value[0].Value[i][j][0].Q, ctIn.Value[0].Value[i][j][0].Q, ctOut.Value[0].Value[i][j][0].Q) - ringQ.AddNoMod(ctOut.Value[0].Value[i][j][1].Q, ctIn.Value[0].Value[i][j][1].Q, ctOut.Value[0].Value[i][j][1].Q) - ringQ.AddNoMod(ctOut.Value[1].Value[i][j][0].Q, ctIn.Value[1].Value[i][j][0].Q, ctOut.Value[1].Value[i][j][0].Q) - ringQ.AddNoMod(ctOut.Value[1].Value[i][j][1].Q, ctIn.Value[1].Value[i][j][1].Q, ctOut.Value[1].Value[i][j][1].Q) - - if ringP != nil { - ringP.AddNoMod(ctOut.Value[0].Value[i][j][0].P, ctIn.Value[0].Value[i][j][0].P, ctOut.Value[0].Value[i][j][0].P) - ringP.AddNoMod(ctOut.Value[0].Value[i][j][1].P, ctIn.Value[0].Value[i][j][1].P, ctOut.Value[0].Value[i][j][1].P) - ringP.AddNoMod(ctOut.Value[1].Value[i][j][0].P, ctIn.Value[1].Value[i][j][0].P, ctOut.Value[1].Value[i][j][0].P) - ringP.AddNoMod(ctOut.Value[1].Value[i][j][1].P, ctIn.Value[1].Value[i][j][1].P, ctOut.Value[1].Value[i][j][1].P) - } - } - } -} - -// MulRGSWByXPowAlphaMinusOne multiplies the input RGSW ciphertext by (X^alpha - 1) and returns the result on the output RGSW ciphertext. -func MulRGSWByXPowAlphaMinusOne(ctIn *rgsw.Ciphertext, powXMinusOne ringqp.Poly, ringQP *ringqp.Ring, ctOut *rgsw.Ciphertext) { - - ringQ := ringQP.RingQ - ringP := ringQP.RingP - - for i := range ctIn.Value[0].Value { - for j := range ctIn.Value[0].Value[i] { - - ringQ.MulCoeffsMontgomeryConstant(ctIn.Value[0].Value[i][j][0].Q, powXMinusOne.Q, ctOut.Value[0].Value[i][j][0].Q) - ringQ.MulCoeffsMontgomeryConstant(ctIn.Value[0].Value[i][j][1].Q, powXMinusOne.Q, ctOut.Value[0].Value[i][j][1].Q) - ringQ.MulCoeffsMontgomeryConstant(ctIn.Value[1].Value[i][j][0].Q, powXMinusOne.Q, ctOut.Value[1].Value[i][j][0].Q) - ringQ.MulCoeffsMontgomeryConstant(ctIn.Value[1].Value[i][j][1].Q, powXMinusOne.Q, ctOut.Value[1].Value[i][j][1].Q) - - if ringP != nil { - ringP.MulCoeffsMontgomeryConstant(ctIn.Value[0].Value[i][j][0].P, powXMinusOne.P, ctOut.Value[0].Value[i][j][0].P) - ringP.MulCoeffsMontgomeryConstant(ctIn.Value[0].Value[i][j][1].P, powXMinusOne.P, ctOut.Value[0].Value[i][j][1].P) - ringP.MulCoeffsMontgomeryConstant(ctIn.Value[1].Value[i][j][0].P, powXMinusOne.P, ctOut.Value[1].Value[i][j][0].P) - ringP.MulCoeffsMontgomeryConstant(ctIn.Value[1].Value[i][j][1].P, powXMinusOne.P, ctOut.Value[1].Value[i][j][1].P) - } - } - } -} - -// MulRGSWByXPowAlphaMinusOneAndAdd multiplies the input RGSW ciphertext by (X^alpha - 1) and adds the result on the output RGSW ciphertext. -func MulRGSWByXPowAlphaMinusOneAndAdd(ctIn *rgsw.Ciphertext, powXMinusOne ringqp.Poly, ringQP *ringqp.Ring, ctOut *rgsw.Ciphertext) { - - ringQ := ringQP.RingQ - ringP := ringQP.RingP - - for i := range ctIn.Value[0].Value { - for j := range ctIn.Value[0].Value[i] { - - ringQ.MulCoeffsMontgomeryConstantAndAddNoMod(ctIn.Value[0].Value[i][j][0].Q, powXMinusOne.Q, ctOut.Value[0].Value[i][j][0].Q) - ringQ.MulCoeffsMontgomeryConstantAndAddNoMod(ctIn.Value[0].Value[i][j][1].Q, powXMinusOne.Q, ctOut.Value[0].Value[i][j][1].Q) - ringQ.MulCoeffsMontgomeryConstantAndAddNoMod(ctIn.Value[1].Value[i][j][0].Q, powXMinusOne.Q, ctOut.Value[1].Value[i][j][0].Q) - ringQ.MulCoeffsMontgomeryConstantAndAddNoMod(ctIn.Value[1].Value[i][j][1].Q, powXMinusOne.Q, ctOut.Value[1].Value[i][j][1].Q) - - if ringP != nil { - ringP.MulCoeffsMontgomeryConstantAndAddNoMod(ctIn.Value[0].Value[i][j][0].P, powXMinusOne.P, ctOut.Value[0].Value[i][j][0].P) - ringP.MulCoeffsMontgomeryConstantAndAddNoMod(ctIn.Value[0].Value[i][j][1].P, powXMinusOne.P, ctOut.Value[0].Value[i][j][1].P) - ringP.MulCoeffsMontgomeryConstantAndAddNoMod(ctIn.Value[1].Value[i][j][0].P, powXMinusOne.P, ctOut.Value[1].Value[i][j][0].P) - ringP.MulCoeffsMontgomeryConstantAndAddNoMod(ctIn.Value[1].Value[i][j][1].P, powXMinusOne.P, ctOut.Value[1].Value[i][j][1].P) - } - } - } -} diff --git a/lwe/lut.go b/lwe/lut.go deleted file mode 100644 index 579f1795..00000000 --- a/lwe/lut.go +++ /dev/null @@ -1,244 +0,0 @@ -package lwe - -import ( - "github.com/tuneinsight/lattigo/v3/ring" - "github.com/tuneinsight/lattigo/v3/rlwe" - "github.com/tuneinsight/lattigo/v3/rlwe/rgsw" - "math/big" -) - -// InitLUT takes a function g, and creates an LUT polynomial for the function between the intervals a, b. -// Inputs to the LUT evaluation are assumed to have been normalized with the change of basis (2*x - a - b)/(b-a). -// Interval a, b should take into account the "drift" of the value x, caused by the change of modulus from Q to 2N. -func InitLUT(g func(x float64) (y float64), scale float64, ringQ *ring.Ring, a, b float64) (F *ring.Poly) { - F = ringQ.NewPoly() - Q := ringQ.Modulus - - // Discretization interval - interval := 2.0 / float64(ringQ.N) - - for j, qi := range Q { - - // Interval [-1, 0] of g(x) - for i := 0; i < (ringQ.N>>1)+1; i++ { - F.Coeffs[j][i] = scaleUp(g(normalizeInv(-interval*float64(i), a, b)), scale, qi) - } - - // Interval ]0, 1[ of g(x) - for i := (ringQ.N >> 1) + 1; i < ringQ.N; i++ { - F.Coeffs[j][i] = scaleUp(-g(normalizeInv(interval*float64(ringQ.N-i), a, b)), scale, qi) - } - } - - ringQ.NTT(F, F) - - return -} - -// ExtractAndEvaluateLUTAndRepack extracts on the fly LWE samples and evaluate the provided LUT on the LWE and repacks everything into a single rlwe.Ciphertext. -// ct : a rlwe Ciphertext with coefficient encoded values at level 0 -// lutPolyWihtSlotIndex : a map with [slot_index] -> LUT -// repackIndex : a map with [slot_index_have] -> slot_index_want -// lutKey : LUTKey -// Returns a *rlwe.Ciphertext -func (h *Handler) ExtractAndEvaluateLUTAndRepack(ct *rlwe.Ciphertext, lutPolyWihtSlotIndex map[int]*ring.Poly, repackIndex map[int]int, lutKey *LUTKey) (res *rlwe.Ciphertext) { - cts := h.ExtractAndEvaluateLUT(ct, lutPolyWihtSlotIndex, lutKey) - - ciphertexts := make(map[int]*rlwe.Ciphertext) - - for i := range cts { - ciphertexts[repackIndex[i]] = cts[i] - } - - return h.evalRLWE.MergeRLWE(ciphertexts) -} - -// EvalGate evaluates the selected binary gate on the input rlwe ciphertext -func (h *Handler) EvalGate(ct *Ciphertext, logNLWE int, gate *ring.Poly, lutKey *LUTKey) { - eval := h.evalRGSW - - acc := h.accumulator - - ringQLUT := h.paramsLUT.RingQ() - ringQLWE := h.paramsLWE.RingQ() - ringQPLUT := h.paramsLUT.RingQP() - - // mod 2N - mask := uint64(ringQLUT.N<<1) - 1 - - tmpRGSW := rgsw.NewCiphertextNTT(h.paramsLUT, h.paramsLUT.MaxLevel()) - - a := ct.Value[0][1:] - b := ct.Value[0][0] - - // LWE = -as + m + e, a - // LUT = LUT * X^{-as + m + e} - ringQLUT.MulCoeffsMontgomery(gate, h.xPowMinusOne[b].Q, acc.Value[0]) - ringQLUT.Add(acc.Value[0], gate, acc.Value[0]) - acc.Value[1].Zero() // TODO remove - for i := 0; i < ringQLWE.N; i++ { - MulRGSWByXPowAlphaMinusOne(lutKey.SkPos[i], h.xPowMinusOne[a[i]], ringQPLUT, tmpRGSW) - MulRGSWByXPowAlphaMinusOneAndAdd(lutKey.SkNeg[i], h.xPowMinusOne[-a[i]&mask], ringQPLUT, tmpRGSW) - AddOneRGSW(lutKey.OneRGSW, ringQLUT, tmpRGSW) - eval.ExternalProduct(acc, tmpRGSW, acc) - } - - eval.SwitchKeysInPlace(0, acc.Value[1], nil, eval.Pool[1].Q, eval.Pool[2].Q) // TODO : add RLWE -> LWE Key - ringQLUT.AddLvl(0, acc.Value[0], eval.Pool[1].Q, acc.Value[0]) - ringQLUT.InvNTT(eval.Pool[2].Q, acc.Value[1]) - ringQLUT.InvNTT(acc.Value[0], acc.Value[0]) - - Qflo := float64(ringQLUT.Modulus[0]) - maskLWE := uint64(2< LUT -// lutKey : LUTKey -// Returns a map[slot_index] -> LUT(ct[slot_index]) -func (h *Handler) ExtractAndEvaluateLUT(ct *rlwe.Ciphertext, lutPolyWihtSlotIndex map[int]*ring.Poly, lutKey *LUTKey) (res map[int]*rlwe.Ciphertext) { - - eval := h.evalRGSW - - bRLWEMod2N := h.poolMod2N[0] - aRLWEMod2N := h.poolMod2N[1] - - acc := h.accumulator - - ringQLUT := h.paramsLUT.RingQ() - ringQLWE := h.paramsLWE.RingQ() - ringQPLUT := h.paramsLUT.RingQP() - - // mod 2N - mask := uint64(ringQLUT.N<<1) - 1 - - ringQLWE.InvNTTLvl(ct.Level(), ct.Value[0], acc.Value[0]) - ringQLWE.InvNTTLvl(ct.Level(), ct.Value[1], acc.Value[1]) - - // Switch modulus from Q to 2N - h.ModSwitchRLWETo2NLvl(ct.Level(), acc.Value[1], acc.Value[1]) - - // Conversion from Convolution(a, sk) to DotProd(a, sk) for LWE decryption. - // Copy coefficients multiplied by X^{N-1} in reverse order: - // a_{0} -a_{N-1} -a2_{N-2} ... -a_{1} - tmp0 := aRLWEMod2N.Coeffs[0] - tmp1 := acc.Value[1].Coeffs[0] - tmp0[0] = tmp1[0] - for j := 1; j < ringQLWE.N; j++ { - tmp0[j] = -tmp1[ringQLWE.N-j] & mask - } - - h.ModSwitchRLWETo2NLvl(ct.Level(), acc.Value[0], bRLWEMod2N) - - res = make(map[int]*rlwe.Ciphertext) - - tmpRGSW := rgsw.NewCiphertextNTT(h.paramsLUT, h.paramsLUT.MaxLevel()) - - var prevIndex int - for index := 0; index < ringQLWE.N; index++ { - - if lut, ok := lutPolyWihtSlotIndex[index]; ok { - - MulBySmallMonomialMod2N(mask, aRLWEMod2N, index-prevIndex) - prevIndex = index - - a := aRLWEMod2N.Coeffs[0] - b := bRLWEMod2N.Coeffs[0][index] - - // LWE = -as + m + e, a - // LUT = LUT * X^{-as + m + e} - ringQLUT.MulCoeffsMontgomery(lut, h.xPowMinusOne[b].Q, acc.Value[0]) - ringQLUT.Add(acc.Value[0], lut, acc.Value[0]) - acc.Value[1].Zero() // TODO remove - - for j := 0; j < ringQLWE.N; j++ { - // RGSW[(X^{a} - 1) * sk_{j}[0] + (X^{-a} - 1) * sk_{j}[1] + 1] - MulRGSWByXPowAlphaMinusOne(lutKey.SkPos[j], h.xPowMinusOne[a[j]], ringQPLUT, tmpRGSW) - MulRGSWByXPowAlphaMinusOneAndAdd(lutKey.SkNeg[j], h.xPowMinusOne[-a[j]&mask], ringQPLUT, tmpRGSW) - AddOneRGSW(lutKey.OneRGSW, ringQLUT, tmpRGSW) - - // LUT[RLWE] = LUT[RLWE] x RGSW[(X^{a} - 1) * sk_{j}[0] + (X^{-a} - 1) * sk_{j}[1] + 1] - eval.ExternalProduct(acc, tmpRGSW, acc) - } - - res[index] = acc.CopyNew() - } - - // LUT[RLWE] = LUT[RLWE] * X^{m+e} - } - - return -} - -// AddOneRGSW adds one in plaintext on the output RGSW ciphertext. -func AddOneRGSW(oneRGSW []*ring.Poly, ringQ *ring.Ring, res *rgsw.Ciphertext) { - nQ := res.LevelQ() + 1 - nP := res.LevelP() + 1 - - if nP == 0 { - nP = 1 - } - - for i := range res.Value[0].Value { - for j := range res.Value[0].Value[i] { - start, end := i*nP, (i+1)*nP - if end > nQ { - end = nQ - } - for k := start; k < end; k++ { - ring.AddVecNoMod(res.Value[0].Value[i][j][0].Q.Coeffs[k], oneRGSW[j].Coeffs[k], res.Value[0].Value[i][j][0].Q.Coeffs[k]) - ring.AddVecNoMod(res.Value[1].Value[i][j][1].Q.Coeffs[k], oneRGSW[j].Coeffs[k], res.Value[1].Value[i][j][1].Q.Coeffs[k]) - } - } - } -} - -//MulBySmallMonomialMod2N multiplies pol by x^n, with 0 <= n < N -func MulBySmallMonomialMod2N(mask uint64, pol *ring.Poly, n int) { - if n != 0 { - N := len(pol.Coeffs[0]) - pol.Coeffs[0] = append(pol.Coeffs[0][N-n:], pol.Coeffs[0][:N-n]...) - tmp := pol.Coeffs[0] - for j := 0; j < n; j++ { - tmp[j] = -tmp[j] & mask - } - } -} - -// ModSwitchRLWETo2NLvl applys round(x * 2N / Q) to the coefficients of polQ and returns the result on pol2N. -func (h *Handler) ModSwitchRLWETo2NLvl(level int, polQ *ring.Poly, pol2N *ring.Poly) { - coeffsBigint := make([]*big.Int, len(polQ.Coeffs[0])) - - ringQ := h.paramsLWE.RingQ() - - ringQ.PolyToBigintLvl(level, polQ, 1, coeffsBigint) - - QBig := ring.NewUint(1) - for i := 0; i < level+1; i++ { - QBig.Mul(QBig, ring.NewUint(ringQ.Modulus[i])) - } - - twoN := uint64(h.paramsLUT.N() << 1) - twoNBig := ring.NewUint(twoN) - tmp := pol2N.Coeffs[0] - for i := 0; i < ringQ.N; i++ { - coeffsBigint[i].Mul(coeffsBigint[i], twoNBig) - ring.DivRound(coeffsBigint[i], QBig, coeffsBigint[i]) - tmp[i] = coeffsBigint[i].Uint64() & (twoN - 1) - } -} diff --git a/lwe/lwe.go b/lwe/lwe.go deleted file mode 100644 index 622a5176..00000000 --- a/lwe/lwe.go +++ /dev/null @@ -1,88 +0,0 @@ -package lwe - -import ( - "github.com/tuneinsight/lattigo/v3/ring" - "github.com/tuneinsight/lattigo/v3/utils" -) - -// Plaintext is CRT representation of an -// integer m. -type Plaintext struct { - Value []uint64 -} - -// Ciphertext is a CRT representation of -// the LWE sample of size N+1, for N -// the degree of the LWE sample. -// The first element of the slice stores the CRT -// representation of <-a, s> + m + e and the N -// next element the CRT representation of a. -type Ciphertext struct { - Value [][]uint64 -} - -// Level returns the CRT level of the target. -func (pt *Plaintext) Level() int { - return len(pt.Value) - 1 -} - -// Level returns the CRT level of the target. -func (ct *Ciphertext) Level() int { - return len(ct.Value) - 1 -} - -// NewCiphertext allocates a new LWE sample of degree N -// and level level. -func NewCiphertext(N, level int) (ct *Ciphertext) { - ct = new(Ciphertext) - ct.Value = make([][]uint64, level+1) - for i := 0; i < level+1; i++ { - ct.Value[i] = make([]uint64, N+1) - } - return ct -} - -// Add adds ct0 to ct1 and returns the result on ct2. -func (h *Handler) Add(ct0, ct1, ct2 *Ciphertext) { - - level := utils.MinInt(utils.MinInt(ct0.Level(), ct1.Level()), ct2.Level()) - - for i := 0; i < level+1; i++ { - Q := h.paramsLUT.RingQ().Modulus[i] - ring.AddVec(ct0.Value[i][1:], ct1.Value[i][1:], ct2.Value[i][1:], Q) - ct2.Value[i][0] = ring.CRed(ct0.Value[i][0]+ct1.Value[i][0], Q) - } - - ct2.Value = ct2.Value[:level+1] -} - -// Sub subtracts ct1 to ct0 and returns the result on ct2. -func (h *Handler) Sub(ct0, ct1, ct2 *Ciphertext) { - - level := utils.MinInt(utils.MinInt(ct0.Level(), ct1.Level()), ct2.Level()) - - Q := h.paramsLUT.RingQ().Modulus - for i := 0; i < level+1; i++ { - ring.SubVec(ct0.Value[i][1:], ct1.Value[i][1:], ct2.Value[i][1:], Q[i]) - ct2.Value[i][0] = ring.CRed(Q[i]+ct0.Value[i][0]-ct1.Value[i][0], Q[i]) - } - - ct2.Value = ct2.Value[:level+1] -} - -// MulScalar multiplies ct0 by the provided scalar and returns the result on ct1. -func (h *Handler) MulScalar(ct0 *Ciphertext, scalar uint64, ct1 *Ciphertext) { - - level := utils.MinInt(ct0.Level(), ct1.Level()) - - ringQ := h.paramsLUT.RingQ() - for i := 0; i < level+1; i++ { - Q := ringQ.Modulus[i] - mredParams := ringQ.MredParams[i] - scalarMont := ring.MForm(scalar, Q, ringQ.BredParams[i]) - ring.MulScalarMontgomeryVec(ct0.Value[i][1:], ct1.Value[i][1:], scalarMont, Q, mredParams) - ct1.Value[i][0] = ring.MRed(ct0.Value[i][0], scalarMont, Q, mredParams) - } - - ct1.Value = ct1.Value[:level+1] -} diff --git a/lwe/lwe_test.go b/lwe/lwe_test.go deleted file mode 100644 index 3a18baee..00000000 --- a/lwe/lwe_test.go +++ /dev/null @@ -1,375 +0,0 @@ -package lwe - -import ( - "encoding/json" - "flag" - "fmt" - "github.com/stretchr/testify/assert" - "github.com/tuneinsight/lattigo/v3/ring" - "github.com/tuneinsight/lattigo/v3/rlwe" - "math" - "runtime" - "testing" -) - -var flagParamString = flag.String("params", "", "specify the test cryptographic parameters as a JSON string. Overrides -short and -long.") - -var TestParams = []rlwe.ParametersLiteral{rlwe.TestPN12QP109, rlwe.TestPN13QP218} - -func testString(params rlwe.Parameters, opname string) string { - return fmt.Sprintf("%slogN=%d/logQ=%d/logP=%d/#Qi=%d/#Pi=%d", - opname, - params.LogN(), - params.LogQ(), - params.LogP(), - params.QCount(), - params.PCount()) -} - -func TestLWE(t *testing.T) { - defaultParams := TestParams // the default test runs for ring degree N=2^12, 2^13, 2^14, 2^15 - if testing.Short() { - defaultParams = TestParams[:1] // the short test suite runs for ring degree N=2^12, 2^13 - } - - if *flagParamString != "" { - var jsonParams rlwe.ParametersLiteral - json.Unmarshal([]byte(*flagParamString), &jsonParams) - defaultParams = []rlwe.ParametersLiteral{jsonParams} // the custom test suite reads the parameters from the -params flag - } - - for _, defaultParam := range defaultParams[1:] { - - params, err := rlwe.NewParametersFromLiteral(defaultParam) - if err != nil { - panic(err) - } - - for _, testSet := range []func(params rlwe.Parameters, t *testing.T){ - testBinFHE, - testLUT, - testRLWEToLWE, - testLWEToRLWE, - } { - testSet(params, t) - runtime.GC() - } - } -} - -func sign(x float64) (y float64) { - if x < 0 { - return -1 - } - - if x == 0 { - return 0 - } - - return 1 -} -func testBinFHE(params rlwe.Parameters, t *testing.T) { - var err error - - // N=1024, Q=0x7fff801 -> 2^131 - paramsLUT, err := rlwe.NewParametersFromLiteral(rlwe.ParametersLiteral{ - LogN: 10, - Q: []uint64{0x7fff801}, - P: []uint64{}, - Sigma: rlwe.DefaultSigma, - LogBase2: 7, - }) - - assert.Nil(t, err) - - // N=512, Q=0x3001 -> 2^135 - paramsLWE, err := rlwe.NewParametersFromLiteral(rlwe.ParametersLiteral{ - LogN: 9, - Q: []uint64{0x3001}, - P: []uint64{}, - Sigma: rlwe.DefaultSigma, - }) - - assert.Nil(t, err) - - paramsKS, err := rlwe.NewParametersFromLiteral(rlwe.ParametersLiteral{ - LogN: 9, - Q: []uint64{0x7fff801}, - P: []uint64{}, - Sigma: rlwe.DefaultSigma, - LogBase2: 0, - }) - - assert.Nil(t, err) - - t.Run(testString(paramsLUT, "BinFHE/"), func(t *testing.T) { - - ringLWE := paramsLWE.RingQ() - ringLUT := paramsLUT.RingQ() - - scaleLWE := float64(paramsLWE.Q()[0]) / 4.0 - scaleLUT := float64(paramsLUT.Q()[0]) / 4.0 - - slots := 1 - - LUTPoly := InitGate(xorGate, ringLUT) - - lutPolyMap := make(map[int]*ring.Poly) - for i := 0; i < slots; i++ { - lutPolyMap[i] = LUTPoly - } - - skLWE := rlwe.NewKeyGenerator(paramsLWE).GenSecretKey() - encryptorLWE := rlwe.NewEncryptor(paramsLWE, skLWE) - - m0 := rlwe.NewPlaintext(paramsLWE, paramsLWE.MaxLevel()) - m0.Value.Coeffs[0][0] = uint64(1 * scaleLWE / 4.0) - - m1 := rlwe.NewPlaintext(paramsLWE, paramsLWE.MaxLevel()) - m1.Value.Coeffs[0][0] = uint64(1 * scaleLWE / 4.0) - - ctm0 := rlwe.NewCiphertextNTT(paramsLWE, 1, paramsLWE.MaxLevel()) - encryptorLWE.Encrypt(m0, ctm0) - - ctm1 := rlwe.NewCiphertextNTT(paramsLWE, 1, paramsLWE.MaxLevel()) - encryptorLWE.Encrypt(m1, ctm1) - - handler := NewHandler(paramsLUT, paramsLWE, nil) - - kgenLUT := rlwe.NewKeyGenerator(paramsLUT) - skLUT := kgenLUT.GenSecretKey() - LUTKEY := handler.GenLUTKey(skLUT, skLWE) - - skLWELarge := rlwe.NewSecretKey(paramsLWE) - skLWELarge2 := rlwe.NewSecretKey(paramsLUT) - ringLWE.InvNTT(skLWE.Value.Q, skLWELarge.Value.Q) - ringLWE.InvMForm(skLWELarge.Value.Q, skLWELarge.Value.Q) - - for i := range skLWELarge.Value.Q.Coeffs[0] { - c := skLWELarge.Value.Q.Coeffs[0][i] - if c == paramsLWE.Q()[0]-1 { - skLWELarge.Value.Q.Coeffs[0][i] = paramsLUT.Q()[0] - 1 - skLWELarge2.Value.Q.Coeffs[0][i*2] = paramsLUT.Q()[0] - 1 - } else if c != 0 { - skLWELarge.Value.Q.Coeffs[0][i] = 1 - skLWELarge2.Value.Q.Coeffs[0][i*2] = 1 - } - } - - paramsKS.RingQ().NTT(skLWELarge.Value.Q, skLWELarge.Value.Q) - paramsKS.RingQ().MForm(skLWELarge.Value.Q, skLWELarge.Value.Q) - - paramsLUT.RingQ().NTT(skLWELarge2.Value.Q, skLWELarge2.Value.Q) - paramsLUT.RingQ().MForm(skLWELarge2.Value.Q, skLWELarge2.Value.Q) - - skLUT2skLWE := kgenLUT.GenSwitchingKey(skLUT, skLWELarge) - - ringLWE.Add(ctm0.Value[0], ctm1.Value[0], ctm0.Value[0]) - ringLWE.Add(ctm0.Value[1], ctm1.Value[1], ctm0.Value[1]) - - ctsLUT := handler.ExtractAndEvaluateLUT(ctm0, lutPolyMap, LUTKEY) - - tmp := rlwe.NewCiphertextNTT(paramsLUT, 1, paramsLUT.MaxLevel()) - handler.evalRLWE.SwitchKeysInPlace(0, ctsLUT[0].Value[1], skLUT2skLWE, handler.evalRLWE.Pool[1].Q, handler.evalRLWE.Pool[2].Q) - ringLUT.AddLvl(0, ctsLUT[0].Value[0], handler.evalRLWE.Pool[1].Q, tmp.Value[0]) - ring.CopyValuesLvl(0, handler.evalRLWE.Pool[2].Q, tmp.Value[1]) - ctLWE := rlwe.NewCiphertextNTT(paramsKS, 1, paramsKS.MaxLevel()) - rlwe.SwitchCiphertextRingDegreeNTT(tmp, paramsKS.RingQ(), paramsLUT.RingQ(), ctLWE) - - for i := range ctLWE.Value { - paramsKS.RingQ().InvNTT(ctLWE.Value[i], ctLWE.Value[i]) - - Q := paramsKS.Q()[0] - q := paramsLWE.Q()[0] - ratio := float64(q) / float64(Q) - - for j := 0; j < paramsLWE.N(); j++ { - - c := ctLWE.Value[i].Coeffs[0][j] - c = uint64(float64(c)*ratio + 0.5) - ctLWE.Value[i].Coeffs[0][j] = c - } - - paramsLWE.RingQ().NTT(ctLWE.Value[i], ctLWE.Value[i]) - } - - q := paramsLWE.Q()[0] - qHalf := q >> 1 - - decryptorLWE := rlwe.NewDecryptor(paramsLWE, skLWE) - ptLWE := rlwe.NewPlaintext(paramsLWE, paramsLWE.MaxLevel()) - - decryptorLWE.Decrypt(ctLWE, ptLWE) - - _ = scaleLUT - - c := ptLWE.Value.Coeffs[0][0] - - var a float64 - if c >= qHalf { - a = -float64(q-c) / scaleLWE - } else { - a = float64(c) / scaleLWE - } - - fmt.Println(math.Round(a * 4)) - }) -} - -func testLUT(params rlwe.Parameters, t *testing.T) { - var err error - - // N=1024, Q=0x7fff801 -> 2^131 - paramsLUT, err := rlwe.NewParametersFromLiteral(rlwe.ParametersLiteral{ - LogN: 8, - Q: []uint64{0x7fff801}, - P: []uint64{}, - Sigma: rlwe.DefaultSigma, - LogBase2: 7, - }) - - assert.Nil(t, err) - - // N=512, Q=0x3001 -> 2^135 - paramsLWE, err := rlwe.NewParametersFromLiteral(rlwe.ParametersLiteral{ - LogN: 7, - Q: []uint64{0x3001}, - P: []uint64{}, - Sigma: rlwe.DefaultSigma, - }) - - assert.Nil(t, err) - - t.Run(testString(paramsLUT, "LUT/"), func(t *testing.T) { - - scaleLWE := float64(paramsLWE.Q()[0]) / 4.0 - scaleLUT := float64(paramsLUT.Q()[0]) / 4.0 - - slots := 16 - - LUTPoly := InitLUT(nandGate, scaleLUT, paramsLUT.RingQ(), -1, 1) - - lutPolyMap := make(map[int]*ring.Poly) - for i := 0; i < slots; i++ { - lutPolyMap[i] = LUTPoly - } - - skLWE := rlwe.NewKeyGenerator(paramsLWE).GenSecretKey() - encryptorLWE := rlwe.NewEncryptor(paramsLWE, skLWE) - - values := make([]float64, slots) - for i := 0; i < slots; i++ { - values[i] = -1 + float64(2*i)/float64(slots) - } - - ptLWE := rlwe.NewPlaintext(paramsLWE, paramsLWE.MaxLevel()) - for i := range values { - if values[i] < 0 { - ptLWE.Value.Coeffs[0][i] = paramsLWE.Q()[0] - uint64(-values[i]*scaleLWE) - } else { - ptLWE.Value.Coeffs[0][i] = uint64(values[i] * scaleLWE) - } - } - ctLWE := rlwe.NewCiphertextNTT(paramsLWE, 1, paramsLWE.MaxLevel()) - encryptorLWE.Encrypt(ptLWE, ctLWE) - - handler := NewHandler(paramsLUT, paramsLWE, nil) - - skLUT := rlwe.NewKeyGenerator(paramsLUT).GenSecretKey() - LUTKEY := handler.GenLUTKey(skLUT, skLWE) - - ctsLUT := handler.ExtractAndEvaluateLUT(ctLWE, lutPolyMap, LUTKEY) - - q := paramsLUT.Q()[0] - qHalf := q >> 1 - decryptorLUT := rlwe.NewDecryptor(paramsLUT, skLUT) - ptLUT := rlwe.NewPlaintext(paramsLUT, paramsLUT.MaxLevel()) - for i := 0; i < slots; i++ { - - decryptorLUT.Decrypt(ctsLUT[i], ptLUT) - - c := ptLUT.Value.Coeffs[0][i] - - var a float64 - if c >= qHalf { - a = -float64(q-c) / scaleLUT - } else { - a = float64(c) / scaleLUT - } - - fmt.Printf("%7.4f - %7.4f - %7.4f\n", math.Round(a*32)/32, math.Round(a*8)/8, values[i]) - } - }) -} - -func testRLWEToLWE(params rlwe.Parameters, t *testing.T) { - t.Run(testString(params, "RLWEToLWE/"), func(t *testing.T) { - kgen := rlwe.NewKeyGenerator(params) - sk := kgen.GenSecretKey() - encryptor := rlwe.NewEncryptor(params, sk) - pt := rlwe.NewPlaintext(params, params.MaxLevel()) - ct := rlwe.NewCiphertextNTT(params, 1, params.MaxLevel()) - encryptor.Encrypt(pt, ct) - - skInvNTT := params.RingQ().NewPoly() - params.RingQ().InvNTT(sk.Value.Q, skInvNTT) - - slotIndex := make(map[int]bool) - for i := 0; i < params.N(); i++ { - slotIndex[i] = true - } - - LWE := RLWEToLWE(ct, params.RingQ(), slotIndex) - - for i := 0; i < params.RingQ().N; i++ { - if math.Abs(DecryptLWE(LWE[i], params.RingQ(), skInvNTT)) > 19 { - t.Error() - } - } - }) -} - -func testLWEToRLWE(params rlwe.Parameters, t *testing.T) { - t.Run(testString(params, "LWEToRLWE/"), func(t *testing.T) { - kgen := rlwe.NewKeyGenerator(params) - sk := kgen.GenSecretKey() - encryptor := rlwe.NewEncryptor(params, sk) - decryptor := rlwe.NewDecryptor(params, sk) - pt := rlwe.NewPlaintext(params, params.MaxLevel()) - ct := rlwe.NewCiphertextNTT(params, 1, params.MaxLevel()) - encryptor.Encrypt(pt, ct) - - skInvNTT := params.RingQ().NewPoly() - params.RingQ().InvNTT(sk.Value.Q, skInvNTT) - - slotIndex := make(map[int]bool) - for i := 0; i < params.N(); i++ { - slotIndex[i] = true - } - - ctLWE := RLWEToLWE(ct, params.RingQ(), slotIndex) - - DecryptLWE(ctLWE[0], params.RingQ(), skInvNTT) - - handler := NewHandler(params, params, nil) - - ctRLWE := handler.LWEToRLWE(ctLWE) - - for i := 0; i < len(ctRLWE); i++ { - decryptor.Decrypt(ctRLWE[i], pt) - - for j := 0; j < pt.Level()+1; j++ { - - c := pt.Value.Coeffs[j][0] - - if c >= params.RingQ().Modulus[j]>>1 { - c = params.RingQ().Modulus[j] - c - } - - if c > 19 { - t.Fatal(i, j, c) - } - } - } - }) -} diff --git a/lwe/lwe_to_rlwe.go b/lwe/lwe_to_rlwe.go deleted file mode 100644 index 96e15f99..00000000 --- a/lwe/lwe_to_rlwe.go +++ /dev/null @@ -1,45 +0,0 @@ -package lwe - -import ( - "github.com/tuneinsight/lattigo/v3/rlwe" -) - -// LWEToRLWE transforms a set of LWE samples into their respective RLWE ciphertext such that decrypt(RLWE)[0] = decrypt(LWE) -func (h *Handler) LWEToRLWE(ctLWE map[int]*Ciphertext) (ctRLWE map[int]*rlwe.Ciphertext) { - - var level int - for i := range ctLWE { - level = ctLWE[i].Level() - break - } - - ringQ := h.paramsLUT.RingQ() - acc := ringQ.NewPolyLvl(level) - ctRLWE = make(map[int]*rlwe.Ciphertext) - for i := range ctLWE { - - // Alocates ciphertext - ctRLWE[i] = rlwe.NewCiphertextNTT(h.paramsLUT, 1, level) - - for u := 0; u < level+1; u++ { - - ctRLWE[i].Value[0].Coeffs[u][0] = ctLWE[i].Value[u][0] - - // Copy coefficients multiplied by X^{N-1} in reverse order: - // a_{0} -a_{N-1} -a2_{N-2} ... -a_{1} - tmp0, tmp1 := acc.Coeffs[u], ctLWE[i].Value[u][1:] - tmp0[0] = tmp1[0] - for k := 1; k < ringQ.N; k++ { - tmp0[k] = ringQ.Modulus[u] - tmp1[ringQ.N-k] - } - - copy(ctRLWE[i].Value[1].Coeffs[u], acc.Coeffs[u]) - } - - // Switches to NTT domain - ringQ.NTTLvl(level, ctRLWE[i].Value[0], ctRLWE[i].Value[0]) - ringQ.NTTLvl(level, ctRLWE[i].Value[1], ctRLWE[i].Value[1]) - } - - return -} diff --git a/lwe/rlwe_to_lwe.go b/lwe/rlwe_to_lwe.go deleted file mode 100644 index 81f72787..00000000 --- a/lwe/rlwe_to_lwe.go +++ /dev/null @@ -1,99 +0,0 @@ -package lwe - -import ( - "github.com/tuneinsight/lattigo/v3/ring" - "github.com/tuneinsight/lattigo/v3/rlwe" -) - -// RLWEToLWESingle extract the first coefficient of the input RLWE and returns it -// as a LWE sample. -func RLWEToLWESingle(ct *rlwe.Ciphertext, ringQ *ring.Ring) (lwe *Ciphertext) { - - level := ct.Level() - - c0 := ringQ.NewPolyLvl(level) - c1 := ringQ.NewPolyLvl(level) - acc := ringQ.NewPolyLvl(level) - - ringQ.InvNTTLvl(level, ct.Value[0], c0) - ringQ.InvNTTLvl(level, ct.Value[1], c1) - - // Copy coefficients multiplied by X^{N-1} in reverse order: - // a_{0} -a_{N-1} -a2_{N-2} ... -a_{1} - for i, qi := range ringQ.Modulus[:level+1] { - tmp0 := acc.Coeffs[i] - tmp1 := c1.Coeffs[i] - tmp0[0] = tmp1[0] - for j := 1; j < ringQ.N; j++ { - tmp0[j] = qi - tmp1[ringQ.N-j] - } - } - - N := ringQ.N - - lwe = NewCiphertext(N, level) - - for j := 0; j < level+1; j++ { - lwe.Value[j][0] = c0.Coeffs[j][0] - copy(lwe.Value[j][1:], acc.Coeffs[j]) - } - - return -} - -// RLWEToLWE extracts all LWE samples from a RLWE ciphertext. -func RLWEToLWE(ct *rlwe.Ciphertext, ringQ *ring.Ring, slotIndex map[int]bool) (LWE map[int]*Ciphertext) { - - LWE = make(map[int]*Ciphertext) - - level := ct.Level() - - c0 := ringQ.NewPolyLvl(level) - c1 := ringQ.NewPolyLvl(level) - acc := ringQ.NewPolyLvl(level) - - ringQ.InvNTTLvl(level, ct.Value[0], c0) - ringQ.InvNTTLvl(level, ct.Value[1], c1) - - // Copy coefficients multiplied by X^{N-1} in reverse order: - // a_{0} -a_{N-1} -a2_{N-2} ... -a_{1} - for i, qi := range ringQ.Modulus[:level+1] { - tmp0 := acc.Coeffs[i] - tmp1 := c1.Coeffs[i] - tmp0[0] = tmp1[0] - for j := 1; j < ringQ.N; j++ { - tmp0[j] = qi - tmp1[ringQ.N-j] - } - } - - var prevIndex int - for index := 0; index < ringQ.N; index++ { - - if _, ok := slotIndex[index]; ok { - - // Multiplies the accumulator by X^{N/(2*slots)} - MulBySmallMonomial(ringQ, acc, index-prevIndex) - prevIndex = index - - LWE[index] = NewCiphertext(ringQ.N, level) - - for j := 0; j < level+1; j++ { - LWE[index].Value[j][0] = c0.Coeffs[j][index] - copy(LWE[index].Value[j][1:], acc.Coeffs[j]) - } - } - } - - return -} - -//MulBySmallMonomial multiplies pol by x^n -func MulBySmallMonomial(ringQ *ring.Ring, pol *ring.Poly, n int) { - for i, qi := range ringQ.Modulus[:pol.Level()+1] { - pol.Coeffs[i] = append(pol.Coeffs[i][ringQ.N-n:], pol.Coeffs[i][:ringQ.N-n]...) - tmp := pol.Coeffs[i] - for j := 0; j < n; j++ { - tmp[j] = qi - tmp[j] - } - } -} diff --git a/lwe/utils.go b/lwe/utils.go deleted file mode 100644 index 379a2c45..00000000 --- a/lwe/utils.go +++ /dev/null @@ -1,109 +0,0 @@ -package lwe - -import ( - "github.com/tuneinsight/lattigo/v3/ring" - "math/big" -) - -func normalizeInv(x, a, b float64) (y float64) { - return (x*(b-a) + b + a) / 2.0 -} - -func scaleUp(value float64, scale float64, Q uint64) (res uint64) { - - var isNegative bool - var xFlo *big.Float - var xInt *big.Int - - isNegative = false - if value < 0 { - isNegative = true - xFlo = big.NewFloat(-scale * value) - } else { - xFlo = big.NewFloat(scale * value) - } - - xFlo.Add(xFlo, big.NewFloat(0.5)) - - xInt = new(big.Int) - xFlo.Int(xInt) - xInt.Mod(xInt, ring.NewUint(Q)) - - res = xInt.Uint64() - - if isNegative { - res = Q - res - } - - return -} - -// DecryptLWE decrypts an LWE sample -func DecryptLWE(ct *Ciphertext, ringQ *ring.Ring, skMont *ring.Poly) float64 { - - level := ct.Level() - - pol := ringQ.NewPolyLvl(ct.Level()) - for i := 0; i < level+1; i++ { - copy(pol.Coeffs[i], ct.Value[i][1:]) - } - - ringQ.MulCoeffsMontgomeryLvl(level, pol, skMont, pol) - - a := make([]uint64, level+1) - - for i := 0; i < level+1; i++ { - qi := ringQ.Modulus[i] - tmp := pol.Coeffs[i] - a[i] = ct.Value[i][0] - for j := 0; j < ringQ.N; j++ { - a[i] = ring.CRed(a[i]+tmp[j], qi) - } - } - - crtReconstruction := make([]*big.Int, level+1) - - QiB := new(big.Int) - tmp := new(big.Int) - modulusBigint := ring.NewUint(1) - - for i := 0; i < level+1; i++ { - - qi := ringQ.Modulus[i] - QiB.SetUint64(qi) - - modulusBigint.Mul(modulusBigint, QiB) - - crtReconstruction[i] = new(big.Int) - crtReconstruction[i].Quo(ringQ.ModulusBigint, QiB) - tmp.ModInverse(crtReconstruction[i], QiB) - tmp.Mod(tmp, QiB) - crtReconstruction[i].Mul(crtReconstruction[i], tmp) - } - - tmp.SetUint64(0) - coeffsBigint := ring.NewUint(0) - - modulusBigintHalf := new(big.Int) - modulusBigintHalf.Rsh(modulusBigint, 1) - - var sign int - for i := 0; i < level+1; i++ { - coeffsBigint.Add(coeffsBigint, tmp.Mul(ring.NewUint(a[i]), crtReconstruction[i])) - } - - coeffsBigint.Mod(coeffsBigint, modulusBigint) - - // Centers the coefficients - sign = coeffsBigint.Cmp(modulusBigintHalf) - - if sign == 1 || sign == 0 { - coeffsBigint.Sub(coeffsBigint, modulusBigint) - } - - flo := new(big.Float) - flo.SetInt(coeffsBigint) - flo64, _ := flo.Float64() - - return flo64 -} diff --git a/rlwe/encryptor.go b/rlwe/encryptor.go index b0f89930..9b71425f 100644 --- a/rlwe/encryptor.go +++ b/rlwe/encryptor.go @@ -2,13 +2,15 @@ package rlwe import ( "github.com/tuneinsight/lattigo/v3/ring" + "github.com/tuneinsight/lattigo/v3/rlwe/gadget" + "github.com/tuneinsight/lattigo/v3/rlwe/rgsw" "github.com/tuneinsight/lattigo/v3/rlwe/ringqp" "github.com/tuneinsight/lattigo/v3/utils" ) // Encryptor a generic RLWE encryption interface. type Encryptor interface { - Encrypt(pt *Plaintext, ct *Ciphertext) + Encrypt(pt *Plaintext, ct interface{}) EncryptFromCRP(pt *Plaintext, crp *ring.Poly, ct *Ciphertext) ShallowCopy() Encryptor WithKey(key interface{}) Encryptor @@ -89,8 +91,14 @@ func newEncryptorSamplers(params Parameters) *encryptorSamplers { } type encryptorBuffers struct { +<<<<<<< dev_bfv_poly buffQ [2]*ring.Poly buffP [3]*ring.Poly +======= + poolQ [2]*ring.Poly + poolP [3]*ring.Poly + poolQP ringqp.Poly +>>>>>>> [rlwe]: further refactoring } func newEncryptorBuffers(params Parameters) *encryptorBuffers { @@ -104,8 +112,14 @@ func newEncryptorBuffers(params Parameters) *encryptorBuffers { } return &encryptorBuffers{ +<<<<<<< dev_bfv_poly buffQ: [2]*ring.Poly{ringQ.NewPoly(), ringQ.NewPoly()}, buffP: buffP, +======= + poolQ: [2]*ring.Poly{ringQ.NewPoly(), ringQ.NewPoly()}, + poolP: poolP, + poolQP: params.RingQP().NewPoly(), +>>>>>>> [rlwe]: further refactoring } } @@ -115,77 +129,55 @@ func newEncryptorBuffers(params Parameters) *encryptorBuffers { // The encryption procedures depends on the parameters. If the auxiliary modulus P is defined, // then the encryption of zero is sampled in QP before being rescaled by P; otherwise, it is directly // sampled in Q. -func (enc *pkEncryptor) Encrypt(pt *Plaintext, ct *Ciphertext) { - enc.uniformSamplerQ.ReadLvl(utils.MinInt(pt.Level(), ct.Level()), ct.Value[1]) +// The method accepts only *rlwe.Ciphertext as input. +func (enc *pkEncryptor) Encrypt(pt *Plaintext, ct interface{}) { - if enc.basisextender != nil { - enc.encrypt(pt, ct) - } else { - enc.encryptNoP(pt, ct) + switch el := ct.(type) { + case *Ciphertext: + enc.uniformSamplerQ.ReadLvl(utils.MinInt(pt.Level(), el.Level()), el.Value[1]) + if enc.basisextender != nil { + enc.encryptRLWE(pt, el) + } else { + enc.encryptNoPRLWE(pt, el) + } + default: + panic("input ciphertext type unsuported (must be *rlwe.Ciphertext or *rgsw.Ciphertext)") } + } -// EncryptFromCRP is not defined when using a public-key. This method will panic. +// EncryptFromCRP is not defined when using a public-key. This method will always panic. func (enc *pkEncryptor) EncryptFromCRP(pt *Plaintext, crp *ring.Poly, ct *Ciphertext) { panic("Cannot encrypt with CRP using a public-key") } -// Encrypt encrypts the input plaintext and write the result on ct. -func (enc *skEncryptor) Encrypt(pt *Plaintext, ct *Ciphertext) { - - enc.uniformSamplerQ.ReadLvl(utils.MinInt(pt.Level(), ct.Level()), ct.Value[1]) - - enc.encrypt(pt, ct) +// Encrypt encrypts the input plaintext using the stored public-key and writes the result on ct. +// The encryption procedure first samples an new encryption of zero under the public-key and +// then adds the plaintext. +// The encryption procedures depends on the parameters. If the auxiliary modulus P is defined, +// then the encryption of zero is sampled in QP before being rescaled by P; otherwise, it is directly +// sampled in Q. +// The method accepts only *rlwe.Ciphertext or *rgsw.Ciphertext as input and will panic otherwise. +func (enc *skEncryptor) Encrypt(pt *Plaintext, ct interface{}) { + switch el := ct.(type) { + case *Ciphertext: + enc.uniformSamplerQ.ReadLvl(utils.MinInt(pt.Level(), el.Level()), el.Value[1]) + enc.encryptRLWE(pt, el) + case *rgsw.Ciphertext: + enc.encryptRGSW(pt, el) + default: + panic("input ciphertext type unsuported (must be *rlwe.Ciphertext or *rgsw.Ciphertext)") + } } // EncryptFromCRP encrypts the input plaintext and writes the result on ct. // The encryption algorithm depends on the implementor. func (enc *skEncryptor) EncryptFromCRP(pt *Plaintext, crp *ring.Poly, ct *Ciphertext) { ring.CopyValues(crp, ct.Value[1]) - - enc.encrypt(pt, ct) + enc.encryptRLWE(pt, ct) } -// ShallowCopy creates a shallow copy of this pkEncryptor in which all the read-only data-structures are -// shared with the receiver and the temporary buffers are reallocated. The receiver and the returned -// Encryptors can be used concurrently. -func (enc *pkEncryptor) ShallowCopy() Encryptor { - return &pkEncryptor{*enc.encryptor.ShallowCopy(), enc.pk} -} - -// ShallowCopy creates a shallow copy of this skEncryptor in which all the read-only data-structures are -// shared with the receiver and the temporary buffers are reallocated. The receiver and the returned -// Encryptors can be used concurrently. -func (enc *skEncryptor) ShallowCopy() Encryptor { - return &skEncryptor{*enc.encryptor.ShallowCopy(), enc.sk} -} - -// ShallowCopy creates a shallow copy of this encryptor in which all the read-only data-structures are -// shared with the receiver and the temporary buffers are reallocated. The receiver and the returned -// Encryptors can be used concurrently. -func (enc *encryptor) ShallowCopy() *encryptor { - - var bc *ring.BasisExtender - if enc.params.PCount() != 0 { - bc = enc.basisextender.ShallowCopy() - } - - return &encryptor{ - encryptorBase: enc.encryptorBase, - encryptorSamplers: newEncryptorSamplers(enc.params), - encryptorBuffers: newEncryptorBuffers(enc.params), - basisextender: bc, - } -} - -// WithKey creates a shallow copy of this encryptor with a new key in which all the read-only data-structures are -// shared with the receiver and the temporary buffers are reallocated. The receiver and the returned -// Encryptors can be used concurrently. -func (enc *encryptor) WithKey(key interface{}) Encryptor { - return enc.ShallowCopy().setKey(key) -} - -func (enc *pkEncryptor) encrypt(plaintext *Plaintext, ciphertext *Ciphertext) { +func (enc *pkEncryptor) encryptRLWE(plaintext *Plaintext, ciphertext *Ciphertext) { ringQ := enc.params.RingQ() ringQP := enc.params.RingQP() @@ -253,7 +245,7 @@ func (enc *pkEncryptor) encrypt(plaintext *Plaintext, ciphertext *Ciphertext) { if ciphertextNTT { - if !plaintext.Value.IsNTT { + if plaintext != nil && !plaintext.Value.IsNTT { ringQ.AddLvl(levelQ, ciphertext.Value[0], plaintext.Value, ciphertext.Value[0]) } @@ -261,12 +253,12 @@ func (enc *pkEncryptor) encrypt(plaintext *Plaintext, ciphertext *Ciphertext) { ringQ.NTTLvl(levelQ, ciphertext.Value[0], ciphertext.Value[0]) ringQ.NTTLvl(levelQ, ciphertext.Value[1], ciphertext.Value[1]) - if plaintext.Value.IsNTT { + if plaintext != nil && plaintext.Value.IsNTT { // ct0 = (u*pk0 + e0)/P + m ringQ.AddLvl(levelQ, ciphertext.Value[0], plaintext.Value, ciphertext.Value[0]) } - } else { + } else if plaintext != nil { if !plaintext.Value.IsNTT { ringQ.AddLvl(levelQ, ciphertext.Value[0], plaintext.Value, ciphertext.Value[0]) @@ -281,7 +273,7 @@ func (enc *pkEncryptor) encrypt(plaintext *Plaintext, ciphertext *Ciphertext) { ciphertext.Value[1].Coeffs = ciphertext.Value[1].Coeffs[:levelQ+1] } -func (enc *pkEncryptor) encryptNoP(plaintext *Plaintext, ciphertext *Ciphertext) { +func (enc *pkEncryptor) encryptNoPRLWE(plaintext *Plaintext, ciphertext *Ciphertext) { levelQ := utils.MinInt(plaintext.Level(), ciphertext.Level()) buffQ0 := enc.buffQ[0] @@ -309,6 +301,7 @@ func (enc *pkEncryptor) encryptNoP(plaintext *Plaintext, ciphertext *Ciphertext) // ct0 = u*pk0 + e0 enc.gaussianSampler.ReadLvl(levelQ, buffQ0) +<<<<<<< dev_bfv_poly if !plaintext.Value.IsNTT { ringQ.AddLvl(levelQ, buffQ0, plaintext.Value, buffQ0) ringQ.NTTLvl(levelQ, buffQ0, buffQ0) @@ -317,6 +310,18 @@ func (enc *pkEncryptor) encryptNoP(plaintext *Plaintext, ciphertext *Ciphertext) ringQ.NTTLvl(levelQ, buffQ0, buffQ0) ringQ.AddLvl(levelQ, ciphertext.Value[0], buffQ0, ciphertext.Value[0]) ringQ.AddLvl(levelQ, ciphertext.Value[0], plaintext.Value, ciphertext.Value[0]) +======= + if plaintext != nil { + if !plaintext.Value.IsNTT { + ringQ.AddLvl(levelQ, poolQ0, plaintext.Value, poolQ0) + ringQ.NTTLvl(levelQ, poolQ0, poolQ0) + ringQ.AddLvl(levelQ, ciphertext.Value[0], poolQ0, ciphertext.Value[0]) + } else { + ringQ.NTTLvl(levelQ, poolQ0, poolQ0) + ringQ.AddLvl(levelQ, ciphertext.Value[0], poolQ0, ciphertext.Value[0]) + ringQ.AddLvl(levelQ, ciphertext.Value[0], plaintext.Value, ciphertext.Value[0]) + } +>>>>>>> [rlwe]: further refactoring } } else { @@ -330,21 +335,30 @@ func (enc *pkEncryptor) encryptNoP(plaintext *Plaintext, ciphertext *Ciphertext) // ct[1] = pk[1]*u + e1 enc.gaussianSampler.ReadAndAddLvl(ciphertext.Level(), ciphertext.Value[1]) +<<<<<<< dev_bfv_poly if !plaintext.Value.IsNTT { ringQ.AddLvl(levelQ, ciphertext.Value[0], plaintext.Value, ciphertext.Value[0]) } else { ringQ.InvNTTLvl(levelQ, plaintext.Value, buffQ0) ringQ.AddLvl(levelQ, ciphertext.Value[0], buffQ0, ciphertext.Value[0]) +======= + if plaintext != nil { + if !plaintext.Value.IsNTT { + ringQ.AddLvl(levelQ, ciphertext.Value[0], plaintext.Value, ciphertext.Value[0]) + } else { + ringQ.InvNTTLvl(levelQ, plaintext.Value, poolQ0) + ringQ.AddLvl(levelQ, ciphertext.Value[0], poolQ0, ciphertext.Value[0]) + } +>>>>>>> [rlwe]: further refactoring } } ciphertext.Value[1].IsNTT = ciphertext.Value[0].IsNTT - ciphertext.Value[0].Coeffs = ciphertext.Value[0].Coeffs[:levelQ+1] ciphertext.Value[1].Coeffs = ciphertext.Value[1].Coeffs[:levelQ+1] } -func (enc *skEncryptor) encrypt(plaintext *Plaintext, ciphertext *Ciphertext) { +func (enc *skEncryptor) encryptRLWE(plaintext *Plaintext, ciphertext *Ciphertext) { ringQ := enc.params.RingQ() @@ -361,6 +375,7 @@ func (enc *skEncryptor) encrypt(plaintext *Plaintext, ciphertext *Ciphertext) { enc.gaussianSampler.ReadLvl(levelQ, buffQ0) +<<<<<<< dev_bfv_poly if plaintext.Value.IsNTT { ringQ.NTTLvl(levelQ, buffQ0, buffQ0) ringQ.AddLvl(levelQ, ciphertext.Value[0], buffQ0, ciphertext.Value[0]) @@ -369,35 +384,158 @@ func (enc *skEncryptor) encrypt(plaintext *Plaintext, ciphertext *Ciphertext) { ringQ.AddLvl(levelQ, buffQ0, plaintext.Value, buffQ0) ringQ.NTTLvl(levelQ, buffQ0, buffQ0) ringQ.AddLvl(levelQ, ciphertext.Value[0], buffQ0, ciphertext.Value[0]) +======= + if plaintext != nil { + if plaintext.Value.IsNTT { + ringQ.NTTLvl(levelQ, poolQ0, poolQ0) + ringQ.AddLvl(levelQ, ciphertext.Value[0], poolQ0, ciphertext.Value[0]) + ringQ.AddLvl(levelQ, ciphertext.Value[0], plaintext.Value, ciphertext.Value[0]) + } else { + ringQ.AddLvl(levelQ, poolQ0, plaintext.Value, poolQ0) + ringQ.NTTLvl(levelQ, poolQ0, poolQ0) + ringQ.AddLvl(levelQ, ciphertext.Value[0], poolQ0, ciphertext.Value[0]) + } } - - ciphertext.Value[0].IsNTT = true - ciphertext.Value[1].IsNTT = true - } else { + if plaintext != nil { + if plaintext.Value.IsNTT { + ringQ.AddLvl(levelQ, ciphertext.Value[0], plaintext.Value, ciphertext.Value[0]) + ringQ.InvNTTLvl(levelQ, ciphertext.Value[0], ciphertext.Value[0]) - if plaintext.Value.IsNTT { - ringQ.AddLvl(levelQ, ciphertext.Value[0], plaintext.Value, ciphertext.Value[0]) - ringQ.InvNTTLvl(levelQ, ciphertext.Value[0], ciphertext.Value[0]) - - } else { - ringQ.InvNTTLvl(levelQ, ciphertext.Value[0], ciphertext.Value[0]) - ringQ.AddLvl(levelQ, ciphertext.Value[0], plaintext.Value, ciphertext.Value[0]) + } else { + ringQ.InvNTTLvl(levelQ, ciphertext.Value[0], ciphertext.Value[0]) + ringQ.AddLvl(levelQ, ciphertext.Value[0], plaintext.Value, ciphertext.Value[0]) + } +>>>>>>> [rlwe]: further refactoring } enc.gaussianSampler.ReadAndAddLvl(ciphertext.Level(), ciphertext.Value[0]) ringQ.InvNTTLvl(levelQ, ciphertext.Value[1], ciphertext.Value[1]) - - ciphertext.Value[0].IsNTT = false - ciphertext.Value[1].IsNTT = false - } + ciphertext.Value[1].IsNTT = ciphertext.Value[0].IsNTT ciphertext.Value[0].Coeffs = ciphertext.Value[0].Coeffs[:levelQ+1] ciphertext.Value[1].Coeffs = ciphertext.Value[1].Coeffs[:levelQ+1] } +func (enc *skEncryptor) encryptRGSW(pt *Plaintext, ct *rgsw.Ciphertext) { + + params := enc.params + ringQ := params.RingQ() + levelQ := ct.LevelQ() + levelP := ct.LevelP() + + decompRNS := params.DecompRNS(levelQ, levelP) + decompBIT := params.DecompBIT(levelQ, levelP) + + for j := 0; j < decompBIT; j++ { + for i := 0; i < decompRNS; i++ { + enc.encryptZeroSymetricQP(levelQ, levelP, enc.sk.Value, true, true, true, ct.Value[0].Value[i][j]) + enc.encryptZeroSymetricQP(levelQ, levelP, enc.sk.Value, true, true, true, ct.Value[1].Value[i][j]) + } + } + + if pt != nil { + ringQ.MFormLvl(levelQ, pt.Value, enc.poolQP.Q) + if !pt.Value.IsNTT { + ringQ.NTTLvl(levelQ, enc.poolQP.Q, enc.poolQP.Q) + } + gadget.AddPolyToGadgetMatrix( + enc.poolQP.Q, + []gadget.Ciphertext{ct.Value[0], ct.Value[1]}, + *params.RingQP(), + params.LogBase2(), + enc.poolQP.Q) + } +} + +func (enc *encryptor) encryptZeroSymetricQP(levelQ, levelP int, sk ringqp.Poly, sample, montgomery, ntt bool, ct [2]ringqp.Poly) { + + params := enc.params + ringQP := params.RingQP() + + hasModulusP := ct[0].P != nil + + if ntt { + enc.gaussianSampler.ReadLvl(levelQ, ct[0].Q) + + if hasModulusP { + ringQP.ExtendBasisSmallNormAndCenter(ct[0].Q, levelP, nil, ct[0].P) + } + + ringQP.NTTLvl(levelQ, levelP, ct[0], ct[0]) + } + + if sample { + enc.uniformSamplerQ.ReadLvl(levelQ, ct[1].Q) + + if hasModulusP { + enc.uniformSamplerP.ReadLvl(levelP, ct[1].P) + } + } + + ringQP.MulCoeffsMontgomeryAndSubLvl(levelQ, levelP, ct[1], sk, ct[0]) + + if !ntt { + ringQP.InvNTTLvl(levelQ, levelP, ct[0], ct[0]) + ringQP.InvNTTLvl(levelQ, levelP, ct[1], ct[1]) + + e := enc.poolQP + enc.gaussianSampler.ReadLvl(levelQ, e.Q) + + if hasModulusP { + ringQP.ExtendBasisSmallNormAndCenter(e.Q, levelP, nil, e.P) + } + + ringQP.AddLvl(levelQ, levelP, ct[0], e, ct[0]) + } + + if montgomery { + ringQP.MFormLvl(levelQ, levelP, ct[0], ct[0]) + ringQP.MFormLvl(levelQ, levelP, ct[1], ct[1]) + } +} + +// ShallowCopy creates a shallow copy of this pkEncryptor in which all the read-only data-structures are +// shared with the receiver and the temporary buffers are reallocated. The receiver and the returned +// Encryptors can be used concurrently. +func (enc *pkEncryptor) ShallowCopy() Encryptor { + return &pkEncryptor{*enc.encryptor.ShallowCopy(), enc.pk} +} + +// ShallowCopy creates a shallow copy of this skEncryptor in which all the read-only data-structures are +// shared with the receiver and the temporary buffers are reallocated. The receiver and the returned +// Encryptors can be used concurrently. +func (enc *skEncryptor) ShallowCopy() Encryptor { + return &skEncryptor{*enc.encryptor.ShallowCopy(), enc.sk} +} + +// ShallowCopy creates a shallow copy of this encryptor in which all the read-only data-structures are +// shared with the receiver and the temporary buffers are reallocated. The receiver and the returned +// Encryptors can be used concurrently. +func (enc *encryptor) ShallowCopy() *encryptor { + + var bc *ring.BasisExtender + if enc.params.PCount() != 0 { + bc = enc.basisextender.ShallowCopy() + } + + return &encryptor{ + encryptorBase: enc.encryptorBase, + encryptorSamplers: newEncryptorSamplers(enc.params), + encryptorBuffers: newEncryptorBuffers(enc.params), + basisextender: bc, + } +} + +// WithKey creates a shallow copy of this encryptor with a new key in which all the read-only data-structures are +// shared with the receiver and the temporary buffers are reallocated. The receiver and the returned +// Encryptors can be used concurrently. +func (enc *encryptor) WithKey(key interface{}) Encryptor { + return enc.ShallowCopy().setKey(key) +} + func (enc *encryptor) setKey(key interface{}) Encryptor { switch key := key.(type) { case *PublicKey: diff --git a/rlwe/eval_automorphism.go b/rlwe/eval_automorphism.go new file mode 100644 index 00000000..23e361c2 --- /dev/null +++ b/rlwe/eval_automorphism.go @@ -0,0 +1,126 @@ +package rlwe + +import ( + "fmt" + "github.com/tuneinsight/lattigo/v3/ring" + "github.com/tuneinsight/lattigo/v3/rlwe/ringqp" + "github.com/tuneinsight/lattigo/v3/utils" +) + +// Automorphism computes phi(ct), where phi is the map X -> X^galEl. The method requires +// that the corresponding RotationKey has been added to the Evaluator. The method will +// panic if either ctIn or ctOut degree is not equal to 1. +func (eval *Evaluator) Automorphism(ctIn *Ciphertext, galEl uint64, ctOut *Ciphertext) { + + if ctIn.Degree() != 1 || ctOut.Degree() != 1 { + panic("cannot apply Automorphism: input and output Ciphertext must be of degree 1") + } + + if galEl == 1 { + if ctOut != ctIn { + ctOut.Copy(ctIn) + } + return + } + + rtk, generated := eval.Rtks.GetRotationKey(galEl) + if !generated { + panic(fmt.Sprintf("galEl key 5^%d missing", eval.params.InverseGaloisElement(galEl))) + } + + level := utils.MinInt(ctIn.Level(), ctOut.Level()) + + ringQ := eval.params.RingQ() + + eval.SwitchKeysInPlace(level, ctIn.Value[1], rtk, eval.Pool[1].Q, eval.Pool[2].Q) + ringQ.AddLvl(level, eval.Pool[1].Q, ctIn.Value[0], eval.Pool[1].Q) + + if ctIn.Value[0].IsNTT { + ringQ.PermuteNTTWithIndexLvl(level, eval.Pool[1].Q, eval.PermuteNTTIndex[galEl], ctOut.Value[0]) + ringQ.PermuteNTTWithIndexLvl(level, eval.Pool[2].Q, eval.PermuteNTTIndex[galEl], ctOut.Value[1]) + } else { + ringQ.PermuteLvl(level, eval.Pool[1].Q, galEl, ctOut.Value[0]) + ringQ.PermuteLvl(level, eval.Pool[2].Q, galEl, ctOut.Value[1]) + } + + ctOut.Value[0].Coeffs = ctOut.Value[0].Coeffs[:level+1] + ctOut.Value[1].Coeffs = ctOut.Value[1].Coeffs[:level+1] +} + +// AutomorphismHoisted is similar to Automorphism, except that it takes as input ctIn and c1DecompQP, where c1DecompQP is the RNS +// decomposition of its element of degree 1. This decomposition can be obtained with DecomposeNTT. +// The method requires that the corresponding RotationKey has been added to the Evaluator. +// The method will panic if either ctIn or ctOut degree is not equal to 1. +func (eval *Evaluator) AutomorphismHoisted(level int, ctIn *Ciphertext, c1DecompQP []ringqp.Poly, galEl uint64, ctOut *Ciphertext) { + + if ctIn.Degree() != 1 || ctOut.Degree() != 1 { + panic("cannot apply AutomorphismHoisted: input and output Ciphertext must be of degree 1") + } + + if galEl == 1 { + if ctIn != ctOut { + ctOut.Copy(ctIn) + } + return + } + + rtk, generated := eval.Rtks.GetRotationKey(galEl) + if !generated { + panic(fmt.Sprintf("galEl key 5^%d missing", eval.params.InverseGaloisElement(galEl))) + } + + ringQ := eval.params.RingQ() + + eval.KeyswitchHoisted(level, c1DecompQP, rtk, eval.Pool[0].Q, eval.Pool[1].Q, eval.Pool[0].P, eval.Pool[1].P) + ringQ.AddLvl(level, eval.Pool[0].Q, ctIn.Value[0], eval.Pool[0].Q) + + if ctIn.Value[0].IsNTT { + ringQ.PermuteNTTWithIndexLvl(level, eval.Pool[0].Q, eval.PermuteNTTIndex[galEl], ctOut.Value[0]) + ringQ.PermuteNTTWithIndexLvl(level, eval.Pool[1].Q, eval.PermuteNTTIndex[galEl], ctOut.Value[1]) + } else { + ringQ.PermuteLvl(level, eval.Pool[0].Q, galEl, ctOut.Value[0]) + ringQ.PermuteLvl(level, eval.Pool[1].Q, galEl, ctOut.Value[1]) + } + + ctOut.Value[0].Coeffs = ctOut.Value[0].Coeffs[:level+1] + ctOut.Value[1].Coeffs = ctOut.Value[1].Coeffs[:level+1] +} + +// AutomorphismHoistedNoModDown is similar to AutomorphismHoisted, except that it returns a ciphertext modulo QP and scaled by P. +// The method requires that the corresponding RotationKey has been added to the Evaluator.The method will panic if either ctIn or ctOut degree is not equal to 1. +func (eval *Evaluator) AutomorphismHoistedNoModDown(levelQ int, c0 *ring.Poly, c1DecompQP []ringqp.Poly, galEl uint64, ct0OutQ, ct1OutQ, ct0OutP, ct1OutP *ring.Poly) { + + levelP := eval.params.PCount() - 1 + + rtk, generated := eval.Rtks.GetRotationKey(galEl) + if !generated { + panic(fmt.Sprintf("galEl key 5^%d missing", eval.params.InverseGaloisElement(galEl))) + } + + eval.KeyswitchHoistedNoModDown(levelQ, c1DecompQP, rtk, eval.Pool[0].Q, eval.Pool[1].Q, eval.Pool[0].P, eval.Pool[1].P) + + ringQ := eval.params.RingQ() + + if c0.IsNTT { + + index := eval.PermuteNTTIndex[galEl] + + ringQ.PermuteNTTWithIndexLvl(levelQ, eval.Pool[1].Q, index, ct1OutQ) + ringQ.PermuteNTTWithIndexLvl(levelP, eval.Pool[1].P, index, ct1OutP) + + ringQ.MulScalarBigintLvl(levelQ, c0, eval.params.RingP().ModulusBigint, eval.Pool[1].Q) + ringQ.AddLvl(levelQ, eval.Pool[0].Q, eval.Pool[1].Q, eval.Pool[0].Q) + + ringQ.PermuteNTTWithIndexLvl(levelQ, eval.Pool[0].Q, index, ct0OutQ) + ringQ.PermuteNTTWithIndexLvl(levelP, eval.Pool[0].P, index, ct0OutP) + } else { + ringQ.PermuteLvl(levelQ, eval.Pool[1].Q, galEl, ct1OutQ) + ringQ.PermuteLvl(levelP, eval.Pool[1].P, galEl, ct1OutP) + + ringQ.MulScalarBigintLvl(levelQ, c0, eval.params.RingP().ModulusBigint, eval.Pool[1].Q) + ringQ.AddLvl(levelQ, eval.Pool[0].Q, eval.Pool[1].Q, eval.Pool[0].Q) + + ringQ.PermuteLvl(levelQ, eval.Pool[0].Q, galEl, ct0OutQ) + ringQ.PermuteLvl(levelP, eval.Pool[0].P, galEl, ct0OutP) + } +} diff --git a/rlwe/rgsw/rgsw.go b/rlwe/eval_external_product.go similarity index 90% rename from rlwe/rgsw/rgsw.go rename to rlwe/eval_external_product.go index ffb37ff5..bbd463de 100644 --- a/rlwe/rgsw/rgsw.go +++ b/rlwe/eval_external_product.go @@ -1,31 +1,19 @@ -package rgsw +package rlwe import ( "github.com/tuneinsight/lattigo/v3/ring" - "github.com/tuneinsight/lattigo/v3/rlwe" + "github.com/tuneinsight/lattigo/v3/rlwe/rgsw" "github.com/tuneinsight/lattigo/v3/rlwe/ringqp" "math" ) -// Evaluator is a struct storing the necessary elements to perform -// homomorphic operations with RGSW ciphertexts. -type Evaluator struct { - params rlwe.Parameters - *rlwe.Evaluator -} - -// NewEvaluator creates a new evaluator. -func NewEvaluator(params rlwe.Parameters) *Evaluator { - return &Evaluator{params, rlwe.NewEvaluator(params, nil)} -} - // ExternalProduct computes RLWE x RGSW -> RLWE // RLWE : (-as + m + e, a) // x // RGSW : [(-as + P*w*m1 + e, a), (-bs + e, b + P*w*m1)] // = // RLWE : (, ) -func (eval *Evaluator) ExternalProduct(op0 *rlwe.Ciphertext, op1 *Ciphertext, op2 *rlwe.Ciphertext) { +func (eval *Evaluator) ExternalProduct(op0 *Ciphertext, op1 *rgsw.Ciphertext, op2 *Ciphertext) { levelQ, levelP := op1.LevelQ(), op1.LevelP() @@ -64,7 +52,7 @@ func (eval *Evaluator) ExternalProduct(op0 *rlwe.Ciphertext, op1 *Ciphertext, op } } -func (eval *Evaluator) externalProduct32Bit(ct0 *rlwe.Ciphertext, rgsw *Ciphertext, c0, c1 *ring.Poly) { +func (eval *Evaluator) externalProduct32Bit(ct0 *Ciphertext, rgsw *rgsw.Ciphertext, c0, c1 *ring.Poly) { // rgsw = [(-as + P*w*m1 + e, a), (-bs + e, b + P*w*m1)] // ct = [-cs + m0 + e, c] @@ -98,7 +86,7 @@ func (eval *Evaluator) externalProduct32Bit(ct0 *rlwe.Ciphertext, rgsw *Cipherte } } -func (eval *Evaluator) externalProductInPlaceSinglePAndBitDecomp(ct0 *rlwe.Ciphertext, rgsw *Ciphertext, c0QP, c1QP ringqp.Poly) { +func (eval *Evaluator) externalProductInPlaceSinglePAndBitDecomp(ct0 *Ciphertext, rgsw *rgsw.Ciphertext, c0QP, c1QP ringqp.Poly) { // rgsw = [(-as + P*w*m1 + e, a), (-bs + e, b + P*w*m1)] // ct = [-cs + m0 + e, c] @@ -159,7 +147,7 @@ func (eval *Evaluator) externalProductInPlaceSinglePAndBitDecomp(ct0 *rlwe.Ciphe } } -func (eval *Evaluator) externalProductInPlaceMultipleP(levelQ, levelP int, ct0 *rlwe.Ciphertext, rgsw *Ciphertext, c0OutQ, c0OutP, c1OutQ, c1OutP *ring.Poly) { +func (eval *Evaluator) externalProductInPlaceMultipleP(levelQ, levelP int, ct0 *Ciphertext, rgsw *rgsw.Ciphertext, c0OutQ, c0OutP, c1OutQ, c1OutP *ring.Poly) { var reduce int ringQ := eval.params.RingQ() diff --git a/rlwe/eval_keyswitch.go b/rlwe/eval_keyswitch.go new file mode 100644 index 00000000..ed914a28 --- /dev/null +++ b/rlwe/eval_keyswitch.go @@ -0,0 +1,384 @@ +package rlwe + +import ( + "github.com/tuneinsight/lattigo/v3/ring" + "github.com/tuneinsight/lattigo/v3/rlwe/ringqp" + "github.com/tuneinsight/lattigo/v3/utils" +) + +// Relinearize applies the relinearization procedure on ct0 and returns the result in ctOut. +// The method will panic if the corresponding relinearization key to the ciphertext degree +// is missing. +func (eval *Evaluator) Relinearize(ctIn *Ciphertext, ctOut *Ciphertext) { + if eval.Rlk == nil || ctIn.Degree()-1 > len(eval.Rlk.Keys) { + panic("cannot Relinearize: relinearization key missing (or ciphertext degree is too large)") + } + + level := utils.MinInt(ctIn.Level(), ctOut.Level()) + + ringQ := eval.params.RingQ() + + eval.SwitchKeysInPlace(level, ctIn.Value[2], eval.Rlk.Keys[0], eval.Pool[1].Q, eval.Pool[2].Q) + ringQ.AddLvl(level, ctIn.Value[0], eval.Pool[1].Q, ctOut.Value[0]) + ringQ.AddLvl(level, ctIn.Value[1], eval.Pool[2].Q, ctOut.Value[1]) + + for deg := ctIn.Degree() - 1; deg > 1; deg-- { + eval.SwitchKeysInPlace(level, ctIn.Value[deg], eval.Rlk.Keys[deg-2], eval.Pool[1].Q, eval.Pool[2].Q) + ringQ.AddLvl(level, ctOut.Value[0], eval.Pool[1].Q, ctOut.Value[0]) + ringQ.AddLvl(level, ctOut.Value[1], eval.Pool[2].Q, ctOut.Value[1]) + } + + ctOut.Value = ctOut.Value[:2] + + ctOut.Value[0].Coeffs = ctOut.Value[0].Coeffs[:level+1] + ctOut.Value[1].Coeffs = ctOut.Value[1].Coeffs[:level+1] +} + +// SwitchKeys re-encrypts ctIn under a different key and returns the result in ctOut. +// It requires a SwitchingKey, which is computed from the key under which the Ciphertext is currently encrypted, +// and the key under which the Ciphertext will be re-encrypted. +// The method will panic if either ctIn or ctOut degree isn't 1. +func (eval *Evaluator) SwitchKeys(ctIn *Ciphertext, switchingKey *SwitchingKey, ctOut *Ciphertext) { + + if ctIn.Degree() != 1 || ctOut.Degree() != 1 { + panic("cannot SwitchKeys: input and output Ciphertext must be of degree 1") + } + + level := utils.MinInt(ctIn.Level(), ctOut.Level()) + ringQ := eval.params.RingQ() + + eval.SwitchKeysInPlace(level, ctIn.Value[1], switchingKey, eval.Pool[1].Q, eval.Pool[2].Q) + + ringQ.AddLvl(level, ctIn.Value[0], eval.Pool[1].Q, ctOut.Value[0]) + ring.CopyValuesLvl(level, eval.Pool[2].Q, ctOut.Value[1]) +} + +// SwitchKeysInPlace applies the general key-switching procedure of the form [c0 + cx*evakey[0], c1 + cx*evakey[1]] +// Will return the result in the same NTT domain as the input cx. +func (eval *Evaluator) SwitchKeysInPlace(levelQ int, cx *ring.Poly, evakey *SwitchingKey, p0, p1 *ring.Poly) { + + levelP := evakey.LevelP() + + if levelP > 0 { + eval.SwitchKeysInPlaceNoModDown(levelQ, cx, evakey, p0, eval.Pool[1].P, p1, eval.Pool[2].P) + } else { + eval.SwitchKeyInPlaceSinglePAndBitDecomp(levelQ, cx, evakey, p0, eval.Pool[1].P, p1, eval.Pool[2].P) + } + + if cx.IsNTT && levelP != -1 { + eval.BasisExtender.ModDownQPtoQNTT(levelQ, levelP, p0, eval.Pool[1].P, p0) + eval.BasisExtender.ModDownQPtoQNTT(levelQ, levelP, p1, eval.Pool[2].P, p1) + } else if !cx.IsNTT { + eval.params.RingQ().InvNTTLazyLvl(levelQ, p0, p0) + eval.params.RingQ().InvNTTLazyLvl(levelQ, p1, p1) + + if levelP != -1 { + eval.params.RingP().InvNTTLazyLvl(levelP, eval.Pool[1].P, eval.Pool[1].P) + eval.params.RingP().InvNTTLazyLvl(levelP, eval.Pool[2].P, eval.Pool[2].P) + eval.BasisExtender.ModDownQPtoQ(levelQ, levelP, p0, eval.Pool[1].P, p0) + eval.BasisExtender.ModDownQPtoQ(levelQ, levelP, p1, eval.Pool[2].P, p1) + } + } +} + +// DecomposeNTT applies the full RNS basis decomposition for all q_alpha_i on c2. +// Expects the IsNTT flag of c2 to correctly reflect the domain of c2. +// PoolDecompQ and PoolDecompQ are vectors of polynomials (mod Q and mod P) that store the +// special RNS decomposition of c2 (in the NTT domain) +func (eval *Evaluator) DecomposeNTT(levelQ, levelP, alpha int, c2 *ring.Poly, PoolDecomp []ringqp.Poly) { + + ringQ := eval.params.RingQ() + + var polyNTT, polyInvNTT *ring.Poly + + if c2.IsNTT { + polyNTT = c2 + polyInvNTT = eval.PoolInvNTT + ringQ.InvNTTLvl(levelQ, polyNTT, polyInvNTT) + } else { + polyNTT = eval.PoolInvNTT + polyInvNTT = c2 + ringQ.NTTLvl(levelQ, polyInvNTT, polyNTT) + } + + beta := (levelQ + 1 + levelP) / (levelP + 1) + + for i := 0; i < beta; i++ { + eval.DecomposeSingleNTT(levelQ, levelP, alpha, i, polyNTT, polyInvNTT, PoolDecomp[i].Q, PoolDecomp[i].P) + } +} + +// DecomposeSingleNTT takes the input polynomial c2 (c2NTT and c2InvNTT, respectively in the NTT and out of the NTT domain) +// modulo q_alpha_beta, and returns the result on c2QiQ are c2QiP the receiver polynomials +// respectively mod Q and mod P (in the NTT domain) +func (eval *Evaluator) DecomposeSingleNTT(levelQ, levelP, alpha, beta int, c2NTT, c2InvNTT, c2QiQ, c2QiP *ring.Poly) { + + ringQ := eval.params.RingQ() + ringP := eval.params.RingP() + + eval.Decomposer.DecomposeAndSplit(levelQ, levelP, alpha, beta, c2InvNTT, c2QiQ, c2QiP) + + p0idxst := beta * alpha + p0idxed := p0idxst + 1 + + // c2_qi = cx mod qi mod qi + for x := 0; x < levelQ+1; x++ { + if p0idxst <= x && x < p0idxed { + copy(c2QiQ.Coeffs[x], c2NTT.Coeffs[x]) + } else { + ringQ.NTTSingle(x, c2QiQ.Coeffs[x], c2QiQ.Coeffs[x]) + } + } + + if ringP != nil { + // c2QiP = c2 mod qi mod pj + ringP.NTTLvl(levelP, c2QiP, c2QiP) + } +} + +// SwitchKeysInPlaceNoModDown applies the key-switch to the polynomial cx : +// +// pool2 = dot(decomp(cx) * evakey[0]) mod QP (encrypted input is multiplied by P factor) +// pool3 = dot(decomp(cx) * evakey[1]) mod QP (encrypted input is multiplied by P factor) +// +// Expects the flag IsNTT of cx to correctly reflect the domain of cx. +func (eval *Evaluator) SwitchKeysInPlaceNoModDown(levelQ int, cx *ring.Poly, evakey *SwitchingKey, c0Q, c0P, c1Q, c1P *ring.Poly) { + + ringQ := eval.params.RingQ() + ringP := eval.params.RingP() + ringQP := eval.params.RingQP() + + c2QP := eval.Pool[0] + + var cxNTT, cxInvNTT *ring.Poly + if cx.IsNTT { + cxNTT = cx + cxInvNTT = eval.PoolInvNTT + ringQ.InvNTTLvl(levelQ, cxNTT, cxInvNTT) + } else { + cxNTT = eval.PoolInvNTT + cxInvNTT = cx + ringQ.NTTLvl(levelQ, cxInvNTT, cxNTT) + } + + c0QP := ringqp.Poly{Q: c0Q, P: c0P} + c1QP := ringqp.Poly{Q: c1Q, P: c1P} + + levelP := evakey.Value[0][0][0].P.Level() + decompRNS := (levelQ + 1 + levelP) / (levelP + 1) + + QiOverF := eval.params.QiOverflowMargin(levelQ) >> 1 + PiOverF := eval.params.PiOverflowMargin(levelP) >> 1 + + // Key switching with CRT decomposition for the Qi + var reduce int + for i := 0; i < decompRNS; i++ { + + eval.DecomposeSingleNTT(levelQ, levelP, levelP+1, i, cxNTT, cxInvNTT, c2QP.Q, c2QP.P) + + if i == 0 { + ringQP.MulCoeffsMontgomeryConstantLvl(levelQ, levelP, evakey.Value[i][0][0], c2QP, c0QP) + ringQP.MulCoeffsMontgomeryConstantLvl(levelQ, levelP, evakey.Value[i][0][1], c2QP, c1QP) + } else { + ringQP.MulCoeffsMontgomeryConstantAndAddNoModLvl(levelQ, levelP, evakey.Value[i][0][0], c2QP, c0QP) + ringQP.MulCoeffsMontgomeryConstantAndAddNoModLvl(levelQ, levelP, evakey.Value[i][0][1], c2QP, c1QP) + } + + if reduce%QiOverF == QiOverF-1 { + ringQ.ReduceLvl(levelQ, c0QP.Q, c0QP.Q) + ringQ.ReduceLvl(levelQ, c1QP.Q, c1QP.Q) + } + + if reduce%PiOverF == PiOverF-1 { + ringP.ReduceLvl(levelP, c0QP.P, c0QP.P) + ringP.ReduceLvl(levelP, c1QP.P, c1QP.P) + } + + reduce++ + } + + if reduce%QiOverF != 0 { + ringQ.ReduceLvl(levelQ, c0QP.Q, c0QP.Q) + ringQ.ReduceLvl(levelQ, c1QP.Q, c1QP.Q) + } + + if reduce%PiOverF != 0 { + ringP.ReduceLvl(levelP, c0QP.P, c0QP.P) + ringP.ReduceLvl(levelP, c1QP.P, c1QP.P) + } +} + +// SwitchKeyInPlaceSinglePAndBitDecomp applies the key-switch to the polynomial cx : +// +// pool2 = dot(decomp(cx) * evakey[0]) mod QP (encrypted input is multiplied by P factor) +// pool3 = dot(decomp(cx) * evakey[1]) mod QP (encrypted input is multiplied by P factor) +// +// Expects the flag IsNTT of cx to correctly reflect the domain of cx. +func (eval *Evaluator) SwitchKeyInPlaceSinglePAndBitDecomp(levelQ int, cx *ring.Poly, evakey *SwitchingKey, c0Q, c0P, c1Q, c1P *ring.Poly) { + + ringQ := eval.params.RingQ() + ringP := eval.params.RingP() + + var cxInvNTT *ring.Poly + if cx.IsNTT { + cxInvNTT = eval.PoolInvNTT + ringQ.InvNTTLvl(levelQ, cx, cxInvNTT) + } else { + cxInvNTT = cx + } + + c0QP := ringqp.Poly{Q: c0Q, P: c0P} + c1QP := ringqp.Poly{Q: c1Q, P: c1P} + + var levelP int + if evakey.Value[0][0][0].P != nil { + levelP = evakey.Value[0][0][0].P.Level() + } else { + levelP = -1 + } + + decompRNS := eval.params.DecompRNS(levelQ, levelP) + decompBIT := eval.params.DecompBIT(levelQ, levelP) + + lb2 := eval.params.logbase2 + + mask := uint64(((1 << lb2) - 1)) + + if mask == 0 { + mask = 0xFFFFFFFFFFFFFFFF + } + + cw := eval.Pool[0].Q.Coeffs[0] + cwNTT := eval.PoolBitDecomp + + QiOverF := eval.params.QiOverflowMargin(levelQ) >> 1 + PiOverF := eval.params.PiOverflowMargin(levelP) >> 1 + + // Key switching with CRT decomposition for the Qi + var reduce int + for i := 0; i < decompRNS; i++ { + for j := 0; j < decompBIT; j++ { + + ring.MaskVec(cxInvNTT.Coeffs[i], cw, j*lb2, mask) + + if i == 0 && j == 0 { + for u := 0; u < levelQ+1; u++ { + ringQ.NTTSingleLazy(u, cw, cwNTT) + ring.MulCoeffsMontgomeryConstantVec(evakey.Value[i][j][0].Q.Coeffs[u], cwNTT, c0QP.Q.Coeffs[u], ringQ.Modulus[u], ringQ.MredParams[u]) + ring.MulCoeffsMontgomeryConstantVec(evakey.Value[i][j][1].Q.Coeffs[u], cwNTT, c1QP.Q.Coeffs[u], ringQ.Modulus[u], ringQ.MredParams[u]) + } + + for u := 0; u < levelP+1; u++ { + ringP.NTTSingleLazy(u, cw, cwNTT) + ring.MulCoeffsMontgomeryConstantVec(evakey.Value[i][j][0].P.Coeffs[u], cwNTT, c0QP.P.Coeffs[u], ringP.Modulus[u], ringP.MredParams[u]) + ring.MulCoeffsMontgomeryConstantVec(evakey.Value[i][j][1].P.Coeffs[u], cwNTT, c1QP.P.Coeffs[u], ringP.Modulus[u], ringP.MredParams[u]) + } + } else { + for u := 0; u < levelQ+1; u++ { + ringQ.NTTSingleLazy(u, cw, cwNTT) + ring.MulCoeffsMontgomeryConstantAndAddNoModVec(evakey.Value[i][j][0].Q.Coeffs[u], cwNTT, c0QP.Q.Coeffs[u], ringQ.Modulus[u], ringQ.MredParams[u]) + ring.MulCoeffsMontgomeryConstantAndAddNoModVec(evakey.Value[i][j][1].Q.Coeffs[u], cwNTT, c1QP.Q.Coeffs[u], ringQ.Modulus[u], ringQ.MredParams[u]) + } + + for u := 0; u < levelP+1; u++ { + ringP.NTTSingleLazy(u, cw, cwNTT) + ring.MulCoeffsMontgomeryConstantAndAddNoModVec(evakey.Value[i][j][0].P.Coeffs[u], cwNTT, c0QP.P.Coeffs[u], ringP.Modulus[u], ringP.MredParams[u]) + ring.MulCoeffsMontgomeryConstantAndAddNoModVec(evakey.Value[i][j][1].P.Coeffs[u], cwNTT, c1QP.P.Coeffs[u], ringP.Modulus[u], ringP.MredParams[u]) + } + } + + if reduce%QiOverF == QiOverF-1 { + ringQ.ReduceLvl(levelQ, c0QP.Q, c0QP.Q) + ringQ.ReduceLvl(levelQ, c1QP.Q, c1QP.Q) + } + + if reduce%PiOverF == PiOverF-1 { + ringP.ReduceLvl(levelP, c0QP.P, c0QP.P) + ringP.ReduceLvl(levelP, c1QP.P, c1QP.P) + } + + reduce++ + } + } + + if reduce%QiOverF != 0 { + ringQ.ReduceLvl(levelQ, c0QP.Q, c0QP.Q) + ringQ.ReduceLvl(levelQ, c1QP.Q, c1QP.Q) + } + + if reduce%PiOverF != 0 { + ringP.ReduceLvl(levelP, c0QP.P, c0QP.P) + ringP.ReduceLvl(levelP, c1QP.P, c1QP.P) + } +} + +// KeyswitchHoisted applies the key-switch to the decomposed polynomial c2 mod QP (PoolDecompQ and PoolDecompP) +// and divides the result by P, reducing the basis from QP to Q. +// +// pool2 = dot(PoolDecompQ||PoolDecompP * evakey[0]) mod Q +// pool3 = dot(PoolDecompQ||PoolDecompP * evakey[1]) mod Q +func (eval *Evaluator) KeyswitchHoisted(levelQ int, PoolDecompQP []ringqp.Poly, evakey *SwitchingKey, c0Q, c1Q, c0P, c1P *ring.Poly) { + + eval.KeyswitchHoistedNoModDown(levelQ, PoolDecompQP, evakey, c0Q, c1Q, c0P, c1P) + + levelP := evakey.Value[0][0][0].P.Level() + + // Computes c0Q = c0Q/c0P and c1Q = c1Q/c1P + eval.BasisExtender.ModDownQPtoQNTT(levelQ, levelP, c0Q, c0P, c0Q) + eval.BasisExtender.ModDownQPtoQNTT(levelQ, levelP, c1Q, c1P, c1Q) +} + +// KeyswitchHoistedNoModDown applies the key-switch to the decomposed polynomial c2 mod QP (PoolDecompQ and PoolDecompP) +// +// pool2 = dot(PoolDecompQ||PoolDecompP * evakey[0]) mod QP +// pool3 = dot(PoolDecompQ||PoolDecompP * evakey[1]) mod QP +func (eval *Evaluator) KeyswitchHoistedNoModDown(levelQ int, PoolDecompQP []ringqp.Poly, evakey *SwitchingKey, c0Q, c1Q, c0P, c1P *ring.Poly) { + + ringQ := eval.params.RingQ() + ringP := eval.params.RingP() + ringQP := eval.params.RingQP() + + c0QP := ringqp.Poly{Q: c0Q, P: c0P} + c1QP := ringqp.Poly{Q: c1Q, P: c1P} + + levelP := evakey.Value[0][0][0].P.Level() + decompRNS := (levelQ + 1 + levelP) / (levelP + 1) + + QiOverF := eval.params.QiOverflowMargin(levelQ) >> 1 + PiOverF := eval.params.PiOverflowMargin(levelP) >> 1 + + // Key switching with CRT decomposition for the Qi + var reduce int + for i := 0; i < decompRNS; i++ { + + if i == 0 { + ringQP.MulCoeffsMontgomeryConstantLvl(levelQ, levelP, evakey.Value[i][0][0], PoolDecompQP[i], c0QP) + ringQP.MulCoeffsMontgomeryConstantLvl(levelQ, levelP, evakey.Value[i][0][1], PoolDecompQP[i], c1QP) + } else { + ringQP.MulCoeffsMontgomeryConstantAndAddNoModLvl(levelQ, levelP, evakey.Value[i][0][0], PoolDecompQP[i], c0QP) + ringQP.MulCoeffsMontgomeryConstantAndAddNoModLvl(levelQ, levelP, evakey.Value[i][0][1], PoolDecompQP[i], c1QP) + } + + if reduce%QiOverF == QiOverF-1 { + ringQ.ReduceLvl(levelQ, c0QP.Q, c0QP.Q) + ringQ.ReduceLvl(levelQ, c1QP.Q, c1QP.Q) + } + + if reduce%PiOverF == PiOverF-1 { + ringP.ReduceLvl(levelP, c0QP.P, c0QP.P) + ringP.ReduceLvl(levelP, c1QP.P, c1QP.P) + } + + reduce++ + } + + if reduce%QiOverF != 0 { + ringQ.ReduceLvl(levelQ, c0QP.Q, c0QP.Q) + ringQ.ReduceLvl(levelQ, c1QP.Q, c1QP.Q) + } + + if reduce%PiOverF != 0 { + ringP.ReduceLvl(levelP, c0QP.P, c0QP.P) + ringP.ReduceLvl(levelP, c1QP.P, c1QP.P) + } +} diff --git a/rlwe/evaluator.go b/rlwe/evaluator.go index 56947fec..2b14ff13 100644 --- a/rlwe/evaluator.go +++ b/rlwe/evaluator.go @@ -1,10 +1,8 @@ package rlwe import ( - "fmt" "github.com/tuneinsight/lattigo/v3/ring" "github.com/tuneinsight/lattigo/v3/rlwe/ringqp" - "github.com/tuneinsight/lattigo/v3/utils" "math/bits" ) @@ -134,171 +132,6 @@ func (eval *Evaluator) WithKey(evaluationKey *EvaluationKey) *Evaluator { } } -// Automorphism computes phi(ct), where phi is the map X -> X^galEl. The method requires -// that the corresponding RotationKey has been added to the Evaluator. The method will -// panic if either ctIn or ctOut degree is not equal to 1. -func (eval *Evaluator) Automorphism(ctIn *Ciphertext, galEl uint64, ctOut *Ciphertext) { - - if ctIn.Degree() != 1 || ctOut.Degree() != 1 { - panic("cannot apply Automorphism: input and output Ciphertext must be of degree 1") - } - - if galEl == 1 { - if ctOut != ctIn { - ctOut.Copy(ctIn) - } - return - } - - rtk, generated := eval.Rtks.GetRotationKey(galEl) - if !generated { - panic(fmt.Sprintf("galEl key 5^%d missing", eval.params.InverseGaloisElement(galEl))) - } - - level := utils.MinInt(ctIn.Level(), ctOut.Level()) - - ringQ := eval.params.RingQ() - - eval.SwitchKeysInPlace(level, ctIn.Value[1], rtk, eval.Pool[1].Q, eval.Pool[2].Q) - ringQ.AddLvl(level, eval.Pool[1].Q, ctIn.Value[0], eval.Pool[1].Q) - - if ctIn.Value[0].IsNTT { - ringQ.PermuteNTTWithIndexLvl(level, eval.Pool[1].Q, eval.PermuteNTTIndex[galEl], ctOut.Value[0]) - ringQ.PermuteNTTWithIndexLvl(level, eval.Pool[2].Q, eval.PermuteNTTIndex[galEl], ctOut.Value[1]) - } else { - ringQ.PermuteLvl(level, eval.Pool[1].Q, galEl, ctOut.Value[0]) - ringQ.PermuteLvl(level, eval.Pool[2].Q, galEl, ctOut.Value[1]) - } - - ctOut.Value[0].Coeffs = ctOut.Value[0].Coeffs[:level+1] - ctOut.Value[1].Coeffs = ctOut.Value[1].Coeffs[:level+1] -} - -// AutomorphismHoisted is similar to Automorphism, except that it takes as input ctIn and c1DecompQP, where c1DecompQP is the RNS -// decomposition of its element of degree 1. This decomposition can be obtained with DecomposeNTT. -// The method requires that the corresponding RotationKey has been added to the Evaluator. -// The method will panic if either ctIn or ctOut degree is not equal to 1. -func (eval *Evaluator) AutomorphismHoisted(level int, ctIn *Ciphertext, c1DecompQP []ringqp.Poly, galEl uint64, ctOut *Ciphertext) { - - if ctIn.Degree() != 1 || ctOut.Degree() != 1 { - panic("cannot apply AutomorphismHoisted: input and output Ciphertext must be of degree 1") - } - - if galEl == 1 { - if ctIn != ctOut { - ctOut.Copy(ctIn) - } - return - } - - rtk, generated := eval.Rtks.GetRotationKey(galEl) - if !generated { - panic(fmt.Sprintf("galEl key 5^%d missing", eval.params.InverseGaloisElement(galEl))) - } - - ringQ := eval.params.RingQ() - - eval.KeyswitchHoisted(level, c1DecompQP, rtk, eval.Pool[0].Q, eval.Pool[1].Q, eval.Pool[0].P, eval.Pool[1].P) - ringQ.AddLvl(level, eval.Pool[0].Q, ctIn.Value[0], eval.Pool[0].Q) - - if ctIn.Value[0].IsNTT { - ringQ.PermuteNTTWithIndexLvl(level, eval.Pool[0].Q, eval.PermuteNTTIndex[galEl], ctOut.Value[0]) - ringQ.PermuteNTTWithIndexLvl(level, eval.Pool[1].Q, eval.PermuteNTTIndex[galEl], ctOut.Value[1]) - } else { - ringQ.PermuteLvl(level, eval.Pool[0].Q, galEl, ctOut.Value[0]) - ringQ.PermuteLvl(level, eval.Pool[1].Q, galEl, ctOut.Value[1]) - } - - ctOut.Value[0].Coeffs = ctOut.Value[0].Coeffs[:level+1] - ctOut.Value[1].Coeffs = ctOut.Value[1].Coeffs[:level+1] -} - -// AutomorphismHoistedNoModDown is similar to AutomorphismHoisted, except that it returns a ciphertext modulo QP and scaled by P. -// The method requires that the corresponding RotationKey has been added to the Evaluator.The method will panic if either ctIn or ctOut degree is not equal to 1. -func (eval *Evaluator) AutomorphismHoistedNoModDown(levelQ int, c0 *ring.Poly, c1DecompQP []ringqp.Poly, galEl uint64, ct0OutQ, ct1OutQ, ct0OutP, ct1OutP *ring.Poly) { - - levelP := eval.params.PCount() - 1 - - rtk, generated := eval.Rtks.GetRotationKey(galEl) - if !generated { - panic(fmt.Sprintf("galEl key 5^%d missing", eval.params.InverseGaloisElement(galEl))) - } - - eval.KeyswitchHoistedNoModDown(levelQ, c1DecompQP, rtk, eval.Pool[0].Q, eval.Pool[1].Q, eval.Pool[0].P, eval.Pool[1].P) - - ringQ := eval.params.RingQ() - - if c0.IsNTT { - - index := eval.PermuteNTTIndex[galEl] - - ringQ.PermuteNTTWithIndexLvl(levelQ, eval.Pool[1].Q, index, ct1OutQ) - ringQ.PermuteNTTWithIndexLvl(levelP, eval.Pool[1].P, index, ct1OutP) - - ringQ.MulScalarBigintLvl(levelQ, c0, eval.params.RingP().ModulusBigint, eval.Pool[1].Q) - ringQ.AddLvl(levelQ, eval.Pool[0].Q, eval.Pool[1].Q, eval.Pool[0].Q) - - ringQ.PermuteNTTWithIndexLvl(levelQ, eval.Pool[0].Q, index, ct0OutQ) - ringQ.PermuteNTTWithIndexLvl(levelP, eval.Pool[0].P, index, ct0OutP) - } else { - ringQ.PermuteLvl(levelQ, eval.Pool[1].Q, galEl, ct1OutQ) - ringQ.PermuteLvl(levelP, eval.Pool[1].P, galEl, ct1OutP) - - ringQ.MulScalarBigintLvl(levelQ, c0, eval.params.RingP().ModulusBigint, eval.Pool[1].Q) - ringQ.AddLvl(levelQ, eval.Pool[0].Q, eval.Pool[1].Q, eval.Pool[0].Q) - - ringQ.PermuteLvl(levelQ, eval.Pool[0].Q, galEl, ct0OutQ) - ringQ.PermuteLvl(levelP, eval.Pool[0].P, galEl, ct0OutP) - } -} - -// SwitchKeys re-encrypts ctIn under a different key and returns the result in ctOut. -// It requires a SwitchingKey, which is computed from the key under which the Ciphertext is currently encrypted, -// and the key under which the Ciphertext will be re-encrypted. -// The method will panic if either ctIn or ctOut degree isn't 1. -func (eval *Evaluator) SwitchKeys(ctIn *Ciphertext, switchingKey *SwitchingKey, ctOut *Ciphertext) { - - if ctIn.Degree() != 1 || ctOut.Degree() != 1 { - panic("cannot SwitchKeys: input and output Ciphertext must be of degree 1") - } - - level := utils.MinInt(ctIn.Level(), ctOut.Level()) - ringQ := eval.params.RingQ() - - eval.SwitchKeysInPlace(level, ctIn.Value[1], switchingKey, eval.Pool[1].Q, eval.Pool[2].Q) - - ringQ.AddLvl(level, ctIn.Value[0], eval.Pool[1].Q, ctOut.Value[0]) - ring.CopyValuesLvl(level, eval.Pool[2].Q, ctOut.Value[1]) -} - -// Relinearize applies the relinearization procedure on ct0 and returns the result in ctOut. -// The method will panic if the corresponding relinearization key to the ciphertext degree -// is missing. -func (eval *Evaluator) Relinearize(ctIn *Ciphertext, ctOut *Ciphertext) { - if eval.Rlk == nil || ctIn.Degree()-1 > len(eval.Rlk.Keys) { - panic("cannot Relinearize: relinearization key missing (or ciphertext degree is too large)") - } - - level := utils.MinInt(ctIn.Level(), ctOut.Level()) - - ringQ := eval.params.RingQ() - - eval.SwitchKeysInPlace(level, ctIn.Value[2], eval.Rlk.Keys[0], eval.Pool[1].Q, eval.Pool[2].Q) - ringQ.AddLvl(level, ctIn.Value[0], eval.Pool[1].Q, ctOut.Value[0]) - ringQ.AddLvl(level, ctIn.Value[1], eval.Pool[2].Q, ctOut.Value[1]) - - for deg := ctIn.Degree() - 1; deg > 1; deg-- { - eval.SwitchKeysInPlace(level, ctIn.Value[deg], eval.Rlk.Keys[deg-2], eval.Pool[1].Q, eval.Pool[2].Q) - ringQ.AddLvl(level, ctOut.Value[0], eval.Pool[1].Q, ctOut.Value[0]) - ringQ.AddLvl(level, ctOut.Value[1], eval.Pool[2].Q, ctOut.Value[1]) - } - - ctOut.Value = ctOut.Value[:2] - - ctOut.Value[0].Coeffs = ctOut.Value[0].Coeffs[:level+1] - ctOut.Value[1].Coeffs = ctOut.Value[1].Coeffs[:level+1] -} - // MergeRLWE merges a batch of RLWE, packing the first coefficient of each RLWE into a single RLWE. // The operation will require N/gap + log(gap) key-switches, where gap is the minimum gap between // two non-zero coefficients of the final ciphertext. @@ -413,9 +246,9 @@ func (eval *Evaluator) mergeRLWERecurse(ciphertexts []*Ciphertext, xPow []*ring. // if L-2 == -1, then gal = -1 if L == 1 { - eval.rotate(tmpEven, uint64(2*ringQ.N-1), tmpEven) + eval.Automorphism(tmpEven, uint64(2*ringQ.N-1), tmpEven) } else { - eval.rotate(tmpEven, eval.params.GaloisElementForColumnRotationBy(1<<(L-2)), tmpEven) + eval.Automorphism(tmpEven, eval.params.GaloisElementForColumnRotationBy(1<<(L-2)), tmpEven) } // ctEven + ctOdd * X^(N/2^L) + phi(ctEven - ctOdd * X^(N/2^L), 2^(L-2)) @@ -425,344 +258,3 @@ func (eval *Evaluator) mergeRLWERecurse(ciphertexts []*Ciphertext, xPow []*ring. return ctEven } - -func (eval *Evaluator) rotate(ctIn *Ciphertext, galEl uint64, ctOut *Ciphertext) { - ringQ := eval.params.RingQ() - rtk, _ := eval.Rtks.GetRotationKey(galEl) - level := utils.MinInt(ctIn.Level(), ctOut.Level()) - index := eval.PermuteNTTIndex[galEl] - eval.SwitchKeysInPlace(level, ctIn.Value[1], rtk, eval.Pool[1].Q, eval.Pool[2].Q) - ringQ.AddLvl(level, eval.Pool[1].Q, ctIn.Value[0], eval.Pool[1].Q) - ringQ.PermuteNTTWithIndexLvl(level, eval.Pool[1].Q, index, ctOut.Value[0]) - ringQ.PermuteNTTWithIndexLvl(level, eval.Pool[2].Q, index, ctOut.Value[1]) -} - -// SwitchKeysInPlace applies the general key-switching procedure of the form [c0 + cx*evakey[0], c1 + cx*evakey[1]] -// Will return the result in the same NTT domain as the input cx. -func (eval *Evaluator) SwitchKeysInPlace(levelQ int, cx *ring.Poly, evakey *SwitchingKey, p0, p1 *ring.Poly) { - - levelP := evakey.LevelP() - - if levelP > 0 { - eval.SwitchKeysInPlaceNoModDown(levelQ, cx, evakey, p0, eval.Pool[1].P, p1, eval.Pool[2].P) - } else { - eval.SwitchKeyInPlaceSinglePAndBitDecomp(levelQ, cx, evakey, p0, eval.Pool[1].P, p1, eval.Pool[2].P) - } - - if cx.IsNTT && levelP != -1 { - eval.BasisExtender.ModDownQPtoQNTT(levelQ, levelP, p0, eval.Pool[1].P, p0) - eval.BasisExtender.ModDownQPtoQNTT(levelQ, levelP, p1, eval.Pool[2].P, p1) - } else if !cx.IsNTT { - eval.params.RingQ().InvNTTLazyLvl(levelQ, p0, p0) - eval.params.RingQ().InvNTTLazyLvl(levelQ, p1, p1) - - if levelP != -1 { - eval.params.RingP().InvNTTLazyLvl(levelP, eval.Pool[1].P, eval.Pool[1].P) - eval.params.RingP().InvNTTLazyLvl(levelP, eval.Pool[2].P, eval.Pool[2].P) - eval.BasisExtender.ModDownQPtoQ(levelQ, levelP, p0, eval.Pool[1].P, p0) - eval.BasisExtender.ModDownQPtoQ(levelQ, levelP, p1, eval.Pool[2].P, p1) - } - } -} - -// DecomposeNTT applies the full RNS basis decomposition for all q_alpha_i on c2. -// Expects the IsNTT flag of c2 to correctly reflect the domain of c2. -// PoolDecompQ and PoolDecompQ are vectors of polynomials (mod Q and mod P) that store the -// special RNS decomposition of c2 (in the NTT domain) -func (eval *Evaluator) DecomposeNTT(levelQ, levelP, alpha int, c2 *ring.Poly, PoolDecomp []ringqp.Poly) { - - ringQ := eval.params.RingQ() - - var polyNTT, polyInvNTT *ring.Poly - - if c2.IsNTT { - polyNTT = c2 - polyInvNTT = eval.PoolInvNTT - ringQ.InvNTTLvl(levelQ, polyNTT, polyInvNTT) - } else { - polyNTT = eval.PoolInvNTT - polyInvNTT = c2 - ringQ.NTTLvl(levelQ, polyInvNTT, polyNTT) - } - - beta := (levelQ + 1 + levelP) / (levelP + 1) - - for i := 0; i < beta; i++ { - eval.DecomposeSingleNTT(levelQ, levelP, alpha, i, polyNTT, polyInvNTT, PoolDecomp[i].Q, PoolDecomp[i].P) - } -} - -// DecomposeSingleNTT takes the input polynomial c2 (c2NTT and c2InvNTT, respectively in the NTT and out of the NTT domain) -// modulo q_alpha_beta, and returns the result on c2QiQ are c2QiP the receiver polynomials -// respectively mod Q and mod P (in the NTT domain) -func (eval *Evaluator) DecomposeSingleNTT(levelQ, levelP, alpha, beta int, c2NTT, c2InvNTT, c2QiQ, c2QiP *ring.Poly) { - - ringQ := eval.params.RingQ() - ringP := eval.params.RingP() - - eval.Decomposer.DecomposeAndSplit(levelQ, levelP, alpha, beta, c2InvNTT, c2QiQ, c2QiP) - - p0idxst := beta * alpha - p0idxed := p0idxst + 1 - - // c2_qi = cx mod qi mod qi - for x := 0; x < levelQ+1; x++ { - if p0idxst <= x && x < p0idxed { - copy(c2QiQ.Coeffs[x], c2NTT.Coeffs[x]) - } else { - ringQ.NTTSingle(x, c2QiQ.Coeffs[x], c2QiQ.Coeffs[x]) - } - } - - if ringP != nil { - // c2QiP = c2 mod qi mod pj - ringP.NTTLvl(levelP, c2QiP, c2QiP) - } -} - -// SwitchKeysInPlaceNoModDown applies the key-switch to the polynomial cx : -// -// pool2 = dot(decomp(cx) * evakey[0]) mod QP (encrypted input is multiplied by P factor) -// pool3 = dot(decomp(cx) * evakey[1]) mod QP (encrypted input is multiplied by P factor) -// -// Expects the flag IsNTT of cx to correctly reflect the domain of cx. -func (eval *Evaluator) SwitchKeysInPlaceNoModDown(levelQ int, cx *ring.Poly, evakey *SwitchingKey, c0Q, c0P, c1Q, c1P *ring.Poly) { - - ringQ := eval.params.RingQ() - ringP := eval.params.RingP() - ringQP := eval.params.RingQP() - - c2QP := eval.Pool[0] - - var cxNTT, cxInvNTT *ring.Poly - if cx.IsNTT { - cxNTT = cx - cxInvNTT = eval.PoolInvNTT - ringQ.InvNTTLvl(levelQ, cxNTT, cxInvNTT) - } else { - cxNTT = eval.PoolInvNTT - cxInvNTT = cx - ringQ.NTTLvl(levelQ, cxInvNTT, cxNTT) - } - - c0QP := ringqp.Poly{Q: c0Q, P: c0P} - c1QP := ringqp.Poly{Q: c1Q, P: c1P} - - levelP := evakey.Value[0][0][0].P.Level() - decompRNS := (levelQ + 1 + levelP) / (levelP + 1) - - QiOverF := eval.params.QiOverflowMargin(levelQ) >> 1 - PiOverF := eval.params.PiOverflowMargin(levelP) >> 1 - - // Key switching with CRT decomposition for the Qi - var reduce int - for i := 0; i < decompRNS; i++ { - - eval.DecomposeSingleNTT(levelQ, levelP, levelP+1, i, cxNTT, cxInvNTT, c2QP.Q, c2QP.P) - - if i == 0 { - ringQP.MulCoeffsMontgomeryConstantLvl(levelQ, levelP, evakey.Value[i][0][0], c2QP, c0QP) - ringQP.MulCoeffsMontgomeryConstantLvl(levelQ, levelP, evakey.Value[i][0][1], c2QP, c1QP) - } else { - ringQP.MulCoeffsMontgomeryConstantAndAddNoModLvl(levelQ, levelP, evakey.Value[i][0][0], c2QP, c0QP) - ringQP.MulCoeffsMontgomeryConstantAndAddNoModLvl(levelQ, levelP, evakey.Value[i][0][1], c2QP, c1QP) - } - - if reduce%QiOverF == QiOverF-1 { - ringQ.ReduceLvl(levelQ, c0QP.Q, c0QP.Q) - ringQ.ReduceLvl(levelQ, c1QP.Q, c1QP.Q) - } - - if reduce%PiOverF == PiOverF-1 { - ringP.ReduceLvl(levelP, c0QP.P, c0QP.P) - ringP.ReduceLvl(levelP, c1QP.P, c1QP.P) - } - - reduce++ - } - - if reduce%QiOverF != 0 { - ringQ.ReduceLvl(levelQ, c0QP.Q, c0QP.Q) - ringQ.ReduceLvl(levelQ, c1QP.Q, c1QP.Q) - } - - if reduce%PiOverF != 0 { - ringP.ReduceLvl(levelP, c0QP.P, c0QP.P) - ringP.ReduceLvl(levelP, c1QP.P, c1QP.P) - } -} - -// SwitchKeyInPlaceSinglePAndBitDecomp applies the key-switch to the polynomial cx : -// -// pool2 = dot(decomp(cx) * evakey[0]) mod QP (encrypted input is multiplied by P factor) -// pool3 = dot(decomp(cx) * evakey[1]) mod QP (encrypted input is multiplied by P factor) -// -// Expects the flag IsNTT of cx to correctly reflect the domain of cx. -func (eval *Evaluator) SwitchKeyInPlaceSinglePAndBitDecomp(levelQ int, cx *ring.Poly, evakey *SwitchingKey, c0Q, c0P, c1Q, c1P *ring.Poly) { - - ringQ := eval.params.RingQ() - ringP := eval.params.RingP() - - var cxInvNTT *ring.Poly - if cx.IsNTT { - cxInvNTT = eval.PoolInvNTT - ringQ.InvNTTLvl(levelQ, cx, cxInvNTT) - } else { - cxInvNTT = cx - } - - c0QP := ringqp.Poly{Q: c0Q, P: c0P} - c1QP := ringqp.Poly{Q: c1Q, P: c1P} - - var levelP int - if evakey.Value[0][0][0].P != nil { - levelP = evakey.Value[0][0][0].P.Level() - } else { - levelP = -1 - } - - decompRNS := eval.params.DecompRNS(levelQ, levelP) - decompBIT := eval.params.DecompBIT(levelQ, levelP) - - lb2 := eval.params.logbase2 - - mask := uint64(((1 << lb2) - 1)) - - if mask == 0 { - mask = 0xFFFFFFFFFFFFFFFF - } - - cw := eval.Pool[0].Q.Coeffs[0] - cwNTT := eval.PoolBitDecomp - - QiOverF := eval.params.QiOverflowMargin(levelQ) >> 1 - PiOverF := eval.params.PiOverflowMargin(levelP) >> 1 - - // Key switching with CRT decomposition for the Qi - var reduce int - for i := 0; i < decompRNS; i++ { - for j := 0; j < decompBIT; j++ { - - ring.MaskVec(cxInvNTT.Coeffs[i], cw, j*lb2, mask) - - if i == 0 && j == 0 { - for u := 0; u < levelQ+1; u++ { - ringQ.NTTSingleLazy(u, cw, cwNTT) - ring.MulCoeffsMontgomeryConstantVec(evakey.Value[i][j][0].Q.Coeffs[u], cwNTT, c0QP.Q.Coeffs[u], ringQ.Modulus[u], ringQ.MredParams[u]) - ring.MulCoeffsMontgomeryConstantVec(evakey.Value[i][j][1].Q.Coeffs[u], cwNTT, c1QP.Q.Coeffs[u], ringQ.Modulus[u], ringQ.MredParams[u]) - } - - for u := 0; u < levelP+1; u++ { - ringP.NTTSingleLazy(u, cw, cwNTT) - ring.MulCoeffsMontgomeryConstantVec(evakey.Value[i][j][0].P.Coeffs[u], cwNTT, c0QP.P.Coeffs[u], ringP.Modulus[u], ringP.MredParams[u]) - ring.MulCoeffsMontgomeryConstantVec(evakey.Value[i][j][1].P.Coeffs[u], cwNTT, c1QP.P.Coeffs[u], ringP.Modulus[u], ringP.MredParams[u]) - } - } else { - for u := 0; u < levelQ+1; u++ { - ringQ.NTTSingleLazy(u, cw, cwNTT) - ring.MulCoeffsMontgomeryConstantAndAddNoModVec(evakey.Value[i][j][0].Q.Coeffs[u], cwNTT, c0QP.Q.Coeffs[u], ringQ.Modulus[u], ringQ.MredParams[u]) - ring.MulCoeffsMontgomeryConstantAndAddNoModVec(evakey.Value[i][j][1].Q.Coeffs[u], cwNTT, c1QP.Q.Coeffs[u], ringQ.Modulus[u], ringQ.MredParams[u]) - } - - for u := 0; u < levelP+1; u++ { - ringP.NTTSingleLazy(u, cw, cwNTT) - ring.MulCoeffsMontgomeryConstantAndAddNoModVec(evakey.Value[i][j][0].P.Coeffs[u], cwNTT, c0QP.P.Coeffs[u], ringP.Modulus[u], ringP.MredParams[u]) - ring.MulCoeffsMontgomeryConstantAndAddNoModVec(evakey.Value[i][j][1].P.Coeffs[u], cwNTT, c1QP.P.Coeffs[u], ringP.Modulus[u], ringP.MredParams[u]) - } - } - - if reduce%QiOverF == QiOverF-1 { - ringQ.ReduceLvl(levelQ, c0QP.Q, c0QP.Q) - ringQ.ReduceLvl(levelQ, c1QP.Q, c1QP.Q) - } - - if reduce%PiOverF == PiOverF-1 { - ringP.ReduceLvl(levelP, c0QP.P, c0QP.P) - ringP.ReduceLvl(levelP, c1QP.P, c1QP.P) - } - - reduce++ - } - } - - if reduce%QiOverF != 0 { - ringQ.ReduceLvl(levelQ, c0QP.Q, c0QP.Q) - ringQ.ReduceLvl(levelQ, c1QP.Q, c1QP.Q) - } - - if reduce%PiOverF != 0 { - ringP.ReduceLvl(levelP, c0QP.P, c0QP.P) - ringP.ReduceLvl(levelP, c1QP.P, c1QP.P) - } -} - -// KeyswitchHoisted applies the key-switch to the decomposed polynomial c2 mod QP (PoolDecompQ and PoolDecompP) -// and divides the result by P, reducing the basis from QP to Q. -// -// pool2 = dot(PoolDecompQ||PoolDecompP * evakey[0]) mod Q -// pool3 = dot(PoolDecompQ||PoolDecompP * evakey[1]) mod Q -func (eval *Evaluator) KeyswitchHoisted(levelQ int, PoolDecompQP []ringqp.Poly, evakey *SwitchingKey, c0Q, c1Q, c0P, c1P *ring.Poly) { - - eval.KeyswitchHoistedNoModDown(levelQ, PoolDecompQP, evakey, c0Q, c1Q, c0P, c1P) - - levelP := evakey.Value[0][0][0].P.Level() - - // Computes c0Q = c0Q/c0P and c1Q = c1Q/c1P - eval.BasisExtender.ModDownQPtoQNTT(levelQ, levelP, c0Q, c0P, c0Q) - eval.BasisExtender.ModDownQPtoQNTT(levelQ, levelP, c1Q, c1P, c1Q) -} - -// KeyswitchHoistedNoModDown applies the key-switch to the decomposed polynomial c2 mod QP (PoolDecompQ and PoolDecompP) -// -// pool2 = dot(PoolDecompQ||PoolDecompP * evakey[0]) mod QP -// pool3 = dot(PoolDecompQ||PoolDecompP * evakey[1]) mod QP -func (eval *Evaluator) KeyswitchHoistedNoModDown(levelQ int, PoolDecompQP []ringqp.Poly, evakey *SwitchingKey, c0Q, c1Q, c0P, c1P *ring.Poly) { - - ringQ := eval.params.RingQ() - ringP := eval.params.RingP() - ringQP := eval.params.RingQP() - - c0QP := ringqp.Poly{Q: c0Q, P: c0P} - c1QP := ringqp.Poly{Q: c1Q, P: c1P} - - levelP := evakey.Value[0][0][0].P.Level() - decompRNS := (levelQ + 1 + levelP) / (levelP + 1) - - QiOverF := eval.params.QiOverflowMargin(levelQ) >> 1 - PiOverF := eval.params.PiOverflowMargin(levelP) >> 1 - - // Key switching with CRT decomposition for the Qi - var reduce int - for i := 0; i < decompRNS; i++ { - - if i == 0 { - ringQP.MulCoeffsMontgomeryConstantLvl(levelQ, levelP, evakey.Value[i][0][0], PoolDecompQP[i], c0QP) - ringQP.MulCoeffsMontgomeryConstantLvl(levelQ, levelP, evakey.Value[i][0][1], PoolDecompQP[i], c1QP) - } else { - ringQP.MulCoeffsMontgomeryConstantAndAddNoModLvl(levelQ, levelP, evakey.Value[i][0][0], PoolDecompQP[i], c0QP) - ringQP.MulCoeffsMontgomeryConstantAndAddNoModLvl(levelQ, levelP, evakey.Value[i][0][1], PoolDecompQP[i], c1QP) - } - - if reduce%QiOverF == QiOverF-1 { - ringQ.ReduceLvl(levelQ, c0QP.Q, c0QP.Q) - ringQ.ReduceLvl(levelQ, c1QP.Q, c1QP.Q) - } - - if reduce%PiOverF == PiOverF-1 { - ringP.ReduceLvl(levelP, c0QP.P, c0QP.P) - ringP.ReduceLvl(levelP, c1QP.P, c1QP.P) - } - - reduce++ - } - - if reduce%QiOverF != 0 { - ringQ.ReduceLvl(levelQ, c0QP.Q, c0QP.Q) - ringQ.ReduceLvl(levelQ, c1QP.Q, c1QP.Q) - } - - if reduce%PiOverF != 0 { - ringP.ReduceLvl(levelP, c0QP.P, c0QP.P) - ringP.ReduceLvl(levelP, c1QP.P, c1QP.P) - } -} diff --git a/rlwe/lut/evaluator.go b/rlwe/lut/evaluator.go new file mode 100644 index 00000000..661092aa --- /dev/null +++ b/rlwe/lut/evaluator.go @@ -0,0 +1,256 @@ +package lut + +import ( + "github.com/tuneinsight/lattigo/v3/ring" + "github.com/tuneinsight/lattigo/v3/rlwe" + "github.com/tuneinsight/lattigo/v3/rlwe/rgsw" + "github.com/tuneinsight/lattigo/v3/rlwe/ringqp" + "math/big" +) + +// Evaluator is a struct that stores necessary +// data to handle LWE <-> RLWE conversion and +// LUT evaluation. +type Evaluator struct { + *rlwe.Evaluator + paramsLUT rlwe.Parameters + paramsLWE rlwe.Parameters + rtks *rlwe.RotationKeySet + + xPowMinusOne []ringqp.Poly //X^n - 1 from 0 to 2N LWE + + poolMod2N [2]*ring.Poly + + accumulator *rlwe.Ciphertext + Sk *rlwe.SecretKey + + tmpRGSW *rgsw.Ciphertext +} + +// NewEvaluator creates a new Handler +func NewEvaluator(paramsLUT, paramsLWE rlwe.Parameters, rtks *rlwe.RotationKeySet) (eval *Evaluator) { + eval = new(Evaluator) + eval.Evaluator = rlwe.NewEvaluator(paramsLUT, &rlwe.EvaluationKey{Rtks: rtks}) + eval.paramsLUT = paramsLUT + eval.paramsLWE = paramsLWE + + ringQ := paramsLUT.RingQ() + ringP := paramsLUT.RingP() + + eval.poolMod2N = [2]*ring.Poly{paramsLWE.RingQ().NewPolyLvl(0), paramsLWE.RingQ().NewPolyLvl(0)} + eval.accumulator = rlwe.NewCiphertextNTT(paramsLUT, 1, paramsLUT.MaxLevel()) + + // Compute X^{n} - 1 from 0 to 2N LWE + oneNTTMFormQ := ringQ.NewPoly() + for i := range ringQ.Modulus { + for j := 0; j < ringQ.N; j++ { + oneNTTMFormQ.Coeffs[i][j] = ring.MForm(1, ringQ.Modulus[i], ringQ.BredParams[i]) + } + } + + N := ringQ.N + + eval.xPowMinusOne = make([]ringqp.Poly, 2*N) + for i := 0; i < N; i++ { + eval.xPowMinusOne[i].Q = ringQ.NewPoly() + eval.xPowMinusOne[i+N].Q = ringQ.NewPoly() + if i == 0 || i == 1 { + for j := range ringQ.Modulus { + eval.xPowMinusOne[i].Q.Coeffs[j][i] = ring.MForm(1, ringQ.Modulus[j], ringQ.BredParams[j]) + } + + ringQ.NTT(eval.xPowMinusOne[i].Q, eval.xPowMinusOne[i].Q) + + // Negacyclic wrap-arround for n > N + ringQ.Neg(eval.xPowMinusOne[i].Q, eval.xPowMinusOne[i+N].Q) + + } else { + ringQ.MulCoeffsMontgomery(eval.xPowMinusOne[1].Q, eval.xPowMinusOne[i-1].Q, eval.xPowMinusOne[i].Q) // X^{n} = X^{1} * X^{n-1} + + // Negacyclic wrap-arround for n > N + ringQ.Neg(eval.xPowMinusOne[i].Q, eval.xPowMinusOne[i+N].Q) // X^{2n} = -X^{1} * X^{n-1} + } + } + + // Subtract -1 in NTT + for i := 0; i < 2*N; i++ { + ringQ.Sub(eval.xPowMinusOne[i].Q, oneNTTMFormQ, eval.xPowMinusOne[i].Q) // X^{n} - 1 + } + + if ringP != nil { + oneNTTMFormP := ringP.NewPoly() + for i := range ringP.Modulus { + for j := 0; j < ringP.N; j++ { + oneNTTMFormP.Coeffs[i][j] = ring.MForm(1, ringP.Modulus[i], ringP.BredParams[i]) + } + } + + for i := 0; i < N; i++ { + eval.xPowMinusOne[i].P = ringP.NewPoly() + eval.xPowMinusOne[i+N].P = ringP.NewPoly() + if i == 0 || i == 1 { + for j := range ringP.Modulus { + eval.xPowMinusOne[i].P.Coeffs[j][i] = ring.MForm(1, ringP.Modulus[j], ringP.BredParams[j]) + } + + ringP.NTT(eval.xPowMinusOne[i].P, eval.xPowMinusOne[i].P) + + // Negacyclic wrap-arround for n > N + ringP.Neg(eval.xPowMinusOne[i].P, eval.xPowMinusOne[i+N].P) + + } else { + // X^{n} = X^{1} * X^{n-1} + ringP.MulCoeffsMontgomery(eval.xPowMinusOne[1].P, eval.xPowMinusOne[i-1].P, eval.xPowMinusOne[i].P) + + // Negacyclic wrap-arround for n > N + // X^{2n} = -X^{1} * X^{n-1} + ringP.Neg(eval.xPowMinusOne[i].P, eval.xPowMinusOne[i+N].P) + } + } + + // Subtract -1 in NTT + for i := 0; i < 2*N; i++ { + // X^{n} - 1 + ringP.Sub(eval.xPowMinusOne[i].P, oneNTTMFormP, eval.xPowMinusOne[i].P) + } + } + + levelQ := paramsLUT.QCount() - 1 + levelP := paramsLUT.PCount() - 1 + decompRNS := paramsLUT.DecompRNS(levelQ, levelP) + decompBIT := paramsLUT.DecompBIT(levelQ, levelP) + ringQP := paramsLUT.RingQP() + eval.tmpRGSW = rgsw.NewCiphertextNTT(levelQ, levelP, decompRNS, decompBIT, ringQP) + + return +} + +func (eval *Evaluator) permuteNTTIndexesForKey(rtks *rlwe.RotationKeySet) *map[uint64][]uint64 { + if rtks == nil { + return &map[uint64][]uint64{} + } + permuteNTTIndex := make(map[uint64][]uint64, len(rtks.Keys)) + for galEl := range rtks.Keys { + permuteNTTIndex[galEl] = eval.paramsLUT.RingQ().PermuteNTTIndex(galEl) + } + return &permuteNTTIndex +} + +// EvaluateAndRepack extracts on the fly LWE samples and evaluate the provided LUT on the LWE and repacks everything into a single rlwe.Ciphertext. +// ct : a rlwe Ciphertext with coefficient encoded values at level 0 +// lutPolyWihtSlotIndex : a map with [slot_index] -> LUT +// repackIndex : a map with [slot_index_have] -> slot_index_want +// lutKey : LUTKey +// Returns a *rlwe.Ciphertext +func (eval *Evaluator) EvaluateAndRepack(ct *rlwe.Ciphertext, lutPolyWihtSlotIndex map[int]*ring.Poly, repackIndex map[int]int, key Key) (res *rlwe.Ciphertext) { + cts := eval.Evaluate(ct, lutPolyWihtSlotIndex, key) + + ciphertexts := make(map[int]*rlwe.Ciphertext) + + for i := range cts { + ciphertexts[repackIndex[i]] = cts[i] + } + + return eval.MergeRLWE(ciphertexts) +} + +// Evaluate extracts on the fly LWE samples and evaluate the provided LUT on the LWE. +// ct : a rlwe Ciphertext with coefficient encoded values at level 0 +// lutPolyWihtSlotIndex : a map with [slot_index] -> LUT +// lutKey : lut.Key +// Returns a map[slot_index] -> LUT(ct[slot_index]) +func (eval *Evaluator) Evaluate(ct *rlwe.Ciphertext, lutPolyWihtSlotIndex map[int]*ring.Poly, key Key) (res map[int]*rlwe.Ciphertext) { + + bRLWEMod2N := eval.poolMod2N[0] + aRLWEMod2N := eval.poolMod2N[1] + + acc := eval.accumulator + + ringQLUT := eval.paramsLUT.RingQ() + ringQLWE := eval.paramsLWE.RingQ() + ringQPLUT := *eval.paramsLUT.RingQP() + + // mod 2N + mask := uint64(ringQLUT.N<<1) - 1 + + ringQLWE.InvNTTLvl(ct.Level(), ct.Value[0], acc.Value[0]) + ringQLWE.InvNTTLvl(ct.Level(), ct.Value[1], acc.Value[1]) + + // Switch modulus from Q to 2N + eval.ModSwitchRLWETo2NLvl(ct.Level(), acc.Value[1], acc.Value[1]) + + // Conversion from Convolution(a, sk) to DotProd(a, sk) for LWE decryption. + // Copy coefficients multiplied by X^{N-1} in reverse order: + // a_{0} -a_{N-1} -a2_{N-2} ... -a_{1} + tmp0 := aRLWEMod2N.Coeffs[0] + tmp1 := acc.Value[1].Coeffs[0] + tmp0[0] = tmp1[0] + for j := 1; j < ringQLWE.N; j++ { + tmp0[j] = -tmp1[ringQLWE.N-j] & mask + } + + eval.ModSwitchRLWETo2NLvl(ct.Level(), acc.Value[0], bRLWEMod2N) + + levelQ := key.SkPos[0].LevelQ() + levelP := key.SkPos[0].LevelP() + + res = make(map[int]*rlwe.Ciphertext) + + var prevIndex int + for index := 0; index < ringQLWE.N; index++ { + + if lut, ok := lutPolyWihtSlotIndex[index]; ok { + + MulBySmallMonomialMod2N(mask, aRLWEMod2N, index-prevIndex) + prevIndex = index + + a := aRLWEMod2N.Coeffs[0] + b := bRLWEMod2N.Coeffs[0][index] + + // LWE = -as + m + e, a + // LUT = LUT * X^{-as + m + e} + ringQLUT.MulCoeffsMontgomery(lut, eval.xPowMinusOne[b].Q, acc.Value[0]) + ringQLUT.Add(acc.Value[0], lut, acc.Value[0]) + acc.Value[1].Zero() // TODO remove + + for j := 0; j < ringQLWE.N; j++ { + // RGSW[(X^{a} - 1) * sk_{j}[0] + (X^{-a} - 1) * sk_{j}[1] + 1] + rgsw.MulByXPowAlphaMinusOneConstantLvl(levelQ, levelP, key.SkPos[j], eval.xPowMinusOne[a[j]], ringQPLUT, eval.tmpRGSW) + rgsw.MulByXPowAlphaMinusOneAndAddNoModLvl(levelQ, levelP, key.SkNeg[j], eval.xPowMinusOne[-a[j]&mask], ringQPLUT, eval.tmpRGSW) + rgsw.AddNoModLvl(levelQ, levelP, key.One, ringQPLUT, eval.tmpRGSW) + + // LUT[RLWE] = LUT[RLWE] x RGSW[(X^{a} - 1) * sk_{j}[0] + (X^{-a} - 1) * sk_{j}[1] + 1] + eval.ExternalProduct(acc, eval.tmpRGSW, acc) + } + + res[index] = acc.CopyNew() + } + + // LUT[RLWE] = LUT[RLWE] * X^{m+e} + } + + return +} + +// ModSwitchRLWETo2NLvl applys round(x * 2N / Q) to the coefficients of polQ and returns the result on pol2N. +func (eval *Evaluator) ModSwitchRLWETo2NLvl(level int, polQ *ring.Poly, pol2N *ring.Poly) { + coeffsBigint := make([]*big.Int, len(polQ.Coeffs[0])) + + ringQ := eval.paramsLWE.RingQ() + + ringQ.PolyToBigintLvl(level, polQ, 1, coeffsBigint) + + QBig := ring.NewUint(1) + for i := 0; i < level+1; i++ { + QBig.Mul(QBig, ring.NewUint(ringQ.Modulus[i])) + } + + twoN := uint64(eval.paramsLUT.N() << 1) + twoNBig := ring.NewUint(twoN) + tmp := pol2N.Coeffs[0] + for i := 0; i < ringQ.N; i++ { + coeffsBigint[i].Mul(coeffsBigint[i], twoNBig) + ring.DivRound(coeffsBigint[i], QBig, coeffsBigint[i]) + tmp[i] = coeffsBigint[i].Uint64() & (twoN - 1) + } +} diff --git a/rlwe/lut/keys.go b/rlwe/lut/keys.go new file mode 100644 index 00000000..6e6b695e --- /dev/null +++ b/rlwe/lut/keys.go @@ -0,0 +1,73 @@ +package lut + +import ( + "github.com/tuneinsight/lattigo/v3/ring" + "github.com/tuneinsight/lattigo/v3/rlwe" + "github.com/tuneinsight/lattigo/v3/rlwe/rgsw" +) + +// Key is a struct storing the encryption +// of the bits of the LWE key. +type Key struct { + SkPos []*rgsw.Ciphertext + SkNeg []*rgsw.Ciphertext + One *rgsw.Plaintext +} + +// GenLUTKey generates the LUT evaluation key +func (eval *Evaluator) GenLUTKey(skRLWE, skLWE *rlwe.SecretKey) (key Key) { + + paramsLUT := eval.paramsLUT + paramsLWE := eval.paramsLWE + + skLWEInvNTT := eval.paramsLWE.RingQ().NewPoly() + + paramsLWE.RingQ().InvNTT(skLWE.Value.Q, skLWEInvNTT) + + plaintextRGSWOne := rlwe.NewPlaintext(paramsLUT, paramsLUT.MaxLevel()) + plaintextRGSWOne.Value.IsNTT = true + for j := 0; j < paramsLUT.QCount(); j++ { + for i := 0; i < paramsLUT.N(); i++ { + plaintextRGSWOne.Value.Coeffs[j][i] = 1 + } + } + + encryptor := rlwe.NewEncryptor(paramsLUT, skRLWE) + + levelQ := paramsLUT.QCount() - 1 + levelP := paramsLUT.PCount() - 1 + + skRGSWPos := make([]*rgsw.Ciphertext, paramsLWE.N()) + skRGSWNeg := make([]*rgsw.Ciphertext, paramsLWE.N()) + + ringQ := paramsLWE.RingQ() + Q := ringQ.Modulus[0] + OneMForm := ring.MForm(1, Q, ringQ.BredParams[0]) + MinusOneMform := ring.MForm(Q-1, Q, ringQ.BredParams[0]) + + decompRNS := paramsLUT.DecompRNS(levelQ, levelP) + decompBIT := paramsLUT.DecompBIT(levelQ, levelP) + ringQP := paramsLUT.RingQP() + + for i, si := range skLWEInvNTT.Coeffs[0] { + + skRGSWPos[i] = rgsw.NewCiphertextNTT(levelQ, levelP, decompRNS, decompBIT, ringQP) + skRGSWNeg[i] = rgsw.NewCiphertextNTT(levelQ, levelP, decompRNS, decompBIT, ringQP) + + // sk_i = 1 -> [RGSW(1), RGSW(0)] + if si == OneMForm { + encryptor.Encrypt(plaintextRGSWOne, skRGSWPos[i]) + encryptor.Encrypt(nil, skRGSWNeg[i]) + // sk_i = -1 -> [RGSW(0), RGSW(1)] + } else if si == MinusOneMform { + encryptor.Encrypt(nil, skRGSWPos[i]) + encryptor.Encrypt(plaintextRGSWOne, skRGSWNeg[i]) + // sk_i = 0 -> [RGSW(0), RGSW(0)] + } else { + encryptor.Encrypt(nil, skRGSWPos[i]) + encryptor.Encrypt(nil, skRGSWNeg[i]) + } + } + + return Key{SkPos: skRGSWPos, SkNeg: skRGSWNeg, One: rgsw.NewPlaintext(uint64(1), levelQ, levelP, paramsLUT.LogBase2(), decompBIT, *ringQP)} +} diff --git a/rlwe/lut/lut.go b/rlwe/lut/lut.go new file mode 100644 index 00000000..4c0f4bdf --- /dev/null +++ b/rlwe/lut/lut.go @@ -0,0 +1,33 @@ +package lut + +import ( + "github.com/tuneinsight/lattigo/v3/ring" +) + +// InitLUT takes a function g, and creates an LUT polynomial for the function between the intervals a, b. +// Inputs to the LUT evaluation are assumed to have been normalized with the change of basis (2*x - a - b)/(b-a). +// Interval a, b should take into account the "drift" of the value x, caused by the change of modulus from Q to 2N. +func InitLUT(g func(x float64) (y float64), scale float64, ringQ *ring.Ring, a, b float64) (F *ring.Poly) { + F = ringQ.NewPoly() + Q := ringQ.Modulus + + // Discretization interval + interval := 2.0 / float64(ringQ.N) + + for j, qi := range Q { + + // Interval [-1, 0] of g(x) + for i := 0; i < (ringQ.N>>1)+1; i++ { + F.Coeffs[j][i] = scaleUp(g(normalizeInv(-interval*float64(i), a, b)), scale, qi) + } + + // Interval ]0, 1[ of g(x) + for i := (ringQ.N >> 1) + 1; i < ringQ.N; i++ { + F.Coeffs[j][i] = scaleUp(-g(normalizeInv(interval*float64(ringQ.N-i), a, b)), scale, qi) + } + } + + ringQ.NTT(F, F) + + return +} diff --git a/rlwe/lut/test_lut.go b/rlwe/lut/test_lut.go new file mode 100644 index 00000000..078d2a96 --- /dev/null +++ b/rlwe/lut/test_lut.go @@ -0,0 +1,132 @@ +package lut + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "github.com/tuneinsight/lattigo/v3/ring" + "github.com/tuneinsight/lattigo/v3/rlwe" + "math" + "runtime" + "testing" +) + +func testString(params rlwe.Parameters, opname string) string { + return fmt.Sprintf("%slogN=%d/logQ=%d/logP=%d/#Qi=%d/#Pi=%d", + opname, + params.LogN(), + params.LogQ(), + params.LogP(), + params.QCount(), + params.PCount()) +} + +// TestLUT tests the LUT evaluation. +func TestLUT(t *testing.T) { + for _, testSet := range []func(t *testing.T){ + testLUT, + } { + testSet(t) + runtime.GC() + } +} + +// m0 : [0, 1/4] +// m1 : [0, 1/4] +// | 0 0 -> 1 (0/8) -> 2/8 +// | 0 1 -> 1 (2/8) -> 2/8 +// | 1 0 -> 1 (2/8) -> 2/8 +// | 1 1 -> 0 (4/8) -> 0/8 +func nandGate(x float64) float64 { + if x > -1/8.0 && x < 3/8.0 { + return 2 / 8.0 + } + + return 0 +} + +func testLUT(t *testing.T) { + var err error + + // N=1024, Q=0x7fff801 -> 2^131 + paramsLUT, err := rlwe.NewParametersFromLiteral(rlwe.ParametersLiteral{ + LogN: 8, + Q: []uint64{0x7fff801}, + P: []uint64{}, + Sigma: rlwe.DefaultSigma, + LogBase2: 7, + }) + + assert.Nil(t, err) + + // N=512, Q=0x3001 -> 2^135 + paramsLWE, err := rlwe.NewParametersFromLiteral(rlwe.ParametersLiteral{ + LogN: 7, + Q: []uint64{0x3001}, + P: []uint64{}, + Sigma: rlwe.DefaultSigma, + }) + + assert.Nil(t, err) + + t.Run(testString(paramsLUT, "LUT/"), func(t *testing.T) { + + scaleLWE := float64(paramsLWE.Q()[0]) / 4.0 + scaleLUT := float64(paramsLUT.Q()[0]) / 4.0 + + slots := 16 + + LUTPoly := InitLUT(nandGate, scaleLUT, paramsLUT.RingQ(), -1, 1) + + lutPolyMap := make(map[int]*ring.Poly) + for i := 0; i < slots; i++ { + lutPolyMap[i] = LUTPoly + } + + skLWE := rlwe.NewKeyGenerator(paramsLWE).GenSecretKey() + encryptorLWE := rlwe.NewEncryptor(paramsLWE, skLWE) + + values := make([]float64, slots) + for i := 0; i < slots; i++ { + values[i] = -1 + float64(2*i)/float64(slots) + } + + ptLWE := rlwe.NewPlaintext(paramsLWE, paramsLWE.MaxLevel()) + for i := range values { + if values[i] < 0 { + ptLWE.Value.Coeffs[0][i] = paramsLWE.Q()[0] - uint64(-values[i]*scaleLWE) + } else { + ptLWE.Value.Coeffs[0][i] = uint64(values[i] * scaleLWE) + } + } + ctLWE := rlwe.NewCiphertextNTT(paramsLWE, 1, paramsLWE.MaxLevel()) + encryptorLWE.Encrypt(ptLWE, ctLWE) + + eval := NewEvaluator(paramsLUT, paramsLWE, nil) + + skLUT := rlwe.NewKeyGenerator(paramsLUT).GenSecretKey() + LUTKEY := eval.GenLUTKey(skLUT, skLWE) + + ctsLUT := eval.Evaluate(ctLWE, lutPolyMap, LUTKEY) + + q := paramsLUT.Q()[0] + qHalf := q >> 1 + decryptorLUT := rlwe.NewDecryptor(paramsLUT, skLUT) + ptLUT := rlwe.NewPlaintext(paramsLUT, paramsLUT.MaxLevel()) + for i := 0; i < slots; i++ { + + decryptorLUT.Decrypt(ctsLUT[i], ptLUT) + + c := ptLUT.Value.Coeffs[0][i] + + var a float64 + if c >= qHalf { + a = -float64(q-c) / scaleLUT + } else { + a = float64(c) / scaleLUT + } + + //fmt.Printf("%7.4f - %7.4f - %7.4f\n", math.Round(a*32)/32, math.Round(a*8)/8, values[i]) + assert.Equal(t, nandGate(values[i]), math.Round(a*8)/8) + } + }) +} diff --git a/rlwe/lut/utils.go b/rlwe/lut/utils.go new file mode 100644 index 00000000..1f5c7cec --- /dev/null +++ b/rlwe/lut/utils.go @@ -0,0 +1,51 @@ +package lut + +import ( + "github.com/tuneinsight/lattigo/v3/ring" + "math/big" +) + +//MulBySmallMonomialMod2N multiplies pol by x^n, with 0 <= n < N +func MulBySmallMonomialMod2N(mask uint64, pol *ring.Poly, n int) { + if n != 0 { + N := len(pol.Coeffs[0]) + pol.Coeffs[0] = append(pol.Coeffs[0][N-n:], pol.Coeffs[0][:N-n]...) + tmp := pol.Coeffs[0] + for j := 0; j < n; j++ { + tmp[j] = -tmp[j] & mask + } + } +} + +func normalizeInv(x, a, b float64) (y float64) { + return (x*(b-a) + b + a) / 2.0 +} + +func scaleUp(value float64, scale float64, Q uint64) (res uint64) { + + var isNegative bool + var xFlo *big.Float + var xInt *big.Int + + isNegative = false + if value < 0 { + isNegative = true + xFlo = big.NewFloat(-scale * value) + } else { + xFlo = big.NewFloat(scale * value) + } + + xFlo.Add(xFlo, big.NewFloat(0.5)) + + xInt = new(big.Int) + xFlo.Int(xInt) + xInt.Mod(xInt, ring.NewUint(Q)) + + res = xInt.Uint64() + + if isNegative { + res = Q - res + } + + return +} diff --git a/rlwe/rgsw/ciphertext.go b/rlwe/rgsw/ciphertext.go index 72564652..feb017bb 100644 --- a/rlwe/rgsw/ciphertext.go +++ b/rlwe/rgsw/ciphertext.go @@ -1,8 +1,8 @@ package rgsw import ( - "github.com/tuneinsight/lattigo/v3/rlwe" "github.com/tuneinsight/lattigo/v3/rlwe/gadget" + "github.com/tuneinsight/lattigo/v3/rlwe/ringqp" ) // Ciphertext is a generic type for RGSW ciphertext. @@ -21,12 +21,11 @@ func (ct *Ciphertext) LevelP() int { } // NewCiphertextNTT allocates a new RGSW ciphertext in the NTT domain. -func NewCiphertextNTT(params rlwe.Parameters, levelQ int) (ct *Ciphertext) { - levelP := params.PCount() - 1 +func NewCiphertextNTT(levelQ, levelP, decompRNS, decompBit int, ringQP *ringqp.Ring) (ct *Ciphertext) { return &Ciphertext{ Value: [2]gadget.Ciphertext{ - *gadget.NewCiphertextNTT(levelQ, levelP, params.DecompRNS(levelQ, levelP), params.DecompBIT(levelQ, levelP), *params.RingQP()), - *gadget.NewCiphertextNTT(levelQ, levelP, params.DecompRNS(levelQ, levelP), params.DecompBIT(levelQ, levelP), *params.RingQP()), + *gadget.NewCiphertextNTT(levelQ, levelP, decompRNS, decompBit, *ringQP), + *gadget.NewCiphertextNTT(levelQ, levelP, decompRNS, decompBit, *ringQP), }, } } diff --git a/rlwe/rgsw/encryptor.go b/rlwe/rgsw/encryptor.go deleted file mode 100644 index fd3a2549..00000000 --- a/rlwe/rgsw/encryptor.go +++ /dev/null @@ -1,202 +0,0 @@ -package rgsw - -import ( - "github.com/tuneinsight/lattigo/v3/ring" - "github.com/tuneinsight/lattigo/v3/rlwe" - "github.com/tuneinsight/lattigo/v3/rlwe/gadget" - "github.com/tuneinsight/lattigo/v3/rlwe/ringqp" - "github.com/tuneinsight/lattigo/v3/utils" -) - -// Encryptor a generic RLWE encryption interface. -type Encryptor interface { - Encrypt(pt *rlwe.Plaintext, ct *Ciphertext) - ShallowCopy() Encryptor - WithKey(key interface{}) Encryptor -} - -type encryptor struct { - *encryptorBase - *encryptorSamplers - *encryptorBuffers -} - -type skEncryptor struct { - encryptor - sk *rlwe.SecretKey -} - -// NewEncryptor creates a new Encryptor -// Accepts either a secret-key or a public-key. -func NewEncryptor(params rlwe.Parameters, key interface{}) Encryptor { - enc := newEncryptor(params) - return enc.setKey(key) -} - -func newEncryptor(params rlwe.Parameters) encryptor { - return encryptor{ - encryptorBase: newEncryptorBase(params), - encryptorSamplers: newEncryptorSamplers(params), - encryptorBuffers: newEncryptorBuffers(params), - } -} - -// encryptorBase is a struct used to encrypt Plaintexts. It stores the public-key and/or secret-key. -type encryptorBase struct { - params rlwe.Parameters -} - -func newEncryptorBase(params rlwe.Parameters) *encryptorBase { - return &encryptorBase{params} -} - -type encryptorSamplers struct { - gaussianSampler *ring.GaussianSampler - ternarySampler *ring.TernarySampler - uniformSamplerQ *ring.UniformSampler - uniformSamplerP *ring.UniformSampler -} - -func newEncryptorSamplers(params rlwe.Parameters) *encryptorSamplers { - prng, err := utils.NewPRNG() - if err != nil { - panic(err) - } - - var uniformSamplerP *ring.UniformSampler - if params.PCount() != 0 { - uniformSamplerP = ring.NewUniformSampler(prng, params.RingP()) - } - - return &encryptorSamplers{ - gaussianSampler: ring.NewGaussianSampler(prng, params.RingQ(), params.Sigma(), int(6*params.Sigma())), - ternarySampler: ring.NewTernarySamplerWithHammingWeight(prng, params.RingQ(), params.HammingWeight(), false), - uniformSamplerQ: ring.NewUniformSampler(prng, params.RingQ()), - uniformSamplerP: uniformSamplerP, - } -} - -type encryptorBuffers struct { - poolQP ringqp.Poly -} - -func newEncryptorBuffers(params rlwe.Parameters) *encryptorBuffers { - return &encryptorBuffers{ - poolQP: params.RingQP().NewPoly(), - } -} - -// ShallowCopy creates a shallow copy of this skEncryptor in which all the read-only data-structures are -// shared with the receiver and the temporary buffers are reallocated. The receiver and the returned -// Encryptors can be used concurrently. -func (enc *skEncryptor) ShallowCopy() Encryptor { - return &skEncryptor{*enc.encryptor.ShallowCopy(), enc.sk} -} - -// ShallowCopy creates a shallow copy of this encryptor in which all the read-only data-structures are -// shared with the receiver and the temporary buffers are reallocated. The receiver and the returned -// Encryptors can be used concurrently. -func (enc *encryptor) ShallowCopy() *encryptor { - return &encryptor{ - encryptorBase: enc.encryptorBase, - encryptorSamplers: newEncryptorSamplers(enc.params), - encryptorBuffers: newEncryptorBuffers(enc.params), - } -} - -// WithKey creates a shallow copy of this encryptor with a new key in which all the read-only data-structures are -// shared with the receiver and the temporary buffers are reallocated. The receiver and the returned -// Encryptors can be used concurrently. -func (enc *encryptor) WithKey(key interface{}) Encryptor { - return enc.ShallowCopy().setKey(key) -} - -func (enc *encryptor) setKey(key interface{}) Encryptor { - switch key := key.(type) { - case *rlwe.SecretKey: - if key.Value.Q.Degree() != enc.params.N() { - panic("cannot setKey: sk ring degree does not match params ring degree") - } - return &skEncryptor{*enc, key} - default: - panic("cannot setKey: key must be *rlwe.SecretKey") - } -} - -func (enc *skEncryptor) Encrypt(pt *rlwe.Plaintext, ct *Ciphertext) { - - params := enc.params - ringQ := params.RingQ() - levelQ := ct.LevelQ() - levelP := ct.LevelP() - - decompRNS := params.DecompRNS(levelQ, levelP) - decompBIT := params.DecompBIT(levelQ, levelP) - - for j := 0; j < decompBIT; j++ { - for i := 0; i < decompRNS; i++ { - enc.encryptZeroSymetricQP(levelQ, levelP, enc.sk.Value, true, true, true, ct.Value[0].Value[i][j]) - enc.encryptZeroSymetricQP(levelQ, levelP, enc.sk.Value, true, true, true, ct.Value[1].Value[i][j]) - } - } - - if pt != nil { - ringQ.MFormLvl(levelQ, pt.Value, enc.poolQP.Q) - if !pt.Value.IsNTT { - ringQ.NTTLvl(levelQ, enc.poolQP.Q, enc.poolQP.Q) - } - gadget.AddPolyToGadgetMatrix( - enc.poolQP.Q, - []gadget.Ciphertext{ct.Value[0], ct.Value[1]}, - *params.RingQP(), - params.LogBase2(), - enc.poolQP.Q) - } -} - -func (enc *encryptor) encryptZeroSymetricQP(levelQ, levelP int, sk ringqp.Poly, sample, montgomery, ntt bool, ct [2]ringqp.Poly) { - - params := enc.params - ringQP := params.RingQP() - - hasModulusP := ct[0].P != nil - - if ntt { - enc.gaussianSampler.ReadLvl(levelQ, ct[0].Q) - - if hasModulusP { - ringQP.ExtendBasisSmallNormAndCenter(ct[0].Q, levelP, nil, ct[0].P) - } - - ringQP.NTTLvl(levelQ, levelP, ct[0], ct[0]) - } - - if sample { - enc.uniformSamplerQ.ReadLvl(levelQ, ct[1].Q) - - if hasModulusP { - enc.uniformSamplerP.ReadLvl(levelP, ct[1].P) - } - } - - ringQP.MulCoeffsMontgomeryAndSubLvl(levelQ, levelP, ct[1], sk, ct[0]) - - if !ntt { - ringQP.InvNTTLvl(levelQ, levelP, ct[0], ct[0]) - ringQP.InvNTTLvl(levelQ, levelP, ct[1], ct[1]) - - e := enc.poolQP - enc.gaussianSampler.ReadLvl(levelQ, e.Q) - - if hasModulusP { - ringQP.ExtendBasisSmallNormAndCenter(e.Q, levelP, nil, e.P) - } - - ringQP.AddLvl(levelQ, levelP, ct[0], e, ct[0]) - } - - if montgomery { - ringQP.MFormLvl(levelQ, levelP, ct[0], ct[0]) - ringQP.MFormLvl(levelQ, levelP, ct[1], ct[1]) - } -} diff --git a/rlwe/rgsw/operations.go b/rlwe/rgsw/operations.go new file mode 100644 index 00000000..5d18c544 --- /dev/null +++ b/rlwe/rgsw/operations.go @@ -0,0 +1,80 @@ +package rgsw + +import ( + "github.com/tuneinsight/lattigo/v3/ring" + "github.com/tuneinsight/lattigo/v3/rlwe/ringqp" +) + +// AddNoModLvl adds op to ctOut, without modular reduction. +func AddNoModLvl(levelQ, levelP int, op interface{}, ringQP ringqp.Ring, ctOut *Ciphertext) { + switch el := op.(type) { + case *Plaintext: + + nQ := levelQ + 1 + nP := levelP + 1 + + if nP == 0 { + nP = 1 + } + + for i := range ctOut.Value[0].Value { + for j := range ctOut.Value[0].Value[i] { + start, end := i*nP, (i+1)*nP + if end > nQ { + end = nQ + } + for k := start; k < end; k++ { + ring.AddVecNoMod(ctOut.Value[0].Value[i][j][0].Q.Coeffs[k], el.Value[j].Coeffs[k], ctOut.Value[0].Value[i][j][0].Q.Coeffs[k]) + ring.AddVecNoMod(ctOut.Value[1].Value[i][j][1].Q.Coeffs[k], el.Value[j].Coeffs[k], ctOut.Value[1].Value[i][j][1].Q.Coeffs[k]) + } + } + } + case *Ciphertext: + for i := range el.Value[0].Value { + for j := range el.Value[0].Value[i] { + ringQP.AddNoModLvl(levelQ, levelP, ctOut.Value[0].Value[i][j][0], el.Value[0].Value[i][j][0], ctOut.Value[0].Value[i][j][0]) + ringQP.AddNoModLvl(levelQ, levelP, ctOut.Value[0].Value[i][j][1], el.Value[0].Value[i][j][1], ctOut.Value[0].Value[i][j][1]) + ringQP.AddNoModLvl(levelQ, levelP, ctOut.Value[1].Value[i][j][0], el.Value[1].Value[i][j][0], ctOut.Value[1].Value[i][j][0]) + ringQP.AddNoModLvl(levelQ, levelP, ctOut.Value[1].Value[i][j][1], el.Value[1].Value[i][j][1], ctOut.Value[1].Value[i][j][1]) + } + } + default: + panic("unsuported op.(type), must be either *rgsw.Plaintext or *rgsw.Ciphertext") + } +} + +// ReduceLvl applies the modular reduction on ctIn and returns the result on ctOut. +func ReduceLvl(levelQ, levelP int, ctIn *Ciphertext, ringQP ringqp.Ring, ctOut *Ciphertext) { + for i := range ctIn.Value[0].Value { + for j := range ctIn.Value[0].Value[i] { + ringQP.ReduceLvl(levelQ, levelP, ctIn.Value[0].Value[i][j][0], ctOut.Value[0].Value[i][j][0]) + ringQP.ReduceLvl(levelQ, levelP, ctIn.Value[0].Value[i][j][1], ctOut.Value[0].Value[i][j][1]) + ringQP.ReduceLvl(levelQ, levelP, ctIn.Value[1].Value[i][j][0], ctOut.Value[1].Value[i][j][0]) + ringQP.ReduceLvl(levelQ, levelP, ctIn.Value[1].Value[i][j][1], ctOut.Value[1].Value[i][j][1]) + } + } +} + +// MulByXPowAlphaMinusOneConstantLvl multiplies ctOut by (X^alpha - 1) and returns the result on ctOut. +func MulByXPowAlphaMinusOneConstantLvl(levelQ, levelP int, ctIn *Ciphertext, powXMinusOne ringqp.Poly, ringQP ringqp.Ring, ctOut *Ciphertext) { + for i := range ctIn.Value[0].Value { + for j := range ctIn.Value[0].Value[i] { + ringQP.MulCoeffsMontgomeryConstantLvl(levelQ, levelP, ctIn.Value[0].Value[i][j][0], powXMinusOne, ctOut.Value[0].Value[i][j][0]) + ringQP.MulCoeffsMontgomeryConstantLvl(levelQ, levelP, ctIn.Value[0].Value[i][j][1], powXMinusOne, ctOut.Value[0].Value[i][j][1]) + ringQP.MulCoeffsMontgomeryConstantLvl(levelQ, levelP, ctIn.Value[1].Value[i][j][0], powXMinusOne, ctOut.Value[1].Value[i][j][0]) + ringQP.MulCoeffsMontgomeryConstantLvl(levelQ, levelP, ctIn.Value[1].Value[i][j][1], powXMinusOne, ctOut.Value[1].Value[i][j][1]) + } + } +} + +// MulByXPowAlphaMinusOneAndAddNoModLvl multiplies ctOut by (X^alpha - 1) and adds the result on ctOut. +func MulByXPowAlphaMinusOneAndAddNoModLvl(levelQ, levelP int, ctIn *Ciphertext, powXMinusOne ringqp.Poly, ringQP ringqp.Ring, ctOut *Ciphertext) { + for i := range ctIn.Value[0].Value { + for j := range ctIn.Value[0].Value[i] { + ringQP.MulCoeffsMontgomeryConstantAndAddNoModLvl(levelQ, levelP, ctIn.Value[0].Value[i][j][0], powXMinusOne, ctOut.Value[0].Value[i][j][0]) + ringQP.MulCoeffsMontgomeryConstantAndAddNoModLvl(levelQ, levelP, ctIn.Value[0].Value[i][j][1], powXMinusOne, ctOut.Value[0].Value[i][j][1]) + ringQP.MulCoeffsMontgomeryConstantAndAddNoModLvl(levelQ, levelP, ctIn.Value[1].Value[i][j][0], powXMinusOne, ctOut.Value[1].Value[i][j][0]) + ringQP.MulCoeffsMontgomeryConstantAndAddNoModLvl(levelQ, levelP, ctIn.Value[1].Value[i][j][1], powXMinusOne, ctOut.Value[1].Value[i][j][1]) + } + } +} diff --git a/rlwe/rgsw/plaintext.go b/rlwe/rgsw/plaintext.go new file mode 100644 index 00000000..e90158bd --- /dev/null +++ b/rlwe/rgsw/plaintext.go @@ -0,0 +1,70 @@ +package rgsw + +import ( + "github.com/tuneinsight/lattigo/v3/ring" + "github.com/tuneinsight/lattigo/v3/rlwe/ringqp" + "math/big" +) + +// Plaintext stores an RGSW plaintext value. +type Plaintext struct { + Value []*ring.Poly +} + +// NewPlaintext creates a new RGSW plaintext fron value, which can be either uint64, int64 or *ring.Poly. +// Plaintext is returned in the NTT and Mongtomery domain. +func NewPlaintext(value interface{}, levelQ, levelP, logBase2, decompBIT int, ringQP ringqp.Ring) (pt *Plaintext) { + + ringQ := ringQP.RingQ + + pt = new(Plaintext) + pt.Value = make([]*ring.Poly, decompBIT) + + switch el := value.(type) { + case uint64: + pt.Value[0] = ringQ.NewPolyLvl(levelQ) + for i := range ringQ.Modulus[:levelQ+1] { + pt.Value[0].Coeffs[i][0] = el + } + case int64: + pt.Value[0] = ringQ.NewPolyLvl(levelQ) + if el < 0 { + for i, qi := range ringQ.Modulus[:levelQ+1] { + pt.Value[0].Coeffs[i][0] = qi - uint64(-el) + } + } else { + for i := range ringQ.Modulus[:levelQ+1] { + pt.Value[0].Coeffs[i][0] = uint64(el) + } + } + case *ring.Poly: + pt.Value[0] = el.CopyNew() + default: + panic("unsupported type, must be wither uint64 or *ring.Poly") + } + + var pBigInt *big.Int + if levelP > -1 { + ringP := ringQP.RingP + if levelP == len(ringP.Modulus)-1 { + pBigInt = ringP.ModulusBigint + } else { + P := ringP.Modulus + pBigInt = new(big.Int).SetUint64(P[0]) + for i := 1; i < levelP+1; i++ { + pBigInt.Mul(pBigInt, ring.NewUint(P[i])) + } + } + ringQ.MulScalarBigintLvl(levelQ, pt.Value[0], pBigInt, pt.Value[0]) + } + + ringQ.NTTLvl(levelQ, pt.Value[0], pt.Value[0]) + ringQ.MFormLvl(levelQ, pt.Value[0], pt.Value[0]) + + for i := 1; i < len(pt.Value); i++ { + pt.Value[i] = pt.Value[0].CopyNew() + ringQ.MulByPow2Lvl(levelQ, pt.Value[i], i*logBase2, pt.Value[i]) + } + + return +} diff --git a/rlwe/ringqp/ringqp.go b/rlwe/ringqp/ringqp.go index 336af26d..af59132c 100644 --- a/rlwe/ringqp/ringqp.go +++ b/rlwe/ringqp/ringqp.go @@ -257,6 +257,17 @@ func (r *Ring) MulCoeffsMontgomeryAndAddLvl(levelQ, levelP int, p1, p2, p3 Poly) } } +// ReduceLvl applies the modular reduction on the coefficients of p1 and returns the result on p2. +// The operation is performed at levelQ for the ringQ and levelP for the ringP. +func (r *Ring) ReduceLvl(levelQ, levelP int, p1, p2 Poly) { + if r.RingQ != nil { + r.RingQ.ReduceLvl(levelQ, p1.Q, p2.Q) + } + if r.RingP != nil { + r.RingP.ReduceLvl(levelP, p1.P, p2.P) + } +} + // PermuteNTTWithIndexLvl applies the automorphism X^{5^j} on p1 and writes the result on p2. // Index of automorphism must be provided. // Method is not in place. diff --git a/rlwe/rlwe_test.go b/rlwe/rlwe_test.go index b0985594..8167b9e4 100644 --- a/rlwe/rlwe_test.go +++ b/rlwe/rlwe_test.go @@ -44,7 +44,7 @@ func TestRLWE(t *testing.T) { defaultParams = []ParametersLiteral{jsonParams} // the custom test suite reads the parameters from the -params flag } - for _, defaultParam := range defaultParams[:1] { + for _, defaultParam := range defaultParams[:] { params, err := NewParametersFromLiteral(defaultParam) if err != nil { panic(err) @@ -651,7 +651,7 @@ func testManyRLWEToSingleRLWE(kgen KeyGenerator, t *testing.T) { ciphertexts := make(map[int]*Ciphertext) slotIndex := make(map[int]bool) - for i := 0; i < params.N()/2; i += 64 { + for i := 0; i < params.N(); i += params.N() / 16 { ciphertexts[i] = NewCiphertextNTT(params, 1, params.MaxLevel()) encryptor.Encrypt(pt, ciphertexts[i]) slotIndex[i] = true