Skip to content

config.py

Code Explained

The provided code defines several configuration classes using Python’s dataclass decorator. These classes encapsulate various parameters for audio processing, model architecture, and training, making the codebase modular and easier to manage. Below is a detailed explanation of each class:


1. MelAudioConfig

This class defines the configuration for generating mel spectrograms, which are commonly used in audio processing tasks like speech synthesis. Key attributes include:

  • filter_length: The length of the FFT window.
  • hop_length: The step size between successive frames.
  • win_length: The size of the window applied to each frame.
  • mel_channels: The number of mel frequency bins.
  • sample_rate: The audio sampling rate.
  • mel_fmin and mel_fmax: The minimum and maximum frequencies for the mel filter bank.

These parameters control how raw audio is transformed into mel spectrograms, which serve as input features for models.


2. ModelAudioConfig

This class specifies the audio-related parameters for the model’s architecture. Key attributes include:

  • resblock: The type of residual block used in the model.
  • resblock_kernel_sizes: Kernel sizes for the residual blocks.
  • resblock_dilation_sizes: Dilation rates for the residual blocks.
  • upsample_rates: Upsampling factors for the model’s decoder.
  • upsample_initial_channel: The number of channels in the initial upsampling layer.
  • upsample_kernel_sizes: Kernel sizes for the upsampling layers.

The class also provides two static methods:

  • low_quality(): Returns a configuration optimized for lower-quality audio with smaller kernel sizes and fewer channels.
  • high_quality(): Returns a configuration optimized for higher-quality audio with larger kernel sizes and more channels.

These methods allow for quick switching between different audio quality settings.


3. ModelConfig

This class defines the overall configuration for the model, combining audio settings, mel spectrogram settings, and architectural parameters. Key attributes include:

  • num_symbols: The number of input symbols (e.g., phonemes or characters).
  • n_speakers: The number of speakers in the dataset.
  • audio: An instance of ModelAudioConfig.
  • mel: An instance of MelAudioConfig (defaulted to a new instance).
  • inter_channels, hidden_channels, filter_channels: Dimensions for intermediate, hidden, and filter layers.
  • n_heads and n_layers: Parameters for multi-head attention and the number of layers in the model.
  • use_spectral_norm: Whether to apply spectral normalization.
  • gin_channels: The number of global conditioning channels, used for multi-speaker models.
  • segment_size: The size of audio segments processed by the model.

Properties:

  • The class provides several properties to expose audio-related attributes (e.g., resblock, upsample_rates) directly from the audio configuration.
  • is_multispeaker: A property that checks if the model is configured for multiple speakers.

Post-Initialization:

  • The __post_init__ method ensures that gin_channels is set to 512 if the model is multi-speaker and gin_channels is not explicitly defined.

4. TrainingConfig

This class encapsulates hyperparameters for training the model. Key attributes include:

  • learning_rate: The initial learning rate.
  • betas and eps: Parameters for the Adam optimizer.
  • fp16_run: Whether to use mixed-precision training for faster computation.
  • lr_decay: The learning rate decay factor.
  • warmup_epochs: The number of warmup epochs for learning rate scheduling.
  • c_mel and c_kl: Coefficients for mel spectrogram loss and KL divergence loss.
  • grad_clip: The maximum gradient value for clipping, preventing exploding gradients.

Purpose and Use Case

These configuration classes are designed to make the codebase modular and flexible. By encapsulating parameters in dedicated classes, the code becomes easier to maintain and extend. For example:

  • MelAudioConfig and ModelAudioConfig handle audio-specific settings, enabling quick adjustments for different datasets or quality requirements.
  • ModelConfig centralizes model architecture settings, ensuring consistency across components.
  • TrainingConfig simplifies the management of training hyperparameters, making it easier to experiment with different setups.

This structure is particularly useful in machine learning projects where configurations often need to be adjusted for different tasks, datasets, or hardware setups.

Source Code

