PPO 算法 #

PPO 概述 #

PPO(Proximal Policy Optimization,近端策略优化)是 OpenAI 于 2017 年提出的一种强化学习算法,因其简单高效、稳定性好而成为 RLHF 中最常用的策略优化算法。

为什么选择 PPO? #

text
┌─────────────────────────────────────────────────────────────┐
│                    PPO 的优势                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  相比传统策略梯度方法:                                       │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  ✅ 训练更稳定                                      │   │
│  │  ✅ 样本效率更高                                    │   │
│  │  ✅ 超参数更少                                      │   │
│  │  ✅ 实现更简单                                      │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                             │
│  相比 TRPO:                                                 │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  ✅ 无需复杂的二阶优化                              │   │
│  │  ✅ 计算效率更高                                    │   │
│  │  ✅ 更容易实现                                      │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                             │
│  在 RLHF 中的优势:                                          │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  ✅ 适合大规模模型                                  │   │
│  │  ✅ 训练过程可控                                    │   │
│  │  ✅ 与 KL 约束天然结合                              │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

PPO 核心原理 #

策略梯度回顾 #

text
策略梯度定理:
────────────────────────
∇J(θ) = E[∇log π_θ(a|s) * A(s,a)]

其中 A(s,a) 是优势函数,表示动作 a 比平均好多少

问题:
────────────────────────
├── 更新步长难以选择
├── 步长太大:策略剧烈变化,性能下降
├── 步长太小:学习太慢
└── 需要一种自动控制步长的方法

重要性采样 #

PPO 使用重要性采样来复用旧策略收集的数据:

text
重要性采样权重:
────────────────────────
ρ_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t)

重要性采样目标:
────────────────────────
L(θ) = E[ρ_t(θ) * A_t]

问题:当 θ 与 θ_old 差距较大时,估计方差大

PPO 裁剪目标 #

PPO 的核心创新是裁剪目标函数,限制策略更新幅度:

text
PPO-Clip 目标函数:
────────────────────────
L^CLIP(θ) = E[min(
    ρ_t(θ) * A_t,
    clip(ρ_t(θ), 1-ε, 1+ε) * A_t
)]

其中:
├── ρ_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t)
├── ε 是裁剪参数,通常 0.1 ~ 0.3
└── A_t 是优势估计

直观理解:
────────────────────────
├── 当 A_t > 0(好动作):
│   限制 ρ_t 最大为 1+ε,防止过度增加概率
│
├── 当 A_t < 0(坏动作):
│   限制 ρ_t 最小为 1-ε,防止过度降低概率
│
└── 效果:策略更新被限制在"信任区域"内

PPO 目标函数图解 #

text
┌─────────────────────────────────────────────────────────────┐
│                    PPO 裁剪目标函数                          │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  当 A_t > 0 时(好动作):                                   │
│                                                             │
│  L     │                                                    │
│  ↑     │          ┌───────────                             │
│        │         /                                          │
│        │        /                                           │
│        │       /                                            │
│  0 ────┼──────┼─────────────────→ ρ                        │
│        0     1-ε  1+ε                                       │
│                                                             │
│  当 A_t < 0 时(坏动作):                                   │
│                                                             │
│  L     │                                                    │
│  ↑     │   ─────────┐                                       │
│        │            \                                       │
│        │             \                                      │
│        │              \                                     │
│  0 ────┼──────────────┼─────────→ ρ                        │
│        0      1-ε    1+ε                                    │
│                                                             │
│  裁剪效果:限制策略变化幅度,保证训练稳定性                   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

PPO 完整算法 #

算法流程 #

text
PPO 算法流程:
────────────────────────
for iteration = 1, 2, ... do
    1. 使用当前策略 π_θ 收集数据
       - 生成回复序列
       - 计算奖励
       - 计算优势估计
    
    2. 计算优势函数 A_t
       - 使用 GAE 估计
       - 需要价值函数 V(s)
    
    3. 多轮更新策略
       for epoch = 1, 2, ..., K do
           for minibatch in data do
               计算重要性采样权重 ρ_t
               计算 PPO 裁剪损失
               计算价值函数损失
               计算熵奖励
               反向传播更新参数
           end for
       end for
