/rpython/translator/backendopt/mallocv.py

https://bitbucket.org/pypy/pypy/ · Python · 1055 lines · 881 code · 123 blank · 51 comment · 205 complexity · d84b99cdc8fe0352e3d63929b8bd1f8f MD5 · raw file

  1. from rpython.flowspace.model import Variable, Constant, Block, Link
  2. from rpython.flowspace.model import SpaceOperation, copygraph
  3. from rpython.flowspace.model import checkgraph
  4. from rpython.translator.backendopt.support import log
  5. from rpython.translator.simplify import join_blocks
  6. from rpython.translator.unsimplify import varoftype
  7. from rpython.rtyper.lltypesystem.lltype import getfunctionptr
  8. from rpython.rtyper.lltypesystem import lltype
  9. from rpython.rtyper.lltypesystem.lloperation import llop
  10. def virtualize_mallocs(translator, graphs, verbose=False):
  11. newgraphs = graphs[:]
  12. mallocv = MallocVirtualizer(newgraphs, translator.rtyper, verbose)
  13. while mallocv.remove_mallocs_once():
  14. pass
  15. for graph in newgraphs:
  16. checkgraph(graph)
  17. join_blocks(graph)
  18. assert newgraphs[:len(graphs)] == graphs
  19. del newgraphs[:len(graphs)]
  20. translator.graphs.extend(newgraphs)
  21. # ____________________________________________________________
  22. class MallocTypeDesc(object):
  23. def __init__(self, MALLOCTYPE):
  24. if not isinstance(MALLOCTYPE, lltype.GcStruct):
  25. raise CannotRemoveThisType
  26. self.MALLOCTYPE = MALLOCTYPE
  27. self.check_no_destructor()
  28. self.names_and_types = []
  29. self.name2index = {}
  30. self.name2subtype = {}
  31. self.initialize_type(MALLOCTYPE)
  32. #self.immutable_struct = MALLOCTYPE._hints.get('immutable')
  33. def check_no_destructor(self):
  34. STRUCT = self.MALLOCTYPE
  35. try:
  36. rttiptr = lltype.getRuntimeTypeInfo(STRUCT)
  37. except ValueError:
  38. return # ok
  39. destr_ptr = getattr(rttiptr._obj, 'destructor_funcptr', None)
  40. if destr_ptr:
  41. raise CannotRemoveThisType
  42. def initialize_type(self, TYPE):
  43. fieldnames = TYPE._names
  44. firstname, FIRSTTYPE = TYPE._first_struct()
  45. if FIRSTTYPE is not None:
  46. self.initialize_type(FIRSTTYPE)
  47. fieldnames = fieldnames[1:]
  48. for name in fieldnames:
  49. FIELDTYPE = TYPE._flds[name]
  50. if isinstance(FIELDTYPE, lltype.ContainerType):
  51. raise CannotRemoveThisType("inlined substructure")
  52. self.name2index[name] = len(self.names_and_types)
  53. self.names_and_types.append((name, FIELDTYPE))
  54. self.name2subtype[name] = TYPE
  55. class SpecNode(object):
  56. pass
  57. class RuntimeSpecNode(SpecNode):
  58. def __init__(self, name, TYPE):
  59. self.name = name
  60. self.TYPE = TYPE
  61. def newvar(self):
  62. v = Variable(self.name)
  63. v.concretetype = self.TYPE
  64. return v
  65. def getfrozenkey(self, memo):
  66. return 'R'
  67. def accumulate_nodes(self, rtnodes, vtnodes):
  68. rtnodes.append(self)
  69. def copy(self, memo, flagreadonly):
  70. return RuntimeSpecNode(self.name, self.TYPE)
  71. def bind_rt_nodes(self, memo, newnodes_iter):
  72. return newnodes_iter.next()
  73. class VirtualSpecNode(SpecNode):
  74. def __init__(self, typedesc, fields, readonly=False):
  75. self.typedesc = typedesc
  76. self.fields = fields # list of SpecNodes
  77. self.readonly = readonly
  78. def getfrozenkey(self, memo):
  79. if self in memo:
  80. return memo[self]
  81. else:
  82. memo[self] = len(memo)
  83. result = [self.typedesc, self.readonly]
  84. for subnode in self.fields:
  85. result.append(subnode.getfrozenkey(memo))
  86. return tuple(result)
  87. def accumulate_nodes(self, rtnodes, vtnodes):
  88. if self in vtnodes:
  89. return
  90. vtnodes[self] = True
  91. for subnode in self.fields:
  92. subnode.accumulate_nodes(rtnodes, vtnodes)
  93. def copy(self, memo, flagreadonly):
  94. if self in memo:
  95. return memo[self]
  96. readonly = self.readonly or self in flagreadonly
  97. newnode = VirtualSpecNode(self.typedesc, [], readonly)
  98. memo[self] = newnode
  99. for subnode in self.fields:
  100. newnode.fields.append(subnode.copy(memo, flagreadonly))
  101. return newnode
  102. def bind_rt_nodes(self, memo, newnodes_iter):
  103. if self in memo:
  104. return memo[self]
  105. newnode = VirtualSpecNode(self.typedesc, [], self.readonly)
  106. memo[self] = newnode
  107. for subnode in self.fields:
  108. newnode.fields.append(subnode.bind_rt_nodes(memo, newnodes_iter))
  109. return newnode
  110. class VirtualFrame(object):
  111. def __init__(self, sourceblock, nextopindex,
  112. allnodes, callerframe=None, calledgraphs={}):
  113. if isinstance(allnodes, dict):
  114. self.varlist = vars_alive_through_op(sourceblock, nextopindex)
  115. self.nodelist = [allnodes[v] for v in self.varlist]
  116. else:
  117. assert nextopindex == 0
  118. self.varlist = sourceblock.inputargs
  119. self.nodelist = allnodes[:]
  120. self.sourceblock = sourceblock
  121. self.nextopindex = nextopindex
  122. self.callerframe = callerframe
  123. self.calledgraphs = calledgraphs
  124. def get_nodes_in_use(self):
  125. return dict(zip(self.varlist, self.nodelist))
  126. def shallowcopy(self):
  127. newframe = VirtualFrame.__new__(VirtualFrame)
  128. newframe.varlist = self.varlist
  129. newframe.nodelist = self.nodelist
  130. newframe.sourceblock = self.sourceblock
  131. newframe.nextopindex = self.nextopindex
  132. newframe.callerframe = self.callerframe
  133. newframe.calledgraphs = self.calledgraphs
  134. return newframe
  135. def copy(self, memo, flagreadonly={}):
  136. newframe = self.shallowcopy()
  137. newframe.nodelist = [node.copy(memo, flagreadonly)
  138. for node in newframe.nodelist]
  139. if newframe.callerframe is not None:
  140. newframe.callerframe = newframe.callerframe.copy(memo,
  141. flagreadonly)
  142. return newframe
  143. def enum_call_stack(self):
  144. frame = self
  145. while frame is not None:
  146. yield frame
  147. frame = frame.callerframe
  148. def getfrozenkey(self):
  149. memo = {}
  150. key = []
  151. for frame in self.enum_call_stack():
  152. key.append(frame.sourceblock)
  153. key.append(frame.nextopindex)
  154. for node in frame.nodelist:
  155. key.append(node.getfrozenkey(memo))
  156. return tuple(key)
  157. def find_all_nodes(self):
  158. rtnodes = []
  159. vtnodes = {}
  160. for frame in self.enum_call_stack():
  161. for node in frame.nodelist:
  162. node.accumulate_nodes(rtnodes, vtnodes)
  163. return rtnodes, vtnodes
  164. def find_rt_nodes(self):
  165. rtnodes, vtnodes = self.find_all_nodes()
  166. return rtnodes
  167. def find_vt_nodes(self):
  168. rtnodes, vtnodes = self.find_all_nodes()
  169. return vtnodes
  170. def copynodes(nodelist, flagreadonly={}):
  171. memo = {}
  172. return [node.copy(memo, flagreadonly) for node in nodelist]
  173. def find_all_nodes(nodelist):
  174. rtnodes = []
  175. vtnodes = {}
  176. for node in nodelist:
  177. node.accumulate_nodes(rtnodes, vtnodes)
  178. return rtnodes, vtnodes
  179. def is_trivial_nodelist(nodelist):
  180. for node in nodelist:
  181. if not isinstance(node, RuntimeSpecNode):
  182. return False
  183. return True
  184. def bind_rt_nodes(srcnodelist, newnodes_list):
  185. """Return srcnodelist with all RuntimeNodes replaced by nodes
  186. coming from newnodes_list.
  187. """
  188. memo = {}
  189. newnodes_iter = iter(newnodes_list)
  190. result = [node.bind_rt_nodes(memo, newnodes_iter) for node in srcnodelist]
  191. rest = list(newnodes_iter)
  192. assert rest == [], "too many nodes in newnodes_list"
  193. return result
  194. class CannotVirtualize(Exception):
  195. pass
  196. class ForcedInline(Exception):
  197. pass
  198. class CannotRemoveThisType(Exception):
  199. pass
  200. # ____________________________________________________________
  201. class MallocVirtualizer(object):
  202. def __init__(self, graphs, rtyper, verbose=False):
  203. self.graphs = graphs
  204. self.rtyper = rtyper
  205. self.excdata = rtyper.exceptiondata
  206. self.graphbuilders = {}
  207. self.specialized_graphs = {}
  208. self.specgraphorigin = {}
  209. self.inline_and_remove = {} # {graph: op_to_remove}
  210. self.inline_and_remove_seen = {} # set of (graph, op_to_remove)
  211. self.malloctypedescs = {}
  212. self.count_virtualized = 0
  213. self.verbose = verbose
  214. self.EXCTYPE_to_vtable = self.build_obscure_mapping()
  215. def build_obscure_mapping(self):
  216. result = {}
  217. for rinstance in self.rtyper.instance_reprs.values():
  218. result[rinstance.lowleveltype.TO] = rinstance.rclass.getvtable()
  219. return result
  220. def report_result(self, progress):
  221. if progress:
  222. log.mallocv('removed %d mallocs so far' % self.count_virtualized)
  223. else:
  224. log.mallocv('done')
  225. def enum_all_mallocs(self, graph):
  226. for block in graph.iterblocks():
  227. for op in block.operations:
  228. if op.opname == 'malloc':
  229. MALLOCTYPE = op.result.concretetype.TO
  230. try:
  231. self.getmalloctypedesc(MALLOCTYPE)
  232. except CannotRemoveThisType:
  233. pass
  234. else:
  235. yield (block, op)
  236. elif op.opname == 'direct_call':
  237. graph = graph_called_by(op)
  238. if graph in self.inline_and_remove:
  239. yield (block, op)
  240. def remove_mallocs_once(self):
  241. self.flush_failed_specializations()
  242. prev = self.count_virtualized
  243. count_inline_and_remove = len(self.inline_and_remove)
  244. for graph in self.graphs:
  245. seen = {}
  246. while True:
  247. for block, op in self.enum_all_mallocs(graph):
  248. if op.result not in seen:
  249. seen[op.result] = True
  250. if self.try_remove_malloc(graph, block, op):
  251. break # graph mutated, restart enum_all_mallocs()
  252. else:
  253. break # enum_all_mallocs() exhausted, graph finished
  254. progress1 = self.count_virtualized - prev
  255. progress2 = len(self.inline_and_remove) - count_inline_and_remove
  256. progress = progress1 or bool(progress2)
  257. self.report_result(progress)
  258. return progress
  259. def flush_failed_specializations(self):
  260. for key, (mode, specgraph) in self.specialized_graphs.items():
  261. if mode == 'fail':
  262. del self.specialized_graphs[key]
  263. def fixup_except_block(self, exceptblock):
  264. # hack: this block's inputargs may be missing concretetypes...
  265. e1, v1 = exceptblock.inputargs
  266. e1.concretetype = self.excdata.lltype_of_exception_type
  267. v1.concretetype = self.excdata.lltype_of_exception_value
  268. def getmalloctypedesc(self, MALLOCTYPE):
  269. try:
  270. dsc = self.malloctypedescs[MALLOCTYPE]
  271. except KeyError:
  272. dsc = self.malloctypedescs[MALLOCTYPE] = MallocTypeDesc(MALLOCTYPE)
  273. return dsc
  274. def try_remove_malloc(self, graph, block, op):
  275. if (graph, op) in self.inline_and_remove_seen:
  276. return False # no point in trying again
  277. graphbuilder = GraphBuilder(self, graph)
  278. if graph in self.graphbuilders:
  279. graphbuilder.initialize_from_old_builder(self.graphbuilders[graph])
  280. graphbuilder.start_from_a_malloc(graph, block, op.result)
  281. try:
  282. graphbuilder.propagate_specializations()
  283. except CannotVirtualize as e:
  284. self.logresult(op, 'failed', e)
  285. return False
  286. except ForcedInline as e:
  287. self.logresult(op, 'forces inlining', e)
  288. self.inline_and_remove[graph] = op
  289. self.inline_and_remove_seen[graph, op] = True
  290. return False
  291. else:
  292. self.logresult(op, 'removed')
  293. graphbuilder.finished_removing_malloc()
  294. self.graphbuilders[graph] = graphbuilder
  295. self.count_virtualized += 1
  296. return True
  297. def logresult(self, op, msg, exc=None): # only for nice log outputs
  298. if self.verbose:
  299. if exc is None:
  300. exc = ''
  301. else:
  302. exc = ': %s' % (exc,)
  303. chain = []
  304. while True:
  305. chain.append(str(op.result))
  306. if op.opname != 'direct_call':
  307. break
  308. fobj = op.args[0].value._obj
  309. op = self.inline_and_remove[fobj.graph]
  310. log.mallocv('%s %s%s' % ('->'.join(chain), msg, exc))
  311. elif exc is None:
  312. log.dot()
  313. def get_specialized_graph(self, graph, nodelist):
  314. assert len(graph.getargs()) == len(nodelist)
  315. if is_trivial_nodelist(nodelist):
  316. return 'trivial', graph
  317. if graph in self.specgraphorigin:
  318. orggraph, orgnodelist = self.specgraphorigin[graph]
  319. nodelist = bind_rt_nodes(orgnodelist, nodelist)
  320. graph = orggraph
  321. virtualframe = VirtualFrame(graph.startblock, 0, nodelist)
  322. key = virtualframe.getfrozenkey()
  323. try:
  324. return self.specialized_graphs[key]
  325. except KeyError:
  326. self.build_specialized_graph(graph, key, nodelist)
  327. return self.specialized_graphs[key]
  328. def build_specialized_graph(self, graph, key, nodelist):
  329. graph2 = copygraph(graph)
  330. virtualframe = VirtualFrame(graph2.startblock, 0, nodelist)
  331. graphbuilder = GraphBuilder(self, graph2)
  332. specblock = graphbuilder.start_from_virtualframe(virtualframe)
  333. specgraph = graph2
  334. specgraph.name += '_mallocv'
  335. specgraph.startblock = specblock
  336. self.specialized_graphs[key] = ('call', specgraph)
  337. try:
  338. graphbuilder.propagate_specializations()
  339. except ForcedInline as e:
  340. if self.verbose:
  341. log.mallocv('%s inlined: %s' % (graph.name, e))
  342. self.specialized_graphs[key] = ('inline', None)
  343. except CannotVirtualize as e:
  344. if self.verbose:
  345. log.mallocv('%s failing: %s' % (graph.name, e))
  346. self.specialized_graphs[key] = ('fail', None)
  347. else:
  348. self.graphbuilders[specgraph] = graphbuilder
  349. self.specgraphorigin[specgraph] = graph, nodelist
  350. self.graphs.append(specgraph)
  351. class GraphBuilder(object):
  352. def __init__(self, mallocv, graph):
  353. self.mallocv = mallocv
  354. self.graph = graph
  355. self.specialized_blocks = {}
  356. self.pending_specializations = []
  357. def initialize_from_old_builder(self, oldbuilder):
  358. self.specialized_blocks.update(oldbuilder.specialized_blocks)
  359. def start_from_virtualframe(self, startframe):
  360. spec = BlockSpecializer(self)
  361. spec.initialize_renamings(startframe)
  362. self.pending_specializations.append(spec)
  363. return spec.specblock
  364. def start_from_a_malloc(self, graph, block, v_result):
  365. assert v_result in [op.result for op in block.operations]
  366. nodelist = []
  367. for v in block.inputargs:
  368. nodelist.append(RuntimeSpecNode(v, v.concretetype))
  369. trivialframe = VirtualFrame(block, 0, nodelist)
  370. spec = BlockSpecializer(self, v_result)
  371. spec.initialize_renamings(trivialframe, keep_inputargs=True)
  372. self.pending_specializations.append(spec)
  373. self.pending_patch = (block, spec.specblock)
  374. def finished_removing_malloc(self):
  375. (srcblock, specblock) = self.pending_patch
  376. srcblock.inputargs = specblock.inputargs
  377. srcblock.operations = specblock.operations
  378. srcblock.exitswitch = specblock.exitswitch
  379. srcblock.recloseblock(*specblock.exits)
  380. def create_outgoing_link(self, currentframe, targetblock,
  381. nodelist, renamings, v_expand_malloc=None):
  382. assert len(nodelist) == len(targetblock.inputargs)
  383. #
  384. if is_except(targetblock):
  385. v_expand_malloc = None
  386. while currentframe.callerframe is not None:
  387. currentframe = currentframe.callerframe
  388. newlink = self.handle_catch(currentframe, nodelist, renamings)
  389. if newlink:
  390. return newlink
  391. else:
  392. targetblock = self.exception_escapes(nodelist, renamings)
  393. assert len(nodelist) == len(targetblock.inputargs)
  394. if (currentframe.callerframe is None and
  395. is_trivial_nodelist(nodelist)):
  396. # there is no more VirtualSpecNodes being passed around,
  397. # so we can stop specializing
  398. rtnodes = nodelist
  399. specblock = targetblock
  400. else:
  401. if is_return(targetblock):
  402. v_expand_malloc = None
  403. newframe = self.return_to_caller(currentframe, nodelist[0])
  404. else:
  405. targetnodes = dict(zip(targetblock.inputargs, nodelist))
  406. newframe = VirtualFrame(targetblock, 0, targetnodes,
  407. callerframe=currentframe.callerframe,
  408. calledgraphs=currentframe.calledgraphs)
  409. rtnodes = newframe.find_rt_nodes()
  410. specblock = self.get_specialized_block(newframe, v_expand_malloc)
  411. linkargs = [renamings[rtnode] for rtnode in rtnodes]
  412. return Link(linkargs, specblock)
  413. def return_to_caller(self, currentframe, retnode):
  414. callerframe = currentframe.callerframe
  415. if callerframe is None:
  416. raise ForcedInline("return block")
  417. nodelist = callerframe.nodelist
  418. callerframe = callerframe.shallowcopy()
  419. callerframe.nodelist = []
  420. for node in nodelist:
  421. if isinstance(node, FutureReturnValue):
  422. node = retnode
  423. callerframe.nodelist.append(node)
  424. return callerframe
  425. def handle_catch(self, catchingframe, nodelist, renamings):
  426. if not self.has_exception_catching(catchingframe):
  427. return None
  428. [exc_node, exc_value_node] = nodelist
  429. v_exc_type = renamings.get(exc_node)
  430. if isinstance(v_exc_type, Constant):
  431. exc_type = v_exc_type.value
  432. elif isinstance(exc_value_node, VirtualSpecNode):
  433. EXCTYPE = exc_value_node.typedesc.MALLOCTYPE
  434. exc_type = self.mallocv.EXCTYPE_to_vtable[EXCTYPE]
  435. else:
  436. raise CannotVirtualize("raising non-constant exc type")
  437. excdata = self.mallocv.excdata
  438. assert catchingframe.sourceblock.exits[0].exitcase is None
  439. for catchlink in catchingframe.sourceblock.exits[1:]:
  440. if excdata.fn_exception_match(exc_type, catchlink.llexitcase):
  441. # Match found. Follow this link.
  442. mynodes = catchingframe.get_nodes_in_use()
  443. for node, attr in zip(nodelist,
  444. ['last_exception', 'last_exc_value']):
  445. v = getattr(catchlink, attr)
  446. if isinstance(v, Variable):
  447. mynodes[v] = node
  448. #
  449. nodelist = []
  450. for v in catchlink.args:
  451. if isinstance(v, Variable):
  452. node = mynodes[v]
  453. else:
  454. node = getconstnode(v, renamings)
  455. nodelist.append(node)
  456. return self.create_outgoing_link(catchingframe,
  457. catchlink.target,
  458. nodelist, renamings)
  459. else:
  460. # No match at all, propagate the exception to the caller
  461. return None
  462. def has_exception_catching(self, catchingframe):
  463. if not catchingframe.sourceblock.canraise:
  464. return False
  465. else:
  466. operations = catchingframe.sourceblock.operations
  467. assert 1 <= catchingframe.nextopindex <= len(operations)
  468. return catchingframe.nextopindex == len(operations)
  469. def exception_escapes(self, nodelist, renamings):
  470. # the exception escapes
  471. if not is_trivial_nodelist(nodelist):
  472. # start of hacks to help handle_catch()
  473. [exc_node, exc_value_node] = nodelist
  474. v_exc_type = renamings.get(exc_node)
  475. if isinstance(v_exc_type, Constant):
  476. # cannot improve: handle_catch() would already be happy
  477. # by seeing the exc_type as a constant
  478. pass
  479. elif isinstance(exc_value_node, VirtualSpecNode):
  480. # can improve with a strange hack: we pretend that
  481. # the source code jumps to a block that itself allocates
  482. # the exception, sets all fields, and raises it by
  483. # passing a constant type.
  484. typedesc = exc_value_node.typedesc
  485. return self.get_exc_reconstruction_block(typedesc)
  486. else:
  487. # cannot improve: handle_catch() will have no clue about
  488. # the exception type
  489. pass
  490. raise CannotVirtualize("except block")
  491. targetblock = self.graph.exceptblock
  492. self.mallocv.fixup_except_block(targetblock)
  493. return targetblock
  494. def get_exc_reconstruction_block(self, typedesc):
  495. exceptblock = self.graph.exceptblock
  496. self.mallocv.fixup_except_block(exceptblock)
  497. TEXC = exceptblock.inputargs[0].concretetype
  498. TVAL = exceptblock.inputargs[1].concretetype
  499. #
  500. v_ignored_type = varoftype(TEXC)
  501. v_incoming_value = varoftype(TVAL)
  502. block = Block([v_ignored_type, v_incoming_value])
  503. #
  504. c_EXCTYPE = Constant(typedesc.MALLOCTYPE, lltype.Void)
  505. v = varoftype(lltype.Ptr(typedesc.MALLOCTYPE))
  506. c_flavor = Constant({'flavor': 'gc'}, lltype.Void)
  507. op = SpaceOperation('malloc', [c_EXCTYPE, c_flavor], v)
  508. block.operations.append(op)
  509. #
  510. for name, FIELDTYPE in typedesc.names_and_types:
  511. EXACTPTR = lltype.Ptr(typedesc.name2subtype[name])
  512. c_name = Constant(name)
  513. c_name.concretetype = lltype.Void
  514. #
  515. v_in = varoftype(EXACTPTR)
  516. op = SpaceOperation('cast_pointer', [v_incoming_value], v_in)
  517. block.operations.append(op)
  518. #
  519. v_field = varoftype(FIELDTYPE)
  520. op = SpaceOperation('getfield', [v_in, c_name], v_field)
  521. block.operations.append(op)
  522. #
  523. v_out = varoftype(EXACTPTR)
  524. op = SpaceOperation('cast_pointer', [v], v_out)
  525. block.operations.append(op)
  526. #
  527. v0 = varoftype(lltype.Void)
  528. op = SpaceOperation('setfield', [v_out, c_name, v_field], v0)
  529. block.operations.append(op)
  530. #
  531. v_exc_value = varoftype(TVAL)
  532. op = SpaceOperation('cast_pointer', [v], v_exc_value)
  533. block.operations.append(op)
  534. #
  535. exc_type = self.mallocv.EXCTYPE_to_vtable[typedesc.MALLOCTYPE]
  536. c_exc_type = Constant(exc_type, TEXC)
  537. block.closeblock(Link([c_exc_type, v_exc_value], exceptblock))
  538. return block
  539. def get_specialized_block(self, virtualframe, v_expand_malloc=None):
  540. key = virtualframe.getfrozenkey()
  541. specblock = self.specialized_blocks.get(key)
  542. if specblock is None:
  543. orgblock = virtualframe.sourceblock
  544. assert len(orgblock.exits) != 0
  545. spec = BlockSpecializer(self, v_expand_malloc)
  546. spec.initialize_renamings(virtualframe)
  547. self.pending_specializations.append(spec)
  548. specblock = spec.specblock
  549. self.specialized_blocks[key] = specblock
  550. return specblock
  551. def propagate_specializations(self):
  552. while self.pending_specializations:
  553. spec = self.pending_specializations.pop()
  554. spec.specialize_operations()
  555. spec.follow_exits()
  556. class BlockSpecializer(object):
  557. def __init__(self, graphbuilder, v_expand_malloc=None):
  558. self.graphbuilder = graphbuilder
  559. self.v_expand_malloc = v_expand_malloc
  560. self.specblock = Block([])
  561. def initialize_renamings(self, virtualframe, keep_inputargs=False):
  562. # we make a copy of the original 'virtualframe' because the
  563. # specialize_operations() will mutate some of its content.
  564. virtualframe = virtualframe.copy({})
  565. self.virtualframe = virtualframe
  566. self.nodes = virtualframe.get_nodes_in_use()
  567. self.renamings = {} # {RuntimeSpecNode(): Variable()}
  568. if keep_inputargs:
  569. assert virtualframe.varlist == virtualframe.sourceblock.inputargs
  570. specinputargs = []
  571. for i, rtnode in enumerate(virtualframe.find_rt_nodes()):
  572. if keep_inputargs:
  573. v = virtualframe.varlist[i]
  574. assert v.concretetype == rtnode.TYPE
  575. else:
  576. v = rtnode.newvar()
  577. self.renamings[rtnode] = v
  578. specinputargs.append(v)
  579. self.specblock.inputargs = specinputargs
  580. def setnode(self, v, node):
  581. assert v not in self.nodes
  582. self.nodes[v] = node
  583. def getnode(self, v):
  584. if isinstance(v, Variable):
  585. return self.nodes[v]
  586. else:
  587. return getconstnode(v, self.renamings)
  588. def rename_nonvirtual(self, v, where=None):
  589. if not isinstance(v, Variable):
  590. return v
  591. node = self.nodes[v]
  592. if not isinstance(node, RuntimeSpecNode):
  593. raise CannotVirtualize(where)
  594. return self.renamings[node]
  595. def expand_nodes(self, nodelist):
  596. rtnodes, vtnodes = find_all_nodes(nodelist)
  597. return [self.renamings[rtnode] for rtnode in rtnodes]
  598. def specialize_operations(self):
  599. newoperations = []
  600. self.ops_produced_by_last_op = 0
  601. # note that 'self.virtualframe' can be changed during the loop!
  602. while True:
  603. operations = self.virtualframe.sourceblock.operations
  604. try:
  605. op = operations[self.virtualframe.nextopindex]
  606. self.virtualframe.nextopindex += 1
  607. except IndexError:
  608. break
  609. meth = getattr(self, 'handle_op_' + op.opname,
  610. self.handle_default)
  611. newops_for_this_op = meth(op)
  612. newoperations += newops_for_this_op
  613. self.ops_produced_by_last_op = len(newops_for_this_op)
  614. for op in newoperations:
  615. if op.opname == 'direct_call':
  616. graph = graph_called_by(op)
  617. if graph in self.virtualframe.calledgraphs:
  618. raise CannotVirtualize("recursion in residual call")
  619. self.specblock.operations = newoperations
  620. def follow_exits(self):
  621. block = self.virtualframe.sourceblock
  622. self.specblock.exitswitch = self.rename_nonvirtual(block.exitswitch,
  623. 'exitswitch')
  624. links = block.exits
  625. catch_exc = self.specblock.canraise
  626. if not catch_exc and isinstance(self.specblock.exitswitch, Constant):
  627. # constant-fold the switch
  628. for link in links:
  629. if link.exitcase == 'default':
  630. break
  631. if link.llexitcase == self.specblock.exitswitch.value:
  632. break
  633. else:
  634. raise Exception("exit case not found?")
  635. links = (link,)
  636. self.specblock.exitswitch = None
  637. if catch_exc and self.ops_produced_by_last_op == 0:
  638. # the last op of the sourceblock did not produce any
  639. # operation in specblock, so we need to discard the
  640. # exception-catching.
  641. catch_exc = False
  642. links = links[:1]
  643. assert links[0].exitcase is None # the non-exception-catching case
  644. self.specblock.exitswitch = None
  645. newlinks = []
  646. for link in links:
  647. is_catch_link = catch_exc and link.exitcase is not None
  648. if is_catch_link:
  649. extravars = []
  650. for attr in ['last_exception', 'last_exc_value']:
  651. v = getattr(link, attr)
  652. if isinstance(v, Variable):
  653. rtnode = RuntimeSpecNode(v, v.concretetype)
  654. self.setnode(v, rtnode)
  655. self.renamings[rtnode] = v = rtnode.newvar()
  656. extravars.append(v)
  657. linkargsnodes = [self.getnode(v1) for v1 in link.args]
  658. #
  659. newlink = self.graphbuilder.create_outgoing_link(
  660. self.virtualframe, link.target, linkargsnodes,
  661. self.renamings, self.v_expand_malloc)
  662. #
  663. if self.specblock.exitswitch is not None:
  664. newlink.exitcase = link.exitcase
  665. if hasattr(link, 'llexitcase'):
  666. newlink.llexitcase = link.llexitcase
  667. if is_catch_link:
  668. newlink.extravars(*extravars)
  669. newlinks.append(newlink)
  670. self.specblock.closeblock(*newlinks)
  671. def make_rt_result(self, v_result):
  672. newrtnode = RuntimeSpecNode(v_result, v_result.concretetype)
  673. self.setnode(v_result, newrtnode)
  674. v_new = newrtnode.newvar()
  675. self.renamings[newrtnode] = v_new
  676. return v_new
  677. def make_const_rt_result(self, v_result, value):
  678. newrtnode = RuntimeSpecNode(v_result, v_result.concretetype)
  679. self.setnode(v_result, newrtnode)
  680. if v_result.concretetype is not lltype.Void:
  681. assert v_result.concretetype == lltype.typeOf(value)
  682. c_value = Constant(value)
  683. c_value.concretetype = v_result.concretetype
  684. self.renamings[newrtnode] = c_value
  685. def handle_default(self, op):
  686. newargs = [self.rename_nonvirtual(v, op) for v in op.args]
  687. constresult = try_fold_operation(op.opname, newargs,
  688. op.result.concretetype)
  689. if constresult:
  690. self.make_const_rt_result(op.result, constresult[0])
  691. return []
  692. else:
  693. newresult = self.make_rt_result(op.result)
  694. return [SpaceOperation(op.opname, newargs, newresult)]
  695. def handle_unreachable(self, op):
  696. from rpython.rtyper.lltypesystem.rstr import string_repr
  697. msg = 'unreachable: %s' % (op,)
  698. ll_msg = string_repr.convert_const(msg)
  699. c_msg = Constant(ll_msg, lltype.typeOf(ll_msg))
  700. newresult = self.make_rt_result(op.result)
  701. return [SpaceOperation('debug_fatalerror', [c_msg], newresult)]
  702. def handle_op_getfield(self, op):
  703. node = self.getnode(op.args[0])
  704. if isinstance(node, VirtualSpecNode):
  705. fieldname = op.args[1].value
  706. index = node.typedesc.name2index[fieldname]
  707. self.setnode(op.result, node.fields[index])
  708. return []
  709. else:
  710. return self.handle_default(op)
  711. def handle_op_setfield(self, op):
  712. node = self.getnode(op.args[0])
  713. if isinstance(node, VirtualSpecNode):
  714. if node.readonly:
  715. raise ForcedInline(op)
  716. fieldname = op.args[1].value
  717. index = node.typedesc.name2index[fieldname]
  718. node.fields[index] = self.getnode(op.args[2])
  719. return []
  720. else:
  721. return self.handle_default(op)
  722. def handle_op_same_as(self, op):
  723. node = self.getnode(op.args[0])
  724. if isinstance(node, VirtualSpecNode):
  725. node = self.getnode(op.args[0])
  726. self.setnode(op.result, node)
  727. return []
  728. else:
  729. return self.handle_default(op)
  730. def handle_op_cast_pointer(self, op):
  731. node = self.getnode(op.args[0])
  732. if isinstance(node, VirtualSpecNode):
  733. node = self.getnode(op.args[0])
  734. SOURCEPTR = lltype.Ptr(node.typedesc.MALLOCTYPE)
  735. TARGETPTR = op.result.concretetype
  736. try:
  737. if lltype.castable(TARGETPTR, SOURCEPTR) < 0:
  738. raise lltype.InvalidCast
  739. except lltype.InvalidCast:
  740. return self.handle_unreachable(op)
  741. self.setnode(op.result, node)
  742. return []
  743. else:
  744. return self.handle_default(op)
  745. def handle_op_ptr_nonzero(self, op):
  746. node = self.getnode(op.args[0])
  747. if isinstance(node, VirtualSpecNode):
  748. self.make_const_rt_result(op.result, True)
  749. return []
  750. else:
  751. return self.handle_default(op)
  752. def handle_op_ptr_iszero(self, op):
  753. node = self.getnode(op.args[0])
  754. if isinstance(node, VirtualSpecNode):
  755. self.make_const_rt_result(op.result, False)
  756. return []
  757. else:
  758. return self.handle_default(op)
  759. def handle_op_ptr_eq(self, op):
  760. node0 = self.getnode(op.args[0])
  761. node1 = self.getnode(op.args[1])
  762. if (isinstance(node0, VirtualSpecNode) or
  763. isinstance(node1, VirtualSpecNode)):
  764. self.make_const_rt_result(op.result, node0 is node1)
  765. return []
  766. else:
  767. return self.handle_default(op)
  768. def handle_op_ptr_ne(self, op):
  769. node0 = self.getnode(op.args[0])
  770. node1 = self.getnode(op.args[1])
  771. if (isinstance(node0, VirtualSpecNode) or
  772. isinstance(node1, VirtualSpecNode)):
  773. self.make_const_rt_result(op.result, node0 is not node1)
  774. return []
  775. else:
  776. return self.handle_default(op)
  777. def handle_op_malloc(self, op):
  778. if op.result is self.v_expand_malloc:
  779. MALLOCTYPE = op.result.concretetype.TO
  780. typedesc = self.graphbuilder.mallocv.getmalloctypedesc(MALLOCTYPE)
  781. virtualnode = VirtualSpecNode(typedesc, [])
  782. self.setnode(op.result, virtualnode)
  783. for name, FIELDTYPE in typedesc.names_and_types:
  784. fieldnode = RuntimeSpecNode(name, FIELDTYPE)
  785. virtualnode.fields.append(fieldnode)
  786. c = Constant(FIELDTYPE._defl())
  787. c.concretetype = FIELDTYPE
  788. self.renamings[fieldnode] = c
  789. self.v_expand_malloc = None # done
  790. return []
  791. else:
  792. return self.handle_default(op)
  793. def handle_op_direct_call(self, op):
  794. graph = graph_called_by(op)
  795. if graph is None:
  796. return self.handle_default(op)
  797. nb_args = len(op.args) - 1
  798. assert nb_args == len(graph.getargs())
  799. newnodes = [self.getnode(v) for v in op.args[1:]]
  800. myframe = self.get_updated_frame(op)
  801. mallocv = self.graphbuilder.mallocv
  802. if op.result is self.v_expand_malloc:
  803. # move to inlining the callee, and continue looking for the
  804. # malloc to expand in the callee's graph
  805. op_to_remove = mallocv.inline_and_remove[graph]
  806. self.v_expand_malloc = op_to_remove.result
  807. return self.handle_inlined_call(myframe, graph, newnodes)
  808. argnodes = copynodes(newnodes, flagreadonly=myframe.find_vt_nodes())
  809. kind, newgraph = mallocv.get_specialized_graph(graph, argnodes)
  810. if kind == 'trivial':
  811. return self.handle_default(op)
  812. elif kind == 'inline':
  813. return self.handle_inlined_call(myframe, graph, newnodes)
  814. elif kind == 'call':
  815. return self.handle_residual_call(op, newgraph, newnodes)
  816. elif kind == 'fail':
  817. raise CannotVirtualize(op)
  818. else:
  819. raise ValueError(kind)
  820. def get_updated_frame(self, op):
  821. sourceblock = self.virtualframe.sourceblock
  822. nextopindex = self.virtualframe.nextopindex
  823. self.nodes[op.result] = FutureReturnValue(op)
  824. myframe = VirtualFrame(sourceblock, nextopindex, self.nodes,
  825. self.virtualframe.callerframe,
  826. self.virtualframe.calledgraphs)
  827. del self.nodes[op.result]
  828. return myframe
  829. def handle_residual_call(self, op, newgraph, newnodes):
  830. fspecptr = getfunctionptr(newgraph)
  831. newargs = [Constant(fspecptr,
  832. concretetype=lltype.typeOf(fspecptr))]
  833. newargs += self.expand_nodes(newnodes)
  834. newresult = self.make_rt_result(op.result)
  835. newop = SpaceOperation('direct_call', newargs, newresult)
  836. return [newop]
  837. def handle_inlined_call(self, myframe, graph, newnodes):
  838. assert len(graph.getargs()) == len(newnodes)
  839. targetnodes = dict(zip(graph.getargs(), newnodes))
  840. calledgraphs = myframe.calledgraphs.copy()
  841. if graph in calledgraphs:
  842. raise CannotVirtualize("recursion during inlining")
  843. calledgraphs[graph] = True
  844. calleeframe = VirtualFrame(graph.startblock, 0,
  845. targetnodes, myframe, calledgraphs)
  846. self.virtualframe = calleeframe
  847. self.nodes = calleeframe.get_nodes_in_use()
  848. return []
  849. def handle_op_indirect_call(self, op):
  850. v_func = self.rename_nonvirtual(op.args[0], op)
  851. if isinstance(v_func, Constant):
  852. op = SpaceOperation('direct_call', [v_func] + op.args[1:-1],
  853. op.result)
  854. return self.handle_op_direct_call(op)
  855. else:
  856. return self.handle_default(op)
  857. class FutureReturnValue(object):
  858. def __init__(self, op):
  859. self.op = op # for debugging
  860. def getfrozenkey(self, memo):
  861. return None
  862. def accumulate_nodes(self, rtnodes, vtnodes):
  863. pass
  864. def copy(self, memo, flagreadonly):
  865. return self
  866. # ____________________________________________________________
  867. # helpers
  868. def vars_alive_through_op(block, index):
  869. # NB. make sure this always returns the variables in the same order
  870. if len(block.exits) == 0:
  871. return block.inputargs # return or except block
  872. result = []
  873. seen = {}
  874. def see(v):
  875. if isinstance(v, Variable) and v not in seen:
  876. result.append(v)
  877. seen[v] = True
  878. # don't include the variables produced by the current or future operations
  879. for op in block.operations[index:]:
  880. seen[op.result] = True
  881. # don't include the extra vars produced by exception-catching links
  882. for link in block.exits:
  883. for v in link.getextravars():
  884. seen[v] = True
  885. # but include the variables consumed by the current or any future operation
  886. for op in block.operations[index:]:
  887. for v in op.args:
  888. see(v)
  889. see(block.exitswitch)
  890. for link in block.exits:
  891. for v in link.args:
  892. see(v)
  893. return result
  894. def is_return(block):
  895. return len(block.exits) == 0 and len(block.inputargs) == 1
  896. def is_except(block):
  897. return len(block.exits) == 0 and len(block.inputargs) == 2
  898. class CannotConstFold(Exception):
  899. pass
  900. def try_fold_operation(opname, args_v, RESTYPE):
  901. args = []
  902. for c in args_v:
  903. if not isinstance(c, Constant):
  904. return
  905. args.append(c.value)
  906. try:
  907. op = getattr(llop, opname)
  908. except AttributeError:
  909. return
  910. if not op.is_pure(args_v):
  911. return
  912. try:
  913. result = op(RESTYPE, *args)
  914. except TypeError:
  915. pass
  916. except (KeyboardInterrupt, SystemExit):
  917. raise
  918. except Exception as e:
  919. pass
  920. #log.WARNING('constant-folding %s%r:' % (opname, args_v))
  921. #log.WARNING(' %s: %s' % (e.__class__.__name__, e))
  922. else:
  923. return (result,)
  924. def getconstnode(v, renamings):
  925. rtnode = RuntimeSpecNode(None, v.concretetype)
  926. renamings[rtnode] = v
  927. return rtnode
  928. def graph_called_by(op):
  929. assert op.opname == 'direct_call'
  930. fobj = op.args[0].value._obj
  931. graph = getattr(fobj, 'graph', None)
  932. return graph