~wallyworld/gwacl/ensure-all-roles-have-costs

« back to all changes in this revision

Viewing changes to fork/tls/conn.go

  • Committer: Ian Booth
  • Date: 2014-12-02 00:36:45 UTC
  • Revision ID: ian.booth@canonical.com-20141202003645-ye8a5akifuf2wjk3
Ensure all regions have costs and remove custom formatting

Show diffs side-by-side

added added

removed removed

Lines of Context:
7
7
package tls
8
8
 
9
9
import (
10
 
    "bytes"
11
 
    "crypto/cipher"
12
 
    "crypto/subtle"
13
 
    "crypto/x509"
14
 
    "errors"
15
 
    "io"
16
 
    "net"
17
 
    "sync"
18
 
    "time"
 
10
        "bytes"
 
11
        "crypto/cipher"
 
12
        "crypto/subtle"
 
13
        "crypto/x509"
 
14
        "errors"
 
15
        "io"
 
16
        "net"
 
17
        "sync"
 
18
        "time"
19
19
)
20
20
 
21
21
// A Conn represents a secured connection.
22
22
// It implements the net.Conn interface.
23
23
type Conn struct {
24
 
    // constant
25
 
    conn     net.Conn
26
 
    isClient bool
27
 
 
28
 
    // constant after handshake; protected by handshakeMutex
29
 
    handshakeMutex    sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
30
 
    vers              uint16     // TLS version
31
 
    haveVers          bool       // version has been negotiated
32
 
    config            *Config    // configuration passed to constructor
33
 
    handshakeComplete bool
34
 
    cipherSuite       uint16
35
 
    ocspResponse      []byte // stapled OCSP response
36
 
    peerCertificates  []*x509.Certificate
37
 
    // verifiedChains contains the certificate chains that we built, as
38
 
    // opposed to the ones presented by the server.
39
 
    verifiedChains [][]*x509.Certificate
40
 
    // serverName contains the server name indicated by the client, if any.
41
 
    serverName string
42
 
 
43
 
    clientProtocol         string
44
 
    clientProtocolFallback bool
45
 
 
46
 
    // first permanent error
47
 
    errMutex sync.Mutex
48
 
    err      error
49
 
 
50
 
    // input/output
51
 
    in, out  halfConn     // in.Mutex < out.Mutex
52
 
    rawInput *block       // raw input, right off the wire
53
 
    input    *block       // application data waiting to be read
54
 
    hand     bytes.Buffer // handshake data waiting to be read
55
 
 
56
 
    tmp [16]byte
 
24
        // constant
 
25
        conn     net.Conn
 
26
        isClient bool
 
27
 
 
28
        // constant after handshake; protected by handshakeMutex
 
29
        handshakeMutex    sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
 
30
        vers              uint16     // TLS version
 
31
        haveVers          bool       // version has been negotiated
 
32
        config            *Config    // configuration passed to constructor
 
33
        handshakeComplete bool
 
34
        cipherSuite       uint16
 
35
        ocspResponse      []byte // stapled OCSP response
 
36
        peerCertificates  []*x509.Certificate
 
37
        // verifiedChains contains the certificate chains that we built, as
 
38
        // opposed to the ones presented by the server.
 
39
        verifiedChains [][]*x509.Certificate
 
40
        // serverName contains the server name indicated by the client, if any.
 
41
        serverName string
 
42
 
 
43
        clientProtocol         string
 
44
        clientProtocolFallback bool
 
45
 
 
46
        // first permanent error
 
47
        errMutex sync.Mutex
 
48
        err      error
 
49
 
 
50
        // input/output
 
51
        in, out  halfConn     // in.Mutex < out.Mutex
 
52
        rawInput *block       // raw input, right off the wire
 
53
        input    *block       // application data waiting to be read
 
54
        hand     bytes.Buffer // handshake data waiting to be read
 
55
 
 
56
        tmp [16]byte
57
57
}
58
58
 
59
59
func (c *Conn) setError(err error) error {
60
 
    c.errMutex.Lock()
61
 
    defer c.errMutex.Unlock()
 
60
        c.errMutex.Lock()
 
61
        defer c.errMutex.Unlock()
62
62
 
63
 
    if c.err == nil {
64
 
        c.err = err
65
 
    }
66
 
    return err
 
63
        if c.err == nil {
 
64
                c.err = err
 
65
        }
 
66
        return err
67
67
}
68
68
 
69
69
func (c *Conn) error() error {
70
 
    c.errMutex.Lock()
71
 
    defer c.errMutex.Unlock()
 
70
        c.errMutex.Lock()
 
71
        defer c.errMutex.Unlock()
72
72
 
73
 
    return c.err
 
73
        return c.err
74
74
}
75
75
 
76
76
// Access to net.Conn methods.
79
79
 
80
80
// LocalAddr returns the local network address.
81
81
func (c *Conn) LocalAddr() net.Addr {
82
 
    return c.conn.LocalAddr()
 
82
        return c.conn.LocalAddr()
83
83
}
84
84
 
85
85
// RemoteAddr returns the remote network address.
86
86
func (c *Conn) RemoteAddr() net.Addr {
87
 
    return c.conn.RemoteAddr()
 
87
        return c.conn.RemoteAddr()
88
88
}
89
89
 
90
90
// SetDeadline sets the read and write deadlines associated with the connection.
91
91
// A zero value for t means Read and Write will not time out.
92
92
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
93
93
func (c *Conn) SetDeadline(t time.Time) error {
94
 
    return c.conn.SetDeadline(t)
 
94
        return c.conn.SetDeadline(t)
95
95
}
96
96
 
97
97
// SetReadDeadline sets the read deadline on the underlying connection.
98
98
// A zero value for t means Read will not time out.
99
99
func (c *Conn) SetReadDeadline(t time.Time) error {
100
 
    return c.conn.SetReadDeadline(t)
 
100
        return c.conn.SetReadDeadline(t)
101
101
}
102
102
 
103
103
// SetWriteDeadline sets the write deadline on the underlying conneciton.
104
104
// A zero value for t means Write will not time out.
105
105
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
106
106
func (c *Conn) SetWriteDeadline(t time.Time) error {
107
 
    return c.conn.SetWriteDeadline(t)
 
107
        return c.conn.SetWriteDeadline(t)
108
108
}
109
109
 
110
110
// A halfConn represents one direction of the record layer
111
111
// connection, either sending or receiving.
112
112
type halfConn struct {
113
 
    sync.Mutex
114
 
    version uint16      // protocol version
115
 
    cipher  interface{} // cipher algorithm
116
 
    mac     macFunction
117
 
    seq     [8]byte // 64-bit sequence number
118
 
    bfree   *block  // list of free blocks
119
 
 
120
 
    nextCipher interface{} // next encryption state
121
 
    nextMac    macFunction // next MAC algorithm
122
 
 
123
 
    // used to save allocating a new buffer for each MAC.
124
 
    inDigestBuf, outDigestBuf []byte
 
113
        sync.Mutex
 
114
        version uint16      // protocol version
 
115
        cipher  interface{} // cipher algorithm
 
116
        mac     macFunction
 
117
        seq     [8]byte // 64-bit sequence number
 
118
        bfree   *block  // list of free blocks
 
119
 
 
120
        nextCipher interface{} // next encryption state
 
121
        nextMac    macFunction // next MAC algorithm
 
122
 
 
123
        // used to save allocating a new buffer for each MAC.
 
124
        inDigestBuf, outDigestBuf []byte
125
125
}
126
126
 
