/wav2vec_cycle_code/FragmentVC/models/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py
Python | 600 lines | 439 code | 71 blank | 90 comment | 89 complexity | 30ea7019446b776e0b0b376cbf3195f9 MD5 | raw file
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import math
- from collections import namedtuple
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from fairseq import options, utils
- from fairseq.modules import (
- AdaptiveSoftmax,
- LayerNorm,
- MultiheadAttention,
- PositionalEmbedding,
- )
- EncoderOut = namedtuple(
- "TransformerEncoderOut",
- [
- "encoder_out", # T x B x C
- "encoder_padding_mask", # B x T
- "encoder_embedding", # B x T x C
- "encoder_states", # List[T x B x C]
- ],
- )
- class TransformerEncoderEmbedding(nn.Module):
- """Encoder Embedding + Positional Embedding"""
- def __init__(self, args, embed_tokens):
- super().__init__()
- self.dropout = args.dropout
- self.max_source_positions = args.max_source_positions
- self.embed_tokens = embed_tokens
- if isinstance(embed_tokens, nn.ModuleList):
- self.padding_idx = embed_tokens[0].padding_idx
- embed_dim = sum(e.embedding_dim for e in embed_tokens)
- else:
- self.padding_idx = embed_tokens.padding_idx
- embed_dim = embed_tokens.embedding_dim
- self.embed_scale = math.sqrt(embed_dim)
- self.embed_positions = (
- PositionalEmbedding(
- args.max_source_positions,
- embed_dim,
- self.padding_idx,
- learned=args.encoder_learned_pos,
- )
- if not args.no_token_positional_embeddings
- else None
- )
- if getattr(args, "layernorm_embedding", False):
- self.layernorm_embedding = LayerNorm(embed_dim)
- else:
- self.layernorm_embedding = None
- def forward(self, input):
- # embed tokens and positions
- src_tokens = input[0]
- prev_output_tokens = input[2]
- if isinstance(self.embed_tokens, nn.ModuleList):
- x_embed_list = []
- for embed_tokens_part in self.embed_tokens:
- x_embed_list.append(embed_tokens_part(src_tokens))
- embedded = torch.cat(x_embed_list, dim=-1)
- else:
- embedded = self.embed_tokens(src_tokens)
- x = embed = self.embed_scale * embedded
- if self.embed_positions is not None:
- x = embed + self.embed_positions(src_tokens)
- if self.layernorm_embedding:
- x = self.layernorm_embedding(x)
- x = F.dropout(x, p=self.dropout, training=self.training)
- # B x T x C -> T x B x C
- x = x.transpose(0, 1)
- # compute padding mask
- encoder_padding_mask = src_tokens.eq(self.padding_idx)
- return (x, encoder_padding_mask, prev_output_tokens)
- class TransformerEncoderLayerNorm(nn.Module):
- """
- Layer norm at the the end of all encoder layers if
- args.encoder_enormalize_before = True
- """
- def __init__(self, args, embed_dim):
- super().__init__()
- if args.encoder_normalize_before:
- self.layer_norm = LayerNorm(embed_dim)
- else:
- self.layer_norm = None
- def forward(self, input):
- x = input[0]
- encoder_padding_mask = input[1]
- prev_output_tokens = input[2]
- if self.layer_norm:
- x = self.layer_norm(x)
- # keeping track of the incremental_state is not supported yet
- return (x, encoder_padding_mask, prev_output_tokens)
- class TransformerDecoderEmbedding(nn.Module):
- """Decoder Embedding + Positional Embedding"""
- def __init__(self, args, embed_tokens):
- super().__init__()
- self.dropout = args.dropout
- self.share_input_output_embed = args.share_decoder_input_output_embed
- input_embed_dim = (
- sum(e.embedding_dim for e in embed_tokens)
- if isinstance(embed_tokens, nn.ModuleList)
- else embed_tokens.embedding_dim
- )
- embed_dim = args.decoder_embed_dim
- self.output_embed_dim = args.decoder_output_dim
- padding_idx = (
- embed_tokens[0].padding_idx
- if isinstance(embed_tokens, nn.ModuleList)
- else embed_tokens.padding_idx
- )
- self.max_target_positions = args.max_target_positions
- self.embed_tokens = embed_tokens
- self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
- self.project_in_dim = (
- Linear(input_embed_dim, embed_dim, bias=False)
- if embed_dim != input_embed_dim
- else None
- )
- self.embed_positions = (
- PositionalEmbedding(
- args.max_target_positions,
- embed_dim,
- padding_idx,
- learned=args.decoder_learned_pos,
- )
- if not args.no_token_positional_embeddings
- else None
- )
- def forward(self, input):
- mt_task = False
- if isinstance(input, tuple):
- if len(input) == 3:
- encoder_out = input[0]
- encoder_padding_mask = input[1]
- prev_output_tokens = input[2]
- incremental_state = None # Hardcoding to avoid passing of None objects
- mt_task = True
- else:
- # HACK for now, need to fix (TODO sidgoyal)
- prev_output_tokens = input[0]
- # discard "src_lengths"
- encoder_out = None
- encoder_padding_mask = None
- incremental_state = None
- else:
- prev_output_tokens = input
- encoder_out = None
- encoder_padding_mask = None
- incremental_state = None
- positions = (
- self.embed_positions(
- prev_output_tokens,
- incremental_state=incremental_state,
- )
- if self.embed_positions is not None
- else None
- )
- if incremental_state is not None:
- prev_output_tokens = prev_output_tokens[:, -1:]
- if positions is not None:
- positions = positions[:, -1:]
- # embed tokens and positions
- if isinstance(self.embed_tokens, nn.ModuleList):
- x_embed_list = []
- for embed_tokens_part in self.embed_tokens:
- x_embed_list.append(embed_tokens_part(prev_output_tokens))
- x = self.embed_scale * torch.cat(x_embed_list, dim=-1)
- else:
- x = self.embed_scale * self.embed_tokens(prev_output_tokens)
- if self.project_in_dim is not None:
- x = self.project_in_dim(x)
- if positions is not None:
- x += positions
- x = F.dropout(x, p=self.dropout, training=self.training)
- # B x T x C -> T x B x C
- x = x.transpose(0, 1)
- if mt_task:
- return (x, encoder_out, encoder_padding_mask)
- return x
- class TransformerDecoderOutputLayer(nn.Module):
- def __init__(self, args, embed_tokens, dictionary):
- super().__init__()
- self.share_input_output_embed = args.share_decoder_input_output_embed
- self.embed_tokens = embed_tokens
- self.output_embed_dim = args.decoder_output_dim
- embed_dim = args.decoder_embed_dim
- self.project_out_dim = (
- Linear(embed_dim, self.output_embed_dim, bias=False)
- if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights
- else None
- )
- self.adaptive_softmax = None
- if args.adaptive_softmax_cutoff is not None:
- assert not isinstance(embed_tokens, nn.ModuleList)
- self.adaptive_softmax = AdaptiveSoftmax(
- len(dictionary),
- self.output_embed_dim,
- options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
- dropout=args.adaptive_softmax_dropout,
- adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
- factor=args.adaptive_softmax_factor,
- tie_proj=args.tie_adaptive_proj,
- )
- elif not self.share_input_output_embed:
- self.embed_tokens = nn.Parameter(
- torch.Tensor(len(dictionary), self.output_embed_dim)
- )
- nn.init.normal_(
- self.embed_tokens, mean=0, std=self.output_embed_dim ** -0.5
- )
- if args.decoder_normalize_before and not getattr(
- args, "no_decoder_final_norm", False
- ):
- self.layer_norm = LayerNorm(embed_dim)
- else:
- self.layer_norm = None
- def forward(self, input, apply_final_proj=True):
- if isinstance(input, tuple):
- x = input[0]
- else:
- x = input
- if self.layer_norm:
- x = self.layer_norm(x)
- # T x B x C -> B x T x C
- x = x.transpose(0, 1)
- if self.project_out_dim is not None:
- x = self.project_out_dim(x)
- if apply_final_proj:
- x = self.output_layer(x)
- return x
- def output_layer(self, features, **kwargs):
- """Project features to the vocabulary size."""
- if self.adaptive_softmax is None:
- # project back to size of vocabulary
- if self.share_input_output_embed:
- if isinstance(self.embed_tokens, nn.ModuleList):
- output = None
- for i, emb in enumerate(self.embed_tokens):
- sidx = i * emb.embedding_dim
- eidx = (i + 1) * emb.embedding_dim
- if output is None:
- output = F.linear(features[:, :, sidx:eidx], emb.weight)
- else:
- output += F.linear(features[:, :, sidx:eidx], emb.weight)
- return output
- else:
- return F.linear(features, self.embed_tokens.weight)
- else:
- return F.linear(features, self.embed_tokens)
- else:
- return features
- class TransformerEncoderLayer(nn.Module):
- """Encoder layer block.
- In the original paper each operation (multi-head attention or FFN) is
- postprocessed with: `dropout -> add residual -> layernorm`. In the
- tensor2tensor code they suggest that learning is more robust when
- preprocessing each layer with layernorm and postprocessing with:
- `dropout -> add residual`. We default to the approach in the paper, but the
- tensor2tensor approach can be enabled by setting
- *args.encoder_normalize_before* to ``True``.
- Args:
- args (argparse.Namespace): parsed command-line arguments
- """
- def __init__(self, args):
- super().__init__()
- self.embed_dim = args.encoder_embed_dim
- self.self_attn = MultiheadAttention(
- self.embed_dim,
- args.encoder_attention_heads,
- dropout=args.attention_dropout,
- self_attention=True,
- )
- self.self_attn_layer_norm = LayerNorm(self.embed_dim)
- self.dropout = args.dropout
- self.activation_fn = utils.get_activation_fn(
- activation=getattr(args, "activation_fn", "relu")
- )
- self.activation_dropout = getattr(args, "activation_dropout", 0)
- if self.activation_dropout == 0:
- # for backwards compatibility with models that use args.relu_dropout
- self.activation_dropout = getattr(args, "relu_dropout", 0)
- self.normalize_before = args.encoder_normalize_before
- self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
- self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
- self.final_layer_norm = LayerNorm(self.embed_dim)
- def upgrade_state_dict_named(self, state_dict, name):
- """
- Rename layer norm states from `...layer_norms.0.weight` to
- `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
- `...final_layer_norm.weight`
- """
- layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"}
- for old, new in layer_norm_map.items():
- for m in ("weight", "bias"):
- k = "{}.layer_norms.{}.{}".format(name, old, m)
- if k in state_dict:
- state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
- del state_dict[k]
- def forward(self, input):
- """
- Args:
- input (Tuple):
- input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
- input[1] (ByteTensor/FloatTensor): encoder padding mask -
- binary ByteTensor of shape `(batch, src_len)` where padding elements
- are indicated by ``1``.
- input[2] (LongTensor): previous decoder outputs of shape
- `(batch, tgt_len)`, for teacher forcing)
- Returns:
- output (Tuple):
- output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
- output[1] (ByteTensor/FloatTensor): encoder padding mask
- output[2] (LongTensor): previous decoder outputs
- """
- x = input[0]
- encoder_padding_mask = input[1]
- prev_output_tokens = input[2]
- residual = x
- x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
- x, _ = self.self_attn(
- query=x, key=x, value=x, key_padding_mask=encoder_padding_mask
- )
- x = F.dropout(x, p=self.dropout, training=self.training)
- x = residual + x
- x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
- residual = x
- x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
- x = self.activation_fn(self.fc1(x))
- x = F.dropout(x, p=self.activation_dropout, training=self.training)
- x = self.fc2(x)
- x = F.dropout(x, p=self.dropout, training=self.training)
- x = residual + x
- x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
- return (x, encoder_padding_mask, prev_output_tokens)
- def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
- assert before ^ after
- if after ^ self.normalize_before:
- return layer_norm(x)
- else:
- return x
- class TransformerDecoderLayer(nn.Module):
- """Decoder layer block.
- In the original paper each operation (multi-head attention, encoder
- attention or FFN) is postprocessed with: `dropout -> add residual ->
- layernorm`. In the tensor2tensor code they suggest that learning is more
- robust when preprocessing each layer with layernorm and postprocessing with:
- `dropout -> add residual`. We default to the approach in the paper, but the
- tensor2tensor approach can be enabled by setting
- *args.decoder_normalize_before* to ``True``.
- Args:
- args (argparse.Namespace): parsed command-line arguments
- no_encoder_attn (bool, optional): whether to attend to encoder outputs
- (default: False).
- """
- def __init__(
- self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
- ):
- super().__init__()
- self.embed_dim = args.decoder_embed_dim
- self.self_attn = MultiheadAttention(
- embed_dim=self.embed_dim,
- num_heads=args.decoder_attention_heads,
- dropout=args.attention_dropout,
- add_bias_kv=add_bias_kv,
- add_zero_attn=add_zero_attn,
- self_attention=True,
- )
- self.dropout = args.dropout
- self.activation_fn = utils.get_activation_fn(
- activation=getattr(args, "activation_fn", "relu")
- )
- self.activation_dropout = getattr(args, "activation_dropout", 0)
- if self.activation_dropout == 0:
- # for backwards compatibility with models that use args.relu_dropout
- self.activation_dropout = getattr(args, "relu_dropout", 0)
- self.normalize_before = args.decoder_normalize_before
- # use layerNorm rather than FusedLayerNorm for exporting.
- # char_inputs can be used to determint this.
- # TODO remove this once we update apex with the fix
- export = getattr(args, "char_inputs", False)
- self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
- if no_encoder_attn:
- self.encoder_attn = None
- self.encoder_attn_layer_norm = None
- else:
- self.encoder_attn = MultiheadAttention(
- self.embed_dim,
- args.decoder_attention_heads,
- kdim=getattr(args, "encoder_embed_dim", None),
- vdim=getattr(args, "encoder_embed_dim", None),
- dropout=args.attention_dropout,
- encoder_decoder_attention=True,
- )
- self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
- self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
- self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
- self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
- self.need_attn = True
- self.onnx_trace = False
- def prepare_for_onnx_export_(self):
- self.onnx_trace = True
- def forward(self, input):
- """
- Args:
- input (Tuple):
- input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
- input[1] (Tensor): encoder output of shape `(batch, src_len, embed_dim)`
- input[2] (ByteTensor/FloatTensor): encoder padding mask -
- binary ByteTensor of shape `(batch, src_len)` where padding elements
- are indicated by ``1``.
- Returns:
- output (Tuple):
- output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
- output[1] (ByteTensor/FloatTensor): encoder padding mask
- output[2] (LongTensor): previous decoder outputs
- """
- # Note: incremental state is not yet supported
- mt_task = False
- if isinstance(input, tuple):
- x = input[0]
- encoder_out = input[1]
- encoder_padding_mask = input[2]
- incremental_state = None
- mt_task = True
- else:
- x = input
- encoder_out = None
- encoder_padding_mask = None
- incremental_state = None
- if incremental_state is None:
- self_attn_mask = self.buffered_future_mask(x)
- else:
- self_attn_mask = None
- # TODO: add back prev_self_attn_state, prev_attn_state,
- # self_attn_padding_mask
- prev_self_attn_state = None
- prev_attn_state = None
- self_attn_padding_mask = None
- residual = x
- x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
- if prev_self_attn_state is not None:
- if incremental_state is None:
- incremental_state = {}
- prev_key, prev_value = prev_self_attn_state
- saved_state = {"prev_key": prev_key, "prev_value": prev_value}
- self.self_attn._set_input_buffer(incremental_state, saved_state)
- x, attn = self.self_attn(
- query=x,
- key=x,
- value=x,
- key_padding_mask=self_attn_padding_mask,
- incremental_state=incremental_state,
- need_weights=False,
- attn_mask=self_attn_mask,
- )
- x = F.dropout(x, p=self.dropout, training=self.training)
- x = residual + x
- x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
- if self.encoder_attn is not None:
- residual = x
- x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
- if prev_attn_state is not None:
- if incremental_state is None:
- incremental_state = {}
- prev_key, prev_value = prev_attn_state
- saved_state = {"prev_key": prev_key, "prev_value": prev_value}
- self.encoder_attn._set_input_buffer(incremental_state, saved_state)
- x, attn = self.encoder_attn(
- query=x,
- key=encoder_out,
- value=encoder_out,
- key_padding_mask=encoder_padding_mask,
- incremental_state=incremental_state,
- static_kv=True,
- need_weights=(not self.training and self.need_attn),
- )
- x = F.dropout(x, p=self.dropout, training=self.training)
- x = residual + x
- x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
- residual = x
- x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
- x = self.activation_fn(self.fc1(x))
- x = F.dropout(x, p=self.activation_dropout, training=self.training)
- x = self.fc2(x)
- x = F.dropout(x, p=self.dropout, training=self.training)
- x = residual + x
- x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
- if mt_task:
- return (x, encoder_out, encoder_padding_mask)
- return x
- def buffered_future_mask(self, tensor):
- dim = tensor.size(0)
- if (
- not hasattr(self, "_future_mask")
- or self._future_mask is None
- or self._future_mask.device != tensor.device
- ):
- self._future_mask = torch.triu(
- utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
- )
- if self._future_mask.size(0) < dim:
- self._future_mask = torch.triu(
- utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
- )
- return self._future_mask[:dim, :dim]
- def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
- assert before ^ after
- if after ^ self.normalize_before:
- return layer_norm(x)
- else:
- return x
- def make_generation_fast_(self, need_attn=False, **kwargs):
- self.need_attn = need_attn
- def Embedding(num_embeddings, embedding_dim, padding_idx):
- m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
- nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
- nn.init.constant_(m.weight[padding_idx], 0)
- return m
- def Linear(in_features, out_features, bias=True):
- m = nn.Linear(in_features, out_features, bias)
- nn.init.xavier_uniform_(m.weight)
- if bias:
- nn.init.constant_(m.bias, 0.0)
- return m