PageRenderTime 26ms CodeModel.GetById 12ms RepoModel.GetById 0ms app.codeStats 1ms

/scipy/stats/_sobol.pyx

https://github.com/matthew-brett/scipy
Cython | 363 lines | 318 code | 34 blank | 11 comment | 40 complexity | cccdbc997b5f037a5046178be76ffbd4 MD5 | raw file
  1. from __future__ import division, absolute_import
  2. cimport cython
  3. cimport numpy as cnp
  4. import os
  5. import numpy as np
  6. cnp.import_array()
  7. # Parameters are linked to the direction numbers list.
  8. # See `initialize_direction_numbers` for more details.
  9. # Declared using DEF to be known at compilation time for ``poly`` et ``vinit``
  10. DEF MAXDIM = 21201 # max number of dimensions
  11. DEF MAXDEG = 18 # max polynomial degree
  12. DEF MAXBIT = 30 # max number of bits
  13. # Needed to be accessed with python
  14. cdef extern from *:
  15. """
  16. int MAXDIM_DEFINE = 21201;
  17. int MAXDEG_DEFINE = 18;
  18. int MAXBIT_DEFINE = 30;
  19. """
  20. int MAXDIM_DEFINE # max number of dimensions
  21. int MAXDEG_DEFINE # max polynomial degree
  22. int MAXBIT_DEFINE # max number of bits
  23. _MAXDIM = MAXDIM_DEFINE
  24. _MAXDEG = MAXDEG_DEFINE
  25. _MAXBIT = MAXBIT_DEFINE
  26. cdef int poly[MAXDIM]
  27. cdef int vinit[MAXDIM][MAXDEG]
  28. cdef int LARGEST_NUMBER = 2 ** MAXBIT # largest possible integer
  29. cdef float RECIPD = 1.0 / LARGEST_NUMBER # normalization constant
  30. cdef bint is_initialized = False
  31. def initialize_direction_numbers():
  32. """Load direction numbers.
  33. Direction numbers obtained using the search criterion D(6)
  34. up to the dimension 21201. This is the recommended choice by the authors.
  35. Original data can be found at https://web.maths.unsw.edu.au/~fkuo/sobol/.
  36. For additional details on the quantities involved, see [1].
  37. [1] S. Joe and F. Y. Kuo. Remark on algorithm 659: Implementing sobol's
  38. quasirandom sequence generator. ACM Trans. Math. Softw., 29(1):49-57,
  39. Mar. 2003.
  40. The C-code generated from putting the numbers in as literals is obscenely
  41. large/inefficient. The data file was thus packaged and save as an .npz data
  42. file for fast loading using the following code (this assumes that the file
  43. https://web.maths.unsw.edu.au/~fkuo/sobol/new-joe-kuo-6.21201 is present in
  44. the working directory):
  45. import pandas as pd
  46. import numpy as np
  47. # read in file content
  48. with open("./new-joe-kuo-6.21201", "r") as f:
  49. lines = f.readlines()
  50. rows = []
  51. # parse data from file line by line
  52. for l in lines[1:]:
  53. nums = [int(n) for n in l.replace(" \n", "").split()]
  54. d, s, a = nums[:3]
  55. vs = {f"v{i}": int(v) for i,v in enumerate(nums[3:])}
  56. rows.append({"d": d, "s": s, "a": a, **vs})
  57. # read in as dataframe, explicitly use zero values
  58. df = pd.DataFrame(rows).fillna(0).astype(int)
  59. # perform conversion
  60. df["poly"] = 2 * df["a"] + 2 ** df["s"] + 1
  61. # ensure columns are properly ordered
  62. vs = df[[f"v{i}" for i in range(18)]].values
  63. # add the degenerate d=1 column (not included in the data file)
  64. vs = np.vstack([vs[0][np.newaxis, :], vs])
  65. poly = np.concatenate([[1], df["poly"].values])
  66. # save as compressed .npz file to minimize size of distribution
  67. np.savez_compressed("./_sobol_direction_numbers", vinit=vs, poly=poly)
  68. """
  69. cdef int[:] dns_poly
  70. cdef int[:, :] dns_vinit
  71. global is_initialized
  72. if not is_initialized:
  73. dns = np.load(os.path.join(os.path.dirname(__file__), "_sobol_direction_numbers.npz"))
  74. dns_poly = dns["poly"].astype(np.intc)
  75. dns_vinit = dns["vinit"].astype(np.intc)
  76. for i in range(MAXDIM):
  77. poly[i] = dns_poly[i]
  78. for i in range(MAXDIM):
  79. for j in range(MAXDEG):
  80. vinit[i][j] = dns_vinit[i, j]
  81. is_initialized = True
  82. @cython.boundscheck(False)
  83. @cython.wraparound(False)
  84. cdef int bit_length(const int n):
  85. cdef int bits = 0
  86. cdef int nloc = n
  87. while nloc != 0:
  88. nloc >>= 1
  89. bits += 1
  90. return bits
  91. @cython.boundscheck(False)
  92. @cython.wraparound(False)
  93. cdef int low_0_bit(const int x) nogil:
  94. """Get the position of the right-most 0 bit for an integer.
  95. Examples:
  96. >>> low_0_bit(0)
  97. 1
  98. >>> low_0_bit(1)
  99. 2
  100. >>> low_0_bit(2)
  101. 1
  102. >>> low_0_bit(5)
  103. 2
  104. >>> low_0_bit(7)
  105. 4
  106. Parameters
  107. ----------
  108. x: int
  109. An integer.
  110. Returns
  111. -------
  112. position: int
  113. Position of the right-most 0 bit.
  114. """
  115. cdef int i = 0
  116. while x & (1 << i) != 0:
  117. i += 1
  118. return i + 1
  119. @cython.boundscheck(False)
  120. @cython.wraparound(False)
  121. cdef int ibits(const int x, const int pos, const int length) nogil:
  122. """Extract a sequence of bits from the bit representation of an integer.
  123. Extract the sequence from position `pos` (inclusive) to ``pos + length``
  124. (not inclusive), leftwise.
  125. Examples:
  126. >>> ibits(1, 0, 1)
  127. 1
  128. >>> ibits(1, 1, 1)
  129. 0
  130. >>> ibits(2, 0, 1)
  131. 0
  132. >>> ibits(2, 0, 2)
  133. 2
  134. >>> ibits(25, 1, 5)
  135. 12
  136. Parameters
  137. ----------
  138. x: int
  139. Integer to convert to bit representation.
  140. pos: int
  141. Starting position of sequence in bit representation of integer.
  142. length: int
  143. Length of sequence (number of bits).
  144. Returns
  145. -------
  146. ibits: int
  147. Integer value corresponding to bit sequence.
  148. """
  149. return (x >> pos) & ((1 << length) - 1)
  150. @cython.boundscheck(False)
  151. @cython.wraparound(False)
  152. cpdef void initialize_v(cnp.int_t[:, :] v, const int dim):
  153. cdef int d, i, j, k, m, p, newv, pow2
  154. if dim == 0:
  155. return
  156. # first row of v is all 1s
  157. for i in range(MAXBIT):
  158. v[0, i] = 1
  159. # Remaining rows of v (row 2 through dim, indexed by [1:dim])
  160. for d in range(1, dim):
  161. p = poly[d]
  162. m = bit_length(p) - 1
  163. # First m elements of row d comes from vinit
  164. for j in range(m):
  165. v[d, j] = vinit[d][j]
  166. # Fill in remaining elements of v as in Section 2 (top of pg. 90) of:
  167. #
  168. # P. Bratley and B. L. Fox. Algorithm 659: Implementing sobol's
  169. # quasirandom sequence generator. ACM Trans.
  170. # Math. Softw., 14(1):88-100, Mar. 1988.
  171. #
  172. for j in range(m, MAXBIT):
  173. newv = v[d, j - m]
  174. pow2 = 1
  175. for k in range(m):
  176. pow2 = pow2 << 1
  177. if (p >> (m - 1 - k)) & 1:
  178. newv = newv ^ (pow2 * v[d, j - k - 1])
  179. v[d, j] = newv
  180. # Multiply each column of v by power of 2:
  181. # v * [2^(maxbit-1), 2^(maxbit-2),..., 2, 1]
  182. pow2 = 1
  183. for d in range(MAXBIT):
  184. for i in range(dim):
  185. v[i, MAXBIT - 1 - d] *= pow2
  186. pow2 = pow2 << 1
  187. @cython.boundscheck(False)
  188. @cython.wraparound(False)
  189. cpdef void _draw(const int n,
  190. const int num_gen,
  191. const int dim,
  192. cnp.int_t[:, :] sv,
  193. cnp.int_t[:] quasi,
  194. cnp.float_t[:, :] result) nogil:
  195. cdef int i, j, l, qtmp
  196. cdef int num_gen_loc = num_gen
  197. for i in range(n):
  198. l = low_0_bit(num_gen_loc)
  199. for j in range(dim):
  200. qtmp = quasi[j] ^ sv[j, l - 1]
  201. quasi[j] = qtmp
  202. result[i, j] = qtmp * RECIPD
  203. num_gen_loc += 1
  204. @cython.boundscheck(False)
  205. @cython.wraparound(False)
  206. cpdef void _fast_forward(const int n,
  207. const int num_gen,
  208. const int dim,
  209. cnp.int_t[:, :] sv,
  210. cnp.int_t[:] quasi) nogil:
  211. cdef int i, j, l
  212. cdef int num_gen_loc = num_gen
  213. for i in range(n):
  214. l = low_0_bit(num_gen_loc)
  215. for j in range(dim):
  216. quasi[j] = quasi[j] ^ sv[j, l - 1]
  217. num_gen_loc += 1
  218. @cython.boundscheck(False)
  219. @cython.wraparound(False)
  220. cdef int cdot_pow2(cnp.int_t[:] a) nogil:
  221. cdef int i
  222. cdef int size = a.shape[0]
  223. cdef int z = 0
  224. cdef int pow2 = 1
  225. for i in range(size):
  226. z += a[size - 1 - i] * pow2
  227. pow2 *= 2
  228. return z
  229. @cython.boundscheck(False)
  230. @cython.wraparound(False)
  231. cpdef void _cscramble(const int dim,
  232. cnp.int_t[:, :, :] ltm,
  233. cnp.int_t[:, :] sv) nogil:
  234. cdef int d, i, j, k, l, lsm, lsmdp, p, t1, t2, vdj
  235. # Set diagonals of maxbit x maxbit arrays to 1
  236. for d in range(dim):
  237. for i in range(MAXBIT):
  238. ltm[d, i, i] = 1
  239. for d in range(dim):
  240. for j in range(MAXBIT):
  241. vdj = sv[d, j]
  242. l = 1
  243. t2 = 0
  244. for p in range(MAXBIT - 1, -1, -1):
  245. lsmdp = cdot_pow2(ltm[d, p, :])
  246. t1 = 0
  247. for k in range(MAXBIT):
  248. t1 += ibits(lsmdp, k, 1) * ibits(vdj, k, 1)
  249. t1 = t1 % 2
  250. t2 = t2 + t1 * l
  251. l = 2 * l
  252. sv[d, j] = t2
  253. @cython.boundscheck(False)
  254. @cython.wraparound(False)
  255. cpdef void _fill_p_cumulative(cnp.float_t[:] p,
  256. cnp.float_t[:] p_cumulative) nogil:
  257. cdef int i
  258. cdef int len_p = p.shape[0]
  259. cdef float tot = 0
  260. cdef float t
  261. for i in range(len_p):
  262. t = tot + p[i]
  263. p_cumulative[i] = t
  264. tot = t
  265. @cython.boundscheck(False)
  266. @cython.wraparound(False)
  267. cpdef void _categorize(cnp.float_t[:] draws,
  268. cnp.float_t[:] p_cumulative,
  269. cnp.int_t[:] result) nogil:
  270. cdef int i
  271. cdef int n_p = p_cumulative.shape[0]
  272. for i in range(draws.shape[0]):
  273. j = _find_index(p_cumulative, n_p, draws[i])
  274. result[j] = result[j] + 1
  275. @cython.boundscheck(False)
  276. @cython.wraparound(False)
  277. cdef int _find_index(cnp.float_t[:] p_cumulative,
  278. const int size,
  279. const float value) nogil:
  280. cdef int l = 0
  281. cdef int r = size - 1
  282. cdef int m
  283. while r > l:
  284. m = (l + r) // 2
  285. if value > p_cumulative[m]:
  286. l = m + 1
  287. else:
  288. r = m
  289. return r
  290. def _test_find_index(p_cumulative, size, value):
  291. # type: (np.ndarray, int, float) -> int
  292. """Wrapper for testing in python"""
  293. return _find_index(p_cumulative, size, value)