Few-shot Learning 应用场景与案例 #
计算机视觉应用 #
图像分类 #
text
┌─────────────────────────────────────────────────────────────┐
│ 图像分类应用 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 应用场景: │
│ ├── 医疗影像诊断(罕见病识别) │
│ ├── 工业缺陷检测(新产品缺陷) │
│ ├── 野生动物识别(稀有物种) │
│ └── 卫星图像分类(新地物类型) │
│ │
│ 挑战: │
│ ├── 类内变异大 │
│ ├── 类间相似度高 │
│ ├── 样本质量不一 │
│ └── 标注成本高 │
│ │
└─────────────────────────────────────────────────────────────┘
医疗影像诊断案例 #
python
import torch
import torch.nn as nn
from torchvision import models
class MedicalImageFewShot:
def __init__(self, num_classes=5, num_support=5):
self.encoder = models.resnet50(pretrained=True)
self.encoder.fc = nn.Identity()
self.classifier = PrototypicalNetwork(
feature_dim=2048,
num_classes=num_classes
)
self.num_support = num_support
def train(self, train_data, num_epochs=100):
optimizer = torch.optim.Adam(
list(self.encoder.parameters()) +
list(self.classifier.parameters()),
lr=0.001
)
for epoch in range(num_epochs):
for episode in self.sample_episodes(train_data):
support_images = episode['support_images']
support_labels = episode['support_labels']
query_images = episode['query_images']
query_labels = episode['query_labels']
support_features = self.encoder(support_images)
query_features = self.encoder(query_images)
logits = self.classifier(
support_features, support_labels, query_features
)
loss = nn.CrossEntropyLoss()(logits, query_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def predict(self, support_images, support_labels, query_images):
self.encoder.eval()
with torch.no_grad():
support_features = self.encoder(support_images)
query_features = self.encoder(query_images)
logits = self.classifier(
support_features, support_labels, query_features
)
predictions = logits.argmax(dim=1)
return predictions
class PrototypicalNetwork(nn.Module):
def __init__(self, feature_dim, num_classes):
super().__init__()
self.feature_dim = feature_dim
self.num_classes = num_classes
def forward(self, support_features, support_labels, query_features):
prototypes = self.compute_prototypes(support_features, support_labels)
distances = torch.cdist(query_features, prototypes)
logits = -distances
return logits
def compute_prototypes(self, features, labels):
prototypes = []
for i in range(self.num_classes):
mask = labels == i
class_features = features[mask]
prototype = class_features.mean(dim=0)
prototypes.append(prototype)
return torch.stack(prototypes)
medical_fewshot = MedicalImageFewShot(num_classes=5, num_support=5)
medical_fewshot.train(train_data, num_epochs=100)
support_images = load_medical_images(['disease_A_sample.jpg', 'disease_B_sample.jpg'])
support_labels = torch.tensor([0, 1])
query_images = load_medical_images(['patient_scan.jpg'])
predictions = medical_fewshot.predict(support_images, support_labels, query_images)
人脸识别案例 #
python
class FaceRecognitionFewShot:
def __init__(self):
self.encoder = self.build_face_encoder()
self.distance_metric = nn.CosineSimilarity()
def build_face_encoder(self):
model = models.resnet50(pretrained=True)
model.fc = nn.Linear(2048, 512)
return model
def register_user(self, user_id, face_images):
features = []
for image in face_images:
feature = self.encoder(image.unsqueeze(0))
features.append(feature)
user_feature = torch.stack(features).mean(dim=0)
self.user_database[user_id] = user_feature
return user_feature
def recognize(self, query_image, threshold=0.8):
query_feature = self.encoder(query_image.unsqueeze(0))
best_match = None
best_score = -1
for user_id, user_feature in self.user_database.items():
score = self.distance_metric(query_feature, user_feature)
if score > best_score:
best_score = score
best_match = user_id
if best_score > threshold:
return best_match, best_score
else:
return None, best_score
face_recognition = FaceRecognitionFewShot()
face_recognition.register_user('user_001', [image1, image2, image3])
user_id, confidence = face_recognition.recognize(query_image)
print(f"识别结果: {user_id}, 置信度: {confidence:.2f}")
自然语言处理应用 #
文本分类 #
text
┌─────────────────────────────────────────────────────────────┐
│ 文本分类应用 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 应用场景: │
│ ├── 意图识别(新意图类别) │
│ ├── 情感分析(新领域) │
│ ├── 新闻分类(新类别) │
│ └── 垃圾邮件检测(新类型) │
│ │
│ 挑战: │
│ ├── 文本多样性 │
│ ├── 语义理解 │
│ ├── 领域适应 │
│ └── 标注成本 │
│ │
└─────────────────────────────────────────────────────────────┘
意图识别案例 #
python
import torch
from transformers import BertModel, BertTokenizer
class IntentRecognitionFewShot:
def __init__(self, model_name='bert-base-uncased', num_classes=10):
self.encoder = BertModel.from_pretrained(model_name)
self.tokenizer = BertTokenizer.from_pretrained(model_name)
self.classifier = PrototypicalNetwork(
feature_dim=768,
num_classes=num_classes
)
def encode_texts(self, texts):
encoded = self.tokenizer(
texts,
padding=True,
truncation=True,
return_tensors='pt'
)
outputs = self.encoder(**encoded)
features = outputs.last_hidden_state[:, 0, :]
return features
def train(self, train_data, num_epochs=50):
optimizer = torch.optim.Adam(
list(self.encoder.parameters()) +
list(self.classifier.parameters()),
lr=2e-5
)
for epoch in range(num_epochs):
for episode in self.sample_episodes(train_data):
support_texts = episode['support_texts']
support_labels = episode['support_labels']
query_texts = episode['query_texts']
query_labels = episode['query_labels']
support_features = self.encode_texts(support_texts)
query_features = self.encode_texts(query_texts)
logits = self.classifier(
support_features, support_labels, query_features
)
loss = nn.CrossEntropyLoss()(logits, query_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def predict(self, support_texts, support_labels, query_texts):
self.encoder.eval()
with torch.no_grad():
support_features = self.encode_texts(support_texts)
query_features = self.encode_texts(query_texts)
logits = self.classifier(
support_features, support_labels, query_features
)
predictions = logits.argmax(dim=1)
return predictions
intent_recognition = IntentRecognitionFewShot(num_classes=10)
support_texts = [
"我想预订一张机票",
"帮我查一下天气",
"播放一首音乐"
]
support_labels = torch.tensor([0, 1, 2])
query_texts = ["我想订票去北京"]
predictions = intent_recognition.predict(support_texts, support_labels, query_texts)
print(f"预测意图: {predictions}")
命名实体识别案例 #
python
class NERFewShot:
def __init__(self, model_name='bert-base-uncased'):
self.encoder = BertModel.from_pretrained(model_name)
self.tokenizer = BertTokenizer.from_pretrained(model_name)
self.tag_classifier = nn.Linear(768, num_tags)
def train(self, train_data, num_epochs=30):
optimizer = torch.optim.Adam(
list(self.encoder.parameters()) +
list(self.tag_classifier.parameters()),
lr=2e-5
)
for epoch in range(num_epochs):
for batch in train_data:
texts = batch['texts']
tags = batch['tags']
encoded = self.tokenizer(
texts,
padding=True,
truncation=True,
return_tensors='pt'
)
outputs = self.encoder(**encoded)
sequence_features = outputs.last_hidden_state
logits = self.tag_classifier(sequence_features)
loss = nn.CrossEntropyLoss()(
logits.view(-1, num_tags),
tags.view(-1)
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def predict(self, text):
self.encoder.eval()
with torch.no_grad():
encoded = self.tokenizer(
text,
return_tensors='pt'
)
outputs = self.encoder(**encoded)
sequence_features = outputs.last_hidden_state
logits = self.tag_classifier(sequence_features)
predictions = logits.argmax(dim=-1)
return predictions
推荐系统应用 #
冷启动问题 #
text
┌─────────────────────────────────────────────────────────────┐
│ 推荐系统应用 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 应用场景: │
│ ├── 新用户推荐 │
│ ├── 新物品推荐 │
│ ├── 跨领域推荐 │
│ └── 实时个性化 │
│ │
│ 挑战: │
│ ├── 用户行为稀疏 │
│ ├── 物品特征缺失 │
│ ├── 冷启动问题 │
│ └── 实时性要求 │
│ │
└─────────────────────────────────────────────────────────────┘
新用户推荐案例 #
python
class UserColdStartFewShot:
def __init__(self, num_items, feature_dim=128):
self.user_encoder = nn.Sequential(
nn.Linear(feature_dim, 256),
nn.ReLU(),
nn.Linear(256, 128)
)
self.item_encoder = nn.Sequential(
nn.Linear(feature_dim, 256),
nn.ReLU(),
nn.Linear(256, 128)
)
self.num_items = num_items
def train(self, train_data, num_epochs=100):
optimizer = torch.optim.Adam(
list(self.user_encoder.parameters()) +
list(self.item_encoder.parameters()),
lr=0.001
)
for epoch in range(num_epochs):
for batch in train_data:
user_features = batch['user_features']
item_features = batch['item_features']
interactions = batch['interactions']
user_embeddings = self.user_encoder(user_features)
item_embeddings = self.item_encoder(item_features)
scores = torch.matmul(user_embeddings, item_embeddings.t())
loss = nn.BCEWithLogitsLoss()(scores, interactions)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def recommend(self, user_feature, item_features, top_k=10):
self.user_encoder.eval()
self.item_encoder.eval()
with torch.no_grad():
user_embedding = self.user_encoder(user_feature.unsqueeze(0))
item_embeddings = self.item_encoder(item_features)
scores = torch.matmul(user_embedding, item_embeddings.t())
top_k_indices = scores.topk(top_k).indices
return top_k_indices
recommender = UserColdStartFewShot(num_items=10000, feature_dim=128)
recommender.train(train_data, num_epochs=100)
new_user_feature = extract_user_features(new_user)
item_features = extract_all_item_features()
recommendations = recommender.recommend(new_user_feature, item_features, top_k=10)
print(f"推荐物品: {recommendations}")
新物品推荐案例 #
python
class ItemColdStartFewShot:
def __init__(self, num_users, feature_dim=128):
self.item_encoder = nn.Sequential(
nn.Linear(feature_dim, 256),
nn.ReLU(),
nn.Linear(256, 128)
)
self.user_embeddings = nn.Embedding(num_users, 128)
def train(self, train_data, num_epochs=100):
optimizer = torch.optim.Adam(
list(self.item_encoder.parameters()) +
list(self.user_embeddings.parameters()),
lr=0.001
)
for epoch in range(num_epochs):
for batch in train_data:
item_features = batch['item_features']
user_ids = batch['user_ids']
interactions = batch['interactions']
item_embeddings = self.item_encoder(item_features)
user_embeddings = self.user_embeddings(user_ids)
scores = torch.sum(item_embeddings * user_embeddings, dim=1)
loss = nn.BCEWithLogitsLoss()(scores, interactions)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def predict_users(self, new_item_feature, num_users, top_k=100):
self.item_encoder.eval()
with torch.no_grad():
item_embedding = self.item_encoder(new_item_feature.unsqueeze(0))
all_user_embeddings = self.user_embeddings.weight
scores = torch.matmul(all_user_embeddings, item_embedding.t()).squeeze()
top_k_users = scores.topk(top_k).indices
return top_k_users
医疗健康应用 #
罕见病诊断 #
python
class RareDiseaseDiagnosis:
def __init__(self, num_diseases=50):
self.symptom_encoder = nn.Sequential(
nn.Linear(num_symptoms, 256),
nn.ReLU(),
nn.Linear(256, 128)
)
self.disease_encoder = nn.Sequential(
nn.Linear(num_disease_features, 256),
nn.ReLU(),
nn.Linear(256, 128)
)
self.num_diseases = num_diseases
def train(self, train_data, num_epochs=100):
optimizer = torch.optim.Adam(
list(self.symptom_encoder.parameters()) +
list(self.disease_encoder.parameters()),
lr=0.001
)
for epoch in range(num_epochs):
for batch in train_data:
symptoms = batch['symptoms']
disease_features = batch['disease_features']
labels = batch['labels']
symptom_embeddings = self.symptom_encoder(symptoms)
disease_embeddings = self.disease_encoder(disease_features)
similarities = torch.matmul(symptom_embeddings, disease_embeddings.t())
loss = nn.CrossEntropyLoss()(similarities, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def diagnose(self, patient_symptoms, known_disease_features):
self.symptom_encoder.eval()
self.disease_encoder.eval()
with torch.no_grad():
symptom_embedding = self.symptom_encoder(patient_symptoms.unsqueeze(0))
disease_embeddings = self.disease_encoder(known_disease_features)
similarities = torch.matmul(symptom_embedding, disease_embeddings.t())
probabilities = torch.softmax(similarities, dim=1)
top_diseases = probabilities.topk(5)
return top_diseases
diagnosis_system = RareDiseaseDiagnosis(num_diseases=50)
diagnosis_system.train(train_data, num_epochs=100)
patient_symptoms = extract_symptoms(patient)
known_diseases = get_rare_disease_features()
top_diseases = diagnosis_system.diagnose(patient_symptoms, known_diseases)
print(f"可能的疾病: {top_diseases}")
药物发现 #
python
class DrugDiscoveryFewShot:
def __init__(self):
self.molecule_encoder = MoleculeEncoder()
self.target_encoder = TargetEncoder()
def train(self, train_data, num_epochs=100):
optimizer = torch.optim.Adam(
list(self.molecule_encoder.parameters()) +
list(self.target_encoder.parameters()),
lr=0.001
)
for epoch in range(num_epochs):
for batch in train_data:
molecules = batch['molecules']
targets = batch['targets']
interactions = batch['interactions']
molecule_embeddings = self.molecule_encoder(molecules)
target_embeddings = self.target_encoder(targets)
scores = torch.matmul(molecule_embeddings, target_embeddings.t())
loss = nn.BCEWithLogitsLoss()(scores, interactions)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def predict_binding(self, new_molecule, known_targets):
self.molecule_encoder.eval()
self.target_encoder.eval()
with torch.no_grad():
molecule_embedding = self.molecule_encoder(new_molecule.unsqueeze(0))
target_embeddings = self.target_encoder(known_targets)
binding_scores = torch.matmul(molecule_embedding, target_embeddings.t())
top_targets = binding_scores.topk(10)
return top_targets
工业应用 #
缺陷检测 #
python
class DefectDetectionFewShot:
def __init__(self):
self.encoder = models.resnet50(pretrained=True)
self.encoder.fc = nn.Identity()
self.prototype_layer = PrototypeLayer(feature_dim=2048)
def train(self, train_data, num_epochs=100):
optimizer = torch.optim.Adam(
self.encoder.parameters(),
lr=0.001
)
for epoch in range(num_epochs):
for batch in train_data:
normal_images = batch['normal_images']
defect_images = batch['defect_images']
defect_types = batch['defect_types']
normal_features = self.encoder(normal_images)
defect_features = self.encoder(defect_images)
loss = self.prototype_layer.compute_loss(
normal_features, defect_features, defect_types
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def detect(self, query_image, support_defects):
self.encoder.eval()
with torch.no_grad():
query_feature = self.encoder(query_image.unsqueeze(0))
defect_features = self.encoder(support_defects)
distances = torch.cdist(query_feature, defect_features)
min_distance = distances.min()
is_defect = min_distance < self.threshold
return is_defect, min_distance
defect_detector = DefectDetectionFewShot()
defect_detector.train(train_data, num_epochs=100)
support_defects = load_defect_samples(['scratch_1.jpg', 'dent_1.jpg'])
query_image = load_image('product_sample.jpg')
is_defect, confidence = defect_detector.detect(query_image, support_defects)
print(f"是否缺陷: {is_defect}, 置信度: {confidence:.2f}")
机器人应用 #
新任务学习 #
python
class RobotTaskLearning:
def __init__(self):
self.state_encoder = StateEncoder()
self.policy_network = PolicyNetwork()
def meta_train(self, tasks, num_epochs=100):
optimizer = torch.optim.Adam(
list(self.state_encoder.parameters()) +
list(self.policy_network.parameters()),
lr=0.001
)
for epoch in range(num_epochs):
for task in tasks:
demonstrations = task['demonstrations']
states = demonstrations['states']
actions = demonstrations['actions']
state_features = self.state_encoder(states)
predicted_actions = self.policy_network(state_features)
loss = nn.MSELoss()(predicted_actions, actions)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def adapt_to_new_task(self, demonstrations, num_steps=10):
temp_encoder = copy.deepcopy(self.state_encoder)
temp_policy = copy.deepcopy(self.policy_network)
optimizer = torch.optim.Adam(
list(temp_encoder.parameters()) +
list(temp_policy.parameters()),
lr=0.01
)
for _ in range(num_steps):
states = demonstrations['states']
actions = demonstrations['actions']
state_features = temp_encoder(states)
predicted_actions = temp_policy(state_features)
loss = nn.MSELoss()(predicted_actions, actions)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return temp_encoder, temp_policy
robot_learner = RobotTaskLearning()
robot_learner.meta_train(training_tasks, num_epochs=100)
new_task_demos = collect_demonstrations(new_task)
adapted_encoder, adapted_policy = robot_learner.adapt_to_new_task(
new_task_demos, num_steps=10
)
下一步 #
现在你已经了解了 Few-shot Learning 的各种应用场景,接下来学习 最佳实践指南,掌握实际项目中的最佳实践!
最后更新:2026-04-05