元学习方法 #
元学习概述 #
什么是元学习? #
元学习(Meta-Learning)是 Few-shot Learning 的核心技术之一,其核心思想是"学习如何学习"。
text
┌─────────────────────────────────────────────────────────────┐
│ 元学习核心思想 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 传统学习: │
│ 在特定任务上学习 → 得到该任务的模型 │
│ │
│ 元学习: │
│ 在多个任务上学习 → 得到能快速适应新任务的模型 │
│ │
│ 目标: │
│ 学习一个初始模型,使其能够通过少量梯度更新快速适应新任务 │
│ │
└─────────────────────────────────────────────────────────────┘
元学习的分类 #
text
┌─────────────────────────────────────────────────────────────┐
│ 元学习方法分类 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 基于优化的方法 │
│ ├── MAML (Model-Agnostic Meta-Learning) │
│ ├── Meta-SGD │
│ ├── Reptile │
│ └── 学习优化策略 │
│ │
│ 2. 基于模型的方法 │
│ ├── MANN (Memory-Augmented Neural Networks) │
│ ├── Meta Networks │
│ └── 设计快速适应的架构 │
│ │
│ 3. 基于度量的方法 │
│ ├── Siamese Networks │
│ ├── Prototypical Networks │
│ └── 学习相似度度量 │
│ │
└─────────────────────────────────────────────────────────────┘
MAML (Model-Agnostic Meta-Learning) #
MAML 核心思想 #
text
┌─────────────────────────────────────────────────────────────┐
│ MAML 核心思想 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 目标:找到一个好的初始化参数 θ │
│ 使得对于新任务 Ti,只需几步梯度更新就能达到好的性能 │
│ │
│ 过程: │
│ 1. 从初始化参数 θ 开始 │
│ 2. 在任务 Ti 上进行几步梯度更新 │
│ θi' = θ - α∇L(θ, Ti) │
│ 3. 在多个任务上优化 θ │
│ θ ← θ - β∇ΣL(θi', Ti) │
│ │
│ 关键:二阶梯度(梯度的梯度) │
│ │
└─────────────────────────────────────────────────────────────┘
MAML 算法流程 #
python
MAML 算法流程:
输入:
- 任务分布 p(T)
- 学习率 α, β
- 初始化参数 θ
for iteration in range(num_iterations):
1. 采样一批任务: T = {T1, T2, ..., Tn}
2. 对每个任务 Ti:
a. 采样 Support Set: Di^s
b. 采样 Query Set: Di^q
c. 计算梯度并更新:
θi' = θ - α∇θL(θ, Di^s)
d. 计算 Query Set 上的损失:
Li = L(θi', Di^q)
3. 元更新:
θ ← θ - β∇θΣLi
输出:元学习后的参数 θ
MAML 代码实现 #
python
import torch
import torch.nn as nn
import torch.optim as optim
class MAML:
def __init__(self, model, inner_lr=0.01, outer_lr=0.001, num_inner_steps=5):
self.model = model
self.inner_lr = inner_lr
self.outer_lr = outer_lr
self.num_inner_steps = num_inner_steps
self.meta_optimizer = optim.Adam(self.model.parameters(), lr=outer_lr)
def inner_loop(self, support_data, support_labels):
fast_weights = list(self.model.parameters())
for step in range(self.num_inner_steps):
outputs = self.model.forward_with_weights(support_data, fast_weights)
loss = nn.CrossEntropyLoss()(outputs, support_labels)
grads = torch.autograd.grad(loss, fast_weights, create_graph=True)
fast_weights = [
w - self.inner_lr * g
for w, g in zip(fast_weights, grads)
]
return fast_weights
def meta_train(self, tasks):
self.meta_optimizer.zero_grad()
meta_loss = 0.0
for task in tasks:
support_data, support_labels = task['support']
query_data, query_labels = task['query']
fast_weights = self.inner_loop(support_data, support_labels)
query_outputs = self.model.forward_with_weights(query_data, fast_weights)
task_loss = nn.CrossEntropyLoss()(query_outputs, query_labels)
meta_loss += task_loss
meta_loss /= len(tasks)
meta_loss.backward()
self.meta_optimizer.step()
return meta_loss.item()
def adapt(self, support_data, support_labels, num_steps=10):
fast_weights = list(self.model.parameters())
optimizer = optim.SGD(fast_weights, lr=self.inner_lr)
for _ in range(num_steps):
outputs = self.model.forward_with_weights(support_data, fast_weights)
loss = nn.CrossEntropyLoss()(outputs, support_labels)
grads = torch.autograd.grad(loss, fast_weights)
fast_weights = [
w - self.inner_lr * g
for w, g in zip(fast_weights, grads)
]
return fast_weights
class MAMLModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)
def forward_with_weights(self, x, weights):
x = torch.relu(torch.nn.functional.linear(x, weights[0], weights[1]))
x = torch.nn.functional.linear(x, weights[2], weights[3])
return x
model = MAMLModel(input_dim=784, hidden_dim=64, output_dim=5)
maml = MAML(model, inner_lr=0.01, outer_lr=0.001, num_inner_steps=5)
for epoch in range(num_epochs):
tasks = sample_tasks(dataset, num_tasks=32)
loss = maml.meta_train(tasks)
print(f"Epoch {epoch}, Meta Loss: {loss:.4f}")
MAML 的优缺点 #
text
┌─────────────────────────────────────────────────────────────┐
│ MAML 优缺点 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 优点: │
│ ✅ 模型无关,适用于任何模型 │
│ ✅ 理论基础扎实 │
│ ✅ 适应速度快 │
│ ✅ 效果好,广泛应用 │
│ │
│ 缺点: │
│ ⚠️ 计算二阶梯度,内存消耗大 │
│ ⚠️ 训练时间长 │
│ ⚠️ 需要仔细调整超参数 │
│ ⚠️ 可能不稳定 │
│ │
└─────────────────────────────────────────────────────────────┘
First-Order MAML (FOMAML) #
FOMAML 核心思想 #
text
┌─────────────────────────────────────────────────────────────┐
│ FOMAML 核心思想 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 问题:MAML 需要计算二阶梯度,计算量大 │
│ │
│ 解决方案:忽略二阶梯度,只使用一阶近似 │
│ │
│ MAML: θ ← θ - β∇θΣL(θ - α∇θL(θ, Di^s), Di^q) │
│ FOMAML: θ ← θ - β∇θ'Sum(L(θ', Di^q)) │
│ where θ' = θ - α∇θL(θ, Di^s) │
│ │
│ 优点:计算简单,速度快 │
│ 缺点:可能损失一些性能 │
│ │
└─────────────────────────────────────────────────────────────┘
FOMAML 代码实现 #
python
class FOMAML:
def __init__(self, model, inner_lr=0.01, outer_lr=0.001, num_inner_steps=5):
self.model = model
self.inner_lr = inner_lr
self.outer_lr = outer_lr
self.num_inner_steps = num_inner_steps
self.meta_optimizer = optim.Adam(self.model.parameters(), lr=outer_lr)
def inner_loop(self, support_data, support_labels):
fast_weights = [p.clone() for p in self.model.parameters()]
for step in range(self.num_inner_steps):
outputs = self.model.forward_with_weights(support_data, fast_weights)
loss = nn.CrossEntropyLoss()(outputs, support_labels)
grads = torch.autograd.grad(loss, fast_weights)
fast_weights = [
w - self.inner_lr * g
for w, g in zip(fast_weights, grads)
]
return fast_weights
def meta_train(self, tasks):
self.meta_optimizer.zero_grad()
meta_loss = 0.0
for task in tasks:
support_data, support_labels = task['support']
query_data, query_labels = task['query']
fast_weights = self.inner_loop(support_data, support_labels)
query_outputs = self.model.forward_with_weights(query_data, fast_weights)
task_loss = nn.CrossEntropyLoss()(query_outputs, query_labels)
grads = torch.autograd.grad(task_loss, fast_weights)
for i, p in enumerate(self.model.parameters()):
if p.grad is None:
p.grad = grads[i]
else:
p.grad += grads[i]
meta_loss += task_loss.item()
for p in self.model.parameters():
if p.grad is not None:
p.grad /= len(tasks)
self.meta_optimizer.step()
return meta_loss / len(tasks)
Reptile #
Reptile 核心思想 #
text
┌─────────────────────────────────────────────────────────────┐
│ Reptile 核心思想 │
├─────────────────────────────────────────────────────────────┤
│ │
│ Reptile 是 MAML 的简化版本: │
│ │
│ MAML: 优化初始参数,使得梯度更新后效果好 │
│ Reptile: 直接向任务更新后的参数移动 │
│ │
│ 算法: │
│ 1. 采样任务 Ti │
│ 2. 在 Ti 上训练 k 步,得到参数 θi │
│ 3. 更新初始参数:θ ← θ + ε(θi - θ) │
│ │
│ 优点: │
│ - 不需要二阶梯度 │
│ - 实现简单 │
│ - 计算效率高 │
│ │
└─────────────────────────────────────────────────────────────┘
Reptile 算法流程 #
python
Reptile 算法流程:
输入:
- 任务分布 p(T)
- 学习率 ε
- 内循环步数 k
- 初始化参数 θ
for iteration in range(num_iterations):
1. 采样任务 Ti
2. 复制当前参数:
θi = θ
3. 在任务 Ti 上训练 k 步:
for j in range(k):
θi = θi - α∇L(θi, Ti)
4. 元更新:
θ = θ + ε(θi - θ)
输出:元学习后的参数 θ
Reptile 代码实现 #
python
import copy
class Reptile:
def __init__(self, model, inner_lr=0.01, outer_lr=0.1, num_inner_steps=5):
self.model = model
self.inner_lr = inner_lr
self.outer_lr = outer_lr
self.num_inner_steps = num_inner_steps
def train_on_task(self, support_data, support_labels):
temp_model = copy.deepcopy(self.model)
optimizer = optim.SGD(temp_model.parameters(), lr=self.inner_lr)
for _ in range(self.num_inner_steps):
outputs = temp_model(support_data)
loss = nn.CrossEntropyLoss()(outputs, support_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return temp_model
def meta_train(self, tasks):
task_weights = []
for task in tasks:
support_data, support_labels = task['support']
temp_model = self.train_on_task(support_data, support_labels)
task_weights.append(temp_model.state_dict())
with torch.no_grad():
for name, param in self.model.named_parameters():
meta_grad = torch.zeros_like(param)
for weights in task_weights:
meta_grad += weights[name] - param
meta_grad /= len(task_weights)
param += self.outer_lr * meta_grad
def adapt(self, support_data, support_labels, num_steps=10):
temp_model = copy.deepcopy(self.model)
optimizer = optim.SGD(temp_model.parameters(), lr=self.inner_lr)
for _ in range(num_steps):
outputs = temp_model(support_data)
loss = nn.CrossEntropyLoss()(outputs, support_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return temp_model
model = MAMLModel(input_dim=784, hidden_dim=64, output_dim=5)
reptile = Reptile(model, inner_lr=0.01, outer_lr=0.1, num_inner_steps=5)
for epoch in range(num_epochs):
tasks = sample_tasks(dataset, num_tasks=32)
reptile.meta_train(tasks)
print(f"Epoch {epoch} completed")
Reptile vs MAML #
text
┌─────────────────────────────────────────────────────────────┐
│ Reptile vs MAML │
├─────────────────────────────────────────────────────────────┤
│ │
│ 计算复杂度: │
│ ├── MAML: O(n²) - 需要二阶梯度 │
│ └── Reptile: O(n) - 只需要一阶梯度 │
│ │
│ 内存消耗: │
│ ├── MAML: 高 - 需要存储计算图 │
│ └── Reptile: 低 - 只需要存储参数 │
│ │
│ 性能: │
│ ├── MAML: 通常更好 │
│ └── Reptile: 略低但接近 │
│ │
│ 实现难度: │
│ ├── MAML: 较复杂 │
│ └── Reptile: 简单 │
│ │
│ 推荐场景: │
│ ├── MAML: 追求最佳性能 │
│ └── Reptile: 计算资源有限或快速原型 │
│ │
└─────────────────────────────────────────────────────────────┘
Meta-SGD #
Meta-SGD 核心思想 #
text
┌─────────────────────────────────────────────────────────────┐
│ Meta-SGD 核心思想 │
├─────────────────────────────────────────────────────────────┤
│ │
│ MAML 的问题:所有参数使用相同的学习率 │
│ │
│ Meta-SGD 的改进:学习自适应的学习率 │
│ │
│ 更新规则: │
│ θi' = θ - α ⊙ ∇L(θ, Ti) │
│ 其中 α 是可学习的参数向量(逐元素乘法) │
│ │
│ 优点: │
│ - 自适应学习率 │
│ - 更灵活的适应 │
│ - 通常比 MAML 性能更好 │
│ │
└─────────────────────────────────────────────────────────────┘
Meta-SGD 代码实现 #
python
class MetaSGD:
def __init__(self, model, inner_lr=0.01, outer_lr=0.001, num_inner_steps=5):
self.model = model
self.inner_lr = inner_lr
self.outer_lr = outer_lr
self.num_inner_steps = num_inner_steps
self.alpha = nn.Parameter(
torch.ones_like(torch.cat([p.view(-1) for p in model.parameters()])) * inner_lr
)
self.meta_optimizer = optim.Adam(
list(self.model.parameters()) + [self.alpha],
lr=outer_lr
)
def get_alpha_shapes(self):
alphas = {}
idx = 0
for name, param in self.model.named_parameters():
num_elements = param.numel()
alphas[name] = self.alpha[idx:idx+num_elements].view(param.shape)
idx += num_elements
return alphas
def inner_loop(self, support_data, support_labels):
fast_weights = list(self.model.parameters())
alphas = self.get_alpha_shapes()
for step in range(self.num_inner_steps):
outputs = self.model.forward_with_weights(support_data, fast_weights)
loss = nn.CrossEntropyLoss()(outputs, support_labels)
grads = torch.autograd.grad(loss, fast_weights, create_graph=True)
fast_weights = [
w - alphas[name] * g
for w, g, name in zip(fast_weights, grads, self.model.state_dict().keys())
]
return fast_weights
def meta_train(self, tasks):
self.meta_optimizer.zero_grad()
meta_loss = 0.0
for task in tasks:
support_data, support_labels = task['support']
query_data, query_labels = task['query']
fast_weights = self.inner_loop(support_data, support_labels)
query_outputs = self.model.forward_with_weights(query_data, fast_weights)
task_loss = nn.CrossEntropyLoss()(query_outputs, query_labels)
meta_loss += task_loss
meta_loss /= len(tasks)
meta_loss.backward()
self.meta_optimizer.step()
return meta_loss.item()
ANIL (Almost No Inner Loop) #
ANIL 核心思想 #
text
┌─────────────────────────────────────────────────────────────┐
│ ANIL 核心思想 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 观察:MAML 的内循环主要更新头部(分类层) │
│ │
│ ANIL 的改进: │
│ - 内循环只更新头部 │
│ - 特征提取器保持不变 │
│ - 大大减少计算量 │
│ │
│ 优点: │
│ - 计算效率高 │
│ - 内存消耗少 │
│ - 性能与 MAML 相当 │
│ │
└─────────────────────────────────────────────────────────────┘
ANIL 代码实现 #
python
class ANIL:
def __init__(self, feature_extractor, head, inner_lr=0.01, outer_lr=0.001, num_inner_steps=5):
self.feature_extractor = feature_extractor
self.head = head
self.inner_lr = inner_lr
self.outer_lr = outer_lr
self.num_inner_steps = num_inner_steps
self.meta_optimizer = optim.Adam(
list(self.feature_extractor.parameters()) + list(self.head.parameters()),
lr=outer_lr
)
def inner_loop(self, support_data, support_labels):
fast_head_weights = list(self.head.parameters())
features = self.feature_extractor(support_data)
features = features.detach()
for step in range(self.num_inner_steps):
outputs = self.head.forward_with_weights(features, fast_head_weights)
loss = nn.CrossEntropyLoss()(outputs, support_labels)
grads = torch.autograd.grad(loss, fast_head_weights, create_graph=True)
fast_head_weights = [
w - self.inner_lr * g
for w, g in zip(fast_head_weights, grads)
]
return fast_head_weights
def meta_train(self, tasks):
self.meta_optimizer.zero_grad()
meta_loss = 0.0
for task in tasks:
support_data, support_labels = task['support']
query_data, query_labels = task['query']
fast_head_weights = self.inner_loop(support_data, support_labels)
query_features = self.feature_extractor(query_data)
query_outputs = self.head.forward_with_weights(query_features, fast_head_weights)
task_loss = nn.CrossEntropyLoss()(query_outputs, query_labels)
meta_loss += task_loss
meta_loss /= len(tasks)
meta_loss.backward()
self.meta_optimizer.step()
return meta_loss.item()
方法对比 #
性能对比 #
| 方法 | miniImageNet 5-way 1-shot | miniImageNet 5-way 5-shot | 计算复杂度 | 内存消耗 |
|---|---|---|---|---|
| MAML | 48.70% | 63.11% | O(n²) | 高 |
| FOMAML | 48.07% | 63.15% | O(n) | 中 |
| Reptile | 47.07% | 62.74% | O(n) | 低 |
| Meta-SGD | 50.47% | 64.03% | O(n²) | 高 |
| ANIL | 49.3% | 63.8% | O(n) | 低 |
选择建议 #
text
┌─────────────────────────────────────────────────────────────┐
│ 方法选择建议 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 追求最佳性能: │
│ └── Meta-SGD > MAML > ANIL │
│ │
│ 计算资源有限: │
│ └── Reptile > ANIL > FOMAML │
│ │
│ 快速原型开发: │
│ └── Reptile(实现最简单) │
│ │
│ 生产环境部署: │
│ └── ANIL(平衡性能和效率) │
│ │
│ 研究实验: │
│ └── MAML(标准基准) │
│ │
└─────────────────────────────────────────────────────────────┘
下一步 #
现在你已经掌握了元学习方法,接下来学习 度量学习方法,了解另一类重要的 Few-shot Learning 方法!
最后更新:2026-04-05