~ubuntu-branches/ubuntu/lucid/python2.6/lucid

« back to all changes in this revision

Viewing changes to Lib/lib2to3/refactor.py

  • Committer: Bazaar Package Importer
  • Author(s): Matthias Klose
  • Date: 2010-03-11 13:30:19 UTC
  • mto: (10.1.13 sid)
  • mto: This revision was merged to the branch mainline in revision 44.
  • Revision ID: james.westby@ubuntu.com-20100311133019-sblbooa3uqrkoe70
Tags: upstream-2.6.5~rc2
ImportĀ upstreamĀ versionĀ 2.6.5~rc2

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
#!/usr/bin/env python2.5
2
1
# Copyright 2006 Google, Inc. All Rights Reserved.
3
2
# Licensed to PSF under a Contributor Agreement.
4
3
 
15
14
# Python imports
16
15
import os
17
16
import sys
18
 
import difflib
19
17
import logging
20
18
import operator
21
 
from collections import defaultdict
 
19
import collections
 
20
import StringIO
22
21
from itertools import chain
23
22
 
24
23
# Local imports
25
 
from .pgen2 import driver
26
 
from .pgen2 import tokenize
27
 
 
28
 
from . import pytree
29
 
from . import patcomp
30
 
from . import fixes
31
 
from . import pygram
 
24
from .pgen2 import driver, tokenize, token
 
25
from . import pytree, pygram
32
26
 
33
27
 
34
28
def get_all_fix_names(fixer_pkg, remove_prefix=True):
43
37
            fix_names.append(name[:-3])
44
38
    return fix_names
45
39
 
46
 
def get_head_types(pat):
 
40
 
 
41
class _EveryNode(Exception):
 
42
    pass
 
43
 
 
44
 
 
45
def _get_head_types(pat):
47
46
    """ Accepts a pytree Pattern Node and returns a set
48
47
        of the pattern types which will match first. """
49
48
 
51
50
        # NodePatters must either have no type and no content
52
51
        #   or a type and content -- so they don't get any farther
53
52
        # Always return leafs
 
53
        if pat.type is None:
 
54
            raise _EveryNode
54
55
        return set([pat.type])
55
56
 
56
57
    if isinstance(pat, pytree.NegatedPattern):
57
58
        if pat.content:
58
 
            return get_head_types(pat.content)
59
 
        return set([None]) # Negated Patterns don't have a type
 
59
            return _get_head_types(pat.content)
 
60
        raise _EveryNode # Negated Patterns don't have a type
60
61
 
61
62
    if isinstance(pat, pytree.WildcardPattern):
62
63
        # Recurse on each node in content
63
64
        r = set()
64
65
        for p in pat.content:
65
66
            for x in p:
66
 
                r.update(get_head_types(x))
 
67
                r.update(_get_head_types(x))
67
68
        return r
68
69
 
69
70
    raise Exception("Oh no! I don't understand pattern %s" %(pat))
70
71
 
71
 
def get_headnode_dict(fixer_list):
 
72
 
 
73
def _get_headnode_dict(fixer_list):
72
74
    """ Accepts a list of fixers and returns a dictionary
73
75
        of head node type --> fixer list.  """
74
 
    head_nodes = defaultdict(list)
 
76
    head_nodes = collections.defaultdict(list)
 
77
    every = []
75
78
    for fixer in fixer_list:
76
 
        if not fixer.pattern:
77
 
            head_nodes[None].append(fixer)
78
 
            continue
79
 
        for t in get_head_types(fixer.pattern):
80
 
            head_nodes[t].append(fixer)
81
 
    return head_nodes
 
79
        if fixer.pattern:
 
80
            try:
 
81
                heads = _get_head_types(fixer.pattern)
 
82
            except _EveryNode:
 
83
                every.append(fixer)
 
84
            else:
 
85
                for node_type in heads:
 
86
                    head_nodes[node_type].append(fixer)
 
87
        else:
 
88
            if fixer._accept_type is not None:
 
89
                head_nodes[fixer._accept_type].append(fixer)
 
