PageRenderTime 51ms CodeModel.GetById 20ms RepoModel.GetById 0ms app.codeStats 1ms

/pypy/module/micronumpy/compile.py

https://bitbucket.org/jerith/pypy
Python | 679 lines | 576 code | 93 blank | 10 comment | 76 complexity | 59dbb6765a2da3f06618164d9d77c634 MD5 | raw file
  1. """ This is a set of tools for standalone compiling of numpy expressions.
  2. It should not be imported by the module itself
  3. """
  4. import re
  5. from pypy.interpreter.baseobjspace import InternalSpaceCache, W_Root
  6. from pypy.interpreter.error import OperationError
  7. from pypy.module.micronumpy import interp_boxes
  8. from pypy.module.micronumpy.interp_dtype import get_dtype_cache
  9. from pypy.module.micronumpy.base import W_NDimArray
  10. from pypy.module.micronumpy.interp_numarray import array
  11. from pypy.module.micronumpy.interp_arrayops import where
  12. from pypy.module.micronumpy import interp_ufuncs
  13. from pypy.rlib.objectmodel import specialize, instantiate
  14. class BogusBytecode(Exception):
  15. pass
  16. class ArgumentMismatch(Exception):
  17. pass
  18. class ArgumentNotAnArray(Exception):
  19. pass
  20. class WrongFunctionName(Exception):
  21. pass
  22. class TokenizerError(Exception):
  23. pass
  24. class BadToken(Exception):
  25. pass
  26. SINGLE_ARG_FUNCTIONS = ["sum", "prod", "max", "min", "all", "any",
  27. "unegative", "flat", "tostring","count_nonzero"]
  28. TWO_ARG_FUNCTIONS = ["dot", 'take']
  29. THREE_ARG_FUNCTIONS = ['where']
  30. class FakeSpace(object):
  31. w_ValueError = "ValueError"
  32. w_TypeError = "TypeError"
  33. w_IndexError = "IndexError"
  34. w_OverflowError = "OverflowError"
  35. w_NotImplementedError = "NotImplementedError"
  36. w_None = None
  37. w_bool = "bool"
  38. w_int = "int"
  39. w_float = "float"
  40. w_list = "list"
  41. w_long = "long"
  42. w_tuple = 'tuple'
  43. w_slice = "slice"
  44. w_str = "str"
  45. w_unicode = "unicode"
  46. w_complex = "complex"
  47. def __init__(self):
  48. """NOT_RPYTHON"""
  49. self.fromcache = InternalSpaceCache(self).getorbuild
  50. def _freeze_(self):
  51. return True
  52. def issequence_w(self, w_obj):
  53. return isinstance(w_obj, ListObject) or isinstance(w_obj, W_NDimArray)
  54. def isinstance_w(self, w_obj, w_tp):
  55. return w_obj.tp == w_tp
  56. def decode_index4(self, w_idx, size):
  57. if isinstance(w_idx, IntObject):
  58. return (self.int_w(w_idx), 0, 0, 1)
  59. else:
  60. assert isinstance(w_idx, SliceObject)
  61. start, stop, step = w_idx.start, w_idx.stop, w_idx.step
  62. if step == 0:
  63. return (0, size, 1, size)
  64. if start < 0:
  65. start += size
  66. if stop < 0:
  67. stop += size + 1
  68. if step < 0:
  69. lgt = (stop - start + 1) / step + 1
  70. else:
  71. lgt = (stop - start - 1) / step + 1
  72. return (start, stop, step, lgt)
  73. @specialize.argtype(1)
  74. def wrap(self, obj):
  75. if isinstance(obj, float):
  76. return FloatObject(obj)
  77. elif isinstance(obj, bool):
  78. return BoolObject(obj)
  79. elif isinstance(obj, int):
  80. return IntObject(obj)
  81. elif isinstance(obj, long):
  82. return LongObject(obj)
  83. elif isinstance(obj, W_Root):
  84. return obj
  85. elif isinstance(obj, str):
  86. return StringObject(obj)
  87. raise NotImplementedError
  88. def newlist(self, items):
  89. return ListObject(items)
  90. def listview(self, obj):
  91. assert isinstance(obj, ListObject)
  92. return obj.items
  93. fixedview = listview
  94. def float(self, w_obj):
  95. if isinstance(w_obj, FloatObject):
  96. return w_obj
  97. assert isinstance(w_obj, interp_boxes.W_GenericBox)
  98. return self.float(w_obj.descr_float(self))
  99. def float_w(self, w_obj):
  100. assert isinstance(w_obj, FloatObject)
  101. return w_obj.floatval
  102. def int_w(self, w_obj):
  103. if isinstance(w_obj, IntObject):
  104. return w_obj.intval
  105. elif isinstance(w_obj, FloatObject):
  106. return int(w_obj.floatval)
  107. elif isinstance(w_obj, SliceObject):
  108. raise OperationError(self.w_TypeError, self.wrap("slice."))
  109. raise NotImplementedError
  110. def index(self, w_obj):
  111. return self.wrap(self.int_w(w_obj))
  112. def str_w(self, w_obj):
  113. if isinstance(w_obj, StringObject):
  114. return w_obj.v
  115. raise NotImplementedError
  116. def int(self, w_obj):
  117. if isinstance(w_obj, IntObject):
  118. return w_obj
  119. assert isinstance(w_obj, interp_boxes.W_GenericBox)
  120. return self.int(w_obj.descr_int(self))
  121. def is_true(self, w_obj):
  122. assert isinstance(w_obj, BoolObject)
  123. return False
  124. #return w_obj.boolval
  125. def is_w(self, w_obj, w_what):
  126. return w_obj is w_what
  127. def type(self, w_obj):
  128. return w_obj.tp
  129. def gettypefor(self, w_obj):
  130. return None
  131. def call_function(self, tp, w_dtype):
  132. return w_dtype
  133. @specialize.arg(1)
  134. def interp_w(self, tp, what):
  135. assert isinstance(what, tp)
  136. return what
  137. def allocate_instance(self, klass, w_subtype):
  138. return instantiate(klass)
  139. def newtuple(self, list_w):
  140. return ListObject(list_w)
  141. def newdict(self):
  142. return {}
  143. def setitem(self, dict, item, value):
  144. dict[item] = value
  145. def len_w(self, w_obj):
  146. if isinstance(w_obj, ListObject):
  147. return len(w_obj.items)
  148. # XXX array probably
  149. assert False
  150. def exception_match(self, w_exc_type, w_check_class):
  151. # Good enough for now
  152. raise NotImplementedError
  153. class FloatObject(W_Root):
  154. tp = FakeSpace.w_float
  155. def __init__(self, floatval):
  156. self.floatval = floatval
  157. class BoolObject(W_Root):
  158. tp = FakeSpace.w_bool
  159. def __init__(self, boolval):
  160. self.boolval = boolval
  161. class IntObject(W_Root):
  162. tp = FakeSpace.w_int
  163. def __init__(self, intval):
  164. self.intval = intval
  165. class LongObject(W_Root):
  166. tp = FakeSpace.w_long
  167. def __init__(self, intval):
  168. self.intval = intval
  169. class ListObject(W_Root):
  170. tp = FakeSpace.w_list
  171. def __init__(self, items):
  172. self.items = items
  173. class SliceObject(W_Root):
  174. tp = FakeSpace.w_slice
  175. def __init__(self, start, stop, step):
  176. self.start = start
  177. self.stop = stop
  178. self.step = step
  179. class StringObject(W_Root):
  180. tp = FakeSpace.w_str
  181. def __init__(self, v):
  182. self.v = v
  183. class InterpreterState(object):
  184. def __init__(self, code):
  185. self.code = code
  186. self.variables = {}
  187. self.results = []
  188. def run(self, space):
  189. self.space = space
  190. for stmt in self.code.statements:
  191. stmt.execute(self)
  192. class Node(object):
  193. def __eq__(self, other):
  194. return (self.__class__ == other.__class__ and
  195. self.__dict__ == other.__dict__)
  196. def __ne__(self, other):
  197. return not self == other
  198. def wrap(self, space):
  199. raise NotImplementedError
  200. def execute(self, interp):
  201. raise NotImplementedError
  202. class Assignment(Node):
  203. def __init__(self, name, expr):
  204. self.name = name
  205. self.expr = expr
  206. def execute(self, interp):
  207. interp.variables[self.name] = self.expr.execute(interp)
  208. def __repr__(self):
  209. return "%r = %r" % (self.name, self.expr)
  210. class ArrayAssignment(Node):
  211. def __init__(self, name, index, expr):
  212. self.name = name
  213. self.index = index
  214. self.expr = expr
  215. def execute(self, interp):
  216. arr = interp.variables[self.name]
  217. w_index = self.index.execute(interp)
  218. # cast to int
  219. if isinstance(w_index, FloatObject):
  220. w_index = IntObject(int(w_index.floatval))
  221. w_val = self.expr.execute(interp)
  222. assert isinstance(arr, W_NDimArray)
  223. arr.descr_setitem(interp.space, w_index, w_val)
  224. def __repr__(self):
  225. return "%s[%r] = %r" % (self.name, self.index, self.expr)
  226. class Variable(Node):
  227. def __init__(self, name):
  228. self.name = name.strip(" ")
  229. def execute(self, interp):
  230. return interp.variables[self.name]
  231. def __repr__(self):
  232. return 'v(%s)' % self.name
  233. class Operator(Node):
  234. def __init__(self, lhs, name, rhs):
  235. self.name = name
  236. self.lhs = lhs
  237. self.rhs = rhs
  238. def execute(self, interp):
  239. w_lhs = self.lhs.execute(interp)
  240. if isinstance(self.rhs, SliceConstant):
  241. w_rhs = self.rhs.wrap(interp.space)
  242. else:
  243. w_rhs = self.rhs.execute(interp)
  244. if not isinstance(w_lhs, W_NDimArray):
  245. # scalar
  246. dtype = get_dtype_cache(interp.space).w_float64dtype
  247. w_lhs = W_NDimArray.new_scalar(interp.space, dtype, w_lhs)
  248. assert isinstance(w_lhs, W_NDimArray)
  249. if self.name == '+':
  250. w_res = w_lhs.descr_add(interp.space, w_rhs)
  251. elif self.name == '*':
  252. w_res = w_lhs.descr_mul(interp.space, w_rhs)
  253. elif self.name == '-':
  254. w_res = w_lhs.descr_sub(interp.space, w_rhs)
  255. elif self.name == '->':
  256. if isinstance(w_rhs, FloatObject):
  257. w_rhs = IntObject(int(w_rhs.floatval))
  258. assert isinstance(w_lhs, W_NDimArray)
  259. w_res = w_lhs.descr_getitem(interp.space, w_rhs)
  260. else:
  261. raise NotImplementedError
  262. if (not isinstance(w_res, W_NDimArray) and
  263. not isinstance(w_res, interp_boxes.W_GenericBox)):
  264. dtype = get_dtype_cache(interp.space).w_float64dtype
  265. w_res = W_NDimArray.new_scalar(interp.space, dtype, w_res)
  266. return w_res
  267. def __repr__(self):
  268. return '(%r %s %r)' % (self.lhs, self.name, self.rhs)
  269. class FloatConstant(Node):
  270. def __init__(self, v):
  271. self.v = float(v)
  272. def __repr__(self):
  273. return "Const(%s)" % self.v
  274. def wrap(self, space):
  275. return space.wrap(self.v)
  276. def execute(self, interp):
  277. return interp.space.wrap(self.v)
  278. class RangeConstant(Node):
  279. def __init__(self, v):
  280. self.v = int(v)
  281. def execute(self, interp):
  282. w_list = interp.space.newlist(
  283. [interp.space.wrap(float(i)) for i in range(self.v)]
  284. )
  285. dtype = get_dtype_cache(interp.space).w_float64dtype
  286. return array(interp.space, w_list, w_dtype=dtype, w_order=None)
  287. def __repr__(self):
  288. return 'Range(%s)' % self.v
  289. class Code(Node):
  290. def __init__(self, statements):
  291. self.statements = statements
  292. def __repr__(self):
  293. return "\n".join([repr(i) for i in self.statements])
  294. class ArrayConstant(Node):
  295. def __init__(self, items):
  296. self.items = items
  297. def wrap(self, space):
  298. return space.newlist([item.wrap(space) for item in self.items])
  299. def execute(self, interp):
  300. w_list = self.wrap(interp.space)
  301. dtype = get_dtype_cache(interp.space).w_float64dtype
  302. return array(interp.space, w_list, w_dtype=dtype, w_order=None)
  303. def __repr__(self):
  304. return "[" + ", ".join([repr(item) for item in self.items]) + "]"
  305. class SliceConstant(Node):
  306. def __init__(self, start, stop, step):
  307. # no negative support for now
  308. self.start = start
  309. self.stop = stop
  310. self.step = step
  311. def wrap(self, space):
  312. return SliceObject(self.start, self.stop, self.step)
  313. def execute(self, interp):
  314. return SliceObject(self.start, self.stop, self.step)
  315. def __repr__(self):
  316. return 'slice(%s,%s,%s)' % (self.start, self.stop, self.step)
  317. class Execute(Node):
  318. def __init__(self, expr):
  319. self.expr = expr
  320. def __repr__(self):
  321. return repr(self.expr)
  322. def execute(self, interp):
  323. interp.results.append(self.expr.execute(interp))
  324. class FunctionCall(Node):
  325. def __init__(self, name, args):
  326. self.name = name.strip(" ")
  327. self.args = args
  328. def __repr__(self):
  329. return "%s(%s)" % (self.name, ", ".join([repr(arg)
  330. for arg in self.args]))
  331. def execute(self, interp):
  332. arr = self.args[0].execute(interp)
  333. if not isinstance(arr, W_NDimArray):
  334. raise ArgumentNotAnArray
  335. if self.name in SINGLE_ARG_FUNCTIONS:
  336. if len(self.args) != 1 and self.name != 'sum':
  337. raise ArgumentMismatch
  338. if self.name == "sum":
  339. if len(self.args)>1:
  340. w_res = arr.descr_sum(interp.space,
  341. self.args[1].execute(interp))
  342. else:
  343. w_res = arr.descr_sum(interp.space)
  344. elif self.name == "prod":
  345. w_res = arr.descr_prod(interp.space)
  346. elif self.name == "max":
  347. w_res = arr.descr_max(interp.space)
  348. elif self.name == "min":
  349. w_res = arr.descr_min(interp.space)
  350. elif self.name == "any":
  351. w_res = arr.descr_any(interp.space)
  352. elif self.name == "all":
  353. w_res = arr.descr_all(interp.space)
  354. elif self.name == "unegative":
  355. neg = interp_ufuncs.get(interp.space).negative
  356. w_res = neg.call(interp.space, [arr])
  357. elif self.name == "cos":
  358. cos = interp_ufuncs.get(interp.space).cos
  359. w_res = cos.call(interp.space, [arr])
  360. elif self.name == "flat":
  361. w_res = arr.descr_get_flatiter(interp.space)
  362. elif self.name == "tostring":
  363. arr.descr_tostring(interp.space)
  364. w_res = None
  365. else:
  366. assert False # unreachable code
  367. elif self.name in TWO_ARG_FUNCTIONS:
  368. if len(self.args) != 2:
  369. raise ArgumentMismatch
  370. arg = self.args[1].execute(interp)
  371. if not isinstance(arg, W_NDimArray):
  372. raise ArgumentNotAnArray
  373. if self.name == "dot":
  374. w_res = arr.descr_dot(interp.space, arg)
  375. elif self.name == 'take':
  376. w_res = arr.descr_take(interp.space, arg)
  377. else:
  378. assert False # unreachable code
  379. elif self.name in THREE_ARG_FUNCTIONS:
  380. if len(self.args) != 3:
  381. raise ArgumentMismatch
  382. arg1 = self.args[1].execute(interp)
  383. arg2 = self.args[2].execute(interp)
  384. if not isinstance(arg1, W_NDimArray):
  385. raise ArgumentNotAnArray
  386. if not isinstance(arg2, W_NDimArray):
  387. raise ArgumentNotAnArray
  388. if self.name == "where":
  389. w_res = where(interp.space, arr, arg1, arg2)
  390. else:
  391. assert False
  392. else:
  393. raise WrongFunctionName
  394. if isinstance(w_res, W_NDimArray):
  395. return w_res
  396. if isinstance(w_res, FloatObject):
  397. dtype = get_dtype_cache(interp.space).w_float64dtype
  398. elif isinstance(w_res, IntObject):
  399. dtype = get_dtype_cache(interp.space).w_int64dtype
  400. elif isinstance(w_res, BoolObject):
  401. dtype = get_dtype_cache(interp.space).w_booldtype
  402. elif isinstance(w_res, interp_boxes.W_GenericBox):
  403. dtype = w_res.get_dtype(interp.space)
  404. else:
  405. dtype = None
  406. return W_NDimArray.new_scalar(interp.space, dtype, w_res)
  407. _REGEXES = [
  408. ('-?[\d\.]+', 'number'),
  409. ('\[', 'array_left'),
  410. (':', 'colon'),
  411. ('\w+', 'identifier'),
  412. ('\]', 'array_right'),
  413. ('(->)|[\+\-\*\/]', 'operator'),
  414. ('=', 'assign'),
  415. (',', 'comma'),
  416. ('\|', 'pipe'),
  417. ('\(', 'paren_left'),
  418. ('\)', 'paren_right'),
  419. ]
  420. REGEXES = []
  421. for r, name in _REGEXES:
  422. REGEXES.append((re.compile(r' *(' + r + ')'), name))
  423. del _REGEXES
  424. class Token(object):
  425. def __init__(self, name, v):
  426. self.name = name
  427. self.v = v
  428. def __repr__(self):
  429. return '(%s, %s)' % (self.name, self.v)
  430. empty = Token('', '')
  431. class TokenStack(object):
  432. def __init__(self, tokens):
  433. self.tokens = tokens
  434. self.c = 0
  435. def pop(self):
  436. token = self.tokens[self.c]
  437. self.c += 1
  438. return token
  439. def get(self, i):
  440. if self.c + i >= len(self.tokens):
  441. return empty
  442. return self.tokens[self.c + i]
  443. def remaining(self):
  444. return len(self.tokens) - self.c
  445. def push(self):
  446. self.c -= 1
  447. def __repr__(self):
  448. return repr(self.tokens[self.c:])
  449. class Parser(object):
  450. def tokenize(self, line):
  451. tokens = []
  452. while True:
  453. for r, name in REGEXES:
  454. m = r.match(line)
  455. if m is not None:
  456. g = m.group(0)
  457. tokens.append(Token(name, g))
  458. line = line[len(g):]
  459. if not line:
  460. return TokenStack(tokens)
  461. break
  462. else:
  463. raise TokenizerError(line)
  464. def parse_number_or_slice(self, tokens):
  465. start_tok = tokens.pop()
  466. if start_tok.name == 'colon':
  467. start = 0
  468. else:
  469. if tokens.get(0).name != 'colon':
  470. return FloatConstant(start_tok.v)
  471. start = int(start_tok.v)
  472. tokens.pop()
  473. if not tokens.get(0).name in ['colon', 'number']:
  474. stop = -1
  475. step = 1
  476. else:
  477. next = tokens.pop()
  478. if next.name == 'colon':
  479. stop = -1
  480. step = int(tokens.pop().v)
  481. else:
  482. stop = int(next.v)
  483. if tokens.get(0).name == 'colon':
  484. tokens.pop()
  485. step = int(tokens.pop().v)
  486. else:
  487. step = 1
  488. return SliceConstant(start, stop, step)
  489. def parse_expression(self, tokens, accept_comma=False):
  490. stack = []
  491. while tokens.remaining():
  492. token = tokens.pop()
  493. if token.name == 'identifier':
  494. if tokens.remaining() and tokens.get(0).name == 'paren_left':
  495. stack.append(self.parse_function_call(token.v, tokens))
  496. else:
  497. stack.append(Variable(token.v))
  498. elif token.name == 'array_left':
  499. stack.append(ArrayConstant(self.parse_array_const(tokens)))
  500. elif token.name == 'operator':
  501. stack.append(Variable(token.v))
  502. elif token.name == 'number' or token.name == 'colon':
  503. tokens.push()
  504. stack.append(self.parse_number_or_slice(tokens))
  505. elif token.name == 'pipe':
  506. stack.append(RangeConstant(tokens.pop().v))
  507. end = tokens.pop()
  508. assert end.name == 'pipe'
  509. elif accept_comma and token.name == 'comma':
  510. continue
  511. else:
  512. tokens.push()
  513. break
  514. if accept_comma:
  515. return stack
  516. stack.reverse()
  517. lhs = stack.pop()
  518. while stack:
  519. op = stack.pop()
  520. assert isinstance(op, Variable)
  521. rhs = stack.pop()
  522. lhs = Operator(lhs, op.name, rhs)
  523. return lhs
  524. def parse_function_call(self, name, tokens):
  525. args = []
  526. tokens.pop() # lparen
  527. while tokens.get(0).name != 'paren_right':
  528. args += self.parse_expression(tokens, accept_comma=True)
  529. return FunctionCall(name, args)
  530. def parse_array_const(self, tokens):
  531. elems = []
  532. while True:
  533. token = tokens.pop()
  534. if token.name == 'number':
  535. elems.append(FloatConstant(token.v))
  536. elif token.name == 'array_left':
  537. elems.append(ArrayConstant(self.parse_array_const(tokens)))
  538. else:
  539. raise BadToken()
  540. token = tokens.pop()
  541. if token.name == 'array_right':
  542. return elems
  543. assert token.name == 'comma'
  544. def parse_statement(self, tokens):
  545. if (tokens.get(0).name == 'identifier' and
  546. tokens.get(1).name == 'assign'):
  547. lhs = tokens.pop().v
  548. tokens.pop()
  549. rhs = self.parse_expression(tokens)
  550. return Assignment(lhs, rhs)
  551. elif (tokens.get(0).name == 'identifier' and
  552. tokens.get(1).name == 'array_left'):
  553. name = tokens.pop().v
  554. tokens.pop()
  555. index = self.parse_expression(tokens)
  556. tokens.pop()
  557. tokens.pop()
  558. return ArrayAssignment(name, index, self.parse_expression(tokens))
  559. return Execute(self.parse_expression(tokens))
  560. def parse(self, code):
  561. statements = []
  562. for line in code.split("\n"):
  563. if '#' in line:
  564. line = line.split('#', 1)[0]
  565. line = line.strip(" ")
  566. if line:
  567. tokens = self.tokenize(line)
  568. statements.append(self.parse_statement(tokens))
  569. return Code(statements)
  570. def numpy_compile(code):
  571. parser = Parser()
  572. return InterpreterState(parser.parse(code))