commons.py
Code Explained
The provided code contains a collection of utility functions commonly used in deep learning models, particularly in sequence modeling and attention-based architectures. These functions handle tasks such as weight initialization, padding, sampling, timing signals, masking, and gradient clipping. Below is a detailed explanation of the key functions:
Weight Initialization
init_weights(m, mean=0.0, std=0.01): Initializes the weights of convolutional layers (Conv) using a normal distribution with a specified mean and standard deviation. This is crucial for ensuring stable training and avoiding issues like vanishing or exploding gradients.
Padding and Sequence Manipulation
get_padding(kernel_size, dilation=1): Computes the padding required to maintain the same spatial dimensions after applying a convolutional layer. This is useful for ensuring consistent output sizes in neural networks.intersperse(lst, item): Inserts a specified item between elements of a list. This can be used for tasks like adding special tokens in sequence processing.
Probability and Sampling
kl_divergence(m_p, logs_p, m_q, logs_q): Computes the Kullback-Leibler (KL) divergence between two distributionsPandQ. This is often used in variational autoencoders (VAEs) and other probabilistic models to measure the difference between distributions.rand_gumbel(shape): Samples from the Gumbel distribution, which is commonly used in Gumbel-Softmax sampling for differentiable approximations of discrete variables.rand_gumbel_like(x): Generates Gumbel samples with the same shape and device as the input tensorx.
Segment Slicing
slice_segments(x, ids_str, segment_size=4): Extracts fixed-size segments from a tensorxbased on starting indicesids_str. This is useful for tasks like cropping audio or feature sequences.rand_slice_segments(x, x_lengths=None, segment_size=4): Randomly selects segments from a tensorx, ensuring that the segments are within valid bounds. This is often used for data augmentation or training models on random subsequences.
Timing Signals
get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): Generates sinusoidal timing signals, which are used in transformer models to encode positional information in sequences.add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): Adds timing signals to the input tensorx, enabling the model to incorporate positional information.cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): Concatenates timing signals to the input tensorxalong a specified axis.
Masking
subsequent_mask(length): Creates a triangular mask to prevent attention to future positions in a sequence. This is essential for autoregressive models like decoders in transformers.sequence_mask(length, max_length=None): Generates a binary mask for sequences of varying lengths, ensuring that padding positions are ignored during computation.
Activation Fusion
fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): Combines addition,tanh,sigmoid, and element-wise multiplication into a single operation. This is often used in gated activation functions for efficiency.
Path Generation
generate_path(duration, mask): Generates a path tensor based on cumulative durations and a mask. This is useful in models that align sequences, such as attention-based TTS systems.
Gradient Clipping
clip_grad_value_(parameters, clip_value, norm_type=2): Clips the gradients of model parameters to a specified value, preventing gradients from becoming too large and destabilizing training.
Key Use Cases
These utility functions are designed to support various aspects of deep learning model development:
- Weight Initialization: Ensures stable training by initializing weights appropriately.
- Sequence Processing: Handles tasks like padding, masking, and timing signal generation for sequence models.
- Sampling and Augmentation: Provides tools for probabilistic sampling and random segment extraction.
- Gradient Management: Prevents issues like exploding gradients during backpropagation.
Applications
These functions are particularly relevant for:
- Transformer Models: Functions like
get_timing_signal_1dandsubsequent_maskare essential for implementing transformers. - Speech and Audio Processing: Segment slicing and path generation are useful for tasks like text-to-speech (TTS) and audio feature extraction.
- Probabilistic Models: KL divergence and Gumbel sampling are key components of VAEs and Gumbel-Softmax models.
By modularizing these utilities, the codebase becomes more reusable and easier to maintain across different projects.
Source Code
import loggingimport mathfrom typing import Optional
import torchfrom torch.nn import functional as F
_LOGGER = logging.getLogger("vits.commons")
def init_weights(m, mean=0.0, std=0.01): classname = m.__class__.__name__ if classname.find("Conv") != -1: m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2)
def intersperse(lst, item): result = [item] * (len(lst) * 2 + 1) result[1::2] = lst return result
def kl_divergence(m_p, logs_p, m_q, logs_q): """KL(P||Q)""" kl = (logs_q - logs_p) - 0.5 kl += ( 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) ) return kl
def rand_gumbel(shape): """Sample from the Gumbel distribution, protect from overflows.""" uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 return -torch.log(-torch.log(uniform_samples))
def rand_gumbel_like(x): g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) return g
def slice_segments(x, ids_str, segment_size=4): ret = torch.zeros_like(x[:, :, :segment_size]) for i in range(x.size(0)): idx_str = max(0, ids_str[i]) idx_end = idx_str + segment_size ret[i] = x[i, :, idx_str:idx_end] return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4): b, d, t = x.size() if x_lengths is None: x_lengths = t ids_str_max = x_lengths - segment_size + 1 ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) ret = slice_segments(x, ids_str, segment_size) return ret, ids_str
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): position = torch.arange(length, dtype=torch.float) num_timescales = channels // 2 log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( num_timescales - 1 ) inv_timescales = min_timescale * torch.exp( torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment ) scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) signal = F.pad(signal, [0, 0, 0, channels % 2]) signal = signal.view(1, channels, length) return signal
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): b, channels, length = x.size() signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) return x + signal.to(dtype=x.dtype, device=x.device)
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): b, channels, length = x.size() signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
def subsequent_mask(length: int): mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) return mask
@torch.jit.scriptdef fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): n_channels_int = n_channels[0] in_act = input_a + input_b t_act = torch.tanh(in_act[:, :n_channels_int, :]) s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) acts = t_act * s_act return acts
def sequence_mask(length, max_length: Optional[int] = None): if max_length is None: max_length = length.max() x = torch.arange(max_length, dtype=length.dtype, device=length.device) return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask): """ duration: [b, 1, t_x] mask: [b, 1, t_y, t_x] """ b, _, t_y, t_x = mask.shape cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x) path = sequence_mask(cum_duration_flat, t_y).type_as(mask) path = path.view(b, t_x, t_y) path = path - F.pad(path, (0, 0, 1, 0, 0, 0))[:, :-1] path = path.unsqueeze(1).transpose(2, 3) * mask return path
def clip_grad_value_(parameters, clip_value, norm_type=2): if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = list(filter(lambda p: p.grad is not None, parameters)) norm_type = float(norm_type) if clip_value is not None: clip_value = float(clip_value)
total_norm = 0 for p in parameters: param_norm = p.grad.data.norm(norm_type) total_norm += param_norm.item() ** norm_type if clip_value is not None: p.grad.data.clamp_(min=-clip_value, max=clip_value) total_norm = total_norm ** (1.0 / norm_type) return total_norm