~ubuntu-branches/ubuntu/saucy/python-scipy/saucy

« back to all changes in this revision

Viewing changes to scipy/lib/blas/tests/test_fblas.py

  • Committer: Bazaar Package Importer
  • Author(s): Ondrej Certik
  • Date: 2008-06-16 22:58:01 UTC
  • mfrom: (2.1.24 intrepid)
  • Revision ID: james.westby@ubuntu.com-20080616225801-irdhrpcwiocfbcmt
Tags: 0.6.0-12
* The description updated to match the current SciPy (Closes: #489149).
* Standards-Version bumped to 3.8.0 (no action needed)
* Build-Depends: netcdf-dev changed to libnetcdf-dev

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Test interfaces to fortran blas.
 
2
#
 
3
# The tests are more of interface than they are of the underlying blas.
 
4
# Only very small matrices checked -- N=3 or so.
 
5
#
 
6
# !! Complex calculations really aren't checked that carefully.
 
7
# !! Only real valued complex numbers are used in tests.
 
8
 
 
9
from numpy import *
 
10
 
 
11
import sys
 
12
from numpy.testing import *
 
13
set_package_path()
 
14
from blas import fblas
 
15
restore_path()
 
16
 
 
17
#decimal accuracy to require between Python and LAPACK/BLAS calculations
 
18
accuracy = 5
 
19
 
 
20
# Since numpy.dot likely uses the same blas, use this routine
 
21
# to check.
 
22
def matrixmultiply(a, b):
 
23
    if len(b.shape) == 1:
 
24
        b_is_vector = True
 
25
        b = b[:,newaxis]
 
26
    else:
 
27
        b_is_vector = False
 
28
    assert a.shape[1] == b.shape[0]
 
29
    c = zeros((a.shape[0], b.shape[1]), common_type(a, b))
 
30
    for i in xrange(a.shape[0]):
 
31
        for j in xrange(b.shape[1]):
 
32
            s = 0
 
33
            for k in xrange(a.shape[1]):
 
34
                s += a[i,k] * b[k, j]
 
35
            c[i,j] = s
 
36
    if b_is_vector:
 
37
        c = c.reshape((a.shape[0],))
 
38
    return c
 
39
 
 
40
##################################################
 
41
### Test blas ?axpy
 
42
 
 
43
class base_axpy(NumpyTestCase):
 
44
    def check_default_a(self):
 
45
        x = arange(3.,dtype=self.dtype)
 
46
        y = arange(3.,dtype=x.dtype)
 
47
        real_y = x*1.+y
 
48
        self.blas_func(x,y)
 
49
        assert_array_almost_equal(real_y,y)
 
50
    def check_simple(self):
 
51
        x = arange(3.,dtype=self.dtype)
 
52
        y = arange(3.,dtype=x.dtype)
 
53
        real_y = x*3.+y
 
54
        self.blas_func(x,y,a=3.)
 
55
        assert_array_almost_equal(real_y,y)
 
56
    def check_x_stride(self):
 
57
        x = arange(6.,dtype=self.dtype)
 
58
        y = zeros(3,x.dtype)
 
59
        y = arange(3.,dtype=x.dtype)
 
60
        real_y = x[::2]*3.+y
 
61
        self.blas_func(x,y,a=3.,n=3,incx=2)
 
62
        assert_array_almost_equal(real_y,y)
 
63
    def check_y_stride(self):
 
64
        x = arange(3.,dtype=self.dtype)
 
65
        y = zeros(6,x.dtype)
 
66
        real_y = x*3.+y[::2]
 
67
        self.blas_func(x,y,a=3.,n=3,incy=2)
 
68
        assert_array_almost_equal(real_y,y[::2])
 
69
    def check_x_and_y_stride(self):
 
70
        x = arange(12.,dtype=self.dtype)
 
71
        y = zeros(6,x.dtype)
 
72
        real_y = x[::4]*3.+y[::2]
 
73
        self.blas_func(x,y,a=3.,n=3,incx=4,incy=2)
 
74
        assert_array_almost_equal(real_y,y[::2])
 
75
    def check_x_bad_size(self):
 
76
        x = arange(12.,dtype=self.dtype)
 
77
        y = zeros(6,x.dtype)
 
78
        try:
 
79
            self.blas_func(x,y,n=4,incx=5)
 
80
        except: # what kind of error should be caught?
 
81
            return
 
82
        # should catch error and never get here
 
83
        assert(0)
 
84
    def check_y_bad_size(self):
 
85
        x = arange(12.,dtype=complex64)
 
86
        y = zeros(6,x.dtype)
 
87
        try:
 
88
            self.blas_func(x,y,n=3,incy=5)
 
89
        except: # what kind of error should be caught?
 
90
            return
 
91
        # should catch error and never get here
 
92
        assert(0)
 
93
 
 
94
try:
 
95
    class test_saxpy(base_axpy):
 
96
        blas_func = fblas.saxpy
 
97
        dtype = float32
 
98
except AttributeError:
 
99
    class test_saxpy: pass
 
100
class test_daxpy(base_axpy):
 
101
    blas_func = fblas.daxpy
 
102
    dtype = float64
 
103
try:
 
104
    class test_caxpy(base_axpy):
 
105
        blas_func = fblas.caxpy
 
106
        dtype = complex64
 
107
except AttributeError:
 
108
    class test_caxpy: pass
 
109
class test_zaxpy(base_axpy):
 
110
    blas_func = fblas.zaxpy
 
111
    dtype = complex128
 
112
 
 
113
 
 
114
##################################################
 
115
### Test blas ?scal
 
116
 
 
117
class base_scal(NumpyTestCase):
 
118
    def check_simple(self):
 
119
        x = arange(3.,dtype=self.dtype)
 
120
        real_x = x*3.
 
121
        self.blas_func(3.,x)
 
122
        assert_array_almost_equal(real_x,x)
 
123
    def check_x_stride(self):
 
124
        x = arange(6.,dtype=self.dtype)
 
125
        real_x = x.copy()
 
126
        real_x[::2] = x[::2]*array(3.,self.dtype)
 
127
        self.blas_func(3.,x,n=3,incx=2)
 
128
        assert_array_almost_equal(real_x,x)
 
129
    def check_x_bad_size(self):
 
130
        x = arange(12.,dtype=self.dtype)
 
131
        try:
 
132
            self.blas_func(2.,x,n=4,incx=5)
 
133
        except: # what kind of error should be caught?
 
134
            return
 
135
        # should catch error and never get here
 
136
        assert(0)
 
137
try:
 
138
    class test_sscal(base_scal):
 
139
        blas_func = fblas.sscal
 
140
        dtype = float32
 
141
except AttributeError:
 
142
    class test_sscal: pass
 
143
class test_dscal(base_scal):
 
144
    blas_func = fblas.dscal
 
145
    dtype = float64
 
146
try:
 
147
    class test_cscal(base_scal):
 
148
        blas_func = fblas.cscal
 
149
        dtype = complex64
 
150
except AttributeError:
 
151
    class test_cscal: pass
 
152
class test_zscal(base_scal):
 
153
    blas_func = fblas.zscal
 
154
    dtype = complex128
 
155
 
 
156
 
 
157
 
 
158
 
 
159
##################################################
 
160
### Test blas ?copy
 
161
 
 
162
class base_copy(NumpyTestCase):
 
163
    def check_simple(self):
 
164
        x = arange(3.,dtype=self.dtype)
 
165
        y = zeros(shape(x),x.dtype)
 
166
        self.blas_func(x,y)
 
167
        assert_array_almost_equal(x,y)
 
168
    def check_x_stride(self):
 
169
        x = arange(6.,dtype=self.dtype)
 
170
        y = zeros(3,x.dtype)
 
171
        self.blas_func(x,y,n=3,incx=2)
 
172
        assert_array_almost_equal(x[::2],y)
 
173
    def check_y_stride(self):
 
174
        x = arange(3.,dtype=self.dtype)
 
175
        y = zeros(6,x.dtype)
 
176
        self.blas_func(x,y,n=3,incy=2)
 
177
        assert_array_almost_equal(x,y[::2])
 
178
    def check_x_and_y_stride(self):
 
179
        x = arange(12.,dtype=self.dtype)
 
180
        y = zeros(6,x.dtype)
 
181
        self.blas_func(x,y,n=3,incx=4,incy=2)
 
182
        assert_array_almost_equal(x[::4],y[::2])
 
183
    def check_x_bad_size(self):
 
184
        x = arange(12.,dtype=self.dtype)
 
185
        y = zeros(6,x.dtype)
 
186
        try:
 
187
            self.blas_func(x,y,n=4,incx=5)
 
188
        except: # what kind of error should be caught?
 
189
            return
 
190
        # should catch error and never get here
 
191
        assert(0)
 
192
    def check_y_bad_size(self):
 
193
        x = arange(12.,dtype=complex64)
 
194
        y = zeros(6,x.dtype)
 
195
        try:
 
196
            self.blas_func(x,y,n=3,incy=5)
 
197
        except: # what kind of error should be caught?
 
198
            return
 
199
        # should catch error and never get here
 
200
        assert(0)
 
201
    #def check_y_bad_type(self):
 
202
    ##   Hmmm. Should this work?  What should be the output.
 
203
    #    x = arange(3.,dtype=self.dtype)
 
204
    #    y = zeros(shape(x))
 
205
    #    self.blas_func(x,y)
 
206
    #    assert_array_almost_equal(x,y)
 
207
 
 
208
try:
 
209
    class test_scopy(base_copy):
 
210
        blas_func = fblas.scopy
 
211
        dtype = float32
 
212
except AttributeError:
 
213
    class test_scopy: pass
 
214
class test_dcopy(base_copy):
 
215
    blas_func = fblas.dcopy
 
216
    dtype = float64
 
217
try:
 
218
    class test_ccopy(base_copy):
 
219
        blas_func = fblas.ccopy
 
220
        dtype = complex64
 
221
except AttributeError:
 
222
    class test_ccopy: pass
 
223
class test_zcopy(base_copy):
 
224
    blas_func = fblas.zcopy
 
225
    dtype = complex128
 
226
 
 
227
 
 
228
##################################################
 
229
### Test blas ?swap
 
230
 
 
231
class base_swap(NumpyTestCase):
 
232
    def check_simple(self):
 
233
        x = arange(3.,dtype=self.dtype)
 
234
        y = zeros(shape(x),x.dtype)
 
235
        desired_x = y.copy()
 
236
        desired_y = x.copy()
 
237
        self.blas_func(x,y)
 
238
        assert_array_almost_equal(desired_x,x)
 
239
        assert_array_almost_equal(desired_y,y)
 
240
    def check_x_stride(self):
 
241
        x = arange(6.,dtype=self.dtype)
 
242
        y = zeros(3,x.dtype)
 
243
        desired_x = y.copy()
 
244
        desired_y = x.copy()[::2]
 
245
        self.blas_func(x,y,n=3,incx=2)
 
246
        assert_array_almost_equal(desired_x,x[::2])
 
247
        assert_array_almost_equal(desired_y,y)
 
248
    def check_y_stride(self):
 
249
        x = arange(3.,dtype=self.dtype)
 
250
        y = zeros(6,x.dtype)
 
251
        desired_x = y.copy()[::2]
 
252
        desired_y = x.copy()
 
253
        self.blas_func(x,y,n=3,incy=2)
 
254
        assert_array_almost_equal(desired_x,x)
 
255
        assert_array_almost_equal(desired_y,y[::2])
 
256
 
 
257
    def check_x_and_y_stride(self):
 
258
        x = arange(12.,dtype=self.dtype)
 
259
        y = zeros(6,x.dtype)
 
260
        desired_x = y.copy()[::2]
 
261
        desired_y = x.copy()[::4]
 
262
        self.blas_func(x,y,n=3,incx=4,incy=2)
 
263
        assert_array_almost_equal(desired_x,x[::4])
 
264
        assert_array_almost_equal(desired_y,y[::2])
 
265
    def check_x_bad_size(self):
 
266
        x = arange(12.,dtype=self.dtype)
 
267
        y = zeros(6,x.dtype)
 
268
        try:
 
269
            self.blas_func(x,y,n=4,incx=5)
 
270
        except: # what kind of error should be caught?
 
271
            return
 
272
        # should catch error and never get here
 
273
        assert(0)
 
274
    def check_y_bad_size(self):
 
275
        x = arange(12.,dtype=complex64)
 
276
        y = zeros(6,x.dtype)
 
277
        try:
 
278
            self.blas_func(x,y,n=3,incy=5)
 
279
        except: # what kind of error should be caught?
 
280
            return
 
281
        # should catch error and never get here
 
282
        assert(0)
 
283
 
 
284
try:
 
285
    class test_sswap(base_swap):
 
286
        blas_func = fblas.sswap
 
287
        dtype = float32
 
288
except AttributeError:
 
289
    class test_sswap: pass
 
290
class test_dswap(base_swap):
 
291
    blas_func = fblas.dswap
 
292
    dtype = float64
 
293
try:
 
294
    class test_cswap(base_swap):
 
295
        blas_func = fblas.cswap
 
296
        dtype = complex64
 
297
except AttributeError:
 
298
    class test_cswap: pass
 
299
class test_zswap(base_swap):
 
300
    blas_func = fblas.zswap
 
301
    dtype = complex128
 
302
 
 
303
##################################################
 
304
### Test blas ?gemv
 
305
### This will be a mess to test all cases.
 
306
 
 
307
class base_gemv(NumpyTestCase):
 
308
    def get_data(self,x_stride=1,y_stride=1):
 
309
        mult = array(1, dtype = self.dtype)
 
310
        if self.dtype in [complex64, complex128]:
 
311
            mult = array(1+1j, dtype = self.dtype)
 
312
        from numpy.random import normal
 
313
        alpha = array(1., dtype = self.dtype) * mult
 
314
        beta = array(1.,dtype = self.dtype) * mult
 
315
        a = normal(0.,1.,(3,3)).astype(self.dtype) * mult
 
316
        x = arange(shape(a)[0]*x_stride,dtype=self.dtype) * mult
 
317
        y = arange(shape(a)[1]*y_stride,dtype=self.dtype) * mult
 
318
        return alpha,beta,a,x,y
 
319
    def check_simple(self):
 
320
        alpha,beta,a,x,y = self.get_data()
 
321
        desired_y = alpha*matrixmultiply(a,x)+beta*y
 
322
        y = self.blas_func(alpha,a,x,beta,y)
 
323
        assert_array_almost_equal(desired_y,y)
 
324
    def check_default_beta_y(self):
 
325
        alpha,beta,a,x,y = self.get_data()
 
326
        desired_y = matrixmultiply(a,x)
 
327
        y = self.blas_func(1,a,x)
 
328
        assert_array_almost_equal(desired_y,y)
 
329
    def check_simple_transpose(self):
 
330
        alpha,beta,a,x,y = self.get_data()
 
331
        desired_y = alpha*matrixmultiply(transpose(a),x)+beta*y
 
332
        y = self.blas_func(alpha,a,x,beta,y,trans=1)
 
333
        assert_array_almost_equal(desired_y,y)
 
334
    def check_simple_transpose_conj(self):
 
335
        alpha,beta,a,x,y = self.get_data()
 
336
        desired_y = alpha*matrixmultiply(transpose(conjugate(a)),x)+beta*y
 
337
        y = self.blas_func(alpha,a,x,beta,y,trans=2)
 
338
        assert_array_almost_equal(desired_y,y)
 
339
    def check_x_stride(self):
 
340
        alpha,beta,a,x,y = self.get_data(x_stride=2)
 
341
        desired_y = alpha*matrixmultiply(a,x[::2])+beta*y
 
342
        y = self.blas_func(alpha,a,x,beta,y,incx=2)
 
343
        assert_array_almost_equal(desired_y,y)
 
344
    def check_x_stride_transpose(self):
 
345
        alpha,beta,a,x,y = self.get_data(x_stride=2)
 
346
        desired_y = alpha*matrixmultiply(transpose(a),x[::2])+beta*y
 
347
        y = self.blas_func(alpha,a,x,beta,y,trans=1,incx=2)
 
348
        assert_array_almost_equal(desired_y,y)
 
349
    def check_x_stride_assert(self):
 
350
        # What is the use of this test?
 
351
        alpha,beta,a,x,y = self.get_data(x_stride=2)
 
352
        try:
 
353
            y = self.blas_func(1,a,x,1,y,trans=0,incx=3)
 
354
            assert(0)
 
355
        except:
 
356
            pass
 
357
        try:
 
358
            y = self.blas_func(1,a,x,1,y,trans=1,incx=3)
 
359
            assert(0)
 
360
        except:
 
361
            pass
 
362
    def check_y_stride(self):
 
363
        alpha,beta,a,x,y = self.get_data(y_stride=2)
 
364
        desired_y = y.copy()
 
365
        desired_y[::2] = alpha*matrixmultiply(a,x)+beta*y[::2]
 
366
        y = self.blas_func(alpha,a,x,beta,y,incy=2)
 
367
        assert_array_almost_equal(desired_y,y)
 
368
    def check_y_stride_transpose(self):
 
369
        alpha,beta,a,x,y = self.get_data(y_stride=2)
 
370
        desired_y = y.copy()
 
371
        desired_y[::2] = alpha*matrixmultiply(transpose(a),x)+beta*y[::2]
 
372
        y = self.blas_func(alpha,a,x,beta,y,trans=1,incy=2)
 
373
        assert_array_almost_equal(desired_y,y)
 
374
    def check_y_stride_assert(self):
 
375
        # What is the use of this test?
 
376
        alpha,beta,a,x,y = self.get_data(y_stride=2)
 
377
        try:
 
378
            y = self.blas_func(1,a,x,1,y,trans=0,incy=3)
 
379
            assert(0)
 
380
        except:
 
381
            pass
 
382
        try:
 
383
            y = self.blas_func(1,a,x,1,y,trans=1,incy=3)
 
384
            assert(0)
 
385
        except:
 
386
            pass
 
387
 
 
388
try:
 
389
    class test_sgemv(base_gemv):
 
390
        blas_func = fblas.sgemv
 
391
        dtype = float32
 
392
except AttributeError:
 
393
    class test_sgemv: pass
 
394
class test_dgemv(base_gemv):
 
395
    blas_func = fblas.dgemv
 
396
    dtype = float64
 
397
try:
 
398
    class test_cgemv(base_gemv):
 
399
        blas_func = fblas.cgemv
 
400
        dtype = complex64
 
401
except AttributeError:
 
402
    class test_cgemv: pass
 
403
class test_zgemv(base_gemv):
 
404
    blas_func = fblas.zgemv
 
405
    dtype = complex128
 
406
 
 
407
"""
 
408
##################################################
 
409
### Test blas ?ger
 
410
### This will be a mess to test all cases.
 
411
 
 
412
class base_ger(NumpyTestCase):
 
413
    def get_data(self,x_stride=1,y_stride=1):
 
414
        from numpy.random import normal
 
415
        alpha = array(1., dtype = self.dtype)
 
416
        a = normal(0.,1.,(3,3)).astype(self.dtype)
 
417
        x = arange(shape(a)[0]*x_stride,dtype=self.dtype)
 
418
        y = arange(shape(a)[1]*y_stride,dtype=self.dtype)
 
419
        return alpha,a,x,y
 
420
    def check_simple(self):
 
421
        alpha,a,x,y = self.get_data()
 
422
        # tranpose takes care of Fortran vs. C(and Python) memory layout
 
423
        desired_a = alpha*transpose(x[:,newaxis]*y) + a
 
424
        self.blas_func(x,y,a)
 
425
        assert_array_almost_equal(desired_a,a)
 
426
    def check_x_stride(self):
 
427
        alpha,a,x,y = self.get_data(x_stride=2)
 
428
        desired_a = alpha*transpose(x[::2,newaxis]*y) + a
 
429
        self.blas_func(x,y,a,incx=2)
 
430
        assert_array_almost_equal(desired_a,a)
 
431
    def check_x_stride_assert(self):
 
432
        alpha,a,x,y = self.get_data(x_stride=2)
 
433
        try:
 
434
            self.blas_func(x,y,a,incx=3)
 
435
            assert(0)
 
436
        except:
 
437
            pass
 
438
    def check_y_stride(self):
 
439
        alpha,a,x,y = self.get_data(y_stride=2)
 
440
        desired_a = alpha*transpose(x[:,newaxis]*y[::2]) + a
 
441
        self.blas_func(x,y,a,incy=2)
 
442
        assert_array_almost_equal(desired_a,a)
 
443
 
 
444
    def check_y_stride_assert(self):
 
445
        alpha,a,x,y = self.get_data(y_stride=2)
 
446
        try:
 
447
            self.blas_func(a,x,y,incy=3)
 
448
            assert(0)
 
449
        except:
 
450
            pass
 
451
 
 
452
class test_sger(base_ger):
 
453
    blas_func = fblas.sger
 
454
    dtype = float32
 
455
class test_dger(base_ger):
 
456
    blas_func = fblas.dger
 
457
    dtype = float64
 
458
"""
 
459
##################################################
 
460
### Test blas ?gerc
 
461
### This will be a mess to test all cases.
 
462
 
 
463
"""
 
464
class base_ger_complex(base_ger):
 
465
    def get_data(self,x_stride=1,y_stride=1):
 
466
        from numpy.random import normal
 
467
        alpha = array(1+1j, dtype = self.dtype)
 
468
        a = normal(0.,1.,(3,3)).astype(self.dtype)
 
469
        a = a + normal(0.,1.,(3,3)) * array(1j, dtype = self.dtype)
 
470
        x = normal(0.,1.,shape(a)[0]*x_stride).astype(self.dtype)
 
471
        x = x + x * array(1j, dtype = self.dtype)
 
472
        y = normal(0.,1.,shape(a)[1]*y_stride).astype(self.dtype)
 
473
        y = y + y * array(1j, dtype = self.dtype)
 
474
        return alpha,a,x,y
 
475
    def check_simple(self):
 
476
        alpha,a,x,y = self.get_data()
 
477
        # tranpose takes care of Fortran vs. C(and Python) memory layout
 
478
        a = a * array(0.,dtype = self.dtype)
 
479
        #desired_a = alpha*transpose(x[:,newaxis]*self.transform(y)) + a
 
480
        desired_a = alpha*transpose(x[:,newaxis]*y) + a
 
481
        #self.blas_func(x,y,a,alpha = alpha)
 
482
        fblas.cgeru(x,y,a,alpha = alpha)
 
483
        assert_array_almost_equal(desired_a,a)
 
484
 
 
485
    #def check_x_stride(self):
 
486
    #    alpha,a,x,y = self.get_data(x_stride=2)
 
487
    #    desired_a = alpha*transpose(x[::2,newaxis]*self.transform(y)) + a
 
488
    #    self.blas_func(x,y,a,incx=2)
 
489
    #    assert_array_almost_equal(desired_a,a)
 
490
    #def check_y_stride(self):
 
491
    #    alpha,a,x,y = self.get_data(y_stride=2)
 
492
    #    desired_a = alpha*transpose(x[:,newaxis]*self.transform(y[::2])) + a
 
493
    #    self.blas_func(x,y,a,incy=2)
 
494
    #    assert_array_almost_equal(desired_a,a)
 
495
 
 
496
class test_cgeru(base_ger_complex):
 
497
    blas_func = fblas.cgeru
 
498
    dtype = complex64
 
499
    def transform(self,x):
 
500
        return x
 
501
class test_zgeru(base_ger_complex):
 
502
    blas_func = fblas.zgeru
 
503
    dtype = complex128
 
504
    def transform(self,x):
 
505
        return x
 
506
 
 
507
class test_cgerc(base_ger_complex):
 
508
    blas_func = fblas.cgerc
 
509
    dtype = complex64
 
510
    def transform(self,x):
 
511
        return conjugate(x)
 
512
 
 
513
class test_zgerc(base_ger_complex):
 
514
    blas_func = fblas.zgerc
 
515
    dtype = complex128
 
516
    def transform(self,x):
 
517
        return conjugate(x)
 
518
"""
 
519
 
 
520
if __name__ == "__main__":
 
521
    NumpyTest().run()