Tasks 任务 #

什么是 Tasks? #

Tasks 是 Ray 中最基本的分布式计算单元。通过 @ray.remote 装饰器,任何 Python 函数都可以转换为远程函数,在集群中并行执行。

text
┌─────────────────────────────────────────────────────────────┐
│                    Task 概念图                               │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  普通函数                    远程函数 (Task)                 │
│  ┌─────────────┐            ┌─────────────┐                │
│  │ def func(): │            │ @ray.remote │                │
│  │   return x  │ ────────► │ def func(): │                │
│  └─────────────┘            │   return x  │                │
│                             └─────────────┘                │
│                                   │                         │
│                                   ▼                         │
│                             ┌─────────────┐                │
│                             │ func.remote()│               │
│                             │   返回 ObjectRef              │
│                             └─────────────┘                │
│                                                             │
└─────────────────────────────────────────────────────────────┘

定义 Tasks #

基本语法 #

python
import ray

@ray.remote
def my_task(x, y):
    return x + y

result_ref = my_task.remote(1, 2)
result = ray.get(result_ref)
print(result)

Task 配置选项 #

python
@ray.remote(
    num_cpus=2,
    num_gpus=1,
    memory=1024 * 1024 * 1024,
    max_retries=3,
    retry_exceptions=True,
    runtime_env={"pip": ["numpy"]}
)
def complex_task(data):
    import numpy as np
    return np.mean(data)

配置参数说明 #

参数 说明 默认值
num_cpus CPU 核心数 1
num_gpus GPU 数量 0
memory 内存需求(字节) 无限制
max_retries 最大重试次数 3
retry_exceptions 是否重试异常 False
runtime_env 运行时环境 None
max_calls 最大调用次数后重启 无限制
scheduling_strategy 调度策略 DEFAULT

执行 Tasks #

远程调用 #

python
import ray

ray.init()

@ray.remote
def compute(x):
    return x ** 2

ref = compute.remote(10)

result = ray.get(ref)
print(result)

ray.shutdown()

并行执行 #

python
import ray
import time

ray.init()

@ray.remote
def process(x):
    time.sleep(1)
    return x * 2

refs = [process.remote(i) for i in range(10)]

results = ray.get(refs)
print(results)

ray.shutdown()

延迟获取 #

python
import ray

ray.init()

@ray.remote
def task_a():
    return "A"

@ray.remote
def task_b(a_result):
    return f"B depends on {a_result}"

a_ref = task_a.remote()
b_ref = task_b.remote(a_ref)

print(ray.get(b_ref))

ray.shutdown()

任务依赖 #

依赖图 #

text
┌─────────────────────────────────────────────────────────────┐
│                    任务依赖图                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│                      task_a()                               │
│                         │                                   │
│           ┌─────────────┼─────────────┐                    │
│           ▼             ▼             ▼                    │
│       task_b()      task_c()      task_d()                 │
│           │             │             │                    │
│           └─────────────┼─────────────┘                    │
│                         ▼                                   │
│                     task_e()                                │
│                         │                                   │
│                         ▼                                   │
│                      结果                                    │
│                                                             │
└─────────────────────────────────────────────────────────────┘

自动依赖解析 #

python
import ray

ray.init()

@ray.remote
def load_data():
    return [1, 2, 3, 4, 5]

@ray.remote
def process_data(data):
    return [x * 2 for x in data]

@ray.remote
def aggregate(results):
    return sum(results)

data_ref = load_data.remote()
processed_ref = process_data.remote(data_ref)
result_ref = aggregate.remote([processed_ref])

print(ray.get(result_ref))

ray.shutdown()

复杂依赖 #

python
import ray

ray.init()

@ray.remote
def task_a():
    return "A"

@ray.remote
def task_b(a_result):
    return f"B({a_result})"

@ray.remote
def task_c(a_result):
    return f"C({a_result})"

@ray.remote
def task_d(b_result, c_result):
    return f"D({b_result}, {c_result})"

a_ref = task_a.remote()
b_ref = task_b.remote(a_ref)
c_ref = task_c.remote(a_ref)
d_ref = task_d.remote(b_ref, c_ref)

print(ray.get(d_ref))

ray.shutdown()

资源管理 #

CPU 资源 #

python
import ray

ray.init(num_cpus=4)

@ray.remote(num_cpus=2)
def cpu_intensive():
    import time
    time.sleep(1)
    return "CPU task done"

refs = [cpu_intensive.remote() for _ in range(4)]

results = ray.get(refs)
print(results)

ray.shutdown()

GPU 资源 #

python
import ray

ray.init(num_gpus=2)

@ray.remote(num_gpus=1)
def gpu_task():
    import torch
    return torch.cuda.is_available()

@ray.remote(num_gpus=0.5)
def half_gpu_task():
    import torch
    return "Using half GPU"

ray.get(gpu_task.remote())
ray.get(half_gpu_task.remote())

ray.shutdown()

自定义资源 #

python
import ray

ray.init(resources={"custom_resource": 4})

@ray.remote(resources={"custom_resource": 1})
def custom_task():
    return "Using custom resource"

ray.get(custom_task.remote())

ray.shutdown()

容错机制 #

自动重试 #

python
import ray

ray.init()

@ray.remote(max_retries=3)
def flaky_task():
    import random
    if random.random() < 0.5:
        raise Exception("Random failure")
    return "Success"

try:
    result = ray.get(flaky_task.remote())
    print(result)
