~ubuntu-branches/ubuntu/karmic/python-scipy/karmic

« back to all changes in this revision

Viewing changes to scipy/linsolve/SuperLU/SRC/zgstrs.c

  • Committer: Bazaar Package Importer
  • Author(s): Ondrej Certik
  • Date: 2008-06-16 22:58:01 UTC
  • mfrom: (2.1.24 intrepid)
  • Revision ID: james.westby@ubuntu.com-20080616225801-irdhrpcwiocfbcmt
Tags: 0.6.0-12
* The description updated to match the current SciPy (Closes: #489149).
* Standards-Version bumped to 3.8.0 (no action needed)
* Build-Depends: netcdf-dev changed to libnetcdf-dev

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
 
 
2
/*
 
3
 * -- SuperLU routine (version 3.0) --
 
4
 * Univ. of California Berkeley, Xerox Palo Alto Research Center,
 
5
 * and Lawrence Berkeley National Lab.
 
6
 * October 15, 2003
 
7
 *
 
8
 */
 
9
/*
 
10
  Copyright (c) 1994 by Xerox Corporation.  All rights reserved.
 
11
 
 
12
  THIS MATERIAL IS PROVIDED AS IS, WITH ABSOLUTELY NO WARRANTY
 
13
  EXPRESSED OR IMPLIED.  ANY USE IS AT YOUR OWN RISK.
 
14
 
 
15
  Permission is hereby granted to use or copy this program for any
 
16
  purpose, provided the above notices are retained on all copies.
 
17
  Permission to modify the code and to distribute modified code is
 
18
  granted, provided the above notices are retained, and a notice that
 
19
  the code was modified is included with the above copyright notice.
 
20
*/
 
21
 
 
22
#include "zsp_defs.h"
 
23
 
 
24
 
 
25
/* 
 
26
 * Function prototypes 
 
27
 */
 
28
void zusolve(int, int, doublecomplex*, doublecomplex*);
 
29
void zlsolve(int, int, doublecomplex*, doublecomplex*);
 
30
void zmatvec(int, int, int, doublecomplex*, doublecomplex*, doublecomplex*);
 
31
 
 
32
 
 
33
void
 
34
zgstrs (trans_t trans, SuperMatrix *L, SuperMatrix *U,
 
35
        int *perm_c, int *perm_r, SuperMatrix *B,
 
36
        SuperLUStat_t *stat, int *info)
 
37
{
 
38
/*
 
39
 * Purpose
 
40
 * =======
 
41
 *
 
42
 * ZGSTRS solves a system of linear equations A*X=B or A'*X=B
 
43
 * with A sparse and B dense, using the LU factorization computed by
 
44
 * ZGSTRF.
 
45
 *
 
46
 * See supermatrix.h for the definition of 'SuperMatrix' structure.
 
47
 *
 
48
 * Arguments
 
49
 * =========
 
50
 *
 
51
 * trans   (input) trans_t
 
52
 *          Specifies the form of the system of equations:
 
53
 *          = NOTRANS: A * X = B  (No transpose)
 
54
 *          = TRANS:   A'* X = B  (Transpose)
 
55
 *          = CONJ:    A**H * X = B  (Conjugate transpose)
 
56
 *
 
57
 * L       (input) SuperMatrix*
 
58
 *         The factor L from the factorization Pr*A*Pc=L*U as computed by
 
59
 *         zgstrf(). Use compressed row subscripts storage for supernodes,
 
60
 *         i.e., L has types: Stype = SLU_SC, Dtype = SLU_Z, Mtype = SLU_TRLU.
 
61
 *
 
62
 * U       (input) SuperMatrix*
 
63
 *         The factor U from the factorization Pr*A*Pc=L*U as computed by
 
64
 *         zgstrf(). Use column-wise storage scheme, i.e., U has types:
 
65
 *         Stype = SLU_NC, Dtype = SLU_Z, Mtype = SLU_TRU.
 
66
 *
 
67
 * perm_c  (input) int*, dimension (L->ncol)
 
68
 *         Column permutation vector, which defines the 
 
69
 *         permutation matrix Pc; perm_c[i] = j means column i of A is 
 
70
 *         in position j in A*Pc.
 
71
 *
 
72
 * perm_r  (input) int*, dimension (L->nrow)
 
73
 *         Row permutation vector, which defines the permutation matrix Pr; 
 
74
 *         perm_r[i] = j means row i of A is in position j in Pr*A.
 
75
 *
 
76
 * B       (input/output) SuperMatrix*
 
77
 *         B has types: Stype = SLU_DN, Dtype = SLU_Z, Mtype = SLU_GE.
 
78
 *         On entry, the right hand side matrix.
 
79
 *         On exit, the solution matrix if info = 0;
 
80
 *
 
81
 * stat     (output) SuperLUStat_t*
 
82
 *          Record the statistics on runtime and floating-point operation count.
 
83
 *          See util.h for the definition of 'SuperLUStat_t'.
 
84
 *
 
85
 * info    (output) int*
 
86
 *         = 0: successful exit
 
87
 *         < 0: if info = -i, the i-th argument had an illegal value
 
88
 *
 
89
 */
 
90
#ifdef _CRAY
 
91
    _fcd ftcs1, ftcs2, ftcs3, ftcs4;
 
92
#endif
 
93
    int      incx = 1, incy = 1;
 
94
#ifdef USE_VENDOR_BLAS
 
95
    doublecomplex   alpha = {1.0, 0.0}, beta = {1.0, 0.0};
 
96
    doublecomplex   *work_col;
 
97
#endif
 
98
    doublecomplex   temp_comp;
 
99
    DNformat *Bstore;
 
100
    doublecomplex   *Bmat;
 
101
    SCformat *Lstore;
 
102
    NCformat *Ustore;
 
103
    doublecomplex   *Lval, *Uval;
 
104
    int      fsupc, nrow, nsupr, nsupc, luptr, istart, irow;
 
105
    int      i, j, k, iptr, jcol, n, ldb, nrhs;
 
106
    doublecomplex   *work, *rhs_work, *soln;
 
107
    flops_t  solve_ops;
 
108
    void zprint_soln();
 
109
 
 
110
    /* Test input parameters ... */
 
111
    *info = 0;
 
112
    Bstore = B->Store;
 
113
    ldb = Bstore->lda;
 
114
    nrhs = B->ncol;
 
115
    if ( trans != NOTRANS && trans != TRANS && trans != CONJ ) *info = -1;
 
116
    else if ( L->nrow != L->ncol || L->nrow < 0 ||
 
117
              L->Stype != SLU_SC || L->Dtype != SLU_Z || L->Mtype != SLU_TRLU )
 
118
        *info = -2;
 
119
    else if ( U->nrow != U->ncol || U->nrow < 0 ||
 
120
              U->Stype != SLU_NC || U->Dtype != SLU_Z || U->Mtype != SLU_TRU )
 
121
        *info = -3;
 
122
    else if ( ldb < SUPERLU_MAX(0, L->nrow) ||
 
123
              B->Stype != SLU_DN || B->Dtype != SLU_Z || B->Mtype != SLU_GE )
 
124
        *info = -6;
 
125
    if ( *info ) {
 
126
        i = -(*info);
 
127
        xerbla_("zgstrs", &i);
 
128
        return;
 
129
    }
 
130
 
 
131
    n = L->nrow;
 
132
    work = doublecomplexCalloc(n * nrhs);
 
133
    if ( !work ) ABORT("Malloc fails for local work[].");
 
134
    soln = doublecomplexMalloc(n);
 
135
    if ( !soln ) ABORT("Malloc fails for local soln[].");
 
136
 
 
137
    Bmat = Bstore->nzval;
 
138
    Lstore = L->Store;
 
139
    Lval = Lstore->nzval;
 
140
    Ustore = U->Store;
 
141
    Uval = Ustore->nzval;
 
142
    solve_ops = 0;
 
143
    
 
144
    if ( trans == NOTRANS ) {
 
145
        /* Permute right hand sides to form Pr*B */
 
146
        for (i = 0; i < nrhs; i++) {
 
147
            rhs_work = &Bmat[i*ldb];
 
148
            for (k = 0; k < n; k++) soln[perm_r[k]] = rhs_work[k];
 
149
            for (k = 0; k < n; k++) rhs_work[k] = soln[k];
 
150
        }
 
151
        
 
152
        /* Forward solve PLy=Pb. */
 
153
        for (k = 0; k <= Lstore->nsuper; k++) {
 
154
            fsupc = L_FST_SUPC(k);
 
155
            istart = L_SUB_START(fsupc);
 
156
            nsupr = L_SUB_START(fsupc+1) - istart;
 
157
            nsupc = L_FST_SUPC(k+1) - fsupc;
 
158
            nrow = nsupr - nsupc;
 
159
 
 
160
            solve_ops += 4 * nsupc * (nsupc - 1) * nrhs;
 
161
            solve_ops += 8 * nrow * nsupc * nrhs;
 
162
            
 
163
            if ( nsupc == 1 ) {
 
164
                for (j = 0; j < nrhs; j++) {
 
165
                    rhs_work = &Bmat[j*ldb];
 
166
                    luptr = L_NZ_START(fsupc);
 
167
                    for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); iptr++){
 
168
                        irow = L_SUB(iptr);
 
169
                        ++luptr;
 
170
                        zz_mult(&temp_comp, &rhs_work[fsupc], &Lval[luptr]);
 
171
                        z_sub(&rhs_work[irow], &rhs_work[irow], &temp_comp);
 
172
                    }
 
173
                }
 
174
            } else {
 
175
                luptr = L_NZ_START(fsupc);
 
176
#ifdef USE_VENDOR_BLAS
 
177
#ifdef _CRAY
 
178
                ftcs1 = _cptofcd("L", strlen("L"));
 
179
                ftcs2 = _cptofcd("N", strlen("N"));
 
180
                ftcs3 = _cptofcd("U", strlen("U"));
 
181
                CTRSM( ftcs1, ftcs1, ftcs2, ftcs3, &nsupc, &nrhs, &alpha,
 
182
                       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
 
183
                
 
184
                CGEMM( ftcs2, ftcs2, &nrow, &nrhs, &nsupc, &alpha, 
 
185
                        &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, 
 
186
                        &beta, &work[0], &n );
 
187
#else
 
188
                ztrsm_("L", "L", "N", "U", &nsupc, &nrhs, &alpha,
 
189
                       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
 
190
                
 
191
                zgemm_( "N", "N", &nrow, &nrhs, &nsupc, &alpha, 
 
192
                        &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, 
 
193
                        &beta, &work[0], &n );
 
194
#endif
 
195
                for (j = 0; j < nrhs; j++) {
 
196
                    rhs_work = &Bmat[j*ldb];
 
197
                    work_col = &work[j*n];
 
198
                    iptr = istart + nsupc;
 
199
                    for (i = 0; i < nrow; i++) {
 
200
                        irow = L_SUB(iptr);
 
201
                        z_sub(&rhs_work[irow], &rhs_work[irow], &work_col[i]);
 
202
                        work_col[i].r = 0.0;
 
203
                        work_col[i].i = 0.0;
 
204
                        iptr++;
 
205
                    }
 
206
                }
 
207
#else           
 
208
                for (j = 0; j < nrhs; j++) {
 
209
                    rhs_work = &Bmat[j*ldb];
 
210
                    zlsolve (nsupr, nsupc, &Lval[luptr], &rhs_work[fsupc]);
 
211
                    zmatvec (nsupr, nrow, nsupc, &Lval[luptr+nsupc],
 
212
                            &rhs_work[fsupc], &work[0] );
 
213
 
 
214
                    iptr = istart + nsupc;
 
215
                    for (i = 0; i < nrow; i++) {
 
216
                        irow = L_SUB(iptr);
 
217
                        z_sub(&rhs_work[irow], &rhs_work[irow], &work[i]);
 
218
                        work[i].r = 0.;
 
219
                        work[i].i = 0.;
 
220
                        iptr++;
 
221
                    }
 
222
                }
 
223
#endif              
 
224
            } /* else ... */
 
225
        } /* for L-solve */
 
226
 
 
227
#ifdef DEBUG
 
228
        printf("After L-solve: y=\n");
 
229
        zprint_soln(n, nrhs, Bmat);
 
230
#endif
 
231
 
 
232
        /*
 
233
         * Back solve Ux=y.
 
234
         */
 
235
        for (k = Lstore->nsuper; k >= 0; k--) {
 
236
            fsupc = L_FST_SUPC(k);
 
237
            istart = L_SUB_START(fsupc);
 
238
            nsupr = L_SUB_START(fsupc+1) - istart;
 
239
            nsupc = L_FST_SUPC(k+1) - fsupc;
 
240
            luptr = L_NZ_START(fsupc);
 
241
 
 
242
            solve_ops += 4 * nsupc * (nsupc + 1) * nrhs;
 
243
 
 
244
            if ( nsupc == 1 ) {
 
245
                rhs_work = &Bmat[0];
 
246
                for (j = 0; j < nrhs; j++) {
 
247
                    z_div(&rhs_work[fsupc], &rhs_work[fsupc], &Lval[luptr]);
 
248
                    rhs_work += ldb;
 
249
                }
 
250
            } else {
 
251
#ifdef USE_VENDOR_BLAS
 
252
#ifdef _CRAY
 
253
                ftcs1 = _cptofcd("L", strlen("L"));
 
254
                ftcs2 = _cptofcd("U", strlen("U"));
 
255
                ftcs3 = _cptofcd("N", strlen("N"));
 
256
                CTRSM( ftcs1, ftcs2, ftcs3, ftcs3, &nsupc, &nrhs, &alpha,
 
257
                       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
 
258
#else
 
259
                ztrsm_("L", "U", "N", "N", &nsupc, &nrhs, &alpha,
 
260
                       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
 
261
#endif
 
262
#else           
 
263
                for (j = 0; j < nrhs; j++)
 
264
                    zusolve ( nsupr, nsupc, &Lval[luptr], &Bmat[fsupc+j*ldb] );
 
265
#endif          
 
266
            }
 
267
 
 
268
            for (j = 0; j < nrhs; ++j) {
 
269
                rhs_work = &Bmat[j*ldb];
 
270
                for (jcol = fsupc; jcol < fsupc + nsupc; jcol++) {
 
271
                    solve_ops += 8*(U_NZ_START(jcol+1) - U_NZ_START(jcol));
 
272
                    for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++ ){
 
273
                        irow = U_SUB(i);
 
274
                        zz_mult(&temp_comp, &rhs_work[jcol], &Uval[i]);
 
275
                        z_sub(&rhs_work[irow], &rhs_work[irow], &temp_comp);
 
276
                    }
 
277
                }
 
278
            }
 
279
            
 
280
        } /* for U-solve */
 
281
 
 
282
#ifdef DEBUG
 
283
        printf("After U-solve: x=\n");
 
284
        zprint_soln(n, nrhs, Bmat);
 
285
#endif
 
286
 
 
287
        /* Compute the final solution X := Pc*X. */
 
288
        for (i = 0; i < nrhs; i++) {
 
289
            rhs_work = &Bmat[i*ldb];
 
290
            for (k = 0; k < n; k++) soln[k] = rhs_work[perm_c[k]];
 
291
            for (k = 0; k < n; k++) rhs_work[k] = soln[k];
 
292
        }
 
293
        
 
294
        stat->ops[SOLVE] = solve_ops;
 
295
 
 
296
    } else { /* Solve A'*X=B */
 
297
        /* Permute right hand sides to form Pc'*B. */
 
298
        for (i = 0; i < nrhs; i++) {
 
299
            rhs_work = &Bmat[i*ldb];
 
300
            for (k = 0; k < n; k++) soln[perm_c[k]] = rhs_work[k];
 
301
            for (k = 0; k < n; k++) rhs_work[k] = soln[k];
 
302
        }
 
303
 
 
304
        stat->ops[SOLVE] = 0;
 
305
        
 
306
        if (trans == TRANS) {
 
307
        
 
308
            for (k = 0; k < nrhs; ++k) {
 
309
                
 
310
                /* Multiply by inv(U'). */
 
311
                sp_ztrsv("U", "T", "N", L, U, &Bmat[k*ldb], stat, info);
 
312
                
 
313
                /* Multiply by inv(L'). */
 
314
                sp_ztrsv("L", "T", "U", L, U, &Bmat[k*ldb], stat, info);
 
315
                
 
316
            }
 
317
        }
 
318
        else {
 
319
            for (k = 0; k < nrhs; ++k) {
 
320
                /* Multiply by inv(U'). */
 
321
                sp_ztrsv("U", "C", "N", L, U, &Bmat[k*ldb], stat, info);
 
322
                
 
323
                /* Multiply by inv(L'). */
 
324
                sp_ztrsv("L", "C", "U", L, U, &Bmat[k*ldb], stat, info);
 
325
                
 
326
            }
 
327
        }
 
328
        
 
329
        /* Compute the final solution X := Pr'*X (=inv(Pr)*X) */
 
330
        for (i = 0; i < nrhs; i++) {
 
331
            rhs_work = &Bmat[i*ldb];
 
332
            for (k = 0; k < n; k++) soln[k] = rhs_work[perm_r[k]];
 
333
            for (k = 0; k < n; k++) rhs_work[k] = soln[k];
 
334
        }
 
335
 
 
336
    }
 
337
 
 
338
    SUPERLU_FREE(work);
 
339
    SUPERLU_FREE(soln);
 
340
}
 
341
 
 
342
/*
 
343
 * Diagnostic print of the solution vector 
 
344
 */
 
345
void
 
346
zprint_soln(int n, int nrhs, doublecomplex *soln)
 
347
{
 
348
    int i;
 
349
 
 
350
    for (i = 0; i < n; i++) 
 
351
        printf("\t%d: %.4f\n", i, soln[i]);
 
352
}