~ubuntu-branches/ubuntu/karmic/python-scipy/karmic

« back to all changes in this revision

Viewing changes to Lib/sparse/_superluobject.c

  • Committer: Bazaar Package Importer
  • Author(s): Daniel T. Chen (new)
  • Date: 2005-03-16 02:15:29 UTC
  • Revision ID: james.westby@ubuntu.com-20050316021529-xrjlowsejs0cijig
Tags: upstream-0.3.2
ImportĀ upstreamĀ versionĀ 0.3.2

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
 
 
2
#include <setjmp.h>
 
3
#include "SuperLU/SRC/zsp_defs.h"
 
4
#define NO_IMPORT_ARRAY
 
5
#include "_superluobject.h"
 
6
 
 
7
extern jmp_buf _superlu_py_jmpbuf;
 
8
 
 
9
/*********************************************************************** 
 
10
 * SciPyLUObject methods
 
11
 */
 
12
 
 
13
static char solve_doc[] = "x = self.solve(b, trans)\n\
 
14
\n\
 
15
solves linear system of equations with one or sereral right hand sides.\n\
 
16
\n\
 
17
parameters\n\
 
18
----------\n\
 
19
\n\
 
20
b        array, right hand side(s) of equation\n\
 
21
x        array, solution vector(s)\n\
 
22
trans    'N': solve A   * x == b\n\
 
23
         'T': solve A^T * x == b\n\
 
24
         'H': solve A^H * x == b (not yet implemented)\n\
 
25
         (optional, default value 'N')\n\
 
26
";
 
27
 
 
28
static PyObject *
 
29
SciPyLU_solve(SciPyLUObject *self, PyObject *args, PyObject *kwds) {
 
30
  PyArrayObject *b, *x=NULL;
 
31
  SuperMatrix B;
 
32
  char itrans = 'N';
 
33
  int info;
 
34
  trans_t trans;
 
35
  SuperLUStat_t stat;
 
36
 
 
37
  static char *kwlist[] = {"rhs","trans",NULL};
 
38
 
 
39
  if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!|c", kwlist,
 
40
                                   &PyArray_Type, &b, 
 
41
                                   &itrans))
 
42
    return NULL;
 
43
 
 
44
  /* solve transposed system: matrix was passed row-wise instead of column-wise */
 
45
  if (itrans == 'n' || itrans == 'N')
 
46
      trans = NOTRANS;
 
47
  else if (itrans == 't' || itrans == 'T')
 
48
      trans = TRANS;
 
49
  else if (itrans == 'h' || itrans == 'H')
 
50
      trans = CONJ;
 
51
  else {
 
52
    PyErr_SetString(PyExc_ValueError, "trans must be N, T, or H");
 
53
    return NULL;
 
54
  }
 
55
 
 
56
  if ((x = (PyArrayObject *) \
 
57
       PyArray_CopyFromObject((PyObject *)b,self->type,1,2))==NULL) return NULL;
 
58
 
 
59
  if (b->dimensions[0] != self->n) goto fail;
 
60
 
 
61
 
 
62
  if (setjmp(_superlu_py_jmpbuf)) goto fail; 
 
63
 
 
64
  if (DenseSuper_from_Numeric(&B, (PyObject *)x)) goto fail;
 
65
 
 
66
  StatInit(&stat);
 
67
 
 
68
  /* Solve the system, overwriting vector x. */
 
69
  switch(self->type) {
 
70
  case PyArray_FLOAT:
 
71
      sgstrs(trans, &self->L, &self->U, self->perm_c, self->perm_r, &B, &stat, &info);      
 
72
      break;
 
73
  case PyArray_DOUBLE:
 
74
      dgstrs(trans, &self->L, &self->U, self->perm_c, self->perm_r, &B, &stat, &info);      
 
75
      break;
 
76
  case PyArray_CFLOAT:
 
77
      cgstrs(trans, &self->L, &self->U, self->perm_c, self->perm_r, &B, &stat, &info);      
 
78
      break;
 
79
  case PyArray_CDOUBLE:
 
80
      zgstrs(trans, &self->L, &self->U, self->perm_c, self->perm_r, &B, &stat, &info); 
 
81
      break;
 
82
  default:
 
83
      PyErr_SetString(PyExc_TypeError, "Invalid type for array.");
 
84
      goto fail;
 
85
  }
 
