TensorFlow 训练 #

概述 #

TFJob 是 Kubeflow 提供的 TensorFlow 分布式训练作业类型,支持 Parameter Server 和 AllReduce 两种分布式训练模式。

TFJob 架构 #

text
┌─────────────────────────────────────────────────────────────┐
│                     TFJob 架构                               │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  参数服务器模式:                                            │
│  ┌─────────────────────────────────────────────────────┐   │
│  │                                                     │   │
│  │  ┌──────┐  ┌──────┐  ┌──────┐                      │   │
│  │  │ PS 0 │  │ PS 1 │  │ PS 2 │  Parameter Servers   │   │
│  │  └──┬───┘  └──┬───┘  └──┬───┘                      │   │
│  │     │         │         │                          │   │
│  │     └─────────┼─────────┘                          │   │
│  │               │                                    │   │
│  │     ┌─────────┼─────────┐                          │   │
│  │     │         │         │                          │   │
│  │  ┌──┴───┐  ┌──┴───┐  ┌──┴───┐                      │   │
│  │  │Worker│  │Worker│  │Worker│  Workers             │   │
│  │  │  0   │  │  1   │  │  2   │                      │   │
│  │  └──────┘  └──────┘  └──────┘                      │   │
│  │                                                     │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                             │
│  AllReduce 模式:                                            │
│  ┌─────────────────────────────────────────────────────┐   │
│  │                                                     │   │
│  │  ┌──────┐     ┌──────┐     ┌──────┐                │   │
│  │  │Chief │←───→│Worker│←───→│Worker│                │   │
│  │  │  0   │     │  1   │     │  2   │                │   │
│  │  └──────┘     └──────┘     └──────┘                │   │
│  │                                                     │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

TFJob 角色类型 #

text
TFJob 角色:

Chief:
├── 主节点
├── 协调训练
├── 保存检查点
├── 写入 TensorBoard 日志
└── 执行训练循环

PS (Parameter Server):
├── 参数服务器
├── 存储模型参数
├── 参数聚合和分发
└── 用于参数服务器模式

Worker:
├── 工作节点
├── 执行训练计算
├── 处理数据批次
└── 更新梯度

Evaluator:
├── 评估节点
├── 模型验证
├── 计算评估指标
└── 独立于训练过程

创建 TFJob #

基本配置 #

yaml
apiVersion: kubeflow.org/v1
kind: TFJob
metadata:
  name: basic-tfjob
  namespace: kubeflow-user-example-com
spec:
  tfReplicaSpecs:
    Chief:
      replicas: 1
      restartPolicy: OnFailure
      template:
        spec:
          containers:
          - name: tensorflow
            image: tensorflow/tensorflow:2.12.0
            command:
            - python
            - /opt/model/train.py
            args:
            - --epochs=10
            - --batch-size=32
    Worker:
      replicas: 2
      restartPolicy: OnFailure
      template:
        spec:
          containers:
          - name: tensorflow
            image: tensorflow/tensorflow:2.12.0
            command:
            - python
            - /opt/model/train.py
            args:
            - --epochs=10
            - --batch-size=32

参数服务器模式 #

yaml
apiVersion: kubeflow.org/v1
kind: TFJob
metadata:
  name: ps-mode-tfjob
  namespace: kubeflow-user-example-com
spec:
  tfReplicaSpecs:
    PS:
      replicas: 2
      template:
        spec:
          containers:
          - name: tensorflow
            image: tensorflow/tensorflow:2.12.0
            resources:
              requests:
                cpu: "2"
                memory: "4Gi"
              limits:
                cpu: "4"
                memory: "8Gi"
    Worker:
      replicas: 4
      template:
        spec:
          containers:
          - name: tensorflow
            image: tensorflow/tensorflow:2.12.0
            resources:
              requests:
                cpu: "4"
                memory: "8Gi"
              limits:
                cpu: "8"
                memory: "16Gi"

AllReduce 模式 #

yaml
apiVersion: kubeflow.org/v1
kind: TFJob
metadata:
  name: allreduce-tfjob
  namespace: kubeflow-user-example-com
spec:
  tfReplicaSpecs:
    Chief:
      replicas: 1
      template:
        spec:
          containers:
          - name: tensorflow
            image: tensorflow/tensorflow:2.12.0
            env:
            - name: TF_CONFIG
              valueFrom:
                fieldRef:
                  fieldPath: metadata.annotations['kubeflow.org/tf-job-config']
    Worker:
      replicas: 4
      template:
        spec:
          containers:
          - name: tensorflow
            image: tensorflow/tensorflow:2.12.0

