Few-shot Learning 核心概念 #

N-way K-shot 定义 #

基本概念 #

N-way K-shot 是 Few-shot Learning 中最基本的概念,用于描述任务的难度和设置:

text
┌─────────────────────────────────────────────────────────────┐
│                    N-way K-shot 定义                         │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  N-way:任务包含 N 个类别                                    │
│  K-shot:每个类别有 K 个标注样本                             │
│                                                             │
│  示例:5-way 1-shot                                         │
│  ├── 5 个类别                                               │
│  └── 每个类别 1 个样本                                      │
│                                                             │
│  示例:5-way 5-shot                                         │
│  ├── 5 个类别                                               │
│  └── 每个类别 5 个样本                                      │
│                                                             │
└─────────────────────────────────────────────────────────────┘

常见设置 #

设置 描述 难度 应用场景
5-way 1-shot 5类,每类1样本 人脸识别、字符识别
5-way 5-shot 5类,每类5样本 图像分类、目标检测
10-way 1-shot 10类,每类1样本 很高 大规模分类
20-way 1-shot 20类,每类1样本 极高 Omniglot 字符识别

可视化示例 #

text
5-way 1-shot 示例:

类别 A: [样本1]
类别 B: [样本2]
类别 C: [样本3]
类别 D: [样本4]
类别 E: [样本5]

Support Set: 5 个样本(每类 1 个)
Query Set: 若干待分类样本

5-way 5-shot 示例:

类别 A: [样本1, 样本2, 样本3, 样本4, 样本5]
类别 B: [样本1, 样本2, 样本3, 样本4, 样本5]
类别 C: [样本1, 样本2, 样本3, 样本4, 样本5]
类别 D: [样本1, 样本2, 样本3, 样本4, 样本5]
类别 E: [样本1, 样本2, 样本3, 样本4, 样本5]

Support Set: 25 个样本(每类 5 个)
Query Set: 若干待分类样本

Support Set 和 Query Set #

Support Set(支持集) #

text
┌─────────────────────────────────────────────────────────────┐
│                    Support Set                               │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  定义:                                                      │
│  - 用于训练的少量标注样本                                    │
│  - 模型从中学习类别特征                                     │
│                                                             │
│  特点:                                                      │
│  - 每个类别 K 个样本                                        │
│  - 样本数量少但质量要求高                                   │
│  - 需要具有代表性                                           │
│                                                             │
│  作用:                                                      │
│  - 提供类别信息                                             │
│  - 帮助模型理解新类别                                       │
│  - 作为参考标准                                             │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Query Set(查询集) #

text
┌─────────────────────────────────────────────────────────────┐
│                    Query Set                                 │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  定义:                                                      │
│  - 用于测试的样本                                           │
│  - 需要模型预测其类别                                       │
│                                                             │
│  特点:                                                      │
│  - 包含 Support Set 中的类别                                │
│  - 用于评估模型学习效果                                     │
│  - 可以有多个样本                                           │
│                                                             │
│  作用:                                                      │
│  - 测试模型的分类能力                                       │
│  - 评估学习效果                                             │
│  - 计算准确率等指标                                         │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Support Set vs Query Set #

python
Support Set 和 Query Set 的关系:

┌─────────────────────────────────────────────────────────────┐
│                                                             │
│  Support Set(训练用):                                     │
│  ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐            │
│  │ 类别A │ │ 类别B │ │ 类别C │ │ 类别D │ │ 类别E │            │
│  │ 样本 │ │ 样本 │ │ 样本 │ │ 样本 │ │ 样本 │            │
│  └──────┘ └──────┘ └──────┘ └──────┘ └──────┘            │
│                                                             │
│  Query Set(测试用):                                       │
│  ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐            │
│  │  ?   │ │  ?   │ │  ?   │ │  ?   │ │  ?   │            │
│  │ 待分类│ │ 待分类│ │ 待分类│ │ 待分类│ │ 待分类│            │
│  └──────┘ └──────┘ └──────┘ └──────┘ └──────┘            │
│                                                             │
│  任务:根据 Support Set 预测 Query Set 的类别               │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Episode 训练方式 #

Episode 的概念 #

