Skip to content

commons.py

Code Explained

The provided code contains a collection of utility functions commonly used in deep learning models, particularly in sequence modeling and attention-based architectures. These functions handle tasks such as weight initialization, padding, sampling, timing signals, masking, and gradient clipping. Below is a detailed explanation of the key functions:


Weight Initialization

  • init_weights(m, mean=0.0, std=0.01): Initializes the weights of convolutional layers (Conv) using a normal distribution with a specified mean and standard deviation. This is crucial for ensuring stable training and avoiding issues like vanishing or exploding gradients.

Padding and Sequence Manipulation

  • get_padding(kernel_size, dilation=1): Computes the padding required to maintain the same spatial dimensions after applying a convolutional layer. This is useful for ensuring consistent output sizes in neural networks.
  • intersperse(lst, item): Inserts a specified item between elements of a list. This can be used for tasks like adding special tokens in sequence processing.

Probability and Sampling

  • kl_divergence(m_p, logs_p, m_q, logs_q): Computes the Kullback-Leibler (KL) divergence between two distributions P and Q. This is often used in variational autoencoders (VAEs) and other probabilistic models to measure the difference between distributions.
  • rand_gumbel(shape): Samples from the Gumbel distribution, which is commonly used in Gumbel-Softmax sampling for differentiable approximations of discrete variables.
  • rand_gumbel_like(x): Generates Gumbel samples with the same shape and device as the input tensor x.

Segment Slicing

  • slice_segments(x, ids_str, segment_size=4): Extracts fixed-size segments from a tensor x based on starting indices ids_str. This is useful for tasks like cropping audio or feature sequences.
  • rand_slice_segments(x, x_lengths=None, segment_size=4): Randomly selects segments from a tensor x, ensuring that the segments are within valid bounds. This is often used for data augmentation or training models on random subsequences.

Timing Signals

  • get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): Generates sinusoidal timing signals, which are used in transformer models to encode positional information in sequences.
  • add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): Adds timing signals to the input tensor x, enabling the model to incorporate positional information.
  • cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): Concatenates timing signals to the input tensor x along a specified axis.

Masking

  • subsequent_mask(length): Creates a triangular mask to prevent attention to future positions in a sequence. This is essential for autoregressive models like decoders in transformers.
  • sequence_mask(length, max_length=None): Generates a binary mask for sequences of varying lengths, ensuring that padding positions are ignored during computation.

Activation Fusion

  • fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): Combines addition, tanh, sigmoid, and element-wise multiplication into a single operation. This is often used in gated activation functions for efficiency.

Path Generation

  • generate_path(duration, mask): Generates a path tensor based on cumulative durations and a mask. This is useful in models that align sequences, such as attention-based TTS systems.

Gradient Clipping

  • clip_grad_value_(parameters, clip_value, norm_type=2): Clips the gradients of model parameters to a specified value, preventing gradients from becoming too large and destabilizing training.

Key Use Cases

These utility functions are designed to support various aspects of deep learning model development:

  1. Weight Initialization: Ensures stable training by initializing weights appropriately.
  2. Sequence Processing: Handles tasks like padding, masking, and timing signal generation for sequence models.
  3. Sampling and Augmentation: Provides tools for probabilistic sampling and random segment extraction.
  4. Gradient Management: Prevents issues like exploding gradients during backpropagation.

Applications

These functions are particularly relevant for:

  • Transformer Models: Functions like get_timing_signal_1d and subsequent_mask are essential for implementing transformers.
  • Speech and Audio Processing: Segment slicing and path generation are useful for tasks like text-to-speech (TTS) and audio feature extraction.
  • Probabilistic Models: KL divergence and Gumbel sampling are key components of VAEs and Gumbel-Softmax models.

By modularizing these utilities, the codebase becomes more reusable and easier to maintain across different projects.

Source Code

import logging
import math
from typing import Optional
import torch
from torch.nn import functional as F
_LOGGER = logging.getLogger("vits.commons")
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += (
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
return kl
def rand_gumbel(shape):
"""Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples))
def rand_gumbel_like(x):
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g
def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = max(0, ids_str[i])
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
num_timescales - 1
)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device)
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
def subsequent_mask(length: int):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def sequence_mask(length, max_length: Optional[int] = None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).type_as(mask)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, (0, 0, 1, 0, 0, 0))[:, :-1]
path = path.unsqueeze(1).transpose(2, 3) * mask
return path
def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1.0 / norm_type)
return total_norm