元学习方法 #

元学习概述 #

什么是元学习? #

元学习(Meta-Learning)是 Few-shot Learning 的核心技术之一,其核心思想是"学习如何学习"。

text
┌─────────────────────────────────────────────────────────────┐
│                    元学习核心思想                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  传统学习:                                                  │
│  在特定任务上学习 → 得到该任务的模型                         │
│                                                             │
│  元学习:                                                    │
│  在多个任务上学习 → 得到能快速适应新任务的模型               │
│                                                             │
│  目标:                                                      │
│  学习一个初始模型,使其能够通过少量梯度更新快速适应新任务     │
│                                                             │
└─────────────────────────────────────────────────────────────┘

元学习的分类 #

text
┌─────────────────────────────────────────────────────────────┐
│                    元学习方法分类                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 基于优化的方法                                          │
│     ├── MAML (Model-Agnostic Meta-Learning)                │
│     ├── Meta-SGD                                           │
│     ├── Reptile                                            │
│     └── 学习优化策略                                        │
│                                                             │
│  2. 基于模型的方法                                          │
│     ├── MANN (Memory-Augmented Neural Networks)            │
│     ├── Meta Networks                                      │
│     └── 设计快速适应的架构                                  │
│                                                             │
│  3. 基于度量的方法                                          │
│     ├── Siamese Networks                                   │
│     ├── Prototypical Networks                              │
│     └── 学习相似度度量                                      │
│                                                             │
└─────────────────────────────────────────────────────────────┘

MAML (Model-Agnostic Meta-Learning) #

MAML 核心思想 #

