~ubuntu-branches/debian/sid/python-django/sid

« back to all changes in this revision

Viewing changes to tests/backends/tests.py

  • Committer: Package Import Robot
  • Author(s): Luke Faraone
  • Date: 2013-11-07 15:33:49 UTC
  • mfrom: (1.3.12)
  • Revision ID: package-import@ubuntu.com-20131107153349-e31sc149l2szs3jb
Tags: 1.6-1
* New upstream version. Closes: #557474, #724637.
* python-django now also suggests the installation of ipython,
  bpython, python-django-doc, and libgdal1.
  Closes: #636511, #686333, #704203
* Set package maintainer to Debian Python Modules Team.
* Bump standards version to 3.9.5, no changes needed.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# -*- coding: utf-8 -*-
 
2
# Unit and doctests for specific database backends.
 
3
from __future__ import absolute_import, unicode_literals
 
4
 
 
5
import datetime
 
6
from decimal import Decimal
 
7
import threading
 
8
 
 
9
from django.conf import settings
 
10
from django.core.management.color import no_style
 
11
from django.db import (connection, connections, DEFAULT_DB_ALIAS,
 
12
    DatabaseError, IntegrityError, transaction)
 
13
from django.db.backends.signals import connection_created
 
14
from django.db.backends.sqlite3.base import DatabaseOperations
 
15
from django.db.backends.postgresql_psycopg2 import version as pg_version
 
16
from django.db.backends.util import format_number
 
17
from django.db.models import Sum, Avg, Variance, StdDev
 
18
from django.db.models.fields import (AutoField, DateField, DateTimeField,
 
19
    DecimalField, IntegerField, TimeField)
 
20
from django.db.utils import ConnectionHandler
 
21
from django.test import (TestCase, skipUnlessDBFeature, skipIfDBFeature,
 
22
    TransactionTestCase)
 
23
from django.test.utils import override_settings, str_prefix
 
24
from django.utils import six, unittest
 
25
from django.utils.six.moves import xrange
 
26
 
 
27
from . import models
 
28
 
 
29
 
 
30
class DummyBackendTest(TestCase):
 
31
    def test_no_databases(self):
 
32
        """
 
33
        Test that empty DATABASES setting default to the dummy backend.
 
34
        """
 
35
        DATABASES = {}
 
36
        conns = ConnectionHandler(DATABASES)
 
37
        self.assertEqual(conns[DEFAULT_DB_ALIAS].settings_dict['ENGINE'],
 
38
            'django.db.backends.dummy')
 
39
 
 
40
 
 
41
class OracleChecks(unittest.TestCase):
 
42
 
 
43
    @unittest.skipUnless(connection.vendor == 'oracle',
 
44
                         "No need to check Oracle quote_name semantics")
 
45
    def test_quote_name(self):
 
46
        # Check that '%' chars are escaped for query execution.
 
47
        name = '"SOME%NAME"'
 
48
        quoted_name = connection.ops.quote_name(name)
 
49
        self.assertEqual(quoted_name % (), name)
 
50
 
 
51
    @unittest.skipUnless(connection.vendor == 'oracle',
 
52
                         "No need to check Oracle cursor semantics")
 
53
    def test_dbms_session(self):
 
54
        # If the backend is Oracle, test that we can call a standard
 
55
        # stored procedure through our cursor wrapper.
 
56
        from django.db.backends.oracle.base import convert_unicode
 
57
 
 
58
        cursor = connection.cursor()
 
59
        cursor.callproc(convert_unicode('DBMS_SESSION.SET_IDENTIFIER'),
 
60
                        [convert_unicode('_django_testing!')])
 
61
 
 
62
    @unittest.skipUnless(connection.vendor == 'oracle',
 
63
                         "No need to check Oracle cursor semantics")
 
64
    def test_cursor_var(self):
 
65
        # If the backend is Oracle, test that we can pass cursor variables
 
66
        # as query parameters.
 
67
        from django.db.backends.oracle.base import Database
 
68
 
 
69
        cursor = connection.cursor()
 
70
        var = cursor.var(Database.STRING)
 
71
        cursor.execute("BEGIN %s := 'X'; END; ", [var])
 
72
        self.assertEqual(var.getvalue(), 'X')
 
73
 
 
74
    @unittest.skipUnless(connection.vendor == 'oracle',
 
75
                         "No need to check Oracle cursor semantics")
 
76
    def test_long_string(self):
 
77
        # If the backend is Oracle, test that we can save a text longer
 
78
        # than 4000 chars and read it properly
 
79
        c = connection.cursor()
 
80
        c.execute('CREATE TABLE ltext ("TEXT" NCLOB)')
 
81
        long_str = ''.join([six.text_type(x) for x in xrange(4000)])
 
82
        c.execute('INSERT INTO ltext VALUES (%s)', [long_str])
 
83
        c.execute('SELECT text FROM ltext')
 
84
        row = c.fetchone()
 
85
        self.assertEqual(long_str, row[0].read())
 
86
        c.execute('DROP TABLE ltext')
 
87
 
 
88
    @unittest.skipUnless(connection.vendor == 'oracle',
 
89
                         "No need to check Oracle connection semantics")
 
90
    def test_client_encoding(self):
 
91
        # If the backend is Oracle, test that the client encoding is set
 
92
        # correctly.  This was broken under Cygwin prior to r14781.
 
93
        connection.cursor()  # Ensure the connection is initialized.
 
94
        self.assertEqual(connection.connection.encoding, "UTF-8")
 
95
        self.assertEqual(connection.connection.nencoding, "UTF-8")
 
96
 
 
97
    @unittest.skipUnless(connection.vendor == 'oracle',
 
98
                         "No need to check Oracle connection semantics")
 
