~ubuntu-branches/ubuntu/saucy/python-scipy/saucy

« back to all changes in this revision

Viewing changes to Lib/linalg/tests/test_decomp.py

  • Committer: Bazaar Package Importer
  • Author(s): Ondrej Certik
  • Date: 2008-06-16 22:58:01 UTC
  • mfrom: (2.1.24 intrepid)
  • Revision ID: james.westby@ubuntu.com-20080616225801-irdhrpcwiocfbcmt
Tags: 0.6.0-12
* The description updated to match the current SciPy (Closes: #489149).
* Standards-Version bumped to 3.8.0 (no action needed)
* Build-Depends: netcdf-dev changed to libnetcdf-dev

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
#!/usr/bin/env python
2
 
#
3
 
# Created by: Pearu Peterson, March 2002
4
 
#
5
 
""" Test functions for linalg.decomp module
6
 
 
7
 
"""
8
 
__usage__ = """
9
 
Build linalg:
10
 
  python setup_linalg.py build
11
 
Run tests if scipy is installed:
12
 
  python -c 'import scipy;scipy.linalg.test(<level>)'
13
 
Run tests if linalg is not installed:
14
 
  python tests/test_decomp.py [<level>]
15
 
"""
16
 
 
17
 
import sys
18
 
from numpy.testing import *
19
 
 
20
 
set_package_path()
21
 
from linalg import eig,eigvals,lu,svd,svdvals,cholesky,qr,schur,rsf2csf
22
 
from linalg import lu_solve,lu_factor,solve,diagsvd,hessenberg
23
 
from linalg import eig_banded,eigvals_banded
24
 
from linalg.flapack import dgbtrf, dgbtrs, zgbtrf, zgbtrs
25
 
from linalg.flapack import dsbev, dsbevd, dsbevx, zhbevd, zhbevx
26
 
 
27
 
restore_path()
28
 
 
29
 
from numpy import *
30
 
from numpy.random import rand
31
 
 
32
 
def random(size):
33
 
    return rand(*size)
34
 
 
35
 
class test_eigvals(ScipyTestCase):
36
 
 
37
 
    def check_simple(self):
38
 
        a = [[1,2,3],[1,2,3],[2,5,6]]
39
 
        w = eigvals(a)
40
 
        exact_w = [(9+sqrt(93))/2,0,(9-sqrt(93))/2]
41
 
        assert_array_almost_equal(w,exact_w)
42
 
 
43
 
    def check_simple_tr(self):
44
 
        a = array([[1,2,3],[1,2,3],[2,5,6]],'d')
45
 
        a = transpose(a).copy()
46
 
        a = transpose(a)
47
 
        w = eigvals(a)
48
 
        exact_w = [(9+sqrt(93))/2,0,(9-sqrt(93))/2]
49
 
        assert_array_almost_equal(w,exact_w)
50
 
 
51
 
    def check_simple_complex(self):
52
 
        a = [[1,2,3],[1,2,3],[2,5,6+1j]]
53
 
        w = eigvals(a)
54
 
        exact_w = [(9+1j+sqrt(92+6j))/2,
55
 
                   0,
56
 
                   (9+1j-sqrt(92+6j))/2]
57
 
        assert_array_almost_equal(w,exact_w)
58
 
 
59
 
    def bench_random(self,level=5):
60
 
        import numpy.linalg as linalg
61
 
        Numeric_eigvals = linalg.eigvals
62
 
        print
63
 
        print '           Finding matrix eigenvalues'
64
 
        print '      =================================='
65
 
        print '      |    contiguous     '#'|   non-contiguous '
66
 
        print '----------------------------------------------'
67
 
        print ' size |  scipy  '#'| core |  scipy  | core '
68
 
 
69
 
        for size,repeat in [(20,150),(100,7),(200,2)]:
70
 
            repeat *= 1
71
 
            print '%5s' % size,
72
 
            sys.stdout.flush()
73
 
 
74
 
            a = random([size,size])
75
 
 
76
 
            print '| %6.2f ' % self.measure('eigvals(a)',repeat),
77
 
            sys.stdout.flush()
78
 
 
79
 
            print '   (secs for %s calls)' % (repeat)
80
 
 
81
 
class test_eig(ScipyTestCase):
82
 
 
83
 
    def check_simple(self):
84
 
        a = [[1,2,3],[1,2,3],[2,5,6]]
85
 
        w,v = eig(a)
86
 
        exact_w = [(9+sqrt(93))/2,0,(9-sqrt(93))/2]
87
 
        v0 = array([1,1,(1+sqrt(93)/3)/2])
88
 
        v1 = array([3.,0,-1])
89
 
        v2 = array([1,1,(1-sqrt(93)/3)/2])
90
 
        v0 = v0 / sqrt(dot(v0,transpose(v0)))
91
 
        v1 = v1 / sqrt(dot(v1,transpose(v1)))
92
 
        v2 = v2 / sqrt(dot(v2,transpose(v2)))
93
 
        assert_array_almost_equal(w,exact_w)
94
 
        assert_array_almost_equal(v0,v[:,0]*sign(v[0,0]))
95
 
        assert_array_almost_equal(v1,v[:,1]*sign(v[0,1]))
96
 
        assert_array_almost_equal(v2,v[:,2]*sign(v[0,2]))
97
 
        for i in range(3):
98
 
            assert_array_almost_equal(dot(a,v[:,i]),w[i]*v[:,i])
99
 
        w,v = eig(a,left=1,right=0)
100
 
        for i in range(3):
101
 
            assert_array_almost_equal(dot(transpose(a),v[:,i]),w[i]*v[:,i])
102
 
 
103
 
    def check_simple_complex(self):
104
 
        a = [[1,2,3],[1,2,3],[2,5,6+1j]]
105
 
        w,vl,vr = eig(a,left=1,right=1)
106
 
        for i in range(3):
107
 
            assert_array_almost_equal(dot(a,vr[:,i]),w[i]*vr[:,i])
108
 
        for i in range(3):
109
 
            assert_array_almost_equal(dot(conjugate(transpose(a)),vl[:,i]),
110
 
                                      conjugate(w[i])*vl[:,i])
111
 
 
112
 
 
113
 
 
114
 
class test_eig_banded(ScipyTestCase):
115
 
 
116
 
    def __init__(self, *args):
117
 
        ScipyTestCase.__init__(self, *args)
118
 
 
119
 
        self.create_bandmat()
120
 
 
121
 
    def create_bandmat(self):
122
 
        """Create the full matrix `self.fullmat` and 
123
 
           the corresponding band matrix `self.bandmat`."""
124
 
        N  = 10
125
 
        self.KL = 2   # number of subdiagonals (below the diagonal)
126
 
        self.KU = 2   # number of superdiagonals (above the diagonal)
127
 
 
128
 
        # symmetric band matrix
129
 
        self.sym_mat = ( diag(1.0*ones(N))
130
 
                     +  diag(-1.0*ones(N-1), -1) + diag(-1.0*ones(N-1), 1) 
131
 
                     + diag(-2.0*ones(N-2), -2) + diag(-2.0*ones(N-2), 2) )
132
 
 
133
 
        # hermitian band matrix
134
 
        self.herm_mat = ( diag(-1.0*ones(N))
135
 
                     + 1j*diag(1.0*ones(N-1), -1) - 1j*diag(1.0*ones(N-1), 1)
136
 
                     + diag(-2.0*ones(N-2), -2) + diag(-2.0*ones(N-2), 2) )
137
 
 
138
 
        # general real band matrix
139
 
        self.real_mat = ( diag(1.0*ones(N))
140
 
                     +  diag(-1.0*ones(N-1), -1) + diag(-3.0*ones(N-1), 1) 
141
 
                     + diag(2.0*ones(N-2), -2) + diag(-2.0*ones(N-2), 2) )
142
 
 
143
 
        # general complex band matrix
144
 
        self.comp_mat = ( 1j*diag(1.0*ones(N))
145
 
                     +  diag(-1.0*ones(N-1), -1) + 1j*diag(-3.0*ones(N-1), 1) 
146
 
                     + diag(2.0*ones(N-2), -2) + diag(-2.0*ones(N-2), 2) )
147
 
 
148
 
 
149
 
        # Eigenvalues and -vectors from linalg.eig
150
 
        ew, ev = linalg.eig(self.sym_mat)
151
 
        ew = ew.real
152
 
        args = argsort(ew)
153
 
        self.w_sym_lin = ew[args]
154
 
        self.evec_sym_lin = ev[:,args]
155
 
 
156
 
        ew, ev = linalg.eig(self.herm_mat)
157
 
        ew = ew.real
158
 
        args = argsort(ew)
159
 
        self.w_herm_lin = ew[args]
160
 
        self.evec_herm_lin = ev[:,args]
161
 
 
162
 
 
163
 
        # Extract upper bands from symmetric and hermitian band matrices
164
 
        # (for use in dsbevd, dsbevx, zhbevd, zhbevx
165
 
        #  and their single precision versions) 
166
 
        LDAB = self.KU + 1
167
 
        self.bandmat_sym  = zeros((LDAB, N), dtype=float)
168
 
        self.bandmat_herm = zeros((LDAB, N), dtype=complex)
169
 
        for i in xrange(LDAB):
170
 
            self.bandmat_sym[LDAB-i-1,i:N]  = diag(self.sym_mat, i)
171
 
            self.bandmat_herm[LDAB-i-1,i:N] = diag(self.herm_mat, i)
172
 
 
173
 
 
174
 
        # Extract bands from general real and complex band matrix
175
 
        # (for use in dgbtrf, dgbtrs and their single precision versions)
176
 
        LDAB = 2*self.KL + self.KU + 1
177
 
        self.bandmat_real = zeros((LDAB, N), dtype=float)
178
 
        self.bandmat_real[2*self.KL,:] = diag(self.real_mat)     # diagonal
179
 
        for i in xrange(self.KL):
180
 
            # superdiagonals
181
 
            self.bandmat_real[2*self.KL-1-i,i+1:N]   = diag(self.real_mat, i+1)
182
 
            # subdiagonals
183
 
            self.bandmat_real[2*self.KL+1+i,0:N-1-i] = diag(self.real_mat,-i-1)
184
 
 
185
 
        self.bandmat_comp = zeros((LDAB, N), dtype=complex)
186
 
        self.bandmat_comp[2*self.KL,:] = diag(self.comp_mat)     # diagonal
187
 
        for i in xrange(self.KL):
188
 
            # superdiagonals
189
 
            self.bandmat_comp[2*self.KL-1-i,i+1:N]   = diag(self.comp_mat, i+1)
190
 
            # subdiagonals
191
 
            self.bandmat_comp[2*self.KL+1+i,0:N-1-i] = diag(self.comp_mat,-i-1)
192
 
 
193
 
        # absolute value for linear equation system A*x = b
194
 
        self.b = 1.0*arange(N)
195
 
        self.bc = self.b *(1 + 1j) 
196
 
        
197
 
 
198
 
    #####################################################################
199
 
 
200
 
        
201
 
    def check_dsbev(self):
202
 
        """Compare dsbev eigenvalues and eigenvectors with
203
 
           the result of linalg.eig."""
204
 
        w, evec, info  = dsbev(self.bandmat_sym, compute_v=1)
205
 
        evec_ = evec[:,argsort(w)]
206
 
        assert_array_almost_equal(sort(w), self.w_sym_lin)
207
 
        assert_array_almost_equal(abs(evec_), abs(self.evec_sym_lin))
208
 
 
209
 
 
210
 
    
211
 
    def check_dsbevd(self):
212
 
        """Compare dsbevd eigenvalues and eigenvectors with
213
 
           the result of linalg.eig."""
214
 
        w, evec, info = dsbevd(self.bandmat_sym, compute_v=1)
215
 
        evec_ = evec[:,argsort(w)]
216
 
        assert_array_almost_equal(sort(w), self.w_sym_lin)
217
 
        assert_array_almost_equal(abs(evec_), abs(self.evec_sym_lin))
218
 
 
219
 
 
220
 
 
221
 
    def check_dsbevx(self):
222
 
        """Compare dsbevx eigenvalues and eigenvectors
223
 
           with the result of linalg.eig."""
224
 
        N,N = shape(self.sym_mat)
225
 
        ## Achtung: Argumente 0.0,0.0,range?
226
 
        w, evec, num, ifail, info = dsbevx(self.bandmat_sym, 0.0, 0.0, 1, N,
227
 
                                       compute_v=1, range=2)
228
 
        evec_ = evec[:,argsort(w)]
229
 
        assert_array_almost_equal(sort(w), self.w_sym_lin)
230
 
        assert_array_almost_equal(abs(evec_), abs(self.evec_sym_lin))
231
 
 
232
 
 
233
 
    def check_zhbevd(self):
234
 
        """Compare zhbevd eigenvalues and eigenvectors
235
 
           with the result of linalg.eig."""
236
 
        w, evec, info = zhbevd(self.bandmat_herm, compute_v=1)
237
 
        evec_ = evec[:,argsort(w)]
238
 
        assert_array_almost_equal(sort(w), self.w_herm_lin)
239
 
        assert_array_almost_equal(abs(evec_), abs(self.evec_herm_lin))
240
 
 
241
 
 
242
 
 
243
 
    def check_zhbevx(self):
244
 
        """Compare zhbevx eigenvalues and eigenvectors
245
 
           with the result of linalg.eig."""
246
 
        N,N = shape(self.herm_mat)
247
 
        ## Achtung: Argumente 0.0,0.0,range?
248
 
        w, evec, num, ifail, info = zhbevx(self.bandmat_herm, 0.0, 0.0, 1, N,
249
 
                                       compute_v=1, range=2)
250
 
        evec_ = evec[:,argsort(w)]
251
 
        assert_array_almost_equal(sort(w), self.w_herm_lin)
252
 
        assert_array_almost_equal(abs(evec_), abs(self.evec_herm_lin))
253
 
 
254
 
 
255
 
 
256
 
    def check_eigvals_banded(self):
257
 
        """Compare eigenvalues of eigvals_banded with those of linalg.eig."""
258
 
        w_sym = eigvals_banded(self.bandmat_sym)
259
 
        w_sym = w_sym.real
260
 
        assert_array_almost_equal(sort(w_sym), self.w_sym_lin)
261
 
 
262
 
        w_herm = eigvals_banded(self.bandmat_herm)
263
 
        w_herm = w_herm.real
264
 
        assert_array_almost_equal(sort(w_herm), self.w_herm_lin)
265
 
 
266
 
        # extracting eigenvalues with respect to an index range
267
 
        ind1 = 2   
268
 
        ind2 = 6
269
 
        w_sym_ind = eigvals_banded(self.bandmat_sym,
270
 
                                    select='i', select_range=(ind1, ind2) )
271
 
        assert_array_almost_equal(sort(w_sym_ind),
272
 
                                  self.w_sym_lin[ind1:ind2+1])
273
 
        w_herm_ind = eigvals_banded(self.bandmat_herm,
274
 
                                    select='i', select_range=(ind1, ind2) )
275
 
        assert_array_almost_equal(sort(w_herm_ind),
276
 
                                  self.w_herm_lin[ind1:ind2+1])
277
 
 
278
 
        # extracting eigenvalues with respect to a value range
279
 
        v_lower = self.w_sym_lin[ind1] - 1.0e-5
280
 
        v_upper = self.w_sym_lin[ind2] + 1.0e-5
281
 
        w_sym_val = eigvals_banded(self.bandmat_sym,
282
 
                                select='v', select_range=(v_lower, v_upper) )
283
 
        assert_array_almost_equal(sort(w_sym_val),
284
 
                                  self.w_sym_lin[ind1:ind2+1])
285
 
 
286
 
        v_lower = self.w_herm_lin[ind1] - 1.0e-5
287
 
        v_upper = self.w_herm_lin[ind2] + 1.0e-5
288
 
        w_herm_val = eigvals_banded(self.bandmat_herm,
289
 
                                select='v', select_range=(v_lower, v_upper) )
290
 
        assert_array_almost_equal(sort(w_herm_val),
291
 
                                  self.w_herm_lin[ind1:ind2+1])
292
 
 
293
 
 
294
 
 
295
 
    def check_eig_banded(self):
296
 
        """Compare eigenvalues and eigenvectors of eig_banded
297
 
           with those of linalg.eig. """
298
 
        w_sym, evec_sym = eig_banded(self.bandmat_sym)
299
 
        evec_sym_ = evec_sym[:,argsort(w_sym.real)]
300
 
        assert_array_almost_equal(sort(w_sym), self.w_sym_lin)
301
 
        assert_array_almost_equal(abs(evec_sym_), abs(self.evec_sym_lin))
302
 
 
303
 
        w_herm, evec_herm = eig_banded(self.bandmat_herm)
304
 
        evec_herm_ = evec_herm[:,argsort(w_herm.real)]
305
 
        assert_array_almost_equal(sort(w_herm), self.w_herm_lin)
306
 
        assert_array_almost_equal(abs(evec_herm_), abs(self.evec_herm_lin))
307
 
        
308
 
        # extracting eigenvalues with respect to an index range
309
 
        ind1 = 2   
310
 
        ind2 = 6
311
 
        w_sym_ind, evec_sym_ind = eig_banded(self.bandmat_sym,
312
 
                                    select='i', select_range=(ind1, ind2) )
313
 
        assert_array_almost_equal(sort(w_sym_ind),
314
 
                                  self.w_sym_lin[ind1:ind2+1])
315
 
        assert_array_almost_equal(abs(evec_sym_ind),
316
 
                                  abs(self.evec_sym_lin[:,ind1:ind2+1]) )
317
 
 
318
 
        w_herm_ind, evec_herm_ind = eig_banded(self.bandmat_herm,
319
 
                                    select='i', select_range=(ind1, ind2) )
320
 
        assert_array_almost_equal(sort(w_herm_ind),
321
 
                                  self.w_herm_lin[ind1:ind2+1])
322
 
        assert_array_almost_equal(abs(evec_herm_ind),
323
 
                                  abs(self.evec_herm_lin[:,ind1:ind2+1]) )
324
 
 
325
 
        # extracting eigenvalues with respect to a value range
326
 
        v_lower = self.w_sym_lin[ind1] - 1.0e-5
327
 
        v_upper = self.w_sym_lin[ind2] + 1.0e-5
328
 
        w_sym_val, evec_sym_val = eig_banded(self.bandmat_sym,
329
 
                                select='v', select_range=(v_lower, v_upper) )
330
 
        assert_array_almost_equal(sort(w_sym_val),
331
 
                                  self.w_sym_lin[ind1:ind2+1])
332
 
        assert_array_almost_equal(abs(evec_sym_val),
333
 
                                  abs(self.evec_sym_lin[:,ind1:ind2+1]) )
334
 
 
335
 
        v_lower = self.w_herm_lin[ind1] - 1.0e-5
336
 
        v_upper = self.w_herm_lin[ind2] + 1.0e-5
337
 
        w_herm_val, evec_herm_val = eig_banded(self.bandmat_herm,
338
 
                                select='v', select_range=(v_lower, v_upper) )
339
 
        assert_array_almost_equal(sort(w_herm_val),
340
 
                                  self.w_herm_lin[ind1:ind2+1])
341
 
        assert_array_almost_equal(abs(evec_herm_val),
342
 
                                  abs(self.evec_herm_lin[:,ind1:ind2+1]) )
343
 
 
344
 
 
345
 
    def check_dgbtrf(self):
346
 
        """Compare dgbtrf  LU factorisation with the LU factorisation result
347
 
           of linalg.lu."""
348
 
        M,N = shape(self.real_mat)        
349
 
        lu_symm_band, ipiv, info = dgbtrf(self.bandmat_real, self.KL, self.KU)
350
 
 
351
 
        # extract matrix u from lu_symm_band
352
 
        u = diag(lu_symm_band[2*self.KL,:])
353
 
        for i in xrange(self.KL + self.KU):
354
 
            u += diag(lu_symm_band[2*self.KL-1-i,i+1:N], i+1)
355
 
 
356
 
        p_lin, l_lin, u_lin = lu(self.real_mat, permute_l=0)
357
 
        assert_array_almost_equal(u, u_lin)
358
 
 
359
 
 
360
 
    def check_zgbtrf(self):
361
 
        """Compare zgbtrf  LU factorisation with the LU factorisation result
362
 
           of linalg.lu."""
363
 
        M,N = shape(self.comp_mat)        
364
 
        lu_symm_band, ipiv, info = zgbtrf(self.bandmat_comp, self.KL, self.KU)
365
 
 
366
 
        # extract matrix u from lu_symm_band
367
 
        u = diag(lu_symm_band[2*self.KL,:])
368
 
        for i in xrange(self.KL + self.KU):
369
 
            u += diag(lu_symm_band[2*self.KL-1-i,i+1:N], i+1)
370
 
 
371
 
        p_lin, l_lin, u_lin =lu(self.comp_mat, permute_l=0)
372
 
        assert_array_almost_equal(u, u_lin)
373
 
 
374
 
 
375
 
 
376
 
    def check_dgbtrs(self):
377
 
        """Compare dgbtrs  solutions for linear equation system  A*x = b
378
 
           with solutions of linalg.solve."""
379
 
        
380
 
        lu_symm_band, ipiv, info = dgbtrf(self.bandmat_real, self.KL, self.KU)
381
 
        y, info = dgbtrs(lu_symm_band, self.KL, self.KU, self.b, ipiv)
382
 
 
383
 
        y_lin = linalg.solve(self.real_mat, self.b)
384
 
        assert_array_almost_equal(y, y_lin)
385
 
 
386
 
    def check_zgbtrs(self):
387
 
        """Compare zgbtrs  solutions for linear equation system  A*x = b
388
 
           with solutions of linalg.solve."""
389
 
        
390
 
        lu_symm_band, ipiv, info = zgbtrf(self.bandmat_comp, self.KL, self.KU)
391
 
        y, info = zgbtrs(lu_symm_band, self.KL, self.KU, self.bc, ipiv)
392
 
 
393
 
        y_lin = linalg.solve(self.comp_mat, self.bc)
394
 
        assert_array_almost_equal(y, y_lin)
395
 
 
396
 
 
397
 
 
398
 
 
399
 
class test_lu(ScipyTestCase):
400
 
 
401
 
    def check_simple(self):
402
 
        a = [[1,2,3],[1,2,3],[2,5,6]]
403
 
        p,l,u = lu(a)
404
 
        assert_array_almost_equal(dot(dot(p,l),u),a)
405
 
        pl,u = lu(a,permute_l=1)
406
 
        assert_array_almost_equal(dot(pl,u),a)
407
 
 
408
 
    def check_simple_complex(self):
409
 
        a = [[1,2,3],[1,2,3],[2,5j,6]]
410
 
        p,l,u = lu(a)
411
 
        assert_array_almost_equal(dot(dot(p,l),u),a)
412
 
        pl,u = lu(a,permute_l=1)
413
 
        assert_array_almost_equal(dot(pl,u),a)
414
 
 
415
 
    #XXX: need more tests
416
 
 
417
 
class test_lu_solve(ScipyTestCase):
418
 
    def check_lu(self):
419
 
        a = random((10,10))
420
 
        b = random((10,))
421
 
 
422
 
        x1 = solve(a,b)
423
 
 
424
 
        lu_a = lu_factor(a)
425
 
        x2 = lu_solve(lu_a,b)
426
 
 
427
 
        assert_array_equal(x1,x2)
428
 
 
429
 
class test_svd(ScipyTestCase):
430
 
 
431
 
    def check_simple(self):
432
 
        a = [[1,2,3],[1,20,3],[2,5,6]]
433
 
        u,s,vh = svd(a)
434
 
        assert_array_almost_equal(dot(transpose(u),u),identity(3))
435
 
        assert_array_almost_equal(dot(transpose(vh),vh),identity(3))
436
 
        sigma = zeros((u.shape[0],vh.shape[0]),s.dtype.char)
437
 
        for i in range(len(s)): sigma[i,i] = s[i]
438
 
        assert_array_almost_equal(dot(dot(u,sigma),vh),a)
439
 
 
440
 
    def check_simple_singular(self):
441
 
        a = [[1,2,3],[1,2,3],[2,5,6]]
442
 
        u,s,vh = svd(a)
443
 
        assert_array_almost_equal(dot(transpose(u),u),identity(3))
444
 
        assert_array_almost_equal(dot(transpose(vh),vh),identity(3))
445
 
        sigma = zeros((u.shape[0],vh.shape[0]),s.dtype.char)
446
 
        for i in range(len(s)): sigma[i,i] = s[i]
447
 
        assert_array_almost_equal(dot(dot(u,sigma),vh),a)
448
 
 
449
 
    def check_simple_underdet(self):
450
 
        a = [[1,2,3],[4,5,6]]
451
 
        u,s,vh = svd(a)
452
 
        assert_array_almost_equal(dot(transpose(u),u),identity(2))
453
 
        assert_array_almost_equal(dot(transpose(vh),vh),identity(3))
454
 
        sigma = zeros((u.shape[0],vh.shape[0]),s.dtype.char)
455
 
        for i in range(len(s)): sigma[i,i] = s[i]
456
 
        assert_array_almost_equal(dot(dot(u,sigma),vh),a)
457
 
 
458
 
    def check_simple_overdet(self):
459
 
        a = [[1,2],[4,5],[3,4]]
460
 
        u,s,vh = svd(a)
461
 
        assert_array_almost_equal(dot(transpose(u),u),identity(3))
462
 
        assert_array_almost_equal(dot(transpose(vh),vh),identity(2))
463
 
        sigma = zeros((u.shape[0],vh.shape[0]),s.dtype.char)
464
 
        for i in range(len(s)): sigma[i,i] = s[i]
465
 
        assert_array_almost_equal(dot(dot(u,sigma),vh),a)
466
 
 
467
 
    def check_random(self):
468
 
        n = 20
469
 
        m = 15
470
 
        for i in range(3):
471
 
            for a in [random([n,m]),random([m,n])]:
472
 
                u,s,vh = svd(a)
473
 
                assert_array_almost_equal(dot(transpose(u),u),identity(len(u)))
474
 
                assert_array_almost_equal(dot(transpose(vh),vh),identity(len(vh)))
475
 
                sigma = zeros((u.shape[0],vh.shape[0]),s.dtype.char)
476
 
                for i in range(len(s)): sigma[i,i] = s[i]
477
 
                assert_array_almost_equal(dot(dot(u,sigma),vh),a)
478
 
 
479
 
    def check_simple_complex(self):
480
 
        a = [[1,2,3],[1,2j,3],[2,5,6]]
481
 
        u,s,vh = svd(a)
482
 
        assert_array_almost_equal(dot(conj(transpose(u)),u),identity(3))
483
 
        assert_array_almost_equal(dot(conj(transpose(vh)),vh),identity(3))
484
 
        sigma = zeros((u.shape[0],vh.shape[0]),s.dtype.char)
485
 
        for i in range(len(s)): sigma[i,i] = s[i]
486
 
        assert_array_almost_equal(dot(dot(u,sigma),vh),a)
487
 
 
488
 
    def check_random_complex(self):
489
 
        n = 20
490
 
        m = 15
491
 
        for i in range(3):
492
 
            for a in [random([n,m]),random([m,n])]:
493
 
                a = a + 1j*random(list(a.shape))
494
 
                u,s,vh = svd(a)
495
 
                assert_array_almost_equal(dot(conj(transpose(u)),u),identity(len(u)))
496
 
                # This fails when [m,n]
497
 
                #assert_array_almost_equal(dot(conj(transpose(vh)),vh),identity(len(vh),dtype=vh.dtype.char))
498
 
                sigma = zeros((u.shape[0],vh.shape[0]),s.dtype.char)
499
 
                for i in range(len(s)): sigma[i,i] = s[i]
500
 
                assert_array_almost_equal(dot(dot(u,sigma),vh),a)
501
 
 
502
 
class test_svdvals(ScipyTestCase):
503
 
 
504
 
    def check_simple(self):
505
 
        a = [[1,2,3],[1,2,3],[2,5,6]]
506
 
        s = svdvals(a)
507
 
        assert len(s)==3
508
 
        assert s[0]>=s[1]>=s[2]
509
 
 
510
 
    def check_simple_underdet(self):
511
 
        a = [[1,2,3],[4,5,6]]
512
 
        s = svdvals(a)
513
 
        assert len(s)==2
514
 
        assert s[0]>=s[1]
515
 
 
516
 
    def check_simple_overdet(self):
517
 
        a = [[1,2],[4,5],[3,4]]
518
 
        s = svdvals(a)
519
 
        assert len(s)==2
520
 
        assert s[0]>=s[1]
521
 
 
522
 
    def check_simple_complex(self):
523
 
        a = [[1,2,3],[1,20,3j],[2,5,6]]
524
 
        s = svdvals(a)
525
 
        assert len(s)==3
526
 
        assert s[0]>=s[1]>=s[2]
527
 
 
528
 
    def check_simple_underdet_complex(self):
529
 
        a = [[1,2,3],[4,5j,6]]
530
 
        s = svdvals(a)
531
 
        assert len(s)==2
532
 
        assert s[0]>=s[1]
533
 
 
534
 
    def check_simple_overdet_complex(self):
535
 
        a = [[1,2],[4,5],[3j,4]]
536
 
        s = svdvals(a)
537
 
        assert len(s)==2
538
 
        assert s[0]>=s[1]
539
 
 
540
 
class test_diagsvd(ScipyTestCase):
541
 
 
542
 
    def check_simple(self):
543
 
        assert_array_almost_equal(diagsvd([1,0,0],3,3),[[1,0,0],[0,0,0],[0,0,0]])
544
 
 
545
 
class test_cholesky(ScipyTestCase):
546
 
 
547
 
    def check_simple(self):
548
 
        a = [[8,2,3],[2,9,3],[3,3,6]]
549
 
        c = cholesky(a)
550
 
        assert_array_almost_equal(dot(transpose(c),c),a)
551
 
        c = transpose(c)
552
 
        a = dot(c,transpose(c))
553
 
        assert_array_almost_equal(cholesky(a,lower=1),c)
554
 
 
555
 
    def check_simple_complex(self):
556
 
        m = array([[3+1j,3+4j,5],[0,2+2j,2+7j],[0,0,7+4j]])
557
 
        a = dot(transpose(conjugate(m)),m)
558
 
        c = cholesky(a)
559
 
        a1 = dot(transpose(conjugate(c)),c)
560
 
        assert_array_almost_equal(a,a1)
561
 
        c = transpose(c)
562
 
        a = dot(c,transpose(conjugate(c)))
563
 
        assert_array_almost_equal(cholesky(a,lower=1),c)
564
 
 
565
 
    def check_random(self):
566
 
        n = 20
567
 
        for k in range(2):
568
 
            m = random([n,n])
569
 
            for i in range(n):
570
 
                m[i,i] = 20*(.1+m[i,i])
571
 
            a = dot(transpose(m),m)
572
 
            c = cholesky(a)
573
 
            a1 = dot(transpose(c),c)
574
 
            assert_array_almost_equal(a,a1)
575
 
            c = transpose(c)
576
 
            a = dot(c,transpose(c))
577
 
            assert_array_almost_equal(cholesky(a,lower=1),c)
578
 
 
579
 
    def check_random_complex(self):
580
 
        n = 20
581
 
        for k in range(2):
582
 
            m = random([n,n])+1j*random([n,n])
583
 
            for i in range(n):
584
 
                m[i,i] = 20*(.1+abs(m[i,i]))
585
 
            a = dot(transpose(conjugate(m)),m)
586
 
            c = cholesky(a)
587
 
            a1 = dot(transpose(conjugate(c)),c)
588
 
            assert_array_almost_equal(a,a1)
589
 
            c = transpose(c)
590
 
            a = dot(c,transpose(conjugate(c)))
591
 
            assert_array_almost_equal(cholesky(a,lower=1),c)
592
 
 
593
 
 
594
 
class test_qr(ScipyTestCase):
595
 
 
596
 
    def check_simple(self):
597
 
        a = [[8,2,3],[2,9,3],[5,3,6]]
598
 
        q,r = qr(a)
599
 
        assert_array_almost_equal(dot(transpose(q),q),identity(3))
600
 
        assert_array_almost_equal(dot(q,r),a)
601
 
 
602
 
    def check_simple_trap(self):
603
 
        a = [[8,2,3],[2,9,3]]
604
 
        q,r = qr(a)
605
 
        assert_array_almost_equal(dot(transpose(q),q),identity(2))
606
 
        assert_array_almost_equal(dot(q,r),a)
607
 
 
608
 
    def check_simple_tall(self):
609
 
        # full version
610
 
        a = [[8,2],[2,9],[5,3]]
611
 
        q,r = qr(a)
612
 
        assert_array_almost_equal(dot(transpose(q),q),identity(3))
613
 
        assert_array_almost_equal(dot(q,r),a)
614
 
 
615
 
    def check_simple_tall_e(self):
616
 
        # economy version
617
 
        a = [[8,2],[2,9],[5,3]]
618
 
        q,r = qr(a,econ=True)
619
 
        assert_array_almost_equal(dot(transpose(q),q),identity(2))
620
 
        assert_array_almost_equal(dot(q,r),a)
621
 
        assert_equal(q.shape, (3,2))
622
 
        assert_equal(r.shape, (2,2))
623
 
 
624
 
    def check_simple_complex(self):
625
 
        a = [[3,3+4j,5],[5,2,2+7j],[3,2,7]]
626
 
        q,r = qr(a)
627
 
        assert_array_almost_equal(dot(conj(transpose(q)),q),identity(3))
628
 
        assert_array_almost_equal(dot(q,r),a)
629
 
 
630
 
    def check_random(self):
631
 
        n = 20
632
 
        for k in range(2):
633
 
            a = random([n,n])
634
 
            q,r = qr(a)
635
 
            assert_array_almost_equal(dot(transpose(q),q),identity(n))
636
 
            assert_array_almost_equal(dot(q,r),a)
637
 
 
638
 
    def check_random_tall(self):
639
 
        # full version
640
 
        m = 200
641
 
        n = 100
642
 
        for k in range(2):
643
 
            a = random([m,n])
644
 
            q,r = qr(a)
645
 
            assert_array_almost_equal(dot(transpose(q),q),identity(m))
646
 
            assert_array_almost_equal(dot(q,r),a)
647
 
 
648
 
    def check_random_tall_e(self):
649
 
        # economy version
650
 
        m = 200
651
 
        n = 100
652
 
        for k in range(2):
653
 
            a = random([m,n])
654
 
            q,r = qr(a,econ=True)
655
 
            assert_array_almost_equal(dot(transpose(q),q),identity(n))
656
 
            assert_array_almost_equal(dot(q,r),a)
657
 
            assert_equal(q.shape, (m,n))
658
 
            assert_equal(r.shape, (n,n))
659
 
 
660
 
    def check_random_trap(self):
661
 
        m = 100
662
 
        n = 200
663
 
        for k in range(2):
664
 
            a = random([m,n])
665
 
            q,r = qr(a)
666
 
            assert_array_almost_equal(dot(transpose(q),q),identity(m))
667
 
            assert_array_almost_equal(dot(q,r),a)
668
 
 
669
 
    def check_random_complex(self):
670
 
        n = 20
671
 
        for k in range(2):
672
 
            a = random([n,n])+1j*random([n,n])
673
 
            q,r = qr(a)
674
 
            assert_array_almost_equal(dot(conj(transpose(q)),q),identity(n))
675
 
            assert_array_almost_equal(dot(q,r),a)
676
 
 
677
 
transp = transpose
678
 
any = sometrue
679
 
 
680
 
class test_schur(ScipyTestCase):
681
 
 
682
 
    def check_simple(self):
683
 
        a = [[8,12,3],[2,9,3],[10,3,6]]
684
 
        t,z = schur(a)
685
 
        assert_array_almost_equal(dot(dot(z,t),transp(conj(z))),a)
686
 
        tc,zc = schur(a,'complex')
687
 
        assert(any(ravel(iscomplex(zc))) and any(ravel(iscomplex(tc))))
688
 
        assert_array_almost_equal(dot(dot(zc,tc),transp(conj(zc))),a)
689
 
        tc2,zc2 = rsf2csf(tc,zc)
690
 
        assert_array_almost_equal(dot(dot(zc2,tc2),transp(conj(zc2))),a)
691
 
 
692
 
class test_hessenberg(ScipyTestCase):
693
 
 
694
 
    def check_simple(self):
695
 
        a = [[-149, -50,-154],
696
 
             [ 537, 180, 546],
697
 
             [ -27,  -9, -25]]
698
 
        h1 = [[-149.0000,42.2037,-156.3165],
699
 
              [-537.6783,152.5511,-554.9272],
700
 
              [0,0.0728, 2.4489]]
701
 
        h,q = hessenberg(a,calc_q=1)
702
 
        assert_array_almost_equal(dot(transp(q),dot(a,q)),h)
703
 
        assert_array_almost_equal(h,h1,decimal=4)
704
 
 
705
 
    def check_simple_complex(self):
706
 
        a = [[-149, -50,-154],
707
 
             [ 537, 180j, 546],
708
 
             [ -27j,  -9, -25]]
709
 
        h,q = hessenberg(a,calc_q=1)
710
 
        h1 = dot(transp(conj(q)),dot(a,q))
711
 
        assert_array_almost_equal(h1,h)
712
 
 
713
 
    def check_simple2(self):
714
 
        a = [[1,2,3,4,5,6,7],
715
 
             [0,2,3,4,6,7,2],
716
 
             [0,2,2,3,0,3,2],
717
 
             [0,0,2,8,0,0,2],
718
 
             [0,3,1,2,0,1,2],
719
 
             [0,1,2,3,0,1,0],
720
 
             [0,0,0,0,0,1,2]]
721
 
        h,q = hessenberg(a,calc_q=1)
722
 
        assert_array_almost_equal(dot(transp(q),dot(a,q)),h)
723
 
 
724
 
    def check_random(self):
725
 
        n = 20
726
 
        for k in range(2):
727
 
            a = random([n,n])
728
 
            h,q = hessenberg(a,calc_q=1)
729
 
            assert_array_almost_equal(dot(transp(q),dot(a,q)),h)
730
 
 
731
 
    def check_random_complex(self):
732
 
        n = 20
733
 
        for k in range(2):
734
 
            a = random([n,n])+1j*random([n,n])
735
 
            h,q = hessenberg(a,calc_q=1)
736
 
            h1 = dot(transp(conj(q)),dot(a,q))
737
 
            assert_array_almost_equal(h1,h)
738
 
 
739
 
if __name__ == "__main__":
740
 
    ScipyTest().run()