/wav2vec_cycle_code/fairseq/examples/m2m_100/process_data/dedup_data.py

https://gitlab.com/lwd17/enhanced_examplar_ae · Python · 91 lines · 67 code · 21 blank · 3 comment · 21 complexity · ad62e2379a048716a43f4055e465e9fd MD5 · raw file

  1. import argparse
  2. from collections import namedtuple
  3. import os
  4. DATADIR = "/path/to/train_data"
  5. DEDUP_FROM_DIR = "/path/to/eval/data"
  6. OUTPUT_DIR = "/path/to/output/data"
  7. def main(args):
  8. languages = set()
  9. for language_directory in os.listdir(DATADIR):
  10. if "_" in language_directory:
  11. src, tgt = language_directory.split("_")
  12. languages.add(LanguagePair(src=src, tgt=tgt))
  13. data = existing_data()
  14. train_languages = sorted(languages)
  15. for language_pair in train_languages[args.start_index:args.start_index + args.size]:
  16. print(language_pair)
  17. dedup(language_pair, data)
  18. LanguagePair = namedtuple("LanguagePair", ["src", "tgt"])
  19. def existing_data():
  20. data = set()
  21. for file in os.listdir(DEDUP_FROM_DIR):
  22. with open(os.path.join(DEDUP_FROM_DIR, file)) as f:
  23. data |= set(f.readlines())
  24. return data
  25. def dedup(language_pair, data, verbose=True, output=True):
  26. train_filenames = LanguagePair(
  27. src=f"{DATADIR}/{language_pair.src}_{language_pair.tgt}/train.{language_pair.src}",
  28. tgt=f"{DATADIR}/{language_pair.src}_{language_pair.tgt}/train.{language_pair.tgt}",
  29. )
  30. output_filenames = LanguagePair(
  31. src=f"{OUTPUT_DIR}/train.dedup.{language_pair.src}-{language_pair.tgt}.{language_pair.src}",
  32. tgt=f"{OUTPUT_DIR}/train.dedup.{language_pair.src}-{language_pair.tgt}.{language_pair.tgt}"
  33. )
  34. # If output exists, skip this pair. It has already been done.
  35. if (os.path.exists(output_filenames.src) and
  36. os.path.exists(output_filenames.tgt)):
  37. if verbose:
  38. print(f"{language_pair.src}-{language_pair.tgt} already done.")
  39. return
  40. if verbose:
  41. print(f"{language_pair.src}-{language_pair.tgt} ready, will check dups.")
  42. # If there is no output, no need to actually do the loop.
  43. if not output:
  44. return
  45. if os.path.exists(train_filenames.src) and os.path.exists(train_filenames.tgt):
  46. with open(train_filenames.src) as f:
  47. train_source = f.readlines()
  48. with open(train_filenames.tgt) as f:
  49. train_target = f.readlines()
  50. # do dedup
  51. new_train_source = []
  52. new_train_target = []
  53. for i, train_line in enumerate(train_source):
  54. if train_line not in data and train_target[i] not in data:
  55. new_train_source.append(train_line)
  56. new_train_target.append(train_target[i])
  57. assert len(train_source) == len(train_target)
  58. assert len(new_train_source) == len(new_train_target)
  59. assert len(new_train_source) <= len(train_source)
  60. with open(output_filenames.src, "w") as o:
  61. for line in new_train_source:
  62. o.write(line)
  63. with open(output_filenames.tgt, "w") as o:
  64. for line in new_train_target:
  65. o.write(line)
  66. if __name__ == '__main__':
  67. parser = argparse.ArgumentParser()
  68. parser.add_argument("-s", "--start-index", required=True, type=int)
  69. parser.add_argument("-n", "--size", required=True, type=int)
  70. main(parser.parse_args())