/wav2vec_cycle_code/FragmentVC/src/fairseq/examples/byte_level_bpe/get_bitext.py
Python | 254 lines | 243 code | 7 blank | 4 comment | 3 complexity | 7de6f69fa9e7f34ac4f3e61b63b211bd MD5 | raw file
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import argparse
- import os
- import os.path as op
- from collections import namedtuple
- from multiprocessing import cpu_count
- from typing import List, Optional
- import sentencepiece as sp
- from fairseq.data.encoders.byte_bpe import ByteBPE
- from fairseq.data.encoders.byte_utils import byte_encode
- from fairseq.data.encoders.bytes import Bytes
- from fairseq.data.encoders.characters import Characters
- from fairseq.data.encoders.moses_tokenizer import MosesTokenizer
- from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
- SPLITS = ["train", "valid", "test"]
- def _convert_xml(in_path: str, out_path: str):
- with open(in_path) as f, open(out_path, "w") as f_o:
- for s in f:
- ss = s.strip()
- if not ss.startswith("<seg"):
- continue
- ss = ss.replace("</seg>", "").split('">')
- assert len(ss) == 2
- f_o.write(ss[1].strip() + "\n")
- def _convert_train(in_path: str, out_path: str):
- with open(in_path) as f, open(out_path, "w") as f_o:
- for s in f:
- ss = s.strip()
- if ss.startswith("<"):
- continue
- f_o.write(ss.strip() + "\n")
- def _get_bytes(in_path: str, out_path: str):
- with open(in_path) as f, open(out_path, "w") as f_o:
- for s in f:
- f_o.write(Bytes.encode(s.strip()) + "\n")
- def _get_chars(in_path: str, out_path: str):
- with open(in_path) as f, open(out_path, "w") as f_o:
- for s in f:
- f_o.write(Characters.encode(s.strip()) + "\n")
- def pretokenize(in_path: str, out_path: str, src: str, tgt: str):
- Args = namedtuple(
- "Args",
- [
- "moses_source_lang",
- "moses_target_lang",
- "moses_no_dash_splits",
- "moses_no_escape",
- ],
- )
- args = Args(
- moses_source_lang=src,
- moses_target_lang=tgt,
- moses_no_dash_splits=False,
- moses_no_escape=False,
- )
- pretokenizer = MosesTokenizer(args)
- with open(in_path) as f, open(out_path, "w") as f_o:
- for s in f:
- f_o.write(pretokenizer.encode(s.strip()) + "\n")
- def _convert_to_bchar(in_path_prefix: str, src: str, tgt: str, out_path: str):
- with open(out_path, "w") as f_o:
- for lang in [src, tgt]:
- with open(f"{in_path_prefix}.{lang}") as f:
- for s in f:
- f_o.write(byte_encode(s.strip()) + "\n")
- def _get_bpe(in_path: str, model_prefix: str, vocab_size: int):
- arguments = [
- f"--input={in_path}",
- f"--model_prefix={model_prefix}",
- f"--model_type=bpe",
- f"--vocab_size={vocab_size}",
- "--character_coverage=1.0",
- "--normalization_rule_name=identity",
- f"--num_threads={cpu_count()}",
- ]
- sp.SentencePieceTrainer.Train(" ".join(arguments))
- def _apply_bbpe(model_path: str, in_path: str, out_path: str):
- Args = namedtuple("Args", ["sentencepiece_model_path"])
- args = Args(sentencepiece_model_path=model_path)
- tokenizer = ByteBPE(args)
- with open(in_path) as f, open(out_path, "w") as f_o:
- for s in f:
- f_o.write(tokenizer.encode(s.strip()) + "\n")
- def _apply_bpe(model_path: str, in_path: str, out_path: str):
- Args = namedtuple("Args", ["sentencepiece_model"])
- args = Args(sentencepiece_model=model_path)
- tokenizer = SentencepieceBPE(args)
- with open(in_path) as f, open(out_path, "w") as f_o:
- for s in f:
- f_o.write(tokenizer.encode(s.strip()) + "\n")
- def _concat_files(in_paths: List[str], out_path: str):
- with open(out_path, "w") as f_o:
- for p in in_paths:
- with open(p) as f:
- for r in f:
- f_o.write(r)
- def preprocess_iwslt17(
- root: str,
- src: str,
- tgt: str,
- bpe_size: Optional[int],
- need_chars: bool,
- bbpe_size: Optional[int],
- need_bytes: bool,
- ):
- # extract bitext
- in_root = op.join(root, f"{src}-{tgt}")
- for lang in [src, tgt]:
- _convert_train(
- op.join(in_root, f"train.tags.{src}-{tgt}.{lang}"),
- op.join(root, f"train.{lang}"),
- )
- _convert_xml(
- op.join(in_root, f"IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml"),
- op.join(root, f"valid.{lang}"),
- )
- _convert_xml(
- op.join(in_root, f"IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml"),
- op.join(root, f"test.{lang}"),
- )
- # pre-tokenize
- for lang in [src, tgt]:
- for split in SPLITS:
- pretokenize(
- op.join(root, f"{split}.{lang}"),
- op.join(root, f"{split}.moses.{lang}"),
- src,
- tgt,
- )
- # tokenize with BPE vocabulary
- if bpe_size is not None:
- # learn vocabulary
- concated_train_path = op.join(root, "train.all")
- _concat_files(
- [op.join(root, "train.moses.fr"), op.join(root, "train.moses.en")],
- concated_train_path,
- )
- bpe_model_prefix = op.join(root, f"spm_bpe{bpe_size}")
- _get_bpe(concated_train_path, bpe_model_prefix, bpe_size)
- os.remove(concated_train_path)
- # apply
- for lang in [src, tgt]:
- for split in SPLITS:
- _apply_bpe(
- bpe_model_prefix + ".model",
- op.join(root, f"{split}.moses.{lang}"),
- op.join(root, f"{split}.moses.bpe{bpe_size}.{lang}"),
- )
- # tokenize with bytes vocabulary
- if need_bytes:
- for lang in [src, tgt]:
- for split in SPLITS:
- _get_bytes(
- op.join(root, f"{split}.moses.{lang}"),
- op.join(root, f"{split}.moses.bytes.{lang}"),
- )
- # tokenize with characters vocabulary
- if need_chars:
- for lang in [src, tgt]:
- for split in SPLITS:
- _get_chars(
- op.join(root, f"{split}.moses.{lang}"),
- op.join(root, f"{split}.moses.chars.{lang}"),
- )
- # tokenize with byte-level BPE vocabulary
- if bbpe_size is not None:
- # learn vocabulary
- bchar_path = op.join(root, "train.bchar")
- _convert_to_bchar(op.join(root, "train.moses"), src, tgt, bchar_path)
- bbpe_model_prefix = op.join(root, f"spm_bbpe{bbpe_size}")
- _get_bpe(bchar_path, bbpe_model_prefix, bbpe_size)
- os.remove(bchar_path)
- # apply
- for lang in [src, tgt]:
- for split in SPLITS:
- _apply_bbpe(
- bbpe_model_prefix + ".model",
- op.join(root, f"{split}.moses.{lang}"),
- op.join(root, f"{split}.moses.bbpe{bbpe_size}.{lang}"),
- )
- def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("--root", type=str, default="data")
- parser.add_argument(
- "--bpe-vocab",
- default=None,
- type=int,
- help="Generate tokenized bitext with BPE of size K."
- "Default to None (disabled).",
- )
- parser.add_argument(
- "--bbpe-vocab",
- default=None,
- type=int,
- help="Generate tokenized bitext with BBPE of size K."
- "Default to None (disabled).",
- )
- parser.add_argument(
- "--byte-vocab",
- action="store_true",
- help="Generate tokenized bitext with bytes vocabulary",
- )
- parser.add_argument(
- "--char-vocab",
- action="store_true",
- help="Generate tokenized bitext with chars vocabulary",
- )
- args = parser.parse_args()
- preprocess_iwslt17(
- args.root,
- "fr",
- "en",
- args.bpe_vocab,
- args.char_vocab,
- args.bbpe_vocab,
- args.byte_vocab,
- )
- if __name__ == "__main__":
- main()