Few-shot Learning 应用场景与案例 #

计算机视觉应用 #

图像分类 #

text
┌─────────────────────────────────────────────────────────────┐
│                    图像分类应用                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  应用场景:                                                  │
│  ├── 医疗影像诊断(罕见病识别)                             │
│  ├── 工业缺陷检测(新产品缺陷)                             │
│  ├── 野生动物识别(稀有物种)                               │
│  └── 卫星图像分类(新地物类型)                             │
│                                                             │
│  挑战:                                                      │
│  ├── 类内变异大                                             │
│  ├── 类间相似度高                                           │
│  ├── 样本质量不一                                           │
│  └── 标注成本高                                             │
│                                                             │
└─────────────────────────────────────────────────────────────┘

医疗影像诊断案例 #

python
import torch
import torch.nn as nn
from torchvision import models

class MedicalImageFewShot:
    def __init__(self, num_classes=5, num_support=5):
        self.encoder = models.resnet50(pretrained=True)
        self.encoder.fc = nn.Identity()
        
        self.classifier = PrototypicalNetwork(
            feature_dim=2048,
            num_classes=num_classes
        )
        
        self.num_support = num_support
    
    def train(self, train_data, num_epochs=100):
        optimizer = torch.optim.Adam(
            list(self.encoder.parameters()) + 
            list(self.classifier.parameters()),
            lr=0.001
        )
        
        for epoch in range(num_epochs):
            for episode in self.sample_episodes(train_data):
                support_images = episode['support_images']
                support_labels = episode['support_labels']
                query_images = episode['query_images']
                query_labels = episode['query_labels']
                
                support_features = self.encoder(support_images)
                query_features = self.encoder(query_images)
                
                logits = self.classifier(
                    support_features, support_labels, query_features
                )
                
                loss = nn.CrossEntropyLoss()(logits, query_labels)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
    def predict(self, support_images, support_labels, query_images):
        self.encoder.eval()
        
        with torch.no_grad():
            support_features = self.encoder(support_images)
            query_features = self.encoder(query_images)
            
            logits = self.classifier(
                support_features, support_labels, query_features
            )
            
            predictions = logits.argmax(dim=1)
        
        return predictions

class PrototypicalNetwork(nn.Module):
    def __init__(self, feature_dim, num_classes):
        super().__init__()
        self.feature_dim = feature_dim
        self.num_classes = num_classes
    
    def forward(self, support_features, support_labels, query_features):
        prototypes = self.compute_prototypes(support_features, support_labels)
        
        distances = torch.cdist(query_features, prototypes)
        
        logits = -distances
        
        return logits
    
    def compute_prototypes(self, features, labels):
        prototypes = []
        for i in range(self.num_classes):
            mask = labels == i
            class_features = features[mask]
            prototype = class_features.mean(dim=0)
            prototypes.append(prototype)
        return torch.stack(prototypes)

medical_fewshot = MedicalImageFewShot(num_classes=5, num_support=5)

medical_fewshot.train(train_data, num_epochs=100)

support_images = load_medical_images(['disease_A_sample.jpg', 'disease_B_sample.jpg'])
support_labels = torch.tensor([0, 1])
query_images = load_medical_images(['patient_scan.jpg'])

predictions = medical_fewshot.predict(support_images, support_labels, query_images)

人脸识别案例 #

python
class FaceRecognitionFewShot:
    def __init__(self):
        self.encoder = self.build_face_encoder()
        self.distance_metric = nn.CosineSimilarity()
    
    def build_face_encoder(self):
        model = models.resnet50(pretrained=True)
        model.fc = nn.Linear(2048, 512)
        return model
    
    def register_user(self, user_id, face_images):
        features = []
        for image in face_images:
            feature = self.encoder(image.unsqueeze(0))
            features.append(feature)
        
        user_feature = torch.stack(features).mean(dim=0)
        
        self.user_database[user_id] = user_feature
        
        return user_feature
    
    def recognize(self, query_image, threshold=0.8):
        query_feature = self.encoder(query_image.unsqueeze(0))
        
        best_match = None
        best_score = -1
        
        for user_id, user_feature in self.user_database.items():
            score = self.distance_metric(query_feature, user_feature)
            
            if score > best_score:
                best_score = score
                best_match = user_id
        
        if best_score > threshold:
            return best_match, best_score
        else:
            return None, best_score

face_recognition = FaceRecognitionFewShot()

face_recognition.register_user('user_001', [image1, image2, image3])

user_id, confidence = face_recognition.recognize(query_image)
print(f"识别结果: {user_id}, 置信度: {confidence:.2f}")

自然语言处理应用 #

文本分类 #

