RLHF 训练流程 #

训练流程概览 #

text
┌─────────────────────────────────────────────────────────────┐
│                    RLHF 完整训练流程                         │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  阶段 0:准备工作                                            │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  ├── 确定训练目标                                    │   │
│  │  ├── 准备计算资源                                    │   │
│  │  └── 设计数据收集方案                                │   │
│  └─────────────────────────────────────────────────────┘   │
│                          │                                  │
│                          ▼                                  │
│  阶段 1:监督微调(SFT)                                     │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  ├── 准备指令数据                                    │   │
│  │  ├── 训练 SFT 模型                                   │   │
│  │  └── 评估 SFT 模型                                   │   │
│  └─────────────────────────────────────────────────────┘   │
│                          │                                  │
│                          ▼                                  │
│  阶段 2:奖励模型训练                                        │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  ├── 收集偏好数据                                    │   │
│  │  ├── 训练奖励模型                                    │   │
│  │  └── 评估奖励模型                                    │   │
│  └─────────────────────────────────────────────────────┘   │
│                          │                                  │
│                          ▼                                  │
│  阶段 3:PPO 训练                                            │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  ├── 配置 PPO 参数                                   │   │
│  │  ├── 执行 PPO 训练                                   │   │
│  │  └── 监控训练指标                                    │   │
│  └─────────────────────────────────────────────────────┘   │
│                          │                                  │
│                          ▼                                  │
│  阶段 4:评估与迭代                                          │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  ├── 模型评估                                        │   │
│  │  ├── 问题分析                                        │   │
│  │  └── 迭代优化                                        │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

阶段 0:准备工作 #

确定训练目标 #

text
明确目标:
────────────────────────
├── 模型用途:对话、写作、代码、分析等
├── 目标用户:普通用户、专业人士、开发者等
├── 安全要求:内容安全、隐私保护等
├── 性能要求:响应速度、输出质量等
└── 资源限制:计算资源、时间、预算等

定义成功指标:
────────────────────────
├── 自动评估指标:困惑度、准确率等
├── 人工评估:有用性、安全性、真实性
├── 特定任务指标:代码正确率、问答准确率等
└── 用户反馈:满意度、使用率等

资源规划 #

text
计算资源需求:
────────────────────────
┌─────────────────────────────────────────────────────┐
│  模型规模    │  SFT      │  RM       │  PPO       │
├─────────────────────────────────────────────────────┤
│  7B         │  4x A100  │  4x A100  │  8x A100   │
│  13B        │  8x A100  │  8x A100  │  16x A100  │
│  34B        │  16x A100 │  16x A100 │  32x A100  │
│  70B        │  32x A100 │  32x A100 │  64x A100  │
└─────────────────────────────────────────────────────┘

时间估算:
────────────────────────
├── SFT:1-3 天
├── RM 训练:1-2 天
├── PPO 训练:3-7 天
└── 评估迭代:持续进行

阶段 1:监督微调(SFT) #

数据准备 #

python
import json
from torch.utils.data import Dataset
from transformers import AutoTokenizer

class SFTDataset(Dataset):
    def __init__(self, data_path, tokenizer, max_length=2048):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self._load_data(data_path)
    
    def _load_data(self, data_path):
        data = []
        with open(data_path, 'r') as f:
            for line in f:
                item = json.loads(line)
                data.append(item)
        return data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        if "conversations" in item:
            text = self._format_conversation(item["conversations"])
        else:
            text = f"### Human: {item['instruction']}\n### Assistant: {item['output']}"
        
        enc = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        labels = enc["input_ids"].clone()
        labels[labels == self.tokenizer.pad_token_id] = -100
        
        return {
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "labels": labels.squeeze(0),
        }
    
    def _format_conversation(self, conversations):
        text = ""
        for turn in conversations:
            if turn["role"] == "user":
                text += f"### Human: {turn['content']}\n"
            else:
                text += f"### Assistant: {turn['content']}\n"
        return text.strip()

SFT 训练配置 #

python
from transformers import (
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq,
)

def train_sft(
    base_model_name,
    train_data_path,
    output_dir,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
):
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    
    train_dataset = SFTDataset(train_data_path, tokenizer)
    
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=per_device_train_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,
        weight_decay=0.01,
        warmup_ratio=0.03,
        logging_steps=10,
        save_steps=500,
        save_total_limit=3,
        bf16=True,
        gradient_checkpointing=True,
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=DataCollatorForSeq2Seq(
            tokenizer=tokenizer,
            padding=True,
        ),
    )
    
    trainer.train()
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)
    
    return model

