模型训练 #

训练概述 #

Coqui TTS 提供了完整的模型训练流程,支持从零开始训练或微调预训练模型。

text
┌─────────────────────────────────────────────────────────────┐
│                     训练流程概览                             │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌──────────┐   ┌──────────┐   ┌──────────┐   ┌──────────┐ │
│  │ 数据准备  │ → │ 配置文件  │ → │ 模型训练  │ → │ 模型评估  │ │
│  └──────────┘   └──────────┘   └──────────┘   └──────────┘ │
│       │              │              │              │       │
│       ↓              ↓              ↓              ↓       │
│   音频+文本       训练参数        训练循环        质量测试   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

数据准备 #

数据集格式 #

text
┌─────────────────────────────────────────────────────────────┐
│                   标准数据集结构                             │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  dataset/                                                   │
│  ├── wavs/                   # 音频文件目录                  │
│  │   ├── audio_001.wav                                      │
│  │   ├── audio_002.wav                                      │
│  │   └── ...                                                │
│  ├── metadata.csv           # 元数据文件                    │
│  └── metadata_train.csv     # 训练集(可选)                │
│                                                             │
│  metadata.csv 格式:                                         │
│  audio_name|transcription                                    │
│  audio_001|Hello, this is the first sentence.               │
│  audio_002|This is the second sentence.                     │
│                                                             │
└─────────────────────────────────────────────────────────────┘

创建数据集 #

python
import os
import shutil
from pathlib import Path

def create_dataset_structure(dataset_path):
    dataset_path = Path(dataset_path)
    
    # 创建目录
    (dataset_path / "wavs").mkdir(parents=True, exist_ok=True)
    
    # 创建元数据文件
    metadata_path = dataset_path / "metadata.csv"
    if not metadata_path.exists():
        with open(metadata_path, "w", encoding="utf-8") as f:
            f.write("audio_name|transcription\n")
    
    print(f"数据集结构已创建: {dataset_path}")
    return dataset_path

# 使用
create_dataset_structure("my_dataset")

数据预处理 #

python
import librosa
import soundfile as sf
import numpy as np
from pathlib import Path

def preprocess_audio(input_path, output_path, target_sr=22050):
    audio, sr = librosa.load(input_path, sr=None)
    
    # 重采样
    if sr != target_sr:
        audio = librosa.resample(audio, sr, target_sr)
    
    # 归一化
    audio = audio / np.max(np.abs(audio)) * 0.95
    
    # 去除静音
    audio, _ = librosa.effects.trim(audio, top_db=30)
    
    # 保存
    sf.write(output_path, audio, target_sr)
    return output_path

def preprocess_dataset(input_dir, output_dir, target_sr=22050):
    input_dir = Path(input_dir)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    for audio_file in input_dir.glob("*.wav"):
        output_path = output_dir / audio_file.name
        preprocess_audio(str(audio_file), str(output_path), target_sr)
        print(f"处理: {audio_file.name}")

# 使用
preprocess_dataset("raw_audio", "my_dataset/wavs")

数据集验证 #

python
import pandas as pd
from pathlib import Path
import soundfile as sf

def validate_dataset(dataset_path):
    dataset_path = Path(dataset_path)
    issues = []
    
    # 检查元数据文件
    metadata_path = dataset_path / "metadata.csv"
    if not metadata_path.exists():
        issues.append("缺少 metadata.csv 文件")
        return issues
    
    # 读取元数据
    df = pd.read_csv(metadata_path, sep="|")
    print(f"总条目: {len(df)}")
    
    # 检查音频文件
    wavs_dir = dataset_path / "wavs"
    missing_files = []
    invalid_files = []
    
    for _, row in df.iterrows():
        audio_name = row.iloc[0]
        audio_path = wavs_dir / f"{audio_name}.wav"
        
        if not audio_path.exists():
            missing_files.append(audio_name)
        else:
            try:
                info = sf.info(str(audio_path))
                if info.samplerate != 22050:
                    invalid_files.append((audio_name, f"采样率 {info.samplerate}"))
            except Exception as e:
                invalid_files.append((audio_name, str(e)))
    
    if missing_files:
        issues.append(f"缺少 {len(missing_files)} 个音频文件")
    
    if invalid_files:
        issues.append(f"{len(invalid_files)} 个音频文件有问题")
    
    # 打印结果
    print("\n验证结果:")
    print(f"  总条目: {len(df)}")
    print(f"  缺失文件: {len(missing_files)}")
    print(f"  问题文件: {len(invalid_files)}")
    
    return issues

# 使用
issues = validate_dataset("my_dataset")
if issues:
    print("\n问题列表:")
    for issue in issues:
        print(f"  - {issue}")

