~ubuntu-branches/ubuntu/quantal/python-django/quantal-security

« back to all changes in this revision

Viewing changes to django/db/models/sql/compiler.py

  • Committer: Bazaar Package Importer
  • Author(s): Chris Lamb
  • Date: 2010-05-21 07:52:55 UTC
  • mfrom: (1.3.6 upstream)
  • mto: This revision was merged to the branch mainline in revision 28.
  • Revision ID: james.westby@ubuntu.com-20100521075255-ii78v1dyfmyu3uzx
Tags: upstream-1.2
ImportĀ upstreamĀ versionĀ 1.2

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
from django.core.exceptions import FieldError
 
2
from django.db import connections
 
3
from django.db.backends.util import truncate_name
 
4
from django.db.models.sql.constants import *
 
5
from django.db.models.sql.datastructures import EmptyResultSet
 
6
from django.db.models.sql.expressions import SQLEvaluator
 
7
from django.db.models.sql.query import get_proxied_model, get_order_dir, \
 
8
     select_related_descend, Query
 
9
 
 
10
class SQLCompiler(object):
 
11
    def __init__(self, query, connection, using):
 
12
        self.query = query
 
13
        self.connection = connection
 
14
        self.using = using
 
15
        self.quote_cache = {}
 
16
 
 
17
    def pre_sql_setup(self):
 
18
        """
 
19
        Does any necessary class setup immediately prior to producing SQL. This
 
20
        is for things that can't necessarily be done in __init__ because we
 
21
        might not have all the pieces in place at that time.
 
22
        """
 
23
        if not self.query.tables:
 
24
            self.query.join((None, self.query.model._meta.db_table, None, None))
 
25
        if (not self.query.select and self.query.default_cols and not
 
26
                self.query.included_inherited_models):
 
27
            self.query.setup_inherited_models()
 
28
        if self.query.select_related and not self.query.related_select_cols:
 
29
            self.fill_related_selections()
 
30
 
 
31
    def quote_name_unless_alias(self, name):
 
32
        """
 
33
        A wrapper around connection.ops.quote_name that doesn't quote aliases
 
34
        for table names. This avoids problems with some SQL dialects that treat
 
35
        quoted strings specially (e.g. PostgreSQL).
 
36
        """
 
37
        if name in self.quote_cache:
 
38
            return self.quote_cache[name]
 
39
        if ((name in self.query.alias_map and name not in self.query.table_map) or
 
40
                name in self.query.extra_select):
 
41
            self.quote_cache[name] = name
 
42
            return name
 
43
        r = self.connection.ops.quote_name(name)
 
44
        self.quote_cache[name] = r
 
45
        return r
 
46
 
 
47
    def as_sql(self, with_limits=True, with_col_aliases=False):
 
48
        """
 
49
        Creates the SQL for this query. Returns the SQL string and list of
 
50
        parameters.
 
51
 
 
52
        If 'with_limits' is False, any limit/offset information is not included
 
53
        in the query.
 
54
        """
 
55
        self.pre_sql_setup()
 
56
        out_cols = self.get_columns(with_col_aliases)
 
57
        ordering, ordering_group_by = self.get_ordering()
 
58
 
 
59
        # This must come after 'select' and 'ordering' -- see docstring of
 
60
        # get_from_clause() for details.
 
61
        from_, f_params = self.get_from_clause()
 
62
 
 
63
        qn = self.quote_name_unless_alias
 
64
 
 
65
        where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection)
 
66
        having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection)
 
67
        params = []
 
68
        for val in self.query.extra_select.itervalues():
 
69
            params.extend(val[1])
 
70
 
 
71
        result = ['SELECT']
 
72
        if self.query.distinct:
 
73
            result.append('DISTINCT')
 
74
        result.append(', '.join(out_cols + self.query.ordering_aliases))
 
75
 
 
76
        result.append('FROM')
 
77
        result.extend(from_)
 
78
        params.extend(f_params)
 
79
 
 
80
        if where:
 
81
            result.append('WHERE %s' % where)
 
82
            params.extend(w_params)
 
83
 
 
84
        grouping, gb_params = self.get_grouping()
 
85
        if grouping:
 
86
            if ordering:
 
87
                # If the backend can't group by PK (i.e., any database
 
88
                # other than MySQL), then any fields mentioned in the
 
89
                # ordering clause needs to be in the group by clause.
 
90
                if not self.connection.features.allows_group_by_pk:
 
91
                    for col, col_params in ordering_group_by:
 
92
                        if col not in grouping:
 
93
                            grouping.append(str(col))
 
94
                            gb_params.extend(col_params)
 
95
            else:
 
96
                ordering = self.connection.ops.force_no_ordering()
 
97
            result.append('GROUP BY %s' % ', '.join(grouping))
 
98
            params.extend(gb_params)
 
99
 
 
100
        if having:
 
101
            result.append('HAVING %s' % having)
 
102
            params.extend(h_params)
 
103
 
 
104
        if ordering:
 
105
            result.append('ORDER BY %s' % ', '.join(ordering))
 
106
 
 
107
        if with_limits:
 
108
            if self.query.high_mark is not None:
 
109
                result.append('LIMIT %d' % (self.query.high_mark - self.query.low_mark))
 
110
            if self.query.low_mark:
 
111
                if self.query.high_mark is None:
 
112
                    val = self.connection.ops.no_limit_value()
 
113
                    if val:
 
114
                        result.append('LIMIT %d' % val)
 
115
                result.append('OFFSET %d' % self.query.low_mark)
 
