Few-shot Learning 高级架构 #

Transformer-based 方法 #

Transformer 在 Few-shot Learning 中的应用 #

text
┌─────────────────────────────────────────────────────────────┐
│                Transformer-based Few-shot                    │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  核心思想:                                                  │
│  利用 Transformer 的注意力机制处理 Support Set 和 Query Set │
│                                                             │
│  优势:                                                      │
│  ├── 全局建模能力                                           │
│  ├── 灵活的注意力机制                                       │
│  ├── 可处理变长输入                                         │
│  └── 强大的表征能力                                         │
│                                                             │
│  代表方法:                                                  │
│  ├── FEAT (Few-shot Learning with Transformer)             │
│  ├── TADAM (Task-Adaptive Activation)                      │
│  └── MetaOptNet                                            │
│                                                             │
└─────────────────────────────────────────────────────────────┘

FEAT 架构 #

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class FEAT(nn.Module):
    def __init__(self, encoder, feature_dim=640, num_heads=8):
        super().__init__()
        self.encoder = encoder
        self.feature_dim = feature_dim
        self.num_heads = num_heads
        
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=feature_dim,
                nhead=num_heads,
                dim_feedforward=feature_dim * 4,
                dropout=0.1
            ),
            num_layers=1
        )
        
        self.scale = nn.Parameter(torch.ones(1) * 10.0)
    
    def forward(self, support_images, query_images):
        support_features = self.encoder(support_images)
        query_features = self.encoder(query_images)
        
        support_features = support_features.unsqueeze(1)
        query_features = query_features.unsqueeze(1)
        
        all_features = torch.cat([support_features, query_features], dim=0)
        
        transformed_features = self.transformer(all_features)
        
        support_transformed = transformed_features[:support_features.size(0)]
        query_transformed = transformed_features[support_features.size(0):]
        
        prototypes = self.compute_prototypes(support_transformed, support_labels)
        
        distances = self.compute_distances(query_transformed, prototypes)
        
        logits = -distances * self.scale
        
        return logits
    
    def compute_prototypes(self, features, labels):
        prototypes = []
        for i in range(self.num_classes):
            mask = labels == i
            class_features = features[mask]
            prototype = class_features.mean(dim=0)
            prototypes.append(prototype)
        return torch.stack(prototypes)
    
    def compute_distances(self, query_features, prototypes):
        return torch.cdist(query_features, prototypes)

Set-to-Set 函数 #

text
┌─────────────────────────────────────────────────────────────┐
│                    Set-to-Set 函数                           │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  问题:如何处理 Support Set 和 Query Set 的关系?            │
│                                                             │
│  解决方案:                                                  │
│                                                             │
│  1. Set-to-Set 函数                                         │
│     f: (Support Set, Query) → Prediction                    │
│     - 将整个 Support Set 作为上下文                         │
│     - 使用注意力机制聚合信息                                │
│                                                             │
│  2. Cross-Attention                                         │
│     Query attends to Support Set                            │
│     - Query 作为 Query                                      │
│     - Support Set 作为 Key 和 Value                         │
│                                                             │
│  3. Self-Attention                                          │
│     Support Set 内部交互                                    │
│     - 学习样本间的关系                                      │
│     - 增强特征表示                                          │
│                                                             │
└─────────────────────────────────────────────────────────────┘

注意力机制增强 #

Cross-Attention 机制 #

python
class CrossAttention(nn.Module):
    def __init__(self, feature_dim, num_heads=8):
        super().__init__()
        self.feature_dim = feature_dim
        self.num_heads = num_heads
        
        self.query_proj = nn.Linear(feature_dim, feature_dim)
        self.key_proj = nn.Linear(feature_dim, feature_dim)
        self.value_proj = nn.Linear(feature_dim, feature_dim)
        self.output_proj = nn.Linear(feature_dim, feature_dim)
        
        self.scale = feature_dim ** -0.5
    
    def forward(self, query_features, support_features):
        Q = self.query_proj(query_features)
        K = self.key_proj(support_features)
        V = self.value_proj(support_features)
        
        Q = Q.view(Q.size(0), self.num_heads, -1).transpose(0, 1)
        K = K.view(K.size(0), self.num_heads, -1).transpose(0, 1)
        V = V.view(V.size(0), self.num_heads, -1).transpose(0, 1)
        
        attention = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        attention = F.softmax(attention, dim=-1)
        
        output = torch.matmul(attention, V)
        output = output.transpose(0, 1).contiguous().view(query_features.size(0), -1)
        
        output = self.output_proj(output)
        
        return output + query_features