text
┌─────────────────────────────────────────────────────────────┐
│                    Episode 训练                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  定义:                                                      │
│  - 一次完整的训练过程                                       │
│  - 包含一个 Support Set 和一个 Query Set                    │
│  - 模拟一次测试场景                                         │
│                                                             │
│  特点:                                                      │
│  - 每个 Episode 都是独立的学习任务                          │
│  - 类别随机选择                                             │
│  - 样本随机采样                                             │
│                                                             │
│  目的:                                                      │
│  - 学习如何学习                                             │
│  - 提高泛化能力                                             │
│  - 适应新任务                                               │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Episode 训练流程 #

python
Episode 训练流程:

for epoch in range(num_epochs):
    for episode in range(num_episodes):
        1. 随机选择 N 个类别
        2. 从每个类别中采样 K 个样本作为 Support Set
        3. 从每个类别中采样若干样本作为 Query Set
        4. 使用 Support Set 进行学习
        5. 在 Query Set 上计算损失
        6. 更新模型参数

Episode 采样示例 #

python
import numpy as np

def sample_episode(data, N, K, num_query):
    """
    采样一个 Episode
    
    Args:
        data: 数据集 {类别: [样本列表]}
        N: 类别数量
        K: 每类样本数量
        num_query: 每类查询样本数量
    
    Returns:
        support_set: 支持集
        query_set: 查询集
    """
    classes = list(data.keys())
    selected_classes = np.random.choice(classes, N, replace=False)
    
    support_set = []
    query_set = []
    
    for cls in selected_classes:
        samples = data[cls]
        indices = np.random.choice(len(samples), K + num_query, replace=False)
        
        support_set.extend([(samples[i], cls) for i in indices[:K]])
        query_set.extend([(samples[i], cls) for i in indices[K:]])
    
    return support_set, query_set

Episode vs 传统训练 #

text
┌─────────────────────────────────────────────────────────────┐
│              传统训练 vs Episode 训练                        │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  传统训练:                                                  │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  整个数据集 → 批量训练 → 更新参数 → 重复             │   │
│  └─────────────────────────────────────────────────────┘   │
│  - 固定的类别                                               │
│  - 大量样本                                                 │
│  - 学习特定任务                                             │
│                                                             │
│  Episode 训练:                                             │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  Episode 1 → Episode 2 → ... → Episode N            │   │
│  │  (不同类别)  (不同类别)       (不同类别)              │   │
│  └─────────────────────────────────────────────────────┘   │
│  - 随机的类别                                               │
│  - 少量样本                                                 │
│  - 学习如何学习                                             │
│                                                             │
└─────────────────────────────────────────────────────────────┘

元学习(Meta-Learning) #

元学习的概念 #

text
┌─────────────────────────────────────────────────────────────┐
│                    元学习定义                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  传统学习:                                                  │
│  学习特定任务的知识                                         │
│  "学会做某件事"                                             │
│                                                             │
│  元学习:                                                    │
│  学习如何学习                                               │
│  "学会如何学习做某件事"                                     │
│                                                             │
│  核心思想:                                                  │
│  - 在多个任务上训练                                         │
│  - 学习通用的学习策略                                       │
│  - 快速适应新任务                                           │
│                                                             │
└─────────────────────────────────────────────────────────────┘

元学习的两个层次 #

text
┌─────────────────────────────────────────────────────────────┐
│                    元学习层次                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  内层循环(Inner Loop):                                    │
│  - 在单个任务上学习                                         │
│  - 使用 Support Set                                         │
│  - 任务特定的适应                                           │
│                                                             │
│  外层循环(Outer Loop):                                    │
│  - 跨任务学习                                               │
│  - 使用多个 Episode                                         │
│  - 学习通用的学习策略                                       │
│                                                             │
│  示例:                                                      │
│  for task in tasks:                                         │
│      # Inner Loop                                           │
│      task_model = adapt(base_model, task.support_set)       │
│      task_loss = evaluate(task_model, task.query_set)       │
│      # Outer Loop                                           │
│      base_model = update(base_model, task_loss)             │
│                                                             │
└─────────────────────────────────────────────────────────────┘

元学习的分类 #

