PageRenderTime 44ms CodeModel.GetById 13ms app.highlight 26ms RepoModel.GetById 1ms app.codeStats 1ms

/ase/visualize/vtk/grid.py

https://gitlab.com/vote539/ase
Python | 315 lines | 264 code | 29 blank | 22 comment | 13 complexity | a422b50205edfa83c4878025b53d220c MD5 | raw file
  1
  2import numpy as np
  3
  4from vtk import vtkPointData, vtkDataArray, vtkUnstructuredGrid, vtkPoints, \
  5                vtkIdList, vtkStructuredPoints
  6from ase.visualize.vtk.cell import vtkUnitCellModule
  7from ase.visualize.vtk.data import vtkDataArrayFromNumPyBuffer, \
  8                                   vtkDoubleArrayFromNumPyArray, \
  9                                   vtkDoubleArrayFromNumPyMultiArray
 10
 11# -------------------------------------------------------------------
 12
 13class vtkBaseGrid:
 14    def __init__(self, npoints, cell):
 15        self.npoints = npoints
 16
 17        # Make sure cell argument is correct type
 18        assert isinstance(cell, vtkUnitCellModule)
 19        self.cell = cell
 20
 21        self.vtk_pointdata = None
 22
 23    def set_point_data(self, vtk_pointdata):
 24        if self.vtk_pointdata is not None:
 25            raise RuntimeError('VTK point data already present.')
 26
 27        assert isinstance(vtk_pointdata, vtkPointData)
 28        self.vtk_pointdata = vtk_pointdata
 29        #self.vtk_pointdata.SetCopyScalars(False)
 30        #self.vtk_pointdata.SetCopyVectors(False)
 31        #self.vtk_pointdata.SetCopyNormals(False)
 32
 33    def get_point_data(self):
 34        if self.vtk_pointdata is None:
 35            raise RuntimeError('VTK point data missing.')
 36
 37        return self.vtk_pointdata
 38
 39    def get_number_of_points(self):
 40        return self.npoints
 41
 42    def add_scalar_data_array(self, data, name=None, active=True):
 43
 44        # Are we converting from NumPy buffer to VTK array?
 45        if isinstance(data, vtkDataArray):
 46            vtk_sda = data
 47        elif isinstance(data, vtkDataArrayFromNumPyBuffer):
 48            vtk_sda = data.get_output()
 49        else:
 50            raise ValueError('Data is not a valid scalar data array.')
 51
 52        del data
 53
 54        assert vtk_sda.GetNumberOfComponents() == 1
 55        assert vtk_sda.GetNumberOfTuples() == self.npoints
 56
 57        if name is not None:
 58            vtk_sda.SetName(name)
 59
 60        # Add VTK array to VTK point data
 61        self.vtk_pointdata.AddArray(vtk_sda)
 62
 63        if active:
 64            self.vtk_pointdata.SetActiveScalars(name)
 65
 66        return vtk_sda
 67
 68    def add_vector_data_array(self, data, name=None, active=True):
 69
 70        # Are we converting from NumPy buffer to VTK array?
 71        if isinstance(data, vtkDataArray):
 72            vtk_vda = data
 73        elif isinstance(data, vtkDataArrayFromNumPyBuffer):
 74            vtk_vda = data.get_output()
 75        else:
 76            raise ValueError('Data is not a valid vector data array.')
 77
 78        del data
 79
 80        assert vtk_vda.GetNumberOfComponents() == 3
 81        assert vtk_vda.GetNumberOfTuples() == self.npoints
 82
 83        if name is not None:
 84            vtk_vda.SetName(name)
 85
 86        # Add VTK array to VTK point data
 87        self.vtk_pointdata.AddArray(vtk_vda)
 88
 89        if active:
 90            self.vtk_pointdata.SetActiveVectors(name)
 91
 92        return vtk_vda
 93
 94# -------------------------------------------------------------------
 95
 96class vtkAtomicPositions(vtkBaseGrid):
 97    """Provides an interface for adding ``Atoms``-centered data to VTK
 98    modules. Atomic positions, e.g. obtained using atoms.get_positions(),
 99    constitute an unstructured grid in VTK, to which scalar and vector
100    can be added as point data sets.
101
102    Just like ``Atoms``, instances of ``vtkAtomicPositions`` can be divided
103    into subsets, which makes it easy to select atoms and add properties.
104
105    Example:
106
107    >>> cell = vtkUnitCellModule(atoms)
108    >>> apos = vtkAtomicPositions(atoms.get_positions(), cell)
109    >>> apos.add_scalar_property(atoms.get_charges(), 'charges')
110    >>> apos.add_vector_property(atoms.get_forces(), 'forces')
111
112    """
113    def __init__(self, pos, cell):
114        """Construct basic VTK-representation of a set of atomic positions.
115
116        pos: NumPy array of dtype float and shape ``(n,3)``
117            Cartesian positions of the atoms.
118        cell: Instance of vtkUnitCellModule of subclass thereof
119            Holds information equivalent to that of atoms.get_cell().
120
121        """
122        # Make sure position argument is a valid array
123        if not isinstance(pos, np.ndarray):
124            pos = np.array(pos)
125
126        assert pos.dtype == float and pos.shape[1:] == (3,)
127
128        vtkBaseGrid.__init__(self, len(pos), cell)
129
130        # Convert positions to VTK array
131        npy2da = vtkDoubleArrayFromNumPyArray(pos)
132        vtk_pda = npy2da.get_output()
133        del npy2da
134
135        # Transfer atomic positions to VTK points
136        self.vtk_pts = vtkPoints()
137        self.vtk_pts.SetData(vtk_pda)
138
139        # Create a VTK unstructured grid of these points
140        self.vtk_ugd = vtkUnstructuredGrid()
141        self.vtk_ugd.SetWholeBoundingBox(self.cell.get_bounding_box())
142        self.vtk_ugd.SetPoints(self.vtk_pts)
143
144        # Extract the VTK point data set
145        self.set_point_data(self.vtk_ugd.GetPointData())
146
147    def get_points(self, subset=None):
148        """Return (subset of) vtkPoints containing atomic positions.
149
150        subset=None: list of int
151            A list of indices into the atomic positions; ignored if None.
152
153        """
154        if subset is None:
155            return self.vtk_pts
156
157        # Create a list of indices from the subset
158        vtk_il = vtkIdList()
159        for i in subset:
160            vtk_il.InsertNextId(i)
161
162        # Allocate VTK points for subset
163        vtk_subpts = vtkPoints()
164        vtk_subpts.SetDataType(self.vtk_pts.GetDataType())
165        vtk_subpts.SetNumberOfPoints(vtk_il.GetNumberOfIds())
166
167        # Transfer subset of VTK points
168        self.vtk_pts.GetPoints(vtk_il, vtk_subpts)
169
170        return vtk_subpts
171
172    def get_unstructured_grid(self, subset=None):
173        """Return (subset of) an unstructured grid of the atomic positions.
174
175        subset=None: list of int
176            A list of indices into the atomic positions; ignored if None.
177
178        """
179        if subset is None:
180            return self.vtk_ugd
181
182        # Get subset of VTK points
183        vtk_subpts = self.get_points(subset)
184
185        # Create a VTK unstructured grid of these points
186        vtk_subugd = vtkUnstructuredGrid()
187        vtk_subugd.SetWholeBoundingBox(self.cell.get_bounding_box())
188        vtk_subugd.SetPoints(vtk_subpts)
189
190        return vtk_subugd
191
192    def add_scalar_property(self, data, name=None, active=True):
193        """Add VTK-representation of scalar data at the atomic positions.
194
195        data: NumPy array of dtype float and shape ``(n,)``
196            Scalar values corresponding to the atomic positions.
197        name=None: str
198            Unique identifier for the scalar data.
199        active=True: bool
200            Flag indicating whether to use as active scalar data.
201
202        """
203        # Make sure data argument is a valid array
204        if not isinstance(data, np.ndarray):
205            data = np.array(data)
206
207        assert data.dtype == float and data.shape == (self.npoints,)
208
209        # Convert scalar properties to VTK array
210        npa2da = vtkDoubleArrayFromNumPyArray(data)
211        return vtkBaseGrid.add_scalar_data_array(self, npa2da, name, active)
212
213    def add_vector_property(self, data, name=None, active=True):
214        """Add VTK-representation of vector data at the atomic positions.
215
216        data: NumPy array of dtype float and shape ``(n,3)``
217            Vector components corresponding to the atomic positions.
218        name=None: str
219            Unique identifier for the vector data.
220        active=True: bool
221            Flag indicating whether to use as active vector data.
222
223        """
224        # Make sure data argument is a valid array
225        if not isinstance(data, np.ndarray):
226            data = np.array(data)
227
228        assert data.dtype == float and data.shape == (self.npoints,3,)
229
230        # Convert vector properties to VTK array
231        npa2da = vtkDoubleArrayFromNumPyArray(data)
232        return vtkBaseGrid.add_vector_data_array(self, npa2da, name, active)
233
234# -------------------------------------------------------------------
235
236class vtkVolumeGrid(vtkBaseGrid):
237    def __init__(self, elements, cell, origin=None):
238
239        # Make sure element argument is a valid array
240        if not isinstance(elements, np.ndarray):
241            elements = np.array(elements)
242
243        assert elements.dtype == int and elements.shape == (3,)
244        self.elements = elements
245
246        vtkBaseGrid.__init__(self, np.prod(self.elements), cell)
247
248        # Create a VTK grid of structured points
249        self.vtk_spts = vtkStructuredPoints()
250        self.vtk_spts.SetWholeBoundingBox(self.cell.get_bounding_box())
251        self.vtk_spts.SetDimensions(self.elements)
252        self.vtk_spts.SetSpacing(self.get_grid_spacing())
253
254        if origin is not None:
255            self.vtk_spts.SetOrigin(origin)
256
257        # Extract the VTK point data set
258        self.set_point_data(self.vtk_spts.GetPointData())
259
260    def get_grid_spacing(self):
261        # Periodic boundary conditions leave out one boundary along an axis
262        # Zero/fixed boundary conditions leave out both boundaries of an axis
263        return self.cell.get_size()/(self.elements+1.0-self.cell.get_pbc())
264
265    def get_relaxation_factor(self):
266        # The relaxation factor is a floating point value between zero and one.
267        # It expresses the need for smoothening (relaxation) e.g. of isosurfaces
268        # due to coarse grid spacings. Larger grid spacing -> larger relaxation.
269        x = self.get_grid_spacing().mean()/self.cell.get_characteristic_length()
270
271        # The relaxation function f(x) satisfies the following requirements
272        # f(x) -> 0 for x -> 0+   and   f(x) -> b for x -> inf
273        # f'(x) -> a for x -> 0+  and   f'(x) -> 0 for x -> inf
274
275        # Furthermore, it is a rescaling of arctan, hence we know
276        # f(x) = 2 b arctan(a pi x / 2 b) / pi
277
278        # Our reference point is x = r for which medium relaxion is needed
279        # f(r) = b/2   <=>   r = 2 b / a pi   <=>   a = 2 b / r pi
280        r = 0.025 # corresponding to 0.2 Ang grid spacing in 8 Ang cell
281        b = 0.5
282        f = 2*b*np.arctan(x/r)/np.pi
283
284        if f > 0.1:
285           return f.round(1)
286        else:
287           return None
288
289    def get_structured_points(self):
290        return self.vtk_spts
291
292    def add_scalar_field(self, data, name=None, active=True):
293
294        # Make sure data argument is a valid array
295        if not isinstance(data, np.ndarray):
296            data = np.array(data)
297
298        assert data.dtype == float and data.shape == tuple(self.elements)
299
300        # Convert scalar field to VTK array
301        npa2da = vtkDoubleArrayFromNumPyMultiArray(data[...,np.newaxis])
302        return vtkBaseGrid.add_scalar_data_array(self, npa2da, name, active)
303
304    def add_vector_field(self, data, name=None, active=True):
305
306        # Make sure data argument is a valid array
307        if not isinstance(data, np.ndarray):
308            data = np.array(data)
309
310        assert data.dtype == float and data.shape == tuple(self.elements)+(3,)
311
312        # Convert vector field to VTK array
313        npa2da = vtkDoubleArrayFromNumPyMultiArray(data)
314        return vtkBaseGrid.add_vector_data_array(self, npa2da, name, active)
315