127
127
// prepareCipherSpec sets the encryption and MAC states
128
128
// that a subsequent changeCipherSpec will use.
129
129
func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) {
130
 
    hc.version = version
131
 
    hc.nextCipher = cipher
132
 
    hc.nextMac = mac
 
130
        hc.version = version
 
131
        hc.nextCipher = cipher
 
132
        hc.nextMac = mac
133
133
}
134
134
 
135
135
// changeCipherSpec changes the encryption and MAC states
136
136
// to the ones previously passed to prepareCipherSpec.
137
137
func (hc *halfConn) changeCipherSpec() error {
138
 
    if hc.nextCipher == nil {
139
 
        return alertInternalError
140
 
    }
141
 
    hc.cipher = hc.nextCipher
142
 
    hc.mac = hc.nextMac
143
 
    hc.nextCipher = nil
144
 
    hc.nextMac = nil
145
 
    for i := range hc.seq {
146
 
        hc.seq[i] = 0
147
 
    }
148
 
    return nil
 
138
        if hc.nextCipher == nil {
 
139
                return alertInternalError
 
140
        }
 
141
        hc.cipher = hc.nextCipher
 
142
        hc.mac = hc.nextMac
 
143
        hc.nextCipher = nil
 
144
        hc.nextMac = nil
 
145
        for i := range hc.seq {
 
146
                hc.seq[i] = 0
 
147
        }
 
148
        return nil
149
149
}
150
150
 
151
151
// incSeq increments the sequence number.
152
152
func (hc *halfConn) incSeq() {
153
 
    for i := 7; i >= 0; i-- {
154
 
        hc.seq[i]++
155
 
        if hc.seq[i] != 0 {
156
 
            return
157
 
        }
158
 
    }
 
153
        for i := 7; i >= 0; i-- {
 
154
                hc.seq[i]++
 
155
                if hc.seq[i] != 0 {
 
156
                        return
 
157
                }
 
158
        }
159
159
 
160
 
    // Not allowed to let sequence number wrap.
161
 
    // Instead, must renegotiate before it does.
162
 
    // Not likely enough to bother.
163
 
    panic("TLS: sequence number wraparound")
 
160
        // Not allowed to let sequence number wrap.
 
161
        // Instead, must renegotiate before it does.
 
162
        // Not likely enough to bother.
 
163
        panic("TLS: sequence number wraparound")
164
164
}
165
165
 
166
166
// resetSeq resets the sequence number to zero.
167
167
func (hc *halfConn) resetSeq() {
168
 
    for i := range hc.seq {
169
 
        hc.seq[i] = 0
170
 
    }
 
168
        for i := range hc.seq {
 
169
                hc.seq[i] = 0
 
170
        }
171
171
}
172
172
 
173
173
// removePadding returns an unpadded slice, in constant time, which is a prefix
174
174
// of the input. It also returns a byte which is equal to 255 if the padding
175
175
// was valid and 0 otherwise. See RFC 2246, section 6.2.3.2
176
176
func removePadding(payload []byte) ([]byte, byte) {
177
 
    if len(payload) < 1 {
178
 
        return payload, 0
179
 
    }
180
 
 
181
 
    paddingLen := payload[len(payload)-1]
182
 
    t := uint(len(payload)-1) - uint(paddingLen)
183
 
    // if len(payload) >= (paddingLen - 1) then the MSB of t is zero
184
 
    good := byte(int32(^t) >> 31)
185
 
 
186
 
    toCheck := 255 // the maximum possible padding length
187
 
    // The length of the padded data is public, so we can use an if here
188
 
    if toCheck+1 > len(payload) {
189
 
        toCheck = len(payload) - 1
190
 
    }
191
 
 
192
 
    for i := 0; i < toCheck; i++ {
193
 
        t := uint(paddingLen) - uint(i)
194
 
        // if i <= paddingLen then the MSB of t is zero
195
 
        mask := byte(int32(^t) >> 31)
196
 
        b := payload[len(payload)-1-i]
197
 
        good &^= mask&paddingLen ^ mask&b
198
 
    }
199
 
 
200
 
    // We AND together the bits of good and replicate the result across
201
 
    // all the bits.
202
 
    good &= good << 4
203
 
    good &= good << 2
204
 
    good &= good << 1
205
 
    good = uint8(int8(good) >> 7)
206
 
 
207
 
    toRemove := good&paddingLen + 1
208
 
    return payload[:len(payload)-int(toRemove)], good
 
177
        if len(payload) < 1 {
 
178
                return payload, 0
 
179
        }
 
180
 
 
181
        paddingLen := payload[len(payload)-1]
 
182
        t := uint(len(payload)-1) - uint(paddingLen)
 
183
        // if len(payload) >= (paddingLen - 1) then the MSB of t is zero
 
184
        good := byte(int32(^t) >> 31)
 
185
 
 
186
        toCheck := 255 // the maximum possible padding length
 
187
        // The length of the padded data is public, so we can use an if here
 
188
        if toCheck+1 > len(payload) {
 
189
                toCheck = len(payload) - 1
 
190
        }
 
191
 
 
192
        for i := 0; i < toCheck; i++ {
 
193
                t := uint(paddingLen) - uint(i)
 
194
                // if i <= paddingLen then the MSB of t is zero
 
195
                mask := byte(int32(^t) >> 31)
 
196
                b := payload[len(payload)-1-i]
 
197
                good &^= mask&paddingLen ^ mask&b
 
198
        }
 
199
 
 
200
        // We AND together the bits of good and replicate the result across
 
201
        // all the bits.
 
202
        good &= good << 4
 
203
        good &= good << 2
 
204
        good &= good << 1
 
205
        good = uint8(int8(good) >> 7)
 
206
 
 
207
        toRemove := good&paddingLen + 1
 
208
        return payload[:len(payload)-int(toRemove)], good
209
209
}
210
210
 
211
211
// removePaddingSSL30 is a replacement for removePadding in the case that the
212
212
// protocol version is SSLv3. In this version, the contents of the padding
213
213
// are random and cannot be checked.
214
214
func removePaddingSSL30(payload []byte) ([]byte, byte) {
215
 
    if len(payload) < 1 {
216
 
        return payload, 0
217
 
    }
218
 
 
219
 
    paddingLen := int(payload[len(payload)-1]) + 1
220
 
    if paddingLen > len(payload) {
221
 
        return payload, 0
222
 
    }
223
 
 
224
 
    return payload[:len(payload)-paddingLen], 255
 
215
        if len(payload) < 1 {
 
216
                return payload, 0
 
217
        }
 
218
 
 
219
        paddingLen := int(payload[len(payload)-1]) + 1
 
220
        if paddingLen > len(payload) {
 
221
                return payload, 0
 
222
        }
 
223
 
 
224
        return payload[:len(payload)-paddingLen], 255
225
225
}
226
226
 
227
227
func roundUp(a, b int) int {
228
 
    return a + (b-a%b)%b
 
228
        return a + (b-a%b)%b
229
229
}
230
230
 
