Browse Source

Add code neccecary to handle io.CopyBuffer().

Yawning Angel 3 years ago
parent
commit
04ca627b5e
2 changed files with 15 additions and 2 deletions
  1. 14 1
      common.go
  2. 1 1
      padding_obfs4.go

+ 14 - 1
common.go

@@ -47,7 +47,8 @@ const (
 	minReqExtDataSize  = 1 + 1 + 1 // Version, nrPaddingAlgs, > 1 padding alg.
 	minRespExtDataSize = 1 + 1 + 1 // Version, authPolicy, padding alg.
 
-	tauReadDelay = 5000 // Microseconds.
+	tauReadDelay          = 5000 // Microseconds.
+	defaultCopyBufferSize = 32 * 1024
 )
 
 var (
@@ -181,6 +182,7 @@ type commonConn struct {
 	paddingMethod PaddingMethod
 
 	maxRecordSize     int
+	copyBufferSize    int
 	enforceRecordSize bool
 	enableReadDelay   bool
 
@@ -192,6 +194,16 @@ func (c *commonConn) Stats() *ConnStats {
 	return &c.stats
 }
 
+// SetCopyBufferSize sets the hint used to detect large bulk transfers
+// when the connection is the destination side of io.Copy()/io.CopyBuffer().
+// By default something sensible for io.Copy() will be used.
+func (c *commonConn) SetCopyBufferSize(sz int) {
+	if sz <= 0 {
+		panic("basket2: SetCopyBufferSize called with invalid value")
+	}
+	c.copyBufferSize = sz
+}
+
 // Write writes len(p) bytes to the stream, and returns the number of bytes
 // written, or an error.  All errors must be considered fatal.
 func (c *commonConn) Write(p []byte) (n int, err error) {
@@ -276,6 +288,7 @@ func (c *commonConn) initConn(conn net.Conn) error {
 
 	c.paddingMethod = PaddingInvalid
 	c.mRNG = rand.New()
+	c.copyBufferSize = defaultCopyBufferSize
 
 	// Derive the "max" record size based off the remote address,
 	// under the assumption that 1500 byte MTU ethernet is in use.

+ 1 - 1
padding_obfs4.go

@@ -113,7 +113,7 @@ func (p *obfs4Padding) largeWrite(b []byte) (n int, err error) {
 	// enabled).
 
 	remaining := len(b)
-	isLargeWrite := remaining >= 32*1024 // XXX: What about CopyBuffer?
+	isLargeWrite := remaining >= p.conn.copyBufferSize
 
 	tailPadLen := p.burstDist.Sample(p.conn.mRNG)
 	// tailPadLen += c.conn.maxRecordSize * c.conn.mRNG.Intn(3)