~gary/python-openid/python-openid-2.2.1-patched

« back to all changes in this revision

Viewing changes to openid/test/test_consumer.py

  • Committer: Launchpad Patch Queue Manager
  • Date: 2007-11-30 02:46:28 UTC
  • mfrom: (1.1.1 pyopenid-2.0)
  • Revision ID: launchpad@pqm.canonical.com-20071130024628-qktwsew3383iawmq
[rs=SteveA] upgrade to python-openid-2.0.1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
import urlparse
 
2
import cgi
 
3
import time
 
4
import warnings
 
5
 
 
6
from openid.message import Message, OPENID_NS, OPENID2_NS, IDENTIFIER_SELECT, \
 
7
     OPENID1_NS, BARE_NS
 
8
from openid import cryptutil, dh, oidutil, kvform
 
9
from openid.store.nonce import mkNonce, split as splitNonce
 
10
from openid.consumer.discover import OpenIDServiceEndpoint, OPENID_2_0_TYPE, \
 
11
     OPENID_1_1_TYPE
 
12
from openid.consumer.consumer import \
 
13
     AuthRequest, GenericConsumer, SUCCESS, FAILURE, CANCEL, SETUP_NEEDED, \
 
14
     SuccessResponse, FailureResponse, SetupNeededResponse, CancelResponse, \
 
15
     DiffieHellmanSHA1ConsumerSession, Consumer, PlainTextConsumerSession, \
 
16
     SetupNeededError, DiffieHellmanSHA256ConsumerSession, ServerError, \
 
17
     ProtocolError
 
18
from openid import association
 
19
from openid.server.server import \
 
20
     PlainTextServerSession, DiffieHellmanSHA1ServerSession
 
21
from openid.yadis.manager import Discovery
 
22
from openid.yadis.discover import DiscoveryFailure
 
23
from openid.dh import DiffieHellman
 
24
 
 
25
from openid.fetchers import HTTPResponse, HTTPFetchingError
 
26
from openid import fetchers
 
27
from openid.store import memstore
 
28
 
 
29
from support import CatchLogs
 
30
 
 
31
assocs = [
 
32
    ('another 20-byte key.', 'Snarky'),
 
33
    ('\x00' * 20, 'Zeros'),
 
34
    ]
 
35
 
 
36
def mkSuccess(endpoint, q):
 
37
    """Convenience function to create a SuccessResponse with the given
 
38
    arguments, all signed."""
 
39
    signed_list = ['openid.' + k for k in q.keys()]
 
40
    return SuccessResponse(endpoint, Message.fromOpenIDArgs(q), signed_list)
 
41
 
 
42
def parseQuery(qs):
 
43
    q = {}
 
44
    for (k, v) in cgi.parse_qsl(qs):
 
45
        assert not q.has_key(k)
 
46
        q[k] = v
 
47
    return q
 
48
 
 
49
def associate(qs, assoc_secret, assoc_handle):
 
50
    """Do the server's half of the associate call, using the given
 
51
    secret and handle."""
 
52
    q = parseQuery(qs)
 
53
    assert q['openid.mode'] == 'associate'
 
54
    assert q['openid.assoc_type'] == 'HMAC-SHA1'
 
55
    reply_dict = {
 
56
        'assoc_type':'HMAC-SHA1',
 
57
        'assoc_handle':assoc_handle,
 
58
        'expires_in':'600',
 
59
        }
 
60
 
 
61
    if q.get('openid.session_type') == 'DH-SHA1':
 
62
        assert len(q) == 6 or len(q) == 4
 
63
        message = Message.fromPostArgs(q)
 
64
        session = DiffieHellmanSHA1ServerSession.fromMessage(message)
 
65
        reply_dict['session_type'] = 'DH-SHA1'
 
66
    else:
 
67
        assert len(q) == 2
 
68
        session = PlainTextServerSession.fromQuery(q)
 
69
 
 
70
    reply_dict.update(session.answer(assoc_secret))
 
71
    return kvform.dictToKV(reply_dict)
 
72
 
 
73
 
 
74
GOODSIG = "[A Good Signature]"
 
75
 
 
76
 
 
77
class GoodAssociation:
 
78
    expiresIn = 3600
 
79
    handle = "-blah-"
 
80
 
 
81
    def getExpiresIn(self):
 
82
        return self.expiresIn
 
83
 
 
84
    def checkMessageSignature(self, message):
 
85
        return message.getArg(OPENID_NS, 'sig') == GOODSIG
 
86
 
 
87
 
 
88
class GoodAssocStore(memstore.MemoryStore):
 
89
    def getAssociation(self, server_url, handle=None):
 
90
        return GoodAssociation()
 
91
 
 
92
 
 
93
class TestFetcher(object):
 
94
    def __init__(self, user_url, user_page, (assoc_secret, assoc_handle)):
 
95
        self.get_responses = {user_url:self.response(user_url, 200, user_page)}
 
96
        self.assoc_secret = assoc_secret
 
97
        self.assoc_handle = assoc_handle
 
98
        self.num_assocs = 0
 
99
 
 
100
    def response(self, url, status, body):
 
101
        return HTTPResponse(
 
102
            final_url=url, status=status, headers={}, body=body)
 
103
 
 
104
    def fetch(self, url, body=None, headers=None):
 
105
        if body is None:
 
106
            if url in self.get_responses:
 
107
                return self.get_responses[url]
 
108
        else:
 
109
            try:
 
110
                body.index('openid.mode=associate')
 
111
            except ValueError:
 
112
                pass # fall through
 
113
            else:
 
114
                assert body.find('DH-SHA1') != -1
 
115
                response = associate(
 
116
                    body, self.assoc_secret, self.assoc_handle)
 
117
                self.num_assocs += 1
 
118
                return self.response(url, 200, response)
 
119
 
 
120
        return self.response(url, 404, 'Not found')
 
121
 
 
122
def makeFastConsumerSession():
 
123
    """
 
124
    Create custom DH object so tests run quickly.
 
125
    """
 
126
    dh = DiffieHellman(100389557, 2)
 
127
    return DiffieHellmanSHA1ConsumerSession(dh)
 
128
 
 
129
def setConsumerSession(con):
 
130
    con.session_types = {'DH-SHA1': makeFastConsumerSession}
 
131
 
 
132
def _test_success(server_url, user_url, delegate_url, links, immediate=False):
 
133
    store = memstore.MemoryStore()
 
134
    if immediate:
 
135
        mode = 'checkid_immediate'
 
136
    else:
 
137
        mode = 'checkid_setup'
 
138
 
 
139
    endpoint = OpenIDServiceEndpoint()
 
140
    endpoint.claimed_id = user_url
 
141
    endpoint.server_url = server_url
 
142
    endpoint.local_id = delegate_url
 
143
    endpoint.type_uris = [OPENID_1_1_TYPE]
 
144
 
 
145
    fetcher = TestFetcher(None, None, assocs[0])
 
146
    fetchers.setDefaultFetcher(fetcher, wrap_exceptions=False)
 
147
 
 
148
    def run():
 
149
        trust_root = consumer_url
 
150
 
 
151
        consumer = GenericConsumer(store)
 
152
        setConsumerSession(consumer)
 
153
 
 
154
        request = consumer.begin(endpoint)
 
155
        return_to = consumer_url
 
156
 
 
157
        m = request.getMessage(trust_root, return_to, immediate)
 
158
 
 
159
        redirect_url = request.redirectURL(trust_root, return_to, immediate)
 
160
 
 
161
        parsed = urlparse.urlparse(redirect_url)
 
162
        qs = parsed[4]
 
163
        q = parseQuery(qs)
 
164
        new_return_to = q['openid.return_to']
 
165
        del q['openid.return_to']
 
166
        assert q == {
 
167
            'openid.mode':mode,
 
168
            'openid.identity':delegate_url,
 
169
            'openid.trust_root':trust_root,
 
170
            'openid.assoc_handle':fetcher.assoc_handle,
 
171
            }, (q, user_url, delegate_url, mode)
 
172
 
 
173
        assert new_return_to.startswith(return_to)
 
174
        assert redirect_url.startswith(server_url)
 
175
 
 
176
        nonce_key = consumer.openid1_nonce_query_arg_name
 
177
        nonce = request.return_to_args[nonce_key]
 
178
 
 
179
        query = {
 
180
            nonce_key:nonce,
 
181
            'openid.mode':'id_res',
 
182
            'openid.return_to':new_return_to,
 
183
            'openid.identity':delegate_url,
 
184
            'openid.assoc_handle':fetcher.assoc_handle,
 
185
            }
 
186
 
 
187
        assoc = store.getAssociation(server_url, fetcher.assoc_handle)
 
188
 
 
189
        message = Message.fromPostArgs(query)
 
190
        message = assoc.signMessage(message)
 
191
        info = consumer.complete(message, request.endpoint)
 
192
        assert info.status == SUCCESS, info.message
 
193
        assert info.identity_url == user_url
 
194
 
 
195
    assert fetcher.num_assocs == 0
 
196
    run()
 
197
    assert fetcher.num_assocs == 1
 
198
 
 
199
    # Test that doing it again uses the existing association
 
200
    run()
 
201
    assert fetcher.num_assocs == 1
 
202
 
 
203
    # Another association is created if we remove the existing one
 
204
    store.removeAssociation(server_url, fetcher.assoc_handle)
 
205
    run()
 
206
    assert fetcher.num_assocs == 2
 
207
 
 
208
    # Test that doing it again uses the existing association
 
209
    run()
 
210
    assert fetcher.num_assocs == 2
 
211
 
 
212
import unittest
 
213
 
 
214
http_server_url = 'http://server.example.com/'
 
215
consumer_url = 'http://consumer.example.com/'
 
216
https_server_url = 'https://server.example.com/'
 
217
 
 
218
class TestSuccess(unittest.TestCase):
 
219
    server_url = http_server_url
 
220
    user_url = 'http://www.example.com/user.html'
 
221
    delegate_url = 'http://consumer.example.com/user'
 
222
 
 
223
    def setUp(self):
 
224
        self.links = '<link rel="openid.server" href="%s" />' % (
 
225
            self.server_url,)
 
226
 
 
227
        self.delegate_links = ('<link rel="openid.server" href="%s" />'
 
228
                               '<link rel="openid.delegate" href="%s" />') % (
 
229
            self.server_url, self.delegate_url)
 
230
 
 
231
    def test_nodelegate(self):
 
232
        _test_success(self.server_url, self.user_url,
 
233
                      self.user_url, self.links)
 
234
 
 
235
    def test_nodelegateImmediate(self):
 
236
        _test_success(self.server_url, self.user_url,
 
237
                      self.user_url, self.links, True)
 
238
 
 
239
    def test_delegate(self):
 
240
        _test_success(self.server_url, self.user_url,
 
241
                      self.delegate_url, self.delegate_links)
 
242
 
 
243
    def test_delegateImmediate(self):
 
244
        _test_success(self.server_url, self.user_url,
 
245
                      self.delegate_url, self.delegate_links, True)
 
246
 
 
247
 
 
248
class TestSuccessHTTPS(TestSuccess):
 
249
    server_url = https_server_url
 
