Ray Train 分布式训练 #

什么是 Ray Train? #

Ray Train 是 Ray 提供的分布式训练框架,简化了分布式机器学习模型的训练过程。它支持主流的深度学习框架,并提供统一的 API 进行分布式训练。

text
┌─────────────────────────────────────────────────────────────┐
│                    Ray Train 架构                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  用户代码                                                    │
│  ┌─────────────────────────────────────────────────────┐   │
│  │              train_fn(config)                        │   │
│  │  ├── 数据加载                                        │   │
│  │  ├── 模型定义                                        │   │
│  │  ├── 训练循环                                        │   │
│  │  └── 检查点保存                                      │   │
│  └─────────────────────────────────────────────────────┘   │
│                          │                                  │
│                          ▼                                  │
│  ┌─────────────────────────────────────────────────────┐   │
│  │                   Trainer                            │   │
│  │  ┌─────────────┐  ┌─────────────┐                   │   │
│  │  │  Worker 0   │  │  Worker 1   │  ...              │   │
│  │  │  (GPU 0)    │  │  (GPU 1)    │                   │   │
│  │  └─────────────┘  └─────────────┘                   │   │
│  └─────────────────────────────────────────────────────┘   │
│                          │                                  │
│                          ▼                                  │
│  ┌─────────────────────────────────────────────────────┐   │
│  │               ScalingConfig                          │   │
│  │  ├── num_workers: 工作进程数                         │   │
│  │  ├── use_gpu: 是否使用 GPU                          │   │
│  │  └── resources_per_worker: 每个工作进程资源          │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

基本用法 #

简单训练 #

python
import ray
from ray import train

ray.init()

def train_fn(config):
    for epoch in range(10):
        loss = 1.0 / (epoch + 1)
        train.report({"loss": loss, "epoch": epoch})

trainer = train.torch.TorchTrainer(
    train_loop_per_worker=train_fn,
    scaling_config=train.ScalingConfig(num_workers=1)
)

result = trainer.fit()
print(f"Final loss: {result.metrics['loss']}")

ray.shutdown()

使用配置 #

python
import ray
from ray import train

ray.init()

def train_fn(config):
    lr = config["lr"]
    epochs = config["epochs"]
    
    for epoch in range(epochs):
        loss = lr / (epoch + 1)
        train.report({"loss": loss})

trainer = train.torch.TorchTrainer(
    train_loop_per_worker=train_fn,
    train_loop_config={"lr": 0.01, "epochs": 20},
    scaling_config=train.ScalingConfig(num_workers=1)
)

result = trainer.fit()

ray.shutdown()

PyTorch 集成 #

基础训练 #

python
import ray
from ray import train
import torch
import torch.nn as nn
import torch.optim as optim

ray.init()

