/wav2vec_cycle_code/fairseq/examples/speech_to_text/prep_mtedx_data.py
Python | 271 lines | 239 code | 15 blank | 17 comment | 28 complexity | 3a746bfab87b03f768788506834ac885 MD5 | raw file
- #!/usr/bin/env python3
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import argparse
- import logging
- import os
- from pathlib import Path
- import shutil
- from itertools import groupby
- from tempfile import NamedTemporaryFile
- from typing import Tuple
- import pandas as pd
- import soundfile as sf
- from examples.speech_to_text.data_utils import (
- create_zip,
- extract_fbank_features,
- filter_manifest_df,
- gen_config_yaml,
- gen_vocab,
- get_zip_manifest,
- load_df_from_tsv,
- save_df_to_tsv,
- )
- import torch
- from torch.utils.data import Dataset
- from tqdm import tqdm
- from fairseq.data.audio.audio_utils import get_waveform, convert_waveform
- log = logging.getLogger(__name__)
- MANIFEST_COLUMNS = [
- "id", "audio", "n_frames", "tgt_text", "speaker", "tgt_lang"
- ]
- class mTEDx(Dataset):
- """
- Create a Dataset for Multilingual TEDx.
- Each item is a tuple of the form: waveform, sample_rate, source utterance,
- target utterance, speaker_id, utterance_id
- """
- SPLITS = ["train", "valid", "test"]
- LANGPAIRS = ["es-es", "fr-fr", "pt-pt", "it-it", "ru-ru", "el-el", "ar-ar",
- "de-de", "es-en", "es-fr", "es-pt", "es-it", "fr-en", "fr-es",
- "fr-pt", "pt-en", "pt-es", "it-en", "it-es", "ru-en", "el-en"]
- def __init__(self, root: str, lang: str, split: str) -> None:
- assert split in self.SPLITS and lang in self.LANGPAIRS
- _root = Path(root) / f"{lang}" / "data" / split
- wav_root, txt_root = _root / "wav", _root / "txt"
- assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir()
- # Load audio segments
- try:
- import yaml
- except ImportError:
- print(
- "Please install PyYAML to load the Multilingual TEDx YAML files"
- )
- with open(txt_root / f"{split}.yaml") as f:
- segments = yaml.load(f, Loader=yaml.BaseLoader)
- # Load source and target utterances
- src, tgt = lang.split("-")
- for _lang in [src, tgt]:
- with open(txt_root / f"{split}.{_lang}") as f:
- utterances = [r.strip() for r in f]
- assert len(segments) == len(utterances)
- for i, u in enumerate(utterances):
- segments[i][_lang] = u
- # Gather info
- self.data = []
- for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
- wav_filename = wav_filename.replace(".wav", ".flac")
- wav_path = wav_root / wav_filename
- sample_rate = sf.info(wav_path.as_posix()).samplerate
- seg_group = sorted(_seg_group, key=lambda x: float(x["offset"]))
- for i, segment in enumerate(seg_group):
- offset = int(float(segment["offset"]) * sample_rate)
- n_frames = int(float(segment["duration"]) * sample_rate)
- _id = f"{wav_path.stem}_{i}"
- self.data.append(
- (
- wav_path.as_posix(),
- offset,
- n_frames,
- sample_rate,
- segment[src],
- segment[tgt],
- segment["speaker_id"],
- tgt,
- _id,
- )
- )
- def __getitem__(
- self, n: int
- ) -> Tuple[torch.Tensor, int, str, str, str, str, str]:
- wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, tgt_lang, \
- utt_id = self.data[n]
- waveform, _ = get_waveform(wav_path, frames=n_frames, start=offset)
- waveform = torch.from_numpy(waveform)
- return waveform, sr, src_utt, tgt_utt, spk_id, tgt_lang, utt_id
- def __len__(self) -> int:
- return len(self.data)
- def process(args):
- root = Path(args.data_root).absolute()
- for lang in mTEDx.LANGPAIRS:
- cur_root = root / f"{lang}"
- if not cur_root.is_dir():
- print(f"{cur_root.as_posix()} does not exist. Skipped.")
- continue
- # Extract features
- audio_root = cur_root / ("flac" if args.use_audio_input else "fbank80")
- audio_root.mkdir(exist_ok=True)
- for split in mTEDx.SPLITS:
- print(f"Fetching split {split}...")
- dataset = mTEDx(root.as_posix(), lang, split)
- if args.use_audio_input:
- print("Converting audios...")
- for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
- tgt_sample_rate = 16_000
- _wavform, _ = convert_waveform(
- waveform, sample_rate, to_mono=True,
- to_sample_rate=tgt_sample_rate
- )
- sf.write(
- (audio_root / f"{utt_id}.flac").as_posix(),
- _wavform.numpy(), tgt_sample_rate
- )
- else:
- print("Extracting log mel filter bank features...")
- for waveform, sample_rate, _, _, _, _, utt_id in tqdm(dataset):
- extract_fbank_features(
- waveform, sample_rate, audio_root / f"{utt_id}.npy"
- )
- # Pack features into ZIP
- zip_path = cur_root / f"{audio_root.name}.zip"
- print("ZIPing audios/features...")
- create_zip(audio_root, zip_path)
- print("Fetching ZIP manifest...")
- audio_paths, audio_lengths = get_zip_manifest(zip_path)
- # Generate TSV manifest
- print("Generating manifest...")
- train_text = []
- for split in mTEDx.SPLITS:
- is_train_split = split.startswith("train")
- manifest = {c: [] for c in MANIFEST_COLUMNS}
- ds = mTEDx(args.data_root, lang, split)
- for _, _, src_utt, tgt_utt, spk_id, tgt_lang, utt_id in tqdm(ds):
- manifest["id"].append(utt_id)
- manifest["audio"].append(audio_paths[utt_id])
- manifest["n_frames"].append(audio_lengths[utt_id])
- manifest["tgt_text"].append(
- src_utt if args.task == "asr" else tgt_utt
- )
- manifest["speaker"].append(spk_id)
- manifest["tgt_lang"].append(tgt_lang)
- if is_train_split:
- train_text.extend(manifest["tgt_text"])
- df = pd.DataFrame.from_dict(manifest)
- df = filter_manifest_df(df, is_train_split=is_train_split)
- save_df_to_tsv(df, cur_root / f"{split}_{args.task}.tsv")
- # Generate vocab
- v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
- spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}"
- with NamedTemporaryFile(mode="w") as f:
- for t in train_text:
- f.write(t + "\n")
- gen_vocab(
- Path(f.name),
- cur_root / spm_filename_prefix,
- args.vocab_type,
- args.vocab_size,
- )
- # Generate config YAML
- if args.use_audio_input:
- gen_config_yaml(
- cur_root,
- spm_filename=spm_filename_prefix + ".model",
- yaml_filename=f"config_{args.task}.yaml",
- specaugment_policy=None,
- extra={"use_audio_input": True}
- )
- else:
- gen_config_yaml(
- cur_root,
- spm_filename=spm_filename_prefix + ".model",
- yaml_filename=f"config_{args.task}.yaml",
- specaugment_policy="lb",
- )
- # Clean up
- shutil.rmtree(audio_root)
- def process_joint(args):
- cur_root = Path(args.data_root)
- assert all((cur_root / f"{lang}").is_dir() for lang in mTEDx.LANGPAIRS), \
- "do not have downloaded data available for all languages"
- # Generate vocab
- vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
- spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{args.task}"
- with NamedTemporaryFile(mode="w") as f:
- for lang in mTEDx.LANGPAIRS:
- tsv_path = cur_root / f"{lang}" / f"train_{args.task}.tsv"
- df = load_df_from_tsv(tsv_path)
- for t in df["tgt_text"]:
- f.write(t + "\n")
- special_symbols = None
- if args.joint:
- # Add tgt_lang tags to dict
- special_symbols = list(
- {f'<lang:{lang.split("-")[1]}>' for lang in mTEDx.LANGPAIRS}
- )
- gen_vocab(
- Path(f.name),
- cur_root / spm_filename_prefix,
- args.vocab_type,
- args.vocab_size,
- special_symbols=special_symbols
- )
- # Generate config YAML
- gen_config_yaml(
- cur_root,
- spm_filename=spm_filename_prefix + ".model",
- yaml_filename=f"config_{args.task}.yaml",
- specaugment_policy="ld",
- prepend_tgt_lang_tag=(args.joint),
- )
- # Make symbolic links to manifests
- for lang in mTEDx.LANGPAIRS:
- for split in mTEDx.SPLITS:
- src_path = cur_root / f"{lang}" / f"{split}_{args.task}.tsv"
- desc_path = cur_root / f"{split}_{lang}_{args.task}.tsv"
- if not desc_path.is_symlink():
- os.symlink(src_path, desc_path)
- def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("--data-root", "-d", required=True, type=str)
- parser.add_argument(
- "--vocab-type",
- default="unigram",
- required=True,
- type=str,
- choices=["bpe", "unigram", "char"],
- ),
- parser.add_argument("--vocab-size", default=8000, type=int)
- parser.add_argument("--task", type=str, choices=["asr", "st"])
- parser.add_argument("--joint", action="store_true", help="")
- parser.add_argument("--use-audio-input", action="store_true")
- args = parser.parse_args()
- if args.joint:
- process_joint(args)
- else:
- process(args)
- if __name__ == "__main__":
- main()