99
    def test_order_of_nls_parameters(self):
 
100
        # an 'almost right' datetime should work with configured
 
101
        # NLS parameters as per #18465.
 
102
        c = connection.cursor()
 
103
        query = "select 1 from dual where '1936-12-29 00:00' < sysdate"
 
104
        # Test that the query succeeds without errors - pre #18465 this
 
105
        # wasn't the case.
 
106
        c.execute(query)
 
107
        self.assertEqual(c.fetchone()[0], 1)
 
108
 
 
109
 
 
110
class MySQLTests(TestCase):
 
111
    @unittest.skipUnless(connection.vendor == 'mysql',
 
112
                        "Test valid only for MySQL")
 
113
    def test_autoincrement(self):
 
114
        """
 
115
        Check that auto_increment fields are reset correctly by sql_flush().
 
116
        Before MySQL version 5.0.13 TRUNCATE did not do auto_increment reset.
 
117
        Refs #16961.
 
118
        """
 
119
        statements = connection.ops.sql_flush(no_style(),
 
120
                                              tables=['test'],
 
121
                                              sequences=[{
 
122
                                                  'table': 'test',
 
123
                                                  'col': 'somecol',
 
124
                                              }])
 
125
        found_reset = False
 
126
        for sql in statements:
 
127
            found_reset = found_reset or 'ALTER TABLE' in sql
 
128
        if connection.mysql_version < (5, 0, 13):
 
129
            self.assertTrue(found_reset)
 
130
        else:
 
131
            self.assertFalse(found_reset)
 
132
 
 
133
 
 
134
class DateQuotingTest(TestCase):
 
135
 
 
136
    def test_django_date_trunc(self):
 
137
        """
 
138
        Test the custom ``django_date_trunc method``, in particular against
 
139
        fields which clash with strings passed to it (e.g. 'year') - see
 
140
        #12818__.
 
141
 
 
142
        __: http://code.djangoproject.com/ticket/12818
 
143
 
 
144
        """
 
145
        updated = datetime.datetime(2010, 2, 20)
 
146
        models.SchoolClass.objects.create(year=2009, last_updated=updated)
 
147
        years = models.SchoolClass.objects.dates('last_updated', 'year')
 
148
        self.assertEqual(list(years), [datetime.date(2010, 1, 1)])
 
149
 
 
150
    def test_django_date_extract(self):
 
151
        """
 
152
        Test the custom ``django_date_extract method``, in particular against fields
 
153
        which clash with strings passed to it (e.g. 'day') - see #12818__.
 
154
 
 
155
        __: http://code.djangoproject.com/ticket/12818
 
156
 
 
157
        """
 
158
        updated = datetime.datetime(2010, 2, 20)
 
159
        models.SchoolClass.objects.create(year=2009, last_updated=updated)
 
160
        classes = models.SchoolClass.objects.filter(last_updated__day=20)
 
161
        self.assertEqual(len(classes), 1)
 
162
 
 
163
 
 
164
@override_settings(DEBUG=True)
 
165
class LastExecutedQueryTest(TestCase):
 
166
 
 
167
    def test_last_executed_query(self):
 
168
        """
 
169
        last_executed_query should not raise an exception even if no previous
 
170
        query has been run.
 
171
        """
 
172
        cursor = connection.cursor()
 
173
        try:
 
174
            connection.ops.last_executed_query(cursor, '', ())
 
175
        except Exception:
 
176
            self.fail("'last_executed_query' should not raise an exception.")
 
177
 
 
178
    def test_debug_sql(self):
 
179
        list(models.Reporter.objects.filter(first_name="test"))
 
180
        sql = connection.queries[-1]['sql'].lower()
 
181
        self.assertIn("select", sql)
 
182
        self.assertIn(models.Reporter._meta.db_table, sql)
 
183
 
 
184
    def test_query_encoding(self):
 
185
        """
 
186
        Test that last_executed_query() returns an Unicode string
 
187
        """
 
188
        persons = models.Reporter.objects.filter(raw_data=b'\x00\x46  \xFE').extra(select={'föö': 1})
 
189
        sql, params = persons.query.sql_with_params()
 
190
        cursor = persons.query.get_compiler('default').execute_sql(None)
 
191
        last_sql = cursor.db.ops.last_executed_query(cursor, sql, params)
 
192
        self.assertIsInstance(last_sql, six.text_type)
 
193
 
 
194
    @unittest.skipUnless(connection.vendor == 'sqlite',
 
195
                         "This test is specific to SQLite.")
 
196
    def test_no_interpolation_on_sqlite(self):
 
197
        # Regression for #17158
 
198
        # This shouldn't raise an exception
 
199
        query = "SELECT strftime('%Y', 'now');"
 
200
        connection.cursor().execute(query)
 
201
        self.assertEqual(connection.queries[-1]['sql'],
 
202
            str_prefix("QUERY = %(_)s\"SELECT strftime('%%Y', 'now');\" - PARAMS = ()"))
 
203
 
 
204
 
 
205
class ParameterHandlingTest(TestCase):
 
206
    def test_bad_parameter_count(self):
 
207
        "An executemany call with too many/not enough parameters will raise an exception (Refs #12612)"
 
208
        cursor = connection.cursor()
 
209
        query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % (
 
210
            connection.introspection.table_name_converter('backends_square'),
 
211
            connection.ops.quote_name('root'),
 
212
            connection.ops.quote_name('square')
 
213
        ))
 
214
        self.assertRaises(Exception, cursor.executemany, query, [(1, 2, 3)])
 
215
        self.assertRaises(Exception, cursor.executemany, query, [(1,)])
 
216
 
 
217
 
 
218
# Unfortunately, the following tests would be a good test to run on all
 