class AttentionBasedFewShot(nn.Module):
    def __init__(self, encoder, feature_dim=640, num_heads=8):
        super().__init__()
        self.encoder = encoder
        self.cross_attention = CrossAttention(feature_dim, num_heads)
        self.classifier = nn.Linear(feature_dim, num_classes)
    
    def forward(self, support_images, support_labels, query_images):
        support_features = self.encoder(support_images)
        query_features = self.encoder(query_images)
        
        prototypes = self.compute_prototypes(support_features, support_labels)
        
        query_enhanced = self.cross_attention(query_features, support_features)
        
        distances = torch.cdist(query_enhanced, prototypes)
        logits = -distances
        
        return logits
    
    def compute_prototypes(self, features, labels):
        prototypes = []
        for i in range(self.num_classes):
            mask = labels == i
            class_features = features[mask]
            prototype = class_features.mean(dim=0)
            prototypes.append(prototype)
        return torch.stack(prototypes)

Self-Attention 增强 #

python
class SelfAttentionEnhanced(nn.Module):
    def __init__(self, feature_dim, num_heads=8):
        super().__init__()
        self.self_attention = nn.MultiheadAttention(
            embed_dim=feature_dim,
            num_heads=num_heads,
            dropout=0.1
        )
        self.norm = nn.LayerNorm(feature_dim)
    
    def forward(self, features):
        features = features.unsqueeze(1)
        
        attn_output, _ = self.self_attention(
            features, features, features
        )
        
        output = self.norm(features + attn_output)
        
        return output.squeeze(1)

跨域 Few-shot Learning #

跨域问题定义 #

text
┌─────────────────────────────────────────────────────────────┐
│                    跨域 Few-shot                             │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  问题:                                                      │
│  训练域和测试域不同,特征分布差异大                          │
│                                                             │
│  示例:                                                      │
│  ├── 训练:真实照片 → 测试:素描                            │
│  ├── 训练:ImageNet → 测试:医疗影像                        │
│  └── 训练:文本 → 测试:图像                                │
│                                                             │
│  挑战:                                                      │
│  ├── 特征分布差异                                           │
│  ├── 域漂移问题                                             │
│  └── 负迁移风险                                             │
│                                                             │
└─────────────────────────────────────────────────────────────┘

域适应方法 #

