/sourceryvsipl--/sourceryvsipl++-lite-2.1/src/src/vsip/opt/cuda/fastconv.hpp

https://github.com/somaproject/thirdparty-packages · C++ Header · 298 lines · 218 code · 49 blank · 31 comment · 22 complexity · 3195d25961d7a3f9054699a933b16df3 MD5 · raw file

  1. /* Copyright (c) 2009 by CodeSourcery. All rights reserved.
  2. This file is available for license from CodeSourcery, Inc. under the terms
  3. of a commercial license and under the GPL. It is not part of the VSIPL++
  4. reference implementation and is not available under the BSD license.
  5. */
  6. /** @file vsip/opt/cuda/fastconv.hpp
  7. @author Don McCoy
  8. @date 2009-03-22
  9. @brief VSIPL++ Library: Wrapper for fast convolution using CUDA.
  10. */
  11. #ifndef VSIP_OPT_CUDA_FASTCONV_HPP
  12. #define VSIP_OPT_CUDA_FASTCONV_HPP
  13. /***********************************************************************
  14. Included Files
  15. ***********************************************************************/
  16. #include <vsip/core/allocation.hpp>
  17. #include <vsip/core/config.hpp>
  18. #include <vsip/core/extdata.hpp>
  19. #include <vsip/opt/cuda/bindings.hpp>
  20. /***********************************************************************
  21. Declarations
  22. ***********************************************************************/
  23. namespace vsip
  24. {
  25. namespace impl
  26. {
  27. namespace cuda
  28. {
  29. template <dimension_type D,
  30. typename T,
  31. typename ComplexFmt>
  32. struct Fastconv_traits;
  33. template <>
  34. struct Fastconv_traits<1, std::complex<float>, Cmplx_inter_fmt>
  35. {
  36. static length_type const min_size = 16;
  37. static length_type const max_size = 8000000;
  38. };
  39. template <>
  40. struct Fastconv_traits<2, std::complex<float>, Cmplx_inter_fmt>
  41. {
  42. static length_type const min_size = 16;
  43. static length_type const max_size = 8000000;
  44. };
  45. /// Fast convolution object
  46. ///
  47. /// Template parameters:
  48. /// D to specify the dimensionality of the kernel (either a 1 or 2)
  49. /// T to be the value type of data that will be processed.
  50. /// ComplexFmt to be the complex format (either Cmplx_inter_fmt or
  51. /// Cmplx_split_fmt) to be processed.
  52. template <dimension_type D,
  53. typename T,
  54. typename ComplexFmt>
  55. class Fastconv_base
  56. {
  57. static dimension_type const dim = D;
  58. typedef ComplexFmt complex_type;
  59. typedef Layout<1, row1_type, Stride_unit_dense, complex_type> layout1_type;
  60. typedef Layout<2, row2_type, Stride_unit_dense, complex_type> layout2_type;
  61. public:
  62. Fastconv_base(length_type const input_size, bool transform_kernel)
  63. : size_ (input_size),
  64. transform_kernel_(transform_kernel)
  65. {
  66. assert(rt_valid_size(this->size_));
  67. }
  68. static bool rt_valid_size(length_type size)
  69. {
  70. return (size >= cuda::Fastconv_traits<dim, T, complex_type>::min_size &&
  71. size <= cuda::Fastconv_traits<dim, T, complex_type>::max_size);
  72. }
  73. template <typename Block0, typename Block1, typename Block2>
  74. void convolve(const_Vector<T, Block0> in, const_Vector<T, Block1> kernel, Vector<T, Block2> out)
  75. {
  76. Ext_data<Block0, layout1_type> ext_in (in.block(), SYNC_IN);
  77. Ext_data<Block1, layout1_type> ext_kernel(kernel.block(), SYNC_IN);
  78. Ext_data<Block2, layout1_type> ext_out (out.block(), SYNC_OUT);
  79. assert(dim == 1);
  80. assert(ext_in.stride(0) == 1);
  81. assert(ext_kernel.stride(0) == 1);
  82. assert(ext_out.stride(0) == 1);
  83. length_type rows = 1;
  84. fconv(ext_in.data(), ext_kernel.data(), ext_out.data(), rows, out.size(0), transform_kernel_);
  85. }
  86. template <typename Block0, typename Block1, typename Block2>
  87. void convolve(const_Matrix<T, Block0> in, const_Vector<T, Block1> kernel, Matrix<T, Block2> out)
  88. {
  89. Ext_data<Block0, layout2_type> ext_in (in.block(), SYNC_IN);
  90. Ext_data<Block1, layout1_type> ext_kernel(kernel.block(), SYNC_IN);
  91. Ext_data<Block2, layout2_type> ext_out (out.block(), SYNC_OUT);
  92. assert(dim == 1);
  93. assert(ext_in.stride(1) == 1);
  94. assert(ext_kernel.stride(0) == 1);
  95. assert(ext_out.stride(1) == 1);
  96. length_type rows = in.size(0);
  97. fconv(ext_in.data(), ext_kernel.data(), ext_out.data(), rows, out.size(1), transform_kernel_);
  98. }
  99. template <typename Block0, typename Block1, typename Block2>
  100. void convolve(const_Matrix<T, Block0> in, const_Matrix<T, Block1> kernel, Matrix<T, Block2> out)
  101. {
  102. Ext_data<Block0, layout2_type> ext_in (in.block(), SYNC_IN);
  103. Ext_data<Block1, layout2_type> ext_kernel(kernel.block(), SYNC_IN);
  104. Ext_data<Block2, layout2_type> ext_out (out.block(), SYNC_OUT);
  105. assert(dim == 2);
  106. assert(ext_in.stride(1) == 1);
  107. assert(ext_kernel.stride(1) == 1);
  108. assert(ext_out.stride(1) == 1);
  109. length_type rows = in.size(0);
  110. fconv(ext_in.data(), ext_kernel.data(), ext_out.data(), rows, out.size(1), transform_kernel_);
  111. }
  112. length_type size() { return size_; }
  113. private:
  114. typedef typename Scalar_of<T>::type uT;
  115. void fconv(T const* in, T const* kernel, T* out,
  116. length_type rows, length_type length, bool transform_kernel);
  117. // Member data.
  118. length_type size_;
  119. bool transform_kernel_;
  120. };
  121. template <dimension_type D,
  122. typename T,
  123. typename ComplexFmt = Cmplx_inter_fmt>
  124. class Fastconv;
  125. template <typename T, typename ComplexFmt>
  126. class Fastconv<1, T, ComplexFmt> : public Fastconv_base<1, T, ComplexFmt>
  127. {
  128. // Constructors, copies, assignments, and destructors.
  129. public:
  130. template <typename Block>
  131. Fastconv(Vector<T, Block> coeffs,
  132. length_type input_size,
  133. bool transform_kernel = true)
  134. VSIP_THROW((std::bad_alloc))
  135. : Fastconv_base<1, T, ComplexFmt>(input_size, transform_kernel),
  136. kernel_(input_size)
  137. {
  138. assert(coeffs.size(0) <= this->size());
  139. if (transform_kernel)
  140. {
  141. kernel_ = T();
  142. kernel_(view_domain(coeffs.local())) = coeffs.local();
  143. }
  144. else
  145. kernel_ = coeffs.local();
  146. }
  147. ~Fastconv() VSIP_NOTHROW {}
  148. // Fastconv operators.
  149. template <typename Block1,
  150. typename Block2>
  151. Vector<T, Block2>
  152. operator()(
  153. const_Vector<T, Block1> in,
  154. Vector<T, Block2> out)
  155. VSIP_NOTHROW
  156. {
  157. assert(in.size() == this->size());
  158. assert(out.size() == this->size());
  159. this->convolve(in.local(), this->kernel_, out.local());
  160. return out;
  161. }
  162. template <typename Block1,
  163. typename Block2>
  164. Matrix<T, Block2>
  165. operator()(
  166. const_Matrix<T, Block1> in,
  167. Matrix<T, Block2> out)
  168. VSIP_NOTHROW
  169. {
  170. assert(in.size(1) == this->size());
  171. assert(out.size(1) == this->size());
  172. this->convolve(in.local(), this->kernel_, out.local());
  173. return out;
  174. }
  175. private:
  176. typedef ComplexFmt complex_type;
  177. typedef Layout<1, row1_type,
  178. Stride_unit_dense, complex_type> kernel_layout_type;
  179. typedef Fast_block<1, T,
  180. kernel_layout_type, Local_map> kernel_block_type;
  181. typedef Vector<T, kernel_block_type> kernel_view_type;
  182. // Member data.
  183. kernel_view_type kernel_;
  184. };
  185. template <typename T, typename ComplexFmt>
  186. class Fastconv<2, T, ComplexFmt> : public Fastconv_base<2, T, ComplexFmt>
  187. {
  188. // Constructors, copies, assignments, and destructors.
  189. public:
  190. template <typename Block>
  191. Fastconv(Matrix<T, Block> coeffs,
  192. length_type input_size,
  193. bool transform_kernel = true)
  194. VSIP_THROW((std::bad_alloc))
  195. : Fastconv_base<2, T, ComplexFmt>(input_size, transform_kernel),
  196. kernel_(coeffs.local().size(0), input_size)
  197. {
  198. assert(coeffs.size(1) <= this->size());
  199. if (transform_kernel)
  200. {
  201. kernel_ = T();
  202. kernel_(view_domain(coeffs.local())) = coeffs.local();
  203. }
  204. else
  205. kernel_ = coeffs.local();
  206. }
  207. ~Fastconv() VSIP_NOTHROW {}
  208. // Fastconv operators.
  209. template <typename Block1,
  210. typename Block2>
  211. Vector<T, Block2>
  212. operator()(
  213. const_Vector<T, Block1> in,
  214. Vector<T, Block2> out)
  215. VSIP_NOTHROW
  216. {
  217. assert(in.size() == this->size());
  218. assert(out.size() == this->size());
  219. this->convolve(in.local(), this->kernel_, out.local());
  220. return out;
  221. }
  222. template <typename Block1,
  223. typename Block2>
  224. Matrix<T, Block2>
  225. operator()(
  226. const_Matrix<T, Block1> in,
  227. Matrix<T, Block2> out)
  228. VSIP_NOTHROW
  229. {
  230. assert(in.size(1) == this->size());
  231. assert(out.size(1) == this->size());
  232. this->convolve(in.local(), this->kernel_, out.local());
  233. return out;
  234. }
  235. private:
  236. // Member data.
  237. typedef ComplexFmt complex_type;
  238. typedef Layout<2, row2_type,
  239. Stride_unit_dense, complex_type> kernel_layout_type;
  240. typedef Fast_block<2, T,
  241. kernel_layout_type, Local_map> kernel_block_type;
  242. typedef Matrix<T, kernel_block_type> kernel_view_type;
  243. kernel_view_type kernel_;
  244. };
  245. } // namespace vsip::impl::cuda
  246. } // namespace vsip::impl
  247. } // namespace vsip
  248. #endif // VSIP_OPT_CUDA_FASTCONV_HPP