Browse Source

Use unaligned loads/stores for interacting with the state.

Turns out the Go compiler doesn't align the state vector in certain
conditions.  Sigh.
Yawning Angel 1 year ago
parent
commit
6f059cfcee
2 changed files with 26 additions and 27 deletions
  1. 14 15
      chacha20_amd64.py
  2. 12 12
      chacha20_amd64.s

+ 14 - 15
chacha20_amd64.py

@@ -76,8 +76,7 @@ def WriteXor_sse2(tmp, inp, outp, d, v0, v1, v2, v3):
     MOVDQU([outp+d+48], tmp)
 
 # SSE2 ChaCha20 (aka vec128).  Does not handle partial blocks, and will
-# process 4/2/1 blocks at a time.  x (the ChaCha20 state) must be 16 byte
-# aligned.
+# process 4/2/1 blocks at a time.
 with Function("blocksAmd64SSE2", (x, inp, outp, nrBlocks)):
     reg_x = GeneralPurposeRegister64()
     reg_inp = GeneralPurposeRegister64()
@@ -144,10 +143,10 @@ with Function("blocksAmd64SSE2", (x, inp, outp, nrBlocks)):
     SUB(reg_blocks, 4)
     JB(vector_loop4.end)
     with vector_loop4:
-        MOVDQA(xmm_v0, mem_s0)
-        MOVDQA(xmm_v1, mem_s1)
-        MOVDQA(xmm_v2, mem_s2)
-        MOVDQA(xmm_v3, mem_s3)
+        MOVDQU(xmm_v0, mem_s0)
+        MOVDQU(xmm_v1, mem_s1)
+        MOVDQU(xmm_v2, mem_s2)
+        MOVDQU(xmm_v3, mem_s3)
 
         MOVDQA(xmm_v4, xmm_v0)
         MOVDQA(xmm_v5, xmm_v1)
@@ -340,7 +339,7 @@ with Function("blocksAmd64SSE2", (x, inp, outp, nrBlocks)):
         PADDD(xmm_v2, mem_s2)
         PADDD(xmm_v3, mem_s3)
         WriteXor_sse2(xmm_tmp, reg_inp, reg_outp, 0, xmm_v0, xmm_v1, xmm_v2, xmm_v3)
-        MOVDQA(xmm_v3, mem_s3)
+        MOVDQU(xmm_v3, mem_s3)
         PADDQ(xmm_v3, mem_one)
 
         PADDD(xmm_v4, mem_s0)
@@ -366,7 +365,7 @@ with Function("blocksAmd64SSE2", (x, inp, outp, nrBlocks)):
         WriteXor_sse2(xmm_v0, reg_inp, reg_outp, 192, xmm_v12, xmm_v13, xmm_v14, xmm_v15)
         PADDQ(xmm_v3, mem_one)
 
-        MOVDQA(mem_s3, xmm_v3)
+        MOVDQU(mem_s3, xmm_v3)
 
         ADD(reg_inp, 4 * 64)
         ADD(reg_outp, 4 * 64)
@@ -386,10 +385,10 @@ with Function("blocksAmd64SSE2", (x, inp, outp, nrBlocks)):
     xmm_s2 = xmm_v10
     xmm_s3 = xmm_v11
     xmm_one = xmm_v13
-    MOVDQA(xmm_s0, mem_s0)
-    MOVDQA(xmm_s1, mem_s1)
-    MOVDQA(xmm_s2, mem_s2)
-    MOVDQA(xmm_s3, mem_s3)
+    MOVDQU(xmm_s0, mem_s0)
+    MOVDQU(xmm_s1, mem_s1)
+    MOVDQU(xmm_s2, mem_s2)
+    MOVDQU(xmm_s3, mem_s3)
     MOVDQA(xmm_one, mem_one)
 
     #
@@ -598,7 +597,7 @@ with Function("blocksAmd64SSE2", (x, inp, outp, nrBlocks)):
     # Write back the updated counter.  Stoping at 2^70 bytes is the user's
     # problem, not mine.  (Skipped if there's exactly a multiple of 4 blocks
     # because the counter is incremented in memory while looping.)
-    MOVDQA(mem_s3, xmm_s3)
+    MOVDQU(mem_s3, xmm_s3)
 
     LABEL(out)
 
@@ -666,7 +665,7 @@ def WriteXor_avx2(tmp, inp, outp, d, v0, v1, v2, v3):
     VMOVDQU([outp+d+96], tmp)
 
 # AVX2 ChaCha20 (aka avx2).  Does not handle partial blocks, will process
