1
# -*- test-case-name: twext.enterprise.dal.test.test_parseschema -*-
3
# Copyright (c) 2010 Apple Inc. All rights reserved.
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
9
# http://www.apache.org/licenses/LICENSE-2.0
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
19
Parser for SQL schema.
22
from itertools import chain
24
from sqlparse import parse, keywords
25
from sqlparse.tokens import Keyword, Punctuation, Number, String, Name
26
from sqlparse.sql import (Comment, Identifier, Parenthesis, IdentifierList,
29
from twext.enterprise.dal.model import (
30
Schema, Table, SQLType, ProcedureCall, Constraint, Sequence, Index)
36
Work around bugs in SQLParse, adding SEQUENCE as a keyword (since it is
37
treated as one in postgres) and removing ACCESS and SIZE (since we use those
38
as column names). Technically those are keywords in SQL, but they aren't
39
treated as such by postgres's parser.
41
keywords.KEYWORDS['SEQUENCE'] = Keyword
42
for columnNameKeyword in ['ACCESS', 'SIZE']:
43
del keywords.KEYWORDS[columnNameKeyword]
49
def tableFromCreateStatement(schema, stmt):
51
Add a table from a CREATE TABLE sqlparse statement object.
53
@param schema: The schema to add the table statement to.
55
@type schema: L{Schema}
57
@param stmt: The C{CREATE TABLE} statement object.
59
@type stmt: L{Statement}
61
i = iterSignificant(stmt)
62
expect(i, ttype=Keyword.DDL, value='CREATE')
63
expect(i, ttype=Keyword, value='TABLE')
64
function = expect(i, cls=Function)
65
i = iterSignificant(function)
66
name = expect(i, cls=Identifier).get_name().encode('utf-8')
67
self = Table(schema, name)
68
parens = expect(i, cls=Parenthesis)
69
cp = _ColumnParser(self, iterSignificant(parens), parens)
75
def schemaFromPath(path):
79
@param path: a L{FilePath}-like object containing SQL.
81
@return: a L{Schema} object with the contents of the given C{path} parsed
82
and added to it as L{Table} objects.
84
schema = Schema(path.basename())
85
schemaData = path.getContent()
86
addSQLToSchema(schema, schemaData)
91
def addSQLToSchema(schema, schemaData):
93
Add new SQL to an existing schema.
95
@param schema: The schema to add the new SQL to.
97
@type schema: L{Schema}
99
@param schemaData: A string containing some SQL statements.
101
@type schemaData: C{str}
103
@return: the C{schema} argument
105
parsed = parse(schemaData)
108
while stmt.tokens and not significant(stmt.tokens[0]):
109
preface += str(stmt.tokens.pop(0))
112
if stmt.get_type() == 'CREATE':
113
createType = stmt.token_next(1, True).value.upper()
114
if createType == u'TABLE':
115
t = tableFromCreateStatement(schema, stmt)
116
t.addComment(preface)
117
elif createType == u'SEQUENCE':
119
stmt.token_next(2, True).get_name().encode('utf-8'))
120
elif createType == u'INDEX':
121
signifindex = iterSignificant(stmt)
122
expect(signifindex, ttype=Keyword.DDL, value='CREATE')
123
expect(signifindex, ttype=Keyword, value='INDEX')
124
indexName = nameOrIdentifier(signifindex.next())
125
expect(signifindex, ttype=Keyword, value='ON')
126
[tableName, columnArgs] = iterSignificant(expect(signifindex,
128
tableName = nameOrIdentifier(tableName)
129
arggetter = iterSignificant(columnArgs)
131
expect(arggetter, ttype=Punctuation, value=u'(')
132
valueOrValues = arggetter.next()
133
if isinstance(valueOrValues, IdentifierList):
134
valuelist = valueOrValues.get_identifiers()
136
valuelist = [valueOrValues]
137
expect(arggetter, ttype=Punctuation, value=u')')
139
idx = Index(schema, indexName, schema.tableNamed(tableName))
140
for token in valuelist:
141
columnName = nameOrIdentifier(token)
142
idx.addColumn(idx.table.columnNamed(columnName))
143
elif stmt.get_type() == 'INSERT':
144
insertTokens = iterSignificant(stmt)
145
expect(insertTokens, ttype=Keyword.DML, value='INSERT')
146
expect(insertTokens, ttype=Keyword, value='INTO')
147
tableName = expect(insertTokens, cls=Identifier).get_name()
148
expect(insertTokens, ttype=Keyword, value='VALUES')
149
values = expect(insertTokens, cls=Parenthesis)
150
vals = iterSignificant(values)
151
expect(vals, ttype=Punctuation, value='(')
152
valuelist = expect(vals, cls=IdentifierList)
153
expect(vals, ttype=Punctuation, value=')')
155
for ident in valuelist.get_identifiers():
157
{Number.Integer: int,
158
String.Single: _destringify}
159
[ident.ttype](ident.value)
162
schema.tableNamed(tableName).insertSchemaRow(rowData)
164
print 'unknown type:', stmt.get_type()
169
class _ColumnParser(object):
171
Stateful parser for the things between commas.
174
def __init__(self, table, parenIter, parens):
176
@param table: the L{Table} to add data to.
178
@param parenIter: the iterator.
181
self.iter = parenIter
187
This object is an iterator; return itself.
194
Get the next L{IdentifierList}.
196
result = self.iter.next()
197
if isinstance(result, IdentifierList):
198
# Expand out all identifier lists, since they seem to pop up
199
# incorrectly. We should never see one in a column list anyway.
200
# http://code.google.com/p/python-sqlparse/issues/detail?id=25
202
it = result.tokens.pop()
209
def pushback(self, value):
211
Push the value back onto this iterator so it will be returned by the
212
next call to C{next}.
214
self.iter = chain(iter((value,)), self.iter)
221
expect(self.iter, ttype=Punctuation, value=u"(")
222
while self.nextColumn():
226
def nextColumn(self):
228
Parse the next column or constraint, depending on the next token.
230
maybeIdent = self.next()
231
if maybeIdent.ttype == Name:
232
return self.parseColumn(maybeIdent.value)
233
elif isinstance(maybeIdent, Identifier):
234
return self.parseColumn(maybeIdent.get_name())
236
return self.parseConstraint(maybeIdent)
239
def namesInParens(self, parens):
240
parens = iterSignificant(parens)
241
expect(parens, ttype=Punctuation, value="(")
242
idorids = parens.next()
243
if isinstance(idorids, Identifier):
244
idnames = [idorids.get_name()]
245
elif isinstance(idorids, IdentifierList):
246
idnames = [x.get_name() for x in idorids.get_identifiers()]
248
raise ViolatedExpectation("identifier or list", repr(idorids))
249
expect(parens, ttype=Punctuation, value=")")
253
def parseConstraint(self, constraintType):
255
Parse a 'free' constraint, described explicitly in the table as opposed
256
to being implicitly associated with a column by being placed after it.
258
# only know about PRIMARY KEY and UNIQUE for now
259
if constraintType.match(Keyword, 'PRIMARY'):
260
expect(self, ttype=Keyword, value='KEY')
261
names = self.namesInParens(expect(self, cls=Parenthesis))
262
self.table.primaryKey = [self.table.columnNamed(n) for n in names]
263
elif constraintType.match(Keyword, 'UNIQUE'):
264
names = self.namesInParens(expect(self, cls=Parenthesis))
265
self.table.tableConstraint(Constraint.UNIQUE, names)
267
raise ViolatedExpectation('PRIMARY or UNIQUE', constraintType)
268
return self.checkEnd(self.next())
271
def checkEnd(self, val):
273
After a column or constraint, check the end.
275
if val.value == u",":
277
elif val.value == u")":
280
raise ViolatedExpectation(", or )", val)
283
def parseColumn(self, name):
285
Parse a column with the given name.
287
typeName = self.next()
288
if isinstance(typeName, Function):
289
[funcIdent, args] = iterSignificant(typeName)
291
arggetter = iterSignificant(args)
292
expect(arggetter, value=u'(')
293
typeLength = int(expect(arggetter,
294
ttype=Number.Integer).value.encode('utf-8'))
296
maybeTypeArgs = self.next()
297
if isinstance(maybeTypeArgs, Parenthesis):
299
significant = iterSignificant(maybeTypeArgs)
300
expect(significant, value=u"(")
301
typeLength = int(significant.next().value)
305
self.pushback(maybeTypeArgs)
306
theType = SQLType(typeName.value.encode("utf-8"), typeLength)
307
theColumn = self.table.addColumn(
308
name=name.encode("utf-8"), type=theType
311
if val.ttype == Punctuation:
312
return self.checkEnd(val)
315
def oneConstraint(t):
316
self.table.tableConstraint(t,
319
if val.match(Keyword, 'PRIMARY'):
320
expect(self, ttype=Keyword, value='KEY')
321
# XXX check to make sure there's no other primary key yet
322
self.table.primaryKey = [theColumn]
323
elif val.match(Keyword, 'UNIQUE'):
324
# XXX add UNIQUE constraint
325
oneConstraint(Constraint.UNIQUE)
326
elif val.match(Keyword, 'NOT'):
327
# possibly not necessary, as 'NOT NULL' is a single keyword
328
# in sqlparse as of 0.1.2
329
expect(self, ttype=Keyword, value='NULL')
330
oneConstraint(Constraint.NOT_NULL)
331
elif val.match(Keyword, 'NOT NULL'):
332
oneConstraint(Constraint.NOT_NULL)
333
elif val.match(Keyword, 'DEFAULT'):
334
theDefault = self.next()
335
if isinstance(theDefault, Function):
336
thingo = theDefault.tokens[0].get_name()
337
parens = expectSingle(
338
theDefault.tokens[-1], cls=Parenthesis
340
pareniter = iterSignificant(parens)
341
if thingo.upper() == 'NEXTVAL':
342
expect(pareniter, ttype=Punctuation, value="(")
343
seqname = _destringify(
344
expect(pareniter, ttype=String.Single).value)
345
defaultValue = self.table.schema.sequenceNamed(
348
defaultValue.referringColumns.append(theColumn)
350
defaultValue = ProcedureCall(thingo.encode('utf-8'),
352
elif theDefault.ttype == Number.Integer:
353
defaultValue = int(theDefault.value)
354
elif (theDefault.ttype == Keyword and
355
theDefault.value.lower() == 'false'):
357
elif (theDefault.ttype == Keyword and
358
theDefault.value.lower() == 'true'):
360
elif (theDefault.ttype == Keyword and
361
theDefault.value.lower() == 'null'):
363
elif theDefault.ttype == String.Single:
364
defaultValue = _destringify(theDefault.value)
367
"not sure what to do: default %r" % (
369
theColumn.setDefaultValue(defaultValue)
370
elif val.match(Keyword, 'REFERENCES'):
371
target = nameOrIdentifier(self.next())
372
theColumn.doesReferenceName(target)
373
elif val.match(Keyword, 'ON'):
374
expect(self, ttype=Keyword.DML, value='DELETE')
375
expect(self, ttype=Keyword, value='CASCADE')
376
theColumn.cascade = True
380
print 'UNEXPECTED TOKEN:', repr(val), theColumn
383
pprint.pprint(self.parens.tokens)
389
class ViolatedExpectation(Exception):
391
An expectation about the structure of the SQL syntax was violated.
394
def __init__(self, expected, got):
395
self.expected = expected
397
super(ViolatedExpectation, self).__init__(
398
"Expected %r got %s" % (expected, got)
403
def nameOrIdentifier(token):
405
Determine if the given object is a name or an identifier, and return the
406
textual value of that name or identifier.
410
if isinstance(token, Identifier):
411
return token.get_name()
412
elif token.ttype == Name:
415
raise ViolatedExpectation("identifier or name", repr(token))
419
def expectSingle(nextval, ttype=None, value=None, cls=None):
421
Expect some properties from retrieved value.
423
@param ttype: A token type to compare against.
425
@param value: A value to compare against.
427
@param cls: A class to check if the value is an instance of.
429
@raise ViolatedExpectation: if an unexpected token is found.
431
@return: C{nextval}, if it matches.
433
if ttype is not None:
434
if nextval.ttype != ttype:
435
raise ViolatedExpectation(ttype, '%s:%r' % (nextval.ttype, nextval))
436
if value is not None:
437
if nextval.value.upper() != value.upper():
438
raise ViolatedExpectation(value, nextval.value)
440
if nextval.__class__ != cls:
441
raise ViolatedExpectation(cls, repr(nextval))
446
def expect(iterator, **kw):
448
Retrieve a value from an iterator and check its properties. Same signature
449
as L{expectSingle}, except it takes an iterator instead of a value.
451
@see: L{expectSingle}
453
nextval = iterator.next()
454
return expectSingle(nextval, **kw)
458
def significant(token):
460
Determine if the token is 'significant', i.e. that it is not a comment and
463
# comment has 'None' is_whitespace() result. intentional?
464
return (not isinstance(token, Comment) and not token.is_whitespace())
468
def iterSignificant(tokenList):
470
Iterate tokens that pass the test given by L{significant}, from a given
473
for token in tokenList.tokens:
474
if significant(token):
479
def _destringify(strval):
481
Convert a single-quoted SQL string into its actual represented value.
482
(Assumes standards compliance, since we should be controlling all the input
483
here. The only quoting syntax respected is "''".)
485
return strval[1:-1].replace("''", "'")