~ubuntu-branches/ubuntu/maverick/python3.1/maverick

« back to all changes in this revision

Viewing changes to Lib/ast.py

  • Committer: Bazaar Package Importer
  • Author(s): Matthias Klose
  • Date: 2009-03-23 00:01:27 UTC
  • Revision ID: james.westby@ubuntu.com-20090323000127-5fstfxju4ufrhthq
Tags: upstream-3.1~a1+20090322
ImportĀ upstreamĀ versionĀ 3.1~a1+20090322

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# -*- coding: utf-8 -*-
 
2
"""
 
3
    ast
 
4
    ~~~
 
5
 
 
6
    The `ast` module helps Python applications to process trees of the Python
 
7
    abstract syntax grammar.  The abstract syntax itself might change with
 
8
    each Python release; this module helps to find out programmatically what
 
9
    the current grammar looks like and allows modifications of it.
 
10
 
 
11
    An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as
 
12
    a flag to the `compile()` builtin function or by using the `parse()`
 
13
    function from this module.  The result will be a tree of objects whose
 
14
    classes all inherit from `ast.AST`.
 
15
 
 
16
    A modified abstract syntax tree can be compiled into a Python code object
 
17
    using the built-in `compile()` function.
 
18
 
 
19
    Additionally various helper functions are provided that make working with
 
20
    the trees simpler.  The main intention of the helper functions and this
 
21
    module in general is to provide an easy to use interface for libraries
 
22
    that work tightly with the python syntax (template engines for example).
 
23
 
 
24
 
 
25
    :copyright: Copyright 2008 by Armin Ronacher.
 
26
    :license: Python License.
 
27
"""
 
28
from _ast import *
 
29
from _ast import __version__
 
30
 
 
31
 
 
32
def parse(expr, filename='<unknown>', mode='exec'):
 
33
    """
 
34
    Parse an expression into an AST node.
 
35
    Equivalent to compile(expr, filename, mode, PyCF_ONLY_AST).
 
36
    """
 
37
    return compile(expr, filename, mode, PyCF_ONLY_AST)
 
38
 
 
39
 
 
40
def literal_eval(node_or_string):
 
41
    """
 
42
    Safely evaluate an expression node or a string containing a Python
 
43
    expression.  The string or node provided may only consist of the following
 
44
    Python literal structures: strings, numbers, tuples, lists, dicts, booleans,
 
45
    and None.
 
46
    """
 
47
    _safe_names = {'None': None, 'True': True, 'False': False}
 
48
    if isinstance(node_or_string, str):
 
49
        node_or_string = parse(node_or_string, mode='eval')
 
50
    if isinstance(node_or_string, Expression):
 
51
        node_or_string = node_or_string.body
 
52
    def _convert(node):
 
53
        if isinstance(node, Str):
 
54
            return node.s
 
55
        elif isinstance(node, Num):
 
56
            return node.n
 
57
        elif isinstance(node, Tuple):
 
58
            return tuple(map(_convert, node.elts))
 
59
        elif isinstance(node, List):
 
60
            return list(map(_convert, node.elts))
 
61
        elif isinstance(node, Dict):
 
62
            return dict((_convert(k), _convert(v)) for k, v
 
63
                        in zip(node.keys, node.values))
 
64
        elif isinstance(node, Name):
 
65
            if node.id in _safe_names:
 
66
                return _safe_names[node.id]
 
67
        elif isinstance(node, BinOp) and \
 
68
             isinstance(node.op, (Add, Sub)) and \
 
69
             isinstance(node.right, Num) and \
 
70
             isinstance(node.right.n, complex) and \
 
71
             isinstance(node.left, Num) and \
 
72
             isinstance(node.left.n, (int, float)):
 
73
            left = node.left.n
 
74
            right = node.right.n
 
75
            if isinstance(node.op, Add):
 
76
                return left + right
 
77
            else:
 
78
                return left - right
 
79
        raise ValueError('malformed string')
 
80
    return _convert(node_or_string)
 
81
 
 
82
 
 
83
def dump(node, annotate_fields=True, include_attributes=False):
 
84
    """
 
85
    Return a formatted dump of the tree in *node*.  This is mainly useful for
 
86
    debugging purposes.  The returned string will show the names and the values
 
87
    for fields.  This makes the code impossible to evaluate, so if evaluation is
 
88
    wanted *annotate_fields* must be set to False.  Attributes such as line
 
89
    numbers and column offsets are not dumped by default.  If this is wanted,
 
90
    *include_attributes* can be set to True.
 
91
    """
 
92
    def _format(node):
 
