Browse Source

Support 96 bit nonces, and a `c.Seek(counter)` method.

This is the IETF RFC 7539 style ChaCha20 with a 96 bit nonce and 32 bit
counter, used for constructing `AEAD_CHACHA20_POLY1305` and HS1-SIV
amongst other things.

The code will panic when the 32 bit counter wraps, because that's the
easy thing to do, and 2^32 blocks per key/nonce should be sufficient
for anybody, though I reserve the right to change this behavior as
needed at a later date.  Poly1305 only allows authenticating
2^64 bytes anyway so in practice this should never be a concern.
Yawning Angel 3 years ago
parent
commit
f3a398b735
4 changed files with 108 additions and 16 deletions
  1. 42 10
      chacha20.go
  2. 14 1
      chacha20_amd64.go
  3. 10 1
      chacha20_ref.go
  4. 42 4
      chacha20_test.go

+ 42 - 10
chacha20.go

@@ -11,6 +11,7 @@ import (
 	"crypto/cipher"
 	"encoding/binary"
 	"errors"
+	"math"
 	"runtime"
 	"unsafe"
 )
@@ -22,6 +23,9 @@ const (
 	// NonceSize is the ChaCha20 nonce size in bytes.
 	NonceSize = 8
 
+	// INonceSize is the IETF ChaCha20 nonce size in bytes.
+	INonceSize = 12
+
 	// XNonceSize is the XChaCha20 nonce size in bytes.
 	XNonceSize = 24
 
@@ -46,7 +50,10 @@ var (
 	ErrInvalidKey = errors.New("key length must be KeySize bytes")
 
 	// ErrInvalidNonce is the error returned when the nonce is invalid.
-	ErrInvalidNonce = errors.New("nonce length must be NonceSize/XNonceSize bytes")
+	ErrInvalidNonce = errors.New("nonce length must be NonceSize/INonceSize/XNonceSize bytes")
+
+	// ErrInvalidCounter is the error returned when the counter is invalid.
+	ErrInvalidCounter = errors.New("block counter is invalid (out of range)")
 
 	useUnsafe    = false
 	usingVectors = false
@@ -58,8 +65,9 @@ var (
 type Cipher struct {
 	state [stateSize]uint32
 
-	buf [BlockSize]byte
-	off int
+	buf  [BlockSize]byte
+	off  int
+	ietf bool
 }
 
 // Reset zeros the key data so that it will no longer appear in the process's
@@ -86,7 +94,7 @@ func (c *Cipher) XORKeyStream(dst, src []byte) {
 			nrBlocks := remaining / BlockSize
 			directBytes := nrBlocks * BlockSize
 			if nrBlocks > 0 {
-				blocksFn(&c.state, src, dst, nrBlocks)
+				blocksFn(&c.state, src, dst, nrBlocks, c.ietf)
 				remaining -= directBytes
 				if remaining == 0 {
 					return
@@ -97,7 +105,7 @@ func (c *Cipher) XORKeyStream(dst, src []byte) {
 
 			// If there's a partial block, generate 1 block of keystream into
 			// the internal buffer.
-			blocksFn(&c.state, nil, c.buf[:], 1)
+			blocksFn(&c.state, nil, c.buf[:], 1, c.ietf)
 			c.off = 0
 		}
 
@@ -127,7 +135,7 @@ func (c *Cipher) KeyStream(dst []byte) {
 			nrBlocks := remaining / BlockSize
 			directBytes := nrBlocks * BlockSize
 			if nrBlocks > 0 {
-				blocksFn(&c.state, nil, dst, nrBlocks)
+				blocksFn(&c.state, nil, dst, nrBlocks, c.ietf)
 				remaining -= directBytes
 				if remaining == 0 {
 					return
@@ -137,7 +145,7 @@ func (c *Cipher) KeyStream(dst []byte) {
 
 			// If there's a partial block, generate 1 block of keystream into
 			// the internal buffer.
-			blocksFn(&c.state, nil, c.buf[:], 1)
+			blocksFn(&c.state, nil, c.buf[:], 1, c.ietf)
 			c.off = 0
 		}
 
@@ -164,6 +172,7 @@ func (c *Cipher) ReKey(key, nonce []byte) error {
 
 	switch len(nonce) {
 	case NonceSize:
+	case INonceSize:
 	case XNonceSize:
 		var subkey [KeySize]byte
 		var subnonce [HNonceSize]byte
@@ -190,14 +199,37 @@ func (c *Cipher) ReKey(key, nonce []byte) error {
 	c.state[6] = binary.LittleEndian.Uint32(key[24:28])
 	c.state[7] = binary.LittleEndian.Uint32(key[28:32])
 	c.state[8] = 0
-	c.state[9] = 0
-	c.state[10] = binary.LittleEndian.Uint32(nonce[0:4])
-	c.state[11] = binary.LittleEndian.Uint32(nonce[4:8])
+	if len(nonce) == INonceSize {
+		c.state[9] = binary.LittleEndian.Uint32(nonce[0:4])
+		c.state[10] = binary.LittleEndian.Uint32(nonce[4:8])
+		c.state[11] = binary.LittleEndian.Uint32(nonce[8:12])
+		c.ietf = true
+	} else {
+		c.state[9] = 0
+		c.state[10] = binary.LittleEndian.Uint32(nonce[0:4])
+		c.state[11] = binary.LittleEndian.Uint32(nonce[4:8])
+		c.ietf = false
+	}
 	c.off = BlockSize
 	return nil
 
 }
 
+// Seek sets the block counter to a given offset.
+func (c *Cipher) Seek(blockCounter uint64) error {
+	if c.ietf {
+		if blockCounter > math.MaxUint32 {
+			return ErrInvalidCounter
+		}
+		c.state[8] = uint32(blockCounter)
+	} else {
+		c.state[8] = uint32(blockCounter)
+		c.state[9] = uint32(blockCounter >> 32)
+	}
+	c.off = BlockSize
+	return nil
+}
+
 // NewCipher returns a new ChaCha20/XChaCha20 instance.
 func NewCipher(key, nonce []byte) (*Cipher, error) {
 	c := new(Cipher)

+ 14 - 1
chacha20_amd64.go

@@ -6,8 +6,13 @@
 // <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
 
 // +build amd64,!gccgo,!appengine
+
 package chacha20
 
+import (
+	"math"
+)
+
 func blocksAmd64SSE2(sigma, one, x *uint32, in, out *byte, nrBlocks uint)
 
 // One day these won't be parameters when PeachPy fixes issue #11, and they
@@ -17,7 +22,15 @@ func blocksAmd64SSE2(sigma, one, x *uint32, in, out *byte, nrBlocks uint)
 var one = [4]uint32{1, 0, 0, 0}
 var sigma = [4]uint32{sigma0, sigma1, sigma2, sigma3}
 
-func blocksAmd64(x *[stateSize]uint32, in []byte, out []byte, nrBlocks int) {
+func blocksAmd64(x *[stateSize]uint32, in []byte, out []byte, nrBlocks int, isIetf bool) {
+	if isIetf {
+		var totalBlocks uint64
+		totalBlocks = uint64(x[8]) + uint64(nrBlocks)
+		if totalBlocks > math.MaxUint32 {
+			panic("chacha20: Exceeded keystream per nonce limit")
+		}
+	}
+
 	if in == nil {
 		for i := range out {
 			out[i] = 0

+ 10 - 1
chacha20_ref.go

@@ -9,10 +9,19 @@ package chacha20
 
 import (
 	"encoding/binary"
+	"math"
 	"unsafe"
 )
 
-func blocksRef(x *[stateSize]uint32, in []byte, out []byte, nrBlocks int) {
+func blocksRef(x *[stateSize]uint32, in []byte, out []byte, nrBlocks int, isIetf bool) {
+	if isIetf {
+		var totalBlocks uint64
+		totalBlocks = uint64(x[8]) + uint64(nrBlocks)
+		if totalBlocks > math.MaxUint32 {
+			panic("chacha20: Exceeded keystream per nonce limit")
+		}
+	}
+
 	for n := 0; n < nrBlocks; n++ {
 		x0, x1, x2, x3 := sigma0, sigma1, sigma2, sigma3
 		x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15 := x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9], x[10], x[11]

+ 42 - 4
chacha20_test.go

@@ -16,10 +16,11 @@ import (
 // Test vectors taken from:
 // https://tools.ietf.org/html/draft-strombergson-chacha-test-vectors-01
 var draftTestVectors = []struct {
-	name   string
-	key    []byte
-	iv     []byte
-	stream []byte
+	name       string
+	key        []byte
+	iv         []byte
+	stream     []byte
+	seekOffset uint64
 }{
 	{
 		name: "IETF Draft: TC1: All zero key and IV.",
@@ -279,6 +280,37 @@ var draftTestVectors = []struct {
 			0xe4, 0x18, 0x3a,
 		},
 	},
+	{
+		name: "RFC 7539 Test Vector (96 bit nonce)",
+		key: []byte{
+			0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
+			0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
+			0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
+			0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
+		},
+		iv: []byte{
+			0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4a,
+			0x00, 0x00, 0x00, 0x00,
+		},
+		stream: []byte{
+			0x22, 0x4f, 0x51, 0xf3, 0x40, 0x1b, 0xd9, 0xe1,
+			0x2f, 0xde, 0x27, 0x6f, 0xb8, 0x63, 0x1d, 0xed,
+			0x8c, 0x13, 0x1f, 0x82, 0x3d, 0x2c, 0x06, 0xe2,
+			0x7e, 0x4f, 0xca, 0xec, 0x9e, 0xf3, 0xcf, 0x78,
+			0x8a, 0x3b, 0x0a, 0xa3, 0x72, 0x60, 0x0a, 0x92,
+			0xb5, 0x79, 0x74, 0xcd, 0xed, 0x2b, 0x93, 0x34,
+			0x79, 0x4c, 0xba, 0x40, 0xc6, 0x3e, 0x34, 0xcd,
+			0xea, 0x21, 0x2c, 0x4c, 0xf0, 0x7d, 0x41, 0xb7,
+			0x69, 0xa6, 0x74, 0x9f, 0x3f, 0x63, 0x0f, 0x41,
+			0x22, 0xca, 0xfe, 0x28, 0xec, 0x4d, 0xc4, 0x7e,
+			0x26, 0xd4, 0x34, 0x6d, 0x70, 0xb9, 0x8c, 0x73,
+			0xf3, 0xe9, 0xc5, 0x3a, 0xc4, 0x0c, 0x59, 0x45,
+			0x39, 0x8b, 0x6e, 0xda, 0x1a, 0x83, 0x2c, 0x89,
+			0xc1, 0x67, 0xea, 0xcd, 0x90, 0x1d, 0x7e, 0x2b,
+			0xf3, 0x63,
+		},
+		seekOffset: 1,
+	},
 }
 
 func TestChaCha20(t *testing.T) {
@@ -288,6 +320,12 @@ func TestChaCha20(t *testing.T) {
 			t.Errorf("[%s]: New(k, iv) returned: %s", v.name, err)
 			continue
 		}
+		if v.seekOffset != 0 {
+			if err = c.Seek(v.seekOffset); err != nil {
+				t.Errorf("[%s]: Seek(seekOffset) returned: %s", v.name, err)
+				continue
+			}
+		}
 		out := make([]byte, len(v.stream))
 		c.XORKeyStream(out, out)
 		if !bytes.Equal(out, v.stream) {