Browse Source

AVX2: Merge the YMM code paths down into 2 functions.

And a host of other improvements.
Yawning Angel 7 months ago
parent
commit
50712dfd3b
2 changed files with 180 additions and 251 deletions
  1. 4 103
      hwaccel_amd64.go
  2. 176 148
      hwaccel_amd64.s

+ 4 - 103
hwaccel_amd64.go

@@ -18,22 +18,10 @@ func cpuidAmd64(cpuidParams *uint32)
 func xgetbv0Amd64(xcrVec *uint32)
 
 //go:noescape
-func initAVX2(s *uint64, key, iv *byte)
+func aeadEncryptAVX2(c, m, a []byte, nonce, key *byte)
 
 //go:noescape
-func absorbBlocksAVX2(s *uint64, in *byte, blocks uint64)
-
-//go:noescape
-func encryptBlocksAVX2(s *uint64, out, in *byte, blocks uint64)
-
-//go:noescape
-func decryptBlocksAVX2(s *uint64, out, in *byte, blocks uint64)
-
-//go:noescape
-func decryptLastBlockAVX2(s *uint64, out, in *byte, inLen uint64)
-
-//go:noescape
-func finalizeAVX2(s *uint64, tag *byte, lastBlock *uint64)
+func aeadDecryptAVX2(m, c, a []byte, nonce, key, tag *byte)
 
 func supportsAVX2() bool {
 	// https://software.intel.com/en-us/articles/how-to-detect-new-instruction-support-in-the-4th-generation-intel-core-processor-family
@@ -70,96 +58,15 @@ func supportsAVX2() bool {
 	return regs[1]&avx2Bit != 0
 }
 
-type ymmState struct {
-	s [20]uint64
-}
-
-func (s *ymmState) init(key, iv []byte) {
-	initAVX2(&s.s[0], &key[0], &iv[0])
-}
-
-func (s *ymmState) absorbData(in []byte) {
-	inLen, off := len(in), 0
-	if inLen == 0 {
-		return
-	}
-
-	if inBlocks := inLen / blockSize; inBlocks > 0 {
-		absorbBlocksAVX2(&s.s[0], &in[0], uint64(inBlocks))
-		off += inBlocks * blockSize
-	}
-	in = in[off:]
-
-	if len(in) > 0 {
-		var tmp [blockSize]byte
-		copy(tmp[:], in)
-		absorbBlocksAVX2(&s.s[0], &tmp[0], 1)
-	}
-}
-
-func (s *ymmState) encryptData(out, in []byte) {
-	inLen, off := len(in), 0
-	if inLen == 0 {
-		return
-	}
-
-	if inBlocks := inLen / blockSize; inBlocks > 0 {
-		encryptBlocksAVX2(&s.s[0], &out[0], &in[0], uint64(inBlocks))
-		off += inBlocks * blockSize
-	}
-	out, in = out[off:], in[off:]
-
-	if len(in) > 0 {
-		var tmp [blockSize]byte
-		copy(tmp[:], in)
-		encryptBlocksAVX2(&s.s[0], &tmp[0], &tmp[0], 1)
-		copy(out, tmp[:])
-	}
-}
-
-func (s *ymmState) decryptData(out, in []byte) {
-	inLen, off := len(in), 0
-	if inLen == 0 {
-		return
-	}
-
-	if inBlocks := inLen / blockSize; inBlocks > 0 {
-		decryptBlocksAVX2(&s.s[0], &out[0], &in[0], uint64(inBlocks))
-		off += inBlocks * blockSize
-	}
-	out, in = out[off:], in[off:]
-
-	if len(in) > 0 {
-		var tmp [blockSize]byte
-		copy(tmp[:], in)
-		decryptLastBlockAVX2(&s.s[0], &tmp[0], &tmp[0], uint64(len(in)))
-		copy(out, tmp[:])
-	}
-}
-
-func (s *ymmState) finalize(msgLen, adLen uint64, tag []byte) {
-	var lastBlock = [4]uint64{adLen << 3, msgLen << 3, 0, 0}
-	finalizeAVX2(&s.s[0], &tag[0], &lastBlock[0])
-}
-
 func aeadEncryptYMM(c, m, a, nonce, key []byte) []byte {
-	var s ymmState
 	mLen := len(m)
-
 	ret, out := sliceForAppend(c, mLen+TagSize)
-
-	s.init(key, nonce)
-	s.absorbData(a)
-	s.encryptData(out, m)
-	s.finalize(uint64(mLen), uint64(len(a)), out[mLen:])
-
-	burnUint64s(s.s[:])
+	aeadEncryptAVX2(out, m, a, &nonce[0], &key[0])
 
 	return ret
 }
 
 func aeadDecryptYMM(m, c, a, nonce, key []byte) ([]byte, bool) {
-	var s ymmState
 	var tag [TagSize]byte
 	cLen := len(c)
 
@@ -169,11 +76,7 @@ func aeadDecryptYMM(m, c, a, nonce, key []byte) ([]byte, bool) {
 
 	mLen := cLen - TagSize
 	ret, out := sliceForAppend(m, mLen)
-
-	s.init(key, nonce)
-	s.absorbData(a)
-	s.decryptData(out, c[:mLen])
-	s.finalize(uint64(mLen), uint64(len(a)), tag[:])
+	aeadDecryptAVX2(out, c[:mLen], a, &nonce[0], &key[0], &tag[0])
 
 	srcTag := c[mLen:]
 	ok := subtle.ConstantTimeCompare(srcTag, tag[:]) == 1
@@ -183,8 +86,6 @@ func aeadDecryptYMM(m, c, a, nonce, key []byte) ([]byte, bool) {
 		ret = nil
 	}
 
-	burnUint64s(s.s[:])
-
 	return ret, ok
 }
 

+ 176 - 148
hwaccel_amd64.s

@@ -33,9 +33,7 @@ TEXT ·xgetbv0Amd64(SB), NOSPLIT, $0-8
 // function, along with aliases for the registers used for readability.
 
 // YMM Registers: Sx -> State, Mx -> Message, Tx -> Temporary
-//
-// Note: Routines use other registers as temporaries, the Tx aliases are
-// for those that are clobbered by STATE_UPDATE().
+// GP Registers: RAX, RBX, RCX -> Temporary
 #define S0 Y0
 #define S1 Y1
 #define S2 Y2
@@ -45,20 +43,6 @@ TEXT ·xgetbv0Amd64(SB), NOSPLIT, $0-8
 #define T0 Y14
 #define T1 Y15
 
-#define LOAD_STATE(SRC) \
-	VMOVDQU (SRC), S0    \
-	VMOVDQU 32(SRC), S1  \
-	VMOVDQU 64(SRC), S2  \
-	VMOVDQU 96(SRC), S3  \
-	VMOVDQU 128(SRC), S4
-
-#define STORE_STATE(DST) \
-	VMOVDQU S0, (DST)    \
-	VMOVDQU S1, 32(DST)  \
-	VMOVDQU S2, 64(DST)  \
-	VMOVDQU S3, 96(DST)  \
-	VMOVDQU S4, 128(DST)
-
 // This essentially naively translated from the intrinsics, but neither GCC nor
 // clang's idea of what this should be appears to be better on Broadwell, and
 // there is a benefit to being easy to cross reference with the upstream
@@ -108,165 +92,209 @@ TEXT ·xgetbv0Amd64(SB), NOSPLIT, $0-8
 	VPOR   T0, T1, S4    \
 	VPERMQ $-109, S2, S2
 
-// func initAVX2(s *uint64, key, iv *byte)
-TEXT ·initAVX2(SB), NOSPLIT, $0-24
-	MOVQ s+0(FP), R8
-	MOVQ key+8(FP), R9
-	MOVQ iv+16(FP), R10
-
-	VPXOR    S0, S0, S0
-	MOVOU    (R10), X0
-	VMOVDQU  (R9), S1
-	VPCMPEQD S2, S2, S2
-	VPXOR    S3, S3, S3
-	VMOVDQU  ·initializationConstants(SB), S4
-	VPXOR    M0, M0, M0
-	VMOVDQA  S1, Y6
-
-	MOVQ $16, AX
-
-initloop:
-	STATE_UPDATE()
-	SUBQ $1, AX
-	JNZ  initloop
-
-	VPXOR Y6, S1, S1
-	STORE_STATE(R8)
-
-	VZEROUPPER
-	RET
-
-// func absorbBlocksAVX2(s *uint64, in *byte, blocks uint64)
-TEXT ·absorbBlocksAVX2(SB), NOSPLIT, $0-24
-	MOVQ s+0(FP), R8
-	MOVQ in+8(FP), R10
-	MOVQ blocks+16(FP), R11
-
-	LOAD_STATE(R8)
-
-loopblocks:
-	VMOVDQU (R10), M0
+#define COPY(DST, SRC, LEN) \
+	MOVQ SRC, SI \
+	MOVQ DST, DI \
+	MOVQ LEN, CX \
+	REP          \
+	MOVSB
+
+#define INIT_STATE(IV, KEY) \
+	VPXOR     S0, S0, S0                       \
+	MOVOU     (IV), X0                         \
+	VMOVDQU   (KEY), S1                        \
+	VPCMPEQD  S2, S2, S2                       \
+	VPXOR     S3, S3, S3                       \
+	VMOVDQU   ·initializationConstants(SB), S4 \
+	VPXOR     M0, M0, M0                       \
+	VMOVDQA   S1, Y6                           \
+	MOVQ      $16, AX                          \
+	                                           \
+initLoop:                                    \
+	STATE_UPDATE()                             \
+	SUBQ      $1, AX                           \
+	JNZ       initLoop                         \
+	                                           \
+	VPXOR     Y6, S1, S1
+
+#define ABSORB_BLOCKS(A, ALEN, SCRATCH) \
+	MOVQ            ALEN, AX       \
+	SHRQ            $5, AX         \
+	JZ              absorbPartial  \
+loopAbsorbFull:                  \
+	VMOVDQU         (A), M0        \
+	STATE_UPDATE()                 \
+	ADDQ            $32, A         \
+	SUBQ            $1, AX         \
+	JNZ             loopAbsorbFull \
+absorbPartial:                   \
+	ANDQ            $31, ALEN      \
+	JZ              absorbDone     \
+	COPY(SCRATCH, A, ALEN)         \
+	VMOVDQU         (SCRATCH), M0  \
+	STATE_UPDATE()                 \
+absorbDone:
+
+#define FINALIZE(TAG, ALEN, MLEN, SCRATCH) \
+	SHLQ       $3, ALEN         \
+	MOVQ       ALEN, (SCRATCH)  \
+	SHLQ       $3, MLEN         \
+	MOVQ       MLEN, 8(SCRATCH) \
+	                            \
+	VPXOR      S4, S0, S4       \
+	VMOVDQU    (SCRATCH), M0    \
+	                            \
+	MOVQ       $10, AX          \
+loopFinal:                    \
+	STATE_UPDATE()              \
+	SUBQ       $1, AX           \
+	JNZ        loopFinal        \
+	                            \
+	VPERMQ     $57, S1, Y6      \
+	VPXOR      S0, Y6, Y6       \
+	VPAND      S2, S3, Y7       \
+	VPXOR      Y6, Y7, Y7       \
+	MOVOU      X7, (TAG)
+
+// func aeadEncryptAVX2(c, m, a []byte, nonce, key *byte)
+TEXT ·aeadEncryptAVX2(SB), NOSPLIT, $32-88
+	MOVQ    SP, R15
+	VPXOR   Y13, Y13, Y13
+	VMOVDQU Y13, (R15)
+
+	// Initialize the state.
+	MOVQ nonce+72(FP), R8
+	MOVQ key+80(FP), R9
+	INIT_STATE(R8, R9)
+
+	// Absorb the AD.
+	MOVQ a+48(FP), R8 // &a[0] -> R8
+	MOVQ a+56(FP), R9 // len(a) -> R9
+	ABSORB_BLOCKS(R8, R9, R15)
+
+	// Encrypt the data.
+	MOVQ m+24(FP), R8 // &m[0] -> R8
+	MOVQ m+32(FP), R9 // len(m) -> R9
+	MOVQ c+0(FP), R10 // &c[0] -> R10
+
+	MOVQ R9, AX
+	SHRQ $5, AX
+	JZ   encryptPartial
+
+loopEncryptFull:
+	VMOVDQU (R8), M0
+	VPERMQ  $57, S1, Y6
+	VPXOR   S0, Y6, Y6
+	VPAND   S2, S3, Y7
+	VPXOR   Y6, Y7, Y6
+	VPXOR   M0, Y6, Y6
+	VMOVDQU Y6, (R10)
 	STATE_UPDATE()
+	ADDQ    $32, R8
 	ADDQ    $32, R10
-	SUBQ    $1, R11
-	JNZ     loopblocks
-
-	STORE_STATE(R8)
-
-	VZEROUPPER
-	RET
-
-// func encryptBlocksAVX2(s *uint64, out, in *byte, blocks uint64)
-TEXT ·encryptBlocksAVX2(SB), NOSPLIT, $0-32
-	MOVQ s+0(FP), R8
-	MOVQ out+8(FP), R9
-	MOVQ in+16(FP), R10
-	MOVQ blocks+24(FP), R11
-
-	LOAD_STATE(R8)
-
-loopblocks:
-	VMOVDQU (R10), M0
+	SUBQ    $1, AX
+	JNZ     loopEncryptFull
+
+encryptPartial:
+	ANDQ    $31, R9
+	JZ      encryptDone
+	VMOVDQU Y13, (R15)
+	COPY(R15, R8, R9)
+	VMOVDQU (R15), M0
 	VPERMQ  $57, S1, Y6
 	VPXOR   S0, Y6, Y6
 	VPAND   S2, S3, Y7
 	VPXOR   Y6, Y7, Y6
 	VPXOR   M0, Y6, Y6
-	VMOVDQU Y6, (R9)
+	VMOVDQU Y6, (R15)
 	STATE_UPDATE()
-	ADDQ    $32, R9
-	ADDQ    $32, R10
-	SUBQ    $1, R11
-	JNZ     loopblocks
+	COPY(R10, R15, R9)
+	ADDQ    R9, R10
 
-	STORE_STATE(R8)
+encryptDone:
 
+	// Finalize and write the tag.
+	MOVQ    a+56(FP), R8 // len(a) -> R8
+	MOVQ    m+32(FP), R9 // len(m) -> R9
+	VMOVDQU Y13, (R15)
+	FINALIZE(R10, R8, R9, R15)
+
+	VMOVDQU Y13, (R15)
 	VZEROUPPER
 	RET
 
-// func decryptBlocksAVX2(s *uint64, out, in *byte, blocks uint64)
-TEXT ·decryptBlocksAVX2(SB), NOSPLIT, $0-32
-	MOVQ s+0(FP), R8
-	MOVQ out+8(FP), R9
-	MOVQ in+16(FP), R10
-	MOVQ blocks+24(FP), R11
-
-	LOAD_STATE(R8)
-
-loopblocks:
-	VMOVDQU (R10), M0
+// func aeadDecryptAVX2(m, c, a []byte, nonce, key, tag *byte)
+TEXT ·aeadDecryptAVX2(SB), NOSPLIT, $32-96
+	MOVQ    SP, R15
+	VPXOR   Y13, Y13, Y13
+	VMOVDQU Y13, (R15)
+
+	// Initialize the state.
+	MOVQ nonce+72(FP), R8
+	MOVQ key+80(FP), R9
+	INIT_STATE(R8, R9)
+
+	// Absorb the AD.
+	MOVQ a+48(FP), R8 // &a[0] -> R8
+	MOVQ a+56(FP), R9 // len(a) -> R9
+	ABSORB_BLOCKS(R8, R9, R15)
+
+	// Decrypt the data.
+	MOVQ c+24(FP), R8 // &c[0] -> R8
+	MOVQ c+32(FP), R9 // len(c) -> R9
+	MOVQ m+0(FP), R10 // &m[0] -> R10
+
+	MOVQ R9, AX
+	SHRQ $5, AX
+	JZ   decryptPartial
+
+loopDecryptFull:
+	VMOVDQU (R8), M0
 	VPERMQ  $57, S1, Y6
 	VPXOR   S0, Y6, Y6
 	VPAND   S2, S3, Y7
 	VPXOR   Y6, Y7, Y6
 	VPXOR   M0, Y6, M0
-	VMOVDQU M0, (R9)
+	VMOVDQU M0, (R10)
 	STATE_UPDATE()
-	ADDQ    $32, R9
+	ADDQ    $32, R8
 	ADDQ    $32, R10
-	SUBQ    $1, R11
-	JNZ     loopblocks
-
-	STORE_STATE(R8)
-
-	VZEROUPPER
-	RET
-
-// func decryptLastBlockAVX2(s *uint64, out, in *byte, inLen uint64)
-TEXT ·decryptLastBlockAVX2(SB), NOSPLIT, $0-32
-	MOVQ s+0(FP), R8
-	MOVQ out+8(FP), R9
-	MOVQ in+16(FP), R10
-	MOVQ inLen+24(FP), R11
-
-	LOAD_STATE(R8)
-
-	VMOVDQU (R10), M0
+	SUBQ    $1, AX
+	JNZ     loopDecryptFull
+
+decryptPartial:
+	ANDQ    $31, R9
+	JZ      decryptDone
+	VMOVDQU Y13, (R15)
+	COPY(R15, R8, R9)
+	VMOVDQU (R15), M0
 	VPERMQ  $57, S1, Y6
 	VPXOR   S0, Y6, Y6
 	VPAND   S2, S3, Y7
 	VPXOR   Y6, Y7, Y6
 	VPXOR   M0, Y6, M0
-	VMOVDQU M0, (R9)
-
-	MOVQ R11, AX
-
-loopclear:
-	MOVB $0, (R9)(AX*1)
-	ADDQ $1, AX
-	CMPQ AX, $32
-	JNE  loopclear
-
-	VMOVDQU (R9), M0
+	VMOVDQU M0, (R15)
+	COPY(R10, R15, R9)
+	MOVQ    $0, AX
+	MOVQ    R15, DI
+	MOVQ    $32, CX
+	SUBQ    R9, CX
+	ADDQ    R9, DI
+	CLD
+	REP
+	STOSB
+	VMOVDQU (R15), M0
 	STATE_UPDATE()
-	STORE_STATE(R8)
-
-	VZEROUPPER
-	RET
-
-// func finalizeAVX2(s *uint64, tag *byte, lastBlock *uint64)
-TEXT ·finalizeAVX2(SB), NOSPLIT, $0-24
-	MOVQ s+0(FP), R8
-	MOVQ tag+8(FP), R9
-	MOVQ lastBlock+16(FP), R10
-
-	LOAD_STATE(R8)
 
-	VPXOR   S4, S0, S4
-	VMOVDQU (R10), M0
-
-	MOVQ $10, AX
-
-finalloop:
-	STATE_UPDATE()
-	SUBQ $1, AX
-	JNZ  finalloop
+decryptDone:
 
-	VPERMQ $57, S1, Y6
-	VPXOR  S0, Y6, Y6
-	VPAND  S2, S3, Y7
-	VPXOR  Y6, Y7, Y7
-	MOVOU  X7, (R9)
+	// Finalize and write the tag.
+	MOVQ    a+56(FP), R8    // len(a) -> R8
+	MOVQ    m+32(FP), R9    // len(m) -> R9
+	MOVQ    tag+88(FP), R14 // tag -> R14
+	VMOVDQU Y13, (R15)
+	FINALIZE(R14, R8, R9, R15)
 
+	VMOVDQU Y13, (R15)
 	VZEROUPPER
 	RET