PageRenderTime 60ms CodeModel.GetById 23ms RepoModel.GetById 1ms app.codeStats 0ms

/pypy/module/micronumpy/loop.py

https://bitbucket.org/pypy/pypy/
Python | 1041 lines | 1014 code | 7 blank | 20 comment | 19 complexity | 4fd7a71bb58d02ca970c74258d89d778 MD5 | raw file
Possible License(s): AGPL-3.0, BSD-3-Clause, Apache-2.0
  1. """ This file is the main run loop as well as evaluation loops for various
  2. operations. This is the place to look for all the computations that iterate
  3. over all the array elements.
  4. """
  5. import py
  6. from pypy.interpreter.error import oefmt
  7. from rpython.rlib import jit
  8. from rpython.rlib.rstring import StringBuilder
  9. from rpython.rtyper.lltypesystem import lltype, rffi
  10. from pypy.module.micronumpy import support, constants as NPY
  11. from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
  12. from pypy.module.micronumpy.iterators import PureShapeIter, AxisIter, \
  13. AllButAxisIter, ArrayIter
  14. from pypy.interpreter.argument import Arguments
  15. def call2(space, shape, func, calc_dtype, w_lhs, w_rhs, out):
  16. if w_lhs.get_size() == 1:
  17. w_left = w_lhs.get_scalar_value().convert_to(space, calc_dtype)
  18. left_iter = left_state = None
  19. else:
  20. w_left = None
  21. left_iter, left_state = w_lhs.create_iter(shape)
  22. left_iter.track_index = False
  23. if w_rhs.get_size() == 1:
  24. w_right = w_rhs.get_scalar_value().convert_to(space, calc_dtype)
  25. right_iter = right_state = None
  26. else:
  27. w_right = None
  28. right_iter, right_state = w_rhs.create_iter(shape)
  29. right_iter.track_index = False
  30. out_iter, out_state = out.create_iter(shape)
  31. shapelen = len(shape)
  32. res_dtype = out.get_dtype()
  33. call2_func = try_to_share_iterators_call2(left_iter, right_iter,
  34. left_state, right_state, out_state)
  35. params = (space, shapelen, func, calc_dtype, res_dtype, out,
  36. w_left, w_right, left_iter, right_iter, out_iter,
  37. left_state, right_state, out_state)
  38. return call2_func(*params)
  39. def try_to_share_iterators_call2(left_iter, right_iter, left_state, right_state, out_state):
  40. # these are all possible iterator sharing combinations
  41. # left == right == out
  42. # left == right
  43. # left == out
  44. # right == out
  45. right_out_equal = False
  46. if right_iter:
  47. # rhs is not a scalar
  48. if out_state.same(right_state):
  49. right_out_equal = True
  50. #
  51. if not left_iter:
  52. # lhs is a scalar
  53. if right_out_equal:
  54. return call2_advance_out_left
  55. else:
  56. # worst case, nothing can be shared and lhs is a scalar
  57. return call2_advance_out_left_right
  58. else:
  59. # lhs is NOT a scalar
  60. if out_state.same(left_state):
  61. # (2) out and left are the same -> remove left
  62. if right_out_equal:
  63. # the best case
  64. return call2_advance_out
  65. else:
  66. return call2_advance_out_right
  67. else:
  68. if right_out_equal:
  69. # right and out are equal, only advance left and out
  70. return call2_advance_out_left
  71. else:
  72. if right_iter and right_state.same(left_state):
  73. # left and right are equal, but still need to advance out
  74. return call2_advance_out_left_eq_right
  75. else:
  76. # worst case, nothing can be shared
  77. return call2_advance_out_left_right
  78. assert 0, "logical problem with the selection of the call2 case"
  79. def generate_call2_cases(name, left_state, right_state):
  80. call2_driver = jit.JitDriver(name='numpy_call2_' + name,
  81. greens=['shapelen', 'func', 'calc_dtype', 'res_dtype'],
  82. reds='auto', vectorize=True)
  83. #
  84. advance_left_state = left_state == "left_state"
  85. advance_right_state = right_state == "right_state"
  86. code = """
  87. def method(space, shapelen, func, calc_dtype, res_dtype, out,
  88. w_left, w_right, left_iter, right_iter, out_iter,
  89. left_state, right_state, out_state):
  90. while not out_iter.done(out_state):
  91. call2_driver.jit_merge_point(shapelen=shapelen, func=func,
  92. calc_dtype=calc_dtype, res_dtype=res_dtype)
  93. if left_iter:
  94. w_left = left_iter.getitem({left_state}).convert_to(space, calc_dtype)
  95. if right_iter:
  96. w_right = right_iter.getitem({right_state}).convert_to(space, calc_dtype)
  97. w_out = func(calc_dtype, w_left, w_right)
  98. out_iter.setitem(out_state, w_out.convert_to(space, res_dtype))
  99. out_state = out_iter.next(out_state)
  100. if advance_left_state and left_iter:
  101. left_state = left_iter.next(left_state)
  102. if advance_right_state and right_iter:
  103. right_state = right_iter.next(right_state)
  104. #
  105. # if not set to None, the values will be loop carried
  106. # (for the var,var case), forcing the vectorization to unpack
  107. # the vector registers at the end of the loop
  108. if left_iter:
  109. w_left = None
  110. if right_iter:
  111. w_right = None
  112. return out
  113. """
  114. exec(py.code.Source(code.format(left_state=left_state,right_state=right_state)).compile(), locals())
  115. method.__name__ = "call2_" + name
  116. return method
  117. call2_advance_out = generate_call2_cases("inc_out", "out_state", "out_state")
  118. call2_advance_out_left = generate_call2_cases("inc_out_left", "left_state", "out_state")
  119. call2_advance_out_right = generate_call2_cases("inc_out_right", "out_state", "right_state")
  120. call2_advance_out_left_eq_right = generate_call2_cases("inc_out_left_eq_right", "left_state", "left_state")
  121. call2_advance_out_left_right = generate_call2_cases("inc_out_left_right", "left_state", "right_state")
  122. call1_driver = jit.JitDriver(
  123. name='numpy_call1',
  124. greens=['shapelen', 'share_iterator', 'func', 'calc_dtype', 'res_dtype'],
  125. reds='auto', vectorize=True)
  126. def call1(space, shape, func, calc_dtype, w_obj, w_ret):
  127. obj_iter, obj_state = w_obj.create_iter(shape)
  128. obj_iter.track_index = False
  129. out_iter, out_state = w_ret.create_iter(shape)
  130. shapelen = len(shape)
  131. res_dtype = w_ret.get_dtype()
  132. share_iterator = out_state.same(obj_state)
  133. while not out_iter.done(out_state):
  134. call1_driver.jit_merge_point(shapelen=shapelen, func=func,
  135. share_iterator=share_iterator,
  136. calc_dtype=calc_dtype, res_dtype=res_dtype)
  137. if share_iterator:
  138. # use out state as param to getitem
  139. elem = obj_iter.getitem(out_state).convert_to(space, calc_dtype)
  140. else:
  141. elem = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
  142. out_iter.setitem(out_state, func(calc_dtype, elem).convert_to(space, res_dtype))
  143. if share_iterator:
  144. # only advance out, they share the same iteration space
  145. out_state = out_iter.next(out_state)
  146. else:
  147. out_state = out_iter.next(out_state)
  148. obj_state = obj_iter.next(obj_state)
  149. elem = None
  150. return w_ret
  151. call_many_to_one_driver = jit.JitDriver(
  152. name='numpy_call_many_to_one',
  153. greens=['shapelen', 'nin', 'func', 'in_dtypes', 'res_dtype'],
  154. reds='auto')
  155. def call_many_to_one(space, shape, func, in_dtypes, res_dtype, in_args, out):
  156. # out must hav been built. func needs no calc_type, is usually an
  157. # external ufunc
  158. nin = len(in_args)
  159. in_iters = [None] * nin
  160. in_states = [None] * nin
  161. for i in range(nin):
  162. in_i = in_args[i]
  163. assert isinstance(in_i, W_NDimArray)
  164. in_iter, in_state = in_i.create_iter(shape)
  165. in_iters[i] = in_iter
  166. in_states[i] = in_state
  167. shapelen = len(shape)
  168. assert isinstance(out, W_NDimArray)
  169. out_iter, out_state = out.create_iter(shape)
  170. vals = [None] * nin
  171. while not out_iter.done(out_state):
  172. call_many_to_one_driver.jit_merge_point(shapelen=shapelen, func=func,
  173. in_dtypes=in_dtypes, res_dtype=res_dtype, nin=nin)
  174. for i in range(nin):
  175. vals[i] = in_dtypes[i].coerce(space, in_iters[i].getitem(in_states[i]))
  176. w_arglist = space.newlist(vals)
  177. w_out_val = space.call_args(func, Arguments.frompacked(space, w_arglist))
  178. out_iter.setitem(out_state, res_dtype.coerce(space, w_out_val))
  179. for i in range(nin):
  180. in_states[i] = in_iters[i].next(in_states[i])
  181. out_state = out_iter.next(out_state)
  182. return out
  183. call_many_to_many_driver = jit.JitDriver(
  184. name='numpy_call_many_to_many',
  185. greens=['shapelen', 'nin', 'nout', 'func', 'in_dtypes', 'out_dtypes'],
  186. reds='auto')
  187. def call_many_to_many(space, shape, func, in_dtypes, out_dtypes, in_args, out_args):
  188. # out must have been built. func needs no calc_type, is usually an
  189. # external ufunc
  190. nin = len(in_args)
  191. in_iters = [None] * nin
  192. in_states = [None] * nin
  193. nout = len(out_args)
  194. out_iters = [None] * nout
  195. out_states = [None] * nout
  196. for i in range(nin):
  197. in_i = in_args[i]
  198. assert isinstance(in_i, W_NDimArray)
  199. in_iter, in_state = in_i.create_iter(shape)
  200. in_iters[i] = in_iter
  201. in_states[i] = in_state
  202. for i in range(nout):
  203. out_i = out_args[i]
  204. assert isinstance(out_i, W_NDimArray)
  205. out_iter, out_state = out_i.create_iter(shape)
  206. out_iters[i] = out_iter
  207. out_states[i] = out_state
  208. shapelen = len(shape)
  209. vals = [None] * nin
  210. test_iter, test_state = in_iters[-1], in_states[-1]
  211. if nout > 0:
  212. test_iter, test_state = out_iters[0], out_states[0]
  213. while not test_iter.done(test_state):
  214. call_many_to_many_driver.jit_merge_point(shapelen=shapelen, func=func,
  215. in_dtypes=in_dtypes, out_dtypes=out_dtypes,
  216. nin=nin, nout=nout)
  217. for i in range(nin):
  218. vals[i] = in_dtypes[i].coerce(space, in_iters[i].getitem(in_states[i]))
  219. w_arglist = space.newlist(vals)
  220. w_outvals = space.call_args(func, Arguments.frompacked(space, w_arglist))
  221. # w_outvals should be a tuple, but func can return a single value as well
  222. if space.isinstance_w(w_outvals, space.w_tuple):
  223. batch = space.listview(w_outvals)
  224. for i in range(len(batch)):
  225. out_iters[i].setitem(out_states[i], out_dtypes[i].coerce(space, batch[i]))
  226. out_states[i] = out_iters[i].next(out_states[i])
  227. elif nout > 0:
  228. out_iters[0].setitem(out_states[0], out_dtypes[0].coerce(space, w_outvals))
  229. out_states[0] = out_iters[0].next(out_states[0])
  230. for i in range(nin):
  231. in_states[i] = in_iters[i].next(in_states[i])
  232. test_state = test_iter.next(test_state)
  233. return space.newtuple([convert_to_array(space, o) for o in out_args])
  234. setslice_driver = jit.JitDriver(name='numpy_setslice',
  235. greens = ['shapelen', 'dtype'],
  236. reds = 'auto', vectorize=True)
  237. def setslice(space, shape, target, source):
  238. if not shape:
  239. dtype = target.dtype
  240. val = source.getitem(source.start)
  241. if dtype.is_str_or_unicode():
  242. val = dtype.coerce(space, val)
  243. else:
  244. val = val.convert_to(space, dtype)
  245. target.setitem(target.start, val)
  246. return target
  247. return _setslice(space, shape, target, source)
  248. def _setslice(space, shape, target, source):
  249. # note that unlike everything else, target and source here are
  250. # array implementations, not arrays
  251. target_iter, target_state = target.create_iter(shape)
  252. source_iter, source_state = source.create_iter(shape)
  253. source_iter.track_index = False
  254. dtype = target.dtype
  255. shapelen = len(shape)
  256. while not target_iter.done(target_state):
  257. setslice_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
  258. val = source_iter.getitem(source_state)
  259. if dtype.is_str_or_unicode() or dtype.is_record():
  260. val = dtype.coerce(space, val)
  261. else:
  262. val = val.convert_to(space, dtype)
  263. target_iter.setitem(target_state, val)
  264. target_state = target_iter.next(target_state)
  265. source_state = source_iter.next(source_state)
  266. return target
  267. def split_iter(arr, axis_flags):
  268. """Prepare 2 iterators for nested iteration over `arr`.
  269. Arguments:
  270. arr: instance of BaseConcreteArray
  271. axis_flags: list of bools, one for each dimension of `arr`.The inner
  272. iterator operates over the dimensions for which the flag is True
  273. """
  274. shape = arr.get_shape()
  275. strides = arr.get_strides()
  276. backstrides = arr.get_backstrides()
  277. shapelen = len(shape)
  278. assert len(axis_flags) == shapelen
  279. inner_shape = [-1] * shapelen
  280. inner_strides = [-1] * shapelen
  281. inner_backstrides = [-1] * shapelen
  282. outer_shape = [-1] * shapelen
  283. outer_strides = [-1] * shapelen
  284. outer_backstrides = [-1] * shapelen
  285. for i in range(len(shape)):
  286. if axis_flags[i]:
  287. inner_shape[i] = shape[i]
  288. inner_strides[i] = strides[i]
  289. inner_backstrides[i] = backstrides[i]
  290. outer_shape[i] = 1
  291. outer_strides[i] = 0
  292. outer_backstrides[i] = 0
  293. else:
  294. outer_shape[i] = shape[i]
  295. outer_strides[i] = strides[i]
  296. outer_backstrides[i] = backstrides[i]
  297. inner_shape[i] = 1
  298. inner_strides[i] = 0
  299. inner_backstrides[i] = 0
  300. inner_iter = ArrayIter(arr, support.product(inner_shape),
  301. inner_shape, inner_strides, inner_backstrides)
  302. outer_iter = ArrayIter(arr, support.product(outer_shape),
  303. outer_shape, outer_strides, outer_backstrides)
  304. return inner_iter, outer_iter
  305. reduce_flat_driver = jit.JitDriver(
  306. name='numpy_reduce_flat',
  307. greens = ['shapelen', 'func', 'done_func', 'calc_dtype'], reds = 'auto',
  308. vectorize = True)
  309. def reduce_flat(space, func, w_arr, calc_dtype, done_func, identity):
  310. obj_iter, obj_state = w_arr.create_iter()
  311. if identity is None:
  312. cur_value = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
  313. obj_state = obj_iter.next(obj_state)
  314. else:
  315. cur_value = identity.convert_to(space, calc_dtype)
  316. shapelen = len(w_arr.get_shape())
  317. while not obj_iter.done(obj_state):
  318. reduce_flat_driver.jit_merge_point(
  319. shapelen=shapelen, func=func,
  320. done_func=done_func, calc_dtype=calc_dtype)
  321. rval = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
  322. if done_func is not None and done_func(calc_dtype, rval):
  323. return rval
  324. cur_value = func(calc_dtype, cur_value, rval)
  325. obj_state = obj_iter.next(obj_state)
  326. return cur_value
  327. reduce_driver = jit.JitDriver(
  328. name='numpy_reduce',
  329. greens=['shapelen', 'func', 'dtype'], reds='auto',
  330. vectorize=True)
  331. def reduce(space, func, w_arr, axis_flags, dtype, out, identity):
  332. out_iter, out_state = out.create_iter()
  333. out_iter.track_index = False
  334. shape = w_arr.get_shape()
  335. shapelen = len(shape)
  336. inner_iter, outer_iter = split_iter(w_arr.implementation, axis_flags)
  337. assert outer_iter.size == out_iter.size
  338. if identity is not None:
  339. identity = identity.convert_to(space, dtype)
  340. outer_state = outer_iter.reset()
  341. while not outer_iter.done(outer_state):
  342. inner_state = inner_iter.reset()
  343. inner_state.offset = outer_state.offset
  344. if identity is not None:
  345. w_val = identity
  346. else:
  347. w_val = inner_iter.getitem(inner_state).convert_to(space, dtype)
  348. inner_state = inner_iter.next(inner_state)
  349. while not inner_iter.done(inner_state):
  350. reduce_driver.jit_merge_point(
  351. shapelen=shapelen, func=func, dtype=dtype)
  352. w_item = inner_iter.getitem(inner_state).convert_to(space, dtype)
  353. w_val = func(dtype, w_item, w_val)
  354. inner_state = inner_iter.next(inner_state)
  355. out_iter.setitem(out_state, w_val)
  356. out_state = out_iter.next(out_state)
  357. outer_state = outer_iter.next(outer_state)
  358. return out
  359. accumulate_flat_driver = jit.JitDriver(
  360. name='numpy_accumulate_flat',
  361. greens=['shapelen', 'func', 'dtype', 'out_dtype'],
  362. reds='auto', vectorize=True)
  363. def accumulate_flat(space, func, w_arr, calc_dtype, w_out, identity):
  364. arr_iter, arr_state = w_arr.create_iter()
  365. out_iter, out_state = w_out.create_iter()
  366. out_iter.track_index = False
  367. if identity is None:
  368. cur_value = arr_iter.getitem(arr_state).convert_to(space, calc_dtype)
  369. out_iter.setitem(out_state, cur_value)
  370. out_state = out_iter.next(out_state)
  371. arr_state = arr_iter.next(arr_state)
  372. else:
  373. cur_value = identity.convert_to(space, calc_dtype)
  374. shapelen = len(w_arr.get_shape())
  375. out_dtype = w_out.get_dtype()
  376. while not arr_iter.done(arr_state):
  377. accumulate_flat_driver.jit_merge_point(
  378. shapelen=shapelen, func=func, dtype=calc_dtype,
  379. out_dtype=out_dtype)
  380. w_item = arr_iter.getitem(arr_state).convert_to(space, calc_dtype)
  381. cur_value = func(calc_dtype, cur_value, w_item)
  382. out_iter.setitem(out_state, out_dtype.coerce(space, cur_value))
  383. out_state = out_iter.next(out_state)
  384. arr_state = arr_iter.next(arr_state)
  385. accumulate_driver = jit.JitDriver(
  386. name='numpy_accumulate',
  387. greens=['shapelen', 'func', 'calc_dtype'],
  388. reds='auto',
  389. vectorize=True)
  390. def accumulate(space, func, w_arr, axis, calc_dtype, w_out, identity):
  391. out_iter, out_state = w_out.create_iter()
  392. arr_shape = w_arr.get_shape()
  393. temp_shape = arr_shape[:axis] + arr_shape[axis + 1:]
  394. temp = W_NDimArray.from_shape(space, temp_shape, calc_dtype, w_instance=w_arr)
  395. temp_iter = AxisIter(temp.implementation, w_arr.get_shape(), axis)
  396. temp_state = temp_iter.reset()
  397. arr_iter, arr_state = w_arr.create_iter()
  398. arr_iter.track_index = False
  399. if identity is not None:
  400. identity = identity.convert_to(space, calc_dtype)
  401. shapelen = len(arr_shape)
  402. while not out_iter.done(out_state):
  403. accumulate_driver.jit_merge_point(shapelen=shapelen, func=func,
  404. calc_dtype=calc_dtype)
  405. w_item = arr_iter.getitem(arr_state).convert_to(space, calc_dtype)
  406. arr_state = arr_iter.next(arr_state)
  407. out_indices = out_iter.indices(out_state)
  408. if out_indices[axis] == 0:
  409. if identity is not None:
  410. w_item = func(calc_dtype, identity, w_item)
  411. else:
  412. cur_value = temp_iter.getitem(temp_state)
  413. w_item = func(calc_dtype, cur_value, w_item)
  414. out_iter.setitem(out_state, w_item)
  415. out_state = out_iter.next(out_state)
  416. temp_iter.setitem(temp_state, w_item)
  417. temp_state = temp_iter.next(temp_state)
  418. return w_out
  419. def fill(arr, box):
  420. arr_iter, arr_state = arr.create_iter()
  421. while not arr_iter.done(arr_state):
  422. arr_iter.setitem(arr_state, box)
  423. arr_state = arr_iter.next(arr_state)
  424. def assign(space, arr, seq):
  425. arr_iter, arr_state = arr.create_iter()
  426. arr_dtype = arr.get_dtype()
  427. for item in seq:
  428. arr_iter.setitem(arr_state, arr_dtype.coerce(space, item))
  429. arr_state = arr_iter.next(arr_state)
  430. where_driver = jit.JitDriver(name='numpy_where',
  431. greens = ['shapelen', 'dtype', 'arr_dtype'],
  432. reds = 'auto',
  433. vectorize=True)
  434. def where(space, out, shape, arr, x, y, dtype):
  435. out_iter, out_state = out.create_iter(shape)
  436. arr_iter, arr_state = arr.create_iter(shape)
  437. arr_dtype = arr.get_dtype()
  438. x_iter, x_state = x.create_iter(shape)
  439. y_iter, y_state = y.create_iter(shape)
  440. if x.is_scalar():
  441. if y.is_scalar():
  442. iter, state = arr_iter, arr_state
  443. else:
  444. iter, state = y_iter, y_state
  445. else:
  446. iter, state = x_iter, x_state
  447. out_iter.track_index = x_iter.track_index = False
  448. arr_iter.track_index = y_iter.track_index = False
  449. iter.track_index = True
  450. shapelen = len(shape)
  451. while not iter.done(state):
  452. where_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
  453. arr_dtype=arr_dtype)
  454. w_cond = arr_iter.getitem(arr_state)
  455. if arr_dtype.itemtype.bool(w_cond):
  456. w_val = x_iter.getitem(x_state).convert_to(space, dtype)
  457. else:
  458. w_val = y_iter.getitem(y_state).convert_to(space, dtype)
  459. out_iter.setitem(out_state, w_val)
  460. out_state = out_iter.next(out_state)
  461. arr_state = arr_iter.next(arr_state)
  462. x_state = x_iter.next(x_state)
  463. y_state = y_iter.next(y_state)
  464. if x.is_scalar():
  465. if y.is_scalar():
  466. state = arr_state
  467. else:
  468. state = y_state
  469. else:
  470. state = x_state
  471. return out
  472. def _new_argmin_argmax(op_name):
  473. arg_driver = jit.JitDriver(name='numpy_' + op_name,
  474. greens = ['shapelen', 'dtype'],
  475. reds = 'auto')
  476. arg_flat_driver = jit.JitDriver(name='numpy_flat_' + op_name,
  477. greens = ['shapelen', 'dtype'],
  478. reds = 'auto')
  479. def argmin_argmax(space, w_arr, w_out, axis):
  480. from pypy.module.micronumpy.descriptor import get_dtype_cache
  481. dtype = w_arr.get_dtype()
  482. shapelen = len(w_arr.get_shape())
  483. axis_flags = [False] * shapelen
  484. axis_flags[axis] = True
  485. inner_iter, outer_iter = split_iter(w_arr.implementation, axis_flags)
  486. outer_state = outer_iter.reset()
  487. out_iter, out_state = w_out.create_iter()
  488. while not outer_iter.done(outer_state):
  489. inner_state = inner_iter.reset()
  490. inner_state.offset = outer_state.offset
  491. cur_best = inner_iter.getitem(inner_state)
  492. inner_state = inner_iter.next(inner_state)
  493. result = 0
  494. idx = 1
  495. while not inner_iter.done(inner_state):
  496. arg_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
  497. w_val = inner_iter.getitem(inner_state)
  498. old_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
  499. if not old_best:
  500. result = idx
  501. cur_best = w_val
  502. inner_state = inner_iter.next(inner_state)
  503. idx += 1
  504. result = get_dtype_cache(space).w_longdtype.box(result)
  505. out_iter.setitem(out_state, result)
  506. out_state = out_iter.next(out_state)
  507. outer_state = outer_iter.next(outer_state)
  508. return w_out
  509. def argmin_argmax_flat(w_arr):
  510. result = 0
  511. idx = 1
  512. dtype = w_arr.get_dtype()
  513. iter, state = w_arr.create_iter()
  514. cur_best = iter.getitem(state)
  515. state = iter.next(state)
  516. shapelen = len(w_arr.get_shape())
  517. while not iter.done(state):
  518. arg_flat_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
  519. w_val = iter.getitem(state)
  520. old_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
  521. if not old_best:
  522. result = idx
  523. cur_best = w_val
  524. state = iter.next(state)
  525. idx += 1
  526. return result
  527. return argmin_argmax, argmin_argmax_flat
  528. argmin, argmin_flat = _new_argmin_argmax('argmin')
  529. argmax, argmax_flat = _new_argmin_argmax('argmax')
  530. dot_driver = jit.JitDriver(name = 'numpy_dot',
  531. greens = ['dtype'],
  532. reds = 'auto',
  533. vectorize=True)
  534. def multidim_dot(space, left, right, result, dtype, right_critical_dim):
  535. ''' assumes left, right are concrete arrays
  536. given left.shape == [3, 5, 7],
  537. right.shape == [2, 7, 4]
  538. then
  539. result.shape == [3, 5, 2, 4]
  540. broadcast shape should be [3, 5, 2, 7, 4]
  541. result should skip dims 3 which is len(result_shape) - 1
  542. (note that if right is 1d, result should
  543. skip len(result_shape))
  544. left should skip 2, 4 which is a.ndims-1 + range(right.ndims)
  545. except where it==(right.ndims-2)
  546. right should skip 0, 1
  547. '''
  548. left_shape = left.get_shape()
  549. right_shape = right.get_shape()
  550. left_impl = left.implementation
  551. right_impl = right.implementation
  552. assert left_shape[-1] == right_shape[right_critical_dim]
  553. assert result.get_dtype() == dtype
  554. outi, outs = result.create_iter()
  555. outi.track_index = False
  556. lefti = AllButAxisIter(left_impl, len(left_shape) - 1)
  557. righti = AllButAxisIter(right_impl, right_critical_dim)
  558. lefts = lefti.reset()
  559. rights = righti.reset()
  560. n = left_impl.shape[-1]
  561. s1 = left_impl.strides[-1]
  562. s2 = right_impl.strides[right_critical_dim]
  563. while not lefti.done(lefts):
  564. while not righti.done(rights):
  565. oval = outi.getitem(outs)
  566. i1 = lefts.offset
  567. i2 = rights.offset
  568. i = 0
  569. while i < n:
  570. i += 1
  571. dot_driver.jit_merge_point(dtype=dtype)
  572. lval = left_impl.getitem(i1).convert_to(space, dtype)
  573. rval = right_impl.getitem(i2).convert_to(space, dtype)
  574. oval = dtype.itemtype.add(oval, dtype.itemtype.mul(lval, rval))
  575. i1 += jit.promote(s1)
  576. i2 += jit.promote(s2)
  577. outi.setitem(outs, oval)
  578. outs = outi.next(outs)
  579. rights = righti.next(rights)
  580. rights = righti.reset(rights)
  581. lefts = lefti.next(lefts)
  582. return result
  583. count_all_true_driver = jit.JitDriver(name = 'numpy_count',
  584. greens = ['shapelen', 'dtype'],
  585. reds = 'auto',
  586. vectorize=True)
  587. def count_all_true_concrete(impl):
  588. s = 0
  589. iter, state = impl.create_iter()
  590. shapelen = len(impl.shape)
  591. dtype = impl.dtype
  592. while not iter.done(state):
  593. count_all_true_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
  594. s += iter.getitem_bool(state)
  595. state = iter.next(state)
  596. return s
  597. def count_all_true(arr):
  598. if arr.is_scalar():
  599. return arr.get_dtype().itemtype.bool(arr.get_scalar_value())
  600. else:
  601. return count_all_true_concrete(arr.implementation)
  602. nonzero_driver = jit.JitDriver(name = 'numpy_nonzero',
  603. greens = ['shapelen', 'dims', 'dtype'],
  604. reds = 'auto',
  605. vectorize=True)
  606. def nonzero(res, arr, box):
  607. res_iter, res_state = res.create_iter()
  608. arr_iter, arr_state = arr.create_iter()
  609. shapelen = len(arr.shape)
  610. dtype = arr.dtype
  611. dims = range(shapelen)
  612. while not arr_iter.done(arr_state):
  613. nonzero_driver.jit_merge_point(shapelen=shapelen, dims=dims, dtype=dtype)
  614. if arr_iter.getitem_bool(arr_state):
  615. arr_indices = arr_iter.indices(arr_state)
  616. for d in dims:
  617. res_iter.setitem(res_state, box(arr_indices[d]))
  618. res_state = res_iter.next(res_state)
  619. arr_state = arr_iter.next(arr_state)
  620. return res
  621. getitem_filter_driver = jit.JitDriver(name = 'numpy_getitem_bool',
  622. greens = ['shapelen', 'arr_dtype',
  623. 'index_dtype'],
  624. reds = 'auto',
  625. vectorize=True)
  626. def getitem_filter(res, arr, index):
  627. res_iter, res_state = res.create_iter()
  628. shapelen = len(arr.get_shape())
  629. if shapelen > 1 and len(index.get_shape()) < 2:
  630. index_iter, index_state = index.create_iter(arr.get_shape(), backward_broadcast=True)
  631. else:
  632. index_iter, index_state = index.create_iter()
  633. arr_iter, arr_state = arr.create_iter()
  634. arr_dtype = arr.get_dtype()
  635. index_dtype = index.get_dtype()
  636. # support the deprecated form where arr([True]) will return arr[0, ...]
  637. # by iterating over res_iter, not index_iter
  638. while not res_iter.done(res_state):
  639. getitem_filter_driver.jit_merge_point(shapelen=shapelen,
  640. index_dtype=index_dtype,
  641. arr_dtype=arr_dtype,
  642. )
  643. if index_iter.getitem_bool(index_state):
  644. res_iter.setitem(res_state, arr_iter.getitem(arr_state))
  645. res_state = res_iter.next(res_state)
  646. index_state = index_iter.next(index_state)
  647. arr_state = arr_iter.next(arr_state)
  648. return res
  649. setitem_filter_driver = jit.JitDriver(name = 'numpy_setitem_bool',
  650. greens = ['shapelen', 'arr_dtype',
  651. 'index_dtype'],
  652. reds = 'auto',
  653. vectorize=True)
  654. def setitem_filter(space, arr, index, value):
  655. arr_iter, arr_state = arr.create_iter()
  656. shapelen = len(arr.get_shape())
  657. if shapelen > 1 and len(index.get_shape()) < 2:
  658. index_iter, index_state = index.create_iter(arr.get_shape(), backward_broadcast=True)
  659. else:
  660. index_iter, index_state = index.create_iter()
  661. if value.get_size() == 1:
  662. value_iter, value_state = value.create_iter(arr.get_shape())
  663. else:
  664. value_iter, value_state = value.create_iter()
  665. index_dtype = index.get_dtype()
  666. arr_dtype = arr.get_dtype()
  667. while not index_iter.done(index_state):
  668. setitem_filter_driver.jit_merge_point(shapelen=shapelen,
  669. index_dtype=index_dtype,
  670. arr_dtype=arr_dtype,
  671. )
  672. if index_iter.getitem_bool(index_state):
  673. val = arr_dtype.coerce(space, value_iter.getitem(value_state))
  674. value_state = value_iter.next(value_state)
  675. arr_iter.setitem(arr_state, val)
  676. arr_state = arr_iter.next(arr_state)
  677. index_state = index_iter.next(index_state)
  678. flatiter_getitem_driver = jit.JitDriver(name = 'numpy_flatiter_getitem',
  679. greens = ['dtype'],
  680. reds = 'auto',
  681. vectorize=True)
  682. def flatiter_getitem(res, base_iter, base_state, step):
  683. ri, rs = res.create_iter()
  684. dtype = res.get_dtype()
  685. while not ri.done(rs):
  686. flatiter_getitem_driver.jit_merge_point(dtype=dtype)
  687. ri.setitem(rs, base_iter.getitem(base_state))
  688. base_state = base_iter.goto(base_state.index + step)
  689. rs = ri.next(rs)
  690. return res
  691. flatiter_setitem_driver = jit.JitDriver(name = 'numpy_flatiter_setitem',
  692. greens = ['dtype'],
  693. reds = 'auto',
  694. vectorize=True)
  695. def flatiter_setitem(space, dtype, val, arr_iter, arr_state, step, length):
  696. val_iter, val_state = val.create_iter()
  697. while length > 0:
  698. flatiter_setitem_driver.jit_merge_point(dtype=dtype)
  699. val = val_iter.getitem(val_state)
  700. if dtype.is_str_or_unicode():
  701. val = dtype.coerce(space, val)
  702. else:
  703. val = val.convert_to(space, dtype)
  704. arr_iter.setitem(arr_state, val)
  705. arr_state = arr_iter.goto(arr_state.index + step)
  706. val_state = val_iter.next(val_state)
  707. if val_iter.done(val_state):
  708. val_state = val_iter.reset(val_state)
  709. length -= 1
  710. fromstring_driver = jit.JitDriver(name = 'numpy_fromstring',
  711. greens = ['itemsize', 'dtype'],
  712. reds = 'auto')
  713. def fromstring_loop(space, a, dtype, itemsize, s):
  714. i = 0
  715. ai, state = a.create_iter()
  716. while not ai.done(state):
  717. fromstring_driver.jit_merge_point(dtype=dtype, itemsize=itemsize)
  718. sub = s[i*itemsize:i*itemsize + itemsize]
  719. val = dtype.runpack_str(space, sub)
  720. ai.setitem(state, val)
  721. state = ai.next(state)
  722. i += 1
  723. def tostring(space, arr):
  724. builder = StringBuilder()
  725. iter, state = arr.create_iter()
  726. w_res_str = W_NDimArray.from_shape(space, [1], arr.get_dtype())
  727. itemsize = arr.get_dtype().elsize
  728. with w_res_str.implementation as storage:
  729. res_str_casted = rffi.cast(rffi.CArrayPtr(lltype.Char),
  730. support.get_storage_as_int(storage))
  731. while not iter.done(state):
  732. w_res_str.implementation.setitem(0, iter.getitem(state))
  733. for i in range(itemsize):
  734. builder.append(res_str_casted[i])
  735. state = iter.next(state)
  736. return builder.build()
  737. getitem_int_driver = jit.JitDriver(name = 'numpy_getitem_int',
  738. greens = ['shapelen', 'indexlen',
  739. 'prefixlen', 'dtype'],
  740. reds = 'auto')
  741. def getitem_array_int(space, arr, res, iter_shape, indexes_w, prefix_w):
  742. shapelen = len(iter_shape)
  743. prefixlen = len(prefix_w)
  744. indexlen = len(indexes_w)
  745. dtype = arr.get_dtype()
  746. iter = PureShapeIter(iter_shape, indexes_w)
  747. while not iter.done():
  748. getitem_int_driver.jit_merge_point(shapelen=shapelen, indexlen=indexlen,
  749. dtype=dtype, prefixlen=prefixlen)
  750. # prepare the index
  751. index_w = [None] * indexlen
  752. for i in range(indexlen):
  753. if iter.idx_w_i[i] is not None:
  754. index_w[i] = iter.idx_w_i[i].getitem(iter.idx_w_s[i])
  755. else:
  756. index_w[i] = indexes_w[i]
  757. res.descr_setitem(space, space.newtuple(prefix_w[:prefixlen] +
  758. iter.get_index(space, shapelen)),
  759. arr.descr_getitem(space, space.newtuple(index_w)))
  760. iter.next()
  761. return res
  762. setitem_int_driver = jit.JitDriver(name = 'numpy_setitem_int',
  763. greens = ['shapelen', 'indexlen',
  764. 'prefixlen', 'dtype'],
  765. reds = 'auto')
  766. def setitem_array_int(space, arr, iter_shape, indexes_w, val_arr,
  767. prefix_w):
  768. shapelen = len(iter_shape)
  769. indexlen = len(indexes_w)
  770. prefixlen = len(prefix_w)
  771. dtype = arr.get_dtype()
  772. iter = PureShapeIter(iter_shape, indexes_w)
  773. while not iter.done():
  774. setitem_int_driver.jit_merge_point(shapelen=shapelen, indexlen=indexlen,
  775. dtype=dtype, prefixlen=prefixlen)
  776. # prepare the index
  777. index_w = [None] * indexlen
  778. for i in range(indexlen):
  779. if iter.idx_w_i[i] is not None:
  780. index_w[i] = iter.idx_w_i[i].getitem(iter.idx_w_s[i])
  781. else:
  782. index_w[i] = indexes_w[i]
  783. w_idx = space.newtuple(prefix_w[:prefixlen] + iter.get_index(space,
  784. shapelen))
  785. if val_arr.is_scalar():
  786. w_value = val_arr.get_scalar_value()
  787. else:
  788. w_value = val_arr.descr_getitem(space, w_idx)
  789. arr.descr_setitem(space, space.newtuple(index_w), w_value)
  790. iter.next()
  791. byteswap_driver = jit.JitDriver(name='numpy_byteswap_driver',
  792. greens = ['dtype'],
  793. reds = 'auto',
  794. vectorize=True)
  795. def byteswap(from_, to):
  796. dtype = from_.dtype
  797. from_iter, from_state = from_.create_iter()
  798. to_iter, to_state = to.create_iter()
  799. while not from_iter.done(from_state):
  800. byteswap_driver.jit_merge_point(dtype=dtype)
  801. val = dtype.itemtype.byteswap(from_iter.getitem(from_state))
  802. to_iter.setitem(to_state, val)
  803. to_state = to_iter.next(to_state)
  804. from_state = from_iter.next(from_state)
  805. choose_driver = jit.JitDriver(name='numpy_choose_driver',
  806. greens = ['shapelen', 'mode', 'dtype'],
  807. reds = 'auto',
  808. vectorize=True)
  809. def choose(space, arr, choices, shape, dtype, out, mode):
  810. shapelen = len(shape)
  811. pairs = [a.create_iter(shape) for a in choices]
  812. iterators = [i[0] for i in pairs]
  813. states = [i[1] for i in pairs]
  814. arr_iter, arr_state = arr.create_iter(shape)
  815. out_iter, out_state = out.create_iter(shape)
  816. while not arr_iter.done(arr_state):
  817. choose_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
  818. mode=mode)
  819. index = support.index_w(space, arr_iter.getitem(arr_state))
  820. if index < 0 or index >= len(iterators):
  821. if mode == NPY.RAISE:
  822. raise oefmt(space.w_ValueError,
  823. "invalid entry in choice array")
  824. elif mode == NPY.WRAP:
  825. index = index % (len(iterators))
  826. else:
  827. assert mode == NPY.CLIP
  828. if index < 0:
  829. index = 0
  830. else:
  831. index = len(iterators) - 1
  832. val = iterators[index].getitem(states[index]).convert_to(space, dtype)
  833. out_iter.setitem(out_state, val)
  834. for i in range(len(iterators)):
  835. states[i] = iterators[i].next(states[i])
  836. out_state = out_iter.next(out_state)
  837. arr_state = arr_iter.next(arr_state)
  838. clip_driver = jit.JitDriver(name='numpy_clip_driver',
  839. greens = ['shapelen', 'dtype'],
  840. reds = 'auto',
  841. vectorize=True)
  842. def clip(space, arr, shape, min, max, out):
  843. assert min or max
  844. arr_iter, arr_state = arr.create_iter(shape)
  845. if min is not None:
  846. min_iter, min_state = min.create_iter(shape)
  847. else:
  848. min_iter, min_state = None, None
  849. if max is not None:
  850. max_iter, max_state = max.create_iter(shape)
  851. else:
  852. max_iter, max_state = None, None
  853. out_iter, out_state = out.create_iter(shape)
  854. shapelen = len(shape)
  855. dtype = out.get_dtype()
  856. while not arr_iter.done(arr_state):
  857. clip_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
  858. w_v = arr_iter.getitem(arr_state).convert_to(space, dtype)
  859. arr_state = arr_iter.next(arr_state)
  860. if min_iter is not None:
  861. w_min = min_iter.getitem(min_state).convert_to(space, dtype)
  862. if dtype.itemtype.lt(w_v, w_min):
  863. w_v = w_min
  864. min_state = min_iter.next(min_state)
  865. if max_iter is not None:
  866. w_max = max_iter.getitem(max_state).convert_to(space, dtype)
  867. if dtype.itemtype.gt(w_v, w_max):
  868. w_v = w_max
  869. max_state = max_iter.next(max_state)
  870. out_iter.setitem(out_state, w_v)
  871. out_state = out_iter.next(out_state)
  872. round_driver = jit.JitDriver(name='numpy_round_driver',
  873. greens = ['shapelen', 'dtype'],
  874. reds = 'auto',
  875. vectorize=True)
  876. def round(space, arr, dtype, shape, decimals, out):
  877. arr_iter, arr_state = arr.create_iter(shape)
  878. out_iter, out_state = out.create_iter(shape)
  879. shapelen = len(shape)
  880. while not arr_iter.done(arr_state):
  881. round_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
  882. w_v = arr_iter.getitem(arr_state).convert_to(space, dtype)
  883. w_v = dtype.itemtype.round(w_v, decimals)
  884. out_iter.setitem(out_state, w_v)
  885. arr_state = arr_iter.next(arr_state)
  886. out_state = out_iter.next(out_state)
  887. diagonal_simple_driver = jit.JitDriver(name='numpy_diagonal_simple_driver',
  888. greens = ['axis1', 'axis2'],
  889. reds = 'auto')
  890. def diagonal_simple(space, arr, out, offset, axis1, axis2, size):
  891. out_iter, out_state = out.create_iter()
  892. i = 0
  893. index = [0] * 2
  894. while i < size:
  895. diagonal_simple_driver.jit_merge_point(axis1=axis1, axis2=axis2)
  896. index[axis1] = i
  897. index[axis2] = i + offset
  898. out_iter.setitem(out_state, arr.getitem_index(space, index))
  899. i += 1
  900. out_state = out_iter.next(out_state)
  901. def diagonal_array(space, arr, out, offset, axis1, axis2, shape):
  902. out_iter, out_state = out.create_iter()
  903. iter = PureShapeIter(shape, [])
  904. shapelen_minus_1 = len(shape) - 1
  905. assert shapelen_minus_1 >= 0
  906. if axis1 < axis2:
  907. a = axis1
  908. b = axis2 - 1
  909. else:
  910. a = axis2
  911. b = axis1 - 1
  912. assert a >= 0
  913. assert b >= 0
  914. while not iter.done():
  915. last_index = iter.indexes[-1]
  916. if axis1 < axis2:
  917. indexes = (iter.indexes[:a] + [last_index] +
  918. iter.indexes[a:b] + [last_index + offset] +
  919. iter.indexes[b:shapelen_minus_1])
  920. else:
  921. indexes = (iter.indexes[:a] + [last_index + offset] +
  922. iter.indexes[a:b] + [last_index] +
  923. iter.indexes[b:shapelen_minus_1])
  924. out_iter.setitem(out_state, arr.getitem_index(space, indexes))
  925. iter.next()
  926. out_state = out_iter.next(out_state)
  927. def _new_binsearch(side, op_name):
  928. binsearch_driver = jit.JitDriver(name='numpy_binsearch_' + side,
  929. greens=['dtype'],
  930. reds='auto')
  931. def binsearch(space, arr, key, ret):
  932. assert len(arr.get_shape()) == 1
  933. dtype = key.get_dtype()
  934. op = getattr(dtype.itemtype, op_name)
  935. key_iter, key_state = key.create_iter()
  936. ret_iter, ret_state = ret.create_iter()
  937. ret_iter.track_index = False
  938. size = arr.get_size()
  939. min_idx = 0
  940. max_idx = size
  941. last_key_val = key_iter.getitem(key_state)
  942. while not key_iter.done(key_state):
  943. key_val = key_iter.getitem(key_state)
  944. if dtype.itemtype.lt(last_key_val, key_val):
  945. max_idx = size
  946. else:
  947. min_idx = 0
  948. max_idx = max_idx + 1 if max_idx < size else size
  949. last_key_val = key_val
  950. while min_idx < max_idx:
  951. binsearch_driver.jit_merge_point(dtype=dtype)
  952. mid_idx = min_idx + ((max_idx - min_idx) >> 1)
  953. mid_val = arr.getitem(space, [mid_idx]).convert_to(space, dtype)
  954. if op(mid_val, key_val):
  955. min_idx = mid_idx + 1
  956. else:
  957. max_idx = mid_idx
  958. ret_iter.setitem(ret_state, ret.get_dtype().box(min_idx))
  959. ret_state = ret_iter.next(ret_state)
  960. key_state = key_iter.next(key_state)
  961. return binsearch
  962. binsearch_left = _new_binsearch('left', 'lt')
  963. binsearch_right = _new_binsearch('right', 'le')