Ray Data 数据处理 #

什么是 Ray Data? #

Ray Data 是 Ray 提供的大规模数据处理库,支持加载、转换和保存各种格式的数据。它提供了简洁的 API,可以轻松处理 TB 级数据。

text
┌─────────────────────────────────────────────────────────────┐
│                    Ray Data 架构                             │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  数据源                                                      │
│  ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐          │
│  │   CSV   │ │ Parquet │ │  JSON   │ │  Image  │          │
│  └─────────┘ └─────────┘ └─────────┘ └─────────┘          │
│       │                                                      │
│       ▼                                                      │
│  ┌─────────────────────────────────────────────────────┐   │
│  │                   Dataset                            │   │
│  │  ┌─────────────────────────────────────────────┐   │   │
│  │  │  Block 1  │  Block 2  │  Block 3  │ ...    │   │   │
│  │  └─────────────────────────────────────────────┘   │   │
│  └─────────────────────────────────────────────────────┘   │
│       │                                                      │
│       ▼                                                      │
│  数据转换                                                    │
│  ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐          │
│  │   map   │ │ filter  │ │ flat_map│ │ groupby │          │
│  └─────────┘ └─────────┘ └─────────┘ └─────────┘          │
│       │                                                      │
│       ▼                                                      │
│  数据输出                                                    │
│  ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐          │
│  │   CSV   │ │ Parquet │ │  JSON   │ │  Train  │          │
│  └─────────┘ └─────────┘ └─────────┘ └─────────┘          │
│                                                             │
└─────────────────────────────────────────────────────────────┘

创建 Dataset #

从文件读取 #

python
import ray

ray.init()

ds = ray.data.read_csv("data.csv")

ds = ray.data.read_parquet("data.parquet")

ds = ray.data.read_json("data.json")

ds = ray.data.read_csv("s3://bucket/data/*.csv")

print(f"Dataset size: {ds.count()}")

ray.shutdown()

从 Python 对象创建 #

python
import ray
import pandas as pd

ray.init()

ds = ray.data.from_items([1, 2, 3, 4, 5])

ds = ray.data.from_pandas(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))

import numpy as np
ds = ray.data.from_numpy(np.array([[1, 2], [3, 4], [5, 6]]))

print(ds.schema())

ray.shutdown()

从数据库读取 #

python
import ray

ray.init()

ds = ray.data.read_sql(
    "SELECT * FROM users",
    connection_factory=lambda: create_connection("postgresql://...")
)

ray.shutdown()

数据转换 #

map 操作 #

python
import ray

ray.init()

ds = ray.data.from_items([1, 2, 3, 4, 5])

ds = ds.map(lambda x: {"value": x * 2})

ds = ds.map_batches(lambda batch: {"value": batch["value"] + 1})

for row in ds.iter_rows():
    print(row)

ray.shutdown()

filter 操作 #

python
import ray

ray.init()

