~certify-web-dev/twisted/certify-production

« back to all changes in this revision

Viewing changes to twisted/web/test/test_webclient.py

  • Committer: Marc Tardif
  • Date: 2010-05-20 19:56:06 UTC
  • Revision ID: marc.tardif@canonical.com-20100520195606-xdrf0ztlxhvwmmzb
Added twisted-web.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (c) 2001-2010 Twisted Matrix Laboratories.
 
2
# See LICENSE for details.
 
3
 
 
4
"""
 
5
Tests for L{twisted.web.client}.
 
6
"""
 
7
 
 
8
import os
 
9
from errno import ENOSPC
 
10
 
 
11
from urlparse import urlparse
 
12
 
 
13
from twisted.trial import unittest
 
14
from twisted.web import server, static, client, error, util, resource, http_headers
 
15
from twisted.internet import reactor, defer, interfaces
 
16
from twisted.python.filepath import FilePath
 
17
from twisted.python.log import msg
 
18
from twisted.protocols.policies import WrappingFactory
 
19
from twisted.test.proto_helpers import StringTransport
 
20
from twisted.test.proto_helpers import MemoryReactor
 
21
from twisted.internet.address import IPv4Address
 
22
from twisted.internet.task import Clock
 
23
from twisted.internet.error import ConnectionRefusedError
 
24
from twisted.internet.protocol import Protocol
 
25
from twisted.internet.defer import Deferred
 
26
from twisted.web.client import Request
 
27
from twisted.web.error import SchemeNotSupported
 
28
 
 
29
try:
 
30
    from twisted.internet import ssl
 
31
except:
 
32
    ssl = None
 
33
 
 
34
 
 
35
 
 
36
class ExtendedRedirect(resource.Resource):
 
37
    """
 
38
    Redirection resource.
 
39
 
 
40
    The HTTP status code is set according to the C{code} query parameter.
 
41
 
 
42
    @type lastMethod: C{str}
 
43
    @ivar lastMethod: Last handled HTTP request method
 
44
    """
 
45
    isLeaf = 1
 
46
    lastMethod = None
 
47
 
 
48
 
 
49
    def __init__(self, url):
 
50
        resource.Resource.__init__(self)
 
51
        self.url = url
 
52
 
 
53
 
 
54
    def render(self, request):
 
55
        if self.lastMethod:
 
56
            self.lastMethod = request.method
 
57
            return "OK Thnx!"
 
58
        else:
 
59
            self.lastMethod = request.method
 
60
            code = int(request.args['code'][0])
 
61
            return self.redirectTo(self.url, request, code)
 
62
 
 
63
 
 
64
    def getChild(self, name, request):
 
65
        return self
 
66
 
 
67
 
 
68
    def redirectTo(self, url, request, code):
 
69
        request.setResponseCode(code)
 
70
        request.setHeader("location", url)
 
71
        return "OK Bye!"
 
72
 
 
73
 
 
74
 
 
75
class ForeverTakingResource(resource.Resource):
 
76
    """
 
77
    L{ForeverTakingResource} is a resource which never finishes responding
 
78
    to requests.
 
79
    """
 
80
    def __init__(self, write=False):
 
81
        resource.Resource.__init__(self)
 
82
        self._write = write
 
83
 
 
84
    def render(self, request):
 
85
        if self._write:
 
86
            request.write('some bytes')
 
87
        return server.NOT_DONE_YET
 
88
 
 
89
 
 
90
class CookieMirrorResource(resource.Resource):
 
91
    def render(self, request):
 
92
        l = []
 
93
        for k,v in request.received_cookies.items():
 
94
            l.append((k, v))
 
95
        l.sort()
 
96
        return repr(l)
 
97
 
 
98
class RawCookieMirrorResource(resource.Resource):
 
99
    def render(self, request):
 
100
        return repr(request.getHeader('cookie'))
 
101
 
 
102
class ErrorResource(resource.Resource):
 
103
 
 
104
    def render(self, request):
 
105
        request.setResponseCode(401)
 
106
        if request.args.get("showlength"):
 
107
            request.setHeader("content-length", "0")
 
108
        return ""
 
109
 
 
110
class NoLengthResource(resource.Resource):
 
111
 
 
112
    def render(self, request):
 
113
        return "nolength"
 
114
 
 
115
 
 
116
 
 
117
class HostHeaderResource(resource.Resource):
 
118
    """
 
119
    A testing resource which renders itself as the value of the host header
 
120
    from the request.
 
121
    """
 
122
    def render(self, request):
 
123
        return request.received_headers['host']
 
124
 
 
125
 
 
126
 
 
127
class PayloadResource(resource.Resource):
 
128
    """
 
129
    A testing resource which renders itself as the contents of the request body
 
130
    as long as the request body is 100 bytes long, otherwise which renders
 
131
    itself as C{"ERROR"}.
 
132
    """
 
133
    def render(self, request):
 
134
        data = request.content.read()
 
135
        contentLength = request.received_headers['content-length']
 
136
        if len(data) != 100 or int(contentLength) != 100:
 
137
            return "ERROR"
 
138
        return data
 
139
 
 
140
 
 
141
 
 
142
class BrokenDownloadResource(resource.Resource):
 
143
 
 
144
    def render(self, request):
 
145
        # only sends 3 bytes even though it claims to send 5
 
146
        request.setHeader("content-length", "5")
 
147
        request.write('abc')
 
148
        return ''
 
149
 
 
150
class CountingRedirect(util.Redirect):
 
151
    """
 
152
    A L{util.Redirect} resource that keeps track of the number of times the
 
153
    resource has been accessed.
 
154
    """
 
155
    def __init__(self, *a, **kw):
 
156
        util.Redirect.__init__(self, *a, **kw)
 
157
        self.count = 0
 
158
 
 
159
    def render(self, request):
 
