PageRenderTime 44ms CodeModel.GetById 16ms RepoModel.GetById 0ms app.codeStats 0ms

/wav2vec_cycle_code/FragmentVC/src/fairseq/examples/byte_level_bpe/get_bitext.py

https://gitlab.com/lwd17/enhanced_examplar_ae
Python | 254 lines | 243 code | 7 blank | 4 comment | 3 complexity | 7de6f69fa9e7f34ac4f3e61b63b211bd MD5 | raw file
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import argparse
  6. import os
  7. import os.path as op
  8. from collections import namedtuple
  9. from multiprocessing import cpu_count
  10. from typing import List, Optional
  11. import sentencepiece as sp
  12. from fairseq.data.encoders.byte_bpe import ByteBPE
  13. from fairseq.data.encoders.byte_utils import byte_encode
  14. from fairseq.data.encoders.bytes import Bytes
  15. from fairseq.data.encoders.characters import Characters
  16. from fairseq.data.encoders.moses_tokenizer import MosesTokenizer
  17. from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
  18. SPLITS = ["train", "valid", "test"]
  19. def _convert_xml(in_path: str, out_path: str):
  20. with open(in_path) as f, open(out_path, "w") as f_o:
  21. for s in f:
  22. ss = s.strip()
  23. if not ss.startswith("<seg"):
  24. continue
  25. ss = ss.replace("</seg>", "").split('">')
  26. assert len(ss) == 2
  27. f_o.write(ss[1].strip() + "\n")
  28. def _convert_train(in_path: str, out_path: str):
  29. with open(in_path) as f, open(out_path, "w") as f_o:
  30. for s in f:
  31. ss = s.strip()
  32. if ss.startswith("<"):
  33. continue
  34. f_o.write(ss.strip() + "\n")
  35. def _get_bytes(in_path: str, out_path: str):
  36. with open(in_path) as f, open(out_path, "w") as f_o:
  37. for s in f:
  38. f_o.write(Bytes.encode(s.strip()) + "\n")
  39. def _get_chars(in_path: str, out_path: str):
  40. with open(in_path) as f, open(out_path, "w") as f_o:
  41. for s in f:
  42. f_o.write(Characters.encode(s.strip()) + "\n")
  43. def pretokenize(in_path: str, out_path: str, src: str, tgt: str):
  44. Args = namedtuple(
  45. "Args",
  46. [
  47. "moses_source_lang",
  48. "moses_target_lang",
  49. "moses_no_dash_splits",
  50. "moses_no_escape",
  51. ],
  52. )
  53. args = Args(
  54. moses_source_lang=src,
  55. moses_target_lang=tgt,
  56. moses_no_dash_splits=False,
  57. moses_no_escape=False,
  58. )
  59. pretokenizer = MosesTokenizer(args)
  60. with open(in_path) as f, open(out_path, "w") as f_o:
  61. for s in f:
  62. f_o.write(pretokenizer.encode(s.strip()) + "\n")
  63. def _convert_to_bchar(in_path_prefix: str, src: str, tgt: str, out_path: str):
  64. with open(out_path, "w") as f_o:
  65. for lang in [src, tgt]:
  66. with open(f"{in_path_prefix}.{lang}") as f:
  67. for s in f:
  68. f_o.write(byte_encode(s.strip()) + "\n")
  69. def _get_bpe(in_path: str, model_prefix: str, vocab_size: int):
  70. arguments = [
  71. f"--input={in_path}",
  72. f"--model_prefix={model_prefix}",
  73. f"--model_type=bpe",
  74. f"--vocab_size={vocab_size}",
  75. "--character_coverage=1.0",
  76. "--normalization_rule_name=identity",
  77. f"--num_threads={cpu_count()}",
  78. ]
  79. sp.SentencePieceTrainer.Train(" ".join(arguments))
  80. def _apply_bbpe(model_path: str, in_path: str, out_path: str):
  81. Args = namedtuple("Args", ["sentencepiece_model_path"])
  82. args = Args(sentencepiece_model_path=model_path)
  83. tokenizer = ByteBPE(args)
  84. with open(in_path) as f, open(out_path, "w") as f_o:
  85. for s in f:
  86. f_o.write(tokenizer.encode(s.strip()) + "\n")
  87. def _apply_bpe(model_path: str, in_path: str, out_path: str):
  88. Args = namedtuple("Args", ["sentencepiece_model"])
  89. args = Args(sentencepiece_model=model_path)
  90. tokenizer = SentencepieceBPE(args)
  91. with open(in_path) as f, open(out_path, "w") as f_o:
  92. for s in f:
  93. f_o.write(tokenizer.encode(s.strip()) + "\n")
  94. def _concat_files(in_paths: List[str], out_path: str):
  95. with open(out_path, "w") as f_o:
  96. for p in in_paths:
  97. with open(p) as f:
  98. for r in f:
  99. f_o.write(r)
  100. def preprocess_iwslt17(
  101. root: str,
  102. src: str,
  103. tgt: str,
  104. bpe_size: Optional[int],
  105. need_chars: bool,
  106. bbpe_size: Optional[int],
  107. need_bytes: bool,
  108. ):
  109. # extract bitext
  110. in_root = op.join(root, f"{src}-{tgt}")
  111. for lang in [src, tgt]:
  112. _convert_train(
  113. op.join(in_root, f"train.tags.{src}-{tgt}.{lang}"),
  114. op.join(root, f"train.{lang}"),
  115. )
  116. _convert_xml(
  117. op.join(in_root, f"IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml"),
  118. op.join(root, f"valid.{lang}"),
  119. )
  120. _convert_xml(
  121. op.join(in_root, f"IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml"),
  122. op.join(root, f"test.{lang}"),
  123. )
  124. # pre-tokenize
  125. for lang in [src, tgt]:
  126. for split in SPLITS:
  127. pretokenize(
  128. op.join(root, f"{split}.{lang}"),
  129. op.join(root, f"{split}.moses.{lang}"),
  130. src,
  131. tgt,
  132. )
  133. # tokenize with BPE vocabulary
  134. if bpe_size is not None:
  135. # learn vocabulary
  136. concated_train_path = op.join(root, "train.all")
  137. _concat_files(
  138. [op.join(root, "train.moses.fr"), op.join(root, "train.moses.en")],
  139. concated_train_path,
  140. )
  141. bpe_model_prefix = op.join(root, f"spm_bpe{bpe_size}")
  142. _get_bpe(concated_train_path, bpe_model_prefix, bpe_size)
  143. os.remove(concated_train_path)
  144. # apply
  145. for lang in [src, tgt]:
  146. for split in SPLITS:
  147. _apply_bpe(
  148. bpe_model_prefix + ".model",
  149. op.join(root, f"{split}.moses.{lang}"),
  150. op.join(root, f"{split}.moses.bpe{bpe_size}.{lang}"),
  151. )
  152. # tokenize with bytes vocabulary
  153. if need_bytes:
  154. for lang in [src, tgt]:
  155. for split in SPLITS:
  156. _get_bytes(
  157. op.join(root, f"{split}.moses.{lang}"),
  158. op.join(root, f"{split}.moses.bytes.{lang}"),
  159. )
  160. # tokenize with characters vocabulary
  161. if need_chars:
  162. for lang in [src, tgt]:
  163. for split in SPLITS:
  164. _get_chars(
  165. op.join(root, f"{split}.moses.{lang}"),
  166. op.join(root, f"{split}.moses.chars.{lang}"),
  167. )
  168. # tokenize with byte-level BPE vocabulary
  169. if bbpe_size is not None:
  170. # learn vocabulary
  171. bchar_path = op.join(root, "train.bchar")
  172. _convert_to_bchar(op.join(root, "train.moses"), src, tgt, bchar_path)
  173. bbpe_model_prefix = op.join(root, f"spm_bbpe{bbpe_size}")
  174. _get_bpe(bchar_path, bbpe_model_prefix, bbpe_size)
  175. os.remove(bchar_path)
  176. # apply
  177. for lang in [src, tgt]:
  178. for split in SPLITS:
  179. _apply_bbpe(
  180. bbpe_model_prefix + ".model",
  181. op.join(root, f"{split}.moses.{lang}"),
  182. op.join(root, f"{split}.moses.bbpe{bbpe_size}.{lang}"),
  183. )
  184. def main():
  185. parser = argparse.ArgumentParser()
  186. parser.add_argument("--root", type=str, default="data")
  187. parser.add_argument(
  188. "--bpe-vocab",
  189. default=None,
  190. type=int,
  191. help="Generate tokenized bitext with BPE of size K."
  192. "Default to None (disabled).",
  193. )
  194. parser.add_argument(
  195. "--bbpe-vocab",
  196. default=None,
  197. type=int,
  198. help="Generate tokenized bitext with BBPE of size K."
  199. "Default to None (disabled).",
  200. )
  201. parser.add_argument(
  202. "--byte-vocab",
  203. action="store_true",
  204. help="Generate tokenized bitext with bytes vocabulary",
  205. )
  206. parser.add_argument(
  207. "--char-vocab",
  208. action="store_true",
  209. help="Generate tokenized bitext with chars vocabulary",
  210. )
  211. args = parser.parse_args()
  212. preprocess_iwslt17(
  213. args.root,
  214. "fr",
  215. "en",
  216. args.bpe_vocab,
  217. args.char_vocab,
  218. args.bbpe_vocab,
  219. args.byte_vocab,
  220. )
  221. if __name__ == "__main__":
  222. main()