text
┌─────────────────────────────────────────────────────────────┐
│                    文本分类应用                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  应用场景:                                                  │
│  ├── 意图识别(新意图类别)                                 │
│  ├── 情感分析(新领域)                                     │
│  ├── 新闻分类(新类别)                                     │
│  └── 垃圾邮件检测(新类型)                                 │
│                                                             │
│  挑战:                                                      │
│  ├── 文本多样性                                             │
│  ├── 语义理解                                               │
│  ├── 领域适应                                               │
│  └── 标注成本                                               │
│                                                             │
└─────────────────────────────────────────────────────────────┘

意图识别案例 #

python
import torch
from transformers import BertModel, BertTokenizer

class IntentRecognitionFewShot:
    def __init__(self, model_name='bert-base-uncased', num_classes=10):
        self.encoder = BertModel.from_pretrained(model_name)
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        
        self.classifier = PrototypicalNetwork(
            feature_dim=768,
            num_classes=num_classes
        )
    
    def encode_texts(self, texts):
        encoded = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            return_tensors='pt'
        )
        
        outputs = self.encoder(**encoded)
        
        features = outputs.last_hidden_state[:, 0, :]
        
        return features
    
    def train(self, train_data, num_epochs=50):
        optimizer = torch.optim.Adam(
            list(self.encoder.parameters()) + 
            list(self.classifier.parameters()),
            lr=2e-5
        )
        
        for epoch in range(num_epochs):
            for episode in self.sample_episodes(train_data):
                support_texts = episode['support_texts']
                support_labels = episode['support_labels']
                query_texts = episode['query_texts']
                query_labels = episode['query_labels']
                
                support_features = self.encode_texts(support_texts)
                query_features = self.encode_texts(query_texts)
                
                logits = self.classifier(
                    support_features, support_labels, query_features
                )
                
                loss = nn.CrossEntropyLoss()(logits, query_labels)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
    def predict(self, support_texts, support_labels, query_texts):
        self.encoder.eval()
        
        with torch.no_grad():
            support_features = self.encode_texts(support_texts)
            query_features = self.encode_texts(query_texts)
            
            logits = self.classifier(
                support_features, support_labels, query_features
            )
            
            predictions = logits.argmax(dim=1)
        
        return predictions

intent_recognition = IntentRecognitionFewShot(num_classes=10)

support_texts = [
    "我想预订一张机票",
    "帮我查一下天气",
    "播放一首音乐"
]
support_labels = torch.tensor([0, 1, 2])

query_texts = ["我想订票去北京"]

predictions = intent_recognition.predict(support_texts, support_labels, query_texts)
print(f"预测意图: {predictions}")

命名实体识别案例 #

python
class NERFewShot:
    def __init__(self, model_name='bert-base-uncased'):
        self.encoder = BertModel.from_pretrained(model_name)
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        
        self.tag_classifier = nn.Linear(768, num_tags)
    
    def train(self, train_data, num_epochs=30):
        optimizer = torch.optim.Adam(
            list(self.encoder.parameters()) + 
            list(self.tag_classifier.parameters()),
            lr=2e-5
        )
        
        for epoch in range(num_epochs):
            for batch in train_data:
                texts = batch['texts']
                tags = batch['tags']
                
                encoded = self.tokenizer(
                    texts,
                    padding=True,
                    truncation=True,
                    return_tensors='pt'
                )
                
                outputs = self.encoder(**encoded)
                sequence_features = outputs.last_hidden_state
                
                logits = self.tag_classifier(sequence_features)
                
                loss = nn.CrossEntropyLoss()(
                    logits.view(-1, num_tags),
                    tags.view(-1)
                )
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
    def predict(self, text):
        self.encoder.eval()
        
        with torch.no_grad():
            encoded = self.tokenizer(
                text,
                return_tensors='pt'
            )
            
            outputs = self.encoder(**encoded)
            sequence_features = outputs.last_hidden_state
            
            logits = self.tag_classifier(sequence_features)
            predictions = logits.argmax(dim=-1)
        
        return predictions

推荐系统应用 #

冷启动问题 #

text
┌─────────────────────────────────────────────────────────────┐
│                    推荐系统应用                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  应用场景:                                                  │
│  ├── 新用户推荐                                             │
│  ├── 新物品推荐                                             │
│  ├── 跨领域推荐                                             │
│  └── 实时个性化                                             │
│                                                             │
│  挑战:                                                      │
│  ├── 用户行为稀疏                                           │
│  ├── 物品特征缺失                                           │
│  ├── 冷启动问题                                             │
│  └── 实时性要求                                             │
│                                                             │
└─────────────────────────────────────────────────────────────┘

新用户推荐案例 #