116
 
 
117
        return ' '.join(result), tuple(params)
 
118
 
 
119
    def as_nested_sql(self):
 
120
        """
 
121
        Perform the same functionality as the as_sql() method, returning an
 
122
        SQL string and parameters. However, the alias prefixes are bumped
 
123
        beforehand (in a copy -- the current query isn't changed), and any
 
124
        ordering is removed if the query is unsliced.
 
125
 
 
126
        Used when nesting this query inside another.
 
127
        """
 
128
        obj = self.query.clone()
 
129
        if obj.low_mark == 0 and obj.high_mark is None:
 
130
            # If there is no slicing in use, then we can safely drop all ordering
 
131
            obj.clear_ordering(True)
 
132
        obj.bump_prefix()
 
133
        return obj.get_compiler(connection=self.connection).as_sql()
 
134
 
 
135
    def get_columns(self, with_aliases=False):
 
136
        """
 
137
        Returns the list of columns to use in the select statement. If no
 
138
        columns have been specified, returns all columns relating to fields in
 
139
        the model.
 
140
 
 
141
        If 'with_aliases' is true, any column names that are duplicated
 
142
        (without the table names) are given unique aliases. This is needed in
 
143
        some cases to avoid ambiguity with nested queries.
 
144
        """
 
145
        qn = self.quote_name_unless_alias
 
146
        qn2 = self.connection.ops.quote_name
 
147
        result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in self.query.extra_select.iteritems()]
 
148
        aliases = set(self.query.extra_select.keys())
 
149
        if with_aliases:
 
150
            col_aliases = aliases.copy()
 
151
        else:
 
152
            col_aliases = set()
 
153
        if self.query.select:
 
154
            only_load = self.deferred_to_columns()
 
155
            for col in self.query.select:
 
156
                if isinstance(col, (list, tuple)):
 
157
                    alias, column = col
 
158
                    table = self.query.alias_map[alias][TABLE_NAME]
 
159
                    if table in only_load and col not in only_load[table]:
 
160
                        continue
 
161
                    r = '%s.%s' % (qn(alias), qn(column))
 
162
                    if with_aliases:
 
163
                        if col[1] in col_aliases:
 
164
                            c_alias = 'Col%d' % len(col_aliases)
 
165
                            result.append('%s AS %s' % (r, c_alias))
 
166
                            aliases.add(c_alias)
 
167
                            col_aliases.add(c_alias)
 
168
                        else:
 
169
                            result.append('%s AS %s' % (r, qn2(col[1])))
 
170
                            aliases.add(r)
 
171
                            col_aliases.add(col[1])
 
172
                    else:
 
173
                        result.append(r)
 
174
                        aliases.add(r)
 
175
                        col_aliases.add(col[1])
 
176
                else:
 
177
                    result.append(col.as_sql(qn, self.connection))
 
178
 
 
179
                    if hasattr(col, 'alias'):
 
180
                        aliases.add(col.alias)
 
181
                        col_aliases.add(col.alias)
 
182
 
 
183
        elif self.query.default_cols:
 
184
            cols, new_aliases = self.get_default_columns(with_aliases,
 
185
                    col_aliases)
 
186
            result.extend(cols)
 
187
            aliases.update(new_aliases)
 
188
 
 
189
        max_name_length = self.connection.ops.max_name_length()
 
190
        result.extend([
 
191
            '%s%s' % (
 
192
                aggregate.as_sql(qn, self.connection),
 
193
                alias is not None
 
194
                    and ' AS %s' % qn(truncate_name(alias, max_name_length))
 
195
                    or ''
 
196
            )
 
197
            for alias, aggregate in self.query.aggregate_select.items()
 
198
        ])
 
199
 
 
200
        for table, col in self.query.related_select_cols:
 
201
            r = '%s.%s' % (qn(table), qn(col))
 
202
            if with_aliases and col in col_aliases:
 
203
                c_alias = 'Col%d' % len(col_aliases)
 
204
                result.append('%s AS %s' % (r, c_alias))
 
205
                aliases.add(c_alias)
 
206
                col_aliases.add(c_alias)
 
207
            else:
 
208
                result.append(r)
 
209
                aliases.add(r)
 
210
                col_aliases.add(col)
 
211
 
 
212
        self._select_aliases = aliases
 
213
        return result
 
214
 
 
215
    def get_default_columns(self, with_aliases=False, col_aliases=None,
 
216
            start_alias=None, opts=None, as_pairs=False, local_only=False):
 
217
        """
 
218
        Computes the default columns for selecting every field in the base
 
219
        model. Will sometimes be called to pull in related models (e.g. via
 
220
        select_related), in which case "opts" and "start_alias" will be given
 
221
        to provide a starting point for the traversal.
 
222
 
 
223
        Returns a list of strings, quoted appropriately for use in SQL
 
224
        directly, as well as a set of aliases used in the select statement (if
 
225
        'as_pairs' is True, returns a list of (alias, col_name) pairs instead
 
226
        of strings as the first component and None as the second component).
 
227
        """
 
228
        result = []
 
229
        if opts is None:
 
230
            opts = self.query.model._meta
 
231
        qn = self.quote_name_unless_alias
 
232
        qn2 = self.connection.ops.quote_name
 
233
        aliases = set()
 
234
        only_load = self.deferred_to_columns()
 
235
        # Skip all proxy to the root proxied model
 
236
        proxied_model = get_proxied_model(opts)
 
237
 
 
238
        if start_alias:
 
239
            seen = {None: start_alias}
 
240
        for field, model in opts.get_fields_with_model():
 