GPU 训练 #

单节点多 GPU #

yaml
apiVersion: kubeflow.org/v1
kind: TFJob
metadata:
  name: single-node-gpu-tfjob
  namespace: kubeflow-user-example-com
spec:
  tfReplicaSpecs:
    Chief:
      replicas: 1
      template:
        spec:
          containers:
          - name: tensorflow
            image: tensorflow/tensorflow:2.12.0-gpu
            resources:
              limits:
                nvidia.com/gpu: 4
            volumeMounts:
            - name: data
              mountPath: /data
          volumes:
          - name: data
            persistentVolumeClaim:
              claimName: training-data-pvc

多节点多 GPU #

yaml
apiVersion: kubeflow.org/v1
kind: TFJob
metadata:
  name: multi-node-gpu-tfjob
  namespace: kubeflow-user-example-com
spec:
  tfReplicaSpecs:
    Chief:
      replicas: 1
      template:
        spec:
          containers:
          - name: tensorflow
            image: tensorflow/tensorflow:2.12.0-gpu
            resources:
              limits:
                nvidia.com/gpu: 4
    Worker:
      replicas: 4
      template:
        spec:
          containers:
          - name: tensorflow
            image: tensorflow/tensorflow:2.12.0-gpu
            resources:
              limits:
                nvidia.com/gpu: 4

训练脚本示例 #

分布式训练脚本 #

python
import tensorflow as tf
import json
import os

def setup_distributed_training():
    tf_config = os.environ.get('TF_CONFIG')
    if tf_config:
        config = json.loads(tf_config)
        cluster_spec = config.get('cluster', {})
        task_spec = config.get('task', {})
        
        cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
        
        if task_spec.get('type') == 'ps':
            server = tf.distribute.Server(
                cluster_spec,
                job_name='ps',
                task_index=task_spec.get('index', 0)
            )
            server.join()
        elif task_spec.get('type') == 'chief' or task_spec.get('type') == 'worker':
            strategy = tf.distribute.experimental.ParameterServerStrategy(
                cluster_resolver
            )
            return strategy
    
    return tf.distribute.get_strategy()

def create_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10)
    ])
    return model

def main():
    strategy = setup_distributed_training()
    
    with strategy.scope():
        model = create_model()
        model.compile(
            optimizer='adam',
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            metrics=['accuracy']
        )
    
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    
    model.fit(
        x_train, y_train,
        epochs=10,
        batch_size=64,
        validation_data=(x_test, y_test)
    )
    
    model.save('/output/model')

if __name__ == '__main__':
    main()

MultiWorkerMirroredStrategy #

python
import tensorflow as tf
import os

def main():
    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    
    with strategy.scope():
        model = tf.keras.Sequential([
            tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
            tf.keras.layers.MaxPooling2D(),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(10)
        ])
        
        model.compile(
            optimizer='adam',
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            metrics=['accuracy']
        )
    
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), _ = mnist.load_data()
    x_train = x_train.reshape(-1, 28, 28, 1) / 255.0
    
    model.fit(x_train, y_train, epochs=10, batch_size=64)
    
    model.save('/output/model')

if __name__ == '__main__':
    main()

存储配置 #

数据存储 #

yaml
apiVersion: kubeflow.org/v1
kind: TFJob
metadata:
  name: storage-tfjob
  namespace: kubeflow-user-example-com
spec:
  tfReplicaSpecs:
    Chief:
      replicas: 1
      template:
        spec:
          containers:
          - name: tensorflow
            image: tensorflow/tensorflow:2.12.0
            volumeMounts:
            - name: training-data
              mountPath: /data
            - name: model-output
              mountPath: /output
            - name: tensorboard-logs
              mountPath: /logs
          volumes:
          - name: training-data
            persistentVolumeClaim:
              claimName: training-data-pvc
          - name: model-output
            persistentVolumeClaim:
              claimName: model-output-pvc
          - name: tensorboard-logs
            persistentVolumeClaim:
              claimName: tensorboard-logs-pvc

ConfigMap 配置 #

yaml
apiVersion: v1
kind: ConfigMap
metadata:
  name: training-config
  namespace: kubeflow-user-example-com
data:
  config.yaml: |
    training:
      epochs: 100
      batch_size: 64
      learning_rate: 0.001
