PageRenderTime 46ms CodeModel.GetById 15ms RepoModel.GetById 1ms app.codeStats 0ms

/example/speech-demo/io_func/feat_io.py

https://gitlab.com/admin-github-cloud/mxnet
Python | 394 lines | 377 code | 13 blank | 4 comment | 14 complexity | 824c6e59c4bd565a2fedb57fd43ddb0b MD5 | raw file
  1. import os
  2. import sys
  3. import random
  4. import shlex
  5. import time
  6. import re
  7. from utils import to_bool
  8. from .feat_readers.common import *
  9. from .feat_readers import stats
  10. class DataReadStream(object):
  11. SCHEMA = {
  12. "type": "object",
  13. "properties": {
  14. "gpu_chunk": {"type": ["string", "integer"], "required": False},
  15. "lst_file": {"type": "string"},
  16. "separate_lines": {"type": ["string", "integer", "boolean"], "required": False},
  17. "has_labels": {"type": ["string", "integer", "boolean"], "required": False},
  18. "file_format": {"type": "string"},
  19. "train_stat": {"type": "string", "required": False},
  20. "offset_labels": {"type": ["string", "integer", "boolean"], "required": False},
  21. #"XXXchunk": {"type": ["string", "integer"], "required": False},
  22. "max_feats": {"type": ["string", "integer"], "required": False},
  23. "shuffle": {"type": ["string", "integer", "boolean"], "required": False},
  24. "seed": {"type": ["string", "integer"], "required": False},
  25. "_num_splits": {"type": ["string", "integer"], "required": False},
  26. "_split_id": {"type": ["string", "integer"], "required": False}
  27. }
  28. }
  29. END_OF_DATA = -1
  30. END_OF_PARTITION = -2
  31. END_OF_SEQ = (None, None, None)
  32. def __init__(self, dataset_args, n_ins):
  33. # stats
  34. self.mean = None
  35. self.std = None
  36. if 'train_stat' in dataset_args.keys():
  37. train_stat = dataset_args['train_stat']
  38. featureStats = stats.FeatureStats()
  39. featureStats.Load(train_stat)
  40. self.mean = featureStats.GetMean()
  41. self.std = featureStats.GetInvStd()
  42. # open lstfile
  43. file_path = dataset_args["lst_file"]
  44. if file_path.endswith('.gz'):
  45. file_read = gzip.open(file_path, 'r')
  46. else:
  47. file_read = open(file_path, 'r')
  48. separate_lines = False
  49. if "separate_lines" in dataset_args:
  50. separate_lines = to_bool(dataset_args["separate_lines"])
  51. self.has_labels = True
  52. if "has_labels" in dataset_args:
  53. self.has_labels = to_bool(dataset_args["has_labels"])
  54. # parse it, file_lst is a list of (featureFile, labelFile) pairs in the input set
  55. lines = [ln.strip() for ln in file_read]
  56. lines = [ln for ln in lines if ln != "" ]
  57. if self.has_labels:
  58. if separate_lines:
  59. if len(lines) % 2 != 0:
  60. print("List has mis-matched number of feature files and label files")
  61. sys.exit(1)
  62. self.orig_file_lst = []
  63. for i in xrange(0, len(lines), 2):
  64. self.orig_file_lst.append((lines[i], lines[i+1]))
  65. else:
  66. self.orig_file_lst = []
  67. for i in xrange(len(lines)):
  68. pair = re.compile("\s+").split(lines[i])
  69. if len(pair) != 2:
  70. print(lines[i])
  71. print("Each line in the train and eval lists must contain feature file and label file separated by space character")
  72. sys.exit(1)
  73. self.orig_file_lst.append(pair)
  74. else:
  75. # no labels
  76. self.orig_file_lst = []
  77. for i in xrange(0, len(lines), 1):
  78. self.orig_file_lst.append((lines[i], None))
  79. # save arguments
  80. self.n_ins = n_ins
  81. self.file_format = dataset_args['file_format']
  82. self.file_format = "htk"
  83. if 'file_format' in dataset_args:
  84. self.file_format = dataset_args['file_format']
  85. self.offsetLabels = False
  86. if 'offset_labels' in dataset_args:
  87. self.offsetLabels = to_bool(dataset_args['offset_labels'])
  88. self.chunk_size = 32768
  89. if 'gpu_chunk' in dataset_args:
  90. self.chunk_size = int(dataset_args['gpu_chunk'])
  91. self.maxFeats = 0
  92. if "max_feats" in dataset_args:
  93. self.maxFeats = int(dataset_args["max_feats"])
  94. if self.maxFeats == 0:
  95. self.maxFeats = sys.maxint
  96. self.shuffle = True
  97. if 'shuffle' in dataset_args:
  98. self.shuffle = to_bool(dataset_args['shuffle'])
  99. self.seed = None
  100. if "seed" in dataset_args:
  101. self.seed = int(dataset_args["seed"])
  102. if int("_split_id" in dataset_args) + int("_num_splits" in dataset_args) == 1:
  103. raise Exception("_split_id must be used with _num_splits")
  104. self.num_splits = 0
  105. if "_num_splits" in dataset_args:
  106. self.num_splits = int(dataset_Args["_num_splits"])
  107. self.split_id = dataset_args["_split_id"]
  108. # internal state
  109. self.split_parts = False
  110. self.by_matrix = False
  111. self.x = numpy.zeros((self.chunk_size, self.n_ins), dtype=numpy.float32)
  112. if self.has_labels:
  113. self.y = numpy.zeros((self.chunk_size,), dtype=numpy.int32)
  114. else:
  115. self.y = None
  116. self.numpy_rng = numpy.random.RandomState(self.seed)
  117. #self.make_shared()
  118. self.initialize_read()
  119. def read_by_part(self):
  120. if self.file_format in ["kaldi"]:
  121. self.read_by_matrix()
  122. else: # htk
  123. self.split_parts = True
  124. def read_by_matrix(self):
  125. self.by_matrix = True
  126. def get_shared(self):
  127. return self.shared_x, self.shared_y
  128. def initialize_read(self):
  129. self.file_lst = self.orig_file_lst[:]
  130. if self.shuffle:
  131. self.numpy_rng.shuffle(self.file_lst)
  132. self.fileIndex = 0
  133. self.totalFrames = 0
  134. self.reader = None
  135. self.crossed_part = False
  136. self.done = False
  137. self.utt_id = None
  138. self.queued_feats = None
  139. self.queued_tgts = None
  140. def _end_of_data(self):
  141. return self.totalFrames >= self.maxFeats or self.fileIndex >= len(self.file_lst)
  142. def _queue_get(self, at_most):
  143. # if we have frames/labels queued, return at_most of those and queue the rest
  144. if self.queued_feats is None:
  145. return None
  146. num_queued = self.queued_feats.shape[0]
  147. at_most = min(at_most, num_queued)
  148. if at_most == num_queued: # no leftover after the split
  149. feats, tgts = self.queued_feats, self.queued_tgts
  150. self.queued_feats = None
  151. self.queued_tgts = None
  152. else:
  153. feats, self.queued_feats = numpy.array_split(self.queued_feats, [at_most])
  154. if self.queued_tgts is not None:
  155. tgts, self.queued_tgts = numpy.array_split(self.queued_tgts, [at_most])
  156. else:
  157. tgts = None
  158. return feats, tgts
  159. def _queue_excess(self, at_most, feats, tgts):
  160. assert(self.queued_feats is None)
  161. num_supplied = feats.shape[0]
  162. if num_supplied > at_most:
  163. feats, self.queued_feats = numpy.array_split(feats, [at_most])
  164. if tgts is not None:
  165. tgts, self.queued_tgts = numpy.array_split(tgts, [at_most])
  166. return feats, tgts
  167. # Returns frames/labels (if there are any) or None (otherwise) for current partition
  168. # Always set the pointers to the next partition
  169. def _load_fn(self, at_most):
  170. tup = self._queue_get(at_most)
  171. if tup is not None:
  172. return tup
  173. if self.reader is None:
  174. featureFile, labelFile = self.file_lst[self.fileIndex]
  175. self.reader = getReader(self.file_format, featureFile, labelFile)
  176. if self.reader.IsDone():
  177. self.fileIndex += 1
  178. self.reader.Cleanup()
  179. self.reader = None # cleanup
  180. return None
  181. tup = self.reader.Read()
  182. if tup is None:
  183. self.fileIndex += 1
  184. self.reader.Cleanup()
  185. self.reader = None # cleanup
  186. return None
  187. feats, tgts = tup
  188. # normalize here
  189. if self.mean is not None:
  190. feats -= self.mean
  191. if self.std is not None:
  192. feats *= self.std
  193. self.utt_id = self.reader.GetUttId()
  194. if feats.shape[1] != self.n_ins:
  195. errMs = "Dimension of features read does not match specified dimensions".format(feats.shape[1], self.n_ins)
  196. if self.has_labels and tgts is not None:
  197. if feats.shape[0] != tgts.shape[0]:
  198. errMs = "Number of frames in feature ({}) and label ({}) files does not match".format(self.featureFile, self.labelFile)
  199. raise FeatureException(errMsg)
  200. if self.offsetLabels:
  201. tgts = numpy.add(tgts, - 1)
  202. feats, tgts = self._queue_excess(at_most, feats, tgts)
  203. return feats, tgts
  204. def current_utt_id(self):
  205. assert(self.by_matrix or self.split_parts)
  206. return self.utt_id
  207. def load_next_seq(self):
  208. if self.done:
  209. return DataReadStream.END_OF_SEQ
  210. if self._end_of_data():
  211. if self.reader is not None:
  212. self.reader.Cleanup()
  213. self.reader = None
  214. self.done = True
  215. return DataReadStream.END_OF_SEQ
  216. num_feats = 0
  217. old_fileIndes = self.fileIndex
  218. self.utt_id = None
  219. tup = self._load_fn(self.chunk_size)
  220. if tup is None:
  221. return DataReadStream.END_OF_SEQ
  222. (loaded_feats, loaded_tgts) = tup
  223. return loaded_feats, loaded_tgts, self.utt_id
  224. def load_next_block(self):
  225. # if anything left...
  226. # set_value
  227. if self.crossed_part:
  228. self.crossed_part = False
  229. if not self.by_matrix: # <--- THERE IS A BUG IN THIS
  230. return DataReadStream.END_OF_PARTITION
  231. if self.done:
  232. return DataReadStream.END_OF_DATA
  233. if self._end_of_data():
  234. if self.reader is not None:
  235. self.reader.Cleanup()
  236. self.reader = None # cleanup
  237. self.done = True
  238. return DataReadStream.END_OF_DATA
  239. # keep loading features until we pass a partition or EOF
  240. num_feats = 0
  241. old_fileIndex = self.fileIndex
  242. self.utt_id = None
  243. while num_feats < self.chunk_size:
  244. if self.split_parts:
  245. if old_fileIndex != self.fileIndex:
  246. self.crossed_part = True
  247. break
  248. if self._end_of_data():
  249. break
  250. tup = self._load_fn(self.chunk_size - num_feats)
  251. if tup is None:
  252. continue
  253. (loaded_feat, loaded_label) = tup
  254. if self.has_labels and loaded_label is None:
  255. print >> sys.stderr, "Missing labels for: ", self.utt_id
  256. continue
  257. numFrames = loaded_feat.shape[0]
  258. # limit loaded_feat, loaded_label, and numFrames to maximum allowed
  259. allowed = self.maxFeats - self.totalFrames
  260. if numFrames > allowed:
  261. loaded_feat = loaded_feat[0:allowed]
  262. if self.has_labels:
  263. loaded_label = loaded_label[0:allowed]
  264. numFrames = allowed
  265. assert(numFrames == loaded_feat.shape[0])
  266. self.totalFrames += numFrames
  267. new_num_feats = num_feats + numFrames
  268. # if the x and y buffers are too small, make bigger ones
  269. # not possible any more; buffers are always fixed
  270. """
  271. if new_num_feats > self.x.shape[0]:
  272. newx = numpy.zeros((new_num_feats, self.n_ins), dtype=numpy.float32)
  273. newx[0:num_feats] = self.x[0:num_feats]
  274. self.x = newx
  275. if self.has_labels:
  276. newy = numpy.zeros((new_num_feats,), dtype=numpy.int32)
  277. newy[0:num_feats] = self.y[0:num_feats]
  278. self.y = newy
  279. """
  280. # place into [num_feats:num_feats+num_loaded]
  281. self.x[num_feats:new_num_feats] = loaded_feat
  282. if self.has_labels:
  283. self.y[num_feats:new_num_feats] = loaded_label
  284. num_feats = new_num_feats
  285. if self.by_matrix:
  286. break
  287. # if we loaded features, shuffle and copy to shared
  288. if num_feats != 0:
  289. if self.shuffle:
  290. x = self.x[0:num_feats]
  291. state = self.numpy_rng.get_state()
  292. self.numpy_rng.shuffle(x)
  293. self.x[0:num_feats] = x
  294. if self.has_labels:
  295. y = self.y[0:num_feats]
  296. self.numpy_rng.set_state(state)
  297. self.numpy_rng.shuffle(y)
  298. self.y[0:num_feats] = y
  299. assert(self.x.shape == (self.chunk_size, self.n_ins))
  300. self.shared_x.set_value(self.x, borrow = True)
  301. if self.has_labels:
  302. self.shared_y.set_value(self.y, borrow = True)
  303. #import hashlib
  304. #print self.totalFrames, self.x.sum(), hashlib.sha1(self.x.view(numpy.float32)).hexdigest()
  305. if self.by_matrix:
  306. self.crossed_part = True
  307. return num_feats
  308. def get_state(self):
  309. return self.numpy_rng.get_state()
  310. def set_state(self, state):
  311. self.numpy_rng.set_state(state)