PyTorch 高级主题 #
分布式训练 #
分布式训练概述 #
text
┌─────────────────────────────────────────────────────────────┐
│ 分布式训练方式 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 数据并行(Data Parallelism): │
│ - 同一模型复制到多个 GPU │
│ - 数据分割到不同 GPU │
│ - 梯度聚合后更新 │
│ │
│ 模型并行(Model Parallelism): │
│ - 模型分割到多个 GPU │
│ - 每个GPU处理模型的一部分 │
│ - 适合超大模型 │
│ │
│ 流水线并行(Pipeline Parallelism): │
│ - 模型按层分割 │
│ - 数据流水线处理 │
│ - 提高GPU利用率 │
│ │
└─────────────────────────────────────────────────────────────┘
DistributedDataParallel #
python
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import torch.multiprocessing as mp
import os
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 5)
)
def forward(self, x):
return self.net(x)
def train(rank, world_size):
setup(rank, world_size)
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001)
dataset = torch.randn(1000, 10)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
for epoch in range(10):
sampler.set_epoch(epoch)
for batch in dataloader:
optimizer.zero_grad()
output = ddp_model(batch.to(rank))
loss = output.sum()
loss.backward()
optimizer.step()
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
使用 torchrun 启动 #
python
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os
def main():
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
model = ToyModel().to(local_rank)
model = DDP(model, device_ids=[local_rank])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(10):
for batch in dataloader:
optimizer.zero_grad()
output = model(batch)
loss = criterion(output, target)
loss.backward()
optimizer.step()
dist.destroy_process_group()
if __name__ == "__main__":
main()
bash
torchrun --nproc_per_node=4 train.py
Fully Sharded Data Parallel #
python
import torch
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
class LargeModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(1024, 1024) for _ in range(100)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
model = LargeModel()
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
device_id=torch.cuda.current_device()
)
混合精度训练 #
自动混合精度 #
python
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
model = nn.Linear(1000, 10).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scaler = GradScaler()
for epoch in range(10):
for x, y in dataloader:
x, y = x.cuda(), y.cuda()
optimizer.zero_grad()
with autocast():
output = model(x)
loss = nn.CrossEntropyLoss()(output, y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
bfloat16 精度 #
python
import torch
import torch.nn as nn
model = nn.Linear(1000, 10).cuda()
model = model.to(torch.bfloat16)
x = torch.randn(32, 1000, dtype=torch.bfloat16, device="cuda")
output = model(x)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output = model(x.float())
模型量化 #
动态量化 #
python
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
model.eval()
quantized_model = torch.quantization.quantize_dynamic(
model,
{nn.Linear},
dtype=torch.qint8
)
print(f"原始模型大小: {sum(p.numel() for p in model.parameters())}")
print(f"量化模型大小: {sum(p.numel() for p in quantized_model.parameters())}")
静态量化 #
python
import torch
import torch.nn as nn
class StaticQuantModel(nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.quantization.QuantStub()
self.fc1 = nn.Linear(784, 256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.dequant(x)
return x
model = StaticQuantModel()
model.eval()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model, inplace=True)
with torch.no_grad():
for data in calibration_dataloader:
model(data)
torch.quantization.convert(model, inplace=True)
量化感知训练 #
python
import torch
import torch.nn as nn
class QATModel(nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.quantization.QuantStub()
self.fc1 = nn.Linear(784, 256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.dequant(x)
return x
model = QATModel()
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model = torch.quantization.prepare_qat(model, inplace=True)
for epoch in range(10):
train_one_epoch(model, dataloader, criterion, optimizer)
model.eval()
quantized_model = torch.quantization.convert(model)
模型剪枝 #
非结构化剪枝 #
python
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
prune.l1_unstructured(model[0], name='weight', amount=0.3)
print(f"剪枝后权重: {model[0].weight}")
prune.remove(model[0], 'weight')
结构化剪枝 #
python
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
model = nn.Linear(10, 10)
prune.ln_structured(model, name='weight', amount=0.3, n=2, dim=0)
print(f"结构化剪枝后权重: {model.weight}")
全局剪枝 #
python
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
parameters_to_prune = [
(model[0], 'weight'),
(model[2], 'weight'),
]
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
for module, name in parameters_to_prune:
print(f"剪枝比例: {torch.sum(module.weight == 0) / module.weight.neel()}")
知识蒸馏 #
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistillationLoss(nn.Module):
def __init__(self, temperature=3.0, alpha=0.5):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.ce_loss = nn.CrossEntropyLoss()
self.kl_loss = nn.KLDivLoss(reduction='batchmean')
def forward(self, student_output, teacher_output, labels):
ce_loss = self.ce_loss(student_output, labels)
soft_targets = F.log_softmax(student_output / self.temperature, dim=1)
soft_labels = F.softmax(teacher_output / self.temperature, dim=1)
kl_loss = self.kl_loss(soft_targets, soft_labels) * (self.temperature ** 2)
return self.alpha * ce_loss + (1 - self.alpha) * kl_loss
teacher = LargeModel().eval()
student = SmallModel()
criterion = DistillationLoss(temperature=3.0, alpha=0.5)
optimizer = torch.optim.Adam(student.parameters(), lr=0.001)
for epoch in range(10):
for x, y in dataloader:
optimizer.zero_grad()
with torch.no_grad():
teacher_output = teacher(x)
student_output = student(x)
loss = criterion(student_output, teacher_output, y)
loss.backward()
optimizer.step()
模型部署 #
TorchServe 部署 #
python
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 2)
def forward(self, x):
return self.fc(x)
model = Model()
model.eval()
traced_model = torch.jit.trace(model, torch.randn(1, 10))
traced_model.save('model.pt')
bash
torch-model-archiver --model-name mymodel \
--version 1.0 \
--model-file model.py \
--serialized-file model.pt \
--handler image_classifier
torchserve --start --model-store model_store --models mymodel.mar
TensorRT 加速 #
python
import torch
import torch.nn as nn
import torch_tensorrt
model = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(64, 10)
).cuda().eval()
trt_model = torch_tensorrt.compile(model,
inputs=[torch_tensorrt.Input(
min_shape=[1, 3, 224, 224],
opt_shape=[32, 3, 224, 224],
max_shape=[64, 3, 224, 224],
dtype=torch.float32
)],
enabled_precisions={torch.float16}
)
x = torch.randn(32, 3, 224, 224).cuda()
output = trt_model(x)
ONNX Runtime 部署 #
python
import torch
import onnxruntime as ort
import numpy as np
model = Model()
model.eval()
dummy_input = torch.randn(1, 10)
torch.onnx.export(model, dummy_input, "model.onnx")
session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
input_data = np.random.randn(1, 10).astype(np.float32)
outputs = session.run([output_name], {input_name: input_data})
性能优化技巧 #
编译优化 #
python
import torch
model = MyModel().cuda()
model = torch.compile(model)
model = torch.compile(model, mode="reduce-overhead")
model = torch.compile(model, mode="max-autotune")
output = model(input)
内存优化 #
python
import torch
with torch.no_grad():
output = model(x)
optimizer.zero_grad(set_to_none=True)
for x, y in dataloader:
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
torch.cuda.empty_cache()
数据加载优化 #
python
from torch.utils.data import DataLoader
dataloader = DataLoader(
dataset,
batch_size=64,
shuffle=True,
num_workers=4,
pin_memory=True,
prefetch_factor=2,
persistent_workers=True
)
总结 #
本指南涵盖了 PyTorch 的高级主题:
- 分布式训练:DDP、FSDP 多 GPU 训练
- 混合精度:AMP、bfloat16 加速训练
- 模型量化:动态量化、静态量化、量化感知训练
- 模型剪枝:非结构化、结构化、全局剪枝
- 知识蒸馏:模型压缩技术
- 模型部署:TorchServe、TensorRT、ONNX Runtime
掌握这些高级技术,可以帮助你构建更高效、更强大的深度学习系统!
最后更新:2026-03-29