PageRenderTime 62ms CodeModel.GetById 22ms RepoModel.GetById 1ms app.codeStats 0ms

/WWDirectComputeTest/SincConvolution.hlsl

http://bitspersampleconv2.googlecode.com/
text | 329 lines | 267 code | 62 blank | 0 comment | 0 complexity | be779ff4a4f5ed8927baf75a86213629 MD5 | raw file
  1. // 日本語UTF-8
  2. /*
  3. OutputBuffer[t+convolutionN] = Σ[sample[t+x] * sinc(πx + XBuffer[t])]
  4. CONV_START <= x < CONV_END
  5. を計算する
  6. convolutionN = 256
  7. sampleN = 100
  8. の場合
  9. CONV_START = -256
  10. CONV_END = 256
  11. CONV_COUNT = 512
  12. SAMPLE_N = 100
  13. GROUP_THREAD_COUNT 2の乗数
  14. を#defineしてCS5.0 DirectCompute シェーダーとしてコンパイルする。
  15. // シェーダー定数を渡す
  16. shaderParams.c_convOffs = 0
  17. shaderParams.c_dispatchCount = convolutionN*2/GROUP_THREAD_COUNT;
  18. ComputeShaderのrun(shaderParams, sampleN, 1, 1);
  19. する。
  20. 用意するデータ
  21. ①SampleDataBuffer…前後を水増しされたサンプルデータsample[t]
  22. SampleDataBuffer[0]~SampleDataBuffer[convolutionN-1]…0を詰める
  23. SampleDataBuffer[convolutionN]~SampleDataBuffer[convolutionN + sampleN-1]…サンプルデータsample[t]
  24. SampleDataBuffer[convolutionN+SampleN]~SampleDataBuffer[convolutionN*2 + sampleN-1]…0を詰める
  25. ②SinxBuffer リサンプル地点のsin(x) 適当に作る
  26. SinxBuffer[0]~SinxBuffer[sampleN-1] sin(x)の値
  27. ③XBuffer リサンプル地点x
  28. XBuffer[0]~XBuffer[sampleN-1] xの値
  29. ④出力バッファー
  30. OutputBuffer[0]~OutputBuffer[sampleN-1]
  31. OutputBuffer[]はsampleN個用意する
  32. */
  33. #ifdef HIGH_PRECISION
  34. // 主にdouble精度
  35. StructuredBuffer<float> g_SampleDataBuffer : register(t0);
  36. StructuredBuffer<double> g_SinxBuffer : register(t1);
  37. StructuredBuffer<float> g_XBuffer : register(t2);
  38. RWStructuredBuffer<float> g_OutputBuffer : register(u0);
  39. /// 定数。16バイトの倍数のサイズの構造体。
  40. cbuffer consts {
  41. /// 畳み込み要素オフセット値。n * GROUP_THREAD_COUNTの飛び飛びの値が渡る。
  42. uint c_convOffs;
  43. /// Dispatch繰り返し回数。
  44. uint c_dispatchCount;
  45. uint c_reserved1;
  46. uint c_reserved2;
  47. };
  48. inline double
  49. SincF(double sinx, float x)
  50. {
  51. if (-0.000000001f < x && x < 0.000000001f) {
  52. return 1.0;
  53. } else {
  54. // 割り算ができないので、ここで精度落ちる。残念。
  55. return sinx * rcp(x);
  56. }
  57. }
  58. #define PI_F 3.141592653589793238462643f
  59. // TGSM
  60. groupshared double s_scratch[GROUP_THREAD_COUNT];
  61. groupshared double s_sinX;
  62. groupshared float s_xOffs;
  63. /// 畳み込み計算要素1回実行。
  64. /// sample[t+x] * sinc(πx + XBuffer[t])
  65. inline double
  66. ConvolutionElemValue(uint pos, uint convOffs)
  67. {
  68. const int offs = c_convOffs + convOffs;
  69. const float x = mad(PI_F, offs + CONV_START, s_xOffs);
  70. return ((double)g_SampleDataBuffer[offs + pos]) * SincF(s_sinX, x);
  71. }
  72. // スレッドグループとTGSMを使用して、GPUメモリからの読み出し回数を減らす最適化。
  73. // groupIdXYZはDispatch()のパラメータXYZ=(nx,1,1)の場合(0,0,0)~(nx-1, 0, 0)。
  74. // スレッドグループが作られ、tid==0~groupDim_x-1までのtidを持ったスレッドが同時に走る。
  75. [numthreads(GROUP_THREAD_COUNT, 1, 1)]
  76. void
  77. CSMain(
  78. uint tid: SV_GroupIndex,
  79. uint3 groupIdXYZ: SV_GroupID)
  80. {
  81. uint offs = tid;
  82. if (tid == 0) {
  83. s_xOffs = g_XBuffer[groupIdXYZ.x];
  84. s_sinX = g_SinxBuffer[groupIdXYZ.x];
  85. }
  86. s_scratch[tid] = 0;
  87. GroupMemoryBarrierWithGroupSync();
  88. do {
  89. s_scratch[tid] +=
  90. ConvolutionElemValue(groupIdXYZ.x, offs) +
  91. ConvolutionElemValue(groupIdXYZ.x, offs + GROUP_THREAD_COUNT);
  92. offs += GROUP_THREAD_COUNT * 2;
  93. } while (offs < CONV_COUNT);
  94. GroupMemoryBarrierWithGroupSync();
  95. #if 1024 <= GROUP_THREAD_COUNT
  96. if (tid < 512) { s_scratch[tid] += s_scratch[tid + 512]; }
  97. GroupMemoryBarrierWithGroupSync();
  98. #endif
  99. #if 512 <= GROUP_THREAD_COUNT
  100. if (tid < 256) { s_scratch[tid] += s_scratch[tid + 256]; }
  101. GroupMemoryBarrierWithGroupSync();
  102. #endif
  103. #if 256 <= GROUP_THREAD_COUNT
  104. if (tid < 128) { s_scratch[tid] += s_scratch[tid + 128]; }
  105. GroupMemoryBarrierWithGroupSync();
  106. #endif
  107. #if 128 <= GROUP_THREAD_COUNT
  108. if (tid < 64) { s_scratch[tid] += s_scratch[tid + 64]; }
  109. GroupMemoryBarrierWithGroupSync();
  110. #endif
  111. #if 64 <= GROUP_THREAD_COUNT
  112. if (tid < 32) { s_scratch[tid] += s_scratch[tid + 32]; }
  113. //GroupMemoryBarrierWithGroupSync(); // これ以降要らないらしい。2260_GTC2010.pdf参照。
  114. #endif
  115. #if 32 <= GROUP_THREAD_COUNT
  116. if (tid < 16) { s_scratch[tid] += s_scratch[tid + 16]; }
  117. //GroupMemoryBarrierWithGroupSync();
  118. #endif
  119. #if 16 <= GROUP_THREAD_COUNT
  120. if (tid < 8) { s_scratch[tid] += s_scratch[tid + 8]; }
  121. //GroupMemoryBarrierWithGroupSync();
  122. #endif
  123. #if 8 <= GROUP_THREAD_COUNT
  124. if (tid < 4) { s_scratch[tid] += s_scratch[tid + 4]; }
  125. // GroupMemoryBarrierWithGroupSync();
  126. #endif
  127. #if 4 <= GROUP_THREAD_COUNT
  128. if (tid < 2) { s_scratch[tid] += s_scratch[tid + 2]; }
  129. //GroupMemoryBarrierWithGroupSync();
  130. #endif
  131. if (tid == 0) {
  132. s_scratch[0] += s_scratch[1];
  133. g_OutputBuffer[groupIdXYZ.x] = (float)s_scratch[0];
  134. }
  135. }
  136. #else
  137. // 主にfloat精度
  138. StructuredBuffer<float> g_SampleDataBuffer : register(t0);
  139. StructuredBuffer<float> g_SinxBuffer : register(t1);
  140. StructuredBuffer<float> g_XBuffer : register(t2);
  141. RWStructuredBuffer<float> g_OutputBuffer : register(u0);
  142. /// 定数。16バイトの倍数のサイズの構造体。
  143. cbuffer consts {
  144. /// 畳み込み要素オフセット値。n * GROUP_THREAD_COUNTの飛び飛びの値が渡る。
  145. uint c_convOffs;
  146. /// Dispatch繰り返し回数。
  147. uint c_dispatchCount;
  148. uint c_reserved1;
  149. uint c_reserved2;
  150. };
  151. inline float
  152. SincF(float sinx, float x)
  153. {
  154. if (-0.000000001f < x && x < 0.000000001f) {
  155. return 1.0f;
  156. } else {
  157. // どちらでも同じだった。
  158. #if 1
  159. return sinx * rcp(x);
  160. #else
  161. return sinx / x;
  162. #endif
  163. }
  164. }
  165. #define PI_F 3.141592653589793238462643f
  166. // TGSM
  167. groupshared float s_scratch[GROUP_THREAD_COUNT];
  168. groupshared float s_sinX;
  169. groupshared float s_xOffs;
  170. /// 畳み込み計算要素1回実行。
  171. /// sample[t+x] * sinc(πx + XBuffer[t])
  172. inline float
  173. ConvolutionElemValue(uint pos, uint convOffs)
  174. {
  175. const int offs = c_convOffs + convOffs;
  176. const float x = mad(PI_F, offs + CONV_START, s_xOffs);
  177. return g_SampleDataBuffer[offs + pos] * SincF(s_sinX, x);
  178. }
  179. // スレッドグループとTGSMを使用して、GPUメモリからの読み出し回数を減らす最適化。
  180. // groupIdXYZはDispatch()のパラメータXYZ=(nx,1,1)の場合(0,0,0)~(nx-1, 0, 0)。
  181. // スレッドグループが作られ、tid==0~groupDim_x-1までのtidを持ったスレッドが同時に走る。
  182. [numthreads(GROUP_THREAD_COUNT, 1, 1)]
  183. void
  184. CSMain(
  185. uint tid: SV_GroupIndex,
  186. uint3 groupIdXYZ: SV_GroupID)
  187. {
  188. uint offs = tid;
  189. if (tid == 0) {
  190. s_xOffs = g_XBuffer[groupIdXYZ.x];
  191. #if 1
  192. // 計算精度良好。
  193. s_sinX = g_SinxBuffer[groupIdXYZ.x];
  194. #else
  195. // こうすると精度が落ちる。GPUのsin()の精度に問題あり。
  196. s_sinX = sin(s_xOffs);
  197. #endif
  198. }
  199. s_scratch[tid] = 0;
  200. GroupMemoryBarrierWithGroupSync();
  201. do {
  202. s_scratch[tid] +=
  203. ConvolutionElemValue(groupIdXYZ.x, offs) +
  204. ConvolutionElemValue(groupIdXYZ.x, offs + GROUP_THREAD_COUNT);
  205. offs += GROUP_THREAD_COUNT * 2;
  206. } while (offs < CONV_COUNT);
  207. GroupMemoryBarrierWithGroupSync();
  208. #if 1024 <= GROUP_THREAD_COUNT
  209. if (tid < 512) { s_scratch[tid] += s_scratch[tid + 512]; }
  210. GroupMemoryBarrierWithGroupSync();
  211. #endif
  212. #if 512 <= GROUP_THREAD_COUNT
  213. if (tid < 256) { s_scratch[tid] += s_scratch[tid + 256]; }
  214. GroupMemoryBarrierWithGroupSync();
  215. #endif
  216. #if 256 <= GROUP_THREAD_COUNT
  217. if (tid < 128) { s_scratch[tid] += s_scratch[tid + 128]; }
  218. GroupMemoryBarrierWithGroupSync();
  219. #endif
  220. #if 128 <= GROUP_THREAD_COUNT
  221. if (tid < 64) { s_scratch[tid] += s_scratch[tid + 64]; }
  222. GroupMemoryBarrierWithGroupSync();
  223. #endif
  224. #if 64 <= GROUP_THREAD_COUNT
  225. if (tid < 32) { s_scratch[tid] += s_scratch[tid + 32]; }
  226. //GroupMemoryBarrierWithGroupSync(); // これ以降要らないらしい。2260_GTC2010.pdf参照。
  227. #endif
  228. #if 32 <= GROUP_THREAD_COUNT
  229. if (tid < 16) { s_scratch[tid] += s_scratch[tid + 16]; }
  230. //GroupMemoryBarrierWithGroupSync();
  231. #endif
  232. #if 16 <= GROUP_THREAD_COUNT
  233. if (tid < 8) { s_scratch[tid] += s_scratch[tid + 8]; }
  234. //GroupMemoryBarrierWithGroupSync();
  235. #endif
  236. #if 8 <= GROUP_THREAD_COUNT
  237. if (tid < 4) { s_scratch[tid] += s_scratch[tid + 4]; }
  238. // GroupMemoryBarrierWithGroupSync();
  239. #endif
  240. #if 4 <= GROUP_THREAD_COUNT
  241. if (tid < 2) { s_scratch[tid] += s_scratch[tid + 2]; }
  242. //GroupMemoryBarrierWithGroupSync();
  243. #endif
  244. if (tid == 0) {
  245. s_scratch[0] += s_scratch[1];
  246. g_OutputBuffer[groupIdXYZ.x] = s_scratch[0];
  247. }
  248. }
  249. #if 0
  250. // 最適化前
  251. [numthreads(1, 1, 1)]
  252. void
  253. CSMain(uint3 groupIdXYZ : SV_GroupID,
  254. uint threadIdx : SV_GroupIndex)
  255. {
  256. int i;
  257. float sinx = SinxBuffer[c_pos];
  258. float xOffs = XBuffer[c_pos];
  259. float r = 0.0f;
  260. for (i=CONV_START; i<CONV_END; ++i) {
  261. float x = mad(PI, i, xOffs);
  262. r = mad(SampleDataBuffer[c_pos+i+CONV_N], SincF(sinx, x), r);
  263. }
  264. OutputBuffer[c_pos] = r;
  265. }
  266. #endif // before optimization
  267. #endif // HIGH_PRECISION