/c/lcao.c

https://gitlab.com/marcindulak/gpaw · C · 213 lines · 190 code · 12 blank · 11 comment · 48 complexity · ae6f106c54e8ef5ef34dc37aeb4a16a4 MD5 · raw file

  1. /* Copyright (C) 2003-2007 CAMP
  2. * Copyright (C) 2007-2009 CAMd
  3. * Please see the accompanying LICENSE file for further information. */
  4. #include "extensions.h"
  5. #include "localized_functions.h"
  6. #include "bmgs/bmgs.h"
  7. #include <complex.h>
  8. #ifdef GPAW_NO_UNDERSCORE_BLAS
  9. # define dgemv_ dgemv
  10. # define dgemm_ dgemm
  11. #endif
  12. int dgemv_(char *trans, int *m, int * n,
  13. double *alpha, double *a, int *lda,
  14. double *x, int *incx, double *beta,
  15. double *y, int *incy);
  16. int dgemm_(char *transa, char *transb, int *m, int * n,
  17. int *k, const double *alpha, double *a, int *lda,
  18. double *b, int *ldb, double *beta,
  19. double *c, int *ldc);
  20. // +-----------n
  21. // +----m +----m | +----c+m |
  22. // | | | | | | | |
  23. // | b | = | v | * | | a | |
  24. // | | | | | | | |
  25. // 0----+ 0----+ | c----+ |
  26. // | |
  27. // 0-----------+
  28. void cut(const double* a, const int n[3], const int c[3],
  29. const double* v,
  30. double* b, const int m[3])
  31. {
  32. a += c[2] + (c[1] + c[0] * n[1]) * n[2];
  33. for (int i0 = 0; i0 < m[0]; i0++)
  34. {
  35. for (int i1 = 0; i1 < m[1]; i1++)
  36. {
  37. for (int i2 = 0; i2 < m[2]; i2++)
  38. b[i2] = v[i2] * a[i2];
  39. a += n[2];
  40. b += m[2];
  41. v += m[2];
  42. }
  43. a += n[2] * (n[1] - m[1]);
  44. }
  45. }
  46. PyObject * overlap(PyObject* self, PyObject *args)
  47. {
  48. PyObject* lfs_b_obj;
  49. PyArrayObject* m_b_obj;
  50. PyArrayObject* phase_bk_obj;
  51. PyArrayObject* vt_sG_obj;
  52. PyArrayObject* Vt_skmm_obj;
  53. if (!PyArg_ParseTuple(args, "OOOOO", &lfs_b_obj, &m_b_obj, &phase_bk_obj,
  54. &vt_sG_obj, &Vt_skmm_obj))
  55. return NULL;
  56. int nk = PyArray_DIMS(phase_bk_obj)[1];
  57. int nm = PyArray_DIMS(Vt_skmm_obj)[2];
  58. int nspins = PyArray_DIMS(vt_sG_obj)[0];
  59. const long *m_b = LONGP(m_b_obj);
  60. const double complex *phase_bk = COMPLEXP(phase_bk_obj);
  61. const double *vt_sG = DOUBLEP(vt_sG_obj);
  62. double *Vt_smm = 0;
  63. double complex *Vt_skmm = 0;
  64. if (nk == 0)
  65. Vt_smm = DOUBLEP(Vt_skmm_obj);
  66. else
  67. Vt_skmm = COMPLEXP(Vt_skmm_obj);
  68. int nb = PyList_Size(lfs_b_obj);
  69. int nmem = 0;
  70. double* a1 = 0;
  71. for (int b1 = 0; b1 < nb; b1++)
  72. {
  73. const LocalizedFunctionsObject* lf1 =
  74. (const LocalizedFunctionsObject*)PyList_GetItem(lfs_b_obj, b1);
  75. int m1 = m_b[b1];
  76. int nao1 = lf1->nf;
  77. double* f1 = lf1->f;
  78. double* vt1 = GPAW_MALLOC(double, lf1->ng0 * nspins);
  79. for (int s = 0; s < nspins; s++)
  80. bmgs_cut(vt_sG + s * lf1->ng, lf1->size, lf1->start,
  81. vt1 + s * lf1->ng0, lf1->size0);
  82. for (int b2 = b1; b2 < nb; b2++)
  83. {
  84. const LocalizedFunctionsObject* lf2 =
  85. (const LocalizedFunctionsObject*)PyList_GetItem(lfs_b_obj, b2);
  86. int beg[3];
  87. int end[3];
  88. int size[3];
  89. int beg1[3];
  90. int beg2[3];
  91. bool overlap = true;
  92. for (int c = 0; c < 3; c++)
  93. {
  94. beg[c] = MAX(lf1->start[c], lf2->start[c]);
  95. end[c] = MIN(lf1->start[c] + lf1->size0[c],
  96. lf2->start[c] + lf2->size0[c]);
  97. size[c] = end[c] - beg[c];
  98. if (size[c] <= 0)
  99. {
  100. overlap = false;
  101. continue;
  102. }
  103. beg1[c] = beg[c] - lf1->start[c];
  104. beg2[c] = beg[c] - lf2->start[c];
  105. }
  106. int nao2 = lf2->nf;
  107. if (overlap)
  108. {
  109. int ng = size[0] * size[1] * size[2];
  110. int n = ng * (nao1 + nao2) + nao1 * nao2;
  111. if (n > nmem)
  112. {
  113. if (nmem != 0)
  114. free(a1);
  115. nmem = n;
  116. a1 = GPAW_MALLOC(double, nmem);
  117. }
  118. double* a2 = a1 + ng * nao1;
  119. double* H = a2 + ng * nao2;
  120. double* f2 = lf2->f;
  121. double* vt2 = lf2->w;
  122. double dv = lf1->dv;
  123. int m2 = m_b[b2];
  124. if (b2 > b1)
  125. for (int i = 0; i < nao2; i++)
  126. bmgs_cut(f2 + i * lf2->ng0, lf2->size0, beg2,
  127. a2 + i * ng, size);
  128. else
  129. a2 = f2;
  130. for (int s = 0; s < nspins; s++)
  131. {
  132. if (b2 > b1)
  133. {
  134. bmgs_cut(vt1 + s * lf1->ng0, lf1->size0, beg1, vt2, size);
  135. for (int i = 0; i < nao1; i++)
  136. cut(f1 + i * lf1->ng0, lf1->size0, beg1, vt2,
  137. a1 + i * ng, size);
  138. }
  139. else
  140. {
  141. for (int i1 = 0; i1 < nao1; i1++)
  142. for (int g = 0; g < ng; g++)
  143. a1[i1 * ng + g] = (vt1[g + s * lf1->ng0] *
  144. f1[i1 * ng + g]);
  145. }
  146. double zero = 0.0;
  147. dgemm_("t", "n", &nao2, &nao1, &ng, &dv,
  148. a2, &ng, a1, &ng, &zero, H, &nao2);
  149. if (nk == 0)
  150. {
  151. double* Vt_mm = (Vt_smm + s * nm * nm + m1 + m2 * nm);
  152. if (b2 == b1)
  153. for (int i1 = 0; i1 < nao1; i1++)
  154. for (int i2 = i1; i2 < nao2; i2++)
  155. Vt_mm[i1 + i2 * nm] += H[i2 + i1 * nao2];
  156. else if (m1 == m2)
  157. for (int i1 = 0; i1 < nao1; i1++)
  158. for (int i2 = i1; i2 < nao2; i2++)
  159. Vt_mm[i1 + i2 * nm] += (H[i2 + i1 * nao2] +
  160. H[i1 + i2 * nao2]);
  161. else
  162. for (int ii = 0, i1 = 0; i1 < nao1; i1++)
  163. for (int i2 = 0; i2 < nao2; i2++, ii++)
  164. Vt_mm[i1 + i2 * nm] += H[ii];
  165. }
  166. else
  167. for (int k = 0; k < nk; k++)
  168. {
  169. double complex* Vt_mm = (Vt_skmm +
  170. (s * nk + k) * nm * nm +
  171. m1 + m2 * nm);
  172. if (b2 == b1)
  173. for (int i1 = 0; i1 < nao1; i1++)
  174. for (int i2 = i1; i2 < nao2; i2++)
  175. Vt_mm[i1 + i2 * nm] += H[i2 + i1 * nao2];
  176. else
  177. {
  178. double complex phase = \
  179. (phase_bk[b1 * nk + k] *
  180. conj(phase_bk[b2 * nk + k]));
  181. if (m1 == m2)
  182. for (int i1 = 0; i1 < nao1; i1++)
  183. for (int i2 = i1; i2 < nao2; i2++)
  184. Vt_mm[i1 + i2 * nm] += \
  185. (phase * H[i2 + i1 * nao2] +
  186. conj(phase) * H[i1 + i2 * nao2]);
  187. else
  188. for (int ii = 0, i1 = 0; i1 < nao1; i1++)
  189. for (int i2 = 0; i2 < nao2; i2++, ii++)
  190. Vt_mm[i1 + i2 * nm] += phase * H[ii];
  191. }
  192. }
  193. }
  194. }
  195. }
  196. free(vt1);
  197. }
  198. if (nmem != 0)
  199. free(a1);
  200. Py_RETURN_NONE;
  201. }