241
            if local_only and model is not None:
 
242
                continue
 
243
            if start_alias:
 
244
                try:
 
245
                    alias = seen[model]
 
246
                except KeyError:
 
247
                    if model is proxied_model:
 
248
                        alias = start_alias
 
249
                    else:
 
250
                        link_field = opts.get_ancestor_link(model)
 
251
                        alias = self.query.join((start_alias, model._meta.db_table,
 
252
                                link_field.column, model._meta.pk.column))
 
253
                    seen[model] = alias
 
254
            else:
 
255
                # If we're starting from the base model of the queryset, the
 
256
                # aliases will have already been set up in pre_sql_setup(), so
 
257
                # we can save time here.
 
258
                alias = self.query.included_inherited_models[model]
 
259
            table = self.query.alias_map[alias][TABLE_NAME]
 
260
            if table in only_load and field.column not in only_load[table]:
 
261
                continue
 
262
            if as_pairs:
 
263
                result.append((alias, field.column))
 
264
                aliases.add(alias)
 
265
                continue
 
266
            if with_aliases and field.column in col_aliases:
 
267
                c_alias = 'Col%d' % len(col_aliases)
 
268
                result.append('%s.%s AS %s' % (qn(alias),
 
269
                    qn2(field.column), c_alias))
 
270
                col_aliases.add(c_alias)
 
271
                aliases.add(c_alias)
 
272
            else:
 
273
                r = '%s.%s' % (qn(alias), qn2(field.column))
 
274
                result.append(r)
 
275
                aliases.add(r)
 
276
                if with_aliases:
 
277
                    col_aliases.add(field.column)
 
278
        return result, aliases
 
279
 
 
280
    def get_ordering(self):
 
281
        """
 
282
        Returns a tuple containing a list representing the SQL elements in the
 
283
        "order by" clause, and the list of SQL elements that need to be added
 
284
        to the GROUP BY clause as a result of the ordering.
 
285
 
 
286
        Also sets the ordering_aliases attribute on this instance to a list of
 
287
        extra aliases needed in the select.
 
288
 
 
289
        Determining the ordering SQL can change the tables we need to include,
 
290
        so this should be run *before* get_from_clause().
 
291
        """
 
292
        if self.query.extra_order_by:
 
293
            ordering = self.query.extra_order_by
 
294
        elif not self.query.default_ordering:
 
295
            ordering = self.query.order_by
 
296
        else:
 
297
            ordering = self.query.order_by or self.query.model._meta.ordering
 
298
        qn = self.quote_name_unless_alias
 
299
        qn2 = self.connection.ops.quote_name
 
300
        distinct = self.query.distinct
 
301
        select_aliases = self._select_aliases
 
302
        result = []
 
303
        group_by = []
 
304
        ordering_aliases = []
 
305
        if self.query.standard_ordering:
 
306
            asc, desc = ORDER_DIR['ASC']
 
307
        else:
 
308
            asc, desc = ORDER_DIR['DESC']
 
309
 
 
310
        # It's possible, due to model inheritance, that normal usage might try
 
311
        # to include the same field more than once in the ordering. We track
 
312
        # the table/column pairs we use and discard any after the first use.
 
313
        processed_pairs = set()
 
314
 
 
315
        for field in ordering:
 
316
            if field == '?':
 
317
                result.append(self.connection.ops.random_function_sql())
 
318
                continue
 
319
            if isinstance(field, int):
 
320
                if field < 0:
 
321
                    order = desc
 
322
                    field = -field
 
323
                else:
 
324
                    order = asc
 
325
                result.append('%s %s' % (field, order))
 
326
                group_by.append((field, []))
 
327
                continue
 
328
            col, order = get_order_dir(field, asc)
 
329
            if col in self.query.aggregate_select:
 
330
                result.append('%s %s' % (col, order))
 
331
                continue
 
332
            if '.' in field:
 
333
                # This came in through an extra(order_by=...) addition. Pass it
 
334
                # on verbatim.
 
335
                table, col = col.split('.', 1)
 
336
                if (table, col) not in processed_pairs:
 
337
                    elt = '%s.%s' % (qn(table), col)
 
338
                    processed_pairs.add((table, col))
 
339
                    if not distinct or elt in select_aliases:
 
340
                        result.append('%s %s' % (elt, order))
 
341
                        group_by.append((elt, []))
 
342
            elif get_order_dir(field)[0] not in self.query.extra_select:
 
343
                # 'col' is of the form 'field' or 'field1__field2' or
 
344
                # '-field1__field2__field', etc.
 
345
                for table, col, order in self.find_ordering_name(field,
 
346
                        self.query.model._meta, default_order=asc):
 
347
                    if (table, col) not in processed_pairs:
 
348
                        elt = '%s.%s' % (qn(table), qn2(col))
 
349
                        processed_pairs.add((table, col))
 
350
                        if distinct and elt not in select_aliases:
 
351
                            ordering_aliases.append(elt)
 
352
                        result.append('%s %s' % (elt, order))
 
353
                        group_by.append((elt, []))
 
354
            else:
 
355
                elt = qn2(col)
 
356
                if distinct and col not in select_aliases:
 
357
                    ordering_aliases.append(elt)
 
358
                result.append('%s %s' % (elt, order))
 
359
                group_by.append(self.query.extra_select[col])
 
360
        self.query.ordering_aliases = ordering_aliases
 
361
        return result, group_by
 
362
 
 
363
    def find_ordering_name(self, name, opts, alias=None, default_order='ASC',
 
364
            already_seen=None):
 
