/figure_compare_differences.py

https://github.com/wafels/eitwave
Python | 121 lines | 79 code | 25 blank | 17 comment | 13 complexity | ae981f9a1edac5bfb3a797437c583236 MD5 | raw file
  1. #
  2. # Script that loads in a data set and creates a series of plots
  3. # that compare the effect of different running difference algorithms
  4. import os
  5. from copy import deepcopy
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. import matplotlib.cm as cm
  9. import mapcube_tools
  10. import swave_study as sws
  11. import aware_utils
  12. filepath = '~/eitwave/img/difference_comparison.png'
  13. # Summing of the simulated observations in the time direction
  14. temporal_summing = sws.temporal_summing
  15. # Summing of the simulated observations in the spatial directions
  16. spatial_summing = sws.spatial_summing
  17. wave_names = ['longetal2014_figure8a', 'longetal2014_figure8e', 'longetal2014_figure4']
  18. differencing_types = ['RDP', 'RD', 'PBD']
  19. info = {'longetal2014_figure8a': 20,
  20. 'longetal2014_figure8e': 20,
  21. 'longetal2014_figure4': 20}
  22. fontsize = 8
  23. maps = {}
  24. figure_type = 2
  25. #
  26. # Set up the matplotlib plot
  27. #
  28. plt.close('all')
  29. fig, axes = plt.subplots(3, len(wave_names), figsize=(9, 9))
  30. # Go through each wave
  31. for i, wave_name in enumerate(wave_names):
  32. index = info[wave_name]
  33. # Load observational data from file
  34. euv_wave_data = aware_utils.create_input_to_aware_for_test_observational_data(wave_name)
  35. #
  36. print('Accumulating AIA data.')
  37. mc = euv_wave_data['finalmaps']
  38. mc = mapcube_tools.accumulate(mapcube_tools.superpixel(mc, spatial_summing), temporal_summing)
  39. for differencing_type in differencing_types:
  40. if differencing_type == 'RD':
  41. # running difference
  42. print('Calculating the running difference.')
  43. mc_rd = mapcube_tools.running_difference(mc)
  44. new = deepcopy(mc_rd[index])
  45. new.plot_settings['cmap'] = cm.RdGy
  46. if differencing_type == 'PBD':
  47. # fraction base difference
  48. print('Calculating the base difference.')
  49. mc_pbd = mapcube_tools.base_difference(mc, fraction=True)
  50. new = deepcopy(mc_pbd[index+1])
  51. new.plot_settings['norm'].vmax = 0.5
  52. new.plot_settings['norm'].vmin = -0.5
  53. new.plot_settings['cmap'] = cm.RdGy
  54. if differencing_type == 'RDP':
  55. # running difference persistence images
  56. print('Calculating the running difference persistence images.')
  57. mc_rdp = mapcube_tools.running_difference(mapcube_tools.persistence(mc))
  58. new = deepcopy(mc_rdp[index])
  59. new.plot_settings['cmap'] = cm.gray_r
  60. maps[differencing_type] = new
  61. rd_all_vmax = np.max([maps['RD'].data.max(), maps['RDP'].data.max()])
  62. maps['RD'].plot_settings['norm'].vmax = rd_all_vmax
  63. maps['RDP'].plot_settings['norm'].vmax = rd_all_vmax
  64. # Go through each differencing type
  65. for j, differencing_type in enumerate(differencing_types):
  66. tm = maps[differencing_type]
  67. ta = axes[j, i]
  68. # Just use the built in map plotting
  69. if figure_type == 1:
  70. if j == 0:
  71. tm.plot(axes=ta, title=differencing_type + '\n' + tm.date.strftime("%Y/%m/%d %H:%M:%S"))
  72. else:
  73. tm.plot(axes=ta, title=differencing_type)
  74. tm.draw_limb(color='black')
  75. ta.set_xlabel('x (arcsec)', fontsize=fontsize)
  76. xtl = ta.axes.xaxis.get_majorticklabels()
  77. for l in range(0, len(xtl)):
  78. xtl[l].set_fontsize(0.67*fontsize)
  79. ta.set_ylabel('y (arcsec)', fontsize=fontsize)
  80. ytl = ta.axes.yaxis.get_majorticklabels()
  81. for l in range(0, len(ytl)):
  82. ytl[l].set_fontsize(0.67*fontsize)
  83. ta.axes.yaxis.set_visible(False)
  84. ta.axes.xaxis.set_visible(False)
  85. # Show the image data itself.
  86. if figure_type == 2:
  87. ta.imshow(tm.data)
  88. ta.set_axis_off()
  89. if i == 0:
  90. fig.text(0.05, 0.2 + i*0.33, differencing_type)
  91. plt.tight_layout(pad=0.0, h_pad=0.0, w_pad=0.0, rect=(0.0, 0.0, 1.0, 1.0))
  92. plt.savefig(os.path.expanduser(filepath))
  93. plt.close('all')
  94. fig.text(0, 0.2, 'RD')
  95. fig.text(0, 0.2 + 0.33, 'PD')
  96. fig.text(0, 0.2 + 0.66, 'RDP')