1
# Copyright (c) 2001-2004 Twisted Matrix Laboratories.
2
# See LICENSE for details.
5
"""Tests for twisted.enterprise reflectors."""
7
from twisted.trial import unittest
11
from twisted.internet import reactor, interfaces, defer
12
from twisted.enterprise.row import RowObject
13
from twisted.enterprise.reflector import *
14
from twisted.enterprise.sqlreflector import SQLReflector
15
from twisted.enterprise import util
16
from twisted.test.test_adbapi import makeSQLTests
18
tableName = "testTable"
19
childTableName = "childTable"
21
class TestRow(RowObject):
22
rowColumns = [("key_string", "varchar"),
24
("another_column", "varchar"),
25
("Column4", "varchar"),
27
rowKeyColumns = [("key_string", "varchar")]
28
rowTableName = tableName
30
class ChildRow(RowObject):
31
rowColumns = [("childId", "int"),
33
("test_key", "varchar"),
37
rowKeyColumns = [("childId", "int")]
38
rowTableName = childTableName
39
rowForeignKeys = [(tableName,
40
[("test_key","varchar")],
41
[("key_string","varchar")],
44
main_table_schema = """
45
CREATE TABLE testTable (
46
key_string varchar(64),
48
another_column varchar(64),
54
child_table_schema = """
55
CREATE TABLE childTable (
65
def randomizeRow(row, nulls_ok=True, trailing_spaces_ok=True):
67
for name, type in row.rowColumns:
68
if util.getKeyColumn(row, name):
69
values[name] = getattr(row, name)
71
elif nulls_ok and random.randint(0, 9) == 0:
74
value = random.randint(-10000, 10000)
76
if random.randint(0, 9) == 0:
79
value = ''.join(map(lambda i:chr(random.randrange(32,127)),
80
xrange(random.randint(1, 64))))
81
if not trailing_spaces_ok:
82
value = value.rstrip()
83
setattr(row, name, value)
87
def rowMatches(row, values):
88
for name, type in row.rowColumns:
89
if getattr(row, name) != values[name]:
90
print ("Mismatch on column %s: |%s| (row) |%s| (values)" %
91
(name, getattr(row, name), values[name]))
95
class ReflectorTestBase:
96
"""Base class for testing reflectors."""
98
if interfaces.IReactorThreads(reactor, None) is None:
99
skip = "No thread support, no reflector tests"
101
count = 100 # a parameter used for running iterative tests
103
def randomizeRow(self, row):
104
return randomizeRow(row, self.nulls_ok, self.trailing_spaces_ok)
107
d = self.createReflector()
108
d.addCallback(self._cbSetUp)
111
def _cbSetUp(self, reflector):
112
self.reflector = reflector
115
return self.destroyReflector()
117
def destroyReflector(self):
120
def testReflector(self):
121
# create one row to work with
123
row.assignKeyAttr("key_string", "first")
124
values = self.randomizeRow(row)
127
d = self.reflector.insertRow(row)
130
# now load it back in
131
whereClause = [("key_string", EQUAL, "first")]
132
d = self.reflector.loadObjectsFrom(tableName,
133
whereClause=whereClause)
134
return d.addCallback(self.gotData)
137
# make sure it came back as what we saved
138
self.failUnless(len(self.data) == 1, "no row")
139
parent = self.data[0]
140
self.failUnless(rowMatches(parent, values), "no match")
143
d.addCallback(_loadBack)
144
d.addCallback(_getParent)
145
d.addCallback(self._cbTestReflector)
148
def _cbTestReflector(self, parent):
149
# create some child rows
153
for i in range(0, self.num_iterations):
155
row.assignKeyAttr("childId", i)
156
values = self.randomizeRow(row)
157
values['test_key'] = row.test_key = "first"
158
child_values[i] = values
159
inserts.append(self.reflector.insertRow(row))
162
d = defer.gatherResults(inserts)
166
d = self.reflector.loadObjectsFrom(childTableName, parentRow=parent)
167
return d.addCallback(self.gotData)
169
def _checkLoadObjects(_):
170
self.failUnless(len(self.data) == self.num_iterations,
172
self.failUnless(len(parent.childRows) == self.num_iterations,
173
"did not load child rows: %d" % len(parent.childRows))
174
for child in parent.childRows:
175
self.failUnless(rowMatches(child, child_values[child.childId]),
176
"child %d does not match" % child.childId)
178
def _checkLoadObjects2(_):
179
self.failUnless(len(self.data) == self.num_iterations,
181
self.failUnless(len(parent.childRows) == self.num_iterations,
182
"child rows added twice!: %d" % len(parent.childRows))
184
def _changeParent(_):
185
# now change the parent
186
values[0] = self.randomizeRow(parent)
187
return self.reflector.updateRow(parent)
190
# now load it back in
191
whereClause = [("key_string", EQUAL, "first")]
192
d = self.reflector.loadObjectsFrom(tableName, whereClause=whereClause)
193
return d.addCallback(self.gotData)
195
def _checkLoadBack(_):
196
# make sure it came back as what we saved
197
self.failUnless(len(self.data) == 1, "no row")
198
parent = self.data[0]
199
self.failUnless(rowMatches(parent, values[0]), "no match")
201
test_values[parent.key_string] = values[0]
204
def _saveMoreTestRows(_):
205
# save some more test rows
207
for i in range(0, self.num_iterations):
209
row.assignKeyAttr("key_string", "bulk%d"%i)
210
test_values[row.key_string] = self.randomizeRow(row)
211
ds.append(self.reflector.insertRow(row))
212
return defer.gatherResults(ds)
214
def _loadRowsBack(_):
215
# now load them all back in
216
d = self.reflector.loadObjectsFrom("testTable")
217
return d.addCallback(self.gotData)
219
def _checkRowsBack(_):
220
# make sure they are the same
221
self.failUnless(len(self.data) == self.num_iterations + 1,
222
"query did not get rows")
223
for row in self.data:
224
self.failUnless(rowMatches(row, test_values[row.key_string]),
225
"child %s does not match" % row.key_string)
228
# now change them all
230
for row in self.data:
231
test_values[row.key_string] = self.randomizeRow(row)
232
ds.append(self.reflector.updateRow(row))
233
d = defer.gatherResults(ds)
234
return d.addCallback(_cbChangeRows)
236
def _cbChangeRows(_):
242
for row in self.data:
243
ds.append(self.reflector.deleteRow(row))
244
d = defer.gatherResults(ds)
245
return d.addCallback(_cbChangeRows)
247
def _checkRowsDeleted(_):
248
self.failUnless(len(self.data) == 0, "rows were not deleted")
250
d.addCallback(_loadObjects)
251
d.addCallback(_checkLoadObjects)
252
d.addCallback(_loadObjects)
253
d.addCallback(_checkLoadObjects2)
254
d.addCallback(_changeParent)
255
d.addCallback(_loadBack)
256
d.addCallback(_checkLoadBack)
257
d.addCallback(_saveMoreTestRows)
258
d.addCallback(_loadRowsBack)
259
d.addCallback(_checkRowsBack)
260
d.addCallback(_changeRows)
261
d.addCallback(_loadRowsBack)
262
d.addCallback(_checkRowsBack)
263
d.addCallback(_deleteRows)
264
d.addCallback(_loadRowsBack)
265
d.addCallback(_checkRowsDeleted)
269
def testSaveAndDelete(self):
270
# create one row to work with
272
row.assignKeyAttr("key_string", "first")
273
values = self.randomizeRow(row)
275
d = self.reflector.insertRow(row)
278
return self.reflector.deleteRow(row)
279
d.addCallback(_deleteRow)
283
def gotData(self, data):
286
ReflectorTestBase.timeout = 30.0
288
class SQLReflectorTestBase(ReflectorTestBase):
289
"""Base class for the SQL reflector."""
291
def createReflector(self):
293
self.dbpool = self.makePool()
297
d = self.dbpool.runOperation('DROP TABLE testTable')
298
d.addCallback(lambda _:
299
self.dbpool.runOperation('DROP TABLE childTable'))
300
d.addErrback(lambda _: None)
302
d = defer.succeed(None)
304
d.addCallback(lambda _: self.dbpool.runOperation(main_table_schema))
305
d.addCallback(lambda _: self.dbpool.runOperation(child_table_schema))
306
reflectorClass = self.escape_slashes and SQLReflector \
307
or NoSlashSQLReflector
308
d.addCallback(lambda _:
309
reflectorClass(self.dbpool, [TestRow, ChildRow]))
312
def destroyReflector(self):
313
d = self.dbpool.runOperation('DROP TABLE testTable')
314
d.addCallback(lambda _:
315
self.dbpool.runOperation('DROP TABLE childTable'))
322
# GadflyReflectorTestCase SQLiteReflectorTestCase PyPgSQLReflectorTestCase
323
# PsycopgReflectorTestCase MySQLReflectorTestCase FirebirdReflectorTestCase
324
makeSQLTests(SQLReflectorTestBase, 'ReflectorTestCase', globals())
326
class NoSlashSQLReflector(SQLReflector):
327
"""An sql reflector that only escapes single quotes."""
329
def escape_string(self, text):
330
return text.replace("'", "''")