~nskaggs/+junk/xenial-test

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
// Copyright 2012, 2013 Canonical Ltd.
// Licensed under the AGPLv3, see LICENCE file for details.

package downloader

import (
	"io"
	"io/ioutil"
	"net/url"
	"os"

	"github.com/juju/errors"
	"github.com/juju/utils"
	"launchpad.net/tomb"
)

// Request holds a single download request.
type Request struct {
	// URL is the location from which the file will be downloaded.
	URL *url.URL

	// TargetDir is the directory into which the file will be downloaded.
	// It defaults to os.TempDir().
	TargetDir string

	// Verify is used to ensure that the download result is correct. If
	// the download is invalid then the func must return errors.NotValid.
	// If no func is provided then no verification happens.
	Verify func(*os.File) error
}

// Status represents the status of a completed download.
type Status struct {
	// File holds the downloaded data on success.
	File *os.File

	// Err describes any error encountered while downloading.
	Err error
}

// Download can download a file from the network.
type Download struct {
	tomb     tomb.Tomb
	done     chan Status
	openBlob func(*url.URL) (io.ReadCloser, error)
}

// StartDownload returns a new Download instance based on the provided
// request. openBlob is used to gain access to the blob, whether through
// an HTTP request or some other means.
func StartDownload(req Request, openBlob func(*url.URL) (io.ReadCloser, error)) *Download {
	dl := newDownload(openBlob)
	go dl.run(req)
	return dl
}

func newDownload(openBlob func(*url.URL) (io.ReadCloser, error)) *Download {
	if openBlob == nil {
		openBlob = NewHTTPBlobOpener(utils.NoVerifySSLHostnames)
	}
	return &Download{
		done:     make(chan Status),
		openBlob: openBlob,
	}
}

// Stop stops any download that's in progress.
func (dl *Download) Stop() {
	dl.tomb.Kill(nil)
	dl.tomb.Wait()
}

// Done returns a channel that receives a status when the download has
// completed.  It is the receiver's responsibility to close and remove
// the received file.
func (dl *Download) Done() <-chan Status {
	return dl.done
}

// Wait blocks until the download completes or the abort channel receives.
func (dl *Download) Wait(abort <-chan struct{}) (*os.File, error) {
	defer dl.Stop()

	select {
	case <-abort:
		logger.Infof("download aborted")
		return nil, errors.New("aborted")
	case status := <-dl.Done():
		if status.Err != nil {
			if status.File != nil {
				if err := status.File.Close(); err != nil {
					logger.Errorf("failed to close file: %v", err)
				}
			}
			return nil, errors.Trace(status.Err)
		}
		return status.File, nil
	}
}

func (dl *Download) run(req Request) {
	defer dl.tomb.Done()

	// TODO(dimitern) 2013-10-03 bug #1234715
	// Add a testing HTTPS storage to verify the
	// disableSSLHostnameVerification behavior here.
	file, err := download(req, dl.openBlob)
	if err != nil {
		err = errors.Annotatef(err, "cannot download %q", req.URL)
	}

	if err == nil {
		logger.Infof("download complete (%q)", req.URL)
		if req.Verify != nil {
			err = verifyDownload(file, req)
		}
	}

	status := Status{
		File: file,
		Err:  err,
	}
	select {
	case dl.done <- status:
		// no-op
	case <-dl.tomb.Dying():
		cleanTempFile(file)
	}
}

func verifyDownload(file *os.File, req Request) error {
	err := req.Verify(file)
	if err != nil {
		if errors.IsNotValid(err) {
			logger.Errorf("download of %s invalid: %v", req.URL, err)
		}
		return errors.Trace(err)
	}
	logger.Infof("download verified (%q)", req.URL)

	if _, err := file.Seek(0, os.SEEK_SET); err != nil {
		logger.Errorf("failed to seek to beginning of file: %v", err)
		return errors.Trace(err)
	}
	return nil
}

func download(req Request, openBlob func(*url.URL) (io.ReadCloser, error)) (file *os.File, err error) {
	logger.Infof("downloading from %s", req.URL)

	dir := req.TargetDir
	if dir == "" {
		dir = os.TempDir()
	}
	tempFile, err := ioutil.TempFile(dir, "inprogress-")
	if err != nil {
		return nil, errors.Trace(err)
	}
	defer func() {
		if err != nil {
			cleanTempFile(tempFile)
		}
	}()

	reader, err := openBlob(req.URL)
	if err != nil {
		return nil, errors.Trace(err)
	}
	defer reader.Close()

	_, err = io.Copy(tempFile, reader)
	if err != nil {
		return nil, errors.Trace(err)
	}
	if _, err := tempFile.Seek(0, 0); err != nil {
		return nil, errors.Trace(err)
	}
	return tempFile, nil
}

func cleanTempFile(f *os.File) {
	if f == nil {
		return
	}

	f.Close()
	if err := os.Remove(f.Name()); err != nil {
		logger.Errorf("cannot remove temp file %q: %v", f.Name(), err)
	}
}