Few-shot Learning 最佳实践 #

数据准备 #

数据收集策略 #

text
┌─────────────────────────────────────────────────────────────┐
│                    数据收集策略                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 训练数据                                                │
│     ├── 需要大量不同类别                                   │
│     ├── 每个类别样本可以较多                               │
│     ├── 覆盖尽可能多的变化                                 │
│     └── 保证数据质量                                       │
│                                                             │
│  2. 验证数据                                                │
│     ├── 使用不同于训练的类别                               │
│     ├── 用于超参数调优                                     │
│     └── 防止过拟合                                         │
│                                                             │
│  3. 测试数据                                                │
│     ├── 完全未见的类别                                     │
│     ├── 评估泛化能力                                       │
│     └── 模拟真实场景                                       │
│                                                             │
└─────────────────────────────────────────────────────────────┘

数据质量要求 #

python
数据质量检查清单:

1. 标注准确性
   ├── 标注一致性检查
   ├── 多人标注验证
   └── 标注质量评估

2. 样本代表性
   ├── 覆盖类内变化
   ├── 避免偏差样本
   └── 平衡类别分布

3. 数据清洗
   ├── 去除噪声样本
   ├── 处理缺失值
   └── 异常值检测

4. 数据增强
   ├── 适度的数据增强
   ├── 保持语义一致性
   └── 避免过度增强

数据划分最佳实践 #

python
import numpy as np
from sklearn.model_selection import StratifiedKFold

class FewShotDataSplitter:
    def __init__(self, num_train_classes=64, num_val_classes=16, num_test_classes=20):
        self.num_train_classes = num_train_classes
        self.num_val_classes = num_val_classes
        self.num_test_classes = num_test_classes
    
    def split(self, data, labels):
        unique_classes = np.unique(labels)
        np.random.shuffle(unique_classes)
        
        train_classes = unique_classes[:self.num_train_classes]
        val_classes = unique_classes[
            self.num_train_classes:
            self.num_train_classes + self.num_val_classes
        ]
        test_classes = unique_classes[
            self.num_train_classes + self.num_val_classes:
        ]
        
        train_data = {cls: data[labels == cls] for cls in train_classes}
        val_data = {cls: data[labels == cls] for cls in val_classes}
        test_data = {cls: data[labels == cls] for cls in test_classes}
        
        return train_data, val_data, test_data
    
    def verify_split(self, train_data, val_data, test_data):
        train_classes = set(train_data.keys())
        val_classes = set(val_data.keys())
        test_classes = set(test_data.keys())
        
        assert len(train_classes & val_classes) == 0, "训练和验证类别有重叠"
        assert len(train_classes & test_classes) == 0, "训练和测试类别有重叠"
        assert len(val_classes & test_classes) == 0, "验证和测试类别有重叠"
        
        print("数据划分验证通过")
        print(f"训练类别: {len(train_classes)}")
        print(f"验证类别: {len(val_classes)}")
        print(f"测试类别: {len(test_classes)}")

模型选择 #

方法选择指南 #

text
┌─────────────────────────────────────────────────────────────┐
│                    方法选择决策树                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 样本数量                                                │
│     ├── 1-shot: 度量学习方法                               │
│     │   ├── Prototypical Networks                          │
│     │   └── Matching Networks                              │
│     └── 5-shot+: 元学习方法                                │
│         ├── MAML                                           │
│         └── Meta-SGD                                       │
│                                                             │
│  2. 计算资源                                                │
│     ├── 有限: Reptile, Prototypical Networks              │
│     └── 充足: MAML, Meta-SGD                              │
│                                                             │
│  3. 领域相似性                                              │
│     ├── 相似: 度量学习方法                                 │
│     └── 不同: 元学习方法 + 域适应                          │
│                                                             │
│  4. 实时性要求                                              │
│     ├── 高: 度量学习方法                                   │
│     └── 低: 元学习方法                                     │
│                                                             │
└─────────────────────────────────────────────────────────────┘

骨干网络选择 #

python
骨干网络选择建议:

1. 图像任务
   ├── ResNet-18/34: 小数据集,快速实验
   ├── ResNet-50/101: 中等数据集,平衡性能
   ├── EfficientNet: 追求最佳性能
   └── Vision Transformer: 大规模预训练

2. 文本任务
   ├── BERT-base: 通用文本任务
   ├── RoBERTa: 更好的预训练
   └── DeBERTa: 最先进性能

