PageRenderTime 48ms CodeModel.GetById 22ms RepoModel.GetById 0ms app.codeStats 0ms

/virtual_ipm/frontends/gui/analysis/views.py

https://gitlab.com/IPMsim/Virtual-IPM
Python | 368 lines | 298 code | 51 blank | 19 comment | 25 complexity | 78855bc808c631dae64487b0e719a7e7 MD5 | raw file
  1. # Virtual-IPM is a software for simulating IPMs and other related devices.
  2. # Copyright (C) 2021 The IPMSim collaboration <https://ipmsim.gitlab.io/>
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as
  6. # published by the Free Software Foundation, either version 3 of the
  7. # License, or (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <http://www.gnu.org/licenses/>.
  16. import os
  17. import re
  18. import matplotlib.pyplot as plt
  19. # noinspection PyUnresolvedReferences
  20. from mpl_toolkits.mplot3d import Axes3D
  21. import matplotlib.backends.backend_qt5agg
  22. import numpy as np
  23. import pandas
  24. from pandas.errors import ParserError
  25. import PyQt5.QtCore as QtCore
  26. import PyQt5.QtGui as QtGui
  27. import PyQt5.QtWidgets as Widgets
  28. from virtual_ipm.simulation.output import BasicRecorder
  29. from ..utils import getOpenFileName
  30. FigureCanvas = matplotlib.backends.backend_qt5agg.FigureCanvasQTAgg
  31. class InitialFinalMapAnalyzer(Widgets.QMainWindow):
  32. def __init__(self, parent=None):
  33. super().__init__(parent=parent, flags=QtCore.Qt.WindowType.Window)
  34. self._profile_plot = ProfilePlot()
  35. self._initial_scatter = ScatterPlot('initial')
  36. self._final_scatter = ScatterPlot('final')
  37. v_splitter = Widgets.QSplitter(QtCore.Qt.Orientation.Vertical)
  38. v_splitter.addWidget(self._profile_plot)
  39. h_splitter = Widgets.QSplitter(QtCore.Qt.Orientation.Horizontal)
  40. h_splitter.addWidget(self._initial_scatter)
  41. h_splitter.addWidget(self._final_scatter)
  42. v_splitter.addWidget(h_splitter)
  43. self.setCentralWidget(v_splitter)
  44. menubar = self.menuBar()
  45. file_menu = menubar.addMenu('File')
  46. open_file_action = Widgets.QAction(
  47. QtGui.QIcon(os.path.join(os.path.split(__file__)[0], '../icons/open_xml.png')),
  48. 'Open output file',
  49. self
  50. )
  51. open_file_action.triggered.connect(self.open_file)
  52. open_file_action.setShortcut(QtGui.QKeySequence.Open)
  53. file_menu.addAction(open_file_action)
  54. toolbar = Widgets.QToolBar()
  55. toolbar.addAction(open_file_action)
  56. self.addToolBar(toolbar)
  57. self.setWindowTitle('[IPMSim] Virtual-IPM')
  58. self.resize(QtCore.QSize(1500, 1000))
  59. def open_file(self):
  60. filename = getOpenFileName(
  61. self,
  62. caption='Choose an output file',
  63. filter='CSV Files (*.csv);;All Files (*.*)'
  64. )
  65. if not filename:
  66. return
  67. try:
  68. df = pandas.read_csv(filename)
  69. except ParserError as err:
  70. Widgets.QMessageBox.critical(
  71. self,
  72. type(err).__name__,
  73. str(err)
  74. )
  75. else:
  76. if 'status' in df:
  77. statuses, counts = np.unique(df['status'], return_counts=True)
  78. if 'DETECTED' not in statuses or statuses.size > 1:
  79. Widgets.QMessageBox.information(
  80. self,
  81. 'File contains undetected particles',
  82. 'The selected data file contains particles which are marked as not '
  83. 'detected. The following statuses were encountered:\n\n'
  84. + '\n'.join('{}: {}'.format(s.capitalize(), c)
  85. for s, c in zip(statuses, counts))
  86. + '\n\nOnly detected particles will be included in the plots.'
  87. )
  88. df = df.loc[df['status'] == 'DETECTED']
  89. self._profile_plot.data_frame = df
  90. self._initial_scatter.data_frame = df
  91. self._final_scatter.data_frame = df
  92. class ProfilePlot(Widgets.QWidget):
  93. bin_size_slider_multiplier = 5
  94. def __init__(self, df=None, parent=None):
  95. super().__init__(parent=parent, flags=QtCore.Qt.WindowType.Widget)
  96. self._df = df
  97. self._bin_size_slider = Widgets.QSlider(QtCore.Qt.Orientation.Horizontal)
  98. # Range and bin size in [um].
  99. self._bin_size_slider.setRange(1, 20)
  100. self._bin_size_line_edit = Widgets.QLineEdit()
  101. self._replot_timer = QtCore.QTimer()
  102. # Wait 500 milliseconds until replotting in order to avoid plotting for fast changes of
  103. # the slider.
  104. self._replot_timer.setInterval(500)
  105. self._bin_size_slider.valueChanged.connect(
  106. lambda x: self._bin_size_line_edit.setText(
  107. '%d um' % (x * self.bin_size_slider_multiplier)
  108. )
  109. )
  110. self._bin_size_slider.valueChanged.connect(lambda x: self._replot_timer.start())
  111. self._bin_size_slider.setEnabled(False)
  112. self._bin_size_line_edit.setText(
  113. '%d um' % (self._bin_size_slider.value() * self.bin_size_slider_multiplier)
  114. )
  115. self._replot_timer.timeout.connect(self.replot)
  116. self._figure = plt.figure()
  117. self._axes = self._figure.add_subplot(111)
  118. self._canvas = FigureCanvas(self._figure)
  119. layout = Widgets.QVBoxLayout()
  120. h_layout = Widgets.QHBoxLayout()
  121. h_layout.addWidget(Widgets.QLabel('<b>Profiles</b>'))
  122. h_layout.addStretch(1)
  123. layout.addLayout(h_layout)
  124. h_layout = Widgets.QHBoxLayout()
  125. h_layout.addWidget(Widgets.QLabel('Bin size:'))
  126. h_layout.addWidget(self._bin_size_line_edit)
  127. h_layout.addWidget(self._bin_size_slider, stretch=1)
  128. layout.addLayout(h_layout)
  129. layout.addWidget(self._canvas, stretch=1)
  130. self.setLayout(layout)
  131. @property
  132. def data_frame(self):
  133. return self._df
  134. @data_frame.setter
  135. def data_frame(self, df):
  136. self._df = df
  137. self._bin_size_slider.setEnabled(True)
  138. self.replot(self._bin_size_slider.value())
  139. def replot(self, *args):
  140. self._reset_figure()
  141. self._plot_profile(BasicRecorder.possible_column_names['initial x'])
  142. self._plot_profile(BasicRecorder.possible_column_names['final x'])
  143. self._axes.set_xlabel('x [mm]')
  144. self._axes.set_ylabel('[a.u.]')
  145. self._axes.legend()
  146. self._canvas.draw()
  147. def _plot_profile(self, column_name):
  148. try:
  149. centers, bins = self._generate_histogram(self._df[column_name])
  150. except KeyError:
  151. pass
  152. else:
  153. self._axes.plot(centers, bins, label=column_name.split()[0])
  154. def _generate_histogram(self, samples):
  155. samples = np.array(samples) * 1.0e3 # [m] -> [mm]
  156. bin_size = (
  157. self._bin_size_slider.value() * self.bin_size_slider_multiplier
  158. * 1.0e-3 # [um] -> [mm]
  159. )
  160. n_bins = int((np.max(samples) - np.min(samples)) / bin_size)
  161. bins, edges = np.histogram(samples, bins=n_bins)
  162. centers = edges[:-1] + (edges[1] - edges[0]) / 2.
  163. return centers, bins
  164. def _reset_figure(self):
  165. self._figure.clear()
  166. self._axes = self._figure.add_subplot(111)
  167. self._canvas.draw()
  168. class ScatterPlot(Widgets.QWidget):
  169. distribution_options_spatial_3d = {
  170. 'x_scaling_factor': 1.0e3,
  171. 'y_scaling_factor': 1.0e3,
  172. 'z_scaling_factor': 1.0e3,
  173. 'x_label': 'x [mm]',
  174. 'y_label': 'y [mm]',
  175. 'z_label': 'z [mm]',
  176. }
  177. distribution_options_spatial_2d = {
  178. 'x_scaling_factor': 1.0e3,
  179. 'y_scaling_factor': 1.0e3,
  180. 'x_label': '{0} [mm]',
  181. 'y_label': '{0} [mm]',
  182. }
  183. distribution_options_time_and_spatial_2d = {
  184. 'x_scaling_factor': 1,
  185. 'y_scaling_factor': 1.0e3,
  186. 'x_label': 'simulation step',
  187. 'y_label': '{0} [mm]',
  188. }
  189. distributions = {
  190. '3d': {
  191. 'column-names': ['{0} x', '{0} y', '{0} z'],
  192. 'options': distribution_options_spatial_3d
  193. },
  194. 'xy-plane': {
  195. 'column-names': ['{0} x', '{0} y'],
  196. 'options': distribution_options_spatial_2d
  197. },
  198. 'xz-plane': {
  199. 'column-names': ['{0} x', '{0} z'],
  200. 'options': distribution_options_spatial_2d
  201. },
  202. 'yz-plane': {
  203. 'column-names': ['{0} y', '{0} z'],
  204. 'options': distribution_options_spatial_2d
  205. },
  206. 'tx-distribution': {
  207. 'column-names': ['{0} sim. step', '{0} x'],
  208. 'options': distribution_options_time_and_spatial_2d
  209. },
  210. 'ty-distribution': {
  211. 'column-names': ['{0} sim. step', '{0} y'],
  212. 'options': distribution_options_time_and_spatial_2d
  213. },
  214. }
  215. scatter_plot_marker_size = 1
  216. def __init__(self, stage, df=None, parent=None):
  217. super().__init__(parent=parent, flags=QtCore.Qt.WindowType.Widget)
  218. self._df = df
  219. if stage not in ('initial', 'final'):
  220. raise ValueError('Invalid value for stage: %s' % stage)
  221. self._stage = stage
  222. self._figure_3d = plt.figure()
  223. self._axes_3d = self._figure_3d.add_subplot(111, projection='3d')
  224. self._canvas_3d = FigureCanvas(self._figure_3d)
  225. self._figure_2d = plt.figure()
  226. self._axes_2d = self._figure_2d.add_subplot(111)
  227. self._canvas_2d = FigureCanvas(self._figure_2d)
  228. self._plot_selector = Widgets.QComboBox()
  229. self._plot_stack = Widgets.QStackedWidget()
  230. self._plot_stack.addWidget(self._canvas_3d)
  231. self._plot_stack.addWidget(self._canvas_2d)
  232. for distribution in sorted(self.distributions, reverse=True):
  233. self._plot_selector.addItem(distribution)
  234. self._plot_selector.currentIndexChanged.connect(self.plot)
  235. self._plot_selector.setEnabled(False)
  236. v_layout = Widgets.QVBoxLayout()
  237. h_layout = Widgets.QHBoxLayout()
  238. h_layout.addWidget(self._plot_selector)
  239. h_layout.addWidget(Widgets.QLabel(
  240. '<b>{0} particle distribution</b>'.format(self._stage.capitalize())
  241. ))
  242. h_layout.addStretch(1)
  243. v_layout.addLayout(h_layout)
  244. v_layout.addWidget(self._plot_stack)
  245. self.setLayout(v_layout)
  246. @property
  247. def data_frame(self):
  248. return self._df
  249. @data_frame.setter
  250. def data_frame(self, df):
  251. self._df = df
  252. self._plot_selector.setEnabled(True)
  253. self.plot(self._plot_selector.currentIndex())
  254. def plot(self, index):
  255. if self._df is None:
  256. return
  257. distribution = self._plot_selector.itemText(index)
  258. column_names = list(map(
  259. lambda x: BasicRecorder.possible_column_names[x.format(self._stage)],
  260. self.distributions[distribution]['column-names']
  261. ))
  262. options = self.distributions[distribution]['options'].copy()
  263. if len(column_names) == 3:
  264. self.plot_3d(*column_names, **options)
  265. else:
  266. plane = re.match(r'([a-z]{2})-(plane|distribution)', distribution).groups()[0]
  267. options['x_label'] = options['x_label'].format(plane[0])
  268. options['y_label'] = options['y_label'].format(plane[1])
  269. self.plot_2d(*column_names, **options)
  270. def plot_3d(self, x_name, y_name, z_name, x_label=None, y_label=None, z_label=None,
  271. x_scaling_factor=1.0, y_scaling_factor=1.0, z_scaling_factor=1.0):
  272. self._reset_3d_figure()
  273. self._plot_stack.setCurrentWidget(self._canvas_3d)
  274. try:
  275. xs = self._df[x_name.format(self._stage)] * x_scaling_factor
  276. ys = self._df[y_name.format(self._stage)] * y_scaling_factor
  277. zs = self._df[z_name.format(self._stage)] * z_scaling_factor
  278. except KeyError as err:
  279. Widgets.QMessageBox.information(
  280. self,
  281. 'Incomplete data',
  282. 'The plot could not be created because column "%s" is missing in the data file.'
  283. % str(err)
  284. )
  285. return
  286. self._axes_3d.scatter(xs, ys, zs, s=self.scatter_plot_marker_size)
  287. self._axes_3d.set_xlabel(x_label or x_name)
  288. self._axes_3d.set_ylabel(y_label or y_name)
  289. self._axes_3d.set_zlabel(z_label or z_name)
  290. self._canvas_3d.draw()
  291. def plot_2d(self, x_name, y_name, x_label=None, y_label=None, x_scaling_factor=1.0,
  292. y_scaling_factor=1.0):
  293. self._reset_2d_figure()
  294. self._plot_stack.setCurrentWidget(self._canvas_2d)
  295. try:
  296. xs = self._df[x_name.format(self._stage)] * x_scaling_factor
  297. ys = self._df[y_name.format(self._stage)] * y_scaling_factor
  298. except KeyError as err:
  299. Widgets.QMessageBox.information(
  300. self,
  301. 'Incomplete data',
  302. 'The plot could not be created because column "%s" is missing in the data file.'
  303. % str(err)
  304. )
  305. return
  306. self._axes_2d.scatter(xs, ys, s=self.scatter_plot_marker_size)
  307. self._axes_2d.set_xlabel(x_label or x_name)
  308. self._axes_2d.set_ylabel(y_label or y_name)
  309. self._canvas_2d.draw()
  310. def _reset_3d_figure(self):
  311. self._figure_3d.clear()
  312. self._axes_3d = self._figure_3d.add_subplot(111, projection='3d')
  313. self._canvas_3d.draw()
  314. def _reset_2d_figure(self):
  315. self._figure_2d.clear()
  316. self._axes_2d = self._figure_2d.add_subplot(111)
  317. self._canvas_2d.draw()