365
        """
 
366
        Returns the table alias (the name might be ambiguous, the alias will
 
367
        not be) and column name for ordering by the given 'name' parameter.
 
368
        The 'name' is of the form 'field1__field2__...__fieldN'.
 
369
        """
 
370
        name, order = get_order_dir(name, default_order)
 
371
        pieces = name.split(LOOKUP_SEP)
 
372
        if not alias:
 
373
            alias = self.query.get_initial_alias()
 
374
        field, target, opts, joins, last, extra = self.query.setup_joins(pieces,
 
375
                opts, alias, False)
 
376
        alias = joins[-1]
 
377
        col = target.column
 
378
        if not field.rel:
 
379
            # To avoid inadvertent trimming of a necessary alias, use the
 
380
            # refcount to show that we are referencing a non-relation field on
 
381
            # the model.
 
382
            self.query.ref_alias(alias)
 
383
 
 
384
        # Must use left outer joins for nullable fields and their relations.
 
385
        self.query.promote_alias_chain(joins,
 
386
            self.query.alias_map[joins[0]][JOIN_TYPE] == self.query.LOUTER)
 
387
 
 
388
        # If we get to this point and the field is a relation to another model,
 
389
        # append the default ordering for that model.
 
390
        if field.rel and len(joins) > 1 and opts.ordering:
 
391
            # Firstly, avoid infinite loops.
 
392
            if not already_seen:
 
393
                already_seen = set()
 
394
            join_tuple = tuple([self.query.alias_map[j][TABLE_NAME] for j in joins])
 
395
            if join_tuple in already_seen:
 
396
                raise FieldError('Infinite loop caused by ordering.')
 
397
            already_seen.add(join_tuple)
 
398
 
 
399
            results = []
 
400
            for item in opts.ordering:
 
401
                results.extend(self.find_ordering_name(item, opts, alias,
 
402
                        order, already_seen))
 
403
            return results
 
404
 
 
405
        if alias:
 
406
            # We have to do the same "final join" optimisation as in
 
407
            # add_filter, since the final column might not otherwise be part of
 
408
            # the select set (so we can't order on it).
 
409
            while 1:
 
410
                join = self.query.alias_map[alias]
 
411
                if col != join[RHS_JOIN_COL]:
 
412
                    break
 
413
                self.query.unref_alias(alias)
 
414
                alias = join[LHS_ALIAS]
 
415
                col = join[LHS_JOIN_COL]
 
416
        return [(alias, col, order)]
 
417
 
 
418
    def get_from_clause(self):
 
419
        """
 
420
        Returns a list of strings that are joined together to go after the
 
421
        "FROM" part of the query, as well as a list any extra parameters that
 
422
        need to be included. Sub-classes, can override this to create a
 
423
        from-clause via a "select".
 
424
 
 
425
        This should only be called after any SQL construction methods that
 
426
        might change the tables we need. This means the select columns and
 
427
        ordering must be done first.
 
428
        """
 
429
        result = []
 
430
        qn = self.quote_name_unless_alias
 
431
        qn2 = self.connection.ops.quote_name
 
432
        first = True
 
433
        for alias in self.query.tables:
 
434
            if not self.query.alias_refcount[alias]:
 
435
                continue
 
436
            try:
 
437
                name, alias, join_type, lhs, lhs_col, col, nullable = self.query.alias_map[alias]
 
438
            except KeyError:
 
439
                # Extra tables can end up in self.tables, but not in the
 
440
                # alias_map if they aren't in a join. That's OK. We skip them.
 
441
                continue
 
442
            alias_str = (alias != name and ' %s' % alias or '')
 
443
            if join_type and not first:
 
444
                result.append('%s %s%s ON (%s.%s = %s.%s)'
 
445
                        % (join_type, qn(name), alias_str, qn(lhs),
 
446
                           qn2(lhs_col), qn(alias), qn2(col)))
 
447
            else:
 
448
                connector = not first and ', ' or ''
 
449
                result.append('%s%s%s' % (connector, qn(name), alias_str))
 
450
            first = False
 
451
        for t in self.query.extra_tables:
 
452
            alias, unused = self.query.table_alias(t)
 
453
            # Only add the alias if it's not already present (the table_alias()
 
454
            # calls increments the refcount, so an alias refcount of one means
 
455
            # this is the only reference.
 
456
            if alias not in self.query.alias_map or self.query.alias_refcount[alias] == 1:
 
457
                connector = not first and ', ' or ''
 
458
                result.append('%s%s' % (connector, qn(alias)))
 
459
                first = False
 
460
        return result, []
 
461
 
 
462
    def get_grouping(self):
 
463
        """
 
464
        Returns a tuple representing the SQL elements in the "group by" clause.
 
465
        """
 
466
        qn = self.quote_name_unless_alias
 
467
        result, params = [], []
 
468
        if self.query.group_by is not None:
 
469
            if len(self.query.model._meta.fields) == len(self.query.select) and \
 
470
                self.connection.features.allows_group_by_pk:
 
471
                self.query.group_by = [(self.query.model._meta.db_table, self.query.model._meta.pk.column)]
 
472
 
 
473
            group_by = self.query.group_by or []
 
474
 
 
475
            extra_selects = []
 
476
            for extra_select, extra_params in self.query.extra_select.itervalues():
 
477
                extra_selects.append(extra_select)
 
478
                params.extend(extra_params)
 
479
            for col in group_by + self.query.related_select_cols + extra_selects:
 
480
                if isinstance(col, (list, tuple)):
 
481
                    result.append('%s.%s' % (qn(col[0]), qn(col[1])))
 