250
 
 
251
 
 
252
class TestConstruct(unittest.TestCase):
 
253
    def setUp(self):
 
254
        self.store_sentinel = object()
 
255
 
 
256
    def test_construct(self):
 
257
        oidc = GenericConsumer(self.store_sentinel)
 
258
        self.failUnless(oidc.store is self.store_sentinel)
 
259
 
 
260
    def test_nostore(self):
 
261
        self.failUnlessRaises(TypeError, GenericConsumer)
 
262
 
 
263
 
 
264
class TestIdRes(unittest.TestCase):
 
265
    consumer_class = GenericConsumer
 
266
 
 
267
    def setUp(self):
 
268
        self.store = memstore.MemoryStore()
 
269
        self.consumer = self.consumer_class(self.store)
 
270
        self.return_to = "nonny"
 
271
        self.endpoint = OpenIDServiceEndpoint()
 
272
        self.endpoint.claimed_id = self.consumer_id = "consu"
 
273
        self.endpoint.server_url = self.server_url = "serlie"
 
274
        self.endpoint.local_id = self.server_id = "sirod"
 
275
        self.endpoint.type_uris = [OPENID_1_1_TYPE]
 
276
 
 
277
    def disableDiscoveryVerification(self):
 
278
        """Set the discovery verification to a no-op for test cases in
 
279
        which we don't care."""
 
280
        def dummyVerifyDiscover(_, endpoint):
 
281
            return endpoint
 
282
        self.consumer._verifyDiscoveryResults = dummyVerifyDiscover
 
283
 
 
284
 
 
285
class TestIdResCheckSignature(TestIdRes):
 
286
    def setUp(self):
 
287
        TestIdRes.setUp(self)
 
288
        self.assoc = GoodAssociation()
 
289
        self.assoc.handle = "{not_dumb}"
 
290
        self.store.storeAssociation(self.endpoint.server_url, self.assoc)
 
291
 
 
292
        self.message = Message.fromPostArgs({
 
293
            'openid.mode': 'id_res',
 
294
            'openid.identity': '=example',
 
295
            'openid.sig': GOODSIG,
 
296
            'openid.assoc_handle': self.assoc.handle,
 
297
            'openid.signed': 'mode,identity,assoc_handle,signed',
 
298
            'frobboz': 'banzit',
 
299
            })
 
300
 
 
301
 
 
302
    def test_sign(self):
 
303
        # assoc_handle to assoc with good sig
 
304
        self.consumer._idResCheckSignature(self.message,
 
305
                                           self.endpoint.server_url)
 
306
 
 
307
 
 
308
    def test_signFailsWithBadSig(self):
 
309
        self.message.setArg(OPENID_NS, 'sig', 'BAD SIGNATURE')
 
310
        self.failUnlessRaises(
 
311
            ProtocolError, self.consumer._idResCheckSignature,
 
312
            self.message, self.endpoint.server_url)
 
313
 
 
314
 
 
315
    def test_stateless(self):
 
316
        # assoc_handle missing assoc, consumer._checkAuth returns goodthings
 
317
        self.message.setArg(OPENID_NS, "assoc_handle", "dumbHandle")
 
318
        self.consumer._processCheckAuthResponse = (
 
319
            lambda response, server_url: True)
 
320
        self.consumer._makeKVPost = lambda args, server_url: {}
 
321
        self.consumer._idResCheckSignature(self.message,
 
322
                                           self.endpoint.server_url)
 
323
 
 
324
    def test_statelessRaisesError(self):
 
325
        # assoc_handle missing assoc, consumer._checkAuth returns goodthings
 
326
        self.message.setArg(OPENID_NS, "assoc_handle", "dumbHandle")
 
327
        self.consumer._checkAuth = lambda unused1, unused2: False
 
328
        self.failUnlessRaises(
 
329
            ProtocolError, self.consumer._idResCheckSignature,
 
330
            self.message, self.endpoint.server_url)
 
331
 
 
332
    def test_stateless_noStore(self):
 
333
        # assoc_handle missing assoc, consumer._checkAuth returns goodthings
 
334
        self.message.setArg(OPENID_NS, "assoc_handle", "dumbHandle")
 
335
        self.consumer.store = None
 
336
        self.consumer._processCheckAuthResponse = (
 
337
            lambda response, server_url: True)
 
338
        self.consumer._makeKVPost = lambda args, server_url: {}
 
339
        self.consumer._idResCheckSignature(self.message,
 
340
                                           self.endpoint.server_url)
 
341
 
 
342
    def test_statelessRaisesError_noStore(self):
 
343
        # assoc_handle missing assoc, consumer._checkAuth returns goodthings
 
344
        self.message.setArg(OPENID_NS, "assoc_handle", "dumbHandle")
 
345
        self.consumer._checkAuth = lambda unused1, unused2: False
 
346
        self.consumer.store = None
 
347
        self.failUnlessRaises(
 
348
            ProtocolError, self.consumer._idResCheckSignature,
 
349
            self.message, self.endpoint.server_url)
 
350
 
 
351
 
 
352
class TestQueryFormat(TestIdRes):
 
353
    def test_notAList(self):
 
354
        # XXX: should be a Message object test, not a consumer test
 
355
 
 
356
        # Value should be a single string.  If it's a list, it should generate
 
357
        # an exception.
 
358
        query = {'openid.mode': ['cancel']}
 
359
        try:
 
360
            r = Message.fromPostArgs(query)
 
361
        except TypeError, err:
 
362
            self.failUnless(str(err).find('values') != -1, err)
 
363
        else:
 
364
            self.fail("expected TypeError, got this instead: %s" % (r,))
 
365
 
 
366
class TestComplete(TestIdRes):
 
367
    """Testing GenericConsumer.complete.
 
368
 
 
369
    Other TestIdRes subclasses test more specific aspects.
 
370
    """
 
371
 
 
372
    def test_setupNeededIdRes(self):
 
373
        message = Message.fromOpenIDArgs({'mode': 'id_res'})
 
374
        setup_url_sentinel = object()
 
375
 
 
376
        def raiseSetupNeeded(msg):
 
377
            self.failUnless(msg is message)
 
378
            raise SetupNeededError(setup_url_sentinel)
 
379
 
 
380
        self.consumer._checkSetupNeeded = raiseSetupNeeded
 
381
 
 
382
        response = self.consumer.complete(message, None)
 
383
        self.failUnlessEqual(SETUP_NEEDED, response.status)
 
384
        self.failUnless(setup_url_sentinel is response.setup_url)
 
385
 
 
386
    def test_cancel(self):
 
387
        message = Message.fromPostArgs({'openid.mode': 'cancel'})
 
388
        r = self.consumer.complete(message, self.endpoint)
 
389
        self.failUnlessEqual(r.status, CANCEL)
 
390
        self.failUnless(r.identity_url == self.endpoint.claimed_id)
 
391
 
 
392
    def test_error(self):
 
393
        msg = 'an error message'
 
394
        message = Message.fromPostArgs({'openid.mode': 'error',
 
395
                 'openid.error': msg,
 
396
                 })
 
397
        r = self.consumer.complete(message, self.endpoint)
 
398
        self.failUnlessEqual(r.status, FAILURE)
 
399
        self.failUnless(r.identity_url == self.endpoint.claimed_id)
 
400
        self.failUnlessEqual(r.message, msg)
 
401
 
 
402
    def test_errorWithNoOptionalKeys(self):
 
403
        msg = 'an error message'
 
404
        contact = 'some contact info here'
 
405
        message = Message.fromPostArgs({'openid.mode': 'error',
 
406
                 'openid.error': msg,
 
407
                 'openid.contact': contact,
 
408
                 })
 
409
        r = self.consumer.complete(message, self.endpoint)
 
410
        self.failUnlessEqual(r.status, FAILURE)
 
411
        self.failUnless(r.identity_url == self.endpoint.claimed_id)
 
412
        self.failUnless(r.contact == contact)
 
413
        self.failUnless(r.reference is None)
 
414
        self.failUnlessEqual(r.message, msg)
 
415
 
 
416
    def test_errorWithOptionalKeys(self):
 
417
        msg = 'an error message'
 
418
        contact = 'me'
 
419
        reference = 'support ticket'
 
420
        message = Message.fromPostArgs({'openid.mode': 'error',
 
421
                 'openid.error': msg, 'openid.reference': reference,
 
422
                 'openid.contact': contact, 'openid.ns': OPENID2_NS,
 
423
                 })
 
424
        r = self.consumer.complete(message, self.endpoint)
 
425
        self.failUnlessEqual(r.status, FAILURE)
 
426
        self.failUnless(r.identity_url == self.endpoint.claimed_id)
 
427
        self.failUnless(r.contact == contact)
 
428
        self.failUnless(r.reference == reference)
 
429
        self.failUnlessEqual(r.message, msg)
 
430
 
 
431
    def test_noMode(self):
 
432
        message = Message.fromPostArgs({})
 
433
        r = self.consumer.complete(message, self.endpoint)
 
434
        self.failUnlessEqual(r.status, FAILURE)
 
435
        self.failUnless(r.identity_url == self.endpoint.claimed_id)
 
436
 
 
437
    def test_idResMissingField(self):
 
438
        # XXX - this test is passing, but not necessarily by what it
 
439
        # is supposed to test for.  status in FAILURE, but it's because
 
440
        # *check_auth* failed, not because it's missing an arg, exactly.
 
441
        message = Message.fromPostArgs({'openid.mode': 'id_res'})
 
442
        self.failUnlessRaises(ProtocolError, self.consumer._doIdRes,
 
443
                              message, self.endpoint)
 
444
 
 
445
    def test_idResURLMismatch(self):
 
446
        message = Message.fromPostArgs(
 
447
            {'openid.mode': 'id_res',
 
448
             'openid.return_to': 'return_to (just anything)',
 
449
             'openid.identity': 'something wrong (not self.consumer_id)',
 
450
             'openid.assoc_handle': 'does not matter',
 
451
             'openid.sig': GOODSIG,
 
452
             'openid.signed': 'identity,return_to',
 
453
             })
 
454
        self.consumer.store = GoodAssocStore()
 
455
        r = self.consumer.complete(message, self.endpoint)
 
456
        self.failUnlessEqual(r.status, FAILURE)
 
457
        self.failUnlessEqual(r.identity_url, self.consumer_id)
 
458
        self.failUnless(r.message.startswith('local_id mismatch'),
 
459
                        r.message)
 
460
 
 
461
 
 
462
 
 
463
class TestCompleteMissingSig(unittest.TestCase, CatchLogs):
 
464
 
 
465
    def setUp(self):
 
466
        self.store = GoodAssocStore()
 
467
        self.consumer = GenericConsumer(self.store)
 
468
        self.server_url = "http://idp.unittest/"
 
469
        CatchLogs.setUp(self)
 
470
 
 
471
        claimed_id = 'bogus.claimed'
 
472
 
 
473
        self.message = Message.fromOpenIDArgs(
 
474
            {'mode': 'id_res',
 
475
             'return_to': 'return_to (just anything)',
 
476
             'identity': claimed_id,
 
477
             'assoc_handle': 'does not matter',
 
478
             'sig': GOODSIG,
 
479
             'response_nonce': mkNonce(),
 
480
             'signed': 'identity,return_to,response_nonce,assoc_handle,claimed_id',
 
481
             'claimed_id': claimed_id,
 
482
             'op_endpoint': self.server_url,
 
483
             'ns':OPENID2_NS,
 
484
             })
 
