PageRenderTime 84ms CodeModel.GetById 21ms RepoModel.GetById 0ms app.codeStats 0ms

/tests/test_util.cpp

https://github.com/jinhou/quda
C++ | 1925 lines | 1499 code | 302 blank | 124 comment | 375 complexity | 7e4891f68b5c381ce1900e618eb04f36 MD5 | raw file
Possible License(s): BSD-3-Clause
  1. #include <complex>
  2. #include <stdlib.h>
  3. #include <stdio.h>
  4. #include <string.h>
  5. #include <short.h>
  6. #if defined(QMP_COMMS)
  7. #include <qmp.h>
  8. #elif defined(MPI_COMMS)
  9. #include <mpi.h>
  10. #endif
  11. #include <wilson_dslash_reference.h>
  12. #include <test_util.h>
  13. #include <face_quda.h>
  14. #include <dslash_quda.h>
  15. #include "misc.h"
  16. using namespace std;
  17. #define XUP 0
  18. #define YUP 1
  19. #define ZUP 2
  20. #define TUP 3
  21. int Z[4];
  22. int V;
  23. int Vh;
  24. int Vs_x, Vs_y, Vs_z, Vs_t;
  25. int Vsh_x, Vsh_y, Vsh_z, Vsh_t;
  26. int faceVolume[4];
  27. //extended volume, +4
  28. int E1, E1h, E2, E3, E4;
  29. int E[4];
  30. int V_ex, Vh_ex;
  31. int Ls;
  32. int V5;
  33. int V5h;
  34. int mySpinorSiteSize;
  35. extern float fat_link_max;
  36. void initComms(int argc, char **argv, const int *commDims)
  37. {
  38. #if defined(QMP_COMMS)
  39. QMP_thread_level_t tl;
  40. QMP_init_msg_passing(&argc, &argv, QMP_THREAD_SINGLE, &tl);
  41. #elif defined(MPI_COMMS)
  42. MPI_Init(&argc, &argv);
  43. #endif
  44. initCommsGridQuda(4, commDims, NULL, NULL);
  45. initRand();
  46. }
  47. void finalizeComms()
  48. {
  49. #if defined(QMP_COMMS)
  50. QMP_finalize_msg_passing();
  51. #elif defined(MPI_COMMS)
  52. MPI_Finalize();
  53. #endif
  54. }
  55. void initRand()
  56. {
  57. int rank = 0;
  58. #if defined(QMP_COMMS)
  59. rank = QMP_get_node_number();
  60. #elif defined(MPI_COMMS)
  61. MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  62. #endif
  63. srand(17*rank + 137);
  64. }
  65. void setDims(int *X) {
  66. V = 1;
  67. for (int d=0; d< 4; d++) {
  68. V *= X[d];
  69. Z[d] = X[d];
  70. faceVolume[d] = 1;
  71. for (int i=0; i<4; i++) {
  72. if (i==d) continue;
  73. faceVolume[d] *= X[i];
  74. }
  75. }
  76. Vh = V/2;
  77. Vs_x = X[1]*X[2]*X[3];
  78. Vs_y = X[0]*X[2]*X[3];
  79. Vs_z = X[0]*X[1]*X[3];
  80. Vs_t = X[0]*X[1]*X[2];
  81. Vsh_x = Vs_x/2;
  82. Vsh_y = Vs_y/2;
  83. Vsh_z = Vs_z/2;
  84. Vsh_t = Vs_t/2;
  85. E1=X[0]+4; E2=X[1]+4; E3=X[2]+4; E4=X[3]+4;
  86. E1h=E1/2;
  87. E[0] = E1;
  88. E[1] = E2;
  89. E[2] = E3;
  90. E[3] = E4;
  91. V_ex = E1*E2*E3*E4;
  92. Vh_ex = V_ex/2;
  93. }
  94. void dw_setDims(int *X, const int L5)
  95. {
  96. V = 1;
  97. for (int d=0; d< 4; d++)
  98. {
  99. V *= X[d];
  100. Z[d] = X[d];
  101. faceVolume[d] = 1;
  102. for (int i=0; i<4; i++) {
  103. if (i==d) continue;
  104. faceVolume[d] *= X[i];
  105. }
  106. }
  107. Vh = V/2;
  108. Ls = L5;
  109. V5 = V*Ls;
  110. V5h = Vh*Ls;
  111. Vs_t = Z[0]*Z[1]*Z[2]*Ls;//?
  112. Vsh_t = Vs_t/2; //?
  113. }
  114. void setSpinorSiteSize(int n)
  115. {
  116. mySpinorSiteSize = n;
  117. }
  118. template <typename Float>
  119. static void printVector(Float *v) {
  120. printfQuda("{(%f %f) (%f %f) (%f %f)}\n", v[0], v[1], v[2], v[3], v[4], v[5]);
  121. }
  122. // X indexes the lattice site
  123. void printSpinorElement(void *spinor, int X, QudaPrecision precision) {
  124. if (precision == QUDA_DOUBLE_PRECISION)
  125. for (int s=0; s<4; s++) printVector((double*)spinor+X*24+s*6);
  126. else
  127. for (int s=0; s<4; s++) printVector((float*)spinor+X*24+s*6);
  128. }
  129. // X indexes the full lattice
  130. void printGaugeElement(void *gauge, int X, QudaPrecision precision) {
  131. if (getOddBit(X) == 0) {
  132. if (precision == QUDA_DOUBLE_PRECISION)
  133. for (int m=0; m<3; m++) printVector((double*)gauge +(X/2)*gaugeSiteSize + m*3*2);
  134. else
  135. for (int m=0; m<3; m++) printVector((float*)gauge +(X/2)*gaugeSiteSize + m*3*2);
  136. } else {
  137. if (precision == QUDA_DOUBLE_PRECISION)
  138. for (int m = 0; m < 3; m++) printVector((double*)gauge + (X/2+Vh)*gaugeSiteSize + m*3*2);
  139. else
  140. for (int m = 0; m < 3; m++) printVector((float*)gauge + (X/2+Vh)*gaugeSiteSize + m*3*2);
  141. }
  142. }
  143. // returns 0 or 1 if the full lattice index X is even or odd
  144. int getOddBit(int Y) {
  145. int x4 = Y/(Z[2]*Z[1]*Z[0]);
  146. int x3 = (Y/(Z[1]*Z[0])) % Z[2];
  147. int x2 = (Y/Z[0]) % Z[1];
  148. int x1 = Y % Z[0];
  149. return (x4+x3+x2+x1) % 2;
  150. }
  151. // a+=b
  152. template <typename Float>
  153. inline void complexAddTo(Float *a, Float *b) {
  154. a[0] += b[0];
  155. a[1] += b[1];
  156. }
  157. // a = b*c
  158. template <typename Float>
  159. inline void complexProduct(Float *a, Float *b, Float *c) {
  160. a[0] = b[0]*c[0] - b[1]*c[1];
  161. a[1] = b[0]*c[1] + b[1]*c[0];
  162. }
  163. // a = conj(b)*conj(c)
  164. template <typename Float>
  165. inline void complexConjugateProduct(Float *a, Float *b, Float *c) {
  166. a[0] = b[0]*c[0] - b[1]*c[1];
  167. a[1] = -b[0]*c[1] - b[1]*c[0];
  168. }
  169. // a = conj(b)*c
  170. template <typename Float>
  171. inline void complexDotProduct(Float *a, Float *b, Float *c) {
  172. a[0] = b[0]*c[0] + b[1]*c[1];
  173. a[1] = b[0]*c[1] - b[1]*c[0];
  174. }
  175. // a += b*c
  176. template <typename Float>
  177. inline void accumulateComplexProduct(Float *a, Float *b, Float *c, Float sign) {
  178. a[0] += sign*(b[0]*c[0] - b[1]*c[1]);
  179. a[1] += sign*(b[0]*c[1] + b[1]*c[0]);
  180. }
  181. // a += conj(b)*c)
  182. template <typename Float>
  183. inline void accumulateComplexDotProduct(Float *a, Float *b, Float *c) {
  184. a[0] += b[0]*c[0] + b[1]*c[1];
  185. a[1] += b[0]*c[1] - b[1]*c[0];
  186. }
  187. template <typename Float>
  188. inline void accumulateConjugateProduct(Float *a, Float *b, Float *c, int sign) {
  189. a[0] += sign * (b[0]*c[0] - b[1]*c[1]);
  190. a[1] -= sign * (b[0]*c[1] + b[1]*c[0]);
  191. }
  192. template <typename Float>
  193. inline void su3Construct12(Float *mat) {
  194. Float *w = mat+12;
  195. w[0] = 0.0;
  196. w[1] = 0.0;
  197. w[2] = 0.0;
  198. w[3] = 0.0;
  199. w[4] = 0.0;
  200. w[5] = 0.0;
  201. }
  202. // Stabilized Bunk and Sommer
  203. template <typename Float>
  204. inline void su3Construct8(Float *mat) {
  205. mat[0] = atan2(mat[1], mat[0]);
  206. mat[1] = atan2(mat[13], mat[12]);
  207. for (int i=8; i<18; i++) mat[i] = 0.0;
  208. }
  209. void su3_construct(void *mat, QudaReconstructType reconstruct, QudaPrecision precision) {
  210. if (reconstruct == QUDA_RECONSTRUCT_12) {
  211. if (precision == QUDA_DOUBLE_PRECISION) su3Construct12((double*)mat);
  212. else su3Construct12((float*)mat);
  213. } else {
  214. if (precision == QUDA_DOUBLE_PRECISION) su3Construct8((double*)mat);
  215. else su3Construct8((float*)mat);
  216. }
  217. }
  218. // given first two rows (u,v) of SU(3) matrix mat, reconstruct the third row
  219. // as the cross product of the conjugate vectors: w = u* x v*
  220. //
  221. // 48 flops
  222. template <typename Float>
  223. static void su3Reconstruct12(Float *mat, int dir, int ga_idx, QudaGaugeParam *param) {
  224. Float *u = &mat[0*(3*2)];
  225. Float *v = &mat[1*(3*2)];
  226. Float *w = &mat[2*(3*2)];
  227. w[0] = 0.0; w[1] = 0.0; w[2] = 0.0; w[3] = 0.0; w[4] = 0.0; w[5] = 0.0;
  228. accumulateConjugateProduct(w+0*(2), u+1*(2), v+2*(2), +1);
  229. accumulateConjugateProduct(w+0*(2), u+2*(2), v+1*(2), -1);
  230. accumulateConjugateProduct(w+1*(2), u+2*(2), v+0*(2), +1);
  231. accumulateConjugateProduct(w+1*(2), u+0*(2), v+2*(2), -1);
  232. accumulateConjugateProduct(w+2*(2), u+0*(2), v+1*(2), +1);
  233. accumulateConjugateProduct(w+2*(2), u+1*(2), v+0*(2), -1);
  234. Float u0 = (dir < 3 ? param->anisotropy :
  235. (ga_idx >= (Z[3]-1)*Z[0]*Z[1]*Z[2]/2 ? param->t_boundary : 1));
  236. w[0]*=u0; w[1]*=u0; w[2]*=u0; w[3]*=u0; w[4]*=u0; w[5]*=u0;
  237. }
  238. template <typename Float>
  239. static void su3Reconstruct8(Float *mat, int dir, int ga_idx, QudaGaugeParam *param) {
  240. // First reconstruct first row
  241. Float row_sum = 0.0;
  242. row_sum += mat[2]*mat[2];
  243. row_sum += mat[3]*mat[3];
  244. row_sum += mat[4]*mat[4];
  245. row_sum += mat[5]*mat[5];
  246. Float u0 = (dir < 3 ? param->anisotropy :
  247. (ga_idx >= (Z[3]-1)*Z[0]*Z[1]*Z[2]/2 ? param->t_boundary : 1));
  248. Float U00_mag = sqrt(1.f/(u0*u0) - row_sum);
  249. mat[14] = mat[0];
  250. mat[15] = mat[1];
  251. mat[0] = U00_mag * cos(mat[14]);
  252. mat[1] = U00_mag * sin(mat[14]);
  253. Float column_sum = 0.0;
  254. for (int i=0; i<2; i++) column_sum += mat[i]*mat[i];
  255. for (int i=6; i<8; i++) column_sum += mat[i]*mat[i];
  256. Float U20_mag = sqrt(1.f/(u0*u0) - column_sum);
  257. mat[12] = U20_mag * cos(mat[15]);
  258. mat[13] = U20_mag * sin(mat[15]);
  259. // First column now restored
  260. // finally reconstruct last elements from SU(2) rotation
  261. Float r_inv2 = 1.0/(u0*row_sum);
  262. // U11
  263. Float A[2];
  264. complexDotProduct(A, mat+0, mat+6);
  265. complexConjugateProduct(mat+8, mat+12, mat+4);
  266. accumulateComplexProduct(mat+8, A, mat+2, u0);
  267. mat[8] *= -r_inv2;
  268. mat[9] *= -r_inv2;
  269. // U12
  270. complexConjugateProduct(mat+10, mat+12, mat+2);
  271. accumulateComplexProduct(mat+10, A, mat+4, -u0);
  272. mat[10] *= r_inv2;
  273. mat[11] *= r_inv2;
  274. // U21
  275. complexDotProduct(A, mat+0, mat+12);
  276. complexConjugateProduct(mat+14, mat+6, mat+4);
  277. accumulateComplexProduct(mat+14, A, mat+2, -u0);
  278. mat[14] *= r_inv2;
  279. mat[15] *= r_inv2;
  280. // U12
  281. complexConjugateProduct(mat+16, mat+6, mat+2);
  282. accumulateComplexProduct(mat+16, A, mat+4, u0);
  283. mat[16] *= -r_inv2;
  284. mat[17] *= -r_inv2;
  285. }
  286. void su3_reconstruct(void *mat, int dir, int ga_idx, QudaReconstructType reconstruct, QudaPrecision precision, QudaGaugeParam *param) {
  287. if (reconstruct == QUDA_RECONSTRUCT_12) {
  288. if (precision == QUDA_DOUBLE_PRECISION) su3Reconstruct12((double*)mat, dir, ga_idx, param);
  289. else su3Reconstruct12((float*)mat, dir, ga_idx, param);
  290. } else {
  291. if (precision == QUDA_DOUBLE_PRECISION) su3Reconstruct8((double*)mat, dir, ga_idx, param);
  292. else su3Reconstruct8((float*)mat, dir, ga_idx, param);
  293. }
  294. }
  295. /*
  296. void su3_construct_8_half(float *mat, short *mat_half) {
  297. su3Construct8(mat);
  298. mat_half[0] = floatToShort(mat[0] / M_PI);
  299. mat_half[1] = floatToShort(mat[1] / M_PI);
  300. for (int i=2; i<18; i++) {
  301. mat_half[i] = floatToShort(mat[i]);
  302. }
  303. }
  304. void su3_reconstruct_8_half(float *mat, short *mat_half, int dir, int ga_idx, QudaGaugeParam *param) {
  305. for (int i=0; i<18; i++) {
  306. mat[i] = shortToFloat(mat_half[i]);
  307. }
  308. mat[0] *= M_PI;
  309. mat[1] *= M_PI;
  310. su3Reconstruct8(mat, dir, ga_idx, param);
  311. }*/
  312. template <typename Float>
  313. static int compareFloats(Float *a, Float *b, int len, double epsilon) {
  314. for (int i = 0; i < len; i++) {
  315. double diff = fabs(a[i] - b[i]);
  316. if (diff > epsilon) {
  317. printfQuda("ERROR: i=%d, a[%d]=%f, b[%d]=%f\n", i, i, a[i], i, b[i]);
  318. return 0;
  319. }
  320. }
  321. return 1;
  322. }
  323. int compare_floats(void *a, void *b, int len, double epsilon, QudaPrecision precision) {
  324. if (precision == QUDA_DOUBLE_PRECISION) return compareFloats((double*)a, (double*)b, len, epsilon);
  325. else return compareFloats((float*)a, (float*)b, len, epsilon);
  326. }
  327. int fullLatticeIndex(int dim[4], int index, int oddBit){
  328. int za = index/(dim[0]>>1);
  329. int zb = za/dim[1];
  330. int x2 = za - zb*dim[1];
  331. int x4 = zb/dim[2];
  332. int x3 = zb - x4*dim[2];
  333. return 2*index + ((x2 + x3 + x4 + oddBit) & 1);
  334. }
  335. // given a "half index" i into either an even or odd half lattice (corresponding
  336. // to oddBit = {0, 1}), returns the corresponding full lattice index.
  337. int fullLatticeIndex(int i, int oddBit) {
  338. /*
  339. int boundaryCrossings = i/(Z[0]/2) + i/(Z[1]*Z[0]/2) + i/(Z[2]*Z[1]*Z[0]/2);
  340. return 2*i + (boundaryCrossings + oddBit) % 2;
  341. */
  342. int X1 = Z[0];
  343. int X2 = Z[1];
  344. int X3 = Z[2];
  345. //int X4 = Z[3];
  346. int X1h =X1/2;
  347. int sid =i;
  348. int za = sid/X1h;
  349. //int x1h = sid - za*X1h;
  350. int zb = za/X2;
  351. int x2 = za - zb*X2;
  352. int x4 = zb/X3;
  353. int x3 = zb - x4*X3;
  354. int x1odd = (x2 + x3 + x4 + oddBit) & 1;
  355. //int x1 = 2*x1h + x1odd;
  356. int X = 2*sid + x1odd;
  357. return X;
  358. }
  359. // i represents a "half index" into an even or odd "half lattice".
  360. // when oddBit={0,1} the half lattice is {even,odd}.
  361. //
  362. // the displacements, such as dx, refer to the full lattice coordinates.
  363. //
  364. // neighborIndex() takes a "half index", displaces it, and returns the
  365. // new "half index", which can be an index into either the even or odd lattices.
  366. // displacements of magnitude one always interchange odd and even lattices.
  367. //
  368. int neighborIndex(int i, int oddBit, int dx4, int dx3, int dx2, int dx1) {
  369. int Y = fullLatticeIndex(i, oddBit);
  370. int x4 = Y/(Z[2]*Z[1]*Z[0]);
  371. int x3 = (Y/(Z[1]*Z[0])) % Z[2];
  372. int x2 = (Y/Z[0]) % Z[1];
  373. int x1 = Y % Z[0];
  374. // assert (oddBit == (x+y+z+t)%2);
  375. x4 = (x4+dx4+Z[3]) % Z[3];
  376. x3 = (x3+dx3+Z[2]) % Z[2];
  377. x2 = (x2+dx2+Z[1]) % Z[1];
  378. x1 = (x1+dx1+Z[0]) % Z[0];
  379. return (x4*(Z[2]*Z[1]*Z[0]) + x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;
  380. }
  381. int neighborIndex(int dim[4], int index, int oddBit, int dx[4]){
  382. const int fullIndex = fullLatticeIndex(dim, index, oddBit);
  383. int x[4];
  384. x[3] = fullIndex/(dim[2]*dim[1]*dim[0]);
  385. x[2] = (fullIndex/(dim[1]*dim[0])) % dim[2];
  386. x[1] = (fullIndex/dim[0]) % dim[1];
  387. x[0] = fullIndex % dim[0];
  388. for(int dir=0; dir<4; ++dir)
  389. x[dir] = (x[dir]+dx[dir]+dim[dir]) % dim[dir];
  390. return (((x[3]*dim[2] + x[2])*dim[1] + x[1])*dim[0] + x[0])/2;
  391. }
  392. int
  393. neighborIndex_mg(int i, int oddBit, int dx4, int dx3, int dx2, int dx1)
  394. {
  395. int ret;
  396. int Y = fullLatticeIndex(i, oddBit);
  397. int x4 = Y/(Z[2]*Z[1]*Z[0]);
  398. int x3 = (Y/(Z[1]*Z[0])) % Z[2];
  399. int x2 = (Y/Z[0]) % Z[1];
  400. int x1 = Y % Z[0];
  401. int ghost_x4 = x4+ dx4;
  402. // assert (oddBit == (x+y+z+t)%2);
  403. x4 = (x4+dx4+Z[3]) % Z[3];
  404. x3 = (x3+dx3+Z[2]) % Z[2];
  405. x2 = (x2+dx2+Z[1]) % Z[1];
  406. x1 = (x1+dx1+Z[0]) % Z[0];
  407. if ( ghost_x4 >= 0 && ghost_x4 < Z[3]){
  408. ret = (x4*(Z[2]*Z[1]*Z[0]) + x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;
  409. }else{
  410. ret = (x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;
  411. }
  412. return ret;
  413. }
  414. /*
  415. * This is a computation of neighbor using the full index and the displacement in each direction
  416. *
  417. */
  418. int
  419. neighborIndexFullLattice(int i, int dx4, int dx3, int dx2, int dx1)
  420. {
  421. int oddBit = 0;
  422. int half_idx = i;
  423. if (i >= Vh){
  424. oddBit =1;
  425. half_idx = i - Vh;
  426. }
  427. int nbr_half_idx = neighborIndex(half_idx, oddBit, dx4,dx3,dx2,dx1);
  428. int oddBitChanged = (dx4+dx3+dx2+dx1)%2;
  429. if (oddBitChanged){
  430. oddBit = 1 - oddBit;
  431. }
  432. int ret = nbr_half_idx;
  433. if (oddBit){
  434. ret = Vh + nbr_half_idx;
  435. }
  436. return ret;
  437. }
  438. int
  439. neighborIndexFullLattice(int dim[4], int index, int dx[4])
  440. {
  441. const int volume = dim[0]*dim[1]*dim[2]*dim[3];
  442. const int halfVolume = volume/2;
  443. int oddBit = 0;
  444. int halfIndex = index;
  445. if(index >= halfVolume){
  446. oddBit = 1;
  447. halfIndex = index - halfVolume;
  448. }
  449. int neighborHalfIndex = neighborIndex(dim, halfIndex, oddBit, dx);
  450. int oddBitChanged = (dx[0]+dx[1]+dx[2]+dx[3])%2;
  451. if(oddBitChanged){
  452. oddBit = 1 - oddBit;
  453. }
  454. return neighborHalfIndex + oddBit*halfVolume;
  455. }
  456. int
  457. neighborIndexFullLattice_mg(int i, int dx4, int dx3, int dx2, int dx1)
  458. {
  459. int ret;
  460. int oddBit = 0;
  461. int half_idx = i;
  462. if (i >= Vh){
  463. oddBit =1;
  464. half_idx = i - Vh;
  465. }
  466. int Y = fullLatticeIndex(half_idx, oddBit);
  467. int x4 = Y/(Z[2]*Z[1]*Z[0]);
  468. int x3 = (Y/(Z[1]*Z[0])) % Z[2];
  469. int x2 = (Y/Z[0]) % Z[1];
  470. int x1 = Y % Z[0];
  471. int ghost_x4 = x4+ dx4;
  472. x4 = (x4+dx4+Z[3]) % Z[3];
  473. x3 = (x3+dx3+Z[2]) % Z[2];
  474. x2 = (x2+dx2+Z[1]) % Z[1];
  475. x1 = (x1+dx1+Z[0]) % Z[0];
  476. if ( ghost_x4 >= 0 && ghost_x4 < Z[3]){
  477. ret = (x4*(Z[2]*Z[1]*Z[0]) + x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;
  478. }else{
  479. ret = (x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;
  480. return ret;
  481. }
  482. int oddBitChanged = (dx4+dx3+dx2+dx1)%2;
  483. if (oddBitChanged){
  484. oddBit = 1 - oddBit;
  485. }
  486. if (oddBit){
  487. ret += Vh;
  488. }
  489. return ret;
  490. }
  491. // 4d checkerboard.
  492. // given a "half index" i into either an even or odd half lattice (corresponding
  493. // to oddBit = {0, 1}), returns the corresponding full lattice index.
  494. // Cf. GPGPU code in dslash_core_ante.h.
  495. // There, i is the thread index.
  496. int fullLatticeIndex_4d(int i, int oddBit) {
  497. if (i >= Vh || i < 0) {printf("i out of range in fullLatticeIndex_4d"); exit(-1);}
  498. /*
  499. int boundaryCrossings = i/(Z[0]/2) + i/(Z[1]*Z[0]/2) + i/(Z[2]*Z[1]*Z[0]/2);
  500. return 2*i + (boundaryCrossings + oddBit) % 2;
  501. */
  502. int X1 = Z[0];
  503. int X2 = Z[1];
  504. int X3 = Z[2];
  505. //int X4 = Z[3];
  506. int X1h =X1/2;
  507. int sid =i;
  508. int za = sid/X1h;
  509. //int x1h = sid - za*X1h;
  510. int zb = za/X2;
  511. int x2 = za - zb*X2;
  512. int x4 = zb/X3;
  513. int x3 = zb - x4*X3;
  514. int x1odd = (x2 + x3 + x4 + oddBit) & 1;
  515. //int x1 = 2*x1h + x1odd;
  516. int X = 2*sid + x1odd;
  517. return X;
  518. }
  519. // 5d checkerboard.
  520. // given a "half index" i into either an even or odd half lattice (corresponding
  521. // to oddBit = {0, 1}), returns the corresponding full lattice index.
  522. // Cf. GPGPU code in dslash_core_ante.h.
  523. // There, i is the thread index sid.
  524. // This function is used by neighborIndex_5d in dslash_reference.cpp.
  525. //ok
  526. int fullLatticeIndex_5d(int i, int oddBit) {
  527. int boundaryCrossings = i/(Z[0]/2) + i/(Z[1]*Z[0]/2) + i/(Z[2]*Z[1]*Z[0]/2) + i/(Z[3]*Z[2]*Z[1]*Z[0]/2);
  528. return 2*i + (boundaryCrossings + oddBit) % 2;
  529. }
  530. int
  531. x4_from_full_index(int i)
  532. {
  533. int oddBit = 0;
  534. int half_idx = i;
  535. if (i >= Vh){
  536. oddBit =1;
  537. half_idx = i - Vh;
  538. }
  539. int Y = fullLatticeIndex(half_idx, oddBit);
  540. int x4 = Y/(Z[2]*Z[1]*Z[0]);
  541. return x4;
  542. }
  543. template <typename Float>
  544. static void applyGaugeFieldScaling(Float **gauge, int Vh, QudaGaugeParam *param) {
  545. // Apply spatial scaling factor (u0) to spatial links
  546. for (int d = 0; d < 3; d++) {
  547. for (int i = 0; i < gaugeSiteSize*Vh*2; i++) {
  548. gauge[d][i] /= param->anisotropy;
  549. }
  550. }
  551. // only apply T-boundary at edge nodes
  552. #ifdef MULTI_GPU
  553. bool last_node_in_t = (commCoords(3) == commDim(3)-1) ? true : false;
  554. #else
  555. bool last_node_in_t = true;
  556. #endif
  557. // Apply boundary conditions to temporal links
  558. if (param->t_boundary == QUDA_ANTI_PERIODIC_T && last_node_in_t) {
  559. for (int j = (Z[0]/2)*Z[1]*Z[2]*(Z[3]-1); j < Vh; j++) {
  560. for (int i = 0; i < gaugeSiteSize; i++) {
  561. gauge[3][j*gaugeSiteSize+i] *= -1.0;
  562. gauge[3][(Vh+j)*gaugeSiteSize+i] *= -1.0;
  563. }
  564. }
  565. }
  566. if (param->gauge_fix) {
  567. // set all gauge links (except for the last Z[0]*Z[1]*Z[2]/2) to the identity,
  568. // to simulate fixing to the temporal gauge.
  569. int iMax = ( last_node_in_t ? (Z[0]/2)*Z[1]*Z[2]*(Z[3]-1) : Vh );
  570. int dir = 3; // time direction only
  571. Float *even = gauge[dir];
  572. Float *odd = gauge[dir]+Vh*gaugeSiteSize;
  573. for (int i = 0; i< iMax; i++) {
  574. for (int m = 0; m < 3; m++) {
  575. for (int n = 0; n < 3; n++) {
  576. even[i*(3*3*2) + m*(3*2) + n*(2) + 0] = (m==n) ? 1 : 0;
  577. even[i*(3*3*2) + m*(3*2) + n*(2) + 1] = 0.0;
  578. odd [i*(3*3*2) + m*(3*2) + n*(2) + 0] = (m==n) ? 1 : 0;
  579. odd [i*(3*3*2) + m*(3*2) + n*(2) + 1] = 0.0;
  580. }
  581. }
  582. }
  583. }
  584. }
  585. template <typename Float>
  586. void applyGaugeFieldScaling_long(Float **gauge, int Vh, QudaGaugeParam *param)
  587. {
  588. int X1h=param->X[0]/2;
  589. int X1 =param->X[0];
  590. int X2 =param->X[1];
  591. int X3 =param->X[2];
  592. int X4 =param->X[3];
  593. // rescale long links by the appropriate coefficient
  594. for(int d=0; d<4; d++){
  595. for(int i=0; i < V*gaugeSiteSize; i++){
  596. gauge[d][i] /= (-24*param->tadpole_coeff*param->tadpole_coeff);
  597. }
  598. }
  599. // apply the staggered phases
  600. for (int d = 0; d < 3; d++) {
  601. //even
  602. for (int i = 0; i < Vh; i++) {
  603. int index = fullLatticeIndex(i, 0);
  604. int i4 = index /(X3*X2*X1);
  605. int i3 = (index - i4*(X3*X2*X1))/(X2*X1);
  606. int i2 = (index - i4*(X3*X2*X1) - i3*(X2*X1))/X1;
  607. int i1 = index - i4*(X3*X2*X1) - i3*(X2*X1) - i2*X1;
  608. int sign=1;
  609. if (d == 0) {
  610. if (i4 % 2 == 1){
  611. sign= -1;
  612. }
  613. }
  614. if (d == 1){
  615. if ((i4+i1) % 2 == 1){
  616. sign= -1;
  617. }
  618. }
  619. if (d == 2){
  620. if ( (i4+i1+i2) % 2 == 1){
  621. sign= -1;
  622. }
  623. }
  624. for (int j=0;j < 6; j++){
  625. gauge[d][i*gaugeSiteSize + 12+ j] *= sign;
  626. }
  627. }
  628. //odd
  629. for (int i = 0; i < Vh; i++) {
  630. int index = fullLatticeIndex(i, 1);
  631. int i4 = index /(X3*X2*X1);
  632. int i3 = (index - i4*(X3*X2*X1))/(X2*X1);
  633. int i2 = (index - i4*(X3*X2*X1) - i3*(X2*X1))/X1;
  634. int i1 = index - i4*(X3*X2*X1) - i3*(X2*X1) - i2*X1;
  635. int sign=1;
  636. if (d == 0) {
  637. if (i4 % 2 == 1){
  638. sign= -1;
  639. }
  640. }
  641. if (d == 1){
  642. if ((i4+i1) % 2 == 1){
  643. sign= -1;
  644. }
  645. }
  646. if (d == 2){
  647. if ( (i4+i1+i2) % 2 == 1){
  648. sign = -1;
  649. }
  650. }
  651. for (int j=0;j < 6; j++){
  652. gauge[d][(Vh+i)*gaugeSiteSize + 12 + j] *= sign;
  653. }
  654. }
  655. }
  656. // Apply boundary conditions to temporal links
  657. if (param->t_boundary == QUDA_ANTI_PERIODIC_T) {
  658. for (int j = 0; j < Vh; j++) {
  659. int sign =1;
  660. if (j >= (X4-3)*X1h*X2*X3 ){
  661. sign= -1;
  662. }
  663. for (int i = 0; i < 6; i++) {
  664. gauge[3][j*gaugeSiteSize+ 12+ i ] *= sign;
  665. gauge[3][(Vh+j)*gaugeSiteSize+12 +i] *= sign;
  666. }
  667. }
  668. }
  669. }
  670. template <typename Float>
  671. static void constructUnitGaugeField(Float **res, QudaGaugeParam *param) {
  672. Float *resOdd[4], *resEven[4];
  673. for (int dir = 0; dir < 4; dir++) {
  674. resEven[dir] = res[dir];
  675. resOdd[dir] = res[dir]+Vh*gaugeSiteSize;
  676. }
  677. for (int dir = 0; dir < 4; dir++) {
  678. for (int i = 0; i < Vh; i++) {
  679. for (int m = 0; m < 3; m++) {
  680. for (int n = 0; n < 3; n++) {
  681. resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = (m==n) ? 1 : 0;
  682. resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = 0.0;
  683. resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = (m==n) ? 1 : 0;
  684. resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = 0.0;
  685. }
  686. }
  687. }
  688. }
  689. applyGaugeFieldScaling(res, Vh, param);
  690. }
  691. // normalize the vector a
  692. template <typename Float>
  693. static void normalize(complex<Float> *a, int len) {
  694. double sum = 0.0;
  695. for (int i=0; i<len; i++) sum += norm(a[i]);
  696. for (int i=0; i<len; i++) a[i] /= sqrt(sum);
  697. }
  698. // orthogonalize vector b to vector a
  699. template <typename Float>
  700. static void orthogonalize(complex<Float> *a, complex<Float> *b, int len) {
  701. complex<double> dot = 0.0;
  702. for (int i=0; i<len; i++) dot += conj(a[i])*b[i];
  703. for (int i=0; i<len; i++) b[i] -= (complex<Float>)dot*a[i];
  704. }
  705. template <typename Float>
  706. static void constructGaugeField(Float **res, QudaGaugeParam *param) {
  707. Float *resOdd[4], *resEven[4];
  708. for (int dir = 0; dir < 4; dir++) {
  709. resEven[dir] = res[dir];
  710. resOdd[dir] = res[dir]+Vh*gaugeSiteSize;
  711. }
  712. for (int dir = 0; dir < 4; dir++) {
  713. for (int i = 0; i < Vh; i++) {
  714. for (int m = 1; m < 3; m++) { // last 2 rows
  715. for (int n = 0; n < 3; n++) { // 3 columns
  716. resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = rand() / (Float)RAND_MAX;
  717. resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = rand() / (Float)RAND_MAX;
  718. resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = rand() / (Float)RAND_MAX;
  719. resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = rand() / (Float)RAND_MAX;
  720. }
  721. }
  722. normalize((complex<Float>*)(resEven[dir] + (i*3+1)*3*2), 3);
  723. orthogonalize((complex<Float>*)(resEven[dir] + (i*3+1)*3*2), (complex<Float>*)(resEven[dir] + (i*3+2)*3*2), 3);
  724. normalize((complex<Float>*)(resEven[dir] + (i*3 + 2)*3*2), 3);
  725. normalize((complex<Float>*)(resOdd[dir] + (i*3+1)*3*2), 3);
  726. orthogonalize((complex<Float>*)(resOdd[dir] + (i*3+1)*3*2), (complex<Float>*)(resOdd[dir] + (i*3+2)*3*2), 3);
  727. normalize((complex<Float>*)(resOdd[dir] + (i*3 + 2)*3*2), 3);
  728. {
  729. Float *w = resEven[dir]+(i*3+0)*3*2;
  730. Float *u = resEven[dir]+(i*3+1)*3*2;
  731. Float *v = resEven[dir]+(i*3+2)*3*2;
  732. for (int n = 0; n < 6; n++) w[n] = 0.0;
  733. accumulateConjugateProduct(w+0*(2), u+1*(2), v+2*(2), +1);
  734. accumulateConjugateProduct(w+0*(2), u+2*(2), v+1*(2), -1);
  735. accumulateConjugateProduct(w+1*(2), u+2*(2), v+0*(2), +1);
  736. accumulateConjugateProduct(w+1*(2), u+0*(2), v+2*(2), -1);
  737. accumulateConjugateProduct(w+2*(2), u+0*(2), v+1*(2), +1);
  738. accumulateConjugateProduct(w+2*(2), u+1*(2), v+0*(2), -1);
  739. }
  740. {
  741. Float *w = resOdd[dir]+(i*3+0)*3*2;
  742. Float *u = resOdd[dir]+(i*3+1)*3*2;
  743. Float *v = resOdd[dir]+(i*3+2)*3*2;
  744. for (int n = 0; n < 6; n++) w[n] = 0.0;
  745. accumulateConjugateProduct(w+0*(2), u+1*(2), v+2*(2), +1);
  746. accumulateConjugateProduct(w+0*(2), u+2*(2), v+1*(2), -1);
  747. accumulateConjugateProduct(w+1*(2), u+2*(2), v+0*(2), +1);
  748. accumulateConjugateProduct(w+1*(2), u+0*(2), v+2*(2), -1);
  749. accumulateConjugateProduct(w+2*(2), u+0*(2), v+1*(2), +1);
  750. accumulateConjugateProduct(w+2*(2), u+1*(2), v+0*(2), -1);
  751. }
  752. }
  753. }
  754. if (param->type == QUDA_WILSON_LINKS){
  755. applyGaugeFieldScaling(res, Vh, param);
  756. } else if (param->type == QUDA_ASQTAD_LONG_LINKS){
  757. applyGaugeFieldScaling_long(res, Vh, param);
  758. } else if (param->type == QUDA_ASQTAD_FAT_LINKS){
  759. for (int dir = 0; dir < 4; dir++){
  760. for (int i = 0; i < Vh; i++) {
  761. for (int m = 0; m < 3; m++) { // last 2 rows
  762. for (int n = 0; n < 3; n++) { // 3 columns
  763. resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] =1.0* rand() / (Float)RAND_MAX;
  764. resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] =2.0* rand() / (Float)RAND_MAX;
  765. resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = 3.0*rand() / (Float)RAND_MAX;
  766. resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = 4.0*rand() / (Float)RAND_MAX;
  767. }
  768. }
  769. }
  770. }
  771. }
  772. }
  773. template <typename Float>
  774. void constructUnitaryGaugeField(Float **res)
  775. {
  776. Float *resOdd[4], *resEven[4];
  777. for (int dir = 0; dir < 4; dir++) {
  778. resEven[dir] = res[dir];
  779. resOdd[dir] = res[dir]+Vh*gaugeSiteSize;
  780. }
  781. for (int dir = 0; dir < 4; dir++) {
  782. for (int i = 0; i < Vh; i++) {
  783. for (int m = 1; m < 3; m++) { // last 2 rows
  784. for (int n = 0; n < 3; n++) { // 3 columns
  785. resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = rand() / (Float)RAND_MAX;
  786. resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = rand() / (Float)RAND_MAX;
  787. resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = rand() / (Float)RAND_MAX;
  788. resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = rand() / (Float)RAND_MAX;
  789. }
  790. }
  791. normalize((complex<Float>*)(resEven[dir] + (i*3+1)*3*2), 3);
  792. orthogonalize((complex<Float>*)(resEven[dir] + (i*3+1)*3*2), (complex<Float>*)(resEven[dir] + (i*3+2)*3*2), 3);
  793. normalize((complex<Float>*)(resEven[dir] + (i*3 + 2)*3*2), 3);
  794. normalize((complex<Float>*)(resOdd[dir] + (i*3+1)*3*2), 3);
  795. orthogonalize((complex<Float>*)(resOdd[dir] + (i*3+1)*3*2), (complex<Float>*)(resOdd[dir] + (i*3+2)*3*2), 3);
  796. normalize((complex<Float>*)(resOdd[dir] + (i*3 + 2)*3*2), 3);
  797. {
  798. Float *w = resEven[dir]+(i*3+0)*3*2;
  799. Float *u = resEven[dir]+(i*3+1)*3*2;
  800. Float *v = resEven[dir]+(i*3+2)*3*2;
  801. for (int n = 0; n < 6; n++) w[n] = 0.0;
  802. accumulateConjugateProduct(w+0*(2), u+1*(2), v+2*(2), +1);
  803. accumulateConjugateProduct(w+0*(2), u+2*(2), v+1*(2), -1);
  804. accumulateConjugateProduct(w+1*(2), u+2*(2), v+0*(2), +1);
  805. accumulateConjugateProduct(w+1*(2), u+0*(2), v+2*(2), -1);
  806. accumulateConjugateProduct(w+2*(2), u+0*(2), v+1*(2), +1);
  807. accumulateConjugateProduct(w+2*(2), u+1*(2), v+0*(2), -1);
  808. }
  809. {
  810. Float *w = resOdd[dir]+(i*3+0)*3*2;
  811. Float *u = resOdd[dir]+(i*3+1)*3*2;
  812. Float *v = resOdd[dir]+(i*3+2)*3*2;
  813. for (int n = 0; n < 6; n++) w[n] = 0.0;
  814. accumulateConjugateProduct(w+0*(2), u+1*(2), v+2*(2), +1);
  815. accumulateConjugateProduct(w+0*(2), u+2*(2), v+1*(2), -1);
  816. accumulateConjugateProduct(w+1*(2), u+2*(2), v+0*(2), +1);
  817. accumulateConjugateProduct(w+1*(2), u+0*(2), v+2*(2), -1);
  818. accumulateConjugateProduct(w+2*(2), u+0*(2), v+1*(2), +1);
  819. accumulateConjugateProduct(w+2*(2), u+1*(2), v+0*(2), -1);
  820. }
  821. }
  822. }
  823. }
  824. void construct_gauge_field(void **gauge, int type, QudaPrecision precision, QudaGaugeParam *param) {
  825. if (type == 0) {
  826. if (precision == QUDA_DOUBLE_PRECISION) constructUnitGaugeField((double**)gauge, param);
  827. else constructUnitGaugeField((float**)gauge, param);
  828. } else if (type == 1) {
  829. if (precision == QUDA_DOUBLE_PRECISION) constructGaugeField((double**)gauge, param);
  830. else constructGaugeField((float**)gauge, param);
  831. } else {
  832. if (precision == QUDA_DOUBLE_PRECISION) applyGaugeFieldScaling((double**)gauge, Vh, param);
  833. else applyGaugeFieldScaling((float**)gauge, Vh, param);
  834. }
  835. }
  836. void
  837. construct_fat_long_gauge_field(void **fatlink, void** longlink,
  838. int type, QudaPrecision precision, QudaGaugeParam* param)
  839. {
  840. if (type == 0) {
  841. if (precision == QUDA_DOUBLE_PRECISION) {
  842. constructUnitGaugeField((double**)fatlink, param);
  843. constructUnitGaugeField((double**)longlink, param);
  844. }else {
  845. constructUnitGaugeField((float**)fatlink, param);
  846. constructUnitGaugeField((float**)longlink, param);
  847. }
  848. } else {
  849. if (precision == QUDA_DOUBLE_PRECISION) {
  850. param->type = QUDA_ASQTAD_FAT_LINKS;
  851. constructGaugeField((double**)fatlink, param);
  852. param->type = QUDA_ASQTAD_LONG_LINKS;
  853. constructGaugeField((double**)longlink, param);
  854. }else {
  855. param->type = QUDA_ASQTAD_FAT_LINKS;
  856. constructGaugeField((float**)fatlink, param);
  857. param->type = QUDA_ASQTAD_LONG_LINKS;
  858. constructGaugeField((float**)longlink, param);
  859. }
  860. }
  861. }
  862. template <typename Float>
  863. static void constructCloverField(Float *res, double norm, double diag) {
  864. Float c = 2.0 * norm / RAND_MAX;
  865. for(int i = 0; i < V; i++) {
  866. for (int j = 0; j < 72; j++) {
  867. res[i*72 + j] = c*rand() - norm;
  868. }
  869. for (int j = 0; j< 6; j++) {
  870. res[i*72 + j] += diag;
  871. res[i*72 + j+36] += diag;
  872. }
  873. }
  874. }
  875. void construct_clover_field(void *clover, double norm, double diag, QudaPrecision precision) {
  876. if (precision == QUDA_DOUBLE_PRECISION) constructCloverField((double *)clover, norm, diag);
  877. else constructCloverField((float *)clover, norm, diag);
  878. }
  879. /*void strong_check(void *spinorRef, void *spinorGPU, int len, QudaPrecision prec) {
  880. printf("Reference:\n");
  881. printSpinorElement(spinorRef, 0, prec); printf("...\n");
  882. printSpinorElement(spinorRef, len-1, prec); printf("\n");
  883. printf("\nCUDA:\n");
  884. printSpinorElement(spinorGPU, 0, prec); printf("...\n");
  885. printSpinorElement(spinorGPU, len-1, prec); printf("\n");
  886. compare_spinor(spinorRef, spinorGPU, len, prec);
  887. }*/
  888. template <typename Float>
  889. static void checkGauge(Float **oldG, Float **newG, double epsilon) {
  890. const int fail_check = 17;
  891. int fail[4][fail_check];
  892. int iter[4][18];
  893. for (int d=0; d<4; d++) for (int i=0; i<fail_check; i++) fail[d][i] = 0;
  894. for (int d=0; d<4; d++) for (int i=0; i<18; i++) iter[d][i] = 0;
  895. for (int d=0; d<4; d++) {
  896. for (int eo=0; eo<2; eo++) {
  897. for (int i=0; i<Vh; i++) {
  898. int ga_idx = (eo*Vh+i);
  899. for (int j=0; j<18; j++) {
  900. double diff = fabs(newG[d][ga_idx*18+j] - oldG[d][ga_idx*18+j]);/// fabs(oldG[d][ga_idx*18+j]);
  901. for (int f=0; f<fail_check; f++) if (diff > pow(10.0,-(f+1))) fail[d][f]++;
  902. if (diff > epsilon) iter[d][j]++;
  903. }
  904. }
  905. }
  906. }
  907. printf("Component fails (X, Y, Z, T)\n");
  908. for (int i=0; i<18; i++) printf("%d fails = (%8d, %8d, %8d, %8d)\n", i, iter[0][i], iter[1][i], iter[2][i], iter[3][i]);
  909. printf("\nDeviation Failures = (X, Y, Z, T)\n");
  910. for (int f=0; f<fail_check; f++) {
  911. printf("%e Failures = (%9d, %9d, %9d, %9d) = (%6.5f, %6.5f, %6.5f, %6.5f)\n", pow(10.0,-(f+1)),
  912. fail[0][f], fail[1][f], fail[2][f], fail[3][f],
  913. fail[0][f]/(double)(V*18), fail[1][f]/(double)(V*18), fail[2][f]/(double)(V*18), fail[3][f]/(double)(V*18));
  914. }
  915. }
  916. void check_gauge(void **oldG, void **newG, double epsilon, QudaPrecision precision) {
  917. if (precision == QUDA_DOUBLE_PRECISION)
  918. checkGauge((double**)oldG, (double**)newG, epsilon);
  919. else
  920. checkGauge((float**)oldG, (float**)newG, epsilon);
  921. }
  922. void
  923. createSiteLinkCPU(void** link, QudaPrecision precision, int phase)
  924. {
  925. if (precision == QUDA_DOUBLE_PRECISION) {
  926. constructUnitaryGaugeField((double**)link);
  927. }else {
  928. constructUnitaryGaugeField((float**)link);
  929. }
  930. // only apply temporal boundary condition if I'm the last node in T
  931. #ifdef MULTI_GPU
  932. bool last_node_in_t = (commCoords(3) == commDim(3)-1) ? true : false;
  933. #else
  934. bool last_node_in_t = true;
  935. #endif
  936. if(phase){
  937. for(int i=0;i < V;i++){
  938. for(int dir =XUP; dir <= TUP; dir++){
  939. int idx = i;
  940. int oddBit =0;
  941. if (i >= Vh) {
  942. idx = i - Vh;
  943. oddBit = 1;
  944. }
  945. int X1 = Z[0];
  946. int X2 = Z[1];
  947. int X3 = Z[2];
  948. int X4 = Z[3];
  949. int full_idx = fullLatticeIndex(idx, oddBit);
  950. int i4 = full_idx /(X3*X2*X1);
  951. int i3 = (full_idx - i4*(X3*X2*X1))/(X2*X1);
  952. int i2 = (full_idx - i4*(X3*X2*X1) - i3*(X2*X1))/X1;
  953. int i1 = full_idx - i4*(X3*X2*X1) - i3*(X2*X1) - i2*X1;
  954. double coeff= 1.0;
  955. switch(dir){
  956. case XUP:
  957. if ( (i4 & 1) != 0){
  958. coeff *= -1;
  959. }
  960. break;
  961. case YUP:
  962. if ( ((i4+i1) & 1) != 0){
  963. coeff *= -1;
  964. }
  965. break;
  966. case ZUP:
  967. if ( ((i4+i1+i2) & 1) != 0){
  968. coeff *= -1;
  969. }
  970. break;
  971. case TUP:
  972. if (last_node_in_t && i4 == (X4-1)){
  973. coeff *= -1;
  974. }
  975. break;
  976. default:
  977. printf("ERROR: wrong dir(%d)\n", dir);
  978. exit(1);
  979. }
  980. if (precision == QUDA_DOUBLE_PRECISION){
  981. //double* mylink = (double*)link;
  982. //mylink = mylink + (4*i + dir)*gaugeSiteSize;
  983. double* mylink = (double*)link[dir];
  984. mylink = mylink + i*gaugeSiteSize;
  985. mylink[12] *= coeff;
  986. mylink[13] *= coeff;
  987. mylink[14] *= coeff;
  988. mylink[15] *= coeff;
  989. mylink[16] *= coeff;
  990. mylink[17] *= coeff;
  991. }else{
  992. //float* mylink = (float*)link;
  993. //mylink = mylink + (4*i + dir)*gaugeSiteSize;
  994. float* mylink = (float*)link[dir];
  995. mylink = mylink + i*gaugeSiteSize;
  996. mylink[12] *= coeff;
  997. mylink[13] *= coeff;
  998. mylink[14] *= coeff;
  999. mylink[15] *= coeff;
  1000. mylink[16] *= coeff;
  1001. mylink[17] *= coeff;
  1002. }
  1003. }
  1004. }
  1005. }
  1006. #if 1
  1007. for(int dir= 0;dir < 4;dir++){
  1008. for(int i=0;i< V*gaugeSiteSize;i++){
  1009. if (precision ==QUDA_SINGLE_PRECISION){
  1010. float* f = (float*)link[dir];
  1011. if (f[i] != f[i] || (fabsf(f[i]) > 1.e+3) ){
  1012. fprintf(stderr, "ERROR: %dth: bad number(%f) in function %s \n",i, f[i], __FUNCTION__);
  1013. exit(1);
  1014. }
  1015. }else{
  1016. double* f = (double*)link[dir];
  1017. if (f[i] != f[i] || (fabs(f[i]) > 1.e+3)){
  1018. fprintf(stderr, "ERROR: %dth: bad number(%f) in function %s \n",i, f[i], __FUNCTION__);
  1019. exit(1);
  1020. }
  1021. }
  1022. }
  1023. }
  1024. #endif
  1025. return;
  1026. }
  1027. template <typename Float>
  1028. int compareLink(Float **linkA, Float **linkB, int len) {
  1029. const int fail_check = 16;
  1030. int fail[fail_check];
  1031. for (int f=0; f<fail_check; f++) fail[f] = 0;
  1032. int iter[18];
  1033. for (int i=0; i<18; i++) iter[i] = 0;
  1034. for(int dir=0;dir < 4; dir++){
  1035. for (int i=0; i<len; i++) {
  1036. for (int j=0; j<18; j++) {
  1037. int is = i*18+j;
  1038. double diff = fabs(linkA[dir][is]-linkB[dir][is]);
  1039. for (int f=0; f<fail_check; f++)
  1040. if (diff > pow(10.0,-(f+1))) fail[f]++;
  1041. //if (diff > 1e-1) printf("%d %d %e\n", i, j, diff);
  1042. if (diff > 1e-3) iter[j]++;
  1043. }
  1044. }
  1045. }
  1046. for (int i=0; i<18; i++) printfQuda("%d fails = %d\n", i, iter[i]);
  1047. int accuracy_level = 0;
  1048. for(int f =0; f < fail_check; f++){
  1049. if(fail[f] == 0){
  1050. accuracy_level =f;
  1051. }
  1052. }
  1053. for (int f=0; f<fail_check; f++) {
  1054. printfQuda("%e Failures: %d / %d = %e\n", pow(10.0,-(f+1)), fail[f], 4*len*18, fail[f] / (double)(4*len*18));
  1055. }
  1056. return accuracy_level;
  1057. }
  1058. static int
  1059. compare_link(void **linkA, void **linkB, int len, QudaPrecision precision)
  1060. {
  1061. int ret;
  1062. if (precision == QUDA_DOUBLE_PRECISION){
  1063. ret = compareLink((double**)linkA, (double**)linkB, len);
  1064. }else {
  1065. ret = compareLink((float**)linkA, (float**)linkB, len);
  1066. }
  1067. return ret;
  1068. }
  1069. // X indexes the lattice site
  1070. static void
  1071. printLinkElement(void *link, int X, QudaPrecision precision)
  1072. {
  1073. if (precision == QUDA_DOUBLE_PRECISION){
  1074. for(int i=0; i < 3;i++){
  1075. printVector((double*)link+ X*gaugeSiteSize + i*6);
  1076. }
  1077. }
  1078. else{
  1079. for(int i=0;i < 3;i++){
  1080. printVector((float*)link+X*gaugeSiteSize + i*6);
  1081. }
  1082. }
  1083. }
  1084. int strong_check_link(void** linkA, const char* msgA,
  1085. void **linkB, const char* msgB,
  1086. int len, QudaPrecision prec)
  1087. {
  1088. printfQuda("%s\n", msgA);
  1089. printLinkElement(linkA[0], 0, prec);
  1090. printfQuda("\n");
  1091. printLinkElement(linkA[0], 1, prec);
  1092. printfQuda("...\n");
  1093. printLinkElement(linkA[3], len-1, prec);
  1094. printfQuda("\n");
  1095. printfQuda("\n%s\n", msgB);
  1096. printLinkElement(linkB[0], 0, prec);
  1097. printfQuda("\n");
  1098. printLinkElement(linkB[0], 1, prec);
  1099. printfQuda("...\n");
  1100. printLinkElement(linkB[3], len-1, prec);
  1101. printfQuda("\n");
  1102. int ret = compare_link(linkA, linkB, len, prec);
  1103. return ret;
  1104. }
  1105. void
  1106. createMomCPU(void* mom, QudaPrecision precision)
  1107. {
  1108. void* temp;
  1109. size_t gSize = (precision == QUDA_DOUBLE_PRECISION) ? sizeof(double) : sizeof(float);
  1110. temp = malloc(4*V*gaugeSiteSize*gSize);
  1111. if (temp == NULL){
  1112. fprintf(stderr, "Error: malloc failed for temp in function %s\n", __FUNCTION__);
  1113. exit(1);
  1114. }
  1115. for(int i=0;i < V;i++){
  1116. if (precision == QUDA_DOUBLE_PRECISION){
  1117. for(int dir=0;dir < 4;dir++){
  1118. double* thismom = (double*)mom;
  1119. for(int k=0; k < momSiteSize; k++){
  1120. thismom[ (4*i+dir)*momSiteSize + k ]= 1.0* rand() /RAND_MAX;
  1121. if (k==momSiteSize-1) thismom[ (4*i+dir)*momSiteSize + k ]= 0.0;
  1122. }
  1123. }
  1124. }else{
  1125. for(int dir=0;dir < 4;dir++){
  1126. float* thismom=(float*)mom;
  1127. for(int k=0; k < momSiteSize; k++){
  1128. thismom[ (4*i+dir)*momSiteSize + k ]= 1.0* rand() /RAND_MAX;
  1129. if (k==momSiteSize-1) thismom[ (4*i+dir)*momSiteSize + k ]= 0.0;
  1130. }
  1131. }
  1132. }
  1133. }
  1134. free(temp);
  1135. return;
  1136. }
  1137. void
  1138. createHwCPU(void* hw, QudaPrecision precision)
  1139. {
  1140. for(int i=0;i < V;i++){
  1141. if (precision == QUDA_DOUBLE_PRECISION){
  1142. for(int dir=0;dir < 4;dir++){
  1143. double* thishw = (double*)hw;
  1144. for(int k=0; k < hwSiteSize; k++){
  1145. thishw[ (4*i+dir)*hwSiteSize + k ]= 1.0* rand() /RAND_MAX;
  1146. }
  1147. }
  1148. }else{
  1149. for(int dir=0;dir < 4;dir++){
  1150. float* thishw=(float*)hw;
  1151. for(int k=0; k < hwSiteSize; k++){
  1152. thishw[ (4*i+dir)*hwSiteSize + k ]= 1.0* rand() /RAND_MAX;
  1153. }
  1154. }
  1155. }
  1156. }
  1157. return;
  1158. }
  1159. template <typename Float>
  1160. int compare_mom(Float *momA, Float *momB, int len) {
  1161. const int fail_check = 16;
  1162. int fail[fail_check];
  1163. for (int f=0; f<fail_check; f++) fail[f] = 0;
  1164. int iter[momSiteSize];
  1165. for (int i=0; i<momSiteSize; i++) iter[i] = 0;
  1166. for (int i=0; i<len; i++) {
  1167. for (int j=0; j<momSiteSize; j++) {
  1168. int is = i*momSiteSize+j;
  1169. double diff = fabs(momA[is]-momB[is]);
  1170. for (int f=0; f<fail_check; f++)
  1171. if (diff > pow(10.0,-(f+1))) fail[f]++;
  1172. //if (diff > 1e-1) printf("%d %d %e\n", i, j, diff);
  1173. if (diff > 1e-3) iter[j]++;
  1174. }
  1175. }
  1176. int accuracy_level = 0;
  1177. for(int f =0; f < fail_check; f++){
  1178. if(fail[f] == 0){
  1179. accuracy_level =f+1;
  1180. }
  1181. }
  1182. for (int i=0; i<momSiteSize; i++) printfQuda("%d fails = %d\n", i, iter[i]);
  1183. for (int f=0; f<fail_check; f++) {
  1184. printfQuda("%e Failures: %d / %d = %e\n", pow(10.0,-(f+1)), fail[f], len*9, fail[f]/(double)(len*9));
  1185. }
  1186. return accuracy_level;
  1187. }
  1188. static void
  1189. printMomElement(void *mom, int X, QudaPrecision precision)
  1190. {
  1191. if (precision == QUDA_DOUBLE_PRECISION){
  1192. double* thismom = ((double*)mom)+ X*momSiteSize;
  1193. printVector(thismom);
  1194. printfQuda("(%9f,%9f) (%9f,%9f)\n", thismom[6], thismom[7], thismom[8], thismom[9]);
  1195. }else{
  1196. float* thismom = ((float*)mom)+ X*momSiteSize;
  1197. printVector(thismom);
  1198. printfQuda("(%9f,%9f) (%9f,%9f)\n", thismom[6], thismom[7], thismom[8], thismom[9]);
  1199. }
  1200. }
  1201. int strong_check_mom(void * momA, void *momB, int len, QudaPrecision prec)
  1202. {
  1203. printfQuda("mom:\n");
  1204. printMomElement(momA, 0, prec);
  1205. printfQuda("\n");
  1206. printMomElement(momA, 1, prec);
  1207. printfQuda("\n");
  1208. printMomElement(momA, 2, prec);
  1209. printfQuda("\n");
  1210. printMomElement(momA, 3, prec);
  1211. printfQuda("...\n");
  1212. printfQuda("\nreference mom:\n");
  1213. printMomElement(momB, 0, prec);
  1214. printfQuda("\n");
  1215. printMomElement(momB, 1, prec);
  1216. printfQuda("\n");
  1217. printMomElement(momB, 2, prec);
  1218. printfQuda("\n");
  1219. printMomElement(momB, 3, prec);
  1220. printfQuda("\n");
  1221. int ret;
  1222. if (prec == QUDA_DOUBLE_PRECISION){
  1223. ret = compare_mom((double*)momA, (double*)momB, len);
  1224. }else{
  1225. ret = compare_mom((float*)momA, (float*)momB, len);
  1226. }
  1227. return ret;
  1228. }
  1229. /************
  1230. * return value
  1231. *
  1232. * 0: command line option matched and processed sucessfully
  1233. * non-zero: command line option does not match
  1234. *
  1235. */
  1236. #ifdef MULTI_GPU
  1237. int device = -1;
  1238. #else
  1239. int device = 0;
  1240. #endif
  1241. QudaReconstructType link_recon = QUDA_RECONSTRUCT_NO;
  1242. QudaReconstructType link_recon_sloppy = QUDA_RECONSTRUCT_INVALID;
  1243. QudaPrecision prec = QUDA_SINGLE_PRECISION;
  1244. QudaPrecision prec_sloppy = QUDA_INVALID_PRECISION;
  1245. int xdim = 24;
  1246. int ydim = 24;
  1247. int zdim = 24;
  1248. int tdim = 24;
  1249. int Lsdim = 16;
  1250. QudaDagType dagger = QUDA_DAG_NO;
  1251. int gridsize_from_cmdline[4] = {1,1,1,1};
  1252. QudaDslashType dslash_type = QUDA_WILSON_DSLASH;
  1253. char latfile[256] = "";
  1254. bool tune = true;
  1255. int niter = 10;
  1256. int test_type = 0;
  1257. bool verify_results = true;
  1258. static int dim_partitioned[4] = {0,0,0,0};
  1259. int dimPartitioned(int dim)
  1260. {
  1261. return ((gridsize_from_cmdline[dim] > 1) || dim_partitioned[dim]);
  1262. }
  1263. void __attribute__((weak)) usage_extra(char** argv){};
  1264. void usage(char** argv )
  1265. {
  1266. printf("Usage: %s [options]\n", argv[0]);
  1267. printf("Common options: \n");
  1268. #ifndef MULTI_GPU
  1269. printf(" --device <n> # Set the CUDA device to use (default 0, single GPU only)\n");
  1270. #endif
  1271. printf(" --prec <double/single/half> # Precision in GPU\n");
  1272. printf(" --prec_sloppy <double/single/half> # Sloppy precision in GPU\n");
  1273. printf(" --recon <8/9/12/13/18> # Link reconstruction type\n");
  1274. printf(" --recon_sloppy <8/9/12/13/18> # Sloppy link reconstruction type\n");
  1275. printf(" --dagger # Set the dagger to 1 (default 0)\n");
  1276. printf(" --sdim <n> # Set space dimention(X/Y/Z) size\n");
  1277. printf(" --xdim <n> # Set X dimension size(default 24)\n");
  1278. printf(" --ydim <n> # Set X dimension size(default 24)\n");
  1279. printf(" --zdim <n> # Set X dimension size(default 24)\n");
  1280. printf(" --tdim <n> # Set T dimension size(default 24)\n");
  1281. printf(" --Lsdim <n> # Set Ls dimension size(default 16)\n");
  1282. printf(" --xgridsize <n> # Set grid size in X dimension (default 1)\n");
  1283. printf(" --ygridsize <n> # Set grid size in Y dimension (default 1)\n");
  1284. printf(" --zgridsize <n> # Set grid size in Z dimension (default 1)\n");
  1285. printf(" --tgridsize <n> # Set grid size in T dimension (default 1)\n");
  1286. printf(" --partition <mask> # Set the communication topology (X=1, Y=2, Z=4, T=8, and combinations of these)\n");
  1287. printf(" --kernel_pack_t # Set T dimension kernel packing to be true (default false)\n");
  1288. printf(" --dslash_type <type> # Set the dslash type, the following values are valid\n"
  1289. " wilson/clover/twisted_mass/asqtad/domain_wall\n");
  1290. printf(" --load-gauge file # Load gauge field \"file\" for the test (requires QIO)\n");
  1291. printf(" --niter <n> # The number of iterations to perform (default 10)\n");
  1292. printf(" --tune <true/false> # Whether to autotune or not (default true)\n");
  1293. printf(" --test # Test method (different for each test)\n");
  1294. printf(" --verify <true/false> # Verify the GPU results using CPU results (default true)\n");
  1295. printf(" --help # Print out this message\n");
  1296. usage_extra(argv);
  1297. #ifdef MULTI_GPU
  1298. char msg[]="multi";
  1299. #else
  1300. char msg[]="single";
  1301. #endif
  1302. printf("Note: this program is %s GPU build\n", msg);
  1303. exit(1);
  1304. return ;
  1305. }
  1306. int process_command_line_option(int argc, char** argv, int* idx)
  1307. {
  1308. #ifdef MULTI_GPU
  1309. char msg[]="multi";
  1310. #else
  1311. char msg[]="single";
  1312. #endif
  1313. int ret = -1;
  1314. int i = *idx;
  1315. if( strcmp(argv[i], "--help")== 0){
  1316. usage(argv);
  1317. }
  1318. if( strcmp(argv[i], "--verify") == 0){
  1319. if (i+1 >= argc){
  1320. usage(argv);
  1321. }
  1322. if (strcmp(argv[i+1], "true") == 0){
  1323. verify_results = true;
  1324. }else if (strcmp(argv[i+1], "false") == 0){
  1325. verify_results = false;
  1326. }else{
  1327. fprintf(stderr, "ERROR: invalid verify type\n");
  1328. exit(1);
  1329. }
  1330. i++;
  1331. ret = 0;
  1332. goto out;
  1333. }
  1334. if( strcmp(argv[i], "--device") == 0){
  1335. if (i+1 >= argc){
  1336. usage(argv);
  1337. }
  1338. device = atoi(argv[i+1]);
  1339. if (device < 0 || device > 16){
  1340. printf("ERROR: Invalid CUDA device number (%d)\n", device);
  1341. usage(argv);
  1342. }
  1343. i++;
  1344. ret = 0;
  1345. goto out;
  1346. }
  1347. if( strcmp(argv[i], "--prec") == 0){
  1348. if (i+1 >= argc){
  1349. usage(argv);
  1350. }
  1351. prec = get_prec(argv[i+1]);
  1352. i++;
  1353. ret = 0;
  1354. goto out;
  1355. }
  1356. if( strcmp(argv[i], "--prec_sloppy") == 0){
  1357. if (i+1 >= argc){
  1358. usage(argv);
  1359. }
  1360. prec_sloppy = get_prec(argv[i+1]);
  1361. i++;
  1362. ret = 0;
  1363. goto out;
  1364. }
  1365. if( strcmp(argv[i], "--recon") == 0){
  1366. if (i+1 >= argc){
  1367. usage(argv);
  1368. }
  1369. link_recon = get_recon(argv[i+1]);
  1370. i++;
  1371. ret = 0;
  1372. goto out;
  1373. }
  1374. if( strcmp(argv[i], "--recon_sloppy") == 0){
  1375. if (i+1 >= argc){
  1376. usage(argv);
  1377. }
  1378. link_recon_sloppy = get_recon(argv[i+1]);
  1379. i++;
  1380. ret = 0;
  1381. goto out;
  1382. }
  1383. if( strcmp(argv[i], "--xdim") == 0){
  1384. if (i+1 >= argc){
  1385. usage(argv);
  1386. }
  1387. xdim= atoi(argv[i+1]);
  1388. if (xdim < 0 || xdim > 512){
  1389. printf("ERROR: invalid X dimension (%d)\n", xdim);
  1390. usage(argv);
  1391. }
  1392. i++;
  1393. ret = 0;
  1394. goto out;
  1395. }
  1396. if( strcmp(argv[i], "--ydim") == 0){
  1397. if (i+1 >= argc){
  1398. usage(argv);
  1399. }
  1400. ydim= atoi(argv[i+1]);
  1401. if (ydim < 0 || ydim > 512){
  1402. printf("ERROR: invalid T dimension (%d)\n", ydim);
  1403. usage(argv);
  1404. }
  1405. i++;
  1406. ret = 0;
  1407. goto out;
  1408. }
  1409. if( strcmp(argv[i], "--zdim") == 0){
  1410. if (i+1 >= argc){
  1411. usage(argv);
  1412. }
  1413. zdim= atoi(argv[i+1]);
  1414. if (zdim < 0 || zdim > 512){
  1415. printf("ERROR: invalid T dimension (%d)\n", zdim);
  1416. usage(argv);
  1417. }
  1418. i++;
  1419. ret = 0;
  1420. goto out;
  1421. }
  1422. if( strcmp(argv[i], "--tdim") == 0){
  1423. if (i+1 >= argc){
  1424. usage(argv);
  1425. }
  1426. tdim = atoi(argv[i+1]);
  1427. if (tdim < 0 || tdim > 512){
  1428. errorQuda("Error: invalid t dimension");
  1429. }
  1430. i++;
  1431. ret = 0;
  1432. goto out;
  1433. }
  1434. if( strcmp(argv[i], "--sdim") == 0){
  1435. if (i+1 >= argc){
  1436. usage(argv);
  1437. }
  1438. int sdim = atoi(argv[i+1]);
  1439. if (sdim < 0 || sdim > 512){
  1440. printfQuda("ERROR: invalid S dimension\n");
  1441. }
  1442. xdim=ydim=zdim=sdim;
  1443. i++;
  1444. ret = 0;
  1445. goto out;
  1446. }
  1447. if( strcmp(argv[i], "--Lsdim") == 0){
  1448. if (i+1 >= argc){
  1449. usage(argv);
  1450. }
  1451. int Ls = atoi(argv[i+1]);
  1452. if (Ls < 0 || Ls > 128){
  1453. printfQuda("ERROR: invalid Ls dimension\n");
  1454. }
  1455. Lsdim=Ls;
  1456. i++;
  1457. ret = 0;
  1458. goto out;
  1459. }
  1460. if( strcmp(argv[i], "--dagger") == 0){
  1461. dagger = QUDA_DAG_YES;
  1462. ret = 0;
  1463. goto out;
  1464. }
  1465. if( strcmp(argv[i], "--partition") == 0){
  1466. if (i+1 >= argc){
  1467. usage(argv);
  1468. }
  1469. #ifdef MULTI_GPU
  1470. int value = atoi(argv[i+1]);
  1471. for(int j=0; j < 4;j++){
  1472. if (value & (1 << j)){
  1473. commDimPartitionedSet(j);
  1474. dim_partitioned[j] = 1;
  1475. }
  1476. }
  1477. #else
  1478. printfQuda("WARNING: Ignoring --partition option since this is a single-GPU build.\n");
  1479. #endif
  1480. i++;
  1481. ret = 0;
  1482. goto out;
  1483. }
  1484. if( strcmp(argv[i], "--kernel_pack_t") == 0){
  1485. quda::setKernelPackT(true);
  1486. ret= 0;
  1487. goto out;
  1488. }
  1489. if( strcmp(argv[i], "--tune") == 0){
  1490. if (i+1 >= argc){
  1491. usage(argv);
  1492. }
  1493. if (strcmp(argv[i+1], "true") == 0){
  1494. tune = true;
  1495. }else if (strcmp(argv[i+1], "false") == 0){
  1496. tune = false;
  1497. }else{
  1498. fprintf(stderr, "ERROR: invalid tuning type\n");
  1499. exit(1);
  1500. }
  1501. i++;
  1502. ret = 0;
  1503. goto out;
  1504. }
  1505. if( strcmp(argv[i], "--xgridsize") == 0){
  1506. if (i+1 >= argc){
  1507. usage(argv);
  1508. }
  1509. int xsize = atoi(argv[i+1]);
  1510. if (xsize <= 0 ){
  1511. errorQuda("ERROR: invalid X grid size");
  1512. }
  1513. gridsize_from_cmdline[0] = xsize;
  1514. i++;
  1515. ret = 0;
  1516. goto out;
  1517. }
  1518. if( strcmp(argv[i], "--ygridsize") == 0){
  1519. if (i+1 >= argc){
  1520. usage(argv);
  1521. }
  1522. int ysize = atoi(argv[i+1]);
  1523. if (ysize <= 0 ){
  1524. errorQuda("ERROR: invalid Y grid size");
  1525. }
  1526. gridsize_from_cmdline[1] = ysize;
  1527. i++;
  1528. ret = 0;
  1529. goto out;
  1530. }
  1531. if( strcmp(argv[i], "--zgridsize") == 0){
  1532. if (i+1 >= argc){
  1533. usage(argv);
  1534. }
  1535. int zsize = atoi(argv[i+1]);
  1536. if (zsize <= 0 ){
  1537. errorQuda("ERROR: invalid Z grid size");
  1538. }
  1539. gridsize_from_cmdline[2] = zsize;
  1540. i++;
  1541. ret = 0;
  1542. goto out;
  1543. }
  1544. if( strcmp(argv[i], "--tgridsize") == 0){
  1545. if (i+1 >= argc){
  1546. usage(argv);
  1547. }
  1548. int tsize = atoi(argv[i+1]);
  1549. if (tsize <= 0 ){
  1550. errorQuda("ERROR: invalid T grid size");
  1551. }
  1552. gridsize_from_cmdline[3] = tsize;
  1553. i++;
  1554. ret = 0;
  1555. goto out;
  1556. }
  1557. if( strcmp(argv[i], "--dslash_type") == 0){
  1558. if (i+1 >= argc){
  1559. usage(argv);
  1560. }
  1561. dslash_type = get_dslash_type(argv[i+1]);
  1562. i++;
  1563. ret = 0;
  1564. goto out;
  1565. }
  1566. if( strcmp(argv[i], "--load-gauge") == 0){
  1567. if (i+1 >= argc){
  1568. usage(argv);
  1569. }
  1570. strcpy(latfile, argv[i+1]);
  1571. i++;
  1572. ret = 0;
  1573. goto out;
  1574. }
  1575. if( strcmp(argv[i], "--test") == 0){
  1576. if (i+1 >= argc){
  1577. usage(argv);
  1578. }
  1579. test_type = atoi(argv[i+1]);
  1580. i++;
  1581. ret = 0;
  1582. goto out;
  1583. }
  1584. if( strcmp(argv[i], "--niter") == 0){
  1585. if (i+1 >= argc){
  1586. usage(argv);
  1587. }
  1588. niter= atoi(argv[i+1]);
  1589. if (niter < 1 || niter > 1e6){
  1590. printf("ERROR: invalid number of iterations (%d)\n", niter);
  1591. usage(argv);
  1592. }
  1593. i++;
  1594. ret = 0;
  1595. goto out;
  1596. }
  1597. if( strcmp(argv[i], "--version") == 0){
  1598. printf("This program is linked with QUDA library, version %s,",
  1599. get_quda_ver_str());
  1600. printf(" %s GPU build\n", msg);
  1601. exit(0);
  1602. }
  1603. out:
  1604. *idx = i;
  1605. return ret ;
  1606. }
  1607. static struct timeval startTime;
  1608. void stopwatchStart() {
  1609. gettimeofday(&startTime, NULL);
  1610. }
  1611. double stopwatchReadSeconds() {
  1612. struct timeval endTime;
  1613. gettimeofday(&endTime, NULL);
  1614. long ds = endTime.tv_sec - startTime.tv_sec;
  1615. long dus = endTime.tv_usec - startTime.tv_usec;
  1616. return ds + 0.000001*dus;
  1617. }