diff --git a/CHANGELOG.md b/CHANGELOG.md
index 54611af1..ad3bec06 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -4,55 +4,96 @@ All notable changes to this project will be documented in this file.
## Unreleased
-- RLWE : added a new `rlwe` package as common implementation base for the lattigo RLWE schemes
-- DRLWE : added a new `drlwe` package as a common implementation base for the lattigo multiparty RLWE schemes
-- BFV/CKKS : the schemes are now using a common implementation for their keys
-- BFV/CKKS : the rotation-keys are now indexed by their corresponding galois automorphism
-- BFV/CKKS : the `Evaluator` interface now has a single method for all column rotations and one method for the row-rotation/conjugate.
-- BFV/CKKS : the relinearization and rotation keys are now passed to the `Evaluator` constructor methods (and no longer to the operations methods)
-- DBFV/DCKKS : added a common interface and implementation for each multiparty key-generation protocols
+- Added SECURITY.md
+- ALL: when possible, public functions now use `int` instead of `uint64` as parameters and return values.
+- RING: RNS rescaling API is now inplace and can take a different poly as output.
+- RING: added `ReadFromDistLvl` and `ReadAndAddFromDistLvl` to Gaussian sampler API.
+- RLWE: added a new `rlwe` package as common implementation base for the lattigo RLWE schemes.
+- DRLWE: added a new `drlwe` package as a common implementation base for the lattigo multiparty RLWE schemes.
+- BFV/CKKS: the schemes are now using a common implementation for their keys.
+- BFV/CKKS: the rotation-keys are now indexed by their corresponding Galois automorphism.
+- BFV/CKKS: the `Evaluator` interface now has a single method for all column rotations and one method for the row-rotation/conjugate.
+- BFV/CKKS: the relinearization and rotation keys are now passed to the `Evaluator` constructor methods (and no longer to the operations methods).
+- DBFV/DCKKS: added a common interface and implementation for each multiparty key-generation protocol.
+- DCKKS: public-refresh now takes a target desired output scale, which allows to refresh the ciphertext to the default scale.
+- CKKS: added methods for operating linear-transformation and improved several aspects listed below:
+
+#### CKKS Bootstrapping
+- The procedure now allows for a more granular parameterization.
+- Added flag in bootstrapping parameters for bit-reversed inputs (with bit-reversed output) CoeffsToSlots and SlotsToCoeffs.
+- Added optional Arcsine.
+- The procedure now uses the new linear-transformation API.
+- `CoeffsToSlots` and `SlotsToCoeffs` are now standalone public functions.
+
+#### New CKKS Evaluator methods
+- `RotateHoisted`: evaluate several rotations on a single ciphertext.
+- `LinearTransform`: evaluate one or more `PtDiagMatrix` on a ciphertext using `MultiplyByDiagMatrix` or `MultiplyByDiagMatrixBSGS` according to the encoding of `PtDiagMatrix`.
+- `MultiplyByDiagMatrix`: multiplies a ciphertext with a `PtDiagMatrix` using n rotations with single hoisting.
+- `MultiplyByDiagMatrixBSGS`: multiplies a ciphertext with a `PtDiagMatrix` using 2sqrt(n) rotations with double-hoisting.
+- `InnerSumLog`: optimal log approach that works for any value (not only powers of two) and can be parameterized to inner sum batches of values (sub-vectors).
+- `InnerSum`: naive approach that is faster for small values but needs more keys.
+- `ReplicateLog`: optimal log approach that works for any value (not only powers of two) and can be parameterized to replicate batches of values (sub-vectors).
+- `Replicate`: naive approach that is faster for small values but needs more keys.
+
+#### New CKKS Encoder methods
+- `PtDiagMatrix`: struct that represents a linear transformation.
+- `EncodeDiagMatrixBSGSAtLvl`: encodes a `PtDiagMatrix` at a given level, with a given scale for the BSGS algorithm.
+- `EncodeDiagMatrixAtLvl`: encodes a `PtDiagMatrix` at a given level, with a given scale for a naive evaluation.
+- `DecodePublic`: adds Gaussian noise of variance floor(sigma * sqrt(2*pi)) before the decoding step (see SECURITY.md).
+- `DecodeCoeffsPublic`: adds Gaussian noise of variance floor(sigma * sqrt(2*pi)) before the decoding step (see SECURITY.md).
+- `GetErrSTDFreqDom` : get the error standard deviation in the frequency domain (slots).
+- `GetErrSTDTimeDom`: get the error standard deviation in the time domain (coefficients).
+
+#### CKKS Fixes
+- `MultByi` now correctly sets the output ciphertext scale.
+- `Relinearize` now correctly sets the output ciphertext level.
+- matrix-vector multiplication now correctly manages ciphertext of higher level than the plaintext matrix.
+- matrix-vector encoding now properly works for negative diagonal indexes.
+
+#### Others
+- PrecisionStats now includes the standard deviation of the error in the slots and coefficients domains.
## [2.1.1] - 2020-12-23
### Added
-- BFV/CKKS : added a check for minimum polynomial degree when creating parameters.
-- BFV : added the `bfv.Element.Level` method.
-- RING : test for sparse ternary sampler.
+- BFV/CKKS: added a check for minimum polynomial degree when creating parameters.
+- BFV: added the `bfv.Element.Level` method.
+- RING: test for sparse ternary sampler.
### Changed
-- BFV/CKKS : pk is now (-as + e, a) instead of (-(as + e), a).
-- BFV : harmonized the EvaluationKey setter from `SetRelinKeys` to `Set`
-- CKKS : renamed `BootstrappParams` into `BootstrappingParameters`
-- CKKS : the `Evaluator.DropLevel`, `Parameters.SetLogSlots` and `Element.Copy` methods no longer return errors
-- RING : minimum poly degree modulus is 16 to ensure the NTT correctness.
-- RING : isPrime has been replaced by big.ProbablyPrime, which is deterministic for integers < 2^64.
+- BFV/CKKS: pk is now (-as + e, a) instead of (-(as + e), a).
+- BFV: harmonized the EvaluationKey setter from `SetRelinKeys` to `Set`.
+- CKKS: renamed `BootstrappParams` into `BootstrappingParameters`.
+- CKKS: the `Evaluator.DropLevel`, `Parameters.SetLogSlots` and `Element.Copy` methods no longer return errors.
+- RING: minimum poly degree modulus is 16 to ensure the NTT correctness.
+- RING: isPrime has been replaced by big.ProbablyPrime, which is deterministic for integers < 2^64.
### Fixed
-- ALL : reduced cyclomatic complexity of several functions.
-- ALL : fixed all instances reporeted by staticcheck and gosec excluding G103 (audit the use of unsafe).
-- ALL : test vectors are now generated using the crypto/rand instead of math/rand package.
-- ALL : fixed some unhandled errors.
-- BFV/CKKS : improved the documentation: documentated several hardcoded values and fixed typos.
-- RING : fixed bias in sparse ternary sampling for some parameters.
-- RING : tests for the modular reduction algorithms are now deterministic.
+- ALL: reduced cyclomatic complexity of several functions.
+- ALL: fixed all instances reported by staticcheck and gosec excluding G103 (audit the use of unsafe).
+- ALL: test vectors are now generated using the crypto/rand instead of math/rand package.
+- ALL: fixed some unhandled errors.
+- BFV/CKKS: improved the documentation: documented several hard-coded values and fixed typos.
+- RING: fixed bias in sparse ternary sampling for some parameters.
+- RING: tests for the modular reduction algorithms are now deterministic.
## [2.1.0] - 2020-12-11
### Added
-- BFV : special-purpose plaintext types (`PlaintextRingT` or `PlaintextMul`) for optimized ct-pt operations. See bfv/encoder.go and bfv/plaintext.go.
-- BFV : allocation-free `Encoder` methods
-- RING : `GenNTTPrimes` now takes the value `Nth` (for Nth primitive root) as input rather than `logN`.
+- BFV: special-purpose plaintext types (`PlaintextRingT` or `PlaintextMul`) for optimized ct-pt operations. See bfv/encoder.go and bfv/plaintext.go.
+- BFV: allocation-free `Encoder` methods.
+- RING: `GenNTTPrimes` now takes the value `Nth` (for Nth primitive root) as input rather than `logN`.
### Changed
-- BFV : the `Encoder.DecodeUint64` and `Encoder.DecodeInt64` methods now take the output slice as argument.
-- CKKS : API of `Evaluator.RotateColumns` becomes `Evaluator.Rotate`
-- CKKS : the change of variable in `Evaluator.EvaluateCheby` isn't done automatically anymore and the user must do it before calling the function to ensure correctness.
-- CKKS : when encoding, the number of slots must now be given in log2 basis. This is to prevent errors that would induced by zero values or non power of two values.
-- CKKS : new encoder API : `EncodeAtLvlNew` and `EncodeNTTAtLvlNew`, which allow a user to encode a plaintext at a specific level.
+- BFV: the `Encoder.DecodeUint64` and `Encoder.DecodeInt64` methods now take the output slice as argument.
+- CKKS: API of `Evaluator.RotateColumns` becomes `Evaluator.Rotate`.
+- CKKS: the change of variable in `Evaluator.EvaluateCheby` isn't done automatically anymore and the user must do it before calling the function to ensure correctness.
+- CKKS: when encoding, the number of slots must now be given in log2 basis. This is to prevent errors that would induced by zero values or non power of two values.
+- CKKS: new encoder API : `EncodeAtLvlNew` and `EncodeNTTAtLvlNew`, which allow a user to encode a plaintext at a specific level.
### Removed
-- CKKS : removed method `Evaluator.EvaluateChebySpecial`
-- BFV : removed `QiMul` field from `bfv.Parameters`. It is now automatically generated.
+- CKKS: removed method `Evaluator.EvaluateChebySpecial`.
+- BFV: removed `QiMul` field from `bfv.Parameters`. It is now automatically generated.
## [2.0.0] - 2020-10-07
@@ -96,7 +137,7 @@ All notable changes to this project will be documented in this file.
- CKKS: EvaluatePolyFast(.) and EvaluatePolyEco(.) are replaced by EvaluatePoly(.).
- CKKS: EvaluateChebyFast(.) and EvaluateChebyEco(.) are replaced by EvaluatePolyCheby(.).
- CKKS: EvaluateChebyEcoSpecial(.) and EvaluateChebyFastSpecial(.) are replaced by EvaluatePolyChebySpecial(.).
-- RING: The Float128 type was removed due to cross-platform incompatility.
+- RING: The Float128 type was removed due to cross-platform incompatibility.
### Fixes
- BFV: Fixed multiplication that was failing when #Qi != #QMul.
@@ -135,7 +176,7 @@ All notable changes to this project will be documented in this file.
- RING: Enabled dense and sparse ternary polynomials sampling directly from the context.
- RING: New API enabling "level"-wise polynomial arithmetic.
- RING: New API for modulus switching with flooring and rounding.
-- UTILS: The pacakge utils now regroups all the utility methods which were previously duplicated among packages.
+- UTILS: The package utils now regroups all the utility methods which were previously duplicated among packages.
### Removed
- BFV/CKKS/DBFV/DCKKS: Removed their respective context. Ring context remains public.
- All schemes: Removed key-switching with bit decomposition. This option will however be re-introduced at a later stage since applications using small parameters can be impacted by this change.
diff --git a/SECURITY.md b/SECURITY.md
new file mode 100644
index 00000000..22eee521
--- /dev/null
+++ b/SECURITY.md
@@ -0,0 +1,19 @@
+# Code Review
+Lattigo 2.0.0 has been code-reviewed by ELCA in November 2020 and, within the allocated time for the code review, no critical or high-risk issues were found.
+
+# Security of Approximate-Numbers Homomorphic Encryption
+Homomorphic encryption schemes are by definition malleable, and are therefore not secure against chosen ciphertext attacks (CCA security). They can be though secure against chosen plaintext attacks (CPA security).
+
+Classified as an _approximate decryption_ scheme, the CKKS scheme is secure as long as the plaintext result of a decryption is only revealed to entities with knowledge of the secret-key. This is because, given a ciphertext (_-as + m + e_, _a_), the decryption outputs a plaintext _m+e_. [Li and Micciancio](https://eprint.iacr.org/2020/1533) show that using this plaintext, it is possible to recover the secret-key with ((_-as + m + e_) - (_m + e_)) * _a^-1 = asa^-1 = s_ (the probability of _a_ being invertible is overwhelming, and if _a_ is not invertible, only require a few more samples are required).
+
+This attack demonstrates that, when using an approximate homomorphic encryption scheme, the usual CPA security may not sufficient depending on the application setting. Many applications do not require to share the result with external parties and are not affected by this attack, but the ones that do must take the appropriate steps to ensure that no key-dependent information is leaked. A homomorphic encryption scheme that provides such functionality and that can be secure when releasing decrypted plaintext to external parties is defined to be CPAD secure. The corresponding indistinguishability notion (IND-CPAD) is defined as "indistinguishability under chosen plaintext attacks with decryption oracles."
+
+# CPAD Security for CKKS
+Lattigo implements tools to mitigate _Li and Micciancio_'s attack. In particular, the decoding step of CKKS (and its real-number variant R-CKKS) allows the user to add a key-independent error **_e_** of standard deviation **_σ_** to the decrypted plaintext before decoding.
+
+If at any point of an application, decrypted values have to be shared with external parties, then the user must ensure that each shared plaintext is first _sanitized_ before being shared. To do so, the user must use the **DecodePublic** method instead of the usual **Decode**. **DecodePublic** takes as additional input **_σ_**, and samples a key-independent error **_e_** with standard deviation **_σ_**, that is added to the plaintext before decoding.
+
+Estimating **_σ_** must be done carefully and we suggest the following iterative process to do so:
+ 1. Given a security parameter **_λ_** and a circuit **_C_** that takes as inputs length-**_n_** vectors **_ω_** following a distribution **_χ_**, select the appropriate parameters enabling the homomorphic evaluation of **_C(ω)_**, denoted by **_H(C(ω))_**, which includes the encoding, encryption, evaluation, decryption and decoding.
+ 2. Sample input vectors **_ω_** from the distribution **_χ_** and compute the standard deviation **_σ_** in the time domain (coefficient domain) of **_e = C(ω) - H(C(ω))_**. This can be done using the encoder method **GetErrSTDTimeDom(_C(ω)_, _H(C(ω))_, _Δ_)**, where **_Δ_** is the scale of the plaintext after the decryption. The user should make sure that the underlying circuit computed by **H(C())** is identical to **C()**; i.e., if the homomorphic implementation **H(C())** uses polynomial approximations, then **C()** should use them too, instead of using the original exact function. This will ensure that **_e_**, and therefore **_σ_**, are as close as possible to the actual underlying scheme error, and not influenced by function-approximation errors.
+ 3. Use the encoder method **DecodePublic** with the parameter **_σ_** to decode plaintexts that will be published. **DecodePublic** adds an error **_e_** with standard deviation **_σ_** bounded by **B = _σ • (2π)0.5_**. The precision loss, compared to a private decoding, should be less than half a bit on average.
diff --git a/bfv/bfv_test.go b/bfv/bfv_test.go
index 31f8059c..3ee097f5 100644
--- a/bfv/bfv_test.go
+++ b/bfv/bfv_test.go
@@ -120,7 +120,7 @@ func testParameters(testctx *testContext, t *testing.T) {
t.Run("Parameters/InverseGaloisElement/", func(t *testing.T) {
for i := 1; i < int(testctx.params.N()/2); i++ {
galEl := testctx.params.GaloisElementForColumnRotationBy(i)
- mod := 2 * testctx.params.N()
+ mod := uint64(2 * testctx.params.N())
inv := testctx.params.InverseGaloisElement(galEl)
res := (inv * galEl) % mod
assert.Equal(t, uint64(1), res)
@@ -645,11 +645,10 @@ func testEvaluatorRotate(testctx *testContext, t *testing.T) {
values, _, ciphertext := newTestVectorsRingQ(testctx, testctx.encryptorPk, t)
receiver := NewCiphertext(testctx.params, 1)
- for _, k := range rots {
+ for _, n := range rots {
- evaluator.RotateColumns(ciphertext, k, receiver)
- nColumns := testctx.params.N() >> 1
- valuesWant := append(utils.RotateUint64Slice(values.Coeffs[0][:nColumns], k), utils.RotateUint64Slice(values.Coeffs[0][nColumns:], k)...)
+ evaluator.RotateColumns(ciphertext, n, receiver)
+ valuesWant := utils.RotateUint64Slots(values.Coeffs[0], n)
verifyTestVectors(testctx, testctx.decryptor, &ring.Poly{Coeffs: [][]uint64{valuesWant}}, receiver, t)
}
@@ -659,11 +658,10 @@ func testEvaluatorRotate(testctx *testContext, t *testing.T) {
values, _, ciphertext := newTestVectorsRingQ(testctx, testctx.encryptorPk, t)
- for _, k := range rots {
+ for _, n := range rots {
- receiver := evaluator.RotateColumnsNew(ciphertext, k)
- nColumns := testctx.params.N() >> 1
- valuesWant := append(utils.RotateUint64Slice(values.Coeffs[0][:nColumns], k), utils.RotateUint64Slice(values.Coeffs[0][nColumns:], k)...)
+ receiver := evaluator.RotateColumnsNew(ciphertext, n)
+ valuesWant := utils.RotateUint64Slots(values.Coeffs[0], n)
verifyTestVectors(testctx, testctx.decryptor, &ring.Poly{Coeffs: [][]uint64{valuesWant}}, receiver, t)
}
diff --git a/bfv/ciphertext.go b/bfv/ciphertext.go
index be0e4529..d5ecf957 100644
--- a/bfv/ciphertext.go
+++ b/bfv/ciphertext.go
@@ -8,12 +8,12 @@ type Ciphertext struct {
}
// NewCiphertext creates a new ciphertext parameterized by degree, level and scale.
-func NewCiphertext(params *Parameters, degree uint64) (ciphertext *Ciphertext) {
+func NewCiphertext(params *Parameters, degree int) (ciphertext *Ciphertext) {
return &Ciphertext{newCiphertextElement(params, degree)}
}
// NewCiphertextRandom generates a new uniformly distributed ciphertext of degree, level and scale.
-func NewCiphertextRandom(prng utils.PRNG, params *Parameters, degree uint64) (ciphertext *Ciphertext) {
+func NewCiphertextRandom(prng utils.PRNG, params *Parameters, degree int) (ciphertext *Ciphertext) {
ciphertext = &Ciphertext{newCiphertextElement(params, degree)}
populateElementRandom(prng, params, ciphertext.Element)
return
diff --git a/bfv/decryptor.go b/bfv/decryptor.go
index 340daa56..3b6b3af7 100644
--- a/bfv/decryptor.go
+++ b/bfv/decryptor.go
@@ -54,7 +54,7 @@ func (decryptor *decryptor) Decrypt(ciphertext *Ciphertext, p *Plaintext) {
ringQ.NTTLazy(ciphertext.value[ciphertext.Degree()], p.value)
- for i := uint64(ciphertext.Degree()); i > 0; i-- {
+ for i := ciphertext.Degree(); i > 0; i-- {
ringQ.MulCoeffsMontgomery(p.value, decryptor.sk.Value, p.value)
ringQ.NTTLazy(ciphertext.value[i-1], tmp)
ringQ.Add(p.value, tmp, p.value)
diff --git a/bfv/encoder.go b/bfv/encoder.go
index f2d3a451..721390c2 100644
--- a/bfv/encoder.go
+++ b/bfv/encoder.go
@@ -91,13 +91,13 @@ func NewEncoder(params *Parameters) Encoder {
indexMatrix := make([]uint64, slots)
- logN := params.LogN()
+ logN := uint64(params.LogN())
rowSize := params.N() >> 1
- m = (params.N() << 1)
+ m = uint64(params.N()) << 1
pos = 1
- for i := uint64(0); i < rowSize; i++ {
+ for i := 0; i < rowSize; i++ {
index1 = (pos - 1) >> 1
index2 = (m - pos - 1) >> 1
@@ -129,7 +129,7 @@ func GenLiftParams(ringQ *ring.Ring, t uint64) (deltaMont []uint64) {
deltaMont = make([]uint64, len(ringQ.Modulus))
tmp := new(big.Int)
- bredParams := ringQ.GetBredParams()
+ bredParams := ringQ.BredParams
for i, Qi := range ringQ.Modulus {
deltaMont[i] = tmp.Mod(delta, ring.NewUint(Qi)).Uint64()
deltaMont[i] = ring.MForm(deltaMont[i], Qi, bredParams[i])
@@ -241,9 +241,9 @@ func scaleUp(ringQ *ring.Ring, deltaMont []uint64, pIn, pOut *ring.Poly) {
in := pIn.Coeffs[0]
d := deltaMont[i]
qi := ringQ.Modulus[i]
- mredParams := ringQ.GetMredParams()[i]
+ mredParams := ringQ.MredParams[i]
- for j := uint64(0); j < ringQ.N; j = j + 8 {
+ for j := 0; j < ringQ.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&in[j]))
z := (*[8]uint64)(unsafe.Pointer(&out[j]))
@@ -312,7 +312,7 @@ func (encoder *encoder) DecodeUint(p interface{}, coeffs []uint64) {
encoder.ringT.NTT(ptRt.value, encoder.tmpPoly)
- for i := uint64(0); i < encoder.ringQ.N; i++ {
+ for i := 0; i < encoder.ringQ.N; i++ {
coeffs[i] = encoder.tmpPoly.Coeffs[0][encoder.indexMatrix[i]]
}
}
@@ -336,7 +336,7 @@ func (encoder *encoder) DecodeInt(p interface{}, coeffs []int64) {
modulus := int64(encoder.params.t)
modulusHalf := modulus >> 1
var value int64
- for i := uint64(0); i < encoder.ringQ.N; i++ {
+ for i := 0; i < encoder.ringQ.N; i++ {
value = int64(encoder.tmpPoly.Coeffs[0][encoder.indexMatrix[i]])
coeffs[i] = value
diff --git a/bfv/encryptor.go b/bfv/encryptor.go
index 2e58a49c..dab7398c 100644
--- a/bfv/encryptor.go
+++ b/bfv/encryptor.go
@@ -63,10 +63,10 @@ type encryptor struct {
polypool [3]*ring.Poly
baseconverter *ring.FastBasisExtender
+ gaussianSamplerQP *ring.GaussianSampler
gaussianSamplerQ *ring.GaussianSampler
uniformSamplerQ *ring.UniformSampler
ternarySamplerMontgomeryQ *ring.TernarySampler
- gaussianSamplerQP *ring.GaussianSampler
uniformSamplerQP *ring.UniformSampler
ternarySamplerMontgomeryQP *ring.TernarySampler
}
@@ -126,10 +126,10 @@ func newEncryptor(params *Parameters) encryptor {
ringQP: ringQP,
polypool: [3]*ring.Poly{ringQP.NewPoly(), ringQP.NewPoly(), ringQP.NewPoly()},
baseconverter: baseconverter,
- gaussianSamplerQ: ring.NewGaussianSampler(prng, ringQ, params.Sigma(), uint64(6*params.Sigma())),
+ gaussianSamplerQP: ring.NewGaussianSampler(prng, ringQP, params.Sigma(), int(6*params.Sigma())),
+ gaussianSamplerQ: ring.NewGaussianSampler(prng, ringQ, params.Sigma(), int(6*params.Sigma())),
uniformSamplerQ: ring.NewUniformSampler(prng, ringQ),
ternarySamplerMontgomeryQ: ring.NewTernarySampler(prng, ringQ, 0.5, true),
- gaussianSamplerQP: ring.NewGaussianSampler(prng, ringQP, params.Sigma(), uint64(6*params.Sigma())),
uniformSamplerQP: ring.NewUniformSampler(prng, ringQP),
ternarySamplerMontgomeryQP: ring.NewTernarySampler(prng, ringQP, 0.5, true),
}
@@ -193,10 +193,10 @@ func (encryptor *pkEncryptor) encrypt(p *Plaintext, ciphertext *Ciphertext, fast
ringQ.InvNTT(encryptor.polypool[1], ciphertext.value[1])
// ct[0] = pk[0]*u + e0
- encryptor.gaussianSamplerQ.ReadAndAdd(ciphertext.value[0])
+ encryptor.gaussianSamplerQ.ReadAndAddLvl(len(ringQ.Modulus)-1, ciphertext.value[0])
// ct[1] = pk[1]*u + e1
- encryptor.gaussianSamplerQ.ReadAndAdd(ciphertext.value[1])
+ encryptor.gaussianSamplerQ.ReadAndAddLvl(len(ringQ.Modulus)-1, ciphertext.value[1])
} else {
@@ -215,14 +215,14 @@ func (encryptor *pkEncryptor) encrypt(p *Plaintext, ciphertext *Ciphertext, fast
ringQP.InvNTTLazy(encryptor.polypool[1], encryptor.polypool[1])
// ct[0] = pk[0]*u + e0
- encryptor.gaussianSamplerQP.ReadAndAdd(encryptor.polypool[0])
+ encryptor.gaussianSamplerQP.ReadAndAddLvl(len(ringQP.Modulus)-1, encryptor.polypool[0])
// ct[1] = pk[1]*u + e1
- encryptor.gaussianSamplerQP.ReadAndAdd(encryptor.polypool[1])
+ encryptor.gaussianSamplerQP.ReadAndAddLvl(len(ringQP.Modulus)-1, encryptor.polypool[1])
// We rescale the encryption of zero by the special prime, dividing the error by this prime
- encryptor.baseconverter.ModDownPQ(uint64(len(ringQ.Modulus))-1, encryptor.polypool[0], ciphertext.value[0])
- encryptor.baseconverter.ModDownPQ(uint64(len(ringQ.Modulus))-1, encryptor.polypool[1], ciphertext.value[1])
+ encryptor.baseconverter.ModDownPQ(len(ringQ.Modulus)-1, encryptor.polypool[0], ciphertext.value[0])
+ encryptor.baseconverter.ModDownPQ(len(ringQ.Modulus)-1, encryptor.polypool[1], ciphertext.value[1])
}
// ct[0] = pk[0]*u + e0 + m
// ct[1] = pk[1]*u + e1
@@ -285,7 +285,7 @@ func (encryptor *skEncryptor) encrypt(p *Plaintext, ciphertext *Ciphertext, crp
ringQ.InvNTT(ciphertext.value[0], ciphertext.value[0])
ringQ.InvNTT(crp, ciphertext.value[1])
- encryptor.gaussianSamplerQ.ReadAndAdd(ciphertext.value[0])
+ encryptor.gaussianSamplerQ.ReadAndAddLvl(len(ringQ.Modulus)-1, ciphertext.value[0])
// ct = [-a*s + m + e , a]
encryptor.ringQ.Add(ciphertext.value[0], p.value, ciphertext.value[0])
diff --git a/bfv/evaluator.go b/bfv/evaluator.go
index cb9ff8b9..e6a6cab6 100644
--- a/bfv/evaluator.go
+++ b/bfv/evaluator.go
@@ -2,6 +2,7 @@ package bfv
import (
"fmt"
+ "math"
"math/big"
"github.com/ldsec/lattigo/v2/ring"
@@ -81,7 +82,8 @@ func newEvaluatorPrecomp(params *Parameters) *evaluatorBase {
panic(err)
}
- qiMul := ring.GenerateNTTPrimesP(61, 2*params.N(), uint64(len(params.qi)))
+ // Generates #QiMul primes such that Q * QMul > Q*Q*N
+ qiMul := ring.GenerateNTTPrimesP(61, 2*params.N(), int(math.Ceil(float64(ev.ringQ.ModulusBigint.BitLen()+params.LogN())/61.0)))
if ev.ringQMul, err = ring.NewRing(params.N(), qiMul); err != nil {
panic(err)
@@ -192,33 +194,33 @@ func (eval *evaluator) WithKey(evaluationKey EvaluationKey) Evaluator {
// Add adds op0 to op1 and returns the result in ctOut.
func (eval *evaluator) Add(op0, op1 Operand, ctOut *Ciphertext) {
- el0, el1, elOut := eval.getElemAndCheckBinary(op0, op1, ctOut, utils.MaxUint64(op0.Degree(), op1.Degree()), true)
+ el0, el1, elOut := eval.getElemAndCheckBinary(op0, op1, ctOut, utils.MaxInt(op0.Degree(), op1.Degree()), true)
eval.evaluateInPlaceBinary(el0, el1, elOut, eval.ringQ.Add)
}
// AddNew adds op0 to op1 and creates a new element ctOut to store the result.
func (eval *evaluator) AddNew(op0, op1 Operand) (ctOut *Ciphertext) {
- ctOut = NewCiphertext(eval.params, utils.MaxUint64(op0.Degree(), op1.Degree()))
+ ctOut = NewCiphertext(eval.params, utils.MaxInt(op0.Degree(), op1.Degree()))
eval.Add(op0, op1, ctOut)
return
}
// AddNoMod adds op0 to op1 without modular reduction, and returns the result in cOut.
func (eval *evaluator) AddNoMod(op0, op1 Operand, ctOut *Ciphertext) {
- el0, el1, elOut := eval.getElemAndCheckBinary(op0, op1, ctOut, utils.MaxUint64(op0.Degree(), op1.Degree()), true)
+ el0, el1, elOut := eval.getElemAndCheckBinary(op0, op1, ctOut, utils.MaxInt(op0.Degree(), op1.Degree()), true)
eval.evaluateInPlaceBinary(el0, el1, elOut, eval.ringQ.AddNoMod)
}
// AddNoModNew adds op0 to op1 without modular reduction and creates a new element ctOut to store the result.
func (eval *evaluator) AddNoModNew(op0, op1 Operand) (ctOut *Ciphertext) {
- ctOut = NewCiphertext(eval.params, utils.MaxUint64(op0.Degree(), op1.Degree()))
+ ctOut = NewCiphertext(eval.params, utils.MaxInt(op0.Degree(), op1.Degree()))
eval.AddNoMod(op0, op1, ctOut)
return
}
// Sub subtracts op1 from op0 and returns the result in cOut.
func (eval *evaluator) Sub(op0, op1 Operand, ctOut *Ciphertext) {
- el0, el1, elOut := eval.getElemAndCheckBinary(op0, op1, ctOut, utils.MaxUint64(op0.Degree(), op1.Degree()), true)
+ el0, el1, elOut := eval.getElemAndCheckBinary(op0, op1, ctOut, utils.MaxInt(op0.Degree(), op1.Degree()), true)
eval.evaluateInPlaceBinary(el0, el1, elOut, eval.ringQ.Sub)
if el0.Degree() < el1.Degree() {
@@ -230,14 +232,14 @@ func (eval *evaluator) Sub(op0, op1 Operand, ctOut *Ciphertext) {
// SubNew subtracts op1 from op0 and creates a new element ctOut to store the result.
func (eval *evaluator) SubNew(op0, op1 Operand) (ctOut *Ciphertext) {
- ctOut = NewCiphertext(eval.params, utils.MaxUint64(op0.Degree(), op1.Degree()))
+ ctOut = NewCiphertext(eval.params, utils.MaxInt(op0.Degree(), op1.Degree()))
eval.Sub(op0, op1, ctOut)
return
}
// SubNoMod subtracts op1 from op0 without modular reduction and returns the result on ctOut.
func (eval *evaluator) SubNoMod(op0, op1 Operand, ctOut *Ciphertext) {
- el0, el1, elOut := eval.getElemAndCheckBinary(op0, op1, ctOut, utils.MaxUint64(op0.Degree(), op1.Degree()), true)
+ el0, el1, elOut := eval.getElemAndCheckBinary(op0, op1, ctOut, utils.MaxInt(op0.Degree(), op1.Degree()), true)
eval.evaluateInPlaceBinary(el0, el1, elOut, eval.ringQ.SubNoMod)
@@ -250,7 +252,7 @@ func (eval *evaluator) SubNoMod(op0, op1 Operand, ctOut *Ciphertext) {
// SubNoModNew subtracts op1 from op0 without modular reduction and creates a new element ctOut to store the result.
func (eval *evaluator) SubNoModNew(op0, op1 Operand) (ctOut *Ciphertext) {
- ctOut = NewCiphertext(eval.params, utils.MaxUint64(op0.Degree(), op1.Degree()))
+ ctOut = NewCiphertext(eval.params, utils.MaxInt(op0.Degree(), op1.Degree()))
eval.SubNoMod(op0, op1, ctOut)
return
}
@@ -328,7 +330,7 @@ func (eval *evaluator) tensorAndRescale(ct0, ct1, ctOut *Element) {
}
func (eval *evaluator) modUpAndNTT(ct *Element, cQ, cQMul []*ring.Poly) {
- levelQ := uint64(len(eval.ringQ.Modulus) - 1)
+ levelQ := len(eval.ringQ.Modulus) - 1
for i := range ct.value {
eval.baseconverterQ1Q2.ModUpSplitQP(levelQ, ct.value[i], cQMul[i])
eval.ringQ.NTTLazy(ct.value[i], cQ[i])
@@ -407,7 +409,7 @@ func (eval *evaluator) tensortLargeDeg(ct0, ct1 *Element) {
c2Q1 := eval.poolQ[2]
c2Q2 := eval.poolQmul[2]
- for i := uint64(0); i < ct0.Degree()+ct1.Degree()+1; i++ {
+ for i := 0; i < ct0.Degree()+ct1.Degree()+1; i++ {
c2Q1[i].Zero()
c2Q2[i].Zero()
}
@@ -423,7 +425,7 @@ func (eval *evaluator) tensortLargeDeg(ct0, ct1 *Element) {
eval.ringQMul.MForm(c0Q2[i], c00Q2[i])
}
- for i := uint64(0); i < ct0.Degree()+1; i++ {
+ for i := 0; i < ct0.Degree()+1; i++ {
for j := i + 1; j < ct0.Degree()+1; j++ {
eval.ringQ.MulCoeffsMontgomery(c00Q1[i], c0Q1[j], c2Q1[i+j])
eval.ringQMul.MulCoeffsMontgomery(c00Q2[i], c0Q2[j], c2Q2[i+j])
@@ -433,7 +435,7 @@ func (eval *evaluator) tensortLargeDeg(ct0, ct1 *Element) {
}
}
- for i := uint64(0); i < ct0.Degree()+1; i++ {
+ for i := 0; i < ct0.Degree()+1; i++ {
eval.ringQ.MulCoeffsMontgomeryAndAdd(c00Q1[i], c0Q1[i], c2Q1[i<<1])
eval.ringQMul.MulCoeffsMontgomeryAndAdd(c00Q2[i], c0Q2[i], c2Q2[i<<1])
}
@@ -453,8 +455,8 @@ func (eval *evaluator) tensortLargeDeg(ct0, ct1 *Element) {
func (eval *evaluator) quantize(ctOut *Element) {
- levelQ := uint64(len(eval.ringQ.Modulus) - 1)
- levelQMul := uint64(len(eval.ringQMul.Modulus) - 1)
+ levelQ := len(eval.ringQ.Modulus) - 1
+ levelQMul := len(eval.ringQMul.Modulus) - 1
c2Q1 := eval.poolQ[2]
c2Q2 := eval.poolQmul[2]
@@ -521,15 +523,15 @@ func (eval *evaluator) mulPlaintextRingT(ct0 *Ciphertext, ptRt *PlaintextRingT,
tmp := ctOut.value[i].Coeffs[j]
qi := ringQ.Modulus[j]
- nttPsi := ringQ.GetNttPsi()[j]
- bredParams := ringQ.GetBredParams()[j]
- mredParams := ringQ.GetMredParams()[j]
+ nttPsi := ringQ.NttPsi[j]
+ bredParams := ringQ.BredParams[j]
+ mredParams := ringQ.MredParams[j]
// Transforms the plaintext in the NTT domain of that qi
ring.NTTLazy(coeffs, coeffsNTT, ringQ.N, nttPsi, qi, mredParams, bredParams)
// Multiplies NTT_qi(pt) * NTT_qi(ct)
- for k := uint64(0); k < eval.ringQ.N; k = k + 8 {
+ for k := 0; k < eval.ringQ.N; k = k + 8 {
x := (*[8]uint64)(unsafe.Pointer(&coeffsNTT[k]))
z := (*[8]uint64)(unsafe.Pointer(&tmp[k]))
@@ -588,7 +590,7 @@ func (eval *evaluator) Relinearize(ct0 *Ciphertext, ctOut *Ciphertext) {
panic("evaluator has no relinearization key")
}
- if int(ct0.Degree()-1) > len(eval.rlk.Keys) {
+ if ct0.Degree()-1 > len(eval.rlk.Keys) {
panic("input ciphertext degree is too large to allow relinearization with the evluator's relinearization key")
}
@@ -732,15 +734,13 @@ func (eval *evaluator) permute(ct0 *Ciphertext, generator uint64, switchKey *rlw
// switchKeys applies the general key-switching procedure of the form [c0 + cx*evakey[0], c1 + cx*evakey[1]]
func (eval *evaluator) switchKeysInPlace(cx *ring.Poly, evakey *rlwe.SwitchingKey, pool2Q, pool3Q *ring.Poly) {
- var level, reduce uint64
-
ringQ := eval.ringQ
ringP := eval.ringP
pool2P := eval.poolPKS[1]
pool3P := eval.poolPKS[2]
- level = uint64(len(ringQ.Modulus)) - 1
+ level := len(ringQ.Modulus) - 1
c2QiQ := eval.poolQKS[0]
c2QiP := eval.poolPKS[0]
@@ -754,10 +754,10 @@ func (eval *evaluator) switchKeysInPlace(cx *ring.Poly, evakey *rlwe.SwitchingKe
// We switch the element on which the key-switching operation will be conducted out of the NTT domain
ringQ.NTTLazy(cx, c2)
- reduce = 0
+ var reduce int
// Key switching with CRT decomposition for the Qi
- for i := uint64(0); i < eval.params.Beta(); i++ {
+ for i := 0; i < eval.params.Beta(); i++ {
eval.decomposeAndSplitNTT(level, i, c2, cx, c2QiQ, c2QiP)
@@ -816,7 +816,7 @@ func (eval *evaluator) getRingQElem(op Operand) *Element {
}
// getElemAndCheckBinary unwraps the elements from the operands and checks that the receiver has sufficiently large degree.
-func (eval *evaluator) getElemAndCheckBinary(op0, op1, opOut Operand, opOutMinDegree uint64, ensureRingQ bool) (el0, el1, elOut *Element) {
+func (eval *evaluator) getElemAndCheckBinary(op0, op1, opOut Operand, opOutMinDegree int, ensureRingQ bool) (el0, el1, elOut *Element) {
if op0 == nil || op1 == nil || opOut == nil {
panic("operands cannot be nil")
}
@@ -836,7 +836,7 @@ func (eval *evaluator) getElemAndCheckBinary(op0, op1, opOut Operand, opOutMinDe
return op0.El(), op1.El(), opOut.El()
}
-func (eval *evaluator) getElemAndCheckUnary(op0, opOut Operand, opOutMinDegree uint64) (el0, elOut *Element) {
+func (eval *evaluator) getElemAndCheckUnary(op0, opOut Operand, opOutMinDegree int) (el0, elOut *Element) {
if op0 == nil || opOut == nil {
panic("operand cannot be nil")
}
@@ -857,7 +857,7 @@ func (eval *evaluator) evaluateInPlaceBinary(el0, el1, elOut *Element, evaluate
smallest, largest, _ := getSmallestLargest(el0, el1)
- for i := uint64(0); i < smallest.Degree()+1; i++ {
+ for i := 0; i < smallest.Degree()+1; i++ {
evaluate(el0.value[i], el1.value[i], elOut.value[i])
}
@@ -877,7 +877,7 @@ func evaluateInPlaceUnary(el0, elOut *Element, evaluate func(*ring.Poly, *ring.P
}
// decomposeAndSplitNTT decomposes the input polynomial into the target CRT basis.
-func (eval *evaluator) decomposeAndSplitNTT(level, beta uint64, c2NTT, c2InvNTT, c2QiQ, c2QiP *ring.Poly) {
+func (eval *evaluator) decomposeAndSplitNTT(level, beta int, c2NTT, c2InvNTT, c2QiQ, c2QiP *ring.Poly) {
ringQ := eval.ringQ
ringP := eval.ringP
@@ -888,17 +888,17 @@ func (eval *evaluator) decomposeAndSplitNTT(level, beta uint64, c2NTT, c2InvNTT,
p0idxed := p0idxst + eval.decomposer.Xalpha()[beta]
// c2_qi = cx mod qi mod qi
- for x := uint64(0); x < level+1; x++ {
+ for x := 0; x < level+1; x++ {
qi := ringQ.Modulus[x]
- nttPsi := ringQ.GetNttPsi()[x]
- bredParams := ringQ.GetBredParams()[x]
- mredParams := ringQ.GetMredParams()[x]
+ nttPsi := ringQ.NttPsi[x]
+ bredParams := ringQ.BredParams[x]
+ mredParams := ringQ.MredParams[x]
if p0idxst <= x && x < p0idxed {
p0tmp := c2NTT.Coeffs[x]
p1tmp := c2QiQ.Coeffs[x]
- for j := uint64(0); j < ringQ.N; j++ {
+ for j := 0; j < ringQ.N; j++ {
p1tmp[j] = p0tmp[j]
}
} else {
diff --git a/bfv/keygen.go b/bfv/keygen.go
index 87225c0c..4823fcaa 100644
--- a/bfv/keygen.go
+++ b/bfv/keygen.go
@@ -15,7 +15,7 @@ type KeyGenerator interface {
GenPublicKey(sk *SecretKey) (pk *PublicKey)
GenKeyPair() (sk *SecretKey, pk *PublicKey)
GenSwitchingKey(skIn, skOut *SecretKey) (evk *SwitchingKey)
- GenRelinearizationKey(sk *SecretKey, maxDegree uint64) (evk *RelinearizationKey)
+ GenRelinearizationKey(sk *SecretKey, maxDegree int) (evk *RelinearizationKey)
GenSwitchingKeyForGalois(galEl uint64, sk *SecretKey) (swk *SwitchingKey)
GenRotationKeys(galEls []uint64, sk *SecretKey) (rks *RotationKeySet)
GenRotationKeysForRotations(ks []int, includeSwapRow bool, sk *SecretKey) (rks *RotationKeySet)
@@ -61,7 +61,7 @@ func NewKeyGenerator(params *Parameters) KeyGenerator {
ringQP: ringQP,
pBigInt: pBigInt,
polypool: [2]*ring.Poly{ringQP.NewPoly(), ringQP.NewPoly()},
- gaussianSampler: ring.NewGaussianSampler(prng, ringQP, params.Sigma(), uint64(6*params.Sigma())),
+ gaussianSampler: ring.NewGaussianSampler(prng, ringQP, params.Sigma(), int(6*params.Sigma())),
uniformSampler: ring.NewUniformSampler(prng, ringQP),
}
}
@@ -112,7 +112,7 @@ func (keygen *keyGenerator) GenKeyPair() (sk *SecretKey, pk *PublicKey) {
// NewRelinKey generates a new evaluation key from the provided SecretKey. It will be used to relinearize a ciphertext (encrypted under a PublicKey generated from the provided SecretKey)
// of degree > 1 to a ciphertext of degree 1. Max degree is the maximum degree of the ciphertext allowed to relinearize.
-func (keygen *keyGenerator) GenRelinearizationKey(sk *SecretKey, maxDegree uint64) (evk *RelinearizationKey) {
+func (keygen *keyGenerator) GenRelinearizationKey(sk *SecretKey, maxDegree int) (evk *RelinearizationKey) {
if keygen.ringQP == nil {
panic("modulus P is empty")
@@ -129,7 +129,7 @@ func (keygen *keyGenerator) GenRelinearizationKey(sk *SecretKey, maxDegree uint6
ringQP := keygen.ringQP
keygen.polypool[1].Copy(sk.Value)
- for i := uint64(0); i < maxDegree; i++ {
+ for i := 0; i < maxDegree; i++ {
ringQP.MulCoeffsMontgomery(keygen.polypool[1], sk.Value, keygen.polypool[1])
keygen.newSwitchingKey(keygen.polypool[1], sk.Value, evk.Keys[i])
}
@@ -212,13 +212,13 @@ func (keygen *keyGenerator) newSwitchingKey(skIn, skOut *ring.Poly, swkOut *rlwe
alpha := keygen.params.Alpha()
beta := keygen.params.Beta()
- var index uint64
+ var index int
// delta_sk = skIn - skOut = GaloisEnd(skOut, rotation) - skOut
ringQP.MulScalarBigint(skIn, keygen.pBigInt, keygen.polypool[0])
- for i := uint64(0); i < beta; i++ {
+ for i := 0; i < beta; i++ {
// e
keygen.gaussianSampler.Read(swkOut.Value[i][0])
@@ -230,7 +230,7 @@ func (keygen *keyGenerator) newSwitchingKey(skIn, skOut *ring.Poly, swkOut *rlwe
// e + skIn * (qiBarre*qiStar) * 2^w
// (qiBarre*qiStar)%qi = 1, else 0
- for j := uint64(0); j < alpha; j++ {
+ for j := 0; j < alpha; j++ {
index = i*alpha + j
@@ -238,7 +238,7 @@ func (keygen *keyGenerator) newSwitchingKey(skIn, skOut *ring.Poly, swkOut *rlwe
p0tmp := keygen.polypool[0].Coeffs[index]
p1tmp := swkOut.Value[i][0].Coeffs[index]
- for w := uint64(0); w < ringQP.N; w++ {
+ for w := 0; w < ringQP.N; w++ {
p1tmp[w] = ring.CRed(p1tmp[w]+p0tmp[w], qi)
}
diff --git a/bfv/keys.go b/bfv/keys.go
index 60aa9b46..a094a371 100644
--- a/bfv/keys.go
+++ b/bfv/keys.go
@@ -40,11 +40,11 @@ func NewSwitchingKey(params *Parameters) *SwitchingKey {
}
// NewRelinearizationKey returns an allocated BFV public relinearization key with zero value for each degree in [2 < maxRelinDegree].
-func NewRelinearizationKey(params *Parameters, maxRelinDegree uint64) *RelinearizationKey {
+func NewRelinearizationKey(params *Parameters, maxRelinDegree int) *RelinearizationKey {
return &RelinearizationKey{*rlwe.NewRelinKey(maxRelinDegree, params.N(), params.QPiCount(), params.Beta())}
}
-// NewRotationKeySet return an allocated set of BFV public rotation keys with zero values for each galois element
+// NewRotationKeySet returns an allocated set of BFV public rotation keys with zero values for each galois element
// (i.e., for each supported rotation).
func NewRotationKeySet(params *Parameters, galoisElements []uint64) *RotationKeySet {
return &RotationKeySet{*rlwe.NewRotationKeySet(galoisElements, params.N(), params.QPiCount(), params.Beta())}
diff --git a/bfv/marshaler.go b/bfv/marshaler.go
index 14b423bb..1930120c 100644
--- a/bfv/marshaler.go
+++ b/bfv/marshaler.go
@@ -11,7 +11,7 @@ func (ciphertext *Ciphertext) MarshalBinary() (data []byte, err error) {
data[0] = uint8(len(ciphertext.value))
- var pointer, inc uint64
+ var pointer, inc int
pointer = 1
@@ -34,7 +34,7 @@ func (ciphertext *Ciphertext) UnmarshalBinary(data []byte) (err error) {
ciphertext.value = make([]*ring.Poly, uint8(data[0]))
- var pointer, inc uint64
+ var pointer, inc int
pointer = 1
for i := range ciphertext.value {
@@ -52,7 +52,7 @@ func (ciphertext *Ciphertext) UnmarshalBinary(data []byte) (err error) {
}
// GetDataLen returns the length in bytes of the target Ciphertext.
-func (ciphertext *Ciphertext) GetDataLen(WithMetaData bool) (dataLen uint64) {
+func (ciphertext *Ciphertext) GetDataLen(WithMetaData bool) (dataLen int) {
if WithMetaData {
dataLen++
}
diff --git a/bfv/operand.go b/bfv/operand.go
index 3223c530..8c4834a8 100644
--- a/bfv/operand.go
+++ b/bfv/operand.go
@@ -8,7 +8,7 @@ import (
// Operand is a common interface for Ciphertext and Plaintext.
type Operand interface {
El() *Element
- Degree() uint64
+ Degree() int
}
// Element is a common struct for Plaintexts and Ciphertexts. It stores a value
@@ -27,10 +27,10 @@ func getSmallestLargest(el0, el1 *Element) (smallest, largest *Element, sameDegr
return el0, el1, true
}
-func newCiphertextElement(params *Parameters, degree uint64) *Element {
+func newCiphertextElement(params *Parameters, degree int) *Element {
el := new(Element)
el.value = make([]*ring.Poly, degree+1)
- for i := uint64(0); i < degree+1; i++ {
+ for i := 0; i < degree+1; i++ {
el.value[i] = ring.NewPoly(params.N(), params.QiCount())
}
return el
@@ -78,24 +78,24 @@ func (el *Element) SetValue(value []*ring.Poly) {
}
// Degree returns the degree of the target Element.
-func (el *Element) Degree() uint64 {
- return uint64(len(el.value) - 1)
+func (el *Element) Degree() int {
+ return len(el.value) - 1
}
// Level returns the level of the target element.
-func (el *Element) Level() uint64 {
- return uint64(len(el.value[0].Coeffs) - 1)
+func (el *Element) Level() int {
+ return len(el.value[0].Coeffs) - 1
}
// Resize resizes the degree of the target element.
-func (el *Element) Resize(params *Parameters, degree uint64) {
+func (el *Element) Resize(params *Parameters, degree int) {
if el.Degree() > degree {
el.value = el.value[:degree+1]
} else if el.Degree() < degree {
for el.Degree() < degree {
el.value = append(el.value, []*ring.Poly{new(ring.Poly)}...)
el.value[el.Degree()].Coeffs = make([][]uint64, el.Level()+1)
- for i := uint64(0); i < el.Level()+1; i++ {
+ for i := 0; i < el.Level()+1; i++ {
el.value[el.Degree()].Coeffs[i] = make([]uint64, params.N())
}
}
diff --git a/bfv/params.go b/bfv/params.go
index 415eeb8c..186dea94 100644
--- a/bfv/params.go
+++ b/bfv/params.go
@@ -157,7 +157,7 @@ func (m *LogModuli) Copy() LogModuli {
// Parameters represents a given parameter set for the BFV cryptosystem.
type Parameters struct {
- logN uint64 // Log Ring degree (power of 2)
+ logN int // Log Ring degree (power of 2)
qi []uint64
pi []uint64
t uint64 // Plaintext modulus
@@ -165,7 +165,7 @@ type Parameters struct {
}
// NewParametersFromModuli creates a new Parameters struct and returns a pointer to it.
-func NewParametersFromModuli(logN uint64, m *Moduli, t uint64) (p *Parameters, err error) {
+func NewParametersFromModuli(logN int, m *Moduli, t uint64) (p *Parameters, err error) {
p = new(Parameters)
@@ -194,7 +194,7 @@ func NewParametersFromModuli(logN uint64, m *Moduli, t uint64) (p *Parameters, e
}
// NewParametersFromLogModuli creates a new Parameters struct and returns a pointer to it.
-func NewParametersFromLogModuli(logN uint64, lm *LogModuli, t uint64) (p *Parameters, err error) {
+func NewParametersFromLogModuli(logN int, lm *LogModuli, t uint64) (p *Parameters, err error) {
if err = checkLogModuli(lm); err != nil {
return nil, err
@@ -205,12 +205,12 @@ func NewParametersFromLogModuli(logN uint64, lm *LogModuli, t uint64) (p *Parame
}
// LogN returns the log of the degree of the polynomial ring
-func (p *Parameters) LogN() uint64 {
+func (p *Parameters) LogN() int {
return p.logN
}
// N returns power of two degree of the ring
-func (p *Parameters) N() uint64 {
+func (p *Parameters) N() int {
return 1 << p.logN
}
@@ -268,8 +268,8 @@ func (p *Parameters) Qi() []uint64 {
}
// QiCount returns the number of factors of the ciphertext modulus q
-func (p *Parameters) QiCount() uint64 {
- return uint64(len(p.qi))
+func (p *Parameters) QiCount() int {
+ return len(p.qi)
}
// Pi returns a new slice with the factors of the ciphertext modulus extension P
@@ -280,17 +280,17 @@ func (p *Parameters) Pi() []uint64 {
}
// PiCount returns the number of factors of the ciphertext modulus extension P
-func (p *Parameters) PiCount() uint64 {
- return uint64(len(p.pi))
+func (p *Parameters) PiCount() int {
+ return len(p.pi)
}
// QPiCount returns the number of factors of the ciphertext modulus Q + the modulus extension P
-func (p *Parameters) QPiCount() uint64 {
+func (p *Parameters) QPiCount() int {
return p.QiCount() + p.PiCount()
}
// LogQP returns the size of the extended modulus QP in bits
-func (p *Parameters) LogQP() uint64 {
+func (p *Parameters) LogQP() int {
tmp := ring.NewUint(1)
for _, qi := range p.qi {
tmp.Mul(tmp, ring.NewUint(qi))
@@ -298,25 +298,25 @@ func (p *Parameters) LogQP() uint64 {
for _, pi := range p.pi {
tmp.Mul(tmp, ring.NewUint(pi))
}
- return uint64(tmp.BitLen())
+ return tmp.BitLen()
}
// LogQ returns the size of the modulus Q in bits
-func (p *Parameters) LogQ() uint64 {
+func (p *Parameters) LogQ() int {
tmp := ring.NewUint(1)
for _, qi := range p.qi {
tmp.Mul(tmp, ring.NewUint(qi))
}
- return uint64(tmp.BitLen())
+ return tmp.BitLen()
}
// LogP returns the size of the modulus P in bits
-func (p *Parameters) LogP() uint64 {
+func (p *Parameters) LogP() int {
tmp := ring.NewUint(1)
for _, pi := range p.pi {
tmp.Mul(tmp, ring.NewUint(pi))
}
- return uint64(tmp.BitLen())
+ return tmp.BitLen()
}
// LogQAlpha returns the size in bits of the sum of the norm of
@@ -326,7 +326,7 @@ func (p *Parameters) LogP() uint64 {
// error during the keyswitching and then divided by P.
// LogQAlpha should be smaller than P or the error added during
// the key-switching wont be negligible.
-func (p *Parameters) LogQAlpha() uint64 {
+func (p *Parameters) LogQAlpha() int {
alpha := p.PiCount()
@@ -335,8 +335,8 @@ func (p *Parameters) LogQAlpha() uint64 {
}
res := ring.NewUint(0)
- var j uint64
- for i := uint64(0); i < p.QiCount(); i = i + alpha {
+ var j int
+ for i := 0; i < p.QiCount(); i = i + alpha {
j = i + alpha
if j > p.QiCount() {
@@ -351,18 +351,18 @@ func (p *Parameters) LogQAlpha() uint64 {
res.Add(res, tmp)
}
- return uint64(res.BitLen())
+ return res.BitLen()
}
// Alpha returns the number of moduli in in P
-func (p *Parameters) Alpha() uint64 {
+func (p *Parameters) Alpha() int {
return p.PiCount()
}
// Beta returns the number of element in the RNS decomposition basis: Ceil(lenQi / lenPi)
-func (p *Parameters) Beta() uint64 {
+func (p *Parameters) Beta() int {
if p.Alpha() != 0 {
- return uint64(math.Ceil(float64(p.QiCount()) / float64(p.Alpha())))
+ return int(math.Ceil(float64(p.QiCount()) / float64(p.Alpha())))
}
return 0
@@ -389,7 +389,7 @@ func (p *Parameters) NewPolyQP() *ring.Poly {
func (p *Parameters) GaloisElementForColumnRotationBy(k int) uint64 {
twoN := 1 << (p.logN + 1)
mask := twoN - 1
- kRed := uint64(k & mask)
+ kRed := k & mask
return ring.ModExp(GaloisGen, kRed, uint64(twoN))
}
@@ -399,9 +399,8 @@ func (p *Parameters) GaloisElementForRowRotation() uint64 {
return (1 << (p.logN + 1)) - 1
}
-// GaloisElementsForRowInnerSum returns a list of all galois elements required to
-// perform an InnerSum operation. This corresponds to all the left rotations by
-// k-positions where k is a power of two and the row-rotation element.
+// GaloisElementsForRowInnerSum returns a list of galois element corresponding to
+// all the left rotations by a k-position where k is a power of two.
func (p *Parameters) GaloisElementsForRowInnerSum() (galEls []uint64) {
galEls = make([]uint64, p.logN+1, p.logN+1)
galEls[0] = p.GaloisElementForRowRotation()
@@ -414,8 +413,8 @@ func (p *Parameters) GaloisElementsForRowInnerSum() (galEls []uint64) {
// InverseGaloisElement takes a galois element and returns the galois element
// corresponding to the inverse automorphism
func (p *Parameters) InverseGaloisElement(galEl uint64) uint64 {
- twoN := uint64(1 << (p.logN + 1))
- return ring.ModExp(galEl, twoN-1, twoN)
+ twoN := 1 << (p.logN + 1)
+ return ring.ModExp(galEl, twoN-1, uint64(twoN))
}
// Copy creates a copy of the target Parameters.
@@ -481,7 +480,7 @@ func (p *Parameters) UnmarshalBinary(data []byte) error {
}
b := utils.NewBuffer(data)
- p.logN = uint64(b.ReadUint8())
+ p.logN = int(b.ReadUint8())
if p.logN > MaxLogN {
return fmt.Errorf("logN larger than %d", MaxLogN)
@@ -505,7 +504,7 @@ func (p *Parameters) UnmarshalBinary(data []byte) error {
return nil
}
-func checkModuli(m *Moduli, logN uint64) (err error) {
+func checkModuli(m *Moduli, logN int) (err error) {
if len(m.Qi) > MaxModuliCount {
return fmt.Errorf("#qi is larger than %d", MaxModuliCount)
@@ -516,27 +515,27 @@ func checkModuli(m *Moduli, logN uint64) (err error) {
}
for i, qi := range m.Qi {
- if uint64(bits.Len64(qi)-1) > MaxModuliSize+1 {
+ if bits.Len64(qi)-1 > MaxModuliSize+1 {
return fmt.Errorf("qi bit-size for i=%d is larger than %d", i, MaxModuliSize)
}
}
for i, pi := range m.Pi {
- if uint64(bits.Len64(pi)-1) > MaxModuliSize+1 {
+ if bits.Len64(pi)-1 > MaxModuliSize+1 {
return fmt.Errorf("Pi bit-size for i=%d is larger than %d", i, MaxModuliSize)
}
}
- N := uint64(1 << logN)
+ N := 1 << logN
for i, qi := range m.Qi {
- if !ring.IsPrime(qi) || qi&((N<<1)-1) != 1 {
+ if !ring.IsPrime(qi) || qi&uint64((N<<1)-1) != 1 {
return fmt.Errorf("qi n°%d is not an NTT prime", i)
}
}
for i, pi := range m.Pi {
- if !ring.IsPrime(pi) || pi&((N<<1)-1) != 1 {
+ if !ring.IsPrime(pi) || pi&uint64((N<<1)-1) != 1 {
return fmt.Errorf("Pi n°%d is not an NTT prime", i)
}
}
@@ -575,12 +574,12 @@ func checkLogModuli(lm *LogModuli) (err error) {
}
// GenModuli generates the appropriate primes from the parameters using generateNTTPrimes such that all primes are different.
-func genModuli(lm *LogModuli, logN uint64) (m *Moduli) {
+func genModuli(lm *LogModuli, logN int) (m *Moduli) {
m = new(Moduli)
// Extracts all the different primes bit-size and maps their number
- primesbitlen := make(map[uint64]uint64)
+ primesbitlen := make(map[uint64]int)
for _, qi := range lm.LogQi {
primesbitlen[qi]++
@@ -593,7 +592,7 @@ func genModuli(lm *LogModuli, logN uint64) (m *Moduli) {
// For each bit-size, it finds that many primes
primes := make(map[uint64][]uint64)
for key, value := range primesbitlen {
- primes[key] = ring.GenerateNTTPrimes(key, 2< 1.
-func (eval *evaluator) PowerOf2(op *Ciphertext, logPow2 uint64, opOut *Ciphertext) {
+func (eval *evaluator) PowerOf2(op *Ciphertext, logPow2 int, opOut *Ciphertext) {
if logPow2 == 0 {
@@ -22,7 +22,7 @@ func (eval *evaluator) PowerOf2(op *Ciphertext, logPow2 uint64, opOut *Ciphertex
panic(err)
}
- for i := uint64(1); i < logPow2; i++ {
+ for i := 1; i < logPow2; i++ {
eval.MulRelin(opOut.El(), opOut.El(), opOut)
@@ -35,7 +35,7 @@ func (eval *evaluator) PowerOf2(op *Ciphertext, logPow2 uint64, opOut *Ciphertex
// PowerNew computes op^degree, consuming log(degree) levels, and returns the result on a new element. Providing an evaluation
// key is necessary when degree > 2.
-func (eval *evaluator) PowerNew(op *Ciphertext, degree uint64) (opOut *Ciphertext) {
+func (eval *evaluator) PowerNew(op *Ciphertext, degree int) (opOut *Ciphertext) {
opOut = NewCiphertext(eval.params, 1, op.Level(), op.Scale())
eval.Power(op, degree, opOut)
return
@@ -43,13 +43,17 @@ func (eval *evaluator) PowerNew(op *Ciphertext, degree uint64) (opOut *Ciphertex
// Power computes op^degree, consuming log(degree) levels, and returns the result on opOut. Providing an evaluation
// key is necessary when degree > 2.
-func (eval *evaluator) Power(op *Ciphertext, degree uint64, opOut *Ciphertext) {
+func (eval *evaluator) Power(op *Ciphertext, degree int, opOut *Ciphertext) {
+
+ if degree < 1 {
+ panic("eval.Power -> degree cannot be smaller than 1")
+ }
tmpct0 := op.CopyNew()
- var logDegree, po2Degree uint64
+ var logDegree, po2Degree int
- logDegree = uint64(bits.Len64(degree)) - 1
+ logDegree = bits.Len64(uint64(degree)) - 1
po2Degree = 1 << logDegree
eval.PowerOf2(tmpct0.Ciphertext(), logDegree, opOut)
@@ -58,7 +62,7 @@ func (eval *evaluator) Power(op *Ciphertext, degree uint64, opOut *Ciphertext) {
for degree > 0 {
- logDegree = uint64(bits.Len64(degree)) - 1
+ logDegree = bits.Len64(uint64(degree)) - 1
po2Degree = 1 << logDegree
tmp := NewCiphertext(eval.params, 1, tmpct0.Level(), tmpct0.Scale())
@@ -77,7 +81,7 @@ func (eval *evaluator) Power(op *Ciphertext, degree uint64, opOut *Ciphertext) {
// InverseNew computes 1/op and returns the result on a new element, iterating for n steps and consuming n levels. The algorithm requires the encrypted values to be in the range
// [-1.5 - 1.5i, 1.5 + 1.5i] or the result will be wrong. Each iteration increases the precision.
-func (eval *evaluator) InverseNew(op *Ciphertext, steps uint64) (opOut *Ciphertext) {
+func (eval *evaluator) InverseNew(op *Ciphertext, steps int) (opOut *Ciphertext) {
cbar := eval.NegNew(op)
@@ -86,7 +90,7 @@ func (eval *evaluator) InverseNew(op *Ciphertext, steps uint64) (opOut *Cipherte
tmp := eval.AddConstNew(cbar, 1)
opOut = tmp.CopyNew().Ciphertext()
- for i := uint64(1); i < steps; i++ {
+ for i := 1; i < steps; i++ {
eval.MulRelin(cbar.El(), cbar.El(), cbar.Ciphertext())
diff --git a/ckks/bootstrap.go b/ckks/bootstrap.go
index 57cf82cd..5f2986e9 100644
--- a/ckks/bootstrap.go
+++ b/ckks/bootstrap.go
@@ -2,11 +2,7 @@ package ckks
import (
"github.com/ldsec/lattigo/v2/ring"
- "github.com/ldsec/lattigo/v2/utils"
-
- //"log"
"math"
- //"time"
)
// Bootstrapp re-encrypt a ciphertext at lvl Q0 to a ciphertext at MaxLevel-k where k is the depth of the bootstrapping circuit.
@@ -14,6 +10,7 @@ import (
// If the input ciphertext is at level one or more, the input scale does not need to be an exact power of two as one level
// can be used to do a scale matching.
func (btp *Bootstrapper) Bootstrapp(ct *Ciphertext) *Ciphertext {
+
//var t time.Time
var ct0, ct1 *Ciphertext
@@ -40,7 +37,11 @@ func (btp *Bootstrapper) Bootstrapp(ct *Ciphertext) *Ciphertext {
btp.evaluator.DropLevel(ct, 1)
}
- // and does an integer constant mult by round((Q0/2^{10})/ctscle)
+ // and does an integer constant mult by round((Q0/Delta_m)/ctscle)
+
+ if btp.prescale < ct.Scale() {
+ panic("ciphetext scale > Q[0]/(Q[0]/Delta_m)")
+ }
btp.evaluator.ScaleUp(ct, math.Round(btp.prescale/ct.Scale()), ct)
}
@@ -59,7 +60,7 @@ func (btp *Bootstrapper) Bootstrapp(ct *Ciphertext) *Ciphertext {
// Part 1 : Coeffs to slots
//t = time.Now()
- ct0, ct1 = btp.coeffsToSlots(ct)
+ ct0, ct1 = CoeffsToSlots(ct, btp.pDFTInv, btp.evaluator)
//log.Println("After CtS :", time.Now().Sub(t), ct0.Level(), ct0.Scale())
// Part 2 : SineEval
@@ -69,7 +70,7 @@ func (btp *Bootstrapper) Bootstrapp(ct *Ciphertext) *Ciphertext {
// Part 3 : Slots to coeffs
//t = time.Now()
- ct0 = btp.slotsToCoeffs(ct0, ct1)
+ ct0 = SlotsToCoeffs(ct0, ct1, btp.pDFT, btp.evaluator)
ct0.SetScale(math.Exp2(math.Round(math.Log2(ct0.Scale())))) // rounds to the nearest power of two
//log.Println("After StC :", time.Now().Sub(t), ct0.Level(), ct0.Scale())
@@ -80,9 +81,9 @@ func (btp *Bootstrapper) subSum(ct *Ciphertext) *Ciphertext {
for i := btp.params.logSlots; i < btp.params.MaxLogSlots(); i++ {
- btp.evaluator.Rotate(ct, 1<> 1) - 1)
-
- if i != 0 {
- cOutQ[i] = [2]*ring.Poly{eval.ringQ.NewPolyLvl(ct0.Level()), eval.ringQ.NewPolyLvl(ct0.Level())}
- cOutP[i] = [2]*ring.Poly{eval.params.NewPolyP(), eval.params.NewPolyP()}
- eval.permuteNTTHoistedNoModDown(ct0, c2QiQDecomp, c2QiPDecomp, i, rotkeys, cOutQ[i], cOutP[i])
- }
- }
-
- c2QiQDecomp = nil
- c2QiPDecomp = nil
-
- return
-}
-
-func (eval *evaluator) permuteNTTHoistedNoModDown(ct0 *Ciphertext, c2QiQDecomp, c2QiPDecomp []*ring.Poly, k uint64, rotKeys *RotationKeySet, ctOutQ, ctOutP [2]*ring.Poly) {
-
- pool2Q := eval.poolQ[0]
- pool3Q := eval.poolQ[1]
-
- pool2P := eval.poolP[0]
- pool3P := eval.poolP[1]
-
- levelQ := ct0.Level()
- levelP := eval.params.PiCount() - 1
-
- galEl := eval.params.GaloisElementForColumnRotationBy(int(k))
- rtk := rotKeys.Keys[galEl]
- indexes := eval.permuteNTTIndex[galEl]
-
- eval.keyswitchHoistedNoModDown(levelQ, c2QiQDecomp, c2QiPDecomp, rtk, pool2Q, pool3Q, pool2P, pool3P)
-
- ring.PermuteNTTWithIndexLvl(levelQ, pool2Q, indexes, ctOutQ[0])
- ring.PermuteNTTWithIndexLvl(levelQ, pool3Q, indexes, ctOutQ[1])
-
- ring.PermuteNTTWithIndexLvl(levelP, pool2P, indexes, ctOutP[0])
- ring.PermuteNTTWithIndexLvl(levelP, pool3P, indexes, ctOutP[1])
-}
-
// Sine Evaluation ct0 = Q/(2pi) * sin((2pi/Q) * ct0)
func (btp *Bootstrapper) evaluateSine(ct0, ct1 *Ciphertext) (*Ciphertext, *Ciphertext) {
- ct0.MulScale(btp.deviation)
- btp.scale = ct0.Scale() // Reference scale is changed to the new ciphertext's scale.
-
- // pre-computes the target scale for the output of the polynomial evaluation such that
- // the output scale after the polynomial evaluation followed by the double angle formula
- // does not change the scale of the ciphertext.
- for i := uint64(0); i < btp.SinRescal; i++ {
- btp.scale *= float64(btp.params.qi[btp.StCLevel[0]+i+1])
- btp.scale = math.Sqrt(btp.scale)
- }
+ ct0.MulScale(btp.MessageRatio)
+ btp.evaluator.scale = btp.sinescale // Reference scale is changed to the Qi used for the SineEval (which is also close to the new ciphetext scale)
ct0 = btp.evaluateCheby(ct0)
- ct0.DivScale(btp.deviation * btp.postscale / btp.params.scale)
+ ct0.DivScale(btp.MessageRatio * btp.postscale / btp.params.scale)
if ct1 != nil {
- ct1.MulScale(btp.deviation)
+ ct1.MulScale(btp.MessageRatio)
ct1 = btp.evaluateCheby(ct1)
- ct1.DivScale(btp.deviation * btp.postscale / btp.params.scale)
+ ct1.DivScale(btp.MessageRatio * btp.postscale / btp.params.scale)
}
// Reference scale is changed back to the current ciphertext's scale.
- btp.scale = ct0.Scale()
+ btp.evaluator.scale = ct0.Scale()
return ct0, ct1
}
func (btp *Bootstrapper) evaluateCheby(ct *Ciphertext) *Ciphertext {
- cheby := btp.chebycoeffs
+ var err error
- sqrt2pi := math.Pow(0.15915494309189535, 1.0/float64(int(1< 0 {
+ if ct, err = btp.EvaluatePoly(ct, btp.arcSinePoly, ct.Scale()); err != nil {
panic(err)
}
}
diff --git a/ckks/bootstrap_bench_test.go b/ckks/bootstrap_bench_test.go
index 834a6735..01c45086 100644
--- a/ckks/bootstrap_bench_test.go
+++ b/ckks/bootstrap_bench_test.go
@@ -16,16 +16,25 @@ func BenchmarkBootstrapp(b *testing.B) {
var testContext = new(testParams)
var btp *Bootstrapper
- paramSet := uint64(3)
+ paramSet := 2
btpParams := DefaultBootstrapParams[paramSet]
- if testContext, err = genTestParams(DefaultBootstrapSchemeParams[paramSet], btpParams.H); err != nil {
+
+ params, err := btpParams.Params()
+ if err != nil {
+ panic(err)
+ }
+ if testContext, err = genTestParams(params, btpParams.H); err != nil {
panic(err)
}
- btpKey := testContext.kgen.GenBootstrappingKey(testContext.params.logSlots, btpParams, testContext.sk)
+ rotations := testContext.kgen.GenRotationIndexesForBootstrapping(testContext.params.logSlots, btpParams)
- if btp, err = NewBootstrapper(testContext.params, btpParams, *btpKey); err != nil {
+ rotkeys := testContext.kgen.GenRotationKeysForRotations(rotations, true, testContext.sk)
+
+ btpKey := BootstrappingKey{testContext.rlk, rotkeys}
+
+ if btp, err = NewBootstrapper(testContext.params, btpParams, btpKey); err != nil {
panic(err)
}
@@ -57,7 +66,7 @@ func BenchmarkBootstrapp(b *testing.B) {
// Part 1 : Coeffs to slots
t = time.Now()
- ct0, ct1 = btp.coeffsToSlots(ct)
+ ct0, ct1 = CoeffsToSlots(ct, btp.pDFTInv, btp.evaluator)
b.Log("After CtS :", time.Since(t), ct0.Level(), ct0.Scale())
// Part 2 : SineEval
@@ -67,7 +76,7 @@ func BenchmarkBootstrapp(b *testing.B) {
// Part 3 : Slots to coeffs
t = time.Now()
- ct0 = btp.slotsToCoeffs(ct0, ct1)
+ ct0 = SlotsToCoeffs(ct0, ct1, btp.pDFT, btp.evaluator)
ct0.SetScale(math.Exp2(math.Round(math.Log2(ct0.Scale()))))
b.Log("After StC :", time.Since(t), ct0.Level(), ct0.Scale())
}
diff --git a/ckks/bootstrap_params.go b/ckks/bootstrap_params.go
index 7b19be39..64ff0bee 100644
--- a/ckks/bootstrap_params.go
+++ b/ckks/bootstrap_params.go
@@ -1,26 +1,10 @@
package ckks
-// BootstrappingParameters is a struct for the default bootstrapping parameters
-type BootstrappingParameters struct {
- H uint64 // Hamming weight of the secret key
- SinType SinType // Choose between [Sin(2*pi*x)] or [cos(2*pi*x/r) with double angle formula]
- SinRange uint64 // K parameter (interpolation in the range -K to K)
- SinDeg uint64 // Degree of the interpolation
- SinRescal uint64 // Number of rescale and double angle formula (only applies for cos)
- CtSLevel []uint64 // Level of the Coeffs To Slots
- StCLevel []uint64 // Level of the Slots To Coeffs
- MaxN1N2Ratio float64 // n1/n2 ratio for the bsgs algo for matrix x vector eval
-}
-
-// CtSDepth returns the number of levels allocated to CoeffsToSlots
-func (b *BootstrappingParameters) CtSDepth() uint64 {
- return uint64(len(b.CtSLevel))
-}
-
-// StCDepth returns the number of levels allocated to SlotToCoeffs
-func (b *BootstrappingParameters) StCDepth() uint64 {
- return uint64(len(b.StCLevel))
-}
+import (
+ "github.com/ldsec/lattigo/v2/utils"
+ "math"
+ //"fmt"
+)
// SinType is the type of function used during the bootstrapping
// for the homomorphic modular reduction
@@ -33,71 +17,437 @@ const (
Cos2 = SinType(2) // Standard Chebyshev approximation of pow((1/2pi), 1/2^r) * cos(2pi(x-0.25)/2^r)
)
-// Copy return a new BootstrapParams which is a copy of the target
+// BootstrappingParameters is a struct for the default bootstrapping parameters
+type BootstrappingParameters struct {
+ ResidualModuli
+ KeySwitchModuli
+ SlotsToCoeffsModuli
+ SineEvalModuli
+ CoeffsToSlotsModuli
+ LogN int
+ LogSlots int
+ Scale float64
+ Sigma float64
+ H int // Hamming weight of the secret key
+ SinType SinType // Chose betwenn [Sin(2*pi*x)] or [cos(2*pi*x/r) with double angle formula]
+ MessageRatio float64 // Ratio between Q0 and m, i.e. Q[0]/|m|
+ SinRange int // K parameter (interpolation in the range -K to K)
+ SinDeg int // Degree of the interpolation
+ SinRescal int // Number of rescale and double angle formula (only applies for cos)
+ ArcSineDeg int // Degree of the Taylor arcsine composed with f(2*pi*x) (if zero then not used)
+ MaxN1N2Ratio float64 // n1/n2 ratio for the bsgs algo for matrix x vector eval
+ BitReversed bool // Flag for bit-reverseed input to the DFT (with bit-reversed output), by default false.
+}
+
+// Params generates a new set of Parameters from the BootstrappingParameters
+func (b *BootstrappingParameters) Params() (p *Parameters, err error) {
+ Qi := append(b.ResidualModuli, b.SlotsToCoeffsModuli.Qi...)
+ Qi = append(Qi, b.SineEvalModuli.Qi...)
+ Qi = append(Qi, b.CoeffsToSlotsModuli.Qi...)
+
+ if p, err = NewParametersFromModuli(b.LogN, &Moduli{Qi, b.KeySwitchModuli}); err != nil {
+ return nil, err
+ }
+
+ p.SetScale(b.Scale)
+ p.SetLogSlots(b.LogSlots)
+ p.SetSigma(b.Sigma)
+ return
+}
+
+// Copy return a new BootstrappingParameters which is a copy of the target
func (b *BootstrappingParameters) Copy() *BootstrappingParameters {
paramsCopy := &BootstrappingParameters{
+ LogN: b.LogN,
+ LogSlots: b.LogSlots,
+ Scale: b.Scale,
+ Sigma: b.Sigma,
H: b.H,
SinType: b.SinType,
+ MessageRatio: b.MessageRatio,
SinRange: b.SinRange,
SinDeg: b.SinDeg,
SinRescal: b.SinRescal,
- CtSLevel: make([]uint64, len(b.CtSLevel)),
- StCLevel: make([]uint64, len(b.StCLevel)),
+ ArcSineDeg: b.ArcSineDeg,
MaxN1N2Ratio: b.MaxN1N2Ratio,
+ BitReversed: b.BitReversed,
}
- copy(paramsCopy.CtSLevel, b.CtSLevel)
- copy(paramsCopy.StCLevel, b.StCLevel)
+
+ // KeySwitchModuli
+ paramsCopy.KeySwitchModuli = make([]uint64, len(b.KeySwitchModuli))
+ copy(paramsCopy.KeySwitchModuli, b.KeySwitchModuli)
+
+ // ResidualModuli
+ paramsCopy.ResidualModuli = make([]uint64, len(b.ResidualModuli))
+ copy(paramsCopy.ResidualModuli, b.ResidualModuli)
+
+ // CoeffsToSlotsModuli
+ paramsCopy.CoeffsToSlotsModuli.Qi = make([]uint64, b.CtSDepth(true))
+ copy(paramsCopy.CoeffsToSlotsModuli.Qi, b.CoeffsToSlotsModuli.Qi)
+
+ paramsCopy.CoeffsToSlotsModuli.ScalingFactor = make([][]float64, b.CtSDepth(true))
+ for i := range paramsCopy.CoeffsToSlotsModuli.ScalingFactor {
+ paramsCopy.CoeffsToSlotsModuli.ScalingFactor[i] = make([]float64, len(b.CoeffsToSlotsModuli.ScalingFactor[i]))
+ copy(paramsCopy.CoeffsToSlotsModuli.ScalingFactor[i], b.CoeffsToSlotsModuli.ScalingFactor[i])
+ }
+
+ // SineEvalModuli
+ paramsCopy.SineEvalModuli.Qi = make([]uint64, len(b.SineEvalModuli.Qi))
+ copy(paramsCopy.SineEvalModuli.Qi, b.SineEvalModuli.Qi)
+ paramsCopy.SineEvalModuli.ScalingFactor = b.SineEvalModuli.ScalingFactor
+
+ // SlotsToCoeffsModuli
+ paramsCopy.SlotsToCoeffsModuli.Qi = make([]uint64, b.StCDepth(true))
+ copy(paramsCopy.SlotsToCoeffsModuli.Qi, b.SlotsToCoeffsModuli.Qi)
+
+ paramsCopy.SlotsToCoeffsModuli.ScalingFactor = make([][]float64, b.StCDepth(true))
+ for i := range paramsCopy.SlotsToCoeffsModuli.ScalingFactor {
+ paramsCopy.SlotsToCoeffsModuli.ScalingFactor[i] = make([]float64, len(b.SlotsToCoeffsModuli.ScalingFactor[i]))
+ copy(paramsCopy.SlotsToCoeffsModuli.ScalingFactor[i], b.SlotsToCoeffsModuli.ScalingFactor[i])
+ }
+
return paramsCopy
}
-// DefaultBootstrapSchemeParams are default scheme params for the bootstrapping
-var DefaultBootstrapSchemeParams = []*Parameters{
+// ResidualModuli is a list of the moduli available after the bootstrapping.
+type ResidualModuli []uint64
+// KeySwitchModuli is a list of the special moduli used for the key-switching.
+type KeySwitchModuli []uint64
+
+// CoeffsToSlotsModuli is a list of the moduli used during he CoeffsToSlots step.
+type CoeffsToSlotsModuli struct {
+ Qi []uint64
+ ScalingFactor [][]float64
+}
+
+// SineEvalModuli is a list of the moduli used during the SineEval step.
+type SineEvalModuli struct {
+ Qi []uint64
+ ScalingFactor float64
+}
+
+// SlotsToCoeffsModuli is a list of the moduli used during the SlotsToCoeffs step.
+type SlotsToCoeffsModuli struct {
+ Qi []uint64
+ ScalingFactor [][]float64
+}
+
+// MaxLevel returns the maximum level of the bootstrapping parameters
+func (b *BootstrappingParameters) MaxLevel() int {
+ return len(b.ResidualModuli) + len(b.CoeffsToSlotsModuli.Qi) + len(b.SineEvalModuli.Qi) + len(b.SlotsToCoeffsModuli.Qi) - 1
+}
+
+// SineEvalDepth returns the depth of the SineEval. If true, then also
+// counts the double angle formula.
+func (b *BootstrappingParameters) SineEvalDepth(withRescale bool) int {
+ depth := int(math.Ceil(math.Log2(float64(b.SinDeg + 1))))
+
+ if withRescale {
+ depth += b.SinRescal
+ }
+
+ return depth
+}
+
+// ArcSineDepth returns the depth of the arcsine polynomial.
+func (b *BootstrappingParameters) ArcSineDepth() int {
+ return int(math.Ceil(math.Log2(float64(b.ArcSineDeg + 1))))
+}
+
+// CtSDepth returns the number of levels allocated to CoeffsToSlots.
+// If actual == true then returns the number of moduli consumed, else
+// returns the factorization depth.
+func (b *BootstrappingParameters) CtSDepth(actual bool) (depth int) {
+ if actual {
+ depth = len(b.CoeffsToSlotsModuli.ScalingFactor)
+ } else {
+ for i := range b.CoeffsToSlotsModuli.ScalingFactor {
+ for range b.CoeffsToSlotsModuli.ScalingFactor[i] {
+ depth++
+ }
+ }
+ }
+
+ return
+}
+
+// CtSLevels returns the index of the Qi used int CoeffsToSlots
+func (b *BootstrappingParameters) CtSLevels() (ctsLevel []int) {
+ ctsLevel = []int{}
+ for i := range b.CoeffsToSlotsModuli.Qi {
+ for range b.CoeffsToSlotsModuli.ScalingFactor[b.CtSDepth(true)-1-i] {
+ ctsLevel = append(ctsLevel, b.MaxLevel()-i)
+ }
+ }
+
+ return
+}
+
+// StCDepth returns the number of levels allocated to SlotToCoeffs.
+// If actual == true then returns the number of moduli consumed, else
+// returns the factorization depth.
+func (b *BootstrappingParameters) StCDepth(actual bool) (depth int) {
+ if actual {
+ depth = len(b.SlotsToCoeffsModuli.ScalingFactor)
+ } else {
+ for i := range b.SlotsToCoeffsModuli.ScalingFactor {
+ for range b.SlotsToCoeffsModuli.ScalingFactor[i] {
+ depth++
+ }
+ }
+ }
+
+ return
+}
+
+// StCLevels returns the index of the Qi used in SlotsToCoeffs
+func (b *BootstrappingParameters) StCLevels() (stcLevel []int) {
+ stcLevel = []int{}
+ for i := range b.SlotsToCoeffsModuli.Qi {
+ for range b.SlotsToCoeffsModuli.ScalingFactor[b.StCDepth(true)-1-i] {
+ stcLevel = append(stcLevel, b.MaxLevel()-b.CtSDepth(true)-b.SineEvalDepth(true)-b.ArcSineDepth()-i)
+ }
+ }
+
+ return
+}
+
+// GenCoeffsToSlotsMatrix generates the factorized encoding matrix
+// scaling : constant by witch the all the matrices will be multuplied by
+// encoder : ckks.Encoder
+func (b *BootstrappingParameters) GenCoeffsToSlotsMatrix(scaling complex128, encoder Encoder) []*PtDiagMatrix {
+
+ logSlots := b.LogSlots
+ slots := 1 << logSlots
+ depth := b.CtSDepth(false)
+ logdSlots := logSlots + 1
+ if logdSlots == b.LogN {
+ logdSlots--
+ }
+
+ roots := computeRoots(slots << 1)
+ pow5 := make([]int, (slots<<1)+1)
+ pow5[0] = 1
+ for i := 1; i < (slots<<1)+1; i++ {
+ pow5[i] = pow5[i-1] * 5
+ pow5[i] &= (slots << 2) - 1
+ }
+
+ ctsLevels := b.CtSLevels()
+
+ // CoeffsToSlots vectors
+ pDFTInv := make([]*PtDiagMatrix, len(ctsLevels))
+ pVecDFTInv := computeDFTMatrices(logSlots, logdSlots, depth, roots, pow5, scaling, true, b.BitReversed)
+ cnt := 0
+ for i := range b.CoeffsToSlotsModuli.ScalingFactor {
+ for j := range b.CoeffsToSlotsModuli.ScalingFactor[b.CtSDepth(true)-i-1] {
+ pDFTInv[cnt] = encoder.EncodeDiagMatrixBSGSAtLvl(ctsLevels[cnt], pVecDFTInv[cnt], b.CoeffsToSlotsModuli.ScalingFactor[b.CtSDepth(true)-i-1][j], b.MaxN1N2Ratio, logdSlots)
+ cnt++
+ }
+ }
+
+ return pDFTInv
+}
+
+// GenSlotsToCoeffsMatrix generates the factorized decoding matrix
+// scaling : constant by witch the all the matrices will be multuplied by
+// encoder : ckks.Encoder
+func (b *BootstrappingParameters) GenSlotsToCoeffsMatrix(scaling complex128, encoder Encoder) []*PtDiagMatrix {
+
+ logSlots := b.LogSlots
+ slots := 1 << logSlots
+ depth := b.StCDepth(false)
+ logdSlots := logSlots + 1
+ if logdSlots == b.LogN {
+ logdSlots--
+ }
+
+ roots := computeRoots(slots << 1)
+ pow5 := make([]int, (slots<<1)+1)
+ pow5[0] = 1
+ for i := 1; i < (slots<<1)+1; i++ {
+ pow5[i] = pow5[i-1] * 5
+ pow5[i] &= (slots << 2) - 1
+ }
+
+ stcLevels := b.StCLevels()
+
+ // CoeffsToSlots vectors
+ pDFT := make([]*PtDiagMatrix, len(stcLevels))
+ pVecDFT := computeDFTMatrices(logSlots, logdSlots, depth, roots, pow5, scaling, false, b.BitReversed)
+ cnt := 0
+ for i := range b.SlotsToCoeffsModuli.ScalingFactor {
+ for j := range b.SlotsToCoeffsModuli.ScalingFactor[b.StCDepth(true)-i-1] {
+ pDFT[cnt] = encoder.EncodeDiagMatrixBSGSAtLvl(stcLevels[cnt], pVecDFT[cnt], b.SlotsToCoeffsModuli.ScalingFactor[b.StCDepth(true)-i-1][j], b.MaxN1N2Ratio, logdSlots)
+ cnt++
+ }
+ }
+
+ return pDFT
+}
+
+// DefaultBootstrapParams are default bootstrapping params for the bootstrapping.
+var DefaultBootstrapParams = []*BootstrappingParameters{
+
+ // SET I
+ // 1546
{
- logN: 16,
- logSlots: 15,
- qi: []uint64{
- 0x80000000080001, // 55 Q0
- 0x2000000a0001, // 45
- 0x2000000e0001, // 45
- 0x1fffffc20001, // 45
- 0x200000440001, // 45
- 0x200000500001, // 45
- 0x200000620001, // 45
- 0x1fffff980001, // 45
- 0x2000006a0001, // 45
- 0x1fffff7e0001, // 45
- 0x200000860001, // 45
- 0x100000000060001, // 56 StC (28 + 28)
- 0xffa0001, // 28 StC
- 0x80000000440001, // 55 Sine (double angle)
- 0x7fffffffba0001, // 55 Sine (double angle)
- 0x80000000500001, // 55 Sine
- 0x7fffffffaa0001, // 55 Sine
- 0x800000005e0001, // 55 Sine
- 0x7fffffff7e0001, // 55 Sine
- 0x7fffffff380001, // 55 Sine
- 0x80000000ca0001, // 55 Sine
- 0x200000000e0001, // 53 CtS
- 0x20000000140001, // 53 CtS
- 0x20000000280001, // 53 CtS
- 0x1fffffffd80001, // 53 CtS
+ LogN: 16,
+ LogSlots: 15,
+ Scale: 1 << 40,
+ Sigma: DefaultSigma,
+ ResidualModuli: []uint64{
+ 0x10000000006e0001, // 60 Q0
+ 0x10000140001, // 40
+ 0xffffe80001, // 40
+ 0xffffc40001, // 40
+ 0x100003e0001, // 40
+ 0xffffb20001, // 40
+ 0x10000500001, // 40
+ 0xffff940001, // 40
+ 0xffff8a0001, // 40
+ 0xffff820001, // 40
},
- pi: []uint64{
- 0xfffffffff00001, // 56
- 0xffffffffd80001, // 56
- 0x1000000002a0001, // 56
- 0xffffffffd20001, // 56
- 0x100000000480001, // 56
+ KeySwitchModuli: []uint64{
+ 0x1fffffffffe00001, // Pi 61
+ 0x1fffffffffc80001, // Pi 61
+ 0x1fffffffffb40001, // Pi 61
+ 0x1fffffffff500001, // Pi 61
+ 0x1fffffffff420001, // Pi 61
},
- scale: 1 << 45,
- sigma: DefaultSigma,
+ SlotsToCoeffsModuli: SlotsToCoeffsModuli{
+ Qi: []uint64{
+ 0x7fffe60001, // 39 StC
+ 0x7fffe40001, // 39 StC
+ 0x7fffe00001, // 39 StC
+ },
+ ScalingFactor: [][]float64{
+ {0x7fffe60001},
+ {0x7fffe40001},
+ {0x7fffe00001},
+ },
+ },
+ SineEvalModuli: SineEvalModuli{
+ Qi: []uint64{
+ 0xfffffffff840001, // 60 Sine (double angle)
+ 0x1000000000860001, // 60 Sine (double angle)
+ 0xfffffffff6a0001, // 60 Sine
+ 0x1000000000980001, // 60 Sine
+ 0xfffffffff5a0001, // 60 Sine
+ 0x1000000000b00001, // 60 Sine
+ 0x1000000000ce0001, // 60 Sine
+ 0xfffffffff2a0001, // 60 Sine
+ },
+ ScalingFactor: 1 << 60,
+ },
+ CoeffsToSlotsModuli: CoeffsToSlotsModuli{
+ Qi: []uint64{
+ 0x100000000060001, // 58 CtS
+ 0xfffffffff00001, // 58 CtS
+ 0xffffffffd80001, // 58 CtS
+ 0x1000000002a0001, // 58 CtS
+ },
+ ScalingFactor: [][]float64{
+ {0x100000000060001},
+ {0xfffffffff00001},
+ {0xffffffffd80001},
+ {0x1000000002a0001},
+ },
+ },
+ H: 192,
+ SinType: Cos1,
+ MessageRatio: 256.0,
+ SinRange: 25,
+ SinDeg: 63,
+ SinRescal: 2,
+ ArcSineDeg: 0,
+ MaxN1N2Ratio: 16.0,
+ BitReversed: false,
},
+ // SET II
+ // 1547
{
- logN: 16,
- logSlots: 15,
- qi: []uint64{
+ LogN: 16,
+ LogSlots: 15,
+ Scale: 1 << 45,
+ Sigma: DefaultSigma,
+ ResidualModuli: []uint64{
+ 0x10000000006e0001, // 60 Q0
+ 0x2000000a0001, // 45
+ 0x2000000e0001, // 45
+ 0x1fffffc20001, // 45
+ 0x200000440001, // 45
+ 0x200000500001, // 45
+ },
+ KeySwitchModuli: []uint64{
+ 0x1fffffffffe00001, // Pi 61
+ 0x1fffffffffc80001, // Pi 61
+ 0x1fffffffffb40001, // Pi 61
+ 0x1fffffffff500001, // Pi 61
+ },
+ SlotsToCoeffsModuli: SlotsToCoeffsModuli{
+ Qi: []uint64{
+ 0x3ffffe80001, //42 StC
+ 0x3ffffd20001, //42 StC
+ 0x3ffffca0001, //42 StC
+ },
+ ScalingFactor: [][]float64{
+ {0x3ffffe80001},
+ {0x3ffffd20001},
+ {0x3ffffca0001},
+ },
+ },
+ SineEvalModuli: SineEvalModuli{
+ Qi: []uint64{
+ 0xffffffffffc0001, // ArcSine
+ 0xfffffffff240001, // ArcSine
+ 0x1000000000f00001, // ArcSine
+ 0xfffffffff840001, // Double angle
+ 0x1000000000860001, // Double angle
+ 0xfffffffff6a0001, // Sine
+ 0x1000000000980001, // Sine
+ 0xfffffffff5a0001, // Sine
+ 0x1000000000b00001, // Sine
+ 0x1000000000ce0001, // Sine
+ 0xfffffffff2a0001, // Sine
+ },
+ ScalingFactor: 1 << 60,
+ },
+ CoeffsToSlotsModuli: CoeffsToSlotsModuli{
+ Qi: []uint64{
+ 0x400000000360001, // 58 CtS
+ 0x3ffffffffbe0001, // 58 CtS
+ 0x400000000660001, // 58 CtS
+ 0x4000000008a0001, // 58 CtS
+ },
+ ScalingFactor: [][]float64{
+ {0x400000000360001},
+ {0x3ffffffffbe0001},
+ {0x400000000660001},
+ {0x4000000008a0001},
+ },
+ },
+ H: 192,
+ SinType: Cos1,
+ MessageRatio: 4.0,
+ SinRange: 25,
+ SinDeg: 63,
+ SinRescal: 2,
+ ArcSineDeg: 7,
+ MaxN1N2Ratio: 16.0,
+ BitReversed: false,
+ },
+
+ // SET III
+ // 1553
+ {
+ LogN: 16,
+ LogSlots: 15,
+ Scale: 1 << 30,
+ Sigma: DefaultSigma,
+ ResidualModuli: []uint64{
0x80000000080001, // 55 Q0
0xffffffffffc0001, // 60
0x10000000006e0001, // 60
@@ -106,65 +456,82 @@ var DefaultBootstrapSchemeParams = []*Parameters{
0xfffffffff6a0001, // 60
0x1000000000980001, // 60
0xfffffffff5a0001, // 60
- 0x1000000000b00001, // 60 StC (30)
- 0x1000000000ce0001, // 60 StC (30+30)
- 0x80000000440001, // 55 Sine (double angle)
- 0x7fffffffba0001, // 55 Sine (double angle)
- 0x80000000500001, // 55 Sine
- 0x7fffffffaa0001, // 55 Sine
- 0x800000005e0001, // 55 Sine
- 0x7fffffff7e0001, // 55 Sine
- 0x7fffffff380001, // 55 Sine
- 0x80000000ca0001, // 55 Sine
- 0x200000000e0001, // 53 CtS
- 0x20000000140001, // 53 CtS
- 0x20000000280001, // 53 CtS
- 0x1fffffffd80001, // 53 CtS
},
- pi: []uint64{
+ KeySwitchModuli: []uint64{
0x1fffffffffe00001, // Pi 61
0x1fffffffffc80001, // Pi 61
0x1fffffffffb40001, // Pi 61
0x1fffffffff500001, // Pi 61
0x1fffffffff420001, // Pi 61
},
- scale: 1 << 30,
- sigma: DefaultSigma,
+ SlotsToCoeffsModuli: SlotsToCoeffsModuli{
+ Qi: []uint64{
+ 0x1000000000b00001, // 60 StC (30)
+ 0x1000000000ce0001, // 60 StC (30+30)
+ },
+ ScalingFactor: [][]float64{
+ {1073741824.0},
+ {1073741824.0062866, 1073741824.0062866},
+ },
+ },
+ SineEvalModuli: SineEvalModuli{
+ Qi: []uint64{
+ 0x80000000440001, // 55 Sine (double angle)
+ 0x7fffffffba0001, // 55 Sine (double angle)
+ 0x80000000500001, // 55 Sine
+ 0x7fffffffaa0001, // 55 Sine
+ 0x800000005e0001, // 55 Sine
+ 0x7fffffff7e0001, // 55 Sine
+ 0x7fffffff380001, // 55 Sine
+ 0x80000000ca0001, // 55 Sine
+ },
+ ScalingFactor: 1 << 55,
+ },
+ CoeffsToSlotsModuli: CoeffsToSlotsModuli{
+ Qi: []uint64{
+ 0x200000000e0001, // 53 CtS
+ 0x20000000140001, // 53 CtS
+ 0x20000000280001, // 53 CtS
+ 0x1fffffffd80001, // 53 CtS
+ },
+ ScalingFactor: [][]float64{
+ {0x200000000e0001},
+ {0x20000000140001},
+ {0x20000000280001},
+ {0x1fffffffd80001},
+ },
+ },
+ H: 192,
+ SinType: Cos1,
+ MessageRatio: 256.0,
+ SinRange: 25,
+ SinDeg: 63,
+ SinRescal: 2,
+ ArcSineDeg: 0,
+ MaxN1N2Ratio: 16.0,
+ BitReversed: false,
},
+ // Set IV
+ // 1792
{
- logN: 16,
- logSlots: 15,
- qi: []uint64{
- 0x80000000080001, // 55 Q0
- 0x2000000a0001, // 45
- 0x2000000e0001, // 45
- 0x1fffffc20001, // 45
- 0x200000440001, // 45
- 0x200000500001, // 45
- 0x200000620001, // 45
- 0x1fffff980001, // 45
- 0x2000006a0001, // 45
- 0x1fffff7e0001, // 45
- 0x100000000060001, // 56 StC (28 + 28)
- 0xffa0001, // 28 StC
- 0xffffffffffc0001, // 60 Sine (double angle)
- 0x10000000006e0001, // 60 Sine (double angle)
- 0xfffffffff840001, // 60 Sine (double angle)
- 0x1000000000860001, // 60 Sine
- 0xfffffffff6a0001, // 60 Sine
- 0x1000000000980001, // 60 Sine
- 0xfffffffff5a0001, // 60 Sine
- 0x1000000000b00001, // 60 Sine
- 0x1000000000ce0001, // 60 Sine
- 0xfffffffff2a0001, // 60 Sine
- 0xfffffffff240001, // 60 Sine
- 0x200000000e0001, // 53 CtS
- 0x20000000140001, // 53 CtS
- 0x20000000280001, // 53 CtS
- 0x1fffffffd80001, // 53 CtS
+ LogN: 16,
+ LogSlots: 15,
+ Scale: 1 << 40,
+ Sigma: DefaultSigma,
+ ResidualModuli: []uint64{
+ 0x4000000120001, // 60 Q0
+ 0x10000140001,
+ 0xffffe80001,
+ 0xffffc40001,
+ 0x100003e0001,
+ 0xffffb20001,
+ 0x10000500001,
+ 0xffff940001,
+ 0xffff8a0001,
+ 0xffff820001,
},
- pi: []uint64{
+ KeySwitchModuli: []uint64{
0x1fffffffffe00001, // Pi 61
0x1fffffffffc80001, // Pi 61
0x1fffffffffb40001, // Pi 61
@@ -172,90 +539,563 @@ var DefaultBootstrapSchemeParams = []*Parameters{
0x1fffffffff420001, // Pi 61
0x1fffffffff380001, // Pi 61
},
- scale: 1 << 45,
- sigma: DefaultSigma,
- },
-
- {
- logN: 15,
- logSlots: 14,
- qi: []uint64{
- 0x7fffb0001, // 35 Q0
- 0x4000000420001, // 50
- 0x1fc0001, // 25
- 0xffffffffffc0001, // 60 StC (30+30)
- 0x4000000120001, // 50 Sine
- 0x40000001b0001, // 50 Sine
- 0x3ffffffdf0001, // 50 Sine
- 0x4000000270001, // 50 Sine
- 0x3ffffffd20001, // 50 Sine
- 0x3ffffffcd0001, // 50 Sine
- 0x4000000350001, // 50 Sine
- 0x3ffffffc70001, // 50 Sine
- 0x1fffffff50001, // 49 CtS
- 0x1ffffffea0001, // 49 CtS
+ SlotsToCoeffsModuli: SlotsToCoeffsModuli{
+ Qi: []uint64{
+ 0x100000000060001, // 56 StC (28 + 28)
+ 0xffa0001, // 28 StC
+ },
+ ScalingFactor: [][]float64{
+ {268435456.0007324, 268435456.0007324},
+ {0xffa0001},
+ },
},
- pi: []uint64{
- 0x7e40000000001, // 50
- 0x7c80000000001, // 50
+ SineEvalModuli: SineEvalModuli{
+ Qi: []uint64{
+ 0xffffffffffc0001, // 60 Sine (double angle)
+ 0x10000000006e0001, // 60 Sine (double angle)
+ 0xfffffffff840001, // 60 Sine (double angle)
+ 0x1000000000860001, // 60 Sine (double angle)
+ 0xfffffffff6a0001, // 60 Sine
+ 0x1000000000980001, // 60 Sine
+ 0xfffffffff5a0001, // 60 Sine
+ 0x1000000000b00001, // 60 Sine
+ 0x1000000000ce0001, // 60 Sine
+ 0xfffffffff2a0001, // 60 Sine
+ 0xfffffffff240001, // 60 Sine
+ 0x1000000000f00001, // 60 Sine
+ },
+ ScalingFactor: 1 << 60,
+ },
+ CoeffsToSlotsModuli: CoeffsToSlotsModuli{
+ Qi: []uint64{
+ 0x200000000e0001, // 53 CtS
+ 0x20000000140001, // 53 CtS
+ 0x20000000280001, // 53 CtS
+ 0x1fffffffd80001, // 53 CtS
+ },
+ ScalingFactor: [][]float64{
+ {0x200000000e0001},
+ {0x20000000140001},
+ {0x20000000280001},
+ {0x1fffffffd80001},
+ },
},
- scale: 1 << 25,
- sigma: DefaultSigma,
- },
-}
-
-// DefaultBootstrapParams are default bootstrapping params for the bootstrapping
-var DefaultBootstrapParams = []*BootstrappingParameters{
-
- // SET II
- // 1525 - 550
- {
- H: 192,
- SinType: Cos1,
- SinRange: 21,
- SinDeg: 52,
- SinRescal: 2,
- CtSLevel: []uint64{24, 23, 22, 21},
- StCLevel: []uint64{12, 11, 11},
- MaxN1N2Ratio: 16.0,
- },
-
- // SET V
- // 1553 - 505
- {
- H: 192,
- SinType: Cos1,
- SinRange: 21,
- SinDeg: 52,
- SinRescal: 2,
- CtSLevel: []uint64{21, 20, 19, 18},
- StCLevel: []uint64{9, 9, 8},
- MaxN1N2Ratio: 16.0,
- },
-
- // Set VII
- // 1773 - 460
- {
H: 32768,
SinType: Cos2,
- SinRange: 257,
- SinDeg: 250,
- SinRescal: 3,
- CtSLevel: []uint64{26, 25, 24, 23},
- StCLevel: []uint64{11, 10, 10},
+ MessageRatio: 256.0,
+ SinRange: 325,
+ SinDeg: 255,
+ SinRescal: 4,
+ ArcSineDeg: 0,
MaxN1N2Ratio: 16.0,
+ BitReversed: false,
},
- // Set IV
- // 768 - 110
+ // Set V
+ // 768
{
+ LogN: 15,
+ LogSlots: 14,
+ Scale: 1 << 25,
+ Sigma: DefaultSigma,
+ ResidualModuli: []uint64{
+ 0x1fff90001, // 32 Q0
+ 0x4000000420001, // 50
+ 0x1fc0001, // 25
+ },
+ KeySwitchModuli: []uint64{
+ 0x7fffffffe0001, // 51
+ 0x8000000110001, // 51
+ },
+ SlotsToCoeffsModuli: SlotsToCoeffsModuli{
+ Qi: []uint64{
+ 0xffffffffffc0001, // 60 StC (30+30)
+ },
+ ScalingFactor: [][]float64{
+ {1073741823.9998779, 1073741823.9998779},
+ },
+ },
+ SineEvalModuli: SineEvalModuli{
+ Qi: []uint64{
+ 0x4000000120001, // 50 Sine
+ 0x40000001b0001, // 50 Sine
+ 0x3ffffffdf0001, // 50 Sine
+ 0x4000000270001, // 50 Sine
+ 0x3ffffffd20001, // 50 Sine
+ 0x3ffffffcd0001, // 50 Sine
+ 0x4000000350001, // 50 Sine
+ 0x3ffffffc70001, // 50 Sine
+ },
+ ScalingFactor: 1 << 50,
+ },
+ CoeffsToSlotsModuli: CoeffsToSlotsModuli{
+ Qi: []uint64{
+ 0x1fffffff50001, // 49 CtS
+ 0x1ffffffea0001, // 49 CtS
+ },
+ ScalingFactor: [][]float64{
+ {0x1fffffff50001},
+ {0x1ffffffea0001},
+ },
+ },
H: 192,
SinType: Cos1,
- SinRange: 21,
- SinDeg: 52,
+ MessageRatio: 256.0,
+ SinRange: 25,
+ SinDeg: 63,
SinRescal: 2,
- CtSLevel: []uint64{13, 12},
- StCLevel: []uint64{3, 3},
+ ArcSineDeg: 0,
MaxN1N2Ratio: 16.0,
+ BitReversed: false,
},
}
+
+func computeRoots(N int) (roots []complex128) {
+
+ var angle float64
+
+ m := N << 1
+
+ roots = make([]complex128, m)
+
+ roots[0] = 1
+
+ for i := 1; i < m; i++ {
+ angle = 6.283185307179586 * float64(i) / float64(m)
+ roots[i] = complex(math.Cos(angle), math.Sin(angle))
+ }
+
+ return
+}
+
+func fftPlainVec(logN, dslots int, roots []complex128, pow5 []int) (a, b, c [][]complex128) {
+
+ var N, m, index, tt, gap, k, mask, idx1, idx2 int
+
+ N = 1 << logN
+
+ a = make([][]complex128, logN)
+ b = make([][]complex128, logN)
+ c = make([][]complex128, logN)
+
+ var size int
+ if 2*N == dslots {
+ size = 2
+ } else {
+ size = 1
+ }
+
+ index = 0
+ for m = 2; m <= N; m <<= 1 {
+
+ a[index] = make([]complex128, dslots)
+ b[index] = make([]complex128, dslots)
+ c[index] = make([]complex128, dslots)
+
+ tt = m >> 1
+
+ for i := 0; i < N; i += m {
+
+ gap = N / m
+ mask = (m << 2) - 1
+
+ for j := 0; j < m>>1; j++ {
+
+ k = (pow5[j] & mask) * gap
+
+ idx1 = i + j
+ idx2 = i + j + tt
+
+ for u := 0; u < size; u++ {
+ a[index][idx1+u*N] = 1
+ a[index][idx2+u*N] = -roots[k]
+ b[index][idx1+u*N] = roots[k]
+ c[index][idx2+u*N] = 1
+ }
+ }
+ }
+
+ index++
+ }
+
+ return
+}
+
+func fftInvPlainVec(logN, dslots int, roots []complex128, pow5 []int) (a, b, c [][]complex128) {
+
+ var N, m, index, tt, gap, k, mask, idx1, idx2 int
+
+ N = 1 << logN
+
+ a = make([][]complex128, logN)
+ b = make([][]complex128, logN)
+ c = make([][]complex128, logN)
+
+ var size int
+ if 2*N == dslots {
+ size = 2
+ } else {
+ size = 1
+ }
+
+ index = 0
+ for m = N; m >= 2; m >>= 1 {
+
+ a[index] = make([]complex128, dslots)
+ b[index] = make([]complex128, dslots)
+ c[index] = make([]complex128, dslots)
+
+ tt = m >> 1
+
+ for i := 0; i < N; i += m {
+
+ gap = N / m
+ mask = (m << 2) - 1
+
+ for j := 0; j < m>>1; j++ {
+
+ k = ((m << 2) - (pow5[j] & mask)) * gap
+
+ idx1 = i + j
+ idx2 = i + j + tt
+
+ for u := 0; u < size; u++ {
+
+ a[index][idx1+u*N] = 1
+ a[index][idx2+u*N] = -roots[k]
+ b[index][idx1+u*N] = 1
+ c[index][idx2+u*N] = roots[k]
+ }
+ }
+ }
+
+ index++
+ }
+
+ return
+}
+
+func computeDFTMatrices(logSlots, logdSlots, maxDepth int, roots []complex128, pow5 []int, diffscale complex128, inverse, bitreversed bool) (plainVector []map[int][]complex128) {
+
+ var fftLevel, depth, nextfftLevel int
+
+ fftLevel = logSlots
+
+ var a, b, c [][]complex128
+
+ if inverse {
+ a, b, c = fftInvPlainVec(logSlots, 1< 1< 1<>1 {
+ mat[i], mat[N-i] = mat[N-i], mat[i]
+ }
+ }
+}
+
+func conjugateDiagMatrix(mat map[int][]complex128) {
+ for i := range mat {
+
+ for j := range mat[i] {
+ c := mat[i][j]
+ mat[i][j] = complex(real(c), -imag(c))
+ }
+ }
+}
+
+func genBitReverseDiagMatrix(logN int) (diagMat map[int][]complex128) {
+
+ var N, iRev, diff int
+
+ diagMat = make(map[int][]complex128)
+
+ N = 1 << logN
+
+ for i := 0; i < N; i++ {
+ iRev = int(utils.BitReverse64(uint64(i), uint64(logN)))
+
+ diff = (i - iRev) & (N - 1)
+
+ if diagMat[diff] == nil {
+ diagMat[diff] = make([]complex128, N)
+ }
+
+ diagMat[diff][iRev] = complex(1, 0)
+ }
+
+ return
+}
+
+func addToDiagMatrix(diagMat map[int][]complex128, index int, vec []complex128) {
+ if diagMat[index] == nil {
+ diagMat[index] = vec
+ } else {
+ diagMat[index] = add(diagMat[index], vec)
+ }
+}
+
+func rotate(x []complex128, n int) (y []complex128) {
+
+ y = make([]complex128, len(x))
+
+ mask := int(len(x) - 1)
+
+ // Rotates to the left
+ for i := 0; i < len(x); i++ {
+ y[i] = x[(i+n)&mask]
+ }
+
+ return
+}
+
+func mul(a, b []complex128) (res []complex128) {
+
+ res = make([]complex128, len(a))
+
+ for i := 0; i < len(a); i++ {
+ res[i] = a[i] * b[i]
+ }
+
+ return
+}
+
+func add(a, b []complex128) (res []complex128) {
+
+ res = make([]complex128, len(a))
+
+ for i := 0; i < len(a); i++ {
+ res[i] = a[i] + b[i]
+ }
+
+ return
+}
diff --git a/ckks/bootstrap_test.go b/ckks/bootstrap_test.go
index c945d2d8..21286df3 100644
--- a/ckks/bootstrap_test.go
+++ b/ckks/bootstrap_test.go
@@ -20,23 +20,25 @@ func TestBootstrap(t *testing.T) {
t.Skip("skipping bootstrapping tests for GOARCH=wasm")
}
- var err error
var testContext = new(testParams)
- paramSet := uint64(1)
+ paramSet := 0
- shemeParams := DefaultBootstrapSchemeParams[paramSet : paramSet+1]
bootstrapParams := DefaultBootstrapParams[paramSet : paramSet+1]
- for paramSet := range shemeParams {
+ for paramSet := range bootstrapParams {
- params := shemeParams[paramSet]
btpParams := bootstrapParams[paramSet]
// Insecure params for fast testing only
if !*flagLongTest {
- params.logN = 14
- params.logSlots = 13
+ btpParams.LogN = 14
+ btpParams.LogSlots = 10
+ }
+
+ params, err := btpParams.Params()
+ if err != nil {
+ panic(err)
}
if testContext, err = genTestParams(params, btpParams.H); err != nil {
@@ -44,9 +46,9 @@ func TestBootstrap(t *testing.T) {
}
for _, testSet := range []func(testContext *testParams, btpParams *BootstrappingParameters, t *testing.T){
- testChebySin,
- testChebyCos,
- testChebyCosNaive,
+ testEvalSine,
+ testCoeffsToSlots,
+ testSlotsToCoeffs,
testbootstrap,
} {
testSet(testContext, btpParams, t)
@@ -55,20 +57,17 @@ func TestBootstrap(t *testing.T) {
}
}
-func testChebySin(testContext *testParams, btpParams *BootstrappingParameters, t *testing.T) {
- t.Run(testString(testContext, "ChebySin/"), func(t *testing.T) {
+func testEvalSine(testContext *testParams, btpParams *BootstrappingParameters, t *testing.T) {
+
+ t.Run(testString(testContext, "Sin/"), func(t *testing.T) {
var err error
eval := testContext.evaluator
- params := testContext.params
-
DefaultScale := testContext.params.scale
- q := params.qi[params.MaxLevel()-uint64(len(btpParams.CtSLevel))]
-
- SineScale := math.Exp2(math.Round(math.Log2(float64(q))))
+ SineScale := btpParams.SineEvalModuli.ScalingFactor
testContext.params.scale = SineScale
eval.(*evaluator).scale = SineScale
@@ -76,8 +75,8 @@ func testChebySin(testContext *testParams, btpParams *BootstrappingParameters, t
deg := 127
K := float64(15)
- values, _, ciphertext := newTestVectorsSineBootstrapp(testContext, testContext.encryptorSk, -K+1, K-1, t)
- eval.DropLevel(ciphertext, uint64(len(btpParams.CtSLevel))-1)
+ values, _, ciphertext := newTestVectorsSineBootstrapp(testContext, btpParams, testContext.encryptorSk, -K+1, K-1, t)
+ eval.DropLevel(ciphertext, btpParams.CtSDepth(true)-1)
cheby := Approximate(sin2pi2pi, -complex(K, 0), complex(K, 0), deg)
@@ -89,44 +88,38 @@ func testChebySin(testContext *testParams, btpParams *BootstrappingParameters, t
eval.AddConst(ciphertext, (-cheby.a-cheby.b)/(cheby.b-cheby.a), ciphertext)
eval.Rescale(ciphertext, eval.(*evaluator).scale, ciphertext)
- if ciphertext, err = eval.EvaluateCheby(ciphertext, cheby); err != nil {
+ if ciphertext, err = eval.EvaluateCheby(ciphertext, cheby, ciphertext.Scale()); err != nil {
t.Error(err)
}
- verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, t)
+ verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t)
testContext.params.scale = DefaultScale
eval.(*evaluator).scale = DefaultScale
})
-}
-func testChebyCos(testContext *testParams, btpParams *BootstrappingParameters, t *testing.T) {
- t.Run(testString(testContext, "ChebyCos/"), func(t *testing.T) {
+ t.Run(testString(testContext, "Cos1/"), func(t *testing.T) {
var err error
eval := testContext.evaluator
- params := testContext.params
-
DefaultScale := testContext.params.scale
- q := params.qi[params.MaxLevel()-uint64(len(btpParams.CtSLevel))]
-
- SineScale := math.Exp2(math.Round(math.Log2(float64(q))))
+ SineScale := btpParams.SineEvalModuli.ScalingFactor
testContext.params.scale = SineScale
eval.(*evaluator).scale = SineScale
- K := 21
- deg := 52
- dev := float64(testContext.params.qi[0]) / DefaultScale
+ K := 25
+ deg := 63
+ dev := btpParams.MessageRatio
scNum := 2
scFac := complex(float64(int(1< 0 {
+ sqrt2pi = math.Pow(1, 1.0/real(scFac))
+ } else {
+ sqrt2pi = math.Pow(0.15915494309189535, 1.0/real(scFac))
+ }
for i := range cheby.coeffs {
cheby.coeffs[i] *= complex(sqrt2pi, 0)
}
+ verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t)
+
for i := range values {
values[i] = cmplx.Cos(6.283185307179586 * (1 / scFac) * (values[i] - 0.25))
@@ -149,7 +149,9 @@ func testChebyCos(testContext *testParams, btpParams *BootstrappingParameters, t
values[i] = 2*values[i]*values[i] - 1
}
- values[i] /= 6.283185307179586
+ if btpParams.ArcSineDeg == 0 {
+ values[i] /= 6.283185307179586
+ }
}
eval.AddConst(ciphertext, -0.25, ciphertext)
@@ -158,7 +160,7 @@ func testChebyCos(testContext *testParams, btpParams *BootstrappingParameters, t
eval.AddConst(ciphertext, (-cheby.a-cheby.b)/(cheby.b-cheby.a), ciphertext)
eval.Rescale(ciphertext, eval.(*evaluator).scale, ciphertext)
- if ciphertext, err = eval.EvaluateCheby(ciphertext, cheby); err != nil {
+ if ciphertext, err = eval.EvaluateCheby(ciphertext, cheby, ciphertext.Scale()); err != nil {
t.Error(err)
}
@@ -170,40 +172,38 @@ func testChebyCos(testContext *testParams, btpParams *BootstrappingParameters, t
eval.Rescale(ciphertext, eval.(*evaluator).scale, ciphertext)
}
- verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, t)
+ verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t)
testContext.params.scale = DefaultScale
eval.(*evaluator).scale = DefaultScale
})
-}
-func testChebyCosNaive(testContext *testParams, btpParams *BootstrappingParameters, t *testing.T) {
- t.Run(testString(testContext, "ChebyCosNaive/"), func(t *testing.T) {
+ t.Run(testString(testContext, "Cos2/"), func(t *testing.T) {
+
+ if len(btpParams.SineEvalModuli.Qi) < 12 {
+ t.Skip()
+ }
var err error
eval := testContext.evaluator
- params := testContext.params
-
DefaultScale := testContext.params.scale
- q := params.qi[params.MaxLevel()-uint64(len(btpParams.CtSLevel))]
-
- SineScale := math.Exp2(math.Round(math.Log2(float64(q))))
+ SineScale := btpParams.SineEvalModuli.ScalingFactor
testContext.params.scale = SineScale
eval.(*evaluator).scale = SineScale
- K := 257
- deg := 250
- scNum := 3
+ K := 325
+ deg := 255
+ scNum := 4
scFac := complex(float64(int(1< must use SinType = Cos")
- }
-
- if btpParams.CtSLevel[0] != params.MaxLevel() {
- return nil, fmt.Errorf("BootstrapParams: CtSLevel start not consistent with MaxLevel")
+ return nil, fmt.Errorf("cannot use double angle formul for SinType = Sin -> must use SinType = Cos")
}
btp = newBootstrapper(params, btpParams)
@@ -78,6 +61,8 @@ func NewBootstrapper(params *Parameters, btpParams *BootstrappingParameters, btp
}
btp.evaluator = btp.evaluator.WithKey(EvaluationKey{btpKey.Rlk, btpKey.Rtks}).(*evaluator)
+ btp.evaluator = btp.evaluator.WithKey(EvaluationKey{btpKey.Rlk, btpKey.Rtks}).(*evaluator)
+
return btp, nil
}
@@ -92,14 +77,13 @@ func newBootstrapper(params *Parameters, btpParams *BootstrappingParameters) (bt
btp.dslots = params.Slots()
btp.logdslots = params.LogSlots()
if params.logSlots < params.MaxLogSlots() {
- btp.repack = true
btp.dslots <<= 1
btp.logdslots++
}
- btp.deviation = 1024.0
- btp.prescale = math.Exp2(math.Round(math.Log2(float64(params.qi[0]) / btp.deviation)))
- btp.postscale = math.Exp2(math.Round(math.Log2(float64(params.qi[len(params.qi)-1-len(btpParams.CtSLevel)])))) / btp.deviation
+ btp.prescale = math.Exp2(math.Round(math.Log2(float64(params.qi[0]) / btp.MessageRatio)))
+ btp.sinescale = math.Exp2(math.Round(math.Log2(btp.SineEvalModuli.ScalingFactor)))
+ btp.postscale = btp.sinescale / btp.MessageRatio
btp.encoder = NewEncoder(params)
btp.evaluator = NewEvaluator(params, EvaluationKey{}).(*evaluator) // creates an evaluator without keys for genDFTMatrices
@@ -109,13 +93,6 @@ func newBootstrapper(params *Parameters, btpParams *BootstrappingParameters) (bt
btp.ctxpool = NewCiphertext(params, 1, params.MaxLevel(), 0)
- for i := range btp.poolQ {
- btp.poolQ[i] = params.NewPolyQ()
- }
-
- for i := range btp.poolP {
- btp.poolP[i] = params.NewPolyP()
- }
return btp
}
@@ -130,7 +107,7 @@ func (btp *Bootstrapper) CheckKeys() (err error) {
return fmt.Errorf("rotation key is nil")
}
- rotMissing := []uint64{}
+ rotMissing := []int{}
for _, i := range btp.rotKeyIndex {
galEl := btp.params.GaloisElementForColumnRotationBy(int(i))
if _, generated := btp.Rtks.Keys[galEl]; !generated {
@@ -145,583 +122,135 @@ func (btp *Bootstrapper) CheckKeys() (err error) {
return nil
}
-func (btp *Bootstrapper) addMatrixRotToList(pVec *dftvectors, rotations []uint64, slots uint64, repack bool) {
+// AddMatrixRotToList adds the rotations neede to evaluate pVec to the list rotations
+func AddMatrixRotToList(pVec *PtDiagMatrix, rotations []int, slots int, repack bool) []int {
- var index uint64
- for j := range pVec.Vec {
-
- N1 := pVec.N1
-
- index = ((j / N1) * N1)
-
- if repack {
- // Sparse repacking, occurring during the first DFT matrix of the CoeffsToSlots.
- index &= (2*slots - 1)
- } else {
- // Other cases
- index &= (slots - 1)
+ if pVec.naive {
+ for j := range pVec.Vec {
+ if !utils.IsInSliceInt(j, rotations) {
+ rotations = append(rotations, j)
+ }
}
+ } else {
+ var index int
+ for j := range pVec.Vec {
- if index != 0 && !utils.IsInSliceUint64(index, rotations) {
- rotations = append(rotations, index)
- }
+ N1 := pVec.N1
- index = j & (N1 - 1)
+ index = ((j / N1) * N1)
- if index != 0 && !utils.IsInSliceUint64(index, rotations) {
- rotations = append(rotations, index)
+ if repack {
+ // Sparse repacking, occurring during the first DFT matrix of the CoeffsToSlots.
+ index &= 2*slots - 1
+ } else {
+ // Other cases
+ index &= slots - 1
+ }
+
+ if index != 0 && !utils.IsInSliceInt(index, rotations) {
+ rotations = append(rotations, index)
+ }
+
+ index = j & (N1 - 1)
+
+ if index != 0 && !utils.IsInSliceInt(index, rotations) {
+ rotations = append(rotations, index)
+ }
}
}
+
+ return rotations
}
func (btp *Bootstrapper) genDFTMatrices() {
- a := real(btp.chebycoeffs.a)
- b := real(btp.chebycoeffs.b)
+ a := real(btp.sineEvalPoly.a)
+ b := real(btp.sineEvalPoly.b)
n := float64(btp.params.N())
- scFac := float64(int(1 << btp.SinRescal))
qDiff := float64(btp.params.qi[0]) / math.Exp2(math.Round(math.Log2(float64(btp.params.qi[0]))))
- // Change of variable for the evaluation of the Chebyshev polynomial + cancelling factor for the DFT and SubSum + eventual scaling factor for the double angle formula
- btp.coeffsToSlotsDiffScale = complex(math.Pow(2.0/((b-a)*n*scFac*qDiff), 1.0/float64(len(btp.CtSLevel))), 0)
+ // Change of variable for the evaluation of the Chebyshev polynomial + cancelling factor for the DFT and SubSum + evantual scaling factor for the double angle formula
+ btp.coeffsToSlotsDiffScale = complex(math.Pow(2.0/((b-a)*n*btp.scFac*qDiff), 1.0/float64(btp.CtSDepth(false))), 0)
// Rescaling factor to set the final ciphertext to the desired scale
- btp.slotsToCoeffsDiffScale = complex(math.Pow((qDiff*btp.params.scale)/btp.postscale, 1.0/float64(len(btp.StCLevel))), 0)
+ btp.slotsToCoeffsDiffScale = complex(math.Pow((qDiff*btp.params.scale)/btp.postscale, 1.0/float64(btp.StCDepth(false))), 0)
- // Computation and encoding of the matrices for CoeffsToSlots and SlotsToCoeffs.
- btp.computePlaintextVectors()
+ // CoeffsToSlots vectors
+ btp.pDFTInv = btp.BootstrappingParameters.GenCoeffsToSlotsMatrix(btp.coeffsToSlotsDiffScale, btp.encoder)
+
+ // SlotsToCoeffs vectors
+ btp.pDFT = btp.BootstrappingParameters.GenSlotsToCoeffsMatrix(btp.slotsToCoeffsDiffScale, btp.encoder)
// List of the rotation key values to needed for the bootstrapp
- btp.rotKeyIndex = []uint64{}
+ btp.rotKeyIndex = []int{}
//SubSum rotation needed X -> Y^slots rotations
for i := btp.params.logSlots; i < btp.params.MaxLogSlots(); i++ {
- if !utils.IsInSliceUint64(1< 0 {
+ btp.sqrt2pi = 1.0
+
+ coeffs := make([]complex128, btp.ArcSineDeg+1)
+
+ coeffs[1] = 0.15915494309189535
+
+ for i := 3; i < btp.ArcSineDeg+1; i += 2 {
+
+ coeffs[i] = coeffs[i-2] * complex(float64(i*i-4*i+4)/float64(i*i-i), 0)
+
+ }
+
+ btp.arcSinePoly = NewPoly(coeffs)
+
+ } else {
+ btp.sqrt2pi = math.Pow(0.15915494309189535, 1.0/btp.scFac)
+ }
+
if btp.SinType == Sin {
- K := complex(float64(btp.SinRange), 0)
- btp.chebycoeffs = Approximate(sin2pi2pi, -K, K, int(btp.SinDeg))
+ btp.sineEvalPoly = Approximate(sin2pi2pi, -complex(float64(K)/btp.scFac, 0), complex(float64(K)/btp.scFac, 0), deg)
} else if btp.SinType == Cos1 {
- K := int(btp.SinRange)
- deg := int(btp.SinDeg)
- scFac := complex(float64(int(1< invalid sineType")
}
-}
-func computeRoots(N uint64) (roots []complex128) {
-
- var angle float64
-
- m := N << 1
-
- roots = make([]complex128, m)
-
- roots[0] = 1
-
- for i := uint64(1); i < m; i++ {
- angle = 6.283185307179586 * float64(i) / float64(m)
- roots[i] = complex(math.Cos(angle), math.Sin(angle))
- }
-
- return
-}
-
-func fftPlainVec(logN uint64, roots []complex128, pow5 []uint64) (a, b, c [][]complex128) {
-
- var N, m, index, tt, gap, k, mask, idx1, idx2 uint64
-
- N = 1 << logN
-
- a = make([][]complex128, logN)
- b = make([][]complex128, logN)
- c = make([][]complex128, logN)
-
- index = 0
- for m = 2; m <= N; m <<= 1 {
-
- a[index] = make([]complex128, 2*N)
- b[index] = make([]complex128, 2*N)
- c[index] = make([]complex128, 2*N)
-
- tt = m >> 1
-
- for i := uint64(0); i < N; i += m {
-
- gap = N / m
- mask = (m << 2) - 1
-
- for j := uint64(0); j < m>>1; j++ {
-
- k = (pow5[j] & mask) * gap
-
- idx1 = i + j
- idx2 = i + j + tt
-
- for u := uint64(0); u < 2; u++ {
- a[index][idx1+u*N] = 1
- a[index][idx2+u*N] = -roots[k]
- b[index][idx1+u*N] = roots[k]
- c[index][idx2+u*N] = 1
- }
- }
- }
-
- index++
- }
-
- return
-}
-
-func fftInvPlainVec(logN uint64, roots []complex128, pow5 []uint64) (a, b, c [][]complex128) {
-
- var N, m, index, tt, gap, k, mask, idx1, idx2 uint64
-
- N = 1 << logN
-
- a = make([][]complex128, logN)
- b = make([][]complex128, logN)
- c = make([][]complex128, logN)
-
- index = 0
- for m = N; m >= 2; m >>= 1 {
-
- a[index] = make([]complex128, 2*N)
- b[index] = make([]complex128, 2*N)
- c[index] = make([]complex128, 2*N)
-
- tt = m >> 1
-
- for i := uint64(0); i < N; i += m {
-
- gap = N / m
- mask = (m << 2) - 1
-
- for j := uint64(0); j < m>>1; j++ {
-
- k = ((m << 2) - (pow5[j] & mask)) * gap
-
- idx1 = i + j
- idx2 = i + j + tt
-
- for u := uint64(0); u < 2; u++ {
-
- a[index][idx1+u*N] = 1
- a[index][idx2+u*N] = -roots[k]
- b[index][idx1+u*N] = 1
- c[index][idx2+u*N] = roots[k]
- }
- }
- }
-
- index++
- }
-
- return
-}
-
-func (btp *Bootstrapper) computePlaintextVectors() {
-
- slots := btp.params.Slots()
- dslots := btp.dslots
-
- CtSLevel := btp.CtSLevel
- StCLevel := btp.StCLevel
-
- roots := computeRoots(slots << 1)
- pow5 := make([]uint64, (slots<<1)+1)
- pow5[0] = 1
- for i := uint64(1); i < (slots<<1)+1; i++ {
- pow5[i] = pow5[i-1] * 5
- pow5[i] &= (slots << 2) - 1
- }
-
- // CoeffsToSlots vectors
- btp.pDFTInv = make([]*dftvectors, len(CtSLevel))
- pVecDFTInv := btp.computeDFTPlaintextVectors(roots, pow5, btp.coeffsToSlotsDiffScale, true)
- for i, lvl := range CtSLevel {
- btp.pDFTInv[i] = new(dftvectors)
- btp.pDFTInv[i].N1 = findbestbabygiantstepsplit(pVecDFTInv[i], dslots, btp.MaxN1N2Ratio)
- btp.encodePVec(pVecDFTInv[i], btp.pDFTInv[i], lvl, true)
- }
-
- // SlotsToCoeffs vectors
- btp.pDFT = make([]*dftvectors, len(StCLevel))
- pVecDFT := btp.computeDFTPlaintextVectors(roots, pow5, btp.slotsToCoeffsDiffScale, false)
- for i, lvl := range StCLevel {
- btp.pDFT[i] = new(dftvectors)
- btp.pDFT[i].N1 = findbestbabygiantstepsplit(pVecDFT[i], dslots, btp.MaxN1N2Ratio)
- btp.encodePVec(pVecDFT[i], btp.pDFT[i], lvl, false)
+ for i := range btp.sineEvalPoly.coeffs {
+ btp.sineEvalPoly.coeffs[i] *= complex(btp.sqrt2pi, 0)
}
}
-
-// Finds the best N1*N2 = N for the baby-step giant-step algorithm for matrix multiplication.
-func findbestbabygiantstepsplit(vector map[uint64][]complex128, maxN uint64, maxRatio float64) (minN uint64) {
-
- for N1 := uint64(1); N1 < maxN; N1 <<= 1 {
-
- index := make(map[uint64][]uint64)
-
- for key := range vector {
-
- idx1 := key / N1
- idx2 := key & (N1 - 1)
-
- if index[idx1] == nil {
- index[idx1] = []uint64{idx2}
- } else {
- index[idx1] = append(index[idx1], idx2)
- }
- }
-
- if len(index[0]) > 0 {
-
- hoisted := len(index[0]) - 1
- normal := len(index) - 1
-
- // The matrice is very sparse already
- if normal == 0 {
- return N1 / 2
- }
-
- if hoisted > normal {
- // Finds the next split that has a ratio hoisted/normal greater or equal to maxRatio
- for float64(hoisted)/float64(normal) < maxRatio {
-
- if normal/2 == 0 {
- break
- }
- N1 *= 2
- hoisted = hoisted*2 + 1
- normal = normal / 2
- }
- return N1
- }
- }
- }
-
- return 1
-}
-
-func (btp *Bootstrapper) encodePVec(pVec map[uint64][]complex128, plaintextVec *dftvectors, level uint64, forward bool) {
- var N, N1 uint64
- var scale float64
-
- // N1*N2 = N
- N = btp.params.N()
- N1 = plaintextVec.N1
-
- index := make(map[uint64][]uint64)
-
- for key := range pVec {
- idx1 := key / N1
- idx2 := key & (N1 - 1)
- if index[idx1] == nil {
- index[idx1] = []uint64{idx2}
- } else {
- index[idx1] = append(index[idx1], idx2)
- }
- }
-
- plaintextVec.Vec = make(map[uint64][2]*ring.Poly)
-
- if forward {
- scale = float64(btp.params.qi[level])
- } else {
- // If the first moduli
- logQi := math.Round(math.Log2(float64(btp.params.qi[level])))
- if logQi >= 56.0 {
- scale = math.Exp2(logQi / 2)
- } else {
- scale = float64(btp.params.qi[level])
- }
- }
-
- plaintextVec.Level = level
- plaintextVec.Scale = scale
- ringQ := btp.evaluator.ringQ
- ringP := btp.evaluator.ringP
- encoder := btp.encoder.(*encoderComplex128)
-
- for j := range index {
-
- for _, i := range index[j] {
-
- // levels * n coefficients of 8 bytes each
- btp.plaintextSize += 8 * N * (level + 1 + btp.params.PiCount())
-
- encoder.Embed(rotate(pVec[N1*j+uint64(i)], (N>>1)-(N1*j))[:btp.dslots], btp.logdslots)
-
- plaintextQ := ring.NewPoly(N, level+1)
- encoder.ScaleUp(plaintextQ, scale, ringQ.Modulus[:level+1])
- ringQ.NTTLvl(level, plaintextQ, plaintextQ)
- ringQ.MFormLvl(level, plaintextQ, plaintextQ)
-
- plaintextP := ring.NewPoly(N, level+1)
- encoder.ScaleUp(plaintextP, scale, ringP.Modulus)
- ringP.NTT(plaintextP, plaintextP)
- ringP.MForm(plaintextP, plaintextP)
-
- plaintextVec.Vec[N1*j+uint64(i)] = [2]*ring.Poly{plaintextQ, plaintextP}
-
- encoder.WipeInternalMemory()
-
- }
- }
-}
-
-func (btp *Bootstrapper) computeDFTPlaintextVectors(roots []complex128, pow5 []uint64, diffscale complex128, forward bool) (plainVector []map[uint64][]complex128) {
-
- var level, depth, nextLevel, logSlots uint64
-
- logSlots = btp.params.logSlots
- level = logSlots
-
- var a, b, c [][]complex128
- var maxDepth uint64
-
- if forward {
- maxDepth = uint64(len(btp.CtSLevel))
- a, b, c = fftInvPlainVec(btp.params.logSlots, roots, pow5)
- } else {
- maxDepth = uint64(len(btp.StCLevel))
- a, b, c = fftPlainVec(btp.params.logSlots, roots, pow5)
- }
-
- plainVector = make([]map[uint64][]complex128, maxDepth)
-
- // We compute the chain of merge in order or reverse order depending if its DFT or InvDFT because
- // the way the levels are collapsed has an inpact on the total number of rotations and keys to be
- // stored. Ex. instead of using 255 + 64 plaintext vectors, we can use 127 + 128 plaintext vectors
- // by reversing the order of the merging.
- merge := make([]uint64, maxDepth)
- for i := uint64(0); i < maxDepth; i++ {
-
- depth = uint64(math.Ceil(float64(level) / float64(maxDepth-i)))
-
- if forward {
- merge[i] = depth
- } else {
- merge[uint64(len(merge))-i-1] = depth
-
- }
-
- level -= depth
- }
-
- level = logSlots
- for i := uint64(0); i < maxDepth; i++ {
-
- if btp.repack && !forward && i == 0 {
-
- // Special initial matrix for the repacking before SlotsToCoeffs
- plainVector[i] = genWfftRepack(logSlots, level)
-
- // Merges this special initial matrix with the first layer of SlotsToCoeffs DFT
- plainVector[i] = nextLevelfft(plainVector[i], logSlots, 2<ct0/"), func(t *testing.T) {
@@ -623,7 +631,7 @@ func testEvaluatorMul(testContext *testParams, t *testing.T) {
testContext.evaluator.MulRelin(ciphertext1, plaintext1, ciphertext1)
- verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext1, t)
+ verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t)
})
t.Run(testString(testContext, "Evaluator/Mul/ct0*pt->ct1/"), func(t *testing.T) {
@@ -636,7 +644,7 @@ func testEvaluatorMul(testContext *testParams, t *testing.T) {
ciphertext2 := testContext.evaluator.MulRelinNew(ciphertext1, plaintext1)
- verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext2, t)
+ verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext2, testContext.params.LogSlots(), 0, t)
})
t.Run(testString(testContext, "Evaluator/Mul/ct0*ct1->ct0/"), func(t *testing.T) {
@@ -650,7 +658,7 @@ func testEvaluatorMul(testContext *testParams, t *testing.T) {
testContext.evaluator.MulRelin(ciphertext1, ciphertext2, ciphertext1)
- verifyTestVectors(testContext, testContext.decryptor, values2, ciphertext1, t)
+ verifyTestVectors(testContext, testContext.decryptor, values2, ciphertext1, testContext.params.LogSlots(), 0, t)
})
t.Run(testString(testContext, "Evaluator/Mul/ct0*ct1->ct1/"), func(t *testing.T) {
@@ -664,7 +672,7 @@ func testEvaluatorMul(testContext *testParams, t *testing.T) {
testContext.evaluator.MulRelin(ciphertext1, ciphertext2, ciphertext2)
- verifyTestVectors(testContext, testContext.decryptor, values2, ciphertext2, t)
+ verifyTestVectors(testContext, testContext.decryptor, values2, ciphertext2, testContext.params.LogSlots(), 0, t)
})
t.Run(testString(testContext, "Evaluator/Mul/ct0*ct1->ct2/"), func(t *testing.T) {
@@ -678,7 +686,7 @@ func testEvaluatorMul(testContext *testParams, t *testing.T) {
ciphertext3 := testContext.evaluator.MulRelinNew(ciphertext1, ciphertext2)
- verifyTestVectors(testContext, testContext.decryptor, values2, ciphertext3, t)
+ verifyTestVectors(testContext, testContext.decryptor, values2, ciphertext3, testContext.params.LogSlots(), 0, t)
})
t.Run(testString(testContext, "Evaluator/Mul/ct0*ct0->ct0/"), func(t *testing.T) {
@@ -691,7 +699,7 @@ func testEvaluatorMul(testContext *testParams, t *testing.T) {
testContext.evaluator.MulRelin(ciphertext1, ciphertext1, ciphertext1)
- verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext1, t)
+ verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t)
})
t.Run(testString(testContext, "Evaluator/Mul/ct0*ct0->ct1/"), func(t *testing.T) {
@@ -704,7 +712,7 @@ func testEvaluatorMul(testContext *testParams, t *testing.T) {
ciphertext2 := testContext.evaluator.MulRelinNew(ciphertext1, ciphertext1)
- verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext2, t)
+ verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext2, testContext.params.LogSlots(), 0, t)
})
t.Run(testString(testContext, "Evaluator/Mul/Relinearize(ct0*ct1->ct0)/"), func(t *testing.T) {
@@ -721,12 +729,11 @@ func testEvaluatorMul(testContext *testParams, t *testing.T) {
}
testContext.evaluator.Mul(ciphertext1, ciphertext2, ciphertext1)
- require.Equal(t, ciphertext1.Degree(), uint64(2))
-
+ require.Equal(t, ciphertext1.Degree(), 2)
testContext.evaluator.Relinearize(ciphertext1, ciphertext1)
- require.Equal(t, ciphertext1.Degree(), uint64(1))
+ require.Equal(t, ciphertext1.Degree(), 1)
- verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext1, t)
+ verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t)
})
t.Run(testString(testContext, "Evaluator/Mul/Relinearize(ct0*ct1->ct1)/"), func(t *testing.T) {
@@ -743,11 +750,11 @@ func testEvaluatorMul(testContext *testParams, t *testing.T) {
}
testContext.evaluator.Mul(ciphertext1, ciphertext2, ciphertext2)
- require.Equal(t, ciphertext2.Degree(), uint64(2))
+ require.Equal(t, ciphertext2.Degree(), 2)
testContext.evaluator.Relinearize(ciphertext2, ciphertext2)
- require.Equal(t, ciphertext2.Degree(), uint64(1))
+ require.Equal(t, ciphertext2.Degree(), 1)
- verifyTestVectors(testContext, testContext.decryptor, values2, ciphertext2, t)
+ verifyTestVectors(testContext, testContext.decryptor, values2, ciphertext2, testContext.params.LogSlots(), 0, t)
})
}
@@ -766,14 +773,14 @@ func testFunctions(testContext *testParams, t *testing.T) {
values, _, ciphertext := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t)
- n := uint64(2)
+ n := 2
valuesWant := make([]complex128, len(values))
for i := 0; i < len(valuesWant); i++ {
valuesWant[i] = values[i]
}
- for i := uint64(0); i < n; i++ {
+ for i := 0; i < n; i++ {
for j := 0; j < len(valuesWant); j++ {
valuesWant[j] *= valuesWant[j]
}
@@ -781,7 +788,7 @@ func testFunctions(testContext *testParams, t *testing.T) {
testContext.evaluator.PowerOf2(ciphertext, n, ciphertext)
- verifyTestVectors(testContext, testContext.decryptor, valuesWant, ciphertext, t)
+ verifyTestVectors(testContext, testContext.decryptor, valuesWant, ciphertext, testContext.params.LogSlots(), 0, t)
})
t.Run(testString(testContext, "Evaluator/Power/"), func(t *testing.T) {
@@ -796,7 +803,7 @@ func testFunctions(testContext *testParams, t *testing.T) {
values, _, ciphertext := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t)
- n := uint64(3)
+ n := 3
for i := range values {
values[i] = cmplx.Pow(values[i], complex(float64(n), 0))
@@ -804,7 +811,7 @@ func testFunctions(testContext *testParams, t *testing.T) {
testContext.evaluator.Power(ciphertext, n, ciphertext)
- verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, t)
+ verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t)
})
t.Run(testString(testContext, "Evaluator/Inverse/"), func(t *testing.T) {
@@ -819,7 +826,7 @@ func testFunctions(testContext *testParams, t *testing.T) {
values, _, ciphertext := newTestVectors(testContext, testContext.encryptorSk, complex(0.1, 0), complex(1, 0), t)
- n := uint64(7)
+ n := 7
for i := range values {
values[i] = 1.0 / values[i]
@@ -827,7 +834,7 @@ func testFunctions(testContext *testParams, t *testing.T) {
ciphertext = testContext.evaluator.InverseNew(ciphertext, n)
- verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, t)
+ verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t)
})
}
@@ -864,11 +871,11 @@ func testEvaluatePoly(testContext *testParams, t *testing.T) {
values[i] = cmplx.Exp(values[i])
}
- if ciphertext, err = testContext.evaluator.EvaluatePoly(ciphertext, poly); err != nil {
+ if ciphertext, err = testContext.evaluator.EvaluatePoly(ciphertext, poly, ciphertext.Scale()); err != nil {
t.Error(err)
}
- verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, t)
+ verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t)
})
}
@@ -900,11 +907,57 @@ func testChebyshevInterpolator(testContext *testParams, t *testing.T) {
eval.AddConst(ciphertext, (-cheby.a-cheby.b)/(cheby.b-cheby.a), ciphertext)
eval.Rescale(ciphertext, eval.(*evaluator).scale, ciphertext)
- if ciphertext, err = eval.EvaluateCheby(ciphertext, cheby); err != nil {
+ if ciphertext, err = eval.EvaluateCheby(ciphertext, cheby, ciphertext.Scale()); err != nil {
t.Error(err)
}
- verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, t)
+ verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t)
+ })
+}
+
+func testDecryptPublic(testContext *testParams, t *testing.T) {
+
+ var err error
+
+ t.Run(testString(testContext, "DecryptPublic/Sin/"), func(t *testing.T) {
+
+ if testContext.params.PiCount() == 0 {
+ t.Skip("#Pi is empty")
+ }
+
+ if testContext.params.MaxLevel() < 5 {
+ t.Skip("skipping test for params max level < 5")
+ }
+
+ eval := testContext.evaluator
+
+ values, _, ciphertext := newTestVectors(testContext, testContext.encryptorSk, complex(-1, 0), complex(1, 0), t)
+
+ cheby := Approximate(cmplx.Sin, complex(-1.5, 0), complex(1.5, 0), 15)
+
+ for i := range values {
+ values[i] = cmplx.Sin(values[i])
+ }
+
+ eval.MultByConst(ciphertext, 2/(cheby.b-cheby.a), ciphertext)
+ eval.AddConst(ciphertext, (-cheby.a-cheby.b)/(cheby.b-cheby.a), ciphertext)
+ eval.Rescale(ciphertext, eval.(*evaluator).scale, ciphertext)
+
+ if ciphertext, err = eval.EvaluateCheby(ciphertext, cheby, ciphertext.Scale()); err != nil {
+ t.Error(err)
+ }
+
+ plaintext := testContext.decryptor.DecryptNew(ciphertext)
+
+ valuesHave := testContext.encoder.Decode(plaintext, testContext.params.LogSlots())
+
+ verifyTestVectors(testContext, nil, values, valuesHave, testContext.params.LogSlots(), 0, t)
+
+ sigma := testContext.encoder.GetErrSTDCoeffDomain(values, valuesHave, plaintext.Scale())
+
+ valuesHave = testContext.encoder.DecodePublic(plaintext, testContext.params.LogSlots(), sigma)
+
+ verifyTestVectors(testContext, nil, values, valuesHave, testContext.params.LogSlots(), 0, t)
})
}
@@ -920,7 +973,7 @@ func testSwitchKeys(testContext *testParams, t *testing.T) {
switchingKey = testContext.kgen.GenSwitchingKey(testContext.sk, sk2)
}
- t.Run(testString(testContext, "SwitchKeys/InPlace/"), func(t *testing.T) {
+ t.Run(testString(testContext, "SwitchKeys/"), func(t *testing.T) {
if testContext.params.PiCount() == 0 {
t.Skip("#Pi is empty")
@@ -930,10 +983,10 @@ func testSwitchKeys(testContext *testParams, t *testing.T) {
testContext.evaluator.SwitchKeys(ciphertext, switchingKey, ciphertext)
- verifyTestVectors(testContext, decryptorSk2, values, ciphertext, t)
+ verifyTestVectors(testContext, decryptorSk2, values, ciphertext, testContext.params.LogSlots(), 0, t)
})
- t.Run(testString(testContext, "SwitchKeys/New/"), func(t *testing.T) {
+ t.Run(testString(testContext, "SwitchKeysNew/"), func(t *testing.T) {
if testContext.params.PiCount() == 0 {
t.Skip("#Pi is empty")
@@ -943,7 +996,7 @@ func testSwitchKeys(testContext *testParams, t *testing.T) {
ciphertext = testContext.evaluator.SwitchKeysNew(ciphertext, switchingKey)
- verifyTestVectors(testContext, decryptorSk2, values, ciphertext, t)
+ verifyTestVectors(testContext, decryptorSk2, values, ciphertext, testContext.params.LogSlots(), 0, t)
})
}
@@ -953,40 +1006,15 @@ func testAutomorphisms(testContext *testParams, t *testing.T) {
if testContext.params.PiCount() == 0 {
t.Skip("#Pi is empty")
}
- rots := []int{1, -1, 4, -4, 63, -63}
+ rots := []int{0, 1, -1, 4, -4, 63, -63}
rotKey := testContext.kgen.GenRotationKeysForRotations(rots, true, testContext.sk)
evaluator := testContext.evaluator.WithKey(EvaluationKey{testContext.rlk, rotKey})
- t.Run(testString(testContext, "RotateColumns/InPlace/"), func(t *testing.T) {
+ t.Run(testString(testContext, "Conjugate/"), func(t *testing.T) {
- values1, _, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t)
-
- ciphertext2 := NewCiphertext(testContext.params, ciphertext1.Degree(), ciphertext1.Level(), ciphertext1.Scale())
-
- for _, n := range rots {
-
- values2 := utils.RotateComplex128Slice(values1, n)
-
- evaluator.Rotate(ciphertext1, n, ciphertext2)
-
- verifyTestVectors(testContext, testContext.decryptor, values2, ciphertext2, t)
+ if testContext.params.PiCount() == 0 {
+ t.Skip("#Pi is empty")
}
- })
-
- t.Run(testString(testContext, "RotateColumns/New/"), func(t *testing.T) {
-
- values1, _, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t)
-
- for _, n := range rots {
-
- values2 := utils.RotateComplex128Slice(values1, n)
-
- verifyTestVectors(testContext, testContext.decryptor, values2, evaluator.RotateNew(ciphertext1, n), t)
- }
-
- })
-
- t.Run(testString(testContext, "Conjugate/InPlace/"), func(t *testing.T) {
values, _, ciphertext := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t)
@@ -996,10 +1024,14 @@ func testAutomorphisms(testContext *testParams, t *testing.T) {
evaluator.Conjugate(ciphertext, ciphertext)
- verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, t)
+ verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t)
})
- t.Run(testString(testContext, "Conjugate/New/"), func(t *testing.T) {
+ t.Run(testString(testContext, "ConjugateNew/"), func(t *testing.T) {
+
+ if testContext.params.PiCount() == 0 {
+ t.Skip("#Pi is empty")
+ }
values, _, ciphertext := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t)
@@ -1009,26 +1041,259 @@ func testAutomorphisms(testContext *testParams, t *testing.T) {
ciphertext = evaluator.ConjugateNew(ciphertext)
- verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, t)
+ verifyTestVectors(testContext, testContext.decryptor, values, ciphertext, testContext.params.LogSlots(), 0, t)
+ })
+
+ t.Run(testString(testContext, "Rotate/"), func(t *testing.T) {
+
+ values1, _, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t)
+
+ ciphertext2 := NewCiphertext(testContext.params, ciphertext1.Degree(), ciphertext1.Level(), ciphertext1.Scale())
+
+ for _, n := range rots {
+ evaluator.Rotate(ciphertext1, n, ciphertext2)
+ verifyTestVectors(testContext, testContext.decryptor, utils.RotateComplex128Slice(values1, n), ciphertext2, testContext.params.LogSlots(), 0, t)
+ }
+ })
+
+ t.Run(testString(testContext, "RotateNew/"), func(t *testing.T) {
+
+ values1, _, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t)
+
+ for _, n := range rots {
+ verifyTestVectors(testContext, testContext.decryptor, utils.RotateComplex128Slice(values1, n), evaluator.RotateNew(ciphertext1, n), testContext.params.LogSlots(), 0, t)
+ }
+
})
t.Run(testString(testContext, "RotateHoisted/"), func(t *testing.T) {
values1, _, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t)
- values2 := make([]complex128, len(values1))
-
ciphertexts := evaluator.RotateHoisted(ciphertext1, rots)
for _, n := range rots {
-
- values2 = utils.RotateComplex128Slice(values1, n)
-
- verifyTestVectors(testContext, testContext.decryptor, values2, ciphertexts[n], t)
+ verifyTestVectors(testContext, testContext.decryptor, utils.RotateComplex128Slice(values1, n), ciphertexts[n], testContext.params.LogSlots(), 0, t)
}
})
}
+func testInnerSum(testContext *testParams, t *testing.T) {
+
+ if testContext.params.PiCount() == 0 {
+ t.Skip("#Pi is empty")
+ }
+
+ t.Run(testString(testContext, "InnerSum/"), func(t *testing.T) {
+ batch := 2
+ n := 35
+
+ rotKey := testContext.kgen.GenRotationKeysForRotations(testContext.kgen.GenRotationIndexesForInnerSum(batch, n), false, testContext.sk)
+ eval := testContext.evaluator.WithKey(EvaluationKey{testContext.rlk, rotKey})
+
+ values1, _, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t)
+
+ eval.InnerSum(ciphertext1, batch, n, ciphertext1)
+
+ tmp0 := make([]complex128, len(values1))
+ copy(tmp0, values1)
+
+ for i := 1; i < n; i++ {
+
+ tmp1 := utils.RotateComplex128Slice(tmp0, i*batch)
+
+ for j := range values1 {
+ values1[j] += tmp1[j]
+ }
+ }
+
+ verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t)
+ })
+
+ t.Run(testString(testContext, "InnerSumLog/"), func(t *testing.T) {
+
+ batch := 3
+ n := 15
+
+ rotKey := testContext.kgen.GenRotationKeysForRotations(testContext.kgen.GenRotationIndexesForInnerSumLog(batch, n), false, testContext.sk)
+ eval := testContext.evaluator.WithKey(EvaluationKey{testContext.rlk, rotKey})
+
+ values1, _, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t)
+
+ eval.InnerSumLog(ciphertext1, batch, n, ciphertext1)
+
+ tmp0 := make([]complex128, len(values1))
+ copy(tmp0, values1)
+
+ for i := 1; i < n; i++ {
+
+ tmp1 := utils.RotateComplex128Slice(tmp0, i*batch)
+
+ for j := range values1 {
+ values1[j] += tmp1[j]
+ }
+ }
+
+ verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t)
+
+ })
+}
+
+func testReplicate(testContext *testParams, t *testing.T) {
+
+ if testContext.params.PiCount() == 0 {
+ t.Skip("#Pi is empty")
+ }
+
+ t.Run(testString(testContext, "Replicate/"), func(t *testing.T) {
+ batch := 2
+ n := 35
+
+ rotKey := testContext.kgen.GenRotationKeysForRotations(testContext.kgen.GenRotationIndexesForReplicate(batch, n), false, testContext.sk)
+ eval := testContext.evaluator.WithKey(EvaluationKey{testContext.rlk, rotKey})
+
+ values1, _, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t)
+
+ eval.Replicate(ciphertext1, batch, n, ciphertext1)
+
+ tmp0 := make([]complex128, len(values1))
+ copy(tmp0, values1)
+
+ for i := 1; i < n; i++ {
+
+ tmp1 := utils.RotateComplex128Slice(tmp0, i*-batch)
+
+ for j := range values1 {
+ values1[j] += tmp1[j]
+ }
+ }
+
+ verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t)
+ })
+
+ t.Run(testString(testContext, "ReplicateLog/"), func(t *testing.T) {
+
+ batch := 3
+ n := 15
+
+ rotKey := testContext.kgen.GenRotationKeysForRotations(testContext.kgen.GenRotationIndexesForReplicateLog(batch, n), false, testContext.sk)
+ eval := testContext.evaluator.WithKey(EvaluationKey{testContext.rlk, rotKey})
+
+ values1, _, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t)
+
+ eval.ReplicateLog(ciphertext1, batch, n, ciphertext1)
+
+ tmp0 := make([]complex128, len(values1))
+ copy(tmp0, values1)
+
+ for i := 1; i < n; i++ {
+
+ tmp1 := utils.RotateComplex128Slice(tmp0, i*-batch)
+
+ for j := range values1 {
+ values1[j] += tmp1[j]
+ }
+ }
+
+ verifyTestVectors(testContext, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t)
+
+ })
+}
+
+func testLinearTransform(testContext *testParams, t *testing.T) {
+
+ if testContext.params.PiCount() == 0 {
+ t.Skip("#Pi is empty")
+ }
+
+ t.Run(testString(testContext, "LinearTransform/BSGS/"), func(t *testing.T) {
+
+ params := testContext.params
+
+ values1, _, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t)
+
+ diagMatrix := make(map[int][]complex128)
+
+ diagMatrix[-15] = make([]complex128, params.Slots())
+ diagMatrix[-4] = make([]complex128, params.Slots())
+ diagMatrix[-1] = make([]complex128, params.Slots())
+ diagMatrix[0] = make([]complex128, params.Slots())
+ diagMatrix[1] = make([]complex128, params.Slots())
+ diagMatrix[4] = make([]complex128, params.Slots())
+ diagMatrix[15] = make([]complex128, params.Slots())
+
+ for i := 0; i < params.Slots(); i++ {
+ diagMatrix[-15][i] = complex(1, 0)
+ diagMatrix[-4][i] = complex(1, 0)
+ diagMatrix[-1][i] = complex(1, 0)
+ diagMatrix[0][i] = complex(1, 0)
+ diagMatrix[1][i] = complex(1, 0)
+ diagMatrix[4][i] = complex(1, 0)
+ diagMatrix[15][i] = complex(1, 0)
+ }
+
+ ptDiagMatrix := testContext.encoder.EncodeDiagMatrixBSGSAtLvl(params.MaxLevel(), diagMatrix, params.Scale(), 1.0, params.LogSlots())
+
+ rots := testContext.kgen.GenRotationIndexesForDiagMatrix(ptDiagMatrix)
+
+ rotKey := testContext.kgen.GenRotationKeysForRotations(rots, false, testContext.sk)
+
+ eval := testContext.evaluator.WithKey(EvaluationKey{testContext.rlk, rotKey})
+
+ res := eval.LinearTransform(ciphertext1, ptDiagMatrix)[0]
+
+ tmp := make([]complex128, params.Slots())
+ copy(tmp, values1)
+
+ for i := 0; i < params.Slots(); i++ {
+ values1[i] += tmp[(i-15+params.Slots())%params.Slots()]
+ values1[i] += tmp[(i-4+params.Slots())%params.Slots()]
+ values1[i] += tmp[(i-1+params.Slots())%params.Slots()]
+ values1[i] += tmp[(i+1)%params.Slots()]
+ values1[i] += tmp[(i+4)%params.Slots()]
+ values1[i] += tmp[(i+15)%params.Slots()]
+ }
+
+ verifyTestVectors(testContext, testContext.decryptor, values1, res, testContext.params.LogSlots(), 0, t)
+ })
+
+ t.Run(testString(testContext, "LinearTransform/Naive/"), func(t *testing.T) {
+
+ params := testContext.params
+
+ values1, _, ciphertext1 := newTestVectors(testContext, testContext.encryptorSk, complex(-1, -1), complex(1, 1), t)
+
+ diagMatrix := make(map[int][]complex128)
+
+ diagMatrix[-1] = make([]complex128, params.Slots())
+ diagMatrix[0] = make([]complex128, params.Slots())
+
+ for i := 0; i < params.Slots(); i++ {
+ diagMatrix[-1][i] = complex(1, 0)
+ diagMatrix[0][i] = complex(1, 0)
+ }
+
+ ptDiagMatrix := testContext.encoder.EncodeDiagMatrixAtLvl(params.MaxLevel(), diagMatrix, params.Scale(), params.LogSlots())
+
+ rots := testContext.kgen.GenRotationIndexesForDiagMatrix(ptDiagMatrix)
+
+ rotKey := testContext.kgen.GenRotationKeysForRotations(rots, false, testContext.sk)
+
+ eval := testContext.evaluator.WithKey(EvaluationKey{testContext.rlk, rotKey})
+
+ res := eval.LinearTransform(ciphertext1, ptDiagMatrix)[0]
+
+ tmp := make([]complex128, params.Slots())
+ copy(tmp, values1)
+
+ for i := 0; i < params.Slots(); i++ {
+ values1[i] += tmp[(i-1+params.Slots())%params.Slots()]
+ }
+
+ verifyTestVectors(testContext, testContext.decryptor, values1, res, testContext.params.LogSlots(), 0, t)
+ })
+}
+
func testMarshaller(testContext *testParams, t *testing.T) {
ringQP := testContext.ringQP
@@ -1064,7 +1329,7 @@ func testMarshaller(testContext *testParams, t *testing.T) {
require.Error(t, ciphertextTest.UnmarshalBinary(nil))
require.NoError(t, ciphertextTest.UnmarshalBinary(marshalledCiphertext))
- require.Equal(t, ciphertext.Degree(), uint64(0))
+ require.Equal(t, ciphertext.Degree(), 0)
require.Equal(t, ciphertext.Level(), testContext.params.MaxLevel())
require.Equal(t, ciphertext.Scale(), testContext.params.Scale())
require.Equal(t, len(ciphertext.Value()), 1)
diff --git a/ckks/decryptor.go b/ckks/decryptor.go
index 8b27c42a..73487ae2 100644
--- a/ckks/decryptor.go
+++ b/ckks/decryptor.go
@@ -30,7 +30,7 @@ type decryptor struct {
// encrypted under the provided secret-key.
func NewDecryptor(params *Parameters, sk *SecretKey) Decryptor {
- if sk.Value.GetDegree() != int(params.N()) {
+ if sk.Value.Degree() != params.N() {
panic("secret_key is invalid for the provided parameters")
}
@@ -58,7 +58,7 @@ func (decryptor *decryptor) DecryptNew(ciphertext *Ciphertext) (plaintext *Plain
func (decryptor *decryptor) Decrypt(ciphertext *Ciphertext, plaintext *Plaintext) {
- level := utils.MinUint64(ciphertext.Level(), plaintext.Level())
+ level := utils.MinInt(ciphertext.Level(), plaintext.Level())
plaintext.SetScale(ciphertext.Scale())
@@ -66,7 +66,7 @@ func (decryptor *decryptor) Decrypt(ciphertext *Ciphertext, plaintext *Plaintext
plaintext.value.Coeffs = plaintext.value.Coeffs[:ciphertext.Level()+1]
- for i := uint64(ciphertext.Degree()); i > 0; i-- {
+ for i := ciphertext.Degree(); i > 0; i-- {
decryptor.ringQ.MulCoeffsMontgomeryLvl(level, plaintext.value, decryptor.sk.Value, plaintext.value)
decryptor.ringQ.AddLvl(level, plaintext.value, ciphertext.value[i-1], plaintext.value)
diff --git a/ckks/encoder.go b/ckks/encoder.go
index 6982de1a..3853fb31 100644
--- a/ckks/encoder.go
+++ b/ckks/encoder.go
@@ -7,6 +7,7 @@ import (
"math/big"
"github.com/ldsec/lattigo/v2/ring"
+ "github.com/ldsec/lattigo/v2/utils"
)
// GaloisGen is an integer of order N/2 modulo M and that spans Z_M with the integer -1.
@@ -17,30 +18,43 @@ var pi = "3.14159265358979323846264338327950288419716939937510582097494459230781
// Encoder is an interface implenting the encoding algorithms.
type Encoder interface {
- Encode(plaintext *Plaintext, values []complex128, logSlots uint64)
- EncodeNew(values []complex128, logSlots uint64) (plaintext *Plaintext)
- EncodeAtLvlNew(level uint64, values []complex128, logSlots uint64) (plaintext *Plaintext)
- EncodeNTT(plaintext *Plaintext, values []complex128, logSlots uint64)
- EncodeNTTNew(values []complex128, logSlots uint64) (plaintext *Plaintext)
- EncodeNTTAtLvlNew(level uint64, values []complex128, logSlots uint64) (plaintext *Plaintext)
- Decode(plaintext *Plaintext, logSlots uint64) (res []complex128)
- Embed(values []complex128, logSlots uint64)
+ Encode(plaintext *Plaintext, values []complex128, logSlots int)
+ EncodeNew(values []complex128, logSlots int) (plaintext *Plaintext)
+ EncodeAtLvlNew(level int, values []complex128, logSlots int) (plaintext *Plaintext)
+
+ EncodeNTT(plaintext *Plaintext, values []complex128, logSlots int)
+ EncodeNTTNew(values []complex128, logSlots int) (plaintext *Plaintext)
+ EncodeNTTAtLvlNew(level int, values []complex128, logSlots int) (plaintext *Plaintext)
+
+ EncodeDiagMatrixBSGSAtLvl(level int, vector map[int][]complex128, scale, maxM1N2Ratio float64, logSlots int) (matrix *PtDiagMatrix)
+ EncodeDiagMatrixAtLvl(level int, vector map[int][]complex128, scale float64, logSlots int) (matrix *PtDiagMatrix)
+
+ Decode(plaintext *Plaintext, logSlots int) (res []complex128)
+ DecodePublic(plaintext *Plaintext, logSlots int, sigma float64) []complex128
+
+ Embed(values []complex128, logSlots int)
ScaleUp(pol *ring.Poly, scale float64, moduli []uint64)
+
WipeInternalMemory()
+
EncodeCoeffs(values []float64, plaintext *Plaintext)
DecodeCoeffs(plaintext *Plaintext) (res []float64)
+ DecodeCoeffsPublic(plaintext *Plaintext, bound float64) (res []float64)
+
+ GetErrSTDCoeffDomain(valuesWant, valuesHave []complex128, scale float64) (std float64)
+ GetErrSTDSlotDomain(valuesWant, valuesHave []complex128, scale float64) (std float64)
}
// EncoderBigComplex is an interface implenting the encoding algorithms with arbitrary precision.
type EncoderBigComplex interface {
- Encode(plaintext *Plaintext, values []*ring.Complex, logSlots uint64)
- EncodeNew(values []*ring.Complex, logSlots uint64) (plaintext *Plaintext)
- EncodeAtLvlNew(level uint64, values []*ring.Complex, logSlots uint64) (plaintext *Plaintext)
- EncodeNTT(plaintext *Plaintext, values []*ring.Complex, logSlots uint64)
- EncodeNTTAtLvlNew(level uint64, values []*ring.Complex, logSlots uint64) (plaintext *Plaintext)
- Decode(plaintext *Plaintext, logSlots uint64) (res []*ring.Complex)
- FFT(values []*ring.Complex, N uint64)
- InvFFT(values []*ring.Complex, N uint64)
+ Encode(plaintext *Plaintext, values []*ring.Complex, logSlots int)
+ EncodeNew(values []*ring.Complex, logSlots int) (plaintext *Plaintext)
+ EncodeAtLvlNew(level int, values []*ring.Complex, logSlots int) (plaintext *Plaintext)
+ EncodeNTT(plaintext *Plaintext, values []*ring.Complex, logSlots int)
+ EncodeNTTAtLvlNew(level int, values []*ring.Complex, logSlots int) (plaintext *Plaintext)
+ Decode(plaintext *Plaintext, logSlots int) (res []*ring.Complex)
+ FFT(values []*ring.Complex, N int)
+ InvFFT(values []*ring.Complex, N int)
//EncodeCoeffs(values []*big.Float, plaintext *Plaintext)
//DecodeCoeffs(plaintext *Plaintext) (res []*big.Float)
@@ -50,12 +64,15 @@ type EncoderBigComplex interface {
type encoder struct {
params *Parameters
ringQ *ring.Ring
+ ringP *ring.Ring
bigintChain []*big.Int
bigintCoeffs []*big.Int
qHalf *big.Int
polypool *ring.Poly
- m uint64
- rotGroup []uint64
+ m int
+ rotGroup []int
+
+ gaussianSampler *ring.GaussianSampler
}
type encoderComplex128 struct {
@@ -75,23 +92,39 @@ func newEncoder(params *Parameters) encoder {
panic(err)
}
- rotGroup := make([]uint64, m>>1)
- fivePows := uint64(1)
- for i := uint64(0); i < m>>2; i++ {
+ var p *ring.Ring
+ if params.PiCount() != 0 {
+ if p, err = ring.NewRing(params.N(), params.pi); err != nil {
+ panic(err)
+ }
+ }
+
+ rotGroup := make([]int, m>>1)
+ fivePows := 1
+ for i := 0; i < m>>2; i++ {
rotGroup[i] = fivePows
- fivePows *= GaloisGen
+ fivePows *= int(GaloisGen)
fivePows &= (m - 1)
}
+ prng, err := utils.NewPRNG()
+ if err != nil {
+ panic(err)
+ }
+
+ gaussianSampler := ring.NewGaussianSampler(prng, q, params.Sigma(), int(6*params.Sigma()))
+
return encoder{
- params: params.Copy(),
- ringQ: q,
- bigintChain: genBigIntChain(params.qi),
- bigintCoeffs: make([]*big.Int, m>>1),
- qHalf: ring.NewUint(0),
- polypool: q.NewPoly(),
- m: m,
- rotGroup: rotGroup,
+ params: params.Copy(),
+ ringQ: q,
+ ringP: p,
+ bigintChain: genBigIntChain(params.qi),
+ bigintCoeffs: make([]*big.Int, m>>1),
+ qHalf: ring.NewUint(0),
+ polypool: q.NewPoly(),
+ m: m,
+ rotGroup: rotGroup,
+ gaussianSampler: gaussianSampler,
}
}
@@ -102,7 +135,7 @@ func NewEncoder(params *Parameters) Encoder {
var angle float64
roots := make([]complex128, encoder.m+1)
- for i := uint64(0); i < encoder.m; i++ {
+ for i := 0; i < encoder.m; i++ {
angle = 2 * 3.141592653589793 * float64(i) / float64(encoder.m)
roots[i] = complex(math.Cos(angle), math.Sin(angle))
@@ -118,19 +151,19 @@ func NewEncoder(params *Parameters) Encoder {
}
// EncodeNew encodes a slice of complex128 of length slots = 2^{logSlots} on new plaintext at the maximum level.
-func (encoder *encoderComplex128) EncodeNew(values []complex128, logSlots uint64) (plaintext *Plaintext) {
+func (encoder *encoderComplex128) EncodeNew(values []complex128, logSlots int) (plaintext *Plaintext) {
return encoder.EncodeAtLvlNew(encoder.params.MaxLevel(), values, logSlots)
}
// EncodeAtLvlNew encodes a slice of complex128 of length slots = 2^{logSlots} on new plaintext at the desired level.
-func (encoder *encoderComplex128) EncodeAtLvlNew(level uint64, values []complex128, logSlots uint64) (plaintext *Plaintext) {
+func (encoder *encoderComplex128) EncodeAtLvlNew(level int, values []complex128, logSlots int) (plaintext *Plaintext) {
plaintext = NewPlaintext(encoder.params, level, encoder.params.scale)
encoder.Encode(plaintext, values, logSlots)
return
}
// Encode encodes a slice of complex128 of length slots = 2^{logSlots} on the input plaintext.
-func (encoder *encoderComplex128) Encode(plaintext *Plaintext, values []complex128, logSlots uint64) {
+func (encoder *encoderComplex128) Encode(plaintext *Plaintext, values []complex128, logSlots int) {
encoder.Embed(values, logSlots)
encoder.ScaleUp(plaintext.value, plaintext.scale, encoder.ringQ.Modulus[:plaintext.Level()+1])
encoder.WipeInternalMemory()
@@ -139,13 +172,13 @@ func (encoder *encoderComplex128) Encode(plaintext *Plaintext, values []complex1
// EncodeNTTNew encodes a slice of complex128 of length slots = 2^{logSlots} on new plaintext at the maximum level.
// Returns a plaintext in the NTT domain.
-func (encoder *encoderComplex128) EncodeNTTNew(values []complex128, logSlots uint64) (plaintext *Plaintext) {
+func (encoder *encoderComplex128) EncodeNTTNew(values []complex128, logSlots int) (plaintext *Plaintext) {
return encoder.EncodeNTTAtLvlNew(encoder.params.MaxLevel(), values, logSlots)
}
// EncodeNTTAtLvlNew encodes a slice of complex128 of length slots = 2^{logSlots} on new plaintext at the desired level.
// Returns a plaintext in the NTT domain.
-func (encoder *encoderComplex128) EncodeNTTAtLvlNew(level uint64, values []complex128, logSlots uint64) (plaintext *Plaintext) {
+func (encoder *encoderComplex128) EncodeNTTAtLvlNew(level int, values []complex128, logSlots int) (plaintext *Plaintext) {
plaintext = NewPlaintext(encoder.params, level, encoder.params.scale)
encoder.EncodeNTT(plaintext, values, logSlots)
return
@@ -153,7 +186,7 @@ func (encoder *encoderComplex128) EncodeNTTAtLvlNew(level uint64, values []compl
// EncodeNTT encodes a slice of complex128 of length slots = 2^{logSlots} on the input plaintext.
// Returns a plaintext in the NTT domain.
-func (encoder *encoderComplex128) EncodeNTT(plaintext *Plaintext, values []complex128, logSlots uint64) {
+func (encoder *encoderComplex128) EncodeNTT(plaintext *Plaintext, values []complex128, logSlots int) {
encoder.Encode(plaintext, values, logSlots)
encoder.ringQ.NTTLvl(plaintext.Level(), plaintext.value, plaintext.value)
plaintext.isNTT = true
@@ -161,11 +194,11 @@ func (encoder *encoderComplex128) EncodeNTT(plaintext *Plaintext, values []compl
// Embed encodes a vector and stores internally the encoded values.
// To be used in conjunction with ScaleUp.
-func (encoder *encoderComplex128) Embed(values []complex128, logSlots uint64) {
+func (encoder *encoderComplex128) Embed(values []complex128, logSlots int) {
- slots := uint64(1 << logSlots)
+ slots := 1 << logSlots
- if uint64(len(values)) > encoder.params.N()/2 || uint64(len(values)) > slots || logSlots > encoder.params.LogN()-1 {
+ if len(values) > encoder.params.N()/2 || len(values) > slots || logSlots > encoder.params.LogN()-1 {
panic("cannot Encode: too many values/slots for the given ring degree")
}
@@ -173,17 +206,48 @@ func (encoder *encoderComplex128) Embed(values []complex128, logSlots uint64) {
encoder.values[i] = values[i]
}
- encoder.invfft(encoder.values, slots)
+ invfft(encoder.values, slots, encoder.m, encoder.rotGroup, encoder.roots)
gap := (encoder.ringQ.N >> 1) / slots
- for i, jdx, idx := uint64(0), encoder.ringQ.N>>1, uint64(0); i < slots; i, jdx, idx = i+1, jdx+gap, idx+gap {
+ for i, jdx, idx := 0, encoder.ringQ.N>>1, 0; i < slots; i, jdx, idx = i+1, jdx+gap, idx+gap {
encoder.valuesfloat[idx] = real(encoder.values[i])
encoder.valuesfloat[jdx] = imag(encoder.values[i])
}
}
-// ScaleUp writes the internally stored encoded values on a polynomial.
+// GetErrSTDSlotDomain returns the scaled standard deviation of the difference between two complex vectors in the slot domains
+func (encoder *encoderComplex128) GetErrSTDSlotDomain(valuesWant, valuesHave []complex128, scale float64) (std float64) {
+
+ var err complex128
+ for i := range valuesWant {
+ err = valuesWant[i] - valuesHave[i]
+ encoder.valuesfloat[2*i] = real(err)
+ encoder.valuesfloat[2*i+1] = imag(err)
+ }
+
+ return StandardDeviation(encoder.valuesfloat[:len(valuesWant)*2], scale)
+}
+
+// GetErrSTDCoeffDomain returns the scaled standard deviation in the coefficient domain of the difference between two complex vectors in the slot domains
+func (encoder *encoderComplex128) GetErrSTDCoeffDomain(valuesWant, valuesHave []complex128, scale float64) (std float64) {
+
+ for i := range valuesHave {
+ encoder.values[i] = (valuesWant[i] - valuesHave[i])
+ }
+
+ invfft(encoder.values, len(valuesWant), encoder.m, encoder.rotGroup, encoder.roots)
+
+ for i := range valuesWant {
+ encoder.valuesfloat[2*i] = real(encoder.values[i])
+ encoder.valuesfloat[2*i+1] = imag(encoder.values[i])
+ }
+
+ return StandardDeviation(encoder.valuesfloat[:len(valuesWant)*2], scale)
+
+}
+
+// ScaleUp writes the internaly stored encoded values on a polynomial.
func (encoder *encoderComplex128) ScaleUp(pol *ring.Poly, scale float64, moduli []uint64) {
scaleUpVecExact(encoder.valuesfloat, scale, moduli, pol.Coeffs)
}
@@ -199,11 +263,385 @@ func (encoder *encoderComplex128) WipeInternalMemory() {
}
}
-// EncodeCoefficients takes as input a polynomial a0 + a1x + a2x^2 + ... + an-1x^n-1 with float coefficient
-// and returns a scaled integer plaintext polynomial in NTT.
+// DecodePublic decodes the Plaintext values to a slice of complex128 values of size at most N/2.
+// Adds a Gaussian error to the plaintext of variance sigma and bound floor(sqrt(2*pi)*sigma) before decoding
+func (encoder *encoderComplex128) DecodePublic(plaintext *Plaintext, logSlots int, bound float64) (res []complex128) {
+ return encoder.decodePublic(plaintext, logSlots, bound)
+}
+
+// Decode decodes the Plaintext values to a slice of complex128 values of size at most N/2.
+func (encoder *encoderComplex128) Decode(plaintext *Plaintext, logSlots int) (res []complex128) {
+ return encoder.decodePublic(plaintext, logSlots, 0)
+}
+
+func polyToComplexNoCRT(coeffs []uint64, values []complex128, scale float64, logSlots int, Q uint64) {
+
+ slots := 1 << logSlots
+ maxSlots := len(coeffs) >> 1
+ gap := maxSlots / slots
+
+ var real, imag float64
+ for i, idx := 0, 0; i < slots; i, idx = i+1, idx+gap {
+
+ if coeffs[idx] >= Q>>1 {
+ real = -float64(Q - coeffs[idx])
+ } else {
+ real = float64(coeffs[idx])
+ }
+
+ if coeffs[idx+maxSlots] >= Q>>1 {
+ imag = -float64(Q - coeffs[idx+maxSlots])
+ } else {
+ imag = float64(coeffs[idx+maxSlots])
+ }
+
+ values[i] = complex(real, imag) / complex(scale, 0)
+ }
+}
+
+func polyToComplexCRT(poly *ring.Poly, bigintCoeffs []*big.Int, values []complex128, scale float64, logSlots int, ringQ *ring.Ring, Q *big.Int) {
+
+ ringQ.PolyToBigint(poly, bigintCoeffs)
+
+ maxSlots := ringQ.N >> 1
+ slots := 1 << logSlots
+ gap := maxSlots / slots
+
+ qHalf := new(big.Int)
+ qHalf.Set(Q)
+ qHalf.Rsh(qHalf, 1)
+
+ var sign int
+
+ for i, idx := 0, 0; i < slots; i, idx = i+1, idx+gap {
+
+ // Centers the value around the current modulus
+ bigintCoeffs[idx].Mod(bigintCoeffs[idx], Q)
+ sign = bigintCoeffs[idx].Cmp(qHalf)
+ if sign == 1 || sign == 0 {
+ bigintCoeffs[idx].Sub(bigintCoeffs[idx], Q)
+ }
+
+ // Centers the value around the current modulus
+ bigintCoeffs[idx+maxSlots].Mod(bigintCoeffs[idx+maxSlots], Q)
+ sign = bigintCoeffs[idx+maxSlots].Cmp(qHalf)
+ if sign == 1 || sign == 0 {
+ bigintCoeffs[idx+maxSlots].Sub(bigintCoeffs[idx+maxSlots], Q)
+ }
+
+ values[i] = complex(scaleDown(bigintCoeffs[idx], scale), scaleDown(bigintCoeffs[idx+maxSlots], scale))
+ }
+}
+
+func (encoder *encoderComplex128) plaintextToComplex(level int, scale float64, logSlots int, p *ring.Poly, values []complex128) {
+ if level == 0 {
+ polyToComplexNoCRT(p.Coeffs[0], encoder.values, scale, logSlots, encoder.ringQ.Modulus[0])
+ } else {
+ polyToComplexCRT(p, encoder.bigintCoeffs, values, scale, logSlots, encoder.ringQ, encoder.bigintChain[level])
+ }
+}
+
+func roundComplexVector(values []complex128, bound float64) {
+ for i := range values {
+ a := math.Round(real(values[i])*bound) / bound
+ b := math.Round(imag(values[i])*bound) / bound
+ values[i] = complex(a, b)
+ }
+}
+
+func polyToFloatNoCRT(coeffs []uint64, values []float64, scale float64, Q uint64) {
+
+ for i, c := range coeffs {
+
+ if c >= Q>>1 {
+ values[i] = -float64(Q-c) / scale
+ } else {
+ values[i] = float64(c) / scale
+ }
+ }
+}
+
+// PtDiagMatrix is a struct storing a plaintext diagonalized matrix
+// ready to be evaluated on a ciphertext using evaluator.MultiplyByDiagMatrice.
+type PtDiagMatrix struct {
+ LogSlots int // Log of the number of slots of the plaintext (needed to compute the appropriate rotation keys)
+ N1 int // N1 is the number of inner loops of the baby-step giant-step algo used in the evaluation.
+ Level int // Level is the level at which the matrix is encoded (can be circuit dependant)
+ Scale float64 // Scale is the scale at which the matrix is encoded (can be circuit dependant)
+ Vec map[int][2]*ring.Poly // Vec is the matrix, in diagonal form, where each entry of vec is an indexed non zero diagonal.
+ naive bool
+ isGaussian bool // Each diagonal of the matrix is of the form [k, ..., k] for k a gaussian integer
+}
+
+func bsgsIndex(el interface{}, slots, N1 int) (index map[int][]int, rotations []int) {
+ index = make(map[int][]int)
+ rotations = []int{}
+ switch element := el.(type) {
+ case map[int][]complex128:
+ for key := range element {
+ key &= (slots - 1)
+ idx1 := key / N1
+ idx2 := key & (N1 - 1)
+ if index[idx1] == nil {
+ index[idx1] = []int{idx2}
+ } else {
+ index[idx1] = append(index[idx1], idx2)
+ }
+
+ if !utils.IsInSliceInt(idx2, rotations) {
+ rotations = append(rotations, idx2)
+ }
+ }
+ case map[int]bool:
+ for key := range element {
+ key &= (slots - 1)
+ idx1 := key / N1
+ idx2 := key & (N1 - 1)
+ if index[idx1] == nil {
+ index[idx1] = []int{idx2}
+ } else {
+ index[idx1] = append(index[idx1], idx2)
+ }
+ if !utils.IsInSliceInt(idx2, rotations) {
+ rotations = append(rotations, idx2)
+ }
+ }
+ case map[int][2]*ring.Poly:
+ for key := range element {
+ key &= (slots - 1)
+ idx1 := key / N1
+ idx2 := key & (N1 - 1)
+ if index[idx1] == nil {
+ index[idx1] = []int{idx2}
+ } else {
+ index[idx1] = append(index[idx1], idx2)
+ }
+ if !utils.IsInSliceInt(idx2, rotations) {
+ rotations = append(rotations, idx2)
+ }
+ }
+ }
+ return
+}
+
+// EncodeDiagMatrixBSGSAtLvl encodes a diagonalized plaintext matrix into PtDiagMatrix struct.
+// It can then be evaluated on a ciphertext using evaluator.LinearTransform.
+// Evaluation will use the optimized approach (doiuble hoisting and baby-step giant-step).
+// Faster if there is more than a few non-zero diagonals.
+// maxM1N2Ratio is the maximum ratio between the inner and outer loop of the baby-step giant-step algorithm used in evaluator.LinearTransform.
+// Optimal maxM1N2Ratio value is between 4 and 16 depending on the sparsity of the matrix.
+func (encoder *encoderComplex128) EncodeDiagMatrixBSGSAtLvl(level int, diagMatrix map[int][]complex128, scale, maxM1N2Ratio float64, logSlots int) (matrix *PtDiagMatrix) {
+
+ matrix = new(PtDiagMatrix)
+ matrix.LogSlots = logSlots
+ slots := 1 << logSlots
+
+ // N1*N2 = N
+ N1 := findbestbabygiantstepsplit(diagMatrix, slots, maxM1N2Ratio)
+ matrix.N1 = N1
+
+ index, _ := bsgsIndex(diagMatrix, slots, N1)
+
+ matrix.Vec = make(map[int][2]*ring.Poly)
+
+ matrix.Level = level
+ matrix.Scale = scale
+
+ for j := range index {
+
+ for _, i := range index[j] {
+
+ // manages inputs that have rotation between 0 and slots-1 or between -slots/2 and slots/2-1
+ v := diagMatrix[N1*j+i]
+ if len(v) == 0 {
+ v = diagMatrix[(N1*j+i)-slots]
+ }
+
+ matrix.Vec[N1*j+i] = encoder.encodeDiagonal(logSlots, level, scale, rotate(v, -N1*j))
+ }
+ }
+
+ return
+}
+
+// EncodeDiagMatrixAtLvl encodes a diagonalized plaintext matrix into PtDiagMatrix struct.
+// It can then be evaluated on a ciphertext using evaluator.LinearTransform.
+// Evaluation will use the naive approach (single hoisting and no baby-step giant-step).
+// Faster if there is only a few non-zero diagonals but uses more keys.
+func (encoder *encoderComplex128) EncodeDiagMatrixAtLvl(level int, diagMatrix map[int][]complex128, scale float64, logSlots int) (matrix *PtDiagMatrix) {
+
+ matrix = new(PtDiagMatrix)
+ matrix.Vec = make(map[int][2]*ring.Poly)
+ matrix.Level = level
+ matrix.Scale = scale
+ slots := 1 << logSlots
+
+ for i := range diagMatrix {
+
+ idx := i
+ if idx < 0 {
+ idx += slots
+ }
+ matrix.Vec[idx] = encoder.encodeDiagonal(logSlots, level, scale, diagMatrix[i])
+ }
+
+ matrix.naive = true
+
+ return
+}
+
+func (encoder *encoderComplex128) encodeDiagonal(logSlots, level int, scale float64, m []complex128) [2]*ring.Poly {
+
+ ringQ := encoder.ringQ
+ ringP := encoder.ringP
+
+ encoder.Embed(m, logSlots)
+
+ mQ := ringQ.NewPolyLvl(level + 1)
+ encoder.ScaleUp(mQ, scale, ringQ.Modulus[:level+1])
+ ringQ.NTTLvl(level, mQ, mQ)
+ ringQ.MFormLvl(level, mQ, mQ)
+
+ mP := ringP.NewPoly()
+ encoder.ScaleUp(mP, scale, ringP.Modulus)
+ ringP.NTT(mP, mP)
+ ringP.MForm(mP, mP)
+
+ encoder.WipeInternalMemory()
+
+ return [2]*ring.Poly{mQ, mP}
+}
+
+// Finds the best N1*N2 = N for the baby-step giant-step algorithm for matrix multiplication.
+func findbestbabygiantstepsplit(diagMatrix interface{}, maxN int, maxRatio float64) (minN int) {
+
+ for N1 := 1; N1 < maxN; N1 <<= 1 {
+
+ index, _ := bsgsIndex(diagMatrix, maxN, N1)
+
+ if len(index[0]) > 0 {
+
+ hoisted := len(index[0]) - 1
+ normal := len(index) - 1
+
+ // The matrice is very sparse already
+ if normal == 0 {
+ return N1 / 2
+ }
+
+ if hoisted > normal {
+ // Finds the next split that has a ratio hoisted/normal greater or equal to maxRatio
+ for float64(hoisted)/float64(normal) < maxRatio {
+
+ if normal/2 == 0 {
+ break
+ }
+ N1 *= 2
+ hoisted = hoisted*2 + 1
+ normal = normal / 2
+ }
+ return N1
+ }
+ }
+ }
+
+ return 1
+}
+
+func (encoder *encoderComplex128) decodePublic(plaintext *Plaintext, logSlots int, sigma float64) (res []complex128) {
+
+ if logSlots > encoder.params.LogN()-1 {
+ panic("cannot Decode: too many slots for the given ring degree")
+ }
+
+ slots := 1 << logSlots
+
+ if plaintext.isNTT {
+ encoder.ringQ.InvNTTLvl(plaintext.Level(), plaintext.value, encoder.polypool)
+ } else {
+ encoder.ringQ.CopyLvl(plaintext.Level(), plaintext.value, encoder.polypool)
+ }
+
+ // B = floor(sigma * sqrt(2*pi))
+ if sigma != 0 {
+ encoder.gaussianSampler.ReadAndAddFromDistLvl(plaintext.Level(), encoder.polypool, encoder.ringQ, sigma, int(2.5066282746310002*sigma))
+ }
+
+ encoder.plaintextToComplex(plaintext.Level(), plaintext.Scale(), logSlots, encoder.polypool, encoder.values)
+
+ fft(encoder.values, slots, encoder.m, encoder.rotGroup, encoder.roots)
+
+ res = make([]complex128, slots)
+
+ for i := range res {
+ res[i] = encoder.values[i]
+ }
+
+ for i := range encoder.values {
+ encoder.values[i] = 0
+ }
+
+ return
+}
+
+func invfft(values []complex128, N, M int, rotGroup []int, roots []complex128) {
+
+ var lenh, lenq, gap, idx int
+ var u, v complex128
+
+ for len := N; len >= 1; len >>= 1 {
+ for i := 0; i < N; i += len {
+ lenh = len >> 1
+ lenq = len << 2
+ gap = M / lenq
+ for j := 0; j < lenh; j++ {
+ idx = (lenq - (rotGroup[j] % lenq)) * gap
+ u = values[i+j] + values[i+j+lenh]
+ v = values[i+j] - values[i+j+lenh]
+ v *= roots[idx]
+ values[i+j] = u
+ values[i+j+lenh] = v
+
+ }
+ }
+ }
+
+ for i := 0; i < N; i++ {
+ values[i] /= complex(float64(N), 0)
+ }
+
+ sliceBitReverseInPlaceComplex128(values, N)
+}
+
+func fft(values []complex128, N, M int, rotGroup []int, roots []complex128) {
+
+ var lenh, lenq, gap, idx int
+ var u, v complex128
+
+ sliceBitReverseInPlaceComplex128(values, N)
+
+ for len := 2; len <= N; len <<= 1 {
+ for i := 0; i < N; i += len {
+ lenh = len >> 1
+ lenq = len << 2
+ gap = M / lenq
+ for j := 0; j < lenh; j++ {
+ idx = (rotGroup[j] % lenq) * gap
+ u = values[i+j]
+ v = values[i+j+lenh]
+ v *= roots[idx]
+ values[i+j] = u + v
+ values[i+j+lenh] = u - v
+ }
+ }
+ }
+}
+
+// EncodeCoeffs takes as input a polynomial a0 + a1x + a2x^2 + ... + an-1x^n-1 with float coefficient
+// and returns a scaled integer plaintext polynomial. Encodes at the input plaintext level.
func (encoder *encoderComplex128) EncodeCoeffs(values []float64, plaintext *Plaintext) {
- if uint64(len(values)) > encoder.params.N() {
+ if len(values) > encoder.params.N() {
panic("cannot EncodeCoeffs : too many values (maximum is N)")
}
@@ -212,16 +650,26 @@ func (encoder *encoderComplex128) EncodeCoeffs(values []float64, plaintext *Plai
plaintext.isNTT = false
}
-// EncodeCoefficients takes as input a polynomial a0 + a1x + a2x^2 + ... + an-1x^n-1 with float coefficient
-// and returns a scaled integer plaintext polynomial in NTT.
+// EncodeCoeffsNTT takes as input a polynomial a0 + a1x + a2x^2 + ... + an-1x^n-1 with float coefficient
+// and returns a scaled integer plaintext polynomial in NTT. Encodes at the input plaintext level.
func (encoder *encoderComplex128) EncodeCoeffsNTT(values []float64, plaintext *Plaintext) {
encoder.EncodeCoeffs(values, plaintext)
encoder.ringQ.NTTLvl(plaintext.Level(), plaintext.value, plaintext.value)
plaintext.isNTT = true
}
-// DecodeCoeffs takes as input a plaintext and returns the scaled down coefficient of the plaintext in float64.
+// DecodeCoeffsPublic takes as input a plaintext and returns the scaled down coefficient of the plaintext in float64.
+// Rounds the decimal part of the output (the bits under the scale) to "logPrecision" bits of precision.
+func (encoder *encoderComplex128) DecodeCoeffsPublic(plaintext *Plaintext, sigma float64) (res []float64) {
+ return encoder.decodeCoeffsPublic(plaintext, sigma)
+}
+
func (encoder *encoderComplex128) DecodeCoeffs(plaintext *Plaintext) (res []float64) {
+ return encoder.decodeCoeffsPublic(plaintext, 0)
+}
+
+// DecodeCoeffs takes as input a plaintext and returns the scaled down coefficient of the plaintext in float64.
+func (encoder *encoderComplex128) decodeCoeffsPublic(plaintext *Plaintext, sigma float64) (res []float64) {
if plaintext.isNTT {
encoder.ringQ.InvNTTLvl(plaintext.Level(), plaintext.value, encoder.polypool)
@@ -229,6 +677,11 @@ func (encoder *encoderComplex128) DecodeCoeffs(plaintext *Plaintext) (res []floa
encoder.ringQ.CopyLvl(plaintext.Level(), plaintext.value, encoder.polypool)
}
+ if sigma != 0 {
+ // B = floor(sigma * sqrt(2*pi))
+ encoder.gaussianSampler.ReadAndAddFromDistLvl(plaintext.Level(), encoder.polypool, encoder.ringQ, sigma, int(2.5066282746310002*sigma))
+ }
+
res = make([]float64, encoder.params.N())
// We have more than one moduli and need the CRT reconstruction
@@ -276,160 +729,19 @@ func (encoder *encoderComplex128) DecodeCoeffs(plaintext *Plaintext) (res []floa
return
}
-// Decode decodes the Plaintext values to a slice of complex128 values of size at most N/2.
-func (encoder *encoderComplex128) Decode(plaintext *Plaintext, logSlots uint64) (res []complex128) {
-
- if logSlots > encoder.params.LogN()-1 {
- panic("cannot Decode: too many slots for the given ring degree")
- }
-
- slots := uint64(1 << logSlots)
-
- if plaintext.isNTT {
- encoder.ringQ.InvNTTLvl(plaintext.Level(), plaintext.value, encoder.polypool)
- } else {
- encoder.ringQ.CopyLvl(plaintext.Level(), plaintext.value, encoder.polypool)
- }
-
- maxSlots := encoder.ringQ.N >> 1
- gap := maxSlots / slots
-
- // We have more than one moduli and need the CRT reconstruction
-
- if plaintext.Level() == 0 {
-
- Q := encoder.ringQ.Modulus[0]
- coeffs := encoder.polypool.Coeffs[0]
-
- var real, imag float64
- for i, idx := uint64(0), uint64(0); i < slots; i, idx = i+1, idx+gap {
-
- if coeffs[idx] >= Q>>1 {
- real = -float64(Q - coeffs[idx])
- } else {
- real = float64(coeffs[idx])
- }
-
- if coeffs[idx+maxSlots] >= Q>>1 {
- imag = -float64(Q - coeffs[idx+maxSlots])
- } else {
- imag = float64(coeffs[idx+maxSlots])
- }
-
- encoder.values[i] = complex(real, imag) / complex(plaintext.scale, 0)
- }
- } else {
-
- encoder.ringQ.PolyToBigint(encoder.polypool, encoder.bigintCoeffs)
-
- Q := encoder.bigintChain[plaintext.Level()]
-
- encoder.qHalf.Set(Q)
- encoder.qHalf.Rsh(encoder.qHalf, 1)
-
- var sign int
-
- for i, idx := uint64(0), uint64(0); i < slots; i, idx = i+1, idx+gap {
-
- // Centers the value around the current modulus
- encoder.bigintCoeffs[idx].Mod(encoder.bigintCoeffs[idx], Q)
- sign = encoder.bigintCoeffs[idx].Cmp(encoder.qHalf)
- if sign == 1 || sign == 0 {
- encoder.bigintCoeffs[idx].Sub(encoder.bigintCoeffs[idx], Q)
- }
-
- // Centers the value around the current modulus
- encoder.bigintCoeffs[idx+maxSlots].Mod(encoder.bigintCoeffs[idx+maxSlots], Q)
- sign = encoder.bigintCoeffs[idx+maxSlots].Cmp(encoder.qHalf)
- if sign == 1 || sign == 0 {
- encoder.bigintCoeffs[idx+maxSlots].Sub(encoder.bigintCoeffs[idx+maxSlots], Q)
- }
-
- encoder.values[i] = complex(scaleDown(encoder.bigintCoeffs[idx], plaintext.scale), scaleDown(encoder.bigintCoeffs[idx+maxSlots], plaintext.scale))
- }
- // We can directly get the coefficients
- }
-
- encoder.fft(encoder.values, slots)
-
- res = make([]complex128, slots)
-
- for i := range res {
- res[i] = encoder.values[i]
- }
-
- for i := uint64(0); i < encoder.ringQ.N>>1; i++ {
- encoder.values[i] = 0
- }
-
- return
-}
-
-func (encoder *encoderComplex128) invfft(values []complex128, N uint64) {
-
- var lenh, lenq, gap, idx uint64
- var u, v complex128
-
- for len := N; len >= 1; len >>= 1 {
- for i := uint64(0); i < N; i += len {
- lenh = len >> 1
- lenq = len << 2
- gap = encoder.m / lenq
- for j := uint64(0); j < lenh; j++ {
- idx = (lenq - (encoder.rotGroup[j] % lenq)) * gap
- u = values[i+j] + values[i+j+lenh]
- v = values[i+j] - values[i+j+lenh]
- v *= encoder.roots[idx]
- values[i+j] = u
- values[i+j+lenh] = v
-
- }
- }
- }
-
- for i := uint64(0); i < N; i++ {
- values[i] /= complex(float64(N), 0)
- }
-
- sliceBitReverseInPlaceComplex128(values, N)
-}
-
-func (encoder *encoderComplex128) fft(values []complex128, N uint64) {
-
- var lenh, lenq, gap, idx uint64
- var u, v complex128
-
- sliceBitReverseInPlaceComplex128(values, N)
-
- for len := uint64(2); len <= N; len <<= 1 {
- for i := uint64(0); i < N; i += len {
- lenh = len >> 1
- lenq = len << 2
- gap = encoder.m / lenq
- for j := uint64(0); j < lenh; j++ {
- idx = (encoder.rotGroup[j] % lenq) * gap
- u = values[i+j]
- v = values[i+j+lenh]
- v *= encoder.roots[idx]
- values[i+j] = u + v
- values[i+j+lenh] = u - v
- }
- }
- }
-}
-
type encoderBigComplex struct {
encoder
- zero *big.Float
- cMul *ring.ComplexMultiplier
- logPrecision uint64
- values []*ring.Complex
- valuesfloat []*big.Float
- roots []*ring.Complex
+ zero *big.Float
+ cMul *ring.ComplexMultiplier
+ logPrecision int
+ values []*ring.Complex
+ valuesfloat []*big.Float
+ roots []*ring.Complex
+ gaussianSampler *ring.GaussianSampler
}
// NewEncoderBigComplex creates a new encoder using arbitrary precision complex arithmetic.
-func NewEncoderBigComplex(params *Parameters, logPrecision uint64) EncoderBigComplex {
+func NewEncoderBigComplex(params *Parameters, logPrecision int) EncoderBigComplex {
encoder := newEncoder(params)
var PI = new(big.Float)
@@ -443,7 +755,7 @@ func NewEncoderBigComplex(params *Parameters, logPrecision uint64) EncoderBigCom
var angle *big.Float
roots := make([]*ring.Complex, encoder.m+1)
- for i := uint64(0); i < encoder.m; i++ {
+ for i := 0; i < encoder.m; i++ {
angle = ring.NewFloat(2, logPrecision)
angle.Mul(angle, PI)
angle.Mul(angle, ring.NewFloat(float64(i), logPrecision))
@@ -461,7 +773,7 @@ func NewEncoderBigComplex(params *Parameters, logPrecision uint64) EncoderBigCom
values := make([]*ring.Complex, encoder.m>>2)
valuesfloat := make([]*big.Float, encoder.m>>1)
- for i := uint64(0); i < encoder.m>>2; i++ {
+ for i := 0; i < encoder.m>>2; i++ {
values[i] = ring.NewComplex(ring.NewFloat(0, logPrecision), ring.NewFloat(0, logPrecision))
valuesfloat[i*2] = ring.NewFloat(0, logPrecision)
@@ -480,12 +792,12 @@ func NewEncoderBigComplex(params *Parameters, logPrecision uint64) EncoderBigCom
}
// EncodeNew encodes a slice of ring.Complex of length slots = 2^{logSlots} on a new plaintext at the maximum level.
-func (encoder *encoderBigComplex) EncodeNew(values []*ring.Complex, logSlots uint64) (plaintext *Plaintext) {
+func (encoder *encoderBigComplex) EncodeNew(values []*ring.Complex, logSlots int) (plaintext *Plaintext) {
return encoder.EncodeAtLvlNew(encoder.params.MaxLevel(), values, logSlots)
}
// EncodeAtLvlNew encodes a slice of ring.Complex of length slots = 2^{logSlots} on a new plaintext at the desired level.
-func (encoder *encoderBigComplex) EncodeAtLvlNew(level uint64, values []*ring.Complex, logSlots uint64) (plaintext *Plaintext) {
+func (encoder *encoderBigComplex) EncodeAtLvlNew(level int, values []*ring.Complex, logSlots int) (plaintext *Plaintext) {
plaintext = NewPlaintext(encoder.params, level, encoder.params.scale)
encoder.Encode(plaintext, values, logSlots)
return
@@ -493,13 +805,13 @@ func (encoder *encoderBigComplex) EncodeAtLvlNew(level uint64, values []*ring.Co
// EncodeNTTNew encodes a slice of ring.Complex of length slots = 2^{logSlots} on a plaintext at the maximum level.
// Returns a plaintext in the NTT domain.
-func (encoder *encoderBigComplex) EncodeNTTNew(values []*ring.Complex, logSlots uint64) (plaintext *Plaintext) {
+func (encoder *encoderBigComplex) EncodeNTTNew(values []*ring.Complex, logSlots int) (plaintext *Plaintext) {
return encoder.EncodeNTTAtLvlNew(encoder.params.MaxLevel(), values, logSlots)
}
// EncodeNTTAtLvlNew encodes a slice of ring.Complex of length slots = 2^{logSlots} on a plaintext at the desired level.
// Returns a plaintext in the NTT domain.
-func (encoder *encoderBigComplex) EncodeNTTAtLvlNew(level uint64, values []*ring.Complex, logSlots uint64) (plaintext *Plaintext) {
+func (encoder *encoderBigComplex) EncodeNTTAtLvlNew(level int, values []*ring.Complex, logSlots int) (plaintext *Plaintext) {
plaintext = NewPlaintext(encoder.params, encoder.params.MaxLevel(), encoder.params.scale)
encoder.EncodeNTT(plaintext, values, logSlots)
return
@@ -507,26 +819,26 @@ func (encoder *encoderBigComplex) EncodeNTTAtLvlNew(level uint64, values []*ring
// Encode encodes a slice of ring.Complex of length slots = 2^{logSlots} on a plaintext at the input plaintext level.
// Returns a plaintext in the NTT domain.
-func (encoder *encoderBigComplex) EncodeNTT(plaintext *Plaintext, values []*ring.Complex, logSlots uint64) {
+func (encoder *encoderBigComplex) EncodeNTT(plaintext *Plaintext, values []*ring.Complex, logSlots int) {
encoder.Encode(plaintext, values, logSlots)
encoder.ringQ.NTTLvl(plaintext.Level(), plaintext.value, plaintext.value)
plaintext.isNTT = true
}
// Encode encodes a slice of ring.Complex of length slots = 2^{logSlots} on a plaintext at the input plaintext level.
-func (encoder *encoderBigComplex) Encode(plaintext *Plaintext, values []*ring.Complex, logSlots uint64) {
+func (encoder *encoderBigComplex) Encode(plaintext *Plaintext, values []*ring.Complex, logSlots int) {
- slots := uint64(1 << logSlots)
+ slots := 1 << logSlots
- if uint64(len(values)) > encoder.params.N()/2 || uint64(len(values)) > slots || logSlots > encoder.params.LogN()-1 {
+ if len(values) > encoder.params.N()/2 || len(values) > slots || logSlots > encoder.params.LogN()-1 {
panic("cannot Encode: too many values/slots for the given ring degree")
}
- if uint64(len(values)) != slots {
+ if len(values) != slots {
panic("cannot Encode: number of values must be equal to slots")
}
- for i := uint64(0); i < slots; i++ {
+ for i := 0; i < slots; i++ {
encoder.values[i].Set(values[i])
}
@@ -534,7 +846,7 @@ func (encoder *encoderBigComplex) Encode(plaintext *Plaintext, values []*ring.Co
gap := (encoder.ringQ.N >> 1) / slots
- for i, jdx, idx := uint64(0), (encoder.ringQ.N >> 1), uint64(0); i < slots; i, jdx, idx = i+1, jdx+gap, idx+gap {
+ for i, jdx, idx := 0, (encoder.ringQ.N >> 1), 0; i < slots; i, jdx, idx = i+1, jdx+gap, idx+gap {
encoder.valuesfloat[idx].Set(encoder.values[i].Real())
encoder.valuesfloat[jdx].Set(encoder.values[i].Imag())
}
@@ -545,26 +857,42 @@ func (encoder *encoderBigComplex) Encode(plaintext *Plaintext, values []*ring.Co
encoder.ringQ.PolyToBigint(plaintext.value, coeffsBigInt)
- for i := uint64(0); i < (encoder.ringQ.N >> 1); i++ {
+ for i := 0; i < (encoder.ringQ.N >> 1); i++ {
encoder.values[i].Real().Set(encoder.zero)
encoder.values[i].Imag().Set(encoder.zero)
}
- for i := uint64(0); i < encoder.ringQ.N; i++ {
+ for i := 0; i < encoder.ringQ.N; i++ {
encoder.valuesfloat[i].Set(encoder.zero)
}
}
-// Decode decodes the Plaintext values to a slice of complex128 values of size at most N/2.
-func (encoder *encoderBigComplex) Decode(plaintext *Plaintext, logSlots uint64) (res []*ring.Complex) {
+// DecodePublic decodes the Plaintext values to a slice of complex128 values of size at most N/2.
+// Adds a Gaussian error to the plaintext of variance sigma and bound floor(sqrt(2*pi)*sigma) before decoding
+func (encoder *encoderBigComplex) DecodePublic(plaintext *Plaintext, logSlots int, sigma float64) (res []*ring.Complex) {
+ return encoder.decodePublic(plaintext, logSlots, sigma)
+}
- slots := uint64(1 << logSlots)
+// Decode decodes the Plaintext values to a slice of complex128 values of size at most N/2.
+func (encoder *encoderBigComplex) Decode(plaintext *Plaintext, logSlots int) (res []*ring.Complex) {
+ return encoder.decodePublic(plaintext, logSlots, 0)
+}
+
+func (encoder *encoderBigComplex) decodePublic(plaintext *Plaintext, logSlots int, sigma float64) (res []*ring.Complex) {
+
+ slots := 1 << logSlots
if logSlots > encoder.params.LogN()-1 {
panic("cannot Decode: too many slots for the given ring degree")
}
encoder.ringQ.InvNTTLvl(plaintext.Level(), plaintext.value, encoder.polypool)
+
+ if sigma != 0 {
+ // B = floor(sigma * sqrt(2*pi))
+ encoder.gaussianSampler.ReadAndAddFromDistLvl(plaintext.Level(), encoder.polypool, encoder.ringQ, sigma, int(2.5066282746310002*sigma+0.5))
+ }
+
encoder.ringQ.PolyToBigint(encoder.polypool, encoder.bigintCoeffs)
Q := encoder.bigintChain[plaintext.Level()]
@@ -580,7 +908,7 @@ func (encoder *encoderBigComplex) Decode(plaintext *Plaintext, logSlots uint64)
var sign int
- for i, idx := uint64(0), uint64(0); i < slots; i, idx = i+1, idx+gap {
+ for i, idx := 0, 0; i < slots; i, idx = i+1, idx+gap {
// Centers the value around the current modulus
encoder.bigintCoeffs[idx].Mod(encoder.bigintCoeffs[idx], Q)
@@ -611,7 +939,7 @@ func (encoder *encoderBigComplex) Decode(plaintext *Plaintext, logSlots uint64)
res[i] = encoder.values[i].Copy()
}
- for i := uint64(0); i < maxSlots; i++ {
+ for i := 0; i < maxSlots; i++ {
encoder.values[i].Real().Set(encoder.zero)
encoder.values[i].Imag().Set(encoder.zero)
}
@@ -620,18 +948,18 @@ func (encoder *encoderBigComplex) Decode(plaintext *Plaintext, logSlots uint64)
}
// InvFFT evaluates the encoding matrix on a slice fo ring.Complex values.
-func (encoder *encoderBigComplex) InvFFT(values []*ring.Complex, N uint64) {
+func (encoder *encoderBigComplex) InvFFT(values []*ring.Complex, N int) {
- var lenh, lenq, gap, idx uint64
+ var lenh, lenq, gap, idx int
u := ring.NewComplex(nil, nil)
v := ring.NewComplex(nil, nil)
for len := N; len >= 1; len >>= 1 {
- for i := uint64(0); i < N; i += len {
+ for i := 0; i < N; i += len {
lenh = len >> 1
lenq = len << 2
gap = encoder.m / lenq
- for j := uint64(0); j < lenh; j++ {
+ for j := 0; j < lenh; j++ {
idx = (lenq - (encoder.rotGroup[j] % lenq)) * gap
u.Add(values[i+j], values[i+j+lenh])
v.Sub(values[i+j], values[i+j+lenh])
@@ -652,21 +980,21 @@ func (encoder *encoderBigComplex) InvFFT(values []*ring.Complex, N uint64) {
}
// FFT evaluates the decoding matrix on a slice fo ring.Complex values.
-func (encoder *encoderBigComplex) FFT(values []*ring.Complex, N uint64) {
+func (encoder *encoderBigComplex) FFT(values []*ring.Complex, N int) {
- var lenh, lenq, gap, idx uint64
+ var lenh, lenq, gap, idx int
u := ring.NewComplex(nil, nil)
v := ring.NewComplex(nil, nil)
sliceBitReverseInPlaceRingComplex(values, N)
- for len := uint64(2); len <= N; len <<= 1 {
- for i := uint64(0); i < N; i += len {
+ for len := 2; len <= N; len <<= 1 {
+ for i := 0; i < N; i += len {
lenh = len >> 1
lenq = len << 2
gap = encoder.m / lenq
- for j := uint64(0); j < lenh; j++ {
+ for j := 0; j < lenh; j++ {
idx = (encoder.rotGroup[j] % lenq) * gap
u.Set(values[i+j])
v.Set(values[i+j+lenh])
diff --git a/ckks/encryptor.go b/ckks/encryptor.go
index da8aef55..289eb561 100644
--- a/ckks/encryptor.go
+++ b/ckks/encryptor.go
@@ -59,11 +59,10 @@ type encryptor struct {
poolQ [3]*ring.Poly
poolP [3]*ring.Poly
- baseconverter *ring.FastBasisExtender
- gaussianSamplerQ *ring.GaussianSampler
- ternarySamplerQ *ring.TernarySampler
- uniformSamplerQ *ring.UniformSampler
- uniformSamplerP *ring.UniformSampler
+ baseconverter *ring.FastBasisExtender
+ gaussianSampler *ring.GaussianSampler
+ ternarySampler *ring.TernarySampler
+ uniformSampler *ring.UniformSampler
}
type pkEncryptor struct {
@@ -81,7 +80,7 @@ type skEncryptor struct {
func NewEncryptorFromPk(params *Parameters, pk *PublicKey) Encryptor {
enc := newEncryptor(params)
- if uint64(pk.Value[0].GetDegree()) != params.N() || uint64(pk.Value[1].GetDegree()) != params.N() {
+ if pk.Value[0].Degree() != params.N() || pk.Value[1].Degree() != params.N() {
panic("cannot newEncryptor: pk ring degree does not match params ring degree")
}
@@ -93,7 +92,7 @@ func NewEncryptorFromPk(params *Parameters, pk *PublicKey) Encryptor {
func NewEncryptorFromSk(params *Parameters, sk *SecretKey) Encryptor {
enc := newEncryptor(params)
- if uint64(sk.Value.GetDegree()) != params.N() {
+ if sk.Value.Degree() != params.N() {
panic("cannot newEncryptor: sk ring degree does not match params ring degree")
}
@@ -115,7 +114,6 @@ func newEncryptor(params *Parameters) encryptor {
var baseconverter *ring.FastBasisExtender
var poolP [3]*ring.Poly
- var uniformSamplerP *ring.UniformSampler
if params.PiCount() != 0 {
if p, err = ring.NewRing(params.N(), params.pi); err != nil {
@@ -125,21 +123,18 @@ func newEncryptor(params *Parameters) encryptor {
baseconverter = ring.NewFastBasisExtender(q, p)
poolP = [3]*ring.Poly{p.NewPoly(), p.NewPoly(), p.NewPoly()}
-
- uniformSamplerP = ring.NewUniformSampler(prng, p)
}
return encryptor{
- params: params.Copy(),
- ringQ: q,
- ringP: p,
- poolQ: [3]*ring.Poly{q.NewPoly(), q.NewPoly(), q.NewPoly()},
- poolP: poolP,
- baseconverter: baseconverter,
- gaussianSamplerQ: ring.NewGaussianSampler(prng, q, params.sigma, uint64(6*params.sigma)),
- ternarySamplerQ: ring.NewTernarySampler(prng, q, 0.5, false),
- uniformSamplerQ: ring.NewUniformSampler(prng, q),
- uniformSamplerP: uniformSamplerP,
+ params: params.Copy(),
+ ringQ: q,
+ ringP: p,
+ poolQ: [3]*ring.Poly{q.NewPoly(), q.NewPoly(), q.NewPoly()},
+ poolP: poolP,
+ baseconverter: baseconverter,
+ gaussianSampler: ring.NewGaussianSampler(prng, q, params.Sigma(), int(6*params.Sigma())),
+ ternarySampler: ring.NewTernarySampler(prng, q, 0.5, false),
+ uniformSampler: ring.NewUniformSampler(prng, q),
}
}
@@ -195,7 +190,7 @@ func (encryptor *pkEncryptor) EncryptFromCRPNew(plaintext *Plaintext, crp *ring.
// encrypt with sk: ciphertext = [-a*sk + m + e, a]
func (encryptor *pkEncryptor) encrypt(plaintext *Plaintext, ciphertext *Ciphertext, fast bool) {
- lvl := utils.MinUint64(plaintext.Level(), ciphertext.Level())
+ lvl := utils.MinInt(plaintext.Level(), ciphertext.Level())
poolQ0 := encryptor.poolQ[0]
poolQ1 := encryptor.poolQ[1]
@@ -210,7 +205,7 @@ func (encryptor *pkEncryptor) encrypt(plaintext *Plaintext, ciphertext *Cipherte
if fast {
- encryptor.ternarySamplerQ.ReadLvl(lvl, poolQ2)
+ encryptor.ternarySampler.ReadLvl(lvl, poolQ2)
ringQ.NTTLvl(lvl, poolQ2, poolQ2)
ringQ.MFormLvl(lvl, poolQ2, poolQ2)
@@ -220,14 +215,14 @@ func (encryptor *pkEncryptor) encrypt(plaintext *Plaintext, ciphertext *Cipherte
ringQ.MulCoeffsMontgomeryLvl(lvl, poolQ2, encryptor.pk.Value[1], ciphertext.value[1])
// ct1 = u*pk1 + e1
- encryptor.gaussianSamplerQ.ReadLvl(lvl, poolQ0)
+ encryptor.gaussianSampler.ReadLvl(lvl, poolQ0)
ringQ.NTTLvl(lvl, poolQ0, poolQ0)
ringQ.AddLvl(lvl, ciphertext.value[1], poolQ0, ciphertext.value[1])
if !plaintext.isNTT {
// ct0 = u*pk0 + e0
- encryptor.gaussianSamplerQ.ReadLvl(lvl, poolQ0)
+ encryptor.gaussianSampler.ReadLvl(lvl, poolQ0)
// ct0 = (u*pk0 + e0)/P + m
ringQ.AddLvl(lvl, poolQ0, plaintext.value, poolQ0)
ringQ.NTTLvl(lvl, poolQ0, poolQ0)
@@ -235,7 +230,7 @@ func (encryptor *pkEncryptor) encrypt(plaintext *Plaintext, ciphertext *Cipherte
} else {
// ct0 = u*pk0 + e0
- encryptor.gaussianSamplerQ.ReadLvl(lvl, poolQ0)
+ encryptor.gaussianSampler.ReadLvl(lvl, poolQ0)
ringQ.NTTLvl(lvl, poolQ0, poolQ0)
ringQ.AddLvl(lvl, ciphertext.value[0], poolQ0, ciphertext.value[0])
ringQ.AddLvl(lvl, ciphertext.value[0], plaintext.value, ciphertext.value[0])
@@ -245,7 +240,7 @@ func (encryptor *pkEncryptor) encrypt(plaintext *Plaintext, ciphertext *Cipherte
ringP := encryptor.ringP
- encryptor.ternarySamplerQ.ReadLvl(lvl, poolQ2)
+ encryptor.ternarySampler.ReadLvl(lvl, poolQ2)
extendBasisSmallNormAndCenter(ringQ, ringP, poolQ2, poolP2)
@@ -275,13 +270,13 @@ func (encryptor *pkEncryptor) encrypt(plaintext *Plaintext, ciphertext *Cipherte
ringP.InvNTT(poolP1, poolP1)
// ct0 = u*pk0 + e0
- encryptor.gaussianSamplerQ.ReadLvl(lvl, poolQ2)
+ encryptor.gaussianSampler.ReadLvl(lvl, poolQ2)
extendBasisSmallNormAndCenter(ringQ, ringP, poolQ2, poolP2)
ringQ.AddLvl(lvl, poolQ0, poolQ2, poolQ0)
ringP.Add(poolP0, poolP2, poolP0)
// ct1 = u*pk1 + e1
- encryptor.gaussianSamplerQ.ReadLvl(lvl, poolQ2)
+ encryptor.gaussianSampler.ReadLvl(lvl, poolQ2)
extendBasisSmallNormAndCenter(ringQ, ringP, poolQ2, poolP2)
ringQ.AddLvl(lvl, poolQ1, poolQ2, poolQ1)
ringP.Add(poolP1, poolP2, poolP1)
@@ -344,7 +339,7 @@ func (encryptor *skEncryptor) EncryptFromCRP(plaintext *Plaintext, ciphertext *C
}
func (encryptor *skEncryptor) encryptSample(plaintext *Plaintext, ciphertext *Ciphertext) {
- encryptor.uniformSamplerQ.Readlvl(utils.MinUint64(plaintext.Level(), ciphertext.Level()), ciphertext.value[1])
+ encryptor.uniformSampler.Readlvl(utils.MinInt(plaintext.Level(), ciphertext.Level()), ciphertext.value[1])
encryptor.encrypt(plaintext, ciphertext, ciphertext.value[1])
}
@@ -352,7 +347,7 @@ func (encryptor *skEncryptor) encrypt(plaintext *Plaintext, ciphertext *Cipherte
ringQ := encryptor.ringQ
- lvl := utils.MinUint64(plaintext.Level(), ciphertext.Level())
+ lvl := utils.MinInt(plaintext.Level(), ciphertext.Level())
poolQ0 := encryptor.poolQ[0]
@@ -360,12 +355,12 @@ func (encryptor *skEncryptor) encrypt(plaintext *Plaintext, ciphertext *Cipherte
ringQ.NegLvl(lvl, ciphertext.value[0], ciphertext.value[0])
if plaintext.isNTT {
- encryptor.gaussianSamplerQ.ReadLvl(lvl, poolQ0)
+ encryptor.gaussianSampler.ReadLvl(lvl, poolQ0)
ringQ.NTTLvl(lvl, poolQ0, poolQ0)
ringQ.AddLvl(lvl, ciphertext.value[0], poolQ0, ciphertext.value[0])
ringQ.AddLvl(lvl, ciphertext.value[0], plaintext.value, ciphertext.value[0])
} else {
- encryptor.gaussianSamplerQ.ReadLvl(lvl, poolQ0)
+ encryptor.gaussianSampler.ReadLvl(lvl, poolQ0)
ringQ.AddLvl(lvl, poolQ0, plaintext.value, poolQ0)
ringQ.NTTLvl(lvl, poolQ0, poolQ0)
ringQ.AddLvl(lvl, ciphertext.value[0], poolQ0, ciphertext.value[0])
@@ -382,7 +377,7 @@ func extendBasisSmallNormAndCenter(ringQ, ringP *ring.Ring, polQ, polP *ring.Pol
Q = ringQ.Modulus[0]
QHalf = Q >> 1
- for j := uint64(0); j < ringQ.N; j++ {
+ for j := 0; j < ringQ.N; j++ {
coeff = polQ.Coeffs[0][j]
diff --git a/ckks/evaluator.go b/ckks/evaluator.go
index 02b51018..c85a3bba 100644
--- a/ckks/evaluator.go
+++ b/ckks/evaluator.go
@@ -13,58 +13,124 @@ import (
// Evaluator is an interface implementing the methodes to conduct homomorphic operations between ciphertext and/or plaintexts.
type Evaluator interface {
+ // ========================
+ // === Basic Arithmetic ===
+ // ========================
+
+ // Addition
Add(op0, op1 Operand, ctOut *Ciphertext)
AddNoMod(op0, op1 Operand, ctOut *Ciphertext)
AddNew(op0, op1 Operand) (ctOut *Ciphertext)
AddNoModNew(op0, op1 Operand) (ctOut *Ciphertext)
+
+ // Subtraction
Sub(op0, op1 Operand, ctOut *Ciphertext)
SubNoMod(op0, op1 Operand, ctOut *Ciphertext)
SubNew(op0, op1 Operand) (ctOut *Ciphertext)
SubNoModNew(op0, op1 Operand) (ctOut *Ciphertext)
- Neg(ct0 *Ciphertext, ctOut *Ciphertext)
- NegNew(ct0 *Ciphertext) (ctOut *Ciphertext)
- AddConstNew(ct0 *Ciphertext, constant interface{}) (ctOut *Ciphertext)
- AddConst(ct0 *Ciphertext, constant interface{}, ctOut *Ciphertext)
- MultByConstAndAdd(ct0 *Ciphertext, constant interface{}, ctOut *Ciphertext)
- MultByConstNew(ct0 *Ciphertext, constant interface{}) (ctOut *Ciphertext)
- MultByConst(ct0 *Ciphertext, constant interface{}, ctOut *Ciphertext)
- MultByGaussianInteger(ct0 *Ciphertext, cReal, cImag int64, ctOut *Ciphertext)
- MultByGaussianIntegerAndAdd(ct0 *Ciphertext, cReal, cImag int64, ctOut *Ciphertext)
- MultByiNew(ct0 *Ciphertext) (ctOut *Ciphertext)
- MultByi(ct0 *Ciphertext, ct1 *Ciphertext)
- DivByiNew(ct0 *Ciphertext) (ctOut *Ciphertext)
- DivByi(ct0 *Ciphertext, ct1 *Ciphertext)
- ScaleUpNew(ct0 *Ciphertext, scale float64) (ctOut *Ciphertext)
- ScaleUp(ct0 *Ciphertext, scale float64, ctOut *Ciphertext)
- SetScale(ct *Ciphertext, scale float64)
- MulByPow2New(ct0 *Ciphertext, pow2 uint64) (ctOut *Ciphertext)
- MulByPow2(ct0 *Element, pow2 uint64, ctOut *Element)
- ReduceNew(ct0 *Ciphertext) (ctOut *Ciphertext)
- Reduce(ct0 *Ciphertext, ctOut *Ciphertext) error
- DropLevelNew(ct0 *Ciphertext, levels uint64) (ctOut *Ciphertext)
- DropLevel(ct0 *Ciphertext, levels uint64)
- Rescale(ct0 *Ciphertext, threshold float64, c1 *Ciphertext) (err error)
- RescaleNew(ct0 *Ciphertext, threshold float64) (ctOut *Ciphertext, err error)
- RescaleMany(ct0 *Ciphertext, nbRescales uint64, c1 *Ciphertext) (err error)
+
+ // Negation
+ Neg(ctIn *Ciphertext, ctOut *Ciphertext)
+ NegNew(ctIn *Ciphertext) (ctOut *Ciphertext)
+
+ // Constant Addition
+ AddConstNew(ctIn *Ciphertext, constant interface{}) (ctOut *Ciphertext)
+ AddConst(ctIn *Ciphertext, constant interface{}, ctOut *Ciphertext)
+
+ // Constant Multiplication
+ MultByConstNew(ctIn *Ciphertext, constant interface{}) (ctOut *Ciphertext)
+ MultByConst(ctIn *Ciphertext, constant interface{}, ctOut *Ciphertext)
+ MultByGaussianInteger(ctIn *Ciphertext, cReal, cImag int64, ctOut *Ciphertext)
+
+ // Constant Multiplication with Addition
+ MultByConstAndAdd(ctIn *Ciphertext, constant interface{}, ctOut *Ciphertext)
+ MultByGaussianIntegerAndAdd(ctIn *Ciphertext, cReal, cImag int64, ctOut *Ciphertext)
+
+ // Multiplication by the imaginary unit
+ MultByiNew(ctIn *Ciphertext) (ctOut *Ciphertext)
+ MultByi(ctIn *Ciphertext, ctOut *Ciphertext)
+ DivByiNew(ctIn *Ciphertext) (ctOut *Ciphertext)
+ DivByi(ctIn *Ciphertext, ctOut *Ciphertext)
+
+ // Conjugation
+ ConjugateNew(ctIn *Ciphertext) (ctOut *Ciphertext)
+ Conjugate(ctIn *Ciphertext, ctOut *Ciphertext)
+
+ // Multiplication
Mul(op0, op1 Operand, ctOut *Ciphertext)
MulNew(op0, op1 Operand) (ctOut *Ciphertext)
- MulRelinNew(op0, op1 Operand) (ctOut *Ciphertext)
MulRelin(op0, op1 Operand, ctOut *Ciphertext)
- RelinearizeNew(ct0 *Ciphertext) (ctOut *Ciphertext)
- Relinearize(ct0 *Ciphertext, ctOut *Ciphertext)
- SwitchKeysNew(ct0 *Ciphertext, switchingKey *SwitchingKey) (ctOut *Ciphertext)
- SwitchKeys(ct0 *Ciphertext, switchingKey *SwitchingKey, ctOut *Ciphertext)
- RotateNew(ct0 *Ciphertext, k int) (ctOut *Ciphertext)
- Rotate(ct0 *Ciphertext, k int, ctOut *Ciphertext)
- RotateHoisted(ctIn *Ciphertext, rotations []int) (cOut map[int]*Ciphertext)
- ConjugateNew(ct0 *Ciphertext) (ctOut *Ciphertext)
- Conjugate(ct0 *Ciphertext, ctOut *Ciphertext)
- PowerOf2(el0 *Ciphertext, logPow2 uint64, elOut *Ciphertext)
- PowerNew(op *Ciphertext, degree uint64) (opOut *Ciphertext)
- Power(ct0 *Ciphertext, degree uint64, res *Ciphertext)
- InverseNew(ct0 *Ciphertext, steps uint64) (res *Ciphertext)
- EvaluatePoly(ct *Ciphertext, coeffs *Poly) (res *Ciphertext, err error)
- EvaluateCheby(ct *Ciphertext, cheby *ChebyshevInterpolation) (res *Ciphertext, err error)
+ MulRelinNew(op0, op1 Operand) (ctOut *Ciphertext)
+
+ // Slot Rotations
+ RotateNew(ctIn *Ciphertext, k int) (ctOut *Ciphertext)
+ Rotate(ctIn *Ciphertext, k int, ctOut *Ciphertext)
+ RotateHoisted(ctIn *Ciphertext, rotations []int) (ctOut map[int]*Ciphertext)
+
+ // ===========================
+ // === Advanced Arithmetic ===
+ // ===========================
+
+ // Multiplication by 2^{s}
+ MulByPow2New(ctIn *Ciphertext, pow2 int) (ctOut *Ciphertext)
+ MulByPow2(ctIn *Element, pow2 int, ctOut *Element)
+
+ // Exponentiation
+ PowerOf2(ctIn *Ciphertext, logPow2 int, ctOut *Ciphertext)
+ Power(ctIn *Ciphertext, degree int, ctOut *Ciphertext)
+ PowerNew(ctIn *Ciphertext, degree int) (ctOut *Ciphertext)
+
+ // Polynomial evaluation
+ EvaluatePoly(ctIn *Ciphertext, coeffs *Poly, targetScale float64) (ctOut *Ciphertext, err error)
+ EvaluateCheby(ctIn *Ciphertext, cheby *ChebyshevInterpolation, targetScale float64) (ctOut *Ciphertext, err error)
+
+ // Inversion
+ InverseNew(ctIn *Ciphertext, steps int) (ctOut *Ciphertext)
+
+ // Linear Transformations
+ LinearTransform(ctIn *Ciphertext, linearTransform interface{}) (ctOut []*Ciphertext)
+ MultiplyByDiagMatrix(ctIn *Ciphertext, matrix *PtDiagMatrix, c2QiQDecomp, c2QiPDecomp []*ring.Poly, ctOut *Ciphertext)
+ MultiplyByDiagMatrixBSGS(ctIn *Ciphertext, matrix *PtDiagMatrix, c2QiQDecomp, c2QiPDecomp []*ring.Poly, ctOut *Ciphertext)
+
+ // Inner sum
+ InnerSumLog(ctIn *Ciphertext, batch, n int, ctOut *Ciphertext)
+ InnerSum(ctIn *Ciphertext, batch, n int, ctOut *Ciphertext)
+
+ // Replicatation (inverse of Inner sum)
+ ReplicateLog(ctIn *Ciphertext, batch, n int, ctOut *Ciphertext)
+ Replicate(ctIn *Ciphertext, batch, n int, ctOut *Ciphertext)
+
+ // =============================
+ // === Ciphertext Management ===
+ // =============================
+
+ // Key-Switching
+ SwitchKeysNew(ctIn *Ciphertext, switchingKey *SwitchingKey) (ctOut *Ciphertext)
+ SwitchKeys(ctIn *Ciphertext, switchingKey *SwitchingKey, ctOut *Ciphertext)
+
+ // Degree Management
+ RelinearizeNew(ctIn *Ciphertext) (ctOut *Ciphertext)
+ Relinearize(ctIn *Ciphertext, ctOut *Ciphertext)
+
+ // Scale Management
+ ScaleUpNew(ctIn *Ciphertext, scale float64) (ctOut *Ciphertext)
+ ScaleUp(ctIn *Ciphertext, scale float64, ctOut *Ciphertext)
+ SetScale(ctIn *Ciphertext, scale float64)
+ Rescale(ctIn *Ciphertext, minScale float64, ctOut *Ciphertext) (err error)
+
+ // Level Management
+ DropLevelNew(ctIn *Ciphertext, levels int) (ctOut *Ciphertext)
+ DropLevel(ctIn *Ciphertext, levels int)
+
+ // Modular Overflow Management
+ ReduceNew(ctIn *Ciphertext) (ctOut *Ciphertext)
+ Reduce(ctIn *Ciphertext, ctOut *Ciphertext) error
+
+ // ==============
+ // === Others ===
+ // ==============
+
+ DecompInternal(level int, c2NTT *ring.Poly, c2QiQDecomp, c2QiPDecomp []*ring.Poly)
ShallowCopy() Evaluator
WithKey(EvaluationKey) Evaluator
}
@@ -93,10 +159,13 @@ type evaluatorBase struct {
}
type evaluatorBuffers struct {
- poolQ [4]*ring.Poly // Memory pool in order : Decomp(c2), for NTT^-1(c2), res(c0', c1')
- poolP [3]*ring.Poly // Memory pool in order : Decomp(c2), res(c0', c1')
- poolQMul [3]*ring.Poly // Memory pool in order : for MForm(c0), MForm(c1), c2
- ctxpool *Ciphertext // Memory pool for ciphertext that need to be scaled up (to be removed eventually)
+ poolQ [6]*ring.Poly // Memory pool in order : Decomp(c2), for NTT^-1(c2), res(c0', c1')
+ poolP [6]*ring.Poly // Memory pool in order : Decomp(c2), res(c0', c1')
+ poolQMul [3]*ring.Poly // Memory pool in order : for MForm(c0), MForm(c1), c2
+ poolInvNTT *ring.Poly
+ c2QiQDecomp []*ring.Poly // Memory pool for the basis extension in hoisting
+ c2QiPDecomp []*ring.Poly // Memory pool for the basis extension in hoisting
+ ctxpool *Ciphertext // Memory pool for ciphertext that need to be scaled up (to be removed eventually)
}
func newEvaluatorBase(params *Parameters) *evaluatorBase {
@@ -120,12 +189,21 @@ func newEvaluatorBase(params *Parameters) *evaluatorBase {
func newEvaluatorBuffers(evalBase *evaluatorBase) *evaluatorBuffers {
buff := new(evaluatorBuffers)
ringQ, ringP := evalBase.ringQ, evalBase.ringP
- buff.poolQ = [4]*ring.Poly{ringQ.NewPoly(), ringQ.NewPoly(), ringQ.NewPoly(), ringQ.NewPoly()}
+ buff.poolQ = [6]*ring.Poly{ringQ.NewPoly(), ringQ.NewPoly(), ringQ.NewPoly(), ringQ.NewPoly(), ringQ.NewPoly(), ringQ.NewPoly()}
buff.poolQMul = [3]*ring.Poly{ringQ.NewPoly(), ringQ.NewPoly(), ringQ.NewPoly()}
if evalBase.params.PiCount() > 0 {
- buff.poolP = [3]*ring.Poly{ringP.NewPoly(), ringP.NewPoly(), ringP.NewPoly()}
+ buff.poolP = [6]*ring.Poly{ringP.NewPoly(), ringP.NewPoly(), ringP.NewPoly(), ringP.NewPoly(), ringP.NewPoly(), ringP.NewPoly()}
+
+ buff.c2QiQDecomp = make([]*ring.Poly, evalBase.params.Beta())
+ buff.c2QiPDecomp = make([]*ring.Poly, evalBase.params.Beta())
+
+ for i := 0; i < evalBase.params.Beta(); i++ {
+ buff.c2QiQDecomp[i] = ringQ.NewPoly()
+ buff.c2QiPDecomp[i] = ringP.NewPoly()
+ }
}
- buff.ctxpool = NewCiphertext(evalBase.params, 1, evalBase.params.MaxLevel(), evalBase.params.scale)
+ buff.poolInvNTT = ringQ.NewPoly()
+ buff.ctxpool = NewCiphertext(evalBase.params, 2, evalBase.params.MaxLevel(), evalBase.params.scale)
return buff
}
@@ -156,7 +234,7 @@ func (eval *evaluator) permuteNTTIndexesForKey(rtks *RotationKeySet) *map[uint64
}
permuteNTTIndex := make(map[uint64][]uint64, len(rtks.Keys))
for galEl := range rtks.Keys {
- permuteNTTIndex[galEl] = ring.PermuteNTTIndex(galEl, eval.ringQ.N)
+ permuteNTTIndex[galEl] = ring.PermuteNTTIndex(galEl, uint64(eval.ringQ.N))
}
return &permuteNTTIndex
}
@@ -194,7 +272,7 @@ func (eval *evaluator) WithKey(evaluationKey EvaluationKey) Evaluator {
}
}
-func (eval *evaluator) getElemAndCheckBinary(op0, op1, opOut Operand, opOutMinDegree uint64) (el0, el1, elOut *Element) {
+func (eval *evaluator) getElemAndCheckBinary(op0, op1, opOut Operand, opOutMinDegree int) (el0, el1, elOut *Element) {
if op0 == nil || op1 == nil || opOut == nil {
panic("operands cannot be nil")
}
@@ -221,22 +299,22 @@ func (eval *evaluator) getElemAndCheckBinary(op0, op1, opOut Operand, opOutMinDe
func (eval *evaluator) newCiphertextBinary(op0, op1 Operand) (ctOut *Ciphertext) {
- maxDegree := utils.MaxUint64(op0.Degree(), op1.Degree())
+ maxDegree := utils.MaxInt(op0.Degree(), op1.Degree())
maxScale := utils.MaxFloat64(op0.Scale(), op1.Scale())
- minLevel := utils.MinUint64(op0.Level(), op1.Level())
+ minLevel := utils.MinInt(op0.Level(), op1.Level())
return NewCiphertext(eval.params, maxDegree, minLevel, maxScale)
}
// Add adds op0 to op1 and returns the result in ctOut.
func (eval *evaluator) Add(op0, op1 Operand, ctOut *Ciphertext) {
- el0, el1, elOut := eval.getElemAndCheckBinary(op0, op1, ctOut, utils.MaxUint64(op0.Degree(), op1.Degree()))
+ el0, el1, elOut := eval.getElemAndCheckBinary(op0, op1, ctOut, utils.MaxInt(op0.Degree(), op1.Degree()))
eval.evaluateInPlace(el0, el1, elOut, eval.ringQ.AddLvl)
}
// AddNoMod adds op0 to op1 and returns the result in ctOut, without modular reduction.
func (eval *evaluator) AddNoMod(op0, op1 Operand, ctOut *Ciphertext) {
- el0, el1, elOut := eval.getElemAndCheckBinary(op0, op1, ctOut, utils.MaxUint64(op0.Degree(), op1.Degree()))
+ el0, el1, elOut := eval.getElemAndCheckBinary(op0, op1, ctOut, utils.MaxInt(op0.Degree(), op1.Degree()))
eval.evaluateInPlace(el0, el1, elOut, eval.ringQ.AddNoModLvl)
}
@@ -257,11 +335,11 @@ func (eval *evaluator) AddNoModNew(op0, op1 Operand) (ctOut *Ciphertext) {
// Sub subtracts op1 from op0 and returns the result in ctOut.
func (eval *evaluator) Sub(op0, op1 Operand, ctOut *Ciphertext) {
- el0, el1, elOut := eval.getElemAndCheckBinary(op0, op1, ctOut, utils.MaxUint64(op0.Degree(), op1.Degree()))
+ el0, el1, elOut := eval.getElemAndCheckBinary(op0, op1, ctOut, utils.MaxInt(op0.Degree(), op1.Degree()))
eval.evaluateInPlace(el0, el1, elOut, eval.ringQ.SubLvl)
- level := utils.MinUint64(utils.MinUint64(el0.Level(), el1.Level()), elOut.Level())
+ level := utils.MinInt(utils.MinInt(el0.Level(), el1.Level()), elOut.Level())
if el0.Degree() < el1.Degree() {
for i := el0.Degree() + 1; i < el1.Degree()+1; i++ {
@@ -274,11 +352,11 @@ func (eval *evaluator) Sub(op0, op1 Operand, ctOut *Ciphertext) {
// SubNoMod subtracts op1 from op0 and returns the result in ctOut, without modular reduction.
func (eval *evaluator) SubNoMod(op0, op1 Operand, ctOut *Ciphertext) {
- el0, el1, elOut := eval.getElemAndCheckBinary(op0, op1, ctOut, utils.MaxUint64(op0.Degree(), op1.Degree()))
+ el0, el1, elOut := eval.getElemAndCheckBinary(op0, op1, ctOut, utils.MaxInt(op0.Degree(), op1.Degree()))
eval.evaluateInPlace(el0, el1, elOut, eval.ringQ.SubNoModLvl)
- level := utils.MinUint64(utils.MinUint64(el0.Level(), el1.Level()), elOut.Level())
+ level := utils.MinInt(utils.MinInt(el0.Level(), el1.Level()), elOut.Level())
if el0.Degree() < el1.Degree() {
for i := el0.Degree() + 1; i < el1.Degree()+1; i++ {
@@ -302,37 +380,36 @@ func (eval *evaluator) SubNoModNew(op0, op1 Operand) (ctOut *Ciphertext) {
return
}
-func (eval *evaluator) evaluateInPlace(c0, c1, ctOut *Element, evaluate func(uint64, *ring.Poly, *ring.Poly, *ring.Poly)) {
+func (eval *evaluator) evaluateInPlace(c0, c1, ctOut *Element, evaluate func(int, *ring.Poly, *ring.Poly, *ring.Poly)) {
var tmp0, tmp1 *Element
- level := utils.MinUint64(utils.MinUint64(c0.Level(), c1.Level()), ctOut.Level())
+ level := utils.MinInt(utils.MinInt(c0.Level(), c1.Level()), ctOut.Level())
- maxDegree := utils.MaxUint64(c0.Degree(), c1.Degree())
- minDegree := utils.MinUint64(c0.Degree(), c1.Degree())
+ maxDegree := utils.MaxInt(c0.Degree(), c1.Degree())
+ minDegree := utils.MinInt(c0.Degree(), c1.Degree())
// Else resizes the receiver element
ctOut.Resize(eval.params, maxDegree)
- eval.DropLevel(ctOut.Ciphertext(), ctOut.Level()-utils.MinUint64(c0.Level(), c1.Level()))
+
+ if ctOut.Level() > level {
+ eval.DropLevel(ctOut.Ciphertext(), ctOut.Level()-utils.MinInt(c0.Level(), c1.Level()))
+ }
// Checks whether or not the receiver element is the same as one of the input elements
// and acts accordingly to avoid unnecessary element creation or element overwriting,
// and scales properly the element before the evaluation.
if ctOut == c0 {
- if c0.Scale() > c1.Scale() {
+ if c0.Scale() > c1.Scale() && math.Floor(c0.Scale()/c1.Scale()) > 1 {
tmp1 = eval.ctxpool.El()
- if uint64(c0.Scale()/c1.Scale()) != 0 {
- eval.MultByConst(c1.Ciphertext(), uint64(c0.Scale()/c1.Scale()), tmp1.Ciphertext())
- }
+ eval.MultByConst(c1.Ciphertext(), math.Floor(c0.Scale()/c1.Scale()), tmp1.Ciphertext())
- } else if c1.Scale() > c0.Scale() {
+ } else if c1.Scale() > c0.Scale() && math.Floor(c1.Scale()/c0.Scale()) > 1 {
- if uint64(c1.Scale()/c0.Scale()) != 0 {
- eval.MultByConst(c0.Ciphertext(), uint64(c1.Scale()/c0.Scale()), c0.Ciphertext())
- }
+ eval.MultByConst(c0.Ciphertext(), math.Floor(c1.Scale()/c0.Scale()), c0.Ciphertext())
c0.SetScale(c1.Scale())
@@ -347,18 +424,15 @@ func (eval *evaluator) evaluateInPlace(c0, c1, ctOut *Element, evaluate func(uin
} else if ctOut == c1 {
- if c1.Scale() > c0.Scale() {
+ if c1.Scale() > c0.Scale() && math.Floor(c1.Scale()/c0.Scale()) > 1 {
tmp0 = eval.ctxpool.El()
- if uint64(c1.Scale()/c0.Scale()) != 0 {
- eval.MultByConst(c0.Ciphertext(), uint64(c1.Scale()/c0.Scale()), tmp0.Ciphertext())
- }
- } else if c0.Scale() > c1.Scale() {
+ eval.MultByConst(c0.Ciphertext(), math.Floor(c1.Scale()/c0.Scale()), tmp0.Ciphertext())
- if uint64(c0.Scale()/c1.Scale()) != 0 {
- eval.MultByConst(c1.Ciphertext(), uint64(c0.Scale()/c1.Scale()), ctOut.Ciphertext())
- }
+ } else if c0.Scale() > c1.Scale() && math.Floor(c0.Scale()/c1.Scale()) > 1 {
+
+ eval.MultByConst(c1.Ciphertext(), math.Floor(c0.Scale()/c1.Scale()), ctOut.Ciphertext())
ctOut.SetScale(c0.Scale())
@@ -373,23 +447,19 @@ func (eval *evaluator) evaluateInPlace(c0, c1, ctOut *Element, evaluate func(uin
} else {
- if c1.Scale() > c0.Scale() {
+ if c1.Scale() > c0.Scale() && math.Floor(c1.Scale()/c0.Scale()) > 1 {
tmp0 = eval.ctxpool.El()
- if uint64(c1.Scale()/c0.Scale()) != 0 {
- eval.MultByConst(c0.Ciphertext(), uint64(c1.Scale()/c0.Scale()), tmp0.Ciphertext())
- }
+ eval.MultByConst(c0.Ciphertext(), math.Floor(c1.Scale()/c0.Scale()), tmp0.Ciphertext())
tmp1 = c1
- } else if c0.Scale() > c1.Scale() {
+ } else if c0.Scale() > c1.Scale() && math.Floor(c0.Scale()/c1.Scale()) > 1 {
tmp1 = eval.ctxpool.El()
- if uint64(c0.Scale()/c1.Scale()) != 0 {
- eval.MultByConst(c1.Ciphertext(), uint64(c0.Scale()/c1.Scale()), tmp1.Ciphertext())
- }
+ eval.MultByConst(c1.Ciphertext(), math.Floor(c0.Scale()/c1.Scale()), tmp1.Ciphertext())
tmp0 = c0
@@ -399,7 +469,7 @@ func (eval *evaluator) evaluateInPlace(c0, c1, ctOut *Element, evaluate func(uin
}
}
- for i := uint64(0); i < minDegree+1; i++ {
+ for i := 0; i < minDegree+1; i++ {
evaluate(level, tmp0.Value()[i], tmp1.Value()[i], ctOut.Value()[i])
}
@@ -422,7 +492,7 @@ func (eval *evaluator) evaluateInPlace(c0, c1, ctOut *Element, evaluate func(uin
// Neg negates the value of ct0 and returns the result in ctOut.
func (eval *evaluator) Neg(ct0 *Ciphertext, ctOut *Ciphertext) {
- level := utils.MinUint64(ct0.Level(), ctOut.Level())
+ level := utils.MinInt(ct0.Level(), ctOut.Level())
if ct0.Degree() != ctOut.Degree() {
panic("cannot Negate: invalid receiver Ciphertext does not match input Ciphertext degree")
@@ -431,6 +501,8 @@ func (eval *evaluator) Neg(ct0 *Ciphertext, ctOut *Ciphertext) {
for i := range ct0.value {
eval.ringQ.NegLvl(level, ct0.value[i], ctOut.Value()[i])
}
+
+ ctOut.SetScale(ct0.Scale())
}
// NegNew negates ct0 and returns the result in a newly created element.
@@ -447,7 +519,7 @@ func (eval *evaluator) AddConstNew(ct0 *Ciphertext, constant interface{}) (ctOut
return ctOut
}
-func (eval *evaluator) getConstAndScale(level uint64, constant interface{}) (cReal, cImag, scale float64) {
+func (eval *evaluator) getConstAndScale(level int, constant interface{}) (cReal, cImag, scale float64) {
// Converts to float64 and determines if a scaling is required (which is the case if either real or imag have a rational part)
scale = 1
@@ -506,19 +578,21 @@ func (eval *evaluator) getConstAndScale(level uint64, constant interface{}) (cRe
// AddConst adds the input constant (which can be a uint64, int64, float64 or complex128) to ct0 and returns the result in ctOut.
func (eval *evaluator) AddConst(ct0 *Ciphertext, constant interface{}, ctOut *Ciphertext) {
- var level = utils.MinUint64(ct0.Level(), ctOut.Level())
+ var level = utils.MinInt(ct0.Level(), ctOut.Level())
var scaledConst, scaledConstReal, scaledConstImag uint64
cReal, cImag, _ := eval.getConstAndScale(level, constant)
ringQ := eval.ringQ
+ ctOut.SetScale(ct0.Scale())
+
// Component wise addition of the following vector to the ciphertext:
// [a + b*psi_qi^2, ....., a + b*psi_qi^2, a - b*psi_qi^2, ...., a - b*psi_qi^2] mod Qi
// [{ N/2 }{ N/2 }]
// Which is equivalent outside of the NTT domain to adding a to the first coefficient of ct0 and b to the N/2-th coefficient of ct0.
var qi uint64
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
scaledConstReal = 0
scaledConstImag = 0
scaledConst = 0
@@ -531,14 +605,14 @@ func (eval *evaluator) AddConst(ct0 *Ciphertext, constant interface{}, ctOut *Ci
}
if cImag != 0 {
- scaledConstImag = ring.MRed(scaleUpExact(cImag, ctOut.Scale(), qi), ringQ.GetNttPsi()[i][1], qi, ringQ.GetMredParams()[i])
+ scaledConstImag = ring.MRed(scaleUpExact(cImag, ctOut.Scale(), qi), ringQ.NttPsi[i][1], qi, ringQ.MredParams[i])
scaledConst = ring.CRed(scaledConst+scaledConstImag, qi)
}
p1tmp := ctOut.Value()[0].Coeffs[i]
p0tmp := ct0.value[0].Coeffs[i]
- for j := uint64(0); j < ringQ.N>>1; j = j + 8 {
+ for j := 0; j < ringQ.N>>1; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p0tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
@@ -581,7 +655,7 @@ func (eval *evaluator) AddConst(ct0 *Ciphertext, constant interface{}, ctOut *Ci
// The scale of the receiver element will be set to the scale that the input element would have after the multiplication by the constant.
func (eval *evaluator) MultByConstAndAdd(ct0 *Ciphertext, constant interface{}, ctOut *Ciphertext) {
- var level = utils.MinUint64(ct0.Level(), ctOut.Level())
+ var level = utils.MinInt(ct0.Level(), ctOut.Level())
// Forces a drop of ctOut level to ct0 level
if ctOut.Level() > level {
@@ -602,9 +676,9 @@ func (eval *evaluator) MultByConstAndAdd(ct0 *Ciphertext, constant interface{},
// then brings ctOut scale to ct0's scale.
if ctOut.Scale() < ct0.Scale()*scale {
- if uint64((scale*ct0.Scale())/ctOut.Scale()) != 0 {
+ if scale := math.Floor((scale * ct0.Scale()) / ctOut.Scale()); scale > 1 {
- eval.MultByConst(ctOut, uint64((scale*ct0.Scale())/ctOut.Scale()), ctOut)
+ eval.MultByConst(ctOut, scale, ctOut)
}
@@ -626,8 +700,8 @@ func (eval *evaluator) MultByConstAndAdd(ct0 *Ciphertext, constant interface{},
} else if ct0.Scale() > ctOut.Scale() {
- if uint64(ct0.Scale()/ctOut.Scale()) != 0 {
- eval.MultByConst(ctOut, ct0.Scale()/ctOut.Scale(), ctOut)
+ if scale := math.Floor(ct0.Scale() / ctOut.Scale()); scale > 1 {
+ eval.MultByConst(ctOut, scale, ctOut)
}
ctOut.SetScale(ct0.Scale())
@@ -638,11 +712,11 @@ func (eval *evaluator) MultByConstAndAdd(ct0 *Ciphertext, constant interface{},
// [a + b*psi_qi^2, ....., a + b*psi_qi^2, a - b*psi_qi^2, ...., a - b*psi_qi^2] mod Qi
// [{ N/2 }{ N/2 }]
// Which is equivalent outside of the NTT domain to adding a to the first coefficient of ct0 and b to the N/2-th coefficient of ct0.
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
qi := ringQ.Modulus[i]
- mredParams := ringQ.GetMredParams()[i]
- bredParams := ringQ.GetBredParams()[i]
+ mredParams := ringQ.MredParams[i]
+ bredParams := ringQ.BredParams[i]
scaledConstReal = 0
scaledConstImag = 0
@@ -655,7 +729,7 @@ func (eval *evaluator) MultByConstAndAdd(ct0 *Ciphertext, constant interface{},
if cImag != 0 {
scaledConstImag = scaleUpExact(cImag, scale, qi)
- scaledConstImag = ring.MRed(scaledConstImag, ringQ.GetNttPsi()[i][1], qi, mredParams)
+ scaledConstImag = ring.MRed(scaledConstImag, ringQ.NttPsi[i][1], qi, mredParams)
scaledConst = ring.CRed(scaledConst+scaledConstImag, qi)
}
@@ -665,7 +739,7 @@ func (eval *evaluator) MultByConstAndAdd(ct0 *Ciphertext, constant interface{},
p0tmp := ct0.Value()[u].Coeffs[i]
p1tmp := ctOut.Value()[u].Coeffs[i]
- for j := uint64(0); j < ringQ.N>>1; j = j + 8 {
+ for j := 0; j < ringQ.N>>1; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p0tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
@@ -721,7 +795,7 @@ func (eval *evaluator) MultByConstNew(ct0 *Ciphertext, constant interface{}) (ct
// needs to be scaled (its rational part is not zero)). The constant can be a uint64, int64, float64 or complex128.
func (eval *evaluator) MultByConst(ct0 *Ciphertext, constant interface{}, ctOut *Ciphertext) {
- var level = utils.MinUint64(ct0.Level(), ctOut.Level())
+ var level = utils.MinInt(ct0.Level(), ctOut.Level())
cReal, cImag, scale := eval.getConstAndScale(level, constant)
@@ -731,11 +805,11 @@ func (eval *evaluator) MultByConst(ct0 *Ciphertext, constant interface{}, ctOut
// Which is equivalent outside of the NTT domain to adding a to the first coefficient of ct0 and b to the N/2-th coefficient of ct0.
ringQ := eval.ringQ
var scaledConst, scaledConstReal, scaledConstImag uint64
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
qi := ringQ.Modulus[i]
- bredParams := ringQ.GetBredParams()[i]
- mredParams := ringQ.GetMredParams()[i]
+ bredParams := ringQ.BredParams[i]
+ mredParams := ringQ.MredParams[i]
scaledConstReal = 0
scaledConstImag = 0
@@ -748,7 +822,7 @@ func (eval *evaluator) MultByConst(ct0 *Ciphertext, constant interface{}, ctOut
if cImag != 0 {
scaledConstImag = scaleUpExact(cImag, scale, qi)
- scaledConstImag = ring.MRed(scaledConstImag, ringQ.GetNttPsi()[i][1], qi, mredParams)
+ scaledConstImag = ring.MRed(scaledConstImag, ringQ.NttPsi[i][1], qi, mredParams)
scaledConst = ring.CRed(scaledConst+scaledConstImag, qi)
}
@@ -758,7 +832,7 @@ func (eval *evaluator) MultByConst(ct0 *Ciphertext, constant interface{}, ctOut
p0tmp := ct0.Value()[u].Coeffs[i]
p1tmp := ctOut.Value()[u].Coeffs[i]
- for j := uint64(0); j < ringQ.N>>1; j = j + 8 {
+ for j := 0; j < ringQ.N>>1; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p0tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
@@ -806,14 +880,16 @@ func (eval *evaluator) MultByGaussianInteger(ct0 *Ciphertext, cReal, cImag int64
ringQ := eval.ringQ
- level := utils.MinUint64(ct0.Level(), ctOut.Level())
+ level := utils.MinInt(ct0.Level(), ctOut.Level())
var scaledConst, scaledConstReal, scaledConstImag uint64
- for i := uint64(0); i < level+1; i++ {
+ ctOut.SetScale(ct0.Scale())
+
+ for i := 0; i < level+1; i++ {
qi := ringQ.Modulus[i]
- bredParams := ringQ.GetBredParams()[i]
- mredParams := ringQ.GetMredParams()[i]
+ bredParams := ringQ.BredParams[i]
+ mredParams := ringQ.MredParams[i]
scaledConstReal = 0
scaledConstImag = 0
@@ -834,7 +910,7 @@ func (eval *evaluator) MultByGaussianInteger(ct0 *Ciphertext, cReal, cImag int64
} else {
scaledConstImag = uint64(cImag)
}
- scaledConstImag = ring.MRed(scaledConstImag, ringQ.GetNttPsi()[i][1], qi, mredParams)
+ scaledConstImag = ring.MRed(scaledConstImag, ringQ.NttPsi[i][1], qi, mredParams)
scaledConst = ring.CRed(scaledConst+scaledConstImag, qi)
}
@@ -844,7 +920,7 @@ func (eval *evaluator) MultByGaussianInteger(ct0 *Ciphertext, cReal, cImag int64
p0tmp := ct0.Value()[u].Coeffs[i]
p1tmp := ctOut.Value()[u].Coeffs[i]
- for j := uint64(0); j < ringQ.N>>1; j = j + 8 {
+ for j := 0; j < ringQ.N>>1; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p0tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
@@ -891,14 +967,14 @@ func (eval *evaluator) MultByGaussianIntegerAndAdd(ct0 *Ciphertext, cReal, cImag
ringQ := eval.ringQ
- level := utils.MinUint64(ct0.Level(), ctOut.Level())
+ level := utils.MinInt(ct0.Level(), ctOut.Level())
var scaledConst, scaledConstReal, scaledConstImag uint64
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
qi := ringQ.Modulus[i]
- bredParams := ringQ.GetBredParams()[i]
- mredParams := ringQ.GetMredParams()[i]
+ bredParams := ringQ.BredParams[i]
+ mredParams := ringQ.MredParams[i]
scaledConstReal = 0
scaledConstImag = 0
@@ -919,7 +995,7 @@ func (eval *evaluator) MultByGaussianIntegerAndAdd(ct0 *Ciphertext, cReal, cImag
} else {
scaledConstImag = uint64(cImag)
}
- scaledConstImag = ring.MRed(scaledConstImag, ringQ.GetNttPsi()[i][1], qi, mredParams)
+ scaledConstImag = ring.MRed(scaledConstImag, ringQ.NttPsi[i][1], qi, mredParams)
scaledConst = ring.CRed(scaledConst+scaledConstImag, qi)
}
@@ -929,7 +1005,7 @@ func (eval *evaluator) MultByGaussianIntegerAndAdd(ct0 *Ciphertext, cReal, cImag
p0tmp := ct0.Value()[u].Coeffs[i]
p1tmp := ctOut.Value()[u].Coeffs[i]
- for j := uint64(0); j < ringQ.N>>1; j = j + 8 {
+ for j := 0; j < ringQ.N>>1; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p0tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
@@ -984,25 +1060,26 @@ func (eval *evaluator) MultByiNew(ct0 *Ciphertext) (ctOut *Ciphertext) {
// It does not change the scale.
func (eval *evaluator) MultByi(ct0 *Ciphertext, ctOut *Ciphertext) {
- var level = utils.MinUint64(ct0.Level(), ctOut.Level())
+ var level = utils.MinInt(ct0.Level(), ctOut.Level())
+ ctOut.SetScale(ct0.Scale())
ringQ := eval.ringQ
var imag uint64
// Equivalent to a product by the monomial x^(n/2) outside of the NTT domain
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
qi := ringQ.Modulus[i]
- mredParams := ringQ.GetMredParams()[i]
+ mredParams := ringQ.MredParams[i]
- imag = ringQ.GetNttPsi()[i][1] // Psi^2
+ imag = ringQ.NttPsi[i][1] // Psi^2
for u := range ctOut.value {
p0tmp := ct0.value[u].Coeffs[i]
p1tmp := ctOut.value[u].Coeffs[i]
- for j := uint64(0); j < ringQ.N>>1; j = j + 8 {
+ for j := 0; j < ringQ.N>>1; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p0tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
@@ -1054,24 +1131,26 @@ func (eval *evaluator) DivByiNew(ct0 *Ciphertext) (ctOut *Ciphertext) {
// It does not change the scale.
func (eval *evaluator) DivByi(ct0 *Ciphertext, ctOut *Ciphertext) {
- var level = utils.MinUint64(ct0.Level(), ctOut.Level())
+ var level = utils.MinInt(ct0.Level(), ctOut.Level())
ringQ := eval.ringQ
+ ctOut.SetScale(ct0.Scale())
+
var imag uint64
// Equivalent to a product by the monomial x^(3*n/2) outside of the NTT domain
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
qi := ringQ.Modulus[i]
- mredParams := ringQ.GetMredParams()[i]
+ mredParams := ringQ.MredParams[i]
- imag = qi - ringQ.GetNttPsi()[i][1] // -Psi^2
+ imag = qi - ringQ.NttPsi[i][1] // -Psi^2
for u := range ctOut.value {
p0tmp := ct0.value[u].Coeffs[i]
p1tmp := ctOut.value[u].Coeffs[i]
- for j := uint64(0); j < ringQ.N>>1; j = j + 8 {
+ for j := 0; j < ringQ.N>>1; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p0tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
@@ -1087,7 +1166,7 @@ func (eval *evaluator) DivByi(ct0 *Ciphertext, ctOut *Ciphertext) {
}
}
- imag = ringQ.GetNttPsi()[i][1] // Psi^2
+ imag = ringQ.NttPsi[i][1] // Psi^2
for u := range ctOut.value {
p0tmp := ct0.value[u].Coeffs[i]
@@ -1144,15 +1223,16 @@ func (eval *evaluator) SetScale(ct *Ciphertext, scale float64) {
}
// MulByPow2New multiplies ct0 by 2^pow2 and returns the result in a newly created element.
-func (eval *evaluator) MulByPow2New(ct0 *Ciphertext, pow2 uint64) (ctOut *Ciphertext) {
+func (eval *evaluator) MulByPow2New(ct0 *Ciphertext, pow2 int) (ctOut *Ciphertext) {
ctOut = NewCiphertext(eval.params, ct0.Degree(), ct0.Level(), ct0.Scale())
eval.MulByPow2(ct0.El(), pow2, ctOut.El())
return
}
// MulByPow2 multiplies ct0 by 2^pow2 and returns the result in ctOut.
-func (eval *evaluator) MulByPow2(ct0 *Element, pow2 uint64, ctOut *Element) {
- var level = utils.MinUint64(ct0.Level(), ctOut.Level())
+func (eval *evaluator) MulByPow2(ct0 *Element, pow2 int, ctOut *Element) {
+ var level = utils.MinInt(ct0.Level(), ctOut.Level())
+ ctOut.SetScale(ct0.Scale())
for i := range ctOut.Value() {
eval.ringQ.MulByPow2Lvl(level, ct0.value[i], pow2, ctOut.Value()[i])
}
@@ -1178,15 +1258,17 @@ func (eval *evaluator) Reduce(ct0 *Ciphertext, ctOut *Ciphertext) error {
}
for i := range ct0.value {
- eval.ringQ.ReduceLvl(utils.MinUint64(ct0.Level(), ctOut.Level()), ct0.value[i], ctOut.value[i])
+ eval.ringQ.ReduceLvl(utils.MinInt(ct0.Level(), ctOut.Level()), ct0.value[i], ctOut.value[i])
}
+ ctOut.SetScale(ct0.Scale())
+
return nil
}
// DropLevelNew reduces the level of ct0 by levels and returns the result in a newly created element.
// No rescaling is applied during this procedure.
-func (eval *evaluator) DropLevelNew(ct0 *Ciphertext, levels uint64) (ctOut *Ciphertext) {
+func (eval *evaluator) DropLevelNew(ct0 *Ciphertext, levels int) (ctOut *Ciphertext) {
ctOut = ct0.CopyNew().Ciphertext()
eval.DropLevel(ctOut, levels)
return
@@ -1194,7 +1276,7 @@ func (eval *evaluator) DropLevelNew(ct0 *Ciphertext, levels uint64) (ctOut *Ciph
// DropLevel reduces the level of ct0 by levels and returns the result in ct0.
// No rescaling is applied during this procedure.
-func (eval *evaluator) DropLevel(ct0 *Ciphertext, levels uint64) {
+func (eval *evaluator) DropLevel(ct0 *Ciphertext, levels int) {
level := ct0.Level()
for i := range ct0.value {
ct0.value[i].Coeffs = ct0.value[i].Coeffs[:level+1-levels]
@@ -1219,75 +1301,46 @@ func (eval *evaluator) RescaleNew(ct0 *Ciphertext, threshold float64) (ctOut *Ci
// in ctOut. Since all the moduli in the moduli chain are generated to be close to the
// original scale, this procedure is equivalent to dividing the input element by the scale and adding
// some error.
-// Returns an error if "threshold <= 0", ct.Scale() = 0, ct.Level() = 0, ct.IsNTT() != true or if ct.Level() != ctOut.Level()
-func (eval *evaluator) Rescale(ct0 *Ciphertext, threshold float64, ctOut *Ciphertext) (err error) {
+// Returns an error if "minScale <= 0", ct.Scale() = 0, ct.Level() = 0, ct.IsNTT() != true or if ct.Leve() != ctOut.Level()
+func (eval *evaluator) Rescale(ctIn *Ciphertext, minScale float64, ctOut *Ciphertext) (err error) {
ringQ := eval.ringQ
- if threshold <= 0 {
- return errors.New("cannot Rescale: threshold is 0")
+ if minScale <= 0 {
+ return errors.New("cannot Rescale: minScale is 0")
}
- if ct0.Scale() == 0 {
+ if ctIn.Scale() == 0 {
return errors.New("cannot Rescale: ciphertext scale is 0")
}
- if ct0.Level() == 0 {
+ if ctIn.Level() == 0 {
return errors.New("cannot Rescale: input Ciphertext already at level 0")
}
- if ct0.Level() != ctOut.Level() {
- panic("cannot Rescale: degrees of receiver Ciphertext and input Ciphertext do not match")
+ if ctOut.Degree() != ctIn.Degree() {
+ return errors.New("cannot Rescale : ctIn.Degree() != ctOut.Degree()")
}
- if ct0.Scale() >= (threshold*float64(ringQ.Modulus[ctOut.Level()]))/2 {
+ ctOut.scale = ctIn.scale
+ ctOut.isNTT = true
- if !ct0.IsNTT() {
- panic("cannot Rescale: input Ciphertext not in NTT")
+ var nbRescale int
+ // Divides the scale by each moduli of the modulus chain as long as the scale isn't smaller than minScale/2
+ // or until the output Level() would be zero
+ for ctOut.Scale()/float64(ringQ.Modulus[ctIn.Level()-nbRescale]) >= minScale/2 && ctIn.Level()-nbRescale >= 0 {
+ ctOut.DivScale(float64(ringQ.Modulus[ctIn.Level()-nbRescale]))
+ nbRescale++
+ }
+
+ if ctIn.IsNTT() {
+ for i := range ctOut.Value() {
+ ringQ.DivRoundByLastModulusManyNTT(ctIn.Value()[i], ctOut.Value()[i], nbRescale)
}
-
- ctOut.Copy(ct0.El())
-
- for ctOut.Scale() >= (threshold*float64(ringQ.Modulus[ctOut.Level()]))/2 && ctOut.Level() != 0 {
-
- ctOut.DivScale(float64(ringQ.Modulus[ctOut.Level()]))
-
- for i := range ctOut.Value() {
- eval.ringQ.DivRoundByLastModulusNTT(ctOut.Value()[i])
- }
-
- }
-
} else {
- ctOut.Copy(ct0.El())
- }
-
- return nil
-}
-
-// RescaleMany applies Rescale several times in a row on the input Ciphertext.
-func (eval *evaluator) RescaleMany(ct0 *Ciphertext, nbRescales uint64, ctOut *Ciphertext) (err error) {
-
- if ct0.Level() < nbRescales {
- return errors.New("cannot RescaleMany: input Ciphertext level too low")
- }
-
- if ct0.Level() != ctOut.Level() {
- panic("cannot RescaleMany: degrees of receiver Ciphertext and input Ciphertext do not match")
- }
-
- if !ct0.IsNTT() {
- panic("cannot RescaleMany: input Ciphertext not in NTT")
- }
-
- ctOut.Copy(ct0.El())
-
- for i := uint64(0); i < nbRescales; i++ {
- ctOut.DivScale(float64(eval.ringQ.Modulus[ctOut.Level()-i]))
- }
-
- for i := range ctOut.Value() {
- eval.ringQ.DivRoundByLastModulusManyNTT(ctOut.Value()[i], nbRescales)
+ for i := range ctOut.Value() {
+ ringQ.DivRoundByLastModulusMany(ctIn.Value()[i], ctOut.Value()[i], nbRescale)
+ }
}
return nil
@@ -1296,7 +1349,7 @@ func (eval *evaluator) RescaleMany(ct0 *Ciphertext, nbRescales uint64, ctOut *Ci
// MulNew multiplies op0 with op1 without relinearization and returns the result in a newly created element.
// The procedure will panic if either op0.Degree or op1.Degree > 1.
func (eval *evaluator) MulNew(op0, op1 Operand) (ctOut *Ciphertext) {
- ctOut = NewCiphertext(eval.params, op0.Degree()+op1.Degree(), utils.MinUint64(op0.Level(), op1.Level()), 0)
+ ctOut = NewCiphertext(eval.params, op0.Degree()+op1.Degree(), utils.MinInt(op0.Level(), op1.Level()), 0)
eval.mulRelin(op0, op1, false, ctOut)
return
}
@@ -1312,7 +1365,7 @@ func (eval *evaluator) Mul(op0, op1 Operand, ctOut *Ciphertext) {
// The procedure will panic if either op0.Degree or op1.Degree > 1.
// The procedure will panic if the evaluator was not created with an relinearization key.
func (eval *evaluator) MulRelinNew(op0, op1 Operand) (ctOut *Ciphertext) {
- ctOut = NewCiphertext(eval.params, 1, utils.MinUint64(op0.Level(), op1.Level()), 0)
+ ctOut = NewCiphertext(eval.params, 1, utils.MinInt(op0.Level(), op1.Level()), 0)
eval.mulRelin(op0, op1, true, ctOut)
return
}
@@ -1327,13 +1380,9 @@ func (eval *evaluator) MulRelin(op0, op1 Operand, ctOut *Ciphertext) {
func (eval *evaluator) mulRelin(op0, op1 Operand, relin bool, ctOut *Ciphertext) {
- if relin && eval.rlk == nil {
- panic("evaluator has no relinearization key")
- }
+ el0, el1, elOut := eval.getElemAndCheckBinary(op0, op1, ctOut, utils.MaxInt(op0.Degree(), op1.Degree()))
- el0, el1, elOut := eval.getElemAndCheckBinary(op0, op1, ctOut, utils.MaxUint64(op0.Degree(), op1.Degree()))
-
- level := utils.MinUint64(utils.MinUint64(el0.Level(), el1.Level()), elOut.Level())
+ level := utils.MinInt(utils.MinInt(el0.Level(), el1.Level()), elOut.Level())
if ctOut.Level() > level {
eval.DropLevel(elOut.Ciphertext(), elOut.Level()-level)
@@ -1392,7 +1441,7 @@ func (eval *evaluator) mulRelin(op0, op1 Operand, relin bool, ctOut *Ciphertext)
}
if relin {
- eval.switchKeysInPlace(level, c2, eval.rlk.Keys[0], eval.poolQ[1], eval.poolQ[2])
+ eval.SwitchKeysInPlace(level, c2, eval.rlk.Keys[0], eval.poolQ[1], eval.poolQ[2])
ringQ.AddLvl(level, c0, eval.poolQ[1], elOut.value[0])
ringQ.AddLvl(level, c1, eval.poolQ[2], elOut.value[1])
}
@@ -1426,23 +1475,20 @@ func (eval *evaluator) RelinearizeNew(ct0 *Ciphertext) (ctOut *Ciphertext) {
// Relinearize applies the relinearization procedure on ct0 and returns the result in ctOut. The input Ciphertext must be of degree two.
func (eval *evaluator) Relinearize(ct0 *Ciphertext, ctOut *Ciphertext) {
-
- if eval.rlk == nil {
- panic("evaluator has no relinearization key")
- }
-
if ct0.Degree() != 2 {
panic("cannot Relinearize: input Ciphertext is not of degree 2")
}
- if ctOut != ct0 {
- ctOut.SetScale(ct0.Scale())
+ if ctOut.Level() > ct0.Level() {
+ eval.DropLevel(ctOut, ctOut.Level()-ct0.Level())
}
- level := utils.MinUint64(ct0.Level(), ctOut.Level())
+ ctOut.SetScale(ct0.Scale())
+
+ level := utils.MinInt(ct0.Level(), ctOut.Level())
ringQ := eval.ringQ
- eval.switchKeysInPlace(level, ct0.value[2], eval.rlk.Keys[0], eval.poolQ[1], eval.poolQ[2])
+ eval.SwitchKeysInPlace(level, ct0.value[2], eval.rlk.Keys[0], eval.poolQ[1], eval.poolQ[2])
ringQ.AddLvl(level, ct0.value[0], eval.poolQ[1], ctOut.value[0])
ringQ.AddLvl(level, ct0.value[1], eval.poolQ[2], ctOut.value[1])
@@ -1468,10 +1514,12 @@ func (eval *evaluator) SwitchKeys(ct0 *Ciphertext, switchingKey *SwitchingKey, c
panic("cannot SwitchKeys: input and output Ciphertext must be of degree 1")
}
- level := utils.MinUint64(ct0.Level(), ctOut.Level())
+ level := utils.MinInt(ct0.Level(), ctOut.Level())
ringQ := eval.ringQ
- eval.switchKeysInPlace(level, ct0.value[1], &switchingKey.SwitchingKey, eval.poolQ[1], eval.poolQ[2])
+ ctOut.SetScale(ct0.Scale())
+
+ eval.SwitchKeysInPlace(level, ct0.value[1], &switchingKey.SwitchingKey, eval.poolQ[1], eval.poolQ[2])
ringQ.AddLvl(level, ct0.value[0], eval.poolQ[1], ctOut.value[0])
ringQ.CopyLvl(level, eval.poolQ[2], ctOut.value[1])
@@ -1500,12 +1548,8 @@ func (eval *evaluator) Rotate(ct0 *Ciphertext, k int, ctOut *Ciphertext) {
ctOut.SetScale(ct0.Scale())
galEl := eval.params.GaloisElementForColumnRotationBy(k)
- rtk, generated := eval.rtks.Keys[galEl]
- if !generated {
- panic(fmt.Errorf("evaluator has no rotation key for rotation by %d", k))
- }
- eval.permuteNTT(ct0, galEl, rtk, ctOut)
+ eval.permuteNTT(ct0, galEl, ctOut)
}
}
@@ -1523,28 +1567,27 @@ func (eval *evaluator) ConjugateNew(ct0 *Ciphertext) (ctOut *Ciphertext) {
func (eval *evaluator) Conjugate(ct0 *Ciphertext, ctOut *Ciphertext) {
galEl := eval.params.GaloisElementForRowRotation()
- rtk, generated := eval.rtks.Keys[galEl]
- if !generated {
- panic("evaluator has no rotation key for row rotation")
- }
-
ctOut.SetScale(ct0.Scale())
-
- eval.permuteNTT(ct0, galEl, rtk, ctOut)
+ eval.permuteNTT(ct0, galEl, ctOut)
}
-func (eval *evaluator) permuteNTT(ct0 *Ciphertext, galEl uint64, rtk *rlwe.SwitchingKey, ctOut *Ciphertext) {
+func (eval *evaluator) permuteNTT(ct0 *Ciphertext, galEl uint64, ctOut *Ciphertext) {
if ct0.Degree() != 1 || ctOut.Degree() != 1 {
panic("input and output Ciphertext must be of degree 1")
}
- level := utils.MinUint64(ct0.Level(), ctOut.Level())
+ rtk, generated := eval.rtks.Keys[galEl]
+ if !generated {
+ panic(fmt.Sprintf("rotation key k=%d not available", eval.params.InverseGaloisElement(galEl)))
+ }
+
+ level := utils.MinInt(ct0.Level(), ctOut.Level())
index := eval.permuteNTTIndex[galEl]
pool2Q := eval.poolQ[1]
pool3Q := eval.poolQ[2]
- eval.switchKeysInPlace(level, ct0.Value()[1], rtk, pool2Q, pool3Q)
+ eval.SwitchKeysInPlace(level, ct0.Value()[1], rtk, pool2Q, pool3Q)
eval.ringQ.AddLvl(level, pool2Q, ct0.value[0], pool2Q)
@@ -1552,8 +1595,60 @@ func (eval *evaluator) permuteNTT(ct0 *Ciphertext, galEl uint64, rtk *rlwe.Switc
ring.PermuteNTTWithIndexLvl(level, pool3Q, index, ctOut.value[1])
}
-func (eval *evaluator) switchKeysInPlaceNoModDown(level uint64, cx *ring.Poly, evakey *rlwe.SwitchingKey, pool2Q, pool2P, pool3Q, pool3P *ring.Poly) {
- var reduce uint64
+func (eval *evaluator) rotateHoistedNoModDown(ct0 *Ciphertext, rotations []int, c2QiQDecomp, c2QiPDecomp []*ring.Poly) (cOutQ, cOutP map[int][2]*ring.Poly) {
+
+ ringQ := eval.ringQ
+
+ cOutQ = make(map[int][2]*ring.Poly)
+ cOutP = make(map[int][2]*ring.Poly)
+
+ level := ct0.Level()
+
+ for _, i := range rotations {
+
+ if i != 0 {
+ cOutQ[i] = [2]*ring.Poly{ringQ.NewPolyLvl(level), ringQ.NewPolyLvl(level)}
+ cOutP[i] = [2]*ring.Poly{eval.params.NewPolyP(), eval.params.NewPolyP()}
+
+ eval.permuteNTTHoistedNoModDown(level, c2QiQDecomp, c2QiPDecomp, i, cOutQ[i][0], cOutQ[i][1], cOutP[i][0], cOutP[i][1])
+ }
+ }
+
+ return
+}
+
+func (eval *evaluator) permuteNTTHoistedNoModDown(level int, c2QiQDecomp, c2QiPDecomp []*ring.Poly, k int, ct0OutQ, ct1OutQ, ct0OutP, ct1OutP *ring.Poly) {
+
+ pool2Q := eval.poolQ[0]
+ pool3Q := eval.poolQ[1]
+
+ pool2P := eval.poolP[0]
+ pool3P := eval.poolP[1]
+
+ levelQ := level
+ levelP := eval.params.PiCount() - 1
+
+ galEl := eval.params.GaloisElementForColumnRotationBy(k)
+
+ rtk, generated := eval.rtks.Keys[galEl]
+ if !generated {
+ fmt.Println(k)
+ panic("switching key not available")
+ }
+ index := eval.permuteNTTIndex[galEl]
+
+ eval.keyswitchHoistedNoModDown(levelQ, c2QiQDecomp, c2QiPDecomp, rtk, pool2Q, pool3Q, pool2P, pool3P)
+
+ ring.PermuteNTTWithIndexLvl(levelQ, pool2Q, index, ct0OutQ)
+ ring.PermuteNTTWithIndexLvl(levelQ, pool3Q, index, ct1OutQ)
+
+ ring.PermuteNTTWithIndexLvl(levelP, pool2P, index, ct0OutP)
+ ring.PermuteNTTWithIndexLvl(levelP, pool3P, index, ct1OutP)
+}
+
+func (eval *evaluator) SwitchKeysInPlaceNoModDown(level int, cx *ring.Poly, evakey *rlwe.SwitchingKey, pool2Q, pool2P, pool3Q, pool3P *ring.Poly) {
+
+ var reduce int
ringQ := eval.ringQ
ringP := eval.ringP
@@ -1562,7 +1657,7 @@ func (eval *evaluator) switchKeysInPlaceNoModDown(level uint64, cx *ring.Poly, e
c2QiQ := eval.poolQ[0]
c2QiP := eval.poolP[0]
- c2 := eval.poolQ[3]
+ c2 := eval.poolInvNTT
evakey0Q := new(ring.Poly)
evakey1Q := new(ring.Poly)
@@ -1576,10 +1671,13 @@ func (eval *evaluator) switchKeysInPlaceNoModDown(level uint64, cx *ring.Poly, e
reduce = 0
alpha := eval.params.Alpha()
- beta := uint64(math.Ceil(float64(level+1) / float64(alpha)))
+ beta := int(math.Ceil(float64(level+1) / float64(alpha)))
+
+ QiOverF := eval.params.QiOverflowMargin(level) >> 1
+ PiOverF := eval.params.PiOverflowMargin() >> 1
// Key switching with CRT decomposition for the Qi
- for i := uint64(0); i < beta; i++ {
+ for i := 0; i < beta; i++ {
eval.decomposeAndSplitNTT(level, i, cx, c2, c2QiQ, c2QiP)
@@ -1600,34 +1698,56 @@ func (eval *evaluator) switchKeysInPlaceNoModDown(level uint64, cx *ring.Poly, e
ringP.MulCoeffsMontgomeryConstantAndAddNoMod(evakey1P, c2QiP, pool3P)
}
- //
- if reduce&3 == 3 {
- ringQ.ReduceConstantLvl(level, pool2Q, pool2Q)
- ringQ.ReduceConstantLvl(level, pool3Q, pool3Q)
- ringP.ReduceConstant(pool2P, pool2P)
- ringP.ReduceConstant(pool3P, pool3P)
+ if reduce%QiOverF == QiOverF-1 {
+ ringQ.ReduceLvl(level, pool2Q, pool2Q)
+ ringQ.ReduceLvl(level, pool3Q, pool3Q)
+ }
+
+ if reduce%PiOverF == PiOverF-1 {
+ ringP.Reduce(pool2P, pool2P)
+ ringP.Reduce(pool3P, pool3P)
}
reduce++
}
- ringQ.ReduceLvl(level, pool2Q, pool2Q)
- ringQ.ReduceLvl(level, pool3Q, pool3Q)
- ringP.Reduce(pool2P, pool2P)
- ringP.Reduce(pool3P, pool3P)
+ if reduce%QiOverF != 0 {
+ ringQ.ReduceLvl(level, pool2Q, pool2Q)
+ ringQ.ReduceLvl(level, pool3Q, pool3Q)
+ }
+
+ if reduce%PiOverF != 0 {
+ ringP.Reduce(pool2P, pool2P)
+ ringP.Reduce(pool3P, pool3P)
+ }
}
-// switchKeysInPlace applies the general key-switching procedure of the form [c0 + cx*evakey[0], c1 + cx*evakey[1]]
-func (eval *evaluator) switchKeysInPlace(level uint64, cx *ring.Poly, evakey *rlwe.SwitchingKey, p0, p1 *ring.Poly) {
+// SwitchKeysInPlace applies the general key-switching procedure of the form [c0 + cx*evakey[0], c1 + cx*evakey[1]]
+func (eval *evaluator) SwitchKeysInPlace(level int, cx *ring.Poly, evakey *rlwe.SwitchingKey, p0, p1 *ring.Poly) {
- eval.switchKeysInPlaceNoModDown(level, cx, evakey, p0, eval.poolP[1], p1, eval.poolP[2])
+ eval.SwitchKeysInPlaceNoModDown(level, cx, evakey, p0, eval.poolP[1], p1, eval.poolP[2])
eval.baseconverter.ModDownSplitNTTPQ(level, p0, eval.poolP[1], p0)
eval.baseconverter.ModDownSplitNTTPQ(level, p1, eval.poolP[2], p1)
}
+func (eval *evaluator) DecompInternal(levelQ int, c2NTT *ring.Poly, c2QiQDecomp, c2QiPDecomp []*ring.Poly) {
+
+ ringQ := eval.ringQ
+
+ c2InvNTT := eval.poolInvNTT
+ ringQ.InvNTTLvl(levelQ, c2NTT, c2InvNTT)
+
+ alpha := eval.params.Alpha()
+ beta := int(math.Ceil(float64(levelQ+1) / float64(alpha)))
+
+ for i := 0; i < beta; i++ {
+ eval.decomposeAndSplitNTT(levelQ, i, c2NTT, c2InvNTT, c2QiQDecomp[i], c2QiPDecomp[i])
+ }
+}
+
// decomposeAndSplitNTT decomposes the input polynomial into the target CRT basis.
-func (eval *evaluator) decomposeAndSplitNTT(level, beta uint64, c2NTT, c2InvNTT, c2QiQ, c2QiP *ring.Poly) {
+func (eval *evaluator) decomposeAndSplitNTT(level, beta int, c2NTT, c2InvNTT, c2QiQ, c2QiP *ring.Poly) {
ringQ := eval.ringQ
ringP := eval.ringP
@@ -1638,17 +1758,17 @@ func (eval *evaluator) decomposeAndSplitNTT(level, beta uint64, c2NTT, c2InvNTT,
p0idxed := p0idxst + eval.decomposer.Xalpha()[beta]
// c2_qi = cx mod qi mod qi
- for x := uint64(0); x < level+1; x++ {
+ for x := 0; x < level+1; x++ {
qi := ringQ.Modulus[x]
- nttPsi := ringQ.GetNttPsi()[x]
- bredParams := ringQ.GetBredParams()[x]
- mredParams := ringQ.GetMredParams()[x]
+ nttPsi := ringQ.NttPsi[x]
+ bredParams := ringQ.BredParams[x]
+ mredParams := ringQ.MredParams[x]
if p0idxst <= x && x < p0idxed {
p0tmp := c2NTT.Coeffs[x]
p1tmp := c2QiQ.Coeffs[x]
- for j := uint64(0); j < ringQ.N; j++ {
+ for j := 0; j < ringQ.N; j++ {
p1tmp[j] = p0tmp[j]
}
} else {
@@ -1659,52 +1779,11 @@ func (eval *evaluator) decomposeAndSplitNTT(level, beta uint64, c2NTT, c2InvNTT,
ringP.NTTLazy(c2QiP, c2QiP)
}
-// RotateHoisted takes an input Ciphertext and a list of rotations and returns a map of Ciphertext, where each element of the map is the input Ciphertext
-// rotation by one element of the list. It is much faster than sequential calls to Rotate.
-func (eval *evaluator) RotateHoisted(ct0 *Ciphertext, rotations []int) (cOut map[int]*Ciphertext) {
-
- // Pre-computation for rotations using hoisting
- ringQ := eval.ringQ
- ringP := eval.ringP
-
- c2NTT := ct0.value[1]
- c2InvNTT := ringQ.NewPoly()
- ringQ.InvNTTLvl(ct0.Level(), c2NTT, c2InvNTT)
-
- alpha := eval.params.Alpha()
- beta := uint64(math.Ceil(float64(ct0.Level()+1) / float64(alpha)))
-
- c2QiQDecomp := make([]*ring.Poly, beta)
- c2QiPDecomp := make([]*ring.Poly, beta)
-
- for i := uint64(0); i < beta; i++ {
- c2QiQDecomp[i] = ringQ.NewPoly()
- c2QiPDecomp[i] = ringP.NewPoly()
- eval.decomposeAndSplitNTT(ct0.Level(), i, c2NTT, c2InvNTT, c2QiQDecomp[i], c2QiPDecomp[i])
- }
-
- cOut = make(map[int]*Ciphertext)
- for _, i := range rotations {
-
- if i == 0 {
- cOut[i] = ct0.CopyNew().Ciphertext()
- } else {
- cOut[i] = NewCiphertext(eval.params, 1, ct0.Level(), ct0.Scale())
- eval.permuteNTTHoisted(ct0, c2QiQDecomp, c2QiPDecomp, i, cOut[i])
- }
- }
-
- return
-}
-
-func (eval *evaluator) permuteNTTHoisted(ct0 *Ciphertext, c2QiQDecomp, c2QiPDecomp []*ring.Poly, k int, ctOut *Ciphertext) {
-
- if ct0.Degree() != 1 || ctOut.Degree() != 1 {
- panic("input and output Ciphertext must be of degree 1")
- }
+func (eval *evaluator) permuteNTTHoisted(level int, c0, c1 *ring.Poly, c2QiQDecomp, c2QiPDecomp []*ring.Poly, k int, cOut0, cOut1 *ring.Poly) {
if k == 0 {
- ctOut.Copy(ct0.Element)
+ cOut0.Copy(c0)
+ cOut1.Copy(c1)
return
}
@@ -1714,8 +1793,6 @@ func (eval *evaluator) permuteNTTHoisted(ct0 *Ciphertext, c2QiQDecomp, c2QiPDeco
panic(fmt.Sprintf("specific rotation has not been generated: %d", k))
}
- ctOut.SetScale(ct0.Scale())
-
index := eval.permuteNTTIndex[galEl]
pool2Q := eval.poolQ[0]
@@ -1724,17 +1801,15 @@ func (eval *evaluator) permuteNTTHoisted(ct0 *Ciphertext, c2QiQDecomp, c2QiPDeco
pool2P := eval.poolP[0]
pool3P := eval.poolP[1]
- level := ctOut.Level()
-
eval.keyswitchHoisted(level, c2QiQDecomp, c2QiPDecomp, rtk, pool2Q, pool3Q, pool2P, pool3P)
- eval.ringQ.AddLvl(level, pool2Q, ct0.value[0], pool2Q)
+ eval.ringQ.AddLvl(level, pool2Q, c0, pool2Q)
- ring.PermuteNTTWithIndexLvl(level, pool2Q, index, ctOut.value[0])
- ring.PermuteNTTWithIndexLvl(level, pool3Q, index, ctOut.value[1])
+ ring.PermuteNTTWithIndexLvl(level, pool2Q, index, cOut0)
+ ring.PermuteNTTWithIndexLvl(level, pool3Q, index, cOut1)
}
-func (eval *evaluator) keyswitchHoisted(level uint64, c2QiQDecomp, c2QiPDecomp []*ring.Poly, evakey *rlwe.SwitchingKey, pool2Q, pool3Q, pool2P, pool3P *ring.Poly) {
+func (eval *evaluator) keyswitchHoisted(level int, c2QiQDecomp, c2QiPDecomp []*ring.Poly, evakey *rlwe.SwitchingKey, pool2Q, pool3Q, pool2P, pool3P *ring.Poly) {
eval.keyswitchHoistedNoModDown(level, c2QiQDecomp, c2QiPDecomp, evakey, pool2Q, pool3Q, pool2P, pool3P)
@@ -1743,22 +1818,25 @@ func (eval *evaluator) keyswitchHoisted(level uint64, c2QiQDecomp, c2QiPDecomp [
eval.baseconverter.ModDownSplitNTTPQ(level, pool3Q, pool3P, pool3Q)
}
-func (eval *evaluator) keyswitchHoistedNoModDown(level uint64, c2QiQDecomp, c2QiPDecomp []*ring.Poly, evakey *rlwe.SwitchingKey, pool2Q, pool3Q, pool2P, pool3P *ring.Poly) {
+func (eval *evaluator) keyswitchHoistedNoModDown(level int, c2QiQDecomp, c2QiPDecomp []*ring.Poly, evakey *rlwe.SwitchingKey, pool2Q, pool3Q, pool2P, pool3P *ring.Poly) {
ringQ := eval.ringQ
ringP := eval.ringP
alpha := eval.params.Alpha()
- beta := uint64(math.Ceil(float64(level+1) / float64(alpha)))
+ beta := int(math.Ceil(float64(level+1) / float64(alpha)))
evakey0Q := new(ring.Poly)
evakey1Q := new(ring.Poly)
evakey0P := new(ring.Poly)
evakey1P := new(ring.Poly)
+ QiOverF := eval.params.QiOverflowMargin(level) >> 1
+ PiOverF := eval.params.PiOverflowMargin() >> 1
+
// Key switching with CRT decomposition for the Qi
- var reduce uint64
- for i := uint64(0); i < beta; i++ {
+ var reduce int
+ for i := 0; i < beta; i++ {
evakey0Q.Coeffs = evakey.Value[i][0].Coeffs[:level+1]
evakey1Q.Coeffs = evakey.Value[i][1].Coeffs[:level+1]
@@ -1766,20 +1844,23 @@ func (eval *evaluator) keyswitchHoistedNoModDown(level uint64, c2QiQDecomp, c2Qi
evakey1P.Coeffs = evakey.Value[i][1].Coeffs[len(ringQ.Modulus):]
if i == 0 {
- ringQ.MulCoeffsMontgomeryLvl(level, evakey0Q, c2QiQDecomp[i], pool2Q)
- ringQ.MulCoeffsMontgomeryLvl(level, evakey1Q, c2QiQDecomp[i], pool3Q)
- ringP.MulCoeffsMontgomery(evakey0P, c2QiPDecomp[i], pool2P)
- ringP.MulCoeffsMontgomery(evakey1P, c2QiPDecomp[i], pool3P)
+ ringQ.MulCoeffsMontgomeryConstantLvl(level, evakey0Q, c2QiQDecomp[i], pool2Q)
+ ringQ.MulCoeffsMontgomeryConstantLvl(level, evakey1Q, c2QiQDecomp[i], pool3Q)
+ ringP.MulCoeffsMontgomeryConstant(evakey0P, c2QiPDecomp[i], pool2P)
+ ringP.MulCoeffsMontgomeryConstant(evakey1P, c2QiPDecomp[i], pool3P)
} else {
- ringQ.MulCoeffsMontgomeryAndAddNoModLvl(level, evakey0Q, c2QiQDecomp[i], pool2Q)
- ringQ.MulCoeffsMontgomeryAndAddNoModLvl(level, evakey1Q, c2QiQDecomp[i], pool3Q)
- ringP.MulCoeffsMontgomeryAndAddNoMod(evakey0P, c2QiPDecomp[i], pool2P)
- ringP.MulCoeffsMontgomeryAndAddNoMod(evakey1P, c2QiPDecomp[i], pool3P)
+ ringQ.MulCoeffsMontgomeryConstantAndAddNoModLvl(level, evakey0Q, c2QiQDecomp[i], pool2Q)
+ ringQ.MulCoeffsMontgomeryConstantAndAddNoModLvl(level, evakey1Q, c2QiQDecomp[i], pool3Q)
+ ringP.MulCoeffsMontgomeryConstantAndAddNoMod(evakey0P, c2QiPDecomp[i], pool2P)
+ ringP.MulCoeffsMontgomeryConstantAndAddNoMod(evakey1P, c2QiPDecomp[i], pool3P)
}
- if reduce&7 == 1 {
+ if reduce%QiOverF == QiOverF-1 {
ringQ.ReduceLvl(level, pool2Q, pool2Q)
ringQ.ReduceLvl(level, pool3Q, pool3Q)
+ }
+
+ if reduce%PiOverF == PiOverF-1 {
ringP.Reduce(pool2P, pool2P)
ringP.Reduce(pool3P, pool3P)
}
@@ -1787,9 +1868,12 @@ func (eval *evaluator) keyswitchHoistedNoModDown(level uint64, c2QiQDecomp, c2Qi
reduce++
}
- if (reduce-1)&7 != 1 {
+ if reduce%QiOverF != 0 {
ringQ.ReduceLvl(level, pool2Q, pool2Q)
ringQ.ReduceLvl(level, pool3Q, pool3Q)
+ }
+
+ if reduce%PiOverF != 0 {
ringP.Reduce(pool2P, pool2P)
ringP.Reduce(pool3P, pool3P)
}
diff --git a/ckks/keygen.go b/ckks/keygen.go
index 660d8ab2..7d6d0e47 100644
--- a/ckks/keygen.go
+++ b/ckks/keygen.go
@@ -14,17 +14,30 @@ type KeyGenerator interface {
GenSecretKey() (sk *SecretKey)
GenSecretKeyGaussian() (sk *SecretKey)
GenSecretKeyWithDistrib(p float64) (sk *SecretKey)
- GenSecretKeySparse(hw uint64) (sk *SecretKey)
+ GenSecretKeySparse(hw int) (sk *SecretKey)
GenPublicKey(sk *SecretKey) (pk *PublicKey)
GenKeyPair() (sk *SecretKey, pk *PublicKey)
- GenKeyPairSparse(hw uint64) (sk *SecretKey, pk *PublicKey)
+ GenKeyPairSparse(hw int) (sk *SecretKey, pk *PublicKey)
GenSwitchingKey(skInput, skOutput *SecretKey) (newevakey *SwitchingKey)
GenRelinearizationKey(sk *SecretKey) (evakey *RelinearizationKey)
GenSwitchingKeyForGalois(galEl uint64, sk *SecretKey) (swk *SwitchingKey)
+
GenRotationKeys(galEls []uint64, sk *SecretKey) (rks *RotationKeySet)
+
GenRotationKeysForRotations(ks []int, includeConjugate bool, sk *SecretKey) (rks *RotationKeySet)
- GenRotationKeysForInnerSum(sk *SecretKey) (rks *RotationKeySet)
- GenBootstrappingKey(logSlots uint64, btpParams *BootstrappingParameters, sk *SecretKey) (btpKey *BootstrappingKey)
+
+ GenRotationIndexesForSubSum(logSlots int) (rotations []int)
+ GenRotationIndexesForCoeffsToSlots(logSlots int, btpParams *BootstrappingParameters) (rotations []int)
+ GenRotationIndexesForSlotsToCoeffs(logSlots int, btpParams *BootstrappingParameters) (rotations []int)
+ GenRotationIndexesForBootstrapping(logSlots int, btpParams *BootstrappingParameters) (rotations []int)
+
+ GenRotationIndexesForInnerSumLog(batch, n int) (rotations []int)
+ GenRotationIndexesForInnerSum(batch, n int) (rotations []int)
+
+ GenRotationIndexesForReplicateLog(batch, n int) (rotations []int)
+ GenRotationIndexesForReplicate(batch, n int) (rotations []int)
+
+ GenRotationIndexesForDiagMatrix(matrix *PtDiagMatrix) (rotations []int)
}
// KeyGenerator is a structure that stores the elements required to create new keys,
@@ -66,7 +79,7 @@ func NewKeyGenerator(params *Parameters) KeyGenerator {
ringQP: qp,
pBigInt: pBigInt,
polypool: [2]*ring.Poly{qp.NewPoly(), qp.NewPoly()},
- gaussianSampler: ring.NewGaussianSampler(prng, qp, params.sigma, uint64(6*params.sigma)),
+ gaussianSampler: ring.NewGaussianSampler(prng, qp, params.Sigma(), int(6*params.Sigma())),
uniformSampler: ring.NewUniformSampler(prng, qp),
}
}
@@ -99,7 +112,7 @@ func (keygen *keyGenerator) GenSecretKeyWithDistrib(p float64) (sk *SecretKey) {
}
// GenSecretKeySparse generates a new SecretKey with exactly hw non-zero coefficients.
-func (keygen *keyGenerator) GenSecretKeySparse(hw uint64) (sk *SecretKey) {
+func (keygen *keyGenerator) GenSecretKeySparse(hw int) (sk *SecretKey) {
prng, err := utils.NewPRNG()
if err != nil {
panic(err)
@@ -121,6 +134,7 @@ func (keygen *keyGenerator) GenPublicKey(sk *SecretKey) (pk *PublicKey) {
//pk[0] = [-(a*s + e)]
//pk[1] = [a]
+
pk.Value[0] = keygen.gaussianSampler.ReadNew()
ringQP.NTT(pk.Value[0], pk.Value[0])
pk.Value[1] = keygen.uniformSampler.ReadNew()
@@ -137,7 +151,7 @@ func (keygen *keyGenerator) GenKeyPair() (sk *SecretKey, pk *PublicKey) {
}
// GenKeyPairSparse generates a new SecretKey with exactly hw non zero coefficients [1/2, 0, 1/2].
-func (keygen *keyGenerator) GenKeyPairSparse(hw uint64) (sk *SecretKey, pk *PublicKey) {
+func (keygen *keyGenerator) GenKeyPairSparse(hw int) (sk *SecretKey, pk *PublicKey) {
sk = keygen.GenSecretKeySparse(hw)
return sk, keygen.GenPublicKey(sk)
}
@@ -196,7 +210,7 @@ func (keygen *keyGenerator) genrotKey(sk *ring.Poly, galEl uint64, swk *rlwe.Swi
skIn := sk
skOut := keygen.polypool[1]
- index := ring.PermuteNTTIndex(galEl, keygen.ringQP.N)
+ index := ring.PermuteNTTIndex(galEl, uint64(keygen.ringQP.N))
ring.PermuteNTTWithIndexLvl(keygen.params.QPiCount()-1, skIn, index, skOut)
keygen.newSwitchingKey(skIn, skOut, swk)
@@ -217,10 +231,11 @@ func (keygen *keyGenerator) newSwitchingKey(skIn, skOut *ring.Poly, swk *rlwe.Sw
alpha := keygen.params.Alpha()
beta := keygen.params.Beta()
- var index uint64
- for i := uint64(0); i < beta; i++ {
+ var index int
+ for i := 0; i < beta; i++ {
// e
+
keygen.gaussianSampler.Read(swk.Value[i][0])
ringQP.NTTLazy(swk.Value[i][0], swk.Value[i][0])
ringQP.MForm(swk.Value[i][0], swk.Value[i][0])
@@ -235,7 +250,7 @@ func (keygen *keyGenerator) newSwitchingKey(skIn, skOut *ring.Poly, swk *rlwe.Sw
// q_tild = q_star^-1 mod q_prod
//
// Therefore : (skIn * P) * (q_star * q_tild) = sk*P mod q[i*alpha+j], else 0
- for j := uint64(0); j < alpha; j++ {
+ for j := 0; j < alpha; j++ {
index = i*alpha + j
@@ -243,7 +258,7 @@ func (keygen *keyGenerator) newSwitchingKey(skIn, skOut *ring.Poly, swk *rlwe.Sw
p0tmp := keygen.polypool[0].Coeffs[index]
p1tmp := swk.Value[i][0].Coeffs[index]
- for w := uint64(0); w < ringQP.N; w++ {
+ for w := 0; w < ringQP.N; w++ {
p1tmp[w] = ring.CRed(p1tmp[w]+p0tmp[w], qi)
}
@@ -284,131 +299,250 @@ func (keygen *keyGenerator) GenRotationKeysForRotations(ks []int, includeConjuga
return keygen.GenRotationKeys(galEls, sk)
}
-// GenRotationKeysForInnerSum generates a RotationKeySet supporting the InnerSum operation of the Evaluator
-func (keygen *keyGenerator) GenRotationKeysForInnerSum(sk *SecretKey) (rks *RotationKeySet) {
- return keygen.GenRotationKeys(keygen.params.GaloisElementsForRowInnerSum(), sk)
+// GenRotationIndexesForInnerSumNaive generates the rotation indexes for the
+// InnerSumNaive. To be then used with GenRotationKeysForRotations to generate
+// the RotationKeySet.
+func (keygen *keyGenerator) GenRotationIndexesForInnerSum(batch, n int) (rotations []int) {
+ rotations = []int{}
+ for i := 1; i < n; i++ {
+ rotations = append(rotations, i*batch)
+ }
+ return
}
-// GenKeys generates the bootstrapping keys
-func (keygen *keyGenerator) GenBootstrappingKey(logSlots uint64, btpParams *BootstrappingParameters, sk *SecretKey) (btpKey *BootstrappingKey) {
+// GenRotationIndexesForInnerSum generates the rotation indexes for the
+// InnerSum. To be then used with GenRotationKeysForRotations to generate
+// the RotationKeySet.
+func (keygen *keyGenerator) GenRotationIndexesForInnerSumLog(batch, n int) (rotations []int) {
- rotUint := computeBootstrappingDFTRotationList(keygen.params.logN, logSlots, btpParams)
- rotInt := make([]int, len(rotUint), len(rotUint))
- for i, r := range rotUint {
- rotInt[i] = int(r)
+ rotations = []int{}
+ var k int
+ for i := 1; i < n; i <<= 1 {
+
+ k = i
+ k *= batch
+
+ if !utils.IsInSliceInt(k, rotations) && k != 0 {
+ rotations = append(rotations, k)
+ }
+
+ k = n - (n & ((i << 1) - 1))
+ k *= batch
+
+ if !utils.IsInSliceInt(k, rotations) && k != 0 {
+ rotations = append(rotations, k)
+ }
}
- btpKey = &BootstrappingKey{
- Rlk: keygen.GenRelinearizationKey(sk),
- Rtks: keygen.GenRotationKeysForRotations(rotInt, true, sk),
- }
-
- /*
- nbKeys := uint64(len(rotKeyIndex)) + 2 //rot keys + conj key + relin key
- nbPoly := keygen.params.Beta()
- nbCoefficients := 2 * keygen.params.N() * keygen.params.QPiCount()
- bytesPerCoeff := uint64(8)
- log.Println("Switching-Keys size (GB) :", float64(nbKeys*nbPoly*nbCoefficients*bytesPerCoeff)/float64(1000000000), "(", nbKeys, "keys)")
- */
-
return
}
-func addMatrixRotToList(pVec map[uint64]bool, rotations []uint64, N1, slots uint64, repack bool) []uint64 {
-
- var index uint64
- for j := range pVec {
-
- index = (j / N1) * N1
-
- if repack {
- // Sparse repacking, occurring during the first DFT matrix of the CoeffsToSlots.
- index &= (2*slots - 1)
- } else {
- // Other cases
- index &= (slots - 1)
- }
-
- if index != 0 && !utils.IsInSliceUint64(index, rotations) {
- rotations = append(rotations, index)
- }
-
- index = j & (N1 - 1)
-
- if index != 0 && !utils.IsInSliceUint64(index, rotations) {
- rotations = append(rotations, index)
- }
- }
-
- return rotations
+func (keygen *keyGenerator) GenRotationIndexesForReplicateLog(batch, n int) (rotations []int) {
+ return keygen.GenRotationIndexesForInnerSumLog(-batch, n)
}
-func computeBootstrappingDFTRotationList(logN, logSlots uint64, btpParams *BootstrappingParameters) (rotKeyIndex []uint64) {
+func (keygen *keyGenerator) GenRotationIndexesForReplicate(batch, n int) (rotations []int) {
+ return keygen.GenRotationIndexesForInnerSum(-batch, n)
+}
- // List of the rotation key values to needed for the bootstrapp
- rotKeyIndex = []uint64{}
+// GetRotationIndexForDiagMatrix generates of all the rotations needed for a the multiplication
+// with the diagonal plaintext matrix.
+func (keygen *keyGenerator) GenRotationIndexesForDiagMatrix(matrix *PtDiagMatrix) []int {
+ slots := 1 << matrix.LogSlots
- var slots uint64 = 1 << logSlots
- var dslots uint64 = slots
- if logSlots < logN-1 {
- dslots <<= 1
- }
+ rotKeyIndex := []int{}
- //SubSum rotation needed X -> Y^slots rotations
- for i := logSlots; i < logN-1; i++ {
- if !utils.IsInSliceUint64(1< Y^slots rotations
+ for i := logSlots; i < logN-1; i++ {
+ if !utils.IsInSliceInt(1< Y^slots rotations
+ for i := logSlots; i < logN-1; i++ {
+ if !utils.IsInSliceInt(1< 0 {
-
- hoisted := len(index[0]) - 1
- normal := len(index) - 1
-
- // The matrice is very sparse already
- if normal == 0 {
- return N1 / 2
- }
-
- if hoisted > normal {
- // Finds the next split that has a ratio hoisted/normal greater or equal to maxRatio
- for float64(hoisted)/float64(normal) < maxRatio {
-
- if normal/2 == 0 {
- break
- }
- N1 *= 2
- hoisted = hoisted*2 + 1
- normal = normal / 2
- }
- return N1
- }
- }
- }
-
- return 1
-}
-
-func genWfftIndexMap(logL, level uint64, forward bool) (vectors map[uint64]bool) {
-
- var rot uint64
-
- if forward {
+ if forward && !bitreversed || !forward && bitreversed {
rot = 1 << (level - 1)
} else {
rot = 1 << (logL - level)
}
- vectors = make(map[uint64]bool)
+ vectors = make(map[int]bool)
vectors[0] = true
vectors[rot] = true
vectors[(1< 0; i, j = i+1, j>>1 {
+
+ // Starts by decomposing the input ciphertext
+ if i == 0 {
+ // If first iteration, then copies directly from the input ciphertext that hasn't been rotated
+ eval.DecompInternal(levelQ, ctIn.value[1], eval.c2QiQDecomp, eval.c2QiPDecomp)
+ } else {
+ // Else copies from the rotated input ciphertext
+ eval.DecompInternal(levelQ, tmpc1, eval.c2QiQDecomp, eval.c2QiPDecomp)
+ }
+
+ // If the binary reading scans a 1
+ if j&1 == 1 {
+
+ k := n - (n & ((2 << i) - 1))
+ k *= batchSize
+
+ // If the rotation is not zero
+ if k != 0 {
+
+ // Rotate((tmpc0, tmpc1), k)
+ eval.permuteNTTHoistedNoModDown(levelQ, eval.c2QiQDecomp, eval.c2QiPDecomp, k, pool2Q, pool3Q, pool2P, pool3P)
+
+ // ctOut += Rotate((tmpc0, tmpc1), k)
+ if copy {
+ ringQ.CopyLvl(levelQ, pool2Q, ct0OutQ)
+ ringQ.CopyLvl(levelQ, pool3Q, ct1OutQ)
+ ringP.Copy(pool2P, ct0OutP)
+ ringP.Copy(pool3P, ct1OutP)
+ copy = false
+ } else {
+ ringQ.AddLvl(levelQ, ct0OutQ, pool2Q, ct0OutQ)
+ ringQ.AddLvl(levelQ, ct1OutQ, pool3Q, ct1OutQ)
+ ringP.Add(ct0OutP, pool2P, ct0OutP)
+ ringP.Add(ct1OutP, pool3P, ct1OutP)
+ }
+
+ if i == 0 {
+ ring.PermuteNTTWithIndexLvl(levelQ, ctIn.value[0], eval.permuteNTTIndex[eval.params.GaloisElementForColumnRotationBy(k)], tmpc2)
+ } else {
+ ring.PermuteNTTWithIndexLvl(levelQ, tmpc0, eval.permuteNTTIndex[eval.params.GaloisElementForColumnRotationBy(k)], tmpc2)
+ }
+
+ ringQ.MulScalarBigintLvl(levelQ, tmpc2, ringP.ModulusBigint, tmpc2)
+ ringQ.AddLvl(levelQ, ct0OutQ, tmpc2, ct0OutQ)
+
+ } else {
+
+ state = true
+
+ // if n is not a power of two
+ if n&(n-1) != 0 {
+ eval.baseconverter.ModDownSplitNTTPQ(levelQ, ct0OutQ, ct0OutP, ct0OutQ) // Division by P
+ eval.baseconverter.ModDownSplitNTTPQ(levelQ, ct1OutQ, ct1OutP, ct1OutQ) // Division by P
+
+ // ctOut += (tmpc0, tmpc1)
+ ringQ.AddLvl(levelQ, ct0OutQ, tmpc0, ctOut.value[0])
+ ringQ.AddLvl(levelQ, ct1OutQ, tmpc1, ctOut.value[1])
+
+ } else {
+ ringQ.CopyLvl(levelQ, tmpc0, ctOut.value[0])
+ ringQ.CopyLvl(levelQ, tmpc1, ctOut.value[1])
+ }
+ }
+ }
+
+ if !state {
+ if i == 0 {
+ eval.permuteNTTHoisted(levelQ, ctIn.value[0], ctIn.value[1], eval.c2QiQDecomp, eval.c2QiPDecomp, (1<> 1
+ PiOverF := eval.params.PiOverflowMargin() >> 1
+
+ // If sum with only the first element, then returns the input
+ if n == 1 {
+
+ // If the input-output points differ, copies on the output
+ if ctIn != ctOut {
+ ringQ.CopyLvl(levelQ, ctIn.value[0], ctOut.value[0])
+ ringQ.CopyLvl(levelQ, ctIn.value[1], ctOut.value[1])
+ }
+ // If sum on at least two elements
+ } else {
+
+ // List of n-2 rotations
+ rotations := []int{}
+ for i := 1; i < n; i++ {
+ rotations = append(rotations, i*batchSize)
+ }
+
+ // Memory pool
+ tmpQ0 := eval.poolQMul[0] // unused memory pool from evaluator
+ tmpQ1 := eval.poolQMul[1] // unused memory pool from evaluator
+
+ pool2P := eval.poolP[1] // ctOut(c0', c1') from evaluator keyswitch memory pool
+ pool3P := eval.poolP[2] // ctOut(c0', c1') from evaluator keyswitch memory pool
+
+ // Basis decomposition
+ eval.DecompInternal(levelQ, ctIn.value[1], eval.c2QiQDecomp, eval.c2QiPDecomp)
+
+ // Pre-rotates all [1, ..., n-1] rotations
+ // Hoisted rotation without division by P
+ vecRotQ, vecRotP := eval.rotateHoistedNoModDown(ctIn, rotations, eval.c2QiQDecomp, eval.c2QiPDecomp)
+
+ // P*c0 -> tmpQ0
+ ringQ.MulScalarBigintLvl(levelQ, ctIn.value[0], ringP.ModulusBigint, tmpQ0)
+
+ // Adds phi_k(P*c0) on each of the vecRotQ
+ // Does not need to add on the vecRotP because mod P === 0
+ for _, i := range rotations {
+ if i != 0 {
+
+ galEl := eval.params.GaloisElementForColumnRotationBy(i)
+
+ _, generated := eval.rtks.Keys[galEl]
+ if !generated {
+ panic("switching key not available")
+ }
+
+ index := eval.permuteNTTIndex[galEl]
+
+ ring.PermuteNTTWithIndexLvl(levelQ, tmpQ0, index, tmpQ1) // phi(P*c0)
+ ringQ.AddLvl(levelQ, vecRotQ[i][0], tmpQ1, vecRotQ[i][0]) // phi(d0_Q) += phi(P*c0)
+ }
+ }
+
+ var reduce int
+ // Sums elements [2, ..., n-1]
+ for i := 1; i < n; i++ {
+
+ j := i * batchSize
+
+ if i == 1 {
+ ringQ.CopyLvl(levelQ, vecRotQ[j][0], tmpQ0)
+ ringQ.CopyLvl(levelQ, vecRotQ[j][1], tmpQ1)
+ ringP.Copy(vecRotP[j][0], pool2P)
+ ringP.Copy(vecRotP[j][1], pool3P)
+ } else {
+ ringQ.AddNoModLvl(levelQ, tmpQ0, vecRotQ[j][0], tmpQ0)
+ ringQ.AddNoModLvl(levelQ, tmpQ1, vecRotQ[j][1], tmpQ1)
+ ringP.AddNoMod(pool2P, vecRotP[j][0], pool2P)
+ ringP.AddNoMod(pool3P, vecRotP[j][1], pool3P)
+ }
+
+ if reduce%QiOverF == QiOverF-1 {
+ ringQ.ReduceLvl(levelQ, tmpQ0, tmpQ0)
+ ringQ.ReduceLvl(levelQ, tmpQ1, tmpQ1)
+ }
+
+ if reduce%PiOverF == PiOverF-1 {
+ ringP.Reduce(pool2P, pool2P)
+ ringP.Reduce(pool3P, pool3P)
+ }
+
+ reduce++
+ }
+
+ if reduce%QiOverF != 0 {
+ ringQ.ReduceLvl(levelQ, tmpQ0, tmpQ0)
+ ringQ.ReduceLvl(levelQ, tmpQ1, tmpQ1)
+ }
+
+ if reduce%PiOverF != 0 {
+ ringP.Reduce(pool2P, pool2P)
+ ringP.Reduce(pool3P, pool3P)
+ }
+
+ // Division by P of sum(elements [2, ..., n-1] )
+ eval.baseconverter.ModDownSplitNTTPQ(levelQ, tmpQ0, pool2P, tmpQ0) // sum_{i=1, n-1}(phi(d0))/P
+ eval.baseconverter.ModDownSplitNTTPQ(levelQ, tmpQ1, pool3P, tmpQ1) // sum_{i=1, n-1}(phi(d1))/P
+
+ // Adds element[1] (which did not require rotation)
+ ringQ.AddLvl(levelQ, ctIn.value[0], tmpQ0, ctOut.value[0]) // sum_{i=1, n-1}(phi(d0))/P + ct0
+ ringQ.AddLvl(levelQ, ctIn.value[1], tmpQ1, ctOut.value[1]) // sum_{i=1, n-1}(phi(d1))/P + ct1
+ }
+}
+
+// ReplicateLog applies an optimized replication on the ciphetext (log2(n) + HW(n) rotations with double hoisting).
+// It acts as the inverse of a inner sum (summing elements from left to right).
+// The replication is parameterized by the size of the sub-vectors to replicate "batchSize" and
+// the number of time "n" they need to be replicated.
+// To ensure correctness, a gap of zero values of size batchSize * (n-1) must exist between
+// two consecutive sub-vectors to replicate.
+// This method is faster than Replicate when the number of rotations is large and uses log2(n) + HW(n) instead of 'n'.
+func (eval *evaluator) ReplicateLog(ctIn *Ciphertext, batchSize, n int, ctOut *Ciphertext) {
+ eval.InnerSumLog(ctIn, -batchSize, n, ctOut)
+}
+
+// Replicate applies naive replication on the ciphetext (n rotations with single hoisting).
+// It acts as the inverse of a inner sum (summing elements from left to right).
+// The replication is parameterized by the size of the sub-vectors to replicate "batchSize" and
+// the number of time "n" they need to be replicated.
+// To ensure correctness, a gap of zero values of size batchSize * (n-1) must exist between
+// two consecutive sub-vectors to replicate.
+// This method is faster than ReplicateLog when the number of rotations is small but uses 'n' keys instead of log2(n) + HW(n).
+func (eval *evaluator) Replicate(ctIn *Ciphertext, batchSize, n int, ctOut *Ciphertext) {
+ eval.InnerSum(ctIn, -batchSize, n, ctOut)
+}
+
+// MultiplyByDiagMatrix multiplies the ciphertext "ctIn" by the plaintext matrix "matrix" and returns the result on the ciphertext
+// "ctOut". Memory pools for the decomposed ciphertext c2QiQDecomp, c2QiPDecomp must be provided, those are list of poly of ringQ and ringP
+// respectively, each of size params.Beta().
+// The naive approach is used (single hoisting and no baby-step giant-step), which is faster than MultiplyByDiagMatrixBSGS
+// for matrix of only a few non-zero diagonals but uses more keys.
+func (eval *evaluator) MultiplyByDiagMatrix(ctIn *Ciphertext, matrix *PtDiagMatrix, c2QiQDecomp, c2QiPDecomp []*ring.Poly, ctOut *Ciphertext) {
+
+ ringQ := eval.ringQ
+ ringP := eval.ringP
+
+ levelQ := utils.MinInt(ctOut.Level(), utils.MinInt(ctIn.Level(), matrix.Level))
+ levelP := eval.params.PiCount() - 1
+
+ QiOverF := eval.params.QiOverflowMargin(levelQ)
+ PiOverF := eval.params.PiOverflowMargin()
+
+ ksResP0 := eval.poolP[0] // Key-Switch ctOut[0] mod P
+ ksResP1 := eval.poolP[1] // Key-Switch ctOut[1] mod P
+ tmpP0 := eval.poolP[2] // Automorphism not-inplace pool res[0] mod P
+ tmpP1 := eval.poolQMul[0] // Automorphism not-inplace pool res[1] mod P
+ accP0 := eval.poolP[3] // Accumulator ctOut[0] mod P
+ accP1 := eval.poolP[4] // Accumulator ctOut[1] mod P
+
+ ct0TimesP := eval.poolQ[0] // ct0 * P mod Q
+ ksResQ0 := eval.poolQ[1] // Key-Switch ctOut[0] mod Q
+ ksResQ1 := eval.poolQ[2] // Key-Switch ctOut[0] mod Q
+ tmpQ0 := eval.poolQ[3] // Automorphism not-inplace pool ctOut[0] mod Q
+ tmpQ1 := eval.poolQ[4] // Automorphism not-inplace pool ctOut[1] mod Q
+
+ ringQ.MulScalarBigintLvl(levelQ, ctIn.value[0], ringP.ModulusBigint, ct0TimesP) // P*c0
+
+ var state bool
+ var cnt int
+ for k := range matrix.Vec {
+
+ k &= int((ringQ.N >> 1) - 1)
+
+ if k == 0 {
+ state = true
+ } else {
+
+ galEl := eval.params.GaloisElementForColumnRotationBy(k)
+
+ rtk, generated := eval.rtks.Keys[galEl]
+ if !generated {
+ panic("switching key not available")
+ }
+
+ index := eval.permuteNTTIndex[galEl]
+
+ eval.keyswitchHoistedNoModDown(levelQ, c2QiQDecomp, c2QiPDecomp, rtk, ksResQ0, ksResQ1, ksResP0, ksResP1)
+
+ ringQ.AddLvl(levelQ, ksResQ0, ct0TimesP, ksResQ0) // phi(d0_Q) += phi(P*c0)
+
+ ring.PermuteNTTWithIndexLvl(levelQ, ksResQ0, index, tmpQ0) // phi(P*c0 + d0_Q)
+ ring.PermuteNTTWithIndexLvl(levelQ, ksResQ1, index, tmpQ1) // phi( d1_Q)
+
+ ring.PermuteNTTWithIndexLvl(levelP, ksResP0, index, tmpP0) // phi(P*c0 + d0_P)
+ ring.PermuteNTTWithIndexLvl(levelP, ksResP1, index, tmpP1) // phi( d1_P)
+
+ plaintextQ := matrix.Vec[k][0]
+ plaintextP := matrix.Vec[k][1]
+
+ if cnt == 0 {
+ // keyswitch(c1_Q) = (d0_QP, d1_QP)
+ ringQ.MulCoeffsMontgomeryLvl(levelQ, plaintextQ, tmpQ0, ctOut.value[0]) // phi(P*c0 + d0_Q) * plaintext
+ ringQ.MulCoeffsMontgomeryLvl(levelQ, plaintextQ, tmpQ1, ctOut.value[1]) // phi(d1_Q) * plaintext
+ ringP.MulCoeffsMontgomery(plaintextP, tmpP0, accP0) // phi(d0_P) * plaintext
+ ringP.MulCoeffsMontgomery(plaintextP, tmpP1, accP1) // phi(d1_P) * plaintext
+ } else {
+ // keyswitch(c1_Q) = (d0_QP, d1_QP)
+ ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, plaintextQ, tmpQ0, ctOut.value[0]) // phi(P*c0 + d0_Q) * plaintext
+ ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, plaintextQ, tmpQ1, ctOut.value[1]) // phi(d1_Q) * plaintext
+ ringP.MulCoeffsMontgomeryAndAdd(plaintextP, tmpP0, accP0) // phi(d0_P) * plaintext
+ ringP.MulCoeffsMontgomeryAndAdd(plaintextP, tmpP1, accP1) // phi(d1_P) * plaintext
+ }
+
+ if cnt%QiOverF == QiOverF-1 {
+ ringQ.ReduceLvl(levelQ, ctOut.value[0], ctOut.value[0])
+ ringQ.ReduceLvl(levelQ, ctOut.value[1], ctOut.value[1])
+ }
+
+ if cnt%PiOverF == PiOverF-1 {
+ ringP.Reduce(accP0, accP0)
+ ringP.Reduce(accP1, accP1)
+ }
+
+ cnt++
+ }
+ }
+
+ if cnt%QiOverF == 0 {
+ ringQ.ReduceLvl(levelQ, ctOut.value[0], ctOut.value[0])
+ ringQ.ReduceLvl(levelQ, ctOut.value[1], ctOut.value[1])
+ }
+
+ if cnt%PiOverF == 0 {
+ ringP.Reduce(accP0, accP0)
+ ringP.Reduce(accP1, accP1)
+ }
+
+ eval.baseconverter.ModDownSplitNTTPQ(levelQ, ctOut.value[0], accP0, ctOut.value[0]) // sum(phi(c0 * P + d0_QP))/P
+ eval.baseconverter.ModDownSplitNTTPQ(levelQ, ctOut.value[1], accP1, ctOut.value[1]) // sum(phi(d1_QP))/P
+
+ if state { // Rotation by zero
+ ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, matrix.Vec[0][0], ctIn.value[0], ctOut.value[0]) // ctOut += c0_Q * plaintext
+ ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, matrix.Vec[0][0], ctIn.value[1], ctOut.value[1]) // ctOut += c1_Q * plaintext
+ }
+
+ ctOut.SetScale(matrix.Scale * ctIn.Scale())
+}
+
+// MultiplyByDiagMatrixBSGS multiplies the ciphertext "ctIn" by the plaintext matrix "matrix" and returns the result on the ciphertext
+// "ctOut". Memory pools for the decomposed ciphertext c2QiQDecomp, c2QiPDecomp must be provided, those are list of poly of ringQ and ringP
+// respectively, each of size params.Beta().
+// The BSGS approach is used (double hoisting with baby-step giant-step), which is faster than MultiplyByDiagMatrix
+// for matrix with more than a few non-zero diagonals and uses much less keys.
+func (eval *evaluator) MultiplyByDiagMatrixBSGS(ctIn *Ciphertext, matrix *PtDiagMatrix, c2QiQDecomp, c2QiPDecomp []*ring.Poly, ctOut *Ciphertext) {
+
+ // N1*N2 = N
+ N1 := matrix.N1
+
+ ringQ := eval.ringQ
+ ringP := eval.ringP
+
+ levelQ := utils.MinInt(ctOut.Level(), utils.MinInt(ctIn.Level(), matrix.Level))
+ levelP := eval.params.PiCount() - 1
+
+ QiOverF := eval.params.QiOverflowMargin(levelQ)
+ PiOverF := eval.params.PiOverflowMargin()
+
+ // Computes the rotations indexes of the non-zero rows of the diagonalized DFT matrix for the baby-step giang-step algorithm
+
+ index, rotations := bsgsIndex(matrix.Vec, 1<>1) == (QiOverF>>1)-1 {
+ ringQ.ReduceLvl(levelQ, tmpQ0, tmpQ0)
+ ringQ.ReduceLvl(levelQ, tmpQ1, tmpQ1)
+ }
+
+ if cnt1%(PiOverF>>1) == (PiOverF>>1)-1 {
+ ringP.Reduce(pool2P, pool2P)
+ ringP.Reduce(pool3P, pool3P)
+ }
+
+ cnt1++
+ }
+ }
+
+ if cnt1%(QiOverF>>1) != 0 {
+ ringQ.ReduceLvl(levelQ, tmpQ0, tmpQ0)
+ ringQ.ReduceLvl(levelQ, tmpQ1, tmpQ1)
+ }
+
+ if cnt1%(PiOverF>>1) != 0 {
+ ringP.Reduce(pool2P, pool2P)
+ ringP.Reduce(pool3P, pool3P)
+ }
+
+ // Hoisting of the ModDown of sum(sum(phi(d0 + P*c0) * plaintext)) and sum(sum(phi(d1) * plaintext))
+ eval.baseconverter.ModDownSplitNTTPQ(levelQ, tmpQ0, pool2P, tmpQ0) // sum(phi(d0) * plaintext)/P
+ eval.baseconverter.ModDownSplitNTTPQ(levelQ, tmpQ1, pool3P, tmpQ1) // sum(phi(d1) * plaintext)/P
+
+ // If i == 0
+ if state {
+
+ // If no loop before, then we copy the values on the accumulator instead of adding them
+ if len(index[j]) == 1 {
+ ringQ.MulCoeffsMontgomeryLvl(levelQ, matrix.Vec[N1*j][0], ctIn.value[0], tmpQ0) // c0 * plaintext + sum(phi(d0) * plaintext)/P + phi(c0) * plaintext mod Q
+ ringQ.MulCoeffsMontgomeryLvl(levelQ, matrix.Vec[N1*j][0], ctIn.value[1], tmpQ1) // c1 * plaintext + sum(phi(d1) * plaintext)/P + phi(c1) * plaintext mod Q
+ } else {
+ ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, matrix.Vec[N1*j][0], ctIn.value[0], tmpQ0) // c0 * plaintext + sum(phi(d0) * plaintext)/P + phi(c0) * plaintext mod Q
+ ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, matrix.Vec[N1*j][0], ctIn.value[1], tmpQ1) // c1 * plaintext + sum(phi(d1) * plaintext)/P + phi(c1) * plaintext mod Q
+ }
+
+ N1Rot++
+ }
+
+ galEl := eval.params.GaloisElementForColumnRotationBy(N1 * j)
+
+ rtk, generated := eval.rtks.Keys[galEl]
+ if !generated {
+ panic("switching key not available")
+ }
+
+ index := eval.permuteNTTIndex[galEl]
+
+ eval.SwitchKeysInPlaceNoModDown(levelQ, tmpQ1, rtk, pool2Q, pool2P, pool3Q, pool3P) // Switchkey(phi(tmpRes_1)) = (d0, d1) in base QP
+
+ // Outer loop rotations
+ ring.PermuteNTTWithIndexLvl(levelQ, tmpQ0, index, tmpQ1) // phi(tmpRes_0)
+ ringQ.AddLvl(levelQ, ctOut.value[0], tmpQ1, ctOut.value[0]) // ctOut += phi(tmpRes)
+
+ N2Rot++
+
+ if cnt0 == 0 {
+ ring.PermuteNTTWithIndexLvl(levelQ, pool2Q, index, tmpQ2) // sum(phi(d0_Q))
+ ring.PermuteNTTWithIndexLvl(levelQ, pool3Q, index, tmpQ3) // sum(phi(d1_Q))
+ ring.PermuteNTTWithIndexLvl(levelP, pool2P, index, tmpP2) // sum(phi(d0_P))
+ ring.PermuteNTTWithIndexLvl(levelP, pool3P, index, tmpP3) // sum(phi(d1_P))
+ } else {
+ ring.PermuteNTTWithIndexAndAddNoModLvl(levelQ, pool2Q, index, tmpQ2) // sum(phi(d0_Q))
+ ring.PermuteNTTWithIndexAndAddNoModLvl(levelQ, pool3Q, index, tmpQ3) // sum(phi(d1_Q))
+ ring.PermuteNTTWithIndexAndAddNoModLvl(levelP, pool2P, index, tmpP2) // sum(phi(d0_P))
+ ring.PermuteNTTWithIndexAndAddNoModLvl(levelP, pool3P, index, tmpP3) // sum(phi(d1_P))
+ }
+
+ if cnt0%QiOverF == QiOverF-1 {
+ ringQ.ReduceLvl(levelQ, tmpQ2, tmpQ2)
+ ringQ.ReduceLvl(levelQ, tmpQ3, tmpQ3)
+ }
+
+ if cnt0%PiOverF == PiOverF-1 {
+ ringP.Reduce(tmpP2, tmpP2)
+ ringP.Reduce(tmpP3, tmpP3)
+ }
+
+ cnt0++
+ }
+ }
+
+ if cnt0%QiOverF != 0 {
+ ringQ.ReduceLvl(levelQ, tmpQ2, tmpQ2)
+ ringQ.ReduceLvl(levelQ, tmpQ3, tmpQ3)
+ }
+
+ if cnt0%PiOverF != 0 {
+ ringP.Reduce(tmpP2, tmpP2)
+ ringP.Reduce(tmpP3, tmpP3)
+ }
+
+ // if j == 0 (N2 rotation by zero)
+ var state bool
+ var cnt1 int
+ for _, i := range index[0] {
+
+ if i == 0 {
+ state = true
+ } else {
+
+ plaintextQ := matrix.Vec[i][0]
+ plaintextP := matrix.Vec[i][1]
+ N1Rot++
+ // keyswitch(c1_Q) = (d0_QP, d1_QP)
+ ringQ.MulCoeffsMontgomeryConstantAndAddNoModLvl(levelQ, plaintextQ, vecRotQ[i][0], tmpQ2) // phi(P*c0 + d0_Q) * plaintext
+ ringQ.MulCoeffsMontgomeryConstantAndAddNoModLvl(levelQ, plaintextQ, vecRotQ[i][1], tmpQ3) // phi(d1_Q) * plaintext
+ ringP.MulCoeffsMontgomeryAndAddNoMod(plaintextP, vecRotP[i][0], tmpP2) // phi(d0_P) * plaintext
+ ringP.MulCoeffsMontgomeryAndAddNoMod(plaintextP, vecRotP[i][1], tmpP3) // phi(d1_P) * plaintext
+
+ if cnt1%(QiOverF>>1) == (QiOverF>>1)-1 {
+ ringQ.ReduceLvl(levelQ, tmpQ2, tmpQ2)
+ ringQ.ReduceLvl(levelQ, tmpQ3, tmpQ3)
+ }
+
+ if cnt1%(PiOverF>>1) == (PiOverF>>1)-1 {
+ ringP.Reduce(tmpP2, tmpP2)
+ ringP.Reduce(tmpP3, tmpP3)
+ }
+
+ cnt1++
+ }
+ }
+
+ if cnt1%(QiOverF>>1) != 0 {
+ ringQ.ReduceLvl(levelQ, tmpQ2, tmpQ2)
+ ringQ.ReduceLvl(levelQ, tmpQ3, tmpQ3)
+ }
+
+ if cnt1%(PiOverF>>1) != 0 {
+ ringP.Reduce(tmpP2, tmpP2)
+ ringP.Reduce(tmpP3, tmpP3)
+ }
+
+ eval.baseconverter.ModDownSplitNTTPQ(levelQ, tmpQ2, tmpP2, tmpQ2) // sum(phi(c0 * P + d0_QP))/P
+ eval.baseconverter.ModDownSplitNTTPQ(levelQ, tmpQ3, tmpP3, tmpQ3) // sum(phi(d1_QP))/P
+
+ ringQ.AddLvl(levelQ, ctOut.value[0], tmpQ2, ctOut.value[0]) // ctOut += sum(phi(c0 * P + d0_QP))/P
+ ringQ.AddLvl(levelQ, ctOut.value[1], tmpQ3, ctOut.value[1]) // ctOut += sum(phi(d1_QP))/P
+
+ if state { // Rotation by zero
+ N1Rot++
+ ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, matrix.Vec[0][0], ctIn.value[0], ctOut.value[0]) // ctOut += c0_Q * plaintext
+ ringQ.MulCoeffsMontgomeryAndAddLvl(levelQ, matrix.Vec[0][0], ctIn.value[1], ctOut.value[1]) // ctOut += c1_Q * plaintext
+ }
+
+ ctOut.SetScale(matrix.Scale * ctIn.Scale())
+
+ vecRotQ, vecRotP = nil, nil
+}
diff --git a/ckks/marshaler.go b/ckks/marshaler.go
index 8e807971..b9a98ce1 100644
--- a/ckks/marshaler.go
+++ b/ckks/marshaler.go
@@ -9,7 +9,7 @@ import (
)
// GetDataLen returns the length in bytes of the target Ciphertext.
-func (ciphertext *Ciphertext) GetDataLen(WithMetaData bool) (dataLen uint64) {
+func (ciphertext *Ciphertext) GetDataLen(WithMetaData bool) (dataLen int) {
// MetaData is :
// 1 byte : Degree
// 9 byte : Scale
@@ -39,7 +39,7 @@ func (ciphertext *Ciphertext) MarshalBinary() (data []byte, err error) {
data[10] = 1
}
- var pointer, inc uint64
+ var pointer, inc int
pointer = 11
@@ -56,8 +56,6 @@ func (ciphertext *Ciphertext) MarshalBinary() (data []byte, err error) {
}
// UnmarshalBinary decodes a previously marshaled Ciphertext on the target Ciphertext.
-// The target Ciphertext must be of the appropriate format and size, it can be created with the
-// method NewCiphertext(uint64).
func (ciphertext *Ciphertext) UnmarshalBinary(data []byte) (err error) {
if len(data) < 11 { // cf. ciphertext.GetDataLen()
return errors.New("too small bytearray")
@@ -73,7 +71,7 @@ func (ciphertext *Ciphertext) UnmarshalBinary(data []byte) (err error) {
ciphertext.isNTT = true
}
- var pointer, inc uint64
+ var pointer, inc int
pointer = 11
for i := range ciphertext.value {
@@ -87,7 +85,7 @@ func (ciphertext *Ciphertext) UnmarshalBinary(data []byte) (err error) {
pointer += inc
}
- if pointer != uint64(len(data)) {
+ if pointer != len(data) {
return errors.New("remaining unparsed data")
}
diff --git a/ckks/operand.go b/ckks/operand.go
index 72d01d93..14a7d35a 100644
--- a/ckks/operand.go
+++ b/ckks/operand.go
@@ -10,8 +10,8 @@ import (
type Operand interface {
El() *Element
IsNTT() bool
- Degree() uint64
- Level() uint64
+ Degree() int
+ Level() int
Scale() float64
}
@@ -38,13 +38,13 @@ func (el *Element) SetValue(value []*ring.Poly) {
}
// Degree returns the degree of the target element.
-func (el *Element) Degree() uint64 {
- return uint64(len(el.value) - 1)
+func (el *Element) Degree() int {
+ return len(el.value) - 1
}
// Level returns the level of the target element.
-func (el *Element) Level() uint64 {
- return uint64(len(el.value[0].Coeffs) - 1)
+func (el *Element) Level() int {
+ return len(el.value[0].Coeffs) - 1
}
// Scale returns the scale of the target element.
@@ -68,14 +68,14 @@ func (el *Element) DivScale(scale float64) {
}
// Resize resizes the degree of the target element.
-func (el *Element) Resize(params *Parameters, degree uint64) {
+func (el *Element) Resize(params *Parameters, degree int) {
if el.Degree() > degree {
el.value = el.value[:degree+1]
} else if el.Degree() < degree {
for el.Degree() < degree {
el.value = append(el.value, []*ring.Poly{new(ring.Poly)}...)
el.value[el.Degree()].Coeffs = make([][]uint64, el.Level()+1)
- for i := uint64(0); i < el.Level()+1; i++ {
+ for i := 0; i < el.Level()+1; i++ {
el.value[el.Degree()].Coeffs[i] = make([]uint64, params.N())
}
}
diff --git a/ckks/params.go b/ckks/params.go
index b2f4710c..39306fde 100644
--- a/ckks/params.go
+++ b/ckks/params.go
@@ -232,14 +232,14 @@ func (m *LogModuli) Copy() LogModuli {
type Parameters struct {
qi []uint64
pi []uint64
- logN uint64 // Ring degree (power of 2)
- logSlots uint64
+ logN int // Ring degree (power of 2)
+ logSlots int
scale float64
sigma float64 // Gaussian sampling variance
}
// NewParametersFromModuli creates a new Parameters struct and returns a pointer to it.
-func NewParametersFromModuli(logN uint64, m *Moduli) (p *Parameters, err error) {
+func NewParametersFromModuli(logN int, m *Moduli) (p *Parameters, err error) {
p = new(Parameters)
if (logN < MinLogN) || (logN > MaxLogN) {
@@ -264,7 +264,7 @@ func NewParametersFromModuli(logN uint64, m *Moduli) (p *Parameters, err error)
}
// NewParametersFromLogModuli creates a new Parameters struct and returns a pointer to it.
-func NewParametersFromLogModuli(logN uint64, lm *LogModuli) (p *Parameters, err error) {
+func NewParametersFromLogModuli(logN int, lm *LogModuli) (p *Parameters, err error) {
if err = checkLogModuli(lm); err != nil {
return nil, err
@@ -290,42 +290,42 @@ func (p *Parameters) NewPolyQP() *ring.Poly {
}
// N returns the ring degree
-func (p *Parameters) N() uint64 {
+func (p *Parameters) N() int {
return 1 << p.logN
}
// LogN returns the log of the degree of the polynomial ring
-func (p *Parameters) LogN() uint64 {
+func (p *Parameters) LogN() int {
return p.logN
}
// LogSlots returns the log of the number of slots
-func (p *Parameters) LogSlots() uint64 {
+func (p *Parameters) LogSlots() int {
return p.logSlots
}
// MaxLevel returns the maximum ciphertext level
-func (p *Parameters) MaxLevel() uint64 {
+func (p *Parameters) MaxLevel() int {
return p.QiCount() - 1
}
// Levels returns then number of total levels enabled by the parameters
-func (p *Parameters) Levels() uint64 {
+func (p *Parameters) Levels() int {
return p.QiCount()
}
// Slots returns number of available plaintext slots
-func (p *Parameters) Slots() uint64 {
+func (p *Parameters) Slots() int {
return 1 << p.logSlots
}
// MaxSlots returns the theoretical maximum of plaintext slots allowed by the ring degree
-func (p *Parameters) MaxSlots() uint64 {
+func (p *Parameters) MaxSlots() int {
return p.N() >> 1
}
// MaxLogSlots returns the log of the maximum number of slots enabled by the parameters
-func (p *Parameters) MaxLogSlots() uint64 {
+func (p *Parameters) MaxLogSlots() int {
return p.logN - 1
}
@@ -345,7 +345,7 @@ func (p *Parameters) SetScale(scale float64) {
}
// SetLogSlots sets the value logSlots of the parameters.
-func (p *Parameters) SetLogSlots(logSlots uint64) {
+func (p *Parameters) SetLogSlots(logSlots int) {
if (logSlots == 0) || (logSlots > p.MaxLogSlots()) {
panic(fmt.Errorf("slots cannot be greater than LogN-1"))
}
@@ -376,6 +376,18 @@ func (p *Parameters) LogModuli() (lm *LogModuli) {
return
}
+// QiOverflowMargin returns floor(2^64 / max(Qi)), i.e. the number of times elements of Z_max{Qi} can
+// be added together before overflowing 2^64.
+func (p *Parameters) QiOverflowMargin(level int) int {
+ return int(math.Exp2(64) / float64(utils.MaxSliceUint64(p.qi[:level+1])))
+}
+
+// PiOverflowMargin returns floor(2^64 / max(Pi)), i.e. the number of times elements of Z_max{Pi} can
+// be added together before overflowing 2^64.
+func (p *Parameters) PiOverflowMargin() int {
+ return int(math.Exp2(64) / float64(utils.MaxSliceUint64(p.pi)))
+}
+
// Moduli returns a struct Moduli with the moduli of the parameters
func (p *Parameters) Moduli() (m *Moduli) {
m = new(Moduli)
@@ -394,8 +406,8 @@ func (p *Parameters) Qi() []uint64 {
}
// QiCount returns the number of factors of the ciphertext modulus Q
-func (p *Parameters) QiCount() uint64 {
- return uint64(len(p.qi))
+func (p *Parameters) QiCount() int {
+ return len(p.qi)
}
// Pi returns a new slice with the factors of the ciphertext modulus extension P
@@ -406,17 +418,17 @@ func (p *Parameters) Pi() []uint64 {
}
// PiCount returns the number of factors of the ciphertext modulus extension P
-func (p *Parameters) PiCount() uint64 {
- return uint64(len(p.pi))
+func (p *Parameters) PiCount() int {
+ return len(p.pi)
}
// QPiCount returns the number of factors of the ciphertext modulus + the modulus extension P
-func (p *Parameters) QPiCount() uint64 {
- return uint64(len(p.qi) + len(p.pi))
+func (p *Parameters) QPiCount() int {
+ return len(p.qi) + len(p.pi)
}
// LogQP returns the size of the extended modulus QP in bits
-func (p *Parameters) LogQP() uint64 {
+func (p *Parameters) LogQP() int {
tmp := ring.NewUint(1)
for _, qi := range p.qi {
tmp.Mul(tmp, ring.NewUint(qi))
@@ -424,17 +436,17 @@ func (p *Parameters) LogQP() uint64 {
for _, pi := range p.pi {
tmp.Mul(tmp, ring.NewUint(pi))
}
- return uint64(tmp.BitLen())
+ return tmp.BitLen()
}
// LogQLvl returns the size of the modulus Q in bits at a specific level
-func (p *Parameters) LogQLvl(level uint64) uint64 {
+func (p *Parameters) LogQLvl(level int) int {
tmp := p.QLvl(level)
- return uint64(tmp.BitLen())
+ return tmp.BitLen()
}
// QLvl returns the product of the moduli at the given level as a big.Int
-func (p *Parameters) QLvl(level uint64) *big.Int {
+func (p *Parameters) QLvl(level int) *big.Int {
tmp := ring.NewUint(1)
for _, qi := range p.qi[:level+1] {
tmp.Mul(tmp, ring.NewUint(qi))
@@ -443,7 +455,7 @@ func (p *Parameters) QLvl(level uint64) *big.Int {
}
// LogQ returns the size of the modulus Q in bits
-func (p *Parameters) LogQ() uint64 {
+func (p *Parameters) LogQ() int {
return p.LogQLvl(p.QiCount() - 1)
}
@@ -453,12 +465,12 @@ func (p *Parameters) Q() *big.Int {
}
// LogP returns the size of the modulus P in bits
-func (p *Parameters) LogP() uint64 {
+func (p *Parameters) LogP() int {
tmp := ring.NewUint(1)
for _, pi := range p.pi {
tmp.Mul(tmp, ring.NewUint(pi))
}
- return uint64(tmp.BitLen())
+ return tmp.BitLen()
}
// LogQAlpha returns the size in bits of the sum of the norm of
@@ -468,7 +480,7 @@ func (p *Parameters) LogP() uint64 {
// error during the keyswitching and then divided by P.
// LogQAlpha should be smaller than P or the error added during
// the key-switching wont be negligible.
-func (p *Parameters) LogQAlpha() uint64 {
+func (p *Parameters) LogQAlpha() int {
alpha := p.PiCount()
@@ -477,8 +489,8 @@ func (p *Parameters) LogQAlpha() uint64 {
}
res := ring.NewUint(0)
- var j uint64
- for i := uint64(0); i < p.QiCount(); i = i + alpha {
+ var j int
+ for i := 0; i < p.QiCount(); i = i + alpha {
j = i + alpha
if j > p.QiCount() {
@@ -493,18 +505,18 @@ func (p *Parameters) LogQAlpha() uint64 {
res.Add(res, tmp)
}
- return uint64(res.BitLen())
+ return res.BitLen()
}
// Alpha returns the number of moduli in in P
-func (p *Parameters) Alpha() uint64 {
+func (p *Parameters) Alpha() int {
return p.PiCount()
}
// Beta returns the number of element in the RNS decomposition basis: Ceil(lenQi / lenPi)
-func (p *Parameters) Beta() uint64 {
+func (p *Parameters) Beta() int {
if p.Alpha() != 0 {
- return uint64(math.Ceil(float64(p.QiCount()) / float64(p.Alpha())))
+ return int(math.Ceil(float64(p.QiCount()) / float64(p.Alpha())))
}
return 0
@@ -516,8 +528,8 @@ func (p *Parameters) Beta() uint64 {
func (p *Parameters) GaloisElementForColumnRotationBy(k int) uint64 {
twoN := 1 << (p.logN + 1)
mask := twoN - 1
- kRed := uint64(k & mask)
- return ring.ModExp(GaloisGen, kRed, uint64(twoN))
+ kRed := k & mask
+ return ring.ModExp(uint64(GaloisGen), kRed, uint64(twoN))
}
// GaloisElementForRowRotation returns the galois element corresponding to a row rotation (conjugate) automorphism
@@ -525,9 +537,8 @@ func (p *Parameters) GaloisElementForRowRotation() uint64 {
return (1 << (p.logN + 1)) - 1
}
-// GaloisElementsForRowInnerSum returns a list of all galois elements required to
-// perform an InnerSum operation. This corresponds to all the left rotations by
-// k-positions where k is a power of two and the conjugate element.
+// GaloisElementsForRowInnerSum returns a list of galois element corresponding to
+// all the left rotations by a k-position where k is a power of two.
func (p *Parameters) GaloisElementsForRowInnerSum() (galEls []uint64) {
galEls = make([]uint64, p.logN+1, p.logN+1)
galEls[0] = p.GaloisElementForRowRotation()
@@ -539,8 +550,8 @@ func (p *Parameters) GaloisElementsForRowInnerSum() (galEls []uint64) {
// InverseGaloisElement returns the galois element for the inverse automorphism of galEl
func (p *Parameters) InverseGaloisElement(galEl uint64) uint64 {
- twoN := uint64(1 << (p.logN + 1))
- return ring.ModExp(galEl, twoN-1, twoN)
+ twoN := 1 << (p.logN + 1)
+ return ring.ModExp(galEl, twoN-1, uint64(twoN))
}
// Copy creates a copy of the target parameters.
@@ -610,13 +621,13 @@ func (p *Parameters) UnmarshalBinary(data []byte) (err error) {
b := utils.NewBuffer(data)
- p.logN = uint64(b.ReadUint8())
+ p.logN = int(b.ReadUint8())
if p.logN > MaxLogN {
return fmt.Errorf("LogN larger than %d", MaxLogN)
}
- p.logSlots = uint64(b.ReadUint8())
+ p.logSlots = int(b.ReadUint8())
if p.logSlots > p.logN-1 {
return fmt.Errorf("LogSlots larger than %d", MaxLogN-1)
@@ -641,7 +652,7 @@ func (p *Parameters) UnmarshalBinary(data []byte) (err error) {
return nil
}
-func checkModuli(m *Moduli, logN uint64) error {
+func checkModuli(m *Moduli, logN int) error {
if len(m.Qi) > MaxModuliCount {
return fmt.Errorf("#Qi is larger than %d", MaxModuliCount)
@@ -663,16 +674,16 @@ func checkModuli(m *Moduli, logN uint64) error {
}
}
- N := uint64(1 << logN)
+ N := 1 << logN
for i, qi := range m.Qi {
- if !ring.IsPrime(qi) || qi&((N<<1)-1) != 1 {
+ if !ring.IsPrime(qi) || qi&uint64((N<<1)-1) != 1 {
return fmt.Errorf("Qi (i=%d) is not an NTT prime", i)
}
}
for i, pi := range m.Pi {
- if !ring.IsPrime(pi) || pi&((N<<1)-1) != 1 {
+ if !ring.IsPrime(pi) || pi&uint64((N<<1)-1) != 1 {
return fmt.Errorf("Pi (i=%d) is not an NTT prime", i)
}
}
@@ -705,12 +716,12 @@ func checkLogModuli(m *LogModuli) error {
return nil
}
-func genModuli(lm *LogModuli, logN uint64) (m *Moduli) {
+func genModuli(lm *LogModuli, logN int) (m *Moduli) {
m = new(Moduli)
// Extracts all the different primes bit size and maps their number
- primesbitlen := make(map[uint64]uint64)
+ primesbitlen := make(map[uint64]int)
for _, qi := range lm.LogQi {
primesbitlen[qi]++
}
@@ -722,7 +733,7 @@ func genModuli(lm *LogModuli, logN uint64) (m *Moduli) {
// For each bit-size, finds that many primes
primes := make(map[uint64][]uint64)
for key, value := range primesbitlen {
- primes[key] = ring.GenerateNTTPrimes(key, 2<> 1) //optimalSplit(logDegree) //
- for i := uint64(2); i < (1 << logSplit); i++ {
+ for i := 2; i < (1 << logSplit); i++ {
if err = computePowerBasis(i, C, eval); err != nil {
return nil, err
}
@@ -77,7 +78,8 @@ func (eval *evaluator) EvaluatePoly(ct0 *Ciphertext, pol *Poly) (opOut *Cipherte
}
}
- opOut, err = recurse(eval.scale, logSplit, logDegree, pol, C, eval)
+ opOut, err = recurse(targetScale, logSplit, logDegree, pol, C, eval)
+
C = nil
return opOut, err
}
@@ -86,20 +88,21 @@ func (eval *evaluator) EvaluatePoly(ct0 *Ciphertext, pol *Poly) (opOut *Cipherte
// Returns an error if the input ciphertext does not have enough level to carry out the full polynomial evaluation.
// Returns an error if something is wrong with the scale.
// A change of basis ct' = (2/(b-a)) * (ct + (-a-b)/(b-a)) is necessary before the polynomial evaluation to ensure correctness.
-func (eval *evaluator) EvaluateCheby(op *Ciphertext, cheby *ChebyshevInterpolation) (opOut *Ciphertext, err error) {
+
+func (eval *evaluator) EvaluateCheby(op *Ciphertext, cheby *ChebyshevInterpolation, tartetScale float64) (opOut *Ciphertext, err error) {
if err := checkEnoughLevels(op.Level(), &cheby.Poly, 1); err != nil {
return op, err
}
- C := make(map[uint64]*Ciphertext)
+ C := make(map[int]*Ciphertext)
C[1] = op.CopyNew().Ciphertext()
- logDegree := uint64(bits.Len64(cheby.Degree()))
+ logDegree := int(bits.Len64(uint64(cheby.Degree())))
logSplit := (logDegree >> 1) //optimalSplit(logDegree) //
- for i := uint64(2); i < (1 << logSplit); i++ {
+ for i := 2; i < (1 << logSplit); i++ {
if err = computePowerBasisCheby(i, C, eval); err != nil {
return nil, err
}
@@ -111,19 +114,19 @@ func (eval *evaluator) EvaluateCheby(op *Ciphertext, cheby *ChebyshevInterpolati
}
}
- opOut, err = recurseCheby(eval.scale, logSplit, logDegree, &cheby.Poly, C, eval)
+ opOut, err = recurseCheby(tartetScale, logSplit, logDegree, &cheby.Poly, C, eval)
C = nil
return opOut, err
}
-func computePowerBasis(n uint64, C map[uint64]*Ciphertext, evaluator *evaluator) (err error) {
+func computePowerBasis(n int, C map[int]*Ciphertext, evaluator *evaluator) (err error) {
if C[n] == nil {
// Computes the index required to compute the asked ring evaluation
- a := uint64(math.Ceil(float64(n) / 2))
+ a := int(math.Ceil(float64(n) / 2))
b := n >> 1
// Recurses on the given indexes
@@ -145,7 +148,7 @@ func computePowerBasis(n uint64, C map[uint64]*Ciphertext, evaluator *evaluator)
return nil
}
-func computePowerBasisCheby(n uint64, C map[uint64]*Ciphertext, evaluator *evaluator) (err error) {
+func computePowerBasisCheby(n int, C map[int]*Ciphertext, evaluator *evaluator) (err error) {
// Given a hash table with the first three evaluations of the Chebyshev ring at x in the interval a, b:
// C0 = 1 (actually not stored in the hash table)
@@ -158,9 +161,9 @@ func computePowerBasisCheby(n uint64, C map[uint64]*Ciphertext, evaluator *evalu
if C[n] == nil {
// Computes the index required to compute the asked ring evaluation
- a := uint64(math.Ceil(float64(n) / 2))
+ a := int(math.Ceil(float64(n) / 2))
b := n >> 1
- c := uint64(math.Abs(float64(a) - float64(b)))
+ c := int(math.Abs(float64(a) - float64(b)))
// Recurses on the given indexes
if err = computePowerBasisCheby(a, C, evaluator); err != nil {
@@ -199,7 +202,7 @@ func computePowerBasisCheby(n uint64, C map[uint64]*Ciphertext, evaluator *evalu
return nil
}
-func splitCoeffs(coeffs *Poly, split uint64) (coeffsq, coeffsr *Poly) {
+func splitCoeffs(coeffs *Poly, split int) (coeffsq, coeffsr *Poly) {
// Splits a polynomial p such that p = q*C^degree + r.
@@ -211,7 +214,7 @@ func splitCoeffs(coeffs *Poly, split uint64) (coeffsq, coeffsr *Poly) {
coeffsr.maxDeg = coeffs.maxDeg - (coeffs.Degree() - split + 1)
}
- for i := uint64(0); i < split; i++ {
+ for i := 0; i < split; i++ {
coeffsr.coeffs[i] = coeffs.coeffs[i]
}
@@ -231,7 +234,7 @@ func splitCoeffs(coeffs *Poly, split uint64) (coeffsq, coeffsr *Poly) {
return coeffsq, coeffsr
}
-func splitCoeffsCheby(coeffs *Poly, split uint64) (coeffsq, coeffsr *Poly) {
+func splitCoeffsCheby(coeffs *Poly, split int) (coeffsq, coeffsr *Poly) {
// Splits a Chebyshev polynomial p such that p = q*C^degree + r, where q and r are a linear combination of a Chebyshev basis.
coeffsr = new(Poly)
@@ -242,7 +245,7 @@ func splitCoeffsCheby(coeffs *Poly, split uint64) (coeffsq, coeffsr *Poly) {
coeffsr.maxDeg = coeffs.maxDeg - (coeffs.Degree() - split + 1)
}
- for i := uint64(0); i < split; i++ {
+ for i := 0; i < split; i++ {
coeffsr.coeffs[i] = coeffs.coeffs[i]
}
@@ -251,7 +254,7 @@ func splitCoeffsCheby(coeffs *Poly, split uint64) (coeffsq, coeffsr *Poly) {
coeffsq.maxDeg = coeffs.maxDeg
coeffsq.coeffs[0] = coeffs.coeffs[split]
- for i, j := split+1, uint64(1); i < coeffs.Degree()+1; i, j = i+1, j+1 {
+ for i, j := split+1, 1; i < coeffs.Degree()+1; i, j = i+1, j+1 {
coeffsq.coeffs[i-split] = 2 * coeffs.coeffs[i]
coeffsr.coeffs[split-j] -= coeffs.coeffs[i]
}
@@ -263,13 +266,14 @@ func splitCoeffsCheby(coeffs *Poly, split uint64) (coeffsq, coeffsr *Poly) {
return coeffsq, coeffsr
}
-func recurse(targetScale float64, logSplit, logDegree uint64, coeffs *Poly, C map[uint64]*Ciphertext, evaluator *evaluator) (res *Ciphertext, err error) {
+func recurse(targetScale float64, logSplit, logDegree int, coeffs *Poly, C map[int]*Ciphertext, evaluator *evaluator) (res *Ciphertext, err error) {
+
// Recursively computes the evalution of the Chebyshev polynomial using a baby-set giant-step algorithm.
if coeffs.Degree() < (1 << logSplit) {
if coeffs.lead && coeffs.maxDeg > ((1< 1 {
- logDegree = uint64(bits.Len64(coeffs.Degree()))
+ logDegree = int(bits.Len64(uint64(coeffs.Degree())))
logSplit = logDegree >> 1
return recurse(targetScale, logSplit, logDegree, coeffs, C, evaluator)
@@ -278,7 +282,7 @@ func recurse(targetScale float64, logSplit, logDegree uint64, coeffs *Poly, C ma
return evaluatePolyFromPowerBasis(targetScale, coeffs, C, evaluator)
}
- var nextPower = uint64(1 << logSplit)
+ var nextPower = 1 << logSplit
for nextPower < (coeffs.Degree()>>1)+1 {
nextPower <<= 1
}
@@ -335,13 +339,14 @@ func recurse(targetScale float64, logSplit, logDegree uint64, coeffs *Poly, C ma
return
}
-func recurseCheby(targetScale float64, logSplit, logDegree uint64, coeffs *Poly, C map[uint64]*Ciphertext, evaluator *evaluator) (res *Ciphertext, err error) {
+func recurseCheby(targetScale float64, logSplit, logDegree int, coeffs *Poly, C map[int]*Ciphertext, evaluator *evaluator) (res *Ciphertext, err error) {
+
// Recursively computes the evalution of the Chebyshev polynomial using a baby-set giant-step algorithm.
if coeffs.Degree() < (1 << logSplit) {
if coeffs.lead && coeffs.maxDeg > ((1< 1 {
- logDegree = uint64(bits.Len64(coeffs.Degree()))
+ logDegree = int(bits.Len64(uint64(coeffs.Degree())))
logSplit = logDegree >> 1
return recurseCheby(targetScale, logSplit, logDegree, coeffs, C, evaluator)
@@ -350,7 +355,7 @@ func recurseCheby(targetScale float64, logSplit, logDegree uint64, coeffs *Poly,
return evaluatePolyFromPowerBasis(targetScale, coeffs, C, evaluator)
}
- var nextPower = uint64(1 << logSplit)
+ var nextPower = 1 << logSplit
for nextPower < (coeffs.Degree()>>1)+1 {
nextPower <<= 1
}
@@ -408,7 +413,7 @@ func recurseCheby(targetScale float64, logSplit, logDegree uint64, coeffs *Poly,
}
-func evaluatePolyFromPowerBasis(targetScale float64, coeffs *Poly, C map[uint64]*Ciphertext, evaluator *evaluator) (res *Ciphertext, err error) {
+func evaluatePolyFromPowerBasis(targetScale float64, coeffs *Poly, C map[int]*Ciphertext, evaluator *evaluator) (res *Ciphertext, err error) {
if coeffs.Degree() == 0 {
diff --git a/ckks/precision.go b/ckks/precision.go
index f9f0cc25..4a2e407e 100644
--- a/ckks/precision.go
+++ b/ckks/precision.go
@@ -16,6 +16,8 @@ type PrecisionStats struct {
MeanPrecision complex128
MedianDelta complex128
MedianPrecision complex128
+ STDFreq float64
+ STDTime float64
RealDist, ImagDist []struct {
Prec float64
@@ -26,25 +28,27 @@ type PrecisionStats struct {
}
func (prec PrecisionStats) String() string {
- return fmt.Sprintf("\nMinimum precision : (%.2f, %.2f) bits \n", real(prec.MinPrecision), imag(prec.MinPrecision)) +
- fmt.Sprintf("Maximum precision : (%.2f, %.2f) bits \n", real(prec.MaxPrecision), imag(prec.MaxPrecision)) +
- fmt.Sprintf("Mean precision : (%.2f, %.2f) bits \n", real(prec.MeanPrecision), imag(prec.MeanPrecision)) +
- fmt.Sprintf("Median precision : (%.2f, %.2f) bits \n", real(prec.MedianPrecision), imag(prec.MedianPrecision))
+ return fmt.Sprintf("\nMIN Prec : (%.2f, %.2f) Log2 \n", real(prec.MinPrecision), imag(prec.MinPrecision)) +
+ fmt.Sprintf("MAX Prec : (%.2f, %.2f) Log2 \n", real(prec.MaxPrecision), imag(prec.MaxPrecision)) +
+ fmt.Sprintf("AVG Prec : (%.2f, %.2f) Log2 \n", real(prec.MeanPrecision), imag(prec.MeanPrecision)) +
+ fmt.Sprintf("MED Prec : (%.2f, %.2f) Log2 \n", real(prec.MedianPrecision), imag(prec.MedianPrecision)) +
+ fmt.Sprintf("Err stdF : %5.2f Log2 \n", math.Log2(prec.STDFreq)) +
+ fmt.Sprintf("Err stdT : %5.2f Log2 \n", math.Log2(prec.STDTime))
+
}
// GetPrecisionStats generates a PrecisionStats struct from the reference values and the decrypted values
-func GetPrecisionStats(params *Parameters, encoder Encoder, decryptor Decryptor, valuesWant []complex128, element interface{}) (prec PrecisionStats) {
+func GetPrecisionStats(params *Parameters, encoder Encoder, decryptor Decryptor, valuesWant []complex128, element interface{}, logSlots int, sigma float64) (prec PrecisionStats) {
var valuesTest []complex128
- logSlots := params.LogSlots()
slots := uint64(1 << logSlots)
switch element := element.(type) {
case *Ciphertext:
- valuesTest = encoder.Decode(decryptor.DecryptNew(element), logSlots)
+ valuesTest = encoder.DecodePublic(decryptor.DecryptNew(element), logSlots, sigma)
case *Plaintext:
- valuesTest = encoder.Decode(element, logSlots)
+ valuesTest = encoder.DecodePublic(element, logSlots, sigma)
case []complex128:
valuesTest = element
}
@@ -112,6 +116,8 @@ func GetPrecisionStats(params *Parameters, encoder Encoder, decryptor Decryptor,
prec.MeanPrecision = deltaToPrecision(prec.MeanDelta)
prec.MedianDelta = calcmedian(diff)
prec.MedianPrecision = deltaToPrecision(prec.MedianDelta)
+ prec.STDFreq = encoder.GetErrSTDSlotDomain(valuesWant[:], valuesTest[:], params.Scale())
+ prec.STDTime = encoder.GetErrSTDCoeffDomain(valuesWant, valuesTest, params.Scale())
return prec
}
diff --git a/ckks/utils.go b/ckks/utils.go
index c4311da6..28bd3952 100644
--- a/ckks/utils.go
+++ b/ckks/utils.go
@@ -1,12 +1,32 @@
package ckks
import (
+ "github.com/ldsec/lattigo/v2/ring"
"math"
"math/big"
-
- "github.com/ldsec/lattigo/v2/ring"
)
+// StandardDeviation computes the scaled standard deviation of the input vector.
+func StandardDeviation(vec []float64, scale float64) (std float64) {
+ // We assume that the error is centered around zero
+ var err, tmp, mean, n float64
+
+ n = float64(len(vec))
+
+ for _, c := range vec {
+ mean += c
+ }
+
+ mean /= n
+
+ for _, c := range vec {
+ tmp = c - mean
+ err += tmp * tmp
+ }
+
+ return math.Sqrt(err/n) * scale
+}
+
func scaleUpExact(value float64, n float64, q uint64) (res uint64) {
var isNegative bool
@@ -94,7 +114,7 @@ func scaleUpVecExact(values []float64, n float64, moduli []uint64, coeffs [][]ui
func scaleUpVecExactBigFloat(values []*big.Float, scale float64, moduli []uint64, coeffs [][]uint64) {
- prec := uint64(values[0].Prec())
+ prec := int(values[0].Prec())
xFlo := ring.NewFloat(0, prec)
xInt := new(big.Int)
@@ -167,18 +187,18 @@ func GenSwitchkeysRescalingParams(Q, P []uint64) (params []uint64) {
for i := 0; i < len(Q); i++ {
params[i] = tmp.Mod(PBig, ring.NewUint(Q[i])).Uint64()
- params[i] = ring.ModExp(params[i], Q[i]-2, Q[i])
+ params[i] = ring.ModExp(params[i], int(Q[i]-2), Q[i])
params[i] = ring.MForm(params[i], Q[i], ring.BRedParams(Q[i]))
}
return
}
-func sliceBitReverseInPlaceComplex128(slice []complex128, N uint64) {
+func sliceBitReverseInPlaceComplex128(slice []complex128, N int) {
- var bit, j uint64
+ var bit, j int
- for i := uint64(1); i < N; i++ {
+ for i := 1; i < N; i++ {
bit = N >> 1
@@ -195,11 +215,11 @@ func sliceBitReverseInPlaceComplex128(slice []complex128, N uint64) {
}
}
-func sliceBitReverseInPlaceRingComplex(slice []*ring.Complex, N uint64) {
+func sliceBitReverseInPlaceRingComplex(slice []*ring.Complex, N int) {
- var bit, j uint64
+ var bit, j int
- for i := uint64(1); i < N; i++ {
+ for i := 1; i < N; i++ {
bit = N >> 1
diff --git a/dbfv/dbfv.go b/dbfv/dbfv.go
index 082e3aa1..e91180e9 100644
--- a/dbfv/dbfv.go
+++ b/dbfv/dbfv.go
@@ -9,7 +9,7 @@ type dbfvContext struct {
params *bfv.Parameters
// Polynomial degree
- n uint64
+ n int
// floor(Q/T) mod each Qi in Montgomery form
deltaMont []uint64
diff --git a/dbfv/dbfv_benchmark_test.go b/dbfv/dbfv_benchmark_test.go
index a2072c7e..ae44abe8 100644
--- a/dbfv/dbfv_benchmark_test.go
+++ b/dbfv/dbfv_benchmark_test.go
@@ -101,7 +101,7 @@ func benchRelinKeyGen(testCtx *testContext, b *testing.B) {
crp := make([]*ring.Poly, testCtx.params.Beta())
- for i := uint64(0); i < testCtx.params.Beta(); i++ {
+ for i := 0; i < testCtx.params.Beta(); i++ {
crp[i] = crpGenerator.ReadNew()
}
@@ -237,7 +237,7 @@ func benchRotKeyGen(testCtx *testContext, b *testing.B) {
crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.dbfvContext.ringQP)
crp := make([]*ring.Poly, testCtx.params.Beta())
- for i := uint64(0); i < testCtx.params.Beta(); i++ {
+ for i := 0; i < testCtx.params.Beta(); i++ {
crp[i] = crpGenerator.ReadNew()
}
diff --git a/dbfv/dbfv_test.go b/dbfv/dbfv_test.go
index 7013f45b..34a19ac1 100644
--- a/dbfv/dbfv_test.go
+++ b/dbfv/dbfv_test.go
@@ -16,9 +16,9 @@ import (
)
var flagLongTest = flag.Bool("long", false, "run the long test suite (all parameters). Overrides -short and requires -timeout=0.")
-var parties uint64 = 3
+var parties int = 3
-func testString(opname string, parties uint64, params *bfv.Parameters) string {
+func testString(opname string, parties int, params *bfv.Parameters) string {
return fmt.Sprintf("%sparties=%d/LogN=%d/logQ=%d", opname, parties, params.LogN(), params.LogQP())
}
@@ -100,7 +100,7 @@ func gentestContext(defaultParams *bfv.Parameters) (testCtx *testContext, err er
tmp0 := testCtx.dbfvContext.ringQP.NewPoly()
tmp1 := testCtx.dbfvContext.ringQP.NewPoly()
- for j := uint64(0); j < parties; j++ {
+ for j := 0; j < parties; j++ {
testCtx.sk0Shards[j] = kgen.GenSecretKey()
testCtx.sk1Shards[j] = kgen.GenSecretKey()
testCtx.dbfvContext.ringQP.Add(tmp0, testCtx.sk0Shards[j].Value, tmp0)
@@ -141,7 +141,7 @@ func testPublicKeyGen(testCtx *testContext, t *testing.T) {
}
ckgParties := make([]*Party, parties)
- for i := uint64(0); i < parties; i++ {
+ for i := 0; i < parties; i++ {
p := new(Party)
p.CKGProtocol = NewCKGProtocol(testCtx.params)
p.s = &sk0Shards[i].SecretKey
@@ -207,7 +207,7 @@ func testRelinKeyGen(testCtx *testContext, t *testing.T) {
crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.dbfvContext.ringQP)
crp := make([]*ring.Poly, testCtx.params.Beta())
- for i := uint64(0); i < testCtx.params.Beta(); i++ {
+ for i := 0; i < testCtx.params.Beta(); i++ {
crp[i] = crpGenerator.ReadNew()
}
@@ -266,7 +266,7 @@ func testKeyswitching(testCtx *testContext, t *testing.T) {
}
cksParties := make([]*Party, parties)
- for i := uint64(0); i < parties; i++ {
+ for i := 0; i < parties; i++ {
p := new(Party)
p.CKSProtocol = NewCKSProtocol(testCtx.params, 6.36)
p.s0 = sk0Shards[i].Value
@@ -314,7 +314,7 @@ func testPublicKeySwitching(testCtx *testContext, t *testing.T) {
}
pcksParties := make([]*Party, parties)
- for i := uint64(0); i < parties; i++ {
+ for i := 0; i < parties; i++ {
p := new(Party)
p.PCKSProtocol = NewPCKSProtocol(testCtx.params, 6.36)
p.s = sk0Shards[i].Value
@@ -355,7 +355,7 @@ func testRotKeyGenRotRows(testCtx *testContext, t *testing.T) {
}
pcksParties := make([]*Party, parties)
- for i := uint64(0); i < parties; i++ {
+ for i := 0; i < parties; i++ {
p := new(Party)
p.RTGProtocol = NewRotKGProtocol(testCtx.params)
p.s = &sk0Shards[i].SecretKey
@@ -370,7 +370,7 @@ func testRotKeyGenRotRows(testCtx *testContext, t *testing.T) {
crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.dbfvContext.ringQP)
crp := make([]*ring.Poly, testCtx.params.Beta())
- for i := uint64(0); i < testCtx.params.Beta(); i++ {
+ for i := 0; i < testCtx.params.Beta(); i++ {
crp[i] = crpGenerator.ReadNew()
}
@@ -412,7 +412,7 @@ func testRotKeyGenRotCols(testCtx *testContext, t *testing.T) {
}
pcksParties := make([]*Party, parties)
- for i := uint64(0); i < parties; i++ {
+ for i := 0; i < parties; i++ {
p := new(Party)
p.RTGProtocol = NewRotKGProtocol(testCtx.params)
p.s = &sk0Shards[i].SecretKey
@@ -428,7 +428,7 @@ func testRotKeyGenRotCols(testCtx *testContext, t *testing.T) {
crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.dbfvContext.ringQP)
crp := make([]*ring.Poly, testCtx.params.Beta())
- for i := uint64(0); i < testCtx.params.Beta(); i++ {
+ for i := 0; i < testCtx.params.Beta(); i++ {
crp[i] = crpGenerator.ReadNew()
}
@@ -450,7 +450,8 @@ func testRotKeyGenRotCols(testCtx *testContext, t *testing.T) {
}
evaluator := testCtx.evaluator.WithKey(bfv.EvaluationKey{Rlk: nil, Rtks: rotKeySet})
- for k := uint64(1); k < testCtx.params.N()>>1; k <<= 1 {
+ for k := 1; k < testCtx.params.N()>>1; k <<= 1 {
+
result := evaluator.RotateColumnsNew(ciphertext, int(k))
coeffsWant := utils.RotateUint64Slots(coeffs, int(k))
verifyTestVectors(testCtx, decryptorSk0, coeffsWant, result, t)
@@ -479,7 +480,7 @@ func testRefresh(testCtx *testContext, t *testing.T) {
}
RefreshParties := make([]*Party, parties)
- for i := uint64(0); i < parties; i++ {
+ for i := 0; i < parties; i++ {
p := new(Party)
p.RefreshProtocol = NewRefreshProtocol(testCtx.params)
p.s = sk0Shards[i].Value
@@ -509,7 +510,7 @@ func testRefresh(testCtx *testContext, t *testing.T) {
evaluator.Relinearize(testCtx.evaluator.MulNew(ciphertextTmp, ciphertextTmp), ciphertextTmp)
for j := range coeffsTmp {
- coeffsTmp[j] = ring.BRed(coeffsTmp[j], coeffsTmp[j], testCtx.dbfvContext.ringT.Modulus[0], testCtx.dbfvContext.ringT.GetBredParams()[0])
+ coeffsTmp[j] = ring.BRed(coeffsTmp[j], coeffsTmp[j], testCtx.dbfvContext.ringT.Modulus[0], testCtx.dbfvContext.ringT.BredParams[0])
}
if utils.EqualSliceUint64(coeffsTmp, encoder.DecodeUintNew(decryptorSk0.DecryptNew(ciphertextTmp))) {
@@ -527,7 +528,7 @@ func testRefresh(testCtx *testContext, t *testing.T) {
errorRange.Quo(errorRange, testCtx.dbfvContext.ringT.ModulusBigint)
errorRange.Quo(errorRange, testCtx.dbfvContext.ringT.ModulusBigint)
- for i := uint64(0); i < testCtx.params.N(); i++ {
+ for i := 0; i < testCtx.params.N(); i++ {
coeffsBigint[i].Add(coeffsBigint[i], ring.RandInt(errorRange))
}
@@ -548,7 +549,7 @@ func testRefresh(testCtx *testContext, t *testing.T) {
evaluator.Relinearize(testCtx.evaluator.MulNew(ciphertext, ciphertext), ciphertext)
for j := range coeffs {
- coeffs[j] = ring.BRed(coeffs[j], coeffs[j], testCtx.dbfvContext.ringT.Modulus[0], testCtx.dbfvContext.ringT.GetBredParams()[0])
+ coeffs[j] = ring.BRed(coeffs[j], coeffs[j], testCtx.dbfvContext.ringT.Modulus[0], testCtx.dbfvContext.ringT.BredParams[0])
}
}
@@ -574,7 +575,7 @@ func testRefreshAndPermutation(testCtx *testContext, t *testing.T) {
}
RefreshParties := make([]*Party, parties)
- for i := uint64(0); i < parties; i++ {
+ for i := 0; i < parties; i++ {
p := new(Party)
p.PermuteProtocol = NewPermuteProtocol(testCtx.params)
p.s = sk0Shards[i].Value
@@ -593,7 +594,7 @@ func testRefreshAndPermutation(testCtx *testContext, t *testing.T) {
permutation := make([]uint64, len(coeffs))
for i := range permutation {
- permutation[i] = ring.RandUniform(testCtx.prng, testCtx.params.N(), testCtx.params.N()-1)
+ permutation[i] = ring.RandUniform(testCtx.prng, uint64(testCtx.params.N()), uint64(testCtx.params.N()-1))
}
for i, p := range RefreshParties {
@@ -669,10 +670,10 @@ func testMarshalling(testCtx *testContext, t *testing.T) {
}
//comparing the results
- require.Equal(t, KeyGenShareBefore.GetDegree(), KeyGenShareAfter.GetDegree())
- require.Equal(t, KeyGenShareBefore.GetLenModuli(), KeyGenShareAfter.GetLenModuli())
+ require.Equal(t, KeyGenShareBefore.Degree(), KeyGenShareAfter.Degree())
+ require.Equal(t, KeyGenShareBefore.LenModuli(), KeyGenShareAfter.LenModuli())
- moduli := KeyGenShareBefore.GetLenModuli()
+ moduli := KeyGenShareBefore.LenModuli()
require.Equal(t, KeyGenShareAfter.Coeffs[:moduli], KeyGenShareBefore.Coeffs[:moduli])
})
@@ -694,8 +695,8 @@ func testMarshalling(testCtx *testContext, t *testing.T) {
//compare the shares.
ringBefore := SwitchShare[i]
ringAfter := SwitchShareReceiver[i]
- require.Equal(t, ringBefore.GetDegree(), ringAfter.GetDegree())
- moduli := ringAfter.GetLenModuli()
+ require.Equal(t, ringBefore.Degree(), ringAfter.Degree())
+ moduli := ringAfter.LenModuli()
require.Equal(t, ringAfter.Coeffs[:moduli], ringBefore.Coeffs[:moduli])
}
})
@@ -715,10 +716,10 @@ func testMarshalling(testCtx *testContext, t *testing.T) {
//now compare both shares.
- require.Equal(t, cksshare.GetDegree(), cksshareAfter.GetDegree())
- require.Equal(t, cksshare.GetLenModuli(), cksshareAfter.GetLenModuli())
+ require.Equal(t, cksshare.Degree(), cksshareAfter.Degree())
+ require.Equal(t, cksshare.LenModuli(), cksshareAfter.LenModuli())
- moduli := cksshare.GetLenModuli()
+ moduli := cksshare.LenModuli()
require.Equal(t, cksshare.Coeffs[:moduli], cksshareAfter.Coeffs[:moduli])
})
@@ -833,7 +834,7 @@ func testMarshallingRelin(testCtx *testContext, t *testing.T) {
// for i := 0; i < (len(r1)); i++ { // TODO test in drlwe
// a := r1[i][0]
// b := (*r1After)[i][0]
- // moduli := a.GetLenModuli()
+ // moduli := a.LenModuli()
// require.Equal(t, a.Coeffs[:moduli], b.Coeffs[:moduli])
// }
@@ -850,7 +851,7 @@ func testMarshallingRelin(testCtx *testContext, t *testing.T) {
// for idx := 0; idx < 2; idx++ {
// a := r2[i][idx]
// b := (*r2After)[i][idx]
- // moduli := a.GetLenModuli()
+ // moduli := a.LenModuli()
// require.Equal(t, a.Coeffs[:moduli], b.Coeffs[:moduli])
// }
diff --git a/dbfv/keyswitching.go b/dbfv/keyswitching.go
index cb9611c5..fcd6165f 100644
--- a/dbfv/keyswitching.go
+++ b/dbfv/keyswitching.go
@@ -55,7 +55,7 @@ func NewCKSProtocol(params *bfv.Parameters, sigmaSmudging float64) *CKSProtocol
if err != nil {
panic(err)
}
- cks.gaussianSampler = ring.NewGaussianSampler(prng, context.ringQP, sigmaSmudging, uint64(6*sigmaSmudging))
+ cks.gaussianSampler = ring.NewGaussianSampler(prng, cks.context.ringQP, sigmaSmudging, int(6*sigmaSmudging))
return cks
}
@@ -82,10 +82,9 @@ func (cks *CKSProtocol) GenShare(skInput, skOutput *ring.Poly, ct *bfv.Ciphertex
func (cks *CKSProtocol) genShareDelta(skDelta *ring.Poly, ct *bfv.Ciphertext, shareOut CKSShare) {
- level := uint64(len(ct.Value()[1].Coeffs) - 1)
+ level := len(ct.Value()[1].Coeffs) - 1
ringQ := cks.context.ringQ
- ringQP := cks.context.ringQP
ringQ.NTTLazy(ct.Value()[1], cks.tmpNtt)
ringQ.MulCoeffsMontgomeryConstant(cks.tmpNtt, skDelta, shareOut.Poly)
@@ -93,13 +92,13 @@ func (cks *CKSProtocol) genShareDelta(skDelta *ring.Poly, ct *bfv.Ciphertext, sh
ringQ.InvNTTLazy(shareOut.Poly, shareOut.Poly)
- cks.gaussianSampler.ReadLvl(uint64(len(ringQP.Modulus)-1), cks.tmpNtt)
+ cks.gaussianSampler.Read(cks.tmpNtt)
ringQ.AddNoMod(shareOut.Poly, cks.tmpNtt, shareOut.Poly)
- for x, i := 0, uint64(len(ringQ.Modulus)); i < uint64(len(cks.context.ringQP.Modulus)); x, i = x+1, i+1 {
+ for x, i := 0, len(ringQ.Modulus); i < len(cks.context.ringQP.Modulus); x, i = x+1, i+1 {
tmphP := cks.hP.Coeffs[x]
tmpNTT := cks.tmpNtt.Coeffs[i]
- for j := uint64(0); j < ringQ.N; j++ {
+ for j := 0; j < ringQ.N; j++ {
tmphP[j] += tmpNTT[j]
}
}
diff --git a/dbfv/public_keyswitching.go b/dbfv/public_keyswitching.go
index 4cc17440..79d13cf6 100644
--- a/dbfv/public_keyswitching.go
+++ b/dbfv/public_keyswitching.go
@@ -94,7 +94,7 @@ func NewPCKSProtocol(params *bfv.Parameters, sigmaSmudging float64) *PCKSProtoco
if err != nil {
panic(err)
}
- pcks.gaussianSampler = ring.NewGaussianSampler(prng, context.ringQP, sigmaSmudging, uint64(6*sigmaSmudging))
+ pcks.gaussianSampler = ring.NewGaussianSampler(prng, pcks.context.ringQP, sigmaSmudging, int(6*sigmaSmudging))
pcks.ternarySamplerMontgomery = ring.NewTernarySampler(prng, context.ringQP, 0.5, true)
return pcks
@@ -115,31 +115,31 @@ func (pcks *PCKSProtocol) AllocateShares() (s PCKSShare) {
func (pcks *PCKSProtocol) GenShare(sk *ring.Poly, pk *bfv.PublicKey, ct *bfv.Ciphertext, shareOut PCKSShare) {
ringQ := pcks.context.ringQ
- contextKeys := pcks.context.ringQP
+ ringQP := pcks.context.ringQP
pcks.ternarySamplerMontgomery.Read(pcks.tmp)
- contextKeys.NTTLazy(pcks.tmp, pcks.tmp)
+ ringQP.NTTLazy(pcks.tmp, pcks.tmp)
// h_0 = u_i * pk_0
- contextKeys.MulCoeffsMontgomeryConstant(pcks.tmp, pk.Value[0], pcks.share0tmp)
+ ringQP.MulCoeffsMontgomeryConstant(pcks.tmp, pk.Value[0], pcks.share0tmp)
// h_1 = u_i * pk_1
- contextKeys.MulCoeffsMontgomeryConstant(pcks.tmp, pk.Value[1], pcks.share1tmp)
+ ringQP.MulCoeffsMontgomeryConstant(pcks.tmp, pk.Value[1], pcks.share1tmp)
- contextKeys.InvNTTLazy(pcks.share0tmp, pcks.share0tmp)
- contextKeys.InvNTTLazy(pcks.share1tmp, pcks.share1tmp)
+ ringQP.InvNTTLazy(pcks.share0tmp, pcks.share0tmp)
+ ringQP.InvNTTLazy(pcks.share1tmp, pcks.share1tmp)
// h_0 = u_i * pk_0 + e0
- pcks.gaussianSampler.ReadAndAdd(pcks.share0tmp)
+ pcks.gaussianSampler.ReadAndAddFromDistLvl(len(ringQP.Modulus)-1, pcks.share0tmp, ringQP, pcks.sigmaSmudging, int(6*pcks.sigmaSmudging))
// h_1 = u_i * pk_1 + e1
- pcks.gaussianSampler.ReadAndAdd(pcks.share1tmp)
+ pcks.gaussianSampler.ReadAndAddFromDistLvl(len(ringQP.Modulus)-1, pcks.share1tmp, ringQP, pcks.sigmaSmudging, int(6*pcks.sigmaSmudging))
// h_0 = (u_i * pk_0 + e0)/P
- pcks.baseconverter.ModDownPQ(uint64(len(ringQ.Modulus))-1, pcks.share0tmp, shareOut[0])
+ pcks.baseconverter.ModDownPQ(len(ringQ.Modulus)-1, pcks.share0tmp, shareOut[0])
// h_0 = (u_i * pk_0 + e0)/P
// Could be moved to the keyswitch phase, but the second element of the shares will be larger
- pcks.baseconverter.ModDownPQ(uint64(len(ringQ.Modulus))-1, pcks.share1tmp, shareOut[1])
+ pcks.baseconverter.ModDownPQ(len(ringQ.Modulus)-1, pcks.share1tmp, shareOut[1])
// tmp = s_i*c_1
ringQ.NTTLazy(ct.Value()[1], pcks.tmp)
diff --git a/dbfv/public_permute.go b/dbfv/public_permute.go
index 69be9a51..05c77a42 100644
--- a/dbfv/public_permute.go
+++ b/dbfv/public_permute.go
@@ -1,8 +1,6 @@
package dbfv
import (
- "math/bits"
-
"github.com/ldsec/lattigo/v2/bfv"
"github.com/ldsec/lattigo/v2/ring"
"github.com/ldsec/lattigo/v2/utils"
@@ -18,6 +16,7 @@ type PermuteProtocol struct {
baseconverter *ring.FastBasisExtender
scaler ring.Scaler
gaussianSampler *ring.GaussianSampler
+ sigma float64
uniformSampler *ring.UniformSampler
}
@@ -38,13 +37,13 @@ func NewPermuteProtocol(params *bfv.Parameters) (refreshProtocol *PermuteProtoco
indexMatrix := make([]uint64, params.N())
- logN := uint64(bits.Len64(params.N()) - 1)
+ logN := uint64(params.LogN())
rowSize := params.N() >> 1
- m = (params.N() << 1)
+ m = uint64(params.N()) << 1
pos = 1
- for i := uint64(0); i < rowSize; i++ {
+ for i := 0; i < rowSize; i++ {
index1 = (pos - 1) >> 1
index2 = (m - pos - 1) >> 1
@@ -64,7 +63,8 @@ func NewPermuteProtocol(params *bfv.Parameters) (refreshProtocol *PermuteProtoco
panic(err)
}
- refreshProtocol.gaussianSampler = ring.NewGaussianSampler(prng, context.ringQP, params.Sigma(), uint64(6*params.Sigma()))
+ refreshProtocol.gaussianSampler = ring.NewGaussianSampler(prng, context.ringQ, params.Sigma(), int(6*params.Sigma()))
+ refreshProtocol.sigma = params.Sigma()
refreshProtocol.uniformSampler = ring.NewUniformSampler(prng, context.ringT)
return
@@ -79,7 +79,7 @@ func (pp *PermuteProtocol) AllocateShares() RefreshShare {
// GenShares generates the shares of the PermuteProtocol
func (pp *PermuteProtocol) GenShares(sk *ring.Poly, ciphertext *bfv.Ciphertext, crs *ring.Poly, permutation []uint64, share RefreshShare) {
- level := uint64(len(ciphertext.Value()[1].Coeffs) - 1)
+ level := len(ciphertext.Value()[1].Coeffs) - 1
ringQ := pp.context.ringQ
ringT := pp.context.ringT
@@ -94,13 +94,13 @@ func (pp *PermuteProtocol) GenShares(sk *ring.Poly, ciphertext *bfv.Ciphertext,
ringQ.MulScalarBigint(share.RefreshShareDecrypt, pp.context.ringP.ModulusBigint, share.RefreshShareDecrypt)
// h0 = s*ct[1]*P + e
- pp.gaussianSampler.ReadLvl(uint64(len(ringQP.Modulus)-1), pp.tmp1)
+ pp.gaussianSampler.ReadFromDistLvl(len(ringQP.Modulus)-1, pp.tmp1, ringQP, pp.sigma, int(6*pp.sigma))
ringQ.Add(share.RefreshShareDecrypt, pp.tmp1, share.RefreshShareDecrypt)
- for x, i := 0, uint64(len(ringQ.Modulus)); i < uint64(len(pp.context.ringQP.Modulus)); x, i = x+1, i+1 {
+ for x, i := 0, len(ringQ.Modulus); i < len(pp.context.ringQP.Modulus); x, i = x+1, i+1 {
tmphP := pp.hP.Coeffs[x]
tmp1 := pp.tmp1.Coeffs[i]
- for j := uint64(0); j < ringQ.N; j++ {
+ for j := 0; j < ringQ.N; j++ {
tmphP[j] += tmp1[j]
}
}
@@ -115,7 +115,7 @@ func (pp *PermuteProtocol) GenShares(sk *ring.Poly, ciphertext *bfv.Ciphertext,
ringQP.InvNTTLazy(pp.tmp2, pp.tmp2)
// h1 = s*a + e'
- pp.gaussianSampler.ReadAndAdd(pp.tmp2)
+ pp.gaussianSampler.ReadAndAddFromDistLvl(len(ringQP.Modulus)-1, pp.tmp2, ringQP, pp.sigma, int(6*pp.sigma))
// h1 = (-s*a + e')/P
pp.baseconverter.ModDownPQ(level, pp.tmp2, share.RefreshShareRecrypt)
@@ -181,7 +181,7 @@ func (pp *PermuteProtocol) Recrypt(sharePlaintext *ring.Poly, crs *ring.Poly, sh
pp.context.ringQ.Add(sharePlaintext, shareRecrypt, ciphertextOut.Value()[0])
// ciphertext[1] = crs/P
- pp.baseconverter.ModDownPQ(uint64(len(ciphertextOut.Value()[1].Coeffs)-1), crs, ciphertextOut.Value()[1])
+ pp.baseconverter.ModDownPQ(len(ciphertextOut.Value()[1].Coeffs)-1, crs, ciphertextOut.Value()[1])
}
@@ -193,7 +193,7 @@ func (pp *PermuteProtocol) Finalize(ciphertext *bfv.Ciphertext, permutation []ui
}
func (pp *PermuteProtocol) permuteWithIndex(polIn *ring.Poly, index []uint64, polOut *ring.Poly) {
- for j := uint64(0); j < uint64(len(polIn.Coeffs[0])); j++ {
+ for j := 0; j < len(polIn.Coeffs[0]); j++ {
polOut.Coeffs[0][pp.indexMatrix[j]] = polIn.Coeffs[0][pp.indexMatrix[index[j]]]
}
}
diff --git a/dbfv/public_refresh.go b/dbfv/public_refresh.go
index 7152cbf7..a8e92f42 100644
--- a/dbfv/public_refresh.go
+++ b/dbfv/public_refresh.go
@@ -18,6 +18,7 @@ type RefreshProtocol struct {
baseconverter *ring.FastBasisExtender
scaler ring.Scaler
gaussianSampler *ring.GaussianSampler
+ sigma float64
uniformSampler *ring.UniformSampler
}
@@ -39,10 +40,10 @@ func (share *RefreshShare) MarshalBinary() ([]byte, error) {
lenRecrypt := (*share.RefreshShareRecrypt).GetDataLen(true)
data := make([]byte, lenDecrypt+lenRecrypt+2*8) // 2 * 3 to write the len of lenDecrypt and lenRecrypt.
- binary.BigEndian.PutUint64(data[0:8], lenDecrypt)
- binary.BigEndian.PutUint64(data[8:16], lenRecrypt)
+ binary.BigEndian.PutUint64(data[0:8], uint64(lenDecrypt))
+ binary.BigEndian.PutUint64(data[8:16], uint64(lenRecrypt))
- ptr := uint64(16)
+ ptr := 16
tmp, err := (*share.RefreshShareDecrypt).WriteTo(data[ptr : ptr+lenDecrypt])
if err != nil {
return []byte{}, err
@@ -97,7 +98,8 @@ func NewRefreshProtocol(params *bfv.Parameters) (refreshProtocol *RefreshProtoco
if err != nil {
panic(err)
}
- refreshProtocol.gaussianSampler = ring.NewGaussianSampler(prng, context.ringQP, params.Sigma(), uint64(6*params.Sigma()))
+ refreshProtocol.gaussianSampler = ring.NewGaussianSampler(prng, context.ringQ, params.Sigma(), int(6*params.Sigma()))
+ refreshProtocol.sigma = params.Sigma()
refreshProtocol.uniformSampler = ring.NewUniformSampler(prng, context.ringT)
return
@@ -112,7 +114,7 @@ func (rfp *RefreshProtocol) AllocateShares() RefreshShare {
// GenShares generates a share for the Refresh protocol.
func (rfp *RefreshProtocol) GenShares(sk *ring.Poly, ciphertext *bfv.Ciphertext, crs *ring.Poly, share RefreshShare) {
- level := uint64(len(ciphertext.Value()[1].Coeffs) - 1)
+ level := len(ciphertext.Value()[1].Coeffs) - 1
ringQ := rfp.context.ringQ
ringQP := rfp.context.ringQP
@@ -126,13 +128,13 @@ func (rfp *RefreshProtocol) GenShares(sk *ring.Poly, ciphertext *bfv.Ciphertext,
ringQ.MulScalarBigint(share.RefreshShareDecrypt, rfp.context.ringP.ModulusBigint, share.RefreshShareDecrypt)
// h0 = s*ct[1]*P + e
- rfp.gaussianSampler.ReadLvl(uint64(len(ringQP.Modulus)-1), rfp.tmp1)
+ rfp.gaussianSampler.ReadFromDistLvl(len(ringQP.Modulus)-1, rfp.tmp1, ringQP, rfp.sigma, int(6*rfp.sigma))
ringQ.Add(share.RefreshShareDecrypt, rfp.tmp1, share.RefreshShareDecrypt)
- for x, i := 0, uint64(len(ringQ.Modulus)); i < uint64(len(rfp.context.ringQP.Modulus)); x, i = x+1, i+1 {
+ for x, i := 0, len(ringQ.Modulus); i < len(rfp.context.ringQP.Modulus); x, i = x+1, i+1 {
tmphP := rfp.hP.Coeffs[x]
tmp1 := rfp.tmp1.Coeffs[i]
- for j := uint64(0); j < ringQ.N; j++ {
+ for j := 0; j < ringQ.N; j++ {
tmphP[j] += tmp1[j]
}
}
@@ -147,7 +149,7 @@ func (rfp *RefreshProtocol) GenShares(sk *ring.Poly, ciphertext *bfv.Ciphertext,
ringQP.InvNTTLazy(rfp.tmp2, rfp.tmp2)
// h1 = s*a + e'
- rfp.gaussianSampler.ReadAndAdd(rfp.tmp2)
+ rfp.gaussianSampler.ReadAndAddFromDistLvl(len(ringQP.Modulus)-1, rfp.tmp2, ringQP, rfp.sigma, int(6*rfp.sigma))
// h1 = (-s*a + e')/P
rfp.baseconverter.ModDownPQ(level, rfp.tmp2, share.RefreshShareRecrypt)
@@ -187,7 +189,7 @@ func (rfp *RefreshProtocol) Recrypt(sharePlaintext *ring.Poly, crs *ring.Poly, s
rfp.context.ringQ.Add(sharePlaintext, shareRecrypt, ciphertextOut.Value()[0])
// ciphertext[1] = crs/P
- rfp.baseconverter.ModDownPQ(uint64(len(ciphertextOut.Value()[1].Coeffs)-1), crs, ciphertextOut.Value()[1])
+ rfp.baseconverter.ModDownPQ(len(ciphertextOut.Value()[1].Coeffs)-1, crs, ciphertextOut.Value()[1])
}
@@ -202,7 +204,7 @@ func lift(p0, p1 *ring.Poly, context *dbfvContext) {
coeffs := p0.Coeffs[0]
var coeff uint64
- for j := uint64(0); j < context.n; j++ {
+ for j := 0; j < context.n; j++ {
coeff = coeffs[j]
for i := len(context.ringQ.Modulus) - 1; i >= 0; i-- {
p1.Coeffs[i][j] = ring.MRed(coeff, context.deltaMont[i], context.ringQ.Modulus[i], context.ringQ.MredParams[i])
diff --git a/dckks/dckks.go b/dckks/dckks.go
index 5bf78337..c101edab 100644
--- a/dckks/dckks.go
+++ b/dckks/dckks.go
@@ -8,14 +8,14 @@ import (
type dckksContext struct {
params *ckks.Parameters
- n uint64
+ n int
ringQ *ring.Ring
ringP *ring.Ring
ringQP *ring.Ring
- alpha uint64
- beta uint64
+ alpha int
+ beta int
}
func newDckksContext(params *ckks.Parameters) (context *dckksContext) {
diff --git a/dckks/dckks_benchmark_test.go b/dckks/dckks_benchmark_test.go
index be0298f1..5be3fe40 100644
--- a/dckks/dckks_benchmark_test.go
+++ b/dckks/dckks_benchmark_test.go
@@ -92,7 +92,7 @@ func benchRelinKeyGen(testCtx *testContext, b *testing.B) {
crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.dckksContext.ringQP)
crp := make([]*ring.Poly, testCtx.params.Beta())
- for i := uint64(0); i < testCtx.params.Beta(); i++ {
+ for i := 0; i < testCtx.params.Beta(); i++ {
crp[i] = crpGenerator.ReadNew()
}
@@ -227,7 +227,7 @@ func benchRotKeyGen(testCtx *testContext, b *testing.B) {
crpGenerator := ring.NewUniformSampler(testCtx.prng, ringQP)
crp := make([]*ring.Poly, testCtx.params.Beta())
- for i := uint64(0); i < testCtx.params.Beta(); i++ {
+ for i := 0; i < testCtx.params.Beta(); i++ {
crp[i] = crpGenerator.ReadNew()
}
galEl := testCtx.params.GaloisElementForRowRotation()
@@ -262,7 +262,7 @@ func benchRefresh(testCtx *testContext, b *testing.B) {
sk0Shards := testCtx.sk0Shards
ringQ := testCtx.dckksContext.ringQ
- levelStart := uint64(3)
+ levelStart := 3
type Party struct {
*RefreshProtocol
@@ -284,7 +284,7 @@ func benchRefresh(testCtx *testContext, b *testing.B) {
b.Run(testString("Refresh/Gen/", parties, testCtx.params), func(b *testing.B) {
for i := 0; i < b.N; i++ {
- p.GenShares(p.s, levelStart, parties, ciphertext, crp, p.share1, p.share2)
+ p.GenShares(p.s, levelStart, parties, ciphertext, testCtx.params.Scale(), crp, p.share1, p.share2)
}
})
@@ -305,7 +305,7 @@ func benchRefresh(testCtx *testContext, b *testing.B) {
b.Run(testString("Refresh/Recode/", parties, testCtx.params), func(b *testing.B) {
for i := 0; i < b.N; i++ {
- p.Recode(ciphertext)
+ p.Recode(ciphertext, testCtx.params.Scale())
}
})
@@ -325,7 +325,7 @@ func benchRefreshAndPermute(testCtx *testContext, b *testing.B) {
sk0Shards := testCtx.sk0Shards
- levelStart := uint64(2)
+ levelStart := 3
type Party struct {
*PermuteProtocol
@@ -350,7 +350,7 @@ func benchRefreshAndPermute(testCtx *testContext, b *testing.B) {
permutation := make([]uint64, testCtx.params.Slots())
for i := range permutation {
- permutation[i] = ring.RandUniform(testCtx.prng, testCtx.params.Slots(), testCtx.params.Slots()-1)
+ permutation[i] = ring.RandUniform(testCtx.prng, uint64(testCtx.params.Slots()), uint64(testCtx.params.Slots()-1))
}
b.Run(testString("RefreshAndPermute/Gen/", parties, testCtx.params), func(b *testing.B) {
diff --git a/dckks/dckks_test.go b/dckks/dckks_test.go
index f5ea1ad0..d5462d0c 100644
--- a/dckks/dckks_test.go
+++ b/dckks/dckks_test.go
@@ -19,9 +19,9 @@ import (
var flagLongTest = flag.Bool("long", false, "run the long test suite (all parameters). Overrides -short and requires -timeout=0.")
var printPrecisionStats = flag.Bool("print-precision", false, "print precision stats")
var minPrec float64 = 15.0
-var parties uint64 = 3
+var parties int = 3
-func testString(opname string, parties uint64, params *ckks.Parameters) string {
+func testString(opname string, parties int, params *ckks.Parameters) string {
return fmt.Sprintf("%sparties=%d/logN=%d/logQ=%d/levels=%d/alpha=%d/beta=%d",
opname,
parties,
@@ -109,7 +109,7 @@ func genTestParams(defaultParams *ckks.Parameters) (testCtx *testContext, err er
tmp0 := testCtx.dckksContext.ringQP.NewPoly()
tmp1 := testCtx.dckksContext.ringQP.NewPoly()
- for j := uint64(0); j < parties; j++ {
+ for j := 0; j < parties; j++ {
testCtx.sk0Shards[j] = kgen.GenSecretKey()
testCtx.sk1Shards[j] = kgen.GenSecretKey()
testCtx.dckksContext.ringQP.Add(tmp0, testCtx.sk0Shards[j].Value, tmp0)
@@ -150,7 +150,7 @@ func testPublicKeyGen(testCtx *testContext, t *testing.T) {
}
ckgParties := make([]*Party, parties)
- for i := uint64(0); i < parties; i++ {
+ for i := 0; i < parties; i++ {
p := new(Party)
p.CKGProtocol = NewCKGProtocol(testCtx.params)
p.s = &sk0Shards[i].SecretKey
@@ -216,7 +216,7 @@ func testRelinKeyGen(testCtx *testContext, t *testing.T) {
crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.dckksContext.ringQP)
crp := make([]*ring.Poly, testCtx.params.Beta())
- for i := uint64(0); i < testCtx.params.Beta(); i++ {
+ for i := 0; i < testCtx.params.Beta(); i++ {
crp[i] = crpGenerator.ReadNew()
}
@@ -250,7 +250,7 @@ func testRelinKeyGen(testCtx *testContext, t *testing.T) {
evaluator.Rescale(ciphertext, testCtx.params.Scale(), ciphertext)
- require.Equal(t, ciphertext.Degree(), uint64(1))
+ require.Equal(t, ciphertext.Degree(), 1)
verifyTestVectors(testCtx, decryptorSk0, coeffs, ciphertext, t)
@@ -275,7 +275,7 @@ func testKeyswitching(testCtx *testContext, t *testing.T) {
}
cksParties := make([]*Party, parties)
- for i := uint64(0); i < parties; i++ {
+ for i := 0; i < parties; i++ {
p := new(Party)
p.CKSProtocol = NewCKSProtocol(testCtx.params, 6.36)
p.s0 = sk0Shards[i].Value
@@ -328,7 +328,7 @@ func testPublicKeySwitching(testCtx *testContext, t *testing.T) {
}
pcksParties := make([]*Party, parties)
- for i := uint64(0); i < parties; i++ {
+ for i := 0; i < parties; i++ {
p := new(Party)
p.PCKSProtocol = NewPCKSProtocol(testCtx.params, 6.36)
p.s = sk0Shards[i].Value
@@ -368,7 +368,7 @@ func testRotKeyGenConjugate(testCtx *testContext, t *testing.T) {
}
pcksParties := make([]*Party, parties)
- for i := uint64(0); i < parties; i++ {
+ for i := 0; i < parties; i++ {
p := new(Party)
p.RTGProtocol = NewRotKGProtocol(testCtx.params)
p.s = &sk0Shards[i].SecretKey
@@ -383,7 +383,7 @@ func testRotKeyGenConjugate(testCtx *testContext, t *testing.T) {
crpGenerator := ring.NewUniformSampler(testCtx.prng, testCtx.dckksContext.ringQP)
crp := make([]*ring.Poly, testCtx.params.Beta())
- for i := uint64(0); i < testCtx.params.Beta(); i++ {
+ for i := 0; i < testCtx.params.Beta(); i++ {
crp[i] = crpGenerator.ReadNew()
}
@@ -406,7 +406,7 @@ func testRotKeyGenConjugate(testCtx *testContext, t *testing.T) {
coeffsWant := make([]complex128, ringQP.N>>1)
- for i := uint64(0); i < ringQP.N>>1; i++ {
+ for i := 0; i < ringQP.N>>1; i++ {
coeffsWant[i] = complex(real(coeffs[i]), -imag(coeffs[i]))
}
@@ -431,7 +431,7 @@ func testRotKeyGenCols(testCtx *testContext, t *testing.T) {
}
pcksParties := make([]*Party, parties)
- for i := uint64(0); i < parties; i++ {
+ for i := 0; i < parties; i++ {
p := new(Party)
p.RTGProtocol = NewRotKGProtocol(testCtx.params)
p.s = &sk0Shards[i].SecretKey
@@ -444,7 +444,7 @@ func testRotKeyGenCols(testCtx *testContext, t *testing.T) {
crpGenerator := ring.NewUniformSampler(testCtx.prng, ringQP)
crp := make([]*ring.Poly, testCtx.params.Beta())
- for i := uint64(0); i < testCtx.params.Beta(); i++ {
+ for i := 0; i < testCtx.params.Beta(); i++ {
crp[i] = crpGenerator.ReadNew()
}
@@ -467,7 +467,7 @@ func testRotKeyGenCols(testCtx *testContext, t *testing.T) {
evaluator := testCtx.evaluator.WithKey(ckks.EvaluationKey{Rlk: nil, Rtks: rotKeySet})
- for k := uint64(1); k < ringQP.N>>1; k <<= 1 {
+ for k := 1; k < ringQP.N>>1; k <<= 1 {
evaluator.Rotate(ciphertext, int(k), receiver)
coeffsWant := utils.RotateComplex128Slice(coeffs, int(k))
@@ -484,7 +484,7 @@ func testRefresh(testCtx *testContext, t *testing.T) {
decryptorSk0 := testCtx.decryptorSk0
sk0Shards := testCtx.sk0Shards
- levelStart := uint64(3)
+ levelStart := 3
t.Run(testString("Refresh/", parties, testCtx.params), func(t *testing.T) {
@@ -500,7 +500,7 @@ func testRefresh(testCtx *testContext, t *testing.T) {
}
RefreshParties := make([]*Party, parties)
- for i := uint64(0); i < parties; i++ {
+ for i := 0; i < parties; i++ {
p := new(Party)
p.RefreshProtocol = NewRefreshProtocol(testCtx.params)
p.s = sk0Shards[i].Value
@@ -520,7 +520,7 @@ func testRefresh(testCtx *testContext, t *testing.T) {
}
for i, p := range RefreshParties {
- p.GenShares(p.s, levelStart, parties, ciphertext, crp, p.share1, p.share2)
+ p.GenShares(p.s, levelStart, parties, ciphertext, testCtx.params.Scale(), crp, p.share1, p.share2)
if i > 0 {
P0.Aggregate(p.share1, P0.share1, P0.share1)
P0.Aggregate(p.share2, P0.share2, P0.share2)
@@ -528,9 +528,9 @@ func testRefresh(testCtx *testContext, t *testing.T) {
}
// We refresh the ciphertext with the simulated error
- P0.Decrypt(ciphertext, P0.share1) // Masked decryption
- P0.Recode(ciphertext) // Masked re-encoding
- P0.Recrypt(ciphertext, crp, P0.share2) // Masked re-encryption
+ P0.Decrypt(ciphertext, P0.share1) // Masked decryption
+ P0.Recode(ciphertext, testCtx.params.Scale()) // Masked re-encoding
+ P0.Recrypt(ciphertext, crp, P0.share2) // Masked re-encryption
require.Equal(t, ciphertext.Level(), testCtx.params.MaxLevel())
@@ -546,7 +546,7 @@ func testRefreshAndPermute(testCtx *testContext, t *testing.T) {
decryptorSk0 := testCtx.decryptorSk0
sk0Shards := testCtx.sk0Shards
- levelStart := uint64(3)
+ levelStart := 3
t.Run(testString("RefreshAndPermute/", parties, testCtx.params), func(t *testing.T) {
@@ -562,7 +562,7 @@ func testRefreshAndPermute(testCtx *testContext, t *testing.T) {
}
RefreshParties := make([]*Party, parties)
- for i := uint64(0); i < parties; i++ {
+ for i := 0; i < parties; i++ {
p := new(Party)
p.PermuteProtocol = NewPermuteProtocol(testCtx.params)
p.s = sk0Shards[i].Value
@@ -584,7 +584,7 @@ func testRefreshAndPermute(testCtx *testContext, t *testing.T) {
permutation := make([]uint64, testCtx.params.Slots())
for i := range permutation {
- permutation[i] = ring.RandUniform(testCtx.prng, testCtx.params.Slots(), testCtx.params.Slots()-1)
+ permutation[i] = ring.RandUniform(testCtx.prng, uint64(testCtx.params.Slots()), uint64(testCtx.params.Slots()-1))
}
for i, p := range RefreshParties {
@@ -619,7 +619,7 @@ func newTestVectors(testCtx *testContext, encryptor ckks.Encryptor, a float64, t
values = make([]complex128, slots)
- for i := uint64(0); i < slots; i++ {
+ for i := 0; i < slots; i++ {
values[i] = utils.RandComplex128(-a, a)
}
diff --git a/dckks/keyswitching.go b/dckks/keyswitching.go
index e3993dda..6945f8bc 100644
--- a/dckks/keyswitching.go
+++ b/dckks/keyswitching.go
@@ -47,7 +47,7 @@ func NewCKSProtocol(params *ckks.Parameters, sigmaSmudging float64) (cks *CKSPro
if err != nil {
panic(err)
}
- cks.gaussianSampler = ring.NewGaussianSampler(prng, dckksContext.ringQP, params.Sigma(), uint64(6*params.Sigma()))
+ cks.gaussianSampler = ring.NewGaussianSampler(prng, cks.dckksContext.ringQ, cks.dckksContext.params.Sigma(), int(6*cks.dckksContext.params.Sigma()))
return cks
}
@@ -74,12 +74,13 @@ func (cks *CKSProtocol) genShareDelta(skDelta *ring.Poly, ct *ckks.Ciphertext, s
ringQ := cks.dckksContext.ringQ
ringP := cks.dckksContext.ringP
+ sigma := cks.dckksContext.params.Sigma()
ringQ.MulCoeffsMontgomeryConstantLvl(ct.Level(), ct.Value()[1], skDelta, shareOut)
ringQ.MulScalarBigintLvl(ct.Level(), shareOut, ringP.ModulusBigint, shareOut)
- cks.gaussianSampler.ReadLvl(ct.Level(), cks.tmpQ)
+ cks.gaussianSampler.ReadFromDistLvl(ct.Level(), cks.tmpQ, ringQ, sigma, int(6*sigma))
extendBasisSmallNormAndCenter(ringQ, ringP, cks.tmpQ, cks.tmpP)
ringQ.NTTLvl(ct.Level(), cks.tmpQ, cks.tmpQ)
@@ -97,7 +98,7 @@ func (cks *CKSProtocol) genShareDelta(skDelta *ring.Poly, ct *ckks.Ciphertext, s
//
// [ctx[0] + sum((skInput_i - skOutput_i) * ctx[0] + e_i), ctx[1]]
func (cks *CKSProtocol) AggregateShares(share1, share2, shareOut CKSShare) {
- cks.dckksContext.ringQ.AddLvl(uint64(len(share1.Coeffs)-1), share1, share2, shareOut)
+ cks.dckksContext.ringQ.AddLvl(len(share1.Coeffs)-1, share1, share2, shareOut)
}
// KeySwitch performs the actual keyswitching operation on a ciphertext ct and put the result in ctOut
diff --git a/dckks/public_keyswitching.go b/dckks/public_keyswitching.go
index 2d9bc4e7..c68429ce 100644
--- a/dckks/public_keyswitching.go
+++ b/dckks/public_keyswitching.go
@@ -46,14 +46,14 @@ func NewPCKSProtocol(params *ckks.Parameters, sigmaSmudging float64) *PCKSProtoc
if err != nil {
panic(err)
}
- pcks.gaussianSampler = ring.NewGaussianSampler(prng, dckksContext.ringQP, params.Sigma(), uint64(6*params.Sigma()))
+ pcks.gaussianSampler = ring.NewGaussianSampler(prng, dckksContext.ringQ, pcks.dckksContext.params.Sigma(), int(6*pcks.dckksContext.params.Sigma()))
pcks.ternarySamplerMontgomery = ring.NewTernarySampler(prng, dckksContext.ringQP, 0.5, true)
return pcks
}
// AllocateShares allocates the share of the PCKS protocol.
-func (pcks *PCKSProtocol) AllocateShares(level uint64) (s PCKSShare) {
+func (pcks *PCKSProtocol) AllocateShares(level int) (s PCKSShare) {
s[0] = pcks.dckksContext.ringQ.NewPolyLvl(level)
s[1] = pcks.dckksContext.ringQ.NewPolyLvl(level)
return
@@ -70,6 +70,7 @@ func (pcks *PCKSProtocol) GenShare(sk *ring.Poly, pk *ckks.PublicKey, ct *ckks.C
ringQ := pcks.dckksContext.ringQ
ringQP := pcks.dckksContext.ringQP
+ sigma := pcks.dckksContext.params.Sigma()
pcks.ternarySamplerMontgomery.Read(pcks.tmp)
ringQP.NTTLazy(pcks.tmp, pcks.tmp)
@@ -80,11 +81,11 @@ func (pcks *PCKSProtocol) GenShare(sk *ring.Poly, pk *ckks.PublicKey, ct *ckks.C
ringQP.MulCoeffsMontgomeryConstant(pcks.tmp, pk.Value[1], pcks.share1tmp)
// h_0 = u_i * pk_0 + e0
- pcks.gaussianSampler.Read(pcks.tmp)
+ pcks.gaussianSampler.ReadFromDistLvl(len(ringQP.Modulus)-1, pcks.tmp, ringQP, sigma, int(sigma))
ringQP.NTT(pcks.tmp, pcks.tmp)
ringQP.Add(pcks.share0tmp, pcks.tmp, pcks.share0tmp)
// h_1 = u_i * pk_1 + e1
- pcks.gaussianSampler.Read(pcks.tmp)
+ pcks.gaussianSampler.ReadFromDistLvl(len(ringQP.Modulus)-1, pcks.tmp, ringQP, sigma, int(sigma))
ringQP.NTT(pcks.tmp, pcks.tmp)
ringQP.Add(pcks.share1tmp, pcks.tmp, pcks.share1tmp)
@@ -107,7 +108,7 @@ func (pcks *PCKSProtocol) GenShare(sk *ring.Poly, pk *ckks.PublicKey, ct *ckks.C
// [ctx[0] + sum(s_i * ctx[0] + u_i * pk[0] + e_0i), sum(u_i * pk[1] + e_1i)]
func (pcks *PCKSProtocol) AggregateShares(share1, share2, shareOut PCKSShare) {
- level := uint64(len(share1[0].Coeffs)) - 1
+ level := len(share1[0].Coeffs) - 1
pcks.dckksContext.ringQ.AddLvl(level, share1[0], share2[0], shareOut[0])
pcks.dckksContext.ringQ.AddLvl(level, share1[1], share2[1], shareOut[1])
}
diff --git a/dckks/public_permute.go b/dckks/public_permute.go
index e9bc310b..f3697516 100644
--- a/dckks/public_permute.go
+++ b/dckks/public_permute.go
@@ -22,7 +22,7 @@ type PermuteProtocol struct {
// NewPermuteProtocol creates a new instance of the PermuteProtocol.
func NewPermuteProtocol(params *ckks.Parameters) (pp *PermuteProtocol) {
- prec := uint64(256)
+ prec := int(256)
pp = new(PermuteProtocol)
pp.encoder = ckks.NewEncoderBigComplex(params, prec)
@@ -33,7 +33,7 @@ func NewPermuteProtocol(params *ckks.Parameters) (pp *PermuteProtocol) {
pp.maskFloat = make([]*big.Float, dckksContext.n)
pp.maskComplex = make([]*ring.Complex, dckksContext.n>>1)
- for i := uint64(0); i < dckksContext.n>>1; i++ {
+ for i := 0; i < dckksContext.n>>1; i++ {
pp.maskFloat[i] = new(big.Float)
pp.maskFloat[i].SetPrec(uint(prec))
@@ -47,13 +47,13 @@ func NewPermuteProtocol(params *ckks.Parameters) (pp *PermuteProtocol) {
if err != nil {
panic(err)
}
- pp.gaussianSampler = ring.NewGaussianSampler(prng, dckksContext.ringQ, params.Sigma(), uint64(6*params.Sigma()))
+ pp.gaussianSampler = ring.NewGaussianSampler(prng, pp.dckksContext.ringQ, pp.dckksContext.params.Sigma(), int(6*pp.dckksContext.params.Sigma()))
return
}
// AllocateShares allocates the shares of the Refresh protocol.
-func (pp *PermuteProtocol) AllocateShares(levelStart uint64) (RefreshShareDecrypt, RefreshShareRecrypt) {
+func (pp *PermuteProtocol) AllocateShares(levelStart int) (RefreshShareDecrypt, RefreshShareRecrypt) {
return pp.dckksContext.ringQ.NewPolyLvl(levelStart), pp.dckksContext.ringQ.NewPoly()
}
@@ -70,16 +70,17 @@ func (pp *PermuteProtocol) permuteWithIndex(permutation []uint64, values []*ring
}
// GenShares generates the decryption and recryption shares of the Refresh protocol.
-func (pp *PermuteProtocol) GenShares(sk *ring.Poly, levelStart, nParties uint64, ciphertext *ckks.Ciphertext, crs *ring.Poly, slots uint64, permutation []uint64, shareDecrypt RefreshShareDecrypt, shareRecrypt RefreshShareRecrypt) {
+func (pp *PermuteProtocol) GenShares(sk *ring.Poly, levelStart, nParties int, ciphertext *ckks.Ciphertext, crs *ring.Poly, slots int, permutation []uint64, shareDecrypt RefreshShareDecrypt, shareRecrypt RefreshShareRecrypt) {
ringQ := pp.dckksContext.ringQ
+ sigma := pp.dckksContext.params.Sigma()
bound := ring.NewUint(ringQ.Modulus[0])
- for i := uint64(1); i < levelStart+1; i++ {
+ for i := 1; i < levelStart+1; i++ {
bound.Mul(bound, ring.NewUint(ringQ.Modulus[i]))
}
- bound.Quo(bound, ring.NewUint(2*nParties))
+ bound.Quo(bound, ring.NewUint(uint64(2*nParties)))
boundHalf := new(big.Int).Rsh(bound, 1)
maxSlots := pp.dckksContext.n >> 1
@@ -87,7 +88,7 @@ func (pp *PermuteProtocol) GenShares(sk *ring.Poly, levelStart, nParties uint64,
// Samples the whole N coefficients for h0
var sign int
- for i := uint64(0); i < 2*maxSlots; i++ {
+ for i := 0; i < 2*maxSlots; i++ {
pp.maskBigint[i] = ring.RandInt(bound)
@@ -103,12 +104,12 @@ func (pp *PermuteProtocol) GenShares(sk *ring.Poly, levelStart, nParties uint64,
// h0 = sk*c1 + mask
ringQ.MulCoeffsMontgomeryAndAddLvl(levelStart, sk, ciphertext.Value()[1], shareDecrypt)
// h0 = sk*c1 + mask + e0
- pp.gaussianSampler.Read(pp.tmp)
- ringQ.NTT(pp.tmp, pp.tmp)
+ pp.gaussianSampler.ReadFromDistLvl(levelStart, pp.tmp, ringQ, sigma, int(6*sigma))
+ ringQ.NTTLvl(levelStart, pp.tmp, pp.tmp)
ringQ.AddLvl(levelStart, shareDecrypt, pp.tmp, shareDecrypt)
// Permutes only the (sparse) plaintext coefficients of h1
- for i, jdx, idx := uint64(0), maxSlots, uint64(0); i < slots; i, jdx, idx = i+1, jdx+gap, idx+gap {
+ for i, jdx, idx := 0, maxSlots, 0; i < slots; i, jdx, idx = i+1, jdx+gap, idx+gap {
pp.maskFloat[idx].SetInt(pp.maskBigint[idx])
pp.maskFloat[jdx].SetInt(pp.maskBigint[jdx])
pp.maskComplex[idx][0] = pp.maskFloat[idx]
@@ -120,7 +121,7 @@ func (pp *PermuteProtocol) GenShares(sk *ring.Poly, levelStart, nParties uint64,
pp.permuteWithIndex(permutation, pp.maskComplex)
pp.encoder.InvFFT(pp.maskComplex, slots)
- for i, jdx, idx := uint64(0), maxSlots, uint64(0); i < slots; i, jdx, idx = i+1, jdx+gap, idx+gap {
+ for i, jdx, idx := 0, maxSlots, 0; i < slots; i, jdx, idx = i+1, jdx+gap, idx+gap {
pp.maskComplex[i].Real().Int(pp.maskBigint[idx])
pp.maskComplex[i].Imag().Int(pp.maskBigint[jdx])
}
@@ -133,7 +134,7 @@ func (pp *PermuteProtocol) GenShares(sk *ring.Poly, levelStart, nParties uint64,
ringQ.MulCoeffsMontgomeryAndAdd(sk, crs, shareRecrypt)
// h1 = sk*a + mask + e1
- pp.gaussianSampler.Read(pp.tmp)
+ pp.gaussianSampler.ReadFromDistLvl(len(ringQ.Modulus)-1, pp.tmp, ringQ, sigma, int(6*sigma))
ringQ.NTT(pp.tmp, pp.tmp)
ringQ.Add(shareRecrypt, pp.tmp, shareRecrypt)
@@ -145,7 +146,7 @@ func (pp *PermuteProtocol) GenShares(sk *ring.Poly, levelStart, nParties uint64,
// Aggregate adds share1 with share2 on shareOut.
func (pp *PermuteProtocol) Aggregate(share1, share2, shareOut *ring.Poly) {
- pp.dckksContext.ringQ.AddLvl(uint64(len(share1.Coeffs)-1), share1, share2, shareOut)
+ pp.dckksContext.ringQ.AddLvl(len(share1.Coeffs)-1, share1, share2, shareOut)
}
// Decrypt operates a masked decryption on the ciphertext with the given decryption share.
@@ -155,7 +156,7 @@ func (pp *PermuteProtocol) Decrypt(ciphertext *ckks.Ciphertext, shareDecrypt Ref
// Permute takes a masked decrypted ciphertext at modulus Q_0 and returns the same masked decrypted ciphertext at modulus Q_L, with Q_0 << Q_L.
// Operates a permutation of the plaintext slots.
-func (pp *PermuteProtocol) Permute(ciphertext *ckks.Ciphertext, permutation []uint64, slots uint64) {
+func (pp *PermuteProtocol) Permute(ciphertext *ckks.Ciphertext, permutation []uint64, slots int) {
dckksContext := pp.dckksContext
ringQ := pp.dckksContext.ringQ
@@ -164,7 +165,7 @@ func (pp *PermuteProtocol) Permute(ciphertext *ckks.Ciphertext, permutation []ui
ringQ.PolyToBigint(ciphertext.Value()[0], pp.maskBigint)
QStart := ring.NewUint(ringQ.Modulus[0])
- for i := uint64(1); i < ciphertext.Level()+1; i++ {
+ for i := 1; i < ciphertext.Level()+1; i++ {
QStart.Mul(QStart, ring.NewUint(ringQ.Modulus[i]))
}
QHalf := new(big.Int).Rsh(QStart, 1)
@@ -173,7 +174,7 @@ func (pp *PermuteProtocol) Permute(ciphertext *ckks.Ciphertext, permutation []ui
gap := maxSlots / slots
var sign int
- for i, idx := uint64(0), uint64(0); i < slots; i, idx = i+1, idx+gap {
+ for i, idx := 0, 0; i < slots; i, idx = i+1, idx+gap {
// Centers the value around the current modulus
sign = pp.maskBigint[idx].Cmp(QHalf)
@@ -197,7 +198,7 @@ func (pp *PermuteProtocol) Permute(ciphertext *ckks.Ciphertext, permutation []ui
pp.encoder.InvFFT(pp.maskComplex, slots)
- for i, jdx, idx := uint64(0), maxSlots, uint64(0); i < slots; i, jdx, idx = i+1, jdx+gap, idx+gap {
+ for i, jdx, idx := 0, maxSlots, 0; i < slots; i, jdx, idx = i+1, jdx+gap, idx+gap {
pp.maskComplex[i].Real().Int(pp.maskBigint[idx])
pp.maskComplex[i].Imag().Int(pp.maskBigint[jdx])
}
diff --git a/dckks/public_refresh.go b/dckks/public_refresh.go
index 803daadc..0757a7de 100644
--- a/dckks/public_refresh.go
+++ b/dckks/public_refresh.go
@@ -34,27 +34,28 @@ func NewRefreshProtocol(params *ckks.Parameters) (refreshProtocol *RefreshProtoc
if err != nil {
panic(err)
}
- refreshProtocol.gaussianSampler = ring.NewGaussianSampler(prng, dckksContext.ringQ, params.Sigma(), uint64(6*params.Sigma()))
+ refreshProtocol.gaussianSampler = ring.NewGaussianSampler(prng, refreshProtocol.dckksContext.ringQ, refreshProtocol.dckksContext.params.Sigma(), int(6*refreshProtocol.dckksContext.params.Sigma()))
return
}
// AllocateShares allocates the shares of the Refresh protocol.
-func (refreshProtocol *RefreshProtocol) AllocateShares(levelStart uint64) (RefreshShareDecrypt, RefreshShareRecrypt) {
+func (refreshProtocol *RefreshProtocol) AllocateShares(levelStart int) (RefreshShareDecrypt, RefreshShareRecrypt) {
return refreshProtocol.dckksContext.ringQ.NewPolyLvl(levelStart), refreshProtocol.dckksContext.ringQ.NewPoly()
}
// GenShares generates the decryption and recryption shares of the Refresh protocol.
-func (refreshProtocol *RefreshProtocol) GenShares(sk *ring.Poly, levelStart, nParties uint64, ciphertext *ckks.Ciphertext, crs *ring.Poly, shareDecrypt RefreshShareDecrypt, shareRecrypt RefreshShareRecrypt) {
+func (refreshProtocol *RefreshProtocol) GenShares(sk *ring.Poly, levelStart, nParties int, ciphertext *ckks.Ciphertext, targetScale float64, crs *ring.Poly, shareDecrypt RefreshShareDecrypt, shareRecrypt RefreshShareRecrypt) {
ringQ := refreshProtocol.dckksContext.ringQ
+ sigma := refreshProtocol.dckksContext.params.Sigma()
bound := ring.NewUint(ringQ.Modulus[0])
- for i := uint64(1); i < levelStart+1; i++ {
+ for i := 1; i < levelStart+1; i++ {
bound.Mul(bound, ring.NewUint(ringQ.Modulus[i]))
}
- bound.Quo(bound, ring.NewUint(2*nParties))
+ bound.Quo(bound, ring.NewUint(uint64(2*nParties)))
boundHalf := new(big.Int).Rsh(bound, 1)
var sign int
@@ -68,6 +69,22 @@ func (refreshProtocol *RefreshProtocol) GenShares(sk *ring.Poly, levelStart, nPa
// h0 = mask (at level min)
ringQ.SetCoefficientsBigintLvl(levelStart, refreshProtocol.maskBigint, shareDecrypt)
+
+ inputScaleFlo := ring.NewFloat(ciphertext.Scale(), 256)
+ outputScaleFlo := ring.NewFloat(targetScale, 256)
+
+ inputScaleInt := new(big.Int)
+ outputScaleInt := new(big.Int)
+
+ inputScaleFlo.Int(inputScaleInt)
+ outputScaleFlo.Int(outputScaleInt)
+
+ // Scales the mask by the ratio between the two scales
+ for i := range refreshProtocol.maskBigint {
+ refreshProtocol.maskBigint[i].Mul(refreshProtocol.maskBigint[i], outputScaleInt)
+ refreshProtocol.maskBigint[i].Quo(refreshProtocol.maskBigint[i], inputScaleInt)
+ }
+
// h1 = mask (at level max)
ringQ.SetCoefficientsBigint(refreshProtocol.maskBigint, shareRecrypt)
@@ -85,12 +102,12 @@ func (refreshProtocol *RefreshProtocol) GenShares(sk *ring.Poly, levelStart, nPa
ringQ.MulCoeffsMontgomeryAndAdd(sk, crs, shareRecrypt)
// h0 = sk*c1 + mask + e0
- refreshProtocol.gaussianSampler.Read(refreshProtocol.tmp)
- ringQ.NTT(refreshProtocol.tmp, refreshProtocol.tmp)
+ refreshProtocol.gaussianSampler.ReadFromDistLvl(levelStart, refreshProtocol.tmp, ringQ, sigma, int(6*sigma))
+ ringQ.NTTLvl(levelStart, refreshProtocol.tmp, refreshProtocol.tmp)
ringQ.AddLvl(levelStart, shareDecrypt, refreshProtocol.tmp, shareDecrypt)
// h1 = sk*a + mask + e1
- refreshProtocol.gaussianSampler.Read(refreshProtocol.tmp)
+ refreshProtocol.gaussianSampler.ReadFromDistLvl(len(ringQ.Modulus)-1, refreshProtocol.tmp, ringQ, sigma, int(6*sigma))
ringQ.NTT(refreshProtocol.tmp, refreshProtocol.tmp)
ringQ.Add(shareRecrypt, refreshProtocol.tmp, shareRecrypt)
@@ -102,7 +119,7 @@ func (refreshProtocol *RefreshProtocol) GenShares(sk *ring.Poly, levelStart, nPa
// Aggregate adds share1 with share2 on shareOut.
func (refreshProtocol *RefreshProtocol) Aggregate(share1, share2, shareOut *ring.Poly) {
- refreshProtocol.dckksContext.ringQ.AddLvl(uint64(len(share1.Coeffs)-1), share1, share2, shareOut)
+ refreshProtocol.dckksContext.ringQ.AddLvl(len(share1.Coeffs)-1, share1, share2, shareOut)
}
// Decrypt operates a masked decryption on the ciphertext with the given decryption share.
@@ -111,16 +128,25 @@ func (refreshProtocol *RefreshProtocol) Decrypt(ciphertext *ckks.Ciphertext, sha
}
// Recode takes a masked decrypted ciphertext at modulus Q_0 and returns the same masked decrypted ciphertext at modulus Q_L, with Q_0 << Q_L.
-func (refreshProtocol *RefreshProtocol) Recode(ciphertext *ckks.Ciphertext) {
+func (refreshProtocol *RefreshProtocol) Recode(ciphertext *ckks.Ciphertext, targetScale float64) {
dckksContext := refreshProtocol.dckksContext
ringQ := refreshProtocol.dckksContext.ringQ
+ inputScaleFlo := ring.NewFloat(ciphertext.Scale(), 256)
+ outputScaleFlo := ring.NewFloat(targetScale, 256)
+
+ inputScaleInt := new(big.Int)
+ outputScaleInt := new(big.Int)
+
+ inputScaleFlo.Int(inputScaleInt)
+ outputScaleFlo.Int(outputScaleInt)
+
ringQ.InvNTTLvl(ciphertext.Level(), ciphertext.Value()[0], ciphertext.Value()[0])
ringQ.PolyToBigint(ciphertext.Value()[0], refreshProtocol.maskBigint)
QStart := ring.NewUint(ringQ.Modulus[0])
- for i := uint64(1); i < ciphertext.Level()+1; i++ {
+ for i := 1; i < ciphertext.Level()+1; i++ {
QStart.Mul(QStart, ring.NewUint(ringQ.Modulus[i]))
}
@@ -132,16 +158,21 @@ func (refreshProtocol *RefreshProtocol) Recode(ciphertext *ckks.Ciphertext) {
}
var sign int
- for i := uint64(0); i < dckksContext.n; i++ {
+ for i := 0; i < dckksContext.n; i++ {
sign = refreshProtocol.maskBigint[i].Cmp(QHalf)
if sign == 1 || sign == 0 {
refreshProtocol.maskBigint[i].Sub(refreshProtocol.maskBigint[i], QStart)
}
+
+ refreshProtocol.maskBigint[i].Mul(refreshProtocol.maskBigint[i], outputScaleInt)
+ refreshProtocol.maskBigint[i].Quo(refreshProtocol.maskBigint[i], inputScaleInt)
}
ringQ.SetCoefficientsBigintLvl(ciphertext.Level(), refreshProtocol.maskBigint, ciphertext.Value()[0])
ringQ.NTTLvl(ciphertext.Level(), ciphertext.Value()[0], ciphertext.Value()[0])
+
+ ciphertext.SetScale(targetScale)
}
// Recrypt operates a masked recryption on the masked decrypted ciphertext.
diff --git a/dckks/utils.go b/dckks/utils.go
index 6237815c..719d6b59 100644
--- a/dckks/utils.go
+++ b/dckks/utils.go
@@ -9,7 +9,7 @@ func extendBasisSmallNormAndCenter(ringQ, ringP *ring.Ring, polQ, polP *ring.Pol
Q = ringQ.Modulus[0]
QHalf = Q >> 1
- for j := uint64(0); j < ringQ.N; j++ {
+ for j := 0; j < ringQ.N; j++ {
coeff = polQ.Coeffs[0][j]
diff --git a/drlwe/public_key_gen.go b/drlwe/public_key_gen.go
index 4630ae91..c0a4d0d8 100644
--- a/drlwe/public_key_gen.go
+++ b/drlwe/public_key_gen.go
@@ -17,7 +17,7 @@ type CollectivePublicKeyGenerator interface {
// CKGProtocol is the structure storing the parameters and and precomputations for the collective key generation protocol.
type CKGProtocol struct {
- n uint64
+ n int
ringQ *ring.Ring
ringP *ring.Ring
@@ -40,7 +40,7 @@ func (share *CKGShare) UnmarshalBinary(data []byte) error {
}
// NewCKGProtocol creates a new CKGProtocol instance
-func NewCKGProtocol(n uint64, q, p []uint64, sigma float64) *CKGProtocol { // TODO drlwe.Params
+func NewCKGProtocol(n int, q, p []uint64, sigma float64) *CKGProtocol { // TODO drlwe.Params
ckg := new(CKGProtocol)
var err error
@@ -60,7 +60,7 @@ func NewCKGProtocol(n uint64, q, p []uint64, sigma float64) *CKGProtocol { // TO
if err != nil {
panic(err)
}
- ckg.gaussianSampler = ring.NewGaussianSampler(prng, ckg.ringQP, sigma, uint64(6*sigma))
+ ckg.gaussianSampler = ring.NewGaussianSampler(prng, ckg.ringQP, sigma, int(6*sigma))
return ckg
}
diff --git a/drlwe/relin_key_gen.go b/drlwe/relin_key_gen.go
index 069917e1..359b5981 100644
--- a/drlwe/relin_key_gen.go
+++ b/drlwe/relin_key_gen.go
@@ -20,10 +20,10 @@ type RelinearizationKeyGenerator interface {
// RKGProtocol is the structure storing the parameters and and precomputations for the collective relinearization key generation protocol.
type RKGProtocol struct {
- ringQModCount uint64
- ringQPModCount uint64
- alpha uint64
- beta uint64
+ ringQModCount int
+ ringQPModCount int
+ alpha int
+ beta int
ringP *ring.Ring
ringQP *ring.Ring
gaussianSampler *ring.GaussianSampler
@@ -39,13 +39,13 @@ type RKGShare struct {
}
// NewRKGProtocol creates a new RKG protocol struct
-func NewRKGProtocol(n uint64, q, p []uint64, ephSkPr, sigma float64) *RKGProtocol {
+func NewRKGProtocol(n int, q, p []uint64, ephSkPr, sigma float64) *RKGProtocol {
rkg := new(RKGProtocol)
- rkg.ringQModCount = uint64(len(q))
- rkg.alpha = uint64(len(p))
+ rkg.ringQModCount = len(q)
+ rkg.alpha = len(p)
rkg.ringQPModCount = rkg.ringQModCount + rkg.alpha
if rkg.alpha != 0 {
- rkg.beta = uint64(math.Ceil(float64(len(q)) / float64(len(p))))
+ rkg.beta = int(math.Ceil(float64(len(q)) / float64(len(p))))
} else {
rkg.beta = 1
}
@@ -62,7 +62,7 @@ func NewRKGProtocol(n uint64, q, p []uint64, ephSkPr, sigma float64) *RKGProtoco
if err != nil {
panic(err) // TODO error
}
- rkg.gaussianSampler = ring.NewGaussianSampler(prng, rkg.ringQP, sigma, uint64(6*sigma))
+ rkg.gaussianSampler = ring.NewGaussianSampler(prng, rkg.ringQP, sigma, int(6*sigma))
rkg.ternarySampler = ring.NewTernarySampler(prng, rkg.ringQP, ephSkPr, true)
rkg.tmpPoly1, rkg.tmpPoly2 = rkg.ringQP.NewPoly(), rkg.ringQP.NewPoly()
return rkg
@@ -74,7 +74,7 @@ func (ekg *RKGProtocol) AllocateShares() (ephSk *rlwe.SecretKey, r1 *RKGShare, r
r1, r2 = new(RKGShare), new(RKGShare)
r1.value = make([][2]*ring.Poly, ekg.beta)
r2.value = make([][2]*ring.Poly, ekg.beta)
- for i := uint64(0); i < ekg.beta; i++ {
+ for i := 0; i < ekg.beta; i++ {
r1.value[i][0] = ekg.ringQP.NewPoly()
r1.value[i][1] = ekg.ringQP.NewPoly()
r2.value[i][0] = ekg.ringQP.NewPoly()
@@ -96,19 +96,19 @@ func (ekg *RKGProtocol) GenShareRoundOne(sk *rlwe.SecretKey, crp []*ring.Poly, e
ekg.ternarySampler.Read(ephSkOut.Value)
ekg.ringQP.NTT(ephSkOut.Value, ephSkOut.Value)
- for i := uint64(0); i < ekg.beta; i++ {
+ for i := 0; i < ekg.beta; i++ {
// h = e
ekg.gaussianSampler.Read(shareOut.value[i][0])
ekg.ringQP.NTT(shareOut.value[i][0], shareOut.value[i][0])
// h = sk*CrtBaseDecompQi + e
- for j := uint64(0); j < ekg.alpha; j++ {
+ for j := 0; j < ekg.alpha; j++ {
index := i*ekg.alpha + j
qi := ekg.ringQP.Modulus[index]
skP := ekg.tmpPoly1.Coeffs[index]
h := shareOut.value[i][0].Coeffs[index]
- for w := uint64(0); w < ekg.ringQP.N; w++ {
+ for w := 0; w < ekg.ringQP.N; w++ {
h[w] = ring.CRed(h[w]+skP[w], qi)
}
@@ -145,7 +145,7 @@ func (ekg *RKGProtocol) GenShareRoundTwo(ephSk, sk *rlwe.SecretKey, round1 *RKGS
// Each sample is of the form [-u*a_i + s*w_i + e_i]
// So for each element of the base decomposition w_i :
- for i := uint64(0); i < ekg.beta; i++ {
+ for i := 0; i < ekg.beta; i++ {
// Computes [(sum samples)*sk + e_1i, sk*a + e_2i]
@@ -169,7 +169,7 @@ func (ekg *RKGProtocol) GenShareRoundTwo(ephSk, sk *rlwe.SecretKey, round1 *RKGS
// AggregateShares combines two RKG shares into a single one
func (ekg *RKGProtocol) AggregateShares(share1, share2, shareOut *RKGShare) {
- for i := uint64(0); i < ekg.beta; i++ {
+ for i := 0; i < ekg.beta; i++ {
ekg.ringQP.Add(share1.value[i][0], share2.value[i][0], shareOut.value[i][0])
ekg.ringQP.Add(share1.value[i][1], share2.value[i][1], shareOut.value[i][1])
}
@@ -177,7 +177,7 @@ func (ekg *RKGProtocol) AggregateShares(share1, share2, shareOut *RKGShare) {
// GenRelinearizationKey computes the generated RLK from the public shares and write the result in evalKeyOut
func (ekg *RKGProtocol) GenRelinearizationKey(round1 *RKGShare, round2 *RKGShare, evalKeyOut *rlwe.RelinearizationKey) {
- for i := uint64(0); i < ekg.beta; i++ {
+ for i := 0; i < ekg.beta; i++ {
ekg.ringQP.Add(round2.value[i][0], round2.value[i][1], evalKeyOut.Keys[0].Value[i][0])
evalKeyOut.Keys[0].Value[i][1].Copy(round1.value[i][1])
@@ -190,7 +190,7 @@ func (ekg *RKGProtocol) GenRelinearizationKey(round1 *RKGShare, round2 *RKGShare
func (share *RKGShare) MarshalBinary() ([]byte, error) {
//we have modulus * bitLog * Len of 1 ring rings
rLength := (share.value[0])[0].GetDataLen(true)
- data := make([]byte, 1+2*rLength*uint64(len(share.value)))
+ data := make([]byte, 1+2*rLength*len(share.value))
if len(share.value) > 0xFF {
return []byte{}, errors.New("RKGShare : uint8 overflow on length")
}
@@ -198,7 +198,7 @@ func (share *RKGShare) MarshalBinary() ([]byte, error) {
//write all of our rings in the data.
//write all the polys
- ptr := uint64(1)
+ ptr := 1
for _, elem := range share.value {
_, err := elem[0].WriteTo(data[ptr : ptr+rLength])
if err != nil {
diff --git a/drlwe/rot_key_gen.go b/drlwe/rot_key_gen.go
index e084c0bd..61eff2c1 100644
--- a/drlwe/rot_key_gen.go
+++ b/drlwe/rot_key_gen.go
@@ -28,25 +28,25 @@ type RTGShare struct {
type RTGProtocol struct { // TODO rename GaloisKeyGen ?
ringQP *ring.Ring
ringPModulusBigint *big.Int
- ringQModCount uint64
- alpha uint64
- beta uint64
+ ringQModCount int
+ alpha int
+ beta int
tmpPoly [2]*ring.Poly
gaussianSampler *ring.GaussianSampler
}
// NewRTGProtocol creates a RTGProtocol instance
-func NewRTGProtocol(n uint64, q, p []uint64, sigma float64) *RTGProtocol {
+func NewRTGProtocol(n int, q, p []uint64, sigma float64) *RTGProtocol {
rtg := new(RTGProtocol)
- rtg.ringQModCount = uint64(len(q))
+ rtg.ringQModCount = len(q)
rtg.ringPModulusBigint = big.NewInt(1)
for _, pi := range p {
rtg.ringPModulusBigint.Mul(rtg.ringPModulusBigint, new(big.Int).SetUint64(pi))
}
- rtg.alpha = uint64(len(p))
+ rtg.alpha = len(p)
if rtg.alpha != 0 {
- rtg.beta = uint64(math.Ceil(float64(len(q)) / float64(len(p))))
+ rtg.beta = int(math.Ceil(float64(len(q)) / float64(len(p))))
} else {
rtg.beta = 1
}
@@ -60,7 +60,7 @@ func NewRTGProtocol(n uint64, q, p []uint64, sigma float64) *RTGProtocol {
if err != nil {
panic(err)
}
- rtg.gaussianSampler = ring.NewGaussianSampler(prng, rtg.ringQP, sigma, uint64(6*sigma))
+ rtg.gaussianSampler = ring.NewGaussianSampler(prng, rtg.ringQP, sigma, int(6*sigma))
rtg.tmpPoly = [2]*ring.Poly{rtg.ringQP.NewPoly(), rtg.ringQP.NewPoly()}
@@ -81,15 +81,15 @@ func (rtg *RTGProtocol) AllocateShares() (rtgShare *RTGShare) {
func (rtg *RTGProtocol) GenShare(sk *rlwe.SecretKey, galEl uint64, crp []*ring.Poly, shareOut *RTGShare) {
twoN := rtg.ringQP.N << 2
- galElInv := ring.ModExp(galEl, twoN-1, twoN)
+ galElInv := ring.ModExp(galEl, int(twoN-1), uint64(twoN))
ring.PermuteNTT(sk.Value, galElInv, rtg.tmpPoly[1])
rtg.ringQP.MulScalarBigint(sk.Value, rtg.ringPModulusBigint, rtg.tmpPoly[0])
- var index uint64
+ var index int
- for i := uint64(0); i < rtg.beta; i++ {
+ for i := 0; i < rtg.beta; i++ {
// e
rtg.gaussianSampler.Read(shareOut.Value[i])
@@ -100,7 +100,7 @@ func (rtg *RTGProtocol) GenShare(sk *rlwe.SecretKey, galEl uint64, crp []*ring.P
// e + sk_in * (qiBarre*qiStar) * 2^w
// (qiBarre*qiStar)%qi = 1, else 0
- for j := uint64(0); j < rtg.alpha; j++ {
+ for j := 0; j < rtg.alpha; j++ {
index = i*rtg.alpha + j
@@ -108,7 +108,7 @@ func (rtg *RTGProtocol) GenShare(sk *rlwe.SecretKey, galEl uint64, crp []*ring.P
tmp0 := rtg.tmpPoly[0].Coeffs[index]
tmp1 := shareOut.Value[i].Coeffs[index]
- for w := uint64(0); w < rtg.ringQP.N; w++ {
+ for w := 0; w < rtg.ringQP.N; w++ {
tmp1[w] = ring.CRed(tmp1[w]+tmp0[w], qi)
}
@@ -130,14 +130,14 @@ func (rtg *RTGProtocol) GenShare(sk *rlwe.SecretKey, galEl uint64, crp []*ring.P
// Aggregate aggregates two shares in the Rotation Key Generation protocol
func (rtg *RTGProtocol) Aggregate(share1, share2, shareOut *RTGShare) {
- for i := uint64(0); i < rtg.beta; i++ {
+ for i := 0; i < rtg.beta; i++ {
rtg.ringQP.Add(share1.Value[i], share2.Value[i], shareOut.Value[i])
}
}
// GenRotationKey finalizes the RTG protocol and populates the input RotationKey with the computed collective SwitchingKey.
func (rtg *RTGProtocol) GenRotationKey(share *RTGShare, crp []*ring.Poly, rotKey *rlwe.SwitchingKey) {
- for i := uint64(0); i < rtg.beta; i++ {
+ for i := 0; i < rtg.beta; i++ {
rtg.ringQP.Copy(share.Value[i], rotKey.Value[i][0])
rtg.ringQP.Copy(crp[i], rotKey.Value[i][1])
}
@@ -146,9 +146,9 @@ func (rtg *RTGProtocol) GenRotationKey(share *RTGShare, crp []*ring.Poly, rotKey
// MarshalBinary encode the target element on a slice of byte.
func (share *RTGShare) MarshalBinary() ([]byte, error) {
lenRing := share.Value[0].GetDataLen(true)
- data := make([]byte, 8+lenRing*uint64(len(share.Value)))
- binary.BigEndian.PutUint64(data[:8], lenRing)
- ptr := uint64(8)
+ data := make([]byte, 8+lenRing*len(share.Value))
+ binary.BigEndian.PutUint64(data[:8], uint64(lenRing))
+ ptr := 8
for _, val := range share.Value {
cnt, err := val.WriteTo(data[ptr : ptr+lenRing])
if err != nil {
diff --git a/examples/bfv/main.go b/examples/bfv/main.go
index c3fa319b..188d285c 100644
--- a/examples/bfv/main.go
+++ b/examples/bfv/main.go
@@ -47,7 +47,7 @@ func obliviousRiding() {
// The rider decrypts the result and chooses the closest driver.
// Number of drivers in the area
- nbDrivers := uint64(2048) //max is N
+ nbDrivers := 2048 //max is N
// BFV parameters (128 bit security) with plaintext modulus 65929217
params := bfv.DefaultParams[bfv.PN13QP218].WithT(0x3ee0001)
@@ -75,8 +75,8 @@ func obliviousRiding() {
1< 61 {
panic("logQ must be between 1 and 61")
@@ -28,13 +28,13 @@ func GenerateNTTPrimes(logQ, NthRoot, n uint64) (primes []uint64) {
// NextNTTPrime returns the next NthRoot NTT prime after q.
// The input q must be itself an NTT prime for the given NthRoot.
-func NextNTTPrime(q, NthRoot uint64) (qNext uint64, err error) {
+func NextNTTPrime(q uint64, NthRoot int) (qNext uint64, err error) {
- qNext = q + NthRoot
+ qNext = q + uint64(NthRoot)
for !IsPrime(qNext) {
- qNext += NthRoot
+ qNext += uint64(NthRoot)
if bits.Len64(qNext) > 61 {
return 0, fmt.Errorf("next NTT prime exceeds the maximum bit-size of 61 bits")
@@ -46,21 +46,21 @@ func NextNTTPrime(q, NthRoot uint64) (qNext uint64, err error) {
// PreviousNTTPrime returns the previous NthRoot NTT prime after q.
// The input q must be itself an NTT prime for the given NthRoot.
-func PreviousNTTPrime(q, NthRoot uint64) (qPrev uint64, err error) {
+func PreviousNTTPrime(q uint64, NthRoot int) (qPrev uint64, err error) {
- if q < NthRoot {
+ if q < uint64(NthRoot) {
return 0, fmt.Errorf("previous NTT prime is smaller than NthRoot")
}
- qPrev = q - NthRoot
+ qPrev = q - uint64(NthRoot)
for !IsPrime(qPrev) {
- if q < NthRoot {
+ if q < uint64(NthRoot) {
return 0, fmt.Errorf("previous NTT prime is smaller than NthRoot")
}
- qPrev -= NthRoot
+ qPrev -= uint64(NthRoot)
}
return qPrev, nil
@@ -68,14 +68,14 @@ func PreviousNTTPrime(q, NthRoot uint64) (qPrev uint64, err error) {
// GenerateNTTPrimesQ generates "levels" different NthRoot NTT-friendly
// primes starting from 2**LogQ and alternating between upward and downward.
-func GenerateNTTPrimesQ(logQ, NthRoot, levels uint64) (primes []uint64) {
+func GenerateNTTPrimesQ(logQ, NthRoot, levels int) (primes []uint64) {
var nextPrime, previousPrime, Qpow2 uint64
var checkfornextprime, checkforpreviousprime bool
primes = []uint64{}
- Qpow2 = 1 << logQ
+ Qpow2 = uint64(1 << logQ)
nextPrime = Qpow2 + 1
previousPrime = Qpow2 + 1
@@ -91,19 +91,19 @@ func GenerateNTTPrimesQ(logQ, NthRoot, levels uint64) (primes []uint64) {
if checkfornextprime {
- if nextPrime > 0xffffffffffffffff-NthRoot {
+ if nextPrime > 0xffffffffffffffff-uint64(NthRoot) {
checkfornextprime = false
} else {
- nextPrime += NthRoot
+ nextPrime += uint64(NthRoot)
if IsPrime(nextPrime) {
primes = append(primes, nextPrime)
- if uint64(len(primes)) == levels {
+ if len(primes) == levels {
return
}
}
@@ -112,19 +112,19 @@ func GenerateNTTPrimesQ(logQ, NthRoot, levels uint64) (primes []uint64) {
if checkforpreviousprime {
- if previousPrime < NthRoot {
+ if previousPrime < uint64(NthRoot) {
checkforpreviousprime = false
} else {
- previousPrime -= NthRoot
+ previousPrime -= uint64(NthRoot)
if IsPrime(previousPrime) {
primes = append(primes, previousPrime)
- if uint64(len(primes)) == levels {
+ if len(primes) == levels {
return
}
}
@@ -137,13 +137,13 @@ func GenerateNTTPrimesQ(logQ, NthRoot, levels uint64) (primes []uint64) {
// GenerateNTTPrimesP generates "levels" different NthRoot NTT-friendly
// primes starting from 2**LogP and downward.
// Special case were primes close to 2^{LogP} but with a smaller bit-size than LogP are sought.
-func GenerateNTTPrimesP(logP, NthRoot, n uint64) (primes []uint64) {
+func GenerateNTTPrimesP(logP, NthRoot, n int) (primes []uint64) {
var x, Ppow2 uint64
primes = []uint64{}
- Ppow2 = 1 << logP
+ Ppow2 = uint64(1 << logP)
x = Ppow2 + 1
@@ -151,15 +151,15 @@ func GenerateNTTPrimesP(logP, NthRoot, n uint64) (primes []uint64) {
// We start by subtracting 2N to ensure that the prime bit-length is smaller than LogP
- if x > NthRoot {
+ if x > uint64(NthRoot) {
- x -= NthRoot
+ x -= uint64(NthRoot)
if IsPrime(x) {
primes = append(primes, x)
- if uint64(len(primes)) == n {
+ if len(primes) == n {
return primes
}
}
diff --git a/ring/ring.go b/ring/ring.go
index e3a741f2..04afcfea 100644
--- a/ring/ring.go
+++ b/ring/ring.go
@@ -17,7 +17,7 @@ import (
type Ring struct {
// Polynomial nb.Coefficients
- N uint64
+ N int
// Moduli
Modulus []uint64
@@ -26,7 +26,7 @@ type Ring struct {
Mask []uint64
// Indicates whether NTT can be used with the current ring.
- allowsNTT bool
+ AllowsNTT bool
// Product of the Moduli
ModulusBigint *big.Int
@@ -44,12 +44,14 @@ type Ring struct {
NttPsi [][]uint64 //powers of the inverse of the 2N-th primitive root in Montgomery form (in bit-reversed order)
NttPsiInv [][]uint64 //powers of the inverse of the 2N-th primitive root in Montgomery form (in bit-reversed order)
NttNInv []uint64 //[N^-1] mod Qi in Montgomery form
+
+ polypool *Poly
}
// NewRing creates a new RNS Ring with degree N and coefficient moduli Moduli. N must be a power of two larger than 8. Moduli should be
// a non-empty []uint64 with distinct prime elements. For the Ring instance to support NTT operation, these elements must also be equal
// to 1 modulo 2*N. Non-nil r and error are returned in the case of non NTT-enabling parameters.
-func NewRing(N uint64, Moduli []uint64) (r *Ring, err error) {
+func NewRing(N int, Moduli []uint64) (r *Ring, err error) {
r = new(Ring)
err = r.setParameters(N, Moduli)
if err != nil {
@@ -60,7 +62,7 @@ func NewRing(N uint64, Moduli []uint64) (r *Ring, err error) {
// setParameters initializes a *Ring by setting the required precomputed values (except for the NTT-related values, which are set by the
// genNTTParams function).
-func (r *Ring) setParameters(N uint64, Modulus []uint64) error {
+func (r *Ring) setParameters(N int, Modulus []uint64) error {
// Checks if N is a power of 2
if (N < 16) || (N&(N-1)) != 0 && N != 0 {
@@ -75,7 +77,7 @@ func (r *Ring) setParameters(N uint64, Modulus []uint64) error {
return errors.New("invalid modulus (moduli are not distinct)")
}
- r.allowsNTT = false
+ r.AllowsNTT = false
r.N = N
@@ -108,6 +110,9 @@ func (r *Ring) setParameters(N uint64, Modulus []uint64) error {
r.MredParams[i] = MRedParams(qi)
}
}
+
+ r.polypool = r.NewPoly()
+
return nil
}
@@ -116,7 +121,7 @@ func (r *Ring) setParameters(N uint64, Modulus []uint64) error {
// NTT parameters.
func (r *Ring) genNTTParams() error {
- if r.allowsNTT {
+ if r.AllowsNTT {
return nil
}
@@ -130,8 +135,8 @@ func (r *Ring) genNTTParams() error {
return fmt.Errorf("invalid modulus (Modulus[%d] is not prime)", i)
}
- if qi&((r.N<<1)-1) != 1 {
- r.allowsNTT = false
+ if int(qi)&((r.N<<1)-1) != 1 {
+ r.AllowsNTT = false
return fmt.Errorf("invalid modulus (Modulus[%d] != 1 mod 2N)", i)
}
}
@@ -144,7 +149,7 @@ func (r *Ring) genNTTParams() error {
for i := 0; i < j; i++ {
- r.RescaleParams[j-1][i] = MForm(ModExp(r.Modulus[j], r.Modulus[i]-2, r.Modulus[i]), r.Modulus[i], r.BredParams[i])
+ r.RescaleParams[j-1][i] = MForm(r.Modulus[i]-ModExp(r.Modulus[j], int(r.Modulus[i]-2), r.Modulus[i]), r.Modulus[i], r.BredParams[i])
}
}
@@ -154,12 +159,12 @@ func (r *Ring) genNTTParams() error {
r.NttPsiInv = make([][]uint64, len(r.Modulus))
r.NttNInv = make([]uint64, len(r.Modulus))
- bitLenofN := uint64(bits.Len64(r.N) - 1)
+ bitLenofN := bits.Len64(uint64(r.N)) - 1
for i, qi := range r.Modulus {
// 1.1 Compute N^(-1) mod Q in Montgomery form
- r.NttNInv[i] = MForm(ModExp(r.N, qi-2, qi), qi, r.BredParams[i])
+ r.NttNInv[i] = MForm(ModExp(uint64(r.N), int(qi-2), qi), qi, r.BredParams[i])
// 1.2 Compute Psi and PsiInv in Montgomery form
r.NttPsi[i] = make([]uint64, r.N)
@@ -168,10 +173,10 @@ func (r *Ring) genNTTParams() error {
// Finds a 2N-th primitive Root
g := primitiveRoot(qi)
- _2n := uint64(r.N << 1)
+ _2n := r.N << 1
- power := (qi - 1) / _2n
- powerInv := (qi - 1) - power
+ power := (int(qi) - 1) / _2n
+ powerInv := (int(qi) - 1) - power
// Computes Psi and PsiInv in Montgomery form
PsiMont := MForm(ModExp(g, power, qi), qi, r.BredParams[i])
@@ -184,24 +189,24 @@ func (r *Ring) genNTTParams() error {
r.NttPsiInv[i][0] = MForm(1, qi, r.BredParams[i])
// Compute nttPsi[j] = nttPsi[j-1]*Psi and nttPsiInv[j] = nttPsiInv[j-1]*PsiInv
- for j := uint64(1); j < r.N; j++ {
+ for j := 1; j < r.N; j++ {
- indexReversePrev := utils.BitReverse64(j-1, bitLenofN)
- indexReverseNext := utils.BitReverse64(j, bitLenofN)
+ indexReversePrev := utils.BitReverse64(uint64(j-1), uint64(bitLenofN))
+ indexReverseNext := utils.BitReverse64(uint64(j), uint64(bitLenofN))
r.NttPsi[i][indexReverseNext] = MRed(r.NttPsi[i][indexReversePrev], PsiMont, qi, r.MredParams[i])
r.NttPsiInv[i][indexReverseNext] = MRed(r.NttPsiInv[i][indexReversePrev], PsiInvMont, qi, r.MredParams[i])
}
}
- r.allowsNTT = true
+ r.AllowsNTT = true
return nil
}
// Minimal required information to recover the full ring. Used to import and export the ring.
type ringParams struct {
- N uint64
+ N int
Modulus []uint64
}
@@ -239,46 +244,6 @@ func (r *Ring) UnmarshalBinary(data []byte) error {
return nil
}
-// AllowsNTT returns true if the ring allows NTT, and false otherwise.
-func (r *Ring) AllowsNTT() bool {
- return r.allowsNTT
-}
-
-// GetBredParams returns the Barret reduction parameters of the Ring.
-func (r *Ring) GetBredParams() [][]uint64 {
- return r.BredParams
-}
-
-// GetMredParams returns the Montgomery reduction parameters of the Ring.
-func (r *Ring) GetMredParams() []uint64 {
- return r.MredParams
-}
-
-// GetPsi returns the primitive root used to compute the NTT parameters of the Ring.
-func (r *Ring) GetPsi() []uint64 {
- return r.PsiMont
-}
-
-// GetPsiInv returns the primitive root used to compute the InvNTT parameters of the Ring.
-func (r *Ring) GetPsiInv() []uint64 {
- return r.PsiInvMont
-}
-
-// GetNttPsi returns the NTT parameters of the Ring.
-func (r *Ring) GetNttPsi() [][]uint64 {
- return r.NttPsi
-}
-
-// GetNttPsiInv returns the InvNTT parameters of the Ring.
-func (r *Ring) GetNttPsiInv() [][]uint64 {
- return r.NttPsiInv
-}
-
-// GetNttNInv returns 1/N mod each modulus.
-func (r *Ring) GetNttNInv() []uint64 {
- return r.NttNInv
-}
-
// NewPoly creates a new polynomial with all coefficients set to 0.
func (r *Ring) NewPoly() *Poly {
p := new(Poly)
@@ -292,11 +257,11 @@ func (r *Ring) NewPoly() *Poly {
}
// NewPolyLvl creates a new polynomial with all coefficients set to 0.
-func (r *Ring) NewPolyLvl(level uint64) *Poly {
+func (r *Ring) NewPolyLvl(level int) *Poly {
p := new(Poly)
p.Coeffs = make([][]uint64, level+1)
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
p.Coeffs[i] = make([]uint64, r.N)
}
@@ -348,11 +313,11 @@ func (r *Ring) SetCoefficientsBigint(coeffs []*big.Int, p1 *Poly) {
}
// SetCoefficientsBigintLvl sets the coefficients of p1 from an array of Int variables.
-func (r *Ring) SetCoefficientsBigintLvl(level uint64, coeffs []*big.Int, p1 *Poly) {
+func (r *Ring) SetCoefficientsBigintLvl(level int, coeffs []*big.Int, p1 *Poly) {
QiBigint := new(big.Int)
coeffTmp := new(big.Int)
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
QiBigint.SetUint64(r.Modulus[i])
for j, coeff := range coeffs {
p1.Coeffs[i][j] = coeffTmp.Mod(coeff, QiBigint).Uint64()
@@ -377,10 +342,9 @@ func (r *Ring) PolyToString(p1 *Poly) []string {
// PolyToBigint reconstructs p1 and returns the result in an array of Int.
func (r *Ring) PolyToBigint(p1 *Poly, coeffsBigint []*big.Int) {
+ var qi uint64
- var qi, level uint64
-
- level = uint64(len(p1.Coeffs) - 1)
+ level := p1.Level()
crtReconstruction := make([]*big.Int, level+1)
@@ -388,7 +352,7 @@ func (r *Ring) PolyToBigint(p1 *Poly, coeffsBigint []*big.Int) {
tmp := new(big.Int)
modulusBigint := NewUint(1)
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
qi = r.Modulus[i]
QiB.SetUint64(qi)
@@ -402,12 +366,12 @@ func (r *Ring) PolyToBigint(p1 *Poly, coeffsBigint []*big.Int) {
crtReconstruction[i].Mul(crtReconstruction[i], tmp)
}
- for x := uint64(0); x < r.N; x++ {
+ for x := 0; x < r.N; x++ {
tmp.SetUint64(0)
coeffsBigint[x] = new(big.Int)
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
coeffsBigint[x].Add(coeffsBigint[x], tmp.Mul(NewUint(p1.Coeffs[i][x]), crtReconstruction[i]))
}
@@ -418,9 +382,9 @@ func (r *Ring) PolyToBigint(p1 *Poly, coeffsBigint []*big.Int) {
// PolyToBigintNoAlloc reconstructs p1 and returns the result in an pre-allocated array of Int.
func (r *Ring) PolyToBigintNoAlloc(p1 *Poly, coeffsBigint []*big.Int) {
- var qi, level uint64
+ var qi uint64
- level = uint64(len(p1.Coeffs) - 1)
+ level := p1.Level()
crtReconstruction := make([]*big.Int, level+1)
@@ -428,7 +392,7 @@ func (r *Ring) PolyToBigintNoAlloc(p1 *Poly, coeffsBigint []*big.Int) {
tmp := new(big.Int)
modulusBigint := NewUint(1)
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
qi = r.Modulus[i]
QiB.SetUint64(qi)
@@ -442,11 +406,11 @@ func (r *Ring) PolyToBigintNoAlloc(p1 *Poly, coeffsBigint []*big.Int) {
crtReconstruction[i].Mul(crtReconstruction[i], tmp)
}
- for x := uint64(0); x < r.N; x++ {
+ for x := 0; x < r.N; x++ {
tmp.SetUint64(0)
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
coeffsBigint[x].Add(coeffsBigint[x], tmp.Mul(NewUint(p1.Coeffs[i][x]), crtReconstruction[i]))
}
@@ -467,7 +431,7 @@ func (r *Ring) Equal(p1, p2 *Poly) bool {
r.Reduce(p2, p2)
for i := 0; i < len(r.Modulus); i++ {
- for j := uint64(0); j < r.N; j++ {
+ for j := 0; j < r.N; j++ {
if p1.Coeffs[i][j] != p2.Coeffs[i][j] {
return false
}
@@ -478,9 +442,9 @@ func (r *Ring) Equal(p1, p2 *Poly) bool {
}
// EqualLvl checks if p1 = p2 in the given Ring, up to a given level.
-func (r *Ring) EqualLvl(level uint64, p1, p2 *Poly) bool {
+func (r *Ring) EqualLvl(level int, p1, p2 *Poly) bool {
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
if len(p1.Coeffs[i]) != len(p2.Coeffs[i]) {
return false
}
@@ -489,8 +453,8 @@ func (r *Ring) EqualLvl(level uint64, p1, p2 *Poly) bool {
r.ReduceLvl(level, p1, p1)
r.ReduceLvl(level, p2, p2)
- for i := uint64(0); i < level+1; i++ {
- for j := uint64(0); j < r.N; j++ {
+ for i := 0; i < level+1; i++ {
+ for j := 0; j < r.N; j++ {
if p1.Coeffs[i][j] != p2.Coeffs[i][j] {
return false
}
diff --git a/ring/ring_automorphism.go b/ring/ring_automorphism.go
index d625cde6..3008109e 100644
--- a/ring/ring_automorphism.go
+++ b/ring/ring_automorphism.go
@@ -86,31 +86,31 @@ func PermuteNTT(polIn *Poly, gen uint64, polOut *Poly) {
// PermuteNTTLvl applies the Galois transform on a polynomial in the NTT domain, up to a given level.
// It maps the coefficients x^i to x^(gen*i)
// It must be noted that the result cannot be in-place.
-func PermuteNTTLvl(level uint64, polIn *Poly, gen uint64, polOut *Poly) {
+func PermuteNTTLvl(level int, polIn *Poly, gen uint64, polOut *Poly) {
- var N, tmp, mask, logN, tmp1, tmp2 uint64
+ var tmp, tmp1, tmp2 uint64
- N = uint64(len(polIn.Coeffs[0]))
+ N := len(polIn.Coeffs[0])
- logN = uint64(bits.Len64(N) - 1)
+ logN := uint64(bits.Len64(uint64(N)) - 1)
- mask = (N << 1) - 1
+ mask := uint64((N << 1) - 1)
index := make([]uint64, N)
- for i := uint64(0); i < N; i++ {
- tmp1 = 2*utils.BitReverse64(i, logN) + 1
+ for i := 0; i < N; i++ {
+ tmp1 = 2*utils.BitReverse64(uint64(i), logN) + 1
tmp2 = ((gen * tmp1 & mask) - 1) >> 1
index[i] = utils.BitReverse64(tmp2, logN)
}
- for j := uint64(0); j < N; j++ {
+ for j := 0; j < N; j++ {
tmp = index[j]
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
polOut.Coeffs[i][j] = polIn.Coeffs[i][tmp]
}
@@ -120,13 +120,13 @@ func PermuteNTTLvl(level uint64, polIn *Poly, gen uint64, polOut *Poly) {
// PermuteNTTWithIndexLvl applies the Galois transform on a polynomial in the NTT domain, up to a given level.
// It maps the coefficients x^i to x^(gen*i) using the PermuteNTTIndex table.
// It must be noted that the result cannot be in-place.
-func PermuteNTTWithIndexLvl(level uint64, polIn *Poly, index []uint64, polOut *Poly) {
+func PermuteNTTWithIndexLvl(level int, polIn *Poly, index []uint64, polOut *Poly) {
- for j := uint64(0); j < uint64(len(polIn.Coeffs[0])); j = j + 8 {
+ for j := 0; j < len(polIn.Coeffs[0]); j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&index[j]))
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
z := (*[8]uint64)(unsafe.Pointer(&polOut.Coeffs[i][j]))
y := polIn.Coeffs[i]
@@ -147,13 +147,13 @@ func PermuteNTTWithIndexLvl(level uint64, polIn *Poly, index []uint64, polOut *P
// and adds the result to the output polynomial without modular reduction.
// It maps the coefficients x^i to x^(gen*i) using the PermuteNTTIndex table.
// It must be noted that the result cannot be in-place.
-func PermuteNTTWithIndexAndAddNoModLvl(level uint64, polIn *Poly, index []uint64, polOut *Poly) {
+func PermuteNTTWithIndexAndAddNoModLvl(level int, polIn *Poly, index []uint64, polOut *Poly) {
- for j := uint64(0); j < uint64(len(polIn.Coeffs[0])); j = j + 8 {
+ for j := 0; j < len(polIn.Coeffs[0]); j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&index[j]))
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
z := (*[8]uint64)(unsafe.Pointer(&polOut.Coeffs[i][j]))
y := polIn.Coeffs[i]
@@ -177,11 +177,11 @@ func (r *Ring) Permute(polIn *Poly, gen uint64, polOut *Poly) {
var mask, index, indexRaw, logN, tmp uint64
- mask = r.N - 1
+ mask = uint64(r.N - 1)
logN = uint64(bits.Len64(mask))
- for i := uint64(0); i < r.N; i++ {
+ for i := uint64(0); i < uint64(r.N); i++ {
indexRaw = i * gen
diff --git a/ring/ring_basis_extension.go b/ring/ring_basis_extension.go
index 4041bd19..27db52ee 100644
--- a/ring/ring_basis_extension.go
+++ b/ring/ring_basis_extension.go
@@ -44,12 +44,12 @@ func genModDownParams(ringP, ringQ *Ring) (params []uint64) {
params = make([]uint64, len(ringP.Modulus))
- bredParams := ringP.GetBredParams()
+ bredParams := ringP.BredParams
tmp := new(big.Int)
for i, Qi := range ringP.Modulus {
params[i] = tmp.Mod(ringQ.ModulusBigint, NewUint(Qi)).Uint64()
- params[i] = ModExp(params[i], Qi-2, Qi)
+ params[i] = ModExp(params[i], int(Qi-2), Qi)
params[i] = MForm(params[i], Qi, bredParams[i])
}
@@ -167,15 +167,15 @@ func (basisextender *FastBasisExtender) ShallowCopy() *FastBasisExtender {
// ModUpSplitQP extends the RNS basis of a polynomial from Q to QP.
// Given a polynomial with coefficients in basis {Q0,Q1....Qlevel},
// it extends its basis from {Q0,Q1....Qlevel} to {Q0,Q1....Qlevel,P0,P1...Pj}
-func (basisextender *FastBasisExtender) ModUpSplitQP(level uint64, p1, p2 *Poly) {
- modUpExact(p1.Coeffs[:level+1], p2.Coeffs[:uint64(len(basisextender.paramsQP.P))], basisextender.paramsQP)
+func (basisextender *FastBasisExtender) ModUpSplitQP(level int, p1, p2 *Poly) {
+ modUpExact(p1.Coeffs[:level+1], p2.Coeffs[:len(basisextender.paramsQP.P)], basisextender.paramsQP)
}
// ModUpSplitPQ extends the RNS basis of a polynomial from P to PQ.
// Given a polynomial with coefficients in basis {P0,P1....Plevel},
// it extends its basis from {P0,P1....Plevel} to {Q0,Q1...Qj}
-func (basisextender *FastBasisExtender) ModUpSplitPQ(level uint64, p1, p2 *Poly) {
- modUpExact(p1.Coeffs[:level+1], p2.Coeffs[:uint64(len(basisextender.paramsPQ.P))], basisextender.paramsPQ)
+func (basisextender *FastBasisExtender) ModUpSplitPQ(level int, p1, p2 *Poly) {
+ modUpExact(p1.Coeffs[:level+1], p2.Coeffs[:len(basisextender.paramsPQ.P)], basisextender.paramsPQ)
}
// ModDownNTTPQ reduces the basis RNS of a polynomial in the NTT domain
@@ -184,7 +184,7 @@ func (basisextender *FastBasisExtender) ModUpSplitPQ(level uint64, p1, p2 *Poly)
// it reduces its basis from {Q0,Q1....Qlevel,P0,P1...Pj} to {Q0,Q1....Qlevel}
// and performs a rounded integer division of the result by P.
// Inputs must be in the NTT domain.
-func (basisextender *FastBasisExtender) ModDownNTTPQ(level uint64, p1, p2 *Poly) {
+func (basisextender *FastBasisExtender) ModDownNTTPQ(level int, p1, p2 *Poly) {
ringQ := basisextender.ringQ
ringP := basisextender.ringP
@@ -195,7 +195,7 @@ func (basisextender *FastBasisExtender) ModDownNTTPQ(level uint64, p1, p2 *Poly)
// First we get the P basis part of p1 out of the NTT domain
for j := 0; j < nPj; j++ {
- InvNTTLazy(p1.Coeffs[nQi+j], p1.Coeffs[nQi+j], ringP.N, ringP.GetNttPsiInv()[j], ringP.GetNttNInv()[j], ringP.Modulus[j], ringP.GetMredParams()[j])
+ InvNTTLazy(p1.Coeffs[nQi+j], p1.Coeffs[nQi+j], ringP.N, ringP.NttPsiInv[j], ringP.NttNInv[j], ringP.Modulus[j], ringP.MredParams[j])
}
// Then we target this P basis of p1 and convert it to a Q basis (at the "level" of p1) and copy it on polypool
@@ -203,7 +203,7 @@ func (basisextender *FastBasisExtender) ModDownNTTPQ(level uint64, p1, p2 *Poly)
modUpExact(p1.Coeffs[nQi:nQi+nPj], polypool.Coeffs[:level+1], basisextender.paramsPQ)
// Finally, for each level of p1 (and polypool since they now share the same basis) we compute p2 = (P^-1) * (p1 - polypool) mod Q
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
qi := ringQ.Modulus[i]
twoqi := qi << 1
@@ -213,12 +213,13 @@ func (basisextender *FastBasisExtender) ModDownNTTPQ(level uint64, p1, p2 *Poly)
params := qi - modDownParams[i]
mredParams := ringQ.MredParams[i]
bredParams := ringQ.BredParams[i]
+ nttPsi := ringQ.NttPsi[i]
// First we switch back the relevant polypool CRT array back to the NTT domain
- NTTLazy(p3tmp, p3tmp, ringQ.N, ringQ.GetNttPsi()[i], qi, mredParams, bredParams)
+ NTTLazy(p3tmp, p3tmp, ringQ.N, nttPsi, qi, mredParams, bredParams)
// Then for each coefficient we compute (P^-1) * (p1[i][j] - polypool[i][j]) mod qi
- for j := uint64(0); j < ringQ.N; j = j + 8 {
+ for j := 0; j < ringQ.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p3tmp[j]))
@@ -243,7 +244,7 @@ func (basisextender *FastBasisExtender) ModDownNTTPQ(level uint64, p1, p2 *Poly)
// it reduces its basis from {Q0,Q1....Qi} and {P0,P1...Pj} to {Q0,Q1....Qi}
// and does a rounded integer division of the result by P.
// Inputs must be in the NTT domain.
-func (basisextender *FastBasisExtender) ModDownSplitNTTPQ(level uint64, p1Q, p1P, p2 *Poly) {
+func (basisextender *FastBasisExtender) ModDownSplitNTTPQ(level int, p1Q, p1P, p2 *Poly) {
ringQ := basisextender.ringQ
ringP := basisextender.ringP
@@ -258,7 +259,7 @@ func (basisextender *FastBasisExtender) ModDownSplitNTTPQ(level uint64, p1Q, p1P
modUpExact(p1P.Coeffs, polypool.Coeffs[:level+1], basisextender.paramsPQ)
// Finally, for each level of p1 (and polypool since they now share the same basis) we compute p2 = (P^-1) * (p1 - polypool) mod Q
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
qi := ringQ.Modulus[i]
twoqi := qi << 1
@@ -268,12 +269,13 @@ func (basisextender *FastBasisExtender) ModDownSplitNTTPQ(level uint64, p1Q, p1P
params := qi - modDownParams[i]
mredParams := ringQ.MredParams[i]
bredParams := ringQ.BredParams[i]
+ nttPsi := ringQ.NttPsi[i]
// First we switch back the relevant polypool CRT array back to the NTT domain
- NTTLazy(p3tmp, p3tmp, ringQ.N, ringQ.GetNttPsi()[i], ringQ.Modulus[i], mredParams, bredParams)
+ NTTLazy(p3tmp, p3tmp, ringQ.N, nttPsi, qi, mredParams, bredParams)
// Then for each coefficient we compute (P^-1) * (p1[i][j] - polypool[i][j]) mod qi
- for j := uint64(0); j < ringQ.N; j = j + 8 {
+ for j := 0; j < ringQ.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p3tmp[j]))
@@ -297,19 +299,19 @@ func (basisextender *FastBasisExtender) ModDownSplitNTTPQ(level uint64, p1Q, p1P
// Given a polynomial with coefficients in basis {Q0,Q1....Qlevel,P0,P1...Pj},
// it reduces its basis from {Q0,Q1....Qlevel,P0,P1...Pj} to {Q0,Q1....Qlevel}
// and does a rounded integer division of the result by P.
-func (basisextender *FastBasisExtender) ModDownPQ(level uint64, p1, p2 *Poly) {
+func (basisextender *FastBasisExtender) ModDownPQ(level int, p1, p2 *Poly) {
ringQ := basisextender.ringQ
modDownParams := basisextender.modDownParamsPQ
polypool := basisextender.polypoolQ
- nPi := uint64(len(basisextender.paramsQP.P))
+ nPi := len(basisextender.paramsQP.P)
// We target this P basis of p1 and convert it to a Q basis (at the "level" of p1) and copy it on polypool
// polypool is now the representation of the P basis of p1 but in basis Q (at the "level" of p1)
modUpExact(p1.Coeffs[level+1:level+1+nPi], polypool.Coeffs[:level+1], basisextender.paramsPQ)
// Finally, for each level of p1 (and polypool since they now share the same basis) we compute p2 = (P^-1) * (p1 - polypool) mod Q
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
qi := ringQ.Modulus[i]
twoqi := qi << 1
@@ -320,7 +322,7 @@ func (basisextender *FastBasisExtender) ModDownPQ(level uint64, p1, p2 *Poly) {
mredParams := ringQ.MredParams[i]
// Then for each coefficient we compute (P^-1) * (p1[i][j] - polypool[i][j]) mod qi
- for j := uint64(0); j < ringQ.N; j = j + 8 {
+ for j := 0; j < ringQ.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p3tmp[j]))
@@ -344,7 +346,7 @@ func (basisextender *FastBasisExtender) ModDownPQ(level uint64, p1, p2 *Poly) {
// Given a polynomial with coefficients in basis {Q0,Q1....Qlevel} and {P0,P1...Pj},
// it reduces its basis from {Q0,Q1....Qlevel} and {P0,P1...Pj} to {Q0,Q1....Qlevel}
// and does a rounded integer division of the result by P.
-func (basisextender *FastBasisExtender) ModDownSplitPQ(level uint64, p1Q, p1P, p2 *Poly) {
+func (basisextender *FastBasisExtender) ModDownSplitPQ(level int, p1Q, p1P, p2 *Poly) {
ringQ := basisextender.ringQ
modDownParams := basisextender.modDownParamsPQ
@@ -355,7 +357,7 @@ func (basisextender *FastBasisExtender) ModDownSplitPQ(level uint64, p1Q, p1P, p
modUpExact(p1P.Coeffs, polypool.Coeffs[:level+1], basisextender.paramsPQ)
// Finally, for each level of p1 (and polypool since they now share the same basis) we compute p2 = (P^-1) * (p1 - polypool) mod Q
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
qi := ringQ.Modulus[i]
twoqi := qi << 1
@@ -366,7 +368,7 @@ func (basisextender *FastBasisExtender) ModDownSplitPQ(level uint64, p1Q, p1P, p
mredParams := ringQ.MredParams[i]
// Then for each coefficient we compute (P^-1) * (p1[i][j] - polypool[i][j]) mod qi
- for j := uint64(0); j < ringQ.N; j = j + 8 {
+ for j := 0; j < ringQ.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p3tmp[j]))
@@ -390,7 +392,7 @@ func (basisextender *FastBasisExtender) ModDownSplitPQ(level uint64, p1Q, p1P, p
// Given a polynomial with coefficients in basis {Q0,Q1....QlevelQ} and {P0,P1...PlevelP},
// it reduces its basis from {Q0,Q1....QlevelQ} and {P0,P1...PlevelP} to {P0,P1...PlevelP}
// and does a floored integer division of the result by Q.
-func (basisextender *FastBasisExtender) ModDownSplitQP(levelQ, levelP uint64, p1Q, p1P, p2 *Poly) {
+func (basisextender *FastBasisExtender) ModDownSplitQP(levelQ, levelP int, p1Q, p1P, p2 *Poly) {
ringP := basisextender.ringP
modDownParams := basisextender.modDownParamsQP
@@ -401,7 +403,7 @@ func (basisextender *FastBasisExtender) ModDownSplitQP(levelQ, levelP uint64, p1
basisextender.ModUpSplitQP(levelQ, p1Q, polypool)
// Finally, for each level of p1 (and polypool since they now share the same basis) we compute p2 = (P^-1) * (p1 - polypool) mod Q
- for i := uint64(0); i < levelP+1; i++ {
+ for i := 0; i < levelP+1; i++ {
qi := ringP.Modulus[i]
twoqi := qi << 1
@@ -412,7 +414,7 @@ func (basisextender *FastBasisExtender) ModDownSplitQP(levelQ, levelP uint64, p1
mredParams := ringP.MredParams[i]
// Then for each coefficient we compute (P^-1) * (p1[i][j] - polypool[i][j]) mod qi
- for j := uint64(0); j < ringP.N; j = j + 8 {
+ for j := 0; j < ringP.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p3tmp[j]))
@@ -439,9 +441,9 @@ func modUpExact(p1, p2 [][]uint64, params *modupParams) {
var y0, y1, y2, y3, y4, y5, y6, y7 [32]uint64
// We loop over each coefficient and apply the basis extension
- for x := uint64(0); x < uint64(len(p1[0])); x = x + 8 {
+ for x := 0; x < len(p1[0]); x = x + 8 {
- reconstructRNS(uint64(len(p1)), x, p1, &v, &y0, &y1, &y2, &y3, &y4, &y5, &y6, &y7, params.Q, params.mredParamsQ, params.qibMont)
+ reconstructRNS(len(p1), x, p1, &v, &y0, &y1, &y2, &y3, &y4, &y5, &y6, &y7, params.Q, params.mredParamsQ, params.qibMont)
for j := 0; j < len(p2); j++ {
@@ -452,7 +454,7 @@ func modUpExact(p1, p2 [][]uint64, params *modupParams) {
res := (*[8]uint64)(unsafe.Pointer(&p2[j][x]))
- multSum(res, &v, &y0, &y1, &y2, &y3, &y4, &y5, &y6, &y7, uint64(len(p1)), pj, qInv, qpjInv, qispjMont)
+ multSum(res, &v, &y0, &y1, &y2, &y3, &y4, &y5, &y6, &y7, len(p1), pj, qInv, qpjInv, qispjMont)
}
}
}
@@ -461,18 +463,18 @@ func modUpExact(p1, p2 [][]uint64, params *modupParams) {
// This decomposer takes a p(x)_Q (in basis Q) and returns p(x) mod qi in basis QP, where
// qi = prod(Q_i) for 0<=i<=L, where L is the number of factors in P.
type Decomposer struct {
- nQprimes uint64
- nPprimes uint64
- alpha uint64
- beta uint64
- xalpha []uint64
+ nQprimes int
+ nPprimes int
+ alpha int
+ beta int
+ xalpha []int
modUpParams [][]*modupParams
QInt *big.Int
PInt *big.Int
}
// Xalpha returns a slice that contains all the values of #Qi/#Pi.
-func (decomposer *Decomposer) Xalpha() (xalpha []uint64) {
+func (decomposer *Decomposer) Xalpha() (xalpha []int) {
return decomposer.xalpha
}
@@ -480,8 +482,8 @@ func (decomposer *Decomposer) Xalpha() (xalpha []uint64) {
func NewDecomposer(Q, P []uint64) (decomposer *Decomposer) {
decomposer = new(Decomposer)
- decomposer.nQprimes = uint64(len(Q))
- decomposer.nPprimes = uint64(len(P))
+ decomposer.nQprimes = len(Q)
+ decomposer.nPprimes = len(P)
decomposer.QInt = NewUint(1)
for i := range Q {
@@ -493,31 +495,31 @@ func NewDecomposer(Q, P []uint64) (decomposer *Decomposer) {
decomposer.PInt.Mul(decomposer.PInt, NewUint(P[i]))
}
- decomposer.alpha = uint64(len(P))
- decomposer.beta = uint64(math.Ceil(float64(len(Q)) / float64(decomposer.alpha)))
+ decomposer.alpha = len(P)
+ decomposer.beta = int(math.Ceil(float64(len(Q)) / float64(decomposer.alpha)))
- decomposer.xalpha = make([]uint64, decomposer.beta)
+ decomposer.xalpha = make([]int, decomposer.beta)
for i := range decomposer.xalpha {
decomposer.xalpha[i] = decomposer.alpha
}
- if uint64(len(Q))%decomposer.alpha != 0 {
- decomposer.xalpha[decomposer.beta-1] = uint64(len(Q)) % decomposer.alpha
+ if len(Q)%decomposer.alpha != 0 {
+ decomposer.xalpha[decomposer.beta-1] = len(Q) % decomposer.alpha
}
decomposer.modUpParams = make([][]*modupParams, decomposer.beta)
// Create a basis extension for each possible combination of [Qi,Pj] according to xalpha
- for i := uint64(0); i < decomposer.beta; i++ {
+ for i := 0; i < decomposer.beta; i++ {
decomposer.modUpParams[i] = make([]*modupParams, decomposer.xalpha[i]-1)
- for j := uint64(0); j < decomposer.xalpha[i]-1; j++ {
+ for j := 0; j < decomposer.xalpha[i]-1; j++ {
Qi := make([]uint64, j+2)
Pi := make([]uint64, len(Q)+len(P))
- for k := uint64(0); k < j+2; k++ {
+ for k := 0; k < j+2; k++ {
Qi[k] = Q[i*decomposer.alpha+k]
}
@@ -538,7 +540,7 @@ func NewDecomposer(Q, P []uint64) (decomposer *Decomposer) {
// DecomposeAndSplit decomposes a polynomial p(x) in basis Q, reduces it modulo qi, and returns
// the result in basis QP separately.
-func (decomposer *Decomposer) DecomposeAndSplit(level, crtDecompLevel uint64, p0, p1Q, p1P *Poly) {
+func (decomposer *Decomposer) DecomposeAndSplit(level, crtDecompLevel int, p0, p1Q, p1P *Poly) {
alphai := decomposer.xalpha[crtDecompLevel]
@@ -548,18 +550,18 @@ func (decomposer *Decomposer) DecomposeAndSplit(level, crtDecompLevel uint64, p0
// First we check if the vector can simply by coping and rearranging elements (the case where no reconstruction is needed)
if (p0idxed > level+1 && (level+1)%decomposer.nPprimes == 1) || alphai == 1 {
- for j := uint64(0); j < level+1; j++ {
+ for j := 0; j < level+1; j++ {
copy(p1Q.Coeffs[j], p0.Coeffs[p0idxst])
}
- for j := uint64(0); j < decomposer.nPprimes; j++ {
+ for j := 0; j < decomposer.nPprimes; j++ {
copy(p1P.Coeffs[j], p0.Coeffs[p0idxst])
}
// Otherwise, we apply a fast exact base conversion for the reconstruction
} else {
- var index uint64
+ var index int
if level >= alphai+crtDecompLevel*decomposer.alpha {
index = decomposer.xalpha[crtDecompLevel] - 2
} else {
@@ -575,12 +577,12 @@ func (decomposer *Decomposer) DecomposeAndSplit(level, crtDecompLevel uint64, p0
var qif float64
// We loop over each coefficient and apply the basis extension
- for x := uint64(0); x < uint64(len(p0.Coeffs[0])); x = x + 8 {
+ for x := 0; x < len(p0.Coeffs[0]); x = x + 8 {
vi[0], vi[1], vi[2], vi[3], vi[4], vi[5], vi[6], vi[7] = 0, 0, 0, 0, 0, 0, 0, 0
// Coefficients to be decomposed
- for i, j := uint64(0), p0idxst; i < index+2; i, j = i+1, j+1 {
+ for i, j := 0, p0idxst; i < index+2; i, j = i+1, j+1 {
qibMont = params.qibMont[i]
qi = params.Q[i]
@@ -624,7 +626,7 @@ func (decomposer *Decomposer) DecomposeAndSplit(level, crtDecompLevel uint64, p0
v[7] = uint64(vi[7])
// Coefficients of index smaller than the ones to be decomposed
- for j := uint64(0); j < p0idxst; j++ {
+ for j := 0; j < p0idxst; j++ {
pj = params.P[j]
qInv := params.mredParamsP[j]
@@ -650,7 +652,7 @@ func (decomposer *Decomposer) DecomposeAndSplit(level, crtDecompLevel uint64, p0
}
// Coefficients of the special primes Pi
- for j, u := uint64(0), decomposer.nQprimes; j < decomposer.nPprimes; j, u = j+1, u+1 {
+ for j, u := 0, decomposer.nQprimes; j < decomposer.nPprimes; j, u = j+1, u+1 {
pj = params.P[u]
qInv := params.mredParamsP[u]
@@ -665,27 +667,28 @@ func (decomposer *Decomposer) DecomposeAndSplit(level, crtDecompLevel uint64, p0
}
}
-func reconstructRNS(index, x uint64, p [][]uint64, v *[8]uint64, y0, y1, y2, y3, y4, y5, y6, y7 *[32]uint64, Q, QInv, QbMont []uint64) {
+func reconstructRNS(index, x int, p [][]uint64, v *[8]uint64, y0, y1, y2, y3, y4, y5, y6, y7 *[32]uint64, Q, QInv, QbMont []uint64) {
var vi [8]float64
var qi, qiInv, qibMont uint64
var qif float64
- for i := uint64(0); i < index; i++ {
+ for i := 0; i < index; i++ {
qibMont = QbMont[i]
qi = Q[i]
qiInv = QInv[i]
qif = float64(qi)
+ pTmp := (*[8]uint64)(unsafe.Pointer(&p[i][x]))
- y0[i] = MRed(p[i][x+0], qibMont, qi, qiInv)
- y1[i] = MRed(p[i][x+1], qibMont, qi, qiInv)
- y2[i] = MRed(p[i][x+2], qibMont, qi, qiInv)
- y3[i] = MRed(p[i][x+3], qibMont, qi, qiInv)
- y4[i] = MRed(p[i][x+4], qibMont, qi, qiInv)
- y5[i] = MRed(p[i][x+5], qibMont, qi, qiInv)
- y6[i] = MRed(p[i][x+6], qibMont, qi, qiInv)
- y7[i] = MRed(p[i][x+7], qibMont, qi, qiInv)
+ y0[i] = MRed(pTmp[0], qibMont, qi, qiInv)
+ y1[i] = MRed(pTmp[1], qibMont, qi, qiInv)
+ y2[i] = MRed(pTmp[2], qibMont, qi, qiInv)
+ y3[i] = MRed(pTmp[3], qibMont, qi, qiInv)
+ y4[i] = MRed(pTmp[4], qibMont, qi, qiInv)
+ y5[i] = MRed(pTmp[5], qibMont, qi, qiInv)
+ y6[i] = MRed(pTmp[6], qibMont, qi, qiInv)
+ y7[i] = MRed(pTmp[7], qibMont, qi, qiInv)
// Computation of the correction term v * Q%pi
vi[0] += float64(y0[i]) / qif
@@ -709,13 +712,13 @@ func reconstructRNS(index, x uint64, p [][]uint64, v *[8]uint64, y0, y1, y2, y3,
}
// Caution, returns the values in [0, 2q-1]
-func multSum(res, v *[8]uint64, y0, y1, y2, y3, y4, y5, y6, y7 *[32]uint64, index, pj, qInv uint64, qpjInv, qispjMont []uint64) {
+func multSum(res, v *[8]uint64, y0, y1, y2, y3, y4, y5, y6, y7 *[32]uint64, index int, pj, qInv uint64, qpjInv, qispjMont []uint64) {
var rlo, rhi [8]uint64
var mhi, mlo, c, hhi uint64
// Accumulates the sum on uint128 and does a lazy montgomery reduction at the end
- for i := uint64(0); i < index; i++ {
+ for i := 0; i < index; i++ {
mhi, mlo = bits.Mul64(y0[i], qispjMont[i])
rlo[0], c = bits.Add64(rlo[0], mlo, 0)
diff --git a/ring/ring_benchmark_test.go b/ring/ring_benchmark_test.go
index 9f589667..c3f711b0 100644
--- a/ring/ring_benchmark_test.go
+++ b/ring/ring_benchmark_test.go
@@ -82,7 +82,7 @@ func benchSampling(testContext *testParams, b *testing.B) {
gaussianSampler := NewGaussianSampler(testContext.prng, testContext.ringQ, DefaultSigma, DefaultBound)
for i := 0; i < b.N; i++ {
- gaussianSampler.ReadLvl(uint64(len(testContext.ringQ.Modulus)-1), pol)
+ gaussianSampler.ReadLvl(len(testContext.ringQ.Modulus)-1, pol)
}
})
@@ -279,7 +279,7 @@ func benchExtendBasis(testContext *testParams, b *testing.B) {
p0 := testContext.uniformSamplerQ.ReadNew()
p1 := testContext.uniformSamplerP.ReadNew()
- level := uint64(len(testContext.ringQ.Modulus) - 1)
+ level := len(testContext.ringQ.Modulus) - 1
b.Run(fmt.Sprintf("ExtendBasis/ModUp/N=%d/limbsQ=%d/limbsP=%d", testContext.ringQ.N, len(testContext.ringQ.Modulus), len(testContext.ringP.Modulus)), func(b *testing.B) {
for i := 0; i < b.N; i++ {
@@ -302,53 +302,30 @@ func benchExtendBasis(testContext *testParams, b *testing.B) {
func benchDivByLastModulus(testContext *testParams, b *testing.B) {
- var p0 *Poly
+ p0 := testContext.uniformSamplerQ.ReadNew()
+ p1 := testContext.ringQ.NewPolyLvl(p0.Level() - 1)
b.Run(testString("DivByLastModulus/Floor/", testContext.ringQ), func(b *testing.B) {
-
for i := 0; i < b.N; i++ {
-
- b.StopTimer()
- p0 = testContext.uniformSamplerQ.ReadNew()
- b.StartTimer()
-
- testContext.ringQ.DivFloorByLastModulus(p0)
+ testContext.ringQ.DivFloorByLastModulus(p0, p1)
}
})
b.Run(testString("DivByLastModulus/FloorNTT/", testContext.ringQ), func(b *testing.B) {
-
for i := 0; i < b.N; i++ {
-
- b.StopTimer()
- p0 = testContext.uniformSamplerQ.ReadNew()
- b.StartTimer()
-
- testContext.ringQ.DivFloorByLastModulusNTT(p0)
+ testContext.ringQ.DivFloorByLastModulusNTT(p0, p1)
}
})
b.Run(testString("DivByLastModulus/Round/", testContext.ringQ), func(b *testing.B) {
-
for i := 0; i < b.N; i++ {
-
- b.StopTimer()
- p0 = testContext.uniformSamplerQ.ReadNew()
- b.StartTimer()
-
- testContext.ringQ.DivRoundByLastModulus(p0)
+ testContext.ringQ.DivRoundByLastModulus(p0, p1)
}
})
b.Run(testString("DivByLastModulus/RoundNTT/", testContext.ringQ), func(b *testing.B) {
-
for i := 0; i < b.N; i++ {
-
- b.StopTimer()
- p0 = testContext.uniformSamplerQ.ReadNew()
- b.StartTimer()
-
- testContext.ringQ.DivRoundByLastModulusNTT(p0)
+ testContext.ringQ.DivRoundByLastModulusNTT(p0, p1)
}
})
}
@@ -360,7 +337,7 @@ func benchDivByRNSBasis(testContext *testParams, b *testing.B) {
rescaler := NewSimpleScaler(T, testContext.ringQ)
coeffs := make([]*big.Int, testContext.ringQ.N)
- for i := uint64(0); i < testContext.ringQ.N; i++ {
+ for i := 0; i < testContext.ringQ.N; i++ {
coeffs[i] = RandInt(testContext.ringQ.ModulusBigint)
}
@@ -379,7 +356,7 @@ func benchDivByRNSBasis(testContext *testParams, b *testing.B) {
rescaler := NewSimpleScaler(T, testContext.ringQ)
coeffs := make([]*big.Int, testContext.ringQ.N)
- for i := uint64(0); i < testContext.ringQ.N; i++ {
+ for i := 0; i < testContext.ringQ.N; i++ {
coeffs[i] = RandInt(testContext.ringQ.ModulusBigint)
}
@@ -396,7 +373,7 @@ func benchDivByRNSBasis(testContext *testParams, b *testing.B) {
b.Run(testString("DivByRNSBasis/RNS/DivByQOverTRounded/", testContext.ringQ), func(b *testing.B) {
coeffs := make([]*big.Int, testContext.ringQ.N)
- for i := uint64(0); i < testContext.ringQ.N; i++ {
+ for i := 0; i < testContext.ringQ.N; i++ {
coeffs[i] = RandInt(testContext.ringQ.ModulusBigint)
}
diff --git a/ring/ring_ntt.go b/ring/ring_ntt.go
index 9660dd12..b2202abd 100644
--- a/ring/ring_ntt.go
+++ b/ring/ring_ntt.go
@@ -14,8 +14,8 @@ func (r *Ring) NTT(p1, p2 *Poly) {
// NTTLvl computes the NTT of p1 and returns the result on p2.
// The value level defines the number of moduli of the input polynomials.
-func (r *Ring) NTTLvl(level uint64, p1, p2 *Poly) {
- for x := uint64(0); x < level+1; x++ {
+func (r *Ring) NTTLvl(level int, p1, p2 *Poly) {
+ for x := 0; x < level+1; x++ {
NTT(p1.Coeffs[x], p2.Coeffs[x], r.N, r.NttPsi[x], r.Modulus[x], r.MredParams[x], r.BredParams[x])
}
}
@@ -29,8 +29,8 @@ func (r *Ring) InvNTT(p1, p2 *Poly) {
// InvNTTLvl computes the inverse-NTT of p1 and returns the result on p2.
// The value level defines the number of moduli of the input polynomials.
-func (r *Ring) InvNTTLvl(level uint64, p1, p2 *Poly) {
- for x := uint64(0); x < level+1; x++ {
+func (r *Ring) InvNTTLvl(level int, p1, p2 *Poly) {
+ for x := 0; x < level+1; x++ {
InvNTT(p1.Coeffs[x], p2.Coeffs[x], r.N, r.NttPsiInv[x], r.NttNInv[x], r.Modulus[x], r.MredParams[x])
}
}
@@ -46,8 +46,8 @@ func (r *Ring) NTTLazy(p1, p2 *Poly) {
// NTTLazyLvl computes the NTT of p1 and returns the result on p2.
// The value level defines the number of moduli of the input polynomials.
// Output values are in the range [0, 2q-1]
-func (r *Ring) NTTLazyLvl(level uint64, p1, p2 *Poly) {
- for x := uint64(0); x < level+1; x++ {
+func (r *Ring) NTTLazyLvl(level int, p1, p2 *Poly) {
+ for x := 0; x < level+1; x++ {
NTTLazy(p1.Coeffs[x], p2.Coeffs[x], r.N, r.NttPsi[x], r.Modulus[x], r.MredParams[x], r.BredParams[x])
}
}
@@ -63,8 +63,8 @@ func (r *Ring) InvNTTLazy(p1, p2 *Poly) {
// InvNTTLazyLvl computes the inverse-NTT of p1 and returns the result on p2.
// The value level defines the number of moduli of the input polynomials.
// Output values are in the range [0, 2q-1]
-func (r *Ring) InvNTTLazyLvl(level uint64, p1, p2 *Poly) {
- for x := uint64(0); x < level+1; x++ {
+func (r *Ring) InvNTTLazyLvl(level int, p1, p2 *Poly) {
+ for x := 0; x < level+1; x++ {
InvNTTLazy(p1.Coeffs[x], p2.Coeffs[x], r.N, r.NttPsiInv[x], r.NttNInv[x], r.Modulus[x], r.MredParams[x])
}
}
@@ -79,11 +79,11 @@ func butterfly(U, V, Psi, twoQ, fourQ, Q, Qinv uint64) (uint64, uint64) {
}
// NTT computes the NTT on the input coefficients using the input parameters.
-func NTT(coeffsIn, coeffsOut []uint64, N uint64, nttPsi []uint64, Q, mredParams uint64, bredParams []uint64) {
+func NTT(coeffsIn, coeffsOut []uint64, N int, nttPsi []uint64, Q, mredParams uint64, bredParams []uint64) {
NTTLazy(coeffsIn, coeffsOut, N, nttPsi, Q, mredParams, bredParams)
// Finish with an exact reduction
- for i := uint64(0); i < N; i = i + 8 {
+ for i := 0; i < N; i = i + 8 {
x := (*[8]uint64)(unsafe.Pointer(&coeffsOut[i]))
@@ -99,9 +99,9 @@ func NTT(coeffsIn, coeffsOut []uint64, N uint64, nttPsi []uint64, Q, mredParams
}
// NTTLazy computes the NTT on the input coefficients using the input parameters with output values in the range [0, 2q-1].
-func NTTLazy(coeffsIn, coeffsOut []uint64, N uint64, nttPsi []uint64, Q, QInv uint64, bredParams []uint64) {
- var j1, j2, t uint64
- var F uint64
+func NTTLazy(coeffsIn, coeffsOut []uint64, N int, nttPsi []uint64, Q, QInv uint64, bredParams []uint64) {
+ var j1, j2, t int
+ var F, V uint64
fourQ := 4 * Q
twoQ := 2 * Q
@@ -109,9 +109,8 @@ func NTTLazy(coeffsIn, coeffsOut []uint64, N uint64, nttPsi []uint64, Q, QInv ui
// Copy the result of the first round of butterflies on p2 with approximate reduction
t = N >> 1
F = nttPsi[1]
- var V uint64
- for j := uint64(0); j <= t-1; j = j + 8 {
+ for j := 0; j <= t-1; j = j + 8 {
xin := (*[8]uint64)(unsafe.Pointer(&coeffsIn[j]))
yin := (*[8]uint64)(unsafe.Pointer(&coeffsIn[j+t]))
@@ -147,15 +146,15 @@ func NTTLazy(coeffsIn, coeffsOut []uint64, N uint64, nttPsi []uint64, Q, QInv ui
// Continue the rest of the second to the n-1 butterflies on p2 with approximate reduction
var reduce bool
- for m := uint64(2); m < N; m <<= 1 {
+ for m := 2; m < N; m <<= 1 {
- reduce = (bits.Len64(m)&1 == 1)
+ reduce = (bits.Len64(uint64(m))&1 == 1)
t >>= 1
if t >= 8 {
- for i := uint64(0); i < m; i++ {
+ for i := 0; i < m; i++ {
j1 = (i * t) << 1
@@ -218,7 +217,7 @@ func NTTLazy(coeffsIn, coeffsOut []uint64, N uint64, nttPsi []uint64, Q, QInv ui
if reduce {
- for i := uint64(0); i < m; i = i + 2 {
+ for i := 0; i < m; i = i + 2 {
j1 = (i * t) << 1
@@ -237,7 +236,7 @@ func NTTLazy(coeffsIn, coeffsOut []uint64, N uint64, nttPsi []uint64, Q, QInv ui
}
} else {
- for i := uint64(0); i < m; i = i + 2 {
+ for i := 0; i < m; i = i + 2 {
j1 = (i * t) << 1
@@ -276,7 +275,7 @@ func NTTLazy(coeffsIn, coeffsOut []uint64, N uint64, nttPsi []uint64, Q, QInv ui
if reduce {
- for i := uint64(0); i < m; i = i + 4 {
+ for i := 0; i < m; i = i + 4 {
j1 = (i * t) << 1
@@ -294,7 +293,7 @@ func NTTLazy(coeffsIn, coeffsOut []uint64, N uint64, nttPsi []uint64, Q, QInv ui
}
} else {
- for i := uint64(0); i < m; i = i + 4 {
+ for i := 0; i < m; i = i + 4 {
j1 = (i * t) << 1
@@ -329,7 +328,7 @@ func NTTLazy(coeffsIn, coeffsOut []uint64, N uint64, nttPsi []uint64, Q, QInv ui
} else {
- for i := uint64(0); i < m; i = i + 8 {
+ for i := 0; i < m; i = i + 8 {
psi := (*[8]uint64)(unsafe.Pointer(&nttPsi[m+i]))
x := (*[16]uint64)(unsafe.Pointer(&coeffsOut[2*i]))
@@ -390,9 +389,9 @@ func invbutterfly(U, V, Psi, twoQ, fourQ, Q, Qinv uint64) (X, Y uint64) {
}
// InvNTT computes the InvNTT transformation on the input coefficients using the input parameters.
-func InvNTT(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttNInv, Q, QInv uint64) {
+func InvNTT(coeffsIn, coeffsOut []uint64, N int, nttPsiInv []uint64, nttNInv, Q, QInv uint64) {
- var j1, j2, h, t uint64
+ var j1, j2, h, t int
var F uint64
// Copy the result of the first round of butterflies on p2 with approximate reduction
@@ -401,7 +400,7 @@ func InvNTT(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttNInv,
twoQ := Q << 1
fourQ := Q << 2
- for i := uint64(0); i < h; i = i + 8 {
+ for i := 0; i < h; i = i + 8 {
psi := (*[8]uint64)(unsafe.Pointer(&nttPsiInv[h+i]))
xin := (*[16]uint64)(unsafe.Pointer(&coeffsIn[2*i]))
@@ -426,7 +425,7 @@ func InvNTT(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttNInv,
if t >= 8 {
- for i := uint64(0); i < h; i++ {
+ for i := 0; i < h; i++ {
j2 = j1 + t - 1
@@ -452,7 +451,7 @@ func InvNTT(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttNInv,
} else if t == 4 {
- for i := uint64(0); i < h; i = i + 2 {
+ for i := 0; i < h; i = i + 2 {
psi := (*[2]uint64)(unsafe.Pointer(&nttPsiInv[h+i]))
x := (*[16]uint64)(unsafe.Pointer(&coeffsOut[j1]))
@@ -471,7 +470,7 @@ func InvNTT(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttNInv,
} else {
- for i := uint64(0); i < h; i = i + 4 {
+ for i := 0; i < h; i = i + 4 {
psi := (*[4]uint64)(unsafe.Pointer(&nttPsiInv[h+i]))
x := (*[16]uint64)(unsafe.Pointer(&coeffsOut[j1]))
@@ -493,7 +492,7 @@ func InvNTT(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttNInv,
}
// Finish with an exact reduction
- for i := uint64(0); i < N; i = i + 8 {
+ for i := 0; i < N; i = i + 8 {
x := (*[8]uint64)(unsafe.Pointer(&coeffsOut[i]))
@@ -509,9 +508,9 @@ func InvNTT(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttNInv,
}
// InvNTTLazy computes the InvNTT transformation on the input coefficients using the input parameters with output values in the range [0, 2q-1].
-func InvNTTLazy(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttNInv, Q, mredParams uint64) {
+func InvNTTLazy(coeffsIn, coeffsOut []uint64, N int, nttPsiInv []uint64, nttNInv, Q, mredParams uint64) {
- var j1, j2, h, t uint64
+ var j1, j2, h, t int
var F uint64
// Copy the result of the first round of butterflies on p2 with approximate reduction
@@ -521,7 +520,7 @@ func InvNTTLazy(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttN
twoQ := Q << 1
fourQ := Q << 2
- for i := uint64(0); i < h; i = i + 8 {
+ for i := 0; i < h; i = i + 8 {
psi := (*[8]uint64)(unsafe.Pointer(&nttPsiInv[h+i]))
xin := (*[16]uint64)(unsafe.Pointer(&coeffsIn[2*i]))
@@ -546,7 +545,7 @@ func InvNTTLazy(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttN
if t >= 8 {
- for i := uint64(0); i < h; i++ {
+ for i := 0; i < h; i++ {
j2 = j1 + t - 1
@@ -572,7 +571,7 @@ func InvNTTLazy(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttN
} else if t == 4 {
- for i := uint64(0); i < h; i = i + 2 {
+ for i := 0; i < h; i = i + 2 {
psi := (*[2]uint64)(unsafe.Pointer(&nttPsiInv[h+i]))
x := (*[16]uint64)(unsafe.Pointer(&coeffsOut[j1]))
@@ -591,7 +590,7 @@ func InvNTTLazy(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttN
} else {
- for i := uint64(0); i < h; i = i + 4 {
+ for i := 0; i < h; i = i + 4 {
psi := (*[4]uint64)(unsafe.Pointer(&nttPsiInv[h+i]))
x := (*[16]uint64)(unsafe.Pointer(&coeffsOut[j1]))
@@ -613,7 +612,7 @@ func InvNTTLazy(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttN
}
// Finish with an exact reduction
- for i := uint64(0); i < N; i = i + 8 {
+ for i := 0; i < N; i = i + 8 {
x := (*[8]uint64)(unsafe.Pointer(&coeffsOut[i]))
@@ -673,20 +672,20 @@ func invbutterflyBarrett(U, V, Psi, Q uint64, bredParams []uint64) (X, Y uint64)
// NTTBarrett computes the NTT using Barrett reduction.
// For benchmark purposes only.
-func NTTBarrett(coeffsIn, coeffsOut []uint64, N uint64, nttPsi []uint64, Q uint64, bredParams []uint64) {
- var j1, j2, t uint64
+func NTTBarrett(coeffsIn, coeffsOut []uint64, N int, nttPsi []uint64, Q uint64, bredParams []uint64) {
+ var j1, j2, t int
var F uint64
t = N >> 1
j2 = t - 1
F = nttPsi[1]
- for j := uint64(0); j <= j2; j++ {
+ for j := 0; j <= j2; j++ {
coeffsOut[j], coeffsOut[j+t] = butterflyBarrett(coeffsIn[j], coeffsIn[j+t], F, Q, bredParams)
}
- for m := uint64(2); m < N; m <<= 1 {
+ for m := 2; m < N; m <<= 1 {
t >>= 1
- for i := uint64(0); i < m; i++ {
+ for i := 0; i < m; i++ {
j1 = (i * t) << 1
@@ -700,23 +699,23 @@ func NTTBarrett(coeffsIn, coeffsOut []uint64, N uint64, nttPsi []uint64, Q uint6
}
}
- for i := uint64(0); i < N; i++ {
+ for i := 0; i < N; i++ {
coeffsOut[i] = BRedAdd(coeffsOut[i], Q, bredParams)
}
}
// InvNTTBarrett computes the Inverse NTT using Barrett reduction.
// For benchmark purposes only.
-func InvNTTBarrett(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, nttNInv, Q uint64, bredParams []uint64) {
+func InvNTTBarrett(coeffsIn, coeffsOut []uint64, N int, nttPsiInv []uint64, nttNInv, Q uint64, bredParams []uint64) {
- var j1, j2, h, t uint64
+ var j1, j2, h, t int
var F uint64
t = 1
j1 = 0
h = N >> 1
- for i := uint64(0); i < h; i++ {
+ for i := 0; i < h; i++ {
j2 = j1
@@ -735,7 +734,7 @@ func InvNTTBarrett(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, n
j1 = 0
h = m >> 1
- for i := uint64(0); i < h; i++ {
+ for i := 0; i < h; i++ {
j2 = j1 + t - 1
@@ -751,7 +750,7 @@ func InvNTTBarrett(coeffsIn, coeffsOut []uint64, N uint64, nttPsiInv []uint64, n
t <<= 1
}
- for j := uint64(0); j < N; j++ {
+ for j := 0; j < N; j++ {
coeffsOut[j] = BRed(coeffsOut[j], nttNInv, Q, bredParams)
}
}
diff --git a/ring/ring_ntt_test.go b/ring/ring_ntt_test.go
index f3f0f495..c9270664 100644
--- a/ring/ring_ntt_test.go
+++ b/ring/ring_ntt_test.go
@@ -8,7 +8,7 @@ import (
)
var testVector = []struct {
- N uint64
+ N int
Qis []uint64
poly *Poly
diff --git a/ring/ring_operations.go b/ring/ring_operations.go
index f3fed573..985d2721 100644
--- a/ring/ring_operations.go
+++ b/ring/ring_operations.go
@@ -12,7 +12,7 @@ import (
func (r *Ring) Add(p1, p2, p3 *Poly) {
for i, qi := range r.Modulus {
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -32,11 +32,11 @@ func (r *Ring) Add(p1, p2, p3 *Poly) {
// AddLvl adds p1 to p2 coefficient-wise for the moduli from
// q_0 up to q_level and writes the result on p3.
-func (r *Ring) AddLvl(level uint64, p1, p2, p3 *Poly) {
- for i := uint64(0); i < level+1; i++ {
+func (r *Ring) AddLvl(level int, p1, p2, p3 *Poly) {
+ for i := 0; i < level+1; i++ {
qi := r.Modulus[i]
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -59,7 +59,7 @@ func (r *Ring) AddLvl(level uint64, p1, p2, p3 *Poly) {
func (r *Ring) AddNoMod(p1, p2, p3 *Poly) {
for i := range r.Modulus {
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -79,10 +79,10 @@ func (r *Ring) AddNoMod(p1, p2, p3 *Poly) {
// AddNoModLvl adds p1 to p2 coefficient-wise without modular reduction
// for the moduli from q_0 up to q_level and writes the result on p3.
-func (r *Ring) AddNoModLvl(level uint64, p1, p2, p3 *Poly) {
- for i := uint64(0); i < level+1; i++ {
+func (r *Ring) AddNoModLvl(level int, p1, p2, p3 *Poly) {
+ for i := 0; i < level+1; i++ {
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -104,7 +104,7 @@ func (r *Ring) AddNoModLvl(level uint64, p1, p2, p3 *Poly) {
func (r *Ring) Sub(p1, p2, p3 *Poly) {
for i, qi := range r.Modulus {
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -123,11 +123,11 @@ func (r *Ring) Sub(p1, p2, p3 *Poly) {
}
// SubLvl subtracts p2 to p1 coefficient-wise and writes the result on p3.
-func (r *Ring) SubLvl(level uint64, p1, p2, p3 *Poly) {
- for i := uint64(0); i < level+1; i++ {
+func (r *Ring) SubLvl(level int, p1, p2, p3 *Poly) {
+ for i := 0; i < level+1; i++ {
qi := r.Modulus[i]
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -150,7 +150,7 @@ func (r *Ring) SubLvl(level uint64, p1, p2, p3 *Poly) {
func (r *Ring) SubNoMod(p1, p2, p3 *Poly) {
for i, qi := range r.Modulus {
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -170,11 +170,11 @@ func (r *Ring) SubNoMod(p1, p2, p3 *Poly) {
// SubNoModLvl subtracts p2 to p1 coefficient-wise without modular reduction
// for the moduli from q_0 up to q_level and writes the result on p3.
-func (r *Ring) SubNoModLvl(level uint64, p1, p2, p3 *Poly) {
- for i := uint64(0); i < level+1; i++ {
+func (r *Ring) SubNoModLvl(level int, p1, p2, p3 *Poly) {
+ for i := 0; i < level+1; i++ {
qi := r.Modulus[i]
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -196,7 +196,7 @@ func (r *Ring) SubNoModLvl(level uint64, p1, p2, p3 *Poly) {
func (r *Ring) Neg(p1, p2 *Poly) {
for i, qi := range r.Modulus {
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -215,11 +215,11 @@ func (r *Ring) Neg(p1, p2 *Poly) {
// NegLvl sets the coefficients of p1 to their additive inverse for
// the moduli from q_0 up to q_level and writes the result on p2.
-func (r *Ring) NegLvl(level uint64, p1, p2 *Poly) {
- for i := uint64(0); i < level+1; i++ {
+func (r *Ring) NegLvl(level int, p1, p2 *Poly) {
+ for i := 0; i < level+1; i++ {
qi := r.Modulus[i]
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -241,7 +241,7 @@ func (r *Ring) Reduce(p1, p2 *Poly) {
for i, qi := range r.Modulus {
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
bredParams := r.BredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -264,7 +264,7 @@ func (r *Ring) ReduceConstant(p1, p2 *Poly) {
for i, qi := range r.Modulus {
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
bredParams := r.BredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -283,12 +283,12 @@ func (r *Ring) ReduceConstant(p1, p2 *Poly) {
// ReduceLvl applies a modular reduction on the coefficients of p1
// for the moduli from q_0 up to q_level and writes the result on p2.
-func (r *Ring) ReduceLvl(level uint64, p1, p2 *Poly) {
- for i := uint64(0); i < level+1; i++ {
+func (r *Ring) ReduceLvl(level int, p1, p2 *Poly) {
+ for i := 0; i < level+1; i++ {
qi := r.Modulus[i]
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
bredParams := r.BredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -308,12 +308,12 @@ func (r *Ring) ReduceLvl(level uint64, p1, p2 *Poly) {
// ReduceConstantLvl applies a modular reduction on the coefficients of p1
// for the moduli from q_0 up to q_level and writes the result on p2.
// Return values in [0, 2q-1]
-func (r *Ring) ReduceConstantLvl(level uint64, p1, p2 *Poly) {
- for i := uint64(0); i < level+1; i++ {
+func (r *Ring) ReduceConstantLvl(level int, p1, p2 *Poly) {
+ for i := 0; i < level+1; i++ {
qi := r.Modulus[i]
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
bredParams := r.BredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -335,7 +335,7 @@ func (r *Ring) Mod(p1 *Poly, m uint64, p2 *Poly) {
bredParams := BRedParams(m)
for i := range r.Modulus {
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -356,7 +356,7 @@ func (r *Ring) Mod(p1 *Poly, m uint64, p2 *Poly) {
func (r *Ring) AND(p1 *Poly, m uint64, p2 *Poly) {
for i := range r.Modulus {
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
- for j := uint64(0); j < r.N; j++ {
+ for j := 0; j < r.N; j++ {
p2tmp[j] = p1tmp[j] & m
}
}
@@ -366,7 +366,7 @@ func (r *Ring) AND(p1 *Poly, m uint64, p2 *Poly) {
func (r *Ring) OR(p1 *Poly, m uint64, p2 *Poly) {
for i := range r.Modulus {
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
- for j := uint64(0); j < r.N; j++ {
+ for j := 0; j < r.N; j++ {
p2tmp[j] = p1tmp[j] | m
}
}
@@ -376,7 +376,7 @@ func (r *Ring) OR(p1 *Poly, m uint64, p2 *Poly) {
func (r *Ring) XOR(p1 *Poly, m uint64, p2 *Poly) {
for i := range r.Modulus {
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
- for j := uint64(0); j < r.N; j++ {
+ for j := 0; j < r.N; j++ {
p2tmp[j] = p1tmp[j] ^ m
}
}
@@ -388,7 +388,7 @@ func (r *Ring) MulCoeffs(p1, p2, p3 *Poly) {
for i, qi := range r.Modulus {
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
bredParams := r.BredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -412,7 +412,7 @@ func (r *Ring) MulCoeffsAndAdd(p1, p2, p3 *Poly) {
for i, qi := range r.Modulus {
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
bredParams := r.BredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -436,7 +436,7 @@ func (r *Ring) MulCoeffsAndAddNoMod(p1, p2, p3 *Poly) {
for i, qi := range r.Modulus {
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
bredParams := r.BredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -461,7 +461,7 @@ func (r *Ring) MulCoeffsMontgomery(p1, p2, p3 *Poly) {
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -481,12 +481,12 @@ func (r *Ring) MulCoeffsMontgomery(p1, p2, p3 *Poly) {
// MulCoeffsMontgomeryLvl multiplies p1 by p2 coefficient-wise with a Montgomery
// modular reduction for the moduli from q_0 up to q_level and returns the result on p3.
-func (r *Ring) MulCoeffsMontgomeryLvl(level uint64, p1, p2, p3 *Poly) {
- for i := uint64(0); i < level+1; i++ {
+func (r *Ring) MulCoeffsMontgomeryLvl(level int, p1, p2, p3 *Poly) {
+ for i := 0; i < level+1; i++ {
qi := r.Modulus[i]
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -506,12 +506,12 @@ func (r *Ring) MulCoeffsMontgomeryLvl(level uint64, p1, p2, p3 *Poly) {
// MulCoeffsMontgomeryConstantLvl multiplies p1 by p2 coefficient-wise with a Montgomery
// modular reduction for the moduli from q_0 up to q_level and returns the result on p3.
-func (r *Ring) MulCoeffsMontgomeryConstantLvl(level uint64, p1, p2, p3 *Poly) {
- for i := uint64(0); i < level+1; i++ {
+func (r *Ring) MulCoeffsMontgomeryConstantLvl(level int, p1, p2, p3 *Poly) {
+ for i := 0; i < level+1; i++ {
qi := r.Modulus[i]
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -535,7 +535,7 @@ func (r *Ring) MulCoeffsMontgomeryAndAdd(p1, p2, p3 *Poly) {
for i, qi := range r.Modulus {
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -555,12 +555,12 @@ func (r *Ring) MulCoeffsMontgomeryAndAdd(p1, p2, p3 *Poly) {
// MulCoeffsMontgomeryAndAddLvl multiplies p1 by p2 coefficient-wise with a Montgomery
// modular reduction for the moduli from q_0 up to q_level and adds the result to p3.
-func (r *Ring) MulCoeffsMontgomeryAndAddLvl(level uint64, p1, p2, p3 *Poly) {
- for i := uint64(0); i < level+1; i++ {
+func (r *Ring) MulCoeffsMontgomeryAndAddLvl(level int, p1, p2, p3 *Poly) {
+ for i := 0; i < level+1; i++ {
qi := r.Modulus[i]
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -584,7 +584,7 @@ func (r *Ring) MulCoeffsMontgomeryAndAddNoMod(p1, p2, p3 *Poly) {
for i, qi := range r.Modulus {
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -609,7 +609,7 @@ func (r *Ring) MulCoeffsMontgomeryConstantAndAddNoMod(p1, p2, p3 *Poly) {
for i, qi := range r.Modulus {
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -629,12 +629,12 @@ func (r *Ring) MulCoeffsMontgomeryConstantAndAddNoMod(p1, p2, p3 *Poly) {
// MulCoeffsMontgomeryAndAddNoModLvl multiplies p1 by p2 coefficient-wise with a Montgomery modular
// reduction for the moduli from q_0 up to q_level and adds the result to p3 without modular reduction.
-func (r *Ring) MulCoeffsMontgomeryAndAddNoModLvl(level uint64, p1, p2, p3 *Poly) {
- for i := uint64(0); i < level+1; i++ {
+func (r *Ring) MulCoeffsMontgomeryAndAddNoModLvl(level int, p1, p2, p3 *Poly) {
+ for i := 0; i < level+1; i++ {
qi := r.Modulus[i]
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -655,12 +655,12 @@ func (r *Ring) MulCoeffsMontgomeryAndAddNoModLvl(level uint64, p1, p2, p3 *Poly)
// MulCoeffsMontgomeryConstantAndAddNoModLvl multiplies p1 by p2 coefficient-wise with a constant-time Montgomery
// modular reduction for the moduli from q_0 up to q_level and adds the result to p3 without modular reduction.
// Return values in [0, 3q-1]
-func (r *Ring) MulCoeffsMontgomeryConstantAndAddNoModLvl(level uint64, p1, p2, p3 *Poly) {
- for i := uint64(0); i < level+1; i++ {
+func (r *Ring) MulCoeffsMontgomeryConstantAndAddNoModLvl(level int, p1, p2, p3 *Poly) {
+ for i := 0; i < level+1; i++ {
qi := r.Modulus[i]
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -684,7 +684,7 @@ func (r *Ring) MulCoeffsMontgomeryAndSub(p1, p2, p3 *Poly) {
for i, qi := range r.Modulus {
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -708,7 +708,7 @@ func (r *Ring) MulCoeffsMontgomeryAndSubNoMod(p1, p2, p3 *Poly) {
for i, qi := range r.Modulus {
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -732,7 +732,7 @@ func (r *Ring) MulCoeffsConstant(p1, p2, p3 *Poly) {
for i, qi := range r.Modulus {
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
bredParams := r.BredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -756,7 +756,7 @@ func (r *Ring) MulCoeffsMontgomeryConstant(p1, p2, p3 *Poly) {
for i, qi := range r.Modulus {
p1tmp, p2tmp, p3tmp := p1.Coeffs[i], p2.Coeffs[i], p3.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -814,13 +814,13 @@ func (r *Ring) MulPolyNaive(p1, p2, p3 *Poly) {
mredParams := r.MredParams[x]
- for i := uint64(0); i < r.N; i++ {
+ for i := 0; i < r.N; i++ {
- for j := uint64(0); j < i; j++ {
+ for j := 0; j < i; j++ {
p3tmp[j] = CRed(p3tmp[j]+(qi-MRed(p1tmp[i], p2tmp[r.N-i+j], qi, mredParams)), qi)
}
- for j := uint64(i); j < r.N; j++ {
+ for j := i; j < r.N; j++ {
p3tmp[j] = CRed(p3tmp[j]+MRed(p1tmp[i], p2tmp[j-i], qi, mredParams), qi)
}
}
@@ -841,13 +841,13 @@ func (r *Ring) MulPolyNaiveMontgomery(p1, p2, p3 *Poly) {
mredParams := r.MredParams[x]
- for i := uint64(0); i < r.N; i++ {
+ for i := 0; i < r.N; i++ {
- for j := uint64(0); j < i; j++ {
+ for j := 0; j < i; j++ {
p3tmp[j] = CRed(p3tmp[j]+(qi-MRed(p1tmp[i], p2tmp[r.N-i+j], qi, mredParams)), qi)
}
- for j := uint64(i); j < r.N; j++ {
+ for j := i; j < r.N; j++ {
p3tmp[j] = CRed(p3tmp[j]+MRed(p1tmp[i], p2tmp[j-i], qi, mredParams), qi)
}
}
@@ -856,7 +856,7 @@ func (r *Ring) MulPolyNaiveMontgomery(p1, p2, p3 *Poly) {
// Exp raises p1 to p1^e and writes the result on p2.
// IMPROVEMENT : implement Montgomery ladder.
-func (r *Ring) Exp(p1 *Poly, e uint64, p2 *Poly) {
+func (r *Ring) Exp(p1 *Poly, e int, p2 *Poly) {
r.NTT(p1, p1)
@@ -865,7 +865,7 @@ func (r *Ring) Exp(p1 *Poly, e uint64, p2 *Poly) {
for i := range r.Modulus {
p2tmp := p2.Coeffs[i]
- for x := uint64(0); x < r.N; x++ {
+ for x := 0; x < r.N; x++ {
p2tmp[x] = 1
}
}
@@ -885,7 +885,7 @@ func (r *Ring) Exp(p1 *Poly, e uint64, p2 *Poly) {
func (r *Ring) AddScalar(p1 *Poly, scalar uint64, p2 *Poly) {
for i, Qi := range r.Modulus {
p1tmp, p2tmp := p1.Coeffs[i], p1.Coeffs[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -908,7 +908,7 @@ func (r *Ring) AddScalarBigint(p1 *Poly, scalar *big.Int, p2 *Poly) {
for i, Qi := range r.Modulus {
scalarQi := tmp.Mod(scalar, NewUint(Qi)).Uint64()
p1tmp, p2tmp := p1.Coeffs[i], p1.Coeffs[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -929,7 +929,7 @@ func (r *Ring) AddScalarBigint(p1 *Poly, scalar *big.Int, p2 *Poly) {
func (r *Ring) SubScalar(p1 *Poly, scalar uint64, p2 *Poly) {
for i, Qi := range r.Modulus {
p1tmp, p2tmp := p1.Coeffs[i], p1.Coeffs[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -952,7 +952,7 @@ func (r *Ring) SubScalarBigint(p1 *Poly, scalar *big.Int, p2 *Poly) {
for i, Qi := range r.Modulus {
scalarQi := tmp.Mod(scalar, NewUint(Qi)).Uint64()
p1tmp, p2tmp := p1.Coeffs[i], p1.Coeffs[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -975,7 +975,7 @@ func (r *Ring) MulScalar(p1 *Poly, scalar uint64, p2 *Poly) {
scalarMont := MForm(BRedAdd(scalar, Qi, r.BredParams[i]), Qi, r.BredParams[i])
mredParams := r.MredParams[i]
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -993,13 +993,13 @@ func (r *Ring) MulScalar(p1 *Poly, scalar uint64, p2 *Poly) {
}
// MulScalarLvl multiplies each coefficient of p1 by a scalar for the moduli from q_0 up to q_level and writes the result on p2.
-func (r *Ring) MulScalarLvl(level uint64, p1 *Poly, scalar uint64, p2 *Poly) {
- for i := uint64(0); i < level+1; i++ {
+func (r *Ring) MulScalarLvl(level int, p1 *Poly, scalar uint64, p2 *Poly) {
+ for i := 0; i < level+1; i++ {
Qi := r.Modulus[i]
scalarMont := MForm(BRedAdd(scalar, Qi, r.BredParams[i]), Qi, r.BredParams[i])
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -1024,7 +1024,7 @@ func (r *Ring) MulScalarBigint(p1 *Poly, scalar *big.Int, p2 *Poly) {
scalarMont := MForm(BRedAdd(scalarQi.Uint64(), Qi, r.BredParams[i]), Qi, r.BredParams[i])
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -1043,15 +1043,15 @@ func (r *Ring) MulScalarBigint(p1 *Poly, scalar *big.Int, p2 *Poly) {
// MulScalarBigintLvl multiplies each coefficient of p1 by a big.Int scalar
//for the moduli from q_0 up to q_level and writes the result on p2.
-func (r *Ring) MulScalarBigintLvl(level uint64, p1 *Poly, scalar *big.Int, p2 *Poly) {
+func (r *Ring) MulScalarBigintLvl(level int, p1 *Poly, scalar *big.Int, p2 *Poly) {
scalarQi := new(big.Int)
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
Qi := r.Modulus[i]
scalarQi.Mod(scalar, NewUint(Qi))
scalarMont := MForm(BRedAdd(scalarQi.Uint64(), Qi, r.BredParams[i]), Qi, r.BredParams[i])
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -1069,8 +1069,8 @@ func (r *Ring) MulScalarBigintLvl(level uint64, p1 *Poly, scalar *big.Int, p2 *P
}
// Shift circulary shifts the coefficients of the polynomial p1 by n positions to the left and writes the result on p2.
-func (r *Ring) Shift(p1 *Poly, n uint64, p2 *Poly) {
- mask := uint64((1 << r.N) - 1)
+func (r *Ring) Shift(p1 *Poly, n int, p2 *Poly) {
+ mask := (1 << r.N) - 1
for i := range r.Modulus {
p2.Coeffs[i] = append(p1.Coeffs[i][(n&mask):], p1.Coeffs[i][:(n&mask)]...)
}
@@ -1081,7 +1081,7 @@ func (r *Ring) MForm(p1, p2 *Poly) {
for i, qi := range r.Modulus {
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
bredParams := r.BredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -1099,12 +1099,12 @@ func (r *Ring) MForm(p1, p2 *Poly) {
}
// MFormLvl switches p1 to the Montgomery domain for the moduli from q_0 up to q_level and writes the result on p2.
-func (r *Ring) MFormLvl(level uint64, p1, p2 *Poly) {
- for i := uint64(0); i < level+1; i++ {
+func (r *Ring) MFormLvl(level int, p1, p2 *Poly) {
+ for i := 0; i < level+1; i++ {
qi := r.Modulus[i]
bredParams := r.BredParams[i]
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -1126,7 +1126,7 @@ func (r *Ring) InvMForm(p1, p2 *Poly) {
for i, qi := range r.Modulus {
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -1144,11 +1144,11 @@ func (r *Ring) InvMForm(p1, p2 *Poly) {
}
// InvMFormLvl switches back p1 from the Montgomery domain to the conventional domain and writes the result on p2.
-func (r *Ring) InvMFormLvl(level uint64, p1, p2 *Poly) {
+func (r *Ring) InvMFormLvl(level int, p1, p2 *Poly) {
for i, qi := range r.Modulus[:level+1] {
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -1166,19 +1166,19 @@ func (r *Ring) InvMFormLvl(level uint64, p1, p2 *Poly) {
}
// MulByPow2New multiplies p1 by 2^pow2 and returns the result in a new polynomial p2.
-func (r *Ring) MulByPow2New(p1 *Poly, pow2 uint64) (p2 *Poly) {
+func (r *Ring) MulByPow2New(p1 *Poly, pow2 int) (p2 *Poly) {
p2 = r.NewPoly()
r.MulByPow2(p1, pow2, p2)
return
}
// MulByPow2 multiplies p1 by 2^pow2 and writes the result on p2.
-func (r *Ring) MulByPow2(p1 *Poly, pow2 uint64, p2 *Poly) {
+func (r *Ring) MulByPow2(p1 *Poly, pow2 int, p2 *Poly) {
r.MForm(p1, p2)
for i, Qi := range r.Modulus {
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -1196,13 +1196,13 @@ func (r *Ring) MulByPow2(p1 *Poly, pow2 uint64, p2 *Poly) {
}
// MulByPow2Lvl multiplies p1 by 2^pow2 for the moduli from q_0 up to q_level and writes the result on p2.
-func (r *Ring) MulByPow2Lvl(level uint64, p1 *Poly, pow2 uint64, p2 *Poly) {
+func (r *Ring) MulByPow2Lvl(level int, p1 *Poly, pow2 int, p2 *Poly) {
r.MFormLvl(level, p1, p2)
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
qi := r.Modulus[i]
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -1220,14 +1220,14 @@ func (r *Ring) MulByPow2Lvl(level uint64, p1 *Poly, pow2 uint64, p2 *Poly) {
}
// MultByMonomialNew multiplies p1 by x^monomialDeg and writes the result on a new polynomial p2.
-func (r *Ring) MultByMonomialNew(p1 *Poly, monomialDeg uint64) (p2 *Poly) {
+func (r *Ring) MultByMonomialNew(p1 *Poly, monomialDeg int) (p2 *Poly) {
p2 = r.NewPoly()
r.MultByMonomial(p1, monomialDeg, p2)
return
}
// MultByMonomial multiplies p1 by x^monomialDeg and writes the result on p2.
-func (r *Ring) MultByMonomial(p1 *Poly, monomialDeg uint64, p2 *Poly) {
+func (r *Ring) MultByMonomial(p1 *Poly, monomialDeg int, p2 *Poly) {
shift := monomialDeg % (r.N << 1)
@@ -1235,7 +1235,7 @@ func (r *Ring) MultByMonomial(p1 *Poly, monomialDeg uint64, p2 *Poly) {
for i := range r.Modulus {
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
- for j := uint64(0); j < r.N; j++ {
+ for j := 0; j < r.N; j++ {
p2tmp[j] = p1tmp[j]
}
}
@@ -1248,7 +1248,7 @@ func (r *Ring) MultByMonomial(p1 *Poly, monomialDeg uint64, p2 *Poly) {
for i := range r.Modulus {
p1tmp, tmpxT := p1.Coeffs[i], tmpx.Coeffs[i]
- for j := uint64(0); j < r.N; j++ {
+ for j := 0; j < r.N; j++ {
tmpxT[j] = p1tmp[j]
}
}
@@ -1257,7 +1257,7 @@ func (r *Ring) MultByMonomial(p1 *Poly, monomialDeg uint64, p2 *Poly) {
for i, qi := range r.Modulus {
p1tmp, tmpxT := p1.Coeffs[i], tmpx.Coeffs[i]
- for j := uint64(0); j < r.N; j++ {
+ for j := 0; j < r.N; j++ {
tmpxT[j] = qi - p1tmp[j]
}
}
@@ -1267,7 +1267,7 @@ func (r *Ring) MultByMonomial(p1 *Poly, monomialDeg uint64, p2 *Poly) {
for i, qi := range r.Modulus {
p2tmp, tmpxT := p2.Coeffs[i], tmpx.Coeffs[i]
- for j := uint64(0); j < shift; j++ {
+ for j := 0; j < shift; j++ {
p2tmp[j] = qi - tmpxT[r.N-shift+j]
}
}
@@ -1287,7 +1287,7 @@ func (r *Ring) MulByVectorMontgomery(p1 *Poly, vector []uint64, p2 *Poly) {
for i, qi := range r.Modulus {
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&vector[j]))
@@ -1310,7 +1310,7 @@ func (r *Ring) MulByVectorMontgomeryAndAddNoMod(p1 *Poly, vector []uint64, p2 *P
for i, qi := range r.Modulus {
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
mredParams := r.MredParams[i]
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
y := (*[8]uint64)(unsafe.Pointer(&vector[j]))
@@ -1331,21 +1331,21 @@ func (r *Ring) MulByVectorMontgomeryAndAddNoMod(p1 *Poly, vector []uint64, p2 *P
// BitReverse applies a bit reverse permutation on the coefficients of p1 and writes the result on p2.
// In can safely be used for in-place permutation.
func (r *Ring) BitReverse(p1, p2 *Poly) {
- bitLenOfN := uint64(bits.Len64(r.N) - 1)
+ bitLenOfN := uint64(bits.Len64(uint64(r.N)) - 1)
if p1 != p2 {
for i := range r.Modulus {
p1tmp, p2tmp := p1.Coeffs[i], p2.Coeffs[i]
- for j := uint64(0); j < r.N; j++ {
- p2tmp[utils.BitReverse64(j, bitLenOfN)] = p1tmp[j]
+ for j := 0; j < r.N; j++ {
+ p2tmp[utils.BitReverse64(uint64(j), bitLenOfN)] = p1tmp[j]
}
}
} else { // In place in case p1 = p2
for x := range r.Modulus {
p2tmp := p2.Coeffs[x]
- for i := uint64(0); i < r.N; i++ {
- j := utils.BitReverse64(i, bitLenOfN)
- if i < j {
+ for i := 0; i < r.N; i++ {
+ j := utils.BitReverse64(uint64(i), bitLenOfN)
+ if i < int(j) {
p2tmp[i], p2tmp[j] = p2tmp[j], p2tmp[i]
}
}
@@ -1356,7 +1356,7 @@ func (r *Ring) BitReverse(p1, p2 *Poly) {
// Rotate applies a Galois automorphism on p1 in NTT form,
// rotating the coefficients to the right by n positions, and writes the result on p2.
// It requires the data to be permuted in bit-reversal order before applying the NTT.
-func (r *Ring) Rotate(p1 *Poly, n uint64, p2 *Poly) {
+func (r *Ring) Rotate(p1 *Poly, n int, p2 *Poly) {
var root, gal uint64
@@ -1374,7 +1374,7 @@ func (r *Ring) Rotate(p1 *Poly, n uint64, p2 *Poly) {
p1tmp, p2tmp := p1.Coeffs[i], p1.Coeffs[i]
- for j := uint64(1); j < r.N; j++ {
+ for j := 1; j < r.N; j++ {
gal = MRed(gal, root, qi, mredParams)
diff --git a/ring/ring_poly.go b/ring/ring_poly.go
index bde97089..bdaefc34 100644
--- a/ring/ring_poly.go
+++ b/ring/ring_poly.go
@@ -12,25 +12,30 @@ type Poly struct {
}
// NewPoly creates a new polynomial with N coefficients set to zero and nbModuli moduli.
-func NewPoly(N, nbModuli uint64) (pol *Poly) {
+func NewPoly(N, nbModuli int) (pol *Poly) {
pol = new(Poly)
pol.Coeffs = make([][]uint64, nbModuli)
- for i := uint64(0); i < nbModuli; i++ {
+ for i := 0; i < nbModuli; i++ {
pol.Coeffs[i] = make([]uint64, N)
}
return
}
-// GetDegree returns the number of coefficients of the polynomial, which equals the degree of the Ring cyclotomic polynomial.
-func (pol *Poly) GetDegree() int {
+// Degree returns the number of coefficients of the polynomial, which equals the degree of the Ring cyclotomic polynomial.
+func (pol *Poly) Degree() int {
return len(pol.Coeffs[0])
}
-// GetLenModuli returns the number of moduli.
-func (pol *Poly) GetLenModuli() int {
+// LenModuli returns the current number of moduli.
+func (pol *Poly) LenModuli() int {
return len(pol.Coeffs)
}
+// Level returns the current number of moduli minus 1.
+func (pol *Poly) Level() int {
+ return len(pol.Coeffs) - 1
+}
+
// Zero sets all coefficients of the target polynomial to 0.
func (pol *Poly) Zero() {
for i := range pol.Coeffs {
@@ -62,7 +67,7 @@ func (r *Ring) Copy(p0, p1 *Poly) {
if p0 != p1 {
for i := range r.Modulus {
p0tmp, p1tmp := p0.Coeffs[i], p1.Coeffs[i]
- for j := uint64(0); j < r.N; j++ {
+ for j := 0; j < r.N; j++ {
p1tmp[j] = p0tmp[j]
}
}
@@ -71,12 +76,12 @@ func (r *Ring) Copy(p0, p1 *Poly) {
// CopyLvl copies the coefficients of p0 on p1 within the given Ring for the moduli from 0 to level.
// Requires p1 to be as big as the target Ring.
-func (r *Ring) CopyLvl(level uint64, p0, p1 *Poly) {
+func (r *Ring) CopyLvl(level int, p0, p1 *Poly) {
if p0 != p1 {
- for i := uint64(0); i < level+1; i++ {
+ for i := 0; i < level+1; i++ {
p0tmp, p1tmp := p0.Coeffs[i], p1.Coeffs[i]
- for j := uint64(0); j < r.N; j++ {
+ for j := 0; j < r.N; j++ {
p1tmp[j] = p0tmp[j]
}
}
@@ -120,10 +125,10 @@ func (pol *Poly) GetCoefficients() (coeffs [][]uint64) {
}
// WriteCoeffsTo converts a matrix of coefficients to a byte array.
-func WriteCoeffsTo(pointer, N, numberModuli uint64, coeffs [][]uint64, data []byte) (uint64, error) {
+func WriteCoeffsTo(pointer, N, numberModuli int, coeffs [][]uint64, data []byte) (int, error) {
tmp := N << 3
- for i := uint64(0); i < numberModuli; i++ {
- for j := uint64(0); j < N; j++ {
+ for i := 0; i < numberModuli; i++ {
+ for j := 0; j < N; j++ {
binary.BigEndian.PutUint64(data[pointer+(j<<3):pointer+((j+1)<<3)], coeffs[i][j])
}
pointer += tmp
@@ -134,12 +139,12 @@ func WriteCoeffsTo(pointer, N, numberModuli uint64, coeffs [][]uint64, data []by
// WriteTo writes the given poly to the data array.
// It returns the number of written bytes, and the corresponding error, if it occurred.
-func (pol *Poly) WriteTo(data []byte) (uint64, error) {
+func (pol *Poly) WriteTo(data []byte) (int, error) {
- N := uint64(pol.GetDegree())
- numberModuli := uint64(pol.GetLenModuli())
+ N := pol.Degree()
+ numberModuli := pol.LenModuli()
- if uint64(len(data)) < pol.GetDataLen(true) {
+ if len(data) < pol.GetDataLen(true) {
// The data is not big enough to write all the information
return 0, errors.New("data array is too small to write ring.Poly")
}
@@ -153,12 +158,12 @@ func (pol *Poly) WriteTo(data []byte) (uint64, error) {
// WriteTo32 writes the given poly to the data array.
// It returns the number of written bytes, and the corresponding error, if it occurred.
-func (pol *Poly) WriteTo32(data []byte) (uint64, error) {
+func (pol *Poly) WriteTo32(data []byte) (int, error) {
- N := uint64(pol.GetDegree())
- numberModuli := uint64(pol.GetLenModuli())
+ N := pol.Degree()
+ numberModuli := pol.LenModuli()
- if uint64(len(data)) < pol.GetDataLen32(true) {
+ if len(data) < pol.GetDataLen32(true) {
//The data is not big enough to write all the information
return 0, errors.New("data array is too small to write ring.Poly")
}
@@ -171,10 +176,10 @@ func (pol *Poly) WriteTo32(data []byte) (uint64, error) {
}
// WriteCoeffsTo32 converts a matrix of coefficients to a byte array.
-func WriteCoeffsTo32(pointer, N, numberModuli uint64, coeffs [][]uint64, data []byte) (uint64, error) {
+func WriteCoeffsTo32(pointer, N, numberModuli int, coeffs [][]uint64, data []byte) (int, error) {
tmp := N << 2
- for i := uint64(0); i < numberModuli; i++ {
- for j := uint64(0); j < N; j++ {
+ for i := 0; i < numberModuli; i++ {
+ for j := 0; j < N; j++ {
binary.BigEndian.PutUint32(data[pointer+(j<<2):pointer+((j+1)<<2)], uint32(coeffs[i][j]))
}
pointer += tmp
@@ -185,8 +190,8 @@ func WriteCoeffsTo32(pointer, N, numberModuli uint64, coeffs [][]uint64, data []
// GetDataLen32 returns the number of bytes the polynomial will take when written to data.
// It can take into account meta data if necessary.
-func (pol *Poly) GetDataLen32(WithMetadata bool) (cnt uint64) {
- cnt = uint64((pol.GetLenModuli() * pol.GetDegree()) << 2)
+func (pol *Poly) GetDataLen32(WithMetadata bool) (cnt int) {
+ cnt = (pol.LenModuli() * pol.Degree()) << 2
if WithMetadata {
cnt += 2
@@ -196,17 +201,15 @@ func (pol *Poly) GetDataLen32(WithMetadata bool) (cnt uint64) {
// WriteCoeffs writes the coefficients to the given data array.
// It fails if the data array is not big enough to contain the ring.Poly
-func (pol *Poly) WriteCoeffs(data []byte) (uint64, error) {
-
- cnt, err := WriteCoeffsTo(0, uint64(pol.GetDegree()), uint64(pol.GetLenModuli()), pol.Coeffs, data)
- return cnt, err
+func (pol *Poly) WriteCoeffs(data []byte) (int, error) {
+ return WriteCoeffsTo(0, pol.Degree(), pol.LenModuli(), pol.Coeffs, data)
}
// GetDataLen returns the number of bytes the polynomial will take when written to data.
// It can take into account meta data if necessary.
-func (pol *Poly) GetDataLen(WithMetadata bool) (cnt uint64) {
- cnt = uint64((pol.GetLenModuli() * pol.GetDegree()) << 3)
+func (pol *Poly) GetDataLen(WithMetadata bool) (cnt int) {
+ cnt = (pol.LenModuli() * pol.Degree()) << 3
if WithMetadata {
cnt += 2
@@ -215,10 +218,10 @@ func (pol *Poly) GetDataLen(WithMetadata bool) (cnt uint64) {
}
// DecodeCoeffs converts a byte array to a matrix of coefficients.
-func DecodeCoeffs(pointer, N, numberModuli uint64, coeffs [][]uint64, data []byte) (uint64, error) {
+func DecodeCoeffs(pointer, N, numberModuli int, coeffs [][]uint64, data []byte) (int, error) {
tmp := N << 3
- for i := uint64(0); i < numberModuli; i++ {
- for j := uint64(0); j < N; j++ {
+ for i := 0; i < numberModuli; i++ {
+ for j := 0; j < N; j++ {
coeffs[i][j] = binary.BigEndian.Uint64(data[pointer+(j<<3) : pointer+((j+1)<<3)])
}
pointer += tmp
@@ -228,11 +231,11 @@ func DecodeCoeffs(pointer, N, numberModuli uint64, coeffs [][]uint64, data []byt
}
// DecodeCoeffsNew converts a byte array to a matrix of coefficients.
-func DecodeCoeffsNew(pointer, N, numberModuli uint64, coeffs [][]uint64, data []byte) (uint64, error) {
+func DecodeCoeffsNew(pointer, N, numberModuli int, coeffs [][]uint64, data []byte) (int, error) {
tmp := N << 3
- for i := uint64(0); i < numberModuli; i++ {
+ for i := 0; i < numberModuli; i++ {
coeffs[i] = make([]uint64, N)
- for j := uint64(0); j < N; j++ {
+ for j := 0; j < N; j++ {
coeffs[i][j] = binary.BigEndian.Uint64(data[pointer+(j<<3) : pointer+((j+1)<<3)])
}
pointer += tmp
@@ -268,10 +271,10 @@ func (pol *Poly) UnmarshalBinary(data []byte) (err error) {
// DecodePolyNew decodes a slice of bytes in the target polynomial returns the number of bytes
// decoded.
-func (pol *Poly) DecodePolyNew(data []byte) (pointer uint64, err error) {
+func (pol *Poly) DecodePolyNew(data []byte) (pointer int, err error) {
- N := uint64(1 << data[0])
- numberModulies := uint64(data[1])
+ N := int(1 << data[0])
+ numberModulies := int(data[1])
pointer = 2
if pol.Coeffs == nil {
@@ -287,10 +290,10 @@ func (pol *Poly) DecodePolyNew(data []byte) (pointer uint64, err error) {
// DecodePolyNew32 decodes a slice of bytes in the target polynomial returns the number of bytes
// decoded.
-func (pol *Poly) DecodePolyNew32(data []byte) (pointer uint64, err error) {
+func (pol *Poly) DecodePolyNew32(data []byte) (pointer int, err error) {
- N := uint64(1 << data[0])
- numberModulies := uint64(data[1])
+ N := int(1 << data[0])
+ numberModulies := int(data[1])
pointer = 2
if pol.Coeffs == nil {
@@ -305,11 +308,11 @@ func (pol *Poly) DecodePolyNew32(data []byte) (pointer uint64, err error) {
}
// DecodeCoeffsNew32 converts a byte array to a matrix of coefficients.
-func DecodeCoeffsNew32(pointer, N, numberModuli uint64, coeffs [][]uint64, data []byte) (uint64, error) {
+func DecodeCoeffsNew32(pointer, N, numberModuli int, coeffs [][]uint64, data []byte) (int, error) {
tmp := N << 2
- for i := uint64(0); i < numberModuli; i++ {
+ for i := 0; i < numberModuli; i++ {
coeffs[i] = make([]uint64, N)
- for j := uint64(0); j < N; j++ {
+ for j := 0; j < N; j++ {
coeffs[i][j] = uint64(binary.BigEndian.Uint32(data[pointer+(j<<2) : pointer+((j+1)<<2)]))
}
pointer += tmp
diff --git a/ring/ring_sampler_gaussian.go b/ring/ring_sampler_gaussian.go
index b0b7f5db..9e8a1422 100644
--- a/ring/ring_sampler_gaussian.go
+++ b/ring/ring_sampler_gaussian.go
@@ -10,99 +10,106 @@ import (
// GaussianSampler keeps the state of a truncated Gaussian polynomial sampler.
type GaussianSampler struct {
baseSampler
+ sigma float64
+ bound int
randomBufferN []byte
ptr uint64
- sigma float64
- bound uint64
}
// NewGaussianSampler creates a new instance of GaussianSampler from a PRNG, a ring definition and the truncated
// Gaussian distribution parameters. Sigma is the desired standard deviation and bound is the maximum coefficient norm in absolute
// value.
-func NewGaussianSampler(prng utils.PRNG, baseRing *Ring, sigma float64, bound uint64) *GaussianSampler {
+func NewGaussianSampler(prng utils.PRNG, baseRing *Ring, sigma float64, bound int) *GaussianSampler {
gaussianSampler := new(GaussianSampler)
- gaussianSampler.baseRing = baseRing
gaussianSampler.prng = prng
- gaussianSampler.randomBufferN = make([]byte, baseRing.N)
+ gaussianSampler.randomBufferN = make([]byte, 1024)
gaussianSampler.ptr = 0
+ gaussianSampler.baseRing = baseRing
gaussianSampler.sigma = sigma
gaussianSampler.bound = bound
return gaussianSampler
}
-// Read samples a polynomial at the maximum level into pol
+// Read samples a truncated Gaussian polynomial on "pol" at the maximum level in the default ring, standard deviation and bound.
func (gaussianSampler *GaussianSampler) Read(pol *Poly) {
- gaussianSampler.ReadLvl(uint64(len(gaussianSampler.baseRing.Modulus)-1), pol)
+ gaussianSampler.ReadLvl(len(gaussianSampler.baseRing.Modulus)-1, pol)
}
-// ReadNew samples a new truncated Gaussian polynomial with
-// standard deviation sigma within the given bound using the Ziggurat algorithm.
+// ReadLvl samples a truncated Gaussian polynomial at the provided level, in the default ring, standard deviation and bound.
+func (gaussianSampler *GaussianSampler) ReadLvl(level int, pol *Poly) {
+ gaussianSampler.readLvl(level, pol, gaussianSampler.baseRing, gaussianSampler.sigma, gaussianSampler.bound)
+}
+
+// ReadNew samples a new truncated Gaussian polynomial at the maximum level in the default ring, standard deviation and bound.
func (gaussianSampler *GaussianSampler) ReadNew() (pol *Poly) {
pol = gaussianSampler.baseRing.NewPoly()
gaussianSampler.Read(pol)
return pol
}
-// ReadLvlNew samples a new truncated Gaussian polynomial with
-// standard deviation sigma within the given bound using the Ziggurat algorithm.
-func (gaussianSampler *GaussianSampler) ReadLvlNew(level uint64) (pol *Poly) {
+// ReadLvlNew samples a new truncated Gaussian polynomial at the provided level, in the default ring, standard deviation and bound.
+func (gaussianSampler *GaussianSampler) ReadLvlNew(level int) (pol *Poly) {
pol = gaussianSampler.baseRing.NewPolyLvl(level)
gaussianSampler.ReadLvl(level, pol)
return pol
}
-// ReadLvl samples a polynomial at the given level into pol.
-func (gaussianSampler *GaussianSampler) ReadLvl(level uint64, pol *Poly) {
+// ReadFromDistLvl samples a truncated Gaussian polynomial at the given level in the provided ring, standard deviation and bound.
+func (gaussianSampler *GaussianSampler) ReadFromDistLvl(level int, pol *Poly, ring *Ring, sigma float64, bound int) {
+ gaussianSampler.readLvl(level, pol, ring, sigma, bound)
+}
+// ReadAndAddLvl samples a truncated Gaussian polynomial at the given level for the receiver's default standard deviation and bound and adds it on "pol".
+func (gaussianSampler *GaussianSampler) ReadAndAddLvl(level int, pol *Poly) {
+ gaussianSampler.ReadAndAddFromDistLvl(level, pol, gaussianSampler.baseRing, gaussianSampler.sigma, gaussianSampler.bound)
+}
+
+// ReadAndAddFromDistLvl samples a truncated Gaussian polynomial at the given level in the provided ring, standard deviation and bound and adds it on "pol".
+func (gaussianSampler *GaussianSampler) ReadAndAddFromDistLvl(level int, pol *Poly, ring *Ring, sigma float64, bound int) {
var coeffFlo float64
- var coeffInt uint64
- var sign uint64
+ var coeffInt, sign uint64
gaussianSampler.prng.Clock(gaussianSampler.randomBufferN)
- for i := uint64(0); i < gaussianSampler.baseRing.N; i++ {
+ modulus := ring.Modulus[:level+1]
+
+ for i := 0; i < ring.N; i++ {
for {
coeffFlo, sign = gaussianSampler.normFloat64()
- if coeffInt = uint64(coeffFlo * gaussianSampler.sigma); coeffInt <= gaussianSampler.bound {
+ if coeffInt = uint64(coeffFlo*sigma + 0.5); coeffInt <= uint64(bound) {
break
}
}
- for j, qi := range gaussianSampler.baseRing.Modulus[:level+1] {
- pol.Coeffs[j][i] = (coeffInt * sign) | (qi-coeffInt)*(sign^1)
+ for j, qi := range modulus {
+ pol.Coeffs[j][i] = CRed(pol.Coeffs[j][i]+((coeffInt*sign)|(qi-coeffInt)*(sign^1)), qi)
}
}
}
-// ReadAndAdd adds on the input polynomial a truncated Gaussian polynomial of at the maximum level
-// with standard deviation sigma within the given bound using the Ziggurat algorithm.
-func (gaussianSampler *GaussianSampler) ReadAndAdd(pol *Poly) {
- gaussianSampler.ReadAndAddLvl(uint64(len(gaussianSampler.baseRing.Modulus)-1), pol)
-}
-
-// ReadAndAddLvl samples and adds a polynomial at the given level directly into pol. pol must be at the given level.
-func (gaussianSampler *GaussianSampler) ReadAndAddLvl(level uint64, pol *Poly) {
-
+func (gaussianSampler *GaussianSampler) readLvl(level int, pol *Poly, ring *Ring, sigma float64, bound int) {
var coeffFlo float64
var coeffInt uint64
var sign uint64
gaussianSampler.prng.Clock(gaussianSampler.randomBufferN)
- for i := uint64(0); i < gaussianSampler.baseRing.N; i++ {
+ modulus := ring.Modulus[:level+1]
+
+ for i := 0; i < ring.N; i++ {
for {
coeffFlo, sign = gaussianSampler.normFloat64()
- if coeffInt = uint64(coeffFlo * gaussianSampler.sigma); coeffInt <= gaussianSampler.bound {
+ if coeffInt = uint64(coeffFlo*sigma + 0.5); coeffInt <= uint64(bound) {
break
}
}
- for j, qi := range gaussianSampler.baseRing.Modulus[:level+1] {
- pol.Coeffs[j][i] = CRed(pol.Coeffs[j][i]+((coeffInt*sign)|(qi-coeffInt)*(sign^1)), qi)
+ for j, qi := range modulus {
+ pol.Coeffs[j][i] = (coeffInt * sign) | (qi-coeffInt)*(sign^1)
}
}
}
diff --git a/ring/ring_sampler_ternary.go b/ring/ring_sampler_ternary.go
index f1cbc8cf..f7d3bec3 100644
--- a/ring/ring_sampler_ternary.go
+++ b/ring/ring_sampler_ternary.go
@@ -13,8 +13,8 @@ type TernarySampler struct {
matrixProba [2][precision - 1]uint8
matrixValues [][3]uint64
p float64
- hw uint64
- sample func(lvl uint64, poly *Poly)
+ hw int
+ sample func(lvl int, poly *Poly)
}
// NewTernarySampler creates a new instance of TernarySampler from a PRNG, the ring definition and the distribution
@@ -39,7 +39,7 @@ func NewTernarySampler(prng utils.PRNG, baseRing *Ring, p float64, montgomery bo
// NewTernarySamplerSparse creates a new instance of a fixed-hamming-weight TernarySampler from a PRNG, the ring definition and the desired
// hamming weight for the output polynomials. If "montgomery" is set to true, polynomials read from this sampler
// are in Montgomery form.
-func NewTernarySamplerSparse(prng utils.PRNG, baseRing *Ring, hw uint64, montgomery bool) *TernarySampler {
+func NewTernarySamplerSparse(prng utils.PRNG, baseRing *Ring, hw int, montgomery bool) *TernarySampler {
ternarySampler := new(TernarySampler)
ternarySampler.baseRing = baseRing
ternarySampler.prng = prng
@@ -53,23 +53,23 @@ func NewTernarySamplerSparse(prng utils.PRNG, baseRing *Ring, hw uint64, montgom
// Read samples a polynomial into pol.
func (ts *TernarySampler) Read(pol *Poly) {
- ts.sample(uint64(len(ts.baseRing.Modulus)-1), pol)
+ ts.sample(len(ts.baseRing.Modulus)-1, pol)
}
// ReadLvl samples a polynomial into pol at the speciefied level.
-func (ts *TernarySampler) ReadLvl(lvl uint64, pol *Poly) {
+func (ts *TernarySampler) ReadLvl(lvl int, pol *Poly) {
ts.sample(lvl, pol)
}
// ReadNew allocates and samples a polynomial at the max level.
func (ts *TernarySampler) ReadNew() (pol *Poly) {
pol = ts.baseRing.NewPoly()
- ts.sample(uint64(len(ts.baseRing.Modulus)-1), pol)
+ ts.sample(len(ts.baseRing.Modulus)-1, pol)
return pol
}
// ReadLvlNew allocates and samples a polynomial at the speficied level.
-func (ts *TernarySampler) ReadLvlNew(lvl uint64) (pol *Poly) {
+func (ts *TernarySampler) ReadLvlNew(lvl int) (pol *Poly) {
pol = ts.baseRing.NewPolyLvl(lvl)
ts.sample(lvl, pol)
return pol
@@ -118,7 +118,7 @@ func (ts *TernarySampler) computeMatrixTernary(p float64) {
}
-func (ts *TernarySampler) sampleProba(lvl uint64, pol *Poly) {
+func (ts *TernarySampler) sampleProba(lvl int, pol *Poly) {
if ts.p == 0 {
panic("cannot sample -> p = 0")
@@ -137,13 +137,13 @@ func (ts *TernarySampler) sampleProba(lvl uint64, pol *Poly) {
ts.prng.Clock(randomBytesSign)
- for i := uint64(0); i < ts.baseRing.N; i++ {
+ for i := 0; i < ts.baseRing.N; i++ {
coeff = uint64(uint8(randomBytesCoeffs[i>>3])>>(i&7)) & 1
sign = uint64(uint8(randomBytesSign[i>>3])>>(i&7)) & 1
index = (coeff & (sign ^ 1)) | ((sign & coeff) << 1)
- for j := uint64(0); j < lvl+1; j++ {
+ for j := 0; j < lvl+1; j++ {
pol.Coeffs[j][i] = ts.matrixValues[j][index]
}
}
@@ -153,24 +153,24 @@ func (ts *TernarySampler) sampleProba(lvl uint64, pol *Poly) {
randomBytes := make([]byte, ts.baseRing.N)
pointer := uint8(0)
- bytePointer := uint64(0)
+ var bytePointer int
ts.prng.Clock(randomBytes)
- for i := uint64(0); i < ts.baseRing.N; i++ {
+ for i := 0; i < ts.baseRing.N; i++ {
coeff, sign, randomBytes, pointer, bytePointer = ts.kysampling(ts.prng, randomBytes, pointer, bytePointer, ts.baseRing.N)
index = (coeff & (sign ^ 1)) | ((sign & coeff) << 1)
- for j := uint64(0); j < lvl+1; j++ {
+ for j := 0; j < lvl+1; j++ {
pol.Coeffs[j][i] = ts.matrixValues[j][index]
}
}
}
}
-func (ts *TernarySampler) sampleSparse(lvl uint64, pol *Poly) {
+func (ts *TernarySampler) sampleSparse(lvl int, pol *Poly) {
if ts.hw > ts.baseRing.N {
ts.hw = ts.baseRing.N
@@ -179,8 +179,8 @@ func (ts *TernarySampler) sampleSparse(lvl uint64, pol *Poly) {
var mask, j uint64
var coeff uint8
- index := make([]uint64, ts.baseRing.N)
- for i := uint64(0); i < ts.baseRing.N; i++ {
+ index := make([]int, ts.baseRing.N)
+ for i := 0; i < ts.baseRing.N; i++ {
index[i] = i
}
@@ -189,16 +189,16 @@ func (ts *TernarySampler) sampleSparse(lvl uint64, pol *Poly) {
ts.prng.Clock(randomBytes)
- for i := uint64(0); i < ts.hw; i++ {
- mask = (1 << uint64(bits.Len64(ts.baseRing.N-i))) - 1 // rejection sampling of a random variable between [0, len(index)]
+ for i := 0; i < ts.hw; i++ {
+ mask = (1 << uint64(bits.Len64(uint64(ts.baseRing.N-i)))) - 1 // rejection sampling of a random variable between [0, len(index)]
j = randInt32(ts.prng, mask)
- for j >= ts.baseRing.N-i {
+ for j >= uint64(ts.baseRing.N-i) {
j = randInt32(ts.prng, mask)
}
coeff = (uint8(randomBytes[0]) >> (i & 7)) & 1 // random binary digit [0, 1] from the random bytes (0 = 1, 1 = -1)
- for k := uint64(0); k < lvl+1; k++ {
+ for k := 0; k < lvl+1; k++ {
pol.Coeffs[k][index[j]] = ts.matrixValues[k][coeff+1]
}
@@ -216,7 +216,7 @@ func (ts *TernarySampler) sampleSparse(lvl uint64, pol *Poly) {
}
// kysampling uses the binary expansion and random bytes matrix to sample a discrete Gaussian value and its sign.
-func (ts *TernarySampler) kysampling(prng utils.PRNG, randomBytes []byte, pointer uint8, bytePointer uint64, byteLength uint64) (uint64, uint64, []byte, uint8, uint64) {
+func (ts *TernarySampler) kysampling(prng utils.PRNG, randomBytes []byte, pointer uint8, bytePointer, byteLength int) (uint64, uint64, []byte, uint8, int) {
var sign uint8
diff --git a/ring/ring_sampler_uniform.go b/ring/ring_sampler_uniform.go
index a322cae7..b3224290 100644
--- a/ring/ring_sampler_uniform.go
+++ b/ring/ring_sampler_uniform.go
@@ -25,7 +25,7 @@ func NewUniformSampler(prng utils.PRNG, baseRing *Ring) *UniformSampler {
func (uniformSampler *UniformSampler) Read(Pol *Poly) {
var randomUint, mask, qi uint64
- var ptr uint64
+ var ptr int
uniformSampler.prng.Clock(uniformSampler.randomBufferN)
@@ -39,7 +39,7 @@ func (uniformSampler *UniformSampler) Read(Pol *Poly) {
ptmp := Pol.Coeffs[j]
// Iterate for each modulus over each coefficient
- for i := uint64(0); i < uniformSampler.baseRing.N; i++ {
+ for i := 0; i < uniformSampler.baseRing.N; i++ {
// Sample an integer between [0, qi-1]
for {
@@ -66,14 +66,14 @@ func (uniformSampler *UniformSampler) Read(Pol *Poly) {
}
// Readlvl generates a new polynomial with coefficients following a uniform distribution over [0, Qi-1].
-func (uniformSampler *UniformSampler) Readlvl(level uint64, Pol *Poly) {
+func (uniformSampler *UniformSampler) Readlvl(level int, Pol *Poly) {
var randomUint, mask, qi uint64
- var ptr uint64
+ var ptr int
uniformSampler.prng.Clock(uniformSampler.randomBufferN)
- for j := uint64(0); j < level+1; j++ {
+ for j := 0; j < level+1; j++ {
qi = uniformSampler.baseRing.Modulus[j]
@@ -83,7 +83,7 @@ func (uniformSampler *UniformSampler) Readlvl(level uint64, Pol *Poly) {
ptmp := Pol.Coeffs[j]
// Iterate for each modulus over each coefficient
- for i := uint64(0); i < uniformSampler.baseRing.N; i++ {
+ for i := 0; i < uniformSampler.baseRing.N; i++ {
// Sample an integer between [0, qi-1]
for {
@@ -119,7 +119,7 @@ func (uniformSampler *UniformSampler) ReadNew() (Pol *Poly) {
// ReadLvlNew generates a new polynomial with coefficients following a uniform distribution over [0, Qi-1].
// Polynomial is created at the specified level.
-func (uniformSampler *UniformSampler) ReadLvlNew(level uint64) (Pol *Poly) {
+func (uniformSampler *UniformSampler) ReadLvlNew(level int) (Pol *Poly) {
Pol = uniformSampler.baseRing.NewPolyLvl(level)
uniformSampler.Read(Pol)
return
diff --git a/ring/ring_scaling.go b/ring/ring_scaling.go
index 85b94a11..d011917a 100644
--- a/ring/ring_scaling.go
+++ b/ring/ring_scaling.go
@@ -43,7 +43,7 @@ func NewRNSScaler(t uint64, ringQ *Ring) (rnss *RNSScaler) {
rnss.t = t
rnss.qHalf = new(big.Int)
rnss.qInv = rnss.qHalf.Mod(ringQ.ModulusBigint, NewUint(t)).Uint64()
- rnss.qInv = ModExp(rnss.qInv, t-2, t)
+ rnss.qInv = ModExp(rnss.qInv, int(t-2), t)
rnss.qInv = MForm(rnss.qInv, t, BRedParams(t))
rnss.qHalf.Set(ringQ.ModulusBigint)
@@ -82,7 +82,7 @@ func (rnss *RNSScaler) DivByQOverTRounded(p1Q, p2T *Poly) {
modUpExact(p1Q.Coeffs, rnss.polypoolT.Coeffs, rnss.paramsQP)
// Compute [Q^{-1} * (t*P_{t} - (t*P_{Q} - ((Q-1)/2 mod t)))] mod t which returns round(t/Q * P_{Q}) mod t
- for j := uint64(0); j < ringQ.N; j = j + 8 {
+ for j := 0; j < ringQ.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p3tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
@@ -260,7 +260,7 @@ func (ss *SimpleScaler) reconstructThenScale(p1, p2 *Poly) {
// Algorithm from https://eprint.iacr.org/2018/117.pdf.
func (ss *SimpleScaler) reconstructAndScale(p1, p2 *Poly) {
- for i := uint64(0); i < ss.ringQ.N; i++ {
+ for i := 0; i < ss.ringQ.N; i++ {
var a uint64
var bBF big.Float
@@ -289,112 +289,172 @@ func (ss *SimpleScaler) reconstructAndScale(p1, p2 *Poly) {
// ============== Scaling-related methods ==============
// DivFloorByLastModulusNTT divides (floored) the polynomial by its last modulus. The input must be in the NTT domain.
-func (r *Ring) DivFloorByLastModulusNTT(p0 *Poly) {
+// Output poly level must be equal or one less than input level.
+func (r *Ring) DivFloorByLastModulusNTT(p0, p1 *Poly) {
+ r.divFloorByLastModulusNTT(p0.Level(), p0, p1)
+ p1.Coeffs = p1.Coeffs[:p0.Level()]
+}
- level := len(p0.Coeffs) - 1
+func (r *Ring) divFloorByLastModulusNTT(level int, p0, p1 *Poly) {
- pTmp := make([]uint64, r.N)
+ pool0 := r.polypool.Coeffs[0]
+ pool1 := r.polypool.Coeffs[1]
- InvNTTLazy(p0.Coeffs[level], p0.Coeffs[level], r.N, r.NttPsiInv[level], r.NttNInv[level], r.Modulus[level], r.MredParams[level])
+ InvNTTLazy(p0.Coeffs[level], pool0, r.N, r.NttPsiInv[level], r.NttNInv[level], r.Modulus[level], r.MredParams[level])
for i := 0; i < level; i++ {
- NTTLazy(p0.Coeffs[level], pTmp, r.N, r.NttPsi[i], r.Modulus[i], r.MredParams[i], r.BredParams[i])
+ NTTLazy(pool0, pool1, r.N, r.NttPsi[i], r.Modulus[i], r.MredParams[i], r.BredParams[i])
p0tmp := p0.Coeffs[i]
+ p1tmp := p1.Coeffs[i]
qi := r.Modulus[i]
twoqi := qi << 1
- mredParams := r.MredParams[i]
- rescalParams := qi - r.RescaleParams[level-1][i]
+ qInv := r.MredParams[i]
+ rescaleParams := r.RescaleParams[level-1][i]
// (x[i] - x[-1]) * InvQ
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
- x := (*[8]uint64)(unsafe.Pointer(&pTmp[j]))
- z := (*[8]uint64)(unsafe.Pointer(&p0tmp[j]))
+ x := (*[8]uint64)(unsafe.Pointer(&pool1[j]))
+ y := (*[8]uint64)(unsafe.Pointer(&p0tmp[j]))
+ z := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
- z[0] = MRed(twoqi-z[0]+x[0], rescalParams, qi, mredParams)
- z[1] = MRed(twoqi-z[1]+x[1], rescalParams, qi, mredParams)
- z[2] = MRed(twoqi-z[2]+x[2], rescalParams, qi, mredParams)
- z[3] = MRed(twoqi-z[3]+x[3], rescalParams, qi, mredParams)
- z[4] = MRed(twoqi-z[4]+x[4], rescalParams, qi, mredParams)
- z[5] = MRed(twoqi-z[5]+x[5], rescalParams, qi, mredParams)
- z[6] = MRed(twoqi-z[6]+x[6], rescalParams, qi, mredParams)
- z[7] = MRed(twoqi-z[7]+x[7], rescalParams, qi, mredParams)
+ z[0] = MRed(twoqi-y[0]+x[0], rescaleParams, qi, qInv)
+ z[1] = MRed(twoqi-y[1]+x[1], rescaleParams, qi, qInv)
+ z[2] = MRed(twoqi-y[2]+x[2], rescaleParams, qi, qInv)
+ z[3] = MRed(twoqi-y[3]+x[3], rescaleParams, qi, qInv)
+ z[4] = MRed(twoqi-y[4]+x[4], rescaleParams, qi, qInv)
+ z[5] = MRed(twoqi-y[5]+x[5], rescaleParams, qi, qInv)
+ z[6] = MRed(twoqi-y[6]+x[6], rescaleParams, qi, qInv)
+ z[7] = MRed(twoqi-y[7]+x[7], rescaleParams, qi, qInv)
}
}
-
- p0.Coeffs = p0.Coeffs[:level]
}
// DivFloorByLastModulus divides (floored) the polynomial by its last modulus.
-func (r *Ring) DivFloorByLastModulus(p0 *Poly) {
+// Output poly level must be equal or one less than input level.
+func (r *Ring) DivFloorByLastModulus(p0, p1 *Poly) {
+ r.divFloorByLastModulus(p0.Level(), p0, p1)
+ p1.Coeffs = p1.Coeffs[:p0.Level()]
+}
- level := len(p0.Coeffs) - 1
+func (r *Ring) divFloorByLastModulus(level int, p0, p1 *Poly) {
for i := 0; i < level; i++ {
p0tmp := p0.Coeffs[level]
p1tmp := p0.Coeffs[i]
+ p2tmp := p1.Coeffs[i]
qi := r.Modulus[i]
twoqi := qi << 1
- bredParams := r.BredParams[i]
- mredParams := r.MredParams[i]
- rescaleParams := qi - r.RescaleParams[level-1][i]
+ qInv := r.MredParams[i]
+ rescaleParams := r.RescaleParams[level-1][i]
+
// (x[i] - x[-1]) * InvQ
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p0tmp[j]))
- z := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
+ y := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
+ z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
- z[0] = MRed(twoqi-z[0]+BRedAdd(x[0], qi, bredParams), rescaleParams, qi, mredParams)
- z[1] = MRed(twoqi-z[1]+BRedAdd(x[1], qi, bredParams), rescaleParams, qi, mredParams)
- z[2] = MRed(twoqi-z[2]+BRedAdd(x[2], qi, bredParams), rescaleParams, qi, mredParams)
- z[3] = MRed(twoqi-z[3]+BRedAdd(x[3], qi, bredParams), rescaleParams, qi, mredParams)
- z[4] = MRed(twoqi-z[4]+BRedAdd(x[4], qi, bredParams), rescaleParams, qi, mredParams)
- z[5] = MRed(twoqi-z[5]+BRedAdd(x[5], qi, bredParams), rescaleParams, qi, mredParams)
- z[6] = MRed(twoqi-z[6]+BRedAdd(x[6], qi, bredParams), rescaleParams, qi, mredParams)
- z[7] = MRed(twoqi-z[7]+BRedAdd(x[7], qi, bredParams), rescaleParams, qi, mredParams)
+ z[0] = MRed(twoqi-y[0]+x[0], rescaleParams, qi, qInv)
+ z[1] = MRed(twoqi-y[1]+x[1], rescaleParams, qi, qInv)
+ z[2] = MRed(twoqi-y[2]+x[2], rescaleParams, qi, qInv)
+ z[3] = MRed(twoqi-y[3]+x[3], rescaleParams, qi, qInv)
+ z[4] = MRed(twoqi-y[4]+x[4], rescaleParams, qi, qInv)
+ z[5] = MRed(twoqi-y[5]+x[5], rescaleParams, qi, qInv)
+ z[6] = MRed(twoqi-y[6]+x[6], rescaleParams, qi, qInv)
+ z[7] = MRed(twoqi-y[7]+x[7], rescaleParams, qi, qInv)
}
}
-
- p0.Coeffs = p0.Coeffs[:level]
}
// DivFloorByLastModulusManyNTT divides (floored) sequentially nbRescales times the polynomial by its last modulus. Input must be in the NTT domain.
-func (r *Ring) DivFloorByLastModulusManyNTT(p0 *Poly, nbRescales uint64) {
- r.InvNTTLvl(uint64(len(p0.Coeffs)-1), p0, p0)
- r.DivFloorByLastModulusMany(p0, nbRescales)
- r.NTTLvl(uint64(len(p0.Coeffs)-1), p0, p0)
-}
+// Output poly level must be equal or nbRescales less than input level.
+func (r *Ring) DivFloorByLastModulusManyNTT(p0, p1 *Poly, nbRescales int) {
-// DivFloorByLastModulusMany divides (floored) sequentially nbRescales times the polynomial by its last modulus.
-func (r *Ring) DivFloorByLastModulusMany(p0 *Poly, nbRescales uint64) {
- for k := uint64(0); k < nbRescales; k++ {
- r.DivFloorByLastModulus(p0)
+ level := p0.Level()
+
+ if nbRescales == 0 {
+
+ if p0 != p1 {
+ r.CopyLvl(p1.Level(), p0, p1)
+ }
+
+ } else {
+
+ r.InvNTTLvl(level, p0, r.polypool)
+
+ for i := 0; i < nbRescales; i++ {
+ r.divFloorByLastModulus(level-i, r.polypool, r.polypool)
+ }
+
+ p1.Coeffs = p1.Coeffs[:level-nbRescales+1]
+
+ r.NTTLvl(p1.Level(), r.polypool, p1)
}
}
+// DivFloorByLastModulusMany divides (floored) sequentially nbRescales times the polynomial by its last modulus.
+// Output poly level must be equal or nbRescales less than input level.
+func (r *Ring) DivFloorByLastModulusMany(p0, p1 *Poly, nbRescales int) {
+
+ level := p0.Level()
+
+ if nbRescales == 0 {
+
+ if p0 != p1 {
+ r.CopyLvl(p1.Level(), p0, p1)
+ }
+
+ } else {
+
+ if nbRescales > 1 {
+ r.divFloorByLastModulus(level, p0, r.polypool)
+
+ for i := 1; i < nbRescales; i++ {
+
+ if i == nbRescales-1 {
+ r.divFloorByLastModulus(level-i, r.polypool, p1)
+ } else {
+ r.divFloorByLastModulus(level-i, r.polypool, r.polypool)
+ }
+ }
+
+ } else {
+ r.divFloorByLastModulus(level, p0, p1)
+ }
+
+ p1.Coeffs = p1.Coeffs[:level-nbRescales+1]
+ }
+
+}
+
// DivRoundByLastModulusNTT divides (rounded) the polynomial by its last modulus. The input must be in the NTT domain.
-func (r *Ring) DivRoundByLastModulusNTT(p0 *Poly) {
+// Output poly level must be equal or one less than input level.
+func (r *Ring) DivRoundByLastModulusNTT(p0, p1 *Poly) {
+ r.divRoundByLastModulusNTT(p0.Level(), p0, p1)
+ p1.Coeffs = p1.Coeffs[:p0.Level()]
+}
+
+func (r *Ring) divRoundByLastModulusNTT(level int, p0, p1 *Poly) {
var pHalf, pHalfNegQi uint64
- level := len(p0.Coeffs) - 1
+ pool0 := r.polypool.Coeffs[0]
+ pool1 := r.polypool.Coeffs[1]
- pTmp := make([]uint64, r.N)
-
- InvNTT(p0.Coeffs[level], p0.Coeffs[level], r.N, r.NttPsiInv[level], r.NttNInv[level], r.Modulus[level], r.MredParams[level])
+ InvNTT(p0.Coeffs[level], pool0, r.N, r.NttPsiInv[level], r.NttNInv[level], r.Modulus[level], r.MredParams[level])
// Center by (p-1)/2
pj := r.Modulus[level]
pHalf = (pj - 1) >> 1
- p0tmp := p0.Coeffs[level]
- for i := uint64(0); i < r.N; i = i + 8 {
+ for i := 0; i < r.N; i = i + 8 {
- z := (*[8]uint64)(unsafe.Pointer(&p0tmp[i]))
+ z := (*[8]uint64)(unsafe.Pointer(&pool0[i]))
z[0] = CRed(z[0]+pHalf, pj)
z[1] = CRed(z[1]+pHalf, pj)
@@ -408,19 +468,21 @@ func (r *Ring) DivRoundByLastModulusNTT(p0 *Poly) {
for i := 0; i < level; i++ {
- p1tmp := p0.Coeffs[i]
+ p0tmp := p0.Coeffs[i]
+ p1tmp := p1.Coeffs[i]
qi := r.Modulus[i]
twoqi := qi << 1
+ qInv := r.MredParams[i]
bredParams := r.BredParams[i]
- mredParams := r.MredParams[i]
- rescaleParams := qi - r.RescaleParams[level-1][i]
+ nttPsi := r.NttPsi[i]
+ rescaleParams := r.RescaleParams[level-1][i]
pHalfNegQi = r.Modulus[i] - BRedAdd(pHalf, qi, bredParams)
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
- x := (*[8]uint64)(unsafe.Pointer(&p0tmp[j]))
- z := (*[8]uint64)(unsafe.Pointer(&pTmp[j]))
+ x := (*[8]uint64)(unsafe.Pointer(&pool0[j]))
+ z := (*[8]uint64)(unsafe.Pointer(&pool1[j]))
z[0] = x[0] + pHalfNegQi
z[1] = x[1] + pHalfNegQi
@@ -432,41 +494,44 @@ func (r *Ring) DivRoundByLastModulusNTT(p0 *Poly) {
z[7] = x[7] + pHalfNegQi
}
- NTTLazy(pTmp, pTmp, r.N, r.NttPsi[i], qi, mredParams, bredParams)
+ NTTLazy(pool1, pool1, r.N, nttPsi, qi, qInv, bredParams)
// (x[i] - x[-1]) * InvQ
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
- x := (*[8]uint64)(unsafe.Pointer(&pTmp[j]))
+ x := (*[8]uint64)(unsafe.Pointer(&pool1[j]))
+ y := (*[8]uint64)(unsafe.Pointer(&p0tmp[j]))
z := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
- z[0] = MRed(twoqi+x[0]-z[0], rescaleParams, qi, mredParams)
- z[1] = MRed(twoqi+x[1]-z[1], rescaleParams, qi, mredParams)
- z[2] = MRed(twoqi+x[2]-z[2], rescaleParams, qi, mredParams)
- z[3] = MRed(twoqi+x[3]-z[3], rescaleParams, qi, mredParams)
- z[4] = MRed(twoqi+x[4]-z[4], rescaleParams, qi, mredParams)
- z[5] = MRed(twoqi+x[5]-z[5], rescaleParams, qi, mredParams)
- z[6] = MRed(twoqi+x[6]-z[6], rescaleParams, qi, mredParams)
- z[7] = MRed(twoqi+x[7]-z[7], rescaleParams, qi, mredParams)
+ z[0] = MRed(twoqi+x[0]-y[0], rescaleParams, qi, qInv)
+ z[1] = MRed(twoqi+x[1]-y[1], rescaleParams, qi, qInv)
+ z[2] = MRed(twoqi+x[2]-y[2], rescaleParams, qi, qInv)
+ z[3] = MRed(twoqi+x[3]-y[3], rescaleParams, qi, qInv)
+ z[4] = MRed(twoqi+x[4]-y[4], rescaleParams, qi, qInv)
+ z[5] = MRed(twoqi+x[5]-y[5], rescaleParams, qi, qInv)
+ z[6] = MRed(twoqi+x[6]-y[6], rescaleParams, qi, qInv)
+ z[7] = MRed(twoqi+x[7]-y[7], rescaleParams, qi, qInv)
}
}
-
- p0.Coeffs = p0.Coeffs[:level]
}
// DivRoundByLastModulus divides (rounded) the polynomial by its last modulus. The input must be in the NTT domain.
-func (r *Ring) DivRoundByLastModulus(p0 *Poly) {
+// Output poly level must be equal or one less than input level.
+func (r *Ring) DivRoundByLastModulus(p0, p1 *Poly) {
+ r.divRoundByLastModulus(p0.Level(), p0, p1)
+ p1.Coeffs = p1.Coeffs[:p0.Level()]
+}
+
+func (r *Ring) divRoundByLastModulus(level int, p0, p1 *Poly) {
var pHalf, pHalfNegQi uint64
- level := len(p0.Coeffs) - 1
-
// Center by (p-1)/2
pHalf = (r.Modulus[level] - 1) >> 1
p0tmp := p0.Coeffs[level]
pj := r.Modulus[level]
- for i := uint64(0); i < r.N; i = i + 8 {
+ for i := 0; i < r.N; i = i + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p0tmp[i]))
@@ -483,45 +548,98 @@ func (r *Ring) DivRoundByLastModulus(p0 *Poly) {
for i := 0; i < level; i++ {
p1tmp := p0.Coeffs[i]
+ p2tmp := p1.Coeffs[i]
qi := r.Modulus[i]
twoqi := qi << 1
+ qInv := r.MredParams[i]
bredParams := r.BredParams[i]
- mredParams := r.MredParams[i]
- rescaleParams := qi - r.RescaleParams[level-1][i]
+ rescaleParams := r.RescaleParams[level-1][i]
pHalfNegQi = r.Modulus[i] - BRedAdd(pHalf, qi, bredParams)
// (x[i] - x[-1]) * InvQ
- for j := uint64(0); j < r.N; j = j + 8 {
+ for j := 0; j < r.N; j = j + 8 {
x := (*[8]uint64)(unsafe.Pointer(&p0tmp[j]))
- z := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
+ y := (*[8]uint64)(unsafe.Pointer(&p1tmp[j]))
+ z := (*[8]uint64)(unsafe.Pointer(&p2tmp[j]))
- z[0] = MRed(x[0]+pHalfNegQi+twoqi-z[0], rescaleParams, qi, mredParams)
- z[1] = MRed(x[1]+pHalfNegQi+twoqi-z[1], rescaleParams, qi, mredParams)
- z[2] = MRed(x[2]+pHalfNegQi+twoqi-z[2], rescaleParams, qi, mredParams)
- z[3] = MRed(x[3]+pHalfNegQi+twoqi-z[3], rescaleParams, qi, mredParams)
- z[4] = MRed(x[4]+pHalfNegQi+twoqi-z[4], rescaleParams, qi, mredParams)
- z[5] = MRed(x[5]+pHalfNegQi+twoqi-z[5], rescaleParams, qi, mredParams)
- z[6] = MRed(x[6]+pHalfNegQi+twoqi-z[6], rescaleParams, qi, mredParams)
- z[7] = MRed(x[7]+pHalfNegQi+twoqi-z[7], rescaleParams, qi, mredParams)
+ z[0] = MRed(x[0]+pHalfNegQi+twoqi-y[0], rescaleParams, qi, qInv)
+ z[1] = MRed(x[1]+pHalfNegQi+twoqi-y[1], rescaleParams, qi, qInv)
+ z[2] = MRed(x[2]+pHalfNegQi+twoqi-y[2], rescaleParams, qi, qInv)
+ z[3] = MRed(x[3]+pHalfNegQi+twoqi-y[3], rescaleParams, qi, qInv)
+ z[4] = MRed(x[4]+pHalfNegQi+twoqi-y[4], rescaleParams, qi, qInv)
+ z[5] = MRed(x[5]+pHalfNegQi+twoqi-y[5], rescaleParams, qi, qInv)
+ z[6] = MRed(x[6]+pHalfNegQi+twoqi-y[6], rescaleParams, qi, qInv)
+ z[7] = MRed(x[7]+pHalfNegQi+twoqi-y[7], rescaleParams, qi, qInv)
}
}
-
- p0.Coeffs = p0.Coeffs[:level]
}
// DivRoundByLastModulusManyNTT divides (rounded) sequentially nbRescales times the polynomial by its last modulus. The input must be in the NTT domain.
-func (r *Ring) DivRoundByLastModulusManyNTT(p0 *Poly, nbRescales uint64) {
- r.InvNTTLvl(uint64(len(p0.Coeffs)-1), p0, p0)
- r.DivRoundByLastModulusMany(p0, nbRescales)
- r.NTTLvl(uint64(len(p0.Coeffs)-1), p0, p0)
+// Output poly level must be equal or nbRescales less than input level.
+func (r *Ring) DivRoundByLastModulusManyNTT(p0, p1 *Poly, nbRescales int) {
+
+ level := p0.Level()
+
+ if nbRescales == 0 {
+
+ if p0 != p1 {
+ r.CopyLvl(p1.Level(), p0, p1)
+ }
+
+ } else {
+
+ if nbRescales > 1 {
+
+ r.InvNTTLvl(level, p0, r.polypool)
+
+ for i := 0; i < nbRescales; i++ {
+ r.divRoundByLastModulus(level-i, r.polypool, r.polypool)
+ }
+
+ r.NTTLvl(p1.Level(), r.polypool, p1)
+
+ } else {
+ r.divRoundByLastModulusNTT(level, p0, p1)
+ }
+
+ p1.Coeffs = p1.Coeffs[:level-nbRescales+1]
+ }
}
// DivRoundByLastModulusMany divides (rounded) sequentially nbRescales times the polynomial by its last modulus.
-func (r *Ring) DivRoundByLastModulusMany(p0 *Poly, nbRescales uint64) {
- for k := uint64(0); k < nbRescales; k++ {
- r.DivRoundByLastModulus(p0)
+// Output poly level must be equal or nbRescales less than input level.
+func (r *Ring) DivRoundByLastModulusMany(p0, p1 *Poly, nbRescales int) {
+
+ level := p0.Level()
+
+ if nbRescales == 0 {
+
+ if p0 != p1 {
+ r.CopyLvl(p1.Level(), p0, p1)
+ }
+
+ } else {
+
+ if nbRescales > 1 {
+
+ r.divRoundByLastModulus(level, p0, r.polypool)
+
+ for i := 1; i < nbRescales; i++ {
+
+ if i == nbRescales-1 {
+ r.divRoundByLastModulus(level-i, r.polypool, p1)
+ } else {
+ r.divRoundByLastModulus(level-i, r.polypool, r.polypool)
+ }
+ }
+
+ } else {
+ r.divRoundByLastModulus(level, p0, p1)
+ }
+
+ p1.Coeffs = p1.Coeffs[:level-nbRescales+1]
}
}
diff --git a/ring/ring_test.go b/ring/ring_test.go
index 1720a4a2..23df86e4 100644
--- a/ring/ring_test.go
+++ b/ring/ring_test.go
@@ -15,7 +15,7 @@ var flagLongTest = flag.Bool("long", false, "run the long test suite (all parame
var T = uint64(0x3ee0001)
var DefaultSigma = float64(3.2)
-var DefaultBound = uint64(6 * DefaultSigma)
+var DefaultBound = int(6 * DefaultSigma)
func testString(opname string, ringQ *Ring) string {
return fmt.Sprintf("%sN=%d/limbs=%d", opname, ringQ.N, len(ringQ.Modulus))
@@ -158,10 +158,10 @@ func testGenerateNTTPrimes(testContext *testParams, t *testing.T) {
t.Run(testString("GenerateNTTPrimes/", testContext.ringQ), func(t *testing.T) {
- primes := GenerateNTTPrimes(55, testContext.ringQ.N<<1, uint64(len(testContext.ringQ.Modulus)))
+ primes := GenerateNTTPrimes(55, testContext.ringQ.N<<1, len(testContext.ringQ.Modulus))
for _, q := range primes {
- require.Equal(t, q&((testContext.ringQ.N<<1)-1), uint64(1))
+ require.Equal(t, q&uint64((testContext.ringQ.N<<1)-1), uint64(1))
require.True(t, IsPrime(q), q)
}
})
@@ -185,7 +185,7 @@ func testDivFloorByLastModulusMany(testContext *testParams, t *testing.T) {
t.Run(testString("DivFloorByLastModulusMany/", testContext.ringQ), func(t *testing.T) {
coeffs := make([]*big.Int, testContext.ringQ.N)
- for i := uint64(0); i < testContext.ringQ.N; i++ {
+ for i := 0; i < testContext.ringQ.N; i++ {
coeffs[i] = RandInt(testContext.ringQ.ModulusBigint)
coeffs[i].Quo(coeffs[i], NewUint(10))
}
@@ -200,16 +200,17 @@ func testDivFloorByLastModulusMany(testContext *testParams, t *testing.T) {
}
}
- polTest := testContext.ringQ.NewPoly()
+ polTest0 := testContext.ringQ.NewPoly()
+ polTest1 := testContext.ringQ.NewPoly()
polWant := testContext.ringQ.NewPoly()
- testContext.ringQ.SetCoefficientsBigint(coeffs, polTest)
+ testContext.ringQ.SetCoefficientsBigint(coeffs, polTest0)
testContext.ringQ.SetCoefficientsBigint(coeffsWant, polWant)
- testContext.ringQ.DivFloorByLastModulusMany(polTest, uint64(nbRescals))
- for i := uint64(0); i < testContext.ringQ.N; i++ {
+ testContext.ringQ.DivFloorByLastModulusMany(polTest0, polTest1, nbRescals)
+ for i := 0; i < testContext.ringQ.N; i++ {
for j := 0; j < len(testContext.ringQ.Modulus)-nbRescals; j++ {
- require.Equalf(t, polWant.Coeffs[j][i], polTest.Coeffs[j][i], "coeff %v Qi%v = %s", i, j, coeffs[i].String())
+ require.Equalf(t, polWant.Coeffs[j][i], polTest1.Coeffs[j][i], "coeff %v Qi%v = %s", i, j, coeffs[i].String())
}
}
})
@@ -220,7 +221,7 @@ func testDivRoundByLastModulusMany(testContext *testParams, t *testing.T) {
t.Run(testString("DivRoundByLastModulusMany/", testContext.ringQ), func(t *testing.T) {
coeffs := make([]*big.Int, testContext.ringQ.N)
- for i := uint64(0); i < testContext.ringQ.N; i++ {
+ for i := 0; i < testContext.ringQ.N; i++ {
coeffs[i] = RandInt(testContext.ringQ.ModulusBigint)
coeffs[i].Quo(coeffs[i], NewUint(10))
}
@@ -235,16 +236,17 @@ func testDivRoundByLastModulusMany(testContext *testParams, t *testing.T) {
}
}
- polTest := testContext.ringQ.NewPoly()
+ polTest0 := testContext.ringQ.NewPoly()
+ polTest1 := testContext.ringQ.NewPoly()
polWant := testContext.ringQ.NewPoly()
- testContext.ringQ.SetCoefficientsBigint(coeffs, polTest)
+ testContext.ringQ.SetCoefficientsBigint(coeffs, polTest0)
testContext.ringQ.SetCoefficientsBigint(coeffsWant, polWant)
- testContext.ringQ.DivRoundByLastModulusMany(polTest, uint64(nbRescals))
- for i := uint64(0); i < testContext.ringQ.N; i++ {
+ testContext.ringQ.DivRoundByLastModulusMany(polTest0, polTest1, nbRescals)
+ for i := 0; i < testContext.ringQ.N; i++ {
for j := 0; j < len(testContext.ringQ.Modulus)-nbRescals; j++ {
- require.Equalf(t, polWant.Coeffs[j][i], polTest.Coeffs[j][i], "coeff %v Qi%v = %s", i, j, coeffs[i].String())
+ require.Equalf(t, polWant.Coeffs[j][i], polTest1.Coeffs[j][i], "coeff %v Qi%v = %s", i, j, coeffs[i].String())
}
}
})
@@ -283,7 +285,7 @@ func testUniformSampler(testContext *testParams, t *testing.T) {
t.Run(testString("UniformSampler/Read/", testContext.ringQ), func(t *testing.T) {
pol := testContext.ringQ.NewPoly()
testContext.uniformSamplerQ.Read(pol)
- for i := uint64(0); i < testContext.ringQ.N; i++ {
+ for i := 0; i < testContext.ringQ.N; i++ {
for j, qi := range testContext.ringQ.Modulus {
require.False(t, pol.Coeffs[j][i] > qi)
}
@@ -292,7 +294,7 @@ func testUniformSampler(testContext *testParams, t *testing.T) {
t.Run(testString("UniformSampler/ReadNew/", testContext.ringQ), func(t *testing.T) {
pol := testContext.uniformSamplerQ.ReadNew()
- for i := uint64(0); i < testContext.ringQ.N; i++ {
+ for i := 0; i < testContext.ringQ.N; i++ {
for j, qi := range testContext.ringQ.Modulus {
require.False(t, pol.Coeffs[j][i] > qi)
}
@@ -306,7 +308,7 @@ func testGaussianSampler(testContext *testParams, t *testing.T) {
gaussianSampler := NewGaussianSampler(testContext.prng, testContext.ringQ, DefaultSigma, DefaultBound)
pol := gaussianSampler.ReadNew()
- for i := uint64(0); i < testContext.ringQ.N; i++ {
+ for i := 0; i < testContext.ringQ.N; i++ {
for j, qi := range testContext.ringQ.Modulus {
require.False(t, uint64(DefaultBound) < pol.Coeffs[j][i] && pol.Coeffs[j][i] < (qi-uint64(DefaultBound)))
}
@@ -335,7 +337,7 @@ func testTernarySampler(testContext *testParams, t *testing.T) {
})
}
- for _, p := range []uint64{0, 64, 96, 128, 256} {
+ for _, p := range []int{0, 64, 96, 128, 256} {
t.Run(testString(fmt.Sprintf("TernarySampler/hw=%d/", p), testContext.ringQ), func(t *testing.T) {
prng, err := utils.NewPRNG()
@@ -348,7 +350,7 @@ func testTernarySampler(testContext *testParams, t *testing.T) {
pol := ternarySampler.ReadNew()
for i := range testContext.ringQ.Modulus {
- hw := uint64(0)
+ hw := 0
for _, c := range pol.Coeffs[i] {
if c != 0 {
hw++
@@ -594,7 +596,7 @@ func testExtendBasis(testContext *testParams, t *testing.T) {
basisextender := NewFastBasisExtender(testContext.ringQ, testContext.ringP)
coeffs := make([]*big.Int, testContext.ringQ.N)
- for i := uint64(0); i < testContext.ringQ.N; i++ {
+ for i := 0; i < testContext.ringQ.N; i++ {
coeffs[i] = RandInt(testContext.ringQ.ModulusBigint)
}
@@ -605,7 +607,7 @@ func testExtendBasis(testContext *testParams, t *testing.T) {
testContext.ringQ.SetCoefficientsBigint(coeffs, Pol)
testContext.ringP.SetCoefficientsBigint(coeffs, PolWant)
- basisextender.ModUpSplitQP(uint64(len(testContext.ringQ.Modulus)-1), Pol, PolTest)
+ basisextender.ModUpSplitQP(len(testContext.ringQ.Modulus)-1, Pol, PolTest)
testContext.ringP.Reduce(PolTest, PolTest)
@@ -622,7 +624,7 @@ func testScaling(testContext *testParams, t *testing.T) {
rescaler := NewSimpleScaler(T, testContext.ringQ)
coeffs := make([]*big.Int, testContext.ringQ.N)
- for i := uint64(0); i < testContext.ringQ.N; i++ {
+ for i := 0; i < testContext.ringQ.N; i++ {
coeffs[i] = RandInt(testContext.ringQ.ModulusBigint)
}
@@ -640,7 +642,7 @@ func testScaling(testContext *testParams, t *testing.T) {
rescaler.DivByQOverTRounded(PolTest, PolTest)
- for i := uint64(0); i < testContext.ringQ.N; i++ {
+ for i := 0; i < testContext.ringQ.N; i++ {
require.Equal(t, PolTest.Coeffs[0][i], coeffsWant[i].Uint64())
}
})
@@ -650,7 +652,7 @@ func testScaling(testContext *testParams, t *testing.T) {
scaler := NewRNSScaler(T, testContext.ringQ)
coeffs := make([]*big.Int, testContext.ringQ.N)
- for i := uint64(0); i < testContext.ringQ.N; i++ {
+ for i := 0; i < testContext.ringQ.N; i++ {
coeffs[i] = RandInt(testContext.ringQ.ModulusBigint)
}
@@ -668,7 +670,7 @@ func testScaling(testContext *testParams, t *testing.T) {
scaler.DivByQOverTRounded(polyQ, polyT)
- for i := uint64(0); i < testContext.ringQ.N; i++ {
+ for i := 0; i < testContext.ringQ.N; i++ {
require.Equal(t, polyT.Coeffs[0][i], coeffsWant[i].Uint64())
}
})
diff --git a/ring/utils.go b/ring/utils.go
index fad2e4a1..f382f8c1 100644
--- a/ring/utils.go
+++ b/ring/utils.go
@@ -14,7 +14,7 @@ func Min(x, y int) int {
}
// PowerOf2 returns (x*2^n)%q where x is in Montgomery form
-func PowerOf2(x, n, q, qInv uint64) (r uint64) {
+func PowerOf2(x uint64, n int, q, qInv uint64) (r uint64) {
ahi, alo := x>>(64-n), x< 0; i >>= 1 {
@@ -41,7 +41,7 @@ func ModExp(x, e, p uint64) (result uint64) {
// modexpMontgomery performs the modular exponentiation x^e mod p,
// where x is in Montgomery form, and returns x^e in Montgomery form.
-func modexpMontgomery(x, e, q, qInv uint64, bredParams []uint64) (result uint64) {
+func modexpMontgomery(x uint64, e int, q, qInv uint64, bredParams []uint64) (result uint64) {
result = MForm(1, q, bredParams)
@@ -80,7 +80,7 @@ func primitiveRoot(q uint64) (g uint64) {
for _, factor := range factors {
tmp = (q - 1) / factor
// if for any factor of q-1, g^(q-1)/factor = 1 mod q, g is not a primitive root
- if ModExp(g, tmp, q) == 1 {
+ if ModExp(g, int(tmp), q) == 1 {
notFoundPrimitiveRoot = true
break
}
diff --git a/rlwe/keys.go b/rlwe/keys.go
index 58e7ffe9..a65a4b5e 100644
--- a/rlwe/keys.go
+++ b/rlwe/keys.go
@@ -35,7 +35,7 @@ type RotationKeySet struct {
}
// NewSecretKey generates a new SecretKey with zero values.
-func NewSecretKey(ringDegree, moduliCount uint64) *SecretKey {
+func NewSecretKey(ringDegree, moduliCount int) *SecretKey {
sk := new(SecretKey)
sk.Value = ring.NewPoly(ringDegree, moduliCount)
@@ -43,12 +43,12 @@ func NewSecretKey(ringDegree, moduliCount uint64) *SecretKey {
}
// NewPublicKey returns a new PublicKey with zero values.
-func NewPublicKey(ringDegree, moduliCount uint64) (pk *PublicKey) {
+func NewPublicKey(ringDegree, moduliCount int) (pk *PublicKey) {
return &PublicKey{Value: [2]*ring.Poly{ring.NewPoly(ringDegree, moduliCount), ring.NewPoly(ringDegree, moduliCount)}}
}
// NewRotationKeySet returns a new RotationKeySet with pre-allocated switching keys for each distinct galoisElement value.
-func NewRotationKeySet(galoisElement []uint64, ringDegree, moduliCount, decompSize uint64) (rotKey *RotationKeySet) {
+func NewRotationKeySet(galoisElement []uint64, ringDegree, moduliCount, decompSize int) (rotKey *RotationKeySet) {
rotKey = new(RotationKeySet)
rotKey.Keys = make(map[uint64]*SwitchingKey, len(galoisElement))
for _, galEl := range galoisElement {
@@ -65,13 +65,13 @@ func (rtks *RotationKeySet) GetRotationKey(galoisEl uint64) (*SwitchingKey, bool
}
// NewSwitchingKey returns a new public switching key with pre-allocated zero-value
-func NewSwitchingKey(ringDegree, moduliCount, decompSize uint64) *SwitchingKey {
+func NewSwitchingKey(ringDegree, moduliCount, decompSize int) *SwitchingKey {
swk := new(SwitchingKey)
swk.Value = make([][2]*ring.Poly, int(decompSize))
- for i := uint64(0); i < decompSize; i++ {
+ for i := 0; i < decompSize; i++ {
swk.Value[i][0] = ring.NewPoly(ringDegree, moduliCount)
swk.Value[i][1] = ring.NewPoly(ringDegree, moduliCount)
}
@@ -80,13 +80,13 @@ func NewSwitchingKey(ringDegree, moduliCount, decompSize uint64) *SwitchingKey {
}
// NewRelinKey creates a new EvaluationKey with zero values.
-func NewRelinKey(maxRelinDegree, ringDegree, moduliCount, decompSize uint64) (evakey *RelinearizationKey) {
+func NewRelinKey(maxRelinDegree, ringDegree, moduliCount, decompSize int) (evakey *RelinearizationKey) {
evakey = new(RelinearizationKey)
evakey.Keys = make([]*SwitchingKey, maxRelinDegree)
- for d := uint64(0); d < maxRelinDegree; d++ {
+ for d := 0; d < maxRelinDegree; d++ {
evakey.Keys[d] = NewSwitchingKey(ringDegree, moduliCount, decompSize)
}
@@ -94,7 +94,7 @@ func NewRelinKey(maxRelinDegree, ringDegree, moduliCount, decompSize uint64) (ev
}
// GetDataLen returns the length in bytes of the target SecretKey.
-func (sk *SecretKey) GetDataLen(WithMetadata bool) (dataLen uint64) {
+func (sk *SecretKey) GetDataLen(WithMetadata bool) (dataLen int) {
return sk.Value.GetDataLen(WithMetadata)
}
@@ -123,7 +123,7 @@ func (sk *SecretKey) UnmarshalBinary(data []byte) (err error) {
}
// GetDataLen returns the length in bytes of the target PublicKey.
-func (pk *PublicKey) GetDataLen(WithMetadata bool) (dataLen uint64) {
+func (pk *PublicKey) GetDataLen(WithMetadata bool) (dataLen int) {
for _, el := range pk.Value {
dataLen += el.GetDataLen(WithMetadata)
@@ -139,7 +139,7 @@ func (pk *PublicKey) MarshalBinary() (data []byte, err error) {
data = make([]byte, dataLen)
- var pointer, inc uint64
+ var pointer, inc int
if inc, err = pk.Value[0].WriteTo(data[pointer:]); err != nil {
return nil, err
@@ -156,7 +156,7 @@ func (pk *PublicKey) MarshalBinary() (data []byte, err error) {
// UnmarshalBinary decodes a previously marshaled PublicKey in the target PublicKey.
func (pk *PublicKey) UnmarshalBinary(data []byte) (err error) {
- var pointer, inc uint64
+ var pointer, inc int
pk.Value[0] = new(ring.Poly)
pk.Value[1] = new(ring.Poly)
@@ -173,7 +173,7 @@ func (pk *PublicKey) UnmarshalBinary(data []byte) (err error) {
}
// GetDataLen returns the length in bytes of the target EvaluationKey.
-func (evaluationkey *RelinearizationKey) GetDataLen(WithMetadata bool) (dataLen uint64) {
+func (evaluationkey *RelinearizationKey) GetDataLen(WithMetadata bool) (dataLen int) {
if WithMetadata {
dataLen++
@@ -189,7 +189,7 @@ func (evaluationkey *RelinearizationKey) GetDataLen(WithMetadata bool) (dataLen
// MarshalBinary encodes an EvaluationKey key in a byte slice.
func (evaluationkey *RelinearizationKey) MarshalBinary() (data []byte, err error) {
- var pointer uint64
+ var pointer int
dataLen := evaluationkey.GetDataLen(true)
@@ -212,13 +212,13 @@ func (evaluationkey *RelinearizationKey) MarshalBinary() (data []byte, err error
// UnmarshalBinary decodes a previously marshaled EvaluationKey in the target EvaluationKey.
func (evaluationkey *RelinearizationKey) UnmarshalBinary(data []byte) (err error) {
- deg := uint64(data[0])
+ deg := int(data[0])
evaluationkey.Keys = make([]*SwitchingKey, deg)
- pointer := uint64(1)
- var inc uint64
- for i := uint64(0); i < deg; i++ {
+ pointer := int(1)
+ var inc int
+ for i := 0; i < deg; i++ {
evaluationkey.Keys[i] = new(SwitchingKey)
if inc, err = evaluationkey.Keys[i].decode(data[pointer:]); err != nil {
return err
@@ -230,7 +230,7 @@ func (evaluationkey *RelinearizationKey) UnmarshalBinary(data []byte) (err error
}
// GetDataLen returns the length in bytes of the target SwitchingKey.
-func (switchkey *SwitchingKey) GetDataLen(WithMetadata bool) (dataLen uint64) {
+func (switchkey *SwitchingKey) GetDataLen(WithMetadata bool) (dataLen int) {
if WithMetadata {
dataLen++
@@ -266,17 +266,16 @@ func (switchkey *SwitchingKey) UnmarshalBinary(data []byte) (err error) {
return nil
}
-func (switchkey *SwitchingKey) encode(pointer uint64, data []byte) (uint64, error) {
+func (switchkey *SwitchingKey) encode(pointer int, data []byte) (int, error) {
var err error
-
- var inc uint64
+ var inc int
data[pointer] = uint8(len(switchkey.Value))
pointer++
- for j := uint64(0); j < uint64(len(switchkey.Value)); j++ {
+ for j := 0; j < len(switchkey.Value); j++ {
if inc, err = switchkey.Value[j][0].WriteTo(data[pointer : pointer+switchkey.Value[j][0].GetDataLen(true)]); err != nil {
return pointer, err
@@ -294,17 +293,17 @@ func (switchkey *SwitchingKey) encode(pointer uint64, data []byte) (uint64, erro
return pointer, nil
}
-func (switchkey *SwitchingKey) decode(data []byte) (pointer uint64, err error) {
+func (switchkey *SwitchingKey) decode(data []byte) (pointer int, err error) {
- decomposition := uint64(data[0])
+ decomposition := int(data[0])
- pointer = uint64(1)
+ pointer = 1
switchkey.Value = make([][2]*ring.Poly, decomposition)
- var inc uint64
+ var inc int
- for j := uint64(0); j < decomposition; j++ {
+ for j := 0; j < decomposition; j++ {
switchkey.Value[j][0] = new(ring.Poly)
if inc, err = switchkey.Value[j][0].DecodePolyNew(data[pointer:]); err != nil {
@@ -324,7 +323,7 @@ func (switchkey *SwitchingKey) decode(data []byte) (pointer uint64, err error) {
}
// GetDataLen returns the length in bytes of the target RotationKeys.
-func (rtks *RotationKeySet) GetDataLen(WithMetaData bool) (dataLen uint64) {
+func (rtks *RotationKeySet) GetDataLen(WithMetaData bool) (dataLen int) {
for _, k := range rtks.Keys {
if WithMetaData {
dataLen += 4
@@ -339,7 +338,7 @@ func (rtks *RotationKeySet) MarshalBinary() (data []byte, err error) {
data = make([]byte, rtks.GetDataLen(true))
- pointer := uint64(0)
+ pointer := int(0)
for galEL, key := range rtks.Keys {
@@ -365,7 +364,7 @@ func (rtks *RotationKeySet) UnmarshalBinary(data []byte) (err error) {
data = data[4:]
swk := new(SwitchingKey)
- var inc uint64
+ var inc int
if inc, err = swk.decode(data); err != nil {
return err
}
diff --git a/utils/utils.go b/utils/utils.go
index ed8aa0a0..721e1205 100644
--- a/utils/utils.go
+++ b/utils/utils.go
@@ -73,7 +73,7 @@ func IsInSliceInt(x int, slice []int) (v bool) {
return
}
-// MinUint64 returns the minimum value of the input slice of uint64 values.
+// MinUint64 returns the minimum value of the input of uint64 values.
func MinUint64(a, b uint64) (r uint64) {
if a <= b {
return a
@@ -81,7 +81,15 @@ func MinUint64(a, b uint64) (r uint64) {
return b
}
-// MaxUint64 returns the maximum value of the input slice of uint64 values.
+// MinInt returns the minimum value of the input of int values.
+func MinInt(a, b int) (r int) {
+ if a <= b {
+ return a
+ }
+ return b
+}
+
+// MaxUint64 returns the maximum value of the input of uint64 values.
func MaxUint64(a, b uint64) (r uint64) {
if a >= b {
return a
@@ -89,7 +97,15 @@ func MaxUint64(a, b uint64) (r uint64) {
return b
}
-// MaxFloat64 returns the maximum value of the input slice of uint64 values.
+// MaxInt returns the maximum value of the input of int values.
+func MaxInt(a, b int) (r int) {
+ if a >= b {
+ return a
+ }
+ return b
+}
+
+// MaxFloat64 returns the maximum value of the input slice of float64 values.
func MaxFloat64(a, b float64) (r float64) {
if a >= b {
return a
@@ -97,6 +113,14 @@ func MaxFloat64(a, b float64) (r float64) {
return b
}
+// MaxSliceUint64 returns the maximum value of the input slice of uint64 values.
+func MaxSliceUint64(slice []uint64) (max uint64) {
+ for i := range slice {
+ max = MaxUint64(max, slice[i])
+ }
+ return
+}
+
// BitReverse64 returns the bit-reverse value of the input value, within a context of 2^bitLen.
func BitReverse64(index, bitLen uint64) uint64 {
return bits.Reverse64(index) >> (64 - bitLen)