485
 
 
486
        self.endpoint = OpenIDServiceEndpoint()
 
487
        self.endpoint.server_url = self.server_url
 
488
        self.endpoint.claimed_id = claimed_id
 
489
 
 
490
    def tearDown(self):
 
491
        CatchLogs.tearDown(self)
 
492
 
 
493
 
 
494
    def test_idResMissingNoSigs(self):
 
495
        def _vrfy(resp_msg, endpoint=None):
 
496
            return endpoint
 
497
 
 
498
        self.consumer._verifyDiscoveryResults = _vrfy
 
499
        r = self.consumer.complete(self.message, self.endpoint)
 
500
        self.failUnlessSuccess(r)
 
501
 
 
502
 
 
503
    def test_idResNoIdentity(self):
 
504
        self.message.delArg(OPENID_NS, 'identity')
 
505
        self.message.delArg(OPENID_NS, 'claimed_id')
 
506
        self.endpoint.claimed_id = None
 
507
        self.message.setArg(OPENID_NS, 'signed', 'return_to,response_nonce,assoc_handle')
 
508
        r = self.consumer.complete(self.message, self.endpoint)
 
509
        self.failUnlessSuccess(r)
 
510
 
 
511
 
 
512
    def test_idResMissingIdentitySig(self):
 
513
        self.message.setArg(OPENID_NS, 'signed', 'return_to,response_nonce,assoc_handle,claimed_id')
 
514
        r = self.consumer.complete(self.message, self.endpoint)
 
515
        self.failUnlessEqual(r.status, FAILURE)
 
516
 
 
517
 
 
518
    def test_idResMissingReturnToSig(self):
 
519
        self.message.setArg(OPENID_NS, 'signed', 'identity,response_nonce,assoc_handle,claimed_id')
 
520
        r = self.consumer.complete(self.message, self.endpoint)
 
521
        self.failUnlessEqual(r.status, FAILURE)
 
522
 
 
523
 
 
524
    def test_idResMissingAssocHandleSig(self):
 
525
        self.message.setArg(OPENID_NS, 'signed', 'identity,response_nonce,return_to,claimed_id')
 
526
        r = self.consumer.complete(self.message, self.endpoint)
 
527
        self.failUnlessEqual(r.status, FAILURE)
 
528
 
 
529
 
 
530
    def test_idResMissingClaimedIDSig(self):
 
531
        self.message.setArg(OPENID_NS, 'signed', 'identity,response_nonce,return_to,assoc_handle')
 
532
        r = self.consumer.complete(self.message, self.endpoint)
 
533
        self.failUnlessEqual(r.status, FAILURE)
 
534
 
 
535
 
 
536
    def failUnlessSuccess(self, response):
 
537
        if response.status != SUCCESS:
 
538
            self.fail("Non-successful response: %s" % (response,))
 
539
 
 
540
 
 
541
 
 
542
class TestCheckAuthResponse(TestIdRes, CatchLogs):
 
543
    def setUp(self):
 
544
        CatchLogs.setUp(self)
 
545
        TestIdRes.setUp(self)
 
546
 
 
547
    def tearDown(self):
 
548
        CatchLogs.tearDown(self)
 
549
 
 
550
    def _createAssoc(self):
 
551
        issued = time.time()
 
552
        lifetime = 1000
 
553
        assoc = association.Association(
 
554
            'handle', 'secret', issued, lifetime, 'HMAC-SHA1')
 
555
        store = self.consumer.store
 
556
        store.storeAssociation(self.server_url, assoc)
 
557
        assoc2 = store.getAssociation(self.server_url)
 
558
        self.failUnlessEqual(assoc, assoc2)
 
559
 
 
560
    def test_goodResponse(self):
 
561
        """successful response to check_authentication"""
 
562
        response = Message.fromOpenIDArgs({'is_valid':'true',})
 
563
        r = self.consumer._processCheckAuthResponse(response, self.server_url)
 
564
        self.failUnless(r)
 
565
 
 
566
    def test_missingAnswer(self):
 
567
        """check_authentication returns false when the server sends no answer"""
 
568
        response = Message.fromOpenIDArgs({})
 
569
        r = self.consumer._processCheckAuthResponse(response, self.server_url)
 
570
        self.failIf(r)
 
571
 
 
572
    def test_badResponse(self):
 
573
        """check_authentication returns false when is_valid is false"""
 
574
        response = Message.fromOpenIDArgs({'is_valid':'false',})
 
575
        r = self.consumer._processCheckAuthResponse(response, self.server_url)
 
576
        self.failIf(r)
 
577
 
 
578
    def test_badResponseInvalidate(self):
 
579
        """Make sure that the handle is invalidated when is_valid is false
 
580
 
 
581
        From "Verifying directly with the OpenID Provider"::
 
582
 
 
583
            If the OP responds with "is_valid" set to "true", and
 
584
            "invalidate_handle" is present, the Relying Party SHOULD
 
585
            NOT send further authentication requests with that handle.
 
586
        """
 
587
        self._createAssoc()
 
588
        response = Message.fromOpenIDArgs({
 
589
            'is_valid':'false',
 
590
            'invalidate_handle':'handle',
 
591
            })
 
592
        r = self.consumer._processCheckAuthResponse(response, self.server_url)
 
593
        self.failIf(r)
 
594
        self.failUnless(
 
595
            self.consumer.store.getAssociation(self.server_url) is None)
 
596
 
 
597
    def test_invalidateMissing(self):
 
598
        """invalidate_handle with a handle that is not present"""
 
599
        response = Message.fromOpenIDArgs({
 
600
            'is_valid':'true',
 
601
            'invalidate_handle':'missing',
 
602
            })
 
603
        r = self.consumer._processCheckAuthResponse(response, self.server_url)
 
604
        self.failUnless(r)
 
605
        self.failUnlessLogMatches(
 
606
            'Received "invalidate_handle"'
 
607
            )
 
608
 
 
609
    def test_invalidateMissing_noStore(self):
 
610
        """invalidate_handle with a handle that is not present"""
 
611
        response = Message.fromOpenIDArgs({
 
612
            'is_valid':'true',
 
613
            'invalidate_handle':'missing',
 
614
            })
 
615
        self.consumer.store = None
 
616
        r = self.consumer._processCheckAuthResponse(response, self.server_url)
 
617
        self.failUnless(r)
 
618
        self.failUnlessLogMatches(
 
619
            'Received "invalidate_handle"',
 
620
            'Unexpectedly got invalidate_handle without a store')
 
621
 
 
622
    def test_invalidatePresent(self):
 
623
        """invalidate_handle with a handle that exists
 
624
 
 
625
        From "Verifying directly with the OpenID Provider"::
 
626
 
 
627
            If the OP responds with "is_valid" set to "true", and
 
628
            "invalidate_handle" is present, the Relying Party SHOULD
 
629
            NOT send further authentication requests with that handle.
 
630
        """
 
631
        self._createAssoc()
 
632
        response = Message.fromOpenIDArgs({
 
633
            'is_valid':'true',
 
634
            'invalidate_handle':'handle',
 
635
            })
 
636
        r = self.consumer._processCheckAuthResponse(response, self.server_url)
 
637
        self.failUnless(r)
 
638
        self.failUnless(
 
639
            self.consumer.store.getAssociation(self.server_url) is None)
 
640
 
 
641
class TestSetupNeeded(TestIdRes):
 
642
    def failUnlessSetupNeeded(self, expected_setup_url, message):
 
643
        try:
 
644
            self.consumer._checkSetupNeeded(message)
 
645
        except SetupNeededError, why:
 
646
            self.failUnlessEqual(expected_setup_url, why.user_setup_url)
 
647
        else:
 
648
            self.fail("Expected to find an immediate-mode response")
 
649
 
 
650
    def test_setupNeededOpenID1(self):
 
651
        """The minimum conditions necessary to trigger Setup Needed"""
 
652
        setup_url = 'http://unittest/setup-here'
 
653
        message = Message.fromPostArgs({
 
654
            'openid.mode': 'id_res',
 
655
            'openid.user_setup_url': setup_url,
 
656
            })
 
657
        self.failUnless(message.isOpenID1())
 
658
        self.failUnlessSetupNeeded(setup_url, message)
 
659
 
 
660
    def test_setupNeededOpenID1_extra(self):
 
661
        """Extra stuff along with setup_url still trigger Setup Needed"""
 
662
        setup_url = 'http://unittest/setup-here'
 
663
        message = Message.fromPostArgs({
 
664
            'openid.mode': 'id_res',
 
665
            'openid.user_setup_url': setup_url,
 
666
            'openid.identity': 'bogus',
 
667
            })
 
668
        self.failUnless(message.isOpenID1())
 
669
        self.failUnlessSetupNeeded(setup_url, message)
 
670
 
 
671
    def test_noSetupNeededOpenID1(self):
 
672
        """When the user_setup_url is missing on an OpenID 1 message,
 
673
        we assume that it's not a cancel response to checkid_immediate"""
 
674
        message = Message.fromOpenIDArgs({'mode': 'id_res'})
 
675
        self.failUnless(message.isOpenID1())
 
676
 
 
677
        # No SetupNeededError raised
 
678
        self.consumer._checkSetupNeeded(message)
 
679
 
 
680
    def test_setupNeededOpenID2(self):
 
681
        message = Message.fromOpenIDArgs({
 
682
            'mode':'setup_needed',
 
683
            'ns':OPENID2_NS,
 
684
            })
 
685
        self.failUnless(message.isOpenID2())
 
686
        response = self.consumer.complete(message, None, None)
 
687
        self.failUnlessEqual('setup_needed', response.status)
 
688
        self.failUnlessEqual(None, response.setup_url)
 
689
 
 
690
    def test_setupNeededDoesntWorkForOpenID1(self):
 
691
        message = Message.fromOpenIDArgs({
 
692
            'mode':'setup_needed',
 
693
            })
 
694
 
 
695
        # No SetupNeededError raised
 
696
        self.consumer._checkSetupNeeded(message)
 
697
 
 
698
        response = self.consumer.complete(message, None, None)
 
699
        self.failUnlessEqual('failure', response.status)
 
700
        self.failUnless(response.message.startswith('Invalid openid.mode'))
 
701
 
 
702
    def test_noSetupNeededOpenID2(self):
 
703
        message = Message.fromOpenIDArgs({
 
704
            'mode':'id_res',
 
705
            'game':'puerto_rico',
 
706
            'ns':OPENID2_NS,
 
707
            })
 
708
        self.failUnless(message.isOpenID2())
 
709
 
 
710
        # No SetupNeededError raised
 
711
        self.consumer._checkSetupNeeded(message)
 
712
 
 
713
class IdResCheckForFieldsTest(TestIdRes):
 
714
    def setUp(self):
 
715
        self.consumer = GenericConsumer(None)
 
716
 
 
717
    def mkSuccessTest(openid_args, signed_list):
 
718
        def test(self):
 
