数据准备 #

数据的重要性 #

text
┌─────────────────────────────────────────────────────────────┐
│                   数据决定模型上限                           │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  好的数据 + 简单模型 > 坏的数据 + 复杂模型                   │
│                                                             │
│  数据质量影响:                                              │
│  ├── 模型性能上限                                           │
│  ├── 训练稳定性                                             │
│  ├── 泛化能力                                               │
│  └── 实际效果                                               │
│                                                             │
└─────────────────────────────────────────────────────────────┘

数据收集 #

数据来源 #

text
公开数据集:
├── Hugging Face Datasets
│   └── https://huggingface.co/datasets
├── Kaggle
│   └── https://www.kaggle.com/datasets
├── Google Dataset Search
│   └── https://datasetsearch.research.google.com
└── Papers With Code
    └── https://paperswithcode.com/datasets

自有数据:
├── 业务日志
├── 用户反馈
├── 内部文档
├── 专家标注
└── 爬虫采集

合成数据:
├── GPT-4 生成
├── 数据增强
├── 模板生成
└── 知识蒸馏

数据收集策略 #

python
from datasets import load_dataset, Dataset
import json

def collect_from_huggingface(dataset_name, subset=None):
    dataset = load_dataset(dataset_name, subset)
    return dataset

def collect_from_local(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))
    return Dataset.from_list(data)

def collect_from_api(api_url, headers=None):
    import requests
    response = requests.get(api_url, headers=headers)
    data = response.json()
    return Dataset.from_list(data)

数据清洗 #

清洗流程 #

text
┌─────────────────────────────────────────────────────────────┐
│                    数据清洗流程                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  原始数据                                                   │
│     │                                                       │
│     ▼                                                       │
│  去重 ────────── 去除重复数据                               │
│     │                                                       │
│     ▼                                                       │
│  去噪 ────────── 去除噪声数据                               │
│     │                                                       │
│     ▼                                                       │
│  格式统一 ────── 统一数据格式                               │
│     │                                                       │
│     ▼                                                       │
│  质量过滤 ────── 过滤低质量数据                             │
│     │                                                       │
│     ▼                                                       │
│  敏感处理 ────── 处理敏感信息                               │
│     │                                                       │
│     ▼                                                       │
│  清洗后数据                                                 │
│                                                             │
└─────────────────────────────────────────────────────────────┘

去重 #

python
import hashlib
from datasets import Dataset

def deduplicate_dataset(dataset, text_column='text'):
    seen = set()
    unique_indices = []
    
    for idx, example in enumerate(dataset):
        text = example[text_column]
        text_hash = hashlib.md5(text.encode()).hexdigest()
        
        if text_hash not in seen:
            seen.add(text_hash)
            unique_indices.append(idx)
    
    return dataset.select(unique_indices)

def deduplicate_by_similarity(dataset, text_column='text', threshold=0.95):
    from sklearn.feature_extraction.text import TfidfVectorizer
    from sklearn.metrics.pairwise import cosine_similarity
    
    texts = [example[text_column] for example in dataset]
    vectorizer = TfidfVectorizer()
    tfidf_matrix = vectorizer.fit_transform(texts)
    
    similarity_matrix = cosine_similarity(tfidf_matrix)
    
    keep_indices = []
    for i in range(len(texts)):
        should_keep = True
        for j in keep_indices:
            if similarity_matrix[i][j] > threshold:
                should_keep = False
                break
        if should_keep:
            keep_indices.append(i)
    
    return dataset.select(keep_indices)

去噪 #

python
import re

def clean_text(text):
    text = re.sub(r'<[^>]+>', '', text)
    text = re.sub(r'http\S+|www\S+', '', text)
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    return text

def remove_noise(dataset, text_column='text'):
    def clean_example(example):
        example[text_column] = clean_text(example[text_column])
        return example
    
    return dataset.map(clean_example)

def filter_short_texts(dataset, text_column='text', min_length=10):
    return dataset.filter(lambda x: len(x[text_column]) >= min_length)

def filter_long_texts(dataset, text_column='text', max_length=10000):
    return dataset.filter(lambda x: len(x[text_column]) <= max_length)

敏感信息处理 #

python
import re

def remove_pii(text):
    text = re.sub(r'\b\d{11}\b', '[PHONE]', text)
    text = re.sub(r'\b[\w\.-]+@[\w\.-]+\.\w+\b', '[EMAIL]', text)
    text = re.sub(r'\b\d{17}[\dXx]\b', '[ID_CARD]', text)
    text = re.sub(r'\b\d{16,19}\b', '[BANK_CARD]', text)
    return text