86
 
 
87
  if (info) { 
 
88
      PyErr_SetString(PyExc_SystemError, "gstrs was called with invalid arguments");
 
89
      goto fail;
 
90
  }
 
91
  
 
92
  /* free memory */
 
93
  Destroy_SuperMatrix_Store(&B);
 
94
  StatFree(&stat);
 
95
  return (PyObject *)x;
 
96
 
 
97
 fail:
 
98
  Destroy_SuperMatrix_Store(&B);  
 
99
  StatFree(&stat);
 
100
  Py_XDECREF(x);
 
101
  return NULL;
 
102
}
 
103
 
 
104
/** table of object methods
 
105
 */
 
106
PyMethodDef SciPyLU_methods[] = {
 
107
  {"solve", (PyCFunction)SciPyLU_solve, METH_VARARGS|METH_KEYWORDS, solve_doc},
 
108
  {NULL, NULL}                  /* sentinel */
 
109
};
 
110
 
 
111
 
 
112
/*********************************************************************** 
 
113
 * SciPySuperLUType methods
 
114
 */
 
115
 
 
116
static void
 
117
SciPyLU_dealloc(SciPyLUObject *self)
 
118
{
 
119
  SUPERLU_FREE(self->perm_r);
 
120
  SUPERLU_FREE(self->perm_c);
 
121
  Destroy_SuperNode_Matrix(&self->L);
 
122
  Destroy_CompCol_Matrix(&self->U);
 
123
  PyObject_Del(self);
 
124
}
 
125
 
 
126
static PyObject *
 
127
SciPyLU_getattr(SciPyLUObject *self, char *name)
 
128
{
 
129
  if (strcmp(name, "shape") == 0)
 
130
    return Py_BuildValue("(i,i)", self->m, self->n);
 
131
  if (strcmp(name, "nnz") == 0)
 
132
    return Py_BuildValue("i", ((SCformat *)self->L.Store)->nnz + ((SCformat *)self->U.Store)->nnz);
 
133
  if (strcmp(name, "__members__") == 0) {
 
134
    char *members[] = {"shape", "nnz"};
 
135
    int i;
 
136
 
 
137
    PyObject *list = PyList_New(sizeof(members)/sizeof(char *));
 
138
    if (list != NULL) {
 
139
      for (i = 0; i < sizeof(members)/sizeof(char *); i ++)
 
140
        PyList_SetItem(list, i, PyString_FromString(members[i]));
 
141
      if (PyErr_Occurred()) {
 
142
        Py_DECREF(list);
 
143
        list = NULL;
 
144
      }
 
145
    }
 
146
    return list;
 
147
  }
 
148
  return Py_FindMethod(SciPyLU_methods, (PyObject *)self, name);
 
149
}
 
150
 
 
151
 
 
152
/***********************************************************************
 
153
 * SciPySuperLUType structure
 
154
 */
 
155
 
 
156
PyTypeObject SciPySuperLUType = {
 
157
  PyObject_HEAD_INIT(NULL)
 
158
  0,
 
159
  "factored_lu",
 
160
  sizeof(SciPyLUObject),
 
161
  0,
 
162
  (destructor)SciPyLU_dealloc,   /* tp_dealloc */
 
163
  0,                            /* tp_print */
 
164
  (getattrfunc)SciPyLU_getattr,  /* tp_getattr */
 
165
  0,                            /* tp_setattr */
 
166
  0,                            /* tp_compare */
 
167
  0,                            /* tp_repr */
 
168
  0,                            /* tp_as_number*/
 
169
  0,                            /* tp_as_sequence*/
 
170
  0,                            /* tp_as_mapping*/
 
171
  0,                            /* tp_hash */
 
172
};
 