end for

完整目标函数 #

text
PPO 总损失:
────────────────────────
L_total = L_policy + c1 * L_value - c2 * L_entropy

其中:
├── L_policy:PPO 裁剪策略损失
│   L_policy = -E[min(ρ_t * A_t, clip(ρ_t, 1-ε, 1+ε) * A_t)]
│
├── L_value:价值函数损失
│   L_value = E[(V(s) - V_target)^2]
│
├── L_entropy:熵奖励(鼓励探索)
│   L_entropy = E[H(π(·|s))]
│
├── c1:价值损失系数,通常 0.5
└── c2:熵系数,通常 0.01

PPO 实现 #

核心代码 #

python
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM

class PPOModel(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.policy_model = AutoModelForCausalLM.from_pretrained(model_name)
        self.value_head = nn.Linear(
            self.policy_model.config.hidden_size, 1
        )
    
    def forward(self, input_ids, attention_mask=None):
        outputs = self.policy_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        
        logits = outputs.logits
        hidden_states = outputs.hidden_states[-1]
        values = self.value_head(hidden_states).squeeze(-1)
        
        return logits, values
    
    def get_log_probs(self, logits, labels):
        log_probs = F.log_softmax(logits, dim=-1)
        return log_probs.gather(2, labels.unsqueeze(-1)).squeeze(-1)
    
    def generate(self, input_ids, attention_mask, **kwargs):
        return self.policy_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )

PPO 训练器 #

python
class PPOTrainer:
    def __init__(
        self,
        policy_model,
        ref_model,
        reward_model,
        learning_rate=1e-5,
        clip_range=0.2,
        kl_coef=0.1,
        value_coef=0.5,
        entropy_coef=0.01,
        gamma=1.0,
        gae_lambda=0.95,
    ):
        self.policy_model = policy_model
        self.ref_model = ref_model
        self.reward_model = reward_model
        
        self.clip_range = clip_range
        self.kl_coef = kl_coef
        self.value_coef = value_coef
        self.entropy_coef = entropy_coef
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        
        self.optimizer = torch.optim.Adam(
            policy_model.parameters(),
            lr=learning_rate
        )
    
    def compute_advantages(self, rewards, values, dones):
        advantages = []
        gae = 0
        
        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_value = 0
            else:
                next_value = values[t + 1]
            
            delta = rewards[t] + self.gamma * next_value * (1 - dones[t]) - values[t]
            gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * gae
            advantages.insert(0, gae)
        
        advantages = torch.tensor(advantages, device=rewards.device)
        returns = advantages + values
        
        return advantages, returns
    
    def compute_kl_penalty(self, log_probs, ref_log_probs):
        kl = log_probs - ref_log_probs
        return kl.sum(dim=-1).mean()
    
    def compute_policy_loss(self, log_probs, old_log_probs, advantages):
        ratio = torch.exp(log_probs - old_log_probs)
        
        surr1 = ratio * advantages
        surr2 = torch.clamp(
            ratio,
            1 - self.clip_range,
            1 + self.clip_range
        ) * advantages
        
        policy_loss = -torch.min(surr1, surr2).mean()
        return policy_loss
    
    def compute_value_loss(self, values, returns):
        return F.mse_loss(values, returns)
    
    def compute_entropy_loss(self, logits):
        probs = F.softmax(logits, dim=-1)
        log_probs = F.log_softmax(logits, dim=-1)
        entropy = -(probs * log_probs).sum(dim=-1)
        return -entropy.mean()
    
    def train_step(self, batch):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        old_log_probs = batch["log_probs"]
        advantages = batch["advantages"]
        returns = batch["returns"]
        
        with torch.no_grad():
            ref_outputs = self.ref_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True
            )
            ref_log_probs = self.policy_model.get_log_probs(
                ref_outputs.logits[:, :-1, :],
                input_ids[:, 1:]
            )
        
        logits, values = self.policy_model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        log_probs = self.policy_model.get_log_probs(
            logits[:, :-1, :],
            input_ids[:, 1:]
        )
        values = values[:, :-1]
        
        policy_loss = self.compute_policy_loss(
            log_probs, old_log_probs, advantages
        )
        
        value_loss = self.compute_value_loss(values, returns)
        
        entropy_loss = self.compute_entropy_loss(logits[:, :-1, :])
        
        kl_penalty = self.compute_kl_penalty(log_probs, ref_log_probs)
        
        total_loss = (
            policy_loss
            + self.value_coef * value_loss
            - self.entropy_coef * entropy_loss
            + self.kl_coef * kl_penalty
        )
        
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(
            self.policy_model.parameters(), 1.0
        )
        self.optimizer.step()
        
        return {
            "policy_loss": policy_loss.item(),
            "value_loss": value_loss.item(),
            "entropy_loss": entropy_loss.item(),
            "kl_penalty": kl_penalty.item(),
            "total_loss": total_loss.item(),
        }

