mirror of
https://github.com/tuneinsight/lattigo.git
synced 2025-09-13 03:27:14 +00:00
[ckks]: updated DecodePublic & updated SECURITY.md
This commit is contained in:
14
SECURITY.md
14
SECURITY.md
@@ -12,11 +12,15 @@ Classified as an _approximate decryption_ scheme, the CKKS scheme is secure as l
|
||||
This attack demonstrates that, when using an approximate homomorphic encryption scheme, the usual CPA security may not sufficient depending on the application setting. Many applications do not require to share the result with external parties and are not affected by this attack, but the ones that do must take the appropriate steps to ensure that no key-dependent information is leaked. A homomorphic encryption scheme that provides such functionality and that can be secure when releasing decrypted plaintext to external parties is defined to be CPA<sup>D</sup> secure. The corresponding indistinguishability notion (IND-CPA<sup>D</sup>) is defined as "indistinguishability under chosen plaintext attacks with decryption oracles."
|
||||
|
||||
# CPA<sup>D</sup> Security for CKKS
|
||||
Lattigo implements tools to mitigate _Li and Micciancio_'s attack. In particular, the decoding step of CKKS (and its real-number variant R-CKKS) allows the user to add a key-independent error $e$ of standard deviation $\sigma$ to the decrypted plaintext before decoding.
|
||||
Lattigo implements tools to mitigate _Li and Micciancio_'s attack. In particular, the decoding step of CKKS (and its real-number variant R-CKKS) allows the user to specify the desired fixed-point bit-precision.
|
||||
|
||||
If at any point of an application, decrypted values have to be shared with external parties, then the user must ensure that each shared plaintext is first _sanitized_ before being shared. To do so, the user must use the $\textsf{DecodePublic}$ method instead of the usual $\textsf{Decode}$. $\textsf{DecodePublic}$ takes as additional input $\sigma$, and samples a key-independent error $e$ with standard deviation $\sigma$, that is added to the plaintext before decoding.
|
||||
Let $\epsilon$ be the scheme error after the decoding step. We compute the bit precision of the output as $\log_{2}(1/\epsilon)$.
|
||||
|
||||
Estimating $\sigma$ must be done carefully and we suggest the following iterative process to do so:
|
||||
If at any point of an application, decrypted values have to be shared with external parties, then the user must ensure that each shared plaintext is first _sanitized_ before being shared. To do so, the user must use the $\textsf{DecodePublic}$ method instead of the usual $\textsf{Decode}$. $\textsf{DecodePublic}$ takes as additional input the desired $\log_{2}(1/\epsilon)$-bit precision and rounds the value by evaluating $y = \lfloor x / \epsilon \rceil \cdot \epsilon$.
|
||||
|
||||
Estimating $E[\epsilon]$ of the circuit must be done carefully and we suggest the following iterative process to do so:
|
||||
1. Given a security parameter $\lambda$ and a circuit $C$ that takes as inputs length-_n_ vectors $\omega$ following a distribution $\chi$, select the appropriate parameters enabling the homomorphic evaluation of $C(\omega)$, denoted by $H(C(\omega))$, which includes the encoding, encryption, evaluation, decryption and decoding.
|
||||
2. Sample input vectors $\omega$ from the distribution $\chi$ and compute the standard deviation $\sigma$ in the time domain (coefficient domain) of $e=C(\omega) - H(C(\omega))$. This can be done using the encoder method $\textsf{GetErrSTDTimeDom}(C(\omega), H(C(\omega)), \Delta)$, where $\Delta$ is the scale of the plaintext after the decryption. The user should make sure that the underlying circuit computed by $H(C(\cdot))$ is identical to $C(\cdot)$; i.e., if the homomorphic implementation $H(C(\cdot))$ uses polynomial approximations, then $C(\cdot)$ should use them too, instead of using the original exact function. This will ensure that $e$, and therefore $\sigma$, are as close as possible to the actual underlying scheme error, and not influenced by function-approximation errors.
|
||||
3. Use the encoder method $\textsf{DecodePublic}$ with the parameter $\sigma$ to decode plaintexts that will be published. $\textsf{DecodePublic}$ adds an error $e$ with standard deviation $\sigma$ bounded by $B = \sigma\sqrt{2\pi}$. The precision loss, compared to a private decoding, should be less than half a bit on average.
|
||||
2. Sample input vectors $\omega$ from the distribution $\chi$ and record $\epsilon=C(\omega) - H(C(\omega))$. The user should make sure that the underlying circuit computed by $H(C(\cdot))$ is identical to $C(\cdot)$; i.e., if the homomorphic implementation $H(C(\cdot))$ uses polynomial approximations, then $C(\cdot)$ should use them too, instead of using the original exact function. Repeat until $\epsilon$ reaches a stable value.
|
||||
3. Use the encoder method $\textsf{DecodePublic}$ with the parameter $\log_{2}(1/\epsilon)$ to decode plaintexts that will be published. $\textsf{DecodePublic}$ will round the values to $\log_{2}(1/\epsilon)$-bits of precision.
|
||||
|
||||
Note that, for composability with differential privacy, the variance of the error introduced by the rounding is $\text{Var}[x - \lfloor x \cdot \epsilon \rceil / \epsilon] = \tfrac{\epsilon}{12}$ and therefore $\text{Var}[x - \lfloor x/(\sigma\sqrt{12})\rceil\cdot(\sigma\sqrt{12})] = \sigma^2$.
|
||||
|
||||
@@ -358,7 +358,7 @@ func TestBootstrapping(t *testing.T) {
|
||||
}
|
||||
|
||||
func verifyTestVectorsBootstrapping(params ckks.Parameters, encoder *ckks.Encoder, decryptor *rlwe.Decryptor, valuesWant, element interface{}, t *testing.T) {
|
||||
precStats := ckks.GetPrecisionStats(params, encoder, decryptor, valuesWant, element, nil, false)
|
||||
precStats := ckks.GetPrecisionStats(params, encoder, decryptor, valuesWant, element, 0, false)
|
||||
if *printPrecisionStats {
|
||||
t.Log(precStats.String())
|
||||
}
|
||||
|
||||
@@ -309,7 +309,7 @@ func testBootstrapHighPrecision(paramSet defaultParametersLiteral, t *testing.T)
|
||||
}
|
||||
|
||||
func verifyTestVectors(params ckks.Parameters, encoder *ckks.Encoder, decryptor *rlwe.Decryptor, valuesWant, valuesHave interface{}, t *testing.T) {
|
||||
precStats := ckks.GetPrecisionStats(params, encoder, decryptor, valuesWant, valuesHave, nil, false)
|
||||
precStats := ckks.GetPrecisionStats(params, encoder, decryptor, valuesWant, valuesHave, 0, false)
|
||||
if *printPrecisionStats {
|
||||
t.Log(precStats.String())
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ func TestComparisons(t *testing.T) {
|
||||
want[i] = polys.Evaluate(values[i])[0]
|
||||
}
|
||||
|
||||
ckks.VerifyTestVectors(params, ecd, nil, want, have, params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, ecd, nil, want, have, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(params, "Step"), func(t *testing.T) {
|
||||
@@ -95,7 +95,7 @@ func TestComparisons(t *testing.T) {
|
||||
want[i].Add(want[i], half)
|
||||
}
|
||||
|
||||
ckks.VerifyTestVectors(params, ecd, nil, want, have, params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, ecd, nil, want, have, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(params, "Max"), func(t *testing.T) {
|
||||
@@ -122,7 +122,7 @@ func TestComparisons(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
ckks.VerifyTestVectors(params, ecd, nil, want, have, params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, ecd, nil, want, have, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(params, "Min"), func(t *testing.T) {
|
||||
@@ -149,7 +149,7 @@ func TestComparisons(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
ckks.VerifyTestVectors(params, ecd, nil, want, have, params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, ecd, nil, want, have, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -232,7 +232,7 @@ func testHomomorphicEncoding(params ckks.Parameters, LogSlots int, t *testing.T)
|
||||
}
|
||||
|
||||
// Compares
|
||||
ckks.VerifyTestVectors(params, ecd2N, nil, want, have, params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, ecd2N, nil, want, have, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
|
||||
} else {
|
||||
|
||||
@@ -276,8 +276,8 @@ func testHomomorphicEncoding(params ckks.Parameters, LogSlots int, t *testing.T)
|
||||
wantImag[i], wantImag[j] = vec1[i][0], vec1[i][1]
|
||||
}
|
||||
|
||||
ckks.VerifyTestVectors(params, ecd2N, nil, wantReal, haveReal, params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, ecd2N, nil, wantImag, haveImag, params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, ecd2N, nil, wantReal, haveReal, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, ecd2N, nil, wantImag, haveImag, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -424,6 +424,6 @@ func testHomomorphicDecoding(params ckks.Parameters, LogSlots int, t *testing.T)
|
||||
// Result is bit-reversed, so applies the bit-reverse permutation on the reference vector
|
||||
utils.BitReverseInPlaceSlice(valuesReal, slots)
|
||||
|
||||
ckks.VerifyTestVectors(params, encoder, decryptor, valuesReal, valuesTest, params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, encoder, decryptor, valuesReal, valuesTest, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -202,7 +202,7 @@ func testCKKSLinearTransformation(tc *ckksTestContext, t *testing.T) {
|
||||
values[i][1].Quo(values[i][1], nB)
|
||||
}
|
||||
|
||||
ckks.VerifyTestVectors(params, tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(params, "LinearTransform/BSGS=True"), func(t *testing.T) {
|
||||
@@ -263,7 +263,7 @@ func testCKKSLinearTransformation(tc *ckksTestContext, t *testing.T) {
|
||||
values[i].Add(values[i], tmp[(i+15)%slots])
|
||||
}
|
||||
|
||||
ckks.VerifyTestVectors(params, tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(params, "LinearTransform/BSGS=False"), func(t *testing.T) {
|
||||
@@ -324,7 +324,7 @@ func testCKKSLinearTransformation(tc *ckksTestContext, t *testing.T) {
|
||||
values[i].Add(values[i], tmp[(i+15)%slots])
|
||||
}
|
||||
|
||||
ckks.VerifyTestVectors(params, tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -367,7 +367,7 @@ func testEvaluatePolynomial(tc *ckksTestContext, t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ckks.VerifyTestVectors(params, tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, tc.encoder, tc.decryptor, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(params, "Polynomial/PolyVector/Exp"), func(t *testing.T) {
|
||||
@@ -415,6 +415,6 @@ func testEvaluatePolynomial(tc *ckksTestContext, t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ckks.VerifyTestVectors(params, tc.encoder, tc.decryptor, valuesWant, ciphertext, params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, tc.encoder, tc.decryptor, valuesWant, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ func TestInverse(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ckks.VerifyTestVectors(params, tc.encoder, tc.decryptor, values, ciphertext, 70, nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, tc.encoder, tc.decryptor, values, ciphertext, 70, 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(params, "PositiveDomain"), func(t *testing.T) {
|
||||
@@ -103,7 +103,7 @@ func TestInverse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
ckks.VerifyTestVectors(params, tc.encoder, nil, want, have, 70, nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, tc.encoder, nil, want, have, 70, 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(params, "NegativeDomain"), func(t *testing.T) {
|
||||
@@ -130,7 +130,7 @@ func TestInverse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
ckks.VerifyTestVectors(params, tc.encoder, nil, want, have, 70, nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, tc.encoder, nil, want, have, 70, 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(params, "FullDomain"), func(t *testing.T) {
|
||||
@@ -157,7 +157,7 @@ func TestInverse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
ckks.VerifyTestVectors(params, tc.encoder, nil, want, have, 70, nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, tc.encoder, nil, want, have, 70, 0, *printPrecisionStats, t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ func testMod1(params ckks.Parameters, t *testing.T) {
|
||||
|
||||
values, ciphertext := evaluateMod1(evm, params, ecd, enc, eval, t)
|
||||
|
||||
ckks.VerifyTestVectors(params, ecd, dec, values, ciphertext, params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, ecd, dec, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run("CosDiscrete", func(t *testing.T) {
|
||||
@@ -108,7 +108,7 @@ func testMod1(params ckks.Parameters, t *testing.T) {
|
||||
|
||||
values, ciphertext := evaluateMod1(evm, params, ecd, enc, eval, t)
|
||||
|
||||
ckks.VerifyTestVectors(params, ecd, dec, values, ciphertext, params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, ecd, dec, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run("CosContinuous", func(t *testing.T) {
|
||||
@@ -125,7 +125,7 @@ func testMod1(params ckks.Parameters, t *testing.T) {
|
||||
|
||||
values, ciphertext := evaluateMod1(evm, params, ecd, enc, eval, t)
|
||||
|
||||
ckks.VerifyTestVectors(params, ecd, dec, values, ciphertext, params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, ecd, dec, values, ciphertext, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -276,7 +276,88 @@ func testEncoder(tc *testContext, t *testing.T) {
|
||||
|
||||
values, plaintext, _ := newTestVectors(tc, nil, -1-1i, 1+1i, t)
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, nil, values, plaintext, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, nil, values, plaintext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
logprec := float64(tc.params.LogDefaultScale()) / 2
|
||||
|
||||
t.Run(GetTestName(tc.params, "Encoder/IsBatched=true/DecodePublic/[]float64"), func(t *testing.T) {
|
||||
|
||||
values, plaintext, _ := newTestVectors(tc, nil, -1-1i, 1+1i, t)
|
||||
|
||||
have := make([]float64, len(values))
|
||||
|
||||
require.NoError(t, tc.encoder.DecodePublic(plaintext, have, logprec))
|
||||
|
||||
want := make([]float64, len(values))
|
||||
for i := range want {
|
||||
want[i], _ = values[i][0].Float64()
|
||||
want[i] -= have[i]
|
||||
}
|
||||
|
||||
// Allows for a 10% error over the expected standard deviation of the error
|
||||
require.GreaterOrEqual(t, StandardDeviation(want, rlwe.NewScale(1)), math.Exp2(-logprec)/math.Sqrt(12)*0.9)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Encoder/IsBatched=true/DecodePublic/[]complex128"), func(t *testing.T) {
|
||||
|
||||
if tc.params.RingType() == ring.ConjugateInvariant {
|
||||
t.Skip()
|
||||
}
|
||||
values, plaintext, _ := newTestVectors(tc, nil, -1-1i, 1+1i, t)
|
||||
|
||||
have := make([]complex128, len(values))
|
||||
require.NoError(t, tc.encoder.DecodePublic(plaintext, have, logprec))
|
||||
|
||||
wantReal := make([]float64, len(values))
|
||||
wantImag := make([]float64, len(values))
|
||||
|
||||
for i := range have {
|
||||
wantReal[i], _ = values[i][0].Float64()
|
||||
wantImag[i], _ = values[i][1].Float64()
|
||||
|
||||
wantReal[i] -= real(have[i])
|
||||
wantImag[i] -= imag(have[i])
|
||||
}
|
||||
|
||||
// Allows for a 10% error over the expected standard deviation of the error
|
||||
require.GreaterOrEqual(t, StandardDeviation(wantReal, rlwe.NewScale(1)), math.Exp2(-logprec)/math.Sqrt(12)*0.9)
|
||||
require.GreaterOrEqual(t, StandardDeviation(wantImag, rlwe.NewScale(1)), math.Exp2(-logprec)/math.Sqrt(12)*0.9)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Encoder/IsBatched=true/DecodePublic/[]big.Float"), func(t *testing.T) {
|
||||
values, plaintext, _ := newTestVectors(tc, nil, -1-1i, 1+1i, t)
|
||||
have := make([]*big.Float, len(values))
|
||||
require.NoError(t, tc.encoder.DecodePublic(plaintext, have, logprec))
|
||||
|
||||
want := make([]*big.Float, len(values))
|
||||
for i := range want {
|
||||
want[i] = values[i][0].Sub(values[i][0], have[i])
|
||||
}
|
||||
|
||||
// Allows for a 10% error over the expected standard deviation of the error
|
||||
require.GreaterOrEqual(t, StandardDeviation(want, rlwe.NewScale(1)), math.Exp2(-logprec)/math.Sqrt(12)*0.9)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Encoder/IsBatched=true/DecodePublic/[]bignum.Complex"), func(t *testing.T) {
|
||||
if tc.params.RingType() == ring.ConjugateInvariant {
|
||||
t.Skip()
|
||||
}
|
||||
values, plaintext, _ := newTestVectors(tc, nil, -1-1i, 1+1i, t)
|
||||
have := make([]*bignum.Complex, len(values))
|
||||
require.NoError(t, tc.encoder.DecodePublic(plaintext, have, logprec))
|
||||
|
||||
wantReal := make([]*big.Float, len(values))
|
||||
wantImag := make([]*big.Float, len(values))
|
||||
|
||||
for i := range have {
|
||||
wantReal[i] = values[i][0].Sub(values[i][0], have[i][0])
|
||||
wantImag[i] = values[i][1].Sub(values[i][1], have[i][1])
|
||||
}
|
||||
|
||||
// Allows for a 10% error over the expected standard deviation of the error
|
||||
require.GreaterOrEqual(t, StandardDeviation(wantReal, rlwe.NewScale(1)), math.Exp2(-logprec)/math.Sqrt(12)*0.9)
|
||||
require.GreaterOrEqual(t, StandardDeviation(wantImag, rlwe.NewScale(1)), math.Exp2(-logprec)/math.Sqrt(12)*0.9)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Encoder/IsBatched=false"), func(t *testing.T) {
|
||||
@@ -336,7 +417,7 @@ func testEvaluatorAdd(tc *testContext, t *testing.T) {
|
||||
ciphertext3, err := tc.evaluator.AddNew(ciphertext1, ciphertext2)
|
||||
require.NoError(t, err)
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext3, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext3, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Evaluator/Add/Ct"), func(t *testing.T) {
|
||||
@@ -350,7 +431,7 @@ func testEvaluatorAdd(tc *testContext, t *testing.T) {
|
||||
|
||||
require.NoError(t, tc.evaluator.Add(ciphertext1, ciphertext2, ciphertext1))
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Evaluator/Add/Pt"), func(t *testing.T) {
|
||||
@@ -364,7 +445,7 @@ func testEvaluatorAdd(tc *testContext, t *testing.T) {
|
||||
|
||||
require.NoError(t, tc.evaluator.Add(ciphertext1, plaintext2, ciphertext1))
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Evaluator/Add/Scalar"), func(t *testing.T) {
|
||||
@@ -379,7 +460,7 @@ func testEvaluatorAdd(tc *testContext, t *testing.T) {
|
||||
|
||||
require.NoError(t, tc.evaluator.Add(ciphertext, constant, ciphertext))
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Evaluator/Add/Vector"), func(t *testing.T) {
|
||||
@@ -393,7 +474,7 @@ func testEvaluatorAdd(tc *testContext, t *testing.T) {
|
||||
|
||||
require.NoError(t, tc.evaluator.Add(ciphertext, values2, ciphertext))
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -411,7 +492,7 @@ func testEvaluatorSub(tc *testContext, t *testing.T) {
|
||||
ciphertext3, err := tc.evaluator.SubNew(ciphertext1, ciphertext2)
|
||||
require.NoError(t, err)
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext3, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext3, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Evaluator/Sub/Ct"), func(t *testing.T) {
|
||||
@@ -425,7 +506,7 @@ func testEvaluatorSub(tc *testContext, t *testing.T) {
|
||||
|
||||
require.NoError(t, tc.evaluator.Sub(ciphertext1, ciphertext2, ciphertext1))
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Evaluator/Sub/Pt"), func(t *testing.T) {
|
||||
@@ -441,7 +522,7 @@ func testEvaluatorSub(tc *testContext, t *testing.T) {
|
||||
|
||||
require.NoError(t, tc.evaluator.Sub(ciphertext1, plaintext2, ciphertext2))
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, valuesTest, ciphertext2, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, valuesTest, ciphertext2, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Evaluator/Sub/Scalar"), func(t *testing.T) {
|
||||
@@ -456,7 +537,7 @@ func testEvaluatorSub(tc *testContext, t *testing.T) {
|
||||
|
||||
require.NoError(t, tc.evaluator.Sub(ciphertext, constant, ciphertext))
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Evaluator/Sub/Vector"), func(t *testing.T) {
|
||||
@@ -470,7 +551,7 @@ func testEvaluatorSub(tc *testContext, t *testing.T) {
|
||||
|
||||
require.NoError(t, tc.evaluator.Sub(ciphertext, values2, ciphertext))
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -494,7 +575,7 @@ func testEvaluatorRescale(tc *testContext, t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Evaluator/RescaleTo/Many"), func(t *testing.T) {
|
||||
@@ -520,7 +601,7 @@ func testEvaluatorRescale(tc *testContext, t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -539,7 +620,7 @@ func testEvaluatorMul(tc *testContext, t *testing.T) {
|
||||
ciphertext2, err := tc.evaluator.MulNew(ciphertext1, plaintext1)
|
||||
require.NoError(t, err)
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext2, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext2, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Evaluator/Mul/Ct/Scalar"), func(t *testing.T) {
|
||||
@@ -556,7 +637,7 @@ func testEvaluatorMul(tc *testContext, t *testing.T) {
|
||||
|
||||
require.NoError(t, tc.evaluator.Mul(ciphertext, constant, ciphertext))
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciphertext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Evaluator/Mul/Ct/Vector"), func(t *testing.T) {
|
||||
@@ -572,7 +653,7 @@ func testEvaluatorMul(tc *testContext, t *testing.T) {
|
||||
|
||||
tc.evaluator.Mul(ciphertext, values2, ciphertext)
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Evaluator/Mul/Ct/Pt"), func(t *testing.T) {
|
||||
@@ -587,7 +668,7 @@ func testEvaluatorMul(tc *testContext, t *testing.T) {
|
||||
|
||||
require.NoError(t, tc.evaluator.MulRelin(ciphertext1, plaintext1, ciphertext1))
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Evaluator/Mul/Ct/Ct/Degree0"), func(t *testing.T) {
|
||||
@@ -607,7 +688,7 @@ func testEvaluatorMul(tc *testContext, t *testing.T) {
|
||||
|
||||
require.NoError(t, tc.evaluator.MulRelin(ciphertext1, ciphertext2, ciphertext1))
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext1, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Evaluator/MulRelin/Ct/Ct"), func(t *testing.T) {
|
||||
@@ -625,7 +706,7 @@ func testEvaluatorMul(tc *testContext, t *testing.T) {
|
||||
require.NoError(t, tc.evaluator.MulRelin(ciphertext1, ciphertext2, ciphertext1))
|
||||
require.Equal(t, ciphertext1.Degree(), 1)
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
|
||||
// op1 <- op0 * op1
|
||||
values1, _, ciphertext1 = newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
|
||||
@@ -638,7 +719,7 @@ func testEvaluatorMul(tc *testContext, t *testing.T) {
|
||||
require.NoError(t, tc.evaluator.MulRelin(ciphertext1, ciphertext2, ciphertext2))
|
||||
require.Equal(t, ciphertext2.Degree(), 1)
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext2, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext2, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
|
||||
// op0 <- op0 * op0
|
||||
for i := range values1 {
|
||||
@@ -648,7 +729,7 @@ func testEvaluatorMul(tc *testContext, t *testing.T) {
|
||||
require.NoError(t, tc.evaluator.MulRelin(ciphertext1, ciphertext1, ciphertext1))
|
||||
require.Equal(t, ciphertext1.Degree(), 1)
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -674,7 +755,7 @@ func testEvaluatorMulThenAdd(tc *testContext, t *testing.T) {
|
||||
|
||||
require.NoError(t, tc.evaluator.MulThenAdd(ciphertext1, constant, ciphertext2))
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext2, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext2, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Evaluator/MulThenAdd/Vector"), func(t *testing.T) {
|
||||
@@ -697,7 +778,7 @@ func testEvaluatorMulThenAdd(tc *testContext, t *testing.T) {
|
||||
|
||||
require.Equal(t, ciphertext1.Degree(), 1)
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Evaluator/MulThenAdd/Pt"), func(t *testing.T) {
|
||||
@@ -720,7 +801,7 @@ func testEvaluatorMulThenAdd(tc *testContext, t *testing.T) {
|
||||
|
||||
require.Equal(t, ciphertext1.Degree(), 1)
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
|
||||
t.Run(GetTestName(tc.params, "Evaluator/MulRelinThenAdd/Ct"), func(t *testing.T) {
|
||||
@@ -747,7 +828,7 @@ func testEvaluatorMulThenAdd(tc *testContext, t *testing.T) {
|
||||
|
||||
require.Equal(t, ciphertext3.Degree(), 1)
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext3, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values2, ciphertext3, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
|
||||
// op1 = op1 + op0*op0
|
||||
values1, _, ciphertext1 = newTestVectors(tc, tc.encryptorSk, -1-1i, 1+1i, t)
|
||||
@@ -763,7 +844,7 @@ func testEvaluatorMulThenAdd(tc *testContext, t *testing.T) {
|
||||
|
||||
require.Equal(t, ciphertext1.Degree(), 1)
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values1, ciphertext1, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -810,7 +891,7 @@ func testBridge(tc *testContext, t *testing.T) {
|
||||
|
||||
switcher.RealToComplex(evalStandar, ctCI, stdCTHave)
|
||||
|
||||
VerifyTestVectors(stdParams, stdEncoder, stdDecryptor, values, stdCTHave, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(stdParams, stdEncoder, stdDecryptor, values, stdCTHave, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
|
||||
stdCTImag, err := stdEvaluator.MulNew(stdCTHave, 1i)
|
||||
require.NoError(t, err)
|
||||
@@ -819,6 +900,6 @@ func testBridge(tc *testContext, t *testing.T) {
|
||||
ciCTHave := NewCiphertext(ciParams, 1, stdCTHave.Level())
|
||||
switcher.ComplexToReal(evalStandar, stdCTHave, ciCTHave)
|
||||
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciCTHave, tc.params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
VerifyTestVectors(tc.params, tc.encoder, tc.decryptor, values, ciCTHave, tc.params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
}
|
||||
|
||||
124
ckks/encoder.go
124
ckks/encoder.go
@@ -2,6 +2,7 @@ package ckks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"math/big"
|
||||
|
||||
"github.com/tuneinsight/lattigo/v4/ring"
|
||||
@@ -179,14 +180,14 @@ func (ecd Encoder) Encode(values FloatSlice, pt *rlwe.Plaintext) (err error) {
|
||||
|
||||
// Decode decodes the input plaintext on a new FloatSlice.
|
||||
func (ecd Encoder) Decode(pt *rlwe.Plaintext, values FloatSlice) (err error) {
|
||||
return ecd.DecodePublic(pt, values, nil)
|
||||
return ecd.DecodePublic(pt, values, 0)
|
||||
}
|
||||
|
||||
// DecodePublic decodes the input plaintext on a FloatSlice.
|
||||
// It adds, before the decoding step (i.e. in the Ring) noise that follows the given distribution parameters.
|
||||
// If the underlying ringType is ConjugateInvariant, the imaginary part (and its related error) are zero.
|
||||
func (ecd Encoder) DecodePublic(pt *rlwe.Plaintext, values FloatSlice, noiseFlooding ring.DistributionParameters) (err error) {
|
||||
return ecd.decodePublic(pt, values, noiseFlooding)
|
||||
func (ecd Encoder) DecodePublic(pt *rlwe.Plaintext, values FloatSlice, logprec float64) (err error) {
|
||||
return ecd.decodePublic(pt, values, logprec)
|
||||
}
|
||||
|
||||
// Embed is a generic method to encode a FloatSlice on the target polyOut.
|
||||
@@ -477,7 +478,7 @@ func (ecd Encoder) plaintextToFloat(level int, scale rlwe.Scale, logSlots int, p
|
||||
|
||||
// decodePublic decode a plaintext to a FloatSlice.
|
||||
// The method will add a flooding noise before the decoding process following the defined distribution if it is not nil.
|
||||
func (ecd Encoder) decodePublic(pt *rlwe.Plaintext, values FloatSlice, noiseFlooding ring.DistributionParameters) (err error) {
|
||||
func (ecd Encoder) decodePublic(pt *rlwe.Plaintext, values FloatSlice, logprec float64) (err error) {
|
||||
|
||||
logSlots := pt.LogDimensions.Cols
|
||||
slots := 1 << logSlots
|
||||
@@ -492,16 +493,6 @@ func (ecd Encoder) decodePublic(pt *rlwe.Plaintext, values FloatSlice, noiseFloo
|
||||
ecd.buff.CopyLvl(pt.Level(), pt.Value)
|
||||
}
|
||||
|
||||
if noiseFlooding != nil {
|
||||
Xe, err := ring.NewSampler(ecd.prng, ecd.parameters.RingQ(), noiseFlooding, pt.IsMontgomery)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot decode: noise flooding: %w", err)
|
||||
}
|
||||
|
||||
Xe.AtLevel(pt.Level()).ReadAndAdd(ecd.buff)
|
||||
}
|
||||
|
||||
switch values.(type) {
|
||||
case []complex128, []float64, []*bignum.Complex, []*big.Float:
|
||||
default:
|
||||
@@ -522,6 +513,22 @@ func (ecd Encoder) decodePublic(pt *rlwe.Plaintext, values FloatSlice, noiseFloo
|
||||
return
|
||||
}
|
||||
|
||||
if logprec != 0 {
|
||||
|
||||
scale := math.Exp2(logprec)
|
||||
|
||||
switch values.(type) {
|
||||
case []*bignum.Complex, []complex128:
|
||||
for i := 0; i < slots; i++ {
|
||||
buffCmplx[i] = complex(math.Round(real(buffCmplx[i])*scale)/scale, math.Round(imag(buffCmplx[i])*scale)/scale)
|
||||
}
|
||||
default:
|
||||
for i := 0; i < slots; i++ {
|
||||
buffCmplx[i] = complex(math.Round(real(buffCmplx[i])*scale)/scale, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch values := values.(type) {
|
||||
case []float64:
|
||||
|
||||
@@ -530,10 +537,11 @@ func (ecd Encoder) decodePublic(pt *rlwe.Plaintext, values FloatSlice, noiseFloo
|
||||
for i := 0; i < slots; i++ {
|
||||
values[i] = real(buffCmplx[i])
|
||||
}
|
||||
|
||||
case []complex128:
|
||||
copy(values, buffCmplx)
|
||||
|
||||
case []*big.Float:
|
||||
|
||||
slots := utils.Min(len(values), slots)
|
||||
|
||||
for i := 0; i < slots; i++ {
|
||||
@@ -582,6 +590,20 @@ func (ecd Encoder) decodePublic(pt *rlwe.Plaintext, values FloatSlice, noiseFloo
|
||||
return
|
||||
}
|
||||
|
||||
var scale, half, zero *big.Float
|
||||
var tmp *big.Int
|
||||
if logprec != 0 {
|
||||
|
||||
// 2^logprec
|
||||
scale = new(big.Float).SetPrec(ecd.Prec()).SetFloat64(logprec)
|
||||
scale.Mul(scale, bignum.Log2(ecd.Prec()))
|
||||
scale = bignum.Exp(scale)
|
||||
|
||||
tmp = new(big.Int)
|
||||
half = new(big.Float).SetFloat64(0.5)
|
||||
zero = new(big.Float)
|
||||
}
|
||||
|
||||
switch values := values.(type) {
|
||||
case []float64:
|
||||
|
||||
@@ -591,6 +613,15 @@ func (ecd Encoder) decodePublic(pt *rlwe.Plaintext, values FloatSlice, noiseFloo
|
||||
values[i], _ = buffCmplx[i][0].Float64()
|
||||
}
|
||||
|
||||
if logprec != 0 {
|
||||
|
||||
scaleF64, _ := scale.Float64()
|
||||
|
||||
for i := 0; i < slots; i++ {
|
||||
values[i] = math.Round(values[i]*scaleF64) / scaleF64
|
||||
}
|
||||
}
|
||||
|
||||
case []complex128:
|
||||
|
||||
slots := utils.Min(len(values), slots)
|
||||
@@ -599,6 +630,15 @@ func (ecd Encoder) decodePublic(pt *rlwe.Plaintext, values FloatSlice, noiseFloo
|
||||
values[i] = buffCmplx[i].Complex128()
|
||||
}
|
||||
|
||||
if logprec != 0 {
|
||||
|
||||
scaleF64, _ := scale.Float64()
|
||||
|
||||
for i := 0; i < slots; i++ {
|
||||
values[i] = complex(math.Round(real(values[i])*scaleF64)/scaleF64, math.Round(imag(values[i])*scaleF64)/scaleF64)
|
||||
}
|
||||
}
|
||||
|
||||
case []*big.Float:
|
||||
slots := utils.Min(len(values), slots)
|
||||
|
||||
@@ -610,6 +650,25 @@ func (ecd Encoder) decodePublic(pt *rlwe.Plaintext, values FloatSlice, noiseFloo
|
||||
|
||||
values[i].Set(buffCmplx[i][0])
|
||||
}
|
||||
if logprec != 0 {
|
||||
for i := range values {
|
||||
values[i].Mul(values[i], scale)
|
||||
|
||||
// Adds/Subtracts 0.5
|
||||
if values[i].Cmp(zero) >= 0 {
|
||||
values[i].Add(values[i], half)
|
||||
} else {
|
||||
values[i].Sub(values[i], half)
|
||||
}
|
||||
|
||||
// Round = floor +/- 0.5
|
||||
values[i].Int(tmp)
|
||||
|
||||
values[i].SetInt(tmp)
|
||||
|
||||
values[i].Quo(values[i], scale)
|
||||
}
|
||||
}
|
||||
|
||||
case []*bignum.Complex:
|
||||
|
||||
@@ -635,6 +694,41 @@ func (ecd Encoder) decodePublic(pt *rlwe.Plaintext, values FloatSlice, noiseFloo
|
||||
values[i][0].Set(buffCmplx[i][0])
|
||||
values[i][1].Set(buffCmplx[i][1])
|
||||
}
|
||||
|
||||
if logprec != 0 {
|
||||
for i := range values {
|
||||
|
||||
// Real
|
||||
values[i][0].Mul(values[i][0], scale)
|
||||
|
||||
// Adds/Subtracts 0.5
|
||||
if values[i][0].Cmp(zero) >= 0 {
|
||||
values[i][0].Add(values[i][0], half)
|
||||
} else {
|
||||
values[i][0].Sub(values[i][0], half)
|
||||
}
|
||||
|
||||
// Round = floor +/- 0.5
|
||||
values[i][0].Int(tmp)
|
||||
values[i][0].SetInt(tmp)
|
||||
values[i][0].Quo(values[i][0], scale)
|
||||
|
||||
// Imag
|
||||
values[i][1].Mul(values[i][1], scale)
|
||||
|
||||
// Adds/Subtracts 0.5
|
||||
if values[i][1].Cmp(zero) >= 0 {
|
||||
values[i][1].Add(values[i][1], half)
|
||||
} else {
|
||||
values[i][1].Sub(values[i][1], half)
|
||||
}
|
||||
|
||||
// Round = floor +/- 0.5
|
||||
values[i][1].Int(tmp)
|
||||
values[i][1].SetInt(tmp)
|
||||
values[i][1].Quo(values[i][1], scale)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -58,18 +58,18 @@ func (prec PrecisionStats) String() string {
|
||||
// GetPrecisionStats generates a PrecisionStats struct from the reference values and the decrypted values
|
||||
// vWant.(type) must be either []complex128 or []float64
|
||||
// element.(type) must be either *Plaintext, *Ciphertext, []complex128 or []float64. If not *Ciphertext, then decryptor can be nil.
|
||||
func GetPrecisionStats(params Parameters, encoder *Encoder, decryptor *rlwe.Decryptor, want, have interface{}, noiseFlooding ring.DistributionParameters, computeDCF bool) (prec PrecisionStats) {
|
||||
func GetPrecisionStats(params Parameters, encoder *Encoder, decryptor *rlwe.Decryptor, want, have interface{}, logprec float64, computeDCF bool) (prec PrecisionStats) {
|
||||
|
||||
if encoder.Prec() <= 53 {
|
||||
return getPrecisionStatsF64(params, encoder, decryptor, want, have, noiseFlooding, computeDCF)
|
||||
return getPrecisionStatsF64(params, encoder, decryptor, want, have, logprec, computeDCF)
|
||||
}
|
||||
|
||||
return getPrecisionStatsF128(params, encoder, decryptor, want, have, noiseFlooding, computeDCF)
|
||||
return getPrecisionStatsF128(params, encoder, decryptor, want, have, logprec, computeDCF)
|
||||
}
|
||||
|
||||
func VerifyTestVectors(params Parameters, encoder *Encoder, decryptor *rlwe.Decryptor, valuesWant, valuesHave interface{}, log2MinPrec int, noise ring.DistributionParameters, printPrecisionStats bool, t *testing.T) {
|
||||
func VerifyTestVectors(params Parameters, encoder *Encoder, decryptor *rlwe.Decryptor, valuesWant, valuesHave interface{}, log2MinPrec int, logprec float64, printPrecisionStats bool, t *testing.T) {
|
||||
|
||||
precStats := GetPrecisionStats(params, encoder, decryptor, valuesWant, valuesHave, noise, false)
|
||||
precStats := GetPrecisionStats(params, encoder, decryptor, valuesWant, valuesHave, logprec, false)
|
||||
|
||||
if printPrecisionStats {
|
||||
t.Log(precStats.String())
|
||||
@@ -92,7 +92,7 @@ func VerifyTestVectors(params Parameters, encoder *Encoder, decryptor *rlwe.Decr
|
||||
require.GreaterOrEqual(t, if64, float64(log2MinPrec))
|
||||
}
|
||||
|
||||
func getPrecisionStatsF64(params Parameters, encoder *Encoder, decryptor *rlwe.Decryptor, want, have interface{}, noiseFlooding ring.DistributionParameters, computeDCF bool) (prec PrecisionStats) {
|
||||
func getPrecisionStatsF64(params Parameters, encoder *Encoder, decryptor *rlwe.Decryptor, want, have interface{}, logprec float64, computeDCF bool) (prec PrecisionStats) {
|
||||
|
||||
precision := encoder.Prec()
|
||||
|
||||
@@ -128,12 +128,12 @@ func getPrecisionStatsF64(params Parameters, encoder *Encoder, decryptor *rlwe.D
|
||||
|
||||
switch have := have.(type) {
|
||||
case *rlwe.Ciphertext:
|
||||
if err := encoder.DecodePublic(decryptor.DecryptNew(have), valuesHave, noiseFlooding); err != nil {
|
||||
if err := encoder.DecodePublic(decryptor.DecryptNew(have), valuesHave, logprec); err != nil {
|
||||
// Sanity check, this error should never happen.
|
||||
panic(err)
|
||||
}
|
||||
case *rlwe.Plaintext:
|
||||
if err := encoder.DecodePublic(have, valuesHave, noiseFlooding); err != nil {
|
||||
if err := encoder.DecodePublic(have, valuesHave, logprec); err != nil {
|
||||
// Sanity check, this error should never happen.
|
||||
panic(err)
|
||||
}
|
||||
@@ -328,7 +328,7 @@ func calcmedianF64(values []struct{ Real, Imag, L2 float64 }) (median Stats) {
|
||||
}
|
||||
}
|
||||
|
||||
func getPrecisionStatsF128(params Parameters, encoder *Encoder, decryptor *rlwe.Decryptor, want, have interface{}, noiseFlooding ring.DistributionParameters, computeDCF bool) (prec PrecisionStats) {
|
||||
func getPrecisionStatsF128(params Parameters, encoder *Encoder, decryptor *rlwe.Decryptor, want, have interface{}, logprec float64, computeDCF bool) (prec PrecisionStats) {
|
||||
precision := encoder.Prec()
|
||||
|
||||
var valuesWant []*bignum.Complex
|
||||
@@ -372,13 +372,13 @@ func getPrecisionStatsF128(params Parameters, encoder *Encoder, decryptor *rlwe.
|
||||
switch have := have.(type) {
|
||||
case *rlwe.Ciphertext:
|
||||
valuesHave = make([]*bignum.Complex, len(valuesWant))
|
||||
if err := encoder.DecodePublic(decryptor.DecryptNew(have), valuesHave, noiseFlooding); err != nil {
|
||||
if err := encoder.DecodePublic(decryptor.DecryptNew(have), valuesHave, logprec); err != nil {
|
||||
// Sanity check, this error should never happen.
|
||||
panic(err)
|
||||
}
|
||||
case *rlwe.Plaintext:
|
||||
valuesHave = make([]*bignum.Complex, len(valuesWant))
|
||||
if err := encoder.DecodePublic(have, valuesHave, noiseFlooding); err != nil {
|
||||
if err := encoder.DecodePublic(have, valuesHave, logprec); err != nil {
|
||||
// Sanity check, this error should never happen.
|
||||
panic(err)
|
||||
}
|
||||
|
||||
@@ -221,7 +221,7 @@ func testEncToShareProtocol(tc *testContext, t *testing.T) {
|
||||
pt.Scale = ciphertext.Scale
|
||||
tc.ringQ.AtLevel(pt.Level()).SetCoefficientsBigint(rec.Value, pt.Value)
|
||||
|
||||
ckks.VerifyTestVectors(params, tc.encoder, nil, coeffs, pt, params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, tc.encoder, nil, coeffs, pt, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
|
||||
crp := P[0].s2e.SampleCRP(params.MaxLevel(), tc.crs)
|
||||
|
||||
@@ -236,7 +236,7 @@ func testEncToShareProtocol(tc *testContext, t *testing.T) {
|
||||
ctRec.Scale = params.DefaultScale()
|
||||
P[0].s2e.GetEncryption(P[0].publicShareS2E, crp, ctRec)
|
||||
|
||||
ckks.VerifyTestVectors(params, tc.encoder, tc.decryptorSk0, coeffs, ctRec, params.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(params, tc.encoder, tc.decryptorSk0, coeffs, ctRec, params.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -464,7 +464,7 @@ func testRefreshParameterized(tc *testContext, paramsOut ckks.Parameters, skOut
|
||||
transform.Func(coeffs)
|
||||
}
|
||||
|
||||
ckks.VerifyTestVectors(paramsOut, ckks.NewEncoder(paramsOut), ckks.NewDecryptor(paramsOut, skIdealOut), coeffs, ciphertext, paramsOut.LogDefaultScale(), nil, *printPrecisionStats, t)
|
||||
ckks.VerifyTestVectors(paramsOut, ckks.NewEncoder(paramsOut), ckks.NewDecryptor(paramsOut, skIdealOut), coeffs, ciphertext, paramsOut.LogDefaultScale(), 0, *printPrecisionStats, t)
|
||||
}
|
||||
|
||||
func newTestVectors(tc *testContext, encryptor *rlwe.Encryptor, a, b complex128, logSlots int) (values []*bignum.Complex, plaintext *rlwe.Plaintext, ciphertext *rlwe.Ciphertext) {
|
||||
|
||||
@@ -220,7 +220,7 @@ func printDebug(params ckks.Parameters, ciphertext *rlwe.Ciphertext, valuesWant
|
||||
fmt.Printf("ValuesTest: %6.10f %6.10f %6.10f %6.10f...\n", valuesTest[0], valuesTest[1], valuesTest[2], valuesTest[3])
|
||||
fmt.Printf("ValuesWant: %6.10f %6.10f %6.10f %6.10f...\n", valuesWant[0], valuesWant[1], valuesWant[2], valuesWant[3])
|
||||
|
||||
precStats := ckks.GetPrecisionStats(params, encoder, nil, valuesWant, valuesTest, nil, false)
|
||||
precStats := ckks.GetPrecisionStats(params, encoder, nil, valuesWant, valuesTest, 0, false)
|
||||
|
||||
fmt.Println(precStats.String())
|
||||
fmt.Println()
|
||||
|
||||
@@ -326,14 +326,14 @@ func main() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("Addition - ct + ct%s", ckks.GetPrecisionStats(params, ecd, dec, want, ct3, nil, false).String())
|
||||
fmt.Printf("Addition - ct + ct%s", ckks.GetPrecisionStats(params, ecd, dec, want, ct3, 0, false).String())
|
||||
|
||||
// ciphertext + plaintext
|
||||
ct3, err = eval.AddNew(ct1, pt2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("Addition - ct + pt%s", ckks.GetPrecisionStats(params, ecd, dec, want, ct3, nil, false).String())
|
||||
fmt.Printf("Addition - ct + pt%s", ckks.GetPrecisionStats(params, ecd, dec, want, ct3, 0, false).String())
|
||||
|
||||
// ciphertext + vector
|
||||
// Note that the evaluator will encode this vector at the scale of the input ciphertext to ensure a noiseless addition.
|
||||
@@ -341,7 +341,7 @@ func main() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("Addition - ct + vector%s", ckks.GetPrecisionStats(params, ecd, dec, want, ct3, nil, false).String())
|
||||
fmt.Printf("Addition - ct + vector%s", ckks.GetPrecisionStats(params, ecd, dec, want, ct3, 0, false).String())
|
||||
|
||||
// ciphertext + scalar
|
||||
scalar := 3.141592653589793 + 1.4142135623730951i
|
||||
@@ -354,7 +354,7 @@ func main() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("Addition - ct + scalar%s", ckks.GetPrecisionStats(params, ecd, dec, want, ct3, nil, false).String())
|
||||
fmt.Printf("Addition - ct + scalar%s", ckks.GetPrecisionStats(params, ecd, dec, want, ct3, 0, false).String())
|
||||
|
||||
fmt.Printf("==============\n")
|
||||
fmt.Printf("MULTIPLICATION\n")
|
||||
@@ -418,14 +418,14 @@ func main() {
|
||||
// For the sake of conciseness, we will not rescale the output for the other multiplication example.
|
||||
// But this maintenance operation should usually be called (either before of after the multiplication depending on the choice of noise management)
|
||||
// to control the magnitude of the plaintext scale.
|
||||
fmt.Printf("Multiplication - ct * ct%s", ckks.GetPrecisionStats(params, ecd, dec, want, res, nil, false).String())
|
||||
fmt.Printf("Multiplication - ct * ct%s", ckks.GetPrecisionStats(params, ecd, dec, want, res, 0, false).String())
|
||||
|
||||
// ciphertext + plaintext
|
||||
ct3, err = eval.MulRelinNew(ct1, pt2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("Multiplication - ct * pt%s", ckks.GetPrecisionStats(params, ecd, dec, want, ct3, nil, false).String())
|
||||
fmt.Printf("Multiplication - ct * pt%s", ckks.GetPrecisionStats(params, ecd, dec, want, ct3, 0, false).String())
|
||||
|
||||
// ciphertext + vector
|
||||
// Note that when giving non-encoded vectors, the evaluator will internally encode this vector with the appropriate scale that ensure that
|
||||
@@ -434,7 +434,7 @@ func main() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("Multiplication - ct * vector%s", ckks.GetPrecisionStats(params, ecd, dec, want, ct3, nil, false).String())
|
||||
fmt.Printf("Multiplication - ct * vector%s", ckks.GetPrecisionStats(params, ecd, dec, want, ct3, 0, false).String())
|
||||
|
||||
// ciphertext + scalar (scalar = pi + sqrt(2) * i)
|
||||
for i := 0; i < Slots; i++ {
|
||||
@@ -448,7 +448,7 @@ func main() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("Multiplication - ct * scalar%s", ckks.GetPrecisionStats(params, ecd, dec, want, ct3, nil, false).String())
|
||||
fmt.Printf("Multiplication - ct * scalar%s", ckks.GetPrecisionStats(params, ecd, dec, want, ct3, 0, false).String())
|
||||
|
||||
fmt.Printf("======================\n")
|
||||
fmt.Printf("ROTATION & CONJUGATION\n")
|
||||
@@ -488,7 +488,7 @@ func main() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("Rotation by k=%d %s", rot, ckks.GetPrecisionStats(params, ecd, dec, want, ct3, nil, false).String())
|
||||
fmt.Printf("Rotation by k=%d %s", rot, ckks.GetPrecisionStats(params, ecd, dec, want, ct3, 0, false).String())
|
||||
|
||||
// Conjugation
|
||||
for i := 0; i < Slots; i++ {
|
||||
@@ -499,7 +499,7 @@ func main() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("Conjugation %s", ckks.GetPrecisionStats(params, ecd, dec, want, ct3, nil, false).String())
|
||||
fmt.Printf("Conjugation %s", ckks.GetPrecisionStats(params, ecd, dec, want, ct3, 0, false).String())
|
||||
|
||||
// Note that rotations and conjugation only add a fixed additive noise independent of the ciphertext noise.
|
||||
// If the parameters are set correctly, this noise can be rounding error (thus negligible).
|
||||
@@ -574,7 +574,7 @@ func main() {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Polynomial Evaluation %s", ckks.GetPrecisionStats(params, ecd, dec, want, res, nil, false).String())
|
||||
fmt.Printf("Polynomial Evaluation %s", ckks.GetPrecisionStats(params, ecd, dec, want, res, 0, false).String())
|
||||
|
||||
// =============================
|
||||
// Vector Polynomials Evaluation
|
||||
@@ -616,7 +616,7 @@ func main() {
|
||||
// Note that this method can obviously be used to average values.
|
||||
// For a good noise management, it is recommended to first multiply the values by 1/n, then
|
||||
// apply the innersum and then only apply the rescaling.
|
||||
fmt.Printf("Innersum %s", ckks.GetPrecisionStats(params, ecd, dec, want, res, nil, false).String())
|
||||
fmt.Printf("Innersum %s", ckks.GetPrecisionStats(params, ecd, dec, want, res, 0, false).String())
|
||||
|
||||
// The replicate operation is exactly the same as the innersum operation, but in reverse
|
||||
eval = eval.WithKey(rlwe.NewMemEvaluationKeySet(rlk, kgen.GenGaloisKeysNew(params.GaloisElementsForReplicate(batch, n), sk)...))
|
||||
@@ -633,7 +633,7 @@ func main() {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Replicate %s", ckks.GetPrecisionStats(params, ecd, dec, want, res, nil, false).String())
|
||||
fmt.Printf("Replicate %s", ckks.GetPrecisionStats(params, ecd, dec, want, res, 0, false).String())
|
||||
|
||||
// And we arrive to the linear transformation.
|
||||
// This method enables to evaluate arbitrary Slots x Slots matrices on a ciphertext.
|
||||
@@ -713,7 +713,7 @@ func main() {
|
||||
// We evaluate the same circuit in plaintext
|
||||
want = EvaluateLinearTransform(values1, diagonals)
|
||||
|
||||
fmt.Printf("vector x matrix %s", ckks.GetPrecisionStats(params, ecd, dec, want, res, nil, false).String())
|
||||
fmt.Printf("vector x matrix %s", ckks.GetPrecisionStats(params, ecd, dec, want, res, 0, false).String())
|
||||
|
||||
// =============================
|
||||
// Homomorphic Encoding/Decoding
|
||||
|
||||
@@ -223,7 +223,7 @@ func printDebug(params ckks.Parameters, ciphertext *rlwe.Ciphertext, valuesWant
|
||||
fmt.Printf("ValuesWant: %6.10f %6.10f %6.10f %6.10f...\n", valuesWant[0], valuesWant[1], valuesWant[2], valuesWant[3])
|
||||
fmt.Println()
|
||||
|
||||
precStats := ckks.GetPrecisionStats(params, encoder, nil, valuesWant, valuesTest, nil, false)
|
||||
precStats := ckks.GetPrecisionStats(params, encoder, nil, valuesWant, valuesTest, 0, false)
|
||||
|
||||
fmt.Println(precStats.String())
|
||||
|
||||
|
||||
@@ -172,7 +172,7 @@ func printDebug(params ckks.Parameters, ciphertext *rlwe.Ciphertext, valuesWant
|
||||
fmt.Printf("ValuesWant: %6.10f %6.10f %6.10f %6.10f...\n", valuesWant[0], valuesWant[1], valuesWant[2], valuesWant[3])
|
||||
fmt.Println()
|
||||
|
||||
precStats := ckks.GetPrecisionStats(params, encoder, nil, valuesWant, valuesTest, nil, false)
|
||||
precStats := ckks.GetPrecisionStats(params, encoder, nil, valuesWant, valuesTest, 0, false)
|
||||
|
||||
fmt.Println(precStats.String())
|
||||
|
||||
|
||||
@@ -103,5 +103,5 @@ func PrintPrecisionStats(params ckks.Parameters, ct *rlwe.Ciphertext, want []flo
|
||||
fmt.Printf("...\n")
|
||||
|
||||
// Pretty prints the precision stats
|
||||
fmt.Println(ckks.GetPrecisionStats(params, ecd, dec, have, want, nil, false).String())
|
||||
fmt.Println(ckks.GetPrecisionStats(params, ecd, dec, have, want, 0, false).String())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user