/gpaw/utilities/blas.py

https://gitlab.com/5rXUTAIYJSBcdxFmjqpuwaPo71b283/gpaw · Python · 353 lines · 322 code · 17 blank · 14 comment · 33 complexity · 4c405142f0b7f681e69422d27972dfb4 MD5 · raw file

  1. # Copyright (C) 2003 CAMP
  2. # Please see the accompanying LICENSE file for further information.
  3. """
  4. Python wrapper functions for the ``C`` package:
  5. Basic Linear Algebra Subroutines (BLAS)
  6. See also:
  7. http://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms
  8. and
  9. http://www.netlib.org/lapack/lug/node145.html
  10. """
  11. from typing import TypeVar
  12. import numpy as np
  13. import scipy.linalg.blas as blas
  14. from gpaw import debug
  15. import _gpaw
  16. __all__ = ['mmm']
  17. T = TypeVar('T', float, complex)
  18. def mmm(alpha: T,
  19. a: np.ndarray,
  20. opa: str,
  21. b: np.ndarray,
  22. opb: str,
  23. beta: T,
  24. c: np.ndarray) -> None:
  25. """Matrix-matrix multiplication using dgemm or zgemm.
  26. For opa='n' and opb='n', we have::
  27. c <- alpha * a * b + beta * c.
  28. Use 't' to transpose matrices and 'c' to transpose and complex conjugate
  29. matrices.
  30. """
  31. assert opa in 'NTC'
  32. assert opb in 'NTC'
  33. if opa == 'N':
  34. a1, a2 = a.shape
  35. else:
  36. a2, a1 = a.shape
  37. if opb == 'N':
  38. b1, b2 = b.shape
  39. else:
  40. b2, b1 = b.shape
  41. assert a2 == b1
  42. assert c.shape == (a1, b2)
  43. assert a.strides[1] == b.strides[1] == c.strides[1] == c.itemsize
  44. assert a.dtype == b.dtype == c.dtype
  45. if a.dtype == float:
  46. assert not isinstance(alpha, complex)
  47. assert not isinstance(beta, complex)
  48. else:
  49. assert a.dtype == complex
  50. _gpaw.mmm(alpha, a, opa, b, opb, beta, c)
  51. def gemm(alpha, a, b, beta, c, transa='n'):
  52. """General Matrix Multiply.
  53. Performs the operation::
  54. c <- alpha * b.a + beta * c
  55. If transa is "n", ``b.a`` denotes the matrix multiplication defined by::
  56. _
  57. \
  58. (b.a) = ) b * a
  59. ijkl... /_ ip pjkl...
  60. p
  61. If transa is "t" or "c", ``b.a`` denotes the matrix multiplication
  62. defined by::
  63. _
  64. \
  65. (b.a) = ) b * a
  66. ij /_ iklm... jklm...
  67. klm...
  68. where in case of "c" also complex conjugate of a is taken.
  69. """
  70. assert np.isfinite(c).all()
  71. assert (a.dtype == float and b.dtype == float and c.dtype == float and
  72. isinstance(alpha, float) and isinstance(beta, float) or
  73. a.dtype == complex and b.dtype == complex and c.dtype == complex)
  74. if transa == 'n':
  75. assert a.size == 0 or a[0].flags.contiguous
  76. assert c.flags.contiguous or c.ndim == 2 and c.strides[1] == c.itemsize
  77. assert b.ndim == 2
  78. assert b.size == 0 or b.strides[1] == b.itemsize
  79. assert a.shape[0] == b.shape[1]
  80. assert c.shape == b.shape[0:1] + a.shape[1:]
  81. else:
  82. assert a.flags.contiguous
  83. assert b.size == 0 or b[0].flags.contiguous
  84. assert c.strides[1] == c.itemsize
  85. assert a.shape[1:] == b.shape[1:]
  86. assert c.shape == (b.shape[0], a.shape[0])
  87. _gpaw.gemm(alpha, a, b, beta, c, transa)
  88. def axpy(alpha, x, y):
  89. """alpha x plus y.
  90. Performs the operation::
  91. y <- alpha * x + y
  92. """
  93. if x.size == 0:
  94. return
  95. x = x.ravel()
  96. y = y.ravel()
  97. if x.dtype == float:
  98. z = blas.daxpy(x, y, a=alpha)
  99. else:
  100. z = blas.zaxpy(x, y, a=alpha)
  101. assert z is y, (x, y, x.shape, y.shape)
  102. def rk(alpha, a, beta, c, trans='c'):
  103. """Rank-k update of a matrix.
  104. Performs the operation::
  105. dag
  106. c <- alpha * a . a + beta * c
  107. where ``a.b`` denotes the matrix multiplication defined by::
  108. _
  109. \
  110. (a.b) = ) a * b
  111. ij /_ ipklm... pjklm...
  112. pklm...
  113. ``dag`` denotes the hermitian conjugate (complex conjugation plus a
  114. swap of axis 0 and 1).
  115. Only the lower triangle of ``c`` will contain sensible numbers.
  116. """
  117. assert beta == 0.0 or np.isfinite(c).all()
  118. assert (a.dtype == float and c.dtype == float or
  119. a.dtype == complex and c.dtype == complex)
  120. assert a.flags.contiguous
  121. assert a.ndim > 1
  122. if trans == 'n':
  123. assert c.shape == (a.shape[1], a.shape[1])
  124. else:
  125. assert c.shape == (a.shape[0], a.shape[0])
  126. assert c.strides[1] == c.itemsize
  127. _gpaw.rk(alpha, a, beta, c, trans)
  128. def r2k(alpha, a, b, beta, c, trans='c'):
  129. """Rank-2k update of a matrix.
  130. Performs the operation::
  131. dag cc dag
  132. c <- alpha * a . b + alpha * b . a + beta * c
  133. or if trans='n'::
  134. dag cc dag
  135. c <- alpha * a . b + alpha * b . a + beta * c
  136. where ``a.b`` denotes the matrix multiplication defined by::
  137. _
  138. \
  139. (a.b) = ) a * b
  140. ij /_ ipklm... pjklm...
  141. pklm...
  142. ``cc`` denotes complex conjugation.
  143. ``dag`` denotes the hermitian conjugate (complex conjugation plus a
  144. swap of axis 0 and 1).
  145. Only the lower triangle of ``c`` will contain sensible numbers.
  146. """
  147. assert beta == 0.0 or np.isfinite(np.tril(c)).all()
  148. assert (a.dtype == float and b.dtype == float and c.dtype == float or
  149. a.dtype == complex and b.dtype == complex and c.dtype == complex)
  150. assert a.flags.contiguous and b.flags.contiguous
  151. assert a.ndim > 1
  152. assert a.shape == b.shape
  153. if trans == 'c':
  154. assert c.shape == (a.shape[0], a.shape[0])
  155. else:
  156. assert c.shape == (a.shape[1], a.shape[1])
  157. assert c.strides[1] == c.itemsize
  158. _gpaw.r2k(alpha, a, b, beta, c, trans)
  159. def _gemmdot(a, b, alpha=1.0, beta=1.0, out=None, trans='n'):
  160. """Matrix multiplication using gemm.
  161. return reference to out, where::
  162. out <- alpha * a . b + beta * out
  163. If out is None, a suitably sized zero array will be created.
  164. ``a.b`` denotes matrix multiplication, where the product-sum is
  165. over the last dimension of a, and either
  166. the first dimension of b (for trans='n'), or
  167. the last dimension of b (for trans='t' or 'c').
  168. If trans='c', the complex conjugate of b is used.
  169. """
  170. # Store original shapes
  171. ashape = a.shape
  172. bshape = b.shape
  173. # Vector-vector multiplication is handled by dotu
  174. if a.ndim == 1 and b.ndim == 1:
  175. assert out is None
  176. if trans == 'c':
  177. return alpha * np.vdot(b, a) # dotc conjugates *first* argument
  178. else:
  179. return alpha * a.dot(b)
  180. # Map all arrays to 2D arrays
  181. a = a.reshape(-1, a.shape[-1])
  182. if trans == 'n':
  183. b = b.reshape(b.shape[0], -1)
  184. outshape = a.shape[0], b.shape[1]
  185. else: # 't' or 'c'
  186. b = b.reshape(-1, b.shape[-1])
  187. # Apply BLAS gemm routine
  188. outshape = a.shape[0], b.shape[trans == 'n']
  189. if out is None:
  190. # (ATLAS can't handle uninitialized output array)
  191. out = np.zeros(outshape, a.dtype)
  192. else:
  193. out = out.reshape(outshape)
  194. gemm(alpha, b, a, beta, out, trans)
  195. # Determine actual shape of result array
  196. if trans == 'n':
  197. outshape = ashape[:-1] + bshape[1:]
  198. else: # 't' or 'c'
  199. outshape = ashape[:-1] + bshape[:-1]
  200. return out.reshape(outshape)
  201. if not hasattr(_gpaw, 'mmm'):
  202. def gemm(alpha, a, b, beta, c, transa='n'): # noqa
  203. if c.size == 0:
  204. return
  205. if beta == 0:
  206. c[:] = 0.0
  207. else:
  208. c *= beta
  209. if a.size == 0:
  210. return
  211. if transa == 'n':
  212. c += alpha * b.dot(a.reshape((len(a), -1))).reshape(c.shape)
  213. elif transa == 't':
  214. c += alpha * b.reshape((len(b), -1)).dot(
  215. a.reshape((len(a), -1)).T)
  216. else:
  217. c += alpha * b.reshape((len(b), -1)).dot(
  218. a.reshape((len(a), -1)).T.conj())
  219. def rk(alpha, a, beta, c, trans='c'): # noqa
  220. if c.size == 0:
  221. return
  222. if beta == 0:
  223. c[:] = 0.0
  224. else:
  225. c *= beta
  226. if trans == 'n':
  227. c += alpha * a.conj().T.dot(a)
  228. else:
  229. a = a.reshape((len(a), -1))
  230. c += alpha * a.dot(a.conj().T)
  231. def r2k(alpha, a, b, beta, c, trans='c'): # noqa
  232. if c.size == 0:
  233. return
  234. if beta == 0.0:
  235. c[:] = 0.0
  236. else:
  237. c *= beta
  238. if trans == 'c':
  239. c += (alpha * a.reshape((len(a), -1))
  240. .dot(b.reshape((len(b), -1)).conj().T) +
  241. alpha * b.reshape((len(b), -1))
  242. .dot(a.reshape((len(a), -1)).conj().T))
  243. else:
  244. c += alpha * (a.conj().T @ b + b.conj().T @ a)
  245. def op(o, m):
  246. if o == 'N':
  247. return m
  248. if o == 'T':
  249. return m.T
  250. return m.conj().T
  251. def mmm(alpha: T, a: np.ndarray, opa: str, # noqa
  252. b: np.ndarray, opb: str,
  253. beta: T, c: np.ndarray) -> None:
  254. if beta == 0.0:
  255. c[:] = 0.0
  256. else:
  257. c *= beta
  258. c += alpha * op(opa, a).dot(op(opb, b))
  259. gemmdot = _gemmdot
  260. elif not debug:
  261. mmm = _gpaw.mmm # noqa
  262. gemm = _gpaw.gemm # noqa
  263. rk = _gpaw.rk # noqa
  264. r2k = _gpaw.r2k # noqa
  265. gemmdot = _gemmdot
  266. else:
  267. def gemmdot(a, b, alpha=1.0, beta=1.0, out=None, trans='n'):
  268. assert a.flags.contiguous
  269. assert b.flags.contiguous
  270. assert a.dtype == b.dtype
  271. if trans == 'n':
  272. assert a.shape[-1] == b.shape[0]
  273. else:
  274. assert a.shape[-1] == b.shape[-1]
  275. if out is not None:
  276. assert out.flags.contiguous
  277. assert a.dtype == out.dtype
  278. assert a.ndim > 1 or b.ndim > 1
  279. if trans == 'n':
  280. assert out.shape == a.shape[:-1] + b.shape[1:]
  281. else:
  282. assert out.shape == a.shape[:-1] + b.shape[:-1]
  283. return _gemmdot(a, b, alpha, beta, out, trans)