PageRenderTime 26ms CodeModel.GetById 30ms RepoModel.GetById 1ms app.codeStats 0ms

/wav2vec_cycle_code/FragmentVC/src/fairseq/examples/speech_to_text/data_utils.py

https://gitlab.com/lwd17/enhanced_examplar_ae
Python | 309 lines | 279 code | 23 blank | 7 comment | 25 complexity | 16d74195e6449476ac95a521e724fd7e MD5 | raw file
  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. #
  4. # This source code is licensed under the MIT license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import csv
  7. import os
  8. import os.path as op
  9. import zipfile
  10. from functools import reduce
  11. from glob import glob
  12. from multiprocessing import cpu_count
  13. from typing import Any, Dict, List
  14. import numpy as np
  15. import pandas as pd
  16. import sentencepiece as sp
  17. from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank
  18. from fairseq.data.audio.feature_transforms.utterance_cmvn import UtteranceCMVN
  19. from tqdm import tqdm
  20. UNK_TOKEN, UNK_TOKEN_ID = "<unk>", 3
  21. BOS_TOKEN, BOS_TOKEN_ID = "<s>", 0
  22. EOS_TOKEN, EOS_TOKEN_ID = "</s>", 2
  23. PAD_TOKEN, PAD_TOKEN_ID = "<pad>", 1
  24. def gen_vocab(
  25. input_path: str, output_path_prefix: str, model_type="bpe", vocab_size=1000,
  26. ):
  27. # Train SentencePiece Model
  28. arguments = [
  29. f"--input={input_path}",
  30. f"--model_prefix={output_path_prefix}",
  31. f"--model_type={model_type}",
  32. f"--vocab_size={vocab_size}",
  33. "--character_coverage=1.0",
  34. f"--num_threads={cpu_count()}",
  35. f"--unk_id={UNK_TOKEN_ID}",
  36. f"--bos_id={BOS_TOKEN_ID}",
  37. f"--eos_id={EOS_TOKEN_ID}",
  38. f"--pad_id={PAD_TOKEN_ID}",
  39. ]
  40. sp.SentencePieceTrainer.Train(" ".join(arguments))
  41. # Export fairseq dictionary
  42. spm = sp.SentencePieceProcessor()
  43. spm.Load(output_path_prefix + ".model")
  44. vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())}
  45. assert (
  46. vocab.get(UNK_TOKEN_ID) == UNK_TOKEN
  47. and vocab.get(PAD_TOKEN_ID) == PAD_TOKEN
  48. and vocab.get(BOS_TOKEN_ID) == BOS_TOKEN
  49. and vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
  50. )
  51. vocab = {
  52. i: s
  53. for i, s in vocab.items()
  54. if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN}
  55. }
  56. with open(output_path_prefix + ".txt", "w") as f_out:
  57. for _, s in sorted(vocab.items(), key=lambda x: x[0]):
  58. f_out.write(f"{s} 1\n")
  59. def extract_fbank_features(
  60. waveform,
  61. sample_rate,
  62. output_path=None,
  63. n_mel_bins=80,
  64. apply_utterance_cmvn=True,
  65. overwrite=False,
  66. ):
  67. if output_path is not None and op.exists(output_path) and not overwrite:
  68. return
  69. _waveform = waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers
  70. _waveform = _waveform.squeeze().numpy()
  71. features = _get_kaldi_fbank(_waveform, sample_rate, n_mel_bins)
  72. if features is None:
  73. features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins)
  74. if features is None:
  75. raise ImportError(
  76. "Please install pyKaldi or torchaudio to enable "
  77. "online filterbank feature extraction"
  78. )
  79. if apply_utterance_cmvn:
  80. cmvn = UtteranceCMVN(norm_means=True, norm_vars=True)
  81. features = cmvn(features)
  82. if output_path is not None:
  83. np.save(output_path, features)
  84. else:
  85. return features
  86. def create_zip(data_root, zip_path):
  87. cwd = os.path.abspath(os.curdir)
  88. os.chdir(data_root)
  89. with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f:
  90. for filename in tqdm(glob("*.npy")):
  91. f.write(filename)
  92. os.chdir(cwd)
  93. def is_npy_data(data: bytes) -> bool:
  94. return data[0] == 147 and data[1] == 78
  95. def get_zip_manifest(zip_root, zip_filename):
  96. zip_path = op.join(zip_root, zip_filename)
  97. with zipfile.ZipFile(zip_path, mode="r") as f:
  98. info = f.infolist()
  99. manifest = {}
  100. for i in tqdm(info):
  101. utt_id = op.splitext(i.filename)[0]
  102. offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size
  103. manifest[utt_id] = f"{zip_filename}:{offset}:{file_size}"
  104. with open(zip_path, "rb") as f:
  105. f.seek(offset)
  106. data = f.read(file_size)
  107. assert len(data) > 1 and is_npy_data(data)
  108. return manifest
  109. def gen_config_yaml(
  110. data_root,
  111. spm_filename,
  112. yaml_filename="config.yaml",
  113. specaugment_policy="lb",
  114. prepend_tgt_lang_tag=False,
  115. sampling_alpha=1.0,
  116. ):
  117. data_root = op.abspath(data_root)
  118. writer = S2TDataConfigWriter(op.join(data_root, yaml_filename))
  119. writer.set_audio_root(op.abspath(data_root))
  120. writer.set_vocab_filename(spm_filename.replace(".model", ".txt"))
  121. writer.set_input_channels(1)
  122. writer.set_input_feat_per_channel(80)
  123. specaugment_setters = {
  124. "lb": writer.set_specaugment_lb_policy,
  125. "ld": writer.set_specaugment_ld_policy,
  126. "sm": writer.set_specaugment_sm_policy,
  127. "ss": writer.set_specaugment_ss_policy,
  128. }
  129. assert specaugment_policy in specaugment_setters
  130. specaugment_setters[specaugment_policy]()
  131. writer.set_bpe_tokenizer(
  132. {
  133. "bpe": "sentencepiece",
  134. "sentencepiece_model": op.join(data_root, spm_filename),
  135. }
  136. )
  137. if prepend_tgt_lang_tag:
  138. writer.set_prepend_tgt_lang_tag(True)
  139. writer.set_sampling_alpha(sampling_alpha)
  140. writer.set_feature_transforms("_train", ["specaugment"])
  141. writer.flush()
  142. def load_df_from_tsv(path: str):
  143. return pd.read_csv(
  144. path,
  145. sep="\t",
  146. header=0,
  147. encoding="utf-8",
  148. escapechar="\\",
  149. quoting=csv.QUOTE_NONE,
  150. na_filter=False,
  151. )
  152. def save_df_to_tsv(dataframe, path):
  153. dataframe.to_csv(
  154. path,
  155. sep="\t",
  156. header=True,
  157. index=False,
  158. encoding="utf-8",
  159. escapechar="\\",
  160. quoting=csv.QUOTE_NONE,
  161. )
  162. def filter_manifest_df(
  163. df, is_train_split=False, extra_filters=None, min_n_frames=5, max_n_frames=3000
  164. ):
  165. filters = {
  166. "no speech": df["audio"] == "",
  167. f"short speech (<{min_n_frames} frames)": df["n_frames"] < min_n_frames,
  168. "empty sentence": df["tgt_text"] == "",
  169. }
  170. if is_train_split:
  171. filters[f"long speech (>{max_n_frames} frames)"] = df["n_frames"] > max_n_frames
  172. if extra_filters is not None:
  173. filters.update(extra_filters)
  174. invalid = reduce(lambda x, y: x | y, filters.values())
  175. valid = ~invalid
  176. print(
  177. "| "
  178. + ", ".join(f"{n}: {f.sum()}" for n, f in filters.items())
  179. + f", total {invalid.sum()} filtered, {valid.sum()} remained."
  180. )
  181. return df[valid]
  182. class S2TDataConfigWriter(object):
  183. DEFAULT_VOCAB_FILENAME = "dict.txt"
  184. DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
  185. DEFAULT_INPUT_CHANNELS = 1
  186. def __init__(self, yaml_path):
  187. try:
  188. import yaml
  189. except ImportError:
  190. print("Please install PyYAML to load YAML files for S2T data config")
  191. self.yaml = yaml
  192. self.yaml_path = yaml_path
  193. self.config = {}
  194. def flush(self):
  195. with open(self.yaml_path, "w") as f:
  196. self.yaml.dump(self.config, f)
  197. def set_audio_root(self, audio_root=""):
  198. self.config["audio_root"] = audio_root
  199. def set_vocab_filename(self, vocab_filename="dict.txt"):
  200. self.config["vocab_filename"] = vocab_filename
  201. def set_specaugment(
  202. self,
  203. time_wrap_w: int,
  204. freq_mask_n: int,
  205. freq_mask_f: int,
  206. time_mask_n: int,
  207. time_mask_t: int,
  208. time_mask_p: float,
  209. ):
  210. self.config["specaugment"] = {
  211. "time_wrap_W": time_wrap_w,
  212. "freq_mask_N": freq_mask_n,
  213. "freq_mask_F": freq_mask_f,
  214. "time_mask_N": time_mask_n,
  215. "time_mask_T": time_mask_t,
  216. "time_mask_p": time_mask_p,
  217. }
  218. def set_specaugment_lb_policy(self):
  219. self.set_specaugment(
  220. time_wrap_w=0,
  221. freq_mask_n=1,
  222. freq_mask_f=27,
  223. time_mask_n=1,
  224. time_mask_t=100,
  225. time_mask_p=1.0,
  226. )
  227. def set_specaugment_ld_policy(self):
  228. self.set_specaugment(
  229. time_wrap_w=0,
  230. freq_mask_n=2,
  231. freq_mask_f=27,
  232. time_mask_n=2,
  233. time_mask_t=100,
  234. time_mask_p=1.0,
  235. )
  236. def set_specaugment_sm_policy(self):
  237. self.set_specaugment(
  238. time_wrap_w=0,
  239. freq_mask_n=2,
  240. freq_mask_f=15,
  241. time_mask_n=2,
  242. time_mask_t=70,
  243. time_mask_p=0.2,
  244. )
  245. def set_specaugment_ss_policy(self):
  246. self.set_specaugment(
  247. time_wrap_w=0,
  248. freq_mask_n=2,
  249. freq_mask_f=27,
  250. time_mask_n=2,
  251. time_mask_t=70,
  252. time_mask_p=0.2,
  253. )
  254. def set_input_channels(self, input_channels=1):
  255. self.config["input_channels"] = input_channels
  256. def set_input_feat_per_channel(self, input_feat_per_channel=80):
  257. self.config["input_feat_per_channel"] = input_feat_per_channel
  258. def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
  259. self.config["bpe_tokenizer"] = bpe_tokenizer
  260. def set_feature_transforms(self, split, transforms: List[str]):
  261. if "transforms" not in self.config:
  262. self.config["transforms"] = {}
  263. self.config["transforms"][split] = transforms
  264. def set_prepend_tgt_lang_tag(self, flag=True):
  265. self.config["prepend_tgt_lang_tag"] = flag
  266. def set_sampling_alpha(self, sampling_alpha=1.0):
  267. self.config["sampling_alpha"] = sampling_alpha