~ubuntu-branches/debian/jessie/sqlalchemy/jessie

« back to all changes in this revision

Viewing changes to test/lib/assertsql.py

  • Committer: Package Import Robot
  • Author(s): Piotr Ożarowski, Jakub Wilk, Piotr Ożarowski
  • Date: 2013-07-06 20:53:52 UTC
  • mfrom: (1.4.23) (16.1.17 experimental)
  • Revision ID: package-import@ubuntu.com-20130706205352-ryppl1eto3illd79
Tags: 0.8.2-1
[ Jakub Wilk ]
* Use canonical URIs for Vcs-* fields.

[ Piotr Ożarowski ]
* New upstream release
* Upload to unstable
* Build depend on python3-all instead of -dev, extensions are not built for
  Python 3.X 

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
 
2
 
from sqlalchemy.interfaces import ConnectionProxy
3
 
from sqlalchemy.engine.default import DefaultDialect
4
 
from sqlalchemy.engine.base import Connection
5
 
from sqlalchemy import util
6
 
import re
7
 
 
8
 
class AssertRule(object):
9
 
 
10
 
    def process_execute(self, clauseelement, *multiparams, **params):
11
 
        pass
12
 
 
13
 
    def process_cursor_execute(self, statement, parameters, context,
14
 
                               executemany):
15
 
        pass
16
 
 
17
 
    def is_consumed(self):
18
 
        """Return True if this rule has been consumed, False if not.
19
 
 
20
 
        Should raise an AssertionError if this rule's condition has
21
 
        definitely failed.
22
 
 
23
 
        """
24
 
 
25
 
        raise NotImplementedError()
26
 
 
27
 
    def rule_passed(self):
28
 
        """Return True if the last test of this rule passed, False if
29
 
        failed, None if no test was applied."""
30
 
 
31
 
        raise NotImplementedError()
32
 
 
33
 
    def consume_final(self):
34
 
        """Return True if this rule has been consumed.
35
 
 
36
 
        Should raise an AssertionError if this rule's condition has not
37
 
        been consumed or has failed.
38
 
 
39
 
        """
40
 
 
41
 
        if self._result is None:
42
 
            assert False, 'Rule has not been consumed'
43
 
        return self.is_consumed()
44
 
 
45
 
class SQLMatchRule(AssertRule):
46
 
    def __init__(self):
47
 
        self._result = None
48
 
        self._errmsg = ""
49
 
 
50
 
    def rule_passed(self):
51
 
        return self._result
52
 
 
53
 
    def is_consumed(self):
54
 
        if self._result is None:
55
 
            return False
56
 
 
57
 
        assert self._result, self._errmsg
58
 
 
59
 
        return True
60
 
 
61
 
class ExactSQL(SQLMatchRule):
62
 
 
63
 
    def __init__(self, sql, params=None):
64
 
        SQLMatchRule.__init__(self)
65
 
        self.sql = sql
66
 
        self.params = params
67
 
 
68
 
    def process_cursor_execute(self, statement, parameters, context,
69
 
                               executemany):
70
 
        if not context:
71
 
            return
72
 
        _received_statement = \
73
 
            _process_engine_statement(context.unicode_statement,
74
 
                context)
75
 
        _received_parameters = context.compiled_parameters
76
 
 
77
 
        # TODO: remove this step once all unit tests are migrated, as
78
 
        # ExactSQL should really be *exact* SQL
79
 
 
80
 
        sql = _process_assertion_statement(self.sql, context)
81
 
        equivalent = _received_statement == sql
82
 
        if self.params:
83
 
            if util.callable(self.params):
84
 
                params = self.params(context)
85
 
            else:
86
 
                params = self.params
87
 
            if not isinstance(params, list):
88
 
                params = [params]
89
 
            equivalent = equivalent and params \
90
 
                == context.compiled_parameters
91
 
        else:
92
 
            params = {}
93
 
        self._result = equivalent
94
 
        if not self._result:
95
 
            self._errmsg = \
96
 
                'Testing for exact statement %r exact params %r, '\
