Browse Source

More AMD64 cleanups.

 * It seems safe to assume that x is 16 byte aligned.  I was doing so
   anyway by the use of PADDD, but I was loading it with MOVDQU for some
   stupid reason.

 * Align the stack (where the counter increment vector resides) to 16
   bytes before placing the vector, so that I can use MOVDQA.

 * Hopefully improve readability.

The alignment stuff doesn't really impact performance in a meaningful
sense. since the penalty for an unaligned load into an XMM register is
1 clock cycle.  This feels cleaner though.
Yawning Angel 3 years ago
parent
commit
6788ab3601
2 changed files with 73 additions and 48 deletions
  1. 50 29
      chacha20_amd64.py
  2. 23 19
      chacha20_amd64.s

+ 50 - 29
chacha20_amd64.py

@@ -31,6 +31,14 @@ inp = Argument(ptr(const_uint8_t))
 outp = Argument(ptr(uint8_t))
 nrBlocks = Argument(ptr(size_t))
 
+# Helper routines for the actual ChaCha round function.
+#
+# Note:
+#   It's been pointed out by the PeachPy author that the tmp variable for a
+#   scratch register is kind of silly, but I only have one XMM register that
+#   can be used for scratch (everything else is used to store the cipher state
+#   or output).
+
 def RotV1(x):
     PSHUFD(x, x, 0x39)
 
@@ -129,6 +137,8 @@ def WriteXor(tmp, inp, outp, d, v0, v1, v2, v3):
     PXOR(tmp, v3)
     MOVDQU([outp+d+48], tmp)
 
+# SSE2 ChaCha20.  Does not handle partial blocks, and will process 3 blocks at
+# a time.  x (the ChaCha20 state) must be 16 byte aligned.
 with Function("blocksAmd64SSE2", (x, inp, outp, nrBlocks)):
     reg_x = GeneralPurposeRegister64()
     reg_inp = GeneralPurposeRegister64()
@@ -140,6 +150,15 @@ with Function("blocksAmd64SSE2", (x, inp, outp, nrBlocks)):
     LOAD.ARGUMENT(reg_outp, outp)
     LOAD.ARGUMENT(reg_blocks, nrBlocks)
 
+    # Align the stack to a 16 byte boundary.
+    reg_align_tmp = GeneralPurposeRegister64()
+    MOV(reg_align_tmp, registers.rsp)
+    AND(reg_align_tmp, 0x0f)
+    reg_align = GeneralPurposeRegister64()
+    MOV(reg_align, 0x10)
+    SUB(reg_align, reg_align_tmp)
+    SUB(registers.rsp, reg_align)
+
     # Build the counter increment vector on the stack.
     SUB(registers.rsp, 16)
     reg_tmp = GeneralPurposeRegister32()
@@ -149,14 +168,16 @@ with Function("blocksAmd64SSE2", (x, inp, outp, nrBlocks)):
     MOV([registers.rsp+4], reg_tmp)
     MOV([registers.rsp+8], reg_tmp)
     MOV([registers.rsp+12], reg_tmp)
+    mem_one = [registers.rsp]  # (Stack) Counter increment vector
 
-    xmm_tmp = XMMRegister()
-    xmm_s1 = XMMRegister()
-    MOVDQU(xmm_s1, [reg_x+16])
-    xmm_s2 = XMMRegister()
-    MOVDQU(xmm_s2, [reg_x+32])
-    xmm_s3 = XMMRegister()
-    MOVDQU(xmm_s3, [reg_x+48])
+    xmm_tmp = XMMRegister()    # The single scratch register
+    mem_s0 = [reg_x]           # (Memory) Cipher state [0..3]
+    xmm_s1 = XMMRegister()     # (Fixed Reg) Cipher state [4..7]
+    MOVDQA(xmm_s1, [reg_x+16])
+    xmm_s2 = XMMRegister()     # (Fixed Reg) Cipher state [8..11]
+    MOVDQA(xmm_s2, [reg_x+32])
+    xmm_s3 = XMMRegister()     # (Fixed Reg) Cipher state [12..15]
+    MOVDQA(xmm_s3, [reg_x+48])
 
     vector_loop = Loop()
     serial_loop = Loop()
@@ -179,23 +200,22 @@ with Function("blocksAmd64SSE2", (x, inp, outp, nrBlocks)):
     SUB(reg_blocks, 3)
     JB(vector_loop.end)
     with vector_loop:
-        MOVDQU(xmm_v0, [reg_x]) # <- sigma
+        MOVDQA(xmm_v0, mem_s0)
         MOVDQA(xmm_v1, xmm_s1)
         MOVDQA(xmm_v2, xmm_s2)
         MOVDQA(xmm_v3, xmm_s3)
-        MOVDQU(xmm_tmp, [registers.rsp])
 
