Tasks 任务 #
什么是 Tasks? #
Tasks 是 Ray 中最基本的分布式计算单元。通过 @ray.remote 装饰器,任何 Python 函数都可以转换为远程函数,在集群中并行执行。
text
┌─────────────────────────────────────────────────────────────┐
│ Task 概念图 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 普通函数 远程函数 (Task) │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ def func(): │ │ @ray.remote │ │
│ │ return x │ ────────► │ def func(): │ │
│ └─────────────┘ │ return x │ │
│ └─────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ func.remote()│ │
│ │ 返回 ObjectRef │
│ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
定义 Tasks #
基本语法 #
python
import ray
@ray.remote
def my_task(x, y):
return x + y
result_ref = my_task.remote(1, 2)
result = ray.get(result_ref)
print(result)
Task 配置选项 #
python
@ray.remote(
num_cpus=2,
num_gpus=1,
memory=1024 * 1024 * 1024,
max_retries=3,
retry_exceptions=True,
runtime_env={"pip": ["numpy"]}
)
def complex_task(data):
import numpy as np
return np.mean(data)
配置参数说明 #
| 参数 | 说明 | 默认值 |
|---|---|---|
| num_cpus | CPU 核心数 | 1 |
| num_gpus | GPU 数量 | 0 |
| memory | 内存需求(字节) | 无限制 |
| max_retries | 最大重试次数 | 3 |
| retry_exceptions | 是否重试异常 | False |
| runtime_env | 运行时环境 | None |
| max_calls | 最大调用次数后重启 | 无限制 |
| scheduling_strategy | 调度策略 | DEFAULT |
执行 Tasks #
远程调用 #
python
import ray
ray.init()
@ray.remote
def compute(x):
return x ** 2
ref = compute.remote(10)
result = ray.get(ref)
print(result)
ray.shutdown()
并行执行 #
python
import ray
import time
ray.init()
@ray.remote
def process(x):
time.sleep(1)
return x * 2
refs = [process.remote(i) for i in range(10)]
results = ray.get(refs)
print(results)
ray.shutdown()
延迟获取 #
python
import ray
ray.init()
@ray.remote
def task_a():
return "A"
@ray.remote
def task_b(a_result):
return f"B depends on {a_result}"
a_ref = task_a.remote()
b_ref = task_b.remote(a_ref)
print(ray.get(b_ref))
ray.shutdown()
任务依赖 #
依赖图 #
text
┌─────────────────────────────────────────────────────────────┐
│ 任务依赖图 │
├─────────────────────────────────────────────────────────────┤
│ │
│ task_a() │
│ │ │
│ ┌─────────────┼─────────────┐ │
│ ▼ ▼ ▼ │
│ task_b() task_c() task_d() │
│ │ │ │ │
│ └─────────────┼─────────────┘ │
│ ▼ │
│ task_e() │
│ │ │
│ ▼ │
│ 结果 │
│ │
└─────────────────────────────────────────────────────────────┘
自动依赖解析 #
python
import ray
ray.init()
@ray.remote
def load_data():
return [1, 2, 3, 4, 5]
@ray.remote
def process_data(data):
return [x * 2 for x in data]
@ray.remote
def aggregate(results):
return sum(results)
data_ref = load_data.remote()
processed_ref = process_data.remote(data_ref)
result_ref = aggregate.remote([processed_ref])
print(ray.get(result_ref))
ray.shutdown()
复杂依赖 #
python
import ray
ray.init()
@ray.remote
def task_a():
return "A"
@ray.remote
def task_b(a_result):
return f"B({a_result})"
@ray.remote
def task_c(a_result):
return f"C({a_result})"
@ray.remote
def task_d(b_result, c_result):
return f"D({b_result}, {c_result})"
a_ref = task_a.remote()
b_ref = task_b.remote(a_ref)
c_ref = task_c.remote(a_ref)
d_ref = task_d.remote(b_ref, c_ref)
print(ray.get(d_ref))
ray.shutdown()
资源管理 #
CPU 资源 #
python
import ray
ray.init(num_cpus=4)
@ray.remote(num_cpus=2)
def cpu_intensive():
import time
time.sleep(1)
return "CPU task done"
refs = [cpu_intensive.remote() for _ in range(4)]
results = ray.get(refs)
print(results)
ray.shutdown()
GPU 资源 #
python
import ray
ray.init(num_gpus=2)
@ray.remote(num_gpus=1)
def gpu_task():
import torch
return torch.cuda.is_available()
@ray.remote(num_gpus=0.5)
def half_gpu_task():
import torch
return "Using half GPU"
ray.get(gpu_task.remote())
ray.get(half_gpu_task.remote())
ray.shutdown()
自定义资源 #
python
import ray
ray.init(resources={"custom_resource": 4})
@ray.remote(resources={"custom_resource": 1})
def custom_task():
return "Using custom resource"
ray.get(custom_task.remote())
ray.shutdown()
容错机制 #
自动重试 #
python
import ray
ray.init()
@ray.remote(max_retries=3)
def flaky_task():
import random
if random.random() < 0.5:
raise Exception("Random failure")
return "Success"
try:
result = ray.get(flaky_task.remote())
print(result)
except Exception as e:
print(f"Failed after retries: {e}")
ray.shutdown()
重试特定异常 #
python
import ray
ray.init()
class RetryableError(Exception):
pass
@ray.remote(retry_exceptions=[RetryableError], max_retries=3)
def task_with_retry():
raise RetryableError("Will retry")
ray.shutdown()
超时处理 #
python
import ray
import time
ray.init()
@ray.remote
def slow_task():
time.sleep(10)
return "done"
ref = slow_task.remote()
try:
result = ray.get(ref, timeout=2)
except ray.exceptions.GetTimeoutError:
print("Task timed out")
ray.shutdown()
高级特性 #
动态远程选项 #
python
import ray
ray.init()
@ray.remote
def base_task(x):
return x * 2
ref1 = base_task.options(num_cpus=2).remote(10)
ref2 = base_task.options(num_gpus=1).remote(20)
print(ray.get([ref1, ref2]))
ray.shutdown()
调度策略 #
python
import ray
ray.init()
@ray.remote(scheduling_strategy="SPREAD")
def spread_task():
return "Spread across nodes"
@ray.remote(scheduling_strategy="NODE_AFFINITY",
scheduling_strategy_node_id="node_id")
def affinity_task():
return "Affinity to specific node"
ray.shutdown()
运行时环境 #
python
import ray
ray.init()
@ray.remote(runtime_env={"pip": ["pandas==2.0.0"]})
def task_with_deps():
import pandas as pd
return pd.__version__
print(ray.get(task_with_deps.remote()))
ray.shutdown()
生成器任务 #
python
import ray
ray.init()
@ray.remote(num_returns=3)
def multi_return():
return 1, 2, 3
ref1, ref2, ref3 = multi_return.remote()
print(ray.get([ref1, ref2, ref3]))
ray.shutdown()
性能优化 #
批处理 #
python
import ray
ray.init()
@ray.remote
def process_single(item):
return item * 2
@ray.remote
def process_batch(items):
return [item * 2 for item in items]
data = list(range(1000))
refs = [process_single.remote(item) for item in data]
results = ray.get(refs)
batch_size = 100
batches = [data[i:i+batch_size] for i in range(0, len(data), batch_size)]
refs = [process_batch.remote(batch) for batch in batches]
results = [item for batch in ray.get(refs) for item in batch]
ray.shutdown()
使用 ray.put #
python
import ray
import numpy as np
ray.init()
large_data = np.random.rand(10000, 10000)
data_ref = ray.put(large_data)
@ray.remote
def process(data_ref, idx):
data = ray.get(data_ref)
return data[idx, 0]
refs = [process.remote(data_ref, i) for i in range(100)]
ray.shutdown()
避免过度并行 #
python
import ray
ray.init(num_cpus=4)
@ray.remote
def task(x):
return x * 2
data = list(range(100))
refs = [task.remote(x) for x in data[:4]]
results = ray.get(refs)
refs = [task.remote(x) for x in data]
results = ray.get(refs)
ray.shutdown()
实用模式 #
Map-Reduce 模式 #
python
import ray
from collections import Counter
ray.init()
@ray.remote
def map_fn(text):
words = text.split()
return Counter(words)
@ray.remote
def reduce_fn(counters):
total = Counter()
for c in counters:
total.update(c)
return total
texts = ["hello world", "hello ray", "world of distributed"]
map_refs = [map_fn.remote(text) for text in texts]
reduce_ref = reduce_fn.remote(map_refs)
print(ray.get(reduce_ref))
ray.shutdown()
流水线模式 #
python
import ray
ray.init()
@ray.remote
def stage1(x):
return x + 1
@ray.remote
def stage2(x):
return x * 2
@ray.remote
def stage3(x):
return x ** 2
def pipeline(x):
r1 = stage1.remote(x)
r2 = stage2.remote(r1)
r3 = stage3.remote(r2)
return ray.get(r3)
print([pipeline(i) for i in range(5)])
ray.shutdown()
Fork-Join 模式 #
python
import ray
ray.init()
@ray.remote
def compute(x):
return x ** 2
@ray.remote
def aggregate(results):
return sum(results)
data = list(range(10))
fork_refs = [compute.remote(x) for x in data]
join_ref = aggregate.remote(fork_refs)
print(ray.get(join_ref))
ray.shutdown()
调试技巧 #
日志记录 #
python
import ray
import logging
ray.init()
@ray.remote
def logged_task(x):
logging.basicConfig(level=logging.INFO)
logging.info(f"Processing {x}")
return x * 2
ray.get(logged_task.remote(10))
ray.shutdown()
错误处理 #
python
import ray
ray.init()
@ray.remote
def failing_task():
raise ValueError("Intentional error")
ref = failing_task.remote()
try:
ray.get(ref)
except ray.exceptions.RayTaskError as e:
print(f"Task failed: {e}")
ray.shutdown()
检查任务状态 #
python
import ray
import time
ray.init()
@ray.remote
def long_task():
time.sleep(10)
return "done"
ref = long_task.remote()
print(ray.wait([ref], timeout=0)[0] == [])
time.sleep(11)
print(ray.wait([ref], timeout=0)[0] != [])
ray.shutdown()
下一步 #
掌握了 Tasks 之后,继续学习 Actors 角色,了解如何创建有状态的分布式服务!
最后更新:2026-04-05