~ubuntu-branches/debian/jessie/sqlalchemy/jessie

« back to all changes in this revision

Viewing changes to lib/sqlalchemy/testing/assertsql.py

  • Committer: Package Import Robot
  • Author(s): Piotr Ożarowski, Jakub Wilk, Piotr Ożarowski
  • Date: 2013-07-06 20:53:52 UTC
  • mfrom: (1.4.23) (16.1.17 experimental)
  • Revision ID: package-import@ubuntu.com-20130706205352-ryppl1eto3illd79
Tags: 0.8.2-1
[ Jakub Wilk ]
* Use canonical URIs for Vcs-* fields.

[ Piotr Ożarowski ]
* New upstream release
* Upload to unstable
* Build depend on python3-all instead of -dev, extensions are not built for
  Python 3.X 

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
 
 
2
from ..engine.default import DefaultDialect
 
3
from .. import util
 
4
import re
 
5
 
 
6
 
 
7
class AssertRule(object):
 
8
 
 
9
    def process_execute(self, clauseelement, *multiparams, **params):
 
10
        pass
 
11
 
 
12
    def process_cursor_execute(self, statement, parameters, context,
 
13
                               executemany):
 
14
        pass
 
15
 
 
16
    def is_consumed(self):
 
17
        """Return True if this rule has been consumed, False if not.
 
18
 
 
19
        Should raise an AssertionError if this rule's condition has
 
20
        definitely failed.
 
21
 
 
22
        """
 
23
 
 
24
        raise NotImplementedError()
 
25
 
 
26
    def rule_passed(self):
 
27
        """Return True if the last test of this rule passed, False if
 
28
        failed, None if no test was applied."""
 
29
 
 
30
        raise NotImplementedError()
 
31
 
 
32
    def consume_final(self):
 
33
        """Return True if this rule has been consumed.
 
34
 
 
35
        Should raise an AssertionError if this rule's condition has not
 
36
        been consumed or has failed.
 
37
 
 
38
        """
 
39
 
 
40
        if self._result is None:
 
41
            assert False, 'Rule has not been consumed'
 
42
        return self.is_consumed()
 
43
 
 
44
 
 
45
class SQLMatchRule(AssertRule):
 
46
    def __init__(self):
 
47
        self._result = None
 
48
        self._errmsg = ""
 
49
 
 
50
    def rule_passed(self):
 
51
        return self._result
 
52
 
 
53
    def is_consumed(self):
 
54
        if self._result is None:
 
55
            return False
 
56
 
 
57
        assert self._result, self._errmsg
 
58
 
 
59
        return True
 
60
 
 
61
 
 
62
class ExactSQL(SQLMatchRule):
 
63
 
 
64
    def __init__(self, sql, params=None):
 
65
        SQLMatchRule.__init__(self)
 
66
        self.sql = sql
 
67
        self.params = params
 
68
 
 
69
    def process_cursor_execute(self, statement, parameters, context,
 
70
                               executemany):
 
71
        if not context:
 
72
            return
 
73
        _received_statement = \
 
74
            _process_engine_statement(context.unicode_statement,
 
75
                context)
 
76
        _received_parameters = context.compiled_parameters
 
77
 
 
78
        # TODO: remove this step once all unit tests are migrated, as
 
79
        # ExactSQL should really be *exact* SQL
 
80
 
 
81
        sql = _process_assertion_statement(self.sql, context)
 
82
        equivalent = _received_statement == sql
 
83
        if self.params:
 
84
            if util.callable(self.params):
 
85
                params = self.params(context)
 
86
            else:
 
87
                params = self.params
 
88
            if not isinstance(params, list):
 
89
                params = [params]
 
90
            equivalent = equivalent and params \
 
91
                == context.compiled_parameters
 
92
        else:
 
93
            params = {}
 
94
        self._result = equivalent
 
95
        if not self._result:
 
96
            self._errmsg = \
 
97
                'Testing for exact statement %r exact params %r, '\
 
98
                'received %r with params %r' % (sql, params,
 
99
                    _received_statement, _received_parameters)
 
100
 
 
101
 
 
102
class RegexSQL(SQLMatchRule):
 