3. 音频任务
   ├── VGGish: 通用音频特征
   ├── PANNs: 预训练音频网络
   └── Wav2Vec 2.0: 自监督预训练

选择原则:
- 优先使用预训练模型
- 考虑计算资源和推理速度
- 平衡性能和复杂度

训练技巧 #

Episode 采样策略 #

python
class EpisodeSampler:
    def __init__(self, data, N, K, num_query, balance=True):
        self.data = data
        self.N = N
        self.K = K
        self.num_query = num_query
        self.balance = balance
    
    def sample(self):
        classes = list(self.data.keys())
        
        if self.balance:
            selected_classes = np.random.choice(classes, self.N, replace=False)
        else:
            selected_classes = np.random.choice(
                classes, 
                self.N, 
                replace=False,
                p=self.get_class_weights()
            )
        
        support_set = []
        query_set = []
        
        for i, cls in enumerate(selected_classes):
            samples = self.data[cls]
            
            if len(samples) < self.K + self.num_query:
                indices = np.random.choice(
                    len(samples), 
                    self.K + self.num_query, 
                    replace=True
                )
            else:
                indices = np.random.permutation(len(samples))
            
            support_indices = indices[:self.K]
            query_indices = indices[self.K:self.K + self.num_query]
            
            support_set.extend([(samples[idx], i) for idx in support_indices])
            query_set.extend([(samples[idx], i) for idx in query_indices])
        
        return support_set, query_set
    
    def get_class_weights(self):
        class_counts = np.array([len(self.data[cls]) for cls in self.data.keys()])
        weights = 1.0 / class_counts
        weights = weights / weights.sum()
        return weights

sampler = EpisodeSampler(data, N=5, K=5, num_query=15, balance=True)

for episode in range(num_episodes):
    support_set, query_set = sampler.sample()

学习率调度 #

python
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR

def get_optimizer_and_scheduler(model, config):
    optimizer = optim.Adam(
        model.parameters(),
        lr=config['lr'],
        weight_decay=config['weight_decay']
    )
    
    if config['scheduler'] == 'cosine':
        scheduler = CosineAnnealingLR(
            optimizer,
            T_max=config['num_epochs'],
            eta_min=config['lr'] * 0.01
        )
    elif config['scheduler'] == 'step':
        scheduler = StepLR(
            optimizer,
            step_size=config['step_size'],
            gamma=config['gamma']
        )
    else:
        scheduler = None
    
    return optimizer, scheduler

optimizer, scheduler = get_optimizer_and_scheduler(model, config)

for epoch in range(num_epochs):
    train_one_epoch(model, optimizer, train_loader)
    
    if scheduler is not None:
        scheduler.step()
    
    val_acc = validate(model, val_loader)
    print(f"Epoch {epoch}, Val Acc: {val_acc:.4f}")

正则化技术 #

python
class RegularizedFewShotModel(nn.Module):
    def __init__(self, encoder, classifier, dropout=0.5):
        super().__init__()
        self.encoder = encoder
        self.classifier = classifier
        self.dropout = nn.Dropout(dropout)
        
        self.label_smoothing = 0.1
    
    def forward(self, support_data, support_labels, query_data):
        support_features = self.encoder(support_data)
        query_features = self.encoder(query_data)
        
        support_features = self.dropout(support_features)
        query_features = self.dropout(query_features)
        
        logits = self.classifier(support_features, support_labels, query_features)
        
        return logits
    
    def compute_loss(self, logits, labels):
        loss = nn.CrossEntropyLoss(label_smoothing=self.label_smoothing)(
            logits, labels
        )
        
        l2_reg = 0.0
        for param in self.parameters():
            l2_reg += torch.norm(param, p=2)
        
        total_loss = loss + 0.0001 * l2_reg
        
        return total_loss

超参数调优 #

关键超参数 #

text
┌─────────────────────────────────────────────────────────────┐
│                    关键超参数                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. Episode 相关                                            │
│     ├── N (way): 通常 5 或 10                              │
│     ├── K (shot): 1, 5, 或根据实际                         │
│     ├── num_query: 15-30                                   │
│     └── num_episodes: 1000-10000                           │
│                                                             │
│  2. 优化相关                                                │
│     ├── learning_rate: 1e-4 到 1e-3                        │
│     ├── weight_decay: 1e-5 到 1e-3                         │
│     ├── batch_size: 1-4 episodes                           │
│     └── num_epochs: 50-200                                 │
│                                                             │
│  3. 模型相关                                                │
│     ├── feature_dim: 512-2048                              │
│     ├── hidden_dim: 256-1024                               │
│     └── dropout: 0.1-0.5                                   │
│                                                             │
│  4. 元学习相关                                              │
│     ├── inner_lr: 0.01-0.1                                 │
│     ├── outer_lr: 0.001-0.01                               │
│     └── num_inner_steps: 1-10                              │
│                                                             │
└─────────────────────────────────────────────────────────────┘