160
        self.count += 1
 
161
        return util.Redirect.render(self, request)
 
162
 
 
163
 
 
164
 
 
165
class ParseUrlTestCase(unittest.TestCase):
 
166
    """
 
167
    Test URL parsing facility and defaults values.
 
168
    """
 
169
 
 
170
    def test_parse(self):
 
171
        """
 
172
        L{client._parse} correctly parses a URL into its various components.
 
173
        """
 
174
        # The default port for HTTP is 80.
 
175
        self.assertEqual(
 
176
            client._parse('http://127.0.0.1/'),
 
177
            ('http', '127.0.0.1', 80, '/'))
 
178
 
 
179
        # The default port for HTTPS is 443.
 
180
        self.assertEqual(
 
181
            client._parse('https://127.0.0.1/'),
 
182
            ('https', '127.0.0.1', 443, '/'))
 
183
 
 
184
        # Specifying a port.
 
185
        self.assertEqual(
 
186
            client._parse('http://spam:12345/'),
 
187
            ('http', 'spam', 12345, '/'))
 
188
 
 
189
        # Weird (but commonly accepted) structure uses default port.
 
190
        self.assertEqual(
 
191
            client._parse('http://spam:/'),
 
192
            ('http', 'spam', 80, '/'))
 
193
 
 
194
        # Spaces in the hostname are trimmed, the default path is /.
 
195
        self.assertEqual(
 
196
            client._parse('http://foo '),
 
197
            ('http', 'foo', 80, '/'))
 
198
 
 
199
 
 
200
    def test_externalUnicodeInterference(self):
 
201
        """
 
202
        L{client._parse} should return C{str} for the scheme, host, and path
 
203
        elements of its return tuple, even when passed an URL which has
 
204
        previously been passed to L{urlparse} as a C{unicode} string.
 
205
        """
 
206
        badInput = u'http://example.com/path'
 
207
        goodInput = badInput.encode('ascii')
 
208
        urlparse(badInput)
 
209
        scheme, host, port, path = client._parse(goodInput)
 
210
        self.assertTrue(isinstance(scheme, str))
 
211
        self.assertTrue(isinstance(host, str))
 
212
        self.assertTrue(isinstance(path, str))
 
213
 
 
214
 
 
215
 
 
216
class HTTPPageGetterTests(unittest.TestCase):
 
217
    """
 
218
    Tests for L{HTTPPagerGetter}, the HTTP client protocol implementation
 
219
    used to implement L{getPage}.
 
220
    """
 
221
    def test_earlyHeaders(self):
 
222
        """
 
223
        When a connection is made, L{HTTPPagerGetter} sends the headers from
 
224
        its factory's C{headers} dict.  If I{Host} or I{Content-Length} is
 
225
        present in this dict, the values are not sent, since they are sent with
 
226
        special values before the C{headers} dict is processed.  If
 
227
        I{User-Agent} is present in the dict, it overrides the value of the
 
228
        C{agent} attribute of the factory.  If I{Cookie} is present in the
 
229
        dict, its value is added to the values from the factory's C{cookies}
 
230
        attribute.
 
231
        """
 
232
        factory = client.HTTPClientFactory(
 
233
            'http://foo/bar',
 
234
            agent="foobar",
 
235
            cookies={'baz': 'quux'},
 
236
            postdata="some data",
 
237
            headers={
 
238
                'Host': 'example.net',
 
239
                'User-Agent': 'fooble',
 
240
                'Cookie': 'blah blah',
 
241
                'Content-Length': '12981',
 
242
                'Useful': 'value'})
 
243
        transport = StringTransport()
 
244
        protocol = client.HTTPPageGetter()
 
245
        protocol.factory = factory
 
246
        protocol.makeConnection(transport)
 
247
        self.assertEqual(
 
248
            transport.value(),
 
249
            "GET /bar HTTP/1.0\r\n"
 
250
            "Host: example.net\r\n"
 
251
            "User-Agent: foobar\r\n"
 
252
            "Content-Length: 9\r\n"
 
253
            "Useful: value\r\n"
 
254
            "connection: close\r\n"
 
255
            "Cookie: blah blah; baz=quux\r\n"
 
256
            "\r\n"
 
257
            "some data")
 
258
 
 
259
 
 
260
 
 
261
class WebClientTestCase(unittest.TestCase):
 
262
    def _listen(self, site):
 
263
        return reactor.listenTCP(0, site, interface="127.0.0.1")
 
264
 
 
265
    def setUp(self):
 
266
        self.cleanupServerConnections = 0
 
267
        name = self.mktemp()
 
268
        os.mkdir(name)
 
269
        FilePath(name).child("file").setContent("0123456789")
 
270
        r = static.File(name)
 
271
        r.putChild("redirect", util.Redirect("/file"))
 
272
        self.infiniteRedirectResource = CountingRedirect("/infiniteRedirect")
 
273
        r.putChild("infiniteRedirect", self.infiniteRedirectResource)
 
274
        r.putChild("wait", ForeverTakingResource())
 
275
        r.putChild("write-then-wait", ForeverTakingResource(write=True))
 
276
        r.putChild("error", ErrorResource())
 
277
        r.putChild("nolength", NoLengthResource())
 
278
        r.putChild("host", HostHeaderResource())
 
279
        r.putChild("payload", PayloadResource())
 
280
        r.putChild("broken", BrokenDownloadResource())
 
281
        r.putChild("cookiemirror", CookieMirrorResource())
 
282
 
 
283
        miscasedHead = static.Data("miscased-head GET response content", "major/minor")
 
284
        miscasedHead.render_Head = lambda request: "miscased-head content"
 
285
        r.putChild("miscased-head", miscasedHead)
 
286
 
 
287
        self.extendedRedirect = ExtendedRedirect('/extendedRedirect')
 