text
┌─────────────────────────────────────────────────────────────┐
│                    MAML 核心思想                             │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  目标:找到一个好的初始化参数 θ                             │
│  使得对于新任务 Ti,只需几步梯度更新就能达到好的性能          │
│                                                             │
│  过程:                                                      │
│  1. 从初始化参数 θ 开始                                     │
│  2. 在任务 Ti 上进行几步梯度更新                            │
│     θi' = θ - α∇L(θ, Ti)                                   │
│  3. 在多个任务上优化 θ                                      │
│     θ ← θ - β∇ΣL(θi', Ti)                                  │
│                                                             │
│  关键:二阶梯度(梯度的梯度)                               │
│                                                             │
└─────────────────────────────────────────────────────────────┘

MAML 算法流程 #

python
MAML 算法流程:

输入:
- 任务分布 p(T)
- 学习率 α, β
- 初始化参数 θ

for iteration in range(num_iterations):
    1. 采样一批任务: T = {T1, T2, ..., Tn}
    
    2. 对每个任务 Ti:
       a. 采样 Support Set: Di^s
       b. 采样 Query Set: Di^q
       c. 计算梯度并更新:
          θi' = θ - α∇θL(θ, Di^s)
       d. 计算 Query Set 上的损失:
          Li = L(θi', Di^q)
    
    3. 元更新:
       θ ← θ - β∇θΣLi

输出:元学习后的参数 θ

MAML 代码实现 #

python
import torch
import torch.nn as nn
import torch.optim as optim

class MAML:
    def __init__(self, model, inner_lr=0.01, outer_lr=0.001, num_inner_steps=5):
        self.model = model
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        self.num_inner_steps = num_inner_steps
        self.meta_optimizer = optim.Adam(self.model.parameters(), lr=outer_lr)
    
    def inner_loop(self, support_data, support_labels):
        fast_weights = list(self.model.parameters())
        
        for step in range(self.num_inner_steps):
            outputs = self.model.forward_with_weights(support_data, fast_weights)
            loss = nn.CrossEntropyLoss()(outputs, support_labels)
            
            grads = torch.autograd.grad(loss, fast_weights, create_graph=True)
            
            fast_weights = [
                w - self.inner_lr * g 
                for w, g in zip(fast_weights, grads)
            ]
        
        return fast_weights
    
    def meta_train(self, tasks):
        self.meta_optimizer.zero_grad()
        
        meta_loss = 0.0
        
        for task in tasks:
            support_data, support_labels = task['support']
            query_data, query_labels = task['query']
            
            fast_weights = self.inner_loop(support_data, support_labels)
            
            query_outputs = self.model.forward_with_weights(query_data, fast_weights)
            task_loss = nn.CrossEntropyLoss()(query_outputs, query_labels)
            
            meta_loss += task_loss
        
        meta_loss /= len(tasks)
        
        meta_loss.backward()
        self.meta_optimizer.step()
        
        return meta_loss.item()
    
    def adapt(self, support_data, support_labels, num_steps=10):
        fast_weights = list(self.model.parameters())
        
        optimizer = optim.SGD(fast_weights, lr=self.inner_lr)
        
        for _ in range(num_steps):
            outputs = self.model.forward_with_weights(support_data, fast_weights)
            loss = nn.CrossEntropyLoss()(outputs, support_labels)
            
            grads = torch.autograd.grad(loss, fast_weights)
            fast_weights = [
                w - self.inner_lr * g 
                for w, g in zip(fast_weights, grads)
            ]
        
        return fast_weights

class MAMLModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)
    
    def forward_with_weights(self, x, weights):
        x = torch.relu(torch.nn.functional.linear(x, weights[0], weights[1]))
        x = torch.nn.functional.linear(x, weights[2], weights[3])
        return x

model = MAMLModel(input_dim=784, hidden_dim=64, output_dim=5)
maml = MAML(model, inner_lr=0.01, outer_lr=0.001, num_inner_steps=5)

for epoch in range(num_epochs):
    tasks = sample_tasks(dataset, num_tasks=32)
    loss = maml.meta_train(tasks)
    print(f"Epoch {epoch}, Meta Loss: {loss:.4f}")

MAML 的优缺点 #

text
┌─────────────────────────────────────────────────────────────┐
│                    MAML 优缺点                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  优点:                                                      │
│  ✅ 模型无关,适用于任何模型                                 │
│  ✅ 理论基础扎实                                            │
│  ✅ 适应速度快                                              │
│  ✅ 效果好,广泛应用                                        │
│                                                             │
│  缺点:                                                      │
│  ⚠️ 计算二阶梯度,内存消耗大                                │
│  ⚠️ 训练时间长                                             │
│  ⚠️ 需要仔细调整超参数                                     │
│  ⚠️ 可能不稳定                                             │
│                                                             │
└─────────────────────────────────────────────────────────────┘

First-Order MAML (FOMAML) #

FOMAML 核心思想 #

text
┌─────────────────────────────────────────────────────────────┐
│                    FOMAML 核心思想                           │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  问题:MAML 需要计算二阶梯度,计算量大                       │
│                                                             │
│  解决方案:忽略二阶梯度,只使用一阶近似                      │
│                                                             │
│  MAML:  θ ← θ - β∇θΣL(θ - α∇θL(θ, Di^s), Di^q)            │
│  FOMAML: θ ← θ - β∇θ'Sum(L(θ', Di^q))                      │
│          where θ' = θ - α∇θL(θ, Di^s)                      │
│                                                             │
│  优点:计算简单,速度快                                     │
│  缺点:可能损失一些性能                                     │
│                                                             │
└─────────────────────────────────────────────────────────────┘

FOMAML 代码实现 #

