3
* -- SuperLU routine (version 3.0) --
4
* Univ. of California Berkeley, Xerox Palo Alto Research Center,
5
* and Lawrence Berkeley National Lab.
10
Copyright (c) 1994 by Xerox Corporation. All rights reserved.
12
THIS MATERIAL IS PROVIDED AS IS, WITH ABSOLUTELY NO WARRANTY
13
EXPRESSED OR IMPLIED. ANY USE IS AT YOUR OWN RISK.
15
Permission is hereby granted to use or copy this program for any
16
purpose, provided the above notices are retained on all copies.
17
Permission to modify the code and to distribute modified code is
18
granted, provided the above notices are retained, and a notice that
19
the code was modified is included with the above copyright notice.
28
void zusolve(int, int, doublecomplex*, doublecomplex*);
29
void zlsolve(int, int, doublecomplex*, doublecomplex*);
30
void zmatvec(int, int, int, doublecomplex*, doublecomplex*, doublecomplex*);
34
zgstrs (trans_t trans, SuperMatrix *L, SuperMatrix *U,
35
int *perm_c, int *perm_r, SuperMatrix *B,
36
SuperLUStat_t *stat, int *info)
42
* ZGSTRS solves a system of linear equations A*X=B or A'*X=B
43
* with A sparse and B dense, using the LU factorization computed by
46
* See supermatrix.h for the definition of 'SuperMatrix' structure.
51
* trans (input) trans_t
52
* Specifies the form of the system of equations:
53
* = NOTRANS: A * X = B (No transpose)
54
* = TRANS: A'* X = B (Transpose)
55
* = CONJ: A**H * X = B (Conjugate transpose)
57
* L (input) SuperMatrix*
58
* The factor L from the factorization Pr*A*Pc=L*U as computed by
59
* zgstrf(). Use compressed row subscripts storage for supernodes,
60
* i.e., L has types: Stype = SLU_SC, Dtype = SLU_Z, Mtype = SLU_TRLU.
62
* U (input) SuperMatrix*
63
* The factor U from the factorization Pr*A*Pc=L*U as computed by
64
* zgstrf(). Use column-wise storage scheme, i.e., U has types:
65
* Stype = SLU_NC, Dtype = SLU_Z, Mtype = SLU_TRU.
67
* perm_c (input) int*, dimension (L->ncol)
68
* Column permutation vector, which defines the
69
* permutation matrix Pc; perm_c[i] = j means column i of A is
70
* in position j in A*Pc.
72
* perm_r (input) int*, dimension (L->nrow)
73
* Row permutation vector, which defines the permutation matrix Pr;
74
* perm_r[i] = j means row i of A is in position j in Pr*A.
76
* B (input/output) SuperMatrix*
77
* B has types: Stype = SLU_DN, Dtype = SLU_Z, Mtype = SLU_GE.
78
* On entry, the right hand side matrix.
79
* On exit, the solution matrix if info = 0;
81
* stat (output) SuperLUStat_t*
82
* Record the statistics on runtime and floating-point operation count.
83
* See util.h for the definition of 'SuperLUStat_t'.
86
* = 0: successful exit
87
* < 0: if info = -i, the i-th argument had an illegal value
91
_fcd ftcs1, ftcs2, ftcs3, ftcs4;
93
int incx = 1, incy = 1;
94
#ifdef USE_VENDOR_BLAS
95
doublecomplex alpha = {1.0, 0.0}, beta = {1.0, 0.0};
96
doublecomplex *work_col;
98
doublecomplex temp_comp;
103
doublecomplex *Lval, *Uval;
104
int fsupc, nrow, nsupr, nsupc, luptr, istart, irow;
105
int i, j, k, iptr, jcol, n, ldb, nrhs;
106
doublecomplex *work, *rhs_work, *soln;
110
/* Test input parameters ... */
115
if ( trans != NOTRANS && trans != TRANS && trans != CONJ ) *info = -1;
116
else if ( L->nrow != L->ncol || L->nrow < 0 ||
117
L->Stype != SLU_SC || L->Dtype != SLU_Z || L->Mtype != SLU_TRLU )
119
else if ( U->nrow != U->ncol || U->nrow < 0 ||
120
U->Stype != SLU_NC || U->Dtype != SLU_Z || U->Mtype != SLU_TRU )
122
else if ( ldb < SUPERLU_MAX(0, L->nrow) ||
123
B->Stype != SLU_DN || B->Dtype != SLU_Z || B->Mtype != SLU_GE )
127
xerbla_("zgstrs", &i);
132
work = doublecomplexCalloc(n * nrhs);
133
if ( !work ) ABORT("Malloc fails for local work[].");
134
soln = doublecomplexMalloc(n);
135
if ( !soln ) ABORT("Malloc fails for local soln[].");
137
Bmat = Bstore->nzval;
139
Lval = Lstore->nzval;
141
Uval = Ustore->nzval;
144
if ( trans == NOTRANS ) {
145
/* Permute right hand sides to form Pr*B */
146
for (i = 0; i < nrhs; i++) {
147
rhs_work = &Bmat[i*ldb];
148
for (k = 0; k < n; k++) soln[perm_r[k]] = rhs_work[k];
149
for (k = 0; k < n; k++) rhs_work[k] = soln[k];
152
/* Forward solve PLy=Pb. */
153
for (k = 0; k <= Lstore->nsuper; k++) {
154
fsupc = L_FST_SUPC(k);
155
istart = L_SUB_START(fsupc);
156
nsupr = L_SUB_START(fsupc+1) - istart;
157
nsupc = L_FST_SUPC(k+1) - fsupc;
158
nrow = nsupr - nsupc;
160
solve_ops += 4 * nsupc * (nsupc - 1) * nrhs;
161
solve_ops += 8 * nrow * nsupc * nrhs;
164
for (j = 0; j < nrhs; j++) {
165
rhs_work = &Bmat[j*ldb];
166
luptr = L_NZ_START(fsupc);
167
for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); iptr++){
170
zz_mult(&temp_comp, &rhs_work[fsupc], &Lval[luptr]);
171
z_sub(&rhs_work[irow], &rhs_work[irow], &temp_comp);
175
luptr = L_NZ_START(fsupc);
176
#ifdef USE_VENDOR_BLAS
178
ftcs1 = _cptofcd("L", strlen("L"));
179
ftcs2 = _cptofcd("N", strlen("N"));
180
ftcs3 = _cptofcd("U", strlen("U"));
181
CTRSM( ftcs1, ftcs1, ftcs2, ftcs3, &nsupc, &nrhs, &alpha,
182
&Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
184
CGEMM( ftcs2, ftcs2, &nrow, &nrhs, &nsupc, &alpha,
185
&Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb,
186
&beta, &work[0], &n );
188
ztrsm_("L", "L", "N", "U", &nsupc, &nrhs, &alpha,
189
&Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
191
zgemm_( "N", "N", &nrow, &nrhs, &nsupc, &alpha,
192
&Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb,
193
&beta, &work[0], &n );
195
for (j = 0; j < nrhs; j++) {
196
rhs_work = &Bmat[j*ldb];
197
work_col = &work[j*n];
198
iptr = istart + nsupc;
199
for (i = 0; i < nrow; i++) {
201
z_sub(&rhs_work[irow], &rhs_work[irow], &work_col[i]);
208
for (j = 0; j < nrhs; j++) {
209
rhs_work = &Bmat[j*ldb];
210
zlsolve (nsupr, nsupc, &Lval[luptr], &rhs_work[fsupc]);
211
zmatvec (nsupr, nrow, nsupc, &Lval[luptr+nsupc],
212
&rhs_work[fsupc], &work[0] );
214
iptr = istart + nsupc;
215
for (i = 0; i < nrow; i++) {
217
z_sub(&rhs_work[irow], &rhs_work[irow], &work[i]);
228
printf("After L-solve: y=\n");
229
zprint_soln(n, nrhs, Bmat);
235
for (k = Lstore->nsuper; k >= 0; k--) {
236
fsupc = L_FST_SUPC(k);
237
istart = L_SUB_START(fsupc);
238
nsupr = L_SUB_START(fsupc+1) - istart;
239
nsupc = L_FST_SUPC(k+1) - fsupc;
240
luptr = L_NZ_START(fsupc);
242
solve_ops += 4 * nsupc * (nsupc + 1) * nrhs;
246
for (j = 0; j < nrhs; j++) {
247
z_div(&rhs_work[fsupc], &rhs_work[fsupc], &Lval[luptr]);
251
#ifdef USE_VENDOR_BLAS
253
ftcs1 = _cptofcd("L", strlen("L"));
254
ftcs2 = _cptofcd("U", strlen("U"));
255
ftcs3 = _cptofcd("N", strlen("N"));
256
CTRSM( ftcs1, ftcs2, ftcs3, ftcs3, &nsupc, &nrhs, &alpha,
257
&Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
259
ztrsm_("L", "U", "N", "N", &nsupc, &nrhs, &alpha,
260
&Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
263
for (j = 0; j < nrhs; j++)
264
zusolve ( nsupr, nsupc, &Lval[luptr], &Bmat[fsupc+j*ldb] );
268
for (j = 0; j < nrhs; ++j) {
269
rhs_work = &Bmat[j*ldb];
270
for (jcol = fsupc; jcol < fsupc + nsupc; jcol++) {
271
solve_ops += 8*(U_NZ_START(jcol+1) - U_NZ_START(jcol));
272
for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++ ){
274
zz_mult(&temp_comp, &rhs_work[jcol], &Uval[i]);
275
z_sub(&rhs_work[irow], &rhs_work[irow], &temp_comp);
283
printf("After U-solve: x=\n");
284
zprint_soln(n, nrhs, Bmat);
287
/* Compute the final solution X := Pc*X. */
288
for (i = 0; i < nrhs; i++) {
289
rhs_work = &Bmat[i*ldb];
290
for (k = 0; k < n; k++) soln[k] = rhs_work[perm_c[k]];
291
for (k = 0; k < n; k++) rhs_work[k] = soln[k];
294
stat->ops[SOLVE] = solve_ops;
296
} else { /* Solve A'*X=B */
297
/* Permute right hand sides to form Pc'*B. */
298
for (i = 0; i < nrhs; i++) {
299
rhs_work = &Bmat[i*ldb];
300
for (k = 0; k < n; k++) soln[perm_c[k]] = rhs_work[k];
301
for (k = 0; k < n; k++) rhs_work[k] = soln[k];
304
stat->ops[SOLVE] = 0;
306
if (trans == TRANS) {
308
for (k = 0; k < nrhs; ++k) {
310
/* Multiply by inv(U'). */
311
sp_ztrsv("U", "T", "N", L, U, &Bmat[k*ldb], stat, info);
313
/* Multiply by inv(L'). */
314
sp_ztrsv("L", "T", "U", L, U, &Bmat[k*ldb], stat, info);
319
for (k = 0; k < nrhs; ++k) {
320
/* Multiply by inv(U'). */
321
sp_ztrsv("U", "C", "N", L, U, &Bmat[k*ldb], stat, info);
323
/* Multiply by inv(L'). */
324
sp_ztrsv("L", "C", "U", L, U, &Bmat[k*ldb], stat, info);
329
/* Compute the final solution X := Pr'*X (=inv(Pr)*X) */
330
for (i = 0; i < nrhs; i++) {
331
rhs_work = &Bmat[i*ldb];
332
for (k = 0; k < n; k++) soln[k] = rhs_work[perm_r[k]];
333
for (k = 0; k < n; k++) rhs_work[k] = soln[k];
343
* Diagnostic print of the solution vector
346
zprint_soln(int n, int nrhs, doublecomplex *soln)
350
for (i = 0; i < n; i++)
351
printf("\t%d: %.4f\n", i, soln[i]);