1
# Copyright 2008-2015 Canonical
3
# This program is free software: you can redistribute it and/or modify
4
# it under the terms of the GNU Affero General Public License as
5
# published by the Free Software Foundation, either version 3 of the
6
# License, or (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU Affero General Public License for more details.
13
# You should have received a copy of the GNU Affero General Public License
14
# along with this program. If not, see <http://www.gnu.org/licenses/>.
16
# For further info, check http://launchpad.net/filesync-server
18
"""Infrastructure for tracking activity on database connections."""
25
class NoDatabaseName(Exception):
26
"""Could not detect database name in connect arguments."""
29
class DatabaseNotEnabled(Exception):
30
"""An attempt to use a disabled database was made."""
32
def __init__(self, dbname):
33
super(DatabaseNotEnabled, self).__init__(
34
"An attempt was made to use %s without enabling it." % dbname)
37
class DatabaseWatcher:
38
"""Watch database connections for use and commits."""
40
def __init__(self, module):
42
self._orig_connect = None
43
self._enabled_databases = set()
44
self._used_databases = set()
45
self._dirty_databases = set()
49
"""Install the database watcher."""
50
assert self._orig_connect is None, "Already installed"
51
self._orig_connect = self._module.connect
52
self._module.connect = self._connect
55
"""Uninstall the database watcher."""
56
assert self._orig_connect is not None, "Watcher not installed"
57
self._module.connect = self._orig_connect
58
self._orig_connect = None
60
def enable(self, dbname):
61
"""Enable use of a given database."""
62
if dbname in self._enabled_databases:
63
raise AssertionError("%s already enabled" % dbname)
64
self._enabled_databases.add(dbname)
66
def disable(self, dbname):
67
"""Disable use of a given database."""
68
if dbname not in self._enabled_databases:
69
raise AssertionError("%s not enabled" % dbname)
70
self._enabled_databases.remove(dbname)
72
def _check_enabled(self, dbname):
73
"""Raise an exception if access to a database has not been enabled."""
74
if dbname not in self._enabled_databases:
75
raise DatabaseNotEnabled(dbname)
77
def _connect(self, *args, **kwargs):
78
"""Create a new connection object, noting the database name."""
80
if 'database' in kwargs:
81
dbname = kwargs['database']
83
match = re.search(r'dbname=(\w+)', args[0])
85
dbname = match.group(1)
87
raise NoDatabaseName("Could not determine database name.")
88
self._check_enabled(dbname)
89
return ConnectionWrapper(
90
self, dbname, self._orig_connect(*args, **kwargs))
92
def hook(self, dbname, callback):
93
"""Register a callback to be notified about a particular database.
95
The callback will be called with arguments (dbname, commit).
97
self._callbacks.setdefault(dbname, set()).add(callback)
99
def unhook(self, dbname, callback):
100
"""Deregister a callback that was registered with hook()."""
102
self._callbacks[dbname].remove(callback)
106
def reset(self, dbname):
107
"""Reset the used and dirty states for a database."""
108
self._used_databases.discard(dbname)
109
self._dirty_databases.discard(dbname)
111
def _notify(self, dbname, commit=False):
112
"""Notify interested parties about activity on a database."""
113
for callback in self._callbacks.get(dbname, set()):
114
callback(dbname, commit)
116
def _saw_execute(self, dbname):
117
"""Report statement execution for a database."""
118
if dbname not in self._used_databases:
119
self._used_databases.add(dbname)
122
def _saw_commit(self, dbname):
123
"""Report a commit on a database."""
124
if dbname not in self._dirty_databases:
125
self._dirty_databases.add(dbname)
126
self._notify(dbname, True)
127
# If we've committed to the DB, mark it as used too.
128
self._used_databases.add(dbname)
131
class ConnectionWrapper:
132
"""A wrapper around a DB-API connection that reports commits."""
134
def __init__(self, dbwatcher, dbname, real_connection):
135
self.__dict__['_dbwatcher'] = dbwatcher
136
self.__dict__['_dbname'] = dbname
137
self.__dict__['_real_connection'] = real_connection
140
"""Commit the transaction and notify the watcher."""
141
self._dbwatcher._check_enabled(self._dbname)
143
self._real_connection.commit()
145
self._dbwatcher._saw_commit(self._dbname)
148
"""Create a cursor that notifies the watcher of statement execution."""
149
return CursorWrapper(self, self._real_connection.cursor())
151
def __getattr__(self, attr):
152
"""Pass attribute access through to the real connection."""
153
return getattr(self._real_connection, attr)
155
def __setattr__(self, attr, value):
156
"""Pass attribute access through to the real connection."""
157
setattr(self._real_connection, attr, value)
161
"""A wrapper around a DB-API cursor that reports executes."""
163
def __init__(self, connection, real_cursor):
164
self.__dict__['_connection'] = connection
165
self.__dict__['_real_cursor'] = real_cursor
167
def execute(self, *args, **kwargs):
168
"""Execute a statement and notify the watcher."""
169
self._connection._dbwatcher._check_enabled(self._connection._dbname)
171
return self._real_cursor.execute(*args, **kwargs)
173
self._connection._dbwatcher._saw_execute(self._connection._dbname)
175
def executemany(self, *args, **kwargs):
176
"""Execute a statement and notify the watcher."""
177
self._connection._dbwatcher._check_enabled(self._connection._dbname)
179
return self._real_cursor.executemany(*args, **kwargs)
181
self._connection._dbwatcher._saw_execute(self._connection._dbname)
183
def __getattr__(self, attr):
184
"""Pass attribute access through to the real cursor."""
185
return getattr(self._real_cursor, attr)
187
def __setattr__(self, attr, value):
188
"""Pass attribute access through to the real cursor."""
189
setattr(self._real_cursor, attr, value)