231
231
// decrypt checks and strips the mac and decrypts the data in b.
232
232
func (hc *halfConn) decrypt(b *block) (bool, alert) {
233
 
    // pull out payload
234
 
    payload := b.data[recordHeaderLen:]
235
 
 
236
 
    macSize := 0
237
 
    if hc.mac != nil {
238
 
        macSize = hc.mac.Size()
239
 
    }
240
 
 
241
 
    paddingGood := byte(255)
242
 
 
243
 
    // decrypt
244
 
    if hc.cipher != nil {
245
 
        switch c := hc.cipher.(type) {
246
 
        case cipher.Stream:
247
 
            c.XORKeyStream(payload, payload)
248
 
        case cipher.BlockMode:
249
 
            blockSize := c.BlockSize()
250
 
 
251
 
            if len(payload)%blockSize != 0 || len(payload) < roundUp(macSize+1, blockSize) {
252
 
                return false, alertBadRecordMAC
253
 
            }
254
 
 
255
 
            c.CryptBlocks(payload, payload)
256
 
            if hc.version == versionSSL30 {
257
 
                payload, paddingGood = removePaddingSSL30(payload)
258
 
            } else {
259
 
                payload, paddingGood = removePadding(payload)
260
 
            }
261
 
            b.resize(recordHeaderLen + len(payload))
262
 
 
263
 
            // note that we still have a timing side-channel in the
264
 
            // MAC check, below. An attacker can align the record
265
 
            // so that a correct padding will cause one less hash
266
 
            // block to be calculated. Then they can iteratively
267
 
            // decrypt a record by breaking each byte. See
268
 
            // "Password Interception in a SSL/TLS Channel", Brice
269
 
            // Canvel et al.
270
 
            //
271
 
            // However, our behavior matches OpenSSL, so we leak
272
 
            // only as much as they do.
273
 
        default:
274
 
            panic("unknown cipher type")
275
 
        }
276
 
    }
277
 
 
278
 
    // check, strip mac
279
 
    if hc.mac != nil {
280
 
        if len(payload) < macSize {
281
 
            return false, alertBadRecordMAC
282
 
        }
283
 
 
284
 
        // strip mac off payload, b.data
285
 
        n := len(payload) - macSize
286
 
        b.data[3] = byte(n >> 8)
287
 
        b.data[4] = byte(n)
288
 
        b.resize(recordHeaderLen + n)
289
 
        remoteMAC := payload[n:]
290
 
        localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data)
291
 
        hc.incSeq()
292
 
 
293
 
        if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 {
294
 
            return false, alertBadRecordMAC
295
 
        }
296
 
        hc.inDigestBuf = localMAC
297
 
    }
298
 
 
299
 
    return true, 0
 
233
        // pull out payload
 
234
        payload := b.data[recordHeaderLen:]
 
235
 
 
236
        macSize := 0
 
237
        if hc.mac != nil {
 
238
                macSize = hc.mac.Size()
 
239
        }
 
240
 
 
241
        paddingGood := byte(255)
 
242
 
 
243
        // decrypt
 
244
        if hc.cipher != nil {
 
245
                switch c := hc.cipher.(type) {
 
246
                case cipher.Stream:
 
247
                        c.XORKeyStream(payload, payload)
 
248
                case cipher.BlockMode:
 
249
                        blockSize := c.BlockSize()
 
250
 
 
251
                        if len(payload)%blockSize != 0 || len(payload) < roundUp(macSize+1, blockSize) {
 
252
                                return false, alertBadRecordMAC
 
253
                        }
 
254
 
 
255
                        c.CryptBlocks(payload, payload)
 
256
                        if hc.version == versionSSL30 {
 
257
                                payload, paddingGood = removePaddingSSL30(payload)
 
258
                        } else {
 
259
                                payload, paddingGood = removePadding(payload)
 
260
                        }
 
261
                        b.resize(recordHeaderLen + len(payload))
 
262
 
 
263
                        // note that we still have a timing side-channel in the
 
264
                        // MAC check, below. An attacker can align the record
 
265
                        // so that a correct padding will cause one less hash
 
266
                        // block to be calculated. Then they can iteratively
 
267
                        // decrypt a record by breaking each byte. See
 
268
                        // "Password Interception in a SSL/TLS Channel", Brice
 
269
                        // Canvel et al.
 
270
                        //
 
271
                        // However, our behavior matches OpenSSL, so we leak
 
272
                        // only as much as they do.
 
273
                default:
 
274
                        panic("unknown cipher type")
 
275
                }
 
276
        }
 
277
 
 
278
        // check, strip mac
 
279
        if hc.mac != nil {
 
280
                if len(payload) < macSize {
 
281
                        return false, alertBadRecordMAC
 
282
                }
 
283
 
 
284
                // strip mac off payload, b.data
 
285
                n := len(payload) - macSize
 
286
                b.data[3] = byte(n >> 8)
 
287
                b.data[4] = byte(n)
 
288
                b.resize(recordHeaderLen + n)
 
289
                remoteMAC := payload[n:]
 
290
                localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data)
 
291
                hc.incSeq()
 
292
 
 
293
                if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 {
 
294
                        return false, alertBadRecordMAC
 
295
                }
 
296
                hc.inDigestBuf = localMAC
 
297
        }
 
298
 
 
299
        return true, 0
300
300
}
301
301
 
302
302
// padToBlockSize calculates the needed padding block, if any, for a payload.
305
305
// any suffix of payload as well as the needed padding to make finalBlock a
306
306
// full block.
307
307
func padToBlockSize(payload []byte, blockSize int) (prefix, finalBlock []byte) {
308
 
    overrun := len(payload) % blockSize
309
 
    paddingLen := blockSize - overrun
310
 
    prefix = payload[:len(payload)-overrun]
311
 
    finalBlock = make([]byte, blockSize)
312
 
    copy(finalBlock, payload[len(payload)-overrun:])
313
 
    for i := overrun; i < blockSize; i++ {
314
 
        finalBlock[i] = byte(paddingLen - 1)
315
 
    }
316
 
    return
 
308
        overrun := len(payload) % blockSize
 
309
        paddingLen := blockSize - overrun
 
310
        prefix = payload[:len(payload)-overrun]
 
311
        finalBlock = make([]byte, blockSize)
 
312
        copy(finalBlock, payload[len(payload)-overrun:])
 
313
        for i := overrun; i < blockSize; i++ {
 
314
                finalBlock[i] = byte(paddingLen - 1)
 
315
        }
 
316
        return
317
317
}
318
318
 
319
319
// encrypt encrypts and macs the data in b.
320
320
func (hc *halfConn) encrypt(b *block) (bool, alert) {
321
 
    // mac
322
 
    if hc.mac != nil {
323
 
        mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data)
324
 
        hc.incSeq()
325
 
 
326
 
        n := len(b.data)
327
 
        b.resize(n + len(mac))
328
 
        copy(b.data[n:], mac)
329
 
        hc.outDigestBuf = mac
330
 
    }
331
 
 
332
 
    payload := b.data[recordHeaderLen:]
333
 
 
334
 
    // encrypt
335
 
    if hc.cipher != nil {
336
 
        switch c := hc.cipher.(type) {
337
 
        case cipher.Stream:
338
 
            c.XORKeyStream(payload, payload)
339
 
        case cipher.BlockMode:
340
 
            prefix, finalBlock := padToBlockSize(payload, c.BlockSize())
341
 
            b.resize(recordHeaderLen + len(prefix) + len(finalBlock))
342
 
            c.CryptBlocks(b.data[recordHeaderLen:], prefix)
343
 
            c.CryptBlocks(b.data[recordHeaderLen+len(prefix):], finalBlock)
344
 
        default:
345
 
            panic("unknown cipher type")
346
 
        }
347
 
    }
348
 
 
349
 
    // update length to include MAC and any block padding needed.
350
 
    n := len(b.data) - recordHeaderLen
