From ca54f2bba4300bda5936ed8a543eb4e4403046c2 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 29 Nov 2019 15:42:33 +0100 Subject: [PATCH] BFV/CKKS : public gen API for DBFV and DCKKS --- bfv/bfv.go | 4 +-- bfv/encoder.go | 2 +- bfv/evaluator.go | 2 +- bfv/utils.go | 51 +++------------------------ ckks/ckks.go | 4 +-- ckks/utils.go | 89 +++++++++++++----------------------------------- ring/utils.go | 64 +++++++++++++++++----------------- 7 files changed, 64 insertions(+), 152 deletions(-) diff --git a/bfv/bfv.go b/bfv/bfv.go index 2121c167..4292b1c1 100644 --- a/bfv/bfv.go +++ b/bfv/bfv.go @@ -81,7 +81,7 @@ func (context *Context) SetParameters(params *Parameters) { N := uint64(1 << LogN) t := params.T - ModuliQ1, ModuliP, ModuliQ2 := genModuli(params) + ModuliQ1, ModuliP, ModuliQ2 := GenModuli(params) sigma := params.Sigma context.n = N @@ -110,7 +110,7 @@ func (context *Context) SetParameters(params *Parameters) { context.alpha = uint64(len(ModuliP)) context.beta = uint64(math.Ceil(float64(len(ModuliQ1)) / float64(context.alpha))) - context.rescaleParamsKeys = genRescalingParams(context.contextP, context.contextQ1) + context.rescaleParamsKeys = GenRescalingParams(context.contextP, context.contextQ1) context.logQ = uint64(context.contextQ1P.ModulusBigint.BitLen()) diff --git a/bfv/encoder.go b/bfv/encoder.go index 5a87e117..680d0feb 100644 --- a/bfv/encoder.go +++ b/bfv/encoder.go @@ -48,7 +48,7 @@ func NewEncoder(params *Parameters) (encoder *Encoder) { pos &= (m - 1) } - encoder.deltaMont = genLiftParams(encoder.bfvContext.contextQ1, encoder.bfvContext.t) + encoder.deltaMont = GenLiftParams(encoder.bfvContext.contextQ1, encoder.bfvContext.t) encoder.simplescaler = ring.NewSimpleScaler(encoder.bfvContext.t, encoder.bfvContext.contextQ1) encoder.polypool = encoder.bfvContext.contextT.NewPoly() diff --git a/bfv/evaluator.go b/bfv/evaluator.go index 394c867d..49339680 100644 --- a/bfv/evaluator.go +++ b/bfv/evaluator.go @@ -44,7 +44,7 @@ func NewEvaluator(params *Parameters) (evaluator *Evaluator) { evaluator.baseconverter = ring.NewFastBasisExtender(evaluator.bfvContext.contextQ1.Modulus, evaluator.bfvContext.contextP.Modulus) evaluator.decomposer = ring.NewArbitraryDecomposer(evaluator.bfvContext.contextQ1.Modulus, evaluator.bfvContext.contextP.Modulus) - evaluator.rescaleParamsMul = genRescalingParams(evaluator.bfvContext.contextQ1, evaluator.bfvContext.contextQ2) + evaluator.rescaleParamsMul = GenRescalingParams(evaluator.bfvContext.contextQ1, evaluator.bfvContext.contextQ2) evaluator.pHalf = new(big.Int).Rsh(evaluator.bfvContext.contextQ2.ModulusBigint, 1) for i := 0; i < 2; i++ { diff --git a/bfv/utils.go b/bfv/utils.go index 14d85eb2..76e60a05 100644 --- a/bfv/utils.go +++ b/bfv/utils.go @@ -5,7 +5,7 @@ import ( "math/big" ) -func genLiftParams(context *ring.Context, t uint64) (deltaMont []uint64) { +func GenLiftParams(context *ring.Context, t uint64) (deltaMont []uint64) { delta := new(big.Int).Quo(context.ModulusBigint, ring.NewUint(t)) @@ -21,7 +21,7 @@ func genLiftParams(context *ring.Context, t uint64) (deltaMont []uint64) { return } -func genRescalingParams(contextQ1, contextQ2 *ring.Context) (params []uint64) { +func GenRescalingParams(contextQ1, contextQ2 *ring.Context) (params []uint64) { params = make([]uint64, len(contextQ2.Modulus)) @@ -38,7 +38,7 @@ func genRescalingParams(contextQ1, contextQ2 *ring.Context) (params []uint64) { } // genModuli generates the appropriate primes from the parameters using generateCKKSPrimes such that all primes are different. -func genModuli(params *Parameters) (Q1 []uint64, P []uint64, Q2 []uint64) { +func GenModuli(params *Parameters) (Q1 []uint64, P []uint64, Q2 []uint64) { // Extracts all the different primes bit size and maps their number primesbitlen := make(map[uint64]uint64) @@ -71,7 +71,7 @@ func genModuli(params *Parameters) (Q1 []uint64, P []uint64, Q2 []uint64) { // For each bitsize, finds that many primes primes := make(map[uint64][]uint64) for key, value := range primesbitlen { - primes[key] = generateNTTPrimes(key, uint64(params.LogN), value) + primes[key] = ring.GenerateNTTPrimes(key, uint64(params.LogN), value) } // Assigns the primes to the ckks moduli chain @@ -97,47 +97,4 @@ func genModuli(params *Parameters) (Q1 []uint64, P []uint64, Q2 []uint64) { return Q1, P, Q2 } -func generateNTTPrimes(logQ, logN, levels uint64) (primes []uint64) { - // generateCKKSPrimes generates primes given logQ = size of the primes, logN = size of N and level, the number - // of levels required. Will return all the appropriate primes, up to the number of level, with the - // best avaliable deviation from the base power of 2 for the given level. - - if logQ > 60 { - panic("logQ must be between 1 and 60") - } - - var x, y, Qpow2, _2N uint64 - - primes = []uint64{} - - Qpow2 = 1 << logQ - - _2N = 2 << logN - - x = Qpow2 + 1 - y = Qpow2 + 1 - - for true { - - if ring.IsPrime(y) { - primes = append(primes, y) - if uint64(len(primes)) == levels { - return primes - } - } - - y -= _2N - - if ring.IsPrime(x) { - primes = append(primes, x) - if uint64(len(primes)) == levels { - return primes - } - } - - x += _2N - } - - return -} diff --git a/ckks/ckks.go b/ckks/ckks.go index 709a0810..391bebca 100644 --- a/ckks/ckks.go +++ b/ckks/ckks.go @@ -70,7 +70,7 @@ func NewContext(params *Parameters) (ckkscontext *Context) { ckkscontext.alpha = uint64(len(params.P)) ckkscontext.beta = uint64(math.Ceil(float64(ckkscontext.levels) / float64(ckkscontext.alpha))) - ckkscontext.moduli, ckkscontext.specialprimes = genModuli(params) + ckkscontext.moduli, ckkscontext.specialprimes = GenModuli(params) ckkscontext.bigintChain = genBigIntChain(ckkscontext.moduli) @@ -88,7 +88,7 @@ func NewContext(params *Parameters) (ckkscontext *Context) { ckkscontext.logQ = uint64(ckkscontext.contextKeys.ModulusBigint.BitLen()) - ckkscontext.rescaleParamsKeys = genSwitchkeysRescalingParams(ckkscontext.moduli, ckkscontext.specialprimes) + ckkscontext.rescaleParamsKeys = GenSwitchkeysRescalingParams(ckkscontext.moduli, ckkscontext.specialprimes) ckkscontext.gaussianSampler = ckkscontext.contextKeys.NewKYSampler(params.Sigma, int(6*params.Sigma)) diff --git a/ckks/utils.go b/ckks/utils.go index f71cd1ca..65e1a93c 100644 --- a/ckks/utils.go +++ b/ckks/utils.go @@ -120,8 +120,29 @@ func genBigIntChain(Q []uint64) (bigintChain []*big.Int) { return } +func GenSwitchkeysRescalingParams(Q, P []uint64) (params []uint64) { + + params = make([]uint64, len(Q)) + + PBig := ring.NewUint(1) + for _, pj := range P { + PBig.Mul(PBig, ring.NewUint(pj)) + } + + tmp := ring.NewUint(0) + + for i := 0; i < len(Q); i++ { + + params[i] = tmp.Mod(PBig, ring.NewUint(Q[i])).Uint64() + params[i] = ring.ModExp(params[i], Q[i]-2, Q[i]) + params[i] = ring.MForm(params[i], Q[i], ring.BRedParams(Q[i])) + } + + return +} + // genModuli generates the appropriate primes from the parameters using generateCKKSPrimes such that all primes are different. -func genModuli(params *Parameters) (Q []uint64, P []uint64) { +func GenModuli(params *Parameters) (Q []uint64, P []uint64) { // Extracts all the different primes bit size and maps their number primesbitlen := make(map[uint64]uint64) @@ -145,7 +166,7 @@ func genModuli(params *Parameters) (Q []uint64, P []uint64) { // For each bitsize, finds that many primes primes := make(map[uint64][]uint64) for key, value := range primesbitlen { - primes[key] = generateCKKSPrimes(key, uint64(params.LogN), value) + primes[key] = ring.GenerateNTTPrimes(key, uint64(params.LogN), value) } // Assigns the primes to the ckks moduli chain @@ -165,51 +186,6 @@ func genModuli(params *Parameters) (Q []uint64, P []uint64) { return Q, P } -func generateCKKSPrimes(logQ, logN, levels uint64) (primes []uint64) { - - // generateCKKSPrimes generates primes given logQ = size of the primes, logN = size of N and level, the number - // of levels required. Will return all the appropriate primes, up to the number of level, with the - // best avaliable deviation from the base power of 2 for the given level. - - if logQ > 60 { - panic("logQ must be between 1 and 60") - } - - var x, y, Qpow2, _2N uint64 - - primes = []uint64{} - - Qpow2 = 1 << logQ - - _2N = 2 << logN - - x = Qpow2 + 1 - y = Qpow2 + 1 - - for true { - - if ring.IsPrime(y) { - primes = append(primes, y) - if uint64(len(primes)) == levels { - return primes - } - } - - y -= _2N - - if ring.IsPrime(x) { - primes = append(primes, x) - if uint64(len(primes)) == levels { - return primes - } - } - - x += _2N - } - - return -} - func sliceBitReverseInPlaceComplex128(slice []complex128, N uint64) { var bit, j uint64 @@ -231,23 +207,4 @@ func sliceBitReverseInPlaceComplex128(slice []complex128, N uint64) { } } -func genSwitchkeysRescalingParams(Q, P []uint64) (params []uint64) { - params = make([]uint64, len(Q)) - - PBig := ring.NewUint(1) - for _, pj := range P { - PBig.Mul(PBig, ring.NewUint(pj)) - } - - tmp := ring.NewUint(0) - - for i := 0; i < len(Q); i++ { - - params[i] = tmp.Mod(PBig, ring.NewUint(Q[i])).Uint64() - params[i] = ring.ModExp(params[i], Q[i]-2, Q[i]) - params[i] = ring.MForm(params[i], Q[i], ring.BRedParams(Q[i])) - } - - return -} diff --git a/ring/utils.go b/ring/utils.go index d17981ff..d57e5aea 100644 --- a/ring/utils.go +++ b/ring/utils.go @@ -1,7 +1,6 @@ package ring import ( - "errors" "math/bits" ) @@ -121,49 +120,48 @@ func IsPrime(num uint64) bool { return true } -// GenerateNTTPrimes generates "n" primes of bitlen "bitLen", suited for NTT with "N", -// starting from the integer "start" (which must be 1 mod 2N) and increasing (true) / decreasing (false) order -func GenerateNTTPrimes(N, start, n, bitLen uint64, sign bool) ([]uint64, error) { - var x, v uint64 +// GenerateCKKSPrimes generates primes given logQ = size of the primes, logN = size of N and level, the number +// of levels required. Will return all the appropriate primes, up to the number of level, with the +// best avaliable deviation from the base power of 2 for the given level. +func GenerateNTTPrimes(logQ, logN, levels uint64) (primes []uint64) { - if uint64(bits.Len64(start)) != bitLen { - return nil, errors.New("error : start != bitLen") + if logQ > 60 { + panic("logQ must be between 1 and 60") } - v = N << 1 - if start != 0 { - if start&((N<<1)-1) != 1 { - return nil, errors.New("error : start != 1 mod 2*N") + var x, y, Qpow2, _2N uint64 + + primes = []uint64{} + + Qpow2 = 1 << logQ + + _2N = 2 << logN + + x = Qpow2 + 1 + y = Qpow2 + 1 + + for true { + + if IsPrime(y) { + primes = append(primes, y) + if uint64(len(primes)) == levels { + return primes + } } - x = start - } else { - x = v<<(bitLen-uint64(bits.Len64(v))) + 1 - } - primes := make([]uint64, n) - - i := uint64(0) - - for i < n { - - // x gets out of the bitLen bound - if uint64(bits.Len64(x)) != bitLen { - return primes, nil - } + y -= _2N if IsPrime(x) { - primes[i] = x - i++ + primes = append(primes, x) + if uint64(len(primes)) == levels { + return primes + } } - if sign { - x += v - } else { - x -= v - } + x += _2N } - return primes, nil + return } //===========================