288
        r.putChild("extendedRedirect", self.extendedRedirect)
 
289
        self.site = server.Site(r, timeout=None)
 
290
        self.wrapper = WrappingFactory(self.site)
 
291
        self.port = self._listen(self.wrapper)
 
292
        self.portno = self.port.getHost().port
 
293
 
 
294
    def tearDown(self):
 
295
        # If the test indicated it might leave some server-side connections
 
296
        # around, clean them up.
 
297
        connections = self.wrapper.protocols.keys()
 
298
        # If there are fewer server-side connections than requested,
 
299
        # that's okay.  Some might have noticed that the client closed
 
300
        # the connection and cleaned up after themselves.
 
301
        for n in range(min(len(connections), self.cleanupServerConnections)):
 
302
            proto = connections.pop()
 
303
            msg("Closing %r" % (proto,))
 
304
            proto.transport.loseConnection()
 
305
        if connections:
 
306
            msg("Some left-over connections; this test is probably buggy.")
 
307
        return self.port.stopListening()
 
308
 
 
309
    def getURL(self, path):
 
310
        return "http://127.0.0.1:%d/%s" % (self.portno, path)
 
311
 
 
312
    def testPayload(self):
 
313
        s = "0123456789" * 10
 
314
        return client.getPage(self.getURL("payload"), postdata=s
 
315
            ).addCallback(self.assertEquals, s
 
316
            )
 
317
 
 
318
 
 
319
    def test_getPageBrokenDownload(self):
 
320
        """
 
321
        If the connection is closed before the number of bytes indicated by
 
322
        I{Content-Length} have been received, the L{Deferred} returned by
 
323
        L{getPage} fails with L{PartialDownloadError}.
 
324
        """
 
325
        d = client.getPage(self.getURL("broken"))
 
326
        d = self.assertFailure(d, client.PartialDownloadError)
 
327
        d.addCallback(lambda exc: self.assertEquals(exc.response, "abc"))
 
328
        return d
 
329
 
 
330
 
 
331
    def test_downloadPageBrokenDownload(self):
 
332
        """
 
333
        If the connection is closed before the number of bytes indicated by
 
334
        I{Content-Length} have been received, the L{Deferred} returned by
 
335
        L{downloadPage} fails with L{PartialDownloadError}.
 
336
        """
 
337
        # test what happens when download gets disconnected in the middle
 
338
        path = FilePath(self.mktemp())
 
339
        d = client.downloadPage(self.getURL("broken"), path.path)
 
340
        d = self.assertFailure(d, client.PartialDownloadError)
 
341
 
 
342
        def checkResponse(response):
 
343
            """
 
344
            The HTTP status code from the server is propagated through the
 
345
            C{PartialDownloadError}.
 
346
            """
 
347
            self.assertEquals(response.status, "200")
 
348
            self.assertEquals(response.message, "OK")
 
349
            return response
 
350
        d.addCallback(checkResponse)
 
351
 
 
352
        def cbFailed(ignored):
 
353
            self.assertEquals(path.getContent(), "abc")
 
354
        d.addCallback(cbFailed)
 
355
        return d
 
356
 
 
357
 
 
358
    def test_downloadPageLogsFileCloseError(self):
 
359
        """
 
360
        If there is an exception closing the file being written to after the
 
361
        connection is prematurely closed, that exception is logged.
 
362
        """
 
363
        class BrokenFile:
 
364
            def write(self, bytes):
 
365
                pass
 
366
 
 
367
            def close(self):
 
368
                raise IOError(ENOSPC, "No file left on device")
 
369
 
 
370
        d = client.downloadPage(self.getURL("broken"), BrokenFile())
 
371
        d = self.assertFailure(d, client.PartialDownloadError)
 
372
        def cbFailed(ignored):
 
373
            self.assertEquals(len(self.flushLoggedErrors(IOError)), 1)
 
374
        d.addCallback(cbFailed)
 
375
        return d
 
376
 
 
377
 
 
378
    def testHostHeader(self):
 
379
        # if we pass Host header explicitly, it should be used, otherwise
 
380
        # it should extract from url
 
381
        return defer.gatherResults([
 
382
            client.getPage(self.getURL("host")).addCallback(self.assertEquals, "127.0.0.1"),
 
383
            client.getPage(self.getURL("host"), headers={"Host": "www.example.com"}).addCallback(self.assertEquals, "www.example.com")])
 
384
 
 
385
 
 
386
    def test_getPage(self):
 
387
        """
 
388
        L{client.getPage} returns a L{Deferred} which is called back with
 
389
        the body of the response if the default method B{GET} is used.
 
390
        """
 
391
        d = client.getPage(self.getURL("file"))
 
392
        d.addCallback(self.assertEquals, "0123456789")
 
393
        return d
 
394
 
 
395
 
 
396
    def test_getPageHEAD(self):
 
397
        """
 
398
        L{client.getPage} returns a L{Deferred} which is called back with
 
399
        the empty string if the method is I{HEAD} and there is a successful
 
400
        response code.
 
401
        """
 
402
        d = client.getPage(self.getURL("file"), method="HEAD")
 
403
        d.addCallback(self.assertEquals, "")
 
404
        return d
 
405
 
 
406
 
 
407
 
 
408
    def test_getPageNotQuiteHEAD(self):
 
409
        """
 
410
        If the request method is a different casing of I{HEAD} (ie, not all
 
411
        capitalized) then it is not a I{HEAD} request and the response body
 
412
        is returned.
 
413
        """
 
414
        d = client.getPage(self.getURL("miscased-head"), method='Head')
 
415
        d.addCallback(self.assertEquals, "miscased-head content")
 
416
        return d
 
417
 
 
418
 
 
419
    def test_timeoutNotTriggering(self):
 
