/lib/python/data_reader_DNN_v2.py

https://github.com/jtkim-kaist/VAD
Python | 177 lines | 111 code | 47 blank | 19 comment | 5 complexity | d4851c2c18bc6c676a435a3f90b78e49 MD5 | raw file
  1. import numpy as np
  2. import os
  3. import glob
  4. import utils
  5. import scipy.io as sio
  6. import matplotlib.pyplot as plt
  7. import matplotlib.image as mpimg
  8. class DataReader(object):
  9. def __init__(self, input_dir, output_dir, norm_dir, w=19, u=9, name=None):
  10. # print(name.title() + " data reader initialization...")
  11. self._input_dir = input_dir
  12. self._output_dir = output_dir
  13. self._norm_dir = norm_dir
  14. self._input_file_list = sorted(glob.glob(input_dir+'/*.bin'))
  15. self._input_spec_list = sorted(glob.glob(input_dir+'/*.txt'))
  16. self._output_file_list = sorted(glob.glob(output_dir+'/*.bin'))
  17. self._file_len = len(self._input_file_list)
  18. self._name = name
  19. assert self._file_len == len(self._output_file_list), "# input files and output file is not matched"
  20. self._w = w
  21. self._u = u
  22. self.eof = False
  23. self.file_change = False
  24. self.num_samples = 0
  25. self._inputs = 0
  26. self._outputs = 0
  27. self._epoch = 1
  28. self._num_file = 0
  29. self._start_idx = self._w
  30. norm_param = sio.loadmat(self._norm_dir+'/global_normalize_factor.mat')
  31. self.train_mean = norm_param['global_mean']
  32. self.train_std = norm_param['global_std']
  33. self.raw_inputs = 0 # adding part
  34. # print("Done")
  35. # print("BOF : " + self._name + " file_" + str(self._num_file).zfill(2))
  36. def _binary_read_with_shape(self):
  37. pass
  38. @staticmethod
  39. def _read_input(input_file_dir, input_spec_dir):
  40. data = np.fromfile(input_file_dir, dtype=np.float32) # (# total frame, feature_size)
  41. with open(input_spec_dir,'r') as f:
  42. spec = f.readline()
  43. size = spec.split(',')
  44. data = data.reshape((int(size[0]), int(size[1])), order='F')
  45. return data
  46. @staticmethod
  47. def _read_output(output_file_dir):
  48. data = np.fromfile(output_file_dir, dtype=np.float32) # data shape : (# total frame,)
  49. data = data.reshape(-1, 1) # data shape : (# total frame, 1)
  50. return data
  51. @staticmethod
  52. def _padding(inputs, batch_size, w_val):
  53. pad_size = batch_size - inputs.shape[0] % batch_size
  54. inputs = np.concatenate((inputs, np.zeros((pad_size, inputs.shape[1]), dtype=np.float32)))
  55. window_pad = np.zeros((w_val, inputs.shape[1]))
  56. inputs = np.concatenate((window_pad, inputs, window_pad), axis=0)
  57. return inputs
  58. def next_batch(self, batch_size):
  59. if self._start_idx == self._w:
  60. self._inputs = self._padding(
  61. self._read_input(self._input_file_list[self._num_file],
  62. self._input_spec_list[self._num_file]), batch_size, self._w)
  63. self._outputs = self._padding(self._read_output(self._output_file_list[self._num_file]), batch_size, self._w)
  64. assert np.shape(self._inputs)[0] == np.shape(self._outputs)[0], \
  65. ("# samples is not matched between input: %d and output: %d files"
  66. % (np.shape(self._inputs)[0], np.shape(self._outputs)[0]))
  67. self.num_samples = np.shape(self._outputs)[0]
  68. if self._start_idx + batch_size > self.num_samples:
  69. self._start_idx = self._w
  70. self.file_change = True
  71. self._num_file += 1
  72. # print("EOF : " + self._name + " file_" + str(self._num_file-1).zfill(2) +
  73. # " -> BOF : " + self._name + " file_" + str(self._num_file).zfill(2))
  74. if self._num_file > self._file_len - 1:
  75. self.eof = True
  76. self._num_file = 0
  77. # print("EOF : last " + self._name + " file. " + "-> BOF : " + self._name + " file_" +
  78. # str(self._num_file).zfill(2))
  79. self._inputs = self._padding(
  80. self._read_input(self._input_file_list[self._num_file],
  81. self._input_spec_list[self._num_file]), batch_size, self._w)
  82. self._outputs = self._padding(self._read_output(self._output_file_list[self._num_file]), batch_size, self._w)
  83. data_len = np.shape(self._inputs)[0]
  84. self._outputs = self._outputs[0:data_len, :]
  85. assert np.shape(self._inputs)[0] == np.shape(self._outputs)[0], \
  86. ("# samples is not matched between input: %d and output: %d files"
  87. % (np.shape(self._inputs)[0], np.shape(self._outputs)[0]))
  88. self.num_samples = np.shape(self._outputs)[0]
  89. else:
  90. self.file_change = False
  91. self.eof = False
  92. inputs = self._inputs[self._start_idx - self._w:self._start_idx + batch_size + self._w, :]
  93. self.raw_inputs = inputs # adding part
  94. inputs = self.normalize(inputs)
  95. inputs = utils.bdnn_transform(inputs, self._w, self._u)
  96. inputs = inputs[self._w: -self._w, :]
  97. outputs = self._outputs[self._start_idx:self._start_idx + batch_size, :]
  98. self._start_idx += batch_size
  99. return inputs, outputs
  100. #num_batches = (np.shape(self._outputs)[0] - np.shape(self._outputs)[0] % batch_size) / batch_size
  101. def normalize(self, x):
  102. x = (x - self.train_mean)/self.train_std
  103. # a = (np.std(x, axis=0))
  104. return x
  105. def reader_initialize(self):
  106. self._num_file = 0
  107. self._start_idx = 0
  108. self.eof = False
  109. def eof_checker(self):
  110. return self.eof
  111. def file_change_checker(self):
  112. return self.file_change
  113. def file_change_initialize(self):
  114. self.file_change = False
  115. def dense_to_one_hot(labels_dense, num_classes=2):
  116. """Convert class labels from scalars to one-hot vectors."""
  117. # copied from TensorFlow tutorial
  118. num_labels = labels_dense.shape[0]
  119. index_offset = np.arange(num_labels) * num_classes
  120. labels_one_hot = np.zeros((num_labels, num_classes))
  121. labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
  122. return labels_one_hot
  123. # file_dir = "/home/sbie/github/VAD_KJT/Datamake/Database/Aurora2withSE"
  124. # input_dir1 = file_dir + "/STFT2"
  125. # output_dir1 = file_dir + "/Labels"
  126. # dr = DataReader(input_dir1, output_dir1, input_dir1,name='test')
  127. #
  128. # for i in range(1000000):
  129. # tt, pp = dr.next_batch(500)
  130. # print("asdf")