97
 
                'received %r with params %r' % (sql, params,
98
 
                    _received_statement, _received_parameters)
99
 
 
100
 
 
101
 
class RegexSQL(SQLMatchRule):
102
 
 
103
 
    def __init__(self, regex, params=None):
104
 
        SQLMatchRule.__init__(self)
105
 
        self.regex = re.compile(regex)
106
 
        self.orig_regex = regex
107
 
        self.params = params
108
 
 
109
 
    def process_cursor_execute(self, statement, parameters, context,
110
 
                               executemany):
111
 
        if not context:
112
 
            return
113
 
        _received_statement = \
114
 
            _process_engine_statement(context.unicode_statement,
115
 
                context)
116
 
        _received_parameters = context.compiled_parameters
117
 
        equivalent = bool(self.regex.match(_received_statement))
118
 
        if self.params:
119
 
            if util.callable(self.params):
120
 
                params = self.params(context)
121
 
            else:
122
 
                params = self.params
123
 
            if not isinstance(params, list):
124
 
                params = [params]
125
 
 
126
 
            # do a positive compare only
127
 
 
128
 
            for param, received in zip(params, _received_parameters):
129
 
                for k, v in param.iteritems():
130
 
                    if k not in received or received[k] != v:
131
 
                        equivalent = False
132
 
                        break
133
 
        else:
134
 
            params = {}
135
 
        self._result = equivalent
136
 
        if not self._result:
137
 
            self._errmsg = \
138
 
                'Testing for regex %r partial params %r, received %r '\
139
 
                'with params %r' % (self.orig_regex, params,
140
 
                                    _received_statement,
141
 
                                    _received_parameters)
142
 
 
143
 
class CompiledSQL(SQLMatchRule):
144
 
 
145
 
    def __init__(self, statement, params):
146
 
        SQLMatchRule.__init__(self)
147
 
        self.statement = statement
148
 
        self.params = params
149
 
 
150
 
    def process_cursor_execute(self, statement, parameters, context,
151
 
                               executemany):
152
 
        if not context:
153
 
            return
154
 
        _received_parameters = list(context.compiled_parameters)
155
 
 
156
 
        # recompile from the context, using the default dialect
157
 
 
158
 
        compiled = \
159
 
            context.compiled.statement.compile(dialect=DefaultDialect(),
160
 
                column_keys=context.compiled.column_keys)
161
 
        _received_statement = re.sub(r'\n', '', str(compiled))
162
 
        equivalent = self.statement == _received_statement
163
 
        if self.params:
164
 
            if util.callable(self.params):
165
 
                params = self.params(context)
166
 
            else:
167
 
                params = self.params
168
 
            if not isinstance(params, list):
169
 
                params = [params]
170
 
            all_params = list(params)
171
 
            all_received = list(_received_parameters)
172
 
            while params:
173
 
                param = dict(params.pop(0))
174
 
                for k, v in context.compiled.params.iteritems():
175
 
                    param.setdefault(k, v)
176
 
                if param not in _received_parameters:
177
 
                    equivalent = False
178
 
                    break
179
 
                else:
180
 
                    _received_parameters.remove(param)
181
 
            if _received_parameters:
182
 
                equivalent = False
183
 
        else:
184
 
            params = {}
185
 
            all_params = {}
186
 
            all_received = []
187
 
        self._result = equivalent
188
 
        if not self._result:
189
 
            print 'Testing for compiled statement %r partial params '\
190
 
                '%r, received %r with params %r' % (self.statement,
191
 
                    all_params, _received_statement, all_received)
192
 
            self._errmsg = \
193
 
                'Testing for compiled statement %r partial params %r, '\
194
 
                'received %r with params %r' % (self.statement,
195
 
                    all_params, _received_statement, all_received)
196
 
 
197
 
 
198
 
            # print self._errmsg
199
 
 
200
 
class CountStatements(AssertRule):
201
 
 
202
 
    def __init__(self, count):
203
 
        self.count = count
204
 
        self._statement_count = 0
205
 
 
206
 
    def process_execute(self, clauseelement, *multiparams, **params):