420
        """
 
421
        When a non-zero timeout is passed to L{getPage} and the page is
 
422
        retrieved before the timeout period elapses, the L{Deferred} is
 
423
        called back with the contents of the page.
 
424
        """
 
425
        d = client.getPage(self.getURL("host"), timeout=100)
 
426
        d.addCallback(self.assertEquals, "127.0.0.1")
 
427
        return d
 
428
 
 
429
 
 
430
    def test_timeoutTriggering(self):
 
431
        """
 
432
        When a non-zero timeout is passed to L{getPage} and that many
 
433
        seconds elapse before the server responds to the request. the
 
434
        L{Deferred} is errbacked with a L{error.TimeoutError}.
 
435
        """
 
436
        # This will probably leave some connections around.
 
437
        self.cleanupServerConnections = 1
 
438
        return self.assertFailure(
 
439
            client.getPage(self.getURL("wait"), timeout=0.000001),
 
440
            defer.TimeoutError)
 
441
 
 
442
 
 
443
    def testDownloadPage(self):
 
444
        downloads = []
 
445
        downloadData = [("file", self.mktemp(), "0123456789"),
 
446
                        ("nolength", self.mktemp(), "nolength")]
 
447
 
 
448
        for (url, name, data) in downloadData:
 
449
            d = client.downloadPage(self.getURL(url), name)
 
450
            d.addCallback(self._cbDownloadPageTest, data, name)
 
451
            downloads.append(d)
 
452
        return defer.gatherResults(downloads)
 
453
 
 
454
    def _cbDownloadPageTest(self, ignored, data, name):
 
455
        bytes = file(name, "rb").read()
 
456
        self.assertEquals(bytes, data)
 
457
 
 
458
    def testDownloadPageError1(self):
 
459
        class errorfile:
 
460
            def write(self, data):
 
461
                raise IOError, "badness happened during write"
 
462
            def close(self):
 
463
                pass
 
464
        ef = errorfile()
 
465
        return self.assertFailure(
 
466
            client.downloadPage(self.getURL("file"), ef),
 
467
            IOError)
 
468
 
 
469
    def testDownloadPageError2(self):
 
470
        class errorfile:
 
471
            def write(self, data):
 
472
                pass
 
473
            def close(self):
 
474
                raise IOError, "badness happened during close"
 
475
        ef = errorfile()
 
476
        return self.assertFailure(
 
477
            client.downloadPage(self.getURL("file"), ef),
 
478
            IOError)
 
479
 
 
480
    def testDownloadPageError3(self):
 
481
        # make sure failures in open() are caught too. This is tricky.
 
482
        # Might only work on posix.
 
483
        tmpfile = open("unwritable", "wb")
 
484
        tmpfile.close()
 
485
        os.chmod("unwritable", 0) # make it unwritable (to us)
 
486
        d = self.assertFailure(
 
487
            client.downloadPage(self.getURL("file"), "unwritable"),
 
488
            IOError)
 
489
        d.addBoth(self._cleanupDownloadPageError3)
 
490
        return d
 
491
 
 
492
    def _cleanupDownloadPageError3(self, ignored):
 
493
        os.chmod("unwritable", 0700)
 
494
        os.unlink("unwritable")
 
495
        return ignored
 
496
 
 
497
    def _downloadTest(self, method):
 
498
        dl = []
 
499
        for (url, code) in [("nosuchfile", "404"), ("error", "401"),
 
500
                            ("error?showlength=1", "401")]:
 
501
            d = method(url)
 
502
            d = self.assertFailure(d, error.Error)
 
503
            d.addCallback(lambda exc, code=code: self.assertEquals(exc.args[0], code))
 
504
            dl.append(d)
 
505
        return defer.DeferredList(dl, fireOnOneErrback=True)
 
506
 
 
507
    def testServerError(self):
 
508
        return self._downloadTest(lambda url: client.getPage(self.getURL(url)))
 
509
 
 
510
    def testDownloadServerError(self):
 
511
        return self._downloadTest(lambda url: client.downloadPage(self.getURL(url), url.split('?')[0]))
 
512
 
 
513
    def testFactoryInfo(self):
 
514
        url = self.getURL('file')
 
515
        scheme, host, port, path = client._parse(url)
 
516
        factory = client.HTTPClientFactory(url)
 
517
        reactor.connectTCP(host, port, factory)
 
518
        return factory.deferred.addCallback(self._cbFactoryInfo, factory)
 
519
 
 
520
    def _cbFactoryInfo(self, ignoredResult, factory):
 
521
        self.assertEquals(factory.status, '200')
 
522
        self.assert_(factory.version.startswith('HTTP/'))
 
523
        self.assertEquals(factory.message, 'OK')
 
524
        self.assertEquals(factory.response_headers['content-length'][0], '10')
 
525
 
 
526
 
 
527
    def testRedirect(self):
 
528
        return client.getPage(self.getURL("redirect")).addCallback(self._cbRedirect)
 
529
 
 
530
    def _cbRedirect(self, pageData):
 
531
        self.assertEquals(pageData, "0123456789")
 
532
        d = self.assertFailure(
 
533
            client.getPage(self.getURL("redirect"), followRedirect=0),
 
534
            error.PageRedirect)
 
535
        d.addCallback(self._cbCheckLocation)
 
536
        return d
 
537
 
 
538
    def _cbCheckLocation(self, exc):
 
539
        self.assertEquals(exc.location, "/file")
 
540
 
 
541
 
 
542
    def test_infiniteRedirection(self):
 
543
        """
 
544
        When more than C{redirectLimit} HTTP redirects are encountered, the
 
545
        page request fails with L{InfiniteRedirection}.
 
546
        """
 
547
        def checkRedirectCount(*a):
 
548
            self.assertEquals(f._redirectCount, 13)
 
549
            self.assertEquals(self.infiniteRedirectResource.count, 13)
 
