~gary/python-openid/python-openid-2.2.1-patched

« back to all changes in this revision

Viewing changes to openid/test/test_association_response.py

  • Committer: Launchpad Patch Queue Manager
  • Date: 2007-11-30 02:46:28 UTC
  • mfrom: (1.1.1 pyopenid-2.0)
  • Revision ID: launchpad@pqm.canonical.com-20071130024628-qktwsew3383iawmq
[rs=SteveA] upgrade to python-openid-2.0.1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
"""Tests for consumer handling of association responses
 
2
 
 
3
This duplicates some things that are covered by test_consumer, but
 
4
this works for now.
 
5
"""
 
6
from openid import oidutil
 
7
from openid.test.test_consumer import CatchLogs
 
8
from openid.message import Message, OPENID2_NS, OPENID_NS, no_default
 
9
from openid.server.server import DiffieHellmanSHA1ServerSession
 
10
from openid.consumer.consumer import GenericConsumer, \
 
11
     DiffieHellmanSHA1ConsumerSession, ProtocolError
 
12
from openid.consumer.discover import OpenIDServiceEndpoint, OPENID_1_1_TYPE, OPENID_2_0_TYPE
 
13
from openid.store import memstore
 
14
import unittest
 
15
 
 
16
# Some values we can use for convenience (see mkAssocResponse)
 
17
association_response_values = {
 
18
    'expires_in': '1000',
 
19
    'assoc_handle':'a handle',
 
20
    'assoc_type':'a type',
 
21
    'session_type':'a session type',
 
22
    'ns':OPENID2_NS,
 
23
    }
 
24
 
 
25
def mkAssocResponse(*keys):
 
26
    """Build an association response message that contains the
 
27
    specified subset of keys. The values come from
 
28
    `association_response_values`.
 
29
 
 
30
    This is useful for testing for missing keys and other times that
 
31
    we don't care what the values are."""
 
32
    args = dict([(key, association_response_values[key]) for key in keys])
 
33
    return Message.fromOpenIDArgs(args)
 
34
 
 
35
class BaseAssocTest(CatchLogs, unittest.TestCase):
 
36
    def setUp(self):
 
37
        CatchLogs.setUp(self)
 
38
        self.store = memstore.MemoryStore()
 
39
        self.consumer = GenericConsumer(self.store)
 
40
        self.endpoint = OpenIDServiceEndpoint()
 
41
 
 
42
    def failUnlessProtocolError(self, str_prefix, func, *args, **kwargs):
 
43
        try:
 
44
            result = func(*args, **kwargs)
 
45
        except ProtocolError, e:
 
46
            message = 'Expected prefix %r, got %r' % (str_prefix, e[0])
 
47
            self.failUnless(e[0].startswith(str_prefix), message)
 
48
        else:
 
49
            self.fail('Expected ProtocolError, got %r' % (result,))
 
50
 
 
51
def mkExtractAssocMissingTest(keys):
 
52
    """Factory function for creating test methods for generating
 
53
    missing field tests.
 
54
 
 
55
    Make a test that ensures that an association response that
 
56
    is missing required fields will short-circuit return None.
 
57
 
 
58
    According to 'Association Session Response' subsection 'Common
 
59
    Response Parameters', the following fields are required for OpenID
 
60
    2.0:
 
61
 
 
62
     * ns
 
63
     * session_type
 
64
     * assoc_handle
 
65
     * assoc_type
 
66
     * expires_in
 
67
 
 
68
    If 'ns' is missing, it will fall back to OpenID 1 checking. In
 
69
    OpenID 1, everything except 'session_type' and 'ns' are required.
 
70
    """
 
71
 
 
72
    def test(self):
 
73
        msg = mkAssocResponse(*keys)
 
74
 
 
75
        self.failUnlessRaises(KeyError,
 
76
                              self.consumer._extractAssociation, msg, None)
 
77
 
 
78
    return test
 
79
 
 
80
class TestExtractAssociationMissingFieldsOpenID2(BaseAssocTest):
 
81
    """Test for returning an error upon missing fields in association
 
82
    responses for OpenID 2"""
 