219
# backends, but it breaks MySQL hard. Until #13711 is fixed, it can't be run
 
220
# everywhere (although it would be an effective test of #13711).
 
221
class LongNameTest(TestCase):
 
222
    """Long primary keys and model names can result in a sequence name
 
223
    that exceeds the database limits, which will result in truncation
 
224
    on certain databases (e.g., Postgres). The backend needs to use
 
225
    the correct sequence name in last_insert_id and other places, so
 
226
    check it is. Refs #8901.
 
227
    """
 
228
 
 
229
    @skipUnlessDBFeature('supports_long_model_names')
 
230
    def test_sequence_name_length_limits_create(self):
 
231
        """Test creation of model with long name and long pk name doesn't error. Ref #8901"""
 
232
        models.VeryLongModelNameZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ.objects.create()
 
233
 
 
234
    @skipUnlessDBFeature('supports_long_model_names')
 
235
    def test_sequence_name_length_limits_m2m(self):
 
236
        """Test an m2m save of a model with a long name and a long m2m field name doesn't error as on Django >=1.2 this now uses object saves. Ref #8901"""
 
237
        obj = models.VeryLongModelNameZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ.objects.create()
 
238
        rel_obj = models.Person.objects.create(first_name='Django', last_name='Reinhardt')
 
239
        obj.m2m_also_quite_long_zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz.add(rel_obj)
 
240
 
 
241
    @skipUnlessDBFeature('supports_long_model_names')
 
242
    def test_sequence_name_length_limits_flush(self):
 
243
        """Test that sequence resetting as part of a flush with model with long name and long pk name doesn't error. Ref #8901"""
 
244
        # A full flush is expensive to the full test, so we dig into the
 
245
        # internals to generate the likely offending SQL and run it manually
 
246
 
 
247
        # Some convenience aliases
 
248
        VLM = models.VeryLongModelNameZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ
 
249
        VLM_m2m = VLM.m2m_also_quite_long_zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz.through
 
250
        tables = [
 
251
            VLM._meta.db_table,
 
252
            VLM_m2m._meta.db_table,
 
253
        ]
 
254
        sequences = [
 
255
            {
 
256
                'column': VLM._meta.pk.column,
 
257
                'table': VLM._meta.db_table
 
258
            },
 
259
        ]
 
260
        cursor = connection.cursor()
 
261
        for statement in connection.ops.sql_flush(no_style(), tables, sequences):
 
262
            cursor.execute(statement)
 
263
 
 
264
 
 
265
class SequenceResetTest(TestCase):
 
266
    def test_generic_relation(self):
 
267
        "Sequence names are correct when resetting generic relations (Ref #13941)"
 
268
        # Create an object with a manually specified PK
 
269
        models.Post.objects.create(id=10, name='1st post', text='hello world')
 
270
 
 
271
        # Reset the sequences for the database
 
272
        cursor = connection.cursor()
 
273
        commands = connections[DEFAULT_DB_ALIAS].ops.sequence_reset_sql(no_style(), [models.Post])
 
274
        for sql in commands:
 
275
            cursor.execute(sql)
 
276
 
 
277
        # If we create a new object now, it should have a PK greater
 
278
        # than the PK we specified manually.
 
279
        obj = models.Post.objects.create(name='New post', text='goodbye world')
 
280
        self.assertTrue(obj.pk > 10)
 
281
 
 
282
 
 
283
class PostgresVersionTest(TestCase):
 
284
    def assert_parses(self, version_string, version):
 
285
        self.assertEqual(pg_version._parse_version(version_string), version)
 
286
 
 
287
    def test_parsing(self):
 
288
        """Test PostgreSQL version parsing from `SELECT version()` output"""
 
289
        self.assert_parses("PostgreSQL 8.3 beta4", 80300)
 
290
        self.assert_parses("PostgreSQL 8.3", 80300)
 
291
        self.assert_parses("EnterpriseDB 8.3", 80300)
 
292
        self.assert_parses("PostgreSQL 8.3.6", 80306)
 
293
        self.assert_parses("PostgreSQL 8.4beta1", 80400)
 
294
        self.assert_parses("PostgreSQL 8.3.1 on i386-apple-darwin9.2.2, compiled by GCC i686-apple-darwin9-gcc-4.0.1 (GCC) 4.0.1 (Apple Inc. build 5478)", 80301)
 
295
 
 
296
    def test_version_detection(self):
 
297
        """Test PostgreSQL version detection"""
 
298
 
 
299
        # Helper mocks
 
300
        class CursorMock(object):
 
301
            "Very simple mock of DB-API cursor"
 
302
            def execute(self, arg):
 
303
                pass
 
304
 
 
305
            def fetchone(self):
 
306
                return ["PostgreSQL 8.3"]
 
307
 
 
308
        class OlderConnectionMock(object):
 
309
            "Mock of psycopg2 (< 2.0.12) connection"
 
310
            def cursor(self):
 
311
                return CursorMock()
 
312
 
 
313
        # psycopg2 < 2.0.12 code path
 
314
        conn = OlderConnectionMock()
 
315
        self.assertEqual(pg_version.get_version(conn), 80300)
 
316
 
 
317
 
 
318
class PostgresNewConnectionTest(TestCase):
 
319
    """
 
320
    #17062: PostgreSQL shouldn't roll back SET TIME ZONE, even if the first
 
321
    transaction is rolled back.
 
322
    """
 
323
    @unittest.skipUnless(
 
324
        connection.vendor == 'postgresql',
 
325
        "This test applies only to PostgreSQL")
 
326
    def test_connect_and_rollback(self):
 
327
        new_connections = ConnectionHandler(settings.DATABASES)
 
