~vcs-imports/python-mysqldb/old-svn-trunk

« back to all changes in this revision

Viewing changes to src/connections.c

  • Committer: adustman
  • Date: 2009-03-30 20:21:24 UTC
  • Revision ID: vcs-imports@canonical.com-20090330202124-j3ehf98sy2zl06ih
Reimplement MySQL->Python type conversion in C; much simpler and easier to deal with now. Hey, all my tests pass, so I guess that means I need to write some more tests.

Show diffs side-by-side

added added

removed removed

Lines of Context:
9
9
        PyObject *kwargs)
10
10
{
11
11
        MYSQL *conn = NULL;
12
 
        PyObject *conv = NULL;
 
12
        PyObject *decoder_stack = NULL;
13
13
        PyObject *ssl = NULL;
14
14
#if HAVE_OPENSSL
15
15
        char *key = NULL, *cert = NULL, *ca = NULL,
20
20
        unsigned int port = 0;
21
21
        unsigned int client_flag = 0;
22
22
        static char *kwlist[] = { "host", "user", "passwd", "db", "port",
23
 
                                  "unix_socket", "conv",
 
23
                                  "unix_socket", "decoder_stack",
24
24
                                  "connect_timeout", "compress",
25
25
                                  "named_pipe", "init_command",
26
26
                                  "read_default_file", "read_default_group",
33
33
             *read_default_file=NULL,
34
34
             *read_default_group=NULL;
35
35
        
36
 
        self->converter = NULL;
 
36
        self->decoder_stack = NULL;
37
37
        self->open = 0;
38
38
        check_server_init(-1);
39
39
        if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ssssisOiiisssiOi:connect",
40
40
                                         kwlist,
41
41
                                         &host, &user, &passwd, &db,
42
 
                                         &port, &unix_socket, &conv,
 
42
                                         &port, &unix_socket, &decoder_stack,
43
43
                                         &connect_timeout,
44
44
                                         &compress, &named_pipe,
45
45
                                         &init_command, &read_default_file,
49
49
                                         ))
50
50
                return -1;
51
51
 
52
 
        /* Keep the converter mapping or a blank mapping dict */
53
 
        if (!conv)
54
 
                conv = PyDict_New();
 
52
        if (!decoder_stack)
 
53
                decoder_stack = PyList_New(0);
55
54
        else
56
 
                Py_INCREF(conv);
57
 
        if (!conv)
58
 
                return -1;
59
 
        self->converter = conv;
60
 
 
 
55
                Py_INCREF(decoder_stack);
 
56
        self->decoder_stack = decoder_stack;
 
57
        
61
58
#define _stringsuck(d,t,s) {t=PyMapping_GetItemString(s,#d);\
62
59
        if(t){d=PyString_AsString(t);Py_DECREF(t);}\
63
60
        PyErr_Clear();}
148
145
unix_socket\n\
149
146
  string, location of unix_socket (UNIX-ish only)\n\
150
147
\n\
151
 
conv\n\
152
 
  mapping, maps MySQL FIELD_TYPE.* to Python functions which\n\
153
 
  convert a string to the appropriate Python type\n\
154
 
\n\
155
148
connect_timeout\n\
156
149
  number of seconds to wait before the connection\n\
157
150
  attempt fails.\n\
201
194
        visitproc visit,
202
195
        void *arg)
203
196
{
204
 
        if (self->converter)
205
 
                return visit(self->converter, arg);
 
197
        if (self->decoder_stack)
 
198
                return visit(self->decoder_stack, arg);
206
199
        return 0;
207
200
}
208
201
 
209
202
static int _mysql_ConnectionObject_clear(
210
203
        _mysql_ConnectionObject *self)
211
204
{
212
 
        Py_XDECREF(self->converter);
213
 
        self->converter = NULL;
 
205
        Py_XDECREF(self->decoder_stack);
 
206
        self->decoder_stack = NULL;
214
207
        return 0;
215
208
}
216
209
 
218
211
_escape_item(
219
212
        PyObject *item,
220
213
        PyObject *d);
221
 
        
222
 
char _mysql_escape__doc__[] =
223
 
"escape(obj, dict) -- escape any special characters in object obj\n\
224
 
using mapping dict to provide quoting functions for each type.\n\
225
 
Returns a SQL literal string.";
226
 
PyObject *
227
 
_mysql_escape(
228
 
        PyObject *self,
229
 
        PyObject *args)
230
 
{
231
 
        PyObject *o=NULL, *d=NULL;
232
 
        if (!PyArg_ParseTuple(args, "O|O:escape", &o, &d))
233
 
                return NULL;
234
 
        if (d) {
235
 
                if (!PyMapping_Check(d)) {
236
 
                        PyErr_SetString(PyExc_TypeError,
237
 
                                        "argument 2 must be a mapping");
238
 
                        return NULL;
239
 
                }
240
 
                return _escape_item(o, d);
241
 
        } else {
242
 
                if (!self) {
243
 
                        PyErr_SetString(PyExc_TypeError,
244
 
                                        "argument 2 must be a mapping");
245
 
                        return NULL;
246
 
                }
247
 
                return _escape_item(o,
248
 
                           ((_mysql_ConnectionObject *) self)->converter);
249
 
        }
250
 
}
251
214
 
252
215
char _mysql_escape_string__doc__[] =
253
216
"escape_string(s) -- quote any SQL-interpreted characters in string s.\n\
254
 
\n\
255
 