90
            else:
 
91
                every.append(fixer)
 
92
    for node_type in chain(pygram.python_grammar.symbol2number.itervalues(),
 
93
                           pygram.python_grammar.tokens):
 
94
        head_nodes[node_type].extend(every)
 
95
    return dict(head_nodes)
 
96
 
82
97
 
83
98
def get_fixers_from_package(pkg_name):
84
99
    """
87
102
    return [pkg_name + "." + fix_name
88
103
            for fix_name in get_all_fix_names(pkg_name, False)]
89
104
 
 
105
def _identity(obj):
 
106
    return obj
 
107
 
 
108
if sys.version_info < (3, 0):
 
109
    import codecs
 
110
    _open_with_encoding = codecs.open
 
111
    # codecs.open doesn't translate newlines sadly.
 
112
    def _from_system_newlines(input):
 
113
        return input.replace(u"\r\n", u"\n")
 
114
    def _to_system_newlines(input):
 
115
        if os.linesep != "\n":
 
116
            return input.replace(u"\n", os.linesep)
 
117
        else:
 
118
            return input
 
119
else:
 
120
    _open_with_encoding = open
 
121
    _from_system_newlines = _identity
 
122
    _to_system_newlines = _identity
 
123
 
 
124
 
 
125
def _detect_future_print(source):
 
126
    have_docstring = False
 
127
    gen = tokenize.generate_tokens(StringIO.StringIO(source).readline)
 
128
    def advance():
 
129
        tok = next(gen)
 
130
        return tok[0], tok[1]
 
131
    ignore = frozenset((token.NEWLINE, tokenize.NL, token.COMMENT))
 
132
    try:
 
133
        while True:
 
134
            tp, value = advance()
 
135
            if tp in ignore:
 
136
                continue
 
137
            elif tp == token.STRING:
 
138
                if have_docstring:
 
139
                    break
 
140
                have_docstring = True
 
141
            elif tp == token.NAME and value == u"from":
 
142
                tp, value = advance()
 
143
                if tp != token.NAME and value != u"__future__":
 
144
                    break
 
145
                tp, value = advance()
 
146
                if tp != token.NAME and value != u"import":
 
147
                    break
 
148
                tp, value = advance()
 
149
                if tp == token.OP and value == u"(":
 
150
                    tp, value = advance()
 
151
                while tp == token.NAME:
 
152
                    if value == u"print_function":
 
153
                        return True
 
154
                    tp, value = advance()
 
155
                    if tp != token.OP and value != u",":
 
156
                        break
 
157
                    tp, value = advance()
 
158
            else:
 
159
                break
 
160
    except StopIteration:
 
161
        pass
 
162
    return False
 
163
 
90
164
 
91
165
class FixerError(Exception):
92
166
    """A fixer could not be loaded."""
94
168
 
95
169
class RefactoringTool(object):
96
170
 
97
 
    _default_options = {"print_function": False}
 
171
    _default_options = {"print_function" : False}
98
172
 
99
173
    CLASS_PREFIX = "Fix" # The prefix for fixer classes
100
174
    FILE_PREFIX = "fix_" # The prefix for modules with a fixer within
112
186
        self.options = self._default_options.copy()
113
187
        if options is not None:
114
188
            self.options.update(options)
 
189
        if self.options["print_function"]:
 
190
            self.grammar = pygram.python_grammar_no_print_statement
 
191
        else:
 
192
            self.grammar = pygram.python_grammar
115
193
        self.errors = []
116
194
        self.logger = logging.getLogger("RefactoringTool")
117
195
        self.fixer_log = []
118
196
        self.wrote = False
119
 
        if self.options["print_function"]:
120
 
            del pygram.python_grammar.keywords["print"]
121
 
        self.driver = driver.Driver(pygram.python_grammar,
 
197
        self.driver = driver.Driver(self.grammar,
122
198
                                    convert=pytree.convert,
123
199
                                    logger=self.logger)
124
200
        self.pre_order, self.post_order = self.get_fixers()
125
201
 
126
 
        self.pre_order_heads = get_headnode_dict(self.pre_order)
127
 
        self.post_order_heads = get_headnode_dict(self.post_order)
 
202
        self.pre_order_heads = _get_headnode_dict(self.pre_order)
 
203
        self.post_order_heads = _get_headnode_dict(self.post_order)
128
204
 
129
205
        self.files = []  # List of files that were or should be modified
130
206
 
183
259
            msg = msg % args
184
260
        self.logger.debug(msg)
185
261
 
186
 
    def print_output(self, lines):
187
 
        """Called with lines of output to give to the user."""
 
262
    def print_output(self, old_text, new_text, filename, equal):
 
263
        """Called with the old version, new version, and filename of a
 
