~ubuntu-branches/ubuntu/raring/python-scipy/raring-proposed

« back to all changes in this revision

Viewing changes to Lib/linsolve/SuperLU/SRC/dgstrs.c

  • Committer: Bazaar Package Importer
  • Author(s): Matthias Klose
  • Date: 2007-01-07 14:12:12 UTC
  • mfrom: (1.1.1 upstream)
  • Revision ID: james.westby@ubuntu.com-20070107141212-mm0ebkh5b37hcpzn
* Remove build dependency on python-numpy-dev.
* python-scipy: Depend on python-numpy instead of python-numpy-dev.
* Package builds on other archs than i386. Closes: #402783.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
 
 
2
/*
 
3
 * -- SuperLU routine (version 3.0) --
 
4
 * Univ. of California Berkeley, Xerox Palo Alto Research Center,
 
5
 * and Lawrence Berkeley National Lab.
 
6
 * October 15, 2003
 
7
 *
 
8
 */
 
9
/*
 
10
  Copyright (c) 1994 by Xerox Corporation.  All rights reserved.
 
11
 
 
12
  THIS MATERIAL IS PROVIDED AS IS, WITH ABSOLUTELY NO WARRANTY
 
13
  EXPRESSED OR IMPLIED.  ANY USE IS AT YOUR OWN RISK.
 
14
 
 
15
  Permission is hereby granted to use or copy this program for any
 
16
  purpose, provided the above notices are retained on all copies.
 
17
  Permission to modify the code and to distribute modified code is
 
18
  granted, provided the above notices are retained, and a notice that
 
19
  the code was modified is included with the above copyright notice.
 
20
*/
 
21
 
 
22
#include "dsp_defs.h"
 
23
 
 
24
 
 
25
/* 
 
26
 * Function prototypes 
 
27
 */
 
28
void dusolve(int, int, double*, double*);
 
29
void dlsolve(int, int, double*, double*);
 
30
void dmatvec(int, int, int, double*, double*, double*);
 
31
 
 
32
 
 
33
void
 
34
dgstrs (trans_t trans, SuperMatrix *L, SuperMatrix *U,
 
35
        int *perm_c, int *perm_r, SuperMatrix *B,
 
36
        SuperLUStat_t *stat, int *info)
 
