~ubuntu-branches/ubuntu/raring/python-scipy/raring-proposed

« back to all changes in this revision

Viewing changes to Lib/weave/accelerate_tools.py

  • Committer: Bazaar Package Importer
  • Author(s): Matthias Klose
  • Date: 2007-01-07 14:12:12 UTC
  • mfrom: (1.1.1 upstream)
  • Revision ID: james.westby@ubuntu.com-20070107141212-mm0ebkh5b37hcpzn
* Remove build dependency on python-numpy-dev.
* python-scipy: Depend on python-numpy instead of python-numpy-dev.
* Package builds on other archs than i386. Closes: #402783.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#**************************************************************************#
 
2
#* FILE   **************    accelerate_tools.py    ************************#
 
3
#**************************************************************************#
 
4
#* Author: Patrick Miller February  9 2002                                *#
 
5
#**************************************************************************#
 
6
"""
 
7
accelerate_tools contains the interface for on-the-fly building of
 
8
C++ equivalents to Python functions.
 
9
"""
 
10
#**************************************************************************#
 
11
 
 
12
from types import InstanceType, XRangeType
 
13
import inspect
 
14
import md5
 
15
import scipy.weave as weave
 
16
 
 
17
from bytecodecompiler import CXXCoder,Type_Descriptor,Function_Descriptor
 
18
 
 
19
def CStr(s):
 
20
    "Hacky way to get legal C string from Python string"
 
21
    if s is None:
 
22
        return '""'
 
23
    assert isinstance(s, str), "only None and string allowed"
 
24
    r = repr('"'+s) # Better for embedded quotes
 
25
    return '"'+r[2:-1]+'"'
 
26
 
 
27
 
 
28
##################################################################
 
29
#                         CLASS INSTANCE                         #
 
30
##################################################################
 
31
class Instance(Type_Descriptor):
 
32
    cxxtype = 'PyObject*'
 
33
 
 
34
    def __init__(self,prototype):
 
35
        self.prototype  = prototype
 
36
 
 
37
    def check(self,s):
 
38
        return "PyInstance_Check(%s)"%s
 
39
 
 
40
    def inbound(self,s):
 
41
        return s
 
42
 
 
43
    def outbound(self,s):
 
44
        return s,0
 
45
 
 
46
    def get_attribute(self,name):
 
47
        proto = getattr(self.prototype,name)
 
48
        T = lookup_type(proto)
 
49
        code = 'tempPY = PyObject_GetAttrString(%%(rhs)s,"%s");\n'%name
 
50
        convert = T.inbound('tempPY')
 
51
        code += '%%(lhsType)s %%(lhs)s = %s;\n'%convert
 
52
        return T,code
 
53
 
 
54
    def set_attribute(self,name):
 
55
        proto = getattr(self.prototype,name)
 
56
        T = lookup_type(proto)
 
57
        convert,owned = T.outbound('%(rhs)s')
 
58
        code = 'tempPY = %s;'%convert
 
59
        if not owned:
 
60
            code += ' Py_INCREF(tempPY);'
 
61
        code += ' PyObject_SetAttrString(%%(lhs)s,"%s",tempPY);'%name
 
62
        code += ' Py_DECREF(tempPY);\n'
 
63
        return T,code
 
64
 
 
65
##################################################################
 
66
#                          CLASS BASIC                           #
 
67
##################################################################
 
68
class Basic(Type_Descriptor):
 
69
    owned = 1
 
70
    def check(self,s):
 
71
        return "%s(%s)"%(self.checker,s)
 
72
    def inbound(self,s):
 
73
        return "%s(%s)"%(self.inbounder,s)
 
74
    def outbound(self,s):
 
75
        return "%s(%s)"%(self.outbounder,s),self.owned
 
76
 
 
77
class Basic_Number(Basic):
 
78
    def literalizer(self,s):
 
79
        return str(s)
 
80
    def binop(self,symbol,a,b):
 
81
        assert symbol in ['+','-','*','/'],symbol
 
82
        return '%s %s %s'%(a,symbol,b),self
 
83
 
 
84
class Integer(Basic_Number):
 
85
    cxxtype = "long"
 
86
    checker = "PyInt_Check"
 
87
    inbounder = "PyInt_AsLong"
 