719
            message = Message.fromOpenIDArgs(openid_args)
 
720
            self.consumer._idResCheckForFields(message, signed_list)
 
721
        return test
 
722
 
 
723
    test_openid1Success = mkSuccessTest(
 
724
        {'return_to':'return',
 
725
         'assoc_handle':'assoc handle',
 
726
         'sig':'a signature',
 
727
         'identity':'someone',
 
728
         },
 
729
        ['return_to', 'identity'])
 
730
 
 
731
    test_openid2Success = mkSuccessTest(
 
732
        {'ns':OPENID2_NS,
 
733
         'return_to':'return',
 
734
         'assoc_handle':'assoc handle',
 
735
         'sig':'a signature',
 
736
         'op_endpoint':'my favourite server',
 
737
         'response_nonce':'use only once',
 
738
         },
 
739
        ['return_to', 'response_nonce', 'assoc_handle'])
 
740
 
 
741
    test_openid2Success_identifiers = mkSuccessTest(
 
742
        {'ns':OPENID2_NS,
 
743
         'return_to':'return',
 
744
         'assoc_handle':'assoc handle',
 
745
         'sig':'a signature',
 
746
         'claimed_id':'i claim to be me',
 
747
         'identity':'my server knows me as me',
 
748
         'op_endpoint':'my favourite server',
 
749
         'response_nonce':'use only once',
 
750
         },
 
751
        ['return_to', 'response_nonce', 'identity',
 
752
         'claimed_id', 'assoc_handle'])
 
753
 
 
754
    def mkFailureTest(openid_args, signed_list, sig_fail=False):
 
755
        def test(self):
 
756
            message = Message.fromOpenIDArgs(openid_args)
 
757
            try:
 
758
                self.consumer._idResCheckForFields(message, signed_list)
 
759
            except ProtocolError, why:
 
760
                if sig_fail:
 
761
                    self.failUnless(why[0].endswith('not signed'))
 
762
                else:
 
763
                    self.failUnless(why[0].startswith('Missing required'))
 
764
            else:
 
765
                self.fail('Expected an error, but none occurred')
 
766
        return test
 
767
 
 
768
    test_openid1Missing_returnToSig = mkFailureTest(
 
769
        {'return_to':'return',
 
770
         'assoc_handle':'assoc handle',
 
771
         'sig':'a signature',
 
772
         'identity':'someone',
 
773
         },
 
774
        ['identity'],
 
775
        sig_fail=True)
 
776
 
 
777
    test_openid1Missing_identitySig = mkFailureTest(
 
778
        {'return_to':'return',
 
779
         'assoc_handle':'assoc handle',
 
780
         'sig':'a signature',
 
781
         'identity':'someone',
 
782
         },
 
783
        ['return_to'],
 
784
        sig_fail=True)
 
785
 
 
786
    test_openid1MissingReturnTo = mkFailureTest(
 
787
        {'assoc_handle':'assoc handle',
 
788
         'sig':'a signature',
 
789
         'identity':'someone',
 
790
         },
 
791
        ['return_to', 'identity'])
 
792
 
 
793
    test_openid1MissingAssocHandle = mkFailureTest(
 
794
        {'return_to':'return',
 
795
         'sig':'a signature',
 
796
         'identity':'someone',
 
797
         },
 
798
        ['return_to', 'identity'])
 
799
 
 
800
    # XXX: I could go on...
 
801
 
 
802
class CheckAuthHappened(Exception): pass
 
803
 
 
804
class CheckNonceVerifyTest(TestIdRes, CatchLogs):
 
805
    def setUp(self):
 
806
        CatchLogs.setUp(self)
 
807
        TestIdRes.setUp(self)
 
808
        self.consumer.openid1_nonce_query_arg_name = 'nonce'
 
809
 
 
810
    def tearDown(self):
 
811
        CatchLogs.tearDown(self)
 
812
 
 
813
    def test_openid1Success(self):
 
814
        """use consumer-generated nonce"""
 
815
        self.return_to = 'http://rt.unittest/?nonce=%s' % (mkNonce(),)
 
816
        self.response = Message.fromOpenIDArgs({'return_to': self.return_to})
 
817
        self.consumer._idResCheckNonce(self.response, self.endpoint)
 
818
        self.failUnlessLogEmpty()
 
819
 
 
820
    def test_openid1Missing(self):
 
821
        """use consumer-generated nonce"""
 
822
        self.response = Message.fromOpenIDArgs({})
 
823
        n = self.consumer._idResGetNonceOpenID1(self.response, self.endpoint)
 
824
        self.failUnless(n is None, n)
 
825
        self.failUnlessLogEmpty()
 
826
 
 
827
    def test_consumerNonceOpenID2(self):
 
828
        """OpenID 2 does not use consumer-generated nonce"""
 
829
        self.return_to = 'http://rt.unittest/?nonce=%s' % (mkNonce(),)
 
830
        self.response = Message.fromOpenIDArgs(
 
831
            {'return_to': self.return_to, 'ns':OPENID2_NS})
 
832
        self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce,
 
833
                              self.response, self.endpoint)
 
834
        self.failUnlessLogEmpty()
 
835
 
 
836
    def test_serverNonce(self):
 
837
        """use server-generated nonce"""
 
838
        self.response = Message.fromOpenIDArgs(
 
839
            {'ns':OPENID2_NS, 'response_nonce': mkNonce(),})
 
840
        self.consumer._idResCheckNonce(self.response, self.endpoint)
 
841
        self.failUnlessLogEmpty()
 
842
 
 
843
    def test_serverNonceOpenID1(self):
 
844
        """OpenID 1 does not use server-generated nonce"""
 
845
        self.response = Message.fromOpenIDArgs(
 
846
            {'ns':OPENID1_NS,
 
847
             'return_to': 'http://return.to/',
 
848
             'response_nonce': mkNonce(),})
 
849
        self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce,
 
850
                              self.response, self.endpoint)
 
851
        self.failUnlessLogEmpty()
 
852
 
 
853
    def test_badNonce(self):
 
854
        """remove the nonce from the store
 
855
 
 
856
        From "Checking the Nonce"::
 
857
 
 
858
            When the Relying Party checks the signature on an assertion, the
 
859
 
 
860
            Relying Party SHOULD ensure that an assertion has not yet
 
861
            been accepted with the same value for "openid.response_nonce"
 
862
            from the same OP Endpoint URL.
 
863
        """
 
864
        nonce = mkNonce()
 
865
        stamp, salt = splitNonce(nonce)
 
866
        self.store.useNonce(self.server_url, stamp, salt)
 
867
        self.response = Message.fromOpenIDArgs(
 
868
                                  {'response_nonce': nonce,
 
869
                                   'ns':OPENID2_NS,
 
870
                                   })
 
871
        self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce,
 
872
                              self.response, self.endpoint)
 
873
 
 
874
    def test_successWithNoStore(self):
 
875
        """When there is no store, checking the nonce succeeds"""
 
876
        self.consumer.store = None
 
877
        self.response = Message.fromOpenIDArgs(
 
878
                                  {'response_nonce': mkNonce(),
 
879
                                   'ns':OPENID2_NS,
 
880
                                   })
 
881
        self.consumer._idResCheckNonce(self.response, self.endpoint)
 
882
        self.failUnlessLogEmpty()
 
883
 
 
884
    def test_tamperedNonce(self):
 
885
        """Malformed nonce"""
 
886
        self.response = Message.fromOpenIDArgs(
 
887
                                  {'ns':OPENID2_NS,
 
888
                                   'response_nonce':'malformed'})
 
889
        self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce,
 
890
                              self.response, self.endpoint)
 
891
 
 
892
    def test_missingNonce(self):
 
893
        """no nonce parameter on the return_to"""
 
894
        self.response = Message.fromOpenIDArgs(
 
895
                                  {'return_to': self.return_to})
 
896
        self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce,
 
897
                              self.response, self.endpoint)
 
898
 
 
899
class CheckAuthDetectingConsumer(GenericConsumer):
 
900
    def _checkAuth(self, *args):
 
901
        raise CheckAuthHappened(args)
 
902
 
 
903
    def _idResCheckNonce(self, *args):
 
904
        """We're not testing nonce-checking, so just return success
 
905
        when it asks."""
 
906
        return True
 
907
 
 
908
class TestCheckAuthTriggered(TestIdRes, CatchLogs):
 
909
    consumer_class = CheckAuthDetectingConsumer
 
910
 
 
911
    def setUp(self):
 
912
        TestIdRes.setUp(self)
 
913
        CatchLogs.setUp(self)
 
914
        self.disableDiscoveryVerification()
 
915
 
 
916
    def test_checkAuthTriggered(self):
 
917
        message = Message.fromPostArgs({
 
918
            'openid.return_to':self.return_to,
 
919
            'openid.identity':self.server_id,
 
920
            'openid.assoc_handle':'not_found',
 
921
            'openid.sig': GOODSIG,
 
922
            'openid.signed': 'identity,return_to',
 
923
            })
 
924
        try:
 
925
            result = self.consumer._doIdRes(message, self.endpoint)
 
926
        except CheckAuthHappened:
 
927
            pass
 
928
        else:
 
929
            self.fail('_checkAuth did not happen. Result was: %r %s' %
 
930
                      (result, self.messages))
 
931
 
 
932
    def test_checkAuthTriggeredWithAssoc(self):
 
933
        # Store an association for this server that does not match the
 
934
        # handle that is in the message
 
935
        issued = time.time()
 
936
        lifetime = 1000
 
937
        assoc = association.Association(
 
938
            'handle', 'secret', issued, lifetime, 'HMAC-SHA1')
 
939
        self.store.storeAssociation(self.server_url, assoc)
 
940
 
 
941
        message = Message.fromPostArgs({
 
942
            'openid.return_to':self.return_to,
 
943
            'openid.identity':self.server_id,
 
944
            'openid.assoc_handle':'not_found',
 
945
            'openid.sig': GOODSIG,
 
946
            'openid.signed': 'identity,return_to',
 
947
            })
 
948
        try:
 
949
            result = self.consumer._doIdRes(message, self.endpoint)
 
950
        except CheckAuthHappened:
 
951
            pass
 
952
        else:
 
953
            self.fail('_checkAuth did not happen. Result was: %r' % (result,))
 
954
 
 
955
    def test_expiredAssoc(self):
 
956
        # Store an expired association for the server with the handle
 
957
        # that is in the message
 
958
        issued = time.time() - 10
 
959
        lifetime = 0
 
960
        handle = 'handle'
 
961
        assoc = association.Association(
 
962
            handle, 'secret', issued, lifetime, 'HMAC-SHA1')
 
963
        self.failUnless(assoc.expiresIn <= 0)
 
964
        self.store.storeAssociation(self.server_url, assoc)
 
965
 
 
966
        message = Message.fromPostArgs({
 
967
            'openid.return_to':self.return_to,
 
968
            'openid.identity':self.server_id,
 
969
            'openid.assoc_handle':handle,
 
970
            'openid.sig': GOODSIG,
 
971
            'openid.signed': 'identity,return_to',
 
972
            })
 
973
        self.failUnlessRaises(ProtocolError, self.consumer._doIdRes,
 
974
                              message, self.endpoint)
 
