~nskaggs/+junk/xenial-test

« back to all changes in this revision

Viewing changes to src/gopkg.in/amz.v3/aws/sign.go

  • Committer: Nicholas Skaggs
  • Date: 2016-10-24 20:56:05 UTC
  • Revision ID: nicholas.skaggs@canonical.com-20161024205605-z8lta0uvuhtxwzwl
Initi with beta15

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
package aws
 
2
 
 
3
import (
 
4
        "bytes"
 
5
        "crypto/hmac"
 
6
        "crypto/sha256"
 
7
        "encoding/base64"
 
8
        "fmt"
 
9
        "io"
 
10
        "io/ioutil"
 
11
        "log"
 
12
        "net/http"
 
13
        "net/url"
 
14
        "sort"
 
15
        "strings"
 
16
        "time"
 
17
)
 
18
 
 
19
var debug = log.New(
 
20
        // Remove the c-style comment header to front of line to debug information.
 
21
        /*os.Stdout, //*/ ioutil.Discard,
 
22
        "DEBUG: ",
 
23
        log.LstdFlags,
 
24
)
 
25
 
 
26
type Signer func(*http.Request, Auth) error
 
27
 
 
28
// Ensure our signers meet the interface
 
29
var _ Signer = SignV2
 
30
var _ Signer = SignV4Factory("", "")
 
31
 
 
32
type hasher func(io.Reader) (string, error)
 
33
 
 
34
const (
 
35
        ISO8601BasicFormat      = "20060102T150405Z"
 
36
        ISO8601BasicFormatShort = "20060102"
 
37
)
 
38
 
 
39
// SignV2 signs an HTTP request utilizing version 2 of the AWS
 
40
// signature, and the given credentials.
 
41
func SignV2(req *http.Request, auth Auth) (err error) {
 
42
 
 
43
        queryVals := req.URL.Query()
 
44
        queryVals.Set("AWSAccessKeyId", auth.AccessKey)
 
45
        queryVals.Set("SignatureVersion", "2")
 
46
        queryVals.Set("SignatureMethod", "HmacSHA256")
 
47
 
 
48
        uriStr := canonicalURI(req.URL)
 
49
        queryStr := canonicalQueryString(queryVals)
 
50
 
 
51
        payload := new(bytes.Buffer)
 
52
        if err := errorCollector(
 
53
                fprintfWrapper(payload, "%s\n", requestMethodVerb(req.Method)),
 
54
                fprintfWrapper(payload, "%s\n", req.Host),
 
55
                fprintfWrapper(payload, "%s\n", uriStr),
 
56
                fprintfWrapper(payload, "%s", queryStr),
 
57
        ); err != nil {
 
58
                return err
 
59
        }
 
60
 
 
61
        hash := hmac.New(sha256.New, []byte(auth.SecretKey))
 
62
        hash.Write(payload.Bytes())
 
63
        signature := make([]byte, base64.StdEncoding.EncodedLen(hash.Size()))
 
64
        base64.StdEncoding.Encode(signature, hash.Sum(nil))
 
65
 
 
66
        queryVals.Set("Signature", string(signature))
 
67
        req.URL.RawQuery = queryVals.Encode()
 
68
 
 
69
        return nil
 
70
}
 
71
 
 
72
// SignV4Factory returns a version 4 Signer which will utilize the
 
73
// given region name.
 
74
func SignV4Factory(regionName, serviceName string) Signer {
 
75
        return func(req *http.Request, auth Auth) error {
 
76
                return SignV4(req, auth, regionName, serviceName)
 
77
        }
 
78
}
 
79
 
 
80
func SignV4URL(req *http.Request, auth Auth, regionName, svcName string, expires time.Duration) error {
 
81
        reqTime, err := requestTime(req)
 
82
        if err != nil {
 
83
                return err
 
84
        }
 
85
 
 
86
        req.Header.Del("date")
 
87
 
 
88
        credScope := credentialScope(reqTime, regionName, svcName)
 
89
 
 
90
        queryVals := req.URL.Query()
 
91
        queryVals.Set("X-Amz-Algorithm", "AWS4-HMAC-SHA256")
 
92
        queryVals.Set("X-Amz-Credential", auth.AccessKey+"/"+credScope)
 
93
        queryVals.Set("X-Amz-Date", reqTime.Format(ISO8601BasicFormat))
 
94
        queryVals.Set("X-Amz-Expires", fmt.Sprintf("%d", int(expires.Seconds())))
 
95
        queryVals.Set("X-Amz-SignedHeaders", "host")
 
96
        req.URL.RawQuery = queryVals.Encode()
 
97
 
 
98
        _, canonReqHash, _, err := canonicalRequest(req, sha256Hasher, false)
 
99
        if err != nil {
 
100
                return err
 
101
        }
 
102
 
 
103
        var strToSign string
 
104
        if strToSign, err = stringToSign(reqTime, canonReqHash, credScope); err != nil {
 
105
                return err
 
106
        }
 
107
 
 
108
        key := signingKey(reqTime, auth.SecretKey, regionName, svcName)
 
109
        signature := fmt.Sprintf("%x", hmacHasher(key, strToSign))
 
110
 
 
111
        debug.Printf("strToSign:\n\"\"\"\n%s\n\"\"\"", strToSign)
 
112
 
 
113
        queryVals.Set("X-Amz-Signature", signature)
 
114
 
 
115
        req.URL.RawQuery = queryVals.Encode()
 
116
 
 
117
        return nil
 
118
}
 
