~ubuntu-branches/ubuntu/raring/python-scipy/raring-proposed

« back to all changes in this revision

Viewing changes to Lib/linalg/tests/test_basic.py

  • Committer: Bazaar Package Importer
  • Author(s): Matthias Klose
  • Date: 2007-01-07 14:12:12 UTC
  • mfrom: (1.1.1 upstream)
  • Revision ID: james.westby@ubuntu.com-20070107141212-mm0ebkh5b37hcpzn
* Remove build dependency on python-numpy-dev.
* python-scipy: Depend on python-numpy instead of python-numpy-dev.
* Package builds on other archs than i386. Closes: #402783.

Show diffs side-by-side

added added

removed removed

Lines of Context:
19
19
  python tests/test_basic.py [<level>]
20
20
"""
21
21
 
22
 
import Numeric
23
 
from Numeric import arange, add, array, dot, zeros, identity
 
22
import numpy
 
23
from numpy import arange, add, array, dot, zeros, identity, conjugate, transpose
24
24
 
25
25
import sys
26
 
from scipy_test.testing import *
 
26
from numpy.testing import *
27
27
set_package_path()
28
28
from linalg import solve,inv,det,lstsq, toeplitz, hankel, tri, triu, tril
29
29
from linalg import pinv, pinv2, solve_banded
30
 
del sys.path[0]
 
30
restore_path()
31
31
 
32
32
import unittest
33
33
 
52
52
        for b in ([[1,0,0,0],[0,0,0,1],[0,1,0,0],[0,1,0,0]],
53
53
                  [[2,1],[-30,4],[2,3],[1,3]]):
54
54
            x = solve_banded((l,u),ab,b)
55
 
            assert_array_almost_equal(Numeric.matrixmultiply(a,x),b)
 
55
            assert_array_almost_equal(numpy.dot(a,x),b)
56
56
 
57
57
class test_solve(ScipyTestCase):
58
58
 
59
59
    def check_20Feb04_bug(self):
60
60
        a = [[1,1],[1.0,0]] # ok
61
61
        x0 = solve(a,[1,0j])
62
 
        assert_array_almost_equal(Numeric.matrixmultiply(a,x0),[1,0])
 
62
        assert_array_almost_equal(numpy.dot(a,x0),[1,0])
63
63
 
64
64
        a = [[1,1],[1.2,0]] # gives failure with clapack.zgesv(..,rowmajor=0)
65
65
        b = [1,0j]
66
66
        x0 = solve(a,b)
67
 
        assert_array_almost_equal(Numeric.matrixmultiply(a,x0),[1,0])
 
67
        assert_array_almost_equal(numpy.dot(a,x0),[1,0])
68
68
 
69
69
    def check_simple(self):
70
70
        a = [[1,20],[-30,4]]
71
71
        for b in ([[1,0],[0,1]],[1,0],
72
72
                  [[2,1],[-30,4]]):
73
73
            x = solve(a,b)
74
 
            assert_array_almost_equal(Numeric.matrixmultiply(a,x),b)
75
 
        
 
74
            assert_array_almost_equal(numpy.dot(a,x),b)
 
75
 
76
76
    def check_simple_sym(self):
77
77
        a = [[2,3],[3,5]]
78
78
        for lower in [0,1]:
79
79
            for b in ([[1,0],[0,1]],[1,0]):
80
80
                x = solve(a,b,sym_pos=1,lower=lower)
81
 
                assert_array_almost_equal(Numeric.matrixmultiply(a,x),b)
 
81
                assert_array_almost_equal(numpy.dot(a,x),b)
82
82
 
83
83
    def check_simple_sym_complex(self):
84
84
        a = [[5,2],[2,4]]
87
87
                   [0,2]],
88
88
                  ]:
89
89
            x = solve(a,b,sym_pos=1)
90
 
            assert_array_almost_equal(Numeric.matrixmultiply(a,x),b)
 
90
            assert_array_almost_equal(numpy.dot(a,x),b)
91
91
 
92
92
    def check_simple_complex(self):
93
93
        a = array([[5,2],[2j,4]],'D')
98
98
                  array([1,0],'D'),
99
99
                  ]:
100
100
            x = solve(a,b)
101
 
            assert_array_almost_equal(Numeric.matrixmultiply(a,x),b)
 
101
            assert_array_almost_equal(numpy.dot(a,x),b)
102
102
 
103
103
    def check_nils_20Feb04(self):
104
104
        n = 2
105
105
        A = random([n,n])+random([n,n])*1j
106
106
        X = zeros((n,n),'D')
107
 
        Ainv = inv(A) 
 
107
        Ainv = inv(A)
108
108
        R = identity(n)+identity(n)*0j
109
109
        for i in arange(0,n):
110
110
            r = R[:,i]
119
119
        for i in range(4):
120
120
            b = random([n,3])
121
121
            x = solve(a,b)
122
 
            assert_array_almost_equal(Numeric.matrixmultiply(a,x),b)
 
122
            assert_array_almost_equal(numpy.dot(a,x),b)
123
123
 
124
124
    def check_random_complex(self):
125
125
        n = 20
128
128
        for i in range(2):
129
129
            b = random([n,3])
130
130
            x = solve(a,b)
131
 
            assert_array_almost_equal(Numeric.matrixmultiply(a,x),b)
 
131
            assert_array_almost_equal(numpy.dot(a,x),b)
132
132
 
133
133
    def check_random_sym(self):
134
134
        n = 20
140
140
        for i in range(4):
141
141
            b = random([n])
142
142
            x = solve(a,b,sym_pos=1)
143
 
            assert_array_almost_equal(Numeric.matrixmultiply(a,x),b)
 
143
            assert_array_almost_equal(numpy.dot(a,x),b)
144
144
 
145
145
    def check_random_sym_complex(self):
146
146
        n = 20
149
149
        for i in range(n):
150
150
            a[i,i] = abs(20*(.1+a[i,i]))
151
151
            for j in range(i):
152
 
                a[i,j] = Numeric.conjugate(a[j,i])
 
152
                a[i,j] = numpy.conjugate(a[j,i])
153
153
        b = random([n])+2j*random([n])
154
154
        for i in range(2):
155
155
            x = solve(a,b,sym_pos=1)
156
 
            assert_array_almost_equal(Numeric.matrixmultiply(a,x),b)
 
156
            assert_array_almost_equal(numpy.dot(a,x),b)
157
157
 
158
158
    def bench_random(self,level=5):
159
 
        import LinearAlgebra
160
 
        Numeric_solve = LinearAlgebra.solve_linear_equations
 
159
        import numpy.linalg as linalg
 
160
        basic_solve = linalg.solve
161
161
        print
162
162
        print '      Solving system of linear equations'
163
163
        print '      =================================='
164
164
 
165
165
        print '      |    contiguous     |   non-contiguous '
166
166
        print '----------------------------------------------'
167
 
        print ' size |  scipy  | Numeric |  scipy  | Numeric'
168
 
        
 
167
        print ' size |  scipy  | basic   |  scipy  | basic '
 
168
 
169
169
        for size,repeat in [(20,1000),(100,150),(500,2),(1000,1)][:-1]:
170
170
            repeat *= 2
171
171
            print '%5s' % size,
172
172
            sys.stdout.flush()
173
 
            
 
173
 
174
174
            a = random([size,size])
175
175
            # larger diagonal ensures non-singularity:
176
176
            for i in range(size): a[i,i] = 10*(.1+a[i,i])
179
179
            print '| %6.2f ' % self.measure('solve(a,b)',repeat),
180
180
            sys.stdout.flush()
181
181
 
182
 
            print '| %6.2f ' % self.measure('Numeric_solve(a,b)',repeat),
 
182
            print '| %6.2f ' % self.measure('basic_solve(a,b)',repeat),
183
183
            sys.stdout.flush()
184
 
                        
 
184
 
185
185
            a = a[-1::-1,-1::-1] # turn into a non-contiguous array
186
 
            assert not a.iscontiguous()
 
186
            assert not a.flags['CONTIGUOUS']
187
187
 
188
188
            print '| %6.2f ' % self.measure('solve(a,b)',repeat),
189
189
            sys.stdout.flush()
190
190
 
191
 
            print '| %6.2f ' % self.measure('Numeric_solve(a,b)',repeat),
 
191
            print '| %6.2f ' % self.measure('basic_solve(a,b)',repeat),
192
192
            sys.stdout.flush()
193
193
 
194
194
            print '   (secs for %s calls)' % (repeat)
198
198
    def check_simple(self):
199
199
        a = [[1,2],[3,4]]
200
200
        a_inv = inv(a)
201
 
        assert_array_almost_equal(Numeric.matrixmultiply(a,a_inv),
 
201
        assert_array_almost_equal(numpy.dot(a,a_inv),
202
202
                                  [[1,0],[0,1]])
203
203
        a = [[1,2,3],[4,5,6],[7,8,10]]
204
204
        a_inv = inv(a)
205
 
        assert_array_almost_equal(Numeric.matrixmultiply(a,a_inv),
 
205
        assert_array_almost_equal(numpy.dot(a,a_inv),
206
206
                                  [[1,0,0],[0,1,0],[0,0,1]])
207
207
 
208
208
    def check_random(self):
211
211
            a = random([n,n])
212
212
            for i in range(n): a[i,i] = 20*(.1+a[i,i])
213
213
            a_inv = inv(a)
214
 
            assert_array_almost_equal(Numeric.matrixmultiply(a,a_inv),
215
 
                                      Numeric.identity(n))
 
214
            assert_array_almost_equal(numpy.dot(a,a_inv),
 
215
                                      numpy.identity(n))
216
216
    def check_simple_complex(self):
217
217
        a = [[1,2],[3,4j]]
218
218
        a_inv = inv(a)
219
 
        assert_array_almost_equal(Numeric.matrixmultiply(a,a_inv),
 
219
        assert_array_almost_equal(numpy.dot(a,a_inv),
220
220
                                  [[1,0],[0,1]])
221
221
 
222
222
    def check_random_complex(self):
225
225
            a = random([n,n])+2j*random([n,n])
226
226
            for i in range(n): a[i,i] = 20*(.1+a[i,i])
227
227
            a_inv = inv(a)
228
 
            assert_array_almost_equal(Numeric.matrixmultiply(a,a_inv),
229
 
                                      Numeric.identity(n))
 
228
            assert_array_almost_equal(numpy.dot(a,a_inv),
 
229
                                      numpy.identity(n))
230
230
 
231
231
    def bench_random(self,level=5):
232
 
        import LinearAlgebra
233
 
        Numeric_inv = LinearAlgebra.inverse
 
232
        import numpy.linalg as linalg
 
233
        basic_inv = linalg.inv
234
234
        print
235
235
        print '           Finding matrix inverse'
236
236
        print '      =================================='
237
237
        print '      |    contiguous     |   non-contiguous '
238
238
        print '----------------------------------------------'
239
 
        print ' size |  scipy  | Numeric |  scipy  | Numeric'
240
 
        
 
239
        print ' size |  scipy  | basic   |  scipy  | basic'
 
240
 
241
241
        for size,repeat in [(20,1000),(100,150),(500,2),(1000,1)][:-1]:
242
242
            repeat *= 2
243
243
            print '%5s' % size,
244
244
            sys.stdout.flush()
245
 
            
 
245
 
246
246
            a = random([size,size])
247
247
            # large diagonal ensures non-singularity:
248
248
            for i in range(size): a[i,i] = 10*(.1+a[i,i])
250
250
            print '| %6.2f ' % self.measure('inv(a)',repeat),
251
251
            sys.stdout.flush()
252
252
 
253
 
            print '| %6.2f ' % self.measure('Numeric_inv(a)',repeat),
 
253
            print '| %6.2f ' % self.measure('basic_inv(a)',repeat),
254
254
            sys.stdout.flush()
255
 
                        
 
255
 
256
256
            a = a[-1::-1,-1::-1] # turn into a non-contiguous array
257
 
            assert not a.iscontiguous()
 
257
            assert not a.flags['CONTIGUOUS']
258
258
 
259
259
            print '| %6.2f ' % self.measure('inv(a)',repeat),
260
260
            sys.stdout.flush()
261
261
 
262
 
            print '| %6.2f ' % self.measure('Numeric_inv(a)',repeat),
 
262
            print '| %6.2f ' % self.measure('basic_inv(a)',repeat),
263
263
            sys.stdout.flush()
264
264
 
265
265
            print '   (secs for %s calls)' % (repeat)
278
278
        assert_almost_equal(a_det,-6+4j)
279
279
 
280
280
    def check_random(self):
281
 
        import LinearAlgebra
282
 
        Numeric_det = LinearAlgebra.determinant
 
281
        import numpy.linalg as linalg
 
282
        basic_det = linalg.det
283
283
        n = 20
284
284
        for i in range(4):
285
285
            a = random([n,n])
286
286
            d1 = det(a)
287
 
            d2 = Numeric_det(a)
 
287
            d2 = basic_det(a)
288
288
            assert_almost_equal(d1,d2)
289
289
 
290
290
    def check_random_complex(self):
291
 
        import LinearAlgebra
292
 
        Numeric_det = LinearAlgebra.determinant
 
291
        import numpy.linalg as linalg
 
292
        basic_det = linalg.det
293
293
        n = 20
294
294
        for i in range(4):
295
295
            a = random([n,n]) + 2j*random([n,n])
296
296
            d1 = det(a)
297
 
            d2 = Numeric_det(a)
 
297
            d2 = basic_det(a)
298
298
            assert_almost_equal(d1,d2)
299
299
 
300
300
    def bench_random(self,level=5):
301
 
        import LinearAlgebra
302
 
        Numeric_det = LinearAlgebra.determinant
 
301
        import numpy.linalg as linalg
 
302
        basic_det = linalg.det
303
303
        print
304
304
        print '           Finding matrix determinant'
305
305
        print '      =================================='
306
306
        print '      |    contiguous     |   non-contiguous '
307
307
        print '----------------------------------------------'
308
 
        print ' size |  scipy  | Numeric |  scipy  | Numeric'
309
 
        
 
308
        print ' size |  scipy  | basic   |  scipy  | basic '
 
309
 
310
310
        for size,repeat in [(20,1000),(100,150),(500,2),(1000,1)][:-1]:
311
311
            repeat *= 2
312
312
            print '%5s' % size,
313
313
            sys.stdout.flush()
314
 
            
 
314
 
315
315
            a = random([size,size])
316
316
 
317
317
            print '| %6.2f ' % self.measure('det(a)',repeat),
318
318
            sys.stdout.flush()
319
319
 
320
 
            print '| %6.2f ' % self.measure('Numeric_det(a)',repeat),
 
320
            print '| %6.2f ' % self.measure('basic_det(a)',repeat),
321
321
            sys.stdout.flush()
322
 
                        
 
322
 
323
323
            a = a[-1::-1,-1::-1] # turn into a non-contiguous array
324
 
            assert not a.iscontiguous()
 
324
            assert not a.flags['CONTIGUOUS']
325
325
 
326
326
            print '| %6.2f ' % self.measure('det(a)',repeat),
327
327
            sys.stdout.flush()
328
328
 
329
 
            print '| %6.2f ' % self.measure('Numeric_det(a)',repeat),
 
329
            print '| %6.2f ' % self.measure('basic_det(a)',repeat),
330
330
            sys.stdout.flush()
331
331
 
332
332
            print '   (secs for %s calls)' % (repeat)
333
333
 
334
334
 
335
 
def direct_lstsq(a,b):
336
 
    a1 = Numeric.matrixmultiply(Numeric.transpose(a),a)
337
 
    b1 = Numeric.matrixmultiply(Numeric.transpose(a),b)
338
 
    return solve(a1,b1)
 
335
def direct_lstsq(a,b,cmplx=0):
 
336
    at = transpose(a)
 
337
    if cmplx:
 
338
        at = conjugate(at)
 
339
    a1 = dot(at, a)
 
340
    b1 = dot(at, b)
 
341
    return solve(a1, b1)
339
342
 
340
343
class test_lstsq(ScipyTestCase):
341
 
 
342
344
    def check_random_overdet_large(self):
343
345
        #bug report: Nils Wagner
344
346
        n = 200
353
355
        for b in ([[1,0],[0,1]],[1,0],
354
356
                  [[2,1],[-30,4]]):
355
357
            x = lstsq(a,b)[0]
356
 
            assert_array_almost_equal(Numeric.matrixmultiply(a,x),b)
 
358
            assert_array_almost_equal(numpy.dot(a,x),b)
357
359
 
358
360
    def check_simple_overdet(self):
359
361
        a = [[1,2],[4,5],[3,4]]
367
369
        b = [1,2]
368
370
        x,res,r,s = lstsq(a,b)
369
371
        #XXX: need independent check
370
 
        assert_array_almost_equal(x,[[-0.05555556],[0.11111111],[0.27777778]])
 
372
        assert_array_almost_equal(x,[[-0.05555556],
 
373
                                     [0.11111111],[0.27777778]])
371
374
 
372
375
    def check_random_exact(self):
373
376
 
377
380
        for i in range(4):
378
381
            b = random([n,3])
379
382
            x = lstsq(a,b)[0]
380
 
            assert_array_almost_equal(Numeric.matrixmultiply(a,x),b)
 
383
            assert_array_almost_equal(numpy.dot(a,x),b)
381
384
 
382
385
    def check_random_complex_exact(self):
383
386
        n = 20
386
389
        for i in range(2):
387
390
            b = random([n,3])
388
391
            x = lstsq(a,b)[0]
389
 
            assert_array_almost_equal(Numeric.matrixmultiply(a,x),b)
 
392
            assert_array_almost_equal(numpy.dot(a,x),b)
390
393
 
391
394
    def check_random_overdet(self):
392
395
        n = 20
404
407
        n = 20
405
408
        m = 15
406
409
        a = random([n,m]) + 1j * random([n,m])
407
 
        for i in range(m): a[i,i] = 20*(.1+a[i,i])
 
410
        for i in range(m):
 
411
            a[i,i] = 20*(.1+a[i,i])
408
412
        for i in range(2):
409
413
            b = random([n,3])
410
414
            x,res,r,s = lstsq(a,b)
411
415
            assert r==m,'unexpected efficient rank'
412
416
            #XXX: check definition of res
413
 
            assert_array_almost_equal(x,direct_lstsq(a,b),1e-3)
414
 
            #XXX: tolerance 1e-3 is quite large, investigate the reason
 
417
            assert_array_almost_equal(x,direct_lstsq(a,b,1))
415
418
 
416
419
class test_tri(unittest.TestCase):
417
420
    def check_basic(self):
419
422
                                   [1,1,0,0],
420
423
                                   [1,1,1,0],
421
424
                                   [1,1,1,1]]))
422
 
        assert_equal(tri(4,typecode='f'),array([[1,0,0,0],
 
425
        assert_equal(tri(4,dtype='f'),array([[1,0,0,0],
423
426
                                                [1,1,0,0],
424
427
                                                [1,1,1,0],
425
428
                                                [1,1,1,1]],'f'))
426
429
    def check_diag(self):
427
430
        assert_equal(tri(4,k=1),array([[1,1,0,0],
428
431
                                       [1,1,1,0],
429
 
                                       [0,1,1,1],
 
432
                                       [1,1,1,1],
430
433
                                       [1,1,1,1]]))
431
434
        assert_equal(tri(4,k=-1),array([[0,0,0,0],
432
435
                                        [1,0,0,0],
439
442
                                     [1,1,1]]))
440
443
        assert_equal(tri(3,4),array([[1,0,0,0],
441
444
                                     [1,1,0,0],
442
 
                                     [1,1,1,0]]))        
 
445
                                     [1,1,1,0]]))
443
446
    def check_diag2d(self):
444
447
        assert_equal(tri(3,4,k=2),array([[1,1,1,0],
445
448
                                         [1,1,1,1],
458
461
                b[k,l] = 0
459
462
        assert_equal(tril(a),b)
460
463
 
461
 
    def check_diag(self):        
 
464
    def check_diag(self):
462
465
        a = (100*get_mat(5)).astype('f')
463
466
        b = a.copy()
464
467
        for k in range(5):
480
483
                b[l,k] = 0
481
484
        assert_equal(triu(a),b)
482
485
 
483
 
    def check_diag(self):        
 
486
    def check_diag(self):
484
487
        a = (100*get_mat(5)).astype('f')
485
488
        b = a.copy()
486
489
        for k in range(5):
491
494
        for k in range(5):
492
495
            for l in range(k+3,5):
493
496
                b[l,k] = 0
494
 
        assert_equal(tril(a,k=-2),b)
 
497
        assert_equal(triu(a,k=-2),b)
495
498
 
496
499
class test_toeplitz(unittest.TestCase):
497
500
    def check_basic(self):
499
502
        assert_array_equal(y,[[1,2,3],[2,1,2],[3,2,1]])
500
503
        y = toeplitz([1,2,3],[1,4,5])
501
504
        assert_array_equal(y,[[1,4,5],[2,1,4],[3,2,1]])
502
 
        
 
505
 
503
506
class test_hankel(unittest.TestCase):
504
507
    def check_basic(self):
505
508
        y = hankel([1,2,3])
535
538
        assert_array_almost_equal(a_pinv,a_pinv2)
536
539
 
537
540
if __name__ == "__main__":
538
 
    ScipyTest('linalg.basic').run()
 
541
    ScipyTest().run()