SFT 评估 #

python
def evaluate_sft(model, tokenizer, test_data, max_new_tokens=256):
    model.eval()
    results = []
    
    for item in test_data:
        prompt = f"### Human: {item['instruction']}\n### Assistant:"
        
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
            )
        
        generated = tokenizer.decode(
            outputs[0][inputs["input_ids"].shape[1]:],
            skip_special_tokens=True
        )
        
        results.append({
            "instruction": item["instruction"],
            "expected": item["output"],
            "generated": generated,
        })
    
    return results

阶段 2:奖励模型训练 #

偏好数据收集 #

python
class PreferenceDataCollector:
    def __init__(self, sft_model, tokenizer, num_responses=4):
        self.model = sft_model
        self.tokenizer = tokenizer
        self.num_responses = num_responses
    
    def collect_responses(self, prompts, temperature=1.0):
        all_responses = []
        
        for prompt in prompts:
            inputs = self.tokenizer(
                prompt, return_tensors="pt", truncation=True
            ).to(self.model.device)
            
            responses = []
            for _ in range(self.num_responses):
                with torch.no_grad():
                    outputs = self.model.generate(
                        **inputs,
                        max_new_tokens=256,
                        temperature=temperature,
                        top_p=0.9,
                        do_sample=True,
                    )
                
                response = self.tokenizer.decode(
                    outputs[0][inputs["input_ids"].shape[1]:],
                    skip_special_tokens=True
                )
                responses.append(response)
            
            all_responses.append({
                "prompt": prompt,
                "responses": responses,
            })
        
        return all_responses
    
    def create_annotation_task(self, collected_data):
        tasks = []
        for item in collected_data:
            task = {
                "prompt": item["prompt"],
                "responses": item["responses"],
                "instruction": "请对以下回复按质量从高到低排序",
            }
            tasks.append(task)
        return tasks

奖励模型训练 #

python
def train_reward_model(
    sft_model_name,
    preference_data_path,
    output_dir,
    num_train_epochs=1,
    per_device_train_batch_size=16,
    learning_rate=1e-5,
):
    tokenizer = AutoTokenizer.from_pretrained(sft_model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    reward_model = RewardModel(sft_model_name)
    
    train_dataset = PreferenceDataset(
        preference_data_path, tokenizer
    )
    
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=per_device_train_batch_size,
        learning_rate=learning_rate,
        weight_decay=0.01,
        warmup_ratio=0.03,
        logging_steps=10,
        save_steps=500,
        bf16=True,
    )
    
    trainer = RewardModelTrainer(
        model=reward_model,
        args=training_args,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
    )
    
    trainer.train()
    reward_model.save_pretrained(output_dir)
    
    return reward_model

阶段 3:PPO 训练 #

完整 PPO 训练脚本 #

python
import torch
from transformers import AutoTokenizer
from tqdm import tqdm

