Ray Train 分布式训练 #
什么是 Ray Train? #
Ray Train 是 Ray 提供的分布式训练框架,简化了分布式机器学习模型的训练过程。它支持主流的深度学习框架,并提供统一的 API 进行分布式训练。
text
┌─────────────────────────────────────────────────────────────┐
│ Ray Train 架构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 用户代码 │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ train_fn(config) │ │
│ │ ├── 数据加载 │ │
│ │ ├── 模型定义 │ │
│ │ ├── 训练循环 │ │
│ │ └── 检查点保存 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ Trainer │ │
│ │ ┌─────────────┐ ┌─────────────┐ │ │
│ │ │ Worker 0 │ │ Worker 1 │ ... │ │
│ │ │ (GPU 0) │ │ (GPU 1) │ │ │
│ │ └─────────────┘ └─────────────┘ │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ ScalingConfig │ │
│ │ ├── num_workers: 工作进程数 │ │
│ │ ├── use_gpu: 是否使用 GPU │ │
│ │ └── resources_per_worker: 每个工作进程资源 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
基本用法 #
简单训练 #
python
import ray
from ray import train
ray.init()
def train_fn(config):
for epoch in range(10):
loss = 1.0 / (epoch + 1)
train.report({"loss": loss, "epoch": epoch})
trainer = train.torch.TorchTrainer(
train_loop_per_worker=train_fn,
scaling_config=train.ScalingConfig(num_workers=1)
)
result = trainer.fit()
print(f"Final loss: {result.metrics['loss']}")
ray.shutdown()
使用配置 #
python
import ray
from ray import train
ray.init()
def train_fn(config):
lr = config["lr"]
epochs = config["epochs"]
for epoch in range(epochs):
loss = lr / (epoch + 1)
train.report({"loss": loss})
trainer = train.torch.TorchTrainer(
train_loop_per_worker=train_fn,
train_loop_config={"lr": 0.01, "epochs": 20},
scaling_config=train.ScalingConfig(num_workers=1)
)
result = trainer.fit()
ray.shutdown()
PyTorch 集成 #
基础训练 #
python
import ray
from ray import train
import torch
import torch.nn as nn
import torch.optim as optim
ray.init()
def train_fn(config):
model = nn.Sequential(
nn.Linear(10, 32),
nn.ReLU(),
nn.Linear(32, 1)
)
model = train.torch.prepare_model(model)
optimizer = optim.Adam(model.parameters(), lr=config["lr"])
for epoch in range(config["epochs"]):
inputs = torch.randn(32, 10)
labels = torch.randn(32, 1)
optimizer.zero_grad()
outputs = model(inputs)
loss = nn.MSELoss()(outputs, labels)
loss.backward()
optimizer.step()
train.report({"loss": loss.item(), "epoch": epoch})
trainer = train.torch.TorchTrainer(
train_loop_per_worker=train_fn,
train_loop_config={"lr": 0.001, "epochs": 10},
scaling_config=train.ScalingConfig(num_workers=2, use_gpu=True)
)
result = trainer.fit()
ray.shutdown()
数据加载 #
python
import ray
from ray import train
import torch
from torch.utils.data import DataLoader, TensorDataset
ray.init()
def train_fn(config):
X = torch.randn(1000, 10)
y = torch.randn(1000, 1)
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
dataloader = train.torch.prepare_data_loader(dataloader)
model = nn.Sequential(nn.Linear(10, 1))
model = train.torch.prepare_model(model)
for epoch in range(5):
for batch_idx, (data, target) in enumerate(dataloader):
pass
train.report({"epoch": epoch})
ray.shutdown()
检查点 #
python
import ray
from ray import train
import torch
import torch.nn as nn
ray.init()
def train_fn(config):
model = nn.Linear(10, 1)
model = train.torch.prepare_model(model)
start_epoch = 0
if train.get_checkpoint():
checkpoint = train.get_checkpoint()
with checkpoint.as_directory() as checkpoint_dir:
model_state = torch.load(f"{checkpoint_dir}/model.pt")
model.load_state_dict(model_state)
start_epoch = torch.load(f"{checkpoint_dir}/epoch.pt")
for epoch in range(start_epoch, config["epochs"]):
pass
checkpoint_dir = "/tmp/checkpoint"
torch.save(model.state_dict(), f"{checkpoint_dir}/model.pt")
torch.save(epoch, f"{checkpoint_dir}/epoch.pt")
train.save_checkpoint(checkpoint_dir)
ray.shutdown()
分布式策略 #
数据并行 #
python
import ray
from ray import train
ray.init()
def train_fn(config):
model = create_model()
model = train.torch.prepare_model(model)
dataloader = create_dataloader()
dataloader = train.torch.prepare_data_loader(dataloader)
for epoch in range(config["epochs"]):
for batch in dataloader:
pass
train.report({"epoch": epoch})
trainer = train.torch.TorchTrainer(
train_loop_per_worker=train_fn,
scaling_config=train.ScalingConfig(
num_workers=4,
use_gpu=True
)
)
ray.shutdown()
分布式配置 #
text
┌─────────────────────────────────────────────────────────────┐
│ 分布式策略配置 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ScalingConfig: │
│ ├── num_workers: 工作进程数量 │
│ ├── use_gpu: 是否使用 GPU │
│ ├── num_cpus_per_worker: 每个工作进程 CPU 数 │
│ ├── num_gpus_per_worker: 每个工作进程 GPU 数 │
│ └── resources_per_worker: 自定义资源 │
│ │
│ RunConfig: │
│ ├── name: 运行名称 │
│ ├── storage_path: 存储路径 │
│ ├── checkpoint_config: 检查点配置 │
│ └── failure_config: 失败处理配置 │
│ │
└─────────────────────────────────────────────────────────────┘
框架集成 #
TensorFlow 集成 #
python
import ray
from ray import train
import tensorflow as tf
ray.init()
def train_fn(config):
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Dense(32, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(1)
])
model.compile(optimizer='adam', loss='mse')
for epoch in range(config["epochs"]):
import numpy as np
x = np.random.randn(100, 10)
y = np.random.randn(100, 1)
history = model.fit(x, y, epochs=1, verbose=0)
train.report({"loss": history.history["loss"][0]})
trainer = train.tensorflow.TensorflowTrainer(
train_loop_per_worker=train_fn,
train_loop_config={"epochs": 10},
scaling_config=train.ScalingConfig(num_workers=2)
)
ray.shutdown()
XGBoost 集成 #
python
import ray
from ray import train
import xgboost as xgb
ray.init()
def train_fn(config):
import numpy as np
X = np.random.randn(1000, 10)
y = np.random.randint(0, 2, 1000)
dtrain = xgb.DMatrix(X, label=y)
params = {
"objective": "binary:logistic",
"max_depth": config["max_depth"],
"eta": config["eta"]
}
model = xgb.train(params, dtrain, num_boost_round=10)
train.report({"model": model})
trainer = train.xgboost.XGBoostTrainer(
train_loop_per_worker=train_fn,
train_loop_config={"max_depth": 6, "eta": 0.1},
scaling_config=train.ScalingConfig(num_workers=2)
)
ray.shutdown()
与 Ray Tune 集成 #
python
import ray
from ray import train, tune
ray.init()
def train_fn(config):
model = create_model()
lr = config["lr"]
for epoch in range(10):
loss = train_model(model, lr)
train.report({"loss": loss})
tuner = tune.Tuner(
train_fn,
param_space={
"lr": tune.loguniform(1e-4, 1e-1)
},
tune_config=tune.TuneConfig(num_samples=10)
)
results = tuner.fit()
ray.shutdown()
与 Ray Data 集成 #
python
import ray
from ray import train
ray.init()
def train_fn(config):
dataset = train.get_dataset_shard("train")
model = create_model()
for epoch in range(config["epochs"]):
for batch in dataset.iter_batches(batch_size=32):
pass
train.report({"epoch": epoch})
train_ds = ray.data.from_items([{"features": [i], "label": i % 2} for i in range(1000)])
trainer = train.torch.TorchTrainer(
train_loop_per_worker=train_fn,
datasets={"train": train_ds},
scaling_config=train.ScalingConfig(num_workers=2)
)
ray.shutdown()
最佳实践 #
1. 合理设置工作进程数 #
python
import ray
from ray import train
scaling_config = train.ScalingConfig(
num_workers=4,
use_gpu=True,
num_gpus_per_worker=1
)
scaling_config = train.ScalingConfig(
num_workers=8,
use_gpu=True,
num_gpus_per_worker=0.5
)
2. 定期保存检查点 #
python
import ray
from ray import train
def train_fn(config):
model = create_model()
for epoch in range(config["epochs"]):
train_one_epoch(model)
if epoch % 5 == 0:
train.save_checkpoint(epoch=epoch, model=model.state_dict())
ray.shutdown()
3. 使用 Ray Data 加载数据 #
python
import ray
from ray import train
train_ds = ray.data.read_parquet("train.parquet")
val_ds = ray.data.read_parquet("val.parquet")
trainer = train.torch.TorchTrainer(
train_loop_per_worker=train_fn,
datasets={"train": train_ds, "val": val_ds},
dataset_config={
"train": train.DatasetConfig(split=True),
"val": train.DatasetConfig(split=False)
}
)
下一步 #
掌握了 Ray Train 之后,继续学习 Ray Serve 模型服务,了解如何部署模型到生产环境!
最后更新:2026-04-05