351
 
    b.data[3] = byte(n >> 8)
352
 
    b.data[4] = byte(n)
353
 
 
354
 
    return true, 0
 
321
        // mac
 
322
        if hc.mac != nil {
 
323
                mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data)
 
324
                hc.incSeq()
 
325
 
 
326
                n := len(b.data)
 
327
                b.resize(n + len(mac))
 
328
                copy(b.data[n:], mac)
 
329
                hc.outDigestBuf = mac
 
330
        }
 
331
 
 
332
        payload := b.data[recordHeaderLen:]
 
333
 
 
334
        // encrypt
 
335
        if hc.cipher != nil {
 
336
                switch c := hc.cipher.(type) {
 
337
                case cipher.Stream:
 
338
                        c.XORKeyStream(payload, payload)
 
339
                case cipher.BlockMode:
 
340
                        prefix, finalBlock := padToBlockSize(payload, c.BlockSize())
 
341
                        b.resize(recordHeaderLen + len(prefix) + len(finalBlock))
 
342
                        c.CryptBlocks(b.data[recordHeaderLen:], prefix)
 
343
                        c.CryptBlocks(b.data[recordHeaderLen+len(prefix):], finalBlock)
 
344
                default:
 
345
                        panic("unknown cipher type")
 
346
                }
 
347
        }
 
348
 
 
349
        // update length to include MAC and any block padding needed.
 
350
        n := len(b.data) - recordHeaderLen
 
351
        b.data[3] = byte(n >> 8)
 
352
        b.data[4] = byte(n)
 
353
 
 
354
        return true, 0
355
355
}
356
356
 
357
357
// A block is a simple data buffer.
358
358
type block struct {
359
 
    data []byte
360
 
    off  int // index for Read
361
 
    link *block
 
359
        data []byte
 
360
        off  int // index for Read
 
361
        link *block
362
362
}
363
363
 
364
364
// resize resizes block to be n bytes, growing if necessary.
365
365
func (b *block) resize(n int) {
366
 
    if n > cap(b.data) {
367
 
        b.reserve(n)
368
 
    }
369
 
    b.data = b.data[0:n]
 
366
        if n > cap(b.data) {
 
367
                b.reserve(n)
 
368
        }
 
369
        b.data = b.data[0:n]
370
370
}
371
371
 
372
372
// reserve makes sure that block contains a capacity of at least n bytes.
373
373
func (b *block) reserve(n int) {
374
 
    if cap(b.data) >= n {
375
 
        return
376
 
    }
377
 
    m := cap(b.data)
378
 
    if m == 0 {
379
 
        m = 1024
380
 
    }
381
 
    for m < n {
382
 
        m *= 2
383
 
    }
384
 
    data := make([]byte, len(b.data), m)
385
 
    copy(data, b.data)
386
 
    b.data = data
 
374
        if cap(b.data) >= n {
 
375
                return
 
376
        }
 
377
        m := cap(b.data)
 
378
        if m == 0 {
 
379
                m = 1024
 
380
        }
 
381
        for m < n {
 
382
                m *= 2
 
383
        }
 
384
        data := make([]byte, len(b.data), m)
 
385
        copy(data, b.data)
 
386
        b.data = data
387
387
}
388
388
 
389
389
// readFromUntil reads from r into b until b contains at least n bytes
390
390
// or else returns an error.
391
391
func (b *block) readFromUntil(r io.Reader, n int) error {
392
 
    // quick case
393
 
    if len(b.data) >= n {
394
 
        return nil
395
 
    }
 
392
        // quick case
 
393
        if len(b.data) >= n {
 
394
                return nil
 
395
        }
396
396
 
397
 
    // read until have enough.
398
 
    b.reserve(n)
399
 
    for {
400
 
        m, err := r.Read(b.data[len(b.data):cap(b.data)])
401
 
        b.data = b.data[0 : len(b.data)+m]
402
 
        if len(b.data) >= n {
403
 
            break
404
 
        }
405
 
        if err != nil {
406
 
            return err
407
 
        }
408
 
    }
409
 
    return nil
 
397
        // read until have enough.
 
398
        b.reserve(n)
 
399
        for {
 
400
                m, err := r.Read(b.data[len(b.data):cap(b.data)])
 
401
                b.data = b.data[0 : len(b.data)+m]
 
402
                if len(b.data) >= n {
 
403
                        break
 
404
                }
 
405
                if err != nil {
 
406
                        return err
 
407
                }
 
408
        }
 
409
        return nil
410
410
}
411
411
 
412
412
func (b *block) Read(p []byte) (n int, err error) {
413
 
    n = copy(p, b.data[b.off:])
414
 
    b.off += n
415
 
    return
 
413
        n = copy(p, b.data[b.off:])
 
414
        b.off += n
 
415
        return
416
416
}
417
417
 
418
418
// newBlock allocates a new block, from hc's free list if possible.
419
419
func (hc *halfConn) newBlock() *block {
420
 
    b := hc.bfree
421
 
    if b == nil {
422
 
        return new(block)
423
 
    }
424
 
    hc.bfree = b.link
425
 
    b.link = nil
426
 
    b.resize(0)
427
 
    return b
 
420
        b := hc.bfree
 
421
        if b == nil {
 
422
                return new(block)
 
423
        }
 
424
        hc.bfree = b.link
 
425
        b.link = nil
 
426
        b.resize(0)
 
427
        return b
428
428
}
429
429
 
430
430
// freeBlock returns a block to hc's free list.
432
432
// its free list at a time, so there's no need to worry about
433
433
// trimming the list, etc.
434
434
func (hc *halfConn) freeBlock(b *block) {
435
 
    b.link = hc.bfree
436
 
    hc.bfree = b
 
435
        b.link = hc.bfree
 
436
        hc.bfree = b
437
437
}
438
438
 
439
439
// splitBlock splits a block after the first n bytes,
440
440
// returning a block with those n bytes and a
441
441
// block with the remainder.  the latter may be nil.
442
442
func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) {
443
 
    if len(b.data) <= n {
444
 
        return b, nil
445
 
    }
446
 
    bb := hc.newBlock()
447
 
    bb.resize(len(b.data) - n)
448
 
    copy(bb.data, b.data[n:])
449
 
    b.data = b.data[0:n]
450
 
    return b, bb
 
443
        if len(b.data) <= n {
 
444
                return b, nil
 
445
        }
 
446
        bb := hc.newBlock()
 
447
        bb.resize(len(b.data) - n)
 
448
        copy(bb.data, b.data[n:])
 
449
        b.data = b.data[0:n]
 
450
        return b, bb
451
451
}
452
452
 
