进阶技巧 #

超参数调优 #

秩(Rank)选择 #

秩是 LoRA 最重要的超参数,直接影响模型的表达能力和训练效率。

text
┌─────────────────────────────────────────────────────────────┐
│                    秩选择指南                                 │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  任务复杂度 vs 秩选择:                                       │
│                                                             │
│  简单任务 (r=1-4):                                          │
│  ├── 文本分类                                               │
│  ├── 情感分析                                               │
│  └── 简单问答                                               │
│                                                             │
│  中等任务 (r=8-16):                                         │
│  ├── 指令微调                                               │
│  ├── 对话系统                                               │
│  └── 代码补全                                               │
│                                                             │
│  复杂任务 (r=32-64):                                        │
│  ├── 风格迁移                                               │
│  ├── 多任务学习                                             │
│  └── 领域深度适配                                           │
│                                                             │
│  极复杂任务 (r=128+):                                       │
│  ├── 新语言学习                                             │
│  ├── 复杂推理                                               │
│  └── 接近全参数效果                                         │
│                                                             │
└─────────────────────────────────────────────────────────────┘

秩搜索实验 #

python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
from datasets import load_dataset

def rank_search_experiment(model_name, dataset_name, ranks=[4, 8, 16, 32, 64]):
    results = {}
    
    for r in ranks:
        print(f"\n{'='*50}")
        print(f"实验: rank = {r}")
        print(f"{'='*50}")
        
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto",
        )
        
        lora_config = LoraConfig(
            r=r,
            lora_alpha=2 * r,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.05,
            task_type="CAUSAL_LM",
        )
        
        model = get_peft_model(model, lora_config)
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        results[r] = {
            "trainable_params": trainable_params,
            "param_ratio": trainable_params / sum(p.numel() for p in model.parameters()) * 100,
        }
        
        print(f"可训练参数: {trainable_params:,}")
        print(f"参数比例: {results[r]['param_ratio']:.2f}%")
        
        del model
        torch.cuda.empty_cache()
    
    return results

results = rank_search_experiment(
    model_name="meta-llama/Llama-2-7b-hf",
    dataset_name="tatsu-lab/alpaca",
    ranks=[4, 8, 16, 32, 64]
)

Alpha 调优 #

python
def analyze_alpha_effect():
    r = 8
    alphas = [8, 16, 32, 64, 128]
    
    print("Alpha 对有效学习率的影响:")
    print("-" * 40)
    
    base_lr = 2e-4
    
    for alpha in alphas:
        scaling = alpha / r
        effective_lr = scaling * base_lr
        print(f"α={alpha:3d} | 缩放={scaling:5.1f} | 有效LR={effective_lr:.2e}")

analyze_alpha_effect()

目标模块选择 #

python
from peft import LoraConfig

target_configs = {
    "minimal": ["q_proj", "v_proj"],
    "standard": ["q_proj", "k_proj", "v_proj", "o_proj"],
    "full_attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
    "full_model": ["q_proj", "k_proj", "v_proj", "o_proj", 
                   "gate_proj", "up_proj", "down_proj"],
}

for name, targets in target_configs.items():
    config = LoraConfig(
        r=8,
        lora_alpha=16,
        target_modules=targets,
        task_type="CAUSAL_LM",
    )
    print(f"{name}: {len(targets)} 个目标模块")

学习率调度 #

python
from transformers import get_cosine_schedule_with_warmup

def create_scheduler(optimizer, num_training_steps, warmup_ratio=0.1):
    num_warmup_steps = int(num_training_steps * warmup_ratio)
    
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
    )
    
    return scheduler

training_args = TrainingArguments(
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    weight_decay=0.01,
)

多任务学习 #

多 LoRA 管理 #

python
import os
from typing import Dict, List
from dataclasses import dataclass

@dataclass
class LoRAConfig:
    name: str
    path: str
    task: str
    rank: int
    target_modules: List[str]

class MultiLoRAManager:
    def __init__(self, base_model_name: str):
        self.base_model_name = base_model_name
        self.lora_configs: Dict[str, LoRAConfig] = {}
        self.current_lora = None
    
    def register_lora(self, config: LoRAConfig):
        self.lora_configs[config.name] = config
        print(f"注册 LoRA: {config.name} (任务: {config.task})")
    
    def list_loras(self):
        print("已注册的 LoRA:")
        for name, config in self.lora_configs.items():
            print(f"  - {name}: {config.task} (r={config.rank})")
    
    def get_lora_config(self, name: str) -> LoRAConfig:
        return self.lora_configs.get(name)

manager = MultiLoRAManager("meta-llama/Llama-2-7b-hf")

manager.register_lora(LoRAConfig(
    name="medical",
    path="./loras/medical",
    task="医疗问答",
    rank=16,
    target_modules=["q_proj", "v_proj"],
))

manager.register_lora(LoRAConfig(
    name="code",
    path="./loras/code",
    task="代码生成",
    rank=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
))

manager.list_loras()

多任务训练策略 #

python
from datasets import concatenate_datasets, DatasetDict

def prepare_multitask_dataset(task_datasets: Dict[str, Dataset]) -> Dataset:
    all_data = []
    
    for task_name, dataset in task_datasets.items():
        def add_task_label(example, task=task_name):
            example["task"] = task
            return example
        
        dataset = dataset.map(add_task_label)
        all_data.append(dataset)
    
    combined = concatenate_datasets(all_data)
    combined = combined.shuffle(seed=42)
    
    return combined