88
    outbounder = "PyInt_FromLong"
 
89
 
 
90
class Double(Basic_Number):
 
91
    cxxtype = "double"
 
92
    checker = "PyFloat_Check"
 
93
    inbounder = "PyFloat_AsDouble"
 
94
    outbounder = "PyFloat_FromDouble"
 
95
 
 
96
class String(Basic):
 
97
    cxxtype = "char*"
 
98
    checker = "PyString_Check"
 
99
    inbounder = "PyString_AsString"
 
100
    outbounder = "PyString_FromString"
 
101
 
 
102
    def literalizer(self,s):
 
103
        return CStr(s)
 
104
 
 
105
# -----------------------------------------------
 
106
# Singletonize the type names
 
107
# -----------------------------------------------
 
108
Integer = Integer()
 
109
Double = Double()
 
110
String = String()
 
111
 
 
112
import numpy as nx
 
113
 
 
114
class Vector(Type_Descriptor):
 
115
    cxxtype = 'PyArrayObject*'
 
116
    refcount = 1
 
117
    dims = 1
 
118
    module_init_code = 'import_array();\n'
 
119
    inbounder = "(PyArrayObject*)"
 
120
    outbounder = "(PyObject*)"
 
121
    owned = 0 # Convertion is by casting!
 
122
 
 
123
    prerequisites = Type_Descriptor.prerequisites+\
 
124
                    ['#include "numpy/arrayobject.h"']
 
125
    dims = 1
 
126
    def check(self,s):
 
127
        return "PyArray_Check(%s) && ((PyArrayObject*)%s)->nd == %d &&  ((PyArrayObject*)%s)->descr->type_num == %s"%(
 
128
            s,s,self.dims,s,self.typecode)
 
129
 
 
130
    def inbound(self,s):
 
131
        return "%s(%s)"%(self.inbounder,s)
 
132
    def outbound(self,s):
 
133
        return "%s(%s)"%(self.outbounder,s),self.owned
 
134
 
 
135
    def getitem(self,A,v,t):
 
136
        assert self.dims == len(v),'Expect dimension %d'%self.dims
 
137
        code = '*((%s*)(%s->data'%(self.cxxbase,A)
 
138
        for i in range(self.dims):
 
139
            # assert that ''t[i]'' is an integer
 
140
            code += '+%s*%s->strides[%d]'%(v[i],A,i)
 
141
        code += '))'
 
142
        return code,self.pybase
 
143
    def setitem(self,A,v,t):
 
144
        return self.getitem(A,v,t)
 
145
 
 
146
class matrix(Vector):
 
147
    dims = 2
 
148
 
 
149
class IntegerVector(Vector):
 
150
    typecode = 'PyArray_INT'
 
151
    cxxbase = 'int'
 
152
    pybase = Integer
 
153
 
 
154
class Integermatrix(matrix):
 
155
    typecode = 'PyArray_INT'
 
156
    cxxbase = 'int'
 
157
    pybase = Integer
 
158
 
 
159
class LongVector(Vector):
 
160
    typecode = 'PyArray_LONG'
 
161
    cxxbase = 'long'
 
162
    pybase = Integer
 
163
 
 
164
class Longmatrix(matrix):
 
165
    typecode = 'PyArray_LONG'
 
166
    cxxbase = 'long'
 
167
    pybase = Integer
 
168
 
 
169
class DoubleVector(Vector):
 
170
    typecode = 'PyArray_DOUBLE'
 
171
    cxxbase = 'double'
 
172
    pybase = Double
 
173
 
 
174
class Doublematrix(matrix):
 
175
    typecode = 'PyArray_DOUBLE'
 
176
    cxxbase = 'double'
 
177
    pybase = Double
 
178
 
 
179
 
 
180
##################################################################
 
181
#                          CLASS XRANGE                          #
 
182
##################################################################
 
183
class XRange(Type_Descriptor):
 
184
    cxxtype = 'XRange'
 
185
    prerequisites = ['''
 
186
    class XRange {
 
187
    public:
 
188
    XRange(long aLow, long aHigh, long aStep=1)
 
189
    : low(aLow),high(aHigh),step(aStep)
 
190
    {
 
191
    }
 
192
    XRange(long aHigh)
 
193
    : low(0),high(aHigh),step(1)
 
194
    {
 
195
    }
 
196
    long low;
 
197
    long high;
 
198
    long step;
 
199
    };''']
 
