~ubuntu-branches/ubuntu/wily/numexpr/wily-proposed

« back to all changes in this revision

Viewing changes to numexpr/module.cpp

  • Committer: Package Import Robot
  • Author(s): Antonio Valentino
  • Date: 2013-09-28 09:03:27 UTC
  • mfrom: (7.1.7 sid)
  • Revision ID: package-import@ubuntu.com-20130928090327-s69mvg0n2xnz6cn8
New upstream release (fixes a build failure on s390)

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
// Numexpr - Fast numerical array expression evaluator for NumPy.
 
2
//
 
3
//      License: MIT
 
4
//      Author:  See AUTHORS.txt
 
5
//
 
6
//  See LICENSE.txt for details about copyright and rights to use.
 
7
//
 
8
// module.cpp contains the CPython-specific module exposure.
 
9
 
 
10
#define DO_NUMPY_IMPORT_ARRAY
 
11
 
 
12
#include "module.hpp"
 
13
#include <structmember.h>
 
14
#include <vector>
 
15
 
 
16
#include "interpreter.hpp"
 
17
#include "numexpr_object.hpp"
 
18
 
 
19
using namespace std;
 
20
 
 
21
// Global state. The file interpreter.hpp also has some global state
 
22
// in its 'th_params' variable
 
23
global_state gs;
 
24
 
 
25
 
 
26
/* Do the worker job for a certain thread */
 
27
void *th_worker(void *tidptr)
 
28
{
 
29
    int tid = *(int *)tidptr;
 
30
    /* Parameters for threads */
 
31
    npy_intp start;
 
32
    npy_intp vlen;
 
33
    npy_intp block_size;
 
34
    NpyIter *iter;
 
35
    vm_params params;
 
36
    int *pc_error;
 
37
    int ret;
 
38
    int n_inputs;
 
39
    int n_constants;
 
40
    int n_temps;
 
41
    size_t memsize;
 
42
    char **mem;
 
43
    npy_intp *memsteps;
 
44
    npy_intp istart, iend;
 
45
    char **errmsg;
 
46
    // For output buffering if needed
 
47
    vector<char> out_buffer;
 
48
 
 
49
    while (1) {
 
50
 
 
51
        gs.init_sentinels_done = 0;     /* sentinels have to be initialised yet */
 
52
 
 
53
        /* Meeting point for all threads (wait for initialization) */
 
54
        pthread_mutex_lock(&gs.count_threads_mutex);
 
55
        if (gs.count_threads < gs.nthreads) {
 
56
            gs.count_threads++;
 
57
            pthread_cond_wait(&gs.count_threads_cv, &gs.count_threads_mutex);
 
58
        }
 
59
        else {
 
60
            pthread_cond_broadcast(&gs.count_threads_cv);
 
61
        }
 
62
        pthread_mutex_unlock(&gs.count_threads_mutex);
 
63
 
 
64
        /* Check if thread has been asked to return */
 
65
        if (gs.end_threads) {
 
66
            return(0);
 
67
        }
 
68
 
 
69
        /* Get parameters for this thread before entering the main loop */
 
70
        start = th_params.start;
 
71
        vlen = th_params.vlen;
 
72
        block_size = th_params.block_size;
 
73
        params = th_params.params;
 
74
        pc_error = th_params.pc_error;
 
75
 
 
76
        // If output buffering is needed, allocate it
 
77
        if (th_params.need_output_buffering) {
 
78
            out_buffer.resize(params.memsizes[0] * BLOCK_SIZE1);
 
79
            params.out_buffer = &out_buffer[0];
 
80
        } else {
 
81
            params.out_buffer = NULL;
 
82
        }
 
83
 
 
84
        /* Populate private data for each thread */
 
85
        n_inputs = params.n_inputs;
 
86
        n_constants = params.n_constants;
 
87
        n_temps = params.n_temps;
 
88
        memsize = (1+n_inputs+n_constants+n_temps) * sizeof(char *);
 
89
        /* XXX malloc seems thread safe for POSIX, but for Win? */
 
90
        mem = (char **)malloc(memsize);
 
91
        memcpy(mem, params.mem, memsize);
 
92
 
 
93
        errmsg = th_params.errmsg;
 
94
 
 
95
        params.mem = mem;
 
96
 
 
97
        /* Loop over blocks */
 
98
        pthread_mutex_lock(&gs.count_mutex);
 
99
        if (!gs.init_sentinels_done) {
 
100
            /* Set sentinels and other global variables */
 
101
            gs.gindex = start;
 
102
            istart = gs.gindex;
 
103
            iend = istart + block_size;
 
104
            if (iend > vlen) {
 
105
                iend = vlen;
 
106
            }
 
107
            gs.init_sentinels_done = 1;  /* sentinels have been initialised */
 
108
            gs.giveup = 0;            /* no giveup initially */
 
109
        } else {
 
110
            gs.gindex += block_size;
 
111
            istart = gs.gindex;
 
112
            iend = istart + block_size;
 
113
            if (iend > vlen) {
 
114
                iend = vlen;
 
115
            }
 
116
        }
 
117
        /* Grab one of the iterators */
 
118
        iter = th_params.iter[tid];
 
119
        if (iter == NULL) {
 
120
            th_params.ret_code = -1;
 
121
            gs.giveup = 1;
 
122
        }
 
123
        memsteps = th_params.memsteps[tid];
 
124
        /* Get temporary space for each thread */
 
125
        ret = get_temps_space(params, mem, BLOCK_SIZE1);
 
126
        if (ret < 0) {
 
127
            /* Propagate error to main thread */
 
128
            th_params.ret_code = ret;
 
129
            gs.giveup = 1;
 
130
        }
 
131
        pthread_mutex_unlock(&gs.count_mutex);
 
132
 
 
133
        while (istart < vlen && !gs.giveup) {
 
134
            /* Reset the iterator to the range for this task */
 
135
            ret = NpyIter_ResetToIterIndexRange(iter, istart, iend,
 
136
                                                errmsg);
 
137
            /* Execute the task */
 
138
            if (ret >= 0) {
 
139
                ret = vm_engine_iter_task(iter, memsteps, params, pc_error, errmsg);
 
140
            }
 
141
 
 
142
            if (ret < 0) {
 
143
                pthread_mutex_lock(&gs.count_mutex);
 
144
                gs.giveup = 1;
 
145
                /* Propagate error to main thread */
 
146
                th_params.ret_code = ret;
 
147
                pthread_mutex_unlock(&gs.count_mutex);
 
148
                break;
 
149
            }
 
150
 
 
151
            pthread_mutex_lock(&gs.count_mutex);
 
152
            gs.gindex += block_size;
 
153
            istart = gs.gindex;
 
154
            iend = istart + block_size;
 
155
            if (iend > vlen) {
 
156
                iend = vlen;
 
157
            }
 
158
            pthread_mutex_unlock(&gs.count_mutex);
 
159
        }
 
160
 
 
161
        /* Meeting point for all threads (wait for finalization) */
 
162
        pthread_mutex_lock(&gs.count_threads_mutex);
 
163
        if (gs.count_threads > 0) {
 
164
            gs.count_threads--;
 
165
            pthread_cond_wait(&gs.count_threads_cv, &gs.count_threads_mutex);
 
166
        }
 
167
        else {
 
168
            pthread_cond_broadcast(&gs.count_threads_cv);
 
169
        }
 
170
        pthread_mutex_unlock(&gs.count_threads_mutex);
 
171
 
 
172
        /* Release resources */
 
173
        free_temps_space(params, mem);
 
174
        free(mem);
 
175
 
 
176
    }  /* closes while(1) */
 
177
 
 
178
    /* This should never be reached, but anyway */
 
179
    return(0);
 
180
}
 