task_datasets = {
    "medical": load_dataset("medical_qa", split="train"),
    "legal": load_dataset("legal_qa", split="train"),
    "general": load_dataset("alpaca", split="train"),
}

multitask_data = prepare_multitask_dataset(task_datasets)

模型合并 #

LoRA 权重合并 #

python
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM

def merge_lora_weights(base_model, lora_model):
    for name, param in base_model.named_parameters():
        lora_param = lora_model.get_parameter(name)
        if lora_param is not None:
            param.data = lora_param.data.clone()
    
    return base_model

def merge_multiple_loras(base_model_name, lora_paths, weights=None):
    if weights is None:
        weights = [1.0 / len(lora_paths)] * len(lora_paths)
    
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        torch_dtype=torch.float16,
        device_map="auto",
    )
    
    merged_weights = {}
    
    for lora_path, weight in zip(lora_paths, weights):
        lora_model = PeftModel.from_pretrained(base_model, lora_path)
        
        for name, param in lora_model.named_parameters():
            if "lora" in name:
                if name not in merged_weights:
                    merged_weights[name] = torch.zeros_like(param.data)
                merged_weights[name] += param.data * weight
    
    return merged_weights

TIES-Merging #

python
import torch
import torch.nn as nn

def ties_merge(base_model, lora_models, top_k=0.2):
    merged_state_dict = {}
    
    for name, base_param in base_model.named_parameters():
        deltas = []
        
        for lora_model in lora_models:
            lora_param = lora_model.state_dict().get(name)
            if lora_param is not None:
                delta = lora_param - base_param
                deltas.append(delta)
        
        if deltas:
            stacked = torch.stack(deltas, dim=0)
            
            threshold = torch.quantile(stacked.abs(), 1 - top_k)
            mask = stacked.abs() > threshold
            
            masked_deltas = stacked * mask.float()
            
            signs = torch.sign(masked_deltas.sum(dim=0))
            merged_delta = signs * masked_deltas.abs().mean(dim=0)
            
            merged_state_dict[name] = base_param + merged_delta
    
    return merged_state_dict

线性插值合并 #

python
def linear_interpolate_loras(lora_a_path, lora_b_path, alpha=0.5):
    from safetensors.torch import load_file
    
    weights_a = load_file(lora_a_path)
    weights_b = load_file(lora_b_path)
    
    interpolated = {}
    
    for key in weights_a.keys():
        if key in weights_b:
            interpolated[key] = (
                alpha * weights_a[key] + (1 - alpha) * weights_b[key]
            )
        else:
            interpolated[key] = weights_a[key]
    
    return interpolated

权重融合技术 #

SVD 分解融合 #

python
import torch

def svd_merge_loras(lora_weights_list, target_rank=None):
    all_a_weights = []
    all_b_weights = []
    
    for weights in lora_weights_list:
        all_a_weights.append(weights["lora_A"])
        all_b_weights.append(weights["lora_B"])
    
    stacked_a = torch.cat(all_a_weights, dim=0)
    stacked_b = torch.cat(all_b_weights, dim=1)
    
    combined = stacked_b @ stacked_a
    
    U, S, Vt = torch.linalg.svd(combined, full_matrices=False)
    
    if target_rank is not None:
        U = U[:, :target_rank]
        S = S[:target_rank]
        Vt = Vt[:target_rank, :]
    
    new_A = torch.diag(torch.sqrt(S)) @ Vt
    new_B = U @ torch.diag(torch.sqrt(S))
    
    return {"lora_A": new_A, "lora_B": new_B}

任务算术 #

python
def task_arithmetic(base_model, task_vectors, scaling_factor=1.0):
    result = {}
    
    for name, param in base_model.named_parameters():
        result[name] = param.data.clone()
        
        for task_name, task_vector in task_vectors.items():
            if name in task_vector:
                result[name] += scaling_factor * task_vector[name]
    
    return result

高级训练技巧 #

梯度检查点 #

python
from transformers import TrainingArguments

training_args = TrainingArguments(
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
)

混合精度训练 #

python
from transformers import TrainingArguments

training_args = TrainingArguments(
    fp16=True,
    fp16_opt_level="O1",
    bf16=False,
)

DeepSpeed 集成 #

python
ds_config = {
    "train_batch_size": "auto",
    "gradient_accumulation_steps": "auto",
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto",
        },
    },
    "scheduler": {
        "type": "WarmupDecayLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto",
        },
    },
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "hysteresis": 2,
        "min_loss_scale": 1,
    },
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": True,
        },
    },
}

training_args = TrainingArguments(
    deepspeed=ds_config,
)

性能优化 #

显存优化 #

python
def optimize_memory_usage():
    import torch
    import gc
    
    torch.cuda.empty_cache()
    gc.collect()
    
    torch.cuda.set_per_process_memory_fraction(0.9, 0)
    
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

optimize_memory_usage()

训练速度优化 #

python
from transformers import TrainingArguments

training_args = TrainingArguments(
    dataloader_num_workers=4,
    dataloader_pin_memory=True,
    dataloader_prefetch_factor=2,
    
    torch_compile=True,
    torch_compile_backend="inductor",
    
    gradient_accumulation_steps=4,
    per_device_train_batch_size=4,
)

下一步 #

现在你已经掌握了 LoRA 的进阶技巧,接下来学习 最佳实践,了解实战经验和常见问题解决方案!

最后更新:2026-04-05