~djfroofy/txaws/930359-headers

« back to all changes in this revision

Viewing changes to txaws/client/tests/test_base.py

  • Committer: Duncan McGreggor
  • Date: 2012-01-26 23:07:01 UTC
  • mfrom: (122.1.11 920309-fix-ca-certs)
  • Revision ID: duncan@dreamhost.com-20120126230701-3faby2yptkwfbktg
This change fixes issues with testing of SSL certs on Mac OS X. Additionally,
it adds unit tests where there were none for get_ca_certs.

Reviewers: Stephon Striplin, Jamu Kakar.
Fixes: https://bugs.launchpad.net/txaws/+bug/920309

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
1
import os
2
2
 
3
 
from OpenSSL.crypto import load_certificate, FILETYPE_PEM
4
 
from OpenSSL.SSL import Error as SSLError
5
 
from OpenSSL.version import __version__ as pyopenssl_version
6
 
 
7
3
from twisted.internet import reactor
8
 
from twisted.internet.ssl import DefaultOpenSSLContextFactory
9
4
from twisted.internet.error import ConnectionRefusedError
10
5
from twisted.protocols.policies import WrappingFactory
11
6
from twisted.python import log
12
7
from twisted.python.filepath import FilePath
13
8
from twisted.python.failure import Failure
 
9
from twisted.test.test_sslverify import makeCertificate
14
10
from twisted.web import server, static
15
11
from twisted.web.client import HTTPClientFactory
16
12
from twisted.web.error import Error as TwistedWebError
17
13
 
 
14
from txaws.client import ssl
18
15
from txaws.client.base import BaseClient, BaseQuery, error_wrapper
19
 
from txaws.client.ssl import VerifyingContextFactory
20
16
from txaws.service import AWSServiceEndpoint
21
17
from txaws.testing.base import TXAWSTestCase
22
18
 
23
19
 
24
 
def sibpath(path):
25
 
    return os.path.join(os.path.dirname(__file__), path)
26
 
 
27
 
 
28
 
PRIVKEY = sibpath("private.ssl")
29
 
PUBKEY = sibpath("public.ssl")
30
 
BADPRIVKEY = sibpath("badprivate.ssl")
31
 
BADPUBKEY = sibpath("badpublic.ssl")
32
 
PRIVSANKEY = sibpath("private_san.ssl")
33
 
PUBSANKEY = sibpath("public_san.ssl")
34
 
 
35
 
 
36
20
class ErrorWrapperTestCase(TXAWSTestCase):
37
21
 
38
22
    def test_204_no_content(self):
168
152
        d.addCallback(query.get_response_headers)
169
153
        return d.addCallback(check_results)
170
154
 
 
155
    # XXX for systems that don't have certs in the DEFAULT_CERT_PATH, this test
 
156
    # will fail; instead, let's create some certs in a temp directory and set
 
157
    # the DEFAULT_CERT_PATH to point there.
171
158
    def test_ssl_hostname_verification(self):
172
159
        """
173
160
        If the endpoint passed to L{BaseQuery} has C{ssl_hostname_verification}
183
170
            def connectSSL(self, host, port, client, factory):
184
171
                self.connects.append((host, port, client, factory))
185
172
 
 
173
        certs = makeCertificate(O="Test Certificate", CN="something")[1]
 
174
        self.patch(ssl, "_ca_certs", certs)
186
175
        fake_reactor = FakeReactor()
187
176
        endpoint = AWSServiceEndpoint(ssl_hostname_verification=True)
188
177
        query = BaseQuery("an action", "creds", endpoint, fake_reactor)
190
179
        [(host, port, client, factory)] = fake_reactor.connects
191
180
        self.assertEqual("example.com", host)
192
181
        self.assertEqual(443, port)
193
 
        self.assertTrue(isinstance(factory, VerifyingContextFactory))
 
182
        self.assertTrue(isinstance(factory, ssl.VerifyingContextFactory))
194
183
        self.assertEqual("example.com", factory.host)
195
184
        self.assertNotEqual([], factory.caCerts)
196
 
 
197
 
 
198
 
class BaseQuerySSLTestCase(TXAWSTestCase):
199
 
 
200
 
    def setUp(self):
201
 
        self.cleanupServerConnections = 0
202
 
        name = self.mktemp()
203
 
        os.mkdir(name)
204
 
        FilePath(name).child("file").setContent("0123456789")
205
 
        r = static.File(name)
206
 
        self.site = server.Site(r, timeout=None)
207
 
        self.wrapper = WrappingFactory(self.site)
208
 
        from txaws.client import ssl
209
 
        pub_key = file(PUBKEY)
210
 
        pub_key_data = pub_key.read()
211
 
        pub_key.close()
212
 
        pub_key_san = file(PUBSANKEY)
213
 
        pub_key_san_data = pub_key_san.read()
214
 
        pub_key_san.close()
215
 
        ssl._ca_certs = [load_certificate(FILETYPE_PEM, pub_key_data),
216
 
                         load_certificate(FILETYPE_PEM, pub_key_san_data)]
217
 
 
218
 
    def tearDown(self):
219
 
        from txaws.client import ssl
220
 
        ssl._ca_certs = None
221
 
        # If the test indicated it might leave some server-side connections
222
 
        # around, clean them up.
223
 
        connections = self.wrapper.protocols.keys()
224
 
        # If there are fewer server-side connections than requested,
225
 
        # that's okay.  Some might have noticed that the client closed
226
 
        # the connection and cleaned up after themselves.
227
 
        for n in range(min(len(connections), self.cleanupServerConnections)):
228
 
            proto = connections.pop()
229
 
            log.msg("Closing %r" % (proto,))
230
 
            proto.transport.loseConnection()
231
 
        if connections:
232
 
            log.msg("Some left-over connections; this test is probably buggy.")
233
 
        return self.port.stopListening()
234
 
 
235
 
    def _get_url(self, path):
236
 
        return "https://localhost:%d/%s" % (self.portno, path)
237
 
 
238
 
    def test_ssl_verification_positive(self):
239
 
        """