264
        refactored file."""
188
265
        pass
189
266
 
190
267
    def refactor(self, items, write=False, doctests_only=False):
207
284
            dirnames.sort()
208
285
            filenames.sort()
209
286
            for name in filenames:
210
 
                if not name.startswith(".") and name.endswith("py"):
 
287
                if not name.startswith(".") and \
 
288
                        os.path.splitext(name)[1].endswith("py"):
211
289
                    fullname = os.path.join(dirpath, name)
212
290
                    self.refactor_file(fullname, write, doctests_only)
213
291
            # Modify dirnames in-place to remove subdirs with leading dots
214
292
            dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")]
215
293
 
 
294
    def _read_python_source(self, filename):
 
295
        """
 
296
        Do our best to decode a Python source file correctly.
 
297
        """
 
298
        try:
 
299
            f = open(filename, "rb")
 
300
        except IOError, err:
 
301
            self.log_error("Can't open %s: %s", filename, err)
 
302
            return None, None
 
303
        try:
 
304
            encoding = tokenize.detect_encoding(f.readline)[0]
 
305
        finally:
 
306
            f.close()
 
307
        with _open_with_encoding(filename, "r", encoding=encoding) as f:
 
308
            return _from_system_newlines(f.read()), encoding
 
309
 
216
310
    def refactor_file(self, filename, write=False, doctests_only=False):
217
311
        """Refactors a file."""
218
 
        try:
219
 
            f = open(filename)
220
 
        except IOError, err:
221
 
            self.log_error("Can't open %s: %s", filename, err)
 
312
        input, encoding = self._read_python_source(filename)
 
313
        if input is None:
 
314
            # Reading the file failed.
222
315
            return
223
 
        try:
224
 
            input = f.read() + "\n" # Silence certain parse errors
225
 
        finally:
226
 
            f.close()
 
316
        input += u"\n" # Silence certain parse errors
227
317
        if doctests_only:
228
318
            self.log_debug("Refactoring doctests in %s", filename)
229
319
            output = self.refactor_docstring(input, filename)
230
320
            if output != input:
231
 
                self.processed_file(output, filename, input, write=write)
 
321
                self.processed_file(output, filename, input, write, encoding)
232
322
            else:
233
323
                self.log_debug("No doctest changes in %s", filename)
234
324
        else:
235
325
            tree = self.refactor_string(input, filename)
236
326
            if tree and tree.was_changed:
237
327
                # The [:-1] is to take off the \n we added earlier
238
 
                self.processed_file(str(tree)[:-1], filename, write=write)
 
328
                self.processed_file(unicode(tree)[:-1], filename,
 
329
                                    write=write, encoding=encoding)
239
330
            else:
240
331
                self.log_debug("No changes in %s", filename)
241
332
 
250
341
            An AST corresponding to the refactored input stream; None if
251
342
            there were errors during the parse.
252
343
        """
 
344
        if _detect_future_print(data):
 
345
            self.driver.grammar = pygram.python_grammar_no_print_statement
253
346
        try:
254
347
            tree = self.driver.parse_string(data)
255
348
        except Exception, err:
256
349
            self.log_error("Can't parse %s: %s: %s",
257
350
                           name, err.__class__.__name__, err)
258
351
            return
 
352
        finally:
 
353
            self.driver.grammar = self.grammar
