1
# -*- test-case-name: twisted.test.test_reflector -*-
2
# Copyright (c) 2001-2004 Twisted Matrix Laboratories.
3
# See LICENSE for details.
6
from twisted.enterprise import reflector
7
from twisted.enterprise.util import DBError, getKeyColumn, quote, safe
8
from twisted.enterprise.util import _TableInfo
9
from twisted.enterprise.row import RowObject
11
from twisted.python import reflect
13
class SQLReflector(reflector.Reflector):
17
I reflect on a database and load RowObjects from it.
19
In order to do this, I interrogate a relational database to
20
extract schema information and interface with RowObject class
21
objects that can interact with specific tables.
25
reflector.EQUAL : "=",
26
reflector.LESSTHAN : "<",
27
reflector.GREATERTHAN : ">",
28
reflector.LIKE : "like"
31
def __init__(self, dbpool, rowClasses):
32
"""Initialize me against a database.
34
reflector.Reflector.__init__(self, rowClasses)
38
self._transPopulateSchema()
40
def _transPopulateSchema(self):
41
"""Used to construct the row classes in a single interaction.
43
for rc in self.rowClasses:
44
if not issubclass(rc, RowObject):
45
raise DBError("Stub class (%s) is not derived from RowObject" % reflect.qual(rc.rowClass))
47
self._populateSchemaFor(rc)
50
def _populateSchemaFor(self, rc):
51
"""Construct all the SQL templates for database operations on
52
<tableName> and populate the class <rowClass> with that info.
54
attributes = ("rowColumns", "rowKeyColumns", "rowTableName" )
55
for att in attributes:
56
if not hasattr(rc, att):
57
raise DBError("RowClass %s must have class variable: %s" % (rc, att))
59
tableInfo = _TableInfo(rc)
60
tableInfo.updateSQL = self.buildUpdateSQL(tableInfo)
61
tableInfo.insertSQL = self.buildInsertSQL(tableInfo)
62
tableInfo.deleteSQL = self.buildDeleteSQL(tableInfo)
63
self.populateSchemaFor(tableInfo)
65
def escape_string(self, text):
66
"""Escape a string for use in an SQL statement. The default
67
implementation escapes ' with '' and \ with \\. Redefine this
68
function in a subclass if your database server uses different
73
def quote_value(self, value, type):
74
"""Format a value for use in an SQL statement.
76
@param value: a value to format as data in SQL.
77
@param type: a key in util.dbTypeMap.
79
return quote(value, type, string_escaper=self.escape_string)
81
def loadObjectsFrom(self, tableName, parentRow=None, data=None,
82
whereClause=None, forceChildren=0):
83
"""Load a set of RowObjects from a database.
85
Create a set of python objects of <rowClass> from the contents
86
of a table populated with appropriate data members.
89
| class EmployeeRow(row.RowObject):
92
| def gotEmployees(employees):
93
| for emp in employees:
94
| emp.manager = "fred smith"
95
| manager.updateRow(emp)
97
| reflector.loadObjectsFrom("employee",
99
| whereClause = [("manager" , EQUAL, "fred smith")]
100
| ).addCallback(gotEmployees)
102
NOTE: the objects and all children should be loaded in a single transaction.
103
NOTE: can specify a parentRow _OR_ a whereClause.
106
if parentRow and whereClause:
107
raise DBError("Must specify one of parentRow _OR_ whereClause")
109
info = self.getTableInfo(parentRow)
110
relationship = info.getRelationshipFor(tableName)
111
whereClause = self.buildWhereClause(relationship, parentRow)
116
return self.dbpool.runInteraction(self._rowLoader, tableName,
117
parentRow, data, whereClause,
120
def _rowLoader(self, transaction, tableName, parentRow, data,
121
whereClause, forceChildren):
122
"""immediate loading of rowobjects from the table with the whereClause.
124
tableInfo = self.schema[tableName]
125
# Build the SQL for the query
128
for column, type in tableInfo.rowColumns:
133
sql = sql + " %s" % column
134
sql = sql + " FROM %s " % (tableName)
138
for wItem in whereClause:
143
(columnName, cond, value) = wItem
144
t = self.findTypeFor(tableName, columnName)
145
quotedValue = self.quote_value(value, t)
146
sql += "%s %s %s" % (columnName, self.conditionalLabels[cond],
150
transaction.execute(sql)
151
rows = transaction.fetchall()
153
# construct the row objects
158
for i in range(0,len(args)):
159
ColumnName = tableInfo.rowColumns[i][0].lower()
160
for attr, type in tableInfo.rowClass.rowColumns:
161
if attr.lower() == ColumnName:
164
# find the row in the cache or add it
165
resultObject = self.findInCache(tableInfo.rowClass, kw)
167
meth = tableInfo.rowFactoryMethod[0]
168
resultObject = meth(tableInfo.rowClass, data, kw)
169
self.addToCache(resultObject)
170
newRows.append(resultObject)
171
results.append(resultObject)
173
# add these rows to the parentRow if required
175
self.addToParent(parentRow, newRows, tableName)
177
# load children or each of these rows if required
178
for relationship in tableInfo.relationships:
179
if not forceChildren and not relationship.autoLoad:
183
childWhereClause = self.buildWhereClause(relationship, row)
184
# load the children immediately, but do nothing with them
185
self._rowLoader(transaction,
186
relationship.childRowClass.rowTableName,
187
row, data, childWhereClause, forceChildren)
191
def findTypeFor(self, tableName, columnName):
192
tableInfo = self.schema[tableName]
193
columnName = columnName.lower()
194
for column, type in tableInfo.rowColumns:
195
if column.lower() == columnName:
198
def buildUpdateSQL(self, tableInfo):
199
"""(Internal) Build SQL template to update a RowObject.
201
Returns: SQL that is used to contruct a rowObject class.
203
sql = "UPDATE %s SET" % tableInfo.rowTableName
204
# build update attributes
206
for column, type in tableInfo.rowColumns:
207
if getKeyColumn(tableInfo.rowClass, column):
211
sql = sql + " %s = %s" % (column, "%s")
216
sql = sql + " WHERE "
217
for keyColumn, type in tableInfo.rowKeyColumns:
220
sql = sql + " %s = %s " % (keyColumn, "%s")
224
def buildInsertSQL(self, tableInfo):
225
"""(Internal) Build SQL template to insert a new row.
227
Returns: SQL that is used to insert a new row for a rowObject
228
instance not created from the database.
230
sql = "INSERT INTO %s (" % tableInfo.rowTableName
233
for column, type in tableInfo.rowColumns:
239
sql = sql + " ) VALUES ("
243
for column, type in tableInfo.rowColumns:
252
def buildDeleteSQL(self, tableInfo):
253
"""Build the SQL template to delete a row from the table.
255
sql = "DELETE FROM %s " % tableInfo.rowTableName
258
sql = sql + " WHERE "
259
for keyColumn, type in tableInfo.rowKeyColumns:
262
sql = sql + " %s = %s " % (keyColumn, "%s")
266
def updateRowSQL(self, rowObject):
267
"""Build SQL to update the contents of rowObject.
270
tableInfo = self.schema[rowObject.rowTableName]
271
# build update attributes
272
for column, type in tableInfo.rowColumns:
273
if not getKeyColumn(rowObject.__class__, column):
274
args.append(self.quote_value(rowObject.findAttribute(column),
277
for keyColumn, type in tableInfo.rowKeyColumns:
278
args.append(self.quote_value(rowObject.findAttribute(keyColumn),
281
return self.getTableInfo(rowObject).updateSQL % tuple(args)
283
def updateRow(self, rowObject):
284
"""Update the contents of rowObject to the database.
286
sql = self.updateRowSQL(rowObject)
287
rowObject.setDirty(0)
288
return self.dbpool.runOperation(sql)
290
def insertRowSQL(self, rowObject):
291
"""Build SQL to insert the contents of rowObject.
294
tableInfo = self.schema[rowObject.rowTableName]
296
for column, type in tableInfo.rowColumns:
297
args.append(self.quote_value(rowObject.findAttribute(column),type))
298
return self.getTableInfo(rowObject).insertSQL % tuple(args)
300
def insertRow(self, rowObject):
301
"""Insert a new row for rowObject.
303
rowObject.setDirty(0)
304
sql = self.insertRowSQL(rowObject)
305
return self.dbpool.runOperation(sql)
307
def deleteRowSQL(self, rowObject):
308
"""Build SQL to delete rowObject from the database.
311
tableInfo = self.schema[rowObject.rowTableName]
313
for keyColumn, type in tableInfo.rowKeyColumns:
314
args.append(self.quote_value(rowObject.findAttribute(keyColumn),
317
return self.getTableInfo(rowObject).deleteSQL % tuple(args)
319
def deleteRow(self, rowObject):
320
"""Delete the row for rowObject from the database.
322
sql = self.deleteRowSQL(rowObject)
323
self.removeFromCache(rowObject)
324
return self.dbpool.runOperation(sql)
327
__all__ = ['SQLReflector']