550
 
 
551
        f = client._makeGetterFactory(
 
552
            self.getURL('infiniteRedirect'),
 
553
            client.HTTPClientFactory,
 
554
            redirectLimit=13)
 
555
        d = self.assertFailure(f.deferred, error.InfiniteRedirection)
 
556
        d.addCallback(checkRedirectCount)
 
557
        return d
 
558
 
 
559
 
 
560
    def test_isolatedFollowRedirect(self):
 
561
        """
 
562
        C{client.HTTPPagerGetter} instances each obey the C{followRedirect}
 
563
        value passed to the L{client.getPage} call which created them.
 
564
        """
 
565
        d1 = client.getPage(self.getURL('redirect'), followRedirect=True)
 
566
        d2 = client.getPage(self.getURL('redirect'), followRedirect=False)
 
567
 
 
568
        d = self.assertFailure(d2, error.PageRedirect
 
569
            ).addCallback(lambda dummy: d1)
 
570
        return d
 
571
 
 
572
 
 
573
    def test_afterFoundGet(self):
 
574
        """
 
575
        Enabling unsafe redirection behaviour overwrites the method of
 
576
        redirected C{POST} requests with C{GET}.
 
577
        """
 
578
        url = self.getURL('extendedRedirect?code=302')
 
579
        f = client.HTTPClientFactory(url, followRedirect=True, method="POST")
 
580
        self.assertFalse(
 
581
            f.afterFoundGet,
 
582
            "By default, afterFoundGet must be disabled")
 
583
 
 
584
        def gotPage(page):
 
585
            self.assertEquals(
 
586
                self.extendedRedirect.lastMethod,
 
587
                "GET",
 
588
                "With afterFoundGet, the HTTP method must change to GET")
 
589
 
 
590
        d = client.getPage(
 
591
            url, followRedirect=True, afterFoundGet=True, method="POST")
 
592
        d.addCallback(gotPage)
 
593
        return d
 
594
 
 
595
 
 
596
    def testPartial(self):
 
597
        name = self.mktemp()
 
598
        f = open(name, "wb")
 
599
        f.write("abcd")
 
600
        f.close()
 
601
 
 
602
        partialDownload = [(True, "abcd456789"),
 
603
                           (True, "abcd456789"),
 
604
                           (False, "0123456789")]
 
605
 
 
606
        d = defer.succeed(None)
 
607
        for (partial, expectedData) in partialDownload:
 
608
            d.addCallback(self._cbRunPartial, name, partial)
 
609
            d.addCallback(self._cbPartialTest, expectedData, name)
 
610
 
 
611
        return d
 
612
 
 
613
    testPartial.skip = "Cannot test until webserver can serve partial data properly"
 
614
 
 
615
    def _cbRunPartial(self, ignored, name, partial):
 
616
        return client.downloadPage(self.getURL("file"), name, supportPartial=partial)
 
617
 
 
618
    def _cbPartialTest(self, ignored, expectedData, filename):
 
619
        bytes = file(filename, "rb").read()
 
620
        self.assertEquals(bytes, expectedData)
 
621
 
 
622
 
 
623
    def test_downloadTimeout(self):
 
624
        """
 
625
        If the timeout indicated by the C{timeout} parameter to
 
626
        L{client.HTTPDownloader.__init__} elapses without the complete response
 
627
        being received, the L{defer.Deferred} returned by
 
628
        L{client.downloadPage} fires with a L{Failure} wrapping a
 
629
        L{defer.TimeoutError}.
 
630
        """
 
631
        self.cleanupServerConnections = 2
 
632
        # Verify the behavior if no bytes are ever written.
 
633
        first = client.downloadPage(
 
634
            self.getURL("wait"),
 
635
            self.mktemp(), timeout=0.01)
 
636
 
 
637
        # Verify the behavior if some bytes are written but then the request
 
638
        # never completes.
 
639
        second = client.downloadPage(
 
640
            self.getURL("write-then-wait"),
 
641
            self.mktemp(), timeout=0.01)
 
642
 
 
643
        return defer.gatherResults([
 
644
            self.assertFailure(first, defer.TimeoutError),
 
645
            self.assertFailure(second, defer.TimeoutError)])
 
646
 
 
647
 
 
648
    def test_downloadHeaders(self):
 
649
        """
 
650
        After L{client.HTTPDownloader.deferred} fires, the
 
651
        L{client.HTTPDownloader} instance's C{status} and C{response_headers}
 
652
        attributes are populated with the values from the response.
 
653
        """
 
654
        def checkHeaders(factory):
 
655
            self.assertEquals(factory.status, '200')
 
656
            self.assertEquals(factory.response_headers['content-type'][0], 'text/html')
 
657
            self.assertEquals(factory.response_headers['content-length'][0], '10')
 
658
            os.unlink(factory.fileName)
 
659
        factory = client._makeGetterFactory(
 
660
            self.getURL('file'),
 
661
            client.HTTPDownloader,
 
662
            fileOrName=self.mktemp())
 
663
        return factory.deferred.addCallback(lambda _: checkHeaders(factory))
 
664
 
 
665
 
 
666
    def test_downloadCookies(self):
 
667
        """
 
668
        The C{cookies} dict passed to the L{client.HTTPDownloader}
 
669
        initializer is used to populate the I{Cookie} header included in the
 
670
        request sent to the server.
 
671
        """
 
672
        output = self.mktemp()
 
673
        factory = client._makeGetterFactory(
 
674
            self.getURL('cookiemirror'),
 
675
            client.HTTPDownloader,
 
676
            fileOrName=output,
 
677
            cookies={'foo': 'bar'})
 
678
        def cbFinished(ignored):
 
679
            self.assertEqual(
 
680
                FilePath(output).getContent(),
 
681
                "[('foo', 'bar')]")
 
