Skip to content

lightning.py

Code Explained

The VitsModel class is a PyTorch Lightning module designed for training and inference in a Variational Inference Text-to-Speech (VITS) system. This class integrates various components, including a synthesizer, discriminator, dataset handling, and training logic, to provide a complete pipeline for text-to-speech synthesis. Below is a detailed explanation of its structure and functionality:


Initialization (__init__)

The constructor initializes the model with a wide range of hyperparameters for audio processing, model architecture, and training. Key components include:

  1. Audio Parameters: Defines settings like FFT filter length, hop length, mel spectrogram channels, and sample rate.
  2. Model Parameters: Configures the architecture of the synthesizer and discriminator, including the number of layers, channels, and kernel sizes.
  3. Training Parameters: Sets up learning rate, batch size, gradient clipping, and dataset-related parameters like validation split and maximum phoneme IDs.

The constructor also:

  • Initializes the synthesizer (model_g) using the SynthesizerTrn class, which handles text encoding, spectrogram generation, and waveform synthesis.
  • Initializes the discriminator (model_d) using the MultiPeriodDiscriminator class, which evaluates the quality of generated audio.
  • Loads datasets for training, validation, and testing using the _load_datasets method.

Dataset Handling

The _load_datasets method splits the dataset into training, validation, and test sets based on the provided validation_split and num_test_examples. It uses the PiperDataset class to load utterances and random_split to partition the dataset.


Forward Pass

The forward method performs inference by:

  1. Passing text and its lengths to the synthesizer (model_g).
  2. Generating audio using the synthesizer’s infer method, which applies noise scaling, length scaling, and speaker conditioning (if applicable).

This method is used during validation to generate audio samples.


Data Loaders

The train_dataloader, val_dataloader, and test_dataloader methods create PyTorch DataLoader objects for the respective datasets. They use the UtteranceCollate class to pad and batch data efficiently, ensuring compatibility with the model’s input requirements.


Training Steps

The training_step method handles the training logic for both the generator (model_g) and discriminator (model_d):

  1. Generator Training (training_step_g):

    • Converts phoneme IDs, spectrograms, and audio waveforms into tensors.
    • Passes the data through the synthesizer to generate predictions (y_hat).
    • Computes losses, including mel spectrogram loss, KL divergence loss, feature matching loss, and duration loss.
    • Logs the total generator loss (loss_gen_all).
  2. Discriminator Training (training_step_d):

    • Evaluates real and generated audio using the discriminator.
    • Computes the discriminator loss and logs it (loss_disc_all).

Validation

The validation_step method evaluates the model on the validation dataset. It:

  1. Computes the generator and discriminator losses.
  2. Logs the validation loss.
  3. Generates audio samples for qualitative evaluation, which are logged as audio files.

Optimization

The configure_optimizers method sets up two optimizers (one for the generator and one for the discriminator) using the AdamW optimizer. It also defines learning rate schedulers with exponential decay.


Hyperparameter Parsing

The add_model_specific_args static method adds model-specific arguments to a command-line parser. This allows users to configure the model’s hyperparameters directly from the command line, making the training process more flexible.


Key Features

  1. Multi-Speaker Support: The model can handle multiple speakers by conditioning on speaker embeddings (gin_channels).
  2. End-to-End Pipeline: Integrates dataset loading, training, validation, and inference into a single class.
  3. Loss Functions: Combines multiple loss functions (e.g., mel spectrogram loss, KL divergence loss) to optimize both audio quality and alignment.
  4. Scalable Training: Uses PyTorch Lightning’s features for distributed training, mixed precision, and logging.

Use Case

The VitsModel class is designed for training and evaluating a VITS-based text-to-speech system. It is suitable for tasks like:

  • Single-speaker and multi-speaker TTS.
  • Voice cloning by conditioning on specific speaker embeddings.
  • Generating high-quality audio from text inputs.

This modular and extensible design makes it easy to adapt the model to different datasets and use cases.

Source Code

