RLHF 训练流程 #
训练流程概览 #
text
┌─────────────────────────────────────────────────────────────┐
│ RLHF 完整训练流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 阶段 0:准备工作 │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ ├── 确定训练目标 │ │
│ │ ├── 准备计算资源 │ │
│ │ └── 设计数据收集方案 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 阶段 1:监督微调(SFT) │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ ├── 准备指令数据 │ │
│ │ ├── 训练 SFT 模型 │ │
│ │ └── 评估 SFT 模型 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 阶段 2:奖励模型训练 │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ ├── 收集偏好数据 │ │
│ │ ├── 训练奖励模型 │ │
│ │ └── 评估奖励模型 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 阶段 3:PPO 训练 │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ ├── 配置 PPO 参数 │ │
│ │ ├── 执行 PPO 训练 │ │
│ │ └── 监控训练指标 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 阶段 4:评估与迭代 │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ ├── 模型评估 │ │
│ │ ├── 问题分析 │ │
│ │ └── 迭代优化 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
阶段 0:准备工作 #
确定训练目标 #
text
明确目标:
────────────────────────
├── 模型用途:对话、写作、代码、分析等
├── 目标用户:普通用户、专业人士、开发者等
├── 安全要求:内容安全、隐私保护等
├── 性能要求:响应速度、输出质量等
└── 资源限制:计算资源、时间、预算等
定义成功指标:
────────────────────────
├── 自动评估指标:困惑度、准确率等
├── 人工评估:有用性、安全性、真实性
├── 特定任务指标:代码正确率、问答准确率等
└── 用户反馈:满意度、使用率等
资源规划 #
text
计算资源需求:
────────────────────────
┌─────────────────────────────────────────────────────┐
│ 模型规模 │ SFT │ RM │ PPO │
├─────────────────────────────────────────────────────┤
│ 7B │ 4x A100 │ 4x A100 │ 8x A100 │
│ 13B │ 8x A100 │ 8x A100 │ 16x A100 │
│ 34B │ 16x A100 │ 16x A100 │ 32x A100 │
│ 70B │ 32x A100 │ 32x A100 │ 64x A100 │
└─────────────────────────────────────────────────────┘
时间估算:
────────────────────────
├── SFT:1-3 天
├── RM 训练:1-2 天
├── PPO 训练:3-7 天
└── 评估迭代:持续进行
阶段 1:监督微调(SFT) #
数据准备 #
python
import json
from torch.utils.data import Dataset
from transformers import AutoTokenizer
class SFTDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=2048):
self.tokenizer = tokenizer
self.max_length = max_length
self.data = self._load_data(data_path)
def _load_data(self, data_path):
data = []
with open(data_path, 'r') as f:
for line in f:
item = json.loads(line)
data.append(item)
return data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
if "conversations" in item:
text = self._format_conversation(item["conversations"])
else:
text = f"### Human: {item['instruction']}\n### Assistant: {item['output']}"
enc = self.tokenizer(
text,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
labels = enc["input_ids"].clone()
labels[labels == self.tokenizer.pad_token_id] = -100
return {
"input_ids": enc["input_ids"].squeeze(0),
"attention_mask": enc["attention_mask"].squeeze(0),
"labels": labels.squeeze(0),
}
def _format_conversation(self, conversations):
text = ""
for turn in conversations:
if turn["role"] == "user":
text += f"### Human: {turn['content']}\n"
else:
text += f"### Assistant: {turn['content']}\n"
return text.strip()
SFT 训练配置 #
python
from transformers import (
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq,
)
def train_sft(
base_model_name,
train_data_path,
output_dir,
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-5,
):
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
)
train_dataset = SFTDataset(train_data_path, tokenizer)
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=per_device_train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=learning_rate,
weight_decay=0.01,
warmup_ratio=0.03,
logging_steps=10,
save_steps=500,
save_total_limit=3,
bf16=True,
gradient_checkpointing=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=DataCollatorForSeq2Seq(
tokenizer=tokenizer,
padding=True,
),
)
trainer.train()
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
return model
SFT 评估 #
python
def evaluate_sft(model, tokenizer, test_data, max_new_tokens=256):
model.eval()
results = []
for item in test_data:
prompt = f"### Human: {item['instruction']}\n### Assistant:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=0.7,
top_p=0.9,
do_sample=True,
)
generated = tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
results.append({
"instruction": item["instruction"],
"expected": item["output"],
"generated": generated,
})
return results
阶段 2:奖励模型训练 #
偏好数据收集 #
python
class PreferenceDataCollector:
def __init__(self, sft_model, tokenizer, num_responses=4):
self.model = sft_model
self.tokenizer = tokenizer
self.num_responses = num_responses
def collect_responses(self, prompts, temperature=1.0):
all_responses = []
for prompt in prompts:
inputs = self.tokenizer(
prompt, return_tensors="pt", truncation=True
).to(self.model.device)
responses = []
for _ in range(self.num_responses):
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
temperature=temperature,
top_p=0.9,
do_sample=True,
)
response = self.tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
responses.append(response)
all_responses.append({
"prompt": prompt,
"responses": responses,
})
return all_responses
def create_annotation_task(self, collected_data):
tasks = []
for item in collected_data:
task = {
"prompt": item["prompt"],
"responses": item["responses"],
"instruction": "请对以下回复按质量从高到低排序",
}
tasks.append(task)
return tasks
奖励模型训练 #
python
def train_reward_model(
sft_model_name,
preference_data_path,
output_dir,
num_train_epochs=1,
per_device_train_batch_size=16,
learning_rate=1e-5,
):
tokenizer = AutoTokenizer.from_pretrained(sft_model_name)
tokenizer.pad_token = tokenizer.eos_token
reward_model = RewardModel(sft_model_name)
train_dataset = PreferenceDataset(
preference_data_path, tokenizer
)
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=per_device_train_batch_size,
learning_rate=learning_rate,
weight_decay=0.01,
warmup_ratio=0.03,
logging_steps=10,
save_steps=500,
bf16=True,
)
trainer = RewardModelTrainer(
model=reward_model,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
trainer.train()
reward_model.save_pretrained(output_dir)
return reward_model
阶段 3:PPO 训练 #
完整 PPO 训练脚本 #
python
import torch
from transformers import AutoTokenizer
from tqdm import tqdm
class RLHFTrainer:
def __init__(
self,
sft_model_path,
reward_model_path,
output_dir,
config=None,
):
self.config = config or self.default_config()
self.output_dir = output_dir
self.tokenizer = AutoTokenizer.from_pretrained(sft_model_path)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.policy_model = PPOModel(sft_model_path)
self.ref_model = AutoModelForCausalLM.from_pretrained(sft_model_path)
self.reward_model = RewardModel.from_pretrained(reward_model_path)
self.ref_model.eval()
for param in self.ref_model.parameters():
param.requires_grad = False
self.ppo_trainer = PPOTrainer(
policy_model=self.policy_model,
ref_model=self.ref_model,
reward_model=self.reward_model,
**self.config["ppo"]
)
self.data_collector = PPODataCollector(
policy_model=self.policy_model,
reward_model=self.reward_model,
tokenizer=self.tokenizer,
)
def default_config(self):
return {
"ppo": {
"learning_rate": 1e-6,
"clip_range": 0.2,
"kl_coef": 0.05,
"value_coef": 0.5,
"entropy_coef": 0.01,
"gamma": 1.0,
"gae_lambda": 0.95,
},
"training": {
"total_steps": 100000,
"batch_size": 64,
"ppo_epochs": 4,
"save_steps": 1000,
"eval_steps": 500,
},
"generation": {
"max_new_tokens": 256,
"temperature": 1.0,
"top_p": 0.9,
}
}
def train(self, prompts):
total_steps = self.config["training"]["total_steps"]
batch_size = self.config["training"]["batch_size"]
ppo_epochs = self.config["training"]["ppo_epochs"]
save_steps = self.config["training"]["save_steps"]
step = 0
epoch = 0
while step < total_steps:
epoch += 1
print(f"Epoch {epoch}")
for batch_prompts in self._batch_prompts(prompts, batch_size):
rollouts = self.data_collector.collect_rollouts(batch_prompts)
for _ in range(ppo_epochs):
metrics = self.ppo_trainer.train_step(rollouts)
if step % 10 == 0:
self._log_metrics(metrics, step)
step += 1
if step % save_steps == 0:
self._save_checkpoint(step)
if step >= total_steps:
break
if step >= total_steps:
break
self._save_checkpoint("final")
def _batch_prompts(self, prompts, batch_size):
for i in range(0, len(prompts), batch_size):
yield prompts[i:i + batch_size]
def _log_metrics(self, metrics, step):
print(f"Step {step}:")
for key, value in metrics.items():
print(f" {key}: {value:.4f}")
def _save_checkpoint(self, step):
save_path = f"{self.output_dir}/checkpoint-{step}"
self.policy_model.policy_model.save_pretrained(save_path)
self.tokenizer.save_pretrained(save_path)
print(f"Saved checkpoint to {save_path}")
训练入口 #
python
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--sft_model", required=True)
parser.add_argument("--reward_model", required=True)
parser.add_argument("--output_dir", required=True)
parser.add_argument("--prompts_file", required=True)
parser.add_argument("--learning_rate", type=float, default=1e-6)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--total_steps", type=int, default=100000)
args = parser.parse_args()
with open(args.prompts_file, 'r') as f:
prompts = [line.strip() for line in f if line.strip()]
config = {
"ppo": {
"learning_rate": args.learning_rate,
"clip_range": 0.2,
"kl_coef": 0.05,
},
"training": {
"total_steps": args.total_steps,
"batch_size": args.batch_size,
}
}
trainer = RLHFTrainer(
sft_model_path=args.sft_model,
reward_model_path=args.reward_model,
output_dir=args.output_dir,
config=config,
)
trainer.train(prompts)
if __name__ == "__main__":
main()
阶段 4:评估与迭代 #
自动评估 #
python
class RLHFEvaluator:
def __init__(self, model, tokenizer, reward_model=None):
self.model = model
self.tokenizer = tokenizer
self.reward_model = reward_model
def evaluate_perplexity(self, test_data):
self.model.eval()
total_loss = 0
total_tokens = 0
for item in test_data:
text = f"### Human: {item['instruction']}\n### Assistant: {item['output']}"
inputs = self.tokenizer(
text, return_tensors="pt", truncation=True, max_length=2048
).to(self.model.device)
with torch.no_grad():
outputs = self.model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
labels=inputs["input_ids"],
)
total_loss += outputs.loss.item() * inputs["input_ids"].shape[1]
total_tokens += inputs["input_ids"].shape[1]
perplexity = torch.exp(torch.tensor(total_loss / total_tokens))
return perplexity.item()
def evaluate_rewards(self, test_data):
if self.reward_model is None:
return None
self.model.eval()
self.reward_model.eval()
rewards = []
for item in test_data:
prompt = f"### Human: {item['instruction']}\n### Assistant:"
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
do_sample=True,
)
full_input_ids = outputs
attention_mask = torch.ones_like(full_input_ids)
with torch.no_grad():
reward = self.reward_model(full_input_ids, attention_mask)
rewards.append(reward.item())
return {
"mean_reward": np.mean(rewards),
"std_reward": np.std(rewards),
"min_reward": np.min(rewards),
"max_reward": np.max(rewards),
}
def evaluate_safety(self, test_data):
self.model.eval()
safety_issues = []
for item in test_data:
prompt = f"### Human: {item['instruction']}\n### Assistant:"
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
)
response = self.tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
issues = self._check_safety(response)
if issues:
safety_issues.append({
"prompt": item["instruction"],
"response": response,
"issues": issues,
})
return {
"total_samples": len(test_data),
"issues_found": len(safety_issues),
"issue_rate": len(safety_issues) / len(test_data),
"details": safety_issues,
}
def _check_safety(self, text):
issues = []
harmful_patterns = [
"violence", "illegal", "harmful"
]
text_lower = text.lower()
for pattern in harmful_patterns:
if pattern in text_lower:
issues.append(f"Potential harmful content: {pattern}")
return issues
人工评估 #
python
class HumanEvaluation:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def create_evaluation_task(self, test_data, num_samples=100):
import random
samples = random.sample(test_data, min(num_samples, len(test_data)))
tasks = []
for item in samples:
prompt = f"### Human: {item['instruction']}\n### Assistant:"
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
)
response = self.tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
tasks.append({
"id": len(tasks),
"instruction": item["instruction"],
"response": response,
"evaluation_dimensions": [
"helpfulness",
"truthfulness",
"safety",
"coherence",
]
})
return tasks
def aggregate_results(self, evaluation_results):
dimensions = ["helpfulness", "truthfulness", "safety", "coherence"]
aggregated = {dim: [] for dim in dimensions}
for result in evaluation_results:
for dim in dimensions:
if dim in result["scores"]:
aggregated[dim].append(result["scores"][dim])
summary = {}
for dim, scores in aggregated.items():
summary[dim] = {
"mean": np.mean(scores),
"std": np.std(scores),
"median": np.median(scores),
}
return summary
训练最佳实践 #
数据质量 #
text
高质量数据标准:
────────────────────────
├── SFT 数据
│ ├── 回复由专业人员编写
│ ├── 覆盖多样化任务
│ ├── 格式统一规范
│ └── 质量人工审核
│
├── 偏好数据
│ ├── 标注者经过培训
│ ├── 多人标注取共识
│ ├── 覆盖各种情况
│ └── 定期质量检查
│
└── PPO 提示
├── 来源真实用户
├── 分布合理
├── 难度适中
└── 持续更新
训练稳定性 #
text
保持训练稳定:
────────────────────────
├── 监控 KL 散度
│ ├── KL 突然增大:降低学习率
│ ├── KL 持续增大:增大 KL 系数
│ └── 使用自适应 KL 控制
│
├── 监控奖励
│ ├── 奖励突然下降:检查奖励模型
│ ├── 奖励持续上升:警惕奖励黑客
│ └── 结合人工评估
│
├── 定期保存检查点
│ ├── 每 1000 步保存
│ ├── 保留多个检查点
│ └── 记录训练指标
│
└── 早停策略
├── 监控验证集性能
├── 性能下降时停止
└── 回滚到最佳检查点
迭代优化 #
text
迭代改进流程:
────────────────────────
1. 分析问题
├── 收集失败案例
├── 分类问题类型
└── 确定优先级
2. 针对性改进
├── 补充特定数据
├── 调整训练参数
└── 修改奖励设计
3. 验证改进
├── 对比测试
├── 人工评估
└── A/B 测试
4. 部署更新
├── 灰度发布
├── 监控反馈
└── 持续迭代
下一步 #
现在你已经掌握了完整的 RLHF 训练流程,接下来学习 DPO 直接偏好优化,了解一种更简单的对齐方法!
最后更新:2026-04-05