83
 
 
84
    test_noFields_openid2 = mkExtractAssocMissingTest(['ns'])
 
85
 
 
86
    test_missingExpires_openid2 = mkExtractAssocMissingTest(
 
87
        ['assoc_handle', 'assoc_type', 'session_type', 'ns'])
 
88
 
 
89
    test_missingHandle_openid2 = mkExtractAssocMissingTest(
 
90
        ['expires_in', 'assoc_type', 'session_type', 'ns'])
 
91
 
 
92
    test_missingAssocType_openid2 = mkExtractAssocMissingTest(
 
93
        ['expires_in', 'assoc_handle', 'session_type', 'ns'])
 
94
 
 
95
    test_missingSessionType_openid2 = mkExtractAssocMissingTest(
 
96
        ['expires_in', 'assoc_handle', 'assoc_type', 'ns'])
 
97
 
 
98
class TestExtractAssociationMissingFieldsOpenID1(BaseAssocTest):
 
99
    """Test for returning an error upon missing fields in association
 
100
    responses for OpenID 2"""
 
101
 
 
102
    test_noFields_openid1 = mkExtractAssocMissingTest([])
 
103
 
 
104
    test_missingExpires_openid1 = mkExtractAssocMissingTest(
 
105
        ['assoc_handle', 'assoc_type'])
 
106
 
 
107
    test_missingHandle_openid1 = mkExtractAssocMissingTest(
 
108
        ['expires_in', 'assoc_type'])
 
109
 
 
110
    test_missingAssocType_openid1 = mkExtractAssocMissingTest(
 
111
        ['expires_in', 'assoc_handle'])
 
112
 
 
113
class DummyAssocationSession(object):
 
114
    def __init__(self, session_type, allowed_assoc_types=()):
 
115
        self.session_type = session_type
 
116
        self.allowed_assoc_types = allowed_assoc_types
 
117
 
 
118
class ExtractAssociationSessionTypeMismatch(BaseAssocTest):
 
119
    def mkTest(requested_session_type, response_session_type, openid1=False):
 
120
        def test(self):
 
121
            assoc_session = DummyAssocationSession(requested_session_type)
 
122
            keys = association_response_values.keys()
 
123
            if openid1:
 
124
                keys.remove('ns')
 
125
            msg = mkAssocResponse(*keys)
 
126
            msg.setArg(OPENID_NS, 'session_type', response_session_type)
 
127
            self.failUnlessProtocolError('Session type mismatch',
 
128
                self.consumer._extractAssociation, msg, assoc_session)
 
129
 
 
130
        return test
 
131
 
 
132
    test_typeMismatchNoEncBlank_openid2 = mkTest(
 
133
        requested_session_type='no-encryption',
 
134
        response_session_type='',
 
135
        )
 
136
 
 
137
    test_typeMismatchDHSHA1NoEnc_openid2 = mkTest(
 
138
        requested_session_type='DH-SHA1',
 
139
        response_session_type='no-encryption',
 
140
        )
 
141
 
 
142
    test_typeMismatchDHSHA256NoEnc_openid2 = mkTest(
 
143
        requested_session_type='DH-SHA256',
 
144
        response_session_type='no-encryption',
 
145
        )
 
146
 
 
147
    test_typeMismatchNoEncDHSHA1_openid2 = mkTest(
 
148
        requested_session_type='no-encryption',
 
149
        response_session_type='DH-SHA1',
 
150
        )
 
151
 
 
152
    test_typeMismatchDHSHA1NoEnc_openid1 = mkTest(
 
153
        requested_session_type='DH-SHA1',
 
154
        response_session_type='DH-SHA256',
 
155
        openid1=True,
 
156
        )
 
157
 
 
158
    test_typeMismatchDHSHA256NoEnc_openid1 = mkTest(
 
159
        requested_session_type='DH-SHA256',
 
160
        response_session_type='DH-SHA1',
 
161
        openid1=True,
 
162
        )
 
163
 
 
164
    test_typeMismatchNoEncDHSHA1_openid1 = mkTest(
 
165
        requested_session_type='no-encryption',
 
166
        response_session_type='DH-SHA1',
 
167
        openid1=True,
 
168
        )
 
