~ubuntu-branches/ubuntu/precise/checkbox/precise

« back to all changes in this revision

Viewing changes to checkbox/lib/transport.py

  • Committer: Bazaar Package Importer
  • Author(s): Marc Tardif
  • Date: 2009-01-20 16:46:15 UTC
  • Revision ID: james.westby@ubuntu.com-20090120164615-7iz6nmlef41h4vx2
Tags: 0.4
* Setup bzr-builddeb in native mode.
* Removed LGPL notice from the copyright file.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#
 
2
# Copyright (c) 2008 Canonical
 
3
#
 
4
# Written by Marc Tardif <marc@interunion.ca>
 
5
#
 
6
# This file is part of Checkbox.
 
7
#
 
8
# Checkbox is free software: you can redistribute it and/or modify
 
9
# it under the terms of the GNU General Public License as published by
 
10
# the Free Software Foundation, either version 3 of the License, or
 
11
# (at your option) any later version.
 
12
#
 
13
# Checkbox is distributed in the hope that it will be useful,
 
14
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
15
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
16
# GNU General Public License for more details.
 
17
#
 
18
# You should have received a copy of the GNU General Public License
 
19
# along with Checkbox.  If not, see <http://www.gnu.org/licenses/>.
 
20
#
 
21
import logging
 
22
 
 
23
import os
 
24
import stat
 
25
import sys
 
26
import posixpath
 
27
 
 
28
import mimetools
 
29
import mimetypes
 
30
import socket
 
31
import httplib
 
32
import urllib
 
33
 
 
34
 
 
35
class ProxyHTTPConnection(httplib.HTTPConnection):
 
36
 
 
37
    _ports = {"http" : httplib.HTTP_PORT, "https" : httplib.HTTPS_PORT}
 
38
 
 
39
    def request(self, method, url, body=None, headers={}):
 
40
        #request is called before connect, so can interpret url and get
 
41
        #real host/port to be used to make CONNECT request to proxy
 
42
        scheme, rest = urllib.splittype(url)
 
43
        if scheme is None:
 
44
            raise ValueError, "unknown URL type: %s" % url
 
45
        #get host
 
46
        host, rest = urllib.splithost(rest)
 
47
        #try to get port
 
48
        host, port = urllib.splitport(host)
 
49
        #if port is not defined try to get from scheme
 
50
        if port is None:
 
51
            try:
 
52
                port = self._ports[scheme]
 
53
            except KeyError:
 
54
                raise ValueError, "unknown protocol for: %s" % url
 
55
        self._real_host = host
 
56
        self._real_port = port
 
57
        httplib.HTTPConnection.request(self, method, url, body, headers)
 
58
 
 
59
    def connect(self):
 
60
        httplib.HTTPConnection.connect(self)
 
61
        #send proxy CONNECT request
 
62
        self.send("CONNECT %s:%d HTTP/1.0\r\n\r\n" % (self._real_host, self._real_port))
 
63
        #expect a HTTP/1.0 200 Connection established
 
64
        response = self.response_class(self.sock, strict=self.strict, method=self._method)
 
65
        (version, code, message) = response._read_status()
 
66
        #probably here we can handle auth requests...
 
67
        if code != 200:
 
68
            #proxy returned and error, abort connection, and raise exception
 
69
            self.close()
 
70
            raise socket.error, "Proxy connection failed: %d %s" % (code, message.strip())
 
71
        #eat up header block from proxy....
 
72
        while True:
 
73
            #should not use directly fp probablu
 
74
            line = response.fp.readline()
 
75
            if line == "\r\n":
 
76
                break
 
77
 
 
78
 
 
79
class ProxyHTTPSConnection(ProxyHTTPConnection):
 
80
 
 
81
    default_port = httplib.HTTPS_PORT
 
82
 
 
83
    def __init__(self, host, port=None, key_file=None, cert_file=None, strict=None):
 
84
        ProxyHTTPConnection.__init__(self, host, port)
 
85
        self.key_file = key_file
 
86
        self.cert_file = cert_file
 
87
 
 
88
    def connect(self):
 
89
        ProxyHTTPConnection.connect(self)
 
90
        #make the sock ssl-aware
 
91
        ssl = socket.ssl(self.sock, self.key_file, self.cert_file)
 
92
        self.sock = httplib.FakeSocket(self.sock, ssl)
 
93
 
 
94
 
 
95
class HTTPTransport(object):
 
96
    """Transport makes a request to exchange message data over HTTP."""
 
97
 
 
98
    def __init__(self, url):
 
99
        self.url = url
 
100
 
 
101
        proxies = urllib.getproxies()
 
102
        self.http_proxy = proxies.get("http")
 
103
        self.https_proxy = proxies.get("https")
 
104
 
 
105
    def _unpack_host_and_port(self, string):
 
106
        scheme, rest = urllib.splittype(string)
 
107
        host, rest = urllib.splithost(rest)
 
108
        host, port = urllib.splitport(host)
 
109
        return (host, port)
 
110
 
 
111
    def _get_connection(self, timeout=0):
 
112
        if timeout:
 
113
            socket.setdefaulttimeout(timeout)
 
