PageRenderTime 47ms CodeModel.GetById 17ms RepoModel.GetById 0ms app.codeStats 0ms

/pypy/module/math/interp_math.py

https://bitbucket.org/SeanTater/pypy-bugfix-st
Python | 603 lines | 427 code | 58 blank | 118 comment | 109 complexity | b344629effd16ac3d6fb3af1c287b88a MD5 | raw file
  1. import math
  2. import sys
  3. from pypy.rlib import rfloat, unroll
  4. from pypy.interpreter.error import OperationError
  5. class State:
  6. def __init__(self, space):
  7. self.w_e = space.wrap(math.e)
  8. self.w_pi = space.wrap(math.pi)
  9. def get(space):
  10. return space.fromcache(State)
  11. def _get_double(space, w_x):
  12. if space.is_w(space.type(w_x), space.w_float):
  13. return space.float_w(w_x)
  14. else:
  15. return space.float_w(space.float(w_x))
  16. def math1(space, f, w_x):
  17. x = _get_double(space, w_x)
  18. try:
  19. y = f(x)
  20. except OverflowError:
  21. raise OperationError(space.w_OverflowError,
  22. space.wrap("math range error"))
  23. except ValueError:
  24. raise OperationError(space.w_ValueError,
  25. space.wrap("math domain error"))
  26. return space.wrap(y)
  27. math1._annspecialcase_ = 'specialize:arg(1)'
  28. def math1_w(space, f, w_x):
  29. x = _get_double(space, w_x)
  30. try:
  31. r = f(x)
  32. except OverflowError:
  33. raise OperationError(space.w_OverflowError,
  34. space.wrap("math range error"))
  35. except ValueError:
  36. raise OperationError(space.w_ValueError,
  37. space.wrap("math domain error"))
  38. return r
  39. math1_w._annspecialcase_ = 'specialize:arg(1)'
  40. def math2(space, f, w_x, w_snd):
  41. x = _get_double(space, w_x)
  42. snd = _get_double(space, w_snd)
  43. try:
  44. r = f(x, snd)
  45. except OverflowError:
  46. raise OperationError(space.w_OverflowError,
  47. space.wrap("math range error"))
  48. except ValueError:
  49. raise OperationError(space.w_ValueError,
  50. space.wrap("math domain error"))
  51. return space.wrap(r)
  52. math2._annspecialcase_ = 'specialize:arg(1)'
  53. def trunc(space, w_x):
  54. """Truncate x."""
  55. return space.trunc(w_x)
  56. def copysign(space, w_x, w_y):
  57. """Return x with the sign of y."""
  58. # No exceptions possible.
  59. x = _get_double(space, w_x)
  60. y = _get_double(space, w_y)
  61. return space.wrap(rfloat.copysign(x, y))
  62. def isinf(space, w_x):
  63. """Return True if x is infinity."""
  64. return space.wrap(rfloat.isinf(_get_double(space, w_x)))
  65. def isnan(space, w_x):
  66. """Return True if x is not a number."""
  67. return space.wrap(rfloat.isnan(_get_double(space, w_x)))
  68. def pow(space, w_x, w_y):
  69. """pow(x,y)
  70. Return x**y (x to the power of y).
  71. """
  72. return math2(space, math.pow, w_x, w_y)
  73. def cosh(space, w_x):
  74. """cosh(x)
  75. Return the hyperbolic cosine of x.
  76. """
  77. return math1(space, math.cosh, w_x)
  78. def ldexp(space, w_x, w_i):
  79. """ldexp(x, i) -> x * (2**i)
  80. """
  81. x = _get_double(space, w_x)
  82. if (space.isinstance_w(w_i, space.w_int) or
  83. space.isinstance_w(w_i, space.w_long)):
  84. try:
  85. exp = space.int_w(w_i)
  86. except OperationError, e:
  87. if not e.match(space, space.w_OverflowError):
  88. raise
  89. if space.is_true(space.lt(w_i, space.wrap(0))):
  90. exp = -sys.maxint
  91. else:
  92. exp = sys.maxint
  93. else:
  94. raise OperationError(space.w_TypeError,
  95. space.wrap("integer required for second argument"))
  96. try:
  97. r = math.ldexp(x, exp)
  98. except OverflowError:
  99. raise OperationError(space.w_OverflowError,
  100. space.wrap("math range error"))
  101. except ValueError:
  102. raise OperationError(space.w_ValueError,
  103. space.wrap("math domain error"))
  104. return space.wrap(r)
  105. def hypot(space, w_x, w_y):
  106. """hypot(x,y)
  107. Return the Euclidean distance, sqrt(x*x + y*y).
  108. """
  109. return math2(space, math.hypot, w_x, w_y)
  110. def tan(space, w_x):
  111. """tan(x)
  112. Return the tangent of x (measured in radians).
  113. """
  114. return math1(space, math.tan, w_x)
  115. def asin(space, w_x):
  116. """asin(x)
  117. Return the arc sine (measured in radians) of x.
  118. """
  119. return math1(space, math.asin, w_x)
  120. def fabs(space, w_x):
  121. """fabs(x)
  122. Return the absolute value of the float x.
  123. """
  124. return math1(space, math.fabs, w_x)
  125. def floor(space, w_x):
  126. """floor(x)
  127. Return the floor of x as a float.
  128. This is the largest integral value <= x.
  129. """
  130. x = _get_double(space, w_x)
  131. return space.wrap(math.floor(x))
  132. def sqrt(space, w_x):
  133. """sqrt(x)
  134. Return the square root of x.
  135. """
  136. return math1(space, math.sqrt, w_x)
  137. def frexp(space, w_x):
  138. """frexp(x)
  139. Return the mantissa and exponent of x, as pair (m, e).
  140. m is a float and e is an int, such that x = m * 2.**e.
  141. If x is 0, m and e are both 0. Else 0.5 <= abs(m) < 1.0.
  142. """
  143. mant, expo = math1_w(space, math.frexp, w_x)
  144. return space.newtuple([space.wrap(mant), space.wrap(expo)])
  145. degToRad = math.pi / 180.0
  146. def degrees(space, w_x):
  147. """degrees(x) -> converts angle x from radians to degrees
  148. """
  149. return space.wrap(_get_double(space, w_x) / degToRad)
  150. def _log_any(space, w_x, base):
  151. # base is supposed to be positive or 0.0, which means we use e
  152. try:
  153. if space.is_true(space.isinstance(w_x, space.w_long)):
  154. # special case to support log(extremely-large-long)
  155. num = space.bigint_w(w_x)
  156. result = num.log(base)
  157. else:
  158. x = _get_double(space, w_x)
  159. if base == 10.0:
  160. result = math.log10(x)
  161. else:
  162. result = math.log(x)
  163. if base != 0.0:
  164. den = math.log(base)
  165. result /= den
  166. except OverflowError:
  167. raise OperationError(space.w_OverflowError,
  168. space.wrap('math range error'))
  169. except ValueError:
  170. raise OperationError(space.w_ValueError,
  171. space.wrap('math domain error'))
  172. return space.wrap(result)
  173. def log(space, w_x, w_base=None):
  174. """log(x[, base]) -> the logarithm of x to the given base.
  175. If the base not specified, returns the natural logarithm (base e) of x.
  176. """
  177. if w_base is None:
  178. base = 0.0
  179. else:
  180. base = _get_double(space, w_base)
  181. if base <= 0.0:
  182. # just for raising the proper errors
  183. return math1(space, math.log, w_base)
  184. return _log_any(space, w_x, base)
  185. def log10(space, w_x):
  186. """log10(x) -> the base 10 logarithm of x.
  187. """
  188. return _log_any(space, w_x, 10.0)
  189. def fmod(space, w_x, w_y):
  190. """fmod(x,y)
  191. Return fmod(x, y), according to platform C. x % y may differ.
  192. """
  193. return math2(space, math.fmod, w_x, w_y)
  194. def atan(space, w_x):
  195. """atan(x)
  196. Return the arc tangent (measured in radians) of x.
  197. """
  198. return math1(space, math.atan, w_x)
  199. def ceil(space, w_x):
  200. """ceil(x)
  201. Return the ceiling of x as a float.
  202. This is the smallest integral value >= x.
  203. """
  204. return math1(space, math.ceil, w_x)
  205. def sinh(space, w_x):
  206. """sinh(x)
  207. Return the hyperbolic sine of x.
  208. """
  209. return math1(space, math.sinh, w_x)
  210. def cos(space, w_x):
  211. """cos(x)
  212. Return the cosine of x (measured in radians).
  213. """
  214. return math1(space, math.cos, w_x)
  215. def tanh(space, w_x):
  216. """tanh(x)
  217. Return the hyperbolic tangent of x.
  218. """
  219. return math1(space, math.tanh, w_x)
  220. def radians(space, w_x):
  221. """radians(x) -> converts angle x from degrees to radians
  222. """
  223. return space.wrap(_get_double(space, w_x) * degToRad)
  224. def sin(space, w_x):
  225. """sin(x)
  226. Return the sine of x (measured in radians).
  227. """
  228. return math1(space, math.sin, w_x)
  229. def atan2(space, w_y, w_x):
  230. """atan2(y, x)
  231. Return the arc tangent (measured in radians) of y/x.
  232. Unlike atan(y/x), the signs of both x and y are considered.
  233. """
  234. return math2(space, math.atan2, w_y, w_x)
  235. def modf(space, w_x):
  236. """modf(x)
  237. Return the fractional and integer parts of x. Both results carry the sign
  238. of x. The integer part is returned as a real.
  239. """
  240. frac, intpart = math1_w(space, math.modf, w_x)
  241. return space.newtuple([space.wrap(frac), space.wrap(intpart)])
  242. def exp(space, w_x):
  243. """exp(x)
  244. Return e raised to the power of x.
  245. """
  246. return math1(space, math.exp, w_x)
  247. def acos(space, w_x):
  248. """acos(x)
  249. Return the arc cosine (measured in radians) of x.
  250. """
  251. return math1(space, math.acos, w_x)
  252. def fsum(space, w_iterable):
  253. """Sum an iterable of floats, trying to keep precision."""
  254. w_iter = space.iter(w_iterable)
  255. inf_sum = special_sum = 0.0
  256. partials = []
  257. while True:
  258. try:
  259. w_value = space.next(w_iter)
  260. except OperationError, e:
  261. if not e.match(space, space.w_StopIteration):
  262. raise
  263. break
  264. v = _get_double(space, w_value)
  265. original = v
  266. added = 0
  267. for y in partials:
  268. if abs(v) < abs(y):
  269. v, y = y, v
  270. hi = v + y
  271. yr = hi - v
  272. lo = y - yr
  273. if lo != 0.0:
  274. partials[added] = lo
  275. added += 1
  276. v = hi
  277. del partials[added:]
  278. if v != 0.0:
  279. if rfloat.isinf(v) or rfloat.isnan(v):
  280. if (not rfloat.isinf(original) and
  281. not rfloat.isnan(original)):
  282. raise OperationError(space.w_OverflowError,
  283. space.wrap("intermediate overflow"))
  284. if rfloat.isinf(original):
  285. inf_sum += original
  286. special_sum += original
  287. del partials[:]
  288. else:
  289. partials.append(v)
  290. if special_sum != 0.0:
  291. if rfloat.isnan(special_sum):
  292. raise OperationError(space.w_ValueError, space.wrap("-inf + inf"))
  293. return space.wrap(special_sum)
  294. hi = 0.0
  295. if partials:
  296. hi = partials[-1]
  297. j = 0
  298. lo = 0
  299. for j in range(len(partials) - 2, -1, -1):
  300. v = hi
  301. y = partials[j]
  302. assert abs(y) < abs(v)
  303. hi = v + y
  304. yr = hi - v
  305. lo = y - yr
  306. if lo != 0.0:
  307. break
  308. if j > 0 and (lo < 0.0 and partials[j - 1] < 0.0 or
  309. lo > 0.0 and partials[j - 1] > 0.0):
  310. y = lo * 2.0
  311. v = hi + y
  312. yr = v - hi
  313. if y == yr:
  314. hi = v
  315. return space.wrap(hi)
  316. def log1p(space, w_x):
  317. """Find log(x + 1)."""
  318. return math1(space, rfloat.log1p, w_x)
  319. def acosh(space, w_x):
  320. """Inverse hyperbolic cosine"""
  321. return math1(space, rfloat.acosh, w_x)
  322. def asinh(space, w_x):
  323. """Inverse hyperbolic sine"""
  324. return math1(space, rfloat.asinh, w_x)
  325. def atanh(space, w_x):
  326. """Inverse hyperbolic tangent"""
  327. return math1(space, rfloat.atanh, w_x)
  328. def expm1(space, w_x):
  329. """exp(x) - 1"""
  330. return math1(space, rfloat.expm1, w_x)
  331. def erf(space, w_x):
  332. """The error function"""
  333. return math1(space, _erf, w_x)
  334. def erfc(space, w_x):
  335. """The complementary error function"""
  336. return math1(space, _erfc, w_x)
  337. def gamma(space, w_x):
  338. """Compute the gamma function for x."""
  339. return math1(space, _gamma, w_x)
  340. def lgamma(space, w_x):
  341. """Compute the natural logarithm of the gamma function for x."""
  342. return math1(space, _lgamma, w_x)
  343. # Implementation of the error function, the complimentary error function, the
  344. # gamma function, and the natural log of the gamma function. These exist in
  345. # libm, but I hear those implementations are horrible.
  346. ERF_SERIES_CUTOFF = 1.5
  347. ERF_SERIES_TERMS = 25
  348. ERFC_CONTFRAC_CUTOFF = 30.
  349. ERFC_CONTFRAC_TERMS = 50
  350. _sqrtpi = 1.772453850905516027298167483341145182798
  351. def _erf_series(x):
  352. x2 = x * x
  353. acc = 0.
  354. fk = ERF_SERIES_TERMS + .5
  355. for i in range(ERF_SERIES_TERMS):
  356. acc = 2.0 + x2 * acc / fk
  357. fk -= 1.
  358. return acc * x * math.exp(-x2) / _sqrtpi
  359. def _erfc_contfrac(x):
  360. if x >= ERFC_CONTFRAC_CUTOFF:
  361. return 0.
  362. x2 = x * x
  363. a = 0.
  364. da = .5
  365. p = 1.
  366. p_last = 0.
  367. q = da + x2
  368. q_last = 1.
  369. for i in range(ERFC_CONTFRAC_TERMS):
  370. a += da
  371. da += 2.
  372. b = da + x2
  373. p_last, p = p, b * p - a * p_last
  374. q_last, q = q, b * q - a * q_last
  375. return p / q * x * math.exp(-x2) / _sqrtpi
  376. def _erf(x):
  377. if rfloat.isnan(x):
  378. return x
  379. absx = abs(x)
  380. if absx < ERF_SERIES_CUTOFF:
  381. return _erf_series(x)
  382. else:
  383. cf = _erfc_contfrac(absx)
  384. return 1. - cf if x > 0. else cf - 1.
  385. def _erfc(x):
  386. if rfloat.isnan(x):
  387. return x
  388. absx = abs(x)
  389. if absx < ERF_SERIES_CUTOFF:
  390. return 1. - _erf_series(x)
  391. else:
  392. cf = _erfc_contfrac(absx)
  393. return cf if x > 0. else 2. - cf
  394. def _sinpi(x):
  395. y = math.fmod(abs(x), 2.)
  396. n = int(rfloat.round_away(2. * y))
  397. if n == 0:
  398. r = math.sin(math.pi * y)
  399. elif n == 1:
  400. r = math.cos(math.pi * (y - .5))
  401. elif n == 2:
  402. r = math.sin(math.pi * (1. - y))
  403. elif n == 3:
  404. r = -math.cos(math.pi * (y - 1.5))
  405. elif n == 4:
  406. r = math.sin(math.pi * (y - 2.))
  407. else:
  408. raise AssertionError("should not reach")
  409. return rfloat.copysign(1., x) * r
  410. _lanczos_g = 6.024680040776729583740234375
  411. _lanczos_g_minus_half = 5.524680040776729583740234375
  412. _lanczos_num_coeffs = [
  413. 23531376880.410759688572007674451636754734846804940,
  414. 42919803642.649098768957899047001988850926355848959,
  415. 35711959237.355668049440185451547166705960488635843,
  416. 17921034426.037209699919755754458931112671403265390,
  417. 6039542586.3520280050642916443072979210699388420708,
  418. 1439720407.3117216736632230727949123939715485786772,
  419. 248874557.86205415651146038641322942321632125127801,
  420. 31426415.585400194380614231628318205362874684987640,
  421. 2876370.6289353724412254090516208496135991145378768,
  422. 186056.26539522349504029498971604569928220784236328,
  423. 8071.6720023658162106380029022722506138218516325024,
  424. 210.82427775157934587250973392071336271166969580291,
  425. 2.5066282746310002701649081771338373386264310793408
  426. ]
  427. _lanczos_den_coeffs = [
  428. 0.0, 39916800.0, 120543840.0, 150917976.0, 105258076.0, 45995730.0,
  429. 13339535.0, 2637558.0, 357423.0, 32670.0, 1925.0, 66.0, 1.0]
  430. LANCZOS_N = len(_lanczos_den_coeffs)
  431. _lanczos_n_iter = unroll.unrolling_iterable(range(LANCZOS_N))
  432. _lanczos_n_iter_back = unroll.unrolling_iterable(range(LANCZOS_N - 1, -1, -1))
  433. _gamma_integrals = [
  434. 1.0, 1.0, 2.0, 6.0, 24.0, 120.0, 720.0, 5040.0, 40320.0, 362880.0,
  435. 3628800.0, 39916800.0, 479001600.0, 6227020800.0, 87178291200.0,
  436. 1307674368000.0, 20922789888000.0, 355687428096000.0,
  437. 6402373705728000.0, 121645100408832000.0, 2432902008176640000.0,
  438. 51090942171709440000.0, 1124000727777607680000.0]
  439. def _lanczos_sum(x):
  440. num = 0.
  441. den = 0.
  442. assert x > 0.
  443. if x < 5.:
  444. for i in _lanczos_n_iter_back:
  445. num = num * x + _lanczos_num_coeffs[i]
  446. den = den * x + _lanczos_den_coeffs[i]
  447. else:
  448. for i in _lanczos_n_iter:
  449. num = num / x + _lanczos_num_coeffs[i]
  450. den = den / x + _lanczos_den_coeffs[i]
  451. return num / den
  452. def _gamma(x):
  453. if rfloat.isnan(x) or (rfloat.isinf(x) and x > 0.):
  454. return x
  455. if rfloat.isinf(x):
  456. raise ValueError("math domain error")
  457. if x == 0.:
  458. raise ValueError("math domain error")
  459. if x == math.floor(x):
  460. if x < 0.:
  461. raise ValueError("math domain error")
  462. if x < len(_gamma_integrals):
  463. return _gamma_integrals[int(x) - 1]
  464. absx = abs(x)
  465. if absx < 1e-20:
  466. r = 1. / x
  467. if rfloat.isinf(r):
  468. raise OverflowError("math range error")
  469. return r
  470. if absx > 200.:
  471. if x < 0.:
  472. return 0. / -_sinpi(x)
  473. else:
  474. raise OverflowError("math range error")
  475. y = absx + _lanczos_g_minus_half
  476. if absx > _lanczos_g_minus_half:
  477. q = y - absx
  478. z = q - _lanczos_g_minus_half
  479. else:
  480. q = y - _lanczos_g_minus_half
  481. z = q - absx
  482. z = z * _lanczos_g / y
  483. if x < 0.:
  484. r = -math.pi / _sinpi(absx) / absx * math.exp(y) / _lanczos_sum(absx)
  485. r -= z * r
  486. if absx < 140.:
  487. r /= math.pow(y, absx - .5)
  488. else:
  489. sqrtpow = math.pow(y, absx / 2. - .25)
  490. r /= sqrtpow
  491. r /= sqrtpow
  492. else:
  493. r = _lanczos_sum(absx) / math.exp(y)
  494. r += z * r
  495. if absx < 140.:
  496. r *= math.pow(y, absx - .5)
  497. else:
  498. sqrtpow = math.pow(y, absx / 2. - .25)
  499. r *= sqrtpow
  500. r *= sqrtpow
  501. if rfloat.isinf(r):
  502. raise OverflowError("math range error")
  503. return r
  504. def _lgamma(x):
  505. if rfloat.isnan(x):
  506. return x
  507. if rfloat.isinf(x):
  508. return rfloat.INFINITY
  509. if x == math.floor(x) and x <= 2.:
  510. if x <= 0.:
  511. raise ValueError("math range error")
  512. return 0.
  513. absx = abs(x)
  514. if absx < 1e-20:
  515. return -math.log(absx)
  516. if x > 0.:
  517. r = (math.log(_lanczos_sum(x)) - _lanczos_g + (x - .5) *
  518. (math.log(x + _lanczos_g - .5) - 1))
  519. else:
  520. r = (math.log(math.pi) - math.log(abs(_sinpi(absx))) - math.log(absx) -
  521. (math.log(_lanczos_sum(absx)) - _lanczos_g +
  522. (absx - .5) * (math.log(absx + _lanczos_g - .5) - 1)))
  523. if rfloat.isinf(r):
  524. raise OverflowError("math domain error")
  525. return r