-# 8/4/2 blocks at a time.  Alignment blah blah blah fuck you.
+# 8/4/2 blocks at a time.
 with Function("blocksAmd64AVX2", (x, inp, outp, nrBlocks), target=uarch.broadwell):
     reg_x = GeneralPurposeRegister64()
     reg_inp = GeneralPurposeRegister64()
@@ -1238,7 +1237,7 @@ with Function("blocksAmd64AVX2", (x, inp, outp, nrBlocks), target=uarch.broadwel
     VPERM2I128(ymm_s3, ymm_s3, ymm_s3, 0x01) # Odd number of blocks.
 
     LABEL(out_write_even)
-    VMOVDQA(x_s3, ymm_s3.as_xmm) # Write back ymm_s3 to x_v3
+    VMOVDQU(x_s3, ymm_s3.as_xmm) # Write back ymm_s3 to x_v3
 
     # Paranoia, cleanse the scratch space.
     VPXOR(ymm_v0, ymm_v0, ymm_v0)

+ 12 - 12
chacha20_amd64.s

@@ -19,10 +19,10 @@ TEXT ·blocksAmd64SSE2(SB),4,$0-32
 	SUBQ $4, DX
 	JCS vector_loop4_end
 vector_loop4_begin:
-		MOVO 0(AX), X0
-		MOVO 16(AX), X1
-		MOVO 32(AX), X2
-		MOVO 48(AX), X3
+		MOVOU 0(AX), X0
+		MOVOU 16(AX), X1
+		MOVOU 32(AX), X2
+		MOVOU 48(AX), X3
 		MOVO X0, X4
 		MOVO X1, X5
 		MOVO X2, X6
@@ -283,7 +283,7 @@ rounds_loop4_begin:
 		MOVOU 48(BX), X12
 		PXOR X3, X12
 		MOVOU X12, 48(CX)
-		MOVO 48(AX), X3
+		MOVOU 48(AX), X3
 		PADDQ 0(SP), X3
 		PADDL 0(AX), X4
 		PADDL 16(AX), X5
@@ -337,7 +337,7 @@ rounds_loop4_begin:
 		PXOR X15, X0
 		MOVOU X0, 240(CX)
 		PADDQ 0(SP), X3
-		MOVO X3, 48(AX)
+		MOVOU X3, 48(AX)
 		ADDQ $256, BX
 		ADDQ $256, CX
 		SUBQ $4, DX
@@ -345,10 +345,10 @@ rounds_loop4_begin:
 vector_loop4_end:
 	ADDQ $4, DX
 	JEQ out
-	MOVO 0(AX), X8
-	MOVO 16(AX), X9
-	MOVO 32(AX), X10
-	MOVO 48(AX), X11
+	MOVOU 0(AX), X8
+	MOVOU 16(AX), X9
+	MOVOU 32(AX), X10
+	MOVOU 48(AX), X11
 	MOVO 0(SP), X13
 	SUBQ $2, DX
 	JCS process_1_block
@@ -593,7 +593,7 @@ rounds_loop1_begin:
 	MOVOU X12, 48(CX)
 	PADDQ X13, X11
 out_serial:
-	MOVO X11, 48(AX)
+	MOVOU X11, 48(AX)
 out:
 	PXOR X0, X0
 	MOVO X0, 16(SP)
@@ -1150,7 +1150,7 @@ rounds_loop2_begin:
 out_write_odd:
 	BYTE $0xC4; BYTE $0x43; BYTE $0x1D; BYTE $0x46; BYTE $0xE4; BYTE $0x01 // VPERM2I128 ymm12, ymm12, ymm12, 1
 out_write_even:
-	BYTE $0xC5; BYTE $0x79; BYTE $0x7F; BYTE $0x60; BYTE $0x30 // VMOVDQA [rax + 48], xmm12
+	BYTE $0xC5; BYTE $0x7A; BYTE $0x7F; BYTE $0x60; BYTE $0x30 // VMOVDQU [rax + 48], xmm12
 	BYTE $0xC5; BYTE $0xED; BYTE $0xEF; BYTE $0xD2 // VPXOR ymm2, ymm2, ymm2
 	BYTE $0xC5; BYTE $0xFD; BYTE $0x7F; BYTE $0x54; BYTE $0x24; BYTE $0x40 // VMOVDQA [rsp + 64], ymm2
 	BYTE $0xC5; BYTE $0xFD; BYTE $0x7F; BYTE $0x54; BYTE $0x24; BYTE $0x20 // VMOVDQA [rsp + 32], ymm2