PageRenderTime 502ms CodeModel.GetById 21ms RepoModel.GetById 0ms app.codeStats 0ms

/wav2vec_cycle_code/fairseq/examples/speech_to_text/prep_mtedx_data.py

https://gitlab.com/lwd17/enhanced_examplar_ae
Python | 271 lines | 239 code | 15 blank | 17 comment | 28 complexity | 3a746bfab87b03f768788506834ac885 MD5 | raw file
  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. #
  4. # This source code is licensed under the MIT license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import argparse
  7. import logging
  8. import os
  9. from pathlib import Path
  10. import shutil
  11. from itertools import groupby
  12. from tempfile import NamedTemporaryFile
  13. from typing import Tuple
  14. import pandas as pd
  15. import soundfile as sf
  16. from examples.speech_to_text.data_utils import (
  17. create_zip,
  18. extract_fbank_features,
  19. filter_manifest_df,
  20. gen_config_yaml,
  21. gen_vocab,
  22. get_zip_manifest,
  23. load_df_from_tsv,
  24. save_df_to_tsv,
  25. )
  26. import torch
  27. from torch.utils.data import Dataset
  28. from tqdm import tqdm
  29. from fairseq.data.audio.audio_utils import get_waveform, convert_waveform
  30. log = logging.getLogger(__name__)
  31. MANIFEST_COLUMNS = [
  32. "id", "audio", "n_frames", "tgt_text", "speaker", "tgt_lang"
  33. ]
  34. class mTEDx(Dataset):
  35. """
  36. Create a Dataset for Multilingual TEDx.
  37. Each item is a tuple of the form: waveform, sample_rate, source utterance,
  38. target utterance, speaker_id, utterance_id
  39. """
  40. SPLITS = ["train", "valid", "test"]
  41. LANGPAIRS = ["es-es", "fr-fr", "pt-pt", "it-it", "ru-ru", "el-el", "ar-ar",
  42. "de-de", "es-en", "es-fr", "es-pt", "es-it", "fr-en", "fr-es",
  43. "fr-pt", "pt-en", "pt-es", "it-en", "it-es", "ru-en", "el-en"]
  44. def __init__(self, root: str, lang: str, split: str) -> None:
  45. assert split in self.SPLITS and lang in self.LANGPAIRS
  46. _root = Path(root) / f"{lang}" / "data" / split
  47. wav_root, txt_root = _root / "wav", _root / "txt"
  48. assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir()
  49. # Load audio segments
  50. try:
  51. import yaml
  52. except ImportError:
  53. print(
  54. "Please install PyYAML to load the Multilingual TEDx YAML files"
  55. )
  56. with open(txt_root / f"{split}.yaml") as f:
  57. segments = yaml.load(f, Loader=yaml.BaseLoader)
  58. # Load source and target utterances
  59. src, tgt = lang.split("-")
  60. for _lang in [src, tgt]:
  61. with open(txt_root / f"{split}.{_lang}") as f:
  62. utterances = [r.strip() for r in f]
  63. assert len(segments) == len(utterances)
  64. for i, u in enumerate(utterances):
  65. segments[i][_lang] = u
  66. # Gather info
  67. self.data = []
  68. for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
  69. wav_filename = wav_filename.replace(".wav", ".flac")
  70. wav_path = wav_root / wav_filename
  71. sample_rate = sf.info(wav_path.as_posix()).samplerate
  72. seg_group = sorted(_seg_group, key=lambda x: float(x["offset"]))
  73. for i, segment in enumerate(seg_group):
  74. offset = int(float(segment["offset"]) * sample_rate)
  75. n_frames = int(float(segment["duration"]) * sample_rate)
  76. _id = f"{wav_path.stem}_{i}"
  77. self.data.append(
  78. (
  79. wav_path.as_posix(),
  80. offset,
  81. n_frames,
  82. sample_rate,
  83. segment[src],
  84. segment[tgt],
  85. segment["speaker_id"],
  86. tgt,
  87. _id,
  88. )
  89. )
  90. def __getitem__(
  91. self, n: int
  92. ) -> Tuple[torch.Tensor, int, str, str, str, str, str]:
  93. wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, tgt_lang, \
  94. utt_id = self.data[n]
  95. waveform, _ = get_waveform(wav_path, frames=n_frames, start=offset)
  96. waveform = torch.from_numpy(waveform)
  97. return waveform, sr, src_utt, tgt_utt, spk_id, tgt_lang, utt_id
  98. def __len__(self) -> int:
  99. return len(self.data)
  100. def process(args):
  101. root = Path(args.data_root).absolute()
  102. for lang in mTEDx.LANGPAIRS:
  103. cur_root = root / f"{lang}"
  104. if not cur_root.is_dir():
  105. print(f"{cur_root.as_posix()} does not exist. Skipped.")
  106. continue
  107. # Extract features
  108. audio_root = cur_root / ("flac" if args.use_audio_input else "fbank80")
  109. audio_root.mkdir(exist_ok=True)
  110. for split in mTEDx.SPLITS:
  111. print(f"Fetching split {split}...")
  112. dataset = mTEDx(root.as_posix(), lang, split)
  113. if args.use_audio_input:
  114. print("Converting audios...")
  115. for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
  116. tgt_sample_rate = 16_000
  117. _wavform, _ = convert_waveform(
  118. waveform, sample_rate, to_mono=True,
  119. to_sample_rate=tgt_sample_rate
  120. )
  121. sf.write(
  122. (audio_root / f"{utt_id}.flac").as_posix(),
  123. _wavform.numpy(), tgt_sample_rate
  124. )
  125. else:
  126. print("Extracting log mel filter bank features...")
  127. for waveform, sample_rate, _, _, _, _, utt_id in tqdm(dataset):
  128. extract_fbank_features(
  129. waveform, sample_rate, audio_root / f"{utt_id}.npy"
  130. )
  131. # Pack features into ZIP
  132. zip_path = cur_root / f"{audio_root.name}.zip"
  133. print("ZIPing audios/features...")
  134. create_zip(audio_root, zip_path)
  135. print("Fetching ZIP manifest...")
  136. audio_paths, audio_lengths = get_zip_manifest(zip_path)
  137. # Generate TSV manifest
  138. print("Generating manifest...")
  139. train_text = []
  140. for split in mTEDx.SPLITS:
  141. is_train_split = split.startswith("train")
  142. manifest = {c: [] for c in MANIFEST_COLUMNS}
  143. ds = mTEDx(args.data_root, lang, split)
  144. for _, _, src_utt, tgt_utt, spk_id, tgt_lang, utt_id in tqdm(ds):
  145. manifest["id"].append(utt_id)
  146. manifest["audio"].append(audio_paths[utt_id])
  147. manifest["n_frames"].append(audio_lengths[utt_id])
  148. manifest["tgt_text"].append(
  149. src_utt if args.task == "asr" else tgt_utt
  150. )
  151. manifest["speaker"].append(spk_id)
  152. manifest["tgt_lang"].append(tgt_lang)
  153. if is_train_split:
  154. train_text.extend(manifest["tgt_text"])
  155. df = pd.DataFrame.from_dict(manifest)
  156. df = filter_manifest_df(df, is_train_split=is_train_split)
  157. save_df_to_tsv(df, cur_root / f"{split}_{args.task}.tsv")
  158. # Generate vocab
  159. v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
  160. spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
  161. with NamedTemporaryFile(mode="w") as f:
  162. for t in train_text:
  163. f.write(t + "\n")
  164. gen_vocab(
  165. Path(f.name),
  166. cur_root / spm_filename_prefix,
  167. args.vocab_type,
  168. args.vocab_size,
  169. )
  170. # Generate config YAML
  171. if args.use_audio_input:
  172. gen_config_yaml(
  173. cur_root,
  174. spm_filename=spm_filename_prefix + ".model",
  175. yaml_filename=f"config_{args.task}.yaml",
  176. specaugment_policy=None,
  177. extra={"use_audio_input": True}
  178. )
  179. else:
  180. gen_config_yaml(
  181. cur_root,
  182. spm_filename=spm_filename_prefix + ".model",
  183. yaml_filename=f"config_{args.task}.yaml",
  184. specaugment_policy="lb",
  185. )
  186. # Clean up
  187. shutil.rmtree(audio_root)
  188. def process_joint(args):
  189. cur_root = Path(args.data_root)
  190. assert all((cur_root / f"{lang}").is_dir() for lang in mTEDx.LANGPAIRS), \
  191. "do not have downloaded data available for all languages"
  192. # Generate vocab
  193. vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
  194. spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{args.task}"
  195. with NamedTemporaryFile(mode="w") as f:
  196. for lang in mTEDx.LANGPAIRS:
  197. tsv_path = cur_root / f"{lang}" / f"train_{args.task}.tsv"
  198. df = load_df_from_tsv(tsv_path)
  199. for t in df["tgt_text"]:
  200. f.write(t + "\n")
  201. special_symbols = None
  202. if args.joint:
  203. # Add tgt_lang tags to dict
  204. special_symbols = list(
  205. {f'<lang:{lang.split("-")[1]}>' for lang in mTEDx.LANGPAIRS}
  206. )
  207. gen_vocab(
  208. Path(f.name),
  209. cur_root / spm_filename_prefix,
  210. args.vocab_type,
  211. args.vocab_size,
  212. special_symbols=special_symbols
  213. )
  214. # Generate config YAML
  215. gen_config_yaml(
  216. cur_root,
  217. spm_filename=spm_filename_prefix + ".model",
  218. yaml_filename=f"config_{args.task}.yaml",
  219. specaugment_policy="ld",
  220. prepend_tgt_lang_tag=(args.joint),
  221. )
  222. # Make symbolic links to manifests
  223. for lang in mTEDx.LANGPAIRS:
  224. for split in mTEDx.SPLITS:
  225. src_path = cur_root / f"{lang}" / f"{split}_{args.task}.tsv"
  226. desc_path = cur_root / f"{split}_{lang}_{args.task}.tsv"
  227. if not desc_path.is_symlink():
  228. os.symlink(src_path, desc_path)
  229. def main():
  230. parser = argparse.ArgumentParser()
  231. parser.add_argument("--data-root", "-d", required=True, type=str)
  232. parser.add_argument(
  233. "--vocab-type",
  234. default="unigram",
  235. required=True,
  236. type=str,
  237. choices=["bpe", "unigram", "char"],
  238. ),
  239. parser.add_argument("--vocab-size", default=8000, type=int)
  240. parser.add_argument("--task", type=str, choices=["asr", "st"])
  241. parser.add_argument("--joint", action="store_true", help="")
  242. parser.add_argument("--use-audio-input", action="store_true")
  243. args = parser.parse_args()
  244. if args.joint:
  245. process_joint(args)
  246. else:
  247. process(args)
  248. if __name__ == "__main__":
  249. main()