text
┌─────────────────────────────────────────────────────────────┐
│                    元学习方法分类                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 基于优化的方法(Optimization-based)                    │
│     ├── MAML(Model-Agnostic Meta-Learning)               │
│     ├── Meta-SGD                                           │
│     ├── Reptile                                            │
│     └── 学习如何优化                                        │
│                                                             │
│  2. 基于模型的方法(Model-based)                           │
│     ├── Memory-augmented Networks                          │
│     ├── Meta Networks                                      │
│     └── 设计能快速适应的模型架构                            │
│                                                             │
│  3. 基于度量的方法(Metric-based)                          │
│     ├── Siamese Networks                                   │
│     ├── Prototypical Networks                              │
│     ├── Matching Networks                                  │
│     └── 学习样本间的相似度度量                              │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Few-shot Learning 数据集 #

标准数据集 #

text
┌─────────────────────────────────────────────────────────────┐
│                    标准数据集                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. Omniglot                                                │
│     ├── 1623 个字符类别                                     │
│     ├── 每个类别 20 个样本                                  │
│     ├── 常用设置:5-way 1-shot, 20-way 1-shot              │
│     └── 字符识别基准                                        │
│                                                             │
│  2. miniImageNet                                            │
│     ├── 100 个类别                                          │
│     ├── 每个类别 600 个样本                                 │
│     ├── 常用设置:5-way 1-shot, 5-way 5-shot               │
│     └── 图像分类基准                                        │
│                                                             │
│  3. tieredImageNet                                          │
│     ├── 608 个类别                                          │
│     ├── 更大的类别数量                                      │
│     ├── 更接近真实场景                                      │
│     └── 高级基准                                            │
│                                                             │
│  4. CUB-200(鸟类识别)                                     │
│     ├── 200 个鸟类类别                                      │
│     ├── 细粒度分类                                          │
│     └── 细粒度识别基准                                      │
│                                                             │
│  5. Meta-Dataset                                            │
│     ├── 多个数据集                                          │
│     ├── 跨领域评估                                          │
│     └── 综合基准                                            │
│                                                             │
└─────────────────────────────────────────────────────────────┘

数据集划分 #

python
数据集划分方式:

┌─────────────────────────────────────────────────────────────┐
│                                                             │
│  传统划分:                                                  │
│  ├── Train Set: 训练集                                      │
│  ├── Validation Set: 验证集                                 │
│  └── Test Set: 测试集                                       │
│                                                             │
│  Few-shot 划分:                                            │
│  ├── Train Classes: 训练类别(用于训练 Episode)            │
│  ├── Validation Classes: 验证类别(用于调参)               │
│  └── Test Classes: 测试类别(用于评估,从未见过)           │
│                                                             │
│  关键区别:                                                  │
│  - 测试类别在训练时从未出现                                 │
│  - 评估的是泛化到新类别的能力                               │
│  - 更接近真实应用场景                                       │
│                                                             │
└─────────────────────────────────────────────────────────────┘

评估指标 #

准确率(Accuracy) #

python
准确率计算:

def accuracy(predictions, labels):
    """
    计算准确率
    
    Args:
        predictions: 预测标签
        labels: 真实标签
    
    Returns:
        accuracy: 准确率
    """
    correct = (predictions == labels).sum()
    total = len(labels)
    return correct / total

示例:
- 5-way 1-shot: 98.5% (Omniglot)
- 5-way 5-shot: 99.5% (Omniglot)
- 5-way 1-shot: 48.7% (miniImageNet)
- 5-way 5-shot: 63.1% (miniImageNet)

置信区间 #

python
置信区间计算:

由于 Few-shot Learning 结果有较大随机性,通常报告:

1. 多次运行的准确率
2. 平均准确率 ± 标准差
3. 95% 置信区间

示例报告:
- 5-way 1-shot: 48.70 ± 0.84%
- 5-way 5-shot: 63.11 ± 0.92%

其他指标 #

text
┌─────────────────────────────────────────────────────────────┐
│                    其他评估指标                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 收敛速度                                                │
│     - 达到目标性能需要的 Episode 数量                       │
│     - 适应新任务的速度                                      │
│                                                             │
│  2. 计算效率                                                │
│     - 训练时间                                              │
│     - 推理时间                                              │
│     - 内存占用                                              │
│                                                             │
│  3. 鲁棒性                                                  │
│     - 对样本选择的敏感性                                    │
│     - 对超参数的敏感性                                      │
│     - 跨域性能                                              │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Few-shot Learning 流程 #

训练阶段 #