259
354
        self.log_debug("Refactoring %s", name)
260
355
        self.refactor_tree(tree, name)
261
356
        return tree
272
367
        else:
273
368
            tree = self.refactor_string(input, "<stdin>")
274
369
            if tree and tree.was_changed:
275
 
                self.processed_file(str(tree), "<stdin>", input)
 
370
                self.processed_file(unicode(tree), "<stdin>", input)
276
371
            else:
277
372
                self.log_debug("No changes in stdin")
278
373
 
312
407
        if not fixers:
313
408
            return
314
409
        for node in traversal:
315
 
            for fixer in fixers[node.type] + fixers[None]:
 
410
            for fixer in fixers[node.type]:
316
411
                results = fixer.match(node)
317
412
                if results:
318
413
                    new = fixer.transform(node, results)
319
 
                    if new is not None and (new != node or
320
 
                                            str(new) != str(node)):
 
414
                    if new is not None:
321
415
                        node.replace(new)
322
416
                        node = new
323
417
 
324
 
    def processed_file(self, new_text, filename, old_text=None, write=False):
 
418
    def processed_file(self, new_text, filename, old_text=None, write=False,
 
419
                       encoding=None):
325
420
        """
326
421
        Called when a file has been refactored, and there are changes.
327
422
        """
328
423
        self.files.append(filename)
329
424
        if old_text is None:
330
 
            try:
331
 
                f = open(filename, "r")
332
 
            except IOError, err:
333
 
                self.log_error("Can't read %s: %s", filename, err)
 
425
            old_text = self._read_python_source(filename)[0]
 
426
            if old_text is None:
334
427
                return
335
 
            try:
336
 
                old_text = f.read()
337
 
            finally:
338
 
                f.close()
339
 
        if old_text == new_text:
 
428
        equal = old_text == new_text
 
429
        self.print_output(old_text, new_text, filename, equal)
 
430
        if equal:
340
431
            self.log_debug("No changes to %s", filename)
341
432
            return
342
 
        self.print_output(diff_texts(old_text, new_text, filename))
343
433
        if write:
344
 
            self.write_file(new_text, filename, old_text)
 
434
            self.write_file(new_text, filename, old_text, encoding)
345
435
        else:
346
436
            self.log_debug("Not writing changes to %s", filename)
347
437
 
348
 
    def write_file(self, new_text, filename, old_text):
 
438
    def write_file(self, new_text, filename, old_text, encoding=None):
349
439
        """Writes a string to a file.
350
440
 
351
441
        It first shows a unified diff between the old text and the new text, and
353
443
        set.
354
444
        """
355
445
        try:
356
 
            f = open(filename, "w")
 
446
            f = _open_with_encoding(filename, "w", encoding=encoding)
357
447
        except os.error, err:
358
448
            self.log_error("Can't create %s: %s", filename, err)
359
449
            return
360
450
        try:
361
 
            f.write(new_text)
 
451
            f.write(_to_system_newlines(new_text))
362
452
        except os.error, err:
363
453
            self.log_error("Can't write %s: %s", filename, err)
364
454
        finally:
398
488
                indent = line[:i]
399
489
            elif (indent is not None and
400
490
                  (line.startswith(indent + self.PS2) or
401
 
                   line == indent + self.PS2.rstrip() + "\n")):
 
491
                   line == indent + self.PS2.rstrip() + u"\n")):
402
492
                block.append(line)
403
493
            else:
404
494
                if block is not None:
410
500
        if block is not None:
411
501
            result.extend(self.refactor_doctest(block, block_lineno,
412
502
                                                indent, filename))
413
 
        return "".join(result)
 
503
        return u"".join(result)
414
504
 
415
505
    def refactor_doctest(self, block, lineno, indent, filename):