240
 
        The L{VerifyingContextFactory} properly allows to connect to the
241
 
        endpoint if the certificates match.
242
 
        """
243
 
        context_factory = DefaultOpenSSLContextFactory(PRIVKEY, PUBKEY)
244
 
        self.port = reactor.listenSSL(
245
 
            0, self.site, context_factory, interface="127.0.0.1")
246
 
        self.portno = self.port.getHost().port
247
 
 
248
 
        endpoint = AWSServiceEndpoint(ssl_hostname_verification=True)
249
 
        query = BaseQuery("an action", "creds", endpoint)
250
 
        d = query.get_page(self._get_url("file"))
251
 
        return d.addCallback(self.assertEquals, "0123456789")
252
 
 
253
 
    def test_ssl_verification_negative(self):
254
 
        """
255
 
        The L{VerifyingContextFactory} fails with a SSL error the certificates
256
 
        can't be checked.
257
 
        """
258
 
        context_factory = DefaultOpenSSLContextFactory(BADPRIVKEY, BADPUBKEY)
259
 
        self.port = reactor.listenSSL(
260
 
            0, self.site, context_factory, interface="127.0.0.1")
261
 
        self.portno = self.port.getHost().port
262
 
 
263
 
        endpoint = AWSServiceEndpoint(ssl_hostname_verification=True)
264
 
        query = BaseQuery("an action", "creds", endpoint)
265
 
        d = query.get_page(self._get_url("file"))
266
 
        return self.assertFailure(d, SSLError)
267
 
 
268
 
    def test_ssl_verification_bypassed(self):
269
 
        """
270
 
        L{BaseQuery} doesn't use L{VerifyingContextFactory}
271
 
        if C{ssl_hostname_verification} is C{False}, thus allowing to connect
272
 
        to non-secure endpoints.
273
 
        """
274
 
        context_factory = DefaultOpenSSLContextFactory(BADPRIVKEY, BADPUBKEY)
275
 
        self.port = reactor.listenSSL(
276
 
            0, self.site, context_factory, interface="127.0.0.1")
277
 
        self.portno = self.port.getHost().port
278
 
 
279
 
        endpoint = AWSServiceEndpoint(ssl_hostname_verification=False)
280
 
        query = BaseQuery("an action", "creds", endpoint)
281
 
        d = query.get_page(self._get_url("file"))
282
 
        return d.addCallback(self.assertEquals, "0123456789")
283
 
 
284
 
    def test_ssl_subject_alt_name(self):
285
 
        """
286
 
        L{VerifyingContextFactory} supports checking C{subjectAltName} in the
287
 
        certificate if it's available.
288
 
        """
289
 
        context_factory = DefaultOpenSSLContextFactory(PRIVSANKEY, PUBSANKEY)
290
 
        self.port = reactor.listenSSL(
291
 
            0, self.site, context_factory, interface="127.0.0.1")
292
 
        self.portno = self.port.getHost().port
293
 
 
294
 
        endpoint = AWSServiceEndpoint(ssl_hostname_verification=True)
295
 
        query = BaseQuery("an action", "creds", endpoint)
296
 
        d = query.get_page("https://127.0.0.1:%d/file" % (self.portno,))
297
 
        return d.addCallback(self.assertEquals, "0123456789")
298
 
 
299
 
    if pyopenssl_version < "0.12":
300
 
        test_ssl_subject_alt_name.skip = (
301
 
            "subjectAltName not supported by older PyOpenSSL")