PageRenderTime 47ms CodeModel.GetById 20ms RepoModel.GetById 0ms app.codeStats 0ms

/contrib/scrooge/src/python/pants/contrib/scrooge/tasks/scrooge_gen.py

https://gitlab.com/Ivy001/pants
Python | 267 lines | 241 code | 19 blank | 7 comment | 8 complexity | 98b4c4735f3326e31ed6b626d4f8f95f MD5 | raw file
  1. # coding=utf-8
  2. # Copyright 2014 Pants project contributors (see CONTRIBUTORS.md).
  3. # Licensed under the Apache License, Version 2.0 (see LICENSE).
  4. from __future__ import (absolute_import, division, generators, nested_scopes, print_function,
  5. unicode_literals, with_statement)
  6. import os
  7. import re
  8. import tempfile
  9. from collections import defaultdict, namedtuple
  10. from pants.backend.codegen.subsystems.thrift_defaults import ThriftDefaults
  11. from pants.backend.codegen.targets.java_thrift_library import JavaThriftLibrary
  12. from pants.backend.codegen.tasks.simple_codegen_task import SimpleCodegenTask
  13. from pants.backend.jvm.tasks.nailgun_task import NailgunTask
  14. from pants.base.exceptions import TargetDefinitionException, TaskError
  15. from pants.build_graph.address import Address
  16. from pants.build_graph.address_lookup_error import AddressLookupError
  17. from pants.util.dirutil import safe_mkdir, safe_open
  18. from pants.util.memo import memoized_method, memoized_property
  19. from twitter.common.collections import OrderedSet
  20. from pants.contrib.scrooge.tasks.thrift_util import calculate_compile_sources
  21. _RPC_STYLES = frozenset(['sync', 'finagle', 'ostrich'])
  22. class ScroogeGen(SimpleCodegenTask, NailgunTask):
  23. DepInfo = namedtuple('DepInfo', ['service', 'structs'])
  24. PartialCmd = namedtuple('PartialCmd', ['language', 'rpc_style', 'namespace_map'])
  25. @classmethod
  26. def register_options(cls, register):
  27. super(ScroogeGen, cls).register_options(register)
  28. register('--verbose', type=bool, help='Emit verbose output.')
  29. register('--strict', fingerprint=True, type=bool,
  30. help='Enable strict compilation.')
  31. register('--service-deps', default={}, advanced=True, type=dict,
  32. help='A map of language to targets to add as dependencies of '
  33. 'synthetic thrift libraries that contain services.')
  34. register('--structs-deps', default={}, advanced=True, type=dict,
  35. help='A map of language to targets to add as dependencies of '
  36. 'synthetic thrift libraries that contain structs.')
  37. register('--target-types',
  38. default={'scala': 'scala_library', 'java': 'java_library', 'android': 'java_library'},
  39. advanced=True,
  40. type=dict,
  41. help='Registered target types.')
  42. cls.register_jvm_tool(register, 'scrooge-gen')
  43. @classmethod
  44. def global_subsystems(cls):
  45. return super(ScroogeGen, cls).global_subsystems() + (ThriftDefaults,)
  46. @classmethod
  47. def product_types(cls):
  48. return ['java', 'scala']
  49. @classmethod
  50. def implementation_version(cls):
  51. return super(ScroogeGen, cls).implementation_version() + [('ScroogeGen', 3)]
  52. def __init__(self, *args, **kwargs):
  53. super(ScroogeGen, self).__init__(*args, **kwargs)
  54. self._thrift_defaults = ThriftDefaults.global_instance()
  55. self._depinfo = None
  56. # TODO(benjy): Use regular os-located tmpfiles, as we do everywhere else.
  57. def _tempname(self):
  58. # don't assume the user's cwd is buildroot
  59. pants_workdir = self.get_options().pants_workdir
  60. tmp_dir = os.path.join(pants_workdir, 'tmp')
  61. safe_mkdir(tmp_dir)
  62. fd, path = tempfile.mkstemp(dir=tmp_dir, prefix='')
  63. os.close(fd)
  64. return path
  65. def _resolve_deps(self, depmap):
  66. """Given a map of gen-key=>target specs, resolves the target specs into references."""
  67. deps = defaultdict(lambda: OrderedSet())
  68. for category, depspecs in depmap.items():
  69. dependencies = deps[category]
  70. for depspec in depspecs:
  71. dep_address = Address.parse(depspec)
  72. try:
  73. self.context.build_graph.maybe_inject_address_closure(dep_address)
  74. dependencies.add(self.context.build_graph.get_target(dep_address))
  75. except AddressLookupError as e:
  76. raise AddressLookupError('{}\n referenced from {} scope'.format(e, self.options_scope))
  77. return deps
  78. def _validate_language(self, target):
  79. language = self._thrift_defaults.language(target)
  80. if language not in self._registered_language_aliases():
  81. raise TargetDefinitionException(
  82. target,
  83. 'language {} not supported: expected one of {}.'.format(language, self._registered_language_aliases().keys()))
  84. return language
  85. def _validate_rpc_style(self, target):
  86. rpc_style = self._thrift_defaults.rpc_style(target)
  87. if rpc_style not in _RPC_STYLES:
  88. raise TargetDefinitionException(
  89. target,
  90. 'rpc_style {} not supported: expected one of {}.'.format(rpc_style, _RPC_STYLES))
  91. return rpc_style
  92. @memoized_method
  93. def _registered_language_aliases(self):
  94. return self.get_options().target_types
  95. @memoized_method
  96. def _target_type_for_language(self, language):
  97. alias_for_lang = self._registered_language_aliases()[language]
  98. registered_aliases = self.context.build_file_parser.registered_aliases()
  99. target_types = registered_aliases.target_types_by_alias.get(alias_for_lang, None)
  100. if not target_types:
  101. raise TaskError('Registered target type `{0}` for language `{1}` does not exist!'.format(alias_for_lang, language))
  102. if len(target_types) > 1:
  103. raise TaskError('More than one target type registered for language `{0}`'.format(language))
  104. return next(iter(target_types))
  105. def execute_codegen(self, target, target_workdir):
  106. self._validate_compiler_configs([target])
  107. self._must_have_sources(target)
  108. partial_cmd = self.PartialCmd(
  109. language=self._validate_language(target),
  110. rpc_style=self._validate_rpc_style(target),
  111. namespace_map=tuple(sorted(target.namespace_map.items()) if target.namespace_map else ()))
  112. self.gen(partial_cmd, target, target_workdir)
  113. def gen(self, partial_cmd, target, target_workdir):
  114. import_paths, _ = calculate_compile_sources([target], self.is_gentarget)
  115. args = []
  116. for import_path in import_paths:
  117. args.extend(['--import-path', import_path])
  118. args.extend(['--language', partial_cmd.language])
  119. for lhs, rhs in partial_cmd.namespace_map:
  120. args.extend(['--namespace-map', '%s=%s' % (lhs, rhs)])
  121. if partial_cmd.rpc_style == 'ostrich':
  122. args.append('--finagle')
  123. args.append('--ostrich')
  124. elif partial_cmd.rpc_style == 'finagle':
  125. args.append('--finagle')
  126. args.extend(['--dest', target_workdir])
  127. if not self.get_options().strict:
  128. args.append('--disable-strict')
  129. if self.get_options().verbose:
  130. args.append('--verbose')
  131. gen_file_map_path = os.path.relpath(self._tempname())
  132. args.extend(['--gen-file-map', gen_file_map_path])
  133. args.extend(target.sources_relative_to_buildroot())
  134. classpath = self.tool_classpath('scrooge-gen')
  135. jvm_options = list(self.get_options().jvm_options)
  136. jvm_options.append('-Dfile.encoding=UTF-8')
  137. returncode = self.runjava(classpath=classpath,
  138. main='com.twitter.scrooge.Main',
  139. jvm_options=jvm_options,
  140. args=args,
  141. workunit_name='scrooge-gen')
  142. if 0 != returncode:
  143. raise TaskError('Scrooge compiler exited non-zero for {} ({})'.format(target, returncode))
  144. SERVICE_PARSER = re.compile(r'^\s*service\s+(?:[^\s{]+)')
  145. def _declares_service(self, source):
  146. with open(source) as thrift:
  147. return any(line for line in thrift if self.SERVICE_PARSER.search(line))
  148. def parse_gen_file_map(self, gen_file_map_path, outdir):
  149. d = defaultdict(set)
  150. with safe_open(gen_file_map_path, 'r') as deps:
  151. for dep in deps:
  152. src, cls = dep.strip().split('->')
  153. src = os.path.relpath(src.strip())
  154. cls = os.path.relpath(cls.strip(), outdir)
  155. d[src].add(cls)
  156. return d
  157. def is_gentarget(self, target):
  158. if not isinstance(target, JavaThriftLibrary):
  159. return False
  160. # We only handle requests for 'scrooge' compilation and not, for example 'thrift', aka the
  161. # Apache thrift compiler
  162. return self._thrift_defaults.compiler(target) == 'scrooge'
  163. def _validate_compiler_configs(self, targets):
  164. assert len(targets) == 1, ("TODO: This method now only ever receives one target. Simplify.")
  165. ValidateCompilerConfig = namedtuple('ValidateCompilerConfig', ['language', 'rpc_style'])
  166. def compiler_config(tgt):
  167. # Note compiler is not present in this signature. At this time
  168. # Scrooge and the Apache thrift generators produce identical
  169. # java sources, and the Apache generator does not produce scala
  170. # sources. As there's no permutation allowing the creation of
  171. # incompatible sources with the same language+rpc_style we omit
  172. # the compiler from the signature at this time.
  173. return ValidateCompilerConfig(language=self._thrift_defaults.language(tgt),
  174. rpc_style=self._thrift_defaults.rpc_style(tgt))
  175. mismatched_compiler_configs = defaultdict(set)
  176. for target in filter(lambda t: isinstance(t, JavaThriftLibrary), targets):
  177. mycompilerconfig = compiler_config(target)
  178. def collect(dep):
  179. if mycompilerconfig != compiler_config(dep):
  180. mismatched_compiler_configs[target].add(dep)
  181. target.walk(collect, predicate=lambda t: isinstance(t, JavaThriftLibrary))
  182. if mismatched_compiler_configs:
  183. msg = ['Thrift dependency trees must be generated with a uniform compiler configuration.\n\n']
  184. for tgt in sorted(mismatched_compiler_configs.keys()):
  185. msg.append('%s - %s\n' % (tgt, compiler_config(tgt)))
  186. for dep in mismatched_compiler_configs[tgt]:
  187. msg.append(' %s - %s\n' % (dep, compiler_config(dep)))
  188. raise TaskError(''.join(msg))
  189. def _must_have_sources(self, target):
  190. if isinstance(target, JavaThriftLibrary) and not target.payload.sources.source_paths:
  191. raise TargetDefinitionException(target, 'no thrift files found')
  192. def synthetic_target_type(self, target):
  193. language = self._thrift_defaults.language(target)
  194. return self._target_type_for_language(language)
  195. def synthetic_target_extra_dependencies(self, target, target_workdir):
  196. deps = OrderedSet(self._thrift_dependencies_for_target(target))
  197. deps.update(target.dependencies)
  198. return deps
  199. def _thrift_dependencies_for_target(self, target):
  200. dep_info = self._resolved_dep_info
  201. target_declares_service = any(self._declares_service(source)
  202. for source in target.sources_relative_to_buildroot())
  203. language = self._thrift_defaults.language(target)
  204. if target_declares_service:
  205. return dep_info.service[language]
  206. else:
  207. return dep_info.structs[language]
  208. @memoized_property
  209. def _resolved_dep_info(self):
  210. return ScroogeGen.DepInfo(self._resolve_deps(self.get_options().service_deps),
  211. self._resolve_deps(self.get_options().structs_deps))
  212. @property
  213. def _copy_target_attributes(self):
  214. return ['provides']