PageRenderTime 66ms CodeModel.GetById 34ms RepoModel.GetById 0ms app.codeStats 0ms

/wav2vec_cycle_code/FragmentVC/models/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py

https://gitlab.com/lwd17/enhanced_examplar_ae
Python | 600 lines | 439 code | 71 blank | 90 comment | 89 complexity | 30ea7019446b776e0b0b376cbf3195f9 MD5 | raw file
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import math
  6. from collections import namedtuple
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from fairseq import options, utils
  11. from fairseq.modules import (
  12. AdaptiveSoftmax,
  13. LayerNorm,
  14. MultiheadAttention,
  15. PositionalEmbedding,
  16. )
  17. EncoderOut = namedtuple(
  18. "TransformerEncoderOut",
  19. [
  20. "encoder_out", # T x B x C
  21. "encoder_padding_mask", # B x T
  22. "encoder_embedding", # B x T x C
  23. "encoder_states", # List[T x B x C]
  24. ],
  25. )
  26. class TransformerEncoderEmbedding(nn.Module):
  27. """Encoder Embedding + Positional Embedding"""
  28. def __init__(self, args, embed_tokens):
  29. super().__init__()
  30. self.dropout = args.dropout
  31. self.max_source_positions = args.max_source_positions
  32. self.embed_tokens = embed_tokens
  33. if isinstance(embed_tokens, nn.ModuleList):
  34. self.padding_idx = embed_tokens[0].padding_idx
  35. embed_dim = sum(e.embedding_dim for e in embed_tokens)
  36. else:
  37. self.padding_idx = embed_tokens.padding_idx
  38. embed_dim = embed_tokens.embedding_dim
  39. self.embed_scale = math.sqrt(embed_dim)
  40. self.embed_positions = (
  41. PositionalEmbedding(
  42. args.max_source_positions,
  43. embed_dim,
  44. self.padding_idx,
  45. learned=args.encoder_learned_pos,
  46. )
  47. if not args.no_token_positional_embeddings
  48. else None
  49. )
  50. if getattr(args, "layernorm_embedding", False):
  51. self.layernorm_embedding = LayerNorm(embed_dim)
  52. else:
  53. self.layernorm_embedding = None
  54. def forward(self, input):
  55. # embed tokens and positions
  56. src_tokens = input[0]
  57. prev_output_tokens = input[2]
  58. if isinstance(self.embed_tokens, nn.ModuleList):
  59. x_embed_list = []
  60. for embed_tokens_part in self.embed_tokens:
  61. x_embed_list.append(embed_tokens_part(src_tokens))
  62. embedded = torch.cat(x_embed_list, dim=-1)
  63. else:
  64. embedded = self.embed_tokens(src_tokens)
  65. x = embed = self.embed_scale * embedded
  66. if self.embed_positions is not None:
  67. x = embed + self.embed_positions(src_tokens)
  68. if self.layernorm_embedding:
  69. x = self.layernorm_embedding(x)
  70. x = F.dropout(x, p=self.dropout, training=self.training)
  71. # B x T x C -> T x B x C
  72. x = x.transpose(0, 1)
  73. # compute padding mask
  74. encoder_padding_mask = src_tokens.eq(self.padding_idx)
  75. return (x, encoder_padding_mask, prev_output_tokens)
  76. class TransformerEncoderLayerNorm(nn.Module):
  77. """
  78. Layer norm at the the end of all encoder layers if
  79. args.encoder_enormalize_before = True
  80. """
  81. def __init__(self, args, embed_dim):
  82. super().__init__()
  83. if args.encoder_normalize_before:
  84. self.layer_norm = LayerNorm(embed_dim)
  85. else:
  86. self.layer_norm = None
  87. def forward(self, input):
  88. x = input[0]
  89. encoder_padding_mask = input[1]
  90. prev_output_tokens = input[2]
  91. if self.layer_norm:
  92. x = self.layer_norm(x)
  93. # keeping track of the incremental_state is not supported yet
  94. return (x, encoder_padding_mask, prev_output_tokens)
  95. class TransformerDecoderEmbedding(nn.Module):
  96. """Decoder Embedding + Positional Embedding"""
  97. def __init__(self, args, embed_tokens):
  98. super().__init__()
  99. self.dropout = args.dropout
  100. self.share_input_output_embed = args.share_decoder_input_output_embed
  101. input_embed_dim = (
  102. sum(e.embedding_dim for e in embed_tokens)
  103. if isinstance(embed_tokens, nn.ModuleList)
  104. else embed_tokens.embedding_dim
  105. )
  106. embed_dim = args.decoder_embed_dim
  107. self.output_embed_dim = args.decoder_output_dim
  108. padding_idx = (
  109. embed_tokens[0].padding_idx
  110. if isinstance(embed_tokens, nn.ModuleList)
  111. else embed_tokens.padding_idx
  112. )
  113. self.max_target_positions = args.max_target_positions
  114. self.embed_tokens = embed_tokens
  115. self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
  116. self.project_in_dim = (
  117. Linear(input_embed_dim, embed_dim, bias=False)
  118. if embed_dim != input_embed_dim
  119. else None
  120. )
  121. self.embed_positions = (
  122. PositionalEmbedding(
  123. args.max_target_positions,
  124. embed_dim,
  125. padding_idx,
  126. learned=args.decoder_learned_pos,
  127. )
  128. if not args.no_token_positional_embeddings
  129. else None
  130. )
  131. def forward(self, input):
  132. mt_task = False
  133. if isinstance(input, tuple):
  134. if len(input) == 3:
  135. encoder_out = input[0]
  136. encoder_padding_mask = input[1]
  137. prev_output_tokens = input[2]
  138. incremental_state = None # Hardcoding to avoid passing of None objects
  139. mt_task = True
  140. else:
  141. # HACK for now, need to fix (TODO sidgoyal)
  142. prev_output_tokens = input[0]
  143. # discard "src_lengths"
  144. encoder_out = None
  145. encoder_padding_mask = None
  146. incremental_state = None
  147. else:
  148. prev_output_tokens = input
  149. encoder_out = None
  150. encoder_padding_mask = None
  151. incremental_state = None
  152. positions = (
  153. self.embed_positions(
  154. prev_output_tokens,
  155. incremental_state=incremental_state,
  156. )
  157. if self.embed_positions is not None
  158. else None
  159. )
  160. if incremental_state is not None:
  161. prev_output_tokens = prev_output_tokens[:, -1:]
  162. if positions is not None:
  163. positions = positions[:, -1:]
  164. # embed tokens and positions
  165. if isinstance(self.embed_tokens, nn.ModuleList):
  166. x_embed_list = []
  167. for embed_tokens_part in self.embed_tokens:
  168. x_embed_list.append(embed_tokens_part(prev_output_tokens))
  169. x = self.embed_scale * torch.cat(x_embed_list, dim=-1)
  170. else:
  171. x = self.embed_scale * self.embed_tokens(prev_output_tokens)
  172. if self.project_in_dim is not None:
  173. x = self.project_in_dim(x)
  174. if positions is not None:
  175. x += positions
  176. x = F.dropout(x, p=self.dropout, training=self.training)
  177. # B x T x C -> T x B x C
  178. x = x.transpose(0, 1)
  179. if mt_task:
  180. return (x, encoder_out, encoder_padding_mask)
  181. return x
  182. class TransformerDecoderOutputLayer(nn.Module):
  183. def __init__(self, args, embed_tokens, dictionary):
  184. super().__init__()
  185. self.share_input_output_embed = args.share_decoder_input_output_embed
  186. self.embed_tokens = embed_tokens
  187. self.output_embed_dim = args.decoder_output_dim
  188. embed_dim = args.decoder_embed_dim
  189. self.project_out_dim = (
  190. Linear(embed_dim, self.output_embed_dim, bias=False)
  191. if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights
  192. else None
  193. )
  194. self.adaptive_softmax = None
  195. if args.adaptive_softmax_cutoff is not None:
  196. assert not isinstance(embed_tokens, nn.ModuleList)
  197. self.adaptive_softmax = AdaptiveSoftmax(
  198. len(dictionary),
  199. self.output_embed_dim,
  200. options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
  201. dropout=args.adaptive_softmax_dropout,
  202. adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
  203. factor=args.adaptive_softmax_factor,
  204. tie_proj=args.tie_adaptive_proj,
  205. )
  206. elif not self.share_input_output_embed:
  207. self.embed_tokens = nn.Parameter(
  208. torch.Tensor(len(dictionary), self.output_embed_dim)
  209. )
  210. nn.init.normal_(
  211. self.embed_tokens, mean=0, std=self.output_embed_dim ** -0.5
  212. )
  213. if args.decoder_normalize_before and not getattr(
  214. args, "no_decoder_final_norm", False
  215. ):
  216. self.layer_norm = LayerNorm(embed_dim)
  217. else:
  218. self.layer_norm = None
  219. def forward(self, input, apply_final_proj=True):
  220. if isinstance(input, tuple):
  221. x = input[0]
  222. else:
  223. x = input
  224. if self.layer_norm:
  225. x = self.layer_norm(x)
  226. # T x B x C -> B x T x C
  227. x = x.transpose(0, 1)
  228. if self.project_out_dim is not None:
  229. x = self.project_out_dim(x)
  230. if apply_final_proj:
  231. x = self.output_layer(x)
  232. return x
  233. def output_layer(self, features, **kwargs):
  234. """Project features to the vocabulary size."""
  235. if self.adaptive_softmax is None:
  236. # project back to size of vocabulary
  237. if self.share_input_output_embed:
  238. if isinstance(self.embed_tokens, nn.ModuleList):
  239. output = None
  240. for i, emb in enumerate(self.embed_tokens):
  241. sidx = i * emb.embedding_dim
  242. eidx = (i + 1) * emb.embedding_dim
  243. if output is None:
  244. output = F.linear(features[:, :, sidx:eidx], emb.weight)
  245. else:
  246. output += F.linear(features[:, :, sidx:eidx], emb.weight)
  247. return output
  248. else:
  249. return F.linear(features, self.embed_tokens.weight)
  250. else:
  251. return F.linear(features, self.embed_tokens)
  252. else:
  253. return features
  254. class TransformerEncoderLayer(nn.Module):
  255. """Encoder layer block.
  256. In the original paper each operation (multi-head attention or FFN) is
  257. postprocessed with: `dropout -> add residual -> layernorm`. In the
  258. tensor2tensor code they suggest that learning is more robust when
  259. preprocessing each layer with layernorm and postprocessing with:
  260. `dropout -> add residual`. We default to the approach in the paper, but the
  261. tensor2tensor approach can be enabled by setting
  262. *args.encoder_normalize_before* to ``True``.
  263. Args:
  264. args (argparse.Namespace): parsed command-line arguments
  265. """
  266. def __init__(self, args):
  267. super().__init__()
  268. self.embed_dim = args.encoder_embed_dim
  269. self.self_attn = MultiheadAttention(
  270. self.embed_dim,
  271. args.encoder_attention_heads,
  272. dropout=args.attention_dropout,
  273. self_attention=True,
  274. )
  275. self.self_attn_layer_norm = LayerNorm(self.embed_dim)
  276. self.dropout = args.dropout
  277. self.activation_fn = utils.get_activation_fn(
  278. activation=getattr(args, "activation_fn", "relu")
  279. )
  280. self.activation_dropout = getattr(args, "activation_dropout", 0)
  281. if self.activation_dropout == 0:
  282. # for backwards compatibility with models that use args.relu_dropout
  283. self.activation_dropout = getattr(args, "relu_dropout", 0)
  284. self.normalize_before = args.encoder_normalize_before
  285. self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
  286. self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
  287. self.final_layer_norm = LayerNorm(self.embed_dim)
  288. def upgrade_state_dict_named(self, state_dict, name):
  289. """
  290. Rename layer norm states from `...layer_norms.0.weight` to
  291. `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
  292. `...final_layer_norm.weight`
  293. """
  294. layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"}
  295. for old, new in layer_norm_map.items():
  296. for m in ("weight", "bias"):
  297. k = "{}.layer_norms.{}.{}".format(name, old, m)
  298. if k in state_dict:
  299. state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
  300. del state_dict[k]
  301. def forward(self, input):
  302. """
  303. Args:
  304. input (Tuple):
  305. input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
  306. input[1] (ByteTensor/FloatTensor): encoder padding mask -
  307. binary ByteTensor of shape `(batch, src_len)` where padding elements
  308. are indicated by ``1``.
  309. input[2] (LongTensor): previous decoder outputs of shape
  310. `(batch, tgt_len)`, for teacher forcing)
  311. Returns:
  312. output (Tuple):
  313. output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
  314. output[1] (ByteTensor/FloatTensor): encoder padding mask
  315. output[2] (LongTensor): previous decoder outputs
  316. """
  317. x = input[0]
  318. encoder_padding_mask = input[1]
  319. prev_output_tokens = input[2]
  320. residual = x
  321. x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
  322. x, _ = self.self_attn(
  323. query=x, key=x, value=x, key_padding_mask=encoder_padding_mask
  324. )
  325. x = F.dropout(x, p=self.dropout, training=self.training)
  326. x = residual + x
  327. x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
  328. residual = x
  329. x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
  330. x = self.activation_fn(self.fc1(x))
  331. x = F.dropout(x, p=self.activation_dropout, training=self.training)
  332. x = self.fc2(x)
  333. x = F.dropout(x, p=self.dropout, training=self.training)
  334. x = residual + x
  335. x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
  336. return (x, encoder_padding_mask, prev_output_tokens)
  337. def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
  338. assert before ^ after
  339. if after ^ self.normalize_before:
  340. return layer_norm(x)
  341. else:
  342. return x
  343. class TransformerDecoderLayer(nn.Module):
  344. """Decoder layer block.
  345. In the original paper each operation (multi-head attention, encoder
  346. attention or FFN) is postprocessed with: `dropout -> add residual ->
  347. layernorm`. In the tensor2tensor code they suggest that learning is more
  348. robust when preprocessing each layer with layernorm and postprocessing with:
  349. `dropout -> add residual`. We default to the approach in the paper, but the
  350. tensor2tensor approach can be enabled by setting
  351. *args.decoder_normalize_before* to ``True``.
  352. Args:
  353. args (argparse.Namespace): parsed command-line arguments
  354. no_encoder_attn (bool, optional): whether to attend to encoder outputs
  355. (default: False).
  356. """
  357. def __init__(
  358. self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
  359. ):
  360. super().__init__()
  361. self.embed_dim = args.decoder_embed_dim
  362. self.self_attn = MultiheadAttention(
  363. embed_dim=self.embed_dim,
  364. num_heads=args.decoder_attention_heads,
  365. dropout=args.attention_dropout,
  366. add_bias_kv=add_bias_kv,
  367. add_zero_attn=add_zero_attn,
  368. self_attention=True,
  369. )
  370. self.dropout = args.dropout
  371. self.activation_fn = utils.get_activation_fn(
  372. activation=getattr(args, "activation_fn", "relu")
  373. )
  374. self.activation_dropout = getattr(args, "activation_dropout", 0)
  375. if self.activation_dropout == 0:
  376. # for backwards compatibility with models that use args.relu_dropout
  377. self.activation_dropout = getattr(args, "relu_dropout", 0)
  378. self.normalize_before = args.decoder_normalize_before
  379. # use layerNorm rather than FusedLayerNorm for exporting.
  380. # char_inputs can be used to determint this.
  381. # TODO remove this once we update apex with the fix
  382. export = getattr(args, "char_inputs", False)
  383. self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
  384. if no_encoder_attn:
  385. self.encoder_attn = None
  386. self.encoder_attn_layer_norm = None
  387. else:
  388. self.encoder_attn = MultiheadAttention(
  389. self.embed_dim,
  390. args.decoder_attention_heads,
  391. kdim=getattr(args, "encoder_embed_dim", None),
  392. vdim=getattr(args, "encoder_embed_dim", None),
  393. dropout=args.attention_dropout,
  394. encoder_decoder_attention=True,
  395. )
  396. self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
  397. self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
  398. self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
  399. self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
  400. self.need_attn = True
  401. self.onnx_trace = False
  402. def prepare_for_onnx_export_(self):
  403. self.onnx_trace = True
  404. def forward(self, input):
  405. """
  406. Args:
  407. input (Tuple):
  408. input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
  409. input[1] (Tensor): encoder output of shape `(batch, src_len, embed_dim)`
  410. input[2] (ByteTensor/FloatTensor): encoder padding mask -
  411. binary ByteTensor of shape `(batch, src_len)` where padding elements
  412. are indicated by ``1``.
  413. Returns:
  414. output (Tuple):
  415. output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
  416. output[1] (ByteTensor/FloatTensor): encoder padding mask
  417. output[2] (LongTensor): previous decoder outputs
  418. """
  419. # Note: incremental state is not yet supported
  420. mt_task = False
  421. if isinstance(input, tuple):
  422. x = input[0]
  423. encoder_out = input[1]
  424. encoder_padding_mask = input[2]
  425. incremental_state = None
  426. mt_task = True
  427. else:
  428. x = input
  429. encoder_out = None
  430. encoder_padding_mask = None
  431. incremental_state = None
  432. if incremental_state is None:
  433. self_attn_mask = self.buffered_future_mask(x)
  434. else:
  435. self_attn_mask = None
  436. # TODO: add back prev_self_attn_state, prev_attn_state,
  437. # self_attn_padding_mask
  438. prev_self_attn_state = None
  439. prev_attn_state = None
  440. self_attn_padding_mask = None
  441. residual = x
  442. x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
  443. if prev_self_attn_state is not None:
  444. if incremental_state is None:
  445. incremental_state = {}
  446. prev_key, prev_value = prev_self_attn_state
  447. saved_state = {"prev_key": prev_key, "prev_value": prev_value}
  448. self.self_attn._set_input_buffer(incremental_state, saved_state)
  449. x, attn = self.self_attn(
  450. query=x,
  451. key=x,
  452. value=x,
  453. key_padding_mask=self_attn_padding_mask,
  454. incremental_state=incremental_state,
  455. need_weights=False,
  456. attn_mask=self_attn_mask,
  457. )
  458. x = F.dropout(x, p=self.dropout, training=self.training)
  459. x = residual + x
  460. x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
  461. if self.encoder_attn is not None:
  462. residual = x
  463. x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
  464. if prev_attn_state is not None:
  465. if incremental_state is None:
  466. incremental_state = {}
  467. prev_key, prev_value = prev_attn_state
  468. saved_state = {"prev_key": prev_key, "prev_value": prev_value}
  469. self.encoder_attn._set_input_buffer(incremental_state, saved_state)
  470. x, attn = self.encoder_attn(
  471. query=x,
  472. key=encoder_out,
  473. value=encoder_out,
  474. key_padding_mask=encoder_padding_mask,
  475. incremental_state=incremental_state,
  476. static_kv=True,
  477. need_weights=(not self.training and self.need_attn),
  478. )
  479. x = F.dropout(x, p=self.dropout, training=self.training)
  480. x = residual + x
  481. x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
  482. residual = x
  483. x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
  484. x = self.activation_fn(self.fc1(x))
  485. x = F.dropout(x, p=self.activation_dropout, training=self.training)
  486. x = self.fc2(x)
  487. x = F.dropout(x, p=self.dropout, training=self.training)
  488. x = residual + x
  489. x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
  490. if mt_task:
  491. return (x, encoder_out, encoder_padding_mask)
  492. return x
  493. def buffered_future_mask(self, tensor):
  494. dim = tensor.size(0)
  495. if (
  496. not hasattr(self, "_future_mask")
  497. or self._future_mask is None
  498. or self._future_mask.device != tensor.device
  499. ):
  500. self._future_mask = torch.triu(
  501. utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
  502. )
  503. if self._future_mask.size(0) < dim:
  504. self._future_mask = torch.triu(
  505. utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
  506. )
  507. return self._future_mask[:dim, :dim]
  508. def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
  509. assert before ^ after
  510. if after ^ self.normalize_before:
  511. return layer_norm(x)
  512. else:
  513. return x
  514. def make_generation_fast_(self, need_attn=False, **kwargs):
  515. self.need_attn = need_attn
  516. def Embedding(num_embeddings, embedding_dim, padding_idx):
  517. m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
  518. nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
  519. nn.init.constant_(m.weight[padding_idx], 0)
  520. return m
  521. def Linear(in_features, out_features, bias=True):
  522. m = nn.Linear(in_features, out_features, bias)
  523. nn.init.xavier_uniform_(m.weight)
  524. if bias:
  525. nn.init.constant_(m.bias, 0.0)
  526. return m