def anonymize_dataset(dataset, text_column='text'):
    def anonymize_example(example):
        example[text_column] = remove_pii(example[text_column])
        return example
    
    return dataset.map(anonymize_example)

数据格式 #

常见数据格式 #

text
1. 文本分类格式
{
    "text": "这是一篇关于科技的文章",
    "label": "科技"
}

2. 问答格式
{
    "question": "什么是机器学习?",
    "answer": "机器学习是人工智能的一个分支..."
}

3. 对话格式
{
    "conversations": [
        {"role": "user", "content": "你好"},
        {"role": "assistant", "content": "你好!有什么可以帮助你的?"}
    ]
}

4. 指令格式
{
    "instruction": "将以下句子翻译成英文",
    "input": "你好,世界",
    "output": "Hello, World"
}

5. 偏好格式(RLHF/DPO)
{
    "prompt": "什么是机器学习?",
    "chosen": "机器学习是人工智能的一个分支...",
    "rejected": "机器学习就是让机器学习..."
}

格式转换 #

python
def convert_to_instruction_format(dataset, instruction_template):
    def convert_example(example):
        instruction = instruction_template.format(**example)
        return {
            "instruction": instruction,
            "input": "",
            "output": example.get("output", "")
        }
    
    return dataset.map(convert_example)

def convert_to_conversation_format(dataset):
    def convert_example(example):
        return {
            "conversations": [
                {"role": "user", "content": example["question"]},
                {"role": "assistant", "content": example["answer"]}
            ]
        }
    
    return dataset.map(convert_example)

数据集划分 #

python
from datasets import DatasetDict

def split_dataset(dataset, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, seed=42):
    train_test = dataset.train_test_split(
        test_size=val_ratio + test_ratio,
        seed=seed
    )
    
    val_test = train_test['test'].train_test_split(
        test_size=test_ratio / (val_ratio + test_ratio),
        seed=seed
    )
    
    return DatasetDict({
        'train': train_test['train'],
        'validation': val_test['train'],
        'test': val_test['test']
    })

数据质量评估 #

质量指标 #

text
数据质量维度:
├── 完整性
│   ├── 无缺失值
│   ├── 字段完整
│   └── 格式正确
│
├── 准确性
│   ├── 标注正确
│   ├── 内容真实
│   └── 逻辑合理
│
├── 一致性
│   ├── 标注标准统一
│   ├── 格式统一
│   └── 风格统一
│
├── 多样性
│   ├── 覆盖多种情况
│   ├── 避免重复
│   └── 平衡分布
│
└── 相关性
    ├── 与任务相关
    ├── 信息密度高
    └── 避免噪声

质量检查 #

python
def check_data_quality(dataset):
    stats = {
        'total_samples': len(dataset),
        'avg_text_length': sum(len(example['text']) for example in dataset) / len(dataset),
        'min_text_length': min(len(example['text']) for example in dataset),
        'max_text_length': max(len(example['text']) for example in dataset),
    }
    
    if 'label' in dataset.column_names:
        from collections import Counter
        label_counts = Counter(example['label'] for example in dataset)
        stats['label_distribution'] = dict(label_counts)
    
    return stats

def detect_anomalies(dataset, text_column='text'):
    anomalies = []
    
    for idx, example in enumerate(dataset):
        text = example[text_column]
        
        if len(text) < 5:
            anomalies.append({'index': idx, 'type': 'too_short', 'text': text})
        elif len(text) > 10000:
            anomalies.append({'index': idx, 'type': 'too_long', 'text': text[:100]})
        elif text.count('?') > 10:
            anomalies.append({'index': idx, 'type': 'many_questions', 'text': text[:100]})
    
    return anomalies

数据增强 #

基础增强方法 #

python
import random
from nlpaug.augmenter.word import SynonymAug, RandomSwapAug, RandomDeleteAug

def synonym_replacement(text, n=2):
    aug = SynonymAug()
    return aug.augment(text, n=n)

def random_swap(text, n=2):
    aug = RandomSwapAug()
    return aug.augment(text, n=n)

def random_delete(text, p=0.1):
    aug = RandomDeleteAug()
    return aug.augment(text)

def back_translation(text, src_lang='zh', tgt_lang='en'):
    from transformers import MarianMTModel, MarianTokenizer
    
    model_name = f'Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}'
    tokenizer = MarianTokenizer.from_pretrained(model_name)
    model = MarianMTModel.from_pretrained(model_name)
    
    translated = model.generate(**tokenizer(text, return_tensors="pt", padding=True))
    translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
    
    model_name_back = f'Helsinki-NLP/opus-mt-{tgt_lang}-{src_lang}'
    tokenizer_back = MarianTokenizer.from_pretrained(model_name_back)
    model_back = MarianMTModel.from_pretrained(model_name_back)
    
    back_translated = model_back.generate(**tokenizer_back(translated_text, return_tensors="pt", padding=True))
    return tokenizer_back.decode(back_translated[0], skip_special_tokens=True)

