3
# Copyright 2014 Canonical Ltd.
5
# This program is free software: you can redistribute it and/or modify it
6
# under the terms of the GNU Affero General Public License version 3, as
7
# published by the Free Software Foundation.
9
# This program is distributed in the hope that it will be useful, but
10
# WITHOUT ANY WARRANTY; without even the implied warranties of
11
# MERCHANTABILITY, SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR
12
# PURPOSE. See the GNU Affero General Public License for more details.
14
# You should have received a copy of the GNU Affero General Public License
15
# along with this program. If not, see <http://www.gnu.org/licenses/>.
17
"""Migration script for the Non Functional Stats Service.
19
This script is called when the service is installed, and any time the service
20
is upgraded. It follows a few simple steps to manage the database version:
22
1) Connecxt to the database, and read the version from the db_version table. If
23
it doesn't exist, assume the version is -1.
25
2) Scan the db_patches directory for files that end in '.sql'. The files must
26
be named 'NNN-{name}.sql'. The NNN is the numeric order of the patches,
29
3) Apply patches in order. Each patch is applied in a DB transaction, and after
30
each patch we incriment the database revision number in the db_version
31
table. As a special case, if we just upgraded to version '0', we create
42
from nfss.database import get_connection_for_request
46
connection = get_connection_for_request()
47
current_version = get_current_database_version(connection)
48
maximum_version = get_maximum_version()
49
print("Current database version is: %d" % current_version)
50
if current_version != maximum_version:
52
"Database schema can be updated to version: %d" % maximum_version
54
for new_version in range(current_version + 1, maximum_version + 1):
55
upgrade_database_to_version(connection, new_version)
59
def get_current_database_version(connection):
60
cursor = connection.cursor()
62
cursor.execute('SELECT version FROM db_version;')
63
except psycopg2.ProgrammingError:
64
# table doesn't exist, we have a blank database, and need to start from
69
data = cursor.fetchone()
71
# table exists, but contains no data - not sure what this means
72
# as we should never get here. Let's assume version -1.
78
def report_error(message):
79
# TODO: replace with proper logging!
80
sys.stderr.write(message + '\n')
83
def get_maximum_version():
84
"""Return the maximum version we can patch the database to."""
85
patches = get_all_patch_paths()
86
return int(patches[-1][:3])
89
def get_all_patch_paths():
90
patches_path = get_database_patch_dir()
93
lambda f: re.match(r'\d{3}-.*.sql', f),
94
os.listdir(patches_path)
99
def get_patch_file_path_for_version(version_number):
100
"""Given a database version number, get the path to a patch that upgrades
101
the schema to that version.
103
raise a RuntimeError if more than one file matches.
107
get_database_patch_dir(),
108
"%03d-*.sql" % version_number
113
"More than one file matched for version %d: %r" % (
114
version_number, matches)
118
"No patch file found for version %d" % version_number
123
def get_database_patch_dir():
124
return os.path.abspath(
126
os.path.dirname(__file__),
133
def upgrade_database_to_version(connection, new_version):
134
"""Do the work to upgrade the database to version 'new_version'.
136
This function starts a transaction on 'connection', and will either
137
commit or rollback the transaction before the function ends.
139
On success the function returns None, on failure the function raises an
142
Note: Only call this function if the database is at 'new_version' - 1.
145
patch_file = get_patch_file_path_for_version(new_version)
146
print("Upgrading to version %s ..." % new_version)
147
cursor = connection.cursor()
149
with open(patch_file, 'r') as patch:
150
cursor.execute(patch.read())
152
cursor.execute('CREATE TABLE db_version (version integer);')
153
cursor.execute('INSERT INTO db_version (version) VALUES (0);')
155
cursor.execute('UPDATE db_version SET version=%d'
157
except Exception as err:
158
print("Error: %s" % err)
159
connection.rollback()