最佳实践 #
推荐配置 #
通用微调配置 #
text
┌─────────────────────────────────────────────────────────────┐
│ 推荐配置模板 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 标准指令微调: │
│ ├── rank: 8-16 │
│ ├── alpha: 16-32 │
│ ├── target_modules: ["q_proj", "v_proj"] │
│ ├── dropout: 0.05 │
│ ├── learning_rate: 2e-4 │
│ └── epochs: 3-5 │
│ │
│ 复杂任务微调: │
│ ├── rank: 32-64 │
│ ├── alpha: 64-128 │
│ ├── target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]│
│ ├── dropout: 0.1 │
│ ├── learning_rate: 1e-4 │
│ └── epochs: 5-10 │
│ │
│ QLoRA 配置: │
│ ├── rank: 64 │
│ ├── alpha: 16 │
│ ├── quantization: 4-bit NF4 │
│ ├── target_modules: 所有线性层 │
│ └── learning_rate: 2e-4 │
│ │
└─────────────────────────────────────────────────────────────┘
Python 配置模板 #
python
from peft import LoraConfig
from transformers import TrainingArguments
STANDARD_CONFIG = {
"lora": LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
),
"training": TrainingArguments(
learning_rate=2e-4,
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_ratio=0.1,
weight_decay=0.01,
lr_scheduler_type="cosine",
),
}
COMPLEX_CONFIG = {
"lora": LoraConfig(
r=32,
lora_alpha=64,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
),
"training": TrainingArguments(
learning_rate=1e-4,
num_train_epochs=5,
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
warmup_ratio=0.05,
weight_decay=0.01,
lr_scheduler_type="cosine",
),
}
数据准备最佳实践 #
数据质量 #
text
数据质量原则:
├── 准确性:标注数据必须准确无误
├── 一致性:格式和风格保持一致
├── 多样性:覆盖各种场景和边界情况
└── 相关性:与目标任务高度相关
数据格式 #
python
def format_instruction_data(instruction, input_text, output):
if input_text:
return f"""### Instruction:
{instruction}
### Input:
{input_text}
### Response:
{output}"""
else:
return f"""### Instruction:
{instruction}
### Response:
{output}"""
def format_chat_data(messages):
formatted = ""
for message in messages:
role = message["role"]
content = message["content"]
if role == "user":
formatted += f"<|user|>\n{content}\n"
elif role == "assistant":
formatted += f"<|assistant|\n{content}\n"
return formatted
数据增强 #
python
import random
def augment_instruction_data(data, num_augments=2):
augmented = []
paraphrase_templates = [
"请{instruction}",
"帮我{instruction}",
"你需要{instruction}",
]
for item in data:
augmented.append(item)
for _ in range(num_augments):
new_item = item.copy()
template = random.choice(paraphrase_templates)
new_item["instruction"] = template.format(instruction=item["instruction"])
augmented.append(new_item)
return augmented
训练最佳实践 #
学习率选择 #
text
学习率选择指南:
├── 保守选择:1e-4 ~ 2e-4
├── 激进选择:5e-4 ~ 1e-3(需要监控)
├── 小模型:可以稍高
└── 大模型:应该更低
训练监控 #
python
import wandb
from transformers import TrainerCallback
class LoRAMonitorCallback(TrainerCallback):
def __init__(self):
self.best_loss = float("inf")
self.patience = 3
self.patience_counter = 0
def on_log(self, args, state, control, logs=None, **kwargs):
if logs and "loss" in logs:
current_loss = logs["loss"]
if current_loss < self.best_loss:
self.best_loss = current_loss
self.patience_counter = 0
else:
self.patience_counter += 1
print(f"Step {state.global_step}: loss={current_loss:.4f}, best={self.best_loss:.4f}")
if self.patience_counter >= self.patience:
print(f"Early stopping: loss hasn't improved for {self.patience} logs")
control.should_training_stop = True
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
callbacks=[LoRAMonitorCallback()],
)
梯度裁剪 #
python
from transformers import TrainingArguments
training_args = TrainingArguments(
max_grad_norm=1.0,
)
常见问题解决 #
问题 1:显存不足 #
text
解决方案:
├── 降低 batch_size
├── 增加 gradient_accumulation_steps
├── 启用 gradient_checkpointing
├── 使用 QLoRA(4-bit 量化)
├── 减少 target_modules 数量
└── 使用更小的 rank
python
training_args = TrainingArguments(
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
gradient_checkpointing=True,
fp16=True,
optim="adamw_torch_fused",
)
问题 2:训练不收敛 #
text
解决方案:
├── 降低学习率
├── 增加 warmup_steps
├── 检查数据质量
├── 调整 alpha/r 比例
├── 增加 dropout
└── 减小 batch_size
python
training_args = TrainingArguments(
learning_rate=5e-5,
warmup_ratio=0.1,
weight_decay=0.01,
lr_scheduler_type="cosine",
)
问题 3:过拟合 #
text
解决方案:
├── 增加 dropout
├── 减少训练轮数
├── 增加数据量
├── 使用数据增强
├── 减小 rank
└── 早停策略
python
lora_config = LoraConfig(
r=8,
lora_dropout=0.1,
)
training_args = TrainingArguments(
num_train_epochs=2,
weight_decay=0.1,
)
问题 4:模型遗忘 #
text
解决方案:
├── 减小学习率
├── 减少训练轮数
├── 使用更小的 rank
├── 添加正则化
└── 混合原始数据
python
def mix_with_original_data(new_data, original_data, ratio=0.1):
mixed = new_data.copy()
sample_size = int(len(new_data) * ratio)
sampled_original = random.sample(original_data, sample_size)
mixed.extend(sampled_original)
random.shuffle(mixed)
return mixed
问题 5:推理速度慢 #
text
解决方案:
├── 合并 LoRA 权重
├── 使用量化推理
├── 批量推理
├── 使用 vLLM/TGI
└── 优化硬件配置
python
from peft import PeftModel
model = PeftModel.from_pretrained(base_model, lora_path)
merged_model = model.merge_and_unload()
merged_model.save_pretrained("./merged-model")
性能优化技巧 #
显存优化 #
python
def optimize_gpu_memory():
import torch
import gc
torch.cuda.empty_cache()
gc.collect()
torch.cuda.set_per_process_memory_fraction(0.95)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
optimize_gpu_memory()
训练速度优化 #
python
from transformers import TrainingArguments
training_args = TrainingArguments(
dataloader_num_workers=4,
dataloader_pin_memory=True,
dataloader_prefetch_factor=2,
torch_compile=True,
gradient_accumulation_steps=4,
per_device_train_batch_size=4,
)
推理优化 #
python
import torch
def optimize_inference(model):
model.eval()
with torch.no_grad():
torch.cuda.empty_cache()
if hasattr(model, "merge_and_unload"):
model = model.merge_and_unload()
return model
def batch_inference(model, tokenizer, prompts, batch_size=8):
results = []
for i in range(0, len(prompts), batch_size):
batch = prompts[i:i+batch_size]
inputs = tokenizer(batch, return_tensors="pt", padding=True).to(model.device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=256)
results.extend(tokenizer.batch_decode(outputs, skip_special_tokens=True))
return results
评估最佳实践 #
评估指标 #
python
from datasets import load_metric
def evaluate_lora_model(model, tokenizer, test_dataset):
perplexity = load_metric("perplexity")
def compute_perplexity(texts):
results = perplexity.compute(
model_id=model,
predictions=texts,
)
return results["mean_perplexity"]
bleu = load_metric("bleu")
rouge = load_metric("rouge")
return {
"perplexity": compute_perplexity,
"bleu": bleu,
"rouge": rouge,
}
人工评估 #
text
人工评估维度:
├── 准确性:回答是否正确
├── 相关性:回答是否切题
├── 流畅性:语言是否自然
├── 完整性:回答是否完整
└── 安全性:内容是否安全
版本管理 #
LoRA 版本控制 #
python
import json
from datetime import datetime
def save_lora_with_metadata(model, path, config, metrics):
model.save_pretrained(path)
metadata = {
"timestamp": datetime.now().isoformat(),
"config": {
"r": config.r,
"lora_alpha": config.lora_alpha,
"target_modules": config.target_modules,
"lora_dropout": config.lora_dropout,
},
"metrics": metrics,
"version": "1.0",
}
with open(f"{path}/metadata.json", "w") as f:
json.dump(metadata, f, indent=2)
def load_lora_metadata(path):
with open(f"{path}/metadata.json", "r") as f:
return json.load(f)
实验追踪 #
python
import wandb
def setup_experiment_tracking(config):
wandb.init(
project="lora-experiments",
config={
"rank": config.r,
"alpha": config.lora_alpha,
"target_modules": config.target_modules,
"learning_rate": config.learning_rate,
"epochs": config.num_train_epochs,
},
)
def log_metrics(metrics, step):
wandb.log(metrics, step=step)
下一步 #
现在你已经掌握了 LoRA 的最佳实践,接下来学习 变体方法,了解 QLoRA、AdaLoRA 等进阶变体!
最后更新:2026-04-05