482
                elif hasattr(col, 'as_sql'):
 
483
                    result.append(col.as_sql(qn))
 
484
                else:
 
485
                    result.append('(%s)' % str(col))
 
486
        return result, params
 
487
 
 
488
    def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
 
489
            used=None, requested=None, restricted=None, nullable=None,
 
490
            dupe_set=None, avoid_set=None):
 
491
        """
 
492
        Fill in the information needed for a select_related query. The current
 
493
        depth is measured as the number of connections away from the root model
 
494
        (for example, cur_depth=1 means we are looking at models with direct
 
495
        connections to the root model).
 
496
        """
 
497
        if not restricted and self.query.max_depth and cur_depth > self.query.max_depth:
 
498
            # We've recursed far enough; bail out.
 
499
            return
 
500
 
 
501
        if not opts:
 
502
            opts = self.query.get_meta()
 
503
            root_alias = self.query.get_initial_alias()
 
504
            self.query.related_select_cols = []
 
505
            self.query.related_select_fields = []
 
506
        if not used:
 
507
            used = set()
 
508
        if dupe_set is None:
 
509
            dupe_set = set()
 
510
        if avoid_set is None:
 
511
            avoid_set = set()
 
512
        orig_dupe_set = dupe_set
 
513
 
 
514
        # Setup for the case when only particular related fields should be
 
515
        # included in the related selection.
 
516
        if requested is None:
 
517
            if isinstance(self.query.select_related, dict):
 
518
                requested = self.query.select_related
 
519
                restricted = True
 
520
            else:
 
521
                restricted = False
 
522
 
 
523
        for f, model in opts.get_fields_with_model():
 
524
            if not select_related_descend(f, restricted, requested):
 
525
                continue
 
526
            # The "avoid" set is aliases we want to avoid just for this
 
527
            # particular branch of the recursion. They aren't permanently
 
528
            # forbidden from reuse in the related selection tables (which is
 
529
            # what "used" specifies).
 
530
            avoid = avoid_set.copy()
 
531
            dupe_set = orig_dupe_set.copy()
 
532
            table = f.rel.to._meta.db_table
 
533
            promote = nullable or f.null
 
534
            if model:
 
535
                int_opts = opts
 
536
                alias = root_alias
 
537
                alias_chain = []
 
538
                for int_model in opts.get_base_chain(model):
 
539
                    # Proxy model have elements in base chain
 
540
                    # with no parents, assign the new options
 
541
                    # object and skip to the next base in that
 
542
                    # case
 
543
                    if not int_opts.parents[int_model]:
 
544
                        int_opts = int_model._meta
 
545
                        continue
 
546
                    lhs_col = int_opts.parents[int_model].column
 
547
                    dedupe = lhs_col in opts.duplicate_targets
 
548
                    if dedupe:
 
549
                        avoid.update(self.query.dupe_avoidance.get((id(opts), lhs_col),
 
550
                                ()))
 
551
                        dupe_set.add((opts, lhs_col))
 
552
                    int_opts = int_model._meta
 
553
                    alias = self.query.join((alias, int_opts.db_table, lhs_col,
 
554
                            int_opts.pk.column), exclusions=used,
 
555
                            promote=promote)
 
556
                    alias_chain.append(alias)
 
557
                    for (dupe_opts, dupe_col) in dupe_set:
 
558
                        self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias)
 
559
                if self.query.alias_map[root_alias][JOIN_TYPE] == self.query.LOUTER:
 
560
                    self.query.promote_alias_chain(alias_chain, True)
 
561
            else:
 
562
                alias = root_alias
 
563
 
 
564
            dedupe = f.column in opts.duplicate_targets
 
565
            if dupe_set or dedupe:
 
566
                avoid.update(self.query.dupe_avoidance.get((id(opts), f.column), ()))
 
567
                if dedupe:
 
568
                    dupe_set.add((opts, f.column))
 
569
 
 
570
            alias = self.query.join((alias, table, f.column,
 
571
                    f.rel.get_related_field().column),
 
572
                    exclusions=used.union(avoid), promote=promote)
 
573
            used.add(alias)
 
574
            columns, aliases = self.get_default_columns(start_alias=alias,
 
575
                    opts=f.rel.to._meta, as_pairs=True)
 
576
            self.query.related_select_cols.extend(columns)
 
577
            if self.query.alias_map[alias][JOIN_TYPE] == self.query.LOUTER:
 
578
                self.query.promote_alias_chain(aliases, True)
 
579
            self.query.related_select_fields.extend(f.rel.to._meta.fields)
 
580
            if restricted:
 
581
                next = requested.get(f.name, {})
 
582
            else:
 
583
                next = False
 
584
            new_nullable = f.null or promote
 
585
            for dupe_opts, dupe_col in dupe_set:
 
586
                self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias)
 
587
            self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
 
588
                    used, next, restricted, new_nullable, dupe_set, avoid)
 
589
 
 
590
        if restricted:
 
591
            related_fields = [
 
592
                (o.field, o.model)
 
593
                for o in opts.get_all_related_objects()
 
594
                if o.field.unique
 
595
            ]
 
596
            for f, model in related_fields:
 
597
                if not select_related_descend(f, restricted, requested, reverse=True):
 
598
                    continue
 
599
                # The "avoid" set is aliases we want to avoid just for this
 
600
                # particular branch of the recursion. They aren't permanently
 
601
                # forbidden from reuse in the related selection tables (which is
 
602
                # what "used" specifies).
 
603
                avoid = avoid_set.copy()
 