181
 
 
182
/* Initialize threads */
 
183
int init_threads(void)
 
184
{
 
185
    int tid, rc;
 
186
 
 
187
    /* Initialize mutex and condition variable objects */
 
188
    pthread_mutex_init(&gs.count_mutex, NULL);
 
189
 
 
190
    /* Barrier initialization */
 
191
    pthread_mutex_init(&gs.count_threads_mutex, NULL);
 
192
    pthread_cond_init(&gs.count_threads_cv, NULL);
 
193
    gs.count_threads = 0;      /* Reset threads counter */
 
194
 
 
195
    /* Finally, create the threads */
 
196
    for (tid = 0; tid < gs.nthreads; tid++) {
 
197
        gs.tids[tid] = tid;
 
198
        rc = pthread_create(&gs.threads[tid], NULL, th_worker,
 
199
                            (void *)&gs.tids[tid]);
 
200
        if (rc) {
 
201
            fprintf(stderr,
 
202
                    "ERROR; return code from pthread_create() is %d\n", rc);
 
203
            fprintf(stderr, "\tError detail: %s\n", strerror(rc));
 
204
            exit(-1);
 
205
        }
 
206
    }
 
207
 
 
208
    gs.init_threads_done = 1;                 /* Initialization done! */
 
209
    gs.pid = (int)getpid();                   /* save the PID for this process */
 
210
 
 
211
    return(0);
 
212
}
 
213
 
 
214
/* Set the number of threads in numexpr's VM */
 
215
int numexpr_set_nthreads(int nthreads_new)
 
