高级配置 #

分布式训练 #

多 GPU 训练 #

bash
# 使用多 GPU 训练
tts --config_path config.json \
    --gpus "0,1,2,3" \
    --coq_dataset_path dataset/

# 指定 GPU 数量
CUDA_VISIBLE_DEVICES=0,1 tts --config_path config.json

分布式配置 #

json
{
    "distributed": {
        "backend": "nccl",
        "url": "tcp://localhost:54321",
        "world_size": 4,
        "rank": 0
    },
    
    "training": {
        "batch_size": 32,
        "num_loader_workers": 8
    }
}

PyTorch DDP 示例 #

python
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup_distributed(rank, world_size):
    dist.init_process_group(
        backend="nccl",
        init_method="tcp://localhost:54321",
        rank=rank,
        world_size=world_size
    )

def train_distributed(rank, world_size, config):
    setup_distributed(rank, world_size)
    
    # 创建模型
    model = create_model(config).to(rank)
    model = DDP(model, device_ids=[rank])
    
    # 创建数据加载器
    train_loader = create_dataloader(config, rank, world_size)
    
    # 训练循环
    for epoch in range(config.epochs):
        for batch in train_loader:
            loss = model(batch)
            loss.backward()
            optimizer.step()
    
    dist.destroy_process_group()

# 启动多进程训练
import torch.multiprocessing as mp
world_size = torch.cuda.device_count()
mp.spawn(train_distributed, args=(world_size, config), nprocs=world_size)

自定义模型 #

创建自定义模型 #

python
import torch
import torch.nn as nn
from TTS.tts.models.base_tts import BaseTTS

class CustomTTSModel(BaseTTS):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # 定义编码器
        self.encoder = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )
        
        # 定义解码器
        self.decoder = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512)
        )
        
        # 定义声码器接口
        self.vocoder = None
    
    def forward(self, text, mel_target=None):
        # 编码
        encoded = self.encoder(text)
        
        # 解码
        mel_output = self.decoder(encoded)
        
        return mel_output
    
    def inference(self, text):
        with torch.no_grad():
            mel = self.forward(text)
            if self.vocoder:
                audio = self.vocoder(mel)
                return audio
            return mel
    
    def load_checkpoint(self, config, checkpoint_path):
        state_dict = torch.load(checkpoint_path, map_location="cpu")
        self.load_state_dict(state_dict)
        return self

# 注册模型
from TTS.tts.models import setup_model

def register_custom_model():
    # 添加到模型注册表
    from TTS.tts.models.base_tts import BaseTTS
    BaseTTS.MODEL_REGISTRY["custom_tts"] = CustomTTSModel

自定义配置 #

python
from dataclasses import dataclass
from TTS.config import BaseDatasetConfig, BaseTrainingConfig

@dataclass
class CustomModelConfig:
    model: str = "custom_tts"
    
    # 模型参数
    encoder_hidden: int = 256
    decoder_hidden: int = 256
    num_layers: int = 4
    
    # 音频参数
    sample_rate: int = 22050
    num_mels: int = 80
    
    # 训练参数
    batch_size: int = 16
    learning_rate: float = 0.001
    epochs: int = 100

# 使用配置
config = CustomModelConfig()
model = CustomTTSModel(config)

自定义音素转换器 #

创建自定义音素转换器 #

python
from TTS.tts.utils.text.phonemizers import BasePhonemizer

class CustomPhonemizer(BasePhonemizer):
    def __init__(self, language="en"):
        super().__init__(language)
        self.punctuation = "!.?,;:"
    
    def phonemize(self, text, separator="|"):
        # 自定义音素转换逻辑
        phonemes = []
        for word in text.split():
            phoneme = self._word_to_phoneme(word)
            phonemes.append(phoneme)
        
        return separator.join(phonemes)
    
    def _word_to_phoneme(self, word):
        # 简单示例:使用字母作为音素
        return " ".join(list(word.lower()))
    
    @staticmethod
    def supported_languages():
        return ["en", "custom"]

