~hadware/magicicada-server/trusty-support

« back to all changes in this revision

Viewing changes to src/backends/db/dbwatcher.py

  • Committer: Facundo Batista
  • Date: 2015-08-05 13:10:02 UTC
  • Revision ID: facundo@taniquetil.com.ar-20150805131002-he7b7k704d8o7js6
First released version.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright 2008-2015 Canonical
 
2
#
 
3
# This program is free software: you can redistribute it and/or modify
 
4
# it under the terms of the GNU Affero General Public License as
 
5
# published by the Free Software Foundation, either version 3 of the
 
6
# License, or (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU Affero General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU Affero General Public License
 
14
# along with this program. If not, see <http://www.gnu.org/licenses/>.
 
15
#
 
16
# For further info, check  http://launchpad.net/filesync-server
 
17
 
 
18
"""Infrastructure for tracking activity on database connections."""
 
19
 
 
20
__metaclass__ = type
 
21
 
 
22
import re
 
23
 
 
24
 
 
25
class NoDatabaseName(Exception):
 
26
    """Could not detect database name in connect arguments."""
 
27
 
 
28
 
 
29
class DatabaseNotEnabled(Exception):
 
30
    """An attempt to use a disabled database was made."""
 
31
 
 
32
    def __init__(self, dbname):
 
33
        super(DatabaseNotEnabled, self).__init__(
 
34
            "An attempt was made to use %s without enabling it." % dbname)
 
35
 
 
36
 
 
37
class DatabaseWatcher:
 
38
    """Watch database connections for use and commits."""
 
39
 
 
40
    def __init__(self, module):
 
41
        self._module = module
 
42
        self._orig_connect = None
 
43
        self._enabled_databases = set()
 
44
        self._used_databases = set()
 
45
        self._dirty_databases = set()
 
46
        self._callbacks = {}
 
47
 
 
48
    def install(self):
 
49
        """Install the database watcher."""
 
50
        assert self._orig_connect is None, "Already installed"
 
51
        self._orig_connect = self._module.connect
 
52
        self._module.connect = self._connect
 
53
 
 
54
    def uninstall(self):
 
55
        """Uninstall the database watcher."""
 
56
        assert self._orig_connect is not None, "Watcher not installed"
 
57
        self._module.connect = self._orig_connect
 
58
        self._orig_connect = None
 
59
 
 
60
    def enable(self, dbname):
 
61
        """Enable use of a given database."""
 
62
        if dbname in self._enabled_databases:
 
63
            raise AssertionError("%s already enabled" % dbname)
 
64
        self._enabled_databases.add(dbname)
 
65
 
 
66
    def disable(self, dbname):
 
67
        """Disable use of a given database."""
 
68
        if dbname not in self._enabled_databases:
 
69
            raise AssertionError("%s not enabled" % dbname)
 
70
        self._enabled_databases.remove(dbname)
 
71
 
 
72
    def _check_enabled(self, dbname):
 
73
        """Raise an exception if access to a database has not been enabled."""
 
74
        if dbname not in self._enabled_databases:
 
75
            raise DatabaseNotEnabled(dbname)
 
76
 
 
77
    def _connect(self, *args, **kwargs):
 
78
        """Create a new connection object, noting the database name."""
 
79
        dbname = None
 
80
        if 'database' in kwargs:
 
81
            dbname = kwargs['database']
 
82
        elif len(args) > 0:
 
83
            match = re.search(r'dbname=(\w+)', args[0])
 
84
            if match:
 
85
                dbname = match.group(1)
 
86
        if dbname is None:
 
87
            raise NoDatabaseName("Could not determine database name.")
 
88
        self._check_enabled(dbname)
 
89
        return ConnectionWrapper(
 
90
            self, dbname, self._orig_connect(*args, **kwargs))
 
91
 
 
92
    def hook(self, dbname, callback):
 
93
        """Register a callback to be notified about a particular database.
 
94
 
 
95
        The callback will be called with arguments (dbname, commit).
 
96
        """
 
97
        self._callbacks.setdefault(dbname, set()).add(callback)
 