975
 
 
976
    def test_newerAssoc(self):
 
977
        lifetime = 1000
 
978
 
 
979
        good_issued = time.time() - 10
 
980
        good_handle = 'handle'
 
981
        good_assoc = association.Association(
 
982
            good_handle, 'secret', good_issued, lifetime, 'HMAC-SHA1')
 
983
        self.store.storeAssociation(self.server_url, good_assoc)
 
984
 
 
985
        bad_issued = time.time() - 5
 
986
        bad_handle = 'handle2'
 
987
        bad_assoc = association.Association(
 
988
            bad_handle, 'secret', bad_issued, lifetime, 'HMAC-SHA1')
 
989
        self.store.storeAssociation(self.server_url, bad_assoc)
 
990
 
 
991
        query = {
 
992
            'return_to':self.return_to,
 
993
            'identity':self.server_id,
 
994
            'assoc_handle':good_handle,
 
995
            }
 
996
 
 
997
        message = Message.fromOpenIDArgs(query)
 
998
        message = good_assoc.signMessage(message)
 
999
        info = self.consumer._doIdRes(message, self.endpoint)
 
1000
        self.failUnlessEqual(info.status, SUCCESS, info.message)
 
1001
        self.failUnlessEqual(self.consumer_id, info.identity_url)
 
1002
 
 
1003
 
 
1004
 
 
1005
class TestReturnToArgs(unittest.TestCase):
 
1006
    """Verifying the Return URL paramaters.
 
1007
    From the specification "Verifying the Return URL"::
 
1008
 
 
1009
        To verify that the "openid.return_to" URL matches the URL that is
 
1010
        processing this assertion:
 
1011
 
 
1012
         - The URL scheme, authority, and path MUST be the same between the
 
1013
           two URLs.
 
1014
 
 
1015
         - Any query parameters that are present in the "openid.return_to"
 
1016
           URL MUST also be present with the same values in the
 
1017
           accepting URL.
 
1018
 
 
1019
    XXX: So far we have only tested the second item on the list above.
 
1020
    XXX: _verifyReturnToArgs is not invoked anywhere.
 
1021
    """
 
1022
 
 
1023
    def setUp(self):
 
1024
        store = object()
 
1025
        self.consumer = GenericConsumer(store)
 
1026
 
 
1027
    def test_returnToArgsOkay(self):
 
1028
        query = {
 
1029
            'openid.mode': 'id_res',
 
1030
            'openid.return_to': 'http://example.com/?foo=bar',
 
1031
            'foo': 'bar',
 
1032
            }
 
1033
        # no return value, success is assumed if there are no exceptions.
 
1034
        self.consumer._verifyReturnToArgs(query)
 
1035
 
 
1036
 
 
1037
    def test_returnToMismatch(self):
 
1038
        query = {
 
1039
            'openid.mode': 'id_res',
 
1040
            'openid.return_to': 'http://example.com/?foo=bar',
 
1041
            }
 
1042
        # fail, query has no key 'foo'.
 
1043
        self.failUnlessRaises(ValueError,
 
1044
                              self.consumer._verifyReturnToArgs, query)
 
1045
 
 
1046
        query['foo'] = 'baz'
 
1047
        # fail, values for 'foo' do not match.
 
1048
        self.failUnlessRaises(ValueError,
 
1049
                              self.consumer._verifyReturnToArgs, query)
 
1050
 
 
1051
 
 
1052
    def test_noReturnTo(self):
 
1053
        query = {'openid.mode': 'id_res'}
 
1054
        self.failUnlessRaises(ValueError,
 
1055
                              self.consumer._verifyReturnToArgs, query)
 
1056
 
 
1057
    def test_completeBadReturnTo(self):
 
1058
        """Test GenericConsumer.complete()'s handling of bad return_to
 
1059
        values.
 
1060
        """
 
1061
        return_to = "http://some.url/path?foo=bar"
 
1062
 
 
1063
        # Scheme, authority, and path differences are checked by
 
1064
        # GenericConsumer._checkReturnTo.  Query args checked by
 
1065
        # GenericConsumer._verifyReturnToArgs.
 
1066
        bad_return_tos = [
 
1067
            # Scheme only
 
1068
            "https://some.url/path?foo=bar",
 
1069
            # Authority only
 
1070
            "http://some.url.invalid/path?foo=bar",
 
1071
            # Path only
 
1072
            "http://some.url/path_extra?foo=bar",
 
1073
            # Query args differ
 
1074
            "http://some.url/path?foo=bar2",
 
1075
            "http://some.url/path?foo2=bar",
 
1076
            ]
 
1077
 
 
1078
        m = Message(OPENID1_NS)
 
1079
        m.setArg(OPENID_NS, 'mode', 'cancel')
 
1080
        m.setArg(BARE_NS, 'foo', 'bar')
 
1081
        endpoint = None
 
1082
 
 
1083
        for bad in bad_return_tos:
 
1084
            m.setArg(OPENID_NS, 'return_to', bad)
 
1085
            result = self.consumer.complete(m, endpoint, return_to)
 
1086
            self.failUnless(isinstance(result, FailureResponse), \
 
1087
                            "Expected FailureResponse, got %r for %s" % (result, bad))
 
1088
            self.failUnless(result.message == \
 
1089
                            "openid.return_to does not match return URL")
 
1090
 
 
1091
    def test_completeGoodReturnTo(self):
 
1092
        """Test GenericConsumer.complete()'s handling of good
 
1093
        return_to values.
 
1094
        """
 
1095
        return_to = "http://some.url/path"
 
1096
 
 
1097
        good_return_tos = [
 
1098
            (return_to, {}),
 
1099
            (return_to + "?another=arg", {(BARE_NS, 'another'): 'arg'}),
 
1100
            (return_to + "?another=arg#fragment", {(BARE_NS, 'another'): 'arg'}),
 
1101
            ]
 
1102
 
 
1103
        endpoint = None
 
1104
 
 
1105
        for good, extra in good_return_tos:
 
1106
            m = Message(OPENID1_NS)
 
1107
            m.setArg(OPENID_NS, 'mode', 'cancel')
 
1108
 
 
1109
            for ns, key in extra:
 
1110
                m.setArg(ns, key, extra[(ns, key)])
 
1111
 
 
1112
            m.setArg(OPENID_NS, 'return_to', good)
 
1113
            result = self.consumer.complete(m, endpoint, return_to)
 
1114
            self.failUnless(isinstance(result, CancelResponse), \
 
1115
                            "Expected CancelResponse, got %r for %s" % (result, good,))
 
1116
 
 
1117
class MockFetcher(object):
 
1118
    def __init__(self, response=None):
 
1119
        self.response = response or HTTPResponse()
 
1120
        self.fetches = []
 
1121
 
 
1122
    def fetch(self, url, body=None, headers=None):
 
1123
        self.fetches.append((url, body, headers))
 
1124
        return self.response
 
1125
 
 
1126
class ExceptionRaisingMockFetcher(object):
 
1127
    class MyException(Exception):
 
1128
        pass
 
1129
 
 
1130
    def fetch(self, url, body=None, headers=None):
 
1131
        raise self.MyException('mock fetcher exception')
 
1132
 
 
1133
class BadArgCheckingConsumer(GenericConsumer):
 
1134
    def _makeKVPost(self, args, _):
 
1135
        assert args == {
 
1136
            'openid.mode':'check_authentication',
 
1137
            'openid.signed':'foo',
 
1138
            }, args
 
1139
        return None
 
1140
 
 
1141
class TestCheckAuth(unittest.TestCase, CatchLogs):
 
1142
    consumer_class = GenericConsumer
 
1143
 
 
1144
    def setUp(self):
 
1145
        CatchLogs.setUp(self)
 
1146
        self.store = memstore.MemoryStore()
 
1147
 
 
1148
        self.consumer = self.consumer_class(self.store)
 
1149
 
 
1150
        self._orig_fetcher = fetchers.getDefaultFetcher()
 
1151
        self.fetcher = MockFetcher()
 
1152
        fetchers.setDefaultFetcher(self.fetcher)
 
1153
 
 
1154
    def tearDown(self):
 
1155
        CatchLogs.tearDown(self)
 
1156
        fetchers.setDefaultFetcher(self._orig_fetcher, wrap_exceptions=False)
 
1157
 
 
1158
    def test_error(self):
 
1159
        self.fetcher.response = HTTPResponse(
 
1160
            "http://some_url", 404, {'Hea': 'der'}, 'blah:blah\n')
 
1161
        query = {'openid.signed': 'stuff',
 
1162
                 'openid.stuff':'a value'}
 
1163
        r = self.consumer._checkAuth(Message.fromPostArgs(query),
 
1164
                                     http_server_url)
 
1165
        self.failIf(r)
 
1166
        self.failUnless(self.messages)
 
1167
 
 
1168
    def test_bad_args(self):
 
1169
        query = {
 
1170
            'openid.signed':'foo',
 
1171
            'closid.foo':'something',
 
1172
            }
 
1173
        consumer = BadArgCheckingConsumer(self.store)
 
1174
        consumer._checkAuth(Message.fromPostArgs(query), 'does://not.matter')
 
1175
 
 
1176
 
 
1177
    def test_signedList(self):
 
1178
        query = Message.fromOpenIDArgs({
 
1179
            'mode': 'id_res',
 
1180
            'sig': 'rabbits',
 
1181
            'identity': '=example',
 
1182
            'assoc_handle': 'munchkins',
 
1183
            'signed': 'identity,mode',
 
1184
            'foo': 'bar',
 
1185
            })
 
1186
        expected = Message.fromOpenIDArgs({
 
1187
            'mode': 'check_authentication',
 
1188
            'sig': 'rabbits',
 
1189
            'assoc_handle': 'munchkins',
 
1190
            'identity': '=example',
 
1191
            'signed': 'identity,mode'
 
1192
            })
 
1193
        args = self.consumer._createCheckAuthRequest(query)
 
1194
        self.failUnlessEqual(args.toPostArgs(), expected.toPostArgs())
 
1195
 
 
1196
 
 
1197
 
 
1198
class TestFetchAssoc(unittest.TestCase, CatchLogs):
 
1199
    consumer_class = GenericConsumer
 
1200
 
 
1201
    def setUp(self):
 
1202
        CatchLogs.setUp(self)
 
1203
        self.store = memstore.MemoryStore()
 
1204
        self.fetcher = MockFetcher()
 
1205
        fetchers.setDefaultFetcher(self.fetcher)
 
1206
        self.consumer = self.consumer_class(self.store)
 
1207
 
 
1208
    def test_error_404(self):
 
1209
        """404 from a kv post raises HTTPFetchingError"""
 
1210
        self.fetcher.response = HTTPResponse(
 
1211
            "http://some_url", 404, {'Hea': 'der'}, 'blah:blah\n')
 
1212
        self.failUnlessRaises(
 
1213
            fetchers.HTTPFetchingError,
 
1214
            self.consumer._makeKVPost,
 
1215
            Message.fromPostArgs({'mode':'associate'}),
 
1216
            "http://server_url")
 
1217
 
 
1218
    def test_error_exception_unwrapped(self):
 
