PageRenderTime 53ms CodeModel.GetById 23ms RepoModel.GetById 1ms app.codeStats 0ms

/nltk/translate/api.py

https://github.com/nltk/nltk
Python | 334 lines | 314 code | 3 blank | 17 comment | 0 complexity | 47b6c80a0b50f609643e6b780a96c0cd MD5 | raw file
  1. # Natural Language Toolkit: API for alignment and translation objects
  2. #
  3. # Copyright (C) 2001-2022 NLTK Project
  4. # Author: Will Zhang <wilzzha@gmail.com>
  5. # Guan Gui <ggui@student.unimelb.edu.au>
  6. # Steven Bird <stevenbird1@gmail.com>
  7. # Tah Wei Hoon <hoon.tw@gmail.com>
  8. # URL: <https://www.nltk.org/>
  9. # For license information, see LICENSE.TXT
  10. import subprocess
  11. from collections import namedtuple
  12. class AlignedSent:
  13. """
  14. Return an aligned sentence object, which encapsulates two sentences
  15. along with an ``Alignment`` between them.
  16. Typically used in machine translation to represent a sentence and
  17. its translation.
  18. >>> from nltk.translate import AlignedSent, Alignment
  19. >>> algnsent = AlignedSent(['klein', 'ist', 'das', 'Haus'],
  20. ... ['the', 'house', 'is', 'small'], Alignment.fromstring('0-3 1-2 2-0 3-1'))
  21. >>> algnsent.words
  22. ['klein', 'ist', 'das', 'Haus']
  23. >>> algnsent.mots
  24. ['the', 'house', 'is', 'small']
  25. >>> algnsent.alignment
  26. Alignment([(0, 3), (1, 2), (2, 0), (3, 1)])
  27. >>> from nltk.corpus import comtrans
  28. >>> print(comtrans.aligned_sents()[54])
  29. <AlignedSent: 'Weshalb also sollten...' -> 'So why should EU arm...'>
  30. >>> print(comtrans.aligned_sents()[54].alignment)
  31. 0-0 0-1 1-0 2-2 3-4 3-5 4-7 5-8 6-3 7-9 8-9 9-10 9-11 10-12 11-6 12-6 13-13
  32. :param words: Words in the target language sentence
  33. :type words: list(str)
  34. :param mots: Words in the source language sentence
  35. :type mots: list(str)
  36. :param alignment: Word-level alignments between ``words`` and ``mots``.
  37. Each alignment is represented as a 2-tuple (words_index, mots_index).
  38. :type alignment: Alignment
  39. """
  40. def __init__(self, words, mots, alignment=None):
  41. self._words = words
  42. self._mots = mots
  43. if alignment is None:
  44. self.alignment = Alignment([])
  45. else:
  46. assert type(alignment) is Alignment
  47. self.alignment = alignment
  48. @property
  49. def words(self):
  50. return self._words
  51. @property
  52. def mots(self):
  53. return self._mots
  54. def _get_alignment(self):
  55. return self._alignment
  56. def _set_alignment(self, alignment):
  57. _check_alignment(len(self.words), len(self.mots), alignment)
  58. self._alignment = alignment
  59. alignment = property(_get_alignment, _set_alignment)
  60. def __repr__(self):
  61. """
  62. Return a string representation for this ``AlignedSent``.
  63. :rtype: str
  64. """
  65. words = "[%s]" % (", ".join("'%s'" % w for w in self._words))
  66. mots = "[%s]" % (", ".join("'%s'" % w for w in self._mots))
  67. return f"AlignedSent({words}, {mots}, {self._alignment!r})"
  68. def _to_dot(self):
  69. """
  70. Dot representation of the aligned sentence
  71. """
  72. s = "graph align {\n"
  73. s += "node[shape=plaintext]\n"
  74. # Declare node
  75. for w in self._words:
  76. s += f'"{w}_source" [label="{w}"] \n'
  77. for w in self._mots:
  78. s += f'"{w}_target" [label="{w}"] \n'
  79. # Alignment
  80. for u, v in self._alignment:
  81. s += f'"{self._words[u]}_source" -- "{self._mots[v]}_target" \n'
  82. # Connect the source words
  83. for i in range(len(self._words) - 1):
  84. s += '"{}_source" -- "{}_source" [style=invis]\n'.format(
  85. self._words[i],
  86. self._words[i + 1],
  87. )
  88. # Connect the target words
  89. for i in range(len(self._mots) - 1):
  90. s += '"{}_target" -- "{}_target" [style=invis]\n'.format(
  91. self._mots[i],
  92. self._mots[i + 1],
  93. )
  94. # Put it in the same rank
  95. s += "{rank = same; %s}\n" % (" ".join('"%s_source"' % w for w in self._words))
  96. s += "{rank = same; %s}\n" % (" ".join('"%s_target"' % w for w in self._mots))
  97. s += "}"
  98. return s
  99. def _repr_svg_(self):
  100. """
  101. Ipython magic : show SVG representation of this ``AlignedSent``.
  102. """
  103. dot_string = self._to_dot().encode("utf8")
  104. output_format = "svg"
  105. try:
  106. process = subprocess.Popen(
  107. ["dot", "-T%s" % output_format],
  108. stdin=subprocess.PIPE,
  109. stdout=subprocess.PIPE,
  110. stderr=subprocess.PIPE,
  111. )
  112. except OSError as e:
  113. raise Exception("Cannot find the dot binary from Graphviz package") from e
  114. out, err = process.communicate(dot_string)
  115. return out.decode("utf8")
  116. def __str__(self):
  117. """
  118. Return a human-readable string representation for this ``AlignedSent``.
  119. :rtype: str
  120. """
  121. source = " ".join(self._words)[:20] + "..."
  122. target = " ".join(self._mots)[:20] + "..."
  123. return f"<AlignedSent: '{source}' -> '{target}'>"
  124. def invert(self):
  125. """
  126. Return the aligned sentence pair, reversing the directionality
  127. :rtype: AlignedSent
  128. """
  129. return AlignedSent(self._mots, self._words, self._alignment.invert())
  130. class Alignment(frozenset):
  131. """
  132. A storage class for representing alignment between two sequences, s1, s2.
  133. In general, an alignment is a set of tuples of the form (i, j, ...)
  134. representing an alignment between the i-th element of s1 and the
  135. j-th element of s2. Tuples are extensible (they might contain
  136. additional data, such as a boolean to indicate sure vs possible alignments).
  137. >>> from nltk.translate import Alignment
  138. >>> a = Alignment([(0, 0), (0, 1), (1, 2), (2, 2)])
  139. >>> a.invert()
  140. Alignment([(0, 0), (1, 0), (2, 1), (2, 2)])
  141. >>> print(a.invert())
  142. 0-0 1-0 2-1 2-2
  143. >>> a[0]
  144. [(0, 1), (0, 0)]
  145. >>> a.invert()[2]
  146. [(2, 1), (2, 2)]
  147. >>> b = Alignment([(0, 0), (0, 1)])
  148. >>> b.issubset(a)
  149. True
  150. >>> c = Alignment.fromstring('0-0 0-1')
  151. >>> b == c
  152. True
  153. """
  154. def __new__(cls, pairs):
  155. self = frozenset.__new__(cls, pairs)
  156. self._len = max(p[0] for p in self) if self != frozenset([]) else 0
  157. self._index = None
  158. return self
  159. @classmethod
  160. def fromstring(cls, s):
  161. """
  162. Read a giza-formatted string and return an Alignment object.
  163. >>> Alignment.fromstring('0-0 2-1 9-2 21-3 10-4 7-5')
  164. Alignment([(0, 0), (2, 1), (7, 5), (9, 2), (10, 4), (21, 3)])
  165. :type s: str
  166. :param s: the positional alignments in giza format
  167. :rtype: Alignment
  168. :return: An Alignment object corresponding to the string representation ``s``.
  169. """
  170. return Alignment([_giza2pair(a) for a in s.split()])
  171. def __getitem__(self, key):
  172. """
  173. Look up the alignments that map from a given index or slice.
  174. """
  175. if not self._index:
  176. self._build_index()
  177. return self._index.__getitem__(key)
  178. def invert(self):
  179. """
  180. Return an Alignment object, being the inverted mapping.
  181. """
  182. return Alignment(((p[1], p[0]) + p[2:]) for p in self)
  183. def range(self, positions=None):
  184. """
  185. Work out the range of the mapping from the given positions.
  186. If no positions are specified, compute the range of the entire mapping.
  187. """
  188. image = set()
  189. if not self._index:
  190. self._build_index()
  191. if not positions:
  192. positions = list(range(len(self._index)))
  193. for p in positions:
  194. image.update(f for _, f in self._index[p])
  195. return sorted(image)
  196. def __repr__(self):
  197. """
  198. Produce a Giza-formatted string representing the alignment.
  199. """
  200. return "Alignment(%r)" % sorted(self)
  201. def __str__(self):
  202. """
  203. Produce a Giza-formatted string representing the alignment.
  204. """
  205. return " ".join("%d-%d" % p[:2] for p in sorted(self))
  206. def _build_index(self):
  207. """
  208. Build a list self._index such that self._index[i] is a list
  209. of the alignments originating from word i.
  210. """
  211. self._index = [[] for _ in range(self._len + 1)]
  212. for p in self:
  213. self._index[p[0]].append(p)
  214. def _giza2pair(pair_string):
  215. i, j = pair_string.split("-")
  216. return int(i), int(j)
  217. def _naacl2pair(pair_string):
  218. i, j, p = pair_string.split("-")
  219. return int(i), int(j)
  220. def _check_alignment(num_words, num_mots, alignment):
  221. """
  222. Check whether the alignments are legal.
  223. :param num_words: the number of source language words
  224. :type num_words: int
  225. :param num_mots: the number of target language words
  226. :type num_mots: int
  227. :param alignment: alignment to be checked
  228. :type alignment: Alignment
  229. :raise IndexError: if alignment falls outside the sentence
  230. """
  231. assert type(alignment) is Alignment
  232. if not all(0 <= pair[0] < num_words for pair in alignment):
  233. raise IndexError("Alignment is outside boundary of words")
  234. if not all(pair[1] is None or 0 <= pair[1] < num_mots for pair in alignment):
  235. raise IndexError("Alignment is outside boundary of mots")
  236. PhraseTableEntry = namedtuple("PhraseTableEntry", ["trg_phrase", "log_prob"])
  237. class PhraseTable:
  238. """
  239. In-memory store of translations for a given phrase, and the log
  240. probability of the those translations
  241. """
  242. def __init__(self):
  243. self.src_phrases = dict()
  244. def translations_for(self, src_phrase):
  245. """
  246. Get the translations for a source language phrase
  247. :param src_phrase: Source language phrase of interest
  248. :type src_phrase: tuple(str)
  249. :return: A list of target language phrases that are translations
  250. of ``src_phrase``, ordered in decreasing order of
  251. likelihood. Each list element is a tuple of the target
  252. phrase and its log probability.
  253. :rtype: list(PhraseTableEntry)
  254. """
  255. return self.src_phrases[src_phrase]
  256. def add(self, src_phrase, trg_phrase, log_prob):
  257. """
  258. :type src_phrase: tuple(str)
  259. :type trg_phrase: tuple(str)
  260. :param log_prob: Log probability that given ``src_phrase``,
  261. ``trg_phrase`` is its translation
  262. :type log_prob: float
  263. """
  264. entry = PhraseTableEntry(trg_phrase=trg_phrase, log_prob=log_prob)
  265. if src_phrase not in self.src_phrases:
  266. self.src_phrases[src_phrase] = []
  267. self.src_phrases[src_phrase].append(entry)
  268. self.src_phrases[src_phrase].sort(key=lambda e: e.log_prob, reverse=True)
  269. def __contains__(self, src_phrase):
  270. return src_phrase in self.src_phrases