PageRenderTime 46ms CodeModel.GetById 18ms RepoModel.GetById 1ms app.codeStats 0ms

/wav2vec_cycle_code/FragmentVC/src/fairseq/examples/speech_to_text/prep_covost_data.py

https://gitlab.com/lwd17/enhanced_examplar_ae
Python | 294 lines | 248 code | 18 blank | 28 comment | 13 complexity | d2daeb5ab94f870a301a635f2dfa97c5 MD5 | raw file
  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. #
  4. # This source code is licensed under the MIT license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import argparse
  7. import csv
  8. import logging
  9. import os
  10. import os.path as op
  11. import shutil
  12. from tempfile import NamedTemporaryFile
  13. from typing import Optional, Tuple
  14. import pandas as pd
  15. import torchaudio
  16. from examples.speech_to_text.data_utils import (
  17. create_zip,
  18. extract_fbank_features,
  19. filter_manifest_df,
  20. gen_config_yaml,
  21. gen_vocab,
  22. get_zip_manifest,
  23. save_df_to_tsv,
  24. )
  25. from torch import Tensor
  26. from torch.utils.data import Dataset
  27. from torchaudio.datasets.utils import download_url, extract_archive
  28. from tqdm import tqdm
  29. log = logging.getLogger(__name__)
  30. MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
  31. class CoVoST(Dataset):
  32. """Create a Dataset for CoVoST (https://github.com/facebookresearch/covost).
  33. Args:
  34. root (str): root path to the dataset and generated manifests/features
  35. source_language (str): source (audio) language
  36. target_language (str, optional): target (text) language,
  37. None for no translation (default: None)
  38. version (int, optional): CoVoST version. (default: 2)
  39. download (bool, optional): Whether to download the dataset if it is not
  40. found at root path. (default: ``False``).
  41. """
  42. CV_URL_TEMPLATE = (
  43. "https://voice-prod-bundler-ee1969a6ce8178826482b88"
  44. "e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz"
  45. )
  46. COVOST_URL_TEMPLATE = (
  47. "https://dl.fbaipublicfiles.com/covost/"
  48. "covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz"
  49. )
  50. VERSIONS = {2}
  51. SPLITS = ["train", "dev", "test"]
  52. CV_VERSION_ID = {1: "cv-corpus-3", 2: "cv-corpus-4-2019-12-10"}
  53. XX_EN_LANGUAGES = {
  54. 1: ["fr", "de", "nl", "ru", "es", "it", "tr", "fa", "sv-SE", "mn", "zh-CN"],
  55. 2: [
  56. "fr",
  57. "de",
  58. "es",
  59. "ca",
  60. "it",
  61. "ru",
  62. "zh-CN",
  63. "pt",
  64. "fa",
  65. "et",
  66. "mn",
  67. "nl",
  68. "tr",
  69. "ar",
  70. "sv-SE",
  71. "lv",
  72. "sl",
  73. "ta",
  74. "ja",
  75. "id",
  76. "cy",
  77. ],
  78. }
  79. EN_XX_LANGUAGES = {
  80. 1: [],
  81. 2: [
  82. "de",
  83. "tr",
  84. "fa",
  85. "sv-SE",
  86. "mn",
  87. "zh-CN",
  88. "cy",
  89. "ca",
  90. "sl",
  91. "et",
  92. "id",
  93. "ar",
  94. "ta",
  95. "lv",
  96. "ja",
  97. ],
  98. }
  99. def __init__(
  100. self,
  101. root: str,
  102. split: str,
  103. source_language: str,
  104. target_language: Optional[str] = None,
  105. version: int = 2,
  106. download: bool = False,
  107. ) -> None:
  108. assert version in self.VERSIONS and split in self.SPLITS
  109. assert source_language is not None
  110. self.no_translation = target_language is None
  111. if not self.no_translation:
  112. assert "en" in {source_language, target_language}
  113. if source_language == "en":
  114. assert target_language in self.EN_XX_LANGUAGES[version]
  115. else:
  116. assert source_language in self.XX_EN_LANGUAGES[version]
  117. else:
  118. # Hack here so that we can get "split" column from CoVoST TSV.
  119. # Note that we use CoVoST train split for ASR which is an extension
  120. # to Common Voice train split.
  121. target_language = "de" if source_language == "en" else "en"
  122. self.root = os.path.join(root, "raw")
  123. os.makedirs(self.root, exist_ok=True)
  124. cv_url = self.CV_URL_TEMPLATE.format(
  125. ver=self.CV_VERSION_ID[version], lang=source_language
  126. )
  127. cv_archive = os.path.join(self.root, os.path.basename(cv_url))
  128. if download:
  129. if not os.path.isfile(cv_archive):
  130. download_url(cv_url, self.root, hash_value=None)
  131. extract_archive(cv_archive)
  132. covost_url = self.COVOST_URL_TEMPLATE.format(
  133. src_lang=source_language, tgt_lang=target_language
  134. )
  135. covost_archive = os.path.join(self.root, os.path.basename(covost_url))
  136. if download:
  137. if not os.path.isfile(covost_archive):
  138. download_url(covost_url, self.root, hash_value=None)
  139. extract_archive(covost_archive)
  140. cv_tsv = self.load_from_tsv(os.path.join(self.root, "validated.tsv"))
  141. covost_tsv = self.load_from_tsv(
  142. os.path.join(self.root, os.path.basename(covost_url).replace(".tar.gz", ""))
  143. )
  144. df = pd.merge(
  145. left=cv_tsv[["path", "sentence", "client_id"]],
  146. right=covost_tsv[["path", "translation", "split"]],
  147. how="inner",
  148. on="path",
  149. )
  150. if split == "train":
  151. df = df[(df["split"] == split) | (df["split"] == f"{split}_covost")]
  152. else:
  153. df = df[df["split"] == split]
  154. self.data = df.to_dict(orient="index").items()
  155. self.data = [v for k, v in sorted(self.data, key=lambda x: x[0])]
  156. @classmethod
  157. def load_from_tsv(cls, path: str):
  158. return pd.read_csv(
  159. path,
  160. sep="\t",
  161. header=0,
  162. encoding="utf-8",
  163. escapechar="\\",
  164. quoting=csv.QUOTE_NONE,
  165. na_filter=False,
  166. )
  167. def __getitem__(
  168. self, n: int
  169. ) -> Tuple[Tensor, int, str, str, Optional[str], str, str]:
  170. """Load the n-th sample from the dataset.
  171. Args:
  172. n (int): The index of the sample to be loaded
  173. Returns:
  174. tuple: ``(waveform, sample_rate, sentence, translation, speaker_id,
  175. sample_id)``
  176. """
  177. data = self.data[n]
  178. path = os.path.join(self.root, "clips", data["path"])
  179. waveform, sample_rate = torchaudio.load(path)
  180. sentence = data["sentence"]
  181. translation = None if self.no_translation else data["translation"]
  182. speaker_id = data["client_id"]
  183. _id = data["path"].replace(".mp3", "")
  184. return waveform, sample_rate, sentence, translation, speaker_id, _id
  185. def __len__(self) -> int:
  186. return len(self.data)
  187. def process(args):
  188. root = op.join(args.data_root, args.src_lang)
  189. os.makedirs(root, exist_ok=True)
  190. # Extract features
  191. feature_root = op.join(root, "fbank80")
  192. os.makedirs(feature_root, exist_ok=True)
  193. for split in CoVoST.SPLITS:
  194. print(f"Fetching split {split}...")
  195. dataset = CoVoST(root, split, args.src_lang, args.tgt_lang, download=True)
  196. print("Extracting log mel filter bank features...")
  197. for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
  198. extract_fbank_features(
  199. waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy")
  200. )
  201. # Pack features into ZIP
  202. zip_filename = "fbank80.zip"
  203. zip_path = op.join(root, zip_filename)
  204. print("ZIPing features...")
  205. create_zip(feature_root, zip_path)
  206. print("Fetching ZIP manifest...")
  207. zip_manifest = get_zip_manifest(args.data_root, f"{args.src_lang}/{zip_filename}")
  208. # Generate TSV manifest
  209. print("Generating manifest...")
  210. train_text = []
  211. task = f"asr_{args.src_lang}"
  212. if args.tgt_lang is not None:
  213. task = f"st_{args.src_lang}_{args.tgt_lang}"
  214. for split in CoVoST.SPLITS:
  215. manifest = {c: [] for c in MANIFEST_COLUMNS}
  216. dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
  217. for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
  218. manifest["id"].append(utt_id)
  219. manifest["audio"].append(zip_manifest[utt_id])
  220. duration_ms = int(wav.size(1) / sr * 1000)
  221. manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
  222. manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt)
  223. manifest["speaker"].append(speaker_id)
  224. is_train_split = split.startswith("train")
  225. if is_train_split:
  226. train_text.extend(manifest["tgt_text"])
  227. df = pd.DataFrame.from_dict(manifest)
  228. df = filter_manifest_df(df, is_train_split=is_train_split)
  229. save_df_to_tsv(df, op.join(root, f"{split}_{task}.tsv"))
  230. # Generate vocab
  231. vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
  232. spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{task}"
  233. with NamedTemporaryFile(mode="w") as f:
  234. for t in train_text:
  235. f.write(t + "\n")
  236. gen_vocab(
  237. f.name, op.join(root, spm_filename_prefix), args.vocab_type, args.vocab_size
  238. )
  239. # Generate config YAML
  240. gen_config_yaml(
  241. root,
  242. spm_filename_prefix + ".model",
  243. yaml_filename=f"config_{task}.yaml",
  244. specaugment_policy="lb",
  245. )
  246. # Clean up
  247. shutil.rmtree(feature_root)
  248. def main():
  249. parser = argparse.ArgumentParser()
  250. parser.add_argument("--data-root", "-d", required=True, type=str)
  251. parser.add_argument(
  252. "--vocab-type",
  253. default="unigram",
  254. required=True,
  255. type=str,
  256. choices=["bpe", "unigram", "char"],
  257. ),
  258. parser.add_argument("--vocab-size", default=1000, type=int)
  259. parser.add_argument("--src-lang", "-s", required=True, type=str)
  260. parser.add_argument("--tgt-lang", "-t", type=str)
  261. args = parser.parse_args()
  262. process(args)
  263. if __name__ == "__main__":
  264. main()