~ubuntu-branches/ubuntu/vivid/ironic/vivid-updates

« back to all changes in this revision

Viewing changes to ironic/openstack/common/db/sqlalchemy/test_migrations.py

  • Committer: Package Import Robot
  • Author(s): Chuck Short
  • Date: 2014-03-06 13:23:35 UTC
  • mfrom: (1.1.2)
  • Revision ID: package-import@ubuntu.com-20140306132335-5b49ji56jffxvtn4
Tags: 2014.1~b3-0ubuntu1
* New upstream release:
  - debian/patches/fix-requirements.patch: Dropped no longer needed.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright 2010-2011 OpenStack Foundation
 
2
# Copyright 2012-2013 IBM Corp.
 
3
# All Rights Reserved.
 
4
#
 
5
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
 
6
#    not use this file except in compliance with the License. You may obtain
 
7
#    a copy of the License at
 
8
#
 
9
#         http://www.apache.org/licenses/LICENSE-2.0
 
10
#
 
11
#    Unless required by applicable law or agreed to in writing, software
 
12
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 
13
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 
14
#    License for the specific language governing permissions and limitations
 
15
#    under the License.
 
16
 
 
17
import functools
 
18
import os
 
19
import subprocess
 
20
 
 
21
import lockfile
 
22
from six import moves
 
23
import sqlalchemy
 
24
import sqlalchemy.exc
 
25
 
 
26
from ironic.openstack.common.gettextutils import _
 
27
from ironic.openstack.common import log as logging
 
28
from ironic.openstack.common.py3kcompat import urlutils
 
29
from ironic.openstack.common import test
 
30
 
 
31
LOG = logging.getLogger(__name__)
 
32
 
 
33
 
 
34
def _get_connect_string(backend, user, passwd, database):
 
35
    """Get database connection
 
36
 
 
37
    Try to get a connection with a very specific set of values, if we get
 
38
    these then we'll run the tests, otherwise they are skipped
 
39
    """
 
40
    if backend == "postgres":
 
41
        backend = "postgresql+psycopg2"
 
42
    elif backend == "mysql":
 
43
        backend = "mysql+mysqldb"
 
44
    else:
 
45
        raise Exception("Unrecognized backend: '%s'" % backend)
 
46
 
 
47
    return ("%(backend)s://%(user)s:%(passwd)s@localhost/%(database)s"
 
48
            % {'backend': backend, 'user': user, 'passwd': passwd,
 
49
               'database': database})
 
50
 
 
51
 
 
52
def _is_backend_avail(backend, user, passwd, database):
 
53
    try:
 
54
        connect_uri = _get_connect_string(backend, user, passwd, database)
 
55
        engine = sqlalchemy.create_engine(connect_uri)
 
56
        connection = engine.connect()
 
57
    except Exception:
 
58
        # intentionally catch all to handle exceptions even if we don't
 
59
        # have any backend code loaded.
 
60
        return False
 
61
    else:
 
62
        connection.close()
 
63
        engine.dispose()
 
64
        return True
 
65
 
 
66
 
 
67
def _have_mysql(user, passwd, database):
 
68
    present = os.environ.get('TEST_MYSQL_PRESENT')
 
69
    if present is None:
 
70
        return _is_backend_avail('mysql', user, passwd, database)
 
71
    return present.lower() in ('', 'true')
 
72
 
 
73
 
 
74
def _have_postgresql(user, passwd, database):
 
75
    present = os.environ.get('TEST_POSTGRESQL_PRESENT')
 
76
    if present is None:
 
77
        return _is_backend_avail('postgres', user, passwd, database)
 
78
    return present.lower() in ('', 'true')
 
79
 
 
80
 
 
81
def get_db_connection_info(conn_pieces):
 
82
    database = conn_pieces.path.strip('/')
 
83
    loc_pieces = conn_pieces.netloc.split('@')
 
84
    host = loc_pieces[1]
 
85
 
 
86
    auth_pieces = loc_pieces[0].split(':')
 
87
    user = auth_pieces[0]
 
88
    password = ""
 
89
    if len(auth_pieces) > 1:
 
90
        password = auth_pieces[1].strip()
 
91
 
 
92
    return (user, password, database, host)
 
93
 
 
94
 
 
95
def _set_db_lock(lock_path=None, lock_prefix=None):
 
96
    def decorator(f):
 
97
        @functools.wraps(f)
 
98
        def wrapper(*args, **kwargs):
 
99
            try:
 
100
                path = lock_path or os.environ.get("IRONIC_LOCK_PATH")
 
101
                lock = lockfile.FileLock(os.path.join(path, lock_prefix))
 
102
                with lock:
 
103
                    LOG.debug(_('Got lock "%s"') % f.__name__)
 
104
                    return f(*args, **kwargs)
 
