/tests/cuda/todo/cuda_poly.py

http://github.com/npinto/python-cuda
Python | 100 lines | 79 code | 19 blank | 2 comment | 14 complexity | 775029e91613a62de6bf1442510d52bf MD5 | raw file
  1. #!/bin/env python
  2. # coding:utf-8: Š Arno Pähler, 2007-08
  3. from ctypes import *
  4. from time import time
  5. from cuda.cuda_defs import *
  6. from cuda.cuda_api import *
  7. from cuda.cuda_utils import *
  8. from cpuFunctions import vectorInit,checkError
  9. from cpuFunctions import cpuPOLY5,cpuPOLY10,cpuPOLY20,cpuPOLY40
  10. from gpuFunctions import gpuPOLY5,gpuPOLY10,gpuPOLY20,gpuPOLY40
  11. BLOCK_SIZE = 144
  12. GRID_SIZE = 192
  13. checkErrorFlag = False
  14. S4 = sizeof(c_float)
  15. psize = 5
  16. def main(vlength = 128,loops = 1,m1 = 1):
  17. print "%5d %5d %5d" % (l,loops,m1),
  18. alfa = c_float(.5)
  19. n2 = vlength ## Vector length
  20. mp = 1 << (m1-1)
  21. print "%5d" % (mp*psize),
  22. gpuPOLY = eval("gpuPOLY%d"%(mp*psize))
  23. h_X = (c_float*n2)()
  24. h_Y = (c_float*n2)()
  25. g_Y = (c_float*n2)()
  26. vectorInit(h_X)
  27. d_X = getMemory(h_X)
  28. d_Y = getMemory(h_Y)
  29. blockDim = dim3(BLOCK_SIZE,1,1)
  30. gridDim = dim3(GRID_SIZE,1,1)
  31. t0 = time()
  32. cudaThreadSynchronize()
  33. for i in range(loops):
  34. cudaConfigureCall(gridDim,blockDim,0,0)
  35. gpuPOLY(d_X,d_Y,n2)
  36. cudaThreadSynchronize()
  37. t0 = time()-t0
  38. flops = (2.e-9*m1*n2*(psize-1))*float(loops)
  39. cudaMemcpy(g_Y,d_Y,S4*n2,cudaMemcpyDeviceToHost)
  40. cudaThreadSynchronize()
  41. cudaFree(d_X)
  42. cudaFree(d_Y)
  43. cudaThreadExit()
  44. cpuPOLY = eval("cpuPOLY%d" % (mp*psize))
  45. t1 = time()
  46. for i in range(loops):
  47. cpuPOLY(h_X,h_Y)
  48. t1 = time()-t1
  49. print "%10d%6.2f%6.2f" % (vlength,flops/t1,flops/t0)
  50. if checkErrorFlag:
  51. err,mxe = checkError(h_Y,g_Y)
  52. print "Avg and max rel error = %.2e %.2e" % (err,mxe)
  53. if __name__ == "__main__":
  54. import sys
  55. cudaSetDevice(0)
  56. lmin,lmax = 7,23
  57. if len(sys.argv) > 1:
  58. lmin = lmax = int(sys.argv[1])
  59. loopx = -1
  60. if len(sys.argv) > 2:
  61. loopx = int(sys.argv[2])
  62. m1 = 4
  63. if len(sys.argv) > 3:
  64. m1 = min(4,int(sys.argv[3]))
  65. lmax = min(max(0,lmax),23)
  66. lmin = min(max(0,lmin),lmax)
  67. for l in range(lmin,lmax+1):
  68. if l < 10:
  69. loops = 10000/m1
  70. elif l < 13:
  71. loops = 5000/m1
  72. elif l < 17:
  73. loops = 500/m1
  74. elif l < 21:
  75. loops = 250/m1
  76. else:
  77. loops = 100/m1
  78. vlength = 1 << l
  79. if loopx > 0:
  80. loops = loopx
  81. main(vlength,loops,m1)