python
class FOMAML:
    def __init__(self, model, inner_lr=0.01, outer_lr=0.001, num_inner_steps=5):
        self.model = model
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        self.num_inner_steps = num_inner_steps
        self.meta_optimizer = optim.Adam(self.model.parameters(), lr=outer_lr)
    
    def inner_loop(self, support_data, support_labels):
        fast_weights = [p.clone() for p in self.model.parameters()]
        
        for step in range(self.num_inner_steps):
            outputs = self.model.forward_with_weights(support_data, fast_weights)
            loss = nn.CrossEntropyLoss()(outputs, support_labels)
            
            grads = torch.autograd.grad(loss, fast_weights)
            
            fast_weights = [
                w - self.inner_lr * g 
                for w, g in zip(fast_weights, grads)
            ]
        
        return fast_weights
    
    def meta_train(self, tasks):
        self.meta_optimizer.zero_grad()
        
        meta_loss = 0.0
        
        for task in tasks:
            support_data, support_labels = task['support']
            query_data, query_labels = task['query']
            
            fast_weights = self.inner_loop(support_data, support_labels)
            
            query_outputs = self.model.forward_with_weights(query_data, fast_weights)
            task_loss = nn.CrossEntropyLoss()(query_outputs, query_labels)
            
            grads = torch.autograd.grad(task_loss, fast_weights)
            
            for i, p in enumerate(self.model.parameters()):
                if p.grad is None:
                    p.grad = grads[i]
                else:
                    p.grad += grads[i]
            
            meta_loss += task_loss.item()
        
        for p in self.model.parameters():
            if p.grad is not None:
                p.grad /= len(tasks)
        
        self.meta_optimizer.step()
        
        return meta_loss / len(tasks)

Reptile #

Reptile 核心思想 #

text
┌─────────────────────────────────────────────────────────────┐
│                    Reptile 核心思想                          │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  Reptile 是 MAML 的简化版本:                               │
│                                                             │
│  MAML: 优化初始参数,使得梯度更新后效果好                    │
│  Reptile: 直接向任务更新后的参数移动                        │
│                                                             │
│  算法:                                                      │
│  1. 采样任务 Ti                                             │
│  2. 在 Ti 上训练 k 步,得到参数 θi                          │
│  3. 更新初始参数:θ ← θ + ε(θi - θ)                         │
│                                                             │
│  优点:                                                      │
│  - 不需要二阶梯度                                           │
│  - 实现简单                                                 │
│  - 计算效率高                                               │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Reptile 算法流程 #

python
Reptile 算法流程:

输入:
- 任务分布 p(T)
- 学习率 ε
- 内循环步数 k
- 初始化参数 θ

for iteration in range(num_iterations):
    1. 采样任务 Ti
    
    2. 复制当前参数:
       θi = θ
    
    3. 在任务 Ti 上训练 k 步:
       for j in range(k):
           θi = θi - α∇L(θi, Ti)
    
    4. 元更新:
       θ = θ + ε(θi - θ)

输出:元学习后的参数 θ

Reptile 代码实现 #

python
import copy