1219
        """Ensure that exceptions are bubbled through from fetchers
 
1220
        when making associations
 
1221
        """
 
1222
        self.fetcher = ExceptionRaisingMockFetcher()
 
1223
        fetchers.setDefaultFetcher(self.fetcher, wrap_exceptions=False)
 
1224
        self.failUnlessRaises(self.fetcher.MyException,
 
1225
                              self.consumer._makeKVPost,
 
1226
                              Message.fromPostArgs({'mode':'associate'}),
 
1227
                              "http://server_url")
 
1228
 
 
1229
        # exception fetching returns no association
 
1230
        e = OpenIDServiceEndpoint()
 
1231
        e.server_url = 'some://url'
 
1232
        self.failUnlessRaises(self.fetcher.MyException,
 
1233
                              self.consumer._getAssociation, e)
 
1234
 
 
1235
        self.failUnlessRaises(self.fetcher.MyException,
 
1236
                              self.consumer._checkAuth,
 
1237
                              Message.fromPostArgs({'openid.signed':''}),
 
1238
                              'some://url')
 
1239
 
 
1240
    def test_error_exception_wrapped(self):
 
1241
        """Ensure that openid.fetchers.HTTPFetchingError is caught by
 
1242
        the association creation stuff.
 
1243
        """
 
1244
        self.fetcher = ExceptionRaisingMockFetcher()
 
1245
        # This will wrap exceptions!
 
1246
        fetchers.setDefaultFetcher(self.fetcher)
 
1247
        self.failUnlessRaises(fetchers.HTTPFetchingError,
 
1248
                              self.consumer._makeKVPost,
 
1249
                              Message.fromOpenIDArgs({'mode':'associate'}),
 
1250
                              "http://server_url")
 
1251
 
 
1252
        # exception fetching returns no association
 
1253
        e = OpenIDServiceEndpoint()
 
1254
        e.server_url = 'some://url'
 
1255
        self.failUnless(self.consumer._getAssociation(e) is None)
 
1256
 
 
1257
        msg = Message.fromPostArgs({'openid.signed':''})
 
1258
        self.failIf(self.consumer._checkAuth(msg, 'some://url'))
 
1259
 
 
1260
 
 
1261
class TestSuccessResponse(unittest.TestCase):
 
1262
    def setUp(self):
 
1263
        self.endpoint = OpenIDServiceEndpoint()
 
1264
        self.endpoint.claimed_id = 'identity_url'
 
1265
 
 
1266
    def test_extensionResponse(self):
 
1267
        resp = mkSuccess(self.endpoint, {
 
1268
            'ns.sreg':'urn:sreg',
 
1269
            'ns.unittest':'urn:unittest',
 
1270
            'unittest.one':'1',
 
1271
            'unittest.two':'2',
 
1272
            'sreg.nickname':'j3h',
 
1273
            'return_to':'return_to',
 
1274
            })
 
1275
        utargs = resp.extensionResponse('urn:unittest', False)
 
1276
        self.failUnlessEqual(utargs, {'one':'1', 'two':'2'})
 
1277
        sregargs = resp.extensionResponse('urn:sreg', False)
 
1278
        self.failUnlessEqual(sregargs, {'nickname':'j3h'})
 
1279
 
 
1280
    def test_extensionResponseSigned(self):
 
1281
        args = {
 
1282
            'ns.sreg':'urn:sreg',
 
1283
            'ns.unittest':'urn:unittest',
 
1284
            'unittest.one':'1',
 
1285
            'unittest.two':'2',
 
1286
            'sreg.nickname':'j3h',
 
1287
            'sreg.dob':'yesterday',
 
1288
            'return_to':'return_to',
 
1289
            'signed': 'sreg.nickname,unittest.one,sreg.dob',
 
1290
            }
 
1291
 
 
1292
        signed_list = ['openid.sreg.nickname',
 
1293
                       'openid.unittest.one',
 
1294
                       'openid.sreg.dob',]
 
1295
 
 
1296
        # Don't use mkSuccess because it creates an all-inclusive
 
1297
        # signed list.
 
1298
        msg = Message.fromOpenIDArgs(args)
 
1299
        resp = SuccessResponse(self.endpoint, msg, signed_list)
 
1300
 
 
1301
        # All args in this NS are signed, so expect all.
 
1302
        sregargs = resp.extensionResponse('urn:sreg', True)
 
1303
        self.failUnlessEqual(sregargs, {'nickname':'j3h', 'dob': 'yesterday'})
 
1304
 
 
1305
        # Not all args in this NS are signed, so expect None when
 
1306
        # asking for them.
 
1307
        utargs = resp.extensionResponse('urn:unittest', True)
 
1308
        self.failUnlessEqual(utargs, None)
 
1309
 
 
1310
    def test_noReturnTo(self):
 
1311
        resp = mkSuccess(self.endpoint, {})
 
1312
        self.failUnless(resp.getReturnTo() is None)
 
1313
 
 
1314
    def test_returnTo(self):
 
1315
        resp = mkSuccess(self.endpoint, {'return_to':'return_to'})
 
1316
        self.failUnlessEqual(resp.getReturnTo(), 'return_to')
 
1317
 
 
1318
class StubConsumer(object):
 
1319
    def __init__(self):
 
1320
        self.assoc = object()
 
1321
        self.response = None
 
1322
        self.endpoint = None
 
1323
 
 
1324
    def begin(self, service):
 
1325
        auth_req = AuthRequest(service, self.assoc)
 
1326
        self.endpoint = service
 
1327
        return auth_req
 
1328
 
 
1329
    def complete(self, message, endpoint, return_to=None):
 
1330
        assert endpoint is self.endpoint
 
1331
        return self.response
 
1332
 
 
1333
class ConsumerTest(unittest.TestCase):
 
1334
    """Tests for high-level consumer.Consumer functions.
 
1335
 
 
1336
    Its GenericConsumer component is stubbed out with StubConsumer.
 
1337
    """
 
1338
    def setUp(self):
 
1339
        self.endpoint = OpenIDServiceEndpoint()
 
1340
        self.endpoint.claimed_id = self.identity_url = 'http://identity.url/'
 
1341
        self.store = None
 
1342
        self.session = {}
 
1343
        self.consumer = Consumer(self.session, self.store)
 
1344
        self.consumer.consumer = StubConsumer()
 
1345
        self.discovery = Discovery(self.session,
 
1346
                                   self.identity_url,
 
1347
                                   self.consumer.session_key_prefix)
 
1348
 
 
1349
    def test_setAssociationPreference(self):
 
1350
        self.consumer.setAssociationPreference([])
 
1351
        self.failUnless(isinstance(self.consumer.consumer.negotiator,
 
1352
                                   association.SessionNegotiator))
 
1353
        self.failUnlessEqual([],
 
1354
                             self.consumer.consumer.negotiator.allowed_types)
 
1355
        self.consumer.setAssociationPreference([('FOO', 'BAR')])
 
1356
        self.failUnlessEqual([('FOO', 'BAR')],
 
1357
                             self.consumer.consumer.negotiator.allowed_types)
 
1358
 
 
1359
    def withDummyDiscovery(self, callable, dummy_getNextService):
 
1360
        class DummyDisco(object):
 
1361
            def __init__(self, *ignored):
 
1362
                pass
 
1363
 
 
1364
            getNextService = dummy_getNextService
 
1365
 
 
1366
        import openid.consumer.consumer
 
1367
        old_discovery = openid.consumer.consumer.Discovery
 
1368
        try:
 
1369
            openid.consumer.consumer.Discovery = DummyDisco
 
1370
            callable()
 
1371
        finally:
 
1372
            openid.consumer.consumer.Discovery = old_discovery
 
1373
 
 
1374
    def test_beginHTTPError(self):
 
1375
        """Make sure that the discovery HTTP failure case behaves properly
 
1376
        """
 
1377
        def getNextService(self, ignored):
 
1378
            raise HTTPFetchingError("Unit test")
 
1379
 
 
1380
        def test():
 
1381
            try:
 
1382
                self.consumer.begin('unused in this test')
 
1383
            except DiscoveryFailure, why:
 
1384
                self.failUnless(why[0].startswith('Error fetching'))
 
1385
                self.failIf(why[0].find('Unit test') == -1)
 
1386
            else:
 
1387
                self.fail('Expected DiscoveryFailure')
 
1388
 
 
1389
        self.withDummyDiscovery(test, getNextService)
 
1390
 
 
1391
    def test_beginNoServices(self):
 
1392
        def getNextService(self, ignored):
 
1393
            return None
 
1394
 
 
1395
        url = 'http://a.user.url/'
 
1396
        def test():
 
1397
            try:
 
1398
                self.consumer.begin(url)
 
1399
            except DiscoveryFailure, why:
 
1400
                self.failUnless(why[0].startswith('No usable OpenID'))
 
1401
                self.failIf(why[0].find(url) == -1)
 
1402
            else:
 
1403
                self.fail('Expected DiscoveryFailure')
 
1404
 
 
1405
        self.withDummyDiscovery(test, getNextService)
 
1406
 
 
1407
 
 
1408
    def test_beginWithoutDiscovery(self):
 
1409
        # Does this really test anything non-trivial?
 
1410
        result = self.consumer.beginWithoutDiscovery(self.endpoint)
 
1411
 
 
1412
        # The result is an auth request
 
1413
        self.failUnless(isinstance(result, AuthRequest))
 
1414
 
 
1415
        # Side-effect of calling beginWithoutDiscovery is setting the
 
1416
        # session value to the endpoint attribute of the result
 
1417
        self.failUnless(self.session[self.consumer._token_key] is result.endpoint)
 
1418
 
 
1419
        # The endpoint that we passed in is the endpoint on the auth_request
 
1420
        self.failUnless(result.endpoint is self.endpoint)
 
1421
 
 
1422
    def test_completeEmptySession(self):
 
1423
        response = self.consumer.complete({})
 
1424
        self.failUnlessEqual(response.status, FAILURE)
 
1425
        self.failUnless(response.identity_url is None)
 
1426
 
 
1427
    def _doResp(self, auth_req, exp_resp):
 
1428
        """complete a transaction, using the expected response from
 
1429
        the generic consumer."""
 
1430
        # response is an attribute of StubConsumer, returned by
 
1431
        # StubConsumer.complete.
 
1432
        self.consumer.consumer.response = exp_resp
 
1433
 
 
1434
        # endpoint is stored in the session
 
1435
        self.failUnless(self.session)
 
1436
        resp = self.consumer.complete({})
 
1437
 
 
1438
        # All responses should have the same identity URL, and the
 
1439
        # session should be cleaned out
 
1440
        self.failUnless(resp.identity_url is self.identity_url)
 
1441
        self.failIf(self.consumer._token_key in self.session)
 
1442
 
 
1443
        # Expected status response
 
1444
        self.failUnlessEqual(resp.status, exp_resp.status)
 
1445
 
 
1446
        return resp
 
1447
 
 
1448
    def _doRespNoDisco(self, exp_resp):
 
1449
        """Set up a transaction without discovery"""
 
1450
        auth_req = self.consumer.beginWithoutDiscovery(self.endpoint)
 
1451
        resp = self._doResp(auth_req, exp_resp)
 
