Browse Source

Make the Encrypt/Decrypt API more like the AEAD one.

Now takes a `dst` slice to append do, won't allocate if there's
sufficient capacity.  Unlike the AEAD api, the dst and
plaintext/ciphertext can't overlap, because the destination is used to
store intermediary values.
Yawning Angel 2 years ago
parent
commit
deddc61482
3 changed files with 50 additions and 20 deletions
  1. 4 2
      aead.go
  2. 36 12
      aez.go
  3. 10 6
      aez_test.go

+ 4 - 2
aead.go

@@ -65,7 +65,8 @@ func (a *AeadAEZ) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
 	if additionalData != nil {
 		ad = append(ad, additionalData)
 	}
-	c := Encrypt(a.key[:], nonce, ad, aeadOverhead, plaintext)
+	// WARNING: The AEAD interface expects plaintext/dst overlap to be allowed.
+	c := Encrypt(a.key[:], nonce, ad, aeadOverhead, plaintext, nil)
 	dst = append(dst, c...)
 
 	return dst
@@ -85,7 +86,8 @@ func (a *AeadAEZ) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, e
 	if additionalData != nil {
 		ad = append(ad, additionalData)
 	}
-	d, ok := Decrypt(a.key[:], nonce, ad, aeadOverhead, ciphertext)
+	// WARNING: The AEAD interface expects ciphertext/dst overlap to be allowed.
+	d, ok := Decrypt(a.key[:], nonce, ad, aeadOverhead, ciphertext, nil)
 	if !ok {
 		return nil, errOpen
 	}

+ 36 - 12
aez.go

@@ -477,11 +477,22 @@ func (e *eState) decipher(delta *[blockSize]byte, in, out []byte) {
 }
 
 // Encrypt encrypts and authenticates the plaintext, authenticates the
