Skip to content

models.py

Code Explained

The provided code defines several neural network modules that are part of a text-to-speech (TTS) system based on deep learning. These modules handle various tasks such as text encoding, duration prediction, spectrogram generation, and adversarial training. Below is a detailed explanation of the key components:


1. StochasticDurationPredictor

This module predicts the duration of phonemes in a probabilistic manner using normalizing flows. It is a critical component for aligning text and audio in TTS systems.

  • Initialization:

    • Defines multiple flows (ConvFlow and Flip) for transforming input distributions.
    • Includes convolutional layers (Conv1d) for feature extraction and conditioning.
    • Supports global conditioning (gin_channels) for multi-speaker models.
  • Forward Pass:

    • In the forward mode (reverse=False), it computes the negative log-likelihood (NLL) of the predicted durations.
    • In the reverse mode (reverse=True), it generates durations (logw) by sampling from the learned distribution.

This module enables flexible and expressive duration modeling, which is essential for natural-sounding speech synthesis.


2. DurationPredictor

This module predicts the duration of phonemes deterministically, without using probabilistic flows.

  • Initialization:

    • Uses two convolutional layers (Conv1d) with layer normalization (LayerNorm) and dropout for robust feature extraction.
    • Supports global conditioning for multi-speaker models.
  • Forward Pass:

    • Processes the input through convolutional layers and predicts durations as a single-channel output.

This simpler duration predictor is useful for deterministic alignment in TTS systems.


3. TextEncoder

The TextEncoder converts input text (represented as phoneme IDs) into a latent representation suitable for speech synthesis.

  • Initialization:

    • Embeds input phoneme IDs into a high-dimensional space using nn.Embedding.
    • Uses a multi-head attention-based encoder (attentions.Encoder) for capturing long-range dependencies in the text.
    • Projects the encoded features into mean (m) and log-variance (logs) for latent space modeling.
  • Forward Pass:

    • Embeds the input text and applies the attention-based encoder.
    • Outputs the latent representation (x), mean (m), log-variance (logs), and a mask (x_mask).

This module is responsible for extracting meaningful features from text, which are used to guide the speech synthesis process.


4. ResidualCouplingBlock

This module implements a series of residual coupling layers for transforming latent variables in the flow-based model.

  • Initialization:

    • Defines multiple ResidualCouplingLayer modules, each followed by a Flip operation to alternate the input dimensions.
  • Forward Pass:

    • Applies the coupling layers sequentially in the forward mode.
    • Reverses the flow in the reverse mode to reconstruct the input.

This module is used for flexible and invertible transformations in the latent space.


5. PosteriorEncoder

The PosteriorEncoder encodes input spectrograms into a latent representation for variational modeling.

  • Initialization:

    • Uses convolutional layers (Conv1d) and a WaveNet-like module (modules.WN) for feature extraction.
    • Projects the features into mean (m) and log-variance (logs) for latent space sampling.
  • Forward Pass:

    • Encodes the input spectrogram and samples latent variables (z) using the reparameterization trick.

This module is essential for learning a latent representation of the input audio.


6. Generator

The Generator synthesizes waveforms from latent representations.

  • Initialization:

    • Uses transposed convolutions (ConvTranspose1d) for upsampling the latent representation to the audio domain.
    • Includes residual blocks (ResBlock1 or ResBlock2) for refining the generated waveform.
    • Supports global conditioning for multi-speaker models.
  • Forward Pass:

    • Upsamples the latent representation and applies residual blocks to generate the waveform.
    • Outputs the final waveform after applying a tanh activation.

This module is responsible for generating high-quality audio from the latent representation.


7. DiscriminatorP and DiscriminatorS

These modules are part of the adversarial training setup for improving audio quality.

  • DiscriminatorP:

    • Operates on periodic segments of the audio to capture fine-grained details.
    • Uses 2D convolutions (Conv2d) to process the periodic representation.
  • DiscriminatorS:

    • Operates on the full audio signal to capture global features.
    • Uses 1D convolutions (Conv1d) with grouped convolutions for efficient processing.

Both discriminators output feature maps (fmap) and classification scores, which are used to compute adversarial losses.


8. MultiPeriodDiscriminator

This module combines multiple discriminators (DiscriminatorP and DiscriminatorS) to evaluate the generated audio at different granularities.

  • Initialization:

    • Includes one DiscriminatorS and multiple DiscriminatorP modules with different periodicities.
  • Forward Pass:

    • Processes the real and generated audio through all discriminators.
    • Outputs classification scores and feature maps for adversarial training.

This module ensures that the generated audio is realistic across multiple temporal scales.


9. SynthesizerTrn

The SynthesizerTrn is the main model for training the TTS system.

  • Initialization:

    • Combines the TextEncoder, PosteriorEncoder, Generator, and ResidualCouplingBlock modules.
    • Includes a StochasticDurationPredictor or DurationPredictor for modeling phoneme durations.
    • Supports multi-speaker synthesis with speaker embeddings.
  • Forward Pass:

    • Encodes the input text and spectrogram into latent representations.
    • Aligns the text and audio using attention-based alignment.
    • Synthesizes the waveform using the Generator.
  • Inference:

    • Generates audio from text by predicting durations and sampling latent variables.
    • Supports voice conversion by transforming latent variables between speakers.

