1
"""Tests for consumer handling of association responses
3
This duplicates some things that are covered by test_consumer, but
6
from openid import oidutil
7
from openid.test.test_consumer import CatchLogs
8
from openid.message import Message, OPENID2_NS, OPENID_NS, no_default
9
from openid.server.server import DiffieHellmanSHA1ServerSession
10
from openid.consumer.consumer import GenericConsumer, \
11
DiffieHellmanSHA1ConsumerSession, ProtocolError
12
from openid.consumer.discover import OpenIDServiceEndpoint, OPENID_1_1_TYPE, OPENID_2_0_TYPE
13
from openid.store import memstore
16
# Some values we can use for convenience (see mkAssocResponse)
17
association_response_values = {
19
'assoc_handle':'a handle',
20
'assoc_type':'a type',
21
'session_type':'a session type',
25
def mkAssocResponse(*keys):
26
"""Build an association response message that contains the
27
specified subset of keys. The values come from
28
`association_response_values`.
30
This is useful for testing for missing keys and other times that
31
we don't care what the values are."""
32
args = dict([(key, association_response_values[key]) for key in keys])
33
return Message.fromOpenIDArgs(args)
35
class BaseAssocTest(CatchLogs, unittest.TestCase):
38
self.store = memstore.MemoryStore()
39
self.consumer = GenericConsumer(self.store)
40
self.endpoint = OpenIDServiceEndpoint()
42
def failUnlessProtocolError(self, str_prefix, func, *args, **kwargs):
44
result = func(*args, **kwargs)
45
except ProtocolError, e:
46
message = 'Expected prefix %r, got %r' % (str_prefix, e[0])
47
self.failUnless(e[0].startswith(str_prefix), message)
49
self.fail('Expected ProtocolError, got %r' % (result,))
51
def mkExtractAssocMissingTest(keys):
52
"""Factory function for creating test methods for generating
55
Make a test that ensures that an association response that
56
is missing required fields will short-circuit return None.
58
According to 'Association Session Response' subsection 'Common
59
Response Parameters', the following fields are required for OpenID
68
If 'ns' is missing, it will fall back to OpenID 1 checking. In
69
OpenID 1, everything except 'session_type' and 'ns' are required.
73
msg = mkAssocResponse(*keys)
75
self.failUnlessRaises(KeyError,
76
self.consumer._extractAssociation, msg, None)
80
class TestExtractAssociationMissingFieldsOpenID2(BaseAssocTest):
81
"""Test for returning an error upon missing fields in association
82
responses for OpenID 2"""
84
test_noFields_openid2 = mkExtractAssocMissingTest(['ns'])
86
test_missingExpires_openid2 = mkExtractAssocMissingTest(
87
['assoc_handle', 'assoc_type', 'session_type', 'ns'])
89
test_missingHandle_openid2 = mkExtractAssocMissingTest(
90
['expires_in', 'assoc_type', 'session_type', 'ns'])
92
test_missingAssocType_openid2 = mkExtractAssocMissingTest(
93
['expires_in', 'assoc_handle', 'session_type', 'ns'])
95
test_missingSessionType_openid2 = mkExtractAssocMissingTest(
96
['expires_in', 'assoc_handle', 'assoc_type', 'ns'])
98
class TestExtractAssociationMissingFieldsOpenID1(BaseAssocTest):
99
"""Test for returning an error upon missing fields in association
100
responses for OpenID 2"""
102
test_noFields_openid1 = mkExtractAssocMissingTest([])
104
test_missingExpires_openid1 = mkExtractAssocMissingTest(
105
['assoc_handle', 'assoc_type'])
107
test_missingHandle_openid1 = mkExtractAssocMissingTest(
108
['expires_in', 'assoc_type'])
110
test_missingAssocType_openid1 = mkExtractAssocMissingTest(
111
['expires_in', 'assoc_handle'])
113
class DummyAssocationSession(object):
114
def __init__(self, session_type, allowed_assoc_types=()):
115
self.session_type = session_type
116
self.allowed_assoc_types = allowed_assoc_types
118
class ExtractAssociationSessionTypeMismatch(BaseAssocTest):
119
def mkTest(requested_session_type, response_session_type, openid1=False):
121
assoc_session = DummyAssocationSession(requested_session_type)
122
keys = association_response_values.keys()
125
msg = mkAssocResponse(*keys)
126
msg.setArg(OPENID_NS, 'session_type', response_session_type)
127
self.failUnlessProtocolError('Session type mismatch',
128
self.consumer._extractAssociation, msg, assoc_session)
132
test_typeMismatchNoEncBlank_openid2 = mkTest(
133
requested_session_type='no-encryption',
134
response_session_type='',
137
test_typeMismatchDHSHA1NoEnc_openid2 = mkTest(
138
requested_session_type='DH-SHA1',
139
response_session_type='no-encryption',
142
test_typeMismatchDHSHA256NoEnc_openid2 = mkTest(
143
requested_session_type='DH-SHA256',
144
response_session_type='no-encryption',
147
test_typeMismatchNoEncDHSHA1_openid2 = mkTest(
148
requested_session_type='no-encryption',
149
response_session_type='DH-SHA1',
152
test_typeMismatchDHSHA1NoEnc_openid1 = mkTest(
153
requested_session_type='DH-SHA1',
154
response_session_type='DH-SHA256',
158
test_typeMismatchDHSHA256NoEnc_openid1 = mkTest(
159
requested_session_type='DH-SHA256',
160
response_session_type='DH-SHA1',
164
test_typeMismatchNoEncDHSHA1_openid1 = mkTest(
165
requested_session_type='no-encryption',
166
response_session_type='DH-SHA1',
171
class TestOpenID1AssociationResponseSessionType(BaseAssocTest):
172
def mkTest(expected_session_type, session_type_value):
173
"""Return a test method that will check what session type will
174
be used if the OpenID 1 response to an associate call sets the
175
'session_type' field to `session_type_value`
178
self._doTest(expected_session_type, session_type_value)
179
self.failUnlessEqual(0, len(self.messages))
183
def _doTest(self, expected_session_type, session_type_value):
184
# Create a Message with just 'session_type' in it, since
185
# that's all this function will use. 'session_type' may be
186
# absent if it's set to None.
188
if session_type_value is not None:
189
args['session_type'] = session_type_value
190
message = Message.fromOpenIDArgs(args)
191
self.failUnless(message.isOpenID1())
193
actual_session_type = self.consumer._getOpenID1SessionType(message)
194
error_message = ('Returned sesion type parameter %r was expected '
195
'to yield session type %r, but yielded %r' %
196
(session_type_value, expected_session_type,
197
actual_session_type))
198
self.failUnlessEqual(
199
expected_session_type, actual_session_type, error_message)
202
session_type_value=None,
203
expected_session_type='no-encryption',
207
session_type_value='',
208
expected_session_type='no-encryption',
211
# This one's different because it expects log messages
212
def test_explicitNoEncryption(self):
214
session_type_value='no-encryption',
215
expected_session_type='no-encryption',
217
self.failUnlessEqual(1, len(self.messages))
218
self.failUnless(self.messages[0].startswith(
219
'WARNING: OpenID server sent "no-encryption"'))
221
test_dhSHA1 = mkTest(
222
session_type_value='DH-SHA1',
223
expected_session_type='DH-SHA1',
226
# DH-SHA256 is not a valid session type for OpenID1, but this
227
# function does not test that. This is mostly just to make sure
228
# that it will pass-through stuff that is not explicitly handled,
229
# so it will get handled the same way as it is handled for OpenID
231
test_dhSHA256 = mkTest(
232
session_type_value='DH-SHA256',
233
expected_session_type='DH-SHA256',
236
class DummyAssociationSession(object):
237
secret = "shh! don't tell!"
238
extract_secret_called = False
242
allowed_assoc_types = None
244
def extractSecret(self, message):
245
self.extract_secret_called = True
248
class TestInvalidFields(BaseAssocTest):
250
BaseAssocTest.setUp(self)
251
self.session_type = 'testing-session'
253
# This must something that works for Association.fromExpiresIn
254
self.assoc_type = 'HMAC-SHA1'
256
self.assoc_handle = 'testing-assoc-handle'
258
# These arguments should all be valid
259
self.assoc_response = Message.fromOpenIDArgs({
260
'expires_in': '1000',
261
'assoc_handle':self.assoc_handle,
262
'assoc_type':self.assoc_type,
263
'session_type':self.session_type,
267
self.assoc_session = DummyAssociationSession()
269
# Make the session for the response's session type
270
self.assoc_session.session_type = self.session_type
271
self.assoc_session.allowed_assoc_types = [self.assoc_type]
273
def test_worksWithGoodFields(self):
274
"""Handle a full successful association response"""
275
assoc = self.consumer._extractAssociation(
276
self.assoc_response, self.assoc_session)
277
self.failUnless(self.assoc_session.extract_secret_called)
278
self.failUnlessEqual(self.assoc_session.secret, assoc.secret)
279
self.failUnlessEqual(1000, assoc.lifetime)
280
self.failUnlessEqual(self.assoc_handle, assoc.handle)
281
self.failUnlessEqual(self.assoc_type, assoc.assoc_type)
283
def test_badAssocType(self):
284
# Make sure that the assoc type in the response is not valid
285
# for the given session.
286
self.assoc_session.allowed_assoc_types = []
287
self.failUnlessProtocolError('Unsupported assoc_type for session',
288
self.consumer._extractAssociation,
289
self.assoc_response, self.assoc_session)
291
def test_badExpiresIn(self):
292
# Invalid value for expires_in should cause failure
293
self.assoc_response.setArg(OPENID_NS, 'expires_in', 'forever')
294
self.failUnlessProtocolError('Invalid expires_in',
295
self.consumer._extractAssociation,
296
self.assoc_response, self.assoc_session)
299
# XXX: This is what causes most of the imports in this file. It is
300
# sort of a unit test and sort of a functional test. I'm not terribly
302
class TestExtractAssociationDiffieHellman(BaseAssocTest):
306
sess, message = self.consumer._createAssociateRequest(
307
self.endpoint, 'HMAC-SHA1', 'DH-SHA1')
309
# XXX: this is testing _createAssociateRequest
310
self.failUnlessEqual(self.endpoint.compatibilityMode(),
313
server_sess = DiffieHellmanSHA1ServerSession.fromMessage(message)
314
server_resp = server_sess.answer(self.secret)
315
server_resp['assoc_type'] = 'HMAC-SHA1'
316
server_resp['assoc_handle'] = 'handle'
317
server_resp['expires_in'] = '1000'
318
server_resp['session_type'] = 'DH-SHA1'
319
return sess, Message.fromOpenIDArgs(server_resp)
321
def test_success(self):
322
sess, server_resp = self._setUpDH()
323
ret = self.consumer._extractAssociation(server_resp, sess)
324
self.failIf(ret is None)
325
self.failUnlessEqual(ret.assoc_type, 'HMAC-SHA1')
326
self.failUnlessEqual(ret.secret, self.secret)
327
self.failUnlessEqual(ret.handle, 'handle')
328
self.failUnlessEqual(ret.lifetime, 1000)
330
def test_openid2success(self):
331
# Use openid 2 type in endpoint so _setUpDH checks
332
# compatibility mode state properly
333
self.endpoint.type_uris = [OPENID_2_0_TYPE, OPENID_1_1_TYPE]
336
def test_badDHValues(self):
337
sess, server_resp = self._setUpDH()
338
server_resp.setArg(OPENID_NS, 'enc_mac_key', '\x00\x00\x00')
339
self.failUnlessProtocolError('Malformed response for',
340
self.consumer._extractAssociation, server_resp, sess)