~jtaylor/ubuntu/maverick/python-django-piston/fix-884910

« back to all changes in this revision

Viewing changes to piston/oauth.py

  • Committer: Bazaar Package Importer
  • Author(s): Michael Ziegler
  • Date: 2010-02-22 08:43:21 UTC
  • Revision ID: james.westby@ubuntu.com-20100222084321-4w7ah3ue1j0tg480
Tags: upstream-0.2.2
ImportĀ upstreamĀ versionĀ 0.2.2

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
import cgi
 
2
import urllib
 
3
import time
 
4
import random
 
5
import urlparse
 
6
import hmac
 
7
import base64
 
8
 
 
9
VERSION = '1.0' # Hi Blaine!
 
10
HTTP_METHOD = 'GET'
 
11
SIGNATURE_METHOD = 'PLAINTEXT'
 
12
 
 
13
# Generic exception class
 
14
class OAuthError(RuntimeError):
 
15
    def get_message(self): 
 
16
        return self._message
 
17
 
 
18
    def set_message(self, message): 
 
19
        self._message = message
 
20
 
 
21
    message = property(get_message, set_message)
 
22
 
 
23
    def __init__(self, message='OAuth error occured.'):
 
24
        self.message = message
 
25
 
 
26
# optional WWW-Authenticate header (401 error)
 
27
def build_authenticate_header(realm=''):
 
28
    return { 'WWW-Authenticate': 'OAuth realm="%s"' % realm }
 
29
 
 
30
# url escape
 
31
def escape(s):
 
32
    # escape '/' too
 
33
    return urllib.quote(s, safe='~')
 
34
 
 
35
# util function: current timestamp
 
36
# seconds since epoch (UTC)
 
37
def generate_timestamp():
 
38
    return int(time.time())
 
39
 
 
40
# util function: nonce
 
41
# pseudorandom number
 
42
def generate_nonce(length=8):
 
43
    return ''.join(str(random.randint(0, 9)) for i in range(length))
 
44
 
 
45
# OAuthConsumer is a data type that represents the identity of the Consumer
 
46
# via its shared secret with the Service Provider.
 
47
class OAuthConsumer(object):
 
48
    key = None
 
49
    secret = None
 
50
 
 
51
    def __init__(self, key, secret):
 
52
        self.key = key
 
53
        self.secret = secret
 
54
 
 
55
# OAuthToken is a data type that represents an End User via either an access
 
56
# or request token.     
 
57
class OAuthToken(object):
 
58
    # access tokens and request tokens
 
59
    key = None
 
60
    secret = None
 
61
 
 
62
    '''
 
63
    key = the token
 
64
    secret = the token secret
 
65
    '''
 
66
    def __init__(self, key, secret):
 
67
        self.key = key
 
68
        self.secret = secret
 
69
 
 
70
    def to_string(self):
 
71
        return urllib.urlencode({'oauth_token': self.key, 'oauth_token_secret': self.secret})
 
72
 
 
73
    # return a token from something like:
 
74
    # oauth_token_secret=digg&oauth_token=digg
 
75
    @staticmethod   
 
76
    def from_string(s):
 
77
        params = cgi.parse_qs(s, keep_blank_values=False)
 
78
        key = params['oauth_token'][0]
 
79
        secret = params['oauth_token_secret'][0]
 
80
        return OAuthToken(key, secret)
 
81
 
 
82
    def __str__(self):
 
83
        return self.to_string()
 
84
 
 
85
# OAuthRequest represents the request and can be serialized
 
86
class OAuthRequest(object):
 
87
    '''
 
88
    OAuth parameters:
 
89
        - oauth_consumer_key 
 
90
        - oauth_token
 
91
        - oauth_signature_method
 
92
        - oauth_signature 
 
93
        - oauth_timestamp 
 
94
        - oauth_nonce
 
95
        - oauth_version
 
96
        ... any additional parameters, as defined by the Service Provider.
 
97
    '''
 
98
    parameters = None # oauth parameters
 
99
    http_method = HTTP_METHOD
 
100
    http_url = None
 
101
    version = VERSION
 
102
 
 
103
    def __init__(self, http_method=HTTP_METHOD, http_url=None, parameters=None):
 
104
        self.http_method = http_method
 
105
        self.http_url = http_url
 
106
        self.parameters = parameters or {}
 
107
 
 
108
    def set_parameter(self, parameter, value):
 
109
        self.parameters[parameter] = value
 
110
 
 
111
    def get_parameter(self, parameter):
 
112
        try:
 
113
            return self.parameters[parameter]
 
