PageRenderTime 35ms CodeModel.GetById 15ms RepoModel.GetById 1ms app.codeStats 0ms

/wav2vec_cycle_code/fairseq/examples/simultaneous_translation/models/transformer_monotonic_attention.py

https://gitlab.com/lwd17/enhanced_examplar_ae
Python | 302 lines | 212 code | 33 blank | 57 comment | 19 complexity | 8b58d189d05bcfa00bfb04810971fdfc MD5 | raw file
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from typing import Dict, List, NamedTuple, Optional
  6. import torch
  7. import torch.nn as nn
  8. from examples.simultaneous_translation.modules.monotonic_transformer_layer import (
  9. TransformerMonotonicDecoderLayer,
  10. TransformerMonotonicEncoderLayer,
  11. )
  12. from fairseq.models import (
  13. register_model,
  14. register_model_architecture,
  15. )
  16. from fairseq.models.transformer import (
  17. TransformerModel,
  18. TransformerEncoder,
  19. TransformerDecoder,
  20. base_architecture,
  21. transformer_iwslt_de_en,
  22. transformer_vaswani_wmt_en_de_big,
  23. tiny_architecture
  24. )
  25. from torch import Tensor
  26. DEFAULT_MAX_SOURCE_POSITIONS = 1024
  27. DEFAULT_MAX_TARGET_POSITIONS = 1024
  28. READ_ACTION = 0
  29. WRITE_ACTION = 1
  30. TransformerMonotonicDecoderOut = NamedTuple(
  31. "TransformerMonotonicDecoderOut",
  32. [
  33. ("action", int),
  34. ("p_choose", Optional[Tensor]),
  35. ("attn_list", Optional[List[Optional[Dict[str, Tensor]]]]),
  36. ("encoder_out", Optional[Dict[str, List[Tensor]]]),
  37. ("encoder_padding_mask", Optional[Tensor]),
  38. ],
  39. )
  40. @register_model("transformer_unidirectional")
  41. class TransformerUnidirectionalModel(TransformerModel):
  42. @classmethod
  43. def build_encoder(cls, args, src_dict, embed_tokens):
  44. return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
  45. @register_model("transformer_monotonic")
  46. class TransformerModelSimulTrans(TransformerModel):
  47. @classmethod
  48. def build_encoder(cls, args, src_dict, embed_tokens):
  49. return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
  50. @classmethod
  51. def build_decoder(cls, args, tgt_dict, embed_tokens):
  52. return TransformerMonotonicDecoder(args, tgt_dict, embed_tokens)
  53. class TransformerMonotonicEncoder(TransformerEncoder):
  54. def __init__(self, args, dictionary, embed_tokens):
  55. super().__init__(args, dictionary, embed_tokens)
  56. self.dictionary = dictionary
  57. self.layers = nn.ModuleList([])
  58. self.layers.extend(
  59. [
  60. TransformerMonotonicEncoderLayer(args)
  61. for i in range(args.encoder_layers)
  62. ]
  63. )
  64. class TransformerMonotonicDecoder(TransformerDecoder):
  65. """
  66. Transformer decoder consisting of *args.decoder_layers* layers. Each layer
  67. is a :class:`TransformerDecoderLayer`.
  68. Args:
  69. args (argparse.Namespace): parsed command-line arguments
  70. dictionary (~fairseq.data.Dictionary): decoding dictionary
  71. embed_tokens (torch.nn.Embedding): output embedding
  72. no_encoder_attn (bool, optional): whether to attend to encoder outputs
  73. (default: False).
  74. """
  75. def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
  76. super().__init__(args, dictionary, embed_tokens, no_encoder_attn=False)
  77. self.dictionary = dictionary
  78. self.layers = nn.ModuleList([])
  79. self.layers.extend(
  80. [
  81. TransformerMonotonicDecoderLayer(args)
  82. for _ in range(args.decoder_layers)
  83. ]
  84. )
  85. self.policy_criterion = getattr(args, "policy_criterion", "any")
  86. self.num_updates = None
  87. def set_num_updates(self, num_updates):
  88. self.num_updates = num_updates
  89. def pre_attention(
  90. self,
  91. prev_output_tokens,
  92. encoder_out_dict: Dict[str, List[Tensor]],
  93. incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
  94. ):
  95. positions = (
  96. self.embed_positions(
  97. prev_output_tokens,
  98. incremental_state=incremental_state,
  99. )
  100. if self.embed_positions is not None
  101. else None
  102. )
  103. if incremental_state is not None:
  104. prev_output_tokens = prev_output_tokens[:, -1:]
  105. if positions is not None:
  106. positions = positions[:, -1:]
  107. # embed tokens and positions
  108. x = self.embed_scale * self.embed_tokens(prev_output_tokens)
  109. if self.project_in_dim is not None:
  110. x = self.project_in_dim(x)
  111. if positions is not None:
  112. x += positions
  113. x = self.dropout_module(x)
  114. # B x T x C -> T x B x C
  115. x = x.transpose(0, 1)
  116. encoder_out = encoder_out_dict["encoder_out"][0]
  117. if "encoder_padding_mask" in encoder_out_dict:
  118. encoder_padding_mask = (
  119. encoder_out_dict["encoder_padding_mask"][0]
  120. if encoder_out_dict["encoder_padding_mask"]
  121. and len(encoder_out_dict["encoder_padding_mask"]) > 0
  122. else None
  123. )
  124. else:
  125. encoder_padding_mask = None
  126. return x, encoder_out, encoder_padding_mask
  127. def post_attention(self, x):
  128. if self.layer_norm is not None:
  129. x = self.layer_norm(x)
  130. # T x B x C -> B x T x C
  131. x = x.transpose(0, 1)
  132. if self.project_out_dim is not None:
  133. x = self.project_out_dim(x)
  134. return x
  135. def clean_cache(
  136. self,
  137. incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
  138. end_id: Optional[int] = None,
  139. ):
  140. """
  141. Clean cache in the monotonic layers.
  142. The cache is generated because of a forward pass of decoder has run but no prediction,
  143. so that the self attention key value in decoder is written in the incremental state.
  144. end_id is the last idx of the layers
  145. """
  146. if end_id is None:
  147. end_id = len(self.layers)
  148. for index, layer in enumerate(self.layers):
  149. if index < end_id:
  150. layer.prune_incremental_state(incremental_state)
  151. def extract_features(
  152. self,
  153. prev_output_tokens,
  154. encoder_out: Optional[Dict[str, List[Tensor]]],
  155. incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
  156. full_context_alignment: bool = False, # unused
  157. alignment_layer: Optional[int] = None, # unused
  158. alignment_heads: Optional[int] = None, # unsed
  159. ):
  160. """
  161. Similar to *forward* but only return features.
  162. Returns:
  163. tuple:
  164. - the decoder's features of shape `(batch, tgt_len, embed_dim)`
  165. - a dictionary with any model-specific outputs
  166. """
  167. # incremental_state = None
  168. assert encoder_out is not None
  169. (x, encoder_outs, encoder_padding_mask) = self.pre_attention(
  170. prev_output_tokens, encoder_out, incremental_state
  171. )
  172. attn = None
  173. inner_states = [x]
  174. attn_list: List[Optional[Dict[str, Tensor]]] = []
  175. p_choose = torch.tensor([1.0])
  176. for i, layer in enumerate(self.layers):
  177. x, attn, _ = layer(
  178. x=x,
  179. encoder_out=encoder_outs,
  180. encoder_padding_mask=encoder_padding_mask,
  181. incremental_state=incremental_state,
  182. self_attn_mask=self.buffered_future_mask(x)
  183. if incremental_state is None
  184. else None,
  185. )
  186. inner_states.append(x)
  187. attn_list.append(attn)
  188. if incremental_state is not None:
  189. if_online = incremental_state["online"]["only"]
  190. assert if_online is not None
  191. if if_online.to(torch.bool):
  192. # Online indicates that the encoder states are still changing
  193. assert attn is not None
  194. if self.policy_criterion == "any":
  195. # Any head decide to read than read
  196. head_read = layer.encoder_attn._get_monotonic_buffer(incremental_state)["head_read"]
  197. assert head_read is not None
  198. if head_read.any():
  199. # We need to prune the last self_attn saved_state
  200. # if model decide not to read
  201. # otherwise there will be duplicated saved_state
  202. self.clean_cache(incremental_state, i + 1)
  203. return x, TransformerMonotonicDecoderOut(
  204. action=0,
  205. p_choose=p_choose,
  206. attn_list=None,
  207. encoder_out=None,
  208. encoder_padding_mask=None,
  209. )
  210. x = self.post_attention(x)
  211. return x, TransformerMonotonicDecoderOut(
  212. action=1,
  213. p_choose=p_choose,
  214. attn_list=attn_list,
  215. encoder_out=encoder_out,
  216. encoder_padding_mask=encoder_padding_mask,
  217. )
  218. @register_model_architecture("transformer_monotonic", "transformer_monotonic")
  219. def base_monotonic_architecture(args):
  220. base_architecture(args)
  221. args.encoder_unidirectional = getattr(args, "encoder_unidirectional", False)
  222. @register_model_architecture(
  223. "transformer_monotonic", "transformer_monotonic_iwslt_de_en"
  224. )
  225. def transformer_monotonic_iwslt_de_en(args):
  226. transformer_iwslt_de_en(args)
  227. base_monotonic_architecture(args)
  228. # parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
  229. @register_model_architecture(
  230. "transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_de_big"
  231. )
  232. def transformer_monotonic_vaswani_wmt_en_de_big(args):
  233. transformer_vaswani_wmt_en_de_big(args)
  234. @register_model_architecture(
  235. "transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_fr_big"
  236. )
  237. def transformer_monotonic_vaswani_wmt_en_fr_big(args):
  238. transformer_monotonic_vaswani_wmt_en_fr_big(args)
  239. @register_model_architecture(
  240. "transformer_unidirectional", "transformer_unidirectional_iwslt_de_en"
  241. )
  242. def transformer_unidirectional_iwslt_de_en(args):
  243. transformer_iwslt_de_en(args)
  244. @register_model_architecture("transformer_monotonic", "transformer_monotonic_tiny")
  245. def monotonic_tiny_architecture(args):
  246. tiny_architecture(args)
  247. base_monotonic_architecture(args)