数据收集 #

python
class PPODataCollector:
    def __init__(self, policy_model, reward_model, tokenizer, 
                 generation_kwargs=None):
        self.policy_model = policy_model
        self.reward_model = reward_model
        self.tokenizer = tokenizer
        self.generation_kwargs = generation_kwargs or {
            "max_new_tokens": 256,
            "temperature": 1.0,
            "top_p": 0.9,
            "do_sample": True,
        }
    
    def collect_rollouts(self, prompts):
        rollouts = []
        
        for prompt in prompts:
            enc = self.tokenizer(
                prompt, return_tensors="pt", truncation=True
            )
            input_ids = enc["input_ids"]
            attention_mask = enc["attention_mask"]
            
            with torch.no_grad():
                output_ids = self.policy_model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    **self.generation_kwargs
                )
            
            full_input_ids = output_ids
            full_attention_mask = torch.ones_like(full_input_ids)
            
            with torch.no_grad():
                logits, values = self.policy_model(
                    input_ids=full_input_ids,
                    attention_mask=full_attention_mask
                )
                log_probs = self.policy_model.get_log_probs(
                    logits[:, :-1, :],
                    full_input_ids[:, 1:]
                )
                
                reward = self.reward_model(
                    full_input_ids,
                    full_attention_mask
                )
            
            rollout = {
                "input_ids": full_input_ids,
                "attention_mask": full_attention_mask,
                "log_probs": log_probs,
                "values": values[:, :-1],
                "reward": reward,
            }
            rollouts.append(rollout)
        
        return rollouts

超参数调优 #

关键超参数 #

text
┌─────────────────────────────────────────────────────────────┐
│                    PPO 关键超参数                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  学习率(learning_rate):                                   │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  典型值:1e-6 ~ 5e-6                                │   │
│  │  影响:控制更新步长                                  │   │
│  │  建议:从小值开始,监控 KL 散度                      │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                             │
│  裁剪范围(clip_range):                                    │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  典型值:0.1 ~ 0.3                                  │   │
│  │  影响:限制策略变化幅度                              │   │
│  │  建议:0.2 是常用值,可根据稳定性调整                │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                             │
│  KL 惩罚系数(kl_coef):                                    │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  典型值:0.01 ~ 0.1                                 │   │
│  │  影响:控制与参考模型的偏离程度                      │   │
│  │  建议:可使用自适应调整策略                          │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                             │
│  PPO 更新轮数(ppo_epochs):                                │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  典型值:2 ~ 4                                      │   │
│  │  影响:数据复用次数                                  │   │
│  │  建议:过多会导致过拟合                              │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                             │
│  批次大小(batch_size):                                    │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  典型值:64 ~ 512                                   │   │
│  │  影响:梯度估计方差                                  │   │
│  │  建议:越大越稳定,但需要更多显存                    │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

自适应 KL 调整 #