text
┌─────────────────────────────────────────────────────────────┐
│                    训练阶段                                  │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 准备数据                                                │
│     ├── 划分训练/验证/测试类别                              │
│     └── 构建类别到样本的映射                                │
│                                                             │
│  2. Episode 采样                                            │
│     ├── 随机选择 N 个类别                                   │
│     ├── 采样 Support Set                                    │
│     └── 采样 Query Set                                      │
│                                                             │
│  3. 前向传播                                                │
│     ├── 使用 Support Set 学习                               │
│     └── 在 Query Set 上预测                                 │
│                                                             │
│  4. 计算损失                                                │
│     └── 交叉熵损失                                          │
│                                                             │
│  5. 反向传播                                                │
│     └── 更新模型参数                                        │
│                                                             │
│  6. 重复 2-5                                                │
│     └── 直到收敛                                            │
│                                                             │
└─────────────────────────────────────────────────────────────┘

测试阶段 #

text
┌─────────────────────────────────────────────────────────────┐
│                    测试阶段                                  │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 准备测试数据                                            │
│     ├── 使用从未见过的类别                                  │
│     └── 构建 Support Set 和 Query Set                       │
│                                                             │
│  2. 快速适应                                                │
│     ├── 使用 Support Set                                    │
│     └── 微调或推理                                          │
│                                                             │
│  3. 预测                                                    │
│     └── 在 Query Set 上预测                                 │
│                                                             │
│  4. 评估                                                    │
│     ├── 计算准确率                                          │
│     └── 多次测试取平均                                      │
│                                                             │
└─────────────────────────────────────────────────────────────┘

代码示例 #

完整的 Episode 训练示例 #

python
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader

class FewShotTrainer:
    def __init__(self, model, num_classes=5, num_support=5, num_query=15):
        self.model = model
        self.num_classes = num_classes
        self.num_support = num_support
        self.num_query = num_query
        self.optimizer = optim.Adam(model.parameters(), lr=0.001)
        self.criterion = nn.CrossEntropyLoss()
    
    def sample_episode(self, dataset):
        classes = np.random.choice(
            len(dataset.classes), 
            self.num_classes, 
            replace=False
        )
        
        support_set = []
        query_set = []
        
        for i, cls in enumerate(classes):
            samples = dataset.get_class_samples(cls)
            indices = np.random.permutation(len(samples))
            
            support_indices = indices[:self.num_support]
            query_indices = indices[self.num_support:self.num_support + self.num_query]
            
            for idx in support_indices:
                support_set.append((samples[idx], i))
            
            for idx in query_indices:
                query_set.append((samples[idx], i))
        
        return support_set, query_set
    
    def train_episode(self, support_set, query_set):
        self.model.train()
        self.optimizer.zero_grad()
        
        support_images = torch.stack([s[0] for s in support_set])
        support_labels = torch.tensor([s[1] for s in support_set])
        query_images = torch.stack([q[0] for q in query_set])
        query_labels = torch.tensor([q[1] for q in query_set])
        
        support_features = self.model.extract_features(support_images)
        query_features = self.model.extract_features(query_images)
        
        prototypes = self.compute_prototypes(support_features, support_labels)
        
        predictions = self.classify(query_features, prototypes)
        
        loss = self.criterion(predictions, query_labels)
        
        loss.backward()
        self.optimizer.step()
        
        accuracy = (predictions.argmax(dim=1) == query_labels).float().mean()
        
        return loss.item(), accuracy.item()
    
    def compute_prototypes(self, features, labels):
        prototypes = []
        for i in range(self.num_classes):
            mask = labels == i
            class_features = features[mask]
            prototype = class_features.mean(dim=0)
            prototypes.append(prototype)
        return torch.stack(prototypes)
    
    def classify(self, query_features, prototypes):
        distances = torch.cdist(query_features, prototypes)
        logits = -distances
        return logits

trainer = FewShotTrainer(model, num_classes=5, num_support=5, num_query=15)

for epoch in range(num_epochs):
    for episode in range(num_episodes):
        support_set, query_set = trainer.sample_episode(dataset)
        loss, accuracy = trainer.train_episode(support_set, query_set)
        
        if episode % 100 == 0:
            print(f"Epoch {epoch}, Episode {episode}, "
                  f"Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")

下一步 #

现在你已经掌握了 Few-shot Learning 的核心概念,接下来学习 元学习方法,深入了解各种 Few-shot Learning 算法!

最后更新:2026-04-05