37
{
 
38
/*
 
39
 * Purpose
 
40
 * =======
 
41
 *
 
42
 * DGSTRS solves a system of linear equations A*X=B or A'*X=B
 
43
 * with A sparse and B dense, using the LU factorization computed by
 
44
 * DGSTRF.
 
45
 *
 
46
 * See supermatrix.h for the definition of 'SuperMatrix' structure.
 
47
 *
 
48
 * Arguments
 
49
 * =========
 
50
 *
 
51
 * trans   (input) trans_t
 
52
 *          Specifies the form of the system of equations:
 
53
 *          = NOTRANS: A * X = B  (No transpose)
 
54
 *          = TRANS:   A'* X = B  (Transpose)
 
55
 *          = CONJ:    A**H * X = B  (Conjugate transpose)
 
56
 *
 
57
 * L       (input) SuperMatrix*
 
58
 *         The factor L from the factorization Pr*A*Pc=L*U as computed by
 
59
 *         dgstrf(). Use compressed row subscripts storage for supernodes,
 
60
 *         i.e., L has types: Stype = SLU_SC, Dtype = SLU_D, Mtype = SLU_TRLU.
 
61
 *
 
62
 * U       (input) SuperMatrix*
 
63
 *         The factor U from the factorization Pr*A*Pc=L*U as computed by
 
64
 *         dgstrf(). Use column-wise storage scheme, i.e., U has types:
 
65
 *         Stype = SLU_NC, Dtype = SLU_D, Mtype = SLU_TRU.
 
66
 *
 
67
 * perm_c  (input) int*, dimension (L->ncol)
 
68
 *         Column permutation vector, which defines the 
 
69
 *         permutation matrix Pc; perm_c[i] = j means column i of A is 
 
70
 *         in position j in A*Pc.
 
71
 *
 
72
 * perm_r  (input) int*, dimension (L->nrow)
 
73
 *         Row permutation vector, which defines the permutation matrix Pr; 
 
74
 *         perm_r[i] = j means row i of A is in position j in Pr*A.
 
75
 *
 
76
 * B       (input/output) SuperMatrix*
 
77
 *         B has types: Stype = SLU_DN, Dtype = SLU_D, Mtype = SLU_GE.
 
78
 *         On entry, the right hand side matrix.
 
79
 *         On exit, the solution matrix if info = 0;
 
80
 *
 
81
 * stat     (output) SuperLUStat_t*
 
82
 *          Record the statistics on runtime and floating-point operation count.
 
83
 *          See util.h for the definition of 'SuperLUStat_t'.
 
84
 *
 
85
 * info    (output) int*
 
86
 *         = 0: successful exit
 
87
 *         < 0: if info = -i, the i-th argument had an illegal value
 
88
 *
 
89
 */
 
90
#ifdef _CRAY
 
91
    _fcd ftcs1, ftcs2, ftcs3, ftcs4;
 
92
#endif
 
93
    int      incx = 1, incy = 1;
 
94
#ifdef USE_VENDOR_BLAS
 
95
    double   alpha = 1.0, beta = 1.0;
 
96
    double   *work_col;
 
97
#endif
 
98
    DNformat *Bstore;
 
99
    double   *Bmat;
 
100
    SCformat *Lstore;
 
101
    NCformat *Ustore;
 
102
    double   *Lval, *Uval;
 
103
    int      fsupc, nrow, nsupr, nsupc, luptr, istart, irow;
 
104
    int      i, j, k, iptr, jcol, n, ldb, nrhs;
 
105
    double   *work, *rhs_work, *soln;
 
106
    flops_t  solve_ops;
 
107
    void dprint_soln();
 
108
 
 
109
    /* Test input parameters ... */
 
110
    *info = 0;
 
111
    Bstore = B->Store;
 
112
    ldb = Bstore->lda;
 
113
    nrhs = B->ncol;
 
114
    if ( trans != NOTRANS && trans != TRANS && trans != CONJ ) *info = -1;
 
115
    else if ( L->nrow != L->ncol || L->nrow < 0 ||
 
116
              L->Stype != SLU_SC || L->Dtype != SLU_D || L->Mtype != SLU_TRLU )
 
117
        *info = -2;
 
118
    else if ( U->nrow != U->ncol || U->nrow < 0 ||
 
119
              U->Stype != SLU_NC || U->Dtype != SLU_D || U->Mtype != SLU_TRU )
 
120
        *info = -3;
 
121
    else if ( ldb < SUPERLU_MAX(0, L->nrow) ||
 
122
              B->Stype != SLU_DN || B->Dtype != SLU_D || B->Mtype != SLU_GE )
 
123
        *info = -6;
 
124
    if ( *info ) {
 
125
        i = -(*info);
 
126
        xerbla_("dgstrs", &i);
 
127
        return;
 
128
    }
 
129
 
 
130
    n = L->nrow;
 
131
    work = doubleCalloc(n * nrhs);
 
132
    if ( !work ) ABORT("Malloc fails for local work[].");
 
133
    soln = doubleMalloc(n);
 
134
    if ( !soln ) ABORT("Malloc fails for local soln[].");
 
135
 
 
136
    Bmat = Bstore->nzval;
 
137
    Lstore = L->Store;
 
138
    Lval = Lstore->nzval;
 
139
    Ustore = U->Store;
 
140
    Uval = Ustore->nzval;
 
141
    solve_ops = 0;
 
142
    
 
143
    if ( trans == NOTRANS ) {
 
144
        /* Permute right hand sides to form Pr*B */
 
145
        for (i = 0; i < nrhs; i++) {
 
146
            rhs_work = &Bmat[i*ldb];
 
147
            for (k = 0; k < n; k++) soln[perm_r[k]] = rhs_work[k];
 
148
            for (k = 0; k < n; k++) rhs_work[k] = soln[k];
 
149
        }
 
150
        
 
151
        /* Forward solve PLy=Pb. */
 
152
        for (k = 0; k <= Lstore->nsuper; k++) {
 
153
            fsupc = L_FST_SUPC(k);
 
154
            istart = L_SUB_START(fsupc);
 
155
            nsupr = L_SUB_START(fsupc+1) - istart;
 
156
            nsupc = L_FST_SUPC(k+1) - fsupc;
 
157
            nrow = nsupr - nsupc;
 
158
 
 
159
            solve_ops += nsupc * (nsupc - 1) * nrhs;
 
160
            solve_ops += 2 * nrow * nsupc * nrhs;
 
161
            
 
162
            if ( nsupc == 1 ) {
 
163
                for (j = 0; j < nrhs; j++) {
 
164
                    rhs_work = &Bmat[j*ldb];
 
165
                    luptr = L_NZ_START(fsupc);
 
166
                    for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); iptr++){
 
167
                        irow = L_SUB(iptr);
 
168
                        ++luptr;
 
169
                        rhs_work[irow] -= rhs_work[fsupc] * Lval[luptr];
 
170
                    }
 
171
                }
 
172
            } else {
 
173
                luptr = L_NZ_START(fsupc);
 
174
#ifdef USE_VENDOR_BLAS
 
175
#ifdef _CRAY
 
176
                ftcs1 = _cptofcd("L", strlen("L"));
 
177
                ftcs2 = _cptofcd("N", strlen("N"));
 
178
                ftcs3 = _cptofcd("U", strlen("U"));
 
179
                STRSM( ftcs1, ftcs1, ftcs2, ftcs3, &nsupc, &nrhs, &alpha,
 
180
                       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
 
181
                
 
182
                SGEMM( ftcs2, ftcs2, &nrow, &nrhs, &nsupc, &alpha, 
 
183
                        &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, 
 
184
                        &beta, &work[0], &n );
 
185
#else
 
186
                dtrsm_("L", "L", "N", "U", &nsupc, &nrhs, &alpha,
 
187
                       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
 
188
                
 
189
                dgemm_( "N", "N", &nrow, &nrhs, &nsupc, &alpha, 
 
190
                        &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, 
 
191
                        &beta, &work[0], &n );
 
192
#endif
 
193
                for (j = 0; j < nrhs; j++) {
 
194
                    rhs_work = &Bmat[j*ldb];
 
195
                    work_col = &work[j*n];
 
196
                    iptr = istart + nsupc;
 
197
                    for (i = 0; i < nrow; i++) {
 
198
                        irow = L_SUB(iptr);
 
199
                        rhs_work[irow] -= work_col[i]; /* Scatter */
 
200
                        work_col[i] = 0.0;
 
201
                        iptr++;
 
202
                    }
 
203
                }
 
204
#else           
 
205
                for (j = 0; j < nrhs; j++) {
 
206
                    rhs_work = &Bmat[j*ldb];
 
207
                    dlsolve (nsupr, nsupc, &Lval[luptr], &rhs_work[fsupc]);
 
208
                    dmatvec (nsupr, nrow, nsupc, &Lval[luptr+nsupc],
 
209
                            &rhs_work[fsupc], &work[0] );
 
210
 
 
211
                    iptr = istart + nsupc;
 
212
                    for (i = 0; i < nrow; i++) {
 
213
                        irow = L_SUB(iptr);
 
214
                        rhs_work[irow] -= work[i];
 
215
                        work[i] = 0.0;
 
216
                        iptr++;
 
217
                    }
 
218
                }
 
219
#endif              
 
220
            } /* else ... */
 
221
        } /* for L-solve */
 
222
 
 
223
#ifdef DEBUG
 
224
        printf("After L-solve: y=\n");
 
225
        dprint_soln(n, nrhs, Bmat);
 
226
#endif
 
227
 
 
228
        /*
 
229
         * Back solve Ux=y.
 
230
         */
 
231
        for (k = Lstore->nsuper; k >= 0; k--) {
 
232
            fsupc = L_FST_SUPC(k);
 
233
            istart = L_SUB_START(fsupc);
 
234
            nsupr = L_SUB_START(fsupc+1) - istart;
 
235
            nsupc = L_FST_SUPC(k+1) - fsupc;
 
236
            luptr = L_NZ_START(fsupc);
 
237
 
 
238
            solve_ops += nsupc * (nsupc + 1) * nrhs;
 
239
 
 
240
            if ( nsupc == 1 ) {
 
241
                rhs_work = &Bmat[0];
 
242
                for (j = 0; j < nrhs; j++) {
 
243
                    rhs_work[fsupc] /= Lval[luptr];
 
244
                    rhs_work += ldb;
 
245
                }
 
246
            } else {
 
247
#ifdef USE_VENDOR_BLAS
 
248
#ifdef _CRAY
 
249
                ftcs1 = _cptofcd("L", strlen("L"));
 
250
                ftcs2 = _cptofcd("U", strlen("U"));
 
251
                ftcs3 = _cptofcd("N", strlen("N"));
 
252
                STRSM( ftcs1, ftcs2, ftcs3, ftcs3, &nsupc, &nrhs, &alpha,
 
253
                       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
 
254
#else
 
255
                dtrsm_("L", "U", "N", "N", &nsupc, &nrhs, &alpha,
 
256
                       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
 
257
#endif
 
258
#else           
 
259
                for (j = 0; j < nrhs; j++)
 
260
                    dusolve ( nsupr, nsupc, &Lval[luptr], &Bmat[fsupc+j*ldb] );
 
261
#endif          
 
262
            }
 
263
 
 
264
            for (j = 0; j < nrhs; ++j) {
 
265
                rhs_work = &Bmat[j*ldb];
 
266
                for (jcol = fsupc; jcol < fsupc + nsupc; jcol++) {
 
267
                    solve_ops += 2*(U_NZ_START(jcol+1) - U_NZ_START(jcol));
 
268
                    for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++ ){
 
269
                        irow = U_SUB(i);
 
270
                        rhs_work[irow] -= rhs_work[jcol] * Uval[i];
 
271
                    }
 
272
                }
 
273
            }
 
274
            
 
275
        } /* for U-solve */
 
276
 
 
277
#ifdef DEBUG
 
278
        printf("After U-solve: x=\n");
 
279
        dprint_soln(n, nrhs, Bmat);
 
280
#endif
 
281
 
 
282
        /* Compute the final solution X := Pc*X. */
 
283
        for (i = 0; i < nrhs; i++) {
 
284
            rhs_work = &Bmat[i*ldb];
 
285
            for (k = 0; k < n; k++) soln[k] = rhs_work[perm_c[k]];
 
286
            for (k = 0; k < n; k++) rhs_work[k] = soln[k];
 
287
        }
 
288
        
 
289
        stat->ops[SOLVE] = solve_ops;
 
290
 
 
291
    } else { /* Solve A'*X=B */
 
292
        /* Permute right hand sides to form Pc'*B. */
 
293
        for (i = 0; i < nrhs; i++) {
 
294
            rhs_work = &Bmat[i*ldb];
 
295
            for (k = 0; k < n; k++) soln[perm_c[k]] = rhs_work[k];
 
296
            for (k = 0; k < n; k++) rhs_work[k] = soln[k];
 
297
        }
 
298
 
 
299
        stat->ops[SOLVE] = 0;
 
300
        
 
301
        for (k = 0; k < nrhs; ++k) {
 
302
            
 
303
            /* Multiply by inv(U'). */
 
304
            sp_dtrsv("U", "T", "N", L, U, &Bmat[k*ldb], stat, info);
 
305
            
 
306
            /* Multiply by inv(L'). */
 
307
            sp_dtrsv("L", "T", "U", L, U, &Bmat[k*ldb], stat, info);
 
308
            
 
309
        }
 
310
        
 
311
        /* Compute the final solution X := Pr'*X (=inv(Pr)*X) */
 
312
        for (i = 0; i < nrhs; i++) {
 
313
            rhs_work = &Bmat[i*ldb];
 
314
            for (k = 0; k < n; k++) soln[k] = rhs_work[perm_r[k]];
 
315
            for (k = 0; k < n; k++) rhs_work[k] = soln[k];
 
316
        }
 
317
 
 
318
    }
 
319
 
 
320
    SUPERLU_FREE(work);
 
321
    SUPERLU_FREE(soln);
 
322
}
 
323
 
 
324
/*
 
325
 * Diagnostic print of the solution vector 
 
326
 */
 
327
void
 
328
dprint_soln(int n, int nrhs, double *soln)
 
329
{
 
330
    int i;
 
331
 
 
332
    for (i = 0; i < n; i++) 
 
333
        printf("\t%d: %.4f\n", i, soln[i]);
 
334
}