114
        except:
 
115
            raise OAuthError('Parameter not found: %s' % parameter)
 
116
 
 
117
    def _get_timestamp_nonce(self):
 
118
        return self.get_parameter('oauth_timestamp'), self.get_parameter('oauth_nonce')
 
119
 
 
120
    # get any non-oauth parameters
 
121
    def get_nonoauth_parameters(self):
 
122
        parameters = {}
 
123
        for k, v in self.parameters.iteritems():
 
124
            # ignore oauth parameters
 
125
            if k.find('oauth_') < 0:
 
126
                parameters[k] = v
 
127
        return parameters
 
128
 
 
129
    # serialize as a header for an HTTPAuth request
 
130
    def to_header(self, realm=''):
 
131
        auth_header = 'OAuth realm="%s"' % realm
 
132
        # add the oauth parameters
 
133
        if self.parameters:
 
134
            for k, v in self.parameters.iteritems():
 
135
                auth_header += ', %s="%s"' % (k, escape(str(v)))
 
136
        return {'Authorization': auth_header}
 
137
 
 
138
    # serialize as post data for a POST request
 
139
    def to_postdata(self):
 
140
        return '&'.join('%s=%s' % (escape(str(k)), escape(str(v))) for k, v in self.parameters.iteritems())
 
141
 
 
142
    # serialize as a url for a GET request
 
143
    def to_url(self):
 
144
        return '%s?%s' % (self.get_normalized_http_url(), self.to_postdata())
 
145
 
 
146
    # return a string that consists of all the parameters that need to be signed
 
147
    def get_normalized_parameters(self):
 
148
        params = self.parameters
 
149
        try:
 
150
            # exclude the signature if it exists
 
151
            del params['oauth_signature']
 
152
        except:
 
153
            pass
 
154
        key_values = params.items()
 
155
        # sort lexicographically, first after key, then after value
 
156
        key_values.sort()
 
157
        # combine key value pairs in string and escape
 
158
        return '&'.join('%s=%s' % (escape(str(k)), escape(str(v))) for k, v in key_values)
 
159
 
 
160
    # just uppercases the http method
 
161
    def get_normalized_http_method(self):
 
162
        return self.http_method.upper()
 
163
 
 
164
    # parses the url and rebuilds it to be scheme://host/path
 
165
    def get_normalized_http_url(self):
 
166
        parts = urlparse.urlparse(self.http_url)
 
167
        url_string = '%s://%s%s' % (parts[0], parts[1], parts[2]) # scheme, netloc, path
 
168
        return url_string
 
169
        
 
170
    # set the signature parameter to the result of build_signature
 
171
    def sign_request(self, signature_method, consumer, token):
 
172
        # set the signature method
 
173
        self.set_parameter('oauth_signature_method', signature_method.get_name())
 
174
        # set the signature
 
175
        self.set_parameter('oauth_signature', self.build_signature(signature_method, consumer, token))
 
176
 
 
177
    def build_signature(self, signature_method, consumer, token):
 
178
        # call the build signature method within the signature method
 
179
        return signature_method.build_signature(self, consumer, token)
 
180
 
 
181
    @staticmethod
 
182
    def from_request(http_method, http_url, headers=None, parameters=None, query_string=None):
 
183
        # combine multiple parameter sources
 
184
        if parameters is None:
 
185
            parameters = {}
 
186
 
 
187
        # headers
 
188
        if headers and 'HTTP_AUTHORIZATION' in headers:
 
189
            auth_header = headers['HTTP_AUTHORIZATION']
 
190
            # check that the authorization header is OAuth
 
191
            if auth_header.index('OAuth') > -1:
 
192
                try:
 
193
                    # get the parameters from the header
 
194
                    header_params = OAuthRequest._split_header(auth_header)
 
195
                    parameters.update(header_params)
 
196
                except:
 
197
                    raise OAuthError('Unable to parse OAuth parameters from Authorization header.')
 
198
 
 
199
        # GET or POST query string
 
200
        if query_string:
 
201
            query_params = OAuthRequest._split_url_string(query_string)
 
202
            parameters.update(query_params)
 
203
 
 
204
        # URL parameters
 
205
        param_str = urlparse.urlparse(http_url)[4] # query
 
206
        url_params = OAuthRequest._split_url_string(param_str)
 
207
        parameters.update(url_params)
 
208
 
 
209
        if parameters:
 
210
            return OAuthRequest(http_method, http_url, parameters)
 
211
 
 
212
        return None
 
213
 
 
214
    @staticmethod
 
