/deepxi/se_batch.py

https://github.com/anicolson/DeepXi
Python | 56 lines | 43 code | 2 blank | 11 comment | 8 complexity | 47b8753e1735c3fc4c644c06989335f4 MD5 | raw file
  1. ## AUTHOR: Aaron Nicolson
  2. ## AFFILIATION: Signal Processing Laboratory, Griffith University.
  3. ##
  4. ## This Source Code Form is subject to the terms of the Mozilla Public
  5. ## License, v. 2.0. If a copy of the MPL was not distributed with this
  6. ## file, You can obtain one at http://mozilla.org/MPL/2.0/.
  7. import contextlib, glob, os, pickle, platform, random, sys, wave
  8. import numpy as np
  9. from deepxi.utils import read_wav
  10. from scipy.io.wavfile import read
  11. def Batch(fdir, snr_l=[]):
  12. '''
  13. REQUIRES REWRITING. WILL BE MOVED TO deepxi/utils.py
  14. Places all of the test waveforms from the list into a numpy array.
  15. SPHERE format cannot be used. 'glob' is used to support Unix style pathname
  16. pattern expansions. Waveforms are padded to the maximum waveform length. The
  17. waveform lengths are recorded so that the correct lengths can be sliced
  18. for feature extraction. The SNR levels of each test file are placed into a
  19. numpy array. Also returns a list of the file names.
  20. Inputs:
  21. fdir - directory containing the waveforms.
  22. fnames - filename/s of the waveforms.
  23. snr_l - list of the SNR levels used.
  24. Outputs:
  25. wav_np - matrix of paded waveforms stored as a numpy array.
  26. len_np - length of each waveform strored as a numpy array.
  27. snr_test_np - numpy array of all the SNR levels for the test set.
  28. fname_l - list of filenames.
  29. '''
  30. fname_l = [] # list of file names.
  31. wav_l = [] # list for waveforms.
  32. snr_test_l = [] # list of SNR levels for the test set.
  33. # if isinstance(fnames, str): fnames = [fnames] # if string, put into list.
  34. fnames = ['*.wav', '*.flac', '*.mp3']
  35. for fname in fnames:
  36. for fpath in glob.glob(os.path.join(fdir, fname)):
  37. for snr in snr_l:
  38. if fpath.find('_' + str(snr) + 'dB') != -1:
  39. snr_test_l.append(snr) # append SNR level.
  40. (wav, _) = read_wav(fpath) # read waveform from given file path.
  41. if np.isnan(wav).any() or np.isinf(wav).any():
  42. raise ValueError('Error: NaN or Inf value. File path: %s.' % (file_path))
  43. wav_l.append(wav) # append.
  44. fname_l.append(os.path.basename(os.path.splitext(fpath)[0])) # append name.
  45. len_l = [] # list of the waveform lengths.
  46. maxlen = max(len(wav) for wav in wav_l) # maximum length of waveforms.
  47. wav_np = np.zeros([len(wav_l), maxlen], np.int16) # numpy array for waveform matrix.
  48. for (i, wav) in zip(range(len(wav_l)), wav_l):
  49. wav_np[i,:len(wav)] = wav # add waveform to numpy array.
  50. len_l.append(len(wav)) # append length of waveform to list.
  51. return wav_np, np.array(len_l, np.int32), np.array(snr_test_l, np.int32), fname_l