4
* -- SuperLU routine (version 2.0) --
5
* Univ. of California Berkeley, Xerox Palo Alto Research Center,
6
* and Lawrence Berkeley National Lab.
11
Copyright (c) 1994 by Xerox Corporation. All rights reserved.
13
THIS MATERIAL IS PROVIDED AS IS, WITH ABSOLUTELY NO WARRANTY
14
EXPRESSED OR IMPLIED. ANY USE IS AT YOUR OWN RISK.
16
Permission is hereby granted to use or copy this program for any
17
purpose, provided the above notices are retained on all copies.
18
Permission to modify the code and to distribute modified code is
19
granted, provided the above notices are retained, and a notice that
20
the code was modified is included with the above copyright notice.
30
void dusolve(int, int, double*, double*);
31
void dlsolve(int, int, double*, double*);
32
void dmatvec(int, int, int, double*, double*, double*);
36
dgstrs (char *trans, SuperMatrix *L, SuperMatrix *U,
37
int *perm_r, int *perm_c, SuperMatrix *B, int *info)
43
* DGSTRS solves a system of linear equations A*X=B or A'*X=B
44
* with A sparse and B dense, using the LU factorization computed by
47
* See supermatrix.h for the definition of 'SuperMatrix' structure.
53
* Specifies the form of the system of equations:
54
* = 'N': A * X = B (No transpose)
55
* = 'T': A'* X = B (Transpose)
56
* = 'C': A**H * X = B (Conjugate transpose)
58
* L (input) SuperMatrix*
59
* The factor L from the factorization Pr*A*Pc=L*U as computed by
60
* dgstrf(). Use compressed row subscripts storage for supernodes,
61
* i.e., L has types: Stype = SC, Dtype = D_, Mtype = TRLU.
63
* U (input) SuperMatrix*
64
* The factor U from the factorization Pr*A*Pc=L*U as computed by
65
* dgstrf(). Use column-wise storage scheme, i.e., U has types:
66
* Stype = NC, Dtype = D_, Mtype = TRU.
68
* perm_r (input) int*, dimension (L->nrow)
69
* Row permutation vector, which defines the permutation matrix Pr;
70
* perm_r[i] = j means row i of A is in position j in Pr*A.
72
* perm_c (input) int*, dimension (L->ncol)
73
* Column permutation vector, which defines the
74
* permutation matrix Pc; perm_c[i] = j means column i of A is
75
* in position j in A*Pc.
77
* B (input/output) SuperMatrix*
78
* B has types: Stype = DN, Dtype = D_, Mtype = GE.
79
* On entry, the right hand side matrix.
80
* On exit, the solution matrix if info = 0;
83
* = 0: successful exit
84
* < 0: if info = -i, the i-th argument had an illegal value
88
_fcd ftcs1, ftcs2, ftcs3, ftcs4;
90
int incx = 1, incy = 1;
91
double alpha = 1.0, beta = 1.0;
98
int fsupc, nsupr, nsupc, luptr, istart, irow;
99
int i, j, k, iptr, jcol, n, ldb, nrhs;
100
double *work, *work_col, *rhs_work, *soln;
102
extern SuperLUStat_t SuperLUStat;
105
/* Test input parameters ... */
110
notran = lsame_(trans, "N");
111
if ( !notran && !lsame_(trans, "T") && !lsame_(trans, "C") ) *info = -1;
112
else if ( L->nrow != L->ncol || L->nrow < 0 ||
113
L->Stype != SC || L->Dtype != D_ || L->Mtype != TRLU )
115
else if ( U->nrow != U->ncol || U->nrow < 0 ||
116
U->Stype != NC || U->Dtype != D_ || U->Mtype != TRU )
118
else if ( ldb < SUPERLU_MAX(0, L->nrow) ||
119
B->Stype != DN || B->Dtype != D_ || B->Mtype != GE )
123
xerbla_("dgstrs", &i);
128
work = doubleCalloc(n * nrhs);
129
if ( !work ) ABORT("Malloc fails for local work[].");
130
soln = doubleMalloc(n);
131
if ( !soln ) ABORT("Malloc fails for local soln[].");
133
Bmat = Bstore->nzval;
135
Lval = Lstore->nzval;
137
Uval = Ustore->nzval;
141
/* Permute right hand sides to form Pr*B */
142
for (i = 0; i < nrhs; i++) {
143
rhs_work = &Bmat[i*ldb];
144
for (k = 0; k < n; k++) soln[perm_r[k]] = rhs_work[k];
145
for (k = 0; k < n; k++) rhs_work[k] = soln[k];
148
/* Forward solve PLy=Pb. */
149
for (k = 0; k <= Lstore->nsuper; k++) {
150
fsupc = L_FST_SUPC(k);
151
istart = L_SUB_START(fsupc);
152
nsupr = L_SUB_START(fsupc+1) - istart;
153
nsupc = L_FST_SUPC(k+1) - fsupc;
154
nrow = nsupr - nsupc;
156
solve_ops += nsupc * (nsupc - 1) * nrhs;
157
solve_ops += 2 * nrow * nsupc * nrhs;
160
for (j = 0; j < nrhs; j++) {
161
rhs_work = &Bmat[j*ldb];
162
luptr = L_NZ_START(fsupc);
163
for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); iptr++){
166
rhs_work[irow] -= rhs_work[fsupc] * Lval[luptr];
170
luptr = L_NZ_START(fsupc);
171
#ifdef USE_VENDOR_BLAS
173
ftcs1 = _cptofcd("L", strlen("L"));
174
ftcs2 = _cptofcd("N", strlen("N"));
175
ftcs3 = _cptofcd("U", strlen("U"));
176
STRSM( ftcs1, ftcs1, ftcs2, ftcs3, &nsupc, &nrhs, &alpha,
177
&Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
179
SGEMM( ftcs2, ftcs2, &nrow, &nrhs, &nsupc, &alpha,
180
&Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb,
181
&beta, &work[0], &n );
183
dtrsm_("L", "L", "N", "U", &nsupc, &nrhs, &alpha,
184
&Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
186
dgemm_( "N", "N", &nrow, &nrhs, &nsupc, &alpha,
187
&Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb,
188
&beta, &work[0], &n );
190
for (j = 0; j < nrhs; j++) {
191
rhs_work = &Bmat[j*ldb];
192
work_col = &work[j*n];
193
iptr = istart + nsupc;
194
for (i = 0; i < nrow; i++) {
196
rhs_work[irow] -= work_col[i]; /* Scatter */
202
for (j = 0; j < nrhs; j++) {
203
rhs_work = &Bmat[j*ldb];
204
dlsolve (nsupr, nsupc, &Lval[luptr], &rhs_work[fsupc]);
205
dmatvec (nsupr, nrow, nsupc, &Lval[luptr+nsupc],
206
&rhs_work[fsupc], &work[0] );
208
iptr = istart + nsupc;
209
for (i = 0; i < nrow; i++) {
211
rhs_work[irow] -= work[i];
221
printf("After L-solve: y=\n");
222
dprint_soln(n, nrhs, Bmat);
228
for (k = Lstore->nsuper; k >= 0; k--) {
229
fsupc = L_FST_SUPC(k);
230
istart = L_SUB_START(fsupc);
231
nsupr = L_SUB_START(fsupc+1) - istart;
232
nsupc = L_FST_SUPC(k+1) - fsupc;
233
luptr = L_NZ_START(fsupc);
235
solve_ops += nsupc * (nsupc + 1) * nrhs;
239
for (j = 0; j < nrhs; j++) {
240
rhs_work[fsupc] /= Lval[luptr];
244
#ifdef USE_VENDOR_BLAS
246
ftcs1 = _cptofcd("L", strlen("L"));
247
ftcs2 = _cptofcd("U", strlen("U"));
248
ftcs3 = _cptofcd("N", strlen("N"));
249
STRSM( ftcs1, ftcs2, ftcs3, ftcs3, &nsupc, &nrhs, &alpha,
250
&Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
252
dtrsm_("L", "U", "N", "N", &nsupc, &nrhs, &alpha,
253
&Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
256
for (j = 0; j < nrhs; j++)
257
dusolve ( nsupr, nsupc, &Lval[luptr], &Bmat[fsupc+j*ldb] );
261
for (j = 0; j < nrhs; ++j) {
262
rhs_work = &Bmat[j*ldb];
263
for (jcol = fsupc; jcol < fsupc + nsupc; jcol++) {
264
solve_ops += 2*(U_NZ_START(jcol+1) - U_NZ_START(jcol));
265
for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++ ){
267
rhs_work[irow] -= rhs_work[jcol] * Uval[i];
275
printf("After U-solve: x=\n");
276
dprint_soln(n, nrhs, Bmat);
279
/* Compute the final solution X := Pc*X. */
280
for (i = 0; i < nrhs; i++) {
281
rhs_work = &Bmat[i*ldb];
282
for (k = 0; k < n; k++) soln[k] = rhs_work[perm_c[k]];
283
for (k = 0; k < n; k++) rhs_work[k] = soln[k];
286
SuperLUStat.ops[SOLVE] = solve_ops;
288
} else { /* Solve A'*X=B */
289
/* Permute right hand sides to form Pc'*B. */
290
for (i = 0; i < nrhs; i++) {
291
rhs_work = &Bmat[i*ldb];
292
for (k = 0; k < n; k++) soln[perm_c[k]] = rhs_work[k];
293
for (k = 0; k < n; k++) rhs_work[k] = soln[k];
296
SuperLUStat.ops[SOLVE] = 0;
298
for (k = 0; k < nrhs; ++k) {
300
/* Multiply by inv(U'). */
301
sp_dtrsv("U", "T", "N", L, U, &Bmat[k*ldb], info);
303
/* Multiply by inv(L'). */
304
sp_dtrsv("L", "T", "U", L, U, &Bmat[k*ldb], info);
308
/* Compute the final solution X := Pr'*X (=inv(Pr)*X) */
309
for (i = 0; i < nrhs; i++) {
310
rhs_work = &Bmat[i*ldb];
311
for (k = 0; k < n; k++) soln[k] = rhs_work[perm_r[k]];
312
for (k = 0; k < n; k++) rhs_work[k] = soln[k];
322
* Diagnostic print of the solution vector
325
dprint_soln(int n, int nrhs, double *soln)
329
for (i = 0; i < n; i++)
330
printf("\t%d: %.4f\n", i, soln[i]);