LLM 辅助增强 #

python
from openai import OpenAI

def llm_augmentation(text, task_description, model="gpt-4"):
    client = OpenAI()
    
    prompt = f"""
    任务:{task_description}
    
    原始文本:{text}
    
    请生成一个语义相似但表达不同的版本:
    """
    
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        temperature=0.7
    )
    
    return response.choices[0].message.content

def generate_variations(dataset, n_variations=2):
    augmented_data = []
    
    for example in dataset:
        augmented_data.append(example)
        
        for _ in range(n_variations):
            variation = llm_augmentation(
                example['text'],
                "改写以下文本,保持语义不变"
            )
            augmented_data.append({
                'text': variation,
                'label': example.get('label')
            })
    
    return Dataset.from_list(augmented_data)

数据平衡 #

类别平衡 #

python
from collections import Counter
import random

def balance_dataset(dataset, label_column='label', strategy='oversample'):
    label_counts = Counter(example[label_column] for example in dataset)
    max_count = max(label_counts.values())
    
    if strategy == 'oversample':
        balanced_data = []
        for label in label_counts:
            label_data = [ex for ex in dataset if ex[label_column] == label]
            while len(label_data) < max_count:
                label_data.extend(random.sample(label_data, min(len(label_data), max_count - len(label_data)))
            balanced_data.extend(label_data)
        
        random.shuffle(balanced_data)
        return Dataset.from_list(balanced_data)
    
    elif strategy == 'undersample':
        min_count = min(label_counts.values())
        balanced_data = []
        for label in label_counts:
            label_data = [ex for ex in dataset if ex[label_column] == label]
            balanced_data.extend(random.sample(label_data, min_count))
        
        random.shuffle(balanced_data)
        return Dataset.from_list(balanced_data)

长度平衡 #

python
def balance_by_length(dataset, text_column='text', n_bins=5):
    lengths = [len(example[text_column]) for example in dataset]
    min_len, max_len = min(lengths), max(lengths)
    bin_size = (max_len - min_len) / n_bins
    
    bins = [[] for _ in range(n_bins)]
    for example in dataset:
        length = len(example[text_column])
        bin_idx = min(int((length - min_len) / bin_size), n_bins - 1)
        bins[bin_idx].append(example)
    
    max_bin_size = max(len(b) for b in bins)
    balanced_data = []
    
    for bin_data in bins:
        while len(bin_data) < max_bin_size:
            bin_data.extend(random.sample(bin_data, min(len(bin_data), max_bin_size - len(bin_data)))
        balanced_data.extend(bin_data)
    
    random.shuffle(balanced_data)
    return Dataset.from_list(balanced_data)

数据预处理 #

Tokenization #

python
from transformers import AutoTokenizer

def tokenize_dataset(dataset, tokenizer_name, text_column='text', max_length=512):
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    
    def tokenize_function(examples):
        return tokenizer(
            examples[text_column],
            padding='max_length',
            truncation=True,
            max_length=max_length
        )
    
    return dataset.map(tokenize_function, batched=True)

数据标准化 #

python
def normalize_text(text):
    import unicodedata
    
    text = unicodedata.normalize('NFKC', text)
    text = text.lower()
    text = re.sub(r'\s+', ' ', text)
    
    return text.strip()

def normalize_dataset(dataset, text_column='text'):
    def normalize_example(example):
        example[text_column] = normalize_text(example[text_column])
        return example
    
    return dataset.map(normalize_example)

数据版本管理 #

python
import json
import hashlib
from datetime import datetime

def save_dataset_version(dataset, output_dir, version_name, metadata=None):
    dataset.save_to_disk(f"{output_dir}/{version_name}")
    
    manifest = {
        'version': version_name,
        'created_at': datetime.now().isoformat(),
        'num_samples': len(dataset),
        'columns': dataset.column_names,
        'metadata': metadata or {}
    }
    
    with open(f"{output_dir}/{version_name}/manifest.json", 'w') as f:
        json.dump(manifest, f, indent=2)

def compute_data_hash(dataset):
    data_str = json.dumps([dict(example) for example in dataset], sort_keys=True)
    return hashlib.sha256(data_str.encode()).hexdigest()

下一步 #

现在你已经掌握了数据准备的核心技能,接下来学习 模型选择,选择合适的基座模型!

最后更新:2026-04-05