328
        new_connection = new_connections[DEFAULT_DB_ALIAS]
 
329
        try:
 
330
            # Ensure the database default time zone is different than
 
331
            # the time zone in new_connection.settings_dict. We can
 
332
            # get the default time zone by reset & show.
 
333
            cursor = new_connection.cursor()
 
334
            cursor.execute("RESET TIMEZONE")
 
335
            cursor.execute("SHOW TIMEZONE")
 
336
            db_default_tz = cursor.fetchone()[0]
 
337
            new_tz = 'Europe/Paris' if db_default_tz == 'UTC' else 'UTC'
 
338
            new_connection.close()
 
339
 
 
340
            # Fetch a new connection with the new_tz as default
 
341
            # time zone, run a query and rollback.
 
342
            new_connection.settings_dict['TIME_ZONE'] = new_tz
 
343
            new_connection.enter_transaction_management()
 
344
            cursor = new_connection.cursor()
 
345
            new_connection.rollback()
 
346
 
 
347
            # Now let's see if the rollback rolled back the SET TIME ZONE.
 
348
            cursor.execute("SHOW TIMEZONE")
 
349
            tz = cursor.fetchone()[0]
 
350
            self.assertEqual(new_tz, tz)
 
351
        finally:
 
352
            try:
 
353
                new_connection.close()
 
354
            except DatabaseError:
 
355
                pass
 
356
 
 
357
 
 
358
# This test needs to run outside of a transaction, otherwise closing the
 
359
# connection would implicitly rollback and cause problems during teardown.
 
360
class ConnectionCreatedSignalTest(TransactionTestCase):
 
361
 
 
362
    available_apps = []
 
363
 
 
364
    # Unfortunately with sqlite3 the in-memory test database cannot be closed,
 
365
    # and so it cannot be re-opened during testing.
 
366
    @skipUnlessDBFeature('test_db_allows_multiple_connections')
 
367
    def test_signal(self):
 
368
        data = {}
 
369
 
 
370
        def receiver(sender, connection, **kwargs):
 
371
            data["connection"] = connection
 
372
 
 
373
        connection_created.connect(receiver)
 
374
        connection.close()
 
375
        connection.cursor()
 
376
        self.assertTrue(data["connection"].connection is connection.connection)
 
377
 
 
378
        connection_created.disconnect(receiver)
 
379
        data.clear()
 
380
        connection.cursor()
 
381
        self.assertTrue(data == {})
 
382
 
 
383
 
 
384
class EscapingChecks(TestCase):
 
385
    """
 
386
    All tests in this test case are also run with settings.DEBUG=True in
 
387
    EscapingChecksDebug test case, to also test CursorDebugWrapper.
 
388
    """
 
389
 
 
390
    # For Oracle, when you want to select a value, you need to specify the
 
391
    # special pseudo-table 'dual'; a select with no from clause is invalid.
 
392
    bare_select_suffix = " FROM DUAL" if connection.vendor == 'oracle' else ""
 
393
 
 
394
    def test_paramless_no_escaping(self):
 
395
        cursor = connection.cursor()
 
396
        cursor.execute("SELECT '%s'" + self.bare_select_suffix)
 
397
        self.assertEqual(cursor.fetchall()[0][0], '%s')
 
398
 
 
399
    def test_parameter_escaping(self):
 
400
        cursor = connection.cursor()
 
401
        cursor.execute("SELECT '%%', %s" + self.bare_select_suffix, ('%d',))
 
402
        self.assertEqual(cursor.fetchall()[0], ('%', '%d'))
 
403
 
 
404
    @unittest.skipUnless(connection.vendor == 'sqlite',
 
405
                         "This is an sqlite-specific issue")
 
406
    def test_sqlite_parameter_escaping(self):
 
407
        #13648: '%s' escaping support for sqlite3
 
408
        cursor = connection.cursor()
 
409
        cursor.execute("select strftime('%s', date('now'))")
 
410
        response = cursor.fetchall()[0][0]
 
411
        # response should be an non-zero integer
 
412
        self.assertTrue(int(response))
 
413
 
 
414
@override_settings(DEBUG=True)
 
415
class EscapingChecksDebug(EscapingChecks):
 
416
    pass
 
417
 
 
418
 
 
419
class SqliteAggregationTests(TestCase):
 
420
    """
 
421
    #19360: Raise NotImplementedError when aggregating on date/time fields.
 
422
    """
 
423
    @unittest.skipUnless(connection.vendor == 'sqlite',
 
424
                         "No need to check SQLite aggregation semantics")
 
425
    def test_aggregation(self):
 
426
        for aggregate in (Sum, Avg, Variance, StdDev):
 
427
            self.assertRaises(NotImplementedError,
 
428
                models.Item.objects.all().aggregate, aggregate('time'))
 
429
            self.assertRaises(NotImplementedError,
 
430
                models.Item.objects.all().aggregate, aggregate('date'))
 
431
            self.assertRaises(NotImplementedError,
 
432
                models.Item.objects.all().aggregate, aggregate('last_modified'))
 
433
 
 
434
 
 
435
class SqliteChecks(TestCase):
 
436
 
 
437
    @unittest.skipUnless(connection.vendor == 'sqlite',
 
438
                         "No need to do SQLite checks")
 
439
    def test_convert_values_to_handle_null_value(self):
 
440
        database_operations = DatabaseOperations(connection)
 
441
        self.assertEqual(
 
442
            None,
 
443
            database_operations.convert_values(None, AutoField(primary_key=True))
 
444
        )
 
445
        self.assertEqual(
 
446
            None,
 
447
            database_operations.convert_values(None, DateField())
 
448
        )
 
