PageRenderTime 57ms CodeModel.GetById 8ms RepoModel.GetById 0ms app.codeStats 0ms

/rpython/jit/metainterp/optimizeopt/schedule.py

https://bitbucket.org/pypy/pypy/
Python | 1100 lines | 856 code | 125 blank | 119 comment | 183 complexity | e3c5c7ca4b441bf71fbcd24d5558721f MD5 | raw file
Possible License(s): AGPL-3.0, BSD-3-Clause, Apache-2.0
  1. from rpython.jit.metainterp.history import (VECTOR, FLOAT, INT,
  2. ConstInt, ConstFloat, TargetToken)
  3. from rpython.jit.metainterp.resoperation import (rop, ResOperation,
  4. GuardResOp, VecOperation, OpHelpers, VecOperationNew,
  5. VectorizationInfo)
  6. from rpython.jit.metainterp.optimizeopt.dependency import (DependencyGraph,
  7. MemoryRef, Node, IndexVar)
  8. from rpython.jit.metainterp.optimizeopt.renamer import Renamer
  9. from rpython.jit.metainterp.resume import AccumInfo
  10. from rpython.rlib.objectmodel import we_are_translated
  11. from rpython.jit.metainterp.jitexc import NotAProfitableLoop
  12. from rpython.rlib.objectmodel import specialize, always_inline
  13. from rpython.jit.metainterp.jitexc import NotAVectorizeableLoop, NotAProfitableLoop
  14. from rpython.rtyper.lltypesystem.lloperation import llop
  15. from rpython.rtyper.lltypesystem import lltype
  16. def forwarded_vecinfo(op):
  17. fwd = op.get_forwarded()
  18. if fwd is None or not isinstance(fwd, VectorizationInfo):
  19. # the optimizer clears getforwarded AFTER
  20. # vectorization, it happens that this is not clean
  21. fwd = VectorizationInfo(op)
  22. if not op.is_constant():
  23. op.set_forwarded(fwd)
  24. return fwd
  25. class SchedulerState(object):
  26. def __init__(self, graph):
  27. self.renamer = Renamer()
  28. self.graph = graph
  29. self.oplist = []
  30. self.worklist = []
  31. self.invariant_oplist = []
  32. self.invariant_vector_vars = []
  33. self.seen = {}
  34. def post_schedule(self):
  35. loop = self.graph.loop
  36. self.renamer.rename(loop.jump)
  37. self.ensure_args_unpacked(loop.jump)
  38. loop.operations = self.oplist
  39. loop.prefix = self.invariant_oplist
  40. if len(self.invariant_vector_vars) + len(self.invariant_oplist) > 0:
  41. # label
  42. args = loop.label.getarglist_copy() + self.invariant_vector_vars
  43. opnum = loop.label.getopnum()
  44. op = loop.label.copy_and_change(opnum, args)
  45. self.renamer.rename(op)
  46. loop.prefix_label = op
  47. # jump
  48. args = loop.jump.getarglist_copy() + self.invariant_vector_vars
  49. opnum = loop.jump.getopnum()
  50. op = loop.jump.copy_and_change(opnum, args)
  51. self.renamer.rename(op)
  52. loop.jump = op
  53. def profitable(self):
  54. return True
  55. def prepare(self):
  56. for node in self.graph.nodes:
  57. if node.depends_count() == 0:
  58. self.worklist.insert(0, node)
  59. def emit(self, node, scheduler):
  60. # implement me in subclass. e.g. as in VecScheduleState
  61. return False
  62. def delay(self, node):
  63. return False
  64. def has_more(self):
  65. return len(self.worklist) > 0
  66. def ensure_args_unpacked(self, op):
  67. pass
  68. def post_emit(self, node):
  69. pass
  70. def pre_emit(self, node):
  71. pass
  72. class Scheduler(object):
  73. """ Create an instance of this class to (re)schedule a vector trace. """
  74. def __init__(self):
  75. pass
  76. def next(self, state):
  77. """ select the next candidate node to be emitted, or None """
  78. worklist = state.worklist
  79. visited = 0
  80. while len(worklist) > 0:
  81. if visited == len(worklist):
  82. return None
  83. node = worklist.pop()
  84. if node.emitted:
  85. continue
  86. if not self.delay(node, state):
  87. return node
  88. worklist.insert(0, node)
  89. visited += 1
  90. return None
  91. def try_to_trash_pack(self, state):
  92. # one element a pack has several dependencies pointing to
  93. # it thus we MUST skip this pack!
  94. if len(state.worklist) > 0:
  95. # break the first!
  96. i = 0
  97. node = state.worklist[i]
  98. i += 1
  99. while i < len(state.worklist) and not node.pack:
  100. node = state.worklist[i]
  101. i += 1
  102. if not node.pack:
  103. return False
  104. pack = node.pack
  105. for n in node.pack.operations:
  106. if n.depends_count() > 0:
  107. pack.clear()
  108. return True
  109. else:
  110. return False
  111. return False
  112. def delay(self, node, state):
  113. """ Delay this operation?
  114. Only if any dependency has not been resolved """
  115. if state.delay(node):
  116. return True
  117. return node.depends_count() != 0
  118. def mark_emitted(self, node, state, unpack=True):
  119. """ An operation has been emitted, adds new operations to the worklist
  120. whenever their dependency count drops to zero.
  121. Keeps worklist sorted (see priority) """
  122. worklist = state.worklist
  123. provides = node.provides()[:]
  124. for dep in provides: # COPY
  125. target = dep.to
  126. node.remove_edge_to(target)
  127. if not target.emitted and target.depends_count() == 0:
  128. # sorts them by priority
  129. i = len(worklist)-1
  130. while i >= 0:
  131. cur = worklist[i]
  132. c = (cur.priority - target.priority)
  133. if c < 0: # meaning itnode.priority < target.priority:
  134. worklist.insert(i+1, target)
  135. break
  136. elif c == 0:
  137. # if they have the same priority, sort them
  138. # using the original position in the trace
  139. if target.getindex() < cur.getindex():
  140. worklist.insert(i+1, target)
  141. break
  142. i -= 1
  143. else:
  144. worklist.insert(0, target)
  145. node.clear_dependencies()
  146. node.emitted = True
  147. if not node.is_imaginary():
  148. op = node.getoperation()
  149. state.renamer.rename(op)
  150. if unpack:
  151. state.ensure_args_unpacked(op)
  152. state.post_emit(node)
  153. def walk_and_emit(self, state):
  154. """ Emit all the operations into the oplist parameter.
  155. Initiates the scheduling. """
  156. assert isinstance(state, SchedulerState)
  157. while state.has_more():
  158. node = self.next(state)
  159. if node:
  160. if not state.emit(node, self):
  161. if not node.emitted:
  162. state.pre_emit(node)
  163. self.mark_emitted(node, state)
  164. if not node.is_imaginary():
  165. op = node.getoperation()
  166. state.seen[op] = None
  167. state.oplist.append(op)
  168. continue
  169. # it happens that packs can emit many nodes that have been
  170. # added to the scheuldable_nodes list, in this case it could
  171. # be that no next exists even though the list contains elements
  172. if not state.has_more():
  173. break
  174. if self.try_to_trash_pack(state):
  175. continue
  176. raise AssertionError("schedule failed cannot continue. possible reason: cycle")
  177. if not we_are_translated():
  178. for node in state.graph.nodes:
  179. assert node.emitted
  180. def failnbail_transformation(msg):
  181. msg = '%s\n' % msg
  182. if we_are_translated():
  183. llop.debug_print(lltype.Void, msg)
  184. else:
  185. import pdb; pdb.set_trace()
  186. raise NotImplementedError(msg)
  187. class TypeRestrict(object):
  188. ANY_TYPE = '\x00'
  189. ANY_SIZE = -1
  190. ANY_SIGN = -1
  191. ANY_COUNT = -1
  192. SIGNED = 1
  193. UNSIGNED = 0
  194. def __init__(self,
  195. type=ANY_TYPE,
  196. bytesize=ANY_SIZE,
  197. count=ANY_SIGN,
  198. sign=ANY_COUNT):
  199. self.type = type
  200. self.bytesize = bytesize
  201. self.sign = sign
  202. self.count = count
  203. @always_inline
  204. def any_size(self):
  205. return self.bytesize == TypeRestrict.ANY_SIZE
  206. @always_inline
  207. def any_count(self):
  208. return self.count == TypeRestrict.ANY_COUNT
  209. def check(self, value):
  210. vecinfo = forwarded_vecinfo(value)
  211. assert vecinfo.datatype != '\x00'
  212. if self.type != TypeRestrict.ANY_TYPE:
  213. if self.type != vecinfo.datatype:
  214. msg = "type mismatch %s != %s" % \
  215. (self.type, vecinfo.datatype)
  216. failnbail_transformation(msg)
  217. assert vecinfo.bytesize > 0
  218. if not self.any_size():
  219. if self.bytesize != vecinfo.bytesize:
  220. msg = "bytesize mismatch %s != %s" % \
  221. (self.bytesize, vecinfo.bytesize)
  222. failnbail_transformation(msg)
  223. assert vecinfo.count > 0
  224. if self.count != TypeRestrict.ANY_COUNT:
  225. if vecinfo.count < self.count:
  226. msg = "count mismatch %s < %s" % \
  227. (self.count, vecinfo.count)
  228. failnbail_transformation(msg)
  229. if self.sign != TypeRestrict.ANY_SIGN:
  230. if bool(self.sign) == vecinfo.sign:
  231. msg = "sign mismatch %s < %s" % \
  232. (self.sign, vecinfo.sign)
  233. failnbail_transformation(msg)
  234. def max_input_count(self, count):
  235. """ How many """
  236. if self.count != TypeRestrict.ANY_COUNT:
  237. return self.count
  238. return count
  239. class OpRestrict(object):
  240. def __init__(self, argument_restris):
  241. self.argument_restrictions = argument_restris
  242. def check_operation(self, state, pack, op):
  243. pass
  244. def crop_vector(self, op, newsize, size):
  245. return newsize, size
  246. def must_crop_vector(self, op, index):
  247. restrict = self.argument_restrictions[index]
  248. vecinfo = forwarded_vecinfo(op.getarg(index))
  249. size = vecinfo.bytesize
  250. newsize = self.crop_to_size(op, index)
  251. return not restrict.any_size() and newsize != size
  252. @always_inline
  253. def crop_to_size(self, op, index):
  254. restrict = self.argument_restrictions[index]
  255. return restrict.bytesize
  256. def opcount_filling_vector_register(self, op, vec_reg_size):
  257. """ How many operations of that kind can one execute
  258. with a machine instruction of register size X?
  259. """
  260. if op.is_typecast():
  261. if op.casts_down():
  262. size = op.cast_input_bytesize(vec_reg_size)
  263. return size // op.cast_from_bytesize()
  264. else:
  265. return vec_reg_size // op.cast_to_bytesize()
  266. vecinfo = forwarded_vecinfo(op)
  267. return vec_reg_size // vecinfo.bytesize
  268. class GuardRestrict(OpRestrict):
  269. def opcount_filling_vector_register(self, op, vec_reg_size):
  270. arg = op.getarg(0)
  271. vecinfo = forwarded_vecinfo(arg)
  272. return vec_reg_size // vecinfo.bytesize
  273. class LoadRestrict(OpRestrict):
  274. def opcount_filling_vector_register(self, op, vec_reg_size):
  275. assert rop.is_primitive_load(op.opnum)
  276. descr = op.getdescr()
  277. return vec_reg_size // descr.get_item_size_in_bytes()
  278. class StoreRestrict(OpRestrict):
  279. def __init__(self, argument_restris):
  280. self.argument_restrictions = argument_restris
  281. def must_crop_vector(self, op, index):
  282. vecinfo = forwarded_vecinfo(op.getarg(index))
  283. bytesize = vecinfo.bytesize
  284. return self.crop_to_size(op, index) != bytesize
  285. @always_inline
  286. def crop_to_size(self, op, index):
  287. # there is only one parameter that needs to be transformed!
  288. descr = op.getdescr()
  289. return descr.get_item_size_in_bytes()
  290. def opcount_filling_vector_register(self, op, vec_reg_size):
  291. assert rop.is_primitive_store(op.opnum)
  292. descr = op.getdescr()
  293. return vec_reg_size // descr.get_item_size_in_bytes()
  294. class OpMatchSizeTypeFirst(OpRestrict):
  295. def check_operation(self, state, pack, op):
  296. i = 0
  297. infos = [forwarded_vecinfo(o) for o in op.getarglist()]
  298. arg0 = op.getarg(i)
  299. while arg0.is_constant() and i < op.numargs():
  300. i += 1
  301. arg0 = op.getarg(i)
  302. vecinfo = forwarded_vecinfo(arg0)
  303. bytesize = vecinfo.bytesize
  304. datatype = vecinfo.datatype
  305. for arg in op.getarglist():
  306. if arg.is_constant():
  307. continue
  308. curvecinfo = forwarded_vecinfo(arg)
  309. if curvecinfo.bytesize != bytesize:
  310. raise NotAVectorizeableLoop()
  311. if curvecinfo.datatype != datatype:
  312. raise NotAVectorizeableLoop()
  313. class trans(object):
  314. TR_ANY = TypeRestrict()
  315. TR_ANY_FLOAT = TypeRestrict(FLOAT)
  316. TR_ANY_INTEGER = TypeRestrict(INT)
  317. TR_FLOAT_2 = TypeRestrict(FLOAT, 4, 2)
  318. TR_DOUBLE_2 = TypeRestrict(FLOAT, 8, 2)
  319. TR_INT32_2 = TypeRestrict(INT, 4, 2)
  320. OR_MSTF_I = OpMatchSizeTypeFirst([TR_ANY_INTEGER, TR_ANY_INTEGER])
  321. OR_MSTF_F = OpMatchSizeTypeFirst([TR_ANY_FLOAT, TR_ANY_FLOAT])
  322. STORE_RESTRICT = StoreRestrict([None, None, TR_ANY])
  323. LOAD_RESTRICT = LoadRestrict([])
  324. GUARD_RESTRICT = GuardRestrict([TR_ANY_INTEGER])
  325. # note that the following definition is x86 arch specific
  326. MAPPING = {
  327. rop.VEC_INT_ADD: OR_MSTF_I,
  328. rop.VEC_INT_SUB: OR_MSTF_I,
  329. rop.VEC_INT_MUL: OR_MSTF_I,
  330. rop.VEC_INT_AND: OR_MSTF_I,
  331. rop.VEC_INT_OR: OR_MSTF_I,
  332. rop.VEC_INT_XOR: OR_MSTF_I,
  333. rop.VEC_INT_EQ: OR_MSTF_I,
  334. rop.VEC_INT_NE: OR_MSTF_I,
  335. rop.VEC_FLOAT_ADD: OR_MSTF_F,
  336. rop.VEC_FLOAT_SUB: OR_MSTF_F,
  337. rop.VEC_FLOAT_MUL: OR_MSTF_F,
  338. rop.VEC_FLOAT_TRUEDIV: OR_MSTF_F,
  339. rop.VEC_FLOAT_ABS: OpRestrict([TR_ANY_FLOAT]),
  340. rop.VEC_FLOAT_NEG: OpRestrict([TR_ANY_FLOAT]),
  341. rop.VEC_RAW_STORE: STORE_RESTRICT,
  342. rop.VEC_SETARRAYITEM_RAW: STORE_RESTRICT,
  343. rop.VEC_SETARRAYITEM_GC: STORE_RESTRICT,
  344. rop.VEC_RAW_LOAD_I: LOAD_RESTRICT,
  345. rop.VEC_RAW_LOAD_F: LOAD_RESTRICT,
  346. rop.VEC_GETARRAYITEM_RAW_I: LOAD_RESTRICT,
  347. rop.VEC_GETARRAYITEM_RAW_F: LOAD_RESTRICT,
  348. rop.VEC_GETARRAYITEM_GC_I: LOAD_RESTRICT,
  349. rop.VEC_GETARRAYITEM_GC_F: LOAD_RESTRICT,
  350. rop.VEC_GUARD_TRUE: GUARD_RESTRICT,
  351. rop.VEC_GUARD_FALSE: GUARD_RESTRICT,
  352. ## irregular
  353. rop.VEC_INT_SIGNEXT: OpRestrict([TR_ANY_INTEGER]),
  354. rop.VEC_CAST_FLOAT_TO_SINGLEFLOAT: OpRestrict([TR_DOUBLE_2]),
  355. # weird but the trace will store single floats in int boxes
  356. rop.VEC_CAST_SINGLEFLOAT_TO_FLOAT: OpRestrict([TR_INT32_2]),
  357. rop.VEC_CAST_FLOAT_TO_INT: OpRestrict([TR_DOUBLE_2]),
  358. rop.VEC_CAST_INT_TO_FLOAT: OpRestrict([TR_INT32_2]),
  359. rop.VEC_FLOAT_EQ: OpRestrict([TR_ANY_FLOAT,TR_ANY_FLOAT]),
  360. rop.VEC_FLOAT_NE: OpRestrict([TR_ANY_FLOAT,TR_ANY_FLOAT]),
  361. rop.VEC_INT_IS_TRUE: OpRestrict([TR_ANY_INTEGER,TR_ANY_INTEGER]),
  362. }
  363. @staticmethod
  364. def get(op):
  365. res = trans.MAPPING.get(op.vector, None)
  366. if not res:
  367. failnbail_transformation("could not get OpRestrict for " + str(op))
  368. return res
  369. def turn_into_vector(state, pack):
  370. """ Turn a pack into a vector instruction """
  371. check_if_pack_supported(state, pack)
  372. state.costmodel.record_pack_savings(pack, pack.numops())
  373. left = pack.leftmost()
  374. oprestrict = trans.get(left)
  375. if oprestrict is not None:
  376. oprestrict.check_operation(state, pack, left)
  377. args = left.getarglist_copy()
  378. prepare_arguments(state, pack, args)
  379. vecop = VecOperation(left.vector, args, left,
  380. pack.numops(), left.getdescr())
  381. for i,node in enumerate(pack.operations):
  382. op = node.getoperation()
  383. if op.returns_void():
  384. continue
  385. state.setvector_of_box(op,i,vecop)
  386. if pack.is_accumulating() and not op.is_guard():
  387. state.renamer.start_renaming(op, vecop)
  388. if left.is_guard():
  389. prepare_fail_arguments(state, pack, left, vecop)
  390. state.oplist.append(vecop)
  391. assert vecop.count >= 1
  392. def prepare_arguments(state, pack, args):
  393. # Transforming one argument to a vector box argument
  394. # The following cases can occur:
  395. # 1) argument is present in the box_to_vbox map.
  396. # a) vector can be reused immediatly (simple case)
  397. # b) the size of the input is mismatching (crop the vector)
  398. # c) values are scattered in differnt registers
  399. # d) the operand is not at the right position in the vector
  400. # 2) argument is not known to reside in a vector
  401. # a) expand vars/consts before the label and add as argument
  402. # b) expand vars created in the loop body
  403. #
  404. oprestrict = trans.MAPPING.get(pack.leftmost().vector, None)
  405. if not oprestrict:
  406. return
  407. restrictions = oprestrict.argument_restrictions
  408. for i,arg in enumerate(args):
  409. if i >= len(restrictions) or restrictions[i] is None:
  410. # ignore this argument
  411. continue
  412. restrict = restrictions[i]
  413. if arg.returns_vector():
  414. restrict.check(arg)
  415. continue
  416. pos, vecop = state.getvector_of_box(arg)
  417. if not vecop:
  418. # 2) constant/variable expand this box
  419. expand(state, pack, args, arg, i)
  420. restrict.check(args[i])
  421. continue
  422. # 1)
  423. args[i] = vecop # a)
  424. assemble_scattered_values(state, pack, args, i) # c)
  425. crop_vector(state, oprestrict, restrict, pack, args, i) # b)
  426. position_values(state, restrict, pack, args, i, pos) # d)
  427. restrict.check(args[i])
  428. def prepare_fail_arguments(state, pack, left, vecop):
  429. assert isinstance(left, GuardResOp)
  430. assert isinstance(vecop, GuardResOp)
  431. args = left.getfailargs()[:]
  432. for i, arg in enumerate(args):
  433. pos, newarg = state.getvector_of_box(arg)
  434. if newarg is None:
  435. newarg = arg
  436. if newarg.is_vector(): # can be moved to guard exit!
  437. newarg = unpack_from_vector(state, newarg, 0, 1)
  438. args[i] = newarg
  439. vecop.setfailargs(args)
  440. # TODO vecop.rd_snapshot = left.rd_snapshot
  441. @always_inline
  442. def crop_vector(state, oprestrict, restrict, pack, args, i):
  443. # convert size i64 -> i32, i32 -> i64, ...
  444. arg = args[i]
  445. vecinfo = forwarded_vecinfo(arg)
  446. size = vecinfo.bytesize
  447. left = pack.leftmost()
  448. if oprestrict.must_crop_vector(left, i):
  449. newsize = oprestrict.crop_to_size(left, i)
  450. assert arg.type == 'i'
  451. state._prevent_signext(newsize, size)
  452. count = vecinfo.count
  453. vecop = VecOperationNew(rop.VEC_INT_SIGNEXT, [arg, ConstInt(newsize)],
  454. 'i', newsize, vecinfo.signed, count)
  455. state.oplist.append(vecop)
  456. state.costmodel.record_cast_int(size, newsize, count)
  457. args[i] = vecop
  458. @always_inline
  459. def assemble_scattered_values(state, pack, args, index):
  460. args_at_index = [node.getoperation().getarg(index) for node in pack.operations]
  461. args_at_index[0] = args[index]
  462. vectors = pack.argument_vectors(state, pack, index, args_at_index)
  463. if len(vectors) > 1:
  464. # the argument is scattered along different vector boxes
  465. args[index] = gather(state, vectors, pack.numops())
  466. state.remember_args_in_vector(pack, index, args[index])
  467. @always_inline
  468. def gather(state, vectors, count): # packed < packable and packed < stride:
  469. (_, arg) = vectors[0]
  470. i = 1
  471. while i < len(vectors):
  472. (newarg_pos, newarg) = vectors[i]
  473. vecinfo = forwarded_vecinfo(arg)
  474. newvecinfo = forwarded_vecinfo(newarg)
  475. if vecinfo.count + newvecinfo.count <= count:
  476. arg = pack_into_vector(state, arg, vecinfo.count, newarg, newarg_pos, newvecinfo.count)
  477. i += 1
  478. return arg
  479. @always_inline
  480. def position_values(state, restrict, pack, args, index, position):
  481. arg = args[index]
  482. vecinfo = forwarded_vecinfo(arg)
  483. count = vecinfo.count
  484. newcount = restrict.count
  485. if not restrict.any_count() and newcount != count:
  486. if position == 0:
  487. pass
  488. pass
  489. if position != 0:
  490. # The vector box is at a position != 0 but it
  491. # is required to be at position 0. Unpack it!
  492. arg = args[index]
  493. vecinfo = forwarded_vecinfo(arg)
  494. count = restrict.max_input_count(vecinfo.count)
  495. args[index] = unpack_from_vector(state, arg, position, count)
  496. state.remember_args_in_vector(pack, index, args[index])
  497. def check_if_pack_supported(state, pack):
  498. left = pack.leftmost()
  499. vecinfo = forwarded_vecinfo(left)
  500. insize = vecinfo.bytesize
  501. if left.is_typecast():
  502. # prohibit the packing of signext calls that
  503. # cast to int16/int8.
  504. state._prevent_signext(left.cast_to_bytesize(),
  505. left.cast_from_bytesize())
  506. if left.getopnum() == rop.INT_MUL:
  507. if insize == 8 or insize == 1:
  508. # see assembler for comment why
  509. raise NotAProfitableLoop
  510. def unpack_from_vector(state, arg, index, count):
  511. """ Extract parts of the vector box into another vector box """
  512. assert count > 0
  513. vecinfo = forwarded_vecinfo(arg)
  514. assert index + count <= vecinfo.count
  515. args = [arg, ConstInt(index), ConstInt(count)]
  516. vecop = OpHelpers.create_vec_unpack(arg.type, args, vecinfo.bytesize,
  517. vecinfo.signed, count)
  518. state.costmodel.record_vector_unpack(arg, index, count)
  519. state.oplist.append(vecop)
  520. return vecop
  521. def pack_into_vector(state, tgt, tidx, src, sidx, scount):
  522. """ tgt = [1,2,3,4,_,_,_,_]
  523. src = [5,6,_,_]
  524. new_box = [1,2,3,4,5,6,_,_] after the operation, tidx=4, scount=2
  525. """
  526. assert sidx == 0 # restriction
  527. vecinfo = forwarded_vecinfo(tgt)
  528. newcount = vecinfo.count + scount
  529. args = [tgt, src, ConstInt(tidx), ConstInt(scount)]
  530. vecop = OpHelpers.create_vec_pack(tgt.type, args, vecinfo.bytesize, vecinfo.signed, newcount)
  531. state.oplist.append(vecop)
  532. state.costmodel.record_vector_pack(src, sidx, scount)
  533. if not we_are_translated():
  534. _check_vec_pack(vecop)
  535. return vecop
  536. def _check_vec_pack(op):
  537. arg0 = op.getarg(0)
  538. arg1 = op.getarg(1)
  539. index = op.getarg(2)
  540. count = op.getarg(3)
  541. assert op.is_vector()
  542. assert arg0.is_vector()
  543. assert index.is_constant()
  544. assert isinstance(count, ConstInt)
  545. vecinfo = forwarded_vecinfo(op)
  546. argvecinfo = forwarded_vecinfo(arg0)
  547. assert argvecinfo.bytesize == vecinfo.bytesize
  548. if arg1.is_vector():
  549. assert argvecinfo.bytesize == vecinfo.bytesize
  550. else:
  551. assert count.value == 1
  552. assert index.value < vecinfo.count
  553. assert index.value + count.value <= vecinfo.count
  554. assert vecinfo.count > argvecinfo.count
  555. def expand(state, pack, args, arg, index):
  556. """ Expand a value into a vector box. useful for arith metic
  557. of one vector with a scalar (either constant/varialbe)
  558. """
  559. left = pack.leftmost()
  560. box_type = arg.type
  561. expanded_map = state.expanded_map
  562. ops = state.invariant_oplist
  563. variables = state.invariant_vector_vars
  564. if not arg.is_constant() and arg not in state.inputargs:
  565. # cannot be created before the loop, expand inline
  566. ops = state.oplist
  567. variables = None
  568. for i, node in enumerate(pack.operations):
  569. op = node.getoperation()
  570. if not arg.same_box(op.getarg(index)):
  571. break
  572. i += 1
  573. else:
  574. # note that heterogenous nodes are not yet tracked
  575. vecop = state.find_expanded([arg])
  576. if vecop:
  577. args[index] = vecop
  578. return vecop
  579. left = pack.leftmost()
  580. vecinfo = forwarded_vecinfo(left)
  581. vecop = OpHelpers.create_vec_expand(arg, vecinfo.bytesize, vecinfo.signed, pack.numops())
  582. ops.append(vecop)
  583. if variables is not None:
  584. variables.append(vecop)
  585. state.expand([arg], vecop)
  586. args[index] = vecop
  587. return vecop
  588. # quick search if it has already been expanded
  589. expandargs = [op.getoperation().getarg(index) for op in pack.operations]
  590. vecop = state.find_expanded(expandargs)
  591. if vecop:
  592. args[index] = vecop
  593. return vecop
  594. arg_vecinfo = forwarded_vecinfo(arg)
  595. vecop = OpHelpers.create_vec(arg.type, arg_vecinfo.bytesize, arg_vecinfo.signed, pack.opnum())
  596. ops.append(vecop)
  597. for i,node in enumerate(pack.operations):
  598. op = node.getoperation()
  599. arg = op.getarg(index)
  600. arguments = [vecop, arg, ConstInt(i), ConstInt(1)]
  601. vecinfo = forwarded_vecinfo(vecop)
  602. vecop = OpHelpers.create_vec_pack(arg.type, arguments, vecinfo.bytesize,
  603. vecinfo.signed, vecinfo.count+1)
  604. ops.append(vecop)
  605. state.expand(expandargs, vecop)
  606. if variables is not None:
  607. variables.append(vecop)
  608. args[index] = vecop
  609. class VecScheduleState(SchedulerState):
  610. def __init__(self, graph, packset, cpu, costmodel):
  611. SchedulerState.__init__(self, graph)
  612. self.box_to_vbox = {}
  613. self.cpu = cpu
  614. self.vec_reg_size = cpu.vector_register_size
  615. self.expanded_map = {}
  616. self.costmodel = costmodel
  617. self.inputargs = {}
  618. self.packset = packset
  619. for arg in graph.loop.inputargs:
  620. self.inputargs[arg] = None
  621. self.accumulation = {}
  622. def expand(self, args, vecop):
  623. index = 0
  624. if len(args) == 1:
  625. # loop is executed once, thus sets -1 as index
  626. index = -1
  627. for arg in args:
  628. self.expanded_map.setdefault(arg, []).append((vecop, index))
  629. index += 1
  630. def find_expanded(self, args):
  631. if len(args) == 1:
  632. candidates = self.expanded_map.get(args[0], [])
  633. for (vecop, index) in candidates:
  634. if index == -1:
  635. # found an expanded variable/constant
  636. return vecop
  637. return None
  638. possible = {}
  639. for i, arg in enumerate(args):
  640. expansions = self.expanded_map.get(arg, [])
  641. candidates = [vecop for (vecop, index) in expansions \
  642. if i == index and possible.get(vecop,True)]
  643. for vecop in candidates:
  644. for key in possible.keys():
  645. if key not in candidates:
  646. # delete every not possible key,value
  647. possible[key] = False
  648. # found a candidate, append it if not yet present
  649. possible[vecop] = True
  650. if not possible:
  651. # no possibility left, this combination is not expanded
  652. return None
  653. for vecop,valid in possible.items():
  654. if valid:
  655. return vecop
  656. return None
  657. def post_emit(self, node):
  658. pass
  659. def pre_emit(self, node):
  660. op = node.getoperation()
  661. if op.is_guard():
  662. # add accumulation info to the descriptor
  663. failargs = op.getfailargs()[:]
  664. descr = op.getdescr()
  665. # note: stitching a guard must resemble the order of the label
  666. # otherwise a wrong mapping is handed to the register allocator
  667. for i,arg in enumerate(failargs):
  668. if arg is None:
  669. continue
  670. accum = self.accumulation.get(arg, None)
  671. if accum:
  672. from rpython.jit.metainterp.compile import AbstractResumeGuardDescr
  673. assert isinstance(accum, AccumPack)
  674. assert isinstance(descr, AbstractResumeGuardDescr)
  675. info = AccumInfo(i, arg, accum.operator)
  676. descr.attach_vector_info(info)
  677. seed = accum.getleftmostseed()
  678. failargs[i] = self.renamer.rename_map.get(seed, seed)
  679. op.setfailargs(failargs)
  680. def profitable(self):
  681. return self.costmodel.profitable()
  682. def prepare(self):
  683. SchedulerState.prepare(self)
  684. self.packset.accumulate_prepare(self)
  685. for arg in self.graph.loop.label.getarglist():
  686. self.seen[arg] = None
  687. def emit(self, node, scheduler):
  688. """ If you implement a scheduler this operations is called
  689. to emit the actual operation into the oplist of the scheduler.
  690. """
  691. if node.pack:
  692. assert node.pack.numops() > 1
  693. for node in node.pack.operations:
  694. self.pre_emit(node)
  695. scheduler.mark_emitted(node, self, unpack=False)
  696. turn_into_vector(self, node.pack)
  697. return True
  698. return False
  699. def delay(self, node):
  700. if node.pack:
  701. pack = node.pack
  702. if pack.is_accumulating():
  703. for node in pack.operations:
  704. for dep in node.depends():
  705. if dep.to.pack is not pack:
  706. return True
  707. else:
  708. for node in pack.operations:
  709. if node.depends_count() > 0:
  710. return True
  711. return False
  712. def ensure_args_unpacked(self, op):
  713. """ If a box is needed that is currently stored within a vector
  714. box, this utility creates a unpacking instruction.
  715. """
  716. # unpack for an immediate use
  717. for i, argument in enumerate(op.getarglist()):
  718. if not argument.is_constant():
  719. arg = self.ensure_unpacked(i, argument)
  720. if argument is not arg:
  721. op.setarg(i, arg)
  722. # unpack for a guard exit
  723. if op.is_guard():
  724. # could be moved to the guard exit
  725. fail_args = op.getfailargs()
  726. for i, argument in enumerate(fail_args):
  727. if argument and not argument.is_constant():
  728. arg = self.ensure_unpacked(i, argument)
  729. if argument is not arg:
  730. fail_args[i] = arg
  731. op.setfailargs(fail_args)
  732. def ensure_unpacked(self, index, arg):
  733. if arg in self.seen or arg.is_vector():
  734. return arg
  735. (pos, var) = self.getvector_of_box(arg)
  736. if var:
  737. if var in self.invariant_vector_vars:
  738. return arg
  739. if arg in self.accumulation:
  740. return arg
  741. args = [var, ConstInt(pos), ConstInt(1)]
  742. vecinfo = forwarded_vecinfo(var)
  743. vecop = OpHelpers.create_vec_unpack(var.type, args, vecinfo.bytesize,
  744. vecinfo.signed, 1)
  745. self.renamer.start_renaming(arg, vecop)
  746. self.seen[vecop] = None
  747. self.costmodel.record_vector_unpack(var, pos, 1)
  748. self.oplist.append(vecop)
  749. return vecop
  750. return arg
  751. def _prevent_signext(self, outsize, insize):
  752. if insize != outsize:
  753. if outsize < 4 or insize < 4:
  754. raise NotAProfitableLoop
  755. def getvector_of_box(self, arg):
  756. return self.box_to_vbox.get(arg, (-1, None))
  757. def setvector_of_box(self, var, off, vector):
  758. if var.returns_void():
  759. assert 0, "not allowed to rename void resop"
  760. vecinfo = forwarded_vecinfo(vector)
  761. assert off < vecinfo.count
  762. assert not var.is_vector()
  763. self.box_to_vbox[var] = (off, vector)
  764. def remember_args_in_vector(self, pack, index, box):
  765. arguments = [op.getoperation().getarg(index) for op in pack.operations]
  766. for i,arg in enumerate(arguments):
  767. vecinfo = forwarded_vecinfo(arg)
  768. if i >= vecinfo.count:
  769. break
  770. self.setvector_of_box(arg, i, box)
  771. class Pack(object):
  772. """ A pack is a set of n statements that are:
  773. * isomorphic
  774. * independent
  775. """
  776. FULL = 0
  777. _attrs_ = ('operations', 'accumulator', 'operator', 'position')
  778. operator = '\x00'
  779. position = -1
  780. accumulator = None
  781. def __init__(self, ops):
  782. self.operations = ops
  783. self.update_pack_of_nodes()
  784. def numops(self):
  785. return len(self.operations)
  786. @specialize.arg(1)
  787. def leftmost(self, node=False):
  788. if node:
  789. return self.operations[0]
  790. return self.operations[0].getoperation()
  791. @specialize.arg(1)
  792. def rightmost(self, node=False):
  793. if node:
  794. return self.operations[-1]
  795. return self.operations[-1].getoperation()
  796. def pack_type(self):
  797. ptype = self.input_type
  798. if self.input_type is None:
  799. # load does not have an input type, but only an output type
  800. ptype = self.output_type
  801. return ptype
  802. def input_byte_size(self):
  803. """ The amount of bytes the operations need with the current
  804. entries in self.operations. E.g. cast_singlefloat_to_float
  805. takes only #2 operations.
  806. """
  807. return self._byte_size(self.input_type)
  808. def output_byte_size(self):
  809. """ The amount of bytes the operations need with the current
  810. entries in self.operations. E.g. vec_load(..., descr=short)
  811. with 10 operations returns 20
  812. """
  813. return self._byte_size(self.output_type)
  814. def pack_load(self, vec_reg_size):
  815. """ Returns the load of the pack a vector register would hold
  816. just after executing the operation.
  817. returns: < 0 - empty, nearly empty
  818. = 0 - full
  819. > 0 - overloaded
  820. """
  821. left = self.leftmost()
  822. if left.returns_void():
  823. if rop.is_primitive_store(left.opnum):
  824. # make this case more general if it turns out this is
  825. # not the only case where packs need to be trashed
  826. descr = left.getdescr()
  827. bytesize = descr.get_item_size_in_bytes()
  828. return bytesize * self.numops() - vec_reg_size
  829. else:
  830. assert left.is_guard() and left.getopnum() in \
  831. (rop.GUARD_TRUE, rop.GUARD_FALSE)
  832. vecinfo = forwarded_vecinfo(left.getarg(0))
  833. bytesize = vecinfo.bytesize
  834. return bytesize * self.numops() - vec_reg_size
  835. return 0
  836. if self.numops() == 0:
  837. return -1
  838. if left.is_typecast():
  839. # casting is special, often only takes a half full vector
  840. if left.casts_down():
  841. # size is reduced
  842. size = left.cast_input_bytesize(vec_reg_size)
  843. return left.cast_from_bytesize() * self.numops() - size
  844. else:
  845. # size is increased
  846. #size = left.cast_input_bytesize(vec_reg_size)
  847. return left.cast_to_bytesize() * self.numops() - vec_reg_size
  848. vecinfo = forwarded_vecinfo(left)
  849. return vecinfo.bytesize * self.numops() - vec_reg_size
  850. def is_full(self, vec_reg_size):
  851. """ If one input element times the opcount is equal
  852. to the vector register size, we are full!
  853. """
  854. return self.pack_load(vec_reg_size) == Pack.FULL
  855. def opnum(self):
  856. assert len(self.operations) > 0
  857. return self.operations[0].getoperation().getopnum()
  858. def clear(self):
  859. for node in self.operations:
  860. node.pack = None
  861. node.pack_position = -1
  862. def update_pack_of_nodes(self):
  863. for i,node in enumerate(self.operations):
  864. node.pack = self
  865. node.pack_position = i
  866. def split(self, packlist, vec_reg_size):
  867. """ Combination phase creates the biggest packs that are possible.
  868. In this step the pack is reduced in size to fit into an
  869. vector register.
  870. """
  871. before_count = len(packlist)
  872. pack = self
  873. while pack.pack_load(vec_reg_size) > Pack.FULL:
  874. pack.clear()
  875. oplist, newoplist = pack.slice_operations(vec_reg_size)
  876. pack.operations = oplist
  877. pack.update_pack_of_nodes()
  878. if not pack.leftmost().is_typecast():
  879. assert pack.is_full(vec_reg_size)
  880. #
  881. newpack = pack.clone(newoplist)
  882. load = newpack.pack_load(vec_reg_size)
  883. if load >= Pack.FULL:
  884. pack.update_pack_of_nodes()
  885. pack = newpack
  886. packlist.append(newpack)
  887. else:
  888. newpack.clear()
  889. newpack.operations = []
  890. break
  891. pack.update_pack_of_nodes()
  892. def opcount_filling_vector_register(self, vec_reg_size):
  893. left = self.leftmost()
  894. oprestrict = trans.get(left)
  895. return oprestrict.opcount_filling_vector_register(left, vec_reg_size)
  896. def slice_operations(self, vec_reg_size):
  897. count = self.opcount_filling_vector_register(vec_reg_size)
  898. assert count > 0
  899. newoplist = self.operations[count:]
  900. oplist = self.operations[:count]
  901. assert len(newoplist) + len(oplist) == len(self.operations)
  902. assert len(newoplist) != 0
  903. return oplist, newoplist
  904. def rightmost_match_leftmost(self, other):
  905. """ Check if pack A can be combined with pack B """
  906. assert isinstance(other, Pack)
  907. rightmost = self.operations[-1]
  908. leftmost = other.operations[0]
  909. # if it is not accumulating it is valid
  910. if self.is_accumulating():
  911. if not other.is_accumulating():
  912. return False
  913. elif self.position != other.position:
  914. return False
  915. return rightmost is leftmost
  916. def argument_vectors(self, state, pack, index, pack_args_index):
  917. vectors = []
  918. last = None
  919. for arg in pack_args_index:
  920. pos, vecop = state.getvector_of_box(arg)
  921. if vecop is not last and vecop is not None:
  922. vectors.append((pos, vecop))
  923. last = vecop
  924. return vectors
  925. def __repr__(self):
  926. if len(self.operations) == 0:
  927. return "Pack(empty)"
  928. packs = self.operations[0].op.getopname() + '[' + ','.join(['%2d' % (o.opidx) for o in self.operations]) + ']'
  929. if self.operations[0].op.getdescr():
  930. packs += 'descr=' + str(self.operations[0].op.getdescr())
  931. return "Pack(%dx %s)" % (self.numops(), packs)
  932. def is_accumulating(self):
  933. return False
  934. def clone(self, oplist):
  935. return Pack(oplist)
  936. class Pair(Pack):
  937. """ A special Pack object with only two statements. """
  938. def __init__(self, left, right):
  939. assert isinstance(left, Node)
  940. assert isinstance(right, Node)
  941. Pack.__init__(self, [left, right])
  942. def __eq__(self, other):
  943. if isinstance(other, Pair):
  944. return self.left is other.left and \
  945. self.right is other.right
  946. class AccumPack(Pack):
  947. SUPPORTED = { rop.FLOAT_ADD: '+',
  948. rop.INT_ADD: '+',
  949. rop.FLOAT_MUL: '*',
  950. }
  951. def __init__(self, nodes, operator, position):
  952. Pack.__init__(self, nodes)
  953. self.operator = operator
  954. self.position = position
  955. def getdatatype(self):
  956. accum = self.leftmost().getarg(self.position)
  957. vecinfo = forwarded_vecinfo(accum)
  958. return vecinfo.datatype
  959. def getbytesize(self):
  960. accum = self.leftmost().getarg(self.position)
  961. vecinfo = forwarded_vecinfo(accum)
  962. return vecinfo.bytesize
  963. def getleftmostseed(self):
  964. return self.leftmost().getarg(self.position)
  965. def getseeds(self):
  966. """ The accumulatoriable holding the seed value """
  967. return [op.getoperation().getarg(self.position) for op in self.operations]
  968. def reduce_init(self):
  969. if self.operator == '*':
  970. return 1
  971. return 0
  972. def is_accumulating(self):
  973. return True
  974. def clone(self, oplist):
  975. return AccumPack(oplist, self.operator, self.position)