~launchpad-committers/storm/lp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#
# Copyright (c) 2006, 2007 Canonical
#
# Written by Gustavo Niemeyer <gustavo@niemeyer.net>
#
# This file is part of Storm Object Relational Mapper.
#
# Storm is free software; you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation; either version 2.1 of
# the License, or (at your option) any later version.
#
# Storm is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
from __future__ import print_function

import os

from six.moves.urllib.parse import urlunsplit

from storm.databases.mysql import MySQL
from storm.database import create_database
from storm.expr import Column, Insert
from storm.uri import URI
from storm.variables import IntVariable, UnicodeVariable

from storm.tests.databases.base import (
    DatabaseTest, DatabaseDisconnectionTest, UnsupportedDatabaseTest)
from storm.tests.databases.proxy import ProxyTCPServer
from storm.tests.helper import TestHelper


def create_proxy_and_uri(uri):
    """Create a TCP proxy to a Unix-domain database identified by `uri`."""
    proxy = ProxyTCPServer(uri.options["unix_socket"])
    proxy_host, proxy_port = proxy.server_address
    proxy_uri = URI(urlunsplit(
        ("mysql", "%s:%s" % (proxy_host, proxy_port), "/storm_test",
         "", "")))
    return proxy, proxy_uri


class MySQLTest(DatabaseTest, TestHelper):

    supports_microseconds = False

    def is_supported(self):
        return bool(os.environ.get("STORM_MYSQL_URI"))

    def create_database(self):
        self.database = create_database(os.environ["STORM_MYSQL_URI"])

    def create_tables(self):
        self.connection.execute("CREATE TABLE number "
                                "(one INTEGER, two INTEGER, three INTEGER)")
        self.connection.execute("CREATE TABLE test "
                                "(id INT AUTO_INCREMENT PRIMARY KEY,"
                                " title VARCHAR(50)) ENGINE=InnoDB")
        self.connection.execute("CREATE TABLE datetime_test "
                                "(id INT AUTO_INCREMENT PRIMARY KEY,"
                                " dt TIMESTAMP, d DATE, t TIME, td TEXT) "
                                "ENGINE=InnoDB")
        self.connection.execute("CREATE TABLE bin_test "
                                "(id INT AUTO_INCREMENT PRIMARY KEY,"
                                " b BLOB) ENGINE=InnoDB")

    def test_wb_create_database(self):
        database = create_database("mysql://un:pw@ht:12/db?unix_socket=us")
        self.assertTrue(isinstance(database, MySQL))
        for key, value in [("db", "db"), ("host", "ht"), ("port", 12),
                           ("user", "un"), ("passwd", "pw"),
                           ("unix_socket", "us")]:
            self.assertEqual(database._connect_kwargs.get(key), value)

    def test_charset_defaults_to_utf8mb3(self):
        result = self.connection.execute("SELECT @@character_set_client")
        self.assertEqual(result.get_one(), ("utf8mb3",))

    def test_charset_option(self):
        uri = URI(os.environ["STORM_MYSQL_URI"])
        uri.options["charset"] = "ascii"
        database = create_database(uri)
        connection = database.connect()
        result = connection.execute("SELECT @@character_set_client")
        self.assertEqual(result.get_one(), ("ascii",))

    def test_get_insert_identity(self):
        # Primary keys are filled in during execute() for MySQL
        pass

    def test_get_insert_identity_composed(self):
        # Primary keys are filled in during execute() for MySQL
        pass

    def test_execute_insert_auto_increment_primary_key(self):
        id_column = Column("id", "test")
        id_variable = IntVariable()
        title_column = Column("title", "test")
        title_variable = UnicodeVariable(u"testing")

        # This is not part of the table.  It is just used to show that
        # only one primary key variable is set from the insert ID.
        dummy_column = Column("dummy", "test")
        dummy_variable = IntVariable()

        insert = Insert({title_column: title_variable},
                        primary_columns=(id_column, dummy_column),
                        primary_variables=(id_variable, dummy_variable))
        self.connection.execute(insert)
        self.assertTrue(id_variable.is_defined())
        self.assertFalse(dummy_variable.is_defined())

        # The newly inserted row should have the maximum id value for
        # the table.
        result = self.connection.execute("SELECT MAX(id) FROM test")
        self.assertEqual(result.get_one()[0], id_variable.get())

    def test_mysql_specific_reserved_words(self):
        reserved_words = """
            accessible analyze asensitive before bigint binary blob call
            change condition current_user database databases day_hour
            day_microsecond day_minute day_second delayed deterministic
            distinctrow div dual each elseif enclosed escaped exit explain
            float4 float8 force fulltext high_priority hour_microsecond
            hour_minute hour_second if ignore index infile inout int1 int2
            int3 int4 int8 iterate keys kill leave limit linear lines load
            localtime localtimestamp lock long longblob longtext loop
            low_priority master_ssl_verify_server_cert mediumblob mediumint
            mediumtext middleint minute_microsecond minute_second mod modifies
            no_write_to_binlog optimize optionally out outfile purge range
            read_write reads regexp release rename repeat replace require
            return rlike schemas second_microsecond sensitive separator show
            spatial specific sql_big_result sql_calc_found_rows
            sql_small_result sqlexception sqlwarning ssl starting
            straight_join terminated tinyblob tinyint tinytext trigger undo
            unlock unsigned use utc_date utc_time utc_timestamp varbinary
            varcharacter while xor year_month zerofill
            """.split()
        for word in reserved_words:
            self.assertTrue(self.connection.compile.is_reserved_word(word),
                            "Word missing: %s" % (word,))


class MySQLUnsupportedTest(UnsupportedDatabaseTest, TestHelper):

    dbapi_module_names = ["MySQLdb"]
    db_module_name = "mysql"


class MySQLDisconnectionTest(DatabaseDisconnectionTest, TestHelper):

    environment_variable = "STORM_MYSQL_URI"
    host_environment_variable = "STORM_MYSQL_HOST_URI"
    default_port = 3306

    def create_proxy(self, uri):
        """See `DatabaseDisconnectionMixin.create_proxy`."""
        if "unix_socket" in uri.options:
            return create_proxy_and_uri(uri)[0]
        else:
            return super(MySQLDisconnectionTest, self).create_proxy(uri)