453
453
// readRecord reads the next TLS record from the connection
454
454
// and updates the record layer state.
455
455
// c.in.Mutex <= L; c.input == nil.
456
456
func (c *Conn) readRecord(want recordType) error {
457
 
    // Caller must be in sync with connection:
458
 
    // handshake data if handshake not yet completed,
459
 
    // else application data.
460
 
    switch want {
461
 
    default:
462
 
        return c.sendAlert(alertInternalError)
463
 
    case recordTypeHandshake, recordTypeChangeCipherSpec:
464
 
        if c.handshakeComplete {
465
 
            return c.sendAlert(alertInternalError)
466
 
        }
467
 
    case recordTypeApplicationData:
468
 
        if !c.handshakeComplete {
469
 
            return c.sendAlert(alertInternalError)
470
 
        }
471
 
    }
 
457
        // Caller must be in sync with connection:
 
458
        // handshake data if handshake not yet completed,
 
459
        // else application data.
 
460
        switch want {
 
461
        default:
 
462
                return c.sendAlert(alertInternalError)
 
463
        case recordTypeHandshake, recordTypeChangeCipherSpec:
 
464
                if c.handshakeComplete {
 
465
                        return c.sendAlert(alertInternalError)
 
466
                }
 
467
        case recordTypeApplicationData:
 
468
                if !c.handshakeComplete {
 
469
                        return c.sendAlert(alertInternalError)
 
470
                }
 
471
        }
472
472
 
473
473
Again:
474
 
    if c.rawInput == nil {
475
 
        c.rawInput = c.in.newBlock()
476
 
    }
477
 
    b := c.rawInput
478
 
 
479
 
    // Read header, payload.
480
 
    if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
481
 
        // RFC suggests that EOF without an alertCloseNotify is
482
 
        // an error, but popular web sites seem to do this,
483
 
        // so we can't make it an error.
484
 
        // if err == io.EOF {
485
 
        //      err = io.ErrUnexpectedEOF
486
 
        // }
487
 
        if e, ok := err.(net.Error); !ok || !e.Temporary() {
488
 
            c.setError(err)
489
 
        }
490
 
        return err
491
 
    }
492
 
    typ := recordType(b.data[0])
493
 
    vers := uint16(b.data[1])<<8 | uint16(b.data[2])
494
 
    n := int(b.data[3])<<8 | int(b.data[4])
495
 
    if c.haveVers && vers != c.vers {
496
 
        return c.sendAlert(alertProtocolVersion)
497
 
    }
498
 
    if n > maxCiphertext {
499
 
        return c.sendAlert(alertRecordOverflow)
500
 
    }
501
 
    if !c.haveVers {
502
 
        // First message, be extra suspicious:
503
 
        // this might not be a TLS client.
504
 
        // Bail out before reading a full 'body', if possible.
505
 
        // The current max version is 3.1.
506
 
        // If the version is >= 16.0, it's probably not real.
507
 
        // Similarly, a clientHello message encodes in
508
 
        // well under a kilobyte.  If the length is >= 12 kB,
509
 
        // it's probably not real.
510
 
        if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 {
511
 
            return c.sendAlert(alertUnexpectedMessage)
512
 
        }
513
 
    }
514
 
    if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
515
 
        if err == io.EOF {
516
 
            err = io.ErrUnexpectedEOF
517
 
        }
518
 
        if e, ok := err.(net.Error); !ok || !e.Temporary() {
519
 
            c.setError(err)
520
 
        }
521
 
        return err
522
 
    }
523
 
 
524
 
    // Process message.
525
 
    b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
526
 
    b.off = recordHeaderLen
527
 
    if ok, err := c.in.decrypt(b); !ok {
528
 
        return c.sendAlert(err)
529
 
    }
530
 
    data := b.data[b.off:]
531
 
    if len(data) > maxPlaintext {
532
 
        c.sendAlert(alertRecordOverflow)
533
 
        c.in.freeBlock(b)
534
 
        return c.error()
535
 
    }
536
 
 
537
 
    switch typ {
538
 
    default:
539
 
        c.sendAlert(alertUnexpectedMessage)
540
 
 
541
 
    case recordTypeAlert:
542
 
        if len(data) != 2 {
543
 
            c.sendAlert(alertUnexpectedMessage)
544
 
            break
545
 
        }
546
 
        if alert(data[1]) == alertCloseNotify {
547
 
            c.setError(io.EOF)
548
 
            break
549
 
        }
550
 
        switch data[0] {
551
 
        case alertLevelWarning:
552
 
            // drop on the floor
553
 
            c.in.freeBlock(b)
554
 
            goto Again
555
 
        case alertLevelError:
556
 
            c.setError(&net.OpError{Op: "remote error", Err: alert(data[1])})
557
 
        default:
558
 
            c.sendAlert(alertUnexpectedMessage)
559
 
        }
560
 
 
561
 
    case recordTypeChangeCipherSpec:
562
 
        if typ != want || len(data) != 1 || data[0] != 1 {
563
 
            c.sendAlert(alertUnexpectedMessage)
564
 
            break
565
 
        }
566
 
        err := c.in.changeCipherSpec()
567
 
        if err != nil {
568
 
            c.sendAlert(err.(alert))
569
 
        }
570
 
 
571
 
    case recordTypeApplicationData:
572
 
        if typ != want {
573
 
            c.sendAlert(alertUnexpectedMessage)
574
 
            break
575
 
        }
576
 
        c.input = b
577
 
        b = nil
578
 
 
579
 
    case recordTypeHandshake:
580
 
        // TODO(rsc): Should at least pick off connection close.
581
 
        if typ != want && !c.isClient {
582
 
            return c.sendAlert(alertNoRenegotiation)
583
 
        }
584
 
        c.hand.Write(data)
585
 
    }
586
 
 
587
 
    if b != nil {
588
 
        c.in.freeBlock(b)
589
 
    }
590
 
    return c.error()
 
474
        if c.rawInput == nil {
 
475
                c.rawInput = c.in.newBlock()
 
476
        }
 
477
        b := c.rawInput
 
478
 
 
479
        // Read header, payload.
 
480
        if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
 
481
                // RFC suggests that EOF without an alertCloseNotify is
 
482
                // an error, but popular web sites seem to do this,
 
483
                // so we can't make it an error.
 
484
                // if err == io.EOF {
 
485
                //      err = io.ErrUnexpectedEOF
 
486
                // }
 
487
                if e, ok := err.(net.Error); !ok || !e.Temporary() {
 
488
                        c.setError(err)
 
489
                }
 
490
                return err
 
491
        }
 
492
        typ := recordType(b.data[0])
 
493
        vers := uint16(b.data[1])<<8 | uint16(b.data[2])
 
494
        n := int(b.data[3])<<8 | int(b.data[4])
 
495
        if c.haveVers && vers != c.vers {
 
496
                return c.sendAlert(alertProtocolVersion)
 
497
        }
 
498
        if n > maxCiphertext {
 
499
                return c.sendAlert(alertRecordOverflow)
 
500
        }
 
501
        if !c.haveVers {
 
502
                // First message, be extra suspicious:
 
503
                // this might not be a TLS client.
 
504
                // Bail out before reading a full 'body', if possible.
 
505
                // The current max version is 3.1.
 
506
                // If the version is >= 16.0, it's probably not real.
 
507
                // Similarly, a clientHello message encodes in
 
508
                // well under a kilobyte.  If the length is >= 12 kB,
 
509
                // it's probably not real.
 
510
                if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 {
 
511
                        return c.sendAlert(alertUnexpectedMessage)
 
512
                }
 
513
        }
 
514
        if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
 
515
                if err == io.EOF {
 
516
                        err = io.ErrUnexpectedEOF
 
517
                }
 
518
                if e, ok := err.(net.Error); !ok || !e.Temporary() {
 
519
                        c.setError(err)
 
520
                }
 
521
                return err
 
522
        }
 
523
 
 
524
        // Process message.
 
525
        b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
 
526
        b.off = recordHeaderLen
 
527
        if ok, err := c.in.decrypt(b); !ok {
 
528
                return c.sendAlert(err)
 
529
        }
 
530
        data := b.data[b.off:]
 
531
        if len(data) > maxPlaintext {
 
532
                c.sendAlert(alertRecordOverflow)
 
533
                c.in.freeBlock(b)
 
534
                return c.error()
 
535
        }
 