105
            finally:
 
106
                LOG.debug(_('Lock released "%s"') % f.__name__)
 
107
        return wrapper
 
108
    return decorator
 
109
 
 
110
 
 
111
class BaseMigrationTestCase(test.BaseTestCase):
 
112
    """Base class fort testing of migration utils."""
 
113
 
 
114
    def __init__(self, *args, **kwargs):
 
115
        super(BaseMigrationTestCase, self).__init__(*args, **kwargs)
 
116
 
 
117
        self.DEFAULT_CONFIG_FILE = os.path.join(os.path.dirname(__file__),
 
118
                                                'test_migrations.conf')
 
119
        # Test machines can set the TEST_MIGRATIONS_CONF variable
 
120
        # to override the location of the config file for migration testing
 
121
        self.CONFIG_FILE_PATH = os.environ.get('TEST_MIGRATIONS_CONF',
 
122
                                               self.DEFAULT_CONFIG_FILE)
 
123
        self.test_databases = {}
 
124
        self.migration_api = None
 
125
 
 
126
    def setUp(self):
 
127
        super(BaseMigrationTestCase, self).setUp()
 
128
 
 
129
        # Load test databases from the config file. Only do this
 
130
        # once. No need to re-run this on each test...
 
131
        LOG.debug('config_path is %s' % self.CONFIG_FILE_PATH)
 
132
        if os.path.exists(self.CONFIG_FILE_PATH):
 
133
            cp = moves.configparser.RawConfigParser()
 
134
            try:
 
135
                cp.read(self.CONFIG_FILE_PATH)
 
136
                defaults = cp.defaults()
 
137
                for key, value in defaults.items():
 
138
                    self.test_databases[key] = value
 
139
            except moves.configparser.ParsingError as e:
 
140
                self.fail("Failed to read test_migrations.conf config "
 
141
                          "file. Got error: %s" % e)
 
142
        else:
 
143
            self.fail("Failed to find test_migrations.conf config "
 
144
                      "file.")
 
145
 
 
146
        self.engines = {}
 
147
        for key, value in self.test_databases.items():
 
148
            self.engines[key] = sqlalchemy.create_engine(value)
 
149
 
 
150
        # We start each test case with a completely blank slate.
 
151
        self._reset_databases()
 
152
 
 
153
    def tearDown(self):
 
154
        # We destroy the test data store between each test case,
 
155
        # and recreate it, which ensures that we have no side-effects
 
156
        # from the tests
 
157
        self._reset_databases()
 
158
        super(BaseMigrationTestCase, self).tearDown()
 
159
 
 
160
    def execute_cmd(self, cmd=None):
 
161
        process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE,
 
162
                                   stderr=subprocess.STDOUT)
 
163
        output = process.communicate()[0]
 
164
        LOG.debug(output)
 
165
        self.assertEqual(0, process.returncode,
 
166
                         "Failed to run: %s\n%s" % (cmd, output))
 
167
 
 
168
    def _reset_pg(self, conn_pieces):
 
169
        (user, password, database, host) = get_db_connection_info(conn_pieces)
 
170
        os.environ['PGPASSWORD'] = password
 
171
        os.environ['PGUSER'] = user
 
172
        # note(boris-42): We must create and drop database, we can't
 
173
        # drop database which we have connected to, so for such
 
174
        # operations there is a special database template1.
 
175
        sqlcmd = ("psql -w -U %(user)s -h %(host)s -c"
 
176
                  " '%(sql)s' -d template1")
 
177
 
 
178
        sql = ("drop database if exists %s;") % database
 
179
        droptable = sqlcmd % {'user': user, 'host': host, 'sql': sql}
 
180
        self.execute_cmd(droptable)
 
181
 
 
182
        sql = ("create database %s;") % database
 
183
        createtable = sqlcmd % {'user': user, 'host': host, 'sql': sql}
 
184
        self.execute_cmd(createtable)
 
185
 
 
186
        os.unsetenv('PGPASSWORD')
 
187
        os.unsetenv('PGUSER')
 
188
 
 
189
    @_set_db_lock(lock_prefix='migration_tests-')
 
190
    def _reset_databases(self):
 
191
        for key, engine in self.engines.items():
 
192
            conn_string = self.test_databases[key]
 
193
            conn_pieces = urlutils.urlparse(conn_string)
 
194
            engine.dispose()
 
195
            if conn_string.startswith('sqlite'):
 
196
                # We can just delete the SQLite database, which is
 
197
                # the easiest and cleanest solution
 
198
                db_path = conn_pieces.path.strip('/')
 
199
                if os.path.exists(db_path):
 
200
                    os.unlink(db_path)
 
201
                # No need to recreate the SQLite DB. SQLite will
 
202
                # create it for us if it's not there...
 
203
            elif conn_string.startswith('mysql'):
 