class Reptile:
    def __init__(self, model, inner_lr=0.01, outer_lr=0.1, num_inner_steps=5):
        self.model = model
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        self.num_inner_steps = num_inner_steps
    
    def train_on_task(self, support_data, support_labels):
        temp_model = copy.deepcopy(self.model)
        optimizer = optim.SGD(temp_model.parameters(), lr=self.inner_lr)
        
        for _ in range(self.num_inner_steps):
            outputs = temp_model(support_data)
            loss = nn.CrossEntropyLoss()(outputs, support_labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        return temp_model
    
    def meta_train(self, tasks):
        task_weights = []
        
        for task in tasks:
            support_data, support_labels = task['support']
            temp_model = self.train_on_task(support_data, support_labels)
            task_weights.append(temp_model.state_dict())
        
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                meta_grad = torch.zeros_like(param)
                
                for weights in task_weights:
                    meta_grad += weights[name] - param
                
                meta_grad /= len(task_weights)
                param += self.outer_lr * meta_grad
    
    def adapt(self, support_data, support_labels, num_steps=10):
        temp_model = copy.deepcopy(self.model)
        optimizer = optim.SGD(temp_model.parameters(), lr=self.inner_lr)
        
        for _ in range(num_steps):
            outputs = temp_model(support_data)
            loss = nn.CrossEntropyLoss()(outputs, support_labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        return temp_model

model = MAMLModel(input_dim=784, hidden_dim=64, output_dim=5)
reptile = Reptile(model, inner_lr=0.01, outer_lr=0.1, num_inner_steps=5)

for epoch in range(num_epochs):
    tasks = sample_tasks(dataset, num_tasks=32)
    reptile.meta_train(tasks)
    print(f"Epoch {epoch} completed")

Reptile vs MAML #

text
┌─────────────────────────────────────────────────────────────┐
│                    Reptile vs MAML                           │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  计算复杂度:                                                │
│  ├── MAML: O(n²) - 需要二阶梯度                             │
│  └── Reptile: O(n) - 只需要一阶梯度                         │
│                                                             │
│  内存消耗:                                                  │
│  ├── MAML: 高 - 需要存储计算图                              │
│  └── Reptile: 低 - 只需要存储参数                           │
│                                                             │
│  性能:                                                      │
│  ├── MAML: 通常更好                                         │
│  └── Reptile: 略低但接近                                    │
│                                                             │
│  实现难度:                                                  │
│  ├── MAML: 较复杂                                           │
│  └── Reptile: 简单                                          │
│                                                             │
│  推荐场景:                                                  │
│  ├── MAML: 追求最佳性能                                     │
│  └── Reptile: 计算资源有限或快速原型                        │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Meta-SGD #

Meta-SGD 核心思想 #

text
┌─────────────────────────────────────────────────────────────┐
│                    Meta-SGD 核心思想                         │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  MAML 的问题:所有参数使用相同的学习率                       │
│                                                             │
│  Meta-SGD 的改进:学习自适应的学习率                        │
│                                                             │
│  更新规则:                                                  │
│  θi' = θ - α ⊙ ∇L(θ, Ti)                                   │
│  其中 α 是可学习的参数向量(逐元素乘法)                    │
│                                                             │
│  优点:                                                      │
│  - 自适应学习率                                             │
│  - 更灵活的适应                                             │
│  - 通常比 MAML 性能更好                                     │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Meta-SGD 代码实现 #

python
class MetaSGD:
    def __init__(self, model, inner_lr=0.01, outer_lr=0.001, num_inner_steps=5):
        self.model = model
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        self.num_inner_steps = num_inner_steps
        
        self.alpha = nn.Parameter(
            torch.ones_like(torch.cat([p.view(-1) for p in model.parameters()])) * inner_lr
        )
        
        self.meta_optimizer = optim.Adam(
            list(self.model.parameters()) + [self.alpha], 
            lr=outer_lr
        )
    
    def get_alpha_shapes(self):
        alphas = {}
        idx = 0
        for name, param in self.model.named_parameters():
            num_elements = param.numel()
            alphas[name] = self.alpha[idx:idx+num_elements].view(param.shape)
            idx += num_elements
        return alphas
    
    def inner_loop(self, support_data, support_labels):
        fast_weights = list(self.model.parameters())
        alphas = self.get_alpha_shapes()
        
        for step in range(self.num_inner_steps):
            outputs = self.model.forward_with_weights(support_data, fast_weights)
            loss = nn.CrossEntropyLoss()(outputs, support_labels)
            
            grads = torch.autograd.grad(loss, fast_weights, create_graph=True)
            
            fast_weights = [
                w - alphas[name] * g 
                for w, g, name in zip(fast_weights, grads, self.model.state_dict().keys())
            ]
        
        return fast_weights
    
    def meta_train(self, tasks):
        self.meta_optimizer.zero_grad()
        
        meta_loss = 0.0
        
        for task in tasks:
            support_data, support_labels = task['support']
            query_data, query_labels = task['query']
            
            fast_weights = self.inner_loop(support_data, support_labels)
            
            query_outputs = self.model.forward_with_weights(query_data, fast_weights)
            task_loss = nn.CrossEntropyLoss()(query_outputs, query_labels)
            
            meta_loss += task_loss
        
        meta_loss /= len(tasks)
        
        meta_loss.backward()
        self.meta_optimizer.step()
        
        return meta_loss.item()

ANIL (Almost No Inner Loop) #

ANIL 核心思想 #

text
┌─────────────────────────────────────────────────────────────┐
│                    ANIL 核心思想                             │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  观察:MAML 的内循环主要更新头部(分类层)                   │
│                                                             │
│  ANIL 的改进:                                               │
│  - 内循环只更新头部                                         │
│  - 特征提取器保持不变                                       │
│  - 大大减少计算量                                           │
│                                                             │
│  优点:                                                      │
│  - 计算效率高                                               │
│  - 内存消耗少                                               │
│  - 性能与 MAML 相当                                         │
│                                                             │
└─────────────────────────────────────────────────────────────┘

ANIL 代码实现 #

python
class ANIL:
    def __init__(self, feature_extractor, head, inner_lr=0.01, outer_lr=0.001, num_inner_steps=5):
        self.feature_extractor = feature_extractor
        self.head = head
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        self.num_inner_steps = num_inner_steps
        
        self.meta_optimizer = optim.Adam(
            list(self.feature_extractor.parameters()) + list(self.head.parameters()),
            lr=outer_lr
        )
    
    def inner_loop(self, support_data, support_labels):
        fast_head_weights = list(self.head.parameters())
        
        features = self.feature_extractor(support_data)
        features = features.detach()
        
        for step in range(self.num_inner_steps):
            outputs = self.head.forward_with_weights(features, fast_head_weights)
            loss = nn.CrossEntropyLoss()(outputs, support_labels)
            
            grads = torch.autograd.grad(loss, fast_head_weights, create_graph=True)
            
            fast_head_weights = [
                w - self.inner_lr * g 
                for w, g in zip(fast_head_weights, grads)
            ]
        
        return fast_head_weights
    
    def meta_train(self, tasks):
        self.meta_optimizer.zero_grad()
        
        meta_loss = 0.0
        
        for task in tasks:
            support_data, support_labels = task['support']
            query_data, query_labels = task['query']
            
            fast_head_weights = self.inner_loop(support_data, support_labels)
            
            query_features = self.feature_extractor(query_data)
            query_outputs = self.head.forward_with_weights(query_features, fast_head_weights)
            task_loss = nn.CrossEntropyLoss()(query_outputs, query_labels)
            
            meta_loss += task_loss
        
        meta_loss /= len(tasks)
        
        meta_loss.backward()
        self.meta_optimizer.step()
        
        return meta_loss.item()

方法对比 #

性能对比 #

方法 miniImageNet 5-way 1-shot miniImageNet 5-way 5-shot 计算复杂度 内存消耗
MAML 48.70% 63.11% O(n²)
FOMAML 48.07% 63.15% O(n)
Reptile 47.07% 62.74% O(n)
Meta-SGD 50.47% 64.03% O(n²)
ANIL 49.3% 63.8% O(n)

选择建议 #

text
┌─────────────────────────────────────────────────────────────┐
│                    方法选择建议                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  追求最佳性能:                                              │
│  └── Meta-SGD > MAML > ANIL                                │
│                                                             │
│  计算资源有限:                                              │
│  └── Reptile > ANIL > FOMAML                               │
│                                                             │
│  快速原型开发:                                              │
│  └── Reptile(实现最简单)                                  │
│                                                             │
│  生产环境部署:                                              │
│  └── ANIL(平衡性能和效率)                                 │
│                                                             │
│  研究实验:                                                  │
│  └── MAML(标准基准)                                       │
│                                                             │
└─────────────────────────────────────────────────────────────┘

下一步 #

现在你已经掌握了元学习方法,接下来学习 度量学习方法,了解另一类重要的 Few-shot Learning 方法!

最后更新:2026-04-05