# 注册音素转换器
from TTS.tts.utils.text.phonemizers import PHONEMIZERS
PHONEMIZERS["custom"] = CustomPhonemizer

自定义声码器 #

集成自定义声码器 #

python
import torch
import torch.nn as nn
from TTS.vocoder.models import BaseVocoder

class CustomVocoder(BaseVocoder):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # 定义网络结构
        self.generator = nn.Sequential(
            nn.ConvTranspose1d(80, 512, 16, 8),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose1d(512, 256, 16, 8),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose1d(256, 1, 4, 2),
            nn.Tanh()
        )
    
    def forward(self, mel):
        return self.generator(mel)
    
    def inference(self, mel):
        with torch.no_grad():
            return self.forward(mel)
    
    def load_checkpoint(self, checkpoint_path):
        state_dict = torch.load(checkpoint_path, map_location="cpu")
        self.load_state_dict(state_dict)
        return self

# 使用自定义声码器
vocoder = CustomVocoder(config)
audio = vocoder.inference(mel_spectrogram)

性能优化 #

模型量化 #

python
import torch
from TTS.api import TTS

def quantize_model(model):
    # 动态量化
    quantized_model = torch.quantization.quantize_dynamic(
        model,
        {torch.nn.Linear, torch.nn.Conv1d},
        dtype=torch.qint8
    )
    return quantized_model

# 使用
tts = TTS("tts_models/en/ljspeech/vits")
quantized = quantize_model(tts.synthesizer.tts_model)

ONNX 导出 #

python
import torch
from TTS.api import TTS

def export_to_onnx(model, output_path, sample_input):
    model.eval()
    
    torch.onnx.export(
        model,
        sample_input,
        output_path,
        export_params=True,
        opset_version=14,
        do_constant_folding=True,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={
            "input": {0: "batch_size", 1: "sequence_length"},
            "output": {0: "batch_size", 1: "sequence_length"}
        }
    )
    
    print(f"模型已导出到: {output_path}")

# 使用
tts = TTS("tts_models/en/ljspeech/vits")
sample_input = torch.randint(0, 100, (1, 20))
export_to_onnx(tts.synthesizer.tts_model, "model.onnx", sample_input)

TorchScript 导出 #

python
import torch
from TTS.api import TTS

def export_to_torchscript(model, output_path, sample_input):
    model.eval()
    
    # 追踪模式
    traced_model = torch.jit.trace(model, sample_input)
    traced_model.save(output_path)
    
    print(f"TorchScript 模型已保存: {output_path}")
    return traced_model

# 使用
tts = TTS("tts_models/en/ljspeech/vits")
sample_input = torch.randint(0, 100, (1, 20))
scripted = export_to_torchscript(tts.synthesizer.tts_model, "model.pt", sample_input)

内存优化 #

梯度检查点 #

python
from torch.utils.checkpoint import checkpoint

class MemoryEfficientModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(config.hidden_size, config.hidden_size)
            for _ in range(config.num_layers)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            # 使用梯度检查点节省内存
            x = checkpoint(layer, x)
        return x

混合精度训练 #

python
import torch
from torch.cuda.amp import autocast, GradScaler

def train_with_amp(model, dataloader, optimizer, epochs):
    scaler = GradScaler()
    
    for epoch in range(epochs):
        for batch in dataloader:
            optimizer.zero_grad()
            
            # 使用混合精度
            with autocast():
                loss = model(batch)
            
            # 缩放梯度
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

扩展开发 #

插件系统 #

python
from abc import ABC, abstractmethod

class TTSPlugin(ABC):
    """TTS 插件基类"""
    
    @abstractmethod
    def pre_process(self, text):
        """文本预处理"""
        pass
    
    @abstractmethod
    def post_process(self, audio):
        """音频后处理"""
        pass

