RLHF 工具与框架 #
工具概览 #
text
┌─────────────────────────────────────────────────────────────┐
│ RLHF 工具生态 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 训练框架: │
│ ├── TRL (Transformer Reinforcement Learning) │
│ ├── DeepSpeed-Chat │
│ ├── Axolotl │
│ ├── RL4LMs │
│ └── OpenRLHF │
│ │
│ 分布式训练: │
│ ├── DeepSpeed │
│ ├── FSDP (Fully Sharded Data Parallel) │
│ └── Megatron-LM │
│ │
│ 数据标注: │
│ ├── Label Studio │
│ ├── Argilla │
│ └── Scale AI │
│ │
│ 评估工具: │
│ ├── lm-evaluation-harness │
│ ├── AlpacaEval │
│ ├── MT-Bench │
│ └── HELM │
│ │
└─────────────────────────────────────────────────────────────┘
TRL(Transformer Reinforcement Learning) #
简介 #
TRL 是 Hugging Face 开发的强化学习训练库,专门用于训练 Transformer 模型,支持 SFT、奖励模型训练和 PPO/DPO 训练。
安装 #
bash
pip install trl
pip install trl[peft]
pip install trl[quantization]
SFT 训练 #
python
from trl import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
training_args = TrainingArguments(
output_dir="./sft_model",
num_train_epochs=3,
per_device_train_batch_size=4,
learning_rate=2e-5,
logging_steps=10,
save_steps=500,
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
max_seq_length=512,
)
trainer.train()
DPO 训练 #
python
from trl import DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("./sft_model")
ref_model = AutoModelForCausalLM.from_pretrained("./sft_model")
tokenizer = AutoTokenizer.from_pretrained("./sft_model")
dpo_trainer = DPOTrainer(
model=model,
ref_model=ref_model,
beta=0.1,
train_dataset=preference_dataset,
tokenizer=tokenizer,
args=TrainingArguments(
output_dir="./dpo_model",
num_train_epochs=3,
per_device_train_batch_size=4,
learning_rate=1e-6,
),
)
dpo_trainer.train()
PPO 训练 #
python
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from transformers import AutoTokenizer
config = PPOConfig(
model_name="meta-llama/Llama-2-7b-hf",
learning_rate=1e-6,
batch_size=16,
mini_batch_size=4,
gradient_accumulation_steps=4,
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
ppo_trainer = PPOTrainer(
config=config,
model=model,
ref_model=None,
tokenizer=tokenizer,
dataset=dataset,
data_collator=collator,
)
generation_kwargs = {
"max_new_tokens": 256,
"temperature": 1.0,
"top_p": 0.9,
}
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
query_tensors = batch["input_ids"]
response_tensors = ppo_trainer.generate(
query_tensors, **generation_kwargs
)
batch["response"] = tokenizer.batch_decode(
response_tensors, skip_special_tokens=True
)
rewards = [reward_model(output) for output in batch["response"]]
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
ppo_trainer.log_stats(stats, batch, rewards)
DeepSpeed-Chat #
简介 #
DeepSpeed-Chat 是微软开发的端到端 RLHF 训练框架,支持完整的 RLHF 训练流程,针对大规模分布式训练进行了优化。
安装 #
bash
pip install deepspeed
git clone https://github.com/microsoft/DeepSpeedExamples.git
cd DeepSpeedExamples/applications/DeepSpeed-Chat
pip install -r requirements.txt
训练脚本 #
bash
python train.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--output_dir ./output \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--learning_rate 2e-5 \
--deepspeed ds_config.json
DeepSpeed 配置 #
json
{
"train_batch_size": 64,
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 4,
"optimizer": {
"type": "AdamW",
"params": {
"lr": 2e-5,
"weight_decay": 0.01
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 2e-5,
"warmup_num_steps": 100
}
},
"fp16": {
"enabled": true
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
}
}
}
三阶段训练 #
bash
python train.py --step 1 \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--output_dir ./step1_sft
python train.py --step 2 \
--model_name_or_path ./step1_sft \
--output_dir ./step2_rm
python train.py --step 3 \
--model_name_or_path ./step1_sft \
--reward_model_path ./step2_rm \
--output_dir ./step3_ppo
Axolotl #
简介 #
Axolotl 是一个配置驱动的微调工具,支持多种训练方法,包括 SFT、DPO、PPO 等,配置简单,易于使用。
安装 #
bash
pip install axolotl
配置文件 #
yaml
base_model: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: ./data/sft_data.json
type: alpaca
dataset_prepared_path: ./prepared_data
val_set_size: 0.05
output_dir: ./output
sequence_len: 512
sample_packing: false
adapter: qlora
lora_model_dir:
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 10
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
训练命令 #
bash
axolotl train config.yaml
DPO 配置 #
yaml
base_model: ./sft_model
model_type: LlamaForCausalLM
datasets:
- path: ./data/preference_data.json
type: dpo
dpo_beta: 0.1
output_dir: ./dpo_output
num_epochs: 3
micro_batch_size: 4
gradient_accumulation_steps: 4
learning_rate: 1e-6
OpenRLHF #
简介 #
OpenRLHF 是一个高性能的 RLHF 训练框架,支持大规模分布式训练,优化了 PPO 训练效率。
安装 #
bash
pip install openrlhf
训练脚本 #
bash
openrlhf.train_ppo \
--pretrain meta-llama/Llama-2-7b-hf \
--reward_model ./reward_model \
--prompt_data ./prompts.json \
--output_dir ./output \
--num_episodes 1000 \
--max_new_tokens 256 \
--temperature 1.0 \
--learning_rate 1e-6 \
--batch_size 64
数据标注工具 #
Label Studio #
python
import label_studio_sdk
ls_client = label_studio_sdk.Client(
url="http://localhost:8080",
api_key="your-api-key"
)
project = ls_client.start_project(
title="RLHF Preference Annotation",
label_config="""
<View>
<Text name="prompt" value="$prompt"/>
<Header value="Response 1"/>
<Text name="response1" value="$response1"/>
<Header value="Response 2"/>
<Text name="response2" value="$response2"/>
<Choices name="preference" toName="prompt" choice="single-radio">
<Choice value="response1"/>
<Choice value="response2"/>
<Choice value="tie"/>
</Choices>
</View>
"""
)
project.import_tasks(tasks_data)
Argilla #
python
import argilla as rg
rg.init(api_url="http://localhost:6900", api_key="your-api-key")
dataset = rg.FeedbackDataset(
guidelines="Please choose the better response",
fields=[
rg.TextField(name="prompt"),
rg.TextField(name="response1"),
rg.TextField(name="response2"),
],
questions=[
rg.RatingQuestion(
name="preference",
description="Choose the better response",
values=[1, 2]
)
]
)
dataset.push_to_argilla(name="preference-dataset")
评估工具 #
lm-evaluation-harness #
bash
pip install lm-eval
lm_eval --model hf \
--model_args pretrained=./rlhf_model \
--tasks hellaswag,mmlu,truthfulqa \
--batch_size 8
MT-Bench #
bash
git clone https://github.com/lm-sys/FastChat.git
cd FastChat
python gen_model_answer.py \
--model-path ./rlhf_model \
--model-id rlhf-model \
--num-gpus 1
python gen_judgment.py \
--model-list rlhf-model \
--parallel 4
python show_result.py --model-list rlhf-model
AlpacaEval #
bash
pip install alpaca_eval
alpaca_eval \
--model_outputs ./outputs.json \
--reference_outputs ./reference.json \
--annotators_config chatgpt_fn
工具对比 #
text
┌─────────────────────────────────────────────────────────────┐
│ RLHF 工具对比 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 工具 │ 易用性 │ 功能 │ 性能 │ 适用场景 │
│ ───────────────────────────────────────────────────────── │
│ TRL │ 高 │ 全面 │ 中 │ 快速开发 │
│ DeepSpeed-Chat│ 中 │ 全面 │ 高 │ 大规模训练 │
│ Axolotl │ 高 │ 中 │ 中 │ 配置驱动 │
│ OpenRLHF │ 中 │ 中 │ 高 │ 高性能 PPO │
│ RL4LMs │ 中 │ 中 │ 中 │ 研究实验 │
│ │
└─────────────────────────────────────────────────────────────┘
选择建议 #
text
选择 TRL 的场景:
────────────────────────
├── 快速原型开发
├── 与 Hugging Face 生态集成
├── 中小规模模型训练
└── 学习和研究
选择 DeepSpeed-Chat 的场景:
────────────────────────
├── 大规模分布式训练
├── 企业级生产环境
├── 需要完整 RLHF 流程
└── 有 DeepSpeed 经验
选择 Axolotl 的场景:
────────────────────────
├── 配置驱动的工作流
├── 快速实验不同配置
├── LoRA/QLoRA 微调
└── 简单易用优先
选择 OpenRLHF 的场景:
────────────────────────
├── 高性能 PPO 训练
├── 大规模模型
├── 计算资源充足
└── 追求训练效率
总结 #
本指南介绍了 RLHF 领域的主要工具和框架。选择合适的工具取决于你的具体需求、资源和技术背景。建议从 TRL 开始学习,然后根据需要尝试其他工具。
恭喜你完成了 RLHF 文档的学习!现在你已经掌握了从基础概念到实践应用的完整知识体系。祝你 RLHF 项目顺利!
最后更新:2026-04-05