超参数搜索 #

python
import optuna

def objective(trial):
    lr = trial.suggest_loguniform('lr', 1e-5, 1e-2)
    weight_decay = trial.suggest_loguniform('weight_decay', 1e-6, 1e-2)
    dropout = trial.suggest_uniform('dropout', 0.1, 0.5)
    feature_dim = trial.suggest_categorical('feature_dim', [256, 512, 1024])
    
    model = build_model(feature_dim, dropout)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    best_val_acc = 0.0
    
    for epoch in range(num_epochs):
        train_one_epoch(model, optimizer, train_loader)
        val_acc = validate(model, val_loader)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
    
    return best_val_acc

study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=100)

print(f"最佳参数: {study.best_params}")
print(f"最佳准确率: {study.best_value}")

性能优化 #

数据加载优化 #

python
from torch.utils.data import DataLoader
from torchvision import transforms

class FastDataLoader:
    def __init__(self, dataset, batch_size, num_workers=8, pin_memory=True):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(84),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
        
        self.loader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=pin_memory,
            prefetch_factor=2,
            persistent_workers=True
        )
    
    def __iter__(self):
        return iter(self.loader)
    
    def __len__(self):
        return len(self.loader)

混合精度训练 #

python
from torch.cuda.amp import autocast, GradScaler

def train_with_amp(model, optimizer, train_loader, num_epochs):
    scaler = GradScaler()
    
    for epoch in range(num_epochs):
        for batch in train_loader:
            images, labels = batch
            
            with autocast():
                outputs = model(images)
                loss = nn.CrossEntropyLoss()(outputs, labels)
            
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

模型压缩 #

python
import torch.nn.utils.prune as prune

def prune_model(model, amount=0.3):
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=amount)
        elif isinstance(module, nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=amount)
    
    return model

def quantize_model(model):
    model.eval()
    
    quantized_model = torch.quantization.quantize_dynamic(
        model,
        {nn.Linear, nn.Conv2d},
        dtype=torch.qint8
    )
    
    return quantized_model

pruned_model = prune_model(model, amount=0.3)
quantized_model = quantize_model(model)

评估与验证 #

评估指标 #

python
class FewShotEvaluator:
    def __init__(self, model, num_episodes=600):
        self.model = model
        self.num_episodes = num_episodes
    
    def evaluate(self, test_data, N, K):
        accuracies = []
        
        for episode in range(self.num_episodes):
            support_set, query_set = self.sample_episode(test_data, N, K)
            
            predictions = self.model.predict(support_set, query_set)
            
            accuracy = self.compute_accuracy(predictions, query_set)
            accuracies.append(accuracy)
        
        mean_acc = np.mean(accuracies)
        std_acc = np.std(accuracies)
        ci_95 = 1.96 * std_acc / np.sqrt(self.num_episodes)
        
        return {
            'mean_accuracy': mean_acc,
            'std_accuracy': std_acc,
            'ci_95': ci_95
        }
    
    def compute_accuracy(self, predictions, query_set):
        correct = 0
        for pred, (_, label) in zip(predictions, query_set):
            if pred == label:
                correct += 1
        return correct / len(query_set)
    
    def sample_episode(self, data, N, K):
        pass

evaluator = FewShotEvaluator(model, num_episodes=600)

results_5way_1shot = evaluator.evaluate(test_data, N=5, K=1)
results_5way_5shot = evaluator.evaluate(test_data, N=5, K=5)

print(f"5-way 1-shot: {results_5way_1shot['mean_accuracy']:.2f} ± {results_5way_1shot['ci_95']:.2f}")
print(f"5-way 5-shot: {results_5way_5shot['mean_accuracy']:.2f} ± {results_5way_5shot['ci_95']:.2f}")

消融实验 #

python
def ablation_study(model, test_data, components):
    results = {}
    
    for component in components:
        print(f"测试组件: {component}")
        
        modified_model = modify_model(model, component)
        
        evaluator = FewShotEvaluator(modified_model)
        result = evaluator.evaluate(test_data, N=5, K=1)
        
        results[component] = result
    
    return results