449
        self.assertEqual(
 
450
            None,
 
451
            database_operations.convert_values(None, DateTimeField())
 
452
        )
 
453
        self.assertEqual(
 
454
            None,
 
455
            database_operations.convert_values(None, DecimalField())
 
456
        )
 
457
        self.assertEqual(
 
458
            None,
 
459
            database_operations.convert_values(None, IntegerField())
 
460
        )
 
461
        self.assertEqual(
 
462
            None,
 
463
            database_operations.convert_values(None, TimeField())
 
464
        )
 
465
 
 
466
 
 
467
class BackendTestCase(TestCase):
 
468
 
 
469
    def create_squares_with_executemany(self, args):
 
470
        self.create_squares(args, 'format', True)
 
471
 
 
472
    def create_squares(self, args, paramstyle, multiple):    
 
473
        cursor = connection.cursor()
 
474
        opts = models.Square._meta
 
475
        tbl = connection.introspection.table_name_converter(opts.db_table)
 
476
        f1 = connection.ops.quote_name(opts.get_field('root').column)
 
477
        f2 = connection.ops.quote_name(opts.get_field('square').column)
 
478
        if paramstyle=='format':
 
479
            query = 'INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % (tbl, f1, f2)
 
480
        elif paramstyle=='pyformat':
 
481
            query = 'INSERT INTO %s (%s, %s) VALUES (%%(root)s, %%(square)s)' % (tbl, f1, f2)
 
482
        else:
 
483
            raise ValueError("unsupported paramstyle in test")
 
484
        if multiple:
 
485
            cursor.executemany(query, args)
 
486
        else:
 
487
            cursor.execute(query, args)
 
488
 
 
489
    def test_cursor_executemany(self):
 
490
        #4896: Test cursor.executemany
 
491
        args = [(i, i**2) for i in range(-5, 6)]
 
492
        self.create_squares_with_executemany(args)
 
493
        self.assertEqual(models.Square.objects.count(), 11)
 
494
        for i in range(-5, 6):
 
495
            square = models.Square.objects.get(root=i)
 
496
            self.assertEqual(square.square, i**2)
 
497
 
 
498
    def test_cursor_executemany_with_empty_params_list(self):
 
499
        #4765: executemany with params=[] does nothing
 
500
        args = []
 
501
        self.create_squares_with_executemany(args)
 
502
        self.assertEqual(models.Square.objects.count(), 0)
 
503
 
 
504
    def test_cursor_executemany_with_iterator(self):
 
505
        #10320: executemany accepts iterators
 
506
        args = iter((i, i**2) for i in range(-3, 2))
 
507
        self.create_squares_with_executemany(args)
 
508
        self.assertEqual(models.Square.objects.count(), 5)
 
509
 
 
510
        args = iter((i, i**2) for i in range(3, 7))
 
511
        with override_settings(DEBUG=True):
 
512
            # same test for DebugCursorWrapper
 
513
            self.create_squares_with_executemany(args)
 
514
        self.assertEqual(models.Square.objects.count(), 9)
 
515
 
 
516
    @skipUnlessDBFeature('supports_paramstyle_pyformat')
 
517
    def test_cursor_execute_with_pyformat(self):
 
518
        #10070: Support pyformat style passing of paramters
 
519
        args = {'root': 3, 'square': 9}
 
520
        self.create_squares(args, 'pyformat', multiple=False)
 
521
        self.assertEqual(models.Square.objects.count(), 1)
 
522
 
 
523
    @skipUnlessDBFeature('supports_paramstyle_pyformat')
 
524
    def test_cursor_executemany_with_pyformat(self):
 
525
        #10070: Support pyformat style passing of paramters
 
526
        args = [{'root': i, 'square': i**2} for i in range(-5, 6)]
 
527
        self.create_squares(args, 'pyformat', multiple=True)
 
528
        self.assertEqual(models.Square.objects.count(), 11)
 
529
        for i in range(-5, 6):
 
530
            square = models.Square.objects.get(root=i)
 
531
            self.assertEqual(square.square, i**2)
 
532
 
 
533
    @skipUnlessDBFeature('supports_paramstyle_pyformat')
 
534
    def test_cursor_executemany_with_pyformat_iterator(self):
 
535
        args = iter({'root': i, 'square': i**2} for i in range(-3, 2))
 
536
        self.create_squares(args, 'pyformat', multiple=True)
 
537
        self.assertEqual(models.Square.objects.count(), 5)
 
538
 
 
539
        args = iter({'root': i, 'square': i**2} for i in range(3, 7))
 
540
        with override_settings(DEBUG=True):
 
541
            # same test for DebugCursorWrapper
 
542
            self.create_squares(args, 'pyformat', multiple=True)
 
543
        self.assertEqual(models.Square.objects.count(), 9)
 
544
        
 
545
    def test_unicode_fetches(self):
 
546
        #6254: fetchone, fetchmany, fetchall return strings as unicode objects
 
547
        qn = connection.ops.quote_name
 
548
        models.Person(first_name="John", last_name="Doe").save()
 
549
        models.Person(first_name="Jane", last_name="Doe").save()
 
550
        models.Person(first_name="Mary", last_name="Agnelline").save()
 
551
        models.Person(first_name="Peter", last_name="Parker").save()
 
552
        models.Person(first_name="Clark", last_name="Kent").save()
 
553
        opts2 = models.Person._meta
 
554
        f3, f4 = opts2.get_field('first_name'), opts2.get_field('last_name')
 
555
        query2 = ('SELECT %s, %s FROM %s ORDER BY %s'
 
556
          % (qn(f3.column), qn(f4.column), connection.introspection.table_name_converter(opts2.db_table),
 
557
             qn(f3.column)))
 
558
        cursor = connection.cursor()
 
