PageRenderTime 28ms CodeModel.GetById 17ms RepoModel.GetById 0ms app.codeStats 0ms

/sympycore/heads/add.py

https://github.com/escheffel/pymaclab
Python | 341 lines | 333 code | 6 blank | 2 comment | 0 complexity | 8d957d5a7dfb2974424212e904bff8e5 MD5 | raw file
  1. __all__ = ['ADD']
  2. from .base import heads_precedence, ArithmeticHead
  3. from ..core import init_module
  4. init_module.import_heads()
  5. init_module.import_numbers()
  6. init_module.import_lowlevel_operations()
  7. @init_module
  8. def _init(module):
  9. from ..arithmetic.number_theory import multinomial_coefficients
  10. module.multinomial_coefficients = multinomial_coefficients
  11. class AddHead(ArithmeticHead):
  12. """
  13. AddHead represents addition n-ary operation where operands is
  14. given as a n-sequence of expressions. For example, expression 'a +
  15. 2*b' is 'Expr(ADD, (a, 2*b))' where ADD=AddHead()
  16. """
  17. op_mth = '__add__'
  18. op_rmth = '__radd__'
  19. def is_data_ok(self, cls, data):
  20. if type(data) in [tuple, list]:
  21. for a in data:
  22. if not isinstance(a, cls):
  23. return '%s data item must be %s instance but got %s' % (self, cls, type(a))
  24. else:
  25. return '%s data part must be a list but got %s' % (self, type(data))
  26. def __repr__(self): return 'ADD'
  27. def new(self, cls, operands, evaluate=True):
  28. if not evaluate:
  29. n = len(operands)
  30. if n==1:
  31. return operands[0]
  32. if n==0:
  33. return cls(NUMBER, 0)
  34. return cls(self, operands)
  35. d = {}
  36. l = []
  37. operands = list(operands)
  38. while operands:
  39. op = operands.pop(0)
  40. if op==0:
  41. continue
  42. head, data = op.pair
  43. if head is ADD:
  44. operands.extend(data)
  45. continue
  46. elif head is SUB:
  47. operands.append(data[0])
  48. for o in data[1:]:
  49. operands.append(-o)
  50. continue
  51. elif head is TERM_COEFF_DICT:
  52. for term, coeff in data.iteritems():
  53. n = len(d)
  54. dict_add_item(cls, d, term, coeff)
  55. if n < len(d):
  56. l.append(term)
  57. elif n > len(d):
  58. l.remove(term)
  59. else:
  60. term, coeff = op.head.term_coeff(cls, op)
  61. n = len(d)
  62. dict_add_item(cls, d, term, coeff)
  63. if n < len(d):
  64. l.append(term)
  65. elif n > len(d):
  66. l.remove(term)
  67. r = []
  68. one = cls(NUMBER, 1)
  69. for term in l:
  70. r.append(TERM_COEFF.new(cls, (term, d[term])))
  71. m = len(r)
  72. if m==0:
  73. return cls(NUMBER, 0)
  74. if m==1:
  75. return r[0]
  76. return cls(self, r)
  77. def reevaluate(self, cls, operands):
  78. r = cls(NUMBER, 0)
  79. for op in operands:
  80. r += op
  81. return r
  82. def data_to_str_and_precedence(self, cls, operands):
  83. m = len(operands)
  84. if m==0:
  85. return '0', heads_precedence.NUMBER
  86. if m==1:
  87. op = operands[0]
  88. return op.head.data_to_str_and_precedence(cls, op.data)
  89. add_p = heads_precedence.ADD
  90. r = ''
  91. evaluate_addition = cls.algebra_options.get('evaluate_addition')
  92. for op in operands:
  93. t,t_p = op.head.data_to_str_and_precedence(cls, op.data)
  94. if not r:
  95. r += '(' + t + ')' if t_p < add_p else t
  96. elif evaluate_addition and t.startswith('-'):
  97. r += ' - ' + t[1:]
  98. else:
  99. r += ' + (' + t + ')' if t_p < add_p else ' + ' + t
  100. return r, add_p
  101. def term_coeff(self, cls, expr):
  102. term_list = expr.data
  103. if not term_list:
  104. return cls(NUMBER, 0), 1
  105. if len(term_list)==1:
  106. expr = term_list[0]
  107. return expr.head.term_coeff(cls, expr)
  108. return expr, 1
  109. def neg(self, cls, expr):
  110. return cls(ADD, [-term for term in expr.data])
  111. def add(self, cls, lhs, rhs):
  112. term_list = lhs.data
  113. rhead, rdata = rhs.pair
  114. if rhead is ADD:
  115. return ADD.new(cls, lhs.data + rdata)
  116. if rhead is TERM_COEFF_DICT:
  117. rdata = [t * c for t, c in rdata.items()]
  118. return ADD.new(cls, lhs.data + rdata)
  119. if rhead is SUB:
  120. rdata = rdata[:1] + [-op for op in rdata[1:]]
  121. return ADD.new(cls, lhs.data + rdata)
  122. return ADD.new(cls, lhs.data + [rhs])
  123. inplace_add = add
  124. def sub(self, cls, lhs, rhs):
  125. return lhs + (-rhs)
  126. def commutative_mul(self, cls, lhs, rhs):
  127. rhead, rdata = rhs.pair
  128. if rhead is NUMBER:
  129. return ADD.new(cls, [op*rhs for op in lhs.data])
  130. if rhead is SYMBOL:
  131. return cls(BASE_EXP_DICT, {lhs:1, rhs:1})
  132. if rhead is ADD:
  133. if lhs==rhs:
  134. return lhs ** 2
  135. return cls(BASE_EXP_DICT, {lhs:1, rhs:1})
  136. if rhead is TERM_COEFF:
  137. term, coeff = rdata
  138. return (lhs * term) * coeff
  139. if rhead is POW:
  140. base, exp = rhs.data
  141. if lhs==base:
  142. return POW.new(cls, (base, exp + 1))
  143. return cls(BASE_EXP_DICT, {lhs:1, base:exp})
  144. if rhead is BASE_EXP_DICT:
  145. data = rdata.copy()
  146. dict_add_item(cls, data, lhs, 1)
  147. return BASE_EXP_DICT.new(cls, data)
  148. raise NotImplementedError(`self, lhs.pair, rhs.pair`)
  149. def commutative_mul_number(self, cls, lhs, rhs):
  150. return ADD.new(cls, [op*rhs for op in lhs.data])
  151. def pow(self, cls, base, exp):
  152. return POW.new(cls, (base, exp))
  153. pow_number = pow
  154. def walk(self, func, cls, data, target):
  155. l = []
  156. flag = False
  157. for op in data:
  158. o = op.head.walk(func, cls, op.data, op)
  159. if op is not o:
  160. flag = True
  161. l.append(o)
  162. if flag:
  163. r = ADD.new(cls, l)
  164. return func(cls, r.head, r.data, r)
  165. return func(cls, self, data, target)
  166. def scan(self, proc, cls, operands, target):
  167. for operand in operands:
  168. operand.head.scan(proc, cls, operand.data, target)
  169. proc(cls, self, operands, target)
  170. def expand(self, cls, expr):
  171. l = []
  172. for op in expr.data:
  173. h, d = op.pair
  174. l.append(h.expand(cls, op))
  175. return self.new(cls, l)
  176. def expand_intpow(self, cls, expr, intexp):
  177. if intexp<=1:
  178. return POW.new(cls, (expr, intexp))
  179. operands = expr.data
  180. mdata = multinomial_coefficients(len(operands), intexp)
  181. s = cls(NUMBER, 0)
  182. for exps, n in mdata.iteritems():
  183. m = cls(NUMBER, n)
  184. for i,e in enumerate(exps):
  185. m *= operands[i] ** e
  186. s += m
  187. return s
  188. def to_TERM_COEFF_DICT(self, Algebra, data, expr):
  189. s = Algebra(NUMBER, 0)
  190. for op in data:
  191. s += op.head.to_TERM_COEFF_DICT(Algebra, op.data, op)
  192. return s
  193. def to_ADD(self, Algebra, data, expr):
  194. return expr
  195. def algebra_pos(self, Algebra, expr):
  196. if Algebra.algebra_options.get('evaluate_addition'):
  197. if Algebra.algebra_options.get('is_additive_group_commutative'):
  198. return +ADD.to_TERM_COEFF_DICT(Algebra, expr.data, expr)
  199. return expr
  200. def algebra_neg(self, Algebra, expr):
  201. if Algebra.algebra_options.get('evaluate_addition'):
  202. if Algebra.algebra_options.get('is_additive_group_commutative'):
  203. return -ADD.to_TERM_COEFF_DICT(Algebra, expr.data, expr)
  204. return add_new(Algebra, [-op for op in expr.data[::-1]])
  205. return Algebra(NEG, expr)
  206. def combine_add_list(self, Algebra, data):
  207. """
  208. Combine add operands of an additive group in data.
  209. data will be changed in place.
  210. """
  211. commutative = Algebra.algebra_options.get('is_additive_group_commutative')
  212. if commutative:
  213. d = {}
  214. for op in data:
  215. term, coeff = op.head.term_coeff(Algebra, op)
  216. term_coeff_dict_add_item(Algebra, d, term, coeff)
  217. data[:] = [term_coeff_new(Algebra, term_coeff) for term_coeff in d.iteritems()]
  218. else:
  219. n = len(data)
  220. i0 = 0
  221. while 1:
  222. i = i0
  223. if i+1 >= n:
  224. break
  225. lhs = data[i]
  226. rhs = data[i+1]
  227. lterm, lcoeff = lhs.head.term_coeff(Algebra, lhs)
  228. rterm, rcoeff = rhs.head.term_coeff(Algebra, rhs)
  229. if lterm==rterm:
  230. coeff = lcoeff + rcoeff
  231. if coeff:
  232. del data[i+1]
  233. data[i] = term_coeff_new(Algebra, (lterm, coeff))
  234. i0 = i
  235. n -= 1
  236. else:
  237. del data[i:i+2]
  238. i0 = max(i - 1, 0)
  239. n -= 2
  240. elif not rcoeff:
  241. del data[i+1]
  242. n -= 1
  243. i0 = i
  244. elif not lcoeff:
  245. del data[i]
  246. n -= 1
  247. i0 = max(i-1,0)
  248. else:
  249. i0 += 1
  250. return data
  251. def algebra_add_number(self, Algebra, lhs, rhs, inplace):
  252. return self.algebra_add(Algebra, lhs, Algebra(NUMBER, rhs), inplace)
  253. def algebra_add(self, Algebra, lhs, rhs, inplace):
  254. rhead, rdata = rhs.pair
  255. if rhead is TERM_COEFF_DICT or rhead is EXP_COEFF_DICT or rhead is MUL or rhead is NEG:
  256. rhs = rhs.to(ADD)
  257. rhead, rdata = rhs.pair
  258. if inplace:
  259. data = lhs.data
  260. else:
  261. data = lhs.data[:]
  262. if rhead is ADD:
  263. data.extend(rdata)
  264. else:
  265. data.append(rhs)
  266. if Algebra.algebra_options.get('evaluate_addition'):
  267. self.combine_add_list(Algebra, data)
  268. if inplace:
  269. return add(Algebra, lhs)
  270. return add_new(Algebra, data)
  271. def algebra_mul_number(self, Algebra, lhs, rhs, inplace):
  272. ntype = type(rhs)
  273. if Algebra.algebra_options.get('is_additive_group_commutative'):
  274. if not rhs:
  275. return Algebra(NUMBER, 0)
  276. return ADD.to_TERM_COEFF_DICT(Algebra, lhs.data, lhs) * rhs
  277. else:
  278. if Algebra.algebra_options.get('evaluate_addition'):
  279. if rhs == 0:
  280. return Algebra(NUMBER, 0)
  281. if rhs == 1:
  282. return lhs
  283. if ntype in inttypes_set:
  284. if rhs > 0:
  285. # (x+y)*3 = x+y+x+y+x+y
  286. # TODO: optimize (x+y+x)*3 = x+y+x+x+y++x+x+y+x = x+y+2*x+y+2*x+y+x
  287. data = lhs.data * rhs
  288. else:
  289. data = [-op for op in (lhs.data * (-rhs))[::-1]]
  290. self.combine_add_list(Algebra, data)
  291. return add_new(Algebra, data)
  292. return mul_new(Algebra, [lhs, Algebra(NUMBER, rhs)])
  293. def algebra_mul(self, Algebra, lhs, rhs, inplace):
  294. ldata = lhs.data
  295. if Algebra.algebra_options.get('is_additive_group_commutative'):
  296. return ADD.to_TERM_COEFF_DICT(Algebra, lhs.data, lhs) * rhs
  297. else:
  298. if Algebra.algebra_options.get('evaluate_addition'):
  299. rhead, rdata = rhs.pair
  300. if rhead is NUMBER:
  301. return ADD.algebra_mul_number(Algebra, lhs, rdata, inplace)
  302. return super(type(self), self).algebra_mul(Algebra, lhs, rhs, inplace)
  303. return mul_new(Algebra, [lhs, rhs])
  304. ADD = AddHead()