Skip to content

attentions.py

Code Explained

The provided code defines a neural network architecture for sequence modeling, consisting of an Encoder, Decoder, MultiHeadAttention, and Feed-Forward Network (FFN). These components are commonly used in transformer-based models for tasks such as natural language processing, speech synthesis, and time-series analysis. Below is a detailed explanation of each class and its functionality:


Encoder

The Encoder class is responsible for processing input sequences and extracting high-level features. It consists of multiple layers, each containing:

  1. Multi-Head Attention: Captures relationships between different positions in the input sequence.
  2. Layer Normalization: Stabilizes training by normalizing intermediate outputs.
  3. Feed-Forward Network (FFN): Applies non-linear transformations to enhance feature representation.

Key Features:

  • Initialization: The constructor initializes n_layers of attention, normalization, and FFN layers. Each attention layer uses the MultiHeadAttention class, and each FFN layer uses the FFN class.
  • Forward Pass: The forward method applies attention and FFN layers sequentially. Residual connections are used to preserve information from earlier layers, and dropout is applied for regularization. The input mask (x_mask) ensures that padding positions are ignored during computation.

Decoder

The Decoder class generates output sequences based on the encoder’s output and additional input (e.g., previous decoder outputs). It extends the encoder’s functionality by adding:

  1. Self-Attention: Captures dependencies within the decoder’s input sequence.
  2. Encoder-Decoder Attention: Aligns the decoder’s input with the encoder’s output.

Key Features:

  • Initialization: Similar to the encoder, the decoder initializes multiple layers of self-attention, encoder-decoder attention, normalization, and FFN layers.
  • Proximal Bias: Encourages attention to closer positions in the sequence, useful for tasks like speech synthesis.
  • Forward Pass: The forward method processes the decoder’s input (x) and the encoder’s output (h) through self-attention, encoder-decoder attention, and FFN layers. Masks ensure that padding and future positions are ignored during computation.

MultiHeadAttention

The MultiHeadAttention class implements the attention mechanism, which allows the model to focus on relevant parts of the input sequence. It supports:

  1. Self-Attention: When the query, key, and value inputs are the same.
  2. Relative Position Embeddings: Encodes positional information to enhance attention to nearby elements.
  3. Proximal Bias: Adds a bias to encourage attention to closer positions.

Key Features:

  • Initialization: The constructor initializes convolutional layers for query (conv_q), key (conv_k), and value (conv_v) projections. It also supports relative position embeddings and proximal initialization.
  • Attention Mechanism: The attention method computes scaled dot-product attention, optionally incorporating relative position embeddings and proximal bias. Masks are applied to ignore padding or enforce causality.
  • Relative Position Handling: Helper methods (_get_relative_embeddings, _relative_position_to_absolute_position, _absolute_position_to_relative_position) handle the transformation of relative position embeddings.

Feed-Forward Network (FFN)

The FFN class applies two convolutional layers with an activation function in between. It enhances the model’s capacity to learn complex transformations.

Key Features:

  • Initialization: The constructor initializes two convolutional layers (conv_1 and conv_2) and supports dropout for regularization. The causal flag determines whether causal padding is applied.
  • Forward Pass: The forward method applies the first convolution, followed by an activation function (ReLU or GELU), dropout, and the second convolution. Padding methods (_causal_padding, _same_padding) ensure that the output length matches the input length.

Key Concepts

  1. Attention Mechanism: Allows the model to focus on relevant parts of the input sequence, improving its ability to capture long-range dependencies.
  2. Residual Connections: Add the input of a layer to its output, helping to preserve information and stabilize training.
  3. Layer Normalization: Normalizes intermediate outputs to improve convergence and stability.
  4. Dropout: Reduces overfitting by randomly setting a fraction of activations to zero during training.
  5. Relative Position Embeddings: Enhance the model’s ability to capture positional relationships between elements in the sequence.

Use Case

This architecture is well-suited for tasks like:

  • Speech Synthesis: The encoder processes phoneme sequences, and the decoder generates audio features.
  • Machine Translation: The encoder extracts features from the source language, and the decoder generates the target language.
  • Time-Series Forecasting: The encoder captures patterns in historical data, and the decoder predicts future values.

By combining attention mechanisms, feed-forward networks, and normalization, this architecture provides a powerful framework for sequence modeling.

Source Code

