/sedfitter/convolve/monochromatic.py

https://github.com/astrofrog/sedfitter
Python | 135 lines | 72 code | 31 blank | 32 comment | 15 complexity | 30258b374edb750b06a7d924e15540a3 MD5 | raw file
  1. from __future__ import print_function, division
  2. import os
  3. import glob
  4. import numpy as np
  5. from astropy.logger import log
  6. from astropy.table import Table
  7. from astropy import units as u
  8. from astropy.utils.console import ProgressBar
  9. from ..convolved_fluxes import ConvolvedFluxes
  10. from ..sed import SED
  11. from ..models import load_parameter_table
  12. from .. import six
  13. from ..utils import parfile
  14. __all__ = ['convolve_model_dir_monochromatic', ]
  15. def convolve_model_dir_monochromatic(model_dir, overwrite=False, max_ram=8,
  16. wav_min=-np.inf * u.micron, wav_max=np.inf * u.micron):
  17. """
  18. Convolve all the model SEDs in a model directory
  19. Parameters
  20. ----------
  21. model_dir : str
  22. The path to the model directory
  23. overwrite : bool, optional
  24. Whether to overwrite the output files
  25. max_ram : float, optional
  26. The maximum amount of RAM that can be used (in Gb)
  27. wav_min : float, optional
  28. The minimum wavelength to consider. Only wavelengths above this value
  29. will be output.
  30. wav_max : float, optional
  31. The maximum wavelength to consider. Only wavelengths below this value
  32. will be output.
  33. """
  34. modpar = parfile.read(os.path.join(model_dir, 'models.conf'), 'conf')
  35. if modpar.get('version', 1) > 1:
  36. raise ValueError("monochromatic filters are no longer used for new-style model directories")
  37. # Create 'convolved' sub-directory if needed
  38. if not os.path.exists(model_dir + '/convolved'):
  39. os.mkdir(model_dir + '/convolved')
  40. # Find all SED files to convolve
  41. sed_files = sorted(glob.glob(model_dir + '/seds/*.fits.gz') +
  42. glob.glob(model_dir + '/seds/*/*.fits.gz') +
  43. glob.glob(model_dir + '/seds/*.fits') +
  44. glob.glob(model_dir + '/seds/*/*.fits'))
  45. par_table = load_parameter_table(model_dir)
  46. # Find number of models
  47. n_models = len(sed_files)
  48. if n_models == 0:
  49. raise Exception("No SEDs found in %s" % model_dir)
  50. else:
  51. log.info("{0} SEDs found in {1}".format(n_models, model_dir))
  52. # Find out apertures and wavelengths
  53. first_sed = SED.read(sed_files[0])
  54. n_ap = first_sed.n_ap
  55. apertures = first_sed.apertures
  56. n_wav = first_sed.n_wav
  57. wavelengths = first_sed.wav
  58. # For model grids that are very large, it is not possible to compute all
  59. # fluxes in one go, so we need to process in chunks in wavelength space.
  60. chunk_size = min(n_wav, int(np.floor(max_ram * 1024. ** 3 / (4. * 2. * n_models * n_ap))))
  61. if chunk_size == n_wav:
  62. log.info("Producing all monochromatic files in one go")
  63. else:
  64. log.info("Producing monochromatic files in chunks of {0}".format(chunk_size))
  65. filters = Table()
  66. filters['wav'] = wavelengths
  67. filters['filter'] = np.zeros(wavelengths.shape, dtype='S10')
  68. # Figure out range of wavelength indices to use
  69. # (wavelengths array is sorted in reverse order)
  70. jlo = n_wav - 1 - (wavelengths[::-1].searchsorted(wav_max) - 1)
  71. jhi = n_wav - 1 - wavelengths[::-1].searchsorted(wav_min)
  72. chunk_size = min(chunk_size, jhi - jlo + 1)
  73. # Loop over wavelength chunks
  74. for jmin in range(jlo, jhi, chunk_size):
  75. # Find upper wavelength to compute
  76. jmax = min(jmin + chunk_size - 1, jhi)
  77. log.info('Processing wavelengths {0} to {1}'.format(jmin, jmax))
  78. # Set up convolved fluxes
  79. fluxes = [ConvolvedFluxes(model_names=np.zeros(n_models, dtype='U30' if six.PY3 else 'S30'), apertures=apertures, initialize_arrays=True) for i in range(chunk_size)]
  80. b = ProgressBar(len(sed_files))
  81. # Loop over SEDs
  82. for im, sed_file in enumerate(sed_files):
  83. b.update()
  84. log.debug('Processing {0}'.format(os.path.basename(sed_file)))
  85. # Read in SED
  86. s = SED.read(sed_file, unit_freq=u.Hz, unit_flux=u.mJy, order='nu')
  87. # Convolve
  88. for j in range(chunk_size):
  89. fluxes[j].central_wavelength = wavelengths[j + jmin]
  90. fluxes[j].apertures = apertures
  91. fluxes[j].model_names[im] = s.name
  92. if n_ap == 1:
  93. fluxes[j].flux[im] = s.flux[0, j + jmin]
  94. fluxes[j].error[im] = s.error[0, j + jmin]
  95. else:
  96. fluxes[j].flux[im, :] = s.flux[:, j + jmin]
  97. fluxes[j].error[im, :] = s.error[:, j + jmin]
  98. for j in range(chunk_size):
  99. fluxes[j].sort_to_match(par_table['MODEL_NAME'])
  100. fluxes[j].write('{0:s}/convolved/MO{1:03d}.fits'.format(model_dir, j + jmin + 1),
  101. overwrite=overwrite)
  102. filters['filter'][j + jmin] = "MO{0:03d}".format(j + jmin + 1)
  103. return filters