Few-shot Learning 高级架构 #
Transformer-based 方法 #
Transformer 在 Few-shot Learning 中的应用 #
text
┌─────────────────────────────────────────────────────────────┐
│ Transformer-based Few-shot │
├─────────────────────────────────────────────────────────────┤
│ │
│ 核心思想: │
│ 利用 Transformer 的注意力机制处理 Support Set 和 Query Set │
│ │
│ 优势: │
│ ├── 全局建模能力 │
│ ├── 灵活的注意力机制 │
│ ├── 可处理变长输入 │
│ └── 强大的表征能力 │
│ │
│ 代表方法: │
│ ├── FEAT (Few-shot Learning with Transformer) │
│ ├── TADAM (Task-Adaptive Activation) │
│ └── MetaOptNet │
│ │
└─────────────────────────────────────────────────────────────┘
FEAT 架构 #
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class FEAT(nn.Module):
def __init__(self, encoder, feature_dim=640, num_heads=8):
super().__init__()
self.encoder = encoder
self.feature_dim = feature_dim
self.num_heads = num_heads
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=feature_dim,
nhead=num_heads,
dim_feedforward=feature_dim * 4,
dropout=0.1
),
num_layers=1
)
self.scale = nn.Parameter(torch.ones(1) * 10.0)
def forward(self, support_images, query_images):
support_features = self.encoder(support_images)
query_features = self.encoder(query_images)
support_features = support_features.unsqueeze(1)
query_features = query_features.unsqueeze(1)
all_features = torch.cat([support_features, query_features], dim=0)
transformed_features = self.transformer(all_features)
support_transformed = transformed_features[:support_features.size(0)]
query_transformed = transformed_features[support_features.size(0):]
prototypes = self.compute_prototypes(support_transformed, support_labels)
distances = self.compute_distances(query_transformed, prototypes)
logits = -distances * self.scale
return logits
def compute_prototypes(self, features, labels):
prototypes = []
for i in range(self.num_classes):
mask = labels == i
class_features = features[mask]
prototype = class_features.mean(dim=0)
prototypes.append(prototype)
return torch.stack(prototypes)
def compute_distances(self, query_features, prototypes):
return torch.cdist(query_features, prototypes)
Set-to-Set 函数 #
text
┌─────────────────────────────────────────────────────────────┐
│ Set-to-Set 函数 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 问题:如何处理 Support Set 和 Query Set 的关系? │
│ │
│ 解决方案: │
│ │
│ 1. Set-to-Set 函数 │
│ f: (Support Set, Query) → Prediction │
│ - 将整个 Support Set 作为上下文 │
│ - 使用注意力机制聚合信息 │
│ │
│ 2. Cross-Attention │
│ Query attends to Support Set │
│ - Query 作为 Query │
│ - Support Set 作为 Key 和 Value │
│ │
│ 3. Self-Attention │
│ Support Set 内部交互 │
│ - 学习样本间的关系 │
│ - 增强特征表示 │
│ │
└─────────────────────────────────────────────────────────────┘
注意力机制增强 #
Cross-Attention 机制 #
python
class CrossAttention(nn.Module):
def __init__(self, feature_dim, num_heads=8):
super().__init__()
self.feature_dim = feature_dim
self.num_heads = num_heads
self.query_proj = nn.Linear(feature_dim, feature_dim)
self.key_proj = nn.Linear(feature_dim, feature_dim)
self.value_proj = nn.Linear(feature_dim, feature_dim)
self.output_proj = nn.Linear(feature_dim, feature_dim)
self.scale = feature_dim ** -0.5
def forward(self, query_features, support_features):
Q = self.query_proj(query_features)
K = self.key_proj(support_features)
V = self.value_proj(support_features)
Q = Q.view(Q.size(0), self.num_heads, -1).transpose(0, 1)
K = K.view(K.size(0), self.num_heads, -1).transpose(0, 1)
V = V.view(V.size(0), self.num_heads, -1).transpose(0, 1)
attention = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
attention = F.softmax(attention, dim=-1)
output = torch.matmul(attention, V)
output = output.transpose(0, 1).contiguous().view(query_features.size(0), -1)
output = self.output_proj(output)
return output + query_features
class AttentionBasedFewShot(nn.Module):
def __init__(self, encoder, feature_dim=640, num_heads=8):
super().__init__()
self.encoder = encoder
self.cross_attention = CrossAttention(feature_dim, num_heads)
self.classifier = nn.Linear(feature_dim, num_classes)
def forward(self, support_images, support_labels, query_images):
support_features = self.encoder(support_images)
query_features = self.encoder(query_images)
prototypes = self.compute_prototypes(support_features, support_labels)
query_enhanced = self.cross_attention(query_features, support_features)
distances = torch.cdist(query_enhanced, prototypes)
logits = -distances
return logits
def compute_prototypes(self, features, labels):
prototypes = []
for i in range(self.num_classes):
mask = labels == i
class_features = features[mask]
prototype = class_features.mean(dim=0)
prototypes.append(prototype)
return torch.stack(prototypes)
Self-Attention 增强 #
python
class SelfAttentionEnhanced(nn.Module):
def __init__(self, feature_dim, num_heads=8):
super().__init__()
self.self_attention = nn.MultiheadAttention(
embed_dim=feature_dim,
num_heads=num_heads,
dropout=0.1
)
self.norm = nn.LayerNorm(feature_dim)
def forward(self, features):
features = features.unsqueeze(1)
attn_output, _ = self.self_attention(
features, features, features
)
output = self.norm(features + attn_output)
return output.squeeze(1)
跨域 Few-shot Learning #
跨域问题定义 #
text
┌─────────────────────────────────────────────────────────────┐
│ 跨域 Few-shot │
├─────────────────────────────────────────────────────────────┤
│ │
│ 问题: │
│ 训练域和测试域不同,特征分布差异大 │
│ │
│ 示例: │
│ ├── 训练:真实照片 → 测试:素描 │
│ ├── 训练:ImageNet → 测试:医疗影像 │
│ └── 训练:文本 → 测试:图像 │
│ │
│ 挑战: │
│ ├── 特征分布差异 │
│ ├── 域漂移问题 │
│ └── 负迁移风险 │
│ │
└─────────────────────────────────────────────────────────────┘
域适应方法 #
python
class DomainAdaptiveFewShot(nn.Module):
def __init__(self, encoder, feature_dim=640):
super().__init__()
self.encoder = encoder
self.feature_dim = feature_dim
self.domain_discriminator = nn.Sequential(
nn.Linear(feature_dim, feature_dim // 2),
nn.ReLU(),
nn.Linear(feature_dim // 2, 1),
nn.Sigmoid()
)
self.class_classifier = nn.Linear(feature_dim, num_classes)
def forward(self, images, domain_label=None):
features = self.encoder(images)
if domain_label is not None:
domain_pred = self.domain_discriminator(features)
domain_loss = nn.BCELoss()(domain_pred, domain_label)
else:
domain_loss = 0
class_pred = self.class_classifier(features)
return class_pred, domain_loss
def get_adversarial_features(self, images, lambda_=1.0):
features = self.encoder(images)
features_adv = features.clone()
features_adv.requires_grad = True
domain_pred = self.domain_discriminator(features_adv)
domain_loss = -torch.log(domain_pred + 1e-8).mean()
domain_loss.backward()
features_adv = features - lambda_ * features_adv.grad
return features_adv
def train_domain_adaptive(model, source_data, target_data, num_epochs):
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(num_epochs):
for (source_images, source_labels), (target_images, _) in zip(source_data, target_data):
source_domain = torch.zeros(source_images.size(0), 1)
target_domain = torch.ones(target_images.size(0), 1)
class_pred_source, domain_loss_source = model(source_images, source_domain)
_, domain_loss_target = model(target_images, target_domain)
class_loss = nn.CrossEntropyLoss()(class_pred_source, source_labels)
domain_loss = domain_loss_source + domain_loss_target
total_loss = class_loss + 0.1 * domain_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
域泛化方法 #
python
class DomainGeneralizationFewShot(nn.Module):
def __init__(self, encoder, feature_dim=640, num_domains=4):
super().__init__()
self.encoder = encoder
self.feature_dim = feature_dim
self.domain_specific_bn = nn.ModuleList([
nn.BatchNorm1d(feature_dim) for _ in range(num_domains)
])
self.shared_bn = nn.BatchNorm1d(feature_dim)
self.classifier = nn.Linear(feature_dim, num_classes)
def forward(self, images, domain_id=None):
features = self.encoder(images)
if domain_id is not None and self.training:
features = self.domain_specific_bn[domain_id](features)
else:
features = self.shared_bn(features)
class_pred = self.classifier(features)
return class_pred
class MMDRegularization(nn.Module):
def __init__(self):
super().__init__()
def forward(self, source_features, target_features):
source_mean = source_features.mean(dim=0)
target_mean = target_features.mean(dim=0)
mmd_loss = torch.norm(source_mean - target_mean, p=2)
return mmd_loss
多模态 Few-shot Learning #
多模态融合 #
text
┌─────────────────────────────────────────────────────────────┐
│ 多模态 Few-shot │
├─────────────────────────────────────────────────────────────┤
│ │
│ 模态类型: │
│ ├── 图像 │
│ ├── 文本 │
│ ├── 音频 │
│ └── 视频 │
│ │
│ 融合方式: │
│ ├── 早期融合:特征级融合 │
│ ├── 晚期融合:决策级融合 │
│ └── 混合融合:多层次融合 │
│ │
│ 应用场景: │
│ ├── 图像-文本 Few-shot │
│ ├── 音频-视频 Few-shot │
│ └── 跨模态检索 │
│ │
└─────────────────────────────────────────────────────────────┘
图像-文本 Few-shot #
python
class ImageTextFewShot(nn.Module):
def __init__(self, image_encoder, text_encoder, feature_dim=512):
super().__init__()
self.image_encoder = image_encoder
self.text_encoder = text_encoder
self.image_proj = nn.Linear(image_encoder.output_dim, feature_dim)
self.text_proj = nn.Linear(text_encoder.output_dim, feature_dim)
self.temperature = nn.Parameter(torch.ones(1) * 0.07)
def forward(self, images, texts):
image_features = self.image_encoder(images)
text_features = self.text_encoder(texts)
image_features = self.image_proj(image_features)
text_features = self.text_proj(text_features)
image_features = F.normalize(image_features, dim=-1)
text_features = F.normalize(text_features, dim=-1)
logits = torch.matmul(image_features, text_features.t()) / self.temperature
return logits
def compute_loss(self, images, texts):
logits = self.forward(images, texts)
labels = torch.arange(images.size(0)).to(images.device)
loss_i2t = F.cross_entropy(logits, labels)
loss_t2i = F.cross_entropy(logits.t(), labels)
loss = (loss_i2t + loss_t2i) / 2
return loss
class CLIPStyleFewShot(nn.Module):
def __init__(self, clip_model, feature_dim=512):
super().__init__()
self.clip = clip_model
self.feature_dim = feature_dim
self.prompt_learner = PromptLearner(feature_dim)
def forward(self, support_images, support_texts, query_images):
support_image_features = self.clip.encode_image(support_images)
query_image_features = self.clip.encode_image(query_images)
support_text_features = self.clip.encode_text(support_texts)
support_image_features = F.normalize(support_image_features, dim=-1)
support_text_features = F.normalize(support_text_features, dim=-1)
query_image_features = F.normalize(query_image_features, dim=-1)
support_features = (support_image_features + support_text_features) / 2
prototypes = self.compute_prototypes(support_features, support_labels)
logits = torch.matmul(query_image_features, prototypes.t())
return logits
音频-视频 Few-shot #
python
class AudioVideoFewShot(nn.Module):
def __init__(self, audio_encoder, video_encoder, feature_dim=512):
super().__init__()
self.audio_encoder = audio_encoder
self.video_encoder = video_encoder
self.audio_proj = nn.Linear(audio_encoder.output_dim, feature_dim)
self.video_proj = nn.Linear(video_encoder.output_dim, feature_dim)
self.fusion = CrossModalFusion(feature_dim)
def forward(self, audio, video):
audio_features = self.audio_encoder(audio)
video_features = self.video_encoder(video)
audio_features = self.audio_proj(audio_features)
video_features = self.video_proj(video_features)
fused_features = self.fusion(audio_features, video_features)
return fused_features
class CrossModalFusion(nn.Module):
def __init__(self, feature_dim, num_heads=8):
super().__init__()
self.audio_to_video = nn.MultiheadAttention(feature_dim, num_heads)
self.video_to_audio = nn.MultiheadAttention(feature_dim, num_heads)
self.audio_norm = nn.LayerNorm(feature_dim)
self.video_norm = nn.LayerNorm(feature_dim)
def forward(self, audio_features, video_features):
audio_enhanced, _ = self.audio_to_video(
audio_features.unsqueeze(0),
video_features.unsqueeze(0),
video_features.unsqueeze(0)
)
audio_enhanced = self.audio_norm(audio_features + audio_enhanced.squeeze(0))
video_enhanced, _ = self.video_to_audio(
video_features.unsqueeze(0),
audio_features.unsqueeze(0),
audio_features.unsqueeze(0)
)
video_enhanced = self.video_norm(video_features + video_enhanced.squeeze(0))
fused = (audio_enhanced + video_enhanced) / 2
return fused
任务自适应机制 #
Task-Adaptive Feature Modulation #
python
class TaskAdaptiveModulation(nn.Module):
def __init__(self, feature_dim, num_classes):
super().__init__()
self.feature_dim = feature_dim
self.num_classes = num_classes
self.task_encoder = nn.Sequential(
nn.Linear(feature_dim * num_classes, feature_dim),
nn.ReLU(),
nn.Linear(feature_dim, feature_dim * 2)
)
def forward(self, support_features, support_labels):
prototypes = self.compute_prototypes(support_features, support_labels)
task_vector = prototypes.view(-1)
modulation_params = self.task_encoder(task_vector)
gamma = modulation_params[:self.feature_dim]
beta = modulation_params[self.feature_dim:]
return gamma, beta
def modulate(self, features, gamma, beta):
return gamma * features + beta
def compute_prototypes(self, features, labels):
prototypes = []
for i in range(self.num_classes):
mask = labels == i
class_features = features[mask]
prototype = class_features.mean(dim=0)
prototypes.append(prototype)
return torch.stack(prototypes)
Conditional Batch Normalization #
python
class ConditionalBatchNorm(nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.num_features = num_features
self.bn = nn.BatchNorm1d(num_features, affine=False)
self.embed = nn.Embedding(num_classes, num_features * 2)
self.embed.weight.data[:, :num_features].normal_(1, 0.02)
self.embed.weight.data[:, num_features:].zero_()
def forward(self, x, class_id):
out = self.bn(x)
gamma_beta = self.embed(class_id)
gamma = gamma_beta[:, :self.num_features]
beta = gamma_beta[:, self.num_features:]
out = gamma.view(-1, self.num_features) * out + beta.view(-1, self.num_features)
return out
自监督预训练 #
自监督预训练策略 #
text
┌─────────────────────────────────────────────────────────────┐
│ 自监督预训练 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 目的:学习更好的特征表示 │
│ │
│ 方法: │
│ ├── Contrastive Learning │
│ │ ├── SimCLR │
│ │ ├── MoCo │
│ │ └── BYOL │
│ ├── Masked Autoencoding │
│ │ └── MAE │
│ └── Self-supervised Classification │
│ └── DINO │
│ │
│ 优势: │
│ ├── 无需标注数据 │
│ ├── 学习通用特征 │
│ └── 提升 Few-shot 性能 │
│ │
└─────────────────────────────────────────────────────────────┘
SimCLR 预训练 #
python
class SimCLRPretrain(nn.Module):
def __init__(self, encoder, feature_dim=128, temperature=0.5):
super().__init__()
self.encoder = encoder
self.temperature = temperature
self.projection = nn.Sequential(
nn.Linear(encoder.output_dim, encoder.output_dim),
nn.ReLU(),
nn.Linear(encoder.output_dim, feature_dim)
)
def forward(self, x1, x2):
h1 = self.encoder(x1)
h2 = self.encoder(x2)
z1 = self.projection(h1)
z2 = self.projection(h2)
z1 = F.normalize(z1, dim=-1)
z2 = F.normalize(z2, dim=-1)
return z1, z2
def contrastive_loss(self, z1, z2):
batch_size = z1.size(0)
z = torch.cat([z1, z2], dim=0)
similarity = torch.matmul(z, z.t()) / self.temperature
labels = torch.cat([
torch.arange(batch_size, 2 * batch_size),
torch.arange(batch_size)
]).to(z.device)
mask = torch.eye(2 * batch_size, dtype=torch.bool).to(z.device)
similarity = similarity[~mask].view(2 * batch_size, -1)
loss = F.cross_entropy(similarity, labels)
return loss
下一步 #
现在你已经掌握了 Few-shot Learning 的高级架构,接下来学习 应用场景与案例,了解 Few-shot Learning 在实际中的应用!
最后更新:2026-04-05