PageRenderTime 43ms CodeModel.GetById 14ms RepoModel.GetById 0ms app.codeStats 0ms

/languages/Python/modules/sympy/utilities/tests/test_pickling.py

https://bitbucket.org/ipre/calico
Python | 407 lines | 312 code | 66 blank | 29 comment | 53 complexity | e9271cb63d8adf4b1c2422e84d218b35 MD5 | raw file
Possible License(s): LGPL-2.1, LGPL-3.0, GPL-2.0, GPL-3.0, LGPL-2.0
  1. import copy
  2. import pickle
  3. import warnings
  4. import sys
  5. from sympy.utilities.pytest import XFAIL
  6. from sympy.core.basic import Atom, Basic
  7. from sympy.core.core import BasicMeta, BasicType, ClassRegistry
  8. from sympy.core.singleton import SingletonRegistry
  9. from sympy.core.symbol import Dummy, Symbol, Wild
  10. from sympy.core.numbers import (E, I, pi, oo, zoo, nan, Integer, Number,
  11. NumberSymbol, Rational, Float)
  12. from sympy.core.relational import (Equality, GreaterThan, LessThan, Relational,
  13. StrictGreaterThan, StrictLessThan, Unequality)
  14. from sympy.core.add import Add
  15. from sympy.core.mul import Mul
  16. from sympy.core.power import Pow
  17. from sympy.core.function import Derivative, Function, FunctionClass, Lambda,\
  18. WildFunction
  19. from sympy.core.sets import Interval
  20. from sympy.core.multidimensional import vectorize
  21. from sympy.functions import exp
  22. #from sympy.core.ast_parser import SymPyParser, SymPyTransformer
  23. from sympy.core.compatibility import callable
  24. from sympy.utilities.exceptions import SymPyDeprecationWarning
  25. from sympy import symbols, S
  26. excluded_attrs = set(['_assumptions', '_mhash'])
  27. def check(a, check_attr=True):
  28. """ Check that pickling and copying round-trips.
  29. """
  30. # The below hasattr() check will warn about is_Real in Python 2.5, so
  31. # disable this to keep the tests clean
  32. warnings.filterwarnings("ignore", category=SymPyDeprecationWarning)
  33. protocols = [0, 1, 2, copy.copy, copy.deepcopy]
  34. # Python 2.x doesn't support the third pickling protocol
  35. if sys.version_info[0] > 2:
  36. protocols.extend([3])
  37. for protocol in protocols:
  38. if callable(protocol):
  39. if isinstance(a, BasicType):
  40. # Classes can't be copied, but that's okay.
  41. return
  42. b = protocol(a)
  43. else:
  44. b = pickle.loads(pickle.dumps(a, protocol))
  45. d1 = dir(a)
  46. d2 = dir(b)
  47. assert d1==d2
  48. if not check_attr:
  49. continue
  50. def c(a, b, d):
  51. for i in d:
  52. if not hasattr(a, i) or i in excluded_attrs:
  53. continue
  54. attr = getattr(a, i)
  55. if not hasattr(attr, "__call__"):
  56. assert hasattr(b,i), i
  57. assert getattr(b,i) == attr
  58. c(a,b,d1)
  59. c(b,a,d2)
  60. warnings.filterwarnings("default", category=SymPyDeprecationWarning)
  61. #================== core =========================
  62. def test_core_basic():
  63. for c in (Atom, Atom(),
  64. Basic, Basic(),
  65. # XXX: dynamically created types are not picklable
  66. # BasicMeta, BasicMeta("test", (), {}),
  67. # BasicType, BasicType("test", (), {}),
  68. ClassRegistry, ClassRegistry(),
  69. SingletonRegistry, SingletonRegistry()):
  70. check(c)
  71. def test_core_symbol():
  72. # make the Symbol a unique name that doesn't class with any other
  73. # testing variable in this file since after this test the symbol
  74. # having the same name will be cached as noncommutative
  75. for c in (Dummy, Dummy("x", commutative=False), Symbol,
  76. Symbol("_issue_3130", commutative=False), Wild, Wild("x")):
  77. check(c)
  78. def test_core_numbers():
  79. for c in (Integer(2), Rational(2, 3), Float("1.2")):
  80. check(c)
  81. def test_core_relational():
  82. x = Symbol("x")
  83. y = Symbol("y")
  84. for c in (Equality, Equality(x,y), GreaterThan, GreaterThan(x, y),
  85. LessThan, LessThan(x,y), Relational, Relational(x,y),
  86. StrictGreaterThan, StrictGreaterThan(x,y), StrictLessThan,
  87. StrictLessThan(x,y), Unequality, Unequality(x,y)):
  88. check(c)
  89. def test_core_add():
  90. x = Symbol("x")
  91. for c in (Add, Add(x,4)):
  92. check(c)
  93. def test_core_mul():
  94. x = Symbol("x")
  95. for c in (Mul, Mul(x,4)):
  96. check(c)
  97. def test_core_power():
  98. x = Symbol("x")
  99. for c in (Pow, Pow(x,4)):
  100. check(c)
  101. def test_core_function():
  102. x = Symbol("x")
  103. for f in (Derivative, Derivative(x), Function, FunctionClass, Lambda,\
  104. WildFunction):
  105. check(f)
  106. @XFAIL
  107. def test_core_dynamicfunctions():
  108. # This fails because f is assumed to be a class at sympy.basic.function.f
  109. f = Function("f")
  110. check(f)
  111. def test_core_interval():
  112. for c in (Interval, Interval(0,2)):
  113. check(c)
  114. def test_core_multidimensional():
  115. for c in (vectorize, vectorize(0)):
  116. check(c)
  117. def test_Singletons():
  118. protocols = [0, 1, 2]
  119. if sys.version_info[0] > 2:
  120. protocols.extend([3])
  121. copiers = [copy.copy, copy.deepcopy]
  122. copiers += [lambda x: pickle.loads(pickle.dumps(x, proto))
  123. for proto in protocols]
  124. for obj in (Integer(-1), Integer(0), Integer(1), Rational(1, 2), pi, E, I,
  125. oo, -oo, zoo, nan, S.GoldenRatio, S.EulerGamma, S.Catalan,
  126. S.EmptySet, S.IdentityFunction):
  127. for func in copiers:
  128. assert func(obj) is obj
  129. #================== functions ===================
  130. from sympy.functions import (Piecewise, lowergamma, acosh,
  131. chebyshevu, chebyshevt, ln, chebyshevt_root, binomial, legendre,
  132. Heaviside, factorial, bernoulli, coth, tanh, assoc_legendre, sign,
  133. arg, asin, DiracDelta, re, rf, Abs, uppergamma, binomial, sinh, Ylm,
  134. cos, cot, acos, acot, gamma, bell, hermite, harmonic,
  135. LambertW, zeta, log, factorial, asinh, acoth, Zlm,
  136. cosh, dirichlet_eta, Eijk, loggamma, erf, ceiling, im, fibonacci,
  137. conjugate, tan, chebyshevu_root, floor, atanh, sqrt,
  138. RisingFactorial, sin, atan, ff, FallingFactorial, lucas, atan2,
  139. polygamma, exp)
  140. def test_functions():
  141. one_var = (acosh, ln, Heaviside, factorial, bernoulli, coth, tanh,
  142. sign, arg, asin, DiracDelta, re, Abs, sinh, cos, cot, acos, acot,
  143. gamma, bell, harmonic, LambertW, zeta, log, factorial, asinh,
  144. acoth, cosh, dirichlet_eta, loggamma, erf, ceiling, im, fibonacci,
  145. conjugate, tan, floor, atanh, sin, atan, lucas, exp)
  146. two_var = (rf, ff, lowergamma, chebyshevu, chebyshevt, binomial,
  147. atan2, polygamma, hermite, legendre, uppergamma)
  148. x, y, z = symbols("x,y,z")
  149. others = (chebyshevt_root, chebyshevu_root, Eijk(x, y, z),
  150. Piecewise( (0, x<-1), (x**2, x<=1), (x**3, True)),
  151. assoc_legendre)
  152. for cls in one_var:
  153. check(cls)
  154. c = cls(x)
  155. check(c)
  156. for cls in two_var:
  157. check(cls)
  158. c = cls(x, y)
  159. check(c)
  160. for cls in others:
  161. check(cls)
  162. #================== geometry ====================
  163. from sympy.geometry.entity import GeometryEntity
  164. from sympy.geometry.point import Point
  165. from sympy.geometry.ellipse import Circle, Ellipse
  166. from sympy.geometry.line import Line, LinearEntity, Ray, Segment
  167. from sympy.geometry.polygon import Polygon, RegularPolygon, Triangle
  168. def test_geometry():
  169. p1 = Point(1,2)
  170. p2 = Point(2,3)
  171. p3 = Point(0,0)
  172. p4 = Point(0,1)
  173. for c in (GeometryEntity, GeometryEntity(), Point, p1, Circle, Circle(p1,2),
  174. Ellipse, Ellipse(p1,3,4), Line, Line(p1,p2), LinearEntity,
  175. LinearEntity(p1,p2), Ray, Ray(p1,p2), Segment, Segment(p1,p2),
  176. Polygon, Polygon(p1,p2,p3,p4), RegularPolygon, RegularPolygon(p1,4,5),
  177. Triangle, Triangle(p1,p2,p3)):
  178. check(c, check_attr = False)
  179. #================== integrals ====================
  180. from sympy.integrals.integrals import Integral
  181. def test_integrals():
  182. x = Symbol("x")
  183. for c in (Integral, Integral(x)):
  184. check(c)
  185. #==================== logic =====================
  186. from sympy.core.logic import Logic
  187. def test_logic():
  188. for c in (Logic, Logic(1)):
  189. check(c)
  190. #================== matrices ====================
  191. from sympy.matrices.matrices import Matrix, SparseMatrix
  192. def test_matrices():
  193. for c in (Matrix, Matrix([1,2,3]), SparseMatrix, SparseMatrix([[1,2],[3,4]])):
  194. check(c)
  195. #================== ntheory =====================
  196. from sympy.ntheory.generate import Sieve
  197. def test_ntheory():
  198. for c in (Sieve, Sieve()):
  199. check(c)
  200. #================== physics =====================
  201. from sympy.physics.paulialgebra import Pauli
  202. from sympy.physics.units import Unit
  203. def test_physics():
  204. for c in (Unit, Unit("meter", "m"), Pauli, Pauli(1)):
  205. check(c)
  206. #================== plotting ====================
  207. # XXX: These tests are not complete, so XFAIL them
  208. @XFAIL
  209. def test_plotting():
  210. from sympy.plotting.color_scheme import ColorGradient, ColorScheme
  211. from sympy.plotting.managed_window import ManagedWindow
  212. from sympy.plotting.plot import Plot, ScreenShot
  213. from sympy.plotting.plot_axes import PlotAxes, PlotAxesBase, PlotAxesFrame, PlotAxesOrdinate
  214. from sympy.plotting.plot_camera import PlotCamera
  215. from sympy.plotting.plot_controller import PlotController
  216. from sympy.plotting.plot_curve import PlotCurve
  217. from sympy.plotting.plot_interval import PlotInterval
  218. from sympy.plotting.plot_mode import PlotMode
  219. from sympy.plotting.plot_modes import Cartesian2D, Cartesian3D, Cylindrical,\
  220. ParametricCurve2D, ParametricCurve3D, ParametricSurface, Polar, Spherical
  221. from sympy.plotting.plot_object import PlotObject
  222. from sympy.plotting.plot_surface import PlotSurface
  223. from sympy.plotting.plot_window import PlotWindow
  224. for c in (ColorGradient, ColorGradient(0.2,0.4), ColorScheme, ManagedWindow,
  225. ManagedWindow, Plot, ScreenShot, PlotAxes, PlotAxesBase,
  226. PlotAxesFrame, PlotAxesOrdinate, PlotCamera, PlotController,
  227. PlotCurve, PlotInterval, PlotMode, Cartesian2D, Cartesian3D,
  228. Cylindrical, ParametricCurve2D, ParametricCurve3D,
  229. ParametricSurface, Polar, Spherical, PlotObject, PlotSurface,
  230. PlotWindow):
  231. check(c)
  232. @XFAIL
  233. def test_plotting2():
  234. from sympy.plotting.color_scheme import ColorGradient, ColorScheme
  235. from sympy.plotting.managed_window import ManagedWindow
  236. from sympy.plotting.plot import Plot, ScreenShot
  237. from sympy.plotting.plot_axes import PlotAxes, PlotAxesBase, PlotAxesFrame, PlotAxesOrdinate
  238. from sympy.plotting.plot_camera import PlotCamera
  239. from sympy.plotting.plot_controller import PlotController
  240. from sympy.plotting.plot_curve import PlotCurve
  241. from sympy.plotting.plot_interval import PlotInterval
  242. from sympy.plotting.plot_mode import PlotMode
  243. from sympy.plotting.plot_modes import Cartesian2D, Cartesian3D, Cylindrical,\
  244. ParametricCurve2D, ParametricCurve3D, ParametricSurface, Polar, Spherical
  245. from sympy.plotting.plot_object import PlotObject
  246. from sympy.plotting.plot_surface import PlotSurface
  247. from sympy.plotting.plot_window import PlotWindow
  248. check(ColorScheme("rainbow"))
  249. check(Plot(1,visible=False))
  250. check(PlotAxes())
  251. #================== polys =======================
  252. from sympy.polys.polytools import Poly
  253. from sympy.polys.polyclasses import DMP, DMF, ANP
  254. from sympy.polys.rootoftools import RootOf, RootSum
  255. from sympy.polys.domains import (
  256. PythonIntegerRing,
  257. SymPyIntegerRing,
  258. SymPyRationalField,
  259. PolynomialRing,
  260. FractionField,
  261. ExpressionDomain,
  262. )
  263. def test_polys():
  264. x = Symbol("X")
  265. ZZ = PythonIntegerRing()
  266. QQ = SymPyRationalField()
  267. for c in (Poly, Poly(x, x)):
  268. check(c)
  269. for c in (DMP, DMP([[ZZ(1)],[ZZ(2)],[ZZ(3)]], ZZ)):
  270. check(c)
  271. for c in (DMF, DMF(([ZZ(1),ZZ(2)], [ZZ(1),ZZ(3)]), ZZ)):
  272. check(c)
  273. for c in (ANP, ANP([QQ(1),QQ(2)], [QQ(1),QQ(2),QQ(3)], QQ)):
  274. check(c)
  275. for c in (PythonIntegerRing, PythonIntegerRing()):
  276. check(c)
  277. for c in (SymPyIntegerRing, SymPyIntegerRing()):
  278. check(c)
  279. for c in (SymPyRationalField, SymPyRationalField()):
  280. check(c)
  281. for c in (PolynomialRing, PolynomialRing(ZZ, 'x', 'y')):
  282. check(c)
  283. for c in (FractionField, FractionField(ZZ, 'x', 'y')):
  284. check(c)
  285. for c in (ExpressionDomain, ExpressionDomain()):
  286. check(c)
  287. from sympy.polys.domains import PythonRationalField
  288. for c in (PythonRationalField, PythonRationalField()):
  289. check(c)
  290. from sympy.polys.domains import HAS_GMPY
  291. if HAS_GMPY:
  292. from sympy.polys.domains import GMPYIntegerRing, GMPYRationalField
  293. for c in (GMPYIntegerRing, GMPYIntegerRing()):
  294. check(c)
  295. for c in (GMPYRationalField, GMPYRationalField()):
  296. check(c)
  297. f = x**3 + x + 3
  298. g = exp
  299. for c in (RootOf, RootOf(f, 0), RootSum, RootSum(f, g)):
  300. check(c)
  301. #================== printing ====================
  302. from sympy.printing.latex import LatexPrinter
  303. from sympy.printing.mathml import MathMLPrinter
  304. from sympy.printing.pretty.pretty import PrettyPrinter
  305. from sympy.printing.pretty.stringpict import prettyForm, stringPict
  306. from sympy.printing.printer import Printer
  307. from sympy.printing.python import PythonPrinter
  308. def test_printing():
  309. for c in (LatexPrinter, LatexPrinter(), MathMLPrinter,
  310. PrettyPrinter, prettyForm, stringPict, stringPict("a"),
  311. Printer, Printer(), PythonPrinter, PythonPrinter()):
  312. check(c)
  313. @XFAIL
  314. def test_printing1():
  315. check(MathMLPrinter())
  316. @XFAIL
  317. def test_printing2():
  318. check(PrettyPrinter())
  319. #================== series ======================
  320. from sympy.series.limits import Limit
  321. from sympy.series.order import Order
  322. def test_series():
  323. e = Symbol("e")
  324. x = Symbol("x")
  325. for c in (Limit, Limit(e, x, 1), Order, Order(e)):
  326. check(c)
  327. #================== statistics ==================
  328. from sympy.statistics.distributions import ContinuousProbability, Normal, Sample, Uniform
  329. def test_statistics():
  330. x = Symbol("x")
  331. y = Symbol("y")
  332. for c in (ContinuousProbability, ContinuousProbability(), Normal,
  333. Normal(x,y), Sample, Sample([1,3,4]), Uniform, Uniform(x,y)):
  334. check(c)
  335. #================== concrete ==================
  336. from sympy.concrete.products import Product
  337. from sympy.concrete.summations import Sum
  338. def test_concrete():
  339. x = Symbol("x")
  340. for c in (Product, Product(x, (x, 2, 4)), Sum, Sum(x, (x, 2, 4))):
  341. check(c)