119
 
 
120
// SignV4 signs an HTTP request utilizing version 4 of the AWS
 
121
// signature, and the given credentials.
 
122
func SignV4(req *http.Request, auth Auth, regionName, svcName string) (err error) {
 
123
 
 
124
        var reqTime time.Time
 
125
        if reqTime, err = requestTime(req); err != nil {
 
126
                return err
 
127
        }
 
128
 
 
129
        // Remove any existing authorization headers as they will corrupt
 
130
        // the signing.
 
131
        delete(req.Header, "Authorization")
 
132
        delete(req.Header, "authorization")
 
133
 
 
134
        credScope := credentialScope(reqTime, regionName, svcName)
 
135
 
 
136
        _, canonReqHash, sortedHdrNames, err := canonicalRequest(req, sha256Hasher, true)
 
137
        if err != nil {
 
138
                return err
 
139
        }
 
140
 
 
141
        var strToSign string
 
142
        if strToSign, err = stringToSign(reqTime, canonReqHash, credScope); err != nil {
 
143
                return err
 
144
        }
 
145
 
 
146
        key := signingKey(reqTime, auth.SecretKey, regionName, svcName)
 
147
        signature := fmt.Sprintf("%x", hmacHasher(key, strToSign))
 
148
 
 
149
        debug.Printf("strToSign:\n\"\"\"\n%s\n\"\"\"", strToSign)
 
150
 
 
151
        var authHdrVal string
 
152
        if authHdrVal, err = authHeaderString(
 
153
                req.Header,
 
154
                auth.AccessKey,
 
155
                signature,
 
156
                credScope,
 
157
                sortedHdrNames,
 
158
        ); err != nil {
 
159
                return err
 
160
        }
 
161
 
 
162
        req.Header.Set("Authorization", authHdrVal)
 
163
 
 
164
        return nil
 
165
}
 
166
 
 
167
// Task 1: Create a Canonical Request.
 
168
// Returns the canonical request, and its hash.
 
169
func canonicalRequest(
 
170
        req *http.Request,
 
171
        hasher hasher,
 
172
        calcPayHash bool,
 
173
) (canReq, canReqHash string, sortedHdrNames []string, err error) {
 
174
 
 
175
        payHash := "UNSIGNED-PAYLOAD"
 
176
        if calcPayHash {
 
177
                if payHash, err = payloadHash(req, hasher); err != nil {
 
178
                        return
 
179
                }
 
180
                req.Header.Set("x-amz-content-sha256", payHash)
 
181
        }
 
182
 
 
183
        sortedHdrNames = sortHeaderNames(req.Header, "host")
 
184
        var canHdr string
 
185
        if canHdr, err = canonicalHeaders(sortedHdrNames, req.Host, req.Header); err != nil {
 
186
                return
 
187
        }
 
188
 
 
189
        debug.Printf("canHdr:\n\"\"\"\n%s\n\"\"\"", canHdr)
 
190
        debug.Printf("signedHeader: %s\n\n", strings.Join(sortedHdrNames, ";"))
 
191
 
 
192
        uriStr := canonicalURI(req.URL)
 
193
        queryStr := canonicalQueryString(req.URL.Query())
 
194
 
 
195
        c := new(bytes.Buffer)
 
196
        if err := errorCollector(
 
197
                fprintfWrapper(c, "%s\n", requestMethodVerb(req.Method)),
 
198
                fprintfWrapper(c, "%s\n", uriStr),
 
199
                fprintfWrapper(c, "%s\n", queryStr),
 
200
                fprintfWrapper(c, "%s\n", canHdr),
 
201
                fprintfWrapper(c, "%s\n", strings.Join(sortedHdrNames, ";")),
 
202
                fprintfWrapper(c, "%s", payHash),
 
203
        ); err != nil {
 
204
                return "", "", nil, err
 
205
        }
 
206
 
 
207
        canReq = c.String()
 
208
        debug.Printf("canReq:\n\"\"\"\n%s\n\"\"\"", canReq)
 
209
        canReqHash, err = hasher(bytes.NewBuffer([]byte(canReq)))
 
210
 
 
211
        return canReq, canReqHash, sortedHdrNames, err
 
212
}
 