215
    def from_consumer_and_token(oauth_consumer, token=None, http_method=HTTP_METHOD, http_url=None, parameters=None):
 
216
        if not parameters:
 
217
            parameters = {}
 
218
 
 
219
        defaults = {
 
220
            'oauth_consumer_key': oauth_consumer.key,
 
221
            'oauth_timestamp': generate_timestamp(),
 
222
            'oauth_nonce': generate_nonce(),
 
223
            'oauth_version': OAuthRequest.version,
 
224
        }
 
225
 
 
226
        defaults.update(parameters)
 
227
        parameters = defaults
 
228
 
 
229
        if token:
 
230
            parameters['oauth_token'] = token.key
 
231
 
 
232
        return OAuthRequest(http_method, http_url, parameters)
 
233
 
 
234
    @staticmethod
 
235
    def from_token_and_callback(token, callback=None, http_method=HTTP_METHOD, http_url=None, parameters=None):
 
236
        if not parameters:
 
237
            parameters = {}
 
238
 
 
239
        parameters['oauth_token'] = token.key
 
240
 
 
241
        if callback:
 
242
            parameters['oauth_callback'] = escape(callback)
 
243
 
 
244
        return OAuthRequest(http_method, http_url, parameters)
 
245
 
 
246
    # util function: turn Authorization: header into parameters, has to do some unescaping
 
247
    @staticmethod
 
248
    def _split_header(header):
 
249
        params = {}
 
250
        parts = header.split(',')
 
251
        for param in parts:
 
252
            # ignore realm parameter
 
253
            if param.find('OAuth realm') > -1:
 
254
                continue
 
255
            # remove whitespace
 
256
            param = param.strip()
 
257
            # split key-value
 
258
            param_parts = param.split('=', 1)
 
259
            # remove quotes and unescape the value
 
260
            params[param_parts[0]] = urllib.unquote(param_parts[1].strip('\"'))
 
261
        return params
 
262
    
 
263
    # util function: turn url string into parameters, has to do some unescaping
 
264
    @staticmethod
 
265
    def _split_url_string(param_str):
 
266
        parameters = cgi.parse_qs(param_str, keep_blank_values=False)
 
267
        for k, v in parameters.iteritems():
 
268
            parameters[k] = urllib.unquote(v[0])
 
269
        return parameters
 
270
 
 
271
# OAuthServer is a worker to check a requests validity against a data store
 
272
class OAuthServer(object):
 
273
    timestamp_threshold = 300 # in seconds, five minutes
 
274
    version = VERSION
 
275
    signature_methods = None
 
276
    data_store = None
 
277
 
 
278
    def __init__(self, data_store=None, signature_methods=None):
 
279
        self.data_store = data_store
 
280
        self.signature_methods = signature_methods or {}
 
281
 
 
282
    def set_data_store(self, oauth_data_store):
 
283
        self.data_store = data_store
 
284
 
 
285
    def get_data_store(self):
 
286
        return self.data_store
 
287
 
 
288
    def add_signature_method(self, signature_method):
 
289
        self.signature_methods[signature_method.get_name()] = signature_method
 
290
        return self.signature_methods
 
291
 
 
292
    # process a request_token request
 
293
    # returns the request token on success
 
294
    def fetch_request_token(self, oauth_request):
 
295
        try:
 
296
            # get the request token for authorization
 
297
            token = self._get_token(oauth_request, 'request')
 
298
        except OAuthError:
 
299
            # no token required for the initial token request
 
300
            version = self._get_version(oauth_request)
 
301
            consumer = self._get_consumer(oauth_request)
 
302
            self._check_signature(oauth_request, consumer, None)
 
303
            # fetch a new token
 
304
            token = self.data_store.fetch_request_token(consumer)
 
305
        return token
 
306
 
 
307
    # process an access_token request
 
308
    # returns the access token on success
 
309
    def fetch_access_token(self, oauth_request):
 
310
        version = self._get_version(oauth_request)
 
311
        consumer = self._get_consumer(oauth_request)
 
312
        # get the request token
 
313
        token = self._get_token(oauth_request, 'request')
 
314
        self._check_signature(oauth_request, consumer, token)
 
315
        new_token = self.data_store.fetch_access_token(consumer, token)
 
316
        return new_token
 
317
 
 
318
    # verify an api call, checks all the parameters
 
319
    def verify_request(self, oauth_request):
 
320
        # -> consumer and token
 
321
        version = self._get_version(oauth_request)
 
322
        consumer = self._get_consumer(oauth_request)
 
323
        # get the access token
 
324
        token = self._get_token(oauth_request, 'access')
 
