PageRenderTime 169ms CodeModel.GetById 61ms app.highlight 69ms RepoModel.GetById 31ms app.codeStats 1ms

/pypy/module/micronumpy/compile.py

https://bitbucket.org/mikestewart/pypy
Python | 673 lines | 571 code | 93 blank | 9 comment | 59 complexity | 14c817460ffe74961431f00c18293fb5 MD5 | raw file
  1
  2""" This is a set of tools for standalone compiling of numpy expressions.
  3It should not be imported by the module itself
  4"""
  5
  6import re
  7
  8from pypy.interpreter.baseobjspace import InternalSpaceCache, W_Root
  9from pypy.interpreter.error import OperationError
 10from pypy.module.micronumpy import interp_boxes
 11from pypy.module.micronumpy.interp_dtype import get_dtype_cache
 12from pypy.module.micronumpy.interp_numarray import (Scalar, BaseArray,
 13     scalar_w, W_NDimArray, array)
 14from pypy.module.micronumpy.interp_arrayops import where
 15from pypy.module.micronumpy import interp_ufuncs
 16from pypy.rlib.objectmodel import specialize, instantiate
 17
 18
 19class BogusBytecode(Exception):
 20    pass
 21
 22class ArgumentMismatch(Exception):
 23    pass
 24
 25class ArgumentNotAnArray(Exception):
 26    pass
 27
 28class WrongFunctionName(Exception):
 29    pass
 30
 31class TokenizerError(Exception):
 32    pass
 33
 34class BadToken(Exception):
 35    pass
 36
 37SINGLE_ARG_FUNCTIONS = ["sum", "prod", "max", "min", "all", "any",
 38                        "unegative", "flat", "tostring"]
 39TWO_ARG_FUNCTIONS = ["dot", 'take']
 40THREE_ARG_FUNCTIONS = ['where']
 41
 42class FakeSpace(object):
 43    w_ValueError = "ValueError"
 44    w_TypeError = "TypeError"
 45    w_IndexError = "IndexError"
 46    w_OverflowError = "OverflowError"
 47    w_NotImplementedError = "NotImplementedError"
 48    w_None = None
 49
 50    w_bool = "bool"
 51    w_int = "int"
 52    w_float = "float"
 53    w_list = "list"
 54    w_long = "long"
 55    w_tuple = 'tuple'
 56    w_slice = "slice"
 57    w_str = "str"
 58    w_unicode = "unicode"
 59
 60    def __init__(self):
 61        """NOT_RPYTHON"""
 62        self.fromcache = InternalSpaceCache(self).getorbuild
 63
 64    def _freeze_(self):
 65        return True
 66
 67    def issequence_w(self, w_obj):
 68        return isinstance(w_obj, ListObject) or isinstance(w_obj, W_NDimArray)
 69
 70    def isinstance_w(self, w_obj, w_tp):
 71        return w_obj.tp == w_tp
 72
 73    def decode_index4(self, w_idx, size):
 74        if isinstance(w_idx, IntObject):
 75            return (self.int_w(w_idx), 0, 0, 1)
 76        else:
 77            assert isinstance(w_idx, SliceObject)
 78            start, stop, step = w_idx.start, w_idx.stop, w_idx.step
 79            if step == 0:
 80                return (0, size, 1, size)
 81            if start < 0:
 82                start += size
 83            if stop < 0:
 84                stop += size + 1
 85            if step < 0:
 86                lgt = (stop - start + 1) / step + 1
 87            else:
 88                lgt = (stop - start - 1) / step + 1
 89            return (start, stop, step, lgt)
 90
 91    @specialize.argtype(1)
 92    def wrap(self, obj):
 93        if isinstance(obj, float):
 94            return FloatObject(obj)
 95        elif isinstance(obj, bool):
 96            return BoolObject(obj)
 97        elif isinstance(obj, int):
 98            return IntObject(obj)
 99        elif isinstance(obj, long):