114
 
 
115
        scheme, rest = urllib.splittype(self.url)
 
116
        if scheme == "http":
 
117
            if self.http_proxy:
 
118
                host, port = self._unpack_host_and_port(self.http_proxy)
 
119
            else:
 
120
                host, port = self._unpack_host_and_port(self.url)
 
121
 
 
122
            connection = httplib.HTTPConnection(host, port)
 
123
        elif scheme == "https":
 
124
            if self.https_proxy:
 
125
                host, port = self._unpack_host_and_port(self.https_proxy)
 
126
                connection = ProxyHTTPSConnection(host, port)
 
127
            else:
 
128
                host, port = self._unpack_host_and_port(self.url)
 
129
                connection = httplib.HTTPSConnection(host, port)
 
130
        else:
 
131
            raise Exception, "Unknown URL scheme: %s" % scheme
 
132
 
 
133
        return connection
 
134
 
 
135
    def _encode_multipart_formdata(self, fields=[], files=[]):
 
136
        boundary = mimetools.choose_boundary()
 
137
 
 
138
        lines = []
 
139
        for (key, value) in fields:
 
140
            lines.append("--" + boundary)
 
141
            lines.append("Content-Disposition: form-data; name=\"%s\"" % key)
 
142
            lines.append("")
 
143
            lines.append(value)
 
144
 
 
145
        for (key, file) in files:
 
146
            if hasattr(file, "size"):
 
147
                length = file.size
 
148
            else:
 
149
                length = os.fstat(file.fileno())[stat.ST_SIZE]
 
150
 
 
151
            filename = posixpath.basename(file.name)
 
152
            if isinstance(filename, unicode):
 
153
                filename = filename.encode("UTF-8")
 
154
 
 
155
            lines.append("--" + boundary)
 
156
            lines.append("Content-Disposition: form-data; name=\"%s\"; filename=\"%s\""
 
157
                % (key, filename))
 
158
            lines.append("Content-Type: %s"
 
159
                % mimetypes.guess_type(filename)[0] or "application/octet-stream")
 
160
            lines.append("Content-Length: %s" % length)
 
161
            lines.append("")
 
162
 
 
163
            if hasattr(file, "seek"):
 
164
                file.seek(0)
 
165
            lines.append(file.read())
 
166
 
 
167
        lines.append("--" + boundary + "--")
 
168
        lines.append("")
 
169
 
 
170
        content_type = "multipart/form-data; boundary=%s" % boundary
 
171
        body = "\r\n".join(lines)
 
172
 
 
173
        return content_type, body
 
174
 
 
175
    def _encode_body(self, body=None):
 
176
        fields = []
 
177
        files = []
 
178
 
 
179
        content_type = "application/octet-stream"
 
180
        if body is not None and type(body) != str:
 
181
            if hasattr(body, "items"):
 
182
                body = body.items()
 
183
            else:
 
184
                try:
 
185
                    if len(body) and not isinstance(body[0], tuple):
 
186
                        raise TypeError
 
187
                except TypeError:
 
188
                    ty, va, tb = sys.exc_info()
 
189
                    raise TypeError, \
 
190
                        "Invalid non-string sequence or mapping", tb
 
191
 
 
192
            for key, value in body:
 
193
                if hasattr(value, "read"):
 
194
                    files.append((key, value))
 
195
                else:
 
196
                    fields.append((key, value))
 
197
 
 
198
            if files:
 
199
                content_type, body = self._encode_multipart_formdata(fields,
 
200
                    files)
 
201
            elif fields:
 
202
                content_type = "application/x-www-form-urlencoded"
 
203
                body = urllib.urlencode(fields)
 
204
            else:
 
205
                body = ""
 
206
 
 
207
        return content_type, body
 
208
 
 
209
    def exchange(self, body=None, headers={}, timeout=0):
 
210
        headers = dict(headers)
 
211
 
 
212
        if body is not None:
 
213
            method = "POST"
 
214
            (content_type, body) = self._encode_body(body)
 
215
            if "Content-Type" not in headers:
 
216
                headers["Content-Type"] = content_type
 
217
            if "Content-Length" not in headers:
 
218
                headers["Content-Length"] = len(body)
 
219
        else:
 
220
            method = "GET"
 
221
 
 
222
        response = None
 
223
        connection = self._get_connection(timeout)
 
224
 
 
225
        try:
 
226
            connection.request(method, self.url, body, headers)
 
227
        except IOError:
 
228
            logging.warning("Can't connect to %s", self.url)
 
229
        except socket.error:
 
230
            logging.error("Error connecting to %s", self.url)
 
231
        except socket.timeout:
 
232
            logging.warning("Timeout connecting to %s", self.url)
 
233
        else:
 
234
            try:
 
235
                response = connection.getresponse()
 
236
            except httplib.BadStatusLine:
 
237
                logging.warning("Service unavailable on %s", self.url)
 
238
            else:
 
239
                if response.status == httplib.FOUND:
 
240
                    # TODO prevent infinite redirect loop
 
241
                    self.url = self._get_location_header(response)
 
242
                    response = self.exchange(body, headers, timeout)
 
243
 
 
244
        return response