Skip to content

losses.py

Code Explained

The provided code defines several loss functions that are essential for training deep learning models, particularly in the context of generative adversarial networks (GANs) and variational models. These loss functions are designed to optimize different aspects of the model, such as feature similarity, adversarial training, and latent space regularization. Below is a detailed explanation of each function:


1. feature_loss

This function computes the feature matching loss, which measures the difference between feature maps extracted from real and generated data. It is commonly used in GANs to ensure that the generator produces outputs with similar intermediate representations to real data.

  • Inputs:
    • fmap_r: Feature maps from real data.
    • fmap_g: Feature maps from generated data.
  • Process:
    • Iterates through corresponding feature maps (dr and dg) from real and generated data.
    • For each layer, computes the mean absolute difference between the feature maps.
    • Detaches the real feature map (rl) to prevent gradients from flowing into the discriminator.
  • Output:
    • The total feature loss, scaled by a factor of 2.

This loss encourages the generator to produce outputs that align with the discriminator’s learned features, improving the perceptual quality of the generated data.


2. discriminator_loss

This function computes the loss for the discriminator in a GAN. The discriminator’s goal is to distinguish between real and generated data.

  • Inputs:
    • disc_real_outputs: Discriminator outputs for real data.
    • disc_generated_outputs: Discriminator outputs for generated data.
  • Process:
    • For each pair of real (dr) and generated (dg) outputs:
      • Computes the loss for real data as the mean squared error between the output and 1 ((1 - dr) ** 2), encouraging the discriminator to classify real data as “real.”
      • Computes the loss for generated data as the mean squared error between the output and 0 (dg ** 2), encouraging the discriminator to classify generated data as “fake.”
      • Accumulates the total loss and stores individual losses for real and generated data.
  • Output:
    • The total discriminator loss.
    • Lists of individual losses for real and generated data.

This loss ensures that the discriminator learns to differentiate real data from generated data effectively.


3. generator_loss

This function computes the loss for the generator in a GAN. The generator’s goal is to produce data that the discriminator classifies as “real.”

  • Inputs:
    • disc_outputs: Discriminator outputs for generated data.
  • Process:
    • For each discriminator output (dg):
      • Computes the loss as the mean squared error between the output and 1 ((1 - dg) ** 2), encouraging the generator to fool the discriminator.
      • Accumulates the total loss and stores individual losses.
  • Output:
    • The total generator loss.
    • A list of individual losses for each discriminator output.

This loss drives the generator to improve the quality of its outputs, making them indistinguishable from real data.


4. kl_loss

This function computes the Kullback-Leibler (KL) divergence loss, which is used in variational models to regularize the latent space.

  • Inputs:
    • z_p, logs_q: Latent variables and log variances from the posterior distribution.
    • m_p, logs_p: Mean and log variances from the prior distribution.
    • z_mask: A mask to ignore padding or irrelevant parts of the input.
  • Process:
    • Computes the KL divergence between the posterior (q) and prior (p) distributions:
      • The first term penalizes the difference in log variances.
      • The second term penalizes the squared difference in means, scaled by the prior variance.
    • Applies the mask to ignore irrelevant regions and normalizes the loss by the sum of the mask.
  • Output:
    • The normalized KL divergence loss.

This loss ensures that the latent space follows the prior distribution, which is crucial for variational inference and sampling.


Key Use Cases

  1. Feature Loss: Improves the perceptual quality of generated data by aligning feature maps.
  2. Discriminator and Generator Losses: Form the adversarial training loop in GANs, where the generator and discriminator compete to improve each other.
  3. KL Loss: Regularizes the latent space in variational models, ensuring meaningful and smooth latent representations.

These loss functions are critical for training models in tasks like image generation, speech synthesis, and other generative applications.

Source Code

import torch
def feature_loss(fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
rl = rl.float().detach()
gl = gl.float()
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
dr = dr.float()
dg = dg.float()
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2)
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_loss(disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
dg = dg.float()
l_dg = torch.mean((1 - dg) ** 2)
gen_losses.append(l_dg)
loss += l_dg
return loss, gen_losses
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
"""
z_p, logs_q: [b, h, t_t]
m_p, logs_p: [b, h, t_t]
"""
z_p = z_p.float()
logs_q = logs_q.float()
m_p = m_p.float()
logs_p = logs_p.float()
z_mask = z_mask.float()
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
kl = torch.sum(kl * z_mask)
l_kl = kl / torch.sum(z_mask)
return l_kl