PyTorch 训练 #

概述 #

PyTorchJob 是 Kubeflow 提供的 PyTorch 分布式训练作业类型,支持 Master-Worker 架构,使用 PyTorch 的分布式数据并行(DDP)进行训练。

PyTorchJob 架构 #

text
┌─────────────────────────────────────────────────────────────┐
│                   PyTorchJob 架构                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  Master-Worker 架构:                                        │
│  ┌─────────────────────────────────────────────────────┐   │
│  │                                                     │   │
│  │  ┌──────────┐                                       │   │
│  │  │  Master  │  协调训练、保存模型                    │   │
│  │  │  (Rank 0)│                                       │   │
│  │  └────┬─────┘                                       │   │
│  │       │                                             │   │
│  │       │  NCCL/Gloo 通信                             │   │
│  │       │                                             │   │
│  │  ┌────┴────┐  ┌─────────┐  ┌─────────┐            │   │
│  │  │ Worker  │  │ Worker  │  │ Worker  │            │   │
│  │  │ (Rank 1)│  │ (Rank 2)│  │ (Rank 3)│            │   │
│  │  └─────────┘  └─────────┘  └─────────┘            │   │
│  │                                                     │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                             │
│  通信方式:                                                  │
│  ├── NCCL - GPU 通信(推荐)                                │
│  ├── Gloo - CPU 通信                                       │
│  └── MPI - 高性能计算                                       │
│                                                             │
└─────────────────────────────────────────────────────────────┘

PyTorchJob 角色 #

text
PyTorchJob 角色:

Master:
├── 主节点(Rank 0)
├── 协调分布式训练
├── 保存检查点和模型
├── 写入日志
└── 初始化进程组

Worker:
├── 工作节点(Rank 1+)
├── 执行训练计算
├── 处理数据批次
├── 梯度同步
└── 模型更新

创建 PyTorchJob #

基本配置 #

yaml
apiVersion: kubeflow.org/v1
kind: PyTorchJob
metadata:
  name: basic-pytorchjob
  namespace: kubeflow-user-example-com
spec:
  pytorchReplicaSpecs:
    Master:
      replicas: 1
      restartPolicy: OnFailure
      template:
        spec:
          containers:
          - name: pytorch
            image: pytorch/pytorch:2.0.0
            command:
            - python
            - /opt/model/train.py
            args:
            - --epochs=10
            - --batch-size=32
    Worker:
      replicas: 2
      restartPolicy: OnFailure
      template:
        spec:
          containers:
          - name: pytorch
            image: pytorch/pytorch:2.0.0
            command:
            - python
            - /opt/model/train.py
            args:
            - --epochs=10
            - --batch-size=32

GPU 训练配置 #

yaml
apiVersion: kubeflow.org/v1
kind: PyTorchJob
metadata:
  name: gpu-pytorchjob
  namespace: kubeflow-user-example-com
spec:
  pytorchReplicaSpecs:
    Master:
      replicas: 1
      template:
        spec:
          containers:
          - name: pytorch
            image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime
            resources:
              requests:
                cpu: "4"
                memory: "16Gi"
                nvidia.com/gpu: "2"
              limits:
                cpu: "8"
                memory: "32Gi"
                nvidia.com/gpu: "2"
            env:
            - name: NCCL_DEBUG
              value: "INFO"
    Worker:
      replicas: 4
      template:
        spec:
          containers:
          - name: pytorch
            image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime
            resources:
              limits:
                nvidia.com/gpu: "2"

多节点多 GPU #

yaml
apiVersion: kubeflow.org/v1
kind: PyTorchJob
metadata:
  name: multi-node-gpu-pytorchjob
  namespace: kubeflow-user-example-com
spec:
  pytorchReplicaSpecs:
    Master:
      replicas: 1
      template:
        spec:
          containers:
          - name: pytorch
            image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime
            resources:
              limits:
                nvidia.com/gpu: "4"
            volumeMounts:
            - name: data
              mountPath: /data
            - name: output
              mountPath: /output
          volumes:
          - name: data
            persistentVolumeClaim:
              claimName: training-data-pvc
          - name: output
            persistentVolumeClaim:
              claimName: model-output-pvc
    Worker:
      replicas: 4
      template:
        spec:
          containers:
          - name: pytorch
            image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime
            resources:
              limits:
                nvidia.com/gpu: "4"

训练脚本示例 #

基本分布式训练 #

python
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import os

def setup():
    dist.init_process_group(
        backend='nccl',
        init_method='env://'
    )
    torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', 0)))

def cleanup():
    dist.destroy_process_group()

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