1452
        # There should be nothing left in the session once we have completed.
 
1453
        self.failIf(self.session)
 
1454
        return resp
 
1455
 
 
1456
    def test_noDiscoCompleteSuccessWithToken(self):
 
1457
        self._doRespNoDisco(mkSuccess(self.endpoint, {}))
 
1458
 
 
1459
    def test_noDiscoCompleteCancelWithToken(self):
 
1460
        self._doRespNoDisco(CancelResponse(self.endpoint))
 
1461
 
 
1462
    def test_noDiscoCompleteFailure(self):
 
1463
        msg = 'failed!'
 
1464
        resp = self._doRespNoDisco(FailureResponse(self.endpoint, msg))
 
1465
        self.failUnless(resp.message is msg)
 
1466
 
 
1467
    def test_noDiscoCompleteSetupNeeded(self):
 
1468
        setup_url = 'http://setup.url/'
 
1469
        resp = self._doRespNoDisco(
 
1470
            SetupNeededResponse(self.endpoint, setup_url))
 
1471
        self.failUnless(resp.setup_url is setup_url)
 
1472
 
 
1473
    # To test that discovery is cleaned up, we need to initialize a
 
1474
    # Yadis manager, and have it put its values in the session.
 
1475
    def _doRespDisco(self, is_clean, exp_resp):
 
1476
        """Set up and execute a transaction, with discovery"""
 
1477
        self.discovery.createManager([self.endpoint], self.identity_url)
 
1478
        auth_req = self.consumer.begin(self.identity_url)
 
1479
        resp = self._doResp(auth_req, exp_resp)
 
1480
 
 
1481
        manager = self.discovery.getManager()
 
1482
        if is_clean:
 
1483
            self.failUnless(self.discovery.getManager() is None, manager)
 
1484
        else:
 
1485
            self.failIf(self.discovery.getManager() is None, manager)
 
1486
 
 
1487
        return resp
 
1488
 
 
1489
    # Cancel and success DO clean up the discovery process
 
1490
    def test_completeSuccess(self):
 
1491
        self._doRespDisco(True, mkSuccess(self.endpoint, {}))
 
1492
 
 
1493
    def test_completeCancel(self):
 
1494
        self._doRespDisco(True, CancelResponse(self.endpoint))
 
1495
 
 
1496
    # Failure and setup_needed don't clean up the discovery process
 
1497
    def test_completeFailure(self):
 
1498
        msg = 'failed!'
 
1499
        resp = self._doRespDisco(False, FailureResponse(self.endpoint, msg))
 
1500
        self.failUnless(resp.message is msg)
 
1501
 
 
1502
    def test_completeSetupNeeded(self):
 
1503
        setup_url = 'http://setup.url/'
 
1504
        resp = self._doRespDisco(
 
1505
            False,
 
1506
            SetupNeededResponse(self.endpoint, setup_url))
 
1507
        self.failUnless(resp.setup_url is setup_url)
 
1508
 
 
1509
    def test_begin(self):
 
1510
        self.discovery.createManager([self.endpoint], self.identity_url)
 
1511
        # Should not raise an exception
 
1512
        auth_req = self.consumer.begin(self.identity_url)
 
1513
        self.failUnless(isinstance(auth_req, AuthRequest))
 
1514
        self.failUnless(auth_req.endpoint is self.endpoint)
 
1515
        self.failUnless(auth_req.endpoint is self.consumer.consumer.endpoint)
 
1516
        self.failUnless(auth_req.assoc is self.consumer.consumer.assoc)
 
1517
 
 
1518
 
 
1519
 
 
1520
class IDPDrivenTest(unittest.TestCase):
 
1521
 
 
1522
    def setUp(self):
 
1523
        self.store = GoodAssocStore()
 
1524
        self.consumer = GenericConsumer(self.store)
 
1525
        self.endpoint = OpenIDServiceEndpoint()
 
1526
        self.endpoint.server_url = "http://idp.unittest/"
 
1527
 
 
1528
 
 
1529
    def test_idpDrivenBegin(self):
 
1530
        # Testing here that the token-handling doesn't explode...
 
1531
        self.consumer.begin(self.endpoint)
 
1532
 
 
1533
 
 
1534
    def test_idpDrivenComplete(self):
 
1535
        identifier = '=directed_identifier'
 
1536
        message = Message.fromPostArgs({
 
1537
            'openid.identity': '=directed_identifier',
 
1538
            'openid.return_to': 'x',
 
1539
            'openid.assoc_handle': 'z',
 
1540
            'openid.signed': 'identity,return_to',
 
1541
            'openid.sig': GOODSIG,
 
1542
            })
 
1543
 
 
1544
        discovered_endpoint = OpenIDServiceEndpoint()
 
1545
        discovered_endpoint.claimed_id = identifier
 
1546
        discovered_endpoint.server_url = self.endpoint.server_url
 
1547
        discovered_endpoint.local_id = identifier
 
1548
        iverified = []
 
1549
        def verifyDiscoveryResults(identifier, endpoint):
 
1550
            self.failUnless(endpoint is self.endpoint)
 
1551
            iverified.append(discovered_endpoint)
 
1552
            return discovered_endpoint
 
1553
        self.consumer._verifyDiscoveryResults = verifyDiscoveryResults
 
1554
        self.consumer._idResCheckNonce = lambda *args: True
 
1555
        response = self.consumer._doIdRes(message, self.endpoint)
 
1556
 
 
1557
        self.failUnlessSuccess(response)
 
1558
        self.failUnlessEqual(response.identity_url, "=directed_identifier")
 
1559
 
 
1560
        # assert that discovery attempt happens and returns good
 
1561
        self.failUnlessEqual(iverified, [discovered_endpoint])
 
1562
 
 
1563
 
 
1564
    def test_idpDrivenCompleteFraud(self):
 
1565
        # crap with an identifier that doesn't match discovery info
 
1566
        message = Message.fromPostArgs({
 
1567
            'openid.identity': '=directed_identifier',
 
1568
            'openid.return_to': 'x',
 
1569
            'openid.assoc_handle': 'z',
 
1570
            'openid.signed': 'identity,return_to',
 
1571
            'openid.sig': GOODSIG,
 
1572
            })
 
1573
        def verifyDiscoveryResults(identifier, endpoint):
 
1574
            raise DiscoveryFailure("PHREAK!", None)
 
1575
        self.consumer._verifyDiscoveryResults = verifyDiscoveryResults
 
1576
        self.failUnlessRaises(DiscoveryFailure, self.consumer._doIdRes,
 
1577
                              message, self.endpoint)
 
1578
 
 
1579
 
 
1580
    def failUnlessSuccess(self, response):
 
1581
        if response.status != SUCCESS:
 
1582
            self.fail("Non-successful response: %s" % (response,))
 
1583
 
 
1584
 
 
1585
 
 
1586
class TestDiscoveryVerification(unittest.TestCase):
 
1587
    services = []
 
1588
 
 
1589
    def setUp(self):
 
1590
        self.store = GoodAssocStore()
 
1591
        self.consumer = GenericConsumer(self.store)
 
1592
 
 
1593
        self.consumer._discover = self.discoveryFunc
 
1594
 
 
1595
        self.identifier = "http://idp.unittest/1337"
 
1596
        self.server_url = "http://endpoint.unittest/"
 
1597
 
 
1598
        self.message = Message.fromPostArgs({
 
1599
            'openid.ns': OPENID2_NS,
 
1600
            'openid.identity': self.identifier,
 
1601
            'openid.claimed_id': self.identifier,
 
1602
            'openid.op_endpoint': self.server_url,
 
1603
            })
 
1604
 
 
1605
        self.endpoint = OpenIDServiceEndpoint()
 
1606
        self.endpoint.server_url = self.server_url
 
1607
 
 
1608
    def test_theGoodStuff(self):
 
1609
        endpoint = OpenIDServiceEndpoint()
 
1610
        endpoint.type_uris = [OPENID_2_0_TYPE]
 
1611
        endpoint.claimed_id = self.identifier
 
1612
        endpoint.server_url = self.server_url
 
1613
        endpoint.local_id = self.identifier
 
1614
        self.services = [endpoint]
 
1615
        r = self.consumer._verifyDiscoveryResults(self.message, endpoint)
 
1616
 
 
1617
        self.failUnlessEqual(r, endpoint)
 
1618
 
 
1619
 
 
1620
    def test_otherServer(self):
 
1621
        # a set of things without the stuff
 
1622
        endpoint = OpenIDServiceEndpoint()
 
1623
        endpoint.type_uris = [OPENID_2_0_TYPE]
 
1624
        endpoint.claimed_id = self.identifier
 
1625
        endpoint.server_url = "http://the-MOON.unittest/"
 
1626
        endpoint.local_id = self.identifier
 
1627
        self.services = [endpoint]
 
1628
        try:
 
1629
            r = self.consumer._verifyDiscoveryResults(self.message, endpoint)
 
1630
        except ProtocolError, e:
 
1631
            # Should we make more ProtocolError subclasses?
 
1632
            self.failUnless('OP Endpoint mismatch' in str(e), e)
 
1633
        else:
 
1634
            self.fail("expected ProtocolError, %r returned." % (r,))
 
1635
            
 
1636
 
 
1637
    def test_foreignDelegate(self):
 
1638
        # a set of things with the server stuff but other delegate
 
1639
        endpoint = OpenIDServiceEndpoint()
 
1640
        endpoint.type_uris = [OPENID_2_0_TYPE]
 
1641
        endpoint.claimed_id = self.identifier
 
1642
        endpoint.server_url = self.server_url
 
1643
        endpoint.local_id = "http://unittest/juan-carlos"
 
1644
        try:
 
1645
            r = self.consumer._verifyDiscoveryResults(self.message, endpoint)
 
1646
        except ProtocolError, e:
 
1647
            self.failUnless('local_id mismatch' in str(e), e)
 
1648
        else:
 
1649
            self.fail("expected ProtocolError, %r returned." % (r,))
 
1650
 
 
1651
 
 
1652
    def test_nothingDiscovered(self):
 
1653
        # a set of no things.
 
1654
        self.services = []
 
1655
        self.failUnlessRaises(DiscoveryFailure,
 
1656
                              self.consumer._verifyDiscoveryResults,
 
1657
                              self.message, self.endpoint)
 
1658
 
 
1659
 
 
1660
    def discoveryFunc(self, identifier):
 
1661
        return identifier, self.services
 
1662
 
 
1663
 
 
1664
class TestCreateAssociationRequest(unittest.TestCase):
 
1665
    def setUp(self):
 
1666
        class DummyEndpoint(object):
 
1667
            use_compatibility = False
 
1668
 
 
1669
            def compatibilityMode(self):
 
1670
                return self.use_compatibility
 
1671
 
 
1672
        self.endpoint = DummyEndpoint()
 
1673
        self.consumer = GenericConsumer(store=None)
 
1674
        self.assoc_type = 'HMAC-SHA1'
 
1675
 
 
1676
    def test_noEncryptionSendsType(self):
 
1677
        session_type = 'no-encryption'
 
1678
        session, args = self.consumer._createAssociateRequest(
 
1679
            self.endpoint, self.assoc_type, session_type)
 