216
{
 
217
    int nthreads_old = gs.nthreads;
 
218
    int t, rc;
 
219
    void *status;
 
220
 
 
221
    if (nthreads_new > MAX_THREADS) {
 
222
        fprintf(stderr,
 
223
                "Error.  nthreads cannot be larger than MAX_THREADS (%d)",
 
224
                MAX_THREADS);
 
225
        return -1;
 
226
    }
 
227
    else if (nthreads_new <= 0) {
 
228
        fprintf(stderr, "Error.  nthreads must be a positive integer");
 
229
        return -1;
 
230
    }
 
231
 
 
232
    /* Only join threads if they are not initialized or if our PID is
 
233
       different from that in pid var (probably means that we are a
 
234
       subprocess, and thus threads are non-existent). */
 
235
    if (gs.nthreads > 1 && gs.init_threads_done && gs.pid == getpid()) {
 
236
        /* Tell all existing threads to finish */
 
237
        gs.end_threads = 1;
 
238
        pthread_mutex_lock(&gs.count_threads_mutex);
 
239
        if (gs.count_threads < gs.nthreads) {
 
240
            gs.count_threads++;
 
241
            pthread_cond_wait(&gs.count_threads_cv, &gs.count_threads_mutex);
 
242
        }
 
243
        else {
 
244
            pthread_cond_broadcast(&gs.count_threads_cv);
 
245
        }
 
246
        pthread_mutex_unlock(&gs.count_threads_mutex);
 
247
 
 
248
        /* Join exiting threads */
 
249
        for (t=0; t<gs.nthreads; t++) {
 
250
            rc = pthread_join(gs.threads[t], &status);
 
251
            if (rc) {
 
252
                fprintf(stderr,
 
253
                        "ERROR; return code from pthread_join() is %d\n",
 
254
                        rc);
 
255
                fprintf(stderr, "\tError detail: %s\n", strerror(rc));
 
256
                exit(-1);
 
257
            }
 
258
        }
 
259
        gs.init_threads_done = 0;
 
260
        gs.end_threads = 0;
 
261
    }
 
262
 
 
263
    /* Launch a new pool of threads (if necessary) */
 
264
    gs.nthreads = nthreads_new;
 
265
    if (gs.nthreads > 1 && (!gs.init_threads_done || gs.pid != getpid())) {
 
266
        init_threads();
 
267
    }
 
268
 
 
269
    return nthreads_old;
 
270
}
 
271
 
 
272
 
 
273
#ifdef USE_VML
 
274
 
 
275
static PyObject *
 
276
_get_vml_version(PyObject *self, PyObject *args)
 
277
{
 
278
    int len=198;
 
279
    char buf[198];
 
280
    MKL_Get_Version_String(buf, len);
 
281
    return Py_BuildValue("s", buf);
 
282
}
 
283
 
 
284
static PyObject *
 
285
_set_vml_accuracy_mode(PyObject *self, PyObject *args)
 
286
{
 
287
    int mode_in, mode_old;
 
288
    if (!PyArg_ParseTuple(args, "i", &mode_in))
 
289
    return NULL;
 
290
    mode_old = vmlGetMode() & VML_ACCURACY_MASK;
 
291
    vmlSetMode((mode_in & VML_ACCURACY_MASK) | VML_ERRMODE_IGNORE );
 
292
    return Py_BuildValue("i", mode_old);
 
293
}
 
294
 
 
295
static PyObject *
 
296
_set_vml_num_threads(PyObject *self, PyObject *args)
 
297
{
 
298
    int max_num_threads;
 
299
    if (!PyArg_ParseTuple(args, "i", &max_num_threads))
 
300
    return NULL;
 
301
    mkl_domain_set_num_threads(max_num_threads, MKL_VML);
 
302
    Py_RETURN_NONE;
 
303
}
 
304
 
 
305
#endif
 
306
 
 
307
static PyObject *
 
308
_set_num_threads(PyObject *self, PyObject *args)
 
309
{
 
310
    int num_threads, nthreads_old;
 
311
    if (!PyArg_ParseTuple(args, "i", &num_threads))
 
312
    return NULL;
 
313
    nthreads_old = numexpr_set_nthreads(num_threads);
 
314
    return Py_BuildValue("i", nthreads_old);
 
315
}
 
316
 
 
317
static PyMethodDef module_methods[] = {
 
318
#ifdef USE_VML
 
319
    {"_get_vml_version", _get_vml_version, METH_VARARGS,
 
320
     "Get the VML/MKL library version."},
 
321
    {"_set_vml_accuracy_mode", _set_vml_accuracy_mode, METH_VARARGS,
 
322
     "Set accuracy mode for VML functions."},
 
323
    {"_set_vml_num_threads", _set_vml_num_threads, METH_VARARGS,
 
324
     "Suggests a maximum number of threads to be used in VML operations."},
 
325
#endif
 
326
    {"_set_num_threads", _set_num_threads, METH_VARARGS,
 
327
     "Suggests a maximum number of threads to be used in operations."},
 
328
    {NULL}
 
329
};
 
