PyTorch 数据加载 #

数据加载概述 #

在深度学习中,高效的数据加载对于训练至关重要。PyTorch 提供了 DatasetDataLoader 两个核心类来处理数据。

text
┌─────────────────────────────────────────────────────────────┐
│                    数据加载流程                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   原始数据                                                   │
│      │                                                      │
│      ▼                                                      │
│   Dataset                                                   │
│   (定义如何获取单个样本)                                     │
│      │                                                      │
│      ▼                                                      │
│   DataLoader                                                │
│   (批量加载、打乱、并行)                                     │
│      │                                                      │
│      ▼                                                      │
│   训练循环                                                   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Dataset 类 #

Dataset 基本概念 #

Dataset 是一个抽象类,需要实现两个方法:

  • __len__: 返回数据集大小
  • __getitem__: 根据索引获取单个样本
python
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

import torch

data = torch.randn(100, 10)
labels = torch.randint(0, 2, (100,))

dataset = CustomDataset(data, labels)
print(f"数据集大小: {len(dataset)}")
print(f"第一个样本: {dataset[0]}")

图像数据集示例 #

python
import os
from PIL import Image
from torch.utils.data import Dataset
import torch

class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        
        for label, class_name in enumerate(sorted(os.listdir(root_dir))):
            class_dir = os.path.join(root_dir, class_name)
            if os.path.isdir(class_dir):
                for img_name in os.listdir(class_dir):
                    self.images.append(os.path.join(class_dir, img_name))
                    self.labels.append(label)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

文本数据集示例 #

python
from torch.utils.data import Dataset
import torch

class TextDataset(Dataset):
    def __init__(self, texts, labels, vocab, max_length=100):
        self.texts = texts
        self.labels = labels
        self.vocab = vocab
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        tokens = text.lower().split()
        token_ids = [self.vocab.get(token, self.vocab['<unk>']) for token in tokens]
        
        if len(token_ids) < self.max_length:
            token_ids = token_ids + [self.vocab['<pad>']] * (self.max_length - len(token_ids))
        else:
            token_ids = token_ids[:self.max_length]
        
        return torch.tensor(token_ids), torch.tensor(label)

DataLoader 类 #

DataLoader 基本用法 #

python
from torch.utils.data import DataLoader, TensorDataset
import torch

data = torch.randn(1000, 10)
labels = torch.randint(0, 2, (1000,))

dataset = TensorDataset(data, labels)

dataloader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0
)

for batch_data, batch_labels in dataloader:
    print(f"Batch data shape: {batch_data.shape}")
    print(f"Batch labels shape: {batch_labels.shape}")
    break

DataLoader 参数详解 #

text
┌─────────────────────────────────────────────────────────────┐
│                    DataLoader 参数                           │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  dataset: Dataset                                           │
│  - 数据集对象                                               │
│                                                             │
│  batch_size: int                                            │
│  - 每批样本数量                                             │
│  - 典型值: 32, 64, 128, 256                                 │
│                                                             │
│  shuffle: bool                                              │
│  - 是否打乱数据                                             │
│  - 训练时通常为 True                                        │
│                                                             │
│  num_workers: int                                           │
│  - 数据加载进程数                                           │
│  - 0: 主进程加载                                            │
│  - >0: 多进程并行加载                                       │
│                                                             │
│  pin_memory: bool                                           │
│  - 是否将数据固定在内存                                     │
│  - GPU 训练时建议 True                                      │
│                                                             │
│  drop_last: bool                                            │
│  - 是否丢弃最后不完整的批次                                 │
│                                                             │
│  collate_fn: callable                                       │
│  - 自定义批次整理函数                                       │
│                                                             │
└─────────────────────────────────────────────────────────────┘
python
from torch.utils.data import DataLoader

dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)

多进程加载 #

python
import torch
from torch.utils.data import DataLoader, TensorDataset

data = torch.randn(10000, 10)
labels = torch.randint(0, 2, (10000,))
dataset = TensorDataset(data, labels)

import time

start = time.time()
dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
for _ in dataloader:
    pass
print(f"单进程时间: {time.time() - start:.2f}s")

start = time.time()
dataloader = DataLoader(dataset, batch_size=64, num_workers=4)
for _ in dataloader:
    pass
print(f"多进程时间: {time.time() - start:.2f}s")

自定义 collate_fn #

python
from torch.utils.data import DataLoader, Dataset
import torch

class VariableLengthDataset(Dataset):
    def __init__(self):
        self.data = [
            torch.randn(5),
            torch.randn(10),
            torch.randn(3),
            torch.randn(8)
        ]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

def collate_fn(batch):
    lengths = [len(item) for item in batch]
    max_len = max(lengths)
    
    padded_batch = []
    for item in batch:
        padding = torch.zeros(max_len - len(item))
        padded_item = torch.cat([item, padding])
        padded_batch.append(padded_item)
    
    return torch.stack(padded_batch), torch.tensor(lengths)

dataset = VariableLengthDataset()
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)

for batch, lengths in dataloader:
    print(f"Batch shape: {batch.shape}")
    print(f"Lengths: {lengths}")

数据预处理 #

torchvision.transforms #

python
from torchvision import transforms
from PIL import Image

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])

image = Image.open('example.jpg')
tensor = transform(image)
print(tensor.shape)

常用图像变换 #

python
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(
        brightness=0.2,
        contrast=0.2,
        saturation=0.2,
        hue=0.1
    ),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

自定义变换 #

python
from torchvision import transforms
import torch

class AddGaussianNoise:
    def __init__(self, mean=0.0, std=1.0):
        self.mean = mean
        self.std = std
    
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"