536
 
 
537
        switch typ {
 
538
        default:
 
539
                c.sendAlert(alertUnexpectedMessage)
 
540
 
 
541
        case recordTypeAlert:
 
542
                if len(data) != 2 {
 
543
                        c.sendAlert(alertUnexpectedMessage)
 
544
                        break
 
545
                }
 
546
                if alert(data[1]) == alertCloseNotify {
 
547
                        c.setError(io.EOF)
 
548
                        break
 
549
                }
 
550
                switch data[0] {
 
551
                case alertLevelWarning:
 
552
                        // drop on the floor
 
553
                        c.in.freeBlock(b)
 
554
                        goto Again
 
555
                case alertLevelError:
 
556
                        c.setError(&net.OpError{Op: "remote error", Err: alert(data[1])})
 
557
                default:
 
558
                        c.sendAlert(alertUnexpectedMessage)
 
559
                }
 
560
 
 
561
        case recordTypeChangeCipherSpec:
 
562
                if typ != want || len(data) != 1 || data[0] != 1 {
 
563
                        c.sendAlert(alertUnexpectedMessage)
 
564
                        break
 
565
                }
 
566
                err := c.in.changeCipherSpec()
 
567
                if err != nil {
 
568
                        c.sendAlert(err.(alert))
 
569
                }
 
570
 
 
571
        case recordTypeApplicationData:
 
572
                if typ != want {
 
573
                        c.sendAlert(alertUnexpectedMessage)
 
574
                        break
 
575
                }
 
576
                c.input = b
 
577
                b = nil
 
578
 
 
579
        case recordTypeHandshake:
 
580
                // TODO(rsc): Should at least pick off connection close.
 
581
                if typ != want && !c.isClient {
 
582
                        return c.sendAlert(alertNoRenegotiation)
 
583
                }
 
584
                c.hand.Write(data)
 
585
        }
 
586
 
 
587
        if b != nil {
 
588
                c.in.freeBlock(b)
 
589
        }
 
590
        return c.error()
591
591
}
592
592
 
593
593
// sendAlert sends a TLS alert message.
594
594
// c.out.Mutex <= L.
595
595
func (c *Conn) sendAlertLocked(err alert) error {
596
 
    c.tmp[0] = alertLevelError
597
 
    if err == alertNoRenegotiation {
598
 
        c.tmp[0] = alertLevelWarning
599
 
    }
600
 
    c.tmp[1] = byte(err)
601
 
    c.writeRecord(recordTypeAlert, c.tmp[0:2])
602
 
    // closeNotify is a special case in that it isn't an error:
603
 
    if err != alertCloseNotify {
604
 
        return c.setError(&net.OpError{Op: "local error", Err: err})
605
 
    }
606
 
    return nil
 
596
        c.tmp[0] = alertLevelError
 
597
        if err == alertNoRenegotiation {
 
598
                c.tmp[0] = alertLevelWarning
 
599
        }
 
600
        c.tmp[1] = byte(err)
 
601
        c.writeRecord(recordTypeAlert, c.tmp[0:2])
 
602
        // closeNotify is a special case in that it isn't an error:
 
603
        if err != alertCloseNotify {
 
604
                return c.setError(&net.OpError{Op: "local error", Err: err})
 
605
        }
 
606
        return nil
607
607
}
608
608
 
609
609
// sendAlert sends a TLS alert message.
610
610
// L < c.out.Mutex.
611
611
func (c *Conn) sendAlert(err alert) error {
612
 
    c.out.Lock()
613
 
    defer c.out.Unlock()
614
 
    return c.sendAlertLocked(err)
 
612
        c.out.Lock()
 
613
        defer c.out.Unlock()
 
614
        return c.sendAlertLocked(err)
615
615
}
616
616
 
617
617
// writeRecord writes a TLS record with the given type and payload
618
618
// to the connection and updates the record layer state.
619
619
// c.out.Mutex <= L.
620
620
func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
621
 
    b := c.out.newBlock()
622
 
    for len(data) > 0 {
623
 
        m := len(data)
624
 
        if m > maxPlaintext {
625
 
            m = maxPlaintext
626
 
        }
627
 
        b.resize(recordHeaderLen + m)
628
 
        b.data[0] = byte(typ)
629
 
        vers := c.vers
630
 
        if vers == 0 {
631
 
            vers = maxVersion
632
 
        }
633
 
        b.data[1] = byte(vers >> 8)
634
 
        b.data[2] = byte(vers)
635
 
        b.data[3] = byte(m >> 8)
636
 
        b.data[4] = byte(m)
637
 
        copy(b.data[recordHeaderLen:], data)
638
 
        c.out.encrypt(b)
639
 
        _, err = c.conn.Write(b.data)
640
 
        if err != nil {
641
 
            break
642
 
        }
643
 
        n += m
644
 
        data = data[m:]
645
 
    }
646
 
    c.out.freeBlock(b)
 
621
        b := c.out.newBlock()
 
622
        for len(data) > 0 {
 
623
                m := len(data)
 
624
                if m > maxPlaintext {
 
625
                        m = maxPlaintext
 
626
                }
 
627
                b.resize(recordHeaderLen + m)
 
628
                b.data[0] = byte(typ)
 
629
                vers := c.vers
 
630
                if vers == 0 {
 
631
                        vers = maxVersion
 
632
                }
 
633
                b.data[1] = byte(vers >> 8)
 
634
                b.data[2] = byte(vers)
 
635
                b.data[3] = byte(m >> 8)
 
636
                b.data[4] = byte(m)
 
637
                copy(b.data[recordHeaderLen:], data)
 
638
                c.out.encrypt(b)
 
639
                _, err = c.conn.Write(b.data)
 
640
                if err != nil {
 
641
                        break
 
642
                }
 
643
                n += m
 
644
                data = data[m:]
 
645
        }
 
646
        c.out.freeBlock(b)
647
647
 
648
 
    if typ == recordTypeChangeCipherSpec {
649
 
        err = c.out.changeCipherSpec()
650
 
        if err != nil {
651
 
            // Cannot call sendAlert directly,
652
 
            // because we already hold c.out.Mutex.
653
 
            c.tmp[0] = alertLevelError
654
 
            c.tmp[1] = byte(err.(alert))
655
 
            c.writeRecord(recordTypeAlert, c.tmp[0:2])
656
 
            c.err = &net.OpError{Op: "local error", Err: err}
657
 
            return n, c.err
658
 
        }
659
 
    }
660
 
    return
 
648
        if typ == recordTypeChangeCipherSpec {
 
649
                err = c.out.changeCipherSpec()
 
650
                if err != nil {
 
651
                        // Cannot call sendAlert directly,
 
652
                        // because we already hold c.out.Mutex.
 
653
                        c.tmp[0] = alertLevelError
 
654
                        c.tmp[1] = byte(err.(alert))
 
655
                        c.writeRecord(recordTypeAlert, c.tmp[0:2])
 
656
                        c.err = &net.OpError{Op: "local error", Err: err}
 
657
                        return n, c.err
 
658
                }
 
659
        }
 
660
        return
661
661
}
662
662
 
