TensorFlow 训练 #
概述 #
TFJob 是 Kubeflow 提供的 TensorFlow 分布式训练作业类型,支持 Parameter Server 和 AllReduce 两种分布式训练模式。
TFJob 架构 #
text
┌─────────────────────────────────────────────────────────────┐
│ TFJob 架构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 参数服务器模式: │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ ┌──────┐ ┌──────┐ ┌──────┐ │ │
│ │ │ PS 0 │ │ PS 1 │ │ PS 2 │ Parameter Servers │ │
│ │ └──┬───┘ └──┬───┘ └──┬───┘ │ │
│ │ │ │ │ │ │
│ │ └─────────┼─────────┘ │ │
│ │ │ │ │
│ │ ┌─────────┼─────────┐ │ │
│ │ │ │ │ │ │
│ │ ┌──┴───┐ ┌──┴───┐ ┌──┴───┐ │ │
│ │ │Worker│ │Worker│ │Worker│ Workers │ │
│ │ │ 0 │ │ 1 │ │ 2 │ │ │
│ │ └──────┘ └──────┘ └──────┘ │ │
│ │ │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ AllReduce 模式: │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ ┌──────┐ ┌──────┐ ┌──────┐ │ │
│ │ │Chief │←───→│Worker│←───→│Worker│ │ │
│ │ │ 0 │ │ 1 │ │ 2 │ │ │
│ │ └──────┘ └──────┘ └──────┘ │ │
│ │ │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
TFJob 角色类型 #
text
TFJob 角色:
Chief:
├── 主节点
├── 协调训练
├── 保存检查点
├── 写入 TensorBoard 日志
└── 执行训练循环
PS (Parameter Server):
├── 参数服务器
├── 存储模型参数
├── 参数聚合和分发
└── 用于参数服务器模式
Worker:
├── 工作节点
├── 执行训练计算
├── 处理数据批次
└── 更新梯度
Evaluator:
├── 评估节点
├── 模型验证
├── 计算评估指标
└── 独立于训练过程
创建 TFJob #
基本配置 #
yaml
apiVersion: kubeflow.org/v1
kind: TFJob
metadata:
name: basic-tfjob
namespace: kubeflow-user-example-com
spec:
tfReplicaSpecs:
Chief:
replicas: 1
restartPolicy: OnFailure
template:
spec:
containers:
- name: tensorflow
image: tensorflow/tensorflow:2.12.0
command:
- python
- /opt/model/train.py
args:
- --epochs=10
- --batch-size=32
Worker:
replicas: 2
restartPolicy: OnFailure
template:
spec:
containers:
- name: tensorflow
image: tensorflow/tensorflow:2.12.0
command:
- python
- /opt/model/train.py
args:
- --epochs=10
- --batch-size=32
参数服务器模式 #
yaml
apiVersion: kubeflow.org/v1
kind: TFJob
metadata:
name: ps-mode-tfjob
namespace: kubeflow-user-example-com
spec:
tfReplicaSpecs:
PS:
replicas: 2
template:
spec:
containers:
- name: tensorflow
image: tensorflow/tensorflow:2.12.0
resources:
requests:
cpu: "2"
memory: "4Gi"
limits:
cpu: "4"
memory: "8Gi"
Worker:
replicas: 4
template:
spec:
containers:
- name: tensorflow
image: tensorflow/tensorflow:2.12.0
resources:
requests:
cpu: "4"
memory: "8Gi"
limits:
cpu: "8"
memory: "16Gi"
AllReduce 模式 #
yaml
apiVersion: kubeflow.org/v1
kind: TFJob
metadata:
name: allreduce-tfjob
namespace: kubeflow-user-example-com
spec:
tfReplicaSpecs:
Chief:
replicas: 1
template:
spec:
containers:
- name: tensorflow
image: tensorflow/tensorflow:2.12.0
env:
- name: TF_CONFIG
valueFrom:
fieldRef:
fieldPath: metadata.annotations['kubeflow.org/tf-job-config']
Worker:
replicas: 4
template:
spec:
containers:
- name: tensorflow
image: tensorflow/tensorflow:2.12.0
GPU 训练 #
单节点多 GPU #
yaml
apiVersion: kubeflow.org/v1
kind: TFJob
metadata:
name: single-node-gpu-tfjob
namespace: kubeflow-user-example-com
spec:
tfReplicaSpecs:
Chief:
replicas: 1
template:
spec:
containers:
- name: tensorflow
image: tensorflow/tensorflow:2.12.0-gpu
resources:
limits:
nvidia.com/gpu: 4
volumeMounts:
- name: data
mountPath: /data
volumes:
- name: data
persistentVolumeClaim:
claimName: training-data-pvc
多节点多 GPU #
yaml
apiVersion: kubeflow.org/v1
kind: TFJob
metadata:
name: multi-node-gpu-tfjob
namespace: kubeflow-user-example-com
spec:
tfReplicaSpecs:
Chief:
replicas: 1
template:
spec:
containers:
- name: tensorflow
image: tensorflow/tensorflow:2.12.0-gpu
resources:
limits:
nvidia.com/gpu: 4
Worker:
replicas: 4
template:
spec:
containers:
- name: tensorflow
image: tensorflow/tensorflow:2.12.0-gpu
resources:
limits:
nvidia.com/gpu: 4
训练脚本示例 #
分布式训练脚本 #
python
import tensorflow as tf
import json
import os
def setup_distributed_training():
tf_config = os.environ.get('TF_CONFIG')
if tf_config:
config = json.loads(tf_config)
cluster_spec = config.get('cluster', {})
task_spec = config.get('task', {})
cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
if task_spec.get('type') == 'ps':
server = tf.distribute.Server(
cluster_spec,
job_name='ps',
task_index=task_spec.get('index', 0)
)
server.join()
elif task_spec.get('type') == 'chief' or task_spec.get('type') == 'worker':
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver
)
return strategy
return tf.distribute.get_strategy()
def create_model():
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
return model
def main():
strategy = setup_distributed_training()
with strategy.scope():
model = create_model()
model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model.fit(
x_train, y_train,
epochs=10,
batch_size=64,
validation_data=(x_test, y_test)
)
model.save('/output/model')
if __name__ == '__main__':
main()
MultiWorkerMirroredStrategy #
python
import tensorflow as tf
import os
def main():
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
mnist = tf.keras.datasets.mnist
(x_train, y_train), _ = mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1) / 255.0
model.fit(x_train, y_train, epochs=10, batch_size=64)
model.save('/output/model')
if __name__ == '__main__':
main()
存储配置 #
数据存储 #
yaml
apiVersion: kubeflow.org/v1
kind: TFJob
metadata:
name: storage-tfjob
namespace: kubeflow-user-example-com
spec:
tfReplicaSpecs:
Chief:
replicas: 1
template:
spec:
containers:
- name: tensorflow
image: tensorflow/tensorflow:2.12.0
volumeMounts:
- name: training-data
mountPath: /data
- name: model-output
mountPath: /output
- name: tensorboard-logs
mountPath: /logs
volumes:
- name: training-data
persistentVolumeClaim:
claimName: training-data-pvc
- name: model-output
persistentVolumeClaim:
claimName: model-output-pvc
- name: tensorboard-logs
persistentVolumeClaim:
claimName: tensorboard-logs-pvc
ConfigMap 配置 #
yaml
apiVersion: v1
kind: ConfigMap
metadata:
name: training-config
namespace: kubeflow-user-example-com
data:
config.yaml: |
training:
epochs: 100
batch_size: 64
learning_rate: 0.001
---
apiVersion: kubeflow.org/v1
kind: TFJob
metadata:
name: config-tfjob
namespace: kubeflow-user-example-com
spec:
tfReplicaSpecs:
Chief:
replicas: 1
template:
spec:
containers:
- name: tensorflow
image: tensorflow/tensorflow:2.12.0
volumeMounts:
- name: config
mountPath: /etc/config
volumes:
- name: config
configMap:
name: training-config
Evaluator 配置 #
yaml
apiVersion: kubeflow.org/v1
kind: TFJob
metadata:
name: evaluator-tfjob
namespace: kubeflow-user-example-com
spec:
tfReplicaSpecs:
Chief:
replicas: 1
template:
spec:
containers:
- name: tensorflow
image: tensorflow/tensorflow:2.12.0
command: ["python", "/opt/model/train.py"]
Worker:
replicas: 2
template:
spec:
containers:
- name: tensorflow
image: tensorflow/tensorflow:2.12.0
command: ["python", "/opt/model/train.py"]
Evaluator:
replicas: 1
template:
spec:
containers:
- name: tensorflow
image: tensorflow/tensorflow:2.12.0
command: ["python", "/opt/model/evaluate.py"]
运行策略 #
清理策略 #
yaml
apiVersion: kubeflow.org/v1
kind: TFJob
metadata:
name: clean-tfjob
namespace: kubeflow-user-example-com
spec:
runPolicy:
cleanPodPolicy: Running # None, Running, All
tfReplicaSpecs:
Chief:
replicas: 1
template:
spec:
containers:
- name: tensorflow
image: tensorflow/tensorflow:2.12.0
超时和 TTL #
yaml
apiVersion: kubeflow.org/v1
kind: TFJob
metadata:
name: timeout-tfjob
namespace: kubeflow-user-example-com
spec:
runPolicy:
activeDeadlineSeconds: 7200 # 2 小时超时
ttlSecondsAfterFinished: 3600 # 完成后 1 小时清理
tfReplicaSpecs:
Chief:
replicas: 1
template:
spec:
containers:
- name: tensorflow
image: tensorflow/tensorflow:2.12.0
管理 TFJob #
查看状态 #
bash
# 列出 TFJob
kubectl get tfjobs -n kubeflow-user-example-com
# 查看详情
kubectl describe tfjob basic-tfjob -n kubeflow-user-example-com
# 查看状态
kubectl get tfjob basic-tfjob -n kubeflow-user-example-com -o jsonpath='{.status}'
# 查看 Pod
kubectl get pods -n kubeflow-user-example-com -l kubeflow.org/job-name=basic-tfjob
查看日志 #
bash
# 查看 Chief 日志
kubectl logs basic-tfjob-chief-0 -n kubeflow-user-example-com
# 查看 Worker 日志
kubectl logs basic-tfjob-worker-0 -n kubeflow-user-example-com
# 查看 PS 日志
kubectl logs basic-tfjob-ps-0 -n kubeflow-user-example-com
# 实时日志
kubectl logs -f basic-tfjob-chief-0 -n kubeflow-user-example-com
停止和删除 #
bash
# 删除 TFJob
kubectl delete tfjob basic-tfjob -n kubeflow-user-example-com
# 查看所有 TFJob
kubectl get tfjobs --all-namespaces
最佳实践 #
训练优化 #
text
1. 数据管道优化
├── 使用 tf.data API
├── 预取和缓存
├── 并行数据加载
└── 数据增强
2. 分布式策略选择
├── 单机多卡:MirroredStrategy
├── 多机多卡:MultiWorkerMirroredStrategy
└── 大规模:ParameterServerStrategy
3. GPU 优化
├── 混合精度训练
├── XLA 编译
└── 内存优化
故障恢复 #
text
1. 检查点保存
├── 定期保存
├── 保存到持久存储
└── 保存优化器状态
2. 断点续训
├── 加载最新检查点
├── 恢复训练状态
└── 继续训练
3. 错误处理
├── 重启策略
├── 错误日志
└── 告警通知
下一步 #
现在你已经掌握了 TFJob 的使用,接下来学习 PyTorch 训练,了解 PyTorch 分布式训练的配置!
最后更新:2026-04-05