98
 
 
99
    def unhook(self, dbname, callback):
 
100
        """Deregister a callback that was registered with hook()."""
 
101
        try:
 
102
            self._callbacks[dbname].remove(callback)
 
103
        except KeyError:
 
104
            raise
 
105
 
 
106
    def reset(self, dbname):
 
107
        """Reset the used and dirty states for a database."""
 
108
        self._used_databases.discard(dbname)
 
109
        self._dirty_databases.discard(dbname)
 
110
 
 
111
    def _notify(self, dbname, commit=False):
 
112
        """Notify interested parties about activity on a database."""
 
113
        for callback in self._callbacks.get(dbname, set()):
 
114
            callback(dbname, commit)
 
115
 
 
116
    def _saw_execute(self, dbname):
 
117
        """Report statement execution for a database."""
 
118
        if dbname not in self._used_databases:
 
119
            self._used_databases.add(dbname)
 
120
            self._notify(dbname)
 
121
 
 
122
    def _saw_commit(self, dbname):
 
123
        """Report a commit on a database."""
 
124
        if dbname not in self._dirty_databases:
 
125
            self._dirty_databases.add(dbname)
 
126
            self._notify(dbname, True)
 
127
        # If we've committed to the DB, mark it as used too.
 
128
        self._used_databases.add(dbname)
 
129
 
 
130
 
 
131
class ConnectionWrapper:
 
132
    """A wrapper around a DB-API connection that reports commits."""
 
133
 
 
134
    def __init__(self, dbwatcher, dbname, real_connection):
 
135
        self.__dict__['_dbwatcher'] = dbwatcher
 
136
        self.__dict__['_dbname'] = dbname
 
137
        self.__dict__['_real_connection'] = real_connection
 
138
 
 
139
    def commit(self):
 
140
        """Commit the transaction and notify the watcher."""
 
141
        self._dbwatcher._check_enabled(self._dbname)
 
142
        try:
 
143
            self._real_connection.commit()
 
144
        finally:
 
145
            self._dbwatcher._saw_commit(self._dbname)
 
146
 
 
147
    def cursor(self):
 
148
        """Create a cursor that notifies the watcher of statement execution."""
 
149
        return CursorWrapper(self, self._real_connection.cursor())
 
150
 
 
151
    def __getattr__(self, attr):
 
152
        """Pass attribute access through to the real connection."""
 
153
        return getattr(self._real_connection, attr)
 
154
 
 
155
    def __setattr__(self, attr, value):
 
156
        """Pass attribute access through to the real connection."""
 
157
        setattr(self._real_connection, attr, value)
 
158
 
 
159
 
 
160
class CursorWrapper:
 
161
    """A wrapper around a DB-API cursor that reports executes."""
 
162
 
 
163
    def __init__(self, connection, real_cursor):
 
164
        self.__dict__['_connection'] = connection
 
165
        self.__dict__['_real_cursor'] = real_cursor
 
166
 
 
167
    def execute(self, *args, **kwargs):
 
168
        """Execute a statement and notify the watcher."""
 
169
        self._connection._dbwatcher._check_enabled(self._connection._dbname)
 
170
        try:
 
171
            return self._real_cursor.execute(*args, **kwargs)
 
172
        finally:
 
173
            self._connection._dbwatcher._saw_execute(self._connection._dbname)
 
174
 
 
175
    def executemany(self, *args, **kwargs):
 
176
        """Execute a statement and notify the watcher."""
 
177
        self._connection._dbwatcher._check_enabled(self._connection._dbname)
 
178
        try:
 
179
            return self._real_cursor.executemany(*args, **kwargs)
 
180
        finally:
 
181
            self._connection._dbwatcher._saw_execute(self._connection._dbname)
 
182
 
 
183
    def __getattr__(self, attr):
 
184
        """Pass attribute access through to the real cursor."""
 
185
        return getattr(self._real_cursor, attr)
 
186
 
 
187
    def __setattr__(self, attr, value):
 
188
        """Pass attribute access through to the real cursor."""
 
189
        setattr(self._real_cursor, attr, value)