except Exception as e:
    print(f"Failed after retries: {e}")

ray.shutdown()

重试特定异常 #

python
import ray

ray.init()

class RetryableError(Exception):
    pass

@ray.remote(retry_exceptions=[RetryableError], max_retries=3)
def task_with_retry():
    raise RetryableError("Will retry")

ray.shutdown()

超时处理 #

python
import ray
import time

ray.init()

@ray.remote
def slow_task():
    time.sleep(10)
    return "done"

ref = slow_task.remote()

try:
    result = ray.get(ref, timeout=2)
except ray.exceptions.GetTimeoutError:
    print("Task timed out")

ray.shutdown()

高级特性 #

动态远程选项 #

python
import ray

ray.init()

@ray.remote
def base_task(x):
    return x * 2

ref1 = base_task.options(num_cpus=2).remote(10)
ref2 = base_task.options(num_gpus=1).remote(20)

print(ray.get([ref1, ref2]))

ray.shutdown()

调度策略 #

python
import ray

ray.init()

@ray.remote(scheduling_strategy="SPREAD")
def spread_task():
    return "Spread across nodes"

@ray.remote(scheduling_strategy="NODE_AFFINITY", 
            scheduling_strategy_node_id="node_id")
def affinity_task():
    return "Affinity to specific node"

ray.shutdown()

运行时环境 #

python
import ray

ray.init()

@ray.remote(runtime_env={"pip": ["pandas==2.0.0"]})
def task_with_deps():
    import pandas as pd
    return pd.__version__

print(ray.get(task_with_deps.remote()))

ray.shutdown()

生成器任务 #

python
import ray

ray.init()

@ray.remote(num_returns=3)
def multi_return():
    return 1, 2, 3

ref1, ref2, ref3 = multi_return.remote()
print(ray.get([ref1, ref2, ref3]))

ray.shutdown()

性能优化 #

批处理 #

python
import ray

ray.init()

@ray.remote
def process_single(item):
    return item * 2

@ray.remote
def process_batch(items):
    return [item * 2 for item in items]

data = list(range(1000))

refs = [process_single.remote(item) for item in data]
results = ray.get(refs)

batch_size = 100
batches = [data[i:i+batch_size] for i in range(0, len(data), batch_size)]
refs = [process_batch.remote(batch) for batch in batches]
results = [item for batch in ray.get(refs) for item in batch]

ray.shutdown()

使用 ray.put #

python
import ray
import numpy as np

ray.init()

large_data = np.random.rand(10000, 10000)
data_ref = ray.put(large_data)

@ray.remote
def process(data_ref, idx):
    data = ray.get(data_ref)
    return data[idx, 0]

refs = [process.remote(data_ref, i) for i in range(100)]

ray.shutdown()

避免过度并行 #

python
import ray

ray.init(num_cpus=4)

@ray.remote
def task(x):
    return x * 2

data = list(range(100))

refs = [task.remote(x) for x in data[:4]]
results = ray.get(refs)

refs = [task.remote(x) for x in data]
results = ray.get(refs)

ray.shutdown()

实用模式 #

Map-Reduce 模式 #

python
import ray
from collections import Counter

ray.init()

@ray.remote
def map_fn(text):
    words = text.split()
    return Counter(words)

@ray.remote
def reduce_fn(counters):
    total = Counter()
    for c in counters:
        total.update(c)
    return total

texts = ["hello world", "hello ray", "world of distributed"]

map_refs = [map_fn.remote(text) for text in texts]
reduce_ref = reduce_fn.remote(map_refs)

print(ray.get(reduce_ref))

ray.shutdown()

流水线模式 #

python
import ray

ray.init()

@ray.remote
def stage1(x):
    return x + 1

@ray.remote
def stage2(x):
    return x * 2

@ray.remote
def stage3(x):
    return x ** 2

def pipeline(x):
    r1 = stage1.remote(x)
    r2 = stage2.remote(r1)
    r3 = stage3.remote(r2)
    return ray.get(r3)

print([pipeline(i) for i in range(5)])

ray.shutdown()

Fork-Join 模式 #

python
import ray

ray.init()

@ray.remote
def compute(x):
    return x ** 2

@ray.remote
def aggregate(results):
    return sum(results)

data = list(range(10))

fork_refs = [compute.remote(x) for x in data]
join_ref = aggregate.remote(fork_refs)

print(ray.get(join_ref))

ray.shutdown()

调试技巧 #

日志记录 #

python
import ray
import logging

ray.init()

@ray.remote
def logged_task(x):
    logging.basicConfig(level=logging.INFO)
    logging.info(f"Processing {x}")
    return x * 2

ray.get(logged_task.remote(10))

ray.shutdown()

错误处理 #

python
import ray

ray.init()

@ray.remote
def failing_task():
    raise ValueError("Intentional error")

ref = failing_task.remote()

try:
    ray.get(ref)
except ray.exceptions.RayTaskError as e:
    print(f"Task failed: {e}")

ray.shutdown()

检查任务状态 #

python
import ray
import time

ray.init()

@ray.remote
def long_task():
    time.sleep(10)
    return "done"

ref = long_task.remote()

print(ray.wait([ref], timeout=0)[0] == [])

time.sleep(11)
print(ray.wait([ref], timeout=0)[0] != [])

ray.shutdown()

下一步 #

掌握了 Tasks 之后,继续学习 Actors 角色,了解如何创建有状态的分布式服务!

最后更新:2026-04-05