components = [
    'without_attention',
    'without_data_augmentation',
    'without_pretraining',
    'with_different_backbone'
]

ablation_results = ablation_study(model, test_data, components)

for component, result in ablation_results.items():
    print(f"{component}: {result['mean_accuracy']:.2f}")

部署最佳实践 #

模型导出 #

python
def export_model(model, save_path):
    model.eval()
    
    dummy_input = torch.randn(1, 3, 84, 84)
    
    traced_model = torch.jit.trace(model, dummy_input)
    
    traced_model.save(save_path)
    
    print(f"模型已导出到: {save_path}")

export_model(model, 'fewshot_model.pt')

def export_onnx(model, save_path):
    model.eval()
    
    dummy_input = torch.randn(1, 3, 84, 84)
    
    torch.onnx.export(
        model,
        dummy_input,
        save_path,
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={
            'input': {0: 'batch_size'},
            'output': {0: 'batch_size'}
        }
    )
    
    print(f"ONNX 模型已导出到: {save_path}")

export_onnx(model, 'fewshot_model.onnx')

推理优化 #

python
class OptimizedInference:
    def __init__(self, model_path):
        self.model = torch.jit.load(model_path)
        self.model.eval()
        
        if torch.cuda.is_available():
            self.model = self.model.cuda()
    
    @torch.no_grad()
    def predict(self, support_images, query_images):
        if torch.cuda.is_available():
            support_images = support_images.cuda()
            query_images = query_images.cuda()
        
        with torch.cuda.amp.autocast():
            predictions = self.model(support_images, query_images)
        
        return predictions.cpu()
    
    def batch_predict(self, support_images, query_images, batch_size=32):
        predictions = []
        
        for i in range(0, len(query_images), batch_size):
            batch_query = query_images[i:i+batch_size]
            batch_pred = self.predict(support_images, batch_query)
            predictions.append(batch_pred)
        
        return torch.cat(predictions, dim=0)

常见问题与解决方案 #

过拟合问题 #

text
问题:模型在训练集上表现好,测试集表现差

解决方案:
1. 增加正则化
   ├── 提高 Dropout
   ├── 增加 L2 正则化
   └── 使用 Label Smoothing

2. 数据增强
   ├── 增加增强强度
   ├── 使用更多增强方法
   └── Mixup/CutMix

3. 早停
   ├── 监控验证集性能
   ├── 及时停止训练
   └── 保存最佳模型

4. 减少模型复杂度
   ├── 减少参数量
   ├── 使用更简单的架构
   └── 减少特征维度

类别不平衡 #

python
class BalancedSampler:
    def __init__(self, data, N, K):
        self.data = data
        self.N = N
        self.K = K
    
    def sample(self):
        classes = list(self.data.keys())
        
        class_counts = [len(self.data[cls]) for cls in classes]
        weights = 1.0 / np.array(class_counts)
        weights = weights / weights.sum()
        
        selected_classes = np.random.choice(
            classes, 
            self.N, 
            replace=False, 
            p=weights
        )
        
        support_set = []
        for i, cls in enumerate(selected_classes):
            samples = self.data[cls]
            indices = np.random.choice(len(samples), self.K, replace=False)
            support_set.extend([(samples[idx], i) for idx in indices])
        
        return support_set

域漂移问题 #

text
问题:训练域和测试域差异大

解决方案:
1. 域适应
   ├── DANN (Domain Adversarial Neural Network)
   ├── MMD (Maximum Mean Discrepancy)
   └── CORAL (Correlation Alignment)

2. 域泛化
   ├── 多域训练
   ├── Domain-specific Batch Normalization
   └── Meta-learning for Domain Generalization

3. 数据增强
   ├── 风格迁移
   ├── 域随机化
   └── 合成数据

4. 特征对齐
   ├── 特征分布对齐
   ├── 对抗训练
   └── 自监督预训练

总结 #

Few-shot Learning 是解决数据稀缺问题的强大工具。通过遵循这些最佳实践,你可以:

  1. 有效准备数据:确保数据质量和合理的划分
  2. 选择合适方法:根据任务特点选择最佳方法
  3. 优化训练过程:使用正确的训练技巧和超参数
  4. 提升模型性能:通过各种优化技术提高效果
  5. 成功部署应用:将模型部署到生产环境

记住,Few-shot Learning 的成功需要理论知识和实践经验的结合。持续学习和实验是掌握这一技术的关键!

最后更新:2026-04-05