93
        if isinstance(node, AST):
 
94
            fields = [(a, _format(b)) for a, b in iter_fields(node)]
 
95
            rv = '%s(%s' % (node.__class__.__name__, ', '.join(
 
96
                ('%s=%s' % field for field in fields)
 
97
                if annotate_fields else
 
98
                (b for a, b in fields)
 
99
            ))
 
100
            if include_attributes and node._attributes:
 
101
                rv += fields and ', ' or ' '
 
102
                rv += ', '.join('%s=%s' % (a, _format(getattr(node, a)))
 
103
                                for a in node._attributes)
 
104
            return rv + ')'
 
105
        elif isinstance(node, list):
 
106
            return '[%s]' % ', '.join(_format(x) for x in node)
 
107
        return repr(node)
 
108
    if not isinstance(node, AST):
 
109
        raise TypeError('expected AST, got %r' % node.__class__.__name__)
 
110
    return _format(node)
 
111
 
 
112
 
 
113
def copy_location(new_node, old_node):
 
114
    """
 
115
    Copy source location (`lineno` and `col_offset` attributes) from
 
116
    *old_node* to *new_node* if possible, and return *new_node*.
 
117
    """
 
118
    for attr in 'lineno', 'col_offset':
 
119
        if attr in old_node._attributes and attr in new_node._attributes \
 
120
           and hasattr(old_node, attr):
 
121
            setattr(new_node, attr, getattr(old_node, attr))
 
122
    return new_node
 
123
 
 
124
 
 
125
def fix_missing_locations(node):
 
126
    """
 
127
    When you compile a node tree with compile(), the compiler expects lineno and
 
128
    col_offset attributes for every node that supports them.  This is rather
 
129
    tedious to fill in for generated nodes, so this helper adds these attributes
 
130
    recursively where not already set, by setting them to the values of the
 
131
    parent node.  It works recursively starting at *node*.
 
132
    """
 
133
    def _fix(node, lineno, col_offset):
 
134
        if 'lineno' in node._attributes:
 
135
            if not hasattr(node, 'lineno'):
 
136
                node.lineno = lineno
 
137
            else:
 
138
                lineno = node.lineno
 
139
        if 'col_offset' in node._attributes:
 
140
            if not hasattr(node, 'col_offset'):
 
141
                node.col_offset = col_offset
 
142
            else:
 
143
                col_offset = node.col_offset
 
144
        for child in iter_child_nodes(node):
 
145
            _fix(child, lineno, col_offset)
 
146
    _fix(node, 1, 0)
 
147
    return node
 
148
 
 
149
 
 
150
def increment_lineno(node, n=1):
 
151
    """
 
152
    Increment the line number of each node in the tree starting at *node* by *n*.
 
153
    This is useful to "move code" to a different location in a file.
 
154
    """
 
155
    if 'lineno' in node._attributes:
 
156
        node.lineno = getattr(node, 'lineno', 0) + n
 
157
    for child in walk(node):
 
158
        if 'lineno' in child._attributes:
 
159
            child.lineno = getattr(child, 'lineno', 0) + n
 
160
    return node
 
161
 
 
162
 
 
163
def iter_fields(node):
 
164
    """
 
165
    Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
 
166
    that is present on *node*.
 
167
    """
 
168
    for field in node._fields:
 
169
        try:
 
170
            yield field, getattr(node, field)
 
171
        except AttributeError:
 
172
            pass
 
173
 
 
174
 
 
175
def iter_child_nodes(node):
 
176
    """
 
177
    Yield all direct child nodes of *node*, that is, all fields that are nodes
 
178
    and all items of fields that are lists of nodes.
 
179
    """
 
180
    for name, field in iter_fields(node):
 
181
        if isinstance(field, AST):
 
182
            yield field
 
183
        elif isinstance(field, list):
 
184
            for item in field:
 
185
                if isinstance(item, AST):
 
186
                    yield item
 
187
 
 
188
 
 
189
def get_docstring(node, clean=True):
 
190
    """
 
191
    Return the docstring for the given node or None if no docstring can
 
192
    be found.  If the node provided does not have docstrings a TypeError
 
193
    will be raised.
 
194
    """
 
195
    if not isinstance(node, (FunctionDef, ClassDef, Module)):
 
196
        raise TypeError("%r can't have docstrings" % node.__class__.__name__)
 
197
    if node.body and isinstance(node.body[0], Expr) and \
 
198
       isinstance(node.body[0].value, Str):
 
199
        if clean:
 
200
            import inspect
 
201
            return inspect.cleandoc(node.body[0].value.s)
 