169
 
 
170
 
 
171
class TestOpenID1AssociationResponseSessionType(BaseAssocTest):
 
172
    def mkTest(expected_session_type, session_type_value):
 
173
        """Return a test method that will check what session type will
 
174
        be used if the OpenID 1 response to an associate call sets the
 
175
        'session_type' field to `session_type_value`
 
176
        """
 
177
        def test(self):
 
178
            self._doTest(expected_session_type, session_type_value)
 
179
            self.failUnlessEqual(0, len(self.messages))
 
180
 
 
181
        return test
 
182
 
 
183
    def _doTest(self, expected_session_type, session_type_value):
 
184
        # Create a Message with just 'session_type' in it, since
 
185
        # that's all this function will use. 'session_type' may be
 
186
        # absent if it's set to None.
 
187
        args = {}
 
188
        if session_type_value is not None:
 
189
            args['session_type'] = session_type_value
 
190
        message = Message.fromOpenIDArgs(args)
 
191
        self.failUnless(message.isOpenID1())
 
192
 
 
193
        actual_session_type = self.consumer._getOpenID1SessionType(message)
 
194
        error_message = ('Returned sesion type parameter %r was expected '
 
195
                         'to yield session type %r, but yielded %r' %
 
196
                         (session_type_value, expected_session_type,
 
197
                          actual_session_type))
 
198
        self.failUnlessEqual(
 
199
            expected_session_type, actual_session_type, error_message)
 
200
 
 
201
    test_none = mkTest(
 
202
        session_type_value=None,
 
203
        expected_session_type='no-encryption',
 
204
        )
 
205
 
 
206
    test_empty = mkTest(
 
207
        session_type_value='',
 
208
        expected_session_type='no-encryption',
 
209
        )
 
210
 
 
211
    # This one's different because it expects log messages
 
212
    def test_explicitNoEncryption(self):
 
213
        self._doTest(
 
214
            session_type_value='no-encryption',
 
215
            expected_session_type='no-encryption',
 
216
            )
 
217
        self.failUnlessEqual(1, len(self.messages))
 
218
        self.failUnless(self.messages[0].startswith(
 
219
            'WARNING: OpenID server sent "no-encryption"'))
 
220
 
 
221
    test_dhSHA1 = mkTest(
 
222
        session_type_value='DH-SHA1',
 
223
        expected_session_type='DH-SHA1',
 
224
        )
 
225
 
 
226
    # DH-SHA256 is not a valid session type for OpenID1, but this
 
227
    # function does not test that. This is mostly just to make sure
 
228
    # that it will pass-through stuff that is not explicitly handled,
 
229
    # so it will get handled the same way as it is handled for OpenID
 
230
    # 2
 
231
    test_dhSHA256 = mkTest(
 
232
        session_type_value='DH-SHA256',
 
233
        expected_session_type='DH-SHA256',
 
234
        )
 
235
 
 
236
class DummyAssociationSession(object):
 
237
    secret = "shh! don't tell!"
 
238
    extract_secret_called = False
 
239
 
 
240
    session_type = None
 
241
 
 
242
    allowed_assoc_types = None
 
243
 
 
244
    def extractSecret(self, message):
 
245
        self.extract_secret_called = True
 
246
        return self.secret
 
247
 
 
248
class TestInvalidFields(BaseAssocTest):
 
249
    def setUp(self):
 
250
        BaseAssocTest.setUp(self)
 
251
        self.session_type = 'testing-session'
 
252
 
 
253
        # This must something that works for Association.fromExpiresIn
 
254
        self.assoc_type = 'HMAC-SHA1'
 
255
 
 
256
        self.assoc_handle = 'testing-assoc-handle'
 
257
 
 
258
        # These arguments should all be valid
 
259
        self.assoc_response = Message.fromOpenIDArgs({
 
260
            'expires_in': '1000',
 
261
            'assoc_handle':self.assoc_handle,
 
262
            'assoc_type':self.assoc_type,
 
263
            'session_type':self.session_type,
 
264
            'ns':OPENID2_NS,
 
265
            })
 
