/ase/visualize/vtk/grid.py

https://gitlab.com/vote539/ase · Python · 315 lines · 264 code · 29 blank · 22 comment · 14 complexity · a422b50205edfa83c4878025b53d220c MD5 · raw file

  1. import numpy as np
  2. from vtk import vtkPointData, vtkDataArray, vtkUnstructuredGrid, vtkPoints, \
  3. vtkIdList, vtkStructuredPoints
  4. from ase.visualize.vtk.cell import vtkUnitCellModule
  5. from ase.visualize.vtk.data import vtkDataArrayFromNumPyBuffer, \
  6. vtkDoubleArrayFromNumPyArray, \
  7. vtkDoubleArrayFromNumPyMultiArray
  8. # -------------------------------------------------------------------
  9. class vtkBaseGrid:
  10. def __init__(self, npoints, cell):
  11. self.npoints = npoints
  12. # Make sure cell argument is correct type
  13. assert isinstance(cell, vtkUnitCellModule)
  14. self.cell = cell
  15. self.vtk_pointdata = None
  16. def set_point_data(self, vtk_pointdata):
  17. if self.vtk_pointdata is not None:
  18. raise RuntimeError('VTK point data already present.')
  19. assert isinstance(vtk_pointdata, vtkPointData)
  20. self.vtk_pointdata = vtk_pointdata
  21. #self.vtk_pointdata.SetCopyScalars(False)
  22. #self.vtk_pointdata.SetCopyVectors(False)
  23. #self.vtk_pointdata.SetCopyNormals(False)
  24. def get_point_data(self):
  25. if self.vtk_pointdata is None:
  26. raise RuntimeError('VTK point data missing.')
  27. return self.vtk_pointdata
  28. def get_number_of_points(self):
  29. return self.npoints
  30. def add_scalar_data_array(self, data, name=None, active=True):
  31. # Are we converting from NumPy buffer to VTK array?
  32. if isinstance(data, vtkDataArray):
  33. vtk_sda = data
  34. elif isinstance(data, vtkDataArrayFromNumPyBuffer):
  35. vtk_sda = data.get_output()
  36. else:
  37. raise ValueError('Data is not a valid scalar data array.')
  38. del data
  39. assert vtk_sda.GetNumberOfComponents() == 1
  40. assert vtk_sda.GetNumberOfTuples() == self.npoints
  41. if name is not None:
  42. vtk_sda.SetName(name)
  43. # Add VTK array to VTK point data
  44. self.vtk_pointdata.AddArray(vtk_sda)
  45. if active:
  46. self.vtk_pointdata.SetActiveScalars(name)
  47. return vtk_sda
  48. def add_vector_data_array(self, data, name=None, active=True):
  49. # Are we converting from NumPy buffer to VTK array?
  50. if isinstance(data, vtkDataArray):
  51. vtk_vda = data
  52. elif isinstance(data, vtkDataArrayFromNumPyBuffer):
  53. vtk_vda = data.get_output()
  54. else:
  55. raise ValueError('Data is not a valid vector data array.')
  56. del data
  57. assert vtk_vda.GetNumberOfComponents() == 3
  58. assert vtk_vda.GetNumberOfTuples() == self.npoints
  59. if name is not None:
  60. vtk_vda.SetName(name)
  61. # Add VTK array to VTK point data
  62. self.vtk_pointdata.AddArray(vtk_vda)
  63. if active:
  64. self.vtk_pointdata.SetActiveVectors(name)
  65. return vtk_vda
  66. # -------------------------------------------------------------------
  67. class vtkAtomicPositions(vtkBaseGrid):
  68. """Provides an interface for adding ``Atoms``-centered data to VTK
  69. modules. Atomic positions, e.g. obtained using atoms.get_positions(),
  70. constitute an unstructured grid in VTK, to which scalar and vector
  71. can be added as point data sets.
  72. Just like ``Atoms``, instances of ``vtkAtomicPositions`` can be divided
  73. into subsets, which makes it easy to select atoms and add properties.
  74. Example:
  75. >>> cell = vtkUnitCellModule(atoms)
  76. >>> apos = vtkAtomicPositions(atoms.get_positions(), cell)
  77. >>> apos.add_scalar_property(atoms.get_charges(), 'charges')
  78. >>> apos.add_vector_property(atoms.get_forces(), 'forces')
  79. """
  80. def __init__(self, pos, cell):
  81. """Construct basic VTK-representation of a set of atomic positions.
  82. pos: NumPy array of dtype float and shape ``(n,3)``
  83. Cartesian positions of the atoms.
  84. cell: Instance of vtkUnitCellModule of subclass thereof
  85. Holds information equivalent to that of atoms.get_cell().
  86. """
  87. # Make sure position argument is a valid array
  88. if not isinstance(pos, np.ndarray):
  89. pos = np.array(pos)
  90. assert pos.dtype == float and pos.shape[1:] == (3,)
  91. vtkBaseGrid.__init__(self, len(pos), cell)
  92. # Convert positions to VTK array
  93. npy2da = vtkDoubleArrayFromNumPyArray(pos)
  94. vtk_pda = npy2da.get_output()
  95. del npy2da
  96. # Transfer atomic positions to VTK points
  97. self.vtk_pts = vtkPoints()
  98. self.vtk_pts.SetData(vtk_pda)
  99. # Create a VTK unstructured grid of these points
  100. self.vtk_ugd = vtkUnstructuredGrid()
  101. self.vtk_ugd.SetWholeBoundingBox(self.cell.get_bounding_box())
  102. self.vtk_ugd.SetPoints(self.vtk_pts)
  103. # Extract the VTK point data set
  104. self.set_point_data(self.vtk_ugd.GetPointData())
  105. def get_points(self, subset=None):
  106. """Return (subset of) vtkPoints containing atomic positions.
  107. subset=None: list of int
  108. A list of indices into the atomic positions; ignored if None.
  109. """
  110. if subset is None:
  111. return self.vtk_pts
  112. # Create a list of indices from the subset
  113. vtk_il = vtkIdList()
  114. for i in subset:
  115. vtk_il.InsertNextId(i)
  116. # Allocate VTK points for subset
  117. vtk_subpts = vtkPoints()
  118. vtk_subpts.SetDataType(self.vtk_pts.GetDataType())
  119. vtk_subpts.SetNumberOfPoints(vtk_il.GetNumberOfIds())
  120. # Transfer subset of VTK points
  121. self.vtk_pts.GetPoints(vtk_il, vtk_subpts)
  122. return vtk_subpts
  123. def get_unstructured_grid(self, subset=None):
  124. """Return (subset of) an unstructured grid of the atomic positions.
  125. subset=None: list of int
  126. A list of indices into the atomic positions; ignored if None.
  127. """
  128. if subset is None:
  129. return self.vtk_ugd
  130. # Get subset of VTK points
  131. vtk_subpts = self.get_points(subset)
  132. # Create a VTK unstructured grid of these points
  133. vtk_subugd = vtkUnstructuredGrid()
  134. vtk_subugd.SetWholeBoundingBox(self.cell.get_bounding_box())
  135. vtk_subugd.SetPoints(vtk_subpts)
  136. return vtk_subugd
  137. def add_scalar_property(self, data, name=None, active=True):
  138. """Add VTK-representation of scalar data at the atomic positions.
  139. data: NumPy array of dtype float and shape ``(n,)``
  140. Scalar values corresponding to the atomic positions.
  141. name=None: str
  142. Unique identifier for the scalar data.
  143. active=True: bool
  144. Flag indicating whether to use as active scalar data.
  145. """
  146. # Make sure data argument is a valid array
  147. if not isinstance(data, np.ndarray):
  148. data = np.array(data)
  149. assert data.dtype == float and data.shape == (self.npoints,)
  150. # Convert scalar properties to VTK array
  151. npa2da = vtkDoubleArrayFromNumPyArray(data)
  152. return vtkBaseGrid.add_scalar_data_array(self, npa2da, name, active)
  153. def add_vector_property(self, data, name=None, active=True):
  154. """Add VTK-representation of vector data at the atomic positions.
  155. data: NumPy array of dtype float and shape ``(n,3)``
  156. Vector components corresponding to the atomic positions.
  157. name=None: str
  158. Unique identifier for the vector data.
  159. active=True: bool
  160. Flag indicating whether to use as active vector data.
  161. """
  162. # Make sure data argument is a valid array
  163. if not isinstance(data, np.ndarray):
  164. data = np.array(data)
  165. assert data.dtype == float and data.shape == (self.npoints,3,)
  166. # Convert vector properties to VTK array
  167. npa2da = vtkDoubleArrayFromNumPyArray(data)
  168. return vtkBaseGrid.add_vector_data_array(self, npa2da, name, active)
  169. # -------------------------------------------------------------------
  170. class vtkVolumeGrid(vtkBaseGrid):
  171. def __init__(self, elements, cell, origin=None):
  172. # Make sure element argument is a valid array
  173. if not isinstance(elements, np.ndarray):
  174. elements = np.array(elements)
  175. assert elements.dtype == int and elements.shape == (3,)
  176. self.elements = elements
  177. vtkBaseGrid.__init__(self, np.prod(self.elements), cell)
  178. # Create a VTK grid of structured points
  179. self.vtk_spts = vtkStructuredPoints()
  180. self.vtk_spts.SetWholeBoundingBox(self.cell.get_bounding_box())
  181. self.vtk_spts.SetDimensions(self.elements)
  182. self.vtk_spts.SetSpacing(self.get_grid_spacing())
  183. if origin is not None:
  184. self.vtk_spts.SetOrigin(origin)
  185. # Extract the VTK point data set
  186. self.set_point_data(self.vtk_spts.GetPointData())
  187. def get_grid_spacing(self):
  188. # Periodic boundary conditions leave out one boundary along an axis
  189. # Zero/fixed boundary conditions leave out both boundaries of an axis
  190. return self.cell.get_size()/(self.elements+1.0-self.cell.get_pbc())
  191. def get_relaxation_factor(self):
  192. # The relaxation factor is a floating point value between zero and one.
  193. # It expresses the need for smoothening (relaxation) e.g. of isosurfaces
  194. # due to coarse grid spacings. Larger grid spacing -> larger relaxation.
  195. x = self.get_grid_spacing().mean()/self.cell.get_characteristic_length()
  196. # The relaxation function f(x) satisfies the following requirements
  197. # f(x) -> 0 for x -> 0+ and f(x) -> b for x -> inf
  198. # f'(x) -> a for x -> 0+ and f'(x) -> 0 for x -> inf
  199. # Furthermore, it is a rescaling of arctan, hence we know
  200. # f(x) = 2 b arctan(a pi x / 2 b) / pi
  201. # Our reference point is x = r for which medium relaxion is needed
  202. # f(r) = b/2 <=> r = 2 b / a pi <=> a = 2 b / r pi
  203. r = 0.025 # corresponding to 0.2 Ang grid spacing in 8 Ang cell
  204. b = 0.5
  205. f = 2*b*np.arctan(x/r)/np.pi
  206. if f > 0.1:
  207. return f.round(1)
  208. else:
  209. return None
  210. def get_structured_points(self):
  211. return self.vtk_spts
  212. def add_scalar_field(self, data, name=None, active=True):
  213. # Make sure data argument is a valid array
  214. if not isinstance(data, np.ndarray):
  215. data = np.array(data)
  216. assert data.dtype == float and data.shape == tuple(self.elements)
  217. # Convert scalar field to VTK array
  218. npa2da = vtkDoubleArrayFromNumPyMultiArray(data[...,np.newaxis])
  219. return vtkBaseGrid.add_scalar_data_array(self, npa2da, name, active)
  220. def add_vector_field(self, data, name=None, active=True):
  221. # Make sure data argument is a valid array
  222. if not isinstance(data, np.ndarray):
  223. data = np.array(data)
  224. assert data.dtype == float and data.shape == tuple(self.elements)+(3,)
  225. # Convert vector field to VTK array
  226. npa2da = vtkDoubleArrayFromNumPyMultiArray(data)
  227. return vtkBaseGrid.add_vector_data_array(self, npa2da, name, active)