173
 
 
174
 
 
175
int DenseSuper_from_Numeric(SuperMatrix *X, PyObject *PyX)
 
176
{
 
177
  int m, n, ldx, nd;
 
178
  PyArrayObject *aX;
 
179
  
 
180
  if (!PyArray_Check(PyX)) {
 
181
    PyErr_SetString(PyExc_TypeError, "dgssv: Second argument is not an array.");
 
182
    return -1;
 
183
  }
 
184
 
 
185
  aX = (PyArrayObject *)PyX;
 
186
  nd = aX->nd;
 
187
 
 
188
  if (nd == 1) {
 
189
    m = aX->dimensions[0];
 
190
    n = 1;
 
191
    ldx = m;
 
192
  }
 
193
  else {  /* nd == 2 */
 
194
    m = aX->dimensions[1];
 
195
    n = aX->dimensions[0];
 
196
    ldx = m;
 
197
  }
 
198
  
 
199
  if (setjmp(_superlu_py_jmpbuf)) return -1;
 
200
  else 
 
201
      switch (aX->descr->type_num) {
 
202
      case PyArray_FLOAT:
 
203
          sCreate_Dense_Matrix(X, m, n, (float *)aX->data, ldx, SLU_DN, SLU_S, SLU_GE);
 
204
          break;
 
205
      case PyArray_DOUBLE:
 
206
          dCreate_Dense_Matrix(X, m, n, (double *)aX->data, ldx, SLU_DN, SLU_D, SLU_GE);
 
207
          break;
 
208
      case PyArray_CFLOAT:
 
209
          cCreate_Dense_Matrix(X, m, n, (complex *)aX->data, ldx, SLU_DN, SLU_C, SLU_GE);
 
210
          break;
 
211
      case PyArray_CDOUBLE:
 
212
          zCreate_Dense_Matrix(X, m, n, (doublecomplex *)aX->data, ldx, SLU_DN, SLU_Z, SLU_GE);
 
213
          break;
 
214
      default:
 
215
          PyErr_SetString(PyExc_TypeError, "Invalid type for Numeric array.");
 
216
          return -1;  
 
217
      }
 
218
  
 
219
  return 0;
 
220
}
 
221
 
 
222
/* Natively handles Compressed Sparse Row and CSC */
 
223
 
 
224
int NRFormat_from_spMatrix(SuperMatrix *A, int m, int n, int nnz, PyArrayObject *nzvals, PyArrayObject *colind, PyArrayObject *rowptr, int typenum)
 
225
{
 
226
  int err = 0;
 
227
    
 
228
  err = (nzvals->descr->type_num != typenum);
 
229
  err += (nzvals->nd != 1);
 
230
  err += (nnz > nzvals->dimensions[0]);
 
231
  if (err) {
 
232
    PyErr_SetString(PyExc_TypeError, "Fourth argument must be a 1-D array at least as big as third argument.");
 
233
    return -1;
 
234
  }
 
235
 
 
236
  if (setjmp(_superlu_py_jmpbuf)) return -1;
 
237
  else 
 
238
      switch (nzvals->descr->type_num) {
 
239
      case PyArray_FLOAT:
 
240
          sCreate_CompRow_Matrix(A, m, n, nnz, (float *)nzvals->data, (int *)colind->data, \
 
241
                                 (int *)rowptr->data, SLU_NR, SLU_S, SLU_GE);
 
242
          break;
 
243
      case PyArray_DOUBLE:
 
244
          dCreate_CompRow_Matrix(A, m, n, nnz, (double *)nzvals->data, (int *)colind->data, \
 
245
                                 (int *)rowptr->data, SLU_NR, SLU_D, SLU_GE);
 
246
          break;
 
247
      case PyArray_CFLOAT:
 
248
          cCreate_CompRow_Matrix(A, m, n, nnz, (complex *)nzvals->data, (int *)colind->data, \
 
249
                                 (int *)rowptr->data, SLU_NR, SLU_C, SLU_GE);
 
250
          break;
 
251
      case PyArray_CDOUBLE:
 
252
          zCreate_CompRow_Matrix(A, m, n, nnz, (doublecomplex *)nzvals->data, (int *)colind->data, \
 
253
                                 (int *)rowptr->data, SLU_NR, SLU_Z, SLU_GE);
 
254
          break;
 
255
      default:
 
256
          PyErr_SetString(PyExc_TypeError, "Invalid type for array.");
 
257
          return -1;  
 
258
      }
 
259
 
 
260
  return 0;
 
261
}
 
