PageRenderTime 84ms CodeModel.GetById 32ms RepoModel.GetById 0ms app.codeStats 1ms

/pypy/module/micronumpy/compile.py

https://bitbucket.org/armisael/pypy
Python | 624 lines | 528 code | 87 blank | 9 comment | 64 complexity | 3005f4ba44679decc0174e0ad7a10734 MD5 | raw file
  1. """ This is a set of tools for standalone compiling of numpy expressions.
  2. It should not be imported by the module itself
  3. """
  4. import re
  5. from pypy.interpreter.baseobjspace import InternalSpaceCache, W_Root
  6. from pypy.module.micronumpy import interp_boxes
  7. from pypy.module.micronumpy.interp_dtype import get_dtype_cache
  8. from pypy.module.micronumpy.interp_numarray import (Scalar, BaseArray,
  9. scalar_w, W_NDimArray, array)
  10. from pypy.module.micronumpy import interp_ufuncs
  11. from pypy.rlib.objectmodel import specialize, instantiate
  12. class BogusBytecode(Exception):
  13. pass
  14. class ArgumentMismatch(Exception):
  15. pass
  16. class ArgumentNotAnArray(Exception):
  17. pass
  18. class WrongFunctionName(Exception):
  19. pass
  20. class TokenizerError(Exception):
  21. pass
  22. class BadToken(Exception):
  23. pass
  24. SINGLE_ARG_FUNCTIONS = ["sum", "prod", "max", "min", "all", "any",
  25. "unegative", "flat"]
  26. TWO_ARG_FUNCTIONS = ["dot", 'take']
  27. class FakeSpace(object):
  28. w_ValueError = None
  29. w_TypeError = None
  30. w_IndexError = None
  31. w_OverflowError = None
  32. w_NotImplementedError = None
  33. w_None = None
  34. w_bool = "bool"
  35. w_int = "int"
  36. w_float = "float"
  37. w_list = "list"
  38. w_long = "long"
  39. w_tuple = 'tuple'
  40. w_slice = "slice"
  41. def __init__(self):
  42. """NOT_RPYTHON"""
  43. self.fromcache = InternalSpaceCache(self).getorbuild
  44. def _freeze_(self):
  45. return True
  46. def issequence_w(self, w_obj):
  47. return isinstance(w_obj, ListObject) or isinstance(w_obj, W_NDimArray)
  48. def isinstance_w(self, w_obj, w_tp):
  49. return w_obj.tp == w_tp
  50. def decode_index4(self, w_idx, size):
  51. if isinstance(w_idx, IntObject):
  52. return (self.int_w(w_idx), 0, 0, 1)
  53. else:
  54. assert isinstance(w_idx, SliceObject)
  55. start, stop, step = w_idx.start, w_idx.stop, w_idx.step
  56. if step == 0:
  57. return (0, size, 1, size)
  58. if start < 0:
  59. start += size
  60. if stop < 0:
  61. stop += size + 1
  62. if step < 0:
  63. lgt = (stop - start + 1) / step + 1
  64. else:
  65. lgt = (stop - start - 1) / step + 1
  66. return (start, stop, step, lgt)
  67. @specialize.argtype(1)
  68. def wrap(self, obj):
  69. if isinstance(obj, float):
  70. return FloatObject(obj)
  71. elif isinstance(obj, bool):
  72. return BoolObject(obj)
  73. elif isinstance(obj, int):
  74. return IntObject(obj)
  75. elif isinstance(obj, W_Root):
  76. return obj
  77. raise NotImplementedError
  78. def newlist(self, items):
  79. return ListObject(items)
  80. def listview(self, obj):
  81. assert isinstance(obj, ListObject)
  82. return obj.items
  83. fixedview = listview
  84. def float(self, w_obj):
  85. if isinstance(w_obj, FloatObject):
  86. return w_obj
  87. assert isinstance(w_obj, interp_boxes.W_GenericBox)
  88. return self.float(w_obj.descr_float(self))
  89. def float_w(self, w_obj):
  90. assert isinstance(w_obj, FloatObject)
  91. return w_obj.floatval
  92. def int_w(self, w_obj):
  93. if isinstance(w_obj, IntObject):
  94. return w_obj.intval
  95. elif isinstance(w_obj, FloatObject):
  96. return int(w_obj.floatval)
  97. raise NotImplementedError
  98. def int(self, w_obj):
  99. if isinstance(w_obj, IntObject):
  100. return w_obj
  101. assert isinstance(w_obj, interp_boxes.W_GenericBox)
  102. return self.int(w_obj.descr_int(self))
  103. def is_true(self, w_obj):
  104. assert isinstance(w_obj, BoolObject)
  105. return w_obj.boolval
  106. def is_w(self, w_obj, w_what):
  107. return w_obj is w_what
  108. def type(self, w_obj):
  109. return w_obj.tp
  110. def gettypefor(self, w_obj):
  111. return None
  112. def call_function(self, tp, w_dtype):
  113. return w_dtype
  114. @specialize.arg(1)
  115. def interp_w(self, tp, what):
  116. assert isinstance(what, tp)
  117. return what
  118. def allocate_instance(self, klass, w_subtype):
  119. return instantiate(klass)
  120. def newtuple(self, list_w):
  121. raise ValueError
  122. def len_w(self, w_obj):
  123. if isinstance(w_obj, ListObject):
  124. return len(w_obj.items)
  125. # XXX array probably
  126. assert False
  127. def exception_match(self, w_exc_type, w_check_class):
  128. # Good enough for now
  129. raise NotImplementedError
  130. class FloatObject(W_Root):
  131. tp = FakeSpace.w_float
  132. def __init__(self, floatval):
  133. self.floatval = floatval
  134. class BoolObject(W_Root):
  135. tp = FakeSpace.w_bool
  136. def __init__(self, boolval):
  137. self.boolval = boolval
  138. class IntObject(W_Root):
  139. tp = FakeSpace.w_int
  140. def __init__(self, intval):
  141. self.intval = intval
  142. class ListObject(W_Root):
  143. tp = FakeSpace.w_list
  144. def __init__(self, items):
  145. self.items = items
  146. class SliceObject(W_Root):
  147. tp = FakeSpace.w_slice
  148. def __init__(self, start, stop, step):
  149. self.start = start
  150. self.stop = stop
  151. self.step = step
  152. class InterpreterState(object):
  153. def __init__(self, code):
  154. self.code = code
  155. self.variables = {}
  156. self.results = []
  157. def run(self, space):
  158. self.space = space
  159. for stmt in self.code.statements:
  160. stmt.execute(self)
  161. class Node(object):
  162. def __eq__(self, other):
  163. return (self.__class__ == other.__class__ and
  164. self.__dict__ == other.__dict__)
  165. def __ne__(self, other):
  166. return not self == other
  167. def wrap(self, space):
  168. raise NotImplementedError
  169. def execute(self, interp):
  170. raise NotImplementedError
  171. class Assignment(Node):
  172. def __init__(self, name, expr):
  173. self.name = name
  174. self.expr = expr
  175. def execute(self, interp):
  176. interp.variables[self.name] = self.expr.execute(interp)
  177. def __repr__(self):
  178. return "%r = %r" % (self.name, self.expr)
  179. class ArrayAssignment(Node):
  180. def __init__(self, name, index, expr):
  181. self.name = name
  182. self.index = index
  183. self.expr = expr
  184. def execute(self, interp):
  185. arr = interp.variables[self.name]
  186. w_index = self.index.execute(interp)
  187. # cast to int
  188. if isinstance(w_index, FloatObject):
  189. w_index = IntObject(int(w_index.floatval))
  190. w_val = self.expr.execute(interp)
  191. assert isinstance(arr, BaseArray)
  192. arr.descr_setitem(interp.space, w_index, w_val)
  193. def __repr__(self):
  194. return "%s[%r] = %r" % (self.name, self.index, self.expr)
  195. class Variable(Node):
  196. def __init__(self, name):
  197. self.name = name.strip(" ")
  198. def execute(self, interp):
  199. return interp.variables[self.name]
  200. def __repr__(self):
  201. return 'v(%s)' % self.name
  202. class Operator(Node):
  203. def __init__(self, lhs, name, rhs):
  204. self.name = name
  205. self.lhs = lhs
  206. self.rhs = rhs
  207. def execute(self, interp):
  208. w_lhs = self.lhs.execute(interp)
  209. if isinstance(self.rhs, SliceConstant):
  210. w_rhs = self.rhs.wrap(interp.space)
  211. else:
  212. w_rhs = self.rhs.execute(interp)
  213. if not isinstance(w_lhs, BaseArray):
  214. # scalar
  215. dtype = get_dtype_cache(interp.space).w_float64dtype
  216. w_lhs = scalar_w(interp.space, dtype, w_lhs)
  217. assert isinstance(w_lhs, BaseArray)
  218. if self.name == '+':
  219. w_res = w_lhs.descr_add(interp.space, w_rhs)
  220. elif self.name == '*':
  221. w_res = w_lhs.descr_mul(interp.space, w_rhs)
  222. elif self.name == '-':
  223. w_res = w_lhs.descr_sub(interp.space, w_rhs)
  224. elif self.name == '->':
  225. assert not isinstance(w_rhs, Scalar)
  226. if isinstance(w_rhs, FloatObject):
  227. w_rhs = IntObject(int(w_rhs.floatval))
  228. assert isinstance(w_lhs, BaseArray)
  229. w_res = w_lhs.descr_getitem(interp.space, w_rhs)
  230. else:
  231. raise NotImplementedError
  232. if (not isinstance(w_res, BaseArray) and
  233. not isinstance(w_res, interp_boxes.W_GenericBox)):
  234. dtype = get_dtype_cache(interp.space).w_float64dtype
  235. w_res = scalar_w(interp.space, dtype, w_res)
  236. return w_res
  237. def __repr__(self):
  238. return '(%r %s %r)' % (self.lhs, self.name, self.rhs)
  239. class FloatConstant(Node):
  240. def __init__(self, v):
  241. self.v = float(v)
  242. def __repr__(self):
  243. return "Const(%s)" % self.v
  244. def wrap(self, space):
  245. return space.wrap(self.v)
  246. def execute(self, interp):
  247. return interp.space.wrap(self.v)
  248. class RangeConstant(Node):
  249. def __init__(self, v):
  250. self.v = int(v)
  251. def execute(self, interp):
  252. w_list = interp.space.newlist(
  253. [interp.space.wrap(float(i)) for i in range(self.v)]
  254. )
  255. dtype = get_dtype_cache(interp.space).w_float64dtype
  256. return array(interp.space, w_list, w_dtype=dtype, w_order=None)
  257. def __repr__(self):
  258. return 'Range(%s)' % self.v
  259. class Code(Node):
  260. def __init__(self, statements):
  261. self.statements = statements
  262. def __repr__(self):
  263. return "\n".join([repr(i) for i in self.statements])
  264. class ArrayConstant(Node):
  265. def __init__(self, items):
  266. self.items = items
  267. def wrap(self, space):
  268. return space.newlist([item.wrap(space) for item in self.items])
  269. def execute(self, interp):
  270. w_list = self.wrap(interp.space)
  271. dtype = get_dtype_cache(interp.space).w_float64dtype
  272. return array(interp.space, w_list, w_dtype=dtype, w_order=None)
  273. def __repr__(self):
  274. return "[" + ", ".join([repr(item) for item in self.items]) + "]"
  275. class SliceConstant(Node):
  276. def __init__(self, start, stop, step):
  277. # no negative support for now
  278. self.start = start
  279. self.stop = stop
  280. self.step = step
  281. def wrap(self, space):
  282. return SliceObject(self.start, self.stop, self.step)
  283. def execute(self, interp):
  284. return SliceObject(self.start, self.stop, self.step)
  285. def __repr__(self):
  286. return 'slice(%s,%s,%s)' % (self.start, self.stop, self.step)
  287. class Execute(Node):
  288. def __init__(self, expr):
  289. self.expr = expr
  290. def __repr__(self):
  291. return repr(self.expr)
  292. def execute(self, interp):
  293. interp.results.append(self.expr.execute(interp))
  294. class FunctionCall(Node):
  295. def __init__(self, name, args):
  296. self.name = name.strip(" ")
  297. self.args = args
  298. def __repr__(self):
  299. return "%s(%s)" % (self.name, ", ".join([repr(arg)
  300. for arg in self.args]))
  301. def execute(self, interp):
  302. arr = self.args[0].execute(interp)
  303. if not isinstance(arr, BaseArray):
  304. raise ArgumentNotAnArray
  305. if self.name in SINGLE_ARG_FUNCTIONS:
  306. if len(self.args) != 1 and self.name != 'sum':
  307. raise ArgumentMismatch
  308. if self.name == "sum":
  309. if len(self.args)>1:
  310. w_res = arr.descr_sum(interp.space,
  311. self.args[1].execute(interp))
  312. else:
  313. w_res = arr.descr_sum(interp.space)
  314. elif self.name == "prod":
  315. w_res = arr.descr_prod(interp.space)
  316. elif self.name == "max":
  317. w_res = arr.descr_max(interp.space)
  318. elif self.name == "min":
  319. w_res = arr.descr_min(interp.space)
  320. elif self.name == "any":
  321. w_res = arr.descr_any(interp.space)
  322. elif self.name == "all":
  323. w_res = arr.descr_all(interp.space)
  324. elif self.name == "unegative":
  325. neg = interp_ufuncs.get(interp.space).negative
  326. w_res = neg.call(interp.space, [arr])
  327. elif self.name == "flat":
  328. w_res = arr.descr_get_flatiter(interp.space)
  329. else:
  330. assert False # unreachable code
  331. elif self.name in TWO_ARG_FUNCTIONS:
  332. if len(self.args) != 2:
  333. raise ArgumentMismatch
  334. arg = self.args[1].execute(interp)
  335. if not isinstance(arg, BaseArray):
  336. raise ArgumentNotAnArray
  337. if not isinstance(arg, BaseArray):
  338. raise ArgumentNotAnArray
  339. if self.name == "dot":
  340. w_res = arr.descr_dot(interp.space, arg)
  341. elif self.name == 'take':
  342. w_res = arr.descr_take(interp.space, arg)
  343. else:
  344. assert False # unreachable code
  345. else:
  346. raise WrongFunctionName
  347. if isinstance(w_res, BaseArray):
  348. return w_res
  349. if isinstance(w_res, FloatObject):
  350. dtype = get_dtype_cache(interp.space).w_float64dtype
  351. elif isinstance(w_res, BoolObject):
  352. dtype = get_dtype_cache(interp.space).w_booldtype
  353. elif isinstance(w_res, interp_boxes.W_GenericBox):
  354. dtype = w_res.get_dtype(interp.space)
  355. else:
  356. dtype = None
  357. return scalar_w(interp.space, dtype, w_res)
  358. _REGEXES = [
  359. ('-?[\d\.]+', 'number'),
  360. ('\[', 'array_left'),
  361. (':', 'colon'),
  362. ('\w+', 'identifier'),
  363. ('\]', 'array_right'),
  364. ('(->)|[\+\-\*\/]', 'operator'),
  365. ('=', 'assign'),
  366. (',', 'comma'),
  367. ('\|', 'pipe'),
  368. ('\(', 'paren_left'),
  369. ('\)', 'paren_right'),
  370. ]
  371. REGEXES = []
  372. for r, name in _REGEXES:
  373. REGEXES.append((re.compile(r' *(' + r + ')'), name))
  374. del _REGEXES
  375. class Token(object):
  376. def __init__(self, name, v):
  377. self.name = name
  378. self.v = v
  379. def __repr__(self):
  380. return '(%s, %s)' % (self.name, self.v)
  381. empty = Token('', '')
  382. class TokenStack(object):
  383. def __init__(self, tokens):
  384. self.tokens = tokens
  385. self.c = 0
  386. def pop(self):
  387. token = self.tokens[self.c]
  388. self.c += 1
  389. return token
  390. def get(self, i):
  391. if self.c + i >= len(self.tokens):
  392. return empty
  393. return self.tokens[self.c + i]
  394. def remaining(self):
  395. return len(self.tokens) - self.c
  396. def push(self):
  397. self.c -= 1
  398. def __repr__(self):
  399. return repr(self.tokens[self.c:])
  400. class Parser(object):
  401. def tokenize(self, line):
  402. tokens = []
  403. while True:
  404. for r, name in REGEXES:
  405. m = r.match(line)
  406. if m is not None:
  407. g = m.group(0)
  408. tokens.append(Token(name, g))
  409. line = line[len(g):]
  410. if not line:
  411. return TokenStack(tokens)
  412. break
  413. else:
  414. raise TokenizerError(line)
  415. def parse_number_or_slice(self, tokens):
  416. start_tok = tokens.pop()
  417. if start_tok.name == 'colon':
  418. start = 0
  419. else:
  420. if tokens.get(0).name != 'colon':
  421. return FloatConstant(start_tok.v)
  422. start = int(start_tok.v)
  423. tokens.pop()
  424. if not tokens.get(0).name in ['colon', 'number']:
  425. stop = -1
  426. step = 1
  427. else:
  428. next = tokens.pop()
  429. if next.name == 'colon':
  430. stop = -1
  431. step = int(tokens.pop().v)
  432. else:
  433. stop = int(next.v)
  434. if tokens.get(0).name == 'colon':
  435. tokens.pop()
  436. step = int(tokens.pop().v)
  437. else:
  438. step = 1
  439. return SliceConstant(start, stop, step)
  440. def parse_expression(self, tokens, accept_comma=False):
  441. stack = []
  442. while tokens.remaining():
  443. token = tokens.pop()
  444. if token.name == 'identifier':
  445. if tokens.remaining() and tokens.get(0).name == 'paren_left':
  446. stack.append(self.parse_function_call(token.v, tokens))
  447. else:
  448. stack.append(Variable(token.v))
  449. elif token.name == 'array_left':
  450. stack.append(ArrayConstant(self.parse_array_const(tokens)))
  451. elif token.name == 'operator':
  452. stack.append(Variable(token.v))
  453. elif token.name == 'number' or token.name == 'colon':
  454. tokens.push()
  455. stack.append(self.parse_number_or_slice(tokens))
  456. elif token.name == 'pipe':
  457. stack.append(RangeConstant(tokens.pop().v))
  458. end = tokens.pop()
  459. assert end.name == 'pipe'
  460. elif accept_comma and token.name == 'comma':
  461. continue
  462. else:
  463. tokens.push()
  464. break
  465. if accept_comma:
  466. return stack
  467. stack.reverse()
  468. lhs = stack.pop()
  469. while stack:
  470. op = stack.pop()
  471. assert isinstance(op, Variable)
  472. rhs = stack.pop()
  473. lhs = Operator(lhs, op.name, rhs)
  474. return lhs
  475. def parse_function_call(self, name, tokens):
  476. args = []
  477. tokens.pop() # lparen
  478. while tokens.get(0).name != 'paren_right':
  479. args += self.parse_expression(tokens, accept_comma=True)
  480. return FunctionCall(name, args)
  481. def parse_array_const(self, tokens):
  482. elems = []
  483. while True:
  484. token = tokens.pop()
  485. if token.name == 'number':
  486. elems.append(FloatConstant(token.v))
  487. elif token.name == 'array_left':
  488. elems.append(ArrayConstant(self.parse_array_const(tokens)))
  489. else:
  490. raise BadToken()
  491. token = tokens.pop()
  492. if token.name == 'array_right':
  493. return elems
  494. assert token.name == 'comma'
  495. def parse_statement(self, tokens):
  496. if (tokens.get(0).name == 'identifier' and
  497. tokens.get(1).name == 'assign'):
  498. lhs = tokens.pop().v
  499. tokens.pop()
  500. rhs = self.parse_expression(tokens)
  501. return Assignment(lhs, rhs)
  502. elif (tokens.get(0).name == 'identifier' and
  503. tokens.get(1).name == 'array_left'):
  504. name = tokens.pop().v
  505. tokens.pop()
  506. index = self.parse_expression(tokens)
  507. tokens.pop()
  508. tokens.pop()
  509. return ArrayAssignment(name, index, self.parse_expression(tokens))
  510. return Execute(self.parse_expression(tokens))
  511. def parse(self, code):
  512. statements = []
  513. for line in code.split("\n"):
  514. if '#' in line:
  515. line = line.split('#', 1)[0]
  516. line = line.strip(" ")
  517. if line:
  518. tokens = self.tokenize(line)
  519. statements.append(self.parse_statement(tokens))
  520. return Code(statements)
  521. def numpy_compile(code):
  522. parser = Parser()
  523. return InterpreterState(parser.parse(code))