Browse Source

Clean up the hardware acceleration plumbing a bit.

This is easier to read, maintain, and extend.
Yawning Angel 1 year ago
parent
commit
aa1972d8c9
7 changed files with 33 additions and 22 deletions
  1. 12 6
      hwaccel.go
  2. 9 6
      hwaccel_amd64.go
  3. 2 2
      kem_test.go
  4. 1 1
      kem_vectors_test.go
  5. 1 1
      kex_test.go
  6. 2 2
      poly.go
  7. 6 4
      polyvec.go

+ 12 - 6
hwaccel.go

@@ -7,23 +7,29 @@
 
 package kyber
 
-const implReference = "Reference"
-
 var (
 	isHardwareAccelerated = false
 	hardwareAccelImpl     = implReference
 
-	nttFn    = nttRef
-	invnttFn = invnttRef
+	implReference = &hwaccelImpl{
+		name:     "Reference",
+		nttFn:    nttRef,
+		invnttFn: invnttRef,
+	}
 )
 
+type hwaccelImpl struct {
+	name                   string
+	nttFn                  func(*[kyberN]uint16)
+	invnttFn               func(*[kyberN]uint16)
+	pointwiseAccMustFreeze bool
+}
+
 func forceDisableHardwareAcceleration() {
 	// This is for the benefit of testing, so that it's possible to test
 	// all versions that are supported by the host.
 	isHardwareAccelerated = false
 	hardwareAccelImpl = implReference
-	nttFn = nttRef
-	invnttFn = invnttRef
 }
 
 // IsHardwareAccelerated returns true iff the Kyber implementation will use

+ 9 - 6
hwaccel_amd64.go

@@ -9,8 +9,6 @@
 
 package kyber
 
-const implAVX2 = "AVX2"
-
 var zetasExp = [752]uint16{
 	3777, 3777, 3777, 3777, 3777, 3777, 3777, 3777, 3777, 3777, 3777, 3777,
 	3777, 3777, 3777, 3777, 4499, 4499, 4499, 4499, 4499, 4499, 4499, 4499,
@@ -190,19 +188,24 @@ func supportsAVX2() bool {
 	return regs[1]&avx2Bit != 0
 }
 
+var implAVX2 = &hwaccelImpl{
+	name:                   "AVX2",
+	nttFn:                  nttOpt,
+	invnttFn:               invnttOpt,
+	pointwiseAccMustFreeze: true,
+}
+
 func nttOpt(p *[kyberN]uint16) {
 	nttAVX2(&p[0], &zetasExp[0])
 }
 
-func invnttOpt(p *[kyberN]uint16) {
-	invnttAVX2(&p[0], &zetasInvExp[0])
+func invnttOpt(a *[kyberN]uint16) {
+	invnttAVX2(&a[0], &zetasInvExp[0])
 }
 
 func initHardwareAcceleration() {
 	if supportsAVX2() {
 		isHardwareAccelerated = true
 		hardwareAccelImpl = implAVX2
-		nttFn = nttOpt
-		invnttFn = invnttOpt
 	}
 }

+ 2 - 2
kem_test.go

@@ -47,7 +47,7 @@ func TestKEM(t *testing.T) {
 }
 
 func doTestKEM(t *testing.T) {
-	impl := "_" + hardwareAccelImpl
+	impl := "_" + hardwareAccelImpl.name
 	for _, p := range allParams {
 		t.Run(p.Name()+"_Keys"+impl, func(t *testing.T) { doTestKEMKeys(t, p) })
 		t.Run(p.Name()+"_Invalid_SecretKey_A"+impl, func(t *testing.T) { doTestKEMInvalidSkA(t, p) })
@@ -168,7 +168,7 @@ func BenchmarkKEM(b *testing.B) {
 }
 
 func doBenchmarkKEM(b *testing.B) {
-	impl := "_" + hardwareAccelImpl
+	impl := "_" + hardwareAccelImpl.name
 	for _, p := range allParams {
 		b.Run(p.Name()+"_GenerateKeyPair"+impl, func(b *testing.B) { doBenchKEMGenerateKeyPair(b, p) })
 		b.Run(p.Name()+"_KEMEncrypt"+impl, func(b *testing.B) { doBenchKEMEncDec(b, p, true) })

+ 1 - 1
kem_vectors_test.go

@@ -42,7 +42,7 @@ func TestKEMVectors(t *testing.T) {
 }
 
 func doTestKEMVectors(t *testing.T) {
-	impl := "_" + hardwareAccelImpl
+	impl := "_" + hardwareAccelImpl.name
 	for _, p := range allParams {
 		t.Run(p.Name()+impl, func(t *testing.T) { doTestKEMVectorsPick(t, p) })
 	}

+ 1 - 1
kex_test.go

@@ -27,7 +27,7 @@ func TestAKE(t *testing.T) {
 }
 
 func doTestKEX(t *testing.T) {
-	impl := "_" + hardwareAccelImpl
+	impl := "_" + hardwareAccelImpl.name
 	for _, p := range allParams {
 		t.Run(p.Name()+"_UAKE"+impl, func(t *testing.T) { doTestUAKE(t, p) })
 		t.Run(p.Name()+"_AKE"+impl, func(t *testing.T) { doTestAKE(t, p) })

+ 2 - 2
poly.go

@@ -121,14 +121,14 @@ func (p *poly) getNoise(seed []byte, nonce byte, eta int) {
 // Computes negacyclic number-theoretic transform (NTT) of a polynomial in
 // place; inputs assumed to be in normal order, output in bitreversed order.
 func (p *poly) ntt() {
-	nttFn(&p.coeffs)
+	hardwareAccelImpl.nttFn(&p.coeffs)
 }
 
 // Computes inverse of negacyclic number-theoretic transform (NTT) of a
 // polynomial in place; inputs assumed to be in bitreversed order, output in
 // normal order.
 func (p *poly) invntt() {
-	invnttFn(&p.coeffs)
+	hardwareAccelImpl.invnttFn(&p.coeffs)
 }
 
 // Add two polynomials.

+ 6 - 4
polyvec.go

@@ -94,11 +94,13 @@ func (p *poly) pointwiseAcc(a, b *polyVec) {
 			p.coeffs[j] += montgomeryReduce(uint32(a.vec[i].coeffs[j]) * uint32(t))
 		}
 
-		// HACK HACK HACK:
+		// The AVX2 inverse-NTT implementation (and possibly others in the
+		// future) require fully reduced inputs, while the reference code
+		// does not.
 		//
-		// The AVX2 code assumes fully reduced coefficients.  Since it's
-		// the only acceleration target right now, just do this here.
-		if isHardwareAccelerated {
+		// Do the right thing based on the current implementation.  Eventually
+		// the AVX2 code will have it's own implementation(s) of this routine.
+		if hardwareAccelImpl.pointwiseAccMustFreeze {
 			p.coeffs[j] = freeze(p.coeffs[j])
 		} else {
 			p.coeffs[j] = barrettReduce(p.coeffs[j])