213
 
 
214
// Task 2: Create a string to Sign
 
215
// Returns a string in the defined format to sign for the authorization header.
 
216
func stringToSign(
 
217
        t time.Time,
 
218
        canonReqHash string,
 
219
        credScope string,
 
220
) (string, error) {
 
221
        w := new(bytes.Buffer)
 
222
        if err := errorCollector(
 
223
                fprintfWrapper(w, "AWS4-HMAC-SHA256\n"),
 
224
                fprintfWrapper(w, "%s\n", t.Format(ISO8601BasicFormat)),
 
225
                fprintfWrapper(w, "%s\n", credScope),
 
226
                fprintfWrapper(w, "%s", canonReqHash),
 
227
        ); err != nil {
 
228
                return "", err
 
229
        }
 
230
 
 
231
        return w.String(), nil
 
232
}
 
233
 
 
234
// Task 3: Calculate the Signature
 
235
// Returns a derived signing key.
 
236
func signingKey(t time.Time, secretKey, regionName, svcName string) []byte {
 
237
 
 
238
        kSecret := secretKey
 
239
        kDate := hmacHasher([]byte("AWS4"+kSecret), t.Format(ISO8601BasicFormatShort))
 
240
        kRegion := hmacHasher(kDate, regionName)
 
241
        kService := hmacHasher(kRegion, svcName)
 
242
        kSigning := hmacHasher(kService, "aws4_request")
 
243
 
 
244
        return kSigning
 
245
}
 
246
 
 
247
// Task 4: Add the Signing Information to the Request
 
248
// Returns a string to be placed in the Authorization header for the request.
 
249
func authHeaderString(
 
250
        header http.Header,
 
251
        accessKey,
 
252
        signature string,
 
253
        credScope string,
 
254
        sortedHeaderNames []string,
 
255
) (string, error) {
 
256
        w := new(bytes.Buffer)
 
257
        if err := errorCollector(
 
258
                fprintfWrapper(w, "AWS4-HMAC-SHA256 "),
 
259
                fprintfWrapper(w, "Credential=%s/%s, ", accessKey, credScope),
 
260
                fprintfWrapper(w, "SignedHeaders=%s, ", strings.Join(sortedHeaderNames, ";")),
 
261
                fprintfWrapper(w, "Signature=%s", signature),
 
262
        ); err != nil {
 
263
                return "", err
 
264
        }
 
265
 
 
266
        return w.String(), nil
 
267
}
 
268
 
 
269
func canonicalURI(u *url.URL) string {
 
270
 
 
271
        // The algorithm states that if the path is empty, to just use a "/".
 
272
        if u.Path == "" {
 
273
                return "/"
 
274
        }
 
275
 
 
276
        // Each path segment must be URI-encoded.
 
277
        segments := strings.Split(u.Path, "/")
 
278
        for i, segment := range segments {
 
279
                segments[i] = goToAwsUrlEncoding(url.QueryEscape(segment))
 
280
        }
 
281
 
 
282
        return strings.Join(segments, "/")
 
283
}
 
284
 
 
285
func canonicalQueryString(queryVals url.Values) string {
 
286
 
 
287
        // AWS dictates that if duplicate keys exist, their values be
 
288
        // sorted as well.
 
289
        for _, values := range queryVals {
 
290
                sort.Strings(values)
 
291
        }
 
292
 
 
293
        return goToAwsUrlEncoding(queryVals.Encode())
 
294
}
 
295
 
 
296
func goToAwsUrlEncoding(urlEncoded string) string {
 
297
        // AWS dictates that we use %20 for encoding spaces rather than +.
 
298
        // All significant +s should already be encoded into their
 
299
        // hexadecimal equivalents before doing the string replace.
 
300
        return strings.Replace(urlEncoded, "+", "%20", -1)
 
301
}
 
302
 
 
303
func canonicalHeaders(sortedHeaderNames []string, host string, hdr http.Header) (string, error) {
 
304
        buffer := new(bytes.Buffer)
 
305
 
 
306
        for _, hName := range sortedHeaderNames {
 
307
 
 
308
                hdrVals := host
 
309
                if hName != "host" {
 
310
                        canonHdrKey := http.CanonicalHeaderKey(hName)
 
311
                        sortedHdrVals := hdr[canonHdrKey]
 
312
                        sort.Strings(sortedHdrVals)
 
313
                        hdrVals = strings.Join(sortedHdrVals, ",")
 
314
                }
 
315
 
 
316
                if _, err := fmt.Fprintf(buffer, "%s:%s\n", hName, hdrVals); err != nil {
 
317
                        return "", err
 
318
                }
 
319
        }
 
320
 
 
321
        // There is intentionally a hanging newline at the end of the
 
322
        // header list.
 
323
        return buffer.String(), nil
 
324
}
 