204
                # We can execute the MySQL client to destroy and re-create
 
205
                # the MYSQL database, which is easier and less error-prone
 
206
                # than using SQLAlchemy to do this via MetaData...trust me.
 
207
                (user, password, database, host) = \
 
208
                    get_db_connection_info(conn_pieces)
 
209
                sql = ("drop database if exists %(db)s; "
 
210
                       "create database %(db)s;") % {'db': database}
 
211
                cmd = ("mysql -u \"%(user)s\" -p\"%(password)s\" -h %(host)s "
 
212
                       "-e \"%(sql)s\"") % {'user': user, 'password': password,
 
213
                                            'host': host, 'sql': sql}
 
214
                self.execute_cmd(cmd)
 
215
            elif conn_string.startswith('postgresql'):
 
216
                self._reset_pg(conn_pieces)
 
217
 
 
218
 
 
219
class WalkVersionsMixin(object):
 
220
    def _walk_versions(self, engine=None, snake_walk=False, downgrade=True):
 
221
        # Determine latest version script from the repo, then
 
222
        # upgrade from 1 through to the latest, with no data
 
223
        # in the databases. This just checks that the schema itself
 
224
        # upgrades successfully.
 
225
 
 
226
        # Place the database under version control
 
227
        self.migration_api.version_control(engine, self.REPOSITORY,
 
228
                                           self.INIT_VERSION)
 
229
        self.assertEqual(self.INIT_VERSION,
 
230
                         self.migration_api.db_version(engine,
 
231
                                                       self.REPOSITORY))
 
232
 
 
233
        LOG.debug('latest version is %s' % self.REPOSITORY.latest)
 
234
        versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1)
 
235
 
 
236
        for version in versions:
 
237
            # upgrade -> downgrade -> upgrade
 
238
            self._migrate_up(engine, version, with_data=True)
 
239
            if snake_walk:
 
240
                downgraded = self._migrate_down(
 
241
                    engine, version - 1, with_data=True)
 
242
                if downgraded:
 
243
                    self._migrate_up(engine, version)
 
244
 
 
245
        if downgrade:
 
246
            # Now walk it back down to 0 from the latest, testing
 
247
            # the downgrade paths.
 
248
            for version in reversed(versions):
 
249
                # downgrade -> upgrade -> downgrade
 
250
                downgraded = self._migrate_down(engine, version - 1)
 
251
 
 
252
                if snake_walk and downgraded:
 
253
                    self._migrate_up(engine, version)
 
254
                    self._migrate_down(engine, version - 1)
 
255
 
 
256
    def _migrate_down(self, engine, version, with_data=False):
 
257
        try:
 
258
            self.migration_api.downgrade(engine, self.REPOSITORY, version)
 
259
        except NotImplementedError:
 
260
            # NOTE(sirp): some migrations, namely release-level
 
261
            # migrations, don't support a downgrade.
 
262
            return False
 
263
 
 
264
        self.assertEqual(
 
265
            version, self.migration_api.db_version(engine, self.REPOSITORY))
 
266
 
 
267
        # NOTE(sirp): `version` is what we're downgrading to (i.e. the 'target'
 
268
        # version). So if we have any downgrade checks, they need to be run for
 
269
        # the previous (higher numbered) migration.
 
270
        if with_data:
 
271
            post_downgrade = getattr(
 
272
                self, "_post_downgrade_%03d" % (version + 1), None)
 
273
            if post_downgrade:
 
274
                post_downgrade(engine)
 
275
 
 
276
        return True
 
277
 
 
278
    def _migrate_up(self, engine, version, with_data=False):
 
279
        """migrate up to a new version of the db.
 
280
 
 
281
        We allow for data insertion and post checks at every
 
282
        migration version with special _pre_upgrade_### and
 
283
        _check_### functions in the main test.
 
284
        """
 
285
        # NOTE(sdague): try block is here because it's impossible to debug
 
286
        # where a failed data migration happens otherwise
 
287
        try:
 
288
            if with_data:
 
289
                data = None
 
290
                pre_upgrade = getattr(
 
291
                    self, "_pre_upgrade_%03d" % version, None)
 
292
                if pre_upgrade:
 
293
                    data = pre_upgrade(engine)
 
294
 
 
295
            self.migration_api.upgrade(engine, self.REPOSITORY, version)
 
296
            self.assertEqual(version,
 
297
                             self.migration_api.db_version(engine,
 
298
                                                           self.REPOSITORY))
 
299
            if with_data:
 
300
                check = getattr(self, "_check_%03d" % version, None)
 
301
                if check:
 
302
                    check(engine, data)
 
303
        except Exception:
 
304
            LOG.error("Failed to migrate to version %s on engine %s" %
 
305
                      (version, engine))
 
306
            raise