663
663
// readHandshake reads the next handshake message from
664
664
// the record layer.
665
665
// c.in.Mutex < L; c.out.Mutex < L.
666
666
func (c *Conn) readHandshake() (interface{}, error) {
667
 
    for c.hand.Len() < 4 {
668
 
        if c.err != nil {
669
 
            return nil, c.err
670
 
        }
671
 
        if err := c.readRecord(recordTypeHandshake); err != nil {
672
 
            return nil, err
673
 
        }
674
 
    }
675
 
 
676
 
    data := c.hand.Bytes()
677
 
    n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
678
 
    if n > maxHandshake {
679
 
        c.sendAlert(alertInternalError)
680
 
        return nil, c.err
681
 
    }
682
 
    for c.hand.Len() < 4+n {
683
 
        if c.err != nil {
684
 
            return nil, c.err
685
 
        }
686
 
        if err := c.readRecord(recordTypeHandshake); err != nil {
687
 
            return nil, err
688
 
        }
689
 
    }
690
 
    data = c.hand.Next(4 + n)
691
 
    var m handshakeMessage
692
 
    switch data[0] {
693
 
    case typeHelloRequest:
694
 
        m = new(helloRequestMsg)
695
 
    case typeClientHello:
696
 
        m = new(clientHelloMsg)
697
 
    case typeServerHello:
698
 
        m = new(serverHelloMsg)
699
 
    case typeCertificate:
700
 
        m = new(certificateMsg)
701
 
    case typeCertificateRequest:
702
 
        m = new(certificateRequestMsg)
703
 
    case typeCertificateStatus:
704
 
        m = new(certificateStatusMsg)
705
 
    case typeServerKeyExchange:
706
 
        m = new(serverKeyExchangeMsg)
707
 
    case typeServerHelloDone:
708
 
        m = new(serverHelloDoneMsg)
709
 
    case typeClientKeyExchange:
710
 
        m = new(clientKeyExchangeMsg)
711
 
    case typeCertificateVerify:
712
 
        m = new(certificateVerifyMsg)
713
 
    case typeNextProtocol:
714
 
        m = new(nextProtoMsg)
715
 
    case typeFinished:
716
 
        m = new(finishedMsg)
717
 
    default:
718
 
        c.sendAlert(alertUnexpectedMessage)
719
 
        return nil, alertUnexpectedMessage
720
 
    }
721
 
 
722
 
    // The handshake message unmarshallers
723
 
    // expect to be able to keep references to data,
724
 
    // so pass in a fresh copy that won't be overwritten.
725
 
    data = append([]byte(nil), data...)
726
 
 
727
 
    if !m.unmarshal(data) {
728
 
        c.sendAlert(alertUnexpectedMessage)
729
 
        return nil, alertUnexpectedMessage
730
 
    }
731
 
    return m, nil
 
667
        for c.hand.Len() < 4 {
 
668
                if c.err != nil {
 
669
                        return nil, c.err
 
670
                }
 
671
                if err := c.readRecord(recordTypeHandshake); err != nil {
 
672
                        return nil, err
 
673
                }
 
674
        }
 
675
 
 
676
        data := c.hand.Bytes()
 
677
        n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
 
678
        if n > maxHandshake {
 
679
                c.sendAlert(alertInternalError)
 
680
                return nil, c.err
 
681
        }
 
682
        for c.hand.Len() < 4+n {
 
683
                if c.err != nil {
 
684
                        return nil, c.err
 
685
                }
 
686
                if err := c.readRecord(recordTypeHandshake); err != nil {
 
687
                        return nil, err
 
688
                }
 
689
        }
 
690
        data = c.hand.Next(4 + n)
 
691
        var m handshakeMessage
 
692
        switch data[0] {
 
693
        case typeHelloRequest:
 
694
                m = new(helloRequestMsg)
 
695
        case typeClientHello:
 
696
                m = new(clientHelloMsg)
 
697
        case typeServerHello:
 
698
                m = new(serverHelloMsg)
 
699
        case typeCertificate:
 
700
                m = new(certificateMsg)
 
701
        case typeCertificateRequest:
 
702
                m = new(certificateRequestMsg)
 
703
        case typeCertificateStatus:
 
704
                m = new(certificateStatusMsg)
 
705
        case typeServerKeyExchange:
 
706
                m = new(serverKeyExchangeMsg)
 
707
        case typeServerHelloDone:
 
708
                m = new(serverHelloDoneMsg)
 
709
        case typeClientKeyExchange:
 
710
                m = new(clientKeyExchangeMsg)
 
711
        case typeCertificateVerify:
 
712
                m = new(certificateVerifyMsg)
 
713
        case typeNextProtocol:
 
714
                m = new(nextProtoMsg)
 
715
        case typeFinished:
 
716
                m = new(finishedMsg)
 
717
        default:
 
718
                c.sendAlert(alertUnexpectedMessage)
 
719
                return nil, alertUnexpectedMessage
 
720
        }
 
721
 
 
722
        // The handshake message unmarshallers
 
723
        // expect to be able to keep references to data,
 
724
        // so pass in a fresh copy that won't be overwritten.
 
725
        data = append([]byte(nil), data...)
 
726
 
 
727
        if !m.unmarshal(data) {
 
728
                c.sendAlert(alertUnexpectedMessage)
 
729
                return nil, alertUnexpectedMessage
 
730
        }
 
731
        return m, nil
732
732
}
733
733
 
734
734
// Write writes data to the connection.
735
735
func (c *Conn) Write(b []byte) (int, error) {
736
 
    if c.err != nil {
737
 
        return 0, c.err
738
 
    }
739
 
 
740
 
    if c.err = c.Handshake(); c.err != nil {
741
 
        return 0, c.err
742
 
    }
743
 
 
744
 
    c.out.Lock()
745
 
    defer c.out.Unlock()
746
 
 
747
 
    if !c.handshakeComplete {
748
 
        return 0, alertInternalError
749
 
    }
750
 
 
751
 
    var n int
752
 
    n, c.err = c.writeRecord(recordTypeApplicationData, b)
753
 
    return n, c.err
 
736
        if c.err != nil {
 
737
                return 0, c.err
 
738
        }
 
739
 
 
740
        if c.err = c.Handshake(); c.err != nil {
 
741
                return 0, c.err
 
742
        }
 
743
 
 
744
        c.out.Lock()
 
745
        defer c.out.Unlock()
 
746
 
 
747
        if !c.handshakeComplete {
 
748
                return 0, alertInternalError
 
749
        }
 
750
 
 
751
        var n int
 
752
        n, c.err = c.writeRecord(recordTypeApplicationData, b)
 
753
        return n, c.err
754
754
}
755
755
 
756
756
func (c *Conn) handleRenegotiation() error {
757
 
    c.handshakeComplete = false
758
 
    if !c.isClient {
759
 
        panic("renegotiation should only happen for a client")
760
 
    }
761
 
 
762
 
    msg, err := c.readHandshake()
763
 
    if err != nil {
764
 
        return err
765
 
    }
766
 
    _, ok := msg.(*helloRequestMsg)
767
 
    if !ok {
768
 
        c.sendAlert(alertUnexpectedMessage)
769
 
        return alertUnexpectedMessage
770
 
    }
771
 
 
772
 
    return c.Handshake()
 
757
        c.handshakeComplete = false
 
758
        if !c.isClient {
 
759
                panic("renegotiation should only happen for a client")
 
760
        }
 
761
 
 
762
        msg, err := c.readHandshake()
 
763
        if err != nil {
 
764
                return err
 
765
        }
 
766
        _, ok := msg.(*helloRequestMsg)
 
767
        if !ok {
 
768
                c.sendAlert(alertUnexpectedMessage)
 
769
                return alertUnexpectedMessage
 
770
        }
 
771
 
 
772
        return c.Handshake()
773
773
}
774
774
 
