PageRenderTime 47ms CodeModel.GetById 19ms RepoModel.GetById 0ms app.codeStats 0ms

/wav2vec_cycle_code/fairseq/examples/criss/mining/mine.py

https://gitlab.com/lwd17/enhanced_examplar_ae
Python | 240 lines | 209 code | 26 blank | 5 comment | 25 complexity | 65f82803b8ea07c3495a5d48833731a5 MD5 | raw file
  1. #!/usr/bin/env python3 -u
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. #
  4. # This source code is licensed under the MIT license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import argparse
  7. import glob
  8. from subprocess import check_call
  9. try:
  10. import faiss
  11. has_faiss = True
  12. except ImportError:
  13. has_faiss = False
  14. import numpy as np
  15. GB = 1024 * 1024 * 1024
  16. def call(cmd):
  17. print(cmd)
  18. check_call(cmd, shell=True)
  19. def get_batches(directory, lang, prefix="all_avg_pool"):
  20. print(f"Finding in {directory}/{prefix}.{lang}*")
  21. files = glob.glob(f"{directory}/{prefix}.{lang}*")
  22. emb_files = []
  23. txt_files = []
  24. for emb_fi in files:
  25. emb_files.append(emb_fi)
  26. txt_fi = emb_fi.replace(prefix, "sentences")
  27. txt_files.append(txt_fi)
  28. return emb_files, txt_files
  29. def load_batch(emb_file, dim):
  30. embeddings = np.fromfile(emb_file, dtype=np.float32)
  31. num_rows = int(embeddings.shape[0] / dim)
  32. embeddings = embeddings.reshape((num_rows, dim))
  33. faiss.normalize_L2(embeddings)
  34. return embeddings
  35. def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction="x2y"):
  36. if not has_faiss:
  37. raise ImportError("Please install Faiss")
  38. sims = []
  39. inds = []
  40. xfrom = 0
  41. xto = 0
  42. for x_batch_f in x_batches_f:
  43. yfrom = 0
  44. yto = 0
  45. x_batch = load_batch(x_batch_f, dim)
  46. xto = xfrom + x_batch.shape[0]
  47. bsims, binds = [], []
  48. for y_batch_f in y_batches_f:
  49. y_batch = load_batch(y_batch_f, dim)
  50. neighbor_size = min(k, y_batch.shape[0])
  51. yto = yfrom + y_batch.shape[0]
  52. print("{}-{} -> {}-{}".format(xfrom, xto, yfrom, yto))
  53. idx = faiss.IndexFlatIP(dim)
  54. idx = faiss.index_cpu_to_all_gpus(idx)
  55. idx.add(y_batch)
  56. bsim, bind = idx.search(x_batch, neighbor_size)
  57. bsims.append(bsim)
  58. binds.append(bind + yfrom)
  59. yfrom += y_batch.shape[0]
  60. del idx
  61. del y_batch
  62. bsims = np.concatenate(bsims, axis=1)
  63. binds = np.concatenate(binds, axis=1)
  64. aux = np.argsort(-bsims, axis=1)
  65. sim_batch = np.zeros((x_batch.shape[0], k), dtype=np.float32)
  66. ind_batch = np.zeros((x_batch.shape[0], k), dtype=np.int64)
  67. for i in range(x_batch.shape[0]):
  68. for j in range(k):
  69. sim_batch[i, j] = bsims[i, aux[i, j]]
  70. ind_batch[i, j] = binds[i, aux[i, j]]
  71. sims.append(sim_batch)
  72. inds.append(ind_batch)
  73. xfrom += x_batch.shape[0]
  74. del x_batch
  75. sim = np.concatenate(sims, axis=0)
  76. ind = np.concatenate(inds, axis=0)
  77. return sim, ind
  78. def score(sim, fwd_mean, bwd_mean, margin):
  79. return margin(sim, (fwd_mean + bwd_mean) / 2)
  80. def score_candidates(
  81. sim_mat, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False
  82. ):
  83. print(" - scoring {:d} candidates".format(sim_mat.shape[0]))
  84. scores = np.zeros(candidate_inds.shape)
  85. for i in range(scores.shape[0]):
  86. for j in range(scores.shape[1]):
  87. k = int(candidate_inds[i, j])
  88. scores[i, j] = score(sim_mat[i, j], fwd_mean[i], bwd_mean[k], margin)
  89. return scores
  90. def load_text(files):
  91. all_sentences = []
  92. for fi in files:
  93. with open(fi) as sentence_fi:
  94. for line in sentence_fi:
  95. all_sentences.append(line.strip())
  96. print(f"Read {len(all_sentences)} sentences")
  97. return all_sentences
  98. if __name__ == "__main__":
  99. parser = argparse.ArgumentParser(description="Mine bitext")
  100. parser.add_argument("--src-lang", help="Source language")
  101. parser.add_argument("--tgt-lang", help="Target language")
  102. parser.add_argument(
  103. "--dict-path", help="Path to dictionary file", default="dict.txt"
  104. )
  105. parser.add_argument(
  106. "--spm-path", help="Path to SPM model file", default="sentence.bpe.model"
  107. )
  108. parser.add_argument("--dim", type=int, default=1024, help="Embedding dimension")
  109. parser.add_argument("--mem", type=int, default=5, help="Memory in GB")
  110. parser.add_argument("--src-dir", help="Source directory")
  111. parser.add_argument("--tgt-dir", help="Target directory")
  112. parser.add_argument("--output", help="Output path")
  113. parser.add_argument(
  114. "--neighborhood", type=int, default=4, help="Embedding dimension"
  115. )
  116. parser.add_argument(
  117. "--threshold", type=float, default=1.06, help="Threshold on mined bitext"
  118. )
  119. parser.add_argument(
  120. "--valid-size",
  121. type=int,
  122. default=2000,
  123. help="Number of sentences used for validation set",
  124. )
  125. parser.add_argument(
  126. "--min-count",
  127. type=int,
  128. default=50000,
  129. help="Min num sentences used for each language",
  130. )
  131. args = parser.parse_args()
  132. x_batches_f, x_sents_f = get_batches(args.src_dir, args.src_lang)
  133. y_batches_f, y_sents_f = get_batches(args.tgt_dir, args.tgt_lang)
  134. margin = lambda a, b: a / b
  135. y2x_sim, y2x_ind = knnGPU_sharded(
  136. y_batches_f, x_batches_f, args.dim, args.neighborhood, direction="y2x"
  137. )
  138. x2y_sim, x2y_ind = knnGPU_sharded(
  139. x_batches_f, y_batches_f, args.dim, args.neighborhood, direction="x2y"
  140. )
  141. x2y_mean = x2y_sim.mean(axis=1)
  142. y2x_mean = y2x_sim.mean(axis=1)
  143. fwd_scores = score_candidates(x2y_sim, x2y_ind, x2y_mean, y2x_mean, margin)
  144. bwd_scores = score_candidates(y2x_sim, y2x_ind, y2x_mean, x2y_mean, margin)
  145. fwd_best = x2y_ind[np.arange(x2y_sim.shape[0]), fwd_scores.argmax(axis=1)]
  146. bwd_best = y2x_ind[np.arange(y2x_sim.shape[0]), bwd_scores.argmax(axis=1)]
  147. indices = np.stack(
  148. (
  149. np.concatenate((np.arange(x2y_ind.shape[0]), bwd_best)),
  150. np.concatenate((fwd_best, np.arange(y2x_ind.shape[0]))),
  151. ),
  152. axis=1,
  153. )
  154. scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1)))
  155. x_sentences = load_text(x_sents_f)
  156. y_sentences = load_text(y_sents_f)
  157. threshold = args.threshold
  158. min_count = args.min_count
  159. seen_src, seen_trg = set(), set()
  160. directory = args.output
  161. call(f"mkdir -p {directory}")
  162. src_out = open(
  163. f"{directory}/all.{args.src_lang}",
  164. mode="w",
  165. encoding="utf-8",
  166. errors="surrogateescape",
  167. )
  168. tgt_out = open(
  169. f"{directory}/all.{args.tgt_lang}",
  170. mode="w",
  171. encoding="utf-8",
  172. errors="surrogateescape",
  173. )
  174. scores_out = open(
  175. f"{directory}/all.scores", mode="w", encoding="utf-8", errors="surrogateescape"
  176. )
  177. count = 0
  178. for i in np.argsort(-scores):
  179. src_ind, trg_ind = indices[i]
  180. if src_ind not in seen_src and trg_ind not in seen_trg:
  181. seen_src.add(src_ind)
  182. seen_trg.add(trg_ind)
  183. if scores[i] > threshold or count < min_count:
  184. if x_sentences[src_ind]:
  185. print(scores[i], file=scores_out)
  186. print(x_sentences[src_ind], file=src_out)
  187. print(y_sentences[trg_ind], file=tgt_out)
  188. count += 1
  189. else:
  190. print(f"Ignoring sentence: {x_sentences[src_ind]}")
  191. src_out.close()
  192. tgt_out.close()
  193. scores_out.close()
  194. print(f"Found {count} pairs for threshold={threshold}")
  195. with open(f"{directory}/all.{args.src_lang}") as all_s, open(
  196. f"{directory}/all.{args.tgt_lang}"
  197. ) as all_t, open(f"{directory}/valid.{args.src_lang}", "w") as valid_s, open(
  198. f"{directory}/valid.{args.tgt_lang}", "w"
  199. ) as valid_t, open(
  200. f"{directory}/train.{args.src_lang}", "w"
  201. ) as train_s, open(
  202. f"{directory}/train.{args.tgt_lang}", "w"
  203. ) as train_t:
  204. count = 0
  205. for s_line, t_line in zip(all_s, all_t):
  206. s_line = s_line.split("\t")[1]
  207. t_line = t_line.split("\t")[1]
  208. if count >= args.valid_size:
  209. train_s.write(s_line)
  210. train_t.write(t_line)
  211. else:
  212. valid_s.write(s_line)
  213. valid_t.write(t_line)
  214. count += 1