325
        self._check_signature(oauth_request, consumer, token)
 
326
        parameters = oauth_request.get_nonoauth_parameters()
 
327
        return consumer, token, parameters
 
328
 
 
329
    # authorize a request token
 
330
    def authorize_token(self, token, user):
 
331
        return self.data_store.authorize_request_token(token, user)
 
332
    
 
333
    # get the callback url
 
334
    def get_callback(self, oauth_request):
 
335
        return oauth_request.get_parameter('oauth_callback')
 
336
 
 
337
    # optional support for the authenticate header   
 
338
    def build_authenticate_header(self, realm=''):
 
339
        return {'WWW-Authenticate': 'OAuth realm="%s"' % realm}
 
340
 
 
341
    # verify the correct version request for this server
 
342
    def _get_version(self, oauth_request):
 
343
        try:
 
344
            version = oauth_request.get_parameter('oauth_version')
 
345
        except:
 
346
            version = VERSION
 
347
        if version and version != self.version:
 
348
            raise OAuthError('OAuth version %s not supported.' % str(version))
 
349
        return version
 
350
 
 
351
    # figure out the signature with some defaults
 
352
    def _get_signature_method(self, oauth_request):
 
353
        try:
 
354
            signature_method = oauth_request.get_parameter('oauth_signature_method')
 
355
        except:
 
356
            signature_method = SIGNATURE_METHOD
 
357
        try:
 
358
            # get the signature method object
 
359
            signature_method = self.signature_methods[signature_method]
 
360
        except:
 
361
            signature_method_names = ', '.join(self.signature_methods.keys())
 
362
            raise OAuthError('Signature method %s not supported try one of the following: %s' % (signature_method, signature_method_names))
 
363
 
 
364
        return signature_method
 
365
 
 
366
    def _get_consumer(self, oauth_request):
 
367
        consumer_key = oauth_request.get_parameter('oauth_consumer_key')
 
368
        if not consumer_key:
 
369
            raise OAuthError('Invalid consumer key.')
 
370
        consumer = self.data_store.lookup_consumer(consumer_key)
 
371
        if not consumer:
 
372
            raise OAuthError('Invalid consumer.')
 
373
        return consumer
 
374
 
 
375
    # try to find the token for the provided request token key
 
376
    def _get_token(self, oauth_request, token_type='access'):
 
377
        token_field = oauth_request.get_parameter('oauth_token')
 
378
        token = self.data_store.lookup_token(token_type, token_field)
 
379
        if not token:
 
380
            raise OAuthError('Invalid %s token: %s' % (token_type, token_field))
 
381
        return token
 
382
 
 
383
    def _check_signature(self, oauth_request, consumer, token):
 
384
        timestamp, nonce = oauth_request._get_timestamp_nonce()
 
385
        self._check_timestamp(timestamp)
 
386
        self._check_nonce(consumer, token, nonce)
 
387
        signature_method = self._get_signature_method(oauth_request)
 
388
        try:
 
389
            signature = oauth_request.get_parameter('oauth_signature')
 
390
        except:
 
391
            raise OAuthError('Missing signature.')
 
392
        # validate the signature
 
393
        valid_sig = signature_method.check_signature(oauth_request, consumer, token, signature)
 
394
        if not valid_sig:
 
395
            key, base = signature_method.build_signature_base_string(oauth_request, consumer, token)
 
396
            raise OAuthError('Invalid signature. Expected signature base string: %s' % base)
 
397
        built = signature_method.build_signature(oauth_request, consumer, token)
 
398
 
 
399
    def _check_timestamp(self, timestamp):
 
400
        # verify that timestamp is recentish
 
401
        timestamp = int(timestamp)
 
402
        now = int(time.time())
 
403
        lapsed = now - timestamp
 
404
        if lapsed > self.timestamp_threshold:
 
405
            raise OAuthError('Expired timestamp: given %d and now %s has a greater difference than threshold %d' % (timestamp, now, self.timestamp_threshold))
 
406
 
 
407
    def _check_nonce(self, consumer, token, nonce):
 
408
        # verify that the nonce is uniqueish
 
409
        nonce = self.data_store.lookup_nonce(consumer, token, nonce)
 
410
        if nonce:
 
411
            raise OAuthError('Nonce already used: %s' % str(nonce))
 
412
 
 
413
# OAuthClient is a worker to attempt to execute a request
 
414
class OAuthClient(object):
 
415
    consumer = None
 
416
    token = None
 
417
 
 
418
    def __init__(self, oauth_consumer, oauth_token):
 