ds = ray.data.from_items([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

ds = ds.filter(lambda x: x % 2 == 0)

print(ds.take_all())

ray.shutdown()

flat_map 操作 #

python
import ray

ray.init()

ds = ray.data.from_items([1, 2, 3])

ds = ds.flat_map(lambda x: [x, x * 2])

print(ds.take_all())

ray.shutdown()

聚合操作 #

python
import ray

ray.init()

ds = ray.data.from_items([{"value": i} for i in range(10)])

print(f"Count: {ds.count()}")
print(f"Sum: {ds.sum('value')}")
print(f"Mean: {ds.mean('value')}")
print(f"Min: {ds.min('value')}")
print(f"Max: {ds.max('value')}")

ray.shutdown()

groupby 操作 #

python
import ray

ray.init()

ds = ray.data.from_items([
    {"category": "A", "value": 1},
    {"category": "A", "value": 2},
    {"category": "B", "value": 3},
    {"category": "B", "value": 4},
])

grouped = ds.groupby("category").sum("value")
print(grouped.take_all())

ray.shutdown()

sort 操作 #

python
import ray

ray.init()

ds = ray.data.from_items([{"value": i} for i in [5, 2, 8, 1, 9, 3]])

ds = ds.sort("value")
print(ds.take_all())

ds = ds.sort("value", descending=True)
print(ds.take_all())

ray.shutdown()

shuffle 操作 #

python
import ray

ray.init()

ds = ray.data.from_items([{"id": i, "value": i * 2} for i in range(100)])

ds = ds.random_shuffle()

print(ds.take(5))

ray.shutdown()

数据输出 #

写入文件 #

python
import ray

ray.init()

ds = ray.data.from_items([{"a": i, "b": i * 2} for i in range(100)])

ds.write_csv("output/data.csv")

ds.write_parquet("output/data.parquet")

ds.write_json("output/data.json")

ds.write_csv("s3://bucket/output/data.csv")

ray.shutdown()

转换为其他格式 #

python
import ray

ray.init()

ds = ray.data.from_items([{"a": 1, "b": 2}, {"a": 3, "b": 4}])

df = ds.to_pandas()
print(df)

items = ds.take_all()
print(items)

ray.shutdown()

数据加载 #

批量迭代 #

python
import ray

ray.init()

ds = ray.data.from_items([{"value": i} for i in range(100)])

for batch in ds.iter_batches(batch_size=10):
    print(f"Batch size: {len(batch['value'])}")

for row in ds.iter_rows():
    print(row)

ray.shutdown()

与训练框架集成 #

python
import ray
from ray.data import Dataset

ray.init()

ds = ray.data.from_items([{"features": [i, i+1], "label": i % 2} for i in range(100)])

def to_torch(batch):
    import torch
    return {
        "features": torch.tensor(batch["features"], dtype=torch.float32),
        "label": torch.tensor(batch["label"], dtype=torch.long)
    }

torch_ds = ds.map_batches(to_torch)

for batch in torch_ds.iter_batches(batch_size=10):
    print(batch["features"].shape)

ray.shutdown()

数据处理流水线 #

链式操作 #

python
import ray

ray.init()

ds = (
    ray.data.read_csv("data.csv")
    .filter(lambda row: row["value"] > 0)
    .map(lambda row: {"value": row["value"] * 2, "category": row["category"]})
    .groupby("category")
    .mean("value")
)

print(ds.take_all())

ray.shutdown()

预处理流水线 #

python
import ray

ray.init()

def preprocess_pipeline(data_path):
    ds = ray.data.read_csv(data_path)
    
    ds = ds.filter(lambda row: row["value"] is not None)
    
    ds = ds.map(lambda row: {
        "value": float(row["value"]),
        "category": row["category"].upper()
    })
    
    ds = ds.add_column("processed", lambda row: row["value"] * 2)
    
    return ds

ds = preprocess_pipeline("data.csv")
print(ds.schema())

ray.shutdown()

性能优化 #

并行度控制 #

python
import ray

ray.init()

ds = ray.data.read_csv("data.csv", parallelism=100)

ds = ds.map_batches(
    lambda batch: batch,
    parallelism=50
)

ray.shutdown()

内存优化 #

python
import ray

ray.init()

ds = ray.data.read_parquet(
    "large_data.parquet",
    columns=["col1", "col2"]
)

ds = ds.map_batches(
    lambda batch: batch,
    batch_format="pandas",
    batch_size=10000
)

ray.shutdown()

流式处理 #

python
import ray

ray.init()

ds = ray.data.read_csv("large_data.csv")

for batch in ds.iter_batches(batch_size=1000):
    process_batch(batch)

ray.shutdown()

常用操作示例 #

数据清洗 #

python
import ray

ray.init()

ds = ray.data.from_items([
    {"id": 1, "value": 10},
    {"id": 2, "value": None},
    {"id": 3, "value": 30},
    {"id": 4, "value": None},
])

ds = ds.filter(lambda row: row["value"] is not None)

ds = ds.fill_null(value=0)

print(ds.take_all())

ray.shutdown()

特征工程 #

python
import ray

ray.init()

ds = ray.data.from_items([
    {"text": "hello world"},
    {"text": "ray data processing"},
    {"text": "distributed computing"},
])

def extract_features(row):
    text = row["text"]
    return {
        "text": text,
        "length": len(text),
        "word_count": len(text.split()),
        "uppercase_count": sum(1 for c in text if c.isupper())
    }

ds = ds.map(extract_features)
print(ds.take_all())

ray.shutdown()

数据合并 #

python
import ray

ray.init()

ds1 = ray.data.from_items([{"id": i, "value1": i * 2} for i in range(5)])
ds2 = ray.data.from_items([{"id": i, "value2": i * 3} for i in range(5)])

ds = ds1.union(ds2)
print(f"Union count: {ds.count()}")

ds_zip = ds1.zip(ds2)
print(ds_zip.take(2))

ray.shutdown()

下一步 #

掌握了 Ray Data 之后,继续学习 Ray Train 分布式训练,了解如何进行大规模模型训练!

最后更新:2026-04-05