import logging
from pathlib import Path
from typing import List, Optional, Tuple, Union
import pytorch_lightning as pl
import torch
from torch import autocast
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from .commons import slice_segments
from .dataset import Batch, PiperDataset, UtteranceCollate
from .losses import discriminator_loss, feature_loss, generator_loss, kl_loss
from .mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from .models import MultiPeriodDiscriminator, SynthesizerTrn
_LOGGER = logging.getLogger("vits.lightning")
class VitsModel(pl.LightningModule):
def __init__(
self,
num_symbols: int,
num_speakers: int,
# audio
resblock="2",
resblock_kernel_sizes=(3, 5, 7),
resblock_dilation_sizes=(
(1, 2),
(2, 6),
(3, 12),
),
upsample_rates=(8, 8, 4),
upsample_initial_channel=256,
upsample_kernel_sizes=(16, 16, 8),
# mel
filter_length: int = 1024,
hop_length: int = 256,
win_length: int = 1024,
mel_channels: int = 80,
sample_rate: int = 22050,
sample_bytes: int = 2,
channels: int = 1,
mel_fmin: float = 0.0,
mel_fmax: Optional[float] = None,
# model
inter_channels: int = 192,
hidden_channels: int = 192,
filter_channels: int = 768,
n_heads: int = 2,
n_layers: int = 6,
kernel_size: int = 3,
p_dropout: float = 0.1,
n_layers_q: int = 3,
use_spectral_norm: bool = False,
gin_channels: int = 0,
use_sdp: bool = True,
segment_size: int = 8192,
# training
dataset: Optional[List[Union[str, Path]]] = None,
learning_rate: float = 2e-4,
betas: Tuple[float, float] = (0.8, 0.99),
eps: float = 1e-9,
batch_size: int = 1,
lr_decay: float = 0.999875,
init_lr_ratio: float = 1.0,
warmup_epochs: int = 0,
c_mel: int = 45,
c_kl: float = 1.0,
grad_clip: Optional[float] = None,
num_workers: int = 1,
seed: int = 1234,
num_test_examples: int = 5,
validation_split: float = 0.1,
max_phoneme_ids: Optional[int] = None,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
if (self.hparams.num_speakers > 1) and (self.hparams.gin_channels <= 0):
# Default gin_channels for multi-speaker model
self.hparams.gin_channels = 512
# Set up models
self.model_g = SynthesizerTrn(
n_vocab=self.hparams.num_symbols,
spec_channels=self.hparams.filter_length // 2 + 1,
segment_size=self.hparams.segment_size // self.hparams.hop_length,
inter_channels=self.hparams.inter_channels,
hidden_channels=self.hparams.hidden_channels,
filter_channels=self.hparams.filter_channels,
n_heads=self.hparams.n_heads,
n_layers=self.hparams.n_layers,
kernel_size=self.hparams.kernel_size,
p_dropout=self.hparams.p_dropout,
resblock=self.hparams.resblock,
resblock_kernel_sizes=self.hparams.resblock_kernel_sizes,
resblock_dilation_sizes=self.hparams.resblock_dilation_sizes,
upsample_rates=self.hparams.upsample_rates,
upsample_initial_channel=self.hparams.upsample_initial_channel,
upsample_kernel_sizes=self.hparams.upsample_kernel_sizes,
n_speakers=self.hparams.num_speakers,
gin_channels=self.hparams.gin_channels,
use_sdp=self.hparams.use_sdp,
)
self.model_d = MultiPeriodDiscriminator(
use_spectral_norm=self.hparams.use_spectral_norm
)
# Dataset splits
self._train_dataset: Optional[Dataset] = None
self._val_dataset: Optional[Dataset] = None
self._test_dataset: Optional[Dataset] = None
self._load_datasets(validation_split, num_test_examples, max_phoneme_ids)
# State kept between training optimizers
self._y = None
self._y_hat = None
def _load_datasets(
self,
validation_split: float,
num_test_examples: int,
max_phoneme_ids: Optional[int] = None,
):
if self.hparams.dataset is None:
_LOGGER.debug("No dataset to load")
return
full_dataset = PiperDataset(
self.hparams.dataset, max_phoneme_ids=max_phoneme_ids
)
valid_set_size = int(len(full_dataset) * validation_split)
train_set_size = len(full_dataset) - valid_set_size - num_test_examples
self._train_dataset, self._test_dataset, self._val_dataset = random_split(
full_dataset, [train_set_size, num_test_examples, valid_set_size]
)
def forward(self, text, text_lengths, scales, sid=None):
noise_scale = scales[0]
length_scale = scales[1]
noise_scale_w = scales[2]
audio, *_ = self.model_g.infer(
text,
text_lengths,
noise_scale=noise_scale,
length_scale=length_scale,
noise_scale_w=noise_scale_w,
sid=sid,
)
return audio
def train_dataloader(self):
return DataLoader(
self._train_dataset,
collate_fn=UtteranceCollate(
is_multispeaker=self.hparams.num_speakers > 1,
segment_size=self.hparams.segment_size,
),
num_workers=self.hparams.num_workers,
batch_size=self.hparams.batch_size,
)
def val_dataloader(self):
return DataLoader(
self._val_dataset,
collate_fn=UtteranceCollate(
is_multispeaker=self.hparams.num_speakers > 1,
segment_size=self.hparams.segment_size,
),
num_workers=self.hparams.num_workers,
batch_size=self.hparams.batch_size,
)
def test_dataloader(self):
return DataLoader(
self._test_dataset,
collate_fn=UtteranceCollate(
is_multispeaker=self.hparams.num_speakers > 1,
segment_size=self.hparams.segment_size,
),
num_workers=self.hparams.num_workers,
batch_size=self.hparams.batch_size,
)
def training_step(self, batch: Batch, batch_idx: int, optimizer_idx: int):
if optimizer_idx == 0:
return self.training_step_g(batch)
if optimizer_idx == 1:
return self.training_step_d(batch)
def training_step_g(self, batch: Batch):
x, x_lengths, y, _, spec, spec_lengths, speaker_ids = (
batch.phoneme_ids,
batch.phoneme_lengths,
batch.audios,
batch.audio_lengths,
batch.spectrograms,
batch.spectrogram_lengths,
batch.speaker_ids if batch.speaker_ids is not None else None,
)
(
y_hat,
l_length,
_attn,
ids_slice,
_x_mask,
z_mask,
(_z, z_p, m_p, logs_p, _m_q, logs_q),
) = self.model_g(x, x_lengths, spec, spec_lengths, speaker_ids)
self._y_hat = y_hat
mel = spec_to_mel_torch(
spec,
self.hparams.filter_length,
self.hparams.mel_channels,
self.hparams.sample_rate,
self.hparams.mel_fmin,
self.hparams.mel_fmax,
)
y_mel = slice_segments(
mel,
ids_slice,
self.hparams.segment_size // self.hparams.hop_length,
)
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1),
self.hparams.filter_length,
self.hparams.mel_channels,
self.hparams.sample_rate,
self.hparams.hop_length,
self.hparams.win_length,
self.hparams.mel_fmin,
self.hparams.mel_fmax,
)
y = slice_segments(
y,
ids_slice * self.hparams.hop_length,
self.hparams.segment_size,
) # slice
# Save for training_step_d
self._y = y
_y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.model_d(y, y_hat)
with autocast(self.device.type, enabled=False):
# Generator loss
loss_dur = torch.sum(l_length.float())
loss_mel = F.l1_loss(y_mel, y_hat_mel) * self.hparams.c_mel
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * self.hparams.c_kl
loss_fm = feature_loss(fmap_r, fmap_g)
loss_gen, _losses_gen = generator_loss(y_d_hat_g)
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl
self.log("loss_gen_all", loss_gen_all)
return loss_gen_all
def training_step_d(self, batch: Batch):
# From training_step_g
y = self._y
y_hat = self._y_hat
y_d_hat_r, y_d_hat_g, _, _ = self.model_d(y, y_hat.detach())
with autocast(self.device.type, enabled=False):
# Discriminator
loss_disc, _losses_disc_r, _losses_disc_g = discriminator_loss(
y_d_hat_r, y_d_hat_g
)
loss_disc_all = loss_disc
self.log("loss_disc_all", loss_disc_all)
return loss_disc_all
def validation_step(self, batch: Batch, batch_idx: int):
val_loss = self.training_step_g(batch) + self.training_step_d(batch)
self.log("val_loss", val_loss)
# Generate audio examples
for utt_idx, test_utt in enumerate(self._test_dataset):
text = test_utt.phoneme_ids.unsqueeze(0).to(self.device)
text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device)
scales = [0.667, 1.0, 0.8]
sid = (
test_utt.speaker_id.to(self.device)
if test_utt.speaker_id is not None
else None
)
test_audio = self(text, text_lengths, scales, sid=sid).detach()
# Scale to make louder in [-1, 1]
test_audio = test_audio * (1.0 / max(0.01, abs(test_audio.max())))
tag = test_utt.text or str(utt_idx)
self.logger.experiment.add_audio(
tag, test_audio, sample_rate=self.hparams.sample_rate
)
return val_loss
def configure_optimizers(self):
optimizers = [
torch.optim.AdamW(
self.model_g.parameters(),
lr=self.hparams.learning_rate,
betas=self.hparams.betas,
eps=self.hparams.eps,
),
torch.optim.AdamW(
self.model_d.parameters(),
lr=self.hparams.learning_rate,
betas=self.hparams.betas,
eps=self.hparams.eps,
),
]
schedulers = [
torch.optim.lr_scheduler.ExponentialLR(
optimizers[0], gamma=self.hparams.lr_decay
),
torch.optim.lr_scheduler.ExponentialLR(
optimizers[1], gamma=self.hparams.lr_decay
),
]
return optimizers, schedulers
@staticmethod
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("VitsModel")
parser.add_argument("--batch-size", type=int, required=True)
parser.add_argument("--validation-split", type=float, default=0.1)
parser.add_argument("--num-test-examples", type=int, default=5)
parser.add_argument(
"--max-phoneme-ids",
type=int,
help="Exclude utterances with phoneme id lists longer than this",
)
#
parser.add_argument("--hidden-channels", type=int, default=192)
parser.add_argument("--inter-channels", type=int, default=192)
parser.add_argument("--filter-channels", type=int, default=768)
parser.add_argument("--n-layers", type=int, default=6)
parser.add_argument("--n-heads", type=int, default=2)
#
return parent_parser