559
        cursor.execute(query2)
 
560
        self.assertEqual(cursor.fetchone(), ('Clark', 'Kent'))
 
561
        self.assertEqual(list(cursor.fetchmany(2)), [('Jane', 'Doe'), ('John', 'Doe')])
 
562
        self.assertEqual(list(cursor.fetchall()), [('Mary', 'Agnelline'), ('Peter', 'Parker')])
 
563
 
 
564
    def test_unicode_password(self):
 
565
        old_password = connection.settings_dict['PASSWORD']
 
566
        connection.settings_dict['PASSWORD'] = "françois"
 
567
        try:
 
568
            connection.cursor()
 
569
        except DatabaseError:
 
570
            # As password is probably wrong, a database exception is expected
 
571
            pass
 
572
        except Exception as e:
 
573
            self.fail("Unexpected error raised with unicode password: %s" % e)
 
574
        finally:
 
575
            connection.settings_dict['PASSWORD'] = old_password
 
576
 
 
577
    def test_database_operations_helper_class(self):
 
578
        # Ticket #13630
 
579
        self.assertTrue(hasattr(connection, 'ops'))
 
580
        self.assertTrue(hasattr(connection.ops, 'connection'))
 
581
        self.assertEqual(connection, connection.ops.connection)
 
582
 
 
583
    def test_cached_db_features(self):
 
584
        self.assertIn(connection.features.supports_transactions, (True, False))
 
585
        self.assertIn(connection.features.supports_stddev, (True, False))
 
586
        self.assertIn(connection.features.can_introspect_foreign_keys, (True, False))
 
587
 
 
588
    def test_duplicate_table_error(self):
 
589
        """ Test that creating an existing table returns a DatabaseError """
 
590
        cursor = connection.cursor()
 
591
        query = 'CREATE TABLE %s (id INTEGER);' % models.Article._meta.db_table
 
592
        with self.assertRaises(DatabaseError):
 
593
            cursor.execute(query)
 
594
 
 
595
 
 
596
# We don't make these tests conditional because that means we would need to
 
597
# check and differentiate between:
 
598
# * MySQL+InnoDB, MySQL+MYISAM (something we currently can't do).
 
599
# * if sqlite3 (if/once we get #14204 fixed) has referential integrity turned
 
600
#   on or not, something that would be controlled by runtime support and user
 
601
#   preference.
 
602
# verify if its type is django.database.db.IntegrityError.
 
603
class FkConstraintsTests(TransactionTestCase):
 
604
 
 
605
    available_apps = ['backends']
 
606
 
 
607
    def setUp(self):
 
608
        # Create a Reporter.
 
609
        self.r = models.Reporter.objects.create(first_name='John', last_name='Smith')
 
610
 
 
611
    def test_integrity_checks_on_creation(self):
 
612
        """
 
613
        Try to create a model instance that violates a FK constraint. If it
 
614
        fails it should fail with IntegrityError.
 
615
        """
 
616
        a1 = models.Article(headline="This is a test", pub_date=datetime.datetime(2005, 7, 27), reporter_id=30)
 
617
        try:
 
618
            a1.save()
 
619
        except IntegrityError:
 
620
            pass
 
621
        else:
 
622
            self.skipTest("This backend does not support integrity checks.")
 
623
        # Now that we know this backend supports integrity checks we make sure
 
624
        # constraints are also enforced for proxy models. Refs #17519
 
625
        a2 = models.Article(headline='This is another test', reporter=self.r,
 
626
                            pub_date=datetime.datetime(2012, 8, 3),
 
627
                            reporter_proxy_id=30)
 
628
        self.assertRaises(IntegrityError, a2.save)
 
629
 
 
630
    def test_integrity_checks_on_update(self):
 
631
        """
 
632
        Try to update a model instance introducing a FK constraint violation.
 
633
        If it fails it should fail with IntegrityError.
 
634
        """
 
635
        # Create an Article.
 
636
        models.Article.objects.create(headline="Test article", pub_date=datetime.datetime(2010, 9, 4), reporter=self.r)
 
637
        # Retrieve it from the DB
 
638
        a1 = models.Article.objects.get(headline="Test article")
 
639
        a1.reporter_id = 30
 
640
        try:
 
641
            a1.save()
 
642
        except IntegrityError:
 
643
            pass
 
644
        else:
 
645
            self.skipTest("This backend does not support integrity checks.")
 
646
        # Now that we know this backend supports integrity checks we make sure
 
647
        # constraints are also enforced for proxy models. Refs #17519
 
648
        # Create another article
 
649
        r_proxy = models.ReporterProxy.objects.get(pk=self.r.pk)
 
650
        models.Article.objects.create(headline='Another article',
 
651
                                      pub_date=datetime.datetime(1988, 5, 15),
 
652
                                      reporter=self.r, reporter_proxy=r_proxy)
 
653
        # Retreive the second article from the DB
 
654
        a2 = models.Article.objects.get(headline='Another article')
 
655
        a2.reporter_proxy_id = 30
 
656
        self.assertRaises(IntegrityError, a2.save)
 
657
 
 
658
    def test_disable_constraint_checks_manually(self):
 
659
        """
 
660
        When constraint checks are disabled, should be able to write bad data without IntegrityErrors.
 
661
        """
 
662
        with transaction.atomic():
 
663
            # Create an Article.
 
664
            models.Article.objects.create(headline="Test article", pub_date=datetime.datetime(2010, 9, 4), reporter=self.r)
 
665
            # Retrive it from the DB
 
666
            a = models.Article.objects.get(headline="Test article")
 
667
            a.reporter_id = 30
 
668
            try:
 
669
                connection.disable_constraint_checking()
 
670
                a.save()
 
