/cbits/VectorScalar.c

https://bitbucket.org/haskellnumerics/vector-vectorized · C · 261 lines · 68 code · 42 blank · 151 comment · 3 complexity · 2b7e77145ce20b7e0e149d177978b8c9 MD5 · raw file

  1. #include "simd.h"
  2. #include <tgmath.h>
  3. // #include <complex.h>
  4. /*
  5. note: using 32bit signed
  6. */
  7. /*
  8. we have 3 versions:
  9. avx2(untested)
  10. sse4/avx1 (and i guess sse2 by accident there too, well
  11. explicit dot product vs not, so some complation)
  12. plus scalar fallback
  13. */
  14. /*
  15. + - * , abs,
  16. note, for now using 32bit int types for array operation sizes and strides
  17. because you really shouldn't do more than 4gb of work in one sequential ffi call!
  18. */
  19. /* CPP macro to generate declarations and body for strided scalar versions */
  20. // UnaryOpScalarArray(arrayGeneralLog)
  21. // UnaryOpScalarArray(arrayGeneralAbs,fabs,type) \ // this is wrong for complex numbers
  22. /*
  23. double carg(double complex);
  24. double cimag(double complex);
  25. double creal(double complex);
  26. double complex cacos(double complex);
  27. double complex cacosh(double complex);
  28. double complex casin(double complex);
  29. double complex casinh(double complex);
  30. double complex catan(double complex);
  31. double complex catanh(double complex);
  32. double complex ccos(double complex);
  33. double complex ccosh(double complex);
  34. double complex cexp(double complex);
  35. double complex clog(double complex);
  36. double complex conj(double complex);
  37. double complex cproj(double complex);
  38. double complex csin(double complex);
  39. double complex csinh(double complex);
  40. double complex csqrt(double complex);
  41. double complex ctan(double complex);
  42. double complex ctanh(double complex);
  43. double complex cpow(double complex, double complex);
  44. */
  45. /*
  46. atan2()
  47. cbrt()
  48. ceil()
  49. copysign()
  50. erf()
  51. erfc()
  52. exp2()
  53. expm1()
  54. fdim()
  55. floor()
  56. fma()
  57. fmax()
  58. fmin()
  59. fmod()
  60. frexp()
  61. hypot()
  62. ilogb()
  63. ldexp()
  64. lgamma()
  65. llrint()
  66. llround()
  67. log10()
  68. log1p()
  69. log2()
  70. logb()
  71. lrint()
  72. lround()
  73. nearbyint()
  74. nextafter()
  75. nexttoward()
  76. remainder()
  77. remquo()
  78. rint()
  79. round()
  80. scalbn()
  81. scalbln()
  82. tgamma()
  83. trunc()
  84. */
  85. #define BinaryOpScalarArray(name,binaryop,type) void name##_##type(int32_t length, type * left, \
  86. int32_t leftStride ,type * right,int32_t rightStride, type * result, \
  87. int resultStride ); \
  88. \
  89. void name##_##type(int32_t length, type * left,int32_t leftStride ,type * right,int32_t rightStride, type * result, int32_t resultStride ){ \
  90. int32_t ix = 0 ; \
  91. for (ix = 0; ix < length ; ix ++){ \
  92. result[ix* resultStride]= (left[ix*leftStride] ) binaryop (right[ix*rightStride] ) ; \
  93. } \
  94. }
  95. #define UnaryOpScalarArray(name,op,type) void name##_##type(int32_t length, type * in,int32_t inStride,\
  96. type * out, int32_t outStride); \
  97. \
  98. void name##_##type(int32_t length, type * in,int32_t inStride, type * out, \
  99. int32_t outStride){ \
  100. int32_t ix = 0 ; \
  101. for(ix = 0 ; ix < length ; ix ++){ \
  102. /* WARNING / NOTE: in the complex float case, for reciprocal unary op, it does the divide using doubles, then casts back to complex float, so may have unexpectedly nice precision
  103. */ \
  104. out[ix*outStride] = (type) op(in[ix*inStride]); \
  105. } \
  106. }
  107. #define DotProductScalarArray(name,binaryop, type,init) type name##_##type(int32_t length, type * left, \
  108. int32_t leftStride ,type * right,int32_t rightStride); \
  109. \
  110. type name##_##type(int32_t length, type * left, int32_t leftStride ,type * right,int32_t rightStride){ \
  111. type res = init ; \
  112. int32_t ix = 0; \
  113. for(ix=0 ; ix < length ; ix++ ){ \
  114. res += binaryop((left[ix*leftStride] ),(right[ix*rightStride] )) ; \
  115. } \
  116. return res ; \
  117. }
  118. ////////////
  119. /// Scalar versions
  120. ///////////
  121. /*
  122. for now I shall assume that everything only works on < 4gb / 32 sized ranges
  123. also, will try to write things so that if any of the read (input) arrays
  124. */
  125. /*
  126. do a typedef for the complex types so that writing the macro stuff
  127. mixes well
  128. */
  129. #define negate(numexp) (-(numexp))
  130. #define reciprocal(numexp) (1.0/(numexp))
  131. #define mkNumFracOpsScalar(type) \
  132. BinaryOpScalarArray(arrayPlus,+,type) \
  133. BinaryOpScalarArray(arrayMinus,-,type) \
  134. BinaryOpScalarArray(arrayTimes,*,type) \
  135. BinaryOpScalarArray(arrayDivide,/,type) \
  136. UnaryOpScalarArray(arrayNegate,negate,type) \
  137. UnaryOpScalarArray(arrayReciprocal,reciprocal,type) \
  138. UnaryOpScalarArray(arraySqrt,sqrt,type)
  139. #define realtimes(x,y) (x * y )
  140. #define complextimes(x,y) (x * conj(y))
  141. typedef float complex complex_float ;
  142. typedef double complex complex_double ;
  143. DotProductScalarArray(arrayDotProduct,realtimes,double,0.0)
  144. DotProductScalarArray(arrayDotProduct,realtimes,float,0.0)
  145. // name mangling the complex dot products because I need to
  146. // wrap them to take a singlen array pointer where I write the result
  147. // because haskell currently doesn't have an FFI story for structs and complex numbers
  148. DotProductScalarArray(arrayDotProduct_internal,complextimes,complex_double,0.0 + I*0.0)
  149. DotProductScalarArray(arrayDotProduct_internal,complextimes,complex_float,0.0f + I*0.0f)
  150. /// NOTE: complex valued dot products have a different type than the real float dot products
  151. /// this is important to remember!
  152. void arrayDotProduct_complex_double(int32_t length, complex_double * left, int32_t leftStride ,complex_double * right,int32_t rightStride, complex_double * resultSingleton);
  153. void arrayDotProduct_complex_double(int32_t length, complex_double * left, int32_t leftStride ,complex_double * right,int32_t rightStride, complex_double * resultSingleton){
  154. complex_double result = arrayDotProduct_internal_complex_double(length,left,leftStride,right,rightStride);
  155. resultSingleton[0] = result ;
  156. }
  157. void arrayDotProduct_complex_float(int32_t length, complex_float * left, int32_t leftStride ,complex_float * right,int32_t rightStride, complex_float * resultSingleton);
  158. void arrayDotProduct_complex_float(int32_t length, complex_float * left, int32_t leftStride ,complex_float * right,int32_t rightStride, complex_float * resultSingleton){
  159. complex_float result = arrayDotProduct_internal_complex_float(length,left,leftStride,right,rightStride);
  160. resultSingleton[0] = result ;
  161. }
  162. UnaryOpScalarArray(arrayGeneralAbs,fabs,float)
  163. UnaryOpScalarArray(arrayGeneralAbs,fabs,double)
  164. mkNumFracOpsScalar(complex_double)
  165. mkNumFracOpsScalar(complex_float)
  166. mkNumFracOpsScalar(double)
  167. mkNumFracOpsScalar(float)