~ubuntu-branches/ubuntu/quantal/python-django/quantal-security

« back to all changes in this revision

Viewing changes to django/db/models/sql/subqueries.py

  • Committer: Bazaar Package Importer
  • Author(s): Chris Lamb
  • Date: 2010-05-21 07:52:55 UTC
  • mfrom: (1.3.6 upstream)
  • mto: This revision was merged to the branch mainline in revision 28.
  • Revision ID: james.westby@ubuntu.com-20100521075255-ii78v1dyfmyu3uzx
Tags: upstream-1.2
ImportĀ upstreamĀ versionĀ 1.2

Show diffs side-by-side

added added

removed removed

Lines of Context:
3
3
"""
4
4
 
5
5
from django.core.exceptions import FieldError
 
6
from django.db import connections
6
7
from django.db.models.sql.constants import *
7
8
from django.db.models.sql.datastructures import Date
8
9
from django.db.models.sql.expressions import SQLEvaluator
17
18
    Delete queries are done through this class, since they are more constrained
18
19
    than general queries.
19
20
    """
20
 
    def as_sql(self):
21
 
        """
22
 
        Creates the SQL for this query. Returns the SQL string and list of
23
 
        parameters.
24
 
        """
25
 
        assert len(self.tables) == 1, \
26
 
                "Can only delete from one table at a time."
27
 
        result = ['DELETE FROM %s' % self.quote_name_unless_alias(self.tables[0])]
28
 
        where, params = self.where.as_sql()
29
 
        result.append('WHERE %s' % where)
30
 
        return ' '.join(result), tuple(params)
31
 
 
32
 
    def do_query(self, table, where):
 
21
 
 
22
    compiler = 'SQLDeleteCompiler'
 
23
 
 
24
    def do_query(self, table, where, using):
33
25
        self.tables = [table]
34
26
        self.where = where
35
 
        self.execute_sql(None)
36
 
 
37
 
    def delete_batch_related(self, pk_list):
38
 
        """
39
 
        Set up and execute delete queries for all the objects related to the
40
 
        primary key values in pk_list. To delete the objects themselves, use
41
 
        the delete_batch() method.
42
 
 
43
 
        More than one physical query may be executed if there are a
44
 
        lot of values in pk_list.
45
 
        """
46
 
        from django.contrib.contenttypes import generic
47
 
        cls = self.model
48
 
        for related in cls._meta.get_all_related_many_to_many_objects():
49
 
            if not isinstance(related.field, generic.GenericRelation):
50
 
                for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
51
 
                    where = self.where_class()
52
 
                    where.add((Constraint(None,
53
 
                            related.field.m2m_reverse_name(), related.field),
54
 
                            'in',
55
 
                            pk_list[offset : offset+GET_ITERATOR_CHUNK_SIZE]),
56
 
                            AND)
57
 
                    self.do_query(related.field.m2m_db_table(), where)
58
 
 
59
 
        for f in cls._meta.many_to_many:
60
 
            w1 = self.where_class()
61
 
            if isinstance(f, generic.GenericRelation):
62
 
                from django.contrib.contenttypes.models import ContentType
63
 
                field = f.rel.to._meta.get_field(f.content_type_field_name)
64
 
                w1.add((Constraint(None, field.column, field), 'exact',
65
 
                        ContentType.objects.get_for_model(cls).id), AND)
66
 
            for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
67
 
                where = self.where_class()