200
 
 
201
# -----------------------------------------------
 
202
# Singletonize the type names
 
203
# -----------------------------------------------
 
204
IntegerVector = IntegerVector()
 
205
Integermatrix = Integermatrix()
 
206
LongVector = LongVector()
 
207
Longmatrix = Longmatrix()
 
208
DoubleVector = DoubleVector()
 
209
Doublematrix = Doublematrix()
 
210
XRange = XRange()
 
211
 
 
212
 
 
213
typedefs = {
 
214
    int : Integer,
 
215
    float : Double,
 
216
    str: String,
 
217
    (nx.ndarray,1,int): IntegerVector,
 
218
    (nx.ndarray,2,int): Integermatrix,
 
219
    (nx.ndarray,1,nx.long): LongVector,
 
220
    (nx.ndarray,2,nx.long): Longmatrix,
 
221
    (nx.ndarray,1,float): DoubleVector,
 
222
    (nx.ndarray,2,float): Doublematrix,
 
223
    XRangeType : XRange,
 
224
    }
 
225
 
 
226
import math
 
227
functiondefs = {
 
228
    (len,(String,)):
 
229
    Function_Descriptor(code='strlen(%s)',return_type=Integer),
 
230
 
 
231
    (len,(LongVector,)):
 
232
    Function_Descriptor(code='PyArray_Size((PyObject*)%s)',return_type=Integer),
 
233
 
 
234
    (float,(Integer,)):
 
235
    Function_Descriptor(code='(double)(%s)',return_type=Double),
 
236
 
 
237
    (range,(Integer,Integer)):
 
238
    Function_Descriptor(code='XRange(%s)',return_type=XRange),
 
239
 
 
240
    (range,(Integer)):
 
241
    Function_Descriptor(code='XRange(%s)',return_type=XRange),
 
242
 
 
243
    (math.sin,(Double,)):
 
244
    Function_Descriptor(code='sin(%s)',return_type=Double),
 
245
 
 
246
    (math.cos,(Double,)):
 
247
    Function_Descriptor(code='cos(%s)',return_type=Double),
 
248
 
 
249
    (math.sqrt,(Double,)):
 
250
    Function_Descriptor(code='sqrt(%s)',return_type=Double),
 
251
    }
 
252
 
 
253
 
 
254
 
 
255
##################################################################
 
256
#                      FUNCTION LOOKUP_TYPE                      #
 
257
##################################################################
 
258
def lookup_type(x):
 
259
    T = type(x)
 
260
    try:
 
261
        return typedefs[T]
 
262
    except:
 
263
        if isinstance(T,nx.ndarray):
 
264
            return typedefs[(T,len(x.shape),x.dtype.char)]
 
265
        elif issubclass(T, InstanceType):
 
266
            return Instance(x)
 
267
        else:
 
268
            raise NotImplementedError,T
 
269
 
 
270
##################################################################
 
271
#                        class ACCELERATE                        #
 
272
##################################################################
 
273
class accelerate(object):
 
274
 
 
275
    def __init__(self, function, *args, **kw):
 
276
        assert inspect.isfunction(function)
 
277
        self.function = function
 
278
        self.module = inspect.getmodule(function)
 
279
        if self.module is None:
 
280
            import __main__
 
281
            self.module = __main__
 
282
        self.__call_map = {}
 
283
 
 
284
    def __cache(self,*args):
 
285
        raise TypeError
 
286
 
 
287
    def __call__(self,*args):
 
288
        try:
 
289
            return self.__cache(*args)
 
290
        except TypeError:
 
291
            # Figure out type info -- Do as tuple so its hashable
 
292
            signature = tuple( map(lookup_type,args) )
 
293
 
 
294
            # If we know the function, call it
 
295
            try:
 
296
                fast = self.__call_map[signature]
 
297
            except:
 
298
                fast = self.singleton(signature)
 
299
                self.__cache = fast
 
300
                self.__call_map[signature] = fast
 
301
            return fast(*args)
 
302
 
 
303
    def signature(self,*args):
 
304
        # Figure out type info -- Do as tuple so its hashable
 
