PageRenderTime 43ms CodeModel.GetById 20ms app.highlight 9ms RepoModel.GetById 12ms app.codeStats 0ms

/Lib/lib2to3/fixes/fix_import.py

http://unladen-swallow.googlecode.com/
Python | 90 lines | 72 code | 3 blank | 15 comment | 3 complexity | 9973617c9e868b2b9afb0a609ef30b35 MD5 | raw file
 1"""Fixer for import statements.
 2If spam is being imported from the local directory, this import:
 3    from spam import eggs
 4Becomes:
 5    from .spam import eggs
 6
 7And this import:
 8    import spam
 9Becomes:
10    from . import spam
11"""
12
13# Local imports
14from .. import fixer_base
15from os.path import dirname, join, exists, pathsep
16from ..fixer_util import FromImport, syms, token
17
18
19def traverse_imports(names):
20    """
21    Walks over all the names imported in a dotted_as_names node.
22    """
23    pending = [names]
24    while pending:
25        node = pending.pop()
26        if node.type == token.NAME:
27            yield node.value
28        elif node.type == syms.dotted_name:
29            yield "".join([ch.value for ch in node.children])
30        elif node.type == syms.dotted_as_name:
31            pending.append(node.children[0])
32        elif node.type == syms.dotted_as_names:
33            pending.extend(node.children[::-2])
34        else:
35            raise AssertionError("unkown node type")
36
37
38class FixImport(fixer_base.BaseFix):
39
40    PATTERN = """
41    import_from< 'from' imp=any 'import' ['('] any [')'] >
42    |
43    import_name< 'import' imp=any >
44    """
45
46    def transform(self, node, results):
47        imp = results['imp']
48
49        if node.type == syms.import_from:
50            # Some imps are top-level (eg: 'import ham')
51            # some are first level (eg: 'import ham.eggs')
52            # some are third level (eg: 'import ham.eggs as spam')
53            # Hence, the loop
54            while not hasattr(imp, 'value'):
55                imp = imp.children[0]
56            if self.probably_a_local_import(imp.value):
57                imp.value = "." + imp.value
58                imp.changed()
59                return node
60        else:
61            have_local = False
62            have_absolute = False
63            for mod_name in traverse_imports(imp):
64                if self.probably_a_local_import(mod_name):
65                    have_local = True
66                else:
67                    have_absolute = True
68            if have_absolute:
69                if have_local:
70                    # We won't handle both sibling and absolute imports in the
71                    # same statement at the moment.
72                    self.warning(node, "absolute and local imports together")
73                return
74
75            new = FromImport('.', [imp])
76            new.set_prefix(node.get_prefix())
77            return new
78
79    def probably_a_local_import(self, imp_name):
80        imp_name = imp_name.split('.', 1)[0]
81        base_path = dirname(self.filename)
82        base_path = join(base_path, imp_name)
83        # If there is no __init__.py next to the file its not in a package
84        # so can't be a relative import.
85        if not exists(join(dirname(base_path), '__init__.py')):
86            return False
87        for ext in ['.py', pathsep, '.pyc', '.so', '.sl', '.pyd']:
88            if exists(base_path + ext):
89                return True
90        return False