1
# Test interfaces to fortran blas.
3
# The tests are more of interface than they are of the underlying blas.
4
# Only very small matrices checked -- N=3 or so.
6
# !! Complex calculations really aren't checked that carefully.
7
# !! Only real valued complex numbers are used in tests.
12
from numpy.testing import *
14
from blas import fblas
17
#decimal accuracy to require between Python and LAPACK/BLAS calculations
20
# Since numpy.dot likely uses the same blas, use this routine
22
def matrixmultiply(a, b):
28
assert a.shape[1] == b.shape[0]
29
c = zeros((a.shape[0], b.shape[1]), common_type(a, b))
30
for i in xrange(a.shape[0]):
31
for j in xrange(b.shape[1]):
33
for k in xrange(a.shape[1]):
37
c = c.reshape((a.shape[0],))
40
##################################################
43
class base_axpy(NumpyTestCase):
44
def check_default_a(self):
45
x = arange(3.,dtype=self.dtype)
46
y = arange(3.,dtype=x.dtype)
49
assert_array_almost_equal(real_y,y)
50
def check_simple(self):
51
x = arange(3.,dtype=self.dtype)
52
y = arange(3.,dtype=x.dtype)
54
self.blas_func(x,y,a=3.)
55
assert_array_almost_equal(real_y,y)
56
def check_x_stride(self):
57
x = arange(6.,dtype=self.dtype)
59
y = arange(3.,dtype=x.dtype)
61
self.blas_func(x,y,a=3.,n=3,incx=2)
62
assert_array_almost_equal(real_y,y)
63
def check_y_stride(self):
64
x = arange(3.,dtype=self.dtype)
67
self.blas_func(x,y,a=3.,n=3,incy=2)
68
assert_array_almost_equal(real_y,y[::2])
69
def check_x_and_y_stride(self):
70
x = arange(12.,dtype=self.dtype)
72
real_y = x[::4]*3.+y[::2]
73
self.blas_func(x,y,a=3.,n=3,incx=4,incy=2)
74
assert_array_almost_equal(real_y,y[::2])
75
def check_x_bad_size(self):
76
x = arange(12.,dtype=self.dtype)
79
self.blas_func(x,y,n=4,incx=5)
80
except: # what kind of error should be caught?
82
# should catch error and never get here
84
def check_y_bad_size(self):
85
x = arange(12.,dtype=complex64)
88
self.blas_func(x,y,n=3,incy=5)
89
except: # what kind of error should be caught?
91
# should catch error and never get here
95
class test_saxpy(base_axpy):
96
blas_func = fblas.saxpy
98
except AttributeError:
99
class test_saxpy: pass
100
class test_daxpy(base_axpy):
101
blas_func = fblas.daxpy
104
class test_caxpy(base_axpy):
105
blas_func = fblas.caxpy
107
except AttributeError:
108
class test_caxpy: pass
109
class test_zaxpy(base_axpy):
110
blas_func = fblas.zaxpy
114
##################################################
117
class base_scal(NumpyTestCase):
118
def check_simple(self):
119
x = arange(3.,dtype=self.dtype)
122
assert_array_almost_equal(real_x,x)
123
def check_x_stride(self):
124
x = arange(6.,dtype=self.dtype)
126
real_x[::2] = x[::2]*array(3.,self.dtype)
127
self.blas_func(3.,x,n=3,incx=2)
128
assert_array_almost_equal(real_x,x)
129
def check_x_bad_size(self):
130
x = arange(12.,dtype=self.dtype)
132
self.blas_func(2.,x,n=4,incx=5)
133
except: # what kind of error should be caught?
135
# should catch error and never get here
138
class test_sscal(base_scal):
139
blas_func = fblas.sscal
141
except AttributeError:
142
class test_sscal: pass
143
class test_dscal(base_scal):
144
blas_func = fblas.dscal
147
class test_cscal(base_scal):
148
blas_func = fblas.cscal
150
except AttributeError:
151
class test_cscal: pass
152
class test_zscal(base_scal):
153
blas_func = fblas.zscal
159
##################################################
162
class base_copy(NumpyTestCase):
163
def check_simple(self):
164
x = arange(3.,dtype=self.dtype)
165
y = zeros(shape(x),x.dtype)
167
assert_array_almost_equal(x,y)
168
def check_x_stride(self):
169
x = arange(6.,dtype=self.dtype)
171
self.blas_func(x,y,n=3,incx=2)
172
assert_array_almost_equal(x[::2],y)
173
def check_y_stride(self):
174
x = arange(3.,dtype=self.dtype)
176
self.blas_func(x,y,n=3,incy=2)
177
assert_array_almost_equal(x,y[::2])
178
def check_x_and_y_stride(self):
179
x = arange(12.,dtype=self.dtype)
181
self.blas_func(x,y,n=3,incx=4,incy=2)
182
assert_array_almost_equal(x[::4],y[::2])
183
def check_x_bad_size(self):
184
x = arange(12.,dtype=self.dtype)
187
self.blas_func(x,y,n=4,incx=5)
188
except: # what kind of error should be caught?
190
# should catch error and never get here
192
def check_y_bad_size(self):
193
x = arange(12.,dtype=complex64)
196
self.blas_func(x,y,n=3,incy=5)
197
except: # what kind of error should be caught?
199
# should catch error and never get here
201
#def check_y_bad_type(self):
202
## Hmmm. Should this work? What should be the output.
203
# x = arange(3.,dtype=self.dtype)
204
# y = zeros(shape(x))
205
# self.blas_func(x,y)
206
# assert_array_almost_equal(x,y)
209
class test_scopy(base_copy):
210
blas_func = fblas.scopy
212
except AttributeError:
213
class test_scopy: pass
214
class test_dcopy(base_copy):
215
blas_func = fblas.dcopy
218
class test_ccopy(base_copy):
219
blas_func = fblas.ccopy
221
except AttributeError:
222
class test_ccopy: pass
223
class test_zcopy(base_copy):
224
blas_func = fblas.zcopy
228
##################################################
231
class base_swap(NumpyTestCase):
232
def check_simple(self):
233
x = arange(3.,dtype=self.dtype)
234
y = zeros(shape(x),x.dtype)
238
assert_array_almost_equal(desired_x,x)
239
assert_array_almost_equal(desired_y,y)
240
def check_x_stride(self):
241
x = arange(6.,dtype=self.dtype)
244
desired_y = x.copy()[::2]
245
self.blas_func(x,y,n=3,incx=2)
246
assert_array_almost_equal(desired_x,x[::2])
247
assert_array_almost_equal(desired_y,y)
248
def check_y_stride(self):
249
x = arange(3.,dtype=self.dtype)
251
desired_x = y.copy()[::2]
253
self.blas_func(x,y,n=3,incy=2)
254
assert_array_almost_equal(desired_x,x)
255
assert_array_almost_equal(desired_y,y[::2])
257
def check_x_and_y_stride(self):
258
x = arange(12.,dtype=self.dtype)
260
desired_x = y.copy()[::2]
261
desired_y = x.copy()[::4]
262
self.blas_func(x,y,n=3,incx=4,incy=2)
263
assert_array_almost_equal(desired_x,x[::4])
264
assert_array_almost_equal(desired_y,y[::2])
265
def check_x_bad_size(self):
266
x = arange(12.,dtype=self.dtype)
269
self.blas_func(x,y,n=4,incx=5)
270
except: # what kind of error should be caught?
272
# should catch error and never get here
274
def check_y_bad_size(self):
275
x = arange(12.,dtype=complex64)
278
self.blas_func(x,y,n=3,incy=5)
279
except: # what kind of error should be caught?
281
# should catch error and never get here
285
class test_sswap(base_swap):
286
blas_func = fblas.sswap
288
except AttributeError:
289
class test_sswap: pass
290
class test_dswap(base_swap):
291
blas_func = fblas.dswap
294
class test_cswap(base_swap):
295
blas_func = fblas.cswap
297
except AttributeError:
298
class test_cswap: pass
299
class test_zswap(base_swap):
300
blas_func = fblas.zswap
303
##################################################
305
### This will be a mess to test all cases.
307
class base_gemv(NumpyTestCase):
308
def get_data(self,x_stride=1,y_stride=1):
309
mult = array(1, dtype = self.dtype)
310
if self.dtype in [complex64, complex128]:
311
mult = array(1+1j, dtype = self.dtype)
312
from numpy.random import normal
313
alpha = array(1., dtype = self.dtype) * mult
314
beta = array(1.,dtype = self.dtype) * mult
315
a = normal(0.,1.,(3,3)).astype(self.dtype) * mult
316
x = arange(shape(a)[0]*x_stride,dtype=self.dtype) * mult
317
y = arange(shape(a)[1]*y_stride,dtype=self.dtype) * mult
318
return alpha,beta,a,x,y
319
def check_simple(self):
320
alpha,beta,a,x,y = self.get_data()
321
desired_y = alpha*matrixmultiply(a,x)+beta*y
322
y = self.blas_func(alpha,a,x,beta,y)
323
assert_array_almost_equal(desired_y,y)
324
def check_default_beta_y(self):
325
alpha,beta,a,x,y = self.get_data()
326
desired_y = matrixmultiply(a,x)
327
y = self.blas_func(1,a,x)
328
assert_array_almost_equal(desired_y,y)
329
def check_simple_transpose(self):
330
alpha,beta,a,x,y = self.get_data()
331
desired_y = alpha*matrixmultiply(transpose(a),x)+beta*y
332
y = self.blas_func(alpha,a,x,beta,y,trans=1)
333
assert_array_almost_equal(desired_y,y)
334
def check_simple_transpose_conj(self):
335
alpha,beta,a,x,y = self.get_data()
336
desired_y = alpha*matrixmultiply(transpose(conjugate(a)),x)+beta*y
337
y = self.blas_func(alpha,a,x,beta,y,trans=2)
338
assert_array_almost_equal(desired_y,y)
339
def check_x_stride(self):
340
alpha,beta,a,x,y = self.get_data(x_stride=2)
341
desired_y = alpha*matrixmultiply(a,x[::2])+beta*y
342
y = self.blas_func(alpha,a,x,beta,y,incx=2)
343
assert_array_almost_equal(desired_y,y)
344
def check_x_stride_transpose(self):
345
alpha,beta,a,x,y = self.get_data(x_stride=2)
346
desired_y = alpha*matrixmultiply(transpose(a),x[::2])+beta*y
347
y = self.blas_func(alpha,a,x,beta,y,trans=1,incx=2)
348
assert_array_almost_equal(desired_y,y)
349
def check_x_stride_assert(self):
350
# What is the use of this test?
351
alpha,beta,a,x,y = self.get_data(x_stride=2)
353
y = self.blas_func(1,a,x,1,y,trans=0,incx=3)
358
y = self.blas_func(1,a,x,1,y,trans=1,incx=3)
362
def check_y_stride(self):
363
alpha,beta,a,x,y = self.get_data(y_stride=2)
365
desired_y[::2] = alpha*matrixmultiply(a,x)+beta*y[::2]
366
y = self.blas_func(alpha,a,x,beta,y,incy=2)
367
assert_array_almost_equal(desired_y,y)
368
def check_y_stride_transpose(self):
369
alpha,beta,a,x,y = self.get_data(y_stride=2)
371
desired_y[::2] = alpha*matrixmultiply(transpose(a),x)+beta*y[::2]
372
y = self.blas_func(alpha,a,x,beta,y,trans=1,incy=2)
373
assert_array_almost_equal(desired_y,y)
374
def check_y_stride_assert(self):
375
# What is the use of this test?
376
alpha,beta,a,x,y = self.get_data(y_stride=2)
378
y = self.blas_func(1,a,x,1,y,trans=0,incy=3)
383
y = self.blas_func(1,a,x,1,y,trans=1,incy=3)
389
class test_sgemv(base_gemv):
390
blas_func = fblas.sgemv
392
except AttributeError:
393
class test_sgemv: pass
394
class test_dgemv(base_gemv):
395
blas_func = fblas.dgemv
398
class test_cgemv(base_gemv):
399
blas_func = fblas.cgemv
401
except AttributeError:
402
class test_cgemv: pass
403
class test_zgemv(base_gemv):
404
blas_func = fblas.zgemv
408
##################################################
410
### This will be a mess to test all cases.
412
class base_ger(NumpyTestCase):
413
def get_data(self,x_stride=1,y_stride=1):
414
from numpy.random import normal
415
alpha = array(1., dtype = self.dtype)
416
a = normal(0.,1.,(3,3)).astype(self.dtype)
417
x = arange(shape(a)[0]*x_stride,dtype=self.dtype)
418
y = arange(shape(a)[1]*y_stride,dtype=self.dtype)
420
def check_simple(self):
421
alpha,a,x,y = self.get_data()
422
# tranpose takes care of Fortran vs. C(and Python) memory layout
423
desired_a = alpha*transpose(x[:,newaxis]*y) + a
424
self.blas_func(x,y,a)
425
assert_array_almost_equal(desired_a,a)
426
def check_x_stride(self):
427
alpha,a,x,y = self.get_data(x_stride=2)
428
desired_a = alpha*transpose(x[::2,newaxis]*y) + a
429
self.blas_func(x,y,a,incx=2)
430
assert_array_almost_equal(desired_a,a)
431
def check_x_stride_assert(self):
432
alpha,a,x,y = self.get_data(x_stride=2)
434
self.blas_func(x,y,a,incx=3)
438
def check_y_stride(self):
439
alpha,a,x,y = self.get_data(y_stride=2)
440
desired_a = alpha*transpose(x[:,newaxis]*y[::2]) + a
441
self.blas_func(x,y,a,incy=2)
442
assert_array_almost_equal(desired_a,a)
444
def check_y_stride_assert(self):
445
alpha,a,x,y = self.get_data(y_stride=2)
447
self.blas_func(a,x,y,incy=3)
452
class test_sger(base_ger):
453
blas_func = fblas.sger
455
class test_dger(base_ger):
456
blas_func = fblas.dger
459
##################################################
461
### This will be a mess to test all cases.
464
class base_ger_complex(base_ger):
465
def get_data(self,x_stride=1,y_stride=1):
466
from numpy.random import normal
467
alpha = array(1+1j, dtype = self.dtype)
468
a = normal(0.,1.,(3,3)).astype(self.dtype)
469
a = a + normal(0.,1.,(3,3)) * array(1j, dtype = self.dtype)
470
x = normal(0.,1.,shape(a)[0]*x_stride).astype(self.dtype)
471
x = x + x * array(1j, dtype = self.dtype)
472
y = normal(0.,1.,shape(a)[1]*y_stride).astype(self.dtype)
473
y = y + y * array(1j, dtype = self.dtype)
475
def check_simple(self):
476
alpha,a,x,y = self.get_data()
477
# tranpose takes care of Fortran vs. C(and Python) memory layout
478
a = a * array(0.,dtype = self.dtype)
479
#desired_a = alpha*transpose(x[:,newaxis]*self.transform(y)) + a
480
desired_a = alpha*transpose(x[:,newaxis]*y) + a
481
#self.blas_func(x,y,a,alpha = alpha)
482
fblas.cgeru(x,y,a,alpha = alpha)
483
assert_array_almost_equal(desired_a,a)
485
#def check_x_stride(self):
486
# alpha,a,x,y = self.get_data(x_stride=2)
487
# desired_a = alpha*transpose(x[::2,newaxis]*self.transform(y)) + a
488
# self.blas_func(x,y,a,incx=2)
489
# assert_array_almost_equal(desired_a,a)
490
#def check_y_stride(self):
491
# alpha,a,x,y = self.get_data(y_stride=2)
492
# desired_a = alpha*transpose(x[:,newaxis]*self.transform(y[::2])) + a
493
# self.blas_func(x,y,a,incy=2)
494
# assert_array_almost_equal(desired_a,a)
496
class test_cgeru(base_ger_complex):
497
blas_func = fblas.cgeru
499
def transform(self,x):
501
class test_zgeru(base_ger_complex):
502
blas_func = fblas.zgeru
504
def transform(self,x):
507
class test_cgerc(base_ger_complex):
508
blas_func = fblas.cgerc
510
def transform(self,x):
513
class test_zgerc(base_ger_complex):
514
blas_func = fblas.zgerc
516
def transform(self,x):
520
if __name__ == "__main__":