python
class UserColdStartFewShot:
    def __init__(self, num_items, feature_dim=128):
        self.user_encoder = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )
        
        self.item_encoder = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )
        
        self.num_items = num_items
    
    def train(self, train_data, num_epochs=100):
        optimizer = torch.optim.Adam(
            list(self.user_encoder.parameters()) + 
            list(self.item_encoder.parameters()),
            lr=0.001
        )
        
        for epoch in range(num_epochs):
            for batch in train_data:
                user_features = batch['user_features']
                item_features = batch['item_features']
                interactions = batch['interactions']
                
                user_embeddings = self.user_encoder(user_features)
                item_embeddings = self.item_encoder(item_features)
                
                scores = torch.matmul(user_embeddings, item_embeddings.t())
                
                loss = nn.BCEWithLogitsLoss()(scores, interactions)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
    def recommend(self, user_feature, item_features, top_k=10):
        self.user_encoder.eval()
        self.item_encoder.eval()
        
        with torch.no_grad():
            user_embedding = self.user_encoder(user_feature.unsqueeze(0))
            item_embeddings = self.item_encoder(item_features)
            
            scores = torch.matmul(user_embedding, item_embeddings.t())
            
            top_k_indices = scores.topk(top_k).indices
        
        return top_k_indices

recommender = UserColdStartFewShot(num_items=10000, feature_dim=128)

recommender.train(train_data, num_epochs=100)

new_user_feature = extract_user_features(new_user)
item_features = extract_all_item_features()

recommendations = recommender.recommend(new_user_feature, item_features, top_k=10)
print(f"推荐物品: {recommendations}")

新物品推荐案例 #

python
class ItemColdStartFewShot:
    def __init__(self, num_users, feature_dim=128):
        self.item_encoder = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )
        
        self.user_embeddings = nn.Embedding(num_users, 128)
    
    def train(self, train_data, num_epochs=100):
        optimizer = torch.optim.Adam(
            list(self.item_encoder.parameters()) + 
            list(self.user_embeddings.parameters()),
            lr=0.001
        )
        
        for epoch in range(num_epochs):
            for batch in train_data:
                item_features = batch['item_features']
                user_ids = batch['user_ids']
                interactions = batch['interactions']
                
                item_embeddings = self.item_encoder(item_features)
                user_embeddings = self.user_embeddings(user_ids)
                
                scores = torch.sum(item_embeddings * user_embeddings, dim=1)
                
                loss = nn.BCEWithLogitsLoss()(scores, interactions)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
    def predict_users(self, new_item_feature, num_users, top_k=100):
        self.item_encoder.eval()
        
        with torch.no_grad():
            item_embedding = self.item_encoder(new_item_feature.unsqueeze(0))
            
            all_user_embeddings = self.user_embeddings.weight
            
            scores = torch.matmul(all_user_embeddings, item_embedding.t()).squeeze()
            
            top_k_users = scores.topk(top_k).indices
        
        return top_k_users

医疗健康应用 #

罕见病诊断 #