262
 
 
263
int NCFormat_from_spMatrix(SuperMatrix *A, int m, int n, int nnz, PyArrayObject *nzvals, PyArrayObject *rowind, PyArrayObject *colptr, int typenum)
 
264
{
 
265
  int err=0;
 
266
 
 
267
  err = (nzvals->descr->type_num != typenum);
 
268
  err += (nzvals->nd != 1);
 
269
  err += (nnz > nzvals->dimensions[0]);
 
270
  if (err) {
 
271
    PyErr_SetString(PyExc_TypeError, "Fifth argument must be a 1-D array at least as big as fourth argument.");
 
272
    return -1;
 
273
  }
 
274
 
 
275
 
 
276
  if (setjmp(_superlu_py_jmpbuf)) return -1;
 
277
  else 
 
278
      switch (nzvals->descr->type_num) {
 
279
      case PyArray_FLOAT:
 
280
          sCreate_CompCol_Matrix(A, m, n, nnz, (float *)nzvals->data, (int *)rowind->data, \
 
281
                                 (int *)colptr->data, SLU_NC, SLU_S, SLU_GE);
 
282
          break;
 
283
      case PyArray_DOUBLE:
 
284
          dCreate_CompCol_Matrix(A, m, n, nnz, (double *)nzvals->data, (int *)rowind->data, \
 
285
                                 (int *)colptr->data, SLU_NC, SLU_D, SLU_GE);
 
286
          break;
 
287
      case PyArray_CFLOAT:
 
288
          cCreate_CompCol_Matrix(A, m, n, nnz, (complex *)nzvals->data, (int *)rowind->data, \
 
289
                                 (int *)colptr->data, SLU_NC, SLU_C, SLU_GE);
 
290
          break;
 
291
      case PyArray_CDOUBLE:
 
292
          zCreate_CompCol_Matrix(A, m, n, nnz, (doublecomplex *)nzvals->data, (int *)rowind->data, \
 
293
                                 (int *)colptr->data, SLU_NC, SLU_Z, SLU_GE);
 
294
          break;
 
295
      default:
 
296
          PyErr_SetString(PyExc_TypeError, "Invalid type for array.");
 
297
          return -1;  
 
298
      }
 
299
 
 
300
  return 0;
 
301
}
 
302
 
 
303
colperm_t superlu_module_getpermc(int permc_spec)
 
304
{
 
305
  switch(permc_spec) {
 
306
  case 0:
 
307
    return NATURAL;
 
308
  case 1:
 
309
    return MMD_ATA;
 
310
  case 2:
 
311
    return MMD_AT_PLUS_A;
 
312
  case 3:
 
313
    return COLAMD;
 
314
  }
 
315
  ABORT("Invalid input for permc_spec.");
 
316
  return NATURAL; /* compiler complains... */
 
317
}
 
318
 
 
319
PyObject *
 
320
newSciPyLUObject(SuperMatrix *A, double diag_pivot_thresh,
 
321
                 double drop_tol, int relax, int panel_size, int permc_spec,
 
322
                 int intype)
 