class Cutout:
    def __init__(self, n_holes=1, length=16):
        self.n_holes = n_holes
        self.length = length
    
    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = torch.ones((h, w), dtype=torch.float32)
        
        for _ in range(self.n_holes):
            y = torch.randint(h, (1,)).item()
            x = torch.randint(w, (1,)).item()
            
            y1 = max(0, y - self.length // 2)
            y2 = min(h, y + self.length // 2)
            x1 = max(0, x - self.length // 2)
            x2 = min(w, x + self.length // 2)
            
            mask[y1:y2, x1:x2] = 0
        
        return img * mask.unsqueeze(0)

transform = transforms.Compose([
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.1),
    Cutout(n_holes=1, length=16)
])

内置数据集 #

torchvision 数据集 #

python
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset = datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")

CIFAR-10 数据集 #

python
from torchvision import datasets, transforms

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

train_dataset = datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform_train
)

test_dataset = datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform_test
)

ImageNet 数据集 #

python
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

dataset = datasets.ImageFolder(
    root='./imagenet/train',
    transform=transform
)

print(f"类别数: {len(dataset.classes)}")
print(f"类别名: {dataset.classes[:10]}")

torchtext 数据集 #

python
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

tokenizer = get_tokenizer('basic_english')

train_iter = AG_NEWS(split='train')

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

print(f"词汇表大小: {len(vocab)}")

数据集划分 #

随机划分 #

python
from torch.utils.data import random_split
import torch
from torch.utils.data import TensorDataset

data = torch.randn(1000, 10)
labels = torch.randint(0, 2, (1000,))
dataset = TensorDataset(data, labels)

train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset,
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

print(f"训练集: {len(train_dataset)}")
print(f"验证集: {len(val_dataset)}")
print(f"测试集: {len(test_dataset)}")

分层划分 #

python
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.model_selection import train_test_split

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

data = np.random.randn(1000, 10)
labels = np.random.randint(0, 5, 1000)

train_data, test_data, train_labels, test_labels = train_test_split(
    data, labels, test_size=0.2, stratify=labels, random_state=42
)

train_dataset = CustomDataset(train_data, train_labels)
test_dataset = CustomDataset(test_data, test_labels)

数据采样器 #

WeightedRandomSampler #

python
import torch
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
import numpy as np

data = torch.randn(1000, 10)
labels = torch.cat([torch.zeros(900), torch.ones(100)]).long()

dataset = TensorDataset(data, labels)

class_counts = torch.bincount(labels)
class_weights = 1.0 / class_counts
sample_weights = class_weights[labels]

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

dataloader = DataLoader(dataset, batch_size=64, sampler=sampler)

for batch_data, batch_labels in dataloader:
    print(f"正样本比例: {batch_labels.float().mean():.2f}")
    break

自定义采样器 #

python
from torch.utils.data import Sampler
import torch

class BalancedBatchSampler(Sampler):
    def __init__(self, labels, batch_size):
        self.labels = torch.tensor(labels)
        self.batch_size = batch_size
        self.num_classes = len(torch.unique(self.labels))
        
        self.class_indices = {}
        for cls in range(self.num_classes):
            self.class_indices[cls] = torch.where(self.labels == cls)[0]
    
    def __iter__(self):
        batch = []
        for _ in range(len(self) // self.batch_size):
            for cls in range(self.num_classes):
                idx = torch.randint(len(self.class_indices[cls]), (1,))
                batch.append(self.class_indices[cls][idx].item())
            if len(batch) == self.batch_size:
                yield batch
                batch = []
    
    def __len__(self):
        return len(self.labels)

labels = torch.randint(0, 5, (1000,))
sampler = BalancedBatchSampler(labels, batch_size=10)

数据加载最佳实践 #

内存映射 #

python
import torch
from torch.utils.data import Dataset
import numpy as np

class MemoryMappedDataset(Dataset):
    def __init__(self, data_path, labels_path):
        self.data = np.load(data_path, mmap_mode='r')
        self.labels = np.load(labels_path, mmap_mode='r')
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return torch.from_numpy(self.data[idx].copy()), torch.tensor(self.labels[idx])

缓存机制 #

python
from torch.utils.data import Dataset
from functools import lru_cache

class CachedDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels
    
    def __len__(self):
        return len(self.data)
    
    @lru_cache(maxsize=1000)
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

预取数据 #

python
import torch
from torch.utils.data import DataLoader

class PrefetchLoader:
    def __init__(self, dataloader, device):
        self.dataloader = dataloader
        self.device = device
    
    def __iter__(self):
        for batch in self.dataloader:
            yield tuple(x.to(self.device, non_blocking=True) for x in batch)
    
    def __len__(self):
        return len(self.dataloader)

device = torch.device("cuda")
dataloader = DataLoader(dataset, batch_size=64, pin_memory=True)
prefetch_loader = PrefetchLoader(dataloader, device)

for data, labels in prefetch_loader:
    pass

完整示例 #

python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

full_dataset = datasets.CIFAR10(root='./data', train=True, download=True)
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_dataset.dataset.transform = transform_train
val_dataset.dataset.transform = transform_test

train_loader = DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=128,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.25)
    
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = self.dropout(self.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return running_loss / len(dataloader), 100. * correct / total

for epoch in range(10):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    print(f"Epoch {epoch+1}, Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")

下一步 #

现在你已经掌握了 PyTorch 数据加载的核心概念,接下来学习 卷积神经网络,开始构建图像分类模型!

最后更新:2026-03-29