python
class RareDiseaseDiagnosis:
    def __init__(self, num_diseases=50):
        self.symptom_encoder = nn.Sequential(
            nn.Linear(num_symptoms, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )
        
        self.disease_encoder = nn.Sequential(
            nn.Linear(num_disease_features, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )
        
        self.num_diseases = num_diseases
    
    def train(self, train_data, num_epochs=100):
        optimizer = torch.optim.Adam(
            list(self.symptom_encoder.parameters()) + 
            list(self.disease_encoder.parameters()),
            lr=0.001
        )
        
        for epoch in range(num_epochs):
            for batch in train_data:
                symptoms = batch['symptoms']
                disease_features = batch['disease_features']
                labels = batch['labels']
                
                symptom_embeddings = self.symptom_encoder(symptoms)
                disease_embeddings = self.disease_encoder(disease_features)
                
                similarities = torch.matmul(symptom_embeddings, disease_embeddings.t())
                
                loss = nn.CrossEntropyLoss()(similarities, labels)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
    def diagnose(self, patient_symptoms, known_disease_features):
        self.symptom_encoder.eval()
        self.disease_encoder.eval()
        
        with torch.no_grad():
            symptom_embedding = self.symptom_encoder(patient_symptoms.unsqueeze(0))
            disease_embeddings = self.disease_encoder(known_disease_features)
            
            similarities = torch.matmul(symptom_embedding, disease_embeddings.t())
            
            probabilities = torch.softmax(similarities, dim=1)
            
            top_diseases = probabilities.topk(5)
        
        return top_diseases

diagnosis_system = RareDiseaseDiagnosis(num_diseases=50)

diagnosis_system.train(train_data, num_epochs=100)

patient_symptoms = extract_symptoms(patient)
known_diseases = get_rare_disease_features()

top_diseases = diagnosis_system.diagnose(patient_symptoms, known_diseases)
print(f"可能的疾病: {top_diseases}")

药物发现 #

python
class DrugDiscoveryFewShot:
    def __init__(self):
        self.molecule_encoder = MoleculeEncoder()
        self.target_encoder = TargetEncoder()
    
    def train(self, train_data, num_epochs=100):
        optimizer = torch.optim.Adam(
            list(self.molecule_encoder.parameters()) + 
            list(self.target_encoder.parameters()),
            lr=0.001
        )
        
        for epoch in range(num_epochs):
            for batch in train_data:
                molecules = batch['molecules']
                targets = batch['targets']
                interactions = batch['interactions']
                
                molecule_embeddings = self.molecule_encoder(molecules)
                target_embeddings = self.target_encoder(targets)
                
                scores = torch.matmul(molecule_embeddings, target_embeddings.t())
                
                loss = nn.BCEWithLogitsLoss()(scores, interactions)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
    def predict_binding(self, new_molecule, known_targets):
        self.molecule_encoder.eval()
        self.target_encoder.eval()
        
        with torch.no_grad():
            molecule_embedding = self.molecule_encoder(new_molecule.unsqueeze(0))
            target_embeddings = self.target_encoder(known_targets)
            
            binding_scores = torch.matmul(molecule_embedding, target_embeddings.t())
            
            top_targets = binding_scores.topk(10)
        
        return top_targets

工业应用 #

缺陷检测 #

python
class DefectDetectionFewShot:
    def __init__(self):
        self.encoder = models.resnet50(pretrained=True)
        self.encoder.fc = nn.Identity()
        
        self.prototype_layer = PrototypeLayer(feature_dim=2048)
    
    def train(self, train_data, num_epochs=100):
        optimizer = torch.optim.Adam(
            self.encoder.parameters(),
            lr=0.001
        )
        
        for epoch in range(num_epochs):
            for batch in train_data:
                normal_images = batch['normal_images']
                defect_images = batch['defect_images']
                defect_types = batch['defect_types']
                
                normal_features = self.encoder(normal_images)
                defect_features = self.encoder(defect_images)
                
                loss = self.prototype_layer.compute_loss(
                    normal_features, defect_features, defect_types
                )
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
    def detect(self, query_image, support_defects):
        self.encoder.eval()
        
        with torch.no_grad():
            query_feature = self.encoder(query_image.unsqueeze(0))
            
            defect_features = self.encoder(support_defects)
            
            distances = torch.cdist(query_feature, defect_features)
            
            min_distance = distances.min()
            
            is_defect = min_distance < self.threshold
        
        return is_defect, min_distance

defect_detector = DefectDetectionFewShot()

defect_detector.train(train_data, num_epochs=100)

support_defects = load_defect_samples(['scratch_1.jpg', 'dent_1.jpg'])
query_image = load_image('product_sample.jpg')

is_defect, confidence = defect_detector.detect(query_image, support_defects)
print(f"是否缺陷: {is_defect}, 置信度: {confidence:.2f}")

机器人应用 #

新任务学习 #

python
class RobotTaskLearning:
    def __init__(self):
        self.state_encoder = StateEncoder()
        self.policy_network = PolicyNetwork()
    
    def meta_train(self, tasks, num_epochs=100):
        optimizer = torch.optim.Adam(
            list(self.state_encoder.parameters()) + 
            list(self.policy_network.parameters()),
            lr=0.001
        )
        
        for epoch in range(num_epochs):
            for task in tasks:
                demonstrations = task['demonstrations']
                states = demonstrations['states']
                actions = demonstrations['actions']
                
                state_features = self.state_encoder(states)
                predicted_actions = self.policy_network(state_features)
                
                loss = nn.MSELoss()(predicted_actions, actions)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
    def adapt_to_new_task(self, demonstrations, num_steps=10):
        temp_encoder = copy.deepcopy(self.state_encoder)
        temp_policy = copy.deepcopy(self.policy_network)
        
        optimizer = torch.optim.Adam(
            list(temp_encoder.parameters()) + 
            list(temp_policy.parameters()),
            lr=0.01
        )
        
        for _ in range(num_steps):
            states = demonstrations['states']
            actions = demonstrations['actions']
            
            state_features = temp_encoder(states)
            predicted_actions = temp_policy(state_features)
            
            loss = nn.MSELoss()(predicted_actions, actions)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        return temp_encoder, temp_policy

robot_learner = RobotTaskLearning()

robot_learner.meta_train(training_tasks, num_epochs=100)

new_task_demos = collect_demonstrations(new_task)
adapted_encoder, adapted_policy = robot_learner.adapt_to_new_task(
    new_task_demos, num_steps=10
)

下一步 #

现在你已经了解了 Few-shot Learning 的各种应用场景,接下来学习 最佳实践指南,掌握实际项目中的最佳实践!

最后更新:2026-04-05