419
        self.consumer = oauth_consumer
 
420
        self.token = oauth_token
 
421
 
 
422
    def get_consumer(self):
 
423
        return self.consumer
 
424
 
 
425
    def get_token(self):
 
426
        return self.token
 
427
 
 
428
    def fetch_request_token(self, oauth_request):
 
429
        # -> OAuthToken
 
430
        raise NotImplementedError
 
431
 
 
432
    def fetch_access_token(self, oauth_request):
 
433
        # -> OAuthToken
 
434
        raise NotImplementedError
 
435
 
 
436
    def access_resource(self, oauth_request):
 
437
        # -> some protected resource
 
438
        raise NotImplementedError
 
439
 
 
440
# OAuthDataStore is a database abstraction used to lookup consumers and tokens
 
441
class OAuthDataStore(object):
 
442
 
 
443
    def lookup_consumer(self, key):
 
444
        # -> OAuthConsumer
 
445
        raise NotImplementedError
 
446
 
 
447
    def lookup_token(self, oauth_consumer, token_type, token_token):
 
448
        # -> OAuthToken
 
449
        raise NotImplementedError
 
450
 
 
451
    def lookup_nonce(self, oauth_consumer, oauth_token, nonce, timestamp):
 
452
        # -> OAuthToken
 
453
        raise NotImplementedError
 
454
 
 
455
    def fetch_request_token(self, oauth_consumer):
 
456
        # -> OAuthToken
 
457
        raise NotImplementedError
 
458
 
 
459
    def fetch_access_token(self, oauth_consumer, oauth_token):
 
460
        # -> OAuthToken
 
461
        raise NotImplementedError
 
462
 
 
463
    def authorize_request_token(self, oauth_token, user):
 
464
        # -> OAuthToken
 
465
        raise NotImplementedError
 
466
 
 
467
# OAuthSignatureMethod is a strategy class that implements a signature method
 
468
class OAuthSignatureMethod(object):
 
469
    def get_name(self):
 
470
        # -> str
 
471
        raise NotImplementedError
 
472
 
 
473
    def build_signature_base_string(self, oauth_request, oauth_consumer, oauth_token):
 
474
        # -> str key, str raw
 
475
        raise NotImplementedError
 
476
 
 
477
    def build_signature(self, oauth_request, oauth_consumer, oauth_token):
 
478
        # -> str
 
479
        raise NotImplementedError
 
480
 
 
481
    def check_signature(self, oauth_request, consumer, token, signature):
 
482
        built = self.build_signature(oauth_request, consumer, token)
 
483
        return built == signature
 
484
 
 
485
class OAuthSignatureMethod_HMAC_SHA1(OAuthSignatureMethod):
 
486
 
 
487
    def get_name(self):
 
488
        return 'HMAC-SHA1'
 
489
        
 
490
    def build_signature_base_string(self, oauth_request, consumer, token):
 
491
        sig = (
 
492
            escape(oauth_request.get_normalized_http_method()),
 
493
            escape(oauth_request.get_normalized_http_url()),
 
494
            escape(oauth_request.get_normalized_parameters()),
 
495
        )
 
496
 
 
497
        key = '%s&' % escape(consumer.secret)
 
498
        if token:
 
499
            key += escape(token.secret)
 
500
        raw = '&'.join(sig)
 
501
        return key, raw
 
502
 
 
503
    def build_signature(self, oauth_request, consumer, token):
 
504
        # build the base signature string
 
505
        key, raw = self.build_signature_base_string(oauth_request, consumer, token)
 
506
 
 
507
        # hmac object
 
508
        try:
 
509
            import hashlib # 2.5
 
510
            hashed = hmac.new(key, raw, hashlib.sha1)
 
511
        except:
 
512
            import sha # deprecated
 
513
            hashed = hmac.new(key, raw, sha)
 
514
 
 
515
        # calculate the digest base 64
 
516
        return base64.b64encode(hashed.digest())
 
517
 
 
518
class OAuthSignatureMethod_PLAINTEXT(OAuthSignatureMethod):
 
519
 
 
520
    def get_name(self):
 
521
        return 'PLAINTEXT'
 
522
 
 
523
    def build_signature_base_string(self, oauth_request, consumer, token):
 
524
        # concatenate the consumer key and secret
 
525
        sig = escape(consumer.secret) + '&'
 
526
        if token:
 
527
            sig = sig + escape(token.secret)
 
528
        return sig
 
529
 
 
530
    def build_signature(self, oauth_request, consumer, token):
 
531
        return self.build_signature_base_string(oauth_request, consumer, token)