配置文件 #

基础配置示例 #

json
{
    "model": "vits",
    "run_name": "my_vits_model",
    "run_description": "My first VITS model",
    
    "audio": {
        "sample_rate": 22050,
        "output_path": "output",
        "fft_size": 1024,
        "win_length": 1024,
        "hop_length": 256,
        "num_mels": 80,
        "mel_fmin": 0,
        "mel_fmax": 8000
    },
    
    "datasets": [
        {
            "name": "my_dataset",
            "path": "my_dataset/",
            "meta_file_train": "metadata.csv",
            "meta_file_val": "metadata.csv",
            "language": "en"
        }
    ],
    
    "training": {
        "batch_size": 16,
        "eval_batch_size": 8,
        "num_loader_workers": 4,
        "num_eval_loader_workers": 2,
        "run_eval": true,
        "test_delay_epochs": 0,
        "epochs": 1000,
        "learning_rate": 0.001,
        "save_step": 1000,
        "print_step": 100,
        "output_path": "output/",
        "use_tensorboard": true
    }
}

VITS 模型配置 #

json
{
    "model": "vits",
    "run_name": "vits_ljspeech",
    
    "audio": {
        "sample_rate": 22050,
        "fft_size": 1024,
        "win_length": 1024,
        "hop_length": 256,
        "num_mels": 80
    },
    
    "vits_config": {
        "hidden_channels": 192,
        "inter_channels": 192,
        "filter_channels": 768,
        "n_heads": 2,
        "n_layers": 6,
        "kernel_size": 3,
        "p_dropout": 0.1,
        "resblock": "1",
        "resblock_kernel_sizes": [3, 7, 11],
        "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
        "upsample_rates": [8, 8, 2, 2],
        "upsample_initial_channel": 512,
        "upsample_kernel_sizes": [16, 16, 4, 4]
    },
    
    "training": {
        "batch_size": 32,
        "epochs": 1000,
        "learning_rate": 0.0002
    }
}

Tacotron2 配置 #

json
{
    "model": "tacotron2",
    "run_name": "tacotron2_ljspeech",
    
    "audio": {
        "sample_rate": 22050,
        "fft_size": 1024,
        "win_length": 1024,
        "hop_length": 256,
        "num_mels": 80
    },
    
    "tacotron_config": {
        "encoder_embedding_dim": 512,
        "encoder_n_convolutions": 3,
        "encoder_kernel_size": 5,
        "decoder_rnn_dim": 1024,
        "decoder_max_step": 2000,
        "attention_rnn_dim": 1024,
        "attention_dim": 128,
        "attention_location_n_filters": 32,
        "attention_location_kernel_size": 31,
        "postnet_embedding_dim": 512,
        "postnet_n_convolutions": 5,
        "postnet_kernel_size": 5
    },
    
    "training": {
        "batch_size": 32,
        "epochs": 1000,
        "learning_rate": 0.001
    }
}

训练命令 #

基础训练命令 #

bash
# 使用配置文件训练
tts --config_path config.json \
    --coq_dataset_path my_dataset/

# 指定模型和数据集
tts --model_name vits \
    --coq_dataset_path my_dataset/ \
    --run_name my_model

使用 Python 训练 #

python
import subprocess

def train_model(config_path, dataset_path, gpu_id=0):
    cmd = [
        "tts",
        "--config_path", config_path,
        "--coq_dataset_path", dataset_path,
        "--gpus", str(gpu_id)
    ]
    
    subprocess.run(cmd)

# 使用
train_model("config.json", "my_dataset/", gpu_id=0)

恢复训练 #

bash
# 从检查点恢复
tts --config_path config.json \
    --restore_path output/run/checkpoint_10000.pth \
    --coq_dataset_path my_dataset/

训练监控 #

TensorBoard 监控 #

bash
# 启动 TensorBoard
tensorboard --logdir output/

# 在浏览器访问
# http://localhost:6006

训练日志分析 #

python
import json
from pathlib import Path

def analyze_training_log(log_path):
    log_path = Path(log_path)
    
    losses = []
    for log_file in log_path.glob("*.jsonl"):
        with open(log_file, "r") as f:
            for line in f:
                try:
                    data = json.loads(line)
                    if "loss" in data:
                        losses.append(data)
                except:
                    pass
    
    if losses:
        print(f"总步数: {len(losses)}")
        print(f"初始损失: {losses[0].get('loss', 'N/A')}")
        print(f"最终损失: {losses[-1].get('loss', 'N/A')}")
        
        # 找最佳损失
        best = min(losses, key=lambda x: x.get("loss", float("inf")))
        print(f"最佳损失: {best.get('loss', 'N/A')} (步 {best.get('step', 'N/A')})")
    
    return losses

