Few-shot Learning 最佳实践 #
数据准备 #
数据收集策略 #
text
┌─────────────────────────────────────────────────────────────┐
│ 数据收集策略 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 训练数据 │
│ ├── 需要大量不同类别 │
│ ├── 每个类别样本可以较多 │
│ ├── 覆盖尽可能多的变化 │
│ └── 保证数据质量 │
│ │
│ 2. 验证数据 │
│ ├── 使用不同于训练的类别 │
│ ├── 用于超参数调优 │
│ └── 防止过拟合 │
│ │
│ 3. 测试数据 │
│ ├── 完全未见的类别 │
│ ├── 评估泛化能力 │
│ └── 模拟真实场景 │
│ │
└─────────────────────────────────────────────────────────────┘
数据质量要求 #
python
数据质量检查清单:
1. 标注准确性
├── 标注一致性检查
├── 多人标注验证
└── 标注质量评估
2. 样本代表性
├── 覆盖类内变化
├── 避免偏差样本
└── 平衡类别分布
3. 数据清洗
├── 去除噪声样本
├── 处理缺失值
└── 异常值检测
4. 数据增强
├── 适度的数据增强
├── 保持语义一致性
└── 避免过度增强
数据划分最佳实践 #
python
import numpy as np
from sklearn.model_selection import StratifiedKFold
class FewShotDataSplitter:
def __init__(self, num_train_classes=64, num_val_classes=16, num_test_classes=20):
self.num_train_classes = num_train_classes
self.num_val_classes = num_val_classes
self.num_test_classes = num_test_classes
def split(self, data, labels):
unique_classes = np.unique(labels)
np.random.shuffle(unique_classes)
train_classes = unique_classes[:self.num_train_classes]
val_classes = unique_classes[
self.num_train_classes:
self.num_train_classes + self.num_val_classes
]
test_classes = unique_classes[
self.num_train_classes + self.num_val_classes:
]
train_data = {cls: data[labels == cls] for cls in train_classes}
val_data = {cls: data[labels == cls] for cls in val_classes}
test_data = {cls: data[labels == cls] for cls in test_classes}
return train_data, val_data, test_data
def verify_split(self, train_data, val_data, test_data):
train_classes = set(train_data.keys())
val_classes = set(val_data.keys())
test_classes = set(test_data.keys())
assert len(train_classes & val_classes) == 0, "训练和验证类别有重叠"
assert len(train_classes & test_classes) == 0, "训练和测试类别有重叠"
assert len(val_classes & test_classes) == 0, "验证和测试类别有重叠"
print("数据划分验证通过")
print(f"训练类别: {len(train_classes)}")
print(f"验证类别: {len(val_classes)}")
print(f"测试类别: {len(test_classes)}")
模型选择 #
方法选择指南 #
text
┌─────────────────────────────────────────────────────────────┐
│ 方法选择决策树 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 样本数量 │
│ ├── 1-shot: 度量学习方法 │
│ │ ├── Prototypical Networks │
│ │ └── Matching Networks │
│ └── 5-shot+: 元学习方法 │
│ ├── MAML │
│ └── Meta-SGD │
│ │
│ 2. 计算资源 │
│ ├── 有限: Reptile, Prototypical Networks │
│ └── 充足: MAML, Meta-SGD │
│ │
│ 3. 领域相似性 │
│ ├── 相似: 度量学习方法 │
│ └── 不同: 元学习方法 + 域适应 │
│ │
│ 4. 实时性要求 │
│ ├── 高: 度量学习方法 │
│ └── 低: 元学习方法 │
│ │
└─────────────────────────────────────────────────────────────┘
骨干网络选择 #
python
骨干网络选择建议:
1. 图像任务
├── ResNet-18/34: 小数据集,快速实验
├── ResNet-50/101: 中等数据集,平衡性能
├── EfficientNet: 追求最佳性能
└── Vision Transformer: 大规模预训练
2. 文本任务
├── BERT-base: 通用文本任务
├── RoBERTa: 更好的预训练
└── DeBERTa: 最先进性能
3. 音频任务
├── VGGish: 通用音频特征
├── PANNs: 预训练音频网络
└── Wav2Vec 2.0: 自监督预训练
选择原则:
- 优先使用预训练模型
- 考虑计算资源和推理速度
- 平衡性能和复杂度
训练技巧 #
Episode 采样策略 #
python
class EpisodeSampler:
def __init__(self, data, N, K, num_query, balance=True):
self.data = data
self.N = N
self.K = K
self.num_query = num_query
self.balance = balance
def sample(self):
classes = list(self.data.keys())
if self.balance:
selected_classes = np.random.choice(classes, self.N, replace=False)
else:
selected_classes = np.random.choice(
classes,
self.N,
replace=False,
p=self.get_class_weights()
)
support_set = []
query_set = []
for i, cls in enumerate(selected_classes):
samples = self.data[cls]
if len(samples) < self.K + self.num_query:
indices = np.random.choice(
len(samples),
self.K + self.num_query,
replace=True
)
else:
indices = np.random.permutation(len(samples))
support_indices = indices[:self.K]
query_indices = indices[self.K:self.K + self.num_query]
support_set.extend([(samples[idx], i) for idx in support_indices])
query_set.extend([(samples[idx], i) for idx in query_indices])
return support_set, query_set
def get_class_weights(self):
class_counts = np.array([len(self.data[cls]) for cls in self.data.keys()])
weights = 1.0 / class_counts
weights = weights / weights.sum()
return weights
sampler = EpisodeSampler(data, N=5, K=5, num_query=15, balance=True)
for episode in range(num_episodes):
support_set, query_set = sampler.sample()
学习率调度 #
python
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
def get_optimizer_and_scheduler(model, config):
optimizer = optim.Adam(
model.parameters(),
lr=config['lr'],
weight_decay=config['weight_decay']
)
if config['scheduler'] == 'cosine':
scheduler = CosineAnnealingLR(
optimizer,
T_max=config['num_epochs'],
eta_min=config['lr'] * 0.01
)
elif config['scheduler'] == 'step':
scheduler = StepLR(
optimizer,
step_size=config['step_size'],
gamma=config['gamma']
)
else:
scheduler = None
return optimizer, scheduler
optimizer, scheduler = get_optimizer_and_scheduler(model, config)
for epoch in range(num_epochs):
train_one_epoch(model, optimizer, train_loader)
if scheduler is not None:
scheduler.step()
val_acc = validate(model, val_loader)
print(f"Epoch {epoch}, Val Acc: {val_acc:.4f}")
正则化技术 #
python
class RegularizedFewShotModel(nn.Module):
def __init__(self, encoder, classifier, dropout=0.5):
super().__init__()
self.encoder = encoder
self.classifier = classifier
self.dropout = nn.Dropout(dropout)
self.label_smoothing = 0.1
def forward(self, support_data, support_labels, query_data):
support_features = self.encoder(support_data)
query_features = self.encoder(query_data)
support_features = self.dropout(support_features)
query_features = self.dropout(query_features)
logits = self.classifier(support_features, support_labels, query_features)
return logits
def compute_loss(self, logits, labels):
loss = nn.CrossEntropyLoss(label_smoothing=self.label_smoothing)(
logits, labels
)
l2_reg = 0.0
for param in self.parameters():
l2_reg += torch.norm(param, p=2)
total_loss = loss + 0.0001 * l2_reg
return total_loss
超参数调优 #
关键超参数 #
text
┌─────────────────────────────────────────────────────────────┐
│ 关键超参数 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. Episode 相关 │
│ ├── N (way): 通常 5 或 10 │
│ ├── K (shot): 1, 5, 或根据实际 │
│ ├── num_query: 15-30 │
│ └── num_episodes: 1000-10000 │
│ │
│ 2. 优化相关 │
│ ├── learning_rate: 1e-4 到 1e-3 │
│ ├── weight_decay: 1e-5 到 1e-3 │
│ ├── batch_size: 1-4 episodes │
│ └── num_epochs: 50-200 │
│ │
│ 3. 模型相关 │
│ ├── feature_dim: 512-2048 │
│ ├── hidden_dim: 256-1024 │
│ └── dropout: 0.1-0.5 │
│ │
│ 4. 元学习相关 │
│ ├── inner_lr: 0.01-0.1 │
│ ├── outer_lr: 0.001-0.01 │
│ └── num_inner_steps: 1-10 │
│ │
└─────────────────────────────────────────────────────────────┘
超参数搜索 #
python
import optuna
def objective(trial):
lr = trial.suggest_loguniform('lr', 1e-5, 1e-2)
weight_decay = trial.suggest_loguniform('weight_decay', 1e-6, 1e-2)
dropout = trial.suggest_uniform('dropout', 0.1, 0.5)
feature_dim = trial.suggest_categorical('feature_dim', [256, 512, 1024])
model = build_model(feature_dim, dropout)
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
best_val_acc = 0.0
for epoch in range(num_epochs):
train_one_epoch(model, optimizer, train_loader)
val_acc = validate(model, val_loader)
if val_acc > best_val_acc:
best_val_acc = val_acc
return best_val_acc
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=100)
print(f"最佳参数: {study.best_params}")
print(f"最佳准确率: {study.best_value}")
性能优化 #
数据加载优化 #
python
from torch.utils.data import DataLoader
from torchvision import transforms
class FastDataLoader:
def __init__(self, dataset, batch_size, num_workers=8, pin_memory=True):
self.transform = transforms.Compose([
transforms.RandomResizedCrop(84),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
self.loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
prefetch_factor=2,
persistent_workers=True
)
def __iter__(self):
return iter(self.loader)
def __len__(self):
return len(self.loader)
混合精度训练 #
python
from torch.cuda.amp import autocast, GradScaler
def train_with_amp(model, optimizer, train_loader, num_epochs):
scaler = GradScaler()
for epoch in range(num_epochs):
for batch in train_loader:
images, labels = batch
with autocast():
outputs = model(images)
loss = nn.CrossEntropyLoss()(outputs, labels)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
模型压缩 #
python
import torch.nn.utils.prune as prune
def prune_model(model, amount=0.3):
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=amount)
elif isinstance(module, nn.Linear):
prune.l1_unstructured(module, name='weight', amount=amount)
return model
def quantize_model(model):
model.eval()
quantized_model = torch.quantization.quantize_dynamic(
model,
{nn.Linear, nn.Conv2d},
dtype=torch.qint8
)
return quantized_model
pruned_model = prune_model(model, amount=0.3)
quantized_model = quantize_model(model)
评估与验证 #
评估指标 #
python
class FewShotEvaluator:
def __init__(self, model, num_episodes=600):
self.model = model
self.num_episodes = num_episodes
def evaluate(self, test_data, N, K):
accuracies = []
for episode in range(self.num_episodes):
support_set, query_set = self.sample_episode(test_data, N, K)
predictions = self.model.predict(support_set, query_set)
accuracy = self.compute_accuracy(predictions, query_set)
accuracies.append(accuracy)
mean_acc = np.mean(accuracies)
std_acc = np.std(accuracies)
ci_95 = 1.96 * std_acc / np.sqrt(self.num_episodes)
return {
'mean_accuracy': mean_acc,
'std_accuracy': std_acc,
'ci_95': ci_95
}
def compute_accuracy(self, predictions, query_set):
correct = 0
for pred, (_, label) in zip(predictions, query_set):
if pred == label:
correct += 1
return correct / len(query_set)
def sample_episode(self, data, N, K):
pass
evaluator = FewShotEvaluator(model, num_episodes=600)
results_5way_1shot = evaluator.evaluate(test_data, N=5, K=1)
results_5way_5shot = evaluator.evaluate(test_data, N=5, K=5)
print(f"5-way 1-shot: {results_5way_1shot['mean_accuracy']:.2f} ± {results_5way_1shot['ci_95']:.2f}")
print(f"5-way 5-shot: {results_5way_5shot['mean_accuracy']:.2f} ± {results_5way_5shot['ci_95']:.2f}")
消融实验 #
python
def ablation_study(model, test_data, components):
results = {}
for component in components:
print(f"测试组件: {component}")
modified_model = modify_model(model, component)
evaluator = FewShotEvaluator(modified_model)
result = evaluator.evaluate(test_data, N=5, K=1)
results[component] = result
return results
components = [
'without_attention',
'without_data_augmentation',
'without_pretraining',
'with_different_backbone'
]
ablation_results = ablation_study(model, test_data, components)
for component, result in ablation_results.items():
print(f"{component}: {result['mean_accuracy']:.2f}")
部署最佳实践 #
模型导出 #
python
def export_model(model, save_path):
model.eval()
dummy_input = torch.randn(1, 3, 84, 84)
traced_model = torch.jit.trace(model, dummy_input)
traced_model.save(save_path)
print(f"模型已导出到: {save_path}")
export_model(model, 'fewshot_model.pt')
def export_onnx(model, save_path):
model.eval()
dummy_input = torch.randn(1, 3, 84, 84)
torch.onnx.export(
model,
dummy_input,
save_path,
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
print(f"ONNX 模型已导出到: {save_path}")
export_onnx(model, 'fewshot_model.onnx')
推理优化 #
python
class OptimizedInference:
def __init__(self, model_path):
self.model = torch.jit.load(model_path)
self.model.eval()
if torch.cuda.is_available():
self.model = self.model.cuda()
@torch.no_grad()
def predict(self, support_images, query_images):
if torch.cuda.is_available():
support_images = support_images.cuda()
query_images = query_images.cuda()
with torch.cuda.amp.autocast():
predictions = self.model(support_images, query_images)
return predictions.cpu()
def batch_predict(self, support_images, query_images, batch_size=32):
predictions = []
for i in range(0, len(query_images), batch_size):
batch_query = query_images[i:i+batch_size]
batch_pred = self.predict(support_images, batch_query)
predictions.append(batch_pred)
return torch.cat(predictions, dim=0)
常见问题与解决方案 #
过拟合问题 #
text
问题:模型在训练集上表现好,测试集表现差
解决方案:
1. 增加正则化
├── 提高 Dropout
├── 增加 L2 正则化
└── 使用 Label Smoothing
2. 数据增强
├── 增加增强强度
├── 使用更多增强方法
└── Mixup/CutMix
3. 早停
├── 监控验证集性能
├── 及时停止训练
└── 保存最佳模型
4. 减少模型复杂度
├── 减少参数量
├── 使用更简单的架构
└── 减少特征维度
类别不平衡 #
python
class BalancedSampler:
def __init__(self, data, N, K):
self.data = data
self.N = N
self.K = K
def sample(self):
classes = list(self.data.keys())
class_counts = [len(self.data[cls]) for cls in classes]
weights = 1.0 / np.array(class_counts)
weights = weights / weights.sum()
selected_classes = np.random.choice(
classes,
self.N,
replace=False,
p=weights
)
support_set = []
for i, cls in enumerate(selected_classes):
samples = self.data[cls]
indices = np.random.choice(len(samples), self.K, replace=False)
support_set.extend([(samples[idx], i) for idx in indices])
return support_set
域漂移问题 #
text
问题:训练域和测试域差异大
解决方案:
1. 域适应
├── DANN (Domain Adversarial Neural Network)
├── MMD (Maximum Mean Discrepancy)
└── CORAL (Correlation Alignment)
2. 域泛化
├── 多域训练
├── Domain-specific Batch Normalization
└── Meta-learning for Domain Generalization
3. 数据增强
├── 风格迁移
├── 域随机化
└── 合成数据
4. 特征对齐
├── 特征分布对齐
├── 对抗训练
└── 自监督预训练
总结 #
Few-shot Learning 是解决数据稀缺问题的强大工具。通过遵循这些最佳实践,你可以:
- 有效准备数据:确保数据质量和合理的划分
- 选择合适方法:根据任务特点选择最佳方法
- 优化训练过程:使用正确的训练技巧和超参数
- 提升模型性能:通过各种优化技术提高效果
- 成功部署应用:将模型部署到生产环境
记住,Few-shot Learning 的成功需要理论知识和实践经验的结合。持续学习和实验是掌握这一技术的关键!
最后更新:2026-04-05