This module integrates all components into a cohesive pipeline for training and inference in TTS systems.


Applications

These modules collectively form the backbone of a TTS system. They are designed for tasks such as:

  • Speech Synthesis: Generating natural-sounding speech from text.
  • Voice Conversion: Transforming the voice characteristics of one speaker into another.
  • Multi-Speaker TTS: Synthesizing speech for multiple speakers using speaker embeddings.

The modular design allows for flexibility and extensibility, making it suitable for various speech synthesis applications.

Source Code

import math
import typing
import torch
from torch import nn
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from . import attentions, commons, modules, monotonic_align
from .commons import get_padding, init_weights
class StochasticDurationPredictor(nn.Module):
def __init__(
self,
in_channels: int,
filter_channels: int,
kernel_size: int,
p_dropout: float,
n_flows: int = 4,
gin_channels: int = 0,
):
super().__init__()
filter_channels = in_channels # it needs to be removed from future version.
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.n_flows = n_flows
self.gin_channels = gin_channels
self.log_flow = modules.Log()
self.flows = nn.ModuleList()
self.flows.append(modules.ElementwiseAffine(2))
for i in range(n_flows):
self.flows.append(
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4):
self.post_flows.append(
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
x = torch.detach(x)
x = self.pre(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.convs(x, x_mask)
x = self.proj(x) * x_mask
if not reverse:
flows = self.flows
assert w is not None
logdet_tot_q = 0
h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = torch.randn(w.size(0), 2, w.size(2)).type_as(x) * x_mask
z_q = e_q
for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
logdet_tot_q += logdet_q
z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += torch.sum(
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
)
logq = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
- logdet_tot_q
)
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
logdet_tot += logdet
z = torch.cat([z0, z1], 1)
for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet
nll = (
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
- logdet_tot
)
return nll + logq # [b]
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.randn(x.size(0), 2, x.size(2)).type_as(x) * noise_scale
for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1)
logw = z0
return logw
class DurationPredictor(nn.Module):
def __init__(
self,
in_channels: int,
filter_channels: int,
kernel_size: int,
p_dropout: float,
gin_channels: int = 0,
):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = modules.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
def forward(self, x, x_mask, g=None):
x = torch.detach(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class TextEncoder(nn.Module):
def __init__(
self,
n_vocab: int,
out_channels: int,
hidden_channels: int,
filter_channels: int,
n_heads: int,
n_layers: int,
kernel_size: int,
p_dropout: float,
):
super().__init__()
self.n_vocab = n_vocab
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
self.encoder = attentions.Encoder(
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths):
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(
commons.sequence_mask(x_lengths, x.size(2)), 1
).type_as(x)
x = self.encoder(x * x_mask, x_mask)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return x, m, logs, x_mask
class ResidualCouplingBlock(nn.Module):
def __init__(
self,
channels: int,
hidden_channels: int,
kernel_size: int,
dilation_rate: int,
n_layers: int,
n_flows: int = 4,
gin_channels: int = 0,
):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.n_flows = n_flows
self.gin_channels = gin_channels
self.flows = nn.ModuleList()
for i in range(n_flows):
self.flows.append(
modules.ResidualCouplingLayer(
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
mean_only=True,
)
)
self.flows.append(modules.Flip())
def forward(self, x, x_mask, g=None, reverse=False):
if not reverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, reverse=reverse)
else:
for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, reverse=reverse)
return x
class PosteriorEncoder(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int,
kernel_size: int,
dilation_rate: int,
n_layers: int,
gin_channels: int = 0,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(
commons.sequence_mask(x_lengths, x.size(2)), 1
).type_as(x)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
return z, m, logs, x_mask
class Generator(torch.nn.Module):
def __init__(
self,
initial_channel: int,
resblock: typing.Optional[str],
resblock_kernel_sizes: typing.Tuple[int, ...],
resblock_dilation_sizes: typing.Tuple[typing.Tuple[int, ...], ...],
upsample_rates: typing.Tuple[int, ...],
upsample_initial_channel: int,
upsample_kernel_sizes: typing.Tuple[int, ...],
gin_channels: int = 0,
):
super(Generator, self).__init__()
self.LRELU_SLOPE = 0.1
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
resblock_module = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
self.resblocks.append(resblock_module(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
self.ups.apply(init_weights)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g=None):
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
for i, up in enumerate(self.ups):
x = F.leaky_relu(x, self.LRELU_SLOPE)
x = up(x)
xs = torch.zeros(1)
for j, resblock in enumerate(self.resblocks):
index = j - (i * self.num_kernels)
if index == 0:
xs = resblock(x)
elif (index > 0) and (index < self.num_kernels):
xs += resblock(x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print("Removing weight norm...")
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
class DiscriminatorP(torch.nn.Module):
def __init__(
self,
period: int,
kernel_size: int = 5,
stride: int = 3,
use_spectral_norm: bool = False,
):
super(DiscriminatorP, self).__init__()
self.LRELU_SLOPE = 0.1
self.period = period
self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if not use_spectral_norm else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(
Conv2d(
1,
32,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
32,
128,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
128,
512,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
512,
1024,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
1024,
1024,
(kernel_size, 1),
1,
padding=(get_padding(kernel_size, 1), 0),
)
),
]
)
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, self.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
self.LRELU_SLOPE = 0.1
norm_f = spectral_norm if use_spectral_norm else weight_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, self.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(MultiPeriodDiscriminator, self).__init__()
periods = [2, 3, 5, 7, 11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
]
self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
y_d_gs.append(y_d_g)
fmap_rs.append(fmap_r)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class SynthesizerTrn(nn.Module):
"""
Synthesizer for Training
"""
def __init__(
self,
n_vocab: int,
spec_channels: int,
segment_size: int,
inter_channels: int,
hidden_channels: int,
filter_channels: int,
n_heads: int,
n_layers: int,
kernel_size: int,
p_dropout: float,
resblock: str,
resblock_kernel_sizes: typing.Tuple[int, ...],
resblock_dilation_sizes: typing.Tuple[typing.Tuple[int, ...], ...],
upsample_rates: typing.Tuple[int, ...],
upsample_initial_channel: int,
upsample_kernel_sizes: typing.Tuple[int, ...],
n_speakers: int = 1,
gin_channels: int = 0,
use_sdp: bool = True,
):
super().__init__()
self.n_vocab = n_vocab
self.spec_channels = spec_channels
self.inter_channels = inter_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.resblock = resblock
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
self.upsample_rates = upsample_rates
self.upsample_initial_channel = upsample_initial_channel
self.upsample_kernel_sizes = upsample_kernel_sizes
self.segment_size = segment_size
self.n_speakers = n_speakers
self.gin_channels = gin_channels
self.use_sdp = use_sdp
self.enc_p = TextEncoder(
n_vocab,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
)
self.dec = Generator(
inter_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels,
inter_channels,
hidden_channels,
5,
1,
16,
gin_channels=gin_channels,
)
self.flow = ResidualCouplingBlock(
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)
if use_sdp:
self.dp = StochasticDurationPredictor(
hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
)
else:
self.dp = DurationPredictor(
hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
)
if n_speakers > 1:
self.emb_g = nn.Embedding(n_speakers, gin_channels)
def forward(self, x, x_lengths, y, y_lengths, sid=None):
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
if self.n_speakers > 1:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else:
g = None
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
z_p = self.flow(z, y_mask, g=g)
with torch.no_grad():
# negative cross-entropy
s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
neg_cent1 = torch.sum(
-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
) # [b, 1, t_s]
neg_cent2 = torch.matmul(
-0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
neg_cent3 = torch.matmul(
z_p.transpose(1, 2), (m_p * s_p_sq_r)
) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
neg_cent4 = torch.sum(
-0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
) # [b, 1, t_s]
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = (
monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
.unsqueeze(1)
.detach()
)
w = attn.sum(2)
if self.use_sdp:
l_length = self.dp(x, x_mask, w, g=g)
l_length = l_length / torch.sum(x_mask)
else:
logw_ = torch.log(w + 1e-6) * x_mask
logw = self.dp(x, x_mask, g=g)
l_length = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
x_mask
) # for averaging
# expand prior
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
z_slice, ids_slice = commons.rand_slice_segments(
z, y_lengths, self.segment_size
)
o = self.dec(z_slice, g=g)
return (
o,
l_length,
attn,
ids_slice,
x_mask,
y_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
)
def infer(
self,
x,
x_lengths,
sid=None,
noise_scale=0.667,
length_scale=1,
noise_scale_w=0.8,
max_len=None,
):
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
if self.n_speakers > 1:
assert sid is not None, "Missing speaker id"
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else:
g = None
if self.use_sdp:
logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
else:
logw = self.dp(x, x_mask, g=g)
w = torch.exp(logw) * x_mask * length_scale
w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_mask = torch.unsqueeze(
commons.sequence_mask(y_lengths, y_lengths.max()), 1
).type_as(x_mask)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = commons.generate_path(w_ceil, attn_mask)
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1, 2
) # [b, t', t], [b, t, d] -> [b, d, t']
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1, 2
) # [b, t', t], [b, t, d] -> [b, d, t']
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=g, reverse=True)
o = self.dec((z * y_mask)[:, :, :max_len], g=g)
return o, attn, y_mask, (z, z_p, m_p, logs_p)
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
assert self.n_speakers > 1, "n_speakers have to be larger than 1."
g_src = self.emb_g(sid_src).unsqueeze(-1)
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
z_p = self.flow(z, y_mask, g=g_src)
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
o_hat = self.dec(z_hat * y_mask, g=g_tgt)
return o_hat, y_mask, (z, z_p, z_hat)