PyTorch 高级主题 #

分布式训练 #

分布式训练概述 #

text
┌─────────────────────────────────────────────────────────────┐
│                    分布式训练方式                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  数据并行(Data Parallelism):                              │
│  - 同一模型复制到多个 GPU                                    │
│  - 数据分割到不同 GPU                                       │
│  - 梯度聚合后更新                                           │
│                                                             │
│  模型并行(Model Parallelism):                             │
│  - 模型分割到多个 GPU                                       │
│  - 每个GPU处理模型的一部分                                  │
│  - 适合超大模型                                             │
│                                                             │
│  流水线并行(Pipeline Parallelism):                        │
│  - 模型按层分割                                             │
│  - 数据流水线处理                                           │
│  - 提高GPU利用率                                            │
│                                                             │
└─────────────────────────────────────────────────────────────┘

DistributedDataParallel #

python
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import torch.multiprocessing as mp
import os

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class ToyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(10, 10),
            torch.nn.ReLU(),
            torch.nn.Linear(10, 5)
        )
    
    def forward(self, x):
        return self.net(x)

def train(rank, world_size):
    setup(rank, world_size)
    
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])
    
    optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001)
    
    dataset = torch.randn(1000, 10)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
    
    for epoch in range(10):
        sampler.set_epoch(epoch)
        for batch in dataloader:
            optimizer.zero_grad()
            output = ddp_model(batch.to(rank))
            loss = output.sum()
            loss.backward()
            optimizer.step()
    
    cleanup()

if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

使用 torchrun 启动 #

python
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os

def main():
    dist.init_process_group("nccl")
    
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    
    model = ToyModel().to(local_rank)
    model = DDP(model, device_ids=[local_rank])
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(10):
        for batch in dataloader:
            optimizer.zero_grad()
            output = model(batch)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    
    dist.destroy_process_group()

if __name__ == "__main__":
    main()
bash
torchrun --nproc_per_node=4 train.py

Fully Sharded Data Parallel #

python
import torch
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy

class LargeModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(1024, 1024) for _ in range(100)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

model = LargeModel()
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    device_id=torch.cuda.current_device()
)

混合精度训练 #

自动混合精度 #

python
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler

model = nn.Linear(1000, 10).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scaler = GradScaler()

for epoch in range(10):
    for x, y in dataloader:
        x, y = x.cuda(), y.cuda()
        
        optimizer.zero_grad()
        
        with autocast():
            output = model(x)
            loss = nn.CrossEntropyLoss()(output, y)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

bfloat16 精度 #

python
import torch
import torch.nn as nn

model = nn.Linear(1000, 10).cuda()
model = model.to(torch.bfloat16)

x = torch.randn(32, 1000, dtype=torch.bfloat16, device="cuda")
output = model(x)

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
    output = model(x.float())

模型量化 #

动态量化 #

python
import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

model.eval()

quantized_model = torch.quantization.quantize_dynamic(
    model,
    {nn.Linear},
    dtype=torch.qint8
)

print(f"原始模型大小: {sum(p.numel() for p in model.parameters())}")
print(f"量化模型大小: {sum(p.numel() for p in quantized_model.parameters())}")

静态量化 #

python
import torch
import torch.nn as nn

class StaticQuantModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.fc1 = nn.Linear(784, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)
        self.dequant = torch.quantization.DeQuantStub()
    
    def forward(self, x):
        x = self.quant(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.dequant(x)
        return x

model = StaticQuantModel()
model.eval()

model.qconfig = torch.quantization.get_default_qconfig('fbgemm')

torch.quantization.prepare(model, inplace=True)

with torch.no_grad():
    for data in calibration_dataloader:
        model(data)

torch.quantization.convert(model, inplace=True)

量化感知训练 #

python
import torch
import torch.nn as nn

class QATModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.fc1 = nn.Linear(784, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)
        self.dequant = torch.quantization.DeQuantStub()
    
    def forward(self, x):
        x = self.quant(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.dequant(x)
        return x

model = QATModel()
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model = torch.quantization.prepare_qat(model, inplace=True)

for epoch in range(10):
    train_one_epoch(model, dataloader, criterion, optimizer)

model.eval()
quantized_model = torch.quantization.convert(model)

模型剪枝 #

非结构化剪枝 #

python
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

prune.l1_unstructured(model[0], name='weight', amount=0.3)

print(f"剪枝后权重: {model[0].weight}")

prune.remove(model[0], 'weight')

结构化剪枝 #

python
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

model = nn.Linear(10, 10)

prune.ln_structured(model, name='weight', amount=0.3, n=2, dim=0)

print(f"结构化剪枝后权重: {model.weight}")

全局剪枝 #

python
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

parameters_to_prune = [
    (model[0], 'weight'),
    (model[2], 'weight'),
]

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

for module, name in parameters_to_prune:
    print(f"剪枝比例: {torch.sum(module.weight == 0) / module.weight.neel()}")

知识蒸馏 #

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    def __init__(self, temperature=3.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')
    
    def forward(self, student_output, teacher_output, labels):
        ce_loss = self.ce_loss(student_output, labels)
        
        soft_targets = F.log_softmax(student_output / self.temperature, dim=1)
        soft_labels = F.softmax(teacher_output / self.temperature, dim=1)
        kl_loss = self.kl_loss(soft_targets, soft_labels) * (self.temperature ** 2)
        
        return self.alpha * ce_loss + (1 - self.alpha) * kl_loss

teacher = LargeModel().eval()
student = SmallModel()

criterion = DistillationLoss(temperature=3.0, alpha=0.5)
optimizer = torch.optim.Adam(student.parameters(), lr=0.001)

for epoch in range(10):
    for x, y in dataloader:
        optimizer.zero_grad()
        
        with torch.no_grad():
            teacher_output = teacher(x)
        
        student_output = student(x)
        loss = criterion(student_output, teacher_output, y)
        
        loss.backward()
        optimizer.step()

模型部署 #

TorchServe 部署 #

python
import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 2)
    
    def forward(self, x):
        return self.fc(x)

model = Model()
model.eval()

traced_model = torch.jit.trace(model, torch.randn(1, 10))
traced_model.save('model.pt')
bash
torch-model-archiver --model-name mymodel \
    --version 1.0 \
    --model-file model.py \
    --serialized-file model.pt \
    --handler image_classifier

torchserve --start --model-store model_store --models mymodel.mar

TensorRT 加速 #

python
import torch
import torch.nn as nn
import torch_tensorrt

model = nn.Sequential(
    nn.Conv2d(3, 64, 3, padding=1),
    nn.ReLU(),
    nn.Conv2d(64, 64, 3, padding=1),
    nn.ReLU(),
    nn.AdaptiveAvgPool2d(1),
    nn.Flatten(),
    nn.Linear(64, 10)
).cuda().eval()

trt_model = torch_tensorrt.compile(model, 
    inputs=[torch_tensorrt.Input(
        min_shape=[1, 3, 224, 224],
        opt_shape=[32, 3, 224, 224],
        max_shape=[64, 3, 224, 224],
        dtype=torch.float32
    )],
    enabled_precisions={torch.float16}
)

x = torch.randn(32, 3, 224, 224).cuda()
output = trt_model(x)

ONNX Runtime 部署 #

python
import torch
import onnxruntime as ort
import numpy as np

model = Model()
model.eval()

dummy_input = torch.randn(1, 10)
torch.onnx.export(model, dummy_input, "model.onnx")

session = ort.InferenceSession("model.onnx")

input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

input_data = np.random.randn(1, 10).astype(np.float32)
outputs = session.run([output_name], {input_name: input_data})

性能优化技巧 #

编译优化 #

python
import torch

model = MyModel().cuda()

model = torch.compile(model)

model = torch.compile(model, mode="reduce-overhead")
model = torch.compile(model, mode="max-autotune")

output = model(input)

内存优化 #

python
import torch

with torch.no_grad():
    output = model(x)

optimizer.zero_grad(set_to_none=True)

for x, y in dataloader:
    x = x.to(device, non_blocking=True)
    y = y.to(device, non_blocking=True)

torch.cuda.empty_cache()

数据加载优化 #

python
from torch.utils.data import DataLoader

dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    prefetch_factor=2,
    persistent_workers=True
)

总结 #

本指南涵盖了 PyTorch 的高级主题:

  • 分布式训练:DDP、FSDP 多 GPU 训练
  • 混合精度:AMP、bfloat16 加速训练
  • 模型量化:动态量化、静态量化、量化感知训练
  • 模型剪枝:非结构化、结构化、全局剪枝
  • 知识蒸馏:模型压缩技术
  • 模型部署:TorchServe、TensorRT、ONNX Runtime

掌握这些高级技术,可以帮助你构建更高效、更强大的深度学习系统!

最后更新:2026-03-29