Browse Source

Implement the bitsliced 64 bit Pass 1/Pass 2. #2

Before:
 BenchmarkEncrypt_65536-4     	     300	   4930493 ns/op	  13.29 MB/s

After:
 BenchmarkEncrypt_65536-4     	    1000	   1821470 ns/op	  35.98 MB/s

Somewhat disapointing, I'm fairly sure I can improve this by doing an
optimization pass over the Pass 2 code.  But it works, and is a clear
improvement over the 32 bit code.
Yawning Angel 2 years ago
parent
commit
8bd96ae53e
3 changed files with 192 additions and 5 deletions
  1. 6 2
      aez.go
  2. 5 3
      round_bitsliced32.go
  3. 181 0
      round_bitsliced64.go

+ 6 - 2
aez.go

@@ -33,7 +33,7 @@ const (
 
 var (
 	extractBlake2Cfg             = &blake2b.Config{Size: extractedKeySize}
-	newAes           aesImplCtor = newRoundB32
+	newAes           aesImplCtor = newRoundB64
 	zero                         = [blockSize]byte{}
 )
 
@@ -238,6 +238,8 @@ func (e *eState) aezCorePass1Slow(in, out []byte, X *[blockSize]byte, sz int) {
 	switch a := e.aes.(type) {
 	case *roundB32:
 		a.aezCorePass1(e, in, out, X, sz)
+	case *roundB64:
+		a.aezCorePass1(e, in, out, X, sz)
 	default:
 		e.aezCorePass1Ref(in, out, X)
 	}
@@ -249,7 +251,9 @@ func (e *eState) aezCorePass2Slow(in, out []byte, Y, S *[blockSize]byte, sz int)
 	// Use one of the portable bitsliced options if possible.
 	switch a := e.aes.(type) {
 	case *roundB32:
-		a.aezCorePass2(e, in, out, Y, S, sz)
+		a.aezCorePass2(e, out, Y, S, sz)
+	case *roundB64:
+		a.aezCorePass2(e, out, Y, S, sz)
 	default:
 		e.aezCorePass2Ref(in, out, Y, S)
 	}

+ 5 - 3
round_bitsliced32.go

@@ -41,7 +41,9 @@ func (r *roundB32) AES4(j, i, l *[blockSize]byte, src []byte, dst *[blockSize]by
 	memwipeU32(q[:])
 }
 
-func (r *roundB32) aes4x2(j0, i0, l0 *[blockSize]byte, src0 []byte, dst0 *[blockSize]byte, j1, i1, l1 *[blockSize]byte, src1 []byte, dst1 *[blockSize]byte) {
+func (r *roundB32) aes4x2(
+	j0, i0, l0 *[blockSize]byte, src0 []byte, dst0 *[blockSize]byte,
+	j1, i1, l1 *[blockSize]byte, src1 []byte, dst1 *[blockSize]byte) {
 	// XXX/performance: Fairly sure i, src, and dst are the only things
 	// that are ever different here so XORs can be pruned.
 
@@ -123,7 +125,7 @@ func (r *roundB32) aezCorePass1(e *eState, in, out []byte, X *[blockSize]byte, s
 	memwipe(I[:])
 }
 
-func (r *roundB32) aezCorePass2(e *eState, in, out []byte, Y, S *[blockSize]byte, sz int) {
+func (r *roundB32) aezCorePass2(e *eState, out []byte, Y, S *[blockSize]byte, sz int) {
 	var tmp0, tmp1, I [blockSize]byte
 
 	copy(I[:], e.I[1][:])
@@ -158,7 +160,7 @@ func (r *roundB32) aezCorePass2(e *eState, in, out []byte, Y, S *[blockSize]byte
 		copy(out[blockSize*3:], tmp1[:])
 
 		sz -= 4 * blockSize
-		in, out = in[64:], out[64:]
+		out = out[64:]
 		if (i+1)%8 == 0 {
 			doubleBlock(&I)
 		}

+ 181 - 0
round_bitsliced64.go

@@ -41,6 +41,27 @@ func (r *roundB64) AES4(j, i, l *[blockSize]byte, src []byte, dst *[blockSize]by
 	memwipeU64(q[:])
 }
 
+func (r *roundB64) aes4x4(
+	j0, i0, l0 *[blockSize]byte, src0 []byte, dst0 *[blockSize]byte,
+	j1, i1, l1 *[blockSize]byte, src1 []byte, dst1 *[blockSize]byte,
+	j2, i2, l2 *[blockSize]byte, src2 []byte, dst2 *[blockSize]byte,
+	j3, i3, l3 *[blockSize]byte, src3 []byte, dst3 *[blockSize]byte) {
+	var q [8]uint64
+	xorBytes4x16(j0[:], i0[:], l0[:], src0, dst0[:])
+	xorBytes4x16(j1[:], i1[:], l1[:], src1, dst1[:])
+	xorBytes4x16(j2[:], i2[:], l2[:], src2, dst2[:])
+	xorBytes4x16(j3[:], i3[:], l3[:], src3, dst3[:])
+
+	r.Load16xU32(&q, dst0[:], dst1[:], dst2[:], dst3[:])
+	r.round(&q, r.skey[8:])  // J
+	r.round(&q, r.skey[0:])  // I
+	r.round(&q, r.skey[16:]) // L
+	r.round(&q, r.skey[24:]) // zero
+	r.Store16xU32(dst0[:], dst1[:], dst2[:], dst3[:], &q)
+
+	memwipeU64(q[:])
+}
+
 func (r *roundB64) AES10(l *[blockSize]byte, src []byte, dst *[blockSize]byte) {
 	var q [8]uint64
 	xorBytes1x16(src, l[:], dst[:])
@@ -64,6 +85,166 @@ func (r *roundB64) round(q *[8]uint64, k []uint64) {
 	r.AddRoundKey(q, k)
 }
 
+func (r *roundB64) aezCorePass1(e *eState, in, out []byte, X *[blockSize]byte, sz int) {
+	var tmp0, tmp1, tmp2, tmp3, I [blockSize]byte
+
+	copy(I[:], e.I[1][:])
+	i := 1
+
+	// Process 8 * 16 bytes at a time in a loop.
+	for mult := false; sz >= 8*blockSize; mult = !mult {
+		r.aes4x4(&e.J[0], &I, &e.L[(i+0)%8], in[blockSize:], &tmp0,
+			&e.J[0], &I, &e.L[(i+1)%8], in[blockSize*3:], &tmp1,
+			&e.J[0], &I, &e.L[(i+2)%8], in[blockSize*5:], &tmp2,
+			&e.J[0], &I, &e.L[(i+3)%8], in[blockSize*7:], &tmp3) // E(1,i) ... E(1,i+3)
+		xorBytes1x16(in[:], tmp0[:], out[:])
+		xorBytes1x16(in[blockSize*2:], tmp1[:], out[blockSize*2:])
+		xorBytes1x16(in[blockSize*4:], tmp2[:], out[blockSize*4:])
+		xorBytes1x16(in[blockSize*6:], tmp3[:], out[blockSize*6:])
+
+		r.aes4x4(&zero, &e.I[0], &e.L[0], out[:], &tmp0,
+			&zero, &e.I[0], &e.L[0], out[blockSize*2:], &tmp1,
+			&zero, &e.I[0], &e.L[0], out[blockSize*4:], &tmp2,
+			&zero, &e.I[0], &e.L[0], out[blockSize*6:], &tmp3) // E(0,0) x4
+		xorBytes1x16(in[blockSize:], tmp0[:], out[blockSize:])
+		xorBytes1x16(in[blockSize*3:], tmp1[:], out[blockSize*3:])
+		xorBytes1x16(in[blockSize*5:], tmp2[:], out[blockSize*5:])
+		xorBytes1x16(in[blockSize*7:], tmp3[:], out[blockSize*7:])
+
+		xorBytes1x16(out[blockSize:], X[:], X[:])
+		xorBytes1x16(out[blockSize*3:], X[:], X[:])
+		xorBytes1x16(out[blockSize*5:], X[:], X[:])
+		xorBytes1x16(out[blockSize*7:], X[:], X[:])
+
+		sz -= 8 * blockSize
+		in, out = in[128:], out[128:]
+		if mult { // Multiply every other pass.
+			doubleBlock(&I)
+		}
+		i += 4
+	}
+
+	// XXX/performance: 4 * 16 bytes at a time.
+
+	for sz > 0 {
+		r.AES4(&e.J[0], &I, &e.L[i%8], in[blockSize:], &tmp0) // E(1,i)
+		xorBytes1x16(in[:], tmp0[:], out[:])
+		r.AES4(&zero, &e.I[0], &e.L[0], out[:], &tmp0) // E(0,0)
+		xorBytes1x16(in[blockSize:], tmp0[:], out[blockSize:])
+		xorBytes1x16(out[blockSize:], X[:], X[:])
+
+		sz -= 2 * blockSize
+		in, out = in[32:], out[32:]
+		if i%8 == 0 {
+			doubleBlock(&I)
+		}
+		i++
+	}
+
+	memwipe(tmp0[:])
+	memwipe(tmp1[:])
+	memwipe(tmp2[:])
+	memwipe(tmp3[:])
+	memwipe(I[:])
+}
+
+func (r *roundB64) aezCorePass2(e *eState, out []byte, Y, S *[blockSize]byte, sz int) {
+	var tmp0, tmp1, tmp2, tmp3, I [blockSize]byte
+
+	copy(I[:], e.I[1][:])
+	i := 1
+
+	// Process 8 * 16 bytes at a time in a loop.
+	for mult := false; sz >= 8*blockSize; mult = !mult {
+		r.aes4x4(&e.J[1], &I, &e.L[(i+0)%8], S[:], &tmp0,
+			&e.J[1], &I, &e.L[(i+1)%8], S[:], &tmp1,
+			&e.J[1], &I, &e.L[(i+2)%8], S[:], &tmp2,
+			&e.J[1], &I, &e.L[(i+3)%8], S[:], &tmp3) // E(2,i) .. E(2,i+3)
+		xorBytes1x16(out, tmp0[:], out[:])
+		xorBytes1x16(out[blockSize*2:], tmp1[:], out[blockSize*2:])
+		xorBytes1x16(out[blockSize*4:], tmp2[:], out[blockSize*4:])
+		xorBytes1x16(out[blockSize*6:], tmp3[:], out[blockSize*6:])
+		xorBytes1x16(out[blockSize:], tmp0[:], out[blockSize:])
+		xorBytes1x16(out[blockSize*3:], tmp1[:], out[blockSize*3:])
+		xorBytes1x16(out[blockSize*5:], tmp2[:], out[blockSize*5:])
+		xorBytes1x16(out[blockSize*7:], tmp3[:], out[blockSize*7:])
+		xorBytes1x16(out, Y[:], Y[:])
+		xorBytes1x16(out[blockSize*2:], Y[:], Y[:])
+		xorBytes1x16(out[blockSize*4:], Y[:], Y[:])
+		xorBytes1x16(out[blockSize*6:], Y[:], Y[:])
+
+		r.aes4x4(&zero, &e.I[0], &e.L[0], out[blockSize:], &tmp0,
+			&zero, &e.I[0], &e.L[0], out[blockSize*3:], &tmp1,
+			&zero, &e.I[0], &e.L[0], out[blockSize*5:], &tmp2,
+			&zero, &e.I[0], &e.L[0], out[blockSize*7:], &tmp3) // E(0,0)x4
+		xorBytes1x16(out, tmp0[:], out[:])
+		xorBytes1x16(out[blockSize*2:], tmp1[:], out[blockSize*2:])
+		xorBytes1x16(out[blockSize*4:], tmp2[:], out[blockSize*4:])
+		xorBytes1x16(out[blockSize*6:], tmp3[:], out[blockSize*6:])
+
+		r.aes4x4(&e.J[0], &I, &e.L[(i+0)%8], out[:], &tmp0,
+			&e.J[0], &I, &e.L[(i+1)%8], out[blockSize*2:], &tmp1,
+			&e.J[0], &I, &e.L[(i+2)%8], out[blockSize*4:], &tmp2,
+			&e.J[0], &I, &e.L[(i+3)%8], out[blockSize*6:], &tmp3) // E(1,i) ...  E(1,i+3)
+		xorBytes1x16(out[blockSize:], tmp0[:], out[blockSize:])
+		xorBytes1x16(out[blockSize*3:], tmp1[:], out[blockSize*3:])
+		xorBytes1x16(out[blockSize*5:], tmp2[:], out[blockSize*5:])
+		xorBytes1x16(out[blockSize*7:], tmp3[:], out[blockSize*7:])
+
+		copy(tmp0[:], out[:])
+		copy(tmp1[:], out[blockSize*2:])
+		copy(tmp2[:], out[blockSize*4:])
+		copy(tmp3[:], out[blockSize*6:])
+		copy(out[:blockSize], out[blockSize:])
+		copy(out[blockSize*2:blockSize*3], out[blockSize*3:])
+		copy(out[blockSize*4:blockSize*5], out[blockSize*5:])
+		copy(out[blockSize*6:blockSize*7], out[blockSize*7:])
+		copy(out[blockSize:], tmp0[:])
+		copy(out[blockSize*3:], tmp1[:])
+		copy(out[blockSize*5:], tmp2[:])
+		copy(out[blockSize*7:], tmp3[:])
+
+		sz -= 8 * blockSize
+		out = out[128:]
+		if mult { // Multiply every other pass.
+			doubleBlock(&I)
+		}
+		i += 4
+	}
+
+	// XXX/performance: 4 * 16 bytes at a time.
+
+	for sz > 0 {
+		r.AES4(&e.J[1], &I, &e.L[i%8], S[:], &tmp0) // E(2,i)
+		xorBytes1x16(out, tmp0[:], out[:])
+		xorBytes1x16(out[blockSize:], tmp0[:], out[blockSize:])
+		xorBytes1x16(out, Y[:], Y[:])
+
+		r.AES4(&zero, &e.I[0], &e.L[0], out[blockSize:], &tmp0) // E(0,0)
+		xorBytes1x16(out, tmp0[:], out[:])
+
+		r.AES4(&e.J[0], &I, &e.L[i%8], out[:], &tmp0) // E(1,i)
+		xorBytes1x16(out[blockSize:], tmp0[:], out[blockSize:])
+
+		copy(tmp0[:], out[:])
+		copy(out[:blockSize], out[blockSize:])
+		copy(out[blockSize:], tmp0[:])
+
+		sz -= 2 * blockSize
+		out = out[32:]
+		if i%8 == 0 {
+			doubleBlock(&I)
+		}
+		i++
+	}
+
+	memwipe(tmp0[:])
+	memwipe(tmp1[:])
+	memwipe(tmp2[:])
+	memwipe(tmp3[:])
+	memwipe(I[:])
+}
+
 func memwipeU64(s []uint64) {
 	for i := range s {
 		s[i] = 0