325
 
 
326
// Returns a SHA256 checksum of the request body. Represented as a
 
327
// lowercase hexadecimal string.
 
328
func payloadHash(req *http.Request, hasher hasher) (string, error) {
 
329
        if req.Body == nil {
 
330
                return hasher(bytes.NewBuffer(nil))
 
331
        }
 
332
 
 
333
        return hasher(req.Body)
 
334
}
 
335
 
 
336
// Retrieve the header names, lower-case them, and sort them.
 
337
func sortHeaderNames(header http.Header, injectedNames ...string) []string {
 
338
 
 
339
        sortedNames := injectedNames
 
340
        for hName, _ := range header {
 
341
                sortedNames = append(sortedNames, strings.ToLower(hName))
 
342
        }
 
343
 
 
344
        sort.Strings(sortedNames)
 
345
 
 
346
        return sortedNames
 
347
}
 
348
 
 
349
func hmacHasher(key []byte, value string) []byte {
 
350
        h := hmac.New(sha256.New, key)
 
351
        h.Write([]byte(value))
 
352
        return h.Sum(nil)
 
353
}
 
354
 
 
355
func sha256Hasher(payloadReader io.Reader) (string, error) {
 
356
        hasher := sha256.New()
 
357
        _, err := io.Copy(hasher, payloadReader)
 
358
 
 
359
        return fmt.Sprintf("%x", hasher.Sum(nil)), err
 
360
}
 
361
 
 
362
func credentialScope(t time.Time, regionName, svcName string) string {
 
363
        return fmt.Sprintf(
 
364
                "%s/%s/%s/aws4_request",
 
365
                t.Format(ISO8601BasicFormatShort),
 
366
                regionName,
 
367
                svcName,
 
368
        )
 
369
}
 
370
 
 
371
// We do a lot of fmt.Fprintfs in this package. Create a higher-order
 
372
// function to elide the bytes written return value so we can submit
 
373
// these calls to an error collector.
 
374
func fprintfWrapper(w io.Writer, format string, vals ...interface{}) func() error {
 
375
        return func() error {
 
376
                _, err := fmt.Fprintf(w, format, vals...)
 
377
                return err
 
378
        }
 
379
}
 
380
 
 
381
// Poor man's maybe monad.
 
382
func errorCollector(writers ...func() error) error {
 
383
        for _, writer := range writers {
 
384
                if err := writer(); err != nil {
 
385
                        return err
 
386
                }
 
387
        }
 
388
 
 
389
        return nil
 
390
}
 
391
 
 
392
// Retrieve the request time from the request. We will attempt to
 
393
// parse whatever we find, but we will not make up a request date for
 
394
// the user (i.e.: Magic!).
 
395
func requestTime(req *http.Request) (time.Time, error) {
 
396
 
 
397
        // Time formats to try. We want to do everything we can to accept
 
398
        // all time formats, but ultimately we may fail. In the package
 
399
        // scope so it doesn't get initialized for every request.
 
400
        var timeFormats = []string{
 
401
                time.RFC822,
 
402
                ISO8601BasicFormat,
 
403
                time.RFC1123,
 
404
                time.ANSIC,
 
405
                time.UnixDate,
 
406
                time.RubyDate,
 
407
                time.RFC822Z,
 
408
                time.RFC850,
 
409
                time.RFC1123Z,
 
410
                time.RFC3339,
 
411
                time.RFC3339Nano,
 
412
                time.Kitchen,
 
413
        }
 
414
 
 
415
        // Get a date header.
 
416
        var date string
 
417
        if date = req.Header.Get("x-amz-date"); date == "" {
 
418
                if date = req.Header.Get("date"); date == "" {
 
419
                        return time.Time{}, fmt.Errorf(`Could not retrieve a request date. Please provide one in either "x-amz-date", or "date".`)
 
420
                }
 
421
        }
 
422
 
 
423
        // Start attempting to parse
 
424
        for _, format := range timeFormats {
 
425
                if parsedTime, err := time.Parse(format, date); err == nil {
 
426
                        return parsedTime, nil
 
427
                }
 
428
        }
 
429
 
 
430
        return time.Time{}, fmt.Errorf(
 
431
                "Could not parse the given date. Please utilize one of the following formats: %s",
 
432
                strings.Join(timeFormats, ","),
 
433
        )
 
434
}
 
435
 
 
436
// http.Request's Method member returns the entire method. Derive the
 
437
// verb.
 
438
func requestMethodVerb(rawMethod string) (verb string) {
 
439
        verbPlus := strings.SplitN(rawMethod, " ", 2)
 
440
        switch {
 
441
        case len(verbPlus) == 0: // Per docs, Method will be empty if it's GET.
 
442
                verb = "GET"
 
443
        default:
 
444
                verb = verbPlus[0]
 
445
        }
 
446
        return verb
 
447
}