PageRenderTime 47ms CodeModel.GetById 18ms RepoModel.GetById 1ms app.codeStats 0ms

/numpy/build_utils/src/apple_sgemv_fix.c

http://github.com/numpy/numpy
C | 229 lines | 145 code | 21 blank | 63 comment | 34 complexity | 13cc1eb927870fbe2bd6fd092ce8ff30 MD5 | raw file
Possible License(s): BSD-3-Clause, JSON, Unlicense
  1. /* This is a collection of ugly hacks to circumvent a bug in
  2. * Apple Accelerate framework's SGEMV subroutine.
  3. *
  4. * See: https://github.com/numpy/numpy/issues/4007
  5. *
  6. * SGEMV in Accelerate framework will segfault on MacOS X version 10.9
  7. * (aka Mavericks) if arrays are not aligned to 32 byte boundaries
  8. * and the CPU supports AVX instructions. This can produce segfaults
  9. * in np.dot.
  10. *
  11. * This patch overshadows the symbols cblas_sgemv, sgemv_ and sgemv
  12. * exported by Accelerate to produce the correct behavior. The MacOS X
  13. * version and CPU specs are checked on module import. If Mavericks and
  14. * AVX are detected the call to SGEMV is emulated with a call to SGEMM
  15. * if the arrays are not 32 byte aligned. If the exported symbols cannot
  16. * be overshadowed on module import, a fatal error is produced and the
  17. * process aborts. All the fixes are in a self-contained C file
  18. * and do not alter the multiarray C code. The patch is not applied
  19. * unless NumPy is configured to link with Apple's Accelerate
  20. * framework.
  21. *
  22. */
  23. #define NPY_NO_DEPRECATED_API NPY_API_VERSION
  24. #include "Python.h"
  25. #include "numpy/arrayobject.h"
  26. #include <string.h>
  27. #include <dlfcn.h>
  28. #include <stdlib.h>
  29. #include <stdio.h>
  30. /* ----------------------------------------------------------------- */
  31. /* Original cblas_sgemv */
  32. #define VECLIB_FILE "/System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/vecLib"
  33. enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
  34. enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};
  35. extern void cblas_xerbla(int info, const char *rout, const char *form, ...);
  36. typedef void cblas_sgemv_t(const enum CBLAS_ORDER order,
  37. const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
  38. const float alpha, const float *A, const int lda,
  39. const float *X, const int incX,
  40. const float beta, float *Y, const int incY);
  41. typedef void cblas_sgemm_t(const enum CBLAS_ORDER order,
  42. const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB,
  43. const int M, const int N, const int K,
  44. const float alpha, const float *A, const int lda,
  45. const float *B, const int ldb,
  46. const float beta, float *C, const int incC);
  47. typedef void fortran_sgemv_t( const char* trans, const int* m, const int* n,
  48. const float* alpha, const float* A, const int* ldA,
  49. const float* X, const int* incX,
  50. const float* beta, float* Y, const int* incY );
  51. static void *veclib = NULL;
  52. static cblas_sgemv_t *accelerate_cblas_sgemv = NULL;
  53. static cblas_sgemm_t *accelerate_cblas_sgemm = NULL;
  54. static fortran_sgemv_t *accelerate_sgemv = NULL;
  55. static int AVX_and_10_9 = 0;
  56. /* Dynamic check for AVX support
  57. * __builtin_cpu_supports("avx") is available in gcc 4.8,
  58. * but clang and icc do not currently support it. */
  59. #define cpu_supports_avx()\
  60. (system("sysctl -n machdep.cpu.features | grep -q AVX") == 0)
  61. /* Check if we are using MacOS X version 10.9 */
  62. #define using_mavericks()\
  63. (system("sw_vers -productVersion | grep -q 10\\.9\\.") == 0)
  64. __attribute__((destructor))
  65. static void unloadlib(void)
  66. {
  67. if (veclib) dlclose(veclib);
  68. }
  69. __attribute__((constructor))
  70. static void loadlib()
  71. /* automatically executed on module import */
  72. {
  73. char errormsg[1024];
  74. int AVX, MAVERICKS;
  75. memset((void*)errormsg, 0, sizeof(errormsg));
  76. /* check if the CPU supports AVX */
  77. AVX = cpu_supports_avx();
  78. /* check if the OS is MacOS X Mavericks */
  79. MAVERICKS = using_mavericks();
  80. /* we need the workaround when the CPU supports
  81. * AVX and the OS version is Mavericks */
  82. AVX_and_10_9 = AVX && MAVERICKS;
  83. /* load vecLib */
  84. veclib = dlopen(VECLIB_FILE, RTLD_LOCAL | RTLD_FIRST);
  85. if (!veclib) {
  86. veclib = NULL;
  87. sprintf(errormsg,"Failed to open vecLib from location '%s'.", VECLIB_FILE);
  88. Py_FatalError(errormsg); /* calls abort() and dumps core */
  89. }
  90. /* resolve Fortran SGEMV from Accelerate */
  91. accelerate_sgemv = (fortran_sgemv_t*) dlsym(veclib, "sgemv_");
  92. if (!accelerate_sgemv) {
  93. unloadlib();
  94. sprintf(errormsg,"Failed to resolve symbol 'sgemv_'.");
  95. Py_FatalError(errormsg);
  96. }
  97. /* resolve cblas_sgemv from Accelerate */
  98. accelerate_cblas_sgemv = (cblas_sgemv_t*) dlsym(veclib, "cblas_sgemv");
  99. if (!accelerate_cblas_sgemv) {
  100. unloadlib();
  101. sprintf(errormsg,"Failed to resolve symbol 'cblas_sgemv'.");
  102. Py_FatalError(errormsg);
  103. }
  104. /* resolve cblas_sgemm from Accelerate */
  105. accelerate_cblas_sgemm = (cblas_sgemm_t*) dlsym(veclib, "cblas_sgemm");
  106. if (!accelerate_cblas_sgemm) {
  107. unloadlib();
  108. sprintf(errormsg,"Failed to resolve symbol 'cblas_sgemm'.");
  109. Py_FatalError(errormsg);
  110. }
  111. }
  112. /* ----------------------------------------------------------------- */
  113. /* Fortran SGEMV override */
  114. void sgemv_( const char* trans, const int* m, const int* n,
  115. const float* alpha, const float* A, const int* ldA,
  116. const float* X, const int* incX,
  117. const float* beta, float* Y, const int* incY )
  118. {
  119. /* It is safe to use the original SGEMV if we are not using AVX on Mavericks
  120. * or the input arrays A, X and Y are all aligned on 32 byte boundaries. */
  121. #define BADARRAY(x) (((npy_intp)(void*)x) % 32)
  122. const int use_sgemm = AVX_and_10_9 && (BADARRAY(A) || BADARRAY(X) || BADARRAY(Y));
  123. if (!use_sgemm) {
  124. accelerate_sgemv(trans,m,n,alpha,A,ldA,X,incX,beta,Y,incY);
  125. return;
  126. }
  127. /* Arrays are misaligned, the CPU supports AVX, and we are running
  128. * Mavericks.
  129. *
  130. * Emulation of SGEMV with SGEMM:
  131. *
  132. * SGEMV allows vectors to be strided. SGEMM requires all arrays to be
  133. * contiguous along the leading dimension. To emulate striding in SGEMV
  134. * with the leading dimension arguments in SGEMM we compute
  135. *
  136. * Y = alpha * op(A) @ X + beta * Y
  137. *
  138. * as
  139. *
  140. * Y.T = alpha * X.T @ op(A).T + beta * Y.T
  141. *
  142. * Because Fortran uses column major order and X.T and Y.T are row vectors,
  143. * the leading dimensions of X.T and Y.T in SGEMM become equal to the
  144. * strides of the the column vectors X and Y in SGEMV. */
  145. switch (*trans) {
  146. case 'T':
  147. case 't':
  148. case 'C':
  149. case 'c':
  150. accelerate_cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
  151. 1, *n, *m, *alpha, X, *incX, A, *ldA, *beta, Y, *incY );
  152. break;
  153. case 'N':
  154. case 'n':
  155. accelerate_cblas_sgemm( CblasColMajor, CblasNoTrans, CblasTrans,
  156. 1, *m, *n, *alpha, X, *incX, A, *ldA, *beta, Y, *incY );
  157. break;
  158. default:
  159. cblas_xerbla(1, "SGEMV", "Illegal transpose setting: %c\n", *trans);
  160. }
  161. }
  162. /* ----------------------------------------------------------------- */
  163. /* Override for an alias symbol for sgemv_ in Accelerate */
  164. void sgemv (char *trans,
  165. const int *m, const int *n,
  166. const float *alpha,
  167. const float *A, const int *lda,
  168. const float *B, const int *incB,
  169. const float *beta,
  170. float *C, const int *incC)
  171. {
  172. sgemv_(trans,m,n,alpha,A,lda,B,incB,beta,C,incC);
  173. }
  174. /* ----------------------------------------------------------------- */
  175. /* cblas_sgemv override, based on Netlib CBLAS code */
  176. void cblas_sgemv(const enum CBLAS_ORDER order,
  177. const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
  178. const float alpha, const float *A, const int lda,
  179. const float *X, const int incX, const float beta,
  180. float *Y, const int incY)
  181. {
  182. char TA;
  183. if (order == CblasColMajor)
  184. {
  185. if (TransA == CblasNoTrans) TA = 'N';
  186. else if (TransA == CblasTrans) TA = 'T';
  187. else if (TransA == CblasConjTrans) TA = 'C';
  188. else
  189. {
  190. cblas_xerbla(2, "cblas_sgemv","Illegal TransA setting, %d\n", TransA);
  191. }
  192. sgemv_(&TA, &M, &N, &alpha, A, &lda, X, &incX, &beta, Y, &incY);
  193. }
  194. else if (order == CblasRowMajor)
  195. {
  196. if (TransA == CblasNoTrans) TA = 'T';
  197. else if (TransA == CblasTrans) TA = 'N';
  198. else if (TransA == CblasConjTrans) TA = 'N';
  199. else
  200. {
  201. cblas_xerbla(2, "cblas_sgemv", "Illegal TransA setting, %d\n", TransA);
  202. return;
  203. }
  204. sgemv_(&TA, &N, &M, &alpha, A, &lda, X, &incX, &beta, Y, &incY);
  205. }
  206. else
  207. cblas_xerbla(1, "cblas_sgemv", "Illegal Order setting, %d\n", order);
  208. }