604
                dupe_set = orig_dupe_set.copy()
 
605
                table = model._meta.db_table
 
606
 
 
607
                int_opts = opts
 
608
                alias = root_alias
 
609
                alias_chain = []
 
610
                chain = opts.get_base_chain(f.rel.to)
 
611
                if chain is not None:
 
612
                    for int_model in chain:
 
613
                        # Proxy model have elements in base chain
 
614
                        # with no parents, assign the new options
 
615
                        # object and skip to the next base in that
 
616
                        # case
 
617
                        if not int_opts.parents[int_model]:
 
618
                            int_opts = int_model._meta
 
619
                            continue
 
620
                        lhs_col = int_opts.parents[int_model].column
 
621
                        dedupe = lhs_col in opts.duplicate_targets
 
622
                        if dedupe:
 
623
                            avoid.update((self.query.dupe_avoidance.get(id(opts), lhs_col),
 
624
                                ()))
 
625
                            dupe_set.add((opts, lhs_col))
 
626
                        int_opts = int_model._meta
 
627
                        alias = self.query.join(
 
628
                            (alias, int_opts.db_table, lhs_col, int_opts.pk.column),
 
629
                            exclusions=used, promote=True, reuse=used
 
630
                        )
 
631
                        alias_chain.append(alias)
 
632
                        for dupe_opts, dupe_col in dupe_set:
 
633
                            self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias)
 
634
                    dedupe = f.column in opts.duplicate_targets
 
635
                    if dupe_set or dedupe:
 
636
                        avoid.update(self.query.dupe_avoidance.get((id(opts), f.column), ()))
 
637
                        if dedupe:
 
638
                            dupe_set.add((opts, f.column))
 
639
                alias = self.query.join(
 
640
                    (alias, table, f.rel.get_related_field().column, f.column),
 
641
                    exclusions=used.union(avoid),
 
642
                    promote=True
 
643
                )
 
644
                used.add(alias)
 
645
                columns, aliases = self.get_default_columns(start_alias=alias,
 
646
                    opts=model._meta, as_pairs=True, local_only=True)
 
647
                self.query.related_select_cols.extend(columns)
 
648
                self.query.related_select_fields.extend(model._meta.fields)
 
649
 
 
650
                next = requested.get(f.related_query_name(), {})
 
651
                new_nullable = f.null or None
 
652
 
 
653
                self.fill_related_selections(model._meta, table, cur_depth+1,
 
654
                    used, next, restricted, new_nullable)
 
655
 
 
656
    def deferred_to_columns(self):
 
657
        """
 
658
        Converts the self.deferred_loading data structure to mapping of table
 
659
        names to sets of column names which are to be loaded. Returns the
 
660
        dictionary.
 
661
        """
 
662
        columns = {}
 
663
        self.query.deferred_to_data(columns, self.query.deferred_to_columns_cb)
 
664
        return columns
 
665
 
 
666
    def results_iter(self):
 
667
        """
 
668
        Returns an iterator over the results from executing this query.
 
669
        """
 
670
        resolve_columns = hasattr(self, 'resolve_columns')
 
671
        fields = None
 
672
        for rows in self.execute_sql(MULTI):
 
673
            for row in rows:
 
674
                if resolve_columns:
 
675
                    if fields is None:
 
676
                        # We only set this up here because
 
677
                        # related_select_fields isn't populated until
 
678
                        # execute_sql() has been called.
 
679
                        if self.query.select_fields:
 
680
                            fields = self.query.select_fields + self.query.related_select_fields
 
681
                        else:
 
682
                            fields = self.query.model._meta.fields
 
683
                        # If the field was deferred, exclude it from being passed
 
684
                        # into `resolve_columns` because it wasn't selected.
 
685
                        only_load = self.deferred_to_columns()
 
686
                        if only_load:
 
687
                            db_table = self.query.model._meta.db_table
 
688
                            fields = [f for f in fields if db_table in only_load and
 
689
                                      f.column in only_load[db_table]]
 
690
                    row = self.resolve_columns(row, fields)
 
691
 
 
692
                if self.query.aggregate_select:
 
693
                    aggregate_start = len(self.query.extra_select.keys()) + len(self.query.select)
 
694
                    aggregate_end = aggregate_start + len(self.query.aggregate_select)
 
695
                    row = tuple(row[:aggregate_start]) + tuple([
 
696
                        self.query.resolve_aggregate(value, aggregate, self.connection)
 
697
                        for (alias, aggregate), value
 
698
                        in zip(self.query.aggregate_select.items(), row[aggregate_start:aggregate_end])
 
699
                    ]) + tuple(row[aggregate_end:])
 
700
 
 
701
                yield row
 
702
 
 
703
    def execute_sql(self, result_type=MULTI):
 
704
        """
 
705
        Run the query against the database and returns the result(s). The
 
706
        return value is a single data item if result_type is SINGLE, or an
 
707
        iterator over the results if the result_type is MULTI.
 
708
 
 
709
        result_type is either MULTI (use fetchmany() to retrieve all rows),
 
710
        SINGLE (only retrieve a single row), or None. In this last case, the
 
711
        cursor is returned if any query is executed, since it's used by
 
712
        subclasses such as InsertQuery). It's possible, however, that no query
 
713
        is needed, as the filters describe an empty set. In that case, None is
 
714
        returned, to avoid any unnecessary database interaction.
 
715
        """
 
716
        try:
 
717
            sql, params = self.as_sql()
 
718
            if not sql:
 
719
                raise EmptyResultSet
 
720
        except EmptyResultSet:
 
