config.py
Code Explained
The provided code defines several configuration classes using Python’s dataclass decorator. These classes encapsulate various parameters for audio processing, model architecture, and training, making the codebase modular and easier to manage. Below is a detailed explanation of each class:
1. MelAudioConfig
This class defines the configuration for generating mel spectrograms, which are commonly used in audio processing tasks like speech synthesis. Key attributes include:
filter_length: The length of the FFT window.hop_length: The step size between successive frames.win_length: The size of the window applied to each frame.mel_channels: The number of mel frequency bins.sample_rate: The audio sampling rate.mel_fminandmel_fmax: The minimum and maximum frequencies for the mel filter bank.
These parameters control how raw audio is transformed into mel spectrograms, which serve as input features for models.
2. ModelAudioConfig
This class specifies the audio-related parameters for the model’s architecture. Key attributes include:
resblock: The type of residual block used in the model.resblock_kernel_sizes: Kernel sizes for the residual blocks.resblock_dilation_sizes: Dilation rates for the residual blocks.upsample_rates: Upsampling factors for the model’s decoder.upsample_initial_channel: The number of channels in the initial upsampling layer.upsample_kernel_sizes: Kernel sizes for the upsampling layers.
The class also provides two static methods:
low_quality(): Returns a configuration optimized for lower-quality audio with smaller kernel sizes and fewer channels.high_quality(): Returns a configuration optimized for higher-quality audio with larger kernel sizes and more channels.
These methods allow for quick switching between different audio quality settings.
3. ModelConfig
This class defines the overall configuration for the model, combining audio settings, mel spectrogram settings, and architectural parameters. Key attributes include:
num_symbols: The number of input symbols (e.g., phonemes or characters).n_speakers: The number of speakers in the dataset.audio: An instance ofModelAudioConfig.mel: An instance ofMelAudioConfig(defaulted to a new instance).inter_channels,hidden_channels,filter_channels: Dimensions for intermediate, hidden, and filter layers.n_headsandn_layers: Parameters for multi-head attention and the number of layers in the model.use_spectral_norm: Whether to apply spectral normalization.gin_channels: The number of global conditioning channels, used for multi-speaker models.segment_size: The size of audio segments processed by the model.
Properties:
- The class provides several properties to expose audio-related attributes (e.g.,
resblock,upsample_rates) directly from theaudioconfiguration. is_multispeaker: A property that checks if the model is configured for multiple speakers.
Post-Initialization:
- The
__post_init__method ensures thatgin_channelsis set to 512 if the model is multi-speaker andgin_channelsis not explicitly defined.
4. TrainingConfig
This class encapsulates hyperparameters for training the model. Key attributes include:
learning_rate: The initial learning rate.betasandeps: Parameters for the Adam optimizer.fp16_run: Whether to use mixed-precision training for faster computation.lr_decay: The learning rate decay factor.warmup_epochs: The number of warmup epochs for learning rate scheduling.c_melandc_kl: Coefficients for mel spectrogram loss and KL divergence loss.grad_clip: The maximum gradient value for clipping, preventing exploding gradients.
Purpose and Use Case
These configuration classes are designed to make the codebase modular and flexible. By encapsulating parameters in dedicated classes, the code becomes easier to maintain and extend. For example:
MelAudioConfigandModelAudioConfighandle audio-specific settings, enabling quick adjustments for different datasets or quality requirements.ModelConfigcentralizes model architecture settings, ensuring consistency across components.TrainingConfigsimplifies the management of training hyperparameters, making it easier to experiment with different setups.
This structure is particularly useful in machine learning projects where configurations often need to be adjusted for different tasks, datasets, or hardware setups.
Source Code
"""Configuration classes"""from dataclasses import dataclass, fieldfrom typing import Optional, Tuple
@dataclassclass MelAudioConfig: filter_length: int = 1024 hop_length: int = 256 win_length: int = 1024 mel_channels: int = 80 sample_rate: int = 22050 sample_bytes: int = 2 channels: int = 1 mel_fmin: float = 0.0 mel_fmax: Optional[float] = None
@dataclassclass ModelAudioConfig: resblock: str resblock_kernel_sizes: Tuple[int, ...] resblock_dilation_sizes: Tuple[Tuple[int, ...], ...] upsample_rates: Tuple[int, ...] upsample_initial_channel: int upsample_kernel_sizes: Tuple[int, ...]
@staticmethod def low_quality() -> "ModelAudioConfig": return ModelAudioConfig( resblock="2", resblock_kernel_sizes=(3, 5, 7), resblock_dilation_sizes=( (1, 2), (2, 6), (3, 12), ), upsample_rates=(8, 8, 4), upsample_initial_channel=256, upsample_kernel_sizes=(16, 16, 8), )
@staticmethod def high_quality() -> "ModelAudioConfig": return ModelAudioConfig( resblock="1", resblock_kernel_sizes=(3, 7, 11), resblock_dilation_sizes=( (1, 3, 5), (1, 3, 5), (1, 3, 5), ), upsample_rates=(8, 8, 2, 2), upsample_initial_channel=512, upsample_kernel_sizes=(16, 16, 4, 4), )
@dataclassclass ModelConfig: num_symbols: int n_speakers: int audio: ModelAudioConfig mel: MelAudioConfig = field(default_factory=MelAudioConfig)
inter_channels: int = 192 hidden_channels: int = 192 filter_channels: int = 768 n_heads: int = 2 n_layers: int = 6 kernel_size: int = 3 p_dropout: float = 0.1 n_layers_q: int = 3 use_spectral_norm: bool = False gin_channels: int = 0 # single speaker use_sdp: bool = True # StochasticDurationPredictor segment_size: int = 8192
@property def is_multispeaker(self) -> bool: return self.n_speakers > 1
@property def resblock(self) -> str: return self.audio.resblock
@property def resblock_kernel_sizes(self) -> Tuple[int, ...]: return self.audio.resblock_kernel_sizes
@property def resblock_dilation_sizes(self) -> Tuple[Tuple[int, ...], ...]: return self.audio.resblock_dilation_sizes
@property def upsample_rates(self) -> Tuple[int, ...]: return self.audio.upsample_rates
@property def upsample_initial_channel(self) -> int: return self.audio.upsample_initial_channel
@property def upsample_kernel_sizes(self) -> Tuple[int, ...]: return self.audio.upsample_kernel_sizes
def __post_init__(self): if self.is_multispeaker and (self.gin_channels == 0): self.gin_channels = 512
@dataclassclass TrainingConfig: learning_rate: float = 2e-4 betas: Tuple[float, float] = field(default=(0.8, 0.99)) eps: float = 1e-9 # batch_size: int = 32 fp16_run: bool = False lr_decay: float = 0.999875 init_lr_ratio: float = 1.0 warmup_epochs: int = 0 c_mel: int = 45 c_kl: float = 1.0 grad_clip: Optional[float] = None
# @dataclass# class PhonemesConfig(DataClassJsonMixin):# phoneme_separator: str = " "# """Separator between individual phonemes in CSV input"""
# word_separator: str = "#"# """Separator between word phonemes in CSV input (must not match phoneme_separator)"""
# phoneme_to_id: typing.Optional[typing.Dict[str, int]] = None# pad: typing.Optional[str] = "_"# bos: typing.Optional[str] = None# eos: typing.Optional[str] = None# blank: typing.Optional[str] = "#"# blank_word: typing.Optional[str] = None# blank_between: typing.Union[str, BlankBetween] = BlankBetween.WORDS# blank_at_start: bool = True# blank_at_end: bool = True# simple_punctuation: bool = True# punctuation_map: typing.Optional[typing.Dict[str, str]] = None# separate: typing.Optional[typing.List[str]] = None# separate_graphemes: bool = False# separate_tones: bool = False# tone_before: bool = False# phoneme_map: typing.Optional[typing.Dict[str, typing.List[str]]] = None# auto_bos_eos: bool = False# minor_break: typing.Optional[str] = IPA.BREAK_MINOR.value# major_break: typing.Optional[str] = IPA.BREAK_MAJOR.value# break_phonemes_into_graphemes: bool = False# break_phonemes_into_codepoints: bool = False# drop_stress: bool = False# symbols: typing.Optional[typing.List[str]] = None
# def split_word_phonemes(self, phonemes_str: str) -> typing.List[typing.List[str]]:# """Split phonemes string into a list of lists (outer is words, inner is individual phonemes in each word)"""# return [# word_phonemes_str.split(self.phoneme_separator)# if self.phoneme_separator# else list(word_phonemes_str)# for word_phonemes_str in phonemes_str.split(self.word_separator)# ]
# def join_word_phonemes(self, word_phonemes: typing.List[typing.List[str]]) -> str:# """Split phonemes string into a list of lists (outer is words, inner is individual phonemes in each word)"""# return self.word_separator.join(# self.phoneme_separator.join(wp) for wp in word_phonemes# )
# class Phonemizer(str, Enum):# SYMBOLS = "symbols"# GRUUT = "gruut"# ESPEAK = "espeak"# EPITRAN = "epitran"
# class Aligner(str, Enum):# KALDI_ALIGN = "kaldi_align"
# class TextCasing(str, Enum):# LOWER = "lower"# UPPER = "upper"
# class MetadataFormat(str, Enum):# TEXT = "text"# PHONEMES = "phonemes"# PHONEME_IDS = "ids"
# @dataclass# class DatasetConfig:# name: str# metadata_format: MetadataFormat = MetadataFormat.TEXT# multispeaker: bool = False# text_language: typing.Optional[str] = None# audio_dir: typing.Optional[typing.Union[str, Path]] = None# cache_dir: typing.Optional[typing.Union[str, Path]] = None
# def get_cache_dir(self, output_dir: typing.Union[str, Path]) -> Path:# if self.cache_dir is not None:# cache_dir = Path(self.cache_dir)# else:# cache_dir = Path("cache") / self.name
# if not cache_dir.is_absolute():# cache_dir = Path(output_dir) / str(cache_dir)
# return cache_dir
# @dataclass# class AlignerConfig:# aligner: typing.Optional[Aligner] = None# casing: typing.Optional[TextCasing] = None
# @dataclass# class InferenceConfig:# length_scale: float = 1.0# noise_scale: float = 0.667# noise_w: float = 0.8
# @dataclass# class TrainingConfig(DataClassJsonMixin):# seed: int = 1234# epochs: int = 10000# learning_rate: float = 2e-4# betas: typing.Tuple[float, float] = field(default=(0.8, 0.99))# eps: float = 1e-9# batch_size: int = 32# fp16_run: bool = False# lr_decay: float = 0.999875# segment_size: int = 8192# init_lr_ratio: float = 1.0# warmup_epochs: int = 0# c_mel: int = 45# c_kl: float = 1.0# grad_clip: typing.Optional[float] = None
# min_seq_length: typing.Optional[int] = None# max_seq_length: typing.Optional[int] = None
# min_spec_length: typing.Optional[int] = None# max_spec_length: typing.Optional[int] = None
# min_speaker_utterances: typing.Optional[int] = None
# last_epoch: int = 1# global_step: int = 1# best_loss: typing.Optional[float] = None# audio: AudioConfig = field(default_factory=AudioConfig)# model: ModelConfig = field(default_factory=ModelConfig)# phonemes: PhonemesConfig = field(default_factory=PhonemesConfig)# text_aligner: AlignerConfig = field(default_factory=AlignerConfig)# text_language: typing.Optional[str] = None# phonemizer: typing.Optional[Phonemizer] = None# datasets: typing.List[DatasetConfig] = field(default_factory=list)# inference: InferenceConfig = field(default_factory=InferenceConfig)
# version: int = 1# git_commit: str = ""
# @property# def is_multispeaker(self):# return self.model.is_multispeaker or any(d.multispeaker for d in self.datasets)
# def save(self, config_file: typing.TextIO):# """Save config as JSON to a file"""# json.dump(self.to_dict(), config_file, indent=4)
# def get_speaker_id(self, dataset_name: str, speaker_name: str) -> int:# if self.speaker_id_map is None:# self.speaker_id_map = {}
# full_speaker_name = f"{dataset_name}_{speaker_name}"# speaker_id = self.speaker_id_map.get(full_speaker_name)# if speaker_id is None:# speaker_id = len(self.speaker_id_map)# self.speaker_id_map[full_speaker_name] = speaker_id
# return speaker_id
# @staticmethod# def load(config_file: typing.TextIO) -> "TrainingConfig":# """Load config from a JSON file"""# return TrainingConfig.from_json(config_file.read())
# @staticmethod# def load_and_merge(# config: "TrainingConfig",# config_files: typing.Iterable[typing.Union[str, Path, typing.TextIO]],# ) -> "TrainingConfig":# """Loads one or more JSON configuration files and overlays them on top of an existing config"""# base_dict = config.to_dict()# for maybe_config_file in config_files:# if isinstance(maybe_config_file, (str, Path)):# # File path# config_file = open(maybe_config_file, "r", encoding="utf-8")# else:# # File object# config_file = maybe_config_file
# with config_file:# # Load new config and overlay on existing config# new_dict = json.load(config_file)# TrainingConfig.recursive_update(base_dict, new_dict)
# return TrainingConfig.from_dict(base_dict)
# @staticmethod# def recursive_update(# base_dict: typing.Dict[typing.Any, typing.Any],# new_dict: typing.Mapping[typing.Any, typing.Any],# ) -> None:# """Recursively overwrites values in base dictionary with values from new dictionary"""# for key, value in new_dict.items():# if isinstance(value, collections.Mapping) and (# base_dict.get(key) is not None# ):# TrainingConfig.recursive_update(base_dict[key], value)# else:# base_dict[key] = value