python
class AdaptiveKLController:
    def __init__(self, init_kl_coef=0.1, target_kl=6.0, 
                 kl_horizon=10000):
        self.kl_coef = init_kl_coef
        self.target_kl = target_kl
        self.kl_horizon = kl_horizon
    
    def update(self, current_kl):
        if current_kl < self.target_kl / 1.5:
            self.kl_coef /= 2
        elif current_kl > self.target_kl * 1.5:
            self.kl_coef *= 2
        
        self.kl_coef = max(min(self.kl_coef, 10.0), 0.01)
        
        return self.kl_coef

学习率调度 #

python
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, 
                                     num_training_steps):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(
            0.0,
            float(num_training_steps - current_step) / 
            float(max(1, num_training_steps - num_warmup_steps))
        )
    
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

训练监控 #

关键指标 #

python
class PPOMetrics:
    def __init__(self):
        self.metrics_history = {
            "policy_loss": [],
            "value_loss": [],
            "entropy": [],
            "kl_divergence": [],
            "reward": [],
            "clip_fraction": [],
        }
    
    def update(self, metrics_dict):
        for key, value in metrics_dict.items():
            if key in self.metrics_history:
                self.metrics_history[key].append(value)
    
    def compute_clip_fraction(self, ratios):
        return ((ratios < 1 - self.clip_range) | 
                (ratios > 1 + self.clip_range)).float().mean()
    
    def get_summary(self):
        summary = {}
        for key, values in self.metrics_history.items():
            if values:
                summary[key] = {
                    "mean": np.mean(values[-100:]),
                    "std": np.std(values[-100:]),
                    "min": np.min(values[-100:]),
                    "max": np.max(values[-100:]),
                }
        return summary

训练可视化 #

python
import wandb

class PPOTrainerWithLogging(PPOTrainer):
    def __init__(self, *args, use_wandb=True, **kwargs):
        super().__init__(*args, **kwargs)
        self.use_wandb = use_wandb
        self.global_step = 0
    
    def train_step(self, batch):
        metrics = super().train_step(batch)
        
        if self.use_wandb:
            wandb.log({
                "train/policy_loss": metrics["policy_loss"],
                "train/value_loss": metrics["value_loss"],
                "train/entropy": -metrics["entropy_loss"],
                "train/kl_penalty": metrics["kl_penalty"],
                "train/total_loss": metrics["total_loss"],
                "train/learning_rate": self.optimizer.param_groups[0]["lr"],
            }, step=self.global_step)
        
        self.global_step += 1
        return metrics
    
    def log_evaluation(self, eval_metrics):
        if self.use_wandb:
            wandb.log({
                "eval/mean_reward": eval_metrics["mean_reward"],
                "eval/reward_std": eval_metrics["reward_std"],
                "eval/mean_length": eval_metrics["mean_length"],
            }, step=self.global_step)

常见问题与解决方案 #

训练不稳定 #

text
症状:
────────────────────────
├── 损失剧烈波动
├── KL 散度突然增大
├── 奖励下降
└── 模型输出质量变差

原因:
────────────────────────
├── 学习率过大
├── 裁剪范围设置不当
├── KL 约束不足
├── 奖励模型问题

解决方案:
────────────────────────
├── 降低学习率
├── 增大 KL 惩罚系数
├── 减小裁剪范围
├── 检查奖励模型质量
├── 增加批次大小
└── 使用梯度裁剪

奖励黑客 #

text
症状:
────────────────────────
├── 奖励很高但输出质量差
├── 模型学会欺骗奖励模型
└── 输出模式化、重复

解决方案:
────────────────────────
├── 增强 KL 约束
├── 定期更新奖励模型
├── 使用多个奖励模型集成
├── 人工审核高奖励输出
└── 添加额外约束

模型能力退化 #

text
症状:
────────────────────────
├── 语言能力下降
├── 知识遗忘
├── 输出不自然

原因:
────────────────────────
├── 过度优化
├── KL 约束过松
├── 训练轮数过多

解决方案:
────────────────────────
├── 增大 KL 约束
├── 减少训练轮数
├── 使用早停
├── 混合预训练数据
└── 多阶段训练

下一步 #

现在你已经掌握了 PPO 算法的原理和实现,接下来学习 训练流程,了解完整的 RLHF 训练流程!

最后更新:2026-04-05