671
                connection.enable_constraint_checking()
 
672
            except IntegrityError:
 
673
                self.fail("IntegrityError should not have occurred.")
 
674
            transaction.set_rollback(True)
 
675
 
 
676
    def test_disable_constraint_checks_context_manager(self):
 
677
        """
 
678
        When constraint checks are disabled (using context manager), should be able to write bad data without IntegrityErrors.
 
679
        """
 
680
        with transaction.atomic():
 
681
            # Create an Article.
 
682
            models.Article.objects.create(headline="Test article", pub_date=datetime.datetime(2010, 9, 4), reporter=self.r)
 
683
            # Retrive it from the DB
 
684
            a = models.Article.objects.get(headline="Test article")
 
685
            a.reporter_id = 30
 
686
            try:
 
687
                with connection.constraint_checks_disabled():
 
688
                    a.save()
 
689
            except IntegrityError:
 
690
                self.fail("IntegrityError should not have occurred.")
 
691
            transaction.set_rollback(True)
 
692
 
 
693
    def test_check_constraints(self):
 
694
        """
 
695
        Constraint checks should raise an IntegrityError when bad data is in the DB.
 
696
        """
 
697
        with transaction.atomic():
 
698
            # Create an Article.
 
699
            models.Article.objects.create(headline="Test article", pub_date=datetime.datetime(2010, 9, 4), reporter=self.r)
 
700
            # Retrive it from the DB
 
701
            a = models.Article.objects.get(headline="Test article")
 
702
            a.reporter_id = 30
 
703
            with connection.constraint_checks_disabled():
 
704
                a.save()
 
705
                with self.assertRaises(IntegrityError):
 
706
                    connection.check_constraints()
 
707
            transaction.set_rollback(True)
 
708
 
 
709
 
 
710
class ThreadTests(TestCase):
 
711
 
 
712
    def test_default_connection_thread_local(self):
 
713
        """
 
714
        Ensure that the default connection (i.e. django.db.connection) is
 
715
        different for each thread.
 
716
        Refs #17258.
 
717
        """
 
718
        # Map connections by id because connections with identical aliases
 
719
        # have the same hash.
 
720
        connections_dict = {}
 
721
        connection.cursor()
 
722
        connections_dict[id(connection)] = connection
 
723
 
 
724
        def runner():
 
725
            # Passing django.db.connection between threads doesn't work while
 
726
            # connections[DEFAULT_DB_ALIAS] does.
 
727
            from django.db import connections
 
728
            connection = connections[DEFAULT_DB_ALIAS]
 
729
            # Allow thread sharing so the connection can be closed by the
 
730
            # main thread.
 
731
            connection.allow_thread_sharing = True
 
732
            connection.cursor()
 
733
            connections_dict[id(connection)] = connection
 
734
        for x in range(2):
 
735
            t = threading.Thread(target=runner)
 
736
            t.start()
 
737
            t.join()
 
738
        # Check that each created connection got different inner connection.
 
739
        self.assertEqual(
 
740
            len(set(conn.connection for conn in connections_dict.values())),
 
741
            3)
 
742
        # Finish by closing the connections opened by the other threads (the
 
743
        # connection opened in the main thread will automatically be closed on
 
744
        # teardown).
 
745
        for conn in connections_dict.values():
 
746
            if conn is not connection:
 
747
                conn.close()
 
748
 
 
749
    def test_connections_thread_local(self):
 
750
        """
 
751
        Ensure that the connections are different for each thread.
 
752
        Refs #17258.
 
753
        """
 
754
        # Map connections by id because connections with identical aliases
 
755
        # have the same hash.
 
756
        connections_dict = {}
 
757
        for conn in connections.all():
 
758
            connections_dict[id(conn)] = conn
 
759
 
 
760
        def runner():
 
761
            from django.db import connections
 
762
            for conn in connections.all():
 
763
                # Allow thread sharing so the connection can be closed by the
 
764
                # main thread.
 
765
                conn.allow_thread_sharing = True
 
766
                connections_dict[id(conn)] = conn
 
767
        for x in range(2):
 
768
            t = threading.Thread(target=runner)
 
769
            t.start()
 
770
            t.join()
 
771
        self.assertEqual(len(connections_dict), 6)
 
772
        # Finish by closing the connections opened by the other threads (the
 
773
        # connection opened in the main thread will automatically be closed on
 
774
        # teardown).
 
775
        for conn in connections_dict.values():
 
776
            if conn is not connection:
 
777
                conn.close()
 
778
 
 
779
    def test_pass_connection_between_threads(self):
 
780
        """
 
781
        Ensure that a connection can be passed from one thread to the other.
 
782
        Refs #17258.
 
783
        """
 
784
        models.Person.objects.create(first_name="John", last_name="Doe")
 
785
 
 
786
        def do_thread():
 
787
            def runner(main_thread_connection):
 
788
                from django.db import connections
 
789
                connections['default'] = main_thread_connection
 
790
                try:
 
791
                    models.Person.objects.get(first_name="John", last_name="Doe")
 
792
                except Exception as e:
 
793
                    exceptions.append(e)
 
794
            t = threading.Thread(target=runner, args=[connections['default']])
 
795
            t.start()
 
796
            t.join()
 
797
 
 
798
        # Without touching allow_thread_sharing, which should be False by default.
 
799
        exceptions = []
 
800
        do_thread()
 
801
        # Forbidden!
 
802
        self.assertIsInstance(exceptions[0], DatabaseError)
 
803
 
 
804
        # If explicitly setting allow_thread_sharing to False
 
805
        connections['default'].allow_thread_sharing = False
 
806
        exceptions = []
 
807
        do_thread()
 
808
        # Forbidden!
 
