/scripts/abins/kpointsdata.py

https://github.com/mantidproject/mantid · Python · 163 lines · 98 code · 29 blank · 36 comment · 25 complexity · 834fbb1fcc38ea65015764f4fe977fea MD5 · raw file

  1. # Mantid Repository : https://github.com/mantidproject/mantid
  2. #
  3. # Copyright © 2018 ISIS Rutherford Appleton Laboratory UKRI,
  4. # NScD Oak Ridge National Laboratory, European Spallation Source,
  5. # Institut Laue - Langevin & CSNS, Institute of High Energy Physics, CAS
  6. # SPDX - License - Identifier: GPL - 3.0 +
  7. import collections.abc
  8. from typing import List, NamedTuple, overload
  9. from math import isclose
  10. import numpy as np
  11. from mantid.kernel import logger as mantid_logger
  12. from abins.constants import (COMPLEX_ID, FLOAT_ID, GAMMA_POINT, SMALL_K)
  13. class KpointData(NamedTuple):
  14. """Vibrational frequency / displacement data at a particular k-point"""
  15. k: np.ndarray
  16. weight: float
  17. frequencies: np.ndarray
  18. atomic_displacements: np.ndarray
  19. class KpointsData(collections.abc.Sequence):
  20. """Class storing atomic frequencies and displacements at specific k-points
  21. Args:
  22. weights: weights of all k-points; weights.shape == (num_k,);
  23. k_vectors: k_vectors of all k-points; k_vectors.shape == (num_k, 3)
  24. frequencies: frequencies for all k-points; frequencies.shape == (num_k, num_freq)
  25. atomic_displacements: atomic displacements for all k-points;
  26. atomic_displacements.shape == (num_k, num_atoms, num_freq, 3)
  27. unit_cell: lattice vectors (use zeros for open boundary conditions);
  28. unit_cell.shape == (3, 3)
  29. logger: Logging instance. Defaults to Mantid logger. Alternate loggers
  30. may be useful for testing.
  31. """
  32. def __init__(self, *, frequencies: np.ndarray, atomic_displacements: np.ndarray,
  33. weights: np.ndarray, k_vectors: np.ndarray, unit_cell: np.ndarray,
  34. logger = None) -> None:
  35. super().__init__()
  36. if logger is None:
  37. logger = mantid_logger
  38. self._data = {}
  39. dim = 3
  40. for arg in (frequencies, atomic_displacements, weights, k_vectors, unit_cell):
  41. if not isinstance(arg, np.ndarray):
  42. raise TypeError("All arguments to KpointsData should be numpy arrays")
  43. # unit_cell
  44. if not (unit_cell.shape == (dim, dim)
  45. and unit_cell.dtype.num == FLOAT_ID):
  46. raise ValueError("Invalid values of unit cell vectors.")
  47. self.unit_cell = unit_cell
  48. # weights
  49. num_k = weights.size
  50. if not (weights.dtype.num == FLOAT_ID
  51. and np.allclose(weights, weights[weights >= 0])):
  52. raise ValueError("Invalid value of weights.")
  53. if not isclose(np.sum(weights), 1.0):
  54. logger.warning("k-point weights do not sum to 1. Re-normalising...")
  55. weights /= np.sum(weights)
  56. self._weights = weights
  57. # k_vectors
  58. if not (k_vectors.shape == (num_k, dim)
  59. and k_vectors.dtype.num == FLOAT_ID):
  60. raise ValueError("Invalid value of k_vectors.")
  61. self._k_vectors = k_vectors
  62. # frequencies
  63. num_freq = frequencies.shape[1]
  64. if not (frequencies.shape == (num_k, num_freq)
  65. and frequencies.dtype.num == FLOAT_ID):
  66. raise ValueError("Invalid value of frequencies.")
  67. self._frequencies = frequencies
  68. # atomic_displacements
  69. if len(atomic_displacements.shape) != 4:
  70. raise ValueError("atomic_displacements should have four dimensions")
  71. num_atoms = atomic_displacements.shape[1]
  72. if not (atomic_displacements.shape == (weights.size, num_atoms, num_freq, dim)
  73. and atomic_displacements.dtype.num == COMPLEX_ID):
  74. raise ValueError("Invalid value of atomic_displacements.")
  75. self._atomic_displacements = atomic_displacements
  76. @staticmethod
  77. def _array_to_dict(array, string_key=False):
  78. if string_key:
  79. return {str(i): row for i, row in enumerate(array)}
  80. else:
  81. return {i: row for i, row in enumerate(array)}
  82. def get_gamma_point_data(self):
  83. """
  84. Extracts k points data only for Gamma point.
  85. :returns: dictionary with data only for Gamma point
  86. """
  87. gamma_pkt_index = -1
  88. k_vectors = self._array_to_dict(self._k_vectors)
  89. # look for index of Gamma point
  90. for k_index, k in k_vectors.items():
  91. if np.linalg.norm(k) < SMALL_K:
  92. gamma_pkt_index = k_index
  93. break
  94. else:
  95. raise ValueError("Gamma point not found.")
  96. k_points = {"weights": {GAMMA_POINT: self._data["weights"][gamma_pkt_index]},
  97. "k_vectors": {GAMMA_POINT: self._data["k_vectors"][gamma_pkt_index]},
  98. "frequencies": {GAMMA_POINT: self._data["frequencies"][gamma_pkt_index]},
  99. "atomic_displacements": {GAMMA_POINT: self._data["atomic_displacements"][gamma_pkt_index]},
  100. "unit_cell": self.unit_cell}
  101. return k_points
  102. def extract(self):
  103. extracted = {"unit_cell": self.unit_cell,
  104. "weights": self._array_to_dict(self._weights, string_key=True),
  105. "k_vectors": self._array_to_dict(self._k_vectors, string_key=True),
  106. "frequencies": self._array_to_dict(self._frequencies, string_key=True),
  107. "atomic_displacements": self._array_to_dict(self._atomic_displacements, string_key=True)}
  108. return extracted
  109. def __str__(self):
  110. return "K-points data"
  111. def __len__(self):
  112. return self._weights.size
  113. @overload # noqa F811
  114. def __getitem__(self, item: int) -> KpointData:
  115. ...
  116. @overload # noqa F811
  117. def __getitem__(self, item: slice) -> List[KpointData]: # noqa F811
  118. ...
  119. def __getitem__(self, item): # noqa F811
  120. if isinstance(item, int):
  121. return KpointData(self._k_vectors[item],
  122. self._weights[item],
  123. self._frequencies[item],
  124. self._atomic_displacements[item])
  125. elif isinstance(item, slice):
  126. return [self[i] for i in range(len(self))[item]]
  127. return self._data[item]