PageRenderTime 55ms CodeModel.GetById 33ms RepoModel.GetById 1ms app.codeStats 0ms

/rpython/jit/backend/llsupport/test/ztranslation_test.py

https://bitbucket.org/pypy/pypy/
Python | 348 lines | 328 code | 7 blank | 13 comment | 0 complexity | c2074c6eea8eb771998781b080f4919f MD5 | raw file
Possible License(s): AGPL-3.0, BSD-3-Clause, Apache-2.0
  1. import os, sys, py
  2. from rpython.tool.udir import udir
  3. from rpython.rlib.jit import JitDriver, unroll_parameters, set_param
  4. from rpython.rlib.jit import PARAMETERS, dont_look_inside
  5. from rpython.rlib.jit import promote, _get_virtualizable_token
  6. from rpython.rlib import jit_hooks, rposix, rgc
  7. from rpython.rlib.objectmodel import keepalive_until_here
  8. from rpython.rlib.rthread import ThreadLocalReference, ThreadLocalField
  9. from rpython.jit.backend.detect_cpu import getcpuclass
  10. from rpython.jit.backend.test.support import CCompiledMixin
  11. from rpython.jit.codewriter.policy import StopAtXPolicy
  12. from rpython.config.config import ConfigError
  13. from rpython.translator.tool.cbuild import ExternalCompilationInfo
  14. from rpython.rtyper.lltypesystem import lltype, rffi, rstr
  15. from rpython.rlib.rjitlog import rjitlog as jl
  16. class TranslationTest(CCompiledMixin):
  17. CPUClass = getcpuclass()
  18. def test_stuff_translates(self):
  19. # this is a basic test that tries to hit a number of features and their
  20. # translation:
  21. # - jitting of loops and bridges
  22. # - two virtualizable types
  23. # - set_param interface
  24. # - profiler
  25. # - full optimizer
  26. # - floats neg and abs
  27. # - cast_int_to_float
  28. # - llexternal with macro=True
  29. # - extra place for the zero after STR instances
  30. class BasicFrame(object):
  31. _virtualizable_ = ['i']
  32. def __init__(self, i):
  33. self.i = i
  34. class Frame(BasicFrame):
  35. pass
  36. eci = ExternalCompilationInfo(post_include_bits=['''
  37. #define pypy_my_fabs(x) fabs(x)
  38. '''], includes=['math.h'])
  39. myabs1 = rffi.llexternal('pypy_my_fabs', [lltype.Float],
  40. lltype.Float, macro=True, releasegil=False,
  41. compilation_info=eci)
  42. myabs2 = rffi.llexternal('pypy_my_fabs', [lltype.Float],
  43. lltype.Float, macro=True, releasegil=True,
  44. compilation_info=eci)
  45. @jl.returns(jl.MP_FILENAME,
  46. jl.MP_LINENO,
  47. jl.MP_INDEX)
  48. def get_location():
  49. return ("/home.py",0,0)
  50. jitdriver = JitDriver(greens = [],
  51. reds = ['total', 'frame', 'prev_s', 'j'],
  52. virtualizables = ['frame'],
  53. get_location = get_location)
  54. def f(i, j):
  55. for param, _ in unroll_parameters:
  56. defl = PARAMETERS[param]
  57. set_param(jitdriver, param, defl)
  58. set_param(jitdriver, "threshold", 3)
  59. set_param(jitdriver, "trace_eagerness", 2)
  60. total = 0
  61. frame = Frame(i)
  62. j = float(j)
  63. prev_s = rstr.mallocstr(16)
  64. while frame.i > 3:
  65. jitdriver.can_enter_jit(frame=frame, total=total, j=j,
  66. prev_s=prev_s)
  67. jitdriver.jit_merge_point(frame=frame, total=total, j=j,
  68. prev_s=prev_s)
  69. _get_virtualizable_token(frame)
  70. total += frame.i
  71. if frame.i >= 20:
  72. frame.i -= 2
  73. frame.i -= 1
  74. j *= -0.712
  75. if j + (-j): raise ValueError
  76. j += frame.i
  77. k = myabs1(myabs2(j))
  78. if k - abs(j): raise ValueError
  79. if k - abs(-j): raise ValueError
  80. s = rstr.mallocstr(16)
  81. rgc.ll_write_final_null_char(s)
  82. rgc.ll_write_final_null_char(prev_s)
  83. if (frame.i & 3) == 0:
  84. prev_s = s
  85. return chr(total % 253)
  86. #
  87. class Virt2(object):
  88. _virtualizable_ = ['i']
  89. def __init__(self, i):
  90. self.i = i
  91. from rpython.rlib.libffi import types, CDLL, ArgChain
  92. from rpython.rlib.test.test_clibffi import get_libm_name
  93. libm_name = get_libm_name(sys.platform)
  94. jitdriver2 = JitDriver(greens=[], reds = ['v2', 'func', 'res', 'x'],
  95. virtualizables = ['v2'])
  96. def libffi_stuff(i, j):
  97. lib = CDLL(libm_name)
  98. func = lib.getpointer('fabs', [types.double], types.double)
  99. res = 0.0
  100. x = float(j)
  101. v2 = Virt2(i)
  102. while v2.i > 0:
  103. jitdriver2.jit_merge_point(v2=v2, res=res, func=func, x=x)
  104. promote(func)
  105. argchain = ArgChain()
  106. argchain.arg(x)
  107. res = func.call(argchain, rffi.DOUBLE)
  108. v2.i -= 1
  109. return res
  110. #
  111. def main(i, j):
  112. a_char = f(i, j)
  113. a_float = libffi_stuff(i, j)
  114. return ord(a_char) * 10 + int(a_float)
  115. expected = main(40, -49)
  116. res = self.meta_interp(main, [40, -49])
  117. assert res == expected
  118. class TranslationTestCallAssembler(CCompiledMixin):
  119. CPUClass = getcpuclass()
  120. def test_direct_assembler_call_translates(self):
  121. """Test CALL_ASSEMBLER and the recursion limit"""
  122. # - also tests threadlocalref_get
  123. from rpython.rlib.rstackovf import StackOverflow
  124. class Thing(object):
  125. def __init__(self, val):
  126. self.val = val
  127. class Frame(object):
  128. _virtualizable_ = ['thing']
  129. driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
  130. virtualizables = ['frame'],
  131. get_printable_location = lambda codeno: str(codeno))
  132. class SomewhereElse(object):
  133. pass
  134. somewhere_else = SomewhereElse()
  135. class Foo(object):
  136. pass
  137. t = ThreadLocalReference(Foo, loop_invariant=True)
  138. tf = ThreadLocalField(lltype.Char, "test_call_assembler_")
  139. def change(newthing):
  140. somewhere_else.frame.thing = newthing
  141. def main(codeno):
  142. frame = Frame()
  143. somewhere_else.frame = frame
  144. frame.thing = Thing(0)
  145. portal(codeno, frame)
  146. return frame.thing.val
  147. def portal(codeno, frame):
  148. i = 0
  149. while i < 10:
  150. driver.can_enter_jit(frame=frame, codeno=codeno, i=i)
  151. driver.jit_merge_point(frame=frame, codeno=codeno, i=i)
  152. nextval = frame.thing.val
  153. if codeno == 0:
  154. subframe = Frame()
  155. subframe.thing = Thing(nextval)
  156. nextval = portal(1, subframe)
  157. elif frame.thing.val > 40:
  158. change(Thing(13))
  159. nextval = 13
  160. frame.thing = Thing(nextval + 1)
  161. i += 1
  162. if t.get().nine != 9: raise ValueError
  163. if ord(tf.getraw()) != 0x92: raise ValueError
  164. return frame.thing.val
  165. driver2 = JitDriver(greens = [], reds = ['n'])
  166. def main2(bound):
  167. try:
  168. while portal2(bound) == -bound+1:
  169. bound *= 2
  170. except StackOverflow:
  171. pass
  172. return bound
  173. def portal2(n):
  174. while True:
  175. driver2.jit_merge_point(n=n)
  176. n -= 1
  177. if n <= 0:
  178. return n
  179. n = portal2(n)
  180. assert portal2(10) == -9
  181. def setup(value):
  182. foo = Foo()
  183. foo.nine = value
  184. t.set(foo)
  185. tf.setraw("\x92")
  186. return foo
  187. def mainall(codeno, bound):
  188. foo = setup(bound + 8)
  189. result = main(codeno) + main2(bound)
  190. keepalive_until_here(foo)
  191. return result
  192. tmp_obj = setup(9)
  193. expected_1 = main(0)
  194. res = self.meta_interp(mainall, [0, 1], inline=True,
  195. policy=StopAtXPolicy(change))
  196. print hex(res)
  197. assert res & 255 == expected_1
  198. bound = res & ~255
  199. assert 1024 <= bound <= 131072
  200. assert bound & (bound-1) == 0 # a power of two
  201. class TranslationTestJITStats(CCompiledMixin):
  202. CPUClass = getcpuclass()
  203. def test_jit_get_stats(self):
  204. py.test.skip("disabled feature")
  205. driver = JitDriver(greens = [], reds = ['i'])
  206. def f():
  207. i = 0
  208. while i < 100000:
  209. driver.jit_merge_point(i=i)
  210. i += 1
  211. def main():
  212. jit_hooks.stats_set_debug(None, True)
  213. f()
  214. ll_times = jit_hooks.stats_get_loop_run_times(None)
  215. return len(ll_times)
  216. res = self.meta_interp(main, [])
  217. assert res == 2
  218. # one for loop and one for the prologue, no unrolling
  219. def test_flush_trace_counts(self):
  220. driver = JitDriver(greens = [], reds = ['i'])
  221. def f():
  222. i = 0
  223. while i < 100000:
  224. driver.jit_merge_point(i=i)
  225. i += 1
  226. def main():
  227. jit_hooks.stats_set_debug(None, True)
  228. f()
  229. jl.stats_flush_trace_counts(None)
  230. return 0
  231. res = self.meta_interp(main, [])
  232. assert res == 0
  233. class TranslationRemoveTypePtrTest(CCompiledMixin):
  234. CPUClass = getcpuclass()
  235. def test_external_exception_handling_translates(self):
  236. jitdriver = JitDriver(greens = [], reds = ['n', 'total'])
  237. class ImDone(Exception):
  238. def __init__(self, resvalue):
  239. self.resvalue = resvalue
  240. @dont_look_inside
  241. def f(x, total):
  242. if x <= 30:
  243. raise ImDone(total * 10)
  244. if x > 200:
  245. return 2
  246. raise ValueError
  247. @dont_look_inside
  248. def g(x):
  249. if x > 150:
  250. raise ValueError
  251. return 2
  252. class Base:
  253. def meth(self):
  254. return 2
  255. class Sub(Base):
  256. def meth(self):
  257. return 1
  258. @dont_look_inside
  259. def h(x):
  260. if x < 20000:
  261. return Sub()
  262. else:
  263. return Base()
  264. def myportal(i):
  265. set_param(jitdriver, "threshold", 3)
  266. set_param(jitdriver, "trace_eagerness", 2)
  267. total = 0
  268. n = i
  269. while True:
  270. jitdriver.can_enter_jit(n=n, total=total)
  271. jitdriver.jit_merge_point(n=n, total=total)
  272. try:
  273. total += f(n, total)
  274. except ValueError:
  275. total += 1
  276. try:
  277. total += g(n)
  278. except ValueError:
  279. total -= 1
  280. n -= h(n).meth() # this is to force a GUARD_CLASS
  281. def main(i):
  282. try:
  283. myportal(i)
  284. except ImDone as e:
  285. return e.resvalue
  286. # XXX custom fishing, depends on the exact env var and format
  287. logfile = udir.join('test_ztranslation.log')
  288. os.environ['PYPYLOG'] = 'jit-log-opt:%s' % (logfile,)
  289. try:
  290. res = self.meta_interp(main, [400])
  291. assert res == main(400)
  292. except ConfigError as e:
  293. assert str(e).startswith('invalid value asmgcc')
  294. py.test.skip('asmgcc not supported')
  295. finally:
  296. del os.environ['PYPYLOG']
  297. guard_class = 0
  298. for line in open(str(logfile)):
  299. if 'guard_class' in line:
  300. guard_class += 1
  301. # if we get many more guard_classes (~93), it means that we generate
  302. # guards that always fail (the following assert's original purpose
  303. # is to catch the following case: each GUARD_CLASS is misgenerated
  304. # and always fails with "gcremovetypeptr")
  305. assert 0 < guard_class < 10