323
{
 
324
 
 
325
   /* A must be in SLU_NC format used by the factorization routine. */
 
326
  SciPyLUObject *self;
 
327
  SuperMatrix AC;     /* Matrix postmultiplied by Pc */
 
328
  int lwork = 0;
 
329
  int *etree=NULL;
 
330
  int info;
 
331
  int n;
 
332
  superlu_options_t options;
 
333
  SuperLUStat_t stat;
 
334
  
 
335
  n = A->ncol;
 
336
 
 
337
  /* Create SciPyLUObject */
 
338
  self = PyObject_New(SciPyLUObject, &SciPySuperLUType);
 
339
  if (self == NULL)
 
340
    return PyErr_NoMemory();
 
341
  self->m = A->nrow;
 
342
  self->n = n;
 
343
  self->perm_r = NULL;
 
344
  self->perm_c = NULL;
 
345
  self->type = intype;
 
346
 
 
347
  if (setjmp(_superlu_py_jmpbuf)) goto fail;
 
348
  
 
349
  /* Calculate and apply minimum degree ordering*/
 
350
  etree = intMalloc(n);
 
351
  self->perm_r = intMalloc(n);
 
352
  self->perm_c = intMalloc(n);
 
353
  
 
354
  set_default_options(&options);
 
355
  options.ColPerm=superlu_module_getpermc(permc_spec);
 
356
  options.DiagPivotThresh = diag_pivot_thresh;
 
357
  StatInit(&stat);
 
358
  
 
359
  get_perm_c(permc_spec, A, self->perm_c); /* calc column permutation */
 
360
  sp_preorder(&options, A, self->perm_c, etree, &AC); /* apply column permutation */
 
361
  
 
362
 
 
363
  /* Perform factorization */
 
364
  switch (A->Dtype) {
 
365
  case SLU_S:
 
366
      sgstrf(&options, &AC, (float) drop_tol, relax, panel_size,
 
367
             etree, NULL, lwork, self->perm_c, self->perm_r,
 
368
             &self->L, &self->U, &stat, &info);
 
369
      break;
 
370
  case SLU_D:
 
371
      dgstrf(&options, &AC, drop_tol, relax, panel_size,
 
372
             etree, NULL, lwork, self->perm_c, self->perm_r,
 
373
             &self->L, &self->U, &stat, &info);
 
374
      break;
 
375
  case SLU_C:          
 
376
      cgstrf(&options, &AC, (float) drop_tol, relax, panel_size,
 
377
             etree, NULL, lwork, self->perm_c, self->perm_r,
 
378
             &self->L, &self->U, &stat, &info);
 
379
      break;
 
380
  case SLU_Z:          
 
381
      zgstrf(&options, &AC, drop_tol, relax, panel_size,
 
382
             etree, NULL, lwork, self->perm_c, self->perm_r,
 
383
             &self->L, &self->U, &stat, &info);
 
384
      break;
 
385
  default:
 
386
      PyErr_SetString(PyExc_ValueError, "Invalid type in SuperMatrix.");
 
387
      goto fail;
 
388
  }
 
389
  
 
390
  if (info) {
 
391
    if (info < 0)
 
392
        PyErr_SetString(PyExc_SystemError, "dgstrf was called with invalid arguments");
 
393
    else {
 
394
        if (info <= n) 
 
395
            PyErr_SetString(PyExc_RuntimeError, "Factor is exactly singular");
 
396
        else
 
397
            PyErr_NoMemory();
 
398
    }
 
399
    goto fail;
 
400
  }
 
401
  
 
402
  /* free memory */
 
403
  SUPERLU_FREE(etree);
 
404
  Destroy_CompCol_Permuted(&AC);
 
405
  StatFree(&stat);
 
406
  
 
407
  return (PyObject *)self;
 
408
 
 
409
 fail:
 
410
  SUPERLU_FREE(etree);
 
411
  Destroy_CompCol_Permuted(&AC);
 
412
  StatFree(&stat);
 
413
  SciPyLU_dealloc(self);
 
414
  return NULL;
 
415
}