~ubuntu-branches/ubuntu/saucy/python-scipy/saucy

« back to all changes in this revision

Viewing changes to Lib/linsolve/SuperLU/SRC/dgstrs.c

  • Committer: Bazaar Package Importer
  • Author(s): Ondrej Certik
  • Date: 2008-06-16 22:58:01 UTC
  • mfrom: (2.1.24 intrepid)
  • Revision ID: james.westby@ubuntu.com-20080616225801-irdhrpcwiocfbcmt
Tags: 0.6.0-12
* The description updated to match the current SciPy (Closes: #489149).
* Standards-Version bumped to 3.8.0 (no action needed)
* Build-Depends: netcdf-dev changed to libnetcdf-dev

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
 
2
 
/*
3
 
 * -- SuperLU routine (version 3.0) --
4
 
 * Univ. of California Berkeley, Xerox Palo Alto Research Center,
5
 
 * and Lawrence Berkeley National Lab.
6
 
 * October 15, 2003
7
 
 *
8
 
 */
9
 
/*
10
 
  Copyright (c) 1994 by Xerox Corporation.  All rights reserved.
11
 
 
12
 
  THIS MATERIAL IS PROVIDED AS IS, WITH ABSOLUTELY NO WARRANTY
13
 
  EXPRESSED OR IMPLIED.  ANY USE IS AT YOUR OWN RISK.
14
 
 
15
 
  Permission is hereby granted to use or copy this program for any
16
 
  purpose, provided the above notices are retained on all copies.
17
 
  Permission to modify the code and to distribute modified code is
18
 
  granted, provided the above notices are retained, and a notice that
19
 
  the code was modified is included with the above copyright notice.
20
 
*/
21
 
 
22
 
#include "dsp_defs.h"
23
 
 
24
 
 
25
 
/* 
26
 
 * Function prototypes 
27
 
 */
28
 
void dusolve(int, int, double*, double*);
29
 
void dlsolve(int, int, double*, double*);
30
 
void dmatvec(int, int, int, double*, double*, double*);
31
 
 
32
 
 
33
 
void
34
 
dgstrs (trans_t trans, SuperMatrix *L, SuperMatrix *U,
35
 
        int *perm_c, int *perm_r, SuperMatrix *B,
36
 
        SuperLUStat_t *stat, int *info)
37
 
{
38
 
/*
39
 
 * Purpose
40
 
 * =======
41
 
 *
42
 
 * DGSTRS solves a system of linear equations A*X=B or A'*X=B
43
 
 * with A sparse and B dense, using the LU factorization computed by
44
 
 * DGSTRF.
45
 
 *
46
 
 * See supermatrix.h for the definition of 'SuperMatrix' structure.
47
 
 *
48
 
 * Arguments
49
 
 * =========
50
 
 *
51
 
 * trans   (input) trans_t
52
 
 *          Specifies the form of the system of equations:
53
 
 *          = NOTRANS: A * X = B  (No transpose)
54
 
 *          = TRANS:   A'* X = B  (Transpose)
55
 
 *          = CONJ:    A**H * X = B  (Conjugate transpose)
56
 
 *
57
 
 * L       (input) SuperMatrix*
58
 
 *         The factor L from the factorization Pr*A*Pc=L*U as computed by
59
 
 *         dgstrf(). Use compressed row subscripts storage for supernodes,
60
 
 *         i.e., L has types: Stype = SLU_SC, Dtype = SLU_D, Mtype = SLU_TRLU.
61
 
 *
62
 
 * U       (input) SuperMatrix*
63
 
 *         The factor U from the factorization Pr*A*Pc=L*U as computed by
64
 
 *         dgstrf(). Use column-wise storage scheme, i.e., U has types:
65
 
 *         Stype = SLU_NC, Dtype = SLU_D, Mtype = SLU_TRU.
66
 
 *
67
 
 * perm_c  (input) int*, dimension (L->ncol)
68
 
 *         Column permutation vector, which defines the 
69
 
 *         permutation matrix Pc; perm_c[i] = j means column i of A is 
70
 
 *         in position j in A*Pc.
71
 
 *
72
 
 * perm_r  (input) int*, dimension (L->nrow)
73
 
 *         Row permutation vector, which defines the permutation matrix Pr; 
74
 
 *         perm_r[i] = j means row i of A is in position j in Pr*A.
75
 
 *
76
 
 * B       (input/output) SuperMatrix*
77
 
 *         B has types: Stype = SLU_DN, Dtype = SLU_D, Mtype = SLU_GE.
78
 
 *         On entry, the right hand side matrix.
79
 
 *         On exit, the solution matrix if info = 0;
80
 
 *
81
 
 * stat     (output) SuperLUStat_t*
82
 
 *          Record the statistics on runtime and floating-point operation count.
83
 
 *          See util.h for the definition of 'SuperLUStat_t'.
84
 
 *
85
 
 * info    (output) int*
86
 
 *         = 0: successful exit
87
 
 *         < 0: if info = -i, the i-th argument had an illegal value
88
 
 *
89
 
 */
90
 
#ifdef _CRAY
91
 
    _fcd ftcs1, ftcs2, ftcs3, ftcs4;
92
 
#endif
93
 
    int      incx = 1, incy = 1;
94
 
#ifdef USE_VENDOR_BLAS
95
 
    double   alpha = 1.0, beta = 1.0;
96
 
    double   *work_col;
97
 
#endif
98
 
    DNformat *Bstore;
99
 
    double   *Bmat;
100
 
    SCformat *Lstore;
101
 
    NCformat *Ustore;
102
 
    double   *Lval, *Uval;
103
 
    int      fsupc, nrow, nsupr, nsupc, luptr, istart, irow;
104
 
    int      i, j, k, iptr, jcol, n, ldb, nrhs;
105
 
    double   *work, *rhs_work, *soln;
106
 
    flops_t  solve_ops;
107
 
    void dprint_soln();
108
 
 
109
 
    /* Test input parameters ... */
110
 
    *info = 0;
111
 
    Bstore = B->Store;
112
 
    ldb = Bstore->lda;
113
 
    nrhs = B->ncol;
114
 
    if ( trans != NOTRANS && trans != TRANS && trans != CONJ ) *info = -1;
115
 
    else if ( L->nrow != L->ncol || L->nrow < 0 ||
116
 
              L->Stype != SLU_SC || L->Dtype != SLU_D || L->Mtype != SLU_TRLU )
117
 
        *info = -2;
118
 
    else if ( U->nrow != U->ncol || U->nrow < 0 ||
119
 
              U->Stype != SLU_NC || U->Dtype != SLU_D || U->Mtype != SLU_TRU )
120
 
        *info = -3;
121
 
    else if ( ldb < SUPERLU_MAX(0, L->nrow) ||
122
 
              B->Stype != SLU_DN || B->Dtype != SLU_D || B->Mtype != SLU_GE )
123
 
        *info = -6;
124
 
    if ( *info ) {
125
 
        i = -(*info);
126
 
        xerbla_("dgstrs", &i);
127
 
        return;
128
 
    }
129
 
 
130
 
    n = L->nrow;
131
 
    work = doubleCalloc(n * nrhs);
132
 
    if ( !work ) ABORT("Malloc fails for local work[].");
133
 
    soln = doubleMalloc(n);
134
 
    if ( !soln ) ABORT("Malloc fails for local soln[].");
135
 
 
136
 
    Bmat = Bstore->nzval;
137
 
    Lstore = L->Store;
138
 
    Lval = Lstore->nzval;
139
 
    Ustore = U->Store;
140
 
    Uval = Ustore->nzval;
141
 
    solve_ops = 0;
142
 
    
143
 
    if ( trans == NOTRANS ) {
144
 
        /* Permute right hand sides to form Pr*B */
145
 
        for (i = 0; i < nrhs; i++) {
146
 
            rhs_work = &Bmat[i*ldb];
147
 
            for (k = 0; k < n; k++) soln[perm_r[k]] = rhs_work[k];
148
 
            for (k = 0; k < n; k++) rhs_work[k] = soln[k];
149
 
        }
150
 
        
151
 
        /* Forward solve PLy=Pb. */
152
 
        for (k = 0; k <= Lstore->nsuper; k++) {
153
 
            fsupc = L_FST_SUPC(k);
154
 
            istart = L_SUB_START(fsupc);
155
 
            nsupr = L_SUB_START(fsupc+1) - istart;
156
 
            nsupc = L_FST_SUPC(k+1) - fsupc;
157
 
            nrow = nsupr - nsupc;
158
 
 
159
 
            solve_ops += nsupc * (nsupc - 1) * nrhs;
160
 
            solve_ops += 2 * nrow * nsupc * nrhs;
161
 
            
162
 
            if ( nsupc == 1 ) {
163
 
                for (j = 0; j < nrhs; j++) {
164
 
                    rhs_work = &Bmat[j*ldb];
165
 
                    luptr = L_NZ_START(fsupc);
166
 
                    for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); iptr++){
167
 
                        irow = L_SUB(iptr);
168
 
                        ++luptr;
169
 
                        rhs_work[irow] -= rhs_work[fsupc] * Lval[luptr];
170
 
                    }
171
 
                }
172
 
            } else {
173
 
                luptr = L_NZ_START(fsupc);
174
 
#ifdef USE_VENDOR_BLAS
175
 
#ifdef _CRAY
176
 
                ftcs1 = _cptofcd("L", strlen("L"));
177
 
                ftcs2 = _cptofcd("N", strlen("N"));
178
 
                ftcs3 = _cptofcd("U", strlen("U"));
179
 
                STRSM( ftcs1, ftcs1, ftcs2, ftcs3, &nsupc, &nrhs, &alpha,
180
 
                       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
181
 
                
182
 
                SGEMM( ftcs2, ftcs2, &nrow, &nrhs, &nsupc, &alpha, 
183
 
                        &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, 
184
 
                        &beta, &work[0], &n );
185
 
#else
186
 
                dtrsm_("L", "L", "N", "U", &nsupc, &nrhs, &alpha,
187
 
                       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
188
 
                
189
 
                dgemm_( "N", "N", &nrow, &nrhs, &nsupc, &alpha, 
190
 
                        &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, 
191
 
                        &beta, &work[0], &n );
192
 
#endif
193
 
                for (j = 0; j < nrhs; j++) {
194
 
                    rhs_work = &Bmat[j*ldb];
195
 
                    work_col = &work[j*n];
196
 
                    iptr = istart + nsupc;
197
 
                    for (i = 0; i < nrow; i++) {
198
 
                        irow = L_SUB(iptr);
199
 
                        rhs_work[irow] -= work_col[i]; /* Scatter */
200
 
                        work_col[i] = 0.0;
201
 
                        iptr++;
202
 
                    }
203
 
                }
204
 
#else           
205
 
                for (j = 0; j < nrhs; j++) {
206
 
                    rhs_work = &Bmat[j*ldb];
207
 
                    dlsolve (nsupr, nsupc, &Lval[luptr], &rhs_work[fsupc]);
208
 
                    dmatvec (nsupr, nrow, nsupc, &Lval[luptr+nsupc],
209
 
                            &rhs_work[fsupc], &work[0] );
210
 
 
211
 
                    iptr = istart + nsupc;
212
 
                    for (i = 0; i < nrow; i++) {
213
 
                        irow = L_SUB(iptr);
214
 
                        rhs_work[irow] -= work[i];
215
 
                        work[i] = 0.0;
216
 
                        iptr++;
217
 
                    }
218
 
                }
219
 
#endif              
220
 
            } /* else ... */
221
 
        } /* for L-solve */
222
 
 
223
 
#ifdef DEBUG
224
 
        printf("After L-solve: y=\n");
225
 
        dprint_soln(n, nrhs, Bmat);
226
 
#endif
227
 
 
228
 
        /*
229
 
         * Back solve Ux=y.
230
 
         */
231
 
        for (k = Lstore->nsuper; k >= 0; k--) {
232
 
            fsupc = L_FST_SUPC(k);
233
 
            istart = L_SUB_START(fsupc);
234
 
            nsupr = L_SUB_START(fsupc+1) - istart;
235
 
            nsupc = L_FST_SUPC(k+1) - fsupc;
236
 
            luptr = L_NZ_START(fsupc);
237
 
 
238
 
            solve_ops += nsupc * (nsupc + 1) * nrhs;
239
 
 
240
 
            if ( nsupc == 1 ) {
241
 
                rhs_work = &Bmat[0];
242
 
                for (j = 0; j < nrhs; j++) {
243
 
                    rhs_work[fsupc] /= Lval[luptr];
244
 
                    rhs_work += ldb;
245
 
                }
246
 
            } else {
247
 
#ifdef USE_VENDOR_BLAS
248
 
#ifdef _CRAY
249
 
                ftcs1 = _cptofcd("L", strlen("L"));
250
 
                ftcs2 = _cptofcd("U", strlen("U"));
251
 
                ftcs3 = _cptofcd("N", strlen("N"));
252
 
                STRSM( ftcs1, ftcs2, ftcs3, ftcs3, &nsupc, &nrhs, &alpha,
253
 
                       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
254
 
#else
255
 
                dtrsm_("L", "U", "N", "N", &nsupc, &nrhs, &alpha,
256
 
                       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
257
 
#endif
258
 
#else           
259
 
                for (j = 0; j < nrhs; j++)
260
 
                    dusolve ( nsupr, nsupc, &Lval[luptr], &Bmat[fsupc+j*ldb] );
261
 
#endif          
262
 
            }
263
 
 
264
 
            for (j = 0; j < nrhs; ++j) {
265
 
                rhs_work = &Bmat[j*ldb];
266
 
                for (jcol = fsupc; jcol < fsupc + nsupc; jcol++) {
267
 
                    solve_ops += 2*(U_NZ_START(jcol+1) - U_NZ_START(jcol));
268
 
                    for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++ ){
269
 
                        irow = U_SUB(i);
270
 
                        rhs_work[irow] -= rhs_work[jcol] * Uval[i];
271
 
                    }
272
 
                }
273
 
            }
274
 
            
275
 
        } /* for U-solve */
276
 
 
277
 
#ifdef DEBUG
278
 
        printf("After U-solve: x=\n");
279
 
        dprint_soln(n, nrhs, Bmat);
280
 
#endif
281
 
 
282
 
        /* Compute the final solution X := Pc*X. */
283
 
        for (i = 0; i < nrhs; i++) {
284
 
            rhs_work = &Bmat[i*ldb];
285
 
            for (k = 0; k < n; k++) soln[k] = rhs_work[perm_c[k]];
286
 
            for (k = 0; k < n; k++) rhs_work[k] = soln[k];
287
 
        }
288
 
        
289
 
        stat->ops[SOLVE] = solve_ops;
290
 
 
291
 
    } else { /* Solve A'*X=B */
292
 
        /* Permute right hand sides to form Pc'*B. */
293
 
        for (i = 0; i < nrhs; i++) {
294
 
            rhs_work = &Bmat[i*ldb];
295
 
            for (k = 0; k < n; k++) soln[perm_c[k]] = rhs_work[k];
296
 
            for (k = 0; k < n; k++) rhs_work[k] = soln[k];
297
 
        }
298
 
 
299
 
        stat->ops[SOLVE] = 0;
300
 
        
301
 
        for (k = 0; k < nrhs; ++k) {
302
 
            
303
 
            /* Multiply by inv(U'). */
304
 
            sp_dtrsv("U", "T", "N", L, U, &Bmat[k*ldb], stat, info);
305
 
            
306
 
            /* Multiply by inv(L'). */
307
 
            sp_dtrsv("L", "T", "U", L, U, &Bmat[k*ldb], stat, info);
308
 
            
309
 
        }
310
 
        
311
 
        /* Compute the final solution X := Pr'*X (=inv(Pr)*X) */
312
 
        for (i = 0; i < nrhs; i++) {
313
 
            rhs_work = &Bmat[i*ldb];
314
 
            for (k = 0; k < n; k++) soln[k] = rhs_work[perm_r[k]];
315
 
            for (k = 0; k < n; k++) rhs_work[k] = soln[k];
316
 
        }
317
 
 
318
 
    }
319
 
 
320
 
    SUPERLU_FREE(work);
321
 
    SUPERLU_FREE(soln);
322
 
}
323
 
 
324
 
/*
325
 
 * Diagnostic print of the solution vector 
326
 
 */
327
 
void
328
 
dprint_soln(int n, int nrhs, double *soln)
329
 
{
330
 
    int i;
331
 
 
332
 
    for (i = 0; i < n; i++) 
333
 
        printf("\t%d: %.4f\n", i, soln[i]);
334
 
}