class NoiseReductionPlugin(TTSPlugin):
    def pre_process(self, text):
        return text
    
    def post_process(self, audio):
        import noisereduce as nr
        return nr.reduce_noise(y=audio, sr=22050)

class PluginManager:
    def __init__(self):
        self.plugins = []
    
    def register(self, plugin):
        self.plugins.append(plugin)
    
    def pre_process(self, text):
        for plugin in self.plugins:
            text = plugin.pre_process(text)
        return text
    
    def post_process(self, audio):
        for plugin in self.plugins:
            audio = plugin.post_process(audio)
        return audio

# 使用
manager = PluginManager()
manager.register(NoiseReductionPlugin())

自定义训练器 #

python
from trainer import Trainer, TrainerArgs

class CustomTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.custom_metrics = {}
    
    def training_step(self, batch):
        # 自定义训练步骤
        outputs = self.model(**batch)
        loss = outputs["loss"]
        
        # 记录自定义指标
        self.custom_metrics["custom_loss"] = loss.item()
        
        return loss
    
    def validation_step(self, batch):
        # 自定义验证步骤
        outputs = self.model(**batch)
        return outputs

# 使用
trainer = CustomTrainer(
    TrainerArgs(),
    config,
    output_path="output",
    model=model
)
trainer.fit()

高级配置选项 #

完整配置示例 #

json
{
    "model": "vits",
    "run_name": "advanced_vits",
    
    "audio": {
        "sample_rate": 22050,
        "fft_size": 1024,
        "win_length": 1024,
        "hop_length": 256,
        "num_mels": 80,
        "mel_fmin": 0,
        "mel_fmax": 8000,
        "spec_gain": 1,
        "signal_norm": true,
        "min_level_db": -100,
        "ref_level_db": 20,
        "preemphasis": 0.0,
        "do_trim_silence": true,
        "trim_db": 45
    },
    
    "datasets": [
        {
            "name": "custom_dataset",
            "path": "dataset/",
            "meta_file_train": "metadata.csv",
            "meta_file_val": "metadata_val.csv",
            "language": "en",
            "cleaners": ["english_cleaners"],
            "phonemizer": "espeak",
            "text_cleaner": "phoneme_cleaners"
        }
    ],
    
    "training": {
        "batch_size": 32,
        "eval_batch_size": 16,
        "num_loader_workers": 8,
        "num_eval_loader_workers": 4,
        
        "epochs": 1000,
        "save_step": 1000,
        "print_step": 100,
        "run_eval": true,
        "test_delay_epochs": 5,
        
        "learning_rate": 0.0002,
        "lr_scheduler": "NoamLR",
        "lr_scheduler_params": {
            "warmup_steps": 4000
        },
        
        "optimizer": "AdamW",
        "optimizer_params": {
            "betas": [0.8, 0.99],
            "eps": 1e-9,
            "weight_decay": 0.01
        },
        
        "gradient_accumulation_steps": 1,
        "mixed_precision": true,
        "grad_clip": 5.0,
        
        "early_stopping": true,
        "early_stopping_patience": 10,
        
        "use_tensorboard": true,
        "tensorboard_log_dir": "logs/"
    },
    
    "model_params": {
        "hidden_channels": 192,
        "inter_channels": 192,
        "filter_channels": 768,
        "n_heads": 2,
        "n_layers": 6,
        "kernel_size": 3,
        "p_dropout": 0.1,
        
        "resblock": "1",
        "resblock_kernel_sizes": [3, 7, 11],
        "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
        "upsample_rates": [8, 8, 2, 2],
        "upsample_initial_channel": 512,
        "upsample_kernel_sizes": [16, 16, 4, 4]
    },
    
    "distributed": {
        "backend": "nccl",
        "url": "tcp://localhost:54321"
    }
}

下一步 #

掌握了高级配置后,继续学习 API 服务,了解如何部署 TTS 服务!

最后更新:2026-04-05