PageRenderTime 50ms CodeModel.GetById 16ms RepoModel.GetById 0ms app.codeStats 0ms

/pypy/module/micronumpy/compile.py

https://bitbucket.org/timfel/pypy
Python | 1116 lines | 1082 code | 28 blank | 6 comment | 18 complexity | 53f595ff68266fa111c8e449b9d24c8f MD5 | raw file
Possible License(s): Apache-2.0, AGPL-3.0, BSD-3-Clause
  1. """ This is a set of tools for standalone compiling of numpy expressions.
  2. It should not be imported by the module itself
  3. """
  4. import re
  5. import py
  6. from pypy.interpreter import special
  7. from pypy.interpreter.baseobjspace import InternalSpaceCache, W_Root, ObjSpace
  8. from pypy.interpreter.error import oefmt
  9. from rpython.rlib.objectmodel import specialize, instantiate
  10. from rpython.rlib.nonconst import NonConstant
  11. from rpython.rlib.rarithmetic import base_int
  12. from pypy.module.micronumpy import boxes, ufuncs
  13. from pypy.module.micronumpy.arrayops import where
  14. from pypy.module.micronumpy.ndarray import W_NDimArray
  15. from pypy.module.micronumpy.ctors import array
  16. from pypy.module.micronumpy.descriptor import get_dtype_cache
  17. from pypy.interpreter.miscutils import ThreadLocals, make_weak_value_dictionary
  18. from pypy.interpreter.executioncontext import (ExecutionContext, ActionFlag,
  19. UserDelAction)
  20. from pypy.interpreter.pyframe import PyFrame
  21. class BogusBytecode(Exception):
  22. pass
  23. class ArgumentMismatch(Exception):
  24. pass
  25. class ArgumentNotAnArray(Exception):
  26. pass
  27. class WrongFunctionName(Exception):
  28. pass
  29. class TokenizerError(Exception):
  30. pass
  31. class BadToken(Exception):
  32. pass
  33. SINGLE_ARG_FUNCTIONS = ["sum", "prod", "max", "min", "all", "any",
  34. "unegative", "flat", "tostring", "count_nonzero",
  35. "argsort", "cumsum", "logical_xor_reduce"]
  36. TWO_ARG_FUNCTIONS = ["dot", 'take', 'searchsorted', 'multiply']
  37. TWO_ARG_FUNCTIONS_OR_NONE = ['view', 'astype', 'reshape']
  38. THREE_ARG_FUNCTIONS = ['where']
  39. class W_TypeObject(W_Root):
  40. def __init__(self, name):
  41. self.name = name
  42. def lookup(self, name):
  43. return self.getdictvalue(self, name)
  44. def getname(self, space):
  45. return self.name
  46. class FakeSpace(ObjSpace):
  47. w_ValueError = W_TypeObject("ValueError")
  48. w_TypeError = W_TypeObject("TypeError")
  49. w_IndexError = W_TypeObject("IndexError")
  50. w_OverflowError = W_TypeObject("OverflowError")
  51. w_NotImplementedError = W_TypeObject("NotImplementedError")
  52. w_AttributeError = W_TypeObject("AttributeError")
  53. w_StopIteration = W_TypeObject("StopIteration")
  54. w_KeyError = W_TypeObject("KeyError")
  55. w_SystemExit = W_TypeObject("SystemExit")
  56. w_KeyboardInterrupt = W_TypeObject("KeyboardInterrupt")
  57. w_RuntimeError = W_TypeObject("RuntimeError")
  58. w_RecursionError = W_TypeObject("RecursionError") # py3.5
  59. w_VisibleDeprecationWarning = W_TypeObject("VisibleDeprecationWarning")
  60. w_None = W_Root()
  61. w_bool = W_TypeObject("bool")
  62. w_int = W_TypeObject("int")
  63. w_float = W_TypeObject("float")
  64. w_list = W_TypeObject("list")
  65. w_long = W_TypeObject("long")
  66. w_tuple = W_TypeObject('tuple')
  67. w_slice = W_TypeObject("slice")
  68. w_bytes = W_TypeObject("str")
  69. w_text = w_bytes
  70. w_unicode = W_TypeObject("unicode")
  71. w_complex = W_TypeObject("complex")
  72. w_dict = W_TypeObject("dict")
  73. w_object = W_TypeObject("object")
  74. w_buffer = W_TypeObject("buffer")
  75. w_type = W_TypeObject("type")
  76. def __init__(self, config=None):
  77. """NOT_RPYTHON"""
  78. self.fromcache = InternalSpaceCache(self).getorbuild
  79. self.w_Ellipsis = special.Ellipsis()
  80. self.w_NotImplemented = special.NotImplemented()
  81. if config is None:
  82. from pypy.config.pypyoption import get_pypy_config
  83. config = get_pypy_config(translating=False)
  84. self.config = config
  85. self.interned_strings = make_weak_value_dictionary(self, str, W_Root)
  86. self.builtin = DictObject({})
  87. self.FrameClass = PyFrame
  88. self.threadlocals = ThreadLocals()
  89. self.actionflag = ActionFlag() # changed by the signal module
  90. self.check_signal_action = None # changed by the signal module
  91. def _freeze_(self):
  92. return True
  93. def is_none(self, w_obj):
  94. return w_obj is None or w_obj is self.w_None
  95. def issequence_w(self, w_obj):
  96. return isinstance(w_obj, ListObject) or isinstance(w_obj, W_NDimArray)
  97. def len(self, w_obj):
  98. if isinstance(w_obj, ListObject):
  99. return self.wrap(len(w_obj.items))
  100. elif isinstance(w_obj, DictObject):
  101. return self.wrap(len(w_obj.items))
  102. raise NotImplementedError
  103. def getattr(self, w_obj, w_attr):
  104. assert isinstance(w_attr, StringObject)
  105. if isinstance(w_obj, DictObject):
  106. return w_obj.getdictvalue(self, w_attr)
  107. return None
  108. def issubtype_w(self, w_sub, w_type):
  109. is_root(w_type)
  110. return NonConstant(True)
  111. def isinstance_w(self, w_obj, w_tp):
  112. try:
  113. return w_obj.tp == w_tp
  114. except AttributeError:
  115. return False
  116. def iter(self, w_iter):
  117. if isinstance(w_iter, ListObject):
  118. raise NotImplementedError
  119. #return IterObject(space, w_iter.items)
  120. elif isinstance(w_iter, DictObject):
  121. return IterDictObject(self, w_iter)
  122. def next(self, w_iter):
  123. return w_iter.next()
  124. def contains(self, w_iter, w_key):
  125. if isinstance(w_iter, DictObject):
  126. return self.wrap(w_key in w_iter.items)
  127. raise NotImplementedError
  128. def decode_index4(self, w_idx, size):
  129. if isinstance(w_idx, IntObject):
  130. return (self.int_w(w_idx), 0, 0, 1)
  131. else:
  132. assert isinstance(w_idx, SliceObject)
  133. start, stop, step = w_idx.start, w_idx.stop, w_idx.step
  134. if step == 0:
  135. return (0, size, 1, size)
  136. if start < 0:
  137. start += size
  138. if stop < 0:
  139. stop += size + 1
  140. if step < 0:
  141. start, stop = stop, start
  142. start -= 1
  143. stop -= 1
  144. lgt = (stop - start + 1) / step + 1
  145. else:
  146. lgt = (stop - start - 1) / step + 1
  147. return (start, stop, step, lgt)
  148. def unicode_from_object(self, w_item):
  149. # XXX
  150. return StringObject("")
  151. @specialize.argtype(1)
  152. def wrap(self, obj):
  153. if isinstance(obj, float):
  154. return FloatObject(obj)
  155. elif isinstance(obj, bool):
  156. return BoolObject(obj)
  157. elif isinstance(obj, int):
  158. return IntObject(obj)
  159. elif isinstance(obj, base_int):
  160. return LongObject(obj)
  161. elif isinstance(obj, W_Root):
  162. return obj
  163. elif isinstance(obj, str):
  164. return StringObject(obj)
  165. raise NotImplementedError
  166. def newtext(self, obj):
  167. return StringObject(obj)
  168. newbytes = newtext
  169. def newunicode(self, obj):
  170. raise NotImplementedError
  171. def newlist(self, items):
  172. return ListObject(items)
  173. def newcomplex(self, r, i):
  174. return ComplexObject(r, i)
  175. def newfloat(self, f):
  176. return FloatObject(f)
  177. def newslice(self, start, stop, step):
  178. return SliceObject(self.int_w(start), self.int_w(stop),
  179. self.int_w(step))
  180. def le(self, w_obj1, w_obj2):
  181. assert isinstance(w_obj1, boxes.W_GenericBox)
  182. assert isinstance(w_obj2, boxes.W_GenericBox)
  183. return w_obj1.descr_le(self, w_obj2)
  184. def lt(self, w_obj1, w_obj2):
  185. assert isinstance(w_obj1, boxes.W_GenericBox)
  186. assert isinstance(w_obj2, boxes.W_GenericBox)
  187. return w_obj1.descr_lt(self, w_obj2)
  188. def ge(self, w_obj1, w_obj2):
  189. assert isinstance(w_obj1, boxes.W_GenericBox)
  190. assert isinstance(w_obj2, boxes.W_GenericBox)
  191. return w_obj1.descr_ge(self, w_obj2)
  192. def add(self, w_obj1, w_obj2):
  193. assert isinstance(w_obj1, boxes.W_GenericBox)
  194. assert isinstance(w_obj2, boxes.W_GenericBox)
  195. return w_obj1.descr_add(self, w_obj2)
  196. def sub(self, w_obj1, w_obj2):
  197. return self.wrap(1)
  198. def mul(self, w_obj1, w_obj2):
  199. assert isinstance(w_obj1, boxes.W_GenericBox)
  200. assert isinstance(w_obj2, boxes.W_GenericBox)
  201. return w_obj1.descr_mul(self, w_obj2)
  202. def pow(self, w_obj1, w_obj2, _):
  203. return self.wrap(1)
  204. def neg(self, w_obj1):
  205. return self.wrap(0)
  206. def repr(self, w_obj1):
  207. return self.wrap('fake')
  208. def getitem(self, obj, index):
  209. if isinstance(obj, DictObject):
  210. w_dict = obj.getdict(self)
  211. if w_dict is not None:
  212. try:
  213. return w_dict[index]
  214. except KeyError as e:
  215. raise oefmt(self.w_KeyError, "key error")
  216. assert isinstance(obj, ListObject)
  217. assert isinstance(index, IntObject)
  218. return obj.items[index.intval]
  219. def listview(self, obj, number=-1):
  220. assert isinstance(obj, ListObject)
  221. if number != -1:
  222. assert number == 2
  223. return [obj.items[0], obj.items[1]]
  224. return obj.items
  225. fixedview = listview
  226. def float(self, w_obj):
  227. if isinstance(w_obj, FloatObject):
  228. return w_obj
  229. assert isinstance(w_obj, boxes.W_GenericBox)
  230. return self.float(w_obj.descr_float(self))
  231. def float_w(self, w_obj, allow_conversion=True):
  232. assert isinstance(w_obj, FloatObject)
  233. return w_obj.floatval
  234. def int_w(self, w_obj, allow_conversion=True):
  235. if isinstance(w_obj, IntObject):
  236. return w_obj.intval
  237. elif isinstance(w_obj, FloatObject):
  238. return int(w_obj.floatval)
  239. elif isinstance(w_obj, SliceObject):
  240. raise oefmt(self.w_TypeError, "slice.")
  241. raise NotImplementedError
  242. def unpackcomplex(self, w_obj):
  243. if isinstance(w_obj, ComplexObject):
  244. return w_obj.r, w_obj.i
  245. raise NotImplementedError
  246. def index(self, w_obj):
  247. return self.wrap(self.int_w(w_obj))
  248. def bytes_w(self, w_obj):
  249. if isinstance(w_obj, StringObject):
  250. return w_obj.v
  251. raise NotImplementedError
  252. text_w = bytes_w
  253. def unicode_w(self, w_obj):
  254. # XXX
  255. if isinstance(w_obj, StringObject):
  256. return unicode(w_obj.v)
  257. raise NotImplementedError
  258. def int(self, w_obj):
  259. if isinstance(w_obj, IntObject):
  260. return w_obj
  261. assert isinstance(w_obj, boxes.W_GenericBox)
  262. return self.int(w_obj.descr_int(self))
  263. def long(self, w_obj):
  264. if isinstance(w_obj, LongObject):
  265. return w_obj
  266. assert isinstance(w_obj, boxes.W_GenericBox)
  267. return self.int(w_obj.descr_long(self))
  268. def str(self, w_obj):
  269. if isinstance(w_obj, StringObject):
  270. return w_obj
  271. assert isinstance(w_obj, boxes.W_GenericBox)
  272. return self.str(w_obj.descr_str(self))
  273. def is_true(self, w_obj):
  274. assert isinstance(w_obj, BoolObject)
  275. return bool(w_obj.intval)
  276. def gt(self, w_lhs, w_rhs):
  277. return BoolObject(self.int_w(w_lhs) > self.int_w(w_rhs))
  278. def lt(self, w_lhs, w_rhs):
  279. return BoolObject(self.int_w(w_lhs) < self.int_w(w_rhs))
  280. def is_w(self, w_obj, w_what):
  281. return w_obj is w_what
  282. def eq_w(self, w_obj, w_what):
  283. return w_obj == w_what
  284. def issubtype(self, w_type1, w_type2):
  285. return BoolObject(True)
  286. def type(self, w_obj):
  287. if self.is_none(w_obj):
  288. return self.w_None
  289. try:
  290. return w_obj.tp
  291. except AttributeError:
  292. if isinstance(w_obj, W_NDimArray):
  293. return W_NDimArray
  294. return self.w_None
  295. def lookup(self, w_obj, name):
  296. w_type = self.type(w_obj)
  297. if not self.is_none(w_type):
  298. return w_type.lookup(name)
  299. def gettypefor(self, w_obj):
  300. return W_TypeObject(w_obj.typedef.name)
  301. def call_function(self, tp, w_dtype, *args):
  302. if tp is self.w_float:
  303. if isinstance(w_dtype, boxes.W_Float64Box):
  304. return FloatObject(float(w_dtype.value))
  305. if isinstance(w_dtype, boxes.W_Float32Box):
  306. return FloatObject(float(w_dtype.value))
  307. if isinstance(w_dtype, boxes.W_Int64Box):
  308. return FloatObject(float(int(w_dtype.value)))
  309. if isinstance(w_dtype, boxes.W_Int32Box):
  310. return FloatObject(float(int(w_dtype.value)))
  311. if isinstance(w_dtype, boxes.W_Int16Box):
  312. return FloatObject(float(int(w_dtype.value)))
  313. if isinstance(w_dtype, boxes.W_Int8Box):
  314. return FloatObject(float(int(w_dtype.value)))
  315. if isinstance(w_dtype, IntObject):
  316. return FloatObject(float(w_dtype.intval))
  317. if tp is self.w_int:
  318. if isinstance(w_dtype, FloatObject):
  319. return IntObject(int(w_dtype.floatval))
  320. return w_dtype
  321. @specialize.arg(2)
  322. def call_method(self, w_obj, s, *args):
  323. # XXX even the hacks have hacks
  324. if s == 'size': # used in _array() but never called by tests
  325. return IntObject(0)
  326. if s == '__buffer__':
  327. # descr___buffer__ does not exist on W_Root
  328. return self.w_None
  329. return getattr(w_obj, 'descr_' + s)(self, *args)
  330. @specialize.arg(1)
  331. def interp_w(self, tp, what):
  332. assert isinstance(what, tp)
  333. return what
  334. def allocate_instance(self, klass, w_subtype):
  335. return instantiate(klass)
  336. def newtuple(self, list_w):
  337. return ListObject(list_w)
  338. def newdict(self, module=True):
  339. return DictObject({})
  340. @specialize.argtype(1)
  341. def newint(self, i):
  342. if isinstance(i, IntObject):
  343. return i
  344. if isinstance(i, base_int):
  345. return LongObject(i)
  346. return IntObject(i)
  347. def setitem(self, obj, index, value):
  348. obj.items[index] = value
  349. def exception_match(self, w_exc_type, w_check_class):
  350. assert isinstance(w_exc_type, W_TypeObject)
  351. assert isinstance(w_check_class, W_TypeObject)
  352. return w_exc_type.name == w_check_class.name
  353. def warn(self, w_msg, w_warn_type):
  354. pass
  355. def is_root(w_obj):
  356. assert isinstance(w_obj, W_Root)
  357. is_root.expecting = W_Root
  358. class FloatObject(W_Root):
  359. tp = FakeSpace.w_float
  360. def __init__(self, floatval):
  361. self.floatval = floatval
  362. class BoolObject(W_Root):
  363. tp = FakeSpace.w_bool
  364. def __init__(self, boolval):
  365. self.intval = boolval
  366. FakeSpace.w_True = BoolObject(True)
  367. FakeSpace.w_False = BoolObject(False)
  368. class IntObject(W_Root):
  369. tp = FakeSpace.w_int
  370. def __init__(self, intval):
  371. self.intval = intval
  372. class LongObject(W_Root):
  373. tp = FakeSpace.w_long
  374. def __init__(self, intval):
  375. self.intval = intval
  376. class ListObject(W_Root):
  377. tp = FakeSpace.w_list
  378. def __init__(self, items):
  379. self.items = items
  380. class DictObject(W_Root):
  381. tp = FakeSpace.w_dict
  382. def __init__(self, items):
  383. self.items = items
  384. def getdict(self, space):
  385. return self.items
  386. def getdictvalue(self, space, key):
  387. return self.items[key]
  388. def descr_memoryview(self, space, buf):
  389. raise oefmt(space.w_TypeError, "error")
  390. class IterDictObject(W_Root):
  391. def __init__(self, space, w_dict):
  392. self.space = space
  393. self.items = w_dict.items.items()
  394. self.i = 0
  395. def __iter__(self):
  396. return self
  397. def next(self):
  398. space = self.space
  399. if self.i >= len(self.items):
  400. raise oefmt(space.w_StopIteration, "stop iteration")
  401. self.i += 1
  402. return self.items[self.i-1][0]
  403. class SliceObject(W_Root):
  404. tp = FakeSpace.w_slice
  405. def __init__(self, start, stop, step):
  406. self.start = start
  407. self.stop = stop
  408. self.step = step
  409. class StringObject(W_Root):
  410. tp = FakeSpace.w_bytes
  411. def __init__(self, v):
  412. self.v = v
  413. class ComplexObject(W_Root):
  414. tp = FakeSpace.w_complex
  415. def __init__(self, r, i):
  416. self.r = r
  417. self.i = i
  418. class InterpreterState(object):
  419. def __init__(self, code):
  420. self.code = code
  421. self.variables = {}
  422. self.results = []
  423. def run(self, space):
  424. self.space = space
  425. for stmt in self.code.statements:
  426. stmt.execute(self)
  427. class Node(object):
  428. def __eq__(self, other):
  429. return (self.__class__ == other.__class__ and
  430. self.__dict__ == other.__dict__)
  431. def __ne__(self, other):
  432. return not self == other
  433. def wrap(self, space):
  434. raise NotImplementedError
  435. def execute(self, interp):
  436. raise NotImplementedError
  437. class Assignment(Node):
  438. def __init__(self, name, expr):
  439. self.name = name
  440. self.expr = expr
  441. def execute(self, interp):
  442. interp.variables[self.name] = self.expr.execute(interp)
  443. def __repr__(self):
  444. return "%r = %r" % (self.name, self.expr)
  445. class ArrayAssignment(Node):
  446. def __init__(self, name, index, expr):
  447. self.name = name
  448. self.index = index
  449. self.expr = expr
  450. def execute(self, interp):
  451. arr = interp.variables[self.name]
  452. w_index = self.index.execute(interp)
  453. # cast to int
  454. if isinstance(w_index, FloatObject):
  455. w_index = IntObject(int(w_index.floatval))
  456. w_val = self.expr.execute(interp)
  457. assert isinstance(arr, W_NDimArray)
  458. arr.descr_setitem(interp.space, w_index, w_val)
  459. def __repr__(self):
  460. return "%s[%r] = %r" % (self.name, self.index, self.expr)
  461. class Variable(Node):
  462. def __init__(self, name):
  463. self.name = name.strip(" ")
  464. def execute(self, interp):
  465. if self.name == 'None':
  466. return None
  467. return interp.variables[self.name]
  468. def __repr__(self):
  469. return 'v(%s)' % self.name
  470. class Operator(Node):
  471. def __init__(self, lhs, name, rhs):
  472. self.name = name
  473. self.lhs = lhs
  474. self.rhs = rhs
  475. def execute(self, interp):
  476. w_lhs = self.lhs.execute(interp)
  477. if isinstance(self.rhs, SliceConstant):
  478. w_rhs = self.rhs.wrap(interp.space)
  479. else:
  480. w_rhs = self.rhs.execute(interp)
  481. if not isinstance(w_lhs, W_NDimArray):
  482. # scalar
  483. dtype = get_dtype_cache(interp.space).w_float64dtype
  484. w_lhs = W_NDimArray.new_scalar(interp.space, dtype, w_lhs)
  485. assert isinstance(w_lhs, W_NDimArray)
  486. if self.name == '+':
  487. w_res = w_lhs.descr_add(interp.space, w_rhs)
  488. elif self.name == '*':
  489. w_res = w_lhs.descr_mul(interp.space, w_rhs)
  490. elif self.name == '-':
  491. w_res = w_lhs.descr_sub(interp.space, w_rhs)
  492. elif self.name == '**':
  493. w_res = w_lhs.descr_pow(interp.space, w_rhs)
  494. elif self.name == '->':
  495. if isinstance(w_rhs, FloatObject):
  496. w_rhs = IntObject(int(w_rhs.floatval))
  497. assert isinstance(w_lhs, W_NDimArray)
  498. w_res = w_lhs.descr_getitem(interp.space, w_rhs)
  499. if isinstance(w_rhs, IntObject):
  500. if isinstance(w_res, boxes.W_Float64Box):
  501. print "access", w_lhs, "[", w_rhs.intval, "] => ", float(w_res.value)
  502. if isinstance(w_res, boxes.W_Float32Box):
  503. print "access", w_lhs, "[", w_rhs.intval, "] => ", float(w_res.value)
  504. if isinstance(w_res, boxes.W_Int64Box):
  505. print "access", w_lhs, "[", w_rhs.intval, "] => ", int(w_res.value)
  506. if isinstance(w_res, boxes.W_Int32Box):
  507. print "access", w_lhs, "[", w_rhs.intval, "] => ", int(w_res.value)
  508. else:
  509. raise NotImplementedError
  510. if (not isinstance(w_res, W_NDimArray) and
  511. not isinstance(w_res, boxes.W_GenericBox)):
  512. dtype = get_dtype_cache(interp.space).w_float64dtype
  513. w_res = W_NDimArray.new_scalar(interp.space, dtype, w_res)
  514. return w_res
  515. def __repr__(self):
  516. return '(%r %s %r)' % (self.lhs, self.name, self.rhs)
  517. class NumberConstant(Node):
  518. def __init__(self, v):
  519. if isinstance(v, int):
  520. self.v = v
  521. elif isinstance(v, float):
  522. self.v = v
  523. else:
  524. assert isinstance(v, str)
  525. assert len(v) > 0
  526. c = v[-1]
  527. if c == 'f':
  528. self.v = float(v[:-1])
  529. elif c == 'i':
  530. self.v = int(v[:-1])
  531. else:
  532. self.v = float(v)
  533. def __repr__(self):
  534. return "Const(%s)" % self.v
  535. def wrap(self, space):
  536. return space.wrap(self.v)
  537. def execute(self, interp):
  538. return interp.space.wrap(self.v)
  539. class ComplexConstant(Node):
  540. def __init__(self, r, i):
  541. self.r = float(r)
  542. self.i = float(i)
  543. def __repr__(self):
  544. return 'ComplexConst(%s, %s)' % (self.r, self.i)
  545. def wrap(self, space):
  546. return space.newcomplex(self.r, self.i)
  547. def execute(self, interp):
  548. return self.wrap(interp.space)
  549. class RangeConstant(Node):
  550. def __init__(self, v):
  551. self.v = int(v)
  552. def execute(self, interp):
  553. w_list = interp.space.newlist(
  554. [interp.space.newfloat(float(i)) for i in range(self.v)]
  555. )
  556. dtype = get_dtype_cache(interp.space).w_float64dtype
  557. return array(interp.space, w_list, w_dtype=dtype, w_order=None)
  558. def __repr__(self):
  559. return 'Range(%s)' % self.v
  560. class Code(Node):
  561. def __init__(self, statements):
  562. self.statements = statements
  563. def __repr__(self):
  564. return "\n".join([repr(i) for i in self.statements])
  565. class ArrayConstant(Node):
  566. def __init__(self, items):
  567. self.items = items
  568. def wrap(self, space):
  569. return space.newlist([item.wrap(space) for item in self.items])
  570. def execute(self, interp):
  571. w_list = self.wrap(interp.space)
  572. return array(interp.space, w_list)
  573. def __repr__(self):
  574. return "[" + ", ".join([repr(item) for item in self.items]) + "]"
  575. class SliceConstant(Node):
  576. def __init__(self, start, stop, step):
  577. self.start = start
  578. self.stop = stop
  579. self.step = step
  580. def wrap(self, space):
  581. return SliceObject(self.start, self.stop, self.step)
  582. def execute(self, interp):
  583. return SliceObject(self.start, self.stop, self.step)
  584. def __repr__(self):
  585. return 'slice(%s,%s,%s)' % (self.start, self.stop, self.step)
  586. class ArrayClass(Node):
  587. def __init__(self):
  588. self.v = W_NDimArray
  589. def execute(self, interp):
  590. return self.v
  591. def __repr__(self):
  592. return '<class W_NDimArray>'
  593. class DtypeClass(Node):
  594. def __init__(self, dt):
  595. self.v = dt
  596. def execute(self, interp):
  597. if self.v == 'int':
  598. dtype = get_dtype_cache(interp.space).w_int64dtype
  599. elif self.v == 'int8':
  600. dtype = get_dtype_cache(interp.space).w_int8dtype
  601. elif self.v == 'int16':
  602. dtype = get_dtype_cache(interp.space).w_int16dtype
  603. elif self.v == 'int32':
  604. dtype = get_dtype_cache(interp.space).w_int32dtype
  605. elif self.v == 'uint':
  606. dtype = get_dtype_cache(interp.space).w_uint64dtype
  607. elif self.v == 'uint8':
  608. dtype = get_dtype_cache(interp.space).w_uint8dtype
  609. elif self.v == 'uint16':
  610. dtype = get_dtype_cache(interp.space).w_uint16dtype
  611. elif self.v == 'uint32':
  612. dtype = get_dtype_cache(interp.space).w_uint32dtype
  613. elif self.v == 'float':
  614. dtype = get_dtype_cache(interp.space).w_float64dtype
  615. elif self.v == 'float32':
  616. dtype = get_dtype_cache(interp.space).w_float32dtype
  617. else:
  618. raise BadToken('unknown v to dtype "%s"' % self.v)
  619. return dtype
  620. def __repr__(self):
  621. return '<class %s dtype>' % self.v
  622. class Execute(Node):
  623. def __init__(self, expr):
  624. self.expr = expr
  625. def __repr__(self):
  626. return repr(self.expr)
  627. def execute(self, interp):
  628. interp.results.append(self.expr.execute(interp))
  629. class FunctionCall(Node):
  630. def __init__(self, name, args):
  631. self.name = name.strip(" ")
  632. self.args = args
  633. def __repr__(self):
  634. return "%s(%s)" % (self.name, ", ".join([repr(arg)
  635. for arg in self.args]))
  636. def execute(self, interp):
  637. arr = self.args[0].execute(interp)
  638. if not isinstance(arr, W_NDimArray):
  639. raise ArgumentNotAnArray
  640. if self.name in SINGLE_ARG_FUNCTIONS:
  641. if len(self.args) != 1 and self.name != 'sum':
  642. raise ArgumentMismatch
  643. if self.name == "sum":
  644. if len(self.args)>1:
  645. var = self.args[1]
  646. if isinstance(var, DtypeClass):
  647. w_res = arr.descr_sum(interp.space, None, var.execute(interp))
  648. else:
  649. w_res = arr.descr_sum(interp.space,
  650. self.args[1].execute(interp))
  651. else:
  652. w_res = arr.descr_sum(interp.space)
  653. elif self.name == "prod":
  654. w_res = arr.descr_prod(interp.space)
  655. elif self.name == "max":
  656. w_res = arr.descr_max(interp.space)
  657. elif self.name == "min":
  658. w_res = arr.descr_min(interp.space)
  659. elif self.name == "any":
  660. w_res = arr.descr_any(interp.space)
  661. elif self.name == "all":
  662. w_res = arr.descr_all(interp.space)
  663. elif self.name == "cumsum":
  664. w_res = arr.descr_cumsum(interp.space)
  665. elif self.name == "logical_xor_reduce":
  666. logical_xor = ufuncs.get(interp.space).logical_xor
  667. w_res = logical_xor.reduce(interp.space, arr, None)
  668. elif self.name == "unegative":
  669. neg = ufuncs.get(interp.space).negative
  670. w_res = neg.call(interp.space, [arr], None, 'unsafe', None)
  671. elif self.name == "cos":
  672. cos = ufuncs.get(interp.space).cos
  673. w_res = cos.call(interp.space, [arr], None, 'unsafe', None)
  674. elif self.name == "flat":
  675. w_res = arr.descr_get_flatiter(interp.space)
  676. elif self.name == "argsort":
  677. w_res = arr.descr_argsort(interp.space)
  678. elif self.name == "tostring":
  679. arr.descr_tostring(interp.space)
  680. w_res = None
  681. else:
  682. assert False # unreachable code
  683. elif self.name in TWO_ARG_FUNCTIONS:
  684. if len(self.args) != 2:
  685. raise ArgumentMismatch
  686. arg = self.args[1].execute(interp)
  687. if not isinstance(arg, W_NDimArray):
  688. raise ArgumentNotAnArray
  689. if self.name == "dot":
  690. w_res = arr.descr_dot(interp.space, arg)
  691. elif self.name == 'multiply':
  692. w_res = arr.descr_mul(interp.space, arg)
  693. elif self.name == 'take':
  694. w_res = arr.descr_take(interp.space, arg)
  695. elif self.name == "searchsorted":
  696. w_res = arr.descr_searchsorted(interp.space, arg,
  697. interp.space.newtext('left'))
  698. else:
  699. assert False # unreachable code
  700. elif self.name in THREE_ARG_FUNCTIONS:
  701. if len(self.args) != 3:
  702. raise ArgumentMismatch
  703. arg1 = self.args[1].execute(interp)
  704. arg2 = self.args[2].execute(interp)
  705. if not isinstance(arg1, W_NDimArray):
  706. raise ArgumentNotAnArray
  707. if not isinstance(arg2, W_NDimArray):
  708. raise ArgumentNotAnArray
  709. if self.name == "where":
  710. w_res = where(interp.space, arr, arg1, arg2)
  711. else:
  712. assert False # unreachable code
  713. elif self.name in TWO_ARG_FUNCTIONS_OR_NONE:
  714. if len(self.args) != 2:
  715. raise ArgumentMismatch
  716. arg = self.args[1].execute(interp)
  717. if self.name == 'view':
  718. w_res = arr.descr_view(interp.space, arg)
  719. elif self.name == 'astype':
  720. w_res = arr.descr_astype(interp.space, arg)
  721. elif self.name == 'reshape':
  722. w_arg = self.args[1]
  723. assert isinstance(w_arg, ArrayConstant)
  724. order = -1
  725. w_res = arr.reshape(interp.space, w_arg.wrap(interp.space), order)
  726. else:
  727. assert False
  728. else:
  729. raise WrongFunctionName
  730. if isinstance(w_res, W_NDimArray):
  731. return w_res
  732. if isinstance(w_res, FloatObject):
  733. dtype = get_dtype_cache(interp.space).w_float64dtype
  734. elif isinstance(w_res, IntObject):
  735. dtype = get_dtype_cache(interp.space).w_int64dtype
  736. elif isinstance(w_res, BoolObject):
  737. dtype = get_dtype_cache(interp.space).w_booldtype
  738. elif isinstance(w_res, boxes.W_GenericBox):
  739. dtype = w_res.get_dtype(interp.space)
  740. else:
  741. dtype = None
  742. return W_NDimArray.new_scalar(interp.space, dtype, w_res)
  743. _REGEXES = [
  744. ('-?[\d\.]+(i|f)?', 'number'),
  745. ('\[', 'array_left'),
  746. (':', 'colon'),
  747. ('\w+', 'identifier'),
  748. ('\]', 'array_right'),
  749. ('(->)|[\+\-\*\/]+', 'operator'),
  750. ('=', 'assign'),
  751. (',', 'comma'),
  752. ('\|', 'pipe'),
  753. ('\(', 'paren_left'),
  754. ('\)', 'paren_right'),
  755. ]
  756. REGEXES = []
  757. for r, name in _REGEXES:
  758. REGEXES.append((re.compile(r' *(' + r + ')'), name))
  759. del _REGEXES
  760. class Token(object):
  761. def __init__(self, name, v):
  762. self.name = name
  763. self.v = v
  764. def __repr__(self):
  765. return '(%s, %s)' % (self.name, self.v)
  766. empty = Token('', '')
  767. class TokenStack(object):
  768. def __init__(self, tokens):
  769. self.tokens = tokens
  770. self.c = 0
  771. def pop(self):
  772. token = self.tokens[self.c]
  773. self.c += 1
  774. return token
  775. def get(self, i):
  776. if self.c + i >= len(self.tokens):
  777. return empty
  778. return self.tokens[self.c + i]
  779. def remaining(self):
  780. return len(self.tokens) - self.c
  781. def push(self):
  782. self.c -= 1
  783. def __repr__(self):
  784. return repr(self.tokens[self.c:])
  785. class Parser(object):
  786. def tokenize(self, line):
  787. tokens = []
  788. while True:
  789. for r, name in REGEXES:
  790. m = r.match(line)
  791. if m is not None:
  792. g = m.group(0)
  793. tokens.append(Token(name, g))
  794. line = line[len(g):]
  795. if not line:
  796. return TokenStack(tokens)
  797. break
  798. else:
  799. raise TokenizerError(line)
  800. def parse_number_or_slice(self, tokens):
  801. start_tok = tokens.pop()
  802. if start_tok.name == 'colon':
  803. start = 0
  804. else:
  805. if tokens.get(0).name != 'colon':
  806. return NumberConstant(start_tok.v)
  807. start = int(start_tok.v)
  808. tokens.pop()
  809. if not tokens.get(0).name in ['colon', 'number']:
  810. stop = -1
  811. step = 1
  812. else:
  813. next = tokens.pop()
  814. if next.name == 'colon':
  815. stop = -1
  816. step = int(tokens.pop().v)
  817. else:
  818. stop = int(next.v)
  819. if tokens.get(0).name == 'colon':
  820. tokens.pop()
  821. step = int(tokens.pop().v)
  822. else:
  823. step = 1
  824. return SliceConstant(start, stop, step)
  825. def parse_expression(self, tokens, accept_comma=False):
  826. stack = []
  827. while tokens.remaining():
  828. token = tokens.pop()
  829. if token.name == 'identifier':
  830. if tokens.remaining() and tokens.get(0).name == 'paren_left':
  831. stack.append(self.parse_function_call(token.v, tokens))
  832. elif token.v.strip(' ') == 'ndarray':
  833. stack.append(ArrayClass())
  834. elif token.v.strip(' ') == 'int':
  835. stack.append(DtypeClass('int'))
  836. elif token.v.strip(' ') == 'int8':
  837. stack.append(DtypeClass('int8'))
  838. elif token.v.strip(' ') == 'int16':
  839. stack.append(DtypeClass('int16'))
  840. elif token.v.strip(' ') == 'int32':
  841. stack.append(DtypeClass('int32'))
  842. elif token.v.strip(' ') == 'int64':
  843. stack.append(DtypeClass('int'))
  844. elif token.v.strip(' ') == 'uint':
  845. stack.append(DtypeClass('uint'))
  846. elif token.v.strip(' ') == 'uint8':
  847. stack.append(DtypeClass('uint8'))
  848. elif token.v.strip(' ') == 'uint16':
  849. stack.append(DtypeClass('uint16'))
  850. elif token.v.strip(' ') == 'uint32':
  851. stack.append(DtypeClass('uint32'))
  852. elif token.v.strip(' ') == 'uint64':
  853. stack.append(DtypeClass('uint'))
  854. elif token.v.strip(' ') == 'float':
  855. stack.append(DtypeClass('float'))
  856. elif token.v.strip(' ') == 'float32':
  857. stack.append(DtypeClass('float32'))
  858. elif token.v.strip(' ') == 'float64':
  859. stack.append(DtypeClass('float'))
  860. else:
  861. stack.append(Variable(token.v.strip(' ')))
  862. elif token.name == 'array_left':
  863. stack.append(ArrayConstant(self.parse_array_const(tokens)))
  864. elif token.name == 'operator':
  865. stack.append(Variable(token.v))
  866. elif token.name == 'number' or token.name == 'colon':
  867. tokens.push()
  868. stack.append(self.parse_number_or_slice(tokens))
  869. elif token.name == 'pipe':
  870. stack.append(RangeConstant(tokens.pop().v))
  871. end = tokens.pop()
  872. assert end.name == 'pipe'
  873. elif token.name == 'paren_left':
  874. stack.append(self.parse_complex_constant(tokens))
  875. elif accept_comma and token.name == 'comma':
  876. continue
  877. else:
  878. tokens.push()
  879. break
  880. if accept_comma:
  881. return stack
  882. stack.reverse()
  883. lhs = stack.pop()
  884. while stack:
  885. op = stack.pop()
  886. assert isinstance(op, Variable)
  887. rhs = stack.pop()
  888. lhs = Operator(lhs, op.name, rhs)
  889. return lhs
  890. def parse_function_call(self, name, tokens):
  891. args = []
  892. tokens.pop() # lparen
  893. while tokens.get(0).name != 'paren_right':
  894. args += self.parse_expression(tokens, accept_comma=True)
  895. return FunctionCall(name, args)
  896. def parse_complex_constant(self, tokens):
  897. r = tokens.pop()
  898. assert r.name == 'number'
  899. assert tokens.pop().name == 'comma'
  900. i = tokens.pop()
  901. assert i.name == 'number'
  902. assert tokens.pop().name == 'paren_right'
  903. return ComplexConstant(r.v, i.v)
  904. def parse_array_const(self, tokens):
  905. elems = []
  906. while True:
  907. token = tokens.pop()
  908. if token.name == 'number':
  909. elems.append(NumberConstant(token.v))
  910. elif token.name == 'array_left':
  911. elems.append(ArrayConstant(self.parse_array_const(tokens)))
  912. elif token.name == 'paren_left':
  913. elems.append(self.parse_complex_constant(tokens))
  914. else:
  915. raise BadToken()
  916. token = tokens.pop()
  917. if token.name == 'array_right':
  918. return elems
  919. assert token.name == 'comma'
  920. def parse_statement(self, tokens):
  921. if (tokens.get(0).name == 'identifier' and
  922. tokens.get(1).name == 'assign'):
  923. lhs = tokens.pop().v
  924. tokens.pop()
  925. rhs = self.parse_expression(tokens)
  926. return Assignment(lhs, rhs)
  927. elif (tokens.get(0).name == 'identifier' and
  928. tokens.get(1).name == 'array_left'):
  929. name = tokens.pop().v
  930. tokens.pop()
  931. index = self.parse_expression(tokens)
  932. tokens.pop()
  933. tokens.pop()
  934. return ArrayAssignment(name, index, self.parse_expression(tokens))
  935. return Execute(self.parse_expression(tokens))
  936. def parse(self, code):
  937. statements = []
  938. for line in code.split("\n"):
  939. if '#' in line:
  940. line = line.split('#', 1)[0]
  941. line = line.strip(" ")
  942. if line:
  943. tokens = self.tokenize(line)
  944. statements.append(self.parse_statement(tokens))
  945. return Code(statements)
  946. def numpy_compile(code):
  947. parser = Parser()
  948. return InterpreterState(parser.parse(code))