"""Configuration classes"""
from dataclasses import dataclass, field
from typing import Optional, Tuple
@dataclass
class MelAudioConfig:
filter_length: int = 1024
hop_length: int = 256
win_length: int = 1024
mel_channels: int = 80
sample_rate: int = 22050
sample_bytes: int = 2
channels: int = 1
mel_fmin: float = 0.0
mel_fmax: Optional[float] = None
@dataclass
class ModelAudioConfig:
resblock: str
resblock_kernel_sizes: Tuple[int, ...]
resblock_dilation_sizes: Tuple[Tuple[int, ...], ...]
upsample_rates: Tuple[int, ...]
upsample_initial_channel: int
upsample_kernel_sizes: Tuple[int, ...]
@staticmethod
def low_quality() -> "ModelAudioConfig":
return ModelAudioConfig(
resblock="2",
resblock_kernel_sizes=(3, 5, 7),
resblock_dilation_sizes=(
(1, 2),
(2, 6),
(3, 12),
),
upsample_rates=(8, 8, 4),
upsample_initial_channel=256,
upsample_kernel_sizes=(16, 16, 8),
)
@staticmethod
def high_quality() -> "ModelAudioConfig":
return ModelAudioConfig(
resblock="1",
resblock_kernel_sizes=(3, 7, 11),
resblock_dilation_sizes=(
(1, 3, 5),
(1, 3, 5),
(1, 3, 5),
),
upsample_rates=(8, 8, 2, 2),
upsample_initial_channel=512,
upsample_kernel_sizes=(16, 16, 4, 4),
)
@dataclass
class ModelConfig:
num_symbols: int
n_speakers: int
audio: ModelAudioConfig
mel: MelAudioConfig = field(default_factory=MelAudioConfig)
inter_channels: int = 192
hidden_channels: int = 192
filter_channels: int = 768
n_heads: int = 2
n_layers: int = 6
kernel_size: int = 3
p_dropout: float = 0.1
n_layers_q: int = 3
use_spectral_norm: bool = False
gin_channels: int = 0 # single speaker
use_sdp: bool = True # StochasticDurationPredictor
segment_size: int = 8192
@property
def is_multispeaker(self) -> bool:
return self.n_speakers > 1
@property
def resblock(self) -> str:
return self.audio.resblock
@property
def resblock_kernel_sizes(self) -> Tuple[int, ...]:
return self.audio.resblock_kernel_sizes
@property
def resblock_dilation_sizes(self) -> Tuple[Tuple[int, ...], ...]:
return self.audio.resblock_dilation_sizes
@property
def upsample_rates(self) -> Tuple[int, ...]:
return self.audio.upsample_rates
@property
def upsample_initial_channel(self) -> int:
return self.audio.upsample_initial_channel
@property
def upsample_kernel_sizes(self) -> Tuple[int, ...]:
return self.audio.upsample_kernel_sizes
def __post_init__(self):
if self.is_multispeaker and (self.gin_channels == 0):
self.gin_channels = 512
@dataclass
class TrainingConfig:
learning_rate: float = 2e-4
betas: Tuple[float, float] = field(default=(0.8, 0.99))
eps: float = 1e-9
# batch_size: int = 32
fp16_run: bool = False
lr_decay: float = 0.999875
init_lr_ratio: float = 1.0
warmup_epochs: int = 0
c_mel: int = 45
c_kl: float = 1.0
grad_clip: Optional[float] = None
# @dataclass
# class PhonemesConfig(DataClassJsonMixin):
# phoneme_separator: str = " "
# """Separator between individual phonemes in CSV input"""
# word_separator: str = "#"
# """Separator between word phonemes in CSV input (must not match phoneme_separator)"""
# phoneme_to_id: typing.Optional[typing.Dict[str, int]] = None
# pad: typing.Optional[str] = "_"
# bos: typing.Optional[str] = None
# eos: typing.Optional[str] = None
# blank: typing.Optional[str] = "#"
# blank_word: typing.Optional[str] = None
# blank_between: typing.Union[str, BlankBetween] = BlankBetween.WORDS
# blank_at_start: bool = True
# blank_at_end: bool = True
# simple_punctuation: bool = True
# punctuation_map: typing.Optional[typing.Dict[str, str]] = None
# separate: typing.Optional[typing.List[str]] = None
# separate_graphemes: bool = False
# separate_tones: bool = False
# tone_before: bool = False
# phoneme_map: typing.Optional[typing.Dict[str, typing.List[str]]] = None
# auto_bos_eos: bool = False
# minor_break: typing.Optional[str] = IPA.BREAK_MINOR.value
# major_break: typing.Optional[str] = IPA.BREAK_MAJOR.value
# break_phonemes_into_graphemes: bool = False
# break_phonemes_into_codepoints: bool = False
# drop_stress: bool = False
# symbols: typing.Optional[typing.List[str]] = None
# def split_word_phonemes(self, phonemes_str: str) -> typing.List[typing.List[str]]:
# """Split phonemes string into a list of lists (outer is words, inner is individual phonemes in each word)"""
# return [
# word_phonemes_str.split(self.phoneme_separator)
# if self.phoneme_separator
# else list(word_phonemes_str)
# for word_phonemes_str in phonemes_str.split(self.word_separator)
# ]
# def join_word_phonemes(self, word_phonemes: typing.List[typing.List[str]]) -> str:
# """Split phonemes string into a list of lists (outer is words, inner is individual phonemes in each word)"""
# return self.word_separator.join(
# self.phoneme_separator.join(wp) for wp in word_phonemes
# )
# class Phonemizer(str, Enum):
# SYMBOLS = "symbols"
# GRUUT = "gruut"
# ESPEAK = "espeak"
# EPITRAN = "epitran"
# class Aligner(str, Enum):
# KALDI_ALIGN = "kaldi_align"
# class TextCasing(str, Enum):
# LOWER = "lower"
# UPPER = "upper"
# class MetadataFormat(str, Enum):
# TEXT = "text"
# PHONEMES = "phonemes"
# PHONEME_IDS = "ids"
# @dataclass
# class DatasetConfig:
# name: str
# metadata_format: MetadataFormat = MetadataFormat.TEXT
# multispeaker: bool = False
# text_language: typing.Optional[str] = None
# audio_dir: typing.Optional[typing.Union[str, Path]] = None
# cache_dir: typing.Optional[typing.Union[str, Path]] = None
# def get_cache_dir(self, output_dir: typing.Union[str, Path]) -> Path:
# if self.cache_dir is not None:
# cache_dir = Path(self.cache_dir)
# else:
# cache_dir = Path("cache") / self.name
# if not cache_dir.is_absolute():
# cache_dir = Path(output_dir) / str(cache_dir)
# return cache_dir
# @dataclass
# class AlignerConfig:
# aligner: typing.Optional[Aligner] = None
# casing: typing.Optional[TextCasing] = None
# @dataclass
# class InferenceConfig:
# length_scale: float = 1.0
# noise_scale: float = 0.667
# noise_w: float = 0.8
# @dataclass
# class TrainingConfig(DataClassJsonMixin):
# seed: int = 1234
# epochs: int = 10000
# learning_rate: float = 2e-4
# betas: typing.Tuple[float, float] = field(default=(0.8, 0.99))
# eps: float = 1e-9
# batch_size: int = 32
# fp16_run: bool = False
# lr_decay: float = 0.999875
# segment_size: int = 8192
# init_lr_ratio: float = 1.0
# warmup_epochs: int = 0
# c_mel: int = 45
# c_kl: float = 1.0
# grad_clip: typing.Optional[float] = None
# min_seq_length: typing.Optional[int] = None
# max_seq_length: typing.Optional[int] = None
# min_spec_length: typing.Optional[int] = None
# max_spec_length: typing.Optional[int] = None
# min_speaker_utterances: typing.Optional[int] = None
# last_epoch: int = 1
# global_step: int = 1
# best_loss: typing.Optional[float] = None
# audio: AudioConfig = field(default_factory=AudioConfig)
# model: ModelConfig = field(default_factory=ModelConfig)
# phonemes: PhonemesConfig = field(default_factory=PhonemesConfig)
# text_aligner: AlignerConfig = field(default_factory=AlignerConfig)
# text_language: typing.Optional[str] = None
# phonemizer: typing.Optional[Phonemizer] = None
# datasets: typing.List[DatasetConfig] = field(default_factory=list)
# inference: InferenceConfig = field(default_factory=InferenceConfig)
# version: int = 1
# git_commit: str = ""
# @property
# def is_multispeaker(self):
# return self.model.is_multispeaker or any(d.multispeaker for d in self.datasets)
# def save(self, config_file: typing.TextIO):
# """Save config as JSON to a file"""
# json.dump(self.to_dict(), config_file, indent=4)
# def get_speaker_id(self, dataset_name: str, speaker_name: str) -> int:
# if self.speaker_id_map is None:
# self.speaker_id_map = {}
# full_speaker_name = f"{dataset_name}_{speaker_name}"
# speaker_id = self.speaker_id_map.get(full_speaker_name)
# if speaker_id is None:
# speaker_id = len(self.speaker_id_map)
# self.speaker_id_map[full_speaker_name] = speaker_id
# return speaker_id
# @staticmethod
# def load(config_file: typing.TextIO) -> "TrainingConfig":
# """Load config from a JSON file"""
# return TrainingConfig.from_json(config_file.read())
# @staticmethod
# def load_and_merge(
# config: "TrainingConfig",
# config_files: typing.Iterable[typing.Union[str, Path, typing.TextIO]],
# ) -> "TrainingConfig":
# """Loads one or more JSON configuration files and overlays them on top of an existing config"""
# base_dict = config.to_dict()
# for maybe_config_file in config_files:
# if isinstance(maybe_config_file, (str, Path)):
# # File path
# config_file = open(maybe_config_file, "r", encoding="utf-8")
# else:
# # File object
# config_file = maybe_config_file
# with config_file:
# # Load new config and overlay on existing config
# new_dict = json.load(config_file)
# TrainingConfig.recursive_update(base_dict, new_dict)
# return TrainingConfig.from_dict(base_dict)
# @staticmethod
# def recursive_update(
# base_dict: typing.Dict[typing.Any, typing.Any],
# new_dict: typing.Mapping[typing.Any, typing.Any],
# ) -> None:
# """Recursively overwrites values in base dictionary with values from new dictionary"""
# for key, value in new_dict.items():
# if isinstance(value, collections.Mapping) and (
# base_dict.get(key) is not None
# ):
# TrainingConfig.recursive_update(base_dict[key], value)
# else:
# base_dict[key] = value