PageRenderTime 12ms CodeModel.GetById 767ms app.highlight 1875ms RepoModel.GetById 122ms app.codeStats 1ms

/tests/test_util.cpp

https://github.com/jinhou/quda
C++ | 1925 lines | 1499 code | 302 blank | 124 comment | 375 complexity | 7e4891f68b5c381ce1900e618eb04f36 MD5 | raw file
   1#include <complex>
   2#include <stdlib.h>
   3#include <stdio.h>
   4#include <string.h>
   5#include <short.h>
   6
   7#if defined(QMP_COMMS)
   8#include <qmp.h>
   9#elif defined(MPI_COMMS)
  10#include <mpi.h>
  11#endif
  12
  13#include <wilson_dslash_reference.h>
  14#include <test_util.h>
  15
  16#include <face_quda.h>
  17#include <dslash_quda.h>
  18#include "misc.h"
  19
  20using namespace std;
  21
  22#define XUP 0
  23#define YUP 1
  24#define ZUP 2
  25#define TUP 3
  26
  27
  28int Z[4];
  29int V;
  30int Vh;
  31int Vs_x, Vs_y, Vs_z, Vs_t;
  32int Vsh_x, Vsh_y, Vsh_z, Vsh_t;
  33int faceVolume[4];
  34
  35//extended volume, +4
  36int E1, E1h, E2, E3, E4; 
  37int E[4];
  38int V_ex, Vh_ex;
  39
  40int Ls;
  41int V5;
  42int V5h;
  43
  44int mySpinorSiteSize;
  45
  46extern float fat_link_max;
  47
  48
  49void initComms(int argc, char **argv, const int *commDims)
  50{
  51#if defined(QMP_COMMS)
  52  QMP_thread_level_t tl;
  53  QMP_init_msg_passing(&argc, &argv, QMP_THREAD_SINGLE, &tl);
  54#elif defined(MPI_COMMS)
  55  MPI_Init(&argc, &argv);
  56#endif
  57  initCommsGridQuda(4, commDims, NULL, NULL);
  58  initRand();
  59}
  60
  61
  62void finalizeComms()
  63{
  64#if defined(QMP_COMMS)
  65  QMP_finalize_msg_passing();
  66#elif defined(MPI_COMMS)
  67  MPI_Finalize();
  68#endif
  69}
  70
  71
  72void initRand()
  73{
  74  int rank = 0;
  75
  76#if defined(QMP_COMMS)
  77  rank = QMP_get_node_number();
  78#elif defined(MPI_COMMS)
  79  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  80#endif
  81
  82  srand(17*rank + 137);
  83}
  84
  85
  86void setDims(int *X) {
  87  V = 1;
  88  for (int d=0; d< 4; d++) {
  89    V *= X[d];
  90    Z[d] = X[d];
  91
  92    faceVolume[d] = 1;
  93    for (int i=0; i<4; i++) {
  94      if (i==d) continue;
  95      faceVolume[d] *= X[i];
  96    }
  97  }
  98  Vh = V/2;
  99
 100  Vs_x = X[1]*X[2]*X[3];
 101  Vs_y = X[0]*X[2]*X[3];
 102  Vs_z = X[0]*X[1]*X[3];
 103  Vs_t = X[0]*X[1]*X[2];
 104  
 105  Vsh_x = Vs_x/2;
 106  Vsh_y = Vs_y/2;
 107  Vsh_z = Vs_z/2;
 108  Vsh_t = Vs_t/2;
 109
 110
 111  E1=X[0]+4; E2=X[1]+4; E3=X[2]+4; E4=X[3]+4;
 112  E1h=E1/2;
 113  E[0] = E1;
 114  E[1] = E2;
 115  E[2] = E3;
 116  E[3] = E4;
 117  V_ex = E1*E2*E3*E4;
 118  Vh_ex = V_ex/2;
 119
 120}
 121
 122
 123void dw_setDims(int *X, const int L5) 
 124{
 125  V = 1;
 126  for (int d=0; d< 4; d++) 
 127  {
 128    V *= X[d];
 129    Z[d] = X[d];
 130
 131    faceVolume[d] = 1;
 132    for (int i=0; i<4; i++) {
 133      if (i==d) continue;
 134      faceVolume[d] *= X[i];
 135    }
 136  }
 137  Vh = V/2;
 138  
 139  Ls = L5;
 140  V5 = V*Ls;
 141  V5h = Vh*Ls;
 142
 143  Vs_t = Z[0]*Z[1]*Z[2]*Ls;//?
 144  Vsh_t = Vs_t/2;  //?
 145}
 146
 147
 148void setSpinorSiteSize(int n)
 149{
 150  mySpinorSiteSize = n;
 151}
 152
 153
 154template <typename Float>
 155static void printVector(Float *v) {
 156  printfQuda("{(%f %f) (%f %f) (%f %f)}\n", v[0], v[1], v[2], v[3], v[4], v[5]);
 157}
 158
 159// X indexes the lattice site
 160void printSpinorElement(void *spinor, int X, QudaPrecision precision) {
 161  if (precision == QUDA_DOUBLE_PRECISION)
 162    for (int s=0; s<4; s++) printVector((double*)spinor+X*24+s*6);
 163  else
 164    for (int s=0; s<4; s++) printVector((float*)spinor+X*24+s*6);
 165}
 166
 167// X indexes the full lattice
 168void printGaugeElement(void *gauge, int X, QudaPrecision precision) {
 169  if (getOddBit(X) == 0) {
 170    if (precision == QUDA_DOUBLE_PRECISION)
 171      for (int m=0; m<3; m++) printVector((double*)gauge +(X/2)*gaugeSiteSize + m*3*2);
 172    else
 173      for (int m=0; m<3; m++) printVector((float*)gauge +(X/2)*gaugeSiteSize + m*3*2);
 174      
 175  } else {
 176    if (precision == QUDA_DOUBLE_PRECISION)
 177      for (int m = 0; m < 3; m++) printVector((double*)gauge + (X/2+Vh)*gaugeSiteSize + m*3*2);
 178    else
 179      for (int m = 0; m < 3; m++) printVector((float*)gauge + (X/2+Vh)*gaugeSiteSize + m*3*2);
 180  }
 181}
 182
 183// returns 0 or 1 if the full lattice index X is even or odd
 184int getOddBit(int Y) {
 185  int x4 = Y/(Z[2]*Z[1]*Z[0]);
 186  int x3 = (Y/(Z[1]*Z[0])) % Z[2];
 187  int x2 = (Y/Z[0]) % Z[1];
 188  int x1 = Y % Z[0];
 189  return (x4+x3+x2+x1) % 2;
 190}
 191
 192// a+=b
 193template <typename Float>
 194inline void complexAddTo(Float *a, Float *b) {
 195  a[0] += b[0];
 196  a[1] += b[1];
 197}
 198
 199// a = b*c
 200template <typename Float>
 201inline void complexProduct(Float *a, Float *b, Float *c) {
 202  a[0] = b[0]*c[0] - b[1]*c[1];
 203  a[1] = b[0]*c[1] + b[1]*c[0];
 204}
 205
 206// a = conj(b)*conj(c)
 207template <typename Float>
 208inline void complexConjugateProduct(Float *a, Float *b, Float *c) {
 209  a[0] = b[0]*c[0] - b[1]*c[1];
 210  a[1] = -b[0]*c[1] - b[1]*c[0];
 211}
 212
 213// a = conj(b)*c
 214template <typename Float>
 215inline void complexDotProduct(Float *a, Float *b, Float *c) {
 216  a[0] = b[0]*c[0] + b[1]*c[1];
 217  a[1] = b[0]*c[1] - b[1]*c[0];
 218}
 219
 220// a += b*c
 221template <typename Float>
 222inline void accumulateComplexProduct(Float *a, Float *b, Float *c, Float sign) {
 223  a[0] += sign*(b[0]*c[0] - b[1]*c[1]);
 224  a[1] += sign*(b[0]*c[1] + b[1]*c[0]);
 225}
 226
 227// a += conj(b)*c)
 228template <typename Float>
 229inline void accumulateComplexDotProduct(Float *a, Float *b, Float *c) {
 230  a[0] += b[0]*c[0] + b[1]*c[1];
 231  a[1] += b[0]*c[1] - b[1]*c[0];
 232}
 233
 234template <typename Float>
 235inline void accumulateConjugateProduct(Float *a, Float *b, Float *c, int sign) {
 236  a[0] += sign * (b[0]*c[0] - b[1]*c[1]);
 237  a[1] -= sign * (b[0]*c[1] + b[1]*c[0]);
 238}
 239
 240template <typename Float>
 241inline void su3Construct12(Float *mat) {
 242  Float *w = mat+12;
 243  w[0] = 0.0;
 244  w[1] = 0.0;
 245  w[2] = 0.0;
 246  w[3] = 0.0;
 247  w[4] = 0.0;
 248  w[5] = 0.0;
 249}
 250
 251// Stabilized Bunk and Sommer
 252template <typename Float>
 253inline void su3Construct8(Float *mat) {
 254  mat[0] = atan2(mat[1], mat[0]);
 255  mat[1] = atan2(mat[13], mat[12]);
 256  for (int i=8; i<18; i++) mat[i] = 0.0;
 257}
 258
 259void su3_construct(void *mat, QudaReconstructType reconstruct, QudaPrecision precision) {
 260  if (reconstruct == QUDA_RECONSTRUCT_12) {
 261    if (precision == QUDA_DOUBLE_PRECISION) su3Construct12((double*)mat);
 262    else su3Construct12((float*)mat);
 263  } else {
 264    if (precision == QUDA_DOUBLE_PRECISION) su3Construct8((double*)mat);
 265    else su3Construct8((float*)mat);
 266  }
 267}
 268
 269// given first two rows (u,v) of SU(3) matrix mat, reconstruct the third row
 270// as the cross product of the conjugate vectors: w = u* x v*
 271// 
 272// 48 flops
 273template <typename Float>
 274static void su3Reconstruct12(Float *mat, int dir, int ga_idx, QudaGaugeParam *param) {
 275  Float *u = &mat[0*(3*2)];
 276  Float *v = &mat[1*(3*2)];
 277  Float *w = &mat[2*(3*2)];
 278  w[0] = 0.0; w[1] = 0.0; w[2] = 0.0; w[3] = 0.0; w[4] = 0.0; w[5] = 0.0;
 279  accumulateConjugateProduct(w+0*(2), u+1*(2), v+2*(2), +1);
 280  accumulateConjugateProduct(w+0*(2), u+2*(2), v+1*(2), -1);
 281  accumulateConjugateProduct(w+1*(2), u+2*(2), v+0*(2), +1);
 282  accumulateConjugateProduct(w+1*(2), u+0*(2), v+2*(2), -1);
 283  accumulateConjugateProduct(w+2*(2), u+0*(2), v+1*(2), +1);
 284  accumulateConjugateProduct(w+2*(2), u+1*(2), v+0*(2), -1);
 285  Float u0 = (dir < 3 ? param->anisotropy :
 286	      (ga_idx >= (Z[3]-1)*Z[0]*Z[1]*Z[2]/2 ? param->t_boundary : 1));
 287  w[0]*=u0; w[1]*=u0; w[2]*=u0; w[3]*=u0; w[4]*=u0; w[5]*=u0;
 288}
 289
 290template <typename Float>
 291static void su3Reconstruct8(Float *mat, int dir, int ga_idx, QudaGaugeParam *param) {
 292  // First reconstruct first row
 293  Float row_sum = 0.0;
 294  row_sum += mat[2]*mat[2];
 295  row_sum += mat[3]*mat[3];
 296  row_sum += mat[4]*mat[4];
 297  row_sum += mat[5]*mat[5];
 298  Float u0 = (dir < 3 ? param->anisotropy :
 299	      (ga_idx >= (Z[3]-1)*Z[0]*Z[1]*Z[2]/2 ? param->t_boundary : 1));
 300  Float U00_mag = sqrt(1.f/(u0*u0) - row_sum);
 301
 302  mat[14] = mat[0];
 303  mat[15] = mat[1];
 304
 305  mat[0] = U00_mag * cos(mat[14]);
 306  mat[1] = U00_mag * sin(mat[14]);
 307
 308  Float column_sum = 0.0;
 309  for (int i=0; i<2; i++) column_sum += mat[i]*mat[i];
 310  for (int i=6; i<8; i++) column_sum += mat[i]*mat[i];
 311  Float U20_mag = sqrt(1.f/(u0*u0) - column_sum);
 312
 313  mat[12] = U20_mag * cos(mat[15]);
 314  mat[13] = U20_mag * sin(mat[15]);
 315
 316  // First column now restored
 317
 318  // finally reconstruct last elements from SU(2) rotation
 319  Float r_inv2 = 1.0/(u0*row_sum);
 320
 321  // U11
 322  Float A[2];
 323  complexDotProduct(A, mat+0, mat+6);
 324  complexConjugateProduct(mat+8, mat+12, mat+4);
 325  accumulateComplexProduct(mat+8, A, mat+2, u0);
 326  mat[8] *= -r_inv2;
 327  mat[9] *= -r_inv2;
 328
 329  // U12
 330  complexConjugateProduct(mat+10, mat+12, mat+2);
 331  accumulateComplexProduct(mat+10, A, mat+4, -u0);
 332  mat[10] *= r_inv2;
 333  mat[11] *= r_inv2;
 334
 335  // U21
 336  complexDotProduct(A, mat+0, mat+12);
 337  complexConjugateProduct(mat+14, mat+6, mat+4);
 338  accumulateComplexProduct(mat+14, A, mat+2, -u0);
 339  mat[14] *= r_inv2;
 340  mat[15] *= r_inv2;
 341
 342  // U12
 343  complexConjugateProduct(mat+16, mat+6, mat+2);
 344  accumulateComplexProduct(mat+16, A, mat+4, u0);
 345  mat[16] *= -r_inv2;
 346  mat[17] *= -r_inv2;
 347}
 348
 349void su3_reconstruct(void *mat, int dir, int ga_idx, QudaReconstructType reconstruct, QudaPrecision precision, QudaGaugeParam *param) {
 350  if (reconstruct == QUDA_RECONSTRUCT_12) {
 351    if (precision == QUDA_DOUBLE_PRECISION) su3Reconstruct12((double*)mat, dir, ga_idx, param);
 352    else su3Reconstruct12((float*)mat, dir, ga_idx, param);
 353  } else {
 354    if (precision == QUDA_DOUBLE_PRECISION) su3Reconstruct8((double*)mat, dir, ga_idx, param);
 355    else su3Reconstruct8((float*)mat, dir, ga_idx, param);
 356  }
 357}
 358
 359/*
 360  void su3_construct_8_half(float *mat, short *mat_half) {
 361  su3Construct8(mat);
 362
 363  mat_half[0] = floatToShort(mat[0] / M_PI);
 364  mat_half[1] = floatToShort(mat[1] / M_PI);
 365  for (int i=2; i<18; i++) {
 366  mat_half[i] = floatToShort(mat[i]);
 367  }
 368  }
 369
 370  void su3_reconstruct_8_half(float *mat, short *mat_half, int dir, int ga_idx, QudaGaugeParam *param) {
 371
 372  for (int i=0; i<18; i++) {
 373  mat[i] = shortToFloat(mat_half[i]);
 374  }
 375  mat[0] *= M_PI;
 376  mat[1] *= M_PI;
 377
 378  su3Reconstruct8(mat, dir, ga_idx, param);
 379  }*/
 380
 381template <typename Float>
 382static int compareFloats(Float *a, Float *b, int len, double epsilon) {
 383  for (int i = 0; i < len; i++) {
 384    double diff = fabs(a[i] - b[i]);
 385    if (diff > epsilon) {
 386      printfQuda("ERROR: i=%d, a[%d]=%f, b[%d]=%f\n", i, i, a[i], i, b[i]);
 387      return 0;
 388    }
 389  }
 390  return 1;
 391}
 392
 393int compare_floats(void *a, void *b, int len, double epsilon, QudaPrecision precision) {
 394  if  (precision == QUDA_DOUBLE_PRECISION) return compareFloats((double*)a, (double*)b, len, epsilon);
 395  else return compareFloats((float*)a, (float*)b, len, epsilon);
 396}
 397
 398int fullLatticeIndex(int dim[4], int index, int oddBit){
 399
 400  int za = index/(dim[0]>>1);
 401  int zb = za/dim[1];
 402  int x2 = za - zb*dim[1];
 403  int x4 = zb/dim[2];
 404  int x3 = zb - x4*dim[2];
 405  
 406  return  2*index + ((x2 + x3 + x4 + oddBit) & 1);
 407}
 408
 409// given a "half index" i into either an even or odd half lattice (corresponding
 410// to oddBit = {0, 1}), returns the corresponding full lattice index.
 411int fullLatticeIndex(int i, int oddBit) {
 412  /*
 413    int boundaryCrossings = i/(Z[0]/2) + i/(Z[1]*Z[0]/2) + i/(Z[2]*Z[1]*Z[0]/2);
 414    return 2*i + (boundaryCrossings + oddBit) % 2;
 415  */
 416
 417  int X1 = Z[0];  
 418  int X2 = Z[1];
 419  int X3 = Z[2];
 420  //int X4 = Z[3];
 421  int X1h =X1/2;
 422
 423  int sid =i;
 424  int za = sid/X1h;
 425  //int x1h = sid - za*X1h;
 426  int zb = za/X2;
 427  int x2 = za - zb*X2;
 428  int x4 = zb/X3;
 429  int x3 = zb - x4*X3;
 430  int x1odd = (x2 + x3 + x4 + oddBit) & 1;
 431  //int x1 = 2*x1h + x1odd;
 432  int X = 2*sid + x1odd; 
 433
 434  return X;
 435}
 436
 437
 438// i represents a "half index" into an even or odd "half lattice".
 439// when oddBit={0,1} the half lattice is {even,odd}.
 440// 
 441// the displacements, such as dx, refer to the full lattice coordinates. 
 442//
 443// neighborIndex() takes a "half index", displaces it, and returns the
 444// new "half index", which can be an index into either the even or odd lattices.
 445// displacements of magnitude one always interchange odd and even lattices.
 446//
 447
 448int neighborIndex(int i, int oddBit, int dx4, int dx3, int dx2, int dx1) {
 449  int Y = fullLatticeIndex(i, oddBit);
 450  int x4 = Y/(Z[2]*Z[1]*Z[0]);
 451  int x3 = (Y/(Z[1]*Z[0])) % Z[2];
 452  int x2 = (Y/Z[0]) % Z[1];
 453  int x1 = Y % Z[0];
 454  
 455  // assert (oddBit == (x+y+z+t)%2);
 456  
 457  x4 = (x4+dx4+Z[3]) % Z[3];
 458  x3 = (x3+dx3+Z[2]) % Z[2];
 459  x2 = (x2+dx2+Z[1]) % Z[1];
 460  x1 = (x1+dx1+Z[0]) % Z[0];
 461  
 462  return (x4*(Z[2]*Z[1]*Z[0]) + x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;
 463}
 464
 465
 466int neighborIndex(int dim[4], int index, int oddBit, int dx[4]){
 467
 468  const int fullIndex = fullLatticeIndex(dim, index, oddBit);
 469
 470  int x[4];
 471  x[3] = fullIndex/(dim[2]*dim[1]*dim[0]);
 472  x[2] = (fullIndex/(dim[1]*dim[0])) % dim[2];
 473  x[1] = (fullIndex/dim[0]) % dim[1];
 474  x[0] = fullIndex % dim[0];
 475
 476  for(int dir=0; dir<4; ++dir)
 477    x[dir] = (x[dir]+dx[dir]+dim[dir]) % dim[dir];
 478
 479  return (((x[3]*dim[2] + x[2])*dim[1] + x[1])*dim[0] + x[0])/2;
 480}
 481
 482int
 483neighborIndex_mg(int i, int oddBit, int dx4, int dx3, int dx2, int dx1)
 484{
 485  int ret;
 486  
 487  int Y = fullLatticeIndex(i, oddBit);
 488  int x4 = Y/(Z[2]*Z[1]*Z[0]);
 489  int x3 = (Y/(Z[1]*Z[0])) % Z[2];
 490  int x2 = (Y/Z[0]) % Z[1];
 491  int x1 = Y % Z[0];
 492  
 493  int ghost_x4 = x4+ dx4;
 494  
 495  // assert (oddBit == (x+y+z+t)%2);
 496  
 497  x4 = (x4+dx4+Z[3]) % Z[3];
 498  x3 = (x3+dx3+Z[2]) % Z[2];
 499  x2 = (x2+dx2+Z[1]) % Z[1];
 500  x1 = (x1+dx1+Z[0]) % Z[0];
 501  
 502  if ( ghost_x4 >= 0 && ghost_x4 < Z[3]){
 503    ret = (x4*(Z[2]*Z[1]*Z[0]) + x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;
 504  }else{
 505    ret = (x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;    
 506  }
 507
 508  
 509  return ret;
 510}
 511
 512
 513/*  
 514 * This is a computation of neighbor using the full index and the displacement in each direction
 515 *
 516 */
 517
 518int
 519neighborIndexFullLattice(int i, int dx4, int dx3, int dx2, int dx1) 
 520{
 521  int oddBit = 0;
 522  int half_idx = i;
 523  if (i >= Vh){
 524    oddBit =1;
 525    half_idx = i - Vh;
 526  }
 527    
 528  int nbr_half_idx = neighborIndex(half_idx, oddBit, dx4,dx3,dx2,dx1);
 529  int oddBitChanged = (dx4+dx3+dx2+dx1)%2;
 530  if (oddBitChanged){
 531    oddBit = 1 - oddBit;
 532  }
 533  int ret = nbr_half_idx;
 534  if (oddBit){
 535    ret = Vh + nbr_half_idx;
 536  }
 537    
 538  return ret;
 539}
 540
 541int
 542neighborIndexFullLattice(int dim[4], int index, int dx[4])
 543{
 544  const int volume = dim[0]*dim[1]*dim[2]*dim[3];
 545  const int halfVolume = volume/2;
 546  int oddBit = 0;
 547  int halfIndex = index;
 548
 549  if(index >= halfVolume){
 550    oddBit = 1;
 551    halfIndex = index - halfVolume;
 552  }
 553
 554  int neighborHalfIndex = neighborIndex(dim, halfIndex, oddBit, dx);
 555
 556  int oddBitChanged = (dx[0]+dx[1]+dx[2]+dx[3])%2;
 557  if(oddBitChanged){
 558    oddBit = 1 - oddBit;
 559  }
 560
 561  return neighborHalfIndex + oddBit*halfVolume;
 562}
 563
 564
 565
 566int
 567neighborIndexFullLattice_mg(int i, int dx4, int dx3, int dx2, int dx1) 
 568{
 569  int ret;
 570  int oddBit = 0;
 571  int half_idx = i;
 572  if (i >= Vh){
 573    oddBit =1;
 574    half_idx = i - Vh;
 575  }
 576    
 577  int Y = fullLatticeIndex(half_idx, oddBit);
 578  int x4 = Y/(Z[2]*Z[1]*Z[0]);
 579  int x3 = (Y/(Z[1]*Z[0])) % Z[2];
 580  int x2 = (Y/Z[0]) % Z[1];
 581  int x1 = Y % Z[0];
 582  int ghost_x4 = x4+ dx4;
 583    
 584  x4 = (x4+dx4+Z[3]) % Z[3];
 585  x3 = (x3+dx3+Z[2]) % Z[2];
 586  x2 = (x2+dx2+Z[1]) % Z[1];
 587  x1 = (x1+dx1+Z[0]) % Z[0];
 588
 589  if ( ghost_x4 >= 0 && ghost_x4 < Z[3]){
 590    ret = (x4*(Z[2]*Z[1]*Z[0]) + x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;
 591  }else{
 592    ret = (x3*(Z[1]*Z[0]) + x2*(Z[0]) + x1) / 2;    
 593    return ret;
 594  }
 595
 596  int oddBitChanged = (dx4+dx3+dx2+dx1)%2;
 597  if (oddBitChanged){
 598    oddBit = 1 - oddBit;
 599  }
 600    
 601  if (oddBit){
 602    ret += Vh;
 603  }
 604    
 605  return ret;
 606}
 607
 608
 609// 4d checkerboard.
 610// given a "half index" i into either an even or odd half lattice (corresponding
 611// to oddBit = {0, 1}), returns the corresponding full lattice index.
 612// Cf. GPGPU code in dslash_core_ante.h.
 613// There, i is the thread index.
 614int fullLatticeIndex_4d(int i, int oddBit) {
 615  if (i >= Vh || i < 0) {printf("i out of range in fullLatticeIndex_4d"); exit(-1);}
 616  /*
 617    int boundaryCrossings = i/(Z[0]/2) + i/(Z[1]*Z[0]/2) + i/(Z[2]*Z[1]*Z[0]/2);
 618    return 2*i + (boundaryCrossings + oddBit) % 2;
 619  */
 620
 621  int X1 = Z[0];  
 622  int X2 = Z[1];
 623  int X3 = Z[2];
 624  //int X4 = Z[3];
 625  int X1h =X1/2;
 626
 627  int sid =i;
 628  int za = sid/X1h;
 629  //int x1h = sid - za*X1h;
 630  int zb = za/X2;
 631  int x2 = za - zb*X2;
 632  int x4 = zb/X3;
 633  int x3 = zb - x4*X3;
 634  int x1odd = (x2 + x3 + x4 + oddBit) & 1;
 635  //int x1 = 2*x1h + x1odd;
 636  int X = 2*sid + x1odd; 
 637
 638  return X;
 639}
 640
 641// 5d checkerboard.
 642// given a "half index" i into either an even or odd half lattice (corresponding
 643// to oddBit = {0, 1}), returns the corresponding full lattice index.
 644// Cf. GPGPU code in dslash_core_ante.h.
 645// There, i is the thread index sid.
 646// This function is used by neighborIndex_5d in dslash_reference.cpp.
 647//ok
 648int fullLatticeIndex_5d(int i, int oddBit) {
 649  int boundaryCrossings = i/(Z[0]/2) + i/(Z[1]*Z[0]/2) + i/(Z[2]*Z[1]*Z[0]/2) + i/(Z[3]*Z[2]*Z[1]*Z[0]/2);
 650  return 2*i + (boundaryCrossings + oddBit) % 2;
 651}
 652
 653int 
 654x4_from_full_index(int i)
 655{
 656  int oddBit = 0;
 657  int half_idx = i;
 658  if (i >= Vh){
 659    oddBit =1;
 660    half_idx = i - Vh;
 661  }
 662  
 663  int Y = fullLatticeIndex(half_idx, oddBit);
 664  int x4 = Y/(Z[2]*Z[1]*Z[0]);
 665  
 666  return x4;
 667}
 668
 669template <typename Float>
 670static void applyGaugeFieldScaling(Float **gauge, int Vh, QudaGaugeParam *param) {
 671  // Apply spatial scaling factor (u0) to spatial links
 672  for (int d = 0; d < 3; d++) {
 673    for (int i = 0; i < gaugeSiteSize*Vh*2; i++) {
 674      gauge[d][i] /= param->anisotropy;
 675    }
 676  }
 677    
 678  // only apply T-boundary at edge nodes
 679#ifdef MULTI_GPU
 680  bool last_node_in_t = (commCoords(3) == commDim(3)-1) ? true : false;
 681#else
 682  bool last_node_in_t = true;
 683#endif
 684
 685  // Apply boundary conditions to temporal links
 686  if (param->t_boundary == QUDA_ANTI_PERIODIC_T && last_node_in_t) {
 687    for (int j = (Z[0]/2)*Z[1]*Z[2]*(Z[3]-1); j < Vh; j++) {
 688      for (int i = 0; i < gaugeSiteSize; i++) {
 689	gauge[3][j*gaugeSiteSize+i] *= -1.0;
 690	gauge[3][(Vh+j)*gaugeSiteSize+i] *= -1.0;
 691      }
 692    }
 693  }
 694    
 695  if (param->gauge_fix) {
 696    // set all gauge links (except for the last Z[0]*Z[1]*Z[2]/2) to the identity,
 697    // to simulate fixing to the temporal gauge.
 698    int iMax = ( last_node_in_t ? (Z[0]/2)*Z[1]*Z[2]*(Z[3]-1) : Vh );
 699    int dir = 3; // time direction only
 700    Float *even = gauge[dir];
 701    Float *odd  = gauge[dir]+Vh*gaugeSiteSize;
 702    for (int i = 0; i< iMax; i++) {
 703      for (int m = 0; m < 3; m++) {
 704	for (int n = 0; n < 3; n++) {
 705	  even[i*(3*3*2) + m*(3*2) + n*(2) + 0] = (m==n) ? 1 : 0;
 706	  even[i*(3*3*2) + m*(3*2) + n*(2) + 1] = 0.0;
 707	  odd [i*(3*3*2) + m*(3*2) + n*(2) + 0] = (m==n) ? 1 : 0;
 708	  odd [i*(3*3*2) + m*(3*2) + n*(2) + 1] = 0.0;
 709	}
 710      }
 711    }
 712  }
 713}
 714
 715template <typename Float>
 716void applyGaugeFieldScaling_long(Float **gauge, int Vh, QudaGaugeParam *param)
 717{
 718
 719  int X1h=param->X[0]/2;
 720  int X1 =param->X[0];
 721  int X2 =param->X[1];
 722  int X3 =param->X[2];
 723  int X4 =param->X[3];
 724
 725  // rescale long links by the appropriate coefficient
 726  for(int d=0; d<4; d++){
 727    for(int i=0; i < V*gaugeSiteSize; i++){
 728      gauge[d][i] /= (-24*param->tadpole_coeff*param->tadpole_coeff);
 729    }
 730  }
 731
 732  // apply the staggered phases
 733  for (int d = 0; d < 3; d++) {
 734
 735    //even
 736    for (int i = 0; i < Vh; i++) {
 737
 738      int index = fullLatticeIndex(i, 0);
 739      int i4 = index /(X3*X2*X1);
 740      int i3 = (index - i4*(X3*X2*X1))/(X2*X1);
 741      int i2 = (index - i4*(X3*X2*X1) - i3*(X2*X1))/X1;
 742      int i1 = index - i4*(X3*X2*X1) - i3*(X2*X1) - i2*X1;
 743      int sign=1;
 744
 745      if (d == 0) {
 746	if (i4 % 2 == 1){
 747	  sign= -1;
 748	}
 749      }
 750
 751      if (d == 1){
 752	if ((i4+i1) % 2 == 1){
 753	  sign= -1;
 754	}
 755      }
 756      if (d == 2){
 757	if ( (i4+i1+i2) % 2 == 1){
 758	  sign= -1;
 759	}
 760      }
 761
 762      for (int j=0;j < 6; j++){
 763	gauge[d][i*gaugeSiteSize + 12+ j] *= sign;
 764      }
 765    }
 766    //odd
 767    for (int i = 0; i < Vh; i++) {
 768      int index = fullLatticeIndex(i, 1);
 769      int i4 = index /(X3*X2*X1);
 770      int i3 = (index - i4*(X3*X2*X1))/(X2*X1);
 771      int i2 = (index - i4*(X3*X2*X1) - i3*(X2*X1))/X1;
 772      int i1 = index - i4*(X3*X2*X1) - i3*(X2*X1) - i2*X1;
 773      int sign=1;
 774
 775      if (d == 0) {
 776	if (i4 % 2 == 1){
 777	  sign= -1;
 778	}
 779      }
 780
 781      if (d == 1){
 782	if ((i4+i1) % 2 == 1){
 783	  sign= -1;
 784	}
 785      }
 786      if (d == 2){
 787	if ( (i4+i1+i2) % 2 == 1){
 788	  sign = -1;
 789	}
 790      }
 791
 792      for (int j=0;j < 6; j++){
 793	gauge[d][(Vh+i)*gaugeSiteSize + 12 + j] *= sign;
 794      }
 795    }
 796
 797  }
 798
 799  // Apply boundary conditions to temporal links
 800  if (param->t_boundary == QUDA_ANTI_PERIODIC_T) {
 801    for (int j = 0; j < Vh; j++) {
 802      int sign =1;
 803      if (j >= (X4-3)*X1h*X2*X3 ){
 804	sign= -1;
 805      }
 806
 807      for (int i = 0; i < 6; i++) {
 808	gauge[3][j*gaugeSiteSize+ 12+ i ] *= sign;
 809	gauge[3][(Vh+j)*gaugeSiteSize+12 +i] *= sign;
 810      }
 811    }
 812  }
 813}
 814
 815
 816
 817template <typename Float>
 818static void constructUnitGaugeField(Float **res, QudaGaugeParam *param) {
 819  Float *resOdd[4], *resEven[4];
 820  for (int dir = 0; dir < 4; dir++) {  
 821    resEven[dir] = res[dir];
 822    resOdd[dir]  = res[dir]+Vh*gaugeSiteSize;
 823  }
 824    
 825  for (int dir = 0; dir < 4; dir++) {
 826    for (int i = 0; i < Vh; i++) {
 827      for (int m = 0; m < 3; m++) {
 828	for (int n = 0; n < 3; n++) {
 829	  resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = (m==n) ? 1 : 0;
 830	  resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = 0.0;
 831	  resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = (m==n) ? 1 : 0;
 832	  resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = 0.0;
 833	}
 834      }
 835    }
 836  }
 837    
 838  applyGaugeFieldScaling(res, Vh, param);
 839}
 840
 841// normalize the vector a
 842template <typename Float>
 843static void normalize(complex<Float> *a, int len) {
 844  double sum = 0.0;
 845  for (int i=0; i<len; i++) sum += norm(a[i]);
 846  for (int i=0; i<len; i++) a[i] /= sqrt(sum);
 847}
 848
 849// orthogonalize vector b to vector a
 850template <typename Float>
 851static void orthogonalize(complex<Float> *a, complex<Float> *b, int len) {
 852  complex<double> dot = 0.0;
 853  for (int i=0; i<len; i++) dot += conj(a[i])*b[i];
 854  for (int i=0; i<len; i++) b[i] -= (complex<Float>)dot*a[i];
 855}
 856
 857template <typename Float> 
 858static void constructGaugeField(Float **res, QudaGaugeParam *param) {
 859  Float *resOdd[4], *resEven[4];
 860  for (int dir = 0; dir < 4; dir++) {  
 861    resEven[dir] = res[dir];
 862    resOdd[dir]  = res[dir]+Vh*gaugeSiteSize;
 863  }
 864    
 865  for (int dir = 0; dir < 4; dir++) {
 866    for (int i = 0; i < Vh; i++) {
 867      for (int m = 1; m < 3; m++) { // last 2 rows
 868	for (int n = 0; n < 3; n++) { // 3 columns
 869	  resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = rand() / (Float)RAND_MAX;
 870	  resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = rand() / (Float)RAND_MAX;
 871	  resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = rand() / (Float)RAND_MAX;
 872	  resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = rand() / (Float)RAND_MAX;                    
 873	}
 874      }
 875      normalize((complex<Float>*)(resEven[dir] + (i*3+1)*3*2), 3);
 876      orthogonalize((complex<Float>*)(resEven[dir] + (i*3+1)*3*2), (complex<Float>*)(resEven[dir] + (i*3+2)*3*2), 3);
 877      normalize((complex<Float>*)(resEven[dir] + (i*3 + 2)*3*2), 3);
 878      
 879      normalize((complex<Float>*)(resOdd[dir] + (i*3+1)*3*2), 3);
 880      orthogonalize((complex<Float>*)(resOdd[dir] + (i*3+1)*3*2), (complex<Float>*)(resOdd[dir] + (i*3+2)*3*2), 3);
 881      normalize((complex<Float>*)(resOdd[dir] + (i*3 + 2)*3*2), 3);
 882
 883      {
 884	Float *w = resEven[dir]+(i*3+0)*3*2;
 885	Float *u = resEven[dir]+(i*3+1)*3*2;
 886	Float *v = resEven[dir]+(i*3+2)*3*2;
 887	
 888	for (int n = 0; n < 6; n++) w[n] = 0.0;
 889	accumulateConjugateProduct(w+0*(2), u+1*(2), v+2*(2), +1);
 890	accumulateConjugateProduct(w+0*(2), u+2*(2), v+1*(2), -1);
 891	accumulateConjugateProduct(w+1*(2), u+2*(2), v+0*(2), +1);
 892	accumulateConjugateProduct(w+1*(2), u+0*(2), v+2*(2), -1);
 893	accumulateConjugateProduct(w+2*(2), u+0*(2), v+1*(2), +1);
 894	accumulateConjugateProduct(w+2*(2), u+1*(2), v+0*(2), -1);
 895      }
 896
 897      {
 898	Float *w = resOdd[dir]+(i*3+0)*3*2;
 899	Float *u = resOdd[dir]+(i*3+1)*3*2;
 900	Float *v = resOdd[dir]+(i*3+2)*3*2;
 901	
 902	for (int n = 0; n < 6; n++) w[n] = 0.0;
 903	accumulateConjugateProduct(w+0*(2), u+1*(2), v+2*(2), +1);
 904	accumulateConjugateProduct(w+0*(2), u+2*(2), v+1*(2), -1);
 905	accumulateConjugateProduct(w+1*(2), u+2*(2), v+0*(2), +1);
 906	accumulateConjugateProduct(w+1*(2), u+0*(2), v+2*(2), -1);
 907	accumulateConjugateProduct(w+2*(2), u+0*(2), v+1*(2), +1);
 908	accumulateConjugateProduct(w+2*(2), u+1*(2), v+0*(2), -1);
 909      }
 910
 911    }
 912  }
 913
 914  if (param->type == QUDA_WILSON_LINKS){  
 915    applyGaugeFieldScaling(res, Vh, param);
 916  } else if (param->type == QUDA_ASQTAD_LONG_LINKS){
 917    applyGaugeFieldScaling_long(res, Vh, param);      
 918  } else if (param->type == QUDA_ASQTAD_FAT_LINKS){
 919    for (int dir = 0; dir < 4; dir++){ 
 920      for (int i = 0; i < Vh; i++) {
 921	for (int m = 0; m < 3; m++) { // last 2 rows
 922	  for (int n = 0; n < 3; n++) { // 3 columns
 923	    resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] =1.0* rand() / (Float)RAND_MAX;
 924	    resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] =2.0* rand() / (Float)RAND_MAX;
 925	    resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = 3.0*rand() / (Float)RAND_MAX;
 926	    resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = 4.0*rand() / (Float)RAND_MAX;
 927	  }
 928	}
 929      }
 930    }
 931    
 932  }
 933
 934}
 935
 936template <typename Float> 
 937void constructUnitaryGaugeField(Float **res) 
 938{
 939  Float *resOdd[4], *resEven[4];
 940  for (int dir = 0; dir < 4; dir++) {  
 941    resEven[dir] = res[dir];
 942    resOdd[dir]  = res[dir]+Vh*gaugeSiteSize;
 943  }
 944  
 945  for (int dir = 0; dir < 4; dir++) {
 946    for (int i = 0; i < Vh; i++) {
 947      for (int m = 1; m < 3; m++) { // last 2 rows
 948	for (int n = 0; n < 3; n++) { // 3 columns
 949	  resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = rand() / (Float)RAND_MAX;
 950	  resEven[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = rand() / (Float)RAND_MAX;
 951	  resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 0] = rand() / (Float)RAND_MAX;
 952	  resOdd[dir][i*(3*3*2) + m*(3*2) + n*(2) + 1] = rand() / (Float)RAND_MAX;                    
 953	}
 954      }
 955      normalize((complex<Float>*)(resEven[dir] + (i*3+1)*3*2), 3);
 956      orthogonalize((complex<Float>*)(resEven[dir] + (i*3+1)*3*2), (complex<Float>*)(resEven[dir] + (i*3+2)*3*2), 3);
 957      normalize((complex<Float>*)(resEven[dir] + (i*3 + 2)*3*2), 3);
 958      
 959      normalize((complex<Float>*)(resOdd[dir] + (i*3+1)*3*2), 3);
 960      orthogonalize((complex<Float>*)(resOdd[dir] + (i*3+1)*3*2), (complex<Float>*)(resOdd[dir] + (i*3+2)*3*2), 3);
 961      normalize((complex<Float>*)(resOdd[dir] + (i*3 + 2)*3*2), 3);
 962
 963      {
 964	Float *w = resEven[dir]+(i*3+0)*3*2;
 965	Float *u = resEven[dir]+(i*3+1)*3*2;
 966	Float *v = resEven[dir]+(i*3+2)*3*2;
 967	
 968	for (int n = 0; n < 6; n++) w[n] = 0.0;
 969	accumulateConjugateProduct(w+0*(2), u+1*(2), v+2*(2), +1);
 970	accumulateConjugateProduct(w+0*(2), u+2*(2), v+1*(2), -1);
 971	accumulateConjugateProduct(w+1*(2), u+2*(2), v+0*(2), +1);
 972	accumulateConjugateProduct(w+1*(2), u+0*(2), v+2*(2), -1);
 973	accumulateConjugateProduct(w+2*(2), u+0*(2), v+1*(2), +1);
 974	accumulateConjugateProduct(w+2*(2), u+1*(2), v+0*(2), -1);
 975      }
 976      
 977      {
 978	Float *w = resOdd[dir]+(i*3+0)*3*2;
 979	Float *u = resOdd[dir]+(i*3+1)*3*2;
 980	Float *v = resOdd[dir]+(i*3+2)*3*2;
 981	
 982	for (int n = 0; n < 6; n++) w[n] = 0.0;
 983	accumulateConjugateProduct(w+0*(2), u+1*(2), v+2*(2), +1);
 984	accumulateConjugateProduct(w+0*(2), u+2*(2), v+1*(2), -1);
 985	accumulateConjugateProduct(w+1*(2), u+2*(2), v+0*(2), +1);
 986	accumulateConjugateProduct(w+1*(2), u+0*(2), v+2*(2), -1);
 987	accumulateConjugateProduct(w+2*(2), u+0*(2), v+1*(2), +1);
 988	accumulateConjugateProduct(w+2*(2), u+1*(2), v+0*(2), -1);
 989      }
 990      
 991    }
 992  }
 993}
 994
 995
 996void construct_gauge_field(void **gauge, int type, QudaPrecision precision, QudaGaugeParam *param) {
 997  if (type == 0) {
 998    if (precision == QUDA_DOUBLE_PRECISION) constructUnitGaugeField((double**)gauge, param);
 999    else constructUnitGaugeField((float**)gauge, param);
1000  } else if (type == 1) {
1001    if (precision == QUDA_DOUBLE_PRECISION) constructGaugeField((double**)gauge, param);
1002    else constructGaugeField((float**)gauge, param);
1003  } else {
1004    if (precision == QUDA_DOUBLE_PRECISION) applyGaugeFieldScaling((double**)gauge, Vh, param);
1005    else applyGaugeFieldScaling((float**)gauge, Vh, param);    
1006  }
1007
1008}
1009
1010void
1011construct_fat_long_gauge_field(void **fatlink, void** longlink,  
1012			       int type, QudaPrecision precision, QudaGaugeParam* param)
1013{
1014  if (type == 0) {
1015    if (precision == QUDA_DOUBLE_PRECISION) {
1016      constructUnitGaugeField((double**)fatlink, param);
1017      constructUnitGaugeField((double**)longlink, param);
1018    }else {
1019      constructUnitGaugeField((float**)fatlink, param);
1020      constructUnitGaugeField((float**)longlink, param);
1021    }
1022  } else {
1023    if (precision == QUDA_DOUBLE_PRECISION) {
1024      param->type = QUDA_ASQTAD_FAT_LINKS;
1025      constructGaugeField((double**)fatlink, param);
1026      param->type = QUDA_ASQTAD_LONG_LINKS;
1027      constructGaugeField((double**)longlink, param);
1028    }else {
1029      param->type = QUDA_ASQTAD_FAT_LINKS;
1030      constructGaugeField((float**)fatlink, param);
1031      param->type = QUDA_ASQTAD_LONG_LINKS;
1032      constructGaugeField((float**)longlink, param);
1033    }
1034  }
1035}
1036
1037
1038template <typename Float>
1039static void constructCloverField(Float *res, double norm, double diag) {
1040
1041  Float c = 2.0 * norm / RAND_MAX;
1042
1043  for(int i = 0; i < V; i++) {
1044    for (int j = 0; j < 72; j++) {
1045      res[i*72 + j] = c*rand() - norm;
1046    }
1047    for (int j = 0; j< 6; j++) {
1048      res[i*72 + j] += diag;
1049      res[i*72 + j+36] += diag;
1050    }
1051  }
1052}
1053
1054void construct_clover_field(void *clover, double norm, double diag, QudaPrecision precision) {
1055
1056  if (precision == QUDA_DOUBLE_PRECISION) constructCloverField((double *)clover, norm, diag);
1057  else constructCloverField((float *)clover, norm, diag);
1058}
1059
1060/*void strong_check(void *spinorRef, void *spinorGPU, int len, QudaPrecision prec) {
1061  printf("Reference:\n");
1062  printSpinorElement(spinorRef, 0, prec); printf("...\n");
1063  printSpinorElement(spinorRef, len-1, prec); printf("\n");    
1064    
1065  printf("\nCUDA:\n");
1066  printSpinorElement(spinorGPU, 0, prec); printf("...\n");
1067  printSpinorElement(spinorGPU, len-1, prec); printf("\n");
1068
1069  compare_spinor(spinorRef, spinorGPU, len, prec);
1070  }*/
1071
1072template <typename Float>
1073static void checkGauge(Float **oldG, Float **newG, double epsilon) {
1074
1075  const int fail_check = 17;
1076  int fail[4][fail_check];
1077  int iter[4][18];
1078  for (int d=0; d<4; d++) for (int i=0; i<fail_check; i++) fail[d][i] = 0;
1079  for (int d=0; d<4; d++) for (int i=0; i<18; i++) iter[d][i] = 0;
1080
1081  for (int d=0; d<4; d++) {
1082    for (int eo=0; eo<2; eo++) {
1083      for (int i=0; i<Vh; i++) {
1084	int ga_idx = (eo*Vh+i);
1085	for (int j=0; j<18; j++) {
1086	  double diff = fabs(newG[d][ga_idx*18+j] - oldG[d][ga_idx*18+j]);/// fabs(oldG[d][ga_idx*18+j]);
1087
1088	  for (int f=0; f<fail_check; f++) if (diff > pow(10.0,-(f+1))) fail[d][f]++;
1089	  if (diff > epsilon) iter[d][j]++;
1090	}
1091      }
1092    }
1093  }
1094
1095  printf("Component fails (X, Y, Z, T)\n");
1096  for (int i=0; i<18; i++) printf("%d fails = (%8d, %8d, %8d, %8d)\n", i, iter[0][i], iter[1][i], iter[2][i], iter[3][i]);
1097
1098  printf("\nDeviation Failures = (X, Y, Z, T)\n");
1099  for (int f=0; f<fail_check; f++) {
1100    printf("%e Failures = (%9d, %9d, %9d, %9d) = (%6.5f, %6.5f, %6.5f, %6.5f)\n", pow(10.0,-(f+1)), 
1101	   fail[0][f], fail[1][f], fail[2][f], fail[3][f],
1102	   fail[0][f]/(double)(V*18), fail[1][f]/(double)(V*18), fail[2][f]/(double)(V*18), fail[3][f]/(double)(V*18));
1103  }
1104
1105}
1106
1107void check_gauge(void **oldG, void **newG, double epsilon, QudaPrecision precision) {
1108  if (precision == QUDA_DOUBLE_PRECISION) 
1109    checkGauge((double**)oldG, (double**)newG, epsilon);
1110  else 
1111    checkGauge((float**)oldG, (float**)newG, epsilon);
1112}
1113
1114
1115
1116void 
1117createSiteLinkCPU(void** link,  QudaPrecision precision, int phase) 
1118{
1119    
1120  if (precision == QUDA_DOUBLE_PRECISION) {
1121    constructUnitaryGaugeField((double**)link);
1122  }else {
1123    constructUnitaryGaugeField((float**)link);
1124  }
1125
1126  // only apply temporal boundary condition if I'm the last node in T
1127#ifdef MULTI_GPU
1128  bool last_node_in_t = (commCoords(3) == commDim(3)-1) ? true : false;
1129#else
1130  bool last_node_in_t = true;
1131#endif
1132
1133  if(phase){
1134	
1135    for(int i=0;i < V;i++){
1136      for(int dir =XUP; dir <= TUP; dir++){
1137	int idx = i;
1138	int oddBit =0;
1139	if (i >= Vh) {
1140	  idx = i - Vh;
1141	  oddBit = 1;
1142	}
1143
1144	int X1 = Z[0];
1145	int X2 = Z[1];
1146	int X3 = Z[2];
1147	int X4 = Z[3];
1148
1149	int full_idx = fullLatticeIndex(idx, oddBit);
1150	int i4 = full_idx /(X3*X2*X1);
1151	int i3 = (full_idx - i4*(X3*X2*X1))/(X2*X1);
1152	int i2 = (full_idx - i4*(X3*X2*X1) - i3*(X2*X1))/X1;
1153	int i1 = full_idx - i4*(X3*X2*X1) - i3*(X2*X1) - i2*X1;	    
1154
1155	double coeff= 1.0;
1156	switch(dir){
1157	case XUP:
1158	  if ( (i4 & 1) != 0){
1159	    coeff *= -1;
1160	  }
1161	  break;
1162
1163	case YUP:
1164	  if ( ((i4+i1) & 1) != 0){
1165	    coeff *= -1;
1166	  }
1167	  break;
1168
1169	case ZUP:
1170	  if ( ((i4+i1+i2) & 1) != 0){
1171	    coeff *= -1;
1172	  }
1173	  break;
1174		
1175	case TUP:
1176	  if (last_node_in_t && i4 == (X4-1)){
1177	    coeff *= -1;
1178	  }
1179	  break;
1180
1181	default:
1182	  printf("ERROR: wrong dir(%d)\n", dir);
1183	  exit(1);
1184	}
1185	
1186	if (precision == QUDA_DOUBLE_PRECISION){
1187	  //double* mylink = (double*)link;
1188	  //mylink = mylink + (4*i + dir)*gaugeSiteSize;
1189	  double* mylink = (double*)link[dir];
1190	  mylink = mylink + i*gaugeSiteSize;
1191
1192	  mylink[12] *= coeff;
1193	  mylink[13] *= coeff;
1194	  mylink[14] *= coeff;
1195	  mylink[15] *= coeff;
1196	  mylink[16] *= coeff;
1197	  mylink[17] *= coeff;
1198		
1199	}else{
1200	  //float* mylink = (float*)link;
1201	  //mylink = mylink + (4*i + dir)*gaugeSiteSize;
1202	  float* mylink = (float*)link[dir];
1203	  mylink = mylink + i*gaugeSiteSize;
1204		  
1205	  mylink[12] *= coeff;
1206	  mylink[13] *= coeff;
1207	  mylink[14] *= coeff;
1208	  mylink[15] *= coeff;
1209	  mylink[16] *= coeff;
1210	  mylink[17] *= coeff;
1211		
1212	}
1213      }
1214    }
1215  }    
1216
1217    
1218#if 1
1219  for(int dir= 0;dir < 4;dir++){
1220    for(int i=0;i< V*gaugeSiteSize;i++){
1221      if (precision ==QUDA_SINGLE_PRECISION){
1222	float* f = (float*)link[dir];
1223	if (f[i] != f[i] || (fabsf(f[i]) > 1.e+3) ){
1224	  fprintf(stderr, "ERROR:  %dth: bad number(%f) in function %s \n",i, f[i], __FUNCTION__);
1225	  exit(1);
1226	}
1227      }else{
1228	double* f = (double*)link[dir];
1229	if (f[i] != f[i] || (fabs(f[i]) > 1.e+3)){
1230	  fprintf(stderr, "ERROR:  %dth: bad number(%f) in function %s \n",i, f[i], __FUNCTION__);
1231	  exit(1);
1232	}
1233	  
1234      }
1235	
1236    }
1237  }
1238#endif
1239
1240  return;
1241}
1242
1243
1244
1245
1246template <typename Float>
1247int compareLink(Float **linkA, Float **linkB, int len) {
1248  const int fail_check = 16;
1249  int fail[fail_check];
1250  for (int f=0; f<fail_check; f++) fail[f] = 0;
1251
1252  int iter[18];
1253  for (int i=0; i<18; i++) iter[i] = 0;
1254  
1255  for(int dir=0;dir < 4; dir++){
1256    for (int i=0; i<len; i++) {
1257      for (int j=0; j<18; j++) {
1258	int is = i*18+j;
1259	double diff = fabs(linkA[dir][is]-linkB[dir][is]);
1260	for (int f=0; f<fail_check; f++)
1261	  if (diff > pow(10.0,-(f+1))) fail[f]++;
1262	//if (diff > 1e-1) printf("%d %d %e\n", i, j, diff);
1263	if (diff > 1e-3) iter[j]++;
1264      }
1265    }
1266  }
1267  
1268  for (int i=0; i<18; i++) printfQuda("%d fails = %d\n", i, iter[i]);
1269  
1270  int accuracy_level = 0;
1271  for(int f =0; f < fail_check; f++){
1272    if(fail[f] == 0){
1273      accuracy_level =f;
1274    }
1275  }
1276
1277  for (int f=0; f<fail_check; f++) {
1278    printfQuda("%e Failures: %d / %d  = %e\n", pow(10.0,-(f+1)), fail[f], 4*len*18, fail[f] / (double)(4*len*18));
1279  }
1280  
1281  return accuracy_level;
1282}
1283
1284static int
1285compare_link(void **linkA, void **linkB, int len, QudaPrecision precision)
1286{
1287  int ret;
1288
1289  if (precision == QUDA_DOUBLE_PRECISION){    
1290    ret = compareLink((double**)linkA, (double**)linkB, len);
1291  }else {
1292    ret = compareLink((float**)linkA, (float**)linkB, len);
1293  }    
1294
1295  return ret;
1296}
1297
1298
1299// X indexes the lattice site
1300static void 
1301printLinkElement(void *link, int X, QudaPrecision precision) 
1302{
1303  if (precision == QUDA_DOUBLE_PRECISION){
1304    for(int i=0; i < 3;i++){
1305      printVector((double*)link+ X*gaugeSiteSize + i*6);
1306    }
1307	
1308  }
1309  else{
1310    for(int i=0;i < 3;i++){
1311      printVector((float*)link+X*gaugeSiteSize + i*6);
1312    }
1313  }
1314}
1315
1316int strong_check_link(void** linkA, const char* msgA, 
1317		      void **linkB, const char* msgB, 
1318		      int len, QudaPrecision prec) 
1319{
1320  printfQuda("%s\n", msgA);
1321  printLinkElement(linkA[0], 0, prec); 
1322  printfQuda("\n");
1323  printLinkElement(linkA[0], 1, prec); 
1324  printfQuda("...\n");
1325  printLinkElement(linkA[3], len-1, prec); 
1326  printfQuda("\n");    
1327    
1328  printfQuda("\n%s\n", msgB);
1329  printLinkElement(linkB[0], 0, prec); 
1330  printfQuda("\n");
1331  printLinkElement(linkB[0], 1, prec); 
1332  printfQuda("...\n");
1333  printLinkElement(linkB[3], len-1, prec); 
1334  printfQuda("\n");
1335    
1336  int ret = compare_link(linkA, linkB, len, prec);
1337  return ret;
1338}
1339
1340
1341void 
1342createMomCPU(void* mom,  QudaPrecision precision) 
1343{
1344  void* temp;
1345    
1346  size_t gSize = (precision == QUDA_DOUBLE_PRECISION) ? sizeof(double) : sizeof(float);
1347  temp = malloc(4*V*gaugeSiteSize*gSize);
1348  if (temp == NULL){
1349    fprintf(stderr, "Error: malloc failed for temp in function %s\n", __FUNCTION__);
1350    exit(1);
1351  }
1352    
1353    
1354    
1355  for(int i=0;i < V;i++){
1356    if (precision == QUDA_DOUBLE_PRECISION){
1357      for(int dir=0;dir < 4;dir++){
1358	double* thismom = (double*)mom;	    
1359	for(int k=0; k < momSiteSize; k++){
1360	  thismom[ (4*i+dir)*momSiteSize + k ]= 1.0* rand() /RAND_MAX;				
1361	  if (k==momSiteSize-1) thismom[ (4*i+dir)*momSiteSize + k ]= 0.0;
1362	}	    
1363      }	    
1364    }else{
1365      for(int dir=0;dir < 4;dir++){
1366	float* thismom=(float*)mom;
1367	for(int k=0; k < momSiteSize; k++){
1368	  thismom[ (4*i+dir)*momSiteSize + k ]= 1.0* rand() /RAND_MAX;		
1369	  if (k==momSiteSize-1) thismom[ (4*i+dir)*momSiteSize + k ]= 0.0;
1370	}	    
1371      }
1372    }
1373  }
1374    
1375  free(temp);
1376  return;
1377}
1378
1379void
1380createHwCPU(void* hw,  QudaPrecision precision)
1381{
1382  for(int i=0;i < V;i++){
1383    if (precision == QUDA_DOUBLE_PRECISION){
1384      for(int dir=0;dir < 4;dir++){
1385	double* thishw = (double*)hw;
1386	for(int k=0; k < hwSiteSize; k++){
1387	  thishw[ (4*i+dir)*hwSiteSize + k ]= 1.0* rand() /RAND_MAX;
1388	}
1389      }
1390    }else{
1391      for(int dir=0;dir < 4;dir++){
1392	float* thishw=(float*)hw;
1393	for(int k=0; k < hwSiteSize; k++){
1394	  thishw[ (4*i+dir)*hwSiteSize + k ]= 1.0* rand() /RAND_MAX;
1395	}
1396      }
1397    }
1398  }
1399
1400  return;
1401}
1402
1403
1404template <typename Float>
1405int compare_mom(Float *momA, Float *momB, int len) {
1406  const int fail_check = 16;
1407  int fail[fail_check];
1408  for (int f=0; f<fail_check; f++) fail[f] = 0;
1409
1410  int iter[momSiteSize];
1411  for (int i=0; i<momSiteSize; i++) iter[i] = 0;
1412  
1413  for (int i=0; i<len; i++) {
1414    for (int j=0; j<momSiteSize; j++) {
1415      int is = i*momSiteSize+j;
1416      double diff = fabs(momA[is]-momB[is]);
1417      for (int f=0; f<fail_check; f++)
1418	if (diff > pow(10.0,-(f+1))) fail[f]++;
1419      //if (diff > 1e-1) printf("%d %d %e\n", i, j, diff);
1420      if (diff > 1e-3) iter[j]++;
1421    }
1422  }
1423  
1424  int accuracy_level = 0;
1425  for(int f =0; f < fail_check; f++){
1426    if(fail[f] == 0){
1427      accuracy_level =f+1;
1428    }
1429  }
1430
1431  for (int i=0; i<momSiteSize; i++) printfQuda("%d fails = %d\n", i, iter[i]);
1432  
1433  for (int f=0; f<fail_check; f++) {
1434    printfQuda("%e Failures: %d / %d  = %e\n", pow(10.0,-(f+1)), fail[f], len*9, fail[f]/(double)(len*9));
1435  }
1436  
1437  return accuracy_level;
1438}
1439
1440static void 
1441printMomElement(void *mom, int X, QudaPrecision precision) 
1442{
1443  if (precision == QUDA_DOUBLE_PRECISION){
1444    double* thismom = ((double*)mom)+ X*momSiteSize;
1445    printVector(thismom);
1446    printfQuda("(%9f,%9f) (%9f,%9f)\n", thismom[6], thismom[7], thismom[8], thismom[9]);
1447  }else{
1448    float* thismom = ((float*)mom)+ X*momSiteSize;
1449    printVector(thismom);
1450    printfQuda("(%9f,%9f) (%9f,%9f)\n", thismom[6], thismom[7], thismom[8], thismom[9]);	
1451  }
1452}
1453int strong_check_mom(void * momA, void *momB, int len, QudaPrecision prec) 
1454{    
1455  printfQuda("mom:\n");
1456  printMomElement(momA, 0, prec); 
1457  printfQuda("\n");
1458  printMomElement(momA, 1, prec); 
1459  printfQuda("\n");
1460  printMomElement(momA, 2, prec); 
1461  printfQuda("\n");
1462  printMomElement(momA, 3, prec); 
1463  printfQuda("...\n");
1464  
1465  printfQuda("\nreference mom:\n");
1466  printMomElement(momB, 0, prec); 
1467  printfQuda("\n");
1468  printMomElement(momB, 1, prec); 
1469  printfQuda("\n");
1470  printMomElement(momB, 2, prec); 
1471  printfQuda("\n");
1472  printMomElement(momB, 3, prec); 
1473  printfQuda("\n");
1474  
1475  int ret;
1476  if (prec == QUDA_DOUBLE_PRECISION){
1477    ret = compare_mom((double*)momA, (double*)momB, len);
1478  }else{
1479    ret = compare_mom((float*)momA, (float*)momB, len);
1480  }
1481  
1482  return ret;
1483}
1484
1485
1486/************
1487 * return value
1488 *
1489 * 0: command line option matched and processed sucessfully
1490 * non-zero: command line option does not match
1491 *
1492 */
1493
1494#ifdef MULTI_GPU
1495int device = -1;
1496#else
1497int device = 0;
1498#endif
1499
1500QudaReconstructType link_recon = QUDA_RECONSTRUCT_NO;
1501QudaReconstructType link_recon_sloppy = QUDA_RECONSTRUCT_INVALID;
1502QudaPrecision prec = QUDA_SINGLE_PRECISION;
1503QudaPrecision  prec_sloppy = QUDA_INVALID_PRECISION;
1504int xdim = 24;
1505int ydim = 24;
1506int zdim = 24;
1507int tdim = 24;
1508int Lsdim = 16;
1509QudaDagType dagger = QUDA_DAG_NO;
1510int gridsize_from_cmdline[4] = {1,1,1,1};
1511QudaDslashType dslash_type = QUDA_WILSON_DSLASH;
1512char latfile[256] = "";
1513bool tune = true;
1514int niter = 10;
1515int test_type = 0;
1516bool verify_results = true;
1517
1518static int dim_partitioned[4] = {0,0,0,0};
1519
1520int dimPartitioned(int dim)
1521{
1522  return ((gridsize_from_cmdline[dim] > 1) || dim_partitioned[dim]);
1523}
1524
1525
1526void __attribute__((weak)) usage_extra(char** argv){};
1527
1528void usage(char** argv )
1529{
1530  printf("Usage: %s [options]\n", argv[0]);
1531  printf("Common options: \n");
1532#ifndef MULTI_GPU
1533  printf("    --device <n>                              # Set the CUDA device to use (default 0, single GPU only)\n");     
1534#endif
1535  printf("    --prec <double/single/half>               # Precision in GPU\n"); 
1536  printf("    --prec_sloppy <double/single/half>        # Sloppy precision in GPU\n"); 
1537  printf("    --recon <8/9/12/13/18>                         # Link reconstruction type\n"); 
1538  printf("    --recon_sloppy <8/9/12/13/18>                  # Sloppy link reconstruction type\n"); 
1539  printf("    --dagger                                  # Set the dagger to 1 (default 0)\n"); 
1540  printf("    --sdim <n>                                # Set space dimention(X/Y/Z) size\n"); 
1541  printf("    --xdim <n>                                # Set X dimension size(default 24)\n");     
1542  printf("    --ydim <n>                                # Set X dimension size(default 24)\n");     
1543  printf("    --zdim <n>                                # Set X dimension size(default 24)\n");     
1544  printf("    --tdim <n>                                # Set T dimension size(default 24)\n");  
1545  printf("    --Lsdim <n>                                # Set Ls dimension size(default 16)\n");  
1546  printf("    --xgridsize <n>                           # Set grid size in X dimension (default 1)\n");
1547  printf("    --ygridsize <n>                           # Set grid size in Y dimension (default 1)\n");
1548  printf("    --zgridsize <n>                           # Set grid size in Z dimension (default 1)\n");
1549  printf("    --tgridsize <n>                           # Set grid size in T dimension (default 1)\n");
1550  printf("    --partition <mask>                        # Set the communication topology (X=1, Y=2, Z=4, T=8, and combinations of these)\n");
1551  printf("    --kernel_pack_t                           # Set T dimension kernel packing to be true (default false)\n");
1552  printf("    --dslash_type <type>                      # Set the dslash type, the following values are valid\n"
1553	 "                                                  wilson/clover/twisted_mass/asqtad/domain_wall\n");
1554  printf("    --load-gauge file                         # Load gauge field \"file\" for the test (requires QIO)\n");
1555  printf("    --niter <n>                               # The number of iterations to perform (default 10)\n");
1556  printf("    --tune <true/false>                       # Whether to autotune or not (default true)\n");     
1557  printf("    --test                                    # Test method (different for each test)\n");
1558  printf("    --verify <true/false>                     # Verify the GPU results using CPU results (default true)\n");
1559  printf("    --help                                    # Print out this message\n"); 
1560  usage_extra(argv); 
1561#ifdef MULTI_GPU
1562  char msg[]="multi";
1563#else
1564  char msg[]="single";
1565#endif  
1566  printf("Note: this program is %s GPU build\n", msg);
1567  exit(1);
1568  return ;
1569}
1570
1571int process_command_line_option(int argc, char** argv, int* idx)
1572{
1573#ifdef MULTI_GPU
1574  char msg[]="multi";
1575#else
1576  char msg[]="single";
1577#endif
1578
1579  int ret = -1;
1580  
1581  int i = *idx;
1582
1583  if( strcmp(argv[i], "--help")== 0){
1584    usage(argv);
1585  }
1586
1587  if( strcmp(argv[i], "--verify") == 0){
1588    if (i+1 >= argc){
1589      usage(argv);
1590    }	    
1591
1592    if (strcmp(argv[i+1], "true") == 0){
1593      verify_results = true;
1594    }else if (strcmp(argv[i+1], "false") == 0){
1595      verify_results = false;
1596    }else{
1597      fprintf(stderr, "ERROR: invalid verify type\n");	
1598      exit(1);
1599    }
1600
1601    i++;
1602    ret = 0;
1603    goto out;
1604  }
1605  
1606  if( strcmp(argv[i], "--device") == 0){
1607    if (i+1 >= argc){
1608      usage(argv);
1609    }
1610    device = atoi(argv[i+1]);
1611    if (device < 0 || device > 16){
1612      printf("ERROR: Invalid CUDA device number (%d)\n", device);
1613      usage(argv);
1614    }
1615    i++;
1616    ret = 0;
1617    goto out;
1618  }
1619
1620  if( strcmp(argv[i], "--prec") == 0){
1621    if (i+1 >= argc){
1622      usage(argv);
1623    }	    
1624    prec =  get_prec(argv[i+1]);
1625    i++;
1626    ret = 0;
1627    goto out;
1628  }
1629
1630  if( strcmp(argv[i], "--prec_sloppy") == 0){
1631    if (i+1 >= argc){
1632      usage(argv);
1633    }	    
1634    prec_sloppy =  get_prec(argv[i+1]);
1635    i++;
1636    ret = 0;
1637    goto out;
1638  }
1639  
1640  if( strcmp(argv[i], "--recon") == 0){
1641    if (i+1 >= argc){
1642      usage(argv);
1643    }	    
1644    link_recon =  get_recon(argv[i+1]);
1645    i++;
1646    ret = 0;
1647    goto out;
1648  }
1649
1650  if( strcmp(argv[i], "--recon_sloppy") == 0){
1651    if (i+1 >= argc){
1652      usage(argv);
1653    }	    
1654    link_recon_sloppy =  get_recon(argv[i+1]);
1655    i++;
1656    ret = 0;
1657    goto out;
1658  }
1659  
1660  if( strcmp(argv[i], "--xdim") == 0){
1661    if (i+1 >= argc){
1662      usage(argv);
1663    }
1664    xdim= atoi(argv[i+1]);
1665    if (xdim < 0 || xdim > 512){
1666      printf("ERROR: invalid X dimension (%d)\n", xdim);
1667      usage(argv);
1668    }
1669    i++;
1670    ret = 0;
1671    goto out;
1672  }
1673
1674  if( strcmp(argv[i], "--ydim") == 0){
1675    if (i+1 >= argc){
1676      usage(argv);
1677    }
1678    ydim= atoi(argv[i+1]);
1679    if (ydim < 0 || ydim > 512){
1680      printf("ERROR: invalid T dimension (%d)\n", ydim);
1681      usage(argv);
1682    }
1683    i++;
1684    ret = 0;
1685    goto out;
1686  }
1687
1688
1689  if( strcmp(argv[i], "--zdim") == 0){
1690    if (i+1 >= argc){
1691      usage(argv);
1692    }
1693    zdim= atoi(argv[i+1]);
1694    if (zdim < 0 || zdim > 512){
1695      printf("ERROR: invalid T dimension (%d)\n", zdim);
1696      usage(argv);
1697    }
1698    i++;
1699    ret = 0;
1700    goto out;
1701  }
1702
1703  if( strcmp(argv[i], "--tdim") == 0){
1704    if (i+1 >= argc){
1705      usage(argv);
1706    }	    
1707    tdim =  atoi(argv[i+1]);
1708    if (tdim < 0 || tdim > 512){
1709      errorQuda("Error: invalid t dimension");
1710    }
1711    i++;
1712    ret = 0;
1713    goto out;
1714  }
1715
1716  if( strcmp(argv[i], "--sdim") == 0){
1717    if (i+1 >= argc){
1718      usage(argv);
1719    }	    
1720    int sdim =  atoi(argv[i+1]);
1721    if (sdim < 0 || sdim > 512){
1722      printfQuda("ERROR: invalid S dimension\n");
1723    }
1724    xdim=ydim=zdim=sdim;
1725    i++;
1726    ret = 0;
1727    goto out;
1728  }
1729  
1730  if( strcmp(argv[i], "--Lsdim") == 0){
1731    if (i+1 >= argc){
1732      usage(argv);
1733    }	    
1734    int Ls =  atoi(argv[i+1]);
1735    if (Ls < 0 || Ls > 128){
1736      printfQuda("ERROR: invalid Ls dimension\n");
1737    }
1738    Lsdim=Ls;
1739    i++;
1740    ret = 0;
1741    goto out;
1742  }
1743  
1744  if( strcmp(argv[i], "--dagger") == 0){
1745    dagger = QUDA_DAG_YES;
1746    ret = 0;
1747    goto out;
1748  }	
1749  
1750  if( strcmp(argv[i], "--partition") == 0){
1751    if (i+1 >= argc){
1752      usage(argv);
1753    }
1754#ifdef MULTI_GPU
1755    int value  =  atoi(argv[i+1]);
1756    for(int j=0; j < 4;j++){
1757      if (value &  (1 << j)){
1758	commDimPartitionedSet(j);
1759	dim_partitioned[j] = 1;
1760      }
1761    }
1762#else
1763    printfQuda("WARNING: Ignoring --partition option since this is a single-GPU build.\n");
1764#endif
1765    i++;
1766    ret = 0;
1767    goto out;
1768  }
1769  
1770  if( strcmp(argv[i], "--kernel_pack_t") == 0){
1771    quda::setKernelPackT(true);
1772    ret= 0;
1773    goto out;
1774  }
1775
1776
1777  if( strcmp(argv[i], "--tune") == 0){
1778    if (i+1 >= argc){
1779      usage(argv);
1780    }	    
1781
1782    if (strcmp(argv[i+1], "true") == 0){
1783      tune = true;
1784    }else if (strcmp(argv[i+1], "false") == 0){
1785      tune = false;
1786    }else{
1787      fprintf(stderr, "ERROR: invalid tuning type\n");	
1788      exit(1);
1789    }
1790
1791    i++;
1792    ret = 0;
1793    goto out;
1794  }
1795
1796  if( strcmp(argv[i], "--xgridsize") == 0){
1797    if (i+1 >= argc){ 
1798      usage(argv);
1799    }     
1800    int xsize =  atoi(argv[i+1]);
1801    if (xsize <= 0 ){
1802      errorQuda("ERROR: invalid X grid size");
1803    }
1804    gridsize_from_cmdline[0] = xsize;
1805    i++;
1806    ret = 0;
1807    goto out;
1808  }
1809
1810  if( strcmp(argv[i], "--ygridsize") == 0){
1811    if (i+1 >= argc){
1812      usage(argv);
1813    }     
1814    int ysize =  atoi(argv[i+1]);
1815    if (ysize <= 0 ){
1816      errorQuda("ERROR: invalid Y grid size");
1817    }
1818    gridsize_from_cmdline[1] = ysize;
1819    i++;
1820    ret = 0;
1821    goto out;
1822  }
1823
1824  if( strcmp(argv[i], "--zgridsize") == 0){
1825    if (i+1 >= argc){
1826      usage(argv);
1827    }     
1828    int zsize =  atoi(argv[i+1]);
1829    if (zsize <= 0 ){
1830      errorQuda("ERROR: invalid Z grid size");
1831    }
1832    gridsize_from_cmdline[2] = zsize;
1833    i++;
1834    ret = 0;
1835    goto out;
1836  }
1837  
1838  if( strcmp(argv[i], "--tgridsize") == 0){
1839    if (i+1 >= argc){
1840      usage(argv);
1841    }     
1842    int tsize =  atoi(argv[i+1]);
1843    if (tsize <= 0 ){
1844      errorQuda("ERROR: invalid T grid size");
1845    }
1846    gridsize_from_cmdline[3] = tsize;
1847    i++;
1848    ret = 0;
1849    goto out;
1850  }
1851  
1852  if( strcmp(argv[i], "--dslash_type") == 0){
1853    if (i+1 >= argc){
1854      usage(argv);
1855    }     
1856    dslash_type =  get_dslash_type(argv[i+1]);
1857    i++;
1858    ret = 0;
1859    goto out;
1860  }
1861  
1862  if( strcmp(argv[i], "--load-gauge") == 0){
1863    if (i+1 >= argc){
1864      usage(argv);
1865    }     
1866    strcpy(latfile, argv[i+1]);
1867    i++;
1868    ret = 0;
1869    goto out;
1870  }
1871  
1872  if( strcmp(argv[i], "--test") == 0){
1873    if (i+1 >= argc){
1874      usage(argv);
1875    }	    
1876    test_type = atoi(argv[i+1]);
1877    i++;
1878    ret = 0;
1879    goto out;	    
1880  }
1881    
1882  if( strcmp(argv[i], "--niter") == 0){
1883    if (i+1 >= argc){
1884      usage(argv);
1885    }
1886    niter= atoi(argv[i+1]);
1887    if (niter < 1 || niter > 1e6){
1888      printf("ERROR: invalid number of iterations (%d)\n", niter);
1889      usage(argv);
1890    }
1891    i++;
1892    ret = 0;
1893    goto out;
1894  }
1895
1896  if( strcmp(argv[i], "--version") == 0){
1897    printf("This program is linked with QUDA library, version %s,", 
1898	   get_quda_ver_str());
1899    printf(" %s GPU build\n", msg);
1900    exit(0);
1901  }
1902
1903 out:
1904  *idx = i;
1905  return ret ;
1906
1907}
1908
1909
1910static struct timeval startTime;
1911
1912void stopwatchStart() {
1913  gettimeofday(&startTime, NULL);
1914}
1915
1916double stopwatchReadSeconds() {
1917  struct timeval endTime;
1918  gettimeofday(&endTime, NULL);
1919
1920  long ds = endTime.tv_sec - startTime.tv_sec;
1921  long dus = endTime.tv_usec - startTime.tv_usec;
1922  return ds + 0.000001*dus;
1923}
1924
1925