/contrib/python/tests/python/pants_test/contrib/python/checks/tasks/checkstyle/test_import_order.py

https://gitlab.com/Ivy001/pants · Python · 152 lines · 120 code · 29 blank · 3 comment · 5 complexity · a8d32a2511b7a11db06a4826969b8a75 MD5 · raw file

  1. # coding=utf-8
  2. # Copyright 2015 Pants project contributors (see CONTRIBUTORS.md).
  3. # Licensed under the Apache License, Version 2.0 (see LICENSE).
  4. from __future__ import (absolute_import, division, generators, nested_scopes, print_function,
  5. unicode_literals, with_statement)
  6. import ast
  7. import textwrap
  8. from pants_test.contrib.python.checks.tasks.checkstyle.plugin_test_base import \
  9. CheckstylePluginTestBase
  10. from pants.contrib.python.checks.tasks.checkstyle.common import Nit
  11. from pants.contrib.python.checks.tasks.checkstyle.import_order import ImportOrder, ImportType
  12. IMPORT_CHUNKS = {
  13. ImportType.STDLIB: """
  14. import ast
  15. from collections import namedtuple
  16. import io
  17. """,
  18. ImportType.TWITTER: """
  19. from twitter.common import app
  20. from twitter.common.dirutil import (
  21. safe_mkdtemp,
  22. safe_open,
  23. safe_rmtree)
  24. """,
  25. ImportType.GEN: """
  26. from gen.twitter.aurora.ttypes import TwitterTaskInfo
  27. """,
  28. ImportType.PACKAGE: """
  29. from .import_order import (
  30. ImportOrder,
  31. ImportType
  32. )
  33. """,
  34. ImportType.THIRD_PARTY: """
  35. from kazoo.client import KazooClient
  36. import zookeeper
  37. """,
  38. }
  39. def strip_newline(stmt):
  40. return textwrap.dedent('\n'.join(filter(None, stmt.splitlines())))
  41. def stitch_chunks(newlines, *chunks):
  42. return ('\n' * newlines).join([strip_newline(IMPORT_CHUNKS.get(c)) for c in chunks])
  43. class ImportOrderTest(CheckstylePluginTestBase):
  44. plugin_type = ImportOrder
  45. def get_import_chunk_types(self, import_type):
  46. chunks = list(self.get_plugin(IMPORT_CHUNKS[import_type]).iter_import_chunks())
  47. self.assertEqual(1, len(chunks))
  48. return tuple(map(type, chunks[0]))
  49. def test_classify_import_chunks(self):
  50. self.assertEqual((ast.Import, ast.ImportFrom, ast.Import),
  51. self.get_import_chunk_types(ImportType.STDLIB))
  52. self.assertEqual((ast.ImportFrom, ast.ImportFrom),
  53. self.get_import_chunk_types(ImportType.TWITTER))
  54. self.assertEqual((ast.ImportFrom,),
  55. self.get_import_chunk_types(ImportType.GEN))
  56. self.assertEqual((ast.ImportFrom,),
  57. self.get_import_chunk_types(ImportType.PACKAGE))
  58. self.assertEqual((ast.ImportFrom, ast.Import),
  59. self.get_import_chunk_types(ImportType.THIRD_PARTY))
  60. def test_classify_import(self):
  61. for import_type, chunk in IMPORT_CHUNKS.items():
  62. io = self.get_plugin(chunk)
  63. import_chunks = list(io.iter_import_chunks())
  64. self.assertEqual(1, len(import_chunks))
  65. module_types, chunk_errors = io.classify_imports(import_chunks[0])
  66. self.assertEqual(1, len(module_types))
  67. self.assertEqual(import_type, module_types.pop())
  68. self.assertEqual([], chunk_errors)
  69. PAIRS = (
  70. (ImportType.STDLIB, ImportType.TWITTER),
  71. (ImportType.TWITTER, ImportType.GEN),
  72. (ImportType.PACKAGE, ImportType.THIRD_PARTY),
  73. )
  74. def test_pairwise_classify(self):
  75. for first, second in self.PAIRS:
  76. io = self.get_plugin(stitch_chunks(1, first, second))
  77. import_chunks = list(io.iter_import_chunks())
  78. self.assertEqual(2, len(import_chunks))
  79. module_types, chunk_errors = io.classify_imports(import_chunks[0])
  80. self.assertEqual(1, len(module_types))
  81. self.assertEqual(0, len(chunk_errors))
  82. self.assertEqual(first, module_types.pop())
  83. module_types, chunk_errors = io.classify_imports(import_chunks[1])
  84. self.assertEqual(1, len(module_types))
  85. self.assertEqual(0, len(chunk_errors))
  86. self.assertEqual(second, module_types.pop())
  87. for second, first in self.PAIRS:
  88. io = self.get_plugin(stitch_chunks(1, first, second))
  89. import_chunks = list(io.iter_import_chunks())
  90. self.assertEqual(2, len(import_chunks))
  91. nits = list(io.nits())
  92. self.assertEqual(1, len(nits))
  93. self.assertEqual('T406', nits[0].code)
  94. self.assertEqual(Nit.ERROR, nits[0].severity)
  95. def test_multiple_imports_error(self):
  96. io = self.get_plugin(stitch_chunks(0, ImportType.STDLIB, ImportType.TWITTER))
  97. import_chunks = list(io.iter_import_chunks())
  98. self.assertEqual(1, len(import_chunks))
  99. module_types, chunk_errors = io.classify_imports(import_chunks[0])
  100. self.assertEqual(1, len(chunk_errors))
  101. self.assertEqual('T405', chunk_errors[0].code)
  102. self.assertEqual(Nit.ERROR, chunk_errors[0].severity)
  103. self.assertItemsEqual([ImportType.STDLIB, ImportType.TWITTER], module_types)
  104. io = self.get_plugin("""
  105. import io, pkg_resources
  106. """)
  107. import_chunks = list(io.iter_import_chunks())
  108. self.assertEqual(1, len(import_chunks))
  109. module_types, chunk_errors = io.classify_imports(import_chunks[0])
  110. self.assertEqual(3, len(chunk_errors))
  111. self.assertItemsEqual(['T403', 'T405', 'T402'],
  112. [chunk_error.code for chunk_error in chunk_errors])
  113. self.assertItemsEqual([ImportType.STDLIB, ImportType.THIRD_PARTY], module_types)
  114. def test_import_lexical_order(self):
  115. imp = """
  116. from twitter.common.dirutil import safe_rmtree, safe_mkdtemp
  117. """
  118. self.assertNit(imp, 'T401')
  119. def test_import_wildcard(self):
  120. imp = """
  121. from twitter.common.dirutil import *
  122. """
  123. self.assertNit(imp, 'T400')