# 使用
analyze_training_log("output/run/")

训练进度可视化 #

python
import matplotlib.pyplot as plt
import json
from pathlib import Path

def plot_training_progress(log_path, output_path="training_progress.png"):
    log_path = Path(log_path)
    
    steps = []
    losses = []
    
    for log_file in log_path.glob("*.jsonl"):
        with open(log_file, "r") as f:
            for line in f:
                try:
                    data = json.loads(line)
                    if "loss" in data and "step" in data:
                        steps.append(data["step"])
                        losses.append(data["loss"])
                except:
                    pass
    
    plt.figure(figsize=(10, 6))
    plt.plot(steps, losses)
    plt.xlabel("Steps")
    plt.ylabel("Loss")
    plt.title("Training Progress")
    plt.grid(True)
    plt.savefig(output_path)
    plt.close()
    
    print(f"图表已保存: {output_path}")

# 使用
plot_training_progress("output/run/")

模型评估 #

合成测试 #

python
from TTS.api import TTS
from pathlib import Path

def test_model(model_path, config_path, test_texts, output_dir):
    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True)
    
    tts = TTS(
        model_path=model_path,
        config_path=config_path
    )
    
    for i, text in enumerate(test_texts):
        output_path = output_dir / f"test_{i}.wav"
        tts.tts_to_file(text=text, file_path=str(output_path))
        print(f"生成: {output_path}")

# 使用
test_texts = [
    "Hello, this is a test.",
    "The quick brown fox jumps over the lazy dog.",
    "Testing the trained model quality."
]

test_model(
    model_path="output/run/best_model.pth",
    config_path="config.json",
    test_texts=test_texts,
    output_dir="test_output"
)

MOS 评估 #

python
import numpy as np
from pathlib import Path

def evaluate_mos(audio_dir, evaluator_scores):
    """
    简化的 MOS 评估
    
    evaluator_scores: dict
        {audio_file: [score1, score2, ...]}
    """
    all_scores = []
    
    for audio_file, scores in evaluator_scores.items():
        avg_score = np.mean(scores)
        all_scores.append(avg_score)
        print(f"{audio_file}: MOS = {avg_score:.2f}")
    
    overall_mos = np.mean(all_scores)
    print(f"\n总体 MOS: {overall_mos:.2f}")
    
    return overall_mos

# 使用示例
scores = {
    "test_0.wav": [4.0, 4.5, 4.0, 4.5],
    "test_1.wav": [3.5, 4.0, 4.0, 3.5],
    "test_2.wav": [4.0, 4.0, 4.5, 4.0],
}

evaluate_mos("test_output", scores)

训练技巧 #

数据增强 #

python
import librosa
import soundfile as sf
import numpy as np

def augment_audio(audio_path, output_path, augment_type="speed"):
    audio, sr = librosa.load(audio_path, sr=None)
    
    if augment_type == "speed":
        # 速度变化
        speed = np.random.uniform(0.9, 1.1)
        audio = librosa.effects.time_stretch(audio, rate=speed)
    
    elif augment_type == "pitch":
        # 音高变化
        n_steps = np.random.randint(-2, 3)
        audio = librosa.effects.pitch_shift(audio, sr=sr, n_steps=n_steps)
    
    elif augment_type == "noise":
        # 添加噪声
        noise = np.random.randn(len(audio)) * 0.005
        audio = audio + noise
    
    sf.write(output_path, audio, sr)
    return output_path

# 使用
augment_audio("original.wav", "augmented_speed.wav", "speed")

学习率调度 #

json
{
    "training": {
        "learning_rate": 0.001,
        "lr_scheduler": "NoamLR",
        "lr_scheduler_params": {
            "warmup_steps": 4000
        }
    }
}

混合精度训练 #

json
{
    "training": {
        "mixed_precision": true,
        "fp16": true
    }
}

常见问题 #

问题 1:训练不收敛 #

python
# 解决方案:
# 1. 检查数据质量
# 2. 减小学习率
# 3. 增加数据量
# 4. 检查配置文件

# 调整学习率
config["training"]["learning_rate"] = 0.0001

问题 2:GPU 内存不足 #

json
{
    "training": {
        "batch_size": 8,
        "gradient_accumulation_steps": 4
    }
}

问题 3:训练速度慢 #

json
{
    "training": {
        "num_loader_workers": 8,
        "mixed_precision": true,
        "batch_size": 32
    }
}

下一步 #

掌握了模型训练后,继续学习 微调优化,了解如何微调预训练模型!

最后更新:2026-04-05