PageRenderTime 69ms CodeModel.GetById 30ms RepoModel.GetById 0ms app.codeStats 0ms

/pypy/jit/metainterp/warmspot.py

https://bitbucket.org/pypy/pypy/
Python | 936 lines | 782 code | 81 blank | 73 comment | 148 complexity | 16457cd7bdf1273b1133b3f351698eaa MD5 | raw file
Possible License(s): AGPL-3.0, BSD-3-Clause, Apache-2.0
  1. import sys, py
  2. from pypy.tool.sourcetools import func_with_new_name
  3. from pypy.rpython.lltypesystem import lltype, llmemory
  4. from pypy.rpython.annlowlevel import llhelper, MixLevelHelperAnnotator,\
  5. cast_base_ptr_to_instance, hlstr
  6. from pypy.annotation import model as annmodel
  7. from pypy.rpython.llinterp import LLException
  8. from pypy.rpython.test.test_llinterp import get_interpreter, clear_tcache
  9. from pypy.objspace.flow.model import SpaceOperation, Variable, Constant
  10. from pypy.objspace.flow.model import checkgraph, Link, copygraph
  11. from pypy.rlib.objectmodel import we_are_translated
  12. from pypy.rlib.unroll import unrolling_iterable
  13. from pypy.rlib.debug import fatalerror
  14. from pypy.rlib.rstackovf import StackOverflow
  15. from pypy.translator.simplify import get_functype
  16. from pypy.translator.unsimplify import call_final_function
  17. from pypy.jit.metainterp import history, pyjitpl, gc, memmgr
  18. from pypy.jit.metainterp.pyjitpl import MetaInterpStaticData
  19. from pypy.jit.metainterp.jitprof import Profiler, EmptyProfiler
  20. from pypy.jit.metainterp.jitexc import JitException
  21. from pypy.jit.metainterp.jitdriver import JitDriverStaticData
  22. from pypy.jit.codewriter import support, codewriter, longlong
  23. from pypy.jit.codewriter.policy import JitPolicy
  24. from pypy.jit.codewriter.effectinfo import EffectInfo
  25. from pypy.jit.metainterp.optimizeopt import ALL_OPTS_NAMES
  26. # ____________________________________________________________
  27. # Bootstrapping
  28. def apply_jit(translator, backend_name="auto", inline=False,
  29. enable_opts=ALL_OPTS_NAMES, **kwds):
  30. if 'CPUClass' not in kwds:
  31. from pypy.jit.backend.detect_cpu import getcpuclass
  32. kwds['CPUClass'] = getcpuclass(backend_name)
  33. ProfilerClass = Profiler
  34. # Always use Profiler here, which should have a very low impact.
  35. # Otherwise you can try with ProfilerClass = EmptyProfiler.
  36. warmrunnerdesc = WarmRunnerDesc(translator,
  37. translate_support_code=True,
  38. listops=True,
  39. no_stats = True,
  40. ProfilerClass = ProfilerClass,
  41. **kwds)
  42. for jd in warmrunnerdesc.jitdrivers_sd:
  43. jd.warmstate.set_param_inlining(inline)
  44. jd.warmstate.set_param_enable_opts(enable_opts)
  45. warmrunnerdesc.finish()
  46. translator.warmrunnerdesc = warmrunnerdesc # for later debugging
  47. def ll_meta_interp(function, args, backendopt=False, type_system='lltype',
  48. listcomp=False, translationoptions={}, **kwds):
  49. if listcomp:
  50. extraconfigopts = {'translation.list_comprehension_operations': True}
  51. else:
  52. extraconfigopts = {}
  53. for key, value in translationoptions.items():
  54. extraconfigopts['translation.' + key] = value
  55. interp, graph = get_interpreter(function, args,
  56. backendopt=False, # will be done below
  57. type_system=type_system,
  58. **extraconfigopts)
  59. clear_tcache()
  60. return jittify_and_run(interp, graph, args, backendopt=backendopt, **kwds)
  61. def jittify_and_run(interp, graph, args, repeat=1, graph_and_interp_only=False,
  62. backendopt=False, trace_limit=sys.maxint,
  63. inline=False, loop_longevity=0, retrace_limit=5,
  64. function_threshold=4,
  65. enable_opts=ALL_OPTS_NAMES, max_retrace_guards=15, **kwds):
  66. from pypy.config.config import ConfigError
  67. translator = interp.typer.annotator.translator
  68. try:
  69. translator.config.translation.gc = "boehm"
  70. except ConfigError:
  71. pass
  72. try:
  73. translator.config.translation.list_comprehension_operations = True
  74. except ConfigError:
  75. pass
  76. try:
  77. translator.config.translation.jit_ffi = True
  78. except ConfigError:
  79. pass
  80. warmrunnerdesc = WarmRunnerDesc(translator, backendopt=backendopt, **kwds)
  81. for jd in warmrunnerdesc.jitdrivers_sd:
  82. jd.warmstate.set_param_threshold(3) # for tests
  83. jd.warmstate.set_param_function_threshold(function_threshold)
  84. jd.warmstate.set_param_trace_eagerness(2) # for tests
  85. jd.warmstate.set_param_trace_limit(trace_limit)
  86. jd.warmstate.set_param_inlining(inline)
  87. jd.warmstate.set_param_loop_longevity(loop_longevity)
  88. jd.warmstate.set_param_retrace_limit(retrace_limit)
  89. jd.warmstate.set_param_max_retrace_guards(max_retrace_guards)
  90. jd.warmstate.set_param_enable_opts(enable_opts)
  91. warmrunnerdesc.finish()
  92. if graph_and_interp_only:
  93. return interp, graph
  94. res = interp.eval_graph(graph, args)
  95. if not kwds.get('translate_support_code', False):
  96. warmrunnerdesc.metainterp_sd.profiler.finish()
  97. warmrunnerdesc.metainterp_sd.cpu.finish_once()
  98. print '~~~ return value:', res
  99. while repeat > 1:
  100. print '~' * 79
  101. res1 = interp.eval_graph(graph, args)
  102. if isinstance(res, int):
  103. assert res1 == res
  104. repeat -= 1
  105. return res
  106. def rpython_ll_meta_interp(function, args, backendopt=True, **kwds):
  107. return ll_meta_interp(function, args, backendopt=backendopt,
  108. translate_support_code=True, **kwds)
  109. def _find_jit_marker(graphs, marker_name, check_driver=True):
  110. results = []
  111. for graph in graphs:
  112. for block in graph.iterblocks():
  113. for i in range(len(block.operations)):
  114. op = block.operations[i]
  115. if (op.opname == 'jit_marker' and
  116. op.args[0].value == marker_name and
  117. (not check_driver or op.args[1].value is None or
  118. op.args[1].value.active)): # the jitdriver
  119. results.append((graph, block, i))
  120. return results
  121. def find_can_enter_jit(graphs):
  122. return _find_jit_marker(graphs, 'can_enter_jit')
  123. def find_loop_headers(graphs):
  124. return _find_jit_marker(graphs, 'loop_header')
  125. def find_jit_merge_points(graphs):
  126. results = _find_jit_marker(graphs, 'jit_merge_point')
  127. if not results:
  128. raise Exception("no jit_merge_point found!")
  129. seen = set([graph for graph, block, pos in results])
  130. assert len(seen) == len(results), (
  131. "found several jit_merge_points in the same graph")
  132. return results
  133. def find_access_helpers(graphs):
  134. return _find_jit_marker(graphs, 'access_helper', False)
  135. def locate_jit_merge_point(graph):
  136. [(graph, block, pos)] = find_jit_merge_points([graph])
  137. return block, pos, block.operations[pos]
  138. def find_set_param(graphs):
  139. return _find_jit_marker(graphs, 'set_param')
  140. def find_force_quasi_immutable(graphs):
  141. results = []
  142. for graph in graphs:
  143. for block in graph.iterblocks():
  144. for i in range(len(block.operations)):
  145. op = block.operations[i]
  146. if op.opname == 'jit_force_quasi_immutable':
  147. results.append((graph, block, i))
  148. return results
  149. def get_stats():
  150. return pyjitpl._warmrunnerdesc.stats
  151. def reset_stats():
  152. pyjitpl._warmrunnerdesc.stats.clear()
  153. def get_translator():
  154. return pyjitpl._warmrunnerdesc.translator
  155. def debug_checks():
  156. stats = get_stats()
  157. stats.maybe_view()
  158. stats.check_consistency()
  159. class ContinueRunningNormallyBase(JitException):
  160. pass
  161. # ____________________________________________________________
  162. class WarmRunnerDesc(object):
  163. def __init__(self, translator, policy=None, backendopt=True, CPUClass=None,
  164. ProfilerClass=EmptyProfiler, **kwds):
  165. pyjitpl._warmrunnerdesc = self # this is a global for debugging only!
  166. self.set_translator(translator)
  167. self.memory_manager = memmgr.MemoryManager()
  168. self.build_cpu(CPUClass, **kwds)
  169. self.find_portals()
  170. self.codewriter = codewriter.CodeWriter(self.cpu, self.jitdrivers_sd)
  171. if policy is None:
  172. policy = JitPolicy()
  173. policy.set_supports_floats(self.cpu.supports_floats)
  174. policy.set_supports_longlong(self.cpu.supports_longlong)
  175. policy.set_supports_singlefloats(self.cpu.supports_singlefloats)
  176. graphs = self.codewriter.find_all_graphs(policy)
  177. policy.dump_unsafe_loops()
  178. self.check_access_directly_sanity(graphs)
  179. if backendopt:
  180. self.prejit_optimizations(policy, graphs)
  181. elif self.opt.listops:
  182. self.prejit_optimizations_minimal_inline(policy, graphs)
  183. self.build_meta_interp(ProfilerClass)
  184. self.make_args_specifications()
  185. #
  186. from pypy.jit.metainterp.virtualref import VirtualRefInfo
  187. vrefinfo = VirtualRefInfo(self)
  188. self.codewriter.setup_vrefinfo(vrefinfo)
  189. #
  190. self.hooks = policy.jithookiface
  191. self.make_virtualizable_infos()
  192. self.make_exception_classes()
  193. self.make_driverhook_graphs()
  194. self.make_enter_functions()
  195. self.rewrite_jit_merge_points(policy)
  196. verbose = False # not self.cpu.translate_support_code
  197. self.rewrite_access_helpers()
  198. self.codewriter.make_jitcodes(verbose=verbose)
  199. self.rewrite_can_enter_jits()
  200. self.rewrite_set_param()
  201. self.rewrite_force_virtual(vrefinfo)
  202. self.rewrite_force_quasi_immutable()
  203. self.add_finish()
  204. self.metainterp_sd.finish_setup(self.codewriter)
  205. def finish(self):
  206. vinfos = set([jd.virtualizable_info for jd in self.jitdrivers_sd])
  207. for vinfo in vinfos:
  208. if vinfo is not None:
  209. vinfo.finish()
  210. if self.cpu.translate_support_code:
  211. self.annhelper.finish()
  212. def _freeze_(self):
  213. return True
  214. def set_translator(self, translator):
  215. self.translator = translator
  216. self.rtyper = translator.rtyper
  217. self.gcdescr = gc.get_description(translator.config)
  218. def find_portals(self):
  219. self.jitdrivers_sd = []
  220. graphs = self.translator.graphs
  221. for jit_merge_point_pos in find_jit_merge_points(graphs):
  222. self.split_graph_and_record_jitdriver(*jit_merge_point_pos)
  223. #
  224. assert (len(set([jd.jitdriver for jd in self.jitdrivers_sd])) ==
  225. len(self.jitdrivers_sd)), \
  226. "there are multiple jit_merge_points with the same jitdriver"
  227. def split_graph_and_record_jitdriver(self, graph, block, pos):
  228. op = block.operations[pos]
  229. jd = JitDriverStaticData()
  230. jd._jit_merge_point_in = graph
  231. args = op.args[2:]
  232. s_binding = self.translator.annotator.binding
  233. jd._portal_args_s = [s_binding(v) for v in args]
  234. graph = copygraph(graph)
  235. [jmpp] = find_jit_merge_points([graph])
  236. graph.startblock = support.split_before_jit_merge_point(*jmpp)
  237. # a crash in the following checkgraph() means that you forgot
  238. # to list some variable in greens=[] or reds=[] in JitDriver,
  239. # or that a jit_merge_point() takes a constant as an argument.
  240. checkgraph(graph)
  241. for v in graph.getargs():
  242. assert isinstance(v, Variable)
  243. assert len(dict.fromkeys(graph.getargs())) == len(graph.getargs())
  244. self.translator.graphs.append(graph)
  245. jd.portal_graph = graph
  246. # it's a bit unbelievable to have a portal without func
  247. assert hasattr(graph, "func")
  248. graph.func._dont_inline_ = True
  249. graph.func._jit_unroll_safe_ = True
  250. jd.jitdriver = block.operations[pos].args[1].value
  251. jd.portal_runner_ptr = "<not set so far>"
  252. jd.result_type = history.getkind(jd.portal_graph.getreturnvar()
  253. .concretetype)[0]
  254. self.jitdrivers_sd.append(jd)
  255. def check_access_directly_sanity(self, graphs):
  256. from pypy.translator.backendopt.inline import collect_called_graphs
  257. jit_graphs = set(graphs)
  258. for graph in collect_called_graphs(self.translator.entry_point_graph,
  259. self.translator):
  260. if graph in jit_graphs:
  261. continue
  262. assert not getattr(graph, 'access_directly', False)
  263. def prejit_optimizations(self, policy, graphs):
  264. from pypy.translator.backendopt.all import backend_optimizations
  265. backend_optimizations(self.translator,
  266. graphs=graphs,
  267. merge_if_blocks=True,
  268. constfold=True,
  269. raisingop2direct_call=False,
  270. remove_asserts=True,
  271. really_remove_asserts=True)
  272. def prejit_optimizations_minimal_inline(self, policy, graphs):
  273. from pypy.translator.backendopt.inline import auto_inline_graphs
  274. auto_inline_graphs(self.translator, graphs, 0.01)
  275. def build_cpu(self, CPUClass, translate_support_code=False,
  276. no_stats=False, supports_floats=True,
  277. supports_longlong=True, supports_singlefloats=True,
  278. **kwds):
  279. assert CPUClass is not None
  280. self.opt = history.Options(**kwds)
  281. if no_stats:
  282. stats = history.NoStats()
  283. else:
  284. stats = history.Stats()
  285. self.stats = stats
  286. if translate_support_code:
  287. self.annhelper = MixLevelHelperAnnotator(self.translator.rtyper)
  288. cpu = CPUClass(self.translator.rtyper, self.stats, self.opt,
  289. translate_support_code, gcdescr=self.gcdescr)
  290. if not supports_floats: cpu.supports_floats = False
  291. if not supports_longlong: cpu.supports_longlong = False
  292. if not supports_singlefloats: cpu.supports_singlefloats = False
  293. self.cpu = cpu
  294. def build_meta_interp(self, ProfilerClass):
  295. self.metainterp_sd = MetaInterpStaticData(self.cpu,
  296. self.opt,
  297. ProfilerClass=ProfilerClass,
  298. warmrunnerdesc=self)
  299. def make_virtualizable_infos(self):
  300. vinfos = {}
  301. for jd in self.jitdrivers_sd:
  302. #
  303. jd.greenfield_info = None
  304. for name in jd.jitdriver.greens:
  305. if '.' in name:
  306. from pypy.jit.metainterp.greenfield import GreenFieldInfo
  307. jd.greenfield_info = GreenFieldInfo(self.cpu, jd)
  308. break
  309. #
  310. if not jd.jitdriver.virtualizables:
  311. jd.virtualizable_info = None
  312. jd.index_of_virtualizable = -1
  313. continue
  314. else:
  315. assert jd.greenfield_info is None, "XXX not supported yet"
  316. #
  317. jitdriver = jd.jitdriver
  318. assert len(jitdriver.virtualizables) == 1 # for now
  319. [vname] = jitdriver.virtualizables
  320. # XXX skip the Voids here too
  321. jd.index_of_virtualizable = jitdriver.reds.index(vname)
  322. #
  323. index = jd.num_green_args + jd.index_of_virtualizable
  324. VTYPEPTR = jd._JIT_ENTER_FUNCTYPE.ARGS[index]
  325. if VTYPEPTR not in vinfos:
  326. from pypy.jit.metainterp.virtualizable import VirtualizableInfo
  327. vinfos[VTYPEPTR] = VirtualizableInfo(self, VTYPEPTR)
  328. jd.virtualizable_info = vinfos[VTYPEPTR]
  329. def make_exception_classes(self):
  330. class DoneWithThisFrameVoid(JitException):
  331. def __str__(self):
  332. return 'DoneWithThisFrameVoid()'
  333. class DoneWithThisFrameInt(JitException):
  334. def __init__(self, result):
  335. assert lltype.typeOf(result) is lltype.Signed
  336. self.result = result
  337. def __str__(self):
  338. return 'DoneWithThisFrameInt(%s)' % (self.result,)
  339. class DoneWithThisFrameRef(JitException):
  340. def __init__(self, cpu, result):
  341. assert lltype.typeOf(result) == cpu.ts.BASETYPE
  342. self.result = result
  343. def __str__(self):
  344. return 'DoneWithThisFrameRef(%s)' % (self.result,)
  345. class DoneWithThisFrameFloat(JitException):
  346. def __init__(self, result):
  347. assert lltype.typeOf(result) is longlong.FLOATSTORAGE
  348. self.result = result
  349. def __str__(self):
  350. return 'DoneWithThisFrameFloat(%s)' % (self.result,)
  351. class ExitFrameWithExceptionRef(JitException):
  352. def __init__(self, cpu, value):
  353. assert lltype.typeOf(value) == cpu.ts.BASETYPE
  354. self.value = value
  355. def __str__(self):
  356. return 'ExitFrameWithExceptionRef(%s)' % (self.value,)
  357. class ContinueRunningNormally(ContinueRunningNormallyBase):
  358. def __init__(self, gi, gr, gf, ri, rr, rf):
  359. # the six arguments are: lists of green ints, greens refs,
  360. # green floats, red ints, red refs, and red floats.
  361. self.green_int = gi
  362. self.green_ref = gr
  363. self.green_float = gf
  364. self.red_int = ri
  365. self.red_ref = rr
  366. self.red_float = rf
  367. def __str__(self):
  368. return 'ContinueRunningNormally(%s, %s, %s, %s, %s, %s)' % (
  369. self.green_int, self.green_ref, self.green_float,
  370. self.red_int, self.red_ref, self.red_float)
  371. # XXX there is no point any more to not just have the exceptions
  372. # as globals
  373. self.DoneWithThisFrameVoid = DoneWithThisFrameVoid
  374. self.DoneWithThisFrameInt = DoneWithThisFrameInt
  375. self.DoneWithThisFrameRef = DoneWithThisFrameRef
  376. self.DoneWithThisFrameFloat = DoneWithThisFrameFloat
  377. self.ExitFrameWithExceptionRef = ExitFrameWithExceptionRef
  378. self.ContinueRunningNormally = ContinueRunningNormally
  379. self.metainterp_sd.DoneWithThisFrameVoid = DoneWithThisFrameVoid
  380. self.metainterp_sd.DoneWithThisFrameInt = DoneWithThisFrameInt
  381. self.metainterp_sd.DoneWithThisFrameRef = DoneWithThisFrameRef
  382. self.metainterp_sd.DoneWithThisFrameFloat = DoneWithThisFrameFloat
  383. self.metainterp_sd.ExitFrameWithExceptionRef = ExitFrameWithExceptionRef
  384. self.metainterp_sd.ContinueRunningNormally = ContinueRunningNormally
  385. def make_enter_functions(self):
  386. for jd in self.jitdrivers_sd:
  387. self.make_enter_function(jd)
  388. def make_enter_function(self, jd):
  389. from pypy.jit.metainterp.warmstate import WarmEnterState
  390. state = WarmEnterState(self, jd)
  391. maybe_compile_and_run = state.make_entry_point()
  392. jd.warmstate = state
  393. def crash_in_jit(e):
  394. tb = not we_are_translated() and sys.exc_info()[2]
  395. try:
  396. raise e
  397. except JitException:
  398. raise # go through
  399. except MemoryError:
  400. raise # go through
  401. except StackOverflow:
  402. raise # go through
  403. except Exception, e:
  404. if not we_are_translated():
  405. print "~~~ Crash in JIT!"
  406. print '~~~ %s: %s' % (e.__class__, e)
  407. if sys.stdout == sys.__stdout__:
  408. import pdb; pdb.post_mortem(tb)
  409. raise e.__class__, e, tb
  410. fatalerror('~~~ Crash in JIT! %s' % (e,), traceback=True)
  411. crash_in_jit._dont_inline_ = True
  412. if self.translator.rtyper.type_system.name == 'lltypesystem':
  413. def maybe_enter_jit(*args):
  414. try:
  415. maybe_compile_and_run(state.increment_threshold, *args)
  416. except Exception, e:
  417. crash_in_jit(e)
  418. maybe_enter_jit._always_inline_ = True
  419. else:
  420. def maybe_enter_jit(*args):
  421. maybe_compile_and_run(state.increment_threshold, *args)
  422. maybe_enter_jit._always_inline_ = True
  423. jd._maybe_enter_jit_fn = maybe_enter_jit
  424. def maybe_enter_from_start(*args):
  425. maybe_compile_and_run(state.increment_function_threshold, *args)
  426. maybe_enter_from_start._always_inline_ = True
  427. jd._maybe_enter_from_start_fn = maybe_enter_from_start
  428. def make_driverhook_graphs(self):
  429. from pypy.rlib.jit import BaseJitCell
  430. bk = self.rtyper.annotator.bookkeeper
  431. classdef = bk.getuniqueclassdef(BaseJitCell)
  432. s_BaseJitCell_or_None = annmodel.SomeInstance(classdef,
  433. can_be_None=True)
  434. s_BaseJitCell_not_None = annmodel.SomeInstance(classdef)
  435. s_Str = annmodel.SomeString()
  436. #
  437. annhelper = MixLevelHelperAnnotator(self.translator.rtyper)
  438. for jd in self.jitdrivers_sd:
  439. jd._set_jitcell_at_ptr = self._make_hook_graph(jd,
  440. annhelper, jd.jitdriver.set_jitcell_at, annmodel.s_None,
  441. s_BaseJitCell_not_None)
  442. jd._get_jitcell_at_ptr = self._make_hook_graph(jd,
  443. annhelper, jd.jitdriver.get_jitcell_at, s_BaseJitCell_or_None)
  444. jd._get_printable_location_ptr = self._make_hook_graph(jd,
  445. annhelper, jd.jitdriver.get_printable_location, s_Str)
  446. jd._confirm_enter_jit_ptr = self._make_hook_graph(jd,
  447. annhelper, jd.jitdriver.confirm_enter_jit, annmodel.s_Bool,
  448. onlygreens=False)
  449. jd._can_never_inline_ptr = self._make_hook_graph(jd,
  450. annhelper, jd.jitdriver.can_never_inline, annmodel.s_Bool)
  451. jd._should_unroll_one_iteration_ptr = self._make_hook_graph(jd,
  452. annhelper, jd.jitdriver.should_unroll_one_iteration,
  453. annmodel.s_Bool)
  454. annhelper.finish()
  455. def _make_hook_graph(self, jitdriver_sd, annhelper, func,
  456. s_result, s_first_arg=None, onlygreens=True):
  457. if func is None:
  458. return None
  459. #
  460. extra_args_s = []
  461. if s_first_arg is not None:
  462. extra_args_s.append(s_first_arg)
  463. #
  464. args_s = jitdriver_sd._portal_args_s
  465. if onlygreens:
  466. args_s = args_s[:len(jitdriver_sd._green_args_spec)]
  467. graph = annhelper.getgraph(func, extra_args_s + args_s, s_result)
  468. funcptr = annhelper.graph2delayed(graph)
  469. return funcptr
  470. def make_args_specifications(self):
  471. for jd in self.jitdrivers_sd:
  472. self.make_args_specification(jd)
  473. def make_args_specification(self, jd):
  474. graph = jd._jit_merge_point_in
  475. _, _, op = locate_jit_merge_point(graph)
  476. greens_v, reds_v = support.decode_hp_hint_args(op)
  477. ALLARGS = [v.concretetype for v in (greens_v + reds_v)]
  478. jd._green_args_spec = [v.concretetype for v in greens_v]
  479. jd.red_args_types = [history.getkind(v.concretetype) for v in reds_v]
  480. jd.num_green_args = len(jd._green_args_spec)
  481. jd.num_red_args = len(jd.red_args_types)
  482. RESTYPE = graph.getreturnvar().concretetype
  483. (jd._JIT_ENTER_FUNCTYPE,
  484. jd._PTR_JIT_ENTER_FUNCTYPE) = self.cpu.ts.get_FuncType(ALLARGS, lltype.Void)
  485. (jd._PORTAL_FUNCTYPE,
  486. jd._PTR_PORTAL_FUNCTYPE) = self.cpu.ts.get_FuncType(ALLARGS, RESTYPE)
  487. #
  488. if jd.result_type == 'v':
  489. ASMRESTYPE = lltype.Void
  490. elif jd.result_type == history.INT:
  491. ASMRESTYPE = lltype.Signed
  492. elif jd.result_type == history.REF:
  493. ASMRESTYPE = llmemory.GCREF
  494. elif jd.result_type == history.FLOAT:
  495. ASMRESTYPE = lltype.Float
  496. else:
  497. assert False
  498. (_, jd._PTR_ASSEMBLER_HELPER_FUNCTYPE) = self.cpu.ts.get_FuncType(
  499. [lltype.Signed, llmemory.GCREF], ASMRESTYPE)
  500. def rewrite_can_enter_jits(self):
  501. sublists = {}
  502. for jd in self.jitdrivers_sd:
  503. sublists[jd.jitdriver] = jd, []
  504. jd.no_loop_header = True
  505. #
  506. loop_headers = find_loop_headers(self.translator.graphs)
  507. for graph, block, index in loop_headers:
  508. op = block.operations[index]
  509. jitdriver = op.args[1].value
  510. assert jitdriver in sublists, \
  511. "loop_header with no matching jit_merge_point"
  512. jd, sublist = sublists[jitdriver]
  513. jd.no_loop_header = False
  514. #
  515. can_enter_jits = find_can_enter_jit(self.translator.graphs)
  516. for graph, block, index in can_enter_jits:
  517. op = block.operations[index]
  518. jitdriver = op.args[1].value
  519. assert jitdriver in sublists, \
  520. "can_enter_jit with no matching jit_merge_point"
  521. jd, sublist = sublists[jitdriver]
  522. origportalgraph = jd._jit_merge_point_in
  523. if graph is not origportalgraph:
  524. sublist.append((graph, block, index))
  525. jd.no_loop_header = False
  526. else:
  527. pass # a 'can_enter_jit' before the 'jit-merge_point', but
  528. # originally in the same function: we ignore it here
  529. # see e.g. test_jitdriver.test_simple
  530. for jd in self.jitdrivers_sd:
  531. _, sublist = sublists[jd.jitdriver]
  532. self.rewrite_can_enter_jit(jd, sublist)
  533. def rewrite_can_enter_jit(self, jd, can_enter_jits):
  534. FUNCPTR = jd._PTR_JIT_ENTER_FUNCTYPE
  535. jit_enter_fnptr = self.helper_func(FUNCPTR, jd._maybe_enter_jit_fn)
  536. if len(can_enter_jits) == 0:
  537. # see test_warmspot.test_no_loop_at_all
  538. operations = jd.portal_graph.startblock.operations
  539. op1 = operations[0]
  540. assert (op1.opname == 'jit_marker' and
  541. op1.args[0].value == 'jit_merge_point')
  542. op0 = SpaceOperation(
  543. 'jit_marker',
  544. [Constant('can_enter_jit', lltype.Void)] + op1.args[1:],
  545. None)
  546. operations.insert(0, op0)
  547. can_enter_jits = [(jd.portal_graph, jd.portal_graph.startblock, 0)]
  548. for graph, block, index in can_enter_jits:
  549. if graph is jd._jit_merge_point_in:
  550. continue
  551. op = block.operations[index]
  552. greens_v, reds_v = support.decode_hp_hint_args(op)
  553. args_v = greens_v + reds_v
  554. vlist = [Constant(jit_enter_fnptr, FUNCPTR)] + args_v
  555. v_result = Variable()
  556. v_result.concretetype = lltype.Void
  557. newop = SpaceOperation('direct_call', vlist, v_result)
  558. block.operations[index] = newop
  559. def helper_func(self, FUNCPTR, func):
  560. if not self.cpu.translate_support_code:
  561. return llhelper(FUNCPTR, func)
  562. FUNC = get_functype(FUNCPTR)
  563. args_s = [annmodel.lltype_to_annotation(ARG) for ARG in FUNC.ARGS]
  564. s_result = annmodel.lltype_to_annotation(FUNC.RESULT)
  565. graph = self.annhelper.getgraph(func, args_s, s_result)
  566. return self.annhelper.graph2delayed(graph, FUNC)
  567. def rewrite_access_helpers(self):
  568. ah = find_access_helpers(self.translator.graphs)
  569. for graph, block, index in ah:
  570. op = block.operations[index]
  571. self.rewrite_access_helper(op)
  572. def rewrite_access_helper(self, op):
  573. ARGS = [arg.concretetype for arg in op.args[2:]]
  574. RESULT = op.result.concretetype
  575. FUNCPTR = lltype.Ptr(lltype.FuncType(ARGS, RESULT))
  576. # make sure we make a copy of function so it no longer belongs
  577. # to extregistry
  578. func = op.args[1].value
  579. func = func_with_new_name(func, func.func_name + '_compiled')
  580. ptr = self.helper_func(FUNCPTR, func)
  581. op.opname = 'direct_call'
  582. op.args = [Constant(ptr, FUNCPTR)] + op.args[2:]
  583. def rewrite_jit_merge_points(self, policy):
  584. for jd in self.jitdrivers_sd:
  585. self.rewrite_jit_merge_point(jd, policy)
  586. def rewrite_jit_merge_point(self, jd, policy):
  587. #
  588. # Mutate the original portal graph from this:
  589. #
  590. # def original_portal(..):
  591. # stuff
  592. # while 1:
  593. # jit_merge_point(*args)
  594. # more stuff
  595. #
  596. # to that:
  597. #
  598. # def original_portal(..):
  599. # stuff
  600. # return portal_runner(*args)
  601. #
  602. # def portal_runner(*args):
  603. # while 1:
  604. # try:
  605. # return portal(*args)
  606. # except ContinueRunningNormally, e:
  607. # *args = *e.new_args
  608. # except DoneWithThisFrame, e:
  609. # return e.return
  610. # except ExitFrameWithException, e:
  611. # raise Exception, e.value
  612. #
  613. # def portal(*args):
  614. # while 1:
  615. # more stuff
  616. #
  617. origportalgraph = jd._jit_merge_point_in
  618. portalgraph = jd.portal_graph
  619. PORTALFUNC = jd._PORTAL_FUNCTYPE
  620. # ____________________________________________________________
  621. # Prepare the portal_runner() helper
  622. #
  623. from pypy.jit.metainterp.warmstate import specialize_value
  624. from pypy.jit.metainterp.warmstate import unspecialize_value
  625. portal_ptr = self.cpu.ts.functionptr(PORTALFUNC, 'portal',
  626. graph = portalgraph)
  627. jd._portal_ptr = portal_ptr
  628. #
  629. portalfunc_ARGS = []
  630. nums = {}
  631. for i, ARG in enumerate(PORTALFUNC.ARGS):
  632. kind = history.getkind(ARG)
  633. assert kind != 'void'
  634. if i < len(jd.jitdriver.greens):
  635. color = 'green'
  636. else:
  637. color = 'red'
  638. attrname = '%s_%s' % (color, kind)
  639. count = nums.get(attrname, 0)
  640. nums[attrname] = count + 1
  641. portalfunc_ARGS.append((ARG, attrname, count))
  642. portalfunc_ARGS = unrolling_iterable(portalfunc_ARGS)
  643. #
  644. rtyper = self.translator.rtyper
  645. RESULT = PORTALFUNC.RESULT
  646. result_kind = history.getkind(RESULT)
  647. ts = self.cpu.ts
  648. def ll_portal_runner(*args):
  649. start = True
  650. while 1:
  651. try:
  652. if start:
  653. jd._maybe_enter_from_start_fn(*args)
  654. return support.maybe_on_top_of_llinterp(rtyper,
  655. portal_ptr)(*args)
  656. except self.ContinueRunningNormally, e:
  657. args = ()
  658. for ARGTYPE, attrname, count in portalfunc_ARGS:
  659. x = getattr(e, attrname)[count]
  660. x = specialize_value(ARGTYPE, x)
  661. args = args + (x,)
  662. start = False
  663. continue
  664. except self.DoneWithThisFrameVoid:
  665. assert result_kind == 'void'
  666. return
  667. except self.DoneWithThisFrameInt, e:
  668. assert result_kind == 'int'
  669. return specialize_value(RESULT, e.result)
  670. except self.DoneWithThisFrameRef, e:
  671. assert result_kind == 'ref'
  672. return specialize_value(RESULT, e.result)
  673. except self.DoneWithThisFrameFloat, e:
  674. assert result_kind == 'float'
  675. return specialize_value(RESULT, e.result)
  676. except self.ExitFrameWithExceptionRef, e:
  677. value = ts.cast_to_baseclass(e.value)
  678. if not we_are_translated():
  679. raise LLException(ts.get_typeptr(value), value)
  680. else:
  681. value = cast_base_ptr_to_instance(Exception, value)
  682. raise Exception, value
  683. def handle_jitexception(e):
  684. # XXX the bulk of this function is mostly a copy-paste from above
  685. try:
  686. raise e
  687. except self.ContinueRunningNormally, e:
  688. args = ()
  689. for ARGTYPE, attrname, count in portalfunc_ARGS:
  690. x = getattr(e, attrname)[count]
  691. x = specialize_value(ARGTYPE, x)
  692. args = args + (x,)
  693. result = ll_portal_runner(*args)
  694. if result_kind != 'void':
  695. result = unspecialize_value(result)
  696. return result
  697. except self.DoneWithThisFrameVoid:
  698. assert result_kind == 'void'
  699. return
  700. except self.DoneWithThisFrameInt, e:
  701. assert result_kind == 'int'
  702. return e.result
  703. except self.DoneWithThisFrameRef, e:
  704. assert result_kind == 'ref'
  705. return e.result
  706. except self.DoneWithThisFrameFloat, e:
  707. assert result_kind == 'float'
  708. return e.result
  709. except self.ExitFrameWithExceptionRef, e:
  710. value = ts.cast_to_baseclass(e.value)
  711. if not we_are_translated():
  712. raise LLException(ts.get_typeptr(value), value)
  713. else:
  714. value = cast_base_ptr_to_instance(Exception, value)
  715. raise Exception, value
  716. jd._ll_portal_runner = ll_portal_runner # for debugging
  717. jd.portal_runner_ptr = self.helper_func(jd._PTR_PORTAL_FUNCTYPE,
  718. ll_portal_runner)
  719. jd.portal_runner_adr = llmemory.cast_ptr_to_adr(jd.portal_runner_ptr)
  720. jd.portal_calldescr = self.cpu.calldescrof(
  721. jd._PTR_PORTAL_FUNCTYPE.TO,
  722. jd._PTR_PORTAL_FUNCTYPE.TO.ARGS,
  723. jd._PTR_PORTAL_FUNCTYPE.TO.RESULT,
  724. EffectInfo.MOST_GENERAL)
  725. vinfo = jd.virtualizable_info
  726. def assembler_call_helper(failindex, virtualizableref):
  727. fail_descr = self.cpu.get_fail_descr_from_number(failindex)
  728. if vinfo is not None:
  729. virtualizable = lltype.cast_opaque_ptr(
  730. vinfo.VTYPEPTR, virtualizableref)
  731. vinfo.reset_vable_token(virtualizable)
  732. try:
  733. fail_descr.handle_fail(self.metainterp_sd, jd)
  734. except JitException, e:
  735. return handle_jitexception(e)
  736. else:
  737. assert 0, "should have raised"
  738. jd._assembler_call_helper = assembler_call_helper # for debugging
  739. jd._assembler_helper_ptr = self.helper_func(
  740. jd._PTR_ASSEMBLER_HELPER_FUNCTYPE,
  741. assembler_call_helper)
  742. jd.assembler_helper_adr = llmemory.cast_ptr_to_adr(
  743. jd._assembler_helper_ptr)
  744. if vinfo is not None:
  745. jd.vable_token_descr = vinfo.vable_token_descr
  746. def handle_jitexception_from_blackhole(bhcaller, e):
  747. result = handle_jitexception(e)
  748. if result_kind == 'void':
  749. pass
  750. elif result_kind == 'int':
  751. bhcaller._setup_return_value_i(result)
  752. elif result_kind == 'ref':
  753. bhcaller._setup_return_value_r(result)
  754. elif result_kind == 'float':
  755. bhcaller._setup_return_value_f(result)
  756. else:
  757. assert False
  758. jd.handle_jitexc_from_bh = handle_jitexception_from_blackhole
  759. # ____________________________________________________________
  760. # Now mutate origportalgraph to end with a call to portal_runner_ptr
  761. #
  762. origblock, origindex, op = locate_jit_merge_point(origportalgraph)
  763. assert op.opname == 'jit_marker'
  764. assert op.args[0].value == 'jit_merge_point'
  765. greens_v, reds_v = support.decode_hp_hint_args(op)
  766. vlist = [Constant(jd.portal_runner_ptr, jd._PTR_PORTAL_FUNCTYPE)]
  767. vlist += greens_v
  768. vlist += reds_v
  769. v_result = Variable()
  770. v_result.concretetype = PORTALFUNC.RESULT
  771. newop = SpaceOperation('direct_call', vlist, v_result)
  772. del origblock.operations[origindex:]
  773. origblock.operations.append(newop)
  774. origblock.exitswitch = None
  775. origblock.recloseblock(Link([v_result], origportalgraph.returnblock))
  776. #
  777. checkgraph(origportalgraph)
  778. def add_finish(self):
  779. def finish():
  780. if self.metainterp_sd.profiler.initialized:
  781. self.metainterp_sd.profiler.finish()
  782. self.metainterp_sd.cpu.finish_once()
  783. if self.cpu.translate_support_code:
  784. call_final_function(self.translator, finish,
  785. annhelper = self.annhelper)
  786. def rewrite_set_param(self):
  787. from pypy.rpython.lltypesystem.rstr import STR
  788. closures = {}
  789. graphs = self.translator.graphs
  790. _, PTR_SET_PARAM_FUNCTYPE = self.cpu.ts.get_FuncType([lltype.Signed],
  791. lltype.Void)
  792. _, PTR_SET_PARAM_STR_FUNCTYPE = self.cpu.ts.get_FuncType(
  793. [lltype.Ptr(STR)], lltype.Void)
  794. def make_closure(jd, fullfuncname, is_string):
  795. if jd is None:
  796. def closure(i):
  797. if is_string:
  798. i = hlstr(i)
  799. for jd in self.jitdrivers_sd:
  800. getattr(jd.warmstate, fullfuncname)(i)
  801. else:
  802. state = jd.warmstate
  803. def closure(i):
  804. if is_string:
  805. i = hlstr(i)
  806. getattr(state, fullfuncname)(i)
  807. if is_string:
  808. TP = PTR_SET_PARAM_STR_FUNCTYPE
  809. else:
  810. TP = PTR_SET_PARAM_FUNCTYPE
  811. funcptr = self.helper_func(TP, closure)
  812. return Constant(funcptr, TP)
  813. #
  814. for graph, block, i in find_set_param(graphs):
  815. op = block.operations[i]
  816. if op.args[1].value is not None:
  817. for jd in self.jitdrivers_sd:
  818. if jd.jitdriver is op.args[1].value:
  819. break
  820. else:
  821. assert 0, "jitdriver of set_param() not found"
  822. else:
  823. jd = None
  824. funcname = op.args[2].value
  825. key = jd, funcname
  826. if key not in closures:
  827. closures[key] = make_closure(jd, 'set_param_' + funcname,
  828. funcname == 'enable_opts')
  829. op.opname = 'direct_call'
  830. op.args[:3] = [closures[key]]
  831. def rewrite_force_virtual(self, vrefinfo):
  832. if self.cpu.ts.name != 'lltype':
  833. py.test.skip("rewrite_force_virtual: port it to ootype")
  834. all_graphs = self.translator.graphs
  835. vrefinfo.replace_force_virtual_with_call(all_graphs)
  836. def replace_force_quasiimmut_with_direct_call(self, op):
  837. ARG = op.args[0].concretetype
  838. mutatefieldname = op.args[1].value
  839. key = (ARG, mutatefieldname)
  840. if key in self._cache_force_quasiimmed_funcs:
  841. cptr = self._cache_force_quasiimmed_funcs[key]
  842. else:
  843. from pypy.jit.metainterp import quasiimmut
  844. func = quasiimmut.make_invalidation_function(ARG, mutatefieldname)
  845. FUNC = lltype.Ptr(lltype.FuncType([ARG], lltype.Void))
  846. llptr = self.helper_func(FUNC, func)
  847. cptr = Constant(llptr, FUNC)
  848. self._cache_force_quasiimmed_funcs[key] = cptr
  849. op.opname = 'direct_call'
  850. op.args = [cptr, op.args[0]]
  851. def rewrite_force_quasi_immutable(self):
  852. self._cache_force_quasiimmed_funcs = {}
  853. graphs = self.translator.graphs
  854. for graph, block, i in find_force_quasi_immutable(graphs):
  855. self.replace_force_quasiimmut_with_direct_call(block.operations[i])