266
 
 
267
        self.assoc_session = DummyAssociationSession()
 
268
 
 
269
        # Make the session for the response's session type
 
270
        self.assoc_session.session_type = self.session_type
 
271
        self.assoc_session.allowed_assoc_types = [self.assoc_type]
 
272
 
 
273
    def test_worksWithGoodFields(self):
 
274
        """Handle a full successful association response"""
 
275
        assoc = self.consumer._extractAssociation(
 
276
            self.assoc_response, self.assoc_session)
 
277
        self.failUnless(self.assoc_session.extract_secret_called)
 
278
        self.failUnlessEqual(self.assoc_session.secret, assoc.secret)
 
279
        self.failUnlessEqual(1000, assoc.lifetime)
 
280
        self.failUnlessEqual(self.assoc_handle, assoc.handle)
 
281
        self.failUnlessEqual(self.assoc_type, assoc.assoc_type)
 
282
 
 
283
    def test_badAssocType(self):
 
284
        # Make sure that the assoc type in the response is not valid
 
285
        # for the given session.
 
286
        self.assoc_session.allowed_assoc_types = []
 
287
        self.failUnlessProtocolError('Unsupported assoc_type for session',
 
288
            self.consumer._extractAssociation,
 
289
            self.assoc_response, self.assoc_session)
 
290
 
 
291
    def test_badExpiresIn(self):
 
292
        # Invalid value for expires_in should cause failure
 
293
        self.assoc_response.setArg(OPENID_NS, 'expires_in', 'forever')
 
294
        self.failUnlessProtocolError('Invalid expires_in',
 
295
            self.consumer._extractAssociation,
 
296
            self.assoc_response, self.assoc_session)
 
297
 
 
298
 
 
299
# XXX: This is what causes most of the imports in this file. It is
 
300
# sort of a unit test and sort of a functional test. I'm not terribly
 
301
# fond of it.
 
302
class TestExtractAssociationDiffieHellman(BaseAssocTest):
 
303
    secret = 'x' * 20
 
304
 
 
305
    def _setUpDH(self):
 
306
        sess, message = self.consumer._createAssociateRequest(
 
307
            self.endpoint, 'HMAC-SHA1', 'DH-SHA1')
 
308
 
 
309
        # XXX: this is testing _createAssociateRequest
 
310
        self.failUnlessEqual(self.endpoint.compatibilityMode(),
 
311
                             message.isOpenID1())
 
312
 
 
313
        server_sess = DiffieHellmanSHA1ServerSession.fromMessage(message)
 
314
        server_resp = server_sess.answer(self.secret)
 
315
        server_resp['assoc_type'] = 'HMAC-SHA1'
 
316
        server_resp['assoc_handle'] = 'handle'
 
317
        server_resp['expires_in'] = '1000'
 
318
        server_resp['session_type'] = 'DH-SHA1'
 
319
        return sess, Message.fromOpenIDArgs(server_resp)
 
320
 
 
321
    def test_success(self):
 
322
        sess, server_resp = self._setUpDH()
 
323
        ret = self.consumer._extractAssociation(server_resp, sess)
 
324
        self.failIf(ret is None)
 
325
        self.failUnlessEqual(ret.assoc_type, 'HMAC-SHA1')
 
326
        self.failUnlessEqual(ret.secret, self.secret)
 
327
        self.failUnlessEqual(ret.handle, 'handle')
 
328
        self.failUnlessEqual(ret.lifetime, 1000)
 
329
 
 
330
    def test_openid2success(self):
 
331
        # Use openid 2 type in endpoint so _setUpDH checks
 
332
        # compatibility mode state properly
 
333
        self.endpoint.type_uris = [OPENID_2_0_TYPE, OPENID_1_1_TYPE]
 
334
        self.test_success()
 
335
 
 
336
    def test_badDHValues(self):
 
337
        sess, server_resp = self._setUpDH()
 
338
        server_resp.setArg(OPENID_NS, 'enc_mac_key', '\x00\x00\x00')
 
339
        self.failUnlessProtocolError('Malformed response for',
 
340
            self.consumer._extractAssociation, server_resp, sess)