---
apiVersion: kubeflow.org/v1
kind: TFJob
metadata:
  name: config-tfjob
  namespace: kubeflow-user-example-com
spec:
  tfReplicaSpecs:
    Chief:
      replicas: 1
      template:
        spec:
          containers:
          - name: tensorflow
            image: tensorflow/tensorflow:2.12.0
            volumeMounts:
            - name: config
              mountPath: /etc/config
          volumes:
          - name: config
            configMap:
              name: training-config

Evaluator 配置 #

yaml
apiVersion: kubeflow.org/v1
kind: TFJob
metadata:
  name: evaluator-tfjob
  namespace: kubeflow-user-example-com
spec:
  tfReplicaSpecs:
    Chief:
      replicas: 1
      template:
        spec:
          containers:
          - name: tensorflow
            image: tensorflow/tensorflow:2.12.0
            command: ["python", "/opt/model/train.py"]
    Worker:
      replicas: 2
      template:
        spec:
          containers:
          - name: tensorflow
            image: tensorflow/tensorflow:2.12.0
            command: ["python", "/opt/model/train.py"]
    Evaluator:
      replicas: 1
      template:
        spec:
          containers:
          - name: tensorflow
            image: tensorflow/tensorflow:2.12.0
            command: ["python", "/opt/model/evaluate.py"]

运行策略 #

清理策略 #

yaml
apiVersion: kubeflow.org/v1
kind: TFJob
metadata:
  name: clean-tfjob
  namespace: kubeflow-user-example-com
spec:
  runPolicy:
    cleanPodPolicy: Running  # None, Running, All
  tfReplicaSpecs:
    Chief:
      replicas: 1
      template:
        spec:
          containers:
          - name: tensorflow
            image: tensorflow/tensorflow:2.12.0

超时和 TTL #

yaml
apiVersion: kubeflow.org/v1
kind: TFJob
metadata:
  name: timeout-tfjob
  namespace: kubeflow-user-example-com
spec:
  runPolicy:
    activeDeadlineSeconds: 7200  # 2 小时超时
    ttlSecondsAfterFinished: 3600  # 完成后 1 小时清理
  tfReplicaSpecs:
    Chief:
      replicas: 1
      template:
        spec:
          containers:
          - name: tensorflow
            image: tensorflow/tensorflow:2.12.0

管理 TFJob #

查看状态 #

bash
# 列出 TFJob
kubectl get tfjobs -n kubeflow-user-example-com

# 查看详情
kubectl describe tfjob basic-tfjob -n kubeflow-user-example-com

# 查看状态
kubectl get tfjob basic-tfjob -n kubeflow-user-example-com -o jsonpath='{.status}'

# 查看 Pod
kubectl get pods -n kubeflow-user-example-com -l kubeflow.org/job-name=basic-tfjob

查看日志 #

bash
# 查看 Chief 日志
kubectl logs basic-tfjob-chief-0 -n kubeflow-user-example-com

# 查看 Worker 日志
kubectl logs basic-tfjob-worker-0 -n kubeflow-user-example-com

# 查看 PS 日志
kubectl logs basic-tfjob-ps-0 -n kubeflow-user-example-com

# 实时日志
kubectl logs -f basic-tfjob-chief-0 -n kubeflow-user-example-com

停止和删除 #

bash
# 删除 TFJob
kubectl delete tfjob basic-tfjob -n kubeflow-user-example-com

# 查看所有 TFJob
kubectl get tfjobs --all-namespaces

最佳实践 #

训练优化 #

text
1. 数据管道优化
   ├── 使用 tf.data API
   ├── 预取和缓存
   ├── 并行数据加载
   └── 数据增强

2. 分布式策略选择
   ├── 单机多卡:MirroredStrategy
   ├── 多机多卡:MultiWorkerMirroredStrategy
   └── 大规模:ParameterServerStrategy

3. GPU 优化
   ├── 混合精度训练
   ├── XLA 编译
   └── 内存优化

故障恢复 #

text
1. 检查点保存
   ├── 定期保存
   ├── 保存到持久存储
   └── 保存优化器状态

2. 断点续训
   ├── 加载最新检查点
   ├── 恢复训练状态
   └── 继续训练

3. 错误处理
   ├── 重启策略
   ├── 错误日志
   └── 告警通知

下一步 #

现在你已经掌握了 TFJob 的使用,接下来学习 PyTorch 训练,了解 PyTorch 分布式训练的配置!

最后更新:2026-04-05