103
 
 
104
    def __init__(self, regex, params=None):
 
105
        SQLMatchRule.__init__(self)
 
106
        self.regex = re.compile(regex)
 
107
        self.orig_regex = regex
 
108
        self.params = params
 
109
 
 
110
    def process_cursor_execute(self, statement, parameters, context,
 
111
                               executemany):
 
112
        if not context:
 
113
            return
 
114
        _received_statement = \
 
115
            _process_engine_statement(context.unicode_statement,
 
116
                context)
 
117
        _received_parameters = context.compiled_parameters
 
118
        equivalent = bool(self.regex.match(_received_statement))
 
119
        if self.params:
 
120
            if util.callable(self.params):
 
121
                params = self.params(context)
 
122
            else:
 
123
                params = self.params
 
124
            if not isinstance(params, list):
 
125
                params = [params]
 
126
 
 
127
            # do a positive compare only
 
128
 
 
129
            for param, received in zip(params, _received_parameters):
 
130
                for k, v in param.iteritems():
 
131
                    if k not in received or received[k] != v:
 
132
                        equivalent = False
 
133
                        break
 
134
        else:
 
135
            params = {}
 
136
        self._result = equivalent
 
137
        if not self._result:
 
138
            self._errmsg = \
 
139
                'Testing for regex %r partial params %r, received %r '\
 
140
                'with params %r' % (self.orig_regex, params,
 
141
                                    _received_statement,
 
142
                                    _received_parameters)
 
143
 
 
144
 
 
145
class CompiledSQL(SQLMatchRule):
 
146
 
 
147
    def __init__(self, statement, params=None):
 
148
        SQLMatchRule.__init__(self)
 
149
        self.statement = statement
 
150
        self.params = params
 
151
 
 
152
    def process_cursor_execute(self, statement, parameters, context,
 
153
                               executemany):
 
154
        if not context:
 
155
            return
 
156
        from sqlalchemy.schema import _DDLCompiles
 
157
        _received_parameters = list(context.compiled_parameters)
 
158
 
 
159
        # recompile from the context, using the default dialect
 
160
 
 
161
        if isinstance(context.compiled.statement, _DDLCompiles):
 
162
            compiled = \
 
163
                context.compiled.statement.compile(dialect=DefaultDialect())
 
164
        else:
 
165
            compiled = \
 
166
                context.compiled.statement.compile(dialect=DefaultDialect(),
 
167
                column_keys=context.compiled.column_keys)
 
168
        _received_statement = re.sub(r'[\n\t]', '', str(compiled))
 
169
        equivalent = self.statement == _received_statement
 
170
        if self.params:
 
171
            if util.callable(self.params):
 
172
                params = self.params(context)
 
173
            else:
 
174
                params = self.params
 
175
            if not isinstance(params, list):
 
176
                params = [params]
 
177
            else:
 
178
                params = list(params)
 
179
            all_params = list(params)
 
180
            all_received = list(_received_parameters)
 
181
            while params:
 
182
                param = dict(params.pop(0))
 
183
                for k, v in context.compiled.params.iteritems():
 
184
                    param.setdefault(k, v)
 
185
                if param not in _received_parameters:
 
186
                    equivalent = False
 
187
                    break
 
188
                else:
 
189
                    _received_parameters.remove(param)
 
190
            if _received_parameters:
 
191
                equivalent = False
 
192
        else:
 
193
            params = {}
 
194
            all_params = {}
 
195
            all_received = []
 
196
        self._result = equivalent
 
197
        if not self._result:
 
198
            print 'Testing for compiled statement %r partial params '\
 
199
                '%r, received %r with params %r' % (self.statement,
 
200
                    all_params, _received_statement, all_received)
 
201
            self._errmsg = \
 
202
                'Testing for compiled statement %r partial params %r, '\
 
203
                'received %r with params %r' % (self.statement,
 
204
                    all_params, _received_statement, all_received)
 
205
 
 
206
 
 
207
            # print self._errmsg
 
208
 
 
209
class CountStatements(AssertRule):
 
210
 
 
211
    def __init__(self, count):
 
212
        self.count = count
 
213
        self._statement_count = 0
 
