~ubuntu-branches/ubuntu/oneiric/python-scipy/oneiric-proposed

« back to all changes in this revision

Viewing changes to Lib/sandbox/umfpack/test_umfpack.py

  • Committer: Bazaar Package Importer
  • Author(s): Ondrej Certik
  • Date: 2008-06-16 22:58:01 UTC
  • mfrom: (2.1.24 intrepid)
  • Revision ID: james.westby@ubuntu.com-20080616225801-irdhrpcwiocfbcmt
Tags: 0.6.0-12
* The description updated to match the current SciPy (Closes: #489149).
* Standards-Version bumped to 3.8.0 (no action needed)
* Build-Depends: netcdf-dev changed to libnetcdf-dev

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
#!/usr/bin/env python
2
 
# Created by: Robert Cimrman, 05.12.2005
3
 
 
4
 
"""Benchamrks for umfpack module"""
5
 
 
6
 
from optparse import OptionParser
7
 
import umfpack as um
8
 
import numpy as nm
9
 
#import scipy.io as io
10
 
import scipy.sparse as sp
11
 
import scipy.linalg as nla
12
 
import pylab
13
 
import time
14
 
import urllib
15
 
import gzip
16
 
 
17
 
defaultURL = 'http://www.cise.ufl.edu/research/sparse/HBformat/'
18
 
 
19
 
usage = """%%prog [options] <matrix file name> [<matrix file name>, ...]
20
 
 
21
 
<matrix file name> can be a local or distant (gzipped) file
22
 
 
23
 
default url is:
24
 
        %s
25
 
 
26
 
supported formats are:
27
 
        triplet .. [nRow, nCol, nItem] followed by 'nItem' * [ir, ic, value]
28
 
        hb      .. Harwell-Boeing format N/A
29
 
""" % defaultURL
30
 
 
31
 
 
32
 
##
33
 
# 05.12.2005, c
34
 
def read_triplet( fd ):
35
 
    nRow, nCol = map( int, fd.readline().split() )
36
 
    nItem = int( fd.readline() )
37
 
 
38
 
    ij = nm.zeros( (nItem,2), nm.int32 )
39
 
    val = nm.zeros( (nItem,), nm.float64 )
40
 
    for ii, row in enumerate( fd.readlines() ):
41
 
        aux = row.split()
42
 
        ij[ii] = int( aux[0] ), int( aux[1] )
43
 
        val[ii] = float( aux[2] )
44
 
 
45
 
    mtx = sp.csc_matrix( (val, ij), dims = (nRow, nCol), nzmax = nItem )
46
 
 
47
 
    return mtx
48
 
 
49
 
##
50
 
# 06.12.2005, c
51
 
def read_triplet2( fd ):
52
 
    nRow, nCol = map( int, fd.readline().split() )
53
 
    nItem = int( fd.readline() )
54
 
 
55
 
    ij, val = io.read_array( fd,
56
 
                             columns = [(0,1), (2,)],
57
 
                             atype = (nm.int32, nm.float64),
58
 
                             rowsize = nItem )
59
 
 
60
 
    mtx = sp.csc_matrix( (val, ij), dims = (nRow, nCol), nzmax = nItem )
61
 
 
62
 
    return mtx
63
 
 
64
 
 
65
 
formatMap = {'triplet' : read_triplet}
66
 
##
67
 
# 05.12.2005, c
68
 
def readMatrix( matrixName, options ):
69
 
 
70
 
    if options.default_url:
71
 
        matrixName = defaultURL + matrixName
72
 
 
73
 
    print 'url:', matrixName
74
 
 
75
 
    if matrixName[:7] == 'http://':
76
 
        fileName, status = urllib.urlretrieve( matrixName )
77
 
##        print status
78
 
    else:
79
 
        fileName = matrixName
80
 
 
81
 
    print 'file:', fileName
82
 
 
83
 
    try:
84
 
        readMatrix = formatMap[options.format]
85
 
    except:
86
 
        raise ValueError, 'unsupported format: %s' % options.format
87
 
 
88
 
    print 'format:', options.format
89
 
 
90
 
    print 'reading...'
91
 
    if fileName[-3:] == '.gz':
92
 
        fd = gzip.open( fileName )
93
 
    else:
94
 
        fd = open( fileName )
95
 
 
96
 
    mtx = readMatrix( fd )
97
 
 
98
 
    fd.close()
99
 
 
100
 
    print 'ok'
101
 
 
102
 
    return mtx
103
 
 
104
 
##
105
 
# 05.12.2005, c
106
 
def main():
107
 
    parser = OptionParser( usage = usage )
108
 
    parser.add_option( "-c", "--compare",
109
 
                       action = "store_true", dest = "compare",
110
 
                       default = False,
111
 
                       help = "compare with default scipy.sparse solver [default: %default]" )
112
 
    parser.add_option( "-p", "--plot",
113
 
                       action = "store_true", dest = "plot",
114
 
                       default = False,
115
 
                       help = "plot time statistics [default: %default]" )
116
 
    parser.add_option( "-d", "--default-url",
117
 
                       action = "store_true", dest = "default_url",
118
 
                       default = False,
119
 
                       help = "use default url [default: %default]" )
120
 
    parser.add_option( "-f", "--format", type = type( '' ),
121
 
                       dest = "format", default = 'triplet',
122
 
                       help = "matrix format [default: %default]" )
123
 
    (options, args) = parser.parse_args()
124
 
 
125
 
    if (len( args ) >= 1):
126
 
        matrixNames = args;
127
 
    else:
128
 
        parser.print_help(),
129
 
        return
130
 
 
131
 
    sizes, nnzs, times, errors = [], [], [], []
132
 
    legends = ['umfpack', 'sparse.solve']
133
 
    for ii, matrixName in enumerate( matrixNames ):
134
 
 
135
 
        print '*' * 50
136
 
        mtx = readMatrix( matrixName, options )
137
 
 
138
 
        sizes.append( mtx.shape )
139
 
        nnzs.append( mtx.nnz )
140
 
        tts = nm.zeros( (2,), dtype = nm.double )
141
 
        times.append( tts )
142
 
        err = nm.zeros( (2,2), dtype = nm.double )
143
 
        errors.append( err )
144
 
 
145
 
        print 'size              : %s (%d nnz)' % (mtx.shape, mtx.nnz)
146
 
 
147
 
        sol0 = nm.ones( (mtx.shape[0],), dtype = nm.double )
148
 
        rhs = mtx * sol0
149
 
 
150
 
        umfpack = um.UmfpackContext()
151
 
 
152
 
        tt = time.clock()
153
 
        sol = umfpack( um.UMFPACK_A, mtx, rhs, autoTranspose = True )
154
 
        tts[0] = time.clock() - tt
155
 
        print "umfpack           : %.2f s" % tts[0]
156
 
 
157
 
        error = mtx * sol - rhs
158
 
        err[0,0] = nla.norm( error )
159
 
        print '||Ax-b||          :', err[0,0]
160
 
 
161
 
        error = sol0 - sol
162
 
        err[0,1] = nla.norm( error )
163
 
        print '||x - x_{exact}|| :', err[0,1]
164
 
 
165
 
        if options.compare:
166
 
            tt = time.clock()
167
 
            sol = sp.solve( mtx, rhs )
168
 
            tts[1] = time.clock() - tt
169
 
            print "sparse.solve      : %.2f s" % tts[1]
170
 
 
171
 
            error = mtx * sol - rhs
172
 
            err[1,0] = nla.norm( error )
173
 
            print '||Ax-b||          :', err[1,0]
174
 
 
175
 
            error = sol0 - sol
176
 
            err[1,1] = nla.norm( error )
177
 
            print '||x - x_{exact}|| :', err[1,1]
178
 
 
179
 
    if options.plot:
180
 
        times = nm.array( times )
181
 
        print times
182
 
        pylab.plot( times[:,0], 'b-o' )
183
 
        if options.compare:
184
 
            pylab.plot( times[:,1], 'r-s' )
185
 
        else:
186
 
            del legends[1]
187
 
 
188
 
        print legends
189
 
 
190
 
        ax = pylab.axis()
191
 
        y2 = 0.5 * (ax[3] - ax[2])
192
 
        xrng = range( len( nnzs ) )
193
 
        for ii in xrng:
194
 
            yy = y2 + 0.4 * (ax[3] - ax[2])\
195
 
                 * nm.sin( ii * 2 * nm.pi / (len( xrng ) - 1) )
196
 
 
197
 
            if options.compare:
198
 
                pylab.text( ii+0.02, yy,
199
 
                            '%s\n%.2e err_umf\n%.2e err_sp'
200
 
                            % (sizes[ii], nm.sum( errors[ii][0,:] ),
201
 
                               nm.sum( errors[ii][1,:] )) )
202
 
            else:
203
 
                pylab.text( ii+0.02, yy,
204
 
                            '%s\n%.2e err_umf'
205
 
                            % (sizes[ii], nm.sum( errors[ii][0,:] )) )
206
 
            pylab.plot( [ii, ii], [ax[2], ax[3]], 'k:' )
207
 
 
208
 
        pylab.xticks( xrng, ['%d' % (nnzs[ii] ) for ii in xrng] )
209
 
        pylab.xlabel( 'nnz' )
210
 
        pylab.ylabel( 'time [s]' )
211
 
        pylab.legend( legends )
212
 
        pylab.axis( [ax[0] - 0.05, ax[1] + 1, ax[2], ax[3]] )
213
 
        pylab.show()
214
 
 
215
 
if __name__ == '__main__':
216
 
    main()