1
#**************************************************************************#
2
#* FILE ************** accelerate_tools.py ************************#
3
#**************************************************************************#
4
#* Author: Patrick Miller February 9 2002 *#
5
#**************************************************************************#
7
accelerate_tools contains the interface for on-the-fly building of
8
C++ equivalents to Python functions.
10
#**************************************************************************#
12
from types import InstanceType, XRangeType
15
import scipy.weave as weave
17
from bytecodecompiler import CXXCoder,Type_Descriptor,Function_Descriptor
20
"Hacky way to get legal C string from Python string"
23
assert isinstance(s, str), "only None and string allowed"
24
r = repr('"'+s) # Better for embedded quotes
25
return '"'+r[2:-1]+'"'
28
##################################################################
30
##################################################################
31
class Instance(Type_Descriptor):
34
def __init__(self,prototype):
35
self.prototype = prototype
38
return "PyInstance_Check(%s)"%s
46
def get_attribute(self,name):
47
proto = getattr(self.prototype,name)
48
T = lookup_type(proto)
49
code = 'tempPY = PyObject_GetAttrString(%%(rhs)s,"%s");\n'%name
50
convert = T.inbound('tempPY')
51
code += '%%(lhsType)s %%(lhs)s = %s;\n'%convert
54
def set_attribute(self,name):
55
proto = getattr(self.prototype,name)
56
T = lookup_type(proto)
57
convert,owned = T.outbound('%(rhs)s')
58
code = 'tempPY = %s;'%convert
60
code += ' Py_INCREF(tempPY);'
61
code += ' PyObject_SetAttrString(%%(lhs)s,"%s",tempPY);'%name
62
code += ' Py_DECREF(tempPY);\n'
65
##################################################################
67
##################################################################
68
class Basic(Type_Descriptor):
71
return "%s(%s)"%(self.checker,s)
73
return "%s(%s)"%(self.inbounder,s)
75
return "%s(%s)"%(self.outbounder,s),self.owned
77
class Basic_Number(Basic):
78
def literalizer(self,s):
80
def binop(self,symbol,a,b):
81
assert symbol in ['+','-','*','/'],symbol
82
return '%s %s %s'%(a,symbol,b),self
84
class Integer(Basic_Number):
86
checker = "PyInt_Check"
87
inbounder = "PyInt_AsLong"
88
outbounder = "PyInt_FromLong"
90
class Double(Basic_Number):
92
checker = "PyFloat_Check"
93
inbounder = "PyFloat_AsDouble"
94
outbounder = "PyFloat_FromDouble"
98
checker = "PyString_Check"
99
inbounder = "PyString_AsString"
100
outbounder = "PyString_FromString"
102
def literalizer(self,s):
105
# -----------------------------------------------
106
# Singletonize the type names
107
# -----------------------------------------------
114
class Vector(Type_Descriptor):
115
cxxtype = 'PyArrayObject*'
118
module_init_code = 'import_array();\n'
119
inbounder = "(PyArrayObject*)"
120
outbounder = "(PyObject*)"
121
owned = 0 # Convertion is by casting!
123
prerequisites = Type_Descriptor.prerequisites+\
124
['#include "numpy/arrayobject.h"']
127
return "PyArray_Check(%s) && ((PyArrayObject*)%s)->nd == %d && ((PyArrayObject*)%s)->descr->type_num == %s"%(
128
s,s,self.dims,s,self.typecode)
131
return "%s(%s)"%(self.inbounder,s)
132
def outbound(self,s):
133
return "%s(%s)"%(self.outbounder,s),self.owned
135
def getitem(self,A,v,t):
136
assert self.dims == len(v),'Expect dimension %d'%self.dims
137
code = '*((%s*)(%s->data'%(self.cxxbase,A)
138
for i in range(self.dims):
139
# assert that ''t[i]'' is an integer
140
code += '+%s*%s->strides[%d]'%(v[i],A,i)
142
return code,self.pybase
143
def setitem(self,A,v,t):
144
return self.getitem(A,v,t)
146
class matrix(Vector):
149
class IntegerVector(Vector):
150
typecode = 'PyArray_INT'
154
class Integermatrix(matrix):
155
typecode = 'PyArray_INT'
159
class LongVector(Vector):
160
typecode = 'PyArray_LONG'
164
class Longmatrix(matrix):
165
typecode = 'PyArray_LONG'
169
class DoubleVector(Vector):
170
typecode = 'PyArray_DOUBLE'
174
class Doublematrix(matrix):
175
typecode = 'PyArray_DOUBLE'
180
##################################################################
182
##################################################################
183
class XRange(Type_Descriptor):
188
XRange(long aLow, long aHigh, long aStep=1)
189
: low(aLow),high(aHigh),step(aStep)
193
: low(0),high(aHigh),step(1)
201
# -----------------------------------------------
202
# Singletonize the type names
203
# -----------------------------------------------
204
IntegerVector = IntegerVector()
205
Integermatrix = Integermatrix()
206
LongVector = LongVector()
207
Longmatrix = Longmatrix()
208
DoubleVector = DoubleVector()
209
Doublematrix = Doublematrix()
217
(nx.ndarray,1,int): IntegerVector,
218
(nx.ndarray,2,int): Integermatrix,
219
(nx.ndarray,1,nx.long): LongVector,
220
(nx.ndarray,2,nx.long): Longmatrix,
221
(nx.ndarray,1,float): DoubleVector,
222
(nx.ndarray,2,float): Doublematrix,
229
Function_Descriptor(code='strlen(%s)',return_type=Integer),
232
Function_Descriptor(code='PyArray_Size((PyObject*)%s)',return_type=Integer),
235
Function_Descriptor(code='(double)(%s)',return_type=Double),
237
(range,(Integer,Integer)):
238
Function_Descriptor(code='XRange(%s)',return_type=XRange),
241
Function_Descriptor(code='XRange(%s)',return_type=XRange),
243
(math.sin,(Double,)):
244
Function_Descriptor(code='sin(%s)',return_type=Double),
246
(math.cos,(Double,)):
247
Function_Descriptor(code='cos(%s)',return_type=Double),
249
(math.sqrt,(Double,)):
250
Function_Descriptor(code='sqrt(%s)',return_type=Double),
255
##################################################################
256
# FUNCTION LOOKUP_TYPE #
257
##################################################################
263
if isinstance(T,nx.ndarray):
264
return typedefs[(T,len(x.shape),x.dtype.char)]
265
elif issubclass(T, InstanceType):
268
raise NotImplementedError,T
270
##################################################################
272
##################################################################
273
class accelerate(object):
275
def __init__(self, function, *args, **kw):
276
assert inspect.isfunction(function)
277
self.function = function
278
self.module = inspect.getmodule(function)
279
if self.module is None:
281
self.module = __main__
284
def __cache(self,*args):
287
def __call__(self,*args):
289
return self.__cache(*args)
291
# Figure out type info -- Do as tuple so its hashable
292
signature = tuple( map(lookup_type,args) )
294
# If we know the function, call it
296
fast = self.__call_map[signature]
298
fast = self.singleton(signature)
300
self.__call_map[signature] = fast
303
def signature(self,*args):
304
# Figure out type info -- Do as tuple so its hashable
305
signature = tuple( map(lookup_type,args) )
306
return self.singleton(signature)
309
def singleton(self,signature):
310
identifier = self.identifier(signature)
312
# Generate a new function, then call it
315
# See if we have an accelerated version of module
317
print 'lookup',self.module.__name__+'_weave'
318
accelerated_module = __import__(self.module.__name__+'_weave')
319
print 'have accelerated',self.module.__name__+'_weave'
320
fast = getattr(accelerated_module,identifier)
323
accelerated_module = None
324
except AttributeError:
327
P = self.accelerate(signature,identifier)
329
E = weave.ext_tools.ext_module(self.module.__name__+'_weave')
332
weave.build_tools.build_extension(self.module.__name__+'_weave.cpp',verbose=2)
334
if accelerated_module:
335
raise NotImplementedError,'Reload'
337
accelerated_module = __import__(self.module.__name__+'_weave')
339
fast = getattr(accelerated_module,identifier)
342
def identifier(self,signature):
343
# Build an MD5 checksum
346
identifier = str(signature)+\
347
str(co.co_argcount)+\
349
str(co.co_varnames)+\
351
return 'F'+md5.md5(identifier).hexdigest()
353
def accelerate(self,signature,identifier):
354
P = Python2CXX(self.function,signature,name=identifier)
357
def code(self,*args):
358
if len(args) != self.function.func_code.co_argcount:
359
raise TypeError,'%s() takes exactly %d arguments (%d given)'%(
360
self.function.__name__,
361
self.function.func_code.co_argcount,
363
signature = tuple( map(lookup_type,args) )
364
ident = self.function.__name__
365
return self.accelerate(signature,ident).function_code()
368
##################################################################
370
##################################################################
371
class Python2CXX(CXXCoder):
372
def typedef_by_value(self,v):
374
if T not in self.used:
378
def function_by_signature(self,signature):
379
descriptor = functiondefs[signature]
380
if descriptor.return_type not in self.used:
381
self.used.append(descriptor.return_type)
384
def __init__(self,f,signature,name=None):
385
# Make sure function is a function
386
assert inspect.isfunction(f)
387
# and check the input type signature
388
assert reduce(lambda x,y: x and y,
389
map(lambda x: isinstance(x,Type_Descriptor),
391
1),'%s not all type objects'%signature
393
self.customize = weave.base_info.custom_info()
395
CXXCoder.__init__(self,f,signature,name)
399
def function_code(self):
400
code = self.wrapped_code()
402
if T != None and T.module_init_code:
403
self.customize.add_module_init_code(T.module_init_code)
406
def python_function_definition_code(self):
407
return '{ "%s", wrapper_%s, METH_VARARGS, %s },\n'%(
410
CStr(self.function.__doc__))