-        MOVDQA(xmm_v4, xmm_v0) # <- sigma
+        MOVDQA(xmm_v4, mem_s0)
         MOVDQA(xmm_v5, xmm_s1)
         MOVDQA(xmm_v6, xmm_s2)
         MOVDQA(xmm_v7, xmm_s3)
-        PADDQ(xmm_v7, xmm_tmp) # + counter
+        PADDQ(xmm_v7, mem_one)
 
-        MOVDQA(xmm_v8, xmm_v0) # <- sigma
+        MOVDQA(xmm_v8, mem_s0)
         MOVDQA(xmm_v9, xmm_s1)
         MOVDQA(xmm_v10, xmm_s2)
         MOVDQA(xmm_v11, xmm_v7)
-        PADDQ(xmm_v11, xmm_tmp) # +  counter
+        PADDQ(xmm_v11, mem_one)
 
         reg_rounds = GeneralPurposeRegister64()
         MOV(reg_rounds, 20)
@@ -207,27 +227,26 @@ with Function("blocksAmd64SSE2", (x, inp, outp, nrBlocks)):
             SUB(reg_rounds, 2)
             JNZ(rounds_loop.begin)
 
-        PADDD(xmm_v0, [reg_x])
+        PADDD(xmm_v0, mem_s0)
         PADDD(xmm_v1, xmm_s1)
         PADDD(xmm_v2, xmm_s2)
         PADDD(xmm_v3, xmm_s3)
         WriteXor(xmm_tmp, reg_inp, reg_outp, 0, xmm_v0, xmm_v1, xmm_v2, xmm_v3)
-        MOVDQU(xmm_v0, [registers.rsp])
-        PADDQ(xmm_s3, xmm_v0) # + counter
+        PADDQ(xmm_s3, mem_one)
 
-        PADDD(xmm_v4, [reg_x])
+        PADDD(xmm_v4, mem_s0)
         PADDD(xmm_v5, xmm_s1)
         PADDD(xmm_v6, xmm_s2)
         PADDD(xmm_v7, xmm_s3)
         WriteXor(xmm_tmp, reg_inp, reg_outp, 64, xmm_v4, xmm_v5, xmm_v6, xmm_v7)
-        PADDQ(xmm_s3, xmm_v0) # +counter
+        PADDQ(xmm_s3, mem_one)
 
-        PADDD(xmm_v8, [reg_x])
+        PADDD(xmm_v8, mem_s0)
         PADDD(xmm_v9, xmm_s1)
         PADDD(xmm_v10, xmm_s2)
         PADDD(xmm_v11, xmm_s3)
         WriteXor(xmm_tmp, reg_inp, reg_outp, 128, xmm_v8, xmm_v9, xmm_v10, xmm_v11)
-        PADDQ(xmm_s3, xmm_v0) # + counter
+        PADDQ(xmm_s3, mem_one)
 
         ADD(reg_inp, 192)
         ADD(reg_outp, 192)
@@ -238,13 +257,14 @@ with Function("blocksAmd64SSE2", (x, inp, outp, nrBlocks)):
     ADD(reg_blocks, 3)
     JZ(serial_loop.end)
 
-    # "serial" component.  Since we actually have registers now:
-    #    xmm_v4 = 1, 0, 0, 0
-    #    xmm_v5 = sigma
-    MOVDQU(xmm_v4, [registers.rsp])
-    MOVDQU(xmm_v5, [reg_x])
+    # Since we're only doing 1 block at  a time, we can use registers for s0
+    # and the counter vector now.
+    xmm_s0 = xmm_v4
+    xmm_one = xmm_v5
+    MOVDQA(xmm_s0, mem_s0)   # sigma
+    MOVDQA(xmm_one, mem_one) # counter increment
     with serial_loop:
-        MOVDQA(xmm_v0, xmm_v5)
+        MOVDQA(xmm_v0, xmm_s0)
         MOVDQA(xmm_v1, xmm_s1)
         MOVDQA(xmm_v2, xmm_s2)
         MOVDQA(xmm_v3, xmm_s3)
@@ -257,12 +277,12 @@ with Function("blocksAmd64SSE2", (x, inp, outp, nrBlocks)):
             SUB(reg_rounds, 2)
             JNZ(rounds_loop.begin)
 
-        PADDD(xmm_v0, xmm_v5)
+        PADDD(xmm_v0, xmm_s0)
         PADDD(xmm_v1, xmm_s1)
         PADDD(xmm_v2, xmm_s2)
         PADDD(xmm_v3, xmm_s3)
         WriteXor(xmm_tmp, reg_inp, reg_outp, 0, xmm_v0, xmm_v1, xmm_v2, xmm_v3)
-        PADDQ(xmm_s3, xmm_v4)
+        PADDQ(xmm_s3, xmm_one)
 
         ADD(reg_inp, 64)
         ADD(reg_outp, 64)
