PyTorch 数据加载 #
数据加载概述 #
在深度学习中,高效的数据加载对于训练至关重要。PyTorch 提供了 Dataset 和 DataLoader 两个核心类来处理数据。
text
┌─────────────────────────────────────────────────────────────┐
│ 数据加载流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 原始数据 │
│ │ │
│ ▼ │
│ Dataset │
│ (定义如何获取单个样本) │
│ │ │
│ ▼ │
│ DataLoader │
│ (批量加载、打乱、并行) │
│ │ │
│ ▼ │
│ 训练循环 │
│ │
└─────────────────────────────────────────────────────────────┘
Dataset 类 #
Dataset 基本概念 #
Dataset 是一个抽象类,需要实现两个方法:
__len__: 返回数据集大小__getitem__: 根据索引获取单个样本
python
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
import torch
data = torch.randn(100, 10)
labels = torch.randint(0, 2, (100,))
dataset = CustomDataset(data, labels)
print(f"数据集大小: {len(dataset)}")
print(f"第一个样本: {dataset[0]}")
图像数据集示例 #
python
import os
from PIL import Image
from torch.utils.data import Dataset
import torch
class ImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.images = []
self.labels = []
for label, class_name in enumerate(sorted(os.listdir(root_dir))):
class_dir = os.path.join(root_dir, class_name)
if os.path.isdir(class_dir):
for img_name in os.listdir(class_dir):
self.images.append(os.path.join(class_dir, img_name))
self.labels.append(label)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
image = Image.open(img_path).convert('RGB')
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
文本数据集示例 #
python
from torch.utils.data import Dataset
import torch
class TextDataset(Dataset):
def __init__(self, texts, labels, vocab, max_length=100):
self.texts = texts
self.labels = labels
self.vocab = vocab
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
tokens = text.lower().split()
token_ids = [self.vocab.get(token, self.vocab['<unk>']) for token in tokens]
if len(token_ids) < self.max_length:
token_ids = token_ids + [self.vocab['<pad>']] * (self.max_length - len(token_ids))
else:
token_ids = token_ids[:self.max_length]
return torch.tensor(token_ids), torch.tensor(label)
DataLoader 类 #
DataLoader 基本用法 #
python
from torch.utils.data import DataLoader, TensorDataset
import torch
data = torch.randn(1000, 10)
labels = torch.randint(0, 2, (1000,))
dataset = TensorDataset(data, labels)
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=0
)
for batch_data, batch_labels in dataloader:
print(f"Batch data shape: {batch_data.shape}")
print(f"Batch labels shape: {batch_labels.shape}")
break
DataLoader 参数详解 #
text
┌─────────────────────────────────────────────────────────────┐
│ DataLoader 参数 │
├─────────────────────────────────────────────────────────────┤
│ │
│ dataset: Dataset │
│ - 数据集对象 │
│ │
│ batch_size: int │
│ - 每批样本数量 │
│ - 典型值: 32, 64, 128, 256 │
│ │
│ shuffle: bool │
│ - 是否打乱数据 │
│ - 训练时通常为 True │
│ │
│ num_workers: int │
│ - 数据加载进程数 │
│ - 0: 主进程加载 │
│ - >0: 多进程并行加载 │
│ │
│ pin_memory: bool │
│ - 是否将数据固定在内存 │
│ - GPU 训练时建议 True │
│ │
│ drop_last: bool │
│ - 是否丢弃最后不完整的批次 │
│ │
│ collate_fn: callable │
│ - 自定义批次整理函数 │
│ │
└─────────────────────────────────────────────────────────────┘
python
from torch.utils.data import DataLoader
dataloader = DataLoader(
dataset,
batch_size=64,
shuffle=True,
num_workers=4,
pin_memory=True,
drop_last=True
)
多进程加载 #
python
import torch
from torch.utils.data import DataLoader, TensorDataset
data = torch.randn(10000, 10)
labels = torch.randint(0, 2, (10000,))
dataset = TensorDataset(data, labels)
import time
start = time.time()
dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
for _ in dataloader:
pass
print(f"单进程时间: {time.time() - start:.2f}s")
start = time.time()
dataloader = DataLoader(dataset, batch_size=64, num_workers=4)
for _ in dataloader:
pass
print(f"多进程时间: {time.time() - start:.2f}s")
自定义 collate_fn #
python
from torch.utils.data import DataLoader, Dataset
import torch
class VariableLengthDataset(Dataset):
def __init__(self):
self.data = [
torch.randn(5),
torch.randn(10),
torch.randn(3),
torch.randn(8)
]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def collate_fn(batch):
lengths = [len(item) for item in batch]
max_len = max(lengths)
padded_batch = []
for item in batch:
padding = torch.zeros(max_len - len(item))
padded_item = torch.cat([item, padding])
padded_batch.append(padded_item)
return torch.stack(padded_batch), torch.tensor(lengths)
dataset = VariableLengthDataset()
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
for batch, lengths in dataloader:
print(f"Batch shape: {batch.shape}")
print(f"Lengths: {lengths}")
数据预处理 #
torchvision.transforms #
python
from torchvision import transforms
from PIL import Image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = Image.open('example.jpg')
tensor = transform(image)
print(tensor.shape)
常用图像变换 #
python
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(15),
transforms.ColorJitter(
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=0.1
),
transforms.RandomGrayscale(p=0.1),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
自定义变换 #
python
from torchvision import transforms
import torch
class AddGaussianNoise:
def __init__(self, mean=0.0, std=1.0):
self.mean = mean
self.std = std
def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
class Cutout:
def __init__(self, n_holes=1, length=16):
self.n_holes = n_holes
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = torch.ones((h, w), dtype=torch.float32)
for _ in range(self.n_holes):
y = torch.randint(h, (1,)).item()
x = torch.randint(w, (1,)).item()
y1 = max(0, y - self.length // 2)
y2 = min(h, y + self.length // 2)
x1 = max(0, x - self.length // 2)
x2 = min(w, x + self.length // 2)
mask[y1:y2, x1:x2] = 0
return img * mask.unsqueeze(0)
transform = transforms.Compose([
transforms.ToTensor(),
AddGaussianNoise(0., 0.1),
Cutout(n_holes=1, length=16)
])
内置数据集 #
torchvision 数据集 #
python
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transform
)
test_dataset = datasets.MNIST(
root='./data',
train=False,
download=True,
transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
CIFAR-10 数据集 #
python
from torchvision import datasets, transforms
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
train_dataset = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform_train
)
test_dataset = datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform_test
)
ImageNet 数据集 #
python
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
dataset = datasets.ImageFolder(
root='./imagenet/train',
transform=transform
)
print(f"类别数: {len(dataset.classes)}")
print(f"类别名: {dataset.classes[:10]}")
torchtext 数据集 #
python
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
tokenizer = get_tokenizer('basic_english')
train_iter = AG_NEWS(split='train')
def yield_tokens(data_iter):
for _, text in data_iter:
yield tokenizer(text)
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])
print(f"词汇表大小: {len(vocab)}")
数据集划分 #
随机划分 #
python
from torch.utils.data import random_split
import torch
from torch.utils.data import TensorDataset
data = torch.randn(1000, 10)
labels = torch.randint(0, 2, (1000,))
dataset = TensorDataset(data, labels)
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(
dataset,
[train_size, val_size, test_size],
generator=torch.Generator().manual_seed(42)
)
print(f"训练集: {len(train_dataset)}")
print(f"验证集: {len(val_dataset)}")
print(f"测试集: {len(test_dataset)}")
分层划分 #
python
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.model_selection import train_test_split
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = torch.tensor(data, dtype=torch.float32)
self.labels = torch.tensor(labels, dtype=torch.long)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
data = np.random.randn(1000, 10)
labels = np.random.randint(0, 5, 1000)
train_data, test_data, train_labels, test_labels = train_test_split(
data, labels, test_size=0.2, stratify=labels, random_state=42
)
train_dataset = CustomDataset(train_data, train_labels)
test_dataset = CustomDataset(test_data, test_labels)
数据采样器 #
WeightedRandomSampler #
python
import torch
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
import numpy as np
data = torch.randn(1000, 10)
labels = torch.cat([torch.zeros(900), torch.ones(100)]).long()
dataset = TensorDataset(data, labels)
class_counts = torch.bincount(labels)
class_weights = 1.0 / class_counts
sample_weights = class_weights[labels]
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True
)
dataloader = DataLoader(dataset, batch_size=64, sampler=sampler)
for batch_data, batch_labels in dataloader:
print(f"正样本比例: {batch_labels.float().mean():.2f}")
break
自定义采样器 #
python
from torch.utils.data import Sampler
import torch
class BalancedBatchSampler(Sampler):
def __init__(self, labels, batch_size):
self.labels = torch.tensor(labels)
self.batch_size = batch_size
self.num_classes = len(torch.unique(self.labels))
self.class_indices = {}
for cls in range(self.num_classes):
self.class_indices[cls] = torch.where(self.labels == cls)[0]
def __iter__(self):
batch = []
for _ in range(len(self) // self.batch_size):
for cls in range(self.num_classes):
idx = torch.randint(len(self.class_indices[cls]), (1,))
batch.append(self.class_indices[cls][idx].item())
if len(batch) == self.batch_size:
yield batch
batch = []
def __len__(self):
return len(self.labels)
labels = torch.randint(0, 5, (1000,))
sampler = BalancedBatchSampler(labels, batch_size=10)
数据加载最佳实践 #
内存映射 #
python
import torch
from torch.utils.data import Dataset
import numpy as np
class MemoryMappedDataset(Dataset):
def __init__(self, data_path, labels_path):
self.data = np.load(data_path, mmap_mode='r')
self.labels = np.load(labels_path, mmap_mode='r')
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return torch.from_numpy(self.data[idx].copy()), torch.tensor(self.labels[idx])
缓存机制 #
python
from torch.utils.data import Dataset
from functools import lru_cache
class CachedDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
@lru_cache(maxsize=1000)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
预取数据 #
python
import torch
from torch.utils.data import DataLoader
class PrefetchLoader:
def __init__(self, dataloader, device):
self.dataloader = dataloader
self.device = device
def __iter__(self):
for batch in self.dataloader:
yield tuple(x.to(self.device, non_blocking=True) for x in batch)
def __len__(self):
return len(self.dataloader)
device = torch.device("cuda")
dataloader = DataLoader(dataset, batch_size=64, pin_memory=True)
prefetch_loader = PrefetchLoader(dataloader, device)
for data, labels in prefetch_loader:
pass
完整示例 #
python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
full_dataset = datasets.CIFAR10(root='./data', train=True, download=True)
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
train_dataset.dataset.transform = transform_train
val_dataset.dataset.transform = transform_test
train_loader = DataLoader(
train_dataset,
batch_size=128,
shuffle=True,
num_workers=4,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=128,
shuffle=False,
num_workers=4,
pin_memory=True
)
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 8 * 8, 512)
self.fc2 = nn.Linear(512, 10)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.25)
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(-1, 64 * 8 * 8)
x = self.dropout(self.relu(self.fc1(x)))
x = self.fc2(x)
return x
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
def train_one_epoch(model, dataloader, criterion, optimizer, device):
model.train()
running_loss = 0.0
correct = 0
total = 0
for inputs, labels in dataloader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
return running_loss / len(dataloader), 100. * correct / total
for epoch in range(10):
train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
print(f"Epoch {epoch+1}, Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")
下一步 #
现在你已经掌握了 PyTorch 数据加载的核心概念,接下来学习 卷积神经网络,开始构建图像分类模型!
最后更新:2026-03-29