def main():
    setup()
    
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    
    if rank == 0:
        print(f"World size: {world_size}")
    
    model = SimpleModel().cuda()
    ddp_model = DDP(model, device_ids=[local_rank])
    
    optimizer = optim.Adam(ddp_model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    train_dataset = torch.randn(10000, 784), torch.randint(0, 10, (10000,))
    train_sampler = DistributedSampler(train_dataset)
    train_loader = DataLoader(
        train_dataset,
        batch_size=64,
        sampler=train_sampler
    )
    
    for epoch in range(10):
        train_sampler.set_epoch(epoch)
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.cuda(), target.cuda()
            
            optimizer.zero_grad()
            output = ddp_model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            if batch_idx % 100 == 0 and rank == 0:
                print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
    
    if rank == 0:
        torch.save(model.state_dict(), '/output/model.pth')
    
    cleanup()

if __name__ == '__main__':
    main()

使用 torchrun #

python
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import os

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

def main():
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    torch.cuda.set_device(local_rank)
    
    dist.init_process_group(backend='nccl')
    
    model = CNN().cuda()
    ddp_model = DDP(model, device_ids=[local_rank])
    
    optimizer = optim.Adam(ddp_model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    train_loader = DataLoader(
        torch.randn(10000, 1, 28, 28),
        torch.randint(0, 10, (10000,)),
        batch_size=64,
        sampler=DistributedSampler(range(10000))
    )
    
    for epoch in range(10):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.cuda(), target.cuda()
            
            optimizer.zero_grad()
            output = ddp_model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    
    if dist.get_rank() == 0:
        torch.save(model.state_dict(), '/output/model.pth')
    
    dist.destroy_process_group()

if __name__ == '__main__':
    main()

使用 torch.distributed.launch #

yaml
apiVersion: kubeflow.org/v1
kind: PyTorchJob
metadata:
  name: launch-pytorchjob
  namespace: kubeflow-user-example-com
spec:
  pytorchReplicaSpecs:
    Master:
      replicas: 1
      template:
        spec:
          containers:
          - name: pytorch
            image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime
            command:
            - python
            - -m
            - torch.distributed.launch
            - --nproc_per_node=4
            - --nnodes=4
            - --node_rank=0
            - --master_addr=$(MASTER_ADDR)
            - --master_port=$(MASTER_PORT)
            - /opt/model/train.py

存储配置 #

数据和模型存储 #

yaml
apiVersion: kubeflow.org/v1
kind: PyTorchJob
metadata:
  name: storage-pytorchjob
  namespace: kubeflow-user-example-com
spec:
  pytorchReplicaSpecs:
    Master:
      replicas: 1
      template:
        spec:
          containers:
          - name: pytorch
            image: pytorch/pytorch:2.0.0
            volumeMounts:
            - name: training-data
              mountPath: /data
              readOnly: true
            - name: model-output
              mountPath: /output
            - name: checkpoints
              mountPath: /checkpoints
          volumes:
          - name: training-data
            persistentVolumeClaim:
              claimName: training-data-pvc
          - name: model-output
            persistentVolumeClaim:
              claimName: model-output-pvc
          - name: checkpoints
            persistentVolumeClaim:
              claimName: checkpoints-pvc

共享存储配置 #

yaml
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
  name: shared-data-pvc
  namespace: kubeflow-user-example-com
spec:
  accessModes:
    - ReadWriteMany
  resources:
    requests:
      storage: 100Gi
  storageClassName: nfs-storage
---
apiVersion: kubeflow.org/v1
kind: PyTorchJob
metadata:
  name: shared-storage-pytorchjob
  namespace: kubeflow-user-example-com
spec:
  pytorchReplicaSpecs:
    Master:
      replicas: 1
      template:
        spec:
          containers:
          - name: pytorch
            image: pytorch/pytorch:2.0.0
            volumeMounts:
            - name: shared-data
              mountPath: /shared
          volumes:
          - name: shared-data
            persistentVolumeClaim:
              claimName: shared-data-pvc

环境变量配置 #

NCCL 配置 #

yaml
apiVersion: kubeflow.org/v1
kind: PyTorchJob
metadata:
  name: nccl-pytorchjob
  namespace: kubeflow-user-example-com
spec:
  pytorchReplicaSpecs:
    Master:
      replicas: 1
      template:
        spec:
          containers:
          - name: pytorch
            image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime
            env:
            - name: NCCL_DEBUG
              value: "INFO"
            - name: NCCL_SOCKET_IFNAME
              value: "eth0"
            - name: NCCL_IB_DISABLE
              value: "0"
            - name: NCCL_IB_HCA
              value: "mlx5_0,mlx5_1"
    Worker:
      replicas: 4
      template:
        spec:
          containers:
          - name: pytorch
            image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime
            env:
            - name: NCCL_DEBUG
              value: "INFO"

训练配置 #

yaml
apiVersion: v1
kind: ConfigMap
metadata:
  name: training-config
  namespace: kubeflow-user-example-com
data:
  EPOCHS: "100"
  BATCH_SIZE: "64"
  LEARNING_RATE: "0.001"
---
apiVersion: kubeflow.org/v1
kind: PyTorchJob
metadata:
  name: config-pytorchjob
  namespace: kubeflow-user-example-com
spec:
  pytorchReplicaSpecs:
    Master:
      replicas: 1
      template:
        spec:
          containers:
          - name: pytorch
            image: pytorch/pytorch:2.0.0
            envFrom:
            - configMapRef:
                name: training-config

调度配置 #

节点选择 #

yaml
apiVersion: kubeflow.org/v1
kind: PyTorchJob
metadata:
  name: node-selector-pytorchjob
  namespace: kubeflow-user-example-com
spec:
  pytorchReplicaSpecs:
    Master:
      replicas: 1
      template:
        spec:
          nodeSelector:
            accelerator: nvidia-tesla-v100
          containers:
          - name: pytorch
            image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime

亲和性和容忍 #

yaml
apiVersion: kubeflow.org/v1
kind: PyTorchJob
metadata:
  name: affinity-pytorchjob
  namespace: kubeflow-user-example-com
spec:
  pytorchReplicaSpecs:
    Master:
      replicas: 1
      template:
        spec:
          affinity:
            nodeAffinity:
              requiredDuringSchedulingIgnoredDuringExecution:
                nodeSelectorTerms:
                - matchExpressions:
                  - key: gpu-type
                    operator: In
                    values:
                    - v100
                    - a100
            podAntiAffinity:
              preferredDuringSchedulingIgnoredDuringExecution:
              - weight: 100
                podAffinityTerm:
                  labelSelector:
                    matchLabels:
                      app: pytorch-training
                  topologyKey: kubernetes.io/hostname
          tolerations:
          - key: "nvidia.com/gpu"
            operator: "Exists"
            effect: "NoSchedule"
          containers:
          - name: pytorch
            image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime

管理 PyTorchJob #

查看状态 #

bash
# 列出 PyTorchJob
kubectl get pytorchjobs -n kubeflow-user-example-com

# 查看详情
kubectl describe pytorchjob basic-pytorchjob -n kubeflow-user-example-com

# 查看状态
kubectl get pytorchjob basic-pytorchjob -n kubeflow-user-example-com -o jsonpath='{.status}'

# 查看 Pod
kubectl get pods -n kubeflow-user-example-com -l kubeflow.org/job-name=basic-pytorchjob

查看日志 #

bash
# 查看 Master 日志
kubectl logs basic-pytorchjob-master-0 -n kubeflow-user-example-com

# 查看 Worker 日志
kubectl logs basic-pytorchjob-worker-0 -n kubeflow-user-example-com

# 实时日志
kubectl logs -f basic-pytorchjob-master-0 -n kubeflow-user-example-com

# 查看所有 Pod 日志
for pod in $(kubectl get pods -n kubeflow-user-example-com -l kubeflow.org/job-name=basic-pytorchjob -o name); do
  echo "=== $pod ==="
  kubectl logs $pod -n kubeflow-user-example-com --tail=50
done

停止和删除 #

bash
# 删除 PyTorchJob
kubectl delete pytorchjob basic-pytorchjob -n kubeflow-user-example-com

# 强制删除
kubectl delete pytorchjob basic-pytorchjob -n kubeflow-user-example-com --force --grace-period=0

最佳实践 #

训练优化 #

text
1. 数据加载优化
   ├── 使用 DistributedSampler
   ├── 多进程数据加载
   ├── 数据预取
   └── 内存缓存

2. GPU 优化
   ├── 使用混合精度训练
   ├── 梯度累积
   ├── 梯度检查点
   └── 模型并行

3. 通信优化
   ├── 梯度压缩
   ├── 重叠计算和通信
   └── 批量归一化同步

故障恢复 #

python
import torch
import os

def save_checkpoint(model, optimizer, epoch, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, path)

def load_checkpoint(model, optimizer, path):
    if os.path.exists(path):
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        return checkpoint['epoch']
    return 0

def main():
    start_epoch = load_checkpoint(model, optimizer, '/checkpoints/latest.pth')
    
    for epoch in range(start_epoch, total_epochs):
        train_one_epoch()
        
        if dist.get_rank() == 0:
            save_checkpoint(model, optimizer, epoch, '/checkpoints/latest.pth')

下一步 #

现在你已经掌握了 PyTorchJob 的使用,接下来学习 MPI 分布式训练,了解高性能计算场景下的分布式训练!

最后更新:2026-04-05