Browse Source

Enforce the minimum padding length at the handshake obfuscation layer.

Yawning Angel 4 years ago
parent
commit
76f6d342bc
2 changed files with 17 additions and 13 deletions
  1. 11 13
      handshake/handshake_test.go
  2. 6 0
      handshake/obfuscation.go

+ 11 - 13
handshake/handshake_test.go

@@ -24,6 +24,7 @@ import (
 	"net"
 	"sync"
 	"testing"
+	"time"
 
 	"git.schwanenlied.me/yawning/basket2.git/crypto/identity"
 )
@@ -60,13 +61,16 @@ func (s *testState) aliceRoutine() {
 	defer s.Done()
 	defer s.alicePipe.Close()
 
+	s.alicePipe.SetDeadline(time.Now().Add(5 * time.Second))
+
 	hs, err := NewClientHandshake(rand.Reader, s.kexMethod, &s.bobKeypair.PublicKey)
 	if err != nil {
 		s.aliceCh <- err
 		return
 	}
 
-	k, extData, err := hs.Handshake(s.aliceRw, clientExtData, 0)
+	padLen := MinHandshakeSize - (MessageSize + len(clientExtData))
+	k, extData, err := hs.Handshake(s.aliceRw, clientExtData, padLen)
 	if err != nil {
 		s.aliceCh <- err
 		return
@@ -84,6 +88,8 @@ func (s *testState) bobRoutine() {
 	defer s.Done()
 	defer s.bobPipe.Close()
 
+	s.bobPipe.SetDeadline(time.Now().Add(5 * time.Second))
+
 	hs, err := NewServerHandshake(rand.Reader, kexMethods, s.replay, s.bobKeypair)
 	if err != nil {
 		s.bobCh <- err
@@ -101,7 +107,8 @@ func (s *testState) bobRoutine() {
 		return
 	}
 
-	k, err := hs.SendHandshakeResp(s.bobRw, serverExtData, 0)
+	padLen := MinHandshakeSize - (MessageSize + len(serverExtData))
+	k, err := hs.SendHandshakeResp(s.bobRw, serverExtData, padLen)
 	if err != nil {
 		s.bobCh <- err
 		return
@@ -138,19 +145,10 @@ func (s *testState) oneIter() error {
 	}
 
 	// Sanity check to ensure that the amount of data sent on the wire each
-	// way is identical, after the difference in test vector length is
-	// accounted for (Never let ease of test writing detract from using cool
-	// test vectors).
-	lenDiff := uint64(len(serverExtData) - len(clientExtData))
-	if s.aliceRw.bytesWrite+lenDiff != s.bobRw.bytesWrite {
+	// way is identical.
+	if s.aliceRw.bytesWrite != s.bobRw.bytesWrite {
 		return fmt.Errorf("bytes written count mismatch")
 	}
-	if int(s.aliceRw.bytesWrite)-len(clientExtData) != MessageSize {
-		return fmt.Errorf("client bytes mismatch: %v", s.aliceRw.bytesWrite)
-	}
-	if int(s.bobRw.bytesWrite)-len(serverExtData) != MessageSize {
-		return fmt.Errorf("server bytes mismatch: %v", s.bobRw.bytesWrite)
-	}
 
 	return nil
 }

+ 6 - 0
handshake/obfuscation.go

@@ -154,6 +154,9 @@ func (o *clientObfsCtx) handshake(rw io.ReadWriter, msg []byte, padLen int) ([]b
 		// This is technically valid, but is stupid, so disallow it.
 		return nil, ErrNoPayload
 	}
+	if want < MinHandshakeSize-obfsServerOverhead {
+		return nil, ErrInvalidPayload
+	}
 	if want > MaxHandshakeSize-obfsServerOverhead {
 		return nil, ErrInvalidPayload
 	}
@@ -332,6 +335,9 @@ func (o *serverObfsCtx) recvHandshakeReq(r io.Reader) ([]byte, error) {
 		// This is technically valid, but is stupid, so disallow it.
 		return nil, ErrNoPayload
 	}
+	if want < MinHandshakeSize-obfsClientOverhead {
+		return nil, ErrInvalidPayload
+	}
 	if want > MaxHandshakeSize-obfsClientOverhead {
 		return nil, ErrInvalidPayload
 	}