202
        return node.body[0].value.s
 
203
 
 
204
 
 
205
def walk(node):
 
206
    """
 
207
    Recursively yield all child nodes of *node*, in no specified order.  This is
 
208
    useful if you only want to modify nodes in place and don't care about the
 
209
    context.
 
210
    """
 
211
    from collections import deque
 
212
    todo = deque([node])
 
213
    while todo:
 
214
        node = todo.popleft()
 
215
        todo.extend(iter_child_nodes(node))
 
216
        yield node
 
217
 
 
218
 
 
219
class NodeVisitor(object):
 
220
    """
 
221
    A node visitor base class that walks the abstract syntax tree and calls a
 
222
    visitor function for every node found.  This function may return a value
 
223
    which is forwarded by the `visit` method.
 
224
 
 
225
    This class is meant to be subclassed, with the subclass adding visitor
 
226
    methods.
 
227
 
 
228
    Per default the visitor functions for the nodes are ``'visit_'`` +
 
229
    class name of the node.  So a `TryFinally` node visit function would
 
230
    be `visit_TryFinally`.  This behavior can be changed by overriding
 
231
    the `visit` method.  If no visitor function exists for a node
 
232
    (return value `None`) the `generic_visit` visitor is used instead.
 
233
 
 
234
    Don't use the `NodeVisitor` if you want to apply changes to nodes during
 
235
    traversing.  For this a special visitor exists (`NodeTransformer`) that
 
236
    allows modifications.
 
237
    """
 
238
 
 
239
    def visit(self, node):
 
240
        """Visit a node."""
 
241
        method = 'visit_' + node.__class__.__name__
 
242
        visitor = getattr(self, method, self.generic_visit)
 
243
        return visitor(node)
 
244
 
 
245
    def generic_visit(self, node):
 
246
        """Called if no explicit visitor function exists for a node."""
 
247
        for field, value in iter_fields(node):
 
248
            if isinstance(value, list):
 
249
                for item in value:
 
250
                    if isinstance(item, AST):
 
251
                        self.visit(item)
 
252
            elif isinstance(value, AST):
 
253
                self.visit(value)
 
254
 
 
255
 
 
256
class NodeTransformer(NodeVisitor):
 
257
    """
 
258
    A :class:`NodeVisitor` subclass that walks the abstract syntax tree and
 
259
    allows modification of nodes.
 
260
 
 
261
    The `NodeTransformer` will walk the AST and use the return value of the
 
262
    visitor methods to replace or remove the old node.  If the return value of
 
263
    the visitor method is ``None``, the node will be removed from its location,
 
264
    otherwise it is replaced with the return value.  The return value may be the
 
265
    original node in which case no replacement takes place.
 
266
 
 
267
    Here is an example transformer that rewrites all occurrences of name lookups
 
268
    (``foo``) to ``data['foo']``::
 
269
 
 
270
       class RewriteName(NodeTransformer):
 
271
 
 
272
           def visit_Name(self, node):
 
273
               return copy_location(Subscript(
 
274
                   value=Name(id='data', ctx=Load()),
 
275
                   slice=Index(value=Str(s=node.id)),
 
276
                   ctx=node.ctx
 
277
               ), node)
 
278
 
 
279
    Keep in mind that if the node you're operating on has child nodes you must
 
280
    either transform the child nodes yourself or call the :meth:`generic_visit`
 
281
    method for the node first.
 
282
 
 
283
    For nodes that were part of a collection of statements (that applies to all
 
284
    statement nodes), the visitor may also return a list of nodes rather than
 
285
    just a single node.
 
286
 
 
287
    Usually you use the transformer like this::
 
288
 
 
289
       node = YourTransformer().visit(node)
 
290
    """
 
291
 
 
292
    def generic_visit(self, node):
 
293
        for field, old_value in iter_fields(node):
 
294
            old_value = getattr(node, field, None)
 
295
            if isinstance(old_value, list):
 
296
                new_values = []
 
297
                for value in old_value:
 
298
                    if isinstance(value, AST):
 
299
                        value = self.visit(value)
 
300
                        if value is None:
 
301
                            continue
 
302
                        elif not isinstance(value, AST):
 
303
                            new_values.extend(value)
 
304
                            continue
 
305
                    new_values.append(value)
 
306
                old_value[:] = new_values
 
307
            elif isinstance(old_value, AST):
 
308
                new_node = self.visit(old_value)
 
309
                if new_node is None:
 
310
                    delattr(node, field)
 
311
                else:
 
312
                    setattr(node, field, new_node)
 
313
        return node