~ubuntu-branches/ubuntu/oneiric/python-scipy/oneiric-proposed

« back to all changes in this revision

Viewing changes to Lib/sandbox/pysparse/superlu/dgstrs.c

  • Committer: Bazaar Package Importer
  • Author(s): Ondrej Certik
  • Date: 2008-06-16 22:58:01 UTC
  • mfrom: (2.1.24 intrepid)
  • Revision ID: james.westby@ubuntu.com-20080616225801-irdhrpcwiocfbcmt
Tags: 0.6.0-12
* The description updated to match the current SciPy (Closes: #489149).
* Standards-Version bumped to 3.8.0 (no action needed)
* Build-Depends: netcdf-dev changed to libnetcdf-dev

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
 
2
 
 
3
 
/*
4
 
 * -- SuperLU routine (version 2.0) --
5
 
 * Univ. of California Berkeley, Xerox Palo Alto Research Center,
6
 
 * and Lawrence Berkeley National Lab.
7
 
 * November 15, 1997
8
 
 *
9
 
 */
10
 
/*
11
 
  Copyright (c) 1994 by Xerox Corporation.  All rights reserved.
12
 
 
13
 
  THIS MATERIAL IS PROVIDED AS IS, WITH ABSOLUTELY NO WARRANTY
14
 
  EXPRESSED OR IMPLIED.  ANY USE IS AT YOUR OWN RISK.
15
 
 
16
 
  Permission is hereby granted to use or copy this program for any
17
 
  purpose, provided the above notices are retained on all copies.
18
 
  Permission to modify the code and to distribute modified code is
19
 
  granted, provided the above notices are retained, and a notice that
20
 
  the code was modified is included with the above copyright notice.
21
 
*/
22
 
 
23
 
#include "dsp_defs.h"
24
 
#include "util.h"
25
 
 
26
 
 
27
 
/* 
28
 
 * Function prototypes 
29
 
 */
30
 
void dusolve(int, int, double*, double*);
31
 
void dlsolve(int, int, double*, double*);
32
 
void dmatvec(int, int, int, double*, double*, double*);
33
 
 
34
 
 
35
 
void
36
 
dgstrs (char *trans, SuperMatrix *L, SuperMatrix *U,
37
 
        int *perm_r, int *perm_c, SuperMatrix *B, int *info)
38
 
{
39
 
/*
40
 
 * Purpose
41
 
 * =======
42
 
 *
43
 
 * DGSTRS solves a system of linear equations A*X=B or A'*X=B
44
 
 * with A sparse and B dense, using the LU factorization computed by
45
 
 * DGSTRF.
46
 
 *
47
 
 * See supermatrix.h for the definition of 'SuperMatrix' structure.
48
 
 *
49
 
 * Arguments
50
 
 * =========
51
 
 *
52
 
 * trans   (input) char*
53
 
 *          Specifies the form of the system of equations:
54
 
 *          = 'N':  A * X = B  (No transpose)
55
 
 *          = 'T':  A'* X = B  (Transpose)
56
 
 *          = 'C':  A**H * X = B  (Conjugate transpose)
57
 
 *
58
 
 * L       (input) SuperMatrix*
59
 
 *         The factor L from the factorization Pr*A*Pc=L*U as computed by
60
 
 *         dgstrf(). Use compressed row subscripts storage for supernodes,
61
 
 *         i.e., L has types: Stype = SC, Dtype = D_, Mtype = TRLU.
62
 
 *
63
 
 * U       (input) SuperMatrix*
64
 
 *         The factor U from the factorization Pr*A*Pc=L*U as computed by
65
 
 *         dgstrf(). Use column-wise storage scheme, i.e., U has types:
66
 
 *         Stype = NC, Dtype = D_, Mtype = TRU.
67
 
 *
68
 
 * perm_r  (input) int*, dimension (L->nrow)
69
 
 *         Row permutation vector, which defines the permutation matrix Pr; 
70
 
 *         perm_r[i] = j means row i of A is in position j in Pr*A.
71
 
 *
72
 
 * perm_c  (input) int*, dimension (L->ncol)
73
 
 *         Column permutation vector, which defines the 
74
 
 *         permutation matrix Pc; perm_c[i] = j means column i of A is 
75
 
 *         in position j in A*Pc.
76
 
 *
77
 
 * B       (input/output) SuperMatrix*
78
 
 *         B has types: Stype = DN, Dtype = D_, Mtype = GE.
79
 
 *         On entry, the right hand side matrix.
80
 
 *         On exit, the solution matrix if info = 0;
81
 
 *
82
 
 * info    (output) int*
83
 
 *         = 0: successful exit
84
 
 *         < 0: if info = -i, the i-th argument had an illegal value
85
 
 *
86
 
 */
87
 
#ifdef _CRAY
88
 
    _fcd ftcs1, ftcs2, ftcs3, ftcs4;
89
 
#endif
90
 
    int      incx = 1, incy = 1;
91
 
    double   alpha = 1.0, beta = 1.0;
92
 
    DNformat *Bstore;
93
 
    double   *Bmat;
94
 
    SCformat *Lstore;
95
 
    NCformat *Ustore;
96
 
    double   *Lval, *Uval;
97
 
    int      nrow, notran;
98
 
    int      fsupc, nsupr, nsupc, luptr, istart, irow;
99
 
    int      i, j, k, iptr, jcol, n, ldb, nrhs;
100
 
    double   *work, *work_col, *rhs_work, *soln;
101
 
    flops_t  solve_ops;
102
 
    extern SuperLUStat_t SuperLUStat;
103
 
    void dprint_soln();
104
 
 
105
 
    /* Test input parameters ... */
106
 
    *info = 0;
107
 
    Bstore = B->Store;
108
 
    ldb = Bstore->lda;
109
 
    nrhs = B->ncol;
110
 
    notran = lsame_(trans, "N");
111
 
    if ( !notran && !lsame_(trans, "T") && !lsame_(trans, "C") ) *info = -1;
112
 
    else if ( L->nrow != L->ncol || L->nrow < 0 ||
113
 
              L->Stype != SC || L->Dtype != D_ || L->Mtype != TRLU )
114
 
        *info = -2;
115
 
    else if ( U->nrow != U->ncol || U->nrow < 0 ||
116
 
              U->Stype != NC || U->Dtype != D_ || U->Mtype != TRU )
117
 
        *info = -3;
118
 
    else if ( ldb < SUPERLU_MAX(0, L->nrow) ||
119
 
              B->Stype != DN || B->Dtype != D_ || B->Mtype != GE )
120
 
        *info = -6;
121
 
    if ( *info ) {
122
 
        i = -(*info);
123
 
        xerbla_("dgstrs", &i);
124
 
        return;
125
 
    }
126
 
 
127
 
    n = L->nrow;
128
 
    work = doubleCalloc(n * nrhs);
129
 
    if ( !work ) ABORT("Malloc fails for local work[].");
130
 
    soln = doubleMalloc(n);
131
 
    if ( !soln ) ABORT("Malloc fails for local soln[].");
132
 
 
133
 
    Bmat = Bstore->nzval;
134
 
    Lstore = L->Store;
135
 
    Lval = Lstore->nzval;
136
 
    Ustore = U->Store;
137
 
    Uval = Ustore->nzval;
138
 
    solve_ops = 0;
139
 
    
140
 
    if ( notran ) {
141
 
        /* Permute right hand sides to form Pr*B */
142
 
        for (i = 0; i < nrhs; i++) {
143
 
            rhs_work = &Bmat[i*ldb];
144
 
            for (k = 0; k < n; k++) soln[perm_r[k]] = rhs_work[k];
145
 
            for (k = 0; k < n; k++) rhs_work[k] = soln[k];
146
 
        }
147
 
        
148
 
        /* Forward solve PLy=Pb. */
149
 
        for (k = 0; k <= Lstore->nsuper; k++) {
150
 
            fsupc = L_FST_SUPC(k);
151
 
            istart = L_SUB_START(fsupc);
152
 
            nsupr = L_SUB_START(fsupc+1) - istart;
153
 
            nsupc = L_FST_SUPC(k+1) - fsupc;
154
 
            nrow = nsupr - nsupc;
155
 
 
156
 
            solve_ops += nsupc * (nsupc - 1) * nrhs;
157
 
            solve_ops += 2 * nrow * nsupc * nrhs;
158
 
            
159
 
            if ( nsupc == 1 ) {
160
 
                for (j = 0; j < nrhs; j++) {
161
 
                    rhs_work = &Bmat[j*ldb];
162
 
                    luptr = L_NZ_START(fsupc);
163
 
                    for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); iptr++){
164
 
                        irow = L_SUB(iptr);
165
 
                        ++luptr;
166
 
                        rhs_work[irow] -= rhs_work[fsupc] * Lval[luptr];
167
 
                    }
168
 
                }
169
 
            } else {
170
 
                luptr = L_NZ_START(fsupc);
171
 
#ifdef USE_VENDOR_BLAS
172
 
#ifdef _CRAY
173
 
                ftcs1 = _cptofcd("L", strlen("L"));
174
 
                ftcs2 = _cptofcd("N", strlen("N"));
175
 
                ftcs3 = _cptofcd("U", strlen("U"));
176
 
                STRSM( ftcs1, ftcs1, ftcs2, ftcs3, &nsupc, &nrhs, &alpha,
177
 
                       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
178
 
                
179
 
                SGEMM( ftcs2, ftcs2, &nrow, &nrhs, &nsupc, &alpha, 
180
 
                        &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, 
181
 
                        &beta, &work[0], &n );
182
 
#else
183
 
                dtrsm_("L", "L", "N", "U", &nsupc, &nrhs, &alpha,
184
 
                       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
185
 
                
186
 
                dgemm_( "N", "N", &nrow, &nrhs, &nsupc, &alpha, 
187
 
                        &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, 
188
 
                        &beta, &work[0], &n );
189
 
#endif
190
 
                for (j = 0; j < nrhs; j++) {
191
 
                    rhs_work = &Bmat[j*ldb];
192
 
                    work_col = &work[j*n];
193
 
                    iptr = istart + nsupc;
194
 
                    for (i = 0; i < nrow; i++) {
195
 
                        irow = L_SUB(iptr);
196
 
                        rhs_work[irow] -= work_col[i]; /* Scatter */
197
 
                        work_col[i] = 0.0;
198
 
                        iptr++;
199
 
                    }
200
 
                }
201
 
#else           
202
 
                for (j = 0; j < nrhs; j++) {
203
 
                    rhs_work = &Bmat[j*ldb];
204
 
                    dlsolve (nsupr, nsupc, &Lval[luptr], &rhs_work[fsupc]);
205
 
                    dmatvec (nsupr, nrow, nsupc, &Lval[luptr+nsupc],
206
 
                            &rhs_work[fsupc], &work[0] );
207
 
 
208
 
                    iptr = istart + nsupc;
209
 
                    for (i = 0; i < nrow; i++) {
210
 
                        irow = L_SUB(iptr);
211
 
                        rhs_work[irow] -= work[i];
212
 
                        work[i] = 0.0;
213
 
                        iptr++;
214
 
                    }
215
 
                }
216
 
#endif              
217
 
            } /* else ... */
218
 
        } /* for L-solve */
219
 
 
220
 
#ifdef DEBUG
221
 
        printf("After L-solve: y=\n");
222
 
        dprint_soln(n, nrhs, Bmat);
223
 
#endif
224
 
 
225
 
        /*
226
 
         * Back solve Ux=y.
227
 
         */
228
 
        for (k = Lstore->nsuper; k >= 0; k--) {
229
 
            fsupc = L_FST_SUPC(k);
230
 
            istart = L_SUB_START(fsupc);
231
 
            nsupr = L_SUB_START(fsupc+1) - istart;
232
 
            nsupc = L_FST_SUPC(k+1) - fsupc;
233
 
            luptr = L_NZ_START(fsupc);
234
 
 
235
 
            solve_ops += nsupc * (nsupc + 1) * nrhs;
236
 
 
237
 
            if ( nsupc == 1 ) {
238
 
                rhs_work = &Bmat[0];
239
 
                for (j = 0; j < nrhs; j++) {
240
 
                    rhs_work[fsupc] /= Lval[luptr];
241
 
                    rhs_work += ldb;
242
 
                }
243
 
            } else {
244
 
#ifdef USE_VENDOR_BLAS
245
 
#ifdef _CRAY
246
 
                ftcs1 = _cptofcd("L", strlen("L"));
247
 
                ftcs2 = _cptofcd("U", strlen("U"));
248
 
                ftcs3 = _cptofcd("N", strlen("N"));
249
 
                STRSM( ftcs1, ftcs2, ftcs3, ftcs3, &nsupc, &nrhs, &alpha,
250
 
                       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
251
 
#else
252
 
                dtrsm_("L", "U", "N", "N", &nsupc, &nrhs, &alpha,
253
 
                       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
254
 
#endif
255
 
#else           
256
 
                for (j = 0; j < nrhs; j++)
257
 
                    dusolve ( nsupr, nsupc, &Lval[luptr], &Bmat[fsupc+j*ldb] );
258
 
#endif          
259
 
            }
260
 
 
261
 
            for (j = 0; j < nrhs; ++j) {
262
 
                rhs_work = &Bmat[j*ldb];
263
 
                for (jcol = fsupc; jcol < fsupc + nsupc; jcol++) {
264
 
                    solve_ops += 2*(U_NZ_START(jcol+1) - U_NZ_START(jcol));
265
 
                    for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++ ){
266
 
                        irow = U_SUB(i);
267
 
                        rhs_work[irow] -= rhs_work[jcol] * Uval[i];
268
 
                    }
269
 
                }
270
 
            }
271
 
            
272
 
        } /* for U-solve */
273
 
 
274
 
#ifdef DEBUG
275
 
        printf("After U-solve: x=\n");
276
 
        dprint_soln(n, nrhs, Bmat);
277
 
#endif
278
 
 
279
 
        /* Compute the final solution X := Pc*X. */
280
 
        for (i = 0; i < nrhs; i++) {
281
 
            rhs_work = &Bmat[i*ldb];
282
 
            for (k = 0; k < n; k++) soln[k] = rhs_work[perm_c[k]];
283
 
            for (k = 0; k < n; k++) rhs_work[k] = soln[k];
284
 
        }
285
 
        
286
 
        SuperLUStat.ops[SOLVE] = solve_ops;
287
 
 
288
 
    } else { /* Solve A'*X=B */
289
 
        /* Permute right hand sides to form Pc'*B. */
290
 
        for (i = 0; i < nrhs; i++) {
291
 
            rhs_work = &Bmat[i*ldb];
292
 
            for (k = 0; k < n; k++) soln[perm_c[k]] = rhs_work[k];
293
 
            for (k = 0; k < n; k++) rhs_work[k] = soln[k];
294
 
        }
295
 
 
296
 
        SuperLUStat.ops[SOLVE] = 0;
297
 
        
298
 
        for (k = 0; k < nrhs; ++k) {
299
 
            
300
 
            /* Multiply by inv(U'). */
301
 
            sp_dtrsv("U", "T", "N", L, U, &Bmat[k*ldb], info);
302
 
            
303
 
            /* Multiply by inv(L'). */
304
 
            sp_dtrsv("L", "T", "U", L, U, &Bmat[k*ldb], info);
305
 
            
306
 
        }
307
 
        
308
 
        /* Compute the final solution X := Pr'*X (=inv(Pr)*X) */
309
 
        for (i = 0; i < nrhs; i++) {
310
 
            rhs_work = &Bmat[i*ldb];
311
 
            for (k = 0; k < n; k++) soln[k] = rhs_work[perm_r[k]];
312
 
            for (k = 0; k < n; k++) rhs_work[k] = soln[k];
313
 
        }
314
 
 
315
 
    }
316
 
 
317
 
    SUPERLU_FREE(work);
318
 
    SUPERLU_FREE(soln);
319
 
}
320
 
 
321
 
/*
322
 
 * Diagnostic print of the solution vector 
323
 
 */
324
 
void
325
 
dprint_soln(int n, int nrhs, double *soln)
326
 
{
327
 
    int i;
328
 
 
329
 
    for (i = 0; i < n; i++) 
330
 
        printf("\t%d: %.4f\n", i, soln[i]);
331
 
}