PageRenderTime 57ms CodeModel.GetById 21ms RepoModel.GetById 0ms app.codeStats 0ms

/wav2vec_cycle_code/fairseq/examples/speech_to_text/prep_covost_data.py

https://gitlab.com/lwd17/enhanced_examplar_ae
Python | 279 lines | 244 code | 16 blank | 19 comment | 10 complexity | 9caf74af6a5167e871b1c4afdf4cf9ec MD5 | raw file
  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. #
  4. # This source code is licensed under the MIT license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import argparse
  7. import logging
  8. from pathlib import Path
  9. import shutil
  10. from tempfile import NamedTemporaryFile
  11. from typing import Optional, Tuple
  12. import pandas as pd
  13. import torchaudio
  14. from examples.speech_to_text.data_utils import (
  15. create_zip,
  16. extract_fbank_features,
  17. filter_manifest_df,
  18. gen_config_yaml,
  19. gen_vocab,
  20. get_zip_manifest,
  21. load_df_from_tsv,
  22. save_df_to_tsv,
  23. )
  24. from torch import Tensor
  25. from torch.utils.data import Dataset
  26. from torchaudio.datasets.utils import download_url, extract_archive
  27. from tqdm import tqdm
  28. log = logging.getLogger(__name__)
  29. MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
  30. class CoVoST(Dataset):
  31. """Create a Dataset for CoVoST (https://github.com/facebookresearch/covost).
  32. Args:
  33. root (str): root path to the dataset and generated manifests/features
  34. source_language (str): source (audio) language
  35. target_language (str, optional): target (text) language,
  36. None for no translation (default: None)
  37. version (int, optional): CoVoST version. (default: 2)
  38. download (bool, optional): Whether to download the dataset if it is not
  39. found at root path. (default: ``False``).
  40. """
  41. COVOST_URL_TEMPLATE = (
  42. "https://dl.fbaipublicfiles.com/covost/"
  43. "covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz"
  44. )
  45. VERSIONS = {2}
  46. SPLITS = ["train", "dev", "test"]
  47. XX_EN_LANGUAGES = {
  48. 1: ["fr", "de", "nl", "ru", "es", "it", "tr", "fa", "sv-SE", "mn", "zh-CN"],
  49. 2: [
  50. "fr",
  51. "de",
  52. "es",
  53. "ca",
  54. "it",
  55. "ru",
  56. "zh-CN",
  57. "pt",
  58. "fa",
  59. "et",
  60. "mn",
  61. "nl",
  62. "tr",
  63. "ar",
  64. "sv-SE",
  65. "lv",
  66. "sl",
  67. "ta",
  68. "ja",
  69. "id",
  70. "cy",
  71. ],
  72. }
  73. EN_XX_LANGUAGES = {
  74. 1: [],
  75. 2: [
  76. "de",
  77. "tr",
  78. "fa",
  79. "sv-SE",
  80. "mn",
  81. "zh-CN",
  82. "cy",
  83. "ca",
  84. "sl",
  85. "et",
  86. "id",
  87. "ar",
  88. "ta",
  89. "lv",
  90. "ja",
  91. ],
  92. }
  93. def __init__(
  94. self,
  95. root: str,
  96. split: str,
  97. source_language: str,
  98. target_language: Optional[str] = None,
  99. version: int = 2,
  100. ) -> None:
  101. assert version in self.VERSIONS and split in self.SPLITS
  102. assert source_language is not None
  103. self.no_translation = target_language is None
  104. if not self.no_translation:
  105. assert "en" in {source_language, target_language}
  106. if source_language == "en":
  107. assert target_language in self.EN_XX_LANGUAGES[version]
  108. else:
  109. assert source_language in self.XX_EN_LANGUAGES[version]
  110. else:
  111. # Hack here so that we can get "split" column from CoVoST TSV.
  112. # Note that we use CoVoST train split for ASR which is an extension
  113. # to Common Voice train split.
  114. target_language = "de" if source_language == "en" else "en"
  115. self.root: Path = Path(root)
  116. cv_tsv_path = self.root / "validated.tsv"
  117. assert cv_tsv_path.is_file()
  118. covost_url = self.COVOST_URL_TEMPLATE.format(
  119. src_lang=source_language, tgt_lang=target_language
  120. )
  121. covost_archive = self.root / Path(covost_url).name
  122. if not covost_archive.is_file():
  123. download_url(covost_url, self.root.as_posix(), hash_value=None)
  124. extract_archive(covost_archive.as_posix())
  125. cv_tsv = load_df_from_tsv(cv_tsv_path)
  126. covost_tsv = load_df_from_tsv(
  127. self.root / Path(covost_url).name.replace(".tar.gz", "")
  128. )
  129. df = pd.merge(
  130. left=cv_tsv[["path", "sentence", "client_id"]],
  131. right=covost_tsv[["path", "translation", "split"]],
  132. how="inner",
  133. on="path",
  134. )
  135. if split == "train":
  136. df = df[(df["split"] == split) | (df["split"] == f"{split}_covost")]
  137. else:
  138. df = df[df["split"] == split]
  139. data = df.to_dict(orient="index").items()
  140. data = [v for k, v in sorted(data, key=lambda x: x[0])]
  141. self.data = []
  142. for e in data:
  143. try:
  144. path = self.root / "clips" / e["path"]
  145. _ = torchaudio.info(path.as_posix())
  146. self.data.append(e)
  147. except RuntimeError:
  148. pass
  149. def __getitem__(
  150. self, n: int
  151. ) -> Tuple[Tensor, int, str, str, Optional[str], str, str]:
  152. """Load the n-th sample from the dataset.
  153. Args:
  154. n (int): The index of the sample to be loaded
  155. Returns:
  156. tuple: ``(waveform, sample_rate, sentence, translation, speaker_id,
  157. sample_id)``
  158. """
  159. data = self.data[n]
  160. path = self.root / "clips" / data["path"]
  161. waveform, sample_rate = torchaudio.load(path)
  162. sentence = data["sentence"]
  163. translation = None if self.no_translation else data["translation"]
  164. speaker_id = data["client_id"]
  165. _id = data["path"].replace(".mp3", "")
  166. return waveform, sample_rate, sentence, translation, speaker_id, _id
  167. def __len__(self) -> int:
  168. return len(self.data)
  169. def process(args):
  170. root = Path(args.data_root).absolute() / args.src_lang
  171. if not root.is_dir():
  172. raise NotADirectoryError(f"{root} does not exist")
  173. # Extract features
  174. feature_root = root / "fbank80"
  175. feature_root.mkdir(exist_ok=True)
  176. for split in CoVoST.SPLITS:
  177. print(f"Fetching split {split}...")
  178. dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
  179. print("Extracting log mel filter bank features...")
  180. for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
  181. extract_fbank_features(
  182. waveform, sample_rate, feature_root / f"{utt_id}.npy"
  183. )
  184. # Pack features into ZIP
  185. zip_path = root / "fbank80.zip"
  186. print("ZIPing features...")
  187. create_zip(feature_root, zip_path)
  188. print("Fetching ZIP manifest...")
  189. audio_paths, audio_lengths = get_zip_manifest(zip_path)
  190. # Generate TSV manifest
  191. print("Generating manifest...")
  192. train_text = []
  193. task = f"asr_{args.src_lang}"
  194. if args.tgt_lang is not None:
  195. task = f"st_{args.src_lang}_{args.tgt_lang}"
  196. for split in CoVoST.SPLITS:
  197. manifest = {c: [] for c in MANIFEST_COLUMNS}
  198. dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
  199. for _, _, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
  200. manifest["id"].append(utt_id)
  201. manifest["audio"].append(audio_paths[utt_id])
  202. manifest["n_frames"].append(audio_lengths[utt_id])
  203. manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt)
  204. manifest["speaker"].append(speaker_id)
  205. is_train_split = split.startswith("train")
  206. if is_train_split:
  207. train_text.extend(manifest["tgt_text"])
  208. df = pd.DataFrame.from_dict(manifest)
  209. df = filter_manifest_df(df, is_train_split=is_train_split)
  210. save_df_to_tsv(df, root / f"{split}_{task}.tsv")
  211. # Generate vocab
  212. vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
  213. spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{task}"
  214. with NamedTemporaryFile(mode="w") as f:
  215. for t in train_text:
  216. f.write(t + "\n")
  217. gen_vocab(
  218. Path(f.name),
  219. root / spm_filename_prefix,
  220. args.vocab_type,
  221. args.vocab_size
  222. )
  223. # Generate config YAML
  224. gen_config_yaml(
  225. root,
  226. spm_filename=spm_filename_prefix + ".model",
  227. yaml_filename=f"config_{task}.yaml",
  228. specaugment_policy="lb",
  229. )
  230. # Clean up
  231. shutil.rmtree(feature_root)
  232. def main():
  233. parser = argparse.ArgumentParser()
  234. parser.add_argument(
  235. "--data-root", "-d", required=True, type=str,
  236. help="data root with sub-folders for each language <root>/<src_lang>"
  237. )
  238. parser.add_argument(
  239. "--vocab-type",
  240. default="unigram",
  241. required=True,
  242. type=str,
  243. choices=["bpe", "unigram", "char"],
  244. ),
  245. parser.add_argument("--vocab-size", default=1000, type=int)
  246. parser.add_argument("--src-lang", "-s", required=True, type=str)
  247. parser.add_argument("--tgt-lang", "-t", type=str)
  248. args = parser.parse_args()
  249. process(args)
  250. if __name__ == "__main__":
  251. main()