305
        signature = tuple( map(lookup_type,args) )
 
306
        return self.singleton(signature)
 
307
 
 
308
 
 
309
    def singleton(self,signature):
 
310
        identifier = self.identifier(signature)
 
311
 
 
312
        # Generate a new function, then call it
 
313
        f = self.function
 
314
 
 
315
        # See if we have an accelerated version of module
 
316
        try:
 
317
            print 'lookup',self.module.__name__+'_weave'
 
318
            accelerated_module = __import__(self.module.__name__+'_weave')
 
319
            print 'have accelerated',self.module.__name__+'_weave'
 
320
            fast = getattr(accelerated_module,identifier)
 
321
            return fast
 
322
        except ImportError:
 
323
            accelerated_module = None
 
324
        except AttributeError:
 
325
            pass
 
326
 
 
327
        P = self.accelerate(signature,identifier)
 
328
 
 
329
        E = weave.ext_tools.ext_module(self.module.__name__+'_weave')
 
330
        E.add_function(P)
 
331
        E.generate_file()
 
332
        weave.build_tools.build_extension(self.module.__name__+'_weave.cpp',verbose=2)
 
333
 
 
334
        if accelerated_module:
 
335
            raise NotImplementedError,'Reload'
 
336
        else:
 
337
            accelerated_module = __import__(self.module.__name__+'_weave')
 
338
 
 
339
        fast = getattr(accelerated_module,identifier)
 
340
        return fast
 
341
 
 
342
    def identifier(self,signature):
 
343
        # Build an MD5 checksum
 
344
        f = self.function
 
345
        co = f.func_code
 
346
        identifier = str(signature)+\
 
347
                     str(co.co_argcount)+\
 
348
                     str(co.co_consts)+\
 
349
                     str(co.co_varnames)+\
 
350
                     co.co_code
 
351
        return 'F'+md5.md5(identifier).hexdigest()
 
352
 
 
353
    def accelerate(self,signature,identifier):
 
354
        P = Python2CXX(self.function,signature,name=identifier)
 
355
        return P
 
356
 
 
357
    def code(self,*args):
 
358
        if len(args) != self.function.func_code.co_argcount:
 
359
            raise TypeError,'%s() takes exactly %d arguments (%d given)'%(
 
360
                self.function.__name__,
 
361
                self.function.func_code.co_argcount,
 
362
                len(args))
 
363
        signature = tuple( map(lookup_type,args) )
 
364
        ident = self.function.__name__
 
365
        return self.accelerate(signature,ident).function_code()
 
366
 
 
367
 
 
368
##################################################################
 
369
#                        CLASS PYTHON2CXX                        #
 
370
##################################################################
 
371
class Python2CXX(CXXCoder):
 
372
    def typedef_by_value(self,v):
 
373
        T = lookup_type(v)
 
374
        if T not in self.used:
 
375
            self.used.append(T)
 
376
        return T
 
377
 
 
378
    def function_by_signature(self,signature):
 
379
        descriptor = functiondefs[signature]
 
380
        if descriptor.return_type not in self.used:
 
381
            self.used.append(descriptor.return_type)
 
382
        return descriptor
 
383
 
 
384
    def __init__(self,f,signature,name=None):
 
385
        # Make sure function is a function
 
386
        assert inspect.isfunction(f)
 
387
        # and check the input type signature
 
388
        assert reduce(lambda x,y: x and y,
 
389
                      map(lambda x: isinstance(x,Type_Descriptor),
 
390
                          signature),
 
391
                      1),'%s not all type objects'%signature
 
392
        self.arg_specs = []
 
393
        self.customize = weave.base_info.custom_info()
 
394
 
 
395
        CXXCoder.__init__(self,f,signature,name)
 
396
 
 
397
        return
 
398
 
 
399
    def function_code(self):
 
400
        code = self.wrapped_code()
 
401
        for T in self.used:
 
402
            if T != None and T.module_init_code:
 
403
                self.customize.add_module_init_code(T.module_init_code)
 
404
        return code
 
405
 
 
406
    def python_function_definition_code(self):
 
407
        return '{ "%s", wrapper_%s, METH_VARARGS, %s },\n'%(
 
408
            self.name,
 
409
            self.name,
 
410
            CStr(self.function.__doc__))