PPO 算法 #
PPO 概述 #
PPO(Proximal Policy Optimization,近端策略优化)是 OpenAI 于 2017 年提出的一种强化学习算法,因其简单高效、稳定性好而成为 RLHF 中最常用的策略优化算法。
为什么选择 PPO? #
text
┌─────────────────────────────────────────────────────────────┐
│ PPO 的优势 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 相比传统策略梯度方法: │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ ✅ 训练更稳定 │ │
│ │ ✅ 样本效率更高 │ │
│ │ ✅ 超参数更少 │ │
│ │ ✅ 实现更简单 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ 相比 TRPO: │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ ✅ 无需复杂的二阶优化 │ │
│ │ ✅ 计算效率更高 │ │
│ │ ✅ 更容易实现 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ 在 RLHF 中的优势: │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ ✅ 适合大规模模型 │ │
│ │ ✅ 训练过程可控 │ │
│ │ ✅ 与 KL 约束天然结合 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
PPO 核心原理 #
策略梯度回顾 #
text
策略梯度定理:
────────────────────────
∇J(θ) = E[∇log π_θ(a|s) * A(s,a)]
其中 A(s,a) 是优势函数,表示动作 a 比平均好多少
问题:
────────────────────────
├── 更新步长难以选择
├── 步长太大:策略剧烈变化,性能下降
├── 步长太小:学习太慢
└── 需要一种自动控制步长的方法
重要性采样 #
PPO 使用重要性采样来复用旧策略收集的数据:
text
重要性采样权重:
────────────────────────
ρ_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t)
重要性采样目标:
────────────────────────
L(θ) = E[ρ_t(θ) * A_t]
问题:当 θ 与 θ_old 差距较大时,估计方差大
PPO 裁剪目标 #
PPO 的核心创新是裁剪目标函数,限制策略更新幅度:
text
PPO-Clip 目标函数:
────────────────────────
L^CLIP(θ) = E[min(
ρ_t(θ) * A_t,
clip(ρ_t(θ), 1-ε, 1+ε) * A_t
)]
其中:
├── ρ_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t)
├── ε 是裁剪参数,通常 0.1 ~ 0.3
└── A_t 是优势估计
直观理解:
────────────────────────
├── 当 A_t > 0(好动作):
│ 限制 ρ_t 最大为 1+ε,防止过度增加概率
│
├── 当 A_t < 0(坏动作):
│ 限制 ρ_t 最小为 1-ε,防止过度降低概率
│
└── 效果:策略更新被限制在"信任区域"内
PPO 目标函数图解 #
text
┌─────────────────────────────────────────────────────────────┐
│ PPO 裁剪目标函数 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 当 A_t > 0 时(好动作): │
│ │
│ L │ │
│ ↑ │ ┌─────────── │
│ │ / │
│ │ / │
│ │ / │
│ 0 ────┼──────┼─────────────────→ ρ │
│ 0 1-ε 1+ε │
│ │
│ 当 A_t < 0 时(坏动作): │
│ │
│ L │ │
│ ↑ │ ─────────┐ │
│ │ \ │
│ │ \ │
│ │ \ │
│ 0 ────┼──────────────┼─────────→ ρ │
│ 0 1-ε 1+ε │
│ │
│ 裁剪效果:限制策略变化幅度,保证训练稳定性 │
│ │
└─────────────────────────────────────────────────────────────┘
PPO 完整算法 #
算法流程 #
text
PPO 算法流程:
────────────────────────
for iteration = 1, 2, ... do
1. 使用当前策略 π_θ 收集数据
- 生成回复序列
- 计算奖励
- 计算优势估计
2. 计算优势函数 A_t
- 使用 GAE 估计
- 需要价值函数 V(s)
3. 多轮更新策略
for epoch = 1, 2, ..., K do
for minibatch in data do
计算重要性采样权重 ρ_t
计算 PPO 裁剪损失
计算价值函数损失
计算熵奖励
反向传播更新参数
end for
end for
end for
完整目标函数 #
text
PPO 总损失:
────────────────────────
L_total = L_policy + c1 * L_value - c2 * L_entropy
其中:
├── L_policy:PPO 裁剪策略损失
│ L_policy = -E[min(ρ_t * A_t, clip(ρ_t, 1-ε, 1+ε) * A_t)]
│
├── L_value:价值函数损失
│ L_value = E[(V(s) - V_target)^2]
│
├── L_entropy:熵奖励(鼓励探索)
│ L_entropy = E[H(π(·|s))]
│
├── c1:价值损失系数,通常 0.5
└── c2:熵系数,通常 0.01
PPO 实现 #
核心代码 #
python
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM
class PPOModel(nn.Module):
def __init__(self, model_name):
super().__init__()
self.policy_model = AutoModelForCausalLM.from_pretrained(model_name)
self.value_head = nn.Linear(
self.policy_model.config.hidden_size, 1
)
def forward(self, input_ids, attention_mask=None):
outputs = self.policy_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True
)
logits = outputs.logits
hidden_states = outputs.hidden_states[-1]
values = self.value_head(hidden_states).squeeze(-1)
return logits, values
def get_log_probs(self, logits, labels):
log_probs = F.log_softmax(logits, dim=-1)
return log_probs.gather(2, labels.unsqueeze(-1)).squeeze(-1)
def generate(self, input_ids, attention_mask, **kwargs):
return self.policy_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs
)
PPO 训练器 #
python
class PPOTrainer:
def __init__(
self,
policy_model,
ref_model,
reward_model,
learning_rate=1e-5,
clip_range=0.2,
kl_coef=0.1,
value_coef=0.5,
entropy_coef=0.01,
gamma=1.0,
gae_lambda=0.95,
):
self.policy_model = policy_model
self.ref_model = ref_model
self.reward_model = reward_model
self.clip_range = clip_range
self.kl_coef = kl_coef
self.value_coef = value_coef
self.entropy_coef = entropy_coef
self.gamma = gamma
self.gae_lambda = gae_lambda
self.optimizer = torch.optim.Adam(
policy_model.parameters(),
lr=learning_rate
)
def compute_advantages(self, rewards, values, dones):
advantages = []
gae = 0
for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_value = 0
else:
next_value = values[t + 1]
delta = rewards[t] + self.gamma * next_value * (1 - dones[t]) - values[t]
gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * gae
advantages.insert(0, gae)
advantages = torch.tensor(advantages, device=rewards.device)
returns = advantages + values
return advantages, returns
def compute_kl_penalty(self, log_probs, ref_log_probs):
kl = log_probs - ref_log_probs
return kl.sum(dim=-1).mean()
def compute_policy_loss(self, log_probs, old_log_probs, advantages):
ratio = torch.exp(log_probs - old_log_probs)
surr1 = ratio * advantages
surr2 = torch.clamp(
ratio,
1 - self.clip_range,
1 + self.clip_range
) * advantages
policy_loss = -torch.min(surr1, surr2).mean()
return policy_loss
def compute_value_loss(self, values, returns):
return F.mse_loss(values, returns)
def compute_entropy_loss(self, logits):
probs = F.softmax(logits, dim=-1)
log_probs = F.log_softmax(logits, dim=-1)
entropy = -(probs * log_probs).sum(dim=-1)
return -entropy.mean()
def train_step(self, batch):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
old_log_probs = batch["log_probs"]
advantages = batch["advantages"]
returns = batch["returns"]
with torch.no_grad():
ref_outputs = self.ref_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True
)
ref_log_probs = self.policy_model.get_log_probs(
ref_outputs.logits[:, :-1, :],
input_ids[:, 1:]
)
logits, values = self.policy_model(
input_ids=input_ids,
attention_mask=attention_mask
)
log_probs = self.policy_model.get_log_probs(
logits[:, :-1, :],
input_ids[:, 1:]
)
values = values[:, :-1]
policy_loss = self.compute_policy_loss(
log_probs, old_log_probs, advantages
)
value_loss = self.compute_value_loss(values, returns)
entropy_loss = self.compute_entropy_loss(logits[:, :-1, :])
kl_penalty = self.compute_kl_penalty(log_probs, ref_log_probs)
total_loss = (
policy_loss
+ self.value_coef * value_loss
- self.entropy_coef * entropy_loss
+ self.kl_coef * kl_penalty
)
self.optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(
self.policy_model.parameters(), 1.0
)
self.optimizer.step()
return {
"policy_loss": policy_loss.item(),
"value_loss": value_loss.item(),
"entropy_loss": entropy_loss.item(),
"kl_penalty": kl_penalty.item(),
"total_loss": total_loss.item(),
}
数据收集 #
python
class PPODataCollector:
def __init__(self, policy_model, reward_model, tokenizer,
generation_kwargs=None):
self.policy_model = policy_model
self.reward_model = reward_model
self.tokenizer = tokenizer
self.generation_kwargs = generation_kwargs or {
"max_new_tokens": 256,
"temperature": 1.0,
"top_p": 0.9,
"do_sample": True,
}
def collect_rollouts(self, prompts):
rollouts = []
for prompt in prompts:
enc = self.tokenizer(
prompt, return_tensors="pt", truncation=True
)
input_ids = enc["input_ids"]
attention_mask = enc["attention_mask"]
with torch.no_grad():
output_ids = self.policy_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
**self.generation_kwargs
)
full_input_ids = output_ids
full_attention_mask = torch.ones_like(full_input_ids)
with torch.no_grad():
logits, values = self.policy_model(
input_ids=full_input_ids,
attention_mask=full_attention_mask
)
log_probs = self.policy_model.get_log_probs(
logits[:, :-1, :],
full_input_ids[:, 1:]
)
reward = self.reward_model(
full_input_ids,
full_attention_mask
)
rollout = {
"input_ids": full_input_ids,
"attention_mask": full_attention_mask,
"log_probs": log_probs,
"values": values[:, :-1],
"reward": reward,
}
rollouts.append(rollout)
return rollouts
超参数调优 #
关键超参数 #
text
┌─────────────────────────────────────────────────────────────┐
│ PPO 关键超参数 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 学习率(learning_rate): │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 典型值:1e-6 ~ 5e-6 │ │
│ │ 影响:控制更新步长 │ │
│ │ 建议:从小值开始,监控 KL 散度 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ 裁剪范围(clip_range): │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 典型值:0.1 ~ 0.3 │ │
│ │ 影响:限制策略变化幅度 │ │
│ │ 建议:0.2 是常用值,可根据稳定性调整 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ KL 惩罚系数(kl_coef): │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 典型值:0.01 ~ 0.1 │ │
│ │ 影响:控制与参考模型的偏离程度 │ │
│ │ 建议:可使用自适应调整策略 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ PPO 更新轮数(ppo_epochs): │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 典型值:2 ~ 4 │ │
│ │ 影响:数据复用次数 │ │
│ │ 建议:过多会导致过拟合 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ 批次大小(batch_size): │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 典型值:64 ~ 512 │ │
│ │ 影响:梯度估计方差 │ │
│ │ 建议:越大越稳定,但需要更多显存 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
自适应 KL 调整 #
python
class AdaptiveKLController:
def __init__(self, init_kl_coef=0.1, target_kl=6.0,
kl_horizon=10000):
self.kl_coef = init_kl_coef
self.target_kl = target_kl
self.kl_horizon = kl_horizon
def update(self, current_kl):
if current_kl < self.target_kl / 1.5:
self.kl_coef /= 2
elif current_kl > self.target_kl * 1.5:
self.kl_coef *= 2
self.kl_coef = max(min(self.kl_coef, 10.0), 0.01)
return self.kl_coef
学习率调度 #
python
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps,
num_training_steps):
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(
0.0,
float(num_training_steps - current_step) /
float(max(1, num_training_steps - num_warmup_steps))
)
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
训练监控 #
关键指标 #
python
class PPOMetrics:
def __init__(self):
self.metrics_history = {
"policy_loss": [],
"value_loss": [],
"entropy": [],
"kl_divergence": [],
"reward": [],
"clip_fraction": [],
}
def update(self, metrics_dict):
for key, value in metrics_dict.items():
if key in self.metrics_history:
self.metrics_history[key].append(value)
def compute_clip_fraction(self, ratios):
return ((ratios < 1 - self.clip_range) |
(ratios > 1 + self.clip_range)).float().mean()
def get_summary(self):
summary = {}
for key, values in self.metrics_history.items():
if values:
summary[key] = {
"mean": np.mean(values[-100:]),
"std": np.std(values[-100:]),
"min": np.min(values[-100:]),
"max": np.max(values[-100:]),
}
return summary
训练可视化 #
python
import wandb
class PPOTrainerWithLogging(PPOTrainer):
def __init__(self, *args, use_wandb=True, **kwargs):
super().__init__(*args, **kwargs)
self.use_wandb = use_wandb
self.global_step = 0
def train_step(self, batch):
metrics = super().train_step(batch)
if self.use_wandb:
wandb.log({
"train/policy_loss": metrics["policy_loss"],
"train/value_loss": metrics["value_loss"],
"train/entropy": -metrics["entropy_loss"],
"train/kl_penalty": metrics["kl_penalty"],
"train/total_loss": metrics["total_loss"],
"train/learning_rate": self.optimizer.param_groups[0]["lr"],
}, step=self.global_step)
self.global_step += 1
return metrics
def log_evaluation(self, eval_metrics):
if self.use_wandb:
wandb.log({
"eval/mean_reward": eval_metrics["mean_reward"],
"eval/reward_std": eval_metrics["reward_std"],
"eval/mean_length": eval_metrics["mean_length"],
}, step=self.global_step)
常见问题与解决方案 #
训练不稳定 #
text
症状:
────────────────────────
├── 损失剧烈波动
├── KL 散度突然增大
├── 奖励下降
└── 模型输出质量变差
原因:
────────────────────────
├── 学习率过大
├── 裁剪范围设置不当
├── KL 约束不足
├── 奖励模型问题
解决方案:
────────────────────────
├── 降低学习率
├── 增大 KL 惩罚系数
├── 减小裁剪范围
├── 检查奖励模型质量
├── 增加批次大小
└── 使用梯度裁剪
奖励黑客 #
text
症状:
────────────────────────
├── 奖励很高但输出质量差
├── 模型学会欺骗奖励模型
└── 输出模式化、重复
解决方案:
────────────────────────
├── 增强 KL 约束
├── 定期更新奖励模型
├── 使用多个奖励模型集成
├── 人工审核高奖励输出
└── 添加额外约束
模型能力退化 #
text
症状:
────────────────────────
├── 语言能力下降
├── 知识遗忘
├── 输出不自然
原因:
────────────────────────
├── 过度优化
├── KL 约束过松
├── 训练轮数过多
解决方案:
────────────────────────
├── 增大 KL 约束
├── 减少训练轮数
├── 使用早停
├── 混合预训练数据
└── 多阶段训练
下一步 #
现在你已经掌握了 PPO 算法的原理和实现,接下来学习 训练流程,了解完整的 RLHF 训练流程!
最后更新:2026-04-05