1680
 
 
1681
        self.failUnless(isinstance(session, PlainTextConsumerSession))
 
1682
        expected = Message.fromOpenIDArgs(
 
1683
            {'ns':OPENID2_NS,
 
1684
             'session_type':session_type,
 
1685
             'mode':'associate',
 
1686
             'assoc_type':self.assoc_type,
 
1687
             })
 
1688
 
 
1689
        self.failUnlessEqual(expected, args)
 
1690
 
 
1691
    def test_noEncryptionCompatibility(self):
 
1692
        self.endpoint.use_compatibility = True
 
1693
        session_type = 'no-encryption'
 
1694
        session, args = self.consumer._createAssociateRequest(
 
1695
            self.endpoint, self.assoc_type, session_type)
 
1696
 
 
1697
        self.failUnless(isinstance(session, PlainTextConsumerSession))
 
1698
        self.failUnlessEqual(Message.fromOpenIDArgs({'mode':'associate',
 
1699
                              'assoc_type':self.assoc_type,
 
1700
                              }), args)
 
1701
 
 
1702
    def test_dhSHA1Compatibility(self):
 
1703
        # Set the consumer's session type to a fast session since we
 
1704
        # need it here.
 
1705
        setConsumerSession(self.consumer)
 
1706
 
 
1707
        self.endpoint.use_compatibility = True
 
1708
        session_type = 'DH-SHA1'
 
1709
        session, args = self.consumer._createAssociateRequest(
 
1710
            self.endpoint, self.assoc_type, session_type)
 
1711
 
 
1712
        self.failUnless(isinstance(session, DiffieHellmanSHA1ConsumerSession))
 
1713
 
 
1714
        # This is a random base-64 value, so just check that it's
 
1715
        # present.
 
1716
        self.failUnless(args.getArg(OPENID1_NS, 'dh_consumer_public'))
 
1717
        args.delArg(OPENID1_NS, 'dh_consumer_public')
 
1718
 
 
1719
        # OK, session_type is set here and not for no-encryption
 
1720
        # compatibility
 
1721
        expected = Message.fromOpenIDArgs({'mode':'associate',
 
1722
                                           'session_type':'DH-SHA1',
 
1723
                                           'assoc_type':self.assoc_type,
 
1724
                                           'dh_modulus': 'BfvStQ==',
 
1725
                                           'dh_gen': 'Ag==',
 
1726
                                           })
 
1727
 
 
1728
        self.failUnlessEqual(expected, args)
 
1729
 
 
1730
    # XXX: test the other types
 
1731
 
 
1732
class TestDiffieHellmanResponseParameters(object):
 
1733
    session_cls = None
 
1734
    message_namespace = None
 
1735
 
 
1736
    def setUp(self):
 
1737
        # Pre-compute DH with small prime so tests run quickly.
 
1738
        self.server_dh = DiffieHellman(100389557, 2)
 
1739
        self.consumer_dh = DiffieHellman(100389557, 2)
 
1740
 
 
1741
        # base64(btwoc(g ^ xb mod p))
 
1742
        self.dh_server_public = cryptutil.longToBase64(self.server_dh.public)
 
1743
 
 
1744
        self.secret = cryptutil.randomString(self.session_cls.secret_size)
 
1745
 
 
1746
        self.enc_mac_key = oidutil.toBase64(
 
1747
            self.server_dh.xorSecret(self.consumer_dh.public,
 
1748
                                     self.secret,
 
1749
                                     self.session_cls.hash_func))
 
1750
 
 
1751
        self.consumer_session = self.session_cls(self.consumer_dh)
 
1752
 
 
1753
        self.msg = Message(self.message_namespace)
 
1754
 
 
1755
    def testExtractSecret(self):
 
1756
        self.msg.setArg(OPENID_NS, 'dh_server_public', self.dh_server_public)
 
1757
        self.msg.setArg(OPENID_NS, 'enc_mac_key', self.enc_mac_key)
 
1758
 
 
1759
        extracted = self.consumer_session.extractSecret(self.msg)
 
1760
        self.failUnlessEqual(extracted, self.secret)
 
1761
 
 
1762
    def testAbsentServerPublic(self):
 
1763
        self.msg.setArg(OPENID_NS, 'enc_mac_key', self.enc_mac_key)
 
1764
 
 
1765
        self.failUnlessRaises(KeyError, self.consumer_session.extractSecret, self.msg)
 
1766
 
 
1767
    def testAbsentMacKey(self):
 
1768
        self.msg.setArg(OPENID_NS, 'dh_server_public', self.dh_server_public)
 
1769
 
 
1770
        self.failUnlessRaises(KeyError, self.consumer_session.extractSecret, self.msg)
 
1771
 
 
1772
    def testInvalidBase64Public(self):
 
1773
        self.msg.setArg(OPENID_NS, 'dh_server_public', 'n o t b a s e 6 4.')
 
1774
        self.msg.setArg(OPENID_NS, 'enc_mac_key', self.enc_mac_key)
 
1775
 
 
1776
        self.failUnlessRaises(ValueError, self.consumer_session.extractSecret, self.msg)
 
1777
 
 
1778
    def testInvalidBase64MacKey(self):
 
1779
        self.msg.setArg(OPENID_NS, 'dh_server_public', self.dh_server_public)
 
1780
        self.msg.setArg(OPENID_NS, 'enc_mac_key', 'n o t base 64')
 
1781
 
 
1782
        self.failUnlessRaises(ValueError, self.consumer_session.extractSecret, self.msg)
 
1783
 
 
1784
class TestOpenID1SHA1(TestDiffieHellmanResponseParameters, unittest.TestCase):
 
1785
    session_cls = DiffieHellmanSHA1ConsumerSession
 
1786
    message_namespace = OPENID1_NS
 
1787
 
 
1788
class TestOpenID2SHA1(TestDiffieHellmanResponseParameters, unittest.TestCase):
 
1789
    session_cls = DiffieHellmanSHA1ConsumerSession
 
1790
    message_namespace = OPENID2_NS
 
1791
 
 
1792
if cryptutil.SHA256_AVAILABLE:
 
1793
    class TestOpenID2SHA256(TestDiffieHellmanResponseParameters, unittest.TestCase):
 
1794
        session_cls = DiffieHellmanSHA256ConsumerSession
 
1795
        message_namespace = OPENID2_NS
 
1796
else:
 
1797
    warnings.warn("Not running SHA256 association session tests.")
 
1798
 
 
1799
class TestNoStore(unittest.TestCase):
 
1800
    def setUp(self):
 
1801
        self.consumer = GenericConsumer(None)
 
1802
 
 
1803
    def test_completeNoGetAssoc(self):
 
1804
        """_getAssociation is never called when the store is None"""
 
1805
        def notCalled(unused):
 
1806
            self.fail('This method was unexpectedly called')
 
1807
 
 
1808
        endpoint = OpenIDServiceEndpoint()
 
1809
        endpoint.claimed_id = 'identity_url'
 
1810
 
 
1811
        self.consumer._getAssociation = notCalled
 
1812
        auth_request = self.consumer.begin(endpoint)
 
1813
        # _getAssociation was not called
 
1814
 
 
1815
 
 
1816
 
 
1817
 
 
1818
class NonAnonymousAuthRequest(object):
 
1819
    endpoint = 'unused'
 
1820
 
 
1821
    def setAnonymous(self, unused):
 
1822
        raise ValueError('Should trigger ProtocolError')
 
1823
 
 
1824
class TestConsumerAnonymous(unittest.TestCase):
 
1825
    def test_beginWithoutDiscoveryAnonymousFail(self):
 
1826
        """Make sure that ValueError for setting an auth request
 
1827
        anonymous gets converted to a ProtocolError
 
1828
        """
 
1829
        sess = {}
 
1830
        consumer = Consumer(sess, None)
 
1831
        def bogusBegin(unused):
 
1832
            return NonAnonymousAuthRequest()
 
1833
        consumer.consumer.begin = bogusBegin
 
1834
        self.failUnlessRaises(
 
1835
            ProtocolError,
 
1836
            consumer.beginWithoutDiscovery, None)
 
1837
 
 
1838
 
 
1839
class TestDiscoverAndVerify(unittest.TestCase):
 
1840
    def setUp(self):
 
1841
        self.consumer = GenericConsumer(None)
 
1842
        self.discovery_result = None
 
1843
        def dummyDiscover(unused_identifier):
 
1844
            return self.discovery_result
 
1845
        self.consumer._discover = dummyDiscover
 
1846
        self.to_match = OpenIDServiceEndpoint()
 
1847
 
 
1848
    def failUnlessDiscoveryFailure(self):
 
1849
        self.failUnlessRaises(
 
1850
            DiscoveryFailure,
 
1851
            self.consumer._discoverAndVerify, self.to_match)
 
1852
 
 
1853
    def test_noServices(self):
 
1854
        """Discovery returning no results results in a
 
1855
        DiscoveryFailure exception"""
 
1856
        self.discovery_result = (None, [])
 
1857
        self.failUnlessDiscoveryFailure()
 
1858
 
 
1859
    def test_noMatches(self):
 
1860
        """If no discovered endpoint matches the values from the
 
1861
        assertion, then we end up raising a ProtocolError
 
1862
        """
 
1863
        self.discovery_result = (None, ['unused'])
 
1864
        def raiseProtocolError(unused1, unused2):
 
1865
            raise ProtocolError('unit test')
 
1866
        self.consumer._verifyDiscoverySingle = raiseProtocolError
 
1867
        self.failUnlessDiscoveryFailure()
 
1868
 
 
1869
    def test_matches(self):
 
1870
        """If an endpoint matches, we return it
 
1871
        """
 
1872
        # Discovery returns a single "endpoint" object
 
1873
        matching_endpoint = 'matching endpoint'
 
1874
        self.discovery_result = (None, [matching_endpoint])
 
1875
 
 
1876
        # Make verifying discovery return True for this endpoint
 
1877
        def returnTrue(unused1, unused2):
 
1878
            return True
 
1879
        self.consumer._verifyDiscoverySingle = returnTrue
 
1880
 
 
1881
        # Since _verifyDiscoverySingle returns True, we should get the
 
1882
        # first endpoint that we passed in as a result.
 
1883
        result = self.consumer._discoverAndVerify(self.to_match)
 
1884
        self.failUnlessEqual(matching_endpoint, result)
 
1885
 
 
1886
from openid.extension import Extension
 
1887
class SillyExtension(Extension):
 
1888
    ns_uri = 'http://silly.example.com/'
 
1889
    ns_alias = 'silly'
 
1890
 
 
1891
    def getExtensionArgs(self):
 
1892
        return {'i_am':'silly'}
 
1893
 
 
1894
class TestAddExtension(unittest.TestCase):
 
1895
 
 
1896
    def test_SillyExtension(self):
 
1897
        ext = SillyExtension()
 
1898
        ar = AuthRequest(OpenIDServiceEndpoint(), None)
 
1899
        ar.addExtension(ext)
 
1900
        ext_args = ar.message.getArgs(ext.ns_uri)
 
1901
        self.failUnlessEqual(ext.getExtensionArgs(), ext_args)
 
1902
if __name__ == '__main__':
 
1903
    unittest.main()