721
            if result_type == MULTI:
 
722
                return empty_iter()
 
723
            else:
 
724
                return
 
725
 
 
726
        cursor = self.connection.cursor()
 
727
        cursor.execute(sql, params)
 
728
 
 
729
        if not result_type:
 
730
            return cursor
 
731
        if result_type == SINGLE:
 
732
            if self.query.ordering_aliases:
 
733
                return cursor.fetchone()[:-len(self.query.ordering_aliases)]
 
734
            return cursor.fetchone()
 
735
 
 
736
        # The MULTI case.
 
737
        if self.query.ordering_aliases:
 
738
            result = order_modified_iter(cursor, len(self.query.ordering_aliases),
 
739
                    self.connection.features.empty_fetchmany_value)
 
740
        else:
 
741
            result = iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
 
742
                    self.connection.features.empty_fetchmany_value)
 
743
        if not self.connection.features.can_use_chunked_reads:
 
744
            # If we are using non-chunked reads, we return the same data
 
745
            # structure as normally, but ensure it is all read into memory
 
746
            # before going any further.
 
747
            return list(result)
 
748
        return result
 
749
 
 
750
 
 
751
class SQLInsertCompiler(SQLCompiler):
 
752
    def placeholder(self, field, val):
 
753
        if field is None:
 
754
            # A field value of None means the value is raw.
 
755
            return val
 
756
        elif hasattr(field, 'get_placeholder'):
 
757
            # Some fields (e.g. geo fields) need special munging before
 
758
            # they can be inserted.
 
759
            return field.get_placeholder(val, self.connection)
 
760
        else:
 
761
            # Return the common case for the placeholder
 
762
            return '%s'
 
763
 
 
764
    def as_sql(self):
 
765
        # We don't need quote_name_unless_alias() here, since these are all
 
766
        # going to be column names (so we can avoid the extra overhead).
 
767
        qn = self.connection.ops.quote_name
 
768
        opts = self.query.model._meta
 
769
        result = ['INSERT INTO %s' % qn(opts.db_table)]
 
770
        result.append('(%s)' % ', '.join([qn(c) for c in self.query.columns]))
 
771
        values = [self.placeholder(*v) for v in self.query.values]
 
772
        result.append('VALUES (%s)' % ', '.join(values))
 
773
        params = self.query.params
 
774
        if self.return_id and self.connection.features.can_return_id_from_insert:
 
775
            col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
 
776
            r_fmt, r_params = self.connection.ops.return_insert_id()
 
777
            result.append(r_fmt % col)
 
778
            params = params + r_params
 
779
        return ' '.join(result), params
 
780
 
 
781
    def execute_sql(self, return_id=False):
 
782
        self.return_id = return_id
 
783
        cursor = super(SQLInsertCompiler, self).execute_sql(None)
 
784
        if not (return_id and cursor):
 
785
            return
 
786
        if self.connection.features.can_return_id_from_insert:
 
787
            return self.connection.ops.fetch_returned_insert_id(cursor)
 
788
        return self.connection.ops.last_insert_id(cursor,
 
789
                self.query.model._meta.db_table, self.query.model._meta.pk.column)
 
790
 
 
791
 
 
792
class SQLDeleteCompiler(SQLCompiler):
 
793
    def as_sql(self):
 
794
        """
 
795
        Creates the SQL for this query. Returns the SQL string and list of
 
796
        parameters.
 
797
        """
 
798
        assert len(self.query.tables) == 1, \
 
799
                "Can only delete from one table at a time."
 
800
        qn = self.quote_name_unless_alias
 
801
        result = ['DELETE FROM %s' % qn(self.query.tables[0])]
 
802
        where, params = self.query.where.as_sql(qn=qn, connection=self.connection)
 
803
        result.append('WHERE %s' % where)
 
804
        return ' '.join(result), tuple(params)
 
805
 
 
806
class SQLUpdateCompiler(SQLCompiler):
 
807
    def as_sql(self):
 
808
        """
 
809
        Creates the SQL for this query. Returns the SQL string and list of
 
810
        parameters.
 
811
        """
 
812
        from django.db.models.base import Model
 
813
 
 
814
        self.pre_sql_setup()
 
815
        if not self.query.values:
 
816
            return '', ()
 
817
        table = self.query.tables[0]
 
818
        qn = self.quote_name_unless_alias
 
819
        result = ['UPDATE %s' % qn(table)]
 
820
        result.append('SET')
 
821
        values, update_params = [], []
 
822
        for field, model, val in self.query.values:
 
823
            if hasattr(val, 'prepare_database_save'):
 
824
                val = val.prepare_database_save(field)
 
825
            else:
 
826
                val = field.get_db_prep_save(val, connection=self.connection)
 
827
 
 
828
            # Getting the placeholder for the field.
 
829
            if hasattr(field, 'get_placeholder'):
 
830
                placeholder = field.get_placeholder(val, self.connection)
 
831
            else:
 
832
                placeholder = '%s'
 
833
 
 
834
            if hasattr(val, 'evaluate'):
 
835
                val = SQLEvaluator(val, self.query, allow_joins=False)
 
836
            name = field.column
 
837
            if hasattr(val, 'as_sql'):
 
838
                sql, params = val.as_sql(qn, self.connection)
 
839
                values.append('%s = %s' % (qn(name), sql))
 
840
                update_params.extend(params)
 
841
            elif val is not None:
 
842
                values.append('%s = %s' % (qn(name), placeholder))
 
843
                update_params.append(val)
 
844
            else:
 
845
                values.append('%s = NULL' % qn(name))
 
