attentions.py
Code Explained
The provided code defines a neural network architecture for sequence modeling, consisting of an Encoder, Decoder, MultiHeadAttention, and Feed-Forward Network (FFN). These components are commonly used in transformer-based models for tasks such as natural language processing, speech synthesis, and time-series analysis. Below is a detailed explanation of each class and its functionality:
Encoder
The Encoder class is responsible for processing input sequences and extracting high-level features. It consists of multiple layers, each containing:
- Multi-Head Attention: Captures relationships between different positions in the input sequence.
- Layer Normalization: Stabilizes training by normalizing intermediate outputs.
- Feed-Forward Network (FFN): Applies non-linear transformations to enhance feature representation.
Key Features:
- Initialization: The constructor initializes
n_layersof attention, normalization, and FFN layers. Each attention layer uses theMultiHeadAttentionclass, and each FFN layer uses theFFNclass. - Forward Pass: The
forwardmethod applies attention and FFN layers sequentially. Residual connections are used to preserve information from earlier layers, and dropout is applied for regularization. The input mask (x_mask) ensures that padding positions are ignored during computation.
Decoder
The Decoder class generates output sequences based on the encoder’s output and additional input (e.g., previous decoder outputs). It extends the encoder’s functionality by adding:
- Self-Attention: Captures dependencies within the decoder’s input sequence.
- Encoder-Decoder Attention: Aligns the decoder’s input with the encoder’s output.
Key Features:
- Initialization: Similar to the encoder, the decoder initializes multiple layers of self-attention, encoder-decoder attention, normalization, and FFN layers.
- Proximal Bias: Encourages attention to closer positions in the sequence, useful for tasks like speech synthesis.
- Forward Pass: The
forwardmethod processes the decoder’s input (x) and the encoder’s output (h) through self-attention, encoder-decoder attention, and FFN layers. Masks ensure that padding and future positions are ignored during computation.
MultiHeadAttention
The MultiHeadAttention class implements the attention mechanism, which allows the model to focus on relevant parts of the input sequence. It supports:
- Self-Attention: When the query, key, and value inputs are the same.
- Relative Position Embeddings: Encodes positional information to enhance attention to nearby elements.
- Proximal Bias: Adds a bias to encourage attention to closer positions.
Key Features:
- Initialization: The constructor initializes convolutional layers for query (
conv_q), key (conv_k), and value (conv_v) projections. It also supports relative position embeddings and proximal initialization. - Attention Mechanism: The
attentionmethod computes scaled dot-product attention, optionally incorporating relative position embeddings and proximal bias. Masks are applied to ignore padding or enforce causality. - Relative Position Handling: Helper methods (
_get_relative_embeddings,_relative_position_to_absolute_position,_absolute_position_to_relative_position) handle the transformation of relative position embeddings.
Feed-Forward Network (FFN)
The FFN class applies two convolutional layers with an activation function in between. It enhances the model’s capacity to learn complex transformations.
Key Features:
- Initialization: The constructor initializes two convolutional layers (
conv_1andconv_2) and supports dropout for regularization. Thecausalflag determines whether causal padding is applied. - Forward Pass: The
forwardmethod applies the first convolution, followed by an activation function (ReLUorGELU), dropout, and the second convolution. Padding methods (_causal_padding,_same_padding) ensure that the output length matches the input length.
Key Concepts
- Attention Mechanism: Allows the model to focus on relevant parts of the input sequence, improving its ability to capture long-range dependencies.
- Residual Connections: Add the input of a layer to its output, helping to preserve information and stabilize training.
- Layer Normalization: Normalizes intermediate outputs to improve convergence and stability.
- Dropout: Reduces overfitting by randomly setting a fraction of activations to zero during training.
- Relative Position Embeddings: Enhance the model’s ability to capture positional relationships between elements in the sequence.
Use Case
This architecture is well-suited for tasks like:
- Speech Synthesis: The encoder processes phoneme sequences, and the decoder generates audio features.
- Machine Translation: The encoder extracts features from the source language, and the decoder generates the target language.
- Time-Series Forecasting: The encoder captures patterns in historical data, and the decoder predicts future values.
By combining attention mechanisms, feed-forward networks, and normalization, this architecture provides a powerful framework for sequence modeling.
Source Code
import mathimport typing
import torchfrom torch import nnfrom torch.nn import functional as F
from .commons import subsequent_maskfrom .modules import LayerNorm
class Encoder(nn.Module): def __init__( self, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int = 1, p_dropout: float = 0.0, window_size: int = 4, **kwargs ): super().__init__() self.hidden_channels = hidden_channels self.filter_channels = filter_channels self.n_heads = n_heads self.n_layers = n_layers self.kernel_size = kernel_size self.p_dropout = p_dropout self.window_size = window_size
self.drop = nn.Dropout(p_dropout) self.attn_layers = nn.ModuleList() self.norm_layers_1 = nn.ModuleList() self.ffn_layers = nn.ModuleList() self.norm_layers_2 = nn.ModuleList() for i in range(self.n_layers): self.attn_layers.append( MultiHeadAttention( hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size, ) ) self.norm_layers_1.append(LayerNorm(hidden_channels)) self.ffn_layers.append( FFN( hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, ) ) self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask): attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) x = x * x_mask for attn_layer, norm_layer_1, ffn_layer, norm_layer_2 in zip( self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2 ): y = attn_layer(x, x, attn_mask) y = self.drop(y) x = norm_layer_1(x + y)
y = ffn_layer(x, x_mask) y = self.drop(y) x = norm_layer_2(x + y) x = x * x_mask return x
class Decoder(nn.Module): def __init__( self, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int = 1, p_dropout: float = 0.0, proximal_bias: bool = False, proximal_init: bool = True, **kwargs ): super().__init__() self.hidden_channels = hidden_channels self.filter_channels = filter_channels self.n_heads = n_heads self.n_layers = n_layers self.kernel_size = kernel_size self.p_dropout = p_dropout self.proximal_bias = proximal_bias self.proximal_init = proximal_init
self.drop = nn.Dropout(p_dropout) self.self_attn_layers = nn.ModuleList() self.norm_layers_0 = nn.ModuleList() self.encdec_attn_layers = nn.ModuleList() self.norm_layers_1 = nn.ModuleList() self.ffn_layers = nn.ModuleList() self.norm_layers_2 = nn.ModuleList() for i in range(self.n_layers): self.self_attn_layers.append( MultiHeadAttention( hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init, ) ) self.norm_layers_0.append(LayerNorm(hidden_channels)) self.encdec_attn_layers.append( MultiHeadAttention( hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout ) ) self.norm_layers_1.append(LayerNorm(hidden_channels)) self.ffn_layers.append( FFN( hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True, ) ) self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, h, h_mask): """ x: decoder input h: encoder output """ self_attn_mask = subsequent_mask(x_mask.size(2)).type_as(x) encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) x = x * x_mask for i in range(self.n_layers): y = self.self_attn_layers[i](x, x, self_attn_mask) y = self.drop(y) x = self.norm_layers_0[i](x + y)
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) y = self.drop(y) x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask) y = self.drop(y) x = self.norm_layers_2[i](x + y) x = x * x_mask return x
class MultiHeadAttention(nn.Module): def __init__( self, channels: int, out_channels: int, n_heads: int, p_dropout: float = 0.0, window_size: typing.Optional[int] = None, heads_share: bool = True, block_length: typing.Optional[int] = None, proximal_bias: bool = False, proximal_init: bool = False, ): super().__init__() assert channels % n_heads == 0
self.channels = channels self.out_channels = out_channels self.n_heads = n_heads self.p_dropout = p_dropout self.window_size = window_size self.heads_share = heads_share self.block_length = block_length self.proximal_bias = proximal_bias self.proximal_init = proximal_init self.attn = torch.zeros(1)
self.k_channels = channels // n_heads self.conv_q = nn.Conv1d(channels, channels, 1) self.conv_k = nn.Conv1d(channels, channels, 1) self.conv_v = nn.Conv1d(channels, channels, 1) self.conv_o = nn.Conv1d(channels, out_channels, 1) self.drop = nn.Dropout(p_dropout)
if window_size is not None: n_heads_rel = 1 if heads_share else n_heads rel_stddev = self.k_channels**-0.5 self.emb_rel_k = nn.Parameter( torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev ) self.emb_rel_v = nn.Parameter( torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev )
nn.init.xavier_uniform_(self.conv_q.weight) nn.init.xavier_uniform_(self.conv_k.weight) nn.init.xavier_uniform_(self.conv_v.weight) if proximal_init: with torch.no_grad(): self.conv_k.weight.copy_(self.conv_q.weight) self.conv_k.bias.copy_(self.conv_q.bias)
def forward(self, x, c, attn_mask=None): q = self.conv_q(x) k = self.conv_k(c) v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x) return x
def attention(self, query, key, value, mask=None): # reshape [b, d, t] -> [b, n_h, t, d_k] b, d, t_s, t_t = (key.size(0), key.size(1), key.size(2), query.size(2)) query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) if self.window_size is not None: assert ( t_s == t_t ), "Relative attention is only available for self-attention." key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) rel_logits = self._matmul_with_relative_keys( query / math.sqrt(self.k_channels), key_relative_embeddings ) scores_local = self._relative_position_to_absolute_position(rel_logits) scores = scores + scores_local if self.proximal_bias: assert t_s == t_t, "Proximal bias is only available for self-attention." scores = scores + self._attention_bias_proximal(t_s).type_as(scores) if mask is not None: scores = scores.masked_fill(mask == 0, -1e4) if self.block_length is not None: assert ( t_s == t_t ), "Local attention is only available for self-attention." block_mask = ( torch.ones_like(scores) .triu(-self.block_length) .tril(self.block_length) ) scores = scores.masked_fill(block_mask == 0, -1e4) p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] p_attn = self.drop(p_attn) output = torch.matmul(p_attn, value) if self.window_size is not None: relative_weights = self._absolute_position_to_relative_position(p_attn) value_relative_embeddings = self._get_relative_embeddings( self.emb_rel_v, t_s ) output = output + self._matmul_with_relative_values( relative_weights, value_relative_embeddings ) output = ( output.transpose(2, 3).contiguous().view(b, d, t_t) ) # [b, n_h, t_t, d_k] -> [b, d, t_t] return output, p_attn
def _matmul_with_relative_values(self, x, y): """ x: [b, h, l, m] y: [h or 1, m, d] ret: [b, h, l, d] """ ret = torch.matmul(x, y.unsqueeze(0)) return ret
def _matmul_with_relative_keys(self, x, y): """ x: [b, h, l, d] y: [h or 1, m, d] ret: [b, h, l, m] """ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) return ret
def _get_relative_embeddings(self, relative_embeddings, length: int): # max_relative_position = 2 * self.window_size + 1 # Pad first before slice to avoid using cond ops. pad_length = max(length - (self.window_size + 1), 0) slice_start_position = max((self.window_size + 1) - length, 0) slice_end_position = slice_start_position + 2 * length - 1 if pad_length > 0: padded_relative_embeddings = F.pad( relative_embeddings, # convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), (0, 0, pad_length, pad_length, 0, 0), ) else: padded_relative_embeddings = relative_embeddings used_relative_embeddings = padded_relative_embeddings[ :, slice_start_position:slice_end_position ] return used_relative_embeddings
def _relative_position_to_absolute_position(self, x): """ x: [b, h, l, 2*l-1] ret: [b, h, l, l] """ batch, heads, length, _ = x.size()
# Concat columns of pad to shift from relative to absolute indexing. # x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) x = F.pad(x, (0, 1, 0, 0, 0, 0, 0, 0))
# Concat extra elements so to add up to shape (len+1, 2*len-1). x_flat = x.view([batch, heads, length * 2 * length]) # x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) x_flat = F.pad(x_flat, (0, length - 1, 0, 0, 0, 0))
# Reshape and slice out the padded elements. x_final = x_flat.view([batch, heads, length + 1, (2 * length) - 1])[ :, :, :length, length - 1 : ] return x_final
def _absolute_position_to_relative_position(self, x): """ x: [b, h, l, l] ret: [b, h, l, 2*l-1] """ batch, heads, length, _ = x.size()
# padd along column # x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) x = F.pad(x, (0, length - 1, 0, 0, 0, 0, 0, 0)) x_flat = x.view([batch, heads, (length * length) + (length * (length - 1))]) # add 0's in the beginning that will skew the elements after reshape # x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])) x_flat = F.pad(x_flat, (length, 0, 0, 0, 0, 0)) x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] return x_final
def _attention_bias_proximal(self, length: int): """Bias for self-attention to encourage attention to close positions. Args: length: an integer scalar. Returns: a Tensor with shape [1, 1, length, length] """ r = torch.arange(length, dtype=torch.float32) diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
class FFN(nn.Module): def __init__( self, in_channels: int, out_channels: int, filter_channels: int, kernel_size: int, p_dropout: float = 0.0, activation: str = "", causal: bool = False, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.filter_channels = filter_channels self.kernel_size = kernel_size self.p_dropout = p_dropout self.activation = activation self.causal = causal
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) self.drop = nn.Dropout(p_dropout)
def forward(self, x, x_mask): if self.causal: padding1 = self._causal_padding(x * x_mask) else: padding1 = self._same_padding(x * x_mask)
x = self.conv_1(padding1)
if self.activation == "gelu": x = x * torch.sigmoid(1.702 * x) else: x = torch.relu(x) x = self.drop(x)
if self.causal: padding2 = self._causal_padding(x * x_mask) else: padding2 = self._same_padding(x * x_mask)
x = self.conv_2(padding2)
return x * x_mask
def _causal_padding(self, x): if self.kernel_size == 1: return x pad_l = self.kernel_size - 1 pad_r = 0 # padding = [[0, 0], [0, 0], [pad_l, pad_r]] # x = F.pad(x, convert_pad_shape(padding)) x = F.pad(x, (pad_l, pad_r, 0, 0, 0, 0)) return x
def _same_padding(self, x): if self.kernel_size == 1: return x pad_l = (self.kernel_size - 1) // 2 pad_r = self.kernel_size // 2 # padding = [[0, 0], [0, 0], [pad_l, pad_r]] # x = F.pad(x, convert_pad_shape(padding)) x = F.pad(x, (pad_l, pad_r, 0, 0, 0, 0)) return x