809
        self.assertIsInstance(exceptions[0], DatabaseError)
 
810
 
 
811
        # If explicitly setting allow_thread_sharing to True
 
812
        connections['default'].allow_thread_sharing = True
 
813
        exceptions = []
 
814
        do_thread()
 
815
        # All good
 
816
        self.assertEqual(exceptions, [])
 
817
 
 
818
    def test_closing_non_shared_connections(self):
 
819
        """
 
820
        Ensure that a connection that is not explicitly shareable cannot be
 
821
        closed by another thread.
 
822
        Refs #17258.
 
823
        """
 
824
        # First, without explicitly enabling the connection for sharing.
 
825
        exceptions = set()
 
826
 
 
827
        def runner1():
 
828
            def runner2(other_thread_connection):
 
829
                try:
 
830
                    other_thread_connection.close()
 
831
                except DatabaseError as e:
 
832
                    exceptions.add(e)
 
833
            t2 = threading.Thread(target=runner2, args=[connections['default']])
 
834
            t2.start()
 
835
            t2.join()
 
836
        t1 = threading.Thread(target=runner1)
 
837
        t1.start()
 
838
        t1.join()
 
839
        # The exception was raised
 
840
        self.assertEqual(len(exceptions), 1)
 
841
 
 
842
        # Then, with explicitly enabling the connection for sharing.
 
843
        exceptions = set()
 
844
 
 
845
        def runner1():
 
846
            def runner2(other_thread_connection):
 
847
                try:
 
848
                    other_thread_connection.close()
 
849
                except DatabaseError as e:
 
850
                    exceptions.add(e)
 
851
            # Enable thread sharing
 
852
            connections['default'].allow_thread_sharing = True
 
853
            t2 = threading.Thread(target=runner2, args=[connections['default']])
 
854
            t2.start()
 
855
            t2.join()
 
856
        t1 = threading.Thread(target=runner1)
 
857
        t1.start()
 
858
        t1.join()
 
859
        # No exception was raised
 
860
        self.assertEqual(len(exceptions), 0)
 
861
 
 
862
 
 
863
class MySQLPKZeroTests(TestCase):
 
864
    """
 
865
    Zero as id for AutoField should raise exception in MySQL, because MySQL
 
866
    does not allow zero for automatic primary key.
 
867
    """
 
868
 
 
869
    @skipIfDBFeature('allows_primary_key_0')
 
870
    def test_zero_as_autoval(self):
 
871
        with self.assertRaises(ValueError):
 
872
            models.Square.objects.create(id=0, root=0, square=1)
 
873
 
 
874
 
 
875
class DBConstraintTestCase(TransactionTestCase):
 
876
 
 
877
    available_apps = ['backends']
 
878
 
 
879
    def test_can_reference_existant(self):
 
880
        obj = models.Object.objects.create()
 
881
        ref = models.ObjectReference.objects.create(obj=obj)
 
882
        self.assertEqual(ref.obj, obj)
 
883
 
 
884
        ref = models.ObjectReference.objects.get(obj=obj)
 
885
        self.assertEqual(ref.obj, obj)
 
886
 
 
887
    def test_can_reference_non_existant(self):
 
888
        self.assertFalse(models.Object.objects.filter(id=12345).exists())
 
889
        ref = models.ObjectReference.objects.create(obj_id=12345)
 
890
        ref_new = models.ObjectReference.objects.get(obj_id=12345)
 
891
        self.assertEqual(ref, ref_new)
 
892
 
 
893
        with self.assertRaises(models.Object.DoesNotExist):
 
894
            ref.obj
 
895
 
 
896
    def test_many_to_many(self):
 
897
        obj = models.Object.objects.create()
 
898
        obj.related_objects.create()
 
899
        self.assertEqual(models.Object.objects.count(), 2)
 
900
        self.assertEqual(obj.related_objects.count(), 1)
 
901
 
 
902
        intermediary_model = models.Object._meta.get_field_by_name("related_objects")[0].rel.through
 
903
        intermediary_model.objects.create(from_object_id=obj.id, to_object_id=12345)
 
904
        self.assertEqual(obj.related_objects.count(), 1)
 
905
        self.assertEqual(intermediary_model.objects.count(), 2)
 
906
 
 
907
 
 
908
class BackendUtilTests(TestCase):
 
909
 
 
910
    def test_format_number(self):
 
911
        """
 
912
        Test the format_number converter utility
 
913
        """
 
914
        def equal(value, max_d, places, result):
 
915
            self.assertEqual(format_number(Decimal(value), max_d, places), result)
 
916
 
 
917
        equal('0', 12, 3,
 
918
              '0.000')
 
919
        equal('0', 12, 8,
 
920
              '0.00000000')
 
921
        equal('1', 12, 9,
 
922
              '1.000000000')
 
923
        equal('0.00000000', 12, 8,
 
924
              '0.00000000')
 
925
        equal('0.000000004', 12, 8,
 
926
              '0.00000000')
 
927
        equal('0.000000008', 12, 8,
 
928
              '0.00000001')
 
929
        equal('0.000000000000000000999', 10, 8,
 
930
              '0.00000000')
 
931
        equal('0.1234567890', 12, 10,
 
932
              '0.1234567890')
 
933
        equal('0.1234567890', 12, 9,
 
934
              '0.123456789')
 
935
        equal('0.1234567890', 12, 8,
 
936
              '0.12345679')
 
937
        equal('0.1234567890', 12, 5,
 
938
              '0.12346')
 
939
        equal('0.1234567890', 12, 3,
 
940
              '0.123')
 
941
        equal('0.1234567890', 12, 1,
 
942
              '0.1')
 
943
        equal('0.1234567890', 12, 0,
 
944
              '0')