846
        if not values:
 
847
            return '', ()
 
848
        result.append(', '.join(values))
 
849
        where, params = self.query.where.as_sql(qn=qn, connection=self.connection)
 
850
        if where:
 
851
            result.append('WHERE %s' % where)
 
852
        return ' '.join(result), tuple(update_params + params)
 
853
 
 
854
    def execute_sql(self, result_type):
 
855
        """
 
856
        Execute the specified update. Returns the number of rows affected by
 
857
        the primary update query. The "primary update query" is the first
 
858
        non-empty query that is executed. Row counts for any subsequent,
 
859
        related queries are not available.
 
860
        """
 
861
        cursor = super(SQLUpdateCompiler, self).execute_sql(result_type)
 
862
        rows = cursor and cursor.rowcount or 0
 
863
        is_empty = cursor is None
 
864
        del cursor
 
865
        for query in self.query.get_related_updates():
 
866
            aux_rows = query.get_compiler(self.using).execute_sql(result_type)
 
867
            if is_empty:
 
868
                rows = aux_rows
 
869
                is_empty = False
 
870
        return rows
 
871
 
 
872
    def pre_sql_setup(self):
 
873
        """
 
874
        If the update depends on results from other tables, we need to do some
 
875
        munging of the "where" conditions to match the format required for
 
876
        (portable) SQL updates. That is done here.
 
877
 
 
878
        Further, if we are going to be running multiple updates, we pull out
 
879
        the id values to update at this point so that they don't change as a
 
880
        result of the progressive updates.
 
881
        """
 
882
        self.query.select_related = False
 
883
        self.query.clear_ordering(True)
 
884
        super(SQLUpdateCompiler, self).pre_sql_setup()
 
885
        count = self.query.count_active_tables()
 
886
        if not self.query.related_updates and count == 1:
 
887
            return
 
888
 
 
889
        # We need to use a sub-select in the where clause to filter on things
 
890
        # from other tables.
 
891
        query = self.query.clone(klass=Query)
 
892
        query.bump_prefix()
 
893
        query.extra = {}
 
894
        query.select = []
 
895
        query.add_fields([query.model._meta.pk.name])
 
896
        must_pre_select = count > 1 and not self.connection.features.update_can_self_select
 
897
 
 
898
        # Now we adjust the current query: reset the where clause and get rid
 
899
        # of all the tables we don't need (since they're in the sub-select).
 
900
        self.query.where = self.query.where_class()
 
901
        if self.query.related_updates or must_pre_select:
 
902
            # Either we're using the idents in multiple update queries (so
 
903
            # don't want them to change), or the db backend doesn't support
 
904
            # selecting from the updating table (e.g. MySQL).
 
905
            idents = []
 
906
            for rows in query.get_compiler(self.using).execute_sql(MULTI):
 
907
                idents.extend([r[0] for r in rows])
 
908
            self.query.add_filter(('pk__in', idents))
 
909
            self.query.related_ids = idents
 
910
        else:
 
911
            # The fast path. Filters and updates in one query.
 
912
            self.query.add_filter(('pk__in', query))
 
913
        for alias in self.query.tables[1:]:
 
914
            self.query.alias_refcount[alias] = 0
 
915
 
 
916
class SQLAggregateCompiler(SQLCompiler):
 
917
    def as_sql(self, qn=None):
 
918
        """
 
919
        Creates the SQL for this query. Returns the SQL string and list of
 
920
        parameters.
 
921
        """
 
922
        if qn is None:
 
923
            qn = self.quote_name_unless_alias
 
924
        sql = ('SELECT %s FROM (%s) subquery' % (
 
925
            ', '.join([
 
926
                aggregate.as_sql(qn, self.connection)
 
927
                for aggregate in self.query.aggregate_select.values()
 
928
            ]),
 
929
            self.query.subquery)
 
930
        )
 
931
        params = self.query.sub_params
 
932
        return (sql, params)
 
933
 
 
934
class SQLDateCompiler(SQLCompiler):
 
935
    def results_iter(self):
 
936
        """
 
937
        Returns an iterator over the results from executing this query.
 
938
        """
 
939
        resolve_columns = hasattr(self, 'resolve_columns')
 
940
        if resolve_columns:
 
941
            from django.db.models.fields import DateTimeField
 
942
            fields = [DateTimeField()]
 
943
        else:
 
944
            from django.db.backends.util import typecast_timestamp
 
945
            needs_string_cast = self.connection.features.needs_datetime_string_cast
 
946
 
 
947
        offset = len(self.query.extra_select)
 
948
        for rows in self.execute_sql(MULTI):
 
949
            for row in rows:
 
950
                date = row[offset]
 
951
                if resolve_columns:
 
952
                    date = self.resolve_columns(row, fields)[offset]
 
953
                elif needs_string_cast:
 
954
                    date = typecast_timestamp(str(date))
 
955
                yield date
 
956
 
 
957
 
 
958
def empty_iter():
 
959
    """
 
960
    Returns an iterator containing no results.
 
961
    """
 
962
    yield iter([]).next()
 
963
 
 
964
 
 
965
def order_modified_iter(cursor, trim, sentinel):
 
966
    """
 
967
    Yields blocks of rows from a cursor. We use this iterator in the special
 
968
    case when extra output columns have been added to support ordering
 
969
    requirements. We must trim those extra columns before anything else can use
 
970
    the results, since they're only needed to make the SQL valid.
 
971
    """
 
972
    for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
 
973
            sentinel):
 
974
        yield [r[:-trim] for r in rows]