682
        factory.deferred.addCallback(cbFinished)
 
683
        return factory.deferred
 
684
 
 
685
 
 
686
    def test_downloadRedirectLimit(self):
 
687
        """
 
688
        When more than C{redirectLimit} HTTP redirects are encountered, the
 
689
        page request fails with L{InfiniteRedirection}.
 
690
        """
 
691
        def checkRedirectCount(*a):
 
692
            self.assertEquals(f._redirectCount, 7)
 
693
            self.assertEquals(self.infiniteRedirectResource.count, 7)
 
694
 
 
695
        f = client._makeGetterFactory(
 
696
            self.getURL('infiniteRedirect'),
 
697
            client.HTTPDownloader,
 
698
            fileOrName=self.mktemp(),
 
699
            redirectLimit=7)
 
700
        d = self.assertFailure(f.deferred, error.InfiniteRedirection)
 
701
        d.addCallback(checkRedirectCount)
 
702
        return d
 
703
 
 
704
 
 
705
 
 
706
class WebClientSSLTestCase(WebClientTestCase):
 
707
    def _listen(self, site):
 
708
        from twisted import test
 
709
        return reactor.listenSSL(0, site,
 
710
                                 contextFactory=ssl.DefaultOpenSSLContextFactory(
 
711
            FilePath(test.__file__).sibling('server.pem').path,
 
712
            FilePath(test.__file__).sibling('server.pem').path,
 
713
            ),
 
714
                                 interface="127.0.0.1")
 
715
 
 
716
    def getURL(self, path):
 
717
        return "https://127.0.0.1:%d/%s" % (self.portno, path)
 
718
 
 
719
    def testFactoryInfo(self):
 
720
        url = self.getURL('file')
 
721
        scheme, host, port, path = client._parse(url)
 
722
        factory = client.HTTPClientFactory(url)
 
723
        reactor.connectSSL(host, port, factory, ssl.ClientContextFactory())
 
724
        # The base class defines _cbFactoryInfo correctly for this
 
725
        return factory.deferred.addCallback(self._cbFactoryInfo, factory)
 
726
 
 
727
class WebClientRedirectBetweenSSLandPlainText(unittest.TestCase):
 
728
    def getHTTPS(self, path):
 
729
        return "https://127.0.0.1:%d/%s" % (self.tlsPortno, path)
 
730
 
 
731
    def getHTTP(self, path):
 
732
        return "http://127.0.0.1:%d/%s" % (self.plainPortno, path)
 
733
 
 
734
    def setUp(self):
 
735
        plainRoot = static.Data('not me', 'text/plain')
 
736
        tlsRoot = static.Data('me neither', 'text/plain')
 
737
 
 
738
        plainSite = server.Site(plainRoot, timeout=None)
 
739
        tlsSite = server.Site(tlsRoot, timeout=None)
 
740
 
 
741
        from twisted import test
 
742
        self.tlsPort = reactor.listenSSL(0, tlsSite,
 
743
                                         contextFactory=ssl.DefaultOpenSSLContextFactory(
 
744
            FilePath(test.__file__).sibling('server.pem').path,
 
745
            FilePath(test.__file__).sibling('server.pem').path,
 
746
            ),
 
747
                                         interface="127.0.0.1")
 
748
        self.plainPort = reactor.listenTCP(0, plainSite, interface="127.0.0.1")
 
749
 
 
750
        self.plainPortno = self.plainPort.getHost().port
 
751
        self.tlsPortno = self.tlsPort.getHost().port
 
752
 
 
753
        plainRoot.putChild('one', util.Redirect(self.getHTTPS('two')))
 
754
        tlsRoot.putChild('two', util.Redirect(self.getHTTP('three')))
 
755
        plainRoot.putChild('three', util.Redirect(self.getHTTPS('four')))
 
756
        tlsRoot.putChild('four', static.Data('FOUND IT!', 'text/plain'))
 
757
 
 
758
    def tearDown(self):
 
759
        ds = map(defer.maybeDeferred,
 
760
                 [self.plainPort.stopListening, self.tlsPort.stopListening])
 
761
        return defer.gatherResults(ds)
 
762
 
 
763
    def testHoppingAround(self):
 
764
        return client.getPage(self.getHTTP("one")
 
765
            ).addCallback(self.assertEquals, "FOUND IT!"
 
766
            )
 
767
 
 
768
class FakeTransport:
 
769
    disconnecting = False
 
770
    def __init__(self):
 
771
        self.data = []
 
772
    def write(self, stuff):
 
773
        self.data.append(stuff)
 
774
 
 
775
class CookieTestCase(unittest.TestCase):
 
776
    def _listen(self, site):
 
777
        return reactor.listenTCP(0, site, interface="127.0.0.1")
 
778
 
 
779
    def setUp(self):
 
780
        root = static.Data('El toro!', 'text/plain')
 
781
        root.putChild("cookiemirror", CookieMirrorResource())
 
782
        root.putChild("rawcookiemirror", RawCookieMirrorResource())
 
783
        site = server.Site(root, timeout=None)
 
784
        self.port = self._listen(site)
 
785
        self.portno = self.port.getHost().port
 
786
 
 
787
    def tearDown(self):
 
788
        return self.port.stopListening()
 
789
 
 
790
    def getHTTP(self, path):
 
791
        return "http://127.0.0.1:%d/%s" % (self.portno, path)
 
792
 
 
793
    def testNoCookies(self):
 
794
        return client.getPage(self.getHTTP("cookiemirror")
 
795
            ).addCallback(self.assertEquals, "[]"
 
796
            )
 
797
 
 
798
    def testSomeCookies(self):
 
799
        cookies = {'foo': 'bar', 'baz': 'quux'}
 
