数据准备 #
数据的重要性 #
text
┌─────────────────────────────────────────────────────────────┐
│ 数据决定模型上限 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 好的数据 + 简单模型 > 坏的数据 + 复杂模型 │
│ │
│ 数据质量影响: │
│ ├── 模型性能上限 │
│ ├── 训练稳定性 │
│ ├── 泛化能力 │
│ └── 实际效果 │
│ │
└─────────────────────────────────────────────────────────────┘
数据收集 #
数据来源 #
text
公开数据集:
├── Hugging Face Datasets
│ └── https://huggingface.co/datasets
├── Kaggle
│ └── https://www.kaggle.com/datasets
├── Google Dataset Search
│ └── https://datasetsearch.research.google.com
└── Papers With Code
└── https://paperswithcode.com/datasets
自有数据:
├── 业务日志
├── 用户反馈
├── 内部文档
├── 专家标注
└── 爬虫采集
合成数据:
├── GPT-4 生成
├── 数据增强
├── 模板生成
└── 知识蒸馏
数据收集策略 #
python
from datasets import load_dataset, Dataset
import json
def collect_from_huggingface(dataset_name, subset=None):
dataset = load_dataset(dataset_name, subset)
return dataset
def collect_from_local(file_path):
data = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
data.append(json.loads(line))
return Dataset.from_list(data)
def collect_from_api(api_url, headers=None):
import requests
response = requests.get(api_url, headers=headers)
data = response.json()
return Dataset.from_list(data)
数据清洗 #
清洗流程 #
text
┌─────────────────────────────────────────────────────────────┐
│ 数据清洗流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 原始数据 │
│ │ │
│ ▼ │
│ 去重 ────────── 去除重复数据 │
│ │ │
│ ▼ │
│ 去噪 ────────── 去除噪声数据 │
│ │ │
│ ▼ │
│ 格式统一 ────── 统一数据格式 │
│ │ │
│ ▼ │
│ 质量过滤 ────── 过滤低质量数据 │
│ │ │
│ ▼ │
│ 敏感处理 ────── 处理敏感信息 │
│ │ │
│ ▼ │
│ 清洗后数据 │
│ │
└─────────────────────────────────────────────────────────────┘
去重 #
python
import hashlib
from datasets import Dataset
def deduplicate_dataset(dataset, text_column='text'):
seen = set()
unique_indices = []
for idx, example in enumerate(dataset):
text = example[text_column]
text_hash = hashlib.md5(text.encode()).hexdigest()
if text_hash not in seen:
seen.add(text_hash)
unique_indices.append(idx)
return dataset.select(unique_indices)
def deduplicate_by_similarity(dataset, text_column='text', threshold=0.95):
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
texts = [example[text_column] for example in dataset]
vectorizer = TfidfVectorizer()
tfidf_matrix = vectorizer.fit_transform(texts)
similarity_matrix = cosine_similarity(tfidf_matrix)
keep_indices = []
for i in range(len(texts)):
should_keep = True
for j in keep_indices:
if similarity_matrix[i][j] > threshold:
should_keep = False
break
if should_keep:
keep_indices.append(i)
return dataset.select(keep_indices)
去噪 #
python
import re
def clean_text(text):
text = re.sub(r'<[^>]+>', '', text)
text = re.sub(r'http\S+|www\S+', '', text)
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
def remove_noise(dataset, text_column='text'):
def clean_example(example):
example[text_column] = clean_text(example[text_column])
return example
return dataset.map(clean_example)
def filter_short_texts(dataset, text_column='text', min_length=10):
return dataset.filter(lambda x: len(x[text_column]) >= min_length)
def filter_long_texts(dataset, text_column='text', max_length=10000):
return dataset.filter(lambda x: len(x[text_column]) <= max_length)
敏感信息处理 #
python
import re
def remove_pii(text):
text = re.sub(r'\b\d{11}\b', '[PHONE]', text)
text = re.sub(r'\b[\w\.-]+@[\w\.-]+\.\w+\b', '[EMAIL]', text)
text = re.sub(r'\b\d{17}[\dXx]\b', '[ID_CARD]', text)
text = re.sub(r'\b\d{16,19}\b', '[BANK_CARD]', text)
return text
def anonymize_dataset(dataset, text_column='text'):
def anonymize_example(example):
example[text_column] = remove_pii(example[text_column])
return example
return dataset.map(anonymize_example)
数据格式 #
常见数据格式 #
text
1. 文本分类格式
{
"text": "这是一篇关于科技的文章",
"label": "科技"
}
2. 问答格式
{
"question": "什么是机器学习?",
"answer": "机器学习是人工智能的一个分支..."
}
3. 对话格式
{
"conversations": [
{"role": "user", "content": "你好"},
{"role": "assistant", "content": "你好!有什么可以帮助你的?"}
]
}
4. 指令格式
{
"instruction": "将以下句子翻译成英文",
"input": "你好,世界",
"output": "Hello, World"
}
5. 偏好格式(RLHF/DPO)
{
"prompt": "什么是机器学习?",
"chosen": "机器学习是人工智能的一个分支...",
"rejected": "机器学习就是让机器学习..."
}
格式转换 #
python
def convert_to_instruction_format(dataset, instruction_template):
def convert_example(example):
instruction = instruction_template.format(**example)
return {
"instruction": instruction,
"input": "",
"output": example.get("output", "")
}
return dataset.map(convert_example)
def convert_to_conversation_format(dataset):
def convert_example(example):
return {
"conversations": [
{"role": "user", "content": example["question"]},
{"role": "assistant", "content": example["answer"]}
]
}
return dataset.map(convert_example)
数据集划分 #
python
from datasets import DatasetDict
def split_dataset(dataset, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, seed=42):
train_test = dataset.train_test_split(
test_size=val_ratio + test_ratio,
seed=seed
)
val_test = train_test['test'].train_test_split(
test_size=test_ratio / (val_ratio + test_ratio),
seed=seed
)
return DatasetDict({
'train': train_test['train'],
'validation': val_test['train'],
'test': val_test['test']
})
数据质量评估 #
质量指标 #
text
数据质量维度:
├── 完整性
│ ├── 无缺失值
│ ├── 字段完整
│ └── 格式正确
│
├── 准确性
│ ├── 标注正确
│ ├── 内容真实
│ └── 逻辑合理
│
├── 一致性
│ ├── 标注标准统一
│ ├── 格式统一
│ └── 风格统一
│
├── 多样性
│ ├── 覆盖多种情况
│ ├── 避免重复
│ └── 平衡分布
│
└── 相关性
├── 与任务相关
├── 信息密度高
└── 避免噪声
质量检查 #
python
def check_data_quality(dataset):
stats = {
'total_samples': len(dataset),
'avg_text_length': sum(len(example['text']) for example in dataset) / len(dataset),
'min_text_length': min(len(example['text']) for example in dataset),
'max_text_length': max(len(example['text']) for example in dataset),
}
if 'label' in dataset.column_names:
from collections import Counter
label_counts = Counter(example['label'] for example in dataset)
stats['label_distribution'] = dict(label_counts)
return stats
def detect_anomalies(dataset, text_column='text'):
anomalies = []
for idx, example in enumerate(dataset):
text = example[text_column]
if len(text) < 5:
anomalies.append({'index': idx, 'type': 'too_short', 'text': text})
elif len(text) > 10000:
anomalies.append({'index': idx, 'type': 'too_long', 'text': text[:100]})
elif text.count('?') > 10:
anomalies.append({'index': idx, 'type': 'many_questions', 'text': text[:100]})
return anomalies
数据增强 #
基础增强方法 #
python
import random
from nlpaug.augmenter.word import SynonymAug, RandomSwapAug, RandomDeleteAug
def synonym_replacement(text, n=2):
aug = SynonymAug()
return aug.augment(text, n=n)
def random_swap(text, n=2):
aug = RandomSwapAug()
return aug.augment(text, n=n)
def random_delete(text, p=0.1):
aug = RandomDeleteAug()
return aug.augment(text)
def back_translation(text, src_lang='zh', tgt_lang='en'):
from transformers import MarianMTModel, MarianTokenizer
model_name = f'Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}'
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
translated = model.generate(**tokenizer(text, return_tensors="pt", padding=True))
translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
model_name_back = f'Helsinki-NLP/opus-mt-{tgt_lang}-{src_lang}'
tokenizer_back = MarianTokenizer.from_pretrained(model_name_back)
model_back = MarianMTModel.from_pretrained(model_name_back)
back_translated = model_back.generate(**tokenizer_back(translated_text, return_tensors="pt", padding=True))
return tokenizer_back.decode(back_translated[0], skip_special_tokens=True)
LLM 辅助增强 #
python
from openai import OpenAI
def llm_augmentation(text, task_description, model="gpt-4"):
client = OpenAI()
prompt = f"""
任务:{task_description}
原始文本:{text}
请生成一个语义相似但表达不同的版本:
"""
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
temperature=0.7
)
return response.choices[0].message.content
def generate_variations(dataset, n_variations=2):
augmented_data = []
for example in dataset:
augmented_data.append(example)
for _ in range(n_variations):
variation = llm_augmentation(
example['text'],
"改写以下文本,保持语义不变"
)
augmented_data.append({
'text': variation,
'label': example.get('label')
})
return Dataset.from_list(augmented_data)
数据平衡 #
类别平衡 #
python
from collections import Counter
import random
def balance_dataset(dataset, label_column='label', strategy='oversample'):
label_counts = Counter(example[label_column] for example in dataset)
max_count = max(label_counts.values())
if strategy == 'oversample':
balanced_data = []
for label in label_counts:
label_data = [ex for ex in dataset if ex[label_column] == label]
while len(label_data) < max_count:
label_data.extend(random.sample(label_data, min(len(label_data), max_count - len(label_data)))
balanced_data.extend(label_data)
random.shuffle(balanced_data)
return Dataset.from_list(balanced_data)
elif strategy == 'undersample':
min_count = min(label_counts.values())
balanced_data = []
for label in label_counts:
label_data = [ex for ex in dataset if ex[label_column] == label]
balanced_data.extend(random.sample(label_data, min_count))
random.shuffle(balanced_data)
return Dataset.from_list(balanced_data)
长度平衡 #
python
def balance_by_length(dataset, text_column='text', n_bins=5):
lengths = [len(example[text_column]) for example in dataset]
min_len, max_len = min(lengths), max(lengths)
bin_size = (max_len - min_len) / n_bins
bins = [[] for _ in range(n_bins)]
for example in dataset:
length = len(example[text_column])
bin_idx = min(int((length - min_len) / bin_size), n_bins - 1)
bins[bin_idx].append(example)
max_bin_size = max(len(b) for b in bins)
balanced_data = []
for bin_data in bins:
while len(bin_data) < max_bin_size:
bin_data.extend(random.sample(bin_data, min(len(bin_data), max_bin_size - len(bin_data)))
balanced_data.extend(bin_data)
random.shuffle(balanced_data)
return Dataset.from_list(balanced_data)
数据预处理 #
Tokenization #
python
from transformers import AutoTokenizer
def tokenize_dataset(dataset, tokenizer_name, text_column='text', max_length=512):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def tokenize_function(examples):
return tokenizer(
examples[text_column],
padding='max_length',
truncation=True,
max_length=max_length
)
return dataset.map(tokenize_function, batched=True)
数据标准化 #
python
def normalize_text(text):
import unicodedata
text = unicodedata.normalize('NFKC', text)
text = text.lower()
text = re.sub(r'\s+', ' ', text)
return text.strip()
def normalize_dataset(dataset, text_column='text'):
def normalize_example(example):
example[text_column] = normalize_text(example[text_column])
return example
return dataset.map(normalize_example)
数据版本管理 #
python
import json
import hashlib
from datetime import datetime
def save_dataset_version(dataset, output_dir, version_name, metadata=None):
dataset.save_to_disk(f"{output_dir}/{version_name}")
manifest = {
'version': version_name,
'created_at': datetime.now().isoformat(),
'num_samples': len(dataset),
'columns': dataset.column_names,
'metadata': metadata or {}
}
with open(f"{output_dir}/{version_name}/manifest.json", 'w') as f:
json.dump(manifest, f, indent=2)
def compute_data_hash(dataset):
data_str = json.dumps([dict(example) for example in dataset], sort_keys=True)
return hashlib.sha256(data_str.encode()).hexdigest()
下一步 #
现在你已经掌握了数据准备的核心技能,接下来学习 模型选择,选择合适的基座模型!
最后更新:2026-04-05