Browse Source

Clamp the various read sizes to plausible/sensible values.

Yawning Angel 4 years ago
parent
commit
0a0b910989
5 changed files with 36 additions and 18 deletions
  1. 2 2
      client.go
  2. 15 9
      common.go
  3. 11 5
      handshake/handshake.go
  4. 6 0
      handshake/obfuscation.go
  5. 2 2
      server.go

+ 2 - 2
client.go

@@ -72,11 +72,11 @@ func (c *ClientConn) Handshake(conn net.Conn) (err error) {
 	// adding a random amount of padding.
 	//
 	// All requests on the wire will be of length [min, max).
-	padLen := minHandshakeSize - (handshake.MessageSize + len(reqExtData))
+	padLen := handshake.MinHandshakeSize - (handshake.MessageSize + len(reqExtData))
 	if padLen < 0 { // Should never happen.
 		panic("basket2: handshake request exceeds payload capacity")
 	}
-	padLen += c.mRNG.Intn(maxHandshakeSize - minHandshakeSize)
+	padLen += c.mRNG.Intn(handshake.MaxHandshakeSize - handshake.MinHandshakeSize)
 
 	// Send the request, receive the response, and derive the session keys.
 	var keys *handshake.SessionKeys

+ 15 - 9
common.go

@@ -41,9 +41,6 @@ const (
 	// ProtocolVersion is the transport protocol version.
 	ProtocolVersion = 0
 
-	// XXX: Should I adjust these??  Maybe make them framesize based?
-	minHandshakeSize                 = 4096
-	maxHandshakeSize                 = 8192
 	minReqExtDataSize                = 1 + 1 + 1 // Version, nrPaddingAlgs, > 1 padding alg.
 	minRespExtDataSize               = 1 + 1 + 1 // Version, authPolicy, padding alg.
 	paddingInvalid     PaddingMethod = 0xff
@@ -62,6 +59,9 @@ var (
 	// padding method.
 	ErrInvalidPadding = errors.New("basket2: invalid padding")
 
+	// ErrMsgSize is the error returned on a message size violation.
+	ErrMsgSize = errors.New("basket2: oversized message")
+
 	// ErrInvalidExtData is the error returned when the req/resp handshake
 	// extData is invalid.
 	ErrInvalidExtData = errors.New("basket2: invalid ext data")
@@ -106,13 +106,16 @@ type commonConn struct {
 	mRNG  *mrand.Rand
 	state connState
 
-	rawConn  net.Conn
-	isClient bool
+	rawConn net.Conn
+
+	txEncoder *tentp.Encoder
+	rxDecoder *tentp.Decoder
+	impl      paddingImpl
 
-	txEncoder     *tentp.Encoder
-	rxDecoder     *tentp.Decoder
-	impl          paddingImpl
-	maxRecordSize int
+	maxRecordSize     int
+	enforceRecordSize bool
+
+	isClient bool
 }
 
 // Conn returns the raw underlying net.Conn associated with the basket2
@@ -384,6 +387,9 @@ func (c *commonConn) RecvRawRecord() (cmd byte, msg []byte, err error) {
 		// Record with no payload, return early.
 		return
 	}
+	if c.enforceRecordSize && want > c.maxRecordSize {
+		return 0, nil, ErrMsgSize
+	}
 
 	// Receive/Decode the TENTP record body.
 	recBody := make([]byte, want)

+ 11 - 5
handshake/handshake.go

@@ -41,6 +41,12 @@ const (
 	// without including user extData or padding (2146 bytes).
 	MessageSize = x448RespSize + obfsServerOverhead
 
+	// MinHandshakeSize is the minimum total handshake length.
+	MinHandshakeSize = 4096
+
+	// MaxHandshakeSize is the maximum total handshake length.
+	MaxHandshakeSize = 8192
+
 	minReqSize = 1 + 1
 )
 
@@ -199,8 +205,8 @@ type ServerHandshake struct {
 // RecvHandshakeReq receives and validates the client's handshake request and
 // returns the client's extData if any.  Callers are responsible for setting
 // timeouts as appropriate.
-func (s *ServerHandshake) RecvHandshakeReq(rw io.ReadWriter) ([]byte, error) {
-	reqBlob, err := s.obfs.recvHandshakeReq(rw)
+func (s *ServerHandshake) RecvHandshakeReq(r io.Reader) ([]byte, error) {
+	reqBlob, err := s.obfs.recvHandshakeReq(r)
 	if err != nil {
 		return nil, err
 	}
@@ -233,14 +239,14 @@ func (s *ServerHandshake) RecvHandshakeReq(rw io.ReadWriter) ([]byte, error) {
 // extData is encrypted/authenticated without PFS.  Callers are responsible
 // for setting timeouts as appropriate.  Upon return, Reset will be called
 // automatically.
-func (s *ServerHandshake) SendHandshakeResp(rw io.ReadWriter, extData []byte, padLen int) (*SessionKeys, error) {
+func (s *ServerHandshake) SendHandshakeResp(w io.Writer, extData []byte, padLen int) (*SessionKeys, error) {
 	defer s.Reset()
 
 	switch s.kexMethod {
 	case X25519NewHope:
-		return s.sendRespX25519(rw, extData, padLen)
+		return s.sendRespX25519(w, extData, padLen)
 	case X448NewHope:
-		return s.sendRespX448(rw, extData, padLen)
+		return s.sendRespX448(w, extData, padLen)
 	default:
 		return nil, ErrInvalidKEXMethod
 	}

+ 6 - 0
handshake/obfuscation.go

@@ -153,6 +153,9 @@ func (o *clientObfsCtx) handshake(rw io.ReadWriter, msg []byte, padLen int) ([]b
 		// This is technically valid, but is stupid, so disallow it.
 		return nil, ErrNoPayload
 	}
+	if want > MaxHandshakeSize-obfsServerOverhead {
+		return nil, ErrInvalidPayload
+	}
 
 	// Receive/Decode the peer's response payload body.
 	//
@@ -324,6 +327,9 @@ func (o *serverObfsCtx) recvHandshakeReq(r io.Reader) ([]byte, error) {
 		// This is technically valid, but is stupid, so disallow it.
 		return nil, ErrNoPayload
 	}
+	if want > MaxHandshakeSize-obfsClientOverhead {
+		return nil, ErrInvalidPayload
+	}
 
 	// Read/Decode client request body.
 	reqBody := make([]byte, want)

+ 2 - 2
server.go

@@ -110,11 +110,11 @@ func (s *ServerConn) Handshake(conn net.Conn) (err error) {
 	// Determine the response padding length by adding padding required to
 	// bring the response size up to the minimum target length, and then
 	// adding a random amount of padding.
-	padLen := minHandshakeSize - (handshake.MessageSize + len(respExtData))
+	padLen := handshake.MinHandshakeSize - (handshake.MessageSize + len(respExtData))
 	if padLen < 0 { // Should never happen.
 		panic("basket2: handshake response exceeds payload capacity")
 	}
-	padLen += s.mRNG.Intn(maxHandshakeSize - minHandshakeSize)
+	padLen += s.mRNG.Intn(handshake.MaxHandshakeSize - handshake.MinHandshakeSize)
 
 	// Send the handshake response and derive the session keys.
 	var keys *handshake.SessionKeys