Browse Source

Added checksums and tests

Ryan Armstrong 4 years ago
parent
commit
899c12233e
4 changed files with 85 additions and 7 deletions
  1. 54 0
      client_test.go
  2. 7 0
      error.go
  3. 0 2
      request.go
  4. 24 5
      response.go

+ 54 - 0
client_test.go

@@ -2,6 +2,7 @@ package grab
 
 import (
 	"bufio"
+	"encoding/hex"
 	"fmt"
 	"net/http"
 	"net/http/httptest"
@@ -12,6 +13,8 @@ import (
 // ts is the test HTTP server instance initiated by TestMain().
 var ts *httptest.Server
 
+// TestMail starts a HTTP test server for all test cases to use as a download
+// source.
 func TestMain(m *testing.M) {
 	// start test HTTP server
 	ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -64,6 +67,8 @@ func TestMain(m *testing.M) {
 	os.Exit(m.Run())
 }
 
+// testFilename executes a request and asserts that the downloaded filename
+// matches the given filename.
 func testFilename(t *testing.T, req *Request, filename string) {
 	// fetch
 	resp, err := DefaultClient.Do(req)
@@ -82,18 +87,67 @@ func testFilename(t *testing.T, req *Request, filename string) {
 	}
 }
 
+// TestWithFilename asserts that the downloaded filename matches a filename
+// specified explicitely via Request.Filename, and not a name matching the
+// request URL or Content-Disposition header.
 func TestWithFilename(t *testing.T) {
 	req, _ := NewRequest(ts.URL + "/url-filename?filename=header-filename")
 	req.Filename = ".testWithFilename"
+
 	testFilename(t, req, req.Filename)
 }
 
+// TestWithHeaderFilename asserts that the downloaded filename matches a
+// filename specified explicitely via the Content-Disposition header and not a
+// name matching the request URL.
 func TestWithHeaderFilename(t *testing.T) {
 	req, _ := NewRequest(ts.URL + "/url-filename?filename=.testWithHeaderFilename")
 	testFilename(t, req, ".testWithHeaderFilename")
 }
 
+// TestWithURLFilename asserts that the downloaded filename matches the
+// requested URL.
 func TestWithURLFilename(t *testing.T) {
 	req, _ := NewRequest(ts.URL + "/.testWithURLFilename?params-filename")
 	testFilename(t, req, ".testWithURLFilename")
 }
+
+// testChecksum executes a request and asserts that the computed checksum for
+// the downloaded file does or does not match the expected checksum.
+func testChecksum(t *testing.T, size int, sum string, match bool) {
+	// create request
+	req, _ := NewRequest(ts.URL + fmt.Sprintf("?size=%d", size))
+	req.Filename = fmt.Sprintf(".testChecksum-%s", sum)
+
+	// set expected checksum
+	sumb, _ := hex.DecodeString(sum)
+	req.SetChecksum("sha256", sumb)
+
+	// fetch
+	resp, err := DefaultClient.Do(req)
+	if err != nil {
+		if !IsChecksumMismatch(err) {
+			t.Errorf("Error in Client.Do(): %v", err)
+		} else if match {
+			t.Errorf("%v (%v bytes)", err, size)
+		}
+	} else if !match {
+		t.Errorf("Expected checksum mismatch but comparison succeeded (%v bytes)", size)
+	}
+
+	// delete downloaded file
+	if err := os.Remove(resp.Filename); err != nil {
+		t.Errorf("Error deleting test file: %v", err)
+	}
+}
+
+// TestChecksums executes a number of checksum tests via testChecksum.
+func TestChecksums(t *testing.T) {
+	testChecksum(t, 128, "471fb943aa23c511f6f72f8d1652d9c880cfa392ad80503120547703e56a2be5", true)
+	testChecksum(t, 1024, "785b0751fc2c53dc14a4ce3d800e69ef9ce1009eb327ccf458afe09c242c26c9", true)
+	testChecksum(t, 1048576, "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83", true)
+
+	testChecksum(t, 128, "00112233", false)
+	testChecksum(t, 1024, "00112233", false)
+	testChecksum(t, 1048576, "00112233", false)
+}

+ 7 - 0
error.go

@@ -7,6 +7,7 @@ import (
 const (
 	errBadLength = iota
 	errNoFilename
+	errChecksumMismatch
 )
 
 // grabError is a custom error type
@@ -52,3 +53,9 @@ func IsContentLengthMismatch(err error) bool {
 func IsNoFilename(err error) bool {
 	return isErrorType(err, errNoFilename)
 }
+
+// IsChecksumMismatch returns a boolean indicating whether the error is known to
+// report that the downloaded file did not match the expected checksum value.
+func IsChecksumMismatch(err error) bool {
+	return isErrorType(err, errChecksumMismatch)
+}

+ 0 - 2
request.go

@@ -46,8 +46,6 @@ type Request struct {
 	// If the checksum values do not match, the file is deleted and an error
 	// returned.
 	Checksum []byte
-
-	SuccessNotify chan<- bool
 }
 
 // NewRequest returns a new file transfer Request given a URL, suitable for use

+ 24 - 5
response.go

@@ -1,8 +1,11 @@
 package grab
 
 import (
+	"bytes"
+	"encoding/hex"
 	"io"
 	"net/http"
+	"os"
 	"sync/atomic"
 	"time"
 )
@@ -100,6 +103,27 @@ func (c *Response) copy() error {
 		}
 	}
 
+	// validate checksum
+	if c.Request.Hash != nil && c.Request.Checksum != nil {
+		// open downloaded file
+		if f, err := os.Open(c.Filename); err != nil {
+			return c.close(err)
+		} else {
+			defer f.Close()
+
+			// hash file
+			if _, err := io.Copy(c.Request.Hash, f); err != nil {
+				return c.close(err)
+			}
+
+			// checksum
+			sum := c.Request.Hash.Sum(nil)
+			if !bytes.Equal(sum, c.Request.Checksum) {
+				return c.close(newGrabError(errChecksumMismatch, "Checksum mismatch: %v", hex.EncodeToString(sum)))
+			}
+		}
+	}
+
 	return c.close(nil)
 }
 
@@ -116,11 +140,6 @@ func (c *Response) close(err error) error {
 	// stop time
 	c.End = time.Now()
 
-	// signal
-	if c.Request != nil && c.Request.SuccessNotify != nil {
-		c.Request.SuccessNotify <- (err == nil)
-	}
-
 	// pass error back
 	return err
 }