PageRenderTime 52ms CodeModel.GetById 16ms RepoModel.GetById 0ms app.codeStats 0ms

/rpython/translator/backendopt/mallocv.py

https://bitbucket.org/pypy/pypy/
Python | 1055 lines | 967 code | 73 blank | 15 comment | 94 complexity | d84b99cdc8fe0352e3d63929b8bd1f8f MD5 | raw file
Possible License(s): AGPL-3.0, BSD-3-Clause, Apache-2.0
  1. from rpython.flowspace.model import Variable, Constant, Block, Link
  2. from rpython.flowspace.model import SpaceOperation, copygraph
  3. from rpython.flowspace.model import checkgraph
  4. from rpython.translator.backendopt.support import log
  5. from rpython.translator.simplify import join_blocks
  6. from rpython.translator.unsimplify import varoftype
  7. from rpython.rtyper.lltypesystem.lltype import getfunctionptr
  8. from rpython.rtyper.lltypesystem import lltype
  9. from rpython.rtyper.lltypesystem.lloperation import llop
  10. def virtualize_mallocs(translator, graphs, verbose=False):
  11. newgraphs = graphs[:]
  12. mallocv = MallocVirtualizer(newgraphs, translator.rtyper, verbose)
  13. while mallocv.remove_mallocs_once():
  14. pass
  15. for graph in newgraphs:
  16. checkgraph(graph)
  17. join_blocks(graph)
  18. assert newgraphs[:len(graphs)] == graphs
  19. del newgraphs[:len(graphs)]
  20. translator.graphs.extend(newgraphs)
  21. # ____________________________________________________________
  22. class MallocTypeDesc(object):
  23. def __init__(self, MALLOCTYPE):
  24. if not isinstance(MALLOCTYPE, lltype.GcStruct):
  25. raise CannotRemoveThisType
  26. self.MALLOCTYPE = MALLOCTYPE
  27. self.check_no_destructor()
  28. self.names_and_types = []
  29. self.name2index = {}
  30. self.name2subtype = {}
  31. self.initialize_type(MALLOCTYPE)
  32. #self.immutable_struct = MALLOCTYPE._hints.get('immutable')
  33. def check_no_destructor(self):
  34. STRUCT = self.MALLOCTYPE
  35. try:
  36. rttiptr = lltype.getRuntimeTypeInfo(STRUCT)
  37. except ValueError:
  38. return # ok
  39. destr_ptr = getattr(rttiptr._obj, 'destructor_funcptr', None)
  40. if destr_ptr:
  41. raise CannotRemoveThisType
  42. def initialize_type(self, TYPE):
  43. fieldnames = TYPE._names
  44. firstname, FIRSTTYPE = TYPE._first_struct()
  45. if FIRSTTYPE is not None:
  46. self.initialize_type(FIRSTTYPE)
  47. fieldnames = fieldnames[1:]
  48. for name in fieldnames:
  49. FIELDTYPE = TYPE._flds[name]
  50. if isinstance(FIELDTYPE, lltype.ContainerType):
  51. raise CannotRemoveThisType("inlined substructure")
  52. self.name2index[name] = len(self.names_and_types)
  53. self.names_and_types.append((name, FIELDTYPE))
  54. self.name2subtype[name] = TYPE
  55. class SpecNode(object):
  56. pass
  57. class RuntimeSpecNode(SpecNode):
  58. def __init__(self, name, TYPE):
  59. self.name = name
  60. self.TYPE = TYPE
  61. def newvar(self):
  62. v = Variable(self.name)
  63. v.concretetype = self.TYPE
  64. return v
  65. def getfrozenkey(self, memo):
  66. return 'R'
  67. def accumulate_nodes(self, rtnodes, vtnodes):
  68. rtnodes.append(self)
  69. def copy(self, memo, flagreadonly):
  70. return RuntimeSpecNode(self.name, self.TYPE)
  71. def bind_rt_nodes(self, memo, newnodes_iter):
  72. return newnodes_iter.next()
  73. class VirtualSpecNode(SpecNode):
  74. def __init__(self, typedesc, fields, readonly=False):
  75. self.typedesc = typedesc
  76. self.fields = fields # list of SpecNodes
  77. self.readonly = readonly
  78. def getfrozenkey(self, memo):
  79. if self in memo:
  80. return memo[self]
  81. else:
  82. memo[self] = len(memo)
  83. result = [self.typedesc, self.readonly]
  84. for subnode in self.fields:
  85. result.append(subnode.getfrozenkey(memo))
  86. return tuple(result)
  87. def accumulate_nodes(self, rtnodes, vtnodes):
  88. if self in vtnodes:
  89. return
  90. vtnodes[self] = True
  91. for subnode in self.fields:
  92. subnode.accumulate_nodes(rtnodes, vtnodes)
  93. def copy(self, memo, flagreadonly):
  94. if self in memo:
  95. return memo[self]
  96. readonly = self.readonly or self in flagreadonly
  97. newnode = VirtualSpecNode(self.typedesc, [], readonly)
  98. memo[self] = newnode
  99. for subnode in self.fields:
  100. newnode.fields.append(subnode.copy(memo, flagreadonly))
  101. return newnode
  102. def bind_rt_nodes(self, memo, newnodes_iter):
  103. if self in memo:
  104. return memo[self]
  105. newnode = VirtualSpecNode(self.typedesc, [], self.readonly)
  106. memo[self] = newnode
  107. for subnode in self.fields:
  108. newnode.fields.append(subnode.bind_rt_nodes(memo, newnodes_iter))
  109. return newnode
  110. class VirtualFrame(object):
  111. def __init__(self, sourceblock, nextopindex,
  112. allnodes, callerframe=None, calledgraphs={}):
  113. if isinstance(allnodes, dict):
  114. self.varlist = vars_alive_through_op(sourceblock, nextopindex)
  115. self.nodelist = [allnodes[v] for v in self.varlist]
  116. else:
  117. assert nextopindex == 0
  118. self.varlist = sourceblock.inputargs
  119. self.nodelist = allnodes[:]
  120. self.sourceblock = sourceblock
  121. self.nextopindex = nextopindex
  122. self.callerframe = callerframe
  123. self.calledgraphs = calledgraphs
  124. def get_nodes_in_use(self):
  125. return dict(zip(self.varlist, self.nodelist))
  126. def shallowcopy(self):
  127. newframe = VirtualFrame.__new__(VirtualFrame)
  128. newframe.varlist = self.varlist
  129. newframe.nodelist = self.nodelist
  130. newframe.sourceblock = self.sourceblock
  131. newframe.nextopindex = self.nextopindex
  132. newframe.callerframe = self.callerframe
  133. newframe.calledgraphs = self.calledgraphs
  134. return newframe
  135. def copy(self, memo, flagreadonly={}):
  136. newframe = self.shallowcopy()
  137. newframe.nodelist = [node.copy(memo, flagreadonly)
  138. for node in newframe.nodelist]
  139. if newframe.callerframe is not None:
  140. newframe.callerframe = newframe.callerframe.copy(memo,
  141. flagreadonly)
  142. return newframe
  143. def enum_call_stack(self):
  144. frame = self
  145. while frame is not None:
  146. yield frame
  147. frame = frame.callerframe
  148. def getfrozenkey(self):
  149. memo = {}
  150. key = []
  151. for frame in self.enum_call_stack():
  152. key.append(frame.sourceblock)
  153. key.append(frame.nextopindex)
  154. for node in frame.nodelist:
  155. key.append(node.getfrozenkey(memo))
  156. return tuple(key)
  157. def find_all_nodes(self):
  158. rtnodes = []
  159. vtnodes = {}
  160. for frame in self.enum_call_stack():
  161. for node in frame.nodelist:
  162. node.accumulate_nodes(rtnodes, vtnodes)
  163. return rtnodes, vtnodes
  164. def find_rt_nodes(self):
  165. rtnodes, vtnodes = self.find_all_nodes()
  166. return rtnodes
  167. def find_vt_nodes(self):
  168. rtnodes, vtnodes = self.find_all_nodes()
  169. return vtnodes
  170. def copynodes(nodelist, flagreadonly={}):
  171. memo = {}
  172. return [node.copy(memo, flagreadonly) for node in nodelist]
  173. def find_all_nodes(nodelist):
  174. rtnodes = []
  175. vtnodes = {}
  176. for node in nodelist:
  177. node.accumulate_nodes(rtnodes, vtnodes)
  178. return rtnodes, vtnodes
  179. def is_trivial_nodelist(nodelist):
  180. for node in nodelist:
  181. if not isinstance(node, RuntimeSpecNode):
  182. return False
  183. return True
  184. def bind_rt_nodes(srcnodelist, newnodes_list):
  185. """Return srcnodelist with all RuntimeNodes replaced by nodes
  186. coming from newnodes_list.
  187. """
  188. memo = {}
  189. newnodes_iter = iter(newnodes_list)
  190. result = [node.bind_rt_nodes(memo, newnodes_iter) for node in srcnodelist]
  191. rest = list(newnodes_iter)
  192. assert rest == [], "too many nodes in newnodes_list"
  193. return result
  194. class CannotVirtualize(Exception):
  195. pass
  196. class ForcedInline(Exception):
  197. pass
  198. class CannotRemoveThisType(Exception):
  199. pass
  200. # ____________________________________________________________
  201. class MallocVirtualizer(object):
  202. def __init__(self, graphs, rtyper, verbose=False):
  203. self.graphs = graphs
  204. self.rtyper = rtyper
  205. self.excdata = rtyper.exceptiondata
  206. self.graphbuilders = {}
  207. self.specialized_graphs = {}
  208. self.specgraphorigin = {}
  209. self.inline_and_remove = {} # {graph: op_to_remove}
  210. self.inline_and_remove_seen = {} # set of (graph, op_to_remove)
  211. self.malloctypedescs = {}
  212. self.count_virtualized = 0
  213. self.verbose = verbose
  214. self.EXCTYPE_to_vtable = self.build_obscure_mapping()
  215. def build_obscure_mapping(self):
  216. result = {}
  217. for rinstance in self.rtyper.instance_reprs.values():
  218. result[rinstance.lowleveltype.TO] = rinstance.rclass.getvtable()
  219. return result
  220. def report_result(self, progress):
  221. if progress:
  222. log.mallocv('removed %d mallocs so far' % self.count_virtualized)
  223. else:
  224. log.mallocv('done')
  225. def enum_all_mallocs(self, graph):
  226. for block in graph.iterblocks():
  227. for op in block.operations:
  228. if op.opname == 'malloc':
  229. MALLOCTYPE = op.result.concretetype.TO
  230. try:
  231. self.getmalloctypedesc(MALLOCTYPE)
  232. except CannotRemoveThisType:
  233. pass
  234. else:
  235. yield (block, op)
  236. elif op.opname == 'direct_call':
  237. graph = graph_called_by(op)
  238. if graph in self.inline_and_remove:
  239. yield (block, op)
  240. def remove_mallocs_once(self):
  241. self.flush_failed_specializations()
  242. prev = self.count_virtualized
  243. count_inline_and_remove = len(self.inline_and_remove)
  244. for graph in self.graphs:
  245. seen = {}
  246. while True:
  247. for block, op in self.enum_all_mallocs(graph):
  248. if op.result not in seen:
  249. seen[op.result] = True
  250. if self.try_remove_malloc(graph, block, op):
  251. break # graph mutated, restart enum_all_mallocs()
  252. else:
  253. break # enum_all_mallocs() exhausted, graph finished
  254. progress1 = self.count_virtualized - prev
  255. progress2 = len(self.inline_and_remove) - count_inline_and_remove
  256. progress = progress1 or bool(progress2)
  257. self.report_result(progress)
  258. return progress
  259. def flush_failed_specializations(self):
  260. for key, (mode, specgraph) in self.specialized_graphs.items():
  261. if mode == 'fail':
  262. del self.specialized_graphs[key]
  263. def fixup_except_block(self, exceptblock):
  264. # hack: this block's inputargs may be missing concretetypes...
  265. e1, v1 = exceptblock.inputargs
  266. e1.concretetype = self.excdata.lltype_of_exception_type
  267. v1.concretetype = self.excdata.lltype_of_exception_value
  268. def getmalloctypedesc(self, MALLOCTYPE):
  269. try:
  270. dsc = self.malloctypedescs[MALLOCTYPE]
  271. except KeyError:
  272. dsc = self.malloctypedescs[MALLOCTYPE] = MallocTypeDesc(MALLOCTYPE)
  273. return dsc
  274. def try_remove_malloc(self, graph, block, op):
  275. if (graph, op) in self.inline_and_remove_seen:
  276. return False # no point in trying again
  277. graphbuilder = GraphBuilder(self, graph)
  278. if graph in self.graphbuilders:
  279. graphbuilder.initialize_from_old_builder(self.graphbuilders[graph])
  280. graphbuilder.start_from_a_malloc(graph, block, op.result)
  281. try:
  282. graphbuilder.propagate_specializations()
  283. except CannotVirtualize as e:
  284. self.logresult(op, 'failed', e)
  285. return False
  286. except ForcedInline as e:
  287. self.logresult(op, 'forces inlining', e)
  288. self.inline_and_remove[graph] = op
  289. self.inline_and_remove_seen[graph, op] = True
  290. return False
  291. else:
  292. self.logresult(op, 'removed')
  293. graphbuilder.finished_removing_malloc()
  294. self.graphbuilders[graph] = graphbuilder
  295. self.count_virtualized += 1
  296. return True
  297. def logresult(self, op, msg, exc=None): # only for nice log outputs
  298. if self.verbose:
  299. if exc is None:
  300. exc = ''
  301. else:
  302. exc = ': %s' % (exc,)
  303. chain = []
  304. while True:
  305. chain.append(str(op.result))
  306. if op.opname != 'direct_call':
  307. break
  308. fobj = op.args[0].value._obj
  309. op = self.inline_and_remove[fobj.graph]
  310. log.mallocv('%s %s%s' % ('->'.join(chain), msg, exc))
  311. elif exc is None:
  312. log.dot()
  313. def get_specialized_graph(self, graph, nodelist):
  314. assert len(graph.getargs()) == len(nodelist)
  315. if is_trivial_nodelist(nodelist):
  316. return 'trivial', graph
  317. if graph in self.specgraphorigin:
  318. orggraph, orgnodelist = self.specgraphorigin[graph]
  319. nodelist = bind_rt_nodes(orgnodelist, nodelist)
  320. graph = orggraph
  321. virtualframe = VirtualFrame(graph.startblock, 0, nodelist)
  322. key = virtualframe.getfrozenkey()
  323. try:
  324. return self.specialized_graphs[key]
  325. except KeyError:
  326. self.build_specialized_graph(graph, key, nodelist)
  327. return self.specialized_graphs[key]
  328. def build_specialized_graph(self, graph, key, nodelist):
  329. graph2 = copygraph(graph)
  330. virtualframe = VirtualFrame(graph2.startblock, 0, nodelist)
  331. graphbuilder = GraphBuilder(self, graph2)
  332. specblock = graphbuilder.start_from_virtualframe(virtualframe)
  333. specgraph = graph2
  334. specgraph.name += '_mallocv'
  335. specgraph.startblock = specblock
  336. self.specialized_graphs[key] = ('call', specgraph)
  337. try:
  338. graphbuilder.propagate_specializations()
  339. except ForcedInline as e:
  340. if self.verbose:
  341. log.mallocv('%s inlined: %s' % (graph.name, e))
  342. self.specialized_graphs[key] = ('inline', None)
  343. except CannotVirtualize as e:
  344. if self.verbose:
  345. log.mallocv('%s failing: %s' % (graph.name, e))
  346. self.specialized_graphs[key] = ('fail', None)
  347. else:
  348. self.graphbuilders[specgraph] = graphbuilder
  349. self.specgraphorigin[specgraph] = graph, nodelist
  350. self.graphs.append(specgraph)
  351. class GraphBuilder(object):
  352. def __init__(self, mallocv, graph):
  353. self.mallocv = mallocv
  354. self.graph = graph
  355. self.specialized_blocks = {}
  356. self.pending_specializations = []
  357. def initialize_from_old_builder(self, oldbuilder):
  358. self.specialized_blocks.update(oldbuilder.specialized_blocks)
  359. def start_from_virtualframe(self, startframe):
  360. spec = BlockSpecializer(self)
  361. spec.initialize_renamings(startframe)
  362. self.pending_specializations.append(spec)
  363. return spec.specblock
  364. def start_from_a_malloc(self, graph, block, v_result):
  365. assert v_result in [op.result for op in block.operations]
  366. nodelist = []
  367. for v in block.inputargs:
  368. nodelist.append(RuntimeSpecNode(v, v.concretetype))
  369. trivialframe = VirtualFrame(block, 0, nodelist)
  370. spec = BlockSpecializer(self, v_result)
  371. spec.initialize_renamings(trivialframe, keep_inputargs=True)
  372. self.pending_specializations.append(spec)
  373. self.pending_patch = (block, spec.specblock)
  374. def finished_removing_malloc(self):
  375. (srcblock, specblock) = self.pending_patch
  376. srcblock.inputargs = specblock.inputargs
  377. srcblock.operations = specblock.operations
  378. srcblock.exitswitch = specblock.exitswitch
  379. srcblock.recloseblock(*specblock.exits)
  380. def create_outgoing_link(self, currentframe, targetblock,
  381. nodelist, renamings, v_expand_malloc=None):
  382. assert len(nodelist) == len(targetblock.inputargs)
  383. #
  384. if is_except(targetblock):
  385. v_expand_malloc = None
  386. while currentframe.callerframe is not None:
  387. currentframe = currentframe.callerframe
  388. newlink = self.handle_catch(currentframe, nodelist, renamings)
  389. if newlink:
  390. return newlink
  391. else:
  392. targetblock = self.exception_escapes(nodelist, renamings)
  393. assert len(nodelist) == len(targetblock.inputargs)
  394. if (currentframe.callerframe is None and
  395. is_trivial_nodelist(nodelist)):
  396. # there is no more VirtualSpecNodes being passed around,
  397. # so we can stop specializing
  398. rtnodes = nodelist
  399. specblock = targetblock
  400. else:
  401. if is_return(targetblock):
  402. v_expand_malloc = None
  403. newframe = self.return_to_caller(currentframe, nodelist[0])
  404. else:
  405. targetnodes = dict(zip(targetblock.inputargs, nodelist))
  406. newframe = VirtualFrame(targetblock, 0, targetnodes,
  407. callerframe=currentframe.callerframe,
  408. calledgraphs=currentframe.calledgraphs)
  409. rtnodes = newframe.find_rt_nodes()
  410. specblock = self.get_specialized_block(newframe, v_expand_malloc)
  411. linkargs = [renamings[rtnode] for rtnode in rtnodes]
  412. return Link(linkargs, specblock)
  413. def return_to_caller(self, currentframe, retnode):
  414. callerframe = currentframe.callerframe
  415. if callerframe is None:
  416. raise ForcedInline("return block")
  417. nodelist = callerframe.nodelist
  418. callerframe = callerframe.shallowcopy()
  419. callerframe.nodelist = []
  420. for node in nodelist:
  421. if isinstance(node, FutureReturnValue):
  422. node = retnode
  423. callerframe.nodelist.append(node)
  424. return callerframe
  425. def handle_catch(self, catchingframe, nodelist, renamings):
  426. if not self.has_exception_catching(catchingframe):
  427. return None
  428. [exc_node, exc_value_node] = nodelist
  429. v_exc_type = renamings.get(exc_node)
  430. if isinstance(v_exc_type, Constant):
  431. exc_type = v_exc_type.value
  432. elif isinstance(exc_value_node, VirtualSpecNode):
  433. EXCTYPE = exc_value_node.typedesc.MALLOCTYPE
  434. exc_type = self.mallocv.EXCTYPE_to_vtable[EXCTYPE]
  435. else:
  436. raise CannotVirtualize("raising non-constant exc type")
  437. excdata = self.mallocv.excdata
  438. assert catchingframe.sourceblock.exits[0].exitcase is None
  439. for catchlink in catchingframe.sourceblock.exits[1:]:
  440. if excdata.fn_exception_match(exc_type, catchlink.llexitcase):
  441. # Match found. Follow this link.
  442. mynodes = catchingframe.get_nodes_in_use()
  443. for node, attr in zip(nodelist,
  444. ['last_exception', 'last_exc_value']):
  445. v = getattr(catchlink, attr)
  446. if isinstance(v, Variable):
  447. mynodes[v] = node
  448. #
  449. nodelist = []
  450. for v in catchlink.args:
  451. if isinstance(v, Variable):
  452. node = mynodes[v]
  453. else:
  454. node = getconstnode(v, renamings)
  455. nodelist.append(node)
  456. return self.create_outgoing_link(catchingframe,
  457. catchlink.target,
  458. nodelist, renamings)
  459. else:
  460. # No match at all, propagate the exception to the caller
  461. return None
  462. def has_exception_catching(self, catchingframe):
  463. if not catchingframe.sourceblock.canraise:
  464. return False
  465. else:
  466. operations = catchingframe.sourceblock.operations
  467. assert 1 <= catchingframe.nextopindex <= len(operations)
  468. return catchingframe.nextopindex == len(operations)
  469. def exception_escapes(self, nodelist, renamings):
  470. # the exception escapes
  471. if not is_trivial_nodelist(nodelist):
  472. # start of hacks to help handle_catch()
  473. [exc_node, exc_value_node] = nodelist
  474. v_exc_type = renamings.get(exc_node)
  475. if isinstance(v_exc_type, Constant):
  476. # cannot improve: handle_catch() would already be happy
  477. # by seeing the exc_type as a constant
  478. pass
  479. elif isinstance(exc_value_node, VirtualSpecNode):
  480. # can improve with a strange hack: we pretend that
  481. # the source code jumps to a block that itself allocates
  482. # the exception, sets all fields, and raises it by
  483. # passing a constant type.
  484. typedesc = exc_value_node.typedesc
  485. return self.get_exc_reconstruction_block(typedesc)
  486. else:
  487. # cannot improve: handle_catch() will have no clue about
  488. # the exception type
  489. pass
  490. raise CannotVirtualize("except block")
  491. targetblock = self.graph.exceptblock
  492. self.mallocv.fixup_except_block(targetblock)
  493. return targetblock
  494. def get_exc_reconstruction_block(self, typedesc):
  495. exceptblock = self.graph.exceptblock
  496. self.mallocv.fixup_except_block(exceptblock)
  497. TEXC = exceptblock.inputargs[0].concretetype
  498. TVAL = exceptblock.inputargs[1].concretetype
  499. #
  500. v_ignored_type = varoftype(TEXC)
  501. v_incoming_value = varoftype(TVAL)
  502. block = Block([v_ignored_type, v_incoming_value])
  503. #
  504. c_EXCTYPE = Constant(typedesc.MALLOCTYPE, lltype.Void)
  505. v = varoftype(lltype.Ptr(typedesc.MALLOCTYPE))
  506. c_flavor = Constant({'flavor': 'gc'}, lltype.Void)
  507. op = SpaceOperation('malloc', [c_EXCTYPE, c_flavor], v)
  508. block.operations.append(op)
  509. #
  510. for name, FIELDTYPE in typedesc.names_and_types:
  511. EXACTPTR = lltype.Ptr(typedesc.name2subtype[name])
  512. c_name = Constant(name)
  513. c_name.concretetype = lltype.Void
  514. #
  515. v_in = varoftype(EXACTPTR)
  516. op = SpaceOperation('cast_pointer', [v_incoming_value], v_in)
  517. block.operations.append(op)
  518. #
  519. v_field = varoftype(FIELDTYPE)
  520. op = SpaceOperation('getfield', [v_in, c_name], v_field)
  521. block.operations.append(op)
  522. #
  523. v_out = varoftype(EXACTPTR)
  524. op = SpaceOperation('cast_pointer', [v], v_out)
  525. block.operations.append(op)
  526. #
  527. v0 = varoftype(lltype.Void)
  528. op = SpaceOperation('setfield', [v_out, c_name, v_field], v0)
  529. block.operations.append(op)
  530. #
  531. v_exc_value = varoftype(TVAL)
  532. op = SpaceOperation('cast_pointer', [v], v_exc_value)
  533. block.operations.append(op)
  534. #
  535. exc_type = self.mallocv.EXCTYPE_to_vtable[typedesc.MALLOCTYPE]
  536. c_exc_type = Constant(exc_type, TEXC)
  537. block.closeblock(Link([c_exc_type, v_exc_value], exceptblock))
  538. return block
  539. def get_specialized_block(self, virtualframe, v_expand_malloc=None):
  540. key = virtualframe.getfrozenkey()
  541. specblock = self.specialized_blocks.get(key)
  542. if specblock is None:
  543. orgblock = virtualframe.sourceblock
  544. assert len(orgblock.exits) != 0
  545. spec = BlockSpecializer(self, v_expand_malloc)
  546. spec.initialize_renamings(virtualframe)
  547. self.pending_specializations.append(spec)
  548. specblock = spec.specblock
  549. self.specialized_blocks[key] = specblock
  550. return specblock
  551. def propagate_specializations(self):
  552. while self.pending_specializations:
  553. spec = self.pending_specializations.pop()
  554. spec.specialize_operations()
  555. spec.follow_exits()
  556. class BlockSpecializer(object):
  557. def __init__(self, graphbuilder, v_expand_malloc=None):
  558. self.graphbuilder = graphbuilder
  559. self.v_expand_malloc = v_expand_malloc
  560. self.specblock = Block([])
  561. def initialize_renamings(self, virtualframe, keep_inputargs=False):
  562. # we make a copy of the original 'virtualframe' because the
  563. # specialize_operations() will mutate some of its content.
  564. virtualframe = virtualframe.copy({})
  565. self.virtualframe = virtualframe
  566. self.nodes = virtualframe.get_nodes_in_use()
  567. self.renamings = {} # {RuntimeSpecNode(): Variable()}
  568. if keep_inputargs:
  569. assert virtualframe.varlist == virtualframe.sourceblock.inputargs
  570. specinputargs = []
  571. for i, rtnode in enumerate(virtualframe.find_rt_nodes()):
  572. if keep_inputargs:
  573. v = virtualframe.varlist[i]
  574. assert v.concretetype == rtnode.TYPE
  575. else:
  576. v = rtnode.newvar()
  577. self.renamings[rtnode] = v
  578. specinputargs.append(v)
  579. self.specblock.inputargs = specinputargs
  580. def setnode(self, v, node):
  581. assert v not in self.nodes
  582. self.nodes[v] = node
  583. def getnode(self, v):
  584. if isinstance(v, Variable):
  585. return self.nodes[v]
  586. else:
  587. return getconstnode(v, self.renamings)
  588. def rename_nonvirtual(self, v, where=None):
  589. if not isinstance(v, Variable):
  590. return v
  591. node = self.nodes[v]
  592. if not isinstance(node, RuntimeSpecNode):
  593. raise CannotVirtualize(where)
  594. return self.renamings[node]
  595. def expand_nodes(self, nodelist):
  596. rtnodes, vtnodes = find_all_nodes(nodelist)
  597. return [self.renamings[rtnode] for rtnode in rtnodes]
  598. def specialize_operations(self):
  599. newoperations = []
  600. self.ops_produced_by_last_op = 0
  601. # note that 'self.virtualframe' can be changed during the loop!
  602. while True:
  603. operations = self.virtualframe.sourceblock.operations
  604. try:
  605. op = operations[self.virtualframe.nextopindex]
  606. self.virtualframe.nextopindex += 1
  607. except IndexError:
  608. break
  609. meth = getattr(self, 'handle_op_' + op.opname,
  610. self.handle_default)
  611. newops_for_this_op = meth(op)
  612. newoperations += newops_for_this_op
  613. self.ops_produced_by_last_op = len(newops_for_this_op)
  614. for op in newoperations:
  615. if op.opname == 'direct_call':
  616. graph = graph_called_by(op)
  617. if graph in self.virtualframe.calledgraphs:
  618. raise CannotVirtualize("recursion in residual call")
  619. self.specblock.operations = newoperations
  620. def follow_exits(self):
  621. block = self.virtualframe.sourceblock
  622. self.specblock.exitswitch = self.rename_nonvirtual(block.exitswitch,
  623. 'exitswitch')
  624. links = block.exits
  625. catch_exc = self.specblock.canraise
  626. if not catch_exc and isinstance(self.specblock.exitswitch, Constant):
  627. # constant-fold the switch
  628. for link in links:
  629. if link.exitcase == 'default':
  630. break
  631. if link.llexitcase == self.specblock.exitswitch.value:
  632. break
  633. else:
  634. raise Exception("exit case not found?")
  635. links = (link,)
  636. self.specblock.exitswitch = None
  637. if catch_exc and self.ops_produced_by_last_op == 0:
  638. # the last op of the sourceblock did not produce any
  639. # operation in specblock, so we need to discard the
  640. # exception-catching.
  641. catch_exc = False
  642. links = links[:1]
  643. assert links[0].exitcase is None # the non-exception-catching case
  644. self.specblock.exitswitch = None
  645. newlinks = []
  646. for link in links:
  647. is_catch_link = catch_exc and link.exitcase is not None
  648. if is_catch_link:
  649. extravars = []
  650. for attr in ['last_exception', 'last_exc_value']:
  651. v = getattr(link, attr)
  652. if isinstance(v, Variable):
  653. rtnode = RuntimeSpecNode(v, v.concretetype)
  654. self.setnode(v, rtnode)
  655. self.renamings[rtnode] = v = rtnode.newvar()
  656. extravars.append(v)
  657. linkargsnodes = [self.getnode(v1) for v1 in link.args]
  658. #
  659. newlink = self.graphbuilder.create_outgoing_link(
  660. self.virtualframe, link.target, linkargsnodes,
  661. self.renamings, self.v_expand_malloc)
  662. #
  663. if self.specblock.exitswitch is not None:
  664. newlink.exitcase = link.exitcase
  665. if hasattr(link, 'llexitcase'):
  666. newlink.llexitcase = link.llexitcase
  667. if is_catch_link:
  668. newlink.extravars(*extravars)
  669. newlinks.append(newlink)
  670. self.specblock.closeblock(*newlinks)
  671. def make_rt_result(self, v_result):
  672. newrtnode = RuntimeSpecNode(v_result, v_result.concretetype)
  673. self.setnode(v_result, newrtnode)
  674. v_new = newrtnode.newvar()
  675. self.renamings[newrtnode] = v_new
  676. return v_new
  677. def make_const_rt_result(self, v_result, value):
  678. newrtnode = RuntimeSpecNode(v_result, v_result.concretetype)
  679. self.setnode(v_result, newrtnode)
  680. if v_result.concretetype is not lltype.Void:
  681. assert v_result.concretetype == lltype.typeOf(value)
  682. c_value = Constant(value)
  683. c_value.concretetype = v_result.concretetype
  684. self.renamings[newrtnode] = c_value
  685. def handle_default(self, op):
  686. newargs = [self.rename_nonvirtual(v, op) for v in op.args]
  687. constresult = try_fold_operation(op.opname, newargs,
  688. op.result.concretetype)
  689. if constresult:
  690. self.make_const_rt_result(op.result, constresult[0])
  691. return []
  692. else:
  693. newresult = self.make_rt_result(op.result)
  694. return [SpaceOperation(op.opname, newargs, newresult)]
  695. def handle_unreachable(self, op):
  696. from rpython.rtyper.lltypesystem.rstr import string_repr
  697. msg = 'unreachable: %s' % (op,)
  698. ll_msg = string_repr.convert_const(msg)
  699. c_msg = Constant(ll_msg, lltype.typeOf(ll_msg))
  700. newresult = self.make_rt_result(op.result)
  701. return [SpaceOperation('debug_fatalerror', [c_msg], newresult)]
  702. def handle_op_getfield(self, op):
  703. node = self.getnode(op.args[0])
  704. if isinstance(node, VirtualSpecNode):
  705. fieldname = op.args[1].value
  706. index = node.typedesc.name2index[fieldname]
  707. self.setnode(op.result, node.fields[index])
  708. return []
  709. else:
  710. return self.handle_default(op)
  711. def handle_op_setfield(self, op):
  712. node = self.getnode(op.args[0])
  713. if isinstance(node, VirtualSpecNode):
  714. if node.readonly:
  715. raise ForcedInline(op)
  716. fieldname = op.args[1].value
  717. index = node.typedesc.name2index[fieldname]
  718. node.fields[index] = self.getnode(op.args[2])
  719. return []
  720. else:
  721. return self.handle_default(op)
  722. def handle_op_same_as(self, op):
  723. node = self.getnode(op.args[0])
  724. if isinstance(node, VirtualSpecNode):
  725. node = self.getnode(op.args[0])
  726. self.setnode(op.result, node)
  727. return []
  728. else:
  729. return self.handle_default(op)
  730. def handle_op_cast_pointer(self, op):
  731. node = self.getnode(op.args[0])
  732. if isinstance(node, VirtualSpecNode):
  733. node = self.getnode(op.args[0])
  734. SOURCEPTR = lltype.Ptr(node.typedesc.MALLOCTYPE)
  735. TARGETPTR = op.result.concretetype
  736. try:
  737. if lltype.castable(TARGETPTR, SOURCEPTR) < 0:
  738. raise lltype.InvalidCast
  739. except lltype.InvalidCast:
  740. return self.handle_unreachable(op)
  741. self.setnode(op.result, node)
  742. return []
  743. else:
  744. return self.handle_default(op)
  745. def handle_op_ptr_nonzero(self, op):
  746. node = self.getnode(op.args[0])
  747. if isinstance(node, VirtualSpecNode):
  748. self.make_const_rt_result(op.result, True)
  749. return []
  750. else:
  751. return self.handle_default(op)
  752. def handle_op_ptr_iszero(self, op):
  753. node = self.getnode(op.args[0])
  754. if isinstance(node, VirtualSpecNode):
  755. self.make_const_rt_result(op.result, False)
  756. return []
  757. else:
  758. return self.handle_default(op)
  759. def handle_op_ptr_eq(self, op):
  760. node0 = self.getnode(op.args[0])
  761. node1 = self.getnode(op.args[1])
  762. if (isinstance(node0, VirtualSpecNode) or
  763. isinstance(node1, VirtualSpecNode)):
  764. self.make_const_rt_result(op.result, node0 is node1)
  765. return []
  766. else:
  767. return self.handle_default(op)
  768. def handle_op_ptr_ne(self, op):
  769. node0 = self.getnode(op.args[0])
  770. node1 = self.getnode(op.args[1])
  771. if (isinstance(node0, VirtualSpecNode) or
  772. isinstance(node1, VirtualSpecNode)):
  773. self.make_const_rt_result(op.result, node0 is not node1)
  774. return []
  775. else:
  776. return self.handle_default(op)
  777. def handle_op_malloc(self, op):
  778. if op.result is self.v_expand_malloc:
  779. MALLOCTYPE = op.result.concretetype.TO
  780. typedesc = self.graphbuilder.mallocv.getmalloctypedesc(MALLOCTYPE)
  781. virtualnode = VirtualSpecNode(typedesc, [])
  782. self.setnode(op.result, virtualnode)
  783. for name, FIELDTYPE in typedesc.names_and_types:
  784. fieldnode = RuntimeSpecNode(name, FIELDTYPE)
  785. virtualnode.fields.append(fieldnode)
  786. c = Constant(FIELDTYPE._defl())
  787. c.concretetype = FIELDTYPE
  788. self.renamings[fieldnode] = c
  789. self.v_expand_malloc = None # done
  790. return []
  791. else:
  792. return self.handle_default(op)
  793. def handle_op_direct_call(self, op):
  794. graph = graph_called_by(op)
  795. if graph is None:
  796. return self.handle_default(op)
  797. nb_args = len(op.args) - 1
  798. assert nb_args == len(graph.getargs())
  799. newnodes = [self.getnode(v) for v in op.args[1:]]
  800. myframe = self.get_updated_frame(op)
  801. mallocv = self.graphbuilder.mallocv
  802. if op.result is self.v_expand_malloc:
  803. # move to inlining the callee, and continue looking for the
  804. # malloc to expand in the callee's graph
  805. op_to_remove = mallocv.inline_and_remove[graph]
  806. self.v_expand_malloc = op_to_remove.result
  807. return self.handle_inlined_call(myframe, graph, newnodes)
  808. argnodes = copynodes(newnodes, flagreadonly=myframe.find_vt_nodes())
  809. kind, newgraph = mallocv.get_specialized_graph(graph, argnodes)
  810. if kind == 'trivial':
  811. return self.handle_default(op)
  812. elif kind == 'inline':
  813. return self.handle_inlined_call(myframe, graph, newnodes)
  814. elif kind == 'call':
  815. return self.handle_residual_call(op, newgraph, newnodes)
  816. elif kind == 'fail':
  817. raise CannotVirtualize(op)
  818. else:
  819. raise ValueError(kind)
  820. def get_updated_frame(self, op):
  821. sourceblock = self.virtualframe.sourceblock
  822. nextopindex = self.virtualframe.nextopindex
  823. self.nodes[op.result] = FutureReturnValue(op)
  824. myframe = VirtualFrame(sourceblock, nextopindex, self.nodes,
  825. self.virtualframe.callerframe,
  826. self.virtualframe.calledgraphs)
  827. del self.nodes[op.result]
  828. return myframe
  829. def handle_residual_call(self, op, newgraph, newnodes):
  830. fspecptr = getfunctionptr(newgraph)
  831. newargs = [Constant(fspecptr,
  832. concretetype=lltype.typeOf(fspecptr))]
  833. newargs += self.expand_nodes(newnodes)
  834. newresult = self.make_rt_result(op.result)
  835. newop = SpaceOperation('direct_call', newargs, newresult)
  836. return [newop]
  837. def handle_inlined_call(self, myframe, graph, newnodes):
  838. assert len(graph.getargs()) == len(newnodes)
  839. targetnodes = dict(zip(graph.getargs(), newnodes))
  840. calledgraphs = myframe.calledgraphs.copy()
  841. if graph in calledgraphs:
  842. raise CannotVirtualize("recursion during inlining")
  843. calledgraphs[graph] = True
  844. calleeframe = VirtualFrame(graph.startblock, 0,
  845. targetnodes, myframe, calledgraphs)
  846. self.virtualframe = calleeframe
  847. self.nodes = calleeframe.get_nodes_in_use()
  848. return []
  849. def handle_op_indirect_call(self, op):
  850. v_func = self.rename_nonvirtual(op.args[0], op)
  851. if isinstance(v_func, Constant):
  852. op = SpaceOperation('direct_call', [v_func] + op.args[1:-1],
  853. op.result)
  854. return self.handle_op_direct_call(op)
  855. else:
  856. return self.handle_default(op)
  857. class FutureReturnValue(object):
  858. def __init__(self, op):
  859. self.op = op # for debugging
  860. def getfrozenkey(self, memo):
  861. return None
  862. def accumulate_nodes(self, rtnodes, vtnodes):
  863. pass
  864. def copy(self, memo, flagreadonly):
  865. return self
  866. # ____________________________________________________________
  867. # helpers
  868. def vars_alive_through_op(block, index):
  869. # NB. make sure this always returns the variables in the same order
  870. if len(block.exits) == 0:
  871. return block.inputargs # return or except block
  872. result = []
  873. seen = {}
  874. def see(v):
  875. if isinstance(v, Variable) and v not in seen:
  876. result.append(v)
  877. seen[v] = True
  878. # don't include the variables produced by the current or future operations
  879. for op in block.operations[index:]:
  880. seen[op.result] = True
  881. # don't include the extra vars produced by exception-catching links
  882. for link in block.exits:
  883. for v in link.getextravars():
  884. seen[v] = True
  885. # but include the variables consumed by the current or any future operation
  886. for op in block.operations[index:]:
  887. for v in op.args:
  888. see(v)
  889. see(block.exitswitch)
  890. for link in block.exits:
  891. for v in link.args:
  892. see(v)
  893. return result
  894. def is_return(block):
  895. return len(block.exits) == 0 and len(block.inputargs) == 1
  896. def is_except(block):
  897. return len(block.exits) == 0 and len(block.inputargs) == 2
  898. class CannotConstFold(Exception):
  899. pass
  900. def try_fold_operation(opname, args_v, RESTYPE):
  901. args = []
  902. for c in args_v:
  903. if not isinstance(c, Constant):
  904. return
  905. args.append(c.value)
  906. try:
  907. op = getattr(llop, opname)
  908. except AttributeError:
  909. return
  910. if not op.is_pure(args_v):
  911. return
  912. try:
  913. result = op(RESTYPE, *args)
  914. except TypeError:
  915. pass
  916. except (KeyboardInterrupt, SystemExit):
  917. raise
  918. except Exception as e:
  919. pass
  920. #log.WARNING('constant-folding %s%r:' % (opname, args_v))
  921. #log.WARNING(' %s: %s' % (e.__class__.__name__, e))
  922. else:
  923. return (result,)
  924. def getconstnode(v, renamings):
  925. rtnode = RuntimeSpecNode(None, v.concretetype)
  926. renamings[rtnode] = v
  927. return rtnode
  928. def graph_called_by(op):
  929. assert op.opname == 'direct_call'
  930. fobj = op.args[0].value._obj
  931. graph = getattr(fobj, 'graph', None)
  932. return graph