Commit 2f530534 authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

common/download: resume test

parent aa7d3b78
...@@ -99,8 +99,6 @@ func (d *DownloadClient) Cancel() { ...@@ -99,8 +99,6 @@ func (d *DownloadClient) Cancel() {
} }
func (d *DownloadClient) Get() (string, error) { func (d *DownloadClient) Get() (string, error) {
var f *os.File
// If we already have the file and it matches, then just return the target path. // If we already have the file and it matches, then just return the target path.
if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify { if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify {
log.Println("Initial checksum matched, no download needed.") log.Println("Initial checksum matched, no download needed.")
...@@ -115,6 +113,7 @@ func (d *DownloadClient) Get() (string, error) { ...@@ -115,6 +113,7 @@ func (d *DownloadClient) Get() (string, error) {
log.Printf("Parsed URL: %#v", url) log.Printf("Parsed URL: %#v", url)
// Files when we don't copy the file are special cased. // Files when we don't copy the file are special cased.
var f *os.File
var finalPath string var finalPath string
if url.Scheme == "file" && !d.config.CopyFile { if url.Scheme == "file" && !d.config.CopyFile {
finalPath = url.Path finalPath = url.Path
...@@ -199,6 +198,15 @@ func (*HTTPDownloader) Cancel() { ...@@ -199,6 +198,15 @@ func (*HTTPDownloader) Cancel() {
func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
log.Printf("Starting download: %s", src.String()) log.Printf("Starting download: %s", src.String())
// Seek to the beginning by default
if _, err := dst.Seek(0, 0); err != nil {
return err
}
// Make the request. We first make a HEAD request so we can check
// if the server supports range queries. If the server/URL doesn't
// support HEAD requests, we just fall back to GET.
req, err := http.NewRequest("HEAD", src.String(), nil) req, err := http.NewRequest("HEAD", src.String(), nil)
if err != nil { if err != nil {
return err return err
...@@ -215,33 +223,9 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { ...@@ -215,33 +223,9 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
} }
resp, err := httpClient.Do(req) resp, err := httpClient.Do(req)
if err != nil || resp.StatusCode != 200 { if err == nil && (resp.StatusCode >= 200 && resp.StatusCode < 300) {
req.Method = "GET" // If the HEAD request succeeded, then attempt to set the range
resp, err = httpClient.Do(req) // query if we can.
if err != nil {
return err
}
}
if resp.StatusCode != 200 {
log.Printf(
"Non-200 status code: %d. Getting error body.", resp.StatusCode)
if req.Method != "GET" {
req.Method = "GET"
resp, err = httpClient.Do(req)
if err != nil {
return err
}
}
errorBody := new(bytes.Buffer)
io.Copy(errorBody, resp.Body)
return fmt.Errorf("HTTP error '%d'! Remote side responded:\n%s",
resp.StatusCode, errorBody.String())
}
req.Method = "GET"
d.progress = 0
if resp.Header.Get("Accept-Ranges") == "bytes" { if resp.Header.Get("Accept-Ranges") == "bytes" {
if fi, err := dst.Stat(); err == nil { if fi, err := dst.Stat(); err == nil {
if _, err = dst.Seek(0, os.SEEK_END); err == nil { if _, err = dst.Seek(0, os.SEEK_END); err == nil {
...@@ -250,6 +234,10 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { ...@@ -250,6 +234,10 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
} }
} }
} }
}
// Set the request to GET now, and redo the query to download
req.Method = "GET"
resp, err = httpClient.Do(req) resp, err = httpClient.Do(req)
if err != nil { if err != nil {
...@@ -257,7 +245,6 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { ...@@ -257,7 +245,6 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
} }
d.total = uint(resp.ContentLength) d.total = uint(resp.ContentLength)
var buffer [4096]byte var buffer [4096]byte
for { for {
n, err := resp.Body.Read(buffer[:]) n, err := resp.Body.Read(buffer[:])
......
...@@ -161,6 +161,41 @@ func TestDownloadClient_checksumNoDownload(t *testing.T) { ...@@ -161,6 +161,41 @@ func TestDownloadClient_checksumNoDownload(t *testing.T) {
} }
} }
func TestDownloadClient_resume(t *testing.T) {
tf, _ := ioutil.TempFile("", "packer")
tf.Write([]byte("w"))
tf.Close()
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if r.Method == "HEAD" {
rw.Header().Set("Accept-Ranges", "bytes")
rw.WriteHeader(204)
return
}
http.ServeFile(rw, r, "./test-fixtures/root/basic.txt")
}))
defer ts.Close()
client := NewDownloadClient(&DownloadConfig{
Url: ts.URL,
TargetPath: tf.Name(),
})
path, err := client.Get()
if err != nil {
t.Fatalf("err: %s", err)
}
raw, err := ioutil.ReadFile(path)
if err != nil {
t.Fatalf("err: %s", err)
}
if string(raw) != "wello\n" {
t.Fatalf("bad: %s", string(raw))
}
}
func TestDownloadClient_usesDefaultUserAgent(t *testing.T) { func TestDownloadClient_usesDefaultUserAgent(t *testing.T) {
tf, err := ioutil.TempFile("", "packer") tf, err := ioutil.TempFile("", "packer")
if err != nil { if err != nil {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment