Browse Source

Added GetBatch convenience mthod

Ryan Armstrong 4 years ago
parent
commit
1653c5609d
5 changed files with 88 additions and 9 deletions
  1. 11 5
      client.go
  2. 14 1
      client_test.go
  3. 8 0
      error.go
  4. 49 0
      grab.go
  5. 6 3
      response.go

+ 11 - 5
client.go

@@ -77,15 +77,13 @@ func (c *Client) Do(req *Request) (*Response, error) {
 // Any error which occurs during the file transfer will be set in the returned
 // Response.Error field at the time which it occurs.
 func (c *Client) DoAsync(req *Request) <-chan *Response {
-	r := make(chan *Response, 0)
+	r := make(chan *Response, 1)
 	go func() {
 		// prepare request with HEAD request
 		resp, err := c.do(req)
 		if err == nil && !resp.IsComplete() {
 			// transfer data in new goroutine
-			go func() {
-				resp.copy()
-			}()
+			go resp.copy()
 		}
 
 		r <- resp
@@ -99,12 +97,15 @@ func (c *Client) DoAsync(req *Request) <-chan *Response {
 // returns a channel to receive the file transfer response contexts. The channel
 // is closed once all responses have been received.
 //
+// If zero is given as the worker count, one worker will be created for each
+// given request.
+//
 // Each response is sent through the channel once the request is initiated via
 // HTTP GET or an error has occurred but before the file transfer begins.
 //
 // Any error which occurs during any of the file transfers will be set in the
 // associated Response.Error field.
-func (c *Client) DoBatch(reqs Requests, workers int) <-chan *Response {
+func (c *Client) DoBatch(workers int, reqs Requests) <-chan *Response {
 	// TODO: enable cancelling of batch jobs
 
 	responses := make(chan *Response, workers)
@@ -126,6 +127,11 @@ func (c *Client) DoBatch(reqs Requests, workers int) <-chan *Response {
 		close(responses)
 	}()
 
+	// default one worker per request
+	if workers == 0 {
+		workers = len(reqs)
+	}
+
 	// start workers
 	for i := 0; i < workers; i++ {
 		go func(i int) {

+ 14 - 1
client_test.go

@@ -47,6 +47,14 @@ func TestMain(m *testing.M) {
 			w.Header().Set("Content-Disposition", fmt.Sprintf("attachment;filename=\"%s\"", filenamep))
 		}
 
+		// sleep before responding?
+		sleep := 0
+		if sleepp := r.URL.Query().Get("sleep"); sleepp != "" {
+			if _, err := fmt.Sscanf(sleepp, "%d", &sleep); err != nil {
+				panic(err)
+			}
+		}
+
 		// compute offset
 		offset := 0
 		if rangeh := r.Header.Get("Range"); rangeh != "" {
@@ -55,6 +63,11 @@ func TestMain(m *testing.M) {
 			}
 		}
 
+		// delay response
+		if sleep > 0 {
+			time.Sleep(time.Duration(sleep) * time.Millisecond)
+		}
+
 		// set response headers
 		w.Header().Set("Content-Length", fmt.Sprintf("%d", size-offset))
 		w.Header().Set("Accept-Ranges", "bytes")
@@ -300,7 +313,7 @@ func TestBatch(t *testing.T) {
 	}
 
 	// batch run
-	responses := DefaultClient.DoBatch(reqs, 4)
+	responses := DefaultClient.DoBatch(4, reqs)
 
 	// listen for responses
 	for i := 0; i < len(reqs); {

+ 8 - 0
error.go

@@ -8,6 +8,7 @@ const (
 	errBadLength = iota
 	errNoFilename
 	errChecksumMismatch
+	errBadDestination
 )
 
 // grabError is a custom error type
@@ -59,3 +60,10 @@ func IsNoFilename(err error) bool {
 func IsChecksumMismatch(err error) bool {
 	return isErrorType(err, errChecksumMismatch)
 }
+
+// IsBadDestination returns a boolean indicating whether the error is known to
+// report that the given destination path is not valid for the requested
+// operation.
+func IsBadDestination(err error) bool {
+	return isErrorType(err, errBadDestination)
+}

+ 49 - 0
grab.go

@@ -31,6 +31,10 @@ protocol error.
 */
 package grab
 
+import (
+	"os"
+)
+
 // Get tranfers a file from the specified source URL to the given destination
 // path and returns the completed Response context.
 //
@@ -44,6 +48,7 @@ func Get(dst, src string) (*Response, error) {
 
 	req.Filename = dst
 
+	// execute with default client
 	return DefaultClient.Do(req)
 }
 
@@ -70,5 +75,49 @@ func GetAsync(dst, src string) (<-chan *Response, error) {
 
 	req.Filename = dst
 
+	// execute async with default client
 	return DefaultClient.DoAsync(req), nil
 }
+
+// GetBatch executes multiple requests with the given number of workers and
+// returns a channel to receive the file transfer response contexts. The channel
+// is closed once all responses have been received.
+//
+// GetBatch requires that the destination path is an existing directory. If not,
+// an error is returned which may be identified with IsBadDestination.
+//
+// If zero is given as the worker count, one worker will be created for each
+// given request.
+//
+// Each response is sent through the channel once the request is initiated via
+// HTTP GET or an error has occurred but before the file transfer begins.
+//
+// Any error which occurs during any of the file transfers will be set in the
+// associated Response.Error field.
+func GetBatch(workers int, dst string, sources ...string) (<-chan *Response, error) {
+	// check that dst is an existing directory
+	fi, err := os.Stat(dst)
+	if err != nil {
+		return nil, err
+	}
+
+	if !fi.IsDir() {
+		return nil, newGrabError(errBadDestination, "Destination path is not a directory")
+	}
+
+	// build slice of request
+	reqs := make(Requests, len(sources))
+	for i := 0; i < len(sources); i++ {
+		req, err := NewRequest(sources[i])
+		if err != nil {
+			return nil, err
+		}
+
+		req.Filename = dst
+
+		reqs[i] = req
+	}
+
+	// execute batch with default client
+	return DefaultClient.DoBatch(workers, reqs), nil
+}

+ 6 - 3
response.go

@@ -85,7 +85,8 @@ func (c *Response) copy() error {
 
 	// download and update progress
 	var buffer [4096]byte
-	for {
+	complete := false
+	for complete == false {
 		// read HTTP stream
 		n, err := c.HTTPResponse.Body.Read(buffer[:])
 		if err != nil && err != io.EOF {
@@ -102,12 +103,14 @@ func (c *Response) copy() error {
 
 		// break when finished
 		if err == io.EOF {
-			break
+			// download is ready for checksum validation
+			c.writer.Close()
+			complete = true
 		}
 	}
 
 	// validate checksum
-	if c.Request.Hash != nil && c.Request.Checksum != nil {
+	if complete && c.Request.Hash != nil && c.Request.Checksum != nil {
 		// open downloaded file
 		if f, err := os.Open(c.Filename); err != nil {
 			return c.close(err)