800
        return client.getPage(self.getHTTP("cookiemirror"), cookies=cookies
 
801
            ).addCallback(self.assertEquals, "[('baz', 'quux'), ('foo', 'bar')]"
 
802
            )
 
803
 
 
804
    def testRawNoCookies(self):
 
805
        return client.getPage(self.getHTTP("rawcookiemirror")
 
806
            ).addCallback(self.assertEquals, "None"
 
807
            )
 
808
 
 
809
    def testRawSomeCookies(self):
 
810
        cookies = {'foo': 'bar', 'baz': 'quux'}
 
811
        return client.getPage(self.getHTTP("rawcookiemirror"), cookies=cookies
 
812
            ).addCallback(self.assertEquals, "'foo=bar; baz=quux'"
 
813
            )
 
814
 
 
815
    def testCookieHeaderParsing(self):
 
816
        factory = client.HTTPClientFactory('http://foo.example.com/')
 
817
        proto = factory.buildProtocol('127.42.42.42')
 
818
        proto.transport = FakeTransport()
 
819
        proto.connectionMade()
 
820
        for line in [
 
821
            '200 Ok',
 
822
            'Squash: yes',
 
823
            'Hands: stolen',
 
824
            'Set-Cookie: CUSTOMER=WILE_E_COYOTE; path=/; expires=Wednesday, 09-Nov-99 23:12:40 GMT',
 
825
            'Set-Cookie: PART_NUMBER=ROCKET_LAUNCHER_0001; path=/',
 
826
            'Set-Cookie: SHIPPING=FEDEX; path=/foo',
 
827
            '',
 
828
            'body',
 
829
            'more body',
 
830
            ]:
 
831
            proto.dataReceived(line + '\r\n')
 
832
        self.assertEquals(proto.transport.data,
 
833
                          ['GET / HTTP/1.0\r\n',
 
834
                           'Host: foo.example.com\r\n',
 
835
                           'User-Agent: Twisted PageGetter\r\n',
 
836
                           '\r\n'])
 
837
        self.assertEquals(factory.cookies,
 
838
                          {
 
839
            'CUSTOMER': 'WILE_E_COYOTE',
 
840
            'PART_NUMBER': 'ROCKET_LAUNCHER_0001',
 
841
            'SHIPPING': 'FEDEX',
 
842
            })
 
843
 
 
844
 
 
845
 
 
846
class StubHTTPProtocol(Protocol):
 
847
    """
 
848
    A protocol like L{HTTP11ClientProtocol} but which does not actually know
 
849
    HTTP/1.1 and only collects requests in a list.
 
850
 
 
851
    @ivar requests: A C{list} of two-tuples.  Each time a request is made, a
 
852
        tuple consisting of the request and the L{Deferred} returned from the
 
853
        request method is appended to this list.
 
854
    """
 
855
    def __init__(self):
 
856
        self.requests = []
 
857
 
 
858
 
 
859
    def request(self, request):
 
860
        """
 
861
        Capture the given request for later inspection.
 
862
 
 
863
        @return: A L{Deferred} which this code will never fire.
 
864
        """
 
865
        result = Deferred()
 
866
        self.requests.append((request, result))
 
867
        return result
 
868
 
 
869
 
 
870
 
 
871
class AgentTests(unittest.TestCase):
 
872
    """
 
873
    Tests for the new HTTP client API provided by L{Agent}.
 
874
    """
 
875
    def setUp(self):
 
876
        """
 
877
        Create an L{Agent} wrapped around a fake reactor.
 
878
        """
 
879
        class Reactor(MemoryReactor, Clock):
 
880
            def __init__(self):
 
881
                MemoryReactor.__init__(self)
 
882
                Clock.__init__(self)
 
883
 
 
884
        self.reactor = Reactor()
 
885
        self.agent = client.Agent(self.reactor)
 
886
 
 
887
 
 
888
    def completeConnection(self):
 
889
        """
 
890
        Do whitebox stuff to finish any outstanding connection attempts the
 
891
        agent may have initiated.
 
892
 
 
893
        This spins the fake reactor clock just enough to get L{ClientCreator},
 
894
        which agent is implemented in terms of, to fire its Deferreds.
 
895
        """
 
896
        self.reactor.advance(0)
 
897
 
 
898
 
 
899
    def _verifyAndCompleteConnectionTo(self, host, port):
 
900
        """
 
901
        Assert that the destination of the oldest unverified TCP connection
 
902
        attempt is the given host and port.  Then pop it, create a protocol,
 
903
        connect it to a L{StringTransport}, and return the protocol.
 
904
        """
 
905
        # Grab the connection attempt, make sure it goes to the right place,
 
906
        # and cause it to succeed.
 
907
        host, port, factory = self.reactor.tcpClients.pop()[:3]
 
908
        self.assertEquals(host, host)
 
909
        self.assertEquals(port, port)
 
910
 
 
911
        protocol = factory.buildProtocol(IPv4Address('TCP', '10.0.0.3', 1234))
 
912
        transport = StringTransport()
 
913
        protocol.makeConnection(transport)
 
914
        self.completeConnection()
 
915
        return protocol
 
916
 
 
917
 
 
918
    def test_unsupportedScheme(self):
 
919
        """
 
920
        L{Agent.request} returns a L{Deferred} which fails with
 
921
        L{SchemeNotSupported} if the scheme of the URI passed to it is not
 
922
        C{'http'}.
 
923
        """
 
924
        return self.assertFailure(
 
925
            self.agent.request('GET', 'mailto:alice@example.com'),
 
926
            SchemeNotSupported)
 
927
 
 
928
 
 
929
    def test_connectionFailed(self):
 
930
        """
 
931
        The L{Deferred} returned by L{Agent.request} fires with a L{Failure} if
 
932
        the TCP connection attempt fails.
 
933
        """
 
934
        result = self.agent.request('GET', 'http://foo/')
 
