PageRenderTime 48ms CodeModel.GetById 20ms RepoModel.GetById 1ms app.codeStats 0ms

/wav2vec_cycle_code/fairseq/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py

https://gitlab.com/lwd17/enhanced_examplar_ae
Python | 363 lines | 331 code | 23 blank | 9 comment | 9 complexity | 2b67be384a0c53da0d6d45094c1705e8 MD5 | raw file
  1. import math
  2. import os
  3. import json
  4. import numpy as np
  5. import torch
  6. import torchaudio.compliance.kaldi as kaldi
  7. import yaml
  8. from fairseq import checkpoint_utils, tasks
  9. from fairseq.file_io import PathManager
  10. try:
  11. from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS
  12. from simuleval.agents import SpeechAgent
  13. from simuleval.states import ListEntry, SpeechStates
  14. except ImportError:
  15. print("Please install simuleval 'pip install simuleval'")
  16. SHIFT_SIZE = 10
  17. WINDOW_SIZE = 25
  18. SAMPLE_RATE = 16000
  19. FEATURE_DIM = 80
  20. BOW_PREFIX = "\u2581"
  21. class OnlineFeatureExtractor:
  22. """
  23. Extract speech feature on the fly.
  24. """
  25. def __init__(self, args):
  26. self.shift_size = args.shift_size
  27. self.window_size = args.window_size
  28. assert self.window_size >= self.shift_size
  29. self.sample_rate = args.sample_rate
  30. self.feature_dim = args.feature_dim
  31. self.num_samples_per_shift = int(self.shift_size * self.sample_rate / 1000)
  32. self.num_samples_per_window = int(self.window_size * self.sample_rate / 1000)
  33. self.len_ms_to_samples = lambda x: x * self.sample_rate / 1000
  34. self.previous_residual_samples = []
  35. self.global_cmvn = args.global_cmvn
  36. def clear_cache(self):
  37. self.previous_residual_samples = []
  38. def __call__(self, new_samples):
  39. samples = self.previous_residual_samples + new_samples
  40. if len(samples) < self.num_samples_per_window:
  41. self.previous_residual_samples = samples
  42. return
  43. # num_frames is the number of frames from the new segment
  44. num_frames = math.floor(
  45. (len(samples) - self.len_ms_to_samples(self.window_size - self.shift_size))
  46. / self.num_samples_per_shift
  47. )
  48. # the number of frames used for feature extraction
  49. # including some part of thte previous segment
  50. effective_num_samples = int(
  51. num_frames * self.len_ms_to_samples(self.shift_size)
  52. + self.len_ms_to_samples(self.window_size - self.shift_size)
  53. )
  54. input_samples = samples[:effective_num_samples]
  55. self.previous_residual_samples = samples[
  56. num_frames * self.num_samples_per_shift:
  57. ]
  58. torch.manual_seed(1)
  59. output = kaldi.fbank(
  60. torch.FloatTensor(input_samples).unsqueeze(0),
  61. num_mel_bins=self.feature_dim,
  62. frame_length=self.window_size,
  63. frame_shift=self.shift_size,
  64. ).numpy()
  65. output = self.transform(output)
  66. return torch.from_numpy(output)
  67. def transform(self, input):
  68. if self.global_cmvn is None:
  69. return input
  70. mean = self.global_cmvn["mean"]
  71. std = self.global_cmvn["std"]
  72. x = np.subtract(input, mean)
  73. x = np.divide(x, std)
  74. return x
  75. class TensorListEntry(ListEntry):
  76. """
  77. Data structure to store a list of tensor.
  78. """
  79. def append(self, value):
  80. if len(self.value) == 0:
  81. self.value = value
  82. return
  83. self.value = torch.cat([self.value] + [value], dim=0)
  84. def info(self):
  85. return {
  86. "type": str(self.new_value_type),
  87. "length": self.__len__(),
  88. "value": "" if type(self.value) is list else self.value.size(),
  89. }
  90. class FairseqSimulSTAgent(SpeechAgent):
  91. speech_segment_size = 40 # in ms, 4 pooling ratio * 10 ms step size
  92. def __init__(self, args):
  93. super().__init__(args)
  94. self.eos = DEFAULT_EOS
  95. self.gpu = getattr(args, "gpu", False)
  96. self.args = args
  97. self.load_model_vocab(args)
  98. if getattr(
  99. self.model.decoder.layers[0].encoder_attn,
  100. 'pre_decision_ratio',
  101. None
  102. ) is not None:
  103. self.speech_segment_size *= (
  104. self.model.decoder.layers[0].encoder_attn.pre_decision_ratio
  105. )
  106. args.global_cmvn = None
  107. if args.config:
  108. with open(os.path.join(args.data_bin, args.config), "r") as f:
  109. config = yaml.load(f, Loader=yaml.BaseLoader)
  110. if "global_cmvn" in config:
  111. args.global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"])
  112. if args.global_stats:
  113. with PathManager.open(args.global_stats, "r") as f:
  114. global_cmvn = json.loads(f.read())
  115. self.global_cmvn = {"mean": global_cmvn["mean"], "std": global_cmvn["stddev"]}
  116. self.feature_extractor = OnlineFeatureExtractor(args)
  117. self.max_len = args.max_len
  118. self.force_finish = args.force_finish
  119. torch.set_grad_enabled(False)
  120. def build_states(self, args, client, sentence_id):
  121. # Initialize states here, for example add customized entry to states
  122. # This function will be called at beginning of every new sentence
  123. states = SpeechStates(args, client, sentence_id, self)
  124. self.initialize_states(states)
  125. return states
  126. def to_device(self, tensor):
  127. if self.gpu:
  128. return tensor.cuda()
  129. else:
  130. return tensor.cpu()
  131. @staticmethod
  132. def add_args(parser):
  133. # fmt: off
  134. parser.add_argument('--model-path', type=str, required=True,
  135. help='path to your pretrained model.')
  136. parser.add_argument("--data-bin", type=str, required=True,
  137. help="Path of data binary")
  138. parser.add_argument("--config", type=str, default=None,
  139. help="Path to config yaml file")
  140. parser.add_argument("--global-stats", type=str, default=None,
  141. help="Path to json file containing cmvn stats")
  142. parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece",
  143. help="Subword splitter type for target text")
  144. parser.add_argument("--tgt-splitter-path", type=str, default=None,
  145. help="Subword splitter model path for target text")
  146. parser.add_argument("--user-dir", type=str, default="examples/simultaneous_translation",
  147. help="User directory for simultaneous translation")
  148. parser.add_argument("--max-len", type=int, default=200,
  149. help="Max length of translation")
  150. parser.add_argument("--force-finish", default=False, action="store_true",
  151. help="Force the model to finish the hypothsis if the source is not finished")
  152. parser.add_argument("--shift-size", type=int, default=SHIFT_SIZE,
  153. help="Shift size of feature extraction window.")
  154. parser.add_argument("--window-size", type=int, default=WINDOW_SIZE,
  155. help="Window size of feature extraction window.")
  156. parser.add_argument("--sample-rate", type=int, default=SAMPLE_RATE,
  157. help="Sample rate")
  158. parser.add_argument("--feature-dim", type=int, default=FEATURE_DIM,
  159. help="Acoustic feature dimension.")
  160. # fmt: on
  161. return parser
  162. def load_model_vocab(self, args):
  163. filename = args.model_path
  164. if not os.path.exists(filename):
  165. raise IOError("Model file not found: {}".format(filename))
  166. state = checkpoint_utils.load_checkpoint_to_cpu(filename)
  167. task_args = state["cfg"]["task"]
  168. task_args.data = args.data_bin
  169. if args.config is not None:
  170. task_args.config_yaml = args.config
  171. task = tasks.setup_task(task_args)
  172. # build model for ensemble
  173. state["cfg"]["model"].load_pretrained_encoder_from = None
  174. state["cfg"]["model"].load_pretrained_decoder_from = None
  175. self.model = task.build_model(state["cfg"]["model"])
  176. self.model.load_state_dict(state["model"], strict=True)
  177. self.model.eval()
  178. self.model.share_memory()
  179. if self.gpu:
  180. self.model.cuda()
  181. # Set dictionary
  182. self.dict = {}
  183. self.dict["tgt"] = task.target_dictionary
  184. def initialize_states(self, states):
  185. self.feature_extractor.clear_cache()
  186. states.units.source = TensorListEntry()
  187. states.units.target = ListEntry()
  188. states.incremental_states = dict()
  189. def segment_to_units(self, segment, states):
  190. # Convert speech samples to features
  191. features = self.feature_extractor(segment)
  192. if features is not None:
  193. return [features]
  194. else:
  195. return []
  196. def units_to_segment(self, units, states):
  197. # Merge sub word to full word.
  198. if self.model.decoder.dictionary.eos() == units[0]:
  199. return DEFAULT_EOS
  200. segment = []
  201. if None in units.value:
  202. units.value.remove(None)
  203. for index in units:
  204. if index is None:
  205. units.pop()
  206. token = self.model.decoder.dictionary.string([index])
  207. if token.startswith(BOW_PREFIX):
  208. if len(segment) == 0:
  209. segment += [token.replace(BOW_PREFIX, "")]
  210. else:
  211. for j in range(len(segment)):
  212. units.pop()
  213. string_to_return = ["".join(segment)]
  214. if self.model.decoder.dictionary.eos() == units[0]:
  215. string_to_return += [DEFAULT_EOS]
  216. return string_to_return
  217. else:
  218. segment += [token.replace(BOW_PREFIX, "")]
  219. if (
  220. len(units) > 0
  221. and self.model.decoder.dictionary.eos() == units[-1]
  222. or len(states.units.target) > self.max_len
  223. ):
  224. tokens = [self.model.decoder.dictionary.string([unit]) for unit in units]
  225. return ["".join(tokens).replace(BOW_PREFIX, "")] + [DEFAULT_EOS]
  226. return None
  227. def update_model_encoder(self, states):
  228. if len(states.units.source) == 0:
  229. return
  230. src_indices = self.to_device(
  231. states.units.source.value.unsqueeze(0)
  232. )
  233. src_lengths = self.to_device(
  234. torch.LongTensor([states.units.source.value.size(0)])
  235. )
  236. states.encoder_states = self.model.encoder(src_indices, src_lengths)
  237. torch.cuda.empty_cache()
  238. def update_states_read(self, states):
  239. # Happens after a read action.
  240. self.update_model_encoder(states)
  241. def policy(self, states):
  242. if not getattr(states, "encoder_states", None):
  243. return READ_ACTION
  244. tgt_indices = self.to_device(
  245. torch.LongTensor(
  246. [self.model.decoder.dictionary.eos()]
  247. + [x for x in states.units.target.value if x is not None]
  248. ).unsqueeze(0)
  249. )
  250. states.incremental_states["steps"] = {
  251. "src": states.encoder_states["encoder_out"][0].size(0),
  252. "tgt": 1 + len(states.units.target),
  253. }
  254. states.incremental_states["online"] = {"only": torch.tensor(not states.finish_read())}
  255. x, outputs = self.model.decoder.forward(
  256. prev_output_tokens=tgt_indices,
  257. encoder_out=states.encoder_states,
  258. incremental_state=states.incremental_states,
  259. )
  260. states.decoder_out = x
  261. states.decoder_out_extra = outputs
  262. torch.cuda.empty_cache()
  263. if outputs.action == 0:
  264. return READ_ACTION
  265. else:
  266. return WRITE_ACTION
  267. def predict(self, states):
  268. decoder_states = states.decoder_out
  269. lprobs = self.model.get_normalized_probs(
  270. [decoder_states[:, -1:]], log_probs=True
  271. )
  272. index = lprobs.argmax(dim=-1)
  273. index = index[0, 0].item()
  274. if (
  275. self.force_finish
  276. and index == self.model.decoder.dictionary.eos()
  277. and not states.finish_read()
  278. ):
  279. # If we want to force finish the translation
  280. # (don't stop before finish reading), return a None
  281. # self.model.decoder.clear_cache(states.incremental_states)
  282. index = None
  283. return index