@@ -272,8 +292,9 @@ with Function("blocksAmd64SSE2", (x, inp, outp, nrBlocks)):
 
     # Write back the updated counter.  Stoping at 2^70 bytes is the user's
     # problem, not mine.
-    MOVDQU([reg_x+48], xmm_s3)
+    MOVDQA([reg_x+48], xmm_s3)
 
     ADD(registers.rsp, 16)
+    ADD(registers.rsp, reg_align)
 
     RETURN()

+ 23 - 19
chacha20_amd64.s

@@ -7,6 +7,11 @@ TEXT ·blocksAmd64SSE2(SB),4,$0-32
 	MOVQ inp+8(FP), BX
 	MOVQ outp+16(FP), CX
 	MOVQ nrBlocks+24(FP), DX
+	MOVQ SP, DI
+	ANDQ $15, DI
+	MOVQ $16, SI
+	SUBQ DI, SI
+	SUBQ SI, SP
 	SUBQ $16, SP
 	MOVL $1, DI
 	MOVL DI, 0(SP)
@@ -14,27 +19,26 @@ TEXT ·blocksAmd64SSE2(SB),4,$0-32
 	MOVL DI, 4(SP)
 	MOVL DI, 8(SP)
 	MOVL DI, 12(SP)
-	MOVOU 16(AX), X1
-	MOVOU 32(AX), X2
-	MOVOU 48(AX), X3
+	MOVO 16(AX), X1
+	MOVO 32(AX), X2
+	MOVO 48(AX), X3
 	SUBQ $3, DX
 	JCS vector_loop_end
 vector_loop_begin:
-		MOVOU 0(AX), X4
+		MOVO 0(AX), X4
 		MOVO X1, X5
 		MOVO X2, X6
 		MOVO X3, X7
-		MOVOU 0(SP), X0
-		MOVO X4, X8
+		MOVO 0(AX), X8
 		MOVO X1, X9
 		MOVO X2, X10
 		MOVO X3, X11
-		PADDQ X0, X11
-		MOVO X4, X12
+		PADDQ 0(SP), X11
+		MOVO 0(AX), X12
 		MOVO X1, X13
 		MOVO X2, X14
 		MOVO X11, X15
-		PADDQ X0, X15
+		PADDQ 0(SP), X15
 		MOVQ $20, DI
 rounds_loop0_begin:
 			PADDL X5, X4
@@ -217,8 +221,7 @@ rounds_loop0_begin:
 		MOVOU 48(BX), X0
 		PXOR X7, X0
 		MOVOU X0, 48(CX)
-		MOVOU 0(SP), X4
-		PADDQ X4, X3
+		PADDQ 0(SP), X3
 		PADDL 0(AX), X8
 		PADDL X1, X9
 		PADDL X2, X10
@@ -235,7 +238,7 @@ rounds_loop0_begin:
 		MOVOU 112(BX), X0
 		PXOR X11, X0
 		MOVOU X0, 112(CX)
-		PADDQ X4, X3
+		PADDQ 0(SP), X3
 		PADDL 0(AX), X12
 		PADDL X1, X13
 		PADDL X2, X14
@@ -252,7 +255,7 @@ rounds_loop0_begin:
 		MOVOU 176(BX), X0
 		PXOR X15, X0
 		MOVOU X0, 176(CX)
-		PADDQ X4, X3
+		PADDQ 0(SP), X3
 		ADDQ $192, BX
 		ADDQ $192, CX
 		SUBQ $3, DX
@@ -260,10 +263,10 @@ rounds_loop0_begin:
 vector_loop_end:
 	ADDQ $3, DX
 	JEQ serial_loop_end
-	MOVOU 0(SP), X8
-	MOVOU 0(AX), X9
+	MOVO 0(AX), X8
+	MOVO 0(SP), X9
 serial_loop_begin:
-		MOVO X9, X4
+		MOVO X8, X4
 		MOVO X1, X5
 		MOVO X2, X6
 		MOVO X3, X7
@@ -325,7 +328,7 @@ rounds_loop1_begin:
 			PSHUFL $57, X7, X7
 			SUBQ $2, DI
 			JNE rounds_loop1_begin
-		PADDL X9, X4
+		PADDL X8, X4
 		PADDL X1, X5
 		PADDL X2, X6
 		PADDL X3, X7
@@ -341,12 +344,13 @@ rounds_loop1_begin:
 		MOVOU 48(BX), X0
 		PXOR X7, X0
 		MOVOU X0, 48(CX)
-		PADDQ X8, X3
+		PADDQ X9, X3
 		ADDQ $64, BX
 		ADDQ $64, CX
 		SUBQ $1, DX
 		JNE serial_loop_begin
 serial_loop_end:
-	MOVOU X3, 48(AX)
+	MOVO X3, 48(AX)
 	ADDQ $16, SP
+	ADDQ SI, SP
 	RET