Few-shot Learning 核心概念 #
N-way K-shot 定义 #
基本概念 #
N-way K-shot 是 Few-shot Learning 中最基本的概念,用于描述任务的难度和设置:
text
┌─────────────────────────────────────────────────────────────┐
│ N-way K-shot 定义 │
├─────────────────────────────────────────────────────────────┤
│ │
│ N-way:任务包含 N 个类别 │
│ K-shot:每个类别有 K 个标注样本 │
│ │
│ 示例:5-way 1-shot │
│ ├── 5 个类别 │
│ └── 每个类别 1 个样本 │
│ │
│ 示例:5-way 5-shot │
│ ├── 5 个类别 │
│ └── 每个类别 5 个样本 │
│ │
└─────────────────────────────────────────────────────────────┘
常见设置 #
| 设置 | 描述 | 难度 | 应用场景 |
|---|---|---|---|
| 5-way 1-shot | 5类,每类1样本 | 高 | 人脸识别、字符识别 |
| 5-way 5-shot | 5类,每类5样本 | 中 | 图像分类、目标检测 |
| 10-way 1-shot | 10类,每类1样本 | 很高 | 大规模分类 |
| 20-way 1-shot | 20类,每类1样本 | 极高 | Omniglot 字符识别 |
可视化示例 #
text
5-way 1-shot 示例:
类别 A: [样本1]
类别 B: [样本2]
类别 C: [样本3]
类别 D: [样本4]
类别 E: [样本5]
Support Set: 5 个样本(每类 1 个)
Query Set: 若干待分类样本
5-way 5-shot 示例:
类别 A: [样本1, 样本2, 样本3, 样本4, 样本5]
类别 B: [样本1, 样本2, 样本3, 样本4, 样本5]
类别 C: [样本1, 样本2, 样本3, 样本4, 样本5]
类别 D: [样本1, 样本2, 样本3, 样本4, 样本5]
类别 E: [样本1, 样本2, 样本3, 样本4, 样本5]
Support Set: 25 个样本(每类 5 个)
Query Set: 若干待分类样本
Support Set 和 Query Set #
Support Set(支持集) #
text
┌─────────────────────────────────────────────────────────────┐
│ Support Set │
├─────────────────────────────────────────────────────────────┤
│ │
│ 定义: │
│ - 用于训练的少量标注样本 │
│ - 模型从中学习类别特征 │
│ │
│ 特点: │
│ - 每个类别 K 个样本 │
│ - 样本数量少但质量要求高 │
│ - 需要具有代表性 │
│ │
│ 作用: │
│ - 提供类别信息 │
│ - 帮助模型理解新类别 │
│ - 作为参考标准 │
│ │
└─────────────────────────────────────────────────────────────┘
Query Set(查询集) #
text
┌─────────────────────────────────────────────────────────────┐
│ Query Set │
├─────────────────────────────────────────────────────────────┤
│ │
│ 定义: │
│ - 用于测试的样本 │
│ - 需要模型预测其类别 │
│ │
│ 特点: │
│ - 包含 Support Set 中的类别 │
│ - 用于评估模型学习效果 │
│ - 可以有多个样本 │
│ │
│ 作用: │
│ - 测试模型的分类能力 │
│ - 评估学习效果 │
│ - 计算准确率等指标 │
│ │
└─────────────────────────────────────────────────────────────┘
Support Set vs Query Set #
python
Support Set 和 Query Set 的关系:
┌─────────────────────────────────────────────────────────────┐
│ │
│ Support Set(训练用): │
│ ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ │
│ │ 类别A │ │ 类别B │ │ 类别C │ │ 类别D │ │ 类别E │ │
│ │ 样本 │ │ 样本 │ │ 样本 │ │ 样本 │ │ 样本 │ │
│ └──────┘ └──────┘ └──────┘ └──────┘ └──────┘ │
│ │
│ Query Set(测试用): │
│ ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ │
│ │ ? │ │ ? │ │ ? │ │ ? │ │ ? │ │
│ │ 待分类│ │ 待分类│ │ 待分类│ │ 待分类│ │ 待分类│ │
│ └──────┘ └──────┘ └──────┘ └──────┘ └──────┘ │
│ │
│ 任务:根据 Support Set 预测 Query Set 的类别 │
│ │
└─────────────────────────────────────────────────────────────┘
Episode 训练方式 #
Episode 的概念 #
text
┌─────────────────────────────────────────────────────────────┐
│ Episode 训练 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 定义: │
│ - 一次完整的训练过程 │
│ - 包含一个 Support Set 和一个 Query Set │
│ - 模拟一次测试场景 │
│ │
│ 特点: │
│ - 每个 Episode 都是独立的学习任务 │
│ - 类别随机选择 │
│ - 样本随机采样 │
│ │
│ 目的: │
│ - 学习如何学习 │
│ - 提高泛化能力 │
│ - 适应新任务 │
│ │
└─────────────────────────────────────────────────────────────┘
Episode 训练流程 #
python
Episode 训练流程:
for epoch in range(num_epochs):
for episode in range(num_episodes):
1. 随机选择 N 个类别
2. 从每个类别中采样 K 个样本作为 Support Set
3. 从每个类别中采样若干样本作为 Query Set
4. 使用 Support Set 进行学习
5. 在 Query Set 上计算损失
6. 更新模型参数
Episode 采样示例 #
python
import numpy as np
def sample_episode(data, N, K, num_query):
"""
采样一个 Episode
Args:
data: 数据集 {类别: [样本列表]}
N: 类别数量
K: 每类样本数量
num_query: 每类查询样本数量
Returns:
support_set: 支持集
query_set: 查询集
"""
classes = list(data.keys())
selected_classes = np.random.choice(classes, N, replace=False)
support_set = []
query_set = []
for cls in selected_classes:
samples = data[cls]
indices = np.random.choice(len(samples), K + num_query, replace=False)
support_set.extend([(samples[i], cls) for i in indices[:K]])
query_set.extend([(samples[i], cls) for i in indices[K:]])
return support_set, query_set
Episode vs 传统训练 #
text
┌─────────────────────────────────────────────────────────────┐
│ 传统训练 vs Episode 训练 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 传统训练: │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 整个数据集 → 批量训练 → 更新参数 → 重复 │ │
│ └─────────────────────────────────────────────────────┘ │
│ - 固定的类别 │
│ - 大量样本 │
│ - 学习特定任务 │
│ │
│ Episode 训练: │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ Episode 1 → Episode 2 → ... → Episode N │ │
│ │ (不同类别) (不同类别) (不同类别) │ │
│ └─────────────────────────────────────────────────────┘ │
│ - 随机的类别 │
│ - 少量样本 │
│ - 学习如何学习 │
│ │
└─────────────────────────────────────────────────────────────┘
元学习(Meta-Learning) #
元学习的概念 #
text
┌─────────────────────────────────────────────────────────────┐
│ 元学习定义 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 传统学习: │
│ 学习特定任务的知识 │
│ "学会做某件事" │
│ │
│ 元学习: │
│ 学习如何学习 │
│ "学会如何学习做某件事" │
│ │
│ 核心思想: │
│ - 在多个任务上训练 │
│ - 学习通用的学习策略 │
│ - 快速适应新任务 │
│ │
└─────────────────────────────────────────────────────────────┘
元学习的两个层次 #
text
┌─────────────────────────────────────────────────────────────┐
│ 元学习层次 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 内层循环(Inner Loop): │
│ - 在单个任务上学习 │
│ - 使用 Support Set │
│ - 任务特定的适应 │
│ │
│ 外层循环(Outer Loop): │
│ - 跨任务学习 │
│ - 使用多个 Episode │
│ - 学习通用的学习策略 │
│ │
│ 示例: │
│ for task in tasks: │
│ # Inner Loop │
│ task_model = adapt(base_model, task.support_set) │
│ task_loss = evaluate(task_model, task.query_set) │
│ # Outer Loop │
│ base_model = update(base_model, task_loss) │
│ │
└─────────────────────────────────────────────────────────────┘
元学习的分类 #
text
┌─────────────────────────────────────────────────────────────┐
│ 元学习方法分类 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 基于优化的方法(Optimization-based) │
│ ├── MAML(Model-Agnostic Meta-Learning) │
│ ├── Meta-SGD │
│ ├── Reptile │
│ └── 学习如何优化 │
│ │
│ 2. 基于模型的方法(Model-based) │
│ ├── Memory-augmented Networks │
│ ├── Meta Networks │
│ └── 设计能快速适应的模型架构 │
│ │
│ 3. 基于度量的方法(Metric-based) │
│ ├── Siamese Networks │
│ ├── Prototypical Networks │
│ ├── Matching Networks │
│ └── 学习样本间的相似度度量 │
│ │
└─────────────────────────────────────────────────────────────┘
Few-shot Learning 数据集 #
标准数据集 #
text
┌─────────────────────────────────────────────────────────────┐
│ 标准数据集 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. Omniglot │
│ ├── 1623 个字符类别 │
│ ├── 每个类别 20 个样本 │
│ ├── 常用设置:5-way 1-shot, 20-way 1-shot │
│ └── 字符识别基准 │
│ │
│ 2. miniImageNet │
│ ├── 100 个类别 │
│ ├── 每个类别 600 个样本 │
│ ├── 常用设置:5-way 1-shot, 5-way 5-shot │
│ └── 图像分类基准 │
│ │
│ 3. tieredImageNet │
│ ├── 608 个类别 │
│ ├── 更大的类别数量 │
│ ├── 更接近真实场景 │
│ └── 高级基准 │
│ │
│ 4. CUB-200(鸟类识别) │
│ ├── 200 个鸟类类别 │
│ ├── 细粒度分类 │
│ └── 细粒度识别基准 │
│ │
│ 5. Meta-Dataset │
│ ├── 多个数据集 │
│ ├── 跨领域评估 │
│ └── 综合基准 │
│ │
└─────────────────────────────────────────────────────────────┘
数据集划分 #
python
数据集划分方式:
┌─────────────────────────────────────────────────────────────┐
│ │
│ 传统划分: │
│ ├── Train Set: 训练集 │
│ ├── Validation Set: 验证集 │
│ └── Test Set: 测试集 │
│ │
│ Few-shot 划分: │
│ ├── Train Classes: 训练类别(用于训练 Episode) │
│ ├── Validation Classes: 验证类别(用于调参) │
│ └── Test Classes: 测试类别(用于评估,从未见过) │
│ │
│ 关键区别: │
│ - 测试类别在训练时从未出现 │
│ - 评估的是泛化到新类别的能力 │
│ - 更接近真实应用场景 │
│ │
└─────────────────────────────────────────────────────────────┘
评估指标 #
准确率(Accuracy) #
python
准确率计算:
def accuracy(predictions, labels):
"""
计算准确率
Args:
predictions: 预测标签
labels: 真实标签
Returns:
accuracy: 准确率
"""
correct = (predictions == labels).sum()
total = len(labels)
return correct / total
示例:
- 5-way 1-shot: 98.5% (Omniglot)
- 5-way 5-shot: 99.5% (Omniglot)
- 5-way 1-shot: 48.7% (miniImageNet)
- 5-way 5-shot: 63.1% (miniImageNet)
置信区间 #
python
置信区间计算:
由于 Few-shot Learning 结果有较大随机性,通常报告:
1. 多次运行的准确率
2. 平均准确率 ± 标准差
3. 95% 置信区间
示例报告:
- 5-way 1-shot: 48.70 ± 0.84%
- 5-way 5-shot: 63.11 ± 0.92%
其他指标 #
text
┌─────────────────────────────────────────────────────────────┐
│ 其他评估指标 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 收敛速度 │
│ - 达到目标性能需要的 Episode 数量 │
│ - 适应新任务的速度 │
│ │
│ 2. 计算效率 │
│ - 训练时间 │
│ - 推理时间 │
│ - 内存占用 │
│ │
│ 3. 鲁棒性 │
│ - 对样本选择的敏感性 │
│ - 对超参数的敏感性 │
│ - 跨域性能 │
│ │
└─────────────────────────────────────────────────────────────┘
Few-shot Learning 流程 #
训练阶段 #
text
┌─────────────────────────────────────────────────────────────┐
│ 训练阶段 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 准备数据 │
│ ├── 划分训练/验证/测试类别 │
│ └── 构建类别到样本的映射 │
│ │
│ 2. Episode 采样 │
│ ├── 随机选择 N 个类别 │
│ ├── 采样 Support Set │
│ └── 采样 Query Set │
│ │
│ 3. 前向传播 │
│ ├── 使用 Support Set 学习 │
│ └── 在 Query Set 上预测 │
│ │
│ 4. 计算损失 │
│ └── 交叉熵损失 │
│ │
│ 5. 反向传播 │
│ └── 更新模型参数 │
│ │
│ 6. 重复 2-5 │
│ └── 直到收敛 │
│ │
└─────────────────────────────────────────────────────────────┘
测试阶段 #
text
┌─────────────────────────────────────────────────────────────┐
│ 测试阶段 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 准备测试数据 │
│ ├── 使用从未见过的类别 │
│ └── 构建 Support Set 和 Query Set │
│ │
│ 2. 快速适应 │
│ ├── 使用 Support Set │
│ └── 微调或推理 │
│ │
│ 3. 预测 │
│ └── 在 Query Set 上预测 │
│ │
│ 4. 评估 │
│ ├── 计算准确率 │
│ └── 多次测试取平均 │
│ │
└─────────────────────────────────────────────────────────────┘
代码示例 #
完整的 Episode 训练示例 #
python
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader
class FewShotTrainer:
def __init__(self, model, num_classes=5, num_support=5, num_query=15):
self.model = model
self.num_classes = num_classes
self.num_support = num_support
self.num_query = num_query
self.optimizer = optim.Adam(model.parameters(), lr=0.001)
self.criterion = nn.CrossEntropyLoss()
def sample_episode(self, dataset):
classes = np.random.choice(
len(dataset.classes),
self.num_classes,
replace=False
)
support_set = []
query_set = []
for i, cls in enumerate(classes):
samples = dataset.get_class_samples(cls)
indices = np.random.permutation(len(samples))
support_indices = indices[:self.num_support]
query_indices = indices[self.num_support:self.num_support + self.num_query]
for idx in support_indices:
support_set.append((samples[idx], i))
for idx in query_indices:
query_set.append((samples[idx], i))
return support_set, query_set
def train_episode(self, support_set, query_set):
self.model.train()
self.optimizer.zero_grad()
support_images = torch.stack([s[0] for s in support_set])
support_labels = torch.tensor([s[1] for s in support_set])
query_images = torch.stack([q[0] for q in query_set])
query_labels = torch.tensor([q[1] for q in query_set])
support_features = self.model.extract_features(support_images)
query_features = self.model.extract_features(query_images)
prototypes = self.compute_prototypes(support_features, support_labels)
predictions = self.classify(query_features, prototypes)
loss = self.criterion(predictions, query_labels)
loss.backward()
self.optimizer.step()
accuracy = (predictions.argmax(dim=1) == query_labels).float().mean()
return loss.item(), accuracy.item()
def compute_prototypes(self, features, labels):
prototypes = []
for i in range(self.num_classes):
mask = labels == i
class_features = features[mask]
prototype = class_features.mean(dim=0)
prototypes.append(prototype)
return torch.stack(prototypes)
def classify(self, query_features, prototypes):
distances = torch.cdist(query_features, prototypes)
logits = -distances
return logits
trainer = FewShotTrainer(model, num_classes=5, num_support=5, num_query=15)
for epoch in range(num_epochs):
for episode in range(num_episodes):
support_set, query_set = trainer.sample_episode(dataset)
loss, accuracy = trainer.train_episode(support_set, query_set)
if episode % 100 == 0:
print(f"Epoch {epoch}, Episode {episode}, "
f"Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")
下一步 #
现在你已经掌握了 Few-shot Learning 的核心概念,接下来学习 元学习方法,深入了解各种 Few-shot Learning 算法!
最后更新:2026-04-05