100            return LongObject(obj)
101        elif isinstance(obj, W_Root):
102            return obj
103        elif isinstance(obj, str):
104            return StringObject(obj)
105        raise NotImplementedError
106
107    def newlist(self, items):
108        return ListObject(items)
109
110    def listview(self, obj):
111        assert isinstance(obj, ListObject)
112        return obj.items
113    fixedview = listview
114
115    def float(self, w_obj):
116        if isinstance(w_obj, FloatObject):
117            return w_obj
118        assert isinstance(w_obj, interp_boxes.W_GenericBox)
119        return self.float(w_obj.descr_float(self))
120
121    def float_w(self, w_obj):
122        assert isinstance(w_obj, FloatObject)
123        return w_obj.floatval
124
125    def int_w(self, w_obj):
126        if isinstance(w_obj, IntObject):
127            return w_obj.intval
128        elif isinstance(w_obj, FloatObject):
129            return int(w_obj.floatval)
130        elif isinstance(w_obj, SliceObject):
131            raise OperationError(self.w_TypeError, self.wrap("slice."))
132        raise NotImplementedError
133
134    def index(self, w_obj):
135        return self.wrap(self.int_w(w_obj))
136
137    def str_w(self, w_obj):
138        if isinstance(w_obj, StringObject):
139            return w_obj.v
140        raise NotImplementedError
141
142    def int(self, w_obj):
143        if isinstance(w_obj, IntObject):
144            return w_obj
145        assert isinstance(w_obj, interp_boxes.W_GenericBox)
146        return self.int(w_obj.descr_int(self))
147
148    def is_true(self, w_obj):
149        assert isinstance(w_obj, BoolObject)
150        return w_obj.boolval
151
152    def is_w(self, w_obj, w_what):
153        return w_obj is w_what
154
155    def type(self, w_obj):
156        return w_obj.tp
157
158    def gettypefor(self, w_obj):
159        return None
160
161    def call_function(self, tp, w_dtype):
162        return w_dtype
163
164    @specialize.arg(1)
165    def interp_w(self, tp, what):
166        assert isinstance(what, tp)
167        return what
168
169    def allocate_instance(self, klass, w_subtype):
170        return instantiate(klass)
171
172    def newtuple(self, list_w):
173        return ListObject(list_w)
174
175    def newdict(self):
176        return {}
177
178    def setitem(self, dict, item, value):
179        dict[item] = value
180
181    def len_w(self, w_obj):
182        if isinstance(w_obj, ListObject):
183            return len(w_obj.items)
184        # XXX array probably
185        assert False
186
187    def exception_match(self, w_exc_type, w_check_class):
188        # Good enough for now
189        raise NotImplementedError
190
191class FloatObject(W_Root):
192    tp = FakeSpace.w_float
193    def __init__(self, floatval):
194        self.floatval = floatval
195
196class BoolObject(W_Root):
197    tp = FakeSpace.w_bool
198    def __init__(self, boolval):
199        self.boolval = boolval
200
201class IntObject(W_Root):
202    tp = FakeSpace.w_int
203    def __init__(self, intval):
204        self.intval = intval
205
206class LongObject(W_Root):
207    tp = FakeSpace.w_long
208    def __init__(self, intval):
209        self.intval = intval
210
211class ListObject(W_Root):
212    tp = FakeSpace.w_list
213    def __init__(self, items):
214        self.items = items
215
216class SliceObject(W_Root):
217    tp = FakeSpace.w_slice
218    def __init__(self, start, stop, step):
219        self.start = start
220        self.stop = stop
221        self.step = step
222
223class StringObject(W_Root):
224    tp = FakeSpace.w_str
225    def __init__(self, v):
226        self.v = v
227
228class InterpreterState(object):
229    def __init__(self, code):
230        self.code = code
231        self.variables = {}
232        self.results = []
233
234    def run(self, space):
235        self.space = space
236        for stmt in self.code.statements:
237            stmt.execute(self)
238
239class Node(object):
240    def __eq__(self, other):
241        return (self.__class__ == other.__class__ and
242                self.__dict__ == other.__dict__)
243
244    def __ne__(self, other):
245        return not self == other
246
247    def wrap(self, space):
248        raise NotImplementedError
249
250    def execute(self, interp):
251        raise NotImplementedError
252
253class Assignment(Node):
254    def __init__(self, name, expr):
255        self.name = name
256        self.expr = expr
257
258    def execute(self, interp):
259        interp.variables[self.name] = self.expr.execute(interp)
260
261    def __repr__(self):
262        return "%r = %r" % (self.name, self.expr)
263
264class ArrayAssignment(Node):
265    def __init__(self, name, index, expr):
266        self.name = name
267        self.index = index
268        self.expr = expr
269
270    def execute(self, interp):
271        arr = interp.variables[self.name]
272        w_index = self.index.execute(interp)
273        # cast to int
274        if isinstance(w_index, FloatObject):
275            w_index = IntObject(int(w_index.floatval))
276        w_val = self.expr.execute(interp)
277        assert isinstance(arr, BaseArray)
278        arr.descr_setitem(interp.space, w_index, w_val)
279
280    def __repr__(self):
281        return "%s[%r] = %r" % (self.name, self.index, self.expr)
282
283class Variable(Node):
284    def __init__(self, name):
285        self.name = name.strip(" ")
286
287    def execute(self, interp):
288        return interp.variables[self.name]
289
290    def __repr__(self):
291        return 'v(%s)' % self.name
292
293class Operator(Node):
294    def __init__(self, lhs, name, rhs):
295        self.name = name
296        self.lhs = lhs
297        self.rhs = rhs
298
299    def execute(self, interp):
300        w_lhs = self.lhs.execute(interp)
301        if isinstance(self.rhs, SliceConstant):
302            w_rhs = self.rhs.wrap(interp.space)
303        else:
304            w_rhs = self.rhs.execute(interp)
305        if not isinstance(w_lhs, BaseArray):
306            # scalar
307            dtype = get_dtype_cache(interp.space).w_float64dtype
308            w_lhs = scalar_w(interp.space, dtype, w_lhs)
309        assert isinstance(w_lhs, BaseArray)
310        if self.name == '+':
311            w_res = w_lhs.descr_add(interp.space, w_rhs)
312        elif self.name == '*':
313            w_res = w_lhs.descr_mul(interp.space, w_rhs)
314        elif self.name == '-':
315            w_res = w_lhs.descr_sub(interp.space, w_rhs)
316        elif self.name == '->':
317            assert not isinstance(w_rhs, Scalar)
318            if isinstance(w_rhs, FloatObject):
319                w_rhs = IntObject(int(w_rhs.floatval))
320            assert isinstance(w_lhs, BaseArray)
321            w_res = w_lhs.descr_getitem(interp.space, w_rhs)
322        else:
323            raise NotImplementedError
324        if (not isinstance(w_res, BaseArray) and
325            not isinstance(w_res, interp_boxes.W_GenericBox)):
326            dtype = get_dtype_cache(interp.space).w_float64dtype
327            w_res = scalar_w(interp.space, dtype, w_res)
328        return w_res
329
330    def __repr__(self):
331        return '(%r %s %r)' % (self.lhs, self.name, self.rhs)
332
333class FloatConstant(Node):
334    def __init__(self, v):
335        self.v = float(v)
336
337    def __repr__(self):
338        return "Const(%s)" % self.v
339
340    def wrap(self, space):
341        return space.wrap(self.v)
342
343    def execute(self, interp):
344        return interp.space.wrap(self.v)
345
346class RangeConstant(Node):
347    def __init__(self, v):
348        self.v = int(v)
349
350    def execute(self, interp):
351        w_list = interp.space.newlist(
352            [interp.space.wrap(float(i)) for i in range(self.v)]
353        )
354        dtype = get_dtype_cache(interp.space).w_float64dtype
355        return array(interp.space, w_list, w_dtype=dtype, w_order=None)
356
357    def __repr__(self):
358        return 'Range(%s)' % self.v
359
360class Code(Node):
361    def __init__(self, statements):
362        self.statements = statements
363
364    def __repr__(self):
365        return "\n".join([repr(i) for i in self.statements])
366
367class ArrayConstant(Node):
368    def __init__(self, items):
369        self.items = items
370
371    def wrap(self, space):
372        return space.newlist([item.wrap(space) for item in self.items])
373
374    def execute(self, interp):
375        w_list = self.wrap(interp.space)
376        dtype = get_dtype_cache(interp.space).w_float64dtype
377        return array(interp.space, w_list, w_dtype=dtype, w_order=None)
378
379    def __repr__(self):
380        return "[" + ", ".join([repr(item) for item in self.items]) + "]"
381
382class SliceConstant(Node):
383    def __init__(self, start, stop, step):
384        # no negative support for now
385        self.start = start
386        self.stop = stop
387        self.step = step
388
389    def wrap(self, space):
390        return SliceObject(self.start, self.stop, self.step)
391
392    def execute(self, interp):
393        return SliceObject(self.start, self.stop, self.step)
394
395    def __repr__(self):
396        return 'slice(%s,%s,%s)' % (self.start, self.stop, self.step)
397
398class Execute(Node):
399    def __init__(self, expr):
400        self.expr = expr
401
402    def __repr__(self):
403        return repr(self.expr)
404
405    def execute(self, interp):
406        interp.results.append(self.expr.execute(interp))
407
408class FunctionCall(Node):
409    def __init__(self, name, args):
410        self.name = name.strip(" ")
411        self.args = args
412
413    def __repr__(self):
414        return "%s(%s)" % (self.name, ", ".join([repr(arg)
415                                                 for arg in self.args]))
416
417    def execute(self, interp):
418        arr = self.args[0].execute(interp)
419        if not isinstance(arr, BaseArray):
420            raise ArgumentNotAnArray
421        if self.name in SINGLE_ARG_FUNCTIONS:
422            if len(self.args) != 1 and self.name != 'sum':
423                raise ArgumentMismatch
424            if self.name == "sum":
425                if len(self.args)>1:
426                    w_res = arr.descr_sum(interp.space,
427                                          self.args[1].execute(interp))
428                else:
429                    w_res = arr.descr_sum(interp.space)
430            elif self.name == "prod":
431                w_res = arr.descr_prod(interp.space)
432            elif self.name == "max":
433                w_res = arr.descr_max(interp.space)
434            elif self.name == "min":
435                w_res = arr.descr_min(interp.space)
436            elif self.name == "any":
437                w_res = arr.descr_any(interp.space)
438            elif self.name == "all":
439                w_res = arr.descr_all(interp.space)
440            elif self.name == "unegative":
441                neg = interp_ufuncs.get(interp.space).negative
442                w_res = neg.call(interp.space, [arr])
443            elif self.name == "flat":
444                w_res = arr.descr_get_flatiter(interp.space)
445            elif self.name == "tostring":
446                arr.descr_tostring(interp.space)
447                w_res = None
448            else:
449                assert False # unreachable code
450        elif self.name in TWO_ARG_FUNCTIONS:
451            if len(self.args) != 2:
452                raise ArgumentMismatch
453            arg = self.args[1].execute(interp)
454            if not isinstance(arg, BaseArray):
455                raise ArgumentNotAnArray
456            if self.name == "dot":
457                w_res = arr.descr_dot(interp.space, arg)
458            elif self.name == 'take':
459                w_res = arr.descr_take(interp.space, arg)
460            else:
461                assert False # unreachable code
462        elif self.name in THREE_ARG_FUNCTIONS:
463            if len(self.args) != 3:
464                raise ArgumentMismatch
465            arg1 = self.args[1].execute(interp)
466            arg2 = self.args[2].execute(interp)
467            if not isinstance(arg1, BaseArray):
468                raise ArgumentNotAnArray
469            if not isinstance(arg2, BaseArray):
470                raise ArgumentNotAnArray
471            if self.name == "where":
472                w_res = where(interp.space, arr, arg1, arg2)
473            else:
474                assert False
475        else:
476            raise WrongFunctionName
477        if isinstance(w_res, BaseArray):
478            return w_res
479        if isinstance(w_res, FloatObject):
480            dtype = get_dtype_cache(interp.space).w_float64dtype
481        elif isinstance(w_res, BoolObject):
482            dtype = get_dtype_cache(interp.space).w_booldtype
483        elif isinstance(w_res, interp_boxes.W_GenericBox):
484            dtype = w_res.get_dtype(interp.space)
485        else:
486            dtype = None
487        return scalar_w(interp.space, dtype, w_res)
488
489_REGEXES = [
490    ('-?[\d\.]+', 'number'),
491    ('\[', 'array_left'),
492    (':', 'colon'),
493    ('\w+', 'identifier'),
494    ('\]', 'array_right'),
495    ('(->)|[\+\-\*\/]', 'operator'),
496    ('=', 'assign'),
497    (',', 'comma'),
498    ('\|', 'pipe'),
499    ('\(', 'paren_left'),
500    ('\)', 'paren_right'),
501]
502REGEXES = []
503
504for r, name in _REGEXES:
505    REGEXES.append((re.compile(r' *(' + r + ')'), name))
506del _REGEXES
507
508class Token(object):
509    def __init__(self, name, v):
510        self.name = name
511        self.v = v
512
513    def __repr__(self):
514        return '(%s, %s)' % (self.name, self.v)
515
516empty = Token('', '')
517
518class TokenStack(object):
519    def __init__(self, tokens):
520        self.tokens = tokens
521        self.c = 0
522
523    def pop(self):
524        token = self.tokens[self.c]
525        self.c += 1
526        return token
527
528    def get(self, i):
529        if self.c + i >= len(self.tokens):
530            return empty
531        return self.tokens[self.c + i]
532
533    def remaining(self):
534        return len(self.tokens) - self.c
535
536    def push(self):
537        self.c -= 1
538
539    def __repr__(self):
540        return repr(self.tokens[self.c:])
541
542class Parser(object):
543    def tokenize(self, line):
544        tokens = []
545        while True:
546            for r, name in REGEXES:
547                m = r.match(line)
548                if m is not None:
549                    g = m.group(0)
550                    tokens.append(Token(name, g))
551                    line = line[len(g):]
552                    if not line:
553                        return TokenStack(tokens)
554                    break
555            else:
556                raise TokenizerError(line)
557
558    def parse_number_or_slice(self, tokens):
559        start_tok = tokens.pop()
560        if start_tok.name == 'colon':
561            start = 0
562        else:
563            if tokens.get(0).name != 'colon':
564                return FloatConstant(start_tok.v)
565            start = int(start_tok.v)
566            tokens.pop()
567        if not tokens.get(0).name in ['colon', 'number']:
568            stop = -1
569            step = 1
570        else:
571            next = tokens.pop()
572            if next.name == 'colon':
573                stop = -1
574                step = int(tokens.pop().v)
575            else:
576                stop = int(next.v)
577                if tokens.get(0).name == 'colon':
578                    tokens.pop()
579                    step = int(tokens.pop().v)
580                else:
581                    step = 1
582        return SliceConstant(start, stop, step)
583
584
585    def parse_expression(self, tokens, accept_comma=False):
586        stack = []
587        while tokens.remaining():
588            token = tokens.pop()
589            if token.name == 'identifier':
590                if tokens.remaining() and tokens.get(0).name == 'paren_left':
591                    stack.append(self.parse_function_call(token.v, tokens))
592                else:
593                    stack.append(Variable(token.v))
594            elif token.name == 'array_left':
595                stack.append(ArrayConstant(self.parse_array_const(tokens)))
596            elif token.name == 'operator':
597                stack.append(Variable(token.v))
598            elif token.name == 'number' or token.name == 'colon':
599                tokens.push()
600                stack.append(self.parse_number_or_slice(tokens))
601            elif token.name == 'pipe':
602                stack.append(RangeConstant(tokens.pop().v))
603                end = tokens.pop()
604                assert end.name == 'pipe'
605            elif accept_comma and token.name == 'comma':
606                continue
607            else:
608                tokens.push()
609                break
610        if accept_comma:
611            return stack
612        stack.reverse()
613        lhs = stack.pop()
614        while stack:
615            op = stack.pop()
616            assert isinstance(op, Variable)
617            rhs = stack.pop()
618            lhs = Operator(lhs, op.name, rhs)
619        return lhs
620
621    def parse_function_call(self, name, tokens):
622        args = []
623        tokens.pop() # lparen
624        while tokens.get(0).name != 'paren_right':
625            args += self.parse_expression(tokens, accept_comma=True)
626        return FunctionCall(name, args)
627
628    def parse_array_const(self, tokens):
629        elems = []
630        while True:
631            token = tokens.pop()
632            if token.name == 'number':
633                elems.append(FloatConstant(token.v))
634            elif token.name == 'array_left':
635                elems.append(ArrayConstant(self.parse_array_const(tokens)))
636            else:
637                raise BadToken()
638            token = tokens.pop()
639            if token.name == 'array_right':
640                return elems
641            assert token.name == 'comma'
642
643    def parse_statement(self, tokens):
644        if (tokens.get(0).name == 'identifier' and
645            tokens.get(1).name == 'assign'):
646            lhs = tokens.pop().v
647            tokens.pop()
648            rhs = self.parse_expression(tokens)
649            return Assignment(lhs, rhs)
650        elif (tokens.get(0).name == 'identifier' and
651              tokens.get(1).name == 'array_left'):
652            name = tokens.pop().v
653            tokens.pop()
654            index = self.parse_expression(tokens)
655            tokens.pop()
656            tokens.pop()
657            return ArrayAssignment(name, index, self.parse_expression(tokens))
658        return Execute(self.parse_expression(tokens))
659
660    def parse(self, code):
661        statements = []
662        for line in code.split("\n"):
663            if '#' in line:
664                line = line.split('#', 1)[0]
665            line = line.strip(" ")
666            if line:
667                tokens = self.tokenize(line)
668                statements.append(self.parse_statement(tokens))
669        return Code(statements)
670
671def numpy_compile(code):
672    parser = Parser()
673    return InterpreterState(parser.parse(code))