330
 
 
331
static int
 
332
add_symbol(PyObject *d, const char *sname, int name, const char* routine_name)
 
333
{
 
334
    PyObject *o, *s;
 
335
    int r;
 
336
 
 
337
    if (!sname) {
 
338
        return 0;
 
339
    }
 
340
 
 
341
    o = PyLong_FromLong(name);
 
342
    s = PyBytes_FromString(sname);
 
343
    if (!s) {
 
344
        PyErr_SetString(PyExc_RuntimeError, routine_name);
 
345
        return -1;
 
346
    }
 
347
    r = PyDict_SetItem(d, s, o);
 
348
    Py_XDECREF(o);
 
349
    return r;
 
350
}
 
351
 
 
352
#ifdef __cplusplus
 
353
extern "C" {
 
354
#endif
 
355
 
 
356
#if PY_MAJOR_VERSION >= 3
 
357
 
 
358
/* XXX: handle the "global_state" state via moduedef */
 
359
static struct PyModuleDef moduledef = {
 
360
        PyModuleDef_HEAD_INIT,
 
361
        "interpreter",
 
362
        NULL,
 
363
        -1,                 /* sizeof(struct global_state), */
 
364
        module_methods,
 
365
        NULL,
 
366
        NULL,               /* module_traverse, */
 
367
        NULL,               /* module_clear, */
 
368
        NULL
 
369
};
 
370
 
 
371
#define INITERROR return NULL
 
372
 
 
373
PyObject *
 
374
PyInit_interpreter(void)
 
375
 
 
376
#else
 
377
#define INITERROR return
 
378
 
 
379
PyMODINIT_FUNC
 
380
initinterpreter()
 
381
#endif
 
382
{
 
383
    PyObject *m, *d;
 
384
 
 
385
    if (PyType_Ready(&NumExprType) < 0)
 
386
        INITERROR;
 
387
 
 
388
#if PY_MAJOR_VERSION >= 3
 
389
    m = PyModule_Create(&moduledef);
 
390
#else
 
391
    m = Py_InitModule3("interpreter", module_methods, NULL);
 
392
#endif
 
393
 
 
394
    if (m == NULL)
 
395
        INITERROR;
 
396
 
 
397
    Py_INCREF(&NumExprType);
 
398
    PyModule_AddObject(m, "NumExpr", (PyObject *)&NumExprType);
 
399
 
 
400
    import_array();
 
401
 
 
402
    d = PyDict_New();
 
403
    if (!d) INITERROR;
 
404
 
 
405
#define OPCODE(n, name, sname, ...)                              \
 
406
    if (add_symbol(d, sname, name, "add_op") < 0) { INITERROR; }
 
407
#include "opcodes.hpp"
 
408
#undef OPCODE
 
409
 
 
410
    if (PyModule_AddObject(m, "opcodes", d) < 0) INITERROR;
 
411
 
 
412
    d = PyDict_New();
 
413
    if (!d) INITERROR;
 
414
 
 
415
#define add_func(name, sname)                           \
 
416
    if (add_symbol(d, sname, name, "add_func") < 0) { INITERROR; }
 
417
#define FUNC_FF(name, sname, ...)  add_func(name, sname);
 
418
#define FUNC_FFF(name, sname, ...) add_func(name, sname);
 
419
#define FUNC_DD(name, sname, ...)  add_func(name, sname);
 
420
#define FUNC_DDD(name, sname, ...) add_func(name, sname);
 
421
#define FUNC_CC(name, sname, ...)  add_func(name, sname);
 
422
#define FUNC_CCC(name, sname, ...) add_func(name, sname);
 
423
#include "functions.hpp"
 
424
#undef FUNC_CCC
 
425
#undef FUNC_CC
 
426
#undef FUNC_DDD
 
427
#undef FUNC_DD
 
428
#undef FUNC_DD
 
429
#undef FUNC_FFF
 
430
#undef FUNC_FF
 
431
#undef add_func
 
432
 
 
433
    if (PyModule_AddObject(m, "funccodes", d) < 0) INITERROR;
 
434
 
 
435
    if (PyModule_AddObject(m, "allaxes", PyLong_FromLong(255)) < 0) INITERROR;
 
436
    if (PyModule_AddObject(m, "maxdims", PyLong_FromLong(NPY_MAXDIMS)) < 0) INITERROR;
 
437
 
 
438
#if PY_MAJOR_VERSION >= 3
 
439
    return m;
 
440
#endif
 
441
}
 
442
 
 
443
#ifdef __cplusplus
 
444
}  // extern "C"
 
445
#endif