PageRenderTime 65ms CodeModel.GetById 34ms RepoModel.GetById 1ms app.codeStats 0ms

/pypy/module/micronumpy/compile.py

https://bitbucket.org/sonwell/pypy
Python | 728 lines | 616 code | 102 blank | 10 comment | 79 complexity | c484e907ef6a0e9107817b6e030147fe MD5 | raw file
  1. """ This is a set of tools for standalone compiling of numpy expressions.
  2. It should not be imported by the module itself
  3. """
  4. import re
  5. from pypy.interpreter.baseobjspace import InternalSpaceCache, W_Root
  6. from pypy.interpreter.error import OperationError
  7. from pypy.module.micronumpy import interp_boxes
  8. from pypy.module.micronumpy.interp_dtype import get_dtype_cache
  9. from pypy.module.micronumpy.base import W_NDimArray
  10. from pypy.module.micronumpy.interp_numarray import array
  11. from pypy.module.micronumpy.interp_arrayops import where
  12. from pypy.module.micronumpy import interp_ufuncs
  13. from pypy.rlib.objectmodel import specialize, instantiate
  14. class BogusBytecode(Exception):
  15. pass
  16. class ArgumentMismatch(Exception):
  17. pass
  18. class ArgumentNotAnArray(Exception):
  19. pass
  20. class WrongFunctionName(Exception):
  21. pass
  22. class TokenizerError(Exception):
  23. pass
  24. class BadToken(Exception):
  25. pass
  26. SINGLE_ARG_FUNCTIONS = ["sum", "prod", "max", "min", "all", "any",
  27. "unegative", "flat", "tostring","count_nonzero"]
  28. TWO_ARG_FUNCTIONS = ["dot", 'take']
  29. THREE_ARG_FUNCTIONS = ['where']
  30. class FakeSpace(object):
  31. w_ValueError = "ValueError"
  32. w_TypeError = "TypeError"
  33. w_IndexError = "IndexError"
  34. w_OverflowError = "OverflowError"
  35. w_NotImplementedError = "NotImplementedError"
  36. w_None = None
  37. w_bool = "bool"
  38. w_int = "int"
  39. w_float = "float"
  40. w_list = "list"
  41. w_long = "long"
  42. w_tuple = 'tuple'
  43. w_slice = "slice"
  44. w_str = "str"
  45. w_unicode = "unicode"
  46. w_complex = "complex"
  47. def __init__(self):
  48. """NOT_RPYTHON"""
  49. self.fromcache = InternalSpaceCache(self).getorbuild
  50. def _freeze_(self):
  51. return True
  52. def is_none(self, w_obj):
  53. return w_obj is None or w_obj is self.w_None
  54. def issequence_w(self, w_obj):
  55. return isinstance(w_obj, ListObject) or isinstance(w_obj, W_NDimArray)
  56. def isinstance_w(self, w_obj, w_tp):
  57. return w_obj.tp == w_tp
  58. def decode_index4(self, w_idx, size):
  59. if isinstance(w_idx, IntObject):
  60. return (self.int_w(w_idx), 0, 0, 1)
  61. else:
  62. assert isinstance(w_idx, SliceObject)
  63. start, stop, step = w_idx.start, w_idx.stop, w_idx.step
  64. if step == 0:
  65. return (0, size, 1, size)
  66. if start < 0:
  67. start += size
  68. if stop < 0:
  69. stop += size + 1
  70. if step < 0:
  71. lgt = (stop - start + 1) / step + 1
  72. else:
  73. lgt = (stop - start - 1) / step + 1
  74. return (start, stop, step, lgt)
  75. @specialize.argtype(1)
  76. def wrap(self, obj):
  77. if isinstance(obj, float):
  78. return FloatObject(obj)
  79. elif isinstance(obj, bool):
  80. return BoolObject(obj)
  81. elif isinstance(obj, int):
  82. return IntObject(obj)
  83. elif isinstance(obj, long):
  84. return LongObject(obj)
  85. elif isinstance(obj, W_Root):
  86. return obj
  87. elif isinstance(obj, str):
  88. return StringObject(obj)
  89. raise NotImplementedError
  90. def newlist(self, items):
  91. return ListObject(items)
  92. def newcomplex(self, r, i):
  93. return ComplexObject(r, i)
  94. def listview(self, obj):
  95. assert isinstance(obj, ListObject)
  96. return obj.items
  97. fixedview = listview
  98. def float(self, w_obj):
  99. if isinstance(w_obj, FloatObject):
  100. return w_obj
  101. assert isinstance(w_obj, interp_boxes.W_GenericBox)
  102. return self.float(w_obj.descr_float(self))
  103. def float_w(self, w_obj):
  104. assert isinstance(w_obj, FloatObject)
  105. return w_obj.floatval
  106. def int_w(self, w_obj):
  107. if isinstance(w_obj, IntObject):
  108. return w_obj.intval
  109. elif isinstance(w_obj, FloatObject):
  110. return int(w_obj.floatval)
  111. elif isinstance(w_obj, SliceObject):
  112. raise OperationError(self.w_TypeError, self.wrap("slice."))
  113. raise NotImplementedError
  114. def unpackcomplex(self, w_obj):
  115. if isinstance(w_obj, ComplexObject):
  116. return w_obj.r, w_obj.i
  117. raise NotImplementedError
  118. def index(self, w_obj):
  119. return self.wrap(self.int_w(w_obj))
  120. def str_w(self, w_obj):
  121. if isinstance(w_obj, StringObject):
  122. return w_obj.v
  123. raise NotImplementedError
  124. def int(self, w_obj):
  125. if isinstance(w_obj, IntObject):
  126. return w_obj
  127. assert isinstance(w_obj, interp_boxes.W_GenericBox)
  128. return self.int(w_obj.descr_int(self))
  129. def str(self, w_obj):
  130. if isinstance(w_obj, StringObject):
  131. return w_obj
  132. assert isinstance(w_obj, interp_boxes.W_GenericBox)
  133. return self.str(w_obj.descr_str(self))
  134. def is_true(self, w_obj):
  135. assert isinstance(w_obj, BoolObject)
  136. return False
  137. #return w_obj.boolval
  138. def is_w(self, w_obj, w_what):
  139. return w_obj is w_what
  140. def type(self, w_obj):
  141. return w_obj.tp
  142. def gettypefor(self, w_obj):
  143. return None
  144. def call_function(self, tp, w_dtype):
  145. return w_dtype
  146. @specialize.arg(1)
  147. def interp_w(self, tp, what):
  148. assert isinstance(what, tp)
  149. return what
  150. def allocate_instance(self, klass, w_subtype):
  151. return instantiate(klass)
  152. def newtuple(self, list_w):
  153. return ListObject(list_w)
  154. def newdict(self):
  155. return {}
  156. def setitem(self, dict, item, value):
  157. dict[item] = value
  158. def len_w(self, w_obj):
  159. if isinstance(w_obj, ListObject):
  160. return len(w_obj.items)
  161. # XXX array probably
  162. assert False
  163. def exception_match(self, w_exc_type, w_check_class):
  164. # Good enough for now
  165. raise NotImplementedError
  166. class FloatObject(W_Root):
  167. tp = FakeSpace.w_float
  168. def __init__(self, floatval):
  169. self.floatval = floatval
  170. class BoolObject(W_Root):
  171. tp = FakeSpace.w_bool
  172. def __init__(self, boolval):
  173. self.boolval = boolval
  174. class IntObject(W_Root):
  175. tp = FakeSpace.w_int
  176. def __init__(self, intval):
  177. self.intval = intval
  178. class LongObject(W_Root):
  179. tp = FakeSpace.w_long
  180. def __init__(self, intval):
  181. self.intval = intval
  182. class ListObject(W_Root):
  183. tp = FakeSpace.w_list
  184. def __init__(self, items):
  185. self.items = items
  186. class SliceObject(W_Root):
  187. tp = FakeSpace.w_slice
  188. def __init__(self, start, stop, step):
  189. self.start = start
  190. self.stop = stop
  191. self.step = step
  192. class StringObject(W_Root):
  193. tp = FakeSpace.w_str
  194. def __init__(self, v):
  195. self.v = v
  196. class ComplexObject(W_Root):
  197. tp = FakeSpace.w_complex
  198. def __init__(self, r, i):
  199. self.r = r
  200. self.i = i
  201. class InterpreterState(object):
  202. def __init__(self, code):
  203. self.code = code
  204. self.variables = {}
  205. self.results = []
  206. def run(self, space):
  207. self.space = space
  208. for stmt in self.code.statements:
  209. stmt.execute(self)
  210. class Node(object):
  211. def __eq__(self, other):
  212. return (self.__class__ == other.__class__ and
  213. self.__dict__ == other.__dict__)
  214. def __ne__(self, other):
  215. return not self == other
  216. def wrap(self, space):
  217. raise NotImplementedError
  218. def execute(self, interp):
  219. raise NotImplementedError
  220. class Assignment(Node):
  221. def __init__(self, name, expr):
  222. self.name = name
  223. self.expr = expr
  224. def execute(self, interp):
  225. interp.variables[self.name] = self.expr.execute(interp)
  226. def __repr__(self):
  227. return "%r = %r" % (self.name, self.expr)
  228. class ArrayAssignment(Node):
  229. def __init__(self, name, index, expr):
  230. self.name = name
  231. self.index = index
  232. self.expr = expr
  233. def execute(self, interp):
  234. arr = interp.variables[self.name]
  235. w_index = self.index.execute(interp)
  236. # cast to int
  237. if isinstance(w_index, FloatObject):
  238. w_index = IntObject(int(w_index.floatval))
  239. w_val = self.expr.execute(interp)
  240. assert isinstance(arr, W_NDimArray)
  241. arr.descr_setitem(interp.space, w_index, w_val)
  242. def __repr__(self):
  243. return "%s[%r] = %r" % (self.name, self.index, self.expr)
  244. class Variable(Node):
  245. def __init__(self, name):
  246. self.name = name.strip(" ")
  247. def execute(self, interp):
  248. return interp.variables[self.name]
  249. def __repr__(self):
  250. return 'v(%s)' % self.name
  251. class Operator(Node):
  252. def __init__(self, lhs, name, rhs):
  253. self.name = name
  254. self.lhs = lhs
  255. self.rhs = rhs
  256. def execute(self, interp):
  257. w_lhs = self.lhs.execute(interp)
  258. if isinstance(self.rhs, SliceConstant):
  259. w_rhs = self.rhs.wrap(interp.space)
  260. else:
  261. w_rhs = self.rhs.execute(interp)
  262. if not isinstance(w_lhs, W_NDimArray):
  263. # scalar
  264. dtype = get_dtype_cache(interp.space).w_float64dtype
  265. w_lhs = W_NDimArray.new_scalar(interp.space, dtype, w_lhs)
  266. assert isinstance(w_lhs, W_NDimArray)
  267. if self.name == '+':
  268. w_res = w_lhs.descr_add(interp.space, w_rhs)
  269. elif self.name == '*':
  270. w_res = w_lhs.descr_mul(interp.space, w_rhs)
  271. elif self.name == '-':
  272. w_res = w_lhs.descr_sub(interp.space, w_rhs)
  273. elif self.name == '->':
  274. if isinstance(w_rhs, FloatObject):
  275. w_rhs = IntObject(int(w_rhs.floatval))
  276. assert isinstance(w_lhs, W_NDimArray)
  277. w_res = w_lhs.descr_getitem(interp.space, w_rhs)
  278. else:
  279. raise NotImplementedError
  280. if (not isinstance(w_res, W_NDimArray) and
  281. not isinstance(w_res, interp_boxes.W_GenericBox)):
  282. dtype = get_dtype_cache(interp.space).w_float64dtype
  283. w_res = W_NDimArray.new_scalar(interp.space, dtype, w_res)
  284. return w_res
  285. def __repr__(self):
  286. return '(%r %s %r)' % (self.lhs, self.name, self.rhs)
  287. class FloatConstant(Node):
  288. def __init__(self, v):
  289. self.v = float(v)
  290. def __repr__(self):
  291. return "Const(%s)" % self.v
  292. def wrap(self, space):
  293. return space.wrap(self.v)
  294. def execute(self, interp):
  295. return interp.space.wrap(self.v)
  296. class ComplexConstant(Node):
  297. def __init__(self, r, i):
  298. self.r = float(r)
  299. self.i = float(i)
  300. def __repr__(self):
  301. return 'ComplexConst(%s, %s)' % (self.r, self.i)
  302. def wrap(self, space):
  303. return space.newcomplex(self.r, self.i)
  304. def execute(self, interp):
  305. return self.wrap(interp.space)
  306. class RangeConstant(Node):
  307. def __init__(self, v):
  308. self.v = int(v)
  309. def execute(self, interp):
  310. w_list = interp.space.newlist(
  311. [interp.space.wrap(float(i)) for i in range(self.v)]
  312. )
  313. dtype = get_dtype_cache(interp.space).w_float64dtype
  314. return array(interp.space, w_list, w_dtype=dtype, w_order=None)
  315. def __repr__(self):
  316. return 'Range(%s)' % self.v
  317. class Code(Node):
  318. def __init__(self, statements):
  319. self.statements = statements
  320. def __repr__(self):
  321. return "\n".join([repr(i) for i in self.statements])
  322. class ArrayConstant(Node):
  323. def __init__(self, items):
  324. self.items = items
  325. def wrap(self, space):
  326. return space.newlist([item.wrap(space) for item in self.items])
  327. def execute(self, interp):
  328. w_list = self.wrap(interp.space)
  329. return array(interp.space, w_list)
  330. def __repr__(self):
  331. return "[" + ", ".join([repr(item) for item in self.items]) + "]"
  332. class SliceConstant(Node):
  333. def __init__(self, start, stop, step):
  334. # no negative support for now
  335. self.start = start
  336. self.stop = stop
  337. self.step = step
  338. def wrap(self, space):
  339. return SliceObject(self.start, self.stop, self.step)
  340. def execute(self, interp):
  341. return SliceObject(self.start, self.stop, self.step)
  342. def __repr__(self):
  343. return 'slice(%s,%s,%s)' % (self.start, self.stop, self.step)
  344. class Execute(Node):
  345. def __init__(self, expr):
  346. self.expr = expr
  347. def __repr__(self):
  348. return repr(self.expr)
  349. def execute(self, interp):
  350. interp.results.append(self.expr.execute(interp))
  351. class FunctionCall(Node):
  352. def __init__(self, name, args):
  353. self.name = name.strip(" ")
  354. self.args = args
  355. def __repr__(self):
  356. return "%s(%s)" % (self.name, ", ".join([repr(arg)
  357. for arg in self.args]))
  358. def execute(self, interp):
  359. arr = self.args[0].execute(interp)
  360. if not isinstance(arr, W_NDimArray):
  361. raise ArgumentNotAnArray
  362. if self.name in SINGLE_ARG_FUNCTIONS:
  363. if len(self.args) != 1 and self.name != 'sum':
  364. raise ArgumentMismatch
  365. if self.name == "sum":
  366. if len(self.args)>1:
  367. w_res = arr.descr_sum(interp.space,
  368. self.args[1].execute(interp))
  369. else:
  370. w_res = arr.descr_sum(interp.space)
  371. elif self.name == "prod":
  372. w_res = arr.descr_prod(interp.space)
  373. elif self.name == "max":
  374. w_res = arr.descr_max(interp.space)
  375. elif self.name == "min":
  376. w_res = arr.descr_min(interp.space)
  377. elif self.name == "any":
  378. w_res = arr.descr_any(interp.space)
  379. elif self.name == "all":
  380. w_res = arr.descr_all(interp.space)
  381. elif self.name == "unegative":
  382. neg = interp_ufuncs.get(interp.space).negative
  383. w_res = neg.call(interp.space, [arr])
  384. elif self.name == "cos":
  385. cos = interp_ufuncs.get(interp.space).cos
  386. w_res = cos.call(interp.space, [arr])
  387. elif self.name == "flat":
  388. w_res = arr.descr_get_flatiter(interp.space)
  389. elif self.name == "tostring":
  390. arr.descr_tostring(interp.space)
  391. w_res = None
  392. else:
  393. assert False # unreachable code
  394. elif self.name in TWO_ARG_FUNCTIONS:
  395. if len(self.args) != 2:
  396. raise ArgumentMismatch
  397. arg = self.args[1].execute(interp)
  398. if not isinstance(arg, W_NDimArray):
  399. raise ArgumentNotAnArray
  400. if self.name == "dot":
  401. w_res = arr.descr_dot(interp.space, arg)
  402. elif self.name == 'take':
  403. w_res = arr.descr_take(interp.space, arg)
  404. else:
  405. assert False # unreachable code
  406. elif self.name in THREE_ARG_FUNCTIONS:
  407. if len(self.args) != 3:
  408. raise ArgumentMismatch
  409. arg1 = self.args[1].execute(interp)
  410. arg2 = self.args[2].execute(interp)
  411. if not isinstance(arg1, W_NDimArray):
  412. raise ArgumentNotAnArray
  413. if not isinstance(arg2, W_NDimArray):
  414. raise ArgumentNotAnArray
  415. if self.name == "where":
  416. w_res = where(interp.space, arr, arg1, arg2)
  417. else:
  418. assert False
  419. else:
  420. raise WrongFunctionName
  421. if isinstance(w_res, W_NDimArray):
  422. return w_res
  423. if isinstance(w_res, FloatObject):
  424. dtype = get_dtype_cache(interp.space).w_float64dtype
  425. elif isinstance(w_res, IntObject):
  426. dtype = get_dtype_cache(interp.space).w_int64dtype
  427. elif isinstance(w_res, BoolObject):
  428. dtype = get_dtype_cache(interp.space).w_booldtype
  429. elif isinstance(w_res, interp_boxes.W_GenericBox):
  430. dtype = w_res.get_dtype(interp.space)
  431. else:
  432. dtype = None
  433. return W_NDimArray.new_scalar(interp.space, dtype, w_res)
  434. _REGEXES = [
  435. ('-?[\d\.]+', 'number'),
  436. ('\[', 'array_left'),
  437. (':', 'colon'),
  438. ('\w+', 'identifier'),
  439. ('\]', 'array_right'),
  440. ('(->)|[\+\-\*\/]', 'operator'),
  441. ('=', 'assign'),
  442. (',', 'comma'),
  443. ('\|', 'pipe'),
  444. ('\(', 'paren_left'),
  445. ('\)', 'paren_right'),
  446. ]
  447. REGEXES = []
  448. for r, name in _REGEXES:
  449. REGEXES.append((re.compile(r' *(' + r + ')'), name))
  450. del _REGEXES
  451. class Token(object):
  452. def __init__(self, name, v):
  453. self.name = name
  454. self.v = v
  455. def __repr__(self):
  456. return '(%s, %s)' % (self.name, self.v)
  457. empty = Token('', '')
  458. class TokenStack(object):
  459. def __init__(self, tokens):
  460. self.tokens = tokens
  461. self.c = 0
  462. def pop(self):
  463. token = self.tokens[self.c]
  464. self.c += 1
  465. return token
  466. def get(self, i):
  467. if self.c + i >= len(self.tokens):
  468. return empty
  469. return self.tokens[self.c + i]
  470. def remaining(self):
  471. return len(self.tokens) - self.c
  472. def push(self):
  473. self.c -= 1
  474. def __repr__(self):
  475. return repr(self.tokens[self.c:])
  476. class Parser(object):
  477. def tokenize(self, line):
  478. tokens = []
  479. while True:
  480. for r, name in REGEXES:
  481. m = r.match(line)
  482. if m is not None:
  483. g = m.group(0)
  484. tokens.append(Token(name, g))
  485. line = line[len(g):]
  486. if not line:
  487. return TokenStack(tokens)
  488. break
  489. else:
  490. raise TokenizerError(line)
  491. def parse_number_or_slice(self, tokens):
  492. start_tok = tokens.pop()
  493. if start_tok.name == 'colon':
  494. start = 0
  495. else:
  496. if tokens.get(0).name != 'colon':
  497. return FloatConstant(start_tok.v)
  498. start = int(start_tok.v)
  499. tokens.pop()
  500. if not tokens.get(0).name in ['colon', 'number']:
  501. stop = -1
  502. step = 1
  503. else:
  504. next = tokens.pop()
  505. if next.name == 'colon':
  506. stop = -1
  507. step = int(tokens.pop().v)
  508. else:
  509. stop = int(next.v)
  510. if tokens.get(0).name == 'colon':
  511. tokens.pop()
  512. step = int(tokens.pop().v)
  513. else:
  514. step = 1
  515. return SliceConstant(start, stop, step)
  516. def parse_expression(self, tokens, accept_comma=False):
  517. stack = []
  518. while tokens.remaining():
  519. token = tokens.pop()
  520. if token.name == 'identifier':
  521. if tokens.remaining() and tokens.get(0).name == 'paren_left':
  522. stack.append(self.parse_function_call(token.v, tokens))
  523. else:
  524. stack.append(Variable(token.v))
  525. elif token.name == 'array_left':
  526. stack.append(ArrayConstant(self.parse_array_const(tokens)))
  527. elif token.name == 'operator':
  528. stack.append(Variable(token.v))
  529. elif token.name == 'number' or token.name == 'colon':
  530. tokens.push()
  531. stack.append(self.parse_number_or_slice(tokens))
  532. elif token.name == 'pipe':
  533. stack.append(RangeConstant(tokens.pop().v))
  534. end = tokens.pop()
  535. assert end.name == 'pipe'
  536. elif token.name == 'paren_left':
  537. stack.append(self.parse_complex_constant(tokens))
  538. elif accept_comma and token.name == 'comma':
  539. continue
  540. else:
  541. tokens.push()
  542. break
  543. if accept_comma:
  544. return stack
  545. stack.reverse()
  546. lhs = stack.pop()
  547. while stack:
  548. op = stack.pop()
  549. assert isinstance(op, Variable)
  550. rhs = stack.pop()
  551. lhs = Operator(lhs, op.name, rhs)
  552. return lhs
  553. def parse_function_call(self, name, tokens):
  554. args = []
  555. tokens.pop() # lparen
  556. while tokens.get(0).name != 'paren_right':
  557. args += self.parse_expression(tokens, accept_comma=True)
  558. return FunctionCall(name, args)
  559. def parse_complex_constant(self, tokens):
  560. r = tokens.pop()
  561. assert r.name == 'number'
  562. assert tokens.pop().name == 'comma'
  563. i = tokens.pop()
  564. assert i.name == 'number'
  565. assert tokens.pop().name == 'paren_right'
  566. return ComplexConstant(r.v, i.v)
  567. def parse_array_const(self, tokens):
  568. elems = []
  569. while True:
  570. token = tokens.pop()
  571. if token.name == 'number':
  572. elems.append(FloatConstant(token.v))
  573. elif token.name == 'array_left':
  574. elems.append(ArrayConstant(self.parse_array_const(tokens)))
  575. elif token.name == 'paren_left':
  576. elems.append(self.parse_complex_constant(tokens))
  577. else:
  578. raise BadToken()
  579. token = tokens.pop()
  580. if token.name == 'array_right':
  581. return elems
  582. assert token.name == 'comma'
  583. def parse_statement(self, tokens):
  584. if (tokens.get(0).name == 'identifier' and
  585. tokens.get(1).name == 'assign'):
  586. lhs = tokens.pop().v
  587. tokens.pop()
  588. rhs = self.parse_expression(tokens)
  589. return Assignment(lhs, rhs)
  590. elif (tokens.get(0).name == 'identifier' and
  591. tokens.get(1).name == 'array_left'):
  592. name = tokens.pop().v
  593. tokens.pop()
  594. index = self.parse_expression(tokens)
  595. tokens.pop()
  596. tokens.pop()
  597. return ArrayAssignment(name, index, self.parse_expression(tokens))
  598. return Execute(self.parse_expression(tokens))
  599. def parse(self, code):
  600. statements = []
  601. for line in code.split("\n"):
  602. if '#' in line:
  603. line = line.split('#', 1)[0]
  604. line = line.strip(" ")
  605. if line:
  606. tokens = self.tokenize(line)
  607. statements.append(self.parse_statement(tokens))
  608. return Code(statements)
  609. def numpy_compile(code):
  610. parser = Parser()
  611. return InterpreterState(parser.parse(code))