class RLHFTrainer:
    def __init__(
        self,
        sft_model_path,
        reward_model_path,
        output_dir,
        config=None,
    ):
        self.config = config or self.default_config()
        self.output_dir = output_dir
        
        self.tokenizer = AutoTokenizer.from_pretrained(sft_model_path)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        
        self.policy_model = PPOModel(sft_model_path)
        self.ref_model = AutoModelForCausalLM.from_pretrained(sft_model_path)
        self.reward_model = RewardModel.from_pretrained(reward_model_path)
        
        self.ref_model.eval()
        for param in self.ref_model.parameters():
            param.requires_grad = False
        
        self.ppo_trainer = PPOTrainer(
            policy_model=self.policy_model,
            ref_model=self.ref_model,
            reward_model=self.reward_model,
            **self.config["ppo"]
        )
        
        self.data_collector = PPODataCollector(
            policy_model=self.policy_model,
            reward_model=self.reward_model,
            tokenizer=self.tokenizer,
        )
    
    def default_config(self):
        return {
            "ppo": {
                "learning_rate": 1e-6,
                "clip_range": 0.2,
                "kl_coef": 0.05,
                "value_coef": 0.5,
                "entropy_coef": 0.01,
                "gamma": 1.0,
                "gae_lambda": 0.95,
            },
            "training": {
                "total_steps": 100000,
                "batch_size": 64,
                "ppo_epochs": 4,
                "save_steps": 1000,
                "eval_steps": 500,
            },
            "generation": {
                "max_new_tokens": 256,
                "temperature": 1.0,
                "top_p": 0.9,
            }
        }
    
    def train(self, prompts):
        total_steps = self.config["training"]["total_steps"]
        batch_size = self.config["training"]["batch_size"]
        ppo_epochs = self.config["training"]["ppo_epochs"]
        save_steps = self.config["training"]["save_steps"]
        
        step = 0
        epoch = 0
        
        while step < total_steps:
            epoch += 1
            print(f"Epoch {epoch}")
            
            for batch_prompts in self._batch_prompts(prompts, batch_size):
                rollouts = self.data_collector.collect_rollouts(batch_prompts)
                
                for _ in range(ppo_epochs):
                    metrics = self.ppo_trainer.train_step(rollouts)
                    
                    if step % 10 == 0:
                        self._log_metrics(metrics, step)
                    
                    step += 1
                    
                    if step % save_steps == 0:
                        self._save_checkpoint(step)
                    
                    if step >= total_steps:
                        break
                
                if step >= total_steps:
                    break
        
        self._save_checkpoint("final")
    
    def _batch_prompts(self, prompts, batch_size):
        for i in range(0, len(prompts), batch_size):
            yield prompts[i:i + batch_size]
    
    def _log_metrics(self, metrics, step):
        print(f"Step {step}:")
        for key, value in metrics.items():
            print(f"  {key}: {value:.4f}")
    
    def _save_checkpoint(self, step):
        save_path = f"{self.output_dir}/checkpoint-{step}"
        self.policy_model.policy_model.save_pretrained(save_path)
        self.tokenizer.save_pretrained(save_path)
        print(f"Saved checkpoint to {save_path}")

训练入口 #

python
def main():
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--sft_model", required=True)
    parser.add_argument("--reward_model", required=True)
    parser.add_argument("--output_dir", required=True)
    parser.add_argument("--prompts_file", required=True)
    parser.add_argument("--learning_rate", type=float, default=1e-6)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--total_steps", type=int, default=100000)
    args = parser.parse_args()
    
    with open(args.prompts_file, 'r') as f:
        prompts = [line.strip() for line in f if line.strip()]
    
    config = {
        "ppo": {
            "learning_rate": args.learning_rate,
            "clip_range": 0.2,
            "kl_coef": 0.05,
        },
        "training": {
            "total_steps": args.total_steps,
            "batch_size": args.batch_size,
        }
    }
    
    trainer = RLHFTrainer(
        sft_model_path=args.sft_model,
        reward_model_path=args.reward_model,
        output_dir=args.output_dir,
        config=config,
    )
    
    trainer.train(prompts)

if __name__ == "__main__":
    main()

阶段 4:评估与迭代 #

自动评估 #

python
class RLHFEvaluator:
    def __init__(self, model, tokenizer, reward_model=None):
        self.model = model
        self.tokenizer = tokenizer
        self.reward_model = reward_model
    
    def evaluate_perplexity(self, test_data):
        self.model.eval()
        total_loss = 0
        total_tokens = 0
        
        for item in test_data:
            text = f"### Human: {item['instruction']}\n### Assistant: {item['output']}"
            
            inputs = self.tokenizer(
                text, return_tensors="pt", truncation=True, max_length=2048
            ).to(self.model.device)
            
            with torch.no_grad():
                outputs = self.model(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    labels=inputs["input_ids"],
                )
            
            total_loss += outputs.loss.item() * inputs["input_ids"].shape[1]
            total_tokens += inputs["input_ids"].shape[1]
        
        perplexity = torch.exp(torch.tensor(total_loss / total_tokens))
        return perplexity.item()
    
    def evaluate_rewards(self, test_data):
        if self.reward_model is None:
            return None
        
        self.model.eval()
        self.reward_model.eval()
        rewards = []
        
        for item in test_data:
            prompt = f"### Human: {item['instruction']}\n### Assistant:"
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
            
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=256,
                    temperature=0.7,
                    top_p=0.9,
                    do_sample=True,
                )
            
            full_input_ids = outputs
            attention_mask = torch.ones_like(full_input_ids)
            
            with torch.no_grad():
                reward = self.reward_model(full_input_ids, attention_mask)
            
            rewards.append(reward.item())
        
        return {
            "mean_reward": np.mean(rewards),
            "std_reward": np.std(rewards),
            "min_reward": np.min(rewards),
            "max_reward": np.max(rewards),
        }
    
    def evaluate_safety(self, test_data):
        self.model.eval()
        safety_issues = []
        
        for item in test_data:
            prompt = f"### Human: {item['instruction']}\n### Assistant:"
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
            
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=256,
                    temperature=0.7,
                )
            
            response = self.tokenizer.decode(
                outputs[0][inputs["input_ids"].shape[1]:],
                skip_special_tokens=True
            )
            
            issues = self._check_safety(response)
            if issues:
                safety_issues.append({
                    "prompt": item["instruction"],
                    "response": response,
                    "issues": issues,
                })
        
        return {
            "total_samples": len(test_data),
            "issues_found": len(safety_issues),
            "issue_rate": len(safety_issues) / len(test_data),
            "details": safety_issues,
        }
    
    def _check_safety(self, text):
        issues = []
        
        harmful_patterns = [
            "violence", "illegal", "harmful"
        ]
        
        text_lower = text.lower()
        for pattern in harmful_patterns:
            if pattern in text_lower:
                issues.append(f"Potential harmful content: {pattern}")
        
        return issues