775
775
// Read can be made to time out and return a net.Error with Timeout() == true
776
776
// after a fixed time limit; see SetDeadline and SetReadDeadline.
777
777
func (c *Conn) Read(b []byte) (n int, err error) {
778
 
    if err = c.Handshake(); err != nil {
779
 
        return
780
 
    }
781
 
 
782
 
    c.in.Lock()
783
 
    defer c.in.Unlock()
784
 
 
785
 
    for c.input == nil && c.err == nil {
786
 
        if err := c.readRecord(recordTypeApplicationData); err != nil {
787
 
            // Soft error, like EAGAIN
788
 
            return 0, err
789
 
        }
790
 
        if c.hand.Len() > 0 {
791
 
            // We received handshake bytes, indicating the start of
792
 
            // a renegotiation.
793
 
            if err := c.handleRenegotiation(); err != nil {
794
 
                return 0, err
795
 
            }
796
 
            continue
797
 
        }
798
 
    }
799
 
    if c.err != nil {
800
 
        return 0, c.err
801
 
    }
802
 
    n, err = c.input.Read(b)
803
 
    if c.input.off >= len(c.input.data) {
804
 
        c.in.freeBlock(c.input)
805
 
        c.input = nil
806
 
    }
807
 
    return n, nil
 
778
        if err = c.Handshake(); err != nil {
 
779
                return
 
780
        }
 
781
 
 
782
        c.in.Lock()
 
783
        defer c.in.Unlock()
 
784
 
 
785
        for c.input == nil && c.err == nil {
 
786
                if err := c.readRecord(recordTypeApplicationData); err != nil {
 
787
                        // Soft error, like EAGAIN
 
788
                        return 0, err
 
789
                }
 
790
                if c.hand.Len() > 0 {
 
791
                        // We received handshake bytes, indicating the start of
 
792
                        // a renegotiation.
 
793
                        if err := c.handleRenegotiation(); err != nil {
 
794
                                return 0, err
 
795
                        }
 
796
                        continue
 
797
                }
 
798
        }
 
799
        if c.err != nil {
 
800
                return 0, c.err
 
801
        }
 
802
        n, err = c.input.Read(b)
 
803
        if c.input.off >= len(c.input.data) {
 
804
                c.in.freeBlock(c.input)
 
805
                c.input = nil
 
806
        }
 
807
        return n, nil
808
808
}
809
809
 
810
810
// Close closes the connection.
811
811
func (c *Conn) Close() error {
812
 
    var alertErr error
813
 
 
814
 
    c.handshakeMutex.Lock()
815
 
    defer c.handshakeMutex.Unlock()
816
 
    if c.handshakeComplete {
817
 
        alertErr = c.sendAlert(alertCloseNotify)
818
 
    }
819
 
 
820
 
    if err := c.conn.Close(); err != nil {
821
 
        return err
822
 
    }
823
 
    return alertErr
 
812
        var alertErr error
 
813
 
 
814
        c.handshakeMutex.Lock()
 
815
        defer c.handshakeMutex.Unlock()
 
816
        if c.handshakeComplete {
 
817
                alertErr = c.sendAlert(alertCloseNotify)
 
818
        }
 
819
 
 
820
        if err := c.conn.Close(); err != nil {
 
821
                return err
 
822
        }
 
823
        return alertErr
824
824
}
825
825
 
826
826
// Handshake runs the client or server handshake
828
828
// Most uses of this package need not call Handshake
829
829
// explicitly: the first Read or Write will call it automatically.
830
830
func (c *Conn) Handshake() error {
831
 
    c.handshakeMutex.Lock()
832
 
    defer c.handshakeMutex.Unlock()
833
 
    if err := c.error(); err != nil {
834
 
        return err
835
 
    }
836
 
    if c.handshakeComplete {
837
 
        return nil
838
 
    }
839
 
    if c.isClient {
840
 
        return c.clientHandshake()
841
 
    }
842
 
    return c.serverHandshake()
 
831
        c.handshakeMutex.Lock()
 
832
        defer c.handshakeMutex.Unlock()
 
833
        if err := c.error(); err != nil {
 
834
                return err
 
835
        }
 
836
        if c.handshakeComplete {
 
837
                return nil
 
838
        }
 
839
        if c.isClient {
 
840
                return c.clientHandshake()
 
841
        }
 
842
        return c.serverHandshake()
843
843
}
844
844
 
845
845
// ConnectionState returns basic TLS details about the connection.
846
846
func (c *Conn) ConnectionState() ConnectionState {
847
 
    c.handshakeMutex.Lock()
848
 
    defer c.handshakeMutex.Unlock()
849
 
 
850
 
    var state ConnectionState
851
 
    state.HandshakeComplete = c.handshakeComplete
852
 
    if c.handshakeComplete {
853
 
        state.NegotiatedProtocol = c.clientProtocol
854
 
        state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback
855
 
        state.CipherSuite = c.cipherSuite
856
 
        state.PeerCertificates = c.peerCertificates
857
 
        state.VerifiedChains = c.verifiedChains
858
 
        state.ServerName = c.serverName
859
 
    }
860
 
 
861
 
    return state
 
847
        c.handshakeMutex.Lock()
 
848
        defer c.handshakeMutex.Unlock()
 
849
 
 
850
        var state ConnectionState
 
851
        state.HandshakeComplete = c.handshakeComplete
 
852
        if c.handshakeComplete {
 
853
                state.NegotiatedProtocol = c.clientProtocol
 
854
                state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback
 
855
                state.CipherSuite = c.cipherSuite
 
856
                state.PeerCertificates = c.peerCertificates
 
857
                state.VerifiedChains = c.verifiedChains
 
858
                state.ServerName = c.serverName
 
859
        }
 
860
 
 
861
        return state
862
862
}
863
863
 
864
864
// OCSPResponse returns the stapled OCSP response from the TLS server, if
865
865
// any. (Only valid for client connections.)
866
866
func (c *Conn) OCSPResponse() []byte {
867
 
    c.handshakeMutex.Lock()
868
 
    defer c.handshakeMutex.Unlock()
 
867
        c.handshakeMutex.Lock()
 
868
        defer c.handshakeMutex.Unlock()
869
869
 
870
 
    return c.ocspResponse
 
870
        return c.ocspResponse
871
871
}
872
872
 
873
873
// VerifyHostname checks that the peer certificate chain is valid for
874
874
// connecting to host.  If so, it returns nil; if not, it returns an error
875
875
// describing the problem.
876
876
func (c *Conn) VerifyHostname(host string) error {
877
 
    c.handshakeMutex.Lock()
878
 
    defer c.handshakeMutex.Unlock()
879
 
    if !c.isClient {
880
 
        return errors.New("VerifyHostname called on TLS server connection")
881
 
    }
882
 
    if !c.handshakeComplete {
883
 
        return errors.New("TLS handshake has not yet been performed")
884
 
    }
885
 
    return c.peerCertificates[0].VerifyHostname(host)
 
877
        c.handshakeMutex.Lock()
 
878
        defer c.handshakeMutex.Unlock()
 
879
        if !c.isClient {
 
880
                return errors.New("VerifyHostname called on TLS server connection")
 
881
        }
 
882
        if !c.handshakeComplete {
 
883
                return errors.New("TLS handshake has not yet been performed")
 
884
        }
 
885
        return c.peerCertificates[0].VerifyHostname(host)
886
886
}