-// additional data, and returns the resulting ciphertext.  The length
-// of the authentication tag in bytes is specified by tau.
-func Encrypt(key []byte, nonce []byte, additionalData [][]byte, tau int, plaintext []byte) []byte {
+// additional data, and appends the result to ciphertext, returning the
+// updated slice.  The length of the authentication tag in bytes is specified
+// by tau.  The plaintext and dst slices MUST NOT overlap.
+func Encrypt(key []byte, nonce []byte, additionalData [][]byte, tau int, plaintext, dst []byte) []byte {
 	var delta [blockSize]byte
-	x := make([]byte, tau+len(plaintext))
+
+	var x []byte
+	dstSz, xSz := len(dst), len(plaintext)+tau
+	if cap(dst) >= dstSz+xSz {
+		dst = dst[:dstSz+xSz]
+	} else {
+		x = make([]byte, dstSz+xSz)
+		copy(x, dst)
+		dst = x
+	}
+	x = dst[dstSz:]
 
 	var e eState
 	defer e.reset()
@@ -491,25 +502,38 @@ func Encrypt(key []byte, nonce []byte, additionalData [][]byte, tau int, plainte
 	if len(plaintext) == 0 {
 		e.aezPRF(&delta, tau, x)
 	} else {
+		memwipe(x[len(plaintext):])
 		copy(x, plaintext)
 		e.encipher(&delta, x, x)
 	}
 
-	return x
+	return dst
 }
 
 // Decrypt decrypts and authenticates the ciphertext, authenticates the
-// additional data, and if successful returns the plaintext and true.  The
-// length of the expected authentication tag in bytes is specified by tau.
-func Decrypt(key []byte, nonce []byte, additionalData [][]byte, tau int, ciphertext []byte) ([]byte, bool) {
+// additional data, and if successful appends the resulting plaintext to the
+// provided slice and returns the updated slice and true.  The length of the
+// expected authentication tag in bytes is specified by tau.  The ciphertext
+// and dst slices MUST NOT overlap.
+func Decrypt(key []byte, nonce []byte, additionalData [][]byte, tau int, ciphertext, dst []byte) ([]byte, bool) {
 	var delta [blockSize]byte
 	sum := byte(0)
-	x := make([]byte, len(ciphertext))
 
 	if len(ciphertext) < tau {
 		return nil, false
 	}
 
+	var x []byte
+	dstSz, xSz := len(dst), len(ciphertext)
+	if cap(dst) >= dstSz+xSz {
+		dst = dst[:dstSz+xSz]
+	} else {
+		x = make([]byte, dstSz+xSz)
+		copy(x, dst)
+		dst = x
+	}
+	x = dst[dstSz:]
+
 	var e eState
 	defer e.reset()
 
@@ -520,20 +544,20 @@ func Decrypt(key []byte, nonce []byte, additionalData [][]byte, tau int, ciphert
 		for i := 0; i < tau; i++ {
 			sum |= x[i] ^ ciphertext[i]
 		}
-		x = nil
+		dst = dst[:dstSz]
 	} else {
 		e.decipher(&delta, ciphertext, x)
 		for i := 0; i < tau; i++ {
 			sum |= x[len(ciphertext)-tau+i]
 		}
 		if sum == 0 {
-			x = x[:len(ciphertext)-tau]
+			dst = dst[:dstSz+len(ciphertext)-tau]
 		}
 	}
 	if sum != 0 { // return true if valid, false if invalid
 		return nil, false
 	}
-	return x, true
+	return dst, true
 }
 
 func memwipe(b []byte) {

+ 10 - 6
aez_test.go

@@ -1372,17 +1372,18 @@ func TestEncryptDecrypt(t *testing.T) {
 		}
 
 		e.init(vecK)
-		c := Encrypt(vecK, vecNonce, vecData, vec.tau, vecM)
+		c := Encrypt(vecK, vecNonce, vecData, vec.tau, vecM, nil)
 		assertEqual(t, i, vecC, c)
 		if aead != nil {
 			ac := aead.Seal(nil, vecNonce, vecM, ad)
 			assertEqual(t, i, vecC, ac)
 		}
 
-		m, ok := Decrypt(vecK, vecNonce, vecData, vec.tau, vecC)
+		m, ok := Decrypt(vecK, vecNonce, vecData, vec.tau, vecC, nil)
 		if !ok {
 			t.Fatalf("decrypt failed: [%d]", i)
 		}
+		assertEqual(t, i, vecM, m)
 		if aead != nil {
 			am, err := aead.Open(nil, vecNonce, vecC, ad)
 			if err != nil {
@@ -1390,7 +1391,6 @@ func TestEncryptDecrypt(t *testing.T) {
 			}
 			assertEqual(t, i, vecM, am)
 		}
-		assertEqual(t, i, vecM, m)
 	}
 }
 
@@ -1417,23 +1417,27 @@ func doBenchEncrypt(b *testing.B, n int) {
 		b.Fail()
 	}
 
+	const tau = 16
+
 	var nonce [16]byte
 	src := make([]byte, n)
+	dst := make([]byte, n+tau)
+	check := make([]byte, n+tau)
 
 	b.SetBytes(int64(n))
 	b.ResetTimer()
 	for i := 0; i < b.N; i++ {
 		b.StartTimer()
-		dst := Encrypt(key[:], nonce[:], nil, 16, src[:n])
+		dst = Encrypt(key[:], nonce[:], nil, tau, src[:n], dst[:0])
 		b.StopTimer()
-		dec, ok := Decrypt(key[:], nonce[:], nil, 16, dst)
+		dec, ok := Decrypt(key[:], nonce[:], nil, tau, dst, check[:0])
 		if !ok {
 			b.Fatalf("decrypt failed")
 		}
 		if !bytes.Equal(dec, src) {
 			b.Fatalf("decrypt produced invalid output")
 		}
-		src = dec
+		copy(src, dst[:n])
 	}
 
 	benchOutput = src