~ubuntu-branches/ubuntu/utopic/calendarserver/utopic

« back to all changes in this revision

Viewing changes to twext/enterprise/dal/parseschema.py

  • Committer: Package Import Robot
  • Author(s): Rahul Amaram
  • Date: 2012-05-29 18:12:12 UTC
  • mfrom: (1.1.2)
  • Revision ID: package-import@ubuntu.com-20120529181212-mxjdfncopy6vou0f
Tags: 3.2+dfsg-1
* New upstream release
* Moved from using cdbs to dh sequencer
* Modified calenderserver init.d script based on /etc/init.d/skeleton script
* Removed ldapdirectory.patch as the OpenLDAP directory service has been 
  merged upstream
* Moved package to section "net" as calendarserver is more service than 
  library (Closes: #665859)
* Changed Architecture of calendarserver package to any as the package
  now includes compiled architecture dependent Python extensions
* Unowned files are no longer left on the system upon purging
  (Closes: #668731)

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# -*- test-case-name: twext.enterprise.dal.test.test_parseschema -*-
 
2
##
 
3
# Copyright (c) 2010 Apple Inc. All rights reserved.
 
4
#
 
5
# Licensed under the Apache License, Version 2.0 (the "License");
 
6
# you may not use this file except in compliance with the License.
 
7
# You may obtain a copy of the License at
 
8
#
 
9
# http://www.apache.org/licenses/LICENSE-2.0
 
10
#
 
11
# Unless required by applicable law or agreed to in writing, software
 
12
# distributed under the License is distributed on an "AS IS" BASIS,
 
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 
14
# See the License for the specific language governing permissions and
 
15
# limitations under the License.
 
16
##
 
17
 
 
18
"""
 
19
Parser for SQL schema.
 
20
"""
 
21
 
 
22
from itertools import chain
 
23
 
 
24
from sqlparse import parse, keywords
 
25
from sqlparse.tokens import Keyword, Punctuation, Number, String, Name
 
26
from sqlparse.sql import (Comment, Identifier, Parenthesis, IdentifierList,
 
27
                          Function)
 
28
 
 
29
from twext.enterprise.dal.model import (
 
30
    Schema, Table, SQLType, ProcedureCall, Constraint, Sequence, Index)
 
31
 
 
32
 
 
33
 
 
34
def _fixKeywords():
 
35
    """
 
36
    Work around bugs in SQLParse, adding SEQUENCE as a keyword (since it is
 
37
    treated as one in postgres) and removing ACCESS and SIZE (since we use those
 
38
    as column names).  Technically those are keywords in SQL, but they aren't
 
39
    treated as such by postgres's parser.
 
40
    """
 
41
    keywords.KEYWORDS['SEQUENCE'] = Keyword
 
42
    for columnNameKeyword in ['ACCESS', 'SIZE']:
 
43
        del keywords.KEYWORDS[columnNameKeyword]
 
44
 
 
45
_fixKeywords()
 
46
 
 
47
 
 
48
 
 
49
def tableFromCreateStatement(schema, stmt):
 
50
    """
 
51
    Add a table from a CREATE TABLE sqlparse statement object.
 
52
 
 
53
    @param schema: The schema to add the table statement to.
 
54
 
 
55
    @type schema: L{Schema}
 
56
 
 
57
    @param stmt: The C{CREATE TABLE} statement object.
 
58
 
 
59
    @type stmt: L{Statement}
 
60
    """
 
61
    i = iterSignificant(stmt)
 
62
    expect(i, ttype=Keyword.DDL, value='CREATE')
 
63
    expect(i, ttype=Keyword, value='TABLE')
 
64
    function = expect(i, cls=Function)
 
65
    i = iterSignificant(function)
 
66
    name = expect(i, cls=Identifier).get_name().encode('utf-8')
 
67
    self = Table(schema, name)
 
68
    parens = expect(i, cls=Parenthesis)
 
69
    cp = _ColumnParser(self, iterSignificant(parens), parens)
 
70
    cp.parse()
 
71
    return self
 
72
 
 
73
 
 
74
 
 
75
def schemaFromPath(path):
 
76
    """
 
77
    Get a L{Schema}.
 
78
 
 
79
    @param path: a L{FilePath}-like object containing SQL.
 
80
 
 
81
    @return: a L{Schema} object with the contents of the given C{path} parsed
 
82
        and added to it as L{Table} objects.
 
83
    """
 
84
    schema = Schema(path.basename())
 
85
    schemaData = path.getContent()
 
86
    addSQLToSchema(schema, schemaData)
 
87
    return schema
 
88
 
 
89
 
 
90
 
 
91
def addSQLToSchema(schema, schemaData):
 
92
    """
 
93
    Add new SQL to an existing schema.
 
94
 
 
95
    @param schema: The schema to add the new SQL to.
 
96
 
 
97
    @type schema: L{Schema}
 
98
 
 
99
    @param schemaData: A string containing some SQL statements.
 
100
 
 
101
    @type schemaData: C{str}
 
102
 
 
103
    @return: the C{schema} argument
 
104
    """
 
105
    parsed = parse(schemaData)
 
106
    for stmt in parsed:
 
107
        preface = ''
 
108
        while stmt.tokens and not significant(stmt.tokens[0]):
 
109
            preface += str(stmt.tokens.pop(0))
 
110
        if not stmt.tokens:
 
111
            continue
 
112
        if stmt.get_type() == 'CREATE':
 
113
            createType = stmt.token_next(1, True).value.upper()
 
114
            if createType == u'TABLE':
 
115
                t = tableFromCreateStatement(schema, stmt)
 
116
                t.addComment(preface)
 
117
            elif createType == u'SEQUENCE':
 
118
                Sequence(schema,
 
119
                         stmt.token_next(2, True).get_name().encode('utf-8'))
 
120
            elif createType == u'INDEX':
 
121
                signifindex = iterSignificant(stmt)
 
122
                expect(signifindex, ttype=Keyword.DDL, value='CREATE')
 
123
                expect(signifindex, ttype=Keyword, value='INDEX')
 
124
                indexName = nameOrIdentifier(signifindex.next())
 
125
                expect(signifindex, ttype=Keyword, value='ON')
 
126
                [tableName, columnArgs] = iterSignificant(expect(signifindex,
 
127
                                                                 cls=Function))
 
128
                tableName = nameOrIdentifier(tableName)
 
129
                arggetter = iterSignificant(columnArgs)
 
130
 
 
131
                expect(arggetter, ttype=Punctuation, value=u'(')
 
132
                valueOrValues = arggetter.next()
 
133
                if isinstance(valueOrValues, IdentifierList):
 
134
                    valuelist = valueOrValues.get_identifiers()
 
135
                else:
 
136
                    valuelist = [valueOrValues]
 
137
                expect(arggetter, ttype=Punctuation, value=u')')
 
138
 
 
139
                idx = Index(schema, indexName, schema.tableNamed(tableName))
 
140
                for token in valuelist:
 
141
                    columnName = nameOrIdentifier(token)
 
142
                    idx.addColumn(idx.table.columnNamed(columnName))
 
143
        elif stmt.get_type() == 'INSERT':
 
144
            insertTokens = iterSignificant(stmt)
 
145
            expect(insertTokens, ttype=Keyword.DML, value='INSERT')
 
146
            expect(insertTokens, ttype=Keyword, value='INTO')
 
147
            tableName = expect(insertTokens, cls=Identifier).get_name()
 
148
            expect(insertTokens, ttype=Keyword, value='VALUES')
 
149
            values = expect(insertTokens, cls=Parenthesis)
 
150
            vals = iterSignificant(values)
 
151
            expect(vals, ttype=Punctuation, value='(')
 
152
            valuelist = expect(vals, cls=IdentifierList)
 
153
            expect(vals, ttype=Punctuation, value=')')
 
154
            rowData = []
 
155
            for ident in valuelist.get_identifiers():
 
156
                rowData.append(
 
157
                    {Number.Integer: int,
 
158
                     String.Single: _destringify}
 
159
                    [ident.ttype](ident.value)
 
160
                )
 
161
 
 
162
            schema.tableNamed(tableName).insertSchemaRow(rowData)
 
163
        else:
 
164
            print 'unknown type:', stmt.get_type()
 
165
    return schema
 
166
 
 
167
 
 
168
 
 
169
class _ColumnParser(object):
 
170
    """
 
171
    Stateful parser for the things between commas.
 
172
    """
 
173
 
 
174
    def __init__(self, table, parenIter, parens):
 
175
        """
 
176
        @param table: the L{Table} to add data to.
 
177
 
 
178
        @param parenIter: the iterator.
 
179
        """
 
180
        self.parens = parens
 
181
        self.iter = parenIter
 
182
        self.table = table
 
183
 
 
184
 
 
185
    def __iter__(self):
 
186
        """
 
187
        This object is an iterator; return itself.
 
188
        """
 
189
        return self
 
190
 
 
191
 
 
192
    def next(self):
 
193
        """
 
194
        Get the next L{IdentifierList}.
 
195
        """
 
196
        result = self.iter.next()
 
197
        if isinstance(result, IdentifierList):
 
198
            # Expand out all identifier lists, since they seem to pop up
 
199
            # incorrectly.  We should never see one in a column list anyway.
 
200
            # http://code.google.com/p/python-sqlparse/issues/detail?id=25
 
201
            while result.tokens:
 
202
                it = result.tokens.pop()
 
203
                if significant(it):
 
204
                    self.pushback(it)
 
205
            return self.next()
 
206
        return result
 
207
 
 
208
 
 
209
    def pushback(self, value):
 
210
        """
 
211
        Push the value back onto this iterator so it will be returned by the
 
212
        next call to C{next}.
 
213
        """
 
214
        self.iter = chain(iter((value,)), self.iter)
 
215
 
 
216
 
 
217
    def parse(self):
 
218
        """
 
219
        Parse everything.
 
220
        """
 
221
        expect(self.iter, ttype=Punctuation, value=u"(")
 
222
        while self.nextColumn():
 
223
            pass
 
224
 
 
225
 
 
226
    def nextColumn(self):
 
227
        """
 
228
        Parse the next column or constraint, depending on the next token.
 
229
        """
 
230
        maybeIdent = self.next()
 
231
        if maybeIdent.ttype == Name:
 
232
            return self.parseColumn(maybeIdent.value)
 
233
        elif isinstance(maybeIdent, Identifier):
 
234
            return self.parseColumn(maybeIdent.get_name())
 
235
        else:
 
236
            return self.parseConstraint(maybeIdent)
 
237
 
 
238
 
 
239
    def namesInParens(self, parens):
 
240
        parens = iterSignificant(parens)
 
241
        expect(parens, ttype=Punctuation, value="(")
 
242
        idorids = parens.next()
 
243
        if isinstance(idorids, Identifier):
 
244
            idnames = [idorids.get_name()]
 
245
        elif isinstance(idorids, IdentifierList):
 
246
            idnames = [x.get_name() for x in idorids.get_identifiers()]
 
247
        else:
 
248
            raise ViolatedExpectation("identifier or list", repr(idorids))
 
249
        expect(parens, ttype=Punctuation, value=")")
 
250
        return idnames
 
251
 
 
252
 
 
253
    def parseConstraint(self, constraintType):
 
254
        """
 
255
        Parse a 'free' constraint, described explicitly in the table as opposed
 
256
        to being implicitly associated with a column by being placed after it.
 
257
        """
 
258
        # only know about PRIMARY KEY and UNIQUE for now
 
259
        if constraintType.match(Keyword, 'PRIMARY'):
 
260
            expect(self, ttype=Keyword, value='KEY')
 
261
            names = self.namesInParens(expect(self, cls=Parenthesis))
 
262
            self.table.primaryKey = [self.table.columnNamed(n) for n in names]
 
263
        elif constraintType.match(Keyword, 'UNIQUE'):
 
264
            names = self.namesInParens(expect(self, cls=Parenthesis))
 
265
            self.table.tableConstraint(Constraint.UNIQUE, names)
 
266
        else:
 
267
            raise ViolatedExpectation('PRIMARY or UNIQUE', constraintType)
 
268
        return self.checkEnd(self.next())
 
269
 
 
270
 
 
271
    def checkEnd(self, val):
 
272
        """
 
273
        After a column or constraint, check the end.
 
274
        """
 
275
        if val.value == u",":
 
276
            return True
 
277
        elif val.value == u")":
 
278
            return False
 
279
        else:
 
280
            raise ViolatedExpectation(", or )", val)
 
281
 
 
282
 
 
283
    def parseColumn(self, name):
 
284
        """
 
285
        Parse a column with the given name.
 
286
        """
 
287
        typeName = self.next()
 
288
        if isinstance(typeName, Function):
 
289
            [funcIdent, args] = iterSignificant(typeName)
 
290
            typeName = funcIdent
 
291
            arggetter = iterSignificant(args)
 
292
            expect(arggetter, value=u'(')
 
293
            typeLength = int(expect(arggetter,
 
294
                                    ttype=Number.Integer).value.encode('utf-8'))
 
295
        else:
 
296
            maybeTypeArgs = self.next()
 
297
            if isinstance(maybeTypeArgs, Parenthesis):
 
298
                # type arguments
 
299
                significant = iterSignificant(maybeTypeArgs)
 
300
                expect(significant, value=u"(")
 
301
                typeLength = int(significant.next().value)
 
302
            else:
 
303
                # something else
 
304
                typeLength = None
 
305
                self.pushback(maybeTypeArgs)
 
306
        theType = SQLType(typeName.value.encode("utf-8"), typeLength)
 
307
        theColumn = self.table.addColumn(
 
308
            name=name.encode("utf-8"), type=theType
 
309
        )
 
310
        for val in self:
 
311
            if val.ttype == Punctuation:
 
312
                return self.checkEnd(val)
 
313
            else:
 
314
                expected = True
 
315
                def oneConstraint(t):
 
316
                    self.table.tableConstraint(t,
 
317
                                               [theColumn.name])
 
318
 
 
319
                if val.match(Keyword, 'PRIMARY'):
 
320
                    expect(self, ttype=Keyword, value='KEY')
 
321
                    # XXX check to make sure there's no other primary key yet
 
322
                    self.table.primaryKey = [theColumn]
 
323
                elif val.match(Keyword, 'UNIQUE'):
 
324
                    # XXX add UNIQUE constraint
 
325
                    oneConstraint(Constraint.UNIQUE)
 
326
                elif val.match(Keyword, 'NOT'):
 
327
                    # possibly not necessary, as 'NOT NULL' is a single keyword
 
328
                    # in sqlparse as of 0.1.2
 
329
                    expect(self, ttype=Keyword, value='NULL')
 
330
                    oneConstraint(Constraint.NOT_NULL)
 
331
                elif val.match(Keyword, 'NOT NULL'):
 
332
                    oneConstraint(Constraint.NOT_NULL)
 
333
                elif val.match(Keyword, 'DEFAULT'):
 
334
                    theDefault = self.next()
 
335
                    if isinstance(theDefault, Function):
 
336
                        thingo = theDefault.tokens[0].get_name()
 
337
                        parens = expectSingle(
 
338
                            theDefault.tokens[-1], cls=Parenthesis
 
339
                        )
 
340
                        pareniter = iterSignificant(parens)
 
341
                        if thingo.upper() == 'NEXTVAL':
 
342
                            expect(pareniter, ttype=Punctuation, value="(")
 
343
                            seqname = _destringify(
 
344
                                expect(pareniter, ttype=String.Single).value)
 
345
                            defaultValue = self.table.schema.sequenceNamed(
 
346
                                seqname
 
347
                            )
 
348
                            defaultValue.referringColumns.append(theColumn)
 
349
                        else:
 
350
                            defaultValue = ProcedureCall(thingo.encode('utf-8'),
 
351
                                                         parens)
 
352
                    elif theDefault.ttype == Number.Integer:
 
353
                        defaultValue = int(theDefault.value)
 
354
                    elif (theDefault.ttype == Keyword and
 
355
                          theDefault.value.lower() == 'false'):
 
356
                        defaultValue = False
 
357
                    elif (theDefault.ttype == Keyword and
 
358
                          theDefault.value.lower() == 'true'):
 
359
                        defaultValue = True
 
360
                    elif (theDefault.ttype == Keyword and
 
361
                          theDefault.value.lower() == 'null'):
 
362
                        defaultValue = None
 
363
                    elif theDefault.ttype == String.Single:
 
364
                        defaultValue = _destringify(theDefault.value)
 
365
                    else:
 
366
                        raise RuntimeError(
 
367
                            "not sure what to do: default %r" % (
 
368
                            theDefault))
 
369
                    theColumn.setDefaultValue(defaultValue)
 
370
                elif val.match(Keyword, 'REFERENCES'):
 
371
                    target = nameOrIdentifier(self.next())
 
372
                    theColumn.doesReferenceName(target)
 
373
                elif val.match(Keyword, 'ON'):
 
374
                    expect(self, ttype=Keyword.DML, value='DELETE')
 
375
                    expect(self, ttype=Keyword, value='CASCADE')
 
376
                    theColumn.cascade = True
 
377
                else:
 
378
                    expected = False
 
379
                if not expected:
 
380
                    print 'UNEXPECTED TOKEN:', repr(val), theColumn
 
381
                    print self.parens
 
382
                    import pprint
 
383
                    pprint.pprint(self.parens.tokens)
 
384
                    return 0
 
385
 
 
386
 
 
387
 
 
388
 
 
389
class ViolatedExpectation(Exception):
 
390
    """
 
391
    An expectation about the structure of the SQL syntax was violated.
 
392
    """
 
393
 
 
394
    def __init__(self, expected, got):
 
395
        self.expected = expected
 
396
        self.got = got
 
397
        super(ViolatedExpectation, self).__init__(
 
398
            "Expected %r got %s" % (expected, got)
 
399
        )
 
400
 
 
401
 
 
402
 
 
403
def nameOrIdentifier(token):
 
404
    """
 
405
    Determine if the given object is a name or an identifier, and return the
 
406
    textual value of that name or identifier.
 
407
 
 
408
    @rtype: L{str}
 
409
    """
 
410
    if isinstance(token, Identifier):
 
411
        return token.get_name()
 
412
    elif token.ttype == Name:
 
413
        return token.value
 
414
    else:
 
415
        raise ViolatedExpectation("identifier or name", repr(token))
 
416
 
 
417
 
 
418
 
 
419
def expectSingle(nextval, ttype=None, value=None, cls=None):
 
420
    """
 
421
    Expect some properties from retrieved value.
 
422
 
 
423
    @param ttype: A token type to compare against.
 
424
 
 
425
    @param value: A value to compare against.
 
426
 
 
427
    @param cls: A class to check if the value is an instance of.
 
428
 
 
429
    @raise ViolatedExpectation: if an unexpected token is found.
 
430
 
 
431
    @return: C{nextval}, if it matches.
 
432
    """
 
433
    if ttype is not None:
 
434
        if nextval.ttype != ttype:
 
435
            raise ViolatedExpectation(ttype, '%s:%r' % (nextval.ttype, nextval))
 
436
    if value is not None:
 
437
        if nextval.value.upper() != value.upper():
 
438
            raise ViolatedExpectation(value, nextval.value)
 
439
    if cls is not None:
 
440
        if nextval.__class__ != cls:
 
441
            raise ViolatedExpectation(cls, repr(nextval))
 
442
    return nextval
 
443
 
 
444
 
 
445
 
 
446
def expect(iterator, **kw):
 
447
    """
 
448
    Retrieve a value from an iterator and check its properties.  Same signature
 
449
    as L{expectSingle}, except it takes an iterator instead of a value.
 
450
 
 
451
    @see: L{expectSingle}
 
452
    """
 
453
    nextval = iterator.next()
 
454
    return expectSingle(nextval, **kw)
 
455
 
 
456
 
 
457
 
 
458
def significant(token):
 
459
    """
 
460
    Determine if the token is 'significant', i.e. that it is not a comment and
 
461
    not whitespace.
 
462
    """
 
463
    # comment has 'None' is_whitespace() result.  intentional?
 
464
    return (not isinstance(token, Comment) and not token.is_whitespace())
 
465
 
 
466
 
 
467
 
 
468
def iterSignificant(tokenList):
 
469
    """
 
470
    Iterate tokens that pass the test given by L{significant}, from a given
 
471
    L{TokenList}.
 
472
    """
 
473
    for token in tokenList.tokens:
 
474
        if significant(token):
 
475
            yield token
 
476
 
 
477
 
 
478
 
 
479
def _destringify(strval):
 
480
    """
 
481
    Convert a single-quoted SQL string into its actual represented value.
 
482
    (Assumes standards compliance, since we should be controlling all the input
 
483
    here.  The only quoting syntax respected is "''".)
 
484
    """
 
485
    return strval[1:-1].replace("''", "'")
 
486
 
 
487
 
 
488