/rpython/rlib/rstring.py
Python | 778 lines | 675 code | 79 blank | 24 comment | 155 complexity | 89c897309f0d8be59cf1f1f020c9018f MD5 | raw file
Possible License(s): AGPL-3.0, BSD-3-Clause, Apache-2.0
- """ String builder interface and string functions
- """
- import sys
- from rpython.annotator.model import (SomeObject, SomeString, s_None, SomeChar,
- SomeInteger, SomeUnicodeCodePoint, SomeUnicodeString, SomePBC)
- from rpython.rtyper.llannotation import SomePtr
- from rpython.rlib import jit
- from rpython.rlib.objectmodel import newlist_hint, resizelist_hint, specialize
- from rpython.rlib.rarithmetic import ovfcheck, LONG_BIT as BLOOM_WIDTH
- from rpython.rlib.buffer import Buffer
- from rpython.rlib.unicodedata import unicodedb_5_2_0 as unicodedb
- from rpython.rtyper.extregistry import ExtRegistryEntry
- from rpython.tool.pairtype import pairtype
- # -------------- public API for string functions -----------------------
- @specialize.argtype(0)
- def _isspace(char):
- if isinstance(char, str):
- return char.isspace()
- else:
- assert isinstance(char, unicode)
- return unicodedb.isspace(ord(char))
- @specialize.argtype(0, 1)
- def split(value, by=None, maxsplit=-1):
- if by is None:
- length = len(value)
- i = 0
- res = []
- while True:
- # find the beginning of the next word
- while i < length:
- if not _isspace(value[i]):
- break # found
- i += 1
- else:
- break # end of string, finished
- # find the end of the word
- if maxsplit == 0:
- j = length # take all the rest of the string
- else:
- j = i + 1
- while j < length and not _isspace(value[j]):
- j += 1
- maxsplit -= 1 # NB. if it's already < 0, it stays < 0
- # the word is value[i:j]
- res.append(value[i:j])
- # continue to look from the character following the space after the word
- i = j + 1
- return res
- if isinstance(value, unicode):
- assert isinstance(by, unicode)
- if isinstance(value, str):
- assert isinstance(by, str)
- if isinstance(value, list):
- assert isinstance(by, str)
- bylen = len(by)
- if bylen == 0:
- raise ValueError("empty separator")
- start = 0
- if bylen == 1:
- # fast path: uses str.rfind(character) and str.count(character)
- by = by[0] # annotator hack: string -> char
- cnt = count(value, by, 0, len(value))
- if 0 <= maxsplit < cnt:
- cnt = maxsplit
- res = newlist_hint(cnt + 1)
- while cnt > 0:
- next = find(value, by, start, len(value))
- assert next >= 0 # cannot fail due to the value.count above
- res.append(value[start:next])
- start = next + bylen
- cnt -= 1
- res.append(value[start:len(value)])
- return res
- if maxsplit > 0:
- res = newlist_hint(min(maxsplit + 1, len(value)))
- else:
- res = []
- while maxsplit != 0:
- next = find(value, by, start, len(value))
- if next < 0:
- break
- assert start >= 0
- res.append(value[start:next])
- start = next + bylen
- maxsplit -= 1 # NB. if it's already < 0, it stays < 0
- res.append(value[start:len(value)])
- return res
- @specialize.argtype(0, 1)
- def rsplit(value, by=None, maxsplit=-1):
- if by is None:
- res = []
- i = len(value) - 1
- while True:
- # starting from the end, find the end of the next word
- while i >= 0:
- if not _isspace(value[i]):
- break # found
- i -= 1
- else:
- break # end of string, finished
- # find the start of the word
- # (more precisely, 'j' will be the space character before the word)
- if maxsplit == 0:
- j = -1 # take all the rest of the string
- else:
- j = i - 1
- while j >= 0 and not _isspace(value[j]):
- j -= 1
- maxsplit -= 1 # NB. if it's already < 0, it stays < 0
- # the word is value[j+1:i+1]
- j1 = j + 1
- assert j1 >= 0
- res.append(value[j1:i+1])
- # continue to look from the character before the space before the word
- i = j - 1
- res.reverse()
- return res
- if isinstance(value, unicode):
- assert isinstance(by, unicode)
- if isinstance(value, str):
- assert isinstance(by, str)
- if isinstance(value, list):
- assert isinstance(by, str)
- if maxsplit > 0:
- res = newlist_hint(min(maxsplit + 1, len(value)))
- else:
- res = []
- end = len(value)
- bylen = len(by)
- if bylen == 0:
- raise ValueError("empty separator")
- while maxsplit != 0:
- next = rfind(value, by, 0, end)
- if next < 0:
- break
- res.append(value[next + bylen:end])
- end = next
- maxsplit -= 1 # NB. if it's already < 0, it stays < 0
- res.append(value[:end])
- res.reverse()
- return res
- @specialize.argtype(0, 1)
- @jit.elidable
- def replace(input, sub, by, maxsplit=-1):
- if isinstance(input, str):
- Builder = StringBuilder
- elif isinstance(input, unicode):
- Builder = UnicodeBuilder
- else:
- assert isinstance(input, list)
- Builder = ByteListBuilder
- if maxsplit == 0:
- return input
- if not sub:
- upper = len(input)
- if maxsplit > 0 and maxsplit < upper + 2:
- upper = maxsplit - 1
- assert upper >= 0
- try:
- result_size = ovfcheck(upper * len(by))
- result_size = ovfcheck(result_size + upper)
- result_size = ovfcheck(result_size + len(by))
- remaining_size = len(input) - upper
- result_size = ovfcheck(result_size + remaining_size)
- except OverflowError:
- raise
- builder = Builder(result_size)
- for i in range(upper):
- builder.append(by)
- builder.append(input[i])
- builder.append(by)
- builder.append_slice(input, upper, len(input))
- else:
- # First compute the exact result size
- cnt = count(input, sub, 0, len(input))
- if cnt > maxsplit and maxsplit > 0:
- cnt = maxsplit
- diff_len = len(by) - len(sub)
- try:
- result_size = ovfcheck(diff_len * cnt)
- result_size = ovfcheck(result_size + len(input))
- except OverflowError:
- raise
- builder = Builder(result_size)
- start = 0
- sublen = len(sub)
- while maxsplit != 0:
- next = find(input, sub, start, len(input))
- if next < 0:
- break
- builder.append_slice(input, start, next)
- builder.append(by)
- start = next + sublen
- maxsplit -= 1 # NB. if it's already < 0, it stays < 0
- builder.append_slice(input, start, len(input))
- return builder.build()
- def _normalize_start_end(length, start, end):
- if start < 0:
- start += length
- if start < 0:
- start = 0
- if end < 0:
- end += length
- if end < 0:
- end = 0
- elif end > length:
- end = length
- return start, end
- @specialize.argtype(0, 1)
- @jit.elidable
- def startswith(u_self, prefix, start=0, end=sys.maxint):
- length = len(u_self)
- start, end = _normalize_start_end(length, start, end)
- stop = start + len(prefix)
- if stop > end:
- return False
- for i in range(len(prefix)):
- if u_self[start+i] != prefix[i]:
- return False
- return True
- @specialize.argtype(0, 1)
- @jit.elidable
- def endswith(u_self, suffix, start=0, end=sys.maxint):
- length = len(u_self)
- start, end = _normalize_start_end(length, start, end)
- begin = end - len(suffix)
- if begin < start:
- return False
- for i in range(len(suffix)):
- if u_self[begin+i] != suffix[i]:
- return False
- return True
- @specialize.argtype(0, 1)
- def find(value, other, start, end):
- if ((isinstance(value, str) and isinstance(other, str)) or
- (isinstance(value, unicode) and isinstance(other, unicode))):
- return value.find(other, start, end)
- return _search(value, other, start, end, SEARCH_FIND)
- @specialize.argtype(0, 1)
- def rfind(value, other, start, end):
- if ((isinstance(value, str) and isinstance(other, str)) or
- (isinstance(value, unicode) and isinstance(other, unicode))):
- return value.rfind(other, start, end)
- return _search(value, other, start, end, SEARCH_RFIND)
- @specialize.argtype(0, 1)
- def count(value, other, start, end):
- if ((isinstance(value, str) and isinstance(other, str)) or
- (isinstance(value, unicode) and isinstance(other, unicode))):
- return value.count(other, start, end)
- return _search(value, other, start, end, SEARCH_COUNT)
- # -------------- substring searching helper ----------------
- # XXX a lot of code duplication with lltypesystem.rstr :-(
- SEARCH_COUNT = 0
- SEARCH_FIND = 1
- SEARCH_RFIND = 2
- def bloom_add(mask, c):
- return mask | (1 << (ord(c) & (BLOOM_WIDTH - 1)))
- def bloom(mask, c):
- return mask & (1 << (ord(c) & (BLOOM_WIDTH - 1)))
- @specialize.argtype(0, 1)
- def _search(value, other, start, end, mode):
- if start < 0:
- start = 0
- if end > len(value):
- end = len(value)
- if start > end:
- if mode == SEARCH_COUNT:
- return 0
- return -1
- count = 0
- n = end - start
- m = len(other)
- if m == 0:
- if mode == SEARCH_COUNT:
- return end - start + 1
- elif mode == SEARCH_RFIND:
- return end
- else:
- return start
- w = n - m
- if w < 0:
- if mode == SEARCH_COUNT:
- return 0
- return -1
- mlast = m - 1
- skip = mlast - 1
- mask = 0
- if mode != SEARCH_RFIND:
- for i in range(mlast):
- mask = bloom_add(mask, other[i])
- if other[i] == other[mlast]:
- skip = mlast - i - 1
- mask = bloom_add(mask, other[mlast])
- i = start - 1
- while i + 1 <= start + w:
- i += 1
- if value[i + m - 1] == other[m - 1]:
- for j in range(mlast):
- if value[i + j] != other[j]:
- break
- else:
- if mode != SEARCH_COUNT:
- return i
- count += 1
- i += mlast
- continue
- if i + m < len(value):
- c = value[i + m]
- else:
- c = '\0'
- if not bloom(mask, c):
- i += m
- else:
- i += skip
- else:
- if i + m < len(value):
- c = value[i + m]
- else:
- c = '\0'
- if not bloom(mask, c):
- i += m
- else:
- mask = bloom_add(mask, other[0])
- for i in range(mlast, 0, -1):
- mask = bloom_add(mask, other[i])
- if other[i] == other[0]:
- skip = i - 1
- i = start + w + 1
- while i - 1 >= start:
- i -= 1
- if value[i] == other[0]:
- for j in xrange(mlast, 0, -1):
- if value[i + j] != other[j]:
- break
- else:
- return i
- if i - 1 >= 0 and not bloom(mask, value[i - 1]):
- i -= m
- else:
- i -= skip
- else:
- if i - 1 >= 0 and not bloom(mask, value[i - 1]):
- i -= m
- if mode != SEARCH_COUNT:
- return -1
- return count
- # -------------- numeric parsing support --------------------
- def strip_spaces(s):
- # XXX this is not locale-dependent
- p = 0
- q = len(s)
- while p < q and s[p] in ' \f\n\r\t\v':
- p += 1
- while p < q and s[q-1] in ' \f\n\r\t\v':
- q -= 1
- assert q >= p # annotator hint, don't remove
- return s[p:q]
- class ParseStringError(Exception):
- def __init__(self, msg):
- self.msg = msg
- class InvalidBaseError(ParseStringError):
- """Signals an invalid base argument"""
- class ParseStringOverflowError(Exception):
- def __init__(self, parser):
- self.parser = parser
- # iterator-like class
- class NumberStringParser:
- def error(self):
- raise ParseStringError("invalid literal for %s() with base %d" %
- (self.fname, self.original_base))
- def __init__(self, s, literal, base, fname):
- self.fname = fname
- sign = 1
- if s.startswith('-'):
- sign = -1
- s = strip_spaces(s[1:])
- elif s.startswith('+'):
- s = strip_spaces(s[1:])
- self.sign = sign
- self.original_base = base
- if base == 0:
- if s.startswith('0x') or s.startswith('0X'):
- base = 16
- elif s.startswith('0b') or s.startswith('0B'):
- base = 2
- elif s.startswith('0'): # also covers the '0o' case
- base = 8
- else:
- base = 10
- elif base < 2 or base > 36:
- raise InvalidBaseError("%s() base must be >= 2 and <= 36" % fname)
- self.base = base
- if base == 16 and (s.startswith('0x') or s.startswith('0X')):
- s = s[2:]
- if base == 8 and (s.startswith('0o') or s.startswith('0O')):
- s = s[2:]
- if base == 2 and (s.startswith('0b') or s.startswith('0B')):
- s = s[2:]
- if not s:
- self.error()
- self.s = s
- self.n = len(s)
- self.i = 0
- def rewind(self):
- self.i = 0
- def next_digit(self): # -1 => exhausted
- if self.i < self.n:
- c = self.s[self.i]
- digit = ord(c)
- if '0' <= c <= '9':
- digit -= ord('0')
- elif 'A' <= c <= 'Z':
- digit = (digit - ord('A')) + 10
- elif 'a' <= c <= 'z':
- digit = (digit - ord('a')) + 10
- else:
- self.error()
- if digit >= self.base:
- self.error()
- self.i += 1
- return digit
- else:
- return -1
- def prev_digit(self):
- # After exhausting all n digits in next_digit(), you can walk them
- # again in reverse order by calling prev_digit() exactly n times
- i = self.i - 1
- assert i >= 0
- self.i = i
- c = self.s[i]
- digit = ord(c)
- if '0' <= c <= '9':
- digit -= ord('0')
- elif 'A' <= c <= 'Z':
- digit = (digit - ord('A')) + 10
- elif 'a' <= c <= 'z':
- digit = (digit - ord('a')) + 10
- else:
- raise AssertionError
- return digit
- # -------------- public API ---------------------------------
- INIT_SIZE = 100 # XXX tweak
- class AbstractStringBuilder(object):
- # This is not the real implementation!
- def __init__(self, init_size=INIT_SIZE):
- "NOT_RPYTHON"
- self._l = []
- self._size = 0
- def _grow(self, size):
- "NOT_RPYTHON"
- self._size += size
- def append(self, s):
- "NOT_RPYTHON"
- assert isinstance(s, self._tp)
- self._l.append(s)
- self._grow(len(s))
- def append_slice(self, s, start, end):
- "NOT_RPYTHON"
- assert isinstance(s, self._tp)
- assert 0 <= start <= end <= len(s)
- s = s[start:end]
- self._l.append(s)
- self._grow(len(s))
- def append_multiple_char(self, c, times):
- "NOT_RPYTHON"
- assert isinstance(c, self._tp)
- self._l.append(c * times)
- self._grow(times)
- def append_charpsize(self, s, size):
- "NOT_RPYTHON"
- assert size >= 0
- l = []
- for i in xrange(size):
- l.append(s[i])
- self._l.append(self._tp("").join(l))
- self._grow(size)
- def build(self):
- "NOT_RPYTHON"
- result = self._tp("").join(self._l)
- assert len(result) == self._size
- self._l = [result]
- return result
- def getlength(self):
- "NOT_RPYTHON"
- return self._size
- class StringBuilder(AbstractStringBuilder):
- _tp = str
- class UnicodeBuilder(AbstractStringBuilder):
- _tp = unicode
- class ByteListBuilder(object):
- def __init__(self, init_size=INIT_SIZE):
- assert init_size >= 0
- self.l = newlist_hint(init_size)
- @specialize.argtype(1)
- def append(self, s):
- l = self.l
- for c in s:
- l.append(c)
- @specialize.argtype(1)
- def append_slice(self, s, start, end):
- l = self.l
- for i in xrange(start, end):
- l.append(s[i])
- def append_multiple_char(self, c, times):
- assert isinstance(c, str)
- self.l.extend([c[0]] * times)
- def append_charpsize(self, s, size):
- assert size >= 0
- l = self.l
- for i in xrange(size):
- l.append(s[i])
- def build(self):
- return self.l
- def getlength(self):
- return len(self.l)
- # ------------------------------------------------------------
- # ----------------- implementation details -------------------
- # ------------------------------------------------------------
- class SomeStringBuilder(SomeObject):
- def method_append(self, s_str):
- if s_str != s_None:
- assert isinstance(s_str, (SomeString, SomeChar))
- return s_None
- def method_append_slice(self, s_str, s_start, s_end):
- if s_str != s_None:
- assert isinstance(s_str, SomeString)
- assert isinstance(s_start, SomeInteger)
- assert isinstance(s_end, SomeInteger)
- return s_None
- def method_append_multiple_char(self, s_char, s_times):
- assert isinstance(s_char, SomeChar)
- assert isinstance(s_times, SomeInteger)
- return s_None
- def method_append_charpsize(self, s_ptr, s_size):
- assert isinstance(s_ptr, SomePtr)
- assert isinstance(s_size, SomeInteger)
- return s_None
- def method_getlength(self):
- return SomeInteger(nonneg=True)
- def method_build(self):
- return SomeString()
- def rtyper_makerepr(self, rtyper):
- from rpython.rtyper.lltypesystem.rbuilder import stringbuilder_repr
- return stringbuilder_repr
- def rtyper_makekey(self):
- return self.__class__,
- def noneify(self):
- return self
- class SomeUnicodeBuilder(SomeObject):
- def method_append(self, s_str):
- if s_str != s_None:
- assert isinstance(s_str, (SomeUnicodeCodePoint, SomeUnicodeString))
- return s_None
- def method_append_slice(self, s_str, s_start, s_end):
- if s_str != s_None:
- assert isinstance(s_str, SomeUnicodeString)
- assert isinstance(s_start, SomeInteger)
- assert isinstance(s_end, SomeInteger)
- return s_None
- def method_append_multiple_char(self, s_char, s_times):
- assert isinstance(s_char, SomeUnicodeCodePoint)
- assert isinstance(s_times, SomeInteger)
- return s_None
- def method_append_charpsize(self, s_ptr, s_size):
- assert isinstance(s_ptr, SomePtr)
- assert isinstance(s_size, SomeInteger)
- return s_None
- def method_getlength(self):
- return SomeInteger(nonneg=True)
- def method_build(self):
- return SomeUnicodeString()
- def rtyper_makerepr(self, rtyper):
- from rpython.rtyper.lltypesystem.rbuilder import unicodebuilder_repr
- return unicodebuilder_repr
- def rtyper_makekey(self):
- return self.__class__,
- def noneify(self):
- return self
- class BaseEntry(object):
- def compute_result_annotation(self, s_init_size=None):
- if s_init_size is not None:
- assert isinstance(s_init_size, SomeInteger)
- if self.use_unicode:
- return SomeUnicodeBuilder()
- return SomeStringBuilder()
- def specialize_call(self, hop):
- return hop.r_result.rtyper_new(hop)
- class StringBuilderEntry(BaseEntry, ExtRegistryEntry):
- _about_ = StringBuilder
- use_unicode = False
- class UnicodeBuilderEntry(BaseEntry, ExtRegistryEntry):
- _about_ = UnicodeBuilder
- use_unicode = True
- class __extend__(pairtype(SomeStringBuilder, SomeStringBuilder)):
- def union((obj1, obj2)):
- return obj1
- class __extend__(pairtype(SomeUnicodeBuilder, SomeUnicodeBuilder)):
- def union((obj1, obj2)):
- return obj1
- class PrebuiltStringBuilderEntry(ExtRegistryEntry):
- _type_ = StringBuilder
- def compute_annotation(self):
- return SomeStringBuilder()
- class PrebuiltUnicodeBuilderEntry(ExtRegistryEntry):
- _type_ = UnicodeBuilder
- def compute_annotation(self):
- return SomeUnicodeBuilder()
- #___________________________________________________________________
- # Support functions for SomeString.no_nul
- def assert_str0(fname):
- assert '\x00' not in fname, "NUL byte in string"
- return fname
- class Entry(ExtRegistryEntry):
- _about_ = assert_str0
- def compute_result_annotation(self, s_obj):
- if s_None.contains(s_obj):
- return s_obj
- assert isinstance(s_obj, (SomeString, SomeUnicodeString))
- if s_obj.no_nul:
- return s_obj
- new_s_obj = SomeObject.__new__(s_obj.__class__)
- new_s_obj.__dict__ = s_obj.__dict__.copy()
- new_s_obj.no_nul = True
- return new_s_obj
- def specialize_call(self, hop):
- hop.exception_cannot_occur()
- return hop.inputarg(hop.args_r[0], arg=0)
- def check_str0(fname):
- """A 'probe' to trigger a failure at translation time, if the
- string was not proved to not contain NUL characters."""
- assert '\x00' not in fname, "NUL byte in string"
- class Entry(ExtRegistryEntry):
- _about_ = check_str0
- def compute_result_annotation(self, s_obj):
- if not isinstance(s_obj, (SomeString, SomeUnicodeString)):
- return s_obj
- if not s_obj.no_nul:
- raise ValueError("Value is not no_nul")
- def specialize_call(self, hop):
- hop.exception_cannot_occur()