214
 
 
215
    def process_execute(self, clauseelement, *multiparams, **params):
 
216
        self._statement_count += 1
 
217
 
 
218
    def process_cursor_execute(self, statement, parameters, context,
 
219
                               executemany):
 
220
        pass
 
221
 
 
222
    def is_consumed(self):
 
223
        return False
 
224
 
 
225
    def consume_final(self):
 
226
        assert self.count == self._statement_count, \
 
227
            'desired statement count %d does not match %d' \
 
228
            % (self.count, self._statement_count)
 
229
        return True
 
230
 
 
231
 
 
232
class AllOf(AssertRule):
 
233
 
 
234
    def __init__(self, *rules):
 
235
        self.rules = set(rules)
 
236
 
 
237
    def process_execute(self, clauseelement, *multiparams, **params):
 
238
        for rule in self.rules:
 
239
            rule.process_execute(clauseelement, *multiparams, **params)
 
240
 
 
241
    def process_cursor_execute(self, statement, parameters, context,
 
242
                               executemany):
 
243
        for rule in self.rules:
 
244
            rule.process_cursor_execute(statement, parameters, context,
 
245
                    executemany)
 
246
 
 
247
    def is_consumed(self):
 
248
        if not self.rules:
 
249
            return True
 
250
        for rule in list(self.rules):
 
251
            if rule.rule_passed():  # a rule passed, move on
 
252
                self.rules.remove(rule)
 
253
                return len(self.rules) == 0
 
254
        assert False, 'No assertion rules were satisfied for statement'
 
255
 
 
256
    def consume_final(self):
 
257
        return len(self.rules) == 0
 
258
 
 
259
 
 
260
def _process_engine_statement(query, context):
 
261
    if util.jython:
 
262
 
 
263
        # oracle+zxjdbc passes a PyStatement when returning into
 
264
 
 
265
        query = unicode(query)
 
266
    if context.engine.name == 'mssql' \
 
267
        and query.endswith('; select scope_identity()'):
 
268
        query = query[:-25]
 
269
    query = re.sub(r'\n', '', query)
 
270
    return query
 
271
 
 
272
 
 
273
def _process_assertion_statement(query, context):
 
274
    paramstyle = context.dialect.paramstyle
 
275
    if paramstyle == 'named':
 
276
        pass
 
277
    elif paramstyle == 'pyformat':
 
278
        query = re.sub(r':([\w_]+)', r"%(\1)s", query)
 
279
    else:
 
280
        # positional params
 
281
        repl = None
 
282
        if paramstyle == 'qmark':
 
283
            repl = "?"
 
284
        elif paramstyle == 'format':
 
285
            repl = r"%s"
 
286
        elif paramstyle == 'numeric':
 
287
            repl = None
 
288
        query = re.sub(r':([\w_]+)', repl, query)
 
289
 
 
290
    return query
 
291
 
 
292
 
 
293
class SQLAssert(object):
 
294
 
 
295
    rules = None
 
296
 
 
297
    def add_rules(self, rules):
 
298
        self.rules = list(rules)
 
299
 
 
300
    def statement_complete(self):
 
301
        for rule in self.rules:
 
302
            if not rule.consume_final():
 
303
                assert False, \
 
304
                    'All statements are complete, but pending '\
 
305
                    'assertion rules remain'
 
306
 
 
307
    def clear_rules(self):
 
308
        del self.rules
 
309
 
 
310
    def execute(self, conn, clauseelement, multiparams, params, result):
 
311
        if self.rules is not None:
 
312
            if not self.rules:
 
313
                assert False, \
 
314
                    'All rules have been exhausted, but further '\
 
315
                    'statements remain'
 
316
            rule = self.rules[0]
 
317
            rule.process_execute(clauseelement, *multiparams, **params)
 
318
            if rule.is_consumed():
 
319
                self.rules.pop(0)
 
320
 
 
321
    def cursor_execute(self, conn, cursor, statement, parameters,
 
322
                       context, executemany):
 
323
        if self.rules:
 
324
            rule = self.rules[0]
 
325
            rule.process_cursor_execute(statement, parameters, context,
 
326
                    executemany)
 
327
 
 
328
asserter = SQLAssert()