~tribaal/txaws/xss-hardening

« back to all changes in this revision

Viewing changes to txaws/client/base.py

  • Committer: Thomas Hervé
  • Date: 2011-12-01 12:19:46 UTC
  • mfrom: (102.1.4 ssl-verify)
  • Revision ID: thomas@canonical.com-20111201121946-kbhurhwhaiotemfm
Merge ssl-verify [r=free.ekanayaka] [f=781949]

Add a ssl_hostname_verification option to AWSServiceEndpoint, which enables
verification of the SSL certificate at connection time.

Show diffs side-by-side

added added

removed removed

Lines of Context:
3
3
except ImportError:
4
4
    from xml.parsers.expat import ExpatError as ParseError
5
5
 
6
 
from twisted.internet import reactor, ssl
 
6
from twisted.internet.ssl import ClientContextFactory
7
7
from twisted.web import http
8
8
from twisted.web.client import HTTPClientFactory
9
9
from twisted.web.error import Error as TwistedWebError
12
12
from txaws.credentials import AWSCredentials
13
13
from txaws.exception import AWSResponseParseError
14
14
from txaws.service import AWSServiceEndpoint
 
15
from txaws.client.ssl import VerifyingContextFactory
15
16
 
16
17
 
17
18
def error_wrapper(error, errorClass):
73
74
 
74
75
class BaseQuery(object):
75
76
 
76
 
    def __init__(self, action=None, creds=None, endpoint=None):
 
77
    def __init__(self, action=None, creds=None, endpoint=None, reactor=None):
77
78
        if not action:
78
79
            raise TypeError("The query requires an action parameter.")
79
80
        self.factory = HTTPClientFactory
80
81
        self.action = action
81
82
        self.creds = creds
82
83
        self.endpoint = endpoint
 
84
        if reactor is None:
 
85
            from twisted.internet import reactor
 
86
        self.reactor = reactor
83
87
        self.client = None
84
88
 
85
89
    def get_page(self, url, *args, **kwds):
92
96
        contextFactory = None
93
97
        scheme, host, port, path = parse(url)
94
98
        self.client = self.factory(url, *args, **kwds)
95
 
        if scheme == 'https':
96
 
            contextFactory = ssl.ClientContextFactory()
97
 
            reactor.connectSSL(host, port, self.client, contextFactory)
 
99
        if scheme == "https":
 
100
            if self.endpoint.ssl_hostname_verification:
 
101
                contextFactory = VerifyingContextFactory(host)
 
102
            else:
 
103
                contextFactory = ClientContextFactory()
 
104
            self.reactor.connectSSL(host, port, self.client, contextFactory)
98
105
        else:
99
 
            reactor.connectTCP(host, port, self.client)
 
106
            self.reactor.connectTCP(host, port, self.client)
100
107
        return self.client.deferred
101
108
 
102
109
    def get_request_headers(self, *args, **kwds):