207
 
        self._statement_count += 1
208
 
 
209
 
    def process_cursor_execute(self, statement, parameters, context,
210
 
                               executemany):
211
 
        pass
212
 
 
213
 
    def is_consumed(self):
214
 
        return False
215
 
 
216
 
    def consume_final(self):
217
 
        assert self.count == self._statement_count, \
218
 
            'desired statement count %d does not match %d' \
219
 
            % (self.count, self._statement_count)
220
 
        return True
221
 
 
222
 
class AllOf(AssertRule):
223
 
 
224
 
    def __init__(self, *rules):
225
 
        self.rules = set(rules)
226
 
 
227
 
    def process_execute(self, clauseelement, *multiparams, **params):
228
 
        for rule in self.rules:
229
 
            rule.process_execute(clauseelement, *multiparams, **params)
230
 
 
231
 
    def process_cursor_execute(self, statement, parameters, context,
232
 
                               executemany):
233
 
        for rule in self.rules:
234
 
            rule.process_cursor_execute(statement, parameters, context,
235
 
                    executemany)
236
 
 
237
 
    def is_consumed(self):
238
 
        if not self.rules:
239
 
            return True
240
 
        for rule in list(self.rules):
241
 
            if rule.rule_passed():  # a rule passed, move on
242
 
                self.rules.remove(rule)
243
 
                return len(self.rules) == 0
244
 
        assert False, 'No assertion rules were satisfied for statement'
245
 
 
246
 
    def consume_final(self):
247
 
        return len(self.rules) == 0
248
 
 
249
 
def _process_engine_statement(query, context):
250
 
    if util.jython:
251
 
 
252
 
        # oracle+zxjdbc passes a PyStatement when returning into
253
 
 
254
 
        query = unicode(query)
255
 
    if context.engine.name == 'mssql' \
256
 
        and query.endswith('; select scope_identity()'):
257
 
        query = query[:-25]
258
 
    query = re.sub(r'\n', '', query)
259
 
    return query
260
 
 
261
 
def _process_assertion_statement(query, context):
262
 
    paramstyle = context.dialect.paramstyle
263
 
    if paramstyle == 'named':
264
 
        pass
265
 
    elif paramstyle =='pyformat':
266
 
        query = re.sub(r':([\w_]+)', r"%(\1)s", query)
267
 
    else:
268
 
        # positional params
269
 
        repl = None
270
 
        if paramstyle=='qmark':
271
 
            repl = "?"
272
 
        elif paramstyle=='format':
273
 
            repl = r"%s"
274
 
        elif paramstyle=='numeric':
275
 
            repl = None
276
 
        query = re.sub(r':([\w_]+)', repl, query)
277
 
 
278
 
    return query
279
 
 
280
 
class SQLAssert(object):
281
 
 
282
 
    rules = None
283
 
 
284
 
    def add_rules(self, rules):
285
 
        self.rules = list(rules)
286
 
 
287
 
    def statement_complete(self):
288
 
        for rule in self.rules:
289
 
            if not rule.consume_final():
290
 
                assert False, \
291
 
                    'All statements are complete, but pending '\
292
 
                    'assertion rules remain'
293
 
 
294
 
    def clear_rules(self):
295
 
        del self.rules
296
 
 
297
 
    def execute(self, conn, clauseelement, multiparams, params, result):
298
 
        if self.rules is not None:
299
 
            if not self.rules:
300
 
                assert False, \
301
 
                    'All rules have been exhausted, but further '\
302
 
                    'statements remain'
303
 
            rule = self.rules[0]
304
 
            rule.process_execute(clauseelement, *multiparams, **params)
305
 
            if rule.is_consumed():
306
 
                self.rules.pop(0)
307
 
 
308
 
    def cursor_execute(self, conn, cursor, statement, parameters,
309
 
                       context, executemany):
310
 
        if self.rules:
311
 
            rule = self.rules[0]
312
 
            rule.process_cursor_execute(statement, parameters, context,
313
 
                    executemany)
314
 
 
315
 
asserter = SQLAssert()
316