/wav2vec_cycle_code/fairseq/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py
Python | 363 lines | 331 code | 23 blank | 9 comment | 9 complexity | 2b67be384a0c53da0d6d45094c1705e8 MD5 | raw file
- import math
- import os
- import json
- import numpy as np
- import torch
- import torchaudio.compliance.kaldi as kaldi
- import yaml
- from fairseq import checkpoint_utils, tasks
- from fairseq.file_io import PathManager
- try:
- from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS
- from simuleval.agents import SpeechAgent
- from simuleval.states import ListEntry, SpeechStates
- except ImportError:
- print("Please install simuleval 'pip install simuleval'")
- SHIFT_SIZE = 10
- WINDOW_SIZE = 25
- SAMPLE_RATE = 16000
- FEATURE_DIM = 80
- BOW_PREFIX = "\u2581"
- class OnlineFeatureExtractor:
- """
- Extract speech feature on the fly.
- """
- def __init__(self, args):
- self.shift_size = args.shift_size
- self.window_size = args.window_size
- assert self.window_size >= self.shift_size
- self.sample_rate = args.sample_rate
- self.feature_dim = args.feature_dim
- self.num_samples_per_shift = int(self.shift_size * self.sample_rate / 1000)
- self.num_samples_per_window = int(self.window_size * self.sample_rate / 1000)
- self.len_ms_to_samples = lambda x: x * self.sample_rate / 1000
- self.previous_residual_samples = []
- self.global_cmvn = args.global_cmvn
- def clear_cache(self):
- self.previous_residual_samples = []
- def __call__(self, new_samples):
- samples = self.previous_residual_samples + new_samples
- if len(samples) < self.num_samples_per_window:
- self.previous_residual_samples = samples
- return
- # num_frames is the number of frames from the new segment
- num_frames = math.floor(
- (len(samples) - self.len_ms_to_samples(self.window_size - self.shift_size))
- / self.num_samples_per_shift
- )
- # the number of frames used for feature extraction
- # including some part of thte previous segment
- effective_num_samples = int(
- num_frames * self.len_ms_to_samples(self.shift_size)
- + self.len_ms_to_samples(self.window_size - self.shift_size)
- )
- input_samples = samples[:effective_num_samples]
- self.previous_residual_samples = samples[
- num_frames * self.num_samples_per_shift:
- ]
- torch.manual_seed(1)
- output = kaldi.fbank(
- torch.FloatTensor(input_samples).unsqueeze(0),
- num_mel_bins=self.feature_dim,
- frame_length=self.window_size,
- frame_shift=self.shift_size,
- ).numpy()
- output = self.transform(output)
- return torch.from_numpy(output)
- def transform(self, input):
- if self.global_cmvn is None:
- return input
- mean = self.global_cmvn["mean"]
- std = self.global_cmvn["std"]
- x = np.subtract(input, mean)
- x = np.divide(x, std)
- return x
- class TensorListEntry(ListEntry):
- """
- Data structure to store a list of tensor.
- """
- def append(self, value):
- if len(self.value) == 0:
- self.value = value
- return
- self.value = torch.cat([self.value] + [value], dim=0)
- def info(self):
- return {
- "type": str(self.new_value_type),
- "length": self.__len__(),
- "value": "" if type(self.value) is list else self.value.size(),
- }
- class FairseqSimulSTAgent(SpeechAgent):
- speech_segment_size = 40 # in ms, 4 pooling ratio * 10 ms step size
- def __init__(self, args):
- super().__init__(args)
- self.eos = DEFAULT_EOS
- self.gpu = getattr(args, "gpu", False)
- self.args = args
- self.load_model_vocab(args)
- if getattr(
- self.model.decoder.layers[0].encoder_attn,
- 'pre_decision_ratio',
- None
- ) is not None:
- self.speech_segment_size *= (
- self.model.decoder.layers[0].encoder_attn.pre_decision_ratio
- )
- args.global_cmvn = None
- if args.config:
- with open(os.path.join(args.data_bin, args.config), "r") as f:
- config = yaml.load(f, Loader=yaml.BaseLoader)
- if "global_cmvn" in config:
- args.global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"])
- if args.global_stats:
- with PathManager.open(args.global_stats, "r") as f:
- global_cmvn = json.loads(f.read())
- self.global_cmvn = {"mean": global_cmvn["mean"], "std": global_cmvn["stddev"]}
- self.feature_extractor = OnlineFeatureExtractor(args)
- self.max_len = args.max_len
- self.force_finish = args.force_finish
- torch.set_grad_enabled(False)
- def build_states(self, args, client, sentence_id):
- # Initialize states here, for example add customized entry to states
- # This function will be called at beginning of every new sentence
- states = SpeechStates(args, client, sentence_id, self)
- self.initialize_states(states)
- return states
- def to_device(self, tensor):
- if self.gpu:
- return tensor.cuda()
- else:
- return tensor.cpu()
- @staticmethod
- def add_args(parser):
- # fmt: off
- parser.add_argument('--model-path', type=str, required=True,
- help='path to your pretrained model.')
- parser.add_argument("--data-bin", type=str, required=True,
- help="Path of data binary")
- parser.add_argument("--config", type=str, default=None,
- help="Path to config yaml file")
- parser.add_argument("--global-stats", type=str, default=None,
- help="Path to json file containing cmvn stats")
- parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece",
- help="Subword splitter type for target text")
- parser.add_argument("--tgt-splitter-path", type=str, default=None,
- help="Subword splitter model path for target text")
- parser.add_argument("--user-dir", type=str, default="examples/simultaneous_translation",
- help="User directory for simultaneous translation")
- parser.add_argument("--max-len", type=int, default=200,
- help="Max length of translation")
- parser.add_argument("--force-finish", default=False, action="store_true",
- help="Force the model to finish the hypothsis if the source is not finished")
- parser.add_argument("--shift-size", type=int, default=SHIFT_SIZE,
- help="Shift size of feature extraction window.")
- parser.add_argument("--window-size", type=int, default=WINDOW_SIZE,
- help="Window size of feature extraction window.")
- parser.add_argument("--sample-rate", type=int, default=SAMPLE_RATE,
- help="Sample rate")
- parser.add_argument("--feature-dim", type=int, default=FEATURE_DIM,
- help="Acoustic feature dimension.")
- # fmt: on
- return parser
- def load_model_vocab(self, args):
- filename = args.model_path
- if not os.path.exists(filename):
- raise IOError("Model file not found: {}".format(filename))
- state = checkpoint_utils.load_checkpoint_to_cpu(filename)
- task_args = state["cfg"]["task"]
- task_args.data = args.data_bin
- if args.config is not None:
- task_args.config_yaml = args.config
- task = tasks.setup_task(task_args)
- # build model for ensemble
- state["cfg"]["model"].load_pretrained_encoder_from = None
- state["cfg"]["model"].load_pretrained_decoder_from = None
- self.model = task.build_model(state["cfg"]["model"])
- self.model.load_state_dict(state["model"], strict=True)
- self.model.eval()
- self.model.share_memory()
- if self.gpu:
- self.model.cuda()
- # Set dictionary
- self.dict = {}
- self.dict["tgt"] = task.target_dictionary
- def initialize_states(self, states):
- self.feature_extractor.clear_cache()
- states.units.source = TensorListEntry()
- states.units.target = ListEntry()
- states.incremental_states = dict()
- def segment_to_units(self, segment, states):
- # Convert speech samples to features
- features = self.feature_extractor(segment)
- if features is not None:
- return [features]
- else:
- return []
- def units_to_segment(self, units, states):
- # Merge sub word to full word.
- if self.model.decoder.dictionary.eos() == units[0]:
- return DEFAULT_EOS
- segment = []
- if None in units.value:
- units.value.remove(None)
- for index in units:
- if index is None:
- units.pop()
- token = self.model.decoder.dictionary.string([index])
- if token.startswith(BOW_PREFIX):
- if len(segment) == 0:
- segment += [token.replace(BOW_PREFIX, "")]
- else:
- for j in range(len(segment)):
- units.pop()
- string_to_return = ["".join(segment)]
- if self.model.decoder.dictionary.eos() == units[0]:
- string_to_return += [DEFAULT_EOS]
- return string_to_return
- else:
- segment += [token.replace(BOW_PREFIX, "")]
- if (
- len(units) > 0
- and self.model.decoder.dictionary.eos() == units[-1]
- or len(states.units.target) > self.max_len
- ):
- tokens = [self.model.decoder.dictionary.string([unit]) for unit in units]
- return ["".join(tokens).replace(BOW_PREFIX, "")] + [DEFAULT_EOS]
- return None
- def update_model_encoder(self, states):
- if len(states.units.source) == 0:
- return
- src_indices = self.to_device(
- states.units.source.value.unsqueeze(0)
- )
- src_lengths = self.to_device(
- torch.LongTensor([states.units.source.value.size(0)])
- )
- states.encoder_states = self.model.encoder(src_indices, src_lengths)
- torch.cuda.empty_cache()
- def update_states_read(self, states):
- # Happens after a read action.
- self.update_model_encoder(states)
- def policy(self, states):
- if not getattr(states, "encoder_states", None):
- return READ_ACTION
- tgt_indices = self.to_device(
- torch.LongTensor(
- [self.model.decoder.dictionary.eos()]
- + [x for x in states.units.target.value if x is not None]
- ).unsqueeze(0)
- )
- states.incremental_states["steps"] = {
- "src": states.encoder_states["encoder_out"][0].size(0),
- "tgt": 1 + len(states.units.target),
- }
- states.incremental_states["online"] = {"only": torch.tensor(not states.finish_read())}
- x, outputs = self.model.decoder.forward(
- prev_output_tokens=tgt_indices,
- encoder_out=states.encoder_states,
- incremental_state=states.incremental_states,
- )
- states.decoder_out = x
- states.decoder_out_extra = outputs
- torch.cuda.empty_cache()
- if outputs.action == 0:
- return READ_ACTION
- else:
- return WRITE_ACTION
- def predict(self, states):
- decoder_states = states.decoder_out
- lprobs = self.model.get_normalized_probs(
- [decoder_states[:, -1:]], log_probs=True
- )
- index = lprobs.argmax(dim=-1)
- index = index[0, 0].item()
- if (
- self.force_finish
- and index == self.model.decoder.dictionary.eos()
- and not states.finish_read()
- ):
- # If we want to force finish the translation
- # (don't stop before finish reading), return a None
- # self.model.decoder.clear_cache(states.incremental_states)
- index = None
- return index