Use connection.escape_string(s), if you use it at all.\n\
256
 
_mysql.escape_string(s) cannot handle character sets. You are\n\
257
 
probably better off using connection.escape(o) instead, since\n\
258
 
it will escape entire sequences as well as strings.";
 
217
If you want quotes around your value, use string_literal(s) instead.\n\
 
218
";
259
219
 
260
220
PyObject *
261
221
_mysql_escape_string(
269
229
        str = PyString_FromStringAndSize((char *) NULL, size*2+1);
270
230
        if (!str) return PyErr_NoMemory();
271
231
        out = PyString_AS_STRING(str);
272
 
#if MYSQL_VERSION_ID < 32321
273
 
        len = mysql_escape_string(out, in, size);
274
 
#else
275
 
        check_server_init(NULL);
276
 
        if (self && self->open)
277
 
                len = mysql_real_escape_string(&(self->connection), out, in, size);
278
 
        else
279
 
                len = mysql_escape_string(out, in, size);
280
 
#endif
 
232
        len = mysql_real_escape_string(&(self->connection), out, in, size);
281
233
        if (_PyString_Resize(&str, len) < 0) return NULL;
282
234
        return (str);
283
235
}
284
236
 
285
237
char _mysql_string_literal__doc__[] =
286
 
"string_literal(obj) -- converts object obj into a SQL string literal.\n\
 
238
"string_literal(s) -- converts string s into a SQL string literal.\n\
287
239
This means, any special SQL characters are escaped, and it is enclosed\n\
288
240
within single quotes. In other words, it performs:\n\
289
241
\n\
290
 
\"'%s'\" % escape_string(str(obj))\n\
291
 
\n\
292
 
Use connection.string_literal(obj), if you use it at all.\n\
293
 
_mysql.string_literal(obj) cannot handle character sets.";
 
242
\"'%s'\" % escape_string(s)\n\
 
243
";
294
244
 
295
245
PyObject *
296
246
_mysql_string_literal(
297
247
        _mysql_ConnectionObject *self,
298
248
        PyObject *args)
299
249
{
300
 
        PyObject *str, *s, *o, *d;
 
250
        PyObject *str;
301
251
        char *in, *out;
302
252
        int len, size;
303
 
        if (!PyArg_ParseTuple(args, "O|O:string_literal", &o, &d)) return NULL;
304
 
        s = PyObject_Str(o);
305
 
        if (!s) return NULL;
306
 
        in = PyString_AsString(s);
307
 
        size = PyString_GET_SIZE(s);
 
253
        if (!PyArg_ParseTuple(args, "s#:string_literal", &in, &size)) return NULL;
308
254
        str = PyString_FromStringAndSize((char *) NULL, size*2+3);
309
255
        if (!str) return PyErr_NoMemory();
310
256
        out = PyString_AS_STRING(str);
311
 
#if MYSQL_VERSION_ID < 32321
312
 
        len = mysql_escape_string(out+1, in, size);
313
 
#else
314
 
        check_server_init(NULL);
315
 
        if (self && self->open)
316
 
                len = mysql_real_escape_string(&(self->connection), out+1, in, size);
317
 
        else
318
 
                len = mysql_escape_string(out+1, in, size);
319
 
#endif
 
257
        len = mysql_real_escape_string(&(self->connection), out+1, in, size);
320
258
        *out = *(out+len+1) = '\'';
321
259
        if (_PyString_Resize(&str, len+2) < 0) return NULL;
322
 
        Py_DECREF(s);
323
260
        return (str);
324
261
}
325
262
 
994
931
        _mysql_ResultObject *r=NULL;
995
932
 
996
933
        check_connection(self);
997
 
        arglist = Py_BuildValue("(OiO)", self, 0, self->converter);
 
934
        arglist = Py_BuildValue("(OiO)", self, 0, self->decoder_stack);
998
935
        if (!arglist) goto error;
999
936
        kwarglist = PyDict_New();
1000
937
        if (!kwarglist) goto error;
1054
991
        _mysql_ResultObject *r=NULL;
1055
992
 
1056
993
        check_connection(self);
1057
 
        arglist = Py_BuildValue("(OiO)", self, 1, self->converter);
 
994
        arglist = Py_BuildValue("(OiO)", self, 1, self->decoder_stack);
1058
995
        if (!arglist) return NULL;
1059
996
        kwarglist = PyDict_New();
1060
997
        if (!kwarglist) goto error;
1197
1134
                _mysql_ConnectionObject_dump_debug_info__doc__
1198
1135
        },
1199
1136
        {
1200
 
                "escape",
1201
 
                (PyCFunction)_mysql_escape,
1202
 
                METH_VARARGS,
1203
 
                _mysql_escape__doc__
1204
 
        },
1205
 
        {
1206
1137
                "escape_string",
1207
1138
                (PyCFunction)_mysql_escape_string,
1208
1139
                METH_VARARGS,
1327
1258
                "True if connection is open"
1328
1259
        },
1329
1260
        {
1330
 
                "converter",
 
1261
                "decoder_stack",
1331
1262
                T_OBJECT,
1332
 
                offsetof(_mysql_ConnectionObject, converter),
 
1263
                offsetof(_mysql_ConnectionObject, decoder_stack),
1333
1264
                0,
1334
 
                "Type conversion mapping"
 
1265
                "Type decoder stack"
1335
1266
        },
1336
1267
        {
1337
1268
                "server_capabilities",