416
506
        """Refactors one doctest.
425
515
        except Exception, err:
426
516
            if self.log.isEnabledFor(logging.DEBUG):
427
517
                for line in block:
428
 
                    self.log_debug("Source: %s", line.rstrip("\n"))
 
518
                    self.log_debug("Source: %s", line.rstrip(u"\n"))
429
519
            self.log_error("Can't parse docstring in %s line %s: %s: %s",
430
520
                           filename, lineno, err.__class__.__name__, err)
431
521
            return block
432
522
        if self.refactor_tree(tree, filename):
433
 
            new = str(tree).splitlines(True)
 
523
            new = unicode(tree).splitlines(True)
434
524
            # Undo the adjustment of the line numbers in wrap_toks() below.
435
525
            clipped, new = new[:lineno-1], new[lineno-1:]
436
 
            assert clipped == ["\n"] * (lineno-1), clipped
437
 
            if not new[-1].endswith("\n"):
438
 
                new[-1] += "\n"
 
526
            assert clipped == [u"\n"] * (lineno-1), clipped
 
527
            if not new[-1].endswith(u"\n"):
 
528
                new[-1] += u"\n"
439
529
            block = [indent + self.PS1 + new.pop(0)]
440
530
            if new:
441
531
                block += [indent + self.PS2 + line for line in new]
497
587
        for line in block:
498
588
            if line.startswith(prefix):
499
589
                yield line[len(prefix):]
500
 
            elif line == prefix.rstrip() + "\n":
501
 
                yield "\n"
 
590
            elif line == prefix.rstrip() + u"\n":
 
591
                yield u"\n"
502
592
            else:
503
593
                raise AssertionError("line=%r, prefix=%r" % (line, prefix))
504
594
            prefix = prefix2
506
596
            yield ""
507
597
 
508
598
 
509
 
def diff_texts(a, b, filename):
510
 
    """Return a unified diff of two strings."""
511
 
    a = a.splitlines()
512
 
    b = b.splitlines()
513
 
    return difflib.unified_diff(a, b, filename, filename,
514
 
                                "(original)", "(refactored)",
515
 
                                lineterm="")
 
599
class MultiprocessingUnsupported(Exception):
 
600
    pass
 
601
 
 
602
 
 
603
class MultiprocessRefactoringTool(RefactoringTool):
 
604
 
 
605
    def __init__(self, *args, **kwargs):
 
606
        super(MultiprocessRefactoringTool, self).__init__(*args, **kwargs)
 
607
        self.queue = None
 
608
 
 
609
    def refactor(self, items, write=False, doctests_only=False,
 
610
                 num_processes=1):
 
611
        if num_processes == 1:
 
612
            return super(MultiprocessRefactoringTool, self).refactor(
 
613
                items, write, doctests_only)
 
614
        try:
 
615
            import multiprocessing
 
616
        except ImportError:
 
617
            raise MultiprocessingUnsupported
 
618
        if self.queue is not None:
 
619
            raise RuntimeError("already doing multiple processes")
 
620
        self.queue = multiprocessing.JoinableQueue()
 
621
        processes = [multiprocessing.Process(target=self._child)
 
622
                     for i in xrange(num_processes)]
 
623
        try:
 
624
            for p in processes:
 
625
                p.start()
 
626
            super(MultiprocessRefactoringTool, self).refactor(items, write,
 
627
                                                              doctests_only)
 
628
        finally:
 
629
            self.queue.join()
 
630
            for i in xrange(num_processes):
 
631
                self.queue.put(None)
 
632
            for p in processes:
 
633
                if p.is_alive():
 
634
                    p.join()
 
635
            self.queue = None
 
636
 
 
637
    def _child(self):
 
638
        task = self.queue.get()
 
639
        while task is not None:
 
640
            args, kwargs = task
 
641
            try:
 
642
                super(MultiprocessRefactoringTool, self).refactor_file(
 
643
                    *args, **kwargs)
 
644
            finally:
 
645
                self.queue.task_done()
 
646
            task = self.queue.get()
 
647
 
 
648
    def refactor_file(self, *args, **kwargs):
 
649
        if self.queue is not None:
 
650
            self.queue.put((args, kwargs))
 
651
        else:
 
652
            return super(MultiprocessRefactoringTool, self).refactor_file(
 
653
                *args, **kwargs)