PageRenderTime 60ms CodeModel.GetById 28ms RepoModel.GetById 1ms app.codeStats 0ms

/pypy/module/micronumpy/compile.py

https://bitbucket.org/rguillebert/pypy
Python | 678 lines | 575 code | 93 blank | 10 comment | 76 complexity | be10c159c8433721e8d220a381e7b891 MD5 | raw file
  1. """ This is a set of tools for standalone compiling of numpy expressions.
  2. It should not be imported by the module itself
  3. """
  4. import re
  5. from pypy.interpreter.baseobjspace import InternalSpaceCache, W_Root
  6. from pypy.interpreter.error import OperationError
  7. from pypy.module.micronumpy import interp_boxes
  8. from pypy.module.micronumpy.interp_dtype import get_dtype_cache
  9. from pypy.module.micronumpy.base import W_NDimArray
  10. from pypy.module.micronumpy.interp_numarray import array
  11. from pypy.module.micronumpy.interp_arrayops import where
  12. from pypy.module.micronumpy import interp_ufuncs
  13. from pypy.rlib.objectmodel import specialize, instantiate
  14. class BogusBytecode(Exception):
  15. pass
  16. class ArgumentMismatch(Exception):
  17. pass
  18. class ArgumentNotAnArray(Exception):
  19. pass
  20. class WrongFunctionName(Exception):
  21. pass
  22. class TokenizerError(Exception):
  23. pass
  24. class BadToken(Exception):
  25. pass
  26. SINGLE_ARG_FUNCTIONS = ["sum", "prod", "max", "min", "all", "any",
  27. "unegative", "flat", "tostring","count_nonzero"]
  28. TWO_ARG_FUNCTIONS = ["dot", 'take']
  29. THREE_ARG_FUNCTIONS = ['where']
  30. class FakeSpace(object):
  31. w_ValueError = "ValueError"
  32. w_TypeError = "TypeError"
  33. w_IndexError = "IndexError"
  34. w_OverflowError = "OverflowError"
  35. w_NotImplementedError = "NotImplementedError"
  36. w_None = None
  37. w_bool = "bool"
  38. w_int = "int"
  39. w_float = "float"
  40. w_list = "list"
  41. w_long = "long"
  42. w_tuple = 'tuple'
  43. w_slice = "slice"
  44. w_str = "str"
  45. w_unicode = "unicode"
  46. def __init__(self):
  47. """NOT_RPYTHON"""
  48. self.fromcache = InternalSpaceCache(self).getorbuild
  49. def _freeze_(self):
  50. return True
  51. def issequence_w(self, w_obj):
  52. return isinstance(w_obj, ListObject) or isinstance(w_obj, W_NDimArray)
  53. def isinstance_w(self, w_obj, w_tp):
  54. return w_obj.tp == w_tp
  55. def decode_index4(self, w_idx, size):
  56. if isinstance(w_idx, IntObject):
  57. return (self.int_w(w_idx), 0, 0, 1)
  58. else:
  59. assert isinstance(w_idx, SliceObject)
  60. start, stop, step = w_idx.start, w_idx.stop, w_idx.step
  61. if step == 0:
  62. return (0, size, 1, size)
  63. if start < 0:
  64. start += size
  65. if stop < 0:
  66. stop += size + 1
  67. if step < 0:
  68. lgt = (stop - start + 1) / step + 1
  69. else:
  70. lgt = (stop - start - 1) / step + 1
  71. return (start, stop, step, lgt)
  72. @specialize.argtype(1)
  73. def wrap(self, obj):
  74. if isinstance(obj, float):
  75. return FloatObject(obj)
  76. elif isinstance(obj, bool):
  77. return BoolObject(obj)
  78. elif isinstance(obj, int):
  79. return IntObject(obj)
  80. elif isinstance(obj, long):
  81. return LongObject(obj)
  82. elif isinstance(obj, W_Root):
  83. return obj
  84. elif isinstance(obj, str):
  85. return StringObject(obj)
  86. raise NotImplementedError
  87. def newlist(self, items):
  88. return ListObject(items)
  89. def listview(self, obj):
  90. assert isinstance(obj, ListObject)
  91. return obj.items
  92. fixedview = listview
  93. def float(self, w_obj):
  94. if isinstance(w_obj, FloatObject):
  95. return w_obj
  96. assert isinstance(w_obj, interp_boxes.W_GenericBox)
  97. return self.float(w_obj.descr_float(self))
  98. def float_w(self, w_obj):
  99. assert isinstance(w_obj, FloatObject)
  100. return w_obj.floatval
  101. def int_w(self, w_obj):
  102. if isinstance(w_obj, IntObject):
  103. return w_obj.intval
  104. elif isinstance(w_obj, FloatObject):
  105. return int(w_obj.floatval)
  106. elif isinstance(w_obj, SliceObject):
  107. raise OperationError(self.w_TypeError, self.wrap("slice."))
  108. raise NotImplementedError
  109. def index(self, w_obj):
  110. return self.wrap(self.int_w(w_obj))
  111. def str_w(self, w_obj):
  112. if isinstance(w_obj, StringObject):
  113. return w_obj.v
  114. raise NotImplementedError
  115. def int(self, w_obj):
  116. if isinstance(w_obj, IntObject):
  117. return w_obj
  118. assert isinstance(w_obj, interp_boxes.W_GenericBox)
  119. return self.int(w_obj.descr_int(self))
  120. def is_true(self, w_obj):
  121. assert isinstance(w_obj, BoolObject)
  122. return False
  123. #return w_obj.boolval
  124. def is_w(self, w_obj, w_what):
  125. return w_obj is w_what
  126. def type(self, w_obj):
  127. return w_obj.tp
  128. def gettypefor(self, w_obj):
  129. return None
  130. def call_function(self, tp, w_dtype):
  131. return w_dtype
  132. @specialize.arg(1)
  133. def interp_w(self, tp, what):
  134. assert isinstance(what, tp)
  135. return what
  136. def allocate_instance(self, klass, w_subtype):
  137. return instantiate(klass)
  138. def newtuple(self, list_w):
  139. return ListObject(list_w)
  140. def newdict(self):
  141. return {}
  142. def setitem(self, dict, item, value):
  143. dict[item] = value
  144. def len_w(self, w_obj):
  145. if isinstance(w_obj, ListObject):
  146. return len(w_obj.items)
  147. # XXX array probably
  148. assert False
  149. def exception_match(self, w_exc_type, w_check_class):
  150. # Good enough for now
  151. raise NotImplementedError
  152. class FloatObject(W_Root):
  153. tp = FakeSpace.w_float
  154. def __init__(self, floatval):
  155. self.floatval = floatval
  156. class BoolObject(W_Root):
  157. tp = FakeSpace.w_bool
  158. def __init__(self, boolval):
  159. self.boolval = boolval
  160. class IntObject(W_Root):
  161. tp = FakeSpace.w_int
  162. def __init__(self, intval):
  163. self.intval = intval
  164. class LongObject(W_Root):
  165. tp = FakeSpace.w_long
  166. def __init__(self, intval):
  167. self.intval = intval
  168. class ListObject(W_Root):
  169. tp = FakeSpace.w_list
  170. def __init__(self, items):
  171. self.items = items
  172. class SliceObject(W_Root):
  173. tp = FakeSpace.w_slice
  174. def __init__(self, start, stop, step):
  175. self.start = start
  176. self.stop = stop
  177. self.step = step
  178. class StringObject(W_Root):
  179. tp = FakeSpace.w_str
  180. def __init__(self, v):
  181. self.v = v
  182. class InterpreterState(object):
  183. def __init__(self, code):
  184. self.code = code
  185. self.variables = {}
  186. self.results = []
  187. def run(self, space):
  188. self.space = space
  189. for stmt in self.code.statements:
  190. stmt.execute(self)
  191. class Node(object):
  192. def __eq__(self, other):
  193. return (self.__class__ == other.__class__ and
  194. self.__dict__ == other.__dict__)
  195. def __ne__(self, other):
  196. return not self == other
  197. def wrap(self, space):
  198. raise NotImplementedError
  199. def execute(self, interp):
  200. raise NotImplementedError
  201. class Assignment(Node):
  202. def __init__(self, name, expr):
  203. self.name = name
  204. self.expr = expr
  205. def execute(self, interp):
  206. interp.variables[self.name] = self.expr.execute(interp)
  207. def __repr__(self):
  208. return "%r = %r" % (self.name, self.expr)
  209. class ArrayAssignment(Node):
  210. def __init__(self, name, index, expr):
  211. self.name = name
  212. self.index = index
  213. self.expr = expr
  214. def execute(self, interp):
  215. arr = interp.variables[self.name]
  216. w_index = self.index.execute(interp)
  217. # cast to int
  218. if isinstance(w_index, FloatObject):
  219. w_index = IntObject(int(w_index.floatval))
  220. w_val = self.expr.execute(interp)
  221. assert isinstance(arr, W_NDimArray)
  222. arr.descr_setitem(interp.space, w_index, w_val)
  223. def __repr__(self):
  224. return "%s[%r] = %r" % (self.name, self.index, self.expr)
  225. class Variable(Node):
  226. def __init__(self, name):
  227. self.name = name.strip(" ")
  228. def execute(self, interp):
  229. return interp.variables[self.name]
  230. def __repr__(self):
  231. return 'v(%s)' % self.name
  232. class Operator(Node):
  233. def __init__(self, lhs, name, rhs):
  234. self.name = name
  235. self.lhs = lhs
  236. self.rhs = rhs
  237. def execute(self, interp):
  238. w_lhs = self.lhs.execute(interp)
  239. if isinstance(self.rhs, SliceConstant):
  240. w_rhs = self.rhs.wrap(interp.space)
  241. else:
  242. w_rhs = self.rhs.execute(interp)
  243. if not isinstance(w_lhs, W_NDimArray):
  244. # scalar
  245. dtype = get_dtype_cache(interp.space).w_float64dtype
  246. w_lhs = W_NDimArray.new_scalar(interp.space, dtype, w_lhs)
  247. assert isinstance(w_lhs, W_NDimArray)
  248. if self.name == '+':
  249. w_res = w_lhs.descr_add(interp.space, w_rhs)
  250. elif self.name == '*':
  251. w_res = w_lhs.descr_mul(interp.space, w_rhs)
  252. elif self.name == '-':
  253. w_res = w_lhs.descr_sub(interp.space, w_rhs)
  254. elif self.name == '->':
  255. if isinstance(w_rhs, FloatObject):
  256. w_rhs = IntObject(int(w_rhs.floatval))
  257. assert isinstance(w_lhs, W_NDimArray)
  258. w_res = w_lhs.descr_getitem(interp.space, w_rhs)
  259. else:
  260. raise NotImplementedError
  261. if (not isinstance(w_res, W_NDimArray) and
  262. not isinstance(w_res, interp_boxes.W_GenericBox)):
  263. dtype = get_dtype_cache(interp.space).w_float64dtype
  264. w_res = W_NDimArray.new_scalar(interp.space, dtype, w_res)
  265. return w_res
  266. def __repr__(self):
  267. return '(%r %s %r)' % (self.lhs, self.name, self.rhs)
  268. class FloatConstant(Node):
  269. def __init__(self, v):
  270. self.v = float(v)
  271. def __repr__(self):
  272. return "Const(%s)" % self.v
  273. def wrap(self, space):
  274. return space.wrap(self.v)
  275. def execute(self, interp):
  276. return interp.space.wrap(self.v)
  277. class RangeConstant(Node):
  278. def __init__(self, v):
  279. self.v = int(v)
  280. def execute(self, interp):
  281. w_list = interp.space.newlist(
  282. [interp.space.wrap(float(i)) for i in range(self.v)]
  283. )
  284. dtype = get_dtype_cache(interp.space).w_float64dtype
  285. return array(interp.space, w_list, w_dtype=dtype, w_order=None)
  286. def __repr__(self):
  287. return 'Range(%s)' % self.v
  288. class Code(Node):
  289. def __init__(self, statements):
  290. self.statements = statements
  291. def __repr__(self):
  292. return "\n".join([repr(i) for i in self.statements])
  293. class ArrayConstant(Node):
  294. def __init__(self, items):
  295. self.items = items
  296. def wrap(self, space):
  297. return space.newlist([item.wrap(space) for item in self.items])
  298. def execute(self, interp):
  299. w_list = self.wrap(interp.space)
  300. dtype = get_dtype_cache(interp.space).w_float64dtype
  301. return array(interp.space, w_list, w_dtype=dtype, w_order=None)
  302. def __repr__(self):
  303. return "[" + ", ".join([repr(item) for item in self.items]) + "]"
  304. class SliceConstant(Node):
  305. def __init__(self, start, stop, step):
  306. # no negative support for now
  307. self.start = start
  308. self.stop = stop
  309. self.step = step
  310. def wrap(self, space):
  311. return SliceObject(self.start, self.stop, self.step)
  312. def execute(self, interp):
  313. return SliceObject(self.start, self.stop, self.step)
  314. def __repr__(self):
  315. return 'slice(%s,%s,%s)' % (self.start, self.stop, self.step)
  316. class Execute(Node):
  317. def __init__(self, expr):
  318. self.expr = expr
  319. def __repr__(self):
  320. return repr(self.expr)
  321. def execute(self, interp):
  322. interp.results.append(self.expr.execute(interp))
  323. class FunctionCall(Node):
  324. def __init__(self, name, args):
  325. self.name = name.strip(" ")
  326. self.args = args
  327. def __repr__(self):
  328. return "%s(%s)" % (self.name, ", ".join([repr(arg)
  329. for arg in self.args]))
  330. def execute(self, interp):
  331. arr = self.args[0].execute(interp)
  332. if not isinstance(arr, W_NDimArray):
  333. raise ArgumentNotAnArray
  334. if self.name in SINGLE_ARG_FUNCTIONS:
  335. if len(self.args) != 1 and self.name != 'sum':
  336. raise ArgumentMismatch
  337. if self.name == "sum":
  338. if len(self.args)>1:
  339. w_res = arr.descr_sum(interp.space,
  340. self.args[1].execute(interp))
  341. else:
  342. w_res = arr.descr_sum(interp.space)
  343. elif self.name == "prod":
  344. w_res = arr.descr_prod(interp.space)
  345. elif self.name == "max":
  346. w_res = arr.descr_max(interp.space)
  347. elif self.name == "min":
  348. w_res = arr.descr_min(interp.space)
  349. elif self.name == "any":
  350. w_res = arr.descr_any(interp.space)
  351. elif self.name == "all":
  352. w_res = arr.descr_all(interp.space)
  353. elif self.name == "unegative":
  354. neg = interp_ufuncs.get(interp.space).negative
  355. w_res = neg.call(interp.space, [arr])
  356. elif self.name == "cos":
  357. cos = interp_ufuncs.get(interp.space).cos
  358. w_res = cos.call(interp.space, [arr])
  359. elif self.name == "flat":
  360. w_res = arr.descr_get_flatiter(interp.space)
  361. elif self.name == "tostring":
  362. arr.descr_tostring(interp.space)
  363. w_res = None
  364. else:
  365. assert False # unreachable code
  366. elif self.name in TWO_ARG_FUNCTIONS:
  367. if len(self.args) != 2:
  368. raise ArgumentMismatch
  369. arg = self.args[1].execute(interp)
  370. if not isinstance(arg, W_NDimArray):
  371. raise ArgumentNotAnArray
  372. if self.name == "dot":
  373. w_res = arr.descr_dot(interp.space, arg)
  374. elif self.name == 'take':
  375. w_res = arr.descr_take(interp.space, arg)
  376. else:
  377. assert False # unreachable code
  378. elif self.name in THREE_ARG_FUNCTIONS:
  379. if len(self.args) != 3:
  380. raise ArgumentMismatch
  381. arg1 = self.args[1].execute(interp)
  382. arg2 = self.args[2].execute(interp)
  383. if not isinstance(arg1, W_NDimArray):
  384. raise ArgumentNotAnArray
  385. if not isinstance(arg2, W_NDimArray):
  386. raise ArgumentNotAnArray
  387. if self.name == "where":
  388. w_res = where(interp.space, arr, arg1, arg2)
  389. else:
  390. assert False
  391. else:
  392. raise WrongFunctionName
  393. if isinstance(w_res, W_NDimArray):
  394. return w_res
  395. if isinstance(w_res, FloatObject):
  396. dtype = get_dtype_cache(interp.space).w_float64dtype
  397. elif isinstance(w_res, IntObject):
  398. dtype = get_dtype_cache(interp.space).w_int64dtype
  399. elif isinstance(w_res, BoolObject):
  400. dtype = get_dtype_cache(interp.space).w_booldtype
  401. elif isinstance(w_res, interp_boxes.W_GenericBox):
  402. dtype = w_res.get_dtype(interp.space)
  403. else:
  404. dtype = None
  405. return W_NDimArray.new_scalar(interp.space, dtype, w_res)
  406. _REGEXES = [
  407. ('-?[\d\.]+', 'number'),
  408. ('\[', 'array_left'),
  409. (':', 'colon'),
  410. ('\w+', 'identifier'),
  411. ('\]', 'array_right'),
  412. ('(->)|[\+\-\*\/]', 'operator'),
  413. ('=', 'assign'),
  414. (',', 'comma'),
  415. ('\|', 'pipe'),
  416. ('\(', 'paren_left'),
  417. ('\)', 'paren_right'),
  418. ]
  419. REGEXES = []
  420. for r, name in _REGEXES:
  421. REGEXES.append((re.compile(r' *(' + r + ')'), name))
  422. del _REGEXES
  423. class Token(object):
  424. def __init__(self, name, v):
  425. self.name = name
  426. self.v = v
  427. def __repr__(self):
  428. return '(%s, %s)' % (self.name, self.v)
  429. empty = Token('', '')
  430. class TokenStack(object):
  431. def __init__(self, tokens):
  432. self.tokens = tokens
  433. self.c = 0
  434. def pop(self):
  435. token = self.tokens[self.c]
  436. self.c += 1
  437. return token
  438. def get(self, i):
  439. if self.c + i >= len(self.tokens):
  440. return empty
  441. return self.tokens[self.c + i]
  442. def remaining(self):
  443. return len(self.tokens) - self.c
  444. def push(self):
  445. self.c -= 1
  446. def __repr__(self):
  447. return repr(self.tokens[self.c:])
  448. class Parser(object):
  449. def tokenize(self, line):
  450. tokens = []
  451. while True:
  452. for r, name in REGEXES:
  453. m = r.match(line)
  454. if m is not None:
  455. g = m.group(0)
  456. tokens.append(Token(name, g))
  457. line = line[len(g):]
  458. if not line:
  459. return TokenStack(tokens)
  460. break
  461. else:
  462. raise TokenizerError(line)
  463. def parse_number_or_slice(self, tokens):
  464. start_tok = tokens.pop()
  465. if start_tok.name == 'colon':
  466. start = 0
  467. else:
  468. if tokens.get(0).name != 'colon':
  469. return FloatConstant(start_tok.v)
  470. start = int(start_tok.v)
  471. tokens.pop()
  472. if not tokens.get(0).name in ['colon', 'number']:
  473. stop = -1
  474. step = 1
  475. else:
  476. next = tokens.pop()
  477. if next.name == 'colon':
  478. stop = -1
  479. step = int(tokens.pop().v)
  480. else:
  481. stop = int(next.v)
  482. if tokens.get(0).name == 'colon':
  483. tokens.pop()
  484. step = int(tokens.pop().v)
  485. else:
  486. step = 1
  487. return SliceConstant(start, stop, step)
  488. def parse_expression(self, tokens, accept_comma=False):
  489. stack = []
  490. while tokens.remaining():
  491. token = tokens.pop()
  492. if token.name == 'identifier':
  493. if tokens.remaining() and tokens.get(0).name == 'paren_left':
  494. stack.append(self.parse_function_call(token.v, tokens))
  495. else:
  496. stack.append(Variable(token.v))
  497. elif token.name == 'array_left':
  498. stack.append(ArrayConstant(self.parse_array_const(tokens)))
  499. elif token.name == 'operator':
  500. stack.append(Variable(token.v))
  501. elif token.name == 'number' or token.name == 'colon':
  502. tokens.push()
  503. stack.append(self.parse_number_or_slice(tokens))
  504. elif token.name == 'pipe':
  505. stack.append(RangeConstant(tokens.pop().v))
  506. end = tokens.pop()
  507. assert end.name == 'pipe'
  508. elif accept_comma and token.name == 'comma':
  509. continue
  510. else:
  511. tokens.push()
  512. break
  513. if accept_comma:
  514. return stack
  515. stack.reverse()
  516. lhs = stack.pop()
  517. while stack:
  518. op = stack.pop()
  519. assert isinstance(op, Variable)
  520. rhs = stack.pop()
  521. lhs = Operator(lhs, op.name, rhs)
  522. return lhs
  523. def parse_function_call(self, name, tokens):
  524. args = []
  525. tokens.pop() # lparen
  526. while tokens.get(0).name != 'paren_right':
  527. args += self.parse_expression(tokens, accept_comma=True)
  528. return FunctionCall(name, args)
  529. def parse_array_const(self, tokens):
  530. elems = []
  531. while True:
  532. token = tokens.pop()
  533. if token.name == 'number':
  534. elems.append(FloatConstant(token.v))
  535. elif token.name == 'array_left':
  536. elems.append(ArrayConstant(self.parse_array_const(tokens)))
  537. else:
  538. raise BadToken()
  539. token = tokens.pop()
  540. if token.name == 'array_right':
  541. return elems
  542. assert token.name == 'comma'
  543. def parse_statement(self, tokens):
  544. if (tokens.get(0).name == 'identifier' and
  545. tokens.get(1).name == 'assign'):
  546. lhs = tokens.pop().v
  547. tokens.pop()
  548. rhs = self.parse_expression(tokens)
  549. return Assignment(lhs, rhs)
  550. elif (tokens.get(0).name == 'identifier' and
  551. tokens.get(1).name == 'array_left'):
  552. name = tokens.pop().v
  553. tokens.pop()
  554. index = self.parse_expression(tokens)
  555. tokens.pop()
  556. tokens.pop()
  557. return ArrayAssignment(name, index, self.parse_expression(tokens))
  558. return Execute(self.parse_expression(tokens))
  559. def parse(self, code):
  560. statements = []
  561. for line in code.split("\n"):
  562. if '#' in line:
  563. line = line.split('#', 1)[0]
  564. line = line.strip(" ")
  565. if line:
  566. tokens = self.tokenize(line)
  567. statements.append(self.parse_statement(tokens))
  568. return Code(statements)
  569. def numpy_compile(code):
  570. parser = Parser()
  571. return InterpreterState(parser.parse(code))