PageRenderTime 797ms CodeModel.GetById 13ms RepoModel.GetById 0ms app.codeStats 0ms

/sympycore/heads/mul.py

https://github.com/escheffel/pymaclab
Python | 350 lines | 330 code | 10 blank | 10 comment | 11 complexity | e4a563d2d8a9185904de616e14ae04b7 MD5 | raw file
  1. __all__ = ['MUL']
  2. from .base import Head, heads_precedence, Pair, Expr, ArithmeticHead
  3. from ..core import init_module
  4. init_module.import_heads()
  5. init_module.import_numbers()
  6. init_module.import_lowlevel_operations()
  7. class MulHead(ArithmeticHead, Head):
  8. """
  9. Algebra(MUL, <list of factors>)
  10. The list of factors is should not contain numeric expressions
  11. representing a coefficient, numeric part should be represented
  12. via TERM_COEFF expression:
  13. Algebra(TERM_COEFF, (Algebra(MUL, <list of factors>), <numeric part>))
  14. """
  15. op_mth = '__mul__'
  16. op_rmth = '__rmul__'
  17. def is_data_ok(self, cls, data):
  18. if type(data) in [tuple, list]:
  19. for a in data:
  20. if not isinstance(a, cls):
  21. return '%s data item must be %s instance but got %s' % (self, cls, type(a))
  22. else:
  23. return '%s data part must be a list but got %s' % (self, type(data))
  24. def __repr__(self): return 'MUL'
  25. def new(self, cls, operands, evaluate=True):
  26. operands = [op for op in operands if op!=1]
  27. n = len(operands)
  28. if n==0:
  29. return cls(NUMBER, 1)
  30. if n==1:
  31. return operands[0]
  32. return cls(MUL, operands)
  33. def reevaluate(self, cls, operands):
  34. r = operands[0]
  35. for op in operands[1:]:
  36. r *= op
  37. return r
  38. def base_exp(self, cls, expr):
  39. return expr, 1
  40. def data_to_str_and_precedence(self, cls, operands):
  41. m = len(operands)
  42. if m==0:
  43. return '1', heads_precedence.NUMBER
  44. if m==1:
  45. op = operands[0]
  46. return op.head.data_to_str_and_precedence(cls, op.data)
  47. mul_p = heads_precedence.MUL
  48. r = ''
  49. for op in operands:
  50. f, f_p = op.head.data_to_str_and_precedence(cls, op.data)
  51. if f=='1': continue
  52. if not r or r=='-':
  53. r += '('+f+')' if f_p<mul_p else f
  54. elif f.startswith('1/'):
  55. r += '*('+f+')' if f_p<mul_p else f[1:]
  56. else:
  57. r += '*('+f+')' if f_p<mul_p else '*'+f
  58. if not r:
  59. return '1', heads_precedence.NUMBER
  60. return r, mul_p
  61. def term_coeff(self, Algebra, expr):
  62. data = []
  63. coeff = 1
  64. for op in expr.data:
  65. t, c = op.head.term_coeff(Algebra, op)
  66. if c is not 1:
  67. coeff *= c
  68. if t.head is NUMBER:
  69. assert t.data==1,`t`
  70. else:
  71. data.append(t)
  72. else:
  73. data.append(op)
  74. if coeff is 1:
  75. return expr, 1
  76. return mul_new(Algebra, data), coeff
  77. def combine(self, cls, factors_list):
  78. """ Combine factors in a list and return result.
  79. """
  80. lst = []
  81. compart = 1
  82. for factor in factors_list:
  83. if factor.head is NUMBER:
  84. compart = compart * factor.data
  85. continue
  86. r = None
  87. b2, e2 = factor.head.base_exp(cls, factor)
  88. if lst and b2.head is MUL:
  89. l = b2.data
  90. if len(l)<=len(lst) and lst[-len(l):]==l:
  91. # x*a*b*(a*b)**2 -> x*(a*b)**3
  92. r = b2 ** (e2 + 1)
  93. del lst[-len(l):]
  94. if lst and r is None:
  95. b1, e1 = lst[-1].head.base_exp(cls, lst[-1])
  96. if b1==b2:
  97. # x*a**3*a**2 -> x*a**5
  98. r = b2 ** (e1 + e2)
  99. del lst[-1]
  100. if r is None:
  101. lst.append(factor)
  102. for i in range(2,len(lst)):
  103. b1, e1 = lst[-i-1].head.base_exp(cls, lst[-i-1]);
  104. if b1.head is MUL:
  105. c, l = b1.data
  106. if l == lst[-i:] and c==1:
  107. # x*(a*b)**2*a*b -> x*(a*b)**3
  108. r = b1 ** (e1 + 1)
  109. del lst[-i-1:]
  110. break
  111. if lst[-i:]==lst[-2*i:-i]:
  112. # x*a*b * a*b -> x*(a*b)**2
  113. r = cls(MUL, lst[-i:])**2
  114. del lst[-2*i:]
  115. break
  116. if r is not None:
  117. if r.head is NUMBER:
  118. compart = compart * r
  119. else:
  120. lst.append(r)
  121. if not lst:
  122. return compart
  123. if compart==1:
  124. if len(lst)==1:
  125. return lst[0]
  126. return cls(MUL, lst)
  127. return cls(TERM_COEFF, (cls(MUL, lst), compart))
  128. def non_commutative_mul(self, cls, lhs, rhs):
  129. head, data = rhs.pair
  130. if head is NUMBER:
  131. if data==1:
  132. return lhs
  133. return TERM_COEFF.new(cls, (lhs, data))
  134. if head is SYMBOL or head is POW:
  135. return self.combine(cls, lhs.data + [rhs])
  136. if head is TERM_COEFF:
  137. term, coeff = data
  138. return (lhs * term) * coeff
  139. if head is MUL:
  140. return self.combine(cls, lhs.data + rhs.data)
  141. raise NotImplementedError(`self, cls, lhs.pair, rhs.pair`)
  142. def non_commutative_mul_number(self, cls, lhs, rhs):
  143. return term_coeff_new(cls, (lhs, rhs))
  144. non_commutative_rmul_number = non_commutative_mul_number
  145. def pow(self, cls, base, exp):
  146. if exp==0:
  147. return cls(NUMBER, 1)
  148. if exp==1:
  149. return base
  150. if exp==-1:
  151. factors_list = [factor**-1 for factor in base.data]
  152. factors_list.reverse()
  153. return cls(MUL, factors_list)
  154. term, coeff = self.term_coeff(cls, base)
  155. if coeff!=1:
  156. return NUMBER.pow(cls, coeff, exp) * term**exp
  157. if isinstance(exp, Expr):
  158. h, d = exp.pair
  159. if h is NUMBER and isinstance(d, inttypes) and d>0:
  160. factors_list = base.data
  161. first = factors_list[0]
  162. last = factors_list[-1]
  163. a = last * first
  164. if a is not None and a.head is NUMBER: # todo: or a is commutative
  165. compart = NUMBER.pow_number(cls, a, d)
  166. rest = factors_list[1:-1]
  167. if not rest:
  168. return compart
  169. if len(rest)==1:
  170. middle = rest[0]
  171. else:
  172. middle = cls(MUL, rest)
  173. return compart * first * middle**d * last # could be optimized
  174. if h is NUMBER:
  175. exp = d
  176. return cls(POW, (base, exp))
  177. def pow_number(self, cls, base, exp):
  178. return self.pow(cls, base, cls(NUMBER, exp))
  179. def walk(self, func, cls, data, target):
  180. l = []
  181. flag = False
  182. for op in data:
  183. o = op.head.walk(func, cls, op.data, op)
  184. if op is not o:
  185. flag = True
  186. l.append(o)
  187. if flag:
  188. r = MUL.new(cls, l)
  189. return func(cls, r.head, r.data, r)
  190. return func(cls, self, data, target)
  191. def scan(self, proc, cls, operands, target):
  192. for operand in operands:
  193. operand.head.scan(proc, cls, operand.data, target)
  194. proc(cls, self, operands, target)
  195. def combine_mul_list(self, Algebra, data):
  196. """
  197. Combine mul operands of an multiplicative group in data.
  198. data will be changed in place.
  199. """
  200. commutative = Algebra.algebra_options.get('is_multiplicative_group_commutative')
  201. coeff = 1
  202. if commutative:
  203. d = {}
  204. for op in data:
  205. base, exp = op.head.base_exp(Algebra, op)
  206. base_exp_dict_add_item(Algebra, d, base, exp)
  207. data[:] = [pow_new(Algebra, base_exp) for base_exp in d.iteritems()]
  208. else:
  209. n = len(data)
  210. i0 = 0
  211. while 1:
  212. i = i0
  213. if i+1 >= n:
  214. break
  215. lhs = data[i]
  216. if lhs.head is NUMBER:
  217. coeff *= lhs.data
  218. del data[i]
  219. n -= 1
  220. continue
  221. rhs = data[i+1]
  222. if rhs.head is NUMBER:
  223. coeff *= rhs.data
  224. del data[i+1]
  225. n -= 1
  226. continue
  227. lbase, lexp = lhs.head.base_exp(Algebra, lhs)
  228. rbase, rexp = rhs.head.base_exp(Algebra, rhs)
  229. if lbase==rbase:
  230. exp = lexp + rexp
  231. if exp:
  232. del data[i+1]
  233. data[i] = pow_new(Algebra, (lbase, exp))
  234. i0 = i
  235. n -= 1
  236. else:
  237. del data[i:i+2]
  238. i0 = max(i - 1, 0)
  239. n -= 2
  240. elif not rexp:
  241. del data[i+1]
  242. n -= 1
  243. i0 = i
  244. elif not lexp:
  245. del data[i]
  246. n -= 1
  247. i0 = max(i-1,0)
  248. else:
  249. i0 += 1
  250. return coeff
  251. def to_TERM_COEFF_DICT(self, Algebra, data, expr):
  252. m = data[0].to(TERM_COEFF_DICT)
  253. for op in data[1:]:
  254. m *= op.to(TERM_COEFF_DICT)
  255. return m
  256. def to_ADD(self, Algebra, data, expr):
  257. m = data[0].to(ADD)
  258. for op in data[1:]:
  259. m *= op.head.to_ADD(Algebra, op.data, op)
  260. return m
  261. def algebra_pos(self, Algebra, expr):
  262. return expr
  263. def algebra_neg(self, Algebra, expr):
  264. if Algebra.algebra_options.get('evaluate_addition'):
  265. return self.algebra_mul_number(Algebra, expr, -1, False)
  266. return Algebra(NEG, expr)
  267. def algebra_add_number(self, Algebra, lhs, rhs, inplace):
  268. return self.algebra_add(Algebra, lhs, Algebra(NUMBER, rhs), inplace)
  269. def algebra_add(self, Algebra, lhs, rhs, inplace):
  270. if Algebra.algebra_options.get('evaluate_addition'):
  271. rhead, rdata = rhs.pair
  272. if rhead is TERM_COEFF_DICT or rhead is EXP_COEFF_DICT:
  273. rhs = rhs.head.to_ADD(Algebra, rdata, rhs)
  274. rhead, rdata = rhs.pair
  275. if rhead is ADD:
  276. data = [lhs] + rdata
  277. else:
  278. data = [lhs, rhs]
  279. ADD.combine_add_list(Algebra, data)
  280. return add_new(Algebra, data)
  281. return Algebra(ADD, [lhs, rhs])
  282. def algebra_mul_number(self, Algebra, lhs, rhs, inplace):
  283. return self.algebra_mul(Algebra, lhs, Algebra(NUMBER, rhs), inplace)
  284. def algebra_mul(self, Algebra, lhs, rhs, inplace):
  285. rhead, rdata = rhs.pair
  286. if rhead is BASE_EXP_DICT or rhead is TERM_COEFF:
  287. rhs = rhs.to(MUL)
  288. rhead, rdata = rhs.pair
  289. if inplace:
  290. data = lhs.data
  291. else:
  292. data = lhs.data[:]
  293. if rhead is MUL:
  294. data.extend(rdata)
  295. else:
  296. data.append(rhs)
  297. if Algebra.algebra_options.get('evaluate_multiplication'):
  298. coeff = self.combine_mul_list(Algebra, data)
  299. if not coeff:
  300. return Algebra(NUMBER, 0)
  301. if coeff != 1:
  302. if len(data)==1:
  303. return Algebra(TERM_COEFF, (data[0], coeff))
  304. data.insert(0, Algebra(NUMBER, coeff))
  305. if inplace:
  306. return mul(Algebra, lhs)
  307. return mul_new(Algebra, data)
  308. def algebra_pow_number(self, Algebra, lhs, rhs, inplace):
  309. if Algebra.algebra_options.get('evaluate_multiplication'):
  310. if rhs==1:
  311. return lhs
  312. if not rhs:
  313. return Algebra(NUMBER, 1)
  314. if rhs==-1:
  315. return Algebra(MUL, [op**-1 for op in lhs.data[::-1]])
  316. return super(type(self), self).algebra_pow_number(Algebra, lhs, rhs, inplace)
  317. MUL = MulHead()