def train_fn(config):
    model = nn.Sequential(
        nn.Linear(10, 32),
        nn.ReLU(),
        nn.Linear(32, 1)
    )
    
    model = train.torch.prepare_model(model)
    
    optimizer = optim.Adam(model.parameters(), lr=config["lr"])
    
    for epoch in range(config["epochs"]):
        inputs = torch.randn(32, 10)
        labels = torch.randn(32, 1)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = nn.MSELoss()(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train.report({"loss": loss.item(), "epoch": epoch})

trainer = train.torch.TorchTrainer(
    train_loop_per_worker=train_fn,
    train_loop_config={"lr": 0.001, "epochs": 10},
    scaling_config=train.ScalingConfig(num_workers=2, use_gpu=True)
)

result = trainer.fit()

ray.shutdown()

数据加载 #

python
import ray
from ray import train
import torch
from torch.utils.data import DataLoader, TensorDataset

ray.init()

def train_fn(config):
    X = torch.randn(1000, 10)
    y = torch.randn(1000, 1)
    dataset = TensorDataset(X, y)
    
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    dataloader = train.torch.prepare_data_loader(dataloader)
    
    model = nn.Sequential(nn.Linear(10, 1))
    model = train.torch.prepare_model(model)
    
    for epoch in range(5):
        for batch_idx, (data, target) in enumerate(dataloader):
            pass
        
        train.report({"epoch": epoch})

ray.shutdown()

检查点 #

python
import ray
from ray import train
import torch
import torch.nn as nn

ray.init()

def train_fn(config):
    model = nn.Linear(10, 1)
    model = train.torch.prepare_model(model)
    
    start_epoch = 0
    
    if train.get_checkpoint():
        checkpoint = train.get_checkpoint()
        with checkpoint.as_directory() as checkpoint_dir:
            model_state = torch.load(f"{checkpoint_dir}/model.pt")
            model.load_state_dict(model_state)
            start_epoch = torch.load(f"{checkpoint_dir}/epoch.pt")
    
    for epoch in range(start_epoch, config["epochs"]):
        pass
        
        checkpoint_dir = "/tmp/checkpoint"
        torch.save(model.state_dict(), f"{checkpoint_dir}/model.pt")
        torch.save(epoch, f"{checkpoint_dir}/epoch.pt")
        train.save_checkpoint(checkpoint_dir)

ray.shutdown()

分布式策略 #

数据并行 #

python
import ray
from ray import train

ray.init()

def train_fn(config):
    model = create_model()
    model = train.torch.prepare_model(model)
    
    dataloader = create_dataloader()
    dataloader = train.torch.prepare_data_loader(dataloader)
    
    for epoch in range(config["epochs"]):
        for batch in dataloader:
            pass
        train.report({"epoch": epoch})

trainer = train.torch.TorchTrainer(
    train_loop_per_worker=train_fn,
    scaling_config=train.ScalingConfig(
        num_workers=4,
        use_gpu=True
    )
)

ray.shutdown()

分布式配置 #

text
┌─────────────────────────────────────────────────────────────┐
│                    分布式策略配置                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ScalingConfig:                                             │
│  ├── num_workers: 工作进程数量                              │
│  ├── use_gpu: 是否使用 GPU                                  │
│  ├── num_cpus_per_worker: 每个工作进程 CPU 数               │
│  ├── num_gpus_per_worker: 每个工作进程 GPU 数               │
│  └── resources_per_worker: 自定义资源                       │
│                                                             │
│  RunConfig:                                                 │
│  ├── name: 运行名称                                         │
│  ├── storage_path: 存储路径                                 │
│  ├── checkpoint_config: 检查点配置                          │
│  └── failure_config: 失败处理配置                           │
│                                                             │
└─────────────────────────────────────────────────────────────┘

框架集成 #

TensorFlow 集成 #

python
import ray
from ray import train
import tensorflow as tf

ray.init()

def train_fn(config):
    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    
    with strategy.scope():
        model = tf.keras.Sequential([
            tf.keras.layers.Dense(32, activation='relu', input_shape=(10,)),
            tf.keras.layers.Dense(1)
        ])
        model.compile(optimizer='adam', loss='mse')
    
    for epoch in range(config["epochs"]):
        import numpy as np
        x = np.random.randn(100, 10)
        y = np.random.randn(100, 1)
        history = model.fit(x, y, epochs=1, verbose=0)
        train.report({"loss": history.history["loss"][0]})

trainer = train.tensorflow.TensorflowTrainer(
    train_loop_per_worker=train_fn,
    train_loop_config={"epochs": 10},
    scaling_config=train.ScalingConfig(num_workers=2)
)

ray.shutdown()

XGBoost 集成 #

python
import ray
from ray import train
import xgboost as xgb

ray.init()

def train_fn(config):
    import numpy as np
    X = np.random.randn(1000, 10)
    y = np.random.randint(0, 2, 1000)
    
    dtrain = xgb.DMatrix(X, label=y)
    
    params = {
        "objective": "binary:logistic",
        "max_depth": config["max_depth"],
        "eta": config["eta"]
    }
    
    model = xgb.train(params, dtrain, num_boost_round=10)
    
    train.report({"model": model})

trainer = train.xgboost.XGBoostTrainer(
    train_loop_per_worker=train_fn,
    train_loop_config={"max_depth": 6, "eta": 0.1},
    scaling_config=train.ScalingConfig(num_workers=2)
)

ray.shutdown()

与 Ray Tune 集成 #

python
import ray
from ray import train, tune

ray.init()

def train_fn(config):
    model = create_model()
    lr = config["lr"]
    
    for epoch in range(10):
        loss = train_model(model, lr)
        train.report({"loss": loss})

tuner = tune.Tuner(
    train_fn,
    param_space={
        "lr": tune.loguniform(1e-4, 1e-1)
    },
    tune_config=tune.TuneConfig(num_samples=10)
)

results = tuner.fit()

ray.shutdown()

与 Ray Data 集成 #

python
import ray
from ray import train

ray.init()

def train_fn(config):
    dataset = train.get_dataset_shard("train")
    
    model = create_model()
    
    for epoch in range(config["epochs"]):
        for batch in dataset.iter_batches(batch_size=32):
            pass
        
        train.report({"epoch": epoch})

train_ds = ray.data.from_items([{"features": [i], "label": i % 2} for i in range(1000)])

trainer = train.torch.TorchTrainer(
    train_loop_per_worker=train_fn,
    datasets={"train": train_ds},
    scaling_config=train.ScalingConfig(num_workers=2)
)

ray.shutdown()

最佳实践 #

1. 合理设置工作进程数 #

python
import ray
from ray import train

scaling_config = train.ScalingConfig(
    num_workers=4,
    use_gpu=True,
    num_gpus_per_worker=1
)

scaling_config = train.ScalingConfig(
    num_workers=8,
    use_gpu=True,
    num_gpus_per_worker=0.5
)

2. 定期保存检查点 #

python
import ray
from ray import train

def train_fn(config):
    model = create_model()
    
    for epoch in range(config["epochs"]):
        train_one_epoch(model)
        
        if epoch % 5 == 0:
            train.save_checkpoint(epoch=epoch, model=model.state_dict())

ray.shutdown()

3. 使用 Ray Data 加载数据 #

python
import ray
from ray import train

train_ds = ray.data.read_parquet("train.parquet")
val_ds = ray.data.read_parquet("val.parquet")

trainer = train.torch.TorchTrainer(
    train_loop_per_worker=train_fn,
    datasets={"train": train_ds, "val": val_ds},
    dataset_config={
        "train": train.DatasetConfig(split=True),
        "val": train.DatasetConfig(split=False)
    }
)

下一步 #

掌握了 Ray Train 之后,继续学习 Ray Serve 模型服务,了解如何部署模型到生产环境!

最后更新:2026-04-05