人工评估 #

python
class HumanEvaluation:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
    
    def create_evaluation_task(self, test_data, num_samples=100):
        import random
        
        samples = random.sample(test_data, min(num_samples, len(test_data)))
        tasks = []
        
        for item in samples:
            prompt = f"### Human: {item['instruction']}\n### Assistant:"
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
            
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=256,
                    temperature=0.7,
                )
            
            response = self.tokenizer.decode(
                outputs[0][inputs["input_ids"].shape[1]:],
                skip_special_tokens=True
            )
            
            tasks.append({
                "id": len(tasks),
                "instruction": item["instruction"],
                "response": response,
                "evaluation_dimensions": [
                    "helpfulness",
                    "truthfulness",
                    "safety",
                    "coherence",
                ]
            })
        
        return tasks
    
    def aggregate_results(self, evaluation_results):
        dimensions = ["helpfulness", "truthfulness", "safety", "coherence"]
        aggregated = {dim: [] for dim in dimensions}
        
        for result in evaluation_results:
            for dim in dimensions:
                if dim in result["scores"]:
                    aggregated[dim].append(result["scores"][dim])
        
        summary = {}
        for dim, scores in aggregated.items():
            summary[dim] = {
                "mean": np.mean(scores),
                "std": np.std(scores),
                "median": np.median(scores),
            }
        
        return summary

训练最佳实践 #

数据质量 #

text
高质量数据标准:
────────────────────────
├── SFT 数据
│   ├── 回复由专业人员编写
│   ├── 覆盖多样化任务
│   ├── 格式统一规范
│   └── 质量人工审核
│
├── 偏好数据
│   ├── 标注者经过培训
│   ├── 多人标注取共识
│   ├── 覆盖各种情况
│   └── 定期质量检查
│
└── PPO 提示
    ├── 来源真实用户
    ├── 分布合理
    ├── 难度适中
    └── 持续更新

训练稳定性 #

text
保持训练稳定:
────────────────────────
├── 监控 KL 散度
│   ├── KL 突然增大:降低学习率
│   ├── KL 持续增大:增大 KL 系数
│   └── 使用自适应 KL 控制
│
├── 监控奖励
│   ├── 奖励突然下降:检查奖励模型
│   ├── 奖励持续上升:警惕奖励黑客
│   └── 结合人工评估
│
├── 定期保存检查点
│   ├── 每 1000 步保存
│   ├── 保留多个检查点
│   └── 记录训练指标
│
└── 早停策略
    ├── 监控验证集性能
    ├── 性能下降时停止
    └── 回滚到最佳检查点

迭代优化 #

text
迭代改进流程:
────────────────────────
1. 分析问题
   ├── 收集失败案例
   ├── 分类问题类型
   └── 确定优先级

2. 针对性改进
   ├── 补充特定数据
   ├── 调整训练参数
   └── 修改奖励设计

3. 验证改进
   ├── 对比测试
   ├── 人工评估
   └── A/B 测试

4. 部署更新
   ├── 灰度发布
   ├── 监控反馈
   └── 持续迭代

下一步 #

现在你已经掌握了完整的 RLHF 训练流程,接下来学习 DPO 直接偏好优化,了解一种更简单的对齐方法!

最后更新:2026-04-05