python
class DomainAdaptiveFewShot(nn.Module):
    def __init__(self, encoder, feature_dim=640):
        super().__init__()
        self.encoder = encoder
        self.feature_dim = feature_dim
        
        self.domain_discriminator = nn.Sequential(
            nn.Linear(feature_dim, feature_dim // 2),
            nn.ReLU(),
            nn.Linear(feature_dim // 2, 1),
            nn.Sigmoid()
        )
        
        self.class_classifier = nn.Linear(feature_dim, num_classes)
    
    def forward(self, images, domain_label=None):
        features = self.encoder(images)
        
        if domain_label is not None:
            domain_pred = self.domain_discriminator(features)
            domain_loss = nn.BCELoss()(domain_pred, domain_label)
        else:
            domain_loss = 0
        
        class_pred = self.class_classifier(features)
        
        return class_pred, domain_loss
    
    def get_adversarial_features(self, images, lambda_=1.0):
        features = self.encoder(images)
        
        features_adv = features.clone()
        features_adv.requires_grad = True
        
        domain_pred = self.domain_discriminator(features_adv)
        domain_loss = -torch.log(domain_pred + 1e-8).mean()
        
        domain_loss.backward()
        
        features_adv = features - lambda_ * features_adv.grad
        
        return features_adv

def train_domain_adaptive(model, source_data, target_data, num_epochs):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(num_epochs):
        for (source_images, source_labels), (target_images, _) in zip(source_data, target_data):
            source_domain = torch.zeros(source_images.size(0), 1)
            target_domain = torch.ones(target_images.size(0), 1)
            
            class_pred_source, domain_loss_source = model(source_images, source_domain)
            _, domain_loss_target = model(target_images, target_domain)
            
            class_loss = nn.CrossEntropyLoss()(class_pred_source, source_labels)
            domain_loss = domain_loss_source + domain_loss_target
            
            total_loss = class_loss + 0.1 * domain_loss
            
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

域泛化方法 #

python
class DomainGeneralizationFewShot(nn.Module):
    def __init__(self, encoder, feature_dim=640, num_domains=4):
        super().__init__()
        self.encoder = encoder
        self.feature_dim = feature_dim
        
        self.domain_specific_bn = nn.ModuleList([
            nn.BatchNorm1d(feature_dim) for _ in range(num_domains)
        ])
        
        self.shared_bn = nn.BatchNorm1d(feature_dim)
        
        self.classifier = nn.Linear(feature_dim, num_classes)
    
    def forward(self, images, domain_id=None):
        features = self.encoder(images)
        
        if domain_id is not None and self.training:
            features = self.domain_specific_bn[domain_id](features)
        else:
            features = self.shared_bn(features)
        
        class_pred = self.classifier(features)
        
        return class_pred

class MMDRegularization(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, source_features, target_features):
        source_mean = source_features.mean(dim=0)
        target_mean = target_features.mean(dim=0)
        
        mmd_loss = torch.norm(source_mean - target_mean, p=2)
        
        return mmd_loss

多模态 Few-shot Learning #

多模态融合 #

text
┌─────────────────────────────────────────────────────────────┐
│                    多模态 Few-shot                           │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  模态类型:                                                  │
│  ├── 图像                                                   │
│  ├── 文本                                                   │
│  ├── 音频                                                   │
│  └── 视频                                                   │
│                                                             │
│  融合方式:                                                  │
│  ├── 早期融合:特征级融合                                   │
│  ├── 晚期融合:决策级融合                                   │
│  └── 混合融合:多层次融合                                   │
│                                                             │
│  应用场景:                                                  │
│  ├── 图像-文本 Few-shot                                     │
│  ├── 音频-视频 Few-shot                                     │
│  └── 跨模态检索                                             │
│                                                             │
└─────────────────────────────────────────────────────────────┘

图像-文本 Few-shot #

python
class ImageTextFewShot(nn.Module):
    def __init__(self, image_encoder, text_encoder, feature_dim=512):
        super().__init__()
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        
        self.image_proj = nn.Linear(image_encoder.output_dim, feature_dim)
        self.text_proj = nn.Linear(text_encoder.output_dim, feature_dim)
        
        self.temperature = nn.Parameter(torch.ones(1) * 0.07)
    
    def forward(self, images, texts):
        image_features = self.image_encoder(images)
        text_features = self.text_encoder(texts)
        
        image_features = self.image_proj(image_features)
        text_features = self.text_proj(text_features)
        
        image_features = F.normalize(image_features, dim=-1)
        text_features = F.normalize(text_features, dim=-1)
        
        logits = torch.matmul(image_features, text_features.t()) / self.temperature
        
        return logits
    
    def compute_loss(self, images, texts):
        logits = self.forward(images, texts)
        
        labels = torch.arange(images.size(0)).to(images.device)
        
        loss_i2t = F.cross_entropy(logits, labels)
        loss_t2i = F.cross_entropy(logits.t(), labels)
        
        loss = (loss_i2t + loss_t2i) / 2
        
        return loss

class CLIPStyleFewShot(nn.Module):
    def __init__(self, clip_model, feature_dim=512):
        super().__init__()
        self.clip = clip_model
        self.feature_dim = feature_dim
        
        self.prompt_learner = PromptLearner(feature_dim)
    
    def forward(self, support_images, support_texts, query_images):
        support_image_features = self.clip.encode_image(support_images)
        query_image_features = self.clip.encode_image(query_images)
        
        support_text_features = self.clip.encode_text(support_texts)
        
        support_image_features = F.normalize(support_image_features, dim=-1)
        support_text_features = F.normalize(support_text_features, dim=-1)
        query_image_features = F.normalize(query_image_features, dim=-1)
        
        support_features = (support_image_features + support_text_features) / 2
        
        prototypes = self.compute_prototypes(support_features, support_labels)
        
        logits = torch.matmul(query_image_features, prototypes.t())
        
        return logits

音频-视频 Few-shot #

python
class AudioVideoFewShot(nn.Module):
    def __init__(self, audio_encoder, video_encoder, feature_dim=512):
        super().__init__()
        self.audio_encoder = audio_encoder
        self.video_encoder = video_encoder
        
        self.audio_proj = nn.Linear(audio_encoder.output_dim, feature_dim)
        self.video_proj = nn.Linear(video_encoder.output_dim, feature_dim)
        
        self.fusion = CrossModalFusion(feature_dim)
    
    def forward(self, audio, video):
        audio_features = self.audio_encoder(audio)
        video_features = self.video_encoder(video)
        
        audio_features = self.audio_proj(audio_features)
        video_features = self.video_proj(video_features)
        
        fused_features = self.fusion(audio_features, video_features)
        
        return fused_features

class CrossModalFusion(nn.Module):
    def __init__(self, feature_dim, num_heads=8):
        super().__init__()
        self.audio_to_video = nn.MultiheadAttention(feature_dim, num_heads)
        self.video_to_audio = nn.MultiheadAttention(feature_dim, num_heads)
        
        self.audio_norm = nn.LayerNorm(feature_dim)
        self.video_norm = nn.LayerNorm(feature_dim)
    
    def forward(self, audio_features, video_features):
        audio_enhanced, _ = self.audio_to_video(
            audio_features.unsqueeze(0),
            video_features.unsqueeze(0),
            video_features.unsqueeze(0)
        )
        audio_enhanced = self.audio_norm(audio_features + audio_enhanced.squeeze(0))
        
        video_enhanced, _ = self.video_to_audio(
            video_features.unsqueeze(0),
            audio_features.unsqueeze(0),
            audio_features.unsqueeze(0)
        )
        video_enhanced = self.video_norm(video_features + video_enhanced.squeeze(0))
        
        fused = (audio_enhanced + video_enhanced) / 2
        
        return fused

任务自适应机制 #

Task-Adaptive Feature Modulation #

python
class TaskAdaptiveModulation(nn.Module):
    def __init__(self, feature_dim, num_classes):
        super().__init__()
        self.feature_dim = feature_dim
        self.num_classes = num_classes
        
        self.task_encoder = nn.Sequential(
            nn.Linear(feature_dim * num_classes, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, feature_dim * 2)
        )
    
    def forward(self, support_features, support_labels):
        prototypes = self.compute_prototypes(support_features, support_labels)
        
        task_vector = prototypes.view(-1)
        modulation_params = self.task_encoder(task_vector)
        
        gamma = modulation_params[:self.feature_dim]
        beta = modulation_params[self.feature_dim:]
        
        return gamma, beta
    
    def modulate(self, features, gamma, beta):
        return gamma * features + beta
    
    def compute_prototypes(self, features, labels):
        prototypes = []
        for i in range(self.num_classes):
            mask = labels == i
            class_features = features[mask]
            prototype = class_features.mean(dim=0)
            prototypes.append(prototype)
        return torch.stack(prototypes)

Conditional Batch Normalization #

python
class ConditionalBatchNorm(nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.num_features = num_features
        
        self.bn = nn.BatchNorm1d(num_features, affine=False)
        
        self.embed = nn.Embedding(num_classes, num_features * 2)
        self.embed.weight.data[:, :num_features].normal_(1, 0.02)
        self.embed.weight.data[:, num_features:].zero_()
    
    def forward(self, x, class_id):
        out = self.bn(x)
        
        gamma_beta = self.embed(class_id)
        gamma = gamma_beta[:, :self.num_features]
        beta = gamma_beta[:, self.num_features:]
        
        out = gamma.view(-1, self.num_features) * out + beta.view(-1, self.num_features)
        
        return out

自监督预训练 #

自监督预训练策略 #

text
┌─────────────────────────────────────────────────────────────┐
│                    自监督预训练                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  目的:学习更好的特征表示                                   │
│                                                             │
│  方法:                                                      │
│  ├── Contrastive Learning                                  │
│  │   ├── SimCLR                                            │
│  │   ├── MoCo                                              │
│  │   └── BYOL                                              │
│  ├── Masked Autoencoding                                   │
│  │   └── MAE                                               │
│  └── Self-supervised Classification                        │
│      └── DINO                                              │
│                                                             │
│  优势:                                                      │
│  ├── 无需标注数据                                           │
│  ├── 学习通用特征                                           │
│  └── 提升 Few-shot 性能                                     │
│                                                             │
└─────────────────────────────────────────────────────────────┘

SimCLR 预训练 #

python
class SimCLRPretrain(nn.Module):
    def __init__(self, encoder, feature_dim=128, temperature=0.5):
        super().__init__()
        self.encoder = encoder
        self.temperature = temperature
        
        self.projection = nn.Sequential(
            nn.Linear(encoder.output_dim, encoder.output_dim),
            nn.ReLU(),
            nn.Linear(encoder.output_dim, feature_dim)
        )
    
    def forward(self, x1, x2):
        h1 = self.encoder(x1)
        h2 = self.encoder(x2)
        
        z1 = self.projection(h1)
        z2 = self.projection(h2)
        
        z1 = F.normalize(z1, dim=-1)
        z2 = F.normalize(z2, dim=-1)
        
        return z1, z2
    
    def contrastive_loss(self, z1, z2):
        batch_size = z1.size(0)
        
        z = torch.cat([z1, z2], dim=0)
        
        similarity = torch.matmul(z, z.t()) / self.temperature
        
        labels = torch.cat([
            torch.arange(batch_size, 2 * batch_size),
            torch.arange(batch_size)
        ]).to(z.device)
        
        mask = torch.eye(2 * batch_size, dtype=torch.bool).to(z.device)
        similarity = similarity[~mask].view(2 * batch_size, -1)
        
        loss = F.cross_entropy(similarity, labels)
        
        return loss

下一步 #

现在你已经掌握了 Few-shot Learning 的高级架构,接下来学习 应用场景与案例,了解 Few-shot Learning 在实际中的应用!

最后更新:2026-04-05