68
 
                where.add((Constraint(None, f.m2m_column_name(), f), 'in',
69
 
                        pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
70
 
                        AND)
71
 
                if w1:
72
 
                    where.add(w1, AND)
73
 
                self.do_query(f.m2m_db_table(), where)
74
 
 
75
 
    def delete_batch(self, pk_list):
76
 
        """
77
 
        Set up and execute delete queries for all the objects in pk_list. This
78
 
        should be called after delete_batch_related(), if necessary.
 
27
        self.get_compiler(using).execute_sql(None)
 
28
 
 
29
    def delete_batch(self, pk_list, using):
 
30
        """
 
31
        Set up and execute delete queries for all the objects in pk_list.
79
32
 
80
33
        More than one physical query may be executed if there are a
81
34
        lot of values in pk_list.
85
38
            field = self.model._meta.pk
86
39
            where.add((Constraint(None, field.column, field), 'in',
87
40
                    pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND)
88
 
            self.do_query(self.model._meta.db_table, where)
 
41
            self.do_query(self.model._meta.db_table, where, using=using)
89
42
 
90
43
class UpdateQuery(Query):
91
44
    """
92
45
    Represents an "update" SQL query.
93
46
    """
 
47
 
 
48
    compiler = 'SQLUpdateCompiler'
 
49
 
94
50
    def __init__(self, *args, **kwargs):
95
51
        super(UpdateQuery, self).__init__(*args, **kwargs)
96
52
        self._setup_query()
110
66
        return super(UpdateQuery, self).clone(klass,
111
67
                related_updates=self.related_updates.copy(), **kwargs)
112
68
 
113
 
    def execute_sql(self, result_type=None):
114
 
        """
115
 
        Execute the specified update. Returns the number of rows affected by
116
 
        the primary update query. The "primary update query" is the first
117
 
        non-empty query that is executed. Row counts for any subsequent,
118
 
        related queries are not available.
119
 
        """
120
 
        cursor = super(UpdateQuery, self).execute_sql(result_type)
121
 
        rows = cursor and cursor.rowcount or 0
122
 
        is_empty = cursor is None
123
 
        del cursor
124
 
        for query in self.get_related_updates():
125
 
            aux_rows = query.execute_sql(result_type)
126
 
            if is_empty:
127
 
                rows = aux_rows
128
 
                is_empty = False
129
 
        return rows
130
 
 
131
 
    def as_sql(self):
132
 
        """
133
 
        Creates the SQL for this query. Returns the SQL string and list of
134
 
        parameters.
135
 
        """
136
 
        self.pre_sql_setup()
137
 
        if not self.values:
138
 
            return '', ()
139
 
        table = self.tables[0]
140
 
        qn = self.quote_name_unless_alias
141
 
        result = ['UPDATE %s' % qn(table)]
142
 
        result.append('SET')
143
 
        values, update_params = [], []
144
 
        for name, val, placeholder in self.values:
145
 
            if hasattr(val, 'as_sql'):
146
 
                sql, params = val.as_sql(qn)
147
 
                values.append('%s = %s' % (qn(name), sql))
148
 
                update_params.extend(params)
149
 
            elif val is not None:
150
 
                values.append('%s = %s' % (qn(name), placeholder))
151
 
                update_params.append(val)
152
 
            else:
153
 
                values.append('%s = NULL' % qn(name))
154
 
        result.append(', '.join(values))
155
 
        where, params = self.where.as_sql()
156
 
        if where:
157
 
            result.append('WHERE %s' % where)
158
 
        return ' '.join(result), tuple(update_params + params)
159
 
 
160
 
    def pre_sql_setup(self):
161
 
        """
162
 
        If the update depends on results from other tables, we need to do some
163
 
        munging of the "where" conditions to match the format required for
164
 
        (portable) SQL updates. That is done here.
165
 
 
166
 
        Further, if we are going to be running multiple updates, we pull out
167
 
        the id values to update at this point so that they don't change as a
168
 
        result of the progressive updates.
169
 
        """
170
 
        self.select_related = False
171
 
        self.clear_ordering(True)
172
 
        super(UpdateQuery, self).pre_sql_setup()
173
 
        count = self.count_active_tables()
174
 
        if not self.related_updates and count == 1:
175
 
            return
176
 
 
177
 
        # We need to use a sub-select in the where clause to filter on things
178
 
        # from other tables.
179
 
        query = self.clone(klass=Query)
180
 
        query.bump_prefix()
181
 
        query.extra = {}
182
 
        query.select = []
183
 
        query.add_fields([query.model._meta.pk.name])
184
 
        must_pre_select = count > 1 and not self.connection.features.update_can_self_select
185
 
 
186
 
        # Now we adjust the current query: reset the where clause and get rid
187
 
        # of all the tables we don't need (since they're in the sub-select).
188
 
        self.where = self.where_class()
189
 
        if self.related_updates or must_pre_select:
190
 
            # Either we're using the idents in multiple update queries (so
191
 
            # don't want them to change), or the db backend doesn't support
192
 
            # selecting from the updating table (e.g. MySQL).
193
 
            idents = []
194
 
            for rows in query.execute_sql(MULTI):
195
 
                idents.extend([r[0] for r in rows])
196
 
            self.add_filter(('pk__in', idents))
197
 
            self.related_ids = idents
198
 
        else:
199
 
            # The fast path. Filters and updates in one query.
200
 
            self.add_filter(('pk__in', query))
201
 
        for alias in self.tables[1:]:
202
 
            self.alias_refcount[alias] = 0
203
 
 
204
 
    def clear_related(self, related_field, pk_list):
 
69
 
 
70
    def clear_related(self, related_field, pk_list, using):
205
71
        """
206
72
        Set up and execute an update query that clears related entries for the
207
73
        keys in pk_list.
214
80
            self.where.add((Constraint(None, f.column, f), 'in',
215
81
                    pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
216
82
                    AND)
217
 
            self.values = [(related_field.column, None, '%s')]
218
 
            self.execute_sql(None)
 
83
            self.values = [(related_field, None, None)]
 
84
            self.get_compiler(using).execute_sql(None)
219
85
 
220
86
    def add_update_values(self, values):
221
87
        """
228
94
            field, model, direct, m2m = self.model._meta.get_field_by_name(name)
229
95
            if not direct or m2m:
230
96
                raise FieldError('Cannot update model field %r (only non-relations and foreign keys permitted).' % field)
 
97
            if model:
 
98
                self.add_related_update(model, field, val)
 
99
                continue
231
100
            values_seq.append((field, model, val))
232
101
        return self.add_update_fields(values_seq)
233
102
 
237
106
        Used by add_update_values() as well as the "fast" update path when
238
107
        saving models.
239
108
        """
240
 
        from django.db.models.base import Model
241
 
        for field, model, val in values_seq:
242
 
            if hasattr(val, 'prepare_database_save'):
243
 
                val = val.prepare_database_save(field)
244
 
            else:
245
 
                val = field.get_db_prep_save(val)
246
 
 
247
 
            # Getting the placeholder for the field.
248
 
            if hasattr(field, 'get_placeholder'):
249
 
                placeholder = field.get_placeholder(val)
250
 
            else:
251
 
                placeholder = '%s'
252
 
 
253
 
            if hasattr(val, 'evaluate'):
254
 
                val = SQLEvaluator(val, self, allow_joins=False)
255
 
            if model:
256
 
                self.add_related_update(model, field.column, val, placeholder)
257
 
            else:
258
 
                self.values.append((field.column, val, placeholder))
259
 
 
260
 
    def add_related_update(self, model, column, value, placeholder):
 
109
        self.values.extend(values_seq)
 
110
 
 
111
    def add_related_update(self, model, field, value):
261
112
        """
262
113
        Adds (name, value) to an update query for an ancestor model.
263
114
 
264
115
        Updates are coalesced so that we only run one update query per ancestor.
265
116
        """
266
117
        try:
267
 
            self.related_updates[model].append((column, value, placeholder))
 
118
            self.related_updates[model].append((field, None, value))
268
119
        except KeyError:
269
 
            self.related_updates[model] = [(column, value, placeholder)]
 
120
            self.related_updates[model] = [(field, None, value)]
270
121
 
271
122
    def get_related_updates(self):
272
123
        """
278
129
            return []
279
130
        result = []
280
131
        for model, values in self.related_updates.iteritems():
281
 
            query = UpdateQuery(model, self.connection)
 
132
            query = UpdateQuery(model)
282
133
            query.values = values
283
 
            if self.related_ids:
 
134
            if self.related_ids is not None:
284
135
                query.add_filter(('pk__in', self.related_ids))
285
136
            result.append(query)
286
137
        return result
287
138
 
288
139
class InsertQuery(Query):
 
140
    compiler = 'SQLInsertCompiler'
 
141
 
289
142
    def __init__(self, *args, **kwargs):
290
143
        super(InsertQuery, self).__init__(*args, **kwargs)
291
144
        self.columns = []
292
145
        self.values = []
293
146
        self.params = ()
294
 
        self.return_id = False
295
147
 
296
148
    def clone(self, klass=None, **kwargs):
297
 
        extras = {'columns': self.columns[:], 'values': self.values[:],
298
 
                  'params': self.params, 'return_id': self.return_id}
 
149
        extras = {
 
150
            'columns': self.columns[:],
 
151
            'values': self.values[:],
 
152
            'params': self.params
 
153
        }
299
154
        extras.update(kwargs)
300
155
        return super(InsertQuery, self).clone(klass, **extras)
301
156
 
302
 
    def as_sql(self):
303
 
        # We don't need quote_name_unless_alias() here, since these are all
304
 
        # going to be column names (so we can avoid the extra overhead).
305
 
        qn = self.connection.ops.quote_name
306
 
        opts = self.model._meta
307
 
        result = ['INSERT INTO %s' % qn(opts.db_table)]
308
 
        result.append('(%s)' % ', '.join([qn(c) for c in self.columns]))
309
 
        result.append('VALUES (%s)' % ', '.join(self.values))
310
 
        params = self.params
311
 
        if self.return_id and self.connection.features.can_return_id_from_insert:
312
 
            col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
313
 
            r_fmt, r_params = self.connection.ops.return_insert_id()
314
 
            result.append(r_fmt % col)
315
 
            params = params + r_params
316
 
        return ' '.join(result), params
317
 
 
318
 
    def execute_sql(self, return_id=False):
319
 
        self.return_id = return_id
320
 
        cursor = super(InsertQuery, self).execute_sql(None)
321
 
        if not (return_id and cursor):
322
 
            return
323
 
        if self.connection.features.can_return_id_from_insert:
324
 
            return self.connection.ops.fetch_returned_insert_id(cursor)
325
 
        return self.connection.ops.last_insert_id(cursor,
326
 
                self.model._meta.db_table, self.model._meta.pk.column)
327
 
 
328
157
    def insert_values(self, insert_values, raw_values=False):
329
158
        """
330
159
        Set up the insert query from the 'insert_values' dictionary. The
337
166
        """
338
167
        placeholders, values = [], []
339
168
        for field, val in insert_values:
340
 
            if hasattr(field, 'get_placeholder'):
341
 
                # Some fields (e.g. geo fields) need special munging before
342
 
                # they can be inserted.
343
 
                placeholders.append(field.get_placeholder(val))
344
 
            else:
345
 
                placeholders.append('%s')
346
 
 
 
169
            placeholders.append((field, val))
347
170
            self.columns.append(field.column)
348
171
            values.append(val)
349
172
        if raw_values:
350
 
            self.values.extend(values)
 
173
            self.values.extend([(None, v) for v in values])
351
174
        else:
352
175
            self.params += tuple(values)
353
176
            self.values.extend(placeholders)
358
181
    date field. This requires some special handling when converting the results
359
182
    back to Python objects, so we put it in a separate class.
360
183
    """
361
 
    def __getstate__(self):
362
 
        """
363
 
        Special DateQuery-specific pickle handling.
364
 
        """
365
 
        for elt in self.select:
366
 
            if isinstance(elt, Date):
367
 
                # Eliminate a method reference that can't be pickled. The
368
 
                # __setstate__ method restores this.
369
 
                elt.date_sql_func = None
370
 
        return super(DateQuery, self).__getstate__()
371
 
 
372
 
    def __setstate__(self, obj_dict):
373
 
        super(DateQuery, self).__setstate__(obj_dict)
374
 
        for elt in self.select:
375
 
            if isinstance(elt, Date):
376
 
                self.date_sql_func = self.connection.ops.date_trunc_sql
377
 
 
378
 
    def results_iter(self):
379
 
        """
380
 
        Returns an iterator over the results from executing this query.
381
 
        """
382
 
        resolve_columns = hasattr(self, 'resolve_columns')
383
 
        if resolve_columns:
384
 
            from django.db.models.fields import DateTimeField
385
 
            fields = [DateTimeField()]
386
 
        else:
387
 
            from django.db.backends.util import typecast_timestamp
388
 
            needs_string_cast = self.connection.features.needs_datetime_string_cast
389
 
 
390
 
        offset = len(self.extra_select)
391
 
        for rows in self.execute_sql(MULTI):
392
 
            for row in rows:
393
 
                date = row[offset]
394
 
                if resolve_columns:
395
 
                    date = self.resolve_columns(row, fields)[offset]
396
 
                elif needs_string_cast:
397
 
                    date = typecast_timestamp(str(date))
398
 
                yield date
 
184
 
 
185
    compiler = 'SQLDateCompiler'
399
186
 
400
187
    def add_date_select(self, field, lookup_type, order='ASC'):
401
188
        """
404
191
        result = self.setup_joins([field.name], self.get_meta(),
405
192
                self.get_initial_alias(), False)
406
193
        alias = result[3][-1]
407
 
        select = Date((alias, field.column), lookup_type,
408
 
                self.connection.ops.date_trunc_sql)
 
194
        select = Date((alias, field.column), lookup_type)
409
195
        self.select = [select]
410
196
        self.select_fields = [None]
411
197
        self.select_related = False # See #7097.
412
 
        self.extra = {}
 
198
        self.set_extra_mask([])
413
199
        self.distinct = True
414
200
        self.order_by = order == 'ASC' and [1] or [-1]
415
201
 
418
204
    An AggregateQuery takes another query as a parameter to the FROM
419
205
    clause and only selects the elements in the provided list.
420
206
    """
421
 
    def add_subquery(self, query):
422
 
        self.subquery, self.sub_params = query.as_sql(with_col_aliases=True)
423
 
 
424
 
    def as_sql(self, quote_func=None):
425
 
        """
426
 
        Creates the SQL for this query. Returns the SQL string and list of
427
 
        parameters.
428
 
        """
429
 
        sql = ('SELECT %s FROM (%s) subquery' % (
430
 
            ', '.join([
431
 
                aggregate.as_sql()
432
 
                for aggregate in self.aggregate_select.values()
433
 
            ]),
434
 
            self.subquery)
435
 
        )
436
 
        params = self.sub_params
437
 
        return (sql, params)
 
207
 
 
208
    compiler = 'SQLAggregateCompiler'
 
209
 
 
210
    def add_subquery(self, query, using):
 
211
        self.subquery, self.sub_params = query.get_compiler(using).as_sql(with_col_aliases=True)