import math
import typing
import torch
from torch import nn
from torch.nn import functional as F
from .commons import subsequent_mask
from .modules import LayerNorm
class Encoder(nn.Module):
def __init__(
self,
hidden_channels: int,
filter_channels: int,
n_heads: int,
n_layers: int,
kernel_size: int = 1,
p_dropout: float = 0.0,
window_size: int = 4,
**kwargs
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size
self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
window_size=window_size,
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for attn_layer, norm_layer_1, ffn_layer, norm_layer_2 in zip(
self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2
):
y = attn_layer(x, x, attn_mask)
y = self.drop(y)
x = norm_layer_1(x + y)
y = ffn_layer(x, x_mask)
y = self.drop(y)
x = norm_layer_2(x + y)
x = x * x_mask
return x
class Decoder(nn.Module):
def __init__(
self,
hidden_channels: int,
filter_channels: int,
n_heads: int,
n_layers: int,
kernel_size: int = 1,
p_dropout: float = 0.0,
proximal_bias: bool = False,
proximal_init: bool = True,
**kwargs
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.drop = nn.Dropout(p_dropout)
self.self_attn_layers = nn.ModuleList()
self.norm_layers_0 = nn.ModuleList()
self.encdec_attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.self_attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
proximal_bias=proximal_bias,
proximal_init=proximal_init,
)
)
self.norm_layers_0.append(LayerNorm(hidden_channels))
self.encdec_attn_layers.append(
MultiHeadAttention(
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
causal=True,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, h, h_mask):
"""
x: decoder input
h: encoder output
"""
self_attn_mask = subsequent_mask(x_mask.size(2)).type_as(x)
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
y = self.self_attn_layers[i](x, x, self_attn_mask)
y = self.drop(y)
x = self.norm_layers_0[i](x + y)
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class MultiHeadAttention(nn.Module):
def __init__(
self,
channels: int,
out_channels: int,
n_heads: int,
p_dropout: float = 0.0,
window_size: typing.Optional[int] = None,
heads_share: bool = True,
block_length: typing.Optional[int] = None,
proximal_bias: bool = False,
proximal_init: bool = False,
):
super().__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.p_dropout = p_dropout
self.window_size = window_size
self.heads_share = heads_share
self.block_length = block_length
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.attn = torch.zeros(1)
self.k_channels = channels // n_heads
self.conv_q = nn.Conv1d(channels, channels, 1)
self.conv_k = nn.Conv1d(channels, channels, 1)
self.conv_v = nn.Conv1d(channels, channels, 1)
self.conv_o = nn.Conv1d(channels, out_channels, 1)
self.drop = nn.Dropout(p_dropout)
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
self.emb_rel_v = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight)
nn.init.xavier_uniform_(self.conv_v.weight)
if proximal_init:
with torch.no_grad():
self.conv_k.weight.copy_(self.conv_q.weight)
self.conv_k.bias.copy_(self.conv_q.bias)
def forward(self, x, c, attn_mask=None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, t_t = (key.size(0), key.size(1), key.size(2), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None:
assert (
t_s == t_t
), "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(
query / math.sqrt(self.k_channels), key_relative_embeddings
)
scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).type_as(scores)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
assert (
t_s == t_t
), "Local attention is only available for self-attention."
block_mask = (
torch.ones_like(scores)
.triu(-self.block_length)
.tril(self.block_length)
)
scores = scores.masked_fill(block_mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(
self.emb_rel_v, t_s
)
output = output + self._matmul_with_relative_values(
relative_weights, value_relative_embeddings
)
output = (
output.transpose(2, 3).contiguous().view(b, d, t_t)
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn
def _matmul_with_relative_values(self, x, y):
"""
x: [b, h, l, m]
y: [h or 1, m, d]
ret: [b, h, l, d]
"""
ret = torch.matmul(x, y.unsqueeze(0))
return ret
def _matmul_with_relative_keys(self, x, y):
"""
x: [b, h, l, d]
y: [h or 1, m, d]
ret: [b, h, l, m]
"""
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
return ret
def _get_relative_embeddings(self, relative_embeddings, length: int):
# max_relative_position = 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_length = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = F.pad(
relative_embeddings,
# convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
(0, 0, pad_length, pad_length, 0, 0),
)
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[
:, slice_start_position:slice_end_position
]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
"""
x: [b, h, l, 2*l-1]
ret: [b, h, l, l]
"""
batch, heads, length, _ = x.size()
# Concat columns of pad to shift from relative to absolute indexing.
# x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
x = F.pad(x, (0, 1, 0, 0, 0, 0, 0, 0))
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
# x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
x_flat = F.pad(x_flat, (0, length - 1, 0, 0, 0, 0))
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1, (2 * length) - 1])[
:, :, :length, length - 1 :
]
return x_final
def _absolute_position_to_relative_position(self, x):
"""
x: [b, h, l, l]
ret: [b, h, l, 2*l-1]
"""
batch, heads, length, _ = x.size()
# padd along column
# x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
x = F.pad(x, (0, length - 1, 0, 0, 0, 0, 0, 0))
x_flat = x.view([batch, heads, (length * length) + (length * (length - 1))])
# add 0's in the beginning that will skew the elements after reshape
# x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_flat = F.pad(x_flat, (length, 0, 0, 0, 0, 0))
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final
def _attention_bias_proximal(self, length: int):
"""Bias for self-attention to encourage attention to close positions.
Args:
length: an integer scalar.
Returns:
a Tensor with shape [1, 1, length, length]
"""
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
class FFN(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
filter_channels: int,
kernel_size: int,
p_dropout: float = 0.0,
activation: str = "",
causal: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.activation = activation
self.causal = causal
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
self.drop = nn.Dropout(p_dropout)
def forward(self, x, x_mask):
if self.causal:
padding1 = self._causal_padding(x * x_mask)
else:
padding1 = self._same_padding(x * x_mask)
x = self.conv_1(padding1)
if self.activation == "gelu":
x = x * torch.sigmoid(1.702 * x)
else:
x = torch.relu(x)
x = self.drop(x)
if self.causal:
padding2 = self._causal_padding(x * x_mask)
else:
padding2 = self._same_padding(x * x_mask)
x = self.conv_2(padding2)
return x * x_mask
def _causal_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = self.kernel_size - 1
pad_r = 0
# padding = [[0, 0], [0, 0], [pad_l, pad_r]]
# x = F.pad(x, convert_pad_shape(padding))
x = F.pad(x, (pad_l, pad_r, 0, 0, 0, 0))
return x
def _same_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2
# padding = [[0, 0], [0, 0], [pad_l, pad_r]]
# x = F.pad(x, convert_pad_shape(padding))
x = F.pad(x, (pad_l, pad_r, 0, 0, 0, 0))
return x