935
 
 
936
        # Cause the connection to be refused
 
937
        host, port, factory = self.reactor.tcpClients.pop()[:3]
 
938
        factory.clientConnectionFailed(None, ConnectionRefusedError())
 
939
        self.completeConnection()
 
940
 
 
941
        return self.assertFailure(result, ConnectionRefusedError)
 
942
 
 
943
 
 
944
    def test_request(self):
 
945
        """
 
946
        L{Agent.request} establishes a new connection to the host indicated by
 
947
        the host part of the URI passed to it and issues a request using the
 
948
        method, the path portion of the URI, the headers, and the body producer
 
949
        passed to it.  It returns a L{Deferred} which fires with a L{Response}
 
950
        from the server.
 
951
        """
 
952
        self.agent._protocol = StubHTTPProtocol
 
953
 
 
954
        headers = http_headers.Headers({'foo': ['bar']})
 
955
        # Just going to check the body for identity, so it doesn't need to be
 
956
        # real.
 
957
        body = object()
 
958
        self.agent.request(
 
959
            'GET', 'http://example.com:1234/foo?bar', headers, body)
 
960
 
 
961
        protocol = self._verifyAndCompleteConnectionTo('example.com', 1234)
 
962
 
 
963
        # The request should be issued.
 
964
        self.assertEquals(len(protocol.requests), 1)
 
965
        req, res = protocol.requests.pop()
 
966
        self.assertTrue(isinstance(req, Request))
 
967
        self.assertEquals(req.method, 'GET')
 
968
        self.assertEquals(req.uri, '/foo?bar')
 
969
        self.assertEquals(
 
970
            req.headers,
 
971
            http_headers.Headers({'foo': ['bar'],
 
972
                                  'host': ['example.com:1234']}))
 
973
        self.assertIdentical(req.bodyProducer, body)
 
974
 
 
975
 
 
976
    def test_hostProvided(self):
 
977
        """
 
978
        If C{None} is passed to L{Agent.request} for the C{headers}
 
979
        parameter, a L{Headers} instance is created for the request and a
 
980
        I{Host} header added to it.
 
981
        """
 
982
        self.agent._protocol = StubHTTPProtocol
 
983
 
 
984
        self.agent.request('GET', 'http://example.com/foo')
 
985
 
 
986
        protocol = self._verifyAndCompleteConnectionTo('example.com', 80)
 
987
 
 
988
        # The request should have been issued with a host header based on
 
989
        # the request URL.
 
990
        self.assertEquals(len(protocol.requests), 1)
 
991
        req, res = protocol.requests.pop()
 
992
        self.assertEquals(req.headers.getRawHeaders('host'), ['example.com'])
 
993
 
 
994
 
 
995
    def test_hostOverride(self):
 
996
        """
 
997
        If the headers passed to L{Agent.request} includes a value for the
 
998
        I{Host} header, that value takes precedence over the one which would
 
999
        otherwise be automatically provided.
 
1000
        """
 
1001
        self.agent._protocol = StubHTTPProtocol
 
1002
 
 
1003
        headers = http_headers.Headers({'foo': ['bar'], 'host': ['quux']})
 
1004
        body = object()
 
1005
        self.agent.request(
 
1006
            'GET', 'http://example.com/baz', headers, body)
 
1007
 
 
1008
        protocol = self._verifyAndCompleteConnectionTo('example.com', 80)
 
1009
 
 
1010
        # The request should have been issued with the host header specified
 
1011
        # above, not one based on the request URI.
 
1012
        self.assertEquals(len(protocol.requests), 1)
 
1013
        req, res = protocol.requests.pop()
 
1014
        self.assertEquals(req.headers.getRawHeaders('host'), ['quux'])
 
1015
 
 
1016
 
 
1017
    def test_headersUnmodified(self):
 
1018
        """
 
1019
        If a I{Host} header must be added to the request, the L{Headers}
 
1020
        instance passed to L{Agent.request} is not modified.
 
1021
        """
 
1022
        self.agent._protocol = StubHTTPProtocol
 
1023
 
 
1024
        headers = http_headers.Headers()
 
1025
        body = object()
 
1026
        self.agent.request(
 
1027
            'GET', 'http://example.com/foo', headers, body)
 
1028
 
 
1029
        protocol = self._verifyAndCompleteConnectionTo('example.com', 80)
 
1030
 
 
1031
        # The request should have been issued.
 
1032
        self.assertEquals(len(protocol.requests), 1)
 
1033
        # And the headers object passed in should not have changed.
 
1034
        self.assertEquals(headers, http_headers.Headers())
 
1035
 
 
1036
 
 
1037
    def test_hostValue(self):
 
1038
        """
 
1039
        L{Agent._computeHostValue} returns just the hostname it is passed if
 
1040
        the port number it is passed is the default for the scheme it is
 
1041
        passed, otherwise it returns a string containing both the host and port
 
1042
        separated by C{":"}.
 
1043
        """
 
1044
        self.assertEquals(
 
1045
            self.agent._computeHostValue('http', 'example.com', 80),
 
1046
            'example.com')
 
1047
 
 
1048
        self.assertEquals(
 
1049
            self.agent._computeHostValue('http', 'example.com', 54321),
 
1050
            'example.com:54321')
 
1051
 
 
1052
 
 
1053
 
 
1054
if ssl is None or not hasattr(ssl, 'DefaultOpenSSLContextFactory'):
 
1055
    for case in [WebClientSSLTestCase, WebClientRedirectBetweenSSLandPlainText]:
 
1056
        case.skip = "OpenSSL not present"
 
1057
 
 
1058
if not interfaces.IReactorSSL(reactor, None):
 
1059
    for case in [WebClientSSLTestCase, WebClientRedirectBetweenSSLandPlainText]:
 
1060
        case.skip = "Reactor doesn't support SSL"